make ATProto.load([DID]) return the profile record by default

add did_doc kwarg to make it return the DID doc instead
pull/923/head
Ryan Barrett 2024-03-13 16:08:08 -07:00
rodzic 5eac0e06d0
commit 9b2bb9ef37
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
4 zmienionych plików z 48 dodań i 8 usunięć

Wyświetl plik

@ -81,7 +81,7 @@ class ATProto(User, Protocol):
@ndb.ComputedProperty @ndb.ComputedProperty
def handle(self): def handle(self):
"""Returns handle if the DID document includes one, otherwise None.""" """Returns handle if the DID document includes one, otherwise None."""
if did_obj := ATProto.load(self.key.id()): if did_obj := ATProto.load(self.key.id(), did_doc=True):
if aka := util.get_first(did_obj.raw, 'alsoKnownAs', ''): if aka := util.get_first(did_obj.raw, 'alsoKnownAs', ''):
handle, _, _ = parse_at_uri(aka) handle, _, _ = parse_at_uri(aka)
if handle: if handle:
@ -114,8 +114,13 @@ class ATProto(User, Protocol):
return did.resolve_handle(handle, get_fn=util.requests_get) return did.resolve_handle(handle, get_fn=util.requests_get)
@staticmethod
def profile_at_uri(id):
assert id.startswith('did:')
return f'at://{id}/app.bsky.actor.profile/self'
def profile_id(self): def profile_id(self):
return f'at://{self.key.id()}/app.bsky.actor.profile/self' return self.profile_at_uri(self.key.id())
@classmethod @classmethod
def target_for(cls, obj, shared=False): def target_for(cls, obj, shared=False):
@ -166,7 +171,7 @@ class ATProto(User, Protocol):
else: else:
return None return None
did_obj = ATProto.load(repo) did_obj = ATProto.load(repo, did_doc=True)
if did_obj: if did_obj:
return cls.pds_for(did_obj) return cls.pds_for(did_obj)
# TODO: what should we do if the DID doesn't exist? should we return # TODO: what should we do if the DID doesn't exist? should we return
@ -308,7 +313,7 @@ class ATProto(User, Protocol):
did = user.get_copy(ATProto) did = user.get_copy(ATProto)
assert did assert did
logger.info(f'{user.key} is {did}') logger.info(f'{user.key} is {did}')
did_doc = to_cls.load(did) did_doc = to_cls.load(did, did_doc=True)
pds = to_cls.pds_for(did_doc) pds = to_cls.pds_for(did_doc)
if not pds or util.domain_from_link(pds) not in DOMAINS: if not pds or util.domain_from_link(pds) not in DOMAINS:
logger.warning(f'{from_key} {did} PDS {pds} is not us') logger.warning(f'{from_key} {did} PDS {pds} is not us')
@ -355,6 +360,19 @@ class ATProto(User, Protocol):
write() write()
return True return True
@classmethod
def load(cls, id, did_doc=False, **kwargs):
"""Thin wrapper that converts DIDs to profile URIs.
Args:
did_doc (bool): if True, loads and returns a DID document object
instead of an ``app.bsky.actor.profile/self``.
"""
if not did_doc and id.startswith('did:'):
id = cls.profile_at_uri(id)
return super().load(id, **kwargs)
@classmethod @classmethod
def fetch(cls, obj, **kwargs): def fetch(cls, obj, **kwargs):
"""Tries to fetch a ATProto object. """Tries to fetch a ATProto object.

Wyświetl plik

@ -716,7 +716,7 @@ class Protocol:
# fall through to deliver to followers # fall through to deliver to followers
# fetch actor if necessary so we have name, profile photo, etc # fetch actor if necessary
if actor and actor.keys() == set(['id']): if actor and actor.keys() == set(['id']):
logger.info('Fetching actor so we have name, profile photo, etc') logger.info('Fetching actor so we have name, profile photo, etc')
actor_obj = from_cls.load(actor['id']) actor_obj = from_cls.load(actor['id'])

Wyświetl plik

@ -316,6 +316,16 @@ class ATProtoTest(TestCase):
obj = Object(id='https://bsky.app/profile/bad.com/post/789') obj = Object(id='https://bsky.app/profile/bad.com/post/789')
self.assertFalse(ATProto.fetch(obj)) self.assertFalse(ATProto.fetch(obj))
def test_load_did_doc(self):
obj = self.store_object(id='did:plc:user', raw=DID_DOC)
self.assert_entities_equal(obj, ATProto.load('did:plc:user', did_doc=True))
def test_load_did_doc_false_loads_profile(self):
did_doc = self.store_object(id='did:plc:user', raw=DID_DOC)
profile = self.store_object(id='at://did:plc:user/app.bsky.actor.profile/self',
bsky=ACTOR_PROFILE_BSKY)
self.assert_entities_equal(profile, ATProto.load('did:plc:user'))
def test_convert_bsky_pass_through(self): def test_convert_bsky_pass_through(self):
self.assertEqual({ self.assertEqual({
'foo': 'bar', 'foo': 'bar',
@ -593,7 +603,7 @@ class ATProtoTest(TestCase):
did = user.get_copy(ATProto) did = user.get_copy(ATProto)
assert did assert did
self.assertEqual([Target(uri=did, protocol='atproto')], user.copies) self.assertEqual([Target(uri=did, protocol='atproto')], user.copies)
did_obj = ATProto.load(did) did_obj = ATProto.load(did, did_doc=True)
self.assertEqual('https://atproto.brid.gy/', self.assertEqual('https://atproto.brid.gy/',
did_obj.raw['service'][0]['serviceEndpoint']) did_obj.raw['service'][0]['serviceEndpoint'])

Wyświetl plik

@ -9,6 +9,7 @@ from oauth_dropins.webutil.testutil import requests_response
from activitypub import ActivityPub from activitypub import ActivityPub
import app import app
from atproto import ATProto from atproto import ATProto
from granary.tests.test_bluesky import ACTOR_PROFILE_BSKY
import hub import hub
from models import Target from models import Target
from web import Web from web import Web
@ -61,8 +62,8 @@ class IntegrationTests(TestCase):
Target(uri='at://did:plc:bob/app.bsky.feed.post/123', protocol='atproto'), Target(uri='at://did:plc:bob/app.bsky.feed.post/123', protocol='atproto'),
]) ])
# ATProto listNotifications => receive
mock_get.side_effect = [ mock_get.side_effect = [
# ATProto listNotifications
requests_response({ requests_response({
'cursor': '...', 'cursor': '...',
'notifications': [{ 'notifications': [{
@ -90,6 +91,12 @@ class IntegrationTests(TestCase):
}, },
}], }],
}), }),
# ATProto getRecord of alice's profile
requests_response({
'uri': 'at://did:plc:alice/app.bsky.actor.profile/self',
'cid': 'alice sidd',
'value': test_atproto.ACTOR_PROFILE_BSKY,
}),
] ]
resp = self.post('/queue/atproto-poll-notifs', client=hub.app.test_client()) resp = self.post('/queue/atproto-poll-notifs', client=hub.app.test_client())
@ -136,7 +143,6 @@ class IntegrationTests(TestCase):
bob = self.make_user(id='bob.com', cls=Web, bob = self.make_user(id='bob.com', cls=Web,
copies=[Target(uri='did:plc:bob', protocol='atproto')]) copies=[Target(uri='did:plc:bob', protocol='atproto')])
# ATProto listNotifications => receive
mock_get.side_effect = [ mock_get.side_effect = [
# ATProto listNotifications # ATProto listNotifications
requests_response({ requests_response({
@ -157,6 +163,12 @@ class IntegrationTests(TestCase):
}, },
}], }],
}), }),
# ATProto getRecord of alice's profile
requests_response({
'uri': 'at://did:plc:alice/app.bsky.actor.profile/self',
'cid': 'alice sidd',
'value': test_atproto.ACTOR_PROFILE_BSKY,
}),
# webmention discovery # webmention discovery
test_web.WEBMENTION_REL_LINK, test_web.WEBMENTION_REL_LINK,
] ]