From 4174d7cdbf210367eec918e2327c57c30c50eaab Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Fri, 26 May 2023 12:16:13 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm.ts | 25 ++++++++++++++----------- src/openai.ts | 30 +++++++++++++++--------------- src/task.ts | 16 +++++++++++----- src/types.ts | 5 +++-- 4 files changed, 43 insertions(+), 33 deletions(-) diff --git a/src/llm.ts b/src/llm.ts index f62b80f..37c9691 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -8,12 +8,18 @@ export abstract class BaseLLMCallBuilder< TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > extends BaseTaskCallBuilder { - protected _options: types.BaseLLMOptions + protected _provider: string + protected _model: string + protected _modelParams: TModelParams + protected _examples: types.LLMExample[] constructor(options: types.BaseLLMOptions) { super(options) - this._options = options + this._provider = options.provider + this._model = options.model + this._modelParams = options.modelParams + this._examples = options.examples } override input( @@ -21,7 +27,7 @@ export abstract class BaseLLMCallBuilder< ): BaseLLMCallBuilder { ;( this as unknown as BaseLLMCallBuilder - )._options.input = inputSchema + )._inputSchema = inputSchema return this as unknown as BaseLLMCallBuilder } @@ -30,22 +36,19 @@ export abstract class BaseLLMCallBuilder< ): BaseLLMCallBuilder { ;( this as unknown as BaseLLMCallBuilder - )._options.output = outputSchema + )._outputSchema = outputSchema return this as unknown as BaseLLMCallBuilder } examples(examples: types.LLMExample[]) { - this._options.examples = examples + this._examples = examples return this } modelParams(params: Partial) { - // We assume that modelParams does not include nested objects; if it did, we would need to do a deep merge... - this._options.modelParams = Object.assign( - {}, - this._options.modelParams, - params - ) + // We assume that modelParams does not include nested objects. + // If it did, we would need to do a deep merge. + this._modelParams = Object.assign({}, this._modelParams, params) return this } diff --git a/src/openai.ts b/src/openai.ts index f817f37..9e30b5a 100644 --- a/src/openai.ts +++ b/src/openai.ts @@ -43,11 +43,11 @@ export class OpenAIChatModelBuilder< override async call( input?: types.ParsedData ): Promise> { - if (this._options.input) { + if (this._inputSchema) { const inputSchema = - this._options.input instanceof z.ZodType - ? this._options.input - : z.object(this._options.input) + this._inputSchema instanceof z.ZodType + ? this._inputSchema + : z.object(this._inputSchema) // TODO: handle errors gracefully input = inputSchema.parse(input) @@ -66,9 +66,9 @@ export class OpenAIChatModelBuilder< }) .filter((message) => message.content) - if (this._options.examples?.length) { + if (this._examples?.length) { // TODO: smarter example selection - for (const example of this._options.examples) { + for (const example of this._examples) { messages.push({ role: 'system', content: `Example input: ${example.input}\n\nExample output: ${example.output}` @@ -76,11 +76,11 @@ export class OpenAIChatModelBuilder< } } - if (this._options.output) { + if (this._outputSchema) { const outputSchema = - this._options.output instanceof z.ZodType - ? this._options.output - : z.object(this._options.output) + this._outputSchema instanceof z.ZodType + ? this._outputSchema + : z.object(this._outputSchema) const { node } = zodToTs(outputSchema) @@ -116,15 +116,15 @@ export class OpenAIChatModelBuilder< console.log(messages) const completion = await this._client.createChatCompletion({ model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining - ...this._options.modelParams, + ...this._outputSchema, messages }) - if (this._options.output) { + if (this._outputSchema) { const outputSchema = - this._options.output instanceof z.ZodType - ? this._options.output - : z.object(this._options.output) + this._outputSchema instanceof z.ZodType + ? this._outputSchema + : z.object(this._outputSchema) let output: any = completion.message.content console.log('===') diff --git a/src/task.ts b/src/task.ts index 00a7e1b..8af8fc9 100644 --- a/src/task.ts +++ b/src/task.ts @@ -6,16 +6,22 @@ export abstract class BaseTaskCallBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TOutput extends ZodRawShape | ZodTypeAny = z.ZodTypeAny > { - protected _options: types.BaseTaskOptions + protected _inputSchema: TInput + protected _outputSchema: TOutput + protected _timeoutMs: number + protected _retryConfig: types.RetryConfig constructor(options: types.BaseTaskOptions) { - this._options = options + this._inputSchema = options.inputSchema + this._outputSchema = options.outputSchema + this._timeoutMs = options.timeoutMs + this._retryConfig = options.retryConfig } input( inputSchema: U ): BaseTaskCallBuilder { - ;(this as unknown as BaseTaskCallBuilder)._options.input = + ;(this as unknown as BaseTaskCallBuilder)._inputSchema = inputSchema return this as unknown as BaseTaskCallBuilder } @@ -23,13 +29,13 @@ export abstract class BaseTaskCallBuilder< output( outputSchema: U ): BaseTaskCallBuilder { - ;(this as unknown as BaseTaskCallBuilder)._options.output = + ;(this as unknown as BaseTaskCallBuilder)._outputSchema = outputSchema return this as unknown as BaseTaskCallBuilder } retry(retryConfig: types.RetryConfig) { - this._options.retryConfig = retryConfig + this._retryConfig = retryConfig return this } diff --git a/src/types.ts b/src/types.ts index ef74a8f..54614f3 100644 --- a/src/types.ts +++ b/src/types.ts @@ -28,8 +28,8 @@ export interface BaseTaskOptions< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType > { - input?: TInput - output?: TOutput + inputSchema?: TInput + outputSchema?: TOutput timeoutMs?: number retryConfig?: RetryConfig @@ -37,6 +37,7 @@ export interface BaseTaskOptions< // TODO // caching config // logging config + // reference to agentic context } export interface BaseLLMOptions<