feat: add pre- and post-hooks and cause retries in feedback

old-agentic-v1^2
Philipp Burckhardt 2023-06-16 12:58:52 -04:00
rodzic 9d3a7bc05f
commit 5875a8858f
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A2C3BCA4F31D1DDD
4 zmienionych plików z 92 dodań i 83 usunięć

Wyświetl plik

@ -1,51 +0,0 @@
/**
export type Metadata = Record<string, unknown>;
export abstract class BaseTask<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
> {
// ...
private _preHooks: ((input?: types.ParsedData<TInput>) => void | Promise<void>, metadata: types.Metadata)[] = [];
private _postHooks: ((result: types.ParsedData<TOutput>, metadata: types.Metadata) => void | Promise<void>)[] = [];
public registerPreHook(hook: (input?: types.ParsedData<TInput>) => void | Promise<void>): this {
this._preHooks.push(hook);
return this;
}
public registerPostHook(hook: (result: types.ParsedData<TOutput>) => void | Promise<void>): this {
this._postHooks.push(hook);
return this;
}
public async callWithMetadata(
input?: types.ParsedData<TInput>,
options: { dryRun?: boolean } = {}
): Promise<{result: types.ParsedData<TOutput> | undefined, metadata: types.Metadata}> {
const metadata: types.Metadata = {};
if (options.dryRun) {
return console.log( '// TODO: implement' )
}
for (const hook of this._preHooks) {
await hook(input);
}
const result = await this._call(input);
for (const hook of this._postHooks) {
await hook(result, metadata);
}
return {result, metadata};
}
}
**/

Wyświetl plik

@ -82,3 +82,14 @@ export class TemplateValidationError extends BaseError {
Error.captureStackTrace?.(this, this.constructor)
}
}
/**
* An error caused by the user declining an output.
*/
export class HumanFeedbackDeclineError extends BaseError {
constructor(message: string, opts: ErrorOptions = {}) {
super(message, opts)
Error.captureStackTrace?.(this, this.constructor)
}
}

Wyświetl plik

@ -1,5 +1,6 @@
import * as types from '@/types'
import { Agentic } from '@/agentic'
import { HumanFeedbackDeclineError } from '@/errors'
import { BaseTask } from '@/task'
import { HumanFeedbackMechanismCLI } from './cli'
@ -268,6 +269,26 @@ export abstract class HumanFeedbackMechanism<
}
}
if (
(Object.hasOwnProperty.call(feedback, 'accepted') &&
feedback.accepted === false) ||
(Object.hasOwnProperty.call(feedback, 'selected') &&
feedback.selected.length === 0)
) {
const errorMsg = [
'The output was declined by the human reviewer.',
'Output:',
'```',
stringified,
'```',
'',
'Please try again and return different output.'
].join('\n')
throw new HumanFeedbackDeclineError(errorMsg, {
context: feedback
})
}
return feedback as FeedbackTypeToMetadata<T>
}
}
@ -309,17 +330,10 @@ export function withHumanFeedback<
options: finalOptions
})
const originalCall = task.callWithMetadata.bind(task)
task.callWithMetadata = async function (input?: TInput) {
const response = await originalCall(input)
const feedback = await feedbackMechanism.interact(response.result)
response.metadata = { ...response.metadata, feedback }
return response
}
task.addAfterCallHook(async function onCall(output, ctx) {
const feedback = await feedbackMechanism.interact(output)
ctx.metadata = { ...ctx.metadata, feedback }
})
return task
}

Wyświetl plik

@ -29,6 +29,11 @@ export abstract class BaseTask<
protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig
private _preHooks: Array<(ctx: types.TaskCallContext<TInput>) => void> = []
private _postHooks: Array<
(output: TOutput, ctx: types.TaskCallContext<TInput>) => void
> = []
constructor(options: types.BaseTaskOptions = {}) {
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
@ -74,6 +79,20 @@ export abstract class BaseTask<
return ''
}
public addBeforeCallHook(
hook: (ctx: types.TaskCallContext<TInput>) => Promise<void>
): this {
this._preHooks.push(hook)
return this
}
public addAfterCallHook(
hook: (output: TOutput, ctx: types.TaskCallContext<TInput>) => Promise<void>
): this {
this._postHooks.push(hook)
return this
}
public validate() {
if (!this._agentic) {
throw new Error(
@ -136,33 +155,49 @@ export abstract class BaseTask<
}
}
const result = await pRetry(() => this._call(ctx), {
...this._retryConfig,
onFailedAttempt: async (err: FailedAttemptError) => {
this._logger.warn(
err,
`Task error "${this.nameForHuman}" failed attempt ${
err.attemptNumber
}${input ? ': ' + JSON.stringify(input) : ''}`
)
for (const hook of this._preHooks) {
await hook(ctx)
}
if (this._retryConfig.onFailedAttempt) {
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
const result = await pRetry(
async () => {
const result = await this._call(ctx)
for (const hook of this._postHooks) {
await hook(result, ctx)
}
// TODO: log this task error
ctx.attemptNumber = err.attemptNumber + 1
ctx.metadata.error = err
return result
},
{
...this._retryConfig,
onFailedAttempt: async (err: FailedAttemptError) => {
this._logger.warn(
err,
`Task error "${this.nameForHuman}" failed attempt ${
err.attemptNumber
}${input ? ': ' + JSON.stringify(input) : ''}`
)
if (err instanceof errors.ZodOutputValidationError) {
ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
} else {
throw err
if (this._retryConfig.onFailedAttempt) {
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
}
// TODO: log this task error
ctx.attemptNumber = err.attemptNumber + 1
ctx.metadata.error = err
if (err instanceof errors.ZodOutputValidationError) {
ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
} else if (err instanceof errors.HumanFeedbackDeclineError) {
ctx.retryMessage = err.message
} else {
throw err
}
}
}
})
)
ctx.metadata.success = true
ctx.metadata.numRetries = ctx.attemptNumber