From 72c499449d4ab7d801dbff792e6e9cdb6f737177 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Fri, 9 Jun 2023 23:11:10 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9A=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llms/llm.ts | 2 +- src/task.ts | 20 +++++++++++++++++--- src/tokenizer.ts | 4 ++-- src/tools/metaphor.ts | 17 ++++++----------- src/tools/novu.ts | 13 +++---------- src/types.ts | 1 - 6 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/llms/llm.ts b/src/llms/llm.ts index e127e64..663c9a6 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -227,7 +227,7 @@ export abstract class BaseChatModel< } protected override async _call( - ctx: types.TaskCallContext + ctx: types.TaskCallContext ): Promise> { const messages = await this.buildMessages(ctx.input, ctx) diff --git a/src/task.ts b/src/task.ts index f3274f6..efea6ca 100644 --- a/src/task.ts +++ b/src/task.ts @@ -1,5 +1,5 @@ import pRetry, { FailedAttemptError } from 'p-retry' -import { ZodRawShape, ZodTypeAny } from 'zod' +import { ZodRawShape, ZodTypeAny, z } from 'zod' import * as errors from '@/errors' import * as types from '@/types' @@ -65,7 +65,21 @@ export abstract class BaseTask< public async callWithMetadata( input?: types.ParsedData ): Promise> { - const ctx: types.TaskCallContext = { + if (this.inputSchema) { + const inputSchema = + this.inputSchema instanceof z.ZodType + ? this.inputSchema + : z.object(this.inputSchema) + + const safeInput = inputSchema.safeParse(input) + if (!safeInput.success) { + throw new Error(`Invalid input: ${safeInput.error.message}`) + } + + input = safeInput.data + } + + const ctx: types.TaskCallContext = { input, attemptNumber: 0, metadata: {} @@ -97,7 +111,7 @@ export abstract class BaseTask< } protected abstract _call( - ctx: types.TaskCallContext + ctx: types.TaskCallContext ): Promise> // TODO diff --git a/src/tokenizer.ts b/src/tokenizer.ts index 2a02da0..341cb43 100644 --- a/src/tokenizer.ts +++ b/src/tokenizer.ts @@ -151,13 +151,13 @@ export function getContextSizeForModel(model: string): number { } } -export const calculateMaxTokens = async ({ +export async function calculateMaxTokens({ prompt, modelName }: { prompt: string modelName: string -}) => { +}) { // fallback to approximate calculation if tiktoken is not available let numTokens = Math.ceil(prompt.length / 4) diff --git a/src/tools/metaphor.ts b/src/tools/metaphor.ts index 45155f1..2c59f08 100644 --- a/src/tools/metaphor.ts +++ b/src/tools/metaphor.ts @@ -59,19 +59,14 @@ export class MetaphorSearchTool extends BaseTask< } protected override async _call( - input: MetaphorSearchToolInput - ): Promise> { - // TODO: handle errors gracefully - input = this.inputSchema.parse(input) - + ctx: types.TaskCallContext + ): Promise { + // TODO: test required inputs const result = await this._metaphorClient.search({ - query: input.query, - numResults: input.numResults + query: ctx.input!.query, + numResults: ctx.input!.numResults }) - return { - result, - metadata: {} - } + return result } } diff --git a/src/tools/novu.ts b/src/tools/novu.ts index 9bf380f..1af1258 100644 --- a/src/tools/novu.ts +++ b/src/tools/novu.ts @@ -64,15 +64,8 @@ export class NovuNotificationTool extends BaseTask< } protected override async _call( - input: NovuNotificationToolInput - ): Promise> { - // TODO: handle errors gracefully - input = this.inputSchema.parse(input) - - const result = await this._novuClient.triggerEvent(input) - return { - result, - metadata: {} - } + ctx: types.TaskCallContext + ): Promise { + return this._novuClient.triggerEvent(ctx.input!) } } diff --git a/src/types.ts b/src/types.ts index 22def8a..6d2a2ac 100644 --- a/src/types.ts +++ b/src/types.ts @@ -146,7 +146,6 @@ export interface TaskResponse< export interface TaskCallContext< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { input?: ParsedData