kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: improve type safety
rodzic
5bddd1f04a
commit
0798a24fdc
|
@ -11,8 +11,9 @@ import {
|
||||||
} from './feedback'
|
} from './feedback'
|
||||||
|
|
||||||
export class HumanFeedbackMechanismCLI<
|
export class HumanFeedbackMechanismCLI<
|
||||||
T extends HumanFeedbackType
|
T extends HumanFeedbackType,
|
||||||
> extends HumanFeedbackMechanism<T> {
|
TOutput = any
|
||||||
|
> extends HumanFeedbackMechanism<T, TOutput> {
|
||||||
/**
|
/**
|
||||||
* Prompt the user to select one of a list of options.
|
* Prompt the user to select one of a list of options.
|
||||||
*/
|
*/
|
||||||
|
@ -43,19 +44,31 @@ export class HumanFeedbackMechanismCLI<
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
protected async selectOne(response: any[]): Promise<void> {
|
protected async selectOne(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends (infer U)[] ? U : never> {
|
||||||
|
if (!Array.isArray(response)) {
|
||||||
|
throw new Error('selectOne called on non-array response')
|
||||||
|
}
|
||||||
|
|
||||||
const choices = response.map((option) => ({
|
const choices = response.map((option) => ({
|
||||||
name: option,
|
name: String(option),
|
||||||
value: option
|
value: option
|
||||||
}))
|
}))
|
||||||
return select({ message: 'Pick one output:', choices })
|
return select({ message: 'Pick one output:', choices })
|
||||||
}
|
}
|
||||||
|
|
||||||
protected async selectN(response: any[]): Promise<any[]> {
|
protected async selectN(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends any[] ? TOutput : never> {
|
||||||
|
if (!Array.isArray(response)) {
|
||||||
|
throw new Error('selectN called on non-array response')
|
||||||
|
}
|
||||||
|
|
||||||
const choices = response.map((option) => ({
|
const choices = response.map((option) => ({
|
||||||
name: option,
|
name: String(option),
|
||||||
value: option
|
value: option
|
||||||
}))
|
}))
|
||||||
return checkbox({ message: 'Select outputs:', choices })
|
return checkbox({ message: 'Select outputs:', choices }) as any
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,14 +33,15 @@ export const HumanFeedbackUserActionMessages: Record<
|
||||||
*/
|
*/
|
||||||
export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN'
|
export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN'
|
||||||
|
|
||||||
type HumanFeedbackMechanismConstructor<T extends HumanFeedbackType> = new (
|
type HumanFeedbackMechanismConstructor<
|
||||||
...args: any[]
|
T extends HumanFeedbackType,
|
||||||
) => HumanFeedbackMechanism<T>
|
TOutput = any
|
||||||
|
> = new (...args: any[]) => HumanFeedbackMechanism<T, TOutput>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Options for human feedback.
|
* Options for human feedback.
|
||||||
*/
|
*/
|
||||||
export type HumanFeedbackOptions<T extends HumanFeedbackType> = {
|
export type HumanFeedbackOptions<T extends HumanFeedbackType, TOutput> = {
|
||||||
/**
|
/**
|
||||||
* What type of feedback to request.
|
* What type of feedback to request.
|
||||||
*/
|
*/
|
||||||
|
@ -64,7 +65,7 @@ export type HumanFeedbackOptions<T extends HumanFeedbackType> = {
|
||||||
/**
|
/**
|
||||||
* The human feedback mechanism to use for this task.
|
* The human feedback mechanism to use for this task.
|
||||||
*/
|
*/
|
||||||
mechanism?: HumanFeedbackMechanismConstructor<T>
|
mechanism?: HumanFeedbackMechanismConstructor<T, TOutput>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface BaseHumanFeedbackMetadata {
|
export interface BaseHumanFeedbackMetadata {
|
||||||
|
@ -125,28 +126,35 @@ export type FeedbackTypeToMetadata<T extends HumanFeedbackType> =
|
||||||
? HumanFeedbackSelectOneMetadata
|
? HumanFeedbackSelectOneMetadata
|
||||||
: HumanFeedbackSelectNMetadata
|
: HumanFeedbackSelectNMetadata
|
||||||
|
|
||||||
export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
|
export abstract class HumanFeedbackMechanism<
|
||||||
|
T extends HumanFeedbackType,
|
||||||
|
TOutput
|
||||||
|
> {
|
||||||
protected _agentic: Agentic
|
protected _agentic: Agentic
|
||||||
|
|
||||||
protected _task: BaseTask
|
protected _task: BaseTask
|
||||||
|
|
||||||
protected _options: Required<HumanFeedbackOptions<T>>
|
protected _options: Required<HumanFeedbackOptions<T, TOutput>>
|
||||||
|
|
||||||
constructor({
|
constructor({
|
||||||
task,
|
task,
|
||||||
options
|
options
|
||||||
}: {
|
}: {
|
||||||
task: BaseTask
|
task: BaseTask
|
||||||
options: Required<HumanFeedbackOptions<T>>
|
options: Required<HumanFeedbackOptions<T, TOutput>>
|
||||||
}) {
|
}) {
|
||||||
this._agentic = task.agentic
|
this._agentic = task.agentic
|
||||||
this._task = task
|
this._task = task
|
||||||
this._options = options
|
this._options = options
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract selectOne(response: any): Promise<any>
|
protected abstract selectOne(
|
||||||
|
output: TOutput
|
||||||
|
): Promise<TOutput extends any[] ? TOutput[0] : never>
|
||||||
|
|
||||||
protected abstract selectN(response: any): Promise<any>
|
protected abstract selectN(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends any[] ? TOutput : never>
|
||||||
|
|
||||||
protected abstract annotate(): Promise<string>
|
protected abstract annotate(): Promise<string>
|
||||||
|
|
||||||
|
@ -162,8 +170,8 @@ export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
|
||||||
return this._task.outputSchema.parse(parsedOutput)
|
return this._task.outputSchema.parse(parsedOutput)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async interact(response: any): Promise<FeedbackTypeToMetadata<T>> {
|
public async interact(output: TOutput): Promise<FeedbackTypeToMetadata<T>> {
|
||||||
const stringified = JSON.stringify(response, null, 2)
|
const stringified = JSON.stringify(output, null, 2)
|
||||||
const msg = [
|
const msg = [
|
||||||
'The following output was generated:',
|
'The following output was generated:',
|
||||||
'```',
|
'```',
|
||||||
|
@ -216,9 +224,17 @@ export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
|
||||||
|
|
||||||
case HumanFeedbackUserActions.Select:
|
case HumanFeedbackUserActions.Select:
|
||||||
if (this._options.type === 'selectN') {
|
if (this._options.type === 'selectN') {
|
||||||
feedback.selected = await this.selectN(response)
|
if (!Array.isArray(output)) {
|
||||||
|
throw new Error('Expected output to be an array')
|
||||||
|
}
|
||||||
|
|
||||||
|
feedback.selected = await this.selectN(output)
|
||||||
} else if (this._options.type === 'selectOne') {
|
} else if (this._options.type === 'selectOne') {
|
||||||
feedback.chosen = await this.selectOne(response)
|
if (!Array.isArray(output)) {
|
||||||
|
throw new Error('Expected output to be an array')
|
||||||
|
}
|
||||||
|
|
||||||
|
feedback.chosen = await this.selectOne(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
break
|
break
|
||||||
|
@ -241,9 +257,9 @@ export abstract class HumanFeedbackMechanism<T extends HumanFeedbackType> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function withHumanFeedback<T, U, V extends HumanFeedbackType>(
|
export function withHumanFeedback<TInput, TOutput, V extends HumanFeedbackType>(
|
||||||
task: BaseTask<T, U>,
|
task: BaseTask<TInput, TOutput>,
|
||||||
options: HumanFeedbackOptions<V> = {}
|
options: HumanFeedbackOptions<V, TOutput> = {}
|
||||||
) {
|
) {
|
||||||
task = task.clone()
|
task = task.clone()
|
||||||
|
|
||||||
|
@ -251,7 +267,7 @@ export function withHumanFeedback<T, U, V extends HumanFeedbackType>(
|
||||||
const instanceDefaults = task.agentic.humanFeedbackDefaults
|
const instanceDefaults = task.agentic.humanFeedbackDefaults
|
||||||
|
|
||||||
// Use Object.assign to merge the options, instance defaults, and hard-coded defaults
|
// Use Object.assign to merge the options, instance defaults, and hard-coded defaults
|
||||||
const finalOptions: HumanFeedbackOptions<V> = Object.assign(
|
const finalOptions: HumanFeedbackOptions<V, TOutput> = Object.assign(
|
||||||
{
|
{
|
||||||
type: 'confirm',
|
type: 'confirm',
|
||||||
bail: false,
|
bail: false,
|
||||||
|
@ -278,7 +294,7 @@ export function withHumanFeedback<T, U, V extends HumanFeedbackType>(
|
||||||
|
|
||||||
const originalCall = task.callWithMetadata.bind(task)
|
const originalCall = task.callWithMetadata.bind(task)
|
||||||
|
|
||||||
task.callWithMetadata = async function (input?: T) {
|
task.callWithMetadata = async function (input?: TInput) {
|
||||||
const response = await originalCall(input)
|
const response = await originalCall(input)
|
||||||
|
|
||||||
const feedback = await feedbackMechanism.interact(response.result)
|
const feedback = await feedbackMechanism.interact(response.result)
|
||||||
|
|
|
@ -10,8 +10,9 @@ import {
|
||||||
} from './feedback'
|
} from './feedback'
|
||||||
|
|
||||||
export class HumanFeedbackMechanismSlack<
|
export class HumanFeedbackMechanismSlack<
|
||||||
T extends HumanFeedbackType
|
T extends HumanFeedbackType,
|
||||||
> extends HumanFeedbackMechanism<T> {
|
TOutput = any
|
||||||
|
> extends HumanFeedbackMechanism<T, TOutput> {
|
||||||
private slackClient: SlackClient
|
private slackClient: SlackClient
|
||||||
|
|
||||||
constructor({
|
constructor({
|
||||||
|
@ -19,7 +20,7 @@ export class HumanFeedbackMechanismSlack<
|
||||||
options
|
options
|
||||||
}: {
|
}: {
|
||||||
task: BaseTask
|
task: BaseTask
|
||||||
options: Required<HumanFeedbackOptions<T>>
|
options: Required<HumanFeedbackOptions<T, TOutput>>
|
||||||
}) {
|
}) {
|
||||||
super({ task, options })
|
super({ task, options })
|
||||||
this.slackClient = new SlackClient()
|
this.slackClient = new SlackClient()
|
||||||
|
@ -68,7 +69,13 @@ export class HumanFeedbackMechanismSlack<
|
||||||
return choices[parseInt(response.text)]
|
return choices[parseInt(response.text)]
|
||||||
}
|
}
|
||||||
|
|
||||||
public async selectOne(response: any[]): Promise<any> {
|
protected async selectOne(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends (infer U)[] ? U : never> {
|
||||||
|
if (!Array.isArray(response)) {
|
||||||
|
throw new Error('selectOne called on non-array response')
|
||||||
|
}
|
||||||
|
|
||||||
const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply(
|
const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply(
|
||||||
{
|
{
|
||||||
text:
|
text:
|
||||||
|
@ -84,7 +91,13 @@ export class HumanFeedbackMechanismSlack<
|
||||||
return response[parseInt(selectedOutput)]
|
return response[parseInt(selectedOutput)]
|
||||||
}
|
}
|
||||||
|
|
||||||
public async selectN(response: any[]): Promise<any[]> {
|
protected async selectN(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends any[] ? TOutput : never> {
|
||||||
|
if (!Array.isArray(response)) {
|
||||||
|
throw new Error('selectN called on non-array response')
|
||||||
|
}
|
||||||
|
|
||||||
const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply(
|
const { text: selectedOutput } = await this.slackClient.sendAndWaitForReply(
|
||||||
{
|
{
|
||||||
text:
|
text:
|
||||||
|
@ -107,6 +120,6 @@ export class HumanFeedbackMechanismSlack<
|
||||||
.map((choice) => parseInt(choice))
|
.map((choice) => parseInt(choice))
|
||||||
return response.filter((_, idx) => {
|
return response.filter((_, idx) => {
|
||||||
return chosenOutputs.includes(idx)
|
return chosenOutputs.includes(idx)
|
||||||
})
|
}) as any
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,8 +10,9 @@ import {
|
||||||
} from './feedback'
|
} from './feedback'
|
||||||
|
|
||||||
export class HumanFeedbackMechanismTwilio<
|
export class HumanFeedbackMechanismTwilio<
|
||||||
T extends HumanFeedbackType
|
T extends HumanFeedbackType,
|
||||||
> extends HumanFeedbackMechanism<T> {
|
TOutput = any
|
||||||
|
> extends HumanFeedbackMechanism<T, TOutput> {
|
||||||
private twilioClient: TwilioConversationClient
|
private twilioClient: TwilioConversationClient
|
||||||
|
|
||||||
constructor({
|
constructor({
|
||||||
|
@ -19,7 +20,7 @@ export class HumanFeedbackMechanismTwilio<
|
||||||
options
|
options
|
||||||
}: {
|
}: {
|
||||||
task: BaseTask
|
task: BaseTask
|
||||||
options: Required<HumanFeedbackOptions<T>>
|
options: Required<HumanFeedbackOptions<T, TOutput>>
|
||||||
}) {
|
}) {
|
||||||
super({ task, options })
|
super({ task, options })
|
||||||
this.twilioClient = new TwilioConversationClient()
|
this.twilioClient = new TwilioConversationClient()
|
||||||
|
@ -71,7 +72,13 @@ export class HumanFeedbackMechanismTwilio<
|
||||||
return choices[parseInt(response.body)]
|
return choices[parseInt(response.body)]
|
||||||
}
|
}
|
||||||
|
|
||||||
public async selectOne(response: any[]): Promise<any> {
|
protected async selectOne(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends (infer U)[] ? U : never> {
|
||||||
|
if (!Array.isArray(response)) {
|
||||||
|
throw new Error('selectOne called on non-array response')
|
||||||
|
}
|
||||||
|
|
||||||
const { body: selectedOutput } =
|
const { body: selectedOutput } =
|
||||||
await this.twilioClient.sendAndWaitForReply({
|
await this.twilioClient.sendAndWaitForReply({
|
||||||
name: 'human-feedback-select',
|
name: 'human-feedback-select',
|
||||||
|
@ -87,7 +94,13 @@ export class HumanFeedbackMechanismTwilio<
|
||||||
return response[parseInt(selectedOutput)]
|
return response[parseInt(selectedOutput)]
|
||||||
}
|
}
|
||||||
|
|
||||||
public async selectN(response: any[]): Promise<any[]> {
|
protected async selectN(
|
||||||
|
response: TOutput
|
||||||
|
): Promise<TOutput extends any[] ? TOutput : never> {
|
||||||
|
if (!Array.isArray(response)) {
|
||||||
|
throw new Error('selectN called on non-array response')
|
||||||
|
}
|
||||||
|
|
||||||
const { body: selectedOutput } =
|
const { body: selectedOutput } =
|
||||||
await this.twilioClient.sendAndWaitForReply({
|
await this.twilioClient.sendAndWaitForReply({
|
||||||
name: 'human-feedback-select',
|
name: 'human-feedback-select',
|
||||||
|
@ -110,6 +123,6 @@ export class HumanFeedbackMechanismTwilio<
|
||||||
.map((choice) => parseInt(choice))
|
.map((choice) => parseInt(choice))
|
||||||
return response.filter((_, idx) => {
|
return response.filter((_, idx) => {
|
||||||
return chosenOutputs.includes(idx)
|
return chosenOutputs.includes(idx)
|
||||||
})
|
}) as any
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
10
src/types.ts
10
src/types.ts
|
@ -6,11 +6,14 @@ import type { JsonObject, JsonValue } from 'type-fest'
|
||||||
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||||
|
|
||||||
import type { Agentic } from './agentic'
|
import type { Agentic } from './agentic'
|
||||||
|
import type {
|
||||||
|
FeedbackTypeToMetadata,
|
||||||
|
HumanFeedbackType
|
||||||
|
} from './human-feedback'
|
||||||
import type { Logger } from './logger'
|
import type { Logger } from './logger'
|
||||||
import type { BaseTask } from './task'
|
import type { BaseTask } from './task'
|
||||||
|
|
||||||
export { openai }
|
export { anthropic, openai }
|
||||||
export { anthropic }
|
|
||||||
|
|
||||||
export type { Logger }
|
export type { Logger }
|
||||||
export type { JsonObject, JsonValue }
|
export type { JsonObject, JsonValue }
|
||||||
|
@ -102,6 +105,9 @@ export interface TaskResponseMetadata extends Record<string, any> {
|
||||||
error?: Error
|
error?: Error
|
||||||
numRetries?: number
|
numRetries?: number
|
||||||
callId?: string
|
callId?: string
|
||||||
|
|
||||||
|
// human feedback info
|
||||||
|
feedback?: FeedbackTypeToMetadata<HumanFeedbackType>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface LLMTaskResponseMetadata<
|
export interface LLMTaskResponseMetadata<
|
||||||
|
|
Ładowanie…
Reference in New Issue