diff --git a/package.json b/package.json index 027385e9..f2738e7d 100644 --- a/package.json +++ b/package.json @@ -25,15 +25,18 @@ "dependencies": { "@dqbd/tiktoken": "^1.0.7", "dotenv-safe": "^8.2.0", + "mustache": "^4.2.0", "openai-fetch": "^1.2.1", "p-map": "^6.0.0", "parse-json": "^7.0.0", "type-fest": "^3.10.0", "zod": "^3.21.4", + "zod-to-ts": "^1.1.4", "zod-validation-error": "^1.3.0" }, "devDependencies": { "@trivago/prettier-plugin-sort-imports": "^4.1.1", + "@types/mustache": "^4.2.2", "@types/node": "^20.2.0", "del-cli": "^5.0.0", "husky": "^8.0.3", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 106ece55..d962da79 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -7,6 +7,9 @@ dependencies: dotenv-safe: specifier: ^8.2.0 version: 8.2.0 + mustache: + specifier: ^4.2.0 + version: 4.2.0 openai-fetch: specifier: ^1.2.1 version: 1.2.1 @@ -22,6 +25,9 @@ dependencies: zod: specifier: ^3.21.4 version: 3.21.4 + zod-to-ts: + specifier: ^1.1.4 + version: 1.1.4(typescript@5.0.4)(zod@3.21.4) zod-validation-error: specifier: ^1.3.0 version: 1.3.0(zod@3.21.4) @@ -30,6 +36,9 @@ devDependencies: '@trivago/prettier-plugin-sort-imports': specifier: ^4.1.1 version: 4.1.1(prettier@2.8.8) + '@types/mustache': + specifier: ^4.2.2 + version: 4.2.2 '@types/node': specifier: ^20.2.0 version: 20.2.0 @@ -473,6 +482,10 @@ packages: resolution: {integrity: sha512-jhuKLIRrhvCPLqwPcx6INqmKeiA5EWrsCOPhrlFSrbrmU4ZMPjj5Ul/oLCMDO98XRUIwVm78xICz4EPCektzeQ==} dev: true + /@types/mustache@4.2.2: + resolution: {integrity: sha512-MUSpfpW0yZbTgjekDbH0shMYBUD+X/uJJJMm9LXN1d5yjl5lCY1vN/eWKD6D1tOtjA6206K0zcIPnUaFMurdNA==} + dev: true + /@types/node@20.2.0: resolution: {integrity: sha512-3iD2jaCCziTx04uudpJKwe39QxXgSUnpxXSvRQjRvHPxFQfmfP4NXIm/NURVeNlTCc+ru4WqjYGTmpXrW9uMlw==} dev: true @@ -1642,6 +1655,11 @@ packages: resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} dev: true + /mustache@4.2.0: + resolution: {integrity: sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==} + hasBin: true + dev: false + /mz@2.7.0: resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} dependencies: @@ -2541,6 +2559,16 @@ packages: engines: {node: '>=10'} dev: true + /zod-to-ts@1.1.4(typescript@5.0.4)(zod@3.21.4): + resolution: {integrity: sha512-jsCg+pTNxLAdJOfW4ul+SpechdGYEJPPnssSbqWdR2LSIkotT22k+UvqPb1nEHwe/YbEcbUOlZUfGM0npgR+Jg==} + peerDependencies: + typescript: ^4.9.4 || ^5.0.2 + zod: ^3 + dependencies: + typescript: 5.0.4 + zod: 3.21.4 + dev: false + /zod-validation-error@1.3.0(zod@3.21.4): resolution: {integrity: sha512-4WoQnuWnj06kwKR4A+cykRxFmy+CTvwMQO5ogTXLiVx1AuvYYmMjixh7sbkSsQTr1Fvtss6d5kVz8PGeMPUQjQ==} engines: {node: '>=16.0.0'} diff --git a/src/llm.ts b/src/llm.ts index cd5c2e30..0b8435bd 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1,5 +1,7 @@ +import Mustache from 'mustache' import type { SetRequired } from 'type-fest' import { ZodRawShape, ZodTypeAny, z } from 'zod' +import { printNode, zodToTs } from 'zod-to-ts' import * as types from './types' @@ -161,16 +163,27 @@ export class OpenAIChatModelBuilder< override async call( input?: types.ParsedData ): Promise> { + if (this._options.input) { + const inputSchema = + this._options.input instanceof z.ZodType + ? this._options.input + : z.object(this._options.input) + + // TODO: handle errors gracefully + input = inputSchema.parse(input) + } + // TODO: construct messages + const messages = this._messages const completion = await this._client.createChatCompletion({ - model: defaultOpenAIModel, // TODO: this shouldn't be necessary + model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining ...this._options.modelParams, - messages: this._messages + messages }) if (this._options.output) { - const schema = + const outputSchema = this._options.output instanceof z.ZodType ? this._options.output : z.object(this._options.output) @@ -178,9 +191,106 @@ export class OpenAIChatModelBuilder< // TODO: convert string => object if necessary // TODO: handle errors, retry logic, and self-healing - return schema.parse(completion.message.content) + return outputSchema.parse(completion.message.content) } else { return completion.message.content as any } } + + protected async _buildMessages(text: string, opts: types.SendMessageOptions) { + const { systemMessage = this._systemMessage } = opts + let { parentMessageId } = opts + + const userLabel = USER_LABEL_DEFAULT + const assistantLabel = ASSISTANT_LABEL_DEFAULT + + const maxNumTokens = this._maxModelTokens - this._maxResponseTokens + let messages: types.openai.ChatCompletionRequestMessage[] = [] + + if (systemMessage) { + messages.push({ + role: 'system', + content: systemMessage + }) + } + + const systemMessageOffset = messages.length + let nextMessages = text + ? messages.concat([ + { + role: 'user', + content: text, + name: opts.name + } + ]) + : messages + let numTokens = 0 + + do { + const prompt = nextMessages + .reduce((prompt, message) => { + switch (message.role) { + case 'system': + return prompt.concat([`Instructions:\n${message.content}`]) + case 'user': + return prompt.concat([`${userLabel}:\n${message.content}`]) + default: + return prompt.concat([`${assistantLabel}:\n${message.content}`]) + } + }, [] as string[]) + .join('\n\n') + + const nextNumTokensEstimate = await this._getTokenCount(prompt) + const isValidPrompt = nextNumTokensEstimate <= maxNumTokens + + if (prompt && !isValidPrompt) { + break + } + + messages = nextMessages + numTokens = nextNumTokensEstimate + + if (!isValidPrompt) { + break + } + + if (!parentMessageId) { + break + } + + const parentMessage = await this._getMessageById(parentMessageId) + if (!parentMessage) { + break + } + + const parentMessageRole = parentMessage.role || 'user' + + nextMessages = nextMessages.slice(0, systemMessageOffset).concat([ + { + role: parentMessageRole, + content: parentMessage.text, + name: parentMessage.name + }, + ...nextMessages.slice(systemMessageOffset) + ]) + + parentMessageId = parentMessage.parentMessageId + } while (true) + + // Use up to 4096 tokens (prompt + response), but try to leave 1000 tokens + // for the response. + const maxTokens = Math.max( + 1, + Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens) + ) + + return { messages, maxTokens, numTokens } + } + + protected async _getTokenCount(text: string) { + // TODO: use a better fix in the tokenizer + text = text.replace(/<\|endoftext\|>/g, '') + + return tokenizer.encode(text).length + } } diff --git a/src/temp.ts b/src/temp.ts index f4732e3c..7966b291 100644 --- a/src/temp.ts +++ b/src/temp.ts @@ -19,10 +19,20 @@ async function main() { console.log(ex0) - const ex1 = await $.gpt4(`give me fake data conforming to this schema`) - .output(z.object({ foo: z.string(), bar: z.number() })) - // .retry({ attempts: 3 }) - .call() + const ex1 = await $.gpt4( + `give me fake data conforming to this schema` + ).output(z.object({ foo: z.string(), bar: z.number() })) + // .retry({ attempts: 3 }) + // .call() + + const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`) + .input(z.object({ foo: z.string() })) + .output(z.boolean()) + + await Promise.all([ + getBoolean.call({ foo: 'foo' }), + getBoolean.call({ foo: 'bar' }) + ]) console.log(ex1) }