From 4b94b4397cab290bbbd79df5b412d393c25936d5 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Thu, 28 Sep 2023 13:56:22 -0700 Subject: [PATCH] User.get_or_create: fetch and propagate user profile object --- models.py | 31 +++++++++++++++++++------------ tests/test_models.py | 13 ++++++++++--- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/models.py b/models.py index 9959ec5..283468c 100644 --- a/models.py +++ b/models.py @@ -8,8 +8,9 @@ from urllib.parse import quote, urlparse from arroba import did from arroba.repo import Repo, Write +import arroba.server from arroba.storage import Action -import arroba.util +from arroba.util import at_uri, parse_at_uri from Crypto.PublicKey import RSA from cryptography.hazmat.primitives import serialization import dag_json @@ -239,9 +240,6 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): else: user = cls(id=id, **kwargs) - # TODO: fetch and store profile - # self.obj = self.load(self.profile_id()) - if propagate and cls.LABEL != 'atproto' and not user.atproto_did: # create new DID, repo logger.info(f'Creating new did:plc for {user.key}') @@ -253,19 +251,28 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): user.atproto_did = did_plc.did add(user.copies, Target(uri=did_plc.did, protocol='atproto')) + # fetch and store profile + if not user.obj: + user.obj = user.load(user.profile_id()) + + initial_writes = None + if user.obj and user.obj.as1: + # create user profile + initial_writes = [Write(action=Action.CREATE, + collection='app.bsky.actor.profile', + rkey='self', record=user.obj.as_bsky())] + uri = at_uri(user.atproto_did, 'app.bsky.actor.profile', 'self') + add(user.obj.copies, Target(uri=uri, protocol='atproto')) + user.obj.put() + repo = Repo.create( arroba.server.storage, user.atproto_did, handle=user.handle_as('atproto'), callback=lambda _: common.create_task(queue='atproto-commit'), + initial_writes=initial_writes, signing_key=did_plc.signing_key, rotation_key=did_plc.rotation_key) - if user.obj and user.obj.as1: - # create user profile - repo.apply_writes([Write(action=Action.CREATE, - collection='app.bsky.actor.profile', - rkey='self', record=user.obj.as_bsky())]) - # generate keys for all protocols _except_ our own # # these can use urandom() and do nontrivial math, so they can take time @@ -574,7 +581,7 @@ class Object(StringIdModel): obj = as2.to_as1(redirect_unwrap(self.as2)) elif self.bsky: - owner, _, _ = arroba.util.parse_at_uri(self.key.id()) + owner, _, _ = parse_at_uri(self.key.id()) ATProto = PROTOCOLS['atproto'] handle = ATProto(id=owner).handle obj = bluesky.to_as1(self.bsky, repo_did=owner, repo_handle=handle, @@ -650,7 +657,7 @@ class Object(StringIdModel): assert '^^' not in self.key.id() if self.key.id().startswith('at://'): - repo, _, _ = arroba.util.parse_at_uri(self.key.id()) + repo, _, _ = parse_at_uri(self.key.id()) if not repo.startswith('did:'): # TODO: if we hit this, that means the AppView gave us an AT URI # with a handle repo/authority instead of DID. that's surprising! diff --git a/tests/test_models.py b/tests/test_models.py index a951b71..5c6bd1d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,7 @@ from unittest.mock import patch from arroba.mst import dag_cbor_cid import arroba.server +from arroba.util import at_uri from Crypto.PublicKey import ECC from flask import g from google.cloud import ndb @@ -58,15 +59,21 @@ class UserTest(TestCase): user = Fake.get_or_create('fake:user', propagate=True) - # check user, record - # TODO: check profile + # check user, repo user = Fake.get_by_id('fake:user') self.assertEqual('fake:handle:user', user.handle) self.assertEqual([Target(uri=user.atproto_did, protocol='atproto')], user.copies) - # check that the repo exists repo = arroba.server.storage.load_repo(user.atproto_did) + # check profile record + profile = repo.get_record('app.bsky.actor.profile', 'self') + self.assertEqual(ACTOR_PROFILE_VIEW_BSKY, profile) + + uri = at_uri(user.atproto_did, 'app.bsky.actor.profile', 'self') + self.assertEqual([Target(uri=uri, protocol='atproto')], + Object.get_by_id(id='fake:user').copies) + mock_create_task.assert_called() def test_validate_atproto_did(self):