functionManager: retrofit for tool calls api

This commit is contained in:
Wroclaw 2024-04-23 17:46:24 +02:00
parent 91232e99a7
commit 0e5c8d22cc
2 changed files with 55 additions and 30 deletions

View file

@ -246,7 +246,7 @@ async function executeFromQueue(channel: string) {
...config.chatCompletionParams, ...config.chatCompletionParams,
messages: OpenAImessages, messages: OpenAImessages,
// FIXME: don't use new instance of FunctionManager // FIXME: don't use new instance of FunctionManager
functions: new FunctionManager().getFunctionsForOpenAi(), tools: new FunctionManager().getToolsForOpenAi(),
}); });
logUsedTokens(answer, message, ++functionRanCounter); logUsedTokens(answer, message, ++functionRanCounter);
@ -254,13 +254,11 @@ async function executeFromQueue(channel: string) {
generatedMessage = answer.choices[0].message; generatedMessage = answer.choices[0].message;
if (!generatedMessage) throw new Error("Empty message received"); if (!generatedMessage) throw new Error("Empty message received");
// handle function calls // handle tool calls
if (generatedMessage.function_call) { if (generatedMessage.tool_calls !== undefined && generatedMessage.tool_calls.length > 0) {
OpenAImessages.push(generatedMessage); OpenAImessages.push(generatedMessage);
// FIXME: don't use new instance of FunctionManager // FIXME: don't use new instance of FunctionManager
OpenAImessages.push( OpenAImessages.push(...(await new FunctionManager().handleToolCalls(generatedMessage.tool_calls)));
new FunctionManager().handleFunction(generatedMessage.function_call)
);
} }
} while (generatedMessage.function_call); } while (generatedMessage.function_call);

View file

@ -1,4 +1,9 @@
import { ChatCompletionCreateParams, ChatCompletionMessage, ChatCompletionMessageParam } from "openai/resources/chat"; import { FunctionDefinition } from "openai/resources";
import {
ChatCompletionMessageParam
, ChatCompletionMessageToolCall
, ChatCompletionTool
} from "openai/resources/chat";
type parameterMap = { type parameterMap = {
string: string, string: string,
@ -11,8 +16,10 @@ type OpenAIFunctionRequestData<T extends nameTypeMap> = {
[name in keyof T]: T[name]; [name in keyof T]: T[name];
}; };
type ChatCompletionFunctions = ChatCompletionCreateParams.Function; type ChatCompletionToolDefinition = ChatCompletionTool;
type ChatCompletionFunctionCall = ChatCompletionMessage.FunctionCall; type ChatCompletionToolCall = ChatCompletionMessageToolCall;
type ChatCompletionFunctionDefinition = FunctionDefinition;
/** /**
* Represents the function that can be ran by the OpenAI model * Represents the function that can be ran by the OpenAI model
@ -33,7 +40,7 @@ export interface OpenAIFunction<T extends nameTypeMap = nameTypeMap> {
} }
export abstract class OpenAIFunction<T extends nameTypeMap = nameTypeMap> { export abstract class OpenAIFunction<T extends nameTypeMap = nameTypeMap> {
getSettings(): ChatCompletionFunctions { getSettings(): ChatCompletionFunctionDefinition {
return { return {
name: this.name, name: this.name,
description: this.description, description: this.description,
@ -41,7 +48,7 @@ export abstract class OpenAIFunction<T extends nameTypeMap = nameTypeMap> {
}; };
} }
abstract execute(data: OpenAIFunctionRequestData<T>): string; abstract execute(data: OpenAIFunctionRequestData<T>): Promise<string>;
} }
/* /*
@ -54,47 +61,67 @@ export default class FunctionManager {
// TODO: import functions from functions directory // TODO: import functions from functions directory
} }
public getFunctions(): ChatCompletionFunctions[] { public getTools(): ChatCompletionToolDefinition[] {
const rvalue: ChatCompletionFunctions[] = []; const rvalue: ChatCompletionToolDefinition[] = [];
for (const [, value] of this.store) { for (const [, value] of this.store) {
rvalue.push(value.getSettings()); rvalue.push({type: "function", function: value.getSettings()});
} }
return rvalue; return rvalue;
} }
public getFunctionsForOpenAi(): ChatCompletionFunctions[] | undefined { public getToolsForOpenAi(): ChatCompletionTool[] | undefined {
const rvalue = this.getFunctions(); const rvalue = this.getTools();
return rvalue.length > 0 ? rvalue : undefined; return rvalue.length > 0 ? rvalue : undefined;
} }
public handleFunction(request: ChatCompletionFunctionCall): ChatCompletionMessageParam { public handleFunction(request: ChatCompletionToolCall): Promise<ChatCompletionMessageParam> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
let parsedArguments: any; let parsedArguments: any;
const functionToRun = this.store.get(request.name ?? ""); const functionToRun = this.store.get(request.function.name);
// check if the function is registered // check if the function is registered
if (!functionToRun) { if (!functionToRun) {
return { return Promise.resolve({
role: "system", role: "system",
content: "Only use functions that were provided to you", content: `Only use functions that were provided to you (response for tool call ID: ${request.id})`,
}; });
} }
try { try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
parsedArguments = JSON.parse(request.arguments ?? ""); parsedArguments = JSON.parse(request.function.arguments);
} }
catch (e) { catch (e) {
console.error("Function arguments raw: " + request.arguments); console.error("Function arguments raw: " + request.function.arguments);
throw new Error(`Failed to parse the function JSON arguments when running function [${request.name}]`, {cause: e}); throw new Error(`Failed to parse the function JSON arguments when running function [${request.function.name}]`, {cause: e});
} }
// FIXME: Verify if the parsedArguments matches the requested function argument declaration. // FIXME: Verify if the parsedArguments matches the requested function argument declaration.
return {
role: "function",
name: request.name,
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument // eslint-disable-next-line @typescript-eslint/no-unsafe-argument
content: functionToRun.execute(parsedArguments), return functionToRun.execute(parsedArguments).then(content => {
return {
role: "tool",
tool_call_id: request.id,
content: content,
}; };
});
}
public handleToolCall(call: ChatCompletionToolCall): Promise<ChatCompletionMessageParam> {
if (call.type === "function") {
return this.handleFunction(call);
}
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw new Error(`Unsupported tool call type: ${call.type || "never"}`);
}
public handleToolCalls(calls: ChatCompletionToolCall[]) {
const rvalue: Promise<ChatCompletionMessageParam>[] = [];
for (const call of calls) {
if (call.type === "function") {
rvalue.push(this.handleToolCall(call));
}
}
return Promise.all(rvalue);
} }
} }