diff --git a/hbmqtt/broker/_broker.py b/hbmqtt/broker/_broker.py index aa820e4..1961f8f 100644 --- a/hbmqtt/broker/_broker.py +++ b/hbmqtt/broker/_broker.py @@ -30,6 +30,7 @@ class Broker: self._loop_thread = None self._message_handlers = None self._codecs = None + self._sessions = dict() def _init_states(self): self.machine = Machine(states=Broker.states, initial='new') @@ -100,6 +101,32 @@ class Broker: self.logger.debug("Broker started, ready to serve") loop.run_forever() + def discard_session(self, client_id): + if client_id in self._sessions: + del self._sessions[client_id] + else: + self.logger.warn("Asked to discard an unknown client ID session") + + def create_session(self, remote_address, remote_port, client_id, clean_session): + session = Session(remote_address, remote_port, client_id, clean_session) + if client_id in self._sessions: + raise BrokerException("Session already exists for client ID: %s", client_id) + self.save_session(session) + return session + + def get_session(self, client_id): + if client_id not in self._sessions: + raise BrokerException("Unknown session for client ID: %s", client_id) + else: + return self._sessions[client_id] + + def resume_session(self, session: Session): + # TBD + pass + + def save_session(self, session: Session): + self._sessions[session.client_id] = session + @asyncio.coroutine def _handle_message(self, message: MQTTMessage) -> MQTTMessage: handler = self._message_handlers[message.mqtt_header.message_type] @@ -130,7 +157,8 @@ class Broker: request = yield from self._decode_message(header.message_type, reader) (remote_address, remote_port) = writer.get_extra_info('peername') - session = Session(remote_address, remote_port, request.client_id) + request.remote_address = remote_address + request.remote_address = remote_port response = self._handle_message(request) encoded_response = yield from self._encode_message(response) diff --git a/hbmqtt/broker/handlers/connect.py b/hbmqtt/broker/handlers/connect.py index e5cf57e..cd85dea 100644 --- a/hbmqtt/broker/handlers/connect.py +++ b/hbmqtt/broker/handlers/connect.py @@ -3,13 +3,69 @@ # See the file license.txt for copying permission. import asyncio -from hbmqtt.message import MQTTMessage, ConnectMessage +from hbmqtt.message import MQTTMessage, ConnectMessage, ConnackMessage +from hbmqtt.broker.session import Session, ClientState +from hbmqtt.errors import BrokerException class ConnectHandler: def __init__(self, broker): - self._broker = broker + self.broker = broker @asyncio.coroutine - def handle(message: ConnectMessage) -> MQTTMessage: - # TBD - pass \ No newline at end of file + def handle(self, message: ConnectMessage) -> MQTTMessage: + session = None + response = None + + # Check Protocol + # protocol level (only MQTT 3.1.1 supported) + if message.proto_level != 4: + return ConnackMessage(False, ConnackMessage.ReturnCode.UNACCEPTABLE_PROTOCOL_VERSION) + + # No client ID provided + if message.client_id is None or message.client_id == "": + if message.is_clean_session(): + # [MQTT-3.1.3-6] and [MQTT-3.1.3-7] + message.client_id = self.gen_client_id() + else: + # [MQTT-3.1.3-8] : Identifier rejected + return ConnackMessage(False, ConnackMessage.ReturnCode.IDENTIFIER_REJECTED) + + if message.is_clean_session(): + try: + self.broker.discard_session(message.client_id) + except BrokerException: + pass + session = self.broker.create_session(message._remote_address, message._remote_port, message.client_id, message.is_clean_session()) + # [MQTT-3.2.2-1] + response = ConnackMessage(False, 0) + else: + try: + session = self.broker.get_session(message.client_id) + if session.client_state == ClientState.CONNECTED: + # [MQTT-3.1.4-2] + # TODO : Add logging + return ConnackMessage(False, ConnackMessage.ReturnCode.IDENTIFIER_REJECTED) + else: + # [MQTT-3.2.2-2] + response = ConnackMessage(True, 0) + except BrokerException: + session = self.broker.create_session(message._remote_address, message._remote_port, message.client_id, message.is_clean_session()) + response = ConnackMessage(False, 0) + if session.client_state == ClientState.DISCONNECTED: + self.broker.resume_session(session) + session.client_state = ClientState.CONNECTED + + if message.is_will_flag(): + session.will_flag = True + session.will_message = message.will_message + session.will_qos = message.will_qos() + session.will_retain = message.is_will_retain() + + session.keep_alive = message.keep_alive + + self.broker.save_session(session) + return response + + def gen_client_id(self): + import uuid + return uuid.uuid4() \ No newline at end of file diff --git a/hbmqtt/broker/session.py b/hbmqtt/broker/session.py index 2ac1318..e296ede 100644 --- a/hbmqtt/broker/session.py +++ b/hbmqtt/broker/session.py @@ -1,9 +1,21 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. +import enum + +class ClientState(enum): + CONNECTED = 1 + DISCONNECTED = 2 class Session: - def __init__(self, remote_address, remote_port, client_id): + def __init__(self, remote_address, remote_port, client_id, clean_session): self.remote_address = remote_address self.remote_port = remote_port self.client_id = client_id + self.clean_session = clean_session + self.client_state = ClientState.CONNECTED + self.will_flag = False + self.will_message = None + self.will_qos = None + self.will_retain = None + self.keep_alive = 0 diff --git a/hbmqtt/codecs/connect.py b/hbmqtt/codecs/connect.py index 584a286..80bbaad 100644 --- a/hbmqtt/codecs/connect.py +++ b/hbmqtt/codecs/connect.py @@ -36,9 +36,6 @@ class ConnectCodec: # protocol level (only MQTT 3.1.1 supported) protocol_level_byte = yield from read_or_raise(reader, 1) protocol_level = bytes_to_int(protocol_level_byte) - if protocol_level != 4: - raise ConnectException( - '[MQTT-3.1.2-2] Unsupported protocol level %s' % bytes_to_hex_str(protocol_level_byte)) # flags flags_byte = yield from read_or_raise(reader, 1) diff --git a/hbmqtt/message.py b/hbmqtt/message.py index be9946f..bb6c9ad 100644 --- a/hbmqtt/message.py +++ b/hbmqtt/message.py @@ -26,7 +26,7 @@ def get_message_type(byte): return MessageType(byte) class MQTTHeader: - def __init__(self, msg_type, flags, length): + def __init__(self, msg_type, flags=0, length=0): if isinstance(msg_type, int): enum_type = msg_type else: @@ -36,7 +36,7 @@ class MQTTHeader: self.flags = flags class MQTTMessage: - def __init__(self, header: MQTTHeader): + def __init__(self, header): # MQTT header self.mqtt_header = header @@ -79,3 +79,19 @@ class ConnectMessage(MQTTMessage): def will_qos(self): return (self.flags & 0x18) >> 3 + + +class ConnackMessage(MQTTMessage): + def __init__(self, session_parent, return_code): + header = MQTTHeader(MessageType.CONNACK) + super().__init__(header) + self.session_parent = session_parent + self.return_code = return_code + + class ReturnCode(enum): + CONNECTION_ACCEPTED = 0 + UNACCEPTABLE_PROTOCOL_VERSION = 1 + IDENTIFIER_REJECTED = 2 + SERVER_UNAVAILABLE = 3 + BAD_USERNAME_PASSWORD = 4 + NOT_AUTHORIZED = 5 \ No newline at end of file