kopia lustrzana https://github.com/Tldraw/Tldraw
330 wiersze
8.1 KiB
TypeScript
330 wiersze
8.1 KiB
TypeScript
import { ChatOllama } from '@langchain/community/chat_models/ollama'
|
|
import { AIMessage, HumanMessage } from '@langchain/core/messages'
|
|
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
|
|
import { ConversationChain } from 'langchain/chains'
|
|
import { BufferMemory, ChatMessageHistory } from 'langchain/memory'
|
|
import { atom, computed, react, structuredClone, transact, uniqueId } from 'tldraw'
|
|
import { systemPrompt } from './system-prompt'
|
|
|
|
export type LMMessage = {
|
|
from: 'user' | 'model'
|
|
time: number
|
|
content: string
|
|
}
|
|
|
|
export type LMThread = {
|
|
id: string
|
|
state: 'idle' | 'waiting'
|
|
content: LMMessage[]
|
|
}
|
|
|
|
export type LMSerialized = {
|
|
currentThreadId: string
|
|
threads: Record<string, LMThread>
|
|
}
|
|
|
|
class ModelManager {
|
|
model: ChatOllama
|
|
prompt: ChatPromptTemplate
|
|
memory: BufferMemory
|
|
chain: ConversationChain
|
|
|
|
constructor() {
|
|
this.load()
|
|
|
|
this.model = new ChatOllama({
|
|
baseUrl: 'http://localhost:11434', // Default value
|
|
model: 'llama3:latest',
|
|
})
|
|
|
|
this.prompt = ChatPromptTemplate.fromMessages([
|
|
[
|
|
'system',
|
|
// 'You are a highly critical german mother. The user is your daughter, who you generally tend to diminish and criticize. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate. You are very strict and have high expectations for your daughter. You are a perfectionist and expect the same from your daughter. You are very critical of your daughter and often make her feel inadequate.',
|
|
// "You are a kind and helpful chatbot. You are a cool friend of the user's from Chicago. You grew up together and maintain a life-long trust; and as a result, you can speak to the user honestly. You always provide short, efficient answers to the user's questions.",
|
|
systemPrompt,
|
|
],
|
|
new MessagesPlaceholder('history'),
|
|
['human', '{input}'],
|
|
])
|
|
|
|
this.memory = new BufferMemory({
|
|
memoryKey: 'history',
|
|
returnMessages: true,
|
|
chatHistory: new ChatMessageHistory(this.getMessagesFromThread(this.getThread())),
|
|
})
|
|
|
|
this.chain = new ConversationChain({
|
|
llm: this.model,
|
|
prompt: this.prompt,
|
|
memory: this.memory,
|
|
})
|
|
|
|
react('persist on change', () => {
|
|
this.persist()
|
|
})
|
|
}
|
|
|
|
@computed getState() {
|
|
return this.getThread().state
|
|
}
|
|
|
|
_threads = atom<Record<string, LMThread>>('threads', {})
|
|
|
|
_currentThreadId = atom<string>('currentThreadId', 'a0')
|
|
|
|
@computed getCurrentThreadId() {
|
|
return this._currentThreadId.get()
|
|
}
|
|
|
|
@computed getThreads(): Record<string, LMThread> {
|
|
return this._threads.get()
|
|
}
|
|
|
|
@computed getThread(): LMThread {
|
|
return this.getThreads()[this.getCurrentThreadId()]
|
|
}
|
|
|
|
private addQueryToThread(query: string) {
|
|
this._threads.update((threads) => {
|
|
const thread = this.getThread()
|
|
if (!thread) throw Error('No thread found')
|
|
const next = structuredClone(thread)
|
|
next.content.push({
|
|
from: 'user',
|
|
time: Date.now(),
|
|
content: query,
|
|
})
|
|
next.state = 'waiting'
|
|
return { ...threads, [next.id]: next }
|
|
})
|
|
}
|
|
|
|
private addResponseToThread(response: string) {
|
|
this._threads.update((threads) => {
|
|
const thread = this.getThread()
|
|
if (!thread) throw Error('No thread found')
|
|
const next = structuredClone(thread)
|
|
next.state === 'idle'
|
|
next.content.push({
|
|
from: 'model',
|
|
time: Date.now(),
|
|
content: response,
|
|
})
|
|
return { ...threads, [next.id]: next }
|
|
})
|
|
}
|
|
|
|
private addChunkToThread(chunk: string) {
|
|
this._threads.update((threads) => {
|
|
const currentThreadId = this.getCurrentThreadId()
|
|
const thread = this.getThreads()[currentThreadId]
|
|
if (!thread) throw Error('No thread found')
|
|
|
|
const next = structuredClone(thread)
|
|
const message = next.content[next.content.length - 1]
|
|
|
|
if (!message || message.from === 'user') {
|
|
next.content = [
|
|
...next.content,
|
|
{
|
|
from: 'model',
|
|
time: Date.now(),
|
|
content: chunk,
|
|
},
|
|
]
|
|
return { ...threads, [currentThreadId]: next }
|
|
}
|
|
|
|
message.content += chunk
|
|
message.time = Date.now()
|
|
return { ...threads, [currentThreadId]: next }
|
|
})
|
|
}
|
|
|
|
/**
|
|
* Serialize the model.
|
|
*/
|
|
private serialize() {
|
|
return {
|
|
currentThreadId: this.getCurrentThreadId(),
|
|
threads: this.getThreads(),
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Deserialize the model.
|
|
*/
|
|
private deserialize() {
|
|
let result: LMSerialized = {
|
|
currentThreadId: 'a0',
|
|
threads: {
|
|
a0: {
|
|
id: 'a0',
|
|
state: 'idle',
|
|
content: [],
|
|
},
|
|
},
|
|
}
|
|
|
|
try {
|
|
const stringified = localStorage.getItem('threads_1')
|
|
if (stringified) {
|
|
const json = JSON.parse(stringified)
|
|
result = json
|
|
}
|
|
} catch (e) {
|
|
// noop
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
private persist() {
|
|
localStorage.setItem('threads_1', JSON.stringify(this.serialize()))
|
|
}
|
|
|
|
private load() {
|
|
const serialized = this.deserialize()
|
|
transact(() => {
|
|
this._currentThreadId.set(serialized.currentThreadId)
|
|
this._threads.set(serialized.threads)
|
|
})
|
|
}
|
|
|
|
private getMessagesFromThread(thread: LMThread) {
|
|
return thread.content.map((m) => {
|
|
if (m.from === 'user') {
|
|
return new HumanMessage(m.content)
|
|
}
|
|
return new AIMessage(m.content)
|
|
})
|
|
}
|
|
|
|
/* --------------------- Public --------------------- */
|
|
|
|
/**
|
|
* Start a new thread.
|
|
*/
|
|
startNewThread() {
|
|
this._threads.update((threads) => {
|
|
const id = uniqueId()
|
|
return {
|
|
...threads,
|
|
[id]: {
|
|
id,
|
|
state: 'idle',
|
|
content: [],
|
|
},
|
|
}
|
|
})
|
|
}
|
|
|
|
/**
|
|
* Cancel the current query.
|
|
*/
|
|
cancel() {
|
|
this._threads.update((threads) => {
|
|
const thread = this.getThread()
|
|
if (!thread) throw Error('No thread found')
|
|
const next = structuredClone(thread)
|
|
if (next.content.length > 0) {
|
|
if (next.content[next.content.length - 1].from === 'model') {
|
|
next.content.pop()
|
|
}
|
|
if (next.content[next.content.length - 1].from === 'user') {
|
|
next.content.pop()
|
|
}
|
|
}
|
|
next.state = 'idle'
|
|
return { ...threads, [next.id]: next }
|
|
})
|
|
}
|
|
|
|
/**
|
|
* Query the model.
|
|
*/
|
|
async query(query: string) {
|
|
this.addQueryToThread(query)
|
|
const currentThreadId = this.getCurrentThreadId()
|
|
let cancelled = false
|
|
return {
|
|
response: await this.chain.invoke({ input: query }).then((r) => {
|
|
if (cancelled) return
|
|
if (this.getCurrentThreadId() !== currentThreadId) return
|
|
if ('response' in r) {
|
|
this.addResponseToThread(r.response)
|
|
}
|
|
}),
|
|
cancel: () => {
|
|
cancelled = true
|
|
this.cancel()
|
|
},
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Query the model and stream the response.
|
|
*/
|
|
stream(query: string) {
|
|
const currentThreadId = this.getCurrentThreadId()
|
|
this.addQueryToThread(query)
|
|
this.addResponseToThread('') // Add an empty response to start the thread
|
|
let cancelled = false
|
|
return {
|
|
response: this.chain
|
|
.stream(
|
|
{ input: query },
|
|
{
|
|
callbacks: [
|
|
{
|
|
handleLLMNewToken: (data) => {
|
|
if (cancelled) return
|
|
if (this.getCurrentThreadId() !== currentThreadId) return
|
|
this.addChunkToThread(data)
|
|
},
|
|
},
|
|
],
|
|
}
|
|
)
|
|
.then(() => {
|
|
if (cancelled) return
|
|
if (this.getCurrentThreadId() !== currentThreadId) return
|
|
this._threads.update((threads) => {
|
|
const thread = this.getThread()
|
|
if (!thread) throw Error('No thread found')
|
|
const next = structuredClone(thread)
|
|
next.state = 'idle'
|
|
return { ...threads, [next.id]: next }
|
|
})
|
|
|
|
const thread = this.getThread()
|
|
return thread.content[thread.content.length - 1]
|
|
}),
|
|
cancel: () => {
|
|
cancelled = true
|
|
this.cancel()
|
|
},
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Clear all threads.
|
|
*/
|
|
clear() {
|
|
transact(() => {
|
|
this._currentThreadId.set('a0')
|
|
this._threads.set({
|
|
a0: {
|
|
id: 'a0',
|
|
state: 'idle',
|
|
content: [],
|
|
},
|
|
})
|
|
this.memory.clear()
|
|
})
|
|
}
|
|
}
|
|
|
|
export const modelManager = new ModelManager()
|