kopia lustrzana https://github.com/Yakifo/amqtt
Handle subscribe and unsubscribe with futures
rodzic
5d2126ff63
commit
5eacd2959d
|
@ -18,7 +18,9 @@ from hbmqtt.mqtt.puback import PubackPacket
|
|||
from hbmqtt.mqtt.pubrec import PubrecPacket
|
||||
from hbmqtt.mqtt.pubcomp import PubcompPacket
|
||||
from hbmqtt.mqtt.subscribe import SubscribePacket
|
||||
from hbmqtt.mqtt.suback import SubackPacket
|
||||
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
|
||||
from hbmqtt.mqtt.unsuback import UnsubackPacket
|
||||
from hbmqtt.session import Session
|
||||
from transitions import Machine, MachineError
|
||||
|
||||
|
@ -153,6 +155,10 @@ class ProtocolHandler:
|
|||
|
||||
if packet.fixed_header.packet_type == PacketType.CONNACK:
|
||||
yield from self.handle_connack(packet)
|
||||
if packet.fixed_header.packet_type == PacketType.SUBACK:
|
||||
yield from self.handle_suback(packet)
|
||||
if packet.fixed_header.packet_type == PacketType.UNSUBACK:
|
||||
yield from self.handle_unsuback(packet)
|
||||
else:
|
||||
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
||||
else:
|
||||
|
@ -273,53 +279,30 @@ class ProtocolHandler:
|
|||
def handle_connack(self, connack: ConnackPacket):
|
||||
pass
|
||||
|
||||
class Subscription:
|
||||
states = ['new', 'subscribed', 'acknowledged']
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
pass
|
||||
|
||||
def __init__(self, packet_id, topics):
|
||||
self.topics = topics
|
||||
self.packet_id = packet_id
|
||||
self._init_states()
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(model=self, states=Subscription.states, initial='new')
|
||||
self.machine.add_transition(trigger='subscribe', source='new', dest='subscribed')
|
||||
self.machine.add_transition(trigger='acknowledge', source='subscribed', dest='acknowledged')
|
||||
|
||||
class UnSubscription:
|
||||
states = ['new', 'unsubscribed', 'acknowledged']
|
||||
|
||||
def __init__(self, packet_id, topics):
|
||||
self.topics = topics
|
||||
self.packet_id = packet_id
|
||||
self._init_states()
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(model=self, states=UnSubscription.states, initial='new')
|
||||
self.machine.add_transition(trigger='unsubscribe', source='new', dest='unsubscribed')
|
||||
self.machine.add_transition(trigger='acknowledge', source='unsubscribed', dest='acknowledged')
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
pass
|
||||
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler):
|
||||
def __init__(self, session: Session, config, loop=None):
|
||||
super().__init__(session, config, loop)
|
||||
self._ping_task = None
|
||||
self.subscriptions = dict()
|
||||
self._subscription_task = None
|
||||
self._subscriptions_changed = asyncio.Condition(loop=self._loop)
|
||||
self._subscriptions_ready = asyncio.Event(loop=self._loop)
|
||||
self._connack_waiter = None
|
||||
self._subscriptions_waiter = dict()
|
||||
self._unsubscriptions_waiter = dict()
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
yield from super().start()
|
||||
self._subscription_task = asyncio.async(self._subscriptions_coro(), loop=self._loop)
|
||||
yield from asyncio.wait([self._subscriptions_ready.wait()], loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
yield from super().stop()
|
||||
yield from asyncio.wait([self._subscription_task], loop=self._loop)
|
||||
if self._ping_task:
|
||||
try:
|
||||
self._ping_task.cancel()
|
||||
|
@ -329,61 +312,28 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
def handle_keepalive(self):
|
||||
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
|
||||
|
||||
def _subscriptions_coro(self):
|
||||
self.logger.debug("Starting subscriptions polling coro")
|
||||
while self._running:
|
||||
self._subscriptions_ready.set()
|
||||
yield from asyncio.sleep(self.config['subscriptions-polling-interval'])
|
||||
self.logger.debug("Subscriptions polling coro wake-up")
|
||||
try:
|
||||
while not self.incoming_queues[PacketType.SUBACK].empty():
|
||||
packet = self.incoming_queues[PacketType.SUBACK].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
subscription = self.subscriptions.get(packet_id)
|
||||
for i in range(len(subscription.topics)):
|
||||
subscription.topics[i]['return_code'] = packet.payload.return_codes[i]
|
||||
subscription.acknowledge()
|
||||
self.logger.debug("Subscription with packet Id=%s acknowledged" % packet_id)
|
||||
|
||||
while not self.incoming_queues[PacketType.UNSUBACK].empty():
|
||||
packet = self.incoming_queues[PacketType.UNSUBACK].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
subscription = self.subscriptions.get(packet_id)
|
||||
subscription.acknowledge()
|
||||
self.logger.debug("Unsubscription with packet Id=%s acknowledged" % packet_id)
|
||||
|
||||
yield from self._subscriptions_changed.acquire()
|
||||
self._subscriptions_changed.notify_all()
|
||||
self._subscriptions_changed.release()
|
||||
except KeyError:
|
||||
self.logger.warn("Received %s for unknown subscription message Id %d" % (packet.fixed_header.packet_type, packet_id))
|
||||
except MachineError as me:
|
||||
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
|
||||
self.logger.debug("Subscriptions polling coro stopped")
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_subscribe(self, topics, packet_id):
|
||||
"""
|
||||
|
||||
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
|
||||
:return:
|
||||
"""
|
||||
def acknowledged_predicate():
|
||||
if self.subscriptions[subscribe.variable_header.packet_id].state == 'acknowledged':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
subscribe = SubscribePacket.build(topics, packet_id)
|
||||
yield from self.outgoing_queue.put(subscribe)
|
||||
subscription = Subscription(subscribe.variable_header.packet_id, topics)
|
||||
subscription.subscribe()
|
||||
self.subscriptions[subscribe.variable_header.packet_id] = subscription
|
||||
yield from self._subscriptions_changed.acquire()
|
||||
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
|
||||
subscription = self.subscriptions.pop(subscribe.variable_header.packet_id)
|
||||
self._subscriptions_changed.release()
|
||||
return subscription
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
|
||||
return_codes = yield from waiter
|
||||
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
|
||||
return return_codes
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
packet_id = suback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._subscriptions_waiter.get(packet_id)
|
||||
waiter.set_result(suback.payload.return_codes)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_unsubscribe(self, topics, packet_id):
|
||||
|
@ -392,23 +342,21 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
:param topics: array of topics ['/a/b', ...]
|
||||
:return:
|
||||
"""
|
||||
def acknowledged_predicate():
|
||||
if self.subscriptions[unsubscribe.variable_header.packet_id].state == 'acknowledged':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
unsubscribe = UnsubscribePacket.build(topics, packet_id)
|
||||
yield from self.outgoing_queue.put(unsubscribe)
|
||||
subscription = UnSubscription(unsubscribe.variable_header.packet_id, topics)
|
||||
subscription.unsubscribe()
|
||||
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
|
||||
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
|
||||
yield from self._subscriptions_changed.acquire()
|
||||
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
|
||||
subscription = self.subscriptions.pop(unsubscribe.variable_header.packet_id)
|
||||
self._subscriptions_changed.release()
|
||||
return subscription
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
|
||||
yield from waiter
|
||||
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
packet_id = unsuback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._unsubscriptions_waiter.get(packet_id)
|
||||
waiter.set_result(None)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_connect(self):
|
||||
|
|
Ładowanie…
Reference in New Issue