Revert to 3.4 coroutine syntax

pull/8/head
Nico 2015-11-01 15:58:20 +01:00
rodzic 8d7fbaaff4
commit 4dcf8eb477
33 zmienionych plików z 764 dodań i 567 usunięć

Wyświetl plik

@ -15,7 +15,8 @@ class ReaderAdapter:
Reader adapters are used to adapt read operations on the network depending on the protocol used
"""
async def read(self, n=-1) -> bytes:
@asyncio.coroutine
def read(self, n=-1) -> bytes:
"""
Read up to n bytes. If n is not provided, or set to -1, read until EOF and return all read bytes.
If the EOF was received and the internal buffer is empty, return an empty bytes object.
@ -40,7 +41,8 @@ class WriterAdapter:
write some data to the protocol layer
"""
async def drain(self):
@asyncio.coroutine
def drain(self):
"""
Let the write buffer of the underlying transport a chance to be flushed.
"""
@ -50,7 +52,8 @@ class WriterAdapter:
Return peer socket info (remote address and remote port as tuple
"""
async def close(self):
@asyncio.coroutine
def close(self):
"""
Close the protocol connection
"""
@ -65,19 +68,21 @@ class WebSocketsReader(ReaderAdapter):
self._protocol = protocol
self._stream = io.BytesIO(b'')
async def read(self, n=-1) -> bytes:
await self._feed_buffer(n)
@asyncio.coroutine
def read(self, n=-1) -> bytes:
yield from self._feed_buffer(n)
data = self._stream.read(n)
return data
async def _feed_buffer(self, n=1):
@asyncio.coroutine
def _feed_buffer(self, n=1):
"""
Feed the data buffer by reading a Websocket message.
:param n: if given, feed buffer until it contains at least n bytes
"""
buffer = bytearray(self._stream.read())
while len(buffer) < n:
message = await self._protocol.recv()
message = yield from self._protocol.recv()
if message is None:
break
if not isinstance(message, bytes):
@ -101,21 +106,23 @@ class WebSocketsWriter(WriterAdapter):
"""
self._stream.write(data)
async def drain(self):
@asyncio.coroutine
def drain(self):
"""
Let the write buffer of the underlying transport a chance to be flushed.
"""
data = self._stream.getvalue()
if len(data):
await self._protocol.send(data)
yield from self._protocol.send(data)
self._stream = io.BytesIO(b'')
def get_peer_info(self):
extra_info = self._protocol.writer.get_extra_info('peername')
return extra_info[0], extra_info[1]
async def close(self):
await self._protocol.close()
@asyncio.coroutine
def close(self):
yield from self._protocol.close()
class StreamReaderAdapter(ReaderAdapter):
@ -127,8 +134,9 @@ class StreamReaderAdapter(ReaderAdapter):
def __init__(self, reader: StreamReader):
self._reader = reader
async def read(self, n=-1) -> bytes:
return await self._reader.read(n)
@asyncio.coroutine
def read(self, n=-1) -> bytes:
return (yield from self._reader.read(n))
def feed_eof(self):
return self._reader.feed_eof()
@ -147,15 +155,17 @@ class StreamWriterAdapter(WriterAdapter):
def write(self, data):
self._writer.write(data)
async def drain(self):
await self._writer.drain()
@asyncio.coroutine
def drain(self):
yield from self._writer.drain()
def get_peer_info(self):
extra_info = self._writer.get_extra_info('peername')
return extra_info[0], extra_info[1]
async def close(self):
await self._writer.drain()
@asyncio.coroutine
def close(self):
yield from self._writer.drain()
if self._writer.can_write_eof():
self._writer.write_eof()
self._writer.close()
@ -169,7 +179,8 @@ class BufferReader(ReaderAdapter):
def __init__(self, buffer: bytes):
self._stream = io.BytesIO(buffer)
async def read(self, n=-1) -> bytes:
@asyncio.coroutine
def read(self, n=-1) -> bytes:
return self._stream.read(n)
@ -187,7 +198,8 @@ class BufferWriter(WriterAdapter):
"""
self._stream.write(data)
async def drain(self):
@asyncio.coroutine
def drain(self):
pass
def get_buffer(self):
@ -196,5 +208,6 @@ class BufferWriter(WriterAdapter):
def get_peer_info(self):
return "BufferWriter", 0
async def close(self):
@asyncio.coroutine
def close(self):
self._stream.close()

Wyświetl plik

@ -72,9 +72,10 @@ class Server:
else:
self.semaphore = None
async def acquire_connection(self):
@asyncio.coroutine
def acquire_connection(self):
if self.semaphore:
await self.semaphore.acquire()
yield from self.semaphore.acquire()
self.conn_count += 1
if self.max_connections > 0:
self.logger.info("Listener '%s': %d/%d connections acquired" %
@ -94,10 +95,11 @@ class Server:
self.logger.info("Listener '%s': %d connections acquired" %
(self.listener_name, self.conn_count))
async def close_instance(self):
@asyncio.coroutine
def close_instance(self):
if self.instance:
self.instance.close()
await self.instance.wait_closed()
yield from self.instance.wait_closed()
class BrokerContext(BaseContext):
@ -110,8 +112,9 @@ class BrokerContext(BaseContext):
self.config = None
self._broker_instance = broker
async def broadcast_message(self, topic, data, qos=None):
await self._broker_instance.internal_message_broadcast(topic, data, qos)
@asyncio.coroutine
def broadcast_message(self, topic, data, qos=None):
yield from self._broker_instance.internal_message_broadcast(topic, data, qos)
def retain_message(self, topic_name, data, qos=None):
self._broker_instance.retain_message(None, topic_name, data, qos)
@ -220,7 +223,8 @@ class Broker:
self.transitions.add_transition(trigger='stopping_failure', source='stopping', dest='not_stopped')
self.transitions.add_transition(trigger='start', source='stopped', dest='starting')
async def start(self):
@asyncio.coroutine
def start(self):
try:
self._sessions = dict()
self._subscriptions = dict()
@ -231,7 +235,7 @@ class Broker:
self.logger.warn("[WARN-0001] Invalid method call at this moment: %s" % me)
raise BrokerException("Broker instance can't be started: %s" % me)
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
yield from self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
try:
# Start network listeners
for listener_name in self.listeners_config:
@ -259,7 +263,7 @@ class Broker:
if listener['type'] == 'tcp':
address, port = listener['bind'].split(':')
cb_partial = partial(self.stream_connected, listener_name=listener_name)
instance = await asyncio.start_server(cb_partial,
instance = yield from asyncio.start_server(cb_partial,
address,
port,
ssl=sc,
@ -268,14 +272,14 @@ class Broker:
elif listener['type'] == 'ws':
address, port = listener['bind'].split(':')
cb_partial = partial(self.ws_connected, listener_name=listener_name)
instance = await websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
self.logger.info("Listener '%s' bind to %s (max_connecionts=%d)" %
(listener_name, listener['bind'], max_connections))
self.transitions.starting_success()
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
yield from self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
#Start broadcast loop
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop(), loop=self._loop)
@ -286,7 +290,8 @@ class Broker:
self.transitions.starting_fail()
raise BrokerException("Broker instance can't be started: %s" % e)
async def shutdown(self):
@asyncio.coroutine
def shutdown(self):
try:
self._sessions = dict()
self._subscriptions = dict()
@ -297,7 +302,7 @@ class Broker:
raise BrokerException("Broker instance can't be stopped: %s" % me)
# Fire broker_shutdown event to plugins
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
yield from self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
# Stop broadcast loop
if self._broadcast_task:
@ -307,44 +312,48 @@ class Broker:
for listener_name in self._servers:
server = self._servers[listener_name]
await server.close_instance()
yield from server.close_instance()
self.logger.debug("Broker closing")
self.logger.info("Broker closed")
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
yield from self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
self.transitions.stopping_success()
async def internal_message_broadcast(self, topic, data, qos=None):
return await self._broadcast_message(None, topic, data)
@asyncio.coroutine
def internal_message_broadcast(self, topic, data, qos=None):
return (yield from self._broadcast_message(None, topic, data))
async def ws_connected(self, websocket, uri, listener_name):
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
@asyncio.coroutine
def ws_connected(self, websocket, uri, listener_name):
yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
async def stream_connected(self, reader, writer, listener_name):
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
@asyncio.coroutine
def stream_connected(self, reader, writer, listener_name):
yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
@asyncio.coroutine
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
# Wait for connection available on listener
server = self._servers.get(listener_name, None)
if not server:
raise BrokerException("Invalid listener name '%s'" % listener_name)
await server.acquire_connection()
yield from server.acquire_connection()
remote_address, remote_port = writer.get_peer_info()
self.logger.info("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
# Wait for first packet and expect a CONNECT
try:
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
handler, client_session = yield from BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
except HBMQTTException as exc:
self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
(format_client_message(address=remote_address, port=remote_port), exc))
await writer.close()
yield from writer.close()
self.logger.debug("Connection closed")
return
except MQTTException as me:
self.logger.error('Invalid connection from %s : %s' %
(format_client_message(address=remote_address, port=remote_port), me))
await writer.close()
yield from writer.close()
self.logger.debug("Connection closed")
return
@ -371,9 +380,9 @@ class Broker:
handler.attach(client_session, reader, writer)
self._sessions[client_session.client_id] = (client_session, handler)
authenticated = await self.authenticate(client_session, self.listeners_config[listener_name])
authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
if not authenticated:
await writer.close()
yield from writer.close()
return
while True:
@ -383,15 +392,15 @@ class Broker:
except MachineError:
self.logger.warning("Client %s is reconnecting too quickly, make it wait" % client_session.client_id)
# Wait a bit may be client is reconnecting too fast
await asyncio.sleep(1, loop=self._loop)
await handler.mqtt_connack_authorize(authenticated)
yield from asyncio.sleep(1, loop=self._loop)
yield from handler.mqtt_connack_authorize(authenticated)
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
self.logger.debug("%s Start messages handling" % client_session.client_id)
await handler.start()
yield from handler.start()
self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize())
await self.publish_session_retained_messages(client_session)
yield from self.publish_session_retained_messages(client_session)
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop)
@ -401,7 +410,7 @@ class Broker:
connected = True
while connected:
try:
done, pending = await asyncio.wait(
done, pending = yield from asyncio.wait(
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
return_when=asyncio.FIRST_COMPLETED, loop=self._loop)
if disconnect_waiter in done:
@ -413,7 +422,7 @@ class Broker:
if client_session.will_flag:
self.logger.debug("Client %s disconnected abnormally, sending will message" %
format_client_message(client_session))
await self._broadcast_message(
yield from self._broadcast_message(
client_session,
client_session.will_topic,
client_session.will_message,
@ -424,21 +433,21 @@ class Broker:
client_session.will_message,
client_session.will_qos)
self.logger.debug("%s Disconnecting session" % client_session.client_id)
await self._stop_handler(handler)
yield from self._stop_handler(handler)
client_session.transitions.disconnect()
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
await writer.close()
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
yield from writer.close()
connected = False
if unsubscribe_waiter in done:
self.logger.debug("%s handling unsubscription" % client_session.client_id)
unsubscription = unsubscribe_waiter.result()
for topic in unsubscription['topics']:
self._del_subscription(topic, client_session)
await self.plugins_manager.fire_event(
yield from self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
client_id=client_session.client_id,
topic=topic)
await handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
if subscribe_waiter in done:
self.logger.debug("%s handling subscription" % client_session.client_id)
@ -446,25 +455,25 @@ class Broker:
return_codes = []
for subscription in subscriptions['topics']:
return_codes.append(self.add_subscription(subscription, client_session))
await handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
for index, subscription in enumerate(subscriptions['topics']):
if return_codes[index] != 0x80:
await self.plugins_manager.fire_event(
yield from self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client_session.client_id,
topic=subscription[0],
qos=subscription[1])
await self.publish_retained_messages_for_subscription(subscription, client_session)
yield from self.publish_retained_messages_for_subscription(subscription, client_session)
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
self.logger.debug(repr(self._subscriptions))
if wait_deliver in done:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("%s handling message delivery" % client_session.client_id)
app_message = wait_deliver.result()
await self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
client_id=client_session.client_id,
message=app_message)
await self._broadcast_message(client_session, app_message.topic, app_message.data)
yield from self._broadcast_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
@ -489,18 +498,20 @@ class Broker:
handler.attach(session, reader, writer)
return handler
async def _stop_handler(self, handler):
@asyncio.coroutine
def _stop_handler(self, handler):
"""
Stop a running handler and detach if from the session
:param handler:
:return:
"""
try:
await handler.stop()
yield from handler.stop()
except Exception as e:
self.logger.error(e)
async def authenticate(self, session: Session, listener):
@asyncio.coroutine
def authenticate(self, session: Session, listener):
"""
This method call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True.
@ -516,7 +527,7 @@ class Broker:
auth_config = self.config.get('auth', None)
if auth_config:
auth_plugins = auth_config.get('plugins', None)
returns = await self.plugins_manager.map_plugin_coro(
returns = yield from self.plugins_manager.map_plugin_coro(
"authenticate",
session=session,
filter_plugins=auth_plugins)
@ -615,13 +626,14 @@ class Broker:
else:
return False
async def _broadcast_loop(self):
@asyncio.coroutine
def _broadcast_loop(self):
running_tasks = deque()
try:
while True:
while running_tasks and running_tasks[0].done():
running_tasks.popleft()
broadcast = await self._broadcast_queue.get()
broadcast = yield from self._broadcast_queue.get()
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("broadcasting %r" % broadcast)
for k_filter in self._subscriptions:
@ -645,13 +657,14 @@ class Broker:
broadcast['topic'], format_client_message(session=target_session)))
retained_message = RetainedApplicationMessage(
broadcast['session'], broadcast['topic'], broadcast['data'], qos)
await target_session.retained_messages.put(retained_message)
yield from target_session.retained_messages.put(retained_message)
except CancelledError:
# Wait until current broadcasting tasks end
if running_tasks:
await asyncio.wait(running_tasks, loop=self._loop)
yield from asyncio.wait(running_tasks, loop=self._loop)
async def _broadcast_message(self, session, topic, data, force_qos=None):
@asyncio.coroutine
def _broadcast_message(self, session, topic, data, force_qos=None):
broadcast = {
'session': session,
'topic': topic,
@ -659,23 +672,25 @@ class Broker:
}
if force_qos:
broadcast['qos'] = force_qos
await self._broadcast_queue.put(broadcast)
yield from self._broadcast_queue.put(broadcast)
async def publish_session_retained_messages(self, session):
@asyncio.coroutine
def publish_session_retained_messages(self, session):
self.logger.debug("Publishing %d messages retained for session %s" %
(session.retained_messages.qsize(), format_client_message(session=session))
)
publish_tasks = []
handler = self._get_handler(session)
while not session.retained_messages.empty():
retained = await session.retained_messages.get()
retained = yield from session.retained_messages.get()
publish_tasks.append(asyncio.ensure_future(
handler.mqtt_publish(
retained.topic, retained.data, retained.qos, True), loop=self._loop))
if publish_tasks:
await asyncio.wait(publish_tasks, loop=self._loop)
yield from asyncio.wait(publish_tasks, loop=self._loop)
async def publish_retained_messages_for_subscription(self, subscription, session):
@asyncio.coroutine
def publish_retained_messages_for_subscription(self, subscription, session):
self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s" %
(subscription[0], format_client_message(session=session)))
publish_tasks = []
@ -689,7 +704,7 @@ class Broker:
handler.mqtt_publish(
retained.topic, retained.data, subscription[1], True), loop=self._loop))
if publish_tasks:
await asyncio.wait(publish_tasks, loop=self._loop)
yield from asyncio.wait(publish_tasks, loop=self._loop)
self.logger.debug("End broadcasting messages retained due to subscription on '%s' from %s" %
(subscription[0], format_client_message(session=session)))

Wyświetl plik

@ -57,11 +57,12 @@ def mqtt_connected(func):
:param func: coroutine to be called once connected
:return: coroutine result
"""
async def wrapper(self, *args, **kwargs):
@asyncio.coroutine
def wrapper(self, *args, **kwargs):
if not self._connected_state.is_set():
base_logger.warning("Client not connected, waiting for it")
await self._connected_state.wait()
return await func(self, *args, **kwargs)
yield from self._connected_state.wait()
return (yield from func(self, *args, **kwargs))
return wrapper
@ -118,7 +119,8 @@ class MQTTClient:
self.client_tasks = deque()
async def connect(self,
@asyncio.coroutine
def connect(self,
uri=None,
cleansession=None,
cafile=None,
@ -135,28 +137,30 @@ class MQTTClient:
self.logger.debug("Connect to: %s" % uri)
try:
return await self._do_connect()
return (yield from self._do_connect())
except BaseException as be:
self.logger.warning("Connection failed: %r" % be)
auto_reconnect = self.config.get('auto_reconnect', False)
if not auto_reconnect:
raise
else:
return await self.reconnect()
return (yield from self.reconnect())
@mqtt_connected
async def disconnect(self):
@asyncio.coroutine
def disconnect(self):
if self.session.transitions.is_connected():
if not self._disconnect_task.done():
self._disconnect_task.cancel()
await self._handler.mqtt_disconnect()
yield from self._handler.mqtt_disconnect()
self._connected_state.clear()
await self._handler.stop()
yield from self._handler.stop()
self.session.transitions.disconnect()
else:
self.logger.warn("Client session is not currently connected, ignoring call")
async def reconnect(self, cleansession=None):
@asyncio.coroutine
def reconnect(self, cleansession=None):
if self.session.transitions.is_connected():
self.logger.warn("Client already connected")
return CONNECTION_ACCEPTED
@ -167,11 +171,11 @@ class MQTTClient:
reconnect_max_interval = self.config.get('reconnect_max_interval', 10)
reconnect_retries = self.config.get('reconnect_retries', 5)
nb_attempt = 1
await asyncio.sleep(1, loop=self._loop)
yield from asyncio.sleep(1, loop=self._loop)
while True:
try:
self.logger.debug("Reconnect attempt %d ..." % nb_attempt)
return (await self._do_connect())
return (yield from self._do_connect())
except BaseException as e:
self.logger.warning("Reconnection attempt failed: %r" % e)
if nb_attempt > reconnect_retries:
@ -180,29 +184,32 @@ class MQTTClient:
exp = 2 ** nb_attempt
delay = exp if exp < reconnect_max_interval else reconnect_max_interval
self.logger.debug("Waiting %d second before next attempt" % delay)
await asyncio.sleep(delay, loop=self._loop)
yield from asyncio.sleep(delay, loop=self._loop)
nb_attempt += 1
async def _do_connect(self):
return_code = await self._connect_coro()
@asyncio.coroutine
def _do_connect(self):
return_code = yield from self._connect_coro()
self._disconnect_task = asyncio.ensure_future(self.handle_connection_close(), loop=self._loop)
return return_code
@mqtt_connected
async def ping(self):
@asyncio.coroutine
def ping(self):
"""
Send a MQTT ping request and wait for response
:return: None
"""
if self.session.transitions.is_connected():
await self._handler.mqtt_ping()
yield from self._handler.mqtt_ping()
else:
self.logger.warn("MQTT PING request incompatible with current session state '%s'" %
self.session.transitions.state)
@mqtt_connected
async def publish(self, topic, message, qos=None, retain=None):
@asyncio.coroutine
def publish(self, topic, message, qos=None, retain=None):
def get_retain_and_qos():
if qos:
assert qos in (QOS_0, QOS_1, QOS_2)
@ -223,27 +230,31 @@ class MQTTClient:
pass
return _qos, _retain
(app_qos, app_retain) = get_retain_and_qos()
return await self._handler.mqtt_publish(topic, message, app_qos, app_retain)
return (yield from self._handler.mqtt_publish(topic, message, app_qos, app_retain))
@mqtt_connected
async def subscribe(self, topics):
return await self._handler.mqtt_subscribe(topics, self.session.next_packet_id)
@asyncio.coroutine
def subscribe(self, topics):
return (yield from self._handler.mqtt_subscribe(topics, self.session.next_packet_id))
@mqtt_connected
async def unsubscribe(self, topics):
await self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
@asyncio.coroutine
def unsubscribe(self, topics):
yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
async def deliver_message(self, timeout=None):
@asyncio.coroutine
def deliver_message(self, timeout=None):
deliver_task = asyncio.ensure_future(self._handler.mqtt_deliver_next_message(), loop=self._loop)
self.client_tasks.append(deliver_task)
self.logger.debug("Waiting message delivery")
await asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
if deliver_task.exception():
raise deliver_task.exception()
self.client_tasks.pop()
return deliver_task.result()
async def _connect_coro(self):
@asyncio.coroutine
def _connect_coro(self):
kwargs = dict()
# Decode URI attributes
@ -287,13 +298,13 @@ class MQTTClient:
# Open connection
if scheme in ('mqtt', 'mqtts'):
conn_reader, conn_writer = \
await asyncio.open_connection(
yield from asyncio.open_connection(
self.session.remote_address,
self.session.remote_port, loop=self._loop, **kwargs)
reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)
elif scheme in ('ws', 'wss'):
websocket = await websockets.connect(
websocket = yield from websockets.connect(
self.session.broker_uri,
subprotocols=['mqtt'],
loop=self._loop,
@ -302,7 +313,7 @@ class MQTTClient:
writer = WebSocketsWriter(websocket)
# Start MQTT protocol
self._handler.attach(self.session, reader, writer)
return_code = await self._handler.mqtt_connect()
return_code = yield from self._handler.mqtt_connect()
if return_code is not CONNECTION_ACCEPTED:
self.session.transitions.disconnect()
self.logger.warning("Connection rejected with code '%s'" % return_code)
@ -311,7 +322,7 @@ class MQTTClient:
raise exc
else:
# Handle MQTT protocol
await self._handler.start()
yield from self._handler.start()
self.session.transitions.connect()
self._connected_state.set()
self.logger.debug("connected to %s:%s" % (self.session.remote_address, self.session.remote_port))
@ -329,17 +340,18 @@ class MQTTClient:
self.session.transitions.disconnect()
raise ConnectException(e)
async def handle_connection_close(self):
@asyncio.coroutine
def handle_connection_close(self):
self.logger.debug("Watch broker disconnection")
# Wait for disconnection from broker (like connection lost)
await self._handler.wait_disconnect()
yield from self._handler.wait_disconnect()
self.logger.warning("Disconnected from broker")
# Block client API
self._connected_state.clear()
# stop an clean handler
await self._handler.stop()
yield from self._handler.stop()
self._handler.detach()
self.session.transitions.disconnect()
@ -347,7 +359,7 @@ class MQTTClient:
# Try reconnection
self.logger.debug("Auto-reconnecting")
try:
await self.reconnect()
yield from self.reconnect()
except ConnectException:
# Cancel client pending tasks
while self.client_tasks:

Wyświetl plik

@ -14,6 +14,7 @@ def bytes_to_hex_str(data):
"""
return '0x' + ''.join(format(b, '02x') for b in data)
def bytes_to_int(data):
"""
convert a sequence of bytes to an integer using big endian byte ordering
@ -25,6 +26,7 @@ def bytes_to_int(data):
except:
return data
def int_to_bytes(int_value: int, length: int) -> bytes:
"""
convert an integer to a sequence of bytes using big endian byte ordering
@ -38,28 +40,32 @@ def int_to_bytes(int_value: int, length: int) -> bytes:
fmt = "!H"
return pack(fmt, int_value)
async def read_or_raise(reader, n=-1):
@asyncio.coroutine
def read_or_raise(reader, n=-1):
"""
Read a given byte number from Stream. NoDataException is raised if read gives no data
:param reader: reader adapter
:param n: number of bytes to read
:return: bytes read
"""
data = await reader.read(n)
data = yield from reader.read(n)
if not data:
raise NoDataException("No more data")
return data
async def decode_string(reader) -> bytes:
@asyncio.coroutine
def decode_string(reader) -> bytes:
"""
Read a string from a reader and decode it according to MQTT string specification
:param reader: Stream reader
:return: UTF-8 string read from stream
"""
length_bytes = await read_or_raise(reader, 2)
length_bytes = yield from read_or_raise(reader, 2)
str_length = unpack("!H", length_bytes)
if str_length[0]:
byte_str = await read_or_raise(reader, str_length[0])
byte_str = yield from read_or_raise(reader, str_length[0])
try:
return byte_str.decode(encoding='utf-8')
except:
@ -67,15 +73,17 @@ async def decode_string(reader) -> bytes:
else:
return ''
async def decode_data_with_length(reader) -> bytes:
@asyncio.coroutine
def decode_data_with_length(reader) -> bytes:
"""
Read data from a reader. Data is prefixed with 2 bytes length
:param reader: Stream reader
:return: bytes read from stream (without length)
"""
length_bytes = await read_or_raise(reader, 2)
length_bytes = yield from read_or_raise(reader, 2)
bytes_length = unpack("!H", length_bytes)
data = await read_or_raise(reader, bytes_length[0])
data = yield from read_or_raise(reader, bytes_length[0])
return data
@ -89,13 +97,15 @@ def encode_data_with_length(data: bytes) -> bytes:
data_length = len(data)
return int_to_bytes(data_length, 2) + data
async def decode_packet_id(reader) -> int:
@asyncio.coroutine
def decode_packet_id(reader) -> int:
"""
Read a packet ID as 2-bytes int from stream according to MQTT specification (2.3.1)
:param reader: Stream reader
:return: Packet ID
"""
packet_id_bytes = await read_or_raise(reader, 2)
packet_id_bytes = yield from read_or_raise(reader, 2)
packet_id = unpack("!H", packet_id_bytes)
return packet_id[0]

Wyświetl plik

@ -1,6 +1,7 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import asyncio
from hbmqtt.mqtt.packet import CONNACK, MQTTPacket, MQTTFixedHeader, MQTTVariableHeader
from hbmqtt.codecs import int_to_bytes, read_or_raise, bytes_to_int
from hbmqtt.errors import HBMQTTException
@ -21,8 +22,9 @@ class ConnackVariableHeader(MQTTVariableHeader):
self.return_code = return_code
@classmethod
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
data = await read_or_raise(reader, 2)
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
data = yield from read_or_raise(reader, 2)
session_parent = data[0] & 0x01
return_code = bytes_to_int(data[1])
return cls(session_parent, return_code)

Wyświetl plik

@ -93,20 +93,21 @@ class ConnectVariableHeader(MQTTVariableHeader):
self.flags |= (val << 3)
@classmethod
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
# protocol name
protocol_name = await decode_string(reader)
protocol_name = yield from decode_string(reader)
# protocol level
protocol_level_byte = await read_or_raise(reader, 1)
protocol_level_byte = yield from read_or_raise(reader, 1)
protocol_level = bytes_to_int(protocol_level_byte)
# flags
flags_byte = await read_or_raise(reader, 1)
flags_byte = yield from read_or_raise(reader, 1)
flags = bytes_to_int(flags_byte)
# keep-alive
keep_alive_byte = await read_or_raise(reader, 2)
keep_alive_byte = yield from read_or_raise(reader, 2)
keep_alive = bytes_to_int(keep_alive_byte)
return cls(flags, keep_alive, protocol_name, protocol_level)
@ -140,33 +141,34 @@ class ConnectPayload(MQTTPayload):
format(self.client_id, self.will_topic, self.will_message, self.username, self.password)
@classmethod
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
variable_header: ConnectVariableHeader):
payload = cls()
# Client identifier
try:
payload.client_id = await decode_string(reader)
payload.client_id = yield from decode_string(reader)
except NoDataException:
payload.client_id = None
# Read will topic, username and password
if variable_header.will_flag:
try:
payload.will_topic = await decode_string(reader)
payload.will_message = await decode_data_with_length(reader)
payload.will_topic = yield from decode_string(reader)
payload.will_message = yield from decode_data_with_length(reader)
except NoDataException:
payload.will_topic = None
payload.will_message = None
if variable_header.username_flag:
try:
payload.username = await decode_string(reader)
payload.username = yield from decode_string(reader)
except NoDataException:
payload.username = None
if variable_header.password_flag:
try:
payload.password = await decode_string(reader)
payload.password = yield from decode_string(reader)
except NoDataException:
payload.password = None

Wyświetl plik

@ -58,7 +58,8 @@ class MQTTFixedHeader:
return out
async def to_stream(self, writer: WriterAdapter):
@asyncio.coroutine
def to_stream(self, writer: WriterAdapter):
writer.write(self.to_bytes())
@property
@ -66,12 +67,14 @@ class MQTTFixedHeader:
return len(self.to_bytes())
@classmethod
async def from_stream(cls, reader: ReaderAdapter):
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter):
"""
Read and decode MQTT message fixed header from stream
:return: FixedHeader instance
"""
async def decode_remaining_length():
@asyncio.coroutine
def decode_remaining_length():
"""
Decode message length according to MQTT specifications
:return:
@ -80,7 +83,7 @@ class MQTTFixedHeader:
value = 0
buffer = bytearray()
while True:
encoded_byte = await reader.read(1)
encoded_byte = yield from reader.read(1)
int_byte = unpack('!B', encoded_byte)
buffer.append(int_byte[0])
value += (int_byte[0] & 0x7f) * multiplier
@ -93,11 +96,11 @@ class MQTTFixedHeader:
return value
try:
byte1 = await read_or_raise(reader, 1)
byte1 = yield from read_or_raise(reader, 1)
int1 = unpack('!B', byte1)
msg_type = (int1[0] & 0xf0) >> 4
flags = int1[0] & 0x0f
remain_length = await decode_remaining_length()
remain_length = yield from decode_remaining_length()
return cls(msg_type, flags, remain_length)
except NoDataException:
@ -112,9 +115,10 @@ class MQTTVariableHeader:
def __init__(self):
pass
async def to_stream(self, writer: asyncio.StreamWriter):
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
await writer.drain()
yield from writer.drain()
def to_bytes(self) -> bytes:
"""
@ -127,7 +131,8 @@ class MQTTVariableHeader:
return len(self.to_bytes())
@classmethod
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
pass
@ -142,8 +147,9 @@ class PacketIdVariableHeader(MQTTVariableHeader):
return out
@classmethod
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
packet_id = await decode_packet_id(reader)
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
packet_id = yield from decode_packet_id(reader)
return cls(packet_id)
def __repr__(self):
@ -154,15 +160,17 @@ class MQTTPayload:
def __init__(self):
pass
async def to_stream(self, writer: asyncio.StreamWriter):
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
await writer.drain()
yield from writer.drain()
def to_bytes(self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader):
pass
@classmethod
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
variable_header: MQTTVariableHeader):
pass
@ -178,9 +186,10 @@ class MQTTPacket:
self.payload = payload
self.protocol_ts = None
async def to_stream(self, writer: asyncio.StreamWriter):
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
await writer.drain()
yield from writer.drain()
self.protocol_ts = datetime.now()
def to_bytes(self) -> bytes:
@ -199,16 +208,17 @@ class MQTTPacket:
return fixed_header_bytes + variable_header_bytes + payload_bytes
@classmethod
async def from_stream(cls, reader: ReaderAdapter, fixed_header=None, variable_header=None):
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter, fixed_header=None, variable_header=None):
if fixed_header is None:
fixed_header = await cls.FIXED_HEADER.from_stream(reader)
fixed_header = yield from cls.FIXED_HEADER.from_stream(reader)
if cls.VARIABLE_HEADER:
if variable_header is None:
variable_header = await cls.VARIABLE_HEADER.from_stream(reader, fixed_header)
variable_header = yield from cls.VARIABLE_HEADER.from_stream(reader, fixed_header)
else:
variable_header = None
if cls.PAYLOAD:
payload = await cls.PAYLOAD.from_stream(reader, fixed_header, variable_header)
payload = yield from cls.PAYLOAD.from_stream(reader, fixed_header, variable_header)
else:
payload = None

Wyświetl plik

@ -1,7 +1,7 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
from asyncio import futures, Queue
from hbmqtt.mqtt.protocol.handler import ProtocolHandler
from hbmqtt.mqtt.connect import ConnectPacket
@ -27,18 +27,21 @@ class BrokerProtocolHandler(ProtocolHandler):
self._pending_subscriptions = Queue(loop=self._loop)
self._pending_unsubscriptions = Queue(loop=self._loop)
async def start(self):
await super().start()
@asyncio.coroutine
def start(self):
yield from super().start()
if self._disconnect_waiter is None:
self._disconnect_waiter = futures.Future(loop=self._loop)
async def stop(self):
await super().stop()
@asyncio.coroutine
def stop(self):
yield from super().stop()
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None)
async def wait_disconnect(self):
return await self._disconnect_waiter
@asyncio.coroutine
def wait_disconnect(self):
return (yield from self._disconnect_waiter)
def handle_write_timeout(self):
pass
@ -47,16 +50,19 @@ class BrokerProtocolHandler(ProtocolHandler):
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None)
async def handle_disconnect(self, disconnect):
@asyncio.coroutine
def handle_disconnect(self, disconnect):
self.logger.debug("Client disconnecting")
if self._disconnect_waiter and not self._disconnect_waiter.done():
self.logger.debug("Setting waiter result to %r" % disconnect)
self._disconnect_waiter.set_result(disconnect)
async def handle_connection_closed(self):
await self.handle_disconnect(None)
@asyncio.coroutine
def handle_connection_closed(self):
yield from self.handle_disconnect(None)
async def handle_connect(self, connect: ConnectPacket):
@asyncio.coroutine
def handle_connect(self, connect: ConnectPacket):
# Broker handler shouldn't received CONNECT message during messages handling
# as CONNECT messages are managed by the broker on client connection
self.logger.error('%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling' %
@ -64,42 +70,51 @@ class BrokerProtocolHandler(ProtocolHandler):
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None)
async def handle_pingreq(self, pingreq: PingReqPacket):
await self._send_packet(PingRespPacket.build())
@asyncio.coroutine
def handle_pingreq(self, pingreq: PingReqPacket):
yield from self._send_packet(PingRespPacket.build())
async def handle_subscribe(self, subscribe: SubscribePacket):
@asyncio.coroutine
def handle_subscribe(self, subscribe: SubscribePacket):
subscription = {'packet_id': subscribe.variable_header.packet_id, 'topics': subscribe.payload.topics}
await self._pending_subscriptions.put(subscription)
yield from self._pending_subscriptions.put(subscription)
async def handle_unsubscribe(self, unsubscribe: UnsubscribePacket):
@asyncio.coroutine
def handle_unsubscribe(self, unsubscribe: UnsubscribePacket):
unsubscription = {'packet_id': unsubscribe.variable_header.packet_id, 'topics': unsubscribe.payload.topics}
await self._pending_unsubscriptions.put(unsubscription)
yield from self._pending_unsubscriptions.put(unsubscription)
async def get_next_pending_subscription(self):
subscription = await self._pending_subscriptions.get()
@asyncio.coroutine
def get_next_pending_subscription(self):
subscription = yield from self._pending_subscriptions.get()
return subscription
async def get_next_pending_unsubscription(self):
unsubscription = await self._pending_unsubscriptions.get()
@asyncio.coroutine
def get_next_pending_unsubscription(self):
unsubscription = yield from self._pending_unsubscriptions.get()
return unsubscription
async def mqtt_acknowledge_subscription(self, packet_id, return_codes):
@asyncio.coroutine
def mqtt_acknowledge_subscription(self, packet_id, return_codes):
suback = SubackPacket.build(packet_id, return_codes)
await self._send_packet(suback)
yield from self._send_packet(suback)
async def mqtt_acknowledge_unsubscription(self, packet_id):
@asyncio.coroutine
def mqtt_acknowledge_unsubscription(self, packet_id):
unsuback = UnsubackPacket.build(packet_id)
await self._send_packet(unsuback)
yield from self._send_packet(unsuback)
async def mqtt_connack_authorize(self, authorize: bool):
@asyncio.coroutine
def mqtt_connack_authorize(self, authorize: bool):
if authorize:
connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED)
else:
connack = ConnackPacket.build(self.session.parent, NOT_AUTHORIZED)
await self._send_packet(connack)
yield from self._send_packet(connack)
@classmethod
async def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
@asyncio.coroutine
def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
"""
:param reader:
@ -108,10 +123,9 @@ class BrokerProtocolHandler(ProtocolHandler):
:param loop:
:return:
"""
log = logging.getLogger(__name__)
remote_address, remote_port = writer.get_peer_info()
connect = await ConnectPacket.from_stream(reader)
await plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
connect = yield from ConnectPacket.from_stream(reader)
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
if connect.payload.client_id is None:
raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
@ -144,9 +158,9 @@ class BrokerProtocolHandler(ProtocolHandler):
format_client_message(address=remote_address, port=remote_port)
connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
if connack is not None:
await plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
await connack.to_stream(writer)
await writer.close()
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
yield from connack.to_stream(writer)
yield from writer.close()
raise MQTTException(error_msg)
incoming_session = Session()

Wyświetl plik

@ -1,6 +1,7 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import asyncio
from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler, EVENT_MQTT_PACKET_RECEIVED
from hbmqtt.mqtt.packet import *
@ -27,13 +28,15 @@ class ClientProtocolHandler(ProtocolHandler):
self._disconnect_waiter = None
self._pingresp_waiter = None
async def start(self):
await super().start()
@asyncio.coroutine
def start(self):
yield from super().start()
if self._disconnect_waiter is None:
self._disconnect_waiter = futures.Future(loop=self._loop)
async def stop(self):
await super().stop()
@asyncio.coroutine
def stop(self):
yield from super().stop()
if self._ping_task:
try:
self.logger.debug("Cancel ping task")
@ -77,11 +80,12 @@ class ClientProtocolHandler(ProtocolHandler):
packet = ConnectPacket(vh=vh, payload=payload)
return packet
async def mqtt_connect(self):
@asyncio.coroutine
def mqtt_connect(self):
connect_packet = self._build_connect_packet()
await self._send_packet(connect_packet)
connack = await ConnackPacket.from_stream(self.reader)
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
yield from self._send_packet(connect_packet)
connack = yield from ConnackPacket.from_stream(self.reader)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
return connack.return_code
def handle_write_timeout(self):
@ -95,7 +99,8 @@ class ClientProtocolHandler(ProtocolHandler):
def handle_read_timeout(self):
pass
async def mqtt_subscribe(self, topics, packet_id):
@asyncio.coroutine
def mqtt_subscribe(self, topics, packet_id):
"""
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
:return:
@ -103,17 +108,18 @@ class ClientProtocolHandler(ProtocolHandler):
# Build and send SUBSCRIBE message
subscribe = SubscribePacket.build(topics, packet_id)
await self._send_packet(subscribe)
yield from self._send_packet(subscribe)
# Wait for SUBACK is received
waiter = futures.Future(loop=self._loop)
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
return_codes = await waiter
return_codes = yield from waiter
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
return return_codes
async def handle_suback(self, suback: SubackPacket):
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
packet_id = suback.variable_header.packet_id
try:
waiter = self._subscriptions_waiter.get(packet_id)
@ -121,20 +127,22 @@ class ClientProtocolHandler(ProtocolHandler):
except KeyError as ke:
self.logger.warning("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
async def mqtt_unsubscribe(self, topics, packet_id):
@asyncio.coroutine
def mqtt_unsubscribe(self, topics, packet_id):
"""
:param topics: array of topics ['/a/b', ...]
:return:
"""
unsubscribe = UnsubscribePacket.build(topics, packet_id)
await self._send_packet(unsubscribe)
yield from self._send_packet(unsubscribe)
waiter = futures.Future(loop=self._loop)
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
await waiter
yield from waiter
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
async def handle_unsuback(self, unsuback: UnsubackPacket):
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
packet_id = unsuback.variable_header.packet_id
try:
waiter = self._unsubscriptions_waiter.get(packet_id)
@ -142,25 +150,30 @@ class ClientProtocolHandler(ProtocolHandler):
except KeyError as ke:
self.logger.warning("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
async def mqtt_disconnect(self):
@asyncio.coroutine
def mqtt_disconnect(self):
disconnect_packet = DisconnectPacket()
await self._send_packet(disconnect_packet)
yield from self._send_packet(disconnect_packet)
async def mqtt_ping(self):
@asyncio.coroutine
def mqtt_ping(self):
ping_packet = PingReqPacket()
await self._send_packet(ping_packet)
yield from self._send_packet(ping_packet)
self._pingresp_waiter = futures.Future(loop=self._loop)
resp = await self._pingresp_queue.get()
resp = yield from self._pingresp_queue.get()
self._pingresp_waiter = None
return resp
async def handle_pingresp(self, pingresp: PingRespPacket):
await self._pingresp_queue.put(pingresp)
@asyncio.coroutine
def handle_pingresp(self, pingresp: PingRespPacket):
yield from self._pingresp_queue.put(pingresp)
async def handle_connection_closed(self):
@asyncio.coroutine
def handle_connection_closed(self):
self.logger.debug("Broker closed connection")
if not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None)
async def wait_disconnect(self):
await self._disconnect_waiter
@asyncio.coroutine
def wait_disconnect(self):
yield from self._disconnect_waiter

Wyświetl plik

@ -97,20 +97,22 @@ class ProtocolHandler:
else:
return False
async def start(self):
@asyncio.coroutine
def start(self):
if not self._is_attached():
raise ProtocolHandlerException("Handler is not attached to a stream")
self._reader_ready = asyncio.Event(loop=self._loop)
self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop)
await asyncio.wait([self._reader_ready.wait()], loop=self._loop)
yield from asyncio.wait([self._reader_ready.wait()], loop=self._loop)
if self.keepalive_timeout:
self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout)
self.logger.debug("Handler tasks started")
await self._retry_deliveries()
yield from self._retry_deliveries()
self.logger.debug("Handler ready")
async def stop(self):
@asyncio.coroutine
def stop(self):
# Stop messages flow waiter
self._stop_waiters()
if self._keepalive_task:
@ -118,11 +120,11 @@ class ProtocolHandler:
self.logger.debug("waiting for tasks to be stopped")
if not self._reader_task.done():
self._reader_task.cancel()
await asyncio.wait(
yield from asyncio.wait(
[self._reader_stopped.wait()], loop=self._loop)
self.logger.debug("closing writer")
try:
await self.writer.close()
yield from self.writer.close()
except Exception as e:
self.logger.debug("Handler writer close failed: %s" % e)
@ -138,7 +140,8 @@ class ProtocolHandler:
self._pubrel_waiters.values()):
waiter.cancel()
async def _retry_deliveries(self):
@asyncio.coroutine
def _retry_deliveries(self):
"""
Handle [MQTT-4.4.0-1] by resending PUBLISH and PUBREL messages for pending out messages
:return:
@ -148,12 +151,13 @@ class ProtocolHandler:
for message in itertools.chain(self.session.inflight_in.values(), self.session.inflight_out.values()):
tasks.append(asyncio.wait_for(self._handle_message_flow(message), 10, loop=self._loop))
if tasks:
done, pending = await asyncio.wait(tasks)
done, pending = yield from asyncio.wait(tasks)
self.logger.debug("%d messages redelivered" % len(done))
self.logger.debug("%d messages not redelivered due to timeout" % len(pending))
self.logger.debug("End messages delivery retries")
async def mqtt_publish(self, topic, data, qos, retain, ack_timeout=None):
@asyncio.coroutine
def mqtt_publish(self, topic, data, qos, retain, ack_timeout=None):
"""
Sends a MQTT publish message and manages messages flows.
This methods doesn't return until the message has been acknowledged by receiver or timeout occur
@ -175,13 +179,14 @@ class ProtocolHandler:
message = OutgoingApplicationMessage(packet_id, topic, qos, data, retain)
# Handle message flow
if ack_timeout is not None and ack_timeout > 0:
await asyncio.wait_for(self._handle_message_flow(message), ack_timeout, loop=self._loop)
yield from asyncio.wait_for(self._handle_message_flow(message), ack_timeout, loop=self._loop)
else:
await self._handle_message_flow(message)
yield from self._handle_message_flow(message)
return message
async def _handle_message_flow(self, app_message):
@asyncio.coroutine
def _handle_message_flow(self, app_message):
"""
Handle protocol flow for incoming and outgoing messages, depending on service level and according to MQTT
spec. paragraph 4.3-Quality of Service levels and protocol flows
@ -189,15 +194,16 @@ class ProtocolHandler:
:return: nothing.
"""
if app_message.qos == QOS_0:
await self._handle_qos0_message_flow(app_message)
yield from self._handle_qos0_message_flow(app_message)
elif app_message.qos == QOS_1:
await self._handle_qos1_message_flow(app_message)
yield from self._handle_qos1_message_flow(app_message)
elif app_message.qos == QOS_2:
await self._handle_qos2_message_flow(app_message)
yield from self._handle_qos2_message_flow(app_message)
else:
raise HBMQTTException("Unexcepted QOS value '%d" % str(app_message.qos))
async def _handle_qos0_message_flow(self, app_message):
@asyncio.coroutine
def _handle_qos0_message_flow(self, app_message):
"""
Handle QOS_0 application message acknowledgment
For incoming messages, this method stores the message
@ -209,7 +215,7 @@ class ProtocolHandler:
if app_message.direction == OUTGOING:
packet = app_message.build_publish_packet()
# Send PUBLISH packet
await self._send_packet(packet)
yield from self._send_packet(packet)
app_message.publish_packet = packet
elif app_message.direction == INCOMING:
if app_message.publish_packet.dup_flag:
@ -221,7 +227,8 @@ class ProtocolHandler:
except:
self.logger.warning("delivered messages queue full. QOS_0 message discarded")
async def _handle_qos1_message_flow(self, app_message):
@asyncio.coroutine
def _handle_qos1_message_flow(self, app_message):
"""
Handle QOS_1 application message acknowledgment
For incoming messages, this method stores the message and reply with PUBACK
@ -242,13 +249,13 @@ class ProtocolHandler:
else:
publish_packet = app_message.build_publish_packet()
# Send PUBLISH packet
await self._send_packet(publish_packet)
yield from self._send_packet(publish_packet)
app_message.publish_packet = publish_packet
# Wait for puback
waiter = asyncio.Future(loop=self._loop)
self._puback_waiters[app_message.packet_id] = waiter
await waiter
yield from waiter
del self._puback_waiters[app_message.packet_id]
app_message.puback_packet = waiter.result()
@ -257,13 +264,14 @@ class ProtocolHandler:
elif app_message.direction == INCOMING:
# Initiate delivery
self.logger.debug("Add message to delivery")
await self.session.delivered_message_queue.put(app_message)
yield from self.session.delivered_message_queue.put(app_message)
# Send PUBACK
puback = PubackPacket.build(app_message.packet_id)
await self._send_packet(puback)
yield from self._send_packet(puback)
app_message.puback_packet = puback
async def _handle_qos2_message_flow(self, app_message):
@asyncio.coroutine
def _handle_qos2_message_flow(self, app_message):
"""
Handle QOS_2 application message acknowledgment
For incoming messages, this method stores the message, sends PUBREC, waits for PUBREL, initiate delivery
@ -288,7 +296,7 @@ class ProtocolHandler:
self.session.inflight_out[app_message.packet_id] = app_message
publish_packet = app_message.build_publish_packet()
# Send PUBLISH packet
await self._send_packet(publish_packet)
yield from self._send_packet(publish_packet)
app_message.publish_packet = publish_packet
# Wait PUBREC
if app_message.packet_id in self._pubrec_waiters:
@ -299,17 +307,17 @@ class ProtocolHandler:
raise HBMQTTException(message)
waiter = asyncio.Future(loop=self._loop)
self._pubrec_waiters[app_message.packet_id] = waiter
await waiter
yield from waiter
del self._pubrec_waiters[app_message.packet_id]
app_message.pubrec_packet = waiter.result()
if not app_message.pubcomp_packet:
# Send pubrel
app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id)
await self._send_packet(app_message.pubrel_packet)
yield from self._send_packet(app_message.pubrel_packet)
# Wait for PUBCOMP
waiter = asyncio.Future(loop=self._loop)
self._pubcomp_waiters[app_message.packet_id] = waiter
await waiter
yield from waiter
del self._pubcomp_waiters[app_message.packet_id]
app_message.pubcomp_packet = waiter.result()
# Discard inflight message
@ -318,7 +326,7 @@ class ProtocolHandler:
self.session.inflight_in[app_message.packet_id] = app_message
# Send pubrec
pubrec_packet = PubrecPacket.build(app_message.packet_id)
await self._send_packet(pubrec_packet)
yield from self._send_packet(pubrec_packet)
app_message.pubrec_packet = pubrec_packet
# Wait PUBREL
if app_message.packet_id in self._pubrel_waiters:
@ -329,18 +337,19 @@ class ProtocolHandler:
raise HBMQTTException(message)
waiter = asyncio.Future(loop=self._loop)
self._pubrel_waiters[app_message.packet_id] = waiter
await waiter
yield from waiter
del self._pubrel_waiters[app_message.packet_id]
app_message.pubrel_packet = waiter.result()
# Initiate delivery and discard message
await self.session.delivered_message_queue.put(app_message)
yield from self.session.delivered_message_queue.put(app_message)
del self.session.inflight_in[app_message.packet_id]
# Send pubcomp
pubcomp_packet = PubcompPacket.build(app_message.packet_id)
await self._send_packet(pubcomp_packet)
yield from self._send_packet(pubcomp_packet)
app_message.pubcomp_packet = pubcomp_packet
async def _reader_loop(self):
@asyncio.coroutine
def _reader_loop(self):
self.logger.debug("%s Starting reader coro" % self.session.client_id)
running_tasks = collections.deque()
keepalive_timeout = self.session.keep_alive
@ -354,17 +363,17 @@ class ProtocolHandler:
if len(running_tasks) > 1:
self.logger.debug("handler running tasks: %d" % len(running_tasks))
fixed_header = await asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader),
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader),
keepalive_timeout, loop=self._loop)
if fixed_header:
if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15:
self.logger.warning("%s Received reserved packet, which is forbidden: closing connection" %
(self.session.client_id))
await self.handle_connection_closed()
yield from self.handle_connection_closed()
else:
cls = packet_class(fixed_header)
packet = await cls.from_stream(self.reader, fixed_header=fixed_header)
await self.plugins_manager.fire_event(
packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header)
yield from self.plugins_manager.fire_event(
EVENT_MQTT_PACKET_RECEIVED, packet=packet, session=self.session)
task = None
if packet.fixed_header.packet_type == CONNACK:
@ -418,30 +427,32 @@ class ProtocolHandler:
break
while running_tasks:
running_tasks.popleft().cancel()
await self.handle_connection_closed()
yield from self.handle_connection_closed()
self._reader_stopped.set()
self.logger.debug("%s Reader coro stopped" % self.session.client_id)
await self.stop()
yield from self.stop()
async def _send_packet(self, packet):
@asyncio.coroutine
def _send_packet(self, packet):
try:
await packet.to_stream(self.writer)
yield from packet.to_stream(self.writer)
if self._keepalive_task:
self._keepalive_task.cancel()
self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout)
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
except ConnectionResetError as cre:
await self.handle_connection_closed()
yield from self.handle_connection_closed()
raise
except BaseException as e:
self.logger.warning("Unhandled exception: %s" % e)
raise
async def mqtt_deliver_next_message(self):
@asyncio.coroutine
def mqtt_deliver_next_message(self):
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("%d message(s) available for delivery" % self.session.delivered_message_queue.qsize())
message = await self.session.delivered_message_queue.get()
message = yield from self.session.delivered_message_queue.get()
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("Delivering message %s" % message)
return message
@ -452,37 +463,48 @@ class ProtocolHandler:
def handle_read_timeout(self):
self.logger.debug('%s read timeout unhandled' % self.session.client_id)
async def handle_connack(self, connack: ConnackPacket):
@asyncio.coroutine
def handle_connack(self, connack: ConnackPacket):
self.logger.debug('%s CONNACK unhandled' % self.session.client_id)
async def handle_connect(self, connect: ConnectPacket):
@asyncio.coroutine
def handle_connect(self, connect: ConnectPacket):
self.logger.debug('%s CONNECT unhandled' % self.session.client_id)
async def handle_subscribe(self, subscribe: SubscribePacket):
@asyncio.coroutine
def handle_subscribe(self, subscribe: SubscribePacket):
self.logger.debug('%s SUBSCRIBE unhandled' % self.session.client_id)
async def handle_unsubscribe(self, subscribe: UnsubscribePacket):
@asyncio.coroutine
def handle_unsubscribe(self, subscribe: UnsubscribePacket):
self.logger.debug('%s UNSUBSCRIBE unhandled' % self.session.client_id)
async def handle_suback(self, suback: SubackPacket):
@asyncio.coroutine
def handle_suback(self, suback: SubackPacket):
self.logger.debug('%s SUBACK unhandled' % self.session.client_id)
async def handle_unsuback(self, unsuback: UnsubackPacket):
@asyncio.coroutine
def handle_unsuback(self, unsuback: UnsubackPacket):
self.logger.debug('%s UNSUBACK unhandled' % self.session.client_id)
async def handle_pingresp(self, pingresp: PingRespPacket):
@asyncio.coroutine
def handle_pingresp(self, pingresp: PingRespPacket):
self.logger.debug('%s PINGRESP unhandled' % self.session.client_id)
async def handle_pingreq(self, pingreq: PingReqPacket):
@asyncio.coroutine
def handle_pingreq(self, pingreq: PingReqPacket):
self.logger.debug('%s PINGREQ unhandled' % self.session.client_id)
async def handle_disconnect(self, disconnect: DisconnectPacket):
@asyncio.coroutine
def handle_disconnect(self, disconnect: DisconnectPacket):
self.logger.debug('%s DISCONNECT unhandled' % self.session.client_id)
async def handle_connection_closed(self):
@asyncio.coroutine
def handle_connection_closed(self):
self.logger.debug('%s Connection closed unhandled' % self.session.client_id)
async def handle_puback(self, puback: PubackPacket):
@asyncio.coroutine
def handle_puback(self, puback: PubackPacket):
packet_id = puback.variable_header.packet_id
try:
waiter = self._puback_waiters[packet_id]
@ -492,7 +514,8 @@ class ProtocolHandler:
except InvalidStateError:
self.logger.warning("PUBACK waiter with Id '%d' already done" % packet_id)
async def handle_pubrec(self, pubrec: PubrecPacket):
@asyncio.coroutine
def handle_pubrec(self, pubrec: PubrecPacket):
packet_id = pubrec.packet_id
try:
waiter = self._pubrec_waiters[packet_id]
@ -502,7 +525,8 @@ class ProtocolHandler:
except InvalidStateError:
self.logger.warning("PUBREC waiter with Id '%d' already done" % packet_id)
async def handle_pubcomp(self, pubcomp: PubcompPacket):
@asyncio.coroutine
def handle_pubcomp(self, pubcomp: PubcompPacket):
packet_id = pubcomp.packet_id
try:
waiter = self._pubcomp_waiters[packet_id]
@ -512,7 +536,8 @@ class ProtocolHandler:
except InvalidStateError:
self.logger.warning("PUBCOMP waiter with Id '%d' already done" % packet_id)
async def handle_pubrel(self, pubrel: PubrelPacket):
@asyncio.coroutine
def handle_pubrel(self, pubrel: PubrelPacket):
packet_id = pubrel.packet_id
try:
waiter = self._pubrel_waiters[packet_id]
@ -522,11 +547,12 @@ class ProtocolHandler:
except InvalidStateError:
self.logger.warning("PUBREL waiter with Id '%d' already done" % packet_id)
async def handle_publish(self, publish_packet: PublishPacket):
@asyncio.coroutine
def handle_publish(self, publish_packet: PublishPacket):
packet_id = publish_packet.variable_header.packet_id
qos = publish_packet.qos
incoming_message = IncomingApplicationMessage(packet_id, publish_packet.topic_name, qos, publish_packet.data, publish_packet.retain_flag)
incoming_message.publish_packet = publish_packet
await self._handle_message_flow(incoming_message)
yield from self._handle_message_flow(incoming_message)
self.logger.debug("Message queue size: %d" % self.session.delivered_message_queue.qsize())

Wyświetl plik

@ -25,11 +25,12 @@ class PublishVariableHeader(MQTTVariableHeader):
return out
@classmethod
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
topic_name = await decode_string(reader)
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
topic_name = yield from decode_string(reader)
has_qos = (fixed_header.flags >> 1) & 0x03
if has_qos:
packet_id = await decode_packet_id(reader)
packet_id = yield from decode_packet_id(reader)
else:
packet_id = None
return cls(topic_name, packet_id)
@ -44,9 +45,10 @@ class PublishPayload(MQTTPayload):
return self.data
@classmethod
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
variable_header: MQTTVariableHeader):
data = await reader.read(fixed_header.remaining_length-variable_header.bytes_length)
data = yield from reader.read(fixed_header.remaining_length-variable_header.bytes_length)
return cls(data)
def __repr__(self):

Wyświetl plik

@ -27,13 +27,14 @@ class SubackPayload(MQTTPayload):
return out
@classmethod
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
@asyncio.coroutine
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
variable_header: MQTTVariableHeader):
return_codes = []
bytes_to_read = fixed_header.remaining_length - variable_header.bytes_length
for i in range(0, bytes_to_read):
try:
return_code_byte = await read_or_raise(reader, 1)
return_code_byte = yield from read_or_raise(reader, 1)
return_code = bytes_to_int(return_code_byte)
return_codes.append(return_code)
except NoDataException:

Wyświetl plik

@ -19,15 +19,16 @@ class SubscribePayload(MQTTPayload):
return out
@classmethod
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
variable_header: MQTTVariableHeader):
topics = []
payload_length = fixed_header.remaining_length - variable_header.bytes_length
read_bytes = 0
while read_bytes < payload_length:
try:
topic = await decode_string(reader)
qos_byte = await read_or_raise(reader, 1)
topic = yield from decode_string(reader)
qos_byte = yield from read_or_raise(reader, 1)
qos = bytes_to_int(qos_byte)
topics.append((topic, qos))
read_bytes += 2 + len(topic.encode('utf-8')) + 1

Wyświetl plik

@ -18,14 +18,15 @@ class UnubscribePayload(MQTTPayload):
return out
@classmethod
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
variable_header: MQTTVariableHeader):
topics = []
payload_length = fixed_header.remaining_length - variable_header.bytes_length
read_bytes = 0
while read_bytes < payload_length:
try:
topic = await decode_string(reader)
topic = yield from decode_string(reader)
topics.append(topic)
read_bytes += 2 + len(topic.encode('utf-8'))
except NoDataException:

Wyświetl plik

@ -26,7 +26,8 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
def __init__(self, context):
super().__init__(context)
async def authenticate(self, *args, **kwargs):
@asyncio.coroutine
def authenticate(self, *args, **kwargs):
authenticated = super().authenticate(*args, **kwargs)
if authenticated:
allow_anonymous = self.auth_config.get('allow-anonymous', True) # allow anonymous by default
@ -73,7 +74,8 @@ class FileAuthPlugin(BaseAuthPlugin):
else:
self.context.logger.debug("Configuration parameter 'password_file' not found")
async def authenticate(self, *args, **kwargs):
@asyncio.coroutine
def authenticate(self, *args, **kwargs):
authenticated = super().authenticate(*args, **kwargs)
if authenticated:
session = kwargs.get('session', None)

Wyświetl plik

@ -12,7 +12,8 @@ class EventLoggerPlugin:
def __init__(self, context):
self.context = context
async def log_event(self, *args, **kwargs):
@asyncio.coroutine
def log_event(self, *args, **kwargs):
self.context.logger.info("### '%s' EVENT FIRED ###" % kwargs['event_name'].replace('old', ''))
def __getattr__(self, name):
@ -24,7 +25,8 @@ class PacketLoggerPlugin:
def __init__(self, context):
self.context = context
async def on_mqtt_packet_received(self, *args, **kwargs):
@asyncio.coroutine
def on_mqtt_packet_received(self, *args, **kwargs):
packet = kwargs.get('packet')
session = kwargs.get('session', None)
if self.context.logger.isEnabledFor(logging.DEBUG):
@ -33,7 +35,8 @@ class PacketLoggerPlugin:
else:
self.context.logger.debug("<-in-- %s" % repr(packet))
async def on_mqtt_packet_sent(self, *args, **kwargs):
@asyncio.coroutine
def on_mqtt_packet_sent(self, *args, **kwargs):
packet = kwargs.get('packet')
session = kwargs.get('session', None)
if self.context.logger.isEnabledFor(logging.DEBUG):

Wyświetl plik

@ -87,13 +87,14 @@ class PluginManager:
return p
return None
async def close(self):
@asyncio.coroutine
def close(self):
"""
Free PluginManager resources and cancel pending event methods
This method call a close() coroutine for each plugin, allowing plugins to close and free resources
:return:
"""
await self.map_plugin_coro("close")
yield from self.map_plugin_coro("close")
for task in self._fired_events:
task.cancel()
@ -108,10 +109,11 @@ class PluginManager:
def _schedule_coro(self, coro):
return asyncio.ensure_future(coro, loop=self._loop)
async def fire_event(self, event_name, wait=False, *args, **kwargs):
@asyncio.coroutine
def fire_event(self, event_name, wait=False, *args, **kwargs):
"""
Fire an event to plugins.
PluginManager schedule async calls for each plugin on method called "on_" + event_name
PluginManager schedule @asyncio.coroutinecalls for each plugin on method called "on_" + event_name
For example, on_connect will be called on event 'connect'
Method calls are schedule in the asyn loop. wait parameter must be set to true to wait until all
mehtods are completed.
@ -133,11 +135,12 @@ class PluginManager:
(event_method_name, plugin.name))
if wait:
if tasks:
await asyncio.wait(tasks, loop=self._loop)
yield from asyncio.wait(tasks, loop=self._loop)
else:
self._fired_events.extend(tasks)
async def map(self, coro, *args, **kwargs):
@asyncio.coroutine
def map(self, coro, *args, **kwargs):
"""
Schedule a given coroutine call for each plugin.
The coro called get the Plugin instance as first argument of its method call
@ -164,7 +167,7 @@ class PluginManager:
self.logger.error("Method '%r' on plugin '%s' is not a coroutine" %
(coro, plugin.name))
if tasks:
ret_list = await asyncio.gather(*tasks, loop=self._loop)
ret_list = yield from asyncio.gather(*tasks, loop=self._loop)
# Create result map plugin=>ret
ret_dict = {k: v for k, v in zip(plugins_list, ret_list)}
else:
@ -172,15 +175,17 @@ class PluginManager:
return ret_dict
@staticmethod
async def _call_coro(plugin, coro_name, *args, **kwargs):
@asyncio.coroutine
def _call_coro(plugin, coro_name, *args, **kwargs):
try:
coro = getattr(plugin.object, coro_name, None)(*args, **kwargs)
return await coro
return (yield from coro)
except TypeError:
# Plugin doesn't implement coro_name
return None
async def map_plugin_coro(self, coro_name, *args, **kwargs):
@asyncio.coroutine
def map_plugin_coro(self, coro_name, *args, **kwargs):
"""
Call a plugin declared by plugin by its name
:param coro_name:
@ -188,4 +193,4 @@ class PluginManager:
:param kwargs:
:return:
"""
return await self.map(self._call_coro, coro_name, *args, **kwargs)
return (yield from self.map(self._call_coro, coro_name, *args, **kwargs))

Wyświetl plik

@ -32,7 +32,8 @@ class SQLitePlugin:
if self.cursor:
self.cursor.execute("CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)")
async def save_session(self, session):
@asyncio.coroutine
def save_session(self, session):
if self.cursor:
dump = pickle.dumps(session)
try:
@ -42,7 +43,8 @@ class SQLitePlugin:
except Exception as e:
self.context.logger.error("Failed saving session '%s': %s" % (session, e))
async def find_session(self, client_id):
@asyncio.coroutine
def find_session(self, client_id):
if self.cursor:
row = self.cursor.execute("SELECT data FROM session where client_id=?", (client_id,)).fetchone()
if row:
@ -50,12 +52,14 @@ class SQLitePlugin:
else:
return None
async def del_session(self, client_id):
@asyncio.coroutine
def del_session(self, client_id):
if self.cursor:
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
self.conn.commit()
async def on_broker_post_shutdown(self):
@asyncio.coroutine
def on_broker_post_shutdown(self):
if self.conn:
self.conn.close()
self.context.logger.info("Database file '%s' closed" % self.db_file)

Wyświetl plik

@ -42,17 +42,20 @@ class BrokerSysPlugin:
STAT_PUBLISH_SENT):
self._stats[stat] = 0
async def _broadcast_sys_topic(self, topic_basename, data):
return await self.context.broadcast_message(DOLLAR_SYS_ROOT + topic_basename, data)
@asyncio.coroutine
def _broadcast_sys_topic(self, topic_basename, data):
return (yield from self.context.broadcast_message(DOLLAR_SYS_ROOT + topic_basename, data))
def schedule_broadcast_sys_topic(self, topic_basename, data):
return asyncio.ensure_future(self._broadcast_sys_topic(DOLLAR_SYS_ROOT + topic_basename, data),
loop=self.context.loop)
async def on_broker_pre_start(self, *args, **kwargs):
@asyncio.coroutine
def on_broker_pre_start(self, *args, **kwargs):
self._clear_stats()
async def on_broker_post_start(self, *args, **kwargs):
@asyncio.coroutine
def on_broker_post_start(self, *args, **kwargs):
self._stats[STAT_START_TIME] = datetime.now()
from hbmqtt.version import get_version
version = 'HBMQTT version ' + get_version()
@ -70,7 +73,8 @@ class BrokerSysPlugin:
pass
# 'sys_internal' config parameter not found
async def on_broker_pre_stop(self, *args, **kwargs):
@asyncio.coroutine
def on_broker_pre_stop(self, *args, **kwargs):
# Stop $SYS topics broadcasting
if self.sys_handle:
self.sys_handle.cancel()
@ -127,7 +131,8 @@ class BrokerSysPlugin:
self.context.logger.debug("Broadcasting $SYS topics")
self.sys_handle = self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
async def on_mqtt_packet_received(self, *args, **kwargs):
@asyncio.coroutine
def on_mqtt_packet_received(self, *args, **kwargs):
packet = kwargs.get('packet')
if packet:
packet_size = packet.bytes_length
@ -136,7 +141,8 @@ class BrokerSysPlugin:
if packet.fixed_header.packet_type == PUBLISH:
self._stats[STAT_PUBLISH_RECEIVED] += 1
async def on_mqtt_packet_sent(self, *args, **kwargs):
@asyncio.coroutine
def on_mqtt_packet_sent(self, *args, **kwargs):
packet = kwargs.get('packet')
if packet:
packet_size = packet.bytes_length
@ -145,10 +151,12 @@ class BrokerSysPlugin:
if packet.fixed_header.packet_type == PUBLISH:
self._stats[STAT_PUBLISH_SENT] += 1
async def on_broker_client_connected(self, *args, **kwargs):
@asyncio.coroutine
def on_broker_client_connected(self, *args, **kwargs):
self._stats[STAT_CLIENTS_CONNECTED] += 1
self._stats[STAT_CLIENTS_MAXIMUM] = max(self._stats[STAT_CLIENTS_MAXIMUM], self._stats[STAT_CLIENTS_CONNECTED])
async def on_broker_client_disconnected(self, *args, **kwargs):
@asyncio.coroutine
def on_broker_client_disconnected(self, *args, **kwargs):
self._stats[STAT_CLIENTS_CONNECTED] -= 1
self._stats[STAT_CLIENTS_DISCONNECTED] += 1

Wyświetl plik

@ -30,10 +30,11 @@ config = {
broker = Broker(config)
async def test_coro():
await broker.start()
#await asyncio.sleep(5)
#await broker.shutdown()
@asyncio.coroutine
def test_coro():
yield from broker.start()
#yield from asyncio.sleep(5)
#yield from broker.shutdown()
if __name__ == '__main__':

Wyświetl plik

@ -18,11 +18,12 @@ config = {
}
C = MQTTClient(config=config)
async def test_coro():
await C.connect('mqtt://test.mosquitto.org:1883/')
await asyncio.sleep(18)
@asyncio.coroutine
def test_coro():
yield from C.connect('mqtt://test.mosquitto.org:1883/')
yield from asyncio.sleep(18)
await C.disconnect()
yield from C.disconnect()
if __name__ == '__main__':

Wyświetl plik

@ -29,28 +29,30 @@ def disconnected(future):
asyncio.get_event_loop().stop()
async def test_coro():
await C.connect('mqtt://test:test@localhost:1883/')
@asyncio.coroutine
def test_coro():
yield from C.connect('mqtt://test:test@localhost:1883/')
tasks = [
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
]
await asyncio.wait(tasks)
yield from asyncio.wait(tasks)
logger.info("messages published")
await C.disconnect()
yield from C.disconnect()
async def test_coro2():
@asyncio.coroutine
def test_coro2():
try:
future = await C.connect('mqtt://test.mosquitto.org:1883/')
future = yield from C.connect('mqtt://test.mosquitto.org:1883/')
future.add_done_callback(disconnected)
message = await C.publish('a/b', b'TEST MESSAGE WITH QOS_0', qos=0x00)
message = await C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)
message = await C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)
message = yield from C.publish('a/b', b'TEST MESSAGE WITH QOS_0', qos=0x00)
message = yield from C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)
message = yield from C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)
#print(message)
logger.info("messages published")
await C.disconnect()
yield from C.disconnect()
except ConnectException as ce:
logger.error("Connection failed: %s" % ce)
asyncio.get_event_loop().stop()

Wyświetl plik

@ -23,16 +23,17 @@ config = {
C = MQTTClient(config=config)
#C = MQTTClient()
async def test_coro():
await C.connect('mqtts://test.mosquitto.org/', cafile='mosquitto.org.crt')
@asyncio.coroutine
def test_coro():
yield from C.connect('mqtts://test.mosquitto.org/', cafile='mosquitto.org.crt')
tasks = [
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
]
await asyncio.wait(tasks)
yield from asyncio.wait(tasks)
logger.info("messages published")
await C.disconnect()
yield from C.disconnect()
if __name__ == '__main__':

Wyświetl plik

@ -24,16 +24,17 @@ config = {
C = MQTTClient(config=config)
#C = MQTTClient()
async def test_coro():
await C.connect('wss://test.mosquitto.org:8081/', cafile='mosquitto.org.crt')
@asyncio.coroutine
def test_coro():
yield from C.connect('wss://test.mosquitto.org:8081/', cafile='mosquitto.org.crt')
tasks = [
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
]
await asyncio.wait(tasks)
yield from asyncio.wait(tasks)
logger.info("messages published")
await C.disconnect()
yield from C.disconnect()
if __name__ == '__main__':

Wyświetl plik

@ -15,22 +15,23 @@ logger = logging.getLogger(__name__)
C = MQTTClient()
async def uptime_coro():
await C.connect('mqtt://localhost/')
@asyncio.coroutine
def uptime_coro():
yield from C.connect('mqtt://localhost/')
# Subscribe to '$SYS/broker/uptime' with QOS=1
await C.subscribe([
yield from C.subscribe([
('$SYS/broker/uptime', QOS_1),
('$SYS/broker/load/#', QOS_2),
])
logger.info("Subscribed")
try:
for i in range(1, 100):
message = await C.deliver_message()
message = yield from C.deliver_message()
packet = message.publish_packet
print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data)))
await C.unsubscribe(['$SYS/broker/uptime'])
yield from C.unsubscribe(['$SYS/broker/uptime'])
logger.info("UnSubscribed")
await C.disconnect()
yield from C.disconnect()
except ClientException as ce:
logger.error("Client exception: %s" % ce)

Wyświetl plik

@ -47,8 +47,8 @@ logger = logging.getLogger(__name__)
def main(*args, **kwargs):
if sys.version_info[:2] < (3, 5):
logger.fatal("Error: Python 3.5 is required")
if sys.version_info[:2] < (3, 4):
logger.fatal("Error: Python 3.4+ is required")
sys.exit(-1)
arguments = docopt(__doc__, version=get_version())

Wyświetl plik

@ -2,6 +2,6 @@ keep_alive: 10
ping_delay: 1
default_qos: 0
default_retain: false
auto_reconnect: true
auto_reconnect: false
reconnect_max_interval: 10
reconnect_retries: 2

Wyświetl plik

@ -107,8 +107,8 @@ def do_pub(client, arguments):
def main(*args, **kwargs):
if sys.version_info[:2] < (3, 5):
logger.fatal("Error: Python 3.5 is required")
if sys.version_info[:2] < (3, 4):
logger.fatal("Error: Python 3.4+ is required")
sys.exit(-1)
arguments = docopt(__doc__, version=get_version())

Wyświetl plik

@ -88,8 +88,8 @@ def do_sub(client, arguments):
def main(*args, **kwargs):
if sys.version_info[:2] < (3, 5):
logger.fatal("Error: Python 3.5 is required")
if sys.version_info[:2] < (3, 4):
logger.fatal("Error: Python 3.4+ is required")
sys.exit(-1)
arguments = docopt(__doc__, version=get_version())

Wyświetl plik

@ -41,18 +41,20 @@ class ProtocolHandlerTest(unittest.TestCase):
self.check_empty_waiters(handler)
def test_start_stop(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
pass
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
s = Session()
reader, writer = await asyncio.open_connection('127.0.0.1', 8888)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888)
reader_adapted, writer_adapted = adapt(reader, writer)
handler = ProtocolHandler(self.plugin_manager)
handler.attach(s, reader_adapted, writer_adapted)
await self.start_handler(handler, s)
await self.stop_handler(handler, s)
yield from self.start_handler(handler, s)
yield from self.stop_handler(handler, s)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -67,31 +69,33 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_publish_qos0(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
try:
packet = await PublishPacket.from_stream(reader)
packet = yield from PublishPacket.from_stream(reader)
self.assertEquals(packet.variable_header.topic_name, '/topic')
self.assertEquals(packet.qos, QOS_0)
self.assertIsNone(packet.packet_id)
except Exception as ae:
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
s = Session()
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
handler.attach(s, reader_adapted, writer_adapted)
await self.start_handler(handler, s)
message = await handler.mqtt_publish('/topic', b'test_data', QOS_0, False)
yield from self.start_handler(handler, s)
message = yield from handler.mqtt_publish('/topic', b'test_data', QOS_0, False)
self.assertIsInstance(message, OutgoingApplicationMessage)
self.assertIsNotNone(message.publish_packet)
self.assertIsNone(message.puback_packet)
self.assertIsNone(message.pubrec_packet)
self.assertIsNone(message.pubrel_packet)
self.assertIsNone(message.pubcomp_packet)
await self.stop_handler(handler, s)
yield from self.stop_handler(handler, s)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -106,8 +110,9 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_publish_qos1(self):
async def server_mock(reader, writer):
packet = await PublishPacket.from_stream(reader)
@asyncio.coroutine
def server_mock(reader, writer):
packet = yield from PublishPacket.from_stream(reader)
try:
self.assertEquals(packet.variable_header.topic_name, '/topic')
self.assertEquals(packet.qos, QOS_1)
@ -115,25 +120,26 @@ class ProtocolHandlerTest(unittest.TestCase):
self.assertIn(packet.packet_id, self.session.inflight_out)
self.assertIn(packet.packet_id, self.handler._puback_waiters)
puback = PubackPacket.build(packet.packet_id)
await puback.to_stream(writer)
yield from puback.to_stream(writer)
except Exception as ae:
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.start_handler(self.handler, self.session)
message = await self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False)
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False)
self.assertIsInstance(message, OutgoingApplicationMessage)
self.assertIsNotNone(message.publish_packet)
self.assertIsNotNone(message.puback_packet)
self.assertIsNone(message.pubrec_packet)
self.assertIsNone(message.pubrel_packet)
self.assertIsNone(message.pubcomp_packet)
await self.stop_handler(self.handler, self.session)
yield from self.stop_handler(self.handler, self.session)
if not future.done():
future.set_result(True)
except Exception as ae:
@ -151,39 +157,41 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_publish_qos2(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
try:
packet = await PublishPacket.from_stream(reader)
packet = yield from PublishPacket.from_stream(reader)
self.assertEquals(packet.topic_name, '/topic')
self.assertEquals(packet.qos, QOS_2)
self.assertIsNotNone(packet.packet_id)
self.assertIn(packet.packet_id, self.session.inflight_out)
self.assertIn(packet.packet_id, self.handler._pubrec_waiters)
pubrec = PubrecPacket.build(packet.packet_id)
await pubrec.to_stream(writer)
yield from pubrec.to_stream(writer)
pubrel = await PubrelPacket.from_stream(reader)
pubrel = yield from PubrelPacket.from_stream(reader)
self.assertIn(packet.packet_id, self.handler._pubcomp_waiters)
pubcomp = PubcompPacket.build(packet.packet_id)
await pubcomp.to_stream(writer)
yield from pubcomp.to_stream(writer)
except Exception as ae:
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.start_handler(self.handler, self.session)
message = await self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False)
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False)
self.assertIsInstance(message, OutgoingApplicationMessage)
self.assertIsNotNone(message.publish_packet)
self.assertIsNone(message.puback_packet)
self.assertIsNotNone(message.pubrec_packet)
self.assertIsNotNone(message.pubrel_packet)
self.assertIsNotNone(message.pubcomp_packet)
await self.stop_handler(self.handler, self.session)
yield from self.stop_handler(self.handler, self.session)
if not future.done():
future.set_result(True)
except Exception as ae:
@ -201,25 +209,27 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_receive_qos0(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
packet = PublishPacket.build('/topic', b'test_data', 1, False, QOS_0, False)
await packet.to_stream(writer)
yield from packet.to_stream(writer)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.start_handler(self.handler, self.session)
message = await self.handler.mqtt_deliver_next_message()
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage)
self.assertIsNotNone(message.publish_packet)
self.assertIsNone(message.puback_packet)
self.assertIsNone(message.pubrec_packet)
self.assertIsNone(message.pubrel_packet)
self.assertIsNone(message.pubcomp_packet)
await self.stop_handler(self.handler, self.session)
yield from self.stop_handler(self.handler, self.session)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -236,32 +246,34 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_receive_qos1(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
try:
packet = PublishPacket.build('/topic', b'test_data', 1, False, QOS_1, False)
await packet.to_stream(writer)
puback = await PubackPacket.from_stream(reader)
yield from packet.to_stream(writer)
puback = yield from PubackPacket.from_stream(reader)
self.assertIsNotNone(puback)
self.assertEqual(packet.packet_id, puback.packet_id)
except Exception as ae:
print(ae)
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.start_handler(self.handler, self.session)
message = await self.handler.mqtt_deliver_next_message()
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage)
self.assertIsNotNone(message.publish_packet)
self.assertIsNotNone(message.puback_packet)
self.assertIsNone(message.pubrec_packet)
self.assertIsNone(message.pubrel_packet)
self.assertIsNone(message.pubcomp_packet)
await self.stop_handler(self.handler, self.session)
yield from self.stop_handler(self.handler, self.session)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -279,37 +291,39 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_receive_qos2(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
try:
packet = PublishPacket.build('/topic', b'test_data', 2, False, QOS_2, False)
await packet.to_stream(writer)
pubrec = await PubrecPacket.from_stream(reader)
yield from packet.to_stream(writer)
pubrec = yield from PubrecPacket.from_stream(reader)
self.assertIsNotNone(pubrec)
self.assertEqual(packet.packet_id, pubrec.packet_id)
self.assertIn(packet.packet_id, self.handler._pubrel_waiters)
pubrel = PubrelPacket.build(packet.packet_id)
await pubrel.to_stream(writer)
pubcomp = await PubcompPacket.from_stream(reader)
yield from pubrel.to_stream(writer)
pubcomp = yield from PubcompPacket.from_stream(reader)
self.assertIsNotNone(pubcomp)
self.assertEqual(packet.packet_id, pubcomp.packet_id)
except Exception as ae:
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.start_handler(self.handler, self.session)
message = await self.handler.mqtt_deliver_next_message()
yield from self.start_handler(self.handler, self.session)
message = yield from self.handler.mqtt_deliver_next_message()
self.assertIsInstance(message, IncomingApplicationMessage)
self.assertIsNotNone(message.publish_packet)
self.assertIsNone(message.puback_packet)
self.assertIsNotNone(message.pubrec_packet)
self.assertIsNotNone(message.pubrel_packet)
self.assertIsNotNone(message.pubcomp_packet)
await self.stop_handler(self.handler, self.session)
yield from self.stop_handler(self.handler, self.session)
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -325,14 +339,16 @@ class ProtocolHandlerTest(unittest.TestCase):
if future.exception():
raise future.exception()
async def start_handler(self, handler, session):
@asyncio.coroutine
def start_handler(self, handler, session):
self.check_empty_waiters(handler)
self.check_no_message(session)
await handler.start()
yield from handler.start()
self.assertTrue(handler._reader_ready)
async def stop_handler(self, handler, session):
await handler.stop()
@asyncio.coroutine
def stop_handler(self, handler, session):
yield from handler.stop()
self.assertTrue(handler._reader_stopped)
self.check_empty_waiters(handler)
self.check_no_message(session)
@ -348,8 +364,9 @@ class ProtocolHandlerTest(unittest.TestCase):
self.assertFalse(session.inflight_in)
def test_publish_qos1_retry(self):
async def server_mock(reader, writer):
packet = await PublishPacket.from_stream(reader)
@asyncio.coroutine
def server_mock(reader, writer):
packet = yield from PublishPacket.from_stream(reader)
try:
self.assertEquals(packet.topic_name, '/topic')
self.assertEquals(packet.qos, QOS_1)
@ -357,18 +374,19 @@ class ProtocolHandlerTest(unittest.TestCase):
self.assertIn(packet.packet_id, self.session.inflight_out)
self.assertIn(packet.packet_id, self.handler._puback_waiters)
puback = PubackPacket.build(packet.packet_id)
await puback.to_stream(writer)
yield from puback.to_stream(writer)
except Exception as ae:
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.handler.start()
await self.stop_handler(self.handler, self.session)
yield from self.handler.start()
yield from self.stop_handler(self.handler, self.session)
if not future.done():
future.set_result(True)
except Exception as ae:
@ -389,32 +407,34 @@ class ProtocolHandlerTest(unittest.TestCase):
raise future.exception()
def test_publish_qos2_retry(self):
async def server_mock(reader, writer):
@asyncio.coroutine
def server_mock(reader, writer):
try:
packet = await PublishPacket.from_stream(reader)
packet = yield from PublishPacket.from_stream(reader)
self.assertEquals(packet.topic_name, '/topic')
self.assertEquals(packet.qos, QOS_2)
self.assertIsNotNone(packet.packet_id)
self.assertIn(packet.packet_id, self.session.inflight_out)
self.assertIn(packet.packet_id, self.handler._pubrec_waiters)
pubrec = PubrecPacket.build(packet.packet_id)
await pubrec.to_stream(writer)
yield from pubrec.to_stream(writer)
pubrel = await PubrelPacket.from_stream(reader)
pubrel = yield from PubrelPacket.from_stream(reader)
self.assertIn(packet.packet_id, self.handler._pubcomp_waiters)
pubcomp = PubcompPacket.build(packet.packet_id)
await pubcomp.to_stream(writer)
yield from pubcomp.to_stream(writer)
except Exception as ae:
future.set_exception(ae)
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
reader_adapted, writer_adapted = adapt(reader, writer)
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
self.handler.attach(self.session, reader_adapted, writer_adapted)
await self.handler.start()
await self.stop_handler(self.handler, self.session)
yield from self.handler.start()
yield from self.stop_handler(self.handler, self.session)
if not future.done():
future.set_result(True)
except Exception as ae:

Wyświetl plik

@ -21,14 +21,17 @@ class EventTestPlugin:
self.test_flag = False
self.coro_flag = False
async def on_test(self, *args, **kwargs):
@asyncio.coroutine
def on_test(self, *args, **kwargs):
self.test_flag = True
self.context.logger.info("on_test")
async def test_coro(self, *args, **kwargs):
@asyncio.coroutine
def test_coro(self, *args, **kwargs):
self.coro_flag = True
async def ret_coro(self, *args, **kwargs):
@asyncio.coroutine
def ret_coro(self, *args, **kwargs):
return "TEST"
@ -41,10 +44,11 @@ class TestPluginManager(unittest.TestCase):
self.assertTrue(len(manager._plugins) > 0)
def test_fire_event(self):
async def fire_event():
await manager.fire_event("test")
await asyncio.sleep(1, loop=self.loop)
await manager.close()
@asyncio.coroutine
def fire_event():
yield from manager.fire_event("test")
yield from asyncio.sleep(1, loop=self.loop)
yield from manager.close()
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
self.loop.run_until_complete(fire_event())
@ -52,9 +56,10 @@ class TestPluginManager(unittest.TestCase):
self.assertTrue(plugin.object.test_flag)
def test_fire_event_wait(self):
async def fire_event():
await manager.fire_event("test", wait=True)
await manager.close()
@asyncio.coroutine
def fire_event():
yield from manager.fire_event("test", wait=True)
yield from manager.close()
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
self.loop.run_until_complete(fire_event())
@ -62,8 +67,9 @@ class TestPluginManager(unittest.TestCase):
self.assertTrue(plugin.object.test_flag)
def test_map_coro(self):
async def call_coro():
await manager.map_plugin_coro('test_coro')
@asyncio.coroutine
def call_coro():
yield from manager.map_plugin_coro('test_coro')
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
self.loop.run_until_complete(call_coro())
@ -71,8 +77,9 @@ class TestPluginManager(unittest.TestCase):
self.assertTrue(plugin.object.test_coro)
def test_map_coro_return(self):
async def call_coro():
return await manager.map_plugin_coro('ret_coro')
@asyncio.coroutine
def call_coro():
return (yield from manager.map_plugin_coro('ret_coro'))
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
ret = self.loop.run_until_complete(call_coro())
@ -84,8 +91,9 @@ class TestPluginManager(unittest.TestCase):
Run plugin coro but expect no return as an empty filter is given
:return:
"""
async def call_coro():
return await manager.map_plugin_coro('ret_coro', filter_plugins=[])
@asyncio.coroutine
def call_coro():
return (yield from manager.map_plugin_coro('ret_coro', filter_plugins=[]))
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
ret = self.loop.run_until_complete(call_coro())

Wyświetl plik

@ -26,12 +26,12 @@ test_config = {
}
class AsyncMock(MagicMock):
def __await__(self, *args, **kwargs):
future = asyncio.Future()
future.set_result(self)
result = yield from future
return result
#class AsyncMock(MagicMock):
# def __yield from__(self, *args, **kwargs):
# future = asyncio.Future()
# future.set_result(self)
# result = yield from future
# return result
class BrokerTest(unittest.TestCase):
def setUp(self):
@ -41,12 +41,13 @@ class BrokerTest(unittest.TestCase):
def tearDown(self):
self.loop.close()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_start_stop(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
self.assertDictEqual(broker._sessions, {})
self.assertIn('default', broker._servers)
@ -54,7 +55,7 @@ class BrokerTest(unittest.TestCase):
[call().fire_event(EVENT_BROKER_PRE_START),
call().fire_event(EVENT_BROKER_POST_START)], any_order=True)
MockPluginManager.reset_mock()
await broker.shutdown()
yield from broker.shutdown()
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_PRE_SHUTDOWN),
call().fire_event(EVENT_BROKER_POST_SHUTDOWN)], any_order=True)
@ -68,20 +69,21 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_connect(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
client = MQTTClient()
ret = await client.connect('mqtt://localhost/')
ret = yield from client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
self.assertIn(client.session.client_id, broker._sessions)
await client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
yield from client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
self.assertDictEqual(broker._sessions, {})
MockPluginManager.assert_has_calls(
@ -97,17 +99,18 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_subscribe(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
client = MQTTClient()
ret = await client.connect('mqtt://localhost/')
ret = yield from client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
await client.subscribe([('/topic', QOS_0)])
yield from client.subscribe([('/topic', QOS_0)])
# Test if the client test client subscription is registered
self.assertIn('/topic', broker._subscriptions)
@ -117,9 +120,9 @@ class BrokerTest(unittest.TestCase):
self.assertEquals(s, client.session)
self.assertEquals(qos, QOS_0)
await client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
yield from client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_CLIENT_SUBSCRIBED,
@ -134,17 +137,18 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_subscribe_twice(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
client = MQTTClient()
ret = await client.connect('mqtt://localhost/')
ret = yield from client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
await client.subscribe([('/topic', QOS_0)])
yield from client.subscribe([('/topic', QOS_0)])
# Test if the client test client subscription is registered
self.assertIn('/topic', broker._subscriptions)
@ -154,15 +158,15 @@ class BrokerTest(unittest.TestCase):
self.assertEquals(s, client.session)
self.assertEquals(qos, QOS_0)
await client.subscribe([('/topic', QOS_0)])
yield from client.subscribe([('/topic', QOS_0)])
self.assertEquals(len(subs), 1)
(s, qos) = subs[0]
self.assertEquals(s, client.session)
self.assertEquals(qos, QOS_0)
await client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
yield from client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_CLIENT_SUBSCRIBED,
@ -177,17 +181,18 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_unsubscribe(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
client = MQTTClient()
ret = await client.connect('mqtt://localhost/')
ret = yield from client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
await client.subscribe([('/topic', QOS_0)])
yield from client.subscribe([('/topic', QOS_0)])
# Test if the client test client subscription is registered
self.assertIn('/topic', broker._subscriptions)
@ -197,12 +202,12 @@ class BrokerTest(unittest.TestCase):
self.assertEquals(s, client.session)
self.assertEquals(qos, QOS_0)
await client.unsubscribe(['/topic'])
await asyncio.sleep(0.1)
yield from client.unsubscribe(['/topic'])
yield from asyncio.sleep(0.1)
self.assertEquals(broker._subscriptions['/topic'], [])
await client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
yield from client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_CLIENT_SUBSCRIBED,
@ -221,23 +226,24 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_publish(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
pub_client = MQTTClient()
ret = await pub_client.connect('mqtt://localhost/')
ret = yield from pub_client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
ret_message = await pub_client.publish('/topic', b'data', QOS_0)
await pub_client.disconnect()
ret_message = yield from pub_client.publish('/topic', b'data', QOS_0)
yield from pub_client.disconnect()
self.assertEquals(broker._retained_messages, {})
await asyncio.sleep(0.1)
await broker.shutdown()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
MockPluginManager.assert_has_calls(
[call().fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
@ -253,27 +259,28 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_publish_retain(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
pub_client = MQTTClient()
ret = await pub_client.connect('mqtt://localhost/')
ret = yield from pub_client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
ret_message = await pub_client.publish('/topic', b'data', QOS_0, retain=True)
await pub_client.disconnect()
await asyncio.sleep(0.1)
ret_message = yield from pub_client.publish('/topic', b'data', QOS_0, retain=True)
yield from pub_client.disconnect()
yield from asyncio.sleep(0.1)
self.assertIn('/topic', broker._retained_messages)
retained_message = broker._retained_messages['/topic']
self.assertEquals(retained_message.source_session, pub_client.session)
self.assertEquals(retained_message.topic, '/topic')
self.assertEquals(retained_message.data, b'data')
self.assertEquals(retained_message.qos, QOS_0)
await broker.shutdown()
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
future.set_result(True)
except Exception as ae:
@ -284,31 +291,32 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_subscribe_publish(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
sub_client = MQTTClient()
await sub_client.connect('mqtt://localhost')
ret = await sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
yield from sub_client.connect('mqtt://localhost')
ret = yield from sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
self.assertEquals(ret, [QOS_0, QOS_1, QOS_2])
await self._client_publish('/qos0', b'data', QOS_0)
await self._client_publish('/qos1', b'data', QOS_1)
await self._client_publish('/qos2', b'data', QOS_2)
await asyncio.sleep(0.1)
yield from self._client_publish('/qos0', b'data', QOS_0)
yield from self._client_publish('/qos1', b'data', QOS_1)
yield from self._client_publish('/qos2', b'data', QOS_2)
yield from asyncio.sleep(0.1)
for qos in [QOS_0, QOS_1, QOS_2]:
message = await sub_client.deliver_message()
message = yield from sub_client.deliver_message()
self.assertIsNotNone(message)
self.assertEquals(message.topic, '/qos%s' % qos)
self.assertEquals(message.data, b'data')
self.assertEquals(message.qos, qos)
await sub_client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
yield from sub_client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
future.set_result(True)
except Exception as ae:
@ -319,35 +327,36 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
@patch('hbmqtt.broker.PluginManager')
def test_client_publish_retain_subscribe(self, MockPluginManager):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
await broker.start()
yield from broker.start()
self.assertTrue(broker.transitions.is_started())
sub_client = MQTTClient()
await sub_client.connect('mqtt://localhost', cleansession=False)
ret = await sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
yield from sub_client.connect('mqtt://localhost', cleansession=False)
ret = yield from sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
self.assertEquals(ret, [QOS_0, QOS_1, QOS_2])
await sub_client.disconnect()
await asyncio.sleep(0.1)
yield from sub_client.disconnect()
yield from asyncio.sleep(0.1)
await self._client_publish('/qos0', b'data', QOS_0, retain=True)
await self._client_publish('/qos1', b'data', QOS_1, retain=True)
await self._client_publish('/qos2', b'data', QOS_2, retain=True)
await sub_client.reconnect()
yield from self._client_publish('/qos0', b'data', QOS_0, retain=True)
yield from self._client_publish('/qos1', b'data', QOS_1, retain=True)
yield from self._client_publish('/qos2', b'data', QOS_2, retain=True)
yield from sub_client.reconnect()
for qos in [QOS_0, QOS_1, QOS_2]:
log.debug("TEST QOS: %d" % qos)
message = await sub_client.deliver_message()
message = yield from sub_client.deliver_message()
log.debug("Message: " + repr(message.publish_packet))
self.assertIsNotNone(message)
self.assertEquals(message.topic, '/qos%s' % qos)
self.assertEquals(message.data, b'data')
self.assertEquals(message.qos, qos)
await sub_client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
yield from sub_client.disconnect()
yield from asyncio.sleep(0.1)
yield from broker.shutdown()
self.assertTrue(broker.transitions.is_stopped())
future.set_result(True)
except Exception as ae:
@ -358,10 +367,11 @@ class BrokerTest(unittest.TestCase):
if future.exception():
raise future.exception()
async def _client_publish(self, topic, data, qos, retain=False):
@asyncio.coroutine
def _client_publish(self, topic, data, qos, retain=False):
pub_client = MQTTClient()
ret = await pub_client.connect('mqtt://localhost/')
ret = yield from pub_client.connect('mqtt://localhost/')
self.assertEqual(ret, 0)
ret = await pub_client.publish(topic, data, qos, retain)
await pub_client.disconnect()
ret = yield from pub_client.publish(topic, data, qos, retain)
yield from pub_client.disconnect()
return ret

Wyświetl plik

@ -22,12 +22,13 @@ class MQTTClientTest(unittest.TestCase):
self.loop.close()
def test_connect_tcp(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
ret = await client.connect('mqtt://test.mosquitto.org/')
ret = yield from client.connect('mqtt://test.mosquitto.org/')
self.assertIsNotNone(client.session)
await client.disconnect()
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -38,13 +39,14 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_connect_tcp_secure(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
ca = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mosquitto.org.crt')
ret = await client.connect('mqtts://test.mosquitto.org/', cafile=ca)
ret = yield from client.connect('mqtts://test.mosquitto.org/', cafile=ca)
self.assertIsNotNone(client.session)
await client.disconnect()
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -55,11 +57,12 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_connect_tcp_failure(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
config = {'auto_reconnect': False}
client = MQTTClient(config=config)
ret = await client.connect('mqtt://localhost/')
ret = yield from client.connect('mqtt://localhost/')
except ConnectException as e:
future.set_result(True)
@ -69,12 +72,13 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_connect_ws(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
await client.connect('ws://test.mosquitto.org:8080/')
yield from client.connect('ws://test.mosquitto.org:8080/')
self.assertIsNotNone(client.session)
await client.disconnect()
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -85,13 +89,14 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_connect_ws_secure(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
ca = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mosquitto.org.crt')
await client.connect('wss://test.mosquitto.org:8081/', cafile=ca)
yield from client.connect('wss://test.mosquitto.org:8081/', cafile=ca)
self.assertIsNotNone(client.session)
await client.disconnect()
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -102,13 +107,14 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_ping(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
ret = await client.connect('mqtt://test.mosquitto.org/')
ret = yield from client.connect('mqtt://test.mosquitto.org/')
self.assertIsNotNone(client.session)
await client.ping()
await client.disconnect()
yield from client.ping()
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -119,12 +125,13 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_subscribe(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
await client.connect('mqtt://test.mosquitto.org/')
yield from client.connect('mqtt://test.mosquitto.org/')
self.assertIsNotNone(client.session)
ret = await client.subscribe([
ret = yield from client.subscribe([
('$SYS/broker/uptime', QOS_0),
('$SYS/broker/uptime', QOS_1),
('$SYS/broker/uptime', QOS_2),
@ -132,7 +139,7 @@ class MQTTClientTest(unittest.TestCase):
self.assertEquals(ret[0], QOS_0)
self.assertEquals(ret[1], QOS_1)
self.assertEquals(ret[2], QOS_2)
await client.disconnect()
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)
@ -143,17 +150,18 @@ class MQTTClientTest(unittest.TestCase):
raise future.exception()
def test_unsubscribe(self):
async def test_coro():
@asyncio.coroutine
def test_coro():
try:
client = MQTTClient()
await client.connect('mqtt://test.mosquitto.org/')
yield from client.connect('mqtt://test.mosquitto.org/')
self.assertIsNotNone(client.session)
ret = await client.subscribe([
ret = yield from client.subscribe([
('$SYS/broker/uptime', QOS_0),
])
self.assertEquals(ret[0], QOS_0)
await client.unsubscribe(['$SYS/broker/uptime'])
await client.disconnect()
yield from client.unsubscribe(['$SYS/broker/uptime'])
yield from client.disconnect()
future.set_result(True)
except Exception as ae:
future.set_exception(ae)