Merge branch 'main' into github-import

This commit is contained in:
Anirban Kar
2024-12-06 18:54:06 +05:30
committed by GitHub
29 changed files with 868 additions and 135 deletions

View File

@@ -51,7 +51,7 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
export function getBaseURL(cloudflareEnv: Env, provider: string) {
switch (provider) {
case 'Together':
return env.TOGETHER_API_BASE_URL || cloudflareEnv.TOGETHER_API_BASE_URL;
return env.TOGETHER_API_BASE_URL || cloudflareEnv.TOGETHER_API_BASE_URL || 'https://api.together.xyz/v1';
case 'OpenAILike':
return env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL;
case 'LMStudio':

View File

@@ -128,7 +128,12 @@ export function getXAIModel(apiKey: OptionalApiKey, model: string) {
}
export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) {
const apiKey = getAPIKey(env, provider, apiKeys);
/*
* let apiKey; // Declare first
* let baseURL;
*/
const apiKey = getAPIKey(env, provider, apiKeys); // Then assign
const baseURL = getBaseURL(env, provider);
switch (provider) {

View File

@@ -1,11 +1,8 @@
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-nocheck TODO: Provider proper types
import { convertToCoreMessages, streamText as _streamText } from 'ai';
import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
@@ -26,24 +23,50 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;
function extractPropertiesFromMessage(message: Message): { model: string; provider: string; content: string } {
// Extract model
const modelMatch = message.content.match(MODEL_REGEX);
const textContent = Array.isArray(message.content)
? message.content.find((item) => item.type === 'text')?.text || ''
: message.content;
const modelMatch = textContent.match(MODEL_REGEX);
const providerMatch = textContent.match(PROVIDER_REGEX);
/*
* Extract model
* const modelMatch = message.content.match(MODEL_REGEX);
*/
const model = modelMatch ? modelMatch[1] : DEFAULT_MODEL;
// Extract provider
const providerMatch = message.content.match(PROVIDER_REGEX);
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
/*
* Extract provider
* const providerMatch = message.content.match(PROVIDER_REGEX);
*/
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER.name;
// Remove model and provider lines from content
const cleanedContent = message.content.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, '').trim();
const cleanedContent = Array.isArray(message.content)
? message.content.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text?.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, ''),
};
}
return item; // Preserve image_url and other types as is
})
: textContent.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, '');
return { model, provider, content: cleanedContent };
}
export function streamText(messages: Messages, env: Env, options?: StreamingOptions, apiKeys?: Record<string, string>) {
export async function streamText(
messages: Messages,
env: Env,
options?: StreamingOptions,
apiKeys?: Record<string, string>,
) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys || {});
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message);
@@ -65,10 +88,10 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys),
model: getModel(currentProvider, currentModel, env, apiKeys) as any,
system: getSystemPrompt(),
maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages),
messages: convertToCoreMessages(processedMessages as any),
...options,
});
}

View File

@@ -29,6 +29,7 @@ exports[`StreamingMessageParser > valid artifacts with actions > should correctl
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -37,6 +38,7 @@ exports[`StreamingMessageParser > valid artifacts with actions > should correctl
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -96,6 +98,7 @@ exports[`StreamingMessageParser > valid artifacts with actions > should correctl
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -104,6 +107,7 @@ exports[`StreamingMessageParser > valid artifacts with actions > should correctl
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -112,6 +116,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -120,6 +125,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -128,6 +134,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": "bundled",
}
`;
@@ -136,6 +143,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": "bundled",
}
`;
@@ -144,6 +152,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -152,6 +161,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -160,6 +170,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -168,6 +179,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -176,6 +188,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -184,6 +197,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -192,6 +206,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -200,6 +215,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -208,6 +224,7 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;
@@ -216,5 +233,6 @@ exports[`StreamingMessageParser > valid artifacts without actions > should corre
"id": "artifact_1",
"messageId": "message_1",
"title": "Some title",
"type": undefined,
}
`;

View File

@@ -59,7 +59,11 @@ describe('StreamingMessageParser', () => {
},
],
[
['Some text before <boltArti', 'fact', ' title="Some title" id="artifact_1">foo</boltArtifact> Some more text'],
[
'Some text before <boltArti',
'fact',
' title="Some title" id="artifact_1" type="bundled" >foo</boltArtifact> Some more text',
],
{
output: 'Some text before Some more text',
callbacks: { onArtifactOpen: 1, onArtifactClose: 1, onActionOpen: 0, onActionClose: 0 },

View File

@@ -192,6 +192,7 @@ export class StreamingMessageParser {
const artifactTag = input.slice(i, openTagEnd + 1);
const artifactTitle = this.#extractAttribute(artifactTag, 'title') as string;
const type = this.#extractAttribute(artifactTag, 'type') as string;
const artifactId = this.#extractAttribute(artifactTag, 'id') as string;
if (!artifactTitle) {
@@ -207,6 +208,7 @@ export class StreamingMessageParser {
const currentArtifact = {
id: artifactId,
title: artifactTitle,
type,
} satisfies BoltArtifactData;
state.currentArtifact = currentArtifact;

View File

@@ -212,9 +212,5 @@ function isBinaryFile(buffer: Uint8Array | undefined) {
* array buffer.
*/
function convertToBuffer(view: Uint8Array): Buffer {
const buffer = new Uint8Array(view.buffer, view.byteOffset, view.byteLength);
Object.setPrototypeOf(buffer, Buffer.prototype);
return buffer as Buffer;
return Buffer.from(view.buffer, view.byteOffset, view.byteLength);
}

View File

@@ -19,6 +19,7 @@ import { description } from '~/lib/persistence';
export interface ArtifactState {
id: string;
title: string;
type?: string;
closed: boolean;
runner: ActionRunner;
}
@@ -230,7 +231,7 @@ export class WorkbenchStore {
// TODO: what do we wanna do and how do we wanna recover from this?
}
addArtifact({ messageId, title, id }: ArtifactCallbackData) {
addArtifact({ messageId, title, id, type }: ArtifactCallbackData) {
const artifact = this.#getArtifact(messageId);
if (artifact) {
@@ -245,6 +246,7 @@ export class WorkbenchStore {
id,
title,
closed: false,
type,
runner: new ActionRunner(webcontainer, () => this.boltTerminal),
});
}