diff --git a/src/execution.ts b/src/execution.ts index 44e5710..d68f7c0 100644 --- a/src/execution.ts +++ b/src/execution.ts @@ -246,7 +246,7 @@ async function executeFromQueue(channel: string) { ...config.chatCompletionParams, messages: OpenAImessages, // FIXME: don't use new instance of FunctionManager - functions: new FunctionManager().getFunctionsForOpenAi(), + tools: new FunctionManager().getToolsForOpenAi(), }); logUsedTokens(answer, message, ++functionRanCounter); @@ -254,13 +254,11 @@ async function executeFromQueue(channel: string) { generatedMessage = answer.choices[0].message; if (!generatedMessage) throw new Error("Empty message received"); - // handle function calls - if (generatedMessage.function_call) { + // handle tool calls + if (generatedMessage.tool_calls !== undefined && generatedMessage.tool_calls.length > 0) { OpenAImessages.push(generatedMessage); // FIXME: don't use new instance of FunctionManager - OpenAImessages.push( - new FunctionManager().handleFunction(generatedMessage.function_call) - ); + OpenAImessages.push(...(await new FunctionManager().handleToolCalls(generatedMessage.tool_calls))); } } while (generatedMessage.function_call); diff --git a/src/funcitonManager.ts b/src/funcitonManager.ts index 11b1bac..6ab27d0 100644 --- a/src/funcitonManager.ts +++ b/src/funcitonManager.ts @@ -1,4 +1,9 @@ -import { ChatCompletionCreateParams, ChatCompletionMessage, ChatCompletionMessageParam } from "openai/resources/chat"; +import { FunctionDefinition } from "openai/resources"; +import { + ChatCompletionMessageParam +, ChatCompletionMessageToolCall +, ChatCompletionTool +} from "openai/resources/chat"; type parameterMap = { string: string, @@ -11,8 +16,10 @@ type OpenAIFunctionRequestData = { [name in keyof T]: T[name]; }; -type ChatCompletionFunctions = ChatCompletionCreateParams.Function; -type ChatCompletionFunctionCall = ChatCompletionMessage.FunctionCall; +type ChatCompletionToolDefinition = ChatCompletionTool; +type ChatCompletionToolCall = ChatCompletionMessageToolCall; + +type ChatCompletionFunctionDefinition = FunctionDefinition; /** * Represents the function that can be ran by the OpenAI model @@ -33,7 +40,7 @@ export interface OpenAIFunction { } export abstract class OpenAIFunction { - getSettings(): ChatCompletionFunctions { + getSettings(): ChatCompletionFunctionDefinition { return { name: this.name, description: this.description, @@ -41,7 +48,7 @@ export abstract class OpenAIFunction { }; } - abstract execute(data: OpenAIFunctionRequestData): string; + abstract execute(data: OpenAIFunctionRequestData): Promise; } /* @@ -54,47 +61,67 @@ export default class FunctionManager { // TODO: import functions from functions directory } - public getFunctions(): ChatCompletionFunctions[] { - const rvalue: ChatCompletionFunctions[] = []; + public getTools(): ChatCompletionToolDefinition[] { + const rvalue: ChatCompletionToolDefinition[] = []; for (const [, value] of this.store) { - rvalue.push(value.getSettings()); + rvalue.push({type: "function", function: value.getSettings()}); } return rvalue; } - public getFunctionsForOpenAi(): ChatCompletionFunctions[] | undefined { - const rvalue = this.getFunctions(); + public getToolsForOpenAi(): ChatCompletionTool[] | undefined { + const rvalue = this.getTools(); return rvalue.length > 0 ? rvalue : undefined; } - public handleFunction(request: ChatCompletionFunctionCall): ChatCompletionMessageParam { + public handleFunction(request: ChatCompletionToolCall): Promise { // eslint-disable-next-line @typescript-eslint/no-explicit-any let parsedArguments: any; - const functionToRun = this.store.get(request.name ?? ""); + const functionToRun = this.store.get(request.function.name); // check if the function is registered if (!functionToRun) { - return { + return Promise.resolve({ role: "system", - content: "Only use functions that were provided to you", - }; + content: `Only use functions that were provided to you (response for tool call ID: ${request.id})`, + }); } try { // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - parsedArguments = JSON.parse(request.arguments ?? ""); + parsedArguments = JSON.parse(request.function.arguments); } catch (e) { - console.error("Function arguments raw: " + request.arguments); - throw new Error(`Failed to parse the function JSON arguments when running function [${request.name}]`, {cause: e}); + console.error("Function arguments raw: " + request.function.arguments); + throw new Error(`Failed to parse the function JSON arguments when running function [${request.function.name}]`, {cause: e}); } // FIXME: Verify if the parsedArguments matches the requested function argument declaration. - return { - role: "function", - name: request.name, - // eslint-disable-next-line @typescript-eslint/no-unsafe-argument - content: functionToRun.execute(parsedArguments), - }; + // eslint-disable-next-line @typescript-eslint/no-unsafe-argument + return functionToRun.execute(parsedArguments).then(content => { + return { + role: "tool", + tool_call_id: request.id, + content: content, + }; + }); + } + + public handleToolCall(call: ChatCompletionToolCall): Promise { + if (call.type === "function") { + return this.handleFunction(call); + } + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + throw new Error(`Unsupported tool call type: ${call.type || "never"}`); + } + + public handleToolCalls(calls: ChatCompletionToolCall[]) { + const rvalue: Promise[] = []; + for (const call of calls) { + if (call.type === "function") { + rvalue.push(this.handleToolCall(call)); + } + } + return Promise.all(rvalue); } }