Handle subscribe and unsubscribe with futures

pull/8/head
Nicolas Jouanin 2015-07-05 22:00:49 +02:00
rodzic 5d2126ff63
commit 5eacd2959d
1 zmienionych plików z 41 dodań i 93 usunięć

Wyświetl plik

@ -18,7 +18,9 @@ from hbmqtt.mqtt.puback import PubackPacket
from hbmqtt.mqtt.pubrec import PubrecPacket
from hbmqtt.mqtt.pubcomp import PubcompPacket
from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.session import Session
from transitions import Machine, MachineError
@ -153,6 +155,10 @@ class ProtocolHandler:
if packet.fixed_header.packet_type == PacketType.CONNACK:
yield from self.handle_connack(packet)
if packet.fixed_header.packet_type == PacketType.SUBACK:
yield from self.handle_suback(packet)
if packet.fixed_header.packet_type == PacketType.UNSUBACK:
yield from self.handle_unsuback(packet)
else:
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
else:
@ -273,53 +279,30 @@ class ProtocolHandler:
def handle_connack(self, connack: ConnackPacket):
pass
class Subscription:
states = ['new', 'subscribed', 'acknowledged']
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
pass
def __init__(self, packet_id, topics):
self.topics = topics
self.packet_id = packet_id
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=Subscription.states, initial='new')
self.machine.add_transition(trigger='subscribe', source='new', dest='subscribed')
self.machine.add_transition(trigger='acknowledge', source='subscribed', dest='acknowledged')
class UnSubscription:
states = ['new', 'unsubscribed', 'acknowledged']
def __init__(self, packet_id, topics):
self.topics = topics
self.packet_id = packet_id
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=UnSubscription.states, initial='new')
self.machine.add_transition(trigger='unsubscribe', source='new', dest='unsubscribed')
self.machine.add_transition(trigger='acknowledge', source='unsubscribed', dest='acknowledged')
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
pass
class ClientProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, config, loop=None):
super().__init__(session, config, loop)
self._ping_task = None
self.subscriptions = dict()
self._subscription_task = None
self._subscriptions_changed = asyncio.Condition(loop=self._loop)
self._subscriptions_ready = asyncio.Event(loop=self._loop)
self._connack_waiter = None
self._subscriptions_waiter = dict()
self._unsubscriptions_waiter = dict()
@asyncio.coroutine
def start(self):
yield from super().start()
self._subscription_task = asyncio.async(self._subscriptions_coro(), loop=self._loop)
yield from asyncio.wait([self._subscriptions_ready.wait()], loop=self._loop)
@asyncio.coroutine
def stop(self):
yield from super().stop()
yield from asyncio.wait([self._subscription_task], loop=self._loop)
if self._ping_task:
try:
self._ping_task.cancel()
@ -329,61 +312,28 @@ class ClientProtocolHandler(ProtocolHandler):
def handle_keepalive(self):
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
def _subscriptions_coro(self):
self.logger.debug("Starting subscriptions polling coro")
while self._running:
self._subscriptions_ready.set()
yield from asyncio.sleep(self.config['subscriptions-polling-interval'])
self.logger.debug("Subscriptions polling coro wake-up")
try:
while not self.incoming_queues[PacketType.SUBACK].empty():
packet = self.incoming_queues[PacketType.SUBACK].get_nowait()
packet_id = packet.variable_header.packet_id
subscription = self.subscriptions.get(packet_id)
for i in range(len(subscription.topics)):
subscription.topics[i]['return_code'] = packet.payload.return_codes[i]
subscription.acknowledge()
self.logger.debug("Subscription with packet Id=%s acknowledged" % packet_id)
while not self.incoming_queues[PacketType.UNSUBACK].empty():
packet = self.incoming_queues[PacketType.UNSUBACK].get_nowait()
packet_id = packet.variable_header.packet_id
subscription = self.subscriptions.get(packet_id)
subscription.acknowledge()
self.logger.debug("Unsubscription with packet Id=%s acknowledged" % packet_id)
yield from self._subscriptions_changed.acquire()
self._subscriptions_changed.notify_all()
self._subscriptions_changed.release()
except KeyError:
self.logger.warn("Received %s for unknown subscription message Id %d" % (packet.fixed_header.packet_type, packet_id))
except MachineError as me:
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
self.logger.debug("Subscriptions polling coro stopped")
@asyncio.coroutine
def mqtt_subscribe(self, topics, packet_id):
"""
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
:return:
"""
def acknowledged_predicate():
if self.subscriptions[subscribe.variable_header.packet_id].state == 'acknowledged':
return True
else:
return False
subscribe = SubscribePacket.build(topics, packet_id)
yield from self.outgoing_queue.put(subscribe)
subscription = Subscription(subscribe.variable_header.packet_id, topics)
subscription.subscribe()
self.subscriptions[subscribe.variable_header.packet_id] = subscription
yield from self._subscriptions_changed.acquire()
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
subscription = self.subscriptions.pop(subscribe.variable_header.packet_id)
self._subscriptions_changed.release()
return subscription
waiter = futures.Future(loop=self._loop)
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
return_codes = yield from waiter
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
return return_codes
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
packet_id = suback.variable_header.packet_id
try:
waiter = self._subscriptions_waiter.get(packet_id)
waiter.set_result(suback.payload.return_codes)
except KeyError as ke:
self.logger.warn("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def mqtt_unsubscribe(self, topics, packet_id):
@ -392,23 +342,21 @@ class ClientProtocolHandler(ProtocolHandler):
:param topics: array of topics ['/a/b', ...]
:return:
"""
def acknowledged_predicate():
if self.subscriptions[unsubscribe.variable_header.packet_id].state == 'acknowledged':
return True
else:
return False
unsubscribe = UnsubscribePacket.build(topics, packet_id)
yield from self.outgoing_queue.put(unsubscribe)
subscription = UnSubscription(unsubscribe.variable_header.packet_id, topics)
subscription.unsubscribe()
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
yield from self._subscriptions_changed.acquire()
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
subscription = self.subscriptions.pop(unsubscribe.variable_header.packet_id)
self._subscriptions_changed.release()
return subscription
waiter = futures.Future(loop=self._loop)
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
yield from waiter
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
packet_id = unsuback.variable_header.packet_id
try:
waiter = self._unsubscriptions_waiter.get(packet_id)
waiter.set_result(None)
except KeyError as ke:
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def mqtt_connect(self):