diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 933d2415..3d1769c7 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -62,8 +62,6 @@ class AtProtoTest(testutil.TestCase): def setUp(self): super().setUp() - atproto._clockid = 17 # need this to be deterministic - # used in now(), injected into Object.created so that TIDs are deterministic self.last_now = NOW.replace(tzinfo=None) diff --git a/tests/test_atproto_mst.py b/tests/test_atproto_mst.py index 89c9b802..79b37b1b 100644 --- a/tests/test_atproto_mst.py +++ b/tests/test_atproto_mst.py @@ -18,35 +18,16 @@ from multiformats import CID from atproto_mst import common_prefix_len, ensure_valid_key, MST import atproto_util -from atproto_util import datetime_to_tid from . import testutil CID1 = CID.decode('bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454') -def make_data(num): - def tid(): - ms = random.randint(datetime(2020, 1, 1).timestamp() * 1000, - datetime(2024, 1, 1).timestamp() * 1000) - return datetime_to_tid(datetime.fromtimestamp(float(ms) / 1000)) - - return [(f'com.example.record/{tid()}', cid) - for cid in dag_cbor.random.rand_cid(num)] - - class MstTest(testutil.TestCase): - def setUp(self): - super().setUp() - - # make random test data deterministic - atproto_util._clockid = 17 - random.seed(1234567890) - dag_cbor.random.set_options(seed=1234567890) - def test_add(self): mst = MST() - data = make_data(1000) + data = self.random_keys_and_cids(1000) for key, cid in data: mst = mst.add(key, cid) @@ -58,7 +39,7 @@ class MstTest(testutil.TestCase): def test_edits_records(self): mst = MST() - data = make_data(100) + data = self.random_keys_and_cids(100) for key, cid in data: mst = mst.add(key, cid) @@ -75,7 +56,7 @@ class MstTest(testutil.TestCase): def test_deletes_records(self): mst = MST() - data = make_data(1000) + data = self.random_keys_and_cids(1000) for key, cid in data: mst = mst.add(key, cid) @@ -94,7 +75,7 @@ class MstTest(testutil.TestCase): def test_is_order_independent(self): mst = MST() - data = make_data(1000) + data = self.random_keys_and_cids(1000) for key, cid in data: mst = mst.add(key, cid) @@ -107,50 +88,6 @@ class MstTest(testutil.TestCase): self.assertEqual(all_nodes, recreated.all_nodes()) - # def test_diffs(self): - # to_diff = MST() - - # to_add = Object.entries(make_data(100)) - # to_edit = self.shuffled[500:600] - # to_del = self.shuffled[400:500] - - # expected_updates = {} - # expected_dels = {} - # expected_adds = {entry[0]: {'key': entry[0], 'cid': entry[1]} - # for entry in to_add.items()} - - # for entry in to_add: - # to_diff.add(entry[0], entry[1]) - # expected_adds[entry[0]] = x - - # for entry, cid in zip(to_edit, dag_cbor.random.rand_cid()): - # updated = random_cid() - # to_diff.update(entry[0], updated) - # expected_updates[entry[0]] = { - # 'key': entry[0], - # 'prev': entry[1], - # 'cid': updated, - # } - - # for entry in to_del: - # to_diff.delete(entry[0]) - # expected_dels[entry[0]] = {'key': entry[0], 'cid': entry[1]} - - # diff = DataDiff.of(to_diff, self.mst) - - # self.assertEqual(100, len(diff.add_list())) - # self.assertEqual(100, len(diff.update_list())) - # self.assertEqual(100, len(diff.delete_list())) - - # self.assertEqual(expected_adds, diff.adds) - # self.assertEqual(expected_updates, diff.updates) - # self.assertEqual(expected_dels, diff.deletes) - - # # ensure we correctly report all added CIDs - # for entry in to_diff.walk(): - # cid = entry.get_pointer() if entry.is_tree() else entry.value - # self.assert_true(blockstore.has(cid) or diff.new_cids.has(cid)) - def test_common_prefix_length(self): self.assertEqual(3, common_prefix_len('abc', 'abc')) self.assertEqual(0, common_prefix_len('', 'abc')) diff --git a/tests/test_atproto_util.py b/tests/test_atproto_util.py index 9bc07c5e..a6ae5f52 100644 --- a/tests/test_atproto_util.py +++ b/tests/test_atproto_util.py @@ -17,10 +17,6 @@ from . import testutil class AtProtoUtilTest(testutil.TestCase): - def setUp(self): - super().setUp() - atproto_util._clockid = 17 - def test_dag_cbor_cid(self): self.assertEqual( CID.decode('bafyreiblaotetvwobe7cu2uqvnddr6ew2q3cu75qsoweulzku2egca4dxq'), diff --git a/tests/testutil.py b/tests/testutil.py index dbe0d41d..41d156ac 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -1,10 +1,13 @@ """Common test utility code.""" import copy -import datetime +from datetime import datetime +import logging import random import unittest from unittest.mock import ANY, call +import atproto_util +from atproto_util import datetime_to_tid import dag_cbor.random from flask import g from google.cloud import ndb @@ -14,16 +17,11 @@ from granary.tests.test_as1 import ( MENTION, NOTE, ) -import logging from oauth_dropins.webutil import testutil, util from oauth_dropins.webutil.appengine_config import ndb_client from oauth_dropins.webutil.testutil import requests_response import requests -# make random test data deterministic -random.seed(1234567890) -dag_cbor.random.set_options(seed=1234567890) - # load all Flask handlers import app from flask_app import app, cache @@ -91,6 +89,11 @@ class TestCase(unittest.TestCase, testutil.Asserts): FakeProtocol.sent = [] FakeProtocol.fetched = [] + # make random test data deterministic + atproto_util._clockid = 17 + random.seed(1234567890) + dag_cbor.random.set_options(seed=1234567890) + self.client = app.test_client() self.client.__enter__() @@ -137,6 +140,22 @@ class TestCase(unittest.TestCase, testutil.Asserts): Object(id='f', domains=['user.com'], labels=['feed', 'notification', 'user'], as2=as2.from_as1(NOTE), deleted=True).put() + @staticmethod + def random_keys_and_cids(num): + def tid(): + ms = random.randint(datetime(2020, 1, 1).timestamp() * 1000, + datetime(2024, 1, 1).timestamp() * 1000) + return datetime_to_tid(datetime.fromtimestamp(float(ms) / 1000)) + + return [(f'com.example.record/{tid()}', cid) + for cid in dag_cbor.random.rand_cid(num)] + + def random_tid(num): + ms = random.randint(datetime(2020, 1, 1).timestamp() * 1000, + datetime(2024, 1, 1).timestamp() * 1000) + tid = datetime_to_tid(datetime.fromtimestamp(float(ms) / 1000)) + return f'com.example.record/{tid}' + def req(self, url, **kwargs): """Returns a mock requests call.""" kwargs.setdefault('headers', {}).update({