Add broker test class + fixes in both client and broker connection management

pull/8/head
Nico 2015-10-07 22:42:04 +02:00
rodzic 58e6069656
commit 3acec1d606
10 zmienionych plików z 327 dodań i 191 usunięć

Wyświetl plik

@ -14,7 +14,7 @@ from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from hbmqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT from hbmqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
from hbmqtt.mqtt.connect import ConnectPacket from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.connack import * from hbmqtt.mqtt.connack import *
from hbmqtt.errors import HBMQTTException from hbmqtt.errors import HBMQTTException, MQTTException
from hbmqtt.utils import format_client_message, gen_client_id from hbmqtt.utils import format_client_message, gen_client_id
from hbmqtt.mqtt.packet import PUBLISH from hbmqtt.mqtt.packet import PUBLISH
from hbmqtt.codecs import int_to_bytes_str from hbmqtt.codecs import int_to_bytes_str
@ -27,6 +27,10 @@ from hbmqtt.adapters import (
WebSocketsWriter) WebSocketsWriter)
from .plugins.manager import PluginManager, BaseContext from .plugins.manager import PluginManager, BaseContext
import sys
if sys.version_info < (3, 5):
from asyncio import async as ensure_future
_defaults = { _defaults = {
'timeout-disconnect-delay': 2, 'timeout-disconnect-delay': 2,
@ -51,6 +55,8 @@ EVENT_BROKER_PRE_START = 'broker_pre_start'
EVENT_BROKER_POST_START = 'broker_post_start' EVENT_BROKER_POST_START = 'broker_post_start'
EVENT_BROKER_PRE_SHUTDOWN = 'broker_pre_shutdown' EVENT_BROKER_PRE_SHUTDOWN = 'broker_pre_shutdown'
EVENT_BROKER_POST_SHUTDOWN = 'broker_post_shutdown' EVENT_BROKER_POST_SHUTDOWN = 'broker_post_shutdown'
EVENT_BROKER_CLIENT_CONNECTED = 'broker_client_connected'
EVENT_BROKER_CLIENT_DISCONNECTED = 'broker_client_disconnected'
class BrokerException(BaseException): class BrokerException(BaseException):
@ -407,131 +413,74 @@ class Broker:
@asyncio.coroutine @asyncio.coroutine
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter): def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
# Wait for connection available # Wait for connection available on listener
server = self._servers[listener_name] server = self._servers.get(listener_name, None)
if not server:
raise BrokerException("Invalid listener name '%s'" % listener_name)
yield from server.acquire_connection() yield from server.acquire_connection()
remote_address, remote_port = writer.get_peer_info() remote_address, remote_port = writer.get_peer_info()
self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name)) self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
# Wait for first packet and expect a CONNECT # Wait for first packet and expect a CONNECT
connect = None
try: try:
connect = yield from ConnectPacket.from_stream(reader) handler, client_session = yield from BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
self.check_connect(connect)
except HBMQTTException as exc: except HBMQTTException as exc:
self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" % self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
(format_client_message(address=remote_address, port=remote_port), exc)) (format_client_message(address=remote_address, port=remote_port), exc))
yield from writer.close() yield from writer.close()
self.logger.debug("Connection closed") self.logger.debug("Connection closed")
return return
except BrokerException as be: except MQTTException as me:
self.logger.error('Invalid connection from %s : %s' % self.logger.error('Invalid connection from %s : %s' %
(format_client_message(address=remote_address, port=remote_port), be)) (format_client_message(address=remote_address, port=remote_port), me))
yield from writer.close()
self.logger.debug("Connection closed")
return
if connect.proto_name != "MQTT":
self.logger.warn('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.variable_header.protocol_name)
yield from writer.close() yield from writer.close()
self.logger.debug("Connection closed") self.logger.debug("Connection closed")
return return
connack = None if client_session.clean_session:
if connect.proto_level != 4:
# only MQTT 3.1.1 supported
self.logger.error('Invalid protocol from %s: %d' %
(format_client_message(address=remote_address, port=remote_port),
connect.variable_header.protocol_level))
connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
elif connect.username_flag and connect.username is None:
self.logger.error('Invalid username from %s' %
(format_client_message(address=remote_address, port=remote_port)))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.password_flag and connect.password is None:
self.logger.error('Invalid password %s' % (format_client_message(address=remote_address, port=remote_port)))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.clean_session_flag is False and connect.payload.client_id is None:
self.logger.error('[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' %
format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
if connack is not None:
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
yield from connack.to_stream(writer)
yield from writer.close()
return
client_session = None
self.logger.debug("Clean session={0}".format(connect.clean_session_flag))
self.logger.debug("known sessions={0}".format(self._sessions))
client_id = connect.client_id
if connect.clean_session_flag:
# Delete existing session and create a new one # Delete existing session and create a new one
if client_id is not None: if client_session.client_id is not None:
self.delete_session(client_id) self.delete_session(client_session.client_id)
else: else:
client_id = gen_client_id() client_session.client_id = gen_client_id()
client_session = Session()
client_session.parent = 0 client_session.parent = 0
client_session.client_id = client_id
else: else:
# Get session from cache # Get session from cache
if client_id in self._sessions: if client_session.client_id in self._sessions:
self.logger.debug("Found old session %s" % repr(self._sessions[client_id])) self.logger.debug("Found old session %s" % repr(self._sessions[client_session.client_id]))
client_session = self._sessions[client_id] (client_session,) = self._sessions[client_session.client_id]
client_session.parent = 1 client_session.parent = 1
else: else:
client_session = Session()
client_session.client_id = client_id
client_session.parent = 0 client_session.parent = 0
if client_session.keep_alive > 0:
client_session.remote_address = remote_address client_session.keep_alive += self.config['timeout-disconnect-delay']
client_session.remote_port = remote_port self.logger.debug("Keep-alive timeout=%d" % client_session.keep_alive)
client_session.clean_session = connect.clean_session_flag
client_session.will_flag = connect.will_flag
client_session.will_retain = connect.will_retain_flag
client_session.will_qos = connect.will_qos
client_session.will_topic = connect.will_topic
client_session.will_message = connect.will_message
client_session.username = connect.username
client_session.password = connect.password
if connect.keep_alive > 0:
client_session.keep_alive = connect.keep_alive + self.config['timeout-disconnect-delay']
else:
client_session.keep_alive = 0
client_session.publish_retry_delay = self.config['publish-retry-delay'] client_session.publish_retry_delay = self.config['publish-retry-delay']
handler.attach(client_session, reader, writer)
self._sessions[client_session.client_id] = (client_session, handler)
authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name]) authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
if authenticated: yield from handler.mqtt_connack_authorize(authenticated)
connack = ConnackPacket.build(client_session.parent, CONNECTION_ACCEPTED) if not authenticated:
self.logger.info('%s : connection accepted' % format_client_message(session=client_session))
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack, session=client_session)
yield from connack.to_stream(writer)
else:
connack = ConnackPacket.build(client_session.parent, NOT_AUTHORIZED)
self.logger.info('%s : connection refused' % format_client_message(session=client_session))
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack, session=client_session)
yield from connack.to_stream(writer)
yield from writer.close() yield from writer.close()
return return
client_session.transitions.connect() client_session.transitions.connect()
handler = self._init_handler(client_session, reader, writer) yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, session=client_session)
self._sessions[client_id] = (client_session, handler)
self.logger.debug("%s Start messages handling" % client_session.client_id) self.logger.debug("%s Start messages handling" % client_session.client_id)
yield from handler.start() yield from handler.start()
self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize()) self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize())
yield from self.publish_session_retained_messages(client_session) yield from self.publish_session_retained_messages(client_session)
self.logger.debug("%s Wait for disconnect" % client_session.client_id)
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect) # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
connected = True connected = True
disconnect_waiter = asyncio.Task(handler.wait_disconnect(), loop=self._loop) disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop)
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop) subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription(), loop=self._loop)
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop) unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop) wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop)
while connected: while connected:
done, pending = yield from asyncio.wait( done, pending = yield from asyncio.wait(
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver], [disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
@ -586,6 +535,7 @@ class Broker:
# Acknowledge message delivery # Acknowledge message delivery
yield from handler.mqtt_acknowledge_delivery(packet_id) yield from handler.mqtt_acknowledge_delivery(packet_id)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop) wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
disconnect_waiter.cancel()
subscribe_waiter.cancel() subscribe_waiter.cancel()
unsubscribe_waiter.cancel() unsubscribe_waiter.cancel()
wait_deliver.cancel() wait_deliver.cancel()
@ -593,6 +543,7 @@ class Broker:
self.logger.debug("%s Client disconnecting" % client_session.client_id) self.logger.debug("%s Client disconnecting" % client_session.client_id)
yield from self._stop_handler(handler) yield from self._stop_handler(handler)
client_session.transitions.disconnect() client_session.transitions.disconnect()
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, session=client_session)
yield from writer.close() yield from writer.close()
self.logger.debug("%s Session disconnected" % client_session.client_id) self.logger.debug("%s Session disconnected" % client_session.client_id)
server.release_connection() server.release_connection()
@ -602,8 +553,8 @@ class Broker:
Create a BrokerProtocolHandler and attach to a session Create a BrokerProtocolHandler and attach to a session
:return: :return:
""" """
handler = BrokerProtocolHandler(session, self.plugins_manager, self._loop) handler = BrokerProtocolHandler(self.plugins_manager, self._loop)
handler.attach_stream(reader, writer) handler.attach(session, reader, writer)
handler.on_packet_received.connect(self.sys_handle_packet_received) handler.on_packet_received.connect(self.sys_handle_packet_received)
handler.on_packet_sent.connect(self.sys_handle_packet_sent) handler.on_packet_sent.connect(self.sys_handle_packet_sent)
return handler return handler
@ -620,17 +571,6 @@ class Broker:
except Exception as e: except Exception as e:
self.logger.error(e) self.logger.error(e)
def check_connect(self, connect: ConnectPacket):
if connect.payload.client_id is None:
raise BrokerException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
if connect.variable_header.will_flag:
if connect.payload.will_topic is None or connect.payload.will_message is None:
raise BrokerException('will flag set, but will topic/message not present in payload')
if connect.variable_header.reserved_flag:
raise BrokerException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
@asyncio.coroutine @asyncio.coroutine
def authenticate(self, session: Session, listener): def authenticate(self, session: Session, listener):
""" """
@ -653,13 +593,14 @@ class Broker:
session=session, session=session,
filter_plugins=auth_plugins) filter_plugins=auth_plugins)
auth_result = True auth_result = True
for plugin in returns: if returns:
res = returns[plugin] for plugin in returns:
if res is False: res = returns[plugin]
auth_result = False if res is False:
self.logger.debug("Authentication failed due to '%s' plugin result: %s" % (plugin.name, res)) auth_result = False
else: self.logger.debug("Authentication failed due to '%s' plugin result: %s" % (plugin.name, res))
self.logger.debug("'%s' plugin result: %s" % (plugin.name, res)) else:
self.logger.debug("'%s' plugin result: %s" % (plugin.name, res))
# If all plugins returned True, authentication is success # If all plugins returned True, authentication is success
return auth_result return auth_result

Wyświetl plik

@ -19,13 +19,16 @@ from hbmqtt.mqtt.constants import *
import websockets import websockets
from websockets.uri import InvalidURI from websockets.uri import InvalidURI
from websockets.handshake import InvalidHandshake from websockets.handshake import InvalidHandshake
from collections import deque
_defaults = { _defaults = {
'keep_alive': 10, 'keep_alive': 10,
'ping_delay': 1, 'ping_delay': 1,
'default_qos': 0, 'default_qos': 0,
'default_retain': False, 'default_retain': False,
'auto_reconnect': True 'auto_reconnect': True,
'reconnect_max_interval': 10,
'reconnect_retries': 2,
} }
@ -114,6 +117,7 @@ class MQTTClient:
context = ClientContext() context = ClientContext()
context.config = self.config context.config = self.config
self.plugins_manager = PluginManager('hbmqtt.client.plugins', context) self.plugins_manager = PluginManager('hbmqtt.client.plugins', context)
self.client_tasks = deque()
@asyncio.coroutine @asyncio.coroutine
@ -133,7 +137,15 @@ class MQTTClient:
self.session = self._initsession(uri, cleansession, cafile, capath, cadata) self.session = self._initsession(uri, cleansession, cafile, capath, cadata)
self.logger.debug("Connect to: %s" % uri) self.logger.debug("Connect to: %s" % uri)
return (yield from self._do_connect()) try:
return (yield from self._do_connect())
except BaseException as be:
self.logger.warning("Connection failed: %r" % be)
auto_reconnect = self.config.get('auto_reconnect', False)
if not auto_reconnect:
raise
else:
return (yield from self.reconnect())
@asyncio.coroutine @asyncio.coroutine
@mqtt_connected @mqtt_connected
@ -157,12 +169,30 @@ class MQTTClient:
if cleansession: if cleansession:
self.session.clean_session = cleansession self.session.clean_session = cleansession
self.logger.debug("Reconnecting with session parameters: %s" % self.session) self.logger.debug("Reconnecting with session parameters: %s" % self.session)
return (yield from self._do_connect()) reconnect_max_interval = self.config.get('reconnect_max_interval', 10)
reconnect_retries = self.config.get('reconnect_retries', 5)
nb_attempt = 1
yield from asyncio.sleep(1, loop=self._loop)
while True:
try:
self.logger.debug("Reconnect attempt %d ..." % nb_attempt)
return (yield from self._do_connect())
except BaseException as e:
self.logger.warning("Reconnection attempt failed: %r" % e)
if nb_attempt > reconnect_retries:
self.logger.error("Maximum number of connection attempts reached. Reconnection aborted")
raise ConnectException("Too many connection attempts failed")
exp = 2 ** nb_attempt
delay = exp if exp < reconnect_max_interval else reconnect_max_interval
self.logger.debug("Waiting %d second before next attempt" % delay)
yield from asyncio.sleep(delay, loop=self._loop)
nb_attempt += 1
@asyncio.coroutine @asyncio.coroutine
def _do_connect(self): def _do_connect(self):
return_code = yield from self._connect_coro() return_code = yield from self._connect_coro()
self._disconnect_task = asyncio.Task(self.handle_connection_close(), loop=self._loop) self._disconnect_task = asyncio.ensure_future(self.handle_connection_close(), loop=self._loop)
return return_code return return_code
@asyncio.coroutine @asyncio.coroutine
@ -213,6 +243,17 @@ class MQTTClient:
def unsubscribe(self, topics): def unsubscribe(self, topics):
yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id) yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
@asyncio.coroutine
def deliver_message(self, timeout=None):
deliver_task = asyncio.ensure_future(self._handler.mqtt_deliver_next_message(), loop=self._loop)
self.client_tasks.append(deliver_task)
self.logger.debug("Waiting message delivery")
message = yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
if deliver_task.exception():
raise deliver_task.exception()
self.client_tasks.pop()
return message
@asyncio.coroutine @asyncio.coroutine
def _connect_coro(self): def _connect_coro(self):
kwargs = dict() kwargs = dict()
@ -235,8 +276,8 @@ class MQTTClient:
uri_attributes[3], uri_attributes[4], uri_attributes[5]) uri_attributes[3], uri_attributes[4], uri_attributes[5])
self.session.broker_uri = urlunparse(uri) self.session.broker_uri = urlunparse(uri)
# Init protocol handler # Init protocol handler
if not self._handler: #if not self._handler:
self._handler = ClientProtocolHandler(self.session, self.plugins_manager, loop=self._loop) self._handler = ClientProtocolHandler(self.plugins_manager, loop=self._loop)
if secure: if secure:
if self.session.cafile is None or self.session.cafile == '': if self.session.cafile is None or self.session.cafile == '':
@ -252,6 +293,8 @@ class MQTTClient:
kwargs['ssl'] = sc kwargs['ssl'] = sc
try: try:
reader = None
writer = None
self._connected_state.clear() self._connected_state.clear()
# Open connection # Open connection
if scheme in ('mqtt', 'mqtts'): if scheme in ('mqtt', 'mqtts'):
@ -270,7 +313,7 @@ class MQTTClient:
reader = WebSocketsReader(websocket) reader = WebSocketsReader(websocket)
writer = WebSocketsWriter(websocket) writer = WebSocketsWriter(websocket)
# Start MQTT protocol # Start MQTT protocol
self._handler.attach_stream(reader, writer) self._handler.attach(self.session, reader, writer)
return_code = yield from self._handler.mqtt_connect() return_code = yield from self._handler.mqtt_connect()
if return_code is not CONNECTION_ACCEPTED: if return_code is not CONNECTION_ACCEPTED:
self.session.transitions.disconnect() self.session.transitions.disconnect()
@ -293,23 +336,39 @@ class MQTTClient:
self.logger.warn("connection failed: invalid websocket handshake") self.logger.warn("connection failed: invalid websocket handshake")
self.session.transitions.disconnect() self.session.transitions.disconnect()
raise ConnectException("connection failed: invalid websocket handshake", ihs) raise ConnectException("connection failed: invalid websocket handshake", ihs)
except ProtocolHandlerException as e: except (ProtocolHandlerException, ConnectionError, OSError) as e:
self.logger.warn("MQTT connection failed: %s" % e) self.logger.warn("MQTT connection failed: %r" % e)
self.session.transitions.disconnect() self.session.transitions.disconnect()
raise ClientException("connection Failed: %s" % e) raise ConnectException(e)
@asyncio.coroutine @asyncio.coroutine
def handle_connection_close(self): def handle_connection_close(self):
self.logger.debug("Watch broker disconnection") self.logger.debug("Watch broker disconnection")
# Wait for disconnection from broker (like connection lost)
yield from self._handler.wait_disconnect() yield from self._handler.wait_disconnect()
self._connected_state.clear()
self.logger.warning("Disconnected from broker") self.logger.warning("Disconnected from broker")
# Block client API
self._connected_state.clear()
# stop an clean handler
yield from self._handler.stop() yield from self._handler.stop()
self._handler.detach_stream() self._handler.detach()
self.session.transitions.disconnect() self.session.transitions.disconnect()
if self.config.get('auto_reconnect', False): if self.config.get('auto_reconnect', False):
# Try reconnection
self.logger.debug("Auto-reconnecting") self.logger.debug("Auto-reconnecting")
yield from self.reconnect() try:
yield from self.reconnect()
except ConnectException:
# Cancel client pending tasks
while self.client_tasks:
self.client_tasks.popleft().set_exception(ClientException("Connection lost"))
else:
# Cancel client pending tasks
while self.client_tasks:
self.client_tasks.popleft().set_exception(ClientException("Connection lost"))
def _initsession( def _initsession(
self, self,

Wyświetl plik

@ -2,9 +2,11 @@
# #
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
import asyncio import asyncio
import logging
from asyncio import futures from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler from hbmqtt.mqtt.protocol.handler import ProtocolHandler
from hbmqtt.mqtt.connect import ConnectPacket from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.connack import *
from hbmqtt.mqtt.pingreq import PingReqPacket from hbmqtt.mqtt.pingreq import PingReqPacket
from hbmqtt.mqtt.pingresp import PingRespPacket from hbmqtt.mqtt.pingresp import PingRespPacket
from hbmqtt.mqtt.subscribe import SubscribePacket from hbmqtt.mqtt.subscribe import SubscribePacket
@ -14,11 +16,14 @@ from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.utils import format_client_message from hbmqtt.utils import format_client_message
from hbmqtt.session import Session from hbmqtt.session import Session
from hbmqtt.plugins.manager import PluginManager from hbmqtt.plugins.manager import PluginManager
from hbmqtt.adapters import ReaderAdapter, WriterAdapter
from hbmqtt.errors import MQTTException
from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
class BrokerProtocolHandler(ProtocolHandler): class BrokerProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, plugins_manager: PluginManager, loop=None): def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
super().__init__(session, plugins_manager, loop) super().__init__(plugins_manager, session, loop)
self._disconnect_waiter = None self._disconnect_waiter = None
self._pending_subscriptions = asyncio.Queue(loop=self._loop) self._pending_subscriptions = asyncio.Queue(loop=self._loop)
self._pending_unsubscriptions = asyncio.Queue(loop=self._loop) self._pending_unsubscriptions = asyncio.Queue(loop=self._loop)
@ -97,3 +102,81 @@ class BrokerProtocolHandler(ProtocolHandler):
def mqtt_acknowledge_unsubscription(self, packet_id): def mqtt_acknowledge_unsubscription(self, packet_id):
unsuback = UnsubackPacket.build(packet_id) unsuback = UnsubackPacket.build(packet_id)
yield from self._send_packet(unsuback) yield from self._send_packet(unsuback)
@asyncio.coroutine
def mqtt_connack_authorize(self, authorize: bool):
if authorize:
connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED)
else:
connack = ConnackPacket.build(self.session.parent, NOT_AUTHORIZED)
yield from self._send_packet(connack)
@classmethod
@asyncio.coroutine
def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
"""
:param reader:
:param writer:
:param plugins_manager:
:param loop:
:return:
"""
log = logging.getLogger(__name__)
remote_address, remote_port = writer.get_peer_info()
connect = yield from ConnectPacket.from_stream(reader)
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
if connect.payload.client_id is None:
raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
if connect.variable_header.will_flag:
if connect.payload.will_topic is None or connect.payload.will_message is None:
raise MQTTException('will flag set, but will topic/message not present in payload')
if connect.variable_header.reserved_flag:
raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
if connect.proto_name != "MQTT":
raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.variable_header.protocol_name)
connack = None
error_msg = None
if connect.proto_level != 4:
# only MQTT 3.1.1 supported
error_msg = 'Invalid protocol from %s: %d' % \
(format_client_message(address=remote_address, port=remote_port),
connect.variable_header.protocol_level)
connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
elif connect.username_flag and connect.username is None:
error_msg = 'Invalid username from %s' % \
(format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.password_flag and connect.password is None:
error_msg = 'Invalid password %s' % (format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.clean_session_flag is False and connect.payload.client_id is None:
error_msg = '[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' % \
format_client_message(address=remote_address, port=remote_port)
connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
if connack is not None:
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
yield from connack.to_stream(writer)
yield from writer.close()
raise MQTTException(error_msg)
incoming_session = Session()
incoming_session.client_id = connect.client_id
incoming_session.clean_session = connect.clean_session_flag
incoming_session.will_flag = connect.will_flag
incoming_session.will_retain = connect.will_retain_flag
incoming_session.will_qos = connect.will_qos
incoming_session.will_topic = connect.will_topic
incoming_session.will_message = connect.will_message
incoming_session.username = connect.username
incoming_session.password = connect.password
if connect.keep_alive > 0:
incoming_session.keep_alive = connect.keep_alive
else:
incoming_session.keep_alive = 0
handler = cls(plugins_manager, loop)
return handler, incoming_session

Wyświetl plik

@ -2,7 +2,7 @@
# #
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
from asyncio import futures from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler, ProtocolHandlerException from hbmqtt.mqtt.protocol.handler import ProtocolHandler, EVENT_MQTT_PACKET_RECEIVED
from hbmqtt.mqtt.packet import * from hbmqtt.mqtt.packet import *
from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.mqtt.pingreq import PingReqPacket from hbmqtt.mqtt.pingreq import PingReqPacket
@ -18,8 +18,8 @@ from hbmqtt.plugins.manager import PluginManager
class ClientProtocolHandler(ProtocolHandler): class ClientProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, plugins_manager: PluginManager, loop=None): def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
super().__init__(session, plugins_manager, loop=loop) super().__init__(plugins_manager, session, loop=loop)
self._ping_task = None self._ping_task = None
self._pingresp_queue = asyncio.Queue(loop=self._loop) self._pingresp_queue = asyncio.Queue(loop=self._loop)
self._subscriptions_waiter = dict() self._subscriptions_waiter = dict()
@ -38,11 +38,15 @@ class ClientProtocolHandler(ProtocolHandler):
yield from super().stop() yield from super().stop()
if self._ping_task: if self._ping_task:
try: try:
self.logger.debug("Cancel ping task")
self._ping_task.cancel() self._ping_task.cancel()
except Exception: except BaseException:
pass pass
if self._pingresp_waiter: if self._pingresp_waiter:
self._pingresp_waiter.cancel() self._pingresp_waiter.cancel()
if not self._disconnect_waiter.done():
self._disconnect_waiter.cancel()
self._disconnect_waiter = None
def _build_connect_packet(self): def _build_connect_packet(self):
vh = ConnectVariableHeader() vh = ConnectVariableHeader()
@ -80,10 +84,16 @@ class ClientProtocolHandler(ProtocolHandler):
connect_packet = self._build_connect_packet() connect_packet = self._build_connect_packet()
yield from self._send_packet(connect_packet) yield from self._send_packet(connect_packet)
connack = yield from ConnackPacket.from_stream(self.reader) connack = yield from ConnackPacket.from_stream(self.reader)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
return connack.return_code return connack.return_code
def handle_write_timeout(self): def handle_write_timeout(self):
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping()) try:
self.logger.debug("Scheduling Ping")
if not self._ping_task:
self._ping_task = asyncio.ensure_future(self.mqtt_ping())
except BaseException as be:
self.logger.debug("Exception ignored in ping task: %r" % be)
def handle_read_timeout(self): def handle_read_timeout(self):
pass pass
@ -143,7 +153,6 @@ class ClientProtocolHandler(ProtocolHandler):
def mqtt_disconnect(self): def mqtt_disconnect(self):
disconnect_packet = DisconnectPacket() disconnect_packet = DisconnectPacket()
yield from self._send_packet(disconnect_packet) yield from self._send_packet(disconnect_packet)
self._connack_waiter = None
@asyncio.coroutine @asyncio.coroutine
def mqtt_ping(self): def mqtt_ping(self):

Wyświetl plik

@ -50,18 +50,16 @@ class ProtocolHandler:
on_packet_sent = Signal() on_packet_sent = Signal()
on_packet_received = Signal() on_packet_received = Signal()
def __init__(self, session: Session, plugins_manager: PluginManager, loop=None): def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
log = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.logger = logging.LoggerAdapter(log, {'client_id': session.client_id}) if session:
self.session = session self._init_session(session)
else:
self.session = None
self.reader = None self.reader = None
self.writer = None self.writer = None
self.plugins_manager = plugins_manager self.plugins_manager = plugins_manager
self.keepalive_timeout = self.session.keep_alive
if self.keepalive_timeout <= 0:
self.keepalive_timeout = None
if loop is None: if loop is None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
else: else:
@ -76,18 +74,29 @@ class ProtocolHandler:
self._pubrel_waiters = dict() self._pubrel_waiters = dict()
self._pubcomp_waiters = dict() self._pubcomp_waiters = dict()
def attach_stream(self, reader: ReaderAdapter, writer: WriterAdapter): def _init_session(self, session: Session):
if self.reader or self.writer: assert session
raise ProtocolHandlerException("Handler is already attached to an opened stream") log = logging.getLogger(__name__)
self.session = session
self.logger = logging.LoggerAdapter(log, {'client_id': self.session.client_id})
self.keepalive_timeout = self.session.keep_alive
if self.keepalive_timeout <= 0:
self.keepalive_timeout = None
def attach(self, session, reader: ReaderAdapter, writer: WriterAdapter):
if self.session:
raise ProtocolHandlerException("Handler is already attached to a session")
self._init_session(session)
self.reader = reader self.reader = reader
self.writer = writer self.writer = writer
def detach_stream(self): def detach(self):
self.session = None
self.reader = None self.reader = None
self.writer = None self.writer = None
def _is_attached(self): def _is_attached(self):
if self.reader and self.writer: if self.session:
return True return True
else: else:
return False return False
@ -109,13 +118,14 @@ class ProtocolHandler:
@asyncio.coroutine @asyncio.coroutine
def stop(self): def stop(self):
# Stop messages flow waiter # Stop messages flow waiter
self._reader_task.cancel()
self._stop_waiters() self._stop_waiters()
if self._keepalive_task: if self._keepalive_task:
self._keepalive_task.cancel() self._keepalive_task.cancel()
self.logger.debug("waiting for tasks to be stopped") self.logger.debug("waiting for tasks to be stopped")
yield from asyncio.wait( if not self._reader_task.done():
[self._reader_stopped.wait()], loop=self._loop) self._reader_task.cancel()
yield from asyncio.wait(
[self._reader_stopped.wait()], loop=self._loop)
self.logger.debug("closing writer") self.logger.debug("closing writer")
try: try:
yield from self.writer.close() yield from self.writer.close()
@ -392,8 +402,8 @@ class ProtocolHandler:
self.handle_read_timeout() self.handle_read_timeout()
except NoDataException: except NoDataException:
self.logger.debug("%s No data available" % self.session.client_id) self.logger.debug("%s No data available" % self.session.client_id)
except Exception as e: except BaseException as e:
self.logger.warning("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e)) self.logger.warning("%s Unhandled exception in reader coro: %s" % (type(self).__name__, e))
break break
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
self._reader_stopped.set() self._reader_stopped.set()
@ -412,7 +422,7 @@ class ProtocolHandler:
except ConnectionResetError as cre: except ConnectionResetError as cre:
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
raise raise
except Exception as e: except BaseException as e:
self.logger.warning("Unhandled exception: %s" % e) self.logger.warning("Unhandled exception: %s" % e)
raise raise

Wyświetl plik

@ -113,3 +113,6 @@ class Session:
self.__dict__.update(state) self.__dict__.update(state)
self.retained_messages = Queue() self.retained_messages = Queue()
self.delivered_message_queue = Queue() self.delivered_message_queue = Queue()
def __eq__(self, other):
return self.client_id == other.client_id

Wyświetl plik

@ -1,7 +1,7 @@
import logging import logging
import asyncio import asyncio
from hbmqtt.client import MQTTClient from hbmqtt.client import MQTTClient, ClientException
from hbmqtt.mqtt.constants import QOS_1, QOS_2 from hbmqtt.mqtt.constants import QOS_1, QOS_2
@ -17,22 +17,25 @@ C = MQTTClient()
@asyncio.coroutine @asyncio.coroutine
def uptime_coro(): def uptime_coro():
yield from C.connect('mqtt://test.mosquitto.org:1883/') yield from C.connect('mqtt://localhost/')
# Subscribe to '$SYS/broker/uptime' with QOS=1 # Subscribe to '$SYS/broker/uptime' with QOS=1
yield from C.subscribe([ yield from C.subscribe([
('$SYS/broker/uptime', QOS_1), ('$SYS/broker/uptime', QOS_1),
('$SYS/broker/load/#', QOS_2), ('$SYS/broker/load/#', QOS_2),
]) ])
logger.info("Subscribed") logger.info("Subscribed")
for i in range(1, 100): try:
packet = yield from C.deliver_message() for i in range(1, 100):
print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data))) packet = yield from C.deliver_message()
yield from C.unsubscribe(['$SYS/broker/uptime']) print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data)))
logger.info("UnSubscribed") yield from C.unsubscribe(['$SYS/broker/uptime'])
yield from C.disconnect() logger.info("UnSubscribed")
yield from C.disconnect()
except ClientException as ce:
logger.error("Client exception: %s" % ce)
if __name__ == '__main__': if __name__ == '__main__':
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter) logging.basicConfig(level=logging.DEBUG, format=formatter)
asyncio.get_event_loop().run_until_complete(uptime_coro()) asyncio.get_event_loop().run_until_complete(uptime_coro())

Wyświetl plik

@ -35,8 +35,8 @@ class ProtocolHandlerTest(unittest.TestCase):
def test_init_handler(self): def test_init_handler(self):
s = Session() s = Session()
handler = ProtocolHandler(s, self.plugin_manager, loop=self.loop) handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.assertIs(handler.session, s) self.assertIsNone(handler.session)
self.assertIs(handler._loop, self.loop) self.assertIs(handler._loop, self.loop)
self.check_empty_waiters(handler) self.check_empty_waiters(handler)
@ -51,8 +51,8 @@ class ProtocolHandlerTest(unittest.TestCase):
s = Session() s = Session()
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
handler = ProtocolHandler(s, self.plugin_manager) handler = ProtocolHandler(self.plugin_manager)
handler.attach_stream(reader_adapted, writer_adapted) handler.attach(s, reader_adapted, writer_adapted)
yield from self.start_handler(handler, s) yield from self.start_handler(handler, s)
yield from self.stop_handler(handler, s) yield from self.stop_handler(handler, s)
future.set_result(True) future.set_result(True)
@ -79,15 +79,14 @@ class ProtocolHandlerTest(unittest.TestCase):
except Exception as ae: except Exception as ae:
future.set_exception(ae) future.set_exception(ae)
@asyncio.coroutine @asyncio.coroutine
def test_coro(): def test_coro():
try: try:
s = Session() s = Session()
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
handler = ProtocolHandler(s, self.plugin_manager, loop=self.loop) handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
handler.attach_stream(reader_adapted, writer_adapted) handler.attach(s, reader_adapted, writer_adapted)
yield from self.start_handler(handler, s) yield from self.start_handler(handler, s)
message = yield from handler.mqtt_publish('/topic', b'test_data', QOS_0, False) message = yield from handler.mqtt_publish('/topic', b'test_data', QOS_0, False)
self.assertIsInstance(message, OutgoingApplicationMessage) self.assertIsInstance(message, OutgoingApplicationMessage)
@ -130,8 +129,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try: try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted) self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session) yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False) message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False)
self.assertIsInstance(message, OutgoingApplicationMessage) self.assertIsInstance(message, OutgoingApplicationMessage)
@ -182,8 +181,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try: try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted) self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session) yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False) message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False)
self.assertIsInstance(message, OutgoingApplicationMessage) self.assertIsInstance(message, OutgoingApplicationMessage)
@ -220,8 +219,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try: try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted) self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session) yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message() message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage) self.assertIsInstance(message, IncomingApplicationMessage)
@ -264,8 +263,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try: try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted) self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.start_handler(self.handler, self.session) yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message() message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage) self.assertIsInstance(message, IncomingApplicationMessage)
@ -385,8 +384,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try: try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted) self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.handler.start() yield from self.handler.start()
yield from self.stop_handler(self.handler, self.session) yield from self.stop_handler(self.handler, self.session)
if not future.done(): if not future.done():
@ -433,8 +432,8 @@ class ProtocolHandlerTest(unittest.TestCase):
try: try:
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop) reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer) reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.session, self.plugin_manager, loop=self.loop) self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach_stream(reader_adapted, writer_adapted) self.handler.attach(self.session, reader_adapted, writer_adapted)
yield from self.handler.start() yield from self.handler.start()
yield from self.stop_handler(self.handler, self.session) yield from self.stop_handler(self.handler, self.session)
if not future.done(): if not future.done():

Wyświetl plik

@ -7,11 +7,26 @@ import asyncio
import logging import logging
from hbmqtt.broker import * from hbmqtt.broker import *
from hbmqtt.mqtt.constants import * from hbmqtt.mqtt.constants import *
from hbmqtt.client import MQTTClient
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter) logging.basicConfig(level=logging.DEBUG, format=formatter)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
test_config = {
'listeners': {
'default': {
'type': 'tcp',
'bind': 'localhost:1883',
'max_connections': 10
},
},
'sys_interval': 0,
'auth': {
'allow-anonymous': True,
}
}
class BrokerTest(unittest.TestCase): class BrokerTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -23,25 +38,12 @@ class BrokerTest(unittest.TestCase):
@patch('hbmqtt.broker.PluginManager') @patch('hbmqtt.broker.PluginManager')
def test_start_stop(self, MockPluginManager): def test_start_stop(self, MockPluginManager):
config = {
'listeners': {
'default': {
'type': 'tcp',
'bind': '0.0.0.0:1883',
'max_connections': 10
},
},
'sys_interval': 0,
'auth': {
'allow-anonymous': True,
}
}
def test_coro(): def test_coro():
try: try:
broker = Broker(config, plugin_namespace="hbmqtt.test.plugins") broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
yield from broker.start() yield from broker.start()
self.assertTrue(broker.transitions.is_started()) self.assertTrue(broker.transitions.is_started())
self.assertDictEqual(broker._sessions, {})
self.assertIn('default', broker._servers) self.assertIn('default', broker._servers)
MockPluginManager.assert_has_calls( MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_PRE_START), [call().fire_event(EVENT_BROKER_PRE_START),
@ -60,3 +62,29 @@ class BrokerTest(unittest.TestCase):
self.loop.run_until_complete(test_coro()) self.loop.run_until_complete(test_coro())
if future.exception(): if future.exception():
raise future.exception() raise future.exception()
@patch('hbmqtt.broker.PluginManager')
def test_client_connect(self, MockPluginManager):
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
client = MQTTClient()
ret = yield from client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
yield from client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_CLIENT_CONNECTED, session=client.session),
call().fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, session=client.session)], any_order=True)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
future = asyncio.Future(loop=self.loop)
self.loop.run_until_complete(test_coro())
if future.exception():
raise future.exception()

Wyświetl plik

@ -5,7 +5,7 @@ import unittest
import asyncio import asyncio
import os import os
import logging import logging
from hbmqtt.client import MQTTClient from hbmqtt.client import MQTTClient, ConnectException
from hbmqtt.mqtt.constants import * from hbmqtt.mqtt.constants import *
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -60,9 +60,10 @@ class MQTTClientTest(unittest.TestCase):
@asyncio.coroutine @asyncio.coroutine
def test_coro(): def test_coro():
try: try:
client = MQTTClient() config = {'auto_reconnect': False}
client = MQTTClient(config=config)
ret = yield from client.connect('mqtt://localhost/') ret = yield from client.connect('mqtt://localhost/')
except Exception as e: except ConnectException as e:
future.set_result(True) future.set_result(True)
future = asyncio.Future(loop=self.loop) future = asyncio.Future(loop=self.loop)