diff --git a/package.json b/package.json index 4b86031..42ad8a5 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@ztimson/ai-utils", - "version": "0.8.3", + "version": "0.8.4", "description": "AI Utility library", "author": "Zak Timson", "license": "MIT", diff --git a/src/open-ai.ts b/src/open-ai.ts index e83fe53..5e1c4e9 100644 --- a/src/open-ai.ts +++ b/src/open-ai.ts @@ -103,15 +103,37 @@ export class OpenAi extends LLMProvider { if(options.stream) { if(!isFirstMessage) options.stream({text: '\n\n'}); else isFirstMessage = false; - resp.choices = [{message: {content: '', tool_calls: []}}]; + resp.choices = [{message: {role: 'assistant', content: '', tool_calls: []}}]; for await (const chunk of resp) { if(controller.signal.aborted) break; if(chunk.choices[0].delta.content) { resp.choices[0].message.content += chunk.choices[0].delta.content; options.stream({text: chunk.choices[0].delta.content}); } + if(chunk.choices[0].delta.tool_calls) { - resp.choices[0].message.tool_calls = chunk.choices[0].delta.tool_calls; + for(const deltaTC of chunk.choices[0].delta.tool_calls) { + const existing = resp.choices[0].message.tool_calls.find(tc => tc.index === deltaTC.index); + if(existing) { + if(deltaTC.id) existing.id = deltaTC.id; + if(deltaTC.type) existing.type = deltaTC.type; + if(deltaTC.function) { + if(!existing.function) existing.function = {}; + if(deltaTC.function.name) existing.function.name = deltaTC.function.name; + if(deltaTC.function.arguments) existing.function.arguments = (existing.function.arguments || '') + deltaTC.function.arguments; + } + } else { + resp.choices[0].message.tool_calls.push({ + index: deltaTC.index, + id: deltaTC.id || '', + type: deltaTC.type || 'function', + function: { + name: deltaTC.function?.name || '', + arguments: deltaTC.function?.arguments || '' + } + }); + } + } } } }