move creating a new ATProto user from ATProto.send to User.get_or_create

in progress, still need to load user profile object and write it to ATProto repo
pull/660/head
Ryan Barrett 2023-09-28 13:42:16 -07:00
rodzic 03a9295224
commit bfabfabea7
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
4 zmienionych plików z 92 dodań i 67 usunięć

Wyświetl plik

@ -10,9 +10,10 @@ import re
from arroba import did
from arroba.datastore_storage import AtpRepo, DatastoreStorage
from arroba.repo import Repo, Write
import arroba.server
from arroba.storage import Action, CommitData
from arroba.util import next_tid, parse_at_uri, service_jwt
from flask import abort, g, request
from flask import abort, request
from google.cloud import ndb
from granary import as1, bluesky
from lexrpc import Client
@ -34,7 +35,7 @@ from protocol import Protocol
logger = logging.getLogger(__name__)
storage = DatastoreStorage()
arroba.server.storage = DatastoreStorage()
class ATProto(User, Protocol):
@ -204,72 +205,40 @@ class ATProto(User, Protocol):
type = as1.object_type(as1.get_object(obj.as1))
assert type in ('note', 'article')
user_key = PROTOCOLS[obj.source_protocol].actor_key(obj)
if not user_key:
from_cls = PROTOCOLS[obj.source_protocol]
from_key = from_cls.actor_key(obj)
if not from_key:
logger.info(f"Couldn't find {obj.source_protocol} user for {obj.key}")
return False
def create_atproto_commit_task(commit_data):
common.create_task(queue='atproto-commit')
writes = []
user = user_key.get()
repo = None
if user.atproto_did:
# existing DID and repo
did_doc = to_cls.load(user.atproto_did)
pds = to_cls._pds_for(did_doc)
if not pds or pds.rstrip('/') != url.rstrip('/'):
logger.warning(f'{user_key} {user.atproto_did} PDS {pds} is not us')
return False
repo = storage.load_repo(user.atproto_did)
repo.callback = create_atproto_commit_task
else:
# create new DID, repo
logger.info(f'Creating new did:plc for {user.key}')
did_plc = did.create_plc(user.handle_as('atproto'),
pds_url=common.host_url(),
post_fn=util.requests_post)
ndb.transactional()
def update_user_create_repo():
Object.get_or_create(did_plc.did, raw=did_plc.doc)
user.atproto_did = did_plc.did
add(user.copies, Target(uri=did_plc.did, protocol=to_cls.LABEL))
user.put()
assert not storage.load_repo(user.atproto_did)
nonlocal repo
repo = Repo.create(storage, user.atproto_did,
handle=user.handle_as('atproto'),
callback=create_atproto_commit_task,
signing_key=did_plc.signing_key,
rotation_key=did_plc.rotation_key)
if user.obj and user.obj.as1:
# create user profile
writes.append(Write(action=Action.CREATE,
collection='app.bsky.actor.profile',
rkey='self', record=user.obj.as_bsky()))
update_user_create_repo()
# load user
user = from_cls.get_or_create(from_key.id(), propagate=True)
assert user.atproto_did
logger.info(f'{user.key} is {user.atproto_did}')
assert repo
did_doc = to_cls.load(user.atproto_did)
pds = to_cls._pds_for(did_doc)
if not pds or pds.rstrip('/') != url.rstrip('/'):
logger.warning(f'{from_key} {user.atproto_did} PDS {pds} is not us')
return False
# create record and commit in ATProto repo
# load repo
repo = arroba.server.storage.load_repo(user.atproto_did)
assert repo
repo.callback = lambda _: common.create_task(queue='atproto-commit')
# create record and commit
ndb.transactional()
def write():
tid = next_tid()
writes.append(Write(action=Action.CREATE, collection='app.bsky.feed.post',
rkey=tid, record=obj.as_bsky()))
repo.apply_writes(writes)
repo.apply_writes(
[Write(action=Action.CREATE, collection='app.bsky.feed.post',
rkey=tid, record=obj.as_bsky())])
at_uri = f'at://{user.atproto_did}/app.bsky.feed.post/{tid}'
add(obj.copies, Target(uri=at_uri, protocol=to_cls.ABBREV))
obj.put()
write()
return True
@classmethod

1
hub.py
Wyświetl plik

@ -61,7 +61,6 @@ def health_check():
#
# XRPC server
#
arroba.server.storage = DatastoreStorage()
lexrpc.flask_server.init_flask(arroba.server.server, app)

Wyświetl plik

@ -6,6 +6,9 @@ import logging
import random
from urllib.parse import quote, urlparse
from arroba import did
from arroba.repo import Repo, Write
from arroba.storage import Action
import arroba.util
from Crypto.PublicKey import RSA
from cryptography.hazmat.primitives import serialization
@ -215,8 +218,13 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
@classmethod
@ndb.transactional()
def get_or_create(cls, id, **kwargs):
"""Loads and returns a User. Creates it if necessary."""
def get_or_create(cls, id, propagate=False, **kwargs):
"""Loads and returns a User. Creates it if necessary.
Args:
propagate (bool): whether to create copies of this user in push-based
protocols, eg ATProto and Nostr.
"""
assert cls != User
user = cls.get_by_id(id)
if user:
@ -226,7 +234,37 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
logger.info(f'Setting {user.key} direct={direct}')
user.direct = direct
user.put()
return user
if not propagate:
return user
else:
user = cls(id=id, **kwargs)
# TODO: fetch and store profile
# self.obj = self.load(self.profile_id())
if propagate and cls.LABEL != 'atproto' and not user.atproto_did:
# create new DID, repo
logger.info(f'Creating new did:plc for {user.key}')
did_plc = did.create_plc(user.handle_as('atproto'),
pds_url=common.host_url(),
post_fn=util.requests_post)
Object.get_or_create(did_plc.did, raw=did_plc.doc)
user.atproto_did = did_plc.did
add(user.copies, Target(uri=did_plc.did, protocol='atproto'))
repo = Repo.create(
arroba.server.storage, user.atproto_did,
handle=user.handle_as('atproto'),
callback=lambda _: common.create_task(queue='atproto-commit'),
signing_key=did_plc.signing_key,
rotation_key=did_plc.rotation_key)
if user.obj and user.obj.as1:
# create user profile
repo.apply_writes([Write(action=Action.CREATE,
collection='app.bsky.actor.profile',
rkey='self', record=user.obj.as_bsky())])
# generate keys for all protocols _except_ our own
#
@ -235,13 +273,10 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
if cls.LABEL != 'activitypub':
# originally from django_salmon.magicsigs
key = RSA.generate(KEY_BITS, randfunc=random.randbytes if DEBUG else None)
kwargs.update({
'mod': long_to_base64(key.n),
'public_exponent': long_to_base64(key.e),
'private_exponent': long_to_base64(key.d),
})
user.mod = long_to_base64(key.n)
user.public_exponent = long_to_base64(key.e)
user.private_exponent = long_to_base64(key.d)
user = cls(id=id, **kwargs)
try:
user.put()
except AssertionError as e:

Wyświetl plik

@ -3,10 +3,13 @@
from unittest.mock import patch
from arroba.mst import dag_cbor_cid
import arroba.server
from Crypto.PublicKey import ECC
from flask import g
from google.cloud import ndb
from granary.tests.test_bluesky import ACTOR_PROFILE_BSKY
from google.cloud.tasks_v2.types import Task
from granary.tests.test_bluesky import ACTOR_AS, ACTOR_PROFILE_VIEW_BSKY
from oauth_dropins.webutil.appengine_config import tasks_client
from oauth_dropins.webutil.testutil import NOW, requests_response
# import first so that Fake is defined before URL routes are registered
@ -31,7 +34,7 @@ class UserTest(TestCase):
g.user = self.make_user('y.z', cls=Web)
def test_get_or_create(self):
user = Fake.get_or_create('a.b')
user = Fake.get_or_create('fake:user')
assert not user.direct
assert user.mod
@ -43,10 +46,29 @@ class UserTest(TestCase):
assert user.private_pem()
# direct should get set even if the user exists
same = Fake.get_or_create('a.b', direct=True)
same = Fake.get_or_create('fake:user', direct=True)
user.direct = True
self.assert_entities_equal(same, user, ignore=['updated'])
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
@patch('requests.post',
return_value=requests_response('OK')) # create DID on PLC
def test_get_or_create_propagate(self, mock_post, mock_create_task):
Fake.fetchable = {'fake:user': ACTOR_AS}
user = Fake.get_or_create('fake:user', propagate=True)
# check user, record
# TODO: check profile
user = Fake.get_by_id('fake:user')
self.assertEqual('fake:handle:user', user.handle)
self.assertEqual([Target(uri=user.atproto_did, protocol='atproto')],
user.copies)
# check that the repo exists
repo = arroba.server.storage.load_repo(user.atproto_did)
mock_create_task.assert_called()
def test_validate_atproto_did(self):
user = Fake()