diff --git a/models.py b/models.py index 283468c..0e67032 100644 --- a/models.py +++ b/models.py @@ -230,6 +230,7 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): user = cls.get_by_id(id) if user: # override direct from False => True if set + # TODO: propagate more props into user? direct = kwargs.get('direct') if direct and not user.direct: logger.info(f'Setting {user.key} direct={direct}') diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 6be6a4a..45c1b1d 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -81,6 +81,11 @@ class ATProtoTest(TestCase): self.store_object(id='did:plc:foo', raw=DID_DOC) self.assertEqual('han.dull', ATProto(id='did:plc:foo').handle) + @patch('requests.get', return_value=requests_response(DID_DOC)) + def test_get_or_create(self, _): + user = ATProto.get_or_create('did:plc:foo') + self.assertEqual('han.dull', user.key.get().handle) + def test_put_blocks_atproto_did(self): with self.assertRaises(AssertionError): ATProto(id='did:plc:123', atproto_did='did:plc:456').put() diff --git a/tests/test_webfinger.py b/tests/test_webfinger.py index 7bf24e2..983bbcb 100644 --- a/tests/test_webfinger.py +++ b/tests/test_webfinger.py @@ -1,4 +1,3 @@ -# coding=utf-8 """Unit tests for webfinger.py.""" import copy from unittest.mock import patch @@ -14,6 +13,7 @@ from webfinger import fetch, fetch_actor_url from . import test_web + WEBFINGER = { 'subject': 'acct:user.com@user.com', 'aliases': [ @@ -189,7 +189,9 @@ class WebfingerTest(TestCase): self.assertEqual('application/jrd+json', got.headers['Content-Type']) self.assert_equals(WEBFINGER_FAKE_FA_BRID_GY, got.json) - def test_handle(self): + def test_handle_new_user(self): + self.assertIsNone(Fake.get_by_id('fake:user')) + got = self.client.get( '/.well-known/webfinger?resource=acct:fake:handle:user@fake.brid.gy', base_url='https://fed.brid.gy/', diff --git a/webfinger.py b/webfinger.py index fed7f63..b30e962 100644 --- a/webfinger.py +++ b/webfinger.py @@ -59,6 +59,7 @@ class Webfinger(flask_util.XrdOrJrd): if not cls: cls = Protocol.for_request(fed='web') + # is this a handle? if cls.owns_id(id) is False: logger.info(f'{id} is not a {cls.LABEL} id') handle = id