From 70edf4173ef408f5d30f9c74af69358bc4f96124 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Fri, 1 Sep 2023 14:18:50 -0700 Subject: [PATCH] ATProto: when creating new repo, add user profile record if available --- atproto.py | 17 ++++++++++++++--- tests/test_atproto.py | 20 +++++++++++++++++++- tests/testutil.py | 3 ++- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/atproto.py b/atproto.py index 9721afc..ef1c38d 100644 --- a/atproto.py +++ b/atproto.py @@ -162,12 +162,14 @@ class ATProto(User, Protocol): user = user_key.get() privkey = user.k256_key() if user.atproto_did: + # existing DID did_doc = cls.load(user.atproto_did) pds = did_doc.raw['services']['atproto_pds']['endpoint'] if pds.rstrip('/') != url.rstrip('/'): logger.warning(f'{user_key} {user.atproto_did} PDS {pds} is not us') return False else: + # create new DID # STATE: (unneeded?) new User.atproto_handle() did_plc = did.create_plc(user.atproto_handle(), privkey=privkey, pds_hostname=request.host, @@ -181,14 +183,23 @@ class ATProto(User, Protocol): user.put() update() + repo = storage.load_repo(did=user.atproto_did) + writes = [] if repo is None: + # create repo handle = user.readable_id if user.readable_id != user.atproto_did else None repo = Repo.create(storage, user.atproto_did, privkey, handle=handle) + if user.obj and user.obj.as1: + # create user profile + writes.append(Write(action=Action.CREATE, + collection='app.bsky.actor.profile', + rkey='self', record=user.obj.as_bsky())) - create = Write(action=Action.CREATE, collection='app.bsky.feed.post', - rkey=next_tid(), record=obj.as_bsky()) - repo.apply_writes([create], privkey) + # create record + writes.append(Write(action=Action.CREATE, collection='app.bsky.feed.post', + rkey=next_tid(), record=obj.as_bsky())) + repo.apply_writes(writes, privkey) return True @classmethod diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 56f81ff..79ec07d 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -175,7 +175,8 @@ class ATProtoTest(TestCase): self.store_object(id='did:plc:foo', raw=DID_DOC) self.assertEqual('@han.dull@atproto.brid.gy', user.ap_address()) - @patch('requests.post', return_value=requests_response('OK')) + @patch('requests.post', + return_value=requests_response('OK')) # create DID on PLC def test_send_new_repo(self, mock_post): user = self.make_user(id='fake:user', cls=Fake) obj = self.store_object(id='fake:post', source_protocol='fake', our_as1={ @@ -198,6 +199,23 @@ class ATProtoTest(TestCase): record = repo.get_record('app.bsky.feed.post', arroba.util._tid_last) self.assertEqual(POST_BSKY, record) + @patch('requests.post', + return_value=requests_response('OK')) # create DID on PLC + def test_send_new_repo_includes_user_profile(self, mock_post): + user = self.make_user(id='fake:user', cls=Fake, obj_as1=ACTOR_AS) + obj = self.store_object(id='fake:post', source_protocol='fake', our_as1={ + **POST_AS, + 'actor': 'fake:user', + }) + self.assertTrue(ATProto.send(obj, 'http://localhost/')) + + # check profile, record + repo = DatastoreStorage().load_repo(did=user.key.get().atproto_did) + profile = repo.get_record('app.bsky.actor.profile', 'self') + self.assertEqual(ACTOR_PROFILE_VIEW_BSKY, profile) + record = repo.get_record('app.bsky.feed.post', arroba.util._tid_last) + self.assertEqual(POST_BSKY, record) + def test_send_existing_repo(self): user = self.make_user(id='fake:user', cls=Fake, atproto_did='did:plc:foo') diff --git a/tests/testutil.py b/tests/testutil.py index 15970d6..cf0fd90 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -239,6 +239,7 @@ class TestCase(unittest.TestCase, testutil.Asserts): """Reuse RSA key across Users because generating it is expensive.""" obj_key = None + obj_as1 = kwargs.pop('obj_as1', None) obj_as2 = kwargs.pop('obj_as2', None) obj_mf2 = kwargs.pop('obj_mf2', None) obj_id = kwargs.pop('obj_id', None) @@ -247,7 +248,7 @@ class TestCase(unittest.TestCase, testutil.Asserts): or util.get_url((obj_mf2 or {}), 'properties') or str(self.last_make_user_id)) self.last_make_user_id += 1 - obj_key = Object(id=obj_id, as2=obj_as2, mf2=obj_mf2).put() + obj_key = Object(id=obj_id, our_as1=obj_as1, as2=obj_as2, mf2=obj_mf2).put() user = cls(id=id, direct=True,