implement ATProto.target_for, .fetch for at:// URIs

pull/631/head
Ryan Barrett 2023-08-31 10:48:28 -07:00
rodzic 96b63487fa
commit e18dabf510
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
6 zmienionych plików z 96 dodań i 91 usunięć

Wyświetl plik

@ -7,15 +7,20 @@ TODO
* use alsoKnownAs as handle? or call getProfile on PDS to get handle? * use alsoKnownAs as handle? or call getProfile on PDS to get handle?
* maybe need getProfile to store profile object? * maybe need getProfile to store profile object?
""" """
import json
import logging import logging
from pathlib import Path
import re import re
from arroba import did from arroba import did
from arroba.util import parse_at_uri
from flask import abort, g, request from flask import abort, g, request
from google.cloud import ndb from google.cloud import ndb
from granary import as1, bluesky from granary import as1, bluesky
from lexrpc import Client
from oauth_dropins.webutil import flask_util, util from oauth_dropins.webutil import flask_util, util
import requests import requests
from urllib.parse import urljoin, urlparse
from flask_app import app, cache from flask_app import app, cache
import common import common
@ -23,12 +28,18 @@ from common import (
add, add,
error, error,
is_blocklisted, is_blocklisted,
USER_AGENT,
) )
from models import Follower, Object, User from models import Follower, Object, User
from protocol import Protocol from protocol import Protocol
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
lexicons = []
for filename in (Path(__file__).parent / 'lexicons').glob('**/*.json'):
with open(filename) as f:
lexicons.append(json.load(f))
class ATProto(User, Protocol): class ATProto(User, Protocol):
"""AT Protocol class. """AT Protocol class.
@ -90,11 +101,25 @@ class ATProto(User, Protocol):
or id.startswith('did:plc:') or id.startswith('did:plc:')
or id.startswith('did:web:')) or id.startswith('did:web:'))
# @classmethod @classmethod
# def target_for(cls, obj, shared=False): def target_for(cls, obj, shared=False):
# """Returns a relay that the receiving user uses.""" """Returns the PDS URL for the given object, or None.
# ...
# return actor.get('publicInbox') or actor.get('inbox') Args:
obj: :class:`Object`
Returns:
str
"""
if not obj.key.id().startswith('at://'):
return None
repo, collection, rkey = parse_at_uri(obj.key.id())
did_obj = ATProto.load(repo)
if not did_obj:
return None
return did_obj.raw.get('services', {}).get('atproto_pds', {}).get('endpoint')
# @classmethod # @classmethod
# def send(cls, obj, url, log_data=True): # def send(cls, obj, url, log_data=True):
@ -133,14 +158,12 @@ class ATProto(User, Protocol):
Raises: Raises:
TODO TODO
""" """
# 1. resolve DID
# 2. call getRecord on PDS
id = obj.key.id() id = obj.key.id()
if not cls.owns_id(id): if not cls.owns_id(id):
logger.info(f"ATProto can't fetch {id}") logger.info(f"ATProto can't fetch {id}")
return False return False
# did:plc, did:web
if id.startswith('did:'): if id.startswith('did:'):
try: try:
obj.raw = did.resolve(id, get_fn=util.requests_get) obj.raw = did.resolve(id, get_fn=util.requests_get)
@ -149,6 +172,17 @@ class ATProto(User, Protocol):
util.interpret_http_exception(e) util.interpret_http_exception(e)
return False return False
# at:// URI
# examples:
# at://did:plc:s2koow7r6t7tozgd4slc3dsg/app.bsky.feed.post/3jqcpv7bv2c2q
# https://bsky.social/xrpc/com.atproto.repo.getRecord?repo=did:plc:s2koow7r6t7tozgd4slc3dsg&collection=app.bsky.feed.post&rkey=3jqcpv7bv2c2q
repo, collection, rkey = parse_at_uri(obj.key.id())
client = Client(cls.target_for(obj), lexicons,
headers={'User-Agent': USER_AGENT})
obj.bsky = client.com.atproto.repo.getRecord(
repo=repo, collection=collection, rkey=rkey)
return True
@classmethod @classmethod
def serve(cls, obj): def serve(cls, obj):
"""Serves an :class:`Object` as AS2. """Serves an :class:`Object` as AS2.

Wyświetl plik

@ -60,6 +60,8 @@ DOMAIN_BLOCKLIST = frozenset((
CACHE_TIME = timedelta(seconds=60) CACHE_TIME = timedelta(seconds=60)
USER_AGENT = 'Bridgy Fed (https://fed.brid.gy/)'
def base64_to_long(x): def base64_to_long(x):
"""Converts x from URL safe base64 encoding to a long integer. """Converts x from URL safe base64 encoding to a long integer.

Wyświetl plik

@ -15,6 +15,8 @@ from oauth_dropins.webutil import (
util, util,
) )
from common import USER_AGENT
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.getLogger('lexrpc').setLevel(logging.INFO) logging.getLogger('lexrpc').setLevel(logging.INFO)
logging.getLogger('negotiator').setLevel(logging.WARNING) logging.getLogger('negotiator').setLevel(logging.WARNING)
@ -56,7 +58,7 @@ app.wsgi_app = flask_util.ndb_context_middleware(
cache = Cache(app) cache = Cache(app)
util.set_user_agent('Bridgy Fed (https://fed.brid.gy/)') util.set_user_agent(USER_AGENT)
# XRPC server # XRPC server
lexicons = [] lexicons = []

Wyświetl plik

@ -26,7 +26,7 @@ from common import add, base64_to_long, long_to_base64, redirect_unwrap
# maps string label to Protocol subclass. populated by ProtocolUserMeta. # maps string label to Protocol subclass. populated by ProtocolUserMeta.
# seed with old and upcoming protocols that don't have their own classes (yet). # seed with old and upcoming protocols that don't have their own classes (yet).
PROTOCOLS = {'bluesky': None, 'ostatus': None} PROTOCOLS = {'atproto': None, 'bluesky': None, 'ostatus': None}
# 2048 bits makes tests slow, so use 1024 for them # 2048 bits makes tests slow, so use 1024 for them
KEY_BITS = 1024 if DEBUG else 2048 KEY_BITS = 1024 if DEBUG else 2048
@ -339,8 +339,8 @@ class Target(ndb.Model):
https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty
""" """
uri = ndb.StringProperty(required=True) uri = ndb.StringProperty(required=True)
# choices is populated in flask_app, after all User subclasses are created, # choices is populated in app via reset_protocol_properties, after all User
# so that PROTOCOLS is fully populated # subclasses are created, so that PROTOCOLS is fully populated
protocol = ndb.StringProperty(choices=[], required=True) protocol = ndb.StringProperty(choices=[], required=True)
def __hash__(self): def __hash__(self):

Wyświetl plik

@ -910,7 +910,7 @@ class Protocol:
obj = orig_as1 = None obj = orig_as1 = None
if local: if local:
obj = Object.get_by_id(id) obj = Object.get_by_id(id)
if obj and (obj.as1 or obj.deleted): if obj and (obj.as1 or obj.raw or obj.deleted):
logger.info(' got from datastore') logger.info(' got from datastore')
obj.new = False obj.new = False
orig_as1 = obj.as1 orig_as1 = obj.as1

Wyświetl plik

@ -11,19 +11,30 @@ from oauth_dropins.webutil.util import json_dumps, json_loads
import requests import requests
from atproto import ATProto from atproto import ATProto
import common from common import USER_AGENT
from models import Object from models import Object
import protocol import protocol
from .testutil import Fake, TestCase from .testutil import Fake, TestCase
DID_DOC = {
'type': 'plc_operation',
'rotationKeys': ['did:key:xyz'],
'verificationMethods': {'atproto': 'did:key:xyz'},
'alsoKnownAs': ['at://han.dull'],
'services': {
'atproto_pds': {
'type': 'AtprotoPersonalDataServer',
'endpoint': 'https://some.pds',
}
},
'prev': None,
'sig': '...'
}
class ATProtoTest(TestCase): class ATProtoTest(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# self.request_context.push()
# self.user = self.make_user('user.com', has_hcard=True, has_redirects=True,
# obj_as2={**ACTOR, 'id': 'https://user.com/'})
def test_put_validates_id(self, *_): def test_put_validates_id(self, *_):
for bad in ( for bad in (
@ -52,10 +63,17 @@ class ATProtoTest(TestCase):
self.assertTrue(ATProto.owns_id('did:plc:foo')) self.assertTrue(ATProto.owns_id('did:plc:foo'))
self.assertTrue(ATProto.owns_id('did:web:bar.com')) self.assertTrue(ATProto.owns_id('did:web:bar.com'))
def test_target_for_stored_did(self):
self.assertIsNone(ATProto.target_for(Object(id='did:plc:foo')))
did_obj = self.store_object(id='did:plc:foo', raw=DID_DOC)
got = ATProto.target_for(Object(id='at://did:plc:foo/co.ll/123'))
self.assertEqual('https://some.pds', got)
@patch('requests.get', return_value=requests_response({'foo': 'bar'})) @patch('requests.get', return_value=requests_response({'foo': 'bar'}))
def test_fetch_did_plc(self, mock_get): def test_fetch_did_plc(self, mock_get):
obj = Object(id='did:plc:123') obj = Object(id='did:plc:123')
ATProto.fetch(obj) self.assertTrue(ATProto.fetch(obj))
self.assertEqual({'foo': 'bar'}, obj.raw) self.assertEqual({'foo': 'bar'}, obj.raw)
mock_get.assert_has_calls(( mock_get.assert_has_calls((
@ -65,21 +83,34 @@ class ATProtoTest(TestCase):
@patch('requests.get', return_value=requests_response({'foo': 'bar'})) @patch('requests.get', return_value=requests_response({'foo': 'bar'}))
def test_fetch_did_web(self, mock_get): def test_fetch_did_web(self, mock_get):
obj = Object(id='did:web:user.com') obj = Object(id='did:web:user.com')
ATProto.fetch(obj) self.assertTrue(ATProto.fetch(obj))
self.assertEqual({'foo': 'bar'}, obj.raw) self.assertEqual({'foo': 'bar'}, obj.raw)
mock_get.assert_has_calls(( mock_get.assert_has_calls((
self.req('https://user.com/.well-known/did.json'), self.req('https://user.com/.well-known/did.json'),
)) ))
# @patch('requests.get') @patch('requests.get', return_value=requests_response('not json'))
# def test_fetch_not_json(self, mock_get): def test_fetch_did_plc_not_json(self, mock_get):
# mock_get.return_value = self.as2_resp('XYZ not JSON') obj = Object(id='did:web:user.com')
self.assertFalse(ATProto.fetch(obj))
self.assertIsNone(obj.raw)
# with self.assertRaises(BadGateway): @patch('requests.get', return_value=requests_response({'foo': 'bar'}))
# ATProto.fetch(Object(id='http://the/id')) def test_fetch_at_uri_record(self, mock_get):
self.store_object(id='did:plc:abc', raw=DID_DOC)
# mock_get.assert_has_calls([self.as2_req('http://the/id')]) obj = Object(id='at://did:plc:abc/app.bsky.feed.post/123')
self.assertTrue(ATProto.fetch(obj))
self.assertEqual({'foo': 'bar'}, obj.bsky)
# eg https://bsky.social/xrpc/com.atproto.repo.getRecord?repo=did:plc:s2koow7r6t7tozgd4slc3dsg&collection=app.bsky.feed.post&rkey=3jqcpv7bv2c2q
mock_get.assert_called_with(
'https://some.pds/xrpc/com.atproto.repo.getRecord?repo=did%3Aplc%3Aabc&collection=app.bsky.feed.post&rkey=123',
json=None,
headers={
'Content-Type': 'application/json',
'User-Agent': USER_AGENT,
},
)
def test_serve(self): def test_serve(self):
obj = self.store_object(id='http://orig', our_as1=ACTOR_AS) obj = self.store_object(id='http://orig', our_as1=ACTOR_AS)
@ -117,67 +148,3 @@ class ATProtoTest(TestCase):
# user.obj = Object(id='a', as2=ACTOR) # user.obj = Object(id='a', as2=ACTOR)
# self.assertEqual('@swentel@mas.to', user.readable_id) # self.assertEqual('@swentel@mas.to', user.readable_id)
# self.assertEqual('@swentel@mas.to', user.readable_or_key_id()) # self.assertEqual('@swentel@mas.to', user.readable_or_key_id())
# @skip
# def test_target_for_not_atproto(self):
# with self.assertRaises(AssertionError):
# ATProto.target_for(Object(source_protocol='web'))
# def test_target_for_actor(self):
# self.assertEqual(ACTOR['inbox'], ATProto.target_for(
# Object(source_protocol='ap', as2=ACTOR)))
# actor = copy.deepcopy(ACTOR)
# del actor['inbox']
# self.assertIsNone(ATProto.target_for(
# Object(source_protocol='ap', as2=actor)))
# actor['publicInbox'] = 'so-public'
# self.assertEqual('so-public', ATProto.target_for(
# Object(source_protocol='ap', as2=actor)))
# # sharedInbox
# self.assertEqual('so-public', ATProto.target_for(
# Object(source_protocol='ap', as2=actor), shared=True))
# actor['endpoints'] = {
# 'sharedInbox': 'so-shared',
# }
# self.assertEqual('so-public', ATProto.target_for(
# Object(source_protocol='ap', as2=actor)))
# self.assertEqual('so-shared', ATProto.target_for(
# Object(source_protocol='ap', as2=actor), shared=True))
# def test_target_for_object(self):
# obj = Object(as2=NOTE_OBJECT, source_protocol='ap')
# self.assertIsNone(ATProto.target_for(obj))
# Object(id=ACTOR['id'], as2=ACTOR).put()
# obj.as2 = {
# **NOTE_OBJECT,
# 'author': ACTOR['id'],
# }
# self.assertEqual('http://mas.to/inbox', ATProto.target_for(obj))
# del obj.as2['author']
# obj.as2['actor'] = copy.deepcopy(ACTOR)
# obj.as2['actor']['url'] = [obj.as2['actor'].pop('id')]
# self.assertEqual('http://mas.to/inbox', ATProto.target_for(obj))
# @patch('requests.get')
# def test_target_for_object_fetch(self, mock_get):
# mock_get.return_value = self.as2_resp(ACTOR)
# obj = Object(as2={
# **NOTE_OBJECT,
# 'author': 'http://the/author',
# }, source_protocol='ap')
# self.assertEqual('http://mas.to/inbox', ATProto.target_for(obj))
# mock_get.assert_has_calls([self.as2_req('http://the/author')])
# @patch('requests.get')
# def test_target_for_author_is_object_id(self, mock_get):
# obj = self.store_object(id='http://the/author', our_as1={
# 'author': 'http://the/author',
# })
# # test is that we short circuit out instead of infinite recursion
# self.assertIsNone(ATProto.target_for(obj))