diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 22e557c..6d690c5 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -25,58 +25,109 @@ from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.session import Session from hbmqtt.specs import * +from transitions import Machine, MachineError class InFlightMessage: + states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed'] + def __init__(self, packet, qos, ack_timeout=0, loop=None): if loop is None: self._loop = asyncio.get_event_loop() else: self._loop = loop - self.packet = packet + self.publish_packet = packet self.qos = qos self.publish_ts = None self.puback_ts = None self.pubrec_ts = None self.pubrel_ts = None self.pubcomp_ts = None + self.nb_retries = 0 self._ack_waiter = asyncio.Future(loop=self._loop) self._ack_timeout = ack_timeout self._ack_timeout_handle = None + self._init_states() + + def _init_states(self): + self.machine = Machine(model=self, states=InFlightMessage.states, initial='new') + self.machine.add_transition(trigger='publish', source='new', dest='published') + self.machine.add_transition(trigger='publish', source='published', dest='published') + self.machine.add_transition(trigger='publish', source='received', dest='published') + self.machine.add_transition(trigger='publish', source='released', dest='published') + if self.qos == 0x01: + self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged') + if self.qos == 0x02: + self.machine.add_transition(trigger='receive', source='published', dest='received') + self.machine.add_transition(trigger='release', source='received', dest='released') + self.machine.add_transition(trigger='complete', source='released', dest='completed') @asyncio.coroutine def wait_acknowledge(self): return (yield from self._ack_waiter) def received_puback(self): - if self.qos == QOS_1: + try: + self.acknowledge() self.puback_ts = datetime.now() self.cancel_ack_timeout() self._ack_waiter.set_result(True) - else: - raise HBMQTTException('Invalid call to method received_puback on inflight messages with QOS=%d' % self.qos) + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_puback on inflight messages with QOS=%d, state=%s' % + (self.qos, self.state)) def received_pubrec(self): - if self.qos == QOS_2: + try: + self.receive() self.pubrec_ts = datetime.now() - self.packet = None # Discard message + self.publish_packet = None # Discard message self.reset_ack_timeout() - else: - raise HBMQTTException('Invalid call to method received_pubrec on inflight messages with QOS=%d' % self.qos) + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_pubrec on inflight messages with QOS=%d, state=%s' % + (self.qos, self.state)) def received_pubcomp(self): - if self.qos == QOS_2: + try: + self.complete() self.pubcomp_ts = datetime.now() self.cancel_ack_timeout() self._ack_waiter.set_result(True) - else: - raise HBMQTTException('Invalid call to method received_pubcomp on inflight messages with QOS=%d' % self.qos) + except MachineError: + raise HBMQTTException( + 'Invalid call to method received_pubcomp on inflight messages with QOS=%d, state=%s' % + (self.qos, self.state)) def sent_pubrel(self): - if self.qos == QOS_2: + try: + self.release() self.pubrel_ts = datetime.now() - else: - raise HBMQTTException('Invalid call to method sent_pubrel on inflight messages with QOS=%d' % self.qos) + except MachineError: + raise HBMQTTException( + 'Invalid call to method sent_pubrel on inflight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def retry_publish(self): + try: + self.publish() + self.nb_retries += 1 + self.publish_ts = datetime.now() + self.start_ack_timeout() + except MachineError: + raise HBMQTTException( + 'Invalid call to method retry_publish on inflight messages with QOS=%d, state=%s' % + (self.qos, self.state)) + + def sent_publish(self): + try: + self.publish() + self.publish_ts = datetime.now() + self.start_ack_timeout() + except MachineError: + raise HBMQTTException( + 'Invalid call to method sent_publish on inflight messages with QOS=%d, state=%s' % + (self.qos, self.state)) def start_ack_timeout(self): def cb_timeout(): @@ -92,10 +143,6 @@ class InFlightMessage: self.cancel_ack_timeout() self.start_ack_timeout() - def sent_publish(self): - self.publish_ts = datetime.now() - self.start_ack_timeout() - class ProtocolHandler: """ @@ -142,6 +189,31 @@ class ProtocolHandler: yield from asyncio.wait( [self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) self.logger.debug("%s Handler tasks started" % self.session.client_id) + yield from self.retry_deliveries() + + @asyncio.coroutine + def retry_deliveries(self): + """ + Handle [MQTT-4.4.0-1] by resending PUBLISH and PUBREL messages for pending out messages + :return: + """ + self.logger.debug("Begin messages delivery retries") + for packet_id in self.session.inflight_out: + message = self.session.inflight_out[packet_id] + if message.is_new(): + self.logger.debug("Retrying publish message Id=%d", packet_id) + message.publish_packet.dup_flag = True + ack = False + while not ack: + yield from self.outgoing_queue.put(message.publish_packet) + message.retry_publish() + ack = yield from message.wait_acknowledge() + del self.session.inflight_out[packet_id] + if message.is_received(): + self.logger.debug("Retrying pubrel message Id=%d", packet_id) + yield from self.outgoing_queue.put(PubrelPacket.build(packet_id)) + message.sent_pubrel() + self.logger.debug("End messages delivery retries") @asyncio.coroutine def mqtt_publish(self, topic, message, qos, retain): @@ -158,9 +230,9 @@ class ProtocolHandler: while not ack: #Retry publish packet = PublishPacket.build(topic, message, packet_id, True, qos, retain) - inflight_message.packet = packet + inflight_message.publish_packet = packet yield from self.outgoing_queue.put(packet) - inflight_message.sent_publish() + inflight_message.retry_publish() ack = yield from inflight_message.wait_acknowledge() del self.session.inflight_out[packet_id] @@ -341,6 +413,7 @@ class ProtocolHandler: inflight_message = self.session.inflight_out[packet_id] inflight_message.received_pubrec() yield from self.outgoing_queue.put(PubrelPacket.build(packet_id)) + inflight_message.sent_pubrel() except KeyError as ke: self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)