diff --git a/protocol.py b/protocol.py index 76346b4..fc16cfc 100644 --- a/protocol.py +++ b/protocol.py @@ -226,19 +226,18 @@ class Protocol: @cached(LRUCache(20000), lock=Lock()) @staticmethod - def for_id(id): + def for_id(id, remote=True): """Returns the protocol for a given id. - May incur expensive side effects like fetching the id itself over the - network or other discovery. - Args: id (str) + remote (bool): whether to perform expensive side effects like fetching + the id itself over the network, or other discovery. Returns: - Protocol subclass: matching protocol, or None if no known protocol - owns this id - """ + Protocol subclass: matching protocol, or None if no single known + protocol definitively owns this id + """ logger.info(f'Determining protocol for id {id}') if not id: return None @@ -273,7 +272,10 @@ class Protocol: logger.info(f' {obj.key} owned by source_protocol {obj.source_protocol}') return PROTOCOLS[obj.source_protocol] - # step 4: fetch over the network + # step 4: fetch over the network, if necessary + if not remote: + return None + for protocol in candidates: logger.info(f'Trying {protocol.LABEL}') try: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 04114f7..4eb4933 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -93,7 +93,8 @@ class ProtocolTest(TestCase): ('https://ap.brid.gy/foo/bar', ActivityPub), ('https://web.brid.gy/foo/bar', Web), ]: - self.assertEqual(expected, Protocol.for_id(id)) + self.assertEqual(expected, Protocol.for_id(id, remote=False)) + self.assertEqual(expected, Protocol.for_id(id, remote=True)) def test_for_id_true_overrides_none(self): class Greedy(Protocol, User): @@ -137,6 +138,11 @@ class ProtocolTest(TestCase): self.assertIsNone(Protocol.for_id('http://web.site/')) self.assertIn(self.req('http://web.site/'), mock_get.mock_calls) + @patch('requests.get') + def test_for_id_web_remote_false(self, mock_get): + self.assertIsNone(Protocol.for_id('http://web.site/', remote=False)) + mock_get.assert_not_called() + def test_for_handle_deterministic(self): for handle, expected in [ (None, (None, None)),