kopia lustrzana https://github.com/Yakifo/amqtt
Handle subscribe and unsubscribe with futures
rodzic
5d2126ff63
commit
5eacd2959d
|
@ -18,7 +18,9 @@ from hbmqtt.mqtt.puback import PubackPacket
|
||||||
from hbmqtt.mqtt.pubrec import PubrecPacket
|
from hbmqtt.mqtt.pubrec import PubrecPacket
|
||||||
from hbmqtt.mqtt.pubcomp import PubcompPacket
|
from hbmqtt.mqtt.pubcomp import PubcompPacket
|
||||||
from hbmqtt.mqtt.subscribe import SubscribePacket
|
from hbmqtt.mqtt.subscribe import SubscribePacket
|
||||||
|
from hbmqtt.mqtt.suback import SubackPacket
|
||||||
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
|
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
|
||||||
|
from hbmqtt.mqtt.unsuback import UnsubackPacket
|
||||||
from hbmqtt.session import Session
|
from hbmqtt.session import Session
|
||||||
from transitions import Machine, MachineError
|
from transitions import Machine, MachineError
|
||||||
|
|
||||||
|
@ -153,6 +155,10 @@ class ProtocolHandler:
|
||||||
|
|
||||||
if packet.fixed_header.packet_type == PacketType.CONNACK:
|
if packet.fixed_header.packet_type == PacketType.CONNACK:
|
||||||
yield from self.handle_connack(packet)
|
yield from self.handle_connack(packet)
|
||||||
|
if packet.fixed_header.packet_type == PacketType.SUBACK:
|
||||||
|
yield from self.handle_suback(packet)
|
||||||
|
if packet.fixed_header.packet_type == PacketType.UNSUBACK:
|
||||||
|
yield from self.handle_unsuback(packet)
|
||||||
else:
|
else:
|
||||||
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
||||||
else:
|
else:
|
||||||
|
@ -273,53 +279,30 @@ class ProtocolHandler:
|
||||||
def handle_connack(self, connack: ConnackPacket):
|
def handle_connack(self, connack: ConnackPacket):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class Subscription:
|
@asyncio.coroutine
|
||||||
states = ['new', 'subscribed', 'acknowledged']
|
def handle_suback(self, suback: SubackPacket):
|
||||||
|
pass
|
||||||
|
|
||||||
def __init__(self, packet_id, topics):
|
@asyncio.coroutine
|
||||||
self.topics = topics
|
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||||
self.packet_id = packet_id
|
pass
|
||||||
self._init_states()
|
|
||||||
|
|
||||||
def _init_states(self):
|
|
||||||
self.machine = Machine(model=self, states=Subscription.states, initial='new')
|
|
||||||
self.machine.add_transition(trigger='subscribe', source='new', dest='subscribed')
|
|
||||||
self.machine.add_transition(trigger='acknowledge', source='subscribed', dest='acknowledged')
|
|
||||||
|
|
||||||
class UnSubscription:
|
|
||||||
states = ['new', 'unsubscribed', 'acknowledged']
|
|
||||||
|
|
||||||
def __init__(self, packet_id, topics):
|
|
||||||
self.topics = topics
|
|
||||||
self.packet_id = packet_id
|
|
||||||
self._init_states()
|
|
||||||
|
|
||||||
def _init_states(self):
|
|
||||||
self.machine = Machine(model=self, states=UnSubscription.states, initial='new')
|
|
||||||
self.machine.add_transition(trigger='unsubscribe', source='new', dest='unsubscribed')
|
|
||||||
self.machine.add_transition(trigger='acknowledge', source='unsubscribed', dest='acknowledged')
|
|
||||||
|
|
||||||
|
|
||||||
class ClientProtocolHandler(ProtocolHandler):
|
class ClientProtocolHandler(ProtocolHandler):
|
||||||
def __init__(self, session: Session, config, loop=None):
|
def __init__(self, session: Session, config, loop=None):
|
||||||
super().__init__(session, config, loop)
|
super().__init__(session, config, loop)
|
||||||
self._ping_task = None
|
self._ping_task = None
|
||||||
self.subscriptions = dict()
|
|
||||||
self._subscription_task = None
|
|
||||||
self._subscriptions_changed = asyncio.Condition(loop=self._loop)
|
|
||||||
self._subscriptions_ready = asyncio.Event(loop=self._loop)
|
|
||||||
self._connack_waiter = None
|
self._connack_waiter = None
|
||||||
|
self._subscriptions_waiter = dict()
|
||||||
|
self._unsubscriptions_waiter = dict()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def start(self):
|
def start(self):
|
||||||
yield from super().start()
|
yield from super().start()
|
||||||
self._subscription_task = asyncio.async(self._subscriptions_coro(), loop=self._loop)
|
|
||||||
yield from asyncio.wait([self._subscriptions_ready.wait()], loop=self._loop)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def stop(self):
|
def stop(self):
|
||||||
yield from super().stop()
|
yield from super().stop()
|
||||||
yield from asyncio.wait([self._subscription_task], loop=self._loop)
|
|
||||||
if self._ping_task:
|
if self._ping_task:
|
||||||
try:
|
try:
|
||||||
self._ping_task.cancel()
|
self._ping_task.cancel()
|
||||||
|
@ -329,61 +312,28 @@ class ClientProtocolHandler(ProtocolHandler):
|
||||||
def handle_keepalive(self):
|
def handle_keepalive(self):
|
||||||
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
|
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
|
||||||
|
|
||||||
def _subscriptions_coro(self):
|
|
||||||
self.logger.debug("Starting subscriptions polling coro")
|
|
||||||
while self._running:
|
|
||||||
self._subscriptions_ready.set()
|
|
||||||
yield from asyncio.sleep(self.config['subscriptions-polling-interval'])
|
|
||||||
self.logger.debug("Subscriptions polling coro wake-up")
|
|
||||||
try:
|
|
||||||
while not self.incoming_queues[PacketType.SUBACK].empty():
|
|
||||||
packet = self.incoming_queues[PacketType.SUBACK].get_nowait()
|
|
||||||
packet_id = packet.variable_header.packet_id
|
|
||||||
subscription = self.subscriptions.get(packet_id)
|
|
||||||
for i in range(len(subscription.topics)):
|
|
||||||
subscription.topics[i]['return_code'] = packet.payload.return_codes[i]
|
|
||||||
subscription.acknowledge()
|
|
||||||
self.logger.debug("Subscription with packet Id=%s acknowledged" % packet_id)
|
|
||||||
|
|
||||||
while not self.incoming_queues[PacketType.UNSUBACK].empty():
|
|
||||||
packet = self.incoming_queues[PacketType.UNSUBACK].get_nowait()
|
|
||||||
packet_id = packet.variable_header.packet_id
|
|
||||||
subscription = self.subscriptions.get(packet_id)
|
|
||||||
subscription.acknowledge()
|
|
||||||
self.logger.debug("Unsubscription with packet Id=%s acknowledged" % packet_id)
|
|
||||||
|
|
||||||
yield from self._subscriptions_changed.acquire()
|
|
||||||
self._subscriptions_changed.notify_all()
|
|
||||||
self._subscriptions_changed.release()
|
|
||||||
except KeyError:
|
|
||||||
self.logger.warn("Received %s for unknown subscription message Id %d" % (packet.fixed_header.packet_type, packet_id))
|
|
||||||
except MachineError as me:
|
|
||||||
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
|
|
||||||
self.logger.debug("Subscriptions polling coro stopped")
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def mqtt_subscribe(self, topics, packet_id):
|
def mqtt_subscribe(self, topics, packet_id):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
|
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
def acknowledged_predicate():
|
|
||||||
if self.subscriptions[subscribe.variable_header.packet_id].state == 'acknowledged':
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
subscribe = SubscribePacket.build(topics, packet_id)
|
subscribe = SubscribePacket.build(topics, packet_id)
|
||||||
yield from self.outgoing_queue.put(subscribe)
|
yield from self.outgoing_queue.put(subscribe)
|
||||||
subscription = Subscription(subscribe.variable_header.packet_id, topics)
|
waiter = futures.Future(loop=self._loop)
|
||||||
subscription.subscribe()
|
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
|
||||||
self.subscriptions[subscribe.variable_header.packet_id] = subscription
|
return_codes = yield from waiter
|
||||||
yield from self._subscriptions_changed.acquire()
|
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
|
||||||
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
|
return return_codes
|
||||||
subscription = self.subscriptions.pop(subscribe.variable_header.packet_id)
|
|
||||||
self._subscriptions_changed.release()
|
@asyncio.coroutine
|
||||||
return subscription
|
def handle_suback(self, suback: SubackPacket):
|
||||||
|
packet_id = suback.variable_header.packet_id
|
||||||
|
try:
|
||||||
|
waiter = self._subscriptions_waiter.get(packet_id)
|
||||||
|
waiter.set_result(suback.payload.return_codes)
|
||||||
|
except KeyError as ke:
|
||||||
|
self.logger.warn("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def mqtt_unsubscribe(self, topics, packet_id):
|
def mqtt_unsubscribe(self, topics, packet_id):
|
||||||
|
@ -392,23 +342,21 @@ class ClientProtocolHandler(ProtocolHandler):
|
||||||
:param topics: array of topics ['/a/b', ...]
|
:param topics: array of topics ['/a/b', ...]
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
def acknowledged_predicate():
|
|
||||||
if self.subscriptions[unsubscribe.variable_header.packet_id].state == 'acknowledged':
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
unsubscribe = UnsubscribePacket.build(topics, packet_id)
|
unsubscribe = UnsubscribePacket.build(topics, packet_id)
|
||||||
yield from self.outgoing_queue.put(unsubscribe)
|
yield from self.outgoing_queue.put(unsubscribe)
|
||||||
subscription = UnSubscription(unsubscribe.variable_header.packet_id, topics)
|
waiter = futures.Future(loop=self._loop)
|
||||||
subscription.unsubscribe()
|
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
|
||||||
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
|
yield from waiter
|
||||||
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
|
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
|
||||||
yield from self._subscriptions_changed.acquire()
|
|
||||||
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
|
@asyncio.coroutine
|
||||||
subscription = self.subscriptions.pop(unsubscribe.variable_header.packet_id)
|
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||||
self._subscriptions_changed.release()
|
packet_id = unsuback.variable_header.packet_id
|
||||||
return subscription
|
try:
|
||||||
|
waiter = self._unsubscriptions_waiter.get(packet_id)
|
||||||
|
waiter.set_result(None)
|
||||||
|
except KeyError as ke:
|
||||||
|
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def mqtt_connect(self):
|
def mqtt_connect(self):
|
||||||
|
|
Ładowanie…
Reference in New Issue