kopia lustrzana https://github.com/snarfed/bridgy-fed
move creating a new ATProto user from ATProto.send to User.get_or_create
in progress, still need to load user profile object and write it to ATProto repopull/660/head
rodzic
03a9295224
commit
bfabfabea7
77
atproto.py
77
atproto.py
|
@ -10,9 +10,10 @@ import re
|
|||
from arroba import did
|
||||
from arroba.datastore_storage import AtpRepo, DatastoreStorage
|
||||
from arroba.repo import Repo, Write
|
||||
import arroba.server
|
||||
from arroba.storage import Action, CommitData
|
||||
from arroba.util import next_tid, parse_at_uri, service_jwt
|
||||
from flask import abort, g, request
|
||||
from flask import abort, request
|
||||
from google.cloud import ndb
|
||||
from granary import as1, bluesky
|
||||
from lexrpc import Client
|
||||
|
@ -34,7 +35,7 @@ from protocol import Protocol
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
storage = DatastoreStorage()
|
||||
arroba.server.storage = DatastoreStorage()
|
||||
|
||||
|
||||
class ATProto(User, Protocol):
|
||||
|
@ -204,72 +205,40 @@ class ATProto(User, Protocol):
|
|||
type = as1.object_type(as1.get_object(obj.as1))
|
||||
assert type in ('note', 'article')
|
||||
|
||||
user_key = PROTOCOLS[obj.source_protocol].actor_key(obj)
|
||||
if not user_key:
|
||||
from_cls = PROTOCOLS[obj.source_protocol]
|
||||
from_key = from_cls.actor_key(obj)
|
||||
if not from_key:
|
||||
logger.info(f"Couldn't find {obj.source_protocol} user for {obj.key}")
|
||||
return False
|
||||
|
||||
def create_atproto_commit_task(commit_data):
|
||||
common.create_task(queue='atproto-commit')
|
||||
|
||||
writes = []
|
||||
user = user_key.get()
|
||||
repo = None
|
||||
if user.atproto_did:
|
||||
# existing DID and repo
|
||||
did_doc = to_cls.load(user.atproto_did)
|
||||
pds = to_cls._pds_for(did_doc)
|
||||
if not pds or pds.rstrip('/') != url.rstrip('/'):
|
||||
logger.warning(f'{user_key} {user.atproto_did} PDS {pds} is not us')
|
||||
return False
|
||||
repo = storage.load_repo(user.atproto_did)
|
||||
repo.callback = create_atproto_commit_task
|
||||
|
||||
else:
|
||||
# create new DID, repo
|
||||
logger.info(f'Creating new did:plc for {user.key}')
|
||||
did_plc = did.create_plc(user.handle_as('atproto'),
|
||||
pds_url=common.host_url(),
|
||||
post_fn=util.requests_post)
|
||||
|
||||
ndb.transactional()
|
||||
def update_user_create_repo():
|
||||
Object.get_or_create(did_plc.did, raw=did_plc.doc)
|
||||
user.atproto_did = did_plc.did
|
||||
add(user.copies, Target(uri=did_plc.did, protocol=to_cls.LABEL))
|
||||
user.put()
|
||||
|
||||
assert not storage.load_repo(user.atproto_did)
|
||||
nonlocal repo
|
||||
repo = Repo.create(storage, user.atproto_did,
|
||||
handle=user.handle_as('atproto'),
|
||||
callback=create_atproto_commit_task,
|
||||
signing_key=did_plc.signing_key,
|
||||
rotation_key=did_plc.rotation_key)
|
||||
if user.obj and user.obj.as1:
|
||||
# create user profile
|
||||
writes.append(Write(action=Action.CREATE,
|
||||
collection='app.bsky.actor.profile',
|
||||
rkey='self', record=user.obj.as_bsky()))
|
||||
update_user_create_repo()
|
||||
|
||||
# load user
|
||||
user = from_cls.get_or_create(from_key.id(), propagate=True)
|
||||
assert user.atproto_did
|
||||
logger.info(f'{user.key} is {user.atproto_did}')
|
||||
assert repo
|
||||
did_doc = to_cls.load(user.atproto_did)
|
||||
pds = to_cls._pds_for(did_doc)
|
||||
if not pds or pds.rstrip('/') != url.rstrip('/'):
|
||||
logger.warning(f'{from_key} {user.atproto_did} PDS {pds} is not us')
|
||||
return False
|
||||
|
||||
# create record and commit in ATProto repo
|
||||
# load repo
|
||||
repo = arroba.server.storage.load_repo(user.atproto_did)
|
||||
assert repo
|
||||
repo.callback = lambda _: common.create_task(queue='atproto-commit')
|
||||
|
||||
# create record and commit
|
||||
ndb.transactional()
|
||||
def write():
|
||||
tid = next_tid()
|
||||
writes.append(Write(action=Action.CREATE, collection='app.bsky.feed.post',
|
||||
rkey=tid, record=obj.as_bsky()))
|
||||
repo.apply_writes(writes)
|
||||
repo.apply_writes(
|
||||
[Write(action=Action.CREATE, collection='app.bsky.feed.post',
|
||||
rkey=tid, record=obj.as_bsky())])
|
||||
|
||||
at_uri = f'at://{user.atproto_did}/app.bsky.feed.post/{tid}'
|
||||
add(obj.copies, Target(uri=at_uri, protocol=to_cls.ABBREV))
|
||||
obj.put()
|
||||
|
||||
write()
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
|
|
1
hub.py
1
hub.py
|
@ -61,7 +61,6 @@ def health_check():
|
|||
#
|
||||
# XRPC server
|
||||
#
|
||||
arroba.server.storage = DatastoreStorage()
|
||||
lexrpc.flask_server.init_flask(arroba.server.server, app)
|
||||
|
||||
|
||||
|
|
53
models.py
53
models.py
|
@ -6,6 +6,9 @@ import logging
|
|||
import random
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
from arroba import did
|
||||
from arroba.repo import Repo, Write
|
||||
from arroba.storage import Action
|
||||
import arroba.util
|
||||
from Crypto.PublicKey import RSA
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
@ -215,8 +218,13 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
|||
|
||||
@classmethod
|
||||
@ndb.transactional()
|
||||
def get_or_create(cls, id, **kwargs):
|
||||
"""Loads and returns a User. Creates it if necessary."""
|
||||
def get_or_create(cls, id, propagate=False, **kwargs):
|
||||
"""Loads and returns a User. Creates it if necessary.
|
||||
|
||||
Args:
|
||||
propagate (bool): whether to create copies of this user in push-based
|
||||
protocols, eg ATProto and Nostr.
|
||||
"""
|
||||
assert cls != User
|
||||
user = cls.get_by_id(id)
|
||||
if user:
|
||||
|
@ -226,7 +234,37 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
|||
logger.info(f'Setting {user.key} direct={direct}')
|
||||
user.direct = direct
|
||||
user.put()
|
||||
return user
|
||||
if not propagate:
|
||||
return user
|
||||
else:
|
||||
user = cls(id=id, **kwargs)
|
||||
|
||||
# TODO: fetch and store profile
|
||||
# self.obj = self.load(self.profile_id())
|
||||
|
||||
if propagate and cls.LABEL != 'atproto' and not user.atproto_did:
|
||||
# create new DID, repo
|
||||
logger.info(f'Creating new did:plc for {user.key}')
|
||||
did_plc = did.create_plc(user.handle_as('atproto'),
|
||||
pds_url=common.host_url(),
|
||||
post_fn=util.requests_post)
|
||||
|
||||
Object.get_or_create(did_plc.did, raw=did_plc.doc)
|
||||
user.atproto_did = did_plc.did
|
||||
add(user.copies, Target(uri=did_plc.did, protocol='atproto'))
|
||||
|
||||
repo = Repo.create(
|
||||
arroba.server.storage, user.atproto_did,
|
||||
handle=user.handle_as('atproto'),
|
||||
callback=lambda _: common.create_task(queue='atproto-commit'),
|
||||
signing_key=did_plc.signing_key,
|
||||
rotation_key=did_plc.rotation_key)
|
||||
|
||||
if user.obj and user.obj.as1:
|
||||
# create user profile
|
||||
repo.apply_writes([Write(action=Action.CREATE,
|
||||
collection='app.bsky.actor.profile',
|
||||
rkey='self', record=user.obj.as_bsky())])
|
||||
|
||||
# generate keys for all protocols _except_ our own
|
||||
#
|
||||
|
@ -235,13 +273,10 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
|||
if cls.LABEL != 'activitypub':
|
||||
# originally from django_salmon.magicsigs
|
||||
key = RSA.generate(KEY_BITS, randfunc=random.randbytes if DEBUG else None)
|
||||
kwargs.update({
|
||||
'mod': long_to_base64(key.n),
|
||||
'public_exponent': long_to_base64(key.e),
|
||||
'private_exponent': long_to_base64(key.d),
|
||||
})
|
||||
user.mod = long_to_base64(key.n)
|
||||
user.public_exponent = long_to_base64(key.e)
|
||||
user.private_exponent = long_to_base64(key.d)
|
||||
|
||||
user = cls(id=id, **kwargs)
|
||||
try:
|
||||
user.put()
|
||||
except AssertionError as e:
|
||||
|
|
|
@ -3,10 +3,13 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from arroba.mst import dag_cbor_cid
|
||||
import arroba.server
|
||||
from Crypto.PublicKey import ECC
|
||||
from flask import g
|
||||
from google.cloud import ndb
|
||||
from granary.tests.test_bluesky import ACTOR_PROFILE_BSKY
|
||||
from google.cloud.tasks_v2.types import Task
|
||||
from granary.tests.test_bluesky import ACTOR_AS, ACTOR_PROFILE_VIEW_BSKY
|
||||
from oauth_dropins.webutil.appengine_config import tasks_client
|
||||
from oauth_dropins.webutil.testutil import NOW, requests_response
|
||||
|
||||
# import first so that Fake is defined before URL routes are registered
|
||||
|
@ -31,7 +34,7 @@ class UserTest(TestCase):
|
|||
g.user = self.make_user('y.z', cls=Web)
|
||||
|
||||
def test_get_or_create(self):
|
||||
user = Fake.get_or_create('a.b')
|
||||
user = Fake.get_or_create('fake:user')
|
||||
|
||||
assert not user.direct
|
||||
assert user.mod
|
||||
|
@ -43,10 +46,29 @@ class UserTest(TestCase):
|
|||
assert user.private_pem()
|
||||
|
||||
# direct should get set even if the user exists
|
||||
same = Fake.get_or_create('a.b', direct=True)
|
||||
same = Fake.get_or_create('fake:user', direct=True)
|
||||
user.direct = True
|
||||
self.assert_entities_equal(same, user, ignore=['updated'])
|
||||
|
||||
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
||||
@patch('requests.post',
|
||||
return_value=requests_response('OK')) # create DID on PLC
|
||||
def test_get_or_create_propagate(self, mock_post, mock_create_task):
|
||||
Fake.fetchable = {'fake:user': ACTOR_AS}
|
||||
|
||||
user = Fake.get_or_create('fake:user', propagate=True)
|
||||
|
||||
# check user, record
|
||||
# TODO: check profile
|
||||
user = Fake.get_by_id('fake:user')
|
||||
self.assertEqual('fake:handle:user', user.handle)
|
||||
self.assertEqual([Target(uri=user.atproto_did, protocol='atproto')],
|
||||
user.copies)
|
||||
# check that the repo exists
|
||||
repo = arroba.server.storage.load_repo(user.atproto_did)
|
||||
|
||||
mock_create_task.assert_called()
|
||||
|
||||
def test_validate_atproto_did(self):
|
||||
user = Fake()
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue