From 6b597c90c3418aca8f39e4bf98d08ad549f12e72 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Sun, 21 Apr 2024 11:27:23 -0700 Subject: [PATCH] User.get_or_create: abstract propagate and create_for across protocols --- ids.py | 2 +- models.py | 14 ++++++++------ protocol.py | 10 ++++++++++ tests/test_atproto.py | 2 ++ tests/test_models.py | 7 ++++++- tests/testutil.py | 14 ++++++++++++-- 6 files changed, 39 insertions(+), 10 deletions(-) diff --git a/ids.py b/ids.py index b3c34d6..a0f628a 100644 --- a/ids.py +++ b/ids.py @@ -17,7 +17,7 @@ import models logger = logging.getLogger(__name__) # Protocols to check User.copies and Object.copies before translating -COPIES_PROTOCOLS = ('atproto', 'fake', 'other', 'nostr') +COPIES_PROTOCOLS = ('atproto', 'fake', 'other') # Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for # historical compatibility. Loaded on first call to web_ap_subdomain(). diff --git a/models.py b/models.py index 201c24b..9d3cf32 100644 --- a/models.py +++ b/models.py @@ -249,12 +249,14 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): if not user.obj_key: user.obj = cls.load(user.profile_id()) - ATProto = PROTOCOLS['atproto'] - if propagate and cls.LABEL != 'atproto' and not user.get_copy(ATProto): - if cls.is_enabled_to(ATProto, user=id): - ATProto.create_for(user) - else: - logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping') + if propagate: + for label in ids.COPIES_PROTOCOLS: + proto = PROTOCOLS[label] + if proto != cls and not user.get_copy(proto): + if cls.is_enabled_to(proto, user=id): + proto.create_for(user) + else: + logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping') # generate keys for all protocols _except_ our own # diff --git a/protocol.py b/protocol.py index 0902925..78f0178 100644 --- a/protocol.py +++ b/protocol.py @@ -441,6 +441,16 @@ class Protocol: if owner: return cls.key_for(owner) + @classmethod + def create_for(cls, user): + """Creates a copy user in this protocol. + + Args: + user (models.User): original source user. Shouldn't already have a + copy user for this protocol in ``copies``. + """ + raise NotImplementedError() + @classmethod def send(to_cls, obj, url, from_user=None, orig_obj=None): """Sends an outgoing activity. diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 34a47fb..0a771f0 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -700,6 +700,7 @@ class ATProtoTest(TestCase): mock_create_task.assert_called() + @patch('ids.COPIES_PROTOCOLS', ['atproto']) @patch('google.cloud.dns.client.ManagedZone', autospec=True) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch('requests.post', @@ -764,6 +765,7 @@ class ATProtoTest(TestCase): self.assert_task(mock_create_task, 'atproto-commit', '/queue/atproto-commit') + @patch('ids.COPIES_PROTOCOLS', ['atproto']) @patch('requests.get', return_value=requests_response( 'blob contents', content_type='image/png')) # image blob fetch @patch('google.cloud.dns.client.ManagedZone', autospec=True) diff --git a/tests/test_models.py b/tests/test_models.py index a211e90..3e4cc8d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -66,10 +66,15 @@ class UserTest(TestCase): user.direct = True self.assert_entities_equal(same, user, ignore=['updated']) + def test_get_or_create_propagate_fake_other(self): + user = Fake.get_or_create('fake:user', propagate=True) + self.assertEqual(['fake:user'], OtherFake.created_for) + + @patch('ids.COPIES_PROTOCOLS', ['fake', 'other', 'atproto']) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch('requests.post', return_value=requests_response('OK')) # create DID on PLC - def test_get_or_create_propagate(self, mock_post, mock_create_task): + def test_get_or_create_propagate_atproto(self, mock_post, mock_create_task): common.RUN_TASKS_INLINE = False Fake.fetchable = { diff --git a/tests/testutil.py b/tests/testutil.py index 092775b..0e08bac 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -35,6 +35,7 @@ import requests # other modules are imported _after_ Fake etc classes is defined so that it's in # PROTOCOLS when URL routes are registered. from common import long_to_base64, TASKS_LOCATION +import ids import models from models import KEY_BITS, Object, PROTOCOLS, Target, User import protocol @@ -76,8 +77,9 @@ class Fake(User, protocol.Protocol): # in-order list of (Object, str URL) sent = [] - # in-order list of ids + # in-order lists of ids fetched = [] + created_for = [] @ndb.ComputedProperty def handle(self): @@ -86,6 +88,10 @@ class Fake(User, protocol.Protocol): def web_url(self): return self.key.id() + @classmethod + def create_for(cls, user): + cls.created_for.append(user.key.id()) + @classmethod def owns_id(cls, id): if id.startswith('nope') or id == f'{cls.LABEL}:nope': @@ -157,6 +163,7 @@ class OtherFake(Fake): fetchable = {} sent = [] fetched = [] + created_for = [] @classmethod def target_for(cls, obj, shared=False): @@ -171,6 +178,7 @@ class ExplicitEnableFake(Fake): fetchable = {} sent = [] fetched = [] + created_for = [] # import other modules that register Flask handlers *after* Fake is defined @@ -211,13 +219,15 @@ class TestCase(unittest.TestCase, testutil.Asserts): did.resolve_plc.cache.clear() did.resolve_web.cache.clear() - for cls in Fake, OtherFake: + for cls in ExplicitEnableFake, Fake, OtherFake: cls.fetchable = {} cls.sent = [] cls.fetched = [] + cls.created_for = [] common.OTHER_DOMAINS += ('fake.brid.gy',) common.DOMAINS += ('fake.brid.gy',) + ids.COPIES_PROTOCOLS = ['fake', 'other'] # make random test data deterministic arroba.util._clockid = 17