diff --git a/hbmqtt/protocol.py b/hbmqtt/protocol.py index 2dea57a..f27d5b9 100644 --- a/hbmqtt/protocol.py +++ b/hbmqtt/protocol.py @@ -20,9 +20,11 @@ class ProtocolHandler: self._writer_task = None self._reader_ready = asyncio.Event(loop=self._loop) self._writer_ready = asyncio.Event(loop=self._loop) + self._running = False @asyncio.coroutine def start(self): + self._running = True self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop) self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop) yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) @@ -30,51 +32,59 @@ class ProtocolHandler: @asyncio.coroutine def stop(self): - self._reader_task.cancel() - self._writer_task.cancel() + self._running = False + yield from asyncio.wait([self._writer_task], loop=self._loop) + @asyncio.coroutine def _reader_coro(self): self.logger.debug("Starting reader coro") - while True: + while self._running: try: self._reader_ready.set() - fixed_header = yield from MQTTFixedHeader.from_stream(self.session.reader) - cls = packet_class(fixed_header) - packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) - yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet) - except asyncio.CancelledError: - self.logger.warn("Reader coro stopped") - break - except NoDataException: - self.logger.debug("No more data to read") - break - except Exception as e: + fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), 60) + if fixed_header: + cls = packet_class(fixed_header) + packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) + yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet) + else: + self.logger.debug("No data") + except asyncio.TimeoutError: + self.logger.warn("Input stream read timeout") + except NoDataException as nde: + self.logger.debug("No data available") + #break + except BaseException as e: self.logger.warn("Exception in reader coro: %s" % e) break + self.logger.debug("Reader coro stopped") + @asyncio.coroutine def _writer_coro(self): self.logger.debug("Starting writer coro") out_queue = self.session.outgoing_queue - while True: + packet = None + while self._running: try: self._writer_ready.set() - packet = yield from out_queue.get() + packet = yield from asyncio.wait_for(out_queue.get(), 60) self.logger.debug(packet) yield from packet.to_stream(self.session.writer) - except asyncio.CancelledError: - self.logger.warn("Writer coro stopping") - # Flush queue - while True: - try: - packet = out_queue.get_nowait() - self.logger.debug(packet) - yield from packet.to_stream(self.session.writer) - except asyncio.QueueEmpty: - break - self.logger.warn("Writer coro stopped") - break + yield from self.session.writer.drain() + except asyncio.TimeoutError as ce: + self.logger.warn("Output queue get timeout") except Exception as e: self.logger.warn("Exception in writer coro: %s" % e) break + self.logger.debug("Writer coro stopping") + # Flush queue before stopping + if not out_queue.empty(): + while True: + try: + packet = out_queue.get_nowait() + self.logger.debug(packet) + yield from packet.to_stream(self.session.writer) + except asyncio.QueueEmpty: + break + self.logger.debug("Writer coro stopped") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 4e137ab..3614445 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -14,6 +14,8 @@ import logging logging.basicConfig(level=logging.DEBUG) +packet = "str" + class ConnectPacketTest(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() @@ -48,13 +50,15 @@ class ConnectPacketTest(unittest.TestCase): def test_write_loop(self): data_ref = b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password' - packet = None + event=asyncio.Event() @asyncio.coroutine def serve_test(reader, writer): global packet packet = yield from ConnectPacket.from_stream(reader) self.logger.info("data=" + repr(packet)) writer.close() + event.set() + return packet loop = asyncio.get_event_loop() coro = asyncio.start_server(serve_test, '127.0.0.1', 8888, loop=loop) @@ -66,17 +70,15 @@ class ConnectPacketTest(unittest.TestCase): S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop) handler = ProtocolHandler(S, loop) - #yield from handler.start() - packet = ConnectPacket(vh=ConnectVariableHeader(), payload=ConnectPayload('Id', 'WillTopic', 'WillMessage', 'user', 'password')) - self.logger.debug(packet) - S.outgoing_queue.put_nowait(packet) - #S.outgoing_queue.put_nowait(packet) - #yield from S.outgoing_queue.put(packet) + yield from handler.start() + conn = ConnectPacket(vh=ConnectVariableHeader(), payload=ConnectPayload('Id', 'WillTopic', 'WillMessage', 'user', 'password')) + yield from S.outgoing_queue.put(conn) self.logger.debug("Messages in queue: %d" % S.outgoing_queue.qsize()) yield from handler.stop() + S.writer.close() loop.run_until_complete(client()) - loop.run_forever() - server.close() - print(packet) + loop.run_until_complete(asyncio.wait([event.wait()])) + ret = server.close() + self.logger.info(packet) #self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT) \ No newline at end of file