From 15468b784953f6cc82ad5c674f95a4b8b5852efc Mon Sep 17 00:00:00 2001 From: nico Date: Tue, 8 Sep 2015 15:01:37 +0200 Subject: [PATCH] Improve handler stopping (remove wait delay when disconnecting) --- hbmqtt/mqtt/protocol/handler.py | 45 ++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 3b38158..03c8509 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -53,17 +53,16 @@ class ProtocolHandler: self._writer_task = None self._reader_ready = asyncio.Event(loop=self._loop) self._writer_ready = asyncio.Event(loop=self._loop) - - self._running = False + self._reader_stopped = asyncio.Event(loop=self._loop) + self._writer_stopped = asyncio.Event(loop=self._loop) self.outgoing_queue = asyncio.Queue(loop=self._loop) self._pubrel_waiters = dict() @asyncio.coroutine def start(self): - self._running = True - self._reader_task = asyncio.Task(self._reader_coro(), loop=self._loop) - self._writer_task = asyncio.Task(self._writer_coro(), loop=self._loop) + self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop) + self._writer_task = asyncio.Task(self._writer_loop(), loop=self._loop) yield from asyncio.wait( [self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) self.logger.debug("%s Handler tasks started" % self.session.client_id) @@ -124,21 +123,23 @@ class ProtocolHandler: @asyncio.coroutine def stop(self): - self._running = False - yield from self.outgoing_queue.put("STOP") - self.reader.feed_eof() - yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop) - yield from self.writer.close() # Stop incoming messages flow waiter for packet_id in self.session.incoming_msg: self.session.incoming_msg[packet_id].cancel() for packet_id in self.session.outgoing_msg: self.session.outgoing_msg[packet_id].cancel() + self._reader_task.cancel() + self._writer_task.cancel() + self.logger.debug("waiting for loops to be stopped") + yield from asyncio.wait( + [self._reader_stopped.wait(), self._writer_stopped.wait()], loop=self._loop) + self.logger.debug("closing writer") + yield from self.writer.close() @asyncio.coroutine - def _reader_coro(self): + def _reader_loop(self): self.logger.debug("%s Starting reader coro" % self.session.client_id) - while self._running: + while True: try: self._reader_ready.set() keepalive_timeout = self.session.keep_alive @@ -192,8 +193,10 @@ class ProtocolHandler: (self.session.client_id, packet.fixed_header.packet_type)) else: self.logger.debug("%s No more data (EOF received), stopping reader coro" % self.session.client_id) - yield from self.handle_connection_closed() break + except asyncio.CancelledError: + self.logger.debug("Task cancelled, reader loop ending") + break except asyncio.TimeoutError: self.logger.debug("%s Input stream read timeout" % self.session.client_id) self.handle_read_timeout() @@ -202,28 +205,29 @@ class ProtocolHandler: except Exception as e: self.logger.warn("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e)) break + yield from self.handle_connection_closed() + self._reader_stopped.set() self.logger.debug("%s Reader coro stopped" % self.session.client_id) @asyncio.coroutine - def _writer_coro(self): + def _writer_loop(self): self.logger.debug("%s Starting writer coro" % self.session.client_id) - while self._running: + while True: try: self._writer_ready.set() keepalive_timeout = self.session.keep_alive if keepalive_timeout <= 0: keepalive_timeout = None packet = yield from asyncio.wait_for(self.outgoing_queue.get(), keepalive_timeout, loop=self._loop) - if not isinstance(packet, MQTTPacket): - self.logger.debug("%s Writer interruption" % self.session.client_id) - break yield from packet.to_stream(self.writer) yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session) self._loop.call_soon(self.on_packet_sent.send, packet) + except asyncio.CancelledError: + self.logger.debug("Task cancelled, writer loop ending") + break except asyncio.TimeoutError as ce: self.logger.debug("%s Output queue get timeout" % self.session.client_id) - if self._running: - self.handle_write_timeout() + self.handle_write_timeout() except ConnectionResetError as cre: yield from self.handle_connection_closed() break @@ -245,6 +249,7 @@ class ProtocolHandler: break except Exception as e: self.logger.warn("%s Unhandled exception in writer coro: %s" % (self.session.client_id, e)) + self._writer_stopped.set() self.logger.debug("%s Writer coro stopped" % self.session.client_id) @asyncio.coroutine