kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: simplify Task input/output typing
rodzic
113fdebee3
commit
9fa5f26215
|
@ -1,6 +1,5 @@
|
||||||
import * as anthropic from '@anthropic-ai/sdk'
|
import * as anthropic from '@anthropic-ai/sdk'
|
||||||
import { type SetOptional } from 'type-fest'
|
import { type SetOptional } from 'type-fest'
|
||||||
import { ZodTypeAny, z } from 'zod'
|
|
||||||
|
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
import { defaultAnthropicModel } from '@/constants'
|
import { defaultAnthropicModel } from '@/constants'
|
||||||
|
@ -10,8 +9,8 @@ import { BaseChatModel } from './llm'
|
||||||
const defaultStopSequences = [anthropic.HUMAN_PROMPT]
|
const defaultStopSequences = [anthropic.HUMAN_PROMPT]
|
||||||
|
|
||||||
export class AnthropicChatModel<
|
export class AnthropicChatModel<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = any,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>
|
TOutput = string
|
||||||
> extends BaseChatModel<
|
> extends BaseChatModel<
|
||||||
TInput,
|
TInput,
|
||||||
TOutput,
|
TOutput,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
|
||||||
import pMap from 'p-map'
|
import pMap from 'p-map'
|
||||||
import { dedent } from 'ts-dedent'
|
import { dedent } from 'ts-dedent'
|
||||||
import { type SetRequired } from 'type-fest'
|
import { type SetRequired } from 'type-fest'
|
||||||
import { ZodTypeAny, z } from 'zod'
|
import { ZodType, z } from 'zod'
|
||||||
import { printNode, zodToTs } from 'zod-to-ts'
|
import { printNode, zodToTs } from 'zod-to-ts'
|
||||||
|
|
||||||
import * as errors from '@/errors'
|
import * as errors from '@/errors'
|
||||||
|
@ -20,12 +20,12 @@ import {
|
||||||
} from '@/utils'
|
} from '@/utils'
|
||||||
|
|
||||||
export abstract class BaseLLM<
|
export abstract class BaseLLM<
|
||||||
TInput extends ZodTypeAny = z.ZodVoid,
|
TInput = void,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
TOutput = string,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>
|
TModelParams extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseTask<TInput, TOutput> {
|
> extends BaseTask<TInput, TOutput> {
|
||||||
protected _inputSchema: TInput | undefined
|
protected _inputSchema: ZodType<TInput> | undefined
|
||||||
protected _outputSchema: TOutput | undefined
|
protected _outputSchema: ZodType<TOutput> | undefined
|
||||||
|
|
||||||
protected _provider: string
|
protected _provider: string
|
||||||
protected _model: string
|
protected _model: string
|
||||||
|
@ -50,36 +50,33 @@ export abstract class BaseLLM<
|
||||||
this._examples = options.examples
|
this._examples = options.examples
|
||||||
}
|
}
|
||||||
|
|
||||||
input<U extends ZodTypeAny = TInput>(
|
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
|
||||||
inputSchema: U
|
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
|
||||||
): BaseLLM<U, TOutput, TModelParams> {
|
refinedInstance._inputSchema = inputSchema
|
||||||
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema =
|
return refinedInstance
|
||||||
inputSchema
|
|
||||||
return this as unknown as BaseLLM<U, TOutput, TModelParams>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
output<U extends ZodTypeAny = TOutput>(
|
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
|
||||||
outputSchema: U
|
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
|
||||||
): BaseLLM<TInput, U, TModelParams> {
|
refinedInstance._outputSchema = outputSchema
|
||||||
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema =
|
return refinedInstance
|
||||||
outputSchema
|
|
||||||
return this as unknown as BaseLLM<TInput, U, TModelParams>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get inputSchema(): TInput {
|
public override get inputSchema(): ZodType<TInput> {
|
||||||
if (this._inputSchema) {
|
if (this._inputSchema) {
|
||||||
return this._inputSchema
|
return this._inputSchema
|
||||||
} else {
|
} else {
|
||||||
return z.void() as TInput
|
// TODO: improve typing
|
||||||
|
return z.void() as unknown as ZodType<TInput>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get outputSchema(): TOutput {
|
public override get outputSchema(): ZodType<TOutput> {
|
||||||
if (this._outputSchema) {
|
if (this._outputSchema) {
|
||||||
return this._outputSchema
|
return this._outputSchema
|
||||||
} else {
|
} else {
|
||||||
// TODO: improve typing
|
// TODO: improve typing
|
||||||
return z.string() as unknown as TOutput
|
return z.string() as unknown as ZodType<TOutput>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,8 +122,8 @@ export abstract class BaseLLM<
|
||||||
}
|
}
|
||||||
|
|
||||||
export abstract class BaseChatModel<
|
export abstract class BaseChatModel<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = void,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
TOutput = string,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>,
|
TModelParams extends Record<string, any> = Record<string, any>,
|
||||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
||||||
|
@ -148,8 +145,8 @@ export abstract class BaseChatModel<
|
||||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||||
|
|
||||||
public async buildMessages(
|
public async buildMessages(
|
||||||
input?: types.ParsedData<TInput>,
|
input?: TInput,
|
||||||
ctx?: types.TaskCallContext
|
ctx?: types.TaskCallContext<TInput>
|
||||||
) {
|
) {
|
||||||
if (this._inputSchema) {
|
if (this._inputSchema) {
|
||||||
// TODO: handle errors gracefully
|
// TODO: handle errors gracefully
|
||||||
|
@ -222,7 +219,7 @@ export abstract class BaseChatModel<
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
||||||
): Promise<types.ParsedData<TOutput>> {
|
): Promise<TOutput> {
|
||||||
const messages = await this.buildMessages(ctx.input, ctx)
|
const messages = await this.buildMessages(ctx.input, ctx)
|
||||||
|
|
||||||
console.log('>>>')
|
console.log('>>>')
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import { type SetOptional } from 'type-fest'
|
import { type SetOptional } from 'type-fest'
|
||||||
import { ZodTypeAny, z } from 'zod'
|
|
||||||
|
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
import { defaultOpenAIModel } from '@/constants'
|
import { defaultOpenAIModel } from '@/constants'
|
||||||
|
@ -7,8 +6,8 @@ import { defaultOpenAIModel } from '@/constants'
|
||||||
import { BaseChatModel } from './llm'
|
import { BaseChatModel } from './llm'
|
||||||
|
|
||||||
export class OpenAIChatModel<
|
export class OpenAIChatModel<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = any,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>
|
TOutput = string
|
||||||
> extends BaseChatModel<
|
> extends BaseChatModel<
|
||||||
TInput,
|
TInput,
|
||||||
TOutput,
|
TOutput,
|
||||||
|
|
21
src/task.ts
21
src/task.ts
|
@ -1,5 +1,5 @@
|
||||||
import pRetry, { FailedAttemptError } from 'p-retry'
|
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||||
import { ZodTypeAny } from 'zod'
|
import { ZodType } from 'zod'
|
||||||
|
|
||||||
import * as errors from '@/errors'
|
import * as errors from '@/errors'
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
|
@ -18,10 +18,7 @@ import { Agentic } from '@/agentic'
|
||||||
* - Native function calls
|
* - Native function calls
|
||||||
* - Invoking sub-agents
|
* - Invoking sub-agents
|
||||||
*/
|
*/
|
||||||
export abstract class BaseTask<
|
export abstract class BaseTask<TInput = void, TOutput = string> {
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
|
||||||
TOutput extends ZodTypeAny = ZodTypeAny
|
|
||||||
> {
|
|
||||||
protected _agentic: Agentic
|
protected _agentic: Agentic
|
||||||
protected _id: string
|
protected _id: string
|
||||||
|
|
||||||
|
@ -46,8 +43,8 @@ export abstract class BaseTask<
|
||||||
return this._id
|
return this._id
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract get inputSchema(): TInput
|
public abstract get inputSchema(): ZodType<TInput>
|
||||||
public abstract get outputSchema(): TOutput
|
public abstract get outputSchema(): ZodType<TOutput>
|
||||||
|
|
||||||
public abstract get name(): string
|
public abstract get name(): string
|
||||||
|
|
||||||
|
@ -74,15 +71,13 @@ export abstract class BaseTask<
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
public async call(
|
public async call(input?: TInput): Promise<TOutput> {
|
||||||
input?: types.ParsedData<TInput>
|
|
||||||
): Promise<types.ParsedData<TOutput>> {
|
|
||||||
const res = await this.callWithMetadata(input)
|
const res = await this.callWithMetadata(input)
|
||||||
return res.result
|
return res.result
|
||||||
}
|
}
|
||||||
|
|
||||||
public async callWithMetadata(
|
public async callWithMetadata(
|
||||||
input?: types.ParsedData<TInput>
|
input?: TInput
|
||||||
): Promise<types.TaskResponse<TOutput>> {
|
): Promise<types.TaskResponse<TOutput>> {
|
||||||
if (this.inputSchema) {
|
if (this.inputSchema) {
|
||||||
const safeInput = this.inputSchema.safeParse(input)
|
const safeInput = this.inputSchema.safeParse(input)
|
||||||
|
@ -134,9 +129,7 @@ export abstract class BaseTask<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract _call(
|
protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
|
||||||
ctx: types.TaskCallContext<TInput>
|
|
||||||
): Promise<types.ParsedData<TOutput>>
|
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// abstract stream({
|
// abstract stream({
|
||||||
|
|
|
@ -31,8 +31,8 @@ export type MetaphorSearchToolOutput = z.infer<
|
||||||
>
|
>
|
||||||
|
|
||||||
export class MetaphorSearchTool extends BaseTask<
|
export class MetaphorSearchTool extends BaseTask<
|
||||||
typeof MetaphorSearchToolInputSchema,
|
MetaphorSearchToolInput,
|
||||||
typeof MetaphorSearchToolOutputSchema
|
MetaphorSearchToolOutput
|
||||||
> {
|
> {
|
||||||
_metaphorClient: MetaphorClient
|
_metaphorClient: MetaphorClient
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ export class MetaphorSearchTool extends BaseTask<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
|
ctx: types.TaskCallContext<MetaphorSearchToolInput>
|
||||||
): Promise<MetaphorSearchToolOutput> {
|
): Promise<MetaphorSearchToolOutput> {
|
||||||
// TODO: test required inputs
|
// TODO: test required inputs
|
||||||
return this._metaphorClient.search({
|
return this._metaphorClient.search({
|
||||||
|
|
|
@ -36,8 +36,8 @@ export type NovuNotificationToolOutput = z.infer<
|
||||||
>
|
>
|
||||||
|
|
||||||
export class NovuNotificationTool extends BaseTask<
|
export class NovuNotificationTool extends BaseTask<
|
||||||
typeof NovuNotificationToolInputSchema,
|
NovuNotificationToolInput,
|
||||||
typeof NovuNotificationToolOutputSchema
|
NovuNotificationToolOutput
|
||||||
> {
|
> {
|
||||||
_novuClient: NovuClient
|
_novuClient: NovuClient
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ export class NovuNotificationTool extends BaseTask<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
ctx: types.TaskCallContext<typeof NovuNotificationToolInputSchema>
|
ctx: types.TaskCallContext<NovuNotificationToolInput>
|
||||||
): Promise<NovuNotificationToolOutput> {
|
): Promise<NovuNotificationToolOutput> {
|
||||||
return this._novuClient.triggerEvent(ctx.input!)
|
return this._novuClient.triggerEvent(ctx.input!)
|
||||||
}
|
}
|
||||||
|
|
26
src/types.ts
26
src/types.ts
|
@ -2,7 +2,7 @@ import * as anthropic from '@anthropic-ai/sdk'
|
||||||
import * as openai from 'openai-fetch'
|
import * as openai from 'openai-fetch'
|
||||||
import type { Options as RetryOptions } from 'p-retry'
|
import type { Options as RetryOptions } from 'p-retry'
|
||||||
import type { JsonObject } from 'type-fest'
|
import type { JsonObject } from 'type-fest'
|
||||||
import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod'
|
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||||
|
|
||||||
import type { Agentic } from './agentic'
|
import type { Agentic } from './agentic'
|
||||||
|
|
||||||
|
@ -31,12 +31,12 @@ export interface BaseTaskOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface BaseLLMOptions<
|
export interface BaseLLMOptions<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = void,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
TOutput = string,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>
|
TModelParams extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseTaskOptions {
|
> extends BaseTaskOptions {
|
||||||
inputSchema?: TInput
|
inputSchema?: ZodType<TInput>
|
||||||
outputSchema?: TOutput
|
outputSchema?: ZodType<TOutput>
|
||||||
|
|
||||||
provider?: string
|
provider?: string
|
||||||
model?: string
|
model?: string
|
||||||
|
@ -45,8 +45,8 @@ export interface BaseLLMOptions<
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface LLMOptions<
|
export interface LLMOptions<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = void,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
TOutput = string,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>
|
TModelParams extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||||
promptTemplate?: string
|
promptTemplate?: string
|
||||||
|
@ -69,8 +69,8 @@ export interface ChatMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatModelOptions<
|
export interface ChatModelOptions<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = void,
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
TOutput = string,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>
|
TModelParams extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||||
messages: ChatMessage[]
|
messages: ChatMessage[]
|
||||||
|
@ -120,18 +120,18 @@ export interface LLMTaskResponseMetadata<
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TaskResponse<
|
export interface TaskResponse<
|
||||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
TOutput = string,
|
||||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||||
> {
|
> {
|
||||||
result: ParsedData<TOutput>
|
result: TOutput
|
||||||
metadata: TMetadata
|
metadata: TMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TaskCallContext<
|
export interface TaskCallContext<
|
||||||
TInput extends ZodTypeAny = ZodTypeAny,
|
TInput = void,
|
||||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||||
> {
|
> {
|
||||||
input?: ParsedData<TInput>
|
input?: TInput
|
||||||
retryMessage?: string
|
retryMessage?: string
|
||||||
|
|
||||||
attemptNumber: number
|
attemptNumber: number
|
||||||
|
|
Ładowanie…
Reference in New Issue