From cec093ea6050c96e82124cf6c4a8ce39c0bede08 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Tue, 30 May 2023 16:53:08 -0700 Subject: [PATCH] AP users: parameterize /remote-follow by protocol for #512 --- follow.py | 27 +++++++++++++++------------ tests/test_follow.py | 28 ++++++++++++++++++---------- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/follow.py b/follow.py index 34ace4e..59bfe61 100644 --- a/follow.py +++ b/follow.py @@ -19,7 +19,7 @@ from oauth_dropins.webutil.util import json_dumps, json_loads from activitypub import ActivityPub from flask_app import app import common -import models +from models import Follower, Object, PROTOCOLS from web import Web logger = logging.getLogger(__name__) @@ -80,9 +80,12 @@ def remote_follow(): """Discovers and redirects to a remote follow page for a given user.""" logger.info(f'Got: {request.values}') + cls = PROTOCOLS.get(request.values['protocol']) + if not cls: + error(f'Unknown protocol {request.values["protocol"]}') + domain = request.values['domain'] - # TODO(#512): parameterize by protocol - g.user = Web.get_by_id(domain) + g.user = cls.get_by_id(domain) if not g.user: error(f'No web user found for domain {domain}') @@ -135,7 +138,7 @@ class FollowCallback(indieauth.Callback): session['indieauthed-me'] = me domain = util.domain_from_link(me) - # TODO(#512): parameterize by protocol + # Web is hard-coded here since this is IndieAuth g.user = Web.get_by_id(domain) if not g.user: error(f'No web user for domain {domain}') @@ -179,12 +182,12 @@ class FollowCallback(indieauth.Callback): 'actor': g.user.actor_id(), 'to': [as2.PUBLIC_AUDIENCE], } - obj = models.Object(id=follow_id, domains=[domain], labels=['user'], - source_protocol='ui', status='complete', as2=follow_as2) + obj = Object(id=follow_id, domains=[domain], labels=['user'], + source_protocol='ui', status='complete', as2=follow_as2) ActivityPub.send(obj, inbox) - models.Follower.get_or_create(dest=id, src=domain, status='active', - last_follow=follow_as2) + Follower.get_or_create(dest=id, src=domain, status='active', + last_follow=follow_as2) obj.put() link = common.pretty_link(util.get_url(followee) or id, text=addr) @@ -224,13 +227,13 @@ class UnfollowCallback(indieauth.Callback): session['indieauthed-me'] = me domain = util.domain_from_link(me) - # TODO(#512): parameterize by protocol + # Web is hard-coded here since this is IndieAuth g.user = Web.get_by_id(domain) if not g.user: error(f'No web user for domain {domain}') domain = g.user.key.id() - follower = models.Follower.get_by_id(state) + follower = Follower.get_by_id(state) if not follower: error(f'Bad state {state}') @@ -258,8 +261,8 @@ class UnfollowCallback(indieauth.Callback): 'object': follower.last_follow, } - obj = models.Object(id=unfollow_id, domains=[domain], labels=['user'], - source_protocol='ui', status='complete', as2=unfollow_as2) + obj = Object(id=unfollow_id, domains=[domain], labels=['user'], + source_protocol='ui', status='complete', as2=unfollow_as2) ActivityPub.send(obj, inbox) follower.status = 'inactive' diff --git a/tests/test_follow.py b/tests/test_follow.py index 9a07218..f69b663 100644 --- a/tests/test_follow.py +++ b/tests/test_follow.py @@ -64,21 +64,29 @@ class RemoteFollowTest(testutil.TestCase): super().setUp() self.make_user('me') - def test_follow_no_domain(self, mock_get): - got = self.client.post('/remote-follow?address=@foo@bar') + def test_follow_no_domain(self, _): + got = self.client.post('/remote-follow?address=@foo@bar&protocol=web') self.assertEqual(400, got.status_code) - def test_follow_no_address(self, mock_get): - got = self.client.post('/remote-follow?domain=baz.com') + def test_follow_no_address(self, _): + got = self.client.post('/remote-follow?domain=baz.com&protocol=web') self.assertEqual(400, got.status_code) - def test_follow_no_user(self, mock_get): + def test_follow_no_protocol(self, _): + got = self.client.post('/remote-follow?address=@foo@bar&domain=me') + self.assertEqual(400, got.status_code) + + def test_follow_unknown_protocol(self, _): + got = self.client.post('/remote-follow?address=@foo@bar&domain=me&protocol=foo') + self.assertEqual(400, got.status_code) + + def test_follow_no_user(self, _): got = self.client.post('/remote-follow?address=@foo@bar&domain=baz.com') self.assertEqual(400, got.status_code) def test_follow(self, mock_get): mock_get.return_value = WEBFINGER - got = self.client.post('/remote-follow?address=@foo@bar&domain=me') + got = self.client.post('/remote-follow?address=@foo@bar&domain=me&protocol=web') self.assertEqual(302, got.status_code) self.assertEqual('https://bar/follow?uri=@me@me', got.headers['Location']) @@ -89,7 +97,7 @@ class RemoteFollowTest(testutil.TestCase): def test_follow_url(self, mock_get): mock_get.return_value = WEBFINGER - got = self.client.post('/remote-follow?address=https://bar/foo&domain=me') + got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web') self.assertEqual(302, got.status_code) self.assertEqual('https://bar/follow?uri=@me@me', got.headers['Location']) @@ -103,21 +111,21 @@ class RemoteFollowTest(testutil.TestCase): 'links': [{'rel': 'other', 'template': 'meh'}], }) - got = self.client.post('/remote-follow?address=https://bar/foo&domain=me') + got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web') self.assertEqual(302, got.status_code) self.assertEqual('/web/me', got.headers['Location']) def test_follow_no_webfinger_subscribe_link(self, mock_get): mock_get.return_value = requests_response(status_code=500) - got = self.client.post('/remote-follow?address=https://bar/foo&domain=me') + got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web') self.assertEqual(302, got.status_code) self.assertEqual('/web/me', got.headers['Location']) def test_follow_no_webfinger_subscribe_link(self, mock_get): mock_get.return_value = requests_response('not json') - got = self.client.post('/remote-follow?address=https://bar/foo&domain=me') + got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web') self.assertEqual(302, got.status_code) self.assertEqual('/web/me', got.headers['Location'])