From 5b16386fbc85273a33d7acc0382a24dec2dbbb92 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Mon, 18 Sep 2023 11:41:01 -0700 Subject: [PATCH] User.get_by_atproto_did and Object.as1 from bsky: support native ATProto users --- models.py | 13 ++++++++++--- tests/test_models.py | 29 +++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/models.py b/models.py index ddada2d..173883b 100644 --- a/models.py +++ b/models.py @@ -142,9 +142,12 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): @staticmethod def get_by_atproto_did(did): - """Fetches the user across all protocols with the given atproto_did. + """Fetches the user across all protocols with the given ATProto DID. - If more than one user has the given atproto_did, this returns an + Prefers bridged (ie not :class:`ATProto`) users to :class:`ATProto` + users. + + If more than one :class:`ATProto` user exists, this returns an arbitrary one! Args: @@ -156,12 +159,14 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): assert did for cls in set(PROTOCOLS.values()): - if not cls: + if not cls or cls.ABBREV == 'atproto': continue user = cls.query(cls.atproto_did == did).get() if user: return user + return PROTOCOLS['atproto'].get_by_id(did) + @classmethod @ndb.transactional() def get_or_create(cls, id, **kwargs): @@ -489,6 +494,8 @@ class Object(StringIdModel): else None) if field: repo, _, _ = arroba.util.parse_at_uri(self.key.id()) + # load matching user. prefer bridged non-ATProto user + # to ATProto user user = User.get_by_atproto_did(repo) if user: logger.debug(f'Filling in {field} from {user}') diff --git a/tests/test_models.py b/tests/test_models.py index d6dc186..5c4e0cb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -55,8 +55,13 @@ class UserTest(TestCase): def test_get_by_atproto_did(self): self.assertIsNone(User.get_by_atproto_did('did:plc:foo')) - user = self.make_user('fake:user', cls=Fake, atproto_did='did:plc:foo') - self.assertEqual(user, User.get_by_atproto_did('did:plc:foo')) + + atp_user = self.make_user('did:plc:foo', cls=ATProto) + self.assertEqual(atp_user, User.get_by_atproto_did('did:plc:foo')) + + # prefer non-ATProto user, if available + fake_user = self.make_user('fake:user', cls=Fake, atproto_did='did:plc:foo') + self.assertEqual(fake_user, User.get_by_atproto_did('did:plc:foo')) def test_get_or_create_use_instead(self): user = Fake.get_or_create('a.b') @@ -433,15 +438,15 @@ class ObjectTest(TestCase): obj = Object(id='at://did:plc:foo/co.ll/123', bsky=like_bsky) self.assert_equals(like_as1, obj.as1) - # user without Object - user = self.make_user(id='fake:user', cls=Fake, atproto_did=f'did:plc:foo') + # ATProto user without Object + user = self.make_user(id='did:plc:foo', cls=ATProto) obj = Object(id='at://did:plc:foo/co.ll/123', bsky=like_bsky) self.assertEqual({ **like_as1, - 'actor': 'fake:user', + 'actor': 'did:plc:foo', }, obj.as1) - # user with Object + # ATProto user with Object user.obj = self.store_object(id='fake:profile', our_as1={'foo': 'bar'}) user.put() obj = Object(id='at://did:plc:foo/co.ll/123', bsky=like_bsky) @@ -453,6 +458,18 @@ class ObjectTest(TestCase): }, }, obj.as1) + # Fake user, should prefer to ATProto user + user.obj = self.store_object(id='fake:profile', our_as1={'baz': 'biff'}) + user.put() + obj = Object(id='at://did:plc:foo/co.ll/123', bsky=like_bsky) + self.assertEqual({ + **like_as1, + 'actor': { + 'id': 'fake:profile', + 'baz': 'biff', + }, + }, obj.as1) + def test_as1_from_mf2_uses_url_as_id(self): obj = Object(mf2={ 'properties': {