diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index d51a620..b44dac4 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -14,7 +14,7 @@ from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from hbmqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT from hbmqtt.mqtt.connect import ConnectPacket from hbmqtt.mqtt.connack import * -from hbmqtt.errors import HBMQTTException +from hbmqtt.errors import HBMQTTException, MQTTException from hbmqtt.utils import format_client_message, gen_client_id from hbmqtt.mqtt.packet import PUBLISH from hbmqtt.codecs import int_to_bytes_str @@ -27,6 +27,10 @@ from hbmqtt.adapters import ( WebSocketsWriter) from .plugins.manager import PluginManager, BaseContext +import sys +if sys.version_info < (3, 5): + from asyncio import async as ensure_future + _defaults = { 'timeout-disconnect-delay': 2, @@ -51,6 +55,8 @@ EVENT_BROKER_PRE_START = 'broker_pre_start' EVENT_BROKER_POST_START = 'broker_post_start' EVENT_BROKER_PRE_SHUTDOWN = 'broker_pre_shutdown' EVENT_BROKER_POST_SHUTDOWN = 'broker_post_shutdown' +EVENT_BROKER_CLIENT_CONNECTED = 'broker_client_connected' +EVENT_BROKER_CLIENT_DISCONNECTED = 'broker_client_disconnected' class BrokerException(BaseException): @@ -407,131 +413,74 @@ class Broker: @asyncio.coroutine def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter): - # Wait for connection available - server = self._servers[listener_name] + # Wait for connection available on listener + server = self._servers.get(listener_name, None) + if not server: + raise BrokerException("Invalid listener name '%s'" % listener_name) yield from server.acquire_connection() remote_address, remote_port = writer.get_peer_info() self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name)) # Wait for first packet and expect a CONNECT - connect = None try: - connect = yield from ConnectPacket.from_stream(reader) - yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect) - self.check_connect(connect) + handler, client_session = yield from BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager) except HBMQTTException as exc: self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" % (format_client_message(address=remote_address, port=remote_port), exc)) yield from writer.close() self.logger.debug("Connection closed") return - except BrokerException as be: + except MQTTException as me: self.logger.error('Invalid connection from %s : %s' % - (format_client_message(address=remote_address, port=remote_port), be)) - yield from writer.close() - self.logger.debug("Connection closed") - return - if connect.proto_name != "MQTT": - self.logger.warn('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.variable_header.protocol_name) + (format_client_message(address=remote_address, port=remote_port), me)) yield from writer.close() self.logger.debug("Connection closed") return - connack = None - if connect.proto_level != 4: - # only MQTT 3.1.1 supported - self.logger.error('Invalid protocol from %s: %d' % - (format_client_message(address=remote_address, port=remote_port), - connect.variable_header.protocol_level)) - connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0 - elif connect.username_flag and connect.username is None: - self.logger.error('Invalid username from %s' % - (format_client_message(address=remote_address, port=remote_port))) - connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0 - elif connect.password_flag and connect.password is None: - self.logger.error('Invalid password %s' % (format_client_message(address=remote_address, port=remote_port))) - connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0 - elif connect.clean_session_flag is False and connect.payload.client_id is None: - self.logger.error('[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' % - format_client_message(address=remote_address, port=remote_port)) - connack = ConnackPacket.build(0, IDENTIFIER_REJECTED) - if connack is not None: - yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack) - yield from connack.to_stream(writer) - yield from writer.close() - return - - client_session = None - self.logger.debug("Clean session={0}".format(connect.clean_session_flag)) - self.logger.debug("known sessions={0}".format(self._sessions)) - client_id = connect.client_id - if connect.clean_session_flag: + if client_session.clean_session: # Delete existing session and create a new one - if client_id is not None: - self.delete_session(client_id) + if client_session.client_id is not None: + self.delete_session(client_session.client_id) else: - client_id = gen_client_id() - client_session = Session() + client_session.client_id = gen_client_id() client_session.parent = 0 - client_session.client_id = client_id else: # Get session from cache - if client_id in self._sessions: - self.logger.debug("Found old session %s" % repr(self._sessions[client_id])) - client_session = self._sessions[client_id] + if client_session.client_id in self._sessions: + self.logger.debug("Found old session %s" % repr(self._sessions[client_session.client_id])) + (client_session,) = self._sessions[client_session.client_id] client_session.parent = 1 else: - client_session = Session() - client_session.client_id = client_id client_session.parent = 0 - - client_session.remote_address = remote_address - client_session.remote_port = remote_port - client_session.clean_session = connect.clean_session_flag - client_session.will_flag = connect.will_flag - client_session.will_retain = connect.will_retain_flag - client_session.will_qos = connect.will_qos - client_session.will_topic = connect.will_topic - client_session.will_message = connect.will_message - client_session.username = connect.username - client_session.password = connect.password - if connect.keep_alive > 0: - client_session.keep_alive = connect.keep_alive + self.config['timeout-disconnect-delay'] - else: - client_session.keep_alive = 0 + if client_session.keep_alive > 0: + client_session.keep_alive += self.config['timeout-disconnect-delay'] + self.logger.debug("Keep-alive timeout=%d" % client_session.keep_alive) client_session.publish_retry_delay = self.config['publish-retry-delay'] + handler.attach(client_session, reader, writer) + self._sessions[client_session.client_id] = (client_session, handler) + authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name]) - if authenticated: - connack = ConnackPacket.build(client_session.parent, CONNECTION_ACCEPTED) - self.logger.info('%s : connection accepted' % format_client_message(session=client_session)) - yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack, session=client_session) - yield from connack.to_stream(writer) - else: - connack = ConnackPacket.build(client_session.parent, NOT_AUTHORIZED) - self.logger.info('%s : connection refused' % format_client_message(session=client_session)) - yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack, session=client_session) - yield from connack.to_stream(writer) + yield from handler.mqtt_connack_authorize(authenticated) + if not authenticated: yield from writer.close() return client_session.transitions.connect() - handler = self._init_handler(client_session, reader, writer) - self._sessions[client_id] = (client_session, handler) + yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, session=client_session) self.logger.debug("%s Start messages handling" % client_session.client_id) yield from handler.start() self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize()) yield from self.publish_session_retained_messages(client_session) - self.logger.debug("%s Wait for disconnect" % client_session.client_id) # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect) connected = True - disconnect_waiter = asyncio.Task(handler.wait_disconnect(), loop=self._loop) - subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop) - unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop) - wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop) + disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop) + subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription(), loop=self._loop) + unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop) + wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop) while connected: done, pending = yield from asyncio.wait( [disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver], @@ -586,6 +535,7 @@ class Broker: # Acknowledge message delivery yield from handler.mqtt_acknowledge_delivery(packet_id) wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop) + disconnect_waiter.cancel() subscribe_waiter.cancel() unsubscribe_waiter.cancel() wait_deliver.cancel() @@ -593,6 +543,7 @@ class Broker: self.logger.debug("%s Client disconnecting" % client_session.client_id) yield from self._stop_handler(handler) client_session.transitions.disconnect() + yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, session=client_session) yield from writer.close() self.logger.debug("%s Session disconnected" % client_session.client_id) server.release_connection() @@ -602,8 +553,8 @@ class Broker: Create a BrokerProtocolHandler and attach to a session :return: """ - handler = BrokerProtocolHandler(session, self.plugins_manager, self._loop) - handler.attach_stream(reader, writer) + handler = BrokerProtocolHandler(self.plugins_manager, self._loop) + handler.attach(session, reader, writer) handler.on_packet_received.connect(self.sys_handle_packet_received) handler.on_packet_sent.connect(self.sys_handle_packet_sent) return handler @@ -620,17 +571,6 @@ class Broker: except Exception as e: self.logger.error(e) - def check_connect(self, connect: ConnectPacket): - if connect.payload.client_id is None: - raise BrokerException('[[MQTT-3.1.3-3]] : Client identifier must be present' ) - - if connect.variable_header.will_flag: - if connect.payload.will_topic is None or connect.payload.will_message is None: - raise BrokerException('will flag set, but will topic/message not present in payload') - - if connect.variable_header.reserved_flag: - raise BrokerException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0') - @asyncio.coroutine def authenticate(self, session: Session, listener): """ @@ -653,13 +593,14 @@ class Broker: session=session, filter_plugins=auth_plugins) auth_result = True - for plugin in returns: - res = returns[plugin] - if res is False: - auth_result = False - self.logger.debug("Authentication failed due to '%s' plugin result: %s" % (plugin.name, res)) - else: - self.logger.debug("'%s' plugin result: %s" % (plugin.name, res)) + if returns: + for plugin in returns: + res = returns[plugin] + if res is False: + auth_result = False + self.logger.debug("Authentication failed due to '%s' plugin result: %s" % (plugin.name, res)) + else: + self.logger.debug("'%s' plugin result: %s" % (plugin.name, res)) # If all plugins returned True, authentication is success return auth_result diff --git a/hbmqtt/client.py b/hbmqtt/client.py index 83a86d2..cde31dd 100644 --- a/hbmqtt/client.py +++ b/hbmqtt/client.py @@ -19,13 +19,16 @@ from hbmqtt.mqtt.constants import * import websockets from websockets.uri import InvalidURI from websockets.handshake import InvalidHandshake +from collections import deque _defaults = { 'keep_alive': 10, 'ping_delay': 1, 'default_qos': 0, 'default_retain': False, - 'auto_reconnect': True + 'auto_reconnect': True, + 'reconnect_max_interval': 10, + 'reconnect_retries': 2, } @@ -114,6 +117,7 @@ class MQTTClient: context = ClientContext() context.config = self.config self.plugins_manager = PluginManager('hbmqtt.client.plugins', context) + self.client_tasks = deque() @asyncio.coroutine @@ -133,7 +137,15 @@ class MQTTClient: self.session = self._initsession(uri, cleansession, cafile, capath, cadata) self.logger.debug("Connect to: %s" % uri) - return (yield from self._do_connect()) + try: + return (yield from self._do_connect()) + except BaseException as be: + self.logger.warning("Connection failed: %r" % be) + auto_reconnect = self.config.get('auto_reconnect', False) + if not auto_reconnect: + raise + else: + return (yield from self.reconnect()) @asyncio.coroutine @mqtt_connected @@ -157,12 +169,30 @@ class MQTTClient: if cleansession: self.session.clean_session = cleansession self.logger.debug("Reconnecting with session parameters: %s" % self.session) - return (yield from self._do_connect()) + reconnect_max_interval = self.config.get('reconnect_max_interval', 10) + reconnect_retries = self.config.get('reconnect_retries', 5) + nb_attempt = 1 + yield from asyncio.sleep(1, loop=self._loop) + while True: + try: + self.logger.debug("Reconnect attempt %d ..." % nb_attempt) + return (yield from self._do_connect()) + except BaseException as e: + self.logger.warning("Reconnection attempt failed: %r" % e) + if nb_attempt > reconnect_retries: + self.logger.error("Maximum number of connection attempts reached. Reconnection aborted") + raise ConnectException("Too many connection attempts failed") + exp = 2 ** nb_attempt + delay = exp if exp < reconnect_max_interval else reconnect_max_interval + self.logger.debug("Waiting %d second before next attempt" % delay) + yield from asyncio.sleep(delay, loop=self._loop) + nb_attempt += 1 + @asyncio.coroutine def _do_connect(self): return_code = yield from self._connect_coro() - self._disconnect_task = asyncio.Task(self.handle_connection_close(), loop=self._loop) + self._disconnect_task = asyncio.ensure_future(self.handle_connection_close(), loop=self._loop) return return_code @asyncio.coroutine @@ -213,6 +243,17 @@ class MQTTClient: def unsubscribe(self, topics): yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id) + @asyncio.coroutine + def deliver_message(self, timeout=None): + deliver_task = asyncio.ensure_future(self._handler.mqtt_deliver_next_message(), loop=self._loop) + self.client_tasks.append(deliver_task) + self.logger.debug("Waiting message delivery") + message = yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout) + if deliver_task.exception(): + raise deliver_task.exception() + self.client_tasks.pop() + return message + @asyncio.coroutine def _connect_coro(self): kwargs = dict() @@ -235,8 +276,8 @@ class MQTTClient: uri_attributes[3], uri_attributes[4], uri_attributes[5]) self.session.broker_uri = urlunparse(uri) # Init protocol handler - if not self._handler: - self._handler = ClientProtocolHandler(self.session, self.plugins_manager, loop=self._loop) + #if not self._handler: + self._handler = ClientProtocolHandler(self.plugins_manager, loop=self._loop) if secure: if self.session.cafile is None or self.session.cafile == '': @@ -252,6 +293,8 @@ class MQTTClient: kwargs['ssl'] = sc try: + reader = None + writer = None self._connected_state.clear() # Open connection if scheme in ('mqtt', 'mqtts'): @@ -270,7 +313,7 @@ class MQTTClient: reader = WebSocketsReader(websocket) writer = WebSocketsWriter(websocket) # Start MQTT protocol - self._handler.attach_stream(reader, writer) + self._handler.attach(self.session, reader, writer) return_code = yield from self._handler.mqtt_connect() if return_code is not CONNECTION_ACCEPTED: self.session.transitions.disconnect() @@ -293,23 +336,39 @@ class MQTTClient: self.logger.warn("connection failed: invalid websocket handshake") self.session.transitions.disconnect() raise ConnectException("connection failed: invalid websocket handshake", ihs) - except ProtocolHandlerException as e: - self.logger.warn("MQTT connection failed: %s" % e) + except (ProtocolHandlerException, ConnectionError, OSError) as e: + self.logger.warn("MQTT connection failed: %r" % e) self.session.transitions.disconnect() - raise ClientException("connection Failed: %s" % e) + raise ConnectException(e) @asyncio.coroutine def handle_connection_close(self): self.logger.debug("Watch broker disconnection") + # Wait for disconnection from broker (like connection lost) yield from self._handler.wait_disconnect() - self._connected_state.clear() self.logger.warning("Disconnected from broker") + + # Block client API + self._connected_state.clear() + + # stop an clean handler yield from self._handler.stop() - self._handler.detach_stream() + self._handler.detach() self.session.transitions.disconnect() + if self.config.get('auto_reconnect', False): + # Try reconnection self.logger.debug("Auto-reconnecting") - yield from self.reconnect() + try: + yield from self.reconnect() + except ConnectException: + # Cancel client pending tasks + while self.client_tasks: + self.client_tasks.popleft().set_exception(ClientException("Connection lost")) + else: + # Cancel client pending tasks + while self.client_tasks: + self.client_tasks.popleft().set_exception(ClientException("Connection lost")) def _initsession( self, diff --git a/hbmqtt/mqtt/protocol/broker_handler.py b/hbmqtt/mqtt/protocol/broker_handler.py index 6a312c6..2b5f584 100644 --- a/hbmqtt/mqtt/protocol/broker_handler.py +++ b/hbmqtt/mqtt/protocol/broker_handler.py @@ -2,9 +2,11 @@ # # See the file license.txt for copying permission. import asyncio +import logging from asyncio import futures from hbmqtt.mqtt.protocol.handler import ProtocolHandler from hbmqtt.mqtt.connect import ConnectPacket +from hbmqtt.mqtt.connack import * from hbmqtt.mqtt.pingreq import PingReqPacket from hbmqtt.mqtt.pingresp import PingRespPacket from hbmqtt.mqtt.subscribe import SubscribePacket @@ -14,11 +16,14 @@ from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.utils import format_client_message from hbmqtt.session import Session from hbmqtt.plugins.manager import PluginManager +from hbmqtt.adapters import ReaderAdapter, WriterAdapter +from hbmqtt.errors import MQTTException +from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT class BrokerProtocolHandler(ProtocolHandler): - def __init__(self, session: Session, plugins_manager: PluginManager, loop=None): - super().__init__(session, plugins_manager, loop) + def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None): + super().__init__(plugins_manager, session, loop) self._disconnect_waiter = None self._pending_subscriptions = asyncio.Queue(loop=self._loop) self._pending_unsubscriptions = asyncio.Queue(loop=self._loop) @@ -97,3 +102,81 @@ class BrokerProtocolHandler(ProtocolHandler): def mqtt_acknowledge_unsubscription(self, packet_id): unsuback = UnsubackPacket.build(packet_id) yield from self._send_packet(unsuback) + + @asyncio.coroutine + def mqtt_connack_authorize(self, authorize: bool): + if authorize: + connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED) + else: + connack = ConnackPacket.build(self.session.parent, NOT_AUTHORIZED) + yield from self._send_packet(connack) + + @classmethod + @asyncio.coroutine + def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None): + """ + + :param reader: + :param writer: + :param plugins_manager: + :param loop: + :return: + """ + log = logging.getLogger(__name__) + remote_address, remote_port = writer.get_peer_info() + connect = yield from ConnectPacket.from_stream(reader) + yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect) + if connect.payload.client_id is None: + raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' ) + + if connect.variable_header.will_flag: + if connect.payload.will_topic is None or connect.payload.will_message is None: + raise MQTTException('will flag set, but will topic/message not present in payload') + + if connect.variable_header.reserved_flag: + raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0') + if connect.proto_name != "MQTT": + raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.variable_header.protocol_name) + + connack = None + error_msg = None + if connect.proto_level != 4: + # only MQTT 3.1.1 supported + error_msg = 'Invalid protocol from %s: %d' % \ + (format_client_message(address=remote_address, port=remote_port), + connect.variable_header.protocol_level) + connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0 + elif connect.username_flag and connect.username is None: + error_msg = 'Invalid username from %s' % \ + (format_client_message(address=remote_address, port=remote_port)) + connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0 + elif connect.password_flag and connect.password is None: + error_msg = 'Invalid password %s' % (format_client_message(address=remote_address, port=remote_port)) + connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0 + elif connect.clean_session_flag is False and connect.payload.client_id is None: + error_msg = '[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' % \ + format_client_message(address=remote_address, port=remote_port) + connack = ConnackPacket.build(0, IDENTIFIER_REJECTED) + if connack is not None: + yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack) + yield from connack.to_stream(writer) + yield from writer.close() + raise MQTTException(error_msg) + + incoming_session = Session() + incoming_session.client_id = connect.client_id + incoming_session.clean_session = connect.clean_session_flag + incoming_session.will_flag = connect.will_flag + incoming_session.will_retain = connect.will_retain_flag + incoming_session.will_qos = connect.will_qos + incoming_session.will_topic = connect.will_topic + incoming_session.will_message = connect.will_message + incoming_session.username = connect.username + incoming_session.password = connect.password + if connect.keep_alive > 0: + incoming_session.keep_alive = connect.keep_alive + else: + incoming_session.keep_alive = 0 + + handler = cls(plugins_manager, loop) + return handler, incoming_session diff --git a/hbmqtt/mqtt/protocol/client_handler.py b/hbmqtt/mqtt/protocol/client_handler.py index 5684ad4..2202557 100644 --- a/hbmqtt/mqtt/protocol/client_handler.py +++ b/hbmqtt/mqtt/protocol/client_handler.py @@ -2,7 +2,7 @@ # # See the file license.txt for copying permission. from asyncio import futures -from hbmqtt.mqtt.protocol.handler import ProtocolHandler, ProtocolHandlerException +from hbmqtt.mqtt.protocol.handler import ProtocolHandler, EVENT_MQTT_PACKET_RECEIVED from hbmqtt.mqtt.packet import * from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.mqtt.pingreq import PingReqPacket @@ -18,8 +18,8 @@ from hbmqtt.plugins.manager import PluginManager class ClientProtocolHandler(ProtocolHandler): - def __init__(self, session: Session, plugins_manager: PluginManager, loop=None): - super().__init__(session, plugins_manager, loop=loop) + def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None): + super().__init__(plugins_manager, session, loop=loop) self._ping_task = None self._pingresp_queue = asyncio.Queue(loop=self._loop) self._subscriptions_waiter = dict() @@ -38,11 +38,15 @@ class ClientProtocolHandler(ProtocolHandler): yield from super().stop() if self._ping_task: try: + self.logger.debug("Cancel ping task") self._ping_task.cancel() - except Exception: + except BaseException: pass if self._pingresp_waiter: self._pingresp_waiter.cancel() + if not self._disconnect_waiter.done(): + self._disconnect_waiter.cancel() + self._disconnect_waiter = None def _build_connect_packet(self): vh = ConnectVariableHeader() @@ -80,10 +84,16 @@ class ClientProtocolHandler(ProtocolHandler): connect_packet = self._build_connect_packet() yield from self._send_packet(connect_packet) connack = yield from ConnackPacket.from_stream(self.reader) + yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session) return connack.return_code def handle_write_timeout(self): - self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping()) + try: + self.logger.debug("Scheduling Ping") + if not self._ping_task: + self._ping_task = asyncio.ensure_future(self.mqtt_ping()) + except BaseException as be: + self.logger.debug("Exception ignored in ping task: %r" % be) def handle_read_timeout(self): pass @@ -143,7 +153,6 @@ class ClientProtocolHandler(ProtocolHandler): def mqtt_disconnect(self): disconnect_packet = DisconnectPacket() yield from self._send_packet(disconnect_packet) - self._connack_waiter = None @asyncio.coroutine def mqtt_ping(self): diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 2f49e8a..fd3c218 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -50,18 +50,16 @@ class ProtocolHandler: on_packet_sent = Signal() on_packet_received = Signal() - def __init__(self, session: Session, plugins_manager: PluginManager, loop=None): - log = logging.getLogger(__name__) - self.logger = logging.LoggerAdapter(log, {'client_id': session.client_id}) - self.session = session + def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None): + self.logger = logging.getLogger(__name__) + if session: + self._init_session(session) + else: + self.session = None self.reader = None self.writer = None self.plugins_manager = plugins_manager - self.keepalive_timeout = self.session.keep_alive - if self.keepalive_timeout <= 0: - self.keepalive_timeout = None - if loop is None: self._loop = asyncio.get_event_loop() else: @@ -76,18 +74,29 @@ class ProtocolHandler: self._pubrel_waiters = dict() self._pubcomp_waiters = dict() - def attach_stream(self, reader: ReaderAdapter, writer: WriterAdapter): - if self.reader or self.writer: - raise ProtocolHandlerException("Handler is already attached to an opened stream") + def _init_session(self, session: Session): + assert session + log = logging.getLogger(__name__) + self.session = session + self.logger = logging.LoggerAdapter(log, {'client_id': self.session.client_id}) + self.keepalive_timeout = self.session.keep_alive + if self.keepalive_timeout <= 0: + self.keepalive_timeout = None + + def attach(self, session, reader: ReaderAdapter, writer: WriterAdapter): + if self.session: + raise ProtocolHandlerException("Handler is already attached to a session") + self._init_session(session) self.reader = reader self.writer = writer - def detach_stream(self): + def detach(self): + self.session = None self.reader = None self.writer = None def _is_attached(self): - if self.reader and self.writer: + if self.session: return True else: return False @@ -109,13 +118,14 @@ class ProtocolHandler: @asyncio.coroutine def stop(self): # Stop messages flow waiter - self._reader_task.cancel() self._stop_waiters() if self._keepalive_task: self._keepalive_task.cancel() self.logger.debug("waiting for tasks to be stopped") - yield from asyncio.wait( - [self._reader_stopped.wait()], loop=self._loop) + if not self._reader_task.done(): + self._reader_task.cancel() + yield from asyncio.wait( + [self._reader_stopped.wait()], loop=self._loop) self.logger.debug("closing writer") try: yield from self.writer.close() @@ -392,8 +402,8 @@ class ProtocolHandler: self.handle_read_timeout() except NoDataException: self.logger.debug("%s No data available" % self.session.client_id) - except Exception as e: - self.logger.warning("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e)) + except BaseException as e: + self.logger.warning("%s Unhandled exception in reader coro: %s" % (type(self).__name__, e)) break yield from self.handle_connection_closed() self._reader_stopped.set() @@ -412,7 +422,7 @@ class ProtocolHandler: except ConnectionResetError as cre: yield from self.handle_connection_closed() raise - except Exception as e: + except BaseException as e: self.logger.warning("Unhandled exception: %s" % e) raise diff --git a/hbmqtt/session.py b/hbmqtt/session.py index 88fcc11..86e6638 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -113,3 +113,6 @@ class Session: self.__dict__.update(state) self.retained_messages = Queue() self.delivered_message_queue = Queue() + + def __eq__(self, other): + return self.client_id == other.client_id \ No newline at end of file diff --git a/samples/client_subscribe.py b/samples/client_subscribe.py index c281955..5ffee9f 100644 --- a/samples/client_subscribe.py +++ b/samples/client_subscribe.py @@ -1,7 +1,7 @@ import logging import asyncio -from hbmqtt.client import MQTTClient +from hbmqtt.client import MQTTClient, ClientException from hbmqtt.mqtt.constants import QOS_1, QOS_2 @@ -17,22 +17,25 @@ C = MQTTClient() @asyncio.coroutine def uptime_coro(): - yield from C.connect('mqtt://test.mosquitto.org:1883/') + yield from C.connect('mqtt://localhost/') # Subscribe to '$SYS/broker/uptime' with QOS=1 yield from C.subscribe([ ('$SYS/broker/uptime', QOS_1), ('$SYS/broker/load/#', QOS_2), ]) logger.info("Subscribed") - for i in range(1, 100): - packet = yield from C.deliver_message() - print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data))) - yield from C.unsubscribe(['$SYS/broker/uptime']) - logger.info("UnSubscribed") - yield from C.disconnect() + try: + for i in range(1, 100): + packet = yield from C.deliver_message() + print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data))) + yield from C.unsubscribe(['$SYS/broker/uptime']) + logger.info("UnSubscribed") + yield from C.disconnect() + except ClientException as ce: + logger.error("Client exception: %s" % ce) if __name__ == '__main__': formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" - logging.basicConfig(level=logging.INFO, format=formatter) + logging.basicConfig(level=logging.DEBUG, format=formatter) asyncio.get_event_loop().run_until_complete(uptime_coro()) \ No newline at end of file diff --git a/tests/mqtt/protocol/test_handler.py b/tests/mqtt/protocol/test_handler.py index 3d6bf63..1ce6339 100644 --- a/tests/mqtt/protocol/test_handler.py +++ b/tests/mqtt/protocol/test_handler.py @@ -35,8 +35,8 @@ class ProtocolHandlerTest(unittest.TestCase): def test_init_handler(self): s = Session() - handler = ProtocolHandler(s, self.plugin_manager, loop=self.loop) - self.assertIs(handler.session, s) + handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.assertIsNone(handler.session) self.assertIs(handler._loop, self.loop) self.check_empty_waiters(handler) @@ -51,8 +51,8 @@ class ProtocolHandlerTest(unittest.TestCase): s = Session() reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888) reader_adapted, writer_adapted = adapt(reader, writer) - handler = ProtocolHandler(s, self.plugin_manager) - handler.attach_stream(reader_adapted, writer_adapted) + handler = ProtocolHandler(self.plugin_manager) + handler.attach(s, reader_adapted, writer_adapted) yield from self.start_handler(handler, s) yield from self.stop_handler(handler, s) future.set_result(True) @@ -79,15 +79,14 @@ class ProtocolHandlerTest(unittest.TestCase): except Exception as ae: future.set_exception(ae) - @asyncio.coroutine def test_coro(): try: s = Session() reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - handler = ProtocolHandler(s, self.plugin_manager, loop=self.loop) - handler.attach_stream(reader_adapted, writer_adapted) + handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + handler.attach(s, reader_adapted, writer_adapted) yield from self.start_handler(handler, s) message = yield from handler.mqtt_publish('/topic', b'test_data', QOS_0, False) self.assertIsInstance(message, OutgoingApplicationMessage) @@ -130,8 +129,8 @@ class ProtocolHandlerTest(unittest.TestCase): try: reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) - self.handler.attach_stream(reader_adapted, writer_adapted) + self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.handler.attach(self.session, reader_adapted, writer_adapted) yield from self.start_handler(self.handler, self.session) message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False) self.assertIsInstance(message, OutgoingApplicationMessage) @@ -182,8 +181,8 @@ class ProtocolHandlerTest(unittest.TestCase): try: reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) - self.handler.attach_stream(reader_adapted, writer_adapted) + self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.handler.attach(self.session, reader_adapted, writer_adapted) yield from self.start_handler(self.handler, self.session) message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False) self.assertIsInstance(message, OutgoingApplicationMessage) @@ -220,8 +219,8 @@ class ProtocolHandlerTest(unittest.TestCase): try: reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) - self.handler.attach_stream(reader_adapted, writer_adapted) + self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.handler.attach(self.session, reader_adapted, writer_adapted) yield from self.start_handler(self.handler, self.session) message = yield from self.handler.mqtt_deliver_next_message() self.assertIsInstance(message, IncomingApplicationMessage) @@ -264,8 +263,8 @@ class ProtocolHandlerTest(unittest.TestCase): try: reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) - self.handler.attach_stream(reader_adapted, writer_adapted) + self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.handler.attach(self.session, reader_adapted, writer_adapted) yield from self.start_handler(self.handler, self.session) message = yield from self.handler.mqtt_deliver_next_message() self.assertIsInstance(message, IncomingApplicationMessage) @@ -385,8 +384,8 @@ class ProtocolHandlerTest(unittest.TestCase): try: reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) - self.handler.attach_stream(reader_adapted, writer_adapted) + self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.handler.attach(self.session, reader_adapted, writer_adapted) yield from self.handler.start() yield from self.stop_handler(self.handler, self.session) if not future.done(): @@ -433,8 +432,8 @@ class ProtocolHandlerTest(unittest.TestCase): try: reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader_adapted, writer_adapted = adapt(reader, writer) - self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) - self.handler.attach_stream(reader_adapted, writer_adapted) + self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop) + self.handler.attach(self.session, reader_adapted, writer_adapted) yield from self.handler.start() yield from self.stop_handler(self.handler, self.session) if not future.done(): diff --git a/tests/test_broker.py b/tests/test_broker.py index c8cb42f..9ad489e 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -7,11 +7,26 @@ import asyncio import logging from hbmqtt.broker import * from hbmqtt.mqtt.constants import * +from hbmqtt.client import MQTTClient formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" logging.basicConfig(level=logging.DEBUG, format=formatter) log = logging.getLogger(__name__) +test_config = { + 'listeners': { + 'default': { + 'type': 'tcp', + 'bind': 'localhost:1883', + 'max_connections': 10 + }, + }, + 'sys_interval': 0, + 'auth': { + 'allow-anonymous': True, + } +} + class BrokerTest(unittest.TestCase): def setUp(self): @@ -23,25 +38,12 @@ class BrokerTest(unittest.TestCase): @patch('hbmqtt.broker.PluginManager') def test_start_stop(self, MockPluginManager): - config = { - 'listeners': { - 'default': { - 'type': 'tcp', - 'bind': '0.0.0.0:1883', - 'max_connections': 10 - }, - }, - 'sys_interval': 0, - 'auth': { - 'allow-anonymous': True, - } - } - def test_coro(): try: - broker = Broker(config, plugin_namespace="hbmqtt.test.plugins") + broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins") yield from broker.start() self.assertTrue(broker.transitions.is_started()) + self.assertDictEqual(broker._sessions, {}) self.assertIn('default', broker._servers) MockPluginManager.assert_has_calls( [call().fire_event(EVENT_BROKER_PRE_START), @@ -60,3 +62,29 @@ class BrokerTest(unittest.TestCase): self.loop.run_until_complete(test_coro()) if future.exception(): raise future.exception() + + @patch('hbmqtt.broker.PluginManager') + def test_client_connect(self, MockPluginManager): + def test_coro(): + try: + broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins") + yield from broker.start() + self.assertTrue(broker.transitions.is_started()) + client = MQTTClient() + ret = yield from client.connect('mqtt://localhost/') + self.assertEqual(ret, 0) + yield from client.disconnect() + yield from asyncio.sleep(0.1) + yield from broker.shutdown() + self.assertTrue(broker.transitions.is_stopped()) + MockPluginManager.assert_has_calls( + [call().fire_event(EVENT_BROKER_CLIENT_CONNECTED, session=client.session), + call().fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, session=client.session)], any_order=True) + future.set_result(True) + except Exception as ae: + future.set_exception(ae) + + future = asyncio.Future(loop=self.loop) + self.loop.run_until_complete(test_coro()) + if future.exception(): + raise future.exception() diff --git a/tests/test_client.py b/tests/test_client.py index b4cbef8..55159a6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,7 +5,7 @@ import unittest import asyncio import os import logging -from hbmqtt.client import MQTTClient +from hbmqtt.client import MQTTClient, ConnectException from hbmqtt.mqtt.constants import * formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" @@ -60,9 +60,10 @@ class MQTTClientTest(unittest.TestCase): @asyncio.coroutine def test_coro(): try: - client = MQTTClient() + config = {'auto_reconnect': False} + client = MQTTClient(config=config) ret = yield from client.connect('mqtt://localhost/') - except Exception as e: + except ConnectException as e: future.set_result(True) future = asyncio.Future(loop=self.loop)