diff --git a/packages/core/src/renderer/components/binding/binding.tsx b/packages/core/src/renderer/components/binding/binding.tsx new file mode 100644 index 000000000..debfb8182 --- /dev/null +++ b/packages/core/src/renderer/components/binding/binding.tsx @@ -0,0 +1,15 @@ +import type { TLBinding } from '@tldraw/core/src/types' + +interface BindingProps { + point: number[] + type: TLBinding['type'] +} + +export function Binding({ point: [x, y], type }: BindingProps): JSX.Element { + return ( + + {type === 'center' && } + {type !== 'pin' && } + + ) +} diff --git a/packages/core/src/renderer/components/binding/index.ts b/packages/core/src/renderer/components/binding/index.ts new file mode 100644 index 000000000..c29c543cd --- /dev/null +++ b/packages/core/src/renderer/components/binding/index.ts @@ -0,0 +1 @@ +export * from './binding' diff --git a/packages/core/src/renderer/hooks/useStyle.tsx b/packages/core/src/renderer/hooks/useStyle.tsx index 4288d2419..cfc5decce 100644 --- a/packages/core/src/renderer/hooks/useStyle.tsx +++ b/packages/core/src/renderer/hooks/useStyle.tsx @@ -218,6 +218,12 @@ const tlcss = css` .tl-current-parent > *[data-shy='true'] { opacity: 1; } + + .tl-binding { + fill: none; + stroke: var(--tl-selectStroke); + stroke-width: calc(2px * var(--tl-scale)); + } ` export function useTLTheme(theme?: Partial) { diff --git a/packages/core/src/types.tsx b/packages/core/src/types.tsx index f4da084e2..89efe2c67 100644 --- a/packages/core/src/types.tsx +++ b/packages/core/src/types.tsx @@ -15,7 +15,7 @@ export interface TLPageState { pointedId?: string hoveredId?: string editingId?: string - editingBindingId?: string + bindingId?: string boundsRotation?: number currentParentId?: string selectedIds: string[] @@ -29,6 +29,8 @@ export interface TLHandle { id: string index: number point: number[] + canBind?: boolean + bindingId?: string } export interface TLShape { @@ -258,6 +260,7 @@ export abstract class TLShapeUtil { isEditableText = false isAspectRatioLocked = false canEdit = false + canBind = false abstract type: T['type'] @@ -294,6 +297,17 @@ export abstract class TLShapeUtil { return [bounds.width / 2, bounds.height / 2] } + getBindingPoint( + shape: T, + point: number[], + origin: number[], + direction: number[], + padding: number, + anywhere: boolean + ): { point: number[]; distance: number } | undefined { + return undefined + } + create(props: Partial): T { return { ...this.defaultProps, ...props } } @@ -317,7 +331,7 @@ export abstract class TLShapeUtil { _targetBounds: TLBounds, _center: number[] ): Partial | void { - return + return undefined } onHandleChange( diff --git a/packages/tldraw/src/shape/shapes/arrow/arrow.tsx b/packages/tldraw/src/shape/shapes/arrow/arrow.tsx index 56fcd4db2..feb4f90da 100644 --- a/packages/tldraw/src/shape/shapes/arrow/arrow.tsx +++ b/packages/tldraw/src/shape/shapes/arrow/arrow.tsx @@ -238,8 +238,6 @@ export class Arrow extends TLDrawShapeUtil { style, } = shape - const circle = getCtp(shape) - const path = Utils.getFromCache(this.simplePathCache, shape, () => getArrowArcPath(start, end, getCtp(shape), bend) ) @@ -251,16 +249,31 @@ export class Arrow extends TLDrawShapeUtil { const arrowHeadlength = Math.min(arrowDist / 3, strokeWidth * 8) - const arcLength = Utils.getArcLength([circle[0], circle[1]], circle[2], start.point, end.point) + let insetStart: number[] + let insetEnd: number[] - const center = [circle[0], circle[1]] - const radius = circle[2] - const sa = Vec.angle(center, start.point) - const ea = Vec.angle(center, end.point) - const t = arrowHeadlength / Math.abs(arcLength) + if (bend === 0) { + insetStart = Vec.nudge(start.point, end.point, arrowHeadlength) + insetEnd = Vec.nudge(end.point, start.point, arrowHeadlength) + } else { + const circle = getCtp(shape) - const insetStart = Vec.nudgeAtAngle(center, Utils.lerpAngles(sa, ea, t), radius) - const insetEnd = Vec.nudgeAtAngle(center, Utils.lerpAngles(ea, sa, t), radius) + const arcLength = Utils.getArcLength( + [circle[0], circle[1]], + circle[2], + start.point, + end.point + ) + + const center = [circle[0], circle[1]] + const radius = circle[2] + const sa = Vec.angle(center, start.point) + const ea = Vec.angle(center, end.point) + const t = arrowHeadlength / Math.abs(arcLength) + + insetStart = Vec.nudgeAtAngle(center, Utils.lerpAngles(sa, ea, t), radius) + insetEnd = Vec.nudgeAtAngle(center, Utils.lerpAngles(ea, sa, t), radius) + } return ( @@ -413,9 +426,10 @@ export class Arrow extends TLDrawShapeUtil { center: number[] ): void | Partial => { const handle = shape.handles[binding.handleId] - const bounds = this.getBounds(shape) - const expandedBounds = Utils.expandBounds(bounds, binding.distance) + const expandedBounds = Utils.expandBounds(targetBounds, 32) + // The anchor is the "actual" point in the target shape + // (Remember that the binding.point is normalized) const anchor = Vec.sub( Vec.add( [expandedBounds.minX, expandedBounds.minY], @@ -424,68 +438,68 @@ export class Arrow extends TLDrawShapeUtil { shape.point ) - let handlePoint: number[] + // We're looking for the point to put the dragging handle + let handlePoint = anchor - const origin = Vec.add( - shape.point, - shape.handles[binding.handleId === 'start' ? 'end' : 'start'].point - ) - - const direction = Vec.uni(Vec.sub(Vec.add(anchor, shape.point), origin)) - - // TODO: Abstract this part onto individual shape utils? - - if ([TLDrawShapeType.Rectangle, TLDrawShapeType.Text].includes(target.type)) { + if (binding.distance) { const intersectBounds = Utils.expandBounds(targetBounds, binding.distance) - let hits = Intersect.ray - .bounds(origin, direction, intersectBounds) - .filter((int) => int.didIntersect) - .map((int) => int.points[0]) - .sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin)) + // The direction vector starts from the arrow's opposite handle + const origin = Vec.add( + shape.point, + shape.handles[handle.id === 'start' ? 'end' : 'start'].point + ) - if (hits.length < 2) { - hits = Intersect.ray - .bounds(origin, Vec.neg(direction), intersectBounds) + // And passes through the dragging handle + const direction = Vec.uni(Vec.sub(Vec.add(anchor, shape.point), origin)) + + if ([TLDrawShapeType.Rectangle, TLDrawShapeType.Text].includes(target.type)) { + let hits = Intersect.ray + .bounds(origin, direction, intersectBounds) .filter((int) => int.didIntersect) .map((int) => int.points[0]) .sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin)) + + if (hits.length < 2) { + hits = Intersect.ray + .bounds(origin, Vec.neg(direction), intersectBounds) + .filter((int) => int.didIntersect) + .map((int) => int.points[0]) + .sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin)) + } + + if (!hits[0]) { + console.warn('No intersection.') + return + } + + handlePoint = Vec.sub(hits[0], shape.point) + } else if (target.type === TLDrawShapeType.Ellipse) { + // const center = getShapeUtils(target).getCenter(target) + + handlePoint = Vec.nudge( + Vec.sub( + Intersect.ray + .ellipse( + origin, + direction, + center, + target.radius[0], + target.radius[1], + target.rotation || 0 + ) + .points.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))[0], + shape.point + ), + origin, + binding.distance + ) } - - if (!hits[0]) { - console.warn('No intersection.') - return - } - - handlePoint = Vec.sub(hits[0], shape.point) - } else if (target.type === TLDrawShapeType.Ellipse) { - // const center = getShapeUtils(target).getCenter(target) - - handlePoint = Vec.nudge( - Vec.sub( - Intersect.ray - .ellipse( - origin, - direction, - center, - target.radius[0], - target.radius[1], - target.rotation || 0 - ) - .points.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))[0], - shape.point - ), - origin, - binding.distance - ) - } else { - handlePoint = anchor } return this.onHandleChange( shape, { - ...shape.handles, [handle.id]: { ...handle, point: Vec.round(handlePoint), @@ -497,30 +511,23 @@ export class Arrow extends TLDrawShapeUtil { onHandleChange = ( shape: ArrowShape, - handles: ArrowShape['handles'], + handles: Partial, { shiftKey }: Partial ) => { - let nextHandles = Utils.deepMerge(shape.handles, handles) + let nextHandles = Utils.deepMerge(shape.handles, handles) let nextBend = shape.bend // If the user is holding shift, we want to snap the handles to angles - for (const id in handles) { - if ((id === 'start' || id === 'end') && shiftKey) { - const point = handles[id].point - const other = id === 'start' ? shape.handles.end : shape.handles.start + Object.values(handles).forEach((handle) => { + if ((handle.id === 'start' || handle.id === 'end') && shiftKey) { + const point = handle.point + const other = handle.id === 'start' ? shape.handles.end : shape.handles.start const angle = Vec.angle(other.point, point) const distance = Vec.dist(other.point, point) const newAngle = Utils.clampToRotationToSegments(angle, 24) - - nextHandles = { - ...nextHandles, - [id]: { - ...nextHandles[id], - point: Vec.nudgeAtAngle(other.point, newAngle, distance), - }, - } + handle.point = Vec.nudgeAtAngle(other.point, newAngle, distance) } - } + }) // If the user is moving the bend handle, we want to move the bend point if ('bend' in handles) { diff --git a/packages/tldraw/src/shape/shapes/ellipse/ellipse.tsx b/packages/tldraw/src/shape/shapes/ellipse/ellipse.tsx index 9265b7c58..b15e87f23 100644 --- a/packages/tldraw/src/shape/shapes/ellipse/ellipse.tsx +++ b/packages/tldraw/src/shape/shapes/ellipse/ellipse.tsx @@ -14,6 +14,7 @@ export class Ellipse extends TLDrawShapeUtil { type = TLDrawShapeType.Ellipse as const toolType = TLDrawToolType.Bounds pathCache = new WeakMap([]) + canBind = true defaultProps = { id: 'id', diff --git a/packages/tldraw/src/shape/shapes/rectangle/rectangle.tsx b/packages/tldraw/src/shape/shapes/rectangle/rectangle.tsx index f56535dbc..ffb716395 100644 --- a/packages/tldraw/src/shape/shapes/rectangle/rectangle.tsx +++ b/packages/tldraw/src/shape/shapes/rectangle/rectangle.tsx @@ -13,6 +13,7 @@ import { export class Rectangle extends TLDrawShapeUtil { type = TLDrawShapeType.Rectangle as const toolType = TLDrawToolType.Bounds + canBind = true pathCache = new WeakMap([]) @@ -179,6 +180,82 @@ export class Rectangle extends TLDrawShapeUtil { return Utils.getBoundsCenter(this.getBounds(shape)) } + getBindingPoint( + shape: RectangleShape, + point: number[], + origin: number[], + direction: number[], + padding: number, + anywhere: boolean + ) { + const bounds = this.getBounds(shape) + + const expandedBounds = Utils.expandBounds(bounds, padding) + + let bindingPoint: number[] + let distance: number + + // The point must be inside of the expanded bounding box + if (!Utils.pointInBounds(point, expandedBounds)) return + + // The point is inside of the shape, so we'll assume the user is + // indicating a specific point inside of the shape. + if (anywhere) { + if (Vec.dist(point, this.getCenter(shape)) < 12) { + bindingPoint = [0.5, 0.5] + } else { + bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [ + expandedBounds.width, + expandedBounds.height, + ]) + } + + distance = 0 + } else { + // Find furthest intersection between ray from + // origin through point and expanded bounds. + + // TODO: Make this a ray vs rounded rect intersection + const intersection = Intersect.ray + .bounds(origin, direction, expandedBounds) + .filter((int) => int.didIntersect) + .map((int) => int.points[0]) + .sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0] + + // The anchor is a point between the handle and the intersection + const anchor = Vec.med(point, intersection) + + // If we're close to the center, snap to the center + if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) { + bindingPoint = [0.5, 0.5] + } else { + // Or else calculate a normalized point + bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [ + expandedBounds.width, + expandedBounds.height, + ]) + } + + if (Utils.pointInBounds(point, bounds)) { + distance = 16 + } else { + // If the binding point was close to the shape's center, snap to the center + // Find the distance between the point and the real bounds of the shape + distance = Math.max( + 16, + Utils.getBoundsSides(bounds) + .map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point)) + .sort((a, b) => a - b)[0] + ) + } + } + + return { + point: bindingPoint, + distance, + } + } + hitTest(shape: RectangleShape, point: number[]) { return Utils.pointInBounds(point, this.getBounds(shape)) } diff --git a/packages/tldraw/src/shape/shapes/text/text.tsx b/packages/tldraw/src/shape/shapes/text/text.tsx index 979ff2b52..253978957 100644 --- a/packages/tldraw/src/shape/shapes/text/text.tsx +++ b/packages/tldraw/src/shape/shapes/text/text.tsx @@ -50,8 +50,8 @@ export class Text extends TLDrawShapeUtil { type = TLDrawShapeType.Text as const toolType = TLDrawToolType.Text canChangeAspectRatio = false - canBind = true isEditableText = true + canBind = true pathCache = new WeakMap([]) diff --git a/packages/tldraw/src/state/command/delete/delete.command.ts b/packages/tldraw/src/state/command/delete/delete.command.ts index 3ea84c361..434739c83 100644 --- a/packages/tldraw/src/state/command/delete/delete.command.ts +++ b/packages/tldraw/src/state/command/delete/delete.command.ts @@ -1,24 +1,75 @@ import type { Data, Command } from '../../state-types' +// - [x] Delete shapes +// - [ ] Delete bindings too +// - [ ] Update parents and possibly delete parents + export function deleteShapes(data: Data, ids: string[]): Command { + // We also need to delete any bindings that reference the deleted shapes + const bindingIdsToDelete = Object.values(data.page.bindings) + .filter((binding) => ids.includes(binding.fromId) || ids.includes(binding.toId)) + .map((binding) => binding.id) + + // We also need to update any shapes that reference the deleted bindings + const shapesWithBindingsToUpdate = Object.values(data.page.shapes).filter( + (shape) => + shape.handles && + Object.values(shape.handles).some( + (handle) => handle.bindingId && bindingIdsToDelete.includes(handle.bindingId) + ) + ) + return { - id: 'toggle_shapes', + id: 'delete_shapes', before: { page: { - shapes: Object.fromEntries(ids.map((id) => [id, data.page.shapes[id]])), + shapes: { + ...Object.fromEntries(ids.map((id) => [id, data.page.shapes[id]])), + ...Object.fromEntries( + shapesWithBindingsToUpdate.map((shape) => { + let handle = Object.values(shape.handles!).find((handle) => { + const bindingId = handle.bindingId + + if (bindingId && bindingIdsToDelete.includes(bindingId)) { + return handle + } + + return false + })! + + return [shape.id, { handles: { [handle.id]: { bindingId: handle } } }] + }) + ), + }, + bindings: Object.fromEntries(bindingIdsToDelete.map((id) => [id, data.page.bindings[id]])), }, pageState: { selectedIds: [...data.pageState.selectedIds], - hoveredId: undefined + hoveredId: undefined, }, }, after: { page: { - shapes: Object.fromEntries(ids.map((id) => [id, undefined])), + shapes: { + ...Object.fromEntries(ids.map((id) => [id, undefined])), + ...Object.fromEntries( + shapesWithBindingsToUpdate.map((shape) => { + for (const id in shape.handles) { + const handle = shape.handles[id as keyof typeof shape.handles] + const bindingId = handle.bindingId + if (bindingId && bindingIdsToDelete.includes(bindingId)) { + handle.bindingId = undefined + } + } + return [shape.id, shape] + }) + ), + }, + bindings: Object.fromEntries(bindingIdsToDelete.map((id) => [id, undefined])), }, pageState: { selectedIds: [], - hoveredId: undefined + hoveredId: undefined, }, }, } diff --git a/packages/tldraw/src/state/command/duplicate/duplicate.command.spec.ts b/packages/tldraw/src/state/command/duplicate/duplicate.command.spec.ts index 3691552d5..514701b5b 100644 --- a/packages/tldraw/src/state/command/duplicate/duplicate.command.spec.ts +++ b/packages/tldraw/src/state/command/duplicate/duplicate.command.spec.ts @@ -4,7 +4,6 @@ import { mockDocument } from '../../test-helpers' describe('Duplicate command', () => { const tlstate = new TLDrawState() tlstate.loadDocument(mockDocument) - tlstate.reset() tlstate.select('rect1') it('does, undoes and redoes command', () => { diff --git a/packages/tldraw/src/state/command/style/style.command.spec.ts b/packages/tldraw/src/state/command/style/style.command.spec.ts index 06bd60ccc..8381eb230 100644 --- a/packages/tldraw/src/state/command/style/style.command.spec.ts +++ b/packages/tldraw/src/state/command/style/style.command.spec.ts @@ -5,8 +5,7 @@ import { mockDocument } from '../../test-helpers' describe('Style command', () => { const tlstate = new TLDrawState() tlstate.loadDocument(mockDocument) - tlstate.reset() - tlstate.setSelectedIds(['rect1']) + tlstate.select('rect1') it('does, undoes and redoes command', () => { expect(tlstate.getShape('rect1').style.size).toEqual(SizeStyle.Medium) diff --git a/packages/tldraw/src/state/session/sessions/arrow/arrow.session.spec.ts b/packages/tldraw/src/state/session/sessions/arrow/arrow.session.spec.ts new file mode 100644 index 000000000..099c27ce0 --- /dev/null +++ b/packages/tldraw/src/state/session/sessions/arrow/arrow.session.spec.ts @@ -0,0 +1,41 @@ +import { TLDrawState } from '../../../tlstate' +import { mockDocument } from '../../../test-helpers' +import { TLDR } from '../../../tldr' +import type { TLDrawShape } from '../../../../shape' + +describe('Handle session', () => { + const tlstate = new TLDrawState() + + it('begins, updates and completes session', () => { + tlstate + .loadDocument(mockDocument) + .create( + TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).create({ + id: 'arrow1', + parentId: 'page1', + }) + ) + .select('arrow1') + .startHandleSession([-10, -10], 'end') + .updateHandleSession([10, 10]) + .completeSession() + .undo() + .redo() + }) + + it('cancels session', () => { + tlstate + .loadDocument(mockDocument) + .create({ + ...TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).defaultProps, + id: 'arrow1', + parentId: 'page1', + }) + .select('arrow1') + .startHandleSession([-10, -10], 'end') + .updateHandleSession([10, 10]) + .cancelSession() + + expect(tlstate.getShape('rect1').point).toStrictEqual([0, 0]) + }) +}) diff --git a/packages/tldraw/src/state/session/sessions/arrow/arrow.session.ts b/packages/tldraw/src/state/session/sessions/arrow/arrow.session.ts new file mode 100644 index 000000000..b060c2656 --- /dev/null +++ b/packages/tldraw/src/state/session/sessions/arrow/arrow.session.ts @@ -0,0 +1,251 @@ +import type { ArrowBinding, ArrowShape } from '../../../../shape' +import type { TLDrawShape } from '../../../../shape' +import type { Session } from '../../../state-types' +import type { Data } from '../../../state-types' +import { Vec, Utils, TLBinding } from '@tldraw/core' +import { TLDR } from '../../../tldr' + +export class ArrowSession implements Session { + id = 'transform_single' + newBindingId = Utils.uniqueId() + delta = [0, 0] + origin: number[] + shiftKey = false + initialShape: ArrowShape + handleId: 'start' | 'end' + bindableShapeIds: string[] + initialBinding: TLBinding | undefined + didBind = false + + constructor(data: Data, handleId: 'start' | 'end', point: number[]) { + const shapeId = data.pageState.selectedIds[0] + this.origin = point + this.handleId = handleId + this.initialShape = TLDR.getShape(data, shapeId) + this.bindableShapeIds = TLDR.getBindableShapeIds(data) + + const initialBindingId = this.initialShape.handles[this.handleId].bindingId + + if (initialBindingId) { + this.initialBinding = data.page.bindings[initialBindingId] + } + } + + start = (data: Data) => data + + update = ( + data: Data, + point: number[], + shiftKey: boolean, + altKey: boolean, + metaKey: boolean + ): Partial => { + const { initialShape, origin } = this + + const shape = TLDR.getShape(data, initialShape.id) + + TLDR.assertShapeHasProperty(shape, 'handles') + + this.shiftKey = shiftKey + + const delta = Vec.sub(point, origin) + + const handles = shape.handles + + const handleId = this.handleId as keyof typeof handles + + const handle = handles[handleId] + + let nextPoint = Vec.round(Vec.add(this.initialShape.handles[handleId].point, delta)) + + // First update the handle's next point + const change = TLDR.getShapeUtils(shape).onHandleChange( + shape, + { + [handleId]: { + ...shape.handles[handleId], + point: nextPoint, // Vec.rot(delta, shape.rotation)), + }, + }, + { delta, shiftKey, altKey, metaKey } + ) + + if (!change) return data + + let nextBindings: Record = { ...data.page.bindings } + let nextShape: ArrowShape = { ...shape, ...change } + let nextBinding: ArrowBinding | undefined = undefined + let nextTarget: TLDrawShape | undefined = undefined + + if (handle.canBind) { + const oppositeHandle = handles[handle.id === 'start' ? 'end' : 'start'] + + // Find the origin and direction of the handle + const rayOrigin = Vec.add(oppositeHandle.point, shape.point) + const rayPoint = Vec.add(nextPoint, shape.point) + const rayDirection = Vec.uni(Vec.sub(rayPoint, rayOrigin)) + + // From all bindable shapes on the page... + for (const id of this.bindableShapeIds) { + if (id === initialShape.id) continue + + const target = TLDR.getShape(data, id) + + const util = TLDR.getShapeUtils(target) + + const bindingPoint = util.getBindingPoint( + target, + rayPoint, + rayOrigin, + rayDirection, + 32, + metaKey + ) + + // Not all shapes will produce a binding point + if (!bindingPoint) continue + + // Stop at the first shape that will produce a binding point + nextTarget = target + + nextBinding = { + id: this.newBindingId, + type: 'arrow', + fromId: initialShape.id, + handleId: this.handleId, + toId: target.id, + point: Vec.round(bindingPoint.point), + distance: bindingPoint.distance, + } + + break + } + // If we didn't find a target... + if (nextBinding === undefined) { + this.didBind = false + if (handle.bindingId) { + delete nextBindings[handle.bindingId] + } + nextShape.handles[handleId].bindingId = undefined + } else if (nextTarget) { + this.didBind = true + + if (handle.bindingId && handle.bindingId !== this.newBindingId) { + delete nextBindings[handle.bindingId] + nextShape.handles[handleId].bindingId = undefined + } + + // If we found a new binding, add its id to the handle... + nextShape = { + ...nextShape, + handles: { + ...nextShape.handles, + [handleId]: { + ...nextShape.handles[handleId], + bindingId: nextBinding.id, + }, + }, + } + + // and add it to the page's bindings + nextBindings = { + ...nextBindings, + [nextBinding.id]: nextBinding, + } + + // Now update the arrow in response to the new binding + const arrowChange = TLDR.getShapeUtils(nextShape).onBindingChange( + nextShape, + nextBinding, + nextTarget, + TLDR.getShapeUtils(nextTarget).getBounds(nextTarget), + TLDR.getShapeUtils(nextTarget).getCenter(nextTarget) + ) + + if (arrowChange) { + nextShape = { + ...nextShape, + ...arrowChange, + } + } + } + } + + return { + page: { + ...data.page, + shapes: { + ...data.page.shapes, + [shape.id]: nextShape, + }, + bindings: nextBindings, + }, + pageState: { + ...data.pageState, + bindingId: nextShape.handles[handleId].bindingId, + }, + } + } + + cancel = (data: Data) => { + const { initialShape, newBindingId } = this + + const nextBindings = { ...data.page.bindings } + + if (this.didBind) { + delete nextBindings[newBindingId] + } + + return { + page: { + ...data.page, + shapes: { + ...data.page.shapes, + [initialShape.id]: initialShape, + }, + bindings: nextBindings, + }, + } + } + + complete(data: Data) { + let beforeBindings: Partial> = {} + let afterBindings: Partial> = {} + + const currentShape = TLDR.getShape(data, this.initialShape.id) + const currentBindingId = currentShape.handles[this.handleId].bindingId + + if (this.initialBinding) { + beforeBindings[this.initialBinding.id] = this.initialBinding + afterBindings[this.initialBinding.id] = undefined + } + + if (currentBindingId) { + beforeBindings[currentBindingId] = undefined + afterBindings[currentBindingId] = data.page.bindings[currentBindingId] + } + + return { + id: 'arrow', + before: { + page: { + shapes: { + [this.initialShape.id]: this.initialShape, + }, + bindings: beforeBindings, + }, + }, + after: { + page: { + shapes: { + [this.initialShape.id]: TLDR.onSessionComplete( + data, + data.page.shapes[this.initialShape.id] + ), + }, + bindings: afterBindings, + }, + }, + } + } +} diff --git a/packages/tldraw/src/state/session/sessions/arrow/index.ts b/packages/tldraw/src/state/session/sessions/arrow/index.ts new file mode 100644 index 000000000..31fb65b5b --- /dev/null +++ b/packages/tldraw/src/state/session/sessions/arrow/index.ts @@ -0,0 +1 @@ +export * from './arrow.session' diff --git a/packages/tldraw/src/state/session/sessions/brush/brush.session.ts b/packages/tldraw/src/state/session/sessions/brush/brush.session.ts index e96ba0a15..6fd9cd650 100644 --- a/packages/tldraw/src/state/session/sessions/brush/brush.session.ts +++ b/packages/tldraw/src/state/session/sessions/brush/brush.session.ts @@ -57,11 +57,10 @@ export class BrushSession implements Session { selectedIds.size === data.pageState.selectedIds.length && data.pageState.selectedIds.every((id) => selectedIds.has(id)) ) { - return data + return {} } return { - ...data, pageState: { ...data.pageState, selectedIds: Array.from(selectedIds.values()), diff --git a/packages/tldraw/src/state/session/sessions/draw/draw.session.ts b/packages/tldraw/src/state/session/sessions/draw/draw.session.ts index b4201f255..e070383d6 100644 --- a/packages/tldraw/src/state/session/sessions/draw/draw.session.ts +++ b/packages/tldraw/src/state/session/sessions/draw/draw.session.ts @@ -29,12 +29,7 @@ export class DrawSession implements Session { start = (data: Data) => data - update = ( - data: Data, - point: number[], - pressure: number, - isLocked = false - ) => { + update = (data: Data, point: number[], pressure: number, isLocked = false) => { const { snapshot } = this // Drawing while holding shift will "lock" the pen to either the @@ -82,10 +77,7 @@ export class DrawSession implements Session { // Don't add duplicate points. It's important to test against the // adjusted (low-passed) point rather than the input point. - const newPoint = Vec.round([ - ...Vec.sub(this.previous, this.origin), - pressure, - ]) + const newPoint = Vec.round([...Vec.sub(this.previous, this.origin), pressure]) if (Vec.isEqual(this.last, newPoint)) return data @@ -98,7 +90,6 @@ export class DrawSession implements Session { if (this.points.length <= 2) return data return { - ...data, page: { ...data.page, shapes: { @@ -119,7 +110,6 @@ export class DrawSession implements Session { cancel = (data: Data): Data => { const { snapshot } = this return { - ...data, page: { ...data.page, // @ts-ignore @@ -152,10 +142,7 @@ export class DrawSession implements Session { after: { page: { shapes: { - [snapshot.id]: TLDR.onSessionComplete( - data, - data.page.shapes[snapshot.id] - ), + [snapshot.id]: TLDR.onSessionComplete(data, data.page.shapes[snapshot.id]), }, }, pageState: { diff --git a/packages/tldraw/src/state/session/sessions/handle/handle.session.ts b/packages/tldraw/src/state/session/sessions/handle/handle.session.ts index 4176a402d..4015fbaae 100644 --- a/packages/tldraw/src/state/session/sessions/handle/handle.session.ts +++ b/packages/tldraw/src/state/session/sessions/handle/handle.session.ts @@ -1,3 +1,4 @@ +import { ArrowBinding } from './../../../../shape/shape-types' import { Vec } from '@tldraw/core' import type { TLDrawShape } from '../../../../shape' import type { Session } from '../../../state-types' @@ -13,12 +14,7 @@ export class HandleSession implements Session { initialShape: TLDrawShape handleId: string - constructor( - data: Data, - handleId: string, - point: number[], - commandId = 'move_handle' - ) { + constructor(data: Data, handleId: string, point: number[], commandId = 'move_handle') { const shapeId = data.pageState.selectedIds[0] this.origin = point this.handleId = handleId @@ -49,12 +45,17 @@ export class HandleSession implements Session { const handleId = this.handleId as keyof typeof handles + const handle = handles[handleId] + + let nextPoint = Vec.round(Vec.add(handle.point, delta)) + + // Now update the handle's next point const change = TLDR.getShapeUtils(shape).onHandleChange( shape, { [handleId]: { ...shape.handles[handleId], - point: Vec.round(Vec.add(handles[handleId].point, delta)), // Vec.rot(delta, shape.rotation)), + point: nextPoint, // Vec.rot(delta, shape.rotation)), }, }, { delta, shiftKey, altKey, metaKey } diff --git a/packages/tldraw/src/state/session/sessions/index.ts b/packages/tldraw/src/state/session/sessions/index.ts index 2390891e9..3ba22bfd2 100644 --- a/packages/tldraw/src/state/session/sessions/index.ts +++ b/packages/tldraw/src/state/session/sessions/index.ts @@ -6,3 +6,4 @@ export * from './draw' export * from './rotate' export * from './handle' export * from './text' +export * from './arrow' diff --git a/packages/tldraw/src/state/session/sessions/rotate/rotate.session.ts b/packages/tldraw/src/state/session/sessions/rotate/rotate.session.ts index 3f5830f20..041c2d768 100644 --- a/packages/tldraw/src/state/session/sessions/rotate/rotate.session.ts +++ b/packages/tldraw/src/state/session/sessions/rotate/rotate.session.ts @@ -19,11 +19,10 @@ export class RotateSession implements Session { start = (data: Data) => data - update = (data: Data, point: number[], isLocked = false): Data => { + update = (data: Data, point: number[], isLocked = false) => { const { commonBoundsCenter, initialShapes } = this.snapshot const next = { - ...data, page: { ...data.page, }, @@ -45,8 +44,7 @@ export class RotateSession implements Session { rot = Utils.clampToRotationToSegments(rot, 24) } - pageState.boundsRotation = - (PI2 + (this.snapshot.boundsRotation + rot)) % PI2 + pageState.boundsRotation = (PI2 + (this.snapshot.boundsRotation + rot)) % PI2 next.page.shapes = { ...next.page.shapes, @@ -58,10 +56,7 @@ export class RotateSession implements Session { ? Utils.clampToRotationToSegments(rotation + rot, 24) : rotation + rot - const nextPoint = Vec.sub( - Vec.rotWith(center, commonBoundsCenter, rot), - offset - ) + const nextPoint = Vec.sub(Vec.rotWith(center, commonBoundsCenter, rot), offset) return [ id, @@ -77,7 +72,9 @@ export class RotateSession implements Session { ), } - return next + return { + page: next.page, + } } cancel = (data: Data) => { @@ -88,16 +85,12 @@ export class RotateSession implements Session { } return { - ...data, page: { ...data.page, shapes: { ...data.page.shapes, ...Object.fromEntries( - initialShapes.map(({ id, shape }) => [ - id, - TLDR.onSessionComplete(data, shape), - ]) + initialShapes.map(({ id, shape }) => [id, TLDR.onSessionComplete(data, shape)]) ), }, }, @@ -114,11 +107,9 @@ export class RotateSession implements Session { before: { page: { shapes: Object.fromEntries( - initialShapes.map( - ({ shape: { id, point, rotation = undefined } }) => { - return [id, { point, rotation }] - } - ) + initialShapes.map(({ shape: { id, point, rotation = undefined } }) => { + return [id, { point, rotation }] + }) ), }, }, @@ -169,10 +160,7 @@ export function getRotateSnapshot(data: Data) { const center = Utils.getBoundsCenter(bounds) const offset = Vec.sub(center, shape.point) - const rotationOffset = Vec.sub( - center, - Utils.getBoundsCenter(rotatedBounds[shape.id]) - ) + const rotationOffset = Vec.sub(center, Utils.getBoundsCenter(rotatedBounds[shape.id])) return { id: shape.id, diff --git a/packages/tldraw/src/state/session/sessions/transform-single/transform-single.session.ts b/packages/tldraw/src/state/session/sessions/transform-single/transform-single.session.ts index 29b700b5a..1d0c6d0b8 100644 --- a/packages/tldraw/src/state/session/sessions/transform-single/transform-single.session.ts +++ b/packages/tldraw/src/state/session/sessions/transform-single/transform-single.session.ts @@ -27,7 +27,7 @@ export class TransformSingleSession implements Session { start = (data: Data) => data - update = (data: Data, point: number[], isAspectRatioLocked = false): Data => { + update = (data: Data, point: number[], isAspectRatioLocked = false): Partial => { const { transformType } = this const { initialShapeBounds, initialShape, id } = this.snapshot @@ -41,13 +41,10 @@ export class TransformSingleSession implements Session { transformType, Vec.sub(point, this.origin), shape.rotation, - isAspectRatioLocked || - shape.isAspectRatioLocked || - utils.isAspectRatioLocked + isAspectRatioLocked || shape.isAspectRatioLocked || utils.isAspectRatioLocked ) return { - ...data, page: { ...data.page, shapes: { @@ -72,7 +69,6 @@ export class TransformSingleSession implements Session { data.page.shapes[id] = initialShape return { - ...data, page: { ...data.page, shapes: { @@ -98,10 +94,7 @@ export class TransformSingleSession implements Session { after: { page: { shapes: { - [this.snapshot.id]: TLDR.onSessionComplete( - data, - data.page.shapes[this.snapshot.id] - ), + [this.snapshot.id]: TLDR.onSessionComplete(data, data.page.shapes[this.snapshot.id]), }, }, }, @@ -130,6 +123,4 @@ export function getTransformSingleSnapshot( } } -export type TransformSingleSnapshot = ReturnType< - typeof getTransformSingleSnapshot -> +export type TransformSingleSnapshot = ReturnType diff --git a/packages/tldraw/src/state/session/sessions/transform/transform.session.ts b/packages/tldraw/src/state/session/sessions/transform/transform.session.ts index 645f58673..17ef3f3fe 100644 --- a/packages/tldraw/src/state/session/sessions/transform/transform.session.ts +++ b/packages/tldraw/src/state/session/sessions/transform/transform.session.ts @@ -28,13 +28,13 @@ export class TransformSession implements Session { point: number[], isAspectRatioLocked = false, _altKey = false - ): Data => { + ): Partial => { const { transformType, snapshot: { shapeBounds, initialBounds, isAllAspectRatioLocked }, } = this - const next = { + const next: Data = { ...data, page: { ...data.page, @@ -89,23 +89,21 @@ export class TransformSession implements Session { ), } - return next + return { + page: next.page, + } } cancel = (data: Data) => { const { shapeBounds } = this.snapshot return { - ...data, page: { ...data.page, shapes: { ...data.page.shapes, ...Object.fromEntries( - Object.entries(shapeBounds).map(([id, { initialShape }]) => [ - id, - initialShape, - ]) + Object.entries(shapeBounds).map(([id, { initialShape }]) => [id, initialShape]) ), }, }, @@ -122,10 +120,7 @@ export class TransformSession implements Session { before: { page: { shapes: Object.fromEntries( - Object.entries(shapeBounds).map(([id, { initialShape }]) => [ - id, - initialShape, - ]) + Object.entries(shapeBounds).map(([id, { initialShape }]) => [id, initialShape]) ), }, }, @@ -143,17 +138,13 @@ export class TransformSession implements Session { } } -export function getTransformSnapshot( - data: Data, - transformType: TLBoundsEdge | TLBoundsCorner -) { +export function getTransformSnapshot(data: Data, transformType: TLBoundsEdge | TLBoundsCorner) { const initialShapes = TLDR.getSelectedBranchSnapshot(data) const hasUnlockedShapes = initialShapes.length > 0 const isAllAspectRatioLocked = initialShapes.every( - (shape) => - shape.isAspectRatioLocked || TLDR.getShapeUtils(shape).isAspectRatioLocked + (shape) => shape.isAspectRatioLocked || TLDR.getShapeUtils(shape).isAspectRatioLocked ) const shapesBounds = Object.fromEntries( @@ -164,9 +155,7 @@ export function getTransformSnapshot( const commonBounds = Utils.getCommonBounds(boundsArr) - const initialInnerBounds = Utils.getBoundsFromPoints( - boundsArr.map(Utils.getBoundsCenter) - ) + const initialInnerBounds = Utils.getBoundsFromPoints(boundsArr.map(Utils.getBoundsCenter)) // Return a mapping of shapes to bounds together with the relative // positions of the shape's bounds within the common bounds shape. diff --git a/packages/tldraw/src/state/session/sessions/translate/translate.session.ts b/packages/tldraw/src/state/session/sessions/translate/translate.session.ts index f639b3858..39b5c6d17 100644 --- a/packages/tldraw/src/state/session/sessions/translate/translate.session.ts +++ b/packages/tldraw/src/state/session/sessions/translate/translate.session.ts @@ -21,12 +21,7 @@ export class TranslateSession implements Session { return data } - update = ( - data: Data, - point: number[], - isAligned = false, - isCloning = false - ) => { + update = (data: Data, point: number[], isAligned = false, isCloning = false) => { const { clones, initialShapes } = this.snapshot const next = { @@ -91,15 +86,13 @@ export class TranslateSession implements Session { clone.id, { ...clone, - point: Vec.round( - Vec.add(next.page.shapes[clone.id].point, trueDelta) - ), + point: Vec.round(Vec.add(next.page.shapes[clone.id].point, trueDelta)), }, ]) ), } - return next + return { page: { ...next.page }, pageState: { ...next.pageState } } } // If not cloning... @@ -137,28 +130,22 @@ export class TranslateSession implements Session { shape.id, { ...next.page.shapes[shape.id], - point: Vec.round( - Vec.add(next.page.shapes[shape.id].point, trueDelta) - ), + point: Vec.round(Vec.add(next.page.shapes[shape.id].point, trueDelta)), }, ]) ), } - return next + return { page: { ...next.page }, pageState: { ...next.pageState } } } cancel = (data: Data): Data => { return { - ...data, page: { - ...data.page, // @ts-ignore - We need to set deleted shapes to undefined in order to correctly deep merge them away. shapes: { ...data.page.shapes, - ...Object.fromEntries( - this.snapshot.clones.map((clone) => [clone.id, undefined]) - ), + ...Object.fromEntries(this.snapshot.clones.map((clone) => [clone.id, undefined])), ...Object.fromEntries( this.snapshot.initialShapes.map((shape) => [ shape.id, @@ -178,38 +165,23 @@ export class TranslateSession implements Session { return { id: 'translate', before: { - ...data, page: { - ...data.page, shapes: { - ...data.page.shapes, + ...Object.fromEntries(this.snapshot.clones.map((clone) => [clone.id, undefined])), ...Object.fromEntries( - this.snapshot.clones.map((clone) => [clone.id, undefined]) - ), - ...Object.fromEntries( - this.snapshot.initialShapes.map((shape) => [ - shape.id, - { point: shape.point }, - ]) + this.snapshot.initialShapes.map((shape) => [shape.id, { point: shape.point }]) ), }, }, pageState: { - ...data.pageState, selectedIds: this.snapshot.selectedIds, }, }, after: { - ...data, page: { - ...data.page, shapes: { - ...data.page.shapes, ...Object.fromEntries( - this.snapshot.clones.map((clone) => [ - clone.id, - data.page.shapes[clone.id], - ]) + this.snapshot.clones.map((clone) => [clone.id, data.page.shapes[clone.id]]) ), ...Object.fromEntries( this.snapshot.initialShapes.map((shape) => [ @@ -220,7 +192,6 @@ export class TranslateSession implements Session { }, }, pageState: { - ...data.pageState, selectedIds: [...data.pageState.selectedIds], }, }, @@ -234,9 +205,7 @@ export function getTranslateSnapshot(data: Data) { const hasUnlockedShapes = selectedShapes.length > 0 - const initialParents = Array.from( - new Set(selectedShapes.map((s) => s.parentId)).values() - ) + const initialParents = Array.from(new Set(selectedShapes.map((s) => s.parentId)).values()) .filter((id) => id !== data.page.id) .map((id) => { const shape = TLDR.getShape(data, id) diff --git a/packages/tldraw/src/state/state-types.ts b/packages/tldraw/src/state/state-types.ts index 7a8e37be5..a8c9d4440 100644 --- a/packages/tldraw/src/state/state-types.ts +++ b/packages/tldraw/src/state/state-types.ts @@ -1,11 +1,6 @@ /* eslint-disable @typescript-eslint/ban-types */ import type { TLPage, TLPageState } from '@tldraw/core' -import type { - ShapeStyles, - TLDrawShape, - TLDrawShapeType, - TLDrawToolType, -} from '../shape' +import type { ShapeStyles, TLDrawShape, TLDrawShapeType, TLDrawToolType } from '../shape' import type { TLDrawSettings } from '../types' import type { StoreApi } from 'zustand' @@ -51,10 +46,10 @@ export interface History { export interface Session { id: string - start: (data: Readonly, ...args: any[]) => Data - update: (data: Readonly, ...args: any[]) => Data - complete: (data: Readonly, ...args: any[]) => Data | Command - cancel: (data: Readonly, ...args: any[]) => Data + start: (data: Readonly, ...args: any[]) => Partial + update: (data: Readonly, ...args: any[]) => Partial + complete: (data: Readonly, ...args: any[]) => Partial | Command + cancel: (data: Readonly, ...args: any[]) => Partial } export type TLDrawStatus = @@ -72,11 +67,6 @@ export type TLDrawStatus = | 'editing-text' // eslint-disable-next-line @typescript-eslint/no-explicit-any -export type ParametersExceptFirst = F extends ( - arg0: any, - ...rest: infer R -) => any - ? R - : never +export type ParametersExceptFirst = F extends (arg0: any, ...rest: infer R) => any ? R : never export {} diff --git a/packages/tldraw/src/state/tldr.ts b/packages/tldraw/src/state/tldr.ts index 30bf91c9c..dad0de6da 100644 --- a/packages/tldraw/src/state/tldr.ts +++ b/packages/tldraw/src/state/tldr.ts @@ -1,6 +1,6 @@ import { TLBinding, TLBounds, TLTransformInfo, Vec, Utils } from '@tldraw/core' import { getShapeUtils, ShapeStyles, ShapesWithProp, TLDrawShape, TLDrawShapeUtil } from '../shape' -import type { Data } from './state-types' +import type { Data, DeepPartial } from './state-types' export class TLDR { static getShapeUtils(shape: T | T['type']): TLDrawShapeUtil { @@ -389,31 +389,132 @@ export class TLDR { } } - static createShapes(data: Data, shapes: TLDrawShape[]): void { + static createShapes( + data: Data, + shapes: TLDrawShape[] + ): { before: DeepPartial; after: DeepPartial } { const page = this.getPage(data) - const shapeIds = shapes.map((shape) => shape.id) - // Update selected ids - this.setSelectedIds(data, shapeIds) + const before: DeepPartial = { + page: { + shapes: { + ...Object.fromEntries( + shapes.flatMap((shape) => { + const results: [string, Partial | undefined][] = [[shape.id, undefined]] - // Restore deleted shapes - shapes.forEach((shape) => { - const newShape = { ...shape } - page.shapes[shape.id] = newShape - }) + // If the shape is a child of another shape, also add that shape + if (shape.parentId !== data.page.id) { + const parent = page.shapes[shape.parentId] + results.push([parent.id, { children: parent.children! }]) + } - // Update parents - shapes.forEach((shape) => { - if (shape.parentId === data.page.id) return + return results + }) + ), + }, + }, + } - const parent = page.shapes[shape.parentId] + const after: DeepPartial = { + page: { + shapes: { + ...Object.fromEntries( + shapes.flatMap((shape) => { + const results: [string, Partial | undefined][] = [[shape.id, shape]] - this.mutate(data, parent, { - children: parent.children!.includes(shape.id) - ? parent.children - : [...parent.children!, shape.id], - }) - }) + // If the shape is a child of a different shape, update its parent + if (shape.parentId !== data.page.id) { + const parent = page.shapes[shape.parentId] + results.push([parent.id, { children: [...parent.children!, shape.id] }]) + } + + return results + }) + ), + }, + }, + } + + return { + before, + after, + } + } + + static deleteShapes( + data: Data, + shapes: TLDrawShape[] | string[] + ): { before: DeepPartial; after: DeepPartial } { + const page = this.getPage(data) + + const shapeIds = + typeof shapes[0] === 'string' + ? (shapes as string[]) + : (shapes as TLDrawShape[]).map((shape) => shape.id) + + const before: DeepPartial = { + page: { + shapes: { + // These are the shapes that we're going to delete + ...Object.fromEntries( + shapeIds.flatMap((id) => { + const shape = page.shapes[id] + const results: [string, Partial | undefined][] = [[shape.id, shape]] + + // If the shape is a child of another shape, also add that shape + if (shape.parentId !== data.page.id) { + const parent = page.shapes[shape.parentId] + results.push([parent.id, { children: parent.children! }]) + } + + return results + }) + ), + }, + bindings: { + // These are the bindings that we're going to delete + ...Object.fromEntries( + Object.values(page.bindings) + .filter((binding) => { + return shapeIds.includes(binding.fromId) || shapeIds.includes(binding.toId) + }) + .map((binding) => { + return [binding.id, binding] + }) + ), + }, + }, + } + + const after: DeepPartial = { + page: { + shapes: { + ...Object.fromEntries( + shapeIds.flatMap((id) => { + const shape = page.shapes[id] + const results: [string, Partial | undefined][] = [[shape.id, undefined]] + + // If the shape is a child of a different shape, update its parent + if (shape.parentId !== data.page.id) { + const parent = page.shapes[shape.parentId] + + results.push([ + parent.id, + { children: parent.children!.filter((id) => id !== shape.id) }, + ]) + } + + return results + }) + ), + }, + }, + } + + return { + before, + after, + } } static onSessionComplete(data: Data, shape: T) { @@ -515,7 +616,9 @@ export class TLDR { return currentStyle } - const shapeStyles = data.pageState.selectedIds.map((id) => page.shapes[id].style) + const shapeStyles = data.pageState.selectedIds.map((id) => { + return page.shapes[id].style + }) const commonStyle: ShapeStyles = {} as ShapeStyles @@ -552,6 +655,13 @@ export class TLDR { return Object.values(page.bindings) } + static getBindableShapeIds(data: Data) { + return Object.values(data.page.shapes) + .filter((shape) => TLDR.getShapeUtils(shape).canBind) + .sort((a, b) => b.childIndex - a.childIndex) + .map((shape) => shape.id) + } + static getBindingsWithShapeIds(data: Data, ids: string[]): TLBinding[] { return Array.from( new Set( @@ -567,13 +677,11 @@ export class TLDR { bindings.forEach((binding) => (page.bindings[binding.id] = binding)) } - static deleteBindings(data: Data, ids: string[]): void { - if (ids.length === 0) return - - const page = this.getPage(data) - - ids.forEach((id) => delete page.bindings[id]) - } + // static deleteBindings(data: Data, ids: string[]): void { + // if (ids.length === 0) return + // const page = this.getPage(data) + // ids.forEach((id) => delete page.bindings[id]) + // } /* -------------------------------------------------- */ /* Assertions */ diff --git a/packages/tldraw/src/state/tlstate.ts b/packages/tldraw/src/state/tlstate.ts index a7fc08b8e..9c20f4a9b 100644 --- a/packages/tldraw/src/state/tlstate.ts +++ b/packages/tldraw/src/state/tlstate.ts @@ -1,3 +1,4 @@ +import { ArrowSession } from './session/sessions/arrow/arrow.session' import type { TextShape } from './../shape/shape-types' import { FlipType } from './../types' import createReact, { PartialState } from 'zustand' @@ -111,14 +112,81 @@ export class TLDrawState implements TLCallbacks { let next = { ...current, ...result } - if ('page' in result) { + if (result.page) { + const shapes = { ...next.page.shapes } + + for (let id in shapes) { + if (!shapes[id]) delete shapes[id] + } + + const bindings = { ...next.page.bindings } + + for (let id in bindings) { + if (!bindings[id]) delete bindings[id] + } + + const changedShapeIds = new Set( + Object.values(shapes) + .filter((shape) => current.page.shapes[shape.id] !== shape) + .map((shape) => shape.id) + ) + + // Find all shapes that we need to update due to bindings + const bindingsArr = Object.values(bindings) + + const bindingsToUpdate = new Set( + bindingsArr.filter( + (binding) => changedShapeIds.has(binding.toId) || changedShapeIds.has(binding.fromId) + ) + ) + + let prevSize = bindingsToUpdate.size + + while (true) { + bindingsToUpdate.forEach((binding) => { + const fromId = binding.fromId + + for (const otherBinding of bindingsArr) { + if (otherBinding.fromId === fromId) { + bindingsToUpdate.add(otherBinding) + } + + if (otherBinding.toId === fromId) { + bindingsToUpdate.add(otherBinding) + } + } + }) + + if (bindingsToUpdate.size === prevSize) break + prevSize = bindingsToUpdate.size + } + + bindingsToUpdate.forEach((binding) => { + // Update the binding + const toShape = shapes[binding.toId] + const fromShape = shapes[binding.fromId] + const toUtils = TLDR.getShapeUtils(toShape) + + const fromDelta = TLDR.getShapeUtils(fromShape).onBindingChange( + fromShape, + binding, + toShape, + toUtils.getBounds(toShape), + toUtils.getCenter(toShape) + ) + + if (fromDelta) { + shapes[fromShape.id] = { + ...fromShape, + ...fromDelta, + } as TLDrawShape + } + }) + next.page = { ...next.page, - shapes: Object.fromEntries( - Object.entries(next.page.shapes).filter(([_, shape]) => { - return shape && (shape.parentId === next.page.id || next.page.shapes[shape.parentId]) - }) - ), + shapes, + bindings, } } @@ -126,13 +194,10 @@ export class TLDrawState implements TLCallbacks { const newSelectedStyle = TLDR.getSelectedStyle(next as Data) if (newSelectedStyle) { - next = { - ...next, - appState: { - ...current.appState, - ...next.appState, - selectedStyle: newSelectedStyle, - }, + next.appState = { + ...current.appState, + ...next.appState, + selectedStyle: newSelectedStyle, } } @@ -212,6 +277,18 @@ export class TLDrawState implements TLCallbacks { ...data.appState, ...initialData.settings, }, + page: { + ...data.page, + shapes: {}, + bindings: {}, + }, + pageState: { + ...data.pageState, + editingId: undefined, + bindingId: undefined, + hoveredId: undefined, + selectedIds: [], + }, })) this._onChange?.(this, `reset`) return this @@ -519,7 +596,13 @@ export class TLDrawState implements TLCallbacks { history.pointer = history.stack.length - 1 - this.setState((data) => Utils.deepMerge(data, history.stack[history.pointer].after)) + this.setState((data) => + Object.fromEntries( + Object.entries(command.after).map(([key, partial]) => { + return [key, Utils.deepMerge(data[key as keyof Data], partial)] + }) + ) + ) this._onChange?.(this, `command:${command.id}`) @@ -533,7 +616,13 @@ export class TLDrawState implements TLCallbacks { const command = history.stack[history.pointer] - this.setState((data) => Utils.deepMerge(data, command.before)) + this.setState((data) => + Object.fromEntries( + Object.entries(command.before).map(([key, partial]) => { + return [key, Utils.deepMerge(data[key as keyof Data], partial)] + }) + ) + ) history.pointer-- @@ -551,8 +640,13 @@ export class TLDrawState implements TLCallbacks { const command = history.stack[history.pointer] - this.setState((data) => Utils.deepMerge(data, command.after)) - + this.setState((data) => + Object.fromEntries( + Object.entries(command.after).map(([key, partial]) => { + return [key, Utils.deepMerge(data[key as keyof Data], partial)] + }) + ) + ) this._onChange?.(this, `redo:${command.id}`) return this @@ -956,14 +1050,21 @@ export class TLDrawState implements TLCallbacks { } startHandleSession = (point: number[], handleId: string, commandId?: string) => { - this.startSession( - new HandleSession(this.store.getState(), handleId, point, commandId) - ) + const selectedShape = this.page.shapes[this.selectedIds[0]] + if (selectedShape.type === TLDrawShapeType.Arrow) { + this.startSession( + new ArrowSession(this.store.getState(), handleId as 'start' | 'end', point) + ) + } else { + this.startSession( + new HandleSession(this.store.getState(), handleId, point, commandId) + ) + } return this } updateHandleSession = (point: number[], shiftKey = false, altKey = false, metaKey = false) => { - this.updateSession(point, shiftKey, altKey, metaKey) + this.updateSession(point, shiftKey, altKey, metaKey) return this } @@ -1003,7 +1104,12 @@ export class TLDrawState implements TLCallbacks { break } case 'translatingHandle': { - this.updateHandleSession(this.getPagePoint(info.point), info.shiftKey, info.altKey) + this.updateHandleSession( + this.getPagePoint(info.point), + info.shiftKey, + info.altKey, + info.metaKey + ) break } case 'creating': { @@ -1017,7 +1123,12 @@ export class TLDrawState implements TLCallbacks { break } case 'handle': { - this.updateHandleSession(this.getPagePoint(info.point), info.shiftKey, info.altKey) + this.updateHandleSession( + this.getPagePoint(info.point), + info.shiftKey, + info.altKey, + info.metaKey + ) break } case 'point': {