import { jsonrepair } from 'jsonrepair' import Mustache from 'mustache' import { dedent } from 'ts-dedent' import { ZodRawShape, ZodTypeAny, z } from 'zod' import { printNode, zodToTs } from 'zod-to-ts' import * as types from './types' import { BaseTaskCallBuilder } from './task' import { extractJSONArrayFromString, extractJSONObjectFromString } from './utils' export abstract class BaseLLMCallBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > extends BaseTaskCallBuilder { protected _provider: string protected _model: string protected _modelParams: TModelParams protected _examples: types.LLMExample[] constructor(options: types.BaseLLMOptions) { super(options) this._provider = options.provider this._model = options.model this._modelParams = options.modelParams this._examples = options.examples } override input( inputSchema: U ): BaseLLMCallBuilder { ;( this as unknown as BaseLLMCallBuilder )._inputSchema = inputSchema return this as unknown as BaseLLMCallBuilder } override output( outputSchema: U ): BaseLLMCallBuilder { ;( this as unknown as BaseLLMCallBuilder )._outputSchema = outputSchema return this as unknown as BaseLLMCallBuilder } examples(examples: types.LLMExample[]) { this._examples = examples return this } modelParams(params: Partial) { // We assume that modelParams does not include nested objects. // If it did, we would need to do a deep merge. this._modelParams = Object.assign({}, this._modelParams, params) return this } // TODO // abstract stream({ // input: TInput, // onProgress: types.ProgressFunction // }): Promise } export abstract class BaseChatModelBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record, TChatCompletionResponse extends Record = Record > extends BaseLLMCallBuilder { _messages: types.ChatMessage[] constructor(options: types.ChatModelOptions) { super(options) this._messages = options.messages } protected abstract _createChatCompletion( messages: types.ChatMessage[] ): Promise> override async call( input?: types.ParsedData ): Promise> { if (this._inputSchema) { const inputSchema = this._inputSchema instanceof z.ZodType ? this._inputSchema : z.object(this._inputSchema) // TODO: handle errors gracefully input = inputSchema.parse(input) } // TODO: validate input message variables against input schema const messages = this._messages .map((message) => { return { ...message, content: message.content ? Mustache.render(dedent(message.content), input).trim() : '' } }) .filter((message) => message.content) if (this._examples?.length) { // TODO: smarter example selection for (const example of this._examples) { messages.push({ role: 'system', content: `Example input: ${example.input}\n\nExample output: ${example.output}` }) } } if (this._outputSchema) { const outputSchema = this._outputSchema instanceof z.ZodType ? this._outputSchema : z.object(this._outputSchema) const { node } = zodToTs(outputSchema) if (node.kind === 152) { // handle raw strings differently messages.push({ role: 'system', content: dedent`Output a raw string only, without any additional text.` }) } else { const tsTypeString = printNode(node, { removeComments: false, // TODO: this doesn't seem to actually work, so we're doing it manually below omitTrailingSemicolon: true, noEmitHelpers: true }) .replace(/^ /gm, ' ') .replace(/;$/gm, '') messages.push({ role: 'system', content: dedent`Output JSON only in the following TypeScript format: \`\`\`ts ${tsTypeString} \`\`\`` }) } } // TODO: filter/compress messages based on token counts console.log('>>>') console.log(messages) const completion = await this._createChatCompletion(messages) let output: any = completion.message.content console.log('===') console.log(output) console.log('<<<') if (this._outputSchema) { const outputSchema = this._outputSchema instanceof z.ZodType ? this._outputSchema : z.object(this._outputSchema) if (outputSchema instanceof z.ZodArray) { try { const trimmedOutput = extractJSONArrayFromString(output) output = JSON.parse(jsonrepair(trimmedOutput ?? output)) } catch (err) { // TODO throw err } } else if (outputSchema instanceof z.ZodObject) { try { const trimmedOutput = extractJSONObjectFromString(output) output = JSON.parse(jsonrepair(trimmedOutput ?? output)) } catch (err) { // TODO throw err } } else if (outputSchema instanceof z.ZodBoolean) { output = output.toLowerCase().trim() const booleanOutputs = { true: true, false: false, yes: true, no: false, 1: true, 0: false } const booleanOutput = booleanOutputs[output] if (booleanOutput !== undefined) { output = booleanOutput } else { // TODO throw new Error(`invalid boolean output: ${output}`) } } else if (outputSchema instanceof z.ZodNumber) { output = output.trim() const numberOutput = outputSchema.isInt ? parseInt(output) : parseFloat(output) if (isNaN(numberOutput)) { // TODO throw new Error(`invalid number output: ${output}`) } else { output = numberOutput } } // TODO: handle errors, retry logic, and self-healing return outputSchema.parse(output) } else { return output } } }