From 7596404ccc0d0edcc03f66bf258d18f1026d41f6 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Fri, 9 Jun 2023 22:42:20 -0700 Subject: [PATCH] =?UTF-8?q?=E2=98=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llms/llm.ts | 42 ++++++++++++++++++++---------------------- src/task.ts | 41 +++++++++++++++++++++++------------------ src/types.ts | 19 ++++++++++++++++--- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/src/llms/llm.ts b/src/llms/llm.ts index def0c62d..e127e646 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -143,7 +143,10 @@ export abstract class BaseChatModel< messages: types.ChatMessage[] ): Promise> - public async buildMessages(input?: types.ParsedData) { + public async buildMessages( + input?: types.ParsedData, + ctx?: types.TaskCallContext + ) { if (this._inputSchema) { const inputSchema = this._inputSchema instanceof z.ZodType @@ -211,20 +214,29 @@ export abstract class BaseChatModel< } } + if (ctx?.retryMessage) { + messages.push({ + role: 'system', + content: ctx.retryMessage + }) + } + // TODO: filter/compress messages based on token counts return messages } protected override async _call( - input?: types.ParsedData - ): Promise> { - const messages = await this.buildMessages(input) + ctx: types.TaskCallContext + ): Promise> { + const messages = await this.buildMessages(ctx.input, ctx) console.log('>>>') console.log(messages) const completion = await this._createChatCompletion(messages) + ctx.metadata.completion = completion + let output: any = completion.message.content console.log('===') @@ -246,7 +258,7 @@ export abstract class BaseChatModel< throw new errors.OutputValidationError(err.message, { cause: err }) } else if (err instanceof SyntaxError) { throw new errors.OutputValidationError( - `Invalid JSON: ${err.message}`, + `Invalid JSON array: ${err.message}`, { cause: err } ) } else { @@ -262,7 +274,7 @@ export abstract class BaseChatModel< throw new errors.OutputValidationError(err.message, { cause: err }) } else if (err instanceof SyntaxError) { throw new errors.OutputValidationError( - `Invalid JSON: ${err.message}`, + `Invalid JSON object: ${err.message}`, { cause: err } ) } else { @@ -310,23 +322,9 @@ export abstract class BaseChatModel< throw new errors.ZodOutputValidationError(safeResult.error) } - return { - result: safeResult.data, - metadata: { - input, - messages, - completion - } - } + return safeResult.data } else { - return { - result: output, - metadata: { - input, - messages, - completion - } - } + return output } } diff --git a/src/task.ts b/src/task.ts index 83f66867..f3274f6d 100644 --- a/src/task.ts +++ b/src/task.ts @@ -1,4 +1,4 @@ -import pRetry from 'p-retry' +import pRetry, { FailedAttemptError } from 'p-retry' import { ZodRawShape, ZodTypeAny } from 'zod' import * as errors from '@/errors' @@ -65,35 +65,40 @@ export abstract class BaseTask< public async callWithMetadata( input?: types.ParsedData ): Promise> { - const metadata: types.TaskResponseMetadata = { + const ctx: types.TaskCallContext = { input, - numRetries: 0 + attemptNumber: 0, + metadata: {} } - do { - try { - const response = await this._call(input) - return response - } catch (err: any) { + const result = await pRetry(() => this._call(ctx), { + ...this._retryConfig, + onFailedAttempt: async (err: FailedAttemptError) => { + if (this._retryConfig.onFailedAttempt) { + await Promise.resolve(this._retryConfig.onFailedAttempt(err)) + } + + ctx.attemptNumber = err.attemptNumber + 1 + if (err instanceof errors.ZodOutputValidationError) { - // TODO + ctx.retryMessage = err.message + } else if (err instanceof errors.OutputValidationError) { + ctx.retryMessage = err.message } else { throw err } } + }) - // TODO: handle errors, retry logic, and self-healing - metadata.numRetries = (metadata.numRetries ?? 0) + 1 - if (metadata.numRetries > this._retryConfig.retries) { - } - - // eslint-disable-next-line no-constant-condition - } while (true) + return { + result, + metadata: ctx.metadata + } } protected abstract _call( - input?: types.ParsedData - ): Promise> + ctx: types.TaskCallContext + ): Promise> // TODO // abstract stream({ diff --git a/src/types.ts b/src/types.ts index efb9596e..22def8a4 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,6 @@ import * as anthropic from '@anthropic-ai/sdk' import * as openai from 'openai-fetch' +import type { Options as RetryOptions } from 'p-retry' import { SafeParseReturnType, ZodObject, @@ -101,8 +102,7 @@ export interface LLMExample { output: string } -export interface RetryConfig { - retries: number +export interface RetryConfig extends RetryOptions { strategy: string } @@ -132,15 +132,28 @@ export interface TaskResponseMetadata extends Record { export interface LLMTaskResponseMetadata< TChatCompletionResponse extends Record = Record > extends TaskResponseMetadata { + messages?: ChatMessage[] completion?: TChatCompletionResponse } export interface TaskResponse< TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, - TMetadata extends Record = Record + TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { result: ParsedData metadata: TMetadata } +export interface TaskCallContext< + TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, + TMetadata extends TaskResponseMetadata = TaskResponseMetadata +> { + input?: ParsedData + retryMessage?: string + + attemptNumber: number + metadata: Partial +} + // export type ProgressFunction = (partialResponse: ChatMessage) => void