diff --git a/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx b/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx index 77f521369..c42582172 100644 --- a/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx +++ b/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx @@ -2,8 +2,8 @@ import { BaseBoxShapeUtil, BoundsSnapGeometry, HTMLContainer, + RecordProps, Rectangle2d, - ShapeProps, T, TLBaseShape, } from 'tldraw' @@ -23,7 +23,7 @@ type IPlayingCard = TLBaseShape< export class PlayingCardUtil extends BaseBoxShapeUtil { // [2] static override type = 'PlayingCard' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, suit: T.string, diff --git a/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts b/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts index 88d55c551..9f078653b 100644 --- a/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts +++ b/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts @@ -1,8 +1,8 @@ -import { DefaultColorStyle, ShapeProps, T } from 'tldraw' +import { DefaultColorStyle, RecordProps, T } from 'tldraw' import { ICardShape } from './card-shape-types' // Validation for our custom card shape's props, using one of tldraw's default styles -export const cardShapeProps: ShapeProps = { +export const cardShapeProps: RecordProps = { w: T.number, h: T.number, color: DefaultColorStyle, diff --git a/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx b/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx index 7660a8282..d3e9e247a 100644 --- a/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx +++ b/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx @@ -1,8 +1,8 @@ import { Geometry2d, HTMLContainer, + RecordProps, Rectangle2d, - ShapeProps, ShapeUtil, T, TLBaseShape, @@ -28,7 +28,7 @@ type ICustomShape = TLBaseShape< export class MyShapeUtil extends ShapeUtil { // [a] static override type = 'my-custom-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, text: T.string, diff --git a/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx b/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx index 8e1f85652..792419915 100644 --- a/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx +++ b/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx @@ -1,7 +1,7 @@ import { BaseBoxShapeUtil, HTMLContainer, - ShapeProps, + RecordProps, T, TLBaseShape, TLOnEditEndHandler, @@ -23,7 +23,7 @@ type IMyEditableShape = TLBaseShape< export class EditableShapeUtil extends BaseBoxShapeUtil { static override type = 'my-editable-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, animal: T.number, diff --git a/apps/examples/src/examples/interactive-shape/my-interactive-shape-util.tsx b/apps/examples/src/examples/interactive-shape/my-interactive-shape-util.tsx index 8eebe1d41..6b35e18cb 100644 --- a/apps/examples/src/examples/interactive-shape/my-interactive-shape-util.tsx +++ b/apps/examples/src/examples/interactive-shape/my-interactive-shape-util.tsx @@ -1,4 +1,4 @@ -import { BaseBoxShapeUtil, HTMLContainer, ShapeProps, T, TLBaseShape } from 'tldraw' +import { BaseBoxShapeUtil, HTMLContainer, RecordProps, T, TLBaseShape } from 'tldraw' // There's a guide at the bottom of this file! @@ -14,7 +14,7 @@ type IMyInteractiveShape = TLBaseShape< export class myInteractiveShape extends BaseBoxShapeUtil { static override type = 'my-interactive-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, checked: T.boolean, diff --git a/apps/examples/src/examples/slides/SlideShapeUtil.tsx b/apps/examples/src/examples/slides/SlideShapeUtil.tsx index e2c01542b..58aa9c5ca 100644 --- a/apps/examples/src/examples/slides/SlideShapeUtil.tsx +++ b/apps/examples/src/examples/slides/SlideShapeUtil.tsx @@ -1,9 +1,9 @@ import { useCallback } from 'react' import { Geometry2d, + RecordProps, Rectangle2d, SVGContainer, - ShapeProps, ShapeUtil, T, TLBaseShape, @@ -24,7 +24,7 @@ export type SlideShape = TLBaseShape< export class SlideShapeUtil extends ShapeUtil { static override type = 'slide' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, } diff --git a/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx b/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx index 1c0d3e1f0..45f9183ac 100644 --- a/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx +++ b/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx @@ -8,7 +8,7 @@ import { Geometry2d, LABEL_FONT_SIZES, Polygon2d, - ShapePropsType, + RecordPropsType, ShapeUtil, T, TEXT_PROPS, @@ -52,7 +52,7 @@ export const speechBubbleShapeProps = { tail: vecModelValidator, } -export type SpeechBubbleShapeProps = ShapePropsType +export type SpeechBubbleShapeProps = RecordPropsType export type SpeechBubbleShape = TLBaseShape<'speech-bubble', SpeechBubbleShapeProps> export class SpeechBubbleUtil extends ShapeUtil { diff --git a/packages/editor/api-report.md b/packages/editor/api-report.md index 99c6779e9..0605f4a71 100644 --- a/packages/editor/api-report.md +++ b/packages/editor/api-report.md @@ -29,22 +29,27 @@ import { default as React_2 } from 'react'; import * as React_3 from 'react'; import { ReactElement } from 'react'; import { ReactNode } from 'react'; +import { RecordProps } from '@tldraw/tlschema'; import { RecordsDiff } from '@tldraw/store'; import { SerializedSchema } from '@tldraw/store'; import { SerializedStore } from '@tldraw/store'; -import { ShapeProps } from '@tldraw/tlschema'; import { Signal } from '@tldraw/state'; import { Store } from '@tldraw/store'; import { StoreSchema } from '@tldraw/store'; import { StoreSnapshot } from '@tldraw/store'; import { StyleProp } from '@tldraw/tlschema'; import { StylePropValue } from '@tldraw/tlschema'; +import { TLArrowBinding } from '@tldraw/tlschema'; +import { TLArrowBindingProps } from '@tldraw/tlschema'; import { TLArrowShape } from '@tldraw/tlschema'; import { TLArrowShapeArrowheadStyle } from '@tldraw/tlschema'; import { TLAsset } from '@tldraw/tlschema'; import { TLAssetId } from '@tldraw/tlschema'; import { TLAssetPartial } from '@tldraw/tlschema'; import { TLBaseShape } from '@tldraw/tlschema'; +import { TLBinding } from '@tldraw/tlschema'; +import { TLBindingId } from '@tldraw/tlschema'; +import { TLBindingPartial } from '@tldraw/tlschema'; import { TLBookmarkAsset } from '@tldraw/tlschema'; import { TLCamera } from '@tldraw/tlschema'; import { TLCursor } from '@tldraw/tlschema'; @@ -60,14 +65,15 @@ import { TLInstancePresence } from '@tldraw/tlschema'; import { TLPage } from '@tldraw/tlschema'; import { TLPageId } from '@tldraw/tlschema'; import { TLParentId } from '@tldraw/tlschema'; +import { TLPropsMigrations } from '@tldraw/tlschema'; import { TLRecord } from '@tldraw/tlschema'; import { TLScribble } from '@tldraw/tlschema'; import { TLShape } from '@tldraw/tlschema'; import { TLShapeId } from '@tldraw/tlschema'; import { TLShapePartial } from '@tldraw/tlschema'; -import { TLShapePropsMigrations } from '@tldraw/tlschema'; import { TLStore } from '@tldraw/tlschema'; import { TLStoreProps } from '@tldraw/tlschema'; +import { TLUnknownBinding } from '@tldraw/tlschema'; import { TLUnknownShape } from '@tldraw/tlschema'; import { TLVideoAsset } from '@tldraw/tlschema'; import { track } from '@tldraw/state'; @@ -138,6 +144,12 @@ export class Arc2d extends Geometry2d { // @public export function areAnglesCompatible(a: number, b: number): boolean; +// @internal +export function arrowBindingMakeItNotSo(editor: Editor, arrow: TLArrowShape, terminal: 'end' | 'start'): void; + +// @internal +export function arrowBindingMakeItSo(editor: Editor, arrow: TLArrowShape | TLShapeId, target: TLShape | TLShapeId, props: TLArrowBindingProps): void; + export { Atom } export { atom } @@ -169,6 +181,23 @@ export abstract class BaseBoxShapeUtil extends Sha onResize: TLOnResizeHandler; } +// @public (undocumented) +export abstract class BindingUtil { + constructor(editor: Editor); + // (undocumented) + editor: Editor; + abstract getDefaultProps(): Binding['props']; + // (undocumented) + static migrations?: TLPropsMigrations; + // (undocumented) + onAfterShapeChange?(binding: Binding, direction: 'from' | 'to', prev: TLShape, next: TLShape): void; + // (undocumented) + onBeforeShapeDelete?(binding: Binding, direction: 'from' | 'to', shape: TLShape): void; + // (undocumented) + static props?: RecordProps; + static type: string; +} + // @public export interface BoundsSnapGeometry { points?: VecModel[]; @@ -585,7 +614,7 @@ export class Edge2d extends Geometry2d { // @public (undocumented) export class Editor extends EventEmitter { - constructor({ store, user, shapeUtils, tools, getContainer, initialState, inferDarkMode, }: TLEditorOptions); + constructor({ store, user, shapeUtils, bindingUtils, tools, getContainer, initialState, inferDarkMode, }: TLEditorOptions); addOpenMenu(id: string): this; alignShapes(shapes: TLShape[] | TLShapeId[], operation: 'bottom' | 'center-horizontal' | 'center-vertical' | 'left' | 'right' | 'top'): this; animateShape(partial: null | TLShapePartial | undefined, animationOptions?: TLAnimationOptions): this; @@ -605,6 +634,9 @@ export class Editor extends EventEmitter { bail(): this; bailToMark(id: string): this; batch(fn: () => void, opts?: TLHistoryBatchOptions): this; + bindingUtils: { + readonly [K in string]?: BindingUtil; + }; bringForward(shapes: TLShape[] | TLShapeId[]): this; bringToFront(shapes: TLShape[] | TLShapeId[]): this; cancel(): this; @@ -619,6 +651,10 @@ export class Editor extends EventEmitter { // @internal (undocumented) crash(error: unknown): this; createAssets(assets: TLAsset[]): this; + // (undocumented) + createBinding(partial: RequiredKeys): void; + // (undocumented) + createBindings(partials: RequiredKeys[]): void; // @internal (undocumented) createErrorAnnotations(origin: string, willCrashApp: 'unknown' | boolean): { extras: { @@ -636,6 +672,10 @@ export class Editor extends EventEmitter { createShape(shape: OptionalKeys, 'id'>): this; createShapes(shapes: OptionalKeys, 'id'>[]): this; deleteAssets(assets: TLAsset[] | TLAssetId[]): this; + // (undocumented) + deleteBinding(binding: TLBinding | TLBindingId): void; + // (undocumented) + deleteBindings(bindings: (TLBinding | TLBindingId)[]): void; deleteOpenMenu(id: string): this; deletePage(page: TLPage | TLPageId): this; deleteShape(id: TLShapeId): this; @@ -671,15 +711,25 @@ export class Editor extends EventEmitter { findCommonAncestor(shapes: TLShape[] | TLShapeId[], predicate?: (shape: TLShape) => boolean): TLShapeId | undefined; findShapeAncestor(shape: TLShape | TLShapeId, predicate: (parent: TLShape) => boolean): TLShape | undefined; flipShapes(shapes: TLShape[] | TLShapeId[], operation: 'horizontal' | 'vertical'): this; + // (undocumented) + getAllBindingsForShape(shape: TLShape | TLShapeId): TLBinding[]; getAncestorPageId(shape?: TLShape | TLShapeId): TLPageId | undefined; getArrowInfo(shape: TLArrowShape | TLShapeId): TLArrowInfo | undefined; - getArrowsBoundTo(shapeId: TLShapeId): { - arrowId: TLShapeId; - handleId: "end" | "start"; - }[]; + getArrowsBoundTo(shapeId: TLShapeId): TLArrowShape[]; getAsset(asset: TLAsset | TLAssetId): TLAsset | undefined; getAssetForExternalContent(info: TLExternalAssetContent): Promise; getAssets(): (TLBookmarkAsset | TLImageAsset | TLVideoAsset)[]; + // (undocumented) + getBinding(id: TLBindingId): TLBinding | undefined; + // (undocumented) + getBindingsFromShape(shape: TLShape | TLShapeId, type: Binding['type']): Binding[]; + // (undocumented) + getBindingsToShape(shape: TLShape | TLShapeId, type: Binding['type']): Binding[]; + getBindingUtil(binding: S | TLBindingPartial): BindingUtil; + // (undocumented) + getBindingUtil(type: S['type']): BindingUtil; + // (undocumented) + getBindingUtil(type: T extends BindingUtil ? R['type'] : string): T; getCamera(): TLCamera; getCameraState(): "idle" | "moving"; getCanRedo(): boolean; @@ -940,6 +990,10 @@ export class Editor extends EventEmitter { // (undocumented) ungroupShapes(ids: TLShape[]): this; updateAssets(assets: TLAssetPartial[]): this; + // (undocumented) + updateBinding(partial: TLBindingPartial): void; + // (undocumented) + updateBindings(partials: (null | TLBindingPartial | undefined)[]): void; updateCurrentPageState(partial: Partial>, historyOptions?: TLHistoryBatchOptions): this; // (undocumented) _updateCurrentPageState: (partial: Partial>, historyOptions?: TLHistoryBatchOptions) => void; @@ -1085,8 +1139,11 @@ export abstract class Geometry2d { // @public export function getArcMeasure(A: number, B: number, sweepFlag: number, largeArcFlag: number): number; +// @internal (undocumented) +export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBindings; + // @public (undocumented) -export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape): { +export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape, bindings: TLArrowBindings): { end: Vec; start: Vec; }; @@ -1182,11 +1239,11 @@ export class GroupShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLGroupShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onChildrenChange: TLOnChildrenChangeHandler; // (undocumented) - static props: ShapeProps; + static props: RecordProps; // (undocumented) static type: "group"; } @@ -1715,7 +1772,7 @@ export abstract class ShapeUtil { abstract indicator(shape: Shape): any; isAspectRatioLocked: TLShapeUtilFlag; // (undocumented) - static migrations?: LegacyMigrations | TLShapePropsMigrations; + static migrations?: LegacyMigrations | MigrationSequence | TLPropsMigrations; onBeforeCreate?: TLOnBeforeCreateHandler; onBeforeUpdate?: TLOnBeforeUpdateHandler; // @internal @@ -1740,7 +1797,7 @@ export abstract class ShapeUtil { onTranslateEnd?: TLOnTranslateEndHandler; onTranslateStart?: TLOnTranslateStartHandler; // (undocumented) - static props?: ShapeProps; + static props?: RecordProps; // @internal providesBackgroundForChildren(shape: Shape): boolean; toBackgroundSvg?(shape: Shape, ctx: SvgExportContext): null | Promise | ReactElement; @@ -1985,6 +2042,9 @@ export type TLAnimationOptions = Partial<{ easing: (t: number) => number; }>; +// @public (undocumented) +export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor; + // @public (undocumented) export type TLAnyShapeUtilConstructor = TLShapeUtilConstructor; @@ -2006,6 +2066,7 @@ export interface TLArcInfo { // @public (undocumented) export type TLArrowInfo = { + bindings: TLArrowBindings; bodyArc: TLArcInfo; end: TLArrowPoint; handleArc: TLArcInfo; @@ -2014,6 +2075,7 @@ export type TLArrowInfo = { middle: VecLike; start: TLArrowPoint; } | { + bindings: TLArrowBindings; end: TLArrowPoint; isStraight: true; isValid: boolean; @@ -2059,6 +2121,18 @@ export type TLBeforeCreateHandler = (record: R, source: 'rem // @public (undocumented) export type TLBeforeDeleteHandler = (record: R, source: 'remote' | 'user') => false | void; +// @public (undocumented) +export interface TLBindingUtilConstructor = BindingUtil> { + // (undocumented) + new (editor: Editor): U; + // (undocumented) + migrations?: TLPropsMigrations; + // (undocumented) + props?: RecordProps; + // (undocumented) + type: T['type']; +} + // @public (undocumented) export type TLBrushProps = { brush: BoxModel; @@ -2139,6 +2213,7 @@ export const TldrawEditor: React_2.NamedExoticComponent; // @public export interface TldrawEditorBaseProps { autoFocus?: boolean; + bindingUtils?: readonly TLAnyBindingUtilConstructor[]; children?: ReactNode; className?: string; components?: TLEditorComponents; @@ -2170,6 +2245,7 @@ export type TLEditorComponents = Partial<{ // @public (undocumented) export interface TLEditorOptions { + bindingUtils: readonly TLBindingUtilConstructor[]; getContainer: () => HTMLElement; inferDarkMode?: boolean; initialState?: string; @@ -2595,9 +2671,9 @@ export interface TLShapeUtilConstructor; + props?: RecordProps; // (undocumented) type: T['type']; } @@ -2633,6 +2709,7 @@ export type TLStoreOptions = { id?: string; initialData?: SerializedStore; } & ({ + bindingUtils?: readonly TLAnyBindingUtilConstructor[]; migrations?: readonly MigrationSequence[]; shapeUtils?: readonly TLAnyShapeUtilConstructor[]; } | { diff --git a/packages/editor/src/index.ts b/packages/editor/src/index.ts index d89708f3e..a6f0d75b6 100644 --- a/packages/editor/src/index.ts +++ b/packages/editor/src/index.ts @@ -104,6 +104,7 @@ export { type TLStoreOptions, } from './lib/config/createTLStore' export { createTLUser } from './lib/config/createTLUser' +export { type TLAnyBindingUtilConstructor } from './lib/config/defaultBindings' export { coreShapes, type TLAnyShapeUtilConstructor } from './lib/config/defaultShapes' export { ANIMATION_MEDIUM_MS, @@ -130,6 +131,7 @@ export { type TLEditorOptions, type TLResizeShapeOptions, } from './lib/editor/Editor' +export { BindingUtil, type TLBindingUtilConstructor } from './lib/editor/bindings/BindingUtil' export { HistoryManager } from './lib/editor/managers/HistoryManager' export type { SideEffectManager, @@ -186,7 +188,12 @@ export { type TLArrowInfo, type TLArrowPoint, } from './lib/editor/shapes/shared/arrow/arrow-types' -export { getArrowTerminalsInArrowSpace } from './lib/editor/shapes/shared/arrow/shared' +export { + arrowBindingMakeItNotSo, + arrowBindingMakeItSo, + getArrowBindings, + getArrowTerminalsInArrowSpace, +} from './lib/editor/shapes/shared/arrow/shared' export { resizeBox, type ResizeBoxOptions } from './lib/editor/shapes/shared/resizeBox' export { BaseBoxShapeTool } from './lib/editor/tools/BaseBoxShapeTool/BaseBoxShapeTool' export { StateNode, type TLStateNodeConstructor } from './lib/editor/tools/StateNode' diff --git a/packages/editor/src/lib/TldrawEditor.tsx b/packages/editor/src/lib/TldrawEditor.tsx index 87ea9b1fe..19e092778 100644 --- a/packages/editor/src/lib/TldrawEditor.tsx +++ b/packages/editor/src/lib/TldrawEditor.tsx @@ -15,6 +15,7 @@ import classNames from 'classnames' import { OptionalErrorBoundary } from './components/ErrorBoundary' import { DefaultErrorFallback } from './components/default-components/DefaultErrorFallback' import { TLUser, createTLUser } from './config/createTLUser' +import { TLAnyBindingUtilConstructor } from './config/defaultBindings' import { TLAnyShapeUtilConstructor } from './config/defaultShapes' import { Editor } from './editor/Editor' import { TLStateNodeConstructor } from './editor/tools/StateNode' @@ -75,6 +76,11 @@ export interface TldrawEditorBaseProps { */ shapeUtils?: readonly TLAnyShapeUtilConstructor[] + /** + * An array of binding utils to use in the editor. + */ + bindingUtils?: readonly TLAnyBindingUtilConstructor[] + /** * An array of tools to add to the editor's state chart. */ @@ -135,6 +141,7 @@ declare global { } const EMPTY_SHAPE_UTILS_ARRAY = [] as const +const EMPTY_BINDING_UTILS_ARRAY = [] as const const EMPTY_TOOLS_ARRAY = [] as const /** @public */ @@ -157,6 +164,7 @@ export const TldrawEditor = memo(function TldrawEditor({ const withDefaults = { ...rest, shapeUtils: rest.shapeUtils ?? EMPTY_SHAPE_UTILS_ARRAY, + bindingUtils: rest.bindingUtils ?? EMPTY_BINDING_UTILS_ARRAY, tools: rest.tools ?? EMPTY_TOOLS_ARRAY, components, } @@ -197,12 +205,25 @@ export const TldrawEditor = memo(function TldrawEditor({ }) function TldrawEditorWithOwnStore( - props: Required + props: Required< + TldrawEditorProps & { store: undefined; user: TLUser }, + 'shapeUtils' | 'bindingUtils' | 'tools' + > ) { - const { defaultName, snapshot, initialData, shapeUtils, persistenceKey, sessionId, user } = props + const { + defaultName, + snapshot, + initialData, + shapeUtils, + bindingUtils, + persistenceKey, + sessionId, + user, + } = props const syncedStore = useLocalStore({ shapeUtils, + bindingUtils, initialData, persistenceKey, sessionId, @@ -219,7 +240,7 @@ const TldrawEditorWithLoadingStore = memo(function TldrawEditorBeforeLoading({ ...rest }: Required< TldrawEditorProps & { store: TLStoreWithStatus; user: TLUser }, - 'shapeUtils' | 'tools' + 'shapeUtils' | 'bindingUtils' | 'tools' >) { const container = useContainer() @@ -262,6 +283,7 @@ function TldrawEditorWithReadyStore({ store, tools, shapeUtils, + bindingUtils, user, initialState, autoFocus = true, @@ -271,7 +293,7 @@ function TldrawEditorWithReadyStore({ store: TLStore user: TLUser }, - 'shapeUtils' | 'tools' + 'shapeUtils' | 'bindingUtils' | 'tools' >) { const { ErrorFallback } = useEditorComponents() const container = useContainer() @@ -281,6 +303,7 @@ function TldrawEditorWithReadyStore({ const editor = new Editor({ store, shapeUtils, + bindingUtils, tools, getContainer: () => container, user, @@ -292,7 +315,7 @@ function TldrawEditorWithReadyStore({ return () => { editor.dispose() } - }, [container, shapeUtils, tools, store, user, initialState, inferDarkMode]) + }, [container, shapeUtils, bindingUtils, tools, store, user, initialState, inferDarkMode]) const crashingError = useSyncExternalStore( useCallback( diff --git a/packages/editor/src/lib/config/createTLStore.ts b/packages/editor/src/lib/config/createTLStore.ts index 3361a18a5..f71511e35 100644 --- a/packages/editor/src/lib/config/createTLStore.ts +++ b/packages/editor/src/lib/config/createTLStore.ts @@ -1,13 +1,6 @@ import { HistoryEntry, MigrationSequence, SerializedStore, Store, StoreSchema } from '@tldraw/store' -import { - SchemaShapeInfo, - TLRecord, - TLStore, - TLStoreProps, - TLUnknownShape, - createTLSchema, -} from '@tldraw/tlschema' -import { TLShapeUtilConstructor } from '../editor/shapes/ShapeUtil' +import { SchemaPropsInfo, TLRecord, TLStore, TLStoreProps, createTLSchema } from '@tldraw/tlschema' +import { TLAnyBindingUtilConstructor, checkBindings } from './defaultBindings' import { TLAnyShapeUtilConstructor, checkShapesAndAddCore } from './defaultShapes' /** @public */ @@ -16,7 +9,11 @@ export type TLStoreOptions = { defaultName?: string id?: string } & ( - | { shapeUtils?: readonly TLAnyShapeUtilConstructor[]; migrations?: readonly MigrationSequence[] } + | { + shapeUtils?: readonly TLAnyShapeUtilConstructor[] + migrations?: readonly MigrationSequence[] + bindingUtils?: readonly TLAnyBindingUtilConstructor[] + } | { schema?: StoreSchema } ) @@ -41,9 +38,12 @@ export function createTLStore({ rest.schema : // we need a schema createTLSchema({ - shapes: currentPageShapesToShapeMap( + shapes: utilsToMap( checkShapesAndAddCore('shapeUtils' in rest && rest.shapeUtils ? rest.shapeUtils : []) ), + bindings: utilsToMap( + checkBindings('bindingUtils' in rest && rest.bindingUtils ? rest.bindingUtils : []) + ), migrations: 'migrations' in rest ? rest.migrations : [], }) @@ -57,9 +57,9 @@ export function createTLStore({ }) } -function currentPageShapesToShapeMap(shapeUtils: TLShapeUtilConstructor[]) { +function utilsToMap(utils: T[]) { return Object.fromEntries( - shapeUtils.map((s): [string, SchemaShapeInfo] => [ + utils.map((s): [string, SchemaPropsInfo] => [ s.type, { props: s.props, diff --git a/packages/editor/src/lib/config/defaultBindings.ts b/packages/editor/src/lib/config/defaultBindings.ts new file mode 100644 index 000000000..19d2f7e09 --- /dev/null +++ b/packages/editor/src/lib/config/defaultBindings.ts @@ -0,0 +1,19 @@ +import { TLBindingUtilConstructor } from '../editor/bindings/BindingUtil' + +/** @public */ +export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor + +export function checkBindings(customBindings: readonly TLAnyBindingUtilConstructor[]) { + const bindings = [] as TLAnyBindingUtilConstructor[] + + const addedCustomBindingTypes = new Set() + for (const customBinding of customBindings) { + if (addedCustomBindingTypes.has(customBinding.type)) { + throw new Error(`Binding type "${customBinding.type}" is defined more than once`) + } + bindings.push(customBinding) + addedCustomBindingTypes.add(customBinding.type) + } + + return bindings +} diff --git a/packages/editor/src/lib/editor/Editor.ts b/packages/editor/src/lib/editor/Editor.ts index 88e44208f..afa685930 100644 --- a/packages/editor/src/lib/editor/Editor.ts +++ b/packages/editor/src/lib/editor/Editor.ts @@ -6,10 +6,14 @@ import { PageRecordType, StyleProp, StylePropValue, + TLArrowBinding, TLArrowShape, TLAsset, TLAssetId, TLAssetPartial, + TLBinding, + TLBindingId, + TLBindingPartial, TLCursor, TLCursorType, TLDOCUMENT_ID, @@ -31,8 +35,10 @@ import { TLShapeId, TLShapePartial, TLStore, + TLUnknownBinding, TLUnknownShape, TLVideoAsset, + createBindingId, createShapeId, getShapePropKeysByStyle, isPageId, @@ -60,6 +66,7 @@ import { EventEmitter } from 'eventemitter3' import { flushSync } from 'react-dom' import { createRoot } from 'react-dom/client' import { TLUser, createTLUser } from '../config/createTLUser' +import { checkBindings } from '../config/defaultBindings' import { checkShapesAndAddCore } from '../config/defaultShapes' import { ANIMATION_MEDIUM_MS, @@ -98,7 +105,7 @@ import { getIncrementedName } from '../utils/getIncrementedName' import { getReorderingShapesChanges } from '../utils/reorderShapes' import { applyRotationToSnapshotShapes, getRotationSnapshot } from '../utils/rotation' import { uniqueId } from '../utils/uniqueId' -import { arrowBindingsIndex } from './derivations/arrowBindingsIndex' +import { BindingUtil, TLBindingUtilConstructor } from './bindings/BindingUtil' import { notVisibleShapes } from './derivations/notVisibleShapes' import { parentsToChildren } from './derivations/parentsToChildren' import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage' @@ -115,7 +122,7 @@ import { UserPreferencesManager } from './managers/UserPreferencesManager' import { ShapeUtil, TLResizeMode, TLShapeUtilConstructor } from './shapes/ShapeUtil' import { TLArrowInfo } from './shapes/shared/arrow/arrow-types' import { getCurvedArrowInfo } from './shapes/shared/arrow/curved-arrow' -import { getArrowTerminalsInArrowSpace, getIsArrowStraight } from './shapes/shared/arrow/shared' +import { getArrowBindings, getIsArrowStraight } from './shapes/shared/arrow/shared' import { getStraightArrowInfo } from './shapes/shared/arrow/straight-arrow' import { RootState } from './tools/RootState' import { StateNode, TLStateNodeConstructor } from './tools/StateNode' @@ -160,6 +167,10 @@ export interface TLEditorOptions { * An array of shapes to use in the editor. These will be used to create and manage shapes in the editor. */ shapeUtils: readonly TLShapeUtilConstructor[] + /** + * An array of bindings to use in the editor. These will be used to create and manage bindings in the editor. + */ + bindingUtils: readonly TLBindingUtilConstructor[] /** * An array of tools to use in the editor. These will be used to handle events and manage user interactions in the editor. */ @@ -189,6 +200,7 @@ export class Editor extends EventEmitter { store, user, shapeUtils, + bindingUtils, tools, getContainer, initialState, @@ -248,6 +260,14 @@ export class Editor extends EventEmitter { this.shapeUtils = _shapeUtils this.styleProps = _styleProps + const allBindingUtils = checkBindings(bindingUtils) + const _bindingUtils = {} as Record> + for (const Util of allBindingUtils) { + const util = new Util(this) + _bindingUtils[Util.type] = util + } + this.bindingUtils = _bindingUtils + // Tools. // Accept tools from constructor parameters which may not conflict with the root note's default or // "baked in" tools, select and zoom. @@ -265,118 +285,118 @@ export class Editor extends EventEmitter { const invalidParents = new Set() - const reparentArrow = (arrowId: TLArrowShape['id']) => { - const arrow = this.getShape(arrowId) - if (!arrow) return - const { start, end } = arrow.props - const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined - const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined + // const reparentArrow = (arrowId: TLArrowShape['id']) => { + // const arrow = this.getShape(arrowId) + // if (!arrow) return + // const { start, end } = arrow.props + // const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined + // const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined - const parentPageId = this.getAncestorPageId(arrow) - if (!parentPageId) return + // const parentPageId = this.getAncestorPageId(arrow) + // if (!parentPageId) return - let nextParentId: TLParentId - if (startShape && endShape) { - // if arrow has two bindings, always parent arrow to closest common ancestor of the bindings - nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId - } else if (startShape || endShape) { - const bindingParentId = (startShape || endShape)?.parentId - // If the arrow and the shape that it is bound to have the same parent, then keep that parent - if (bindingParentId && bindingParentId === arrow.parentId) { - nextParentId = arrow.parentId - } else { - // if arrow has one binding, keep arrow on its own page - nextParentId = parentPageId - } - } else { - return - } + // let nextParentId: TLParentId + // if (startShape && endShape) { + // // if arrow has two bindings, always parent arrow to closest common ancestor of the bindings + // nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId + // } else if (startShape || endShape) { + // const bindingParentId = (startShape || endShape)?.parentId + // // If the arrow and the shape that it is bound to have the same parent, then keep that parent + // if (bindingParentId && bindingParentId === arrow.parentId) { + // nextParentId = arrow.parentId + // } else { + // // if arrow has one binding, keep arrow on its own page + // nextParentId = parentPageId + // } + // } else { + // return + // } - if (nextParentId && nextParentId !== arrow.parentId) { - this.reparentShapes([arrowId], nextParentId) - } + // if (nextParentId && nextParentId !== arrow.parentId) { + // this.reparentShapes([arrowId], nextParentId) + // } - const reparentedArrow = this.getShape(arrowId) - if (!reparentedArrow) throw Error('no reparented arrow') + // const reparentedArrow = this.getShape(arrowId) + // if (!reparentedArrow) throw Error('no reparented arrow') - const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape) - const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape) + // const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape) + // const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape) - let highestSibling: TLShape | undefined + // let highestSibling: TLShape | undefined - if (startSibling && endSibling) { - highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling - } else if (startSibling && !endSibling) { - highestSibling = startSibling - } else if (endSibling && !startSibling) { - highestSibling = endSibling - } else { - return - } + // if (startSibling && endSibling) { + // highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling + // } else if (startSibling && !endSibling) { + // highestSibling = startSibling + // } else if (endSibling && !startSibling) { + // highestSibling = endSibling + // } else { + // return + // } - let finalIndex: IndexKey + // let finalIndex: IndexKey - const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId) - .map((id) => this.getShape(id)!) - .filter((sibling) => sibling.index > highestSibling!.index) + // const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId) + // .map((id) => this.getShape(id)!) + // .filter((sibling) => sibling.index > highestSibling!.index) - if (higherSiblings.length) { - // there are siblings above the highest bound sibling, we need to - // insert between them. + // if (higherSiblings.length) { + // // there are siblings above the highest bound sibling, we need to + // // insert between them. - // if the next sibling is also a bound arrow though, we can end up - // all fighting for the same indexes. so lets find the next - // non-arrow sibling... - const nextHighestNonArrowSibling = higherSiblings.find( - (sibling) => sibling.type !== 'arrow' - ) + // // if the next sibling is also a bound arrow though, we can end up + // // all fighting for the same indexes. so lets find the next + // // non-arrow sibling... + // const nextHighestNonArrowSibling = higherSiblings.find( + // (sibling) => sibling.type !== 'arrow' + // ) - if ( - // ...then, if we're above the last shape we want to be above... - reparentedArrow.index > highestSibling.index && - // ...but below the next non-arrow sibling... - (!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index) - ) { - // ...then we're already in the right place. no need to update! - return - } + // if ( + // // ...then, if we're above the last shape we want to be above... + // reparentedArrow.index > highestSibling.index && + // // ...but below the next non-arrow sibling... + // (!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index) + // ) { + // // ...then we're already in the right place. no need to update! + // return + // } - // otherwise, we need to find the index between the highest sibling - // we want to be above, and the next highest sibling we want to be - // below: - finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index) - } else { - // if there are no siblings above us, we can just get the next index: - finalIndex = getIndexAbove(highestSibling.index) - } + // // otherwise, we need to find the index between the highest sibling + // // we want to be above, and the next highest sibling we want to be + // // below: + // finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index) + // } else { + // // if there are no siblings above us, we can just get the next index: + // finalIndex = getIndexAbove(highestSibling.index) + // } - if (finalIndex !== reparentedArrow.index) { - this.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) - } - } + // if (finalIndex !== reparentedArrow.index) { + // this.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) + // } + // } - const unbindArrowTerminal = (arrow: TLArrowShape, handleId: 'start' | 'end') => { - const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId] - this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }]) - } + // const unbindArrowTerminal = (arrow: TLArrowShape, handleId: 'start' | 'end') => { + // const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId] + // this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }]) + // } - const arrowDidUpdate = (arrow: TLArrowShape) => { - // if the shape is an arrow and its bound shape is on another page - // or was deleted, unbind it - for (const handle of ['start', 'end'] as const) { - const terminal = arrow.props[handle] - if (terminal.type !== 'binding') continue - const boundShape = this.getShape(terminal.boundShapeId) - const isShapeInSamePageAsArrow = - this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape) - if (!boundShape || !isShapeInSamePageAsArrow) { - unbindArrowTerminal(arrow, handle) - } - } + // const arrowDidUpdate = (arrow: TLArrowShape) => { + // // if the shape is an arrow and its bound shape is on another page + // // or was deleted, unbind it + // for (const handle of ['start', 'end'] as const) { + // const terminal = arrow.props[handle] + // if (terminal.type !== 'binding') continue + // const boundShape = this.getShape(terminal.boundShapeId) + // const isShapeInSamePageAsArrow = + // this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape) + // if (!boundShape || !isShapeInSamePageAsArrow) { + // unbindArrowTerminal(arrow, handle) + // } + // } - // always check the arrow parents - reparentArrow(arrow.id) - } + // // always check the arrow parents + // reparentArrow(arrow.id) + // } const cleanupInstancePageState = ( prevPageState: TLInstancePageState, @@ -450,28 +470,28 @@ export class Editor extends EventEmitter { this.sideEffects.register({ shape: { afterCreate: (record) => { - if (this.isShapeOfType(record, 'arrow')) { - arrowDidUpdate(record) - } + // if (this.isShapeOfType(record, 'arrow')) { + // arrowDidUpdate(record) + // } }, afterChange: (prev, next) => { - if (this.isShapeOfType(next, 'arrow')) { - arrowDidUpdate(next) - } + // if (this.isShapeOfType(next, 'arrow')) { + // arrowDidUpdate(next) + // } // if the shape's parent changed and it is bound to an arrow, update the arrow's parent - if (prev.parentId !== next.parentId) { - const reparentBoundArrows = (id: TLShapeId) => { - const boundArrows = this._getArrowBindingsIndex().get()[id] - if (boundArrows?.length) { - for (const arrow of boundArrows) { - reparentArrow(arrow.arrowId) - } - } - } - reparentBoundArrows(next.id) - this.visitDescendants(next.id, reparentBoundArrows) - } + // if (prev.parentId !== next.parentId) { + // const reparentBoundArrows = (id: TLShapeId) => { + // const boundArrows = this._getArrowBindingsIndex().get()[id] + // if (boundArrows?.length) { + // for (const arrow of boundArrows) { + // reparentArrow(arrow.arrowId) + // } + // } + // } + // reparentBoundArrows(next.id) + // this.visitDescendants(next.id, reparentBoundArrows) + // } // if this shape moved to a new page, clean up any previous page's instance state if (prev.parentId !== next.parentId && isPageId(next.parentId)) { @@ -504,14 +524,14 @@ export class Editor extends EventEmitter { invalidParents.add(record.parentId) } // clean up any arrows bound to this shape - const bindings = this._getArrowBindingsIndex().get()[record.id] - if (bindings?.length) { - for (const { arrowId, handleId } of bindings) { - const arrow = this.getShape(arrowId) - if (!arrow) continue - unbindArrowTerminal(arrow, handleId) - } - } + // const bindings = this._getArrowBindingsIndex().get()[record.id] + // if (bindings?.length) { + // for (const { arrowId, handleId } of bindings) { + // const arrow = this.getShape(arrowId) + // if (!arrow) continue + // unbindArrowTerminal(arrow, handleId) + // } + // } const deletedIds = new Set([record.id]) const updates = compact( this.getPageStates().map((pageState) => { @@ -792,6 +812,41 @@ export class Editor extends EventEmitter { return shapeUtil } + /* ------------------- Binding Utils ------------------ */ + /** + * A map of shape utility classes (TLShapeUtils) by shape type. + * + * @public + */ + bindingUtils: { readonly [K in string]?: BindingUtil } + + /** + * Get a binding util from a binding itself. + * + * @example + * ```ts + * const util = editor.getBindingUtil(myArrowBinding) + * const util = editor.getBindingUtil('arrow') + * const util = editor.getBindingUtil(myArrowBinding) + * const util = editor.getBindingUtil(TLArrowBinding)('arrow') + * ``` + * + * @param binding - A binding, binding partial, or binding type. + * + * @public + */ + getBindingUtil(binding: S | TLBindingPartial): BindingUtil + getBindingUtil(type: S['type']): BindingUtil + getBindingUtil( + type: T extends BindingUtil ? R['type'] : string + ): T + getBindingUtil(arg: string | { type: string }) { + const type = typeof arg === 'string' ? arg : arg.type + const bindingUtil = getOwnProperty(this.bindingUtils, type) + assert(bindingUtil, `No binding util found for type "${type}"`) + return bindingUtil + } + /* --------------------- History -------------------- */ /** @@ -913,12 +968,6 @@ export class Editor extends EventEmitter { /* --------------------- Arrows --------------------- */ // todo: move these to tldraw or replace with a bindings API - /** @internal */ - @computed - private _getArrowBindingsIndex() { - return arrowBindingsIndex(this) - } - /** * Get all arrows bound to a shape. * @@ -927,15 +976,19 @@ export class Editor extends EventEmitter { * @public */ getArrowsBoundTo(shapeId: TLShapeId) { - return this._getArrowBindingsIndex().get()[shapeId] || EMPTY_ARRAY + const ids = new Set( + this.getBindingsToShape(shapeId, 'arrow').map((b) => b.fromId) + ) + return compact(Array.from(ids, (id) => this.getShape(id))) } @computed private getArrowInfoCache() { return this.store.createComputedCache('arrow infoCache', (shape) => { + const bindings = getArrowBindings(this, shape) return getIsArrowStraight(shape) - ? getStraightArrowInfo(this, shape) - : getCurvedArrowInfo(this, shape) + ? getStraightArrowInfo(this, shape, bindings) + : getCurvedArrowInfo(this, shape, bindings) }) } @@ -4836,6 +4889,99 @@ export class Editor extends EventEmitter { return match } + /* -------------------- Bindings -------------------- */ + + getBinding(id: TLBindingId): TLBinding | undefined { + return this.store.get(id) as TLBinding | undefined + } + + // TODO: maintain these indexes more pro-actively + getBindingsFromShape( + shape: TLShape | TLShapeId, + type: Binding['type'] + ): Binding[] { + const id = typeof shape === 'string' ? shape : shape.id + return this.store.query.exec('binding', { + fromId: { eq: id }, + type: { eq: type }, + }) as Binding[] + } + getBindingsToShape( + shape: TLShape | TLShapeId, + type: Binding['type'] + ): Binding[] { + const id = typeof shape === 'string' ? shape : shape.id + return this.store.query.exec('binding', { + toId: { eq: id }, + type: { eq: type }, + }) as Binding[] + } + getAllBindingsForShape(shape: TLShape | TLShapeId): TLBinding[] { + const id = typeof shape === 'string' ? shape : shape.id + const from = this.store.query.index('binding', 'fromId').get().get(id) ?? new Set() + const to = this.store.query.index('binding', 'toId').get().get(id) + const shapes = [] + for (const id of from) { + shapes.push(this.store.get(id) as TLBinding) + } + if (to) { + for (const id of to) { + if (from.has(id)) continue + shapes.push(this.store.get(id) as TLBinding) + } + } + return shapes + } + + createBindings(partials: RequiredKeys[]) { + const bindings = partials.map((partial) => { + const util = this.getBindingUtil(partial.type) + const defaultProps = util.getDefaultProps() + return this.store.schema.types.binding.create({ + ...partial, + id: partial.id ?? createBindingId(), + props: { + ...defaultProps, + ...partial.props, + }, + }) + }) + this.store.put(bindings) + } + createBinding(partial: RequiredKeys) { + return this.createBindings([partial]) + } + + updateBindings(partials: (TLBindingPartial | null | undefined)[]) { + const updated: TLBinding[] = [] + + for (const partial of partials) { + if (!partial) continue + + const current = this.getBinding(partial.id) + if (!current) continue + + const updatedBinding = applyPartialToBinding(current, partial) + if (updatedBinding === current) continue + + updated.push(updatedBinding) + } + + this.store.put(updated) + } + + updateBinding(partial: TLBindingPartial) { + return this.updateBindings([partial]) + } + + deleteBindings(bindings: (TLBinding | TLBindingId)[]) { + const ids = bindings.map((binding) => (typeof binding === 'string' ? binding : binding.id)) + this.store.remove(ids) + } + deleteBinding(binding: TLBinding | TLBindingId) { + return this.deleteBindings([binding]) + } + /* -------------------- Commands -------------------- */ /** @@ -4999,80 +5145,80 @@ export class Editor extends EventEmitter { let newShape: TLShape = structuredClone(shape) - if ( - this.isShapeOfType(shape, 'arrow') && - this.isShapeOfType(newShape, 'arrow') - ) { - const info = this.getArrowInfo(shape) - let newStartShapeId: TLShapeId | undefined = undefined - let newEndShapeId: TLShapeId | undefined = undefined + // if ( + // this.isShapeOfType(shape, 'arrow') && + // this.isShapeOfType(newShape, 'arrow') + // ) { + // const info = this.getArrowInfo(shape) + // let newStartShapeId: TLShapeId | undefined = undefined + // let newEndShapeId: TLShapeId | undefined = undefined - if (shape.props.start.type === 'binding') { - newStartShapeId = idsMap.get(shape.props.start.boundShapeId) + // if (shape.props.start.type === 'binding') { + // newStartShapeId = idsMap.get(shape.props.start.boundShapeId) - if (!newStartShapeId) { - if (info?.isValid) { - const { x, y } = info.start.point - newShape.props.start = { - type: 'point', - x, - y, - } - } else { - const { start } = getArrowTerminalsInArrowSpace(this, shape) - newShape.props.start = { - type: 'point', - x: start.x, - y: start.y, - } - } - } - } + // if (!newStartShapeId) { + // if (info?.isValid) { + // const { x, y } = info.start.point + // newShape.props.start = { + // type: 'point', + // x, + // y, + // } + // } else { + // const { start } = getArrowTerminalsInArrowSpace(this, shape) + // newShape.props.start = { + // type: 'point', + // x: start.x, + // y: start.y, + // } + // } + // } + // } - if (shape.props.end.type === 'binding') { - newEndShapeId = idsMap.get(shape.props.end.boundShapeId) - if (!newEndShapeId) { - if (info?.isValid) { - const { x, y } = info.end.point - newShape.props.end = { - type: 'point', - x, - y, - } - } else { - const { end } = getArrowTerminalsInArrowSpace(this, shape) - newShape.props.start = { - type: 'point', - x: end.x, - y: end.y, - } - } - } - } + // if (shape.props.end.type === 'binding') { + // newEndShapeId = idsMap.get(shape.props.end.boundShapeId) + // if (!newEndShapeId) { + // if (info?.isValid) { + // const { x, y } = info.end.point + // newShape.props.end = { + // type: 'point', + // x, + // y, + // } + // } else { + // const { end } = getArrowTerminalsInArrowSpace(this, shape) + // newShape.props.start = { + // type: 'point', + // x: end.x, + // y: end.y, + // } + // } + // } + // } - const infoAfter = getIsArrowStraight(newShape) - ? getStraightArrowInfo(this, newShape) - : getCurvedArrowInfo(this, newShape) + // const infoAfter = getIsArrowStraight(newShape) + // ? getStraightArrowInfo(this, newShape) + // : getCurvedArrowInfo(this, newShape) - if (info?.isValid && infoAfter?.isValid && !getIsArrowStraight(shape)) { - const mpA = Vec.Med(info.start.handle, info.end.handle) - const distA = Vec.Dist(info.middle, mpA) - const distB = Vec.Dist(infoAfter.middle, mpA) - if (newShape.props.bend < 0) { - newShape.props.bend += distB - distA - } else { - newShape.props.bend -= distB - distA - } - } + // if (info?.isValid && infoAfter?.isValid && !getIsArrowStraight(shape)) { + // const mpA = Vec.Med(info.start.handle, info.end.handle) + // const distA = Vec.Dist(info.middle, mpA) + // const distB = Vec.Dist(infoAfter.middle, mpA) + // if (newShape.props.bend < 0) { + // newShape.props.bend += distB - distA + // } else { + // newShape.props.bend -= distB - distA + // } + // } - if (newShape.props.start.type === 'binding' && newStartShapeId) { - newShape.props.start.boundShapeId = newStartShapeId - } + // if (newShape.props.start.type === 'binding' && newStartShapeId) { + // newShape.props.start.boundShapeId = newStartShapeId + // } - if (newShape.props.end.type === 'binding' && newEndShapeId) { - newShape.props.end.boundShapeId = newEndShapeId - } - } + // if (newShape.props.end.type === 'binding' && newEndShapeId) { + // newShape.props.end.boundShapeId = newEndShapeId + // } + // } newShape = { ...newShape, id: createId, x: shape.x + ox, y: shape.y + oy, index } @@ -5430,11 +5576,11 @@ export class Editor extends EventEmitter { .filter((shape) => { if (!shape) return false - if (this.isShapeOfType(shape, 'arrow')) { - if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { - return false - } - } + // if (this.isShapeOfType(shape, 'arrow')) { + // if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { + // return false + // } + // } return true }) @@ -5576,11 +5722,11 @@ export class Editor extends EventEmitter { .filter((shape) => { if (!shape) return false - if (this.isShapeOfType(shape, 'arrow')) { - if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { - return false - } - } + // if (this.isShapeOfType(shape, 'arrow')) { + // if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { + // return false + // } + // } return true }) @@ -7270,75 +7416,75 @@ export class Editor extends EventEmitter { shape = structuredClone(shape) as typeof shape - if (this.isShapeOfType(shape, 'arrow')) { - const startBindingId = - shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : undefined + // if (this.isShapeOfType(shape, 'arrow')) { + // const startBindingId = + // shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : undefined - const endBindingId = - shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : undefined + // const endBindingId = + // shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : undefined - const info = this.getArrowInfo(shape) + // const info = this.getArrowInfo(shape) - if (shape.props.start.type === 'binding') { - if (!shapesForContent.some((s) => s.id === startBindingId)) { - // Uh oh, the arrow's bound-to shape isn't among the shapes - // that we're getting the content for. We should try to adjust - // the arrow so that it appears in the place it would be - if (info?.isValid) { - const { x, y } = info.start.point - shape.props.start = { - type: 'point', - x, - y, - } - } else { - const { start } = getArrowTerminalsInArrowSpace(this, shape) - shape.props.start = { - type: 'point', - x: start.x, - y: start.y, - } - } - } - } + // if (shape.props.start.type === 'binding') { + // if (!shapesForContent.some((s) => s.id === startBindingId)) { + // // Uh oh, the arrow's bound-to shape isn't among the shapes + // // that we're getting the content for. We should try to adjust + // // the arrow so that it appears in the place it would be + // if (info?.isValid) { + // const { x, y } = info.start.point + // shape.props.start = { + // type: 'point', + // x, + // y, + // } + // } else { + // const { start } = getArrowTerminalsInArrowSpace(this, shape) + // shape.props.start = { + // type: 'point', + // x: start.x, + // y: start.y, + // } + // } + // } + // } - if (shape.props.end.type === 'binding') { - if (!shapesForContent.some((s) => s.id === endBindingId)) { - if (info?.isValid) { - const { x, y } = info.end.point - shape.props.end = { - type: 'point', - x, - y, - } - } else { - const { end } = getArrowTerminalsInArrowSpace(this, shape) - shape.props.end = { - type: 'point', - x: end.x, - y: end.y, - } - } - } - } + // if (shape.props.end.type === 'binding') { + // if (!shapesForContent.some((s) => s.id === endBindingId)) { + // if (info?.isValid) { + // const { x, y } = info.end.point + // shape.props.end = { + // type: 'point', + // x, + // y, + // } + // } else { + // const { end } = getArrowTerminalsInArrowSpace(this, shape) + // shape.props.end = { + // type: 'point', + // x: end.x, + // y: end.y, + // } + // } + // } + // } - const infoAfter = getIsArrowStraight(shape) - ? getStraightArrowInfo(this, shape) - : getCurvedArrowInfo(this, shape) + // const infoAfter = getIsArrowStraight(shape) + // ? getStraightArrowInfo(this, shape) + // : getCurvedArrowInfo(this, shape) - if (info?.isValid && infoAfter?.isValid && !getIsArrowStraight(shape)) { - const mpA = Vec.Med(info.start.handle, info.end.handle) - const distA = Vec.Dist(info.middle, mpA) - const distB = Vec.Dist(infoAfter.middle, mpA) - if (shape.props.bend < 0) { - shape.props.bend += distB - distA - } else { - shape.props.bend -= distB - distA - } - } + // if (info?.isValid && infoAfter?.isValid && !getIsArrowStraight(shape)) { + // const mpA = Vec.Med(info.start.handle, info.end.handle) + // const distA = Vec.Dist(info.middle, mpA) + // const distB = Vec.Dist(infoAfter.middle, mpA) + // if (shape.props.bend < 0) { + // shape.props.bend += distB - distA + // } else { + // shape.props.bend -= distB - distA + // } + // } - return shape - } + // return shape + // } return shape }) @@ -7550,24 +7696,24 @@ export class Editor extends EventEmitter { index = getIndexAbove(index) } - if (this.isShapeOfType(newShape, 'arrow')) { - if (newShape.props.start.type === 'binding') { - const mappedId = idMap.get(newShape.props.start.boundShapeId) - newShape.props.start = mappedId - ? { ...newShape.props.start, boundShapeId: mappedId } - : // this shouldn't happen, if you copy an arrow but not it's bound shape it should - // convert the binding to a point at the time of copying - { type: 'point', x: 0, y: 0 } - } - if (newShape.props.end.type === 'binding') { - const mappedId = idMap.get(newShape.props.end.boundShapeId) - newShape.props.end = mappedId - ? { ...newShape.props.end, boundShapeId: mappedId } - : // this shouldn't happen, if you copy an arrow but not it's bound shape it should - // convert the binding to a point at the time of copying - { type: 'point', x: 0, y: 0 } - } - } + // if (this.isShapeOfType(newShape, 'arrow')) { + // if (newShape.props.start.type === 'binding') { + // const mappedId = idMap.get(newShape.props.start.boundShapeId) + // newShape.props.start = mappedId + // ? { ...newShape.props.start, boundShapeId: mappedId } + // : // this shouldn't happen, if you copy an arrow but not it's bound shape it should + // // convert the binding to a point at the time of copying + // { type: 'point', x: 0, y: 0 } + // } + // if (newShape.props.end.type === 'binding') { + // const mappedId = idMap.get(newShape.props.end.boundShapeId) + // newShape.props.end = mappedId + // ? { ...newShape.props.end, boundShapeId: mappedId } + // : // this shouldn't happen, if you copy an arrow but not it's bound shape it should + // // convert the binding to a point at the time of copying + // { type: 'point', x: 0, y: 0 } + // } + // } return newShape }) @@ -8547,6 +8693,41 @@ function applyPartialToShape(prev: T, partial?: TLShapePartia return next } +function applyPartialToBinding(prev: T, partial?: TLBindingPartial): T { + if (!partial) return prev + let next = null as null | T + const entries = Object.entries(partial) + for (let i = 0, n = entries.length; i < n; i++) { + const [k, v] = entries[i] + if (v === undefined) continue + + // Is the key a special key? We don't update those + if (k === 'id' || k === 'type' || k === 'typeName') continue + + // Is the value the same as it was before? + if (v === (prev as any)[k]) continue + + // There's a new value, so create the new shape if we haven't already (should we be cloning this?) + if (!next) next = { ...prev } + + // for props / meta properties, we support updates with partials of this object + if (k === 'props' || k === 'meta') { + next[k] = { ...prev[k] } as JsonObject + for (const [nextKey, nextValue] of Object.entries(v as object)) { + if (nextValue !== undefined) { + ;(next[k] as JsonObject)[nextKey] = nextValue + } + } + continue + } + + // base property + ;(next as any)[k] = v + } + if (!next) return prev + return next +} + function pushShapeWithDescendants(editor: Editor, id: TLShapeId, result: TLShape[]): void { const shape = editor.getShape(id) if (!shape) return diff --git a/packages/editor/src/lib/editor/bindings/BindingUtil.ts b/packages/editor/src/lib/editor/bindings/BindingUtil.ts new file mode 100644 index 000000000..563a9f4f8 --- /dev/null +++ b/packages/editor/src/lib/editor/bindings/BindingUtil.ts @@ -0,0 +1,43 @@ +import { RecordProps, TLPropsMigrations, TLShape, TLUnknownBinding } from '@tldraw/tlschema' +import { Editor } from '../Editor' + +/** @public */ +export interface TLBindingUtilConstructor< + T extends TLUnknownBinding, + U extends BindingUtil = BindingUtil, +> { + new (editor: Editor): U + type: T['type'] + props?: RecordProps + migrations?: TLPropsMigrations +} + +/** @public */ +export abstract class BindingUtil { + constructor(public editor: Editor) {} + static props?: RecordProps + static migrations?: TLPropsMigrations + + /** + * The type of the binding util, which should match the binding's type. + * + * @public + */ + static type: string + + /** + * Get the default props for a binding. + * + * @public + */ + abstract getDefaultProps(): Binding['props'] + + onAfterShapeChange?( + binding: Binding, + direction: 'from' | 'to', + prev: TLShape, + next: TLShape + ): void + + onBeforeShapeDelete?(binding: Binding, direction: 'from' | 'to', shape: TLShape): void +} diff --git a/packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts b/packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts deleted file mode 100644 index 553c4a88e..000000000 --- a/packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts +++ /dev/null @@ -1,141 +0,0 @@ -import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state' -import { TLArrowShape, TLShape, TLShapeId } from '@tldraw/tlschema' -import { Editor } from '../Editor' - -type TLArrowBindingsIndex = Record< - TLShapeId, - undefined | { arrowId: TLShapeId; handleId: 'start' | 'end' }[] -> - -export const arrowBindingsIndex = (editor: Editor): Computed => { - const { store } = editor - const shapeHistory = store.query.filterHistory('shape') - const arrowQuery = store.query.records('shape', () => ({ type: { eq: 'arrow' as const } })) - function fromScratch() { - const allArrows = arrowQuery.get() as TLArrowShape[] - - const bindings2Arrows: TLArrowBindingsIndex = {} - - for (const arrow of allArrows) { - const { start, end } = arrow.props - if (start.type === 'binding') { - const arrows = bindings2Arrows[start.boundShapeId] - if (arrows) arrows.push({ arrowId: arrow.id, handleId: 'start' }) - else bindings2Arrows[start.boundShapeId] = [{ arrowId: arrow.id, handleId: 'start' }] - } - - if (end.type === 'binding') { - const arrows = bindings2Arrows[end.boundShapeId] - if (arrows) arrows.push({ arrowId: arrow.id, handleId: 'end' }) - else bindings2Arrows[end.boundShapeId] = [{ arrowId: arrow.id, handleId: 'end' }] - } - } - - return bindings2Arrows - } - - return computed('arrowBindingsIndex', (_lastValue, lastComputedEpoch) => { - if (isUninitialized(_lastValue)) { - return fromScratch() - } - - const lastValue = _lastValue - - const diff = shapeHistory.getDiffSince(lastComputedEpoch) - - if (diff === RESET_VALUE) { - return fromScratch() - } - - let nextValue: TLArrowBindingsIndex | undefined = undefined - - function ensureNewArray(boundShapeId: TLShapeId) { - // this will never happen - if (!nextValue) { - nextValue = { ...lastValue } - } - if (!nextValue[boundShapeId]) { - nextValue[boundShapeId] = [] - } else if (nextValue[boundShapeId] === lastValue[boundShapeId]) { - nextValue[boundShapeId] = [...nextValue[boundShapeId]!] - } - } - - function removingBinding( - boundShapeId: TLShapeId, - arrowId: TLShapeId, - handleId: 'start' | 'end' - ) { - ensureNewArray(boundShapeId) - nextValue![boundShapeId] = nextValue![boundShapeId]!.filter( - (binding) => binding.arrowId !== arrowId || binding.handleId !== handleId - ) - if (nextValue![boundShapeId]!.length === 0) { - delete nextValue![boundShapeId] - } - } - - function addBinding(boundShapeId: TLShapeId, arrowId: TLShapeId, handleId: 'start' | 'end') { - ensureNewArray(boundShapeId) - nextValue![boundShapeId]!.push({ arrowId, handleId }) - } - - for (const changes of diff) { - for (const newShape of Object.values(changes.added)) { - if (editor.isShapeOfType(newShape, 'arrow')) { - const { start, end } = newShape.props - if (start.type === 'binding') { - addBinding(start.boundShapeId, newShape.id, 'start') - } - if (end.type === 'binding') { - addBinding(end.boundShapeId, newShape.id, 'end') - } - } - } - - for (const [prev, next] of Object.values(changes.updated) as [TLShape, TLShape][]) { - if ( - !editor.isShapeOfType(prev, 'arrow') || - !editor.isShapeOfType(next, 'arrow') - ) - continue - - for (const handle of ['start', 'end'] as const) { - const prevTerminal = prev.props[handle] - const nextTerminal = next.props[handle] - - if (prevTerminal.type === 'binding' && nextTerminal.type === 'point') { - // if the binding was removed - removingBinding(prevTerminal.boundShapeId, prev.id, handle) - } else if (prevTerminal.type === 'point' && nextTerminal.type === 'binding') { - // if the binding was added - addBinding(nextTerminal.boundShapeId, next.id, handle) - } else if ( - prevTerminal.type === 'binding' && - nextTerminal.type === 'binding' && - prevTerminal.boundShapeId !== nextTerminal.boundShapeId - ) { - // if the binding was changed - removingBinding(prevTerminal.boundShapeId, prev.id, handle) - addBinding(nextTerminal.boundShapeId, next.id, handle) - } - } - } - - for (const prev of Object.values(changes.removed)) { - if (editor.isShapeOfType(prev, 'arrow')) { - const { start, end } = prev.props - if (start.type === 'binding') { - removingBinding(start.boundShapeId, prev.id, 'start') - } - if (end.type === 'binding') { - removingBinding(end.boundShapeId, prev.id, 'end') - } - } - } - } - - // TODO: add diff entries if we need them - return nextValue ?? lastValue - }) -} diff --git a/packages/editor/src/lib/editor/shapes/ShapeUtil.ts b/packages/editor/src/lib/editor/shapes/ShapeUtil.ts index e596d90f5..d0f175feb 100644 --- a/packages/editor/src/lib/editor/shapes/ShapeUtil.ts +++ b/packages/editor/src/lib/editor/shapes/ShapeUtil.ts @@ -1,11 +1,11 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { LegacyMigrations, MigrationSequence } from '@tldraw/store' import { - ShapeProps, + RecordProps, TLHandle, + TLPropsMigrations, TLShape, TLShapePartial, - TLShapePropsMigrations, TLUnknownShape, } from '@tldraw/tlschema' import { ReactElement } from 'react' @@ -25,8 +25,8 @@ export interface TLShapeUtilConstructor< > { new (editor: Editor): U type: T['type'] - props?: ShapeProps - migrations?: LegacyMigrations | TLShapePropsMigrations | MigrationSequence + props?: RecordProps + migrations?: LegacyMigrations | TLPropsMigrations | MigrationSequence } /** @public */ @@ -41,8 +41,8 @@ export interface TLShapeUtilCanvasSvgDef { /** @public */ export abstract class ShapeUtil { constructor(public editor: Editor) {} - static props?: ShapeProps - static migrations?: LegacyMigrations | TLShapePropsMigrations + static props?: RecordProps + static migrations?: LegacyMigrations | TLPropsMigrations | MigrationSequence /** * The type of the shape util, which should match the shape's type. diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts index dbcba62a8..e0412ed69 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts @@ -1,5 +1,6 @@ import { TLArrowShapeArrowheadStyle } from '@tldraw/tlschema' import { VecLike } from '../../../../primitives/Vec' +import { TLArrowBindings } from './shared' /** @public */ export type TLArrowPoint = { @@ -21,6 +22,7 @@ export interface TLArcInfo { /** @public */ export type TLArrowInfo = | { + bindings: TLArrowBindings isStraight: false start: TLArrowPoint end: TLArrowPoint @@ -30,6 +32,7 @@ export type TLArrowInfo = isValid: boolean } | { + bindings: TLArrowBindings isStraight: true start: TLArrowPoint end: TLArrowPoint diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts index a8a755df5..1bff1f53f 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts @@ -15,6 +15,7 @@ import { BOUND_ARROW_OFFSET, MIN_ARROW_LENGTH, STROKE_SIZES, + TLArrowBindings, WAY_TOO_BIG_ARROW_BEND_FACTOR, getArrowTerminalsInArrowSpace, getBoundShapeInfoForTerminal, @@ -25,16 +26,16 @@ import { getStraightArrowInfo } from './straight-arrow' export function getCurvedArrowInfo( editor: Editor, shape: TLArrowShape, - extraBend = 0 + bindings: TLArrowBindings ): TLArrowInfo { const { arrowheadEnd, arrowheadStart } = shape.props - const bend = shape.props.bend + extraBend + const bend = shape.props.bend if (Math.abs(bend) > Math.abs(shape.props.bend * WAY_TOO_BIG_ARROW_BEND_FACTOR)) { - return getStraightArrowInfo(editor, shape) + return getStraightArrowInfo(editor, shape, bindings) } - const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape) + const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape, bindings) const med = Vec.Med(terminalsInArrowSpace.start, terminalsInArrowSpace.end) // point between start and end const distance = Vec.Sub(terminalsInArrowSpace.end, terminalsInArrowSpace.start) @@ -42,8 +43,8 @@ export function getCurvedArrowInfo( const u = Vec.Len(distance) ? distance.uni() : Vec.From(distance) // unit vector between start and end const middle = Vec.Add(med, u.per().mul(-bend)) // middle handle - const startShapeInfo = getBoundShapeInfoForTerminal(editor, shape.props.start) - const endShapeInfo = getBoundShapeInfoForTerminal(editor, shape.props.end) + const startShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'start') + const endShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'end') // The positions of the body of the arrow, which may be different // than the arrow's start / end points if the arrow is bound to shapes @@ -53,6 +54,7 @@ export function getCurvedArrowInfo( if (Vec.Equals(a, b)) { return { + bindings, isStraight: true, start: { handle: a, @@ -84,7 +86,7 @@ export function getCurvedArrowInfo( !isSafeFloat(handleArc.length) || !isSafeFloat(handleArc.size) ) { - return getStraightArrowInfo(editor, shape) + return getStraightArrowInfo(editor, shape, bindings) } const tempA = a.clone() @@ -341,6 +343,7 @@ export function getCurvedArrowInfo( const bodyArc = getArcInfo(a, b, c) return { + bindings, isStraight: false, start: { point: a, diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts index a42559ddb..19e923b14 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts @@ -1,4 +1,10 @@ -import { TLArrowShape, TLArrowShapeTerminal, TLShape, TLShapeId } from '@tldraw/tlschema' +import { + TLArrowBinding, + TLArrowBindingProps, + TLArrowShape, + TLShape, + TLShapeId, +} from '@tldraw/tlschema' import { Mat } from '../../../../primitives/Mat' import { Vec } from '../../../../primitives/Vec' import { Group2d } from '../../../../primitives/geometry/Group2d' @@ -19,15 +25,17 @@ export type BoundShapeInfo = { export function getBoundShapeInfoForTerminal( editor: Editor, - terminal: TLArrowShapeTerminal + arrow: TLArrowShape, + terminalName: 'start' | 'end' ): BoundShapeInfo | undefined { - if (terminal.type === 'point') { - return - } + const binding = editor + .getBindingsFromShape(arrow, 'arrow') + .find((b) => b.props.terminal === terminalName) + if (!binding) return - const shape = editor.getShape(terminal.boundShapeId)! - const transform = editor.getShapePageTransform(shape)! - const geometry = editor.getShapeGeometry(shape) + const boundShape = editor.getShape(binding.toId)! + const transform = editor.getShapePageTransform(boundShape)! + const geometry = editor.getShapeGeometry(boundShape) // This is hacky: we're only looking at the first child in the group. Really the arrow should // consider all items in the group which are marked as snappable as separate polygons with which @@ -36,10 +44,10 @@ export function getBoundShapeInfoForTerminal( const outline = geometry instanceof Group2d ? geometry.children[0].vertices : geometry.vertices return { - shape, + shape: boundShape, transform, isClosed: geometry.isClosed, - isExact: terminal.isExact, + isExact: binding.props.isExact, didIntersect: false, outline, } @@ -48,14 +56,10 @@ export function getBoundShapeInfoForTerminal( function getArrowTerminalInArrowSpace( editor: Editor, arrowPageTransform: Mat, - terminal: TLArrowShapeTerminal, + binding: TLArrowBinding, forceImprecise: boolean ) { - if (terminal.type === 'point') { - return Vec.From(terminal) - } - - const boundShape = editor.getShape(terminal.boundShapeId) + const boundShape = editor.getShape(binding.toId) if (!boundShape) { // this can happen in multiplayer contexts where the shape is being deleted @@ -69,7 +73,9 @@ function getArrowTerminalInArrowSpace( point, Vec.MulV( // if the parent is the bound shape, then it's ALWAYS precise - terminal.isPrecise || forceImprecise ? terminal.normalizedAnchor : { x: 0.5, y: 0.5 }, + binding.props.isPrecise || forceImprecise + ? binding.props.normalizedAnchor + : { x: 0.5, y: 0.5 }, size ) ) @@ -79,41 +85,113 @@ function getArrowTerminalInArrowSpace( } } -/** @public */ -export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape) { - const arrowPageTransform = editor.getShapePageTransform(shape)! +export interface TLArrowBindings { + start: TLArrowBinding | undefined + end: TLArrowBinding | undefined +} - let startBoundShapeId: TLShapeId | undefined - let endBoundShapeId: TLShapeId | undefined - - if (shape.props.start.type === 'binding' && shape.props.end.type === 'binding') { - startBoundShapeId = shape.props.start.boundShapeId - endBoundShapeId = shape.props.end.boundShapeId +/** @internal */ +export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBindings { + const bindings = editor.getBindingsFromShape(shape, 'arrow') + return { + start: bindings.find((b) => b.props.terminal === 'start'), + end: bindings.find((b) => b.props.terminal === 'end'), } +} + +/** @public */ +export function getArrowTerminalsInArrowSpace( + editor: Editor, + shape: TLArrowShape, + bindings: TLArrowBindings +) { + const arrowPageTransform = editor.getShapePageTransform(shape)! const boundShapeRelationships = getBoundShapeRelationships( editor, - startBoundShapeId, - endBoundShapeId + bindings.start?.toId, + bindings.end?.toId ) - const start = getArrowTerminalInArrowSpace( - editor, - arrowPageTransform, - shape.props.start, - boundShapeRelationships === 'double-bound' || boundShapeRelationships === 'start-contains-end' - ) + const start = bindings.start + ? getArrowTerminalInArrowSpace( + editor, + arrowPageTransform, + bindings.start, + boundShapeRelationships === 'double-bound' || + boundShapeRelationships === 'start-contains-end' + ) + : Vec.From(shape.props.start) - const end = getArrowTerminalInArrowSpace( - editor, - arrowPageTransform, - shape.props.end, - boundShapeRelationships === 'double-bound' || boundShapeRelationships === 'end-contains-start' - ) + const end = bindings.end + ? getArrowTerminalInArrowSpace( + editor, + arrowPageTransform, + bindings.end, + boundShapeRelationships === 'double-bound' || + boundShapeRelationships === 'end-contains-start' + ) + : Vec.From(shape.props.end) return { start, end } } +/** + * Create or update the arrow binding for a particular arrow terminal. Will clear up if needed. + * TODO(alex): find a better name for this + * @internal + */ +export function arrowBindingMakeItSo( + editor: Editor, + arrow: TLArrowShape | TLShapeId, + target: TLShape | TLShapeId, + props: TLArrowBindingProps +) { + const arrowId = typeof arrow === 'string' ? arrow : arrow.id + const targetId = typeof target === 'string' ? target : target.id + + const existingMany = editor + .getBindingsFromShape(arrowId, 'arrow') + .filter((b) => b.props.terminal === props.terminal) + + // if we've somehow ended up with too many bindings, delete the extras + if (existingMany.length > 1) { + editor.deleteBindings(existingMany.slice(1)) + } + + const existing = existingMany[0] + if (existing) { + editor.updateBinding({ + ...existing, + toId: targetId, + props, + }) + } else { + editor.createBinding({ + type: 'arrow', + fromId: arrowId, + toId: targetId, + props, + }) + } +} + +/** + * Remove any arrow bindings for a particular terminal. + * @internal + */ +export function arrowBindingMakeItNotSo( + editor: Editor, + arrow: TLArrowShape, + terminal: 'start' | 'end' +) { + const existing = editor + .getBindingsFromShape(arrow, 'arrow') + .filter((b) => b.props.terminal === terminal) + + editor.deleteBindings(existing) +} + /** @internal */ export const MIN_ARROW_LENGTH = 10 /** @internal */ diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts index db2373c56..4ecea943f 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts @@ -12,15 +12,20 @@ import { BoundShapeInfo, MIN_ARROW_LENGTH, STROKE_SIZES, + TLArrowBindings, getArrowTerminalsInArrowSpace, getBoundShapeInfoForTerminal, getBoundShapeRelationships, } from './shared' -export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArrowInfo { - const { start, end, arrowheadStart, arrowheadEnd } = shape.props +export function getStraightArrowInfo( + editor: Editor, + shape: TLArrowShape, + bindings: TLArrowBindings +): TLArrowInfo { + const { arrowheadStart, arrowheadEnd } = shape.props - const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape) + const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape, bindings) const a = terminalsInArrowSpace.start.clone() const b = terminalsInArrowSpace.end.clone() @@ -28,6 +33,7 @@ export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArr if (Vec.Equals(a, b)) { return { + bindings, isStraight: true, start: { handle: a, @@ -49,8 +55,8 @@ export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArr // Update the arrowhead points using intersections with the bound shapes, if any. - const startShapeInfo = getBoundShapeInfoForTerminal(editor, start) - const endShapeInfo = getBoundShapeInfoForTerminal(editor, end) + const startShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'start') + const endShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'end') const arrowPageTransform = editor.getShapePageTransform(shape)! @@ -189,6 +195,7 @@ export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArr const length = Vec.Dist(a, b) return { + bindings, isStraight: true, start: { handle: terminalsInArrowSpace.start, diff --git a/packages/editor/src/lib/test/currentToolIdMask.test.ts b/packages/editor/src/lib/test/currentToolIdMask.test.ts index ea60924ae..b76905c46 100644 --- a/packages/editor/src/lib/test/currentToolIdMask.test.ts +++ b/packages/editor/src/lib/test/currentToolIdMask.test.ts @@ -24,6 +24,7 @@ beforeEach(() => { editor = new Editor({ initialState: 'A', shapeUtils: [], + bindingUtils: [], tools: [A, B, C], store: createTLStore({ shapeUtils: [] }), getContainer: () => document.body, diff --git a/packages/editor/src/lib/test/user.test.ts b/packages/editor/src/lib/test/user.test.ts index 256ef3472..09c46d861 100644 --- a/packages/editor/src/lib/test/user.test.ts +++ b/packages/editor/src/lib/test/user.test.ts @@ -6,6 +6,7 @@ let editor: Editor beforeEach(() => { editor = new Editor({ shapeUtils: [], + bindingUtils: [], tools: [], store: createTLStore({ shapeUtils: [] }), getContainer: () => document.body, diff --git a/packages/store/api-report.md b/packages/store/api-report.md index cd2ba6928..e30dc7247 100644 --- a/packages/store/api-report.md +++ b/packages/store/api-report.md @@ -37,7 +37,7 @@ export type ComputedCache = { export function createEmptyRecordsDiff(): RecordsDiff; // @public -export function createMigrationIds>(sequenceId: ID, versions: Versions): { +export function createMigrationIds>(sequenceId: ID, versions: Versions): { [K in keyof Versions]: `${ID}/${Versions[K]}`; }; @@ -265,6 +265,11 @@ export function squashRecordDiffs(diffs: RecordsDiff // @internal export function squashRecordDiffsMutable(target: RecordsDiff, diffs: RecordsDiff[]): void; +// @public (undocumented) +export type StandaloneDependsOn = { + readonly dependsOn: readonly MigrationId[]; +}; + // @public export class Store { constructor(config: { diff --git a/packages/store/src/index.ts b/packages/store/src/index.ts index 2216ec4d7..fe45f3f04 100644 --- a/packages/store/src/index.ts +++ b/packages/store/src/index.ts @@ -43,5 +43,6 @@ export { type MigrationId, type MigrationResult, type MigrationSequence, + type StandaloneDependsOn, } from './lib/migrate' export type { AllRecords } from './lib/type-utils' diff --git a/packages/store/src/lib/StoreSchema.ts b/packages/store/src/lib/StoreSchema.ts index 1f6fcdccd..f04aa3bd9 100644 --- a/packages/store/src/lib/StoreSchema.ts +++ b/packages/store/src/lib/StoreSchema.ts @@ -121,7 +121,12 @@ export class StoreSchema { if (!migration.dependsOn?.length) continue for (const dep of migration.dependsOn) { const depMigration = allMigrations.find((m) => m.id === dep) - assert(depMigration, `Migration '${migration.id}' depends on missing migration '${dep}'`) + // TODO: we can't assert here because the store migrations depend on the arrow + // migrations, but the arrow migrations might not be present if we're using the + // editor without arrows :/ + if (!depMigration) { + console.warn(`Migration '${migration.id}' depends on missing migration '${dep}'`) + } } } } diff --git a/packages/store/src/lib/migrate.ts b/packages/store/src/lib/migrate.ts index 95dc30c7e..7782eb9cf 100644 --- a/packages/store/src/lib/migrate.ts +++ b/packages/store/src/lib/migrate.ts @@ -91,10 +91,10 @@ export function createMigrationSequence({ * @public * @public */ -export function createMigrationIds>( - sequenceId: ID, - versions: Versions -): { [K in keyof Versions]: `${ID}/${Versions[K]}` } { +export function createMigrationIds< + const ID extends string, + const Versions extends Record, +>(sequenceId: ID, versions: Versions): { [K in keyof Versions]: `${ID}/${Versions[K]}` } { return Object.fromEntries( objectMapEntries(versions).map(([key, version]) => [key, `${sequenceId}/${version}`] as const) ) as any @@ -136,6 +136,7 @@ export type LegacyMigration = { /** @public */ export type MigrationId = `${string}/${number}` +/** @public */ export type StandaloneDependsOn = { readonly dependsOn: readonly MigrationId[] } diff --git a/packages/tldraw/api-report.md b/packages/tldraw/api-report.md index 85fbf14d2..9a329cd46 100644 --- a/packages/tldraw/api-report.md +++ b/packages/tldraw/api-report.md @@ -33,7 +33,6 @@ import { MemoExoticComponent } from 'react'; import { MigrationFailureReason } from '@tldraw/editor'; import { MigrationSequence } from '@tldraw/editor'; import { NamedExoticComponent } from 'react'; -import { ObjectValidator } from '@tldraw/editor'; import { Polygon2d } from '@tldraw/editor'; import { Polyline2d } from '@tldraw/editor'; import { default as React_2 } from 'react'; @@ -54,6 +53,7 @@ import { StoreSnapshot } from '@tldraw/editor'; import { StyleProp } from '@tldraw/editor'; import { SvgExportContext } from '@tldraw/editor'; import { T } from '@tldraw/editor'; +import { TLAnyBindingUtilConstructor } from '@tldraw/editor'; import { TLAnyShapeUtilConstructor } from '@tldraw/editor'; import { TLArrowShape } from '@tldraw/editor'; import { TLAssetId } from '@tldraw/editor'; @@ -102,6 +102,7 @@ import { TLParentId } from '@tldraw/editor'; import { TLPointerEvent } from '@tldraw/editor'; import { TLPointerEventInfo } from '@tldraw/editor'; import { TLPointerEventName } from '@tldraw/editor'; +import { TLPropsMigrations } from '@tldraw/editor'; import { TLRecord } from '@tldraw/editor'; import { TLRotationSnapshot } from '@tldraw/editor'; import { TLSchema } from '@tldraw/editor'; @@ -112,7 +113,6 @@ import { TLSelectionHandle } from '@tldraw/editor'; import { TLShape } from '@tldraw/editor'; import { TLShapeId } from '@tldraw/editor'; import { TLShapePartial } from '@tldraw/editor'; -import { TLShapePropsMigrations } from '@tldraw/editor'; import { TLShapeUtilCanvasSvgDef } from '@tldraw/editor'; import { TLShapeUtilFlag } from '@tldraw/editor'; import { TLStore } from '@tldraw/editor'; @@ -121,7 +121,6 @@ import { TLSvgOptions } from '@tldraw/editor'; import { TLTextShape } from '@tldraw/editor'; import { TLUnknownShape } from '@tldraw/editor'; import { TLVideoShape } from '@tldraw/editor'; -import { UnionValidator } from '@tldraw/editor'; import { UnknownRecord } from '@tldraw/editor'; import { Validator } from '@tldraw/editor'; import { Vec } from '@tldraw/editor'; @@ -192,7 +191,7 @@ export class ArrowShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLArrowShape): JSX_2.Element | null; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: MigrationSequence; // (undocumented) onDoubleClickHandle: (shape: TLArrowShape, handle: TLHandle) => TLShapePartial | void; // (undocumented) @@ -212,39 +211,13 @@ export class ArrowShapeUtil extends ShapeUtil { bend: Validator; color: EnumStyleProp<"black" | "blue" | "green" | "grey" | "light-blue" | "light-green" | "light-red" | "light-violet" | "orange" | "red" | "violet" | "white" | "yellow">; dash: EnumStyleProp<"dashed" | "dotted" | "draw" | "solid">; - end: UnionValidator<"type", { - binding: ObjectValidator< { - boundShapeId: TLShapeId; - isExact: boolean; - isPrecise: boolean; - normalizedAnchor: VecModel; - type: "binding"; - }>; - point: ObjectValidator< { - type: "point"; - x: number; - y: number; - }>; - }, never>; + end: Validator; fill: EnumStyleProp<"none" | "pattern" | "semi" | "solid">; font: EnumStyleProp<"draw" | "mono" | "sans" | "serif">; labelColor: EnumStyleProp<"black" | "blue" | "green" | "grey" | "light-blue" | "light-green" | "light-red" | "light-violet" | "orange" | "red" | "violet" | "white" | "yellow">; labelPosition: Validator; size: EnumStyleProp<"l" | "m" | "s" | "xl">; - start: UnionValidator<"type", { - binding: ObjectValidator< { - boundShapeId: TLShapeId; - isExact: boolean; - isPrecise: boolean; - normalizedAnchor: VecModel; - type: "binding"; - }>; - point: ObjectValidator< { - type: "point"; - x: number; - y: number; - }>; - }, never>; + start: Validator; text: Validator; }; // (undocumented) @@ -281,7 +254,7 @@ export class BookmarkShapeUtil extends BaseBoxShapeUtil { // (undocumented) indicator(shape: TLBookmarkShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate?: TLOnBeforeCreateHandler; // (undocumented) @@ -492,7 +465,7 @@ export class DrawShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLDrawShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onResize: TLOnResizeHandler; // (undocumented) @@ -549,7 +522,7 @@ export class EmbedShapeUtil extends BaseBoxShapeUtil { // (undocumented) isAspectRatioLocked: TLShapeUtilFlag; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onResize: TLOnResizeHandler; // (undocumented) @@ -656,7 +629,7 @@ export class FrameShapeUtil extends BaseBoxShapeUtil { // (undocumented) indicator(shape: TLFrameShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onDragShapesOut: (_shape: TLFrameShape, shapes: TLShape[]) => void; // (undocumented) @@ -709,7 +682,7 @@ export class GeoShapeUtil extends BaseBoxShapeUtil { // (undocumented) indicator(shape: TLGeoShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate: (shape: TLGeoShape) => { id: TLShapeId; @@ -923,7 +896,7 @@ export class HighlightShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLHighlightShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onResize: TLOnResizeHandler; // (undocumented) @@ -961,7 +934,7 @@ export class ImageShapeUtil extends BaseBoxShapeUtil { // (undocumented) isAspectRatioLocked: () => boolean; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onDoubleClick: (shape: TLImageShape) => void; // (undocumented) @@ -1065,7 +1038,7 @@ export class LineShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLLineShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onHandleDrag: TLOnHandleDragHandler; // (undocumented) @@ -1129,7 +1102,7 @@ export class NoteShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLNoteShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate: (next: TLNoteShape) => { id: TLShapeId; @@ -1376,7 +1349,7 @@ export class TextShapeUtil extends ShapeUtil { // (undocumented) isAspectRatioLocked: TLShapeUtilFlag; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate: (shape: TLTextShape) => { id: TLShapeId; @@ -1496,6 +1469,7 @@ export function TldrawHandles({ children }: TLHandlesProps): JSX_2.Element | nul // @public export const TldrawImage: NamedExoticComponent< { background?: boolean | undefined; +bindingUtils?: readonly TLAnyBindingUtilConstructor[] | undefined; bounds?: Box | undefined; darkMode?: boolean | undefined; format?: "png" | "svg" | undefined; @@ -1509,6 +1483,7 @@ snapshot: StoreSnapshot; // @public export type TldrawImageProps = Expand<{ + bindingUtils?: readonly TLAnyBindingUtilConstructor[]; shapeUtils?: readonly TLAnyShapeUtilConstructor[]; format?: 'png' | 'svg'; pageId?: TLPageId; @@ -2668,7 +2643,7 @@ export class VideoShapeUtil extends BaseBoxShapeUtil { // (undocumented) isAspectRatioLocked: () => boolean; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) static props: { assetId: Validator; diff --git a/packages/tldraw/src/lib/Tldraw.tsx b/packages/tldraw/src/lib/Tldraw.tsx index 159cf2023..86c2d3f0c 100644 --- a/packages/tldraw/src/lib/Tldraw.tsx +++ b/packages/tldraw/src/lib/Tldraw.tsx @@ -23,6 +23,7 @@ import { TldrawHandles } from './canvas/TldrawHandles' import { TldrawScribble } from './canvas/TldrawScribble' import { TldrawSelectionBackground } from './canvas/TldrawSelectionBackground' import { TldrawSelectionForeground } from './canvas/TldrawSelectionForeground' +import { defaultBindingUtils } from './defaultBindingUtils' import { TLExternalContentProps, registerDefaultExternalContentHandlers, @@ -79,6 +80,7 @@ export function Tldraw(props: TldrawProps) { onMount, components = {}, shapeUtils = [], + bindingUtils = [], tools = [], ...rest } = props @@ -102,6 +104,12 @@ export function Tldraw(props: TldrawProps) { [_shapeUtils] ) + const _bindingUtils = useShallowArrayIdentity(bindingUtils) + const bindingUtilsWithDefaults = useMemo( + () => [...defaultBindingUtils, ..._bindingUtils], + [_bindingUtils] + ) + const _tools = useShallowArrayIdentity(tools) const toolsWithDefaults = useMemo( () => [...defaultTools, ...defaultShapeTools, ..._tools], @@ -123,6 +131,7 @@ export function Tldraw(props: TldrawProps) { {...rest} components={componentsWithDefault} shapeUtils={shapeUtilsWithDefaults} + bindingUtils={bindingUtilsWithDefaults} tools={toolsWithDefaults} > diff --git a/packages/tldraw/src/lib/TldrawImage.tsx b/packages/tldraw/src/lib/TldrawImage.tsx index f3ce3cd95..8c345ee89 100644 --- a/packages/tldraw/src/lib/TldrawImage.tsx +++ b/packages/tldraw/src/lib/TldrawImage.tsx @@ -4,6 +4,7 @@ import { Expand, LoadingScreen, StoreSnapshot, + TLAnyBindingUtilConstructor, TLAnyShapeUtilConstructor, TLPageId, TLRecord, @@ -12,6 +13,7 @@ import { useTLStore, } from '@tldraw/editor' import { memo, useLayoutEffect, useMemo, useState } from 'react' +import { defaultBindingUtils } from './defaultBindingUtils' import { defaultShapeUtils } from './defaultShapeUtils' import { usePreloadAssets } from './ui/hooks/usePreloadAssets' import { getSvgAsImage } from './utils/export/export' @@ -43,6 +45,10 @@ export type TldrawImageProps = Expand< * Additional shape utils to use. */ shapeUtils?: readonly TLAnyShapeUtilConstructor[] + /** + * Additional binding utils to use. + */ + bindingUtils?: readonly TLAnyBindingUtilConstructor[] } & Partial > @@ -69,6 +75,11 @@ export const TldrawImage = memo(function TldrawImage(props: TldrawImageProps) { const shapeUtils = useShallowArrayIdentity(props.shapeUtils ?? []) const shapeUtilsWithDefaults = useMemo(() => [...defaultShapeUtils, ...shapeUtils], [shapeUtils]) + const bindingUtils = useShallowArrayIdentity(props.bindingUtils ?? []) + const bindingUtilsWithDefaults = useMemo( + () => [...defaultBindingUtils, ...bindingUtils], + [bindingUtils] + ) const store = useTLStore({ snapshot: props.snapshot, shapeUtils: shapeUtilsWithDefaults }) const assets = useDefaultEditorAssetsWithOverrides() @@ -98,7 +109,8 @@ export const TldrawImage = memo(function TldrawImage(props: TldrawImageProps) { const editor = new Editor({ store, - shapeUtils: shapeUtilsWithDefaults ?? [], + shapeUtils: shapeUtilsWithDefaults, + bindingUtils: bindingUtilsWithDefaults, tools: [], getContainer: () => tempElm, }) @@ -152,6 +164,7 @@ export const TldrawImage = memo(function TldrawImage(props: TldrawImageProps) { container, store, shapeUtilsWithDefaults, + bindingUtilsWithDefaults, pageId, bounds, scale, diff --git a/packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts b/packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts new file mode 100644 index 000000000..166262838 --- /dev/null +++ b/packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts @@ -0,0 +1,21 @@ +import { + BindingUtil, + TLArrowBindingProps, + arrowBindingMigrations, + arrowBindingProps, +} from '@tldraw/editor' + +export class ArrowBindingUtil extends BindingUtil { + static override type = 'arrow' + + static override props = arrowBindingProps + static override migrations = arrowBindingMigrations + + override getDefaultProps(): Partial { + return { + isPrecise: false, + isExact: false, + normalizedAnchor: { x: 0.5, y: 0.5 }, + } + } +} diff --git a/packages/tldraw/src/lib/defaultBindingUtils.ts b/packages/tldraw/src/lib/defaultBindingUtils.ts new file mode 100644 index 000000000..d0fc6f0cf --- /dev/null +++ b/packages/tldraw/src/lib/defaultBindingUtils.ts @@ -0,0 +1,4 @@ +import { TLAnyBindingUtilConstructor } from '@tldraw/editor' +import { ArrowBindingUtil } from './bindings/arrow/ArrowBindingUtil' + +export const defaultBindingUtils: TLAnyBindingUtilConstructor[] = [ArrowBindingUtil] diff --git a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts index 50e698626..5e5fe7818 100644 --- a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts +++ b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts @@ -1,4 +1,4 @@ -import { IndexKey, TLArrowShape, Vec, createShapeId } from '@tldraw/editor' +import { IndexKey, TLArrowShape, Vec, createShapeId, getArrowBindings } from '@tldraw/editor' import { TestEditor } from '../../../test/TestEditor' let editor: TestEditor @@ -530,8 +530,8 @@ describe('line bug', () => { .keyUp('Shift') expect(editor.getCurrentPageShapes().length).toBe(2) - const arrow = editor.getCurrentPageShapes()[1] as TLArrowShape - expect(arrow.props.end.type).toBe('binding') + const bindings = getArrowBindings(editor, editor.getCurrentPageShapes()[1] as TLArrowShape) + expect(bindings.end).toBeDefined() }) it('works as expected when binding to a straight horizontal line', () => { @@ -552,7 +552,7 @@ describe('line bug', () => { .pointerUp() expect(editor.getCurrentPageShapes().length).toBe(2) - const arrow = editor.getCurrentPageShapes()[1] as TLArrowShape - expect(arrow.props.end.type).toBe('binding') + const bindings = getArrowBindings(editor, editor.getCurrentPageShapes()[1] as TLArrowShape) + expect(bindings.end).toBeDefined() }) }) diff --git a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts index 2105d6d4a..fe9f8084b 100644 --- a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts +++ b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts @@ -1,11 +1,13 @@ import { - assert, - createShapeId, HALF_PI, TLArrowShape, - TLArrowShapeTerminal, TLShapeId, + arrowBindingMakeItSo, + assert, + createShapeId, + getArrowBindings, } from '@tldraw/editor' +import { describe } from 'node:test' import { TestEditor } from '../../../test/TestEditor' let editor: TestEditor @@ -42,23 +44,25 @@ beforeEach(() => { x: 150, y: 150, props: { - start: { - type: 'binding', - isExact: false, - boundShapeId: ids.box1, - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - }, - end: { - type: 'binding', - isExact: false, - boundShapeId: ids.box2, - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - }, + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, }, }, ]) + + arrowBindingMakeItSo(editor, ids.arrow1, ids.box1, { + terminal: 'start', + isExact: false, + isPrecise: false, + normalizedAnchor: { x: 0.5, y: 0.5 }, + }) + + arrowBindingMakeItSo(editor, ids.arrow1, ids.box2, { + terminal: 'end', + isExact: false, + isPrecise: false, + normalizedAnchor: { x: 0.5, y: 0.5 }, + }) }) describe('When translating a bound shape', () => { @@ -93,6 +97,11 @@ describe('When translating a bound shape', () => { }, }, }) + expect(getArrowBindings(editor, editor.getShape(ids.arrow1)!)).toMatchObject({ + start: { + toId: ids.box1, + }, + }) }) it('updates the arrow when curved', () => { @@ -300,8 +309,9 @@ describe('Other cases when arrow are moved', () => { editor.setCurrentTool('arrow').pointerDown(1000, 1000).pointerMove(50, 350).pointerUp(50, 350) let arrow = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1] assert(editor.isShapeOfType(arrow, 'arrow')) - assert(arrow.props.end.type === 'binding') - expect(arrow.props.end.boundShapeId).toBe(ids.box3) + let bindings = getArrowBindings(editor, arrow) + assert(bindings.end) + expect(bindings.end.toId).toBe(ids.box3) // translate: editor.selectAll().nudgeShapes(editor.getSelectedShapeIds(), { x: 0, y: 1 }) @@ -309,8 +319,9 @@ describe('Other cases when arrow are moved', () => { // arrow should still be bound to box3 arrow = editor.getShape(arrow.id)! assert(editor.isShapeOfType(arrow, 'arrow')) - assert(arrow.props.end.type === 'binding') - expect(arrow.props.end.boundShapeId).toBe(ids.box3) + bindings = getArrowBindings(editor, arrow) + assert(bindings.end) + expect(bindings.end.toId).toBe(ids.box3) }) }) @@ -342,11 +353,7 @@ describe('When a shape is rotated', () => { }, }) - const anchor = ( - editor.getShape(arrow.id)!.props.end as TLArrowShapeTerminal & { - type: 'binding' - } - ).normalizedAnchor + const anchor = getArrowBindings(editor, editor.getShape(arrow.id)!).end!.props.normalizedAnchor expect(anchor.x).toBeCloseTo(0.5) expect(anchor.y).toBeCloseTo(0.75) }) diff --git a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx index 7036ee1b6..959717f69 100644 --- a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx +++ b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx @@ -9,8 +9,8 @@ import { SVGContainer, ShapeUtil, SvgExportContext, + TLArrowBinding, TLArrowShape, - TLArrowShapeProps, TLHandle, TLOnEditEndHandler, TLOnHandleDragHandler, @@ -21,12 +21,14 @@ import { TLShapeUtilCanvasSvgDef, TLShapeUtilFlag, Vec, + arrowBindingMakeItNotSo, + arrowBindingMakeItSo, arrowShapeMigrations, arrowShapeProps, + getArrowBindings, getArrowTerminalsInArrowSpace, getDefaultColorTheme, mapObjectMapValues, - objectMapEntries, structuredClone, toDomPrecision, track, @@ -83,8 +85,8 @@ export class ArrowShapeUtil extends ShapeUtil { color: 'black', labelColor: 'black', bend: 0, - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 2, y: 0 }, + start: { x: 0, y: 0 }, + end: { x: 2, y: 0 }, arrowheadStart: 'none', arrowheadEnd: 'arrow', text: '', @@ -164,10 +166,11 @@ export class ArrowShapeUtil extends ShapeUtil { override onHandleDrag: TLOnHandleDragHandler = (shape, { handle, isPrecise }) => { const handleId = handle.id as ARROW_HANDLES + const bindings = getArrowBindings(this.editor, shape) if (handleId === ARROW_HANDLES.MIDDLE) { // Bending the arrow... - const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape) + const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape, bindings) const delta = Vec.Sub(end, start) const v = Vec.Per(delta) @@ -186,11 +189,17 @@ export class ArrowShapeUtil extends ShapeUtil { const next = structuredClone(shape) as TLArrowShape + const currentBinding = bindings[handleId] + + const otherHandleId = handleId === ARROW_HANDLES.START ? ARROW_HANDLES.END : ARROW_HANDLES.START + const otherBinding = bindings[otherHandleId] + if (this.editor.inputs.ctrlKey) { // todo: maybe double check that this isn't equal to the other handle too? // Skip binding + arrowBindingMakeItNotSo(this.editor, shape, handleId) + next.props[handleId] = { - type: 'point', x: handle.x, y: handle.y, } @@ -210,8 +219,9 @@ export class ArrowShapeUtil extends ShapeUtil { if (!target) { // todo: maybe double check that this isn't equal to the other handle too? + arrowBindingMakeItNotSo(this.editor, shape, handleId) + next.props[handleId] = { - type: 'point', x: handle.x, y: handle.y, } @@ -230,11 +240,7 @@ export class ArrowShapeUtil extends ShapeUtil { if (!precise) { // If we're switching to a new bound shape, then precise only if moving slowly - const prevHandle = next.props[handleId] - if ( - prevHandle.type === 'point' || - (prevHandle.type === 'binding' && target.id !== prevHandle.boundShapeId) - ) { + if (!currentBinding || (currentBinding && target.id !== currentBinding.toId)) { precise = this.editor.inputs.pointerVelocity.len() < 0.5 } } @@ -246,13 +252,7 @@ export class ArrowShapeUtil extends ShapeUtil { // Double check that we're not going to be doing an imprecise snap on // the same shape twice, as this would result in a zero length line - const otherHandle = - next.props[handleId === ARROW_HANDLES.START ? ARROW_HANDLES.END : ARROW_HANDLES.START] - if ( - otherHandle.type === 'binding' && - target.id === otherHandle.boundShapeId && - otherHandle.isPrecise - ) { + if (otherBinding && target.id === otherBinding.toId && otherBinding.props.isPrecise) { precise = true } } @@ -276,64 +276,58 @@ export class ArrowShapeUtil extends ShapeUtil { } } - next.props[handleId] = { - type: 'binding', - boundShapeId: target.id, - normalizedAnchor: normalizedAnchor, + arrowBindingMakeItSo(this.editor, shape, target.id, { + terminal: handleId, + normalizedAnchor, isPrecise: precise, isExact: this.editor.inputs.altKey, - } + }) - if (next.props.start.type === 'binding' && next.props.end.type === 'binding') { - if (next.props.start.boundShapeId === next.props.end.boundShapeId) { - if (Vec.Equals(next.props.start.normalizedAnchor, next.props.end.normalizedAnchor)) { - next.props.end.normalizedAnchor.x += 0.05 - } - } - } + this.editor.setHintingShapes([target.id]) + + // TODO(alex): restore this if we can + // if (next.props.start.type === 'binding' && next.props.end.type === 'binding') { + // if (next.props.start.boundShapeId === next.props.end.boundShapeId) { + // if (Vec.Equals(next.props.start.normalizedAnchor, next.props.end.normalizedAnchor)) { + // next.props.end.normalizedAnchor.x += 0.05 + // } + // } + // } return next } override onTranslateStart: TLOnTranslateStartHandler = (shape) => { - const startBindingId = - shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : null - const endBindingId = shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : null + const bindings = getArrowBindings(this.editor, shape) - const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(this.editor, shape) + const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(this.editor, shape, bindings) const shapePageTransform = this.editor.getShapePageTransform(shape.id)! // If at least one bound shape is in the selection, do nothing; // If no bound shapes are in the selection, unbind any bound shapes const selectedShapeIds = this.editor.getSelectedShapeIds() - const shapesToCheck = new Set() - if (startBindingId) { - // Add shape and all ancestors to set - shapesToCheck.add(startBindingId) - this.editor.getShapeAncestors(startBindingId).forEach((a) => shapesToCheck.add(a.id)) - } - if (endBindingId) { - // Add shape and all ancestors to set - shapesToCheck.add(endBindingId) - this.editor.getShapeAncestors(endBindingId).forEach((a) => shapesToCheck.add(a.id)) - } - // If any of the shapes are selected, return - for (const id of selectedShapeIds) { - if (shapesToCheck.has(id)) return - } - let result = shape + if ( + (bindings.start && + (selectedShapeIds.includes(bindings.start.toId) || + this.editor.isAncestorSelected(bindings.start.toId))) || + (bindings.end && + (selectedShapeIds.includes(bindings.end.toId) || + this.editor.isAncestorSelected(bindings.end.toId))) + ) { + return + } // When we start translating shapes, record where their bindings were in page space so we // can maintain them as we translate the arrow shapeAtTranslationStart.set(shape, { pagePosition: shapePageTransform.applyToPoint(shape), terminalBindings: mapObjectMapValues(terminalsInArrowSpace, (terminalName, point) => { - const terminal = shape.props[terminalName] - if (terminal.type !== 'binding') return null + const binding = bindings[terminalName] + if (!binding) return null return { - binding: terminal, + binding, shapePosition: point, pagePosition: shapePageTransform.applyToPoint(point), } @@ -341,15 +335,16 @@ export class ArrowShapeUtil extends ShapeUtil { }) for (const handleName of [ARROW_HANDLES.START, ARROW_HANDLES.END] as const) { - const terminal = shape.props[handleName] - if (terminal.type !== 'binding') continue - result = { - ...shape, - props: { ...shape.props, [handleName]: { ...terminal, isPrecise: true } }, - } + const binding = bindings[handleName] + if (!binding) continue + + this.editor.updateBinding({ + ...binding, + props: { ...binding.props, isPrecise: true }, + }) } - return result + return } override onTranslate?: TLOnTranslateHandler = (initialShape, shape) => { @@ -362,10 +357,7 @@ export class ArrowShapeUtil extends ShapeUtil { atTranslationStart.pagePosition ) - let result = shape - for (const [terminalName, terminalBinding] of objectMapEntries( - atTranslationStart.terminalBindings - )) { + for (const terminalBinding of Object.values(atTranslationStart.terminalBindings)) { if (!terminalBinding) continue const newPagePoint = Vec.Add(terminalBinding.pagePosition, Vec.Mul(pageDelta, 0.5)) @@ -378,54 +370,41 @@ export class ArrowShapeUtil extends ShapeUtil { }, }) - if (newTarget?.id === terminalBinding.binding.boundShapeId) { + if (newTarget?.id === terminalBinding.binding.toId) { const targetBounds = Box.ZeroFix(this.editor.getShapeGeometry(newTarget).bounds) const pointInTargetSpace = this.editor.getPointInShapeSpace(newTarget, newPagePoint) const normalizedAnchor = { x: (pointInTargetSpace.x - targetBounds.minX) / targetBounds.width, y: (pointInTargetSpace.y - targetBounds.minY) / targetBounds.height, } - result = { - ...result, - props: { - ...result.props, - [terminalName]: { ...terminalBinding.binding, isPrecise: true, normalizedAnchor }, - }, - } + arrowBindingMakeItSo(this.editor, shape, newTarget.id, { + ...terminalBinding.binding.props, + normalizedAnchor, + isPrecise: true, + }) } else { - result = { - ...result, - props: { - ...result.props, - [terminalName]: { - type: 'point', - x: terminalBinding.shapePosition.x, - y: terminalBinding.shapePosition.y, - }, - }, - } + arrowBindingMakeItNotSo(this.editor, shape, terminalBinding.binding.props.terminal) } } - - return result } override onResize: TLOnResizeHandler = (shape, info) => { const { scaleX, scaleY } = info - const terminals = getArrowTerminalsInArrowSpace(this.editor, shape) + const bindings = getArrowBindings(this.editor, shape) + const terminals = getArrowTerminalsInArrowSpace(this.editor, shape, bindings) const { start, end } = structuredClone(shape.props) let { bend } = shape.props // Rescale start handle if it's not bound to a shape - if (start.type === 'point') { + if (!bindings.start) { start.x = terminals.start.x * scaleX start.y = terminals.start.y * scaleY } // Rescale end handle if it's not bound to a shape - if (end.type === 'point') { + if (!bindings.end) { end.x = terminals.end.x * scaleX end.y = terminals.end.y * scaleY } @@ -436,18 +415,23 @@ export class ArrowShapeUtil extends ShapeUtil { const mx = Math.abs(scaleX) const my = Math.abs(scaleY) + const startNormalizedAnchor = bindings?.start + ? Vec.From(bindings.start.props.normalizedAnchor) + : null + const endNormalizedAnchor = bindings?.end ? Vec.From(bindings.end.props.normalizedAnchor) : null + if (scaleX < 0 && scaleY >= 0) { if (bend !== 0) { bend *= -1 bend *= Math.max(mx, my) } - if (start.type === 'binding') { - start.normalizedAnchor.x = 1 - start.normalizedAnchor.x + if (startNormalizedAnchor) { + startNormalizedAnchor.x = 1 - startNormalizedAnchor.x } - if (end.type === 'binding') { - end.normalizedAnchor.x = 1 - end.normalizedAnchor.x + if (endNormalizedAnchor) { + endNormalizedAnchor.x = 1 - endNormalizedAnchor.x } } else if (scaleX >= 0 && scaleY < 0) { if (bend !== 0) { @@ -455,12 +439,12 @@ export class ArrowShapeUtil extends ShapeUtil { bend *= Math.max(mx, my) } - if (start.type === 'binding') { - start.normalizedAnchor.y = 1 - start.normalizedAnchor.y + if (startNormalizedAnchor) { + startNormalizedAnchor.y = 1 - startNormalizedAnchor.y } - if (end.type === 'binding') { - end.normalizedAnchor.y = 1 - end.normalizedAnchor.y + if (endNormalizedAnchor) { + endNormalizedAnchor.y = 1 - endNormalizedAnchor.y } } else if (scaleX >= 0 && scaleY >= 0) { if (bend !== 0) { @@ -471,17 +455,30 @@ export class ArrowShapeUtil extends ShapeUtil { bend *= Math.max(mx, my) } - if (start.type === 'binding') { - start.normalizedAnchor.x = 1 - start.normalizedAnchor.x - start.normalizedAnchor.y = 1 - start.normalizedAnchor.y + if (startNormalizedAnchor) { + startNormalizedAnchor.x = 1 - startNormalizedAnchor.x + startNormalizedAnchor.y = 1 - startNormalizedAnchor.y } - if (end.type === 'binding') { - end.normalizedAnchor.x = 1 - end.normalizedAnchor.x - end.normalizedAnchor.y = 1 - end.normalizedAnchor.y + if (endNormalizedAnchor) { + endNormalizedAnchor.x = 1 - endNormalizedAnchor.x + endNormalizedAnchor.y = 1 - endNormalizedAnchor.y } } + if (bindings.start && startNormalizedAnchor) { + arrowBindingMakeItSo(this.editor, shape, bindings.start.toId, { + ...bindings.start.props, + normalizedAnchor: startNormalizedAnchor, + }) + } + if (bindings.end && endNormalizedAnchor) { + arrowBindingMakeItSo(this.editor, shape, bindings.end.toId, { + ...bindings.end.props, + normalizedAnchor: endNormalizedAnchor, + }) + } + const next = { props: { start, @@ -565,18 +562,18 @@ export class ArrowShapeUtil extends ShapeUtil { } indicator(shape: TLArrowShape) { - const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape) + // eslint-disable-next-line react-hooks/rules-of-hooks + const isEditing = useIsEditing(shape.id) const info = this.editor.getArrowInfo(shape) + if (!info) return null + + const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape, info?.bindings) const geometry = this.editor.getShapeGeometry(shape) const bounds = geometry.bounds const labelGeometry = shape.props.text.trim() ? (geometry.children[1] as Rectangle2d) : null - // eslint-disable-next-line react-hooks/rules-of-hooks - const isEditing = useIsEditing(shape.id) - - if (!info) return null if (Vec.Equals(start, end)) return null const strokeWidth = STROKE_SIZES[shape.props.size] @@ -753,6 +750,7 @@ const ArrowSvg = track(function ArrowSvg({ const theme = useDefaultColorTheme() const info = editor.getArrowInfo(shape) const bounds = Box.ZeroFix(editor.getShapeGeometry(shape).bounds) + const bindings = getArrowBindings(editor, shape) const changeIndex = React.useMemo(() => { return editor.environment.isSafari ? (globalRenderIndex += 1) : 0 @@ -783,7 +781,7 @@ const ArrowSvg = track(function ArrowSvg({ ) handlePath = - shape.props.start.type === 'binding' || shape.props.end.type === 'binding' ? ( + bindings.start || bindings.end ? ( + binding: TLArrowBinding } | null > } diff --git a/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts b/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts index edef80479..a43d9847b 100644 --- a/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts +++ b/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts @@ -268,8 +268,8 @@ export function getArrowLabelPosition(editor: Editor, shape: TLArrowShape) { const debugGeom: Geometry2d[] = [] const info = editor.getArrowInfo(shape)! - const hasStartBinding = shape.props.start.type === 'binding' - const hasEndBinding = shape.props.end.type === 'binding' + const hasStartBinding = !!info.bindings.start + const hasEndBinding = !!info.bindings.end const hasStartArrowhead = info.start.arrowhead !== 'none' const hasEndArrowhead = info.end.arrowhead !== 'none' if (info.isStraight) { diff --git a/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts b/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts index 9e0b9414d..8b58bb3e8 100644 --- a/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts +++ b/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts @@ -111,10 +111,6 @@ export class Pointing extends StateNode { }) if (change) { - const startTerminal = change.props?.start - if (startTerminal?.type === 'binding') { - this.editor.setHintingShapes([startTerminal.boundShapeId]) - } this.editor.updateShapes([change]) } @@ -148,10 +144,6 @@ export class Pointing extends StateNode { }) if (change) { - const endTerminal = change.props?.end - if (endTerminal?.type === 'binding') { - this.editor.setHintingShapes([endTerminal.boundShapeId]) - } this.editor.updateShapes([change]) } } diff --git a/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx b/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx index aab15b79d..3268743b1 100644 --- a/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx +++ b/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx @@ -1,7 +1,6 @@ import { StateNode, TLArrowShape, - TLArrowShapeTerminal, TLCancelEvent, TLEnterEventHandler, TLEventHandlers, @@ -12,6 +11,7 @@ import { TLShapeId, TLShapePartial, Vec, + getArrowBindings, snapAngle, sortByIndex, structuredClone, @@ -112,16 +112,16 @@ export class DraggingHandle extends StateNode { //