diff --git a/src/services/index.ts b/src/services/index.ts index 88569671..76fc4da5 100644 --- a/src/services/index.ts +++ b/src/services/index.ts @@ -5,5 +5,6 @@ export * from './novu' export * from './polygon' export * from './serpapi' export * from './slack' +export * from './stability' export * from './twilio-conversation' export * from './weather' diff --git a/src/services/metaphor.ts b/src/services/metaphor.ts index d079e2db..1ef5f959 100644 --- a/src/services/metaphor.ts +++ b/src/services/metaphor.ts @@ -105,19 +105,10 @@ export type MetaphorSearchOutput = { } export class MetaphorClient { - /** - * HTTP client for the Metaphor API. - */ readonly api: typeof defaultKy - /** - * Metaphor API key. - */ readonly apiKey: string - /** - * Metaphor API base URL. - */ readonly apiBaseUrl: string constructor({ diff --git a/src/services/stability.ts b/src/services/stability.ts new file mode 100644 index 00000000..7bdae637 --- /dev/null +++ b/src/services/stability.ts @@ -0,0 +1,381 @@ +import defaultKy from 'ky' + +import { getEnv } from '@/env' +import { isArray, isString } from '@/utils' + +export const STABILITY_API_BASE_URL = 'https://api.stability.ai' +export const STABILITY_DEFAULT_IMAGE_GENERATION_ENGINE_ID = + 'stable-diffusion-512-v2-1' +export const STABILITY_DEFAULT_IMAGE_UPSCALE_ENGINE_ID = 'esrgan-v1-x2plus' +export const STABILITY_DEFAULT_IMAGE_MASKING_ENGINE_ID = + 'stable-inpainting-512-v2-0' + +export interface StabilityTextToImageOptions { + textPrompts: StabilityTextToImagePrompt[] + + engineId?: string + width?: number + height?: number + cfgScale?: number + clipGuidancePreset?: string + sampler?: StabilityTextToImageSampler + samples?: number + seed?: number + steps?: number + stylePreset?: StabilityImageStylePreset +} + +export interface StabilityImageToImageOptions { + // if initImage is a string, it should be encoed as `binary` + initImage: string | Buffer | Blob + textPrompts: StabilityTextToImagePrompt[] + + initImageMode?: 'IMAGE_STRENGTH' | 'STEP_SCHEDULE' + imageStrength?: number + + engineId?: string + cfgScale?: number + clipGuidancePreset?: string + sampler?: StabilityTextToImageSampler + samples?: number + seed?: number + steps?: number + stylePreset?: StabilityImageStylePreset +} + +export interface StabilityImageToImageMaskingOptions + extends StabilityImageToImageOptions { + maskImage: string | Buffer | Blob + maskSource?: 'MASK_IMAGE_WHITE' | 'MASK_IMAGE_BLACK' | 'INIT_IMAGE_ALPHA' +} + +export interface StabilityImageUpscaleOptions { + image: string | Buffer | Blob + + // only available on certain engines + textPrompts?: StabilityTextToImagePrompt[] + + // should specify either width or height + width?: number + height?: number + + engineId?: string + + // only available on certain engines + cfgScale?: number + seed?: number + steps?: number +} + +export type StabilityTextToImagePrompt = { + text: string + weight?: number +} + +export type StabilityImageStylePreset = + | 'enhance' + | 'anime' + | 'photographic' + | 'digital-art' + | 'comic-book' + | 'fantasy-art' + | 'line-art' + | 'analog-film' + | 'neon-punk' + | 'isometric' + | 'low-poly' + | 'origami' + | 'modeling-compound' + | 'cinematic' + | '3d-model' + | 'pixel-art' + | 'tile-texture' + +export type StabilityTextToImageSampler = + | 'DDIM' + | 'DDPM' + | 'K_DPMPP_2M' + | 'K_DPMPP_2S_ANCESTRAL' + | 'K_DPM_2' + | 'K_DPM_2_ANCESTRAL' + | 'K_EULER' + | 'K_EULER_ANCESTRAL' + | 'K_HEUN' + | 'K_LMS' + +export interface StabilityEngine { + description: string + id: string + name: string + type: string +} +export type StabilityListEnginesResponse = StabilityEngine[] + +export interface StabilityTextToImageResponse { + artifacts: Array<{ + base64: string + finishReason: 'CONTENT_FILTERED' | 'ERROR' | 'SUCCESS' + seed: number + }> +} + +export type StabilityImageToImageResponse = StabilityTextToImageResponse + +export interface StabilityUserAccountResponse { + email: string + id: string + profile_picture?: string + organizations: Array<{ + id: string + is_default: boolean + name: string + role: string + }> +} + +export interface StabilityUserBalanceResponse { + credits: number +} + +export class StabilityClient { + public readonly api: typeof defaultKy + public readonly apiKey: string + public readonly apiBaseUrl: string + + constructor({ + apiKey = getEnv('STABILITY_API_KEY'), + apiBaseUrl = STABILITY_API_BASE_URL, + ky = defaultKy, + organizationId = getEnv('STABILITY_ORGANIZATION_ID'), + clientId = '@agentic/stability', + clientVersion + }: { + apiKey?: string + apiBaseUrl?: string + ky?: typeof defaultKy + organizationId?: string + clientId?: string + clientVersion?: string + } = {}) { + if (!apiKey) { + throw new Error(`Error StabilityClient missing required "apiKey"`) + } + + this.apiKey = apiKey + this.apiBaseUrl = apiBaseUrl + + this.api = ky.extend({ + prefixUrl: apiBaseUrl, + headers: { + Authorization: `Bearer ${this.apiKey}`, + Organization: organizationId, + 'Stability-Client-ID': clientId, + 'Stability-Client-Version': clientVersion + } + }) + } + + /** + * Generates a new image from a text prompt. Can also generate multiple images + * from an array of text prompts. + * + * @see https://platform.stability.ai/rest-api#tag/v1generation/operation/textToImage + */ + async textToImage( + promptOrTextToImageOptions: string | string[] | StabilityTextToImageOptions + ) { + const defaultOptions: Partial = { + engineId: STABILITY_DEFAULT_IMAGE_GENERATION_ENGINE_ID + } + + const options: StabilityTextToImageOptions = isString( + promptOrTextToImageOptions + ) + ? { + ...defaultOptions, + textPrompts: [ + { + text: promptOrTextToImageOptions + } + ] + } + : isArray(promptOrTextToImageOptions) + ? { + ...defaultOptions, + textPrompts: promptOrTextToImageOptions.map((text) => ({ text })) + } + : { + ...defaultOptions, + ...promptOrTextToImageOptions + } + + return this.api + .post(`v1/generation/${options.engineId}/text-to-image`, { + json: { + text_prompts: options.textPrompts, + width: options.width, + height: options.height, + cfg_scale: options.cfgScale, + clip_guidance_preset: options.clipGuidancePreset, + sampler: options.sampler, + samples: options.samples, + seed: options.seed, + steps: options.steps, + style_preset: options.stylePreset + } + }) + .json() + } + + /** + * Modifies an initial image based on a text prompt. + * + * @see https://platform.stability.ai/rest-api#tag/v1generation/operation/imageToImage + */ + async imageToImage(opts: StabilityImageToImageOptions) { + const { engineId = STABILITY_DEFAULT_IMAGE_GENERATION_ENGINE_ID } = opts + + const body = createFormData( + opts, + { + textPrompts: 'text_prompts', + initImageMode: 'init_image_mode', + cfgScale: 'cfg_scale', + clipGuidancePreset: 'clip_guidance_preset', + sampler: 'sampler', + samples: 'samples', + seed: 'seed', + steps: 'steps', + stylePreset: 'style_preset' + }, + { + initImage: 'init_image' + } + ) + + return this.api + .post(`v1/generation/${engineId}/image-to-image`, { + body + }) + .json() + } + + /** + * Creates a higher resolution version of an input image. + * + * @see https://platform.stability.ai/rest-api#tag/v1generation/operation/upscaleImage + */ + async upscaleImage(opts: StabilityImageUpscaleOptions) { + const { engineId = STABILITY_DEFAULT_IMAGE_UPSCALE_ENGINE_ID } = opts + + const body = createFormData( + opts, + { + textPrompts: 'text_prompts', + cfgScale: 'cfg_scale', + width: 'width', + height: 'height', + seed: 'seed', + steps: 'steps' + }, + { + image: 'image' + } + ) + + return this.api + .post(`v1/generation/${engineId}/image-to-image/upscale`, { + body + }) + .json() + } + + /** + * Selectively modifies portions of an initial image using a mask image. + * + * @see https://platform.stability.ai/rest-api#tag/v1generation/operation/masking + */ + async maskImage(opts: StabilityImageToImageMaskingOptions) { + const { engineId = STABILITY_DEFAULT_IMAGE_MASKING_ENGINE_ID } = opts + + const body = createFormData( + { ...opts, maskSource: 'MASK_IMAGE_BLACK' }, + { + textPrompts: 'text_prompts', + initImageMode: 'init_image_mode', + maskSource: 'mask_source', + cfgScale: 'cfg_scale', + clipGuidancePreset: 'clip_guidance_preset', + sampler: 'sampler', + samples: 'samples', + seed: 'seed', + steps: 'steps', + stylePreset: 'style_preset' + }, + { + initImage: 'init_image', + maskImage: 'mask_image' + } + ) + + return this.api + .post(`v1/generation/${engineId}/image-to-image/masking`, { + body + }) + .json() + } + + /** + * Lists the available engines (e.g., models). + * + * @see https://platform.stability.ai/rest-api#tag/v1engines + */ + async listEngines() { + return this.api.get('v1/engines/list').json() + } + + /** + * Gets information about the user associated with this account. + * + * @see https://platform.stability.ai/rest-api#tag/v1user/operation/userAccount + */ + async getUserAccount() { + return this.api.get('v1/user/account').json() + } + + /** + * Gets the credit balance of the account/organization associated with this account. + * + * @see https://platform.stability.ai/rest-api#tag/v1user/operation/userBalance + */ + async getUserBalance() { + return this.api.get('v1/user/balance').json() + } +} + +function createFormData( + data: Record, + jsonKeys: Record, + imageKeys?: Record +) { + const formData = new FormData() + + for (const [key, key2] of Object.entries(imageKeys || {})) { + const value = data[key] + if (value !== undefined) { + formData.append( + key2, + Buffer.isBuffer(value) ? value.toString('binary') : value + ) + } + } + + for (const [key, key2] of Object.entries(jsonKeys)) { + const value = data[key] + if (value !== undefined) { + formData.append(key2, JSON.stringify(value)) + } + } + + return formData +} diff --git a/src/utils.ts b/src/utils.ts index ef9ab554..9e44045d 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -246,3 +246,7 @@ export function isFunction(value: any): value is Function { export function isString(value: any): value is string { return typeof value === 'string' } + +export function isArray(value: any): value is any[] { + return Array.isArray(value) +} diff --git a/test/_utils.ts b/test/_utils.ts index feb038b8..de10a2b9 100644 --- a/test/_utils.ts +++ b/test/_utils.ts @@ -45,12 +45,12 @@ keyv.has = async (key, ...rest) => { function getCacheKeyForRequest(request: Request): string | null { const method = request.method.toLowerCase() - if (method === 'get') { + if (method === 'get' || method === 'head' || method === 'options') { const url = normalizeUrl(request.url) if (url) { const cacheParams = { - // TODO: request.headers isn't a normal JS object... + // TODO: request.headers isn't a normal JS object headers: { ...request.headers } } diff --git a/test/services/stability.test.ts b/test/services/stability.test.ts new file mode 100644 index 00000000..036cfbd0 --- /dev/null +++ b/test/services/stability.test.ts @@ -0,0 +1,73 @@ +import test from 'ava' + +// import fs from 'fs/promises' +import { StabilityClient } from '@/services' + +import { ky } from '../_utils' + +const isStabilityTestingEnabled = false + +test('StabilityClient.listEngines', async (t) => { + if (!process.env.STABILITY_API_KEY || !isStabilityTestingEnabled) { + return t.pass() + } + + t.timeout(2 * 60 * 1000) + const client = new StabilityClient({ ky }) + + const result = await client.listEngines() + // console.log(result) + t.true(Array.isArray(result)) +}) + +test('StabilityClient.textToImage string', async (t) => { + if (!process.env.STABILITY_API_KEY || !isStabilityTestingEnabled) { + return t.pass() + } + + t.timeout(2 * 60 * 1000) + const client = new StabilityClient({ ky }) + + const result = await client.textToImage('tiny baby kittens, kawaii, anime') + // console.log(result) + + t.is(result.artifacts.length, 1) + t.truthy(result.artifacts[0].base64) + t.is(result.artifacts[0].finishReason, 'SUCCESS') + + // await fs.writeFile( + // 'out.png', + // Buffer.from(result.artifacts[0].base64, 'base64') + // ) +}) + +test('StabilityClient.textToImage full params', async (t) => { + if (!process.env.STABILITY_API_KEY || !isStabilityTestingEnabled) { + return t.pass() + } + + t.timeout(2 * 60 * 1000) + const client = new StabilityClient({ ky }) + + const result = await client.textToImage({ + engineId: 'stable-diffusion-xl-beta-v2-2-2', + textPrompts: [{ text: 'smol kittens, kawaii, cute, anime' }], + width: 512, + height: 512, + cfgScale: 7, + clipGuidancePreset: 'FAST_BLUE', + stylePreset: 'anime', + samples: 1, + steps: 30 + }) + // console.log(result) + + t.is(result.artifacts.length, 1) + t.truthy(result.artifacts[0].base64) + t.is(result.artifacts[0].finishReason, 'SUCCESS') + + // await fs.writeFile( + // 'out.png', + // Buffer.from(result.artifacts[0].base64, 'base64') + // ) +}) diff --git a/test/services/weather.test.ts b/test/services/weather.test.ts index b192facd..3a497c4d 100644 --- a/test/services/weather.test.ts +++ b/test/services/weather.test.ts @@ -48,7 +48,7 @@ test('WeatherTool.call', async (t) => { // console.log(result) t.truthy(result.current) t.truthy(result.location) - t.is(result.location.name, 'Brooklyn') - t.is(result.location.region, 'New York') - t.is(result.location.country, 'USA') + t.is(result.location!.name, 'Brooklyn') + t.is(result.location!.region, 'New York') + t.is(result.location!.country, 'USA') })