diff --git a/legacy/src/llms/chat.ts b/legacy/src/llms/chat.ts index 886380d4..7e6cb6e0 100644 --- a/legacy/src/llms/chat.ts +++ b/legacy/src/llms/chat.ts @@ -6,14 +6,10 @@ import { zodToJsonSchema } from 'zod-to-json-schema' import * as errors from '@/errors' import * as types from '@/types' +import { parseOutput } from '@/llms/parse-output' import { BaseTask } from '@/task' import { getCompiledTemplate } from '@/template' -import { - extractFunctionIdentifierFromString, - extractJSONArrayFromString, - extractJSONObjectFromString, - stringifyForModel -} from '@/utils' +import { extractFunctionIdentifierFromString, stringifyForModel } from '@/utils' import { BaseLLM } from './llm' import { @@ -21,113 +17,6 @@ import { getNumTokensForChatMessages } from './llm-utils' -const BOOLEAN_OUTPUTS = { - true: true, - false: false, - t: true, - f: false, - yes: true, - no: false, - y: true, - n: false, - '1': true, - '0': false -} - -function parseArrayOutput(output: string): Array { - try { - const trimmedOutput = extractJSONArrayFromString(output) - const parsedOutput = JSON.parse(jsonrepair(trimmedOutput ?? output)) - return parsedOutput - } catch (err: any) { - if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) - } else if (err instanceof SyntaxError) { - throw new errors.OutputValidationError( - `Invalid JSON array: ${err.message}`, - { cause: err } - ) - } else { - throw err - } - } -} - -function parseObjectOutput(output) { - try { - const trimmedOutput = extractJSONObjectFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - - if (Array.isArray(output)) { - // TODO - output = output[0] - } - - return output - } catch (err: any) { - if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) - } else if (err instanceof SyntaxError) { - throw new errors.OutputValidationError( - `Invalid JSON object: ${err.message}`, - { cause: err } - ) - } else { - throw err - } - } -} - -function parseBooleanOutput(output): boolean { - output = output - .toLowerCase() - .trim() - .replace(/[.!?]+$/, '') - - const booleanOutput = BOOLEAN_OUTPUTS[output] - - if (booleanOutput !== undefined) { - return booleanOutput - } else { - throw new errors.OutputValidationError(`Invalid boolean output: ${output}`) - } -} - -function parseNumberOutput(output, outputSchema: z.ZodNumber): number { - output = output.trim() - - const numberOutput = outputSchema.isInt - ? parseInt(output) - : parseFloat(output) - - if (isNaN(numberOutput)) { - throw new errors.OutputValidationError(`Invalid number output: ${output}`) - } - - return numberOutput -} - -function parseOutput(output: any, outputSchema: ZodType) { - if (outputSchema instanceof z.ZodArray) { - output = parseArrayOutput(output) - } else if (outputSchema instanceof z.ZodObject) { - output = parseObjectOutput(output) - } else if (outputSchema instanceof z.ZodBoolean) { - output = parseBooleanOutput(output) - } else if (outputSchema instanceof z.ZodNumber) { - output = parseNumberOutput(output, outputSchema) - } - - // TODO: fix typescript issue here with recursive types - const safeResult = (outputSchema.safeParse as any)(output) - - if (!safeResult.success) { - throw new errors.ZodOutputValidationError(safeResult.error) - } - - return safeResult.data -} - export abstract class BaseChatCompletion< TInput extends types.TaskInput = void, TOutput extends types.TaskOutput = string, @@ -462,7 +351,7 @@ export abstract class BaseChatCompletion< // console.log('<<<') if (this._outputSchema) { - return parseOutput(output, this._outputSchema) + return parseOutput(output as string, this._outputSchema) } else { return output } diff --git a/legacy/src/llms/parse-output.ts b/legacy/src/llms/parse-output.ts new file mode 100644 index 00000000..8fbda33c --- /dev/null +++ b/legacy/src/llms/parse-output.ts @@ -0,0 +1,161 @@ +import { JSONRepairError, jsonrepair } from 'jsonrepair' +import { ZodType, z } from 'zod' + +import * as errors from '@/errors' +import { + extractJSONArrayFromString, + extractJSONObjectFromString +} from '@/utils' + +const BOOLEAN_OUTPUTS = { + true: true, + false: false, + t: true, + f: false, + yes: true, + no: false, + y: true, + n: false, + '1': true, + '0': false +} + +/** + * Parses an array output from a string. + * + * @param output - string to parse + * @returns parsed array + */ +export function parseArrayOutput(output: string): Array { + try { + const trimmedOutput = extractJSONArrayFromString(output) + const parsedOutput = JSON.parse(jsonrepair(trimmedOutput ?? output)) + if (!Array.isArray(parsedOutput)) { + throw new errors.OutputValidationError( + `Invalid JSON array: ${JSON.stringify(parsedOutput)}` + ) + } + + return parsedOutput + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON array: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } +} + +/** + * Parses an object output from a string. + * + * @param output - string to parse + * @returns parsed object + */ +export function parseObjectOutput(output: string) { + try { + const trimmedOutput = extractJSONObjectFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + + if (Array.isArray(output)) { + // TODO + output = output[0] + } else if (typeof output !== 'object') { + throw new errors.OutputValidationError( + `Invalid JSON object: ${JSON.stringify(output)}` + ) + } + + return output + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON object: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } +} + +/** + * Parses a boolean output from a string. + * + * @param output - string to parse + * @returns parsed boolean + */ +export function parseBooleanOutput(output: string): boolean { + output = output + .toLowerCase() + .trim() + .replace(/[.!?]+$/, '') + + const booleanOutput = BOOLEAN_OUTPUTS[output] + + if (booleanOutput !== undefined) { + return booleanOutput + } else { + throw new errors.OutputValidationError(`Invalid boolean output: ${output}`) + } +} + +/** + * Parses a number output from a string. + * + * @param output - string to parse + * @param outputSchema - zod number schema + * @returns parsed number + */ +export function parseNumberOutput( + output: string, + outputSchema: z.ZodNumber +): number { + output = output.trim() + + const numberOutput = outputSchema.isInt + ? parseInt(output) + : parseFloat(output) + + if (isNaN(numberOutput)) { + throw new errors.OutputValidationError(`Invalid number output: ${output}`) + } + + return numberOutput +} + +/** + * Parses an output value from a string. + * + * @param output - string to parse + * @param outputSchema - zod schema + * @returns parsed output + */ +export function parseOutput(output: string, outputSchema: ZodType) { + let result + if (outputSchema instanceof z.ZodArray) { + result = parseArrayOutput(output) + } else if (outputSchema instanceof z.ZodObject) { + result = parseObjectOutput(output) + } else if (outputSchema instanceof z.ZodBoolean) { + result = parseBooleanOutput(output) + } else if (outputSchema instanceof z.ZodNumber) { + result = parseNumberOutput(output, outputSchema) + } + + // TODO: fix typescript issue here with recursive types + const safeResult = (outputSchema.safeParse as any)(result) + + if (!safeResult.success) { + throw new errors.ZodOutputValidationError(safeResult.error) + } + + return safeResult.data +} diff --git a/legacy/test/.snapshots/test/llms/parse-output.ts.md b/legacy/test/.snapshots/test/llms/parse-output.ts.md new file mode 100644 index 00000000..54a5033e --- /dev/null +++ b/legacy/test/.snapshots/test/llms/parse-output.ts.md @@ -0,0 +1,226 @@ +# Snapshot report for `test/llms/parse-output.ts` + +The actual snapshot is saved in `parse-output.ts.snap`. + +Generated by [AVA](https://avajs.dev). + +## parseArrayOutput - handles valid arrays correctly + +> should return [1, 2, 3] for "[1,2,3]" + + [ + 1, + 2, + 3, + ] + +> should return ["a", "b", "c"] for "["a", "b", "c"] + + [ + 'a', + 'b', + 'c', + ] + +> should return [{"a": 1}, {"b": 2}] for [{"a": 1}, {"b": 2}] + + [ + { + a: 1, + }, + { + b: 2, + }, + ] + +## parseArrayOutput - handles arrays surrounded by text correctly + +> should return [1, 2, 3] for "The array is [1,2,3]" + + [ + 1, + 2, + 3, + ] + +> should return ["a", "b", "c"] for "Array: ["a", "b", "c"]. That's all!" + + [ + 'a', + 'b', + 'c', + ] + +> should return [{"a": 1}, {"b": 2}] for "This is the array [{"a": 1}, {"b": 2}] in the text" + + [ + { + a: 1, + }, + { + b: 2, + }, + ] + +## parseArrayOutput - handles and repairs broken JSON arrays correctly + +> should repair and return [1, "two", 3] for [1, "two, 3] + + [ + 1, + 'two, 3]', + ] + +> should repair and return ["a", "b", "c"] for Array: ["a, "b", "c"]. Error here! + + [ + 'a, ', + 'b', + ', ', + 'c', + ']', + ] + +> should repair and return {"arr": ["value1", "value2"]} for Array in text {"arr": ["value1, "value2"]} + + [ + 'value1, ', + 'value2', + ']', + ] + +## parseArrayOutput - throws error for invalid arrays + +> Snapshot 1 + + 'Invalid JSON array: "not a valid array"' + +## parseObjectOutput - handles valid objects correctly + +> should return {"a":1,"b":2,"c":3} for {"a":1,"b":2,"c":3} + + { + a: 1, + b: 2, + c: 3, + } + +> should return {"name":"John","age":30,"city":"New York"} for {"name":"John","age":30,"city":"New York"} + + { + age: 30, + city: 'New York', + name: 'John', + } + +## parseObjectOutput - handles objects surrounded by text correctly + +> should return {"a":1,"b":2,"c":3} for "The object is {"a":1,"b":2,"c":3}" + + { + a: 1, + b: 2, + c: 3, + } + +> should return {"name":"John","age":30,"city":"New York"} for "Object: {"name":"John","age":30,"city":"New York"}. That's all!" + + { + age: 30, + city: 'New York', + name: 'John', + } + +## parseObjectOutput - handles and repairs broken JSON objects correctly + +> should repair and return {"a":1, "b":2, "c":3} for {"a":1, "b":2, "c":3 + + { + a: 1, + b: 2, + c: 3, + } + +> should repair and return {"name":"John","age":30,"city":"New York"} for Object: {"name":"John,"age":30,"city":"New York"}. Error here! + + { + 'New York': '}', + age: ':30,', + city: ':', + name: 'John,', + } + +## parseObjectOutput - handles JSON array of objects + +> should return first object {"a":1,"b":2} for [{"a":1,"b":2},{"c":3,"d":4}] + + { + a: 1, + b: 2, + } + +## parseObjectOutput - throws error for invalid objects + +> Snapshot 1 + + 'Invalid JSON object: "not a valid object"' + +## parseBooleanOutput - handles `true` outputs correctly + +> should return true for "True" + + true + +> should return true for "TRUE" + + true + +> should return true for "true." + + true + +## parseBooleanOutput - handles `false` outputs correctly + +> should return false for "False" + + false + +> should return false for "FALSE" + + false + +> should return false for "false!" + + false + +## parseBooleanOutput - throws error for invalid outputs + +> Snapshot 1 + + 'Invalid boolean output: notbooleanvalue' + +## parseNumberOutput - handles integer outputs correctly + +> should return 42 for "42" + + 42 + +> should return -5 for " -5 " + + -5 + +## parseNumberOutput - handles float outputs correctly + +> should return 42.42 for "42.42" + + 42.42 + +> should return -5.5 for " -5.5 " + + -5.5 + +## parseNumberOutput - throws error for invalid outputs + +> Snapshot 1 + + 'Invalid number output: NotANumber' diff --git a/legacy/test/.snapshots/test/llms/parse-output.ts.snap b/legacy/test/.snapshots/test/llms/parse-output.ts.snap new file mode 100644 index 00000000..67e98c66 Binary files /dev/null and b/legacy/test/.snapshots/test/llms/parse-output.ts.snap differ diff --git a/legacy/test/llms/parse-output.ts b/legacy/test/llms/parse-output.ts new file mode 100644 index 00000000..94ee932c --- /dev/null +++ b/legacy/test/llms/parse-output.ts @@ -0,0 +1,193 @@ +import test from 'ava' +import { z } from 'zod' + +import { + parseArrayOutput, + parseBooleanOutput, + parseNumberOutput, + parseObjectOutput +} from '@/llms/parse-output' + +test('parseArrayOutput - handles valid arrays correctly', (t) => { + const output1 = parseArrayOutput('[1,2,3]') + const output2 = parseArrayOutput('["a", "b", "c"]') + const output3 = parseArrayOutput('[{"a": 1}, {"b": 2}]') + + t.snapshot(output1, 'should return [1, 2, 3] for "[1,2,3]"') + t.snapshot(output2, 'should return ["a", "b", "c"] for "["a", "b", "c"]') + t.snapshot( + output3, + 'should return [{"a": 1}, {"b": 2}] for [{"a": 1}, {"b": 2}]' + ) +}) + +test('parseArrayOutput - handles arrays surrounded by text correctly', (t) => { + const output1 = parseArrayOutput('The array is [1,2,3]') + const output2 = parseArrayOutput('Array: ["a", "b", "c"]. That\'s all!') + const output3 = parseArrayOutput( + 'This is the array [{"a": 1}, {"b": 2}] in the text' + ) + + t.snapshot(output1, 'should return [1, 2, 3] for "The array is [1,2,3]"') + t.snapshot( + output2, + 'should return ["a", "b", "c"] for "Array: ["a", "b", "c"]. That\'s all!"' + ) + t.snapshot( + output3, + 'should return [{"a": 1}, {"b": 2}] for "This is the array [{"a": 1}, {"b": 2}] in the text"' + ) +}) + +test('parseArrayOutput - handles and repairs broken JSON arrays correctly', (t) => { + const output1 = parseArrayOutput('[1, "two, 3]') + const output2 = parseArrayOutput('Array: ["a, "b", "c"]. Error here!') + const output3 = parseArrayOutput('Array in text {"arr": ["value1, "value2"]}') + + t.snapshot(output1, 'should repair and return [1, "two", 3] for [1, "two, 3]') + t.snapshot( + output2, + 'should repair and return ["a", "b", "c"] for Array: ["a, "b", "c"]. Error here!' + ) + t.snapshot( + output3, + 'should repair and return {"arr": ["value1", "value2"]} for Array in text {"arr": ["value1, "value2"]}' + ) +}) + +test('parseArrayOutput - throws error for invalid arrays', (t) => { + const error = t.throws( + () => { + parseArrayOutput('not a valid array') + }, + { instanceOf: Error } + ) + + t.snapshot(error?.message) +}) + +test('parseObjectOutput - handles valid objects correctly', (t) => { + const output1 = parseObjectOutput('{"a":1,"b":2,"c":3}') + const output2 = parseObjectOutput( + '{"name":"John","age":30,"city":"New York"}' + ) + + t.snapshot( + output1, + 'should return {"a":1,"b":2,"c":3} for {"a":1,"b":2,"c":3}' + ) + t.snapshot( + output2, + 'should return {"name":"John","age":30,"city":"New York"} for {"name":"John","age":30,"city":"New York"}' + ) +}) + +test('parseObjectOutput - handles objects surrounded by text correctly', (t) => { + const output1 = parseObjectOutput('The object is {"a":1,"b":2,"c":3}') + const output2 = parseObjectOutput( + 'Object: {"name":"John","age":30,"city":"New York"}. That\'s all!' + ) + + t.snapshot( + output1, + 'should return {"a":1,"b":2,"c":3} for "The object is {"a":1,"b":2,"c":3}"' + ) + t.snapshot( + output2, + 'should return {"name":"John","age":30,"city":"New York"} for "Object: {"name":"John","age":30,"city":"New York"}. That\'s all!"' + ) +}) + +test('parseObjectOutput - handles and repairs broken JSON objects correctly', (t) => { + const output1 = parseObjectOutput('{"a":1, "b":2, "c":3') + const output2 = parseObjectOutput( + 'Object: {"name":"John,"age":30,"city":"New York"}. Error here!' + ) + + t.snapshot( + output1, + 'should repair and return {"a":1, "b":2, "c":3} for {"a":1, "b":2, "c":3' + ) + t.snapshot( + output2, + 'should repair and return {"name":"John","age":30,"city":"New York"} for Object: {"name":"John,"age":30,"city":"New York"}. Error here!' + ) +}) + +test('parseObjectOutput - handles JSON array of objects', (t) => { + const output = parseObjectOutput('[{"a":1,"b":2},{"c":3,"d":4}]') + + t.snapshot( + output, + 'should return first object {"a":1,"b":2} for [{"a":1,"b":2},{"c":3,"d":4}]' + ) +}) + +test('parseObjectOutput - throws error for invalid objects', (t) => { + const error = t.throws( + () => { + parseObjectOutput('not a valid object') + }, + { instanceOf: Error } + ) + + t.snapshot(error?.message) +}) + +test('parseBooleanOutput - handles `true` outputs correctly', (t) => { + const output1 = parseBooleanOutput('True') + const output2 = parseBooleanOutput('TRUE') + const output3 = parseBooleanOutput('true.') + + t.snapshot(output1, 'should return true for "True"') + t.snapshot(output2, 'should return true for "TRUE"') + t.snapshot(output3, 'should return true for "true."') +}) + +test('parseBooleanOutput - handles `false` outputs correctly', (t) => { + const output1 = parseBooleanOutput('False') + const output2 = parseBooleanOutput('FALSE') + const output3 = parseBooleanOutput('false!') + + t.snapshot(output1, 'should return false for "False"') + t.snapshot(output2, 'should return false for "FALSE"') + t.snapshot(output3, 'should return false for "false!"') +}) + +test('parseBooleanOutput - throws error for invalid outputs', (t) => { + const error = t.throws( + () => { + parseBooleanOutput('NotBooleanValue') + }, + { instanceOf: Error } + ) + + t.snapshot(error?.message) +}) + +test('parseNumberOutput - handles integer outputs correctly', (t) => { + const output1 = parseNumberOutput('42', z.number().int()) + const output2 = parseNumberOutput(' -5 ', z.number().int()) + + t.snapshot(output1, 'should return 42 for "42"') + t.snapshot(output2, 'should return -5 for " -5 "') +}) + +test('parseNumberOutput - handles float outputs correctly', (t) => { + const output1 = parseNumberOutput('42.42', z.number()) + const output2 = parseNumberOutput(' -5.5 ', z.number()) + + t.snapshot(output1, 'should return 42.42 for "42.42"') + t.snapshot(output2, 'should return -5.5 for " -5.5 "') +}) + +test('parseNumberOutput - throws error for invalid outputs', (t) => { + const error = t.throws( + () => { + parseNumberOutput('NotANumber', z.number()) + }, + { instanceOf: Error } + ) + + t.snapshot(error?.message) +})