diff --git a/src/human-feedback/cli.ts b/src/human-feedback/cli.ts index ad795bf..cacd22d 100644 --- a/src/human-feedback/cli.ts +++ b/src/human-feedback/cli.ts @@ -11,8 +11,9 @@ import { } from './feedback' export class HumanFeedbackMechanismCLI< - T extends HumanFeedbackType -> extends HumanFeedbackMechanism { + T extends HumanFeedbackType, + TOutput = any +> extends HumanFeedbackMechanism { /** * Prompt the user to select one of a list of options. */ @@ -43,19 +44,31 @@ export class HumanFeedbackMechanismCLI< }) } - protected async selectOne(response: any[]): Promise { + protected async selectOne( + response: TOutput + ): Promise { + if (!Array.isArray(response)) { + throw new Error('selectOne called on non-array response') + } + const choices = response.map((option) => ({ - name: option, + name: String(option), value: option })) return select({ message: 'Pick one output:', choices }) } - protected async selectN(response: any[]): Promise { + protected async selectN( + response: TOutput + ): Promise { + if (!Array.isArray(response)) { + throw new Error('selectN called on non-array response') + } + const choices = response.map((option) => ({ - name: option, + name: String(option), value: option })) - return checkbox({ message: 'Select outputs:', choices }) + return checkbox({ message: 'Select outputs:', choices }) as any } } diff --git a/src/human-feedback/feedback.ts b/src/human-feedback/feedback.ts index 2e9c0b0..7936c18 100644 --- a/src/human-feedback/feedback.ts +++ b/src/human-feedback/feedback.ts @@ -33,14 +33,15 @@ export const HumanFeedbackUserActionMessages: Record< */ export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN' -type HumanFeedbackMechanismConstructor = new ( - ...args: any[] -) => HumanFeedbackMechanism +type HumanFeedbackMechanismConstructor< + T extends HumanFeedbackType, + TOutput = any +> = new (...args: any[]) => HumanFeedbackMechanism /** * Options for human feedback. */ -export type HumanFeedbackOptions = { +export type HumanFeedbackOptions = { /** * What type of feedback to request. */ @@ -64,7 +65,7 @@ export type HumanFeedbackOptions = { /** * The human feedback mechanism to use for this task. */ - mechanism?: HumanFeedbackMechanismConstructor + mechanism?: HumanFeedbackMechanismConstructor } export interface BaseHumanFeedbackMetadata { @@ -125,28 +126,35 @@ export type FeedbackTypeToMetadata = ? HumanFeedbackSelectOneMetadata : HumanFeedbackSelectNMetadata -export abstract class HumanFeedbackMechanism { +export abstract class HumanFeedbackMechanism< + T extends HumanFeedbackType, + TOutput +> { protected _agentic: Agentic protected _task: BaseTask - protected _options: Required> + protected _options: Required> constructor({ task, options }: { task: BaseTask - options: Required> + options: Required> }) { this._agentic = task.agentic this._task = task this._options = options } - protected abstract selectOne(response: any): Promise + protected abstract selectOne( + output: TOutput + ): Promise - protected abstract selectN(response: any): Promise + protected abstract selectN( + response: TOutput + ): Promise protected abstract annotate(): Promise @@ -162,8 +170,8 @@ export abstract class HumanFeedbackMechanism { return this._task.outputSchema.parse(parsedOutput) } - public async interact(response: any): Promise> { - const stringified = JSON.stringify(response, null, 2) + public async interact(output: TOutput): Promise> { + const stringified = JSON.stringify(output, null, 2) const msg = [ 'The following output was generated:', '```', @@ -216,9 +224,17 @@ export abstract class HumanFeedbackMechanism { case HumanFeedbackUserActions.Select: if (this._options.type === 'selectN') { - feedback.selected = await this.selectN(response) + if (!Array.isArray(output)) { + throw new Error('Expected output to be an array') + } + + feedback.selected = await this.selectN(output) } else if (this._options.type === 'selectOne') { - feedback.chosen = await this.selectOne(response) + if (!Array.isArray(output)) { + throw new Error('Expected output to be an array') + } + + feedback.chosen = await this.selectOne(output) } break @@ -241,9 +257,9 @@ export abstract class HumanFeedbackMechanism { } } -export function withHumanFeedback( - task: BaseTask, - options: HumanFeedbackOptions = {} +export function withHumanFeedback( + task: BaseTask, + options: HumanFeedbackOptions = {} ) { task = task.clone() @@ -251,7 +267,7 @@ export function withHumanFeedback( const instanceDefaults = task.agentic.humanFeedbackDefaults // Use Object.assign to merge the options, instance defaults, and hard-coded defaults - const finalOptions: HumanFeedbackOptions = Object.assign( + const finalOptions: HumanFeedbackOptions = Object.assign( { type: 'confirm', bail: false, @@ -278,7 +294,7 @@ export function withHumanFeedback( const originalCall = task.callWithMetadata.bind(task) - task.callWithMetadata = async function (input?: T) { + task.callWithMetadata = async function (input?: TInput) { const response = await originalCall(input) const feedback = await feedbackMechanism.interact(response.result) diff --git a/src/human-feedback/slack.ts b/src/human-feedback/slack.ts index 724d01c..4700202 100644 --- a/src/human-feedback/slack.ts +++ b/src/human-feedback/slack.ts @@ -10,8 +10,9 @@ import { } from './feedback' export class HumanFeedbackMechanismSlack< - T extends HumanFeedbackType -> extends HumanFeedbackMechanism { + T extends HumanFeedbackType, + TOutput = any +> extends HumanFeedbackMechanism { private slackClient: SlackClient constructor({ @@ -19,7 +20,7 @@ export class HumanFeedbackMechanismSlack< options }: { task: BaseTask - options: Required> + options: Required> }) { super({ task, options }) this.slackClient = new SlackClient() @@ -68,7 +69,13 @@ export class HumanFeedbackMechanismSlack< return choices[parseInt(response.text)] } - public async selectOne(response: any[]): Promise { + protected async selectOne( + response: TOutput + ): Promise { + if (!Array.isArray(response)) { + throw new Error('selectOne called on non-array response') + } + const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply( { text: @@ -84,7 +91,13 @@ export class HumanFeedbackMechanismSlack< return response[parseInt(selectedOutput)] } - public async selectN(response: any[]): Promise { + protected async selectN( + response: TOutput + ): Promise { + if (!Array.isArray(response)) { + throw new Error('selectN called on non-array response') + } + const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply( { text: @@ -107,6 +120,6 @@ export class HumanFeedbackMechanismSlack< .map((choice) => parseInt(choice)) return response.filter((_, idx) => { return chosenOutputs.includes(idx) - }) + }) as any } } diff --git a/src/human-feedback/twilio.ts b/src/human-feedback/twilio.ts index e41c739..895d10d 100644 --- a/src/human-feedback/twilio.ts +++ b/src/human-feedback/twilio.ts @@ -10,8 +10,9 @@ import { } from './feedback' export class HumanFeedbackMechanismTwilio< - T extends HumanFeedbackType -> extends HumanFeedbackMechanism { + T extends HumanFeedbackType, + TOutput = any +> extends HumanFeedbackMechanism { private twilioClient: TwilioConversationClient constructor({ @@ -19,7 +20,7 @@ export class HumanFeedbackMechanismTwilio< options }: { task: BaseTask - options: Required> + options: Required> }) { super({ task, options }) this.twilioClient = new TwilioConversationClient() @@ -71,7 +72,13 @@ export class HumanFeedbackMechanismTwilio< return choices[parseInt(response.body)] } - public async selectOne(response: any[]): Promise { + protected async selectOne( + response: TOutput + ): Promise { + if (!Array.isArray(response)) { + throw new Error('selectOne called on non-array response') + } + const { body: selectedOutput } = await this.twilioClient.sendAndWaitForReply({ name: 'human-feedback-select', @@ -87,7 +94,13 @@ export class HumanFeedbackMechanismTwilio< return response[parseInt(selectedOutput)] } - public async selectN(response: any[]): Promise { + protected async selectN( + response: TOutput + ): Promise { + if (!Array.isArray(response)) { + throw new Error('selectN called on non-array response') + } + const { body: selectedOutput } = await this.twilioClient.sendAndWaitForReply({ name: 'human-feedback-select', @@ -110,6 +123,6 @@ export class HumanFeedbackMechanismTwilio< .map((choice) => parseInt(choice)) return response.filter((_, idx) => { return chosenOutputs.includes(idx) - }) + }) as any } } diff --git a/src/types.ts b/src/types.ts index da5e8ee..5fc4752 100644 --- a/src/types.ts +++ b/src/types.ts @@ -6,11 +6,14 @@ import type { JsonObject, JsonValue } from 'type-fest' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import type { Agentic } from './agentic' +import type { + FeedbackTypeToMetadata, + HumanFeedbackType +} from './human-feedback' import type { Logger } from './logger' import type { BaseTask } from './task' -export { openai } -export { anthropic } +export { anthropic, openai } export type { Logger } export type { JsonObject, JsonValue } @@ -102,6 +105,9 @@ export interface TaskResponseMetadata extends Record { error?: Error numRetries?: number callId?: string + + // human feedback info + feedback?: FeedbackTypeToMetadata } export interface LLMTaskResponseMetadata<