diff --git a/activitypub.py b/activitypub.py index 873d142..703da80 100644 --- a/activitypub.py +++ b/activitypub.py @@ -34,7 +34,7 @@ from common import ( subdomain_wrap, unwrap, ) -from models import fetch_objects, Follower, Object, User +from models import fetch_objects, Follower, Object, PROTOCOLS, User from protocol import Protocol import webfinger @@ -332,6 +332,11 @@ class ActivityPub(User, Protocol): """ if not obj or not obj.as1: return {} + + from_proto = PROTOCOLS.get(obj.source_protocol) + if from_proto and not common.is_enabled(cls, from_proto): + error(f'{cls.LABEL} <=> {from_proto.LABEL} not enabled') + if obj.as2: return { # add back @context since we strip it when we store Objects diff --git a/atproto.py b/atproto.py index dc97f82..d771483 100644 --- a/atproto.py +++ b/atproto.py @@ -389,6 +389,10 @@ class ATProto(User, Protocol): Returns: dict: JSON object """ + from_proto = PROTOCOLS.get(obj.source_protocol) + if from_proto and not common.is_enabled(cls, from_proto): + error(f'{cls.LABEL} <=> {from_proto.LABEL} not enabled') + if obj.bsky: return obj.bsky diff --git a/common.py b/common.py index 76bb0cf..12cfeee 100644 --- a/common.py +++ b/common.py @@ -31,6 +31,13 @@ TLD_BLOCKLIST = ('7z', 'asp', 'aspx', 'gif', 'html', 'ico', 'jpg', 'jpeg', 'js', CONTENT_TYPE_HTML = 'text/html; charset=utf-8' +# Protocol pairs that we currently support bridging between. Values must be +# Protocol LABELs. Each pair must be lexicographically sorted! +ENABLED_BRIDGES = frozenset(( + ('activitypub', 'web'), + ('atproto', 'web'), +)) + PRIMARY_DOMAIN = 'fed.brid.gy' # protocol-specific subdomains are under this "super"domain SUPERDOMAIN = '.brid.gy' @@ -248,6 +255,27 @@ def add(seq, val): seq.append(val) +def is_enabled(proto_a, proto_b): + """Returns True if bridging the two input protocols is enabled, False otherwise. + + Args: + proto_a (Protocol subclass) + proto_b (Protocol subclass) + + Returns: + bool: + """ + if proto_a == proto_b: + return True + + labels = tuple(sorted((proto_a.LABEL, proto_b.LABEL))) + + if DEBUG and ('fake' in labels or 'other' in labels): + return True + + return labels in ENABLED_BRIDGES + + def create_task(queue, delay=None, **params): """Adds a Cloud Tasks task. diff --git a/models.py b/models.py index 3cd5ad0..bb3cd97 100644 --- a/models.py +++ b/models.py @@ -240,7 +240,10 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): ATProto = PROTOCOLS['atproto'] if propagate and cls.LABEL != 'atproto' and not user.get_copy(ATProto): - ATProto.create_for(user) + if common.is_enabled(cls, ATProto): + ATProto.create_for(user) + else: + logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping') # generate keys for all protocols _except_ our own # diff --git a/tests/test_activitypub.py b/tests/test_activitypub.py index 45e2253..a947a0b 100644 --- a/tests/test_activitypub.py +++ b/tests/test_activitypub.py @@ -16,7 +16,7 @@ from oauth_dropins.webutil.util import domain_from_link, json_dumps, json_loads from oauth_dropins.webutil import util import requests from urllib3.exceptions import ReadTimeoutError -from werkzeug.exceptions import BadGateway +from werkzeug.exceptions import BadGateway, BadRequest # import first so that Fake is defined before URL routes are registered from . import testutil @@ -2198,6 +2198,11 @@ class ActivityPubUtilsTest(TestCase): 'object': ACTOR, }, ActivityPub.convert(obj)) + def test_convert_protocols_not_enabled(self): + obj = Object(our_as1={'foo': 'bar'}, source_protocol='atproto') + with self.assertRaises(BadRequest): + ActivityPub.convert(obj) + def test_postprocess_as2_idempotent(self): for obj in (ACTOR, REPLY_OBJECT, REPLY_OBJECT_WRAPPED, REPLY, NOTE_OBJECT, NOTE, MENTION_OBJECT, MENTION, LIKE, diff --git a/tests/test_atproto.py b/tests/test_atproto.py index e522741..29acfca 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -24,12 +24,13 @@ from multiformats import CID from oauth_dropins.webutil.appengine_config import tasks_client from oauth_dropins.webutil.testutil import requests_response from oauth_dropins.webutil.util import json_dumps, json_loads, trim_nulls +from werkzeug.exceptions import BadRequest import atproto from atproto import ATProto import common import hub -from models import Object, Target +from models import Object, PROTOCOLS, Target import protocol from .testutil import ATPROTO_KEY, Fake, TestCase from . import test_activitypub @@ -381,6 +382,11 @@ class ATProtoTest(TestCase): 'image': [{'url': 'http://my/pic'}], }), fetch_blobs=True)) + def test_convert_protocols_not_enabled(self): + obj = Object(our_as1={'foo': 'bar'}, source_protocol='activitypub') + with self.assertRaises(BadRequest): + ATProto.convert(obj) + @patch('requests.get', return_value=requests_response('', status=404)) def test_web_url(self, mock_get): user = self.make_user('did:plc:user', cls=ATProto) diff --git a/tests/test_common.py b/tests/test_common.py index 8b48320..b92e2cb 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -2,8 +2,10 @@ from flask import g # import first so that Fake is defined before URL routes are registered -from .testutil import Fake, TestCase +from .testutil import Fake, OtherFake, TestCase +from activitypub import ActivityPub +from atproto import ATProto import common from flask_app import app from ui import UIProtocol @@ -99,3 +101,11 @@ class CommonTest(TestCase): with app.test_request_context(base_url='https://atproto.brid.gy', path='/foo'): self.assertEqual('https://atproto.brid.gy/asdf', common.host_url('asdf')) + + def test_is_enabled(self): + self.assertTrue(common.is_enabled(Web, ActivityPub)) + self.assertTrue(common.is_enabled(ActivityPub, Web)) + self.assertTrue(common.is_enabled(ActivityPub, ActivityPub)) + self.assertTrue(common.is_enabled(ATProto, Web)) + self.assertTrue(common.is_enabled(Fake, OtherFake)) + self.assertFalse(common.is_enabled(ATProto, ActivityPub)) diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 0d20080..e9510af 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -26,6 +26,7 @@ class IntegrationTests(TestCase): @patch('requests.post') @patch('requests.get') + @patch('common.ENABLED_BRIDGES', new=[('activitypub', 'atproto')]) def test_atproto_notify_reply_to_activitypub(self, mock_get, mock_post): """ATProto poll notifications, deliver reply to ActivityPub. diff --git a/tests/test_models.py b/tests/test_models.py index a2350e7..c5794f5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,7 @@ """Unit tests for models.py.""" from unittest.mock import patch -from arroba.datastore_storage import AtpRemoteBlob +from arroba.datastore_storage import AtpRemoteBlob, AtpRepo from arroba.mst import dag_cbor_cid import arroba.server from arroba.util import at_uri @@ -20,6 +20,7 @@ from oauth_dropins.webutil import util # import first so that Fake is defined before URL routes are registered from .testutil import Fake, OtherFake, TestCase +from activitypub import ActivityPub from atproto import ATProto import common import models @@ -102,6 +103,20 @@ class UserTest(TestCase): mock_create_task.assert_called() + @patch.object(tasks_client, 'create_task') + @patch('requests.post') + def test_get_or_create_propagate_not_enabled(self, mock_post, mock_create_task): + user = ActivityPub.get_or_create('https://mas.to/actor', propagate=True) + + # self.assertEqual([], Fake.fetched) + mock_post.assert_not_called() + mock_create_task.assert_not_called() + + user = ActivityPub.get_by_id('https://mas.to/actor') + self.assertEqual([], user.copies) + self.assertEqual(0, AtpRepo.query().count()) + + def test_get_or_create_use_instead(self): user = Fake.get_or_create('a.b') user.use_instead = self.user.key diff --git a/tests/testutil.py b/tests/testutil.py index b73396c..48a6735 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -174,7 +174,6 @@ import common from web import Web from flask_app import app, cache - # used in TestCase.make_user() to reuse keys across Users since they're # expensive to generate. requests.post(f'http://{ndb_client.host}/reset') diff --git a/web.py b/web.py index c482f5a..974ce34 100644 --- a/web.py +++ b/web.py @@ -523,6 +523,9 @@ class Web(User, Protocol): obj_as1 = obj.as1 from_proto = PROTOCOLS.get(obj.source_protocol) if from_proto: + if not common.is_enabled(cls, from_proto): + error(f'{cls.LABEL} <=> {from_proto.LABEL} not enabled') + # fill in author/actor if available for field in 'author', 'actor': val = as1.get_object(obj.as1, field)