From 0ab3698db7b8a837d6696daefb2136750ebccfe5 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Sat, 6 May 2023 14:37:23 -0700 Subject: [PATCH] Bluesky: move most ATProto code to separate arroba library https://github.com/snarfed/arroba --- atproto.py | 7 +- atproto_diff.py | 283 --------- atproto_mst.py | 1119 ------------------------------------ atproto_util.py | 182 ------ docs/conf.py | 1 + docs/source/modules.rst | 4 - models.py | 2 +- requirements.txt | 1 + tests/test_atproto.py | 8 +- tests/test_atproto_diff.py | 62 -- tests/test_atproto_mst.py | 286 --------- tests/test_atproto_util.py | 46 -- tests/test_models.py | 2 +- tests/testutil.py | 6 +- 14 files changed, 13 insertions(+), 1996 deletions(-) delete mode 100644 atproto_diff.py delete mode 100644 atproto_mst.py delete mode 100644 atproto_util.py delete mode 100644 tests/test_atproto_diff.py delete mode 100644 tests/test_atproto_mst.py delete mode 100644 tests/test_atproto_util.py diff --git a/atproto.py b/atproto.py index 0561099..8d3746a 100644 --- a/atproto.py +++ b/atproto.py @@ -5,6 +5,8 @@ import logging import random import re +from arroba.mst import MST, serialize_node_data +from arroba.util import dag_cbor_cid, sign_commit from Crypto.PublicKey import ECC import dag_cbor.encoding from flask import g @@ -13,11 +15,6 @@ from granary import bluesky from multiformats import CID, multibase, multicodec, multihash from oauth_dropins.webutil import util -from atproto_mst import ( - MST, - serialize_node_data, -) -from atproto_util import dag_cbor_cid, sign_commit from flask_app import xrpc_server from models import Follower, Object, PAGE_SIZE, User diff --git a/atproto_diff.py b/atproto_diff.py deleted file mode 100644 index 4718434..0000000 --- a/atproto_diff.py +++ /dev/null @@ -1,283 +0,0 @@ -"""AT Protocol utility for diffing two MSTs. - -Heavily based on: -https://github.com/bluesky/atproto/blob/main/packages/repo/src/mst/diff.ts - -Huge thanks to the Bluesky team for working in the public, in open source, and to -Daniel Holmgren and Devin Ivy for this code specifically! -""" -from collections import namedtuple -import logging - -from atproto_mst import Leaf, MST, Walker - -logger = logging.getLogger(__name__) - - -def mst_diff(cur, prev=None): - """Generates a diff between two MSTs. - - Args: - cur: :class:`MST` - prev: :class:`MST`, optional - - Returns: - :class:`Diff` - """ - cur.get_pointer() - if not prev: - return null_diff(cur) - - prev.get_pointer() - diff = Diff() - - left_walker = Walker(prev) - right_walker = Walker(cur) - while not left_walker.status.done or not right_walker.status.done: - # print(left_walker.status, right_walker.status) - # if one walker is finished, continue walking the other & logging all nodes - if left_walker.status.done and not right_walker.status.done: - node = right_walker.status.cur - if isinstance(node, Leaf): - diff.record_add(node.key, node.value) - else: - diff.record_new_cid(node.pointer) - right_walker.advance() - continue - - elif not left_walker.status.done and right_walker.status.done: - node = left_walker.status.cur - if isinstance(node, Leaf): - diff.record_delete(node.key, node.value) - else: - diff.record_removed_cid(node.pointer) - left_walker.advance() - continue - - if left_walker.status.done or right_walker.status.done: - break - - left = left_walker.status.cur - right = right_walker.status.cur - if not left or not right: - break - - # if both pointers are leaves, record an update & advance both or record - # the lowest key and advance that pointer - if isinstance(left, Leaf) and isinstance(right, Leaf): - if left.key == right.key: - if left.value != right.value: - diff.record_update(left.key, left.value, right.value) - left_walker.advance() - right_walker.advance() - elif left.key < right.key: - diff.record_delete(left.key, left.value) - left_walker.advance() - else: - diff.record_add(right.key, right.value) - right_walker.advance() - - continue - - # next, ensure that we're on the same layer - # - # if one walker is at a higher layer than the other, we need to do one - # of two things if the higher walker is pointed at a tree, step into - # that tree to try to catch up with the lower if the higher walker is - # pointed at a leaf, then advance the lower walker to try to catch up - # the higher - if left_walker.layer() > right_walker.layer(): - if isinstance(left, Leaf): - if isinstance(right, Leaf): - diff.record_add(right.key, right.value) - else: - diff.record_new_cid(right.pointer) - - right_walker.advance() - else: - diff.record_removed_cid(left.pointer) - left_walker.step_into() - - continue - - elif left_walker.layer() < right_walker.layer(): - if isinstance(right, Leaf): - if isinstance(left, Leaf): - diff.record_delete(left.key, left.value) - else: - diff.record_removed_cid(left.pointer) - - left_walker.advance() - - else: - diff.record_new_cid(right.pointer) - right_walker.step_into() - - continue - - # if we're on the same level, and both pointers are trees, do a - # comparison. if they're the same, step over. if they're different, step - # in to find the subdiff - if isinstance(left, MST) and isinstance(right, MST): - if left.pointer == right.pointer: - left_walker.step_over() - right_walker.step_over() - else: - diff.record_new_cid(right.pointer) - diff.record_removed_cid(left.pointer) - left_walker.step_into() - right_walker.step_into() - - continue - - # finally, if one pointer is a tree and the other is a leaf, simply step - # into the tree - if isinstance(left, Leaf) and isinstance(right, MST): - diff.record_new_cid(right.pointer) - right_walker.step_into() - continue - - elif isinstance(left, MST) and isinstance(right, Leaf): - diff.record_removed_cid(left.pointer) - left_walker.step_into() - continue - - raise RuntimeError('Unidentifiable case in diff walk') - - return diff - - -def null_diff(tree): - """Generates a "null" diff for a single MST with all adds and new CIDs. - - Args: - tree: :class:`MST` - - Returns: - :class:`Diff` - """ - diff = Diff() - - for entry in tree.walk(): - if isinstance(entry, Leaf): - diff.record_add(entry.key, entry.value) - else: - diff.record_new_cid(entry.pointer) - - return diff - - -Change = namedtuple('Change', [ - 'key', # str - 'cid', # :class:`CID` - 'prev', # :class:`CID` -], defaults=[None]) - - -class Diff: - """A diff between two MSTs. - - Attributes: - adds: {str key: :class:`Change`} - updates: {str key: :class:`Change`} - deletes: {str key: :class:`Change`} - new_cids: set of :class:`CID` - removed_cids: set of :class:`CID` - """ - - def __init__(self): - self.adds = {} - self.updates = {} - self.deletes = {} - self.new_cids = set() - self.removed_cids = set() - - @staticmethod - def of(cur, prev=None): - """ - Args: - cur: :class:`MST` - prev: :class:`MST`, optional - - Returns: - :class:`Diff` - """ - return mst_diff(cur, prev) - - def record_add(self, key, cid): - """ - Args: - key: str - cid: :class:`CID` - """ - self.adds[key] = Change(key=key, cid=cid) - self.new_cids.add(cid) - - def record_update(self, key, prev, cid): - """ - Args: - key: str - prev: :class:`CID` - cid: :class:`CID` - """ - self.updates[key] = Change(key=key, cid=cid, prev=prev) - self.new_cids.add(cid) - - def record_delete(self, key, cid): - """ - Args: - key: str - cid: :class:`CID` - """ - self.deletes[key] = Change(key=key, cid=cid) - - def record_new_cid(self, cid): - """ - Args: - cid: :class:`CID` - """ - if cid in self.removed_cids: - self.removed_cids.remove(cid) - else: - self.new_cids.add(cid) - - def record_removed_cid(self, cid): - """ - Args: - cid: :class:`CID` - """ - if cid in self.new_cids: - self.new_cids.remove(cid) - else: - self.removed_cids.add(cid) - - def add_diff(self, diff): - """ - Args: - diff: :class:`Diff` - """ - for add in diff.adds.values(): - if self.deletes[add.key]: - deleted = self.deletes[add.key] - if deleted.cid != add.cid: - self.record_update(add.key, deleted.cid, add.cid) - del self.deletes[add.key] - else: - self.record_add(add.key, add.cid) - - for update in diff.updates.values(): - self.record_update(update.key, update.prev, update.cid) - del self.adds[update.key] - del self.deletes[update.key] - - for deleted in diff.deletes.values(): - if self.adds[deleted.key]: - del self.adds[deleted.key] - else: - del self.updates[deleted.key] - self.record_delete(deleted.key, deleted.cid) - - self.new_cids |= diff.new_cids - - def updated_keys(self): - return self.adds | self.updates | self.deletes diff --git a/atproto_mst.py b/atproto_mst.py deleted file mode 100644 index 8c328b0..0000000 --- a/atproto_mst.py +++ /dev/null @@ -1,1119 +0,0 @@ -"""Bluesky / AT Protocol Merkle search tree implementation. - -* https://atproto.com/guides/data-repos -* https://atproto.com/lexicons/com-atproto-sync -* https://hal.inria.fr/hal-02303490/document - -Heavily based on: -https://github.com/bluesky/atproto/blob/main/packages/repo/src/mst/mst.ts - -Huge thanks to the Bluesky team for working in the public, in open source, and to -Daniel Holmgren and Devin Ivy for this code specifically! - -Notable differences: -* All in memory, no block storage (yet) - -From that file: - -This is an implementation of a Merkle Search Tree (MST) -The data structure is described here: https://hal.inria.fr/hal-02303490/document -The MST is an ordered, insert-order-independent, deterministic tree. -Keys are laid out in alphabetic order. -The key insight of an MST is that each key is hashed and starting 0s are counted -to determine which layer it falls on (5 zeros for ~32 fanout). -This is a merkle tree, so each subtree is referred to by it's hash (CID). -When a leaf is changed, ever tree on the path to that leaf is changed as well, -thereby updating the root hash. - -For atproto, we use SHA-256 as the key hashing algorithm, and ~4 fanout -(2-bits of zero per layer). - -A couple notes on CBOR encoding: - -There are never two neighboring subtrees. -Therefore, we can represent a node as an array of -leaves & pointers to their right neighbor (possibly null), -along with a pointer to the left-most subtree (also possibly null). - -Most keys in a subtree will have overlap. -We do compression on prefixes by describing keys as: -* the length of the prefix that it shares in common with the preceding key -* the rest of the string - -For example: - -If the first leaf in a tree is `bsky/posts/abcdefg` and the second is -`bsky/posts/abcdehi` Then the first will be described as `prefix: 0, key: -'bsky/posts/abcdefg'`, and the second will be described as `prefix: 16, key: -'hi'.` -""" -from collections import namedtuple -import copy -from hashlib import sha256 -from os.path import commonprefix -import re - -from multiformats import CID - -from atproto_util import dag_cbor_cid - -# this is treeEntry in mst.ts -Entry = namedtuple('Entry', [ - 'p', # int, length of prefix that this key shares with the prev key - 'k', # bytes, the rest of the key outside the shared prefix - 'v', # str CID, value - 't', # str CID, next subtree (to the right of leaf), or None -]) - -Data = namedtuple('Data', [ - 'l', # str CID, left-most subtree, or None - 'e', # list of Entry -]) - -Leaf = namedtuple('Leaf', [ - 'key', # str, record key - 'value', # CID -]) - - -class MST: - """Merkle search tree class. - - Attributes: - entries: sequence of :class:`MST` and :class:`Leaf` - layer: int, this MST's layer in the root MST - pointer: :class:`CID` - outdated_pointer: boolean, whether pointer needs to be recalculated - """ - entries = None - layer = None - pointer = None - outdated_pointer = False - - def __init__(self, entries=None, pointer=None, layer=None): - """Constructor. - Args: - entries: sequence of :class:`MST` and :class:`Leaf` - pointer: :class:`CID` - layer: int - - Returns: - :class:`MST` - """ - self.entries = entries or [] - self.pointer = pointer or cid_for_entries(self.entries) - self.layer = layer - -# def from_data(data: NodeData, opts?: Partial): -# """ -# Returns: -# :class:`MST` -# """ -# { layer = None } = opts or {} -# entries = deserialize_node_data(data, opts) -# pointer = cid_for_cbor(data) -# return MST(entries=entries, pointer=pointer) - - def __eq__(self, other): - if isinstance(other, MST): - return self.get_pointer() == other.get_pointer() - - def __unicode__(self): - return f'MST with pointer {self.get_pointer()}' - - def __repr__(self): - return f'MST(entries=..., pointer={self.get_pointer()}, layer={self.get_layer()})' - - # Immutability - # ------------------- - def new_tree(self, entries): - """We never mutate an MST, we just return a new MST with updated values. - - Args: - entries: sequence of :class:`MST` and :class:`Leaf` - - Returns: - :class:`MST` - """ - mst = MST(entries=entries, pointer=self.pointer, layer=self.layer) - mst.outdated_pointer = True - return mst - - -# Getters (lazy load) -# ------------------- - - def get_entries(self): - """ - - We don't want to load entries of every subtree, just the ones we need. - - Returns: - sequence of :class:`MST` and :class:`Leaf` - """ - if self.entries: - return copy.copy(self.entries) - - if self.pointer: - data = self.storage.read_obj(self.pointer, node_data_def) - first_leaf = data.e[0] - layer = leading_zeros_on_hash(first_leaf.k) if first_leaf else None - self.entries = deserialize_node_data(self.storage, data, {layer}) - return self.entries - - raise RuntimeError('No entries or CID provided') - - def get_pointer(self): - """Returns this MST's root CID pointer. Calculates it if necessary. - - We don't hash the node on every mutation for performance reasons. - Instead we keep track of whether the pointer is outdated and only - (recursively) calculate when needed. - - Returns: - :class:`CID` - """ - if not self.outdated_pointer: - return self.pointer - - for e in self.entries: - if isinstance(e, MST) and e.outdated_pointer: - e.get_pointer() - - self.pointer = cid_for_entries(self.entries) - self.outdated_pointer = False - return self.pointer - - def get_layer(self): - """Returns this MST's layer, and sets self.layer. - - In most cases, we get the layer of a node from a hint on creation. In the - case of the topmost node in the tree, we look for a key in the node & - determine the layer. In the case where we don't find one, we recurse down - until we do. If we still can't find one, then we have an empty tree and the - node is layer 0. - - Returns: - int - """ - self.layer = self.attempt_get_layer() - if self.layer is None: - self.layer = 0 - - return self.layer - - def attempt_get_layer(self): - """Returns this MST's layer, and sets self.layer. - - Returns: - int or None - """ - if self.layer is not None: - return self.layer - - layer = layer_for_entries(self.entries) - if layer is None: - for entry in self.entries: - if isinstance(entry, MST): - child_layer = entry.attempt_get_layer() - if child_layer is not None: - layer = child_layer + 1 - break - - if layer is not None: - self.layer = layer - - return layer - - - # Core functionality - # ------------------- - - def add(self, key, value=None, known_zeros=None): - """Adds a new leaf for the given key/value pair. - - Args: - key: str - value: :class:`CID` - known_zeros: int - - Returns: - :class:`MST` - - Raises: - ValueError if a leaf with that key already exists - """ - ensure_valid_key(key) - key_zeros = known_zeros or leading_zeros_on_hash(key) - layer = self.get_layer() - new_leaf = Leaf(key=key, value=value) - - if key_zeros == layer: - # it belongs in self layer - index = self.find_gt_or_equal_leaf_index(key) - found = self.at_index(index) - if isinstance(found, Leaf) and found.key == key: - raise ValueError(f'There is already a value at key: {key}') - prev_node = self.at_index(index - 1) - if not prev_node or isinstance(prev_node, Leaf): - # if entry before is a leaf, (or we're on far left) we can just splice in - return self.splice_in(new_leaf, index) - else: - # else we try to split the subtree around the key - left, right = prev_node.split_around(key) - return self.replace_with_split(index - 1, left, new_leaf, right) - - elif key_zeros < layer: - # it belongs on a lower layer - index = self.find_gt_or_equal_leaf_index(key) - prev_node = self.at_index(index - 1) - if prev_node and isinstance(prev_node, MST): - # if entry before is a tree, we add it to that tree - new_subtree = prev_node.add(key, value, key_zeros) - return self.update_entry(index - 1, new_subtree) - else: - sub_tree = self.create_child() - new_subtree = sub_tree.add(key, value, key_zeros) - return self.splice_in(new_subtree, index) - - else: # key_zeros > layer - # it belongs on a higher layer, push the rest of the tree down - left, right = self.split_around(key) - # if the newly added key has >=2 more leading zeros than the current - # highest layer then we need to add structural nodes between as well - layer = self.get_layer() - extra_layers_to_add = key_zeros - layer - # intentionally starting at 1, first layer is taken care of by split - for i in range(1, extra_layers_to_add): - if left: - left = left.create_parent() - if right: - right = right.create_parent() - - updated = [] - if left: - updated.append(left) - updated.append(Leaf(key=key, value=value)) - if right: - updated.append(right) - - new_root = MST(entries=updated, layer=key_zeros) - new_root.outdated_pointer = True - return new_root - - def get(self, key): - """Gets the value at the given key. - - Args: - key: str - - Returns: - :class:`CID` or None - """ - index = self.find_gt_or_equal_leaf_index(key) - found = self.at_index(index) - if found and isinstance(found, Leaf) and found.key == key: - return found.value - - prev = self.at_index(index - 1) - if prev and isinstance(prev, MST): - return prev.get(key) - - def update(self, key, value): - """Edits the value at the given key - - Args: - key: str - value: :class:`CID` - - Returns: - :class:`MST` - - Raises: - KeyError if key doesn't exist - """ - ensure_valid_key(key) - - index = self.find_gt_or_equal_leaf_index(key) - found = self.at_index(index) - if found and isinstance(found, Leaf) and found.key == key: - return self.update_entry(index, Leaf(key=key, value=value)) - - prev = self.at_index(index - 1) - if prev and isinstance(prev, MST): - updated_tree = prev.update(key, value) - return self.update_entry(index - 1, updated_tree) - - raise KeyError(f'Could not find a record with key: {key}') - - def delete(self, key): - """Deletes the value at the given key. - - Args: - key: str - - Returns: - :class:`MST` - - Raises: - KeyError if key doesn't exist - """ - return self.delete_recurse(key).trim_top() - - def delete_recurse(self, key): - """Deletes the value and subtree, if any, at the given key. - - Args: - key: str - - Returns: - :class:`MST` - """ - index = self.find_gt_or_equal_leaf_index(key) - found = self.at_index(index) - - # if found, remove it on self level - if isinstance(found, Leaf) and found.key == key: - prev = self.at_index(index - 1) - next = self.at_index(index + 1) - if isinstance(prev, MST) and isinstance(next, MST): - merged = prev.append_merge(next) - return self.new_tree( - self.slice(0, index - 1) + [merged] + self.slice(index + 2) - ) - else: - return self.remove_entry(index) - - # else recurse down to find it - prev = self.at_index(index - 1) - if isinstance(prev, MST): - subtree = prev.delete_recurse(key) - if subtree.entries == 0: - return self.remove_entry(index - 1) - else: - return self.update_entry(index - 1, subtree) - - raise KeyError(f'Could not find a record with key: {key}') - - -# Simple Operations -# ------------------- - - def update_entry(self, index, entry): - """Updates an entry in place. - - Args: - index: int - entry: :class:`MST` or :class:`Leaf` - - Returns: - :class:`MST` - """ - return self.new_tree( - entries=self.slice(0, index) + [entry] + self.slice(index + 1)) - - def remove_entry(self, index): - """Removes the entry at a given index. - - Args: - index: int - - Returns: - :class:`MST` - """ - return self.new_tree(entries=self.slice(0, index) + self.slice(index + 1)) - - def append(self, entry): - """Appends an entry to the end of the node. - - Args: - entry: :class:`MST` or :class:`Leaf` - - Returns: - :class:`MST` - """ - return self.new_tree(self.entries + [entry]) - - def prepend(self, entry): - """Prepends an entry to the start of the node. - - Args: - entry: :class:`MST` or :class:`Leaf` - - Returns: - :class:`MST` - """ - return self.new_tree([entry] + self.entries) - - def at_index(self, index): - """Returns the entry at a given index. - - Args: - index: int - - Returns: - :class:`MST` or :class:`Leaf` or None - """ - if 0 <= index < len(self.entries): - return self.entries[index] - - def slice(self, start=None, end=None): - """Returns a slice of this node. - - Args: - start: int, optional, inclusive - end: int, optional, exclusive - - Returns: - sequence of :class:`MST` and :class:`Leaf` - """ - return self.entries[start:end] - - def splice_in(self, entry, index): - """Inserts an entry at a given index. - - Args: - entry: :class:`MST` or :class:`Leaf` - index: int - - Returns: - :class:`MST` - """ - return self.new_tree(self.slice(0, index) + [entry] + self.slice(index)) - - def replace_with_split(self, index, left=None, leaf=None, right=None): - """Replaces an entry with [ Maybe(tree), Leaf, Maybe(tree) ]. - - Args: - index: int - left: :class:`MST` or :class:`Leaf` - leaf: :class:`Leaf` - right: :class:`MST` or :class:`Leaf` - - Returns: - :class:`MST` - """ - updated = self.slice(0, index) - if left: - updated.append(left) - updated.append(leaf) - if right: - updated.append(right) - updated.extend(self.slice(index + 1)) - return self.new_tree(updated) - - def trim_top(self): - """Trims the top and return its subtree, if necessary. - - Only if the topmost node in the tree only points to another tree. - Otherwise, does nothing. - - Returns: - :class:`MST` - """ - if len(self.entries) == 1 and isinstance(self.entries[0], MST): - return self.entries[0].trim_top() - else: - return self - - -# Subtree & Splits -# ------------------- - - def split_around(self, key): - """Recursively splits a subtree around a given key. - - Args: - key: str - - Returns: - tuple, (:class:`MST` or None, :class:`MST or None) - """ - index = self.find_gt_or_equal_leaf_index(key) - # split tree around key - left_data = self.slice(0, index) - right_data = self.slice(index) - left = self.new_tree(left_data) - right = self.new_tree(right_data) - - # if the far right of the left side is a subtree, - # we need to split it on the key as well - last_in_left = left_data[-1] if left_data else None - if isinstance(last_in_left, MST): - left = left.remove_entry(len(left_data) -1) - split = last_in_left.split_around(key) - if split[0]: - left = left.append(split[0]) - if split[1]: - right = right.prepend(split[1]) - - return [ - left if left.entries else None, - right if right.entries else None, - ] - - def append_merge(self, to_merge): - """Merges another tree with this one. - - The simple merge case where every key in the right tree is greater than - every key in the left tree. Used primarily for deletes. - - Args: - to_merge: :class:`MST` - - Returns: - :class:`MST` - """ - assert self.get_layer() == to_merge.get_layer(), \ - 'Trying to merge two nodes from different layers of the MST' - - last_in_left = self.entries[-1] - first_in_right = to_merge.entries[0] - - if isinstance(last_in_left, MST) and isinstance(first_in_right, MST): - merged = last_in_left.append_merge(first_in_right) - return self.new_tree( - list(self.entries[:-1]) + [merged] + to_merge.entries[1:]) - else: - return self.new_tree(self.entries + to_merge.entries) - - - # Create relatives - # ------------------- - - def create_child(self): - """ - Returns: - :class:`MST` - """ - return MST(entries=[], layer=self.get_layer() - 1) - - def create_parent(self): - """ - Returns: - :class:`MST` - """ - parent = MST(entries=[self], layer=self.get_layer() + 1) - parent.outdated_pointer = True - return parent - - -# Finding insertion points -# ------------------- - - def find_gt_or_equal_leaf_index(self, key): - """Finds the index of the first leaf node greater than or equal to value. - - Args: - key: str - - Returns: - int - """ - for i, entry in enumerate(self.entries): - if isinstance(entry, Leaf) and entry.key >= key: - return i - - # if we can't find it, we're on the end - return len(self.entries) - - -# List operations (partial tree traversal) -# ------------------- - -# @TODO write tests for these - -# Walk tree starting at key -# def walk_leaves_from(key: string): AsyncIterable: -# index = self.find_gt_or_equal_leaf_index(key) -# prev = self.entries[index - 1] -# if prev and isinstance(prev, MST): -# for e in prev.walk_leaves_from(key): -# yield e -# for entry in self.entries[index:]: -# if isinstance(entry, Leaf): -# yield entry -# else: -# for e in entry.walk_leaves_from(key): -# yield e - -# def list( -# count = Number.MAX_SAFE_INTEGER, -# after?: string, -# before?: string, -# ): -# """ -# Returns: -# Leaf[] -# """ -# vals: Leaf[] = [] -# for leaf in self.walk_leaves_from(after or ''): -# if leaf.key == after: -# continue -# if len(vals) >= count: -# break -# if before and leaf.key >= before: -# break -# vals.append(leaf) -# return vals - -# def list_with_prefix( -# prefix: string, -# count = Number.MAX_SAFE_INTEGER, -# ): -# """ -# Returns: -# Leaf[] -# """ -# vals: Leaf[] = [] -# for leaf in self.walk_leaves_from(prefix): -# if len(vals) >= count or not leaf.key.startswith(prefix): -# break -# vals.append(leaf) -# return vals - - -# Full tree traversal -# ------------------- - - def walk(self): - """Walk full tree, depth first, and emit nodes. - - Returns: - generator of :class:`MST` and :class:`Leaf` - """ - yield self - - for entry in self.entries: - if isinstance(entry, MST): - for e in entry.walk(): - yield e - else: - yield entry - -# Walk full tree & emit nodes, consumer can bail at any point by returning False -# def paths(): -# """ -# Returns: -# sequence of :class:`MST` and :class:`Leaf` -# """ -# paths: NodeEntry[][] = [] -# for entry in self.entries: -# if isinstance(entry, Leaf): -# paths.append([entry]) -# if isinstance(entry, MST): -# sub_paths = entry.paths() -# paths = paths + sub_paths.map((p) => ([entry] + p)) -# return paths - - def all_nodes(self): - """Walks the tree and returns all nodes. - - Returns: - sequence of :class:`MST` and :class:`Leaf` - """ - return list(self.walk()) - -# Walks tree & returns all cids -# def all_cids(): -# """ -# Returns: -# CidSet -# """ -# cids = CidSet() -# for entry in self.entries: -# if isinstance(entry, Leaf): -# cids.add(entry.value) -# else: -# subtree_cids = entry.all_cids() -# cids.add_set(subtree_cids) -# cids.add(self.get_pointer()) -# return cids - - def leaves(self): - """Walks tree and returns all leaves. - - Returns: - sequence of :class:`Leaf` - """ - return [entry for entry in self.walk() if isinstance(entry, Leaf)] - - def leaf_count(self): - """Returns the total number of leaves in this MST. - - Returns: - int - """ - return len(self.leaves()) - - -# Reachable tree traversal -# ------------------- - - # Walk reachable branches of tree & emit nodes, consumer can bail at any - # point by returning False - -# def walk_reachable(): AsyncIterable: -# yield self -# for entry in self.entries: -# if isinstance(entry, MST): -# try: -# for e in entry.walk_reachable(): -# yield e -# catch (err): -# if err instanceof MissingBlockError: -# continue -# else: -# raise err -# else: -# yield entry - -# def reachable_leaves(): -# """ -# Returns: -# Leaf[] -# """ -# leaves: Leaf[] = [] -# for entry in self.walk_reachable(): -# if isinstance(entry, Leaf): -# leaves.append(entry) -# return leaves - -# Sync Protocol - -# def write_to_car_stream(car: BlockWriter): -# """ -# Returns: -# void -# """ -# leaves = CidSet() -# to_fetch = CidSet() -# to_fetch.add(self.get_pointer()) -# for entry in self.entries: -# if isinstance(entry, Leaf): -# leaves.add(entry.value) -# else: -# to_fetch.add(entry.get_pointer()) -# while (to_fetch.size() > 0): -# next_layer = CidSet() -# fetched = self.storage.get_blocks(to_fetch.to_list()) -# if fetched.missing: -# raise MissingBlocksError('mst node', fetched.missing) -# for cid in to_fetch.to_list(): -# found = parse.get_and_parse_by_def( -# fetched.blocks, -# cid, -# node_data_def, -# ) -# car.put({ cid, bytes: found.bytes }) -# entries = deserialize_node_data(self.storage, found.obj) - -# for entry in entries: -# if isinstance(entry, Leaf): -# leaves.add(entry.value) -# else: -# next_layer.add(entry.get_pointer()) -# to_fetch = next_layer -# leaf_data = self.storage.get_blocks(leaves.to_list()) -# if leaf_data.missing: -# raise MissingBlocksError('mst leaf', leaf_data.missing) - -# for leaf in leaf_data.blocks.entries(): -# car.put(leaf) - -# def cids_for_path(self, key): -# """Returns the CIDs in a given key path. ??? -# -# Args: -# key: str -# -# Returns: -# sequence of :class:`CID` -# """ -# cids: CID[] = [self.get_pointer()] -# index = self.find_gt_or_equal_leaf_index(key) -# found = self.at_index(index) -# if found and isinstance(found, Leaf) and found.key == key: -# return cids + [found.value] -# prev = self.at_index(index - 1) -# if prev and isinstance(prev, MST): -# return cids + prev.cids_for_path(key) -# return cids - - -def leading_zeros_on_hash(key): - """Returns the number of leading zeros in a key's hash. - - Args: - key: str or bytes - - Returns: - int - """ - if not isinstance(key, bytes): - key = key.encode() # ensure_valid_key enforces that this is ASCII only - - leading_zeros = 0 - for byte in sha256(key).digest(): - if byte < 64: - leading_zeros += 1 - if byte < 16: - leading_zeros += 1 - if byte < 4: - leading_zeros += 1 - if byte == 0: - leading_zeros += 1 - else: - break - - return leading_zeros - - -def layer_for_entries(entries): - """ - sequence of :class:`MST` and :class:`Leaf` - Returns: - number | None - """ - for entry in entries: - if isinstance(entry, Leaf): - return leading_zeros_on_hash(first_leaf.key) - - -# def deserialize_node_data = ( -# storage: ReadableBlockstore, -# data: NodeData, -# opts?: Partial, -# ): -# """ -# Returns: -# sequence of :class:`MST` and :class:`Leaf` -# """ -# { layer } = opts or {} -# entries = [] -# if (data.l is not None): -# entries.append( -# MST.load(storage, data.l,: -# layer: layer ? layer - 1 : undefined, -# ) - -# last_key = '' -# for entry in data.e: -# key_str = uint8arrays.to_string(entry.k, 'ascii') -# key = last_key.slice(0, entry.p) + key_str -# ensure_valid_key(key) -# entries.append(Leaf(key, entry.v)) -# last_key = key -# if entry.t is not None: -# entries.append( -# MST.load(storage, entry.t,: -# layer: layer ? layer - 1 : undefined, -# ) - -# return entries - - -def serialize_node_data(entries): - """ - Args: - entries: sequence of :class:`MST` and :class:`Leaf` - - Returns: - :class:`Data` - """ - l = None - i = 0 - if entries and isinstance(entries[0], MST): - i += 1 - l = entries[0].get_pointer() - - data = Data(l=l, e=[]) - last_key = '' - while i < len(entries): - leaf = entries[i] - next = entries[i + 1] if i < len(entries) - 1 else None - - if not isinstance(leaf, Leaf): - raise ValueError('Not a valid node: two subtrees next to each other') - i += 1 - - subtree = None - if next and isinstance(next, MST): - subtree = next.get_pointer() - i += 1 - - ensure_valid_key(leaf.key) - prefix_len = common_prefix_len(last_key, leaf.key) - data.e.append(Entry( - p=prefix_len, - k=leaf.key[prefix_len:].encode('ascii'), - v=leaf.value, - t=subtree, - )._asdict()) - - last_key = leaf.key - - return data - - -def common_prefix_len(a, b): - """ - Args: - a, b: str - - Returns: - int - """ - return len(commonprefix((a, b))) - - -def cid_for_entries(entries): - """ - Args: - entries: sequence of :class:`MST` and :class:`Leaf` - - Returns: - :class:`CID` - """ - return dag_cbor_cid(serialize_node_data(entries)._asdict()) - - -def ensure_valid_key(key): - """ - Args: - key: str - - Raises: - ValueError if key is not a valid MST key. - """ - valid = re.compile('[a-zA-Z0-9_\-:.]*$') - split = key.split('/') - if not (len(key) <= 256 and - len(split) == 2 and - split[0] and - split[1] and - valid.match(split[0]) and - valid.match(split[1]) - ): - raise ValueError(f'Invalid MST key: {key}') - - -WalkStatus = namedtuple('WalkStatus', [ - 'done', # boolean - 'cur', # MST or Leaf - 'walking', # MST or None if cur is the root of the tree - 'index', # int -], defaults=[None, None, None, None]) - - -class Walker: - """Allows walking an MST manually. - - Attributes: - stack: sequence of WalkStatus - status: WalkStatus, current - """ - stack = None - status = None - - def __init__(self, tree): - """Constructor. - - Args: - tree: :class:`MST` - """ - self.stack = [] - self.status = WalkStatus( - done=False, - cur=tree, - walking=None, - index=0, - ) - - def layer(self): - """Returns the curent layer of the node we're on.""" - assert not self.status.done, 'Walk is done' - - if self.status.walking: - return self.status.walking.layer or 0 - - # if cur is the root of the tree, add 1 - if isinstance(self.status.cur, MST): - return (self.status.cur.layer or 0) + 1 - - raise RuntimeError('Could not identify layer of walk') - - - def step_over(self): - """Moves to the next node in the subtree, skipping over the subtree.""" - if self.status.done: - return - - # if stepping over the root of the node, we're done - if not self.status.walking: - self.status = WalkStatus(done=True) - return - - entries = self.status.walking.get_entries() - self.status = self.status._replace(index=self.status.index + 1) - - if self.status.index >= len(entries): - if not self.stack: - self.status = WalkStatus(done=True) - else: - self.status = self.stack.pop() - self.step_over() - else: - self.status = self.status._replace(cur=entries[self.status.index]) - - def step_into(self): - """Steps into a subtree. - - Raises: - RuntimeError, if curently on a leaf - """ - if self.status.done: - return - - # edge case for very start of walk - if not self.status.walking: - assert isinstance(self.status.cur, MST), \ - 'The root of the tree cannot be a leaf' - next = self.status.cur.at_index(0) - if not next: - self.status = WalkStatus(done=True) - else: - self.status = WalkStatus( - done=False, - walking=self.status.cur, - cur=next, - index=0, - ) - return - - if not isinstance(self.status.cur, MST): - raise RuntimeError('No tree at pointer, cannot step into') - - next = self.status.cur.at_index(0) - assert next, 'Tried to step into a node with 0 entries which is invalid' - - self.stack.append(self.status) - self.status = WalkStatus( - walking=self.status.cur, - cur=next, - index=0, - done=False, - ) - - def advance(self): - """Advances to the next node in the tree. - - Steps into the curent node if necessary. - """ - if self.status.done: - return - - if isinstance(self.status.cur, Leaf): - self.step_over() - else: - self.step_into() diff --git a/atproto_util.py b/atproto_util.py deleted file mode 100644 index 15043c3..0000000 --- a/atproto_util.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Misc AT Protocol utils. TIDs, CIDs, etc.""" -import copy -from datetime import datetime, timezone -import logging -from numbers import Integral -import random - -from Crypto.Hash import SHA256 -from Crypto.Signature import DSS -import dag_cbor.encoding -from multiformats import CID, multicodec, multihash -from oauth_dropins.webutil.appengine_info import DEBUG - -logger = logging.getLogger(__name__) - -# the bottom 32 clock ids can be randomized & are not guaranteed to be collision -# resistant. we use the same clockid for all TIDs coming from this runtime. -_clockid = random.randint(0, 31) -# _tid_last = time.time_ns() // 1000 # microseconds - -S32_CHARS = '234567abcdefghijklmnopqrstuvwxyz' - - -def dag_cbor_cid(obj): - """Returns the DAG-CBOR CID for a given object. - - Args: - obj: CBOR-compatible native object or value - - Returns: - :class:`CID` - """ - encoded = dag_cbor.encoding.encode(obj) - digest = multihash.digest(encoded, 'sha2-256') - return CID('base58btc', 1, multicodec.get('dag-cbor'), digest) - - -def s32encode(num): - """Base32 encode with encoding variant sort. - - Based on https://github.com/bluesky-social/atproto/blob/main/packages/common-web/src/tid.ts - - Args: - num: int or Integral - - Returns: - str - """ - assert isinstance(num, Integral) - - encoded = [] - while num > 0: - c = num % 32 - num = num // 32 - encoded.insert(0, S32_CHARS[c]) - - return ''.join(encoded) - - -def s32decode(val): - """Base32 decode with encoding variant sort. - - Based on https://github.com/bluesky-social/atproto/blob/main/packages/common-web/src/tid.ts - - Args: - val: str - - Returns: - int or Integral - """ - i = 0 - for c in val: - i = i * 32 + S32_CHARS.index(c) - - return i - - -def datetime_to_tid(dt): - """Converts a datetime to an ATProto TID. - - https://atproto.com/guides/data-repos#identifier-types - - Args: - dt: :class:`datetime.datetime` - - Returns: - str, base32-encoded TID - """ - tid = (s32encode(int(dt.timestamp() * 1000 * 1000)) + - s32encode(_clockid).ljust(2, '2')) - assert len(tid) == 13 - return tid - - -def tid_to_datetime(tid): - """Converts an ATProto TID to a datetime. - - https://atproto.com/guides/data-repos#identifier-types - - Args: - tid: bytes, base32-encoded TID - - Returns: - :class:`datetime.datetime` - - Raises: - ValueError if tid is not bytes or not 13 characters long - """ - if not isinstance(tid, (str, bytes)) or len(tid) != 13: - raise ValueError(f'Expected 13-character str or bytes; got {tid}') - - encoded = tid.replace('-', '')[:-2] # strip clock id - return datetime.fromtimestamp(s32decode(encoded) / 1000 / 1000, timezone.utc) - - -# TODO -# def next_tid(): -# global _tid_last - -# # enforce that we're at least 1us after the last TID to prevent TIDs moving -# # backwards if system clock drifts backwards -# now = time.time_ns() // 1000 -# if now > _tid_last: -# _tid_last = now -# else: -# _tid_last += 1 -# now = _tid_last - - -def sign_commit(commit, key): - """Signs a repo commit. - - Adds the signature in the `sig` field. - - Signing isn't yet in the atproto.com docs, this setup is taken from the TS - code and conversations with @why on #bluesky-dev:matrix.org. - - * https://matrix.to/#/!vpdMrhHjzaPbBUSgOs:matrix.org/$Xaf4ugYks-iYg7Pguh3dN8hlsvVMUOuCQo3fMiYPXTY?via=matrix.org&via=minds.com&via=envs.net - * https://github.com/bluesky-social/atproto/blob/384e739a3b7d34f7a95d6ba6f08e7223a7398995/packages/repo/src/util.ts#L238-L248 - * https://github.com/bluesky-social/atproto/blob/384e739a3b7d34f7a95d6ba6f08e7223a7398995/packages/crypto/src/p256/keypair.ts#L66-L73 - * https://github.com/bluesky-social/indigo/blob/f1f2480888ab5d0ac1e03bd9b7de090a3d26cd13/repo/repo.go#L64-L70 - * https://github.com/whyrusleeping/go-did/blob/2146016fc220aa1e08ccf26aaa762f5a11a81404/key.go#L67-L91 - - The signature is ECDSA around SHA-256 of the input. We currently use P-256 - keypairs. Context: - * Go supports P-256, ED25519, SECP256K1 keys - * TS supports P-256, SECP256K1 keys - * this recommends ED25519, then P-256: - https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ - - Args: - commit: dict repo commit - key: :class:`Crypto.PublicKey.ECC.EccKey` - """ - signer = DSS.new(key, 'fips-186-3', randfunc=random.randbytes if DEBUG else None) - commit['sig'] = signer.sign(SHA256.new(dag_cbor.encoding.encode(commit))) - - -def verify_commit_sig(commit, key): - """Returns true if the commit's signature is valid, False otherwise. - - See :func:`sign_commit` for more background. - - Args: - commit: dict repo commit - key: :class:`Crypto.PublicKey.ECC.EccKey` - - Raises: - KeyError if the commit isn't signed, ie doesn't have a `sig` field - """ - commit = copy.copy(commit) - sig = commit.pop('sig') - - verifier = DSS.new(key.public_key(), 'fips-186-3', - randfunc=random.randbytes if DEBUG else None) - try: - verifier.verify(SHA256.new(dag_cbor.encoding.encode(commit)), sig) - return True - except ValueError: - logger.debug("Couldn't verify signature", exc_info=True) - return False - diff --git a/docs/conf.py b/docs/conf.py index 2f16e03..a81dd4a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -341,6 +341,7 @@ texinfo_documents = [ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { + 'arroba': ('https://arroba.readthedocs.io/en/latest', None), 'dag_cbor': ('https://dag-cbor.readthedocs.io/en/latest', None), 'flask': ('https://flask.palletsprojects.com/en/latest', None), 'flask_caching': ('https://flask-caching.readthedocs.io/en/latest', None), diff --git a/docs/source/modules.rst b/docs/source/modules.rst index 5f6ae92..438f1bb 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -9,10 +9,6 @@ activitypub ----------- .. automodule:: activitypub -atproto_mst ------------ -.. automodule:: atproto - common ------ .. automodule:: common diff --git a/models.py b/models.py index 2d81a56..1470fc2 100644 --- a/models.py +++ b/models.py @@ -8,6 +8,7 @@ import logging import random import urllib.parse +from arroba.mst import dag_cbor_cid from Crypto import Random from Crypto.PublicKey import ECC, RSA from Crypto.Util import number @@ -22,7 +23,6 @@ from oauth_dropins.webutil.util import json_dumps, json_loads import requests from werkzeug.exceptions import BadRequest, NotFound -from atproto_mst import dag_cbor_cid import common # https://github.com/snarfed/bridgy-fed/issues/314 diff --git a/requirements.txt b/requirements.txt index c1c7fcb..e08c58e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +git+https://github.com/snarfed/arroba.git#egg=arroba git+https://github.com/dvska/gdata-python3.git#egg=gdata git+https://github.com/snarfed/granary.git#egg=granary git+https://github.com/snarfed/dag-json.git#egg=dag_json diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 3d1769c..599c490 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -5,6 +5,9 @@ import random from unittest import skip from unittest.mock import patch +import atproto +import arroba.util +from arroba.util import dag_cbor_cid, verify_commit_sig from Crypto.PublicKey import ECC import dag_cbor.decoding, dag_cbor.encoding from granary import as2, bluesky @@ -22,14 +25,11 @@ from multiformats import CID from oauth_dropins.webutil import util from oauth_dropins.webutil.testutil import NOW -import atproto -import atproto_util -from atproto_util import dag_cbor_cid, verify_commit_sig from flask_app import app from models import Follower, Object, User from . import testutil -# # atproto_mst.Data entry for MST with POST_AS, REPLY_AS, and REPOST_AS +# # arroba.mst.Data entry for MST with POST_AS, REPLY_AS, and REPOST_AS # POST_CID = 'bafyreic5xwex7jxqvliumozkoli3qy2hzxrmui6odl7ujrcybqaypacfiy' # REPLY_CID = 'bafyreib55ro37wasiflouvlfenhzllorcthm7flr2nj4fnk7yjo54cagvm' # REPOST_CID = 'bafyreiek3jnp6e4sussy4c7pwtbkkf3kepekzycylowwuepmnvq7aeng44' diff --git a/tests/test_atproto_diff.py b/tests/test_atproto_diff.py deleted file mode 100644 index 19b1ee3..0000000 --- a/tests/test_atproto_diff.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Unit tests for atproto_diff.py. - -Heavily based on: -https://github.com/bluesky/atproto/blob/main/packages/repo/tests/sync/diff.test.ts - -Huge thanks to the Bluesky team for working in the public, in open source, and to -Daniel Holmgren and Devin Ivy for this code specifically! -""" -import dag_cbor.random - -from atproto_diff import Change, Diff -from atproto_mst import MST -from . import testutil - - -class AtProtoDiffTest(testutil.TestCase): - - def test_diffs(self): - mst = MST() - data = self.random_keys_and_cids(3)#1000) - for key, cid in data: - mst = mst.add(key, cid) - - before = after = mst - - to_add = self.random_keys_and_cids(1)#100) - to_edit = data[1:2] - to_del = data[0:1] - - # these are all {str key: Change} - expected_adds = {} - expected_updates = {} - expected_deletes = {} - - for key, cid in to_add: - after = after.add(key, cid) - expected_adds[key] = Change(key=key, cid=cid) - - for (key, prev), new in zip(to_edit, dag_cbor.random.rand_cid()): - after = after.update(key, new) - expected_updates[key] = Change(key=key, prev=prev, cid=new) - - for key, cid in to_del: - after = after.delete(key) - expected_deletes[key] = Change(key=key, cid=cid) - - diff = Diff.of(after, before) - - self.assertEqual(1, len(diff.adds)) - self.assertEqual(1, len(diff.updates)) - self.assertEqual(1, len(diff.deletes)) - - self.assertEqual(expected_adds, diff.adds) - self.assertEqual(expected_updates, diff.updates) - self.assertEqual(expected_deletes, diff.deletes) - - # ensure we correctly report all added CIDs - for entry in after.walk(): - cid = entry.get_pointer() if isinstance(entry, MST) else entry.value - # TODO - # assert cid in blockstore or cid in diff.new_cids - diff --git a/tests/test_atproto_mst.py b/tests/test_atproto_mst.py deleted file mode 100644 index 79b37b1..0000000 --- a/tests/test_atproto_mst.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Unit tests for atproto_mst.py. - -Heavily based on: -https://github.com/bluesky/atproto/blob/main/packages/repo/tests/mst.test.ts - -Huge thanks to the Bluesky team for working in the public, in open source, and to -Daniel Holmgren and Devin Ivy for this code specifically! -""" -from base64 import b32encode -from datetime import datetime -import time -import random -import string -from unittest import skip - -import dag_cbor.random -from multiformats import CID - -from atproto_mst import common_prefix_len, ensure_valid_key, MST -import atproto_util -from . import testutil - -CID1 = CID.decode('bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454') - - -class MstTest(testutil.TestCase): - - def test_add(self): - mst = MST() - data = self.random_keys_and_cids(1000) - for key, cid in data: - mst = mst.add(key, cid) - - for key, cid in data: - got = mst.get(key) - self.assertEqual(cid, got) - - self.assertEqual(1000, mst.leaf_count()) - - def test_edits_records(self): - mst = MST() - data = self.random_keys_and_cids(100) - - for key, cid in data: - mst = mst.add(key, cid) - - edited = [] - for (key, _), cid in zip(data, dag_cbor.random.rand_cid()): - mst = mst.update(key, cid) - edited.append([key, cid]) - - for key, cid in edited: - self.assertEqual(cid, mst.get(key)) - - self.assertEqual(100, mst.leaf_count()) - - def test_deletes_records(self): - mst = MST() - data = self.random_keys_and_cids(1000) - for key, cid in data: - mst = mst.add(key, cid) - - to_delete = data[:100] - the_rest = data[100:] - for key, _ in to_delete: - mst = mst.delete(key) - - self.assertEqual(900, mst.leaf_count()) - - for key, _ in to_delete: - self.assertIsNone(mst.get(key)) - - for key, cid in the_rest: - self.assertEqual(cid, mst.get(key)) - - def test_is_order_independent(self): - mst = MST() - data = self.random_keys_and_cids(1000) - for key, cid in data: - mst = mst.add(key, cid) - - all_nodes = mst.all_nodes() - - recreated = MST() - random.shuffle(data) - for key, cid in data: - recreated = recreated.add(key, cid) - - self.assertEqual(all_nodes, recreated.all_nodes()) - - def test_common_prefix_length(self): - self.assertEqual(3, common_prefix_len('abc', 'abc')) - self.assertEqual(0, common_prefix_len('', 'abc')) - self.assertEqual(0, common_prefix_len('abc', '')) - self.assertEqual(2, common_prefix_len('ab', 'abc')) - self.assertEqual(2, common_prefix_len('abc', 'ab')) - self.assertEqual(3, common_prefix_len('abcde', 'abc')) - self.assertEqual(3, common_prefix_len('abc', 'abcde')) - self.assertEqual(3, common_prefix_len('abcde', 'abc1')) - self.assertEqual(2, common_prefix_len('abcde', 'abb')) - self.assertEqual(0, common_prefix_len('abcde', 'qbb')) - self.assertEqual(0, common_prefix_len('', 'asdf')) - self.assertEqual(3, common_prefix_len('abc', 'abc\x00')) - self.assertEqual(3, common_prefix_len('abc\x00', 'abc')) - - def test_rejects_the_empty_key(self): - with self.assertRaises(ValueError): - MST().add('') - - def test_rejects_a_key_with_no_collection(self): - with self.assertRaises(ValueError): - MST().add('asdf') - - def test_rejects_a_key_with_a_nested_collection(self): - with self.assertRaises(ValueError): - MST().add('nested/collection/asdf') - - def test_rejects_on_empty_coll_or_rkey(self): - for key in 'coll/', '/rkey': - with self.assertRaises(ValueError): - MST().add(key) - - def test_rejects_non_ascii_chars(self): - for key in 'coll/jalapeñoA', 'coll/coöperative', 'coll/abc💩': - with self.assertRaises(ValueError): - MST().add(key) - - def test_rejects_ascii_that_we_dont_support(self): - for key in ('coll/key$', 'coll/key%', 'coll/key(', 'coll/key)', - 'coll/key+', 'coll/key='): - with self.assertRaises(ValueError): - MST().add(key) - - def test_rejects_keys_over_256_chars(self): - with self.assertRaises(ValueError): - MST().add( - 'coll/asdofiupoiwqeurfpaosidfuapsodirupasoirupasoeiruaspeoriuaspeoriu2p3o4iu1pqw3oiuaspdfoiuaspdfoiuasdfpoiasdufpwoieruapsdofiuaspdfoiuasdpfoiausdfpoasidfupasodifuaspdofiuasdpfoiasudfpoasidfuapsodfiuasdpfoiausdfpoasidufpasodifuapsdofiuasdpofiuasdfpoaisdufpao', - ) - - def test_computes_empty_tree_root_CID(self): - self.assertEqual(0, MST().leaf_count()) - self.assertEqual( - 'bafyreie5737gdxlw5i64vzichcalba3z2v5n6icifvx5xytvske7mr3hpm', - MST().get_pointer().encode('base32')) - - def test_computes_trivial_tree_root_CID(self): - mst = MST().add('com.example.record/3jqfcqzm3fo2j', CID1) - self.assertEqual(1, mst.leaf_count()) - self.assertEqual( - 'bafyreibj4lsc3aqnrvphp5xmrnfoorvru4wynt6lwidqbm2623a6tatzdu', - mst.get_pointer().encode('base32')) - - def test_computes_single_layer_2_tree_root_CID(self): - mst = MST().add('com.example.record/3jqfcqzm3fx2j', CID1) - self.assertEqual(1, mst.leaf_count()) - self.assertEqual(2, mst.layer) - self.assertEqual( - 'bafyreih7wfei65pxzhauoibu3ls7jgmkju4bspy4t2ha2qdjnzqvoy33ai', - mst.get_pointer().encode('base32')) - - def test_computes_simple_tree_root_CID(self): - mst = MST() - mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm3fr2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # level 1 - mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm4fc2j', CID1) # level 0 - self.assertEqual(5, mst.leaf_count()) - self.assertEqual( - 'bafyreicmahysq4n6wfuxo522m6dpiy7z7qzym3dzs756t5n7nfdgccwq7m', - mst.get_pointer().encode('base32')) - - def test_trims_top_of_tree_on_delete(self): - l1root = 'bafyreifnqrwbk6ffmyaz5qtujqrzf5qmxf7cbxvgzktl4e3gabuxbtatv4' - l0root = 'bafyreie4kjuxbwkhzg2i5dljaswcroeih4dgiqq6pazcmunwt2byd725vi' - - mst = MST() - mst = mst.add('com.example.record/3jqfcqzm3fn2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm3fo2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # level 1 - mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # level 0 - mst = mst.add('com.example.record/3jqfcqzm3fu2j', CID1) # level 0 - - self.assertEqual(6, mst.leaf_count()) - self.assertEqual(1, mst.layer) - self.assertEqual(l1root, mst.get_pointer().encode('base32')) - - mst = mst.delete('com.example.record/3jqfcqzm3fs2j') # level 1 - self.assertEqual(5, mst.leaf_count()) - self.assertEqual(0, mst.layer) - self.assertEqual(l0root, mst.get_pointer().encode('base32')) - - def test_handles_insertion_that_splits_two_layers_down(self): - """ - * * - _________|________ ____|_____ - | | | | | | | | - * d * i * -> * f * - __|__ __|__ __|__ __|__ __|___ - | | | | | | | | | | | | | | | - a b c e g h j k l * d * * i * - __|__ | _|_ __|__ - | | | | | | | | | - a b c e g h j k l - """ - l1root = 'bafyreiettyludka6fpgp33stwxfuwhkzlur6chs4d2v4nkmq2j3ogpdjem' - l2root = 'bafyreid2x5eqs4w4qxvc5jiwda4cien3gw2q6cshofxwnvv7iucrmfohpm' - - mst = MST() - mst = mst.add('com.example.record/3jqfcqzm3fo2j', CID1) # A; level 0 - mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # B; level 0 - mst = mst.add('com.example.record/3jqfcqzm3fr2j', CID1) # C; level 0 - mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # D; level 1 - mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # E; level 0 - # GAP for F - mst = mst.add('com.example.record/3jqfcqzm3fz2j', CID1) # G; level 0 - mst = mst.add('com.example.record/3jqfcqzm4fc2j', CID1) # H; level 0 - mst = mst.add('com.example.record/3jqfcqzm4fd2j', CID1) # I; level 1 - mst = mst.add('com.example.record/3jqfcqzm4ff2j', CID1) # J; level 0 - mst = mst.add('com.example.record/3jqfcqzm4fg2j', CID1) # K; level 0 - mst = mst.add('com.example.record/3jqfcqzm4fh2j', CID1) # L; level 0 - - self.assertEqual(11, mst.leaf_count()) - self.assertEqual(1, mst.layer) - self.assertEqual(l1root, mst.get_pointer().encode('base32')) - - # insert F, which will push E out in the node with G+H to a new node under D - mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # F; level 2 - self.assertEqual(12, mst.leaf_count()) - self.assertEqual(2, mst.layer) - self.assertEqual(l2root, mst.get_pointer().encode('base32')) - - # remove F, which should push E back over with G+H - mst = mst.delete('com.example.record/3jqfcqzm3fx2j') # F; level 2 - self.assertEqual(11, mst.leaf_count()) - self.assertEqual(1, mst.layer) - self.assertEqual(l1root, mst.get_pointer().encode('base32')) - - def test_handles_new_layers_that_are_two_higher_than_existing(self): - """ - * -> * - __|__ __|__ - | | | | | - a c * b * - | | - * * - | | - a c - """ - - l0root = 'bafyreidfcktqnfmykz2ps3dbul35pepleq7kvv526g47xahuz3rqtptmky' - l2root = 'bafyreiavxaxdz7o7rbvr3zg2liox2yww46t7g6hkehx4i4h3lwudly7dhy' - l2root2 = 'bafyreig4jv3vuajbsybhyvb7gggvpwh2zszwfyttjrj6qwvcsp24h6popu' - - mst = MST() - mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # A; level 0 - mst = mst.add('com.example.record/3jqfcqzm3fz2j', CID1) # C; level 0 - self.assertEqual(2, mst.leaf_count()) - self.assertEqual(0, mst.layer) - self.assertEqual(l0root, mst.get_pointer().encode('base32')) - - # insert B, which is two levels above - mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # B; level 2 - self.assertEqual(3, mst.leaf_count()) - self.assertEqual(2, mst.layer) - self.assertEqual(l2root, mst.get_pointer().encode('base32')) - - # remove B - mst = mst.delete('com.example.record/3jqfcqzm3fx2j') # B; level 2 - self.assertEqual(2, mst.leaf_count()) - self.assertEqual(0, mst.layer) - self.assertEqual(l0root, mst.get_pointer().encode('base32')) - - # insert B (level=2) and D (level=1) - mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # B; level 2 - mst = mst.add('com.example.record/3jqfcqzm4fd2j', CID1) # D; level 1 - self.assertEqual(4, mst.leaf_count()) - self.assertEqual(2, mst.layer) - self.assertEqual(l2root2, mst.get_pointer().encode('base32')) - - # remove D - mst = mst.delete('com.example.record/3jqfcqzm4fd2j') # D; level 1 - self.assertEqual(3, mst.leaf_count()) - self.assertEqual(2, mst.layer) - self.assertEqual(l2root, mst.get_pointer().encode('base32')) diff --git a/tests/test_atproto_util.py b/tests/test_atproto_util.py deleted file mode 100644 index a6ae5f5..0000000 --- a/tests/test_atproto_util.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Unit tests for atproto_util.py.""" - -from Crypto.PublicKey import ECC -from multiformats import CID -from oauth_dropins.webutil.testutil import NOW - -import atproto_util -from atproto_util import ( - dag_cbor_cid, - datetime_to_tid, - sign_commit, - tid_to_datetime, - verify_commit_sig, -) -from . import testutil - - -class AtProtoUtilTest(testutil.TestCase): - - def test_dag_cbor_cid(self): - self.assertEqual( - CID.decode('bafyreiblaotetvwobe7cu2uqvnddr6ew2q3cu75qsoweulzku2egca4dxq'), - dag_cbor_cid({'foo': 'bar'})) - - def test_datetime_to_tid(self): - self.assertEqual('3iom4o4g6u2l2', datetime_to_tid(NOW)) - - def test_tid_to_datetime(self): - self.assertEqual(NOW, tid_to_datetime('3iom4o4g6u2l2')) - - def test_sign_commit_and_verify(self): - user = self.make_user('user.com') - - commit = {'foo': 'bar'} - key = ECC.import_key(user.p256_key) - sign_commit(commit, key) - assert verify_commit_sig(commit, key) - - def test_verify_commit_error(self): - key = ECC.import_key(self.make_user('user.com').p256_key) - with self.assertRaises(KeyError): - self.assertFalse(verify_commit_sig({'foo': 'bar'}, key)) - - def test_verify_commit_fail(self): - key = ECC.import_key(self.make_user('user.com').p256_key) - self.assertFalse(verify_commit_sig({'foo': 'bar', 'sig': 'nope'}, key)) diff --git a/tests/test_models.py b/tests/test_models.py index d9acafd..c331d1c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,6 +2,7 @@ """Unit tests for models.py.""" from unittest import mock +from arroba.mst import dag_cbor_cid from Crypto.PublicKey import ECC from flask import g, get_flashed_messages from granary import as2 @@ -10,7 +11,6 @@ from multiformats import CID from oauth_dropins.webutil.testutil import NOW, requests_response from app import app -from atproto_mst import dag_cbor_cid import common from models import AtpNode, Follower, Object, OBJECT_EXPIRE_AGE, User import protocol diff --git a/tests/testutil.py b/tests/testutil.py index 41d156a..d341e3c 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -6,8 +6,8 @@ import random import unittest from unittest.mock import ANY, call -import atproto_util -from atproto_util import datetime_to_tid +import arroba.util +from arroba.util import datetime_to_tid import dag_cbor.random from flask import g from google.cloud import ndb @@ -90,7 +90,7 @@ class TestCase(unittest.TestCase, testutil.Asserts): FakeProtocol.fetched = [] # make random test data deterministic - atproto_util._clockid = 17 + arroba.util._clockid = 17 random.seed(1234567890) dag_cbor.random.set_options(seed=1234567890)