From cfbfba654e30103cc80140f975589a03f2e97dd2 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Thu, 26 Oct 2023 16:00:03 -0700 Subject: [PATCH] move User.get_for_copy/ies to module level, add Object results --- ids.py | 3 +- models.py | 73 +++++++++++++++++++++++++------------------- protocol.py | 9 +++--- tests/test_models.py | 19 +++++++----- 4 files changed, 60 insertions(+), 44 deletions(-) diff --git a/ids.py b/ids.py index 7384641..9cc70a6 100644 --- a/ids.py +++ b/ids.py @@ -5,6 +5,7 @@ https://fed.brid.gy/docs#translate import re from common import subdomain_wrap, SUPERDOMAIN +import models def translate_user_id(*, id, from_proto, to_proto): @@ -29,7 +30,7 @@ def translate_user_id(*, id, from_proto, to_proto): user = from_proto.get_by_id(id) return user.atproto_did if user else None case ('atproto', _): - user = from_proto.get_for_copy(id) + user = models.get_for_copy(id) return user.key.id() if user else None case (_, 'activitypub'): return subdomain_wrap(from_proto, f'/ap/{id}') diff --git a/models.py b/models.py index 91e805c..fa65b62 100644 --- a/models.py +++ b/models.py @@ -197,35 +197,6 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): return user - @staticmethod - def get_for_copy(copy_id): - """Fetches a user with a given id in copies. - - Thin wrapper around :meth:`User.get_copies` that returns the first - matching :class:`User`. - """ - users = User.get_for_copies([copy_id]) - if users: - return users[0] - - @staticmethod - def get_for_copies(copy_ids): - """Fetches users (across all protocols) for a given set of copies. - - Args: - copy_ids (sequence of str) - - Returns: - sequence of :class:`User` subclass instances - """ - assert copy_ids - return list(itertools.chain(*( - cls.query(cls.copies.uri.IN(copy_ids)) - for cls in set(PROTOCOLS.values()) if cls))) - - # TODO: default to looking up copy_ids as key ids, across protocols? is - # that useful anywhere? - @classmethod @ndb.transactional() def get_or_create(cls, id, propagate=False, **kwargs): @@ -960,9 +931,8 @@ class Object(StringIdModel): if not ids: return - origs = (User.get_for_copies(ids) - + Object.query(Object.copies.uri.IN(ids)).fetch()) + origs = get_for_copies(ids) replaced = False def replace(obj, field): @@ -1165,3 +1135,44 @@ def fetch_page(query, model_class, by=None): new_before = new_before.isoformat() return results, new_before, new_after + + +def get_for_copy(copy_id, keys_only=None): + """Fetches a user or object with a given id in copies. + + Thin wrapper around :func:`get_copies` that returns the first + matching result. + + Args: + copy_id (str) + keys_only (bool): passed through to :class:`google.cloud.ndb.Query` + + Returns: + User or Object: + """ + got = get_for_copies([copy_id], keys_only=keys_only) + if got: + return got[0] + + +def get_for_copies(copy_ids, keys_only=None): + """Fetches users (across all protocols) for a given set of copies. + + Args: + copy_ids (sequence of str) + keys_only (bool): passed through to :class:`google.cloud.ndb.Query` + + Returns: + sequence of User and/or Object + """ + assert copy_ids + + classes = set(cls for cls in PROTOCOLS.values() if cls) + classes.add(Object) + + return list(itertools.chain(*( + cls.query(cls.copies.uri.IN(copy_ids)).iter(keys_only=keys_only) + for cls in classes))) + + # TODO: default to looking up copy_ids as key ids, across protocols? is + # that useful anywhere? diff --git a/protocol.py b/protocol.py index 71fa143..9ddcd7b 100644 --- a/protocol.py +++ b/protocol.py @@ -12,14 +12,14 @@ from google.cloud.ndb import OR from granary import as1 from oauth_dropins.webutil.appengine_config import ndb_client from oauth_dropins.webutil.flask_util import cloud_tasks_only +from oauth_dropins.webutil import util +from oauth_dropins.webutil.util import json_dumps, json_loads import werkzeug.exceptions import common from common import add, DOMAIN_BLOCKLIST, DOMAINS, error, subdomain_wrap from flask_app import app -from models import Follower, Object, PROTOCOLS, Target, User -from oauth_dropins.webutil import util -from oauth_dropins.webutil.util import json_dumps, json_loads +from models import Follower, get_for_copies, Object, PROTOCOLS, Target, User SUPPORTED_TYPES = ( 'accept', @@ -988,8 +988,7 @@ class Protocol: logger.info(f'Raw targets: {target_uris}') if target_uris: - origs = {u.key.id() for u in User.get_for_copies(target_uris)} | \ - {o.key.id() for o in Object.query(Object.copies.uri.IN(target_uris))} + origs = {key.id() for key in get_for_copies(target_uris, keys_only=True)} if origs: target_uris |= origs logger.info(f'Added originals: {origs}') diff --git a/tests/test_models.py b/tests/test_models.py index 182ed0a..f00a30e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,6 +21,7 @@ from oauth_dropins.webutil import util from .testutil import Fake, OtherFake, TestCase from atproto import ATProto +import models from models import Follower, Object, OBJECT_EXPIRE_AGE, Target, User import protocol from protocol import Protocol @@ -99,13 +100,6 @@ class UserTest(TestCase): user.atproto_did = 'did:plc:123' user.atproto_did = None - def test_get_for_copies(self): - self.assertEqual([], User.get_for_copies(['did:plc:foo'])) - - target = Target(uri='did:plc:foo', protocol='atproto') - fake_user = self.make_user('fake:user', cls=Fake, copies=[target]) - self.assertEqual([fake_user], User.get_for_copies(['did:plc:foo'])) - def test_get_or_create_use_instead(self): user = Fake.get_or_create('a.b') user.use_instead = g.user.key @@ -794,6 +788,17 @@ class ObjectTest(TestCase): }, }, obj.our_as1) + def test_get_for_copies(self): + self.assertEqual([], models.get_for_copies(['foo', 'did:plc:bar'])) + + obj = self.store_object(id='fake:post', + copies=[Target(uri='other:foo', protocol='other')]) + user = self.make_user('other:user', cls=OtherFake, + copies=[Target(uri='fake:bar', protocol='fake')]) + + self.assert_entities_equal( + [obj, user], models.get_for_copies(['other:foo', 'fake:bar', 'baz'])) + class FollowerTest(TestCase):