cache models.get_originals in memcache with new memcache_memoize decorator

for #1149
pull/1221/head
Ryan Barrett 2024-07-30 14:50:33 -07:00
rodzic 33e0d0b14a
commit 88cbe3b7b4
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
4 zmienionych plików z 39 dodań i 2 usunięć

Wyświetl plik

@ -1,6 +1,7 @@
"""Misc common utilities.""" """Misc common utilities."""
import base64 import base64
from datetime import timedelta from datetime import timedelta
import functools
import logging import logging
from pathlib import Path from pathlib import Path
import re import re
@ -9,6 +10,7 @@ import urllib.parse
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import cachetools import cachetools
from cachetools.keys import hashkey
from Crypto.Util import number from Crypto.Util import number
from flask import abort, g, has_request_context, make_response, request from flask import abort, g, has_request_context, make_response, request
from google.cloud.error_reporting.util import build_flask_context from google.cloud.error_reporting.util import build_flask_context
@ -413,3 +415,29 @@ def memcache_key(key):
pymemcache Client's allow_unicode_keys constructor kwarg. pymemcache Client's allow_unicode_keys constructor kwarg.
""" """
return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode() return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode()
def memcache_memoize(expire=None):
"""Memoize function decorator that stores the cached value in memcache.
Only caches non-null/empty values.
Args:
expire (int): optional, expiration in seconds
"""
def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
key = memcache_key(f'{fn.__name__}-{str(hashkey(*args, **kwargs))}')
if val := memcache.get(key):
logger.debug(f'cache hit {key}')
return val
logger.debug(f'cache miss {key}')
val = fn(*args, **kwargs)
memcache.set(key, val)
return val
return wrapped
return decorator

Wyświetl plik

@ -29,7 +29,7 @@ else:
if logging_client := getattr(appengine_config, 'logging_client'): if logging_client := getattr(appengine_config, 'logging_client'):
logging_client.setup_logging(log_level=logging.INFO) logging_client.setup_logging(log_level=logging.INFO)
for logger in ('oauth_dropins.webutil.webmention', 'lexrpc'): for logger in ('common', 'oauth_dropins.webutil.webmention', 'lexrpc'):
logging.getLogger(logger).setLevel(logging.DEBUG) logging.getLogger(logger).setLevel(logging.DEBUG)
os.environ.setdefault('APPVIEW_HOST', 'api.bsky.local') os.environ.setdefault('APPVIEW_HOST', 'api.bsky.local')

Wyświetl plik

@ -30,6 +30,7 @@ from common import (
base64_to_long, base64_to_long,
DOMAIN_RE, DOMAIN_RE,
long_to_base64, long_to_base64,
memcache_memoize,
OLD_ACCOUNT_AGE, OLD_ACCOUNT_AGE,
remove, remove,
report_error, report_error,
@ -1563,6 +1564,7 @@ def get_original(copy_id, keys_only=None):
return got[0] return got[0]
@memcache_memoize(expire=60 * 60 * 24) # 1d
def get_originals(copy_ids, keys_only=None): def get_originals(copy_ids, keys_only=None):
"""Fetches users (across all protocols) for a given set of copies. """Fetches users (across all protocols) for a given set of copies.
@ -1577,7 +1579,7 @@ def get_originals(copy_ids, keys_only=None):
""" """
assert copy_ids assert copy_ids
classes = set(cls for cls in PROTOCOLS.values() if cls) classes = set(cls for cls in PROTOCOLS.values() if cls and cls.LABEL != 'ui')
classes.add(Object) classes.add(Object)
return list(itertools.chain(*( return list(itertools.chain(*(

Wyświetl plik

@ -1102,6 +1102,13 @@ class ObjectTest(TestCase):
user = self.make_user('other:user', cls=OtherFake, user = self.make_user('other:user', cls=OtherFake,
copies=[Target(uri='fake:bar', protocol='fake')]) copies=[Target(uri='fake:bar', protocol='fake')])
memcache_key = "get_originals-(['other:foo',%20'fake:bar',%20'baz'],)"
self.assertIsNone(common.memcache.get(memcache_key))
self.assert_entities_equal(
[obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz']))
self.assertIsNotNone(common.memcache.get(memcache_key))
self.assert_entities_equal( self.assert_entities_equal(
[obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz'])) [obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz']))