From a1b3afba7f204da86df2022f82ca5c9376b3728e Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Sun, 28 Oct 2018 21:54:23 +0200 Subject: [PATCH] Move handle_send protocol identification to federation module root Add also "identify by id" function. --- federation/__init__.py | 29 +++++++++++++++++++++++ federation/inbound.py | 28 ++++++---------------- federation/protocols/diaspora/protocol.py | 9 ++++++- federation/tests/test_inbound.py | 2 +- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/federation/__init__.py b/federation/__init__.py index 6209776..9ca8bbf 100644 --- a/federation/__init__.py +++ b/federation/__init__.py @@ -1 +1,30 @@ +import importlib + +from federation.exceptions import NoSuitableProtocolFoundError + __version__ = "0.18.0-dev" + +PROTOCOLS = ( + "activitypub", + "diaspora", +) + + +def identify_protocol(method: str, value: str): + """ + Loop through protocols, import the protocol module and try to identify the id or payload. + """ + for protocol_name in PROTOCOLS: + protocol = importlib.import_module(f"federation.protocols.{protocol_name}.protocol") + if getattr(protocol, f"identify_{method}")(value): + return protocol + else: + raise NoSuitableProtocolFoundError() + + +def identify_protocol_by_id(id: str): + return identify_protocol('id', id) + + +def identify_protocol_by_payload(payload: str): + return identify_protocol('payload', payload) diff --git a/federation/inbound.py b/federation/inbound.py index ee1c064..8ed26c1 100644 --- a/federation/inbound.py +++ b/federation/inbound.py @@ -2,16 +2,11 @@ import importlib import logging from typing import Tuple, List, Callable -from federation.exceptions import NoSuitableProtocolFoundError +from federation import identify_protocol_by_payload from federation.types import UserType logger = logging.getLogger("federation") -PROTOCOLS = ( - "activitypub", - "diaspora", -) - def handle_receive( payload: str, @@ -36,24 +31,15 @@ def handle_receive( :arg sender_key_fetcher: Function that accepts sender handle and returns public key (optional) :arg skip_author_verification: Don't verify sender (test purposes, false default) :returns: Tuple of sender id, protocol name and list of entity objects - :raises NoSuitableProtocolFound: When no protocol was identified to pass message to """ logger.debug("handle_receive: processing payload: %s", payload) - found_protocol = None - for protocol_name in PROTOCOLS: - protocol = importlib.import_module("federation.protocols.%s.protocol" % protocol_name) - if protocol.identify_payload(payload): - found_protocol = protocol - break + found_protocol = identify_protocol_by_payload(payload) - if found_protocol: - logger.debug("handle_receive: using protocol %s", found_protocol.PROTOCOL_NAME) - protocol = found_protocol.Protocol() - sender, message = protocol.receive( - payload, user, sender_key_fetcher, skip_author_verification=skip_author_verification) - logger.debug("handle_receive: sender %s, message %s", sender, message) - else: - raise NoSuitableProtocolFoundError() + logger.debug("handle_receive: using protocol %s", found_protocol.PROTOCOL_NAME) + protocol = found_protocol.Protocol() + sender, message = protocol.receive( + payload, user, sender_key_fetcher, skip_author_verification=skip_author_verification) + logger.debug("handle_receive: sender %s, message %s", sender, message) mappers = importlib.import_module("federation.entities.%s.mappers" % found_protocol.PROTOCOL_NAME) entities = mappers.message_to_objects(message, sender, sender_key_fetcher, user) diff --git a/federation/protocols/diaspora/protocol.py b/federation/protocols/diaspora/protocol.py index 54179d6..7483948 100644 --- a/federation/protocols/diaspora/protocol.py +++ b/federation/protocols/diaspora/protocol.py @@ -13,7 +13,7 @@ from federation.protocols.diaspora.encrypted import EncryptedPayload from federation.protocols.diaspora.magic_envelope import MagicEnvelope from federation.types import UserType from federation.utils.diaspora import fetch_public_key -from federation.utils.text import decode_if_bytes, encode_if_text +from federation.utils.text import decode_if_bytes, encode_if_text, validate_handle logger = logging.getLogger("federation") @@ -22,6 +22,13 @@ PROTOCOL_NS = "https://joindiaspora.com/protocol" MAGIC_ENV_TAG = "{http://salmon-protocol.org/ns/magic-env}env" +def identify_id(id: str) -> bool: + """ + Try to identify if this ID is a Diaspora ID. + """ + return validate_handle(id) + + def identify_payload(payload): """Try to identify whether this is a Diaspora payload. diff --git a/federation/tests/test_inbound.py b/federation/tests/test_inbound.py index 4da8f8b..47f61fa 100644 --- a/federation/tests/test_inbound.py +++ b/federation/tests/test_inbound.py @@ -8,7 +8,7 @@ from federation.protocols.diaspora.protocol import Protocol from federation.tests.fixtures.payloads import DIASPORA_PUBLIC_PAYLOAD -class TestHandleReceiveProtocolIdentification(): +class TestHandleReceiveProtocolIdentification: def test_handle_receive_routes_to_identified_protocol(self): payload = DIASPORA_PUBLIC_PAYLOAD with patch.object(