feat: add pre- and post-hooks and cause retries in feedback

PR-URL: #21
Philipp Burckhardt 2023-06-16 19:05:07 -04:00 zatwierdzone przez GitHub
commit 3cce379a2d
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
6 zmienionych plików z 191 dodań i 84 usunięć

Wyświetl plik

@ -0,0 +1,93 @@
import { OpenAIClient } from '@agentic/openai-fetch'
import 'dotenv/config'
import { z } from 'zod'
import {
Agentic,
NovuNotificationTool,
SerpAPITool,
withHumanFeedback
} from '@/index'
async function main() {
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
const agentic = new Agentic({ openai })
const question = 'How do I build a product that people will love?'
const task = withHumanFeedback(
agentic
.gpt4(
`Generate a list of {n} prominent experts that can answer the following question: {{question}}.`
)
.tools([new SerpAPITool()])
.output(
z.array(
z.object({
name: z.string(),
bio: z.string()
})
)
)
.input(
z.object({
question: z.string(),
n: z.number().int().default(5)
})
),
{
type: 'selectN'
}
)
const { metadata } = await task.callWithMetadata({
question
})
if (
metadata.feedback &&
metadata.feedback.type === 'selectN' &&
metadata.feedback.selected
) {
const answer = await agentic
.gpt4(
`Generate an answer to the following question: "{{question}}" from each of the following experts: {{#each experts}}
- {{this.name}}: {{this.bio}}
{{/each}}`
)
.output(
z.array(
z.object({
expert: z.string(),
answer: z.string()
})
)
)
.input(
z.object({
question: z.string(),
experts: z.array(z.object({ name: z.string(), bio: z.string() }))
})
)
.call({
question,
experts: metadata.feedback.selected
})
const message = answer.reduce((acc, { expert, answer }) => {
return `${acc}
${expert}: ${answer}`
}, '')
const notifier = new NovuNotificationTool()
await notifier.call({
name: 'send-email',
payload: {
subject: 'Experts have answered your question: ' + question,
message
},
to: [{ subscriberId: '123' }]
})
}
}
main()

Wyświetl plik

@ -1,51 +0,0 @@
/**
export type Metadata = Record<string, unknown>;
export abstract class BaseTask<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
> {
// ...
private _preHooks: ((input?: types.ParsedData<TInput>) => void | Promise<void>, metadata: types.Metadata)[] = [];
private _postHooks: ((result: types.ParsedData<TOutput>, metadata: types.Metadata) => void | Promise<void>)[] = [];
public registerPreHook(hook: (input?: types.ParsedData<TInput>) => void | Promise<void>): this {
this._preHooks.push(hook);
return this;
}
public registerPostHook(hook: (result: types.ParsedData<TOutput>) => void | Promise<void>): this {
this._postHooks.push(hook);
return this;
}
public async callWithMetadata(
input?: types.ParsedData<TInput>,
options: { dryRun?: boolean } = {}
): Promise<{result: types.ParsedData<TOutput> | undefined, metadata: types.Metadata}> {
const metadata: types.Metadata = {};
if (options.dryRun) {
return console.log( '// TODO: implement' )
}
for (const hook of this._preHooks) {
await hook(input);
}
const result = await this._call(input);
for (const hook of this._postHooks) {
await hook(result, metadata);
}
return {result, metadata};
}
}
**/

Wyświetl plik

@ -82,3 +82,14 @@ export class TemplateValidationError extends BaseError {
Error.captureStackTrace?.(this, this.constructor)
}
}
/**
* An error caused by the user declining an output.
*/
export class HumanFeedbackDeclineError extends BaseError {
constructor(message: string, opts: ErrorOptions = {}) {
super(message, opts)
Error.captureStackTrace?.(this, this.constructor)
}
}

Wyświetl plik

@ -1,5 +1,6 @@
import * as types from '@/types'
import { Agentic } from '@/agentic'
import { HumanFeedbackDeclineError } from '@/errors'
import { BaseTask } from '@/task'
import { HumanFeedbackMechanismCLI } from './cli'
@ -220,7 +221,9 @@ export abstract class HumanFeedbackMechanism<
? HumanFeedbackUserActions.Select
: await this._askUser(msg, choices)
const feedback: Record<string, any> = {}
const feedback: Record<string, any> = {
type: this._options.type
}
switch (choice) {
case HumanFeedbackUserActions.Accept:
@ -268,6 +271,26 @@ export abstract class HumanFeedbackMechanism<
}
}
if (
(Object.hasOwnProperty.call(feedback, 'accepted') &&
feedback.accepted === false) ||
(Object.hasOwnProperty.call(feedback, 'selected') &&
feedback.selected.length === 0)
) {
const errorMsg = [
'The output was declined by the human reviewer.',
'Output:',
'```',
stringified,
'```',
'',
'Please try again and return different output.'
].join('\n')
throw new HumanFeedbackDeclineError(errorMsg, {
context: feedback
})
}
return feedback as FeedbackTypeToMetadata<T>
}
}
@ -309,17 +332,10 @@ export function withHumanFeedback<
options: finalOptions
})
const originalCall = task.callWithMetadata.bind(task)
task.callWithMetadata = async function (input?: TInput) {
const response = await originalCall(input)
const feedback = await feedbackMechanism.interact(response.result)
response.metadata = { ...response.metadata, feedback }
return response
}
task.addAfterCallHook(async function onCall(output, ctx) {
const feedback = await feedbackMechanism.interact(output)
ctx.metadata = { ...ctx.metadata, feedback }
})
return task
}

Wyświetl plik

@ -29,6 +29,9 @@ export abstract class BaseTask<
protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig
private _preHooks: Array<types.BeforeCallHook<TInput>> = []
private _postHooks: Array<types.AfterCallHook<TInput, TOutput>> = []
constructor(options: types.BaseTaskOptions = {}) {
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
@ -74,6 +77,16 @@ export abstract class BaseTask<
return ''
}
public addBeforeCallHook(hook: types.BeforeCallHook<TInput>): this {
this._preHooks.push(hook)
return this
}
public addAfterCallHook(hook: types.AfterCallHook<TInput, TOutput>): this {
this._postHooks.push(hook)
return this
}
public validate() {
if (!this._agentic) {
throw new Error(
@ -136,33 +149,49 @@ export abstract class BaseTask<
}
}
const result = await pRetry(() => this._call(ctx), {
...this._retryConfig,
onFailedAttempt: async (err: FailedAttemptError) => {
this._logger.warn(
err,
`Task error "${this.nameForHuman}" failed attempt ${
err.attemptNumber
}${input ? ': ' + JSON.stringify(input) : ''}`
)
for (const hook of this._preHooks) {
await hook(ctx)
}
if (this._retryConfig.onFailedAttempt) {
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
const result = await pRetry(
async () => {
const result = await this._call(ctx)
for (const hook of this._postHooks) {
await hook(result, ctx)
}
// TODO: log this task error
ctx.attemptNumber = err.attemptNumber + 1
ctx.metadata.error = err
return result
},
{
...this._retryConfig,
onFailedAttempt: async (err: FailedAttemptError) => {
this._logger.warn(
err,
`Task error "${this.nameForHuman}" failed attempt ${
err.attemptNumber
}${input ? ': ' + JSON.stringify(input) : ''}`
)
if (err instanceof errors.ZodOutputValidationError) {
ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
} else {
throw err
if (this._retryConfig.onFailedAttempt) {
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
}
// TODO: log this task error
ctx.attemptNumber = err.attemptNumber + 1
ctx.metadata.error = err
if (err instanceof errors.ZodOutputValidationError) {
ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
} else if (err instanceof errors.HumanFeedbackDeclineError) {
ctx.retryMessage = err.message
} else {
throw err
}
}
}
})
)
ctx.metadata.success = true
ctx.metadata.numRetries = ctx.attemptNumber

Wyświetl plik

@ -152,3 +152,12 @@ export declare class CancelablePromise<T> extends Promise<T> {
}
// export type ProgressFunction = (partialResponse: ChatMessage) => void
export type BeforeCallHook<TInput extends TaskInput = void> = (
ctx: TaskCallContext<TInput>
) => void | Promise<void>
export type AfterCallHook<
TInput extends TaskInput = void,
TOutput extends TaskOutput = string
> = (output: TOutput, ctx: TaskCallContext<TInput>) => void | Promise<void>