User.get_or_create: load user profile object, fetch if it doesn't exist

pull/691/head
Ryan Barrett 2023-10-19 15:01:19 -07:00
rodzic 4faf551f8f
commit fe3a9b693c
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
7 zmienionych plików z 61 dodań i 38 usunięć

Wyświetl plik

@ -384,7 +384,8 @@ class ActivityPub(User, Protocol):
if (activity.get('type') == 'Delete' and obj_id if (activity.get('type') == 'Delete' and obj_id
and keyId == fragmentless(obj_id)): and keyId == fragmentless(obj_id)):
logger.info('Object/actor being deleted is also keyId') logger.info('Object/actor being deleted is also keyId')
key_actor = Object(id=keyId, source_protocol='activitypub', deleted=True) key_actor = Object.get_or_create(
id=keyId, source_protocol='activitypub', deleted=True)
key_actor.put() key_actor.put()
else: else:
raise raise

Wyświetl plik

@ -224,7 +224,7 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
@classmethod @classmethod
@ndb.transactional() @ndb.transactional()
def get_or_create(cls, id, propagate=False, **kwargs): def get_or_create(cls, id, propagate=False, **kwargs):
"""Loads and returns a :class:`User`\. Creates it if necessary. """Loads and returns a :class:`User`. Creates it if necessary.
Args: Args:
propagate (bool): whether to create copies of this user in push-based propagate (bool): whether to create copies of this user in push-based
@ -245,9 +245,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
else: else:
user = cls(id=id, **kwargs) user = cls(id=id, **kwargs)
if propagate: # load user profile object, refreshing if necessary
# force refresh user profile user.obj = cls.load(user.profile_id(), remote=True if propagate else None)
user.obj = cls.load(user.profile_id(), remote=True)
if propagate and cls.LABEL != 'atproto' and not user.atproto_did: if propagate and cls.LABEL != 'atproto' and not user.atproto_did:
PROTOCOLS['atproto'].create_for(user) PROTOCOLS['atproto'].create_for(user)
@ -750,7 +749,7 @@ class Object(StringIdModel):
authorized = (as1.get_ids(orig_as1, 'author') + authorized = (as1.get_ids(orig_as1, 'author') +
as1.get_ids(orig_as1, 'actor')) as1.get_ids(orig_as1, 'actor'))
if not actor: if not actor:
logger.warning(f'Cowardly refusing to overwrite {id} without checking actor') logger.warning(f'would cowardly refuse to overwrite {id} without checking actor')
elif actor not in authorized + [id]: elif actor not in authorized + [id]:
logger.warning(f"actor {actor} isn't {id}'s author or actor {authorized}") logger.warning(f"actor {actor} isn't {id}'s author or actor {authorized}")
else: else:

Wyświetl plik

@ -449,7 +449,9 @@ class ActivityPubTest(TestCase):
resp = self.post('/ap/sharedInbox', json=note) resp = self.post('/ap/sharedInbox', json=note)
self.assertEqual(400, resp.status_code) self.assertEqual(400, resp.status_code)
def test_inbox_no_matching_protocol(self, *_): def test_inbox_no_matching_protocol(self, mock_head, mock_get, mock_post):
# TODO: remove
mock_get.return_value = self.as2_resp(ACTOR)
resp = self.post('/foo.json/inbox', json=NOTE) resp = self.post('/foo.json/inbox', json=NOTE)
self.assertEqual(400, resp.status_code) self.assertEqual(400, resp.status_code)
@ -687,13 +689,12 @@ class ActivityPubTest(TestCase):
object_ids=['https://user.com/orig']) object_ids=['https://user.com/orig'])
def test_shared_inbox_repost_of_fediverse(self, mock_head, mock_get, mock_post): def test_shared_inbox_repost_of_fediverse(self, mock_head, mock_get, mock_post):
Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), to = self.make_user(ACTOR['id'], cls=ActivityPub)
from_=self.user) Follower.get_or_create(to=to, from_=self.user)
baz = self.make_user('fake:baz', cls=Fake, obj_id='fake:baz') baz = self.make_user('fake:baz', cls=Fake, obj_id='fake:baz')
Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), from_=baz) Follower.get_or_create(to=to, from_=baz)
baj = self.make_user('fake:baj', cls=Fake, obj_id='fake:baj') baj = self.make_user('fake:baj', cls=Fake, obj_id='fake:baj')
Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), Follower.get_or_create(to=to, from_=baj, status='inactive')
from_=baj, status='inactive')
mock_head.return_value = requests_response(url='http://target') mock_head.return_value = requests_response(url='http://target')
mock_get.side_effect = [ mock_get.side_effect = [
@ -752,7 +753,7 @@ class ActivityPubTest(TestCase):
object_ids=['http://nope.com/post']) object_ids=['http://nope.com/post'])
def test_inbox_not_public(self, mock_head, mock_get, mock_post): def test_inbox_not_public(self, mock_head, mock_get, mock_post):
Follower.get_or_create(to=ActivityPub.get_or_create(ACTOR['id']), Follower.get_or_create(to=self.make_user(ACTOR['id'], cls=ActivityPub),
from_=self.user) from_=self.user)
mock_head.return_value = requests_response(url='http://target') mock_head.return_value = requests_response(url='http://target')
@ -983,7 +984,7 @@ class ActivityPubTest(TestCase):
def test_inbox_undo_follow(self, mock_head, mock_get, mock_post): def test_inbox_undo_follow(self, mock_head, mock_get, mock_post):
follower = Follower(to=self.user.key, follower = Follower(to=self.user.key,
from_=ActivityPub.get_or_create(ACTOR['id']).key, from_=ActivityPub(id=ACTOR['id']).key,
status='active') status='active')
follower.put() follower.put()
@ -1000,9 +1001,10 @@ class ActivityPubTest(TestCase):
self.assertEqual('inactive', follower.key.get().status) self.assertEqual('inactive', follower.key.get().status)
def test_inbox_follow_inactive(self, mock_head, mock_get, mock_post): def test_inbox_follow_inactive(self, mock_head, mock_get, mock_post):
follower = Follower.get_or_create(to=self.user, follower = Follower.get_or_create(
from_=ActivityPub.get_or_create(ACTOR['id']), to=self.user,
status='inactive') from_=self.make_user(ACTOR['id'], cls=ActivityPub),
status='inactive')
mock_head.return_value = requests_response(url='https://user.com/') mock_head.return_value = requests_response(url='https://user.com/')
mock_get.side_effect = [ mock_get.side_effect = [
@ -1188,14 +1190,14 @@ class ActivityPubTest(TestCase):
mock_common_log.assert_any_call('Returning 401: No HTTP Signature', exc_info=None) mock_common_log.assert_any_call('Returning 401: No HTTP Signature', exc_info=None)
def test_delete_actor(self, *mocks): def test_delete_actor(self, *mocks):
follower = Follower.get_or_create( deleted = self.make_user(DELETE['actor'], cls=ActivityPub)
to=self.user, from_=ActivityPub.get_or_create(DELETE['actor'])) follower = Follower.get_or_create(to=self.user, from_=deleted)
followee = Follower.get_or_create( followee = Follower.get_or_create(to=deleted, from_=Fake(id='fake:user'))
to=ActivityPub.get_or_create(DELETE['actor']),
from_=Fake.get_or_create('snarfed.org'))
# other unrelated follower # other unrelated follower
other = Follower.get_or_create( other = self.make_user('https://mas.to/users/other', cls=ActivityPub)
to=self.user, from_=ActivityPub.get_or_create('https://mas.to/users/other')) other = Follower.get_or_create(to=self.user, from_=other)
self.assertEqual(3, Follower.query().count()) self.assertEqual(3, Follower.query().count())
got = self.post('/ap/sharedInbox', json=DELETE) got = self.post('/ap/sharedInbox', json=DELETE)

Wyświetl plik

@ -84,7 +84,7 @@ class ATProtoTest(TestCase):
@patch('requests.get', return_value=requests_response(DID_DOC)) @patch('requests.get', return_value=requests_response(DID_DOC))
def test_get_or_create(self, _): def test_get_or_create(self, _):
user = ATProto.get_or_create('did:plc:foo') user = self.make_user('did:plc:foo', cls=ATProto)
self.assertEqual('han.dull', user.key.get().handle) self.assertEqual('han.dull', user.key.get().handle)
def test_put_blocks_atproto_did(self): def test_put_blocks_atproto_did(self):

Wyświetl plik

@ -154,10 +154,10 @@ class FollowTest(TestCase):
def test_callback_address(self, mock_get, mock_post): def test_callback_address(self, mock_get, mock_post):
mock_get.side_effect = ( mock_get.side_effect = (
# oauth-dropins indieauth https://alice.com fetch for user json requests_response(''), # indieauth https://alice.com fetch for user json
requests_response(''),
WEBFINGER, WEBFINGER,
self.as2_resp(FOLLOWEE), self.as2_resp(FOLLOWEE),
self.as2_resp(FOLLOWEE),
) )
mock_post.side_effect = ( mock_post.side_effect = (
requests_response('me=https://alice.com'), requests_response('me=https://alice.com'),
@ -173,7 +173,8 @@ class FollowTest(TestCase):
def test_callback_url(self, mock_get, mock_post): def test_callback_url(self, mock_get, mock_post):
mock_get.side_effect = ( mock_get.side_effect = (
requests_response(''), requests_response(''), # indieauth https://alice.com fetch for user json
self.as2_resp(FOLLOWEE),
self.as2_resp(FOLLOWEE), self.as2_resp(FOLLOWEE),
) )
mock_post.side_effect = ( mock_post.side_effect = (
@ -234,6 +235,7 @@ class FollowTest(TestCase):
mock_get.side_effect = ( mock_get.side_effect = (
requests_response(''), requests_response(''),
self.as2_resp(followee), self.as2_resp(followee),
self.as2_resp(followee),
) )
mock_post.side_effect = ( mock_post.side_effect = (
requests_response('me=https://alice.com'), requests_response('me=https://alice.com'),
@ -304,6 +306,7 @@ class FollowTest(TestCase):
mock_get.side_effect = ( mock_get.side_effect = (
requests_response(''), requests_response(''),
self.as2_resp(FOLLOWEE), self.as2_resp(FOLLOWEE),
self.as2_resp(FOLLOWEE),
) )
mock_post.side_effect = ( mock_post.side_effect = (
requests_response('me=https://alice.com'), requests_response('me=https://alice.com'),
@ -349,6 +352,7 @@ class FollowTest(TestCase):
mock_get.side_effect = ( mock_get.side_effect = (
requests_response(''), requests_response(''),
self.as2_resp(followee), self.as2_resp(followee),
self.as2_resp(followee),
) )
mock_post.side_effect = ( mock_post.side_effect = (
requests_response('me=https://alice.com'), requests_response('me=https://alice.com'),
@ -371,6 +375,7 @@ class FollowTest(TestCase):
def test_indieauthed_session(self, mock_get, mock_post): def test_indieauthed_session(self, mock_get, mock_post):
mock_get.side_effect = ( mock_get.side_effect = (
self.as2_resp(FOLLOWEE), self.as2_resp(FOLLOWEE),
self.as2_resp(FOLLOWEE),
) )
mock_post.side_effect = ( mock_post.side_effect = (
requests_response('OK'), # AP Follow to inbox requests_response('OK'), # AP Follow to inbox

Wyświetl plik

@ -465,18 +465,24 @@ class WebTest(TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
Web(id=bad).put() Web(id=bad).put()
def test_get_or_create_lower_cases_domain(self, *_): def test_get_or_create_lower_cases_domain(self, mock_get, mock_post):
mock_get.return_value = requests_response('')
user = Web.get_or_create('AbC.oRg') user = Web.get_or_create('AbC.oRg')
self.assertEqual('abc.org', user.key.id()) self.assertEqual('abc.org', user.key.id())
self.assert_entities_equal(user, Web.get_by_id('abc.org')) self.assert_entities_equal(user, Web.get_by_id('abc.org'))
self.assertIsNone(Web.get_by_id('AbC.oRg')) self.assertIsNone(Web.get_by_id('AbC.oRg'))
def test_get_or_create_unicode_domain(self, *_): def test_get_or_create_unicode_domain(self, mock_get, mock_post):
mock_get.return_value = requests_response('')
user = Web.get_or_create('☃.net') user = Web.get_or_create('☃.net')
self.assertEqual('☃.net', user.key.id()) self.assertEqual('☃.net', user.key.id())
self.assert_entities_equal(user, Web.get_by_id('☃.net')) self.assert_entities_equal(user, Web.get_by_id('☃.net'))
def test_get_or_create_scripts_leading_trailing_dots(self, *_): def test_get_or_create_scripts_leading_trailing_dots(self, mock_get, mock_post):
mock_get.return_value = requests_response('')
user = Web.get_or_create('..foo.bar.') user = Web.get_or_create('..foo.bar.')
self.assertEqual('foo.bar', user.key.id()) self.assertEqual('foo.bar', user.key.id())
self.assert_entities_equal(user, Web.get_by_id('foo.bar')) self.assert_entities_equal(user, Web.get_by_id('foo.bar'))
@ -1861,8 +1867,7 @@ http://this/404s
redir = 'http://localhost/.well-known/webfinger?resource=acct:user.com@user.com' redir = 'http://localhost/.well-known/webfinger?resource=acct:user.com@user.com'
mock_get.side_effect = ( mock_get.side_effect = (
requests_response('', status=302, redirected_url=redir), requests_response('', status=302, redirected_url=redir),
requests_response(ACTOR_HTML, url='https://user.com/', ACTOR_HTML_RESP,
content_type=CONTENT_TYPE_HTML),
) )
got = self.post('/web-site', data={'url': 'https://user.com/'}) got = self.post('/web-site', data={'url': 'https://user.com/'})
@ -1877,6 +1882,7 @@ http://this/404s
mock_get.side_effect = ( mock_get.side_effect = (
requests_response(''), requests_response(''),
requests_response(''), requests_response(''),
requests_response(''),
) )
got = self.post('/web-site', data={'url': 'https://☃.net/'}) got = self.post('/web-site', data={'url': 'https://☃.net/'})
@ -1888,6 +1894,7 @@ http://this/404s
mock_get.side_effect = ( mock_get.side_effect = (
requests_response(''), requests_response(''),
requests_response(''), requests_response(''),
requests_response(''),
) )
got = self.post('/web-site', data={'url': 'https://AbC.oRg/'}) got = self.post('/web-site', data={'url': 'https://AbC.oRg/'})
@ -1918,9 +1925,10 @@ http://this/404s
get_flashed_messages()) get_flashed_messages())
self.assertEqual(1, Web.query().count()) self.assertEqual(1, Web.query().count())
def test_check_web_site_fetch_fails(self, mock_get, _): def test_check_webfinger_redirects_then_fails(self, mock_get, _):
redir = 'http://localhost/.well-known/webfinger?resource=acct:orig@orig' redir = 'http://localhost/.well-known/webfinger?resource=acct:orig@orig'
mock_get.side_effect = ( mock_get.side_effect = (
ACTOR_HTML_RESP,
requests_response('', status=302, redirected_url=redir), requests_response('', status=302, redirected_url=redir),
requests_response('', status=503), requests_response('', status=503),
) )
@ -1930,6 +1938,14 @@ http://this/404s
self.assertTrue(get_flashed_messages()[0].startswith( self.assertTrue(get_flashed_messages()[0].startswith(
"Couldn't connect to https://orig.co/: ")) "Couldn't connect to https://orig.co/: "))
def test_check_web_site_fetch_fails(self, mock_get, _):
mock_get.return_value = requests_response('', status=503)
got = self.post('/web-site', data={'url': 'https://orig.co/'})
self.assert_equals(200, got.status_code, got.headers)
self.assertTrue(get_flashed_messages()[0].startswith(
"Couldn't connect to https://orig.co/: "))
@patch('requests.post') @patch('requests.post')
@patch('requests.get') @patch('requests.get')

8
web.py
Wyświetl plik

@ -181,14 +181,14 @@ class Web(User, Protocol):
logger.info(f'Verifying {domain}') logger.info(f'Verifying {domain}')
if domain.startswith('www.') and domain not in WWW_DOMAINS: if domain.startswith('www.') and domain not in WWW_DOMAINS:
# if root domain redirects to www, use root domain instead # if root domain serves ok, use it instead
# https://github.com/snarfed/bridgy-fed/issues/314 # https://github.com/snarfed/bridgy-fed/issues/314
root = domain.removeprefix("www.") root = domain.removeprefix("www.")
root_site = f'https://{root}/' root_site = f'https://{root}/'
try: try:
resp = util.requests_get(root_site, gateway=False) resp = util.requests_get(root_site, gateway=False)
if resp.ok and self.is_web_url(resp.url): if resp.ok and self.is_web_url(resp.url):
logger.info(f'{root_site} redirects to {resp.url} ; using {root} instead') logger.info(f'{root_site} serves ok ; using {root} instead')
root_user = Web.get_or_create(root) root_user = Web.get_or_create(root)
self.use_instead = root_user.key self.use_instead = root_user.key
self.put() self.put()
@ -321,7 +321,7 @@ class Web(User, Protocol):
logger.info(f'Skipping sending {verb} (not supported in webmention/mf2) to {url}') logger.info(f'Skipping sending {verb} (not supported in webmention/mf2) to {url}')
return False return False
elif url not in as1.targets(obj.as1): elif url not in as1.targets(obj.as1):
logger.info(f'Skipping sending to {url} , not a target in the object') # logger.info(f'Skipping sending to {url} , not a target in the object')
return False return False
elif to_cls.is_blocklisted(url): elif to_cls.is_blocklisted(url):
logger.info(f'Skipping sending to blocklisted {url}') logger.info(f'Skipping sending to blocklisted {url}')
@ -478,8 +478,8 @@ def check_web_site():
flash(f'{url} is not a valid or supported web site') flash(f'{url} is not a valid or supported web site')
return render_template('enter_web_site.html'), 400 return render_template('enter_web_site.html'), 400
g.user = Web.get_or_create(domain, direct=True)
try: try:
g.user = Web.get_or_create(domain, direct=True)
g.user = g.user.verify() g.user = g.user.verify()
except BaseException as e: except BaseException as e:
code, body = util.interpret_http_exception(e) code, body = util.interpret_http_exception(e)