feat: WIP add retry logic

old-agentic-v1^2
Travis Fischer 2023-06-09 10:44:02 -07:00
rodzic af1cc845b3
commit 9d54530880
6 zmienionych plików z 183 dodań i 19 usunięć

Wyświetl plik

@ -52,7 +52,7 @@ async function ExampleLLMQuery({ texts }: { texts: string[] }) {
)
}
ExampleLLMQuery({
const example = await ExampleLLMQuery({
texts: [
'I went to this place and it was just so awful.',
'I had a great time.',

76
src/errors.ts 100644
Wyświetl plik

@ -0,0 +1,76 @@
import type { JsonObject } from 'type-fest'
import type { ZodError } from 'zod'
import { ValidationError, fromZodError } from 'zod-validation-error'
export type ErrorOptions = {
/** HTTP status code for the error. */
status?: number
/** The original error that caused this error. */
cause?: unknown
/** Additional context to be added to the error. */
context?: JsonObject
}
export class BaseError extends Error {
status?: number
context?: JsonObject
constructor(message: string, opts: ErrorOptions = {}) {
if (opts.cause) {
super(message, { cause: opts.cause })
} else {
super(message)
}
// Ensure the name of this error is the same as the class name
this.name = this.constructor.name
// Set stack trace to caller
Error.captureStackTrace?.(this, this.constructor)
// Status is used for Express error handling
if (opts.status) {
this.status = opts.status
}
// Add additional context to the error
if (opts.context) {
this.context = opts.context
}
}
}
/**
* An error caused by an OpenAI API call.
*/
export class OpenAIApiError extends BaseError {
constructor(message: string, opts: ErrorOptions = {}) {
opts.status = opts.status || 500
super(message, opts)
Error.captureStackTrace?.(this, this.constructor)
}
}
export class ZodOutputValidationError extends BaseError {
validationError: ValidationError
constructor(zodError: ZodError) {
const validationError = fromZodError(zodError)
super(validationError.message, { cause: zodError })
Error.captureStackTrace?.(this, this.constructor)
this.validationError = validationError
}
}
export class OutputValidationError extends BaseError {
constructor(message: string, opts: ErrorOptions = {}) {
super(message, opts)
Error.captureStackTrace?.(this, this.constructor)
}
}

Wyświetl plik

@ -1,10 +1,11 @@
import { jsonrepair } from 'jsonrepair'
import { JSONRepairError, jsonrepair } from 'jsonrepair'
import pMap from 'p-map'
import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest'
import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors'
import * as types from '@/types'
import { BaseTask } from '@/task'
import { getCompiledTemplate } from '@/template'
@ -237,13 +238,37 @@ export abstract class BaseChatModel<
: z.object(this._outputSchema)
if (outputSchema instanceof z.ZodArray) {
// TODO: gracefully handle parse errors
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
} else if (outputSchema instanceof z.ZodObject) {
// TODO: gracefully handle parse errors
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
} else if (outputSchema instanceof z.ZodBoolean) {
output = output.toLowerCase().trim()
const booleanOutputs = {
@ -260,8 +285,9 @@ export abstract class BaseChatModel<
if (booleanOutput !== undefined) {
output = booleanOutput
} else {
// TODO
throw new Error(`invalid boolean output: ${output}`)
throw new errors.OutputValidationError(
`Invalid boolean output: ${output}`
)
}
} else if (outputSchema instanceof z.ZodNumber) {
output = output.trim()
@ -271,17 +297,21 @@ export abstract class BaseChatModel<
: parseFloat(output)
if (isNaN(numberOutput)) {
// TODO
throw new Error(`invalid number output: ${output}`)
throw new errors.OutputValidationError(
`Invalid number output: ${output}`
)
} else {
output = numberOutput
}
}
// TODO: handle errors, retry logic, and self-healing
const safeResult = outputSchema.safeParse(output)
if (!safeResult.success) {
throw new errors.ZodOutputValidationError(safeResult.error)
}
return {
result: outputSchema.parse(output),
result: safeResult.data,
metadata: {
input,
messages,
@ -312,6 +342,7 @@ export abstract class BaseChatModel<
const modelName = getModelNameForTiktoken(this._model)
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
if (modelName === 'gpt-3.5-turbo') {
tokensPerMessage = 4
tokensPerName = -1

Wyświetl plik

@ -1,5 +1,7 @@
import pRetry from 'p-retry'
import { ZodRawShape, ZodTypeAny } from 'zod'
import * as errors from '@/errors'
import * as types from '@/types'
import { Agentic } from '@/agentic'
@ -23,12 +25,15 @@ export abstract class BaseTask<
protected _agentic: Agentic
protected _timeoutMs?: number
protected _retryConfig?: types.RetryConfig
protected _retryConfig: types.RetryConfig
constructor(options: types.BaseTaskOptions) {
this._agentic = options.agentic
this._timeoutMs = options.timeoutMs
this._retryConfig = options.retryConfig
this._retryConfig = options.retryConfig ?? {
retries: 3,
strategy: 'default'
}
}
public get agentic(): Agentic {
@ -60,7 +65,30 @@ export abstract class BaseTask<
public async callWithMetadata(
input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> {
return this._call(input)
const metadata: types.TaskResponseMetadata = {
input,
numRetries: 0
}
do {
try {
const response = await this._call(input)
return response
} catch (err: any) {
if (err instanceof errors.ZodOutputValidationError) {
// TODO
} else {
throw err
}
}
// TODO: handle errors, retry logic, and self-healing
metadata.numRetries = (metadata.numRetries ?? 0) + 1
if (metadata.numRetries > this._retryConfig.retries) {
}
// eslint-disable-next-line no-constant-condition
} while (true)
}
protected abstract _call(

Wyświetl plik

@ -102,10 +102,39 @@ export interface LLMExample {
}
export interface RetryConfig {
attempts: number
retries: number
strategy: string
}
export type TaskError =
| 'timeout'
| 'provider'
| 'validation'
| 'unknown'
| string
export interface TaskResponseMetadata extends Record<string, any> {
// task info
// - task name
// - task id
// config
input?: any
stream?: boolean
// execution info
success?: boolean
numRetries?: number
errorType?: TaskError
error?: Error
}
export interface LLMTaskResponseMetadata<
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends TaskResponseMetadata {
completion?: TChatCompletionResponse
}
export interface TaskResponse<
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TMetadata extends Record<string, any> = Record<string, any>

Wyświetl plik

@ -1,7 +1,7 @@
{
"compilerOptions": {
"target": "es2020",
"lib": ["esnext"],
"lib": ["esnext", "es2022.error"],
"allowJs": true,
"skipLibCheck": true,
"strict": true,