diff --git a/activitypub.py b/activitypub.py index c2137f4..f7bed4a 100644 --- a/activitypub.py +++ b/activitypub.py @@ -153,10 +153,11 @@ def inbox(domain=None): logger.info(f'updating Object {obj_id}') obj = Object.get_by_id(obj_id) or Object(id=obj_id) - obj.as2 = json_dumps(obj_as2) - obj_as1 = as2.to_as1(obj_as2) - obj.as1 = json_dumps(obj_as1) - obj.source_protocol = 'activitypub' + obj.populate( + as2=json_dumps(obj_as2), + as1=json_dumps(as2.to_as1(obj_as2)), + source_protocol='activitypub', + ) obj.put() return 'OK' @@ -188,7 +189,8 @@ def inbox(domain=None): # fetch object if necessary so we can render it in feeds if type in FETCH_OBJECT_TYPES and isinstance(activity.get('object'), str): - obj_as2 = activity['object'] = common.get_as2(activity['object'], user=user).json() + obj_as2 = activity['object'] = \ + common.get_as2(activity['object'], user=user).json() activity_unwrapped = redirect_unwrap(activity) if type == 'Follow': diff --git a/common.py b/common.py index a8313d3..8ec1dca 100644 --- a/common.py +++ b/common.py @@ -9,8 +9,10 @@ import itertools import logging import os import re +import threading import urllib.parse +from cachetools import cached, LRUCache from flask import request from granary import as1, as2, microformats2 from httpsig.requests_auth import HTTPSignatureAuth @@ -117,6 +119,30 @@ def pretty_link(url, text=None, user=None): return util.pretty_link(url, text=text) +@cached(LRUCache(1000), lock=threading.Lock()) +def get_object(id, user=None): + """Loads and returns an Object from memory cache, datastore, or HTTP fetch. + + Note that :meth:`Object._post_put_hook` updates the cache. + + Args: + id: str + user: optional, :class:`User` used to sign HTTP request, if necessary + + Returns: Object, or None if it can't be fetched + """ + if obj := Object.get_by_id(id): + return obj + + obj_as2 = get_as2(id, user=user).json() + obj = Object(id=id, + as2=json_dumps(obj_as2), + as1=json_dumps(as2.to_as1(obj_as2)), + source_protocol='activitypub') + obj.put() + return obj + + def signed_get(url, user, **kwargs): return signed_request(util.requests_get, url, user, **kwargs) diff --git a/models.py b/models.py index abc5d2d..78afb1f 100644 --- a/models.py +++ b/models.py @@ -302,6 +302,11 @@ class Object(StringIdModel): created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True) + def _post_put_hook(self, future): + """Update :func:`common.get_object` cache.""" + if self.type != 'activity': + common.get_object.cache[self.key.id()] = self + def proxy_url(self): """Returns the Bridgy Fed proxy URL to render this post as HTML.""" return common.host_url('render?' + diff --git a/tests/test_activitypub.py b/tests/test_activitypub.py index 9a0ca5c..84d2595 100644 --- a/tests/test_activitypub.py +++ b/tests/test_activitypub.py @@ -697,7 +697,8 @@ class ActivityPubTest(testutil.TestCase): self.assertEqual('active', other.key.get().status) def test_delete_note(self, _, __, ___): - key = Object(id='http://an/obj', as1='{}').put() + obj = Object(id='http://an/obj', as1='{}') + obj.put() delete = { **DELETE, @@ -705,11 +706,14 @@ class ActivityPubTest(testutil.TestCase): } resp = self.client.post('/inbox', json=delete) self.assertEqual(200, resp.status_code) - self.assertTrue(key.get().deleted) + self.assertTrue(obj.key.get().deleted) self.assert_object(delete['id'], as2=delete, as1=as2.to_as1(delete), type='delete', source_protocol='activitypub', status='complete') + obj.deleted = True + self.assert_entities_equal(obj, common.get_object.cache['http://an/obj']) + def test_update_note(self, *_): Object(id='https://a/note', as1='{}').put() self._test_update() @@ -728,6 +732,9 @@ class ActivityPubTest(testutil.TestCase): type='update', status='complete', as2=UPDATE_NOTE, as1=as2.to_as1(UPDATE_NOTE)) + self.assert_entities_equal(Object.get_by_id('https://a/note'), + common.get_object.cache['https://a/note']) + def test_inbox_webmention_discovery_connection_fails(self, mock_head, mock_get, mock_post): mock_get.side_effect = [ diff --git a/tests/test_common.py b/tests/test_common.py index 31d46da..da9f372 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -4,13 +4,14 @@ from unittest import mock from granary import as2 from oauth_dropins.webutil import appengine_config, util +from oauth_dropins.webutil.util import json_dumps, json_loads from oauth_dropins.webutil.testutil import requests_response import requests from werkzeug.exceptions import BadGateway from app import app import common -from models import User +from models import Object, User from . import testutil HTML = requests_response('', headers={ @@ -226,3 +227,39 @@ class CommonTest(testutil.TestCase): resp = common.signed_post('https://first', user=self.user) mock_post.assert_called_once() self.assertEqual(302, resp.status_code) + + @mock.patch('requests.get', return_value=AS2) + def test_get_object_http(self, mock_get): + self.assertEqual(0, Object.query().count()) + + # first time fetches over HTTP + id = 'http://the/id' + got = common.get_object(id) + self.assert_equals(id, got.key.id()) + self.assert_equals(AS2_OBJ, json_loads(got.as2)) + mock_get.assert_has_calls([self.as2_req(id)]) + + # second time is in cache + got.key.delete() + mock_get.reset_mock() + got = common.get_object(id) + self.assert_equals(id, got.key.id()) + self.assert_equals(AS2_OBJ, json_loads(got.as2)) + mock_get.assert_not_called() + + @mock.patch('requests.get') + def test_get_object_datastore(self, mock_get): + id = 'http://the/id' + stored = Object(id=id, as2=json_dumps(AS2_OBJ), as1='{}') + stored.put() + + # first time loads from datastore + got = common.get_object(id) + self.assert_entities_equal(stored, got) + mock_get.assert_not_called() + + # second time is in cache + stored.key.delete() + got = common.get_object(id) + self.assert_entities_equal(stored, got) + mock_get.assert_not_called() diff --git a/tests/testutil.py b/tests/testutil.py index 9234c64..4536e1c 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -29,6 +29,7 @@ class TestCase(unittest.TestCase, testutil.Asserts): super().setUp() app.testing = True cache.clear() + common.get_object.cache.clear() self.client = app.test_client() self.client.__enter__()