kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add pre- and post-hooks and cause retries in feedback
rodzic
9d3a7bc05f
commit
5875a8858f
|
@ -1,51 +0,0 @@
|
|||
/**
|
||||
|
||||
export type Metadata = Record<string, unknown>;
|
||||
|
||||
|
||||
export abstract class BaseTask<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = ZodTypeAny
|
||||
> {
|
||||
|
||||
// ...
|
||||
|
||||
private _preHooks: ((input?: types.ParsedData<TInput>) => void | Promise<void>, metadata: types.Metadata)[] = [];
|
||||
private _postHooks: ((result: types.ParsedData<TOutput>, metadata: types.Metadata) => void | Promise<void>)[] = [];
|
||||
|
||||
|
||||
public registerPreHook(hook: (input?: types.ParsedData<TInput>) => void | Promise<void>): this {
|
||||
this._preHooks.push(hook);
|
||||
return this;
|
||||
}
|
||||
|
||||
public registerPostHook(hook: (result: types.ParsedData<TOutput>) => void | Promise<void>): this {
|
||||
this._postHooks.push(hook);
|
||||
return this;
|
||||
}
|
||||
|
||||
public async callWithMetadata(
|
||||
input?: types.ParsedData<TInput>,
|
||||
options: { dryRun?: boolean } = {}
|
||||
): Promise<{result: types.ParsedData<TOutput> | undefined, metadata: types.Metadata}> {
|
||||
const metadata: types.Metadata = {};
|
||||
|
||||
if (options.dryRun) {
|
||||
return console.log( '// TODO: implement' )
|
||||
}
|
||||
|
||||
for (const hook of this._preHooks) {
|
||||
await hook(input);
|
||||
}
|
||||
|
||||
const result = await this._call(input);
|
||||
|
||||
for (const hook of this._postHooks) {
|
||||
await hook(result, metadata);
|
||||
}
|
||||
|
||||
return {result, metadata};
|
||||
}
|
||||
}
|
||||
|
||||
**/
|
|
@ -82,3 +82,14 @@ export class TemplateValidationError extends BaseError {
|
|||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An error caused by the user declining an output.
|
||||
*/
|
||||
export class HumanFeedbackDeclineError extends BaseError {
|
||||
constructor(message: string, opts: ErrorOptions = {}) {
|
||||
super(message, opts)
|
||||
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
import { HumanFeedbackDeclineError } from '@/errors'
|
||||
import { BaseTask } from '@/task'
|
||||
|
||||
import { HumanFeedbackMechanismCLI } from './cli'
|
||||
|
@ -268,6 +269,26 @@ export abstract class HumanFeedbackMechanism<
|
|||
}
|
||||
}
|
||||
|
||||
if (
|
||||
(Object.hasOwnProperty.call(feedback, 'accepted') &&
|
||||
feedback.accepted === false) ||
|
||||
(Object.hasOwnProperty.call(feedback, 'selected') &&
|
||||
feedback.selected.length === 0)
|
||||
) {
|
||||
const errorMsg = [
|
||||
'The output was declined by the human reviewer.',
|
||||
'Output:',
|
||||
'```',
|
||||
stringified,
|
||||
'```',
|
||||
'',
|
||||
'Please try again and return different output.'
|
||||
].join('\n')
|
||||
throw new HumanFeedbackDeclineError(errorMsg, {
|
||||
context: feedback
|
||||
})
|
||||
}
|
||||
|
||||
return feedback as FeedbackTypeToMetadata<T>
|
||||
}
|
||||
}
|
||||
|
@ -309,17 +330,10 @@ export function withHumanFeedback<
|
|||
options: finalOptions
|
||||
})
|
||||
|
||||
const originalCall = task.callWithMetadata.bind(task)
|
||||
|
||||
task.callWithMetadata = async function (input?: TInput) {
|
||||
const response = await originalCall(input)
|
||||
|
||||
const feedback = await feedbackMechanism.interact(response.result)
|
||||
|
||||
response.metadata = { ...response.metadata, feedback }
|
||||
|
||||
return response
|
||||
}
|
||||
task.addAfterCallHook(async function onCall(output, ctx) {
|
||||
const feedback = await feedbackMechanism.interact(output)
|
||||
ctx.metadata = { ...ctx.metadata, feedback }
|
||||
})
|
||||
|
||||
return task
|
||||
}
|
||||
|
|
77
src/task.ts
77
src/task.ts
|
@ -29,6 +29,11 @@ export abstract class BaseTask<
|
|||
protected _timeoutMs?: number
|
||||
protected _retryConfig: types.RetryConfig
|
||||
|
||||
private _preHooks: Array<(ctx: types.TaskCallContext<TInput>) => void> = []
|
||||
private _postHooks: Array<
|
||||
(output: TOutput, ctx: types.TaskCallContext<TInput>) => void
|
||||
> = []
|
||||
|
||||
constructor(options: types.BaseTaskOptions = {}) {
|
||||
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
|
||||
|
||||
|
@ -74,6 +79,20 @@ export abstract class BaseTask<
|
|||
return ''
|
||||
}
|
||||
|
||||
public addBeforeCallHook(
|
||||
hook: (ctx: types.TaskCallContext<TInput>) => Promise<void>
|
||||
): this {
|
||||
this._preHooks.push(hook)
|
||||
return this
|
||||
}
|
||||
|
||||
public addAfterCallHook(
|
||||
hook: (output: TOutput, ctx: types.TaskCallContext<TInput>) => Promise<void>
|
||||
): this {
|
||||
this._postHooks.push(hook)
|
||||
return this
|
||||
}
|
||||
|
||||
public validate() {
|
||||
if (!this._agentic) {
|
||||
throw new Error(
|
||||
|
@ -136,33 +155,49 @@ export abstract class BaseTask<
|
|||
}
|
||||
}
|
||||
|
||||
const result = await pRetry(() => this._call(ctx), {
|
||||
...this._retryConfig,
|
||||
onFailedAttempt: async (err: FailedAttemptError) => {
|
||||
this._logger.warn(
|
||||
err,
|
||||
`Task error "${this.nameForHuman}" failed attempt ${
|
||||
err.attemptNumber
|
||||
}${input ? ': ' + JSON.stringify(input) : ''}`
|
||||
)
|
||||
for (const hook of this._preHooks) {
|
||||
await hook(ctx)
|
||||
}
|
||||
|
||||
if (this._retryConfig.onFailedAttempt) {
|
||||
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
|
||||
const result = await pRetry(
|
||||
async () => {
|
||||
const result = await this._call(ctx)
|
||||
for (const hook of this._postHooks) {
|
||||
await hook(result, ctx)
|
||||
}
|
||||
|
||||
// TODO: log this task error
|
||||
ctx.attemptNumber = err.attemptNumber + 1
|
||||
ctx.metadata.error = err
|
||||
return result
|
||||
},
|
||||
{
|
||||
...this._retryConfig,
|
||||
onFailedAttempt: async (err: FailedAttemptError) => {
|
||||
this._logger.warn(
|
||||
err,
|
||||
`Task error "${this.nameForHuman}" failed attempt ${
|
||||
err.attemptNumber
|
||||
}${input ? ': ' + JSON.stringify(input) : ''}`
|
||||
)
|
||||
|
||||
if (err instanceof errors.ZodOutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.OutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else {
|
||||
throw err
|
||||
if (this._retryConfig.onFailedAttempt) {
|
||||
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
|
||||
}
|
||||
|
||||
// TODO: log this task error
|
||||
ctx.attemptNumber = err.attemptNumber + 1
|
||||
ctx.metadata.error = err
|
||||
|
||||
if (err instanceof errors.ZodOutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.OutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.HumanFeedbackDeclineError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
ctx.metadata.success = true
|
||||
ctx.metadata.numRetries = ctx.attemptNumber
|
||||
|
|
Ładowanie…
Reference in New Issue