feat: minor improvements to types and task

Travis Fischer 2023-06-11 00:01:14 -07:00
rodzic cff74009c5
commit 113fdebee3
8 zmienionych plików z 93 dodań i 71 usunięć

Wyświetl plik

@ -4,8 +4,8 @@ export type Metadata = Record<string, unknown>;
export abstract class BaseTask<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
> {
// ...

Wyświetl plik

@ -1,4 +1,4 @@
import { ZodRawShape, ZodTypeAny } from 'zod'
import { ZodTypeAny } from 'zod'
import { Agentic } from '@/agentic'
import { BaseTask } from '@/task'
@ -37,8 +37,8 @@ export class HumanFeedbackMechanismCLI extends HumanFeedbackMechanism {
}
export function withHumanFeedback<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
>(
task: BaseTask<TInput, TOutput>,
options: HumanFeedbackOptions = {

Wyświetl plik

@ -86,4 +86,19 @@ export class AnthropicChatModel<
response
}
}
public override clone(): AnthropicChatModel<TInput, TOutput> {
return new AnthropicChatModel<TInput, TOutput>({
agentic: this._agentic,
timeoutMs: this._timeoutMs,
retryConfig: this._retryConfig,
inputSchema: this._inputSchema,
outputSchema: this._outputSchema,
provider: this._provider,
model: this._model,
examples: this._examples,
messages: this._messages,
...this._modelParams
})
}
}

Wyświetl plik

@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
import pMap from 'p-map'
import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest'
import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors'
@ -20,8 +20,8 @@ import {
} from '@/utils'
export abstract class BaseLLM<
TInput extends ZodRawShape | ZodTypeAny = z.ZodVoid,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TInput extends ZodTypeAny = z.ZodVoid,
TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTask<TInput, TOutput> {
protected _inputSchema: TInput | undefined
@ -50,7 +50,7 @@ export abstract class BaseLLM<
this._examples = options.examples
}
input<U extends ZodRawShape | ZodTypeAny = TInput>(
input<U extends ZodTypeAny = TInput>(
inputSchema: U
): BaseLLM<U, TOutput, TModelParams> {
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema =
@ -58,7 +58,7 @@ export abstract class BaseLLM<
return this as unknown as BaseLLM<U, TOutput, TModelParams>
}
output<U extends ZodRawShape | ZodTypeAny = TOutput>(
output<U extends ZodTypeAny = TOutput>(
outputSchema: U
): BaseLLM<TInput, U, TModelParams> {
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema =
@ -125,8 +125,8 @@ export abstract class BaseLLM<
}
export abstract class BaseChatModel<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> {
@ -152,13 +152,8 @@ export abstract class BaseChatModel<
ctx?: types.TaskCallContext
) {
if (this._inputSchema) {
const inputSchema =
this._inputSchema instanceof z.ZodType
? this._inputSchema
: z.object(this._inputSchema)
// TODO: handle errors gracefully
input = inputSchema.parse(input)
input = this.inputSchema.parse(input)
}
// TODO: validate input message variables against input schema
@ -185,12 +180,7 @@ export abstract class BaseChatModel<
}
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const { node } = zodToTs(outputSchema)
const { node } = zodToTs(this._outputSchema)
if (node.kind === 152) {
// handle raw strings differently
@ -248,10 +238,7 @@ export abstract class BaseChatModel<
console.log('<<<')
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const outputSchema = this._outputSchema
if (outputSchema instanceof z.ZodArray) {
try {

Wyświetl plik

@ -50,4 +50,19 @@ export class OpenAIChatModel<
messages
})
}
public override clone(): OpenAIChatModel<TInput, TOutput> {
return new OpenAIChatModel<TInput, TOutput>({
agentic: this._agentic,
timeoutMs: this._timeoutMs,
retryConfig: this._retryConfig,
inputSchema: this._inputSchema,
outputSchema: this._outputSchema,
provider: this._provider,
model: this._model,
examples: this._examples,
messages: this._messages,
...this._modelParams
})
}
}

Wyświetl plik

@ -1,5 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { ZodTypeAny } from 'zod'
import * as errors from '@/errors'
import * as types from '@/types'
@ -19,8 +19,8 @@ import { Agentic } from '@/agentic'
* - Invoking sub-agents
*/
export abstract class BaseTask<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
> {
protected _agentic: Agentic
protected _id: string
@ -51,6 +51,24 @@ export abstract class BaseTask<
public abstract get name(): string
public serialize(): types.SerializedTask {
return {
_taskName: this.name
// inputSchema: this.inputSchema.serialize()
}
}
// public abstract deserialize<
// TInput extends ZodTypeAny = ZodTypeAny,
// TOutput extends ZodTypeAny = ZodTypeAny
// >(data: types.SerializedTask): BaseTask<TInput, TOutput>
// TODO: is this really necessary?
public clone(): BaseTask<TInput, TOutput> {
// TODO: override in subclass if needed
throw new Error(`clone not implemented for task "${this.name}"`)
}
public retryConfig(retryConfig: types.RetryConfig) {
this._retryConfig = retryConfig
return this
@ -67,12 +85,7 @@ export abstract class BaseTask<
input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> {
if (this.inputSchema) {
const inputSchema =
this.inputSchema instanceof z.ZodType
? this.inputSchema
: z.object(this.inputSchema)
const safeInput = inputSchema.safeParse(input)
const safeInput = this.inputSchema.safeParse(input)
if (!safeInput.success) {
throw new Error(`Invalid input: ${safeInput.error.message}`)

Wyświetl plik

@ -38,13 +38,15 @@ export class MetaphorSearchTool extends BaseTask<
constructor({
agentic,
metaphorClient = new MetaphorClient()
metaphorClient = new MetaphorClient(),
...rest
}: {
agentic: Agentic
metaphorClient?: MetaphorClient
}) {
} & types.BaseTaskOptions) {
super({
agentic
agentic,
...rest
})
this._metaphorClient = metaphorClient
@ -66,11 +68,9 @@ export class MetaphorSearchTool extends BaseTask<
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
): Promise<MetaphorSearchToolOutput> {
// TODO: test required inputs
const result = await this._metaphorClient.search({
return this._metaphorClient.search({
query: ctx.input!.query,
numResults: ctx.input!.numResults
})
return result
}
}

Wyświetl plik

@ -1,33 +1,21 @@
import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch'
import type { Options as RetryOptions } from 'p-retry'
import {
SafeParseReturnType,
ZodObject,
ZodRawShape,
ZodTypeAny,
output,
z
} from 'zod'
import type { JsonObject } from 'type-fest'
import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic'
export { openai }
export { anthropic }
export type ParsedData<T extends ZodRawShape | ZodTypeAny> =
T extends ZodTypeAny
? output<T>
: T extends ZodRawShape
? output<ZodObject<T>>
: never
export type ParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
? output<T>
: never
export type SafeParsedData<T extends ZodRawShape | ZodTypeAny> =
T extends ZodTypeAny
? SafeParseReturnType<z.infer<T>, ParsedData<T>>
: T extends ZodRawShape
? SafeParseReturnType<ZodObject<T>, ParsedData<T>>
: never
export type SafeParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
? SafeParseReturnType<z.infer<T>, ParsedData<T>>
: never
export interface BaseTaskOptions {
agentic: Agentic
@ -43,8 +31,8 @@ export interface BaseTaskOptions {
}
export interface BaseLLMOptions<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTaskOptions {
inputSchema?: TInput
@ -57,8 +45,8 @@ export interface BaseLLMOptions<
}
export interface LLMOptions<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
promptTemplate?: string
@ -81,8 +69,8 @@ export interface ChatMessage {
}
export interface ChatModelOptions<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[]
@ -132,7 +120,7 @@ export interface LLMTaskResponseMetadata<
}
export interface TaskResponse<
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TOutput extends ZodTypeAny = z.ZodType<string>,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
result: ParsedData<TOutput>
@ -140,7 +128,7 @@ export interface TaskResponse<
}
export interface TaskCallContext<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TInput extends ZodTypeAny = ZodTypeAny,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
input?: ParsedData<TInput>
@ -152,4 +140,8 @@ export interface TaskCallContext<
export type IDGeneratorFunction = () => string
export interface SerializedTask extends JsonObject {
_taskName: string
}
// export type ProgressFunction = (partialResponse: ChatMessage) => void