diff --git a/atproto.py b/atproto.py index 1d38577..e4c1443 100644 --- a/atproto.py +++ b/atproto.py @@ -1,11 +1,6 @@ """ATProto protocol implementation. https://atproto.com/ - -TODO -* signup. resolve DID, fetch DID doc, extract PDS - * use alsoKnownAs as handle? or call getProfile on PDS to get handle? - * maybe need getProfile to store profile object? """ import json import logging @@ -91,6 +86,7 @@ class ATProto(User, Protocol): return f'@{self.readable_id}@{self.ABBREV}{common.SUPERDOMAIN}' @classmethod + # TODO: add bsky.app URLs, translating to/from at:// URIs. (to arroba?) def owns_id(cls, id): return (id.startswith('at://') or id.startswith('did:plc:') @@ -116,9 +112,7 @@ class ATProto(User, Protocol): repo, collection, rkey = parse_at_uri(obj.key.id()) did_obj = ATProto.load(repo) if did_obj: - return did_obj.raw.get('services', {})\ - .get('atproto_pds', {})\ - .get('endpoint') + return cls._pds_for(did_obj) # TODO: what should we do if the DID doesn't exist? should we return # None here? or do we need this path to return BF's URL so that we # then create the DID for non-ATP users on demand? @@ -134,6 +128,25 @@ class ATProto(User, Protocol): return common.host_url() + @classmethod + def _pds_for(cls, did_obj): + """ + Args: + did_obj: :class:`Object` + + Returns: + str, PDS URL, or None + """ + assert did_obj.key.id().startswith('did:') + + for service in did_obj.raw.get('service', []): + if service.get('id') in ('#atproto_pds', + f'{did_obj.key.id()}#atproto_pds'): + return service.get('serviceEndpoint') + + logger.info(f"{did_obj.key.id()}'s DID doc has no ATProto PDS") + return None + def is_blocklisted(url): # don't block common.DOMAINS since we want ourselves, ie our own PDS, to # be a valid domain to send to @@ -175,8 +188,8 @@ class ATProto(User, Protocol): if user.atproto_did: # existing DID and repo did_doc = to_cls.load(user.atproto_did) - pds = did_doc.raw['services']['atproto_pds']['endpoint'] - if pds.rstrip('/') != url.rstrip('/'): + pds = to_cls._pds_for(did_doc) + if not pds or pds.rstrip('/') != url.rstrip('/'): logger.warning(f'{user_key} {user.atproto_did} PDS {pds} is not us') return False repo = storage.load_repo(user.atproto_did) @@ -196,7 +209,6 @@ class ATProto(User, Protocol): user.put() assert not storage.load_repo(user.atproto_did) - # TODO: pass callback into create() so it's called for initial commit nonlocal repo repo = Repo.create(storage, user.atproto_did, handle=user.atproto_handle(), @@ -250,12 +262,16 @@ class ATProto(User, Protocol): util.interpret_http_exception(e) return False + pds = cls.target_for(obj) + if not pds: + return False + # at:// URI # examples: # at://did:plc:s2koow7r6t7tozgd4slc3dsg/app.bsky.feed.post/3jqcpv7bv2c2q # https://bsky.social/xrpc/com.atproto.repo.getRecord?repo=did:plc:s2koow7r6t7tozgd4slc3dsg&collection=app.bsky.feed.post&rkey=3jqcpv7bv2c2q repo, collection, rkey = parse_at_uri(obj.key.id()) - client = Client(cls.target_for(obj), headers={'User-Agent': USER_AGENT}) + client = Client(pds, headers={'User-Agent': USER_AGENT}) obj.bsky = client.com.atproto.repo.getRecord( repo=repo, collection=collection, rkey=rkey) return True diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 11fa458..7fdaa39 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -1,4 +1,5 @@ """Unit tests for atproto.py.""" +import base64 import copy from google.cloud.tasks_v2.types import Task import logging @@ -6,6 +7,7 @@ from unittest import skip from unittest.mock import call, patch from arroba.datastore_storage import AtpBlock, AtpRepo, DatastoreStorage +from arroba.did import encode_did_key from arroba.repo import Repo import arroba.util from flask import g @@ -27,18 +29,19 @@ import protocol from .testutil import Fake, TestCase DID_DOC = { - 'type': 'plc_operation', - 'rotationKeys': ['did:key:xyz'], - 'verificationMethods': {'atproto': 'did:key:xyz'}, - 'alsoKnownAs': ['at://han.dull'], - 'services': { - 'atproto_pds': { - 'type': 'AtprotoPersonalDataServer', - 'endpoint': 'https://some.pds', - } - }, - 'prev': None, - 'sig': '...', + 'id': 'did:plc:foo', + 'alsoKnownAs': ['at://han.dull'], + 'verificationMethod': [{ + 'id': 'did:plc:foo#atproto', + 'type': 'Multikey', + 'controller': 'did:plc:foo', + 'publicKeyMultibase': 'did:key:xyz', + }], + 'service': [{ + 'id': '#atproto_pds', + 'type': 'AtprotoPersonalDataServer', + 'serviceEndpoint': 'https://some.pds', + }], } class ATProtoTest(TestCase): @@ -197,15 +200,40 @@ class ATProtoTest(TestCase): assert user.atproto_did did_obj = ATProto.load(user.atproto_did) self.assertEqual('http://localhost/', - did_obj.raw['services']['atproto_pds']['endpoint']) - mock_post.assert_has_calls( - [self.req(f'https://plc.local/{user.atproto_did}', json=did_obj.raw)]) + did_obj.raw['service'][0]['serviceEndpoint']) # check repo, record repo = self.storage.load_repo(user.atproto_did) record = repo.get_record('app.bsky.feed.post', arroba.util._tid_last) self.assertEqual(POST_BSKY, record) + # check PLC directory call to create did:plc + self.assertEqual((f'https://plc.local/{user.atproto_did}',), + mock_post.call_args.args) + genesis_op = mock_post.call_args.kwargs['json'] + self.assertEqual(user.atproto_did, genesis_op.pop('did')) + genesis_op['sig'] = base64.urlsafe_b64decode(genesis_op['sig']) + assert arroba.util.verify_sig(genesis_op, repo.rotation_key.public_key()) + + del genesis_op['sig'] + self.assertEqual({ + 'type': 'plc_operation', + 'verificationMethods': { + 'atproto': encode_did_key(repo.signing_key.public_key()), + }, + 'rotationKeys': [encode_did_key(repo.rotation_key.public_key())], + 'alsoKnownAs': [ + 'at://user.fake.brid.gy', + ], + 'services': { + 'atproto_pds': { + 'type': 'AtprotoPersonalDataServer', + 'endpoint': 'http://localhost/', + } + }, + 'prev': None, + }, genesis_op) + # check atproto-commit task mock_create_task.assert_has_calls([ call(parent='projects/my-app/locations/us-central1/queues/atproto-commit', @@ -249,7 +277,7 @@ class ATProtoTest(TestCase): user = self.make_user(id='fake:user', cls=Fake, atproto_did='did:plc:foo') did_doc = copy.deepcopy(DID_DOC) - did_doc['services']['atproto_pds']['endpoint'] = 'http://localhost/' + did_doc['service'][0]['serviceEndpoint'] = 'http://localhost/' self.store_object(id='did:plc:foo', raw=did_doc) Repo.create(self.storage, 'did:plc:foo', signing_key=arroba.util.new_key()) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d8f5c68..a485586 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -282,7 +282,7 @@ class ProtocolTest(TestCase): # shouldn't be blocklisted user = self.make_user(id='fake:user', cls=Fake, atproto_did='did:plc:foo') did_doc = copy.deepcopy(DID_DOC) - did_doc['services']['atproto_pds']['endpoint'] = 'http://localhost/' + did_doc['service'][0]['serviceEndpoint'] = 'http://localhost/' self.store_object(id='did:plc:foo', raw=did_doc) # store Objects so we don't try to fetch them remotely