feat: minor improvements to types and task

Travis Fischer 2023-06-11 00:01:14 -07:00
rodzic cff74009c5
commit 113fdebee3
8 zmienionych plików z 93 dodań i 71 usunięć

Wyświetl plik

@ -4,8 +4,8 @@ export type Metadata = Record<string, unknown>;
export abstract class BaseTask< export abstract class BaseTask<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny TOutput extends ZodTypeAny = ZodTypeAny
> { > {
// ... // ...

Wyświetl plik

@ -1,4 +1,4 @@
import { ZodRawShape, ZodTypeAny } from 'zod' import { ZodTypeAny } from 'zod'
import { Agentic } from '@/agentic' import { Agentic } from '@/agentic'
import { BaseTask } from '@/task' import { BaseTask } from '@/task'
@ -37,8 +37,8 @@ export class HumanFeedbackMechanismCLI extends HumanFeedbackMechanism {
} }
export function withHumanFeedback< export function withHumanFeedback<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny TOutput extends ZodTypeAny = ZodTypeAny
>( >(
task: BaseTask<TInput, TOutput>, task: BaseTask<TInput, TOutput>,
options: HumanFeedbackOptions = { options: HumanFeedbackOptions = {

Wyświetl plik

@ -86,4 +86,19 @@ export class AnthropicChatModel<
response response
} }
} }
public override clone(): AnthropicChatModel<TInput, TOutput> {
return new AnthropicChatModel<TInput, TOutput>({
agentic: this._agentic,
timeoutMs: this._timeoutMs,
retryConfig: this._retryConfig,
inputSchema: this._inputSchema,
outputSchema: this._outputSchema,
provider: this._provider,
model: this._model,
examples: this._examples,
messages: this._messages,
...this._modelParams
})
}
} }

Wyświetl plik

@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
import pMap from 'p-map' import pMap from 'p-map'
import { dedent } from 'ts-dedent' import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest' import { type SetRequired } from 'type-fest'
import { ZodRawShape, ZodTypeAny, z } from 'zod' import { ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts' import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors' import * as errors from '@/errors'
@ -20,8 +20,8 @@ import {
} from '@/utils' } from '@/utils'
export abstract class BaseLLM< export abstract class BaseLLM<
TInput extends ZodRawShape | ZodTypeAny = z.ZodVoid, TInput extends ZodTypeAny = z.ZodVoid,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTask<TInput, TOutput> { > extends BaseTask<TInput, TOutput> {
protected _inputSchema: TInput | undefined protected _inputSchema: TInput | undefined
@ -50,7 +50,7 @@ export abstract class BaseLLM<
this._examples = options.examples this._examples = options.examples
} }
input<U extends ZodRawShape | ZodTypeAny = TInput>( input<U extends ZodTypeAny = TInput>(
inputSchema: U inputSchema: U
): BaseLLM<U, TOutput, TModelParams> { ): BaseLLM<U, TOutput, TModelParams> {
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema = ;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema =
@ -58,7 +58,7 @@ export abstract class BaseLLM<
return this as unknown as BaseLLM<U, TOutput, TModelParams> return this as unknown as BaseLLM<U, TOutput, TModelParams>
} }
output<U extends ZodRawShape | ZodTypeAny = TOutput>( output<U extends ZodTypeAny = TOutput>(
outputSchema: U outputSchema: U
): BaseLLM<TInput, U, TModelParams> { ): BaseLLM<TInput, U, TModelParams> {
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema = ;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema =
@ -125,8 +125,8 @@ export abstract class BaseLLM<
} }
export abstract class BaseChatModel< export abstract class BaseChatModel<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>, TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any> TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> { > extends BaseLLM<TInput, TOutput, TModelParams> {
@ -152,13 +152,8 @@ export abstract class BaseChatModel<
ctx?: types.TaskCallContext ctx?: types.TaskCallContext
) { ) {
if (this._inputSchema) { if (this._inputSchema) {
const inputSchema =
this._inputSchema instanceof z.ZodType
? this._inputSchema
: z.object(this._inputSchema)
// TODO: handle errors gracefully // TODO: handle errors gracefully
input = inputSchema.parse(input) input = this.inputSchema.parse(input)
} }
// TODO: validate input message variables against input schema // TODO: validate input message variables against input schema
@ -185,12 +180,7 @@ export abstract class BaseChatModel<
} }
if (this._outputSchema) { if (this._outputSchema) {
const outputSchema = const { node } = zodToTs(this._outputSchema)
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const { node } = zodToTs(outputSchema)
if (node.kind === 152) { if (node.kind === 152) {
// handle raw strings differently // handle raw strings differently
@ -248,10 +238,7 @@ export abstract class BaseChatModel<
console.log('<<<') console.log('<<<')
if (this._outputSchema) { if (this._outputSchema) {
const outputSchema = const outputSchema = this._outputSchema
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
if (outputSchema instanceof z.ZodArray) { if (outputSchema instanceof z.ZodArray) {
try { try {

Wyświetl plik

@ -50,4 +50,19 @@ export class OpenAIChatModel<
messages messages
}) })
} }
public override clone(): OpenAIChatModel<TInput, TOutput> {
return new OpenAIChatModel<TInput, TOutput>({
agentic: this._agentic,
timeoutMs: this._timeoutMs,
retryConfig: this._retryConfig,
inputSchema: this._inputSchema,
outputSchema: this._outputSchema,
provider: this._provider,
model: this._model,
examples: this._examples,
messages: this._messages,
...this._modelParams
})
}
} }

Wyświetl plik

@ -1,5 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry' import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodRawShape, ZodTypeAny, z } from 'zod' import { ZodTypeAny } from 'zod'
import * as errors from '@/errors' import * as errors from '@/errors'
import * as types from '@/types' import * as types from '@/types'
@ -19,8 +19,8 @@ import { Agentic } from '@/agentic'
* - Invoking sub-agents * - Invoking sub-agents
*/ */
export abstract class BaseTask< export abstract class BaseTask<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny TOutput extends ZodTypeAny = ZodTypeAny
> { > {
protected _agentic: Agentic protected _agentic: Agentic
protected _id: string protected _id: string
@ -51,6 +51,24 @@ export abstract class BaseTask<
public abstract get name(): string public abstract get name(): string
public serialize(): types.SerializedTask {
return {
_taskName: this.name
// inputSchema: this.inputSchema.serialize()
}
}
// public abstract deserialize<
// TInput extends ZodTypeAny = ZodTypeAny,
// TOutput extends ZodTypeAny = ZodTypeAny
// >(data: types.SerializedTask): BaseTask<TInput, TOutput>
// TODO: is this really necessary?
public clone(): BaseTask<TInput, TOutput> {
// TODO: override in subclass if needed
throw new Error(`clone not implemented for task "${this.name}"`)
}
public retryConfig(retryConfig: types.RetryConfig) { public retryConfig(retryConfig: types.RetryConfig) {
this._retryConfig = retryConfig this._retryConfig = retryConfig
return this return this
@ -67,12 +85,7 @@ export abstract class BaseTask<
input?: types.ParsedData<TInput> input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
if (this.inputSchema) { if (this.inputSchema) {
const inputSchema = const safeInput = this.inputSchema.safeParse(input)
this.inputSchema instanceof z.ZodType
? this.inputSchema
: z.object(this.inputSchema)
const safeInput = inputSchema.safeParse(input)
if (!safeInput.success) { if (!safeInput.success) {
throw new Error(`Invalid input: ${safeInput.error.message}`) throw new Error(`Invalid input: ${safeInput.error.message}`)

Wyświetl plik

@ -38,13 +38,15 @@ export class MetaphorSearchTool extends BaseTask<
constructor({ constructor({
agentic, agentic,
metaphorClient = new MetaphorClient() metaphorClient = new MetaphorClient(),
...rest
}: { }: {
agentic: Agentic agentic: Agentic
metaphorClient?: MetaphorClient metaphorClient?: MetaphorClient
}) { } & types.BaseTaskOptions) {
super({ super({
agentic agentic,
...rest
}) })
this._metaphorClient = metaphorClient this._metaphorClient = metaphorClient
@ -66,11 +68,9 @@ export class MetaphorSearchTool extends BaseTask<
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema> ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
): Promise<MetaphorSearchToolOutput> { ): Promise<MetaphorSearchToolOutput> {
// TODO: test required inputs // TODO: test required inputs
const result = await this._metaphorClient.search({ return this._metaphorClient.search({
query: ctx.input!.query, query: ctx.input!.query,
numResults: ctx.input!.numResults numResults: ctx.input!.numResults
}) })
return result
} }
} }

Wyświetl plik

@ -1,33 +1,21 @@
import * as anthropic from '@anthropic-ai/sdk' import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch' import * as openai from 'openai-fetch'
import type { Options as RetryOptions } from 'p-retry' import type { Options as RetryOptions } from 'p-retry'
import { import type { JsonObject } from 'type-fest'
SafeParseReturnType, import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod'
ZodObject,
ZodRawShape,
ZodTypeAny,
output,
z
} from 'zod'
import type { Agentic } from './agentic' import type { Agentic } from './agentic'
export { openai } export { openai }
export { anthropic } export { anthropic }
export type ParsedData<T extends ZodRawShape | ZodTypeAny> = export type ParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
T extends ZodTypeAny ? output<T>
? output<T> : never
: T extends ZodRawShape
? output<ZodObject<T>>
: never
export type SafeParsedData<T extends ZodRawShape | ZodTypeAny> = export type SafeParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
T extends ZodTypeAny ? SafeParseReturnType<z.infer<T>, ParsedData<T>>
? SafeParseReturnType<z.infer<T>, ParsedData<T>> : never
: T extends ZodRawShape
? SafeParseReturnType<ZodObject<T>, ParsedData<T>>
: never
export interface BaseTaskOptions { export interface BaseTaskOptions {
agentic: Agentic agentic: Agentic
@ -43,8 +31,8 @@ export interface BaseTaskOptions {
} }
export interface BaseLLMOptions< export interface BaseLLMOptions<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTaskOptions { > extends BaseTaskOptions {
inputSchema?: TInput inputSchema?: TInput
@ -57,8 +45,8 @@ export interface BaseLLMOptions<
} }
export interface LLMOptions< export interface LLMOptions<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
promptTemplate?: string promptTemplate?: string
@ -81,8 +69,8 @@ export interface ChatMessage {
} }
export interface ChatModelOptions< export interface ChatModelOptions<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[] messages: ChatMessage[]
@ -132,7 +120,7 @@ export interface LLMTaskResponseMetadata<
} }
export interface TaskResponse< export interface TaskResponse<
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodTypeAny = z.ZodType<string>,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
result: ParsedData<TOutput> result: ParsedData<TOutput>
@ -140,7 +128,7 @@ export interface TaskResponse<
} }
export interface TaskCallContext< export interface TaskCallContext<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodTypeAny = ZodTypeAny,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
input?: ParsedData<TInput> input?: ParsedData<TInput>
@ -152,4 +140,8 @@ export interface TaskCallContext<
export type IDGeneratorFunction = () => string export type IDGeneratorFunction = () => string
export interface SerializedTask extends JsonObject {
_taskName: string
}
// export type ProgressFunction = (partialResponse: ChatMessage) => void // export type ProgressFunction = (partialResponse: ChatMessage) => void