diff --git a/src/llms/chat.ts b/src/llms/chat.ts index 7e6cb6e..0a75ac9 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -9,7 +9,11 @@ import * as types from '@/types' import { parseOutput } from '@/llms/parse-output' import { BaseTask } from '@/task' import { getCompiledTemplate } from '@/template' -import { extractFunctionIdentifierFromString, stringifyForModel } from '@/utils' +import { + extractFunctionIdentifierFromString, + isFunction, + stringifyForModel +} from '@/utils' import { BaseLLM } from './llm' import { @@ -23,7 +27,7 @@ export abstract class BaseChatCompletion< TModelParams extends Record = Record, TChatCompletionResponse extends Record = Record > extends BaseLLM { - protected _messages: types.ChatMessage[] + protected _messages: types.ChatMessageInput[] protected _tools?: BaseTask[] constructor( @@ -116,14 +120,17 @@ export abstract class BaseChatCompletion< input = this.inputSchema.parse(input) } - const messages = this._messages + const messages: types.ChatMessage[] = this._messages .map((message) => { return { ...message, content: message.content - ? getCompiledTemplate(dedent(message.content))(input).trim() + ? // support functions which return a string + isFunction(message.content) + ? message.content(input!) + : getCompiledTemplate(dedent(message.content))(input).trim() : '' - } + } as types.ChatMessage }) .filter((message) => message.content || message.function_call) @@ -170,6 +177,7 @@ export abstract class BaseChatCompletion< 'additionalProperties', '$schema' ]) + let label: string if (outputSchema instanceof z.ZodArray) { label = 'JSON array (minified)' diff --git a/src/llms/openai.ts b/src/llms/openai.ts index 09d6202..72f4a3d 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -14,7 +14,7 @@ const openaiModelsSupportingFunctions = new Set([ ]) export class OpenAIChatCompletion< - TInput extends types.TaskInput = any, + TInput extends types.TaskInput = void, TOutput extends types.TaskOutput = string > extends BaseChatCompletion< TInput, diff --git a/src/task.ts b/src/task.ts index 6252af2..02af003 100644 --- a/src/task.ts +++ b/src/task.ts @@ -156,7 +156,7 @@ export abstract class BaseTask< /** * Calls this task with the given `input` and returns the result only. */ - public async call(input?: TInput): Promise { + public async call(input: TInput): Promise { const res = await this.callWithMetadata(input) return res.result } @@ -165,7 +165,7 @@ export abstract class BaseTask< * Calls this task with the given `input` and returns the result along with metadata. */ public async callWithMetadata( - input?: TInput, + input: TInput, parentCtx?: types.TaskCallContext ): Promise> { this.validate() diff --git a/src/types.ts b/src/types.ts index 4b45700..9870c07 100644 --- a/src/types.ts +++ b/src/types.ts @@ -73,12 +73,18 @@ export interface LLMOptions< export type ChatMessage = openai.ChatMessage export type ChatMessageRole = openai.ChatMessageRole +export type ChatMessageInput = + | ChatMessage + | { + content: (input: TInput | any) => string + } + export interface ChatModelOptions< TInput extends TaskInput = void, TOutput extends TaskOutput = string, TModelParams extends Record = Record > extends BaseLLMOptions { - messages: ChatMessage[] + messages: ChatMessageInput[] tools?: BaseTask[] } @@ -136,7 +142,7 @@ export interface TaskCallContext< TInput extends TaskInput = void, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { - input?: TInput + input: TInput retryMessage?: string attemptNumber: number diff --git a/src/utils.ts b/src/utils.ts index 3e8a2c2..18c74c7 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -210,3 +210,8 @@ export function throttleKy( } }) } + +// eslint-disable-next-line @typescript-eslint/ban-types +export function isFunction(value: any): value is Function { + return typeof value === 'function' +} diff --git a/test/llms/openai.test.ts b/test/llms/openai.test.ts index 1b8da04..d2a015d 100644 --- a/test/llms/openai.test.ts +++ b/test/llms/openai.test.ts @@ -224,3 +224,29 @@ test('OpenAIChatCompletion ⇒ missing template variable', async (t) => { message: 'Template error: "numFacts" not defined in input - 1:10' }) }) + +test('OpenAIChatCompletion ⇒ function inputs', async (t) => { + t.timeout(2 * 60 * 1000) + const agentic = createTestAgenticRuntime() + + const builder = new OpenAIChatCompletion({ + agentic, + modelParams: { + temperature: 0, + max_tokens: 30 + }, + messages: [ + { + role: 'user', + content: (input) => `tell me about ${input.topic}` + } + ] + }) + .input(z.object({ topic: z.string() })) + .output(z.string()) + + const result = await builder.call({ topic: 'cats' }) + t.truthy(typeof result === 'string') + + expectTypeOf(result).toMatchTypeOf() +})