From 1bffd5e92a75dfa9db47ba8fb95c0702f0ab3213 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Tue, 14 Feb 2023 00:30:06 -0600 Subject: [PATCH] feat: add ability to override global "fetch" --- src/chatgpt-api.ts | 76 ++++++++++++++++++++++++++++------------------ src/fetch-sse.ts | 5 +-- src/fetch.ts | 4 --- src/types.ts | 2 ++ 4 files changed, 52 insertions(+), 35 deletions(-) diff --git a/src/chatgpt-api.ts b/src/chatgpt-api.ts index 2e3c915..eeca4f2 100644 --- a/src/chatgpt-api.ts +++ b/src/chatgpt-api.ts @@ -5,7 +5,7 @@ import QuickLRU from 'quick-lru' import { v4 as uuidv4 } from 'uuid' import * as types from './types' -import { fetch } from './fetch' +import { fetch as globalFetch } from './fetch' import { fetchSSE } from './fetch-sse' // Official model (costs money and is not fine-tuned for chat) @@ -27,6 +27,7 @@ export class ChatGPTAPI { protected _assistantLabel: string protected _endToken: string protected _sepToken: string + protected _fetch: types.FetchFn protected _getMessageById: types.GetMessageByIdFunction protected _upsertMessage: types.UpsertMessageFunction @@ -47,6 +48,7 @@ export class ChatGPTAPI { * @param messageStore - Optional [Keyv](https://github.com/jaredwray/keyv) store to persist chat messages to. If not provided, messages will be lost when the process exits. * @param getMessageById - Optional function to retrieve a message by its ID. If not provided, the default implementation will be used (using an in-memory `messageStore`). * @param upsertMessage - Optional function to insert or update a message. If not provided, the default implementation will be used (using an in-memory `messageStore`). + * @param fetch - Optional override for the `fetch` implementation to use. Defaults to the global `fetch` function. */ constructor(opts: { apiKey: string @@ -77,6 +79,8 @@ export class ChatGPTAPI { messageStore?: Keyv getMessageById?: types.GetMessageByIdFunction upsertMessage?: types.UpsertMessageFunction + + fetch?: types.FetchFn }) { const { apiKey, @@ -90,13 +94,15 @@ export class ChatGPTAPI { userLabel = USER_LABEL_DEFAULT, assistantLabel = ASSISTANT_LABEL_DEFAULT, getMessageById = this._defaultGetMessageById, - upsertMessage = this._defaultUpsertMessage + upsertMessage = this._defaultUpsertMessage, + fetch = globalFetch } = opts this._apiKey = apiKey this._apiBaseUrl = apiBaseUrl this._apiReverseProxyUrl = apiReverseProxyUrl this._debug = !!debug + this._fetch = fetch this._completionParams = { model: CHATGPT_MODEL, @@ -141,6 +147,14 @@ export class ChatGPTAPI { if (!this._apiKey) { throw new Error('ChatGPT invalid apiKey') } + + if (!this._fetch) { + throw new Error('Invalid environment; fetch is not defined') + } + + if (typeof this._fetch !== 'function') { + throw new Error('Invalid "fetch" is not a function') + } } /** @@ -229,40 +243,44 @@ export class ChatGPTAPI { } if (stream) { - fetchSSE(url, { - method: 'POST', - headers, - body: JSON.stringify(body), - signal: abortSignal, - onMessage: (data: string) => { - if (data === '[DONE]') { - result.text = result.text.trim() - return resolve(result) - } - - try { - const response: types.openai.CompletionResponse = - JSON.parse(data) - - if (response.id) { - result.id = response.id + fetchSSE( + url, + { + method: 'POST', + headers, + body: JSON.stringify(body), + signal: abortSignal, + onMessage: (data: string) => { + if (data === '[DONE]') { + result.text = result.text.trim() + return resolve(result) } - if (response?.choices?.length) { - result.text += response.choices[0].text - result.detail = response + try { + const response: types.openai.CompletionResponse = + JSON.parse(data) - onProgress?.(result) + if (response.id) { + result.id = response.id + } + + if (response?.choices?.length) { + result.text += response.choices[0].text + result.detail = response + + onProgress?.(result) + } + } catch (err) { + console.warn('ChatGPT stream SEE event unexpected error', err) + return reject(err) } - } catch (err) { - console.warn('ChatGPT stream SEE event unexpected error', err) - return reject(err) } - } - }).catch(reject) + }, + this._fetch + ).catch(reject) } else { try { - const res = await fetch(url, { + const res = await this._fetch(url, { method: 'POST', headers, body: JSON.stringify(body), diff --git a/src/fetch-sse.ts b/src/fetch-sse.ts index f9b4ed4..00410eb 100644 --- a/src/fetch-sse.ts +++ b/src/fetch-sse.ts @@ -1,12 +1,13 @@ import { createParser } from 'eventsource-parser' import * as types from './types' -import { fetch } from './fetch' +import { fetch as globalFetch } from './fetch' import { streamAsyncIterable } from './stream-async-iterable' export async function fetchSSE( url: string, - options: Parameters[1] & { onMessage: (data: string) => void } + options: Parameters[1] & { onMessage: (data: string) => void }, + fetch: types.FetchFn = globalFetch ) { const { onMessage, ...fetchOptions } = options const res = await fetch(url, fetchOptions) diff --git a/src/fetch.ts b/src/fetch.ts index f68f4ae..1dcd6a3 100644 --- a/src/fetch.ts +++ b/src/fetch.ts @@ -2,8 +2,4 @@ const fetch = globalThis.fetch -if (typeof fetch !== 'function') { - throw new Error('Invalid environment: global fetch not defined') -} - export { fetch } diff --git a/src/types.ts b/src/types.ts index e6fa650..d5e1e6d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,7 @@ export type Role = 'user' | 'assistant' +export type FetchFn = typeof fetch + export type SendMessageOptions = { conversationId?: string parentMessageId?: string