kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: WIP add retry logic
rodzic
af1cc845b3
commit
9d54530880
|
@ -52,7 +52,7 @@ async function ExampleLLMQuery({ texts }: { texts: string[] }) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
ExampleLLMQuery({
|
const example = await ExampleLLMQuery({
|
||||||
texts: [
|
texts: [
|
||||||
'I went to this place and it was just so awful.',
|
'I went to this place and it was just so awful.',
|
||||||
'I had a great time.',
|
'I had a great time.',
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
import type { JsonObject } from 'type-fest'
|
||||||
|
import type { ZodError } from 'zod'
|
||||||
|
import { ValidationError, fromZodError } from 'zod-validation-error'
|
||||||
|
|
||||||
|
export type ErrorOptions = {
|
||||||
|
/** HTTP status code for the error. */
|
||||||
|
status?: number
|
||||||
|
|
||||||
|
/** The original error that caused this error. */
|
||||||
|
cause?: unknown
|
||||||
|
|
||||||
|
/** Additional context to be added to the error. */
|
||||||
|
context?: JsonObject
|
||||||
|
}
|
||||||
|
|
||||||
|
export class BaseError extends Error {
|
||||||
|
status?: number
|
||||||
|
context?: JsonObject
|
||||||
|
|
||||||
|
constructor(message: string, opts: ErrorOptions = {}) {
|
||||||
|
if (opts.cause) {
|
||||||
|
super(message, { cause: opts.cause })
|
||||||
|
} else {
|
||||||
|
super(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the name of this error is the same as the class name
|
||||||
|
this.name = this.constructor.name
|
||||||
|
|
||||||
|
// Set stack trace to caller
|
||||||
|
Error.captureStackTrace?.(this, this.constructor)
|
||||||
|
|
||||||
|
// Status is used for Express error handling
|
||||||
|
if (opts.status) {
|
||||||
|
this.status = opts.status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add additional context to the error
|
||||||
|
if (opts.context) {
|
||||||
|
this.context = opts.context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An error caused by an OpenAI API call.
|
||||||
|
*/
|
||||||
|
export class OpenAIApiError extends BaseError {
|
||||||
|
constructor(message: string, opts: ErrorOptions = {}) {
|
||||||
|
opts.status = opts.status || 500
|
||||||
|
super(message, opts)
|
||||||
|
|
||||||
|
Error.captureStackTrace?.(this, this.constructor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ZodOutputValidationError extends BaseError {
|
||||||
|
validationError: ValidationError
|
||||||
|
|
||||||
|
constructor(zodError: ZodError) {
|
||||||
|
const validationError = fromZodError(zodError)
|
||||||
|
super(validationError.message, { cause: zodError })
|
||||||
|
|
||||||
|
Error.captureStackTrace?.(this, this.constructor)
|
||||||
|
|
||||||
|
this.validationError = validationError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class OutputValidationError extends BaseError {
|
||||||
|
constructor(message: string, opts: ErrorOptions = {}) {
|
||||||
|
super(message, opts)
|
||||||
|
|
||||||
|
Error.captureStackTrace?.(this, this.constructor)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,11 @@
|
||||||
import { jsonrepair } from 'jsonrepair'
|
import { JSONRepairError, jsonrepair } from 'jsonrepair'
|
||||||
import pMap from 'p-map'
|
import pMap from 'p-map'
|
||||||
import { dedent } from 'ts-dedent'
|
import { dedent } from 'ts-dedent'
|
||||||
import { type SetRequired } from 'type-fest'
|
import { type SetRequired } from 'type-fest'
|
||||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||||
import { printNode, zodToTs } from 'zod-to-ts'
|
import { printNode, zodToTs } from 'zod-to-ts'
|
||||||
|
|
||||||
|
import * as errors from '@/errors'
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
import { BaseTask } from '@/task'
|
import { BaseTask } from '@/task'
|
||||||
import { getCompiledTemplate } from '@/template'
|
import { getCompiledTemplate } from '@/template'
|
||||||
|
@ -237,13 +238,37 @@ export abstract class BaseChatModel<
|
||||||
: z.object(this._outputSchema)
|
: z.object(this._outputSchema)
|
||||||
|
|
||||||
if (outputSchema instanceof z.ZodArray) {
|
if (outputSchema instanceof z.ZodArray) {
|
||||||
// TODO: gracefully handle parse errors
|
try {
|
||||||
const trimmedOutput = extractJSONArrayFromString(output)
|
const trimmedOutput = extractJSONArrayFromString(output)
|
||||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||||
|
} catch (err: any) {
|
||||||
|
if (err instanceof JSONRepairError) {
|
||||||
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
|
} else if (err instanceof SyntaxError) {
|
||||||
|
throw new errors.OutputValidationError(
|
||||||
|
`Invalid JSON: ${err.message}`,
|
||||||
|
{ cause: err }
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (outputSchema instanceof z.ZodObject) {
|
} else if (outputSchema instanceof z.ZodObject) {
|
||||||
// TODO: gracefully handle parse errors
|
try {
|
||||||
const trimmedOutput = extractJSONObjectFromString(output)
|
const trimmedOutput = extractJSONObjectFromString(output)
|
||||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||||
|
} catch (err: any) {
|
||||||
|
if (err instanceof JSONRepairError) {
|
||||||
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
|
} else if (err instanceof SyntaxError) {
|
||||||
|
throw new errors.OutputValidationError(
|
||||||
|
`Invalid JSON: ${err.message}`,
|
||||||
|
{ cause: err }
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||||
output = output.toLowerCase().trim()
|
output = output.toLowerCase().trim()
|
||||||
const booleanOutputs = {
|
const booleanOutputs = {
|
||||||
|
@ -260,8 +285,9 @@ export abstract class BaseChatModel<
|
||||||
if (booleanOutput !== undefined) {
|
if (booleanOutput !== undefined) {
|
||||||
output = booleanOutput
|
output = booleanOutput
|
||||||
} else {
|
} else {
|
||||||
// TODO
|
throw new errors.OutputValidationError(
|
||||||
throw new Error(`invalid boolean output: ${output}`)
|
`Invalid boolean output: ${output}`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else if (outputSchema instanceof z.ZodNumber) {
|
} else if (outputSchema instanceof z.ZodNumber) {
|
||||||
output = output.trim()
|
output = output.trim()
|
||||||
|
@ -271,17 +297,21 @@ export abstract class BaseChatModel<
|
||||||
: parseFloat(output)
|
: parseFloat(output)
|
||||||
|
|
||||||
if (isNaN(numberOutput)) {
|
if (isNaN(numberOutput)) {
|
||||||
// TODO
|
throw new errors.OutputValidationError(
|
||||||
throw new Error(`invalid number output: ${output}`)
|
`Invalid number output: ${output}`
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
output = numberOutput
|
output = numberOutput
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: handle errors, retry logic, and self-healing
|
const safeResult = outputSchema.safeParse(output)
|
||||||
|
if (!safeResult.success) {
|
||||||
|
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
result: outputSchema.parse(output),
|
result: safeResult.data,
|
||||||
metadata: {
|
metadata: {
|
||||||
input,
|
input,
|
||||||
messages,
|
messages,
|
||||||
|
@ -312,6 +342,7 @@ export abstract class BaseChatModel<
|
||||||
|
|
||||||
const modelName = getModelNameForTiktoken(this._model)
|
const modelName = getModelNameForTiktoken(this._model)
|
||||||
|
|
||||||
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||||
if (modelName === 'gpt-3.5-turbo') {
|
if (modelName === 'gpt-3.5-turbo') {
|
||||||
tokensPerMessage = 4
|
tokensPerMessage = 4
|
||||||
tokensPerName = -1
|
tokensPerName = -1
|
||||||
|
|
34
src/task.ts
34
src/task.ts
|
@ -1,5 +1,7 @@
|
||||||
|
import pRetry from 'p-retry'
|
||||||
import { ZodRawShape, ZodTypeAny } from 'zod'
|
import { ZodRawShape, ZodTypeAny } from 'zod'
|
||||||
|
|
||||||
|
import * as errors from '@/errors'
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
import { Agentic } from '@/agentic'
|
import { Agentic } from '@/agentic'
|
||||||
|
|
||||||
|
@ -23,12 +25,15 @@ export abstract class BaseTask<
|
||||||
protected _agentic: Agentic
|
protected _agentic: Agentic
|
||||||
|
|
||||||
protected _timeoutMs?: number
|
protected _timeoutMs?: number
|
||||||
protected _retryConfig?: types.RetryConfig
|
protected _retryConfig: types.RetryConfig
|
||||||
|
|
||||||
constructor(options: types.BaseTaskOptions) {
|
constructor(options: types.BaseTaskOptions) {
|
||||||
this._agentic = options.agentic
|
this._agentic = options.agentic
|
||||||
this._timeoutMs = options.timeoutMs
|
this._timeoutMs = options.timeoutMs
|
||||||
this._retryConfig = options.retryConfig
|
this._retryConfig = options.retryConfig ?? {
|
||||||
|
retries: 3,
|
||||||
|
strategy: 'default'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public get agentic(): Agentic {
|
public get agentic(): Agentic {
|
||||||
|
@ -60,7 +65,30 @@ export abstract class BaseTask<
|
||||||
public async callWithMetadata(
|
public async callWithMetadata(
|
||||||
input?: types.ParsedData<TInput>
|
input?: types.ParsedData<TInput>
|
||||||
): Promise<types.TaskResponse<TOutput>> {
|
): Promise<types.TaskResponse<TOutput>> {
|
||||||
return this._call(input)
|
const metadata: types.TaskResponseMetadata = {
|
||||||
|
input,
|
||||||
|
numRetries: 0
|
||||||
|
}
|
||||||
|
|
||||||
|
do {
|
||||||
|
try {
|
||||||
|
const response = await this._call(input)
|
||||||
|
return response
|
||||||
|
} catch (err: any) {
|
||||||
|
if (err instanceof errors.ZodOutputValidationError) {
|
||||||
|
// TODO
|
||||||
|
} else {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle errors, retry logic, and self-healing
|
||||||
|
metadata.numRetries = (metadata.numRetries ?? 0) + 1
|
||||||
|
if (metadata.numRetries > this._retryConfig.retries) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line no-constant-condition
|
||||||
|
} while (true)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract _call(
|
protected abstract _call(
|
||||||
|
|
31
src/types.ts
31
src/types.ts
|
@ -102,10 +102,39 @@ export interface LLMExample {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RetryConfig {
|
export interface RetryConfig {
|
||||||
attempts: number
|
retries: number
|
||||||
strategy: string
|
strategy: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type TaskError =
|
||||||
|
| 'timeout'
|
||||||
|
| 'provider'
|
||||||
|
| 'validation'
|
||||||
|
| 'unknown'
|
||||||
|
| string
|
||||||
|
|
||||||
|
export interface TaskResponseMetadata extends Record<string, any> {
|
||||||
|
// task info
|
||||||
|
// - task name
|
||||||
|
// - task id
|
||||||
|
|
||||||
|
// config
|
||||||
|
input?: any
|
||||||
|
stream?: boolean
|
||||||
|
|
||||||
|
// execution info
|
||||||
|
success?: boolean
|
||||||
|
numRetries?: number
|
||||||
|
errorType?: TaskError
|
||||||
|
error?: Error
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface LLMTaskResponseMetadata<
|
||||||
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||||
|
> extends TaskResponseMetadata {
|
||||||
|
completion?: TChatCompletionResponse
|
||||||
|
}
|
||||||
|
|
||||||
export interface TaskResponse<
|
export interface TaskResponse<
|
||||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||||
TMetadata extends Record<string, any> = Record<string, any>
|
TMetadata extends Record<string, any> = Record<string, any>
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"compilerOptions": {
|
"compilerOptions": {
|
||||||
"target": "es2020",
|
"target": "es2020",
|
||||||
"lib": ["esnext"],
|
"lib": ["esnext", "es2022.error"],
|
||||||
"allowJs": true,
|
"allowJs": true,
|
||||||
"skipLibCheck": true,
|
"skipLibCheck": true,
|
||||||
"strict": true,
|
"strict": true,
|
||||||
|
|
Ładowanie…
Reference in New Issue