From bc12a880a28229f39dd6a69dba5ce6599df9da71 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Wed, 5 Jun 2024 20:23:45 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=8D=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- readme.md | 2 +- src/services/midjourney-client.ts | 82 ++++++++++++++++++++++++++----- 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/readme.md b/readme.md index a6dc6c6..4d70b2d 100644 --- a/readme.md +++ b/readme.md @@ -128,6 +128,7 @@ The SDK-specific imports are all isolated to keep the main `@agentic/stdlib` as - diffbot - exa - firecrawl (WIP) +- midjourney - people data labs (WIP) - perigon - predict leads @@ -169,7 +170,6 @@ The SDK-specific imports are all isolated to keep the main `@agentic/stdlib` as - replicate - huggingface - [skyvern](https://github.com/Skyvern-AI/skyvern) - - midjourney - unstructured - pull from [langchain](https://github.com/langchain-ai/langchainjs/tree/main/langchain) - provide a converter for langchain `DynamicStructuredTool` diff --git a/src/services/midjourney-client.ts b/src/services/midjourney-client.ts index 3609e61..5ce2962 100644 --- a/src/services/midjourney-client.ts +++ b/src/services/midjourney-client.ts @@ -5,6 +5,8 @@ import { TimeoutError } from '../errors.js' import { aiFunction, AIFunctionsProvider } from '../fns.js' import { assert, delay, getEnv, pruneNullOrUndefined } from '../utils.js' +// TODO: add additional methods for upscaling, variations, etc. + export namespace midjourney { export const API_BASE_URL = 'https://cl.imagineapi.dev' @@ -28,6 +30,12 @@ export namespace midjourney { ref?: string upscaled?: string[] } + + export interface JobOptions { + wait?: boolean + timeoutMs?: number + intervalMs?: number + } } /** @@ -79,12 +87,22 @@ export class MidjourneyClient extends AIFunctionsProvider { }) }) async imagine( - promptOrOptions: string | { prompt: string } + promptOrOptions: + | string + | ({ + prompt: string + } & midjourney.JobOptions) ): Promise { - const options = - typeof promptOrOptions === 'string' - ? { prompt: promptOrOptions } - : promptOrOptions + const { + wait = true, + timeoutMs, + intervalMs, + ...options + } = typeof promptOrOptions === 'string' + ? ({ prompt: promptOrOptions } as { + prompt: string + } & midjourney.JobOptions) + : promptOrOptions const res = await this.ky .post('items/images', { @@ -92,15 +110,56 @@ export class MidjourneyClient extends AIFunctionsProvider { }) .json() - return pruneNullOrUndefined(res.data) + const job = pruneNullOrUndefined(res.data) + if (!wait) { + return job + } + + if (job.status === 'completed' || job.status === 'failed') { + return job + } + + return this.waitForJobById(job.id, { + timeoutMs, + intervalMs + }) } - async getJobById(jobId: string): Promise { + async getJobById( + jobIdOrOptions: + | string + | ({ + jobId: string + } & midjourney.JobOptions) + ): Promise { + const { + jobId, + wait = true, + timeoutMs, + intervalMs + } = typeof jobIdOrOptions === 'string' + ? ({ jobId: jobIdOrOptions } as { + jobId: string + } & midjourney.JobOptions) + : jobIdOrOptions + const res = await this.ky .get(`items/images/${jobId}`) .json() - return pruneNullOrUndefined(res.data) + const job = pruneNullOrUndefined(res.data) + if (!wait) { + return job + } + + if (job.status === 'completed' || job.status === 'failed') { + return job + } + + return this.waitForJobById(job.id, { + timeoutMs, + intervalMs + }) } async waitForJobById( @@ -108,11 +167,8 @@ export class MidjourneyClient extends AIFunctionsProvider { { timeoutMs = 5 * 60 * 1000, // 5 minutes intervalMs = 1000 - }: { - timeoutMs?: number - intervalMs?: number - } = {} - ) { + }: Omit = {} + ): Promise { const startTimeMs = Date.now() function checkForTimeout() {