kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: simplify Task input/output typing
rodzic
113fdebee3
commit
9fa5f26215
|
@ -1,6 +1,5 @@
|
|||
import * as anthropic from '@anthropic-ai/sdk'
|
||||
import { type SetOptional } from 'type-fest'
|
||||
import { ZodTypeAny, z } from 'zod'
|
||||
|
||||
import * as types from '@/types'
|
||||
import { defaultAnthropicModel } from '@/constants'
|
||||
|
@ -10,8 +9,8 @@ import { BaseChatModel } from './llm'
|
|||
const defaultStopSequences = [anthropic.HUMAN_PROMPT]
|
||||
|
||||
export class AnthropicChatModel<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>
|
||||
TInput = any,
|
||||
TOutput = string
|
||||
> extends BaseChatModel<
|
||||
TInput,
|
||||
TOutput,
|
||||
|
|
|
@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
|
|||
import pMap from 'p-map'
|
||||
import { dedent } from 'ts-dedent'
|
||||
import { type SetRequired } from 'type-fest'
|
||||
import { ZodTypeAny, z } from 'zod'
|
||||
import { ZodType, z } from 'zod'
|
||||
import { printNode, zodToTs } from 'zod-to-ts'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
|
@ -20,12 +20,12 @@ import {
|
|||
} from '@/utils'
|
||||
|
||||
export abstract class BaseLLM<
|
||||
TInput extends ZodTypeAny = z.ZodVoid,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TInput = void,
|
||||
TOutput = string,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseTask<TInput, TOutput> {
|
||||
protected _inputSchema: TInput | undefined
|
||||
protected _outputSchema: TOutput | undefined
|
||||
protected _inputSchema: ZodType<TInput> | undefined
|
||||
protected _outputSchema: ZodType<TOutput> | undefined
|
||||
|
||||
protected _provider: string
|
||||
protected _model: string
|
||||
|
@ -50,36 +50,33 @@ export abstract class BaseLLM<
|
|||
this._examples = options.examples
|
||||
}
|
||||
|
||||
input<U extends ZodTypeAny = TInput>(
|
||||
inputSchema: U
|
||||
): BaseLLM<U, TOutput, TModelParams> {
|
||||
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema =
|
||||
inputSchema
|
||||
return this as unknown as BaseLLM<U, TOutput, TModelParams>
|
||||
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
|
||||
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
|
||||
refinedInstance._inputSchema = inputSchema
|
||||
return refinedInstance
|
||||
}
|
||||
|
||||
output<U extends ZodTypeAny = TOutput>(
|
||||
outputSchema: U
|
||||
): BaseLLM<TInput, U, TModelParams> {
|
||||
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema =
|
||||
outputSchema
|
||||
return this as unknown as BaseLLM<TInput, U, TModelParams>
|
||||
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
|
||||
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
|
||||
refinedInstance._outputSchema = outputSchema
|
||||
return refinedInstance
|
||||
}
|
||||
|
||||
public override get inputSchema(): TInput {
|
||||
public override get inputSchema(): ZodType<TInput> {
|
||||
if (this._inputSchema) {
|
||||
return this._inputSchema
|
||||
} else {
|
||||
return z.void() as TInput
|
||||
// TODO: improve typing
|
||||
return z.void() as unknown as ZodType<TInput>
|
||||
}
|
||||
}
|
||||
|
||||
public override get outputSchema(): TOutput {
|
||||
public override get outputSchema(): ZodType<TOutput> {
|
||||
if (this._outputSchema) {
|
||||
return this._outputSchema
|
||||
} else {
|
||||
// TODO: improve typing
|
||||
return z.string() as unknown as TOutput
|
||||
return z.string() as unknown as ZodType<TOutput>
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,8 +122,8 @@ export abstract class BaseLLM<
|
|||
}
|
||||
|
||||
export abstract class BaseChatModel<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TInput = void,
|
||||
TOutput = string,
|
||||
TModelParams extends Record<string, any> = Record<string, any>,
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
||||
|
@ -148,8 +145,8 @@ export abstract class BaseChatModel<
|
|||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||
|
||||
public async buildMessages(
|
||||
input?: types.ParsedData<TInput>,
|
||||
ctx?: types.TaskCallContext
|
||||
input?: TInput,
|
||||
ctx?: types.TaskCallContext<TInput>
|
||||
) {
|
||||
if (this._inputSchema) {
|
||||
// TODO: handle errors gracefully
|
||||
|
@ -222,7 +219,7 @@ export abstract class BaseChatModel<
|
|||
|
||||
protected override async _call(
|
||||
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
||||
): Promise<types.ParsedData<TOutput>> {
|
||||
): Promise<TOutput> {
|
||||
const messages = await this.buildMessages(ctx.input, ctx)
|
||||
|
||||
console.log('>>>')
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import { type SetOptional } from 'type-fest'
|
||||
import { ZodTypeAny, z } from 'zod'
|
||||
|
||||
import * as types from '@/types'
|
||||
import { defaultOpenAIModel } from '@/constants'
|
||||
|
@ -7,8 +6,8 @@ import { defaultOpenAIModel } from '@/constants'
|
|||
import { BaseChatModel } from './llm'
|
||||
|
||||
export class OpenAIChatModel<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>
|
||||
TInput = any,
|
||||
TOutput = string
|
||||
> extends BaseChatModel<
|
||||
TInput,
|
||||
TOutput,
|
||||
|
|
21
src/task.ts
21
src/task.ts
|
@ -1,5 +1,5 @@
|
|||
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||
import { ZodTypeAny } from 'zod'
|
||||
import { ZodType } from 'zod'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
import * as types from '@/types'
|
||||
|
@ -18,10 +18,7 @@ import { Agentic } from '@/agentic'
|
|||
* - Native function calls
|
||||
* - Invoking sub-agents
|
||||
*/
|
||||
export abstract class BaseTask<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = ZodTypeAny
|
||||
> {
|
||||
export abstract class BaseTask<TInput = void, TOutput = string> {
|
||||
protected _agentic: Agentic
|
||||
protected _id: string
|
||||
|
||||
|
@ -46,8 +43,8 @@ export abstract class BaseTask<
|
|||
return this._id
|
||||
}
|
||||
|
||||
public abstract get inputSchema(): TInput
|
||||
public abstract get outputSchema(): TOutput
|
||||
public abstract get inputSchema(): ZodType<TInput>
|
||||
public abstract get outputSchema(): ZodType<TOutput>
|
||||
|
||||
public abstract get name(): string
|
||||
|
||||
|
@ -74,15 +71,13 @@ export abstract class BaseTask<
|
|||
return this
|
||||
}
|
||||
|
||||
public async call(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.ParsedData<TOutput>> {
|
||||
public async call(input?: TInput): Promise<TOutput> {
|
||||
const res = await this.callWithMetadata(input)
|
||||
return res.result
|
||||
}
|
||||
|
||||
public async callWithMetadata(
|
||||
input?: types.ParsedData<TInput>
|
||||
input?: TInput
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
if (this.inputSchema) {
|
||||
const safeInput = this.inputSchema.safeParse(input)
|
||||
|
@ -134,9 +129,7 @@ export abstract class BaseTask<
|
|||
}
|
||||
}
|
||||
|
||||
protected abstract _call(
|
||||
ctx: types.TaskCallContext<TInput>
|
||||
): Promise<types.ParsedData<TOutput>>
|
||||
protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
|
||||
|
||||
// TODO
|
||||
// abstract stream({
|
||||
|
|
|
@ -31,8 +31,8 @@ export type MetaphorSearchToolOutput = z.infer<
|
|||
>
|
||||
|
||||
export class MetaphorSearchTool extends BaseTask<
|
||||
typeof MetaphorSearchToolInputSchema,
|
||||
typeof MetaphorSearchToolOutputSchema
|
||||
MetaphorSearchToolInput,
|
||||
MetaphorSearchToolOutput
|
||||
> {
|
||||
_metaphorClient: MetaphorClient
|
||||
|
||||
|
@ -65,7 +65,7 @@ export class MetaphorSearchTool extends BaseTask<
|
|||
}
|
||||
|
||||
protected override async _call(
|
||||
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
|
||||
ctx: types.TaskCallContext<MetaphorSearchToolInput>
|
||||
): Promise<MetaphorSearchToolOutput> {
|
||||
// TODO: test required inputs
|
||||
return this._metaphorClient.search({
|
||||
|
|
|
@ -36,8 +36,8 @@ export type NovuNotificationToolOutput = z.infer<
|
|||
>
|
||||
|
||||
export class NovuNotificationTool extends BaseTask<
|
||||
typeof NovuNotificationToolInputSchema,
|
||||
typeof NovuNotificationToolOutputSchema
|
||||
NovuNotificationToolInput,
|
||||
NovuNotificationToolOutput
|
||||
> {
|
||||
_novuClient: NovuClient
|
||||
|
||||
|
@ -68,7 +68,7 @@ export class NovuNotificationTool extends BaseTask<
|
|||
}
|
||||
|
||||
protected override async _call(
|
||||
ctx: types.TaskCallContext<typeof NovuNotificationToolInputSchema>
|
||||
ctx: types.TaskCallContext<NovuNotificationToolInput>
|
||||
): Promise<NovuNotificationToolOutput> {
|
||||
return this._novuClient.triggerEvent(ctx.input!)
|
||||
}
|
||||
|
|
26
src/types.ts
26
src/types.ts
|
@ -2,7 +2,7 @@ import * as anthropic from '@anthropic-ai/sdk'
|
|||
import * as openai from 'openai-fetch'
|
||||
import type { Options as RetryOptions } from 'p-retry'
|
||||
import type { JsonObject } from 'type-fest'
|
||||
import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod'
|
||||
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||
|
||||
import type { Agentic } from './agentic'
|
||||
|
||||
|
@ -31,12 +31,12 @@ export interface BaseTaskOptions {
|
|||
}
|
||||
|
||||
export interface BaseLLMOptions<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TInput = void,
|
||||
TOutput = string,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseTaskOptions {
|
||||
inputSchema?: TInput
|
||||
outputSchema?: TOutput
|
||||
inputSchema?: ZodType<TInput>
|
||||
outputSchema?: ZodType<TOutput>
|
||||
|
||||
provider?: string
|
||||
model?: string
|
||||
|
@ -45,8 +45,8 @@ export interface BaseLLMOptions<
|
|||
}
|
||||
|
||||
export interface LLMOptions<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TInput = void,
|
||||
TOutput = string,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||
promptTemplate?: string
|
||||
|
@ -69,8 +69,8 @@ export interface ChatMessage {
|
|||
}
|
||||
|
||||
export interface ChatModelOptions<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TInput = void,
|
||||
TOutput = string,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||
messages: ChatMessage[]
|
||||
|
@ -120,18 +120,18 @@ export interface LLMTaskResponseMetadata<
|
|||
}
|
||||
|
||||
export interface TaskResponse<
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TOutput = string,
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
result: ParsedData<TOutput>
|
||||
result: TOutput
|
||||
metadata: TMetadata
|
||||
}
|
||||
|
||||
export interface TaskCallContext<
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TInput = void,
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
input?: ParsedData<TInput>
|
||||
input?: TInput
|
||||
retryMessage?: string
|
||||
|
||||
attemptNumber: number
|
||||
|
|
Ładowanie…
Reference in New Issue