Refactor mqtt_publish and inflight messages management

pull/8/head
Nicolas Jouanin 2015-07-25 23:21:25 +02:00
rodzic 832446acda
commit d74cdf2665
5 zmienionych plików z 104 dodań i 85 usunięć

Wyświetl plik

@ -390,9 +390,8 @@ class Broker:
(format_client_message(session=source_session),
topic, format_client_message(session=target_session)))
handler = subscription.session.handler
packet_id = handler.session.next_packet_id
publish_tasks.append(
asyncio.Task(handler.mqtt_publish(topic, data, packet_id, False, qos, retain=False))
asyncio.Task(handler.mqtt_publish(topic, data, qos, retain=False))
)
else:
self.logger.debug("retaining application message from %s on topic '%s' to client '%s'" %
@ -418,10 +417,9 @@ class Broker:
publish_tasks = []
while not session.retained_messages.empty():
retained = yield from session.retained_messages.get()
packet_id = session.next_packet_id
publish_tasks.append(asyncio.Task(
session.handler.mqtt_publish(
retained.topic, retained.data, packet_id, False, retained.qos, True)))
retained.topic, retained.data, False, retained.qos, True)))
if len(publish_tasks) > 0:
asyncio.wait(publish_tasks)
@ -435,10 +433,9 @@ class Broker:
if self.matches(d_topic, subscription['filter']):
self.logger.debug("%s and %s match" % (d_topic, subscription['filter']))
retained = self._global_retained_messages[d_topic]
packet_id = session.next_packet_id
publish_tasks.append(asyncio.Task(
session.handler.mqtt_publish(
retained.topic, retained.data, packet_id, False, subscription['qos'], True)))
retained.topic, retained.data, subscription['qos'], True)))
if len(publish_tasks) > 0:
asyncio.wait(publish_tasks)
self.logger.debug("End broadcasting messages retained due to subscription on '%s' from %s" %

Wyświetl plik

@ -143,7 +143,7 @@ class MQTTClient:
self._handler.mqtt_ping()
@asyncio.coroutine
def publish(self, topic, message, dup=False, qos=None, retain=None):
def publish(self, topic, message, qos=None, retain=None):
def get_retain_and_qos():
if qos:
_qos = qos
@ -164,11 +164,11 @@ class MQTTClient:
return _qos, _retain
(app_qos, app_retain) = get_retain_and_qos()
if app_qos == 0:
yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x00, app_retain)
yield from self._handler.mqtt_publish(topic, message, 0x00, app_retain)
if app_qos == 1:
yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x01, app_retain)
yield from self._handler.mqtt_publish(topic, message, 0x01, app_retain)
if app_qos == 2:
yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x02, app_retain)
yield from self._handler.mqtt_publish(topic, message, 0x02, app_retain)
@asyncio.coroutine
def subscribe(self, topics):

Wyświetl plik

@ -3,10 +3,11 @@
# See the file license.txt for copying permission.
import logging
import asyncio
from datetime import datetime
from asyncio import futures
from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket
from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException
from hbmqtt.errors import NoDataException, HBMQTTException
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.connect import ConnectPacket
@ -23,30 +24,77 @@ from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.session import Session
from transitions import Machine
from hbmqtt.specs import *
class InFlightMessage:
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
def __init__(self, packet, qos):
def __init__(self, packet, qos, ack_timeout=0, loop=None):
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self.packet = packet
self.qos = qos
self.puback = None
self.pubrec = None
self.pubcomp = None
self.pubrel = None
self._init_states()
self.publish_ts = None
self.puback_ts = None
self.pubrec_ts = None
self.pubrel_ts = None
self.pubcomp_ts = None
self._ack_waiter = asyncio.Future(loop=self._loop)
self._ack_timeout = ack_timeout
self._ack_timeout_handle = None
def _init_states(self):
self.machine = Machine(model=self, states=InFlightMessage.states, initial='new')
self.machine.add_transition(trigger='publish', source='new', dest='published')
if self.qos == 0x01:
self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged')
if self.qos == 0x02:
self.machine.add_transition(trigger='receive', source='published', dest='received')
self.machine.add_transition(trigger='release', source='received', dest='released')
self.machine.add_transition(trigger='complete', source='released', dest='completed')
@asyncio.coroutine
def wait_acknowledge(self):
return (yield from self._ack_waiter)
def received_puback(self):
if self.qos == QOS_1:
self.puback_ts = datetime.now()
self.cancel_ack_timeout()
self._ack_waiter.set_result(True)
else:
raise HBMQTTException('Invalid call to method received_puback on inflight messages with QOS=%d' % self.qos)
def received_pubrec(self):
if self.qos == QOS_2:
self.pubrec_ts = datetime.now()
self.packet = None # Discard message
self.reset_ack_timeout()
else:
raise HBMQTTException('Invalid call to method received_pubrec on inflight messages with QOS=%d' % self.qos)
def received_pubcomp(self):
if self.qos == QOS_2:
self.pubcomp_ts = datetime.now()
self.cancel_ack_timeout()
self._ack_waiter.set_result(True)
else:
raise HBMQTTException('Invalid call to method received_pubcomp on inflight messages with QOS=%d' % self.qos)
def sent_pubrel(self):
if self.qos == QOS_2:
self.pubrel_ts = datetime.now()
else:
raise HBMQTTException('Invalid call to method sent_pubrel on inflight messages with QOS=%d' % self.qos)
def start_ack_timeout(self):
def cb_timeout():
self._ack_waiter.set_result(False)
if self._ack_timeout:
self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout)
def cancel_ack_timeout(self):
if self._ack_timeout_handle:
self._ack_timeout_handle.cancel()
def reset_ack_timeout(self):
self.cancel_ack_timeout()
self.start_ack_timeout()
def sent_publish(self):
self.publish_ts = datetime.now()
self.start_ack_timeout()
class ProtocolHandler:
@ -68,10 +116,6 @@ class ProtocolHandler:
self._running = False
self.incoming_queues = dict()
self.application_messages = asyncio.Queue()
for p in PacketType:
self.incoming_queues[p] = asyncio.Queue()
self.outgoing_queue = asyncio.Queue()
self._puback_waiters = dict()
self._pubrec_waiters = dict()
@ -100,47 +144,25 @@ class ProtocolHandler:
self.logger.debug("%s Handler tasks started" % self.session.client_id)
@asyncio.coroutine
def mqtt_publish(self, topic, message, packet_id, dup, qos, retain):
def mqtt_publish(self, topic, message, qos, retain):
packet_id = self.session.next_packet_id
if packet_id in self.session.inflight_out:
self.logger.warn("%s A message with the same packet ID is already in flight" % self.session.client_id)
packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain)
packet = PublishPacket.build(topic, message, packet_id, False, qos, retain)
yield from self.outgoing_queue.put(packet)
inflight_message = InFlightMessage(packet, qos)
self.session.inflight_out[packet.variable_header.packet_id] = inflight_message
inflight_message.publish()
if qos == 0x01:
waiter = futures.Future(loop=self._loop)
self._puback_waiters[packet_id] = waiter
yield from waiter
inflight_message.puback = waiter.result()
inflight_message.acknowledge()
del self._puback_waiters[packet_id]
if qos == 0x02:
# Wait for PUBREC
waiter = futures.Future(loop=self._loop)
self._pubrec_waiters[packet_id] = waiter
yield from waiter
inflight_message.pubrec = waiter.result()
del self._pubrec_waiters[packet_id]
inflight_message.receive()
# Send pubrel
pubrel = PubrelPacket.build(packet_id)
yield from self.outgoing_queue.put(pubrel)
inflight_message.pubrel = pubrel
inflight_message.release()
# Wait for pubcomp
waiter = futures.Future(loop=self._loop)
self._pubcomp_waiters[packet_id] = waiter
yield from waiter
inflight_message.pubcomp = waiter.result()
del self._pubcomp_waiters[packet_id]
inflight_message.complete()
del self.session.inflight_out[packet_id]
return inflight_message
if qos != QOS_0:
inflight_message = InFlightMessage(packet, qos, loop=self._loop)
inflight_message.sent_publish()
self.session.inflight_out[packet_id] = inflight_message
ack = yield from inflight_message.wait_acknowledge()
while not ack:
#Retry publish
packet = PublishPacket.build(topic, message, packet_id, True, qos, retain)
inflight_message.packet = packet
yield from self.outgoing_queue.put(packet)
inflight_message.sent_publish()
ack = yield from inflight_message.wait_acknowledge()
del self.session.inflight_out[packet_id]
@asyncio.coroutine
def stop(self):
@ -306,8 +328,8 @@ class ProtocolHandler:
def handle_puback(self, puback: PubackPacket):
packet_id = puback.variable_header.packet_id
try:
waiter = self._puback_waiters[packet_id]
waiter.set_result(puback)
inflight_message = self.session.inflight_out[packet_id]
inflight_message.received_puback()
except KeyError as ke:
self.logger.warn("%s Received PUBACK for unknown pending subscription with Id: %s" %
(self.session.client_id, packet_id))
@ -316,8 +338,9 @@ class ProtocolHandler:
def handle_pubrec(self, pubrec: PubrecPacket):
packet_id = pubrec.variable_header.packet_id
try:
waiter = self._pubrec_waiters[packet_id]
waiter.set_result(pubrec)
inflight_message = self.session.inflight_out[packet_id]
inflight_message.received_pubrec()
yield from self.outgoing_queue.put(PubrelPacket.build(packet_id))
except KeyError as ke:
self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)
@ -325,8 +348,8 @@ class ProtocolHandler:
def handle_pubcomp(self, pubcomp: PubcompPacket):
packet_id = pubcomp.variable_header.packet_id
try:
waiter = self._pubcomp_waiters[packet_id]
waiter.set_result(pubcomp)
inflight_message = self.session.inflight_out[packet_id]
inflight_message.received_pubcomp()
except KeyError as ke:
self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id)

Wyświetl plik

@ -1,8 +1,7 @@
__author__ = 'nico'
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
from enum import Enum
class Qos(Enum):
Qos_0 = 0x00
Qos_1 = 0x01
Qos_2 = 0x02
QOS_0 = 0x00
QOS_1 = 0x01
QOS_2 = 0x02

Wyświetl plik

@ -38,5 +38,5 @@ def test_coro():
if __name__ == '__main__':
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
logging.basicConfig(level=logging.DEBUG, format=formatter)
asyncio.get_event_loop().run_until_complete(test_coro())