From cb2ea1a2eff7f22db9811022a7a38fef6448b0c2 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Thu, 15 Jun 2023 23:49:56 -0700 Subject: [PATCH] feat: improve robustness of serpapi, diffbot, and add ky rate limiting --- examples/llm-with-search.ts | 4 +- package.json | 4 +- pnpm-lock.yaml | 10 +- src/errors.ts | 6 +- src/human-feedback/feedback.ts | 4 +- src/llms/anthropic.ts | 4 +- src/llms/chat.ts | 43 +++++--- src/llms/llm.ts | 8 +- src/llms/openai.ts | 4 +- src/services/diffbot.ts | 32 +++++- src/services/serpapi.ts | 1 + src/task.ts | 4 +- src/tools/diffbot.ts | 125 +++++++++++++++++++++ src/tools/index.ts | 1 + src/tools/serpapi.ts | 31 +++--- src/types.ts | 27 +++-- src/utils.ts | 57 +++++++++- test/_utils.ts | 194 +++++++++++++++++++++++---------- test/services/diffbot.test.ts | 2 +- test/utils.test.ts | 54 ++++++++- 20 files changed, 498 insertions(+), 117 deletions(-) create mode 100644 src/tools/diffbot.ts diff --git a/examples/llm-with-search.ts b/examples/llm-with-search.ts index 5a08aa9..a1f61cb 100644 --- a/examples/llm-with-search.ts +++ b/examples/llm-with-search.ts @@ -2,7 +2,7 @@ import { OpenAIClient } from '@agentic/openai-fetch' import 'dotenv/config' import { z } from 'zod' -import { Agentic, SerpAPITool } from '@/index' +import { Agentic, DiffbotTool, SerpAPITool } from '@/index' async function main() { const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! }) @@ -12,7 +12,7 @@ async function main() { .gpt4( `Can you summarize the top {{numResults}} results for today's news about {{topic}}?` ) - .tools([new SerpAPITool()]) + .tools([new SerpAPITool(), new DiffbotTool()]) .input( z.object({ topic: z.string(), diff --git a/package.json b/package.json index f0998cf..6df91cc 100644 --- a/package.json +++ b/package.json @@ -56,6 +56,7 @@ "normalize-url": "^8.0.0", "p-map": "^6.0.0", "p-retry": "^5.1.2", + "p-throttle": "^5.1.0", "p-timeout": "^6.1.2", "pino": "^8.14.1", "pino-pretty": "^10.0.0", @@ -104,6 +105,7 @@ }, "ava": { "snapshotDir": "test/.snapshots", + "failFast": true, "extensions": { "ts": "module" }, @@ -126,4 +128,4 @@ "guardrails", "plugins" ] -} \ No newline at end of file +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index be62312..e9d231e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.1' +lockfileVersion: '6.0' settings: autoInstallPeers: true @@ -59,6 +59,9 @@ dependencies: p-retry: specifier: ^5.1.2 version: 5.1.2 + p-throttle: + specifier: ^5.1.0 + version: 5.1.0 p-timeout: specifier: ^6.1.2 version: 6.1.2 @@ -3183,6 +3186,11 @@ packages: retry: 0.13.1 dev: false + /p-throttle@5.1.0: + resolution: {integrity: sha512-+N+s2g01w1Zch4D0K3OpnPDqLOKmLcQ4BvIFq3JC0K29R28vUOjWpO+OJZBNt8X9i3pFCksZJZ0YXkUGjaFE6g==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + dev: false + /p-timeout@5.1.0: resolution: {integrity: sha512-auFDyzzzGZZZdHz3BtET9VEz0SE/uMEAx7uWfGPucfzEwwe/xH0iVeZibQmANYE/hp9T2+UUZT5m+BKyrDp3Ew==} engines: {node: '>=12'} diff --git a/src/errors.ts b/src/errors.ts index 63f09a2..8ae22b1 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -1,4 +1,4 @@ -import type { JsonObject } from 'type-fest' +import type { Jsonifiable } from 'type-fest' import type { ZodError } from 'zod' import { ValidationError, fromZodError } from 'zod-validation-error' @@ -10,12 +10,12 @@ export type ErrorOptions = { cause?: unknown /** Additional context to be added to the error. */ - context?: JsonObject + context?: Jsonifiable } export class BaseError extends Error { status?: number - context?: JsonObject + context?: Jsonifiable constructor(message: string, opts: ErrorOptions = {}) { if (opts.cause) { diff --git a/src/human-feedback/feedback.ts b/src/human-feedback/feedback.ts index af20800..5063eb2 100644 --- a/src/human-feedback/feedback.ts +++ b/src/human-feedback/feedback.ts @@ -273,8 +273,8 @@ export abstract class HumanFeedbackMechanism< } export function withHumanFeedback< - TInput extends void | types.JsonObject, - TOutput extends types.JsonValue, + TInput extends types.TaskInput, + TOutput extends types.TaskOutput, V extends HumanFeedbackType >( task: BaseTask, diff --git a/src/llms/anthropic.ts b/src/llms/anthropic.ts index 2a267cd..c79507c 100644 --- a/src/llms/anthropic.ts +++ b/src/llms/anthropic.ts @@ -9,8 +9,8 @@ import { BaseChatCompletion } from './chat' const defaultStopSequences = [anthropic.HUMAN_PROMPT] export class AnthropicChatCompletion< - TInput extends void | types.JsonObject = any, - TOutput extends types.JsonValue = string + TInput extends types.TaskInput = any, + TOutput extends types.TaskOutput = string > extends BaseChatCompletion< TInput, TOutput, diff --git a/src/llms/chat.ts b/src/llms/chat.ts index b08573e..bebbcb5 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -9,6 +9,7 @@ import * as types from '@/types' import { BaseTask } from '@/task' import { getCompiledTemplate } from '@/template' import { + extractFunctionIdentifierFromString, extractJSONArrayFromString, extractJSONObjectFromString } from '@/utils' @@ -20,8 +21,8 @@ import { } from './llm-utils' export abstract class BaseChatCompletion< - TInput extends void | types.JsonObject = void, - TOutput extends types.JsonValue = string, + TInput extends types.TaskInput = void, + TOutput extends types.TaskOutput = string, TModelParams extends Record = Record, TChatCompletionResponse extends Record = Record > extends BaseLLM { @@ -41,7 +42,7 @@ export abstract class BaseChatCompletion< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - input( + input( inputSchema: ZodType ): BaseChatCompletion { const refinedInstance = this as unknown as BaseChatCompletion< @@ -54,7 +55,7 @@ export abstract class BaseChatCompletion< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - output( + output( outputSchema: ZodType ): BaseChatCompletion { const refinedInstance = this as unknown as BaseChatCompletion< @@ -242,9 +243,10 @@ export abstract class BaseChatCompletion< `<<< Task createChatCompletion "${this.nameForHuman}"` ) ctx.metadata.completion = completion + const message = completion.message - if (completion.message.function_call) { - const functionCall = completion.message.function_call + if (message.function_call) { + const functionCall = message.function_call if (!isUsingTools) { // TODO: not sure what we should do in this case... @@ -252,16 +254,31 @@ export abstract class BaseChatCompletion< break } + const functionName = extractFunctionIdentifierFromString( + functionCall.name + ) + + if (!functionName) { + throw new errors.OutputValidationError( + `Unrecognized function call "${functionCall.name}"` + ) + } + const tool = this._tools!.find( - (tool) => tool.nameForModel === functionCall.name + (tool) => tool.nameForModel === functionName ) if (!tool) { throw new errors.OutputValidationError( - `Function not found "${functionCall.name}"` + `Function not found "${functionName}"` ) } + if (functionName !== functionCall.name) { + // fix function name hallucinations + functionCall.name = functionName + } + let functionArguments: any try { functionArguments = JSON.parse(jsonrepair(functionCall.arguments)) @@ -281,12 +298,12 @@ export abstract class BaseChatCompletion< } // console.log('>>> sub-task', { - // task: functionCall.name, + // task: functionName, // input: functionArguments // }) this._logger.info( { - task: functionCall.name, + task: functionName, input: functionArguments }, `>>> Sub-Task "${tool.nameForHuman}"` @@ -297,14 +314,14 @@ export abstract class BaseChatCompletion< this._logger.info( { - task: functionCall.name, + task: functionName, input: functionArguments, output: toolCallResponse.result }, `<<< Sub-Task "${tool.nameForHuman}"` ) // console.log('<<< sub-task', { - // task: functionCall.name, + // task: functionName, // input: functionArguments, // output: toolCallResponse.result // }) @@ -321,7 +338,7 @@ export abstract class BaseChatCompletion< messages.push(completion.message as any) messages.push({ role: 'function', - name: functionCall.name, + name: functionName, content: taskCallContent }) diff --git a/src/llms/llm.ts b/src/llms/llm.ts index bceb26c..5d042a5 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -7,8 +7,8 @@ import { Tokenizer, getTokenizerForModel } from '@/tokenizer' // TODO: TInput should only be allowed to be void or an object export abstract class BaseLLM< - TInput extends void | types.JsonObject = void, - TOutput extends types.JsonValue = string, + TInput extends types.TaskInput = void, + TOutput extends types.TaskOutput = string, TModelParams extends Record = Record > extends BaseTask { protected _inputSchema: ZodType | undefined @@ -38,7 +38,7 @@ export abstract class BaseLLM< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - input( + input( inputSchema: ZodType ): BaseLLM { const refinedInstance = this as unknown as BaseLLM @@ -47,7 +47,7 @@ export abstract class BaseLLM< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - output( + output( outputSchema: ZodType ): BaseLLM { const refinedInstance = this as unknown as BaseLLM diff --git a/src/llms/openai.ts b/src/llms/openai.ts index ee8f501..09d6202 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -14,8 +14,8 @@ const openaiModelsSupportingFunctions = new Set([ ]) export class OpenAIChatCompletion< - TInput extends void | types.JsonObject = any, - TOutput extends types.JsonValue = string + TInput extends types.TaskInput = any, + TOutput extends types.TaskOutput = string > extends BaseChatCompletion< TInput, TOutput, diff --git a/src/services/diffbot.ts b/src/services/diffbot.ts index 235a27c..a311be2 100644 --- a/src/services/diffbot.ts +++ b/src/services/diffbot.ts @@ -1,4 +1,7 @@ import defaultKy from 'ky' +import pThrottle from 'p-throttle' + +import { throttleKy } from '@/utils' export const DIFFBOT_API_BASE_URL = 'https://api.diffbot.com' export const DIFFBOT_KNOWLEDGE_GRAPH_API_BASE_URL = 'https://kg.diffbot.com' @@ -94,9 +97,17 @@ export interface DiffbotObject { categories?: DiffbotCategory[] authors: DiffbotAuthor[] breadcrumb?: DiffbotBreadcrumb[] + items?: DiffbotListItem[] meta?: any } +interface DiffbotListItem { + title: string + link: string + summary: string + image?: string +} + interface DiffbotAuthor { name: string link: string @@ -307,6 +318,12 @@ interface DiffbotSkill { diffbotUri: string } +const throttle = pThrottle({ + limit: 5, + interval: 1000, + strict: true +}) + export class DiffbotClient { api: typeof defaultKy apiKnowledgeGraph: typeof defaultKy @@ -336,8 +353,14 @@ export class DiffbotClient { this.apiBaseUrl = apiBaseUrl this.apiKnowledgeGraphBaseUrl = apiKnowledgeGraphBaseUrl - this.api = ky.extend({ prefixUrl: apiBaseUrl, timeout: timeoutMs }) - this.apiKnowledgeGraph = ky.extend({ + const throttledKy = throttleKy(ky, throttle) + + this.api = throttledKy.extend({ + prefixUrl: apiBaseUrl, + timeout: timeoutMs + }) + + this.apiKnowledgeGraph = throttledKy.extend({ prefixUrl: apiKnowledgeGraphBaseUrl, timeout: timeoutMs }) @@ -364,10 +387,13 @@ export class DiffbotClient { } } + console.log(`DiffbotClient._extract: ${endpoint}`, searchParams) + return this.api .get(endpoint, { searchParams, - headers + headers, + retry: 2 }) .json() } diff --git a/src/services/serpapi.ts b/src/services/serpapi.ts index 6904d87..ccf9985 100644 --- a/src/services/serpapi.ts +++ b/src/services/serpapi.ts @@ -656,6 +656,7 @@ export class SerpAPIClient { : queryOrOpts const { timeout, ...rest } = this.params + // console.log(options) return this.api .get('search', { searchParams: { diff --git a/src/task.ts b/src/task.ts index de8b9a0..224d15a 100644 --- a/src/task.ts +++ b/src/task.ts @@ -20,8 +20,8 @@ import { defaultIDGeneratorFn, isValidTaskIdentifier } from './utils' * - Invoking sub-agents */ export abstract class BaseTask< - TInput extends void | types.JsonObject = void, - TOutput extends types.JsonValue = string + TInput extends types.TaskInput = void, + TOutput extends types.TaskOutput = string > { protected _agentic: Agentic protected _id: string diff --git a/src/tools/diffbot.ts b/src/tools/diffbot.ts new file mode 100644 index 0000000..262e072 --- /dev/null +++ b/src/tools/diffbot.ts @@ -0,0 +1,125 @@ +import { z } from 'zod' + +import * as types from '@/types' +import { DiffbotClient } from '@/services/diffbot' +import { BaseTask } from '@/task' +import { omit, pick } from '@/utils' + +export const DiffbotInputSchema = z.object({ + url: z.string().describe('URL of page to scrape') +}) +export type DiffbotInput = z.infer + +export const DiffbotImageSchema = z.object({ + url: z.string().optional(), + naturalWidth: z.number().optional(), + naturalHeight: z.number().optional(), + width: z.number().optional(), + height: z.number().optional(), + isCached: z.boolean().optional(), + primary: z.boolean().optional() +}) + +export const DiffbotListItemSchema = z.object({ + title: z.string().optional(), + link: z.string().optional(), + summary: z.string().optional(), + image: z.string().optional() +}) + +export const DiffbotObjectSchema = z.object({ + type: z.string().optional(), + title: z.string().optional(), + siteName: z.string().optional(), + author: z.string().optional(), + authorUrl: z.string().optional(), + pageUrl: z.string().optional(), + date: z.string().optional(), + estimatedDate: z.string().optional(), + humanLanguage: z.string().optional(), + text: z.string().describe('core text content of the page').optional(), + tags: z.array(z.string()).optional(), + images: z.array(DiffbotImageSchema).optional(), + items: z.array(DiffbotListItemSchema).optional() +}) + +export const DiffbotOutputSchema = z.object({ + type: z.string().optional(), + title: z.string().optional(), + objects: z.array(DiffbotObjectSchema).optional() +}) +export type DiffbotOutput = z.infer + +export class DiffbotTool extends BaseTask { + protected _diffbotClient: DiffbotClient + + constructor( + opts: { + diffbot?: DiffbotClient + } & types.BaseTaskOptions = {} + ) { + super(opts) + + this._diffbotClient = + opts.diffbot ?? new DiffbotClient({ ky: opts.agentic?.ky }) + } + + public override get inputSchema() { + return DiffbotInputSchema + } + + public override get outputSchema() { + return DiffbotOutputSchema + } + + public override get nameForModel(): string { + return 'scrapeWebPage' + } + + public override get nameForHuman(): string { + return 'Diffbot Scrape Web Page' + } + + public override get descForModel(): string { + return 'Scrapes a web page for its content and structured data.' + } + + protected override async _call( + ctx: types.TaskCallContext + ): Promise { + const res = await this._diffbotClient.extractAnalyze({ + url: ctx.input!.url + }) + + this._logger.info(res, `Diffbot response for url "${ctx.input!.url}"`) + console.log(res) + + const pickedRes = { + type: res.type, + title: res.title, + objects: res.objects.map((obj) => ({ + ...pick( + obj, + 'type', + 'siteName', + 'author', + 'authorUrl', + 'pageUrl', + 'date', + 'estimatedDate', + 'humanLanguage', + 'items', + 'text' + ), + tags: obj.tags?.map((tag) => tag.label), + images: obj.images?.map((image) => omit(image, 'diffbotUri')) + })) + } + + this._logger.info( + pickedRes, + `Diffbot picked response for url "${ctx.input!.url}"` + ) + return this.outputSchema.parse(pickedRes) + } +} diff --git a/src/tools/index.ts b/src/tools/index.ts index 2bb73cb..d3c0f66 100644 --- a/src/tools/index.ts +++ b/src/tools/index.ts @@ -1,4 +1,5 @@ export * from './calculator' +export * from './diffbot' export * from './metaphor' export * from './novu' export * from './serpapi' diff --git a/src/tools/serpapi.ts b/src/tools/serpapi.ts index 8181e92..22124f5 100644 --- a/src/tools/serpapi.ts +++ b/src/tools/serpapi.ts @@ -11,32 +11,32 @@ export const SerpAPIInputSchema = z.object({ export type SerpAPIInput = z.infer export const SerpAPIOrganicSearchResult = z.object({ - position: z.number(), - title: z.string(), - link: z.string(), - displayed_link: z.string(), - snippet: z.string(), + position: z.number().optional(), + title: z.string().optional(), + link: z.string().optional(), + displayed_link: z.string().optional(), + snippet: z.string().optional(), source: z.string().optional(), date: z.string().optional() }) export const SerpAPIAnswerBox = z.object({ - type: z.string(), - title: z.string(), - link: z.string(), - displayed_link: z.string(), - snippet: z.string() + type: z.string().optional(), + title: z.string().optional(), + link: z.string().optional(), + displayed_link: z.string().optional(), + snippet: z.string().optional() }) export const SerpAPIKnowledgeGraph = z.object({ - type: z.string(), - description: z.string() + type: z.string().optional(), + description: z.string().optional() }) export const SerpAPIOutputSchema = z.object({ knowledgeGraph: SerpAPIKnowledgeGraph.optional(), answerBox: SerpAPIAnswerBox.optional(), - organicResults: z.array(SerpAPIOrganicSearchResult) + organicResults: z.array(SerpAPIOrganicSearchResult).optional() }) export type SerpAPIOutput = z.infer @@ -82,7 +82,10 @@ export class SerpAPITool extends BaseTask { num: ctx.input!.numResults }) - this._logger.debug(res, `SerpAPI response for query "${ctx.input!.query}"`) + this._logger.debug( + res, + `SerpAPI response for query ${JSON.stringify(ctx.input, null, 2)}"` + ) return this.outputSchema.parse({ knowledgeGraph: res.knowledge_graph, diff --git a/src/types.ts b/src/types.ts index 6474b6b..181bee2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2,7 +2,7 @@ import * as openai from '@agentic/openai-fetch' import * as anthropic from '@anthropic-ai/sdk' import ky from 'ky' import type { Options as RetryOptions } from 'p-retry' -import type { JsonObject, JsonValue } from 'type-fest' +import type { JsonObject, Jsonifiable } from 'type-fest' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import type { Agentic } from './agentic' @@ -15,9 +15,16 @@ import type { BaseTask } from './task' export { anthropic, openai } -export type { JsonObject, JsonValue, Logger } +export type { Jsonifiable, Logger } export type KyInstance = typeof ky +export type JsonifiableObject = + | { [Key in string]?: Jsonifiable } + | { toJSON: () => Jsonifiable } + +export type TaskInput = void | JsonifiableObject +export type TaskOutput = Jsonifiable + export type ParsedData = T extends ZodTypeAny ? output : never @@ -40,8 +47,8 @@ export interface BaseTaskOptions { } export interface BaseLLMOptions< - TInput extends void | JsonObject = void, - TOutput extends JsonValue = string, + TInput extends TaskInput = void, + TOutput extends TaskOutput = string, TModelParams extends Record = Record > extends BaseTaskOptions { inputSchema?: ZodType @@ -54,8 +61,8 @@ export interface BaseLLMOptions< } export interface LLMOptions< - TInput extends void | JsonObject = void, - TOutput extends JsonValue = string, + TInput extends TaskInput = void, + TOutput extends TaskOutput = string, TModelParams extends Record = Record > extends BaseLLMOptions { promptTemplate?: string @@ -67,8 +74,8 @@ export type ChatMessage = openai.ChatMessage export type ChatMessageRole = openai.ChatMessageRole export interface ChatModelOptions< - TInput extends void | JsonObject = void, - TOutput extends JsonValue = string, + TInput extends TaskInput = void, + TOutput extends TaskOutput = string, TModelParams extends Record = Record > extends BaseLLMOptions { messages: ChatMessage[] @@ -116,7 +123,7 @@ export interface LLMTaskResponseMetadata< } export interface TaskResponse< - TOutput extends JsonValue = string, + TOutput extends TaskOutput = string, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { result: TOutput @@ -124,7 +131,7 @@ export interface TaskResponse< } export interface TaskCallContext< - TInput extends void | JsonObject = void, + TInput extends TaskInput = void, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { input?: TInput diff --git a/src/utils.ts b/src/utils.ts index b0b58c3..819a2e8 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,5 @@ import { customAlphabet, urlAlphabet } from 'nanoid' +import type { ThrottledFunction } from 'p-throttle' import * as types from './types' @@ -43,6 +44,30 @@ export function isValidTaskIdentifier(id: string): boolean { return !!id && taskNameRegex.test(id) } +export function extractFunctionIdentifierFromString( + text: string +): string | undefined { + text = text?.trim() + + if (!text) { + return + } + + if (isValidTaskIdentifier(text)) { + return text + } + + const splits = text + .split(/[^a-zA-Z0-9_-]/) + .map((s) => { + s = s.trim() + return isValidTaskIdentifier(s) ? s : undefined + }) + .filter(Boolean) + + return splits[splits.length - 1] +} + /** * Chunks a string into an array of chunks. * @@ -81,7 +106,7 @@ export function chunkString(text: string, maxLength: number): string[] { * @param json - JSON value to stringify * @returns stringified value with all double quotes around object keys removed */ -export function stringifyForModel(json: types.JsonValue): string { +export function stringifyForModel(json: types.TaskOutput): string { const UNIQUE_PREFIX = defaultIDGeneratorFn() return ( JSON.stringify(json, replacer) @@ -112,3 +137,33 @@ export function stringifyForModel(json: types.JsonValue): string { return value } } + +export function pick(obj: T, ...keys: string[]): U { + return Object.fromEntries( + keys.filter((key) => key in obj).map((key) => [key, obj[key]]) + ) as U +} + +export function omit(obj: T, ...keys: string[]): U { + return Object.fromEntries( + Object.entries(obj).filter(([key]) => !keys.includes(key)) + ) as U +} + +const noop = () => undefined + +/** + * Throttles HTTP requests made by a ky instance. Very useful for enforcing rate limits. + */ +export function throttleKy( + ky: types.KyInstance, + throttleFn: ( + function_: (...args: Argument) => ReturnValue + ) => ThrottledFunction +) { + return ky.extend({ + hooks: { + beforeRequest: [throttleFn(noop)] + } + }) +} diff --git a/test/_utils.ts b/test/_utils.ts index afcb205..5c34e32 100644 --- a/test/_utils.ts +++ b/test/_utils.ts @@ -5,9 +5,10 @@ import 'dotenv/config' import hashObject from 'hash-obj' import Redis from 'ioredis' import Keyv from 'keyv' -import defaultKy from 'ky' +import defaultKy, { AfterResponseHook, BeforeRequestHook } from 'ky' import pMemoize from 'p-memoize' +import * as types from '@/types' import { Agentic } from '@/agentic' import { normalizeUrl } from '@/url-utils' @@ -62,6 +63,9 @@ function getCacheKeyForRequest(request: Request): string | null { return null } +const AGENTIC_TEST_CACHE_HEADER = 'x-agentic-test-cache' +const AGENTIC_TEST_MOCK_HEADER = 'x-agentic-test-mock' + /** * Custom `ky` instance that caches GET JSON requests. * @@ -69,68 +73,148 @@ function getCacheKeyForRequest(request: Request): string | null { * - support non-GET requests * - support non-JSON responses */ -export const ky = defaultKy.extend({ - hooks: { - beforeRequest: [ - async (request) => { - try { - // console.log(`beforeRequest ${request.method} ${request.url}`) +export function createTestKyInstance( + ky: types.KyInstance = defaultKy +): types.KyInstance { + return ky.extend({ + hooks: { + beforeRequest: [ + async (request) => { + try { + const cacheKey = getCacheKeyForRequest(request) + // console.log( + // `beforeRequest ${request.method} ${request.url} ⇒ ${cacheKey}` + // ) - const cacheKey = getCacheKeyForRequest(request) - // console.log({ cacheKey }) - if (!cacheKey) { - return + // console.log({ cacheKey }) + if (!cacheKey) { + return + } + + if (!(await keyv.has(cacheKey))) { + return + } + + const cachedResponse = await keyv.get(cacheKey) + // console.log({ cachedResponse }) + + if (!cachedResponse) { + return + } + + return new Response(JSON.stringify(cachedResponse), { + status: 200, + headers: { + 'Content-Type': 'application/json', + [AGENTIC_TEST_CACHE_HEADER]: '1' + } + }) + } catch (err) { + console.error('ky beforeResponse cache error', err) } - - if (!(await keyv.has(cacheKey))) { - return - } - - const cachedResponse = await keyv.get(cacheKey) - // console.log({ cachedResponse }) - - if (!cachedResponse) { - return - } - - return new Response(JSON.stringify(cachedResponse), { - status: 200, - headers: { 'Content-Type': 'application/json' } - }) - } catch (err) { - console.error('ky beforeResponse cache error', err) } - } - ], + ], - afterResponse: [ - async (request, _options, response) => { - try { - // console.log( - // `afterRequest ${request.method} ${request.url} ⇒ ${response.status}` - // ) + afterResponse: [ + async (request, _options, response) => { + try { + if (response.headers.get(AGENTIC_TEST_CACHE_HEADER)) { + // console.log('cached') + return + } - if (response.status < 200 || response.status >= 300) { - return + if (response.headers.get(AGENTIC_TEST_MOCK_HEADER)) { + // console.log('mocked') + return + } + + const contentType = response.headers.get('content-type') + // console.log( + // `afterRequest ${request.method} ${request.url} ⇒ ${response.status} ${contentType}` + // ) + + if (response.status < 200 || response.status >= 300) { + return + } + + if (contentType !== 'application/json') { + return + } + + const cacheKey = getCacheKeyForRequest(request) + // console.log({ cacheKey }) + if (!cacheKey) { + console.log('222') + return + } + + const responseBody = await response.json() + // console.log({ responseBody }) + + await keyv.set(cacheKey, responseBody) + } catch (err) { + console.error('ky afterResponse cache error', err) } - - const cacheKey = getCacheKeyForRequest(request) - // console.log({ cacheKey }) - if (!cacheKey) { - return - } - - const responseBody = await response.json() - // console.log({ responseBody }) - - await keyv.set(cacheKey, responseBody) - } catch (err) { - console.error('ky afterResponse cache error', err) } + ] + } + }) +} + +function defaultBeforeRequest(request: Request): Response { + return new Response( + JSON.stringify({ + url: request.url, + normalizedUrl: normalizeUrl(request.url), + method: request.method, + headers: request.headers + }), + { + status: 200, + headers: { + 'Content-Type': 'application/json', + [AGENTIC_TEST_MOCK_HEADER]: '1' } - ] - } -}) + } + ) +} + +export function mockKyInstance( + ky: types.KyInstance = defaultKy, + { + beforeRequest = defaultBeforeRequest, + afterResponse = null + }: { + beforeRequest?: BeforeRequestHook | null + afterResponse?: AfterResponseHook | null + } = {} +): types.KyInstance { + return ky.extend({ + hooks: { + beforeRequest: beforeRequest === null ? [] : [beforeRequest], + afterResponse: afterResponse === null ? [] : [afterResponse] + } + }) +} + +/* + * NOTE: ky hooks are appended when doing `ky.extend`, so if you already have a + * beforeRequest hook, it will be called before any passed to `ky.extend`. + * + * For example: + * + * ```ts + * // runs caching first, then mocking + * const ky0 = mockKyInstance(createTestKyInstance(ky)) + * + * // runs mocking first, then caching + * const ky1 = createTestKyInstance(mockKyInstance(ky)) + * + * // runs throttling first, then mocking + * const ky2 = mockKyInstance(throttleKy(ky, throttle)) + * ``` + */ +export const ky = createTestKyInstance() export class OpenAITestClient extends OpenAIClient { createChatCompletion = pMemoize(super.createChatCompletion, { diff --git a/test/services/diffbot.test.ts b/test/services/diffbot.test.ts index 86e1a79..f0261fe 100644 --- a/test/services/diffbot.test.ts +++ b/test/services/diffbot.test.ts @@ -36,7 +36,7 @@ test('Diffbot.extractArticle', async (t) => { t.is(result.objects[0].type, 'article') }) -test.only('Diffbot.knowledgeGraphSearch', async (t) => { +test('Diffbot.knowledgeGraphSearch', async (t) => { if (!process.env.DIFFBOT_API_KEY || isCI) { return t.pass() } diff --git a/test/utils.test.ts b/test/utils.test.ts index d9b13ec..0e4804e 100644 --- a/test/utils.test.ts +++ b/test/utils.test.ts @@ -1,15 +1,21 @@ import test from 'ava' +import ky from 'ky' +import pThrottle from 'p-throttle' import { chunkString, defaultIDGeneratorFn, + extractFunctionIdentifierFromString, extractJSONArrayFromString, extractJSONObjectFromString, isValidTaskIdentifier, sleep, - stringifyForModel + stringifyForModel, + throttleKy } from '@/utils' +import { mockKyInstance } from './_utils' + test('isValidTaskIdentifier - valid', async (t) => { t.true(isValidTaskIdentifier('foo')) t.true(isValidTaskIdentifier('foo_bar_179')) @@ -124,3 +130,49 @@ test('stringifyForModel should stringify objects with null values correctly', (t t.is(actualOutput, expectedOutput) }) + +test('extractFunctionIdentifierFromString valid', (t) => { + t.is(extractFunctionIdentifierFromString('foo'), 'foo') + t.is(extractFunctionIdentifierFromString('fooBar_BAZ'), 'fooBar_BAZ') + t.is(extractFunctionIdentifierFromString('functions.fooBar'), 'fooBar') + t.is( + extractFunctionIdentifierFromString('function fooBar1234_'), + 'fooBar1234_' + ) +}) + +test('extractFunctionIdentifierFromString invalid', (t) => { + t.is(extractFunctionIdentifierFromString(''), undefined) + t.is(extractFunctionIdentifierFromString(' '), undefined) + t.is(extractFunctionIdentifierFromString('.-'), undefined) +}) + +test('throttleKy should rate-limit requests to ky properly', async (t) => { + t.timeout(30_1000) + + const interval = 1000 + const throttle = pThrottle({ + limit: 1, + interval, + strict: true + }) + + const ky2 = mockKyInstance(throttleKy(ky, throttle)) + + const url = 'https://httpbin.org/get' + + for (let i = 0; i < 10; i++) { + const before = Date.now() + const res = await ky2.get(url) + const after = Date.now() + + const duration = after - before + // console.log(duration, res.status) + t.is(res.status, 200) + + // leave a bit of wiggle room for the interval + if (i > 0) { + t.true(duration >= interval - interval / 5) + } + } +})