add common.ENABLED_BRIDGES, check before conversion and /bridge-user

pull/906/head
Ryan Barrett 2024-02-28 10:57:30 -08:00
rodzic 3ef64948e5
commit d2865fdb86
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
11 zmienionych plików z 86 dodań i 7 usunięć

Wyświetl plik

@ -34,7 +34,7 @@ from common import (
subdomain_wrap, subdomain_wrap,
unwrap, unwrap,
) )
from models import fetch_objects, Follower, Object, User from models import fetch_objects, Follower, Object, PROTOCOLS, User
from protocol import Protocol from protocol import Protocol
import webfinger import webfinger
@ -332,6 +332,11 @@ class ActivityPub(User, Protocol):
""" """
if not obj or not obj.as1: if not obj or not obj.as1:
return {} return {}
from_proto = PROTOCOLS.get(obj.source_protocol)
if from_proto and not common.is_enabled(cls, from_proto):
error(f'{cls.LABEL} <=> {from_proto.LABEL} not enabled')
if obj.as2: if obj.as2:
return { return {
# add back @context since we strip it when we store Objects # add back @context since we strip it when we store Objects

Wyświetl plik

@ -389,6 +389,10 @@ class ATProto(User, Protocol):
Returns: Returns:
dict: JSON object dict: JSON object
""" """
from_proto = PROTOCOLS.get(obj.source_protocol)
if from_proto and not common.is_enabled(cls, from_proto):
error(f'{cls.LABEL} <=> {from_proto.LABEL} not enabled')
if obj.bsky: if obj.bsky:
return obj.bsky return obj.bsky

Wyświetl plik

@ -31,6 +31,13 @@ TLD_BLOCKLIST = ('7z', 'asp', 'aspx', 'gif', 'html', 'ico', 'jpg', 'jpeg', 'js',
CONTENT_TYPE_HTML = 'text/html; charset=utf-8' CONTENT_TYPE_HTML = 'text/html; charset=utf-8'
# Protocol pairs that we currently support bridging between. Values must be
# Protocol LABELs. Each pair must be lexicographically sorted!
ENABLED_BRIDGES = frozenset((
('activitypub', 'web'),
('atproto', 'web'),
))
PRIMARY_DOMAIN = 'fed.brid.gy' PRIMARY_DOMAIN = 'fed.brid.gy'
# protocol-specific subdomains are under this "super"domain # protocol-specific subdomains are under this "super"domain
SUPERDOMAIN = '.brid.gy' SUPERDOMAIN = '.brid.gy'
@ -248,6 +255,27 @@ def add(seq, val):
seq.append(val) seq.append(val)
def is_enabled(proto_a, proto_b):
"""Returns True if bridging the two input protocols is enabled, False otherwise.
Args:
proto_a (Protocol subclass)
proto_b (Protocol subclass)
Returns:
bool:
"""
if proto_a == proto_b:
return True
labels = tuple(sorted((proto_a.LABEL, proto_b.LABEL)))
if DEBUG and ('fake' in labels or 'other' in labels):
return True
return labels in ENABLED_BRIDGES
def create_task(queue, delay=None, **params): def create_task(queue, delay=None, **params):
"""Adds a Cloud Tasks task. """Adds a Cloud Tasks task.

Wyświetl plik

@ -240,7 +240,10 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
ATProto = PROTOCOLS['atproto'] ATProto = PROTOCOLS['atproto']
if propagate and cls.LABEL != 'atproto' and not user.get_copy(ATProto): if propagate and cls.LABEL != 'atproto' and not user.get_copy(ATProto):
if common.is_enabled(cls, ATProto):
ATProto.create_for(user) ATProto.create_for(user)
else:
logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping')
# generate keys for all protocols _except_ our own # generate keys for all protocols _except_ our own
# #

Wyświetl plik

@ -16,7 +16,7 @@ from oauth_dropins.webutil.util import domain_from_link, json_dumps, json_loads
from oauth_dropins.webutil import util from oauth_dropins.webutil import util
import requests import requests
from urllib3.exceptions import ReadTimeoutError from urllib3.exceptions import ReadTimeoutError
from werkzeug.exceptions import BadGateway from werkzeug.exceptions import BadGateway, BadRequest
# import first so that Fake is defined before URL routes are registered # import first so that Fake is defined before URL routes are registered
from . import testutil from . import testutil
@ -2198,6 +2198,11 @@ class ActivityPubUtilsTest(TestCase):
'object': ACTOR, 'object': ACTOR,
}, ActivityPub.convert(obj)) }, ActivityPub.convert(obj))
def test_convert_protocols_not_enabled(self):
obj = Object(our_as1={'foo': 'bar'}, source_protocol='atproto')
with self.assertRaises(BadRequest):
ActivityPub.convert(obj)
def test_postprocess_as2_idempotent(self): def test_postprocess_as2_idempotent(self):
for obj in (ACTOR, REPLY_OBJECT, REPLY_OBJECT_WRAPPED, REPLY, for obj in (ACTOR, REPLY_OBJECT, REPLY_OBJECT_WRAPPED, REPLY,
NOTE_OBJECT, NOTE, MENTION_OBJECT, MENTION, LIKE, NOTE_OBJECT, NOTE, MENTION_OBJECT, MENTION, LIKE,

Wyświetl plik

@ -24,12 +24,13 @@ from multiformats import CID
from oauth_dropins.webutil.appengine_config import tasks_client from oauth_dropins.webutil.appengine_config import tasks_client
from oauth_dropins.webutil.testutil import requests_response from oauth_dropins.webutil.testutil import requests_response
from oauth_dropins.webutil.util import json_dumps, json_loads, trim_nulls from oauth_dropins.webutil.util import json_dumps, json_loads, trim_nulls
from werkzeug.exceptions import BadRequest
import atproto import atproto
from atproto import ATProto from atproto import ATProto
import common import common
import hub import hub
from models import Object, Target from models import Object, PROTOCOLS, Target
import protocol import protocol
from .testutil import ATPROTO_KEY, Fake, TestCase from .testutil import ATPROTO_KEY, Fake, TestCase
from . import test_activitypub from . import test_activitypub
@ -381,6 +382,11 @@ class ATProtoTest(TestCase):
'image': [{'url': 'http://my/pic'}], 'image': [{'url': 'http://my/pic'}],
}), fetch_blobs=True)) }), fetch_blobs=True))
def test_convert_protocols_not_enabled(self):
obj = Object(our_as1={'foo': 'bar'}, source_protocol='activitypub')
with self.assertRaises(BadRequest):
ATProto.convert(obj)
@patch('requests.get', return_value=requests_response('', status=404)) @patch('requests.get', return_value=requests_response('', status=404))
def test_web_url(self, mock_get): def test_web_url(self, mock_get):
user = self.make_user('did:plc:user', cls=ATProto) user = self.make_user('did:plc:user', cls=ATProto)

Wyświetl plik

@ -2,8 +2,10 @@
from flask import g from flask import g
# import first so that Fake is defined before URL routes are registered # import first so that Fake is defined before URL routes are registered
from .testutil import Fake, TestCase from .testutil import Fake, OtherFake, TestCase
from activitypub import ActivityPub
from atproto import ATProto
import common import common
from flask_app import app from flask_app import app
from ui import UIProtocol from ui import UIProtocol
@ -99,3 +101,11 @@ class CommonTest(TestCase):
with app.test_request_context(base_url='https://atproto.brid.gy', path='/foo'): with app.test_request_context(base_url='https://atproto.brid.gy', path='/foo'):
self.assertEqual('https://atproto.brid.gy/asdf', common.host_url('asdf')) self.assertEqual('https://atproto.brid.gy/asdf', common.host_url('asdf'))
def test_is_enabled(self):
self.assertTrue(common.is_enabled(Web, ActivityPub))
self.assertTrue(common.is_enabled(ActivityPub, Web))
self.assertTrue(common.is_enabled(ActivityPub, ActivityPub))
self.assertTrue(common.is_enabled(ATProto, Web))
self.assertTrue(common.is_enabled(Fake, OtherFake))
self.assertFalse(common.is_enabled(ATProto, ActivityPub))

Wyświetl plik

@ -26,6 +26,7 @@ class IntegrationTests(TestCase):
@patch('requests.post') @patch('requests.post')
@patch('requests.get') @patch('requests.get')
@patch('common.ENABLED_BRIDGES', new=[('activitypub', 'atproto')])
def test_atproto_notify_reply_to_activitypub(self, mock_get, mock_post): def test_atproto_notify_reply_to_activitypub(self, mock_get, mock_post):
"""ATProto poll notifications, deliver reply to ActivityPub. """ATProto poll notifications, deliver reply to ActivityPub.

Wyświetl plik

@ -1,7 +1,7 @@
"""Unit tests for models.py.""" """Unit tests for models.py."""
from unittest.mock import patch from unittest.mock import patch
from arroba.datastore_storage import AtpRemoteBlob from arroba.datastore_storage import AtpRemoteBlob, AtpRepo
from arroba.mst import dag_cbor_cid from arroba.mst import dag_cbor_cid
import arroba.server import arroba.server
from arroba.util import at_uri from arroba.util import at_uri
@ -20,6 +20,7 @@ from oauth_dropins.webutil import util
# import first so that Fake is defined before URL routes are registered # import first so that Fake is defined before URL routes are registered
from .testutil import Fake, OtherFake, TestCase from .testutil import Fake, OtherFake, TestCase
from activitypub import ActivityPub
from atproto import ATProto from atproto import ATProto
import common import common
import models import models
@ -102,6 +103,20 @@ class UserTest(TestCase):
mock_create_task.assert_called() mock_create_task.assert_called()
@patch.object(tasks_client, 'create_task')
@patch('requests.post')
def test_get_or_create_propagate_not_enabled(self, mock_post, mock_create_task):
user = ActivityPub.get_or_create('https://mas.to/actor', propagate=True)
# self.assertEqual([], Fake.fetched)
mock_post.assert_not_called()
mock_create_task.assert_not_called()
user = ActivityPub.get_by_id('https://mas.to/actor')
self.assertEqual([], user.copies)
self.assertEqual(0, AtpRepo.query().count())
def test_get_or_create_use_instead(self): def test_get_or_create_use_instead(self):
user = Fake.get_or_create('a.b') user = Fake.get_or_create('a.b')
user.use_instead = self.user.key user.use_instead = self.user.key

Wyświetl plik

@ -174,7 +174,6 @@ import common
from web import Web from web import Web
from flask_app import app, cache from flask_app import app, cache
# used in TestCase.make_user() to reuse keys across Users since they're # used in TestCase.make_user() to reuse keys across Users since they're
# expensive to generate. # expensive to generate.
requests.post(f'http://{ndb_client.host}/reset') requests.post(f'http://{ndb_client.host}/reset')

3
web.py
Wyświetl plik

@ -523,6 +523,9 @@ class Web(User, Protocol):
obj_as1 = obj.as1 obj_as1 = obj.as1
from_proto = PROTOCOLS.get(obj.source_protocol) from_proto = PROTOCOLS.get(obj.source_protocol)
if from_proto: if from_proto:
if not common.is_enabled(cls, from_proto):
error(f'{cls.LABEL} <=> {from_proto.LABEL} not enabled')
# fill in author/actor if available # fill in author/actor if available
for field in 'author', 'actor': for field in 'author', 'actor':
val = as1.get_object(obj.as1, field) val = as1.get_object(obj.as1, field)