diff --git a/federation/controllers.py b/federation/controllers.py index 4397106..95c6d5b 100644 --- a/federation/controllers.py +++ b/federation/controllers.py @@ -8,8 +8,14 @@ PROTOCOLS = ( ) -def handle_receive(payload, user=None): - """Takes a payload and passes it to the correct protocol.""" +def handle_receive(payload, user=None, sender_key_fetcher=None): + """Takes a payload and passes it to the correct protocol. + + Args: + payload (str) - Payload blob + user (optional, obj) - User that will be passed to `protocol.receive` + sender_key_fetcher (optional, func) - Function that accepts sender handle and returns public key + """ protocol = None for protocol_name in PROTOCOLS: protocol = importlib.import_module("federation.protocols.%s.protocol" % protocol_name) @@ -18,6 +24,6 @@ def handle_receive(payload, user=None): if protocol: proto_obj = protocol.Protocol() - return proto_obj.receive(payload, user) + return proto_obj.receive(payload, user, sender_key_fetcher) else: raise NoSuitableProtocolFoundError() diff --git a/federation/protocols/base.py b/federation/protocols/base.py index a65fa8d..2c7cccb 100644 --- a/federation/protocols/base.py +++ b/federation/protocols/base.py @@ -2,13 +2,13 @@ import logging def identify_payload(payload): - """Each protocol module should define an 'identify_payload' method. + """Each protocol module should define an `identify_payload` method. Args: - payload (str) - Payload blob + payload (str) - Payload blob Returns: - True or False - A boolean whether the payload matches this protocol. + True or False - A boolean whether the payload matches this protocol. """ raise NotImplementedError("Implement in protocol module") @@ -29,14 +29,15 @@ class BaseProtocol(object): """Send a payload.""" raise NotImplementedError("Implement in subclass") - def receive(self, payload, user=None, *args, **kwargs): + def receive(self, payload, user=None, sender_key_fetcher=None, *args, **kwargs): """Receive a payload. Args: - payload (str) - Payload blob - user (object) - Optional target user entry - If given, MUST contain `key` attribute which corresponds to user - decrypted private key + payload (str) - Payload blob + user (optional, obj) - Target user object + If given, MUST contain `key` attribute which corresponds to user + decrypted private key + sender_key_fetcher (optional, func) - Function that accepts sender handle and returns public key Returns tuple of: str - Sender handle ie user@domain.tld diff --git a/federation/protocols/diaspora/protocol.py b/federation/protocols/diaspora/protocol.py index 8b2c156..ea7e805 100644 --- a/federation/protocols/diaspora/protocol.py +++ b/federation/protocols/diaspora/protocol.py @@ -29,13 +29,10 @@ class Protocol(BaseProtocol): protocol_ns = "https://joindiaspora.com/protocol" user_agent = 'social-federation/diaspora/0.1' - def __init__(self, contact_key_fetcher=None, *args, **kwargs): - super(Protocol, self).__init__() - self.get_contact_key = contact_key_fetcher - - def receive(self, payload, user=None, *args, **kwargs): + def receive(self, payload, user=None, sender_key_fetcher=None, *args, **kwargs): """Receive a payload.""" self.user = user + self.get_contact_key = sender_key_fetcher xml = unquote_plus(payload) xml = xml.lstrip().encode("utf-8") self.doc = etree.fromstring(xml) diff --git a/federation/tests/protocols/diaspora/test_diaspora.py b/federation/tests/protocols/diaspora/test_diaspora.py index 21450b0..dfbb7b1 100644 --- a/federation/tests/protocols/diaspora/test_diaspora.py +++ b/federation/tests/protocols/diaspora/test_diaspora.py @@ -63,10 +63,10 @@ class TestDiasporaProtocol(): assert protocol.encrypted is True def test_receive_unencrypted_returns_sender_and_content(self): - protocol = self.init_protocol(contact_key_fetcher=mock_get_contact_key) + protocol = self.init_protocol() user = self.get_mock_user() protocol.get_message_content = self.mock_get_message_content - sender, content = protocol.receive(UNENCRYPTED_DOCUMENT, user) + sender, content = protocol.receive(UNENCRYPTED_DOCUMENT, user, mock_get_contact_key) assert sender == "bob@example.com" assert content == "" @@ -82,10 +82,10 @@ class TestDiasporaProtocol(): protocol.receive(ENCRYPTED_DOCUMENT, user) def test_receive_raises_if_sender_key_cannot_be_found(self): - protocol = self.init_protocol(contact_key_fetcher=mock_not_found_get_contact_key) + protocol = self.init_protocol() user = self.get_mock_user() with pytest.raises(NoSenderKeyFoundError): - protocol.receive(UNENCRYPTED_DOCUMENT, user) + protocol.receive(UNENCRYPTED_DOCUMENT, user, mock_not_found_get_contact_key) def test_get_message_content(self): protocol = self.init_protocol() @@ -96,8 +96,8 @@ class TestDiasporaProtocol(): body = protocol.get_message_content() assert body == urlsafe_b64decode("{data}".encode("ascii")) - def init_protocol(self, contact_key_fetcher=None): - return Protocol(contact_key_fetcher=contact_key_fetcher) + def init_protocol(self): + return Protocol() def get_unencrypted_doc(self): return etree.fromstring(UNENCRYPTED_DOCUMENT)