kopia lustrzana https://github.com/Yakifo/amqtt
Refactor mqtt_publish and inflight messages management
rodzic
832446acda
commit
d74cdf2665
|
@ -390,9 +390,8 @@ class Broker:
|
|||
(format_client_message(session=source_session),
|
||||
topic, format_client_message(session=target_session)))
|
||||
handler = subscription.session.handler
|
||||
packet_id = handler.session.next_packet_id
|
||||
publish_tasks.append(
|
||||
asyncio.Task(handler.mqtt_publish(topic, data, packet_id, False, qos, retain=False))
|
||||
asyncio.Task(handler.mqtt_publish(topic, data, qos, retain=False))
|
||||
)
|
||||
else:
|
||||
self.logger.debug("retaining application message from %s on topic '%s' to client '%s'" %
|
||||
|
@ -418,10 +417,9 @@ class Broker:
|
|||
publish_tasks = []
|
||||
while not session.retained_messages.empty():
|
||||
retained = yield from session.retained_messages.get()
|
||||
packet_id = session.next_packet_id
|
||||
publish_tasks.append(asyncio.Task(
|
||||
session.handler.mqtt_publish(
|
||||
retained.topic, retained.data, packet_id, False, retained.qos, True)))
|
||||
retained.topic, retained.data, False, retained.qos, True)))
|
||||
if len(publish_tasks) > 0:
|
||||
asyncio.wait(publish_tasks)
|
||||
|
||||
|
@ -435,10 +433,9 @@ class Broker:
|
|||
if self.matches(d_topic, subscription['filter']):
|
||||
self.logger.debug("%s and %s match" % (d_topic, subscription['filter']))
|
||||
retained = self._global_retained_messages[d_topic]
|
||||
packet_id = session.next_packet_id
|
||||
publish_tasks.append(asyncio.Task(
|
||||
session.handler.mqtt_publish(
|
||||
retained.topic, retained.data, packet_id, False, subscription['qos'], True)))
|
||||
retained.topic, retained.data, subscription['qos'], True)))
|
||||
if len(publish_tasks) > 0:
|
||||
asyncio.wait(publish_tasks)
|
||||
self.logger.debug("End broadcasting messages retained due to subscription on '%s' from %s" %
|
||||
|
|
|
@ -143,7 +143,7 @@ class MQTTClient:
|
|||
self._handler.mqtt_ping()
|
||||
|
||||
@asyncio.coroutine
|
||||
def publish(self, topic, message, dup=False, qos=None, retain=None):
|
||||
def publish(self, topic, message, qos=None, retain=None):
|
||||
def get_retain_and_qos():
|
||||
if qos:
|
||||
_qos = qos
|
||||
|
@ -164,11 +164,11 @@ class MQTTClient:
|
|||
return _qos, _retain
|
||||
(app_qos, app_retain) = get_retain_and_qos()
|
||||
if app_qos == 0:
|
||||
yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x00, app_retain)
|
||||
yield from self._handler.mqtt_publish(topic, message, 0x00, app_retain)
|
||||
if app_qos == 1:
|
||||
yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x01, app_retain)
|
||||
yield from self._handler.mqtt_publish(topic, message, 0x01, app_retain)
|
||||
if app_qos == 2:
|
||||
yield from self._handler.mqtt_publish(topic, message, self.session.next_packet_id, dup, 0x02, app_retain)
|
||||
yield from self._handler.mqtt_publish(topic, message, 0x02, app_retain)
|
||||
|
||||
@asyncio.coroutine
|
||||
def subscribe(self, topics):
|
||||
|
|
|
@ -3,10 +3,11 @@
|
|||
# See the file license.txt for copying permission.
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from asyncio import futures
|
||||
from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket
|
||||
from hbmqtt.mqtt import packet_class
|
||||
from hbmqtt.errors import NoDataException
|
||||
from hbmqtt.errors import NoDataException, HBMQTTException
|
||||
from hbmqtt.mqtt.packet import PacketType
|
||||
from hbmqtt.mqtt.connack import ConnackPacket
|
||||
from hbmqtt.mqtt.connect import ConnectPacket
|
||||
|
@ -23,30 +24,77 @@ from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
|
|||
from hbmqtt.mqtt.unsuback import UnsubackPacket
|
||||
from hbmqtt.mqtt.disconnect import DisconnectPacket
|
||||
from hbmqtt.session import Session
|
||||
from transitions import Machine
|
||||
from hbmqtt.specs import *
|
||||
|
||||
|
||||
class InFlightMessage:
|
||||
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
|
||||
|
||||
def __init__(self, packet, qos):
|
||||
def __init__(self, packet, qos, ack_timeout=0, loop=None):
|
||||
if loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self._loop = loop
|
||||
self.packet = packet
|
||||
self.qos = qos
|
||||
self.puback = None
|
||||
self.pubrec = None
|
||||
self.pubcomp = None
|
||||
self.pubrel = None
|
||||
self._init_states()
|
||||
self.publish_ts = None
|
||||
self.puback_ts = None
|
||||
self.pubrec_ts = None
|
||||
self.pubrel_ts = None
|
||||
self.pubcomp_ts = None
|
||||
self._ack_waiter = asyncio.Future(loop=self._loop)
|
||||
self._ack_timeout = ack_timeout
|
||||
self._ack_timeout_handle = None
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(model=self, states=InFlightMessage.states, initial='new')
|
||||
self.machine.add_transition(trigger='publish', source='new', dest='published')
|
||||
if self.qos == 0x01:
|
||||
self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged')
|
||||
if self.qos == 0x02:
|
||||
self.machine.add_transition(trigger='receive', source='published', dest='received')
|
||||
self.machine.add_transition(trigger='release', source='received', dest='released')
|
||||
self.machine.add_transition(trigger='complete', source='released', dest='completed')
|
||||
@asyncio.coroutine
|
||||
def wait_acknowledge(self):
|
||||
return (yield from self._ack_waiter)
|
||||
|
||||
def received_puback(self):
|
||||
if self.qos == QOS_1:
|
||||
self.puback_ts = datetime.now()
|
||||
self.cancel_ack_timeout()
|
||||
self._ack_waiter.set_result(True)
|
||||
else:
|
||||
raise HBMQTTException('Invalid call to method received_puback on inflight messages with QOS=%d' % self.qos)
|
||||
|
||||
def received_pubrec(self):
|
||||
if self.qos == QOS_2:
|
||||
self.pubrec_ts = datetime.now()
|
||||
self.packet = None # Discard message
|
||||
self.reset_ack_timeout()
|
||||
else:
|
||||
raise HBMQTTException('Invalid call to method received_pubrec on inflight messages with QOS=%d' % self.qos)
|
||||
|
||||
def received_pubcomp(self):
|
||||
if self.qos == QOS_2:
|
||||
self.pubcomp_ts = datetime.now()
|
||||
self.cancel_ack_timeout()
|
||||
self._ack_waiter.set_result(True)
|
||||
else:
|
||||
raise HBMQTTException('Invalid call to method received_pubcomp on inflight messages with QOS=%d' % self.qos)
|
||||
|
||||
def sent_pubrel(self):
|
||||
if self.qos == QOS_2:
|
||||
self.pubrel_ts = datetime.now()
|
||||
else:
|
||||
raise HBMQTTException('Invalid call to method sent_pubrel on inflight messages with QOS=%d' % self.qos)
|
||||
|
||||
def start_ack_timeout(self):
|
||||
def cb_timeout():
|
||||
self._ack_waiter.set_result(False)
|
||||
if self._ack_timeout:
|
||||
self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout)
|
||||
|
||||
def cancel_ack_timeout(self):
|
||||
if self._ack_timeout_handle:
|
||||
self._ack_timeout_handle.cancel()
|
||||
|
||||
def reset_ack_timeout(self):
|
||||
self.cancel_ack_timeout()
|
||||
self.start_ack_timeout()
|
||||
|
||||
def sent_publish(self):
|
||||
self.publish_ts = datetime.now()
|
||||
self.start_ack_timeout()
|
||||
|
||||
|
||||
class ProtocolHandler:
|
||||
|
@ -68,10 +116,6 @@ class ProtocolHandler:
|
|||
|
||||
self._running = False
|
||||
|
||||
self.incoming_queues = dict()
|
||||
self.application_messages = asyncio.Queue()
|
||||
for p in PacketType:
|
||||
self.incoming_queues[p] = asyncio.Queue()
|
||||
self.outgoing_queue = asyncio.Queue()
|
||||
self._puback_waiters = dict()
|
||||
self._pubrec_waiters = dict()
|
||||
|
@ -100,47 +144,25 @@ class ProtocolHandler:
|
|||
self.logger.debug("%s Handler tasks started" % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_publish(self, topic, message, packet_id, dup, qos, retain):
|
||||
def mqtt_publish(self, topic, message, qos, retain):
|
||||
packet_id = self.session.next_packet_id
|
||||
if packet_id in self.session.inflight_out:
|
||||
self.logger.warn("%s A message with the same packet ID is already in flight" % self.session.client_id)
|
||||
packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain)
|
||||
packet = PublishPacket.build(topic, message, packet_id, False, qos, retain)
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
inflight_message = InFlightMessage(packet, qos)
|
||||
self.session.inflight_out[packet.variable_header.packet_id] = inflight_message
|
||||
|
||||
inflight_message.publish()
|
||||
if qos == 0x01:
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._puback_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.puback = waiter.result()
|
||||
inflight_message.acknowledge()
|
||||
del self._puback_waiters[packet_id]
|
||||
if qos == 0x02:
|
||||
# Wait for PUBREC
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubrec_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.pubrec = waiter.result()
|
||||
del self._pubrec_waiters[packet_id]
|
||||
inflight_message.receive()
|
||||
|
||||
# Send pubrel
|
||||
pubrel = PubrelPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(pubrel)
|
||||
inflight_message.pubrel = pubrel
|
||||
inflight_message.release()
|
||||
|
||||
# Wait for pubcomp
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubcomp_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.pubcomp = waiter.result()
|
||||
del self._pubcomp_waiters[packet_id]
|
||||
inflight_message.complete()
|
||||
|
||||
del self.session.inflight_out[packet_id]
|
||||
return inflight_message
|
||||
if qos != QOS_0:
|
||||
inflight_message = InFlightMessage(packet, qos, loop=self._loop)
|
||||
inflight_message.sent_publish()
|
||||
self.session.inflight_out[packet_id] = inflight_message
|
||||
ack = yield from inflight_message.wait_acknowledge()
|
||||
while not ack:
|
||||
#Retry publish
|
||||
packet = PublishPacket.build(topic, message, packet_id, True, qos, retain)
|
||||
inflight_message.packet = packet
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
inflight_message.sent_publish()
|
||||
ack = yield from inflight_message.wait_acknowledge()
|
||||
del self.session.inflight_out[packet_id]
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
|
@ -306,8 +328,8 @@ class ProtocolHandler:
|
|||
def handle_puback(self, puback: PubackPacket):
|
||||
packet_id = puback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._puback_waiters[packet_id]
|
||||
waiter.set_result(puback)
|
||||
inflight_message = self.session.inflight_out[packet_id]
|
||||
inflight_message.received_puback()
|
||||
except KeyError as ke:
|
||||
self.logger.warn("%s Received PUBACK for unknown pending subscription with Id: %s" %
|
||||
(self.session.client_id, packet_id))
|
||||
|
@ -316,8 +338,9 @@ class ProtocolHandler:
|
|||
def handle_pubrec(self, pubrec: PubrecPacket):
|
||||
packet_id = pubrec.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubrec_waiters[packet_id]
|
||||
waiter.set_result(pubrec)
|
||||
inflight_message = self.session.inflight_out[packet_id]
|
||||
inflight_message.received_pubrec()
|
||||
yield from self.outgoing_queue.put(PubrelPacket.build(packet_id))
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
|
@ -325,8 +348,8 @@ class ProtocolHandler:
|
|||
def handle_pubcomp(self, pubcomp: PubcompPacket):
|
||||
packet_id = pubcomp.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubcomp_waiters[packet_id]
|
||||
waiter.set_result(pubcomp)
|
||||
inflight_message = self.session.inflight_out[packet_id]
|
||||
inflight_message.received_pubcomp()
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
__author__ = 'nico'
|
||||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class Qos(Enum):
|
||||
Qos_0 = 0x00
|
||||
Qos_1 = 0x01
|
||||
Qos_2 = 0x02
|
||||
QOS_0 = 0x00
|
||||
QOS_1 = 0x01
|
||||
QOS_2 = 0x02
|
||||
|
|
|
@ -38,5 +38,5 @@ def test_coro():
|
|||
|
||||
if __name__ == '__main__':
|
||||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
asyncio.get_event_loop().run_until_complete(test_coro())
|
Ładowanie…
Reference in New Issue