Merge branch 'main' into main

This commit is contained in:
Chris Mahoney
2024-11-21 20:39:08 -06:00
committed by GitHub
35 changed files with 755 additions and 200 deletions

View File

@@ -23,6 +23,8 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
return env.GOOGLE_GENERATIVE_AI_API_KEY || cloudflareEnv.GOOGLE_GENERATIVE_AI_API_KEY;
case 'Groq':
return env.GROQ_API_KEY || cloudflareEnv.GROQ_API_KEY;
case 'HuggingFace':
return env.HuggingFace_API_KEY || cloudflareEnv.HuggingFace_API_KEY;
case 'OpenRouter':
return env.OPEN_ROUTER_API_KEY || cloudflareEnv.OPEN_ROUTER_API_KEY;
case 'Deepseek':
@@ -33,6 +35,8 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY;
case "xAI":
return env.XAI_API_KEY || cloudflareEnv.XAI_API_KEY;
case "Cohere":
return env.COHERE_API_KEY;
default:
return "";
}

View File

@@ -7,6 +7,7 @@ import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { createMistral } from '@ai-sdk/mistral';
import { createCohere } from '@ai-sdk/cohere'
export const DEFAULT_NUM_CTX = process.env.DEFAULT_NUM_CTX ?
parseInt(process.env.DEFAULT_NUM_CTX, 10) :
@@ -27,6 +28,15 @@ export function getOpenAILikeModel(baseURL:string,apiKey: string, model: string)
return openai(model);
}
export function getCohereAIModel(apiKey:string, model: string){
const cohere = createCohere({
apiKey,
});
return cohere(model);
}
export function getOpenAIModel(apiKey: string, model: string) {
const openai = createOpenAI({
apiKey,
@@ -60,6 +70,15 @@ export function getGroqModel(apiKey: string, model: string) {
return openai(model);
}
export function getHuggingFaceModel(apiKey: string, model: string) {
const openai = createOpenAI({
baseURL: 'https://api-inference.huggingface.co/v1/',
apiKey,
});
return openai(model);
}
export function getOllamaModel(baseURL: string, model: string) {
let Ollama = ollama(model, {
numCtx: DEFAULT_NUM_CTX,
@@ -103,6 +122,8 @@ export function getXAIModel(apiKey: string, model: string) {
return openai(model);
}
export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) {
const apiKey = getAPIKey(env, provider, apiKeys);
const baseURL = getBaseURL(env, provider);
@@ -114,6 +135,8 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re
return getOpenAIModel(apiKey, model);
case 'Groq':
return getGroqModel(apiKey, model);
case 'HuggingFace':
return getHuggingFaceModel(apiKey, model);
case 'OpenRouter':
return getOpenRouterModel(apiKey, model);
case 'Google':
@@ -128,6 +151,8 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re
return getLMStudioModel(baseURL, model);
case 'xAI':
return getXAIModel(apiKey, model);
case 'Cohere':
return getCohereAIModel(apiKey, model);
default:
return getOllamaModel(baseURL, model);
}

View File

@@ -88,7 +88,7 @@ You are Bolt, an expert AI assistant and exceptional senior software developer w
Example:
<${MODIFICATIONS_TAG_NAME}>
<diff path="/home/project/src/main.js">
<diff path="${WORK_DIR}/src/main.js">
@@ -2,7 +2,10 @@
return a + b;
}
@@ -103,7 +103,7 @@ You are Bolt, an expert AI assistant and exceptional senior software developer w
+
+console.log('The End');
</diff>
<file path="/home/project/package.json">
<file path="${WORK_DIR}/package.json">
// full file content here
</file>
</${MODIFICATIONS_TAG_NAME}>

View File

@@ -41,10 +41,9 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent };
}
export function streamText(
messages: Messages,
env: Env,
messages: Messages,
env: Env,
options?: StreamingOptions,
apiKeys?: Record<string, string>
) {
@@ -64,13 +63,22 @@ export function streamText(
return { ...message, content };
}
return message; // No changes for non-user messages
return message;
});
const modelDetails = MODEL_LIST.find((m) => m.name === currentModel);
const dynamicMaxTokens =
modelDetails && modelDetails.maxTokenAllowed
? modelDetails.maxTokenAllowed
: MAX_TOKENS;
return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages),
...options,
});

View File

@@ -2,3 +2,4 @@ export * from './useMessageParser';
export * from './usePromptEnhancer';
export * from './useShortcuts';
export * from './useSnapScroll';
export { default } from './useViewport';

View File

@@ -1,4 +1,5 @@
import { useState } from 'react';
import type { ProviderInfo } from '~/types/model';
import { createScopedLogger } from '~/utils/logger';
const logger = createScopedLogger('usePromptEnhancement');
@@ -13,54 +14,54 @@ export function usePromptEnhancer() {
};
const enhancePrompt = async (
input: string,
input: string,
setInput: (value: string) => void,
model: string,
provider: string,
apiKeys?: Record<string, string>
provider: ProviderInfo,
apiKeys?: Record<string, string>,
) => {
setEnhancingPrompt(true);
setPromptEnhanced(false);
const requestBody: any = {
message: input,
model,
provider,
};
if (apiKeys) {
requestBody.apiKeys = apiKeys;
}
const response = await fetch('/api/enhancer', {
method: 'POST',
body: JSON.stringify(requestBody),
});
const reader = response.body?.getReader();
const originalInput = input;
if (reader) {
const decoder = new TextDecoder();
let _input = '';
let _error;
try {
setInput('');
while (true) {
const { value, done } = await reader.read();
if (done) {
break;
}
_input += decoder.decode(value);
logger.trace('Set input', _input);
setInput(_input);
}
} catch (error) {
@@ -70,10 +71,10 @@ export function usePromptEnhancer() {
if (_error) {
logger.error(_error);
}
setEnhancingPrompt(false);
setPromptEnhanced(true);
setTimeout(() => {
setInput(_input);
});

View File

@@ -0,0 +1,18 @@
import { useState, useEffect } from 'react';
const useViewport = (threshold = 1024) => {
const [isSmallViewport, setIsSmallViewport] = useState(window.innerWidth < threshold);
useEffect(() => {
const handleResize = () => setIsSmallViewport(window.innerWidth < threshold);
window.addEventListener('resize', handleResize);
return () => {
window.removeEventListener('resize', handleResize);
};
}, [threshold]);
return isSmallViewport;
};
export default useViewport;

View File

@@ -158,3 +158,50 @@ async function getUrlIds(db: IDBDatabase): Promise<string[]> {
};
});
}
export async function forkChat(db: IDBDatabase, chatId: string, messageId: string): Promise<string> {
const chat = await getMessages(db, chatId);
if (!chat) throw new Error('Chat not found');
// Find the index of the message to fork at
const messageIndex = chat.messages.findIndex(msg => msg.id === messageId);
if (messageIndex === -1) throw new Error('Message not found');
// Get messages up to and including the selected message
const messages = chat.messages.slice(0, messageIndex + 1);
// Generate new IDs
const newId = await getNextId(db);
const urlId = await getUrlId(db, newId);
// Create the forked chat
await setMessages(
db,
newId,
messages,
urlId,
chat.description ? `${chat.description} (fork)` : 'Forked chat'
);
return urlId;
}
export async function duplicateChat(db: IDBDatabase, id: string): Promise<string> {
const chat = await getMessages(db, id);
if (!chat) {
throw new Error('Chat not found');
}
const newId = await getNextId(db);
const newUrlId = await getUrlId(db, newId); // Get a new urlId for the duplicated chat
await setMessages(
db,
newId,
chat.messages,
newUrlId, // Use the new urlId
`${chat.description || 'Chat'} (copy)`
);
return newUrlId; // Return the urlId instead of id for navigation
}

View File

@@ -1,10 +1,10 @@
import { useLoaderData, useNavigate } from '@remix-run/react';
import { useLoaderData, useNavigate, useSearchParams } from '@remix-run/react';
import { useState, useEffect } from 'react';
import { atom } from 'nanostores';
import type { Message } from 'ai';
import { toast } from 'react-toastify';
import { workbenchStore } from '~/lib/stores/workbench';
import { getMessages, getNextId, getUrlId, openDatabase, setMessages } from './db';
import { getMessages, getNextId, getUrlId, openDatabase, setMessages, duplicateChat } from './db';
export interface ChatHistoryItem {
id: string;
@@ -24,6 +24,7 @@ export const description = atom<string | undefined>(undefined);
export function useChatHistory() {
const navigate = useNavigate();
const { id: mixedId } = useLoaderData<{ id?: string }>();
const [searchParams] = useSearchParams();
const [initialMessages, setInitialMessages] = useState<Message[]>([]);
const [ready, setReady] = useState<boolean>(false);
@@ -44,7 +45,12 @@ export function useChatHistory() {
getMessages(db, mixedId)
.then((storedMessages) => {
if (storedMessages && storedMessages.messages.length > 0) {
setInitialMessages(storedMessages.messages);
const rewindId = searchParams.get('rewindTo');
const filteredMessages = rewindId
? storedMessages.messages.slice(0, storedMessages.messages.findIndex((m) => m.id === rewindId) + 1)
: storedMessages.messages;
setInitialMessages(filteredMessages);
setUrlId(storedMessages.urlId);
description.set(storedMessages.description);
chatId.set(storedMessages.id);
@@ -93,6 +99,19 @@ export function useChatHistory() {
await setMessages(db, chatId.get() as string, messages, urlId, description.get());
},
duplicateCurrentChat: async (listItemId:string) => {
if (!db || (!mixedId && !listItemId)) {
return;
}
try {
const newId = await duplicateChat(db, mixedId || listItemId);
navigate(`/chat/${newId}`);
toast.success('Chat duplicated successfully');
} catch (error) {
toast.error('Failed to duplicate chat');
}
}
};
}

View File

@@ -94,7 +94,7 @@ export class ActionRunner {
this.#updateAction(actionId, { ...action, ...data.action, executed: !isStreaming });
this.#currentExecutionPromise = this.#currentExecutionPromise
return this.#currentExecutionPromise = this.#currentExecutionPromise
.then(() => {
return this.#executeAction(actionId, isStreaming);
})
@@ -119,7 +119,14 @@ export class ActionRunner {
break;
}
case 'start': {
await this.#runStartAction(action)
// making the start app non blocking
this.#runStartAction(action).then(()=>this.#updateAction(actionId, { status: 'complete' }))
.catch(()=>this.#updateAction(actionId, { status: 'failed', error: 'Action failed' }))
// adding a delay to avoid any race condition between 2 start actions
// i am up for a better approch
await new Promise(resolve=>setTimeout(resolve,2000))
return
break;
}
}

View File

@@ -14,6 +14,7 @@ import { saveAs } from 'file-saver';
import { Octokit, type RestEndpointMethodTypes } from "@octokit/rest";
import * as nodePath from 'node:path';
import type { WebContainerProcess } from '@webcontainer/api';
import { extractRelativePath } from '~/utils/diff';
export interface ArtifactState {
id: string;
@@ -42,7 +43,7 @@ export class WorkbenchStore {
modifiedFiles = new Set<string>();
artifactIdList: string[] = [];
#boltTerminal: { terminal: ITerminal; process: WebContainerProcess } | undefined;
#globalExecutionQueue=Promise.resolve();
constructor() {
if (import.meta.hot) {
import.meta.hot.data.artifacts = this.artifacts;
@@ -52,6 +53,10 @@ export class WorkbenchStore {
}
}
addToExecutionQueue(callback: () => Promise<void>) {
this.#globalExecutionQueue=this.#globalExecutionQueue.then(()=>callback())
}
get previews() {
return this.#previewsStore.previews;
}
@@ -255,8 +260,11 @@ export class WorkbenchStore {
this.artifacts.setKey(messageId, { ...artifact, ...state });
}
async addAction(data: ActionCallbackData) {
addAction(data: ActionCallbackData) {
this._addAction(data)
// this.addToExecutionQueue(()=>this._addAction(data))
}
async _addAction(data: ActionCallbackData) {
const { messageId } = data;
const artifact = this.#getArtifact(messageId);
@@ -265,10 +273,18 @@ export class WorkbenchStore {
unreachable('Artifact not found');
}
artifact.runner.addAction(data);
return artifact.runner.addAction(data);
}
async runAction(data: ActionCallbackData, isStreaming: boolean = false) {
runAction(data: ActionCallbackData, isStreaming: boolean = false) {
if(isStreaming) {
this._runAction(data, isStreaming)
}
else{
this.addToExecutionQueue(()=>this._runAction(data, isStreaming))
}
}
async _runAction(data: ActionCallbackData, isStreaming: boolean = false) {
const { messageId } = data;
const artifact = this.#getArtifact(messageId);
@@ -293,11 +309,11 @@ export class WorkbenchStore {
this.#editorStore.updateFile(fullPath, data.action.content);
if (!isStreaming) {
this.resetCurrentDocument();
await artifact.runner.runAction(data);
this.resetAllFileModifications();
}
} else {
artifact.runner.runAction(data);
await artifact.runner.runAction(data);
}
}
@@ -312,8 +328,7 @@ export class WorkbenchStore {
for (const [filePath, dirent] of Object.entries(files)) {
if (dirent?.type === 'file' && !dirent.isBinary) {
// remove '/home/project/' from the beginning of the path
const relativePath = filePath.replace(/^\/home\/project\//, '');
const relativePath = extractRelativePath(filePath);
// split the path into segments
const pathSegments = relativePath.split('/');
@@ -343,7 +358,7 @@ export class WorkbenchStore {
for (const [filePath, dirent] of Object.entries(files)) {
if (dirent?.type === 'file' && !dirent.isBinary) {
const relativePath = filePath.replace(/^\/home\/project\//, '');
const relativePath = extractRelativePath(filePath);
const pathSegments = relativePath.split('/');
let currentHandle = targetHandle;
@@ -417,7 +432,7 @@ export class WorkbenchStore {
content: Buffer.from(dirent.content).toString('base64'),
encoding: 'base64',
});
return { path: filePath.replace(/^\/home\/project\//, ''), sha: blob.sha };
return { path: extractRelativePath(filePath), sha: blob.sha };
}
})
);