kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add pre- and post-hooks and cause retries in feedback
PR-URL: #21
commit
3cce379a2d
|
@ -0,0 +1,93 @@
|
|||
import { OpenAIClient } from '@agentic/openai-fetch'
|
||||
import 'dotenv/config'
|
||||
import { z } from 'zod'
|
||||
|
||||
import {
|
||||
Agentic,
|
||||
NovuNotificationTool,
|
||||
SerpAPITool,
|
||||
withHumanFeedback
|
||||
} from '@/index'
|
||||
|
||||
async function main() {
|
||||
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
|
||||
const agentic = new Agentic({ openai })
|
||||
|
||||
const question = 'How do I build a product that people will love?'
|
||||
|
||||
const task = withHumanFeedback(
|
||||
agentic
|
||||
.gpt4(
|
||||
`Generate a list of {n} prominent experts that can answer the following question: {{question}}.`
|
||||
)
|
||||
.tools([new SerpAPITool()])
|
||||
.output(
|
||||
z.array(
|
||||
z.object({
|
||||
name: z.string(),
|
||||
bio: z.string()
|
||||
})
|
||||
)
|
||||
)
|
||||
.input(
|
||||
z.object({
|
||||
question: z.string(),
|
||||
n: z.number().int().default(5)
|
||||
})
|
||||
),
|
||||
{
|
||||
type: 'selectN'
|
||||
}
|
||||
)
|
||||
const { metadata } = await task.callWithMetadata({
|
||||
question
|
||||
})
|
||||
|
||||
if (
|
||||
metadata.feedback &&
|
||||
metadata.feedback.type === 'selectN' &&
|
||||
metadata.feedback.selected
|
||||
) {
|
||||
const answer = await agentic
|
||||
.gpt4(
|
||||
`Generate an answer to the following question: "{{question}}" from each of the following experts: {{#each experts}}
|
||||
- {{this.name}}: {{this.bio}}
|
||||
{{/each}}`
|
||||
)
|
||||
.output(
|
||||
z.array(
|
||||
z.object({
|
||||
expert: z.string(),
|
||||
answer: z.string()
|
||||
})
|
||||
)
|
||||
)
|
||||
.input(
|
||||
z.object({
|
||||
question: z.string(),
|
||||
experts: z.array(z.object({ name: z.string(), bio: z.string() }))
|
||||
})
|
||||
)
|
||||
.call({
|
||||
question,
|
||||
experts: metadata.feedback.selected
|
||||
})
|
||||
|
||||
const message = answer.reduce((acc, { expert, answer }) => {
|
||||
return `${acc}
|
||||
|
||||
${expert}: ${answer}`
|
||||
}, '')
|
||||
const notifier = new NovuNotificationTool()
|
||||
await notifier.call({
|
||||
name: 'send-email',
|
||||
payload: {
|
||||
subject: 'Experts have answered your question: ' + question,
|
||||
message
|
||||
},
|
||||
to: [{ subscriberId: '123' }]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
main()
|
|
@ -1,51 +0,0 @@
|
|||
/**
|
||||
|
||||
export type Metadata = Record<string, unknown>;
|
||||
|
||||
|
||||
export abstract class BaseTask<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = ZodTypeAny
|
||||
> {
|
||||
|
||||
// ...
|
||||
|
||||
private _preHooks: ((input?: types.ParsedData<TInput>) => void | Promise<void>, metadata: types.Metadata)[] = [];
|
||||
private _postHooks: ((result: types.ParsedData<TOutput>, metadata: types.Metadata) => void | Promise<void>)[] = [];
|
||||
|
||||
|
||||
public registerPreHook(hook: (input?: types.ParsedData<TInput>) => void | Promise<void>): this {
|
||||
this._preHooks.push(hook);
|
||||
return this;
|
||||
}
|
||||
|
||||
public registerPostHook(hook: (result: types.ParsedData<TOutput>) => void | Promise<void>): this {
|
||||
this._postHooks.push(hook);
|
||||
return this;
|
||||
}
|
||||
|
||||
public async callWithMetadata(
|
||||
input?: types.ParsedData<TInput>,
|
||||
options: { dryRun?: boolean } = {}
|
||||
): Promise<{result: types.ParsedData<TOutput> | undefined, metadata: types.Metadata}> {
|
||||
const metadata: types.Metadata = {};
|
||||
|
||||
if (options.dryRun) {
|
||||
return console.log( '// TODO: implement' )
|
||||
}
|
||||
|
||||
for (const hook of this._preHooks) {
|
||||
await hook(input);
|
||||
}
|
||||
|
||||
const result = await this._call(input);
|
||||
|
||||
for (const hook of this._postHooks) {
|
||||
await hook(result, metadata);
|
||||
}
|
||||
|
||||
return {result, metadata};
|
||||
}
|
||||
}
|
||||
|
||||
**/
|
|
@ -82,3 +82,14 @@ export class TemplateValidationError extends BaseError {
|
|||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An error caused by the user declining an output.
|
||||
*/
|
||||
export class HumanFeedbackDeclineError extends BaseError {
|
||||
constructor(message: string, opts: ErrorOptions = {}) {
|
||||
super(message, opts)
|
||||
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
import { HumanFeedbackDeclineError } from '@/errors'
|
||||
import { BaseTask } from '@/task'
|
||||
|
||||
import { HumanFeedbackMechanismCLI } from './cli'
|
||||
|
@ -220,7 +221,9 @@ export abstract class HumanFeedbackMechanism<
|
|||
? HumanFeedbackUserActions.Select
|
||||
: await this._askUser(msg, choices)
|
||||
|
||||
const feedback: Record<string, any> = {}
|
||||
const feedback: Record<string, any> = {
|
||||
type: this._options.type
|
||||
}
|
||||
|
||||
switch (choice) {
|
||||
case HumanFeedbackUserActions.Accept:
|
||||
|
@ -268,6 +271,26 @@ export abstract class HumanFeedbackMechanism<
|
|||
}
|
||||
}
|
||||
|
||||
if (
|
||||
(Object.hasOwnProperty.call(feedback, 'accepted') &&
|
||||
feedback.accepted === false) ||
|
||||
(Object.hasOwnProperty.call(feedback, 'selected') &&
|
||||
feedback.selected.length === 0)
|
||||
) {
|
||||
const errorMsg = [
|
||||
'The output was declined by the human reviewer.',
|
||||
'Output:',
|
||||
'```',
|
||||
stringified,
|
||||
'```',
|
||||
'',
|
||||
'Please try again and return different output.'
|
||||
].join('\n')
|
||||
throw new HumanFeedbackDeclineError(errorMsg, {
|
||||
context: feedback
|
||||
})
|
||||
}
|
||||
|
||||
return feedback as FeedbackTypeToMetadata<T>
|
||||
}
|
||||
}
|
||||
|
@ -309,17 +332,10 @@ export function withHumanFeedback<
|
|||
options: finalOptions
|
||||
})
|
||||
|
||||
const originalCall = task.callWithMetadata.bind(task)
|
||||
|
||||
task.callWithMetadata = async function (input?: TInput) {
|
||||
const response = await originalCall(input)
|
||||
|
||||
const feedback = await feedbackMechanism.interact(response.result)
|
||||
|
||||
response.metadata = { ...response.metadata, feedback }
|
||||
|
||||
return response
|
||||
}
|
||||
task.addAfterCallHook(async function onCall(output, ctx) {
|
||||
const feedback = await feedbackMechanism.interact(output)
|
||||
ctx.metadata = { ...ctx.metadata, feedback }
|
||||
})
|
||||
|
||||
return task
|
||||
}
|
||||
|
|
71
src/task.ts
71
src/task.ts
|
@ -29,6 +29,9 @@ export abstract class BaseTask<
|
|||
protected _timeoutMs?: number
|
||||
protected _retryConfig: types.RetryConfig
|
||||
|
||||
private _preHooks: Array<types.BeforeCallHook<TInput>> = []
|
||||
private _postHooks: Array<types.AfterCallHook<TInput, TOutput>> = []
|
||||
|
||||
constructor(options: types.BaseTaskOptions = {}) {
|
||||
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
|
||||
|
||||
|
@ -74,6 +77,16 @@ export abstract class BaseTask<
|
|||
return ''
|
||||
}
|
||||
|
||||
public addBeforeCallHook(hook: types.BeforeCallHook<TInput>): this {
|
||||
this._preHooks.push(hook)
|
||||
return this
|
||||
}
|
||||
|
||||
public addAfterCallHook(hook: types.AfterCallHook<TInput, TOutput>): this {
|
||||
this._postHooks.push(hook)
|
||||
return this
|
||||
}
|
||||
|
||||
public validate() {
|
||||
if (!this._agentic) {
|
||||
throw new Error(
|
||||
|
@ -136,33 +149,49 @@ export abstract class BaseTask<
|
|||
}
|
||||
}
|
||||
|
||||
const result = await pRetry(() => this._call(ctx), {
|
||||
...this._retryConfig,
|
||||
onFailedAttempt: async (err: FailedAttemptError) => {
|
||||
this._logger.warn(
|
||||
err,
|
||||
`Task error "${this.nameForHuman}" failed attempt ${
|
||||
err.attemptNumber
|
||||
}${input ? ': ' + JSON.stringify(input) : ''}`
|
||||
)
|
||||
for (const hook of this._preHooks) {
|
||||
await hook(ctx)
|
||||
}
|
||||
|
||||
if (this._retryConfig.onFailedAttempt) {
|
||||
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
|
||||
const result = await pRetry(
|
||||
async () => {
|
||||
const result = await this._call(ctx)
|
||||
for (const hook of this._postHooks) {
|
||||
await hook(result, ctx)
|
||||
}
|
||||
|
||||
// TODO: log this task error
|
||||
ctx.attemptNumber = err.attemptNumber + 1
|
||||
ctx.metadata.error = err
|
||||
return result
|
||||
},
|
||||
{
|
||||
...this._retryConfig,
|
||||
onFailedAttempt: async (err: FailedAttemptError) => {
|
||||
this._logger.warn(
|
||||
err,
|
||||
`Task error "${this.nameForHuman}" failed attempt ${
|
||||
err.attemptNumber
|
||||
}${input ? ': ' + JSON.stringify(input) : ''}`
|
||||
)
|
||||
|
||||
if (err instanceof errors.ZodOutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.OutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else {
|
||||
throw err
|
||||
if (this._retryConfig.onFailedAttempt) {
|
||||
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
|
||||
}
|
||||
|
||||
// TODO: log this task error
|
||||
ctx.attemptNumber = err.attemptNumber + 1
|
||||
ctx.metadata.error = err
|
||||
|
||||
if (err instanceof errors.ZodOutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.OutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.HumanFeedbackDeclineError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
ctx.metadata.success = true
|
||||
ctx.metadata.numRetries = ctx.attemptNumber
|
||||
|
|
|
@ -152,3 +152,12 @@ export declare class CancelablePromise<T> extends Promise<T> {
|
|||
}
|
||||
|
||||
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
||||
|
||||
export type BeforeCallHook<TInput extends TaskInput = void> = (
|
||||
ctx: TaskCallContext<TInput>
|
||||
) => void | Promise<void>
|
||||
|
||||
export type AfterCallHook<
|
||||
TInput extends TaskInput = void,
|
||||
TOutput extends TaskOutput = string
|
||||
> = (output: TOutput, ctx: TaskCallContext<TInput>) => void | Promise<void>
|
||||
|
|
Ładowanie…
Reference in New Issue