Add broker test class + fixes in both client and broker connection management

pull/8/head
Nico 2015-10-07 22:42:04 +02:00
rodzic 58e6069656
commit 3acec1d606
10 zmienionych plików z 327 dodań i 191 usunięć

Wyświetl plik

@ -14,7 +14,7 @@ from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from hbmqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.connack import *
from hbmqtt.errors import HBMQTTException
from hbmqtt.errors import HBMQTTException, MQTTException
from hbmqtt.utils import format_client_message, gen_client_id
from hbmqtt.mqtt.packet import PUBLISH
from hbmqtt.codecs import int_to_bytes_str
@ -27,6 +27,10 @@ from hbmqtt.adapters import (
WebSocketsWriter)
from .plugins.manager import PluginManager, BaseContext
import sys
if sys.version_info < (3, 5):
from asyncio import async as ensure_future
_defaults = {
'timeout-disconnect-delay': 2,
@ -51,6 +55,8 @@ EVENT_BROKER_PRE_START = 'broker_pre_start'
EVENT_BROKER_POST_START = 'broker_post_start'
EVENT_BROKER_PRE_SHUTDOWN = 'broker_pre_shutdown'
EVENT_BROKER_POST_SHUTDOWN = 'broker_post_shutdown'
EVENT_BROKER_CLIENT_CONNECTED = 'broker_client_connected'
EVENT_BROKER_CLIENT_DISCONNECTED = 'broker_client_disconnected'
class BrokerException(BaseException):
@ -407,131 +413,74 @@ class Broker:
@asyncio.coroutine
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
# Wait for connection available
server = self._servers[listener_name]
# Wait for connection available on listener
server = self._servers.get(listener_name, None)
if not server:
raise BrokerException("Invalid listener name '%s'" % listener_name)
yield from server.acquire_connection()
remote_address, remote_port = writer.get_peer_info()
self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
# Wait for first packet and expect a CONNECT
connect = None
try:
connect = yield from ConnectPacket.from_stream(reader)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
self.check_connect(connect)
handler, client_session = yield from BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
except HBMQTTException as exc:
self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
(format_client_message(address=remote_address, port=remote_port), exc))
yield from writer.close()
self.logger.debug("Connection closed")
return
except BrokerException as be:
except MQTTException as me:
self.logger.error('Invalid connection from %s : %s' %
(format_client_message(address=remote_address, port=remote_port), be))
yield from writer.close()
self.logger.debug("Connection closed")
return
if connect.proto_name != "MQTT":
self.logger.warn('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.variable_header.protocol_name)
(format_client_message(address=remote_address, port=remote_port), me))
yield from writer.close()
self.logger.debug("Connection closed")
return
connack = None
if connect.proto_level != 4:
# only MQTT 3.1.1 supported
self.logger.error('Invalid protocol from %s: %d' %
(format_client_message(address=remote_address, port=remote_port),
connect.variable_header.protocol_level))
connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
elif connect.username_flag and connect.username is None:
self.logger.error('Invalid username from %s' %
(format_client_message(address=remote_address, port=remote_port)))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.password_flag and connect.password is None:
self.logger.error('Invalid password %s' % (format_client_message(address=remote_address, port=remote_port)))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.clean_session_flag is False and connect.payload.client_id is None:
self.logger.error('[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' %
format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
if connack is not None:
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
yield from connack.to_stream(writer)
yield from writer.close()
return
client_session = None
self.logger.debug("Clean session={0}".format(connect.clean_session_flag))
self.logger.debug("known sessions={0}".format(self._sessions))
client_id = connect.client_id
if connect.clean_session_flag:
if client_session.clean_session:
# Delete existing session and create a new one
if client_id is not None:
self.delete_session(client_id)
if client_session.client_id is not None:
self.delete_session(client_session.client_id)
else:
client_id = gen_client_id()
client_session = Session()
client_session.client_id = gen_client_id()
client_session.parent = 0
client_session.client_id = client_id
else:
# Get session from cache
if client_id in self._sessions:
self.logger.debug("Found old session %s" % repr(self._sessions[client_id]))
client_session = self._sessions[client_id]
if client_session.client_id in self._sessions:
self.logger.debug("Found old session %s" % repr(self._sessions[client_session.client_id]))
(client_session,) = self._sessions[client_session.client_id]
client_session.parent = 1
else:
client_session = Session()
client_session.client_id = client_id
client_session.parent = 0
client_session.remote_address = remote_address
client_session.remote_port = remote_port
client_session.clean_session = connect.clean_session_flag
client_session.will_flag = connect.will_flag
client_session.will_retain = connect.will_retain_flag
client_session.will_qos = connect.will_qos
client_session.will_topic = connect.will_topic
client_session.will_message = connect.will_message
client_session.username = connect.username
client_session.password = connect.password
if connect.keep_alive > 0:
client_session.keep_alive = connect.keep_alive + self.config['timeout-disconnect-delay']
else:
client_session.keep_alive = 0
if client_session.keep_alive > 0:
client_session.keep_alive += self.config['timeout-disconnect-delay']
self.logger.debug("Keep-alive timeout=%d" % client_session.keep_alive)
client_session.publish_retry_delay = self.config['publish-retry-delay']
handler.attach(client_session, reader, writer)
self._sessions[client_session.client_id] = (client_session, handler)
authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
if authenticated:
connack = ConnackPacket.build(client_session.parent, CONNECTION_ACCEPTED)
self.logger.info('%s : connection accepted' % format_client_message(session=client_session))
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack, session=client_session)
yield from connack.to_stream(writer)
else:
connack = ConnackPacket.build(client_session.parent, NOT_AUTHORIZED)
self.logger.info('%s : connection refused' % format_client_message(session=client_session))
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack, session=client_session)
yield from connack.to_stream(writer)
yield from handler.mqtt_connack_authorize(authenticated)
if not authenticated:
yield from writer.close()
return
client_session.transitions.connect()
handler = self._init_handler(client_session, reader, writer)
self._sessions[client_id] = (client_session, handler)
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, session=client_session)
self.logger.debug("%s Start messages handling" % client_session.client_id)
yield from handler.start()
self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize())
yield from self.publish_session_retained_messages(client_session)
self.logger.debug("%s Wait for disconnect" % client_session.client_id)
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
connected = True
disconnect_waiter = asyncio.Task(handler.wait_disconnect(), loop=self._loop)
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription(), loop=self._loop)
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop)
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop)
while connected:
done, pending = yield from asyncio.wait(
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
@ -586,6 +535,7 @@ class Broker:
# Acknowledge message delivery
yield from handler.mqtt_acknowledge_delivery(packet_id)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
disconnect_waiter.cancel()
subscribe_waiter.cancel()
unsubscribe_waiter.cancel()
wait_deliver.cancel()
@ -593,6 +543,7 @@ class Broker:
self.logger.debug("%s Client disconnecting" % client_session.client_id)
yield from self._stop_handler(handler)
client_session.transitions.disconnect()
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, session=client_session)
yield from writer.close()
self.logger.debug("%s Session disconnected" % client_session.client_id)
server.release_connection()
@ -602,8 +553,8 @@ class Broker:
Create a BrokerProtocolHandler and attach to a session
:return:
"""
handler = BrokerProtocolHandler(session, self.plugins_manager, self._loop)
handler.attach_stream(reader, writer)
handler = BrokerProtocolHandler(self.plugins_manager, self._loop)
handler.attach(session, reader, writer)
handler.on_packet_received.connect(self.sys_handle_packet_received)
handler.on_packet_sent.connect(self.sys_handle_packet_sent)
return handler
@ -620,17 +571,6 @@ class Broker:
except Exception as e:
self.logger.error(e)
def check_connect(self, connect: ConnectPacket):
if connect.payload.client_id is None:
raise BrokerException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
if connect.variable_header.will_flag:
if connect.payload.will_topic is None or connect.payload.will_message is None:
raise BrokerException('will flag set, but will topic/message not present in payload')
if connect.variable_header.reserved_flag:
raise BrokerException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
@asyncio.coroutine
def authenticate(self, session: Session, listener):
"""
@ -653,13 +593,14 @@ class Broker:
session=session,
filter_plugins=auth_plugins)
auth_result = True
for plugin in returns:
res = returns[plugin]
if res is False:
auth_result = False
self.logger.debug("Authentication failed due to '%s' plugin result: %s" % (plugin.name, res))
else:
self.logger.debug("'%s' plugin result: %s" % (plugin.name, res))
if returns:
for plugin in returns:
res = returns[plugin]
if res is False:
auth_result = False
self.logger.debug("Authentication failed due to '%s' plugin result: %s" % (plugin.name, res))
else:
self.logger.debug("'%s' plugin result: %s" % (plugin.name, res))
# If all plugins returned True, authentication is success
return auth_result

Wyświetl plik

@ -19,13 +19,16 @@ from hbmqtt.mqtt.constants import *
import websockets
from websockets.uri import InvalidURI
from websockets.handshake import InvalidHandshake
from collections import deque
_defaults = {
'keep_alive': 10,
'ping_delay': 1,
'default_qos': 0,
'default_retain': False,
'auto_reconnect': True
'auto_reconnect': True,
'reconnect_max_interval': 10,
'reconnect_retries': 2,
}
@ -114,6 +117,7 @@ class MQTTClient:
context = ClientContext()
context.config = self.config
self.plugins_manager = PluginManager('hbmqtt.client.plugins', context)
self.client_tasks = deque()
@asyncio.coroutine
@ -133,7 +137,15 @@ class MQTTClient:
self.session = self._initsession(uri, cleansession, cafile, capath, cadata)
self.logger.debug("Connect to: %s" % uri)
return (yield from self._do_connect())
try:
return (yield from self._do_connect())
except BaseException as be:
self.logger.warning("Connection failed: %r" % be)
auto_reconnect = self.config.get('auto_reconnect', False)
if not auto_reconnect:
raise
else:
return (yield from self.reconnect())
@asyncio.coroutine
@mqtt_connected
@ -157,12 +169,30 @@ class MQTTClient:
if cleansession:
self.session.clean_session = cleansession
self.logger.debug("Reconnecting with session parameters: %s" % self.session)
return (yield from self._do_connect())
reconnect_max_interval = self.config.get('reconnect_max_interval', 10)
reconnect_retries = self.config.get('reconnect_retries', 5)
nb_attempt = 1
yield from asyncio.sleep(1, loop=self._loop)
while True:
try:
self.logger.debug("Reconnect attempt %d ..." % nb_attempt)
return (yield from self._do_connect())
except BaseException as e:
self.logger.warning("Reconnection attempt failed: %r" % e)
if nb_attempt > reconnect_retries:
self.logger.error("Maximum number of connection attempts reached. Reconnection aborted")
raise ConnectException("Too many connection attempts failed")
exp = 2 ** nb_attempt
delay = exp if exp < reconnect_max_interval else reconnect_max_interval
self.logger.debug("Waiting %d second before next attempt" % delay)
yield from asyncio.sleep(delay, loop=self._loop)
nb_attempt += 1
@asyncio.coroutine
def _do_connect(self):
return_code = yield from self._connect_coro()
self._disconnect_task = asyncio.Task(self.handle_connection_close(), loop=self._loop)
self._disconnect_task = asyncio.ensure_future(self.handle_connection_close(), loop=self._loop)
return return_code
@asyncio.coroutine
@ -213,6 +243,17 @@ class MQTTClient:
def unsubscribe(self, topics):
yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
@asyncio.coroutine
def deliver_message(self, timeout=None):
deliver_task = asyncio.ensure_future(self._handler.mqtt_deliver_next_message(), loop=self._loop)
self.client_tasks.append(deliver_task)
self.logger.debug("Waiting message delivery")
message = yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
if deliver_task.exception():
raise deliver_task.exception()
self.client_tasks.pop()
return message
@asyncio.coroutine
def _connect_coro(self):
kwargs = dict()
@ -235,8 +276,8 @@ class MQTTClient:
uri_attributes[3], uri_attributes[4], uri_attributes[5])
self.session.broker_uri = urlunparse(uri)
# Init protocol handler
if not self._handler:
self._handler = ClientProtocolHandler(self.session, self.plugins_manager, loop=self._loop)
#if not self._handler:
self._handler = ClientProtocolHandler(self.plugins_manager, loop=self._loop)
if secure:
if self.session.cafile is None or self.session.cafile == '':
@ -252,6 +293,8 @@ class MQTTClient:
kwargs['ssl'] = sc
try:
reader = None
writer = None
self._connected_state.clear()
# Open connection
if scheme in ('mqtt', 'mqtts'):
@ -270,7 +313,7 @@ class MQTTClient:
reader = WebSocketsReader(websocket)
writer = WebSocketsWriter(websocket)
# Start MQTT protocol
self._handler.attach_stream(reader, writer)
self._handler.attach(self.session, reader, writer)
return_code = yield from self._handler.mqtt_connect()
if return_code is not CONNECTION_ACCEPTED:
self.session.transitions.disconnect()
@ -293,23 +336,39 @@ class MQTTClient:
self.logger.warn("connection failed: invalid websocket handshake")
self.session.transitions.disconnect()
raise ConnectException("connection failed: invalid websocket handshake", ihs)
except ProtocolHandlerException as e:
self.logger.warn("MQTT connection failed: %s" % e)
except (ProtocolHandlerException, ConnectionError, OSError) as e:
self.logger.warn("MQTT connection failed: %r" % e)
self.session.transitions.disconnect()
raise ClientException("connection Failed: %s" % e)
raise ConnectException(e)
@asyncio.coroutine
def handle_connection_close(self):
self.logger.debug("Watch broker disconnection")
# Wait for disconnection from broker (like connection lost)
yield from self._handler.wait_disconnect()
self._connected_state.clear()
self.logger.warning("Disconnected from broker")
# Block client API
self._connected_state.clear()
# stop an clean handler
yield from self._handler.stop()
self._handler.detach_stream()
self._handler.detach()
self.session.transitions.disconnect()
if self.config.get('auto_reconnect', False):
# Try reconnection
self.logger.debug("Auto-reconnecting")
yield from self.reconnect()
try:
yield from self.reconnect()
except ConnectException:
# Cancel client pending tasks
while self.client_tasks:
self.client_tasks.popleft().set_exception(ClientException("Connection lost"))
else:
# Cancel client pending tasks
while self.client_tasks:
self.client_tasks.popleft().set_exception(ClientException("Connection lost"))
def _initsession(
self,

Wyświetl plik

@ -2,9 +2,11 @@
#
# See the file license.txt for copying permission.
import asyncio
import logging
from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler
from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.connack import *
from hbmqtt.mqtt.pingreq import PingReqPacket
from hbmqtt.mqtt.pingresp import PingRespPacket
from hbmqtt.mqtt.subscribe import SubscribePacket
@ -14,11 +16,14 @@ from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.utils import format_client_message
from hbmqtt.session import Session
from hbmqtt.plugins.manager import PluginManager
from hbmqtt.adapters import ReaderAdapter, WriterAdapter
from hbmqtt.errors import MQTTException
from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
class BrokerProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, plugins_manager: PluginManager, loop=None):
super().__init__(session, plugins_manager, loop)
def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
super().__init__(plugins_manager, session, loop)
self._disconnect_waiter = None
self._pending_subscriptions = asyncio.Queue(loop=self._loop)
self._pending_unsubscriptions = asyncio.Queue(loop=self._loop)
@ -97,3 +102,81 @@ class BrokerProtocolHandler(ProtocolHandler):
def mqtt_acknowledge_unsubscription(self, packet_id):
unsuback = UnsubackPacket.build(packet_id)
yield from self._send_packet(unsuback)
@asyncio.coroutine
def mqtt_connack_authorize(self, authorize: bool):
if authorize:
connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED)
else:
connack = ConnackPacket.build(self.session.parent, NOT_AUTHORIZED)
yield from self._send_packet(connack)
@classmethod
@asyncio.coroutine
def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
"""
:param reader:
:param writer:
:param plugins_manager:
:param loop:
:return:
"""
log = logging.getLogger(__name__)
remote_address, remote_port = writer.get_peer_info()
connect = yield from ConnectPacket.from_stream(reader)
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
if connect.payload.client_id is None:
raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
if connect.variable_header.will_flag:
if connect.payload.will_topic is None or connect.payload.will_message is None:
raise MQTTException('will flag set, but will topic/message not present in payload')
if connect.variable_header.reserved_flag:
raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
if connect.proto_name != "MQTT":
raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.variable_header.protocol_name)
connack = None
error_msg = None
if connect.proto_level != 4:
# only MQTT 3.1.1 supported
error_msg = 'Invalid protocol from %s: %d' % \
(format_client_message(address=remote_address, port=remote_port),
connect.variable_header.protocol_level)
connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
elif connect.username_flag and connect.username is None:
error_msg = 'Invalid username from %s' % \
(format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.password_flag and connect.password is None:
error_msg = 'Invalid password %s' % (format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.clean_session_flag is False and connect.payload.client_id is None:
error_msg = '[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' % \
format_client_message(address=remote_address, port=remote_port)
connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
if connack is not None:
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
yield from connack.to_stream(writer)
yield from writer.close()
raise MQTTException(error_msg)
incoming_session = Session()
incoming_session.client_id = connect.client_id
incoming_session.clean_session = connect.clean_session_flag
incoming_session.will_flag = connect.will_flag
incoming_session.will_retain = connect.will_retain_flag
incoming_session.will_qos = connect.will_qos
incoming_session.will_topic = connect.will_topic
incoming_session.will_message = connect.will_message
incoming_session.username = connect.username
incoming_session.password = connect.password
if connect.keep_alive > 0:
incoming_session.keep_alive = connect.keep_alive
else:
incoming_session.keep_alive = 0
handler = cls(plugins_manager, loop)
return handler, incoming_session

Wyświetl plik

@ -2,7 +2,7 @@
#
# See the file license.txt for copying permission.
from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler, ProtocolHandlerException
from hbmqtt.mqtt.protocol.handler import ProtocolHandler, EVENT_MQTT_PACKET_RECEIVED
from hbmqtt.mqtt.packet import *
from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.mqtt.pingreq import PingReqPacket
@ -18,8 +18,8 @@ from hbmqtt.plugins.manager import PluginManager
class ClientProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, plugins_manager: PluginManager, loop=None):
super().__init__(session, plugins_manager, loop=loop)
def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
super().__init__(plugins_manager, session, loop=loop)
self._ping_task = None
self._pingresp_queue = asyncio.Queue(loop=self._loop)
self._subscriptions_waiter = dict()
@ -38,11 +38,15 @@ class ClientProtocolHandler(ProtocolHandler):
yield from super().stop()
if self._ping_task:
try:
self.logger.debug("Cancel ping task")
self._ping_task.cancel()
except Exception:
except BaseException:
pass
if self._pingresp_waiter:
self._pingresp_waiter.cancel()
if not self._disconnect_waiter.done():
self._disconnect_waiter.cancel()
self._disconnect_waiter = None
def _build_connect_packet(self):
vh = ConnectVariableHeader()
@ -80,10 +84,16 @@ class ClientProtocolHandler(ProtocolHandler):
connect_packet = self._build_connect_packet()
yield from self._send_packet(connect_packet)
connack = yield from ConnackPacket.from_stream(self.reader)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
return connack.return_code
def handle_write_timeout(self):
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
try:
self.logger.debug("Scheduling Ping")
if not self._ping_task:
self._ping_task = asyncio.ensure_future(self.mqtt_ping())
except BaseException as be:
self.logger.debug("Exception ignored in ping task: %r" % be)
def handle_read_timeout(self):
pass
@ -143,7 +153,6 @@ class ClientProtocolHandler(ProtocolHandler):
def mqtt_disconnect(self):
disconnect_packet = DisconnectPacket()
yield from self._send_packet(disconnect_packet)
self._connack_waiter = None
@asyncio.coroutine
def mqtt_ping(self):

Wyświetl plik

@ -50,18 +50,16 @@ class ProtocolHandler:
on_packet_sent = Signal()
on_packet_received = Signal()
def __init__(self, session: Session, plugins_manager: PluginManager, loop=None):
log = logging.getLogger(__name__)
self.logger = logging.LoggerAdapter(log, {'client_id': session.client_id})
self.session = session
def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
self.logger = logging.getLogger(__name__)
if session:
self._init_session(session)
else:
self.session = None
self.reader = None
self.writer = None
self.plugins_manager = plugins_manager
self.keepalive_timeout = self.session.keep_alive
if self.keepalive_timeout <= 0:
self.keepalive_timeout = None
if loop is None:
self._loop = asyncio.get_event_loop()
else:
@ -76,18 +74,29 @@ class ProtocolHandler:
self._pubrel_waiters = dict()
self._pubcomp_waiters = dict()
def attach_stream(self, reader: ReaderAdapter, writer: WriterAdapter):
if self.reader or self.writer:
raise ProtocolHandlerException("Handler is already attached to an opened stream")
def _init_session(self, session: Session):
assert session
log = logging.getLogger(__name__)
self.session = session
self.logger = logging.LoggerAdapter(log, {'client_id': self.session.client_id})
self.keepalive_timeout = self.session.keep_alive
if self.keepalive_timeout <= 0:
self.keepalive_timeout = None
def attach(self, session, reader: ReaderAdapter, writer: WriterAdapter):
if self.session:
raise ProtocolHandlerException("Handler is already attached to a session")
self._init_session(session)
self.reader = reader
self.writer = writer
def detach_stream(self):
def detach(self):
self.session = None
self.reader = None
self.writer = None
def _is_attached(self):
if self.reader and self.writer:
if self.session:
return True
else:
return False
@ -109,13 +118,14 @@ class ProtocolHandler:
@asyncio.coroutine
def stop(self):
# Stop messages flow waiter
self._reader_task.cancel()
self._stop_waiters()
if self._keepalive_task:
self._keepalive_task.cancel()
self.logger.debug("waiting for tasks to be stopped")
yield from asyncio.wait(
[self._reader_stopped.wait()], loop=self._loop)
if not self._reader_task.done():
self._reader_task.cancel()
yield from asyncio.wait(
[self._reader_stopped.wait()], loop=self._loop)
self.logger.debug("closing writer")
try:
yield from self.writer.close()
@ -392,8 +402,8 @@ class ProtocolHandler:
self.handle_read_timeout()
except NoDataException:
self.logger.debug("%s No data available" % self.session.client_id)
except Exception as e:
self.logger.warning("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e))
except BaseException as e:
self.logger.warning("%s Unhandled exception in reader coro: %s" % (type(self).__name__, e))
break
yield from self.handle_connection_closed()
self._reader_stopped.set()
@ -412,7 +422,7 @@ class ProtocolHandler:
except ConnectionResetError as cre:
yield from self.handle_connection_closed()
raise
except Exception as e:
except BaseException as e:
self.logger.warning("Unhandled exception: %s" % e)
raise

Wyświetl plik

@ -113,3 +113,6 @@ class Session:
self.__dict__.update(state)
self.retained_messages = Queue()
self.delivered_message_queue = Queue()
def __eq__(self, other):
return self.client_id == other.client_id

Wyświetl plik

@ -1,7 +1,7 @@
import logging
import asyncio
from hbmqtt.client import MQTTClient
from hbmqtt.client import MQTTClient, ClientException
from hbmqtt.mqtt.constants import QOS_1, QOS_2
@ -17,22 +17,25 @@ C = MQTTClient()
@asyncio.coroutine
def uptime_coro():
yield from C.connect('mqtt://test.mosquitto.org:1883/')
yield from C.connect('mqtt://localhost/')
# Subscribe to '$SYS/broker/uptime' with QOS=1
yield from C.subscribe([
('$SYS/broker/uptime', QOS_1),
('$SYS/broker/load/#', QOS_2),
])
logger.info("Subscribed")
for i in range(1, 100):
packet = yield from C.deliver_message()
print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data)))
yield from C.unsubscribe(['$SYS/broker/uptime'])
logger.info("UnSubscribed")
yield from C.disconnect()
try:
for i in range(1, 100):
packet = yield from C.deliver_message()
print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data)))
yield from C.unsubscribe(['$SYS/broker/uptime'])
logger.info("UnSubscribed")
yield from C.disconnect()
except ClientException as ce:
logger.error("Client exception: %s" % ce)
if __name__ == '__main__':
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
logging.basicConfig(level=logging.DEBUG, format=formatter)
asyncio.get_event_loop().run_until_complete(uptime_coro())

Wyświetl plik

@ -35,8 +35,8 @@ class ProtocolHandlerTest(unittest.TestCase):
def test_init_handler(self):
s = Session()
handler = ProtocolHandler(s, self.plugin_manager, loop=self.loop)
self.assertIs(handler.session, s)
handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.assertIsNone(handler.session)
self.assertIs(handler._loop, self.loop)
self.check_empty_waiters(handler)
@ -51,8 +51,8 @@ class ProtocolHandlerTest(unittest.TestCase):
s = Session()
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888)
reader_adapted, writer_adapted = adapt(reader, writer)
handler = ProtocolHandler(s, self.plugin_manager)
handler.attach_stream(reader_adapted, writer_adapted)
handler = ProtocolHandler(self.plugin_manager)
handler.attach(s, reader_adapted, writer_adapted)
yield from self.start_handler(handler, s)
yield from self.stop_handler(handler, s)
future.set_result(True)
@ -79,15 +79,14 @@ class ProtocolHandlerTest(unittest.TestCase):
except Exception as ae:
future.set_exception(ae)
@asyncio.coroutine
def test_coro():
try:
s = Session()
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
handler = ProtocolHandler(s, self.plugin_manager, loop=self.loop)
handler.attach_stream(reader_adapted, writer_adapted)
handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
handler.attach(s, reader_adapted, writer_adapted)
yield from self.start_handler(handler, s)
message = yield from handler.mqtt_publish('/topic', b'test_data', QOS_0, False)
self.assertIsInstance(message, OutgoingApplicationMessage)
@ -130,8 +129,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False)
self.assertIsInstance(message, OutgoingApplicationMessage)
@ -182,8 +181,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False)
self.assertIsInstance(message, OutgoingApplicationMessage)
@ -220,8 +219,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage)
@ -264,8 +263,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage)
@ -385,8 +384,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.handler.start()
yield from self.stop_handler(self.handler, self.session)
if not future.done():
@ -433,8 +432,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.handler.start()
yield from self.stop_handler(self.handler, self.session)
if not future.done():

Wyświetl plik

@ -7,11 +7,26 @@ import asyncio
import logging
from hbmqtt.broker import *
from hbmqtt.mqtt.constants import *
from hbmqtt.client import MQTTClient
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
log = logging.getLogger(__name__)
test_config = {
'listeners': {
'default': {
'type': 'tcp',
'bind': 'localhost:1883',
'max_connections': 10
},
},
'sys_interval': 0,
'auth': {
'allow-anonymous': True,
}
}
class BrokerTest(unittest.TestCase):
def setUp(self):
@ -23,25 +38,12 @@ class BrokerTest(unittest.TestCase):
@patch('hbmqtt.broker.PluginManager')
def test_start_stop(self, MockPluginManager):
config = {
'listeners': {
'default': {
'type': 'tcp',
'bind': '0.0.0.0:1883',
'max_connections': 10
},
},
'sys_interval': 0,
'auth': {
'allow-anonymous': True,
}
}
def test_coro():
try:
broker = Broker(config, plugin_namespace="hbmqtt.test.plugins")
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
self.assertDictEqual(broker._sessions, {})
self.assertIn('default', broker._servers)
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_PRE_START),
@ -60,3 +62,29 @@ class BrokerTest(unittest.TestCase):
self.loop.run_until_complete(test_coro())
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager')
def test_client_connect(self, MockPluginManager):
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
client = MQTTClient()
ret = yield from client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
yield from client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_CLIENT_CONNECTED, session=client.session),
call().fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, session=client.session)], any_order=True)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
future = asyncio.Future(loop=self.loop)
self.loop.run_until_complete(test_coro())
if future.exception():
raise future.exception()

Wyświetl plik

@ -5,7 +5,7 @@ import unittest
import asyncio
import os
import logging
from hbmqtt.client import MQTTClient
from hbmqtt.client import MQTTClient, ConnectException
from hbmqtt.mqtt.constants import *
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -60,9 +60,10 @@ class MQTTClientTest(unittest.TestCase):
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
config = {'auto_reconnect': False}
client = MQTTClient(config=config)
ret = yield from client.connect('mqtt://localhost/')
except Exception as e:
except ConnectException as e:
future.set_result(True)
future = asyncio.Future(loop=self.loop)