diff --git a/package.json b/package.json index 5287915..14912c4 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@ztimson/ai-utils", - "version": "0.1.21", + "version": "0.1.22", "description": "AI Utility library", "author": "Zak Timson", "license": "MIT", diff --git a/src/antrhopic.ts b/src/antrhopic.ts index 42117c2..8160dc4 100644 --- a/src/antrhopic.ts +++ b/src/antrhopic.ts @@ -73,12 +73,15 @@ export class Anthropic extends LLMProvider { stream: !!options.stream, }; - let resp: any; + let resp: any, isFirstMessage = true; const assistantMessages: string[] = []; do { resp = await this.client.messages.create(requestParams); + + // Streaming mode if(options.stream) { - if(assistantMessages.length) options.stream({text: '\n\n'}); + if(!isFirstMessage) options.stream({text: '\n\n'}); + else isFirstMessage = false; resp.content = []; for await (const chunk of resp) { if(controller.signal.aborted) break; @@ -105,8 +108,7 @@ export class Anthropic extends LLMProvider { } } - const textContent = resp.content.filter((c: any) => c.type == 'text').map((c: any) => c.text).join('\n\n'); - if(textContent) assistantMessages.push(textContent); + // Run tools const toolCalls = resp.content.filter((c: any) => c.type === 'tool_use'); if(toolCalls.length && !controller.signal.aborted) { history.push({role: 'assistant', content: resp.content}); @@ -122,12 +124,12 @@ export class Anthropic extends LLMProvider { } })); history.push({role: 'user', content: results}); - original.push({role: 'user', content: results}); requestParams.messages = history; } } while (!controller.signal.aborted && resp.content.some((c: any) => c.type === 'tool_use')); + if(options.stream) options.stream({done: true}); - res(this.toStandard([...original, {role: 'assistant', content: assistantMessages.join('\n\n'), timestamp: Date.now()}])); + res(this.toStandard([...history, {role: 'assistant', content: resp.content.filter((c: any) => c.type == 'text').map((c: any) => c.text).join('\n\n')}])); }); return Object.assign(response, {abort: () => controller.abort()}); diff --git a/src/ollama.ts b/src/ollama.ts index 6caeb2c..72e3edc 100644 --- a/src/ollama.ts +++ b/src/ollama.ts @@ -72,12 +72,12 @@ export class Ollama extends LLMProvider { })) } - let resp: any; - const loopMessages: any[] = []; + let resp: any, isFirstMessage = true; do { resp = await this.client.chat(requestParams); if(options.stream) { - if(loopMessages.length) options.stream({text: '\n\n'}); + if(!isFirstMessage) options.stream({text: '\n\n'}); + else isFirstMessage = false; resp.message = {role: 'assistant', content: '', tool_calls: []}; for await (const chunk of resp) { if(controller.signal.aborted) break; @@ -90,7 +90,6 @@ export class Ollama extends LLMProvider { } } - loopMessages.push({role: 'assistant', content: resp.message?.content, timestamp: Date.now()}); if(resp.message?.tool_calls?.length && !controller.signal.aborted) { history.push(resp.message); const results = await Promise.all(resp.message.tool_calls.map(async (toolCall: any) => { @@ -105,15 +104,12 @@ export class Ollama extends LLMProvider { } })); history.push(...results); - loopMessages.push(...results.map(r => ({...r, timestamp: Date.now()}))); requestParams.messages = history; } } while (!controller.signal.aborted && resp.message?.tool_calls?.length); - const combinedContent = loopMessages.filter(m => m.role === 'assistant') - .map(m => m.content).filter(c => c).join('\n\n'); if(options.stream) options.stream({done: true}); - res(this.toStandard([...history, {role: 'assistant', content: combinedContent, timestamp: Date.now()}])); + res(this.toStandard([...history, {role: 'assistant', content: resp.message?.content}])); }); return Object.assign(response, {abort: () => controller.abort()}); diff --git a/src/open-ai.ts b/src/open-ai.ts index 4862efc..55c6840 100644 --- a/src/open-ai.ts +++ b/src/open-ai.ts @@ -87,12 +87,12 @@ export class OpenAi extends LLMProvider { })) }; - let resp: any; - const loopMessages: any[] = []; + let resp: any, isFirstMessage = true; do { resp = await this.client.chat.completions.create(requestParams); if(options.stream) { - if(loopMessages.length) options.stream({text: '\n\n'}); + if(!isFirstMessage) options.stream({text: '\n\n'}); + else isFirstMessage = false; resp.choices = [{message: {content: '', tool_calls: []}}]; for await (const chunk of resp) { if(controller.signal.aborted) break; @@ -106,8 +106,6 @@ export class OpenAi extends LLMProvider { } } - loopMessages.push({role: 'assistant', content: resp.choices[0].message.content || '', timestamp: Date.now()}); - const toolCalls = resp.choices[0].message.tool_calls || []; if(toolCalls.length && !controller.signal.aborted) { history.push(resp.choices[0].message); @@ -123,15 +121,12 @@ export class OpenAi extends LLMProvider { } })); history.push(...results); - loopMessages.push(...results.map(r => ({...r, timestamp: Date.now()}))); requestParams.messages = history; } } while (!controller.signal.aborted && resp.choices?.[0]?.message?.tool_calls?.length); - const combinedContent = loopMessages.filter(m => m.role === 'assistant') - .map(m => m.content).filter(c => c).join('\n\n'); if(options.stream) options.stream({done: true}); - res(this.toStandard([...history, {role: 'assistant', content: combinedContent, timestamp: Date.now()}])); + res(this.toStandard([...history, {role: 'assistant', content: resp.choices[0].message.content || ''}])); }); return Object.assign(response, {abort: () => controller.abort()}); }