kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
233 wiersze
6.7 KiB
TypeScript
233 wiersze
6.7 KiB
TypeScript
import { jsonrepair } from 'jsonrepair'
|
|
import Mustache from 'mustache'
|
|
import { dedent } from 'ts-dedent'
|
|
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
|
import { printNode, zodToTs } from 'zod-to-ts'
|
|
|
|
import * as types from './types'
|
|
import { BaseTaskCallBuilder } from './task'
|
|
import {
|
|
extractJSONArrayFromString,
|
|
extractJSONObjectFromString
|
|
} from './utils'
|
|
|
|
export abstract class BaseLLMCallBuilder<
|
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
|
TModelParams extends Record<string, any> = Record<string, any>
|
|
> extends BaseTaskCallBuilder<TInput, TOutput> {
|
|
protected _provider: string
|
|
protected _model: string
|
|
protected _modelParams: TModelParams
|
|
protected _examples: types.LLMExample[]
|
|
|
|
constructor(options: types.BaseLLMOptions<TInput, TOutput, TModelParams>) {
|
|
super(options)
|
|
|
|
this._provider = options.provider
|
|
this._model = options.model
|
|
this._modelParams = options.modelParams
|
|
this._examples = options.examples
|
|
}
|
|
|
|
input<U extends ZodRawShape | ZodTypeAny = TInput>(
|
|
inputSchema: U
|
|
): BaseLLMCallBuilder<U, TOutput, TModelParams> {
|
|
;(
|
|
this as unknown as BaseLLMCallBuilder<U, TOutput, TModelParams>
|
|
)._inputSchema = inputSchema
|
|
return this as unknown as BaseLLMCallBuilder<U, TOutput, TModelParams>
|
|
}
|
|
|
|
output<U extends ZodRawShape | ZodTypeAny = TOutput>(
|
|
outputSchema: U
|
|
): BaseLLMCallBuilder<TInput, U, TModelParams> {
|
|
;(
|
|
this as unknown as BaseLLMCallBuilder<TInput, U, TModelParams>
|
|
)._outputSchema = outputSchema
|
|
return this as unknown as BaseLLMCallBuilder<TInput, U, TModelParams>
|
|
}
|
|
|
|
examples(examples: types.LLMExample[]) {
|
|
this._examples = examples
|
|
return this
|
|
}
|
|
|
|
modelParams(params: Partial<TModelParams>) {
|
|
// We assume that modelParams does not include nested objects.
|
|
// If it did, we would need to do a deep merge.
|
|
this._modelParams = Object.assign({}, this._modelParams, params)
|
|
return this
|
|
}
|
|
|
|
// TODO
|
|
// abstract stream({
|
|
// input: TInput,
|
|
// onProgress: types.ProgressFunction
|
|
// }): Promise<TOutput>
|
|
}
|
|
|
|
export abstract class BaseChatModelBuilder<
|
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
|
TModelParams extends Record<string, any> = Record<string, any>,
|
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
|
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
|
|
_messages: types.ChatMessage[]
|
|
|
|
constructor(options: types.ChatModelOptions<TInput, TOutput, TModelParams>) {
|
|
super(options)
|
|
|
|
this._messages = options.messages
|
|
}
|
|
|
|
protected abstract _createChatCompletion(
|
|
messages: types.ChatMessage[]
|
|
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
|
|
|
override async call(
|
|
input?: types.ParsedData<TInput>
|
|
): Promise<types.ParsedData<TOutput>> {
|
|
if (this._inputSchema) {
|
|
const inputSchema =
|
|
this._inputSchema instanceof z.ZodType
|
|
? this._inputSchema
|
|
: z.object(this._inputSchema)
|
|
|
|
// TODO: handle errors gracefully
|
|
input = inputSchema.parse(input)
|
|
}
|
|
|
|
// TODO: validate input message variables against input schema
|
|
|
|
const messages = this._messages
|
|
.map((message) => {
|
|
return {
|
|
...message,
|
|
content: message.content
|
|
? Mustache.render(dedent(message.content), input).trim()
|
|
: ''
|
|
}
|
|
})
|
|
.filter((message) => message.content)
|
|
|
|
if (this._examples?.length) {
|
|
// TODO: smarter example selection
|
|
for (const example of this._examples) {
|
|
messages.push({
|
|
role: 'system',
|
|
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
|
|
})
|
|
}
|
|
}
|
|
|
|
if (this._outputSchema) {
|
|
const outputSchema =
|
|
this._outputSchema instanceof z.ZodType
|
|
? this._outputSchema
|
|
: z.object(this._outputSchema)
|
|
|
|
const { node } = zodToTs(outputSchema)
|
|
|
|
if (node.kind === 152) {
|
|
// handle raw strings differently
|
|
messages.push({
|
|
role: 'system',
|
|
content: dedent`Output a raw string only, without any additional text.`
|
|
})
|
|
} else {
|
|
const tsTypeString = printNode(node, {
|
|
removeComments: false,
|
|
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
|
omitTrailingSemicolon: true,
|
|
noEmitHelpers: true
|
|
})
|
|
.replace(/^ /gm, ' ')
|
|
.replace(/;$/gm, '')
|
|
|
|
messages.push({
|
|
role: 'system',
|
|
content: dedent`Do not output code. Output JSON only in the following TypeScript format:
|
|
\`\`\`ts
|
|
${tsTypeString}
|
|
\`\`\``
|
|
})
|
|
}
|
|
}
|
|
|
|
// TODO: filter/compress messages based on token counts
|
|
|
|
console.log('>>>')
|
|
console.log(messages)
|
|
|
|
const completion = await this._createChatCompletion(messages)
|
|
let output: any = completion.message.content
|
|
|
|
console.log('===')
|
|
console.log(output)
|
|
console.log('<<<')
|
|
|
|
if (this._outputSchema) {
|
|
const outputSchema =
|
|
this._outputSchema instanceof z.ZodType
|
|
? this._outputSchema
|
|
: z.object(this._outputSchema)
|
|
|
|
if (outputSchema instanceof z.ZodArray) {
|
|
try {
|
|
const trimmedOutput = extractJSONArrayFromString(output)
|
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
|
} catch (err) {
|
|
// TODO
|
|
throw err
|
|
}
|
|
} else if (outputSchema instanceof z.ZodObject) {
|
|
try {
|
|
const trimmedOutput = extractJSONObjectFromString(output)
|
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
|
} catch (err) {
|
|
// TODO
|
|
throw err
|
|
}
|
|
} else if (outputSchema instanceof z.ZodBoolean) {
|
|
output = output.toLowerCase().trim()
|
|
const booleanOutputs = {
|
|
true: true,
|
|
false: false,
|
|
yes: true,
|
|
no: false,
|
|
1: true,
|
|
0: false
|
|
}
|
|
|
|
const booleanOutput = booleanOutputs[output]
|
|
if (booleanOutput !== undefined) {
|
|
output = booleanOutput
|
|
} else {
|
|
// TODO
|
|
throw new Error(`invalid boolean output: ${output}`)
|
|
}
|
|
} else if (outputSchema instanceof z.ZodNumber) {
|
|
output = output.trim()
|
|
|
|
const numberOutput = outputSchema.isInt
|
|
? parseInt(output)
|
|
: parseFloat(output)
|
|
|
|
if (isNaN(numberOutput)) {
|
|
// TODO
|
|
throw new Error(`invalid number output: ${output}`)
|
|
} else {
|
|
output = numberOutput
|
|
}
|
|
}
|
|
|
|
// TODO: handle errors, retry logic, and self-healing
|
|
|
|
return outputSchema.parse(output)
|
|
} else {
|
|
return output
|
|
}
|
|
}
|
|
}
|