Merge pull request #26 from transitive-bullshit/feature/hook-priorities

WIP feat: add support for hook priorities and return values
old-agentic-v1^2
Philipp Burckhardt 2023-06-24 10:32:12 -04:00 zatwierdzone przez GitHub
commit 63f9536eb9
3 zmienionych plików z 70 dodań i 13 usunięć

Wyświetl plik

@ -1,3 +1,4 @@
export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo' export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo'
export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1' export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1'
export const DEFAULT_BOT_NAME = 'Agentic Bot' export const DEFAULT_BOT_NAME = 'Agentic Bot'
export const SKIP_HOOKS = Symbol('SKIP_HOOKS')

Wyświetl plik

@ -4,6 +4,7 @@ import { ZodType } from 'zod'
import * as errors from './errors' import * as errors from './errors'
import * as types from './types' import * as types from './types'
import type { Agentic } from './agentic' import type { Agentic } from './agentic'
import { SKIP_HOOKS } from './constants'
import { import {
HumanFeedbackMechanismCLI, HumanFeedbackMechanismCLI,
HumanFeedbackOptions, HumanFeedbackOptions,
@ -34,8 +35,14 @@ export abstract class BaseTask<
protected _timeoutMs?: number protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig protected _retryConfig: types.RetryConfig
private _preHooks: Array<types.TaskBeforeCallHook<TInput>> = [] private _preHooks: Array<{
private _postHooks: Array<types.TaskAfterCallHook<TInput, TOutput>> = [] hook: types.TaskBeforeCallHook<TInput>
priority: number
}> = []
private _postHooks: Array<{
hook: types.TaskAfterCallHook<TInput, TOutput>
priority: number
}> = []
constructor(options: types.BaseTaskOptions = {}) { constructor(options: types.BaseTaskOptions = {}) {
this._agentic = options.agentic ?? globalThis.__agentic?.deref() this._agentic = options.agentic ?? globalThis.__agentic?.deref()
@ -82,15 +89,33 @@ export abstract class BaseTask<
return '' return ''
} }
public addBeforeCallHook(hook: types.TaskBeforeCallHook<TInput>): this { /**
this._preHooks.push(hook) * Adds a hook to be called before the task is invoked.
*
* @param hook - function to be called before the task is invoked
* @param priority - priority of the hook; higher priority hooks are called first
*/
public addBeforeCallHook(
hook: types.TaskBeforeCallHook<TInput>,
priority = 0
): this {
this._preHooks.push({ hook, priority })
this._preHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
return this return this
} }
/**
* Adds a hook to be called after the task is invoked.
*
* @param hook - function to be called after the task is invoked
* @param priority - priority of the hook; higher priority hooks are called first
*/
public addAfterCallHook( public addAfterCallHook(
hook: types.TaskAfterCallHook<TInput, TOutput> hook: types.TaskAfterCallHook<TInput, TOutput>,
priority = 0
): this { ): this {
this._postHooks.push(hook) this._postHooks.push({ hook, priority })
this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
return this return this
} }
@ -202,16 +227,36 @@ export abstract class BaseTask<
} }
} }
for (const preHook of this._preHooks) { for (const { hook: preHook } of this._preHooks) {
await preHook(ctx) const preHookResult = await preHook(ctx)
if (preHookResult === SKIP_HOOKS) {
break
} else if (preHookResult !== undefined) {
const output = this.outputSchema?.safeParse(preHookResult)
if (!output?.success) {
throw new Error(`Invalid preHook output: ${output?.error.message}`)
}
ctx.metadata.success = true
ctx.metadata.numRetries = ctx.attemptNumber
ctx.metadata.error = undefined
return {
result: output.data,
metadata: ctx.metadata
}
}
} }
const result = await pRetry( const result = await pRetry(
async () => { async () => {
const result = await this._call(ctx) const result = await this._call(ctx)
for (const postHook of this._postHooks) { for (const { hook: postHook } of this._postHooks) {
await postHook(result, ctx) const postHookResult = await postHook(result, ctx)
if (postHookResult === SKIP_HOOKS) {
break
}
} }
return result return result

Wyświetl plik

@ -6,6 +6,7 @@ import type { JsonObject, Jsonifiable } from 'type-fest'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic' import type { Agentic } from './agentic'
import { SKIP_HOOKS } from './constants'
import type { import type {
FeedbackTypeToMetadata, FeedbackTypeToMetadata,
HumanFeedbackType HumanFeedbackType
@ -173,11 +174,21 @@ export declare class CancelablePromise<T> extends Promise<T> {
// export type ProgressFunction = (partialResponse: ChatMessage) => void // export type ProgressFunction = (partialResponse: ChatMessage) => void
export type TaskBeforeCallHook<TInput extends TaskInput = void> = ( export type TaskBeforeCallHook<
TInput extends TaskInput = void,
TOutput extends TaskOutput = string
> = (
ctx: TaskCallContext<TInput> ctx: TaskCallContext<TInput>
) => void | Promise<void> ) =>
| void
| TOutput
| typeof SKIP_HOOKS
| Promise<void | TOutput | typeof SKIP_HOOKS>
export type TaskAfterCallHook< export type TaskAfterCallHook<
TInput extends TaskInput = void, TInput extends TaskInput = void,
TOutput extends TaskOutput = string TOutput extends TaskOutput = string
> = (output: TOutput, ctx: TaskCallContext<TInput>) => void | Promise<void> > = (
output: TOutput,
ctx: TaskCallContext<TInput>
) => void | typeof SKIP_HOOKS | Promise<void | typeof SKIP_HOOKS>