feat: enhance context handling by adding code context selection and implementing summary generation (#1091) #release
* feat: add context annotation types and enhance file handling in LLM processing * feat: enhance context handling by adding chatId to annotations and implementing summary generation * removed useless changes * feat: updated token counts to include optimization requests * prompt fix * logging added * useless logs removed
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
import { type ActionFunctionArgs } from '@remix-run/cloudflare';
|
||||
import { createDataStream } from 'ai';
|
||||
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
|
||||
import { createDataStream, generateId } from 'ai';
|
||||
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS, type FileMap } from '~/lib/.server/llm/constants';
|
||||
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts';
|
||||
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
|
||||
import SwitchableStream from '~/lib/.server/llm/switchable-stream';
|
||||
import type { IProviderSetting } from '~/types/model';
|
||||
import { createScopedLogger } from '~/utils/logger';
|
||||
import { getFilePaths, selectContext } from '~/lib/.server/llm/select-context';
|
||||
import type { ContextAnnotation, ProgressAnnotation } from '~/types/context';
|
||||
import { WORK_DIR } from '~/utils/constants';
|
||||
import { createSummary } from '~/lib/.server/llm/create-summary';
|
||||
|
||||
export async function action(args: ActionFunctionArgs) {
|
||||
return chatAction(args);
|
||||
@@ -52,23 +56,121 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
promptTokens: 0,
|
||||
totalTokens: 0,
|
||||
};
|
||||
const encoder: TextEncoder = new TextEncoder();
|
||||
let progressCounter: number = 1;
|
||||
|
||||
try {
|
||||
const options: StreamingOptions = {
|
||||
toolChoice: 'none',
|
||||
onFinish: async ({ text: content, finishReason, usage }) => {
|
||||
logger.debug('usage', JSON.stringify(usage));
|
||||
const totalMessageContent = messages.reduce((acc, message) => acc + message.content, '');
|
||||
logger.debug(`Total message length: ${totalMessageContent.split(' ').length}, words`);
|
||||
|
||||
if (usage) {
|
||||
cumulativeUsage.completionTokens += usage.completionTokens || 0;
|
||||
cumulativeUsage.promptTokens += usage.promptTokens || 0;
|
||||
cumulativeUsage.totalTokens += usage.totalTokens || 0;
|
||||
const dataStream = createDataStream({
|
||||
async execute(dataStream) {
|
||||
const filePaths = getFilePaths(files || {});
|
||||
let filteredFiles: FileMap | undefined = undefined;
|
||||
let summary: string | undefined = undefined;
|
||||
|
||||
if (filePaths.length > 0 && contextOptimization) {
|
||||
dataStream.writeData('HI ');
|
||||
logger.debug('Generating Chat Summary');
|
||||
dataStream.writeMessageAnnotation({
|
||||
type: 'progress',
|
||||
value: progressCounter++,
|
||||
message: 'Generating Chat Summary',
|
||||
} as ProgressAnnotation);
|
||||
|
||||
// Create a summary of the chat
|
||||
console.log(`Messages count: ${messages.length}`);
|
||||
|
||||
summary = await createSummary({
|
||||
messages: [...messages],
|
||||
env: context.cloudflare?.env,
|
||||
apiKeys,
|
||||
providerSettings,
|
||||
promptId,
|
||||
contextOptimization,
|
||||
onFinish(resp) {
|
||||
if (resp.usage) {
|
||||
logger.debug('createSummary token usage', JSON.stringify(resp.usage));
|
||||
cumulativeUsage.completionTokens += resp.usage.completionTokens || 0;
|
||||
cumulativeUsage.promptTokens += resp.usage.promptTokens || 0;
|
||||
cumulativeUsage.totalTokens += resp.usage.totalTokens || 0;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
dataStream.writeMessageAnnotation({
|
||||
type: 'chatSummary',
|
||||
summary,
|
||||
chatId: messages.slice(-1)?.[0]?.id,
|
||||
} as ContextAnnotation);
|
||||
|
||||
// Update context buffer
|
||||
logger.debug('Updating Context Buffer');
|
||||
dataStream.writeMessageAnnotation({
|
||||
type: 'progress',
|
||||
value: progressCounter++,
|
||||
message: 'Updating Context Buffer',
|
||||
} as ProgressAnnotation);
|
||||
|
||||
// Select context files
|
||||
console.log(`Messages count: ${messages.length}`);
|
||||
filteredFiles = await selectContext({
|
||||
messages: [...messages],
|
||||
env: context.cloudflare?.env,
|
||||
apiKeys,
|
||||
files,
|
||||
providerSettings,
|
||||
promptId,
|
||||
contextOptimization,
|
||||
summary,
|
||||
onFinish(resp) {
|
||||
if (resp.usage) {
|
||||
logger.debug('selectContext token usage', JSON.stringify(resp.usage));
|
||||
cumulativeUsage.completionTokens += resp.usage.completionTokens || 0;
|
||||
cumulativeUsage.promptTokens += resp.usage.promptTokens || 0;
|
||||
cumulativeUsage.totalTokens += resp.usage.totalTokens || 0;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
if (filteredFiles) {
|
||||
logger.debug(`files in context : ${JSON.stringify(Object.keys(filteredFiles))}`);
|
||||
}
|
||||
|
||||
dataStream.writeMessageAnnotation({
|
||||
type: 'codeContext',
|
||||
files: Object.keys(filteredFiles).map((key) => {
|
||||
let path = key;
|
||||
|
||||
if (path.startsWith(WORK_DIR)) {
|
||||
path = path.replace(WORK_DIR, '');
|
||||
}
|
||||
|
||||
return path;
|
||||
}),
|
||||
} as ContextAnnotation);
|
||||
|
||||
dataStream.writeMessageAnnotation({
|
||||
type: 'progress',
|
||||
value: progressCounter++,
|
||||
message: 'Context Buffer Updated',
|
||||
} as ProgressAnnotation);
|
||||
logger.debug('Context Buffer Updated');
|
||||
}
|
||||
|
||||
if (finishReason !== 'length') {
|
||||
const encoder = new TextEncoder();
|
||||
const usageStream = createDataStream({
|
||||
async execute(dataStream) {
|
||||
// Stream the text
|
||||
const options: StreamingOptions = {
|
||||
toolChoice: 'none',
|
||||
onFinish: async ({ text: content, finishReason, usage }) => {
|
||||
logger.debug('usage', JSON.stringify(usage));
|
||||
|
||||
if (usage) {
|
||||
cumulativeUsage.completionTokens += usage.completionTokens || 0;
|
||||
cumulativeUsage.promptTokens += usage.promptTokens || 0;
|
||||
cumulativeUsage.totalTokens += usage.totalTokens || 0;
|
||||
}
|
||||
|
||||
if (finishReason !== 'length') {
|
||||
dataStream.writeMessageAnnotation({
|
||||
type: 'usage',
|
||||
value: {
|
||||
@@ -77,80 +179,89 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
totalTokens: cumulativeUsage.totalTokens,
|
||||
},
|
||||
});
|
||||
},
|
||||
onError: (error: any) => `Custom error: ${error.message}`,
|
||||
}).pipeThrough(
|
||||
new TransformStream({
|
||||
transform: (chunk, controller) => {
|
||||
// Convert the string stream to a byte stream
|
||||
const str = typeof chunk === 'string' ? chunk : JSON.stringify(chunk);
|
||||
controller.enqueue(encoder.encode(str));
|
||||
},
|
||||
}),
|
||||
);
|
||||
await stream.switchSource(usageStream);
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
stream.close();
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
return;
|
||||
}
|
||||
// stream.close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
|
||||
throw Error('Cannot continue message: Maximum segments reached');
|
||||
}
|
||||
if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
|
||||
throw Error('Cannot continue message: Maximum segments reached');
|
||||
}
|
||||
|
||||
const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
|
||||
const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
|
||||
|
||||
logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`);
|
||||
logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`);
|
||||
|
||||
messages.push({ role: 'assistant', content });
|
||||
messages.push({ role: 'user', content: CONTINUE_PROMPT });
|
||||
messages.push({ id: generateId(), role: 'assistant', content });
|
||||
messages.push({ id: generateId(), role: 'user', content: CONTINUE_PROMPT });
|
||||
|
||||
const result = await streamText({
|
||||
messages,
|
||||
env: context.cloudflare?.env,
|
||||
options,
|
||||
apiKeys,
|
||||
files,
|
||||
providerSettings,
|
||||
promptId,
|
||||
contextOptimization,
|
||||
});
|
||||
|
||||
result.mergeIntoDataStream(dataStream);
|
||||
|
||||
(async () => {
|
||||
for await (const part of result.fullStream) {
|
||||
if (part.type === 'error') {
|
||||
const error: any = part.error;
|
||||
logger.error(`${error}`);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
const result = await streamText({
|
||||
messages,
|
||||
env: context.cloudflare.env,
|
||||
env: context.cloudflare?.env,
|
||||
options,
|
||||
apiKeys,
|
||||
files,
|
||||
providerSettings,
|
||||
promptId,
|
||||
contextOptimization,
|
||||
contextFiles: filteredFiles,
|
||||
summary,
|
||||
});
|
||||
|
||||
stream.switchSource(result.toDataStream());
|
||||
(async () => {
|
||||
for await (const part of result.fullStream) {
|
||||
if (part.type === 'error') {
|
||||
const error: any = part.error;
|
||||
logger.error(`${error}`);
|
||||
|
||||
return;
|
||||
return;
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
result.mergeIntoDataStream(dataStream);
|
||||
},
|
||||
};
|
||||
const totalMessageContent = messages.reduce((acc, message) => acc + message.content, '');
|
||||
logger.debug(`Total message length: ${totalMessageContent.split(' ').length}, words`);
|
||||
onError: (error: any) => `Custom error: ${error.message}`,
|
||||
}).pipeThrough(
|
||||
new TransformStream({
|
||||
transform: (chunk, controller) => {
|
||||
// Convert the string stream to a byte stream
|
||||
const str = typeof chunk === 'string' ? chunk : JSON.stringify(chunk);
|
||||
controller.enqueue(encoder.encode(str));
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await streamText({
|
||||
messages,
|
||||
env: context.cloudflare.env,
|
||||
options,
|
||||
apiKeys,
|
||||
files,
|
||||
providerSettings,
|
||||
promptId,
|
||||
contextOptimization,
|
||||
});
|
||||
|
||||
(async () => {
|
||||
for await (const part of result.fullStream) {
|
||||
if (part.type === 'error') {
|
||||
const error: any = part.error;
|
||||
logger.error(`${error}`);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
stream.switchSource(result.toDataStream());
|
||||
|
||||
// return createrespo
|
||||
return new Response(stream.readable, {
|
||||
return new Response(dataStream, {
|
||||
status: 200,
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream; charset=utf-8',
|
||||
|
||||
Reference in New Issue
Block a user