From ffd8709e3858c46d09047ba97e55a6ae2ac2d341 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Wed, 5 Jun 2024 18:57:43 -0500 Subject: [PATCH] feat: add unofficial midjourney client --- legacy/bin/scratch.ts | 23 ++-- legacy/src/errors.ts | 2 + legacy/src/services/index.ts | 1 + legacy/src/services/midjourney-client.ts | 139 +++++++++++++++++++++++ legacy/src/url-utils.test.ts | 31 +++++ legacy/src/utils.ts | 10 ++ 6 files changed, 198 insertions(+), 8 deletions(-) create mode 100644 legacy/src/services/midjourney-client.ts create mode 100644 legacy/src/url-utils.test.ts diff --git a/legacy/bin/scratch.ts b/legacy/bin/scratch.ts index a90ef820..e53f72af 100644 --- a/legacy/bin/scratch.ts +++ b/legacy/bin/scratch.ts @@ -8,7 +8,7 @@ import restoreCursor from 'restore-cursor' // import { ProxycurlClient } from '../src/services/proxycurl-client.js' // import { WikipediaClient } from '../src/index.js' // import { PerigonClient } from '../src/index.js' -import { FirecrawlClient } from '../src/index.js' +// import { FirecrawlClient } from '../src/index.js' // import { ExaClient } from '../src/index.js' // import { DiffbotClient } from '../src/index.js' // import { WolframClient } from '../src/index.js' @@ -16,6 +16,7 @@ import { FirecrawlClient } from '../src/index.js' // createTwitterV2Client, // TwitterClient // } from '../src/services/twitter/index.js' +import { MidjourneyClient } from '../src/index.js' /** * Scratch pad for testing. @@ -56,13 +57,13 @@ async function main() { // }) // console.log(JSON.stringify(res, null, 2)) - const firecrawl = new FirecrawlClient() - const res = await firecrawl.scrapeUrl({ - url: 'https://www.bbc.com/news/articles/cp4475gwny1o' - // url: 'https://www.theguardian.com/technology/article/2024/jun/04/openai-google-ai-risks-letter' - // url: 'https://www.firecrawl.dev' - }) - console.log(JSON.stringify(res, null, 2)) + // const firecrawl = new FirecrawlClient() + // const res = await firecrawl.scrapeUrl({ + // url: 'https://www.bbc.com/news/articles/cp4475gwny1o' + // // url: 'https://www.theguardian.com/technology/article/2024/jun/04/openai-google-ai-risks-letter' + // // url: 'https://www.firecrawl.dev' + // }) + // console.log(JSON.stringify(res, null, 2)) // const exa = new ExaClient() // const res = await exa.search({ @@ -96,6 +97,12 @@ async function main() { // query: 'open source AI agents' // }) // console.log(res) + + const midjourney = new MidjourneyClient() + const res = await midjourney.imagine( + 'tiny lil baby kittens playing with an inquisitive AI robot, kawaii, anime' + ) + console.log(JSON.stringify(res, null, 2)) } try { diff --git a/legacy/src/errors.ts b/legacy/src/errors.ts index d3bdd21d..d6e43f6a 100644 --- a/legacy/src/errors.ts +++ b/legacy/src/errors.ts @@ -1,3 +1,5 @@ export class RetryableError extends Error {} export class ParseError extends RetryableError {} + +export class TimeoutError extends Error {} diff --git a/legacy/src/services/index.ts b/legacy/src/services/index.ts index 72fae17b..5e5484c7 100644 --- a/legacy/src/services/index.ts +++ b/legacy/src/services/index.ts @@ -3,6 +3,7 @@ export * from './dexa-client.js' export * from './diffbot-client.js' export * from './exa-client.js' export * from './firecrawl-client.js' +export * from './midjourney-client.js' export * from './people-data-labs-client.js' export * from './perigon-client.js' export * from './predict-leads-client.js' diff --git a/legacy/src/services/midjourney-client.ts b/legacy/src/services/midjourney-client.ts new file mode 100644 index 00000000..3609e61b --- /dev/null +++ b/legacy/src/services/midjourney-client.ts @@ -0,0 +1,139 @@ +import defaultKy, { type KyInstance } from 'ky' +import { z } from 'zod' + +import { TimeoutError } from '../errors.js' +import { aiFunction, AIFunctionsProvider } from '../fns.js' +import { assert, delay, getEnv, pruneNullOrUndefined } from '../utils.js' + +export namespace midjourney { + export const API_BASE_URL = 'https://cl.imagineapi.dev' + + export type JobStatus = 'pending' | 'in-progress' | 'completed' | 'failed' + + export interface ImagineResponse { + data: Job + } + + export interface Job { + id: string + prompt: string + status: JobStatus + user_created: string + date_created: string + results?: string + progress?: string + url?: string + error?: string + upscaled_urls?: string[] + ref?: string + upscaled?: string[] + } +} + +/** + * Unofficial Midjourney API client. + * + * @see https://www.imagineapi.dev + */ +export class MidjourneyClient extends AIFunctionsProvider { + readonly ky: KyInstance + readonly apiKey: string + readonly apiBaseUrl: string + + constructor({ + apiKey = getEnv('MIDJOURNEY_IMAGINE_API_KEY'), + apiBaseUrl = midjourney.API_BASE_URL, + ky = defaultKy + }: { + apiKey?: string + apiBaseUrl?: string + ky?: KyInstance + } = {}) { + assert( + apiKey, + 'MidjourneyClient missing required "apiKey" (defaults to "MIDJOURNEY_IMAGINE_API_KEY")' + ) + super() + + this.apiKey = apiKey + this.apiBaseUrl = apiBaseUrl + + this.ky = ky.extend({ + prefixUrl: apiBaseUrl, + headers: { + Authorization: `Bearer ${this.apiKey}` + } + }) + } + + @aiFunction({ + name: 'midjourney_create_images', + description: + 'Creates 4 images from a prompt using the Midjourney API. Useful for generating images on the fly.', + inputSchema: z.object({ + prompt: z + .string() + .describe( + 'Simple, short, comma-separated list of phrases which describe the image you want to generate' + ) + }) + }) + async imagine( + promptOrOptions: string | { prompt: string } + ): Promise { + const options = + typeof promptOrOptions === 'string' + ? { prompt: promptOrOptions } + : promptOrOptions + + const res = await this.ky + .post('items/images', { + json: { ...options } + }) + .json() + + return pruneNullOrUndefined(res.data) + } + + async getJobById(jobId: string): Promise { + const res = await this.ky + .get(`items/images/${jobId}`) + .json() + + return pruneNullOrUndefined(res.data) + } + + async waitForJobById( + jobId: string, + { + timeoutMs = 5 * 60 * 1000, // 5 minutes + intervalMs = 1000 + }: { + timeoutMs?: number + intervalMs?: number + } = {} + ) { + const startTimeMs = Date.now() + + function checkForTimeout() { + const elapsedTimeMs = Date.now() - startTimeMs + if (elapsedTimeMs >= timeoutMs) { + throw new TimeoutError( + `MidjourneyClient timeout waiting for job "${jobId}"` + ) + } + } + + do { + checkForTimeout() + + const job = await this.getJobById(jobId) + if (job.status === 'completed' || job.status === 'failed') { + return job + } + + checkForTimeout() + await delay(intervalMs) + } while (true) + } +} diff --git a/legacy/src/url-utils.test.ts b/legacy/src/url-utils.test.ts new file mode 100644 index 00000000..3523821d --- /dev/null +++ b/legacy/src/url-utils.test.ts @@ -0,0 +1,31 @@ +import { describe, expect, test } from 'vitest' + +import { normalizeUrl } from './url-utils.js' + +describe('normalizeUrl', () => { + test('valid urls', async () => { + expect(normalizeUrl('https://www.google.com')).toBe( + 'https://www.google.com' + ) + expect(normalizeUrl('//www.google.com')).toBe('https://www.google.com') + expect(normalizeUrl('https://www.google.com/foo?')).toBe( + 'https://www.google.com/foo' + ) + expect(normalizeUrl('https://www.google.com/?foo=bar&dog=cat')).toBe( + 'https://www.google.com/?dog=cat&foo=bar' + ) + expect(normalizeUrl('https://google.com/abc/123//')).toBe( + 'https://google.com/abc/123' + ) + }) + + test('invalid urls', async () => { + expect(normalizeUrl('/foo')).toBe(null) + expect(normalizeUrl('/foo/bar/baz')).toBe(null) + expect(normalizeUrl('://foo.com')).toBe(null) + expect(normalizeUrl('foo')).toBe(null) + expect(normalizeUrl('')).toBe(null) + expect(normalizeUrl(undefined as unknown as string)).toBe(null) + expect(normalizeUrl(null as unknown as string)).toBe(null) + }) +}) diff --git a/legacy/src/utils.ts b/legacy/src/utils.ts index 2f033e8a..dd440492 100644 --- a/legacy/src/utils.ts +++ b/legacy/src/utils.ts @@ -57,6 +57,16 @@ export function pruneUndefined>( ) as NonNullable } +export function pruneNullOrUndefined>( + obj: T +): NonNullable<{ [K in keyof T]: Exclude }> { + return Object.fromEntries( + Object.entries(obj).filter( + ([, value]) => value !== undefined && value !== null + ) + ) as NonNullable +} + export function getEnv(name: string): string | undefined { try { return typeof process !== 'undefined'