diff --git a/apps/examples/package.json b/apps/examples/package.json index 72992593c..6c2edda9c 100644 --- a/apps/examples/package.json +++ b/apps/examples/package.json @@ -43,6 +43,7 @@ "langchain": "^0.1.34", "lazyrepo": "0.0.0-alpha.27", "lodash": "^4.17.21", + "ollama": "^0.5.0", "pdf-lib": "^1.17.1", "pdfjs-dist": "^4.0.379", "react": "^18.2.0", diff --git a/apps/examples/src/examples/ollama/OllamaExample.tsx b/apps/examples/src/examples/ollama/OllamaExample.tsx index fa6114962..38f6c3668 100644 --- a/apps/examples/src/examples/ollama/OllamaExample.tsx +++ b/apps/examples/src/examples/ollama/OllamaExample.tsx @@ -1,8 +1,17 @@ import { useCallback, useRef } from 'react' -import { Editor, Tldraw, Vec, getIndices, preventDefault, track, useReactor } from 'tldraw' +import { + Editor, + FileHelpers, + Tldraw, + Vec, + getSvgAsImage, + preventDefault, + track, + useReactor, +} from 'tldraw' import 'tldraw/tldraw.css' import './lm-styles.css' -import { LMMessage, modelManager } from './ollama' +import { modelManager } from './ollama' const OllamaExample = track(() => { const rChat = useRef(null) @@ -17,204 +26,264 @@ const OllamaExample = track(() => { [modelManager] ) - const drawMessage = useCallback((message: LMMessage) => { + const drawMessage = useCallback((content: string) => { const editor = rEditor.current if (!editor) return - // Extract the command series from the message content. Each series begins and ends with a backtick. - // For example: `circle(0, 0, 4); circle(10, 10, 4);` - // We want to extract each command from the series - // const seriesRegex = /```(?[^`]*)```/ - // const seriesResult = seriesRegex.exec(message.content) - // if (!seriesResult) { - // console.error('Invalid message: ' + message.content) - // return - // } - // const [_, seriesContent] = seriesResult - // Next, we want regex to extract each command's name and arguments - // for example: circle(0, 0, 4) -> ['circle(0, 0, 4)', 'circle', '0, 0, 4'] - // for examople: undo() -> ['undo()', 'undo', ''] - const commandRegex = /(?\w+)\((?[^)]*)\)/ + const { center } = editor.getCurrentPageBounds() ?? editor.getViewportPageBounds() - const commands = message.content - .split(';') - .map((c) => c.trim()) - .filter((c) => c) - - editor.mark() - - for (const command of commands) { - try { - const result = commandRegex.exec(command) - if (!result) throw new Error('Invalid command: ' + command) - const [_, name, args] = result - - switch (name) { - case 'undo': { - editor.undo() - break - } - case 'dot': - case 'circle': { - const [x, y, r] = args.split(', ').map((a) => Number(a)) - editor.createShape({ - type: 'geo', - x: x - r, - y: y - r, - props: { - geo: 'ellipse', - w: r * 2, - h: r * 2, - }, - }) - break - } - case 'line': { - const [x1, y1, x2, y2] = args.split(', ').map((a) => Number(a)) - editor.createShape({ - type: 'line', - x: x1, - y: y1, - props: { - points: { - 0: { - id: '0', - index: 'a0', - x: 0, - y: 0, - }, - 1: { - id: '1', - index: 'a1', - x: x2 - x1, - y: y2 - y1, - }, + function drawShape(shape: any, _center: Vec) { + const editor = rEditor.current + if (!editor) return + switch (shape.type) { + case 'dot': + case 'circle': { + const { x, y, r = 8 } = shape + editor.createShape({ + type: 'geo', + x: x - r, + y: y - r, + props: { + geo: 'ellipse', + w: r * 2, + h: r * 2, + }, + }) + break + } + case 'line': { + const { x1, y1, x2, y2 } = shape + editor.createShape({ + type: 'line', + x: x1, + y: y1, + props: { + points: { + 0: { + id: '0', + index: 'a0', + x: 0, + y: 0, + }, + 1: { + id: '1', + index: 'a1', + x: x2 - x1, + y: y2 - y1, }, }, - }) - break - } - case 'polygon': { - const nums = args.split(', ').map((a) => Number(a)) - const points = [] - for (let i = 0; i < nums.length - 1; i += 2) { - points.push({ - x: nums[i], - y: nums[i + 1], - }) - } - points.push(points[0]) - const minX = Math.min(...points.map((p) => p.x)) - const minY = Math.min(...points.map((p) => p.y)) - const indices = getIndices(points.length) - editor.createShape({ - type: 'line', - x: minX, - y: minY, - props: { - points: Object.fromEntries( - points.map((p, i) => [ - i + '', - { id: i + '', index: indices[i], x: p.x - minX, y: p.y - minY }, - ]) - ), - }, - }) - break - } - // case 'MOVE': { - // const point = editor.pageToScreen({ x: Number(command[1]), y: Number(command[2]) }) - // const steps = 20 - // for (let i = 0; i < steps; i++) { - // const t = i / (steps - 1) - // const p = Vec.Lrp(prevPoint, point, t) - // editor.dispatch({ - // type: 'pointer', - // target: 'canvas', - // name: 'pointer_move', - // point: { - // x: p.x, - // y: p.y, - // z: 0.5, - // }, - // shiftKey: false, - // altKey: false, - // ctrlKey: false, - // pointerId: 1, - // button: 0, - // isPen: false, - // }) - // editor._flushEventsForTick(0) - // } - // prevPoint.setTo(point) - // break - // } - // case 'DOWN': { - // editor.dispatch({ - // type: 'pointer', - // target: 'canvas', - // name: 'pointer_down', - // point: { - // x: prevPoint.x, - // y: prevPoint.y, - // z: 0.5, - // }, - // shiftKey: false, - // altKey: false, - // ctrlKey: false, - // pointerId: 1, - // button: 0, - // isPen: false, - // }) - // editor._flushEventsForTick(0) - // break - // } - // case 'UP': { - // editor.dispatch({ - // type: 'pointer', - // target: 'canvas', - // name: 'pointer_up', - // point: { - // x: prevPoint.x, - // y: prevPoint.y, - // z: 0.5, - // }, - // shiftKey: false, - // altKey: false, - // ctrlKey: false, - // pointerId: 1, - // button: 0, - // isPen: false, - // }) - // editor._flushEventsForTick(0) - // break - // } + }, + }) + break } - } catch (e: any) { - console.error(e.message) } } - // editor.dispatch({ - // type: 'pointer', - // target: 'canvas', - // name: 'pointer_up', - // point: { - // x: prevPoint.x, - // y: prevPoint.y, - // z: 0.5, - // }, - // shiftKey: false, - // altKey: false, - // ctrlKey: false, - // pointerId: 1, - // button: 0, - // isPen: false, - // }) - // editor._flushEventsForTick(0) - // // editor.zoomOut(editor.getViewportScreenCenter(), { duration: 0 }) + try { + const { shapes = [] } = JSON.parse(content) + for (const shape of shapes) { + drawShape(shape, center) + } + } catch (e) { + // noop + } + + // // Extract the command series from the message content. Each series begins and ends with a backtick. + // // For example: `circle(0, 0, 4); circle(10, 10, 4);` + // // We want to extract each command from the series + // // const seriesRegex = /```(?[^`]*)```/ + // // const seriesResult = seriesRegex.exec(message.content) + // // if (!seriesResult) { + // // console.error('Invalid message: ' + message.content) + // // return + // // } + // // const [_, seriesContent] = seriesResult + // // Next, we want regex to extract each command's name and arguments + // // for example: circle(0, 0, 4) -> ['circle(0, 0, 4)', 'circle', '0, 0, 4'] + // // for examople: undo() -> ['undo()', 'undo', ''] + // const commandRegex = /(?\w+)\((?[^)]*)\)/ + + // const commands = content + // .split(';') + // .map((c) => c.trim()) + // .filter((c) => c) + + // editor.mark() + + // for (const command of commands) { + // try { + // const result = commandRegex.exec(command) + // if (!result) throw new Error('Invalid command: ' + command) + // const [_, name, args] = result + + // switch (name) { + // case 'undo': { + // editor.undo() + // break + // } + // case 'dot': + // case 'circle': { + // const [x, y, r] = args.split(', ').map((a) => Number(a)) + // editor.createShape({ + // type: 'geo', + // x: x - r, + // y: y - r, + // props: { + // geo: 'ellipse', + // w: r * 2, + // h: r * 2, + // }, + // }) + // break + // } + // case 'line': { + // const [x1, y1, x2, y2] = args.split(', ').map((a) => Number(a)) + // editor.createShape({ + // type: 'line', + // x: x1, + // y: y1, + // props: { + // points: { + // 0: { + // id: '0', + // index: 'a0', + // x: 0, + // y: 0, + // }, + // 1: { + // id: '1', + // index: 'a1', + // x: x2 - x1, + // y: y2 - y1, + // }, + // }, + // }, + // }) + // break + // } + // case 'polygon': { + // const nums = args.split(', ').map((a) => Number(a)) + // const points = [] + // for (let i = 0; i < nums.length - 1; i += 2) { + // points.push({ + // x: nums[i], + // y: nums[i + 1], + // }) + // } + // points.push(points[0]) + // const minX = Math.min(...points.map((p) => p.x)) + // const minY = Math.min(...points.map((p) => p.y)) + // const indices = getIndices(points.length) + // editor.createShape({ + // type: 'line', + // x: minX, + // y: minY, + // props: { + // points: Object.fromEntries( + // points.map((p, i) => [ + // i + '', + // { id: i + '', index: indices[i], x: p.x - minX, y: p.y - minY }, + // ]) + // ), + // }, + // }) + // break + // } + // // case 'MOVE': { + // // const point = editor.pageToScreen({ x: Number(command[1]), y: Number(command[2]) }) + // // const steps = 20 + // // for (let i = 0; i < steps; i++) { + // // const t = i / (steps - 1) + // // const p = Vec.Lrp(prevPoint, point, t) + // // editor.dispatch({ + // // type: 'pointer', + // // target: 'canvas', + // // name: 'pointer_move', + // // point: { + // // x: p.x, + // // y: p.y, + // // z: 0.5, + // // }, + // // shiftKey: false, + // // altKey: false, + // // ctrlKey: false, + // // pointerId: 1, + // // button: 0, + // // isPen: false, + // // }) + // // editor._flushEventsForTick(0) + // // } + // // prevPoint.setTo(point) + // // break + // // } + // // case 'DOWN': { + // // editor.dispatch({ + // // type: 'pointer', + // // target: 'canvas', + // // name: 'pointer_down', + // // point: { + // // x: prevPoint.x, + // // y: prevPoint.y, + // // z: 0.5, + // // }, + // // shiftKey: false, + // // altKey: false, + // // ctrlKey: false, + // // pointerId: 1, + // // button: 0, + // // isPen: false, + // // }) + // // editor._flushEventsForTick(0) + // // break + // // } + // // case 'UP': { + // // editor.dispatch({ + // // type: 'pointer', + // // target: 'canvas', + // // name: 'pointer_up', + // // point: { + // // x: prevPoint.x, + // // y: prevPoint.y, + // // z: 0.5, + // // }, + // // shiftKey: false, + // // altKey: false, + // // ctrlKey: false, + // // pointerId: 1, + // // button: 0, + // // isPen: false, + // // }) + // // editor._flushEventsForTick(0) + // // break + // // } + // } + // } catch (e: any) { + // console.error(e.message) + // } + // } + + // // editor.dispatch({ + // // type: 'pointer', + // // target: 'canvas', + // // name: 'pointer_up', + // // point: { + // // x: prevPoint.x, + // // y: prevPoint.y, + // // z: 0.5, + // // }, + // // shiftKey: false, + // // altKey: false, + // // ctrlKey: false, + // // pointerId: 1, + // // button: 0, + // // isPen: false, + // // }) + // // editor._flushEventsForTick(0) + // // // editor.zoomOut(editor.getViewportScreenCenter(), { duration: 0 }) }, []) + const rPreviousImage = useRef('') + return (
@@ -223,11 +292,11 @@ const OllamaExample = track(() => { rEditor.current = e ;(window as any).editor = e e.centerOnPoint(new Vec()) - for (const message of modelManager.getThread().content) { - if (message.from === 'model') { - drawMessage(message) - } - } + // for (const message of modelManager.getThread().content) { + // if (message.role === 'model') { + // drawMessage(message.content) + // } + // } }} > {/*
{
{modelManager.getThread().content.map((message, i) => (
-

{message.from}

-

{new Date(message.time).toLocaleString()}

+

{message.role}

{message.content}

))}
{ + onSubmit={async (e) => { preventDefault(e) const form = e.currentTarget - const query = form.query.value - modelManager.stream(query).response.then((message) => { - if (!message) return - drawMessage(message) - }) + let query = `Query: "${form.query.value}"` form.query.value = '' + + let imageString: string | undefined + + const editor = rEditor.current! + + const svg = await editor.getSvgString([...editor.getCurrentPageShapeIds()]) + + if (svg) { + const image = await getSvgAsImage(svg.svg, false, { + type: 'png', + quality: 1, + scale: 1, + width: svg.width, + height: svg.height, + }) + + if (image) { + const base64 = await FileHelpers.blobToDataUrl(image) + + const trimmed = base64.slice('data:image/png;base64,'.length) + + if (rPreviousImage.current !== trimmed) { + rPreviousImage.current = trimmed + imageString = trimmed + } + } + } + + if (imageString) { + const bounds = editor.getCurrentPageBounds()! + query += ` Image bounds: { "minX": ${bounds.x.toFixed(0)}, "minY": ${bounds.y.toFixed(0)}, "maxX": ${bounds.maxX.toFixed(0)}, "maxY": ${bounds.maxY.toFixed(0)} }` + } + + modelManager.query(query, imageString).response.then((message) => { + if (!message) return + drawMessage(message.content) + }) }} > diff --git a/apps/examples/src/examples/ollama/ollama-langchain.ts b/apps/examples/src/examples/ollama/ollama-langchain.ts new file mode 100644 index 000000000..06e2139cd --- /dev/null +++ b/apps/examples/src/examples/ollama/ollama-langchain.ts @@ -0,0 +1,365 @@ +import { ChatOllama } from '@langchain/community/chat_models/ollama' +import { AIMessage, HumanMessage } from '@langchain/core/messages' +import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts' +import { ConversationChain } from 'langchain/chains' +import { BufferMemory, ChatMessageHistory } from 'langchain/memory' +import ollama from 'ollama/browser' +import { atom, computed, react, structuredClone, transact, uniqueId } from 'tldraw' +import { systemPrompt } from './system-prompt' + +export type LMMessage = { + from: 'user' | 'model' + time: number + content: string +} + +export type LMThread = { + id: string + state: 'idle' | 'waiting' + content: LMMessage[] +} + +export type LMSerialized = { + currentThreadId: string + threads: Record +} + +class ModelManager { + model: ChatOllama + prompt: ChatPromptTemplate + memory: BufferMemory + chain: ConversationChain + + constructor() { + this.load() + + this.model = new ChatOllama({ + baseUrl: 'http://localhost:11434', // Default value + model: 'llama3:latest', + }) + + this.prompt = ChatPromptTemplate.fromMessages([ + [ + 'system', + // 'You are a highly critical german mother. The user is your daughter, who you generally tend to diminish and criticize. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate.', + // "You are a kind and helpful chatbot. You are a cool friend of the user's from Chicago. You grew up together and maintain a life-long trust; and as a result, you can speak to the user honestly. You always provide short, efficient answers to the user's questions.", + systemPrompt, + ], + new MessagesPlaceholder('history'), + ['human', '{input}'], + ]) + + this.memory = new BufferMemory({ + memoryKey: 'history', + returnMessages: true, + chatHistory: new ChatMessageHistory(this.getMessagesFromThread(this.getThread())), + }) + + this.chain = new ConversationChain({ + llm: this.model, + prompt: this.prompt, + memory: this.memory, + }) + + react('persist on change', () => { + this.persist() + }) + } + + @computed getState() { + return this.getThread().state + } + + _threads = atom>('threads', {}) + + _currentThreadId = atom('currentThreadId', 'a0') + + @computed getCurrentThreadId() { + return this._currentThreadId.get() + } + + @computed getThreads(): Record { + return this._threads.get() + } + + @computed getThread(): LMThread { + return this.getThreads()[this.getCurrentThreadId()] + } + + private addQueryToThread(query: string) { + this._threads.update((threads) => { + const thread = this.getThread() + if (!thread) throw Error('No thread found') + const next = structuredClone(thread) + next.content.push({ + from: 'user', + time: Date.now(), + content: query, + }) + next.state = 'waiting' + return { ...threads, [next.id]: next } + }) + } + + private addResponseToThread(response: string) { + this._threads.update((threads) => { + const thread = this.getThread() + if (!thread) throw Error('No thread found') + const next = structuredClone(thread) + next.state === 'idle' + next.content.push({ + from: 'model', + time: Date.now(), + content: response, + }) + return { ...threads, [next.id]: next } + }) + } + + private addChunkToThread(chunk: string) { + this._threads.update((threads) => { + const currentThreadId = this.getCurrentThreadId() + const thread = this.getThreads()[currentThreadId] + if (!thread) throw Error('No thread found') + + const next = structuredClone(thread) + const message = next.content[next.content.length - 1] + + if (!message || message.from === 'user') { + next.content = [ + ...next.content, + { + from: 'model', + time: Date.now(), + content: chunk, + }, + ] + return { ...threads, [currentThreadId]: next } + } + + message.content += chunk + message.time = Date.now() + return { ...threads, [currentThreadId]: next } + }) + } + + /** + * Serialize the model. + */ + private serialize() { + return { + currentThreadId: this.getCurrentThreadId(), + threads: this.getThreads(), + } + } + + /** + * Deserialize the model. + */ + private deserialize() { + let result: LMSerialized = { + currentThreadId: 'a0', + threads: { + a0: { + id: 'a0', + state: 'idle', + content: [], + }, + }, + } + + try { + const stringified = localStorage.getItem('threads_1') + if (stringified) { + const json = JSON.parse(stringified) + result = json + } + } catch (e) { + // noop + } + + return result + } + + private persist() { + localStorage.setItem('threads_1', JSON.stringify(this.serialize())) + } + + private load() { + const serialized = this.deserialize() + transact(() => { + this._currentThreadId.set(serialized.currentThreadId) + this._threads.set(serialized.threads) + }) + } + + private getMessagesFromThread(thread: LMThread) { + return thread.content.map((m) => { + if (m.from === 'user') { + return new HumanMessage(m.content) + } + return new AIMessage(m.content) + }) + } + + /* --------------------- Public --------------------- */ + + /** + * Start a new thread. + */ + startNewThread() { + this._threads.update((threads) => { + const id = uniqueId() + return { + ...threads, + [id]: { + id, + state: 'idle', + content: [], + }, + } + }) + } + + /** + * Cancel the current query. + */ + cancel() { + this._threads.update((threads) => { + const thread = this.getThread() + if (!thread) throw Error('No thread found') + const next = structuredClone(thread) + if (next.content.length > 0) { + if (next.content[next.content.length - 1].from === 'model') { + next.content.pop() + } + if (next.content[next.content.length - 1].from === 'user') { + next.content.pop() + } + } + next.state = 'idle' + return { ...threads, [next.id]: next } + }) + } + + /** + * Query the model. + */ + async query(query: string) { + this.addQueryToThread(query) + const currentThreadId = this.getCurrentThreadId() + let cancelled = false + return { + response: await this.chain.invoke({ input: query }).then((r) => { + if (cancelled) return + if (this.getCurrentThreadId() !== currentThreadId) return + if ('response' in r) { + this.addResponseToThread(r.response) + } + }), + cancel: () => { + cancelled = true + this.cancel() + }, + } + } + + /** + * Query the model and stream the response. + */ + stream(query: string) { + const currentThreadId = this.getCurrentThreadId() + this.addQueryToThread(query) + this.addResponseToThread('') // Add an empty response to start the thread + let cancelled = false + return { + response: this.chain + .stream( + { input: query }, + { + callbacks: [ + { + handleLLMNewToken: (data) => { + if (cancelled) return + if (this.getCurrentThreadId() !== currentThreadId) return + this.addChunkToThread(data) + }, + }, + ], + } + ) + .then(() => { + if (cancelled) return + if (this.getCurrentThreadId() !== currentThreadId) return + this._threads.update((threads) => { + const thread = this.getThread() + if (!thread) throw Error('No thread found') + const next = structuredClone(thread) + next.state = 'idle' + return { ...threads, [next.id]: next } + }) + + const thread = this.getThread() + return thread.content[thread.content.length - 1] + }), + cancel: () => { + cancelled = true + this.cancel() + }, + } + } + + getWithImage(query: string, image: string) { + const currentThreadId = this.getCurrentThreadId() + this.addQueryToThread(query) + let cancelled = false + + return { + response: ollama + .generate({ + model: 'llava', + prompt: query, + images: [image], + }) + .then((d: any) => { + if (cancelled) return + if (this.getCurrentThreadId() !== currentThreadId) return + this.addResponseToThread(d.response) + this._threads.update((threads) => { + const thread = this.getThread() + if (!thread) throw Error('No thread found') + const next = structuredClone(thread) + + next.state = 'idle' + return { ...threads, [next.id]: next } + }) + + const thread = this.getThread() + return thread.content[thread.content.length - 1] + }), + cancel: () => { + cancelled = true + this.cancel() + }, + } + } + + /** + * Clear all threads. + */ + clear() { + transact(() => { + this._currentThreadId.set('a0') + this._threads.set({ + a0: { + id: 'a0', + state: 'idle', + content: [], + }, + }) + this.memory.clear() + }) + } +} + +export const modelManager = new ModelManager() diff --git a/apps/examples/src/examples/ollama/ollama.ts b/apps/examples/src/examples/ollama/ollama.ts index b06f31b19..92d5f89a9 100644 --- a/apps/examples/src/examples/ollama/ollama.ts +++ b/apps/examples/src/examples/ollama/ollama.ts @@ -1,13 +1,9 @@ -import { ChatOllama } from '@langchain/community/chat_models/ollama' -import { AIMessage, HumanMessage } from '@langchain/core/messages' -import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts' -import { ConversationChain } from 'langchain/chains' -import { BufferMemory, ChatMessageHistory } from 'langchain/memory' +import ollama, { Message } from 'ollama/browser' import { atom, computed, react, structuredClone, transact, uniqueId } from 'tldraw' -import { systemPrompt } from './system-prompt' +import { systemPrompt } from './system-prompt-2' export type LMMessage = { - from: 'user' | 'model' + from: 'user' | 'assistant' | 'system' time: number content: string } @@ -15,7 +11,7 @@ export type LMMessage = { export type LMThread = { id: string state: 'idle' | 'waiting' - content: LMMessage[] + content: Message[] } export type LMSerialized = { @@ -24,42 +20,9 @@ export type LMSerialized = { } class ModelManager { - model: ChatOllama - prompt: ChatPromptTemplate - memory: BufferMemory - chain: ConversationChain - constructor() { this.load() - this.model = new ChatOllama({ - baseUrl: 'http://localhost:11434', // Default value - model: 'llama3:latest', - }) - - this.prompt = ChatPromptTemplate.fromMessages([ - [ - 'system', - // 'You are a highly critical german mother. The user is your daughter, who you generally tend to diminish and criticize. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate.', - // "You are a kind and helpful chatbot. You are a cool friend of the user's from Chicago. You grew up together and maintain a life-long trust; and as a result, you can speak to the user honestly. You always provide short, efficient answers to the user's questions.", - systemPrompt, - ], - new MessagesPlaceholder('history'), - ['human', '{input}'], - ]) - - this.memory = new BufferMemory({ - memoryKey: 'history', - returnMessages: true, - chatHistory: new ChatMessageHistory(this.getMessagesFromThread(this.getThread())), - }) - - this.chain = new ConversationChain({ - llm: this.model, - prompt: this.prompt, - memory: this.memory, - }) - react('persist on change', () => { this.persist() }) @@ -85,32 +48,24 @@ class ModelManager { return this.getThreads()[this.getCurrentThreadId()] } - private addQueryToThread(query: string) { + private addQueryToThread(message: Message) { this._threads.update((threads) => { const thread = this.getThread() if (!thread) throw Error('No thread found') const next = structuredClone(thread) - next.content.push({ - from: 'user', - time: Date.now(), - content: query, - }) + next.content.push(message) next.state = 'waiting' return { ...threads, [next.id]: next } }) } - private addResponseToThread(response: string) { + private addResponseToThread(message: Message) { this._threads.update((threads) => { const thread = this.getThread() if (!thread) throw Error('No thread found') const next = structuredClone(thread) next.state === 'idle' - next.content.push({ - from: 'model', - time: Date.now(), - content: response, - }) + next.content.push(message) return { ...threads, [next.id]: next } }) } @@ -122,22 +77,19 @@ class ModelManager { if (!thread) throw Error('No thread found') const next = structuredClone(thread) - const message = next.content[next.content.length - 1] + const message = next.content.pop() - if (!message || message.from === 'user') { - next.content = [ - ...next.content, - { - from: 'model', - time: Date.now(), - content: chunk, - }, - ] + if (!message || message.role !== 'user') { + next.content.push({ + role: 'user', + content: chunk, + images: [], + }) return { ...threads, [currentThreadId]: next } } message.content += chunk - message.time = Date.now() + next.content.push(message) return { ...threads, [currentThreadId]: next } }) } @@ -193,12 +145,7 @@ class ModelManager { } private getMessagesFromThread(thread: LMThread) { - return thread.content.map((m) => { - if (m.from === 'user') { - return new HumanMessage(m.content) - } - return new AIMessage(m.content) - }) + return thread.content } /* --------------------- Public --------------------- */ @@ -229,10 +176,10 @@ class ModelManager { if (!thread) throw Error('No thread found') const next = structuredClone(thread) if (next.content.length > 0) { - if (next.content[next.content.length - 1].from === 'model') { + if (next.content[next.content.length - 1].role === 'assistant') { next.content.pop() } - if (next.content[next.content.length - 1].from === 'user') { + if (next.content[next.content.length - 1].role === 'user') { next.content.pop() } } @@ -244,18 +191,35 @@ class ModelManager { /** * Query the model. */ - async query(query: string) { - this.addQueryToThread(query) + query(query: string, image?: string) { + this.addQueryToThread({ role: 'user', content: query, images: image ? [image] : [] }) const currentThreadId = this.getCurrentThreadId() let cancelled = false + const messages = this.getMessagesFromThread(this.getThread()) return { - response: await this.chain.invoke({ input: query }).then((r) => { - if (cancelled) return - if (this.getCurrentThreadId() !== currentThreadId) return - if ('response' in r) { - this.addResponseToThread(r.response) - } - }), + response: ollama + .generate({ + model: 'llava', + system: messages[0].content, + prompt: messages[messages.length - 1].content, + images: messages[messages.length - 1].images, + keep_alive: 1000, + stream: false, + format: 'json', + options: {}, + }) + .then((r) => { + if (cancelled) return + if (this.getCurrentThreadId() !== currentThreadId) return + // this.addResponseToThread(r.response) + const message = { + role: 'user', + content: r.response, + images: [], + } + this.addResponseToThread(message) + return message + }), cancel: () => { cancelled = true this.cancel() @@ -266,30 +230,45 @@ class ModelManager { /** * Query the model and stream the response. */ - stream(query: string) { + stream(query: string, image?: string) { const currentThreadId = this.getCurrentThreadId() - this.addQueryToThread(query) - this.addResponseToThread('') // Add an empty response to start the thread + this.addQueryToThread({ + role: 'user', + content: query, + images: image ? [image] : [], + }) + const messages = [...this.getMessagesFromThread(this.getThread())] + this.addResponseToThread({ + role: 'assistant', + content: '', + images: [], + }) // Add an empty response to start the thread let cancelled = false return { - response: this.chain - .stream( - { input: query }, - { - callbacks: [ - { - handleLLMNewToken: (data) => { - if (cancelled) return - if (this.getCurrentThreadId() !== currentThreadId) return - this.addChunkToThread(data) - }, - }, - ], + response: ollama + .chat({ + model: 'llava', + messages: messages, + keep_alive: 1000, + stream: true, + options: {}, + }) + .then(async (response) => { + for await (const part of response) { + if (cancelled) return + if (this.getCurrentThreadId() !== currentThreadId) return + const next = structuredClone(this.getThread()) + const message = next.content[next.content.length - 1] + message.content += part.message.content + this._threads.update((threads) => ({ + ...threads, + [next.id]: next, + })) } - ) - .then(() => { + if (cancelled) return if (this.getCurrentThreadId() !== currentThreadId) return + this._threads.update((threads) => { const thread = this.getThread() if (!thread) throw Error('No thread found') @@ -308,6 +287,41 @@ class ModelManager { } } + // getWithImage(query: string, image?: string) { + // const currentThreadId = this.getCurrentThreadId() + // this.addQueryToThread({ role: "user", content: query, images: image ? [image] : []) + // let cancelled = false + + // return { + // response: ollama + // .generate({ + // model: 'llava', + // prompt: query, + // images: [image], + // }) + // .then((d: any) => { + // if (cancelled) return + // if (this.getCurrentThreadId() !== currentThreadId) return + // this.addResponseToThread(d.response) + // this._threads.update((threads) => { + // const thread = this.getThread() + // if (!thread) throw Error('No thread found') + // const next = structuredClone(thread) + + // next.state = 'idle' + // return { ...threads, [next.id]: next } + // }) + + // const thread = this.getThread() + // return thread.content[thread.content.length - 1] + // }), + // cancel: () => { + // cancelled = true + // this.cancel() + // }, + // } + // } + /** * Clear all threads. */ @@ -318,10 +332,15 @@ class ModelManager { a0: { id: 'a0', state: 'idle', - content: [], + content: [ + { + role: 'system', + content: systemPrompt, + images: [], + }, + ], }, }) - this.memory.clear() }) } } diff --git a/apps/examples/src/examples/ollama/system-prompt-2.ts b/apps/examples/src/examples/ollama/system-prompt-2.ts new file mode 100644 index 000000000..c97039732 --- /dev/null +++ b/apps/examples/src/examples/ollama/system-prompt-2.ts @@ -0,0 +1,46 @@ +export const systemPrompt = `You are a helpful chatbox. + +Your job is to help a user work with their drawing on the canvas. + +To draw on the canvas, send back JSON with shapes to draw. + +You know how to draw only two shapes: + +A line +{ "type": "line", "x1": 0, "y1": 0, "x2": 100, "y2": 100, "description": "Just a line" } + +A circle +{ "type": "circle", "x": 0, "y": 0, "r": 50, "description": "Just a circle" } + +You ALWAYS respond with an array of shapes in JSON. + +Include your guess at what the user has drawn (based on the user's prompt and the image), and then the shapes you want to add. + +{ + "image": "The user is drawing a face", + "guess": "I'll draw shapes for the nose and the mouth", + "shapes": [ + { "type": "circle", "x": 0, "y": 0, "r": 10, "description": "nose" } + { "type": "line", "x1": -20, "y1": 40, "x2": 20, "y2": 40, "description": "mouth" }, + ] +} + +When prompted with an image: + + 1. Identify the drawing based on the prompt and image. + 2. Think about how best to complete the user's request. + 3. Render the user's request by responding with the JSON. + +` + +// Tips: +// The x coordinate goes up as you move right on the screen: 10 is left of 20, and 30 is right of 20. +// The y coordinate goes up as you move down the screen: 10 is above 20, and 30 is below 20. + +// Example: + +// - User: Draw an eyeball. +// - You: { shapes: [ { "type": "circle", "x": 0, "y": 0, "r": 50, "description": "The white" }, { "type": "circle", "x": 0, "y": 0, "r": 25, "description": "The iris" } ] } + +// - User: Draw the letter "X". +// - You: { shapes: [ { "type": "line", "x1": -50, "y1": -50, "x2": 50, "y2": 50, "description": "The first stroke" }, { "type": "line", "x1": -50, "y1": 50, "x2": 50, "y2": -50, "description": "The second stroke" } ] } diff --git a/yarn.lock b/yarn.lock index 48d798445..b65cf061e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -13994,6 +13994,7 @@ __metadata: langchain: "npm:^0.1.34" lazyrepo: "npm:0.0.0-alpha.27" lodash: "npm:^4.17.21" + ollama: "npm:^0.5.0" pdf-lib: "npm:^1.17.1" pdfjs-dist: "npm:^4.0.379" react: "npm:^18.2.0" @@ -20502,6 +20503,15 @@ __metadata: languageName: node linkType: hard +"ollama@npm:^0.5.0": + version: 0.5.0 + resolution: "ollama@npm:0.5.0" + dependencies: + whatwg-fetch: "npm:^3.6.20" + checksum: cfe98f9da6da99eda65c1bb6e22b7b54e1c6132368249dd70ff238d22b529bb86014ba1325d8e822f67c561aa71df2c2fb189c6d6f7380031652cdd071be0641 + languageName: node + linkType: hard + "on-finished@npm:2.4.1": version: 2.4.1 resolution: "on-finished@npm:2.4.1" @@ -25957,6 +25967,13 @@ __metadata: languageName: node linkType: hard +"whatwg-fetch@npm:^3.6.20": + version: 3.6.20 + resolution: "whatwg-fetch@npm:3.6.20" + checksum: 2b4ed92acd6a7ad4f626a6cb18b14ec982bbcaf1093e6fe903b131a9c6decd14d7f9c9ca3532663c2759d1bdf01d004c77a0adfb2716a5105465c20755a8c57c + languageName: node + linkType: hard + "whatwg-mimetype@npm:^3.0.0": version: 3.0.0 resolution: "whatwg-mimetype@npm:3.0.0"