diff --git a/legacy/src/constants.ts b/legacy/src/constants.ts index e39b8aa1..dc85269f 100644 --- a/legacy/src/constants.ts +++ b/legacy/src/constants.ts @@ -1,3 +1,4 @@ export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo' export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1' export const DEFAULT_BOT_NAME = 'Agentic Bot' +export const SKIP_HOOKS = Symbol('SKIP_HOOKS') diff --git a/legacy/src/task.ts b/legacy/src/task.ts index 5f093075..ed1faacf 100644 --- a/legacy/src/task.ts +++ b/legacy/src/task.ts @@ -4,6 +4,7 @@ import { ZodType } from 'zod' import * as errors from './errors' import * as types from './types' import type { Agentic } from './agentic' +import { SKIP_HOOKS } from './constants' import { HumanFeedbackMechanismCLI, HumanFeedbackOptions, @@ -34,8 +35,14 @@ export abstract class BaseTask< protected _timeoutMs?: number protected _retryConfig: types.RetryConfig - private _preHooks: Array> = [] - private _postHooks: Array> = [] + private _preHooks: Array<{ + hook: types.TaskBeforeCallHook + priority: number + }> = [] + private _postHooks: Array<{ + hook: types.TaskAfterCallHook + priority: number + }> = [] constructor(options: types.BaseTaskOptions = {}) { this._agentic = options.agentic ?? globalThis.__agentic?.deref() @@ -82,15 +89,33 @@ export abstract class BaseTask< return '' } - public addBeforeCallHook(hook: types.TaskBeforeCallHook): this { - this._preHooks.push(hook) + /** + * Adds a hook to be called before the task is invoked. + * + * @param hook - function to be called before the task is invoked + * @param priority - priority of the hook; higher priority hooks are called first + */ + public addBeforeCallHook( + hook: types.TaskBeforeCallHook, + priority = 0 + ): this { + this._preHooks.push({ hook, priority }) + this._preHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019) return this } + /** + * Adds a hook to be called after the task is invoked. + * + * @param hook - function to be called after the task is invoked + * @param priority - priority of the hook; higher priority hooks are called first + */ public addAfterCallHook( - hook: types.TaskAfterCallHook + hook: types.TaskAfterCallHook, + priority = 0 ): this { - this._postHooks.push(hook) + this._postHooks.push({ hook, priority }) + this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019) return this } @@ -202,16 +227,36 @@ export abstract class BaseTask< } } - for (const preHook of this._preHooks) { - await preHook(ctx) + for (const { hook: preHook } of this._preHooks) { + const preHookResult = await preHook(ctx) + if (preHookResult === SKIP_HOOKS) { + break + } else if (preHookResult !== undefined) { + const output = this.outputSchema?.safeParse(preHookResult) + if (!output?.success) { + throw new Error(`Invalid preHook output: ${output?.error.message}`) + } + + ctx.metadata.success = true + ctx.metadata.numRetries = ctx.attemptNumber + ctx.metadata.error = undefined + + return { + result: output.data, + metadata: ctx.metadata + } + } } const result = await pRetry( async () => { const result = await this._call(ctx) - for (const postHook of this._postHooks) { - await postHook(result, ctx) + for (const { hook: postHook } of this._postHooks) { + const postHookResult = await postHook(result, ctx) + if (postHookResult === SKIP_HOOKS) { + break + } } return result diff --git a/legacy/src/types.ts b/legacy/src/types.ts index 3f8ecd93..26f9ba99 100644 --- a/legacy/src/types.ts +++ b/legacy/src/types.ts @@ -6,6 +6,7 @@ import type { JsonObject, Jsonifiable } from 'type-fest' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import type { Agentic } from './agentic' +import { SKIP_HOOKS } from './constants' import type { FeedbackTypeToMetadata, HumanFeedbackType @@ -173,11 +174,21 @@ export declare class CancelablePromise extends Promise { // export type ProgressFunction = (partialResponse: ChatMessage) => void -export type TaskBeforeCallHook = ( +export type TaskBeforeCallHook< + TInput extends TaskInput = void, + TOutput extends TaskOutput = string +> = ( ctx: TaskCallContext -) => void | Promise +) => + | void + | TOutput + | typeof SKIP_HOOKS + | Promise export type TaskAfterCallHook< TInput extends TaskInput = void, TOutput extends TaskOutput = string -> = (output: TOutput, ctx: TaskCallContext) => void | Promise +> = ( + output: TOutput, + ctx: TaskCallContext +) => void | typeof SKIP_HOOKS | Promise