Steve Ruiz 2024-04-26 00:00:13 +01:00
rodzic 64ee3a57b3
commit f245bef4aa
6 zmienionych plików z 851 dodań i 302 usunięć

Wyświetl plik

@ -43,6 +43,7 @@
"langchain": "^0.1.34",
"lazyrepo": "0.0.0-alpha.27",
"lodash": "^4.17.21",
"ollama": "^0.5.0",
"pdf-lib": "^1.17.1",
"pdfjs-dist": "^4.0.379",
"react": "^18.2.0",

Wyświetl plik

@ -1,8 +1,17 @@
import { useCallback, useRef } from 'react'
import { Editor, Tldraw, Vec, getIndices, preventDefault, track, useReactor } from 'tldraw'
import {
Editor,
FileHelpers,
Tldraw,
Vec,
getSvgAsImage,
preventDefault,
track,
useReactor,
} from 'tldraw'
import 'tldraw/tldraw.css'
import './lm-styles.css'
import { LMMessage, modelManager } from './ollama'
import { modelManager } from './ollama'
const OllamaExample = track(() => {
const rChat = useRef<HTMLDivElement>(null)
@ -17,204 +26,264 @@ const OllamaExample = track(() => {
[modelManager]
)
const drawMessage = useCallback((message: LMMessage) => {
const drawMessage = useCallback((content: string) => {
const editor = rEditor.current
if (!editor) return
// Extract the command series from the message content. Each series begins and ends with a backtick.
// For example: `circle(0, 0, 4); circle(10, 10, 4);`
// We want to extract each command from the series
// const seriesRegex = /```(?<commands>[^`]*)```/
// const seriesResult = seriesRegex.exec(message.content)
// if (!seriesResult) {
// console.error('Invalid message: ' + message.content)
// return
// }
// const [_, seriesContent] = seriesResult
// Next, we want regex to extract each command's name and arguments
// for example: circle(0, 0, 4) -> ['circle(0, 0, 4)', 'circle', '0, 0, 4']
// for examople: undo() -> ['undo()', 'undo', '']
const commandRegex = /(?<name>\w+)\((?<args>[^)]*)\)/
const { center } = editor.getCurrentPageBounds() ?? editor.getViewportPageBounds()
const commands = message.content
.split(';')
.map((c) => c.trim())
.filter((c) => c)
editor.mark()
for (const command of commands) {
try {
const result = commandRegex.exec(command)
if (!result) throw new Error('Invalid command: ' + command)
const [_, name, args] = result
switch (name) {
case 'undo': {
editor.undo()
break
}
case 'dot':
case 'circle': {
const [x, y, r] = args.split(', ').map((a) => Number(a))
editor.createShape({
type: 'geo',
x: x - r,
y: y - r,
props: {
geo: 'ellipse',
w: r * 2,
h: r * 2,
},
})
break
}
case 'line': {
const [x1, y1, x2, y2] = args.split(', ').map((a) => Number(a))
editor.createShape({
type: 'line',
x: x1,
y: y1,
props: {
points: {
0: {
id: '0',
index: 'a0',
x: 0,
y: 0,
},
1: {
id: '1',
index: 'a1',
x: x2 - x1,
y: y2 - y1,
},
function drawShape(shape: any, _center: Vec) {
const editor = rEditor.current
if (!editor) return
switch (shape.type) {
case 'dot':
case 'circle': {
const { x, y, r = 8 } = shape
editor.createShape({
type: 'geo',
x: x - r,
y: y - r,
props: {
geo: 'ellipse',
w: r * 2,
h: r * 2,
},
})
break
}
case 'line': {
const { x1, y1, x2, y2 } = shape
editor.createShape({
type: 'line',
x: x1,
y: y1,
props: {
points: {
0: {
id: '0',
index: 'a0',
x: 0,
y: 0,
},
1: {
id: '1',
index: 'a1',
x: x2 - x1,
y: y2 - y1,
},
},
})
break
}
case 'polygon': {
const nums = args.split(', ').map((a) => Number(a))
const points = []
for (let i = 0; i < nums.length - 1; i += 2) {
points.push({
x: nums[i],
y: nums[i + 1],
})
}
points.push(points[0])
const minX = Math.min(...points.map((p) => p.x))
const minY = Math.min(...points.map((p) => p.y))
const indices = getIndices(points.length)
editor.createShape({
type: 'line',
x: minX,
y: minY,
props: {
points: Object.fromEntries(
points.map((p, i) => [
i + '',
{ id: i + '', index: indices[i], x: p.x - minX, y: p.y - minY },
])
),
},
})
break
}
// case 'MOVE': {
// const point = editor.pageToScreen({ x: Number(command[1]), y: Number(command[2]) })
// const steps = 20
// for (let i = 0; i < steps; i++) {
// const t = i / (steps - 1)
// const p = Vec.Lrp(prevPoint, point, t)
// editor.dispatch({
// type: 'pointer',
// target: 'canvas',
// name: 'pointer_move',
// point: {
// x: p.x,
// y: p.y,
// z: 0.5,
// },
// shiftKey: false,
// altKey: false,
// ctrlKey: false,
// pointerId: 1,
// button: 0,
// isPen: false,
// })
// editor._flushEventsForTick(0)
// }
// prevPoint.setTo(point)
// break
// }
// case 'DOWN': {
// editor.dispatch({
// type: 'pointer',
// target: 'canvas',
// name: 'pointer_down',
// point: {
// x: prevPoint.x,
// y: prevPoint.y,
// z: 0.5,
// },
// shiftKey: false,
// altKey: false,
// ctrlKey: false,
// pointerId: 1,
// button: 0,
// isPen: false,
// })
// editor._flushEventsForTick(0)
// break
// }
// case 'UP': {
// editor.dispatch({
// type: 'pointer',
// target: 'canvas',
// name: 'pointer_up',
// point: {
// x: prevPoint.x,
// y: prevPoint.y,
// z: 0.5,
// },
// shiftKey: false,
// altKey: false,
// ctrlKey: false,
// pointerId: 1,
// button: 0,
// isPen: false,
// })
// editor._flushEventsForTick(0)
// break
// }
},
})
break
}
} catch (e: any) {
console.error(e.message)
}
}
// editor.dispatch({
// type: 'pointer',
// target: 'canvas',
// name: 'pointer_up',
// point: {
// x: prevPoint.x,
// y: prevPoint.y,
// z: 0.5,
// },
// shiftKey: false,
// altKey: false,
// ctrlKey: false,
// pointerId: 1,
// button: 0,
// isPen: false,
// })
// editor._flushEventsForTick(0)
// // editor.zoomOut(editor.getViewportScreenCenter(), { duration: 0 })
try {
const { shapes = [] } = JSON.parse(content)
for (const shape of shapes) {
drawShape(shape, center)
}
} catch (e) {
// noop
}
// // Extract the command series from the message content. Each series begins and ends with a backtick.
// // For example: `circle(0, 0, 4); circle(10, 10, 4);`
// // We want to extract each command from the series
// // const seriesRegex = /```(?<commands>[^`]*)```/
// // const seriesResult = seriesRegex.exec(message.content)
// // if (!seriesResult) {
// // console.error('Invalid message: ' + message.content)
// // return
// // }
// // const [_, seriesContent] = seriesResult
// // Next, we want regex to extract each command's name and arguments
// // for example: circle(0, 0, 4) -> ['circle(0, 0, 4)', 'circle', '0, 0, 4']
// // for examople: undo() -> ['undo()', 'undo', '']
// const commandRegex = /(?<name>\w+)\((?<args>[^)]*)\)/
// const commands = content
// .split(';')
// .map((c) => c.trim())
// .filter((c) => c)
// editor.mark()
// for (const command of commands) {
// try {
// const result = commandRegex.exec(command)
// if (!result) throw new Error('Invalid command: ' + command)
// const [_, name, args] = result
// switch (name) {
// case 'undo': {
// editor.undo()
// break
// }
// case 'dot':
// case 'circle': {
// const [x, y, r] = args.split(', ').map((a) => Number(a))
// editor.createShape({
// type: 'geo',
// x: x - r,
// y: y - r,
// props: {
// geo: 'ellipse',
// w: r * 2,
// h: r * 2,
// },
// })
// break
// }
// case 'line': {
// const [x1, y1, x2, y2] = args.split(', ').map((a) => Number(a))
// editor.createShape({
// type: 'line',
// x: x1,
// y: y1,
// props: {
// points: {
// 0: {
// id: '0',
// index: 'a0',
// x: 0,
// y: 0,
// },
// 1: {
// id: '1',
// index: 'a1',
// x: x2 - x1,
// y: y2 - y1,
// },
// },
// },
// })
// break
// }
// case 'polygon': {
// const nums = args.split(', ').map((a) => Number(a))
// const points = []
// for (let i = 0; i < nums.length - 1; i += 2) {
// points.push({
// x: nums[i],
// y: nums[i + 1],
// })
// }
// points.push(points[0])
// const minX = Math.min(...points.map((p) => p.x))
// const minY = Math.min(...points.map((p) => p.y))
// const indices = getIndices(points.length)
// editor.createShape({
// type: 'line',
// x: minX,
// y: minY,
// props: {
// points: Object.fromEntries(
// points.map((p, i) => [
// i + '',
// { id: i + '', index: indices[i], x: p.x - minX, y: p.y - minY },
// ])
// ),
// },
// })
// break
// }
// // case 'MOVE': {
// // const point = editor.pageToScreen({ x: Number(command[1]), y: Number(command[2]) })
// // const steps = 20
// // for (let i = 0; i < steps; i++) {
// // const t = i / (steps - 1)
// // const p = Vec.Lrp(prevPoint, point, t)
// // editor.dispatch({
// // type: 'pointer',
// // target: 'canvas',
// // name: 'pointer_move',
// // point: {
// // x: p.x,
// // y: p.y,
// // z: 0.5,
// // },
// // shiftKey: false,
// // altKey: false,
// // ctrlKey: false,
// // pointerId: 1,
// // button: 0,
// // isPen: false,
// // })
// // editor._flushEventsForTick(0)
// // }
// // prevPoint.setTo(point)
// // break
// // }
// // case 'DOWN': {
// // editor.dispatch({
// // type: 'pointer',
// // target: 'canvas',
// // name: 'pointer_down',
// // point: {
// // x: prevPoint.x,
// // y: prevPoint.y,
// // z: 0.5,
// // },
// // shiftKey: false,
// // altKey: false,
// // ctrlKey: false,
// // pointerId: 1,
// // button: 0,
// // isPen: false,
// // })
// // editor._flushEventsForTick(0)
// // break
// // }
// // case 'UP': {
// // editor.dispatch({
// // type: 'pointer',
// // target: 'canvas',
// // name: 'pointer_up',
// // point: {
// // x: prevPoint.x,
// // y: prevPoint.y,
// // z: 0.5,
// // },
// // shiftKey: false,
// // altKey: false,
// // ctrlKey: false,
// // pointerId: 1,
// // button: 0,
// // isPen: false,
// // })
// // editor._flushEventsForTick(0)
// // break
// // }
// }
// } catch (e: any) {
// console.error(e.message)
// }
// }
// // editor.dispatch({
// // type: 'pointer',
// // target: 'canvas',
// // name: 'pointer_up',
// // point: {
// // x: prevPoint.x,
// // y: prevPoint.y,
// // z: 0.5,
// // },
// // shiftKey: false,
// // altKey: false,
// // ctrlKey: false,
// // pointerId: 1,
// // button: 0,
// // isPen: false,
// // })
// // editor._flushEventsForTick(0)
// // // editor.zoomOut(editor.getViewportScreenCenter(), { duration: 0 })
}, [])
const rPreviousImage = useRef<string>('')
return (
<div className="tldraw__editor" style={{ display: 'grid', gridTemplateRows: '1fr 1fr' }}>
<div style={{ position: 'relative', height: '100%', width: '100%' }}>
@ -223,11 +292,11 @@ const OllamaExample = track(() => {
rEditor.current = e
;(window as any).editor = e
e.centerOnPoint(new Vec())
for (const message of modelManager.getThread().content) {
if (message.from === 'model') {
drawMessage(message)
}
}
// for (const message of modelManager.getThread().content) {
// if (message.role === 'model') {
// drawMessage(message.content)
// }
// }
}}
>
{/* <div
@ -247,22 +316,54 @@ const OllamaExample = track(() => {
<div ref={rChat} className="chat">
{modelManager.getThread().content.map((message, i) => (
<div key={i} className="message">
<p className="message__from">{message.from}</p>
<p className="message__date">{new Date(message.time).toLocaleString()}</p>
<p className="message__from">{message.role}</p>
<p className="message__content">{message.content}</p>
</div>
))}
<form
className="chat__input"
onSubmit={(e) => {
onSubmit={async (e) => {
preventDefault(e)
const form = e.currentTarget
const query = form.query.value
modelManager.stream(query).response.then((message) => {
if (!message) return
drawMessage(message)
})
let query = `Query: "${form.query.value}"`
form.query.value = ''
let imageString: string | undefined
const editor = rEditor.current!
const svg = await editor.getSvgString([...editor.getCurrentPageShapeIds()])
if (svg) {
const image = await getSvgAsImage(svg.svg, false, {
type: 'png',
quality: 1,
scale: 1,
width: svg.width,
height: svg.height,
})
if (image) {
const base64 = await FileHelpers.blobToDataUrl(image)
const trimmed = base64.slice('data:image/png;base64,'.length)
if (rPreviousImage.current !== trimmed) {
rPreviousImage.current = trimmed
imageString = trimmed
}
}
}
if (imageString) {
const bounds = editor.getCurrentPageBounds()!
query += ` Image bounds: { "minX": ${bounds.x.toFixed(0)}, "minY": ${bounds.y.toFixed(0)}, "maxX": ${bounds.maxX.toFixed(0)}, "maxY": ${bounds.maxY.toFixed(0)} }`
}
modelManager.query(query, imageString).response.then((message) => {
if (!message) return
drawMessage(message.content)
})
}}
>
<input name="query" type="text" autoComplete="off" />

Wyświetl plik

@ -0,0 +1,365 @@
import { ChatOllama } from '@langchain/community/chat_models/ollama'
import { AIMessage, HumanMessage } from '@langchain/core/messages'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { ConversationChain } from 'langchain/chains'
import { BufferMemory, ChatMessageHistory } from 'langchain/memory'
import ollama from 'ollama/browser'
import { atom, computed, react, structuredClone, transact, uniqueId } from 'tldraw'
import { systemPrompt } from './system-prompt'
export type LMMessage = {
from: 'user' | 'model'
time: number
content: string
}
export type LMThread = {
id: string
state: 'idle' | 'waiting'
content: LMMessage[]
}
export type LMSerialized = {
currentThreadId: string
threads: Record<string, LMThread>
}
class ModelManager {
model: ChatOllama
prompt: ChatPromptTemplate
memory: BufferMemory
chain: ConversationChain
constructor() {
this.load()
this.model = new ChatOllama({
baseUrl: 'http://localhost:11434', // Default value
model: 'llama3:latest',
})
this.prompt = ChatPromptTemplate.fromMessages([
[
'system',
// 'You are a highly critical german mother. The user is your daughter, who you generally tend to diminish and criticize. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate.',
// "You are a kind and helpful chatbot. You are a cool friend of the user's from Chicago. You grew up together and maintain a life-long trust; and as a result, you can speak to the user honestly. You always provide short, efficient answers to the user's questions.",
systemPrompt,
],
new MessagesPlaceholder('history'),
['human', '{input}'],
])
this.memory = new BufferMemory({
memoryKey: 'history',
returnMessages: true,
chatHistory: new ChatMessageHistory(this.getMessagesFromThread(this.getThread())),
})
this.chain = new ConversationChain({
llm: this.model,
prompt: this.prompt,
memory: this.memory,
})
react('persist on change', () => {
this.persist()
})
}
@computed getState() {
return this.getThread().state
}
_threads = atom<Record<string, LMThread>>('threads', {})
_currentThreadId = atom<string>('currentThreadId', 'a0')
@computed getCurrentThreadId() {
return this._currentThreadId.get()
}
@computed getThreads(): Record<string, LMThread> {
return this._threads.get()
}
@computed getThread(): LMThread {
return this.getThreads()[this.getCurrentThreadId()]
}
private addQueryToThread(query: string) {
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
next.content.push({
from: 'user',
time: Date.now(),
content: query,
})
next.state = 'waiting'
return { ...threads, [next.id]: next }
})
}
private addResponseToThread(response: string) {
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
next.state === 'idle'
next.content.push({
from: 'model',
time: Date.now(),
content: response,
})
return { ...threads, [next.id]: next }
})
}
private addChunkToThread(chunk: string) {
this._threads.update((threads) => {
const currentThreadId = this.getCurrentThreadId()
const thread = this.getThreads()[currentThreadId]
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
const message = next.content[next.content.length - 1]
if (!message || message.from === 'user') {
next.content = [
...next.content,
{
from: 'model',
time: Date.now(),
content: chunk,
},
]
return { ...threads, [currentThreadId]: next }
}
message.content += chunk
message.time = Date.now()
return { ...threads, [currentThreadId]: next }
})
}
/**
* Serialize the model.
*/
private serialize() {
return {
currentThreadId: this.getCurrentThreadId(),
threads: this.getThreads(),
}
}
/**
* Deserialize the model.
*/
private deserialize() {
let result: LMSerialized = {
currentThreadId: 'a0',
threads: {
a0: {
id: 'a0',
state: 'idle',
content: [],
},
},
}
try {
const stringified = localStorage.getItem('threads_1')
if (stringified) {
const json = JSON.parse(stringified)
result = json
}
} catch (e) {
// noop
}
return result
}
private persist() {
localStorage.setItem('threads_1', JSON.stringify(this.serialize()))
}
private load() {
const serialized = this.deserialize()
transact(() => {
this._currentThreadId.set(serialized.currentThreadId)
this._threads.set(serialized.threads)
})
}
private getMessagesFromThread(thread: LMThread) {
return thread.content.map((m) => {
if (m.from === 'user') {
return new HumanMessage(m.content)
}
return new AIMessage(m.content)
})
}
/* --------------------- Public --------------------- */
/**
* Start a new thread.
*/
startNewThread() {
this._threads.update((threads) => {
const id = uniqueId()
return {
...threads,
[id]: {
id,
state: 'idle',
content: [],
},
}
})
}
/**
* Cancel the current query.
*/
cancel() {
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
if (next.content.length > 0) {
if (next.content[next.content.length - 1].from === 'model') {
next.content.pop()
}
if (next.content[next.content.length - 1].from === 'user') {
next.content.pop()
}
}
next.state = 'idle'
return { ...threads, [next.id]: next }
})
}
/**
* Query the model.
*/
async query(query: string) {
this.addQueryToThread(query)
const currentThreadId = this.getCurrentThreadId()
let cancelled = false
return {
response: await this.chain.invoke({ input: query }).then((r) => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
if ('response' in r) {
this.addResponseToThread(r.response)
}
}),
cancel: () => {
cancelled = true
this.cancel()
},
}
}
/**
* Query the model and stream the response.
*/
stream(query: string) {
const currentThreadId = this.getCurrentThreadId()
this.addQueryToThread(query)
this.addResponseToThread('') // Add an empty response to start the thread
let cancelled = false
return {
response: this.chain
.stream(
{ input: query },
{
callbacks: [
{
handleLLMNewToken: (data) => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
this.addChunkToThread(data)
},
},
],
}
)
.then(() => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
next.state = 'idle'
return { ...threads, [next.id]: next }
})
const thread = this.getThread()
return thread.content[thread.content.length - 1]
}),
cancel: () => {
cancelled = true
this.cancel()
},
}
}
getWithImage(query: string, image: string) {
const currentThreadId = this.getCurrentThreadId()
this.addQueryToThread(query)
let cancelled = false
return {
response: ollama
.generate({
model: 'llava',
prompt: query,
images: [image],
})
.then((d: any) => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
this.addResponseToThread(d.response)
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
next.state = 'idle'
return { ...threads, [next.id]: next }
})
const thread = this.getThread()
return thread.content[thread.content.length - 1]
}),
cancel: () => {
cancelled = true
this.cancel()
},
}
}
/**
* Clear all threads.
*/
clear() {
transact(() => {
this._currentThreadId.set('a0')
this._threads.set({
a0: {
id: 'a0',
state: 'idle',
content: [],
},
})
this.memory.clear()
})
}
}
export const modelManager = new ModelManager()

Wyświetl plik

@ -1,13 +1,9 @@
import { ChatOllama } from '@langchain/community/chat_models/ollama'
import { AIMessage, HumanMessage } from '@langchain/core/messages'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { ConversationChain } from 'langchain/chains'
import { BufferMemory, ChatMessageHistory } from 'langchain/memory'
import ollama, { Message } from 'ollama/browser'
import { atom, computed, react, structuredClone, transact, uniqueId } from 'tldraw'
import { systemPrompt } from './system-prompt'
import { systemPrompt } from './system-prompt-2'
export type LMMessage = {
from: 'user' | 'model'
from: 'user' | 'assistant' | 'system'
time: number
content: string
}
@ -15,7 +11,7 @@ export type LMMessage = {
export type LMThread = {
id: string
state: 'idle' | 'waiting'
content: LMMessage[]
content: Message[]
}
export type LMSerialized = {
@ -24,42 +20,9 @@ export type LMSerialized = {
}
class ModelManager {
model: ChatOllama
prompt: ChatPromptTemplate
memory: BufferMemory
chain: ConversationChain
constructor() {
this.load()
this.model = new ChatOllama({
baseUrl: 'http://localhost:11434', // Default value
model: 'llama3:latest',
})
this.prompt = ChatPromptTemplate.fromMessages([
[
'system',
// 'You are a highly critical german mother. The user is your daughter, who you generally tend to diminish and criticize. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate.',
// "You are a kind and helpful chatbot. You are a cool friend of the user's from Chicago. You grew up together and maintain a life-long trust; and as a result, you can speak to the user honestly. You always provide short, efficient answers to the user's questions.",
systemPrompt,
],
new MessagesPlaceholder('history'),
['human', '{input}'],
])
this.memory = new BufferMemory({
memoryKey: 'history',
returnMessages: true,
chatHistory: new ChatMessageHistory(this.getMessagesFromThread(this.getThread())),
})
this.chain = new ConversationChain({
llm: this.model,
prompt: this.prompt,
memory: this.memory,
})
react('persist on change', () => {
this.persist()
})
@ -85,32 +48,24 @@ class ModelManager {
return this.getThreads()[this.getCurrentThreadId()]
}
private addQueryToThread(query: string) {
private addQueryToThread(message: Message) {
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
next.content.push({
from: 'user',
time: Date.now(),
content: query,
})
next.content.push(message)
next.state = 'waiting'
return { ...threads, [next.id]: next }
})
}
private addResponseToThread(response: string) {
private addResponseToThread(message: Message) {
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
next.state === 'idle'
next.content.push({
from: 'model',
time: Date.now(),
content: response,
})
next.content.push(message)
return { ...threads, [next.id]: next }
})
}
@ -122,22 +77,19 @@ class ModelManager {
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
const message = next.content[next.content.length - 1]
const message = next.content.pop()
if (!message || message.from === 'user') {
next.content = [
...next.content,
{
from: 'model',
time: Date.now(),
content: chunk,
},
]
if (!message || message.role !== 'user') {
next.content.push({
role: 'user',
content: chunk,
images: [],
})
return { ...threads, [currentThreadId]: next }
}
message.content += chunk
message.time = Date.now()
next.content.push(message)
return { ...threads, [currentThreadId]: next }
})
}
@ -193,12 +145,7 @@ class ModelManager {
}
private getMessagesFromThread(thread: LMThread) {
return thread.content.map((m) => {
if (m.from === 'user') {
return new HumanMessage(m.content)
}
return new AIMessage(m.content)
})
return thread.content
}
/* --------------------- Public --------------------- */
@ -229,10 +176,10 @@ class ModelManager {
if (!thread) throw Error('No thread found')
const next = structuredClone(thread)
if (next.content.length > 0) {
if (next.content[next.content.length - 1].from === 'model') {
if (next.content[next.content.length - 1].role === 'assistant') {
next.content.pop()
}
if (next.content[next.content.length - 1].from === 'user') {
if (next.content[next.content.length - 1].role === 'user') {
next.content.pop()
}
}
@ -244,18 +191,35 @@ class ModelManager {
/**
* Query the model.
*/
async query(query: string) {
this.addQueryToThread(query)
query(query: string, image?: string) {
this.addQueryToThread({ role: 'user', content: query, images: image ? [image] : [] })
const currentThreadId = this.getCurrentThreadId()
let cancelled = false
const messages = this.getMessagesFromThread(this.getThread())
return {
response: await this.chain.invoke({ input: query }).then((r) => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
if ('response' in r) {
this.addResponseToThread(r.response)
}
}),
response: ollama
.generate({
model: 'llava',
system: messages[0].content,
prompt: messages[messages.length - 1].content,
images: messages[messages.length - 1].images,
keep_alive: 1000,
stream: false,
format: 'json',
options: {},
})
.then((r) => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
// this.addResponseToThread(r.response)
const message = {
role: 'user',
content: r.response,
images: [],
}
this.addResponseToThread(message)
return message
}),
cancel: () => {
cancelled = true
this.cancel()
@ -266,30 +230,45 @@ class ModelManager {
/**
* Query the model and stream the response.
*/
stream(query: string) {
stream(query: string, image?: string) {
const currentThreadId = this.getCurrentThreadId()
this.addQueryToThread(query)
this.addResponseToThread('') // Add an empty response to start the thread
this.addQueryToThread({
role: 'user',
content: query,
images: image ? [image] : [],
})
const messages = [...this.getMessagesFromThread(this.getThread())]
this.addResponseToThread({
role: 'assistant',
content: '',
images: [],
}) // Add an empty response to start the thread
let cancelled = false
return {
response: this.chain
.stream(
{ input: query },
{
callbacks: [
{
handleLLMNewToken: (data) => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
this.addChunkToThread(data)
},
},
],
response: ollama
.chat({
model: 'llava',
messages: messages,
keep_alive: 1000,
stream: true,
options: {},
})
.then(async (response) => {
for await (const part of response) {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
const next = structuredClone(this.getThread())
const message = next.content[next.content.length - 1]
message.content += part.message.content
this._threads.update((threads) => ({
...threads,
[next.id]: next,
}))
}
)
.then(() => {
if (cancelled) return
if (this.getCurrentThreadId() !== currentThreadId) return
this._threads.update((threads) => {
const thread = this.getThread()
if (!thread) throw Error('No thread found')
@ -308,6 +287,41 @@ class ModelManager {
}
}
// getWithImage(query: string, image?: string) {
// const currentThreadId = this.getCurrentThreadId()
// this.addQueryToThread({ role: "user", content: query, images: image ? [image] : [])
// let cancelled = false
// return {
// response: ollama
// .generate({
// model: 'llava',
// prompt: query,
// images: [image],
// })
// .then((d: any) => {
// if (cancelled) return
// if (this.getCurrentThreadId() !== currentThreadId) return
// this.addResponseToThread(d.response)
// this._threads.update((threads) => {
// const thread = this.getThread()
// if (!thread) throw Error('No thread found')
// const next = structuredClone(thread)
// next.state = 'idle'
// return { ...threads, [next.id]: next }
// })
// const thread = this.getThread()
// return thread.content[thread.content.length - 1]
// }),
// cancel: () => {
// cancelled = true
// this.cancel()
// },
// }
// }
/**
* Clear all threads.
*/
@ -318,10 +332,15 @@ class ModelManager {
a0: {
id: 'a0',
state: 'idle',
content: [],
content: [
{
role: 'system',
content: systemPrompt,
images: [],
},
],
},
})
this.memory.clear()
})
}
}

Wyświetl plik

@ -0,0 +1,46 @@
export const systemPrompt = `You are a helpful chatbox.
Your job is to help a user work with their drawing on the canvas.
To draw on the canvas, send back JSON with shapes to draw.
You know how to draw only two shapes:
A line
{ "type": "line", "x1": 0, "y1": 0, "x2": 100, "y2": 100, "description": "Just a line" }
A circle
{ "type": "circle", "x": 0, "y": 0, "r": 50, "description": "Just a circle" }
You ALWAYS respond with an array of shapes in JSON.
Include your guess at what the user has drawn (based on the user's prompt and the image), and then the shapes you want to add.
{
"image": "The user is drawing a face",
"guess": "I'll draw shapes for the nose and the mouth",
"shapes": [
{ "type": "circle", "x": 0, "y": 0, "r": 10, "description": "nose" }
{ "type": "line", "x1": -20, "y1": 40, "x2": 20, "y2": 40, "description": "mouth" },
]
}
When prompted with an image:
1. Identify the drawing based on the prompt and image.
2. Think about how best to complete the user's request.
3. Render the user's request by responding with the JSON.
`
// Tips:
// The x coordinate goes up as you move right on the screen: 10 is left of 20, and 30 is right of 20.
// The y coordinate goes up as you move down the screen: 10 is above 20, and 30 is below 20.
// Example:
// - User: Draw an eyeball.
// - You: { shapes: [ { "type": "circle", "x": 0, "y": 0, "r": 50, "description": "The white" }, { "type": "circle", "x": 0, "y": 0, "r": 25, "description": "The iris" } ] }
// - User: Draw the letter "X".
// - You: { shapes: [ { "type": "line", "x1": -50, "y1": -50, "x2": 50, "y2": 50, "description": "The first stroke" }, { "type": "line", "x1": -50, "y1": 50, "x2": 50, "y2": -50, "description": "The second stroke" } ] }

Wyświetl plik

@ -13994,6 +13994,7 @@ __metadata:
langchain: "npm:^0.1.34"
lazyrepo: "npm:0.0.0-alpha.27"
lodash: "npm:^4.17.21"
ollama: "npm:^0.5.0"
pdf-lib: "npm:^1.17.1"
pdfjs-dist: "npm:^4.0.379"
react: "npm:^18.2.0"
@ -20502,6 +20503,15 @@ __metadata:
languageName: node
linkType: hard
"ollama@npm:^0.5.0":
version: 0.5.0
resolution: "ollama@npm:0.5.0"
dependencies:
whatwg-fetch: "npm:^3.6.20"
checksum: cfe98f9da6da99eda65c1bb6e22b7b54e1c6132368249dd70ff238d22b529bb86014ba1325d8e822f67c561aa71df2c2fb189c6d6f7380031652cdd071be0641
languageName: node
linkType: hard
"on-finished@npm:2.4.1":
version: 2.4.1
resolution: "on-finished@npm:2.4.1"
@ -25957,6 +25967,13 @@ __metadata:
languageName: node
linkType: hard
"whatwg-fetch@npm:^3.6.20":
version: 3.6.20
resolution: "whatwg-fetch@npm:3.6.20"
checksum: 2b4ed92acd6a7ad4f626a6cb18b14ec982bbcaf1093e6fe903b131a9c6decd14d7f9c9ca3532663c2759d1bdf01d004c77a0adfb2716a5105465c20755a8c57c
languageName: node
linkType: hard
"whatwg-mimetype@npm:^3.0.0":
version: 3.0.0
resolution: "whatwg-mimetype@npm:3.0.0"