kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: WIP add retry logic
rodzic
af1cc845b3
commit
9d54530880
|
@ -52,7 +52,7 @@ async function ExampleLLMQuery({ texts }: { texts: string[] }) {
|
|||
)
|
||||
}
|
||||
|
||||
ExampleLLMQuery({
|
||||
const example = await ExampleLLMQuery({
|
||||
texts: [
|
||||
'I went to this place and it was just so awful.',
|
||||
'I had a great time.',
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
import type { JsonObject } from 'type-fest'
|
||||
import type { ZodError } from 'zod'
|
||||
import { ValidationError, fromZodError } from 'zod-validation-error'
|
||||
|
||||
export type ErrorOptions = {
|
||||
/** HTTP status code for the error. */
|
||||
status?: number
|
||||
|
||||
/** The original error that caused this error. */
|
||||
cause?: unknown
|
||||
|
||||
/** Additional context to be added to the error. */
|
||||
context?: JsonObject
|
||||
}
|
||||
|
||||
export class BaseError extends Error {
|
||||
status?: number
|
||||
context?: JsonObject
|
||||
|
||||
constructor(message: string, opts: ErrorOptions = {}) {
|
||||
if (opts.cause) {
|
||||
super(message, { cause: opts.cause })
|
||||
} else {
|
||||
super(message)
|
||||
}
|
||||
|
||||
// Ensure the name of this error is the same as the class name
|
||||
this.name = this.constructor.name
|
||||
|
||||
// Set stack trace to caller
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
|
||||
// Status is used for Express error handling
|
||||
if (opts.status) {
|
||||
this.status = opts.status
|
||||
}
|
||||
|
||||
// Add additional context to the error
|
||||
if (opts.context) {
|
||||
this.context = opts.context
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An error caused by an OpenAI API call.
|
||||
*/
|
||||
export class OpenAIApiError extends BaseError {
|
||||
constructor(message: string, opts: ErrorOptions = {}) {
|
||||
opts.status = opts.status || 500
|
||||
super(message, opts)
|
||||
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
||||
export class ZodOutputValidationError extends BaseError {
|
||||
validationError: ValidationError
|
||||
|
||||
constructor(zodError: ZodError) {
|
||||
const validationError = fromZodError(zodError)
|
||||
super(validationError.message, { cause: zodError })
|
||||
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
|
||||
this.validationError = validationError
|
||||
}
|
||||
}
|
||||
|
||||
export class OutputValidationError extends BaseError {
|
||||
constructor(message: string, opts: ErrorOptions = {}) {
|
||||
super(message, opts)
|
||||
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
|
@ -1,10 +1,11 @@
|
|||
import { jsonrepair } from 'jsonrepair'
|
||||
import { JSONRepairError, jsonrepair } from 'jsonrepair'
|
||||
import pMap from 'p-map'
|
||||
import { dedent } from 'ts-dedent'
|
||||
import { type SetRequired } from 'type-fest'
|
||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||
import { printNode, zodToTs } from 'zod-to-ts'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
import * as types from '@/types'
|
||||
import { BaseTask } from '@/task'
|
||||
import { getCompiledTemplate } from '@/template'
|
||||
|
@ -237,13 +238,37 @@ export abstract class BaseChatModel<
|
|||
: z.object(this._outputSchema)
|
||||
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
// TODO: gracefully handle parse errors
|
||||
const trimmedOutput = extractJSONArrayFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
try {
|
||||
const trimmedOutput = extractJSONArrayFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodObject) {
|
||||
// TODO: gracefully handle parse errors
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
try {
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||
output = output.toLowerCase().trim()
|
||||
const booleanOutputs = {
|
||||
|
@ -260,8 +285,9 @@ export abstract class BaseChatModel<
|
|||
if (booleanOutput !== undefined) {
|
||||
output = booleanOutput
|
||||
} else {
|
||||
// TODO
|
||||
throw new Error(`invalid boolean output: ${output}`)
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid boolean output: ${output}`
|
||||
)
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodNumber) {
|
||||
output = output.trim()
|
||||
|
@ -271,17 +297,21 @@ export abstract class BaseChatModel<
|
|||
: parseFloat(output)
|
||||
|
||||
if (isNaN(numberOutput)) {
|
||||
// TODO
|
||||
throw new Error(`invalid number output: ${output}`)
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid number output: ${output}`
|
||||
)
|
||||
} else {
|
||||
output = numberOutput
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle errors, retry logic, and self-healing
|
||||
const safeResult = outputSchema.safeParse(output)
|
||||
if (!safeResult.success) {
|
||||
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||
}
|
||||
|
||||
return {
|
||||
result: outputSchema.parse(output),
|
||||
result: safeResult.data,
|
||||
metadata: {
|
||||
input,
|
||||
messages,
|
||||
|
@ -312,6 +342,7 @@ export abstract class BaseChatModel<
|
|||
|
||||
const modelName = getModelNameForTiktoken(this._model)
|
||||
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
if (modelName === 'gpt-3.5-turbo') {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1
|
||||
|
|
34
src/task.ts
34
src/task.ts
|
@ -1,5 +1,7 @@
|
|||
import pRetry from 'p-retry'
|
||||
import { ZodRawShape, ZodTypeAny } from 'zod'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
|
||||
|
@ -23,12 +25,15 @@ export abstract class BaseTask<
|
|||
protected _agentic: Agentic
|
||||
|
||||
protected _timeoutMs?: number
|
||||
protected _retryConfig?: types.RetryConfig
|
||||
protected _retryConfig: types.RetryConfig
|
||||
|
||||
constructor(options: types.BaseTaskOptions) {
|
||||
this._agentic = options.agentic
|
||||
this._timeoutMs = options.timeoutMs
|
||||
this._retryConfig = options.retryConfig
|
||||
this._retryConfig = options.retryConfig ?? {
|
||||
retries: 3,
|
||||
strategy: 'default'
|
||||
}
|
||||
}
|
||||
|
||||
public get agentic(): Agentic {
|
||||
|
@ -60,7 +65,30 @@ export abstract class BaseTask<
|
|||
public async callWithMetadata(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
return this._call(input)
|
||||
const metadata: types.TaskResponseMetadata = {
|
||||
input,
|
||||
numRetries: 0
|
||||
}
|
||||
|
||||
do {
|
||||
try {
|
||||
const response = await this._call(input)
|
||||
return response
|
||||
} catch (err: any) {
|
||||
if (err instanceof errors.ZodOutputValidationError) {
|
||||
// TODO
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle errors, retry logic, and self-healing
|
||||
metadata.numRetries = (metadata.numRetries ?? 0) + 1
|
||||
if (metadata.numRetries > this._retryConfig.retries) {
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
} while (true)
|
||||
}
|
||||
|
||||
protected abstract _call(
|
||||
|
|
31
src/types.ts
31
src/types.ts
|
@ -102,10 +102,39 @@ export interface LLMExample {
|
|||
}
|
||||
|
||||
export interface RetryConfig {
|
||||
attempts: number
|
||||
retries: number
|
||||
strategy: string
|
||||
}
|
||||
|
||||
export type TaskError =
|
||||
| 'timeout'
|
||||
| 'provider'
|
||||
| 'validation'
|
||||
| 'unknown'
|
||||
| string
|
||||
|
||||
export interface TaskResponseMetadata extends Record<string, any> {
|
||||
// task info
|
||||
// - task name
|
||||
// - task id
|
||||
|
||||
// config
|
||||
input?: any
|
||||
stream?: boolean
|
||||
|
||||
// execution info
|
||||
success?: boolean
|
||||
numRetries?: number
|
||||
errorType?: TaskError
|
||||
error?: Error
|
||||
}
|
||||
|
||||
export interface LLMTaskResponseMetadata<
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> extends TaskResponseMetadata {
|
||||
completion?: TChatCompletionResponse
|
||||
}
|
||||
|
||||
export interface TaskResponse<
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TMetadata extends Record<string, any> = Record<string, any>
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "es2020",
|
||||
"lib": ["esnext"],
|
||||
"lib": ["esnext", "es2022.error"],
|
||||
"allowJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
|
|
Ładowanie…
Reference in New Issue