User.get_or_create: abstract propagate and create_for across protocols

pull/968/head
Ryan Barrett 2024-04-21 11:27:23 -07:00
rodzic f357ea1698
commit 6b597c90c3
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
6 zmienionych plików z 39 dodań i 10 usunięć

2
ids.py
Wyświetl plik

@ -17,7 +17,7 @@ import models
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Protocols to check User.copies and Object.copies before translating # Protocols to check User.copies and Object.copies before translating
COPIES_PROTOCOLS = ('atproto', 'fake', 'other', 'nostr') COPIES_PROTOCOLS = ('atproto', 'fake', 'other')
# Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for # Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for
# historical compatibility. Loaded on first call to web_ap_subdomain(). # historical compatibility. Loaded on first call to web_ap_subdomain().

Wyświetl plik

@ -249,12 +249,14 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
if not user.obj_key: if not user.obj_key:
user.obj = cls.load(user.profile_id()) user.obj = cls.load(user.profile_id())
ATProto = PROTOCOLS['atproto'] if propagate:
if propagate and cls.LABEL != 'atproto' and not user.get_copy(ATProto): for label in ids.COPIES_PROTOCOLS:
if cls.is_enabled_to(ATProto, user=id): proto = PROTOCOLS[label]
ATProto.create_for(user) if proto != cls and not user.get_copy(proto):
else: if cls.is_enabled_to(proto, user=id):
logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping') proto.create_for(user)
else:
logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping')
# generate keys for all protocols _except_ our own # generate keys for all protocols _except_ our own
# #

Wyświetl plik

@ -441,6 +441,16 @@ class Protocol:
if owner: if owner:
return cls.key_for(owner) return cls.key_for(owner)
@classmethod
def create_for(cls, user):
"""Creates a copy user in this protocol.
Args:
user (models.User): original source user. Shouldn't already have a
copy user for this protocol in ``copies``.
"""
raise NotImplementedError()
@classmethod @classmethod
def send(to_cls, obj, url, from_user=None, orig_obj=None): def send(to_cls, obj, url, from_user=None, orig_obj=None):
"""Sends an outgoing activity. """Sends an outgoing activity.

Wyświetl plik

@ -700,6 +700,7 @@ class ATProtoTest(TestCase):
mock_create_task.assert_called() mock_create_task.assert_called()
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
@patch('google.cloud.dns.client.ManagedZone', autospec=True) @patch('google.cloud.dns.client.ManagedZone', autospec=True)
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
@patch('requests.post', @patch('requests.post',
@ -764,6 +765,7 @@ class ATProtoTest(TestCase):
self.assert_task(mock_create_task, 'atproto-commit', self.assert_task(mock_create_task, 'atproto-commit',
'/queue/atproto-commit') '/queue/atproto-commit')
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
@patch('requests.get', return_value=requests_response( @patch('requests.get', return_value=requests_response(
'blob contents', content_type='image/png')) # image blob fetch 'blob contents', content_type='image/png')) # image blob fetch
@patch('google.cloud.dns.client.ManagedZone', autospec=True) @patch('google.cloud.dns.client.ManagedZone', autospec=True)

Wyświetl plik

@ -66,10 +66,15 @@ class UserTest(TestCase):
user.direct = True user.direct = True
self.assert_entities_equal(same, user, ignore=['updated']) self.assert_entities_equal(same, user, ignore=['updated'])
def test_get_or_create_propagate_fake_other(self):
user = Fake.get_or_create('fake:user', propagate=True)
self.assertEqual(['fake:user'], OtherFake.created_for)
@patch('ids.COPIES_PROTOCOLS', ['fake', 'other', 'atproto'])
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
@patch('requests.post', @patch('requests.post',
return_value=requests_response('OK')) # create DID on PLC return_value=requests_response('OK')) # create DID on PLC
def test_get_or_create_propagate(self, mock_post, mock_create_task): def test_get_or_create_propagate_atproto(self, mock_post, mock_create_task):
common.RUN_TASKS_INLINE = False common.RUN_TASKS_INLINE = False
Fake.fetchable = { Fake.fetchable = {

Wyświetl plik

@ -35,6 +35,7 @@ import requests
# other modules are imported _after_ Fake etc classes is defined so that it's in # other modules are imported _after_ Fake etc classes is defined so that it's in
# PROTOCOLS when URL routes are registered. # PROTOCOLS when URL routes are registered.
from common import long_to_base64, TASKS_LOCATION from common import long_to_base64, TASKS_LOCATION
import ids
import models import models
from models import KEY_BITS, Object, PROTOCOLS, Target, User from models import KEY_BITS, Object, PROTOCOLS, Target, User
import protocol import protocol
@ -76,8 +77,9 @@ class Fake(User, protocol.Protocol):
# in-order list of (Object, str URL) # in-order list of (Object, str URL)
sent = [] sent = []
# in-order list of ids # in-order lists of ids
fetched = [] fetched = []
created_for = []
@ndb.ComputedProperty @ndb.ComputedProperty
def handle(self): def handle(self):
@ -86,6 +88,10 @@ class Fake(User, protocol.Protocol):
def web_url(self): def web_url(self):
return self.key.id() return self.key.id()
@classmethod
def create_for(cls, user):
cls.created_for.append(user.key.id())
@classmethod @classmethod
def owns_id(cls, id): def owns_id(cls, id):
if id.startswith('nope') or id == f'{cls.LABEL}:nope': if id.startswith('nope') or id == f'{cls.LABEL}:nope':
@ -157,6 +163,7 @@ class OtherFake(Fake):
fetchable = {} fetchable = {}
sent = [] sent = []
fetched = [] fetched = []
created_for = []
@classmethod @classmethod
def target_for(cls, obj, shared=False): def target_for(cls, obj, shared=False):
@ -171,6 +178,7 @@ class ExplicitEnableFake(Fake):
fetchable = {} fetchable = {}
sent = [] sent = []
fetched = [] fetched = []
created_for = []
# import other modules that register Flask handlers *after* Fake is defined # import other modules that register Flask handlers *after* Fake is defined
@ -211,13 +219,15 @@ class TestCase(unittest.TestCase, testutil.Asserts):
did.resolve_plc.cache.clear() did.resolve_plc.cache.clear()
did.resolve_web.cache.clear() did.resolve_web.cache.clear()
for cls in Fake, OtherFake: for cls in ExplicitEnableFake, Fake, OtherFake:
cls.fetchable = {} cls.fetchable = {}
cls.sent = [] cls.sent = []
cls.fetched = [] cls.fetched = []
cls.created_for = []
common.OTHER_DOMAINS += ('fake.brid.gy',) common.OTHER_DOMAINS += ('fake.brid.gy',)
common.DOMAINS += ('fake.brid.gy',) common.DOMAINS += ('fake.brid.gy',)
ids.COPIES_PROTOCOLS = ['fake', 'other']
# make random test data deterministic # make random test data deterministic
arroba.util._clockid = 17 arroba.util._clockid = 17