feat: improve type safety

old-agentic-v1^2
Philipp Burckhardt 2023-06-13 21:25:40 -04:00 zatwierdzone przez Travis Fischer
rodzic 5bddd1f04a
commit 0798a24fdc
5 zmienionych plików z 101 dodań i 40 usunięć

Wyświetl plik

@ -11,8 +11,9 @@ import {
} from './feedback' } from './feedback'
export class HumanFeedbackMechanismCLI< export class HumanFeedbackMechanismCLI<
T extends HumanFeedbackType T extends HumanFeedbackType,
> extends HumanFeedbackMechanism<T> { TOutput = any
> extends HumanFeedbackMechanism<T, TOutput> {
/** /**
* Prompt the user to select one of a list of options. * Prompt the user to select one of a list of options.
*/ */
@ -43,19 +44,31 @@ export class HumanFeedbackMechanismCLI<
}) })
} }
protected async selectOne(response: any[]): Promise<void> { protected async selectOne(
response: TOutput
): Promise<TOutput extends (infer U)[] ? U : never> {
if (!Array.isArray(response)) {
throw new Error('selectOne called on non-array response')
}
const choices = response.map((option) => ({ const choices = response.map((option) => ({
name: option, name: String(option),
value: option value: option
})) }))
return select({ message: 'Pick one output:', choices }) return select({ message: 'Pick one output:', choices })
} }
protected async selectN(response: any[]): Promise<any[]> { protected async selectN(
response: TOutput
): Promise<TOutput extends any[] ? TOutput : never> {
if (!Array.isArray(response)) {
throw new Error('selectN called on non-array response')
}
const choices = response.map((option) => ({ const choices = response.map((option) => ({
name: option, name: String(option),
value: option value: option
})) }))
return checkbox({ message: 'Select outputs:', choices }) return checkbox({ message: 'Select outputs:', choices }) as any
} }
} }

Wyświetl plik

@ -33,14 +33,15 @@ export const HumanFeedbackUserActionMessages: Record<
*/ */
export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN' export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN'
type HumanFeedbackMechanismConstructor<T extends HumanFeedbackType> = new ( type HumanFeedbackMechanismConstructor<
...args: any[] T extends HumanFeedbackType,
) => HumanFeedbackMechanism<T> TOutput = any
> = new (...args: any[]) => HumanFeedbackMechanism<T, TOutput>
/** /**
* Options for human feedback. * Options for human feedback.
*/ */
export type HumanFeedbackOptions<T extends HumanFeedbackType> = { export type HumanFeedbackOptions<T extends HumanFeedbackType, TOutput> = {
/** /**
* What type of feedback to request. * What type of feedback to request.
*/ */
@ -64,7 +65,7 @@ export type HumanFeedbackOptions<T extends HumanFeedbackType> = {
/** /**
* The human feedback mechanism to use for this task. * The human feedback mechanism to use for this task.
*/ */
mechanism?: HumanFeedbackMechanismConstructor<T> mechanism?: HumanFeedbackMechanismConstructor<T, TOutput>
} }
export interface BaseHumanFeedbackMetadata { export interface BaseHumanFeedbackMetadata {
@ -125,28 +126,35 @@ export type FeedbackTypeToMetadata<T extends HumanFeedbackType> =
? HumanFeedbackSelectOneMetadata ? HumanFeedbackSelectOneMetadata
: HumanFeedbackSelectNMetadata : HumanFeedbackSelectNMetadata
export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> { export abstract class HumanFeedbackMechanism<
T extends HumanFeedbackType,
TOutput
> {
protected _agentic: Agentic protected _agentic: Agentic
protected _task: BaseTask protected _task: BaseTask
protected _options: Required<HumanFeedbackOptions<T>> protected _options: Required<HumanFeedbackOptions<T, TOutput>>
constructor({ constructor({
task, task,
options options
}: { }: {
task: BaseTask task: BaseTask
options: Required<HumanFeedbackOptions<T>> options: Required<HumanFeedbackOptions<T, TOutput>>
}) { }) {
this._agentic = task.agentic this._agentic = task.agentic
this._task = task this._task = task
this._options = options this._options = options
} }
protected abstract selectOne(response: any): Promise<any> protected abstract selectOne(
output: TOutput
): Promise<TOutput extends any[] ? TOutput[0] : never>
protected abstract selectN(response: any): Promise<any> protected abstract selectN(
response: TOutput
): Promise<TOutput extends any[] ? TOutput : never>
protected abstract annotate(): Promise<string> protected abstract annotate(): Promise<string>
@ -162,8 +170,8 @@ export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
return this._task.outputSchema.parse(parsedOutput) return this._task.outputSchema.parse(parsedOutput)
} }
public async interact(response: any): Promise<FeedbackTypeToMetadata<T>> { public async interact(output: TOutput): Promise<FeedbackTypeToMetadata<T>> {
const stringified = JSON.stringify(response, null, 2) const stringified = JSON.stringify(output, null, 2)
const msg = [ const msg = [
'The following output was generated:', 'The following output was generated:',
'```', '```',
@ -216,9 +224,17 @@ export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
case HumanFeedbackUserActions.Select: case HumanFeedbackUserActions.Select:
if (this._options.type === 'selectN') { if (this._options.type === 'selectN') {
feedback.selected = await this.selectN(response) if (!Array.isArray(output)) {
throw new Error('Expected output to be an array')
}
feedback.selected = await this.selectN(output)
} else if (this._options.type === 'selectOne') { } else if (this._options.type === 'selectOne') {
feedback.chosen = await this.selectOne(response) if (!Array.isArray(output)) {
throw new Error('Expected output to be an array')
}
feedback.chosen = await this.selectOne(output)
} }
break break
@ -241,9 +257,9 @@ export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
} }
} }
export function withHumanFeedback<T, U, V extends HumanFeedbackType>( export function withHumanFeedback<TInput, TOutput, V extends HumanFeedbackType>(
task: BaseTask<T, U>, task: BaseTask<TInput, TOutput>,
options: HumanFeedbackOptions<V> = {} options: HumanFeedbackOptions<V, TOutput> = {}
) { ) {
task = task.clone() task = task.clone()
@ -251,7 +267,7 @@ export function withHumanFeedback<T, U, V extends HumanFeedbackType>(
const instanceDefaults = task.agentic.humanFeedbackDefaults const instanceDefaults = task.agentic.humanFeedbackDefaults
// Use Object.assign to merge the options, instance defaults, and hard-coded defaults // Use Object.assign to merge the options, instance defaults, and hard-coded defaults
const finalOptions: HumanFeedbackOptions<V> = Object.assign( const finalOptions: HumanFeedbackOptions<V, TOutput> = Object.assign(
{ {
type: 'confirm', type: 'confirm',
bail: false, bail: false,
@ -278,7 +294,7 @@ export function withHumanFeedback<T, U, V extends HumanFeedbackType>(
const originalCall = task.callWithMetadata.bind(task) const originalCall = task.callWithMetadata.bind(task)
task.callWithMetadata = async function (input?: T) { task.callWithMetadata = async function (input?: TInput) {
const response = await originalCall(input) const response = await originalCall(input)
const feedback = await feedbackMechanism.interact(response.result) const feedback = await feedbackMechanism.interact(response.result)

Wyświetl plik

@ -10,8 +10,9 @@ import {
} from './feedback' } from './feedback'
export class HumanFeedbackMechanismSlack< export class HumanFeedbackMechanismSlack<
T extends HumanFeedbackType T extends HumanFeedbackType,
> extends HumanFeedbackMechanism<T> { TOutput = any
> extends HumanFeedbackMechanism<T, TOutput> {
private slackClient: SlackClient private slackClient: SlackClient
constructor({ constructor({
@ -19,7 +20,7 @@ export class HumanFeedbackMechanismSlack<
options options
}: { }: {
task: BaseTask task: BaseTask
options: Required<HumanFeedbackOptions<T>> options: Required<HumanFeedbackOptions<T, TOutput>>
}) { }) {
super({ task, options }) super({ task, options })
this.slackClient = new SlackClient() this.slackClient = new SlackClient()
@ -68,7 +69,13 @@ export class HumanFeedbackMechanismSlack<
return choices[parseInt(response.text)] return choices[parseInt(response.text)]
} }
public async selectOne(response: any[]): Promise<any> { protected async selectOne(
response: TOutput
): Promise<TOutput extends (infer U)[] ? U : never> {
if (!Array.isArray(response)) {
throw new Error('selectOne called on non-array response')
}
const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply( const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply(
{ {
text: text:
@ -84,7 +91,13 @@ export class HumanFeedbackMechanismSlack<
return response[parseInt(selectedOutput)] return response[parseInt(selectedOutput)]
} }
public async selectN(response: any[]): Promise<any[]> { protected async selectN(
response: TOutput
): Promise<TOutput extends any[] ? TOutput : never> {
if (!Array.isArray(response)) {
throw new Error('selectN called on non-array response')
}
const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply( const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply(
{ {
text: text:
@ -107,6 +120,6 @@ export class HumanFeedbackMechanismSlack<
.map((choice) => parseInt(choice)) .map((choice) => parseInt(choice))
return response.filter((_, idx) => { return response.filter((_, idx) => {
return chosenOutputs.includes(idx) return chosenOutputs.includes(idx)
}) }) as any
} }
} }

Wyświetl plik

@ -10,8 +10,9 @@ import {
} from './feedback' } from './feedback'
export class HumanFeedbackMechanismTwilio< export class HumanFeedbackMechanismTwilio<
T extends HumanFeedbackType T extends HumanFeedbackType,
> extends HumanFeedbackMechanism<T> { TOutput = any
> extends HumanFeedbackMechanism<T, TOutput> {
private twilioClient: TwilioConversationClient private twilioClient: TwilioConversationClient
constructor({ constructor({
@ -19,7 +20,7 @@ export class HumanFeedbackMechanismTwilio<
options options
}: { }: {
task: BaseTask task: BaseTask
options: Required<HumanFeedbackOptions<T>> options: Required<HumanFeedbackOptions<T, TOutput>>
}) { }) {
super({ task, options }) super({ task, options })
this.twilioClient = new TwilioConversationClient() this.twilioClient = new TwilioConversationClient()
@ -71,7 +72,13 @@ export class HumanFeedbackMechanismTwilio<
return choices[parseInt(response.body)] return choices[parseInt(response.body)]
} }
public async selectOne(response: any[]): Promise<any> { protected async selectOne(
response: TOutput
): Promise<TOutput extends (infer U)[] ? U : never> {
if (!Array.isArray(response)) {
throw new Error('selectOne called on non-array response')
}
const { body: selectedOutput } = const { body: selectedOutput } =
await this.twilioClient.sendAndWaitForReply({ await this.twilioClient.sendAndWaitForReply({
name: 'human-feedback-select', name: 'human-feedback-select',
@ -87,7 +94,13 @@ export class HumanFeedbackMechanismTwilio<
return response[parseInt(selectedOutput)] return response[parseInt(selectedOutput)]
} }
public async selectN(response: any[]): Promise<any[]> { protected async selectN(
response: TOutput
): Promise<TOutput extends any[] ? TOutput : never> {
if (!Array.isArray(response)) {
throw new Error('selectN called on non-array response')
}
const { body: selectedOutput } = const { body: selectedOutput } =
await this.twilioClient.sendAndWaitForReply({ await this.twilioClient.sendAndWaitForReply({
name: 'human-feedback-select', name: 'human-feedback-select',
@ -110,6 +123,6 @@ export class HumanFeedbackMechanismTwilio<
.map((choice) => parseInt(choice)) .map((choice) => parseInt(choice))
return response.filter((_, idx) => { return response.filter((_, idx) => {
return chosenOutputs.includes(idx) return chosenOutputs.includes(idx)
}) }) as any
} }
} }

Wyświetl plik

@ -6,11 +6,14 @@ import type { JsonObject, JsonValue } from 'type-fest'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic' import type { Agentic } from './agentic'
import type {
FeedbackTypeToMetadata,
HumanFeedbackType
} from './human-feedback'
import type { Logger } from './logger' import type { Logger } from './logger'
import type { BaseTask } from './task' import type { BaseTask } from './task'
export { openai } export { anthropic, openai }
export { anthropic }
export type { Logger } export type { Logger }
export type { JsonObject, JsonValue } export type { JsonObject, JsonValue }
@ -102,6 +105,9 @@ export interface TaskResponseMetadata extends Record<string, any> {
error?: Error error?: Error
numRetries?: number numRetries?: number
callId?: string callId?: string
// human feedback info
feedback?: FeedbackTypeToMetadata<HumanFeedbackType>
} }
export interface LLMTaskResponseMetadata< export interface LLMTaskResponseMetadata<