amqtt/hbmqtt/mqtt/protocol/handler.py

399 wiersze
18 KiB
Python

# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
from datetime import datetime
from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket
from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException, HBMQTTException
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.pingresp import PingRespPacket
from hbmqtt.mqtt.pingreq import PingReqPacket
from hbmqtt.mqtt.publish import PublishPacket
from hbmqtt.mqtt.pubrel import PubrelPacket
from hbmqtt.mqtt.puback import PubackPacket
from hbmqtt.mqtt.pubrec import PubrecPacket
from hbmqtt.mqtt.pubcomp import PubcompPacket
from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.session import Session
from hbmqtt.specs import *
from hbmqtt.mqtt.protocol.inflight import *
class ProtocolHandler:
"""
Class implementing the MQTT communication protocol using asyncio features
"""
def __init__(self, loop=None):
self.logger = logging.getLogger(__name__)
self.session = None
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self._reader_task = None
self._writer_task = None
self._reader_ready = asyncio.Event(loop=self._loop)
self._writer_ready = asyncio.Event(loop=self._loop)
self._running = False
self.outgoing_queue = asyncio.Queue()
self._pubrel_waiters = dict()
def attach_to_session(self, session: Session):
self.session = session
self.session.handler = self
extra_info = self.session.writer.get_extra_info('sockname')
self.session.local_address = extra_info[0]
self.session.local_port = extra_info[1]
def detach_from_session(self):
self.session.handler = None
self.session = None
@asyncio.coroutine
def start(self):
self._running = True
self._reader_task = asyncio.Task(self._reader_coro(), loop=self._loop)
self._writer_task = asyncio.Task(self._writer_coro(), loop=self._loop)
yield from asyncio.wait(
[self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
self.logger.debug("%s Handler tasks started" % self.session.client_id)
yield from self.retry_deliveries()
@asyncio.coroutine
def retry_deliveries(self):
"""
Handle [MQTT-4.4.0-1] by resending PUBLISH and PUBREL messages for pending out messages
:return:
"""
self.logger.debug("Begin messages delivery retries")
for packet_id in self.session.outgoing_msg:
message = self.session.outgoing_msg[packet_id]
if message.is_new():
self.logger.debug("Retrying publish message Id=%d", packet_id)
message.publish_packet.dup_flag = True
ack = False
while not ack:
yield from self.outgoing_queue.put(message.publish_packet)
message.retry_publish()
ack = yield from message.wait_acknowledge()
del self.session.outgoing_msg[packet_id]
if message.is_received():
self.logger.debug("Retrying pubrel message Id=%d", packet_id)
yield from self.outgoing_queue.put(PubrelPacket.build(packet_id))
message.sent_pubrel()
self.logger.debug("End messages delivery retries")
@asyncio.coroutine
def mqtt_publish(self, topic, message, qos, retain):
packet_id = self.session.next_packet_id
if packet_id in self.session.outgoing_msg:
self.logger.warn("%s A message with the same packet ID is already in flight" % self.session.client_id)
packet = PublishPacket.build(topic, message, packet_id, False, qos, retain)
yield from self.outgoing_queue.put(packet)
if qos != QOS_0:
inflight_message = OutgoingInFlightMessage(packet, qos, loop=self._loop)
inflight_message.sent_publish()
self.session.outgoing_msg[packet_id] = inflight_message
ack = yield from inflight_message.wait_acknowledge()
while not ack:
#Retry publish
packet = PublishPacket.build(topic, message, packet_id, True, qos, retain)
inflight_message.publish_packet = packet
yield from self.outgoing_queue.put(packet)
inflight_message.retry_publish()
ack = yield from inflight_message.wait_acknowledge()
del self.session.outgoing_msg[packet_id]
@asyncio.coroutine
def stop(self):
self._running = False
self.session.reader.feed_eof()
yield from self.outgoing_queue.put("STOP")
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
# Stop incoming messages flow waiter
for packet_id in self.session.incoming_msg:
self.session.incoming_msg[packet_id].cancel()
for packet_id in self.session.outgoing_msg:
self.session.outgoing_msg[packet_id].cancel()
@asyncio.coroutine
def _reader_coro(self):
self.logger.debug("%s Starting reader coro" % self.session.client_id)
while self._running:
try:
self._reader_ready.set()
keepalive_timeout = self.session.keep_alive
if keepalive_timeout <= 0:
keepalive_timeout = None
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), keepalive_timeout)
if fixed_header:
cls = packet_class(fixed_header)
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet)))
task = None
if packet.fixed_header.packet_type == PacketType.CONNACK:
task = asyncio.Task(self.handle_connack(packet))
elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE:
task = asyncio.Task(self.handle_subscribe(packet))
elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE:
task = asyncio.Task(self.handle_unsubscribe(packet))
elif packet.fixed_header.packet_type == PacketType.SUBACK:
task = asyncio.Task(self.handle_suback(packet))
elif packet.fixed_header.packet_type == PacketType.UNSUBACK:
task = asyncio.Task(self.handle_unsuback(packet))
elif packet.fixed_header.packet_type == PacketType.PUBACK:
task = asyncio.Task(self.handle_puback(packet))
elif packet.fixed_header.packet_type == PacketType.PUBREC:
task = asyncio.Task(self.handle_pubrec(packet))
elif packet.fixed_header.packet_type == PacketType.PUBREL:
task = asyncio.Task(self.handle_pubrel(packet))
elif packet.fixed_header.packet_type == PacketType.PUBCOMP:
task = asyncio.Task(self.handle_pubcomp(packet))
elif packet.fixed_header.packet_type == PacketType.PINGREQ:
task = asyncio.Task(self.handle_pingreq(packet))
elif packet.fixed_header.packet_type == PacketType.PINGRESP:
task = asyncio.Task(self.handle_pingresp(packet))
elif packet.fixed_header.packet_type == PacketType.PUBLISH:
task = asyncio.Task(self.handle_publish(packet))
elif packet.fixed_header.packet_type == PacketType.DISCONNECT:
task = asyncio.Task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == PacketType.CONNECT:
task = asyncio.Task(self.handle_connect(packet))
else:
self.logger.warn("%s Unhandled packet type: %s" %
(self.session.client_id, packet.fixed_header.packet_type))
if task:
# Wait for message handling ends
asyncio.wait([task])
else:
self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id)
yield from self.handle_connection_closed()
break
except asyncio.TimeoutError:
self.logger.debug("%s Input stream read timeout" % self.session.client_id)
self.handle_read_timeout()
except NoDataException as nde:
self.logger.debug("%s No data available" % self.session.client_id)
except Exception as e:
self.logger.warn("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e))
break
self.logger.debug("%s Reader coro stopped" % self.session.client_id)
@asyncio.coroutine
def _writer_coro(self):
self.logger.debug("%s Starting writer coro" % self.session.client_id)
while self._running:
try:
self._writer_ready.set()
keepalive_timeout = self.session.keep_alive
if keepalive_timeout <= 0:
keepalive_timeout = None
packet = yield from asyncio.wait_for(self.outgoing_queue.get(), keepalive_timeout)
if not isinstance(packet, MQTTPacket):
self.logger.debug("%s Writer interruption" % self.session.client_id)
break
yield from packet.to_stream(self.session.writer)
self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet)))
yield from self.session.writer.drain()
except asyncio.TimeoutError as ce:
self.logger.debug("%s Output queue get timeout" % self.session.client_id)
if self._running:
self.handle_write_timeout()
except ConnectionResetError as cre:
yield from self.handle_connection_closed()
break
except Exception as e:
self.logger.warn("%sUnhandled exception in writer coro: %s" % (self.session.client_id, e))
break
self.logger.debug("%s Writer coro stopping" % self.session.client_id)
# Flush queue before stopping
if not self.outgoing_queue.empty():
while True:
try:
packet = self.outgoing_queue.get_nowait()
if not isinstance(packet, MQTTPacket):
break
yield from packet.to_stream(self.session.writer)
self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet)))
except asyncio.QueueEmpty:
break
except Exception as e:
self.logger.warn("%s Unhandled exception in writer coro: %s" % (self.session.client_id, e))
self.logger.debug("%s Writer coro stopped" % self.session.client_id)
@asyncio.coroutine
def mqtt_deliver_next_message(self):
packet_id = yield from self.session.delivered_message_queue.get()
message = self.session.incoming_msg[packet_id]
if message.qos == QOS_0:
del self.session.incoming_msg[packet_id]
self.logger.debug("Discarded incoming message %s" % packet_id)
return message.publish_packet
@asyncio.coroutine
def mqtt_acknowledge_delivery(self, packet_id):
try:
message = self.session.incoming_msg[packet_id]
message.acknowledge_delivery()
self.logger.debug('Message delivery acknowledged, packed_id=%d' % packet_id)
except KeyError:
pass
def handle_write_timeout(self):
self.logger.warn('%s write timeout unhandled' % self.session.client_id)
def handle_read_timeout(self):
self.logger.warn('%s read timeout unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_connack(self, connack: ConnackPacket):
self.logger.warn('%s CONNACK unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_connect(self, connect: ConnectPacket):
self.logger.warn('%s CONNECT unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_subscribe(self, subscribe: SubscribePacket):
self.logger.warn('%s SUBSCRIBE unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_unsubscribe(self, subscribe: UnsubscribePacket):
self.logger.warn('%s UNSUBSCRIBE unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
self.logger.warn('%s SUBACK unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
self.logger.warn('%s UNSUBACK unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_pingresp(self, pingresp: PingRespPacket):
self.logger.warn('%s PINGRESP unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_pingreq(self, pingreq: PingReqPacket):
self.logger.warn('%s PINGREQ unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_disconnect(self, disconnect: DisconnectPacket):
self.logger.warn('%s DISCONNECT unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_connection_closed(self):
self.logger.warn('%s Connection closed unhandled' % self.session.client_id)
@asyncio.coroutine
def handle_puback(self, puback: PubackPacket):
packet_id = puback.variable_header.packet_id
try:
inflight_message = self.session.outgoing_msg[packet_id]
inflight_message.received_puback()
except KeyError as ke:
self.logger.warn("%s Received PUBACK for unknown pending subscription with Id: %s" %
(self.session.client_id, packet_id))
@asyncio.coroutine
def handle_pubrec(self, pubrec: PubrecPacket):
packet_id = pubrec.variable_header.packet_id
try:
inflight_message = self.session.outgoing_msg[packet_id]
inflight_message.received_pubrec()
yield from self.outgoing_queue.put(PubrelPacket.build(packet_id))
inflight_message.sent_pubrel()
except KeyError as ke:
self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def handle_pubcomp(self, pubcomp: PubcompPacket):
packet_id = pubcomp.variable_header.packet_id
try:
inflight_message = self.session.outgoing_msg[packet_id]
inflight_message.received_pubcomp()
except KeyError as ke:
self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def handle_pubrel(self, pubrel: PubrecPacket):
packet_id = pubrel.variable_header.packet_id
try:
inflight_message = self.session.incoming_msg[packet_id]
inflight_message.received_pubrel()
except KeyError as ke:
self.logger.warn("Received PUBREL for unknown pending subscription with Id: %s" % packet_id)
@asyncio.coroutine
def handle_publish(self, publish_packet: PublishPacket):
incoming_message = None
packet_id = publish_packet.variable_header.packet_id
qos = publish_packet.qos
if qos == 0:
if publish_packet.dup_flag:
self.logger.warn("[MQTT-3.3.1-2] DUP flag must set to 0 for QOS 0 message. Message ignored: %s" %
repr(publish_packet))
else:
incoming_message = IncomingInFlightMessage(publish_packet, qos)
incoming_message.received_publish()
self.session.incoming_msg[packet_id] = incoming_message
yield from self.session.delivered_message_queue.put(packet_id)
else:
# Check if publish is a retry
if packet_id in self.session.incoming_msg:
incoming_message = self.session.incoming_msg[packet_id]
else:
incoming_message = IncomingInFlightMessage(publish_packet, qos)
self.session.incoming_msg[packet_id] = incoming_message
incoming_message.publish()
if qos == 1:
# Initiate delivery
yield from self.session.delivered_message_queue.put(packet_id)
ack = yield from incoming_message.wait_acknowledge()
if ack:
# Send PUBACK
puback = PubackPacket.build(packet_id)
yield from self.outgoing_queue.put(puback)
#Discard message
del self.session.incoming_msg[packet_id]
self.logger.debug("Discarded incoming message %d" % packet_id)
else:
raise HBMQTTException("Something wrong, ack is False")
if qos == 2:
# Send PUBREC
pubrec = PubrecPacket.build(packet_id)
yield from self.outgoing_queue.put(pubrec)
incoming_message.sent_pubrec()
# Wait for pubrel
ack = yield from incoming_message.wait_pubrel()
if ack:
# Initiate delivery
yield from self.session.delivered_message_queue.put(packet_id)
else:
raise HBMQTTException("Something wrong, ack is False")
ack = yield from incoming_message.wait_acknowledge()
if ack:
# Send PUBCOMP
pubcomp = PubcompPacket.build(packet_id)
yield from self.outgoing_queue.put(pubcomp)
incoming_message.sent_pubcomp()
#Discard message
del self.session.incoming_msg[packet_id]
self.logger.debug("Discarded incoming message %d" % packet_id)
else:
raise HBMQTTException("Something wrong, ack is False")