From d2d843fec4401d892c7a1adad71da13979128cba Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Sat, 11 Jul 2015 22:42:50 +0200 Subject: [PATCH] HBMQTT-17 Handle timeout on broker side (disconnect client if not message received until timeout specified on CONNECT) --- hbmqtt/broker.py | 5 +++-- hbmqtt/mqtt/protocol/broker_handler.py | 8 ++++++++ hbmqtt/mqtt/protocol/client_handler.py | 5 ++++- hbmqtt/mqtt/protocol/handler.py | 23 ++++++++++++++++++----- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index 72c929d..2e2b846 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -15,7 +15,8 @@ from hbmqtt.utils import format_client_message, gen_client_id _defaults = { 'bind-address': 'localhost', - 'bind-port': 1883 + 'bind-port': 1883, + 'timeout-disconnect-delay': 1 } @@ -169,7 +170,7 @@ class Broker: new_session.username = connect.payload.username new_session.password = connect.payload.password new_session.client_id = connect.payload.client_id - new_session.keep_alive = connect.variable_header.keep_alive + new_session.keep_alive = connect.variable_header.keep_alive + self.config['timeout-disconnect-delay'] new_session.reader = reader new_session.writer = writer diff --git a/hbmqtt/mqtt/protocol/broker_handler.py b/hbmqtt/mqtt/protocol/broker_handler.py index 66adaaa..52928e3 100644 --- a/hbmqtt/mqtt/protocol/broker_handler.py +++ b/hbmqtt/mqtt/protocol/broker_handler.py @@ -34,6 +34,8 @@ class BrokerProtocolHandler(ProtocolHandler): @asyncio.coroutine def stop(self): yield from super().stop() + if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): + self._disconnect_waiter.set_result(None) @asyncio.coroutine def wait_disconnect(self): @@ -41,6 +43,12 @@ class BrokerProtocolHandler(ProtocolHandler): self._disconnect_waiter = futures.Future(loop=self._loop) yield from self._disconnect_waiter + def handle_write_timeout(self): + pass + + def handle_read_timeout(self): + asyncio.Task(self.stop()) + @asyncio.coroutine def handle_disconnect(self, disconnect: DisconnectPacket): if self._disconnect_waiter is not None: diff --git a/hbmqtt/mqtt/protocol/client_handler.py b/hbmqtt/mqtt/protocol/client_handler.py index 09fb443..e46eaff 100644 --- a/hbmqtt/mqtt/protocol/client_handler.py +++ b/hbmqtt/mqtt/protocol/client_handler.py @@ -40,9 +40,12 @@ class ClientProtocolHandler(ProtocolHandler): except Exception: pass - def handle_keepalive(self): + def handle_write_timeout(self): self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping()) + def handle_read_timeout(self): + pass + @asyncio.coroutine def mqtt_subscribe(self, topics, packet_id): """ diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index b2cc247..310d836 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -19,6 +19,7 @@ from hbmqtt.mqtt.pubrec import PubrecPacket from hbmqtt.mqtt.pubcomp import PubcompPacket from hbmqtt.mqtt.suback import SubackPacket from hbmqtt.mqtt.subscribe import SubscribePacket +from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.session import Session @@ -146,7 +147,10 @@ class ProtocolHandler: while self._running: try: self._reader_ready.set() - fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), 5) + keepalive_timeout = self.session.keep_alive + if keepalive_timeout <= 0: + keepalive_timeout = None + fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), keepalive_timeout) if fixed_header: cls = packet_class(fixed_header) packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) @@ -156,6 +160,8 @@ class ProtocolHandler: asyncio.Task(self.handle_connack(packet)) elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE: asyncio.Task(self.handle_subscribe(packet)) + elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE: + asyncio.Task(self.handle_unsubscribe(packet)) elif packet.fixed_header.packet_type == PacketType.SUBACK: asyncio.Task(self.handle_suback(packet)) elif packet.fixed_header.packet_type == PacketType.UNSUBACK: @@ -185,6 +191,7 @@ class ProtocolHandler: break except asyncio.TimeoutError: self.logger.debug("Input stream read timeout") + self.handle_read_timeout() except NoDataException as nde: self.logger.debug("No data available") except Exception as e: @@ -209,8 +216,7 @@ class ProtocolHandler: except asyncio.TimeoutError as ce: self.logger.debug("Output queue get timeout") if self._running: - self.logger.debug("PING for keepalive") - self.handle_keepalive() + self.handle_write_timeout() except Exception as e: self.logger.warn("Unhandled exception in writer coro: %s" % e) break @@ -233,8 +239,11 @@ class ProtocolHandler: inflight_message = yield from self.delivered_message.get() return inflight_message - def handle_keepalive(self): - self.logger.warn('keepalive unhandled') + def handle_write_timeout(self): + self.logger.warn('write timeout unhandled') + + def handle_read_timeout(self): + self.logger.warn('read timeout unhandled') @asyncio.coroutine def handle_connack(self, connack: ConnackPacket): @@ -248,6 +257,10 @@ class ProtocolHandler: def handle_subscribe(self, subscribe: SubscribePacket): self.logger.warn('SUBSCRIBE unhandled') + @asyncio.coroutine + def handle_unsubscribe(self, subscribe: UnsubscribePacket): + self.logger.warn('UNSUBSCRIBE unhandled') + @asyncio.coroutine def handle_suback(self, suback: SubackPacket): self.logger.warn('SUBACK unhandled')