diff --git a/hbmqtt/client.py b/hbmqtt/client.py index 12e9c29..1d918c2 100644 --- a/hbmqtt/client.py +++ b/hbmqtt/client.py @@ -14,7 +14,7 @@ from hbmqtt.mqtt.connect import * from hbmqtt.mqtt.protocol.client_handler import ClientProtocolHandler from hbmqtt.adapters import StreamReaderAdapter, StreamWriterAdapter, WebSocketsReader, WebSocketsWriter from hbmqtt.plugins.manager import PluginManager, BaseContext -from hbmqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_SENT, EVENT_MQTT_PACKET_RECEIVED +from hbmqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_SENT, EVENT_MQTT_PACKET_RECEIVED, ProtocolHandlerException from hbmqtt.mqtt.constants import * import websockets from websockets.uri import InvalidURI @@ -25,6 +25,7 @@ _defaults = { 'ping_delay': 1, 'default_qos': 0, 'default_retain': False, + 'auto_reconnect': True } @@ -90,7 +91,6 @@ class MQTTClient: self.session = None self._handler = None self._disconnect_task = None - self._connection_closed_future = None # Init plugins manager context = ClientContext() @@ -115,10 +115,7 @@ class MQTTClient: self.session = self._initsession(uri, cleansession, cafile, capath, cadata) self.logger.debug("Connect to: %s" % uri) - return_code = yield from self._connect_coro() - self._connection_closed_future = asyncio.Future(loop=self._loop) - self._disconnect_task = asyncio.Task(self.handle_connection_close(), loop=self._loop) - return self._connection_closed_future + return (yield from self._do_connect()) @asyncio.coroutine def disconnect(self): @@ -128,23 +125,25 @@ class MQTTClient: yield from self._handler.mqtt_disconnect() yield from self._handler.stop() self.session.transitions.disconnect() - self._connection_closed_future.set_result(None) else: self.logger.warn("Client session is not currently connected, ignoring call") @asyncio.coroutine - def reconnect(self, cleansession=False): + def reconnect(self, cleansession=None): if self.session.transitions.is_connected(): self.logger.warn("Client already connected") return CONNECTION_ACCEPTED - self.session.clean_session = cleansession + if cleansession: + self.session.clean_session = cleansession self.logger.debug("Reconnecting with session parameters: %s" % self.session) + return (yield from self._do_connect()) + @asyncio.coroutine + def _do_connect(self): return_code = yield from self._connect_coro() - self._connection_closed_future = asyncio.Future(loop=self._loop) self._disconnect_task = asyncio.Task(self.handle_connection_close(), loop=self._loop) - return self._connection_closed_future + return return_code @asyncio.coroutine def ping(self): @@ -198,9 +197,6 @@ class MQTTClient: @asyncio.coroutine def _connect_coro(self): - sc = None - reader = None - writer = None kwargs = dict() # Decode URI attributes @@ -220,6 +216,9 @@ class MQTTClient: uri = (scheme, self.session.remote_address + ":" + str(self.session.remote_port), uri_attributes[2], uri_attributes[3], uri_attributes[4], uri_attributes[5]) self.session.broker_uri = urlunparse(uri) + # Init protocol handler + if not self._handler: + self._handler = ClientProtocolHandler(self.session, self.plugins_manager, loop=self._loop) if secure: if self.session.cafile is None or self.session.cafile == '': @@ -234,8 +233,8 @@ class MQTTClient: sc.load_cert_chain(self.config['certfile'], self.config['keyfile']) kwargs['ssl'] = sc - # Open connection try: + # Open connection if scheme in ('mqtt', 'mqtts'): conn_reader, conn_writer = \ yield from asyncio.open_connection( @@ -251,6 +250,20 @@ class MQTTClient: **kwargs) reader = WebSocketsReader(websocket) writer = WebSocketsWriter(websocket) + # Start MQTT protocol + self._handler.attach_stream(reader, writer) + return_code = yield from self._handler.mqtt_connect() + if return_code is not CONNECTION_ACCEPTED: + self.session.transitions.disconnect() + self.logger.warning("Connection rejected with code '%s'" % return_code) + exc = ConnectException("Connection rejected by broker") + exc.return_code = return_code + raise exc + else: + # Handle MQTT protocol + yield from self._handler.start() + self.session.transitions.connect() + self.logger.debug("connected to %s:%s" % (self.session.remote_address, self.session.remote_port)) except InvalidURI as iuri: self.logger.warn("connection failed: invalid URI '%s'" % self.session.broker_uri) self.session.transitions.disconnect() @@ -259,80 +272,23 @@ class MQTTClient: self.logger.warn("connection failed: invalid websocket handshake") self.session.transitions.disconnect() raise ConnectException("connection failed: invalid websocket handshake", ihs) - - return_code = None - try : - connect_packet = self.build_connect_packet() - yield from connect_packet.to_stream(writer) - yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, - packet=connect_packet, - session=self.session) - - connack = yield from ConnackPacket.from_stream(reader) - yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, - packet=connack, - session=self.session) - return_code = connack.variable_header.return_code - except Exception as e: - self.logger.warn("connection failed: %s" % e) + except ProtocolHandlerException as e: + self.logger.warn("MQTT connection failed: %s" % e) self.session.transitions.disconnect() raise ClientException("connection Failed: %s" % e) - if return_code is not CONNECTION_ACCEPTED: - self.session.transitions.disconnect() - self.logger.warn("Connection rejected with code '%s'" % return_code) - exc = ConnectException("Connection rejected by broker") - exc.return_code = return_code - raise exc - else: - # Handle MQTT protocol - self.session.reader = reader - self.session.writer = writer - self._handler = ClientProtocolHandler(self.session, self.plugins_manager, loop=self._loop) - yield from self._handler.start() - self.session.transitions.connect() - self.logger.debug("connected to %s:%s" % (self.session.remote_address, self.session.remote_port)) - - def build_connect_packet(self): - vh = ConnectVariableHeader() - payload = ConnectPayload() - - vh.keep_alive = self.session.keep_alive - vh.clean_session_flag = self.session.clean_session - vh.will_retain_flag = self.session.will_retain - payload.client_id = self.session.client_id - - if self.session.username: - vh.username_flag = True - payload.username = self.session.username - else: - vh.username_flag = False - - if self.session.password: - vh.password_flag = True - payload.password = self.session.password - else: - vh.password_flag = False - if self.session.will_flag: - vh.will_flag = True - vh.will_qos = self.session.will_qos - payload.will_message = self.session.will_message - payload.will_topic = self.session.will_topic - else: - vh.will_flag = False - - header = MQTTFixedHeader(CONNECT, 0x00) - packet = ConnectPacket(header, vh, payload) - return packet - @asyncio.coroutine def handle_connection_close(self): + self.logger.warning("Disconnectd from broker") self.logger.debug("Watch broker disconnection") yield from self._handler.wait_disconnect() self.logger.debug("Handle broker disconnection") yield from self._handler.stop() + self._handler.detach_stream() self.session.transitions.disconnect() - self._connection_closed_future.set_result(None) + if self.config.get('auto_reconnect', False): + self.logger.debug("Auto-reconnecting") + yield from self.reconnect() def _initsession( self, diff --git a/hbmqtt/mqtt/protocol/client_handler.py b/hbmqtt/mqtt/protocol/client_handler.py index 11790cc..5684ad4 100644 --- a/hbmqtt/mqtt/protocol/client_handler.py +++ b/hbmqtt/mqtt/protocol/client_handler.py @@ -26,7 +26,6 @@ class ClientProtocolHandler(ProtocolHandler): self._unsubscriptions_waiter = dict() self._disconnect_waiter = None self._pingresp_waiter = None - self._connack_waiter = None @asyncio.coroutine def start(self): @@ -78,22 +77,11 @@ class ClientProtocolHandler(ProtocolHandler): @asyncio.coroutine def mqtt_connect(self): - if self._connack_waiter and not self._connack_waiter.done(): - raise ProtocolHandlerException("A CONNECT request is already pending") connect_packet = self._build_connect_packet() yield from self._send_packet(connect_packet) - self._connack_waiter = futures.Future(loop=self._loop) - yield from self._connack_waiter - connack = self._connack_waiter.result() + connack = yield from ConnackPacket.from_stream(self.reader) return connack.return_code - @asyncio.coroutine - def handle_connack(self, connack: ConnackPacket): - if not self._connack_waiter or self._connack_waiter.done(): - self.logger.warning("Unexpected CONNACK received") - else: - self._connack_waiter.set_result(connack) - def handle_write_timeout(self): self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping()) diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 69b260d..254d41b 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -53,8 +53,8 @@ class ProtocolHandler: log = logging.getLogger(__name__) self.logger = logging.LoggerAdapter(log, {'client_id': session.client_id}) self.session = session - self.reader = session.reader - self.writer = session.writer + self.reader = None + self.writer = None self.plugins_manager = plugins_manager self.keepalive_timeout = self.session.keep_alive @@ -75,23 +75,26 @@ class ProtocolHandler: self._pubrel_waiters = dict() self._pubcomp_waiters = dict() - def attach_session(self, session: Session, reader:ReaderAdapter, writer:WriterAdapter): - if self.session: - raise ProtocolHandlerException("Handler already attached to session '%s'" % self.session.client_id) - self.session = session + def attach_stream(self, reader: ReaderAdapter, writer: WriterAdapter): + if self.reader or self.writer: + raise ProtocolHandlerException("Handler is already attached to an opened stream") self.reader = reader self.writer = writer - def detach_session(self): - if not self.session: - self.logger.warning("detach_session() called while no session attached to handler") + def detach_stream(self): + self.reader = None + self.writer = None + + def _is_attached(self): + if self.reader and self.writer: + return True else: - self.session = None - self.reader = None - self.writer = None + return False @asyncio.coroutine def start(self): + if not self._is_attached(): + raise ProtocolHandlerException("Handler is not attached to a stream") self._reader_ready = asyncio.Event(loop=self._loop) self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop) yield from asyncio.wait([self._reader_ready.wait()], loop=self._loop) diff --git a/hbmqtt/session.py b/hbmqtt/session.py index 22fa516..88fcc11 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -41,8 +41,6 @@ class Session: def __init__(self): self._init_states() - self.reader = None - self.writer = None self.remote_address = None self.remote_port = None self.client_id = None