kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: improve input/output typing tests
rodzic
d788cc1600
commit
33069bed5e
|
@ -19,6 +19,7 @@ import {
|
||||||
extractJSONObjectFromString
|
extractJSONObjectFromString
|
||||||
} from '@/utils'
|
} from '@/utils'
|
||||||
|
|
||||||
|
// TODO: TInput should only be allowed to be an object
|
||||||
export abstract class BaseLLM<
|
export abstract class BaseLLM<
|
||||||
TInput = void,
|
TInput = void,
|
||||||
TOutput = string,
|
TOutput = string,
|
||||||
|
@ -142,6 +143,28 @@ export abstract class BaseChatModel<
|
||||||
this._messages = options.messages
|
this._messages = options.messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
|
||||||
|
input<U>(inputSchema: ZodType<U>): BaseChatModel<U, TOutput, TModelParams> {
|
||||||
|
const refinedInstance = this as unknown as BaseChatModel<
|
||||||
|
U,
|
||||||
|
TOutput,
|
||||||
|
TModelParams
|
||||||
|
>
|
||||||
|
refinedInstance._inputSchema = inputSchema
|
||||||
|
return refinedInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
|
||||||
|
output<U>(outputSchema: ZodType<U>): BaseChatModel<TInput, U, TModelParams> {
|
||||||
|
const refinedInstance = this as unknown as BaseChatModel<
|
||||||
|
TInput,
|
||||||
|
U,
|
||||||
|
TModelParams
|
||||||
|
>
|
||||||
|
refinedInstance._outputSchema = outputSchema
|
||||||
|
return refinedInstance
|
||||||
|
}
|
||||||
|
|
||||||
protected abstract _createChatCompletion(
|
protected abstract _createChatCompletion(
|
||||||
messages: types.ChatMessage[]
|
messages: types.ChatMessage[]
|
||||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||||
|
|
|
@ -4,10 +4,48 @@ import sinon from 'sinon'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { OutputValidationError, TemplateValidationError } from '@/errors'
|
import { OutputValidationError, TemplateValidationError } from '@/errors'
|
||||||
import { OpenAIChatModel } from '@/llms/openai'
|
import { BaseChatModel, OpenAIChatModel } from '@/llms'
|
||||||
|
|
||||||
import { createTestAgenticRuntime } from './_utils'
|
import { createTestAgenticRuntime } from './_utils'
|
||||||
|
|
||||||
|
test('OpenAIChatModel ⇒ types', async (t) => {
|
||||||
|
const agentic = createTestAgenticRuntime()
|
||||||
|
const b = agentic.gpt4('test')
|
||||||
|
t.pass()
|
||||||
|
|
||||||
|
expectTypeOf(b).toMatchTypeOf<OpenAIChatModel<any, string>>()
|
||||||
|
|
||||||
|
expectTypeOf(
|
||||||
|
b.input(
|
||||||
|
z.object({
|
||||||
|
foo: z.string()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
).toMatchTypeOf<
|
||||||
|
BaseChatModel<
|
||||||
|
{
|
||||||
|
foo: string
|
||||||
|
},
|
||||||
|
string
|
||||||
|
>
|
||||||
|
>()
|
||||||
|
|
||||||
|
expectTypeOf(
|
||||||
|
b.output(
|
||||||
|
z.object({
|
||||||
|
bar: z.string().optional()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
).toMatchTypeOf<
|
||||||
|
BaseChatModel<
|
||||||
|
any,
|
||||||
|
{
|
||||||
|
bar?: string
|
||||||
|
}
|
||||||
|
>
|
||||||
|
>()
|
||||||
|
})
|
||||||
|
|
||||||
test('OpenAIChatModel ⇒ string output', async (t) => {
|
test('OpenAIChatModel ⇒ string output', async (t) => {
|
||||||
t.timeout(2 * 60 * 1000)
|
t.timeout(2 * 60 * 1000)
|
||||||
const agentic = createTestAgenticRuntime()
|
const agentic = createTestAgenticRuntime()
|
||||||
|
@ -165,7 +203,7 @@ test('OpenAIChatModel ⇒ template variables', async (t) => {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
test.only('OpenAIChatModel ⇒ missing template variable', async (t) => {
|
test('OpenAIChatModel ⇒ missing template variable', async (t) => {
|
||||||
t.timeout(2 * 60 * 1000)
|
t.timeout(2 * 60 * 1000)
|
||||||
const agentic = createTestAgenticRuntime()
|
const agentic = createTestAgenticRuntime()
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ test('TwilioConversationClient.createConversation', async (t) => {
|
||||||
const conversation = await client.createConversation(friendlyName)
|
const conversation = await client.createConversation(friendlyName)
|
||||||
t.is(conversation.friendly_name, friendlyName)
|
t.is(conversation.friendly_name, friendlyName)
|
||||||
|
|
||||||
client.deleteConversation(conversation.sid)
|
await client.deleteConversation(conversation.sid)
|
||||||
})
|
})
|
||||||
|
|
||||||
test('TwilioConversationClient.addParticipant', async (t) => {
|
test('TwilioConversationClient.addParticipant', async (t) => {
|
||||||
|
@ -90,7 +90,7 @@ test('TwilioConversationClient.sendAndWaitForReply', async (t) => {
|
||||||
await t.throwsAsync(
|
await t.throwsAsync(
|
||||||
async () => {
|
async () => {
|
||||||
await client.sendAndWaitForReply({
|
await client.sendAndWaitForReply({
|
||||||
recipientPhoneNumber: process.env.TWILIO_TEST_PHONE_NUMBER as string,
|
recipientPhoneNumber: process.env.TWILIO_TEST_PHONE_NUMBER!,
|
||||||
text: 'Please confirm by replying with "yes" or "no".',
|
text: 'Please confirm by replying with "yes" or "no".',
|
||||||
name: 'wait-for-reply-test',
|
name: 'wait-for-reply-test',
|
||||||
validate: (message) =>
|
validate: (message) =>
|
||||||
|
@ -122,7 +122,7 @@ test('TwilioConversationClient.sendAndWaitForReply.stopSignal', async (t) => {
|
||||||
async () => {
|
async () => {
|
||||||
const controller = new AbortController()
|
const controller = new AbortController()
|
||||||
const promise = client.sendAndWaitForReply({
|
const promise = client.sendAndWaitForReply({
|
||||||
recipientPhoneNumber: process.env.TWILIO_TEST_PHONE_NUMBER as string,
|
recipientPhoneNumber: process.env.TWILIO_TEST_PHONE_NUMBER!,
|
||||||
text: 'Please confirm by replying with "yes" or "no".',
|
text: 'Please confirm by replying with "yes" or "no".',
|
||||||
name: 'wait-for-reply-test',
|
name: 'wait-for-reply-test',
|
||||||
validate: (message) =>
|
validate: (message) =>
|
||||||
|
|
Ładowanie…
Reference in New Issue