feat: tools and BaseTask.agentic refactor

Travis Fischer 2023-06-14 17:40:48 -07:00
rodzic 94cb9041b1
commit f95dd54436
15 zmienionych plików z 159 dodań i 59 usunięć

13
examples/functions.ts vendored
Wyświetl plik

@ -2,16 +2,21 @@ import { OpenAIClient } from '@agentic/openai-fetch'
import 'dotenv/config'
import { z } from 'zod'
import { Agentic, CalculatorTool } from '@/index'
import { Agentic, CalculatorTool, WeatherTool } from '@/index'
async function main() {
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
const agentic = new Agentic({ openai })
const example = await agentic
.gpt3('What is 5 * 50?')
.tools([new CalculatorTool({ agentic })])
.output(z.object({ answer: z.number() }))
.gpt3('What is the temperature in san francisco today?')
.tools([new CalculatorTool(), new WeatherTool()])
.output(
z.object({
answer: z.number(),
units: z.union([z.literal('fahrenheit'), z.literal('celcius')])
})
)
.call()
console.log(example)
}

Wyświetl plik

@ -101,7 +101,7 @@
]
},
"ava": {
"snapshotDir": ".snapshots",
"snapshotDir": "test/.snapshots",
"extensions": {
"ts": "module"
},

Wyświetl plik

@ -1,11 +1,10 @@
import * as types from '@/types'
import { DEFAULT_OPENAI_MODEL } from '@/constants'
import { OpenAIChatCompletion } from '@/llms/openai'
import * as types from './types'
import { DEFAULT_OPENAI_MODEL } from './constants'
import {
HumanFeedbackMechanism,
HumanFeedbackMechanismCLI
} from './human-feedback'
import { OpenAIChatCompletion } from './llms/openai'
import { defaultIDGeneratorFn } from './utils'
export class Agentic {
@ -20,7 +19,9 @@ export class Agentic {
'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig'
>
protected _defaultHumanFeedbackMechamism?: HumanFeedbackMechanism
protected _idGeneratorFn: types.IDGeneratorFunction
protected _id: string
constructor(opts: {
openai?: types.openai.OpenAIClient
@ -33,6 +34,10 @@ export class Agentic {
defaultHumanFeedbackMechanism?: HumanFeedbackMechanism
idGeneratorFn?: types.IDGeneratorFunction
}) {
if (!globalThis.__agentic?.deref()) {
globalThis.__agentic = new WeakRef(this)
}
this._openai = opts.openai
this._anthropic = opts.anthropic
@ -59,6 +64,7 @@ export class Agentic {
new HumanFeedbackMechanismCLI({ agentic: this })
this._idGeneratorFn = opts.idGeneratorFn ?? defaultIDGeneratorFn
this._id = this._idGeneratorFn()
}
public get openai(): types.openai.OpenAIClient | undefined {

Wyświetl plik

@ -1,6 +1,6 @@
import * as types from '@/types'
import { Agentic } from '@/agentic'
import { BaseTask } from '@/task'
import * as types from './types'
import { Agentic } from './agentic'
import { BaseTask } from './task'
export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN'

Wyświetl plik

@ -38,7 +38,7 @@ export class AnthropicChatCompletion<
...options
})
if (this._agentic.anthropic) {
if (this._agentic?.anthropic) {
this._client = this._agentic.anthropic
} else {
throw new Error(

Wyświetl plik

@ -72,13 +72,38 @@ export abstract class BaseChatCompletion<
}
this._tools = tools
for (const tool of tools) {
tool.agentic = this.agentic
}
return this
}
/**
* Whether or not this chat completion model directly supports the use of tools.
*/
public get supportsTools(): boolean {
return false
}
public override validate() {
super.validate()
if (this._tools) {
for (const tool of this._tools) {
if (!tool.agentic) {
tool.agentic = this.agentic
} else if (tool.agentic !== this.agentic) {
throw new Error(
`Task "${this.nameForHuman}" has a different Agentic runtime instance than the tool "${tool.nameForHuman}"`
)
}
tool.validate()
}
}
}
protected abstract _createChatCompletion(
messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[]
@ -133,9 +158,22 @@ export abstract class BaseChatCompletion<
.replace(/^ {4}/gm, ' ')
.replace(/;$/gm, '')
const label =
this._outputSchema instanceof z.ZodArray
? 'JSON array'
: this._outputSchema instanceof z.ZodObject
? 'JSON object'
: this._outputSchema instanceof z.ZodNumber
? 'number'
: this._outputSchema instanceof z.ZodString
? 'string'
: this._outputSchema instanceof z.ZodBoolean
? 'boolean'
: 'JSON value'
messages.push({
role: 'system',
content: dedent`Do not output code. Output JSON only in the following TypeScript format:
content: dedent`Do not output code. Output a single ${label} in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
@ -285,6 +323,11 @@ export abstract class BaseChatCompletion<
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
if (Array.isArray(output)) {
// TODO
output = output[0]
}
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })

Wyświetl plik

@ -38,7 +38,7 @@ export class OpenAIChatCompletion<
...options
})
if (this._agentic.openai) {
if (this._agentic?.openai) {
this._client = this._agentic.openai
} else {
throw new Error(
@ -93,6 +93,30 @@ export class OpenAIChatCompletion<
return openaiModelsSupportingFunctions.has(this._model)
}
public override validate() {
super.validate()
if (!this._client) {
throw new Error(
'OpenAIChatCompletion requires an OpenAI client to be configured on the Agentic runtime'
)
}
if (!this.supportsTools) {
if (this._tools) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support tools`
)
}
if (this._modelParams?.functions) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support functions`
)
}
}
}
protected override async _createChatCompletion(
messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[]

Wyświetl plik

@ -1,9 +1,10 @@
import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodType } from 'zod'
import * as errors from '@/errors'
import * as types from '@/types'
import { Agentic } from '@/agentic'
import * as errors from './errors'
import * as types from './types'
import type { Agentic } from './agentic'
import { defaultIDGeneratorFn, isValidTaskIdentifier } from './utils'
/**
* A `Task` is an async function call that may be non-deterministic. It has
@ -28,24 +29,27 @@ export abstract class BaseTask<
protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig
constructor(options: types.BaseTaskOptions) {
if (!options.agentic) {
throw new Error('Passing "agentic" is required when creating a Task')
}
constructor(options: types.BaseTaskOptions = {}) {
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
this._agentic = options.agentic
this._timeoutMs = options.timeoutMs
this._retryConfig = options.retryConfig ?? {
retries: 3,
strategy: 'default'
}
this._id = options.id ?? this._agentic.idGeneratorFn()
this._id =
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
}
public get agentic(): Agentic {
return this._agentic
}
public set agentic(agentic: Agentic) {
this._agentic = agentic
}
public get id(): string {
return this._id
}
@ -53,7 +57,10 @@ export abstract class BaseTask<
public abstract get inputSchema(): ZodType<TInput>
public abstract get outputSchema(): ZodType<TOutput>
public abstract get nameForModel(): string
public get nameForModel(): string {
const name = this.constructor.name
return name[0].toLowerCase() + name.slice(1)
}
public get nameForHuman(): string {
return this.constructor.name
@ -63,6 +70,19 @@ export abstract class BaseTask<
return ''
}
public validate() {
if (!this._agentic) {
throw new Error(
`Task "${this.nameForHuman}" is missing a required "agentic" instance`
)
}
const nameForModel = this.nameForModel
if (!isValidTaskIdentifier(nameForModel)) {
throw new Error(`Task field nameForModel "${nameForModel}" is invalid`)
}
}
// TODO: is this really necessary?
public clone(): BaseTask<TInput, TOutput> {
// TODO: override in subclass if needed
@ -88,6 +108,8 @@ export abstract class BaseTask<
public async callWithMetadata(
input?: TInput
): Promise<types.TaskResponse<TOutput>> {
this.validate()
if (this.inputSchema) {
const safeInput = this.inputSchema.safeParse(input)
@ -104,7 +126,7 @@ export abstract class BaseTask<
metadata: {
taskName: this.nameForModel,
taskId: this.id,
callId: this._agentic.idGeneratorFn()
callId: this._agentic!.idGeneratorFn()
}
}

Wyświetl plik

@ -19,7 +19,7 @@ export class CalculatorTool extends BaseTask<
CalculatorInput,
CalculatorOutput
> {
constructor(opts: types.BaseTaskOptions) {
constructor(opts: types.BaseTaskOptions = {}) {
super(opts)
}

Wyświetl plik

@ -1,26 +1,20 @@
import * as metaphor from '@/services/metaphor'
import * as types from '@/types'
import { Agentic } from '@/agentic'
import { BaseTask } from '@/task'
export class MetaphorSearchTool extends BaseTask<
metaphor.MetaphorSearchInput,
metaphor.MetaphorSearchOutput
> {
_metaphorClient: metaphor.MetaphorClient
protected _metaphorClient: metaphor.MetaphorClient
constructor({
agentic,
metaphorClient = new metaphor.MetaphorClient(),
...rest
...opts
}: {
agentic: Agentic
metaphorClient?: metaphor.MetaphorClient
} & types.BaseTaskOptions) {
super({
agentic,
...rest
})
} & types.BaseTaskOptions = {}) {
super(opts)
this._metaphorClient = metaphorClient
}
@ -34,7 +28,7 @@ export class MetaphorSearchTool extends BaseTask<
}
public override get nameForModel(): string {
return 'metaphor_web_search'
return 'metaphorWebSearch'
}
protected override async _call(

Wyświetl plik

@ -1,7 +1,6 @@
import { z } from 'zod'
import * as types from '@/types'
import { Agentic } from '@/agentic'
import { NovuClient } from '@/services/novu'
import { BaseTask } from '@/task'
@ -39,18 +38,15 @@ export class NovuNotificationTool extends BaseTask<
NovuNotificationToolInput,
NovuNotificationToolOutput
> {
_novuClient: NovuClient
protected _novuClient: NovuClient
constructor({
agentic,
novuClient = new NovuClient()
novuClient = new NovuClient(),
...opts
}: {
agentic: Agentic
novuClient?: NovuClient
}) {
super({
agentic
})
} & types.BaseTaskOptions = {}) {
super(opts)
this._novuClient = novuClient
}
@ -64,7 +60,7 @@ export class NovuNotificationTool extends BaseTask<
}
public override get nameForModel(): string {
return 'novu_send_notification'
return 'novuSendNotification'
}
protected override async _call(

Wyświetl plik

@ -9,7 +9,12 @@ export const WeatherInputSchema = z.object({
.string()
.describe(
'Location to get the weather for. Can be a city name like "Paris", a zipcode like "53121", an international postal code like "SW1", or a latitude and longitude like "48.8567,2.3508"'
)
),
units: z
.union([z.literal('imperial'), z.literal('metric')])
.default('imperial')
.optional()
})
export type WeatherInput = z.infer<typeof WeatherInputSchema>
@ -33,8 +38,8 @@ const ConditionSchema = z.object({
const CurrentSchema = z.object({
last_updated_epoch: z.number(),
last_updated: z.string(),
temp_c: z.number(),
temp_f: z.number(),
temp_c: z.number().describe('temperature in celsius'),
temp_f: z.number().describe('temperature in fahrenheit'),
is_day: z.number(),
condition: ConditionSchema,
wind_mph: z.number(),
@ -65,14 +70,19 @@ export type WeatherOutput = z.infer<typeof WeatherOutputSchema>
export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> {
client: WeatherClient
constructor(
opts: {
weather: WeatherClient
} & types.BaseTaskOptions
) {
constructor({
weather = new WeatherClient({ apiKey: process.env.WEATHER_API_KEY }),
...opts
}: {
weather?: WeatherClient
} & types.BaseTaskOptions = {}) {
super(opts)
this.client = opts.weather
if (!weather) {
throw new Error(`Error WeatherTool missing required "weather" client`)
}
this.client = weather
}
public override get inputSchema() {
@ -92,7 +102,7 @@ export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> {
}
public override get descForModel(): string {
return 'Useful for getting the current weather at a location'
return 'Useful for getting the current weather at a location.'
}
protected override async _call(

Wyświetl plik

@ -20,7 +20,7 @@ export type SafeParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
: never
export interface BaseTaskOptions {
agentic: Agentic
agentic?: Agentic
timeoutMs?: number
retryConfig?: RetryConfig