From 0bbba69ffb7455af36331ab1a34ecbcd44a3085b Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Sun, 26 Jul 2015 21:21:35 +0200 Subject: [PATCH] Refactor incoming / outgoing message management --- hbmqtt/broker.py | 5 +- hbmqtt/client.py | 4 + hbmqtt/mqtt/protocol/handler.py | 226 ++++++++++--------------------- hbmqtt/mqtt/protocol/inflight.py | 192 ++++++++++++++++++++++++++ hbmqtt/session.py | 8 ++ samples/client_subscribe.py | 9 +- 6 files changed, 284 insertions(+), 160 deletions(-) create mode 100644 hbmqtt/mqtt/protocol/inflight.py diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index 9cba33f..23494ff 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -271,12 +271,15 @@ class Broker: self.logger.debug(repr(self._subscriptions)) if wait_deliver in done: self.logger.debug("%s handling message delivery" % client_session.client_id) - publish_packet = wait_deliver.result().publish_packet + publish_packet = wait_deliver.result() + packet_id = publish_packet.variable_header.packet_id topic_name = publish_packet.variable_header.topic_name data = publish_packet.payload.data yield from self.broadcast_application_message(client_session, topic_name, data) if publish_packet.retain_flag: self.retain_message(client_session, topic_name, data) + # Acknowledge message delivery + yield from handler.mqtt_acknowledge_delivery(packet_id) wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message()) wait_subscription.cancel() wait_unsubscription.cancel() diff --git a/hbmqtt/client.py b/hbmqtt/client.py index c88501c..04d2918 100644 --- a/hbmqtt/client.py +++ b/hbmqtt/client.py @@ -182,6 +182,10 @@ class MQTTClient: def deliver_message(self): return (yield from self._handler.mqtt_deliver_next_message()) + @asyncio.coroutine + def acknowledge_delivery(self, packet_id): + yield from self._handler.mqtt_acknowledge_delivery(packet_id) + @asyncio.coroutine def _connect_coro(self): try: diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index ca778b5..2ab548f 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -4,7 +4,6 @@ import logging import asyncio from datetime import datetime -from asyncio import futures from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket from hbmqtt.mqtt import packet_class from hbmqtt.errors import NoDataException, HBMQTTException @@ -25,131 +24,7 @@ from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.session import Session from hbmqtt.specs import * -from transitions import Machine, MachineError - - -class InFlightMessage: - states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed'] - - def __init__(self, packet, qos, ack_timeout=0, loop=None): - if loop is None: - self._loop = asyncio.get_event_loop() - else: - self._loop = loop - self.publish_packet = packet - self.qos = qos - self.publish_ts = None - self.puback_ts = None - self.pubrec_ts = None - self.pubrel_ts = None - self.pubcomp_ts = None - self.nb_retries = 0 - self._ack_waiter = asyncio.Future(loop=self._loop) - self._ack_timeout = ack_timeout - self._ack_timeout_handle = None - self._init_states() - - def _init_states(self): - self.machine = Machine(model=self, states=InFlightMessage.states, initial='new') - self.machine.add_transition(trigger='publish', source='new', dest='published') - self.machine.add_transition(trigger='publish', source='published', dest='published') - self.machine.add_transition(trigger='publish', source='received', dest='published') - self.machine.add_transition(trigger='publish', source='released', dest='published') - if self.qos == 0x01: - self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged') - if self.qos == 0x02: - self.machine.add_transition(trigger='receive', source='published', dest='received') - self.machine.add_transition(trigger='release', source='received', dest='released') - self.machine.add_transition(trigger='complete', source='released', dest='completed') - - @asyncio.coroutine - def wait_acknowledge(self): - return (yield from self._ack_waiter) - - def received_puback(self): - try: - self.acknowledge() - self.puback_ts = datetime.now() - self.cancel_ack_timeout() - self._ack_waiter.set_result(True) - except MachineError: - raise HBMQTTException( - 'Invalid call to method received_puback on inflight messages with QOS=%d, state=%s' % - (self.qos, self.state)) - - def received_pubrec(self): - try: - self.receive() - self.pubrec_ts = datetime.now() - self.publish_packet = None # Discard message - self.reset_ack_timeout() - except MachineError: - raise HBMQTTException( - 'Invalid call to method received_pubrec on inflight messages with QOS=%d, state=%s' % - (self.qos, self.state)) - - def received_pubcomp(self): - try: - self.complete() - self.pubcomp_ts = datetime.now() - self.cancel_ack_timeout() - self._ack_waiter.set_result(True) - except MachineError: - raise HBMQTTException( - 'Invalid call to method received_pubcomp on inflight messages with QOS=%d, state=%s' % - (self.qos, self.state)) - - def sent_pubrel(self): - try: - self.release() - self.pubrel_ts = datetime.now() - except MachineError: - raise HBMQTTException( - 'Invalid call to method sent_pubrel on inflight messages with QOS=%d, state=%s' % - (self.qos, self.state)) - - def retry_publish(self): - try: - self.publish() - self.nb_retries += 1 - self.publish_ts = datetime.now() - self.start_ack_timeout() - except MachineError: - raise HBMQTTException( - 'Invalid call to method retry_publish on inflight messages with QOS=%d, state=%s' % - (self.qos, self.state)) - - def sent_publish(self): - try: - self.publish() - self.publish_ts = datetime.now() - self.start_ack_timeout() - except MachineError: - raise HBMQTTException( - 'Invalid call to method sent_publish on inflight messages with QOS=%d, state=%s' % - (self.qos, self.state)) - - def start_ack_timeout(self): - def cb_timeout(): - self._ack_waiter.set_result(False) - if self._ack_timeout: - self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout) - - def cancel_ack_timeout(self): - if self._ack_timeout_handle: - self._ack_timeout_handle.cancel() - - def reset_ack_timeout(self): - self.cancel_ack_timeout() - self.start_ack_timeout() - - -class IncomingInFlightMessage(InFlightMessage): - pass - - -class OutgoingInFlightMessage(InFlightMessage): - pass +from hbmqtt.mqtt.protocol.inflight import * class ProtocolHandler: @@ -173,7 +48,6 @@ class ProtocolHandler: self.outgoing_queue = asyncio.Queue() self._pubrel_waiters = dict() - self.delivered_message = asyncio.Queue() def attach_to_session(self, session: Session): self.session = session @@ -247,6 +121,11 @@ class ProtocolHandler: self.session.reader.feed_eof() yield from self.outgoing_queue.put("STOP") yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop) + # Stop incoming messages flow waiter + for packet_id in self.session.incoming_msg: + self.session.incoming_msg[packet_id].cancel() + for packet_id in self.session.outgoing_msg: + self.session.outgoing_msg[packet_id].cancel() @asyncio.coroutine def _reader_coro(self): @@ -352,8 +231,21 @@ class ProtocolHandler: @asyncio.coroutine def mqtt_deliver_next_message(self): - inflight_message = yield from self.delivered_message.get() - return inflight_message + packet_id = yield from self.session.delivered_message_queue.get() + message = self.session.incoming_msg[packet_id] + if message.qos == QOS_0: + del self.session.incoming_msg[packet_id] + self.logger.debug("Discarded incoming message %s" % packet_id) + return message.publish_packet + + @asyncio.coroutine + def mqtt_acknowledge_delivery(self, packet_id): + try: + message = self.session.incoming_msg[packet_id] + message.acknowledge_delivery() + self.logger.debug('Message delivery acknowledged, packed_id=%d' % packet_id) + except KeyError: + pass def handle_write_timeout(self): self.logger.warn('%s write timeout unhandled' % self.session.client_id) @@ -435,44 +327,68 @@ class ProtocolHandler: def handle_pubrel(self, pubrel: PubrecPacket): packet_id = pubrel.variable_header.packet_id try: - waiter = self._pubrel_waiters[packet_id] - waiter.set_result(pubrel) + inflight_message = self.session.incoming_msg[packet_id] + inflight_message.received_pubrel() except KeyError as ke: self.logger.warn("Received PUBREL for unknown pending subscription with Id: %s" % packet_id) @asyncio.coroutine def handle_publish(self, publish_packet: PublishPacket): - inflight_message = None + incoming_message = None packet_id = publish_packet.variable_header.packet_id qos = publish_packet.qos if qos == 0: - inflight_message = IncomingInFlightMessage(publish_packet, qos) - yield from self.delivered_message.put(inflight_message) - else: - if packet_id in self.session.incoming_msg: - inflight_message = self.session.incoming_msg[packet_id] + if publish_packet.dup_flag: + self.logger.warn("[MQTT-3.3.1-2] DUP flag must set to 0 for QOS 0 message. Message ignored: %s" % + repr(publish_packet)) else: - inflight_message = InFlightMessage(publish_packet, qos) - self.session.incoming_msg[packet_id] = inflight_message - inflight_message.publish() + incoming_message = IncomingInFlightMessage(publish_packet, qos) + incoming_message.received_publish() + self.session.incoming_msg[packet_id] = incoming_message + yield from self.session.delivered_message_queue.put(packet_id) + else: + # Check if publish is a retry + if packet_id in self.session.incoming_msg: + incoming_message = self.session.incoming_msg[packet_id] + else: + incoming_message = IncomingInFlightMessage(publish_packet, qos) + self.session.incoming_msg[packet_id] = incoming_message + incoming_message.publish() if qos == 1: - puback = PubackPacket.build(packet_id) - yield from self.outgoing_queue.put(puback) - inflight_message.acknowledge() + # Initiate delivery + yield from self.session.delivered_message_queue.put(packet_id) + ack = yield from incoming_message.wait_acknowledge() + if ack: + # Send PUBACK + puback = PubackPacket.build(packet_id) + yield from self.outgoing_queue.put(puback) + #Discard message + del self.session.incoming_msg[packet_id] + self.logger.debug("Discarded incoming message %d" % packet_id) + else: + raise HBMQTTException("Something wrong, ack is False") if qos == 2: + # Send PUBREC pubrec = PubrecPacket.build(packet_id) yield from self.outgoing_queue.put(pubrec) - inflight_message.receive() - waiter = futures.Future(loop=self._loop) - self._pubrel_waiters[packet_id] = waiter - yield from waiter - inflight_message.pubrel = waiter.result() - del self._pubrel_waiters[packet_id] - inflight_message.release() - pubcomp = PubcompPacket.build(packet_id) - yield from self.outgoing_queue.put(pubcomp) - inflight_message.complete() - yield from self.delivered_message.put(inflight_message) - del self.session.incoming_msg[packet_id] + incoming_message.sent_pubrec() + # Wait for pubrel + ack = yield from incoming_message.wait_pubrel() + if ack: + # Initiate delivery + yield from self.session.delivered_message_queue.put(packet_id) + else: + raise HBMQTTException("Something wrong, ack is False") + ack = yield from incoming_message.wait_acknowledge() + if ack: + # Send PUBCOMP + pubcomp = PubcompPacket.build(packet_id) + yield from self.outgoing_queue.put(pubcomp) + incoming_message.sent_pubcomp() + #Discard message + del self.session.incoming_msg[packet_id] + self.logger.debug("Discarded incoming message %d" % packet_id) + else: + raise HBMQTTException("Something wrong, ack is False") diff --git a/hbmqtt/mqtt/protocol/inflight.py b/hbmqtt/mqtt/protocol/inflight.py new file mode 100644 index 0000000..557c400 --- /dev/null +++ b/hbmqtt/mqtt/protocol/inflight.py @@ -0,0 +1,192 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +import asyncio +from transitions import Machine, MachineError +from datetime import datetime +from hbmqtt.errors import HBMQTTException + + +class InFlightMessage: + states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed'] + + def __init__(self, packet, qos, ack_timeout=0, loop=None): + if loop is None: + self._loop = asyncio.get_event_loop() + else: + self._loop = loop + self.publish_packet = packet + self.qos = qos + self.publish_ts = None + self.puback_ts = None + self.pubrec_ts = None + self.pubrel_ts = None + self.pubcomp_ts = None + self.nb_retries = 0 + self._ack_waiter = asyncio.Future(loop=self._loop) + self._ack_timeout = ack_timeout + self._ack_timeout_handle = None + self._init_states() + + def _init_states(self): + self.machine = Machine(model=self, states=InFlightMessage.states, initial='new') + self.machine.add_transition(trigger='publish', source='new', dest='published') + self.machine.add_transition(trigger='publish', source='published', dest='published') + self.machine.add_transition(trigger='publish', source='received', dest='published') + self.machine.add_transition(trigger='publish', source='released', dest='published') + if self.qos == 0x01: + self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged') + if self.qos == 0x02: + self.machine.add_transition(trigger='receive', source='published', dest='received') + self.machine.add_transition(trigger='release', source='received', dest='released') + self.machine.add_transition(trigger='complete', source='released', dest='completed') + self.machine.add_transition(trigger='acknowledge', source='completed', dest='acknowledged') + + @asyncio.coroutine + def wait_acknowledge(self): + return (yield from self._ack_waiter) + + def start_ack_timeout(self): + def cb_timeout(): + self._ack_waiter.set_result(False) + if self._ack_timeout: + self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout) + + def cancel_ack_timeout(self): + if self._ack_timeout_handle: + self._ack_timeout_handle.cancel() + + def reset_ack_timeout(self): + self.cancel_ack_timeout() + self.start_ack_timeout() + + def cancel(self): + if self._ack_waiter and not self._ack_waiter.done(): + self._ack_waiter.cancel() + self.cancel_ack_timeout() + + +class OutgoingInFlightMessage(InFlightMessage): + def received_puback(self): + try: + self.acknowledge() + self.puback_ts = datetime.now() + self.cancel_ack_timeout() + self._ack_waiter.set_result(True) + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_puback on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def received_pubrec(self): + try: + self.receive() + self.pubrec_ts = datetime.now() + self.publish_packet = None # Discard message + self.reset_ack_timeout() + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_pubrec on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def received_pubcomp(self): + try: + self.complete() + self.pubcomp_ts = datetime.now() + self.cancel_ack_timeout() + self._ack_waiter.set_result(True) + self.acknowledge() + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_pubcomp on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def sent_pubrel(self): + try: + self.release() + self.pubrel_ts = datetime.now() + except MachineError: + raise HBMQTTException( + 'Invalid call to method sent_pubrel on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def retry_publish(self): + try: + self.publish() + self.nb_retries += 1 + self.publish_ts = datetime.now() + self.start_ack_timeout() + except MachineError: + raise HBMQTTException( + 'Invalid call to method retry_publish on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def sent_publish(self): + try: + self.publish() + self.publish_ts = datetime.now() + self.start_ack_timeout() + except MachineError: + raise HBMQTTException( + 'Invalid call to method sent_publish on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + +class IncomingInFlightMessage(InFlightMessage): + def __init__(self, packet, qos, ack_timeout=0, loop=None): + super().__init__(packet, qos, ack_timeout, loop) + self._pubrel_waiter = asyncio.Future(loop=self._loop) + + def received_publish(self): + try: + self.publish() + self.publish_ts = datetime.now() + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_publish on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def sent_pubrec(self): + try: + self.receive() + self.pubrec_ts = datetime.now() + except MachineError: + raise HBMQTTException( + 'Invalid call to method sent_pubrec on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def sent_pubcomp(self): + try: + self.complete() + self.pubcomp_ts = datetime.now() + except MachineError: + raise HBMQTTException( + 'Invalid call to method sent_pubrec on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + @asyncio.coroutine + def wait_pubrel(self): + return (yield from self._pubrel_waiter) + + def received_pubrel(self): + try: + self.release() + self.pubrel_ts = datetime.now() + self._pubrel_waiter.set_result(True) + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_pubcomp on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def acknowledge_delivery(self): + try: + self._ack_waiter.set_result(True) + except MachineError: + raise HBMQTTException( + 'Invalid call to method acknowledge_delivery on in-flight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def cancel(self): + super().cancel() + if self._pubrel_waiter and not self._pubrel_waiter.done(): + self._pubrel_waiter.cancel() diff --git a/hbmqtt/session.py b/hbmqtt/session.py index ba240a0..2f26158 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -31,10 +31,18 @@ class Session: self.parent = 0 self.handler = None + # Used to store outgoing InflightMessage while publish protocol flows self.outgoing_msg = dict() + + # Used to store incoming InflightMessage while publish protocol flows self.incoming_msg = dict() + + # Stores messages retained for this session self.retained_messages = Queue() + # Stores PUBLISH messages ID received in order and ready for application process + self.delivered_message_queue = Queue() + def _init_states(self): self.machine = Machine(states=Session.states, initial='new') self.machine.add_transition(trigger='connect', source='new', dest='connected') diff --git a/samples/client_subscribe.py b/samples/client_subscribe.py index 856795b..e2c5077 100644 --- a/samples/client_subscribe.py +++ b/samples/client_subscribe.py @@ -21,12 +21,13 @@ def uptime_coro(): # Subscribe to '$SYS/broker/uptime' with QOS=1 yield from C.subscribe([ {'filter': '$SYS/broker/uptime', 'qos': 0x01}, - {'filter': '$SYS/broker/load/#', 'qos': 0x00}, + {'filter': '$SYS/broker/load/#', 'qos': 0x02}, ]) logger.info("Subscribed") - for i in range(1, 10): - inflight = yield from C.deliver_message() - print(inflight.packet.payload.data) + for i in range(1, 100): + packet = yield from C.deliver_message() + print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data))) + yield from C.acknowledge_delivery(packet.variable_header.packet_id) yield from C.unsubscribe(['$SYS/broker/uptime']) logger.info("UnSubscribed") yield from C.disconnect()