kopia lustrzana https://github.com/snarfed/bridgy-fed
Bluesky: move most ATProto code to separate arroba library
https://github.com/snarfed/arrobapull/505/head
rodzic
cc2ed9dd81
commit
0ab3698db7
|
@ -5,6 +5,8 @@ import logging
|
|||
import random
|
||||
import re
|
||||
|
||||
from arroba.mst import MST, serialize_node_data
|
||||
from arroba.util import dag_cbor_cid, sign_commit
|
||||
from Crypto.PublicKey import ECC
|
||||
import dag_cbor.encoding
|
||||
from flask import g
|
||||
|
@ -13,11 +15,6 @@ from granary import bluesky
|
|||
from multiformats import CID, multibase, multicodec, multihash
|
||||
from oauth_dropins.webutil import util
|
||||
|
||||
from atproto_mst import (
|
||||
MST,
|
||||
serialize_node_data,
|
||||
)
|
||||
from atproto_util import dag_cbor_cid, sign_commit
|
||||
from flask_app import xrpc_server
|
||||
from models import Follower, Object, PAGE_SIZE, User
|
||||
|
||||
|
|
283
atproto_diff.py
283
atproto_diff.py
|
@ -1,283 +0,0 @@
|
|||
"""AT Protocol utility for diffing two MSTs.
|
||||
|
||||
Heavily based on:
|
||||
https://github.com/bluesky/atproto/blob/main/packages/repo/src/mst/diff.ts
|
||||
|
||||
Huge thanks to the Bluesky team for working in the public, in open source, and to
|
||||
Daniel Holmgren and Devin Ivy for this code specifically!
|
||||
"""
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from atproto_mst import Leaf, MST, Walker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mst_diff(cur, prev=None):
|
||||
"""Generates a diff between two MSTs.
|
||||
|
||||
Args:
|
||||
cur: :class:`MST`
|
||||
prev: :class:`MST`, optional
|
||||
|
||||
Returns:
|
||||
:class:`Diff`
|
||||
"""
|
||||
cur.get_pointer()
|
||||
if not prev:
|
||||
return null_diff(cur)
|
||||
|
||||
prev.get_pointer()
|
||||
diff = Diff()
|
||||
|
||||
left_walker = Walker(prev)
|
||||
right_walker = Walker(cur)
|
||||
while not left_walker.status.done or not right_walker.status.done:
|
||||
# print(left_walker.status, right_walker.status)
|
||||
# if one walker is finished, continue walking the other & logging all nodes
|
||||
if left_walker.status.done and not right_walker.status.done:
|
||||
node = right_walker.status.cur
|
||||
if isinstance(node, Leaf):
|
||||
diff.record_add(node.key, node.value)
|
||||
else:
|
||||
diff.record_new_cid(node.pointer)
|
||||
right_walker.advance()
|
||||
continue
|
||||
|
||||
elif not left_walker.status.done and right_walker.status.done:
|
||||
node = left_walker.status.cur
|
||||
if isinstance(node, Leaf):
|
||||
diff.record_delete(node.key, node.value)
|
||||
else:
|
||||
diff.record_removed_cid(node.pointer)
|
||||
left_walker.advance()
|
||||
continue
|
||||
|
||||
if left_walker.status.done or right_walker.status.done:
|
||||
break
|
||||
|
||||
left = left_walker.status.cur
|
||||
right = right_walker.status.cur
|
||||
if not left or not right:
|
||||
break
|
||||
|
||||
# if both pointers are leaves, record an update & advance both or record
|
||||
# the lowest key and advance that pointer
|
||||
if isinstance(left, Leaf) and isinstance(right, Leaf):
|
||||
if left.key == right.key:
|
||||
if left.value != right.value:
|
||||
diff.record_update(left.key, left.value, right.value)
|
||||
left_walker.advance()
|
||||
right_walker.advance()
|
||||
elif left.key < right.key:
|
||||
diff.record_delete(left.key, left.value)
|
||||
left_walker.advance()
|
||||
else:
|
||||
diff.record_add(right.key, right.value)
|
||||
right_walker.advance()
|
||||
|
||||
continue
|
||||
|
||||
# next, ensure that we're on the same layer
|
||||
#
|
||||
# if one walker is at a higher layer than the other, we need to do one
|
||||
# of two things if the higher walker is pointed at a tree, step into
|
||||
# that tree to try to catch up with the lower if the higher walker is
|
||||
# pointed at a leaf, then advance the lower walker to try to catch up
|
||||
# the higher
|
||||
if left_walker.layer() > right_walker.layer():
|
||||
if isinstance(left, Leaf):
|
||||
if isinstance(right, Leaf):
|
||||
diff.record_add(right.key, right.value)
|
||||
else:
|
||||
diff.record_new_cid(right.pointer)
|
||||
|
||||
right_walker.advance()
|
||||
else:
|
||||
diff.record_removed_cid(left.pointer)
|
||||
left_walker.step_into()
|
||||
|
||||
continue
|
||||
|
||||
elif left_walker.layer() < right_walker.layer():
|
||||
if isinstance(right, Leaf):
|
||||
if isinstance(left, Leaf):
|
||||
diff.record_delete(left.key, left.value)
|
||||
else:
|
||||
diff.record_removed_cid(left.pointer)
|
||||
|
||||
left_walker.advance()
|
||||
|
||||
else:
|
||||
diff.record_new_cid(right.pointer)
|
||||
right_walker.step_into()
|
||||
|
||||
continue
|
||||
|
||||
# if we're on the same level, and both pointers are trees, do a
|
||||
# comparison. if they're the same, step over. if they're different, step
|
||||
# in to find the subdiff
|
||||
if isinstance(left, MST) and isinstance(right, MST):
|
||||
if left.pointer == right.pointer:
|
||||
left_walker.step_over()
|
||||
right_walker.step_over()
|
||||
else:
|
||||
diff.record_new_cid(right.pointer)
|
||||
diff.record_removed_cid(left.pointer)
|
||||
left_walker.step_into()
|
||||
right_walker.step_into()
|
||||
|
||||
continue
|
||||
|
||||
# finally, if one pointer is a tree and the other is a leaf, simply step
|
||||
# into the tree
|
||||
if isinstance(left, Leaf) and isinstance(right, MST):
|
||||
diff.record_new_cid(right.pointer)
|
||||
right_walker.step_into()
|
||||
continue
|
||||
|
||||
elif isinstance(left, MST) and isinstance(right, Leaf):
|
||||
diff.record_removed_cid(left.pointer)
|
||||
left_walker.step_into()
|
||||
continue
|
||||
|
||||
raise RuntimeError('Unidentifiable case in diff walk')
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def null_diff(tree):
|
||||
"""Generates a "null" diff for a single MST with all adds and new CIDs.
|
||||
|
||||
Args:
|
||||
tree: :class:`MST`
|
||||
|
||||
Returns:
|
||||
:class:`Diff`
|
||||
"""
|
||||
diff = Diff()
|
||||
|
||||
for entry in tree.walk():
|
||||
if isinstance(entry, Leaf):
|
||||
diff.record_add(entry.key, entry.value)
|
||||
else:
|
||||
diff.record_new_cid(entry.pointer)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
Change = namedtuple('Change', [
|
||||
'key', # str
|
||||
'cid', # :class:`CID`
|
||||
'prev', # :class:`CID`
|
||||
], defaults=[None])
|
||||
|
||||
|
||||
class Diff:
|
||||
"""A diff between two MSTs.
|
||||
|
||||
Attributes:
|
||||
adds: {str key: :class:`Change`}
|
||||
updates: {str key: :class:`Change`}
|
||||
deletes: {str key: :class:`Change`}
|
||||
new_cids: set of :class:`CID`
|
||||
removed_cids: set of :class:`CID`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.adds = {}
|
||||
self.updates = {}
|
||||
self.deletes = {}
|
||||
self.new_cids = set()
|
||||
self.removed_cids = set()
|
||||
|
||||
@staticmethod
|
||||
def of(cur, prev=None):
|
||||
"""
|
||||
Args:
|
||||
cur: :class:`MST`
|
||||
prev: :class:`MST`, optional
|
||||
|
||||
Returns:
|
||||
:class:`Diff`
|
||||
"""
|
||||
return mst_diff(cur, prev)
|
||||
|
||||
def record_add(self, key, cid):
|
||||
"""
|
||||
Args:
|
||||
key: str
|
||||
cid: :class:`CID`
|
||||
"""
|
||||
self.adds[key] = Change(key=key, cid=cid)
|
||||
self.new_cids.add(cid)
|
||||
|
||||
def record_update(self, key, prev, cid):
|
||||
"""
|
||||
Args:
|
||||
key: str
|
||||
prev: :class:`CID`
|
||||
cid: :class:`CID`
|
||||
"""
|
||||
self.updates[key] = Change(key=key, cid=cid, prev=prev)
|
||||
self.new_cids.add(cid)
|
||||
|
||||
def record_delete(self, key, cid):
|
||||
"""
|
||||
Args:
|
||||
key: str
|
||||
cid: :class:`CID`
|
||||
"""
|
||||
self.deletes[key] = Change(key=key, cid=cid)
|
||||
|
||||
def record_new_cid(self, cid):
|
||||
"""
|
||||
Args:
|
||||
cid: :class:`CID`
|
||||
"""
|
||||
if cid in self.removed_cids:
|
||||
self.removed_cids.remove(cid)
|
||||
else:
|
||||
self.new_cids.add(cid)
|
||||
|
||||
def record_removed_cid(self, cid):
|
||||
"""
|
||||
Args:
|
||||
cid: :class:`CID`
|
||||
"""
|
||||
if cid in self.new_cids:
|
||||
self.new_cids.remove(cid)
|
||||
else:
|
||||
self.removed_cids.add(cid)
|
||||
|
||||
def add_diff(self, diff):
|
||||
"""
|
||||
Args:
|
||||
diff: :class:`Diff`
|
||||
"""
|
||||
for add in diff.adds.values():
|
||||
if self.deletes[add.key]:
|
||||
deleted = self.deletes[add.key]
|
||||
if deleted.cid != add.cid:
|
||||
self.record_update(add.key, deleted.cid, add.cid)
|
||||
del self.deletes[add.key]
|
||||
else:
|
||||
self.record_add(add.key, add.cid)
|
||||
|
||||
for update in diff.updates.values():
|
||||
self.record_update(update.key, update.prev, update.cid)
|
||||
del self.adds[update.key]
|
||||
del self.deletes[update.key]
|
||||
|
||||
for deleted in diff.deletes.values():
|
||||
if self.adds[deleted.key]:
|
||||
del self.adds[deleted.key]
|
||||
else:
|
||||
del self.updates[deleted.key]
|
||||
self.record_delete(deleted.key, deleted.cid)
|
||||
|
||||
self.new_cids |= diff.new_cids
|
||||
|
||||
def updated_keys(self):
|
||||
return self.adds | self.updates | self.deletes
|
1119
atproto_mst.py
1119
atproto_mst.py
Plik diff jest za duży
Load Diff
182
atproto_util.py
182
atproto_util.py
|
@ -1,182 +0,0 @@
|
|||
"""Misc AT Protocol utils. TIDs, CIDs, etc."""
|
||||
import copy
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
from numbers import Integral
|
||||
import random
|
||||
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.Signature import DSS
|
||||
import dag_cbor.encoding
|
||||
from multiformats import CID, multicodec, multihash
|
||||
from oauth_dropins.webutil.appengine_info import DEBUG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# the bottom 32 clock ids can be randomized & are not guaranteed to be collision
|
||||
# resistant. we use the same clockid for all TIDs coming from this runtime.
|
||||
_clockid = random.randint(0, 31)
|
||||
# _tid_last = time.time_ns() // 1000 # microseconds
|
||||
|
||||
S32_CHARS = '234567abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
|
||||
def dag_cbor_cid(obj):
|
||||
"""Returns the DAG-CBOR CID for a given object.
|
||||
|
||||
Args:
|
||||
obj: CBOR-compatible native object or value
|
||||
|
||||
Returns:
|
||||
:class:`CID`
|
||||
"""
|
||||
encoded = dag_cbor.encoding.encode(obj)
|
||||
digest = multihash.digest(encoded, 'sha2-256')
|
||||
return CID('base58btc', 1, multicodec.get('dag-cbor'), digest)
|
||||
|
||||
|
||||
def s32encode(num):
|
||||
"""Base32 encode with encoding variant sort.
|
||||
|
||||
Based on https://github.com/bluesky-social/atproto/blob/main/packages/common-web/src/tid.ts
|
||||
|
||||
Args:
|
||||
num: int or Integral
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
assert isinstance(num, Integral)
|
||||
|
||||
encoded = []
|
||||
while num > 0:
|
||||
c = num % 32
|
||||
num = num // 32
|
||||
encoded.insert(0, S32_CHARS[c])
|
||||
|
||||
return ''.join(encoded)
|
||||
|
||||
|
||||
def s32decode(val):
|
||||
"""Base32 decode with encoding variant sort.
|
||||
|
||||
Based on https://github.com/bluesky-social/atproto/blob/main/packages/common-web/src/tid.ts
|
||||
|
||||
Args:
|
||||
val: str
|
||||
|
||||
Returns:
|
||||
int or Integral
|
||||
"""
|
||||
i = 0
|
||||
for c in val:
|
||||
i = i * 32 + S32_CHARS.index(c)
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def datetime_to_tid(dt):
|
||||
"""Converts a datetime to an ATProto TID.
|
||||
|
||||
https://atproto.com/guides/data-repos#identifier-types
|
||||
|
||||
Args:
|
||||
dt: :class:`datetime.datetime`
|
||||
|
||||
Returns:
|
||||
str, base32-encoded TID
|
||||
"""
|
||||
tid = (s32encode(int(dt.timestamp() * 1000 * 1000)) +
|
||||
s32encode(_clockid).ljust(2, '2'))
|
||||
assert len(tid) == 13
|
||||
return tid
|
||||
|
||||
|
||||
def tid_to_datetime(tid):
|
||||
"""Converts an ATProto TID to a datetime.
|
||||
|
||||
https://atproto.com/guides/data-repos#identifier-types
|
||||
|
||||
Args:
|
||||
tid: bytes, base32-encoded TID
|
||||
|
||||
Returns:
|
||||
:class:`datetime.datetime`
|
||||
|
||||
Raises:
|
||||
ValueError if tid is not bytes or not 13 characters long
|
||||
"""
|
||||
if not isinstance(tid, (str, bytes)) or len(tid) != 13:
|
||||
raise ValueError(f'Expected 13-character str or bytes; got {tid}')
|
||||
|
||||
encoded = tid.replace('-', '')[:-2] # strip clock id
|
||||
return datetime.fromtimestamp(s32decode(encoded) / 1000 / 1000, timezone.utc)
|
||||
|
||||
|
||||
# TODO
|
||||
# def next_tid():
|
||||
# global _tid_last
|
||||
|
||||
# # enforce that we're at least 1us after the last TID to prevent TIDs moving
|
||||
# # backwards if system clock drifts backwards
|
||||
# now = time.time_ns() // 1000
|
||||
# if now > _tid_last:
|
||||
# _tid_last = now
|
||||
# else:
|
||||
# _tid_last += 1
|
||||
# now = _tid_last
|
||||
|
||||
|
||||
def sign_commit(commit, key):
|
||||
"""Signs a repo commit.
|
||||
|
||||
Adds the signature in the `sig` field.
|
||||
|
||||
Signing isn't yet in the atproto.com docs, this setup is taken from the TS
|
||||
code and conversations with @why on #bluesky-dev:matrix.org.
|
||||
|
||||
* https://matrix.to/#/!vpdMrhHjzaPbBUSgOs:matrix.org/$Xaf4ugYks-iYg7Pguh3dN8hlsvVMUOuCQo3fMiYPXTY?via=matrix.org&via=minds.com&via=envs.net
|
||||
* https://github.com/bluesky-social/atproto/blob/384e739a3b7d34f7a95d6ba6f08e7223a7398995/packages/repo/src/util.ts#L238-L248
|
||||
* https://github.com/bluesky-social/atproto/blob/384e739a3b7d34f7a95d6ba6f08e7223a7398995/packages/crypto/src/p256/keypair.ts#L66-L73
|
||||
* https://github.com/bluesky-social/indigo/blob/f1f2480888ab5d0ac1e03bd9b7de090a3d26cd13/repo/repo.go#L64-L70
|
||||
* https://github.com/whyrusleeping/go-did/blob/2146016fc220aa1e08ccf26aaa762f5a11a81404/key.go#L67-L91
|
||||
|
||||
The signature is ECDSA around SHA-256 of the input. We currently use P-256
|
||||
keypairs. Context:
|
||||
* Go supports P-256, ED25519, SECP256K1 keys
|
||||
* TS supports P-256, SECP256K1 keys
|
||||
* this recommends ED25519, then P-256:
|
||||
https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
|
||||
|
||||
Args:
|
||||
commit: dict repo commit
|
||||
key: :class:`Crypto.PublicKey.ECC.EccKey`
|
||||
"""
|
||||
signer = DSS.new(key, 'fips-186-3', randfunc=random.randbytes if DEBUG else None)
|
||||
commit['sig'] = signer.sign(SHA256.new(dag_cbor.encoding.encode(commit)))
|
||||
|
||||
|
||||
def verify_commit_sig(commit, key):
|
||||
"""Returns true if the commit's signature is valid, False otherwise.
|
||||
|
||||
See :func:`sign_commit` for more background.
|
||||
|
||||
Args:
|
||||
commit: dict repo commit
|
||||
key: :class:`Crypto.PublicKey.ECC.EccKey`
|
||||
|
||||
Raises:
|
||||
KeyError if the commit isn't signed, ie doesn't have a `sig` field
|
||||
"""
|
||||
commit = copy.copy(commit)
|
||||
sig = commit.pop('sig')
|
||||
|
||||
verifier = DSS.new(key.public_key(), 'fips-186-3',
|
||||
randfunc=random.randbytes if DEBUG else None)
|
||||
try:
|
||||
verifier.verify(SHA256.new(dag_cbor.encoding.encode(commit)), sig)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.debug("Couldn't verify signature", exc_info=True)
|
||||
return False
|
||||
|
|
@ -341,6 +341,7 @@ texinfo_documents = [
|
|||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
'arroba': ('https://arroba.readthedocs.io/en/latest', None),
|
||||
'dag_cbor': ('https://dag-cbor.readthedocs.io/en/latest', None),
|
||||
'flask': ('https://flask.palletsprojects.com/en/latest', None),
|
||||
'flask_caching': ('https://flask-caching.readthedocs.io/en/latest', None),
|
||||
|
|
|
@ -9,10 +9,6 @@ activitypub
|
|||
-----------
|
||||
.. automodule:: activitypub
|
||||
|
||||
atproto_mst
|
||||
-----------
|
||||
.. automodule:: atproto
|
||||
|
||||
common
|
||||
------
|
||||
.. automodule:: common
|
||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
|||
import random
|
||||
import urllib.parse
|
||||
|
||||
from arroba.mst import dag_cbor_cid
|
||||
from Crypto import Random
|
||||
from Crypto.PublicKey import ECC, RSA
|
||||
from Crypto.Util import number
|
||||
|
@ -22,7 +23,6 @@ from oauth_dropins.webutil.util import json_dumps, json_loads
|
|||
import requests
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from atproto_mst import dag_cbor_cid
|
||||
import common
|
||||
|
||||
# https://github.com/snarfed/bridgy-fed/issues/314
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
git+https://github.com/snarfed/arroba.git#egg=arroba
|
||||
git+https://github.com/dvska/gdata-python3.git#egg=gdata
|
||||
git+https://github.com/snarfed/granary.git#egg=granary
|
||||
git+https://github.com/snarfed/dag-json.git#egg=dag_json
|
||||
|
|
|
@ -5,6 +5,9 @@ import random
|
|||
from unittest import skip
|
||||
from unittest.mock import patch
|
||||
|
||||
import atproto
|
||||
import arroba.util
|
||||
from arroba.util import dag_cbor_cid, verify_commit_sig
|
||||
from Crypto.PublicKey import ECC
|
||||
import dag_cbor.decoding, dag_cbor.encoding
|
||||
from granary import as2, bluesky
|
||||
|
@ -22,14 +25,11 @@ from multiformats import CID
|
|||
from oauth_dropins.webutil import util
|
||||
from oauth_dropins.webutil.testutil import NOW
|
||||
|
||||
import atproto
|
||||
import atproto_util
|
||||
from atproto_util import dag_cbor_cid, verify_commit_sig
|
||||
from flask_app import app
|
||||
from models import Follower, Object, User
|
||||
from . import testutil
|
||||
|
||||
# # atproto_mst.Data entry for MST with POST_AS, REPLY_AS, and REPOST_AS
|
||||
# # arroba.mst.Data entry for MST with POST_AS, REPLY_AS, and REPOST_AS
|
||||
# POST_CID = 'bafyreic5xwex7jxqvliumozkoli3qy2hzxrmui6odl7ujrcybqaypacfiy'
|
||||
# REPLY_CID = 'bafyreib55ro37wasiflouvlfenhzllorcthm7flr2nj4fnk7yjo54cagvm'
|
||||
# REPOST_CID = 'bafyreiek3jnp6e4sussy4c7pwtbkkf3kepekzycylowwuepmnvq7aeng44'
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
"""Unit tests for atproto_diff.py.
|
||||
|
||||
Heavily based on:
|
||||
https://github.com/bluesky/atproto/blob/main/packages/repo/tests/sync/diff.test.ts
|
||||
|
||||
Huge thanks to the Bluesky team for working in the public, in open source, and to
|
||||
Daniel Holmgren and Devin Ivy for this code specifically!
|
||||
"""
|
||||
import dag_cbor.random
|
||||
|
||||
from atproto_diff import Change, Diff
|
||||
from atproto_mst import MST
|
||||
from . import testutil
|
||||
|
||||
|
||||
class AtProtoDiffTest(testutil.TestCase):
|
||||
|
||||
def test_diffs(self):
|
||||
mst = MST()
|
||||
data = self.random_keys_and_cids(3)#1000)
|
||||
for key, cid in data:
|
||||
mst = mst.add(key, cid)
|
||||
|
||||
before = after = mst
|
||||
|
||||
to_add = self.random_keys_and_cids(1)#100)
|
||||
to_edit = data[1:2]
|
||||
to_del = data[0:1]
|
||||
|
||||
# these are all {str key: Change}
|
||||
expected_adds = {}
|
||||
expected_updates = {}
|
||||
expected_deletes = {}
|
||||
|
||||
for key, cid in to_add:
|
||||
after = after.add(key, cid)
|
||||
expected_adds[key] = Change(key=key, cid=cid)
|
||||
|
||||
for (key, prev), new in zip(to_edit, dag_cbor.random.rand_cid()):
|
||||
after = after.update(key, new)
|
||||
expected_updates[key] = Change(key=key, prev=prev, cid=new)
|
||||
|
||||
for key, cid in to_del:
|
||||
after = after.delete(key)
|
||||
expected_deletes[key] = Change(key=key, cid=cid)
|
||||
|
||||
diff = Diff.of(after, before)
|
||||
|
||||
self.assertEqual(1, len(diff.adds))
|
||||
self.assertEqual(1, len(diff.updates))
|
||||
self.assertEqual(1, len(diff.deletes))
|
||||
|
||||
self.assertEqual(expected_adds, diff.adds)
|
||||
self.assertEqual(expected_updates, diff.updates)
|
||||
self.assertEqual(expected_deletes, diff.deletes)
|
||||
|
||||
# ensure we correctly report all added CIDs
|
||||
for entry in after.walk():
|
||||
cid = entry.get_pointer() if isinstance(entry, MST) else entry.value
|
||||
# TODO
|
||||
# assert cid in blockstore or cid in diff.new_cids
|
||||
|
|
@ -1,286 +0,0 @@
|
|||
"""Unit tests for atproto_mst.py.
|
||||
|
||||
Heavily based on:
|
||||
https://github.com/bluesky/atproto/blob/main/packages/repo/tests/mst.test.ts
|
||||
|
||||
Huge thanks to the Bluesky team for working in the public, in open source, and to
|
||||
Daniel Holmgren and Devin Ivy for this code specifically!
|
||||
"""
|
||||
from base64 import b32encode
|
||||
from datetime import datetime
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
from unittest import skip
|
||||
|
||||
import dag_cbor.random
|
||||
from multiformats import CID
|
||||
|
||||
from atproto_mst import common_prefix_len, ensure_valid_key, MST
|
||||
import atproto_util
|
||||
from . import testutil
|
||||
|
||||
CID1 = CID.decode('bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454')
|
||||
|
||||
|
||||
class MstTest(testutil.TestCase):
|
||||
|
||||
def test_add(self):
|
||||
mst = MST()
|
||||
data = self.random_keys_and_cids(1000)
|
||||
for key, cid in data:
|
||||
mst = mst.add(key, cid)
|
||||
|
||||
for key, cid in data:
|
||||
got = mst.get(key)
|
||||
self.assertEqual(cid, got)
|
||||
|
||||
self.assertEqual(1000, mst.leaf_count())
|
||||
|
||||
def test_edits_records(self):
|
||||
mst = MST()
|
||||
data = self.random_keys_and_cids(100)
|
||||
|
||||
for key, cid in data:
|
||||
mst = mst.add(key, cid)
|
||||
|
||||
edited = []
|
||||
for (key, _), cid in zip(data, dag_cbor.random.rand_cid()):
|
||||
mst = mst.update(key, cid)
|
||||
edited.append([key, cid])
|
||||
|
||||
for key, cid in edited:
|
||||
self.assertEqual(cid, mst.get(key))
|
||||
|
||||
self.assertEqual(100, mst.leaf_count())
|
||||
|
||||
def test_deletes_records(self):
|
||||
mst = MST()
|
||||
data = self.random_keys_and_cids(1000)
|
||||
for key, cid in data:
|
||||
mst = mst.add(key, cid)
|
||||
|
||||
to_delete = data[:100]
|
||||
the_rest = data[100:]
|
||||
for key, _ in to_delete:
|
||||
mst = mst.delete(key)
|
||||
|
||||
self.assertEqual(900, mst.leaf_count())
|
||||
|
||||
for key, _ in to_delete:
|
||||
self.assertIsNone(mst.get(key))
|
||||
|
||||
for key, cid in the_rest:
|
||||
self.assertEqual(cid, mst.get(key))
|
||||
|
||||
def test_is_order_independent(self):
|
||||
mst = MST()
|
||||
data = self.random_keys_and_cids(1000)
|
||||
for key, cid in data:
|
||||
mst = mst.add(key, cid)
|
||||
|
||||
all_nodes = mst.all_nodes()
|
||||
|
||||
recreated = MST()
|
||||
random.shuffle(data)
|
||||
for key, cid in data:
|
||||
recreated = recreated.add(key, cid)
|
||||
|
||||
self.assertEqual(all_nodes, recreated.all_nodes())
|
||||
|
||||
def test_common_prefix_length(self):
|
||||
self.assertEqual(3, common_prefix_len('abc', 'abc'))
|
||||
self.assertEqual(0, common_prefix_len('', 'abc'))
|
||||
self.assertEqual(0, common_prefix_len('abc', ''))
|
||||
self.assertEqual(2, common_prefix_len('ab', 'abc'))
|
||||
self.assertEqual(2, common_prefix_len('abc', 'ab'))
|
||||
self.assertEqual(3, common_prefix_len('abcde', 'abc'))
|
||||
self.assertEqual(3, common_prefix_len('abc', 'abcde'))
|
||||
self.assertEqual(3, common_prefix_len('abcde', 'abc1'))
|
||||
self.assertEqual(2, common_prefix_len('abcde', 'abb'))
|
||||
self.assertEqual(0, common_prefix_len('abcde', 'qbb'))
|
||||
self.assertEqual(0, common_prefix_len('', 'asdf'))
|
||||
self.assertEqual(3, common_prefix_len('abc', 'abc\x00'))
|
||||
self.assertEqual(3, common_prefix_len('abc\x00', 'abc'))
|
||||
|
||||
def test_rejects_the_empty_key(self):
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add('')
|
||||
|
||||
def test_rejects_a_key_with_no_collection(self):
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add('asdf')
|
||||
|
||||
def test_rejects_a_key_with_a_nested_collection(self):
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add('nested/collection/asdf')
|
||||
|
||||
def test_rejects_on_empty_coll_or_rkey(self):
|
||||
for key in 'coll/', '/rkey':
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add(key)
|
||||
|
||||
def test_rejects_non_ascii_chars(self):
|
||||
for key in 'coll/jalapeñoA', 'coll/coöperative', 'coll/abc💩':
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add(key)
|
||||
|
||||
def test_rejects_ascii_that_we_dont_support(self):
|
||||
for key in ('coll/key$', 'coll/key%', 'coll/key(', 'coll/key)',
|
||||
'coll/key+', 'coll/key='):
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add(key)
|
||||
|
||||
def test_rejects_keys_over_256_chars(self):
|
||||
with self.assertRaises(ValueError):
|
||||
MST().add(
|
||||
'coll/asdofiupoiwqeurfpaosidfuapsodirupasoirupasoeiruaspeoriuaspeoriu2p3o4iu1pqw3oiuaspdfoiuaspdfoiuasdfpoiasdufpwoieruapsdofiuaspdfoiuasdpfoiausdfpoasidfupasodifuaspdofiuasdpfoiasudfpoasidfuapsodfiuasdpfoiausdfpoasidufpasodifuapsdofiuasdpofiuasdfpoaisdufpao',
|
||||
)
|
||||
|
||||
def test_computes_empty_tree_root_CID(self):
|
||||
self.assertEqual(0, MST().leaf_count())
|
||||
self.assertEqual(
|
||||
'bafyreie5737gdxlw5i64vzichcalba3z2v5n6icifvx5xytvske7mr3hpm',
|
||||
MST().get_pointer().encode('base32'))
|
||||
|
||||
def test_computes_trivial_tree_root_CID(self):
|
||||
mst = MST().add('com.example.record/3jqfcqzm3fo2j', CID1)
|
||||
self.assertEqual(1, mst.leaf_count())
|
||||
self.assertEqual(
|
||||
'bafyreibj4lsc3aqnrvphp5xmrnfoorvru4wynt6lwidqbm2623a6tatzdu',
|
||||
mst.get_pointer().encode('base32'))
|
||||
|
||||
def test_computes_single_layer_2_tree_root_CID(self):
|
||||
mst = MST().add('com.example.record/3jqfcqzm3fx2j', CID1)
|
||||
self.assertEqual(1, mst.leaf_count())
|
||||
self.assertEqual(2, mst.layer)
|
||||
self.assertEqual(
|
||||
'bafyreih7wfei65pxzhauoibu3ls7jgmkju4bspy4t2ha2qdjnzqvoy33ai',
|
||||
mst.get_pointer().encode('base32'))
|
||||
|
||||
def test_computes_simple_tree_root_CID(self):
|
||||
mst = MST()
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fr2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # level 1
|
||||
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm4fc2j', CID1) # level 0
|
||||
self.assertEqual(5, mst.leaf_count())
|
||||
self.assertEqual(
|
||||
'bafyreicmahysq4n6wfuxo522m6dpiy7z7qzym3dzs756t5n7nfdgccwq7m',
|
||||
mst.get_pointer().encode('base32'))
|
||||
|
||||
def test_trims_top_of_tree_on_delete(self):
|
||||
l1root = 'bafyreifnqrwbk6ffmyaz5qtujqrzf5qmxf7cbxvgzktl4e3gabuxbtatv4'
|
||||
l0root = 'bafyreie4kjuxbwkhzg2i5dljaswcroeih4dgiqq6pazcmunwt2byd725vi'
|
||||
|
||||
mst = MST()
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fn2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fo2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # level 1
|
||||
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fu2j', CID1) # level 0
|
||||
|
||||
self.assertEqual(6, mst.leaf_count())
|
||||
self.assertEqual(1, mst.layer)
|
||||
self.assertEqual(l1root, mst.get_pointer().encode('base32'))
|
||||
|
||||
mst = mst.delete('com.example.record/3jqfcqzm3fs2j') # level 1
|
||||
self.assertEqual(5, mst.leaf_count())
|
||||
self.assertEqual(0, mst.layer)
|
||||
self.assertEqual(l0root, mst.get_pointer().encode('base32'))
|
||||
|
||||
def test_handles_insertion_that_splits_two_layers_down(self):
|
||||
"""
|
||||
* *
|
||||
_________|________ ____|_____
|
||||
| | | | | | | |
|
||||
* d * i * -> * f *
|
||||
__|__ __|__ __|__ __|__ __|___
|
||||
| | | | | | | | | | | | | | |
|
||||
a b c e g h j k l * d * * i *
|
||||
__|__ | _|_ __|__
|
||||
| | | | | | | | |
|
||||
a b c e g h j k l
|
||||
"""
|
||||
l1root = 'bafyreiettyludka6fpgp33stwxfuwhkzlur6chs4d2v4nkmq2j3ogpdjem'
|
||||
l2root = 'bafyreid2x5eqs4w4qxvc5jiwda4cien3gw2q6cshofxwnvv7iucrmfohpm'
|
||||
|
||||
mst = MST()
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fo2j', CID1) # A; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # B; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fr2j', CID1) # C; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # D; level 1
|
||||
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # E; level 0
|
||||
# GAP for F
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fz2j', CID1) # G; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm4fc2j', CID1) # H; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm4fd2j', CID1) # I; level 1
|
||||
mst = mst.add('com.example.record/3jqfcqzm4ff2j', CID1) # J; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm4fg2j', CID1) # K; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm4fh2j', CID1) # L; level 0
|
||||
|
||||
self.assertEqual(11, mst.leaf_count())
|
||||
self.assertEqual(1, mst.layer)
|
||||
self.assertEqual(l1root, mst.get_pointer().encode('base32'))
|
||||
|
||||
# insert F, which will push E out in the node with G+H to a new node under D
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # F; level 2
|
||||
self.assertEqual(12, mst.leaf_count())
|
||||
self.assertEqual(2, mst.layer)
|
||||
self.assertEqual(l2root, mst.get_pointer().encode('base32'))
|
||||
|
||||
# remove F, which should push E back over with G+H
|
||||
mst = mst.delete('com.example.record/3jqfcqzm3fx2j') # F; level 2
|
||||
self.assertEqual(11, mst.leaf_count())
|
||||
self.assertEqual(1, mst.layer)
|
||||
self.assertEqual(l1root, mst.get_pointer().encode('base32'))
|
||||
|
||||
def test_handles_new_layers_that_are_two_higher_than_existing(self):
|
||||
"""
|
||||
* -> *
|
||||
__|__ __|__
|
||||
| | | | |
|
||||
a c * b *
|
||||
| |
|
||||
* *
|
||||
| |
|
||||
a c
|
||||
"""
|
||||
|
||||
l0root = 'bafyreidfcktqnfmykz2ps3dbul35pepleq7kvv526g47xahuz3rqtptmky'
|
||||
l2root = 'bafyreiavxaxdz7o7rbvr3zg2liox2yww46t7g6hkehx4i4h3lwudly7dhy'
|
||||
l2root2 = 'bafyreig4jv3vuajbsybhyvb7gggvpwh2zszwfyttjrj6qwvcsp24h6popu'
|
||||
|
||||
mst = MST()
|
||||
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # A; level 0
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fz2j', CID1) # C; level 0
|
||||
self.assertEqual(2, mst.leaf_count())
|
||||
self.assertEqual(0, mst.layer)
|
||||
self.assertEqual(l0root, mst.get_pointer().encode('base32'))
|
||||
|
||||
# insert B, which is two levels above
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # B; level 2
|
||||
self.assertEqual(3, mst.leaf_count())
|
||||
self.assertEqual(2, mst.layer)
|
||||
self.assertEqual(l2root, mst.get_pointer().encode('base32'))
|
||||
|
||||
# remove B
|
||||
mst = mst.delete('com.example.record/3jqfcqzm3fx2j') # B; level 2
|
||||
self.assertEqual(2, mst.leaf_count())
|
||||
self.assertEqual(0, mst.layer)
|
||||
self.assertEqual(l0root, mst.get_pointer().encode('base32'))
|
||||
|
||||
# insert B (level=2) and D (level=1)
|
||||
mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # B; level 2
|
||||
mst = mst.add('com.example.record/3jqfcqzm4fd2j', CID1) # D; level 1
|
||||
self.assertEqual(4, mst.leaf_count())
|
||||
self.assertEqual(2, mst.layer)
|
||||
self.assertEqual(l2root2, mst.get_pointer().encode('base32'))
|
||||
|
||||
# remove D
|
||||
mst = mst.delete('com.example.record/3jqfcqzm4fd2j') # D; level 1
|
||||
self.assertEqual(3, mst.leaf_count())
|
||||
self.assertEqual(2, mst.layer)
|
||||
self.assertEqual(l2root, mst.get_pointer().encode('base32'))
|
|
@ -1,46 +0,0 @@
|
|||
"""Unit tests for atproto_util.py."""
|
||||
|
||||
from Crypto.PublicKey import ECC
|
||||
from multiformats import CID
|
||||
from oauth_dropins.webutil.testutil import NOW
|
||||
|
||||
import atproto_util
|
||||
from atproto_util import (
|
||||
dag_cbor_cid,
|
||||
datetime_to_tid,
|
||||
sign_commit,
|
||||
tid_to_datetime,
|
||||
verify_commit_sig,
|
||||
)
|
||||
from . import testutil
|
||||
|
||||
|
||||
class AtProtoUtilTest(testutil.TestCase):
|
||||
|
||||
def test_dag_cbor_cid(self):
|
||||
self.assertEqual(
|
||||
CID.decode('bafyreiblaotetvwobe7cu2uqvnddr6ew2q3cu75qsoweulzku2egca4dxq'),
|
||||
dag_cbor_cid({'foo': 'bar'}))
|
||||
|
||||
def test_datetime_to_tid(self):
|
||||
self.assertEqual('3iom4o4g6u2l2', datetime_to_tid(NOW))
|
||||
|
||||
def test_tid_to_datetime(self):
|
||||
self.assertEqual(NOW, tid_to_datetime('3iom4o4g6u2l2'))
|
||||
|
||||
def test_sign_commit_and_verify(self):
|
||||
user = self.make_user('user.com')
|
||||
|
||||
commit = {'foo': 'bar'}
|
||||
key = ECC.import_key(user.p256_key)
|
||||
sign_commit(commit, key)
|
||||
assert verify_commit_sig(commit, key)
|
||||
|
||||
def test_verify_commit_error(self):
|
||||
key = ECC.import_key(self.make_user('user.com').p256_key)
|
||||
with self.assertRaises(KeyError):
|
||||
self.assertFalse(verify_commit_sig({'foo': 'bar'}, key))
|
||||
|
||||
def test_verify_commit_fail(self):
|
||||
key = ECC.import_key(self.make_user('user.com').p256_key)
|
||||
self.assertFalse(verify_commit_sig({'foo': 'bar', 'sig': 'nope'}, key))
|
|
@ -2,6 +2,7 @@
|
|||
"""Unit tests for models.py."""
|
||||
from unittest import mock
|
||||
|
||||
from arroba.mst import dag_cbor_cid
|
||||
from Crypto.PublicKey import ECC
|
||||
from flask import g, get_flashed_messages
|
||||
from granary import as2
|
||||
|
@ -10,7 +11,6 @@ from multiformats import CID
|
|||
from oauth_dropins.webutil.testutil import NOW, requests_response
|
||||
|
||||
from app import app
|
||||
from atproto_mst import dag_cbor_cid
|
||||
import common
|
||||
from models import AtpNode, Follower, Object, OBJECT_EXPIRE_AGE, User
|
||||
import protocol
|
||||
|
|
|
@ -6,8 +6,8 @@ import random
|
|||
import unittest
|
||||
from unittest.mock import ANY, call
|
||||
|
||||
import atproto_util
|
||||
from atproto_util import datetime_to_tid
|
||||
import arroba.util
|
||||
from arroba.util import datetime_to_tid
|
||||
import dag_cbor.random
|
||||
from flask import g
|
||||
from google.cloud import ndb
|
||||
|
@ -90,7 +90,7 @@ class TestCase(unittest.TestCase, testutil.Asserts):
|
|||
FakeProtocol.fetched = []
|
||||
|
||||
# make random test data deterministic
|
||||
atproto_util._clockid = 17
|
||||
arroba.util._clockid = 17
|
||||
random.seed(1234567890)
|
||||
dag_cbor.random.set_options(seed=1234567890)
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue