refactor validating handles in ATProto and elsewhere

for https://github.com/snarfed/bridgy-fed/issues/982
pull/1020/head
Ryan Barrett 2024-05-03 15:18:16 -07:00
rodzic b8e67829e3
commit 2bf526ab7c
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
7 zmienionych plików z 62 dodań i 33 usunięć

Wyświetl plik

@ -158,7 +158,7 @@ class ActivityPub(User, Protocol):
return False return False
@classmethod @classmethod
def owns_handle(cls, handle): def owns_handle(cls, handle, allow_internal=False):
"""Returns True if handle is a WebFinger ``@-@`` handle, False otherwise. """Returns True if handle is a WebFinger ``@-@`` handle, False otherwise.
Example: ``@user@instance.com``. The leading ``@`` is optional. Example: ``@user@instance.com``. The leading ``@`` is optional.
@ -171,7 +171,8 @@ class ActivityPub(User, Protocol):
return False return False
user, domain = parts user, domain = parts
return user and domain and not cls.is_blocklisted(domain) return user and domain and not cls.is_blocklisted(
domain, allow_internal=allow_internal)
@classmethod @classmethod
def handle_to_id(cls, handle): def handle_to_id(cls, handle):

Wyświetl plik

@ -128,8 +128,9 @@ class ATProto(User, Protocol):
or id.startswith('https://bsky.app/')) or id.startswith('https://bsky.app/'))
@classmethod @classmethod
def owns_handle(cls, handle): def owns_handle(cls, handle, allow_internal=False):
if not re.match(DOMAIN_RE, handle): # TODO: implement allow_internal
if not did.HANDLE_RE.fullmatch(handle):
return False return False
@classmethod @classmethod
@ -248,6 +249,10 @@ class ATProto(User, Protocol):
Args: Args:
user (models.User) user (models.User)
Raises:
ValueError: if the user's handle is invalid, eg begins or ends with an
underscore or dash
""" """
assert not isinstance(user, ATProto) assert not isinstance(user, ATProto)

38
ids.py
Wyświetl plik

@ -6,7 +6,6 @@ import logging
import re import re
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from arroba import did
from flask import request from flask import request
from google.cloud.ndb.query import FilterNode, Query from google.cloud.ndb.query import FilterNode, Query
from granary.bluesky import BSKY_APP_URL_RE, web_url_to_at_uri from granary.bluesky import BSKY_APP_URL_RE, web_url_to_at_uri
@ -162,50 +161,57 @@ def translate_handle(*, handle, from_, to, enhanced):
Returns: Returns:
str: the corresponding handle in ``to`` str: the corresponding handle in ``to``
Raises:
ValueError: if the user's handle is invalid, eg begins or ends with an
underscore or dash
""" """
assert handle and from_ and to, (handle, from_, to) assert handle and from_ and to, (handle, from_, to)
assert from_.owns_handle(handle) is not False or from_.LABEL == 'ui' if not from_.LABEL == 'ui':
if from_.owns_handle(handle, allow_internal=True) is False:
if from_.LABEL == 'atproto': raise ValueError(f'input handle {handle} is not valid for {from_.LABEL}')
assert did.HANDLE_RE.fullmatch(handle)
if from_ == to: if from_ == to:
return handle return handle
output = None
match from_.LABEL, to.LABEL: match from_.LABEL, to.LABEL:
case _, 'activitypub': case _, 'activitypub':
domain = f'{from_.ABBREV}{SUPERDOMAIN}' domain = f'{from_.ABBREV}{SUPERDOMAIN}'
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS: if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
domain = handle domain = handle
return f'@{handle}@{domain}' output = f'@{handle}@{domain}'
case _, 'atproto': case _, 'atproto':
output = handle.lstrip('@').replace('@', '.')
for from_char in ATPROTO_DASH_CHARS: for from_char in ATPROTO_DASH_CHARS:
handle = handle.replace(from_char, '-') output = output.replace(from_char, '-')
handle = handle.lstrip('@').replace('@', '.')
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS: if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
pass pass
else: else:
handle = f'{handle}.{from_.ABBREV}{SUPERDOMAIN}' output = f'{output}.{from_.ABBREV}{SUPERDOMAIN}'
assert did.HANDLE_RE.fullmatch(handle)
return handle
case 'activitypub', 'web': case 'activitypub', 'web':
user, instance = handle.lstrip('@').split('@') user, instance = handle.lstrip('@').split('@')
# TODO: get this from the actor object's url field? # TODO: get this from the actor object's url field?
return (f'https://{user}' if user == instance output = (f'https://{user}' if user == instance
else f'https://{instance}/@{user}') else f'https://{instance}/@{user}')
case _, 'web': case _, 'web':
return handle output = handle
# only for unit tests # only for unit tests
case _, 'fake' | 'other' | 'eefake': case _, 'fake' | 'other' | 'eefake':
return f'{to.LABEL}:handle:{handle}' output = f'{to.LABEL}:handle:{handle}'
assert False, (handle, from_.LABEL, to.LABEL) assert output, (handle, from_.LABEL, to.LABEL)
# don't check Web handles because they're sometimes URLs, eg
# @user@instance => https://instance/@user
if to.LABEL != 'web' and to.owns_handle(output, allow_internal=True) is False:
raise ValueError(f'translated handle {output} is not valid for {to.LABEL}')
return output
def translate_object_id(*, id, from_, to): def translate_object_id(*, id, from_, to):

Wyświetl plik

@ -172,7 +172,7 @@ class Protocol:
return False return False
@classmethod @classmethod
def owns_handle(cls, handle): def owns_handle(cls, handle, allow_internal=False):
"""Returns whether this protocol owns the handle, or None if it's unclear. """Returns whether this protocol owns the handle, or None if it's unclear.
To be implemented by subclasses. To be implemented by subclasses.
@ -192,6 +192,8 @@ class Protocol:
Args: Args:
handle (str) handle (str)
allow_internal (bool): whether to return False for internal domains
like ``fed.brid.gy``, ``bsky.brid.gy``, etc
Returns: Returns:
bool or None bool or None
@ -409,6 +411,9 @@ class Protocol:
Args: Args:
user (models.User): original source user. Shouldn't already have a user (models.User): original source user. Shouldn't already have a
copy user for this protocol in :attr:`copies`. copy user for this protocol in :attr:`copies`.
Raises:
ValueError: if we can't create a copy of the given user in this protocol
""" """
raise NotImplementedError() raise NotImplementedError()

Wyświetl plik

@ -109,9 +109,9 @@ class ATProtoTest(TestCase):
self.assertEqual('han.dull', user.key.get().handle) self.assertEqual('han.dull', user.key.get().handle)
def test_owns_id(self): def test_owns_id(self):
self.assertEqual(False, ATProto.owns_id('http://foo')) self.assertFalse(ATProto.owns_id('http://foo'))
self.assertEqual(False, ATProto.owns_id('https://bar.baz/biff')) self.assertFalse(ATProto.owns_id('https://bar.baz/biff'))
self.assertEqual(False, ATProto.owns_id('e45fab982')) self.assertFalse(ATProto.owns_id('e45fab982'))
self.assertTrue(ATProto.owns_id('at://did:plc:user/bar/123')) self.assertTrue(ATProto.owns_id('at://did:plc:user/bar/123'))
self.assertTrue(ATProto.owns_id('did:plc:user')) self.assertTrue(ATProto.owns_id('did:plc:user'))
@ -123,12 +123,18 @@ class ATProtoTest(TestCase):
self.assertIsNone(ATProto.owns_handle('foo.com')) self.assertIsNone(ATProto.owns_handle('foo.com'))
self.assertIsNone(ATProto.owns_handle('foo.bar.com')) self.assertIsNone(ATProto.owns_handle('foo.bar.com'))
self.assertEqual(False, ATProto.owns_handle('foo')) self.assertFalse(ATProto.owns_handle('foo'))
self.assertEqual(False, ATProto.owns_handle('@foo')) self.assertFalse(ATProto.owns_handle('@foo'))
self.assertEqual(False, ATProto.owns_handle('@foo.com')) self.assertFalse(ATProto.owns_handle('@foo.com'))
self.assertEqual(False, ATProto.owns_handle('@foo@bar.com')) self.assertFalse(ATProto.owns_handle('@foo@bar.com'))
self.assertEqual(False, ATProto.owns_handle('foo@bar.com')) self.assertFalse(ATProto.owns_handle('foo@bar.com'))
self.assertEqual(False, ATProto.owns_handle('localhost')) self.assertFalse(ATProto.owns_handle('localhost'))
self.assertFalse(ATProto.owns_handle('_foo.com'))
self.assertFalse(ATProto.owns_handle('-foo.com'))
self.assertFalse(ATProto.owns_handle('foo_.com'))
self.assertFalse(ATProto.owns_handle('foo-.com'))
# TODO: this should be False # TODO: this should be False
self.assertIsNone(ATProto.owns_handle('web.brid.gy')) self.assertIsNone(ATProto.owns_handle('web.brid.gy'))
@ -701,6 +707,12 @@ class ATProtoTest(TestCase):
mock_create_task.assert_called() mock_create_task.assert_called()
def test_create_for_bad_handle(self):
# underscores gets translated to dashes, trailing/leading aren't allowed
for bad in 'fake:user_', '_fake:user':
with self.assertRaises(ValueError):
ATProto.create_for(Fake(id=bad))
@patch('google.cloud.dns.client.ManagedZone', autospec=True) @patch('google.cloud.dns.client.ManagedZone', autospec=True)
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
@patch('requests.post', @patch('requests.post',

Wyświetl plik

@ -108,7 +108,7 @@ class Fake(User, protocol.Protocol):
or id in cls.fetchable) or id in cls.fetchable)
@classmethod @classmethod
def owns_handle(cls, handle): def owns_handle(cls, handle, allow_internal=False):
return handle.startswith(f'{cls.LABEL}:handle:') return handle.startswith(f'{cls.LABEL}:handle:')
@classmethod @classmethod

4
web.py
Wyświetl plik

@ -358,10 +358,10 @@ class Web(User, Protocol):
return False return False
@classmethod @classmethod
def owns_handle(cls, handle): def owns_handle(cls, handle, allow_internal=False):
if handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS: if handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
return True return True
elif not is_valid_domain(handle, allow_internal=False): elif not is_valid_domain(handle, allow_internal=allow_internal):
return False return False
@classmethod @classmethod