improve domain validation for Web key ids, normalize to lower case

pull/542/head
Ryan Barrett 2023-06-09 10:58:28 -07:00
rodzic 0f19654eb2
commit 7f6cc61683
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
6 zmienionych plików z 90 dodań i 20 usunięć

Wyświetl plik

@ -20,7 +20,15 @@ from oauth_dropins.webutil.util import json_dumps, json_loads
logger = logging.getLogger(__name__)
DOMAIN_RE = r'[^/:]+\.[^/:]+'
# allow hostname chars (a-z, 0-9, -), allow arbitrary unicode (eg ☃.net), don't
# allow specific chars that we'll often see in webfinger, AP handles, etc. (@, :)
# https://stackoverflow.com/questions/10306690/what-is-a-regular-expression-which-will-match-a-valid-domain-name-without-a-subd
#
# uses $ at end but not ^ at the beginning so that it can be used to match just
# part of a URL path segment, eg for /acct:user.com in webfinger.py.
#
# TODO: preprocess with domain2idna, then narrow this to just [a-z0-9-]
DOMAIN_RE = r'[^/:;@_?!\']+\.[^/:@_?!\']+$'
TLD_BLOCKLIST = ('7z', 'asp', 'aspx', 'gif', 'html', 'ico', 'jpg', 'jpeg', 'js',
'json', 'php', 'png', 'rar', 'txt', 'yaml', 'yml', 'zip')

Wyświetl plik

@ -21,12 +21,12 @@ class CommonTest(TestCase):
def setUpClass(cls):
with appengine_config.ndb_client.context():
# do this in setUpClass since generating RSA keys is slow
cls.user = cls.make_user('site')
cls.user = cls.make_user('user.com')
def setUp(self):
super().setUp()
self.request_context.push()
g.user = Fake(id='site')
g.user = Fake(id='user.com')
def tearDown(self):
self.request_context.pop()
@ -48,8 +48,8 @@ class CommonTest(TestCase):
common.pretty_link('http://foo'))
self.assertEqual(
'<a class="h-card u-author" href="/fake/site"><img src="" class="profile"> site</a>',
common.pretty_link('https://site/'))
'<a class="h-card u-author" href="/fake/user.com"><img src="" class="profile"> user.com</a>',
common.pretty_link('https://user.com/'))
def test_redirect_wrap_empty(self):
self.assertIsNone(common.redirect_wrap(None))

Wyświetl plik

@ -63,7 +63,7 @@ class RemoteFollowTest(TestCase):
def setUp(self):
super().setUp()
self.make_user('me')
self.make_user('user.com')
def test_no_domain(self, _):
got = self.client.post('/remote-follow?address=@foo@bar&protocol=web')
@ -74,11 +74,11 @@ class RemoteFollowTest(TestCase):
self.assertEqual(400, got.status_code)
def test_no_protocol(self, _):
got = self.client.post('/remote-follow?address=@foo@bar&domain=me')
got = self.client.post('/remote-follow?address=@foo@bar&domain=user.com')
self.assertEqual(400, got.status_code)
def test_unknown_protocol(self, _):
got = self.client.post('/remote-follow?address=@foo@bar&domain=me&protocol=foo')
got = self.client.post('/remote-follow?address=@foo@bar&domain=user.com&protocol=foo')
self.assertEqual(400, got.status_code)
def test_no_user(self, _):
@ -87,9 +87,9 @@ class RemoteFollowTest(TestCase):
def test(self, mock_get):
mock_get.return_value = WEBFINGER
got = self.client.post('/remote-follow?address=@foo@bar&domain=me&protocol=web')
got = self.client.post('/remote-follow?address=@foo@bar&domain=user.com&protocol=web')
self.assertEqual(302, got.status_code)
self.assertEqual('https://bar/follow?uri=@me@me',
self.assertEqual('https://bar/follow?uri=@user.com@user.com',
got.headers['Location'])
mock_get.assert_has_calls((
@ -98,9 +98,9 @@ class RemoteFollowTest(TestCase):
def test_url(self, mock_get):
mock_get.return_value = WEBFINGER
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
self.assertEqual(302, got.status_code)
self.assertEqual('https://bar/follow?uri=@me@me', got.headers['Location'])
self.assertEqual('https://bar/follow?uri=@user.com@user.com', got.headers['Location'])
mock_get.assert_has_calls((
self.req('https://bar/.well-known/webfinger?resource=https://bar/foo'),
@ -112,23 +112,23 @@ class RemoteFollowTest(TestCase):
'links': [{'rel': 'other', 'template': 'meh'}],
})
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
self.assertEqual(302, got.status_code)
self.assertEqual('/web/me', got.headers['Location'])
self.assertEqual('/web/user.com', got.headers['Location'])
def test_webfinger_error(self, mock_get):
mock_get.return_value = requests_response(status=500)
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
self.assertEqual(302, got.status_code)
self.assertEqual('/web/me', got.headers['Location'])
self.assertEqual('/web/user.com', got.headers['Location'])
def test_webfinger_returns_not_json(self, mock_get):
mock_get.return_value = requests_response('<html>not json</html>')
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
self.assertEqual(302, got.status_code)
self.assertEqual('/web/me', got.headers['Location'])
self.assertEqual('/web/user.com', got.headers['Location'])
@patch('requests.post')

Wyświetl plik

@ -407,6 +407,30 @@ class WebTest(TestCase):
def assert_object(self, id, **props):
return super().assert_object(id, delivered_protocol='activitypub', **props)
def test_put_validates_domain_id(self, *_):
for bad in (
'AbC.cOm',
'foo',
'@user.com',
'@user.com@user.com',
'acct:user.com',
'acct:@user.com@user.com',
'acc:me@user.com',
):
with self.assertRaises(AssertionError):
Web(id=bad).put()
def test_get_or_create_lower_cases_domain(self, *_):
user = Web.get_or_create('AbC.oRg')
self.assertEqual('abc.org', user.key.id())
self.assert_entities_equal(user, Web.get_by_id('abc.org'))
self.assertIsNone(Web.get_by_id('AbC.oRg'))
def test_get_or_create_unicode_domain(self, *_):
user = Web.get_or_create('☃.net')
self.assertEqual('☃.net', user.key.id())
self.assert_entities_equal(user, Web.get_by_id('☃.net'))
def test_bad_source_url(self, mock_get, mock_post):
for data in b'', {'source': 'bad'}, {'source': 'https://'}:
got = self.client.post('/webmention', data=data)
@ -1581,6 +1605,29 @@ http://this/404s
self.assertEqual('Person', user.actor_as2['type'])
self.assertEqual('http://localhost/user.com', user.actor_as2['id'])
def test_check_web_site_unicode_domain(self, mock_get, _):
mock_get.side_effect = (
requests_response(''),
requests_response(''),
)
got = self.client.post('/web-site', data={'url': 'https://☃.net/'})
self.assert_equals(302, got.status_code)
self.assert_equals('/web/%E2%98%83.net', got.headers['Location'])
self.assertIsNotNone(Web.get_by_id('☃.net'))
def test_check_web_site_lower_cases_domain(self, mock_get, _):
mock_get.side_effect = (
requests_response(''),
requests_response(''),
)
got = self.client.post('/web-site', data={'url': 'https://AbC.oRg/'})
self.assert_equals(302, got.status_code)
self.assert_equals('/web/abc.org', got.headers['Location'])
self.assertIsNotNone(Web.get_by_id('abc.org'))
self.assertIsNone(Web.get_by_id('AbC.oRg'))
def test_check_web_site_bad_url(self, _, __):
got = self.client.post('/web-site', data={'url': '!!!'})
self.assert_equals(200, got.status_code)
@ -1594,10 +1641,10 @@ http://this/404s
requests_response('', status=503),
)
got = self.client.post('/web-site', data={'url': 'https://orig/'})
got = self.client.post('/web-site', data={'url': 'https://orig.co/'})
self.assert_equals(200, got.status_code, got.headers)
self.assertTrue(get_flashed_messages()[0].startswith(
"Couldn't connect to https://orig/: "))
"Couldn't connect to https://orig.co/: "))
@patch('requests.post')

14
web.py
Wyświetl plik

@ -2,6 +2,7 @@
import datetime
import difflib
import logging
import re
import urllib.parse
from urllib.parse import urlencode, urljoin, urlparse
@ -62,6 +63,18 @@ class Web(User, Protocol):
if username != self.key.id():
return util.domain_from_link(username, minimize=False)
def put(self, *args, **kwargs):
"""Validate domain id, don't allow lower case or invalid characters."""
id = self.key.id()
assert re.match(common.DOMAIN_RE, id)
assert id.lower() == id, f'lower case is not allowed in Web key id: {id}'
return super().put(*args, **kwargs)
@classmethod
def get_or_create(cls, id, **kwargs):
"""Lower cases id (domain), then passes through to :meth:`User.get_or_create`."""
return super().get_or_create(id.lower(), **kwargs)
def web_url(self):
"""Returns this user's web URL aka web_url, eg 'https://foo.com/'."""
return f'https://{self.key.id()}/'
@ -325,6 +338,7 @@ def enter_web_site():
@app.post('/web-site')
def check_web_site():
url = request.values['url']
# this normalizes and lower cases domain
domain = util.domain_from_link(url, minimize=False)
if not domain:
flash(f'No domain found in {url}')

Wyświetl plik

@ -228,6 +228,7 @@ def fetch(addr):
return data
# TODO: why do we serve this URL? should we drop it?
app.add_url_rule(f'/acct:<regex("{common.DOMAIN_RE}"):domain>',
view_func=Actor.as_view('actor_acct'))
app.add_url_rule('/.well-known/webfinger', view_func=Webfinger.as_view('webfinger'))