diff --git a/common.py b/common.py index b498128..70ee045 100644 --- a/common.py +++ b/common.py @@ -1,6 +1,7 @@ """Misc common utilities.""" import base64 from datetime import timedelta +import functools import logging from pathlib import Path import re @@ -9,6 +10,7 @@ import urllib.parse from urllib.parse import urljoin, urlparse import cachetools +from cachetools.keys import hashkey from Crypto.Util import number from flask import abort, g, has_request_context, make_response, request from google.cloud.error_reporting.util import build_flask_context @@ -413,3 +415,29 @@ def memcache_key(key): pymemcache Client's allow_unicode_keys constructor kwarg. """ return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode() + + +def memcache_memoize(expire=None): + """Memoize function decorator that stores the cached value in memcache. + + Only caches non-null/empty values. + + Args: + expire (int): optional, expiration in seconds + """ + def decorator(fn): + @functools.wraps(fn) + def wrapped(*args, **kwargs): + key = memcache_key(f'{fn.__name__}-{str(hashkey(*args, **kwargs))}') + if val := memcache.get(key): + logger.debug(f'cache hit {key}') + return val + + logger.debug(f'cache miss {key}') + val = fn(*args, **kwargs) + memcache.set(key, val) + return val + + return wrapped + + return decorator diff --git a/config.py b/config.py index 44bf8fb..3a15778 100644 --- a/config.py +++ b/config.py @@ -29,7 +29,7 @@ else: if logging_client := getattr(appengine_config, 'logging_client'): logging_client.setup_logging(log_level=logging.INFO) - for logger in ('oauth_dropins.webutil.webmention', 'lexrpc'): + for logger in ('common', 'oauth_dropins.webutil.webmention', 'lexrpc'): logging.getLogger(logger).setLevel(logging.DEBUG) os.environ.setdefault('APPVIEW_HOST', 'api.bsky.local') diff --git a/models.py b/models.py index eec1416..38ce86c 100644 --- a/models.py +++ b/models.py @@ -30,6 +30,7 @@ from common import ( base64_to_long, DOMAIN_RE, long_to_base64, + memcache_memoize, OLD_ACCOUNT_AGE, remove, report_error, @@ -1563,6 +1564,7 @@ def get_original(copy_id, keys_only=None): return got[0] +@memcache_memoize(expire=60 * 60 * 24) # 1d def get_originals(copy_ids, keys_only=None): """Fetches users (across all protocols) for a given set of copies. @@ -1577,7 +1579,7 @@ def get_originals(copy_ids, keys_only=None): """ assert copy_ids - classes = set(cls for cls in PROTOCOLS.values() if cls) + classes = set(cls for cls in PROTOCOLS.values() if cls and cls.LABEL != 'ui') classes.add(Object) return list(itertools.chain(*( diff --git a/tests/test_models.py b/tests/test_models.py index 1801cbb..f255998 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1102,6 +1102,13 @@ class ObjectTest(TestCase): user = self.make_user('other:user', cls=OtherFake, copies=[Target(uri='fake:bar', protocol='fake')]) + memcache_key = "get_originals-(['other:foo',%20'fake:bar',%20'baz'],)" + self.assertIsNone(common.memcache.get(memcache_key)) + + self.assert_entities_equal( + [obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz'])) + + self.assertIsNotNone(common.memcache.get(memcache_key)) self.assert_entities_equal( [obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz']))