From a4e002de2a5233e95110cf5b836159d3d39501b1 Mon Sep 17 00:00:00 2001 From: Nico Date: Wed, 30 Sep 2015 21:22:46 +0200 Subject: [PATCH] Add publish message retry on start() --- hbmqtt/mqtt/protocol/handler.py | 45 +++++++------- tests/mqtt/protocol/test_handler.py | 91 +++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 25 deletions(-) diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 254d41b..2f49e8a 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -3,6 +3,7 @@ # See the file license.txt for copying permission. import logging import collections +import itertools from asyncio import InvalidStateError from blinker import Signal @@ -102,15 +103,14 @@ class ProtocolHandler: self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout) self.logger.debug("Handler tasks started") - yield from self.retry_deliveries() + yield from self._retry_deliveries() self.logger.debug("Handler ready") @asyncio.coroutine def stop(self): - # Stop incoming messages flow waiter - #for packet_id in self.session.inflight_in: - # self.session.inflight_in[packet_id].cancel() + # Stop messages flow waiter self._reader_task.cancel() + self._stop_waiters() if self._keepalive_task: self._keepalive_task.cancel() self.logger.debug("waiting for tasks to be stopped") @@ -122,33 +122,28 @@ class ProtocolHandler: except Exception as e: self.logger.debug("Handler writer close failed: %s" % e) + def _stop_waiters(self): + for waiter in itertools.chain( + self._puback_waiters.values(), + self._pubcomp_waiters.values(), + self._pubrec_waiters.values(), + self._pubrel_waiters.values()): + waiter.cancel() + @asyncio.coroutine - def retry_deliveries(self): + def _retry_deliveries(self): """ Handle [MQTT-4.4.0-1] by resending PUBLISH and PUBREL messages for pending out messages :return: """ self.logger.debug("Begin messages delivery retries") - ack_packets = [] - for packet_id in self.session.inflight_out: - message = self.session.inflight_out[packet_id] - if message.is_acknowledged(): - ack_packets.append(packet_id) - else: - if not message.pubrec_packet: - self.logger.debug("Retrying publish message Id=%d acknowledgment", packet_id) - message.publish_packet = PublishPacket.build( - message.topic, - message.data, - message.packet_id, - True, - message.qos, - message.retain) - yield from self._send_packet(message.publish_packet) - yield from self._handle_message_flow(message) - for packet_id in ack_packets: - del self.session.inflight_out[packet_id] - self.logger.debug("%d messages redelivered" % len(ack_packets)) + tasks = [] + for message in itertools.chain(self.session.inflight_in.values(), self.session.inflight_out.values()): + tasks.append(asyncio.wait_for(self._handle_message_flow(message), 10, loop=self._loop)) + if tasks: + done, pending = yield from asyncio.wait(tasks) + self.logger.debug("%d messages redelivered" % len(done)) + self.logger.debug("%d messages not redelivered due to timeout" % len(pending)) self.logger.debug("End messages delivery retries") @asyncio.coroutine diff --git a/tests/mqtt/protocol/test_handler.py b/tests/mqtt/protocol/test_handler.py index 3b9a280..3d6bf63 100644 --- a/tests/mqtt/protocol/test_handler.py +++ b/tests/mqtt/protocol/test_handler.py @@ -364,3 +364,94 @@ class ProtocolHandlerTest(unittest.TestCase): self.assertFalse(session.inflight_out) self.assertFalse(session.inflight_in) # self.assertEquals(session.delivered_message_queue.qsize(), 0) + + def test_publish_qos1_retry(self): + @asyncio.coroutine + def server_mock(reader, writer): + packet = yield from PublishPacket.from_stream(reader) + try: + self.assertEquals(packet.topic_name, '/topic') + self.assertEquals(packet.qos, QOS_1) + self.assertIsNotNone(packet.packet_id) + self.assertIn(packet.packet_id, self.session.inflight_out) + self.assertIn(packet.packet_id, self.handler._puback_waiters) + puback = PubackPacket.build(packet.packet_id) + yield from puback.to_stream(writer) + except Exception as ae: + future.set_exception(ae) + + @asyncio.coroutine + def test_coro(): + try: + reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) + reader_adapted, writer_adapted = adapt(reader, writer) + self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) + self.handler.attach_stream(reader_adapted, writer_adapted) + yield from self.handler.start() + yield from self.stop_handler(self.handler, self.session) + if not future.done(): + future.set_result(True) + except Exception as ae: + future.set_exception(ae) + self.handler = None + self.session = Session() + message = OutgoingApplicationMessage(1, '/topic', QOS_1, b'test_data', False) + message.publish_packet = PublishPacket.build('/topic', b'test_data', 1, False, QOS_1, False) + self.session.inflight_out[1] = message + future = asyncio.Future(loop=self.loop) + + coro = asyncio.start_server(server_mock, '127.0.0.1', 8888, loop=self.loop) + server = self.loop.run_until_complete(coro) + self.loop.run_until_complete(test_coro()) + server.close() + self.loop.run_until_complete(server.wait_closed()) + if future.exception(): + raise future.exception() + + def test_publish_qos2_retry(self): + @asyncio.coroutine + def server_mock(reader, writer): + try: + packet = yield from PublishPacket.from_stream(reader) + self.assertEquals(packet.topic_name, '/topic') + self.assertEquals(packet.qos, QOS_2) + self.assertIsNotNone(packet.packet_id) + self.assertIn(packet.packet_id, self.session.inflight_out) + self.assertIn(packet.packet_id, self.handler._pubrec_waiters) + pubrec = PubrecPacket.build(packet.packet_id) + yield from pubrec.to_stream(writer) + + pubrel = yield from PubrelPacket.from_stream(reader) + self.assertIn(packet.packet_id, self.handler._pubcomp_waiters) + pubcomp = PubcompPacket.build(packet.packet_id) + yield from pubcomp.to_stream(writer) + except Exception as ae: + future.set_exception(ae) + + @asyncio.coroutine + def test_coro(): + try: + reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) + reader_adapted, writer_adapted = adapt(reader, writer) + self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) + self.handler.attach_stream(reader_adapted, writer_adapted) + yield from self.handler.start() + yield from self.stop_handler(self.handler, self.session) + if not future.done(): + future.set_result(True) + except Exception as ae: + future.set_exception(ae) + self.handler = None + self.session = Session() + message = OutgoingApplicationMessage(1, '/topic', QOS_2, b'test_data', False) + message.publish_packet = PublishPacket.build('/topic', b'test_data', 1, False, QOS_2, False) + self.session.inflight_out[1] = message + future = asyncio.Future(loop=self.loop) + + coro = asyncio.start_server(server_mock, '127.0.0.1', 8888, loop=self.loop) + server = self.loop.run_until_complete(coro) + self.loop.run_until_complete(test_coro()) + server.close() + self.loop.run_until_complete(server.wait_closed()) + if future.exception(): + raise future.exception()