From e18dabf5106a15e035b91537fc2096c682f906e1 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Thu, 31 Aug 2023 10:48:28 -0700 Subject: [PATCH] implement ATProto.target_for, .fetch for at:// URIs --- atproto.py | 50 ++++++++++++++--- common.py | 2 + flask_app.py | 4 +- models.py | 6 +-- protocol.py | 2 +- tests/test_atproto.py | 123 ++++++++++++++++-------------------------- 6 files changed, 96 insertions(+), 91 deletions(-) diff --git a/atproto.py b/atproto.py index 407c020..55e7f1b 100644 --- a/atproto.py +++ b/atproto.py @@ -7,15 +7,20 @@ TODO * use alsoKnownAs as handle? or call getProfile on PDS to get handle? * maybe need getProfile to store profile object? """ +import json import logging +from pathlib import Path import re from arroba import did +from arroba.util import parse_at_uri from flask import abort, g, request from google.cloud import ndb from granary import as1, bluesky +from lexrpc import Client from oauth_dropins.webutil import flask_util, util import requests +from urllib.parse import urljoin, urlparse from flask_app import app, cache import common @@ -23,12 +28,18 @@ from common import ( add, error, is_blocklisted, + USER_AGENT, ) from models import Follower, Object, User from protocol import Protocol logger = logging.getLogger(__name__) +lexicons = [] +for filename in (Path(__file__).parent / 'lexicons').glob('**/*.json'): + with open(filename) as f: + lexicons.append(json.load(f)) + class ATProto(User, Protocol): """AT Protocol class. @@ -90,11 +101,25 @@ class ATProto(User, Protocol): or id.startswith('did:plc:') or id.startswith('did:web:')) - # @classmethod - # def target_for(cls, obj, shared=False): - # """Returns a relay that the receiving user uses.""" - # ... - # return actor.get('publicInbox') or actor.get('inbox') + @classmethod + def target_for(cls, obj, shared=False): + """Returns the PDS URL for the given object, or None. + + Args: + obj: :class:`Object` + + Returns: + str + """ + if not obj.key.id().startswith('at://'): + return None + + repo, collection, rkey = parse_at_uri(obj.key.id()) + did_obj = ATProto.load(repo) + if not did_obj: + return None + + return did_obj.raw.get('services', {}).get('atproto_pds', {}).get('endpoint') # @classmethod # def send(cls, obj, url, log_data=True): @@ -133,14 +158,12 @@ class ATProto(User, Protocol): Raises: TODO """ - # 1. resolve DID - # 2. call getRecord on PDS - id = obj.key.id() if not cls.owns_id(id): logger.info(f"ATProto can't fetch {id}") return False + # did:plc, did:web if id.startswith('did:'): try: obj.raw = did.resolve(id, get_fn=util.requests_get) @@ -149,6 +172,17 @@ class ATProto(User, Protocol): util.interpret_http_exception(e) return False + # at:// URI + # examples: + # at://did:plc:s2koow7r6t7tozgd4slc3dsg/app.bsky.feed.post/3jqcpv7bv2c2q + # https://bsky.social/xrpc/com.atproto.repo.getRecord?repo=did:plc:s2koow7r6t7tozgd4slc3dsg&collection=app.bsky.feed.post&rkey=3jqcpv7bv2c2q + repo, collection, rkey = parse_at_uri(obj.key.id()) + client = Client(cls.target_for(obj), lexicons, + headers={'User-Agent': USER_AGENT}) + obj.bsky = client.com.atproto.repo.getRecord( + repo=repo, collection=collection, rkey=rkey) + return True + @classmethod def serve(cls, obj): """Serves an :class:`Object` as AS2. diff --git a/common.py b/common.py index faaa665..38ae291 100644 --- a/common.py +++ b/common.py @@ -60,6 +60,8 @@ DOMAIN_BLOCKLIST = frozenset(( CACHE_TIME = timedelta(seconds=60) +USER_AGENT = 'Bridgy Fed (https://fed.brid.gy/)' + def base64_to_long(x): """Converts x from URL safe base64 encoding to a long integer. diff --git a/flask_app.py b/flask_app.py index 8956c5e..32ccc40 100644 --- a/flask_app.py +++ b/flask_app.py @@ -15,6 +15,8 @@ from oauth_dropins.webutil import ( util, ) +from common import USER_AGENT + logger = logging.getLogger(__name__) logging.getLogger('lexrpc').setLevel(logging.INFO) logging.getLogger('negotiator').setLevel(logging.WARNING) @@ -56,7 +58,7 @@ app.wsgi_app = flask_util.ndb_context_middleware( cache = Cache(app) -util.set_user_agent('Bridgy Fed (https://fed.brid.gy/)') +util.set_user_agent(USER_AGENT) # XRPC server lexicons = [] diff --git a/models.py b/models.py index bee33e5..8c4ee54 100644 --- a/models.py +++ b/models.py @@ -26,7 +26,7 @@ from common import add, base64_to_long, long_to_base64, redirect_unwrap # maps string label to Protocol subclass. populated by ProtocolUserMeta. # seed with old and upcoming protocols that don't have their own classes (yet). -PROTOCOLS = {'bluesky': None, 'ostatus': None} +PROTOCOLS = {'atproto': None, 'bluesky': None, 'ostatus': None} # 2048 bits makes tests slow, so use 1024 for them KEY_BITS = 1024 if DEBUG else 2048 @@ -339,8 +339,8 @@ class Target(ndb.Model): https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty """ uri = ndb.StringProperty(required=True) - # choices is populated in flask_app, after all User subclasses are created, - # so that PROTOCOLS is fully populated + # choices is populated in app via reset_protocol_properties, after all User + # subclasses are created, so that PROTOCOLS is fully populated protocol = ndb.StringProperty(choices=[], required=True) def __hash__(self): diff --git a/protocol.py b/protocol.py index bdeeb3e..f1d1655 100644 --- a/protocol.py +++ b/protocol.py @@ -910,7 +910,7 @@ class Protocol: obj = orig_as1 = None if local: obj = Object.get_by_id(id) - if obj and (obj.as1 or obj.deleted): + if obj and (obj.as1 or obj.raw or obj.deleted): logger.info(' got from datastore') obj.new = False orig_as1 = obj.as1 diff --git a/tests/test_atproto.py b/tests/test_atproto.py index e6b85b0..d86a126 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -11,19 +11,30 @@ from oauth_dropins.webutil.util import json_dumps, json_loads import requests from atproto import ATProto -import common +from common import USER_AGENT from models import Object import protocol from .testutil import Fake, TestCase +DID_DOC = { + 'type': 'plc_operation', + 'rotationKeys': ['did:key:xyz'], + 'verificationMethods': {'atproto': 'did:key:xyz'}, + 'alsoKnownAs': ['at://han.dull'], + 'services': { + 'atproto_pds': { + 'type': 'AtprotoPersonalDataServer', + 'endpoint': 'https://some.pds', + } + }, + 'prev': None, + 'sig': '...' +} + class ATProtoTest(TestCase): def setUp(self): super().setUp() - # self.request_context.push() - - # self.user = self.make_user('user.com', has_hcard=True, has_redirects=True, - # obj_as2={**ACTOR, 'id': 'https://user.com/'}) def test_put_validates_id(self, *_): for bad in ( @@ -52,10 +63,17 @@ class ATProtoTest(TestCase): self.assertTrue(ATProto.owns_id('did:plc:foo')) self.assertTrue(ATProto.owns_id('did:web:bar.com')) + def test_target_for_stored_did(self): + self.assertIsNone(ATProto.target_for(Object(id='did:plc:foo'))) + + did_obj = self.store_object(id='did:plc:foo', raw=DID_DOC) + got = ATProto.target_for(Object(id='at://did:plc:foo/co.ll/123')) + self.assertEqual('https://some.pds', got) + @patch('requests.get', return_value=requests_response({'foo': 'bar'})) def test_fetch_did_plc(self, mock_get): obj = Object(id='did:plc:123') - ATProto.fetch(obj) + self.assertTrue(ATProto.fetch(obj)) self.assertEqual({'foo': 'bar'}, obj.raw) mock_get.assert_has_calls(( @@ -65,21 +83,34 @@ class ATProtoTest(TestCase): @patch('requests.get', return_value=requests_response({'foo': 'bar'})) def test_fetch_did_web(self, mock_get): obj = Object(id='did:web:user.com') - ATProto.fetch(obj) + self.assertTrue(ATProto.fetch(obj)) self.assertEqual({'foo': 'bar'}, obj.raw) mock_get.assert_has_calls(( self.req('https://user.com/.well-known/did.json'), )) - # @patch('requests.get') - # def test_fetch_not_json(self, mock_get): - # mock_get.return_value = self.as2_resp('XYZ not JSON') + @patch('requests.get', return_value=requests_response('not json')) + def test_fetch_did_plc_not_json(self, mock_get): + obj = Object(id='did:web:user.com') + self.assertFalse(ATProto.fetch(obj)) + self.assertIsNone(obj.raw) - # with self.assertRaises(BadGateway): - # ATProto.fetch(Object(id='http://the/id')) - - # mock_get.assert_has_calls([self.as2_req('http://the/id')]) + @patch('requests.get', return_value=requests_response({'foo': 'bar'})) + def test_fetch_at_uri_record(self, mock_get): + self.store_object(id='did:plc:abc', raw=DID_DOC) + obj = Object(id='at://did:plc:abc/app.bsky.feed.post/123') + self.assertTrue(ATProto.fetch(obj)) + self.assertEqual({'foo': 'bar'}, obj.bsky) + # eg https://bsky.social/xrpc/com.atproto.repo.getRecord?repo=did:plc:s2koow7r6t7tozgd4slc3dsg&collection=app.bsky.feed.post&rkey=3jqcpv7bv2c2q + mock_get.assert_called_with( + 'https://some.pds/xrpc/com.atproto.repo.getRecord?repo=did%3Aplc%3Aabc&collection=app.bsky.feed.post&rkey=123', + json=None, + headers={ + 'Content-Type': 'application/json', + 'User-Agent': USER_AGENT, + }, + ) def test_serve(self): obj = self.store_object(id='http://orig', our_as1=ACTOR_AS) @@ -117,67 +148,3 @@ class ATProtoTest(TestCase): # user.obj = Object(id='a', as2=ACTOR) # self.assertEqual('@swentel@mas.to', user.readable_id) # self.assertEqual('@swentel@mas.to', user.readable_or_key_id()) - - # @skip - # def test_target_for_not_atproto(self): - # with self.assertRaises(AssertionError): - # ATProto.target_for(Object(source_protocol='web')) - - # def test_target_for_actor(self): - # self.assertEqual(ACTOR['inbox'], ATProto.target_for( - # Object(source_protocol='ap', as2=ACTOR))) - - # actor = copy.deepcopy(ACTOR) - # del actor['inbox'] - # self.assertIsNone(ATProto.target_for( - # Object(source_protocol='ap', as2=actor))) - - # actor['publicInbox'] = 'so-public' - # self.assertEqual('so-public', ATProto.target_for( - # Object(source_protocol='ap', as2=actor))) - - # # sharedInbox - # self.assertEqual('so-public', ATProto.target_for( - # Object(source_protocol='ap', as2=actor), shared=True)) - # actor['endpoints'] = { - # 'sharedInbox': 'so-shared', - # } - # self.assertEqual('so-public', ATProto.target_for( - # Object(source_protocol='ap', as2=actor))) - # self.assertEqual('so-shared', ATProto.target_for( - # Object(source_protocol='ap', as2=actor), shared=True)) - - # def test_target_for_object(self): - # obj = Object(as2=NOTE_OBJECT, source_protocol='ap') - # self.assertIsNone(ATProto.target_for(obj)) - - # Object(id=ACTOR['id'], as2=ACTOR).put() - # obj.as2 = { - # **NOTE_OBJECT, - # 'author': ACTOR['id'], - # } - # self.assertEqual('http://mas.to/inbox', ATProto.target_for(obj)) - - # del obj.as2['author'] - # obj.as2['actor'] = copy.deepcopy(ACTOR) - # obj.as2['actor']['url'] = [obj.as2['actor'].pop('id')] - # self.assertEqual('http://mas.to/inbox', ATProto.target_for(obj)) - - # @patch('requests.get') - # def test_target_for_object_fetch(self, mock_get): - # mock_get.return_value = self.as2_resp(ACTOR) - - # obj = Object(as2={ - # **NOTE_OBJECT, - # 'author': 'http://the/author', - # }, source_protocol='ap') - # self.assertEqual('http://mas.to/inbox', ATProto.target_for(obj)) - # mock_get.assert_has_calls([self.as2_req('http://the/author')]) - - # @patch('requests.get') - # def test_target_for_author_is_object_id(self, mock_get): - # obj = self.store_object(id='http://the/author', our_as1={ - # 'author': 'http://the/author', - # }) - # # test is that we short circuit out instead of infinite recursion - # self.assertIsNone(ATProto.target_for(obj))