diff --git a/activitypub.py b/activitypub.py index e469d37..20ddee8 100644 --- a/activitypub.py +++ b/activitypub.py @@ -78,7 +78,7 @@ class ActivityPub(User, Protocol): assert not self.is_blocklisted(domain), f'{id} is a blocked domain' def web_url(self): - """Returns this user's web URL aka web_url, eg 'https://foo.com/'.""" + """Returns this user's web URL aka web_url, eg ``https://foo.com/``.""" if self.obj and self.obj.as1: url = util.get_url(self.obj.as1) if url: diff --git a/protocol.py b/protocol.py index 385c298..5cd1f5f 100644 --- a/protocol.py +++ b/protocol.py @@ -221,10 +221,10 @@ class Protocol: network or other discovery. Args: - id: str + id (str) Returns: - :class:`Protocol` subclass, or None if no known protocol owns this id + Protocol subclass: ...or None if no known protocol owns this id """ logger.info(f'Determining protocol for id {id}') if not id: @@ -234,7 +234,7 @@ class Protocol: if util.is_web(id): by_subdomain = Protocol.for_bridgy_subdomain(id) if by_subdomain: - logger.info(f' {by_subdomain.__name__} owns {id}') + logger.info(f' {by_subdomain.__name__} owns id {id}') return by_subdomain # step 2: check if any Protocols say conclusively that they own it @@ -245,13 +245,13 @@ class Protocol: for protocol in protocols: owns = protocol.owns_id(id) if owns: - logger.info(f' {protocol.__name__} owns {id}') + logger.info(f' {protocol.__name__} owns id {id}') return protocol elif owns is not False: candidates.append(protocol) if len(candidates) == 1: - logger.info(f' {candidates[0].__name__} owns {id}') + logger.info(f' {candidates[0].__name__} owns id {id}') return candidates[0] # step 3: look for existing Objects in the datastore @@ -265,7 +265,7 @@ class Protocol: logger.info(f'Trying {protocol.__name__}') try: if protocol.load(id, local=False, remote=True): - logger.info(f' {protocol.__name__} owns {id}') + logger.info(f' {protocol.__name__} owns id {id}') return protocol except werkzeug.exceptions.BadGateway: # we tried and failed fetching the id over the network. @@ -284,6 +284,57 @@ class Protocol: logger.info(f'No matching protocol found for {id} !') return None + @staticmethod + def for_handle(handle): + """Returns the protocol for a given handle. + + May incur expensive side effects like resolving the handle itself over + the network or other discovery. + + Args: + handle (str) + + Returns: + (Protocol subclass, str) tuple: matching protocol and optional id (if + resolved), or ``(None, None)`` if no known protocol owns this handle + """ + logger.info(f'Determining protocol for handle {handle}') + if not handle: + return (None, None) + + # step 1: check if any Protocols say conclusively that they own it. + # sort to be deterministic. + protocols = sorted(set(p for p in PROTOCOLS.values() if p), + key=lambda p: p.__name__) + candidates = [] + for proto in protocols: + owns = proto.owns_handle(handle) + if owns: + logger.info(f' {proto.__name__} owns handle {handle}') + return (proto, None) + elif owns is not False: + candidates.append(proto) + + if len(candidates) == 1: + logger.info(f' {candidates[0].__name__} owns handle {handle}') + return (candidates[0], None) + + # step 2: look for matching User in the datastore + for proto in candidates: + user = proto.query(proto.handle == handle).get(keys_only=True) + if user: + logger.info(f' user {user} owns handle {handle}') + return (proto, user.id()) + + # step 3: resolve handle to id + for proto in candidates: + id = proto.handle_to_id(handle) + if id: + logger.info(f' {proto.__name__} resolved handle {handle} to id {id}') + return (proto, id) + + return (None, None) + @classmethod def actor_key(cls, obj, default_g_user=True): """Returns the :class:`User`: key for a given object's author or actor. diff --git a/tests/test_atproto.py b/tests/test_atproto.py index da93193..6be6a4a 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -1,7 +1,6 @@ """Unit tests for atproto.py.""" import base64 import copy -from google.cloud.tasks_v2.types import Task import logging from unittest import skip from unittest.mock import call, patch @@ -13,6 +12,7 @@ import arroba.util import dns.resolver from dns.resolver import NXDOMAIN from flask import g +from google.cloud.tasks_v2.types import Task from granary.tests.test_bluesky import ( ACTOR_AS, ACTOR_PROFILE_VIEW_BSKY, @@ -112,7 +112,7 @@ class ATProtoTest(TestCase): self.assertEqual('did:plc:foo', ATProto.handle_to_id('han.dull')) @patch('dns.resolver.resolve', side_effect=dns.resolver.NXDOMAIN()) - # resolving handle, HTTPS method, not founud + # resolving handle, HTTPS method, not found @patch('requests.get', return_value=requests_response('', status=404)) def test_handle_to_id_not_found(self, *_): self.assertIsNone(ATProto.handle_to_id('han.dull')) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 882273f..4d81861 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -3,6 +3,7 @@ import copy from unittest import skip from unittest.mock import patch +from arroba.tests.testutil import dns_answer from flask import g from google.cloud import ndb from granary import as2 @@ -90,8 +91,7 @@ class ProtocolTest(TestCase): ('', None), ('foo://bar', None), ('fake:foo', Fake), - # TODO - # ('at://foo', ATProto), + ('at://foo', ATProto), ('https://ap.brid.gy/foo/bar', ActivityPub), ('https://web.brid.gy/foo/bar', Web), ]: @@ -139,6 +139,27 @@ class ProtocolTest(TestCase): self.assertIsNone(Protocol.for_id('http://web.site/')) self.assertIn(self.req('http://web.site/'), mock_get.mock_calls) + def test_for_handle_deterministic(self): + for handle, expected in [ + (None, (None, None)), + ('', (None, None)), + ('foo://bar', (None, None)), + ('fake:foo', (None, None)), + ('fake:handle:foo', (Fake, None)), + ('@me@foo', (ActivityPub, None)), + ]: + self.assertEqual(expected, Protocol.for_handle(handle)) + + def test_for_handle_stored_user(self): + user = self.make_user(id='user.com', cls=Web) + self.assertEqual('user.com', user.handle) + self.assertEqual((Web, 'user.com'), Protocol.for_handle('user.com')) + + @patch('dns.resolver.resolve', return_value = dns_answer( + '_atproto.han.dull.', '"did=did:plc:123abc"')) + def test_for_handle_atproto_resolve(self, _): + self.assertEqual((ATProto, 'did:plc:123abc'), Protocol.for_handle('han.dull')) + def test_load(self): Fake.fetchable['foo'] = {'x': 'y'}