diff --git a/legacy/src/constants.ts b/legacy/src/constants.ts index e39b8aa1..dc85269f 100644 --- a/legacy/src/constants.ts +++ b/legacy/src/constants.ts @@ -1,3 +1,4 @@ export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo' export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1' export const DEFAULT_BOT_NAME = 'Agentic Bot' +export const SKIP_HOOKS = Symbol('SKIP_HOOKS') diff --git a/legacy/src/task.ts b/legacy/src/task.ts index 6252af27..00bc9a53 100644 --- a/legacy/src/task.ts +++ b/legacy/src/task.ts @@ -4,6 +4,7 @@ import { ZodType } from 'zod' import * as errors from './errors' import * as types from './types' import type { Agentic } from './agentic' +import { SKIP_HOOKS } from './constants' import { HumanFeedbackMechanismCLI, HumanFeedbackOptions, @@ -34,8 +35,16 @@ export abstract class BaseTask< protected _timeoutMs?: number protected _retryConfig: types.RetryConfig - private _preHooks: Array> = [] - private _postHooks: Array> = [] + private _preHooks: Array<{ + hook: types.TaskBeforeCallHook + priority: number + name: string + }> = [] + private _postHooks: Array<{ + hook: types.TaskAfterCallHook + priority: number + name: string + }> = [] constructor(options: types.BaseTaskOptions = {}) { this._agentic = options.agentic ?? globalThis.__agentic?.deref() @@ -82,15 +91,95 @@ export abstract class BaseTask< return '' } - public addBeforeCallHook(hook: types.TaskBeforeCallHook): this { - this._preHooks.push(hook) + /** + * Adds a hook to be called before the task is invoked. + * + * @param hook - function to be called before the task is invoked + * @param options - options for the hook; `priority` is used to determine the order in which hooks are called, with higher priority hooks being called first, and `name` is used to identify the hook + */ + public addBeforeCallHook( + hook: types.TaskBeforeCallHook, + { priority = 0, name }: { priority?: number; name?: string } = {} + ): this { + const hookName = name ?? `preHook_${this._preHooks.length}` + this._preHooks.push({ hook, priority, name: hookName }) + this._preHooks.sort((a, b) => b.priority - a.priority) return this } + /** + * Adds a hook to be called after the task is invoked. + * + * @param hook - function to be called after the task is invoked + * @param options - options for the hook; `priority` is used to determine the order in which hooks are called, with higher priority hooks being called first, and `name` is used to identify the hook + */ public addAfterCallHook( - hook: types.TaskAfterCallHook + hook: types.TaskAfterCallHook, + { priority = 0, name }: { priority?: number; name?: string } = {} ): this { - this._postHooks.push(hook) + const hookName = name ?? `postHook_${this._postHooks.length}` + this._postHooks.push({ hook, priority, name: hookName }) + this._postHooks.sort((a, b) => b.priority - a.priority) + return this + } + + /** + * Changes the priority of a before call hook. + * + * @param hookType - `before` + * @param hookOrName - hook or the name of the hook to change the priority of + * @param newPriority - new priority of the hook + */ + public changeHookPriority( + hookType: 'before', + hookOrName: types.TaskBeforeCallHook | string, + newPriority: number + ): this + + /** + * Changes the priority of a after call hook. + * + * @param hookType - `after` + * @param hookOrName - hook or the name of the hook to change the priority of + * @param newPriority - new priority of the hook + */ + public changeHookPriority( + hookType: 'after', + hookOrName: types.TaskAfterCallHook | string, + newPriority: number + ): this + + public changeHookPriority( + hookType: 'before' | 'after', + hookOrName: + | types.TaskBeforeCallHook + | types.TaskAfterCallHook + | string, + newPriority: number + ): this { + const hooks = hookType === 'before' ? this._preHooks : this._postHooks + + if (typeof hookOrName === 'string') { + const hookObj = hooks.find((h) => h.name === hookOrName) + if (!hookObj) { + throw new Error( + `Could not find a ${hookType}-call hook named "${hookOrName}" to change its priority` + ) + } + + hookObj.priority = newPriority + } else { + const hookObj = hooks.find((h) => h.hook === hookOrName) + if (!hookObj) { + throw new Error( + `Could not find the provided ${hookType}-call hook to change its priority` + ) + } + + hookObj.priority = newPriority + } + + hooks.sort((a, b) => b.priority - a.priority) return this } @@ -140,10 +229,13 @@ export abstract class BaseTask< options }) - this.addAfterCallHook(async (output, ctx) => { - const feedback = await feedbackMechanism.interact(output) - ctx.metadata = { ...ctx.metadata, feedback } - }) + this.addAfterCallHook( + async (output, ctx) => { + const feedback = await feedbackMechanism.interact(output) + ctx.metadata = { ...ctx.metadata, feedback } + }, + { name: 'humanFeedback' } + ) return this } @@ -194,16 +286,32 @@ export abstract class BaseTask< } } - for (const preHook of this._preHooks) { - await preHook(ctx) + for (const { hook: preHook } of this._preHooks) { + const preHookResult = await preHook(ctx) + if (preHookResult === SKIP_HOOKS) { + break + } else if (preHookResult !== undefined) { + const output = this.outputSchema?.safeParse(preHookResult) + if (!output?.success) { + throw new Error(`Invalid preHook output: ${output?.error.message}`) + } + + return { + result: output.data, + metadata: ctx.metadata + } + } } const result = await pRetry( async () => { const result = await this._call(ctx) - for (const postHook of this._postHooks) { - await postHook(result, ctx) + for (const { hook: postHook } of this._postHooks) { + const postHookResult = await postHook(result, ctx) + if (postHookResult === SKIP_HOOKS) { + break + } } return result diff --git a/legacy/src/types.ts b/legacy/src/types.ts index 4b457001..e3897228 100644 --- a/legacy/src/types.ts +++ b/legacy/src/types.ts @@ -6,6 +6,7 @@ import type { JsonObject, Jsonifiable } from 'type-fest' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import type { Agentic } from './agentic' +import { SKIP_HOOKS } from './constants' import type { FeedbackTypeToMetadata, HumanFeedbackType @@ -155,11 +156,21 @@ export declare class CancelablePromise extends Promise { // export type ProgressFunction = (partialResponse: ChatMessage) => void -export type TaskBeforeCallHook = ( +export type TaskBeforeCallHook< + TInput extends TaskInput = void, + TOutput extends TaskOutput = string +> = ( ctx: TaskCallContext -) => void | Promise +) => + | void + | TOutput + | typeof SKIP_HOOKS + | Promise export type TaskAfterCallHook< TInput extends TaskInput = void, TOutput extends TaskOutput = string -> = (output: TOutput, ctx: TaskCallContext) => void | Promise +> = ( + output: TOutput, + ctx: TaskCallContext +) => void | typeof SKIP_HOOKS | Promise