diff --git a/legacy/src/llms/chat.ts b/legacy/src/llms/chat.ts index 1bfe7097..3583b0a7 100644 --- a/legacy/src/llms/chat.ts +++ b/legacy/src/llms/chat.ts @@ -7,122 +7,14 @@ import { printNode, zodToTs } from 'zod-to-ts' import * as errors from '@/errors' import * as types from '@/types' -import { BaseTask } from '@/task' import { getCompiledTemplate } from '@/template' -import { - Tokenizer, - getModelNameForTiktoken, - getTokenizerForModel -} from '@/tokenizer' +import { getModelNameForTiktoken } from '@/tokenizer' import { extractJSONArrayFromString, extractJSONObjectFromString } from '@/utils' -// TODO: TInput should only be allowed to be void or an object -export abstract class BaseLLM< - TInput = void, - TOutput = string, - TModelParams extends Record = Record -> extends BaseTask { - protected _inputSchema: ZodType | undefined - protected _outputSchema: ZodType | undefined - - protected _provider: string - protected _model: string - protected _modelParams: TModelParams | undefined - protected _examples: types.LLMExample[] | undefined - protected _tokenizerP?: Promise - - constructor( - options: SetRequired< - types.BaseLLMOptions, - 'provider' | 'model' - > - ) { - super(options) - - this._inputSchema = options.inputSchema - this._outputSchema = options.outputSchema - - this._provider = options.provider - this._model = options.model - this._modelParams = options.modelParams - this._examples = options.examples - } - - // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - input(inputSchema: ZodType): BaseLLM { - const refinedInstance = this as unknown as BaseLLM - refinedInstance._inputSchema = inputSchema - return refinedInstance - } - - // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - output(outputSchema: ZodType): BaseLLM { - const refinedInstance = this as unknown as BaseLLM - refinedInstance._outputSchema = outputSchema - return refinedInstance - } - - public override get inputSchema(): ZodType { - if (this._inputSchema) { - return this._inputSchema - } else { - // TODO: improve typing - return z.void() as unknown as ZodType - } - } - - public override get outputSchema(): ZodType { - if (this._outputSchema) { - return this._outputSchema - } else { - // TODO: improve typing - return z.string() as unknown as ZodType - } - } - - public override get name(): string { - return `${this._provider}:chat:${this._model}` - } - - examples(examples: types.LLMExample[]): this { - this._examples = examples - return this - } - - modelParams(params: Partial): this { - // We assume that modelParams does not include nested objects. - // If it did, we would need to do a deep merge. - this._modelParams = { ...this._modelParams, ...params } as TModelParams - return this - } - - public async getNumTokens(text: string): Promise { - if (!this._tokenizerP) { - const model = this._model || 'gpt2' - - this._tokenizerP = getTokenizerForModel(model).catch((err) => { - console.warn( - `Failed to initialize tokenizer for model "${model}", falling back to approximate count`, - err - ) - - return null - }) - } - - const tokenizer = await this._tokenizerP - - if (tokenizer) { - return tokenizer.encode(text).length - } - - // fallback to approximate calculation if tokenizer is not available - return Math.ceil(text.length / 4) - } -} +import { BaseLLM } from './llm' export abstract class BaseChatModel< TInput = void,