diff --git a/src/llms/chat.ts b/src/llms/chat.ts index b9e1f72..5f36a95 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -21,6 +21,113 @@ import { getNumTokensForChatMessages } from './llm-utils' +const BOOLEAN_OUTPUTS = { + true: true, + false: false, + t: true, + f: false, + yes: true, + no: false, + y: true, + n: false, + '1': true, + '0': false +} + +function parseArrayOutput(output: string): Array { + try { + const trimmedOutput = extractJSONArrayFromString(output) + const parsedOutput = JSON.parse(jsonrepair(trimmedOutput ?? output)) + return parsedOutput + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON array: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } +} + +function parseObjectOutput(output) { + try { + const trimmedOutput = extractJSONObjectFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + + if (Array.isArray(output)) { + // TODO + output = output[0] + } + + return output + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON object: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } +} + +function parseBooleanOutput(output): boolean { + output = output + .toLowerCase() + .trim() + .replace(/[.!?]+$/, '') + + const booleanOutput = BOOLEAN_OUTPUTS[output] + + if (booleanOutput !== undefined) { + return booleanOutput + } else { + throw new errors.OutputValidationError(`Invalid boolean output: ${output}`) + } +} + +function parseNumberOutput(output, outputSchema: z.ZodNumber): number { + output = output.trim() + + const numberOutput = outputSchema.isInt + ? parseInt(output) + : parseFloat(output) + + if (isNaN(numberOutput)) { + throw new errors.OutputValidationError(`Invalid number output: ${output}`) + } + + return numberOutput +} + +function parseOutput(output: any, outputSchema: ZodType) { + if (outputSchema instanceof z.ZodArray) { + output = parseArrayOutput(output) + } else if (outputSchema instanceof z.ZodObject) { + output = parseObjectOutput(output) + } else if (outputSchema instanceof z.ZodBoolean) { + output = parseBooleanOutput(output) + } else if (outputSchema instanceof z.ZodNumber) { + output = parseNumberOutput(output, outputSchema) + } + + // TODO: fix typescript issue here with recursive types + const safeResult = (outputSchema.safeParse as any)(output) + + if (!safeResult.success) { + throw new errors.ZodOutputValidationError(safeResult.error) + } + + return safeResult.data +} + export abstract class BaseChatCompletion< TInput extends types.TaskInput = void, TOutput extends types.TaskOutput = string, @@ -141,47 +248,12 @@ export abstract class BaseChatCompletion< } } - if (this._outputSchema) { - // TODO: replace zod-to-ts with zod-to-json-schema? - const { node } = zodToTs(this._outputSchema) - - if (node.kind === 152) { - // handle raw strings differently - messages.push({ - role: 'system', - content: dedent`Output a raw string only, without any additional text.` - }) - } else { - const tsTypeString = printNode(node, { - removeComments: false, - // TODO: this doesn't seem to actually work, so we're doing it manually below - omitTrailingSemicolon: true, - noEmitHelpers: true - }) - .replace(/^ {4}/gm, ' ') - .replace(/;$/gm, '') - - const label = - this._outputSchema instanceof z.ZodArray - ? 'JSON array' - : this._outputSchema instanceof z.ZodObject - ? 'JSON object' - : this._outputSchema instanceof z.ZodNumber - ? 'number' - : this._outputSchema instanceof z.ZodString - ? 'string' - : this._outputSchema instanceof z.ZodBoolean - ? 'boolean' - : 'JSON value' - - messages.push({ - role: 'system', - content: dedent`Do not output code. Output a single ${label} in the following TypeScript format: - \`\`\`ts - ${tsTypeString} - \`\`\`` - }) - } + const outputMsg = this.outputMessage() + if (outputMsg !== null) { + messages.push({ + role: 'system', + content: outputMsg + }) } if (ctx?.retryMessage) { @@ -196,6 +268,50 @@ export abstract class BaseChatCompletion< return messages } + public outputMessage(): string | null { + const outputSchema = this._outputSchema + + if (!outputSchema) { + return null + } + + // TODO: replace zod-to-ts with zod-to-json-schema? + const { node } = zodToTs(outputSchema) + + if (node.kind === 152) { + // Handle raw strings differently: + return dedent`Output a raw string only, without any additional text.` + } + + const tsTypeString = printNode(node, { + removeComments: false, + // TODO: this doesn't seem to actually work, so we're doing it manually below + omitTrailingSemicolon: true, + noEmitHelpers: true + }) + .replace(/^ {4}/gm, ' ') + .replace(/;$/gm, '') + let label: string + if (outputSchema instanceof z.ZodArray) { + label = 'JSON array (minified)' + } else if (outputSchema instanceof z.ZodObject) { + label = 'JSON object (minified)' + } else if (outputSchema instanceof z.ZodNumber) { + label = 'number' + } else if (outputSchema instanceof z.ZodString) { + label = 'string' + } else if (outputSchema instanceof z.ZodBoolean) { + label = 'boolean' + } else { + label = 'JSON value' + } + + return dedent`Do not output code. Output a single ${label} in the following TypeScript format: + \`\`\`ts + ${tsTypeString} + \`\`\`` + } + protected override async _call( ctx: types.TaskCallContext ): Promise { @@ -356,89 +472,7 @@ export abstract class BaseChatCompletion< // console.log('<<<') if (this._outputSchema) { - const outputSchema = this._outputSchema - - if (outputSchema instanceof z.ZodArray) { - try { - const trimmedOutput = extractJSONArrayFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - } catch (err: any) { - if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) - } else if (err instanceof SyntaxError) { - throw new errors.OutputValidationError( - `Invalid JSON array: ${err.message}`, - { cause: err } - ) - } else { - throw err - } - } - } else if (outputSchema instanceof z.ZodObject) { - try { - const trimmedOutput = extractJSONObjectFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - - if (Array.isArray(output)) { - // TODO - output = output[0] - } - } catch (err: any) { - if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) - } else if (err instanceof SyntaxError) { - throw new errors.OutputValidationError( - `Invalid JSON object: ${err.message}`, - { cause: err } - ) - } else { - throw err - } - } - } else if (outputSchema instanceof z.ZodBoolean) { - output = output.toLowerCase().trim() - const booleanOutputs = { - true: true, - false: false, - yes: true, - no: false, - 1: true, - 0: false - } - - const booleanOutput = booleanOutputs[output] - - if (booleanOutput !== undefined) { - output = booleanOutput - } else { - throw new errors.OutputValidationError( - `Invalid boolean output: ${output}` - ) - } - } else if (outputSchema instanceof z.ZodNumber) { - output = output.trim() - - const numberOutput = outputSchema.isInt - ? parseInt(output) - : parseFloat(output) - - if (isNaN(numberOutput)) { - throw new errors.OutputValidationError( - `Invalid number output: ${output}` - ) - } else { - output = numberOutput - } - } - - // TODO: fix typescript issue here with recursive types - const safeResult = (outputSchema.safeParse as any)(output) - - if (!safeResult.success) { - throw new errors.ZodOutputValidationError(safeResult.error) - } - - return safeResult.data + return parseOutput(output, this._outputSchema) } else { return output }