From 5875a8858fe35b87f3981784328ba0fbdfbc490f Mon Sep 17 00:00:00 2001 From: Philipp Burckhardt Date: Fri, 16 Jun 2023 12:58:52 -0400 Subject: [PATCH] feat: add pre- and post-hooks and cause retries in feedback --- scratch/hooks.ts | 51 ---------------------- src/errors.ts | 11 +++++ src/human-feedback/feedback.ts | 36 +++++++++++----- src/task.ts | 77 ++++++++++++++++++++++++---------- 4 files changed, 92 insertions(+), 83 deletions(-) delete mode 100644 scratch/hooks.ts diff --git a/scratch/hooks.ts b/scratch/hooks.ts deleted file mode 100644 index 0ef8564..0000000 --- a/scratch/hooks.ts +++ /dev/null @@ -1,51 +0,0 @@ -/** - -export type Metadata = Record; - - -export abstract class BaseTask< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = ZodTypeAny -> { - -// ... - - private _preHooks: ((input?: types.ParsedData) => void | Promise, metadata: types.Metadata)[] = []; - private _postHooks: ((result: types.ParsedData, metadata: types.Metadata) => void | Promise)[] = []; - - - public registerPreHook(hook: (input?: types.ParsedData) => void | Promise): this { - this._preHooks.push(hook); - return this; - } - - public registerPostHook(hook: (result: types.ParsedData) => void | Promise): this { - this._postHooks.push(hook); - return this; - } - -public async callWithMetadata( - input?: types.ParsedData, - options: { dryRun?: boolean } = {} - ): Promise<{result: types.ParsedData | undefined, metadata: types.Metadata}> { - const metadata: types.Metadata = {}; - - if (options.dryRun) { - return console.log( '// TODO: implement' ) - } - - for (const hook of this._preHooks) { - await hook(input); - } - - const result = await this._call(input); - - for (const hook of this._postHooks) { - await hook(result, metadata); - } - - return {result, metadata}; - } -} - -**/ diff --git a/src/errors.ts b/src/errors.ts index 8ae22b1..5d17802 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -82,3 +82,14 @@ export class TemplateValidationError extends BaseError { Error.captureStackTrace?.(this, this.constructor) } } + +/** + * An error caused by the user declining an output. + */ +export class HumanFeedbackDeclineError extends BaseError { + constructor(message: string, opts: ErrorOptions = {}) { + super(message, opts) + + Error.captureStackTrace?.(this, this.constructor) + } +} diff --git a/src/human-feedback/feedback.ts b/src/human-feedback/feedback.ts index 5063eb2..518088b 100644 --- a/src/human-feedback/feedback.ts +++ b/src/human-feedback/feedback.ts @@ -1,5 +1,6 @@ import * as types from '@/types' import { Agentic } from '@/agentic' +import { HumanFeedbackDeclineError } from '@/errors' import { BaseTask } from '@/task' import { HumanFeedbackMechanismCLI } from './cli' @@ -268,6 +269,26 @@ export abstract class HumanFeedbackMechanism< } } + if ( + (Object.hasOwnProperty.call(feedback, 'accepted') && + feedback.accepted === false) || + (Object.hasOwnProperty.call(feedback, 'selected') && + feedback.selected.length === 0) + ) { + const errorMsg = [ + 'The output was declined by the human reviewer.', + 'Output:', + '```', + stringified, + '```', + '', + 'Please try again and return different output.' + ].join('\n') + throw new HumanFeedbackDeclineError(errorMsg, { + context: feedback + }) + } + return feedback as FeedbackTypeToMetadata } } @@ -309,17 +330,10 @@ export function withHumanFeedback< options: finalOptions }) - const originalCall = task.callWithMetadata.bind(task) - - task.callWithMetadata = async function (input?: TInput) { - const response = await originalCall(input) - - const feedback = await feedbackMechanism.interact(response.result) - - response.metadata = { ...response.metadata, feedback } - - return response - } + task.addAfterCallHook(async function onCall(output, ctx) { + const feedback = await feedbackMechanism.interact(output) + ctx.metadata = { ...ctx.metadata, feedback } + }) return task } diff --git a/src/task.ts b/src/task.ts index 224d15a..d3d8ce7 100644 --- a/src/task.ts +++ b/src/task.ts @@ -29,6 +29,11 @@ export abstract class BaseTask< protected _timeoutMs?: number protected _retryConfig: types.RetryConfig + private _preHooks: Array<(ctx: types.TaskCallContext) => void> = [] + private _postHooks: Array< + (output: TOutput, ctx: types.TaskCallContext) => void + > = [] + constructor(options: types.BaseTaskOptions = {}) { this._agentic = options.agentic ?? globalThis.__agentic?.deref() @@ -74,6 +79,20 @@ export abstract class BaseTask< return '' } + public addBeforeCallHook( + hook: (ctx: types.TaskCallContext) => Promise + ): this { + this._preHooks.push(hook) + return this + } + + public addAfterCallHook( + hook: (output: TOutput, ctx: types.TaskCallContext) => Promise + ): this { + this._postHooks.push(hook) + return this + } + public validate() { if (!this._agentic) { throw new Error( @@ -136,33 +155,49 @@ export abstract class BaseTask< } } - const result = await pRetry(() => this._call(ctx), { - ...this._retryConfig, - onFailedAttempt: async (err: FailedAttemptError) => { - this._logger.warn( - err, - `Task error "${this.nameForHuman}" failed attempt ${ - err.attemptNumber - }${input ? ': ' + JSON.stringify(input) : ''}` - ) + for (const hook of this._preHooks) { + await hook(ctx) + } - if (this._retryConfig.onFailedAttempt) { - await Promise.resolve(this._retryConfig.onFailedAttempt(err)) + const result = await pRetry( + async () => { + const result = await this._call(ctx) + for (const hook of this._postHooks) { + await hook(result, ctx) } - // TODO: log this task error - ctx.attemptNumber = err.attemptNumber + 1 - ctx.metadata.error = err + return result + }, + { + ...this._retryConfig, + onFailedAttempt: async (err: FailedAttemptError) => { + this._logger.warn( + err, + `Task error "${this.nameForHuman}" failed attempt ${ + err.attemptNumber + }${input ? ': ' + JSON.stringify(input) : ''}` + ) - if (err instanceof errors.ZodOutputValidationError) { - ctx.retryMessage = err.message - } else if (err instanceof errors.OutputValidationError) { - ctx.retryMessage = err.message - } else { - throw err + if (this._retryConfig.onFailedAttempt) { + await Promise.resolve(this._retryConfig.onFailedAttempt(err)) + } + + // TODO: log this task error + ctx.attemptNumber = err.attemptNumber + 1 + ctx.metadata.error = err + + if (err instanceof errors.ZodOutputValidationError) { + ctx.retryMessage = err.message + } else if (err instanceof errors.OutputValidationError) { + ctx.retryMessage = err.message + } else if (err instanceof errors.HumanFeedbackDeclineError) { + ctx.retryMessage = err.message + } else { + throw err + } } } - }) + ) ctx.metadata.success = true ctx.metadata.numRetries = ctx.attemptNumber