kopia lustrzana https://github.com/snarfed/bridgy-fed
refactor validating handles in ATProto and elsewhere
for https://github.com/snarfed/bridgy-fed/issues/982pull/1020/head
rodzic
b8e67829e3
commit
2bf526ab7c
|
@ -158,7 +158,7 @@ class ActivityPub(User, Protocol):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_handle(cls, handle):
|
def owns_handle(cls, handle, allow_internal=False):
|
||||||
"""Returns True if handle is a WebFinger ``@-@`` handle, False otherwise.
|
"""Returns True if handle is a WebFinger ``@-@`` handle, False otherwise.
|
||||||
|
|
||||||
Example: ``@user@instance.com``. The leading ``@`` is optional.
|
Example: ``@user@instance.com``. The leading ``@`` is optional.
|
||||||
|
@ -171,7 +171,8 @@ class ActivityPub(User, Protocol):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
user, domain = parts
|
user, domain = parts
|
||||||
return user and domain and not cls.is_blocklisted(domain)
|
return user and domain and not cls.is_blocklisted(
|
||||||
|
domain, allow_internal=allow_internal)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def handle_to_id(cls, handle):
|
def handle_to_id(cls, handle):
|
||||||
|
|
|
@ -128,8 +128,9 @@ class ATProto(User, Protocol):
|
||||||
or id.startswith('https://bsky.app/'))
|
or id.startswith('https://bsky.app/'))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_handle(cls, handle):
|
def owns_handle(cls, handle, allow_internal=False):
|
||||||
if not re.match(DOMAIN_RE, handle):
|
# TODO: implement allow_internal
|
||||||
|
if not did.HANDLE_RE.fullmatch(handle):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -248,6 +249,10 @@ class ATProto(User, Protocol):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (models.User)
|
user (models.User)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the user's handle is invalid, eg begins or ends with an
|
||||||
|
underscore or dash
|
||||||
"""
|
"""
|
||||||
assert not isinstance(user, ATProto)
|
assert not isinstance(user, ATProto)
|
||||||
|
|
||||||
|
|
38
ids.py
38
ids.py
|
@ -6,7 +6,6 @@ import logging
|
||||||
import re
|
import re
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
from arroba import did
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from google.cloud.ndb.query import FilterNode, Query
|
from google.cloud.ndb.query import FilterNode, Query
|
||||||
from granary.bluesky import BSKY_APP_URL_RE, web_url_to_at_uri
|
from granary.bluesky import BSKY_APP_URL_RE, web_url_to_at_uri
|
||||||
|
@ -162,50 +161,57 @@ def translate_handle(*, handle, from_, to, enhanced):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: the corresponding handle in ``to``
|
str: the corresponding handle in ``to``
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the user's handle is invalid, eg begins or ends with an
|
||||||
|
underscore or dash
|
||||||
"""
|
"""
|
||||||
assert handle and from_ and to, (handle, from_, to)
|
assert handle and from_ and to, (handle, from_, to)
|
||||||
assert from_.owns_handle(handle) is not False or from_.LABEL == 'ui'
|
if not from_.LABEL == 'ui':
|
||||||
|
if from_.owns_handle(handle, allow_internal=True) is False:
|
||||||
if from_.LABEL == 'atproto':
|
raise ValueError(f'input handle {handle} is not valid for {from_.LABEL}')
|
||||||
assert did.HANDLE_RE.fullmatch(handle)
|
|
||||||
|
|
||||||
if from_ == to:
|
if from_ == to:
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
output = None
|
||||||
match from_.LABEL, to.LABEL:
|
match from_.LABEL, to.LABEL:
|
||||||
case _, 'activitypub':
|
case _, 'activitypub':
|
||||||
domain = f'{from_.ABBREV}{SUPERDOMAIN}'
|
domain = f'{from_.ABBREV}{SUPERDOMAIN}'
|
||||||
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
|
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
|
||||||
domain = handle
|
domain = handle
|
||||||
return f'@{handle}@{domain}'
|
output = f'@{handle}@{domain}'
|
||||||
|
|
||||||
case _, 'atproto':
|
case _, 'atproto':
|
||||||
|
output = handle.lstrip('@').replace('@', '.')
|
||||||
for from_char in ATPROTO_DASH_CHARS:
|
for from_char in ATPROTO_DASH_CHARS:
|
||||||
handle = handle.replace(from_char, '-')
|
output = output.replace(from_char, '-')
|
||||||
|
|
||||||
handle = handle.lstrip('@').replace('@', '.')
|
|
||||||
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
|
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
handle = f'{handle}.{from_.ABBREV}{SUPERDOMAIN}'
|
output = f'{output}.{from_.ABBREV}{SUPERDOMAIN}'
|
||||||
|
|
||||||
assert did.HANDLE_RE.fullmatch(handle)
|
|
||||||
return handle
|
|
||||||
|
|
||||||
case 'activitypub', 'web':
|
case 'activitypub', 'web':
|
||||||
user, instance = handle.lstrip('@').split('@')
|
user, instance = handle.lstrip('@').split('@')
|
||||||
# TODO: get this from the actor object's url field?
|
# TODO: get this from the actor object's url field?
|
||||||
return (f'https://{user}' if user == instance
|
output = (f'https://{user}' if user == instance
|
||||||
else f'https://{instance}/@{user}')
|
else f'https://{instance}/@{user}')
|
||||||
|
|
||||||
case _, 'web':
|
case _, 'web':
|
||||||
return handle
|
output = handle
|
||||||
|
|
||||||
# only for unit tests
|
# only for unit tests
|
||||||
case _, 'fake' | 'other' | 'eefake':
|
case _, 'fake' | 'other' | 'eefake':
|
||||||
return f'{to.LABEL}:handle:{handle}'
|
output = f'{to.LABEL}:handle:{handle}'
|
||||||
|
|
||||||
assert False, (handle, from_.LABEL, to.LABEL)
|
assert output, (handle, from_.LABEL, to.LABEL)
|
||||||
|
# don't check Web handles because they're sometimes URLs, eg
|
||||||
|
# @user@instance => https://instance/@user
|
||||||
|
if to.LABEL != 'web' and to.owns_handle(output, allow_internal=True) is False:
|
||||||
|
raise ValueError(f'translated handle {output} is not valid for {to.LABEL}')
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def translate_object_id(*, id, from_, to):
|
def translate_object_id(*, id, from_, to):
|
||||||
|
|
|
@ -172,7 +172,7 @@ class Protocol:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_handle(cls, handle):
|
def owns_handle(cls, handle, allow_internal=False):
|
||||||
"""Returns whether this protocol owns the handle, or None if it's unclear.
|
"""Returns whether this protocol owns the handle, or None if it's unclear.
|
||||||
|
|
||||||
To be implemented by subclasses.
|
To be implemented by subclasses.
|
||||||
|
@ -192,6 +192,8 @@ class Protocol:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
handle (str)
|
handle (str)
|
||||||
|
allow_internal (bool): whether to return False for internal domains
|
||||||
|
like ``fed.brid.gy``, ``bsky.brid.gy``, etc
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool or None
|
bool or None
|
||||||
|
@ -409,6 +411,9 @@ class Protocol:
|
||||||
Args:
|
Args:
|
||||||
user (models.User): original source user. Shouldn't already have a
|
user (models.User): original source user. Shouldn't already have a
|
||||||
copy user for this protocol in :attr:`copies`.
|
copy user for this protocol in :attr:`copies`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if we can't create a copy of the given user in this protocol
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -109,9 +109,9 @@ class ATProtoTest(TestCase):
|
||||||
self.assertEqual('han.dull', user.key.get().handle)
|
self.assertEqual('han.dull', user.key.get().handle)
|
||||||
|
|
||||||
def test_owns_id(self):
|
def test_owns_id(self):
|
||||||
self.assertEqual(False, ATProto.owns_id('http://foo'))
|
self.assertFalse(ATProto.owns_id('http://foo'))
|
||||||
self.assertEqual(False, ATProto.owns_id('https://bar.baz/biff'))
|
self.assertFalse(ATProto.owns_id('https://bar.baz/biff'))
|
||||||
self.assertEqual(False, ATProto.owns_id('e45fab982'))
|
self.assertFalse(ATProto.owns_id('e45fab982'))
|
||||||
|
|
||||||
self.assertTrue(ATProto.owns_id('at://did:plc:user/bar/123'))
|
self.assertTrue(ATProto.owns_id('at://did:plc:user/bar/123'))
|
||||||
self.assertTrue(ATProto.owns_id('did:plc:user'))
|
self.assertTrue(ATProto.owns_id('did:plc:user'))
|
||||||
|
@ -123,12 +123,18 @@ class ATProtoTest(TestCase):
|
||||||
self.assertIsNone(ATProto.owns_handle('foo.com'))
|
self.assertIsNone(ATProto.owns_handle('foo.com'))
|
||||||
self.assertIsNone(ATProto.owns_handle('foo.bar.com'))
|
self.assertIsNone(ATProto.owns_handle('foo.bar.com'))
|
||||||
|
|
||||||
self.assertEqual(False, ATProto.owns_handle('foo'))
|
self.assertFalse(ATProto.owns_handle('foo'))
|
||||||
self.assertEqual(False, ATProto.owns_handle('@foo'))
|
self.assertFalse(ATProto.owns_handle('@foo'))
|
||||||
self.assertEqual(False, ATProto.owns_handle('@foo.com'))
|
self.assertFalse(ATProto.owns_handle('@foo.com'))
|
||||||
self.assertEqual(False, ATProto.owns_handle('@foo@bar.com'))
|
self.assertFalse(ATProto.owns_handle('@foo@bar.com'))
|
||||||
self.assertEqual(False, ATProto.owns_handle('foo@bar.com'))
|
self.assertFalse(ATProto.owns_handle('foo@bar.com'))
|
||||||
self.assertEqual(False, ATProto.owns_handle('localhost'))
|
self.assertFalse(ATProto.owns_handle('localhost'))
|
||||||
|
|
||||||
|
self.assertFalse(ATProto.owns_handle('_foo.com'))
|
||||||
|
self.assertFalse(ATProto.owns_handle('-foo.com'))
|
||||||
|
self.assertFalse(ATProto.owns_handle('foo_.com'))
|
||||||
|
self.assertFalse(ATProto.owns_handle('foo-.com'))
|
||||||
|
|
||||||
# TODO: this should be False
|
# TODO: this should be False
|
||||||
self.assertIsNone(ATProto.owns_handle('web.brid.gy'))
|
self.assertIsNone(ATProto.owns_handle('web.brid.gy'))
|
||||||
|
|
||||||
|
@ -701,6 +707,12 @@ class ATProtoTest(TestCase):
|
||||||
|
|
||||||
mock_create_task.assert_called()
|
mock_create_task.assert_called()
|
||||||
|
|
||||||
|
def test_create_for_bad_handle(self):
|
||||||
|
# underscores gets translated to dashes, trailing/leading aren't allowed
|
||||||
|
for bad in 'fake:user_', '_fake:user':
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ATProto.create_for(Fake(id=bad))
|
||||||
|
|
||||||
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
||||||
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
||||||
@patch('requests.post',
|
@patch('requests.post',
|
||||||
|
|
|
@ -108,7 +108,7 @@ class Fake(User, protocol.Protocol):
|
||||||
or id in cls.fetchable)
|
or id in cls.fetchable)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_handle(cls, handle):
|
def owns_handle(cls, handle, allow_internal=False):
|
||||||
return handle.startswith(f'{cls.LABEL}:handle:')
|
return handle.startswith(f'{cls.LABEL}:handle:')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
4
web.py
4
web.py
|
@ -358,10 +358,10 @@ class Web(User, Protocol):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_handle(cls, handle):
|
def owns_handle(cls, handle, allow_internal=False):
|
||||||
if handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
|
if handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
|
||||||
return True
|
return True
|
||||||
elif not is_valid_domain(handle, allow_internal=False):
|
elif not is_valid_domain(handle, allow_internal=allow_internal):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Ładowanie…
Reference in New Issue