Refactor incoming / outgoing message management

pull/8/head
Nicolas Jouanin 2015-07-26 21:21:35 +02:00
rodzic 238069e5d8
commit 0bbba69ffb
6 zmienionych plików z 284 dodań i 160 usunięć

Wyświetl plik

@ -271,12 +271,15 @@ class Broker:
self.logger.debug(repr(self._subscriptions)) self.logger.debug(repr(self._subscriptions))
if wait_deliver in done: if wait_deliver in done:
self.logger.debug("%s handling message delivery" % client_session.client_id) self.logger.debug("%s handling message delivery" % client_session.client_id)
publish_packet = wait_deliver.result().publish_packet publish_packet = wait_deliver.result()
packet_id = publish_packet.variable_header.packet_id
topic_name = publish_packet.variable_header.topic_name topic_name = publish_packet.variable_header.topic_name
data = publish_packet.payload.data data = publish_packet.payload.data
yield from self.broadcast_application_message(client_session, topic_name, data) yield from self.broadcast_application_message(client_session, topic_name, data)
if publish_packet.retain_flag: if publish_packet.retain_flag:
self.retain_message(client_session, topic_name, data) self.retain_message(client_session, topic_name, data)
# Acknowledge message delivery
yield from handler.mqtt_acknowledge_delivery(packet_id)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message()) wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message())
wait_subscription.cancel() wait_subscription.cancel()
wait_unsubscription.cancel() wait_unsubscription.cancel()

Wyświetl plik

@ -182,6 +182,10 @@ class MQTTClient:
def deliver_message(self): def deliver_message(self):
return (yield from self._handler.mqtt_deliver_next_message()) return (yield from self._handler.mqtt_deliver_next_message())
@asyncio.coroutine
def acknowledge_delivery(self, packet_id):
yield from self._handler.mqtt_acknowledge_delivery(packet_id)
@asyncio.coroutine @asyncio.coroutine
def _connect_coro(self): def _connect_coro(self):
try: try:

Wyświetl plik

@ -4,7 +4,6 @@
import logging import logging
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from asyncio import futures
from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket
from hbmqtt.mqtt import packet_class from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException, HBMQTTException from hbmqtt.errors import NoDataException, HBMQTTException
@ -25,131 +24,7 @@ from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.session import Session from hbmqtt.session import Session
from hbmqtt.specs import * from hbmqtt.specs import *
from transitions import Machine, MachineError from hbmqtt.mqtt.protocol.inflight import *
class InFlightMessage:
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
def __init__(self, packet, qos, ack_timeout=0, loop=None):
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self.publish_packet = packet
self.qos = qos
self.publish_ts = None
self.puback_ts = None
self.pubrec_ts = None
self.pubrel_ts = None
self.pubcomp_ts = None
self.nb_retries = 0
self._ack_waiter = asyncio.Future(loop=self._loop)
self._ack_timeout = ack_timeout
self._ack_timeout_handle = None
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=InFlightMessage.states, initial='new')
self.machine.add_transition(trigger='publish', source='new', dest='published')
self.machine.add_transition(trigger='publish', source='published', dest='published')
self.machine.add_transition(trigger='publish', source='received', dest='published')
self.machine.add_transition(trigger='publish', source='released', dest='published')
if self.qos == 0x01:
self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged')
if self.qos == 0x02:
self.machine.add_transition(trigger='receive', source='published', dest='received')
self.machine.add_transition(trigger='release', source='received', dest='released')
self.machine.add_transition(trigger='complete', source='released', dest='completed')
@asyncio.coroutine
def wait_acknowledge(self):
return (yield from self._ack_waiter)
def received_puback(self):
try:
self.acknowledge()
self.puback_ts = datetime.now()
self.cancel_ack_timeout()
self._ack_waiter.set_result(True)
except MachineError:
raise HBMQTTException(
'Invalid call to method received_puback on inflight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def received_pubrec(self):
try:
self.receive()
self.pubrec_ts = datetime.now()
self.publish_packet = None # Discard message
self.reset_ack_timeout()
except MachineError:
raise HBMQTTException(
'Invalid call to method received_pubrec on inflight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def received_pubcomp(self):
try:
self.complete()
self.pubcomp_ts = datetime.now()
self.cancel_ack_timeout()
self._ack_waiter.set_result(True)
except MachineError:
raise HBMQTTException(
'Invalid call to method received_pubcomp on inflight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def sent_pubrel(self):
try:
self.release()
self.pubrel_ts = datetime.now()
except MachineError:
raise HBMQTTException(
'Invalid call to method sent_pubrel on inflight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def retry_publish(self):
try:
self.publish()
self.nb_retries += 1
self.publish_ts = datetime.now()
self.start_ack_timeout()
except MachineError:
raise HBMQTTException(
'Invalid call to method retry_publish on inflight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def sent_publish(self):
try:
self.publish()
self.publish_ts = datetime.now()
self.start_ack_timeout()
except MachineError:
raise HBMQTTException(
'Invalid call to method sent_publish on inflight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def start_ack_timeout(self):
def cb_timeout():
self._ack_waiter.set_result(False)
if self._ack_timeout:
self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout)
def cancel_ack_timeout(self):
if self._ack_timeout_handle:
self._ack_timeout_handle.cancel()
def reset_ack_timeout(self):
self.cancel_ack_timeout()
self.start_ack_timeout()
class IncomingInFlightMessage(InFlightMessage):
pass
class OutgoingInFlightMessage(InFlightMessage):
pass
class ProtocolHandler: class ProtocolHandler:
@ -173,7 +48,6 @@ class ProtocolHandler:
self.outgoing_queue = asyncio.Queue() self.outgoing_queue = asyncio.Queue()
self._pubrel_waiters = dict() self._pubrel_waiters = dict()
self.delivered_message = asyncio.Queue()
def attach_to_session(self, session: Session): def attach_to_session(self, session: Session):
self.session = session self.session = session
@ -247,6 +121,11 @@ class ProtocolHandler:
self.session.reader.feed_eof() self.session.reader.feed_eof()
yield from self.outgoing_queue.put("STOP") yield from self.outgoing_queue.put("STOP")
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop) yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
# Stop incoming messages flow waiter
for packet_id in self.session.incoming_msg:
self.session.incoming_msg[packet_id].cancel()
for packet_id in self.session.outgoing_msg:
self.session.outgoing_msg[packet_id].cancel()
@asyncio.coroutine @asyncio.coroutine
def _reader_coro(self): def _reader_coro(self):
@ -352,8 +231,21 @@ class ProtocolHandler:
@asyncio.coroutine @asyncio.coroutine
def mqtt_deliver_next_message(self): def mqtt_deliver_next_message(self):
inflight_message = yield from self.delivered_message.get() packet_id = yield from self.session.delivered_message_queue.get()
return inflight_message message = self.session.incoming_msg[packet_id]
if message.qos == QOS_0:
del self.session.incoming_msg[packet_id]
self.logger.debug("Discarded incoming message %s" % packet_id)
return message.publish_packet
@asyncio.coroutine
def mqtt_acknowledge_delivery(self, packet_id):
try:
message = self.session.incoming_msg[packet_id]
message.acknowledge_delivery()
self.logger.debug('Message delivery acknowledged, packed_id=%d' % packet_id)
except KeyError:
pass
def handle_write_timeout(self): def handle_write_timeout(self):
self.logger.warn('%s write timeout unhandled' % self.session.client_id) self.logger.warn('%s write timeout unhandled' % self.session.client_id)
@ -435,44 +327,68 @@ class ProtocolHandler:
def handle_pubrel(self, pubrel: PubrecPacket): def handle_pubrel(self, pubrel: PubrecPacket):
packet_id = pubrel.variable_header.packet_id packet_id = pubrel.variable_header.packet_id
try: try:
waiter = self._pubrel_waiters[packet_id] inflight_message = self.session.incoming_msg[packet_id]
waiter.set_result(pubrel) inflight_message.received_pubrel()
except KeyError as ke: except KeyError as ke:
self.logger.warn("Received PUBREL for unknown pending subscription with Id: %s" % packet_id) self.logger.warn("Received PUBREL for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine @asyncio.coroutine
def handle_publish(self, publish_packet: PublishPacket): def handle_publish(self, publish_packet: PublishPacket):
inflight_message = None incoming_message = None
packet_id = publish_packet.variable_header.packet_id packet_id = publish_packet.variable_header.packet_id
qos = publish_packet.qos qos = publish_packet.qos
if qos == 0: if qos == 0:
inflight_message = IncomingInFlightMessage(publish_packet, qos) if publish_packet.dup_flag:
yield from self.delivered_message.put(inflight_message) self.logger.warn("[MQTT-3.3.1-2] DUP flag must set to 0 for QOS 0 message. Message ignored: %s" %
else: repr(publish_packet))
if packet_id in self.session.incoming_msg:
inflight_message = self.session.incoming_msg[packet_id]
else: else:
inflight_message = InFlightMessage(publish_packet, qos) incoming_message = IncomingInFlightMessage(publish_packet, qos)
self.session.incoming_msg[packet_id] = inflight_message incoming_message.received_publish()
inflight_message.publish() self.session.incoming_msg[packet_id] = incoming_message
yield from self.session.delivered_message_queue.put(packet_id)
else:
# Check if publish is a retry
if packet_id in self.session.incoming_msg:
incoming_message = self.session.incoming_msg[packet_id]
else:
incoming_message = IncomingInFlightMessage(publish_packet, qos)
self.session.incoming_msg[packet_id] = incoming_message
incoming_message.publish()
if qos == 1: if qos == 1:
puback = PubackPacket.build(packet_id) # Initiate delivery
yield from self.outgoing_queue.put(puback) yield from self.session.delivered_message_queue.put(packet_id)
inflight_message.acknowledge() ack = yield from incoming_message.wait_acknowledge()
if ack:
# Send PUBACK
puback = PubackPacket.build(packet_id)
yield from self.outgoing_queue.put(puback)
#Discard message
del self.session.incoming_msg[packet_id]
self.logger.debug("Discarded incoming message %d" % packet_id)
else:
raise HBMQTTException("Something wrong, ack is False")
if qos == 2: if qos == 2:
# Send PUBREC
pubrec = PubrecPacket.build(packet_id) pubrec = PubrecPacket.build(packet_id)
yield from self.outgoing_queue.put(pubrec) yield from self.outgoing_queue.put(pubrec)
inflight_message.receive() incoming_message.sent_pubrec()
waiter = futures.Future(loop=self._loop) # Wait for pubrel
self._pubrel_waiters[packet_id] = waiter ack = yield from incoming_message.wait_pubrel()
yield from waiter if ack:
inflight_message.pubrel = waiter.result() # Initiate delivery
del self._pubrel_waiters[packet_id] yield from self.session.delivered_message_queue.put(packet_id)
inflight_message.release() else:
pubcomp = PubcompPacket.build(packet_id) raise HBMQTTException("Something wrong, ack is False")
yield from self.outgoing_queue.put(pubcomp) ack = yield from incoming_message.wait_acknowledge()
inflight_message.complete() if ack:
yield from self.delivered_message.put(inflight_message) # Send PUBCOMP
del self.session.incoming_msg[packet_id] pubcomp = PubcompPacket.build(packet_id)
yield from self.outgoing_queue.put(pubcomp)
incoming_message.sent_pubcomp()
#Discard message
del self.session.incoming_msg[packet_id]
self.logger.debug("Discarded incoming message %d" % packet_id)
else:
raise HBMQTTException("Something wrong, ack is False")

Wyświetl plik

@ -0,0 +1,192 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import asyncio
from transitions import Machine, MachineError
from datetime import datetime
from hbmqtt.errors import HBMQTTException
class InFlightMessage:
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
def __init__(self, packet, qos, ack_timeout=0, loop=None):
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self.publish_packet = packet
self.qos = qos
self.publish_ts = None
self.puback_ts = None
self.pubrec_ts = None
self.pubrel_ts = None
self.pubcomp_ts = None
self.nb_retries = 0
self._ack_waiter = asyncio.Future(loop=self._loop)
self._ack_timeout = ack_timeout
self._ack_timeout_handle = None
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=InFlightMessage.states, initial='new')
self.machine.add_transition(trigger='publish', source='new', dest='published')
self.machine.add_transition(trigger='publish', source='published', dest='published')
self.machine.add_transition(trigger='publish', source='received', dest='published')
self.machine.add_transition(trigger='publish', source='released', dest='published')
if self.qos == 0x01:
self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged')
if self.qos == 0x02:
self.machine.add_transition(trigger='receive', source='published', dest='received')
self.machine.add_transition(trigger='release', source='received', dest='released')
self.machine.add_transition(trigger='complete', source='released', dest='completed')
self.machine.add_transition(trigger='acknowledge', source='completed', dest='acknowledged')
@asyncio.coroutine
def wait_acknowledge(self):
return (yield from self._ack_waiter)
def start_ack_timeout(self):
def cb_timeout():
self._ack_waiter.set_result(False)
if self._ack_timeout:
self._ack_timeout_handle = self._loop.call_later(self._ack_timeout, cb_timeout)
def cancel_ack_timeout(self):
if self._ack_timeout_handle:
self._ack_timeout_handle.cancel()
def reset_ack_timeout(self):
self.cancel_ack_timeout()
self.start_ack_timeout()
def cancel(self):
if self._ack_waiter and not self._ack_waiter.done():
self._ack_waiter.cancel()
self.cancel_ack_timeout()
class OutgoingInFlightMessage(InFlightMessage):
def received_puback(self):
try:
self.acknowledge()
self.puback_ts = datetime.now()
self.cancel_ack_timeout()
self._ack_waiter.set_result(True)
except MachineError:
raise HBMQTTException(
'Invalid call to method received_puback on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def received_pubrec(self):
try:
self.receive()
self.pubrec_ts = datetime.now()
self.publish_packet = None # Discard message
self.reset_ack_timeout()
except MachineError:
raise HBMQTTException(
'Invalid call to method received_pubrec on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def received_pubcomp(self):
try:
self.complete()
self.pubcomp_ts = datetime.now()
self.cancel_ack_timeout()
self._ack_waiter.set_result(True)
self.acknowledge()
except MachineError:
raise HBMQTTException(
'Invalid call to method received_pubcomp on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def sent_pubrel(self):
try:
self.release()
self.pubrel_ts = datetime.now()
except MachineError:
raise HBMQTTException(
'Invalid call to method sent_pubrel on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def retry_publish(self):
try:
self.publish()
self.nb_retries += 1
self.publish_ts = datetime.now()
self.start_ack_timeout()
except MachineError:
raise HBMQTTException(
'Invalid call to method retry_publish on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def sent_publish(self):
try:
self.publish()
self.publish_ts = datetime.now()
self.start_ack_timeout()
except MachineError:
raise HBMQTTException(
'Invalid call to method sent_publish on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
class IncomingInFlightMessage(InFlightMessage):
def __init__(self, packet, qos, ack_timeout=0, loop=None):
super().__init__(packet, qos, ack_timeout, loop)
self._pubrel_waiter = asyncio.Future(loop=self._loop)
def received_publish(self):
try:
self.publish()
self.publish_ts = datetime.now()
except MachineError:
raise HBMQTTException(
'Invalid call to method received_publish on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def sent_pubrec(self):
try:
self.receive()
self.pubrec_ts = datetime.now()
except MachineError:
raise HBMQTTException(
'Invalid call to method sent_pubrec on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def sent_pubcomp(self):
try:
self.complete()
self.pubcomp_ts = datetime.now()
except MachineError:
raise HBMQTTException(
'Invalid call to method sent_pubrec on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
@asyncio.coroutine
def wait_pubrel(self):
return (yield from self._pubrel_waiter)
def received_pubrel(self):
try:
self.release()
self.pubrel_ts = datetime.now()
self._pubrel_waiter.set_result(True)
except MachineError:
raise HBMQTTException(
'Invalid call to method received_pubcomp on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def acknowledge_delivery(self):
try:
self._ack_waiter.set_result(True)
except MachineError:
raise HBMQTTException(
'Invalid call to method acknowledge_delivery on in-flight messages with QOS=%d, state=%s' %
(self.qos, self.state))
def cancel(self):
super().cancel()
if self._pubrel_waiter and not self._pubrel_waiter.done():
self._pubrel_waiter.cancel()

Wyświetl plik

@ -31,10 +31,18 @@ class Session:
self.parent = 0 self.parent = 0
self.handler = None self.handler = None
# Used to store outgoing InflightMessage while publish protocol flows
self.outgoing_msg = dict() self.outgoing_msg = dict()
# Used to store incoming InflightMessage while publish protocol flows
self.incoming_msg = dict() self.incoming_msg = dict()
# Stores messages retained for this session
self.retained_messages = Queue() self.retained_messages = Queue()
# Stores PUBLISH messages ID received in order and ready for application process
self.delivered_message_queue = Queue()
def _init_states(self): def _init_states(self):
self.machine = Machine(states=Session.states, initial='new') self.machine = Machine(states=Session.states, initial='new')
self.machine.add_transition(trigger='connect', source='new', dest='connected') self.machine.add_transition(trigger='connect', source='new', dest='connected')

Wyświetl plik

@ -21,12 +21,13 @@ def uptime_coro():
# Subscribe to '$SYS/broker/uptime' with QOS=1 # Subscribe to '$SYS/broker/uptime' with QOS=1
yield from C.subscribe([ yield from C.subscribe([
{'filter': '$SYS/broker/uptime', 'qos': 0x01}, {'filter': '$SYS/broker/uptime', 'qos': 0x01},
{'filter': '$SYS/broker/load/#', 'qos': 0x00}, {'filter': '$SYS/broker/load/#', 'qos': 0x02},
]) ])
logger.info("Subscribed") logger.info("Subscribed")
for i in range(1, 10): for i in range(1, 100):
inflight = yield from C.deliver_message() packet = yield from C.deliver_message()
print(inflight.packet.payload.data) print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data)))
yield from C.acknowledge_delivery(packet.variable_header.packet_id)
yield from C.unsubscribe(['$SYS/broker/uptime']) yield from C.unsubscribe(['$SYS/broker/uptime'])
logger.info("UnSubscribed") logger.info("UnSubscribed")
yield from C.disconnect() yield from C.disconnect()