From df317319d48a14322ffe90ed61c13b225e9fefb8 Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Sun, 26 Jul 2015 22:53:11 +0200 Subject: [PATCH] Fix disconnection management --- hbmqtt/broker.py | 42 ++++++++++++-------------- hbmqtt/mqtt/protocol/broker_handler.py | 4 +-- hbmqtt/mqtt/protocol/handler.py | 32 +++++++++++--------- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index 23494ff..2f186fb 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -223,33 +223,29 @@ class Broker: wait_subscription = asyncio.Task(handler.get_next_pending_subscription()) wait_unsubscription = asyncio.Task(handler.get_next_pending_unsubscription()) wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message()) - disconnect_event = False while connected: done, pending = yield from asyncio.wait( [wait_disconnect, wait_subscription, wait_unsubscription, wait_deliver], return_when=asyncio.FIRST_COMPLETED) if wait_disconnect in done: - if not disconnect_event: - result = wait_disconnect.result() - self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result)) - if result is None: - self.logger.debug("Will flag: %s" % client_session.will_flag) - #Connection closed anormally, send will message - if client_session.will_flag: - self.logger.debug("Client %s disconnected abnormally, sending will message" % - format_client_message(client_session)) - yield from self.broadcast_application_message( - client_session, client_session.will_topic, - client_session.will_message, - client_session.will_qos) - if client_session.will_retain: - self.retain_message(client_session, - client_session.will_topic, - client_session.will_message, - client_session.will_qos) - disconnect_event = True - if not (wait_unsubscription.done() or wait_subscription.done() or wait_deliver.done): - connected = False + result = wait_disconnect.result() + self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result)) + if result is None: + self.logger.debug("Will flag: %s" % client_session.will_flag) + #Connection closed anormally, send will message + if client_session.will_flag: + self.logger.debug("Client %s disconnected abnormally, sending will message" % + format_client_message(client_session)) + yield from self.broadcast_application_message( + client_session, client_session.will_topic, + client_session.will_message, + client_session.will_qos) + if client_session.will_retain: + self.retain_message(client_session, + client_session.will_topic, + client_session.will_message, + client_session.will_qos) + connected = False if wait_unsubscription in done: self.logger.debug("%s handling unsubscription" % client_session.client_id) unsubscription = wait_unsubscription.result() @@ -422,7 +418,7 @@ class Broker: retained = yield from session.retained_messages.get() publish_tasks.append(asyncio.Task( session.handler.mqtt_publish( - retained.topic, retained.data, False, retained.qos, True))) + retained.topic, retained.data, retained.qos, True))) if len(publish_tasks) > 0: asyncio.wait(publish_tasks) diff --git a/hbmqtt/mqtt/protocol/broker_handler.py b/hbmqtt/mqtt/protocol/broker_handler.py index e3a89e1..989bbbe 100644 --- a/hbmqtt/mqtt/protocol/broker_handler.py +++ b/hbmqtt/mqtt/protocol/broker_handler.py @@ -37,7 +37,7 @@ class BrokerProtocolHandler(ProtocolHandler): @asyncio.coroutine def wait_disconnect(self): - yield from self._disconnect_waiter + return (yield from self._disconnect_waiter) def handle_write_timeout(self): pass @@ -48,7 +48,7 @@ class BrokerProtocolHandler(ProtocolHandler): @asyncio.coroutine def handle_disconnect(self, disconnect): - if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): + if self._disconnect_waiter and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(disconnect) @asyncio.coroutine diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 2ab548f..b55b5ae 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -142,37 +142,41 @@ class ProtocolHandler: packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet))) + task = None if packet.fixed_header.packet_type == PacketType.CONNACK: - asyncio.Task(self.handle_connack(packet)) + task = asyncio.Task(self.handle_connack(packet)) elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE: - asyncio.Task(self.handle_subscribe(packet)) + task = asyncio.Task(self.handle_subscribe(packet)) elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE: - asyncio.Task(self.handle_unsubscribe(packet)) + task = asyncio.Task(self.handle_unsubscribe(packet)) elif packet.fixed_header.packet_type == PacketType.SUBACK: - asyncio.Task(self.handle_suback(packet)) + task = asyncio.Task(self.handle_suback(packet)) elif packet.fixed_header.packet_type == PacketType.UNSUBACK: - asyncio.Task(self.handle_unsuback(packet)) + task = asyncio.Task(self.handle_unsuback(packet)) elif packet.fixed_header.packet_type == PacketType.PUBACK: - asyncio.Task(self.handle_puback(packet)) + task = asyncio.Task(self.handle_puback(packet)) elif packet.fixed_header.packet_type == PacketType.PUBREC: - asyncio.Task(self.handle_pubrec(packet)) + task = asyncio.Task(self.handle_pubrec(packet)) elif packet.fixed_header.packet_type == PacketType.PUBREL: - asyncio.Task(self.handle_pubrel(packet)) + task = asyncio.Task(self.handle_pubrel(packet)) elif packet.fixed_header.packet_type == PacketType.PUBCOMP: - asyncio.Task(self.handle_pubcomp(packet)) + task = asyncio.Task(self.handle_pubcomp(packet)) elif packet.fixed_header.packet_type == PacketType.PINGREQ: - asyncio.Task(self.handle_pingreq(packet)) + task = asyncio.Task(self.handle_pingreq(packet)) elif packet.fixed_header.packet_type == PacketType.PINGRESP: - asyncio.Task(self.handle_pingresp(packet)) + task = asyncio.Task(self.handle_pingresp(packet)) elif packet.fixed_header.packet_type == PacketType.PUBLISH: - asyncio.Task(self.handle_publish(packet)) + task = asyncio.Task(self.handle_publish(packet)) elif packet.fixed_header.packet_type == PacketType.DISCONNECT: - asyncio.Task(self.handle_disconnect(packet)) + task = asyncio.Task(self.handle_disconnect(packet)) elif packet.fixed_header.packet_type == PacketType.CONNECT: - asyncio.Task(self.handle_connect(packet)) + task = asyncio.Task(self.handle_connect(packet)) else: self.logger.warn("%s Unhandled packet type: %s" % (self.session.client_id, packet.fixed_header.packet_type)) + if task: + # Wait for message handling ends + asyncio.wait([task]) else: self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id) yield from self.handle_connection_closed()