diff --git a/activitypub.py b/activitypub.py index 19c4e95..f76da8a 100644 --- a/activitypub.py +++ b/activitypub.py @@ -384,7 +384,8 @@ class ActivityPub(User, Protocol): if (activity.get('type') == 'Delete' and obj_id and keyId == fragmentless(obj_id)): logger.info('Object/actor being deleted is also keyId') - key_actor = Object(id=keyId, source_protocol='activitypub', deleted=True) + key_actor = Object.get_or_create( + id=keyId, source_protocol='activitypub', deleted=True) key_actor.put() else: raise diff --git a/models.py b/models.py index 557c26f..3fc0f37 100644 --- a/models.py +++ b/models.py @@ -224,7 +224,7 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): @classmethod @ndb.transactional() def get_or_create(cls, id, propagate=False, **kwargs): - """Loads and returns a :class:`User`\. Creates it if necessary. + """Loads and returns a :class:`User`. Creates it if necessary. Args: propagate (bool): whether to create copies of this user in push-based @@ -245,9 +245,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): else: user = cls(id=id, **kwargs) - if propagate: - # force refresh user profile - user.obj = cls.load(user.profile_id(), remote=True) + # load user profile object, refreshing if necessary + user.obj = cls.load(user.profile_id(), remote=True if propagate else None) if propagate and cls.LABEL != 'atproto' and not user.atproto_did: PROTOCOLS['atproto'].create_for(user) @@ -750,7 +749,7 @@ class Object(StringIdModel): authorized = (as1.get_ids(orig_as1, 'author') + as1.get_ids(orig_as1, 'actor')) if not actor: - logger.warning(f'Cowardly refusing to overwrite {id} without checking actor') + logger.warning(f'would cowardly refuse to overwrite {id} without checking actor') elif actor not in authorized + [id]: logger.warning(f"actor {actor} isn't {id}'s author or actor {authorized}") else: diff --git a/tests/test_activitypub.py b/tests/test_activitypub.py index 858f25f..18abb97 100644 --- a/tests/test_activitypub.py +++ b/tests/test_activitypub.py @@ -449,7 +449,9 @@ class ActivityPubTest(TestCase): resp = self.post('/ap/sharedInbox', json=note) self.assertEqual(400, resp.status_code) - def test_inbox_no_matching_protocol(self, *_): + def test_inbox_no_matching_protocol(self, mock_head, mock_get, mock_post): + # TODO: remove + mock_get.return_value = self.as2_resp(ACTOR) resp = self.post('/foo.json/inbox', json=NOTE) self.assertEqual(400, resp.status_code) @@ -687,13 +689,12 @@ class ActivityPubTest(TestCase): object_ids=['https://user.com/orig']) def test_shared_inbox_repost_of_fediverse(self, mock_head, mock_get, mock_post): - Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), - from_=self.user) + to = self.make_user(ACTOR['id'], cls=ActivityPub) + Follower.get_or_create(to=to, from_=self.user) baz = self.make_user('fake:baz', cls=Fake, obj_id='fake:baz') - Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), from_=baz) + Follower.get_or_create(to=to, from_=baz) baj = self.make_user('fake:baj', cls=Fake, obj_id='fake:baj') - Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), - from_=baj, status='inactive') + Follower.get_or_create(to=to, from_=baj, status='inactive') mock_head.return_value = requests_response(url='http://target') mock_get.side_effect = [ @@ -752,7 +753,7 @@ class ActivityPubTest(TestCase): object_ids=['http://nope.com/post']) def test_inbox_not_public(self, mock_head, mock_get, mock_post): - Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), + Follower.get_or_create(to=self.make_user(ACTOR['id'], cls=ActivityPub), from_=self.user) mock_head.return_value = requests_response(url='http://target') @@ -983,7 +984,7 @@ class ActivityPubTest(TestCase): def test_inbox_undo_follow(self, mock_head, mock_get, mock_post): follower = Follower(to=self.user.key, - from_=ActivityPub.get_or_create(ACTOR['id']).key, + from_=ActivityPub(id=ACTOR['id']).key, status='active') follower.put() @@ -1000,9 +1001,10 @@ class ActivityPubTest(TestCase): self.assertEqual('inactive', follower.key.get().status) def test_inbox_follow_inactive(self, mock_head, mock_get, mock_post): - follower = Follower.get_or_create(to=self.user, - from_=ActivityPub.get_or_create(ACTOR['id']), - status='inactive') + follower = Follower.get_or_create( + to=self.user, + from_=self.make_user(ACTOR['id'], cls=ActivityPub), + status='inactive') mock_head.return_value = requests_response(url='https://user.com/') mock_get.side_effect = [ @@ -1188,14 +1190,14 @@ class ActivityPubTest(TestCase): mock_common_log.assert_any_call('Returning 401: No HTTP Signature', exc_info=None) def test_delete_actor(self, *mocks): - follower = Follower.get_or_create( - to=self.user, from_=ActivityPub.get_or_create(DELETE['actor'])) - followee = Follower.get_or_create( - to=ActivityPub.get_or_create(DELETE['actor']), - from_=Fake.get_or_create('snarfed.org')) + deleted = self.make_user(DELETE['actor'], cls=ActivityPub) + follower = Follower.get_or_create(to=self.user, from_=deleted) + followee = Follower.get_or_create(to=deleted, from_=Fake(id='fake:user')) + # other unrelated follower - other = Follower.get_or_create( - to=self.user, from_=ActivityPub.get_or_create('https://mas.to/users/other')) + other = self.make_user('https://mas.to/users/other', cls=ActivityPub) + other = Follower.get_or_create(to=self.user, from_=other) + self.assertEqual(3, Follower.query().count()) got = self.post('/ap/sharedInbox', json=DELETE) diff --git a/tests/test_atproto.py b/tests/test_atproto.py index cc74c15..c5df172 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -84,7 +84,7 @@ class ATProtoTest(TestCase): @patch('requests.get', return_value=requests_response(DID_DOC)) def test_get_or_create(self, _): - user = ATProto.get_or_create('did:plc:foo') + user = self.make_user('did:plc:foo', cls=ATProto) self.assertEqual('han.dull', user.key.get().handle) def test_put_blocks_atproto_did(self): diff --git a/tests/test_follow.py b/tests/test_follow.py index 7e0fb7a..ce9fe32 100644 --- a/tests/test_follow.py +++ b/tests/test_follow.py @@ -154,10 +154,10 @@ class FollowTest(TestCase): def test_callback_address(self, mock_get, mock_post): mock_get.side_effect = ( - # oauth-dropins indieauth https://alice.com fetch for user json - requests_response(''), + requests_response(''), # indieauth https://alice.com fetch for user json WEBFINGER, self.as2_resp(FOLLOWEE), + self.as2_resp(FOLLOWEE), ) mock_post.side_effect = ( requests_response('me=https://alice.com'), @@ -173,7 +173,8 @@ class FollowTest(TestCase): def test_callback_url(self, mock_get, mock_post): mock_get.side_effect = ( - requests_response(''), + requests_response(''), # indieauth https://alice.com fetch for user json + self.as2_resp(FOLLOWEE), self.as2_resp(FOLLOWEE), ) mock_post.side_effect = ( @@ -234,6 +235,7 @@ class FollowTest(TestCase): mock_get.side_effect = ( requests_response(''), self.as2_resp(followee), + self.as2_resp(followee), ) mock_post.side_effect = ( requests_response('me=https://alice.com'), @@ -304,6 +306,7 @@ class FollowTest(TestCase): mock_get.side_effect = ( requests_response(''), self.as2_resp(FOLLOWEE), + self.as2_resp(FOLLOWEE), ) mock_post.side_effect = ( requests_response('me=https://alice.com'), @@ -349,6 +352,7 @@ class FollowTest(TestCase): mock_get.side_effect = ( requests_response(''), self.as2_resp(followee), + self.as2_resp(followee), ) mock_post.side_effect = ( requests_response('me=https://alice.com'), @@ -371,6 +375,7 @@ class FollowTest(TestCase): def test_indieauthed_session(self, mock_get, mock_post): mock_get.side_effect = ( self.as2_resp(FOLLOWEE), + self.as2_resp(FOLLOWEE), ) mock_post.side_effect = ( requests_response('OK'), # AP Follow to inbox diff --git a/tests/test_web.py b/tests/test_web.py index a31ef20..5d3f729 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -465,18 +465,24 @@ class WebTest(TestCase): with self.assertRaises(AssertionError): Web(id=bad).put() - def test_get_or_create_lower_cases_domain(self, *_): + def test_get_or_create_lower_cases_domain(self, mock_get, mock_post): + mock_get.return_value = requests_response('') + user = Web.get_or_create('AbC.oRg') self.assertEqual('abc.org', user.key.id()) self.assert_entities_equal(user, Web.get_by_id('abc.org')) self.assertIsNone(Web.get_by_id('AbC.oRg')) - def test_get_or_create_unicode_domain(self, *_): + def test_get_or_create_unicode_domain(self, mock_get, mock_post): + mock_get.return_value = requests_response('') + user = Web.get_or_create('☃.net') self.assertEqual('☃.net', user.key.id()) self.assert_entities_equal(user, Web.get_by_id('☃.net')) - def test_get_or_create_scripts_leading_trailing_dots(self, *_): + def test_get_or_create_scripts_leading_trailing_dots(self, mock_get, mock_post): + mock_get.return_value = requests_response('') + user = Web.get_or_create('..foo.bar.') self.assertEqual('foo.bar', user.key.id()) self.assert_entities_equal(user, Web.get_by_id('foo.bar')) @@ -1861,8 +1867,7 @@ http://this/404s redir = 'http://localhost/.well-known/webfinger?resource=acct:user.com@user.com' mock_get.side_effect = ( requests_response('', status=302, redirected_url=redir), - requests_response(ACTOR_HTML, url='https://user.com/', - content_type=CONTENT_TYPE_HTML), + ACTOR_HTML_RESP, ) got = self.post('/web-site', data={'url': 'https://user.com/'}) @@ -1877,6 +1882,7 @@ http://this/404s mock_get.side_effect = ( requests_response(''), requests_response(''), + requests_response(''), ) got = self.post('/web-site', data={'url': 'https://☃.net/'}) @@ -1888,6 +1894,7 @@ http://this/404s mock_get.side_effect = ( requests_response(''), requests_response(''), + requests_response(''), ) got = self.post('/web-site', data={'url': 'https://AbC.oRg/'}) @@ -1918,9 +1925,10 @@ http://this/404s get_flashed_messages()) self.assertEqual(1, Web.query().count()) - def test_check_web_site_fetch_fails(self, mock_get, _): + def test_check_webfinger_redirects_then_fails(self, mock_get, _): redir = 'http://localhost/.well-known/webfinger?resource=acct:orig@orig' mock_get.side_effect = ( + ACTOR_HTML_RESP, requests_response('', status=302, redirected_url=redir), requests_response('', status=503), ) @@ -1930,6 +1938,14 @@ http://this/404s self.assertTrue(get_flashed_messages()[0].startswith( "Couldn't connect to https://orig.co/: ")) + def test_check_web_site_fetch_fails(self, mock_get, _): + mock_get.return_value = requests_response('', status=503) + + got = self.post('/web-site', data={'url': 'https://orig.co/'}) + self.assert_equals(200, got.status_code, got.headers) + self.assertTrue(get_flashed_messages()[0].startswith( + "Couldn't connect to https://orig.co/: ")) + @patch('requests.post') @patch('requests.get') diff --git a/web.py b/web.py index b606fc9..57bae46 100644 --- a/web.py +++ b/web.py @@ -181,14 +181,14 @@ class Web(User, Protocol): logger.info(f'Verifying {domain}') if domain.startswith('www.') and domain not in WWW_DOMAINS: - # if root domain redirects to www, use root domain instead + # if root domain serves ok, use it instead # https://github.com/snarfed/bridgy-fed/issues/314 root = domain.removeprefix("www.") root_site = f'https://{root}/' try: resp = util.requests_get(root_site, gateway=False) if resp.ok and self.is_web_url(resp.url): - logger.info(f'{root_site} redirects to {resp.url} ; using {root} instead') + logger.info(f'{root_site} serves ok ; using {root} instead') root_user = Web.get_or_create(root) self.use_instead = root_user.key self.put() @@ -321,7 +321,7 @@ class Web(User, Protocol): logger.info(f'Skipping sending {verb} (not supported in webmention/mf2) to {url}') return False elif url not in as1.targets(obj.as1): - logger.info(f'Skipping sending to {url} , not a target in the object') + # logger.info(f'Skipping sending to {url} , not a target in the object') return False elif to_cls.is_blocklisted(url): logger.info(f'Skipping sending to blocklisted {url}') @@ -478,8 +478,8 @@ def check_web_site(): flash(f'{url} is not a valid or supported web site') return render_template('enter_web_site.html'), 400 - g.user = Web.get_or_create(domain, direct=True) try: + g.user = Web.get_or_create(domain, direct=True) g.user = g.user.verify() except BaseException as e: code, body = util.interpret_http_exception(e)