diff --git a/protocol.py b/protocol.py index fc271b8..6a5ba5e 100644 --- a/protocol.py +++ b/protocol.py @@ -79,26 +79,30 @@ class Protocol: :class:`Protocol` subclass, or None if the provided domain or request hostname domain is not a subdomain of brid.gy or isn't a known protocol """ - request_cls = Protocol.for_domain(request.host) - if request_cls: - return request_cls - elif (request.host == common.PRIMARY_DOMAIN - or request.host in common.LOCAL_DOMAINS): - return fed + return Protocol.for_domain(request.host, fed=fed) @staticmethod - def for_domain(domain): + def for_domain(domain_or_url, fed=None): """Returns the protocol for a brid.gy subdomain. + Args: + domain_or_url: str + fed: :class:`Protocol` subclass to return if the domain_or_url is on + fed.brid.gy + Returns: :class:`Protocol` subclass, or None if the request hostname is not a subdomain of brid.gy or isn't a known protocol """ - if not domain or not domain.endswith(common.SUPERDOMAIN): - return None + domain = (util.domain_from_link(domain_or_url, minimize=False) + if util.is_web(domain_or_url) + else domain_or_url) - label = domain.removesuffix(common.SUPERDOMAIN) - return PROTOCOLS.get(label) + if domain == common.PRIMARY_DOMAIN or domain in common.LOCAL_DOMAINS: + return fed + elif domain and domain.endswith(common.SUPERDOMAIN): + label = domain.removesuffix(common.SUPERDOMAIN) + return PROTOCOLS.get(label) @classmethod def send(cls, obj, url, log_data=True): @@ -223,7 +227,7 @@ class Protocol: followee_domain = util.domain_from_link(inner_obj_id, minimize=False) # TODO: avoid import? from web import Web - to_cls = Protocol.for_domain(followee_domain) or Protocol.for_request(fed=Web) + to_cls = Protocol.for_domain(followee_domain) or Protocol.for_request() or Web follower = Follower.query( Follower.to == to_cls(id=followee_domain).key, Follower.from_ == from_cls(id=actor_id).key, @@ -383,20 +387,23 @@ class Protocol: logger.info(f'targets: {targets}') - # send webmentions and update Object errors = [] # stores (code, body) tuples + + # TODO: avoid import? + from web import Web + targets = [Target(uri=uri, protocol=(Protocol.for_domain(uri) or Web).LABEL) + for uri in targets] + no_user_domains = set() + obj.undelivered = [] obj.status = 'in progress' - for uri in targets: - # TODO: avoid import? - from web import Web - domain = util.domain_from_link(uri, minimize=False) - protocol = Protocol.for_domain(domain) or Protocol.for_request() or Web - obj.undelivered.append(Target(uri=uri, protocol=protocol.LABEL)) - - no_user_domains = set() + obj.populate( + undelivered=targets, + status='in progress', + ) + # send webmentions and update Object while obj.undelivered: target = obj.undelivered.pop() domain = util.domain_from_link(target.uri, minimize=False) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f92c66e..0ff3e0d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -44,7 +44,7 @@ class ProtocolTest(TestCase): self.assertEqual(Web, PROTOCOLS['webmention']) def test_for_domain_for_request(self): - for domain, protocol in [ + for domain, expected in [ ('fake.brid.gy', Fake), ('ap.brid.gy', ActivityPub), ('activitypub.brid.gy', ActivityPub), @@ -58,19 +58,22 @@ class ProtocolTest(TestCase): ('fake', None), ('fake.com', None), ]: - with self.subTest(domain=domain, protocol=protocol): - self.assertEqual(protocol, Protocol.for_domain(domain)) + with self.subTest(domain=domain, expected=expected): + self.assertEqual(expected, Protocol.for_domain(domain)) with app.test_request_context('/foo', base_url=f'https://{domain}/'): - self.assertEqual(protocol, Protocol.for_request()) - - def test_for_request_fed(self): - for base_url in 'https://fed.brid.gy/', 'http://localhost/': - with app.test_request_context('/foo', base_url=base_url): - self.assertEqual(Fake, Protocol.for_request(fed=Fake)) - - with app.test_request_context('/foo', base_url='https://ap.brid.gy/'): - self.assertEqual(ActivityPub, Protocol.for_request(fed=Fake)) + self.assertEqual(expected, Protocol.for_request()) + def test_for_domain_for_request_fed(self): + for url, expected in [ + ('https://fed.brid.gy/', Fake), + ('http://localhost/foo', Fake), + ('https://ap.brid.gy/bar', ActivityPub), + ('https://baz/biff', None), + ]: + with self.subTest(url=url, expected=expected): + self.assertEqual(expected, Protocol.for_domain(url, fed=Fake)) + with app.test_request_context('/foo', base_url=url): + self.assertEqual(expected, Protocol.for_request(fed=Fake)) @patch('requests.get') def test_receive_reply_not_feed_not_notification(self, mock_get):