common.memcache_memoize: add key callable kwarg

pull/1628/head
Ryan Barrett 2024-12-13 15:04:46 -08:00
rodzic 3a3c2dd557
commit 4f5e868b05
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
3 zmienionych plików z 43 dodań i 32 usunięć

Wyświetl plik

@ -490,11 +490,13 @@ def memcache_memoize_key(fn, *args, **kwargs):
NONE = () # empty tuple
def memcache_memoize(expire=None):
def memcache_memoize(expire=None, key=None):
"""Memoize function decorator that stores the cached value in memcache.
Args:
expire (timedelta): optional, expiration
key (callable): function that takes the function's (*args, **kwargs) and
returns the cache key to use
"""
if expire:
expire = int(expire.total_seconds())
@ -502,15 +504,19 @@ def memcache_memoize(expire=None):
def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
key = memcache_memoize_key(fn, *args, **kwargs)
val = pickle_memcache.get(key)
if key:
cache_key = memcache_memoize_key(fn, key(*args, **kwargs))
else:
cache_key = memcache_memoize_key(fn, *args, **kwargs)
val = pickle_memcache.get(cache_key)
if val is not None:
# logger.debug(f'cache hit {key}')
# logger.debug(f'cache hit {cache_key}')
return None if val == NONE else val
# logger.debug(f'cache miss {key}')
# logger.debug(f'cache miss {cache_key}')
val = fn(*args, **kwargs)
pickle_memcache.set(key, NONE if val is None else val, expire=expire)
pickle_memcache.set(cache_key, NONE if val is None else val, expire=expire)
return val
return wrapped

Wyświetl plik

@ -9,6 +9,9 @@ git+https://github.com/snarfed/lexrpc.git#egg=lexrpc
git+https://github.com/snarfed/mox3.git#egg=mox3
git+https://github.com/snarfed/negotiator.git@py3#egg=negotiator
git+https://github.com/snarfed/oauth-dropins.git#egg=oauth_dropins
# TODO: switch back to pypi as soon as a new release is cut after 4.0.0
# that includes https://github.com/pinterest/pymemcache/pull/471
git+https://github.com/pinterest/pymemcache.git#egg=pymemcache
attrs==24.2.0
bases==0.3.0
@ -86,7 +89,6 @@ pyasn1-modules==0.4.1
pycparser==2.22
pycryptodome==3.21.0
pyjwt==2.10.1
pymemcache==4.0.0
pyparsing==3.2.0
pyrsistent==0.20.0
python-dateutil==2.9.0.post0

Wyświetl plik

@ -175,7 +175,7 @@ class CommonTest(TestCase):
):
self.assertEqual(expected, common.memcache_key(input))
def test_memcache_memoize(self):
def test_memcache_memoize_int(self):
calls = []
@common.memcache_memoize()
@ -194,32 +194,18 @@ class CommonTest(TestCase):
self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)
# def test_memcache_memoize_Object(self):
# calls = []
def test_memcache_memoize_str(self):
calls = []
# obj = Object(users=[Key(Object, 'abc')],
# copies=[Target(uri='abc', protocol='web')],
# as2={'foo': 'x ☕ y', 'bar': True, 'baz': 5})
@common.memcache_memoize()
def foo(x):
calls.append(x)
return str(x)
# @common.memcache_memoize()
# def foo(x):
# calls.append(x)
# obj.key = Key(Object, x)
# return obj
# expected_a = Object(id='a', **obj.to_dict(include=['users', 'copies', 'as2']))
# self.assert_entities_equal(expected_a, foo('a'))
# self.assertEqual(['a'], calls)
# self.assert_entities_equal(expected_a, foo('a'))
# self.assertEqual(['a'], calls)
# expected_b = Object(id='b', **obj.to_dict(include=['users', 'copies', 'as2']))
# self.assert_entities_equal(expected_b, foo('b'))
# self.assertEqual(['a', 'b'], calls)
# self.assert_entities_equal(expected_a, foo('a'))
# self.assertEqual(['a', 'b'], calls)
# self.assert_entities_equal(expected_b, foo('b'))
# self.assertEqual(['a', 'b'], calls)
self.assertEqual('1', foo(1))
self.assertEqual([1], calls)
self.assertEqual('1', foo(1))
self.assertEqual([1], calls)
def test_memcache_memoize_Key(self):
calls = []
@ -256,6 +242,23 @@ class CommonTest(TestCase):
self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)
def test_memcache_memoize_key_fn(self):
calls = []
@common.memcache_memoize(key=lambda x: x + 1)
def foo(x):
calls.append(x)
return str(x)
self.assertEqual('5', foo(5))
self.assertEqual([5], calls)
self.assertIsNone(common.pickle_memcache.get(b'foo-2-(5,)-{}'))
self.assertEqual('5', common.pickle_memcache.get('foo-2-(6,)-{}'))
self.assertEqual('5', foo(5))
self.assertEqual([5], calls)
def test_as2_request_type(self):
for accept, expected in (
(as2.CONTENT_TYPE_LD_PROFILE, as2.CONTENT_TYPE_LD_PROFILE),