kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: refactoring openai completions
rodzic
4174d7cdbf
commit
ab9de9a725
162
src/llm.ts
162
src/llm.ts
|
@ -1,7 +1,15 @@
|
||||||
|
import { jsonrepair } from 'jsonrepair'
|
||||||
|
import Mustache from 'mustache'
|
||||||
|
import { dedent } from 'ts-dedent'
|
||||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||||
|
import { printNode, zodToTs } from 'zod-to-ts'
|
||||||
|
|
||||||
import * as types from './types'
|
import * as types from './types'
|
||||||
import { BaseTaskCallBuilder } from './task'
|
import { BaseTaskCallBuilder } from './task'
|
||||||
|
import {
|
||||||
|
extractJSONArrayFromString,
|
||||||
|
extractJSONObjectFromString
|
||||||
|
} from './utils'
|
||||||
|
|
||||||
export abstract class BaseLLMCallBuilder<
|
export abstract class BaseLLMCallBuilder<
|
||||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||||
|
@ -59,10 +67,11 @@ export abstract class BaseLLMCallBuilder<
|
||||||
// }): Promise<TOutput>
|
// }): Promise<TOutput>
|
||||||
}
|
}
|
||||||
|
|
||||||
export abstract class ChatModelBuilder<
|
export abstract class BaseChatModelBuilder<
|
||||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>
|
TModelParams extends Record<string, any> = Record<string, any>,
|
||||||
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
|
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
|
||||||
_messages: types.ChatMessage[]
|
_messages: types.ChatMessage[]
|
||||||
|
|
||||||
|
@ -71,4 +80,153 @@ export abstract class ChatModelBuilder<
|
||||||
|
|
||||||
this._messages = options.messages
|
this._messages = options.messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected abstract _createChatCompletion(
|
||||||
|
messages: types.ChatMessage[]
|
||||||
|
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||||
|
|
||||||
|
override async call(
|
||||||
|
input?: types.ParsedData<TInput>
|
||||||
|
): Promise<types.ParsedData<TOutput>> {
|
||||||
|
if (this._inputSchema) {
|
||||||
|
const inputSchema =
|
||||||
|
this._inputSchema instanceof z.ZodType
|
||||||
|
? this._inputSchema
|
||||||
|
: z.object(this._inputSchema)
|
||||||
|
|
||||||
|
// TODO: handle errors gracefully
|
||||||
|
input = inputSchema.parse(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: validate input message variables against input schema
|
||||||
|
|
||||||
|
const messages = this._messages
|
||||||
|
.map((message) => {
|
||||||
|
return {
|
||||||
|
...message,
|
||||||
|
content: message.content
|
||||||
|
? Mustache.render(dedent(message.content), input).trim()
|
||||||
|
: ''
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter((message) => message.content)
|
||||||
|
|
||||||
|
if (this._examples?.length) {
|
||||||
|
// TODO: smarter example selection
|
||||||
|
for (const example of this._examples) {
|
||||||
|
messages.push({
|
||||||
|
role: 'system',
|
||||||
|
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this._outputSchema) {
|
||||||
|
const outputSchema =
|
||||||
|
this._outputSchema instanceof z.ZodType
|
||||||
|
? this._outputSchema
|
||||||
|
: z.object(this._outputSchema)
|
||||||
|
|
||||||
|
const { node } = zodToTs(outputSchema)
|
||||||
|
|
||||||
|
if (node.kind === 152) {
|
||||||
|
// handle raw strings differently
|
||||||
|
messages.push({
|
||||||
|
role: 'system',
|
||||||
|
content: dedent`Output a raw string only, without any additional text.`
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const tsTypeString = printNode(node, {
|
||||||
|
removeComments: false,
|
||||||
|
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
||||||
|
omitTrailingSemicolon: true,
|
||||||
|
noEmitHelpers: true
|
||||||
|
})
|
||||||
|
.replace(/^ /gm, ' ')
|
||||||
|
.replace(/;$/gm, '')
|
||||||
|
|
||||||
|
messages.push({
|
||||||
|
role: 'system',
|
||||||
|
content: dedent`Output JSON only in the following TypeScript format:
|
||||||
|
\`\`\`ts
|
||||||
|
${tsTypeString}
|
||||||
|
\`\`\``
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: filter/compress messages based on token counts
|
||||||
|
|
||||||
|
console.log('>>>')
|
||||||
|
console.log(messages)
|
||||||
|
|
||||||
|
const completion = await this._createChatCompletion(messages)
|
||||||
|
let output: any = completion.message.content
|
||||||
|
|
||||||
|
console.log('===')
|
||||||
|
console.log(output)
|
||||||
|
console.log('<<<')
|
||||||
|
|
||||||
|
if (this._outputSchema) {
|
||||||
|
const outputSchema =
|
||||||
|
this._outputSchema instanceof z.ZodType
|
||||||
|
? this._outputSchema
|
||||||
|
: z.object(this._outputSchema)
|
||||||
|
|
||||||
|
if (outputSchema instanceof z.ZodArray) {
|
||||||
|
try {
|
||||||
|
const trimmedOutput = extractJSONArrayFromString(output)
|
||||||
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||||
|
} catch (err) {
|
||||||
|
// TODO
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
} else if (outputSchema instanceof z.ZodObject) {
|
||||||
|
try {
|
||||||
|
const trimmedOutput = extractJSONObjectFromString(output)
|
||||||
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||||
|
} catch (err) {
|
||||||
|
// TODO
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||||
|
output = output.toLowerCase().trim()
|
||||||
|
const booleanOutputs = {
|
||||||
|
true: true,
|
||||||
|
false: false,
|
||||||
|
yes: true,
|
||||||
|
no: false,
|
||||||
|
1: true,
|
||||||
|
0: false
|
||||||
|
}
|
||||||
|
|
||||||
|
const booleanOutput = booleanOutputs[output]
|
||||||
|
if (booleanOutput !== undefined) {
|
||||||
|
output = booleanOutput
|
||||||
|
} else {
|
||||||
|
// TODO
|
||||||
|
throw new Error(`invalid boolean output: ${output}`)
|
||||||
|
}
|
||||||
|
} else if (outputSchema instanceof z.ZodNumber) {
|
||||||
|
output = output.trim()
|
||||||
|
|
||||||
|
const numberOutput = outputSchema.isInt
|
||||||
|
? parseInt(output)
|
||||||
|
: parseFloat(output)
|
||||||
|
|
||||||
|
if (isNaN(numberOutput)) {
|
||||||
|
// TODO
|
||||||
|
throw new Error(`invalid number output: ${output}`)
|
||||||
|
} else {
|
||||||
|
output = numberOutput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle errors, retry logic, and self-healing
|
||||||
|
|
||||||
|
return outputSchema.parse(output)
|
||||||
|
} else {
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
168
src/openai.ts
168
src/openai.ts
|
@ -1,25 +1,18 @@
|
||||||
import { jsonrepair } from 'jsonrepair'
|
|
||||||
import Mustache from 'mustache'
|
|
||||||
import { dedent } from 'ts-dedent'
|
|
||||||
import type { SetRequired } from 'type-fest'
|
import type { SetRequired } from 'type-fest'
|
||||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||||
import { printNode, zodToTs } from 'zod-to-ts'
|
|
||||||
|
|
||||||
import * as types from './types'
|
import * as types from './types'
|
||||||
import { defaultOpenAIModel } from './constants'
|
import { defaultOpenAIModel } from './constants'
|
||||||
import { ChatModelBuilder } from './llm'
|
import { BaseChatModelBuilder } from './llm'
|
||||||
import {
|
|
||||||
extractJSONArrayFromString,
|
|
||||||
extractJSONObjectFromString
|
|
||||||
} from './utils'
|
|
||||||
|
|
||||||
export class OpenAIChatModelBuilder<
|
export class OpenAIChatModelBuilder<
|
||||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>
|
||||||
> extends ChatModelBuilder<
|
> extends BaseChatModelBuilder<
|
||||||
TInput,
|
TInput,
|
||||||
TOutput,
|
TOutput,
|
||||||
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>
|
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>,
|
||||||
|
types.openai.ChatCompletionResponse
|
||||||
> {
|
> {
|
||||||
_client: types.openai.OpenAIClient
|
_client: types.openai.OpenAIClient
|
||||||
|
|
||||||
|
@ -40,151 +33,20 @@ export class OpenAIChatModelBuilder<
|
||||||
this._client = client
|
this._client = client
|
||||||
}
|
}
|
||||||
|
|
||||||
override async call(
|
protected override async _createChatCompletion(
|
||||||
input?: types.ParsedData<TInput>
|
messages: types.ChatMessage[]
|
||||||
): Promise<types.ParsedData<TOutput>> {
|
): Promise<
|
||||||
if (this._inputSchema) {
|
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
|
||||||
const inputSchema =
|
> {
|
||||||
this._inputSchema instanceof z.ZodType
|
const response = await this._client.createChatCompletion({
|
||||||
? this._inputSchema
|
model: this._model,
|
||||||
: z.object(this._inputSchema)
|
...this._modelParams,
|
||||||
|
|
||||||
// TODO: handle errors gracefully
|
|
||||||
input = inputSchema.parse(input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: validate input message variables against input schema
|
|
||||||
|
|
||||||
const messages = this._messages
|
|
||||||
.map((message) => {
|
|
||||||
return {
|
|
||||||
...message,
|
|
||||||
content: message.content
|
|
||||||
? Mustache.render(dedent(message.content), input).trim()
|
|
||||||
: ''
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter((message) => message.content)
|
|
||||||
|
|
||||||
if (this._examples?.length) {
|
|
||||||
// TODO: smarter example selection
|
|
||||||
for (const example of this._examples) {
|
|
||||||
messages.push({
|
|
||||||
role: 'system',
|
|
||||||
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this._outputSchema) {
|
|
||||||
const outputSchema =
|
|
||||||
this._outputSchema instanceof z.ZodType
|
|
||||||
? this._outputSchema
|
|
||||||
: z.object(this._outputSchema)
|
|
||||||
|
|
||||||
const { node } = zodToTs(outputSchema)
|
|
||||||
|
|
||||||
if (node.kind === 152) {
|
|
||||||
// handle raw strings differently
|
|
||||||
messages.push({
|
|
||||||
role: 'system',
|
|
||||||
content: dedent`Output a raw string only, without any additional text.`
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
const tsTypeString = printNode(node, {
|
|
||||||
removeComments: false,
|
|
||||||
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
|
||||||
omitTrailingSemicolon: true,
|
|
||||||
noEmitHelpers: true
|
|
||||||
})
|
|
||||||
.replace(/^ /gm, ' ')
|
|
||||||
.replace(/;$/gm, '')
|
|
||||||
|
|
||||||
messages.push({
|
|
||||||
role: 'system',
|
|
||||||
content: dedent`Output JSON only in the following TypeScript format:
|
|
||||||
\`\`\`ts
|
|
||||||
${tsTypeString}
|
|
||||||
\`\`\``
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: filter/compress messages based on token counts
|
|
||||||
|
|
||||||
console.log('>>>')
|
|
||||||
console.log(messages)
|
|
||||||
const completion = await this._client.createChatCompletion({
|
|
||||||
model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining
|
|
||||||
...this._outputSchema,
|
|
||||||
messages
|
messages
|
||||||
})
|
})
|
||||||
|
|
||||||
if (this._outputSchema) {
|
return {
|
||||||
const outputSchema =
|
message: response.message,
|
||||||
this._outputSchema instanceof z.ZodType
|
response: response.response
|
||||||
? this._outputSchema
|
|
||||||
: z.object(this._outputSchema)
|
|
||||||
|
|
||||||
let output: any = completion.message.content
|
|
||||||
console.log('===')
|
|
||||||
console.log(output)
|
|
||||||
console.log('<<<')
|
|
||||||
|
|
||||||
if (outputSchema instanceof z.ZodArray) {
|
|
||||||
try {
|
|
||||||
const trimmedOutput = extractJSONArrayFromString(output)
|
|
||||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
|
||||||
} catch (err) {
|
|
||||||
// TODO
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
} else if (outputSchema instanceof z.ZodObject) {
|
|
||||||
try {
|
|
||||||
const trimmedOutput = extractJSONObjectFromString(output)
|
|
||||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
|
||||||
} catch (err) {
|
|
||||||
// TODO
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
|
||||||
output = output.toLowerCase().trim()
|
|
||||||
const booleanOutputs = {
|
|
||||||
true: true,
|
|
||||||
false: false,
|
|
||||||
yes: true,
|
|
||||||
no: false,
|
|
||||||
1: true,
|
|
||||||
0: false
|
|
||||||
}
|
|
||||||
|
|
||||||
const booleanOutput = booleanOutputs[output]
|
|
||||||
if (booleanOutput !== undefined) {
|
|
||||||
output = booleanOutput
|
|
||||||
} else {
|
|
||||||
// TODO
|
|
||||||
throw new Error(`invalid boolean output: ${output}`)
|
|
||||||
}
|
|
||||||
} else if (outputSchema instanceof z.ZodNumber) {
|
|
||||||
output = output.trim()
|
|
||||||
|
|
||||||
const numberOutput = outputSchema.isInt
|
|
||||||
? parseInt(output)
|
|
||||||
: parseFloat(output)
|
|
||||||
|
|
||||||
if (isNaN(numberOutput)) {
|
|
||||||
// TODO
|
|
||||||
throw new Error(`invalid number output: ${output}`)
|
|
||||||
} else {
|
|
||||||
output = numberOutput
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: handle errors, retry logic, and self-healing
|
|
||||||
|
|
||||||
return outputSchema.parse(output)
|
|
||||||
} else {
|
|
||||||
return completion.message.content as any
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
18
src/types.ts
18
src/types.ts
|
@ -61,7 +61,13 @@ export interface LLMOptions<
|
||||||
promptSuffix?: string
|
promptSuffix?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ChatMessageRole = 'user' | 'system' | 'assistant'
|
// export type ChatMessageRole = 'user' | 'system' | 'assistant'
|
||||||
|
export const ChatMessageRoleSchema = z.union([
|
||||||
|
z.literal('user'),
|
||||||
|
z.literal('system'),
|
||||||
|
z.literal('assistant')
|
||||||
|
])
|
||||||
|
export type ChatMessageRole = z.infer<typeof ChatMessageRoleSchema>
|
||||||
|
|
||||||
export interface ChatMessage {
|
export interface ChatMessage {
|
||||||
role: ChatMessageRole
|
role: ChatMessageRole
|
||||||
|
@ -76,6 +82,16 @@ export interface ChatModelOptions<
|
||||||
messages: ChatMessage[]
|
messages: ChatMessage[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface BaseChatCompletionResponse<
|
||||||
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||||
|
> {
|
||||||
|
/** The completion message. */
|
||||||
|
message: ChatMessage
|
||||||
|
|
||||||
|
/** The raw response from the LLM provider. */
|
||||||
|
response: TChatCompletionResponse
|
||||||
|
}
|
||||||
|
|
||||||
export interface LLMExample {
|
export interface LLMExample {
|
||||||
input: string
|
input: string
|
||||||
output: string
|
output: string
|
||||||
|
|
Ładowanie…
Reference in New Issue