kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add Replicate stable diffusion tool
rodzic
cd4a2a9236
commit
ef7e350ddc
|
@ -0,0 +1,37 @@
|
||||||
|
import 'dotenv/config'
|
||||||
|
import { OpenAIClient } from 'openai-fetch'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import { Agentic, ReplicateStableDiffusionTool } from '@/index'
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const openai = new OpenAIClient({
|
||||||
|
apiKey: process.env.OPENAI_API_KEY!,
|
||||||
|
fetchOptions: {
|
||||||
|
timeout: false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
const agentic = new Agentic({ openai })
|
||||||
|
|
||||||
|
const topic = process.argv[2] || 'san francisco'
|
||||||
|
|
||||||
|
const res = await agentic
|
||||||
|
.gpt4(
|
||||||
|
`Generate {{numImages}} images of {{topic}}. Use prompts that are artistic and creative.`
|
||||||
|
)
|
||||||
|
.modelParams({ temperature: 1.0 })
|
||||||
|
.tools([new ReplicateStableDiffusionTool()])
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
numImages: z.number().default(5),
|
||||||
|
topic: z.string()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.call({
|
||||||
|
topic
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log(`\n\n\n${res}\n\n\n`)
|
||||||
|
}
|
||||||
|
|
||||||
|
main()
|
|
@ -63,6 +63,7 @@
|
||||||
"pino": "^8.14.1",
|
"pino": "^8.14.1",
|
||||||
"pino-pretty": "^10.0.0",
|
"pino-pretty": "^10.0.0",
|
||||||
"quick-lru": "^6.1.1",
|
"quick-lru": "^6.1.1",
|
||||||
|
"replicate": "^0.12.3",
|
||||||
"ts-dedent": "^2.2.0",
|
"ts-dedent": "^2.2.0",
|
||||||
"uuid": "^9.0.0",
|
"uuid": "^9.0.0",
|
||||||
"zod": "^3.21.4",
|
"zod": "^3.21.4",
|
||||||
|
|
|
@ -77,6 +77,9 @@ dependencies:
|
||||||
quick-lru:
|
quick-lru:
|
||||||
specifier: ^6.1.1
|
specifier: ^6.1.1
|
||||||
version: 6.1.1
|
version: 6.1.1
|
||||||
|
replicate:
|
||||||
|
specifier: ^0.12.3
|
||||||
|
version: 0.12.3
|
||||||
ts-dedent:
|
ts-dedent:
|
||||||
specifier: ^2.2.0
|
specifier: ^2.2.0
|
||||||
version: 2.2.0
|
version: 2.2.0
|
||||||
|
@ -3884,6 +3887,11 @@ packages:
|
||||||
functions-have-names: 1.2.3
|
functions-have-names: 1.2.3
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/replicate@0.12.3:
|
||||||
|
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
|
||||||
|
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/require-directory@2.1.1:
|
/require-directory@2.1.1:
|
||||||
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
|
|
|
@ -3,6 +3,7 @@ export * from './diffbot'
|
||||||
export * from './metaphor'
|
export * from './metaphor'
|
||||||
export * from './midjourney'
|
export * from './midjourney'
|
||||||
export * from './novu'
|
export * from './novu'
|
||||||
|
export * from './replicate'
|
||||||
export * from './search-and-crawl'
|
export * from './search-and-crawl'
|
||||||
export * from './serpapi'
|
export * from './serpapi'
|
||||||
export * from './weather'
|
export * from './weather'
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
import pMap from 'p-map'
|
||||||
|
import Replicate from 'replicate'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import * as types from '@/types'
|
||||||
|
import { getEnv } from '@/env'
|
||||||
|
import { BaseTask } from '@/task'
|
||||||
|
|
||||||
|
const REPLICATE_SD_MODEL =
|
||||||
|
'stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf'
|
||||||
|
|
||||||
|
const ReplicateInputSchema = z.object({
|
||||||
|
images: z.array(
|
||||||
|
z.object({
|
||||||
|
prompt: z
|
||||||
|
.string()
|
||||||
|
.describe(
|
||||||
|
'Simple, short, comma-separated list of phrases which describes the image you want to generate'
|
||||||
|
),
|
||||||
|
negativePrompt: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.describe(
|
||||||
|
'Simple, short, comma-separated list of phrases which describes qualities of the image you do NOT want to generate. Example: low quality, blurry, pixelated'
|
||||||
|
)
|
||||||
|
// seed: z.number().int().optional()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
type ReplicateInput = z.infer<typeof ReplicateInputSchema>
|
||||||
|
|
||||||
|
const ReplicateOutputSchema = z.object({
|
||||||
|
images: z.array(z.string())
|
||||||
|
})
|
||||||
|
type ReplicateOutput = z.infer<typeof ReplicateOutputSchema>
|
||||||
|
|
||||||
|
export class ReplicateStableDiffusionTool extends BaseTask<
|
||||||
|
ReplicateInput,
|
||||||
|
ReplicateOutput
|
||||||
|
> {
|
||||||
|
protected _replicateClient: Replicate
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
opts: {
|
||||||
|
replicate?: Replicate
|
||||||
|
} & types.BaseTaskOptions = {}
|
||||||
|
) {
|
||||||
|
super(opts)
|
||||||
|
|
||||||
|
this._replicateClient =
|
||||||
|
opts.replicate ||
|
||||||
|
new Replicate({
|
||||||
|
auth: getEnv('REPLICATE_API_KEY')!
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get inputSchema() {
|
||||||
|
return ReplicateInputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get outputSchema() {
|
||||||
|
return ReplicateOutputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get nameForModel(): string {
|
||||||
|
return 'replicateStableDiffusion'
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get descForModel(): string {
|
||||||
|
return 'Creates one or more images from a prompt using the Replicate stable diffusion API. Useful for generating images on the fly.'
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override async _call(
|
||||||
|
ctx: types.TaskCallContext<ReplicateInput>
|
||||||
|
): Promise<ReplicateOutput> {
|
||||||
|
const images = (
|
||||||
|
await pMap(
|
||||||
|
ctx.input!.images,
|
||||||
|
async (image) => {
|
||||||
|
try {
|
||||||
|
const input = {
|
||||||
|
prompt: image.prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
if (image.negativePrompt) {
|
||||||
|
input['negative_prompt'] = image.negativePrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('>>> replicate', image)
|
||||||
|
const output: any = await this._replicateClient.run(
|
||||||
|
REPLICATE_SD_MODEL,
|
||||||
|
{
|
||||||
|
input
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
console.log('<<< replicate', image, output)
|
||||||
|
return output
|
||||||
|
} catch (err) {
|
||||||
|
this._logger.error(err, 'Replicate API error')
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
concurrency: 5
|
||||||
|
}
|
||||||
|
)
|
||||||
|
).flat()
|
||||||
|
|
||||||
|
// console.log('replicate output', images)
|
||||||
|
return this.outputSchema.parse({ images })
|
||||||
|
}
|
||||||
|
}
|
Ładowanie…
Reference in New Issue