From 9937498c12f917ddee2bf07e1033460864e5efd8 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Tue, 13 Jun 2023 21:50:44 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- legacy/src/llms/chat.ts | 65 +++---------------------------- legacy/src/llms/llm-utils.ts | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 59 deletions(-) diff --git a/legacy/src/llms/chat.ts b/legacy/src/llms/chat.ts index 5127f202..5a1ade28 100644 --- a/legacy/src/llms/chat.ts +++ b/legacy/src/llms/chat.ts @@ -1,5 +1,4 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair' -import pMap from 'p-map' import { dedent } from 'ts-dedent' import { type SetRequired } from 'type-fest' import { ZodType, z } from 'zod' @@ -8,7 +7,6 @@ import { printNode, zodToTs } from 'zod-to-ts' import * as errors from '@/errors' import * as types from '@/types' import { getCompiledTemplate } from '@/template' -import { getModelNameForTiktoken } from '@/tokenizer' import { extractJSONArrayFromString, extractJSONObjectFromString @@ -16,6 +14,7 @@ import { import { BaseTask } from '../task' import { BaseLLM } from './llm' +import { getNumTokensForChatMessages } from './llm-utils' export abstract class BaseChatModel< TInput extends void | types.JsonObject = void, @@ -255,63 +254,11 @@ export abstract class BaseChatModel< } } - // TODO: this needs work + testing - // TODO: move to isolated file and/or module - public async getNumTokensForMessages(messages: types.ChatMessage[]): Promise<{ - numTokensTotal: number - numTokensPerMessage: number[] - }> { - let numTokensTotal = 0 - let tokensPerMessage = 0 - let tokensPerName = 0 - - const modelName = getModelNameForTiktoken(this._model) - - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb - if (modelName === 'gpt-3.5-turbo') { - tokensPerMessage = 4 - tokensPerName = -1 - } else if (modelName.startsWith('gpt-4')) { - tokensPerMessage = 3 - tokensPerName = 1 - } else { - // TODO - tokensPerMessage = 4 - tokensPerName = -1 - } - - const numTokensPerMessage = await pMap( + public async getNumTokensForMessages(messages: types.ChatMessage[]) { + return getNumTokensForChatMessages({ messages, - async (message) => { - let content = message.content || '' - if (message.function_call) { - // TODO: this case needs testing - content = message.function_call.arguments - } - - const [numTokensContent, numTokensRole, numTokensName] = - await Promise.all([ - this.getNumTokens(content), - this.getNumTokens(message.role), - message.name - ? this.getNumTokens(message.name).then((n) => n + tokensPerName) - : Promise.resolve(0) - ]) - - const numTokens = - tokensPerMessage + numTokensContent + numTokensRole + numTokensName - - numTokensTotal += numTokens - return numTokens - }, - { - concurrency: 8 - } - ) - - // TODO - numTokensTotal += 3 // every reply is primed with <|start|>assistant<|message|> - - return { numTokensTotal, numTokensPerMessage } + model: this._model, + getNumTokens: this.getNumTokens.bind(this) + }) } } diff --git a/legacy/src/llms/llm-utils.ts b/legacy/src/llms/llm-utils.ts index 25b0c4be..bde0025f 100644 --- a/legacy/src/llms/llm-utils.ts +++ b/legacy/src/llms/llm-utils.ts @@ -1,9 +1,84 @@ +import pMap from 'p-map' import { zodToJsonSchema } from 'zod-to-json-schema' import * as types from '@/types' import { BaseTask } from '@/task' +import { getModelNameForTiktoken } from '@/tokenizer' import { isValidTaskIdentifier } from '@/utils' +// TODO: this needs work + testing +// TODO: move to isolated module +export async function getNumTokensForChatMessages({ + messages, + model, + getNumTokens, + concurrency = 8 +}: { + messages: types.ChatMessage[] + model: string + getNumTokens: (text: string) => Promise + concurrency?: number +}): Promise<{ + numTokensTotal: number + numTokensPerMessage: number[] +}> { + let numTokensTotal = 0 + let configNumTokensPerMessage = 0 + let configNumTokensPerName = 0 + + const modelName = getModelNameForTiktoken(model) + + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb + if (modelName === 'gpt-3.5-turbo') { + configNumTokensPerMessage = 4 + configNumTokensPerName = -1 + } else if (modelName.startsWith('gpt-4')) { + configNumTokensPerMessage = 3 + configNumTokensPerName = 1 + } else { + // TODO + configNumTokensPerMessage = 4 + configNumTokensPerName = -1 + } + + const numTokensPerMessage = await pMap( + messages, + async (message) => { + let content = message.content || '' + if (message.function_call) { + // TODO: this case needs testing + content = message.function_call.arguments + } + + const [numTokensContent, numTokensRole, numTokensName] = + await Promise.all([ + getNumTokens(content), + getNumTokens(message.role), + message.name + ? getNumTokens(message.name).then((n) => n + configNumTokensPerName) + : Promise.resolve(0) + ]) + + const numTokens = + configNumTokensPerMessage + + numTokensContent + + numTokensRole + + numTokensName + + numTokensTotal += numTokens + return numTokens + }, + { + concurrency + } + ) + + // TODO + numTokensTotal += 3 // every reply is primed with <|start|>assistant<|message|> + + return { numTokensTotal, numTokensPerMessage } +} + export function getChatMessageFunctionDefinitionFromTask( task: BaseTask ): types.openai.ChatMessageFunction {