bridgy-fed/tests/testutil.py

694 wiersze
23 KiB
Python
Czysty Zwykły widok Historia

"""Common test utility code."""
import contextlib
import copy
from datetime import datetime
import logging
2023-08-31 03:59:37 +00:00
import os
import random
import re
import unittest
from unittest.mock import ANY, call
from urllib.parse import parse_qs, urlencode
import warnings
from arroba import did
import arroba.util
from arroba.util import datetime_to_tid
from bs4 import MarkupResemblesLocatorWarning
import dag_cbor.random
from google.cloud import ndb
from google.protobuf.timestamp_pb2 import Timestamp
from granary import as1, as2
from granary.tests.test_as1 import (
2023-07-28 22:49:29 +00:00
ACTOR,
COMMENT,
MENTION,
NOTE,
)
2023-10-18 04:50:19 +00:00
from oauth_dropins.webutil import flask_util, testutil, util
from oauth_dropins.webutil.appengine_config import ndb_client
from oauth_dropins.webutil import appengine_info
from oauth_dropins.webutil.testutil import requests_response
from oauth_dropins.webutil.util import json_dumps, json_loads
import requests
# other modules are imported _after_ Fake etc classes is defined so that it's in
# PROTOCOLS when URL routes are registered.
from common import long_to_base64, NDB_CONTEXT_KWARGS, TASKS_LOCATION
import ids
import models
from models import KEY_BITS, Object, PROTOCOLS, Target, User
import protocol
import router
logger = logging.getLogger(__name__)
logging.getLogger('memcache').setLevel(logging.INFO)
ATPROTO_KEY = arroba.util.new_key(2349823483510) # deterministic seed
2023-07-28 22:49:29 +00:00
NOTE = {
**NOTE,
# bare string author id
'author': ACTOR['id'],
}
MENTION = {
**MENTION,
# author object with just id
'author': {'id': ACTOR['id']},
'content': 'a mention',
2023-07-28 22:49:29 +00:00
}
COMMENT = {
**COMMENT,
# full author object
'author': {
**ACTOR,
'displayName': 'Dr. Eve',
},
'content': 'a comment',
2023-07-28 22:49:29 +00:00
}
2023-03-14 00:23:57 +00:00
class Fake(User, protocol.Protocol):
ABBREV = 'fa'
PHRASE = 'fake-phrase'
CONTENT_TYPE = 'fa/ke'
DEFAULT_ENABLED_PROTOCOLS = ('other',)
HAS_COPIES = True
LOGO_HTML = '<img src="fake-logo">'
DEFAULT_SERVE_USER_PAGES = True
SUPPORTED_AS1_TYPES = frozenset(
2024-07-04 23:58:06 +00:00
tuple(as1.ACTOR_TYPES)
+ tuple(as1.POST_TYPES)
+ tuple(as1.CRUD_VERBS)
2024-07-04 23:58:06 +00:00
+ tuple(as1.VERBS_WITH_OBJECT)
)
2023-03-14 00:23:57 +00:00
# maps string ids to dict AS1 objects that can be fetched
fetchable = {}
# in-order list of (Object, str URL)
sent = []
# in-order lists of ids
fetched = []
created_for = []
# maps str user id to str username
usernames = {}
@ndb.ComputedProperty
2023-09-25 17:27:08 +00:00
def handle(self):
return self.key.id().replace(f'{self.LABEL}:', f'{self.LABEL}:handle:')
2023-09-25 17:27:08 +00:00
def web_url(self):
2024-05-29 23:46:16 +00:00
return f'web:{self.key.id()}'
def id_uri(self):
return f'uri:{self.key.id()}'
@classmethod
def bridged_web_url_for(cls, user, fallback=False):
return f'web:{cls.LABEL}:{user.key.id()}'
@classmethod
def create_for(cls, user):
id = user.key.id()
logger.info(f'{cls.__name__}.create_for {id}')
cls.created_for.append(id)
if not user.get_copy(cls):
copy = Target(uri=ids.translate_user_id(id=id, from_=user, to=cls),
protocol=cls.LABEL)
util.add(user.copies, copy)
user.put()
if user.obj and not user.obj.get_copy(cls):
profile_copy_id = ids.translate_object_id(
id=user.profile_id(), from_=user, to=cls)
user.obj.add('copies', Target(uri=profile_copy_id, protocol=cls.LABEL))
user.obj.put()
@classmethod
def owns_id(cls, id):
if id.startswith('nope') or id == f'{cls.LABEL}:nope':
return False
return ((id.startswith(f'{cls.LABEL}:')
and not id.startswith(f'{cls.LABEL}:handle:'))
or id in cls.fetchable)
@classmethod
def owns_handle(cls, handle, allow_internal=False):
return handle.startswith(f'{cls.LABEL}:handle:')
@classmethod
def handle_to_id(cls, handle):
if handle == f'{cls.LABEL}:handle:nope':
2023-09-23 20:53:17 +00:00
return None
return handle.replace(f'{cls.LABEL}:handle:', f'{cls.LABEL}:')
@classmethod
def is_blocklisted(cls, url, allow_internal=False):
return url.startswith(f'{cls.LABEL}:blocklisted')
2023-03-14 00:23:57 +00:00
@classmethod
def send(to, obj, url, from_user=None, orig_obj_id=None):
logger.info(f'{to.__name__}.send {url} {obj.as1}')
to.sent.append((url, obj.as1))
from_ = PROTOCOLS.get(obj.source_protocol)
if (from_ and from_ != to and to.HAS_COPIES
and obj.type not in ('update', 'delete')):
obj_id = as1.get_object(obj.as1).get('id')
if not obj_id:
return False
if obj.type == 'post':
obj = Object.get_by_id(obj_id)
copy_id = ids.translate_object_id(
id=obj.key.id(), from_=from_, to=to)
util.add(obj.copies, Target(uri=copy_id, protocol=to.LABEL))
obj.put()
return True
2023-03-14 00:23:57 +00:00
@classmethod
def fetch(cls, obj, **kwargs):
id = obj.key.id()
logger.info(f'{cls.__name__}.fetch {id}')
cls.fetched.append(id)
if id in cls.fetchable:
obj.our_as1 = cls.fetchable[id]
return True
return False
@classmethod
def _convert(cls, obj, from_user=None):
logger.info(f'{cls.__name__}.convert {obj.key.id()} {from_user}')
return cls.translate_ids(obj.as1)
2023-06-17 21:12:43 +00:00
@classmethod
def target_for(cls, obj, shared=False):
assert obj.source_protocol in (cls.LABEL, cls.ABBREV, 'ui', None), \
(obj.source_protocol, cls.LABEL)
if shared:
return f'{cls.LABEL}:shared:target'
user_id = ids.normalize_user_id(id=obj.key.id(), proto=cls)
return f'{user_id}:target'
2023-06-17 21:12:43 +00:00
@classmethod
def receive(cls, obj, authed_as=None, **kwargs):
2023-09-19 02:19:59 +00:00
assert isinstance(obj, Object)
if not authed_as:
authed_as = as1.get_owner(obj.as1) or obj.as1['id']
return super().receive(obj=obj, authed_as=authed_as, **kwargs)
2023-09-19 02:19:59 +00:00
@classmethod
def receive_as1(cls, our_as1, **kwargs):
assert isinstance(our_as1, dict)
obj = Object(id=our_as1['id'], our_as1=our_as1, source_protocol=cls.LABEL)
obj.new = kwargs.pop('new', True)
obj.changed = kwargs.pop('changed', None)
return cls.receive(obj, **kwargs)
class OtherFake(Fake):
"""Different class because the same-protocol check special cases Fake."""
LABEL = ABBREV = 'other'
PHRASE = 'other-phrase'
CONTENT_TYPE = 'ot/her'
DEFAULT_ENABLED_PROTOCOLS = ('fake',)
DEFAULT_SERVE_USER_PAGES = False
SUPPORTED_AS1_TYPES = Fake.SUPPORTED_AS1_TYPES - set(('accept',))
SUPPORTS_DMS = True
fetchable = {}
sent = []
fetched = []
created_for = []
usernames = {}
@classmethod
def target_for(cls, obj, shared=False):
"""No shared target."""
return f'{obj.key.id()}:target'
@classmethod
def set_username(cls, user, username):
logger.info(f'Setting custom username for {user.key.id()} in {cls.LABEL} to {username}')
cls.usernames[user.key.id()] = username
class ExplicitFake(Fake):
"""Fake protocol class that isn't default enabled for any other protocols."""
LABEL = ABBREV = 'efake'
PHRASE = 'efake-phrase'
CONTENT_TYPE = 'un/known'
DEFAULT_ENABLED_PROTOCOLS = ()
DEFAULT_SERVE_USER_PAGES = False
SUPPORTS_DMS = True
fetchable = {}
sent = []
fetched = []
created_for = []
usernames = {}
# import other modules that register Flask handlers *after* Fake is defined
models.reset_protocol_properties()
import app
import activitypub
from activitypub import ActivityPub, CONNEG_HEADERS_AS2_HTML
import atproto
2023-08-31 20:49:45 +00:00
from atproto import ATProto
import common
from common import (
LOCAL_DOMAINS,
OTHER_DOMAINS,
PRIMARY_DOMAIN,
PROTOCOL_DOMAINS,
)
from memcache import (
global_cache,
memcache,
pickle_memcache,
)
from web import Web
from flask_app import app
ActivityPub.DEFAULT_ENABLED_PROTOCOLS += ('fake', 'other')
ATProto.DEFAULT_ENABLED_PROTOCOLS += ('fake', 'other')
Web.DEFAULT_ENABLED_PROTOCOLS += ('fake', 'other')
# used in TestCase.make_user() to reuse keys across Users since they're
# expensive to generate.
requests.post(f'http://{ndb_client.host}/reset')
with ndb_client.context():
global_user = activitypub._INSTANCE_ACTOR = Fake.get_or_create('fake:user')
class TestCase(unittest.TestCase, testutil.Asserts):
maxDiff = None
def setUp(self):
super().setUp()
testutil.suppress_warnings()
appengine_info.APP_ID = 'my-app'
appengine_info.LOCAL_SERVER = False
common.RUN_TASKS_INLINE = True
2021-08-18 14:59:52 +00:00
app.testing = True
did.resolve_handle.cache.clear()
did.resolve_plc.cache.clear()
did.resolve_web.cache.clear()
ids.web_ap_base_domain.cache.clear()
protocol.Protocol.for_id.cache.clear()
2024-08-12 20:31:13 +00:00
protocol.Protocol.for_handle.cache.clear()
User.count_followers.cache.clear()
common.protocol_user_copy_ids.cache_clear()
for cls in ExplicitFake, Fake, OtherFake:
cls.fetchable = {}
cls.sent = []
cls.fetched = []
cls.created_for = []
cls.usernames = {}
# make random test data deterministic
arroba.util._clockid = 17
random.seed(1234567890)
dag_cbor.random.set_options(seed=1234567890)
2021-08-18 14:59:52 +00:00
self.client = app.test_client()
self.client.__enter__()
self.router_client = router.app.test_client()
memcache.clear()
pickle_memcache.clear()
global_cache.clear()
models.get_original_object_key.cache_clear()
models.get_original_user_key.cache_clear()
# clear datastore
requests.post(f'http://{ndb_client.host}/reset')
self.ndb_context = ndb_client.context(**NDB_CONTEXT_KWARGS)
self.ndb_context.__enter__()
util.now = lambda **kwargs: testutil.NOW
# used in make_user()
self.last_make_user_id = 1
self.app_context = app.app_context()
self.app_context.push()
self.request_context = app.test_request_context('/')
self.request_context.push()
2023-08-31 03:59:37 +00:00
# arroba config
os.environ.update({
'APPVIEW_HOST': 'appview.local',
'BGS_HOST': 'bgs.local',
2023-08-31 03:59:37 +00:00
'PDS_HOST': 'pds.local',
'PLC_HOST': 'plc.local',
'MOD_SERVICE_HOST': 'mod.service.local',
'MOD_SERVICE_DID': 'did:mod-service',
'CHAT_HOST': 'chat.local',
'CHAT_DID': 'did:chat',
2023-08-31 03:59:37 +00:00
})
atproto.appview.address = 'https://appview.local'
2023-08-31 03:59:37 +00:00
def tearDown(self):
self.app_context.pop()
self.ndb_context.__exit__(None, None, None)
self.client.__exit__(None, None, None)
super().tearDown()
# this breaks if it's before super().tearDown(). why?!
try:
self.request_context.pop()
except LookupError as e:
# the test probably popped this already
logger.warning('', exc_info=e)
def run(self, result=None):
"""Override to hide stdlib and virtualenv lines in tracebacks.
https://docs.python.org/3.9/library/unittest.html#unittest.TestCase.run
https://docs.python.org/3.9/library/unittest.html#unittest.TestResult
"""
result = super().run(result=result)
def prune(results):
return [
(tc, re.sub(r'\n File ".+/(local|.venv|oauth-dropins|Python.framework)/.+\n.+\n',
'\n', tb))
for tc, tb in results]
result.errors = prune(result.errors)
result.failures = prune(result.failures)
return result
def get(self, url, **kwargs):
"""Adds Cloud tasks header to ``self.client.get``."""
return self._request('get', url, **kwargs)
def post(self, url, **kwargs):
2023-10-18 04:50:19 +00:00
"""Adds Cloud tasks header to ``self.client.post``."""
return self._request('post', url, **kwargs)
def _request(self, fn, url, client=None, **kwargs):
2023-10-18 04:50:19 +00:00
if client is None:
client = (self.router_client
if url.startswith('/queue/') or url.startswith('/cron/')
else self.client)
kwargs.setdefault('headers', {})[flask_util.CLOUD_TASKS_TASK_HEADER] = 'x'
return getattr(client, fn)(url, **kwargs)
2023-10-18 04:50:19 +00:00
def make_user(self, id, cls, **kwargs):
"""Reuse RSA key across Users because generating it is expensive."""
obj_as1 = copy.deepcopy(kwargs.pop('obj_as1', None))
obj_as2 = copy.deepcopy(kwargs.pop('obj_as2', None))
obj_bsky = copy.deepcopy(kwargs.pop('obj_bsky', None))
obj_mf2 = copy.deepcopy(kwargs.pop('obj_mf2', None))
obj_id = copy.deepcopy(kwargs.pop('obj_id', None))
if cls == Web and not obj_mf2:
kwargs.setdefault('last_webmention_in', testutil.NOW)
user = cls(id=id,
mod=global_user.mod,
public_exponent=global_user.public_exponent,
private_exponent=global_user.private_exponent,
**kwargs)
user.obj_key = kwargs.pop('obj_key', None)
if user.obj_key:
assert not (obj_as1 or obj_as2 or obj_bsky or obj_mf2 or obj_id)
elif cls != ATProto or obj_bsky:
if not obj_id:
obj_id = ((obj_as2 or {}).get('id')
or util.get_url((obj_mf2 or {}), 'properties')
or user.profile_id())
user.obj_key = Object.get_or_create(
id=obj_id, authed_as=obj_id, our_as1=obj_as1, as2=obj_as2,
bsky=obj_bsky, mf2=obj_mf2, source_protocol=cls.LABEL).key
user.put()
return user
def add_objects(self):
user = ndb.Key(Web, 'user.com')
# post
self.store_object(id='a',
users=[user],
notify=[user],
feed=[user],
2023-07-28 22:49:29 +00:00
our_as1=NOTE)
# post with mention
self.store_object(id='b',
2023-07-28 22:49:29 +00:00
notify=[user],
feed=[user],
our_as1=MENTION)
# reply
self.store_object(id='d',
notify=[user],
feed=[user],
2023-07-28 22:49:29 +00:00
our_as1=COMMENT)
# not feed/notif
2023-07-28 22:49:29 +00:00
self.store_object(id='e',
users=[user],
our_as1=NOTE)
# deleted
self.store_object(id='f',
notify=[user],
feed=[user],
our_as1={**NOTE, 'content': 'deleted!'},
2023-07-28 22:49:29 +00:00
deleted=True)
# different domain
nope = ndb.Key(Web, 'nope.org')
self.store_object(id='g',
notify=[nope],
feed=[nope],
our_as1=MENTION)
# actor whose id is in NOTE.author
self.store_object(id=ACTOR['id'], our_as1=ACTOR)
@staticmethod
def store_object(**kwargs):
obj = Object(**kwargs)
obj.put()
return obj
@staticmethod
def random_keys_and_cids(num):
def tid():
ms = random.randint(datetime(2020, 1, 1).timestamp() * 1000,
datetime(2024, 1, 1).timestamp() * 1000)
return datetime_to_tid(datetime.fromtimestamp(float(ms) / 1000))
return [(f'com.example.record/{tid()}', cid)
for cid in dag_cbor.random.rand_cid(num)]
def random_tid(num):
ms = random.randint(datetime(2020, 1, 1).timestamp() * 1000,
datetime(2024, 1, 1).timestamp() * 1000)
tid = datetime_to_tid(datetime.fromtimestamp(float(ms) / 1000))
return f'com.example.record/{tid}'
def get_as2(self, *args, **kwargs):
kwargs.setdefault('headers', {})['Accept'] = CONNEG_HEADERS_AS2_HTML
return self.client.get(*args, **kwargs)
@classmethod
def req(cls, url, **kwargs):
"""Returns a mock requests call."""
kwargs.setdefault('headers', {}).update({
'User-Agent': util.user_agent,
})
kwargs.setdefault('timeout', util.HTTP_TIMEOUT)
kwargs.setdefault('stream', True)
return call(url, **kwargs)
2022-03-17 04:11:09 +00:00
@classmethod
def as2_req(cls, url, **kwargs):
kwargs.setdefault('data', None)
headers = {
'Date': 'Sun, 02 Jan 2022 03:04:05 GMT',
'Host': util.domain_from_link(url, minimize=False),
'Content-Type': as2.CONTENT_TYPE_LD_PROFILE,
'Digest': ANY,
**CONNEG_HEADERS_AS2_HTML,
**kwargs.pop('headers', {}),
}
return cls.req(url, auth=ANY, headers=headers, allow_redirects=False, **kwargs)
@classmethod
def as2_resp(cls, obj):
return requests_response(obj, content_type=as2.CONTENT_TYPE)
2022-03-17 04:11:09 +00:00
def assert_req(self, mock, url, **kwargs):
"""Checks a mock requests call."""
kwargs.setdefault('headers', {}).setdefault(
'User-Agent', 'Bridgy Fed (https://fed.brid.gy/)')
2022-03-17 04:11:09 +00:00
kwargs.setdefault('stream', True)
kwargs.setdefault('timeout', util.HTTP_TIMEOUT)
mock.assert_any_call(url, **kwargs)
def assert_object(self, id, **props):
ignore = props.pop('ignore', [])
got = Object.get_by_id(id)
assert got, id
if 'our_as1' in props:
assert 'as2' not in props
assert 'bsky' not in props
assert 'mf2' not in props
ignore.extend(['as2', 'bsky', 'mf2'])
mf2 = props.get('mf2')
if mf2 and 'items' in mf2:
props['mf2'] = mf2['items'][0]
# strip @context
if 'as2' in props:
props['as2'].pop('@context', None)
for field in 'actor', 'object':
val = props['as2'].get(field)
if isinstance(val, dict):
val.pop('@context', None)
type = props.pop('type', None)
if type is not None:
self.assertEqual(type, got.type)
if expected_as1 := props.pop('as1', None):
self.assert_equals(expected_as1, got.as1)
if got.mf2:
got.mf2.pop('url', None)
self.assert_entities_equal(Object(id=id, **props), got,
2025-01-31 22:57:56 +00:00
ignore=['as1', 'created', 'expire',
'type', 'updated'] + ignore)
return got
def assert_user(self, cls, id, ignore=(), **props):
got = cls.get_by_id(id)
assert got, id
obj_as2 = props.pop('obj_as2', None)
if obj_as2:
self.assert_equals(obj_as2, as2.from_as1(got.obj.as1))
# generated, computed, etc
ignore = ['created', 'mod', 'handle', 'obj_key', 'private_exponent',
'public_exponent', 'status', 'updated'] + list(ignore)
for prop in ignore:
assert prop not in props
self.assert_entities_equal(cls(id=id, **props), got, ignore=ignore)
if cls != ActivityPub:
assert got.mod
assert got.private_exponent
assert got.public_exponent
return got
def assert_task(self, mock_create_task, queue, eta_seconds=None, **params):
# if only one task was created in this queue, extract and decode HTTP
# body so we can pretty-print a comparison
bodies = [body for task_queue, body in self.parse_tasks(mock_create_task)
if task_queue == queue]
if len(bodies) == 1:
self.assertEqual({k: v.decode() if isinstance(v, bytes) else v
for k, v in params.items()},
bodies[0])
# check for exact task
params = [(k, json_dumps(v, sort_keys=True) if isinstance(v, dict) else v)
for k, v in params.items()]
expected = {
'app_engine_http_request': {
'http_method': 'POST',
'relative_uri': f'/queue/{queue}',
'body': urlencode(sorted(params)).encode(),
'headers': {
'Content-Type': 'application/x-www-form-urlencoded',
'traceparent': '',
},
},
}
if eta_seconds:
expected['schedule_time'] = Timestamp(seconds=int(eta_seconds))
mock_create_task.assert_any_call(
parent=f'projects/{appengine_info.APP_ID}/locations/{TASKS_LOCATION}/queues/{queue}',
task=expected,
)
def assert_ap_deliveries(self, mock_post, inboxes, data, from_user=None,
ignore=()):
self.assertEqual(len(inboxes), len(mock_post.call_args_list),
mock_post.call_args_list)
calls = {} # maps inbox URL to JSON data
for args, kwargs in mock_post.call_args_list:
self.assertEqual(as2.CONTENT_TYPE_LD_PROFILE,
kwargs['headers']['Content-Type'])
rsa_key = kwargs['auth'].header_signer._rsa._key
self.assertEqual((from_user or self.user).private_pem(),
rsa_key.exportKey())
calls[args[0]] = json_loads(kwargs['data'])
for inbox in inboxes:
got = calls[inbox]
as1.get_object(got).pop('publicKey', None)
self.assert_equals(data, got, inbox, ignore=ignore)
def parse_tasks(self, mock_create_task):
"""Returns (queue, {param name: value}) tuples, where JSON param values
are parsed."""
tasks = []
for call in mock_create_task.call_args_list:
task = call.kwargs['task']['app_engine_http_request']
queue = task['relative_uri'].removeprefix('/queue/')
body = {name: json_loads(val[0]) if val and val[0][0] == '{' else val[0]
for name, val in parse_qs(task['body'].decode()).items()}
tasks.append((queue, body))
return tasks
def assert_equals(self, expected, actual, msg=None, ignore=(), **kwargs):
return super().assert_equals(
expected, actual, msg=msg, ignore=tuple(ignore) + ('@context',), **kwargs)
@contextlib.contextmanager
def assertLogs(self):
"""Wraps :meth:`unittest.TestCase.assertLogs` and enables/disables logs.
Works around ``oauth_dropins.webutil.tests.__init__``.
"""
orig_disable_level = logging.root.manager.disable
logging.disable(logging.NOTSET)
try:
with super().assertLogs() as logs:
yield logs
finally:
# emit logs that were captured
for record in logs.records:
if record.levelno >= orig_disable_level:
logging.root.handle(record)
logging.disable(orig_disable_level)