feat: simplify Task input/output typing

Travis Fischer 2023-06-11 00:46:38 -07:00
rodzic 113fdebee3
commit 9fa5f26215
7 zmienionych plików z 53 dodań i 65 usunięć

Wyświetl plik

@ -1,6 +1,5 @@
import * as anthropic from '@anthropic-ai/sdk' import * as anthropic from '@anthropic-ai/sdk'
import { type SetOptional } from 'type-fest' import { type SetOptional } from 'type-fest'
import { ZodTypeAny, z } from 'zod'
import * as types from '@/types' import * as types from '@/types'
import { defaultAnthropicModel } from '@/constants' import { defaultAnthropicModel } from '@/constants'
@ -10,8 +9,8 @@ import { BaseChatModel } from './llm'
const defaultStopSequences = [anthropic.HUMAN_PROMPT] const defaultStopSequences = [anthropic.HUMAN_PROMPT]
export class AnthropicChatModel< export class AnthropicChatModel<
TInput extends ZodTypeAny = ZodTypeAny, TInput = any,
TOutput extends ZodTypeAny = z.ZodType<string> TOutput = string
> extends BaseChatModel< > extends BaseChatModel<
TInput, TInput,
TOutput, TOutput,

Wyświetl plik

@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
import pMap from 'p-map' import pMap from 'p-map'
import { dedent } from 'ts-dedent' import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest' import { type SetRequired } from 'type-fest'
import { ZodTypeAny, z } from 'zod' import { ZodType, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts' import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors' import * as errors from '@/errors'
@ -20,12 +20,12 @@ import {
} from '@/utils' } from '@/utils'
export abstract class BaseLLM< export abstract class BaseLLM<
TInput extends ZodTypeAny = z.ZodVoid, TInput = void,
TOutput extends ZodTypeAny = z.ZodType<string>, TOutput = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTask<TInput, TOutput> { > extends BaseTask<TInput, TOutput> {
protected _inputSchema: TInput | undefined protected _inputSchema: ZodType<TInput> | undefined
protected _outputSchema: TOutput | undefined protected _outputSchema: ZodType<TOutput> | undefined
protected _provider: string protected _provider: string
protected _model: string protected _model: string
@ -50,36 +50,33 @@ export abstract class BaseLLM<
this._examples = options.examples this._examples = options.examples
} }
input<U extends ZodTypeAny = TInput>( input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
inputSchema: U const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
): BaseLLM<U, TOutput, TModelParams> { refinedInstance._inputSchema = inputSchema
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema = return refinedInstance
inputSchema
return this as unknown as BaseLLM<U, TOutput, TModelParams>
} }
output<U extends ZodTypeAny = TOutput>( output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
outputSchema: U const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
): BaseLLM<TInput, U, TModelParams> { refinedInstance._outputSchema = outputSchema
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema = return refinedInstance
outputSchema
return this as unknown as BaseLLM<TInput, U, TModelParams>
} }
public override get inputSchema(): TInput { public override get inputSchema(): ZodType<TInput> {
if (this._inputSchema) { if (this._inputSchema) {
return this._inputSchema return this._inputSchema
} else { } else {
return z.void() as TInput // TODO: improve typing
return z.void() as unknown as ZodType<TInput>
} }
} }
public override get outputSchema(): TOutput { public override get outputSchema(): ZodType<TOutput> {
if (this._outputSchema) { if (this._outputSchema) {
return this._outputSchema return this._outputSchema
} else { } else {
// TODO: improve typing // TODO: improve typing
return z.string() as unknown as TOutput return z.string() as unknown as ZodType<TOutput>
} }
} }
@ -125,8 +122,8 @@ export abstract class BaseLLM<
} }
export abstract class BaseChatModel< export abstract class BaseChatModel<
TInput extends ZodTypeAny = ZodTypeAny, TInput = void,
TOutput extends ZodTypeAny = z.ZodType<string>, TOutput = string,
TModelParams extends Record<string, any> = Record<string, any>, TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any> TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> { > extends BaseLLM<TInput, TOutput, TModelParams> {
@ -148,8 +145,8 @@ export abstract class BaseChatModel<
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>> ): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages( public async buildMessages(
input?: types.ParsedData<TInput>, input?: TInput,
ctx?: types.TaskCallContext ctx?: types.TaskCallContext<TInput>
) { ) {
if (this._inputSchema) { if (this._inputSchema) {
// TODO: handle errors gracefully // TODO: handle errors gracefully
@ -222,7 +219,7 @@ export abstract class BaseChatModel<
protected override async _call( protected override async _call(
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata> ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
): Promise<types.ParsedData<TOutput>> { ): Promise<TOutput> {
const messages = await this.buildMessages(ctx.input, ctx) const messages = await this.buildMessages(ctx.input, ctx)
console.log('>>>') console.log('>>>')

Wyświetl plik

@ -1,5 +1,4 @@
import { type SetOptional } from 'type-fest' import { type SetOptional } from 'type-fest'
import { ZodTypeAny, z } from 'zod'
import * as types from '@/types' import * as types from '@/types'
import { defaultOpenAIModel } from '@/constants' import { defaultOpenAIModel } from '@/constants'
@ -7,8 +6,8 @@ import { defaultOpenAIModel } from '@/constants'
import { BaseChatModel } from './llm' import { BaseChatModel } from './llm'
export class OpenAIChatModel< export class OpenAIChatModel<
TInput extends ZodTypeAny = ZodTypeAny, TInput = any,
TOutput extends ZodTypeAny = z.ZodType<string> TOutput = string
> extends BaseChatModel< > extends BaseChatModel<
TInput, TInput,
TOutput, TOutput,

Wyświetl plik

@ -1,5 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry' import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodTypeAny } from 'zod' import { ZodType } from 'zod'
import * as errors from '@/errors' import * as errors from '@/errors'
import * as types from '@/types' import * as types from '@/types'
@ -18,10 +18,7 @@ import { Agentic } from '@/agentic'
* - Native function calls * - Native function calls
* - Invoking sub-agents * - Invoking sub-agents
*/ */
export abstract class BaseTask< export abstract class BaseTask<TInput = void, TOutput = string> {
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
> {
protected _agentic: Agentic protected _agentic: Agentic
protected _id: string protected _id: string
@ -46,8 +43,8 @@ export abstract class BaseTask<
return this._id return this._id
} }
public abstract get inputSchema(): TInput public abstract get inputSchema(): ZodType<TInput>
public abstract get outputSchema(): TOutput public abstract get outputSchema(): ZodType<TOutput>
public abstract get name(): string public abstract get name(): string
@ -74,15 +71,13 @@ export abstract class BaseTask<
return this return this
} }
public async call( public async call(input?: TInput): Promise<TOutput> {
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> {
const res = await this.callWithMetadata(input) const res = await this.callWithMetadata(input)
return res.result return res.result
} }
public async callWithMetadata( public async callWithMetadata(
input?: types.ParsedData<TInput> input?: TInput
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
if (this.inputSchema) { if (this.inputSchema) {
const safeInput = this.inputSchema.safeParse(input) const safeInput = this.inputSchema.safeParse(input)
@ -134,9 +129,7 @@ export abstract class BaseTask<
} }
} }
protected abstract _call( protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
ctx: types.TaskCallContext<TInput>
): Promise<types.ParsedData<TOutput>>
// TODO // TODO
// abstract stream({ // abstract stream({

Wyświetl plik

@ -31,8 +31,8 @@ export type MetaphorSearchToolOutput = z.infer<
> >
export class MetaphorSearchTool extends BaseTask< export class MetaphorSearchTool extends BaseTask<
typeof MetaphorSearchToolInputSchema, MetaphorSearchToolInput,
typeof MetaphorSearchToolOutputSchema MetaphorSearchToolOutput
> { > {
_metaphorClient: MetaphorClient _metaphorClient: MetaphorClient
@ -65,7 +65,7 @@ export class MetaphorSearchTool extends BaseTask<
} }
protected override async _call( protected override async _call(
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema> ctx: types.TaskCallContext<MetaphorSearchToolInput>
): Promise<MetaphorSearchToolOutput> { ): Promise<MetaphorSearchToolOutput> {
// TODO: test required inputs // TODO: test required inputs
return this._metaphorClient.search({ return this._metaphorClient.search({

Wyświetl plik

@ -36,8 +36,8 @@ export type NovuNotificationToolOutput = z.infer<
> >
export class NovuNotificationTool extends BaseTask< export class NovuNotificationTool extends BaseTask<
typeof NovuNotificationToolInputSchema, NovuNotificationToolInput,
typeof NovuNotificationToolOutputSchema NovuNotificationToolOutput
> { > {
_novuClient: NovuClient _novuClient: NovuClient
@ -68,7 +68,7 @@ export class NovuNotificationTool extends BaseTask<
} }
protected override async _call( protected override async _call(
ctx: types.TaskCallContext<typeof NovuNotificationToolInputSchema> ctx: types.TaskCallContext<NovuNotificationToolInput>
): Promise<NovuNotificationToolOutput> { ): Promise<NovuNotificationToolOutput> {
return this._novuClient.triggerEvent(ctx.input!) return this._novuClient.triggerEvent(ctx.input!)
} }

Wyświetl plik

@ -2,7 +2,7 @@ import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch' import * as openai from 'openai-fetch'
import type { Options as RetryOptions } from 'p-retry' import type { Options as RetryOptions } from 'p-retry'
import type { JsonObject } from 'type-fest' import type { JsonObject } from 'type-fest'
import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic' import type { Agentic } from './agentic'
@ -31,12 +31,12 @@ export interface BaseTaskOptions {
} }
export interface BaseLLMOptions< export interface BaseLLMOptions<
TInput extends ZodTypeAny = ZodTypeAny, TInput = void,
TOutput extends ZodTypeAny = z.ZodType<string>, TOutput = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTaskOptions { > extends BaseTaskOptions {
inputSchema?: TInput inputSchema?: ZodType<TInput>
outputSchema?: TOutput outputSchema?: ZodType<TOutput>
provider?: string provider?: string
model?: string model?: string
@ -45,8 +45,8 @@ export interface BaseLLMOptions<
} }
export interface LLMOptions< export interface LLMOptions<
TInput extends ZodTypeAny = ZodTypeAny, TInput = void,
TOutput extends ZodTypeAny = z.ZodType<string>, TOutput = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
promptTemplate?: string promptTemplate?: string
@ -69,8 +69,8 @@ export interface ChatMessage {
} }
export interface ChatModelOptions< export interface ChatModelOptions<
TInput extends ZodTypeAny = ZodTypeAny, TInput = void,
TOutput extends ZodTypeAny = z.ZodType<string>, TOutput = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[] messages: ChatMessage[]
@ -120,18 +120,18 @@ export interface LLMTaskResponseMetadata<
} }
export interface TaskResponse< export interface TaskResponse<
TOutput extends ZodTypeAny = z.ZodType<string>, TOutput = string,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
result: ParsedData<TOutput> result: TOutput
metadata: TMetadata metadata: TMetadata
} }
export interface TaskCallContext< export interface TaskCallContext<
TInput extends ZodTypeAny = ZodTypeAny, TInput = void,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
input?: ParsedData<TInput> input?: TInput
retryMessage?: string retryMessage?: string
attemptNumber: number attemptNumber: number