Handle subscribe and unsubscribe with futures

pull/8/head
Nicolas Jouanin 2015-07-05 22:00:49 +02:00
rodzic 5d2126ff63
commit 5eacd2959d
1 zmienionych plików z 41 dodań i 93 usunięć

Wyświetl plik

@ -18,7 +18,9 @@ from hbmqtt.mqtt.puback import PubackPacket
from hbmqtt.mqtt.pubrec import PubrecPacket from hbmqtt.mqtt.pubrec import PubrecPacket
from hbmqtt.mqtt.pubcomp import PubcompPacket from hbmqtt.mqtt.pubcomp import PubcompPacket
from hbmqtt.mqtt.subscribe import SubscribePacket from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.session import Session from hbmqtt.session import Session
from transitions import Machine, MachineError from transitions import Machine, MachineError
@ -153,6 +155,10 @@ class ProtocolHandler:
if packet.fixed_header.packet_type == PacketType.CONNACK: if packet.fixed_header.packet_type == PacketType.CONNACK:
yield from self.handle_connack(packet) yield from self.handle_connack(packet)
if packet.fixed_header.packet_type == PacketType.SUBACK:
yield from self.handle_suback(packet)
if packet.fixed_header.packet_type == PacketType.UNSUBACK:
yield from self.handle_unsuback(packet)
else: else:
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet) yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
else: else:
@ -273,53 +279,30 @@ class ProtocolHandler:
def handle_connack(self, connack: ConnackPacket): def handle_connack(self, connack: ConnackPacket):
pass pass
class Subscription: @asyncio.coroutine
states = ['new', 'subscribed', 'acknowledged'] def handle_suback(self, suback: SubackPacket):
pass
def __init__(self, packet_id, topics): @asyncio.coroutine
self.topics = topics def handle_unsuback(self, unsuback: UnsubackPacket):
self.packet_id = packet_id pass
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=Subscription.states, initial='new')
self.machine.add_transition(trigger='subscribe', source='new', dest='subscribed')
self.machine.add_transition(trigger='acknowledge', source='subscribed', dest='acknowledged')
class UnSubscription:
states = ['new', 'unsubscribed', 'acknowledged']
def __init__(self, packet_id, topics):
self.topics = topics
self.packet_id = packet_id
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=UnSubscription.states, initial='new')
self.machine.add_transition(trigger='unsubscribe', source='new', dest='unsubscribed')
self.machine.add_transition(trigger='acknowledge', source='unsubscribed', dest='acknowledged')
class ClientProtocolHandler(ProtocolHandler): class ClientProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, config, loop=None): def __init__(self, session: Session, config, loop=None):
super().__init__(session, config, loop) super().__init__(session, config, loop)
self._ping_task = None self._ping_task = None
self.subscriptions = dict()
self._subscription_task = None
self._subscriptions_changed = asyncio.Condition(loop=self._loop)
self._subscriptions_ready = asyncio.Event(loop=self._loop)
self._connack_waiter = None self._connack_waiter = None
self._subscriptions_waiter = dict()
self._unsubscriptions_waiter = dict()
@asyncio.coroutine @asyncio.coroutine
def start(self): def start(self):
yield from super().start() yield from super().start()
self._subscription_task = asyncio.async(self._subscriptions_coro(), loop=self._loop)
yield from asyncio.wait([self._subscriptions_ready.wait()], loop=self._loop)
@asyncio.coroutine @asyncio.coroutine
def stop(self): def stop(self):
yield from super().stop() yield from super().stop()
yield from asyncio.wait([self._subscription_task], loop=self._loop)
if self._ping_task: if self._ping_task:
try: try:
self._ping_task.cancel() self._ping_task.cancel()
@ -329,61 +312,28 @@ class ClientProtocolHandler(ProtocolHandler):
def handle_keepalive(self): def handle_keepalive(self):
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping()) self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
def _subscriptions_coro(self):
self.logger.debug("Starting subscriptions polling coro")
while self._running:
self._subscriptions_ready.set()
yield from asyncio.sleep(self.config['subscriptions-polling-interval'])
self.logger.debug("Subscriptions polling coro wake-up")
try:
while not self.incoming_queues[PacketType.SUBACK].empty():
packet = self.incoming_queues[PacketType.SUBACK].get_nowait()
packet_id = packet.variable_header.packet_id
subscription = self.subscriptions.get(packet_id)
for i in range(len(subscription.topics)):
subscription.topics[i]['return_code'] = packet.payload.return_codes[i]
subscription.acknowledge()
self.logger.debug("Subscription with packet Id=%s acknowledged" % packet_id)
while not self.incoming_queues[PacketType.UNSUBACK].empty():
packet = self.incoming_queues[PacketType.UNSUBACK].get_nowait()
packet_id = packet.variable_header.packet_id
subscription = self.subscriptions.get(packet_id)
subscription.acknowledge()
self.logger.debug("Unsubscription with packet Id=%s acknowledged" % packet_id)
yield from self._subscriptions_changed.acquire()
self._subscriptions_changed.notify_all()
self._subscriptions_changed.release()
except KeyError:
self.logger.warn("Received %s for unknown subscription message Id %d" % (packet.fixed_header.packet_type, packet_id))
except MachineError as me:
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
self.logger.debug("Subscriptions polling coro stopped")
@asyncio.coroutine @asyncio.coroutine
def mqtt_subscribe(self, topics, packet_id): def mqtt_subscribe(self, topics, packet_id):
""" """
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...] :param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
:return: :return:
""" """
def acknowledged_predicate():
if self.subscriptions[subscribe.variable_header.packet_id].state == 'acknowledged':
return True
else:
return False
subscribe = SubscribePacket.build(topics, packet_id) subscribe = SubscribePacket.build(topics, packet_id)
yield from self.outgoing_queue.put(subscribe) yield from self.outgoing_queue.put(subscribe)
subscription = Subscription(subscribe.variable_header.packet_id, topics) waiter = futures.Future(loop=self._loop)
subscription.subscribe() self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
self.subscriptions[subscribe.variable_header.packet_id] = subscription return_codes = yield from waiter
yield from self._subscriptions_changed.acquire() del self._subscriptions_waiter[subscribe.variable_header.packet_id]
yield from self._subscriptions_changed.wait_for(acknowledged_predicate) return return_codes
subscription = self.subscriptions.pop(subscribe.variable_header.packet_id)
self._subscriptions_changed.release() @asyncio.coroutine
return subscription def handle_suback(self, suback: SubackPacket):
packet_id = suback.variable_header.packet_id
try:
waiter = self._subscriptions_waiter.get(packet_id)
waiter.set_result(suback.payload.return_codes)
except KeyError as ke:
self.logger.warn("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine @asyncio.coroutine
def mqtt_unsubscribe(self, topics, packet_id): def mqtt_unsubscribe(self, topics, packet_id):
@ -392,23 +342,21 @@ class ClientProtocolHandler(ProtocolHandler):
:param topics: array of topics ['/a/b', ...] :param topics: array of topics ['/a/b', ...]
:return: :return:
""" """
def acknowledged_predicate():
if self.subscriptions[unsubscribe.variable_header.packet_id].state == 'acknowledged':
return True
else:
return False
unsubscribe = UnsubscribePacket.build(topics, packet_id) unsubscribe = UnsubscribePacket.build(topics, packet_id)
yield from self.outgoing_queue.put(unsubscribe) yield from self.outgoing_queue.put(unsubscribe)
subscription = UnSubscription(unsubscribe.variable_header.packet_id, topics) waiter = futures.Future(loop=self._loop)
subscription.unsubscribe() self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription yield from waiter
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
yield from self._subscriptions_changed.acquire()
yield from self._subscriptions_changed.wait_for(acknowledged_predicate) @asyncio.coroutine
subscription = self.subscriptions.pop(unsubscribe.variable_header.packet_id) def handle_unsuback(self, unsuback: UnsubackPacket):
self._subscriptions_changed.release() packet_id = unsuback.variable_header.packet_id
return subscription try:
waiter = self._unsubscriptions_waiter.get(packet_id)
waiter.set_result(None)
except KeyError as ke:
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine @asyncio.coroutine
def mqtt_connect(self): def mqtt_connect(self):