feat: improve openai function/task/tool support

Travis Fischer 2023-06-13 21:39:19 -07:00
rodzic c9d6619855
commit 0ace3bb9b8
20 zmienionych plików z 184 dodań i 46 usunięć

Wyświetl plik

@ -0,0 +1,26 @@
# Snapshot report for `test/llms/llm-utils.test.ts`
The actual snapshot is saved in `llm-utils.test.ts.snap`.
Generated by [AVA](https://avajs.dev).
## getChatMessageFunctionDefinitionFromTask
> Snapshot 1
{
description: 'Useful for getting the result of a math expression. The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.',
name: 'calculator',
parameters: {
properties: {
expression: {
description: 'mathematical expression to evaluate',
type: 'string',
},
},
required: [
'expression',
],
type: 'object',
},
}

Plik binarny nie jest wyświetlany.

Wyświetl plik

@ -1,5 +1,4 @@
import { ZodTypeAny } from 'zod'
import * as types from '@/types'
import { Agentic } from '@/agentic'
import { BaseTask } from '@/task'
@ -37,8 +36,8 @@ export class HumanFeedbackMechanismCLI extends HumanFeedbackMechanism {
}
export function withHumanFeedback<
TInput extends ZodTypeAny = ZodTypeAny,
TOutput extends ZodTypeAny = ZodTypeAny
TInput extends void | types.JsonObject = void,
TOutput extends types.JsonValue = string
>(
task: BaseTask<TInput, TOutput>,
options: HumanFeedbackOptions = {

Wyświetl plik

@ -9,8 +9,8 @@ import { BaseChatModel } from './chat'
const defaultStopSequences = [anthropic.HUMAN_PROMPT]
export class AnthropicChatModel<
TInput = any,
TOutput = string
TInput extends void | types.JsonObject = any,
TOutput extends types.JsonValue = string
> extends BaseChatModel<
TInput,
TOutput,

Wyświetl plik

@ -3,7 +3,6 @@ import pMap from 'p-map'
import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest'
import { ZodType, z } from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema'
import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors'
@ -19,8 +18,8 @@ import { BaseTask } from '../task'
import { BaseLLM } from './llm'
export abstract class BaseChatModel<
TInput = void,
TOutput = string,
TInput extends void | types.JsonObject = void,
TOutput extends types.JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> {
@ -40,7 +39,9 @@ export abstract class BaseChatModel<
}
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
input<U>(inputSchema: ZodType<U>): BaseChatModel<U, TOutput, TModelParams> {
input<U extends void | types.JsonObject>(
inputSchema: ZodType<U>
): BaseChatModel<U, TOutput, TModelParams> {
const refinedInstance = this as unknown as BaseChatModel<
U,
TOutput,
@ -51,7 +52,9 @@ export abstract class BaseChatModel<
}
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
output<U>(outputSchema: ZodType<U>): BaseChatModel<TInput, U, TModelParams> {
output<U extends types.JsonValue>(
outputSchema: ZodType<U>
): BaseChatModel<TInput, U, TModelParams> {
const refinedInstance = this as unknown as BaseChatModel<
TInput,
U,
@ -72,14 +75,14 @@ export abstract class BaseChatModel<
return this
}
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public get supportsTools(): boolean {
return false
}
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages(
input?: TInput,
ctx?: types.TaskCallContext<TInput>
@ -239,7 +242,8 @@ export abstract class BaseChatModel<
}
}
const safeResult = outputSchema.safeParse(output)
// TODO: this doesn't bode well, batman...
const safeResult = (outputSchema.safeParse as any)(output)
if (!safeResult.success) {
throw new errors.ZodOutputValidationError(safeResult.error)

Wyświetl plik

@ -1,4 +1,5 @@
export * from './llm'
export * from './llm-utils'
export * from './chat'
export * from './openai'
export * from './anthropic'

Wyświetl plik

@ -0,0 +1,38 @@
import { zodToJsonSchema } from 'zod-to-json-schema'
import * as types from '@/types'
import { BaseTask } from '@/task'
import { isValidTaskIdentifier } from '@/utils'
export function getChatMessageFunctionDefinitionFromTask(
task: BaseTask<any, any>
): types.openai.ChatMessageFunction {
const name = task.nameForModel
if (!isValidTaskIdentifier(name)) {
throw new Error(`Invalid task name "${name}"`)
}
const jsonSchema = zodToJsonSchema(task.inputSchema, {
name,
$refStrategy: 'none'
})
const parameters: any = jsonSchema.definitions?.[name]
if (parameters) {
if (parameters.additionalProperties === false) {
delete parameters['additionalProperties']
}
}
return {
name,
description: task.descForModel || task.nameForHuman,
parameters
}
}
export function getChatMessageFunctionDefinitionsFromTasks(
tasks: BaseTask<any, any>[]
): types.openai.ChatMessageFunction[] {
return tasks.map(getChatMessageFunctionDefinitionFromTask)
}

Wyświetl plik

@ -7,8 +7,8 @@ import { Tokenizer, getTokenizerForModel } from '@/tokenizer'
// TODO: TInput should only be allowed to be void or an object
export abstract class BaseLLM<
TInput = void,
TOutput = string,
TInput extends void | types.JsonObject = void,
TOutput extends types.JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTask<TInput, TOutput> {
protected _inputSchema: ZodType<TInput> | undefined
@ -38,14 +38,18 @@ export abstract class BaseLLM<
}
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
input<U extends void | types.JsonObject>(
inputSchema: ZodType<U>
): BaseLLM<U, TOutput, TModelParams> {
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
refinedInstance._inputSchema = inputSchema
return refinedInstance
}
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
output<U extends types.JsonValue>(
outputSchema: ZodType<U>
): BaseLLM<TInput, U, TModelParams> {
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
refinedInstance._outputSchema = outputSchema
return refinedInstance

Wyświetl plik

@ -13,8 +13,8 @@ const openaiModelsSupportingFunctions = new Set([
])
export class OpenAIChatModel<
TInput = any,
TOutput = string
TInput extends void | types.JsonObject = any,
TOutput extends types.JsonValue = string
> extends BaseChatModel<
TInput,
TOutput,

Wyświetl plik

@ -18,7 +18,10 @@ import { Agentic } from '@/agentic'
* - Native function calls
* - Invoking sub-agents
*/
export abstract class BaseTask<TInput = void, TOutput = string> {
export abstract class BaseTask<
TInput extends void | types.JsonObject = void,
TOutput extends types.JsonValue = string
> {
protected _agentic: Agentic
protected _id: string
@ -26,6 +29,10 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
protected _retryConfig: types.RetryConfig
constructor(options: types.BaseTaskOptions) {
if (!options.agentic) {
throw new Error('Passing "agentic" is required when creating a Task')
}
this._agentic = options.agentic
this._timeoutMs = options.timeoutMs
this._retryConfig = options.retryConfig ?? {
@ -49,7 +56,7 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
public abstract get nameForModel(): string
public get nameForHuman(): string {
return this.nameForModel
return this.constructor.name
}
public get descForModel(): string {
@ -67,11 +74,17 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
return this
}
/**
* Calls this task with the given `input` and returns the result only.
*/
public async call(input?: TInput): Promise<TOutput> {
const res = await this.callWithMetadata(input)
return res.result
}
/**
* Calls this task with the given `input` and returns the result along with metadata.
*/
public async callWithMetadata(
input?: TInput
): Promise<types.TaskResponse<TOutput>> {
@ -126,6 +139,9 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
}
}
/**
* Subclasses must implement the core `_call` logic for this task.
*/
protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
// TODO

Wyświetl plik

@ -4,7 +4,9 @@ import { z } from 'zod'
import * as types from '@/types'
import { BaseTask } from '@/task'
export const CalculatorInputSchema = z.string().describe('expression')
export const CalculatorInputSchema = z.object({
expression: z.string().describe('mathematical expression to evaluate')
})
export const CalculatorOutputSchema = z
.number()
.describe('result of calculating the expression')
@ -44,7 +46,7 @@ export class CalculatorTool extends BaseTask<
protected override async _call(
ctx: types.TaskCallContext<CalculatorInput>
): Promise<CalculatorOutput> {
const result = Parser.evaluate(ctx.input!)
const result = Parser.evaluate(ctx.input!.expression)
return result
}
}

Wyświetl plik

@ -7,7 +7,7 @@ import { BaseTask } from '@/task'
export const NovuNotificationToolInputSchema = z.object({
name: z.string(),
payload: z.record(z.unknown()),
payload: z.record(z.any()),
to: z.array(
z.object({
subscriberId: z.string(),

Wyświetl plik

@ -1,7 +1,7 @@
import * as openai from '@agentic/openai-fetch'
import * as anthropic from '@anthropic-ai/sdk'
import type { Options as RetryOptions } from 'p-retry'
import type { JsonObject } from 'type-fest'
import type { JsonObject, JsonValue } from 'type-fest'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic'
@ -9,6 +9,7 @@ import type { BaseTask } from './task'
export { openai }
export { anthropic }
export type { JsonObject, JsonValue }
export type ParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
? output<T>
@ -32,8 +33,8 @@ export interface BaseTaskOptions {
}
export interface BaseLLMOptions<
TInput = void,
TOutput = string,
TInput extends void | JsonObject = void,
TOutput extends JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTaskOptions {
inputSchema?: ZodType<TInput>
@ -46,8 +47,8 @@ export interface BaseLLMOptions<
}
export interface LLMOptions<
TInput = void,
TOutput = string,
TInput extends void | JsonObject = void,
TOutput extends JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
promptTemplate?: string
@ -59,8 +60,8 @@ export type ChatMessage = openai.ChatMessage
export type ChatMessageRole = openai.ChatMessageRole
export interface ChatModelOptions<
TInput = void,
TOutput = string,
TInput extends void | JsonObject = void,
TOutput extends JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[]
@ -105,7 +106,7 @@ export interface LLMTaskResponseMetadata<
}
export interface TaskResponse<
TOutput = string,
TOutput extends JsonValue = string,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
result: TOutput
@ -113,7 +114,7 @@ export interface TaskResponse<
}
export interface TaskCallContext<
TInput = void,
TInput extends void | JsonObject = void,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
input?: TInput

Wyświetl plik

@ -2,14 +2,22 @@ import { customAlphabet, urlAlphabet } from 'nanoid'
import * as types from './types'
export const extractJSONObjectFromString = (text: string): string | undefined =>
text.match(/\{(.|\n)*\}/gm)?.[0]
export function extractJSONObjectFromString(text: string): string | undefined {
return text.match(/\{(.|\n)*\}/gm)?.[0]
}
export const extractJSONArrayFromString = (text: string): string | undefined =>
text.match(/\[(.|\n)*\]/gm)?.[0]
export function extractJSONArrayFromString(text: string): string | undefined {
return text.match(/\[(.|\n)*\]/gm)?.[0]
}
export const sleep = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms))
export function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms))
}
export const defaultIDGeneratorFn: types.IDGeneratorFunction =
customAlphabet(urlAlphabet)
const taskNameRegex = /^[a-zA-Z_][a-zA-Z0-9_-]{0,63}$/
export function isValidTaskIdentifier(id: string): boolean {
return !!id && taskNameRegex.test(id)
}

Wyświetl plik

@ -9,11 +9,11 @@ test('CalculatorTool', async (t) => {
const agentic = createTestAgenticRuntime()
const tool = new CalculatorTool({ agentic })
const res = await tool.call('1 + 1')
const res = await tool.call({ expression: '1 + 1' })
t.is(res, 2)
expectTypeOf(res).toMatchTypeOf<number>()
const res2 = await tool.callWithMetadata('cos(0)')
const res2 = await tool.callWithMetadata({ expression: 'cos(0)' })
t.is(res2.result, 1)
expectTypeOf(res2.result).toMatchTypeOf<number>()

Wyświetl plik

@ -3,7 +3,7 @@ import { expectTypeOf } from 'expect-type'
import { AnthropicChatModel } from '@/llms/anthropic'
import { createTestAgenticRuntime } from './_utils'
import { createTestAgenticRuntime } from '../_utils'
test('AnthropicChatModel ⇒ string output', async (t) => {
t.timeout(2 * 60 * 1000)

Wyświetl plik

@ -0,0 +1,19 @@
import test from 'ava'
import { getChatMessageFunctionDefinitionFromTask } from '@/llms/llm-utils'
import { CalculatorTool } from '@/tools/calculator'
import { createTestAgenticRuntime } from '../_utils'
test('getChatMessageFunctionDefinitionFromTask', async (t) => {
const agentic = createTestAgenticRuntime()
const tool = new CalculatorTool({ agentic })
const functionDefinition = getChatMessageFunctionDefinitionFromTask(tool)
t.is(functionDefinition.name, 'calculator')
t.is(functionDefinition.description, tool.descForModel)
console.log(JSON.stringify(functionDefinition, null, 2))
t.snapshot(functionDefinition)
})

Wyświetl plik

@ -6,7 +6,7 @@ import { z } from 'zod'
import { OutputValidationError, TemplateValidationError } from '@/errors'
import { BaseChatModel, OpenAIChatModel } from '@/llms'
import { createTestAgenticRuntime } from './_utils'
import { createTestAgenticRuntime } from '../_utils'
test('OpenAIChatModel ⇒ types', async (t) => {
const agentic = createTestAgenticRuntime()

20
test/utils.test.ts vendored 100644
Wyświetl plik

@ -0,0 +1,20 @@
import test from 'ava'
import { isValidTaskIdentifier } from '@/utils'
test('isValidTaskIdentifier - valid', async (t) => {
t.true(isValidTaskIdentifier('foo'))
t.true(isValidTaskIdentifier('foo_bar_179'))
t.true(isValidTaskIdentifier('fooBarBAZ'))
t.true(isValidTaskIdentifier('foo-bar-baz_'))
t.true(isValidTaskIdentifier('_'))
t.true(isValidTaskIdentifier('_foo___'))
})
test('isValidTaskIdentifier - invalid', async (t) => {
t.false(isValidTaskIdentifier(null as any))
t.false(isValidTaskIdentifier(''))
t.false(isValidTaskIdentifier('-'))
t.false(isValidTaskIdentifier('x'.repeat(65)))
t.false(isValidTaskIdentifier('-foo'))
})