amqtt/hbmqtt/mqtt/protocol.py

394 wiersze
16 KiB
Python
Czysty Zwykły widok Historia

2015-06-27 15:42:19 +00:00
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
2015-07-05 19:30:52 +00:00
from asyncio import futures
2015-06-27 15:42:19 +00:00
from hbmqtt.mqtt.packet import MQTTFixedHeader
from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload
2015-07-05 19:12:18 +00:00
from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
2015-06-27 20:26:50 +00:00
from hbmqtt.mqtt.pingreq import PingReqPacket
2015-06-28 20:48:07 +00:00
from hbmqtt.mqtt.publish import PublishPacket
from hbmqtt.mqtt.pubrel import PubrelPacket
from hbmqtt.mqtt.puback import PubackPacket
from hbmqtt.mqtt.pubrec import PubrecPacket
from hbmqtt.mqtt.pubcomp import PubcompPacket
2015-06-29 20:38:36 +00:00
from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.suback import SubackPacket
2015-06-29 20:38:36 +00:00
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.session import Session
2015-06-28 20:48:07 +00:00
from transitions import Machine, MachineError
class InFlightMessage:
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
2015-06-29 20:46:05 +00:00
def __init__(self, packet, qos):
self.packet = packet
2015-06-28 20:48:07 +00:00
self.qos = qos
self._init_states()
def _init_states(self):
self.machine = Machine(model=self, states=InFlightMessage.states, initial='new')
self.machine.add_transition(trigger='publish', source='new', dest='published')
if self.qos == 0x01:
self.machine.add_transition(trigger='acknowledge', source='published', dest='acknowledged')
if self.qos == 0x02:
self.machine.add_transition(trigger='receive', source='published', dest='received')
self.machine.add_transition(trigger='release', source='received', dest='released')
self.machine.add_transition(trigger='complete', source='released', dest='completed')
2015-06-27 15:42:19 +00:00
2015-06-29 20:46:05 +00:00
2015-06-27 15:42:19 +00:00
class ProtocolHandler:
"""
Class implementing the MQTT communication protocol using asyncio features
"""
2015-06-27 20:26:50 +00:00
2015-06-28 20:48:07 +00:00
def __init__(self, session: Session, config, loop=None):
2015-06-27 15:42:19 +00:00
self.logger = logging.getLogger(__name__)
self.session = session
2015-06-28 20:48:07 +00:00
self.config = config
2015-06-27 15:42:19 +00:00
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self._reader_task = None
self._writer_task = None
self._reader_ready = asyncio.Event(loop=self._loop)
self._writer_ready = asyncio.Event(loop=self._loop)
2015-06-28 20:48:07 +00:00
2015-06-27 15:42:19 +00:00
self._running = False
self.session.local_address, self.session.local_port = self.session.writer.get_extra_info('sockname')
2015-06-28 20:48:07 +00:00
self.incoming_queues = dict()
self.application_messages = asyncio.Queue()
for p in PacketType:
self.incoming_queues[p] = asyncio.Queue()
self.outgoing_queue = asyncio.Queue()
self._puback_waiters = dict()
self._pubrec_waiters = dict()
self._pubrec_waiters = dict()
self._pubcomp_waiters = dict()
2015-06-28 20:48:07 +00:00
self.inflight_messages = dict()
2015-06-27 15:42:19 +00:00
@asyncio.coroutine
def start(self):
self._running = True
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
2015-06-29 20:46:05 +00:00
yield from asyncio.wait(
[self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
2015-06-27 15:42:19 +00:00
self.logger.debug("Handler tasks started")
2015-06-28 20:48:07 +00:00
@asyncio.coroutine
def mqtt_publish(self, topic, message, packet_id, dup, qos, retain):
if packet_id in self.inflight_messages:
self.logger.warn("A message with the same packet ID is already in flight")
packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain)
yield from self.outgoing_queue.put(packet)
inflight_message = InFlightMessage(packet, qos)
2015-06-28 20:48:07 +00:00
self.inflight_messages[packet.variable_header.packet_id] = inflight_message
inflight_message.publish()
2015-06-28 20:48:07 +00:00
if qos == 0x01:
waiter = futures.Future(loop=self._loop)
self._puback_waiters[packet_id] = waiter
yield from waiter
inflight_message.acknowledge()
del self._puback_waiters[packet_id]
2015-06-28 20:48:07 +00:00
if qos == 0x02:
# Wait for PUBREC
waiter = futures.Future(loop=self._loop)
self._pubrec_waiters[packet_id] = waiter
yield from waiter
del self._pubrec_waiters[packet_id]
inflight_message.receive()
# Send pubrel
pubrel = PubrelPacket.build(packet_id)
yield from self.outgoing_queue.put(pubrel)
inflight_message.release()
# Wait for pubcomp
waiter = futures.Future(loop=self._loop)
self._pubcomp_waiters[packet_id] = waiter
yield from waiter
del self._pubcomp_waiters[packet_id]
2015-07-05 20:29:46 +00:00
inflight_message.complete()
del self.inflight_messages[packet_id]
2015-07-05 20:29:46 +00:00
return inflight_message
2015-06-27 15:42:19 +00:00
@asyncio.coroutine
def stop(self):
self._running = False
self.session.reader.feed_eof()
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
2015-06-27 15:42:19 +00:00
@asyncio.coroutine
def _reader_coro(self):
self.logger.debug("Starting reader coro")
while self._running:
try:
self._reader_ready.set()
2015-06-27 15:55:18 +00:00
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), 5)
2015-06-27 15:42:19 +00:00
if fixed_header:
cls = packet_class(fixed_header)
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
self.logger.debug(" <-in-- " + repr(packet))
2015-07-05 19:12:18 +00:00
if packet.fixed_header.packet_type == PacketType.CONNACK:
yield from self.handle_connack(packet)
if packet.fixed_header.packet_type == PacketType.SUBACK:
yield from self.handle_suback(packet)
if packet.fixed_header.packet_type == PacketType.UNSUBACK:
yield from self.handle_unsuback(packet)
if packet.fixed_header.packet_type == PacketType.PUBACK:
yield from self.handle_puback(packet)
if packet.fixed_header.packet_type == PacketType.PUBREC:
yield from self.handle_pubrec(packet)
if packet.fixed_header.packet_type == PacketType.PUBCOMP:
yield from self.handle_pubcomp(packet)
2015-07-05 19:12:18 +00:00
else:
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
2015-06-27 15:42:19 +00:00
else:
2015-06-29 20:38:36 +00:00
self.logger.debug("No more data, stopping reader coro")
break
2015-06-27 15:42:19 +00:00
except asyncio.TimeoutError:
2015-06-27 20:26:50 +00:00
self.logger.debug("Input stream read timeout")
2015-06-27 15:42:19 +00:00
except NoDataException as nde:
self.logger.debug("No data available")
except Exception as e:
self.logger.warn("Unhandled exception in reader coro: %s" % e)
2015-06-27 15:42:19 +00:00
break
self.logger.debug("Reader coro stopped")
@asyncio.coroutine
def _writer_coro(self):
self.logger.debug("Starting writer coro")
2015-07-05 13:53:52 +00:00
keepalive_timeout = self.session.keep_alive - self.config['ping_delay']
2015-06-27 15:42:19 +00:00
while self._running:
try:
self._writer_ready.set()
2015-07-05 13:53:52 +00:00
packet = yield from asyncio.wait_for(self.outgoing_queue.get(), keepalive_timeout)
2015-06-27 15:42:19 +00:00
yield from packet.to_stream(self.session.writer)
2015-06-28 20:48:07 +00:00
self.logger.debug(" -out-> " + repr(packet))
2015-06-27 15:42:19 +00:00
yield from self.session.writer.drain()
2015-07-05 13:53:52 +00:00
#self.outgoing_queue.task_done() # to be used with Python 3.5
2015-06-27 15:42:19 +00:00
except asyncio.TimeoutError as ce:
2015-06-27 20:26:50 +00:00
self.logger.debug("Output queue get timeout")
2015-07-05 13:53:52 +00:00
if self._running:
self.logger.debug("PING for keepalive")
self.handle_keepalive()
2015-06-27 15:42:19 +00:00
except Exception as e:
self.logger.warn("Unhandled exception in writer coro: %s" % e)
2015-06-27 15:42:19 +00:00
break
self.logger.debug("Writer coro stopping")
# Flush queue before stopping
if not self.outgoing_queue.empty():
2015-06-27 15:42:19 +00:00
while True:
try:
packet = self.outgoing_queue.get_nowait()
2015-06-27 15:42:19 +00:00
yield from packet.to_stream(self.session.writer)
2015-06-28 20:48:07 +00:00
self.logger.debug(" -out-> " + repr(packet))
2015-06-27 15:42:19 +00:00
except asyncio.QueueEmpty:
break
except Exception as e:
self.logger.warn("Unhandled exception in writer coro: %s" % e)
2015-06-27 15:42:19 +00:00
self.logger.debug("Writer coro stopped")
@asyncio.coroutine
def _receive_publish_coro(self):
while self._running:
message = yield from self.incoming_queues[PacketType.PUBLISH].get()
yield self.application_messages.put(message)
message_id = message.fixed_header.packet_id
if (message.fixed_header.flags >> 1) & 0x01:
# QOS 1
yield from self.outgoing_queue.put(PubackPacket.build(message_id))
if (message.fixed_header.flags >> 1) & 0x02:
# QOS 2
yield from self.outgoing_queue.put(PubrecPacket.build(message_id))
@asyncio.coroutine
def mqtt_deliver_next_message(self):
message = yield from self.application_messages.get()
message_id = message.fixed_header.packet_id
if (message.fixed_header.flags >> 1) & 0x02:
# QOS 2
yield from self.outgoing_queue.put(PubrecPacket.build(message_id))
return message
2015-06-28 20:48:07 +00:00
2015-07-05 13:53:52 +00:00
def handle_keepalive(self):
pass
2015-07-05 19:12:18 +00:00
@asyncio.coroutine
def handle_connack(self, connack: ConnackPacket):
pass
2015-07-05 13:53:52 +00:00
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
pass
2015-06-29 20:38:36 +00:00
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
pass
2015-06-29 20:38:36 +00:00
@asyncio.coroutine
def handle_puback(self, puback: PubackPacket):
packet_id = puback.variable_header.packet_id
try:
waiter = self._puback_waiters[packet_id]
waiter.set_result(None)
except KeyError as ke:
self.logger.warn("Received PUBACK for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def handle_pubrec(self, pubrec: PubrecPacket):
packet_id = pubrec.variable_header.packet_id
try:
waiter = self._pubrec_waiters[packet_id]
waiter.set_result(None)
except KeyError as ke:
self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def handle_pubcomp(self, pubcomp: PubcompPacket):
packet_id = pubcomp.variable_header.packet_id
try:
waiter = self._pubcomp_waiters[packet_id]
waiter.set_result(None)
except KeyError as ke:
self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id)
class ClientProtocolHandler(ProtocolHandler):
2015-06-28 20:48:07 +00:00
def __init__(self, session: Session, config, loop=None):
super().__init__(session, config, loop)
2015-06-27 20:26:50 +00:00
self._ping_task = None
2015-07-05 19:30:52 +00:00
self._connack_waiter = None
self._subscriptions_waiter = dict()
self._unsubscriptions_waiter = dict()
2015-06-27 20:26:50 +00:00
@asyncio.coroutine
def start(self):
yield from super().start()
2015-06-29 20:38:36 +00:00
2015-06-27 20:26:50 +00:00
@asyncio.coroutine
def stop(self):
2015-06-29 20:38:36 +00:00
yield from super().stop()
2015-06-27 20:26:50 +00:00
if self._ping_task:
try:
self._ping_task.cancel()
except Exception:
pass
2015-07-05 13:53:52 +00:00
def handle_keepalive(self):
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
2015-06-29 20:38:36 +00:00
@asyncio.coroutine
def mqtt_subscribe(self, topics, packet_id):
"""
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
:return:
"""
subscribe = SubscribePacket.build(topics, packet_id)
yield from self.outgoing_queue.put(subscribe)
waiter = futures.Future(loop=self._loop)
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
return_codes = yield from waiter
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
return return_codes
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
packet_id = suback.variable_header.packet_id
try:
waiter = self._subscriptions_waiter.get(packet_id)
waiter.set_result(suback.payload.return_codes)
except KeyError as ke:
self.logger.warn("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
2015-06-29 20:38:36 +00:00
@asyncio.coroutine
def mqtt_unsubscribe(self, topics, packet_id):
"""
:param topics: array of topics ['/a/b', ...]
:return:
"""
unsubscribe = UnsubscribePacket.build(topics, packet_id)
yield from self.outgoing_queue.put(unsubscribe)
waiter = futures.Future(loop=self._loop)
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
yield from waiter
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
packet_id = unsuback.variable_header.packet_id
try:
waiter = self._unsubscriptions_waiter.get(packet_id)
waiter.set_result(None)
except KeyError as ke:
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
2015-06-29 20:38:36 +00:00
@asyncio.coroutine
def mqtt_connect(self):
def build_connect_packet(session):
vh = ConnectVariableHeader()
payload = ConnectPayload()
vh.keep_alive = session.keep_alive
vh.clean_session_flag = session.clean_session
vh.will_retain_flag = session.will_retain
payload.client_id = session.client_id
if session.username:
vh.username_flag = True
payload.username = session.username
else:
vh.username_flag = False
if session.password:
vh.password_flag = True
payload.password = session.password
else:
vh.password_flag = False
if session.will_flag:
vh.will_flag = True
vh.will_qos = session.will_qos
payload.will_message = session.will_message
payload.will_topic = session.will_topic
else:
vh.will_flag = False
header = MQTTFixedHeader(PacketType.CONNECT, 0x00)
packet = ConnectPacket(header, vh, payload)
return packet
packet = build_connect_packet(self.session)
yield from self.outgoing_queue.put(packet)
2015-07-05 19:30:52 +00:00
self._connack_waiter = futures.Future(loop=self._loop)
return (yield from self._connack_waiter)
2015-07-05 19:12:18 +00:00
@asyncio.coroutine
def handle_connack(self, connack: ConnackPacket):
2015-07-05 19:30:52 +00:00
self._connack_waiter.set_result(connack.variable_header.return_code)
2015-07-05 19:12:18 +00:00
@asyncio.coroutine
def mqtt_disconnect(self):
2015-07-05 13:53:52 +00:00
# yield from self.outgoing_queue.join() To be used in Python 3.5
disconnect_packet = DisconnectPacket()
yield from self.outgoing_queue.put(disconnect_packet)
2015-07-05 19:30:52 +00:00
self._connack_waiter = None
2015-06-27 20:26:50 +00:00
@asyncio.coroutine
def mqtt_ping(self):
ping_packet = PingReqPacket()
yield from self.outgoing_queue.put(ping_packet)
yield from self.incoming_queues[PacketType.PINGRESP].get()