feat: simplify Task input/output typing

Travis Fischer 2023-06-11 00:46:38 -07:00
rodzic 113fdebee3
commit 9fa5f26215
7 zmienionych plików z 53 dodań i 65 usunięć

Wyświetl plik

@ -1,6 +1,5 @@
import * as anthropic from '@anthropic-ai/sdk'
import { type SetOptional } from 'type-fest'
import { ZodTypeAny, z } from 'zod'
import * as types from '@/types'
import { defaultAnthropicModel } from '@/constants'
@ -10,8 +9,8 @@ import { BaseChatModel } from './llm'
const defaultStopSequences = [anthropic.HUMAN_PROMPT]
export class AnthropicChatModel<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>
TInput = any,
TOutput = string
> extends BaseChatModel<
TInput,
TOutput,

Wyświetl plik

@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
import pMap from 'p-map'
import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest'
import { ZodTypeAny, z } from 'zod'
import { ZodType, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors'
@ -20,12 +20,12 @@ import {
} from '@/utils'
export abstract class BaseLLM<
TInput extends ZodTypeAny = z.ZodVoid,
TOutput extends ZodTypeAny = z.ZodType<string>,
TInput = void,
TOutput = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTask<TInput, TOutput> {
protected _inputSchema: TInput | undefined
protected _outputSchema: TOutput | undefined
protected _inputSchema: ZodType<TInput> | undefined
protected _outputSchema: ZodType<TOutput> | undefined
protected _provider: string
protected _model: string
@ -50,36 +50,33 @@ export abstract class BaseLLM<
this._examples = options.examples
}
input<U extends ZodTypeAny = TInput>(
inputSchema: U
): BaseLLM<U, TOutput, TModelParams> {
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema =
inputSchema
return this as unknown as BaseLLM<U, TOutput, TModelParams>
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
refinedInstance._inputSchema = inputSchema
return refinedInstance
}
output<U extends ZodTypeAny = TOutput>(
outputSchema: U
): BaseLLM<TInput, U, TModelParams> {
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema =
outputSchema
return this as unknown as BaseLLM<TInput, U, TModelParams>
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
refinedInstance._outputSchema = outputSchema
return refinedInstance
}
public override get inputSchema(): TInput {
public override get inputSchema(): ZodType<TInput> {
if (this._inputSchema) {
return this._inputSchema
} else {
return z.void() as TInput
// TODO: improve typing
return z.void() as unknown as ZodType<TInput>
}
}
public override get outputSchema(): TOutput {
public override get outputSchema(): ZodType<TOutput> {
if (this._outputSchema) {
return this._outputSchema
} else {
// TODO: improve typing
return z.string() as unknown as TOutput
return z.string() as unknown as ZodType<TOutput>
}
}
@ -125,8 +122,8 @@ export abstract class BaseLLM<
}
export abstract class BaseChatModel<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TInput = void,
TOutput = string,
TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> {
@ -148,8 +145,8 @@ export abstract class BaseChatModel<
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages(
input?: types.ParsedData<TInput>,
ctx?: types.TaskCallContext
input?: TInput,
ctx?: types.TaskCallContext<TInput>
) {
if (this._inputSchema) {
// TODO: handle errors gracefully
@ -222,7 +219,7 @@ export abstract class BaseChatModel<
protected override async _call(
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
): Promise<types.ParsedData<TOutput>> {
): Promise<TOutput> {
const messages = await this.buildMessages(ctx.input, ctx)
console.log('>>>')

Wyświetl plik

@ -1,5 +1,4 @@
import { type SetOptional } from 'type-fest'
import { ZodTypeAny, z } from 'zod'
import * as types from '@/types'
import { defaultOpenAIModel } from '@/constants'
@ -7,8 +6,8 @@ import { defaultOpenAIModel } from '@/constants'
import { BaseChatModel } from './llm'
export class OpenAIChatModel<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>
TInput = any,
TOutput = string
> extends BaseChatModel<
TInput,
TOutput,

Wyświetl plik

@ -1,5 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodTypeAny } from 'zod'
import { ZodType } from 'zod'
import * as errors from '@/errors'
import * as types from '@/types'
@ -18,10 +18,7 @@ import { Agentic } from '@/agentic'
* - Native function calls
* - Invoking sub-agents
*/
export abstract class BaseTask<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
> {
export abstract class BaseTask<TInput = void, TOutput = string> {
protected _agentic: Agentic
protected _id: string
@ -46,8 +43,8 @@ export abstract class BaseTask<
return this._id
}
public abstract get inputSchema(): TInput
public abstract get outputSchema(): TOutput
public abstract get inputSchema(): ZodType<TInput>
public abstract get outputSchema(): ZodType<TOutput>
public abstract get name(): string
@ -74,15 +71,13 @@ export abstract class BaseTask<
return this
}
public async call(
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> {
public async call(input?: TInput): Promise<TOutput> {
const res = await this.callWithMetadata(input)
return res.result
}
public async callWithMetadata(
input?: types.ParsedData<TInput>
input?: TInput
): Promise<types.TaskResponse<TOutput>> {
if (this.inputSchema) {
const safeInput = this.inputSchema.safeParse(input)
@ -134,9 +129,7 @@ export abstract class BaseTask<
}
}
protected abstract _call(
ctx: types.TaskCallContext<TInput>
): Promise<types.ParsedData<TOutput>>
protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
// TODO
// abstract stream({

Wyświetl plik

@ -31,8 +31,8 @@ export type MetaphorSearchToolOutput = z.infer<
>
export class MetaphorSearchTool extends BaseTask<
typeof MetaphorSearchToolInputSchema,
typeof MetaphorSearchToolOutputSchema
MetaphorSearchToolInput,
MetaphorSearchToolOutput
> {
_metaphorClient: MetaphorClient
@ -65,7 +65,7 @@ export class MetaphorSearchTool extends BaseTask<
}
protected override async _call(
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
ctx: types.TaskCallContext<MetaphorSearchToolInput>
): Promise<MetaphorSearchToolOutput> {
// TODO: test required inputs
return this._metaphorClient.search({

Wyświetl plik

@ -36,8 +36,8 @@ export type NovuNotificationToolOutput = z.infer<
>
export class NovuNotificationTool extends BaseTask<
typeof NovuNotificationToolInputSchema,
typeof NovuNotificationToolOutputSchema
NovuNotificationToolInput,
NovuNotificationToolOutput
> {
_novuClient: NovuClient
@ -68,7 +68,7 @@ export class NovuNotificationTool extends BaseTask<
}
protected override async _call(
ctx: types.TaskCallContext<typeof NovuNotificationToolInputSchema>
ctx: types.TaskCallContext<NovuNotificationToolInput>
): Promise<NovuNotificationToolOutput> {
return this._novuClient.triggerEvent(ctx.input!)
}

Wyświetl plik

@ -2,7 +2,7 @@ import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch'
import type { Options as RetryOptions } from 'p-retry'
import type { JsonObject } from 'type-fest'
import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic'
@ -31,12 +31,12 @@ export interface BaseTaskOptions {
}
export interface BaseLLMOptions<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TInput = void,
TOutput = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTaskOptions {
inputSchema?: TInput
outputSchema?: TOutput
inputSchema?: ZodType<TInput>
outputSchema?: ZodType<TOutput>
provider?: string
model?: string
@ -45,8 +45,8 @@ export interface BaseLLMOptions<
}
export interface LLMOptions<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TInput = void,
TOutput = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
promptTemplate?: string
@ -69,8 +69,8 @@ export interface ChatMessage {
}
export interface ChatModelOptions<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TInput = void,
TOutput = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[]
@ -120,18 +120,18 @@ export interface LLMTaskResponseMetadata<
}
export interface TaskResponse<
TOutput extends ZodTypeAny = z.ZodType<string>,
TOutput = string,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
result: ParsedData<TOutput>
result: TOutput
metadata: TMetadata
}
export interface TaskCallContext<
TInput extends ZodTypeAny = ZodTypeAny,
TInput = void,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
input?: ParsedData<TInput>
input?: TInput
retryMessage?: string
attemptNumber: number