From 48c40c10a81ca276c241ef1b7eda8d2410cc8ae9 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Tue, 13 Jun 2023 13:17:11 -0700 Subject: [PATCH] add Protocol.for_id and .owns_id fixes #548 --- activitypub.py | 10 +++++ models.py | 3 +- protocol.py | 81 +++++++++++++++++++++++++++++++++++++-- tests/test_activitypub.py | 6 +++ tests/test_protocol.py | 42 ++++++++++++++++++++ tests/test_web.py | 6 +++ tests/testutil.py | 4 ++ web.py | 8 ++++ 8 files changed, 155 insertions(+), 5 deletions(-) diff --git a/activitypub.py b/activitypub.py index 1a77795..14ffc01 100644 --- a/activitypub.py +++ b/activitypub.py @@ -80,6 +80,16 @@ class ActivityPub(User, Protocol): """ return self.key.id() + @classmethod + def owns_id(cls, id): + """Returns None if id is an http(s) URL, False otherwise. + + All AP ids are http(s) URLs, but not all http(s) URLs are AP ids. + + https://www.w3.org/TR/activitypub/#obj-id + """ + return None if util.is_web(id) else False + @classmethod def send(cls, obj, url, log_data=True): """Delivers an activity to an inbox URL.""" diff --git a/models.py b/models.py index ea907eb..dc5c8a0 100644 --- a/models.py +++ b/models.py @@ -53,7 +53,8 @@ class ProtocolUserMeta(type(ndb.Model)): cls = super().__new__(meta, name, bases, class_dict) if hasattr(cls, 'LABEL') and cls.LABEL not in ('protocol', 'user'): for label in (cls.LABEL, cls.ABBREV) + cls.OTHER_LABELS: - PROTOCOLS[label] = cls + if label: + PROTOCOLS[label] = cls return cls diff --git a/protocol.py b/protocol.py index 6a5ba5e..4e79b05 100644 --- a/protocol.py +++ b/protocol.py @@ -7,6 +7,8 @@ from flask import g, request from google.cloud import ndb from google.cloud.ndb import OR from granary import as1, as2 +import requests +import werkzeug.exceptions import common from common import error @@ -76,8 +78,8 @@ class Protocol: fed.brid.gy Returns: - :class:`Protocol` subclass, or None if the provided domain or request - hostname domain is not a subdomain of brid.gy or isn't a known protocol + :class:`Protocol` subclass, or None if the provided domain or request + hostname domain is not a subdomain of brid.gy or isn't a known protocol """ return Protocol.for_domain(request.host, fed=fed) @@ -91,8 +93,8 @@ class Protocol: fed.brid.gy Returns: - :class:`Protocol` subclass, or None if the request hostname is not a - subdomain of brid.gy or isn't a known protocol + :class:`Protocol` subclass, or None if the request hostname is not a + subdomain of brid.gy or isn't a known protocol """ domain = (util.domain_from_link(domain_or_url, minimize=False) if util.is_web(domain_or_url) @@ -104,6 +106,77 @@ class Protocol: label = domain.removesuffix(common.SUPERDOMAIN) return PROTOCOLS.get(label) + @classmethod + def owns_id(cls, id): + """Returns whether this protocol owns the id, or None if it's unclear. + + To be implemented by subclasses. + + Some protocols' ids are more or less deterministic based on the id + format, eg AT Protocol owns at:// URIs. Others, like http(s) URLs, could + be owned by eg Web or ActivityPub. + + This should be a quick guess without expensive side effects, eg no + external HTTP fetches to fetch the id itself or otherwise perform + discovery. + + Args: + id: str + + Returns: + boolean or None + """ + return False + + @staticmethod + def for_id(id): + """Returns the protocol for a given id. + + May incur expensive side effects like fetching the id itself over the + network or other discovery. + + Args: + id: str + + Returns: + :class:`Protocol` subclass, or None if no known protocol owns this id + """ + logger.info(f'Determining protocol for id {id}') + if not id: + return None + + candidates = [] + for protocol in set(PROTOCOLS.values()): + if not protocol: + continue + owns = protocol.owns_id(id) + if owns: + return protocol + elif owns is not False: + candidates.append(protocol) + + if len(candidates) == 1: + return candidates[0] + + for protocol in candidates: + logger.info(f'Trying {protocol.__name__}') + try: + obj = protocol.load(id) + logger.info(f"Looks like it's {obj.source_protocol}") + return PROTOCOLS[obj.source_protocol] + except werkzeug.exceptions.HTTPException: + # internal error we generated ourselves; try next protocol + pass + except Exception as e: + code, _ = util.interpret_http_exception(e) + if code: + # we tried and failed fetching the id over the network + return None + logger.info(e) + + logger.info(f'No matching protocol found for {id} !') + return None + @classmethod def send(cls, obj, url, log_data=True): """Sends an outgoing activity. diff --git a/tests/test_activitypub.py b/tests/test_activitypub.py index 6ffb705..6c35b7b 100644 --- a/tests/test_activitypub.py +++ b/tests/test_activitypub.py @@ -1391,6 +1391,12 @@ class ActivityPubUtilsTest(TestCase): self.request_context.pop() super().tearDown() + def test_owns_id(self): + self.assertIsNone(ActivityPub.owns_id('http://foo')) + self.assertIsNone(ActivityPub.owns_id('https://bar/baz')) + self.assertFalse(ActivityPub.owns_id('at://did:plc:foo/bar/123')) + self.assertFalse(ActivityPub.owns_id('e45fab982')) + def test_postprocess_as2_multiple_in_reply_tos(self): self.assert_equals({ 'id': 'http://localhost/r/xyz', diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0ff3e0d..f635165 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -12,9 +12,11 @@ from activitypub import ActivityPub from app import app from models import Follower, Object, PROTOCOLS, User from protocol import Protocol +from ui import UIProtocol from web import Web from .test_activitypub import ACTOR, REPLY +from .test_web import ACTOR_HTML REPLY = { **REPLY, @@ -35,6 +37,7 @@ class ProtocolTest(TestCase): g.user = None def tearDown(self): + PROTOCOLS.pop('greedy', None) self.request_context.pop() super().tearDown() @@ -102,6 +105,45 @@ class ProtocolTest(TestCase): source_protocol='fake', ) + def test_for_id(self): + self.assertIsNone(Protocol.for_id(None)) + self.assertIsNone(Protocol.for_id('')) + self.assertIsNone(Protocol.for_id('foo://bar')) + self.assertEqual(Fake, Protocol.for_id('fake://foo')) + # TODO + # self.assertEqual(ATProto, Protocol.for_id('at://foo')) + + def test_for_id_true_overrides_none(self): + class Greedy(Protocol, User): + @classmethod + def owns_id(cls, id): + return True + + self.assertEqual(Greedy, Protocol.for_id('http://foo')) + self.assertEqual(Greedy, Protocol.for_id('https://bar/baz')) + + def test_for_id_object(self): + Object(id='http://ui/obj', source_protocol='ui').put() + self.assertEqual(UIProtocol, Protocol.for_id('http://ui/obj')) + + @patch('requests.get') + def test_for_id_activitypub_fetch(self, mock_get): + mock_get.return_value = self.as2_resp(ACTOR) + self.assertEqual(ActivityPub, Protocol.for_id('http://ap/actor')) + self.assertIn(self.as2_req('http://ap/actor'), mock_get.mock_calls) + + @patch('requests.get') + def test_for_id_web_fetch(self, mock_get): + mock_get.return_value = requests_response(ACTOR_HTML) + self.assertEqual(Web, Protocol.for_id('http://web.site/')) + self.assertIn(self.req('http://web.site/'), mock_get.mock_calls) + + @patch('requests.get') + def test_for_id_web_fetch_no_mf2(self, mock_get): + mock_get.return_value = requests_response('') + self.assertIsNone(Protocol.for_id('http://web.site/')) + self.assertIn(self.req('http://web.site/'), mock_get.mock_calls) + def test_load(self): Fake.objects['foo'] = {'x': 'y'} diff --git a/tests/test_web.py b/tests/test_web.py index b7fa563..38d2215 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -1673,6 +1673,12 @@ class WebProtocolTest(TestCase): self.request_context.__enter__() super().tearDown() + def test_owns_id(self, *_): + self.assertIsNone(Web.owns_id('http://foo')) + self.assertIsNone(Web.owns_id('https://bar/baz')) + self.assertFalse(Web.owns_id('at://did:plc:foo/bar/123')) + self.assertFalse(Web.owns_id('e45fab982')) + def test_fetch(self, mock_get, __): mock_get.return_value = REPOST diff --git a/tests/testutil.py b/tests/testutil.py index cfeb6bc..969f2c9 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -56,6 +56,10 @@ class Fake(User, protocol.Protocol): def ap_actor(self, rest=None): return f'http://bf/fake/{self.key.id()}/ap' + (f'/{rest}' if rest else '') + @classmethod + def owns_id(cls, id): + return id.startswith('fake://') + @classmethod def send(cls, obj, url, log_data=True): logger.info(f'Fake.send {url}') diff --git a/web.py b/web.py index 864a085..29b5b71 100644 --- a/web.py +++ b/web.py @@ -193,6 +193,14 @@ class Web(User, Protocol): return self + @classmethod + def owns_id(cls, id): + """Returns None if id is an http(s) URL, False otherwise. + + All web pages are http(s) URLs, but not all http(s) URLs are web pages. + """ + return None if util.is_web(id) else False + @classmethod def send(cls, obj, url): """Sends a webmention to a given target URL.