extract memcache logic out into new memcache.py file

pull/1675/head
Ryan Barrett 2025-01-09 16:57:01 -08:00
rodzic c7f4a39b0d
commit e8b201dc33
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
17 zmienionych plików z 277 dodań i 253 usunięć

Wyświetl plik

@ -33,7 +33,6 @@ from common import (
error,
host_url,
LOCAL_DOMAINS,
memcache,
PRIMARY_DOMAIN,
PROTOCOL_DOMAINS,
redirect_wrap,
@ -42,6 +41,7 @@ from common import (
unwrap,
)
from ids import BOT_ACTOR_AP_IDS
import memcache
from models import fetch_objects, Follower, Object, PROTOCOLS, User
from protocol import activity_id_memcache_key, DELETE_TASK_DELAY, Protocol
import webfinger
@ -1113,7 +1113,7 @@ def inbox(protocol=None, id=None):
logger.info(f'{domain} is opted out')
return '', 204
if memcache.get(activity_id_memcache_key(id)):
if memcache.memcache.get(activity_id_memcache_key(id)):
logger.info(f'Already seen {id}')
return '', 204

Wyświetl plik

@ -26,11 +26,7 @@ from oauth_dropins.webutil.util import json_dumps, json_loads
from atproto import ATProto, Cursor
from common import (
cache_policy,
create_task,
global_cache,
global_cache_policy,
global_cache_timeout_policy,
NDB_CONTEXT_KWARGS,
PROTOCOL_DOMAINS,
report_error,

109
common.py
Wyświetl plik

@ -10,12 +10,10 @@ import threading
import urllib.parse
from urllib.parse import urljoin, urlparse
import cachetools
from Crypto.Util import number
from flask import abort, g, has_request_context, make_response, request
from google.cloud.error_reporting.util import build_flask_context
from google.cloud import ndb
from google.cloud.ndb.global_cache import _InProcessGlobalCache, MemcacheCache
from google.cloud.ndb.key import Key
from google.protobuf.timestamp_pb2 import Timestamp
from granary import as2
@ -26,9 +24,8 @@ from oauth_dropins.webutil.appengine_info import DEBUG
from oauth_dropins.webutil import flask_util
from oauth_dropins.webutil.util import json_dumps
from negotiator import ContentNegotiator, AcceptParameters, ContentType
import pymemcache.client.base
from pymemcache.serde import PickleSerde
from pymemcache.test.utils import MockMemcacheClient
import memcache
logger = logging.getLogger(__name__)
@ -102,24 +99,6 @@ OLD_ACCOUNT_AGE = timedelta(days=14)
# populated later in this file
NDB_CONTEXT_KWARGS = None
# https://github.com/memcached/memcached/wiki/Commands#standard-protocol
MEMCACHE_KEY_MAX_LEN = 250
if appengine_info.DEBUG or appengine_info.LOCAL_SERVER:
logger.info('Using in memory mock memcache')
memcache = MockMemcacheClient(allow_unicode_keys=True)
pickle_memcache = MockMemcacheClient(allow_unicode_keys=True, serde=PickleSerde())
global_cache = _InProcessGlobalCache()
else:
logger.info('Using production Memorystore memcache')
memcache = pymemcache.client.base.PooledClient(
os.environ['MEMCACHE_HOST'], timeout=10, connect_timeout=10, # seconds
allow_unicode_keys=True)
pickle_memcache = pymemcache.client.base.PooledClient(
os.environ['MEMCACHE_HOST'], timeout=10, connect_timeout=10, # seconds
serde=PickleSerde(), allow_unicode_keys=True)
global_cache = MemcacheCache(memcache)
_negotiator = ContentNegotiator(acceptable=[
AcceptParameters(ContentType(CONTENT_TYPE_HTML)),
AcceptParameters(ContentType(as2.CONTENT_TYPE)),
@ -289,37 +268,6 @@ def unwrap(val, field=None):
return val
def webmention_endpoint_cache_key(url):
"""Returns cache key for a cached webmention endpoint for a given URL.
Just the domain by default. If the URL is the home page, ie path is ``/``,
the key includes a ``/`` at the end, so that we cache webmention endpoints
for home pages separate from other pages.
https://github.com/snarfed/bridgy/issues/701
Example: ``snarfed.org /``
https://github.com/snarfed/bridgy-fed/issues/423
Adapted from ``bridgy/util.py``.
"""
parsed = urllib.parse.urlparse(url)
key = parsed.netloc
if parsed.path in ('', '/'):
key += ' /'
logger.debug(f'wm cache key {key}')
return key
@cachetools.cached(cachetools.TTLCache(50000, 60 * 60 * 2), # 2h expiration
key=webmention_endpoint_cache_key,
lock=threading.Lock())
def webmention_discover(url, **kwargs):
"""Thin caching wrapper around :func:`oauth_dropins.webutil.webmention.discover`."""
return webmention.discover(url, **kwargs)
def create_task(queue, delay=None, **params):
"""Adds a Cloud Tasks task.
@ -482,63 +430,12 @@ NDB_CONTEXT_KWARGS = {
# limited context-local cache. avoid full one due to this bug:
# https://github.com/googleapis/python-ndb/issues/888
'cache_policy': cache_policy,
'global_cache': global_cache,
'global_cache': memcache.global_cache,
'global_cache_policy': global_cache_policy,
'global_cache_timeout_policy': global_cache_timeout_policy,
}
def memcache_key(key):
"""Preprocesses a memcache key. Right now just truncates it to 250 chars.
https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html
https://github.com/memcached/memcached/wiki/Commands#standard-protocol
TODO: truncate to 250 *UTF-8* chars, to handle Unicode chars in URLs. Related:
pymemcache Client's allow_unicode_keys constructor kwarg.
"""
return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode()
def memcache_memoize_key(fn, *args, **kwargs):
return memcache_key(f'{fn.__name__}-2-{repr(args)}-{repr(kwargs)}')
NONE = () # empty tuple
def memcache_memoize(expire=None, key=None):
"""Memoize function decorator that stores the cached value in memcache.
Args:
expire (timedelta): optional, expiration
key (callable): function that takes the function's (*args, **kwargs) and
returns the cache key to use
"""
if expire:
expire = int(expire.total_seconds())
def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
if key:
cache_key = memcache_memoize_key(fn, key(*args, **kwargs))
else:
cache_key = memcache_memoize_key(fn, *args, **kwargs)
val = pickle_memcache.get(cache_key)
if val is not None:
# logger.debug(f'cache hit {cache_key}')
return None if val == NONE else val
# logger.debug(f'cache miss {cache_key}')
val = fn(*args, **kwargs)
pickle_memcache.set(cache_key, NONE if val is None else val, expire=expire)
return val
return wrapped
return decorator
def as2_request_type():
"""If this request has conneg (ie the ``Accept`` header) for AS2, returns its type.

10
dms.py
Wyświetl plik

@ -6,8 +6,9 @@ from granary import as1
from oauth_dropins.webutil.flask_util import error
from oauth_dropins.webutil import util
from common import create_task, DOMAINS, memcache, memcache_key
from common import create_task, DOMAINS
import ids
import memcache
import models
from models import PROTOCOLS
import protocol
@ -221,10 +222,11 @@ def receive(*, from_user, obj):
attempts_key = f'dm-user-requests-{from_user.LABEL}-{from_user.key.id()}'
# incr leaves existing expiration as is, doesn't change it
# https://stackoverflow.com/a/4084043/186123
attempts = memcache.incr(attempts_key, 1)
attempts = memcache.memcache.incr(attempts_key, 1)
if not attempts:
memcache.add(attempts_key, 1,
expire=int(REQUESTS_LIMIT_EXPIRE.total_seconds()))
memcache.memcache.add(
attempts_key, 1,
expire=int(REQUESTS_LIMIT_EXPIRE.total_seconds()))
elif attempts > REQUESTS_LIMIT_USER:
return reply(f"Sorry, you've hit your limit of {REQUESTS_LIMIT_USER} requests per day. Try again tomorrow!")

Wyświetl plik

@ -1,4 +1,4 @@
"""Main Flask application."""
"""Flask application for frontend ("default") service."""
import json
import logging
from pathlib import Path

83
memcache.py 100644
Wyświetl plik

@ -0,0 +1,83 @@
"""Utilities for caching data in memcache."""
import functools
import logging
from google.cloud.ndb.global_cache import _InProcessGlobalCache, MemcacheCache
from oauth_dropins.webutil import appengine_info
import pymemcache.client.base
from pymemcache.serde import PickleSerde
from pymemcache.test.utils import MockMemcacheClient
logger = logging.getLogger(__name__)
# https://github.com/memcached/memcached/wiki/Commands#standard-protocol
KEY_MAX_LEN = 250
if appengine_info.DEBUG or appengine_info.LOCAL_SERVER:
logger.info('Using in memory mock memcache')
memcache = MockMemcacheClient(allow_unicode_keys=True)
pickle_memcache = MockMemcacheClient(allow_unicode_keys=True, serde=PickleSerde())
global_cache = _InProcessGlobalCache()
else:
logger.info('Using production Memorystore memcache')
memcache = pymemcache.client.base.PooledClient(
os.environ['MEMCACHE_HOST'], timeout=10, connect_timeout=10, # seconds
allow_unicode_keys=True)
pickle_memcache = pymemcache.client.base.PooledClient(
os.environ['MEMCACHE_HOST'], timeout=10, connect_timeout=10, # seconds
serde=PickleSerde(), allow_unicode_keys=True)
global_cache = MemcacheCache(memcache)
def key(key):
"""Preprocesses a memcache key. Right now just truncates it to 250 chars.
https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html
https://github.com/memcached/memcached/wiki/Commands#standard-protocol
TODO: truncate to 250 *UTF-8* chars, to handle Unicode chars in URLs. Related:
pymemcache Client's allow_unicode_keys constructor kwarg.
"""
return key[:KEY_MAX_LEN].replace(' ', '%20').encode()
def memoize_key(fn, *args, **kwargs):
return key(f'{fn.__name__}-2-{repr(args)}-{repr(kwargs)}')
NONE = () # empty tuple
def memoize(expire=None, key=None):
"""Memoize function decorator that stores the cached value in memcache.
Args:
expire (timedelta): optional, expiration
key (callable): function that takes the function's (*args, **kwargs) and
returns the cache key to use
"""
if expire:
expire = int(expire.total_seconds())
def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
if key:
cache_key = memoize_key(fn, key(*args, **kwargs))
else:
cache_key = memoize_key(fn, *args, **kwargs)
val = pickle_memcache.get(cache_key)
if val is not None:
# logger.debug(f'cache hit {cache_key}')
return None if val == NONE else val
# logger.debug(f'cache miss {cache_key}')
val = fn(*args, **kwargs)
pickle_memcache.set(cache_key, NONE if val is None else val, expire=expire)
return val
return wrapped
return decorator

Wyświetl plik

@ -30,12 +30,12 @@ from common import (
base64_to_long,
DOMAIN_RE,
long_to_base64,
memcache_memoize,
OLD_ACCOUNT_AGE,
report_error,
unwrap,
)
import ids
import memcache
# maps string label to Protocol subclass. values are populated by ProtocolUserMeta.
# (we used to wait for ProtocolUserMeta to populate the keys as well, but that was
@ -253,7 +253,7 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
util.add(getattr(self, prop), val)
if prop == 'copies':
common.pickle_memcache.set(common.memcache_memoize_key(
memcache.pickle_memcache.set(memcache.memoize_key(
get_original_user_key, val.uri), self.key)
@classmethod
@ -848,7 +848,7 @@ Welcome to Bridgy Fed! Your account will soon be bridged to {to_proto.PHRASE} at
@cachetools.cached(
cachetools.TTLCache(50000, FOLLOWERS_CACHE_EXPIRATION.total_seconds()),
key=lambda user: user.key.id(), lock=Lock())
@memcache_memoize(key=lambda self: self.key.id(),
@memcache.memoize(key=lambda self: self.key.id(),
expire=FOLLOWERS_CACHE_EXPIRATION)
def count_followers(self):
"""Counts this user's followers and followings.
@ -1120,7 +1120,7 @@ class Object(StringIdModel):
util.add(getattr(self, prop), val)
if prop == 'copies':
common.pickle_memcache.set(common.memcache_memoize_key(
memcache.pickle_memcache.set(memcache.memoize_key(
get_original_object_key, val.uri), self.key)
def remove(self, prop, val):
@ -1684,12 +1684,12 @@ def fetch_page(query, model_class, by=None):
@lru_cache(maxsize=100000)
@memcache_memoize(expire=GET_ORIGINALS_CACHE_EXPIRATION)
@memcache.memoize(expire=GET_ORIGINALS_CACHE_EXPIRATION)
def get_original_object_key(copy_id):
"""Finds the :class:`Object` with a given copy id, if any.
Note that :meth:`Object.add` also updates this function's
:func:`memcache_memoize` cache.
:func:`memcache.memoize` cache.
Args:
copy_id (str)
@ -1703,12 +1703,12 @@ def get_original_object_key(copy_id):
@lru_cache(maxsize=100000)
@memcache_memoize(expire=GET_ORIGINALS_CACHE_EXPIRATION)
@memcache.memoize(expire=GET_ORIGINALS_CACHE_EXPIRATION)
def get_original_user_key(copy_id):
"""Finds the user with a given copy id, if any.
Note that :meth:`User.add` also updates this function's
:func:`memcache_memoize` cache.
:func:`memcache.memoize` cache.
Args:
copy_id (str)
@ -1723,4 +1723,3 @@ def get_original_user_key(copy_id):
if proto and proto.LABEL != 'ui' and not proto.owns_id(copy_id):
if orig := proto.query(proto.copies.uri == copy_id).get(keys_only=True):
return orig

Wyświetl plik

@ -28,6 +28,7 @@ from common import CACHE_CONTROL, DOMAIN_RE
from flask_app import app
from flask import redirect
import ids
import memcache
from models import fetch_objects, fetch_page, Follower, Object, PAGE_SIZE, PROTOCOLS
from protocol import Protocol
@ -445,6 +446,6 @@ def memcache_command():
if request.headers.get('Authorization') != app.config['SECRET_KEY']:
return '', 401
resp = common.memcache.raw_command(request.get_data(as_text=True),
end_tokens='END\r\n')
resp = memcache.memcache.raw_command(request.get_data(as_text=True),
end_tokens='END\r\n')
return resp.decode(), {'Content-Type': 'text/plain'}

Wyświetl plik

@ -41,6 +41,7 @@ from ids import (
translate_object_id,
translate_user_id,
)
import memcache
from models import (
DM,
Follower,
@ -77,7 +78,7 @@ werkzeug.exceptions._aborter.mapping.setdefault(299, ErrorButDoNotRetryTask)
def activity_id_memcache_key(id):
return common.memcache_key(f'receive-{id}')
return memcache.key(f'receive-{id}')
class Protocol:
@ -839,7 +840,7 @@ class Protocol:
# lease this object, atomically
memcache_key = activity_id_memcache_key(id)
leased = common.memcache.add(memcache_key, 'leased', noreply=False,
leased = memcache.memcache.add(memcache_key, 'leased', noreply=False,
expire=5 * 60) # 5 min
# short circuit if we've already seen this activity id.
# (don't do this for bare objects since we need to check further down
@ -1037,7 +1038,7 @@ class Protocol:
f.status = 'inactive'
ndb.put_multi(followers)
common.memcache.set(memcache_key, 'done', expire=7 * 24 * 60 * 60) # 1w
memcache.memcache.set(memcache_key, 'done', expire=7 * 24 * 60 * 60) # 1w
return resp
@classmethod

Wyświetl plik

@ -34,6 +34,7 @@ from activitypub import (
from atproto import ATProto
import common
from flask_app import app
import memcache
from models import Follower, Object, Target
import protocol
from protocol import DELETE_TASK_DELAY
@ -1612,7 +1613,7 @@ class ActivityPubTest(TestCase):
self.assertEqual(202, got.status_code, got.text)
self.assertIn('Ignoring LD Signature', got.text)
self.assertIsNone(Object.get_by_id('http://inst/post'))
self.assertIsNone(common.memcache.get('receive-http://inst/post'))
self.assertIsNone(memcache.memcache.get('receive-http://inst/post'))
def test_inbox_http_sig_is_not_actor_author(self, mock_head, mock_get, mock_post):

Wyświetl plik

@ -2,7 +2,6 @@
from unittest.mock import Mock, patch
import flask
from google.cloud.ndb import Key
from granary import as2
from oauth_dropins.webutil.appengine_config import error_reporting_client
@ -165,100 +164,6 @@ class CommonTest(TestCase):
mock_client.report.assert_called_with('foo', http_context=None, bar='baz')
@patch('common.MEMCACHE_KEY_MAX_LEN', new=10)
def test_memcache_key(self):
for input, expected in (
('foo', b'foo'),
('foo-bar-baz', b'foo-bar-ba'),
('foo bar', b'foo%20bar'),
('☃.net', b'\xe2\x98\x83.net'),
):
self.assertEqual(expected, common.memcache_key(input))
def test_memcache_memoize_int(self):
calls = []
@common.memcache_memoize()
def foo(x, y, z=None):
calls.append((x, y, z))
return len(calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual([(1, 'a', 1)], calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual([(1, 'a', 1)], calls)
self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)
def test_memcache_memoize_str(self):
calls = []
@common.memcache_memoize()
def foo(x):
calls.append(x)
return str(x)
self.assertEqual('1', foo(1))
self.assertEqual([1], calls)
self.assertEqual('1', foo(1))
self.assertEqual([1], calls)
def test_memcache_memoize_Key(self):
calls = []
@common.memcache_memoize()
def foo(x):
calls.append(x)
return Key(Object, x)
a = Key(Object, 'a')
self.assertEqual(a, foo('a'))
self.assertEqual(['a'], calls)
self.assertEqual(a, foo('a'))
self.assertEqual(['a'], calls)
b = Key(Object, 'b')
self.assertEqual(b, foo('b'))
self.assertEqual(['a', 'b'], calls)
self.assertEqual(a, foo('a'))
self.assertEqual(['a', 'b'], calls)
self.assertEqual(b, foo('b'))
self.assertEqual(['a', 'b'], calls)
def test_memcache_memoize_None(self):
calls = []
@common.memcache_memoize()
def foo(x):
calls.append(x)
return None
self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)
self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)
def test_memcache_memoize_key_fn(self):
calls = []
@common.memcache_memoize(key=lambda x: x + 1)
def foo(x):
calls.append(x)
return str(x)
self.assertEqual('5', foo(5))
self.assertEqual([5], calls)
self.assertIsNone(common.pickle_memcache.get(b'foo-2-(5,)-{}'))
self.assertEqual('5', common.pickle_memcache.get('foo-2-(6,)-{}'))
self.assertEqual('5', foo(5))
self.assertEqual([5], calls)
def test_as2_request_type(self):
for accept, expected in (
(as2.CONTENT_TYPE_LD_PROFILE, as2.CONTENT_TYPE_LD_PROFILE),

Wyświetl plik

@ -2,10 +2,10 @@
from unittest import mock
from atproto import ATProto
from common import memcache
import dms
from dms import maybe_send, receive
import ids
from common import memcache
from models import DM, Follower, Object, Target
from web import Web
@ -336,7 +336,8 @@ class DmsTest(TestCase):
self.assert_sent(ExplicitFake, [bob, eve], 'request_bridging',
ALICE_REQUEST_CONTENT)
self.assertEqual(2, memcache.get('dm-user-requests-efake-efake:alice'))
self.assertEqual(2, memcache.memcache.get(
'dm-user-requests-efake-efake:alice'))
# over the limit
OtherFake.sent = []
@ -348,7 +349,8 @@ class DmsTest(TestCase):
self.assertEqual(('OK', 200), receive(from_user=alice, obj=obj))
self.assertEqual([], OtherFake.sent)
self.assert_replied(OtherFake, alice, '?', "Sorry, you've hit your limit of 2 requests per day. Try again tomorrow!")
self.assertEqual(3, memcache.get('dm-user-requests-efake-efake:alice'))
self.assertEqual(3, memcache.memcache.get(
'dm-user-requests-efake-efake:alice'))
def test_receive_prompt_wrong_protocol(self):
self.make_user(id='other.brid.gy', cls=Web)

Wyświetl plik

@ -0,0 +1,105 @@
"""Unit tests for memcache.py."""
from unittest.mock import patch
from google.cloud.ndb import Key
import memcache
from memcache import memoize, pickle_memcache
from models import Object
from .testutil import TestCase
class MemcacheTest(TestCase):
def test_memoize_int(self):
calls = []
@memoize()
def foo(x, y, z=None):
calls.append((x, y, z))
return len(calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual([(1, 'a', 1)], calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual([(1, 'a', 1)], calls)
self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)
def test_memoize_str(self):
calls = []
@memoize()
def foo(x):
calls.append(x)
return str(x)
self.assertEqual('1', foo(1))
self.assertEqual([1], calls)
self.assertEqual('1', foo(1))
self.assertEqual([1], calls)
def test_memoize_Key(self):
calls = []
@memoize()
def foo(x):
calls.append(x)
return Key(Object, x)
a = Key(Object, 'a')
self.assertEqual(a, foo('a'))
self.assertEqual(['a'], calls)
self.assertEqual(a, foo('a'))
self.assertEqual(['a'], calls)
b = Key(Object, 'b')
self.assertEqual(b, foo('b'))
self.assertEqual(['a', 'b'], calls)
self.assertEqual(a, foo('a'))
self.assertEqual(['a', 'b'], calls)
self.assertEqual(b, foo('b'))
self.assertEqual(['a', 'b'], calls)
def test_memoize_None(self):
calls = []
@memoize()
def foo(x):
calls.append(x)
return None
self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)
self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)
def test_memoize_key_fn(self):
calls = []
@memoize(key=lambda x: x + 1)
def foo(x):
calls.append(x)
return str(x)
self.assertEqual('5', foo(5))
self.assertEqual([5], calls)
self.assertIsNone(pickle_memcache.get(b'foo-2-(5,)-{}'))
self.assertEqual('5', pickle_memcache.get('foo-2-(6,)-{}'))
self.assertEqual('5', foo(5))
self.assertEqual([5], calls)
@patch('memcache.KEY_MAX_LEN', new=10)
def test_key(self):
for input, expected in (
('foo', b'foo'),
('foo-bar-baz', b'foo-bar-ba'),
('foo bar', b'foo%20bar'),
('☃.net', b'\xe2\x98\x83.net'),
):
self.assertEqual(expected, memcache.key(input))

Wyświetl plik

@ -26,6 +26,7 @@ from .testutil import ExplicitFake, Fake, OtherFake, TestCase
from activitypub import ActivityPub
from atproto import ATProto
import common
import memcache
import models
from models import Follower, Object, OBJECT_EXPIRE_AGE, PROTOCOLS, Target, User
import protocol
@ -433,7 +434,7 @@ class UserTest(TestCase):
self.assertEqual((0, 0), user.count_followers())
# clear both
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
user.count_followers.cache.clear()
self.assertEqual((1, 2), user.count_followers())
@ -531,15 +532,15 @@ class UserTest(TestCase):
self.assertFalse(Web(id='bsky.brid.gy').is_enabled(ATProto))
def test_add_to_copies_updates_memcache(self):
cache_key = common.memcache_memoize_key(
cache_key = memcache.memoize_key(
models.get_original_user_key, 'other:x')
self.assertIsNone(common.pickle_memcache.get(cache_key))
self.assertIsNone(memcache.pickle_memcache.get(cache_key))
user = Fake(id='fake:x')
copy = Target(protocol='other', uri='other:x')
user.add('copies', copy)
self.assertEqual(user.key, common.pickle_memcache.get(cache_key))
self.assertEqual(user.key, memcache.pickle_memcache.get(cache_key))
class ObjectTest(TestCase):
@ -1019,7 +1020,7 @@ class ObjectTest(TestCase):
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
# matching copy users
self.make_user('other:alice', cls=OtherFake,
@ -1058,7 +1059,7 @@ class ObjectTest(TestCase):
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
# matching copies
self.make_user('other:alice', cls=OtherFake,
@ -1100,7 +1101,7 @@ class ObjectTest(TestCase):
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
# matching copies
self.store_object(id='other:a',
@ -1214,7 +1215,7 @@ class ObjectTest(TestCase):
def test_get_original_user_key(self):
self.assertIsNone(models.get_original_user_key('other:user'))
models.get_original_user_key.cache_clear()
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
user = self.make_user('fake:user', cls=Fake,
copies=[Target(uri='other:user', protocol='other')])
self.assertEqual(user.key, models.get_original_user_key('other:user'))
@ -1222,7 +1223,7 @@ class ObjectTest(TestCase):
def test_get_original_object_key(self):
self.assertIsNone(models.get_original_object_key('other:post'))
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
obj = self.store_object(id='fake:post',
copies=[Target(uri='other:post', protocol='other')])
self.assertEqual(obj.key, models.get_original_object_key('other:post'))
@ -1241,15 +1242,15 @@ class ObjectTest(TestCase):
self.assertEqual('fake:foo', obj.get_copy(Fake))
def test_add_to_copies_updates_memcache(self):
cache_key = common.memcache_memoize_key(
cache_key = memcache.memoize_key(
models.get_original_object_key, 'other:x')
self.assertIsNone(common.pickle_memcache.get(cache_key))
self.assertIsNone(memcache.pickle_memcache.get(cache_key))
obj = Object(id='x')
copy = Target(protocol='other', uri='other:x')
obj.add('copies', copy)
self.assertEqual(obj.key, common.pickle_memcache.get(cache_key))
self.assertEqual(obj.key, memcache.pickle_memcache.get(cache_key))
class FollowerTest(TestCase):

Wyświetl plik

@ -26,6 +26,7 @@ from activitypub import ActivityPub
from app import app
from atproto import ATProto
import common
import memcache
import models
from models import DM, Follower, Object, PROTOCOLS, Target, User
import protocol
@ -2757,10 +2758,10 @@ class ProtocolReceiveTest(TestCase):
self.alice.copies = [Target(uri='fake:alice', protocol='fake')]
self.alice.put()
common.memcache.clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.memcache.clear()
memcache.pickle_memcache.clear()
obj.new = True
Fake.fetchable = {
@ -2792,10 +2793,10 @@ class ProtocolReceiveTest(TestCase):
self.store_object(id='other:post',
copies=[Target(uri='fake:post', protocol='fake')])
common.memcache.clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.memcache.clear()
memcache.pickle_memcache.clear()
obj.new = True
_, code = Fake.receive(obj, authed_as='fake:user')
@ -2847,7 +2848,7 @@ class ProtocolReceiveTest(TestCase):
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
memcache.pickle_memcache.clear()
obj.new = True
self.assertEqual(('OK', 202), Fake.receive(obj, authed_as='fake:user'))

Wyświetl plik

@ -259,14 +259,16 @@ import atproto
from atproto import ATProto
import common
from common import (
global_cache,
LOCAL_DOMAINS,
memcache,
pickle_memcache,
OTHER_DOMAINS,
PRIMARY_DOMAIN,
PROTOCOL_DOMAINS,
)
from memcache import (
global_cache,
memcache,
pickle_memcache,
)
from web import Web
from flask_app import app
@ -293,7 +295,6 @@ class TestCase(unittest.TestCase, testutil.Asserts):
common.RUN_TASKS_INLINE = True
app.testing = True
common.webmention_discover.cache.clear()
did.resolve_handle.cache.clear()
did.resolve_plc.cache.clear()
did.resolve_web.cache.clear()
@ -325,7 +326,6 @@ class TestCase(unittest.TestCase, testutil.Asserts):
global_cache.clear()
models.get_original_object_key.cache_clear()
models.get_original_user_key.cache_clear()
common.pickle_memcache.clear()
activitypub.WEB_OPT_OUT_DOMAINS = set()
# clear datastore

32
web.py
Wyświetl plik

@ -35,6 +35,7 @@ from common import (
)
from flask_app import app
from ids import normalize_user_id, translate_object_id, translate_user_id
import memcache
from models import Follower, Object, PROTOCOLS, Target, User
from protocol import Protocol
@ -485,7 +486,7 @@ class Web(User, Protocol):
# we only send webmentions for responses. for sending normal posts etc
# to followers, we just update our stored objects (elsewhere) and web
# users consume them via feeds.
endpoint = common.webmention_discover(url).endpoint
endpoint = webmention_discover(url).endpoint
if not endpoint:
return False
@ -1005,3 +1006,32 @@ def webmention_task():
except ValueError as e:
logger.warning(e, exc_info=True)
error(e, status=304)
def webmention_endpoint_cache_key(url):
"""Returns cache key for a cached webmention endpoint for a given URL.
Just the domain by default. If the URL is the home page, ie path is ``/``,
the key includes a ``/`` at the end, so that we cache webmention endpoints
for home pages separate from other pages.
https://github.com/snarfed/bridgy/issues/701
Example: ``snarfed.org /``
https://github.com/snarfed/bridgy-fed/issues/423
Adapted from ``bridgy/util.py``.
"""
parsed = urllib.parse.urlparse(url)
key = parsed.netloc
if parsed.path in ('', '/'):
key += ' /'
logger.debug(f'wm cache key {key}')
return key
@memcache.memoize(expire=timedelta(hours=2), key=webmention_endpoint_cache_key)
def webmention_discover(url, **kwargs):
"""Thin caching wrapper around :func:`oauth_dropins.webutil.webmention.discover`."""
return webmention.discover(url, **kwargs)