From 21ab9e34ed6dcab3457cfcbc7a079f2bfb60e6d2 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Sun, 18 Jun 2023 07:29:54 -0700 Subject: [PATCH] Revise Protocol.load shallow and refresh kwargs, rename to local and remote and use in for_id to optimize datastore usage. --- protocol.py | 82 +++++++++++++++++++++++------------------- tests/test_protocol.py | 45 +++++++++++++++++------ web.py | 2 +- 3 files changed, 81 insertions(+), 48 deletions(-) diff --git a/protocol.py b/protocol.py index 9cb6939..48025ad 100644 --- a/protocol.py +++ b/protocol.py @@ -176,14 +176,13 @@ class Protocol: if not id: return None - # check for our per-protocol subdomains + # step 1: check for our per-protocol subdomains if util.is_web(id): by_domain = Protocol.for_domain(id) if by_domain: return by_domain - candidates = [] - + # step 2: check if any Protocols say conclusively that they own it # sort to be deterministic protocols = sorted(set(p for p in PROTOCOLS.values() if p), key=lambda p: p.__name__) @@ -198,13 +197,18 @@ class Protocol: if len(candidates) == 1: return candidates[0] + # step 3: look for existing Objects in the datastore + obj = Protocol.load(id, remote=False) + if obj and obj.source_protocol: + logger.info(f'{obj.key} has source_protocol {obj.source_protocol}') + return PROTOCOLS[obj.source_protocol] + + # step 4: fetch over the network for protocol in candidates: logger.info(f'Trying {protocol.__name__}') try: - obj = protocol.load(id) - if obj.source_protocol: - logger.info(f"{obj.key} has source_protocol {obj.source_protocol}") - return PROTOCOLS[obj.source_protocol] + protocol.load(id, local=False, remote=True) + return protocol except werkzeug.exceptions.HTTPException: # internal error we generated ourselves; try next protocol pass @@ -244,11 +248,9 @@ class Protocol: @classmethod def fetch(cls, obj, **kwargs): - """Fetches a protocol-specific object and returns it in an :class:`Object`. + """Fetches a protocol-specific object and populates it in an :class:`Object`. - To be implemented by subclasses. The returned :class:`Object` is loaded - from the datastore, if it exists there, then updated in memory but not - yet written back to the datastore. + To be implemented by subclasses. Args: obj: :class:`Object` with the id to fetch. Data is filled into one of @@ -595,55 +597,63 @@ class Protocol: error(msg, status=int(errors[0][0] or 502)) @classmethod - def load(cls, id, refresh=False, shallow=True, **kwargs): + def load(cls, id, remote=None, local=True, **kwargs): """Loads and returns an Object from memory cache, datastore, or HTTP fetch. Note that :meth:`Object._post_put_hook` updates the cache. Args: id: str - refresh: boolean, whether to fetch the object remotely even if we have - it stored - shallow: boolean, whether to only fetch from the datastore. If it - isn't there, returns None instead of fetching over the network. + + remote: boolean, whether to fetch the object over the network. If True, + fetches even if we already have the object stored, and updates our + stored copy. If False and we don't have the object stored, returns + None. Default (None) means to fetch over the network only if we + don't already have it stored. + local: boolean, whether to load from the datastore before + fetching over the network. If False, still stores back to the + datastore after a successful remote fetch. kwargs: passed through to :meth:`fetch()` - Returns: :class:`Object` or None if it isn't in the datastore and shallow - is True + Returns: :class:`Object` or None if it isn't in the datastore and remote + is False Raises: :class:`requests.HTTPError`, anything else that :meth:`fetch` raises """ - assert not (refresh and shallow) + assert local or remote is not False - if not refresh: + logger.info(f'Loading Object {id} local={local} remote={remote}') + + if remote is not True: with objects_cache_lock: cached = objects_cache.get(id) if cached: return cached - logger.info(f'Loading Object {id}') - orig_as1 = None - obj = Object.get_by_id(id) - if obj and (obj.as1 or obj.deleted): - logger.info(' got from datastore') - obj.new = False - orig_as1 = obj.as1 - if not refresh: - with objects_cache_lock: - objects_cache[id] = obj - return obj + obj = orig_as1 = None + if local: + obj = Object.get_by_id(id) + if obj and (obj.as1 or obj.deleted): + logger.info(' got from datastore') + obj.new = False + orig_as1 = obj.as1 + if remote is not True: + with objects_cache_lock: + objects_cache[id] = obj + return obj - if refresh: - logger.info(' forced refresh requested') + if remote is True: + logger.info(' remote=True, forced refresh requested') if obj: obj.clear() obj.new = False else: - logger.info(' not in datastore') - if shallow: - logger.info(' shallow load requested, returning None') + if local: + logger.info(' not in datastore') + if remote is False: + logger.info(' remote=False; returning None') return None obj = Object(id=id) obj.new = True diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 548f009..197f031 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -11,7 +11,9 @@ from .testutil import Fake, TestCase from activitypub import ActivityPub from app import app from models import Follower, Object, PROTOCOLS, User +import protocol from protocol import Protocol +import requests from ui import UIProtocol from web import Web @@ -189,52 +191,73 @@ class ProtocolTest(TestCase): self.assertEqual([], Fake.fetched) - def test_load_refresh_existing_empty(self): + def test_load_remote_true_existing_empty(self): Fake.objects['foo'] = {'x': 'y'} Object(id='foo').put() - loaded = Fake.load('foo', refresh=True) + loaded = Fake.load('foo', remote=True) self.assertEqual({'x': 'y'}, loaded.as1) self.assertTrue(loaded.changed) self.assertFalse(loaded.new) self.assertEqual(['foo'], Fake.fetched) - def test_load_refresh_new_empty(self): + def test_load_remote_true_new_empty(self): Fake.objects['foo'] = None Object(id='foo', our_as1={'x': 'y'}).put() - loaded = Fake.load('foo', refresh=True) + loaded = Fake.load('foo', remote=True) self.assertIsNone(loaded.as1) self.assertTrue(loaded.changed) self.assertFalse(loaded.new) self.assertEqual(['foo'], Fake.fetched) - def test_load_refresh_unchanged(self): + def test_load_remote_true_unchanged(self): obj = Object(id='foo', our_as1={'x': 'stored'}) obj.put() Fake.objects['foo'] = {'x': 'stored'} - loaded = Fake.load('foo', refresh=True) + loaded = Fake.load('foo', remote=True) self.assert_entities_equal(obj, loaded) self.assertFalse(obj.changed) self.assertFalse(obj.new) self.assertEqual(['foo'], Fake.fetched) - def test_load_refresh_changed(self): + def test_load_remote_true_changed(self): Object(id='foo', our_as1={'content': 'stored'}).put() Fake.objects['foo'] = {'content': 'new'} - loaded = Fake.load('foo', refresh=True) + loaded = Fake.load('foo', remote=True) self.assert_equals({'content': 'new'}, loaded.our_as1) self.assertTrue(loaded.changed) self.assertFalse(loaded.new) self.assertEqual(['foo'], Fake.fetched) - def test_load_shallow_missing(self): - self.assertIsNone(Fake.load('nope', shallow=True)) + def test_load_remote_false(self): + self.assertIsNone(Fake.load('nope', remote=False)) self.assertEqual([], Fake.fetched) obj = Object(id='foo', our_as1={'content': 'stored'}) obj.put() - self.assert_entities_equal(obj, Fake.load('foo', shallow=True)) + self.assert_entities_equal(obj, Fake.load('foo', remote=False)) self.assertEqual([], Fake.fetched) + + def test_local_false_missing(self): + with self.assertRaises(requests.HTTPError) as e: + Fake.load('foo', local=False) + self.assertEqual(410, e.response.status_code) + + self.assertEqual(['foo'], Fake.fetched) + + def test_local_false_existing(self): + obj = Object(id='foo', our_as1={'content': 'stored'}, source_protocol='ui') + obj.put() + del protocol.objects_cache['foo'] + + Fake.objects['foo'] = {'foo': 'bar'} + Fake.load('foo', local=False) + self.assert_object('foo', source_protocol='fake', our_as1={'foo': 'bar'}) + self.assertEqual(['foo'], Fake.fetched) + + def test_remote_false_local_false_assert(self): + with self.assertRaises(AssertionError): + Fake.load('nope', local=False, remote=False) diff --git a/web.py b/web.py index 3d3d9f6..f2e9a5f 100644 --- a/web.py +++ b/web.py @@ -494,7 +494,7 @@ def webmention_task(): # fetch source page try: - obj = Web.load(source, refresh=True, check_backlink=True) + obj = Web.load(source, remote=True, check_backlink=True) except BadRequest as e: error(str(e.description), status=304) except HTTPError as e: