kopia lustrzana https://github.com/snarfed/bridgy-fed
User.get_or_create: abstract propagate and create_for across protocols
rodzic
f357ea1698
commit
6b597c90c3
2
ids.py
2
ids.py
|
@ -17,7 +17,7 @@ import models
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Protocols to check User.copies and Object.copies before translating
|
# Protocols to check User.copies and Object.copies before translating
|
||||||
COPIES_PROTOCOLS = ('atproto', 'fake', 'other', 'nostr')
|
COPIES_PROTOCOLS = ('atproto', 'fake', 'other')
|
||||||
|
|
||||||
# Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for
|
# Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for
|
||||||
# historical compatibility. Loaded on first call to web_ap_subdomain().
|
# historical compatibility. Loaded on first call to web_ap_subdomain().
|
||||||
|
|
14
models.py
14
models.py
|
@ -249,12 +249,14 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
||||||
if not user.obj_key:
|
if not user.obj_key:
|
||||||
user.obj = cls.load(user.profile_id())
|
user.obj = cls.load(user.profile_id())
|
||||||
|
|
||||||
ATProto = PROTOCOLS['atproto']
|
if propagate:
|
||||||
if propagate and cls.LABEL != 'atproto' and not user.get_copy(ATProto):
|
for label in ids.COPIES_PROTOCOLS:
|
||||||
if cls.is_enabled_to(ATProto, user=id):
|
proto = PROTOCOLS[label]
|
||||||
ATProto.create_for(user)
|
if proto != cls and not user.get_copy(proto):
|
||||||
else:
|
if cls.is_enabled_to(proto, user=id):
|
||||||
logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping')
|
proto.create_for(user)
|
||||||
|
else:
|
||||||
|
logger.info(f'{cls.LABEL} <=> atproto not enabled, skipping')
|
||||||
|
|
||||||
# generate keys for all protocols _except_ our own
|
# generate keys for all protocols _except_ our own
|
||||||
#
|
#
|
||||||
|
|
10
protocol.py
10
protocol.py
|
@ -441,6 +441,16 @@ class Protocol:
|
||||||
if owner:
|
if owner:
|
||||||
return cls.key_for(owner)
|
return cls.key_for(owner)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_for(cls, user):
|
||||||
|
"""Creates a copy user in this protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user (models.User): original source user. Shouldn't already have a
|
||||||
|
copy user for this protocol in ``copies``.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def send(to_cls, obj, url, from_user=None, orig_obj=None):
|
def send(to_cls, obj, url, from_user=None, orig_obj=None):
|
||||||
"""Sends an outgoing activity.
|
"""Sends an outgoing activity.
|
||||||
|
|
|
@ -700,6 +700,7 @@ class ATProtoTest(TestCase):
|
||||||
|
|
||||||
mock_create_task.assert_called()
|
mock_create_task.assert_called()
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
|
||||||
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
||||||
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
||||||
@patch('requests.post',
|
@patch('requests.post',
|
||||||
|
@ -764,6 +765,7 @@ class ATProtoTest(TestCase):
|
||||||
self.assert_task(mock_create_task, 'atproto-commit',
|
self.assert_task(mock_create_task, 'atproto-commit',
|
||||||
'/queue/atproto-commit')
|
'/queue/atproto-commit')
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
|
||||||
@patch('requests.get', return_value=requests_response(
|
@patch('requests.get', return_value=requests_response(
|
||||||
'blob contents', content_type='image/png')) # image blob fetch
|
'blob contents', content_type='image/png')) # image blob fetch
|
||||||
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
||||||
|
|
|
@ -66,10 +66,15 @@ class UserTest(TestCase):
|
||||||
user.direct = True
|
user.direct = True
|
||||||
self.assert_entities_equal(same, user, ignore=['updated'])
|
self.assert_entities_equal(same, user, ignore=['updated'])
|
||||||
|
|
||||||
|
def test_get_or_create_propagate_fake_other(self):
|
||||||
|
user = Fake.get_or_create('fake:user', propagate=True)
|
||||||
|
self.assertEqual(['fake:user'], OtherFake.created_for)
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['fake', 'other', 'atproto'])
|
||||||
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
||||||
@patch('requests.post',
|
@patch('requests.post',
|
||||||
return_value=requests_response('OK')) # create DID on PLC
|
return_value=requests_response('OK')) # create DID on PLC
|
||||||
def test_get_or_create_propagate(self, mock_post, mock_create_task):
|
def test_get_or_create_propagate_atproto(self, mock_post, mock_create_task):
|
||||||
common.RUN_TASKS_INLINE = False
|
common.RUN_TASKS_INLINE = False
|
||||||
|
|
||||||
Fake.fetchable = {
|
Fake.fetchable = {
|
||||||
|
|
|
@ -35,6 +35,7 @@ import requests
|
||||||
# other modules are imported _after_ Fake etc classes is defined so that it's in
|
# other modules are imported _after_ Fake etc classes is defined so that it's in
|
||||||
# PROTOCOLS when URL routes are registered.
|
# PROTOCOLS when URL routes are registered.
|
||||||
from common import long_to_base64, TASKS_LOCATION
|
from common import long_to_base64, TASKS_LOCATION
|
||||||
|
import ids
|
||||||
import models
|
import models
|
||||||
from models import KEY_BITS, Object, PROTOCOLS, Target, User
|
from models import KEY_BITS, Object, PROTOCOLS, Target, User
|
||||||
import protocol
|
import protocol
|
||||||
|
@ -76,8 +77,9 @@ class Fake(User, protocol.Protocol):
|
||||||
# in-order list of (Object, str URL)
|
# in-order list of (Object, str URL)
|
||||||
sent = []
|
sent = []
|
||||||
|
|
||||||
# in-order list of ids
|
# in-order lists of ids
|
||||||
fetched = []
|
fetched = []
|
||||||
|
created_for = []
|
||||||
|
|
||||||
@ndb.ComputedProperty
|
@ndb.ComputedProperty
|
||||||
def handle(self):
|
def handle(self):
|
||||||
|
@ -86,6 +88,10 @@ class Fake(User, protocol.Protocol):
|
||||||
def web_url(self):
|
def web_url(self):
|
||||||
return self.key.id()
|
return self.key.id()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_for(cls, user):
|
||||||
|
cls.created_for.append(user.key.id())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_id(cls, id):
|
def owns_id(cls, id):
|
||||||
if id.startswith('nope') or id == f'{cls.LABEL}:nope':
|
if id.startswith('nope') or id == f'{cls.LABEL}:nope':
|
||||||
|
@ -157,6 +163,7 @@ class OtherFake(Fake):
|
||||||
fetchable = {}
|
fetchable = {}
|
||||||
sent = []
|
sent = []
|
||||||
fetched = []
|
fetched = []
|
||||||
|
created_for = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def target_for(cls, obj, shared=False):
|
def target_for(cls, obj, shared=False):
|
||||||
|
@ -171,6 +178,7 @@ class ExplicitEnableFake(Fake):
|
||||||
fetchable = {}
|
fetchable = {}
|
||||||
sent = []
|
sent = []
|
||||||
fetched = []
|
fetched = []
|
||||||
|
created_for = []
|
||||||
|
|
||||||
|
|
||||||
# import other modules that register Flask handlers *after* Fake is defined
|
# import other modules that register Flask handlers *after* Fake is defined
|
||||||
|
@ -211,13 +219,15 @@ class TestCase(unittest.TestCase, testutil.Asserts):
|
||||||
did.resolve_plc.cache.clear()
|
did.resolve_plc.cache.clear()
|
||||||
did.resolve_web.cache.clear()
|
did.resolve_web.cache.clear()
|
||||||
|
|
||||||
for cls in Fake, OtherFake:
|
for cls in ExplicitEnableFake, Fake, OtherFake:
|
||||||
cls.fetchable = {}
|
cls.fetchable = {}
|
||||||
cls.sent = []
|
cls.sent = []
|
||||||
cls.fetched = []
|
cls.fetched = []
|
||||||
|
cls.created_for = []
|
||||||
|
|
||||||
common.OTHER_DOMAINS += ('fake.brid.gy',)
|
common.OTHER_DOMAINS += ('fake.brid.gy',)
|
||||||
common.DOMAINS += ('fake.brid.gy',)
|
common.DOMAINS += ('fake.brid.gy',)
|
||||||
|
ids.COPIES_PROTOCOLS = ['fake', 'other']
|
||||||
|
|
||||||
# make random test data deterministic
|
# make random test data deterministic
|
||||||
arroba.util._clockid = 17
|
arroba.util._clockid = 17
|
||||||
|
|
Ładowanie…
Reference in New Issue