From 051b4f4928a86efcc9687c9a108223b3c18da297 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Mon, 12 Jun 2023 16:59:54 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=8D=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scratch/ts-fluent-problem.md | 35 +++ src/index.ts | 2 + src/llms/anthropic.ts | 2 +- src/llms/chat.ts | 399 +++++++++++++++++++++++++++++++++++ src/llms/index.ts | 1 + src/llms/llm.ts | 292 +------------------------ src/llms/openai.ts | 2 +- 7 files changed, 441 insertions(+), 292 deletions(-) create mode 100644 scratch/ts-fluent-problem.md create mode 100644 src/llms/chat.ts diff --git a/scratch/ts-fluent-problem.md b/scratch/ts-fluent-problem.md new file mode 100644 index 0000000..0f9d006 --- /dev/null +++ b/scratch/ts-fluent-problem.md @@ -0,0 +1,35 @@ +For the following TypeScript code: + +```ts +import { ZodType, z } from 'zod' + +class Super { + protected _inputSchema: ZodType | undefined + protected _outputSchema: ZodType | undefined + + input(outputSchema: ZodType): Super { + const refinedInstance = this as unknown as Super + refinedInstance._inputSchema = inputSchema + return refinedInstance + } + + output(outputSchema: ZodType): Super { + const refinedInstance = this as unknown as Super + refinedInstance._outputSchema = outputSchema + return refinedInstance + } +} + +class SubA extends Super {} +class SubB extends Super {} +``` + +```ts +const a = new SubA() +a.output() // SubA + +const b = new SubB() +b.output() // SubB +``` + +How can I change this implementation so `input` and `output` return the correct subclassed types? diff --git a/src/index.ts b/src/index.ts index 81a4b43..9dac168 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,5 +8,7 @@ export * from './human-feedback' export * from './services/metaphor' export * from './services/serpapi' export * from './services/novu' + +export * from './tools/calculator' export * from './tools/metaphor' export * from './tools/novu' diff --git a/src/llms/anthropic.ts b/src/llms/anthropic.ts index 1d60823..4fe4375 100644 --- a/src/llms/anthropic.ts +++ b/src/llms/anthropic.ts @@ -4,7 +4,7 @@ import { type SetOptional } from 'type-fest' import * as types from '@/types' import { DEFAULT_ANTHROPIC_MODEL } from '@/constants' -import { BaseChatModel } from './llm' +import { BaseChatModel } from './chat' const defaultStopSequences = [anthropic.HUMAN_PROMPT] diff --git a/src/llms/chat.ts b/src/llms/chat.ts new file mode 100644 index 0000000..1bfe709 --- /dev/null +++ b/src/llms/chat.ts @@ -0,0 +1,399 @@ +import { JSONRepairError, jsonrepair } from 'jsonrepair' +import pMap from 'p-map' +import { dedent } from 'ts-dedent' +import { type SetRequired } from 'type-fest' +import { ZodType, z } from 'zod' +import { printNode, zodToTs } from 'zod-to-ts' + +import * as errors from '@/errors' +import * as types from '@/types' +import { BaseTask } from '@/task' +import { getCompiledTemplate } from '@/template' +import { + Tokenizer, + getModelNameForTiktoken, + getTokenizerForModel +} from '@/tokenizer' +import { + extractJSONArrayFromString, + extractJSONObjectFromString +} from '@/utils' + +// TODO: TInput should only be allowed to be void or an object +export abstract class BaseLLM< + TInput = void, + TOutput = string, + TModelParams extends Record = Record +> extends BaseTask { + protected _inputSchema: ZodType | undefined + protected _outputSchema: ZodType | undefined + + protected _provider: string + protected _model: string + protected _modelParams: TModelParams | undefined + protected _examples: types.LLMExample[] | undefined + protected _tokenizerP?: Promise + + constructor( + options: SetRequired< + types.BaseLLMOptions, + 'provider' | 'model' + > + ) { + super(options) + + this._inputSchema = options.inputSchema + this._outputSchema = options.outputSchema + + this._provider = options.provider + this._model = options.model + this._modelParams = options.modelParams + this._examples = options.examples + } + + // TODO: use polymorphic `this` type to return correct BaseLLM subclass type + input(inputSchema: ZodType): BaseLLM { + const refinedInstance = this as unknown as BaseLLM + refinedInstance._inputSchema = inputSchema + return refinedInstance + } + + // TODO: use polymorphic `this` type to return correct BaseLLM subclass type + output(outputSchema: ZodType): BaseLLM { + const refinedInstance = this as unknown as BaseLLM + refinedInstance._outputSchema = outputSchema + return refinedInstance + } + + public override get inputSchema(): ZodType { + if (this._inputSchema) { + return this._inputSchema + } else { + // TODO: improve typing + return z.void() as unknown as ZodType + } + } + + public override get outputSchema(): ZodType { + if (this._outputSchema) { + return this._outputSchema + } else { + // TODO: improve typing + return z.string() as unknown as ZodType + } + } + + public override get name(): string { + return `${this._provider}:chat:${this._model}` + } + + examples(examples: types.LLMExample[]): this { + this._examples = examples + return this + } + + modelParams(params: Partial): this { + // We assume that modelParams does not include nested objects. + // If it did, we would need to do a deep merge. + this._modelParams = { ...this._modelParams, ...params } as TModelParams + return this + } + + public async getNumTokens(text: string): Promise { + if (!this._tokenizerP) { + const model = this._model || 'gpt2' + + this._tokenizerP = getTokenizerForModel(model).catch((err) => { + console.warn( + `Failed to initialize tokenizer for model "${model}", falling back to approximate count`, + err + ) + + return null + }) + } + + const tokenizer = await this._tokenizerP + + if (tokenizer) { + return tokenizer.encode(text).length + } + + // fallback to approximate calculation if tokenizer is not available + return Math.ceil(text.length / 4) + } +} + +export abstract class BaseChatModel< + TInput = void, + TOutput = string, + TModelParams extends Record = Record, + TChatCompletionResponse extends Record = Record +> extends BaseLLM { + _messages: types.ChatMessage[] + + constructor( + options: SetRequired< + types.ChatModelOptions, + 'provider' | 'model' | 'messages' + > + ) { + super(options) + + this._messages = options.messages + } + + // TODO: use polymorphic `this` type to return correct BaseLLM subclass type + input(inputSchema: ZodType): BaseChatModel { + const refinedInstance = this as unknown as BaseChatModel< + U, + TOutput, + TModelParams + > + refinedInstance._inputSchema = inputSchema + return refinedInstance + } + + // TODO: use polymorphic `this` type to return correct BaseLLM subclass type + output(outputSchema: ZodType): BaseChatModel { + const refinedInstance = this as unknown as BaseChatModel< + TInput, + U, + TModelParams + > + refinedInstance._outputSchema = outputSchema + return refinedInstance + } + + protected abstract _createChatCompletion( + messages: types.ChatMessage[] + ): Promise> + + public async buildMessages( + input?: TInput, + ctx?: types.TaskCallContext + ) { + if (this._inputSchema) { + // TODO: handle errors gracefully + input = this.inputSchema.parse(input) + } + + // TODO: validate input message variables against input schema + console.log({ input }) + + const messages = this._messages + .map((message) => { + return { + ...message, + content: message.content + ? getCompiledTemplate(dedent(message.content))(input).trim() + : '' + } + }) + .filter((message) => message.content) + + if (this._examples?.length) { + // TODO: smarter example selection + for (const example of this._examples) { + messages.push({ + role: 'system', + content: `Example input: ${example.input}\n\nExample output: ${example.output}` + }) + } + } + + if (this._outputSchema) { + const { node } = zodToTs(this._outputSchema) + + if (node.kind === 152) { + // handle raw strings differently + messages.push({ + role: 'system', + content: dedent`Output a raw string only, without any additional text.` + }) + } else { + const tsTypeString = printNode(node, { + removeComments: false, + // TODO: this doesn't seem to actually work, so we're doing it manually below + omitTrailingSemicolon: true, + noEmitHelpers: true + }) + .replace(/^ {4}/gm, ' ') + .replace(/;$/gm, '') + + messages.push({ + role: 'system', + content: dedent`Do not output code. Output JSON only in the following TypeScript format: + \`\`\`ts + ${tsTypeString} + \`\`\`` + }) + } + } + + if (ctx?.retryMessage) { + messages.push({ + role: 'system', + content: ctx.retryMessage + }) + } + + // TODO: filter/compress messages based on token counts + + return messages + } + + protected override async _call( + ctx: types.TaskCallContext + ): Promise { + const messages = await this.buildMessages(ctx.input, ctx) + + console.log('>>>') + console.log(messages) + + const completion = await this._createChatCompletion(messages) + ctx.metadata.completion = completion + + let output: any = completion.message.content + + console.log('===') + console.log(output) + console.log('<<<') + + if (this._outputSchema) { + const outputSchema = this._outputSchema + + if (outputSchema instanceof z.ZodArray) { + try { + const trimmedOutput = extractJSONArrayFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON array: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } + } else if (outputSchema instanceof z.ZodObject) { + try { + const trimmedOutput = extractJSONObjectFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON object: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } + } else if (outputSchema instanceof z.ZodBoolean) { + output = output.toLowerCase().trim() + const booleanOutputs = { + true: true, + false: false, + yes: true, + no: false, + 1: true, + 0: false + } + + const booleanOutput = booleanOutputs[output] + + if (booleanOutput !== undefined) { + output = booleanOutput + } else { + throw new errors.OutputValidationError( + `Invalid boolean output: ${output}` + ) + } + } else if (outputSchema instanceof z.ZodNumber) { + output = output.trim() + + const numberOutput = outputSchema.isInt + ? parseInt(output) + : parseFloat(output) + + if (isNaN(numberOutput)) { + throw new errors.OutputValidationError( + `Invalid number output: ${output}` + ) + } else { + output = numberOutput + } + } + + const safeResult = outputSchema.safeParse(output) + + if (!safeResult.success) { + throw new errors.ZodOutputValidationError(safeResult.error) + } + + return safeResult.data + } else { + return output + } + } + + // TODO: this needs work + testing + // TODO: move to isolated file and/or module + public async getNumTokensForMessages(messages: types.ChatMessage[]): Promise<{ + numTokensTotal: number + numTokensPerMessage: number[] + }> { + let numTokensTotal = 0 + let tokensPerMessage = 0 + let tokensPerName = 0 + + const modelName = getModelNameForTiktoken(this._model) + + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb + if (modelName === 'gpt-3.5-turbo') { + tokensPerMessage = 4 + tokensPerName = -1 + } else if (modelName.startsWith('gpt-4')) { + tokensPerMessage = 3 + tokensPerName = 1 + } else { + // TODO + tokensPerMessage = 4 + tokensPerName = -1 + } + + const numTokensPerMessage = await pMap( + messages, + async (message) => { + const [numTokensContent, numTokensRole, numTokensName] = + await Promise.all([ + this.getNumTokens(message.content), + this.getNumTokens(message.role), + message.name + ? this.getNumTokens(message.name).then((n) => n + tokensPerName) + : Promise.resolve(0) + ]) + + const numTokens = + tokensPerMessage + numTokensContent + numTokensRole + numTokensName + + numTokensTotal += numTokens + return numTokens + }, + { + concurrency: 8 + } + ) + + // TODO + numTokensTotal += 3 // every reply is primed with <|start|>assistant<|message|> + + return { numTokensTotal, numTokensPerMessage } + } +} diff --git a/src/llms/index.ts b/src/llms/index.ts index e2425cb..0b742fe 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -1,3 +1,4 @@ export * from './llm' +export * from './chat' export * from './openai' export * from './anthropic' diff --git a/src/llms/llm.ts b/src/llms/llm.ts index 945aefe..ab8d746 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -1,25 +1,11 @@ -import { JSONRepairError, jsonrepair } from 'jsonrepair' -import pMap from 'p-map' -import { dedent } from 'ts-dedent' import { type SetRequired } from 'type-fest' import { ZodType, z } from 'zod' -import { printNode, zodToTs } from 'zod-to-ts' -import * as errors from '@/errors' import * as types from '@/types' import { BaseTask } from '@/task' -import { getCompiledTemplate } from '@/template' -import { - Tokenizer, - getModelNameForTiktoken, - getTokenizerForModel -} from '@/tokenizer' -import { - extractJSONArrayFromString, - extractJSONObjectFromString -} from '@/utils' +import { Tokenizer, getTokenizerForModel } from '@/tokenizer' -// TODO: TInput should only be allowed to be an object +// TODO: TInput should only be allowed to be void or an object export abstract class BaseLLM< TInput = void, TOutput = string, @@ -123,277 +109,3 @@ export abstract class BaseLLM< return Math.ceil(text.length / 4) } } - -export abstract class BaseChatModel< - TInput = void, - TOutput = string, - TModelParams extends Record = Record, - TChatCompletionResponse extends Record = Record -> extends BaseLLM { - _messages: types.ChatMessage[] - - constructor( - options: SetRequired< - types.ChatModelOptions, - 'provider' | 'model' | 'messages' - > - ) { - super(options) - - this._messages = options.messages - } - - // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - input(inputSchema: ZodType): BaseChatModel { - const refinedInstance = this as unknown as BaseChatModel< - U, - TOutput, - TModelParams - > - refinedInstance._inputSchema = inputSchema - return refinedInstance - } - - // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - output(outputSchema: ZodType): BaseChatModel { - const refinedInstance = this as unknown as BaseChatModel< - TInput, - U, - TModelParams - > - refinedInstance._outputSchema = outputSchema - return refinedInstance - } - - protected abstract _createChatCompletion( - messages: types.ChatMessage[] - ): Promise> - - public async buildMessages( - input?: TInput, - ctx?: types.TaskCallContext - ) { - if (this._inputSchema) { - // TODO: handle errors gracefully - input = this.inputSchema.parse(input) - } - - // TODO: validate input message variables against input schema - console.log({ input }) - - const messages = this._messages - .map((message) => { - return { - ...message, - content: message.content - ? getCompiledTemplate(dedent(message.content))(input).trim() - : '' - } - }) - .filter((message) => message.content) - - if (this._examples?.length) { - // TODO: smarter example selection - for (const example of this._examples) { - messages.push({ - role: 'system', - content: `Example input: ${example.input}\n\nExample output: ${example.output}` - }) - } - } - - if (this._outputSchema) { - const { node } = zodToTs(this._outputSchema) - - if (node.kind === 152) { - // handle raw strings differently - messages.push({ - role: 'system', - content: dedent`Output a raw string only, without any additional text.` - }) - } else { - const tsTypeString = printNode(node, { - removeComments: false, - // TODO: this doesn't seem to actually work, so we're doing it manually below - omitTrailingSemicolon: true, - noEmitHelpers: true - }) - .replace(/^ {4}/gm, ' ') - .replace(/;$/gm, '') - - messages.push({ - role: 'system', - content: dedent`Do not output code. Output JSON only in the following TypeScript format: - \`\`\`ts - ${tsTypeString} - \`\`\`` - }) - } - } - - if (ctx?.retryMessage) { - messages.push({ - role: 'system', - content: ctx.retryMessage - }) - } - - // TODO: filter/compress messages based on token counts - - return messages - } - - protected override async _call( - ctx: types.TaskCallContext - ): Promise { - const messages = await this.buildMessages(ctx.input, ctx) - - console.log('>>>') - console.log(messages) - - const completion = await this._createChatCompletion(messages) - ctx.metadata.completion = completion - - let output: any = completion.message.content - - console.log('===') - console.log(output) - console.log('<<<') - - if (this._outputSchema) { - const outputSchema = this._outputSchema - - if (outputSchema instanceof z.ZodArray) { - try { - const trimmedOutput = extractJSONArrayFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - } catch (err: any) { - if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) - } else if (err instanceof SyntaxError) { - throw new errors.OutputValidationError( - `Invalid JSON array: ${err.message}`, - { cause: err } - ) - } else { - throw err - } - } - } else if (outputSchema instanceof z.ZodObject) { - try { - const trimmedOutput = extractJSONObjectFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) - } catch (err: any) { - if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) - } else if (err instanceof SyntaxError) { - throw new errors.OutputValidationError( - `Invalid JSON object: ${err.message}`, - { cause: err } - ) - } else { - throw err - } - } - } else if (outputSchema instanceof z.ZodBoolean) { - output = output.toLowerCase().trim() - const booleanOutputs = { - true: true, - false: false, - yes: true, - no: false, - 1: true, - 0: false - } - - const booleanOutput = booleanOutputs[output] - - if (booleanOutput !== undefined) { - output = booleanOutput - } else { - throw new errors.OutputValidationError( - `Invalid boolean output: ${output}` - ) - } - } else if (outputSchema instanceof z.ZodNumber) { - output = output.trim() - - const numberOutput = outputSchema.isInt - ? parseInt(output) - : parseFloat(output) - - if (isNaN(numberOutput)) { - throw new errors.OutputValidationError( - `Invalid number output: ${output}` - ) - } else { - output = numberOutput - } - } - - const safeResult = outputSchema.safeParse(output) - - if (!safeResult.success) { - throw new errors.ZodOutputValidationError(safeResult.error) - } - - return safeResult.data - } else { - return output - } - } - - // TODO: this needs work + testing - // TODO: move to isolated file and/or module - public async getNumTokensForMessages(messages: types.ChatMessage[]): Promise<{ - numTokensTotal: number - numTokensPerMessage: number[] - }> { - let numTokensTotal = 0 - let tokensPerMessage = 0 - let tokensPerName = 0 - - const modelName = getModelNameForTiktoken(this._model) - - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb - if (modelName === 'gpt-3.5-turbo') { - tokensPerMessage = 4 - tokensPerName = -1 - } else if (modelName.startsWith('gpt-4')) { - tokensPerMessage = 3 - tokensPerName = 1 - } else { - // TODO - tokensPerMessage = 4 - tokensPerName = -1 - } - - const numTokensPerMessage = await pMap( - messages, - async (message) => { - const [numTokensContent, numTokensRole, numTokensName] = - await Promise.all([ - this.getNumTokens(message.content), - this.getNumTokens(message.role), - message.name - ? this.getNumTokens(message.name).then((n) => n + tokensPerName) - : Promise.resolve(0) - ]) - - const numTokens = - tokensPerMessage + numTokensContent + numTokensRole + numTokensName - - numTokensTotal += numTokens - return numTokens - }, - { - concurrency: 8 - } - ) - - // TODO - numTokensTotal += 3 // every reply is primed with <|start|>assistant<|message|> - - return { numTokensTotal, numTokensPerMessage } - } -} diff --git a/src/llms/openai.ts b/src/llms/openai.ts index 5330cfc..3956fef 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -3,7 +3,7 @@ import { type SetOptional } from 'type-fest' import * as types from '@/types' import { DEFAULT_OPENAI_MODEL } from '@/constants' -import { BaseChatModel } from './llm' +import { BaseChatModel } from './chat' export class OpenAIChatModel< TInput = any,