Merge branch 'main' into respect-provider-choice

This commit is contained in:
Cole Medin
2024-11-09 07:44:40 -06:00
committed by GitHub
11 changed files with 196 additions and 27 deletions

View File

@@ -2,12 +2,18 @@
// Preventing TS checks with files presented in the video for a better presentation.
import { env } from 'node:process';
export function getAPIKey(cloudflareEnv: Env, provider: string) {
export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Record<string, string>) {
/**
* The `cloudflareEnv` is only used when deployed or when previewing locally.
* In development the environment variables are available through `env`.
*/
// First check user-provided API keys
if (userApiKeys?.[provider]) {
return userApiKeys[provider];
}
// Fall back to environment variables
switch (provider) {
case 'Anthropic':
return env.ANTHROPIC_API_KEY || cloudflareEnv.ANTHROPIC_API_KEY;
@@ -25,6 +31,8 @@ export function getAPIKey(cloudflareEnv: Env, provider: string) {
return env.MISTRAL_API_KEY || cloudflareEnv.MISTRAL_API_KEY;
case "OpenAILike":
return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY;
case "xAI":
return env.XAI_API_KEY || cloudflareEnv.XAI_API_KEY;
default:
return "";
}

View File

@@ -58,7 +58,10 @@ export function getGroqModel(apiKey: string, model: string) {
}
export function getOllamaModel(baseURL: string, model: string) {
let Ollama = ollama(model);
let Ollama = ollama(model, {
numCtx: 32768,
});
Ollama.config.baseURL = `${baseURL}/api`;
return Ollama;
}
@@ -80,8 +83,16 @@ export function getOpenRouterModel(apiKey: string, model: string) {
return openRouter.chat(model);
}
export function getModel(provider: string, model: string, env: Env) {
const apiKey = getAPIKey(env, provider);
export function getXAIModel(apiKey: string, model: string) {
const openai = createOpenAI({
baseURL: 'https://api.x.ai/v1',
apiKey,
});
return openai(model);
}
export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) {
const apiKey = getAPIKey(env, provider, apiKeys);
const baseURL = getBaseURL(env, provider);
switch (provider) {
@@ -101,6 +112,8 @@ export function getModel(provider: string, model: string, env: Env) {
return getDeepseekModel(apiKey, model)
case 'Mistral':
return getMistralModel(apiKey, model);
case 'xAI':
return getXAIModel(apiKey, model);
default:
return getOllamaModel(baseURL, model);
}

View File

@@ -42,7 +42,12 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent };
}
export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
export function streamText(
messages: Messages,
env: Env,
options?: StreamingOptions,
apiKeys?: Record<string, string>
) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
@@ -63,7 +68,7 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
});
return _streamText({
model: getModel(currentProvider, currentModel, env),
model: getModel(currentProvider, currentModel, env, apiKeys),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
messages: convertToCoreMessages(processedMessages),