move User.get_for_copy/ies to module level, add Object results

pull/701/head
Ryan Barrett 2023-10-26 16:00:03 -07:00
rodzic 5843235fd1
commit cfbfba654e
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
4 zmienionych plików z 60 dodań i 44 usunięć

3
ids.py
Wyświetl plik

@ -5,6 +5,7 @@ https://fed.brid.gy/docs#translate
import re import re
from common import subdomain_wrap, SUPERDOMAIN from common import subdomain_wrap, SUPERDOMAIN
import models
def translate_user_id(*, id, from_proto, to_proto): def translate_user_id(*, id, from_proto, to_proto):
@ -29,7 +30,7 @@ def translate_user_id(*, id, from_proto, to_proto):
user = from_proto.get_by_id(id) user = from_proto.get_by_id(id)
return user.atproto_did if user else None return user.atproto_did if user else None
case ('atproto', _): case ('atproto', _):
user = from_proto.get_for_copy(id) user = models.get_for_copy(id)
return user.key.id() if user else None return user.key.id() if user else None
case (_, 'activitypub'): case (_, 'activitypub'):
return subdomain_wrap(from_proto, f'/ap/{id}') return subdomain_wrap(from_proto, f'/ap/{id}')

Wyświetl plik

@ -197,35 +197,6 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
return user return user
@staticmethod
def get_for_copy(copy_id):
"""Fetches a user with a given id in copies.
Thin wrapper around :meth:`User.get_copies` that returns the first
matching :class:`User`.
"""
users = User.get_for_copies([copy_id])
if users:
return users[0]
@staticmethod
def get_for_copies(copy_ids):
"""Fetches users (across all protocols) for a given set of copies.
Args:
copy_ids (sequence of str)
Returns:
sequence of :class:`User` subclass instances
"""
assert copy_ids
return list(itertools.chain(*(
cls.query(cls.copies.uri.IN(copy_ids))
for cls in set(PROTOCOLS.values()) if cls)))
# TODO: default to looking up copy_ids as key ids, across protocols? is
# that useful anywhere?
@classmethod @classmethod
@ndb.transactional() @ndb.transactional()
def get_or_create(cls, id, propagate=False, **kwargs): def get_or_create(cls, id, propagate=False, **kwargs):
@ -960,9 +931,8 @@ class Object(StringIdModel):
if not ids: if not ids:
return return
origs = (User.get_for_copies(ids)
+ Object.query(Object.copies.uri.IN(ids)).fetch())
origs = get_for_copies(ids)
replaced = False replaced = False
def replace(obj, field): def replace(obj, field):
@ -1165,3 +1135,44 @@ def fetch_page(query, model_class, by=None):
new_before = new_before.isoformat() new_before = new_before.isoformat()
return results, new_before, new_after return results, new_before, new_after
def get_for_copy(copy_id, keys_only=None):
"""Fetches a user or object with a given id in copies.
Thin wrapper around :func:`get_copies` that returns the first
matching result.
Args:
copy_id (str)
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
Returns:
User or Object:
"""
got = get_for_copies([copy_id], keys_only=keys_only)
if got:
return got[0]
def get_for_copies(copy_ids, keys_only=None):
"""Fetches users (across all protocols) for a given set of copies.
Args:
copy_ids (sequence of str)
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
Returns:
sequence of User and/or Object
"""
assert copy_ids
classes = set(cls for cls in PROTOCOLS.values() if cls)
classes.add(Object)
return list(itertools.chain(*(
cls.query(cls.copies.uri.IN(copy_ids)).iter(keys_only=keys_only)
for cls in classes)))
# TODO: default to looking up copy_ids as key ids, across protocols? is
# that useful anywhere?

Wyświetl plik

@ -12,14 +12,14 @@ from google.cloud.ndb import OR
from granary import as1 from granary import as1
from oauth_dropins.webutil.appengine_config import ndb_client from oauth_dropins.webutil.appengine_config import ndb_client
from oauth_dropins.webutil.flask_util import cloud_tasks_only from oauth_dropins.webutil.flask_util import cloud_tasks_only
from oauth_dropins.webutil import util
from oauth_dropins.webutil.util import json_dumps, json_loads
import werkzeug.exceptions import werkzeug.exceptions
import common import common
from common import add, DOMAIN_BLOCKLIST, DOMAINS, error, subdomain_wrap from common import add, DOMAIN_BLOCKLIST, DOMAINS, error, subdomain_wrap
from flask_app import app from flask_app import app
from models import Follower, Object, PROTOCOLS, Target, User from models import Follower, get_for_copies, Object, PROTOCOLS, Target, User
from oauth_dropins.webutil import util
from oauth_dropins.webutil.util import json_dumps, json_loads
SUPPORTED_TYPES = ( SUPPORTED_TYPES = (
'accept', 'accept',
@ -988,8 +988,7 @@ class Protocol:
logger.info(f'Raw targets: {target_uris}') logger.info(f'Raw targets: {target_uris}')
if target_uris: if target_uris:
origs = {u.key.id() for u in User.get_for_copies(target_uris)} | \ origs = {key.id() for key in get_for_copies(target_uris, keys_only=True)}
{o.key.id() for o in Object.query(Object.copies.uri.IN(target_uris))}
if origs: if origs:
target_uris |= origs target_uris |= origs
logger.info(f'Added originals: {origs}') logger.info(f'Added originals: {origs}')

Wyświetl plik

@ -21,6 +21,7 @@ from oauth_dropins.webutil import util
from .testutil import Fake, OtherFake, TestCase from .testutil import Fake, OtherFake, TestCase
from atproto import ATProto from atproto import ATProto
import models
from models import Follower, Object, OBJECT_EXPIRE_AGE, Target, User from models import Follower, Object, OBJECT_EXPIRE_AGE, Target, User
import protocol import protocol
from protocol import Protocol from protocol import Protocol
@ -99,13 +100,6 @@ class UserTest(TestCase):
user.atproto_did = 'did:plc:123' user.atproto_did = 'did:plc:123'
user.atproto_did = None user.atproto_did = None
def test_get_for_copies(self):
self.assertEqual([], User.get_for_copies(['did:plc:foo']))
target = Target(uri='did:plc:foo', protocol='atproto')
fake_user = self.make_user('fake:user', cls=Fake, copies=[target])
self.assertEqual([fake_user], User.get_for_copies(['did:plc:foo']))
def test_get_or_create_use_instead(self): def test_get_or_create_use_instead(self):
user = Fake.get_or_create('a.b') user = Fake.get_or_create('a.b')
user.use_instead = g.user.key user.use_instead = g.user.key
@ -794,6 +788,17 @@ class ObjectTest(TestCase):
}, },
}, obj.our_as1) }, obj.our_as1)
def test_get_for_copies(self):
self.assertEqual([], models.get_for_copies(['foo', 'did:plc:bar']))
obj = self.store_object(id='fake:post',
copies=[Target(uri='other:foo', protocol='other')])
user = self.make_user('other:user', cls=OtherFake,
copies=[Target(uri='fake:bar', protocol='fake')])
self.assert_entities_equal(
[obj, user], models.get_for_copies(['other:foo', 'fake:bar', 'baz']))
class FollowerTest(TestCase): class FollowerTest(TestCase):