diff --git a/packages/editor/src/lib/editor/Editor.ts b/packages/editor/src/lib/editor/Editor.ts index 2ff5e8925..268d4d5f8 100644 --- a/packages/editor/src/lib/editor/Editor.ts +++ b/packages/editor/src/lib/editor/Editor.ts @@ -120,6 +120,7 @@ import { getReorderingShapesChanges } from '../utils/reorderShapes' import { applyRotationToSnapshotShapes, getRotationSnapshot } from '../utils/rotation' import { uniqueId } from '../utils/uniqueId' import { BindingUtil, TLBindingUtilConstructor } from './bindings/BindingUtil' +import { bindingsIndex } from './derivations/bindingsIndex' import { notVisibleShapes } from './derivations/notVisibleShapes' import { parentsToChildren } from './derivations/parentsToChildren' import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage' @@ -379,19 +380,21 @@ export class Editor extends EventEmitter { this.sideEffects.register({ shape: { afterChange: (shapeBefore, shapeAfter) => { - for (const binding of this.getAllBindingsFromShape(shapeAfter)) { - this.getBindingUtil(binding).onAfterChangeFromShape?.({ - binding, - shapeBefore, - shapeAfter, - }) - } - for (const binding of this.getAllBindingsToShape(shapeAfter)) { - this.getBindingUtil(binding).onAfterChangeToShape?.({ - binding, - shapeBefore, - shapeAfter, - }) + for (const binding of this.getBindingsInvolvingShape(shapeAfter)) { + if (binding.fromId === shapeAfter.id) { + this.getBindingUtil(binding).onAfterChangeFromShape?.({ + binding, + shapeBefore, + shapeAfter, + }) + } + if (binding.toId === shapeAfter.id) { + this.getBindingUtil(binding).onAfterChangeToShape?.({ + binding, + shapeBefore, + shapeAfter, + }) + } } // if the shape's parent changed and it has a binding, update the binding @@ -400,19 +403,21 @@ export class Editor extends EventEmitter { const descendantShape = this.getShape(id) if (!descendantShape) return - for (const binding of this.getAllBindingsFromShape(descendantShape)) { - this.getBindingUtil(binding).onAfterChangeFromShape?.({ - binding, - shapeBefore: descendantShape, - shapeAfter: descendantShape, - }) - } - for (const binding of this.getAllBindingsToShape(descendantShape)) { - this.getBindingUtil(binding).onAfterChangeToShape?.({ - binding, - shapeBefore: descendantShape, - shapeAfter: descendantShape, - }) + for (const binding of this.getBindingsInvolvingShape(descendantShape)) { + if (binding.fromId === descendantShape.id) { + this.getBindingUtil(binding).onAfterChangeFromShape?.({ + binding, + shapeBefore: descendantShape, + shapeAfter: descendantShape, + }) + } + if (binding.toId === descendantShape.id) { + this.getBindingUtil(binding).onAfterChangeToShape?.({ + binding, + shapeBefore: descendantShape, + shapeAfter: descendantShape, + }) + } } } notifyBindingAncestryChange(shapeAfter.id) @@ -451,13 +456,15 @@ export class Editor extends EventEmitter { } const deleteBindingIds: TLBindingId[] = [] - for (const binding of this.getAllBindingsFromShape(shape)) { - this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) - deleteBindingIds.push(binding.id) - } - for (const binding of this.getAllBindingsToShape(shape)) { - this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape }) - deleteBindingIds.push(binding.id) + for (const binding of this.getBindingsInvolvingShape(shape)) { + if (binding.fromId === shape.id) { + this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) + deleteBindingIds.push(binding.id) + } + if (binding.toId === shape.id) { + this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape }) + deleteBindingIds.push(binding.id) + } } this.deleteBindings(deleteBindingIds) @@ -5032,6 +5039,15 @@ export class Editor extends EventEmitter { /* -------------------- Bindings -------------------- */ + @computed + private _getBindingsIndex() { + return bindingsIndex(this) + } + + private getBindingsIndex() { + return this._getBindingsIndex().get() + } + getBinding(id: TLBindingId): TLBinding | undefined { return this.store.get(id) as TLBinding | undefined } @@ -5039,35 +5055,26 @@ export class Editor extends EventEmitter { // TODO(alex) #bindings - cache `allBindings` getters and derive type-specific ones from them getBindingsFromShape( shape: TLShape | TLShapeId, - type: Binding['type'] + type?: Binding['type'] ): Binding[] { const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - fromId: { eq: id }, - type: { eq: type }, - }) as Binding[] + return this.getBindingsInvolvingShape(id, type).filter((b) => b.fromId === id) as Binding[] } getBindingsToShape( shape: TLShape | TLShapeId, - type: Binding['type'] + type?: Binding['type'] ): Binding[] { const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - toId: { eq: id }, - type: { eq: type }, - }) as Binding[] + return this.getBindingsInvolvingShape(id, type).filter((b) => b.toId === id) as Binding[] } - getAllBindingsFromShape(shape: TLShape | TLShapeId): TLBinding[] { + getBindingsInvolvingShape( + shape: TLShape | TLShapeId, + type?: Binding['type'] + ): Binding[] { const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - fromId: { eq: id }, - }) - } - getAllBindingsToShape(shape: TLShape | TLShapeId): TLBinding[] { - const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - toId: { eq: id }, - }) + const result = this.getBindingsIndex()[id] ?? EMPTY_ARRAY + if (!type) return result as Binding[] + return result.filter((b) => b.type === type) as Binding[] } createBindings(partials: RequiredKeys[]) { @@ -8744,22 +8751,19 @@ function withoutBindingsToUnrelatedShapes( const shape = editor.getShape(shapeId) if (!shape) continue - for (const binding of editor.getAllBindingsFromShape(shapeId)) { - if (shapeIds.has(binding.toId)) { - // if we have both sides of the binding, we want to recreate it + for (const binding of editor.getBindingsInvolvingShape(shapeId)) { + const hasFrom = shapeIds.has(binding.fromId) + const hasTo = shapeIds.has(binding.toId) + if (hasFrom && hasTo) { bindingsWithBoth.add(binding.id) - } else { - // otherwise, if we only have one side, we need to record that and duplicate - // the shape as if the one it's bound to has been deleted - bindingsWithoutTo.add(binding.id) + continue } - } - for (const binding of editor.getAllBindingsToShape(shapeId)) { - if (shapeIds.has(binding.fromId)) { - bindingsWithBoth.add(binding.id) - } else { + if (!hasFrom) { bindingsWithoutFrom.add(binding.id) } + if (!hasTo) { + bindingsWithoutTo.add(binding.id) + } } } diff --git a/packages/editor/src/lib/editor/derivations/bindingsIndex.ts b/packages/editor/src/lib/editor/derivations/bindingsIndex.ts new file mode 100644 index 000000000..eabf85796 --- /dev/null +++ b/packages/editor/src/lib/editor/derivations/bindingsIndex.ts @@ -0,0 +1,89 @@ +import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state' +import { TLBinding, TLShapeId } from '@tldraw/tlschema' +import { objectMapValues } from '@tldraw/utils' +import { Editor } from '../Editor' + +type TLBindingsIndex = Record + +export const bindingsIndex = (editor: Editor): Computed => { + const { store } = editor + const bindingsHistory = store.query.filterHistory('binding') + const bindingsQuery = store.query.records('binding') + function fromScratch() { + const allBindings = bindingsQuery.get() as TLBinding[] + + const shape2Binding: TLBindingsIndex = {} + + for (const binding of allBindings) { + const { fromId, toId } = binding + const bindingsForFromShape = (shape2Binding[fromId] ??= []) + bindingsForFromShape.push(binding) + const bindingsForToShape = (shape2Binding[toId] ??= []) + bindingsForToShape.push(binding) + } + + return shape2Binding + } + + return computed('arrowBindingsIndex', (_lastValue, lastComputedEpoch) => { + if (isUninitialized(_lastValue)) { + return fromScratch() + } + + const lastValue = _lastValue + + const diff = bindingsHistory.getDiffSince(lastComputedEpoch) + + if (diff === RESET_VALUE) { + return fromScratch() + } + + let nextValue: TLBindingsIndex | undefined = undefined + + function removingBinding(binding: TLBinding) { + nextValue ??= { ...lastValue } + nextValue[binding.fromId] = nextValue[binding.fromId]?.filter((b) => b.id !== binding.id) + if (!nextValue[binding.fromId]?.length) { + delete nextValue[binding.fromId] + } + nextValue[binding.toId] = nextValue[binding.toId]?.filter((b) => b.id !== binding.id) + if (!nextValue[binding.toId]?.length) { + delete nextValue[binding.toId] + } + } + + function ensureNewArray(shapeId: TLShapeId) { + nextValue ??= { ...lastValue } + if (!nextValue[shapeId]) { + nextValue[shapeId] = [] + } else if (nextValue[shapeId] === lastValue[shapeId]) { + nextValue[shapeId] = nextValue[shapeId]!.slice(0) + } + } + + function addBinding(binding: TLBinding) { + ensureNewArray(binding.fromId) + ensureNewArray(binding.toId) + nextValue![binding.fromId]!.push(binding) + nextValue![binding.toId]!.push(binding) + } + + for (const changes of diff) { + for (const newBinding of objectMapValues(changes.added)) { + addBinding(newBinding) + } + + for (const [prev, next] of objectMapValues(changes.updated)) { + removingBinding(prev) + addBinding(next) + } + + for (const prev of objectMapValues(changes.removed)) { + removingBinding(prev) + } + } + + // TODO: add diff entries if we need them + return nextValue ?? lastValue + }) +} diff --git a/packages/tldraw/src/test/bindingsIndex.test.tsx b/packages/tldraw/src/test/bindingsIndex.test.tsx new file mode 100644 index 000000000..4b6fbfea2 --- /dev/null +++ b/packages/tldraw/src/test/bindingsIndex.test.tsx @@ -0,0 +1,260 @@ +import { TLArrowBinding, TLGeoShape, TLShapeId, createShapeId } from '@tldraw/editor' +import { TestEditor } from './TestEditor' +import { TL } from './test-jsx' + +let editor: TestEditor + +beforeEach(() => { + editor = new TestEditor() +}) + +describe('bindingsIndex', () => { + it('keeps a mapping from bound shapes to their bindings', () => { + const ids = editor.createShapesFromJsx([ + , + , + ]) + + editor.selectNone() + editor.setCurrentTool('arrow') + editor.pointerDown(50, 50) + expect(editor.getOnlySelectedShape()).toBe(null) + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([]) + + editor.pointerMove(50, 55) + expect(editor.getOnlySelectedShape()).not.toBe(null) + const arrow = editor.getOnlySelectedShape()! + expect(arrow.type).toBe('arrow') + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow]) + + editor.pointerMove(250, 50) + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([editor.getShape(arrow.id)]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([editor.getShape(arrow.id)]) + }) + + it('works if there are many arrows', () => { + const ids = { + box1: createShapeId('box1'), + box2: createShapeId('box2'), + } + + editor.createShapes([ + { type: 'geo', id: ids.box1, x: 0, y: 0, props: { w: 100, h: 100 } }, + { type: 'geo', id: ids.box2, x: 200, y: 0, props: { w: 100, h: 100 } }, + ]) + + editor.setCurrentTool('arrow') + // start at box 1 and end on box 2 + editor.pointerDown(50, 50) + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([]) + + editor.pointerMove(250, 50) + const arrow1 = editor.getOnlySelectedShape()! + expect(arrow1.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1]) + + editor.pointerUp() + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1]) + + // start at box 1 and end on the page + editor.setCurrentTool('arrow') + editor.pointerMove(50, 50).pointerDown().pointerMove(50, -50).pointerUp() + const arrow2 = editor.getOnlySelectedShape()! + expect(arrow2.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2]) + + // start outside box 1 and end in box 1 + editor.setCurrentTool('arrow') + editor.pointerDown(0, -50).pointerMove(50, 50).pointerUp(50, 50) + const arrow3 = editor.getOnlySelectedShape()! + expect(arrow3.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2, arrow3]) + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1]) + + // start at box 2 and end on the page + editor.selectNone() + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50) + editor.expectToBeIn('arrow.pointing') + editor.pointerMove(250, -50) + editor.expectToBeIn('select.dragging_handle') + const arrow4 = editor.getOnlySelectedShape()! + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4]) + + editor.pointerUp(250, -50) + editor.expectToBeIn('select.idle') + expect(arrow4.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4]) + + // start outside box 2 and enter in box 2 + editor.setCurrentTool('arrow') + editor.pointerDown(250, -50).pointerMove(250, 50).pointerUp(250, 50) + const arrow5 = editor.getOnlySelectedShape()! + expect(arrow5.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2, arrow3]) + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4, arrow5]) + }) + + describe('updating shapes', () => { + // ▲ │ │ ▲ + // │ │ │ │ + // b c e d + // ┌───┼─┴─┐ ┌──┴──┼─┐ + // │ │ ▼ │ │ ▼ │ │ + // │ └───┼─────a───┼───► │ │ + // │ 1 │ │ 2 │ + // └───────┘ └───────┘ + let arrowAId: TLShapeId + let arrowBId: TLShapeId + let arrowCId: TLShapeId + let arrowDId: TLShapeId + let arrowEId: TLShapeId + let ids: Record + beforeEach(() => { + ids = editor.createShapesFromJsx([ + , + , + ]) + + // span both boxes + editor.setCurrentTool('arrow') + editor.pointerDown(50, 50).pointerMove(250, 50).pointerUp(250, 50) + arrowAId = editor.getOnlySelectedShape()!.id + // start at box 1 and leave + editor.setCurrentTool('arrow') + editor.pointerDown(50, 50).pointerMove(50, -50).pointerUp(50, -50) + arrowBId = editor.getOnlySelectedShape()!.id + // start outside box 1 and enter + editor.setCurrentTool('arrow') + editor.pointerDown(50, -50).pointerMove(50, 50).pointerUp(50, 50) + arrowCId = editor.getOnlySelectedShape()!.id + // start at box 2 and leave + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50).pointerMove(250, -50).pointerUp(250, -50) + arrowDId = editor.getOnlySelectedShape()!.id + // start outside box 2 and enter + editor.setCurrentTool('arrow') + editor.pointerDown(250, -50).pointerMove(250, 50).pointerUp(250, 50) + arrowEId = editor.getOnlySelectedShape()!.id + }) + it('deletes the entry if you delete the bound shapes', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + editor.deleteShapes([ids.box2]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([]) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + }) + it('deletes the entry if you delete an arrow', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + editor.deleteShapes([arrowEId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(2) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.deleteShapes([arrowDId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.deleteShapes([arrowCId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(2) + + editor.deleteShapes([arrowBId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(1) + + editor.deleteShapes([arrowAId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(0) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(0) + }) + + it('deletes the entries in a batch too', () => { + editor.deleteShapes([arrowAId, arrowBId, arrowCId, arrowDId, arrowEId]) + + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(0) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(0) + }) + + it('adds new entries after initial creation', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + // draw from box 2 to box 1 + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50).pointerMove(50, 50).pointerUp(50, 50) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(4) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(4) + + // create a new box + + const { box3 } = editor.createShapesFromJsx( + + ) + + // draw from box 2 to box 3 + + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50).pointerMove(450, 50).pointerUp(450, 50) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(5) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(4) + expect(editor.getArrowsBoundTo(box3)).toHaveLength(1) + }) + + it('works when copy pasting', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.selectAll() + editor.duplicateShapes(editor.getSelectedShapeIds()) + + const [box1Clone, box2Clone] = editor + .getSelectedShapes() + .filter((shape) => editor.isShapeOfType(shape, 'geo')) + .sort((a, b) => a.x - b.x) + + expect(editor.getArrowsBoundTo(box2Clone.id)).toHaveLength(3) + expect(editor.getArrowsBoundTo(box1Clone.id)).toHaveLength(3) + }) + + it('allows bound shapes to be moved', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.nudgeShapes([ids.box2], { x: 0, y: -1 }) + + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + }) + + it('allows the arrows bound shape to change', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + // create another box + + const { box3 } = editor.createShapesFromJsx( + + ) + + // move arrowA end from box2 to box3 + const binding = editor + .getBindingsInvolvingShape(ids.box2, 'arrow') + .find((b) => b.props.terminal === 'end')! + editor.updateBinding({ ...binding, toId: box3 } satisfies TLArrowBinding) + + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(2) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + expect(editor.getArrowsBoundTo(box3)).toHaveLength(1) + }) + }) +})