From cbdb97aefc7dcb5f10f6a64af97238f7ffcf5d58 Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Fri, 24 Jul 2015 21:47:05 +0200 Subject: [PATCH] Disconnection and message handling refactoring --- hbmqtt/broker.py | 55 ++++++++------- hbmqtt/mqtt/protocol/broker_handler.py | 6 +- hbmqtt/mqtt/protocol/handler.py | 95 +++++++++++++------------- 3 files changed, 82 insertions(+), 74 deletions(-) diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index 65ceb60..1a7fd7b 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -213,39 +213,43 @@ class Broker: client_session.machine.connect() handler = BrokerProtocolHandler(self._loop) handler.attach_to_session(client_session) - self.logger.debug("Start messages handling") + self.logger.debug("%s Start messages handling" % client_session.client_id) yield from handler.start() yield from self.publish_session_retained_messages(client_session) - self.logger.debug("Wait for disconnect") + self.logger.debug("%s Wait for disconnect" % client_session.client_id) connected = True wait_disconnect = asyncio.Task(handler.wait_disconnect()) wait_subscription = asyncio.Task(handler.get_next_pending_subscription()) wait_unsubscription = asyncio.Task(handler.get_next_pending_unsubscription()) wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message()) + disconnect_event = False while connected: done, pending = yield from asyncio.wait( [wait_disconnect, wait_subscription, wait_unsubscription, wait_deliver], return_when=asyncio.FIRST_COMPLETED) if wait_disconnect in done: - result = wait_disconnect.result() - self.logger.debug("Result from wait_diconnect: %s" % result) - if result is None: - self.logger.debug("Will flag: %s" % client_session.will_flag) - #Connection closed anormally, send will message - if client_session.will_flag: - self.logger.debug("Client %s disconnected abnormally, sending will message" % - format_client_message(client_session)) - yield from self.broadcast_application_message( - client_session, client_session.will_topic, - client_session.will_message, - client_session.will_qos) - if client_session.will_retain: - self.retain_message(client_session, - client_session.will_topic, - client_session.will_message, - client_session.will_qos) - connected = False + if not disconnect_event: + result = wait_disconnect.result() + self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result)) + if result is None: + self.logger.debug("Will flag: %s" % client_session.will_flag) + #Connection closed anormally, send will message + if client_session.will_flag: + self.logger.debug("Client %s disconnected abnormally, sending will message" % + format_client_message(client_session)) + yield from self.broadcast_application_message( + client_session, client_session.will_topic, + client_session.will_message, + client_session.will_qos) + if client_session.will_retain: + self.retain_message(client_session, + client_session.will_topic, + client_session.will_message, + client_session.will_qos) + disconnect_event = True + if not (wait_unsubscription.done() or wait_subscription.done() or wait_deliver.done): + connected = False if wait_unsubscription in done: unsubscription = wait_unsubscription.result() for topic in unsubscription['topics']: @@ -275,7 +279,7 @@ class Broker: wait_unsubscription.cancel() wait_deliver.cancel() - self.logger.debug("Client disconnecting") + self.logger.debug("%s Client disconnecting" % client_session.client_id) try: yield from handler.stop() except Exception as e: @@ -285,7 +289,7 @@ class Broker: handler = None client_session.machine.disconnect() writer.close() - self.logger.debug("Session disconnected") + self.logger.debug("%s Session disconnected" % client_session.client_id) @asyncio.coroutine def check_connect(self, connect: ConnectPacket): @@ -306,12 +310,12 @@ class Broker: def retain_message(self, source_session, topic_name, data, qos=None): if data is not None and data != b'': # If retained flag set, store the message for further subscriptions - self.logger.debug("Retaining message on topic %s" % topic_name) + self.logger.debug("%s Retaining message on topic %s" % (source_session.client_id, topic_name)) retained_message = RetainedApplicationMessage(source_session, topic_name, data, qos) self._global_retained_messages[topic_name] = retained_message else: # [MQTT-3.3.1-10] - self.logger.debug("Clear retained messages for topic '%s'" % topic_name) + self.logger.debug("%s Clear retained messages for topic '%s'" % (source_session.client_id, topic_name)) del self._global_retained_messages[topic_name] def add_subscription(self, subscription, session): @@ -399,6 +403,9 @@ class Broker: asyncio.wait(publish_tasks) except Exception as e: self.logger.warn("Message broadcasting failed: %s", e) + self.logger.debug("End Broadcasting message from %s on topic %s" % + (format_client_message(session=source_session), topic) + ) @asyncio.coroutine def publish_session_retained_messages(self, session): diff --git a/hbmqtt/mqtt/protocol/broker_handler.py b/hbmqtt/mqtt/protocol/broker_handler.py index bf7293c..e3a89e1 100644 --- a/hbmqtt/mqtt/protocol/broker_handler.py +++ b/hbmqtt/mqtt/protocol/broker_handler.py @@ -48,7 +48,7 @@ class BrokerProtocolHandler(ProtocolHandler): @asyncio.coroutine def handle_disconnect(self, disconnect): - if self._disconnect_waiter is not None: + if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(disconnect) @asyncio.coroutine @@ -59,8 +59,8 @@ class BrokerProtocolHandler(ProtocolHandler): def handle_connect(self, connect: ConnectPacket): # Broker handler shouldn't received CONNECT message during messages handling # as CONNECT messages are managed by the broker on client connection - self.logger.error('[MQTT-3.1.0-2] %s : CONNECT message received during messages handling' % - (format_client_message(self.session))) + self.logger.error('%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling' % + (self.session.client_id, format_client_message(self.session))) if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(None) diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 44e64a9..ad7f227 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -97,12 +97,12 @@ class ProtocolHandler: self._writer_task = asyncio.Task(self._writer_coro(), loop=self._loop) yield from asyncio.wait( [self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) - self.logger.debug("Handler tasks started") + self.logger.debug("%s Handler tasks started" % self.session.client_id) @asyncio.coroutine def mqtt_publish(self, topic, message, packet_id, dup, qos, retain): if packet_id in self.session.inflight_out: - self.logger.warn("A message with the same packet ID is already in flight") + self.logger.warn("%s A message with the same packet ID is already in flight" % self.session.client_id) packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain) yield from self.outgoing_queue.put(packet) inflight_message = InFlightMessage(packet, qos) @@ -151,7 +151,7 @@ class ProtocolHandler: @asyncio.coroutine def _reader_coro(self): - self.logger.debug("Starting reader coro") + self.logger.debug("%s Starting reader coro" % self.session.client_id) while self._running: try: self._reader_ready.set() @@ -162,55 +162,56 @@ class ProtocolHandler: if fixed_header: cls = packet_class(fixed_header) packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) - self.logger.debug(" <-in-- " + repr(packet)) + self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet))) if packet.fixed_header.packet_type == PacketType.CONNACK: - asyncio.Task(self.handle_connack(packet)) + yield from self.handle_connack(packet) elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE: - asyncio.Task(self.handle_subscribe(packet)) + yield from self.handle_subscribe(packet) elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE: - asyncio.Task(self.handle_unsubscribe(packet)) + yield from self.handle_unsubscribe(packet) elif packet.fixed_header.packet_type == PacketType.SUBACK: - asyncio.Task(self.handle_suback(packet)) + yield from self.handle_suback(packet) elif packet.fixed_header.packet_type == PacketType.UNSUBACK: - asyncio.Task(self.handle_unsuback(packet)) + yield from self.handle_unsuback(packet) elif packet.fixed_header.packet_type == PacketType.PUBACK: - asyncio.Task(self.handle_puback(packet)) + yield from self.handle_puback(packet) elif packet.fixed_header.packet_type == PacketType.PUBREC: - asyncio.Task(self.handle_pubrec(packet)) + yield from self.handle_pubrec(packet) elif packet.fixed_header.packet_type == PacketType.PUBREL: - asyncio.Task(self.handle_pubrel(packet)) + yield from self.handle_pubrel(packet) elif packet.fixed_header.packet_type == PacketType.PUBCOMP: - asyncio.Task(self.handle_pubcomp(packet)) + yield from self.handle_pubcomp(packet) elif packet.fixed_header.packet_type == PacketType.PINGREQ: - asyncio.Task(self.handle_pingreq(packet)) + yield from self.handle_pingreq(packet) elif packet.fixed_header.packet_type == PacketType.PINGRESP: - asyncio.Task(self.handle_pingresp(packet)) + yield from self.handle_pingresp(packet) elif packet.fixed_header.packet_type == PacketType.PUBLISH: - asyncio.Task(self.handle_publish(packet)) + yield from self.handle_publish(packet) elif packet.fixed_header.packet_type == PacketType.DISCONNECT: - asyncio.Task(self.handle_disconnect(packet)) + yield from self.handle_disconnect(packet) elif packet.fixed_header.packet_type == PacketType.CONNECT: - asyncio.Task(self.handle_connect(packet)) + yield from self.handle_connect(packet) else: - self.logger.warn("Unhandled packet type: %s" % packet.fixed_header.packet_type) + self.logger.warn("%s Unhandled packet type: %s" % + (self.session.client_id, packet.fixed_header.packet_type)) else: - self.logger.debug("No more data, stopping reader coro") + self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id) yield from self.handle_connection_closed() break except asyncio.TimeoutError: - self.logger.debug("Input stream read timeout") + self.logger.debug("%s Input stream read timeout" % self.session.client_id) self.handle_read_timeout() except NoDataException as nde: - self.logger.debug("No data available") + self.logger.debug("%s No data available" % self.session.client_id) except Exception as e: - self.logger.warn("Unhandled exception in reader coro: %s" % e) + self.logger.warn("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e)) break - self.logger.debug("Reader coro stopped") + self.logger.debug("%s Reader coro stopped" % self.session.client_id) @asyncio.coroutine def _writer_coro(self): - self.logger.debug("Starting writer coro") + self.logger.debug("%s Starting writer coro" % self.session.client_id) while self._running: try: self._writer_ready.set() @@ -219,23 +220,22 @@ class ProtocolHandler: keepalive_timeout = None packet = yield from asyncio.wait_for(self.outgoing_queue.get(), keepalive_timeout) if not isinstance(packet, MQTTPacket): - self.logger.debug("Writer interruption") + self.logger.debug("%s Writer interruption" % self.session.client_id) break yield from packet.to_stream(self.session.writer) - self.logger.debug(" -out-> " + repr(packet)) + self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet))) yield from self.session.writer.drain() - #self.outgoing_queue.task_done() # to be used with Python 3.5 except asyncio.TimeoutError as ce: - self.logger.debug("Output queue get timeout") + self.logger.debug("%s Output queue get timeout" % self.session.client_id) if self._running: self.handle_write_timeout() except ConnectionResetError as cre: yield from self.handle_connection_closed() break except Exception as e: - self.logger.warn("Unhandled exception in writer coro: %s" % e) + self.logger.warn("%sUnhandled exception in writer coro: %s" % (self.session.client_id, e)) break - self.logger.debug("Writer coro stopping") + self.logger.debug("%s Writer coro stopping" % self.session.client_id) # Flush queue before stopping if not self.outgoing_queue.empty(): while True: @@ -244,12 +244,12 @@ class ProtocolHandler: if not isinstance(packet, MQTTPacket): break yield from packet.to_stream(self.session.writer) - self.logger.debug(" -out-> " + repr(packet)) + self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet))) except asyncio.QueueEmpty: break except Exception as e: - self.logger.warn("Unhandled exception in writer coro: %s" % e) - self.logger.debug("Writer coro stopped") + self.logger.warn("%s Unhandled exception in writer coro: %s" % (self.session.client_id, e)) + self.logger.debug("%s Writer coro stopped" % self.session.client_id) @asyncio.coroutine def mqtt_deliver_next_message(self): @@ -257,50 +257,50 @@ class ProtocolHandler: return inflight_message def handle_write_timeout(self): - self.logger.warn('write timeout unhandled') + self.logger.warn('%s write timeout unhandled' % self.session.client_id) def handle_read_timeout(self): - self.logger.warn('read timeout unhandled') + self.logger.warn('%s read timeout unhandled' % self.session.client_id) @asyncio.coroutine def handle_connack(self, connack: ConnackPacket): - self.logger.warn('CONNACK unhandled') + self.logger.warn('%s CONNACK unhandled' % self.session.client_id) @asyncio.coroutine def handle_connect(self, connect: ConnectPacket): - self.logger.warn('CONNECT unhandled') + self.logger.warn('%s CONNECT unhandled' % self.session.client_id) @asyncio.coroutine def handle_subscribe(self, subscribe: SubscribePacket): - self.logger.warn('SUBSCRIBE unhandled') + self.logger.warn('%s SUBSCRIBE unhandled' % self.session.client_id) @asyncio.coroutine def handle_unsubscribe(self, subscribe: UnsubscribePacket): - self.logger.warn('UNSUBSCRIBE unhandled') + self.logger.warn('%s UNSUBSCRIBE unhandled' % self.session.client_id) @asyncio.coroutine def handle_suback(self, suback: SubackPacket): - self.logger.warn('SUBACK unhandled') + self.logger.warn('%s SUBACK unhandled' % self.session.client_id) @asyncio.coroutine def handle_unsuback(self, unsuback: UnsubackPacket): - self.logger.warn('UNSUBACK unhandled') + self.logger.warn('%s UNSUBACK unhandled' % self.session.client_id) @asyncio.coroutine def handle_pingresp(self, pingresp: PingRespPacket): - self.logger.warn('PINGRESP unhandled') + self.logger.warn('%s PINGRESP unhandled' % self.session.client_id) @asyncio.coroutine def handle_pingreq(self, pingreq: PingReqPacket): - self.logger.warn('PINGREQ unhandled') + self.logger.warn('%s PINGREQ unhandled' % self.session.client_id) @asyncio.coroutine def handle_disconnect(self, disconnect: DisconnectPacket): - self.logger.warn('DISCONNECT unhandled') + self.logger.warn('%s DISCONNECT unhandled' % self.session.client_id) @asyncio.coroutine def handle_connection_closed(self): - self.logger.warn('Connection closed unhandled') + self.logger.warn('%s Connection closed unhandled' % self.session.client_id) @asyncio.coroutine def handle_puback(self, puback: PubackPacket): @@ -309,7 +309,8 @@ class ProtocolHandler: waiter = self._puback_waiters[packet_id] waiter.set_result(puback) except KeyError as ke: - self.logger.warn("Received PUBACK for unknown pending subscription with Id: %s" % packet_id) + self.logger.warn("%s Received PUBACK for unknown pending subscription with Id: %s" % + (self.session.client_id, packet_id)) @asyncio.coroutine def handle_pubrec(self, pubrec: PubrecPacket):