116 lines
3.6 KiB
TypeScript
116 lines
3.6 KiB
TypeScript
import { FunctionDefinition } from "openai/resources";
|
|
import {
|
|
ChatCompletionMessageParam
|
|
, ChatCompletionMessageToolCall
|
|
, ChatCompletionTool
|
|
} from "openai/resources/chat";
|
|
import { type FromSchema, type JSONSchema } from "json-schema-to-ts";
|
|
|
|
type OpenAIFunctionRequestData = (JSONSchema & {
|
|
type: "object"
|
|
});
|
|
|
|
type ChatCompletionToolDefinition = ChatCompletionTool;
|
|
type ChatCompletionToolCall = ChatCompletionMessageToolCall;
|
|
|
|
type ChatCompletionFunctionDefinition = FunctionDefinition;
|
|
|
|
/**
|
|
* Represents the function that can be ran by the OpenAI model
|
|
*/
|
|
export interface OpenAIFunction<
|
|
T extends Readonly<OpenAIFunctionRequestData> = Readonly<OpenAIFunctionRequestData>
|
|
> {
|
|
name: string,
|
|
description?: string,
|
|
parameters: T,
|
|
}
|
|
|
|
export abstract class OpenAIFunction<
|
|
T extends Readonly<OpenAIFunctionRequestData> = Readonly<OpenAIFunctionRequestData>
|
|
> {
|
|
getSettings(): ChatCompletionFunctionDefinition {
|
|
return {
|
|
name: this.name,
|
|
description: this.description,
|
|
parameters: this.parameters as Record<string, unknown>,
|
|
};
|
|
}
|
|
|
|
abstract execute(data: FromSchema<T>): Promise<string>;
|
|
}
|
|
|
|
/*
|
|
* Manages functions for the OpenAI
|
|
**/
|
|
export default class FunctionManager {
|
|
store = new Map<string, OpenAIFunction>();
|
|
|
|
constructor() {
|
|
// TODO: import functions from functions directory
|
|
}
|
|
|
|
public getTools(): ChatCompletionToolDefinition[] {
|
|
const rvalue: ChatCompletionToolDefinition[] = [];
|
|
for (const [, value] of this.store) {
|
|
rvalue.push({type: "function", function: value.getSettings()});
|
|
}
|
|
return rvalue;
|
|
}
|
|
|
|
public getToolsForOpenAi(): ChatCompletionTool[] | undefined {
|
|
const rvalue = this.getTools();
|
|
return rvalue.length > 0 ? rvalue : undefined;
|
|
}
|
|
|
|
public handleFunction(request: ChatCompletionToolCall): Promise<ChatCompletionMessageParam> {
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
let parsedArguments: any;
|
|
|
|
const functionToRun = this.store.get(request.function.name);
|
|
|
|
// check if the function is registered
|
|
if (!functionToRun) {
|
|
return Promise.resolve({
|
|
role: "system",
|
|
content: `Only use functions that were provided to you (response for tool call ID: ${request.id})`,
|
|
});
|
|
}
|
|
|
|
try {
|
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
|
parsedArguments = JSON.parse(request.function.arguments);
|
|
}
|
|
catch (e) {
|
|
console.error("Function arguments raw: " + request.function.arguments);
|
|
throw new Error(`Failed to parse the function JSON arguments when running function [${request.function.name}]`, {cause: e});
|
|
}
|
|
// FIXME: Verify if the parsedArguments matches the requested function argument declaration.
|
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
|
return functionToRun.execute(parsedArguments).then(content => {
|
|
return {
|
|
role: "tool",
|
|
tool_call_id: request.id,
|
|
content: content,
|
|
};
|
|
});
|
|
}
|
|
|
|
public handleToolCall(call: ChatCompletionToolCall): Promise<ChatCompletionMessageParam> {
|
|
if (call.type === "function") {
|
|
return this.handleFunction(call);
|
|
}
|
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
|
throw new Error(`Unsupported tool call type: ${call.type || "never"}`);
|
|
}
|
|
|
|
public handleToolCalls(calls: ChatCompletionToolCall[]) {
|
|
const rvalue: Promise<ChatCompletionMessageParam>[] = [];
|
|
for (const call of calls) {
|
|
if (call.type === "function") {
|
|
rvalue.push(this.handleToolCall(call));
|
|
}
|
|
}
|
|
return Promise.all(rvalue);
|
|
}
|
|
}
|