From 8b3ebb4515daf4ba978148bda1551f758635c6c1 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Thu, 25 May 2023 17:51:45 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm.ts | 33 +++++++++++++++++++++------------ src/temp.ts | 3 ++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/llm.ts b/src/llm.ts index 1c7ebf4..1e9eda2 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -209,22 +209,31 @@ export class OpenAIChatModelBuilder< : z.object(this._options.output) const { node } = zodToTs(outputSchema) - const tsTypeString = printNode(node, { - removeComments: true, - // TODO: this doesn't seem to actually work, so we're doing it manually below - omitTrailingSemicolon: true, - noEmitHelpers: true - }) - .replace(/^ /gm, ' ') - .replace(/;$/gm, '') - messages.push({ - role: 'system', - content: dedent`Output JSON only in the following format: + if (node.kind === 152) { + // ignore raw strings + messages.push({ + role: 'system', + content: dedent`Output a string` + }) + } else { + const tsTypeString = printNode(node, { + removeComments: false, + // TODO: this doesn't seem to actually work, so we're doing it manually below + omitTrailingSemicolon: true, + noEmitHelpers: true + }) + .replace(/^ /gm, ' ') + .replace(/;$/gm, '') + + messages.push({ + role: 'system', + content: dedent`Output JSON only in the following format: \`\`\`ts ${tsTypeString} \`\`\`` - }) + }) + } } // TODO: filter/compress messages based on token counts diff --git a/src/temp.ts b/src/temp.ts index 4e257fa..632c22b 100644 --- a/src/temp.ts +++ b/src/temp.ts @@ -18,8 +18,9 @@ async function main() { // .call() // console.log(ex0) - const ex1 = await $.gpt4(`give me fake data conforming to this schema`) + const ex1 = await $.gpt4(`give me fake data`) .output(z.object({ foo: z.string(), bar: z.number() })) + // .output(z.string()) // .retry({ attempts: 3 }) .call() console.log(ex1)