From 2afaf4d9adc81afc9c6b70ed2ed281118fce0185 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Tue, 23 May 2023 22:28:38 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- package.json | 1 + pnpm-lock.yaml | 8 ++++ src/index.ts | 3 +- src/llm.ts | 110 ++++++----------------------------------------- src/temp.ts | 9 ++-- src/tokenizer.ts | 5 +++ src/utils.ts | 39 ----------------- 7 files changed, 32 insertions(+), 143 deletions(-) create mode 100644 src/tokenizer.ts delete mode 100644 src/utils.ts diff --git a/package.json b/package.json index f2738e7..7356ac6 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,7 @@ "openai-fetch": "^1.2.1", "p-map": "^6.0.0", "parse-json": "^7.0.0", + "ts-dedent": "^2.2.0", "type-fest": "^3.10.0", "zod": "^3.21.4", "zod-to-ts": "^1.1.4", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d962da7..f422f76 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -19,6 +19,9 @@ dependencies: parse-json: specifier: ^7.0.0 version: 7.0.0(typescript@5.0.4) + ts-dedent: + specifier: ^2.2.0 + version: 2.2.0 type-fest: specifier: ^3.10.0 version: 3.10.0(typescript@5.0.4) @@ -2361,6 +2364,11 @@ packages: engines: {node: '>=12'} dev: true + /ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + dev: false + /ts-interface-checker@0.1.13: resolution: {integrity: sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==} dev: true diff --git a/src/index.ts b/src/index.ts index 9c56149..5e04fea 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1 +1,2 @@ -export * from './utils' +export * from './llm' +export * from './tokenizer' diff --git a/src/llm.ts b/src/llm.ts index 0b8435b..bc6d13b 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1,4 +1,5 @@ import Mustache from 'mustache' +import { dedent } from 'ts-dedent' import type { SetRequired } from 'type-fest' import { ZodRawShape, ZodTypeAny, z } from 'zod' import { printNode, zodToTs } from 'zod-to-ts' @@ -173,8 +174,18 @@ export class OpenAIChatModelBuilder< input = inputSchema.parse(input) } - // TODO: construct messages const messages = this._messages + .map((message) => { + return { + ...message, + content: message.content + ? Mustache.render(dedent(message.content), input).trim() + : '' + } + }) + .filter((message) => message.content) + + // TODO: filter/compress messages based on token counts const completion = await this._client.createChatCompletion({ model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining @@ -196,101 +207,4 @@ export class OpenAIChatModelBuilder< return completion.message.content as any } } - - protected async _buildMessages(text: string, opts: types.SendMessageOptions) { - const { systemMessage = this._systemMessage } = opts - let { parentMessageId } = opts - - const userLabel = USER_LABEL_DEFAULT - const assistantLabel = ASSISTANT_LABEL_DEFAULT - - const maxNumTokens = this._maxModelTokens - this._maxResponseTokens - let messages: types.openai.ChatCompletionRequestMessage[] = [] - - if (systemMessage) { - messages.push({ - role: 'system', - content: systemMessage - }) - } - - const systemMessageOffset = messages.length - let nextMessages = text - ? messages.concat([ - { - role: 'user', - content: text, - name: opts.name - } - ]) - : messages - let numTokens = 0 - - do { - const prompt = nextMessages - .reduce((prompt, message) => { - switch (message.role) { - case 'system': - return prompt.concat([`Instructions:\n${message.content}`]) - case 'user': - return prompt.concat([`${userLabel}:\n${message.content}`]) - default: - return prompt.concat([`${assistantLabel}:\n${message.content}`]) - } - }, [] as string[]) - .join('\n\n') - - const nextNumTokensEstimate = await this._getTokenCount(prompt) - const isValidPrompt = nextNumTokensEstimate <= maxNumTokens - - if (prompt && !isValidPrompt) { - break - } - - messages = nextMessages - numTokens = nextNumTokensEstimate - - if (!isValidPrompt) { - break - } - - if (!parentMessageId) { - break - } - - const parentMessage = await this._getMessageById(parentMessageId) - if (!parentMessage) { - break - } - - const parentMessageRole = parentMessage.role || 'user' - - nextMessages = nextMessages.slice(0, systemMessageOffset).concat([ - { - role: parentMessageRole, - content: parentMessage.text, - name: parentMessage.name - }, - ...nextMessages.slice(systemMessageOffset) - ]) - - parentMessageId = parentMessage.parentMessageId - } while (true) - - // Use up to 4096 tokens (prompt + response), but try to leave 1000 tokens - // for the response. - const maxTokens = Math.max( - 1, - Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens) - ) - - return { messages, maxTokens, numTokens } - } - - protected async _getTokenCount(text: string) { - // TODO: use a better fix in the tokenizer - text = text.replace(/<\|endoftext\|>/g, '') - - return tokenizer.encode(text).length - } } diff --git a/src/temp.ts b/src/temp.ts index 7966b29..53ce247 100644 --- a/src/temp.ts +++ b/src/temp.ts @@ -19,11 +19,10 @@ async function main() { console.log(ex0) - const ex1 = await $.gpt4( - `give me fake data conforming to this schema` - ).output(z.object({ foo: z.string(), bar: z.number() })) - // .retry({ attempts: 3 }) - // .call() + const ex1 = await $.gpt4(`give me fake data conforming to this schema`) + .output(z.object({ foo: z.string(), bar: z.number() })) + // .retry({ attempts: 3 }) + .call() const getBoolean = $.gpt4(`give me a single boolean value {{foo}}`) .input(z.object({ foo: z.string() })) diff --git a/src/tokenizer.ts b/src/tokenizer.ts new file mode 100644 index 0000000..ba98904 --- /dev/null +++ b/src/tokenizer.ts @@ -0,0 +1,5 @@ +import { encoding_for_model } from '@dqbd/tiktoken' + +export function getTokenizerForModel(model: string) { + return encoding_for_model(model as any) +} diff --git a/src/utils.ts b/src/utils.ts deleted file mode 100644 index 8a007b4..0000000 --- a/src/utils.ts +++ /dev/null @@ -1,39 +0,0 @@ -import dotenv from 'dotenv-safe' -import { OpenAIClient } from 'openai-fetch' -import { z } from 'zod' -import { fromZodError } from 'zod-validation-error' - -dotenv.config() - -interface Temp { - contentType: string -} - -async function main() { - const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY }) - - const outputSchema = z.object({}) - - const res = await openai.createChatCompletion({ - model: 'gpt-4', - messages: [ - { - role: 'system', - content: '' - } - ] - }) - - const out = await infer('give me a single boolean value', z.boolean(), {}) -} - -async function infer( - prompt: string, - schema: z.ZodType, - { retry } -): Promise {} - -main().catch((err) => { - console.error('error', err) - process.exit(1) -})