A little refactoring moving sender key fetcher from Protocol.__init__ to Protocol.receive

merge-requests/130/head
Jason Robinson 2015-07-03 13:20:09 +03:00
rodzic 001060a37f
commit 97c4a63c85
4 zmienionych plików z 26 dodań i 22 usunięć

Wyświetl plik

@ -8,8 +8,14 @@ PROTOCOLS = (
)
def handle_receive(payload, user=None):
"""Takes a payload and passes it to the correct protocol."""
def handle_receive(payload, user=None, sender_key_fetcher=None):
"""Takes a payload and passes it to the correct protocol.
Args:
payload (str) - Payload blob
user (optional, obj) - User that will be passed to `protocol.receive`
sender_key_fetcher (optional, func) - Function that accepts sender handle and returns public key
"""
protocol = None
for protocol_name in PROTOCOLS:
protocol = importlib.import_module("federation.protocols.%s.protocol" % protocol_name)
@ -18,6 +24,6 @@ def handle_receive(payload, user=None):
if protocol:
proto_obj = protocol.Protocol()
return proto_obj.receive(payload, user)
return proto_obj.receive(payload, user, sender_key_fetcher)
else:
raise NoSuitableProtocolFoundError()

Wyświetl plik

@ -2,13 +2,13 @@ import logging
def identify_payload(payload):
"""Each protocol module should define an 'identify_payload' method.
"""Each protocol module should define an `identify_payload` method.
Args:
payload (str) - Payload blob
payload (str) - Payload blob
Returns:
True or False - A boolean whether the payload matches this protocol.
True or False - A boolean whether the payload matches this protocol.
"""
raise NotImplementedError("Implement in protocol module")
@ -29,14 +29,15 @@ class BaseProtocol(object):
"""Send a payload."""
raise NotImplementedError("Implement in subclass")
def receive(self, payload, user=None, *args, **kwargs):
def receive(self, payload, user=None, sender_key_fetcher=None, *args, **kwargs):
"""Receive a payload.
Args:
payload (str) - Payload blob
user (object) - Optional target user entry
If given, MUST contain `key` attribute which corresponds to user
decrypted private key
payload (str) - Payload blob
user (optional, obj) - Target user object
If given, MUST contain `key` attribute which corresponds to user
decrypted private key
sender_key_fetcher (optional, func) - Function that accepts sender handle and returns public key
Returns tuple of:
str - Sender handle ie user@domain.tld

Wyświetl plik

@ -29,13 +29,10 @@ class Protocol(BaseProtocol):
protocol_ns = "https://joindiaspora.com/protocol"
user_agent = 'social-federation/diaspora/0.1'
def __init__(self, contact_key_fetcher=None, *args, **kwargs):
super(Protocol, self).__init__()
self.get_contact_key = contact_key_fetcher
def receive(self, payload, user=None, *args, **kwargs):
def receive(self, payload, user=None, sender_key_fetcher=None, *args, **kwargs):
"""Receive a payload."""
self.user = user
self.get_contact_key = sender_key_fetcher
xml = unquote_plus(payload)
xml = xml.lstrip().encode("utf-8")
self.doc = etree.fromstring(xml)

Wyświetl plik

@ -63,10 +63,10 @@ class TestDiasporaProtocol():
assert protocol.encrypted is True
def test_receive_unencrypted_returns_sender_and_content(self):
protocol = self.init_protocol(contact_key_fetcher=mock_get_contact_key)
protocol = self.init_protocol()
user = self.get_mock_user()
protocol.get_message_content = self.mock_get_message_content
sender, content = protocol.receive(UNENCRYPTED_DOCUMENT, user)
sender, content = protocol.receive(UNENCRYPTED_DOCUMENT, user, mock_get_contact_key)
assert sender == "bob@example.com"
assert content == "<content />"
@ -82,10 +82,10 @@ class TestDiasporaProtocol():
protocol.receive(ENCRYPTED_DOCUMENT, user)
def test_receive_raises_if_sender_key_cannot_be_found(self):
protocol = self.init_protocol(contact_key_fetcher=mock_not_found_get_contact_key)
protocol = self.init_protocol()
user = self.get_mock_user()
with pytest.raises(NoSenderKeyFoundError):
protocol.receive(UNENCRYPTED_DOCUMENT, user)
protocol.receive(UNENCRYPTED_DOCUMENT, user, mock_not_found_get_contact_key)
def test_get_message_content(self):
protocol = self.init_protocol()
@ -96,8 +96,8 @@ class TestDiasporaProtocol():
body = protocol.get_message_content()
assert body == urlsafe_b64decode("{data}".encode("ascii"))
def init_protocol(self, contact_key_fetcher=None):
return Protocol(contact_key_fetcher=contact_key_fetcher)
def init_protocol(self):
return Protocol()
def get_unencrypted_doc(self):
return etree.fromstring(UNENCRYPTED_DOCUMENT)