kopia lustrzana https://github.com/snarfed/bridgy-fed
				
				
				
			move creating a new ATProto user from ATProto.send to User.get_or_create
in progress, still need to load user profile object and write it to ATProto repopull/660/head
							rodzic
							
								
									03a9295224
								
							
						
					
					
						commit
						bfabfabea7
					
				
							
								
								
									
										77
									
								
								atproto.py
								
								
								
								
							
							
						
						
									
										77
									
								
								atproto.py
								
								
								
								
							| 
						 | 
				
			
			@ -10,9 +10,10 @@ import re
 | 
			
		|||
from arroba import did
 | 
			
		||||
from arroba.datastore_storage import AtpRepo, DatastoreStorage
 | 
			
		||||
from arroba.repo import Repo, Write
 | 
			
		||||
import arroba.server
 | 
			
		||||
from arroba.storage import Action, CommitData
 | 
			
		||||
from arroba.util import next_tid, parse_at_uri, service_jwt
 | 
			
		||||
from flask import abort, g, request
 | 
			
		||||
from flask import abort, request
 | 
			
		||||
from google.cloud import ndb
 | 
			
		||||
from granary import as1, bluesky
 | 
			
		||||
from lexrpc import Client
 | 
			
		||||
| 
						 | 
				
			
			@ -34,7 +35,7 @@ from protocol import Protocol
 | 
			
		|||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
storage = DatastoreStorage()
 | 
			
		||||
arroba.server.storage = DatastoreStorage()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ATProto(User, Protocol):
 | 
			
		||||
| 
						 | 
				
			
			@ -204,72 +205,40 @@ class ATProto(User, Protocol):
 | 
			
		|||
            type = as1.object_type(as1.get_object(obj.as1))
 | 
			
		||||
        assert type in ('note', 'article')
 | 
			
		||||
 | 
			
		||||
        user_key = PROTOCOLS[obj.source_protocol].actor_key(obj)
 | 
			
		||||
        if not user_key:
 | 
			
		||||
        from_cls = PROTOCOLS[obj.source_protocol]
 | 
			
		||||
        from_key = from_cls.actor_key(obj)
 | 
			
		||||
        if not from_key:
 | 
			
		||||
            logger.info(f"Couldn't find {obj.source_protocol} user for {obj.key}")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        def create_atproto_commit_task(commit_data):
 | 
			
		||||
            common.create_task(queue='atproto-commit')
 | 
			
		||||
 | 
			
		||||
        writes = []
 | 
			
		||||
        user = user_key.get()
 | 
			
		||||
        repo = None
 | 
			
		||||
        if user.atproto_did:
 | 
			
		||||
            # existing DID and repo
 | 
			
		||||
            did_doc = to_cls.load(user.atproto_did)
 | 
			
		||||
            pds = to_cls._pds_for(did_doc)
 | 
			
		||||
            if not pds or pds.rstrip('/') != url.rstrip('/'):
 | 
			
		||||
                logger.warning(f'{user_key} {user.atproto_did} PDS {pds} is not us')
 | 
			
		||||
                return False
 | 
			
		||||
            repo = storage.load_repo(user.atproto_did)
 | 
			
		||||
            repo.callback = create_atproto_commit_task
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            # create new DID, repo
 | 
			
		||||
            logger.info(f'Creating new did:plc for {user.key}')
 | 
			
		||||
            did_plc = did.create_plc(user.handle_as('atproto'),
 | 
			
		||||
                                     pds_url=common.host_url(),
 | 
			
		||||
                                     post_fn=util.requests_post)
 | 
			
		||||
 | 
			
		||||
            ndb.transactional()
 | 
			
		||||
            def update_user_create_repo():
 | 
			
		||||
                Object.get_or_create(did_plc.did, raw=did_plc.doc)
 | 
			
		||||
                user.atproto_did = did_plc.did
 | 
			
		||||
                add(user.copies, Target(uri=did_plc.did, protocol=to_cls.LABEL))
 | 
			
		||||
                user.put()
 | 
			
		||||
 | 
			
		||||
                assert not storage.load_repo(user.atproto_did)
 | 
			
		||||
                nonlocal repo
 | 
			
		||||
                repo = Repo.create(storage, user.atproto_did,
 | 
			
		||||
                                   handle=user.handle_as('atproto'),
 | 
			
		||||
                                   callback=create_atproto_commit_task,
 | 
			
		||||
                                   signing_key=did_plc.signing_key,
 | 
			
		||||
                                   rotation_key=did_plc.rotation_key)
 | 
			
		||||
                if user.obj and user.obj.as1:
 | 
			
		||||
                    # create user profile
 | 
			
		||||
                    writes.append(Write(action=Action.CREATE,
 | 
			
		||||
                                        collection='app.bsky.actor.profile',
 | 
			
		||||
                                        rkey='self', record=user.obj.as_bsky()))
 | 
			
		||||
            update_user_create_repo()
 | 
			
		||||
 | 
			
		||||
        # load user
 | 
			
		||||
        user = from_cls.get_or_create(from_key.id(), propagate=True)
 | 
			
		||||
        assert user.atproto_did
 | 
			
		||||
        logger.info(f'{user.key} is {user.atproto_did}')
 | 
			
		||||
        assert repo
 | 
			
		||||
        did_doc = to_cls.load(user.atproto_did)
 | 
			
		||||
        pds = to_cls._pds_for(did_doc)
 | 
			
		||||
        if not pds or pds.rstrip('/') != url.rstrip('/'):
 | 
			
		||||
            logger.warning(f'{from_key} {user.atproto_did} PDS {pds} is not us')
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # create record and commit in ATProto repo
 | 
			
		||||
        # load repo
 | 
			
		||||
        repo = arroba.server.storage.load_repo(user.atproto_did)
 | 
			
		||||
        assert repo
 | 
			
		||||
        repo.callback = lambda _: common.create_task(queue='atproto-commit')
 | 
			
		||||
 | 
			
		||||
        # create record and commit
 | 
			
		||||
        ndb.transactional()
 | 
			
		||||
        def write():
 | 
			
		||||
            tid = next_tid()
 | 
			
		||||
            writes.append(Write(action=Action.CREATE, collection='app.bsky.feed.post',
 | 
			
		||||
                                rkey=tid, record=obj.as_bsky()))
 | 
			
		||||
            repo.apply_writes(writes)
 | 
			
		||||
            repo.apply_writes(
 | 
			
		||||
                [Write(action=Action.CREATE, collection='app.bsky.feed.post',
 | 
			
		||||
                       rkey=tid, record=obj.as_bsky())])
 | 
			
		||||
 | 
			
		||||
            at_uri = f'at://{user.atproto_did}/app.bsky.feed.post/{tid}'
 | 
			
		||||
            add(obj.copies, Target(uri=at_uri, protocol=to_cls.ABBREV))
 | 
			
		||||
            obj.put()
 | 
			
		||||
 | 
			
		||||
        write()
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1
									
								
								hub.py
								
								
								
								
							
							
						
						
									
										1
									
								
								hub.py
								
								
								
								
							| 
						 | 
				
			
			@ -61,7 +61,6 @@ def health_check():
 | 
			
		|||
#
 | 
			
		||||
# XRPC server
 | 
			
		||||
#
 | 
			
		||||
arroba.server.storage = DatastoreStorage()
 | 
			
		||||
lexrpc.flask_server.init_flask(arroba.server.server, app)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										53
									
								
								models.py
								
								
								
								
							
							
						
						
									
										53
									
								
								models.py
								
								
								
								
							| 
						 | 
				
			
			@ -6,6 +6,9 @@ import logging
 | 
			
		|||
import random
 | 
			
		||||
from urllib.parse import quote, urlparse
 | 
			
		||||
 | 
			
		||||
from arroba import did
 | 
			
		||||
from arroba.repo import Repo, Write
 | 
			
		||||
from arroba.storage import Action
 | 
			
		||||
import arroba.util
 | 
			
		||||
from Crypto.PublicKey import RSA
 | 
			
		||||
from cryptography.hazmat.primitives import serialization
 | 
			
		||||
| 
						 | 
				
			
			@ -215,8 +218,13 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @ndb.transactional()
 | 
			
		||||
    def get_or_create(cls, id, **kwargs):
 | 
			
		||||
        """Loads and returns a User. Creates it if necessary."""
 | 
			
		||||
    def get_or_create(cls, id, propagate=False, **kwargs):
 | 
			
		||||
        """Loads and returns a User. Creates it if necessary.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          propagate (bool): whether to create copies of this user in push-based
 | 
			
		||||
            protocols, eg ATProto and Nostr.
 | 
			
		||||
        """
 | 
			
		||||
        assert cls != User
 | 
			
		||||
        user = cls.get_by_id(id)
 | 
			
		||||
        if user:
 | 
			
		||||
| 
						 | 
				
			
			@ -226,7 +234,37 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
 | 
			
		|||
                logger.info(f'Setting {user.key} direct={direct}')
 | 
			
		||||
                user.direct = direct
 | 
			
		||||
                user.put()
 | 
			
		||||
            return user
 | 
			
		||||
            if not propagate:
 | 
			
		||||
                return user
 | 
			
		||||
        else:
 | 
			
		||||
            user = cls(id=id, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # TODO: fetch and store profile
 | 
			
		||||
        # self.obj = self.load(self.profile_id())
 | 
			
		||||
 | 
			
		||||
        if propagate and cls.LABEL != 'atproto' and not user.atproto_did:
 | 
			
		||||
            # create new DID, repo
 | 
			
		||||
            logger.info(f'Creating new did:plc for {user.key}')
 | 
			
		||||
            did_plc = did.create_plc(user.handle_as('atproto'),
 | 
			
		||||
                                     pds_url=common.host_url(),
 | 
			
		||||
                                     post_fn=util.requests_post)
 | 
			
		||||
 | 
			
		||||
            Object.get_or_create(did_plc.did, raw=did_plc.doc)
 | 
			
		||||
            user.atproto_did = did_plc.did
 | 
			
		||||
            add(user.copies, Target(uri=did_plc.did, protocol='atproto'))
 | 
			
		||||
 | 
			
		||||
            repo = Repo.create(
 | 
			
		||||
                arroba.server.storage, user.atproto_did,
 | 
			
		||||
                handle=user.handle_as('atproto'),
 | 
			
		||||
                callback=lambda _: common.create_task(queue='atproto-commit'),
 | 
			
		||||
                signing_key=did_plc.signing_key,
 | 
			
		||||
                rotation_key=did_plc.rotation_key)
 | 
			
		||||
 | 
			
		||||
            if user.obj and user.obj.as1:
 | 
			
		||||
                # create user profile
 | 
			
		||||
                repo.apply_writes([Write(action=Action.CREATE,
 | 
			
		||||
                                         collection='app.bsky.actor.profile',
 | 
			
		||||
                                         rkey='self', record=user.obj.as_bsky())])
 | 
			
		||||
 | 
			
		||||
        # generate keys for all protocols _except_ our own
 | 
			
		||||
        #
 | 
			
		||||
| 
						 | 
				
			
			@ -235,13 +273,10 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
 | 
			
		|||
        if cls.LABEL != 'activitypub':
 | 
			
		||||
            # originally from django_salmon.magicsigs
 | 
			
		||||
            key = RSA.generate(KEY_BITS, randfunc=random.randbytes if DEBUG else None)
 | 
			
		||||
            kwargs.update({
 | 
			
		||||
                    'mod': long_to_base64(key.n),
 | 
			
		||||
                    'public_exponent': long_to_base64(key.e),
 | 
			
		||||
                    'private_exponent': long_to_base64(key.d),
 | 
			
		||||
            })
 | 
			
		||||
            user.mod = long_to_base64(key.n)
 | 
			
		||||
            user.public_exponent = long_to_base64(key.e)
 | 
			
		||||
            user.private_exponent = long_to_base64(key.d)
 | 
			
		||||
 | 
			
		||||
        user = cls(id=id, **kwargs)
 | 
			
		||||
        try:
 | 
			
		||||
            user.put()
 | 
			
		||||
        except AssertionError as e:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,10 +3,13 @@
 | 
			
		|||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
from arroba.mst import dag_cbor_cid
 | 
			
		||||
import arroba.server
 | 
			
		||||
from Crypto.PublicKey import ECC
 | 
			
		||||
from flask import g
 | 
			
		||||
from google.cloud import ndb
 | 
			
		||||
from granary.tests.test_bluesky import ACTOR_PROFILE_BSKY
 | 
			
		||||
from google.cloud.tasks_v2.types import Task
 | 
			
		||||
from granary.tests.test_bluesky import ACTOR_AS, ACTOR_PROFILE_VIEW_BSKY
 | 
			
		||||
from oauth_dropins.webutil.appengine_config import tasks_client
 | 
			
		||||
from oauth_dropins.webutil.testutil import NOW, requests_response
 | 
			
		||||
 | 
			
		||||
# import first so that Fake is defined before URL routes are registered
 | 
			
		||||
| 
						 | 
				
			
			@ -31,7 +34,7 @@ class UserTest(TestCase):
 | 
			
		|||
        g.user = self.make_user('y.z', cls=Web)
 | 
			
		||||
 | 
			
		||||
    def test_get_or_create(self):
 | 
			
		||||
        user = Fake.get_or_create('a.b')
 | 
			
		||||
        user = Fake.get_or_create('fake:user')
 | 
			
		||||
 | 
			
		||||
        assert not user.direct
 | 
			
		||||
        assert user.mod
 | 
			
		||||
| 
						 | 
				
			
			@ -43,10 +46,29 @@ class UserTest(TestCase):
 | 
			
		|||
        assert user.private_pem()
 | 
			
		||||
 | 
			
		||||
        # direct should get set even if the user exists
 | 
			
		||||
        same = Fake.get_or_create('a.b', direct=True)
 | 
			
		||||
        same = Fake.get_or_create('fake:user', direct=True)
 | 
			
		||||
        user.direct = True
 | 
			
		||||
        self.assert_entities_equal(same, user, ignore=['updated'])
 | 
			
		||||
 | 
			
		||||
    @patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
 | 
			
		||||
    @patch('requests.post',
 | 
			
		||||
           return_value=requests_response('OK'))  # create DID on PLC
 | 
			
		||||
    def test_get_or_create_propagate(self, mock_post, mock_create_task):
 | 
			
		||||
        Fake.fetchable = {'fake:user': ACTOR_AS}
 | 
			
		||||
 | 
			
		||||
        user = Fake.get_or_create('fake:user', propagate=True)
 | 
			
		||||
 | 
			
		||||
        # check user, record
 | 
			
		||||
        # TODO: check profile
 | 
			
		||||
        user = Fake.get_by_id('fake:user')
 | 
			
		||||
        self.assertEqual('fake:handle:user', user.handle)
 | 
			
		||||
        self.assertEqual([Target(uri=user.atproto_did, protocol='atproto')],
 | 
			
		||||
                         user.copies)
 | 
			
		||||
        # check that the repo exists
 | 
			
		||||
        repo = arroba.server.storage.load_repo(user.atproto_did)
 | 
			
		||||
 | 
			
		||||
        mock_create_task.assert_called()
 | 
			
		||||
 | 
			
		||||
    def test_validate_atproto_did(self):
 | 
			
		||||
        user = Fake()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Ładowanie…
	
		Reference in New Issue