diff --git a/package.json b/package.json index 7356ac65..5ec80e1b 100644 --- a/package.json +++ b/package.json @@ -25,6 +25,7 @@ "dependencies": { "@dqbd/tiktoken": "^1.0.7", "dotenv-safe": "^8.2.0", + "jsonrepair": "^3.1.0", "mustache": "^4.2.0", "openai-fetch": "^1.2.1", "p-map": "^6.0.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index f422f76d..9a0d8d7e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -7,6 +7,9 @@ dependencies: dotenv-safe: specifier: ^8.2.0 version: 8.2.0 + jsonrepair: + specifier: ^3.1.0 + version: 3.1.0 mustache: specifier: ^4.2.0 version: 4.2.0 @@ -1461,6 +1464,11 @@ packages: engines: {node: ^14.17.0 || ^16.13.0 || >=18.0.0} dev: false + /jsonrepair@3.1.0: + resolution: {integrity: sha512-idqReg23J0PVRAADmZMc5xQM3xeOX5bTB6OTyMnzq33IXJXmn9iJuWIEvGmrN80rQf4d7uLTMEDwpzujNcI0Rg==} + hasBin: true + dev: false + /kind-of@6.0.3: resolution: {integrity: sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==} engines: {node: '>=0.10.0'} diff --git a/src/llm.ts b/src/llm.ts index bc6d13ba..eeb7baea 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1,3 +1,4 @@ +import { jsonrepair } from 'jsonrepair' import Mustache from 'mustache' import { dedent } from 'ts-dedent' import type { SetRequired } from 'type-fest' @@ -5,6 +6,10 @@ import { ZodRawShape, ZodTypeAny, z } from 'zod' import { printNode, zodToTs } from 'zod-to-ts' import * as types from './types' +import { + extractJSONArrayFromString, + extractJSONObjectFromString +} from './utils' const defaultOpenAIModel = 'gpt-3.5-turbo' @@ -174,6 +179,8 @@ export class OpenAIChatModelBuilder< input = inputSchema.parse(input) } + // TODO: validate input message variables against input schema + const messages = this._messages .map((message) => { return { @@ -185,6 +192,24 @@ export class OpenAIChatModelBuilder< }) .filter((message) => message.content) + if (this._options.output) { + const outputSchema = + this._options.output instanceof z.ZodType + ? this._options.output + : z.object(this._options.output) + + const { node } = zodToTs(outputSchema) + const tsTypeString = printNode(node) + + messages.push({ + role: 'system', + content: dedent`Output JSON only in the following format: + \`\`\`ts + ${tsTypeString} + \`\`\`` + }) + } + // TODO: filter/compress messages based on token counts const completion = await this._client.createChatCompletion({ @@ -199,10 +224,28 @@ export class OpenAIChatModelBuilder< ? this._options.output : z.object(this._options.output) - // TODO: convert string => object if necessary + let output: any = completion.message.content + if (outputSchema instanceof z.ZodArray) { + try { + const trimmedOutput = extractJSONArrayFromString(output) + output = jsonrepair(trimmedOutput ?? output) + } catch (err) { + // TODO + throw err + } + } else if (outputSchema instanceof z.ZodObject) { + try { + const trimmedOutput = extractJSONObjectFromString(output) + output = jsonrepair(trimmedOutput ?? output) + } catch (err) { + // TODO + throw err + } + } + // TODO: handle errors, retry logic, and self-healing - return outputSchema.parse(completion.message.content) + return outputSchema.parse(output) } else { return completion.message.content as any } diff --git a/src/temp.ts b/src/temp.ts index 53ce2477..4e257fae 100644 --- a/src/temp.ts +++ b/src/temp.ts @@ -12,28 +12,26 @@ async function main() { const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! }) const $ = new Agentic({ openai }) - const ex0 = await $.gpt4(`give me a single boolean value`) - .output(z.boolean()) - // .retry({ attempts: 3 }) - .call() - - console.log(ex0) + // const ex0 = await $.gpt4(`give me a single boolean value`) + // .output(z.boolean()) + // // .retry({ attempts: 3 }) + // .call() + // console.log(ex0) const ex1 = await $.gpt4(`give me fake data conforming to this schema`) .output(z.object({ foo: z.string(), bar: z.number() })) // .retry({ attempts: 3 }) .call() - - const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`) - .input(z.object({ foo: z.string() })) - .output(z.boolean()) - - await Promise.all([ - getBoolean.call({ foo: 'foo' }), - getBoolean.call({ foo: 'bar' }) - ]) - console.log(ex1) + + // const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`) + // .input(z.object({ foo: z.string() })) + // .output(z.boolean()) + + // await Promise.all([ + // getBoolean.call({ foo: 'foo' }), + // getBoolean.call({ foo: 'bar' }) + // ]) } main() diff --git a/src/utils.ts b/src/utils.ts new file mode 100644 index 00000000..d407dfe5 --- /dev/null +++ b/src/utils.ts @@ -0,0 +1,5 @@ +export const extractJSONObjectFromString = (text: string): string | undefined => + text.match(/\{(.|\n)*\}/gm)?.[0] + +export const extractJSONArrayFromString = (text: string): string | undefined => + text.match(/\[(.|\n)*\]/gm)?.[0]