From f5319b4e174313ae44f4e8f51428c5453642b149 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Mon, 3 Jun 2024 15:06:41 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .eslintrc.json | 6 +- bin/scratch.ts | 2 +- package.json | 2 + pnpm-lock.yaml | 6 + src/index.ts | 2 +- src/message.test.ts | 55 +++++ src/message.ts | 344 ++++++++++++++++++++++++++++++++ src/sdks/langchain.ts | 2 +- src/services/dexa-client.ts | 24 ++- src/services/diffbot-client.ts | 1 + src/stringify-for-model.test.ts | 22 -- src/stringify-for-model.ts | 16 -- src/types.ts | 127 +----------- src/utils.test.ts | 29 ++- src/utils.ts | 29 +++ 15 files changed, 494 insertions(+), 173 deletions(-) create mode 100644 src/message.test.ts create mode 100644 src/message.ts delete mode 100644 src/stringify-for-model.test.ts delete mode 100644 src/stringify-for-model.ts diff --git a/.eslintrc.json b/.eslintrc.json index d03d820..6de124a 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -1,4 +1,8 @@ { "root": true, - "extends": ["@fisch0920/eslint-config/node"] + "extends": ["@fisch0920/eslint-config/node"], + "rules": { + "unicorn/no-static-only-class": "off", + "@typescript-eslint/naming-convention": "off" + } } diff --git a/bin/scratch.ts b/bin/scratch.ts index a0b4c31..5472bd0 100644 --- a/bin/scratch.ts +++ b/bin/scratch.ts @@ -71,7 +71,7 @@ async function main() { // }) const res = await diffbot.enhanceEntity({ type: 'Person', - name: 'Travis Fischer' + name: 'Kevin Raheja' }) console.log(JSON.stringify(res, null, 2)) } diff --git a/package.json b/package.json index 45fbe1c..42cd386 100644 --- a/package.json +++ b/package.json @@ -59,6 +59,7 @@ }, "dependencies": { "@nangohq/node": "^0.39.32", + "dedent": "^1.5.3", "delay": "^6.0.0", "jsonrepair": "^3.6.1", "ky": "^1.2.4", @@ -85,6 +86,7 @@ "np": "^10.0.5", "npm-run-all2": "^6.2.0", "only-allow": "^1.2.1", + "openai-fetch": "^2.0.3", "prettier": "^3.2.5", "restore-cursor": "^5.0.0", "ts-node": "^10.9.2", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7b9efbd..2da3744 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -14,6 +14,9 @@ importers: '@nangohq/node': specifier: ^0.39.32 version: 0.39.32 + dedent: + specifier: ^1.5.3 + version: 1.5.3 delay: specifier: ^6.0.0 version: 6.0.0 @@ -87,6 +90,9 @@ importers: only-allow: specifier: ^1.2.1 version: 1.2.1 + openai-fetch: + specifier: ^2.0.3 + version: 2.0.3 prettier: specifier: ^3.2.5 version: 3.3.0 diff --git a/src/index.ts b/src/index.ts index 269386e..50edc27 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,9 +3,9 @@ export * from './create-ai-function.js' export * from './create-ai-function.js' export * from './errors.js' export * from './fns.js' +export * from './message.js' export * from './parse-structured-output.js' export * from './services/index.js' -export * from './stringify-for-model.js' export type * from './types.js' export * from './utils.js' export * from './zod-to-json-schema.js' diff --git a/src/message.test.ts b/src/message.test.ts new file mode 100644 index 0000000..3da4301 --- /dev/null +++ b/src/message.test.ts @@ -0,0 +1,55 @@ +import type * as OpenAI from 'openai-fetch' +import { describe, expect, expectTypeOf, it } from 'vitest' + +import type * as types from './types.js' +import { Msg } from './message.js' + +describe('Msg', () => { + it('creates a message and fixes indentation', () => { + const msgContent = ` + Hello, World! + ` + const msg = Msg.system(msgContent) + expect(msg.role).toEqual('system') + expect(msg.content).toEqual('Hello, World!') + }) + + it('supports disabling indentation fixing', () => { + const msgContent = ` + Hello, World! + ` + const msg = Msg.system(msgContent, { cleanContent: false }) + expect(msg.content).toEqual('\n Hello, World!\n ') + }) + + it('handles tool calls request', () => { + const msg = Msg.toolCall([ + { + id: 'fake-tool-call-id', + type: 'function', + function: { + arguments: '{"prompt": "Hello, World!"}', + name: 'hello' + } + } + ]) + expectTypeOf(msg).toMatchTypeOf() + expect(Msg.isToolCall(msg)).toBe(true) + }) + + it('handles tool call response', () => { + const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id') + expectTypeOf(msg).toMatchTypeOf() + expect(Msg.isToolResult(msg)).toBe(true) + }) + + it('prompt message types should interop with openai-fetch message types', () => { + expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf() + expectTypeOf({} as types.Msg).toMatchTypeOf() + expectTypeOf({} as types.Msg.System).toMatchTypeOf() + expectTypeOf({} as types.Msg.User).toMatchTypeOf() + expectTypeOf({} as types.Msg.Assistant).toMatchTypeOf() + expectTypeOf({} as types.Msg.FuncCall).toMatchTypeOf() + expectTypeOf({} as types.Msg.FuncResult).toMatchTypeOf() + }) +}) diff --git a/src/message.ts b/src/message.ts new file mode 100644 index 0000000..6202c25 --- /dev/null +++ b/src/message.ts @@ -0,0 +1,344 @@ +import type { Jsonifiable } from 'type-fest' + +import { cleanStringForModel, stringifyForModel } from './utils.js' + +/** + * Generic/default OpenAI message without any narrowing applied. + */ +export interface Msg { + /** + * The contents of the message. `content` is required for all messages, and + * may be null for assistant messages with function calls. + */ + content: string | null + + /** + * The role of the messages author. One of `system`, `user`, `assistant`, + * 'tool', or `function`. + */ + role: Msg.Role + + /** + * The name and arguments of a function that should be called, as generated + * by the model. + */ + function_call?: Msg.Call.Function + + /** + * The tool calls generated by the model, such as function calls. + */ + tool_calls?: Msg.Call.Tool[] + + /** + * Tool call that this message is responding to. + */ + tool_call_id?: string + + /** + * The name of the author of this message. `name` is required if role is + * `function`, and it should be the name of the function whose response is in the + * `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of + * 64 characters. + */ + name?: string +} + +/** Narrowed OpenAI Message types. */ +export namespace Msg { + /** Possible roles for a message. */ + export type Role = 'system' | 'user' | 'assistant' | 'function' | 'tool' + + export namespace Call { + /** + * The name and arguments of a function that should be called, as generated + * by the model. + */ + export type Function = { + /** + * The arguments to call the function with, as generated by the model in + * JSON format. + */ + arguments: string + + /** The name of the function to call. */ + name: string + } + + /** The tool calls generated by the model, such as function calls. */ + export type Tool = { + /** The ID of the tool call. */ + id: string + + /** The type of the tool. Currently, only `function` is supported. */ + type: 'function' + + /** The function that the model called. */ + function: Call.Function + } + } + + /** Message with text content for the system. */ + export type System = { + role: 'system' + content: string + name?: string + } + + /** Message with text content from the user. */ + export type User = { + role: 'user' + name?: string + content: string + } + + /** Message with text content from the assistant. */ + export type Assistant = { + role: 'assistant' + name?: string + content: string + } + + /** Message with arguments to call a function. */ + export type FuncCall = { + role: 'assistant' + name?: string + content: null + function_call: Call.Function + } + + /** Message with the result of a function call. */ + export type FuncResult = { + role: 'function' + name: string + content: string + } + + /** Message with arguments to call one or more tools. */ + export type ToolCall = { + role: 'assistant' + name?: string + content: null + tool_calls: Call.Tool[] + } + + /** Message with the result of a tool call. */ + export type ToolResult = { + role: 'tool' + tool_call_id: string + content: string + } +} + +/** Utility functions for creating and checking message types. */ +export namespace Msg { + /** Create a system message. Cleans indentation and newlines by default. */ + export function system( + content: string, + opts?: { + /** Custom name for the message. */ + name?: string + /** Whether to clean extra newlines and indentation. Defaults to true. */ + cleanContent?: boolean + } + ): Msg.System { + const { name, cleanContent = true } = opts ?? {} + return { + role: 'system', + content: cleanContent ? cleanStringForModel(content) : content, + ...(name ? { name } : {}) + } + } + + /** Create a user message. Cleans indentation and newlines by default. */ + export function user( + content: string, + opts?: { + /** Custom name for the message. */ + name?: string + /** Whether to clean extra newlines and indentation. Defaults to true. */ + cleanContent?: boolean + } + ): Msg.User { + const { name, cleanContent = true } = opts ?? {} + return { + role: 'user', + content: cleanContent ? cleanStringForModel(content) : content, + ...(name ? { name } : {}) + } + } + + /** Create an assistant message. Cleans indentation and newlines by default. */ + export function assistant( + content: string, + opts?: { + /** Custom name for the message. */ + name?: string + /** Whether to clean extra newlines and indentation. Defaults to true. */ + cleanContent?: boolean + } + ): Msg.Assistant { + const { name, cleanContent = true } = opts ?? {} + return { + role: 'assistant', + content: cleanContent ? cleanStringForModel(content) : content, + ...(name ? { name } : {}) + } + } + + /** Create a function call message with argumets. */ + export function funcCall( + function_call: { + /** Name of the function to call. */ + name: string + /** Arguments to pass to the function. */ + arguments: string + }, + opts?: { + /** The name descriptor for the message.(message.name) */ + name?: string + } + ): Msg.FuncCall { + return { + ...opts, + role: 'assistant', + content: null, + function_call + } + } + + /** Create a function result message. */ + export function funcResult( + content: Jsonifiable, + name: string + ): Msg.FuncResult { + const contentString = stringifyForModel(content) + return { role: 'function', content: contentString, name } + } + + /** Create a function call message with argumets. */ + export function toolCall( + tool_calls: Msg.Call.Tool[], + opts?: { + /** The name descriptor for the message.(message.name) */ + name?: string + } + ): Msg.ToolCall { + return { + ...opts, + role: 'assistant', + content: null, + tool_calls + } + } + + /** Create a tool call result message. */ + export function toolResult( + content: Jsonifiable, + tool_call_id: string, + opts?: { + /** The name of the tool which was called */ + name?: string + } + ): Msg.ToolResult { + const contentString = stringifyForModel(content) + return { ...opts, role: 'tool', tool_call_id, content: contentString } + } + + /** Get the narrowed message from an EnrichedResponse. */ + export function getMessage( + // @TODO + response: any + // response: ChatModel.EnrichedResponse + ): Msg.Assistant | Msg.FuncCall | Msg.ToolCall { + const msg = response.choices[0].message as Msg + return narrowResponseMessage(msg) + } + + /** Narrow a message received from the API. It only responds with role=assistant */ + export function narrowResponseMessage( + msg: Msg + ): Msg.Assistant | Msg.FuncCall | Msg.ToolCall { + if (msg.content === null && msg.tool_calls != null) { + return Msg.toolCall(msg.tool_calls) + } else if (msg.content === null && msg.function_call != null) { + return Msg.funcCall(msg.function_call) + } else if (msg.content !== null) { + return Msg.assistant(msg.content) + } else { + // @TODO: probably don't want to error here + console.log('Invalid message', msg) + throw new Error('Invalid message') + } + } + + /** Check if a message is a system message. */ + export function isSystem(message: Msg): message is Msg.System { + return message.role === 'system' + } + /** Check if a message is a user message. */ + export function isUser(message: Msg): message is Msg.User { + return message.role === 'user' + } + /** Check if a message is an assistant message. */ + export function isAssistant(message: Msg): message is Msg.Assistant { + return message.role === 'assistant' && message.content !== null + } + /** Check if a message is a function call message with arguments. */ + export function isFuncCall(message: Msg): message is Msg.FuncCall { + return message.role === 'assistant' && message.function_call != null + } + /** Check if a message is a function result message. */ + export function isFuncResult(message: Msg): message is Msg.FuncResult { + return message.role === 'function' && message.name != null + } + /** Check if a message is a tool calls message. */ + export function isToolCall(message: Msg): message is Msg.ToolCall { + return message.role === 'assistant' && message.tool_calls != null + } + /** Check if a message is a tool call result message. */ + export function isToolResult(message: Msg): message is Msg.ToolResult { + return message.role === 'tool' && !!message.tool_call_id + } + + /** Narrow a ChatModel.Message to a specific type. */ + export function narrow(message: Msg.System): Msg.System + export function narrow(message: Msg.User): Msg.User + export function narrow(message: Msg.Assistant): Msg.Assistant + export function narrow(message: Msg.FuncCall): Msg.FuncCall + export function narrow(message: Msg.FuncResult): Msg.FuncResult + export function narrow(message: Msg.ToolCall): Msg.ToolCall + export function narrow(message: Msg.ToolResult): Msg.ToolResult + export function narrow( + message: Msg + ): + | Msg.System + | Msg.User + | Msg.Assistant + | Msg.FuncCall + | Msg.FuncResult + | Msg.ToolCall + | Msg.ToolResult { + if (isSystem(message)) { + return message + } + if (isUser(message)) { + return message + } + if (isAssistant(message)) { + return message + } + if (isFuncCall(message)) { + return message + } + if (isFuncResult(message)) { + return message + } + if (isToolCall(message)) { + return message + } + if (isToolResult(message)) { + return message + } + throw new Error('Invalid message type') + } +} diff --git a/src/sdks/langchain.ts b/src/sdks/langchain.ts index b273c91..f1b0cf4 100644 --- a/src/sdks/langchain.ts +++ b/src/sdks/langchain.ts @@ -2,7 +2,7 @@ import { DynamicStructuredTool } from '@langchain/core/tools' import type { AIFunctionLike } from '../types.js' import { AIFunctionSet } from '../ai-function-set.js' -import { stringifyForModel } from '../stringify-for-model.js' +import { stringifyForModel } from '../utils.js' /** * Converts a set of Agentic stdlib AI functions to an array of LangChain- diff --git a/src/services/dexa-client.ts b/src/services/dexa-client.ts index d43afb1..b6a0dac 100644 --- a/src/services/dexa-client.ts +++ b/src/services/dexa-client.ts @@ -1,9 +1,18 @@ import defaultKy, { type KyInstance } from 'ky' +import { z } from 'zod' -import type * as types from '../types.js' +import { aiFunction, AIFunctionsProvider } from '../fns.js' +import { Msg } from '../message.js' import { assert, getEnv } from '../utils.js' -export class DexaClient { +export namespace dexa { + export const AskDexaOptionsSchema = z.object({ + question: z.string().describe('The question to ask Dexa.') + }) + export type AskDexaOptions = z.infer +} + +export class DexaClient extends AIFunctionsProvider { readonly apiKey: string readonly apiBaseUrl: string readonly ky: KyInstance @@ -23,6 +32,7 @@ export class DexaClient { apiKey, 'DexaClient missing required "apiKey" (defaults to "DEXA_API_KEY")' ) + super() this.apiKey = apiKey this.apiBaseUrl = apiBaseUrl @@ -30,12 +40,18 @@ export class DexaClient { this.ky = ky.extend({ prefixUrl: this.apiBaseUrl, timeout: timeoutMs }) } - async askDexa({ messages }: { messages: types.Msg[] }) { + @aiFunction({ + name: 'ask_dexa', + description: + 'Answers questions based on knowledge of trusted experts and podcasters. Example experts include: Andrew Huberman, Tim Ferriss, Lex Fridman, Peter Attia, Seth Godin, Rhonda Patrick, Rick Rubin, and more.', + inputSchema: dexa.AskDexaOptionsSchema + }) + async askDexa(opts: dexa.AskDexaOptions) { return this.ky .post('api/ask-dexa', { json: { secret: this.apiKey, - messages + messages: [Msg.user(opts.question)] } }) .json() diff --git a/src/services/diffbot-client.ts b/src/services/diffbot-client.ts index bb729b0..549bd47 100644 --- a/src/services/diffbot-client.ts +++ b/src/services/diffbot-client.ts @@ -414,6 +414,7 @@ export namespace diffbot { locations?: Location[] location?: Location interests?: Interest[] + emailAddresses?: any age?: number crawlTimestamp?: number } diff --git a/src/stringify-for-model.test.ts b/src/stringify-for-model.test.ts deleted file mode 100644 index be6a689..0000000 --- a/src/stringify-for-model.test.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { describe, expect, it } from 'vitest' - -import { stringifyForModel } from './stringify-for-model.js' - -describe('stringifyForModel', () => { - it('handles basic objects', () => { - const input = { - foo: 'bar', - nala: ['is', 'cute'], - kittens: null, - cats: undefined, - paws: 4.3 - } - const result = stringifyForModel(input) - expect(result).toEqual(JSON.stringify(input, null)) - }) - - it('handles empty input', () => { - const result = stringifyForModel() - expect(result).toEqual('') - }) -}) diff --git a/src/stringify-for-model.ts b/src/stringify-for-model.ts deleted file mode 100644 index 8a3602d..0000000 --- a/src/stringify-for-model.ts +++ /dev/null @@ -1,16 +0,0 @@ -import type { Jsonifiable } from 'type-fest' - -/** - * Stringifies a JSON value in a way that's optimized for use with LLM prompts. - */ -export function stringifyForModel(jsonObject?: Jsonifiable): string { - if (jsonObject === undefined) { - return '' - } - - if (typeof jsonObject === 'string') { - return jsonObject - } - - return JSON.stringify(jsonObject, null, 0) -} diff --git a/src/types.ts b/src/types.ts index 4a27a1c..44f2e20 100644 --- a/src/types.ts +++ b/src/types.ts @@ -3,7 +3,9 @@ import type { z } from 'zod' import type { AIFunctionSet } from './ai-function-set.js' import type { AIFunctionsProvider } from './fns.js' +import type { Msg } from './message.js' +export type { Msg } from './message.js' export type { KyInstance } from 'ky' export type { ThrottledFunction } from 'p-throttle' @@ -70,128 +72,3 @@ export interface AITool< /** The tool spec for the OpenAI API `tools` property. */ spec: AIToolSpec } - -/** - * Generic/default OpenAI message without any narrowing applied. - */ -export interface Msg { - /** - * The contents of the message. `content` is required for all messages, and - * may be null for assistant messages with function calls. - */ - content: string | null - - /** - * The role of the messages author. One of `system`, `user`, `assistant`, - * 'tool', or `function`. - */ - role: Msg.Role - - /** - * The name and arguments of a function that should be called, as generated - * by the model. - */ - function_call?: Msg.Call.Function - - /** The tool calls generated by the model, such as function calls. */ - tool_calls?: Msg.Call.Tool[] - - /** - * Tool call that this message is responding to. - */ - tool_call_id?: string - - /** - * The name of the author of this message. `name` is required if role is - * `function`, and it should be the name of the function whose response is in the - * `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of - * 64 characters. - */ - name?: string -} - -/** Narrowed Message types. */ -export namespace Msg { - /** Possible roles for a message. */ - export type Role = 'system' | 'user' | 'assistant' | 'function' | 'tool' - - export namespace Call { - /** - * The name and arguments of a function that should be called, as generated - * by the model. - */ - export type Function = { - /** - * The arguments to call the function with, as generated by the model in - * JSON format. - */ - arguments: string - - /** The name of the function to call. */ - name: string - } - - /** The tool calls generated by the model, such as function calls. */ - export type Tool = { - /** The ID of the tool call. */ - id: string - - /** The type of the tool. Currently, only `function` is supported. */ - type: 'function' - - /** The function that the model called. */ - function: Call.Function - } - } - - /** Message with text content for the system. */ - export type System = { - role: 'system' - content: string - name?: string - } - - /** Message with text content from the user. */ - export type User = { - role: 'user' - name?: string - content: string - } - - /** Message with text content from the assistant. */ - export type Assistant = { - role: 'assistant' - name?: string - content: string - } - - /** Message with arguments to call a function. */ - export type FuncCall = { - role: 'assistant' - name?: string - content: null - function_call: Call.Function - } - - /** Message with the result of a function call. */ - export type FuncResult = { - role: 'function' - name: string - content: string - } - - /** Message with arguments to call one or more tools. */ - export type ToolCall = { - role: 'assistant' - name?: string - content: null - tool_calls: Call.Tool[] - } - - /** Message with the result of a tool call. */ - export type ToolResult = { - role: 'tool' - tool_call_id: string - content: string - } -} diff --git a/src/utils.test.ts b/src/utils.test.ts index 16314ce..b234243 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -1,9 +1,15 @@ import ky from 'ky' import pThrottle from 'p-throttle' -import { expect, test } from 'vitest' +import { describe, expect, test } from 'vitest' import { mockKyInstance } from './_utils.js' -import { omit, pick, sanitizeSearchParams, throttleKy } from './utils.js' +import { + omit, + pick, + sanitizeSearchParams, + stringifyForModel, + throttleKy +} from './utils.js' test('pick', () => { expect(pick({ a: 1, b: 2, c: 3 }, 'a', 'c')).toEqual({ a: 1, c: 3 }) @@ -79,3 +85,22 @@ test( timeout: 60_000 } ) + +describe('stringifyForModel', () => { + test('handles basic objects', () => { + const input = { + foo: 'bar', + nala: ['is', 'cute'], + kittens: null, + cats: undefined, + paws: 4.3 + } + const result = stringifyForModel(input) + expect(result).toEqual(JSON.stringify(input, null)) + }) + + test('handles empty input', () => { + const result = stringifyForModel() + expect(result).toEqual('') + }) +}) diff --git a/src/utils.ts b/src/utils.ts index 7aae335..10ae0b8 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,3 +1,6 @@ +import type { Jsonifiable } from 'type-fest' +import dedent from 'dedent' + import type * as types from './types.js' export { assert } from './assert.js' @@ -111,3 +114,29 @@ export function sanitizeSearchParams( }) ) } + +/** + * Stringifies a JSON value in a way that's optimized for use with LLM prompts. + */ +export function stringifyForModel(jsonObject?: Jsonifiable): string { + if (jsonObject === undefined) { + return '' + } + + if (typeof jsonObject === 'string') { + return jsonObject + } + + return JSON.stringify(jsonObject, null, 0) +} + +const dedenter = dedent.withOptions({ escapeSpecialCharacters: true }) + +/** + * Clean a string by removing extra newlines and indentation. + * + * @see: https://github.com/dmnd/dedent + */ +export function cleanStringForModel(text: string): string { + return dedenter(text).trim() +}