kopia lustrzana https://github.com/Yakifo/amqtt
Manage publish sending message flows using futures
rodzic
5eacd2959d
commit
9b81cf2277
|
@ -19,7 +19,6 @@ _defaults = {
|
|||
'ping_delay': 1,
|
||||
'default_qos': 0,
|
||||
'default_retain': False,
|
||||
'inflight-polling-interval': 1,
|
||||
'subscriptions-polling-interval': 1,
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ from transitions import Machine, MachineError
|
|||
class InFlightMessage:
|
||||
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
|
||||
|
||||
def __init__(self, packet_id, qos):
|
||||
self.packet_id = packet_id
|
||||
def __init__(self, packet, qos):
|
||||
self.packet = packet
|
||||
self.qos = qos
|
||||
self._init_states()
|
||||
|
||||
|
@ -58,11 +58,8 @@ class ProtocolHandler:
|
|||
self._loop = loop
|
||||
self._reader_task = None
|
||||
self._writer_task = None
|
||||
self._inflight_task = None
|
||||
self._reader_ready = asyncio.Event(loop=self._loop)
|
||||
self._writer_ready = asyncio.Event(loop=self._loop)
|
||||
self._inflight_ready = asyncio.Event(loop=self._loop)
|
||||
self._inflight_changed = asyncio.Condition(loop=self._loop)
|
||||
|
||||
self._running = False
|
||||
|
||||
|
@ -73,6 +70,10 @@ class ProtocolHandler:
|
|||
for p in PacketType:
|
||||
self.incoming_queues[p] = asyncio.Queue()
|
||||
self.outgoing_queue = asyncio.Queue()
|
||||
self._puback_waiters = dict()
|
||||
self._pubrec_waiters = dict()
|
||||
self._pubrec_waiters = dict()
|
||||
self._pubcomp_waiters = dict()
|
||||
self.inflight_messages = dict()
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -80,66 +81,51 @@ class ProtocolHandler:
|
|||
self._running = True
|
||||
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
|
||||
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
|
||||
self._inflight_task = asyncio.async(self._inflight_coro(), loop=self._loop)
|
||||
yield from asyncio.wait(
|
||||
[self._reader_ready.wait(), self._writer_ready.wait(), self._inflight_ready.wait()], loop=self._loop)
|
||||
[self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
|
||||
self.logger.debug("Handler tasks started")
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_publish(self, topic, message, packet_id, dup, qos, retain):
|
||||
def qos_0_predicate():
|
||||
ret = False
|
||||
try:
|
||||
if self.inflight_messages.get(packet_id).state == 'published':
|
||||
ret = True
|
||||
#self.logger.debug("qos_0 predicate return %s" % ret)
|
||||
return ret
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def qos_1_predicate():
|
||||
ret = False
|
||||
try:
|
||||
if self.inflight_messages.get(packet_id).state == 'acknowledged':
|
||||
ret = True
|
||||
#self.logger.debug("qos_1 predicate return %s" % ret)
|
||||
return ret
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def qos_2_predicate():
|
||||
ret = False
|
||||
try:
|
||||
if self.inflight_messages.get(packet_id).state == 'completed':
|
||||
ret = True
|
||||
#self.logger.debug("qos_1 predicate return %s" % ret)
|
||||
return ret
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
if packet_id in self.inflight_messages:
|
||||
self.logger.warn("A message with the same packet ID is already in flight")
|
||||
packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain)
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
inflight_message = InFlightMessage(packet.variable_header.packet_id, qos)
|
||||
inflight_message.publish()
|
||||
inflight_message = InFlightMessage(packet, qos)
|
||||
self.inflight_messages[packet.variable_header.packet_id] = inflight_message
|
||||
yield from self._inflight_changed.acquire()
|
||||
if qos == 0x00:
|
||||
yield from self._inflight_changed.wait_for(qos_0_predicate)
|
||||
inflight_message.publish()
|
||||
if qos == 0x01:
|
||||
yield from self._inflight_changed.wait_for(qos_1_predicate)
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._puback_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.acknowledge()
|
||||
del self._puback_waiters[packet_id]
|
||||
if qos == 0x02:
|
||||
yield from self._inflight_changed.wait_for(qos_2_predicate)
|
||||
self.inflight_messages.pop(packet.variable_header.packet_id)
|
||||
self._inflight_changed.release()
|
||||
return packet
|
||||
# Wait for PUBREC
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubrec_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
del self._pubrec_waiters[packet_id]
|
||||
inflight_message.receive()
|
||||
|
||||
# Send pubrel
|
||||
pubrel = PubrelPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(pubrel)
|
||||
inflight_message.release()
|
||||
|
||||
# Wait for pubcomp
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubcomp_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
del self._pubcomp_waiters[packet_id]
|
||||
|
||||
del self.inflight_messages[packet_id]
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
self._running = False
|
||||
self.session.reader.feed_eof()
|
||||
yield from asyncio.wait([self._inflight_task, self._writer_task, self._reader_task], loop=self._loop)
|
||||
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _reader_coro(self):
|
||||
|
@ -159,6 +145,12 @@ class ProtocolHandler:
|
|||
yield from self.handle_suback(packet)
|
||||
if packet.fixed_header.packet_type == PacketType.UNSUBACK:
|
||||
yield from self.handle_unsuback(packet)
|
||||
if packet.fixed_header.packet_type == PacketType.PUBACK:
|
||||
yield from self.handle_puback(packet)
|
||||
if packet.fixed_header.packet_type == PacketType.PUBREC:
|
||||
yield from self.handle_pubrec(packet)
|
||||
if packet.fixed_header.packet_type == PacketType.PUBCOMP:
|
||||
yield from self.handle_pubcomp(packet)
|
||||
else:
|
||||
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
||||
else:
|
||||
|
@ -207,49 +199,6 @@ class ProtocolHandler:
|
|||
self.logger.warn("Unhandled exception in writer coro: %s" % e)
|
||||
self.logger.debug("Writer coro stopped")
|
||||
|
||||
@asyncio.coroutine
|
||||
def _inflight_coro(self):
|
||||
self.logger.debug("Starting in-flight messages polling coro")
|
||||
while self._running:
|
||||
self._inflight_ready.set()
|
||||
yield from asyncio.sleep(self.config['inflight-polling-interval'])
|
||||
self.logger.debug("in-flight polling coro wake-up")
|
||||
try:
|
||||
while not self.incoming_queues[PacketType.PUBACK].empty():
|
||||
packet = self.incoming_queues[PacketType.PUBACK].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
inflight_message = self.inflight_messages.get(packet_id)
|
||||
inflight_message.acknowledge()
|
||||
self.logger.debug("Message with packet Id=%s acknowledged" % packet_id)
|
||||
|
||||
while not self.incoming_queues[PacketType.PUBREC].empty():
|
||||
packet = self.incoming_queues[PacketType.PUBREC].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
inflight_message = self.inflight_messages.get(packet_id)
|
||||
inflight_message.receive()
|
||||
self.logger.debug("Message with packet Id=%s received" % packet_id)
|
||||
|
||||
rel_packet = PubrelPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(rel_packet)
|
||||
inflight_message.release()
|
||||
self.logger.debug("Message with packet Id=%s released" % packet_id)
|
||||
|
||||
while not self.incoming_queues[PacketType.PUBCOMP].empty():
|
||||
packet = self.incoming_queues[PacketType.PUBCOMP].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
inflight_message = self.inflight_messages.get(packet_id)
|
||||
inflight_message.complete()
|
||||
self.logger.debug("Message with packet Id=%s completed" % packet_id)
|
||||
|
||||
yield from self._inflight_changed.acquire()
|
||||
self._inflight_changed.notify_all()
|
||||
self._inflight_changed.release()
|
||||
except KeyError:
|
||||
self.logger.warn("Received %s for unknown inflight message Id %d" % (packet.fixed_header.packet_type, packet_id))
|
||||
except MachineError as me:
|
||||
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
|
||||
self.logger.debug("In-flight messages polling coro stopped")
|
||||
|
||||
@asyncio.coroutine
|
||||
def _receive_publish_coro(self):
|
||||
while self._running:
|
||||
|
@ -287,6 +236,33 @@ class ProtocolHandler:
|
|||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_puback(self, puback: PubackPacket):
|
||||
packet_id = puback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._puback_waiters[packet_id]
|
||||
waiter.set_result(None)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pubrec(self, pubrec: PubrecPacket):
|
||||
packet_id = pubrec.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubrec_waiters[packet_id]
|
||||
waiter.set_result(None)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pubcomp(self, pubcomp: PubcompPacket):
|
||||
packet_id = pubcomp.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubcomp_waiters[packet_id]
|
||||
waiter.set_result(None)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler):
|
||||
def __init__(self, session: Session, config, loop=None):
|
||||
|
|
|
@ -10,7 +10,7 @@ C = MQTTClient()
|
|||
def test_coro():
|
||||
yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None)
|
||||
ret = yield from C.subscribe([
|
||||
{'filter': '$SYS/broker/uptime', 'qos': 0x00},
|
||||
{'filter': '$SYS/broker/uptime', 'qos': 0x01},
|
||||
])
|
||||
logger.info("Subscribed")
|
||||
logger.info(repr(ret))
|
||||
|
|
|
@ -10,37 +10,12 @@ C = MQTTClient()
|
|||
def test_coro():
|
||||
yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None)
|
||||
tasks = [
|
||||
asyncio.async(C.publish('a/b', b'0123456789')),
|
||||
asyncio.async(C.publish('a/b', b'0', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'1', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'2', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'3', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'4', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'5', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'6', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'7', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'8', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'9', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'0', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'1', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'2', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'3', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'4', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'5', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'6', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'7', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'8', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'9', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
|
||||
]
|
||||
yield from asyncio.wait(tasks)
|
||||
logger.info("messages published")
|
||||
yield from C.subscribe([
|
||||
{'filter': '$SYS/broker/connections/*', 'qos': 0x01},
|
||||
])
|
||||
logger.info("Subscribed")
|
||||
#yield from C.unsubscribe(['a/b', 'c/d'])
|
||||
#logger.info("Unsubscribed")
|
||||
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue