Bluesky: move most ATProto code to separate arroba library

https://github.com/snarfed/arroba
pull/505/head
Ryan Barrett 2023-05-06 14:37:23 -07:00
rodzic cc2ed9dd81
commit 0ab3698db7
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
14 zmienionych plików z 13 dodań i 1996 usunięć

Wyświetl plik

@ -5,6 +5,8 @@ import logging
import random
import re
from arroba.mst import MST, serialize_node_data
from arroba.util import dag_cbor_cid, sign_commit
from Crypto.PublicKey import ECC
import dag_cbor.encoding
from flask import g
@ -13,11 +15,6 @@ from granary import bluesky
from multiformats import CID, multibase, multicodec, multihash
from oauth_dropins.webutil import util
from atproto_mst import (
MST,
serialize_node_data,
)
from atproto_util import dag_cbor_cid, sign_commit
from flask_app import xrpc_server
from models import Follower, Object, PAGE_SIZE, User

Wyświetl plik

@ -1,283 +0,0 @@
"""AT Protocol utility for diffing two MSTs.
Heavily based on:
https://github.com/bluesky/atproto/blob/main/packages/repo/src/mst/diff.ts
Huge thanks to the Bluesky team for working in the public, in open source, and to
Daniel Holmgren and Devin Ivy for this code specifically!
"""
from collections import namedtuple
import logging
from atproto_mst import Leaf, MST, Walker
logger = logging.getLogger(__name__)
def mst_diff(cur, prev=None):
"""Generates a diff between two MSTs.
Args:
cur: :class:`MST`
prev: :class:`MST`, optional
Returns:
:class:`Diff`
"""
cur.get_pointer()
if not prev:
return null_diff(cur)
prev.get_pointer()
diff = Diff()
left_walker = Walker(prev)
right_walker = Walker(cur)
while not left_walker.status.done or not right_walker.status.done:
# print(left_walker.status, right_walker.status)
# if one walker is finished, continue walking the other & logging all nodes
if left_walker.status.done and not right_walker.status.done:
node = right_walker.status.cur
if isinstance(node, Leaf):
diff.record_add(node.key, node.value)
else:
diff.record_new_cid(node.pointer)
right_walker.advance()
continue
elif not left_walker.status.done and right_walker.status.done:
node = left_walker.status.cur
if isinstance(node, Leaf):
diff.record_delete(node.key, node.value)
else:
diff.record_removed_cid(node.pointer)
left_walker.advance()
continue
if left_walker.status.done or right_walker.status.done:
break
left = left_walker.status.cur
right = right_walker.status.cur
if not left or not right:
break
# if both pointers are leaves, record an update & advance both or record
# the lowest key and advance that pointer
if isinstance(left, Leaf) and isinstance(right, Leaf):
if left.key == right.key:
if left.value != right.value:
diff.record_update(left.key, left.value, right.value)
left_walker.advance()
right_walker.advance()
elif left.key < right.key:
diff.record_delete(left.key, left.value)
left_walker.advance()
else:
diff.record_add(right.key, right.value)
right_walker.advance()
continue
# next, ensure that we're on the same layer
#
# if one walker is at a higher layer than the other, we need to do one
# of two things if the higher walker is pointed at a tree, step into
# that tree to try to catch up with the lower if the higher walker is
# pointed at a leaf, then advance the lower walker to try to catch up
# the higher
if left_walker.layer() > right_walker.layer():
if isinstance(left, Leaf):
if isinstance(right, Leaf):
diff.record_add(right.key, right.value)
else:
diff.record_new_cid(right.pointer)
right_walker.advance()
else:
diff.record_removed_cid(left.pointer)
left_walker.step_into()
continue
elif left_walker.layer() < right_walker.layer():
if isinstance(right, Leaf):
if isinstance(left, Leaf):
diff.record_delete(left.key, left.value)
else:
diff.record_removed_cid(left.pointer)
left_walker.advance()
else:
diff.record_new_cid(right.pointer)
right_walker.step_into()
continue
# if we're on the same level, and both pointers are trees, do a
# comparison. if they're the same, step over. if they're different, step
# in to find the subdiff
if isinstance(left, MST) and isinstance(right, MST):
if left.pointer == right.pointer:
left_walker.step_over()
right_walker.step_over()
else:
diff.record_new_cid(right.pointer)
diff.record_removed_cid(left.pointer)
left_walker.step_into()
right_walker.step_into()
continue
# finally, if one pointer is a tree and the other is a leaf, simply step
# into the tree
if isinstance(left, Leaf) and isinstance(right, MST):
diff.record_new_cid(right.pointer)
right_walker.step_into()
continue
elif isinstance(left, MST) and isinstance(right, Leaf):
diff.record_removed_cid(left.pointer)
left_walker.step_into()
continue
raise RuntimeError('Unidentifiable case in diff walk')
return diff
def null_diff(tree):
"""Generates a "null" diff for a single MST with all adds and new CIDs.
Args:
tree: :class:`MST`
Returns:
:class:`Diff`
"""
diff = Diff()
for entry in tree.walk():
if isinstance(entry, Leaf):
diff.record_add(entry.key, entry.value)
else:
diff.record_new_cid(entry.pointer)
return diff
Change = namedtuple('Change', [
'key', # str
'cid', # :class:`CID`
'prev', # :class:`CID`
], defaults=[None])
class Diff:
"""A diff between two MSTs.
Attributes:
adds: {str key: :class:`Change`}
updates: {str key: :class:`Change`}
deletes: {str key: :class:`Change`}
new_cids: set of :class:`CID`
removed_cids: set of :class:`CID`
"""
def __init__(self):
self.adds = {}
self.updates = {}
self.deletes = {}
self.new_cids = set()
self.removed_cids = set()
@staticmethod
def of(cur, prev=None):
"""
Args:
cur: :class:`MST`
prev: :class:`MST`, optional
Returns:
:class:`Diff`
"""
return mst_diff(cur, prev)
def record_add(self, key, cid):
"""
Args:
key: str
cid: :class:`CID`
"""
self.adds[key] = Change(key=key, cid=cid)
self.new_cids.add(cid)
def record_update(self, key, prev, cid):
"""
Args:
key: str
prev: :class:`CID`
cid: :class:`CID`
"""
self.updates[key] = Change(key=key, cid=cid, prev=prev)
self.new_cids.add(cid)
def record_delete(self, key, cid):
"""
Args:
key: str
cid: :class:`CID`
"""
self.deletes[key] = Change(key=key, cid=cid)
def record_new_cid(self, cid):
"""
Args:
cid: :class:`CID`
"""
if cid in self.removed_cids:
self.removed_cids.remove(cid)
else:
self.new_cids.add(cid)
def record_removed_cid(self, cid):
"""
Args:
cid: :class:`CID`
"""
if cid in self.new_cids:
self.new_cids.remove(cid)
else:
self.removed_cids.add(cid)
def add_diff(self, diff):
"""
Args:
diff: :class:`Diff`
"""
for add in diff.adds.values():
if self.deletes[add.key]:
deleted = self.deletes[add.key]
if deleted.cid != add.cid:
self.record_update(add.key, deleted.cid, add.cid)
del self.deletes[add.key]
else:
self.record_add(add.key, add.cid)
for update in diff.updates.values():
self.record_update(update.key, update.prev, update.cid)
del self.adds[update.key]
del self.deletes[update.key]
for deleted in diff.deletes.values():
if self.adds[deleted.key]:
del self.adds[deleted.key]
else:
del self.updates[deleted.key]
self.record_delete(deleted.key, deleted.cid)
self.new_cids |= diff.new_cids
def updated_keys(self):
return self.adds | self.updates | self.deletes

Plik diff jest za duży Load Diff

Wyświetl plik

@ -1,182 +0,0 @@
"""Misc AT Protocol utils. TIDs, CIDs, etc."""
import copy
from datetime import datetime, timezone
import logging
from numbers import Integral
import random
from Crypto.Hash import SHA256
from Crypto.Signature import DSS
import dag_cbor.encoding
from multiformats import CID, multicodec, multihash
from oauth_dropins.webutil.appengine_info import DEBUG
logger = logging.getLogger(__name__)
# the bottom 32 clock ids can be randomized & are not guaranteed to be collision
# resistant. we use the same clockid for all TIDs coming from this runtime.
_clockid = random.randint(0, 31)
# _tid_last = time.time_ns() // 1000 # microseconds
S32_CHARS = '234567abcdefghijklmnopqrstuvwxyz'
def dag_cbor_cid(obj):
"""Returns the DAG-CBOR CID for a given object.
Args:
obj: CBOR-compatible native object or value
Returns:
:class:`CID`
"""
encoded = dag_cbor.encoding.encode(obj)
digest = multihash.digest(encoded, 'sha2-256')
return CID('base58btc', 1, multicodec.get('dag-cbor'), digest)
def s32encode(num):
"""Base32 encode with encoding variant sort.
Based on https://github.com/bluesky-social/atproto/blob/main/packages/common-web/src/tid.ts
Args:
num: int or Integral
Returns:
str
"""
assert isinstance(num, Integral)
encoded = []
while num > 0:
c = num % 32
num = num // 32
encoded.insert(0, S32_CHARS[c])
return ''.join(encoded)
def s32decode(val):
"""Base32 decode with encoding variant sort.
Based on https://github.com/bluesky-social/atproto/blob/main/packages/common-web/src/tid.ts
Args:
val: str
Returns:
int or Integral
"""
i = 0
for c in val:
i = i * 32 + S32_CHARS.index(c)
return i
def datetime_to_tid(dt):
"""Converts a datetime to an ATProto TID.
https://atproto.com/guides/data-repos#identifier-types
Args:
dt: :class:`datetime.datetime`
Returns:
str, base32-encoded TID
"""
tid = (s32encode(int(dt.timestamp() * 1000 * 1000)) +
s32encode(_clockid).ljust(2, '2'))
assert len(tid) == 13
return tid
def tid_to_datetime(tid):
"""Converts an ATProto TID to a datetime.
https://atproto.com/guides/data-repos#identifier-types
Args:
tid: bytes, base32-encoded TID
Returns:
:class:`datetime.datetime`
Raises:
ValueError if tid is not bytes or not 13 characters long
"""
if not isinstance(tid, (str, bytes)) or len(tid) != 13:
raise ValueError(f'Expected 13-character str or bytes; got {tid}')
encoded = tid.replace('-', '')[:-2] # strip clock id
return datetime.fromtimestamp(s32decode(encoded) / 1000 / 1000, timezone.utc)
# TODO
# def next_tid():
# global _tid_last
# # enforce that we're at least 1us after the last TID to prevent TIDs moving
# # backwards if system clock drifts backwards
# now = time.time_ns() // 1000
# if now > _tid_last:
# _tid_last = now
# else:
# _tid_last += 1
# now = _tid_last
def sign_commit(commit, key):
"""Signs a repo commit.
Adds the signature in the `sig` field.
Signing isn't yet in the atproto.com docs, this setup is taken from the TS
code and conversations with @why on #bluesky-dev:matrix.org.
* https://matrix.to/#/!vpdMrhHjzaPbBUSgOs:matrix.org/$Xaf4ugYks-iYg7Pguh3dN8hlsvVMUOuCQo3fMiYPXTY?via=matrix.org&via=minds.com&via=envs.net
* https://github.com/bluesky-social/atproto/blob/384e739a3b7d34f7a95d6ba6f08e7223a7398995/packages/repo/src/util.ts#L238-L248
* https://github.com/bluesky-social/atproto/blob/384e739a3b7d34f7a95d6ba6f08e7223a7398995/packages/crypto/src/p256/keypair.ts#L66-L73
* https://github.com/bluesky-social/indigo/blob/f1f2480888ab5d0ac1e03bd9b7de090a3d26cd13/repo/repo.go#L64-L70
* https://github.com/whyrusleeping/go-did/blob/2146016fc220aa1e08ccf26aaa762f5a11a81404/key.go#L67-L91
The signature is ECDSA around SHA-256 of the input. We currently use P-256
keypairs. Context:
* Go supports P-256, ED25519, SECP256K1 keys
* TS supports P-256, SECP256K1 keys
* this recommends ED25519, then P-256:
https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
Args:
commit: dict repo commit
key: :class:`Crypto.PublicKey.ECC.EccKey`
"""
signer = DSS.new(key, 'fips-186-3', randfunc=random.randbytes if DEBUG else None)
commit['sig'] = signer.sign(SHA256.new(dag_cbor.encoding.encode(commit)))
def verify_commit_sig(commit, key):
"""Returns true if the commit's signature is valid, False otherwise.
See :func:`sign_commit` for more background.
Args:
commit: dict repo commit
key: :class:`Crypto.PublicKey.ECC.EccKey`
Raises:
KeyError if the commit isn't signed, ie doesn't have a `sig` field
"""
commit = copy.copy(commit)
sig = commit.pop('sig')
verifier = DSS.new(key.public_key(), 'fips-186-3',
randfunc=random.randbytes if DEBUG else None)
try:
verifier.verify(SHA256.new(dag_cbor.encoding.encode(commit)), sig)
return True
except ValueError:
logger.debug("Couldn't verify signature", exc_info=True)
return False

Wyświetl plik

@ -341,6 +341,7 @@ texinfo_documents = [
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'arroba': ('https://arroba.readthedocs.io/en/latest', None),
'dag_cbor': ('https://dag-cbor.readthedocs.io/en/latest', None),
'flask': ('https://flask.palletsprojects.com/en/latest', None),
'flask_caching': ('https://flask-caching.readthedocs.io/en/latest', None),

Wyświetl plik

@ -9,10 +9,6 @@ activitypub
-----------
.. automodule:: activitypub
atproto_mst
-----------
.. automodule:: atproto
common
------
.. automodule:: common

Wyświetl plik

@ -8,6 +8,7 @@ import logging
import random
import urllib.parse
from arroba.mst import dag_cbor_cid
from Crypto import Random
from Crypto.PublicKey import ECC, RSA
from Crypto.Util import number
@ -22,7 +23,6 @@ from oauth_dropins.webutil.util import json_dumps, json_loads
import requests
from werkzeug.exceptions import BadRequest, NotFound
from atproto_mst import dag_cbor_cid
import common
# https://github.com/snarfed/bridgy-fed/issues/314

Wyświetl plik

@ -1,3 +1,4 @@
git+https://github.com/snarfed/arroba.git#egg=arroba
git+https://github.com/dvska/gdata-python3.git#egg=gdata
git+https://github.com/snarfed/granary.git#egg=granary
git+https://github.com/snarfed/dag-json.git#egg=dag_json

Wyświetl plik

@ -5,6 +5,9 @@ import random
from unittest import skip
from unittest.mock import patch
import atproto
import arroba.util
from arroba.util import dag_cbor_cid, verify_commit_sig
from Crypto.PublicKey import ECC
import dag_cbor.decoding, dag_cbor.encoding
from granary import as2, bluesky
@ -22,14 +25,11 @@ from multiformats import CID
from oauth_dropins.webutil import util
from oauth_dropins.webutil.testutil import NOW
import atproto
import atproto_util
from atproto_util import dag_cbor_cid, verify_commit_sig
from flask_app import app
from models import Follower, Object, User
from . import testutil
# # atproto_mst.Data entry for MST with POST_AS, REPLY_AS, and REPOST_AS
# # arroba.mst.Data entry for MST with POST_AS, REPLY_AS, and REPOST_AS
# POST_CID = 'bafyreic5xwex7jxqvliumozkoli3qy2hzxrmui6odl7ujrcybqaypacfiy'
# REPLY_CID = 'bafyreib55ro37wasiflouvlfenhzllorcthm7flr2nj4fnk7yjo54cagvm'
# REPOST_CID = 'bafyreiek3jnp6e4sussy4c7pwtbkkf3kepekzycylowwuepmnvq7aeng44'

Wyświetl plik

@ -1,62 +0,0 @@
"""Unit tests for atproto_diff.py.
Heavily based on:
https://github.com/bluesky/atproto/blob/main/packages/repo/tests/sync/diff.test.ts
Huge thanks to the Bluesky team for working in the public, in open source, and to
Daniel Holmgren and Devin Ivy for this code specifically!
"""
import dag_cbor.random
from atproto_diff import Change, Diff
from atproto_mst import MST
from . import testutil
class AtProtoDiffTest(testutil.TestCase):
def test_diffs(self):
mst = MST()
data = self.random_keys_and_cids(3)#1000)
for key, cid in data:
mst = mst.add(key, cid)
before = after = mst
to_add = self.random_keys_and_cids(1)#100)
to_edit = data[1:2]
to_del = data[0:1]
# these are all {str key: Change}
expected_adds = {}
expected_updates = {}
expected_deletes = {}
for key, cid in to_add:
after = after.add(key, cid)
expected_adds[key] = Change(key=key, cid=cid)
for (key, prev), new in zip(to_edit, dag_cbor.random.rand_cid()):
after = after.update(key, new)
expected_updates[key] = Change(key=key, prev=prev, cid=new)
for key, cid in to_del:
after = after.delete(key)
expected_deletes[key] = Change(key=key, cid=cid)
diff = Diff.of(after, before)
self.assertEqual(1, len(diff.adds))
self.assertEqual(1, len(diff.updates))
self.assertEqual(1, len(diff.deletes))
self.assertEqual(expected_adds, diff.adds)
self.assertEqual(expected_updates, diff.updates)
self.assertEqual(expected_deletes, diff.deletes)
# ensure we correctly report all added CIDs
for entry in after.walk():
cid = entry.get_pointer() if isinstance(entry, MST) else entry.value
# TODO
# assert cid in blockstore or cid in diff.new_cids

Wyświetl plik

@ -1,286 +0,0 @@
"""Unit tests for atproto_mst.py.
Heavily based on:
https://github.com/bluesky/atproto/blob/main/packages/repo/tests/mst.test.ts
Huge thanks to the Bluesky team for working in the public, in open source, and to
Daniel Holmgren and Devin Ivy for this code specifically!
"""
from base64 import b32encode
from datetime import datetime
import time
import random
import string
from unittest import skip
import dag_cbor.random
from multiformats import CID
from atproto_mst import common_prefix_len, ensure_valid_key, MST
import atproto_util
from . import testutil
CID1 = CID.decode('bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454')
class MstTest(testutil.TestCase):
def test_add(self):
mst = MST()
data = self.random_keys_and_cids(1000)
for key, cid in data:
mst = mst.add(key, cid)
for key, cid in data:
got = mst.get(key)
self.assertEqual(cid, got)
self.assertEqual(1000, mst.leaf_count())
def test_edits_records(self):
mst = MST()
data = self.random_keys_and_cids(100)
for key, cid in data:
mst = mst.add(key, cid)
edited = []
for (key, _), cid in zip(data, dag_cbor.random.rand_cid()):
mst = mst.update(key, cid)
edited.append([key, cid])
for key, cid in edited:
self.assertEqual(cid, mst.get(key))
self.assertEqual(100, mst.leaf_count())
def test_deletes_records(self):
mst = MST()
data = self.random_keys_and_cids(1000)
for key, cid in data:
mst = mst.add(key, cid)
to_delete = data[:100]
the_rest = data[100:]
for key, _ in to_delete:
mst = mst.delete(key)
self.assertEqual(900, mst.leaf_count())
for key, _ in to_delete:
self.assertIsNone(mst.get(key))
for key, cid in the_rest:
self.assertEqual(cid, mst.get(key))
def test_is_order_independent(self):
mst = MST()
data = self.random_keys_and_cids(1000)
for key, cid in data:
mst = mst.add(key, cid)
all_nodes = mst.all_nodes()
recreated = MST()
random.shuffle(data)
for key, cid in data:
recreated = recreated.add(key, cid)
self.assertEqual(all_nodes, recreated.all_nodes())
def test_common_prefix_length(self):
self.assertEqual(3, common_prefix_len('abc', 'abc'))
self.assertEqual(0, common_prefix_len('', 'abc'))
self.assertEqual(0, common_prefix_len('abc', ''))
self.assertEqual(2, common_prefix_len('ab', 'abc'))
self.assertEqual(2, common_prefix_len('abc', 'ab'))
self.assertEqual(3, common_prefix_len('abcde', 'abc'))
self.assertEqual(3, common_prefix_len('abc', 'abcde'))
self.assertEqual(3, common_prefix_len('abcde', 'abc1'))
self.assertEqual(2, common_prefix_len('abcde', 'abb'))
self.assertEqual(0, common_prefix_len('abcde', 'qbb'))
self.assertEqual(0, common_prefix_len('', 'asdf'))
self.assertEqual(3, common_prefix_len('abc', 'abc\x00'))
self.assertEqual(3, common_prefix_len('abc\x00', 'abc'))
def test_rejects_the_empty_key(self):
with self.assertRaises(ValueError):
MST().add('')
def test_rejects_a_key_with_no_collection(self):
with self.assertRaises(ValueError):
MST().add('asdf')
def test_rejects_a_key_with_a_nested_collection(self):
with self.assertRaises(ValueError):
MST().add('nested/collection/asdf')
def test_rejects_on_empty_coll_or_rkey(self):
for key in 'coll/', '/rkey':
with self.assertRaises(ValueError):
MST().add(key)
def test_rejects_non_ascii_chars(self):
for key in 'coll/jalapeñoA', 'coll/coöperative', 'coll/abc💩':
with self.assertRaises(ValueError):
MST().add(key)
def test_rejects_ascii_that_we_dont_support(self):
for key in ('coll/key$', 'coll/key%', 'coll/key(', 'coll/key)',
'coll/key+', 'coll/key='):
with self.assertRaises(ValueError):
MST().add(key)
def test_rejects_keys_over_256_chars(self):
with self.assertRaises(ValueError):
MST().add(
'coll/asdofiupoiwqeurfpaosidfuapsodirupasoirupasoeiruaspeoriuaspeoriu2p3o4iu1pqw3oiuaspdfoiuaspdfoiuasdfpoiasdufpwoieruapsdofiuaspdfoiuasdpfoiausdfpoasidfupasodifuaspdofiuasdpfoiasudfpoasidfuapsodfiuasdpfoiausdfpoasidufpasodifuapsdofiuasdpofiuasdfpoaisdufpao',
)
def test_computes_empty_tree_root_CID(self):
self.assertEqual(0, MST().leaf_count())
self.assertEqual(
'bafyreie5737gdxlw5i64vzichcalba3z2v5n6icifvx5xytvske7mr3hpm',
MST().get_pointer().encode('base32'))
def test_computes_trivial_tree_root_CID(self):
mst = MST().add('com.example.record/3jqfcqzm3fo2j', CID1)
self.assertEqual(1, mst.leaf_count())
self.assertEqual(
'bafyreibj4lsc3aqnrvphp5xmrnfoorvru4wynt6lwidqbm2623a6tatzdu',
mst.get_pointer().encode('base32'))
def test_computes_single_layer_2_tree_root_CID(self):
mst = MST().add('com.example.record/3jqfcqzm3fx2j', CID1)
self.assertEqual(1, mst.leaf_count())
self.assertEqual(2, mst.layer)
self.assertEqual(
'bafyreih7wfei65pxzhauoibu3ls7jgmkju4bspy4t2ha2qdjnzqvoy33ai',
mst.get_pointer().encode('base32'))
def test_computes_simple_tree_root_CID(self):
mst = MST()
mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm3fr2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # level 1
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm4fc2j', CID1) # level 0
self.assertEqual(5, mst.leaf_count())
self.assertEqual(
'bafyreicmahysq4n6wfuxo522m6dpiy7z7qzym3dzs756t5n7nfdgccwq7m',
mst.get_pointer().encode('base32'))
def test_trims_top_of_tree_on_delete(self):
l1root = 'bafyreifnqrwbk6ffmyaz5qtujqrzf5qmxf7cbxvgzktl4e3gabuxbtatv4'
l0root = 'bafyreie4kjuxbwkhzg2i5dljaswcroeih4dgiqq6pazcmunwt2byd725vi'
mst = MST()
mst = mst.add('com.example.record/3jqfcqzm3fn2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm3fo2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # level 1
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # level 0
mst = mst.add('com.example.record/3jqfcqzm3fu2j', CID1) # level 0
self.assertEqual(6, mst.leaf_count())
self.assertEqual(1, mst.layer)
self.assertEqual(l1root, mst.get_pointer().encode('base32'))
mst = mst.delete('com.example.record/3jqfcqzm3fs2j') # level 1
self.assertEqual(5, mst.leaf_count())
self.assertEqual(0, mst.layer)
self.assertEqual(l0root, mst.get_pointer().encode('base32'))
def test_handles_insertion_that_splits_two_layers_down(self):
"""
* *
_________|________ ____|_____
| | | | | | | |
* d * i * -> * f *
__|__ __|__ __|__ __|__ __|___
| | | | | | | | | | | | | | |
a b c e g h j k l * d * * i *
__|__ | _|_ __|__
| | | | | | | | |
a b c e g h j k l
"""
l1root = 'bafyreiettyludka6fpgp33stwxfuwhkzlur6chs4d2v4nkmq2j3ogpdjem'
l2root = 'bafyreid2x5eqs4w4qxvc5jiwda4cien3gw2q6cshofxwnvv7iucrmfohpm'
mst = MST()
mst = mst.add('com.example.record/3jqfcqzm3fo2j', CID1) # A; level 0
mst = mst.add('com.example.record/3jqfcqzm3fp2j', CID1) # B; level 0
mst = mst.add('com.example.record/3jqfcqzm3fr2j', CID1) # C; level 0
mst = mst.add('com.example.record/3jqfcqzm3fs2j', CID1) # D; level 1
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # E; level 0
# GAP for F
mst = mst.add('com.example.record/3jqfcqzm3fz2j', CID1) # G; level 0
mst = mst.add('com.example.record/3jqfcqzm4fc2j', CID1) # H; level 0
mst = mst.add('com.example.record/3jqfcqzm4fd2j', CID1) # I; level 1
mst = mst.add('com.example.record/3jqfcqzm4ff2j', CID1) # J; level 0
mst = mst.add('com.example.record/3jqfcqzm4fg2j', CID1) # K; level 0
mst = mst.add('com.example.record/3jqfcqzm4fh2j', CID1) # L; level 0
self.assertEqual(11, mst.leaf_count())
self.assertEqual(1, mst.layer)
self.assertEqual(l1root, mst.get_pointer().encode('base32'))
# insert F, which will push E out in the node with G+H to a new node under D
mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # F; level 2
self.assertEqual(12, mst.leaf_count())
self.assertEqual(2, mst.layer)
self.assertEqual(l2root, mst.get_pointer().encode('base32'))
# remove F, which should push E back over with G+H
mst = mst.delete('com.example.record/3jqfcqzm3fx2j') # F; level 2
self.assertEqual(11, mst.leaf_count())
self.assertEqual(1, mst.layer)
self.assertEqual(l1root, mst.get_pointer().encode('base32'))
def test_handles_new_layers_that_are_two_higher_than_existing(self):
"""
* -> *
__|__ __|__
| | | | |
a c * b *
| |
* *
| |
a c
"""
l0root = 'bafyreidfcktqnfmykz2ps3dbul35pepleq7kvv526g47xahuz3rqtptmky'
l2root = 'bafyreiavxaxdz7o7rbvr3zg2liox2yww46t7g6hkehx4i4h3lwudly7dhy'
l2root2 = 'bafyreig4jv3vuajbsybhyvb7gggvpwh2zszwfyttjrj6qwvcsp24h6popu'
mst = MST()
mst = mst.add('com.example.record/3jqfcqzm3ft2j', CID1) # A; level 0
mst = mst.add('com.example.record/3jqfcqzm3fz2j', CID1) # C; level 0
self.assertEqual(2, mst.leaf_count())
self.assertEqual(0, mst.layer)
self.assertEqual(l0root, mst.get_pointer().encode('base32'))
# insert B, which is two levels above
mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # B; level 2
self.assertEqual(3, mst.leaf_count())
self.assertEqual(2, mst.layer)
self.assertEqual(l2root, mst.get_pointer().encode('base32'))
# remove B
mst = mst.delete('com.example.record/3jqfcqzm3fx2j') # B; level 2
self.assertEqual(2, mst.leaf_count())
self.assertEqual(0, mst.layer)
self.assertEqual(l0root, mst.get_pointer().encode('base32'))
# insert B (level=2) and D (level=1)
mst = mst.add('com.example.record/3jqfcqzm3fx2j', CID1) # B; level 2
mst = mst.add('com.example.record/3jqfcqzm4fd2j', CID1) # D; level 1
self.assertEqual(4, mst.leaf_count())
self.assertEqual(2, mst.layer)
self.assertEqual(l2root2, mst.get_pointer().encode('base32'))
# remove D
mst = mst.delete('com.example.record/3jqfcqzm4fd2j') # D; level 1
self.assertEqual(3, mst.leaf_count())
self.assertEqual(2, mst.layer)
self.assertEqual(l2root, mst.get_pointer().encode('base32'))

Wyświetl plik

@ -1,46 +0,0 @@
"""Unit tests for atproto_util.py."""
from Crypto.PublicKey import ECC
from multiformats import CID
from oauth_dropins.webutil.testutil import NOW
import atproto_util
from atproto_util import (
dag_cbor_cid,
datetime_to_tid,
sign_commit,
tid_to_datetime,
verify_commit_sig,
)
from . import testutil
class AtProtoUtilTest(testutil.TestCase):
def test_dag_cbor_cid(self):
self.assertEqual(
CID.decode('bafyreiblaotetvwobe7cu2uqvnddr6ew2q3cu75qsoweulzku2egca4dxq'),
dag_cbor_cid({'foo': 'bar'}))
def test_datetime_to_tid(self):
self.assertEqual('3iom4o4g6u2l2', datetime_to_tid(NOW))
def test_tid_to_datetime(self):
self.assertEqual(NOW, tid_to_datetime('3iom4o4g6u2l2'))
def test_sign_commit_and_verify(self):
user = self.make_user('user.com')
commit = {'foo': 'bar'}
key = ECC.import_key(user.p256_key)
sign_commit(commit, key)
assert verify_commit_sig(commit, key)
def test_verify_commit_error(self):
key = ECC.import_key(self.make_user('user.com').p256_key)
with self.assertRaises(KeyError):
self.assertFalse(verify_commit_sig({'foo': 'bar'}, key))
def test_verify_commit_fail(self):
key = ECC.import_key(self.make_user('user.com').p256_key)
self.assertFalse(verify_commit_sig({'foo': 'bar', 'sig': 'nope'}, key))

Wyświetl plik

@ -2,6 +2,7 @@
"""Unit tests for models.py."""
from unittest import mock
from arroba.mst import dag_cbor_cid
from Crypto.PublicKey import ECC
from flask import g, get_flashed_messages
from granary import as2
@ -10,7 +11,6 @@ from multiformats import CID
from oauth_dropins.webutil.testutil import NOW, requests_response
from app import app
from atproto_mst import dag_cbor_cid
import common
from models import AtpNode, Follower, Object, OBJECT_EXPIRE_AGE, User
import protocol

Wyświetl plik

@ -6,8 +6,8 @@ import random
import unittest
from unittest.mock import ANY, call
import atproto_util
from atproto_util import datetime_to_tid
import arroba.util
from arroba.util import datetime_to_tid
import dag_cbor.random
from flask import g
from google.cloud import ndb
@ -90,7 +90,7 @@ class TestCase(unittest.TestCase, testutil.Asserts):
FakeProtocol.fetched = []
# make random test data deterministic
atproto_util._clockid = 17
arroba.util._clockid = 17
random.seed(1234567890)
dag_cbor.random.set_options(seed=1234567890)