Travis Fischer 2023-05-23 23:15:59 -07:00
rodzic 2afaf4d9ad
commit c4a5006f26
5 zmienionych plików z 73 dodań i 18 usunięć

Wyświetl plik

@ -25,6 +25,7 @@
"dependencies": {
"@dqbd/tiktoken": "^1.0.7",
"dotenv-safe": "^8.2.0",
"jsonrepair": "^3.1.0",
"mustache": "^4.2.0",
"openai-fetch": "^1.2.1",
"p-map": "^6.0.0",

Wyświetl plik

@ -7,6 +7,9 @@ dependencies:
dotenv-safe:
specifier: ^8.2.0
version: 8.2.0
jsonrepair:
specifier: ^3.1.0
version: 3.1.0
mustache:
specifier: ^4.2.0
version: 4.2.0
@ -1461,6 +1464,11 @@ packages:
engines: {node: ^14.17.0 || ^16.13.0 || >=18.0.0}
dev: false
/jsonrepair@3.1.0:
resolution: {integrity: sha512-idqReg23J0PVRAADmZMc5xQM3xeOX5bTB6OTyMnzq33IXJXmn9iJuWIEvGmrN80rQf4d7uLTMEDwpzujNcI0Rg==}
hasBin: true
dev: false
/kind-of@6.0.3:
resolution: {integrity: sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==}
engines: {node: '>=0.10.0'}

Wyświetl plik

@ -1,3 +1,4 @@
import { jsonrepair } from 'jsonrepair'
import Mustache from 'mustache'
import { dedent } from 'ts-dedent'
import type { SetRequired } from 'type-fest'
@ -5,6 +6,10 @@ import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types'
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from './utils'
const defaultOpenAIModel = 'gpt-3.5-turbo'
@ -174,6 +179,8 @@ export class OpenAIChatModelBuilder<
input = inputSchema.parse(input)
}
// TODO: validate input message variables against input schema
const messages = this._messages
.map((message) => {
return {
@ -185,6 +192,24 @@ export class OpenAIChatModelBuilder<
})
.filter((message) => message.content)
if (this._options.output) {
const outputSchema =
this._options.output instanceof z.ZodType
? this._options.output
: z.object(this._options.output)
const { node } = zodToTs(outputSchema)
const tsTypeString = printNode(node)
messages.push({
role: 'system',
content: dedent`Output JSON only in the following format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
// TODO: filter/compress messages based on token counts
const completion = await this._client.createChatCompletion({
@ -199,10 +224,28 @@ export class OpenAIChatModelBuilder<
? this._options.output
: z.object(this._options.output)
// TODO: convert string => object if necessary
let output: any = completion.message.content
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = jsonrepair(trimmedOutput ?? output)
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = jsonrepair(trimmedOutput ?? output)
} catch (err) {
// TODO
throw err
}
}
// TODO: handle errors, retry logic, and self-healing
return outputSchema.parse(completion.message.content)
return outputSchema.parse(output)
} else {
return completion.message.content as any
}

Wyświetl plik

@ -12,28 +12,26 @@ async function main() {
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
const $ = new Agentic({ openai })
const ex0 = await $.gpt4(`give me a single boolean value`)
.output(z.boolean())
// .retry({ attempts: 3 })
.call()
console.log(ex0)
// const ex0 = await $.gpt4(`give me a single boolean value`)
// .output(z.boolean())
// // .retry({ attempts: 3 })
// .call()
// console.log(ex0)
const ex1 = await $.gpt4(`give me fake data conforming to this schema`)
.output(z.object({ foo: z.string(), bar: z.number() }))
// .retry({ attempts: 3 })
.call()
const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`)
.input(z.object({ foo: z.string() }))
.output(z.boolean())
await Promise.all([
getBoolean.call({ foo: 'foo' }),
getBoolean.call({ foo: 'bar' })
])
console.log(ex1)
// const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`)
// .input(z.object({ foo: z.string() }))
// .output(z.boolean())
// await Promise.all([
// getBoolean.call({ foo: 'foo' }),
// getBoolean.call({ foo: 'bar' })
// ])
}
main()

5
src/utils.ts 100644
Wyświetl plik

@ -0,0 +1,5 @@
export const extractJSONObjectFromString = (text: string): string | undefined =>
text.match(/\{(.|\n)*\}/gm)?.[0]
export const extractJSONArrayFromString = (text: string): string | undefined =>
text.match(/\[(.|\n)*\]/gm)?.[0]