diff --git a/hbmqtt/client/_client.py b/hbmqtt/client/_client.py index eef7ae1..fca06fd 100644 --- a/hbmqtt/client/_client.py +++ b/hbmqtt/client/_client.py @@ -19,7 +19,6 @@ _defaults = { 'ping_delay': 1, 'default_qos': 0, 'default_retain': False, - 'inflight-polling-interval': 1, 'subscriptions-polling-interval': 1, } diff --git a/hbmqtt/mqtt/protocol.py b/hbmqtt/mqtt/protocol.py index 7e175da..e73cb76 100644 --- a/hbmqtt/mqtt/protocol.py +++ b/hbmqtt/mqtt/protocol.py @@ -27,8 +27,8 @@ from transitions import Machine, MachineError class InFlightMessage: states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed'] - def __init__(self, packet_id, qos): - self.packet_id = packet_id + def __init__(self, packet, qos): + self.packet = packet self.qos = qos self._init_states() @@ -58,11 +58,8 @@ class ProtocolHandler: self._loop = loop self._reader_task = None self._writer_task = None - self._inflight_task = None self._reader_ready = asyncio.Event(loop=self._loop) self._writer_ready = asyncio.Event(loop=self._loop) - self._inflight_ready = asyncio.Event(loop=self._loop) - self._inflight_changed = asyncio.Condition(loop=self._loop) self._running = False @@ -73,6 +70,10 @@ class ProtocolHandler: for p in PacketType: self.incoming_queues[p] = asyncio.Queue() self.outgoing_queue = asyncio.Queue() + self._puback_waiters = dict() + self._pubrec_waiters = dict() + self._pubrec_waiters = dict() + self._pubcomp_waiters = dict() self.inflight_messages = dict() @asyncio.coroutine @@ -80,66 +81,51 @@ class ProtocolHandler: self._running = True self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop) self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop) - self._inflight_task = asyncio.async(self._inflight_coro(), loop=self._loop) yield from asyncio.wait( - [self._reader_ready.wait(), self._writer_ready.wait(), self._inflight_ready.wait()], loop=self._loop) + [self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) self.logger.debug("Handler tasks started") @asyncio.coroutine def mqtt_publish(self, topic, message, packet_id, dup, qos, retain): - def qos_0_predicate(): - ret = False - try: - if self.inflight_messages.get(packet_id).state == 'published': - ret = True - #self.logger.debug("qos_0 predicate return %s" % ret) - return ret - except KeyError: - return False - - def qos_1_predicate(): - ret = False - try: - if self.inflight_messages.get(packet_id).state == 'acknowledged': - ret = True - #self.logger.debug("qos_1 predicate return %s" % ret) - return ret - except KeyError: - return False - - def qos_2_predicate(): - ret = False - try: - if self.inflight_messages.get(packet_id).state == 'completed': - ret = True - #self.logger.debug("qos_1 predicate return %s" % ret) - return ret - except KeyError: - return False - if packet_id in self.inflight_messages: self.logger.warn("A message with the same packet ID is already in flight") packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain) yield from self.outgoing_queue.put(packet) - inflight_message = InFlightMessage(packet.variable_header.packet_id, qos) - inflight_message.publish() + inflight_message = InFlightMessage(packet, qos) self.inflight_messages[packet.variable_header.packet_id] = inflight_message - yield from self._inflight_changed.acquire() - if qos == 0x00: - yield from self._inflight_changed.wait_for(qos_0_predicate) + inflight_message.publish() if qos == 0x01: - yield from self._inflight_changed.wait_for(qos_1_predicate) + waiter = futures.Future(loop=self._loop) + self._puback_waiters[packet_id] = waiter + yield from waiter + inflight_message.acknowledge() + del self._puback_waiters[packet_id] if qos == 0x02: - yield from self._inflight_changed.wait_for(qos_2_predicate) - self.inflight_messages.pop(packet.variable_header.packet_id) - self._inflight_changed.release() - return packet + # Wait for PUBREC + waiter = futures.Future(loop=self._loop) + self._pubrec_waiters[packet_id] = waiter + yield from waiter + del self._pubrec_waiters[packet_id] + inflight_message.receive() + + # Send pubrel + pubrel = PubrelPacket.build(packet_id) + yield from self.outgoing_queue.put(pubrel) + inflight_message.release() + + # Wait for pubcomp + waiter = futures.Future(loop=self._loop) + self._pubcomp_waiters[packet_id] = waiter + yield from waiter + del self._pubcomp_waiters[packet_id] + + del self.inflight_messages[packet_id] @asyncio.coroutine def stop(self): self._running = False self.session.reader.feed_eof() - yield from asyncio.wait([self._inflight_task, self._writer_task, self._reader_task], loop=self._loop) + yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop) @asyncio.coroutine def _reader_coro(self): @@ -159,6 +145,12 @@ class ProtocolHandler: yield from self.handle_suback(packet) if packet.fixed_header.packet_type == PacketType.UNSUBACK: yield from self.handle_unsuback(packet) + if packet.fixed_header.packet_type == PacketType.PUBACK: + yield from self.handle_puback(packet) + if packet.fixed_header.packet_type == PacketType.PUBREC: + yield from self.handle_pubrec(packet) + if packet.fixed_header.packet_type == PacketType.PUBCOMP: + yield from self.handle_pubcomp(packet) else: yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet) else: @@ -207,49 +199,6 @@ class ProtocolHandler: self.logger.warn("Unhandled exception in writer coro: %s" % e) self.logger.debug("Writer coro stopped") - @asyncio.coroutine - def _inflight_coro(self): - self.logger.debug("Starting in-flight messages polling coro") - while self._running: - self._inflight_ready.set() - yield from asyncio.sleep(self.config['inflight-polling-interval']) - self.logger.debug("in-flight polling coro wake-up") - try: - while not self.incoming_queues[PacketType.PUBACK].empty(): - packet = self.incoming_queues[PacketType.PUBACK].get_nowait() - packet_id = packet.variable_header.packet_id - inflight_message = self.inflight_messages.get(packet_id) - inflight_message.acknowledge() - self.logger.debug("Message with packet Id=%s acknowledged" % packet_id) - - while not self.incoming_queues[PacketType.PUBREC].empty(): - packet = self.incoming_queues[PacketType.PUBREC].get_nowait() - packet_id = packet.variable_header.packet_id - inflight_message = self.inflight_messages.get(packet_id) - inflight_message.receive() - self.logger.debug("Message with packet Id=%s received" % packet_id) - - rel_packet = PubrelPacket.build(packet_id) - yield from self.outgoing_queue.put(rel_packet) - inflight_message.release() - self.logger.debug("Message with packet Id=%s released" % packet_id) - - while not self.incoming_queues[PacketType.PUBCOMP].empty(): - packet = self.incoming_queues[PacketType.PUBCOMP].get_nowait() - packet_id = packet.variable_header.packet_id - inflight_message = self.inflight_messages.get(packet_id) - inflight_message.complete() - self.logger.debug("Message with packet Id=%s completed" % packet_id) - - yield from self._inflight_changed.acquire() - self._inflight_changed.notify_all() - self._inflight_changed.release() - except KeyError: - self.logger.warn("Received %s for unknown inflight message Id %d" % (packet.fixed_header.packet_type, packet_id)) - except MachineError as me: - self.logger.warn("Packet type incompatible with message QOS: %s" % me) - self.logger.debug("In-flight messages polling coro stopped") - @asyncio.coroutine def _receive_publish_coro(self): while self._running: @@ -287,6 +236,33 @@ class ProtocolHandler: def handle_unsuback(self, unsuback: UnsubackPacket): pass + @asyncio.coroutine + def handle_puback(self, puback: PubackPacket): + packet_id = puback.variable_header.packet_id + try: + waiter = self._puback_waiters[packet_id] + waiter.set_result(None) + except KeyError as ke: + self.logger.warn("Received PUBACK for unknown pending subscription with Id: %s" % packet_id) + + @asyncio.coroutine + def handle_pubrec(self, pubrec: PubrecPacket): + packet_id = pubrec.variable_header.packet_id + try: + waiter = self._pubrec_waiters[packet_id] + waiter.set_result(None) + except KeyError as ke: + self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id) + + @asyncio.coroutine + def handle_pubcomp(self, pubcomp: PubcompPacket): + packet_id = pubcomp.variable_header.packet_id + try: + waiter = self._pubcomp_waiters[packet_id] + waiter.set_result(None) + except KeyError as ke: + self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id) + class ClientProtocolHandler(ProtocolHandler): def __init__(self, session: Session, config, loop=None): diff --git a/samples/client_subscribe.py b/samples/client_subscribe.py index e9611b0..bdbb689 100644 --- a/samples/client_subscribe.py +++ b/samples/client_subscribe.py @@ -10,7 +10,7 @@ C = MQTTClient() def test_coro(): yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None) ret = yield from C.subscribe([ - {'filter': '$SYS/broker/uptime', 'qos': 0x00}, + {'filter': '$SYS/broker/uptime', 'qos': 0x01}, ]) logger.info("Subscribed") logger.info(repr(ret)) diff --git a/samples/test_client.py b/samples/test_client.py index 400d50b..997a75b 100644 --- a/samples/test_client.py +++ b/samples/test_client.py @@ -10,37 +10,12 @@ C = MQTTClient() def test_coro(): yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None) tasks = [ - asyncio.async(C.publish('a/b', b'0123456789')), - asyncio.async(C.publish('a/b', b'0', qos=0x01)), - asyncio.async(C.publish('a/b', b'1', qos=0x01)), - asyncio.async(C.publish('a/b', b'2', qos=0x01)), - asyncio.async(C.publish('a/b', b'3', qos=0x01)), - asyncio.async(C.publish('a/b', b'4', qos=0x01)), - asyncio.async(C.publish('a/b', b'5', qos=0x01)), - asyncio.async(C.publish('a/b', b'6', qos=0x01)), - asyncio.async(C.publish('a/b', b'7', qos=0x01)), - asyncio.async(C.publish('a/b', b'8', qos=0x01)), - asyncio.async(C.publish('a/b', b'9', qos=0x01)), - asyncio.async(C.publish('a/b', b'0', qos=0x02)), - asyncio.async(C.publish('a/b', b'1', qos=0x02)), - asyncio.async(C.publish('a/b', b'2', qos=0x02)), - asyncio.async(C.publish('a/b', b'3', qos=0x02)), - asyncio.async(C.publish('a/b', b'4', qos=0x02)), - asyncio.async(C.publish('a/b', b'5', qos=0x02)), - asyncio.async(C.publish('a/b', b'6', qos=0x02)), - asyncio.async(C.publish('a/b', b'7', qos=0x02)), - asyncio.async(C.publish('a/b', b'8', qos=0x02)), - asyncio.async(C.publish('a/b', b'9', qos=0x02)), + asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')), + asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)), + asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)), ] yield from asyncio.wait(tasks) logger.info("messages published") - yield from C.subscribe([ - {'filter': '$SYS/broker/connections/*', 'qos': 0x01}, - ]) - logger.info("Subscribed") - #yield from C.unsubscribe(['a/b', 'c/d']) - #logger.info("Unsubscribed") - yield from C.disconnect()