Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -35,6 +35,8 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
|
||||
return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY;
|
||||
case "xAI":
|
||||
return env.XAI_API_KEY || cloudflareEnv.XAI_API_KEY;
|
||||
case "Cohere":
|
||||
return env.COHERE_API_KEY;
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
|
||||
@@ -7,6 +7,11 @@ import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { ollama } from 'ollama-ai-provider';
|
||||
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
||||
import { createMistral } from '@ai-sdk/mistral';
|
||||
import { createCohere } from '@ai-sdk/cohere'
|
||||
|
||||
export const DEFAULT_NUM_CTX = process.env.DEFAULT_NUM_CTX ?
|
||||
parseInt(process.env.DEFAULT_NUM_CTX, 10) :
|
||||
32768;
|
||||
|
||||
export function getAnthropicModel(apiKey: string, model: string) {
|
||||
const anthropic = createAnthropic({
|
||||
@@ -22,14 +27,16 @@ export function getOpenAILikeModel(baseURL: string, apiKey: string, model: strin
|
||||
baseURL,
|
||||
apiKey,
|
||||
});
|
||||
// console.log('OpenAI client created:', !!openai);
|
||||
const client = openai(model);
|
||||
// console.log('OpenAI model client:', !!client);
|
||||
return client;
|
||||
// return {
|
||||
// model: client,
|
||||
// provider: 'OpenAILike' // Correctly identifying the actual provider
|
||||
// };
|
||||
|
||||
return openai(model);
|
||||
}
|
||||
|
||||
export function getCohereAIModel(apiKey:string, model: string){
|
||||
const cohere = createCohere({
|
||||
apiKey,
|
||||
});
|
||||
|
||||
return cohere(model);
|
||||
}
|
||||
|
||||
export function getOpenAIModel(apiKey: string, model: string) {
|
||||
@@ -76,7 +83,7 @@ export function getHuggingFaceModel(apiKey: string, model: string) {
|
||||
|
||||
export function getOllamaModel(baseURL: string, model: string) {
|
||||
let Ollama = ollama(model, {
|
||||
numCtx: 32768,
|
||||
numCtx: DEFAULT_NUM_CTX,
|
||||
});
|
||||
|
||||
Ollama.config.baseURL = `${baseURL}/api`;
|
||||
@@ -150,6 +157,8 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re
|
||||
return getLMStudioModel(baseURL, model);
|
||||
case 'xAI':
|
||||
return getXAIModel(apiKey, model);
|
||||
case 'Cohere':
|
||||
return getCohereAIModel(apiKey, model);
|
||||
default:
|
||||
return getOllamaModel(baseURL, model);
|
||||
}
|
||||
|
||||
@@ -58,7 +58,6 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
|
||||
|
||||
return { model, provider, content: cleanedContent };
|
||||
}
|
||||
|
||||
export function streamText(
|
||||
messages: Messages,
|
||||
env: Env,
|
||||
@@ -68,8 +67,6 @@ export function streamText(
|
||||
let currentModel = DEFAULT_MODEL;
|
||||
let currentProvider = DEFAULT_PROVIDER;
|
||||
|
||||
// console.log('StreamText:', JSON.stringify(messages));
|
||||
|
||||
const processedMessages = messages.map((message) => {
|
||||
if (message.role === 'user') {
|
||||
const { model, provider, content } = extractPropertiesFromMessage(message);
|
||||
@@ -83,25 +80,19 @@ export function streamText(
|
||||
return { ...message, content };
|
||||
}
|
||||
|
||||
return message; // No changes for non-user messages
|
||||
});
|
||||
const modelDetails = MODEL_LIST.find((m) => m.name === currentModel);
|
||||
|
||||
// console.log('Message content:', messages[0].content);
|
||||
// console.log('Extracted properties:', extractPropertiesFromMessage(messages[0]));
|
||||
const dynamicMaxTokens =
|
||||
modelDetails && modelDetails.maxTokenAllowed
|
||||
? modelDetails.maxTokenAllowed
|
||||
: MAX_TOKENS;
|
||||
|
||||
const llmClient = getModel(currentProvider, currentModel, env, apiKeys);
|
||||
// console.log('LLM Client:', llmClient);
|
||||
|
||||
const llmConfig = {
|
||||
...options,
|
||||
model: llmClient, //getModel(currentProvider, currentModel, env, apiKeys),
|
||||
provider: currentProvider,
|
||||
system: getSystemPrompt(),
|
||||
maxTokens: MAX_TOKENS,
|
||||
messages: convertToCoreMessages(processedMessages),
|
||||
};
|
||||
|
||||
// console.log('LLM Config:', llmConfig);
|
||||
|
||||
return _streamText(llmConfig);
|
||||
}
|
||||
return _streamText({
|
||||
model: getModel(currentProvider, currentModel, env, apiKeys),
|
||||
system: getSystemPrompt(),
|
||||
maxTokens: dynamicMaxTokens,
|
||||
messages: convertToCoreMessages(processedMessages),
|
||||
...options,
|
||||
});
|
||||
}
|
||||
)}
|
||||
|
||||
@@ -2,3 +2,4 @@ export * from './useMessageParser';
|
||||
export * from './usePromptEnhancer';
|
||||
export * from './useShortcuts';
|
||||
export * from './useSnapScroll';
|
||||
export { default } from './useViewport';
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useState } from 'react';
|
||||
import type { ProviderInfo } from '~/types/model';
|
||||
import { createScopedLogger } from '~/utils/logger';
|
||||
|
||||
const logger = createScopedLogger('usePromptEnhancement');
|
||||
@@ -13,54 +14,54 @@ export function usePromptEnhancer() {
|
||||
};
|
||||
|
||||
const enhancePrompt = async (
|
||||
input: string,
|
||||
input: string,
|
||||
setInput: (value: string) => void,
|
||||
model: string,
|
||||
provider: string,
|
||||
apiKeys?: Record<string, string>
|
||||
provider: ProviderInfo,
|
||||
apiKeys?: Record<string, string>,
|
||||
) => {
|
||||
setEnhancingPrompt(true);
|
||||
setPromptEnhanced(false);
|
||||
|
||||
|
||||
const requestBody: any = {
|
||||
message: input,
|
||||
model,
|
||||
provider,
|
||||
};
|
||||
|
||||
|
||||
if (apiKeys) {
|
||||
requestBody.apiKeys = apiKeys;
|
||||
}
|
||||
|
||||
|
||||
const response = await fetch('/api/enhancer', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
|
||||
|
||||
const originalInput = input;
|
||||
|
||||
|
||||
if (reader) {
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
|
||||
let _input = '';
|
||||
let _error;
|
||||
|
||||
|
||||
try {
|
||||
setInput('');
|
||||
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
_input += decoder.decode(value);
|
||||
|
||||
|
||||
logger.trace('Set input', _input);
|
||||
|
||||
|
||||
setInput(_input);
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -70,10 +71,10 @@ export function usePromptEnhancer() {
|
||||
if (_error) {
|
||||
logger.error(_error);
|
||||
}
|
||||
|
||||
|
||||
setEnhancingPrompt(false);
|
||||
setPromptEnhanced(true);
|
||||
|
||||
|
||||
setTimeout(() => {
|
||||
setInput(_input);
|
||||
});
|
||||
|
||||
18
app/lib/hooks/useViewport.ts
Normal file
18
app/lib/hooks/useViewport.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
|
||||
const useViewport = (threshold = 1024) => {
|
||||
const [isSmallViewport, setIsSmallViewport] = useState(window.innerWidth < threshold);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => setIsSmallViewport(window.innerWidth < threshold);
|
||||
window.addEventListener('resize', handleResize);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
};
|
||||
}, [threshold]);
|
||||
|
||||
return isSmallViewport;
|
||||
};
|
||||
|
||||
export default useViewport;
|
||||
Reference in New Issue
Block a user