kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: separate out functions, generate minified JSON, improve Boolean handling
rodzic
eeb2315879
commit
e4d929c11f
282
src/llms/chat.ts
282
src/llms/chat.ts
|
@ -21,6 +21,113 @@ import {
|
||||||
getNumTokensForChatMessages
|
getNumTokensForChatMessages
|
||||||
} from './llm-utils'
|
} from './llm-utils'
|
||||||
|
|
||||||
|
const BOOLEAN_OUTPUTS = {
|
||||||
|
true: true,
|
||||||
|
false: false,
|
||||||
|
t: true,
|
||||||
|
f: false,
|
||||||
|
yes: true,
|
||||||
|
no: false,
|
||||||
|
y: true,
|
||||||
|
n: false,
|
||||||
|
'1': true,
|
||||||
|
'0': false
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseArrayOutput(output: string): Array<any> {
|
||||||
|
try {
|
||||||
|
const trimmedOutput = extractJSONArrayFromString(output)
|
||||||
|
const parsedOutput = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||||
|
return parsedOutput
|
||||||
|
} catch (err: any) {
|
||||||
|
if (err instanceof JSONRepairError) {
|
||||||
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
|
} else if (err instanceof SyntaxError) {
|
||||||
|
throw new errors.OutputValidationError(
|
||||||
|
`Invalid JSON array: ${err.message}`,
|
||||||
|
{ cause: err }
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseObjectOutput(output) {
|
||||||
|
try {
|
||||||
|
const trimmedOutput = extractJSONObjectFromString(output)
|
||||||
|
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||||
|
|
||||||
|
if (Array.isArray(output)) {
|
||||||
|
// TODO
|
||||||
|
output = output[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
} catch (err: any) {
|
||||||
|
if (err instanceof JSONRepairError) {
|
||||||
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
|
} else if (err instanceof SyntaxError) {
|
||||||
|
throw new errors.OutputValidationError(
|
||||||
|
`Invalid JSON object: ${err.message}`,
|
||||||
|
{ cause: err }
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseBooleanOutput(output): boolean {
|
||||||
|
output = output
|
||||||
|
.toLowerCase()
|
||||||
|
.trim()
|
||||||
|
.replace(/[.!?]+$/, '')
|
||||||
|
|
||||||
|
const booleanOutput = BOOLEAN_OUTPUTS[output]
|
||||||
|
|
||||||
|
if (booleanOutput !== undefined) {
|
||||||
|
return booleanOutput
|
||||||
|
} else {
|
||||||
|
throw new errors.OutputValidationError(`Invalid boolean output: ${output}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseNumberOutput(output, outputSchema: z.ZodNumber): number {
|
||||||
|
output = output.trim()
|
||||||
|
|
||||||
|
const numberOutput = outputSchema.isInt
|
||||||
|
? parseInt(output)
|
||||||
|
: parseFloat(output)
|
||||||
|
|
||||||
|
if (isNaN(numberOutput)) {
|
||||||
|
throw new errors.OutputValidationError(`Invalid number output: ${output}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return numberOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseOutput(output: any, outputSchema: ZodType<any>) {
|
||||||
|
if (outputSchema instanceof z.ZodArray) {
|
||||||
|
output = parseArrayOutput(output)
|
||||||
|
} else if (outputSchema instanceof z.ZodObject) {
|
||||||
|
output = parseObjectOutput(output)
|
||||||
|
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||||
|
output = parseBooleanOutput(output)
|
||||||
|
} else if (outputSchema instanceof z.ZodNumber) {
|
||||||
|
output = parseNumberOutput(output, outputSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: fix typescript issue here with recursive types
|
||||||
|
const safeResult = (outputSchema.safeParse as any)(output)
|
||||||
|
|
||||||
|
if (!safeResult.success) {
|
||||||
|
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return safeResult.data
|
||||||
|
}
|
||||||
|
|
||||||
export abstract class BaseChatCompletion<
|
export abstract class BaseChatCompletion<
|
||||||
TInput extends types.TaskInput = void,
|
TInput extends types.TaskInput = void,
|
||||||
TOutput extends types.TaskOutput = string,
|
TOutput extends types.TaskOutput = string,
|
||||||
|
@ -141,47 +248,12 @@ export abstract class BaseChatCompletion<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this._outputSchema) {
|
const outputMsg = this.outputMessage()
|
||||||
// TODO: replace zod-to-ts with zod-to-json-schema?
|
if (outputMsg !== null) {
|
||||||
const { node } = zodToTs(this._outputSchema)
|
messages.push({
|
||||||
|
role: 'system',
|
||||||
if (node.kind === 152) {
|
content: outputMsg
|
||||||
// handle raw strings differently
|
})
|
||||||
messages.push({
|
|
||||||
role: 'system',
|
|
||||||
content: dedent`Output a raw string only, without any additional text.`
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
const tsTypeString = printNode(node, {
|
|
||||||
removeComments: false,
|
|
||||||
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
|
||||||
omitTrailingSemicolon: true,
|
|
||||||
noEmitHelpers: true
|
|
||||||
})
|
|
||||||
.replace(/^ {4}/gm, ' ')
|
|
||||||
.replace(/;$/gm, '')
|
|
||||||
|
|
||||||
const label =
|
|
||||||
this._outputSchema instanceof z.ZodArray
|
|
||||||
? 'JSON array'
|
|
||||||
: this._outputSchema instanceof z.ZodObject
|
|
||||||
? 'JSON object'
|
|
||||||
: this._outputSchema instanceof z.ZodNumber
|
|
||||||
? 'number'
|
|
||||||
: this._outputSchema instanceof z.ZodString
|
|
||||||
? 'string'
|
|
||||||
: this._outputSchema instanceof z.ZodBoolean
|
|
||||||
? 'boolean'
|
|
||||||
: 'JSON value'
|
|
||||||
|
|
||||||
messages.push({
|
|
||||||
role: 'system',
|
|
||||||
content: dedent`Do not output code. Output a single ${label} in the following TypeScript format:
|
|
||||||
\`\`\`ts
|
|
||||||
${tsTypeString}
|
|
||||||
\`\`\``
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx?.retryMessage) {
|
if (ctx?.retryMessage) {
|
||||||
|
@ -196,6 +268,50 @@ export abstract class BaseChatCompletion<
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public outputMessage(): string | null {
|
||||||
|
const outputSchema = this._outputSchema
|
||||||
|
|
||||||
|
if (!outputSchema) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: replace zod-to-ts with zod-to-json-schema?
|
||||||
|
const { node } = zodToTs(outputSchema)
|
||||||
|
|
||||||
|
if (node.kind === 152) {
|
||||||
|
// Handle raw strings differently:
|
||||||
|
return dedent`Output a raw string only, without any additional text.`
|
||||||
|
}
|
||||||
|
|
||||||
|
const tsTypeString = printNode(node, {
|
||||||
|
removeComments: false,
|
||||||
|
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
||||||
|
omitTrailingSemicolon: true,
|
||||||
|
noEmitHelpers: true
|
||||||
|
})
|
||||||
|
.replace(/^ {4}/gm, ' ')
|
||||||
|
.replace(/;$/gm, '')
|
||||||
|
let label: string
|
||||||
|
if (outputSchema instanceof z.ZodArray) {
|
||||||
|
label = 'JSON array (minified)'
|
||||||
|
} else if (outputSchema instanceof z.ZodObject) {
|
||||||
|
label = 'JSON object (minified)'
|
||||||
|
} else if (outputSchema instanceof z.ZodNumber) {
|
||||||
|
label = 'number'
|
||||||
|
} else if (outputSchema instanceof z.ZodString) {
|
||||||
|
label = 'string'
|
||||||
|
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||||
|
label = 'boolean'
|
||||||
|
} else {
|
||||||
|
label = 'JSON value'
|
||||||
|
}
|
||||||
|
|
||||||
|
return dedent`Do not output code. Output a single ${label} in the following TypeScript format:
|
||||||
|
\`\`\`ts
|
||||||
|
${tsTypeString}
|
||||||
|
\`\`\``
|
||||||
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
||||||
): Promise<TOutput> {
|
): Promise<TOutput> {
|
||||||
|
@ -356,89 +472,7 @@ export abstract class BaseChatCompletion<
|
||||||
// console.log('<<<')
|
// console.log('<<<')
|
||||||
|
|
||||||
if (this._outputSchema) {
|
if (this._outputSchema) {
|
||||||
const outputSchema = this._outputSchema
|
return parseOutput(output, this._outputSchema)
|
||||||
|
|
||||||
if (outputSchema instanceof z.ZodArray) {
|
|
||||||
try {
|
|
||||||
const trimmedOutput = extractJSONArrayFromString(output)
|
|
||||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
|
||||||
} catch (err: any) {
|
|
||||||
if (err instanceof JSONRepairError) {
|
|
||||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
|
||||||
} else if (err instanceof SyntaxError) {
|
|
||||||
throw new errors.OutputValidationError(
|
|
||||||
`Invalid JSON array: ${err.message}`,
|
|
||||||
{ cause: err }
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (outputSchema instanceof z.ZodObject) {
|
|
||||||
try {
|
|
||||||
const trimmedOutput = extractJSONObjectFromString(output)
|
|
||||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
|
||||||
|
|
||||||
if (Array.isArray(output)) {
|
|
||||||
// TODO
|
|
||||||
output = output[0]
|
|
||||||
}
|
|
||||||
} catch (err: any) {
|
|
||||||
if (err instanceof JSONRepairError) {
|
|
||||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
|
||||||
} else if (err instanceof SyntaxError) {
|
|
||||||
throw new errors.OutputValidationError(
|
|
||||||
`Invalid JSON object: ${err.message}`,
|
|
||||||
{ cause: err }
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
|
||||||
output = output.toLowerCase().trim()
|
|
||||||
const booleanOutputs = {
|
|
||||||
true: true,
|
|
||||||
false: false,
|
|
||||||
yes: true,
|
|
||||||
no: false,
|
|
||||||
1: true,
|
|
||||||
0: false
|
|
||||||
}
|
|
||||||
|
|
||||||
const booleanOutput = booleanOutputs[output]
|
|
||||||
|
|
||||||
if (booleanOutput !== undefined) {
|
|
||||||
output = booleanOutput
|
|
||||||
} else {
|
|
||||||
throw new errors.OutputValidationError(
|
|
||||||
`Invalid boolean output: ${output}`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else if (outputSchema instanceof z.ZodNumber) {
|
|
||||||
output = output.trim()
|
|
||||||
|
|
||||||
const numberOutput = outputSchema.isInt
|
|
||||||
? parseInt(output)
|
|
||||||
: parseFloat(output)
|
|
||||||
|
|
||||||
if (isNaN(numberOutput)) {
|
|
||||||
throw new errors.OutputValidationError(
|
|
||||||
`Invalid number output: ${output}`
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
output = numberOutput
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: fix typescript issue here with recursive types
|
|
||||||
const safeResult = (outputSchema.safeParse as any)(output)
|
|
||||||
|
|
||||||
if (!safeResult.success) {
|
|
||||||
throw new errors.ZodOutputValidationError(safeResult.error)
|
|
||||||
}
|
|
||||||
|
|
||||||
return safeResult.data
|
|
||||||
} else {
|
} else {
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
Ładowanie…
Reference in New Issue