Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Andrew Trokhymenko
2024-11-21 23:31:30 -05:00
22 changed files with 374 additions and 217 deletions

View File

@@ -35,6 +35,8 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY;
case "xAI":
return env.XAI_API_KEY || cloudflareEnv.XAI_API_KEY;
case "Cohere":
return env.COHERE_API_KEY;
default:
return "";
}

View File

@@ -7,6 +7,11 @@ import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { createMistral } from '@ai-sdk/mistral';
import { createCohere } from '@ai-sdk/cohere'
export const DEFAULT_NUM_CTX = process.env.DEFAULT_NUM_CTX ?
parseInt(process.env.DEFAULT_NUM_CTX, 10) :
32768;
export function getAnthropicModel(apiKey: string, model: string) {
const anthropic = createAnthropic({
@@ -22,14 +27,16 @@ export function getOpenAILikeModel(baseURL: string, apiKey: string, model: strin
baseURL,
apiKey,
});
// console.log('OpenAI client created:', !!openai);
const client = openai(model);
// console.log('OpenAI model client:', !!client);
return client;
// return {
// model: client,
// provider: 'OpenAILike' // Correctly identifying the actual provider
// };
return openai(model);
}
export function getCohereAIModel(apiKey:string, model: string){
const cohere = createCohere({
apiKey,
});
return cohere(model);
}
export function getOpenAIModel(apiKey: string, model: string) {
@@ -76,7 +83,7 @@ export function getHuggingFaceModel(apiKey: string, model: string) {
export function getOllamaModel(baseURL: string, model: string) {
let Ollama = ollama(model, {
numCtx: 32768,
numCtx: DEFAULT_NUM_CTX,
});
Ollama.config.baseURL = `${baseURL}/api`;
@@ -150,6 +157,8 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re
return getLMStudioModel(baseURL, model);
case 'xAI':
return getXAIModel(apiKey, model);
case 'Cohere':
return getCohereAIModel(apiKey, model);
default:
return getOllamaModel(baseURL, model);
}

View File

@@ -58,7 +58,6 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent };
}
export function streamText(
messages: Messages,
env: Env,
@@ -68,8 +67,6 @@ export function streamText(
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
// console.log('StreamText:', JSON.stringify(messages));
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message);
@@ -83,25 +80,19 @@ export function streamText(
return { ...message, content };
}
return message; // No changes for non-user messages
});
const modelDetails = MODEL_LIST.find((m) => m.name === currentModel);
// console.log('Message content:', messages[0].content);
// console.log('Extracted properties:', extractPropertiesFromMessage(messages[0]));
const dynamicMaxTokens =
modelDetails && modelDetails.maxTokenAllowed
? modelDetails.maxTokenAllowed
: MAX_TOKENS;
const llmClient = getModel(currentProvider, currentModel, env, apiKeys);
// console.log('LLM Client:', llmClient);
const llmConfig = {
...options,
model: llmClient, //getModel(currentProvider, currentModel, env, apiKeys),
provider: currentProvider,
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
messages: convertToCoreMessages(processedMessages),
};
// console.log('LLM Config:', llmConfig);
return _streamText(llmConfig);
}
return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys),
system: getSystemPrompt(),
maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages),
...options,
});
}
)}