Travis Fischer 2023-05-23 20:22:50 -07:00
rodzic 9076bb342e
commit 2c638ca207
4 zmienionych plików z 159 dodań i 8 usunięć

Wyświetl plik

@ -25,15 +25,18 @@
"dependencies": { "dependencies": {
"@dqbd/tiktoken": "^1.0.7", "@dqbd/tiktoken": "^1.0.7",
"dotenv-safe": "^8.2.0", "dotenv-safe": "^8.2.0",
"mustache": "^4.2.0",
"openai-fetch": "^1.2.1", "openai-fetch": "^1.2.1",
"p-map": "^6.0.0", "p-map": "^6.0.0",
"parse-json": "^7.0.0", "parse-json": "^7.0.0",
"type-fest": "^3.10.0", "type-fest": "^3.10.0",
"zod": "^3.21.4", "zod": "^3.21.4",
"zod-to-ts": "^1.1.4",
"zod-validation-error": "^1.3.0" "zod-validation-error": "^1.3.0"
}, },
"devDependencies": { "devDependencies": {
"@trivago/prettier-plugin-sort-imports": "^4.1.1", "@trivago/prettier-plugin-sort-imports": "^4.1.1",
"@types/mustache": "^4.2.2",
"@types/node": "^20.2.0", "@types/node": "^20.2.0",
"del-cli": "^5.0.0", "del-cli": "^5.0.0",
"husky": "^8.0.3", "husky": "^8.0.3",

Wyświetl plik

@ -7,6 +7,9 @@ dependencies:
dotenv-safe: dotenv-safe:
specifier: ^8.2.0 specifier: ^8.2.0
version: 8.2.0 version: 8.2.0
mustache:
specifier: ^4.2.0
version: 4.2.0
openai-fetch: openai-fetch:
specifier: ^1.2.1 specifier: ^1.2.1
version: 1.2.1 version: 1.2.1
@ -22,6 +25,9 @@ dependencies:
zod: zod:
specifier: ^3.21.4 specifier: ^3.21.4
version: 3.21.4 version: 3.21.4
zod-to-ts:
specifier: ^1.1.4
version: 1.1.4(typescript@5.0.4)(zod@3.21.4)
zod-validation-error: zod-validation-error:
specifier: ^1.3.0 specifier: ^1.3.0
version: 1.3.0(zod@3.21.4) version: 1.3.0(zod@3.21.4)
@ -30,6 +36,9 @@ devDependencies:
'@trivago/prettier-plugin-sort-imports': '@trivago/prettier-plugin-sort-imports':
specifier: ^4.1.1 specifier: ^4.1.1
version: 4.1.1(prettier@2.8.8) version: 4.1.1(prettier@2.8.8)
'@types/mustache':
specifier: ^4.2.2
version: 4.2.2
'@types/node': '@types/node':
specifier: ^20.2.0 specifier: ^20.2.0
version: 20.2.0 version: 20.2.0
@ -473,6 +482,10 @@ packages:
resolution: {integrity: sha512-jhuKLIRrhvCPLqwPcx6INqmKeiA5EWrsCOPhrlFSrbrmU4ZMPjj5Ul/oLCMDO98XRUIwVm78xICz4EPCektzeQ==} resolution: {integrity: sha512-jhuKLIRrhvCPLqwPcx6INqmKeiA5EWrsCOPhrlFSrbrmU4ZMPjj5Ul/oLCMDO98XRUIwVm78xICz4EPCektzeQ==}
dev: true dev: true
/@types/mustache@4.2.2:
resolution: {integrity: sha512-MUSpfpW0yZbTgjekDbH0shMYBUD+X/uJJJMm9LXN1d5yjl5lCY1vN/eWKD6D1tOtjA6206K0zcIPnUaFMurdNA==}
dev: true
/@types/node@20.2.0: /@types/node@20.2.0:
resolution: {integrity: sha512-3iD2jaCCziTx04uudpJKwe39QxXgSUnpxXSvRQjRvHPxFQfmfP4NXIm/NURVeNlTCc+ru4WqjYGTmpXrW9uMlw==} resolution: {integrity: sha512-3iD2jaCCziTx04uudpJKwe39QxXgSUnpxXSvRQjRvHPxFQfmfP4NXIm/NURVeNlTCc+ru4WqjYGTmpXrW9uMlw==}
dev: true dev: true
@ -1642,6 +1655,11 @@ packages:
resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==}
dev: true dev: true
/mustache@4.2.0:
resolution: {integrity: sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==}
hasBin: true
dev: false
/mz@2.7.0: /mz@2.7.0:
resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==}
dependencies: dependencies:
@ -2541,6 +2559,16 @@ packages:
engines: {node: '>=10'} engines: {node: '>=10'}
dev: true dev: true
/zod-to-ts@1.1.4(typescript@5.0.4)(zod@3.21.4):
resolution: {integrity: sha512-jsCg+pTNxLAdJOfW4ul+SpechdGYEJPPnssSbqWdR2LSIkotT22k+UvqPb1nEHwe/YbEcbUOlZUfGM0npgR+Jg==}
peerDependencies:
typescript: ^4.9.4 || ^5.0.2
zod: ^3
dependencies:
typescript: 5.0.4
zod: 3.21.4
dev: false
/zod-validation-error@1.3.0(zod@3.21.4): /zod-validation-error@1.3.0(zod@3.21.4):
resolution: {integrity: sha512-4WoQnuWnj06kwKR4A+cykRxFmy+CTvwMQO5ogTXLiVx1AuvYYmMjixh7sbkSsQTr1Fvtss6d5kVz8PGeMPUQjQ==} resolution: {integrity: sha512-4WoQnuWnj06kwKR4A+cykRxFmy+CTvwMQO5ogTXLiVx1AuvYYmMjixh7sbkSsQTr1Fvtss6d5kVz8PGeMPUQjQ==}
engines: {node: '>=16.0.0'} engines: {node: '>=16.0.0'}

Wyświetl plik

@ -1,5 +1,7 @@
import Mustache from 'mustache'
import type { SetRequired } from 'type-fest' import type { SetRequired } from 'type-fest'
import { ZodRawShape, ZodTypeAny, z } from 'zod' import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types' import * as types from './types'
@ -161,16 +163,27 @@ export class OpenAIChatModelBuilder<
override async call( override async call(
input?: types.ParsedData<TInput> input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> { ): Promise<types.ParsedData<TOutput>> {
if (this._options.input) {
const inputSchema =
this._options.input instanceof z.ZodType
? this._options.input
: z.object(this._options.input)
// TODO: handle errors gracefully
input = inputSchema.parse(input)
}
// TODO: construct messages // TODO: construct messages
const messages = this._messages
const completion = await this._client.createChatCompletion({ const completion = await this._client.createChatCompletion({
model: defaultOpenAIModel, // TODO: this shouldn't be necessary model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining
...this._options.modelParams, ...this._options.modelParams,
messages: this._messages messages
}) })
if (this._options.output) { if (this._options.output) {
const schema = const outputSchema =
this._options.output instanceof z.ZodType this._options.output instanceof z.ZodType
? this._options.output ? this._options.output
: z.object(this._options.output) : z.object(this._options.output)
@ -178,9 +191,106 @@ export class OpenAIChatModelBuilder<
// TODO: convert string => object if necessary // TODO: convert string => object if necessary
// TODO: handle errors, retry logic, and self-healing // TODO: handle errors, retry logic, and self-healing
return schema.parse(completion.message.content) return outputSchema.parse(completion.message.content)
} else { } else {
return completion.message.content as any return completion.message.content as any
} }
} }
protected async _buildMessages(text: string, opts: types.SendMessageOptions) {
const { systemMessage = this._systemMessage } = opts
let { parentMessageId } = opts
const userLabel = USER_LABEL_DEFAULT
const assistantLabel = ASSISTANT_LABEL_DEFAULT
const maxNumTokens = this._maxModelTokens - this._maxResponseTokens
let messages: types.openai.ChatCompletionRequestMessage[] = []
if (systemMessage) {
messages.push({
role: 'system',
content: systemMessage
})
}
const systemMessageOffset = messages.length
let nextMessages = text
? messages.concat([
{
role: 'user',
content: text,
name: opts.name
}
])
: messages
let numTokens = 0
do {
const prompt = nextMessages
.reduce((prompt, message) => {
switch (message.role) {
case 'system':
return prompt.concat([`Instructions:\n${message.content}`])
case 'user':
return prompt.concat([`${userLabel}:\n${message.content}`])
default:
return prompt.concat([`${assistantLabel}:\n${message.content}`])
}
}, [] as string[])
.join('\n\n')
const nextNumTokensEstimate = await this._getTokenCount(prompt)
const isValidPrompt = nextNumTokensEstimate <= maxNumTokens
if (prompt && !isValidPrompt) {
break
}
messages = nextMessages
numTokens = nextNumTokensEstimate
if (!isValidPrompt) {
break
}
if (!parentMessageId) {
break
}
const parentMessage = await this._getMessageById(parentMessageId)
if (!parentMessage) {
break
}
const parentMessageRole = parentMessage.role || 'user'
nextMessages = nextMessages.slice(0, systemMessageOffset).concat([
{
role: parentMessageRole,
content: parentMessage.text,
name: parentMessage.name
},
...nextMessages.slice(systemMessageOffset)
])
parentMessageId = parentMessage.parentMessageId
} while (true)
// Use up to 4096 tokens (prompt + response), but try to leave 1000 tokens
// for the response.
const maxTokens = Math.max(
1,
Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens)
)
return { messages, maxTokens, numTokens }
}
protected async _getTokenCount(text: string) {
// TODO: use a better fix in the tokenizer
text = text.replace(/<\|endoftext\|>/g, '')
return tokenizer.encode(text).length
}
} }

Wyświetl plik

@ -19,10 +19,20 @@ async function main() {
console.log(ex0) console.log(ex0)
const ex1 = await $.gpt4(`give me fake data conforming to this schema`) const ex1 = await $.gpt4(
.output(z.object({ foo: z.string(), bar: z.number() })) `give me fake data conforming to this schema`
// .retry({ attempts: 3 }) ).output(z.object({ foo: z.string(), bar: z.number() }))
.call() // .retry({ attempts: 3 })
// .call()
const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`)
.input(z.object({ foo: z.string() }))
.output(z.boolean())
await Promise.all([
getBoolean.call({ foo: 'foo' }),
getBoolean.call({ foo: 'bar' })
])
console.log(ex1) console.log(ex1)
} }