kopia lustrzana https://github.com/snarfed/bridgy-fed
move User.get_for_copy/ies to module level, add Object results
rodzic
5843235fd1
commit
cfbfba654e
3
ids.py
3
ids.py
|
@ -5,6 +5,7 @@ https://fed.brid.gy/docs#translate
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from common import subdomain_wrap, SUPERDOMAIN
|
from common import subdomain_wrap, SUPERDOMAIN
|
||||||
|
import models
|
||||||
|
|
||||||
|
|
||||||
def translate_user_id(*, id, from_proto, to_proto):
|
def translate_user_id(*, id, from_proto, to_proto):
|
||||||
|
@ -29,7 +30,7 @@ def translate_user_id(*, id, from_proto, to_proto):
|
||||||
user = from_proto.get_by_id(id)
|
user = from_proto.get_by_id(id)
|
||||||
return user.atproto_did if user else None
|
return user.atproto_did if user else None
|
||||||
case ('atproto', _):
|
case ('atproto', _):
|
||||||
user = from_proto.get_for_copy(id)
|
user = models.get_for_copy(id)
|
||||||
return user.key.id() if user else None
|
return user.key.id() if user else None
|
||||||
case (_, 'activitypub'):
|
case (_, 'activitypub'):
|
||||||
return subdomain_wrap(from_proto, f'/ap/{id}')
|
return subdomain_wrap(from_proto, f'/ap/{id}')
|
||||||
|
|
73
models.py
73
models.py
|
@ -197,35 +197,6 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_for_copy(copy_id):
|
|
||||||
"""Fetches a user with a given id in copies.
|
|
||||||
|
|
||||||
Thin wrapper around :meth:`User.get_copies` that returns the first
|
|
||||||
matching :class:`User`.
|
|
||||||
"""
|
|
||||||
users = User.get_for_copies([copy_id])
|
|
||||||
if users:
|
|
||||||
return users[0]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_for_copies(copy_ids):
|
|
||||||
"""Fetches users (across all protocols) for a given set of copies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
copy_ids (sequence of str)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sequence of :class:`User` subclass instances
|
|
||||||
"""
|
|
||||||
assert copy_ids
|
|
||||||
return list(itertools.chain(*(
|
|
||||||
cls.query(cls.copies.uri.IN(copy_ids))
|
|
||||||
for cls in set(PROTOCOLS.values()) if cls)))
|
|
||||||
|
|
||||||
# TODO: default to looking up copy_ids as key ids, across protocols? is
|
|
||||||
# that useful anywhere?
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ndb.transactional()
|
@ndb.transactional()
|
||||||
def get_or_create(cls, id, propagate=False, **kwargs):
|
def get_or_create(cls, id, propagate=False, **kwargs):
|
||||||
|
@ -960,9 +931,8 @@ class Object(StringIdModel):
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
origs = (User.get_for_copies(ids)
|
|
||||||
+ Object.query(Object.copies.uri.IN(ids)).fetch())
|
|
||||||
|
|
||||||
|
origs = get_for_copies(ids)
|
||||||
replaced = False
|
replaced = False
|
||||||
|
|
||||||
def replace(obj, field):
|
def replace(obj, field):
|
||||||
|
@ -1165,3 +1135,44 @@ def fetch_page(query, model_class, by=None):
|
||||||
new_before = new_before.isoformat()
|
new_before = new_before.isoformat()
|
||||||
|
|
||||||
return results, new_before, new_after
|
return results, new_before, new_after
|
||||||
|
|
||||||
|
|
||||||
|
def get_for_copy(copy_id, keys_only=None):
|
||||||
|
"""Fetches a user or object with a given id in copies.
|
||||||
|
|
||||||
|
Thin wrapper around :func:`get_copies` that returns the first
|
||||||
|
matching result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
copy_id (str)
|
||||||
|
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User or Object:
|
||||||
|
"""
|
||||||
|
got = get_for_copies([copy_id], keys_only=keys_only)
|
||||||
|
if got:
|
||||||
|
return got[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_for_copies(copy_ids, keys_only=None):
|
||||||
|
"""Fetches users (across all protocols) for a given set of copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
copy_ids (sequence of str)
|
||||||
|
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sequence of User and/or Object
|
||||||
|
"""
|
||||||
|
assert copy_ids
|
||||||
|
|
||||||
|
classes = set(cls for cls in PROTOCOLS.values() if cls)
|
||||||
|
classes.add(Object)
|
||||||
|
|
||||||
|
return list(itertools.chain(*(
|
||||||
|
cls.query(cls.copies.uri.IN(copy_ids)).iter(keys_only=keys_only)
|
||||||
|
for cls in classes)))
|
||||||
|
|
||||||
|
# TODO: default to looking up copy_ids as key ids, across protocols? is
|
||||||
|
# that useful anywhere?
|
||||||
|
|
|
@ -12,14 +12,14 @@ from google.cloud.ndb import OR
|
||||||
from granary import as1
|
from granary import as1
|
||||||
from oauth_dropins.webutil.appengine_config import ndb_client
|
from oauth_dropins.webutil.appengine_config import ndb_client
|
||||||
from oauth_dropins.webutil.flask_util import cloud_tasks_only
|
from oauth_dropins.webutil.flask_util import cloud_tasks_only
|
||||||
|
from oauth_dropins.webutil import util
|
||||||
|
from oauth_dropins.webutil.util import json_dumps, json_loads
|
||||||
import werkzeug.exceptions
|
import werkzeug.exceptions
|
||||||
|
|
||||||
import common
|
import common
|
||||||
from common import add, DOMAIN_BLOCKLIST, DOMAINS, error, subdomain_wrap
|
from common import add, DOMAIN_BLOCKLIST, DOMAINS, error, subdomain_wrap
|
||||||
from flask_app import app
|
from flask_app import app
|
||||||
from models import Follower, Object, PROTOCOLS, Target, User
|
from models import Follower, get_for_copies, Object, PROTOCOLS, Target, User
|
||||||
from oauth_dropins.webutil import util
|
|
||||||
from oauth_dropins.webutil.util import json_dumps, json_loads
|
|
||||||
|
|
||||||
SUPPORTED_TYPES = (
|
SUPPORTED_TYPES = (
|
||||||
'accept',
|
'accept',
|
||||||
|
@ -988,8 +988,7 @@ class Protocol:
|
||||||
logger.info(f'Raw targets: {target_uris}')
|
logger.info(f'Raw targets: {target_uris}')
|
||||||
|
|
||||||
if target_uris:
|
if target_uris:
|
||||||
origs = {u.key.id() for u in User.get_for_copies(target_uris)} | \
|
origs = {key.id() for key in get_for_copies(target_uris, keys_only=True)}
|
||||||
{o.key.id() for o in Object.query(Object.copies.uri.IN(target_uris))}
|
|
||||||
if origs:
|
if origs:
|
||||||
target_uris |= origs
|
target_uris |= origs
|
||||||
logger.info(f'Added originals: {origs}')
|
logger.info(f'Added originals: {origs}')
|
||||||
|
|
|
@ -21,6 +21,7 @@ from oauth_dropins.webutil import util
|
||||||
from .testutil import Fake, OtherFake, TestCase
|
from .testutil import Fake, OtherFake, TestCase
|
||||||
|
|
||||||
from atproto import ATProto
|
from atproto import ATProto
|
||||||
|
import models
|
||||||
from models import Follower, Object, OBJECT_EXPIRE_AGE, Target, User
|
from models import Follower, Object, OBJECT_EXPIRE_AGE, Target, User
|
||||||
import protocol
|
import protocol
|
||||||
from protocol import Protocol
|
from protocol import Protocol
|
||||||
|
@ -99,13 +100,6 @@ class UserTest(TestCase):
|
||||||
user.atproto_did = 'did:plc:123'
|
user.atproto_did = 'did:plc:123'
|
||||||
user.atproto_did = None
|
user.atproto_did = None
|
||||||
|
|
||||||
def test_get_for_copies(self):
|
|
||||||
self.assertEqual([], User.get_for_copies(['did:plc:foo']))
|
|
||||||
|
|
||||||
target = Target(uri='did:plc:foo', protocol='atproto')
|
|
||||||
fake_user = self.make_user('fake:user', cls=Fake, copies=[target])
|
|
||||||
self.assertEqual([fake_user], User.get_for_copies(['did:plc:foo']))
|
|
||||||
|
|
||||||
def test_get_or_create_use_instead(self):
|
def test_get_or_create_use_instead(self):
|
||||||
user = Fake.get_or_create('a.b')
|
user = Fake.get_or_create('a.b')
|
||||||
user.use_instead = g.user.key
|
user.use_instead = g.user.key
|
||||||
|
@ -794,6 +788,17 @@ class ObjectTest(TestCase):
|
||||||
},
|
},
|
||||||
}, obj.our_as1)
|
}, obj.our_as1)
|
||||||
|
|
||||||
|
def test_get_for_copies(self):
|
||||||
|
self.assertEqual([], models.get_for_copies(['foo', 'did:plc:bar']))
|
||||||
|
|
||||||
|
obj = self.store_object(id='fake:post',
|
||||||
|
copies=[Target(uri='other:foo', protocol='other')])
|
||||||
|
user = self.make_user('other:user', cls=OtherFake,
|
||||||
|
copies=[Target(uri='fake:bar', protocol='fake')])
|
||||||
|
|
||||||
|
self.assert_entities_equal(
|
||||||
|
[obj, user], models.get_for_copies(['other:foo', 'fake:bar', 'baz']))
|
||||||
|
|
||||||
|
|
||||||
class FollowerTest(TestCase):
|
class FollowerTest(TestCase):
|
||||||
|
|
||||||
|
|
Ładowanie…
Reference in New Issue