diff --git a/scratch/declarative-design-jsx-0.tsx b/scratch/declarative-design-jsx-0.tsx index e4ff702..96cdc3c 100644 --- a/scratch/declarative-design-jsx-0.tsx +++ b/scratch/declarative-design-jsx-0.tsx @@ -52,7 +52,7 @@ async function ExampleLLMQuery({ texts }: { texts: string[] }) { ) } -ExampleLLMQuery({ +const example = await ExampleLLMQuery({ texts: [ 'I went to this place and it was just so awful.', 'I had a great time.', diff --git a/src/errors.ts b/src/errors.ts new file mode 100644 index 0000000..1bb71c6 --- /dev/null +++ b/src/errors.ts @@ -0,0 +1,76 @@ +import type { JsonObject } from 'type-fest' +import type { ZodError } from 'zod' +import { ValidationError, fromZodError } from 'zod-validation-error' + +export type ErrorOptions = { + /** HTTP status code for the error. */ + status?: number + + /** The original error that caused this error. */ + cause?: unknown + + /** Additional context to be added to the error. */ + context?: JsonObject +} + +export class BaseError extends Error { + status?: number + context?: JsonObject + + constructor(message: string, opts: ErrorOptions = {}) { + if (opts.cause) { + super(message, { cause: opts.cause }) + } else { + super(message) + } + + // Ensure the name of this error is the same as the class name + this.name = this.constructor.name + + // Set stack trace to caller + Error.captureStackTrace?.(this, this.constructor) + + // Status is used for Express error handling + if (opts.status) { + this.status = opts.status + } + + // Add additional context to the error + if (opts.context) { + this.context = opts.context + } + } +} + +/** + * An error caused by an OpenAI API call. + */ +export class OpenAIApiError extends BaseError { + constructor(message: string, opts: ErrorOptions = {}) { + opts.status = opts.status || 500 + super(message, opts) + + Error.captureStackTrace?.(this, this.constructor) + } +} + +export class ZodOutputValidationError extends BaseError { + validationError: ValidationError + + constructor(zodError: ZodError) { + const validationError = fromZodError(zodError) + super(validationError.message, { cause: zodError }) + + Error.captureStackTrace?.(this, this.constructor) + + this.validationError = validationError + } +} + +export class OutputValidationError extends BaseError { + constructor(message: string, opts: ErrorOptions = {}) { + super(message, opts) + + Error.captureStackTrace?.(this, this.constructor) + } +} diff --git a/src/llms/llm.ts b/src/llms/llm.ts index 2882ac6..def0c62 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -1,10 +1,11 @@ -import { jsonrepair } from 'jsonrepair' +import { JSONRepairError, jsonrepair } from 'jsonrepair' import pMap from 'p-map' import { dedent } from 'ts-dedent' import { type SetRequired } from 'type-fest' import { ZodRawShape, ZodTypeAny, z } from 'zod' import { printNode, zodToTs } from 'zod-to-ts' +import * as errors from '@/errors' import * as types from '@/types' import { BaseTask } from '@/task' import { getCompiledTemplate } from '@/template' @@ -237,13 +238,37 @@ export abstract class BaseChatModel< : z.object(this._outputSchema) if (outputSchema instanceof z.ZodArray) { - // TODO: gracefully handle parse errors - const trimmedOutput = extractJSONArrayFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + try { + const trimmedOutput = extractJSONArrayFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } } else if (outputSchema instanceof z.ZodObject) { - // TODO: gracefully handle parse errors - const trimmedOutput = extractJSONObjectFromString(output) - output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + try { + const trimmedOutput = extractJSONObjectFromString(output) + output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } } else if (outputSchema instanceof z.ZodBoolean) { output = output.toLowerCase().trim() const booleanOutputs = { @@ -260,8 +285,9 @@ export abstract class BaseChatModel< if (booleanOutput !== undefined) { output = booleanOutput } else { - // TODO - throw new Error(`invalid boolean output: ${output}`) + throw new errors.OutputValidationError( + `Invalid boolean output: ${output}` + ) } } else if (outputSchema instanceof z.ZodNumber) { output = output.trim() @@ -271,17 +297,21 @@ export abstract class BaseChatModel< : parseFloat(output) if (isNaN(numberOutput)) { - // TODO - throw new Error(`invalid number output: ${output}`) + throw new errors.OutputValidationError( + `Invalid number output: ${output}` + ) } else { output = numberOutput } } - // TODO: handle errors, retry logic, and self-healing + const safeResult = outputSchema.safeParse(output) + if (!safeResult.success) { + throw new errors.ZodOutputValidationError(safeResult.error) + } return { - result: outputSchema.parse(output), + result: safeResult.data, metadata: { input, messages, @@ -312,6 +342,7 @@ export abstract class BaseChatModel< const modelName = getModelNameForTiktoken(this._model) + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb if (modelName === 'gpt-3.5-turbo') { tokensPerMessage = 4 tokensPerName = -1 diff --git a/src/task.ts b/src/task.ts index 38e5d1a..83f6686 100644 --- a/src/task.ts +++ b/src/task.ts @@ -1,5 +1,7 @@ +import pRetry from 'p-retry' import { ZodRawShape, ZodTypeAny } from 'zod' +import * as errors from '@/errors' import * as types from '@/types' import { Agentic } from '@/agentic' @@ -23,12 +25,15 @@ export abstract class BaseTask< protected _agentic: Agentic protected _timeoutMs?: number - protected _retryConfig?: types.RetryConfig + protected _retryConfig: types.RetryConfig constructor(options: types.BaseTaskOptions) { this._agentic = options.agentic this._timeoutMs = options.timeoutMs - this._retryConfig = options.retryConfig + this._retryConfig = options.retryConfig ?? { + retries: 3, + strategy: 'default' + } } public get agentic(): Agentic { @@ -60,7 +65,30 @@ export abstract class BaseTask< public async callWithMetadata( input?: types.ParsedData ): Promise> { - return this._call(input) + const metadata: types.TaskResponseMetadata = { + input, + numRetries: 0 + } + + do { + try { + const response = await this._call(input) + return response + } catch (err: any) { + if (err instanceof errors.ZodOutputValidationError) { + // TODO + } else { + throw err + } + } + + // TODO: handle errors, retry logic, and self-healing + metadata.numRetries = (metadata.numRetries ?? 0) + 1 + if (metadata.numRetries > this._retryConfig.retries) { + } + + // eslint-disable-next-line no-constant-condition + } while (true) } protected abstract _call( diff --git a/src/types.ts b/src/types.ts index 7c8dde9..efb9596 100644 --- a/src/types.ts +++ b/src/types.ts @@ -102,10 +102,39 @@ export interface LLMExample { } export interface RetryConfig { - attempts: number + retries: number strategy: string } +export type TaskError = + | 'timeout' + | 'provider' + | 'validation' + | 'unknown' + | string + +export interface TaskResponseMetadata extends Record { + // task info + // - task name + // - task id + + // config + input?: any + stream?: boolean + + // execution info + success?: boolean + numRetries?: number + errorType?: TaskError + error?: Error +} + +export interface LLMTaskResponseMetadata< + TChatCompletionResponse extends Record = Record +> extends TaskResponseMetadata { + completion?: TChatCompletionResponse +} + export interface TaskResponse< TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TMetadata extends Record = Record diff --git a/tsconfig.json b/tsconfig.json index f18422b..d3b520b 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,7 +1,7 @@ { "compilerOptions": { "target": "es2020", - "lib": ["esnext"], + "lib": ["esnext", "es2022.error"], "allowJs": true, "skipLibCheck": true, "strict": true,