This commit is contained in:
Karrot0
2024-11-10 21:08:06 -05:00
14 changed files with 238 additions and 115 deletions

View File

@@ -2,12 +2,18 @@
// Preventing TS checks with files presented in the video for a better presentation.
import { env } from 'node:process';
export function getAPIKey(cloudflareEnv: Env, provider: string) {
export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Record<string, string>) {
/**
* The `cloudflareEnv` is only used when deployed or when previewing locally.
* In development the environment variables are available through `env`.
*/
// First check user-provided API keys
if (userApiKeys?.[provider]) {
return userApiKeys[provider];
}
// Fall back to environment variables
switch (provider) {
case 'Anthropic':
return env.ANTHROPIC_API_KEY || cloudflareEnv.ANTHROPIC_API_KEY;

View File

@@ -100,9 +100,8 @@ export function getXAIModel(apiKey: string, model: string) {
return openai(model);
}
export function getModel(provider: string, model: string, env: Env) {
const apiKey = getAPIKey(env, provider);
export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) {
const apiKey = getAPIKey(env, provider, apiKeys);
const baseURL = getBaseURL(env, provider);
switch (provider) {

View File

@@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai';
import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
@@ -24,42 +24,53 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;
function extractModelFromMessage(message: Message): { model: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);
function extractPropertiesFromMessage(message: Message): { model: string; provider: string; content: string } {
// Extract model
const modelMatch = message.content.match(MODEL_REGEX);
const model = modelMatch ? modelMatch[1] : DEFAULT_MODEL;
if (match) {
const model = match[1];
const content = message.content.replace(modelRegex, '');
return { model, content };
}
// Extract provider
const providerMatch = message.content.match(PROVIDER_REGEX);
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
// Default model if not specified
return { model: DEFAULT_MODEL, content: message.content };
// Remove model and provider lines from content
const cleanedContent = message.content
.replace(MODEL_REGEX, '')
.replace(PROVIDER_REGEX, '')
.trim();
return { model, provider, content: cleanedContent };
}
export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
export function streamText(
messages: Messages,
env: Env,
options?: StreamingOptions,
apiKeys?: Record<string, string>
) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model
const { model, provider, content } = extractPropertiesFromMessage(message);
if (MODEL_LIST.find((m) => m.name === model)) {
currentModel = model;
}
currentProvider = provider;
return { ...message, content };
}
return message;
return message; // No changes for non-user messages
});
const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;
return _streamText({
model: getModel(provider, currentModel, env),
model: getModel(currentProvider, currentModel, env, apiKeys),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
// headers: {
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
// },
messages: convertToCoreMessages(processedMessages),
...options,
});