move User.get_for_copy/ies to module level, add Object results

pull/701/head
Ryan Barrett 2023-10-26 16:00:03 -07:00
rodzic 5843235fd1
commit cfbfba654e
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
4 zmienionych plików z 60 dodań i 44 usunięć

3
ids.py
Wyświetl plik

@ -5,6 +5,7 @@ https://fed.brid.gy/docs#translate
import re
from common import subdomain_wrap, SUPERDOMAIN
import models
def translate_user_id(*, id, from_proto, to_proto):
@ -29,7 +30,7 @@ def translate_user_id(*, id, from_proto, to_proto):
user = from_proto.get_by_id(id)
return user.atproto_did if user else None
case ('atproto', _):
user = from_proto.get_for_copy(id)
user = models.get_for_copy(id)
return user.key.id() if user else None
case (_, 'activitypub'):
return subdomain_wrap(from_proto, f'/ap/{id}')

Wyświetl plik

@ -197,35 +197,6 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
return user
@staticmethod
def get_for_copy(copy_id):
"""Fetches a user with a given id in copies.
Thin wrapper around :meth:`User.get_copies` that returns the first
matching :class:`User`.
"""
users = User.get_for_copies([copy_id])
if users:
return users[0]
@staticmethod
def get_for_copies(copy_ids):
"""Fetches users (across all protocols) for a given set of copies.
Args:
copy_ids (sequence of str)
Returns:
sequence of :class:`User` subclass instances
"""
assert copy_ids
return list(itertools.chain(*(
cls.query(cls.copies.uri.IN(copy_ids))
for cls in set(PROTOCOLS.values()) if cls)))
# TODO: default to looking up copy_ids as key ids, across protocols? is
# that useful anywhere?
@classmethod
@ndb.transactional()
def get_or_create(cls, id, propagate=False, **kwargs):
@ -960,9 +931,8 @@ class Object(StringIdModel):
if not ids:
return
origs = (User.get_for_copies(ids)
+ Object.query(Object.copies.uri.IN(ids)).fetch())
origs = get_for_copies(ids)
replaced = False
def replace(obj, field):
@ -1165,3 +1135,44 @@ def fetch_page(query, model_class, by=None):
new_before = new_before.isoformat()
return results, new_before, new_after
def get_for_copy(copy_id, keys_only=None):
"""Fetches a user or object with a given id in copies.
Thin wrapper around :func:`get_copies` that returns the first
matching result.
Args:
copy_id (str)
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
Returns:
User or Object:
"""
got = get_for_copies([copy_id], keys_only=keys_only)
if got:
return got[0]
def get_for_copies(copy_ids, keys_only=None):
"""Fetches users (across all protocols) for a given set of copies.
Args:
copy_ids (sequence of str)
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
Returns:
sequence of User and/or Object
"""
assert copy_ids
classes = set(cls for cls in PROTOCOLS.values() if cls)
classes.add(Object)
return list(itertools.chain(*(
cls.query(cls.copies.uri.IN(copy_ids)).iter(keys_only=keys_only)
for cls in classes)))
# TODO: default to looking up copy_ids as key ids, across protocols? is
# that useful anywhere?

Wyświetl plik

@ -12,14 +12,14 @@ from google.cloud.ndb import OR
from granary import as1
from oauth_dropins.webutil.appengine_config import ndb_client
from oauth_dropins.webutil.flask_util import cloud_tasks_only
from oauth_dropins.webutil import util
from oauth_dropins.webutil.util import json_dumps, json_loads
import werkzeug.exceptions
import common
from common import add, DOMAIN_BLOCKLIST, DOMAINS, error, subdomain_wrap
from flask_app import app
from models import Follower, Object, PROTOCOLS, Target, User
from oauth_dropins.webutil import util
from oauth_dropins.webutil.util import json_dumps, json_loads
from models import Follower, get_for_copies, Object, PROTOCOLS, Target, User
SUPPORTED_TYPES = (
'accept',
@ -988,8 +988,7 @@ class Protocol:
logger.info(f'Raw targets: {target_uris}')
if target_uris:
origs = {u.key.id() for u in User.get_for_copies(target_uris)} | \
{o.key.id() for o in Object.query(Object.copies.uri.IN(target_uris))}
origs = {key.id() for key in get_for_copies(target_uris, keys_only=True)}
if origs:
target_uris |= origs
logger.info(f'Added originals: {origs}')

Wyświetl plik

@ -21,6 +21,7 @@ from oauth_dropins.webutil import util
from .testutil import Fake, OtherFake, TestCase
from atproto import ATProto
import models
from models import Follower, Object, OBJECT_EXPIRE_AGE, Target, User
import protocol
from protocol import Protocol
@ -99,13 +100,6 @@ class UserTest(TestCase):
user.atproto_did = 'did:plc:123'
user.atproto_did = None
def test_get_for_copies(self):
self.assertEqual([], User.get_for_copies(['did:plc:foo']))
target = Target(uri='did:plc:foo', protocol='atproto')
fake_user = self.make_user('fake:user', cls=Fake, copies=[target])
self.assertEqual([fake_user], User.get_for_copies(['did:plc:foo']))
def test_get_or_create_use_instead(self):
user = Fake.get_or_create('a.b')
user.use_instead = g.user.key
@ -794,6 +788,17 @@ class ObjectTest(TestCase):
},
}, obj.our_as1)
def test_get_for_copies(self):
self.assertEqual([], models.get_for_copies(['foo', 'did:plc:bar']))
obj = self.store_object(id='fake:post',
copies=[Target(uri='other:foo', protocol='other')])
user = self.make_user('other:user', cls=OtherFake,
copies=[Target(uri='fake:bar', protocol='fake')])
self.assert_entities_equal(
[obj, user], models.get_for_copies(['other:foo', 'fake:bar', 'baz']))
class FollowerTest(TestCase):