diff --git a/src/llms/anthropic.ts b/src/llms/anthropic.ts index 10f2b4ff..772d6b98 100644 --- a/src/llms/anthropic.ts +++ b/src/llms/anthropic.ts @@ -1,6 +1,5 @@ import * as anthropic from '@anthropic-ai/sdk' import { type SetOptional } from 'type-fest' -import { ZodTypeAny, z } from 'zod' import * as types from '@/types' import { defaultAnthropicModel } from '@/constants' @@ -10,8 +9,8 @@ import { BaseChatModel } from './llm' const defaultStopSequences = [anthropic.HUMAN_PROMPT] export class AnthropicChatModel< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = z.ZodType + TInput = any, + TOutput = string > extends BaseChatModel< TInput, TOutput, diff --git a/src/llms/llm.ts b/src/llms/llm.ts index 6a5b3ea9..c014242d 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair' import pMap from 'p-map' import { dedent } from 'ts-dedent' import { type SetRequired } from 'type-fest' -import { ZodTypeAny, z } from 'zod' +import { ZodType, z } from 'zod' import { printNode, zodToTs } from 'zod-to-ts' import * as errors from '@/errors' @@ -20,12 +20,12 @@ import { } from '@/utils' export abstract class BaseLLM< - TInput extends ZodTypeAny = z.ZodVoid, - TOutput extends ZodTypeAny = z.ZodType, + TInput = void, + TOutput = string, TModelParams extends Record = Record > extends BaseTask { - protected _inputSchema: TInput | undefined - protected _outputSchema: TOutput | undefined + protected _inputSchema: ZodType | undefined + protected _outputSchema: ZodType | undefined protected _provider: string protected _model: string @@ -50,36 +50,33 @@ export abstract class BaseLLM< this._examples = options.examples } - input( - inputSchema: U - ): BaseLLM { - ;(this as unknown as BaseLLM)._inputSchema = - inputSchema - return this as unknown as BaseLLM + input(inputSchema: ZodType): BaseLLM { + const refinedInstance = this as unknown as BaseLLM + refinedInstance._inputSchema = inputSchema + return refinedInstance } - output( - outputSchema: U - ): BaseLLM { - ;(this as unknown as BaseLLM)._outputSchema = - outputSchema - return this as unknown as BaseLLM + output(outputSchema: ZodType): BaseLLM { + const refinedInstance = this as unknown as BaseLLM + refinedInstance._outputSchema = outputSchema + return refinedInstance } - public override get inputSchema(): TInput { + public override get inputSchema(): ZodType { if (this._inputSchema) { return this._inputSchema } else { - return z.void() as TInput + // TODO: improve typing + return z.void() as unknown as ZodType } } - public override get outputSchema(): TOutput { + public override get outputSchema(): ZodType { if (this._outputSchema) { return this._outputSchema } else { // TODO: improve typing - return z.string() as unknown as TOutput + return z.string() as unknown as ZodType } } @@ -125,8 +122,8 @@ export abstract class BaseLLM< } export abstract class BaseChatModel< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = z.ZodType, + TInput = void, + TOutput = string, TModelParams extends Record = Record, TChatCompletionResponse extends Record = Record > extends BaseLLM { @@ -148,8 +145,8 @@ export abstract class BaseChatModel< ): Promise> public async buildMessages( - input?: types.ParsedData, - ctx?: types.TaskCallContext + input?: TInput, + ctx?: types.TaskCallContext ) { if (this._inputSchema) { // TODO: handle errors gracefully @@ -222,7 +219,7 @@ export abstract class BaseChatModel< protected override async _call( ctx: types.TaskCallContext - ): Promise> { + ): Promise { const messages = await this.buildMessages(ctx.input, ctx) console.log('>>>') diff --git a/src/llms/openai.ts b/src/llms/openai.ts index 547a897c..84b45aae 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -1,5 +1,4 @@ import { type SetOptional } from 'type-fest' -import { ZodTypeAny, z } from 'zod' import * as types from '@/types' import { defaultOpenAIModel } from '@/constants' @@ -7,8 +6,8 @@ import { defaultOpenAIModel } from '@/constants' import { BaseChatModel } from './llm' export class OpenAIChatModel< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = z.ZodType + TInput = any, + TOutput = string > extends BaseChatModel< TInput, TOutput, diff --git a/src/task.ts b/src/task.ts index e00a99f6..8518fd84 100644 --- a/src/task.ts +++ b/src/task.ts @@ -1,5 +1,5 @@ import pRetry, { FailedAttemptError } from 'p-retry' -import { ZodTypeAny } from 'zod' +import { ZodType } from 'zod' import * as errors from '@/errors' import * as types from '@/types' @@ -18,10 +18,7 @@ import { Agentic } from '@/agentic' * - Native function calls * - Invoking sub-agents */ -export abstract class BaseTask< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = ZodTypeAny -> { +export abstract class BaseTask { protected _agentic: Agentic protected _id: string @@ -46,8 +43,8 @@ export abstract class BaseTask< return this._id } - public abstract get inputSchema(): TInput - public abstract get outputSchema(): TOutput + public abstract get inputSchema(): ZodType + public abstract get outputSchema(): ZodType public abstract get name(): string @@ -74,15 +71,13 @@ export abstract class BaseTask< return this } - public async call( - input?: types.ParsedData - ): Promise> { + public async call(input?: TInput): Promise { const res = await this.callWithMetadata(input) return res.result } public async callWithMetadata( - input?: types.ParsedData + input?: TInput ): Promise> { if (this.inputSchema) { const safeInput = this.inputSchema.safeParse(input) @@ -134,9 +129,7 @@ export abstract class BaseTask< } } - protected abstract _call( - ctx: types.TaskCallContext - ): Promise> + protected abstract _call(ctx: types.TaskCallContext): Promise // TODO // abstract stream({ diff --git a/src/tools/metaphor.ts b/src/tools/metaphor.ts index 06c2cf9e..6422c895 100644 --- a/src/tools/metaphor.ts +++ b/src/tools/metaphor.ts @@ -31,8 +31,8 @@ export type MetaphorSearchToolOutput = z.infer< > export class MetaphorSearchTool extends BaseTask< - typeof MetaphorSearchToolInputSchema, - typeof MetaphorSearchToolOutputSchema + MetaphorSearchToolInput, + MetaphorSearchToolOutput > { _metaphorClient: MetaphorClient @@ -65,7 +65,7 @@ export class MetaphorSearchTool extends BaseTask< } protected override async _call( - ctx: types.TaskCallContext + ctx: types.TaskCallContext ): Promise { // TODO: test required inputs return this._metaphorClient.search({ diff --git a/src/tools/novu.ts b/src/tools/novu.ts index 40f639a0..cd7aae5e 100644 --- a/src/tools/novu.ts +++ b/src/tools/novu.ts @@ -36,8 +36,8 @@ export type NovuNotificationToolOutput = z.infer< > export class NovuNotificationTool extends BaseTask< - typeof NovuNotificationToolInputSchema, - typeof NovuNotificationToolOutputSchema + NovuNotificationToolInput, + NovuNotificationToolOutput > { _novuClient: NovuClient @@ -68,7 +68,7 @@ export class NovuNotificationTool extends BaseTask< } protected override async _call( - ctx: types.TaskCallContext + ctx: types.TaskCallContext ): Promise { return this._novuClient.triggerEvent(ctx.input!) } diff --git a/src/types.ts b/src/types.ts index a95f1336..57c51bc5 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2,7 +2,7 @@ import * as anthropic from '@anthropic-ai/sdk' import * as openai from 'openai-fetch' import type { Options as RetryOptions } from 'p-retry' import type { JsonObject } from 'type-fest' -import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod' +import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import type { Agentic } from './agentic' @@ -31,12 +31,12 @@ export interface BaseTaskOptions { } export interface BaseLLMOptions< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = z.ZodType, + TInput = void, + TOutput = string, TModelParams extends Record = Record > extends BaseTaskOptions { - inputSchema?: TInput - outputSchema?: TOutput + inputSchema?: ZodType + outputSchema?: ZodType provider?: string model?: string @@ -45,8 +45,8 @@ export interface BaseLLMOptions< } export interface LLMOptions< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = z.ZodType, + TInput = void, + TOutput = string, TModelParams extends Record = Record > extends BaseLLMOptions { promptTemplate?: string @@ -69,8 +69,8 @@ export interface ChatMessage { } export interface ChatModelOptions< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = z.ZodType, + TInput = void, + TOutput = string, TModelParams extends Record = Record > extends BaseLLMOptions { messages: ChatMessage[] @@ -120,18 +120,18 @@ export interface LLMTaskResponseMetadata< } export interface TaskResponse< - TOutput extends ZodTypeAny = z.ZodType, + TOutput = string, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { - result: ParsedData + result: TOutput metadata: TMetadata } export interface TaskCallContext< - TInput extends ZodTypeAny = ZodTypeAny, + TInput = void, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { - input?: ParsedData + input?: TInput retryMessage?: string attemptNumber: number