kopia lustrzana https://github.com/Yakifo/amqtt
Test client reconnection (and messages buffering while connection lost)
rodzic
b14cd1b5a4
commit
0077c61943
171
hbmqtt/client.py
171
hbmqtt/client.py
|
@ -11,7 +11,8 @@ from transitions import MachineError
|
|||
|
||||
from hbmqtt.utils import not_in_dict_or_none
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.mqtt.connack import CONNECTION_ACCEPTED
|
||||
from hbmqtt.mqtt.connack import *
|
||||
from hbmqtt.mqtt.connect import *
|
||||
from hbmqtt.mqtt.protocol.client_handler import ClientProtocolHandler
|
||||
from hbmqtt.adapters import StreamReaderAdapter, StreamWriterAdapter, WebSocketsReader, WebSocketsWriter
|
||||
import websockets
|
||||
|
@ -124,24 +125,25 @@ class MQTTClient:
|
|||
raise ClientException("Unhandled exception: %s" % e)
|
||||
|
||||
@asyncio.coroutine
|
||||
def reconnect(self, cleansession=None):
|
||||
def reconnect(self, cleansession=False):
|
||||
if self.session.machine.state == 'connected':
|
||||
self.logger.warn("Client already connected")
|
||||
return CONNECTION_ACCEPTED
|
||||
|
||||
try:
|
||||
self.session.machine.connect()
|
||||
self.session.clclean_session = cleansession
|
||||
self.session.clean_session = cleansession
|
||||
self.logger.debug("Reconnecting with session parameters: %s" % self.session)
|
||||
|
||||
return_code = yield from self._connect_coro()
|
||||
asyncio.Task(self.handle_connection_close())
|
||||
|
||||
self.session.machine.connect_success()
|
||||
self._disconnect_task = asyncio.Task(self.handle_connection_close())
|
||||
return return_code
|
||||
except MachineError:
|
||||
msg = "Connect call incompatible with client current state '%s'" % self.session.machine.state
|
||||
self.logger.warn(msg)
|
||||
self.session.machine.connect_fail()
|
||||
self.session.machine.disconnect()
|
||||
raise ClientException(msg)
|
||||
except Exception as e:
|
||||
self.session.machine.connect_fail()
|
||||
self.session.machine.disconnect()
|
||||
self.logger.warn("Connection failed: %s " % e)
|
||||
raise ClientException("Connection failed: %s " % e)
|
||||
|
||||
|
@ -151,7 +153,11 @@ class MQTTClient:
|
|||
Send a MQTT ping request and wait for response
|
||||
:return: None
|
||||
"""
|
||||
self._handler.mqtt_ping()
|
||||
if self.session.machine.state == 'connected':
|
||||
yield from self._handler.mqtt_ping()
|
||||
else:
|
||||
self.logger.warn("MQTT PING request incompatible with current session state '%s'" %
|
||||
self.session.machine.state)
|
||||
|
||||
@asyncio.coroutine
|
||||
def publish(self, topic, message, qos=None, retain=None):
|
||||
|
@ -199,67 +205,106 @@ class MQTTClient:
|
|||
|
||||
@asyncio.coroutine
|
||||
def _connect_coro(self):
|
||||
sc = None
|
||||
reader = None
|
||||
writer = None
|
||||
kwargs = dict()
|
||||
|
||||
# Decode URI attributes
|
||||
uri_attributes = urlparse(self.session.broker_uri)
|
||||
scheme = uri_attributes.scheme
|
||||
self.session.username = uri_attributes.username
|
||||
self.session.password = uri_attributes.password
|
||||
self.session.remote_address = uri_attributes.hostname
|
||||
self.session.remote_port = uri_attributes.port
|
||||
if scheme in ('mqtt', 'mqtts') and not self.session.remote_port:
|
||||
self.session.remote_port = 8883 if scheme == 'mqtts' else 1883
|
||||
|
||||
if scheme in ('mqtts', 'wss'):
|
||||
if self.session.cafile is None or self.session.cafile == '':
|
||||
self.logger.warn("TLS connection can't be estabilshed, no certificate file (.cert) given")
|
||||
raise ClientException("TLS connection can't be estabilshed, no certificate file (.cert) given")
|
||||
sc = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH,
|
||||
cafile=self.session.cafile,
|
||||
capath=self.session.capath,
|
||||
cadata=self.session.cadata)
|
||||
if 'certfile' in self.config and 'keyfile' in self.config:
|
||||
sc.load_cert_chain(self.config['certfile'], self.config['keyfile'])
|
||||
kwargs['ssl'] = sc
|
||||
|
||||
# Open connection
|
||||
try:
|
||||
sc = None
|
||||
reader = None
|
||||
writer = None
|
||||
kwargs = dict()
|
||||
if scheme in ('mqtt', 'mqtts'):
|
||||
conn_reader, conn_writer = \
|
||||
yield from asyncio.open_connection(self.session.remote_address, self.session.remote_port, **kwargs)
|
||||
reader = StreamReaderAdapter(conn_reader)
|
||||
writer = StreamWriterAdapter(conn_writer)
|
||||
elif scheme in ('ws', 'wss'):
|
||||
websocket = yield from websockets.connect(self.session.broker_uri, subprotocols=['mqtt'], **kwargs)
|
||||
reader = WebSocketsReader(websocket)
|
||||
writer = WebSocketsWriter(websocket)
|
||||
except Exception as e:
|
||||
self.logger.warn("connection failed: %s" % e)
|
||||
self.session.machine.disconnect()
|
||||
raise ClientException("connection Failed: %s" % e)
|
||||
|
||||
# Decode URI attributes
|
||||
uri_attributes = urlparse(self.session.broker_uri)
|
||||
scheme = uri_attributes.scheme
|
||||
self.session.username = uri_attributes.username
|
||||
self.session.password = uri_attributes.password
|
||||
self.session.remote_address = uri_attributes.hostname
|
||||
self.session.remote_port = uri_attributes.port
|
||||
if scheme in ('mqtt', 'mqtts') and not self.session.remote_port:
|
||||
self.session.remote_port = 8883 if scheme == 'mqtts' else 1883
|
||||
|
||||
if scheme in ('mqtts', 'wss'):
|
||||
if self.session.cafile is None or self.session.cafile == '':
|
||||
self.logger.warn("TLS connection can't be estabilshed, no certificate file (.cert) given")
|
||||
raise ClientException("TLS connection can't be estabilshed, no certificate file (.cert) given")
|
||||
sc = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH,
|
||||
cafile=self.session.cafile,
|
||||
capath=self.session.capath,
|
||||
cadata=self.session.cadata)
|
||||
if 'certfile' in self.config and 'keyfile' in self.config:
|
||||
sc.load_cert_chain(self.config['certfile'], self.config['keyfile'])
|
||||
kwargs['ssl'] = sc
|
||||
|
||||
# Open connection
|
||||
try:
|
||||
if scheme in ('mqtt', 'mqtts'):
|
||||
conn_reader, conn_writer = \
|
||||
yield from asyncio.open_connection(self.session.remote_address, self.session.remote_port, **kwargs)
|
||||
reader = StreamReaderAdapter(conn_reader)
|
||||
writer = StreamWriterAdapter(conn_writer)
|
||||
elif scheme in ('ws', 'wss'):
|
||||
websocket = yield from websockets.connect(self.session.broker_uri, subprotocols=['mqtt'], **kwargs)
|
||||
reader = WebSocketsReader(websocket)
|
||||
writer = WebSocketsWriter(websocket)
|
||||
except Exception as e:
|
||||
self.logger.warn("connection failed: %s" % e)
|
||||
raise ClientException("connection Failed: %s" % e)
|
||||
|
||||
# Handle MQTT protocol
|
||||
self._handler = ClientProtocolHandler(reader, writer, loop=self._loop)
|
||||
self._handler.attach_to_session(self.session)
|
||||
yield from self._handler.start()
|
||||
|
||||
return_code = yield from self._handler.mqtt_connect()
|
||||
connect_packet = self.build_connect_packet()
|
||||
yield from connect_packet.to_stream(writer)
|
||||
self.logger.debug(" -out-> " + repr(connect_packet))
|
||||
try :
|
||||
connack = yield from ConnackPacket.from_stream(reader)
|
||||
self.logger.debug(" <-in-- " + repr(connack))
|
||||
return_code = connack.variable_header.return_code
|
||||
|
||||
if return_code is not CONNECTION_ACCEPTED:
|
||||
yield from self._handler.stop()
|
||||
self.session.machine.disconnect()
|
||||
self.logger.warn("Connection rejected with code '%s'" % return_code)
|
||||
else:
|
||||
# Handle MQTT protocol
|
||||
self._handler = ClientProtocolHandler(reader, writer, loop=self._loop)
|
||||
self._handler.attach_to_session(self.session)
|
||||
yield from self._handler.start()
|
||||
self.session.machine.connect()
|
||||
self.logger.debug("connected to %s:%s" % (self.session.remote_address, self.session.remote_port))
|
||||
return return_code
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.logger.warn("connection failed: %s" % e)
|
||||
self.session.machine.disconnect()
|
||||
raise ClientException("connection Failed: %s" % e)
|
||||
|
||||
def build_connect_packet(self):
|
||||
vh = ConnectVariableHeader()
|
||||
payload = ConnectPayload()
|
||||
|
||||
vh.keep_alive = self.session.keep_alive
|
||||
vh.clean_session_flag = self.session.clean_session
|
||||
vh.will_retain_flag = self.session.will_retain
|
||||
payload.client_id = self.session.client_id
|
||||
|
||||
if self.session.username:
|
||||
vh.username_flag = True
|
||||
payload.username = self.session.username
|
||||
else:
|
||||
vh.username_flag = False
|
||||
|
||||
if self.session.password:
|
||||
vh.password_flag = True
|
||||
payload.password = self.session.password
|
||||
else:
|
||||
vh.password_flag = False
|
||||
if self.session.will_flag:
|
||||
vh.will_flag = True
|
||||
vh.will_qos = self.session.will_qos
|
||||
payload.will_message = self.session.will_message
|
||||
payload.will_topic = self.session.will_topic
|
||||
else:
|
||||
vh.will_flag = False
|
||||
|
||||
header = MQTTFixedHeader(CONNECT, 0x00)
|
||||
packet = ConnectPacket(header, vh, payload)
|
||||
return packet
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_connection_close(self):
|
||||
|
@ -267,8 +312,14 @@ class MQTTClient:
|
|||
yield from self._handler.wait_disconnect()
|
||||
self.logger.debug("Handle broker disconnection")
|
||||
yield from self._handler.stop()
|
||||
self._handler.detach_from_session()
|
||||
self.session.machine.disconnect()
|
||||
# while self.session.machine.state != 'connected':
|
||||
# yield from asyncio.sleep(2)
|
||||
# self.logger.debug("Trying reconnect")
|
||||
# try:
|
||||
# yield from self.reconnect()
|
||||
# except ClientException:
|
||||
# self.logger.warn("Reconnect failed")
|
||||
|
||||
def _initsession(
|
||||
self,
|
||||
|
|
|
@ -20,7 +20,6 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None):
|
||||
super().__init__(reader, writer, loop)
|
||||
self._ping_task = None
|
||||
self._connack_waiter = None
|
||||
self._pingresp_queue = asyncio.Queue()
|
||||
self._subscriptions_waiter = dict()
|
||||
self._unsubscriptions_waiter = dict()
|
||||
|
@ -96,49 +95,6 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
except KeyError as ke:
|
||||
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_connect(self):
|
||||
def build_connect_packet(session):
|
||||
vh = ConnectVariableHeader()
|
||||
payload = ConnectPayload()
|
||||
|
||||
vh.keep_alive = session.keep_alive
|
||||
vh.clean_session_flag = session.clean_session
|
||||
vh.will_retain_flag = session.will_retain
|
||||
payload.client_id = session.client_id
|
||||
|
||||
if session.username:
|
||||
vh.username_flag = True
|
||||
payload.username = session.username
|
||||
else:
|
||||
vh.username_flag = False
|
||||
|
||||
if session.password:
|
||||
vh.password_flag = True
|
||||
payload.password = session.password
|
||||
else:
|
||||
vh.password_flag = False
|
||||
if session.will_flag:
|
||||
vh.will_flag = True
|
||||
vh.will_qos = session.will_qos
|
||||
payload.will_message = session.will_message
|
||||
payload.will_topic = session.will_topic
|
||||
else:
|
||||
vh.will_flag = False
|
||||
|
||||
header = MQTTFixedHeader(CONNECT, 0x00)
|
||||
packet = ConnectPacket(header, vh, payload)
|
||||
return packet
|
||||
|
||||
packet = build_connect_packet(self.session)
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
self._connack_waiter = futures.Future(loop=self._loop)
|
||||
return (yield from self._connack_waiter)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_connack(self, connack: ConnackPacket):
|
||||
self._connack_waiter.set_result(connack.variable_header.return_code)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_disconnect(self):
|
||||
# yield from self.outgoing_queue.join() To be used in Python 3.5
|
||||
|
|
|
@ -80,7 +80,6 @@ class ProtocolHandler:
|
|||
ack_packets = []
|
||||
for packet_id in self.session.outgoing_msg:
|
||||
message = self.session.outgoing_msg[packet_id]
|
||||
self.logger.debug(message.state)
|
||||
if message.is_new() or message.is_published():
|
||||
self.logger.debug("Retrying publish message Id=%d", packet_id)
|
||||
message.publish_packet.dup_flag = True
|
||||
|
|
|
@ -51,6 +51,7 @@ class Session:
|
|||
self.machine.add_transition(trigger='connect', source='disconnected', dest='connected')
|
||||
self.machine.add_transition(trigger='disconnect', source='connected', dest='disconnected')
|
||||
self.machine.add_transition(trigger='disconnect', source='new', dest='disconnected')
|
||||
self.machine.add_transition(trigger='disconnect', source='disconnected', dest='disconnected')
|
||||
|
||||
@property
|
||||
def next_packet_id(self):
|
||||
|
|
Ładowanie…
Reference in New Issue