diff --git a/src/llm.ts b/src/llm.ts index 37c9691..e72d21d 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1,7 +1,15 @@ +import { jsonrepair } from 'jsonrepair' +import Mustache from 'mustache' +import { dedent } from 'ts-dedent' import { ZodRawShape, ZodTypeAny, z } from 'zod' +import { printNode, zodToTs } from 'zod-to-ts' import * as types from './types' import { BaseTaskCallBuilder } from './task' +import { + extractJSONArrayFromString, + extractJSONObjectFromString +} from './utils' export abstract class BaseLLMCallBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, @@ -59,10 +67,11 @@ export abstract class BaseLLMCallBuilder< // }): Promise } -export abstract class ChatModelBuilder< +export abstract class BaseChatModelBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, - TModelParams extends Record = Record + TModelParams extends Record = Record, + TChatCompletionResponse extends Record = Record > extends BaseLLMCallBuilder { _messages: types.ChatMessage[] @@ -71,4 +80,153 @@ export abstract class ChatModelBuilder< this._messages = options.messages } + + protected abstract _createChatCompletion( + messages: types.ChatMessage[] + ): Promise> + + override async call( + input?: types.ParsedData + ): Promise> { + if (this._inputSchema) { + const inputSchema = + this._inputSchema instanceof z.ZodType + ? this._inputSchema + : z.object(this._inputSchema) + + // TODO: handle errors gracefully + input = inputSchema.parse(input) + } + + // TODO: validate input message variables against input schema + + const messages = this._messages + .map((message) => { + return { + ...message, + content: message.content + ? Mustache.render(dedent(message.content), input).trim() + : '' + } + }) + .filter((message) => message.content) + + if (this._examples?.length) { + // TODO: smarter example selection + for (const example of this._examples) { + messages.push({ + role: 'system', + content: `Example input: ${example.input}\n\nExample output: ${example.output}` + }) + } + } + + if (this._outputSchema) { + const outputSchema = + this._outputSchema instanceof z.ZodType + ? this._outputSchema + : z.object(this._outputSchema) + + const { node } = zodToTs(outputSchema) + + if (node.kind === 152) { + // handle raw strings differently + messages.push({ + role: 'system', + content: dedent`Output a raw string only, without any additional text.` + }) + } else { + const tsTypeString = printNode(node, { + removeComments: false, + // TODO: this doesn't seem to actually work, so we're doing it manually below + omitTrailingSemicolon: true, + noEmitHelpers: true + }) + .replace(/^ /gm, ' ') + .replace(/;$/gm, '') + + messages.push({ + role: 'system', + content: dedent`Output JSON only in the following TypeScript format: + \`\`\`ts + ${tsTypeString} + \`\`\`` + }) + } + } + + // TODO: filter/compress messages based on token counts + + console.log('>>>') + console.log(messages) + + const completion = await this._createChatCompletion(messages) + let output: any = completion.message.content + + console.log('===') + console.log(output) + console.log('<<<') + + if (this._outputSchema) { + const outputSchema = + this._outputSchema instanceof z.ZodType + ? this._outputSchema + : z.object(this._outputSchema) + + if (outputSchema instanceof z.ZodArray) { + try { + const trimmedOutput = extractJSONArrayFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + } catch (err) { + // TODO + throw err + } + } else if (outputSchema instanceof z.ZodObject) { + try { + const trimmedOutput = extractJSONObjectFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + } catch (err) { + // TODO + throw err + } + } else if (outputSchema instanceof z.ZodBoolean) { + output = output.toLowerCase().trim() + const booleanOutputs = { + true: true, + false: false, + yes: true, + no: false, + 1: true, + 0: false + } + + const booleanOutput = booleanOutputs[output] + if (booleanOutput !== undefined) { + output = booleanOutput + } else { + // TODO + throw new Error(`invalid boolean output: ${output}`) + } + } else if (outputSchema instanceof z.ZodNumber) { + output = output.trim() + + const numberOutput = outputSchema.isInt + ? parseInt(output) + : parseFloat(output) + + if (isNaN(numberOutput)) { + // TODO + throw new Error(`invalid number output: ${output}`) + } else { + output = numberOutput + } + } + + // TODO: handle errors, retry logic, and self-healing + + return outputSchema.parse(output) + } else { + return output + } + } } diff --git a/src/openai.ts b/src/openai.ts index 9e30b5a..6fe0e86 100644 --- a/src/openai.ts +++ b/src/openai.ts @@ -1,25 +1,18 @@ -import { jsonrepair } from 'jsonrepair' -import Mustache from 'mustache' -import { dedent } from 'ts-dedent' import type { SetRequired } from 'type-fest' import { ZodRawShape, ZodTypeAny, z } from 'zod' -import { printNode, zodToTs } from 'zod-to-ts' import * as types from './types' import { defaultOpenAIModel } from './constants' -import { ChatModelBuilder } from './llm' -import { - extractJSONArrayFromString, - extractJSONObjectFromString -} from './utils' +import { BaseChatModelBuilder } from './llm' export class OpenAIChatModelBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType -> extends ChatModelBuilder< +> extends BaseChatModelBuilder< TInput, TOutput, - SetRequired, 'model'> + SetRequired, 'model'>, + types.openai.ChatCompletionResponse > { _client: types.openai.OpenAIClient @@ -40,151 +33,20 @@ export class OpenAIChatModelBuilder< this._client = client } - override async call( - input?: types.ParsedData - ): Promise> { - if (this._inputSchema) { - const inputSchema = - this._inputSchema instanceof z.ZodType - ? this._inputSchema - : z.object(this._inputSchema) - - // TODO: handle errors gracefully - input = inputSchema.parse(input) - } - - // TODO: validate input message variables against input schema - - const messages = this._messages - .map((message) => { - return { - ...message, - content: message.content - ? Mustache.render(dedent(message.content), input).trim() - : '' - } - }) - .filter((message) => message.content) - - if (this._examples?.length) { - // TODO: smarter example selection - for (const example of this._examples) { - messages.push({ - role: 'system', - content: `Example input: ${example.input}\n\nExample output: ${example.output}` - }) - } - } - - if (this._outputSchema) { - const outputSchema = - this._outputSchema instanceof z.ZodType - ? this._outputSchema - : z.object(this._outputSchema) - - const { node } = zodToTs(outputSchema) - - if (node.kind === 152) { - // handle raw strings differently - messages.push({ - role: 'system', - content: dedent`Output a raw string only, without any additional text.` - }) - } else { - const tsTypeString = printNode(node, { - removeComments: false, - // TODO: this doesn't seem to actually work, so we're doing it manually below - omitTrailingSemicolon: true, - noEmitHelpers: true - }) - .replace(/^ /gm, ' ') - .replace(/;$/gm, '') - - messages.push({ - role: 'system', - content: dedent`Output JSON only in the following TypeScript format: - \`\`\`ts - ${tsTypeString} - \`\`\`` - }) - } - } - - // TODO: filter/compress messages based on token counts - - console.log('>>>') - console.log(messages) - const completion = await this._client.createChatCompletion({ - model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining - ...this._outputSchema, + protected override async _createChatCompletion( + messages: types.ChatMessage[] + ): Promise< + types.BaseChatCompletionResponse + > { + const response = await this._client.createChatCompletion({ + model: this._model, + ...this._modelParams, messages }) - if (this._outputSchema) { - const outputSchema = - this._outputSchema instanceof z.ZodType - ? this._outputSchema - : z.object(this._outputSchema) - - let output: any = completion.message.content - console.log('===') - console.log(output) - console.log('<<<') - - if (outputSchema instanceof z.ZodArray) { - try { - const trimmedOutput = extractJSONArrayFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - } catch (err) { - // TODO - throw err - } - } else if (outputSchema instanceof z.ZodObject) { - try { - const trimmedOutput = extractJSONObjectFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - } catch (err) { - // TODO - throw err - } - } else if (outputSchema instanceof z.ZodBoolean) { - output = output.toLowerCase().trim() - const booleanOutputs = { - true: true, - false: false, - yes: true, - no: false, - 1: true, - 0: false - } - - const booleanOutput = booleanOutputs[output] - if (booleanOutput !== undefined) { - output = booleanOutput - } else { - // TODO - throw new Error(`invalid boolean output: ${output}`) - } - } else if (outputSchema instanceof z.ZodNumber) { - output = output.trim() - - const numberOutput = outputSchema.isInt - ? parseInt(output) - : parseFloat(output) - - if (isNaN(numberOutput)) { - // TODO - throw new Error(`invalid number output: ${output}`) - } else { - output = numberOutput - } - } - - // TODO: handle errors, retry logic, and self-healing - - return outputSchema.parse(output) - } else { - return completion.message.content as any + return { + message: response.message, + response: response.response } } } diff --git a/src/types.ts b/src/types.ts index 54614f3..5b8c217 100644 --- a/src/types.ts +++ b/src/types.ts @@ -61,7 +61,13 @@ export interface LLMOptions< promptSuffix?: string } -export type ChatMessageRole = 'user' | 'system' | 'assistant' +// export type ChatMessageRole = 'user' | 'system' | 'assistant' +export const ChatMessageRoleSchema = z.union([ + z.literal('user'), + z.literal('system'), + z.literal('assistant') +]) +export type ChatMessageRole = z.infer export interface ChatMessage { role: ChatMessageRole @@ -76,6 +82,16 @@ export interface ChatModelOptions< messages: ChatMessage[] } +export interface BaseChatCompletionResponse< + TChatCompletionResponse extends Record = Record +> { + /** The completion message. */ + message: ChatMessage + + /** The raw response from the LLM provider. */ + response: TChatCompletionResponse +} + export interface LLMExample { input: string output: string