From d74cdf2665267865fb2527cfe7cd508994446e93 Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Sat, 25 Jul 2015 23:21:25 +0200 Subject: [PATCH] Refactor mqtt_publish and inflight messages management --- hbmqtt/broker.py | 9 +- hbmqtt/client.py | 8 +- hbmqtt/mqtt/protocol/handler.py | 157 ++++++++++++++++++-------------- hbmqtt/specs.py | 13 ++- samples/client_publish.py | 2 +- 5 files changed, 104 insertions(+), 85 deletions(-) diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index fc45090..a687eb1 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -390,9 +390,8 @@ class Broker: (format_client_message(session=source_session), topic, format_client_message(session=target_session))) handler = subscription.session.handler - packet_id = handler.session.next_packet_id publish_tasks.append( - asyncio.Task(handler.mqtt_publish(topic, data, packet_id, False, qos, retain=False)) + asyncio.Task(handler.mqtt_publish(topic, data, qos, retain=False)) ) else: self.logger.debug("retaining application message from %s on topic '%s' to client '%s'" % @@ -418,10 +417,9 @@ class Broker: publish_tasks = [] while not session.retained_messages.empty(): retained = yield from session.retained_messages.get() - packet_id = session.next_packet_id publish_tasks.append(asyncio.Task( session.handler.mqtt_publish( - retained.topic, retained.data, packet_id, False, retained.qos, True))) + retained.topic, retained.data, False, retained.qos, True))) if len(publish_tasks) > 0: asyncio.wait(publish_tasks) @@ -435,10 +433,9 @@ class Broker: if self.matches(d_topic, subscription['filter']): self.logger.debug("%s and %s match" % (d_topic, subscription['filter'])) retained = self._global_retained_messages[d_topic] - packet_id = session.next_packet_id publish_tasks.append(asyncio.Task( session.handler.mqtt_publish( - retained.topic, retained.data, packet_id, False, subscription['qos'], True))) + retained.topic, retained.data, subscription['qos'], True))) if len(publish_tasks) > 0: asyncio.wait(publish_tasks) self.logger.debug("End broadcasting messages retained due to subscription on '%s' from %s" % diff --git a/hbmqtt/client.py b/hbmqtt/client.py index ed29e82..c88501c 100644 --- a/hbmqtt/client.py +++ b/hbmqtt/client.py @@ -143,7 +143,7 @@ class MQTTClient: self._handler.mqtt_ping() @asyncio.coroutine - def publish(self, topic, message, dup=False, qos=None, retain=None): + def publish(self, topic, message, qos=None, retain=None): def get_retain_and_qos(): if qos: _qos = qos @@ -164,11 +164,11 @@ class MQTTClient: return _qos, _retain (app_qos, app_retain) = get_retain_and_qos() if app_qos == 0: - yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x00, app_retain) + yield from self._handler.mqtt_publish(topic, message, 0x00, app_retain) if app_qos == 1: - yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x01, app_retain) + yield from self._handler.mqtt_publish(topic, message, 0x01, app_retain) if app_qos == 2: - yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x02, app_retain) + yield from self._handler.mqtt_publish(topic, message, 0x02, app_retain) @asyncio.coroutine def subscribe(self, topics): diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index ad77b7b..22e557c 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -3,10 +3,11 @@ # See the file license.txt for copying permission. import logging import asyncio +from datetime import datetime from asyncio import futures from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket from hbmqtt.mqtt import packet_class -from hbmqtt.errors import NoDataException +from hbmqtt.errors import NoDataException, HBMQTTException from hbmqtt.mqtt.packet import PacketType from hbmqtt.mqtt.connack import ConnackPacket from hbmqtt.mqtt.connect import ConnectPacket @@ -23,30 +24,77 @@ from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.session import Session -from transitions import Machine +from hbmqtt.specs import * class InFlightMessage: - states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed'] - - def __init__(self, packet, qos): + def __init__(self, packet, qos, ack_timeout=0, loop=None): + if loop is None: + self._loop = asyncio.get_event_loop() + else: + self._loop = loop self.packet = packet self.qos = qos - self.puback = None - self.pubrec = None - self.pubcomp = None - self.pubrel = None - self._init_states() + self.publish_ts = None + self.puback_ts = None + self.pubrec_ts = None + self.pubrel_ts = None + self.pubcomp_ts = None + self._ack_waiter = asyncio.Future(loop=self._loop) + self._ack_timeout = ack_timeout + self._ack_timeout_handle = None - def _init_states(self): - self.machine = Machine(model=self, states=InFlightMessage.states, initial='new') - self.machine.add_transition(trigger='publish', source='new', dest='published') - if self.qos == 0x01: - self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged') - if self.qos == 0x02: - self.machine.add_transition(trigger='receive', source='published', dest='received') - self.machine.add_transition(trigger='release', source='received', dest='released') - self.machine.add_transition(trigger='complete', source='released', dest='completed') + @asyncio.coroutine + def wait_acknowledge(self): + return (yield from self._ack_waiter) + + def received_puback(self): + if self.qos == QOS_1: + self.puback_ts = datetime.now() + self.cancel_ack_timeout() + self._ack_waiter.set_result(True) + else: + raise HBMQTTException('Invalid call to method received_puback on inflight messages with QOS=%d' % self.qos) + + def received_pubrec(self): + if self.qos == QOS_2: + self.pubrec_ts = datetime.now() + self.packet = None # Discard message + self.reset_ack_timeout() + else: + raise HBMQTTException('Invalid call to method received_pubrec on inflight messages with QOS=%d' % self.qos) + + def received_pubcomp(self): + if self.qos == QOS_2: + self.pubcomp_ts = datetime.now() + self.cancel_ack_timeout() + self._ack_waiter.set_result(True) + else: + raise HBMQTTException('Invalid call to method received_pubcomp on inflight messages with QOS=%d' % self.qos) + + def sent_pubrel(self): + if self.qos == QOS_2: + self.pubrel_ts = datetime.now() + else: + raise HBMQTTException('Invalid call to method sent_pubrel on inflight messages with QOS=%d' % self.qos) + + def start_ack_timeout(self): + def cb_timeout(): + self._ack_waiter.set_result(False) + if self._ack_timeout: + self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout) + + def cancel_ack_timeout(self): + if self._ack_timeout_handle: + self._ack_timeout_handle.cancel() + + def reset_ack_timeout(self): + self.cancel_ack_timeout() + self.start_ack_timeout() + + def sent_publish(self): + self.publish_ts = datetime.now() + self.start_ack_timeout() class ProtocolHandler: @@ -68,10 +116,6 @@ class ProtocolHandler: self._running = False - self.incoming_queues = dict() - self.application_messages = asyncio.Queue() - for p in PacketType: - self.incoming_queues[p] = asyncio.Queue() self.outgoing_queue = asyncio.Queue() self._puback_waiters = dict() self._pubrec_waiters = dict() @@ -100,47 +144,25 @@ class ProtocolHandler: self.logger.debug("%s Handler tasks started" % self.session.client_id) @asyncio.coroutine - def mqtt_publish(self, topic, message, packet_id, dup, qos, retain): + def mqtt_publish(self, topic, message, qos, retain): + packet_id = self.session.next_packet_id if packet_id in self.session.inflight_out: self.logger.warn("%s A message with the same packet ID is already in flight" % self.session.client_id) - packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain) + packet = PublishPacket.build(topic, message, packet_id, False, qos, retain) yield from self.outgoing_queue.put(packet) - inflight_message = InFlightMessage(packet, qos) - self.session.inflight_out[packet.variable_header.packet_id] = inflight_message - - inflight_message.publish() - if qos == 0x01: - waiter = futures.Future(loop=self._loop) - self._puback_waiters[packet_id] = waiter - yield from waiter - inflight_message.puback = waiter.result() - inflight_message.acknowledge() - del self._puback_waiters[packet_id] - if qos == 0x02: - # Wait for PUBREC - waiter = futures.Future(loop=self._loop) - self._pubrec_waiters[packet_id] = waiter - yield from waiter - inflight_message.pubrec = waiter.result() - del self._pubrec_waiters[packet_id] - inflight_message.receive() - - # Send pubrel - pubrel = PubrelPacket.build(packet_id) - yield from self.outgoing_queue.put(pubrel) - inflight_message.pubrel = pubrel - inflight_message.release() - - # Wait for pubcomp - waiter = futures.Future(loop=self._loop) - self._pubcomp_waiters[packet_id] = waiter - yield from waiter - inflight_message.pubcomp = waiter.result() - del self._pubcomp_waiters[packet_id] - inflight_message.complete() - - del self.session.inflight_out[packet_id] - return inflight_message + if qos != QOS_0: + inflight_message = InFlightMessage(packet, qos, loop=self._loop) + inflight_message.sent_publish() + self.session.inflight_out[packet_id] = inflight_message + ack = yield from inflight_message.wait_acknowledge() + while not ack: + #Retry publish + packet = PublishPacket.build(topic, message, packet_id, True, qos, retain) + inflight_message.packet = packet + yield from self.outgoing_queue.put(packet) + inflight_message.sent_publish() + ack = yield from inflight_message.wait_acknowledge() + del self.session.inflight_out[packet_id] @asyncio.coroutine def stop(self): @@ -306,8 +328,8 @@ class ProtocolHandler: def handle_puback(self, puback: PubackPacket): packet_id = puback.variable_header.packet_id try: - waiter = self._puback_waiters[packet_id] - waiter.set_result(puback) + inflight_message = self.session.inflight_out[packet_id] + inflight_message.received_puback() except KeyError as ke: self.logger.warn("%s Received PUBACK for unknown pending subscription with Id: %s" % (self.session.client_id, packet_id)) @@ -316,8 +338,9 @@ class ProtocolHandler: def handle_pubrec(self, pubrec: PubrecPacket): packet_id = pubrec.variable_header.packet_id try: - waiter = self._pubrec_waiters[packet_id] - waiter.set_result(pubrec) + inflight_message = self.session.inflight_out[packet_id] + inflight_message.received_pubrec() + yield from self.outgoing_queue.put(PubrelPacket.build(packet_id)) except KeyError as ke: self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id) @@ -325,8 +348,8 @@ class ProtocolHandler: def handle_pubcomp(self, pubcomp: PubcompPacket): packet_id = pubcomp.variable_header.packet_id try: - waiter = self._pubcomp_waiters[packet_id] - waiter.set_result(pubcomp) + inflight_message = self.session.inflight_out[packet_id] + inflight_message.received_pubcomp() except KeyError as ke: self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id) diff --git a/hbmqtt/specs.py b/hbmqtt/specs.py index efa9edf..841d281 100644 --- a/hbmqtt/specs.py +++ b/hbmqtt/specs.py @@ -1,8 +1,7 @@ -__author__ = 'nico' +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. -from enum import Enum - -class Qos(Enum): - Qos_0 = 0x00 - Qos_1 = 0x01 - Qos_2 = 0x02 +QOS_0 = 0x00 +QOS_1 = 0x01 +QOS_2 = 0x02 diff --git a/samples/client_publish.py b/samples/client_publish.py index b59e78a..0b971d2 100644 --- a/samples/client_publish.py +++ b/samples/client_publish.py @@ -38,5 +38,5 @@ def test_coro(): if __name__ == '__main__': formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" - logging.basicConfig(level=logging.INFO, format=formatter) + logging.basicConfig(level=logging.DEBUG, format=formatter) asyncio.get_event_loop().run_until_complete(test_coro()) \ No newline at end of file