From 9b2bb9ef3709ac7a62456bc6aad66680f1a4ed93 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Wed, 13 Mar 2024 16:08:08 -0700 Subject: [PATCH] make ATProto.load([DID]) return the profile record by default add did_doc kwarg to make it return the DID doc instead --- atproto.py | 26 ++++++++++++++++++++++---- protocol.py | 2 +- tests/test_atproto.py | 12 +++++++++++- tests/test_integrations.py | 16 ++++++++++++++-- 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/atproto.py b/atproto.py index cd8f41e..cb38c39 100644 --- a/atproto.py +++ b/atproto.py @@ -81,7 +81,7 @@ class ATProto(User, Protocol): @ndb.ComputedProperty def handle(self): """Returns handle if the DID document includes one, otherwise None.""" - if did_obj := ATProto.load(self.key.id()): + if did_obj := ATProto.load(self.key.id(), did_doc=True): if aka := util.get_first(did_obj.raw, 'alsoKnownAs', ''): handle, _, _ = parse_at_uri(aka) if handle: @@ -114,8 +114,13 @@ class ATProto(User, Protocol): return did.resolve_handle(handle, get_fn=util.requests_get) + @staticmethod + def profile_at_uri(id): + assert id.startswith('did:') + return f'at://{id}/app.bsky.actor.profile/self' + def profile_id(self): - return f'at://{self.key.id()}/app.bsky.actor.profile/self' + return self.profile_at_uri(self.key.id()) @classmethod def target_for(cls, obj, shared=False): @@ -166,7 +171,7 @@ class ATProto(User, Protocol): else: return None - did_obj = ATProto.load(repo) + did_obj = ATProto.load(repo, did_doc=True) if did_obj: return cls.pds_for(did_obj) # TODO: what should we do if the DID doesn't exist? should we return @@ -308,7 +313,7 @@ class ATProto(User, Protocol): did = user.get_copy(ATProto) assert did logger.info(f'{user.key} is {did}') - did_doc = to_cls.load(did) + did_doc = to_cls.load(did, did_doc=True) pds = to_cls.pds_for(did_doc) if not pds or util.domain_from_link(pds) not in DOMAINS: logger.warning(f'{from_key} {did} PDS {pds} is not us') @@ -355,6 +360,19 @@ class ATProto(User, Protocol): write() return True + @classmethod + def load(cls, id, did_doc=False, **kwargs): + """Thin wrapper that converts DIDs to profile URIs. + + Args: + did_doc (bool): if True, loads and returns a DID document object + instead of an ``app.bsky.actor.profile/self``. + """ + if not did_doc and id.startswith('did:'): + id = cls.profile_at_uri(id) + + return super().load(id, **kwargs) + @classmethod def fetch(cls, obj, **kwargs): """Tries to fetch a ATProto object. diff --git a/protocol.py b/protocol.py index 2a64a29..f9d73de 100644 --- a/protocol.py +++ b/protocol.py @@ -716,7 +716,7 @@ class Protocol: # fall through to deliver to followers - # fetch actor if necessary so we have name, profile photo, etc + # fetch actor if necessary if actor and actor.keys() == set(['id']): logger.info('Fetching actor so we have name, profile photo, etc') actor_obj = from_cls.load(actor['id']) diff --git a/tests/test_atproto.py b/tests/test_atproto.py index e7a8b42..ebd3a89 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -316,6 +316,16 @@ class ATProtoTest(TestCase): obj = Object(id='https://bsky.app/profile/bad.com/post/789') self.assertFalse(ATProto.fetch(obj)) + def test_load_did_doc(self): + obj = self.store_object(id='did:plc:user', raw=DID_DOC) + self.assert_entities_equal(obj, ATProto.load('did:plc:user', did_doc=True)) + + def test_load_did_doc_false_loads_profile(self): + did_doc = self.store_object(id='did:plc:user', raw=DID_DOC) + profile = self.store_object(id='at://did:plc:user/app.bsky.actor.profile/self', + bsky=ACTOR_PROFILE_BSKY) + self.assert_entities_equal(profile, ATProto.load('did:plc:user')) + def test_convert_bsky_pass_through(self): self.assertEqual({ 'foo': 'bar', @@ -593,7 +603,7 @@ class ATProtoTest(TestCase): did = user.get_copy(ATProto) assert did self.assertEqual([Target(uri=did, protocol='atproto')], user.copies) - did_obj = ATProto.load(did) + did_obj = ATProto.load(did, did_doc=True) self.assertEqual('https://atproto.brid.gy/', did_obj.raw['service'][0]['serviceEndpoint']) diff --git a/tests/test_integrations.py b/tests/test_integrations.py index d4282a7..ad8cb19 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -9,6 +9,7 @@ from oauth_dropins.webutil.testutil import requests_response from activitypub import ActivityPub import app from atproto import ATProto +from granary.tests.test_bluesky import ACTOR_PROFILE_BSKY import hub from models import Target from web import Web @@ -61,8 +62,8 @@ class IntegrationTests(TestCase): Target(uri='at://did:plc:bob/app.bsky.feed.post/123', protocol='atproto'), ]) - # ATProto listNotifications => receive mock_get.side_effect = [ + # ATProto listNotifications requests_response({ 'cursor': '...', 'notifications': [{ @@ -90,6 +91,12 @@ class IntegrationTests(TestCase): }, }], }), + # ATProto getRecord of alice's profile + requests_response({ + 'uri': 'at://did:plc:alice/app.bsky.actor.profile/self', + 'cid': 'alice sidd', + 'value': test_atproto.ACTOR_PROFILE_BSKY, + }), ] resp = self.post('/queue/atproto-poll-notifs', client=hub.app.test_client()) @@ -136,7 +143,6 @@ class IntegrationTests(TestCase): bob = self.make_user(id='bob.com', cls=Web, copies=[Target(uri='did:plc:bob', protocol='atproto')]) - # ATProto listNotifications => receive mock_get.side_effect = [ # ATProto listNotifications requests_response({ @@ -157,6 +163,12 @@ class IntegrationTests(TestCase): }, }], }), + # ATProto getRecord of alice's profile + requests_response({ + 'uri': 'at://did:plc:alice/app.bsky.actor.profile/self', + 'cid': 'alice sidd', + 'value': test_atproto.ACTOR_PROFILE_BSKY, + }), # webmention discovery test_web.WEBMENTION_REL_LINK, ]