kopia lustrzana https://github.com/Yakifo/amqtt
Revert to 3.4 coroutine syntax
rodzic
8d7fbaaff4
commit
4dcf8eb477
|
@ -15,7 +15,8 @@ class ReaderAdapter:
|
|||
Reader adapters are used to adapt read operations on the network depending on the protocol used
|
||||
"""
|
||||
|
||||
async def read(self, n=-1) -> bytes:
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1) -> bytes:
|
||||
"""
|
||||
Read up to n bytes. If n is not provided, or set to -1, read until EOF and return all read bytes.
|
||||
If the EOF was received and the internal buffer is empty, return an empty bytes object.
|
||||
|
@ -40,7 +41,8 @@ class WriterAdapter:
|
|||
write some data to the protocol layer
|
||||
"""
|
||||
|
||||
async def drain(self):
|
||||
@asyncio.coroutine
|
||||
def drain(self):
|
||||
"""
|
||||
Let the write buffer of the underlying transport a chance to be flushed.
|
||||
"""
|
||||
|
@ -50,7 +52,8 @@ class WriterAdapter:
|
|||
Return peer socket info (remote address and remote port as tuple
|
||||
"""
|
||||
|
||||
async def close(self):
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
"""
|
||||
Close the protocol connection
|
||||
"""
|
||||
|
@ -65,19 +68,21 @@ class WebSocketsReader(ReaderAdapter):
|
|||
self._protocol = protocol
|
||||
self._stream = io.BytesIO(b'')
|
||||
|
||||
async def read(self, n=-1) -> bytes:
|
||||
await self._feed_buffer(n)
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1) -> bytes:
|
||||
yield from self._feed_buffer(n)
|
||||
data = self._stream.read(n)
|
||||
return data
|
||||
|
||||
async def _feed_buffer(self, n=1):
|
||||
@asyncio.coroutine
|
||||
def _feed_buffer(self, n=1):
|
||||
"""
|
||||
Feed the data buffer by reading a Websocket message.
|
||||
:param n: if given, feed buffer until it contains at least n bytes
|
||||
"""
|
||||
buffer = bytearray(self._stream.read())
|
||||
while len(buffer) < n:
|
||||
message = await self._protocol.recv()
|
||||
message = yield from self._protocol.recv()
|
||||
if message is None:
|
||||
break
|
||||
if not isinstance(message, bytes):
|
||||
|
@ -101,21 +106,23 @@ class WebSocketsWriter(WriterAdapter):
|
|||
"""
|
||||
self._stream.write(data)
|
||||
|
||||
async def drain(self):
|
||||
@asyncio.coroutine
|
||||
def drain(self):
|
||||
"""
|
||||
Let the write buffer of the underlying transport a chance to be flushed.
|
||||
"""
|
||||
data = self._stream.getvalue()
|
||||
if len(data):
|
||||
await self._protocol.send(data)
|
||||
yield from self._protocol.send(data)
|
||||
self._stream = io.BytesIO(b'')
|
||||
|
||||
def get_peer_info(self):
|
||||
extra_info = self._protocol.writer.get_extra_info('peername')
|
||||
return extra_info[0], extra_info[1]
|
||||
|
||||
async def close(self):
|
||||
await self._protocol.close()
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
yield from self._protocol.close()
|
||||
|
||||
|
||||
class StreamReaderAdapter(ReaderAdapter):
|
||||
|
@ -127,8 +134,9 @@ class StreamReaderAdapter(ReaderAdapter):
|
|||
def __init__(self, reader: StreamReader):
|
||||
self._reader = reader
|
||||
|
||||
async def read(self, n=-1) -> bytes:
|
||||
return await self._reader.read(n)
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1) -> bytes:
|
||||
return (yield from self._reader.read(n))
|
||||
|
||||
def feed_eof(self):
|
||||
return self._reader.feed_eof()
|
||||
|
@ -147,15 +155,17 @@ class StreamWriterAdapter(WriterAdapter):
|
|||
def write(self, data):
|
||||
self._writer.write(data)
|
||||
|
||||
async def drain(self):
|
||||
await self._writer.drain()
|
||||
@asyncio.coroutine
|
||||
def drain(self):
|
||||
yield from self._writer.drain()
|
||||
|
||||
def get_peer_info(self):
|
||||
extra_info = self._writer.get_extra_info('peername')
|
||||
return extra_info[0], extra_info[1]
|
||||
|
||||
async def close(self):
|
||||
await self._writer.drain()
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
yield from self._writer.drain()
|
||||
if self._writer.can_write_eof():
|
||||
self._writer.write_eof()
|
||||
self._writer.close()
|
||||
|
@ -169,7 +179,8 @@ class BufferReader(ReaderAdapter):
|
|||
def __init__(self, buffer: bytes):
|
||||
self._stream = io.BytesIO(buffer)
|
||||
|
||||
async def read(self, n=-1) -> bytes:
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1) -> bytes:
|
||||
return self._stream.read(n)
|
||||
|
||||
|
||||
|
@ -187,7 +198,8 @@ class BufferWriter(WriterAdapter):
|
|||
"""
|
||||
self._stream.write(data)
|
||||
|
||||
async def drain(self):
|
||||
@asyncio.coroutine
|
||||
def drain(self):
|
||||
pass
|
||||
|
||||
def get_buffer(self):
|
||||
|
@ -196,5 +208,6 @@ class BufferWriter(WriterAdapter):
|
|||
def get_peer_info(self):
|
||||
return "BufferWriter", 0
|
||||
|
||||
async def close(self):
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
self._stream.close()
|
||||
|
|
135
hbmqtt/broker.py
135
hbmqtt/broker.py
|
@ -72,9 +72,10 @@ class Server:
|
|||
else:
|
||||
self.semaphore = None
|
||||
|
||||
async def acquire_connection(self):
|
||||
@asyncio.coroutine
|
||||
def acquire_connection(self):
|
||||
if self.semaphore:
|
||||
await self.semaphore.acquire()
|
||||
yield from self.semaphore.acquire()
|
||||
self.conn_count += 1
|
||||
if self.max_connections > 0:
|
||||
self.logger.info("Listener '%s': %d/%d connections acquired" %
|
||||
|
@ -94,10 +95,11 @@ class Server:
|
|||
self.logger.info("Listener '%s': %d connections acquired" %
|
||||
(self.listener_name, self.conn_count))
|
||||
|
||||
async def close_instance(self):
|
||||
@asyncio.coroutine
|
||||
def close_instance(self):
|
||||
if self.instance:
|
||||
self.instance.close()
|
||||
await self.instance.wait_closed()
|
||||
yield from self.instance.wait_closed()
|
||||
|
||||
|
||||
class BrokerContext(BaseContext):
|
||||
|
@ -110,8 +112,9 @@ class BrokerContext(BaseContext):
|
|||
self.config = None
|
||||
self._broker_instance = broker
|
||||
|
||||
async def broadcast_message(self, topic, data, qos=None):
|
||||
await self._broker_instance.internal_message_broadcast(topic, data, qos)
|
||||
@asyncio.coroutine
|
||||
def broadcast_message(self, topic, data, qos=None):
|
||||
yield from self._broker_instance.internal_message_broadcast(topic, data, qos)
|
||||
|
||||
def retain_message(self, topic_name, data, qos=None):
|
||||
self._broker_instance.retain_message(None, topic_name, data, qos)
|
||||
|
@ -220,7 +223,8 @@ class Broker:
|
|||
self.transitions.add_transition(trigger='stopping_failure', source='stopping', dest='not_stopped')
|
||||
self.transitions.add_transition(trigger='start', source='stopped', dest='starting')
|
||||
|
||||
async def start(self):
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
try:
|
||||
self._sessions = dict()
|
||||
self._subscriptions = dict()
|
||||
|
@ -231,7 +235,7 @@ class Broker:
|
|||
self.logger.warn("[WARN-0001] Invalid method call at this moment: %s" % me)
|
||||
raise BrokerException("Broker instance can't be started: %s" % me)
|
||||
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
|
||||
try:
|
||||
# Start network listeners
|
||||
for listener_name in self.listeners_config:
|
||||
|
@ -259,7 +263,7 @@ class Broker:
|
|||
if listener['type'] == 'tcp':
|
||||
address, port = listener['bind'].split(':')
|
||||
cb_partial = partial(self.stream_connected, listener_name=listener_name)
|
||||
instance = await asyncio.start_server(cb_partial,
|
||||
instance = yield from asyncio.start_server(cb_partial,
|
||||
address,
|
||||
port,
|
||||
ssl=sc,
|
||||
|
@ -268,14 +272,14 @@ class Broker:
|
|||
elif listener['type'] == 'ws':
|
||||
address, port = listener['bind'].split(':')
|
||||
cb_partial = partial(self.ws_connected, listener_name=listener_name)
|
||||
instance = await websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
|
||||
instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
|
||||
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
|
||||
|
||||
self.logger.info("Listener '%s' bind to %s (max_connecionts=%d)" %
|
||||
(listener_name, listener['bind'], max_connections))
|
||||
|
||||
self.transitions.starting_success()
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
|
||||
|
||||
#Start broadcast loop
|
||||
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop(), loop=self._loop)
|
||||
|
@ -286,7 +290,8 @@ class Broker:
|
|||
self.transitions.starting_fail()
|
||||
raise BrokerException("Broker instance can't be started: %s" % e)
|
||||
|
||||
async def shutdown(self):
|
||||
@asyncio.coroutine
|
||||
def shutdown(self):
|
||||
try:
|
||||
self._sessions = dict()
|
||||
self._subscriptions = dict()
|
||||
|
@ -297,7 +302,7 @@ class Broker:
|
|||
raise BrokerException("Broker instance can't be stopped: %s" % me)
|
||||
|
||||
# Fire broker_shutdown event to plugins
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
|
||||
|
||||
# Stop broadcast loop
|
||||
if self._broadcast_task:
|
||||
|
@ -307,44 +312,48 @@ class Broker:
|
|||
|
||||
for listener_name in self._servers:
|
||||
server = self._servers[listener_name]
|
||||
await server.close_instance()
|
||||
yield from server.close_instance()
|
||||
self.logger.debug("Broker closing")
|
||||
self.logger.info("Broker closed")
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
|
||||
self.transitions.stopping_success()
|
||||
|
||||
async def internal_message_broadcast(self, topic, data, qos=None):
|
||||
return await self._broadcast_message(None, topic, data)
|
||||
@asyncio.coroutine
|
||||
def internal_message_broadcast(self, topic, data, qos=None):
|
||||
return (yield from self._broadcast_message(None, topic, data))
|
||||
|
||||
async def ws_connected(self, websocket, uri, listener_name):
|
||||
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
||||
@asyncio.coroutine
|
||||
def ws_connected(self, websocket, uri, listener_name):
|
||||
yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
||||
|
||||
async def stream_connected(self, reader, writer, listener_name):
|
||||
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||
@asyncio.coroutine
|
||||
def stream_connected(self, reader, writer, listener_name):
|
||||
yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||
|
||||
async def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
|
||||
@asyncio.coroutine
|
||||
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
|
||||
# Wait for connection available on listener
|
||||
server = self._servers.get(listener_name, None)
|
||||
if not server:
|
||||
raise BrokerException("Invalid listener name '%s'" % listener_name)
|
||||
await server.acquire_connection()
|
||||
yield from server.acquire_connection()
|
||||
|
||||
remote_address, remote_port = writer.get_peer_info()
|
||||
self.logger.info("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
|
||||
|
||||
# Wait for first packet and expect a CONNECT
|
||||
try:
|
||||
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
|
||||
handler, client_session = yield from BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
|
||||
except HBMQTTException as exc:
|
||||
self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
|
||||
(format_client_message(address=remote_address, port=remote_port), exc))
|
||||
await writer.close()
|
||||
yield from writer.close()
|
||||
self.logger.debug("Connection closed")
|
||||
return
|
||||
except MQTTException as me:
|
||||
self.logger.error('Invalid connection from %s : %s' %
|
||||
(format_client_message(address=remote_address, port=remote_port), me))
|
||||
await writer.close()
|
||||
yield from writer.close()
|
||||
self.logger.debug("Connection closed")
|
||||
return
|
||||
|
||||
|
@ -371,9 +380,9 @@ class Broker:
|
|||
handler.attach(client_session, reader, writer)
|
||||
self._sessions[client_session.client_id] = (client_session, handler)
|
||||
|
||||
authenticated = await self.authenticate(client_session, self.listeners_config[listener_name])
|
||||
authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
|
||||
if not authenticated:
|
||||
await writer.close()
|
||||
yield from writer.close()
|
||||
return
|
||||
|
||||
while True:
|
||||
|
@ -383,15 +392,15 @@ class Broker:
|
|||
except MachineError:
|
||||
self.logger.warning("Client %s is reconnecting too quickly, make it wait" % client_session.client_id)
|
||||
# Wait a bit may be client is reconnecting too fast
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
await handler.mqtt_connack_authorize(authenticated)
|
||||
yield from asyncio.sleep(1, loop=self._loop)
|
||||
yield from handler.mqtt_connack_authorize(authenticated)
|
||||
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
||||
|
||||
self.logger.debug("%s Start messages handling" % client_session.client_id)
|
||||
await handler.start()
|
||||
yield from handler.start()
|
||||
self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize())
|
||||
await self.publish_session_retained_messages(client_session)
|
||||
yield from self.publish_session_retained_messages(client_session)
|
||||
|
||||
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
|
||||
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop)
|
||||
|
@ -401,7 +410,7 @@ class Broker:
|
|||
connected = True
|
||||
while connected:
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
done, pending = yield from asyncio.wait(
|
||||
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
|
||||
return_when=asyncio.FIRST_COMPLETED, loop=self._loop)
|
||||
if disconnect_waiter in done:
|
||||
|
@ -413,7 +422,7 @@ class Broker:
|
|||
if client_session.will_flag:
|
||||
self.logger.debug("Client %s disconnected abnormally, sending will message" %
|
||||
format_client_message(client_session))
|
||||
await self._broadcast_message(
|
||||
yield from self._broadcast_message(
|
||||
client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
|
@ -424,21 +433,21 @@ class Broker:
|
|||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
self.logger.debug("%s Disconnecting session" % client_session.client_id)
|
||||
await self._stop_handler(handler)
|
||||
yield from self._stop_handler(handler)
|
||||
client_session.transitions.disconnect()
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
||||
await writer.close()
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
||||
yield from writer.close()
|
||||
connected = False
|
||||
if unsubscribe_waiter in done:
|
||||
self.logger.debug("%s handling unsubscription" % client_session.client_id)
|
||||
unsubscription = unsubscribe_waiter.result()
|
||||
for topic in unsubscription['topics']:
|
||||
self._del_subscription(topic, client_session)
|
||||
await self.plugins_manager.fire_event(
|
||||
yield from self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
||||
client_id=client_session.client_id,
|
||||
topic=topic)
|
||||
await handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
|
||||
yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
|
||||
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
|
||||
if subscribe_waiter in done:
|
||||
self.logger.debug("%s handling subscription" % client_session.client_id)
|
||||
|
@ -446,25 +455,25 @@ class Broker:
|
|||
return_codes = []
|
||||
for subscription in subscriptions['topics']:
|
||||
return_codes.append(self.add_subscription(subscription, client_session))
|
||||
await handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
|
||||
yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
|
||||
for index, subscription in enumerate(subscriptions['topics']):
|
||||
if return_codes[index] != 0x80:
|
||||
await self.plugins_manager.fire_event(
|
||||
yield from self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
client_id=client_session.client_id,
|
||||
topic=subscription[0],
|
||||
qos=subscription[1])
|
||||
await self.publish_retained_messages_for_subscription(subscription, client_session)
|
||||
yield from self.publish_retained_messages_for_subscription(subscription, client_session)
|
||||
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
|
||||
self.logger.debug(repr(self._subscriptions))
|
||||
if wait_deliver in done:
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug("%s handling message delivery" % client_session.client_id)
|
||||
app_message = wait_deliver.result()
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
client_id=client_session.client_id,
|
||||
message=app_message)
|
||||
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
||||
yield from self._broadcast_message(client_session, app_message.topic, app_message.data)
|
||||
if app_message.publish_packet.retain_flag:
|
||||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
|
||||
|
@ -489,18 +498,20 @@ class Broker:
|
|||
handler.attach(session, reader, writer)
|
||||
return handler
|
||||
|
||||
async def _stop_handler(self, handler):
|
||||
@asyncio.coroutine
|
||||
def _stop_handler(self, handler):
|
||||
"""
|
||||
Stop a running handler and detach if from the session
|
||||
:param handler:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
await handler.stop()
|
||||
yield from handler.stop()
|
||||
except Exception as e:
|
||||
self.logger.error(e)
|
||||
|
||||
async def authenticate(self, session: Session, listener):
|
||||
@asyncio.coroutine
|
||||
def authenticate(self, session: Session, listener):
|
||||
"""
|
||||
This method call the authenticate method on registered plugins to test user authentication.
|
||||
User is considered authenticated if all plugins called returns True.
|
||||
|
@ -516,7 +527,7 @@ class Broker:
|
|||
auth_config = self.config.get('auth', None)
|
||||
if auth_config:
|
||||
auth_plugins = auth_config.get('plugins', None)
|
||||
returns = await self.plugins_manager.map_plugin_coro(
|
||||
returns = yield from self.plugins_manager.map_plugin_coro(
|
||||
"authenticate",
|
||||
session=session,
|
||||
filter_plugins=auth_plugins)
|
||||
|
@ -615,13 +626,14 @@ class Broker:
|
|||
else:
|
||||
return False
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
@asyncio.coroutine
|
||||
def _broadcast_loop(self):
|
||||
running_tasks = deque()
|
||||
try:
|
||||
while True:
|
||||
while running_tasks and running_tasks[0].done():
|
||||
running_tasks.popleft()
|
||||
broadcast = await self._broadcast_queue.get()
|
||||
broadcast = yield from self._broadcast_queue.get()
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug("broadcasting %r" % broadcast)
|
||||
for k_filter in self._subscriptions:
|
||||
|
@ -645,13 +657,14 @@ class Broker:
|
|||
broadcast['topic'], format_client_message(session=target_session)))
|
||||
retained_message = RetainedApplicationMessage(
|
||||
broadcast['session'], broadcast['topic'], broadcast['data'], qos)
|
||||
await target_session.retained_messages.put(retained_message)
|
||||
yield from target_session.retained_messages.put(retained_message)
|
||||
except CancelledError:
|
||||
# Wait until current broadcasting tasks end
|
||||
if running_tasks:
|
||||
await asyncio.wait(running_tasks, loop=self._loop)
|
||||
yield from asyncio.wait(running_tasks, loop=self._loop)
|
||||
|
||||
async def _broadcast_message(self, session, topic, data, force_qos=None):
|
||||
@asyncio.coroutine
|
||||
def _broadcast_message(self, session, topic, data, force_qos=None):
|
||||
broadcast = {
|
||||
'session': session,
|
||||
'topic': topic,
|
||||
|
@ -659,23 +672,25 @@ class Broker:
|
|||
}
|
||||
if force_qos:
|
||||
broadcast['qos'] = force_qos
|
||||
await self._broadcast_queue.put(broadcast)
|
||||
yield from self._broadcast_queue.put(broadcast)
|
||||
|
||||
async def publish_session_retained_messages(self, session):
|
||||
@asyncio.coroutine
|
||||
def publish_session_retained_messages(self, session):
|
||||
self.logger.debug("Publishing %d messages retained for session %s" %
|
||||
(session.retained_messages.qsize(), format_client_message(session=session))
|
||||
)
|
||||
publish_tasks = []
|
||||
handler = self._get_handler(session)
|
||||
while not session.retained_messages.empty():
|
||||
retained = await session.retained_messages.get()
|
||||
retained = yield from session.retained_messages.get()
|
||||
publish_tasks.append(asyncio.ensure_future(
|
||||
handler.mqtt_publish(
|
||||
retained.topic, retained.data, retained.qos, True), loop=self._loop))
|
||||
if publish_tasks:
|
||||
await asyncio.wait(publish_tasks, loop=self._loop)
|
||||
yield from asyncio.wait(publish_tasks, loop=self._loop)
|
||||
|
||||
async def publish_retained_messages_for_subscription(self, subscription, session):
|
||||
@asyncio.coroutine
|
||||
def publish_retained_messages_for_subscription(self, subscription, session):
|
||||
self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s" %
|
||||
(subscription[0], format_client_message(session=session)))
|
||||
publish_tasks = []
|
||||
|
@ -689,7 +704,7 @@ class Broker:
|
|||
handler.mqtt_publish(
|
||||
retained.topic, retained.data, subscription[1], True), loop=self._loop))
|
||||
if publish_tasks:
|
||||
await asyncio.wait(publish_tasks, loop=self._loop)
|
||||
yield from asyncio.wait(publish_tasks, loop=self._loop)
|
||||
self.logger.debug("End broadcasting messages retained due to subscription on '%s' from %s" %
|
||||
(subscription[0], format_client_message(session=session)))
|
||||
|
||||
|
|
|
@ -57,11 +57,12 @@ def mqtt_connected(func):
|
|||
:param func: coroutine to be called once connected
|
||||
:return: coroutine result
|
||||
"""
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self._connected_state.is_set():
|
||||
base_logger.warning("Client not connected, waiting for it")
|
||||
await self._connected_state.wait()
|
||||
return await func(self, *args, **kwargs)
|
||||
yield from self._connected_state.wait()
|
||||
return (yield from func(self, *args, **kwargs))
|
||||
return wrapper
|
||||
|
||||
|
||||
|
@ -118,7 +119,8 @@ class MQTTClient:
|
|||
self.client_tasks = deque()
|
||||
|
||||
|
||||
async def connect(self,
|
||||
@asyncio.coroutine
|
||||
def connect(self,
|
||||
uri=None,
|
||||
cleansession=None,
|
||||
cafile=None,
|
||||
|
@ -135,28 +137,30 @@ class MQTTClient:
|
|||
self.logger.debug("Connect to: %s" % uri)
|
||||
|
||||
try:
|
||||
return await self._do_connect()
|
||||
return (yield from self._do_connect())
|
||||
except BaseException as be:
|
||||
self.logger.warning("Connection failed: %r" % be)
|
||||
auto_reconnect = self.config.get('auto_reconnect', False)
|
||||
if not auto_reconnect:
|
||||
raise
|
||||
else:
|
||||
return await self.reconnect()
|
||||
return (yield from self.reconnect())
|
||||
|
||||
@mqtt_connected
|
||||
async def disconnect(self):
|
||||
@asyncio.coroutine
|
||||
def disconnect(self):
|
||||
if self.session.transitions.is_connected():
|
||||
if not self._disconnect_task.done():
|
||||
self._disconnect_task.cancel()
|
||||
await self._handler.mqtt_disconnect()
|
||||
yield from self._handler.mqtt_disconnect()
|
||||
self._connected_state.clear()
|
||||
await self._handler.stop()
|
||||
yield from self._handler.stop()
|
||||
self.session.transitions.disconnect()
|
||||
else:
|
||||
self.logger.warn("Client session is not currently connected, ignoring call")
|
||||
|
||||
async def reconnect(self, cleansession=None):
|
||||
@asyncio.coroutine
|
||||
def reconnect(self, cleansession=None):
|
||||
if self.session.transitions.is_connected():
|
||||
self.logger.warn("Client already connected")
|
||||
return CONNECTION_ACCEPTED
|
||||
|
@ -167,11 +171,11 @@ class MQTTClient:
|
|||
reconnect_max_interval = self.config.get('reconnect_max_interval', 10)
|
||||
reconnect_retries = self.config.get('reconnect_retries', 5)
|
||||
nb_attempt = 1
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
yield from asyncio.sleep(1, loop=self._loop)
|
||||
while True:
|
||||
try:
|
||||
self.logger.debug("Reconnect attempt %d ..." % nb_attempt)
|
||||
return (await self._do_connect())
|
||||
return (yield from self._do_connect())
|
||||
except BaseException as e:
|
||||
self.logger.warning("Reconnection attempt failed: %r" % e)
|
||||
if nb_attempt > reconnect_retries:
|
||||
|
@ -180,29 +184,32 @@ class MQTTClient:
|
|||
exp = 2 ** nb_attempt
|
||||
delay = exp if exp < reconnect_max_interval else reconnect_max_interval
|
||||
self.logger.debug("Waiting %d second before next attempt" % delay)
|
||||
await asyncio.sleep(delay, loop=self._loop)
|
||||
yield from asyncio.sleep(delay, loop=self._loop)
|
||||
nb_attempt += 1
|
||||
|
||||
|
||||
async def _do_connect(self):
|
||||
return_code = await self._connect_coro()
|
||||
@asyncio.coroutine
|
||||
def _do_connect(self):
|
||||
return_code = yield from self._connect_coro()
|
||||
self._disconnect_task = asyncio.ensure_future(self.handle_connection_close(), loop=self._loop)
|
||||
return return_code
|
||||
|
||||
@mqtt_connected
|
||||
async def ping(self):
|
||||
@asyncio.coroutine
|
||||
def ping(self):
|
||||
"""
|
||||
Send a MQTT ping request and wait for response
|
||||
:return: None
|
||||
"""
|
||||
if self.session.transitions.is_connected():
|
||||
await self._handler.mqtt_ping()
|
||||
yield from self._handler.mqtt_ping()
|
||||
else:
|
||||
self.logger.warn("MQTT PING request incompatible with current session state '%s'" %
|
||||
self.session.transitions.state)
|
||||
|
||||
@mqtt_connected
|
||||
async def publish(self, topic, message, qos=None, retain=None):
|
||||
@asyncio.coroutine
|
||||
def publish(self, topic, message, qos=None, retain=None):
|
||||
def get_retain_and_qos():
|
||||
if qos:
|
||||
assert qos in (QOS_0, QOS_1, QOS_2)
|
||||
|
@ -223,27 +230,31 @@ class MQTTClient:
|
|||
pass
|
||||
return _qos, _retain
|
||||
(app_qos, app_retain) = get_retain_and_qos()
|
||||
return await self._handler.mqtt_publish(topic, message, app_qos, app_retain)
|
||||
return (yield from self._handler.mqtt_publish(topic, message, app_qos, app_retain))
|
||||
|
||||
@mqtt_connected
|
||||
async def subscribe(self, topics):
|
||||
return await self._handler.mqtt_subscribe(topics, self.session.next_packet_id)
|
||||
@asyncio.coroutine
|
||||
def subscribe(self, topics):
|
||||
return (yield from self._handler.mqtt_subscribe(topics, self.session.next_packet_id))
|
||||
|
||||
@mqtt_connected
|
||||
async def unsubscribe(self, topics):
|
||||
await self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
|
||||
@asyncio.coroutine
|
||||
def unsubscribe(self, topics):
|
||||
yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
|
||||
|
||||
async def deliver_message(self, timeout=None):
|
||||
@asyncio.coroutine
|
||||
def deliver_message(self, timeout=None):
|
||||
deliver_task = asyncio.ensure_future(self._handler.mqtt_deliver_next_message(), loop=self._loop)
|
||||
self.client_tasks.append(deliver_task)
|
||||
self.logger.debug("Waiting message delivery")
|
||||
await asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
|
||||
yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
|
||||
if deliver_task.exception():
|
||||
raise deliver_task.exception()
|
||||
self.client_tasks.pop()
|
||||
return deliver_task.result()
|
||||
|
||||
async def _connect_coro(self):
|
||||
@asyncio.coroutine
|
||||
def _connect_coro(self):
|
||||
kwargs = dict()
|
||||
|
||||
# Decode URI attributes
|
||||
|
@ -287,13 +298,13 @@ class MQTTClient:
|
|||
# Open connection
|
||||
if scheme in ('mqtt', 'mqtts'):
|
||||
conn_reader, conn_writer = \
|
||||
await asyncio.open_connection(
|
||||
yield from asyncio.open_connection(
|
||||
self.session.remote_address,
|
||||
self.session.remote_port, loop=self._loop, **kwargs)
|
||||
reader = StreamReaderAdapter(conn_reader)
|
||||
writer = StreamWriterAdapter(conn_writer)
|
||||
elif scheme in ('ws', 'wss'):
|
||||
websocket = await websockets.connect(
|
||||
websocket = yield from websockets.connect(
|
||||
self.session.broker_uri,
|
||||
subprotocols=['mqtt'],
|
||||
loop=self._loop,
|
||||
|
@ -302,7 +313,7 @@ class MQTTClient:
|
|||
writer = WebSocketsWriter(websocket)
|
||||
# Start MQTT protocol
|
||||
self._handler.attach(self.session, reader, writer)
|
||||
return_code = await self._handler.mqtt_connect()
|
||||
return_code = yield from self._handler.mqtt_connect()
|
||||
if return_code is not CONNECTION_ACCEPTED:
|
||||
self.session.transitions.disconnect()
|
||||
self.logger.warning("Connection rejected with code '%s'" % return_code)
|
||||
|
@ -311,7 +322,7 @@ class MQTTClient:
|
|||
raise exc
|
||||
else:
|
||||
# Handle MQTT protocol
|
||||
await self._handler.start()
|
||||
yield from self._handler.start()
|
||||
self.session.transitions.connect()
|
||||
self._connected_state.set()
|
||||
self.logger.debug("connected to %s:%s" % (self.session.remote_address, self.session.remote_port))
|
||||
|
@ -329,17 +340,18 @@ class MQTTClient:
|
|||
self.session.transitions.disconnect()
|
||||
raise ConnectException(e)
|
||||
|
||||
async def handle_connection_close(self):
|
||||
@asyncio.coroutine
|
||||
def handle_connection_close(self):
|
||||
self.logger.debug("Watch broker disconnection")
|
||||
# Wait for disconnection from broker (like connection lost)
|
||||
await self._handler.wait_disconnect()
|
||||
yield from self._handler.wait_disconnect()
|
||||
self.logger.warning("Disconnected from broker")
|
||||
|
||||
# Block client API
|
||||
self._connected_state.clear()
|
||||
|
||||
# stop an clean handler
|
||||
await self._handler.stop()
|
||||
yield from self._handler.stop()
|
||||
self._handler.detach()
|
||||
self.session.transitions.disconnect()
|
||||
|
||||
|
@ -347,7 +359,7 @@ class MQTTClient:
|
|||
# Try reconnection
|
||||
self.logger.debug("Auto-reconnecting")
|
||||
try:
|
||||
await self.reconnect()
|
||||
yield from self.reconnect()
|
||||
except ConnectException:
|
||||
# Cancel client pending tasks
|
||||
while self.client_tasks:
|
||||
|
|
|
@ -14,6 +14,7 @@ def bytes_to_hex_str(data):
|
|||
"""
|
||||
return '0x' + ''.join(format(b, '02x') for b in data)
|
||||
|
||||
|
||||
def bytes_to_int(data):
|
||||
"""
|
||||
convert a sequence of bytes to an integer using big endian byte ordering
|
||||
|
@ -25,6 +26,7 @@ def bytes_to_int(data):
|
|||
except:
|
||||
return data
|
||||
|
||||
|
||||
def int_to_bytes(int_value: int, length: int) -> bytes:
|
||||
"""
|
||||
convert an integer to a sequence of bytes using big endian byte ordering
|
||||
|
@ -38,28 +40,32 @@ def int_to_bytes(int_value: int, length: int) -> bytes:
|
|||
fmt = "!H"
|
||||
return pack(fmt, int_value)
|
||||
|
||||
async def read_or_raise(reader, n=-1):
|
||||
|
||||
@asyncio.coroutine
|
||||
def read_or_raise(reader, n=-1):
|
||||
"""
|
||||
Read a given byte number from Stream. NoDataException is raised if read gives no data
|
||||
:param reader: reader adapter
|
||||
:param n: number of bytes to read
|
||||
:return: bytes read
|
||||
"""
|
||||
data = await reader.read(n)
|
||||
data = yield from reader.read(n)
|
||||
if not data:
|
||||
raise NoDataException("No more data")
|
||||
return data
|
||||
|
||||
async def decode_string(reader) -> bytes:
|
||||
|
||||
@asyncio.coroutine
|
||||
def decode_string(reader) -> bytes:
|
||||
"""
|
||||
Read a string from a reader and decode it according to MQTT string specification
|
||||
:param reader: Stream reader
|
||||
:return: UTF-8 string read from stream
|
||||
"""
|
||||
length_bytes = await read_or_raise(reader, 2)
|
||||
length_bytes = yield from read_or_raise(reader, 2)
|
||||
str_length = unpack("!H", length_bytes)
|
||||
if str_length[0]:
|
||||
byte_str = await read_or_raise(reader, str_length[0])
|
||||
byte_str = yield from read_or_raise(reader, str_length[0])
|
||||
try:
|
||||
return byte_str.decode(encoding='utf-8')
|
||||
except:
|
||||
|
@ -67,15 +73,17 @@ async def decode_string(reader) -> bytes:
|
|||
else:
|
||||
return ''
|
||||
|
||||
async def decode_data_with_length(reader) -> bytes:
|
||||
|
||||
@asyncio.coroutine
|
||||
def decode_data_with_length(reader) -> bytes:
|
||||
"""
|
||||
Read data from a reader. Data is prefixed with 2 bytes length
|
||||
:param reader: Stream reader
|
||||
:return: bytes read from stream (without length)
|
||||
"""
|
||||
length_bytes = await read_or_raise(reader, 2)
|
||||
length_bytes = yield from read_or_raise(reader, 2)
|
||||
bytes_length = unpack("!H", length_bytes)
|
||||
data = await read_or_raise(reader, bytes_length[0])
|
||||
data = yield from read_or_raise(reader, bytes_length[0])
|
||||
return data
|
||||
|
||||
|
||||
|
@ -89,13 +97,15 @@ def encode_data_with_length(data: bytes) -> bytes:
|
|||
data_length = len(data)
|
||||
return int_to_bytes(data_length, 2) + data
|
||||
|
||||
async def decode_packet_id(reader) -> int:
|
||||
|
||||
@asyncio.coroutine
|
||||
def decode_packet_id(reader) -> int:
|
||||
"""
|
||||
Read a packet ID as 2-bytes int from stream according to MQTT specification (2.3.1)
|
||||
:param reader: Stream reader
|
||||
:return: Packet ID
|
||||
"""
|
||||
packet_id_bytes = await read_or_raise(reader, 2)
|
||||
packet_id_bytes = yield from read_or_raise(reader, 2)
|
||||
packet_id = unpack("!H", packet_id_bytes)
|
||||
return packet_id[0]
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
import asyncio
|
||||
from hbmqtt.mqtt.packet import CONNACK, MQTTPacket, MQTTFixedHeader, MQTTVariableHeader
|
||||
from hbmqtt.codecs import int_to_bytes, read_or_raise, bytes_to_int
|
||||
from hbmqtt.errors import HBMQTTException
|
||||
|
@ -21,8 +22,9 @@ class ConnackVariableHeader(MQTTVariableHeader):
|
|||
self.return_code = return_code
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
|
||||
data = await read_or_raise(reader, 2)
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
|
||||
data = yield from read_or_raise(reader, 2)
|
||||
session_parent = data[0] & 0x01
|
||||
return_code = bytes_to_int(data[1])
|
||||
return cls(session_parent, return_code)
|
||||
|
|
|
@ -93,20 +93,21 @@ class ConnectVariableHeader(MQTTVariableHeader):
|
|||
self.flags |= (val << 3)
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
|
||||
# protocol name
|
||||
protocol_name = await decode_string(reader)
|
||||
protocol_name = yield from decode_string(reader)
|
||||
|
||||
# protocol level
|
||||
protocol_level_byte = await read_or_raise(reader, 1)
|
||||
protocol_level_byte = yield from read_or_raise(reader, 1)
|
||||
protocol_level = bytes_to_int(protocol_level_byte)
|
||||
|
||||
# flags
|
||||
flags_byte = await read_or_raise(reader, 1)
|
||||
flags_byte = yield from read_or_raise(reader, 1)
|
||||
flags = bytes_to_int(flags_byte)
|
||||
|
||||
# keep-alive
|
||||
keep_alive_byte = await read_or_raise(reader, 2)
|
||||
keep_alive_byte = yield from read_or_raise(reader, 2)
|
||||
keep_alive = bytes_to_int(keep_alive_byte)
|
||||
|
||||
return cls(flags, keep_alive, protocol_name, protocol_level)
|
||||
|
@ -140,33 +141,34 @@ class ConnectPayload(MQTTPayload):
|
|||
format(self.client_id, self.will_topic, self.will_message, self.username, self.password)
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
|
||||
variable_header: ConnectVariableHeader):
|
||||
payload = cls()
|
||||
# Client identifier
|
||||
try:
|
||||
payload.client_id = await decode_string(reader)
|
||||
payload.client_id = yield from decode_string(reader)
|
||||
except NoDataException:
|
||||
payload.client_id = None
|
||||
|
||||
# Read will topic, username and password
|
||||
if variable_header.will_flag:
|
||||
try:
|
||||
payload.will_topic = await decode_string(reader)
|
||||
payload.will_message = await decode_data_with_length(reader)
|
||||
payload.will_topic = yield from decode_string(reader)
|
||||
payload.will_message = yield from decode_data_with_length(reader)
|
||||
except NoDataException:
|
||||
payload.will_topic = None
|
||||
payload.will_message = None
|
||||
|
||||
if variable_header.username_flag:
|
||||
try:
|
||||
payload.username = await decode_string(reader)
|
||||
payload.username = yield from decode_string(reader)
|
||||
except NoDataException:
|
||||
payload.username = None
|
||||
|
||||
if variable_header.password_flag:
|
||||
try:
|
||||
payload.password = await decode_string(reader)
|
||||
payload.password = yield from decode_string(reader)
|
||||
except NoDataException:
|
||||
payload.password = None
|
||||
|
||||
|
|
|
@ -58,7 +58,8 @@ class MQTTFixedHeader:
|
|||
|
||||
return out
|
||||
|
||||
async def to_stream(self, writer: WriterAdapter):
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: WriterAdapter):
|
||||
writer.write(self.to_bytes())
|
||||
|
||||
@property
|
||||
|
@ -66,12 +67,14 @@ class MQTTFixedHeader:
|
|||
return len(self.to_bytes())
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter):
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter):
|
||||
"""
|
||||
Read and decode MQTT message fixed header from stream
|
||||
:return: FixedHeader instance
|
||||
"""
|
||||
async def decode_remaining_length():
|
||||
@asyncio.coroutine
|
||||
def decode_remaining_length():
|
||||
"""
|
||||
Decode message length according to MQTT specifications
|
||||
:return:
|
||||
|
@ -80,7 +83,7 @@ class MQTTFixedHeader:
|
|||
value = 0
|
||||
buffer = bytearray()
|
||||
while True:
|
||||
encoded_byte = await reader.read(1)
|
||||
encoded_byte = yield from reader.read(1)
|
||||
int_byte = unpack('!B', encoded_byte)
|
||||
buffer.append(int_byte[0])
|
||||
value += (int_byte[0] & 0x7f) * multiplier
|
||||
|
@ -93,11 +96,11 @@ class MQTTFixedHeader:
|
|||
return value
|
||||
|
||||
try:
|
||||
byte1 = await read_or_raise(reader, 1)
|
||||
byte1 = yield from read_or_raise(reader, 1)
|
||||
int1 = unpack('!B', byte1)
|
||||
msg_type = (int1[0] & 0xf0) >> 4
|
||||
flags = int1[0] & 0x0f
|
||||
remain_length = await decode_remaining_length()
|
||||
remain_length = yield from decode_remaining_length()
|
||||
|
||||
return cls(msg_type, flags, remain_length)
|
||||
except NoDataException:
|
||||
|
@ -112,9 +115,10 @@ class MQTTVariableHeader:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def to_stream(self, writer: asyncio.StreamWriter):
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
await writer.drain()
|
||||
yield from writer.drain()
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
"""
|
||||
|
@ -127,7 +131,8 @@ class MQTTVariableHeader:
|
|||
return len(self.to_bytes())
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -142,8 +147,9 @@ class PacketIdVariableHeader(MQTTVariableHeader):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
|
||||
packet_id = await decode_packet_id(reader)
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader):
|
||||
packet_id = yield from decode_packet_id(reader)
|
||||
return cls(packet_id)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -154,15 +160,17 @@ class MQTTPayload:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def to_stream(self, writer: asyncio.StreamWriter):
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
await writer.drain()
|
||||
yield from writer.drain()
|
||||
|
||||
def to_bytes(self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
variable_header: MQTTVariableHeader):
|
||||
pass
|
||||
|
||||
|
@ -178,9 +186,10 @@ class MQTTPacket:
|
|||
self.payload = payload
|
||||
self.protocol_ts = None
|
||||
|
||||
async def to_stream(self, writer: asyncio.StreamWriter):
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
await writer.drain()
|
||||
yield from writer.drain()
|
||||
self.protocol_ts = datetime.now()
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
|
@ -199,16 +208,17 @@ class MQTTPacket:
|
|||
return fixed_header_bytes + variable_header_bytes + payload_bytes
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter, fixed_header=None, variable_header=None):
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter, fixed_header=None, variable_header=None):
|
||||
if fixed_header is None:
|
||||
fixed_header = await cls.FIXED_HEADER.from_stream(reader)
|
||||
fixed_header = yield from cls.FIXED_HEADER.from_stream(reader)
|
||||
if cls.VARIABLE_HEADER:
|
||||
if variable_header is None:
|
||||
variable_header = await cls.VARIABLE_HEADER.from_stream(reader, fixed_header)
|
||||
variable_header = yield from cls.VARIABLE_HEADER.from_stream(reader, fixed_header)
|
||||
else:
|
||||
variable_header = None
|
||||
if cls.PAYLOAD:
|
||||
payload = await cls.PAYLOAD.from_stream(reader, fixed_header, variable_header)
|
||||
payload = yield from cls.PAYLOAD.from_stream(reader, fixed_header, variable_header)
|
||||
else:
|
||||
payload = None
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
import logging
|
||||
import asyncio
|
||||
from asyncio import futures, Queue
|
||||
from hbmqtt.mqtt.protocol.handler import ProtocolHandler
|
||||
from hbmqtt.mqtt.connect import ConnectPacket
|
||||
|
@ -27,18 +27,21 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
self._pending_subscriptions = Queue(loop=self._loop)
|
||||
self._pending_unsubscriptions = Queue(loop=self._loop)
|
||||
|
||||
async def start(self):
|
||||
await super().start()
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
yield from super().start()
|
||||
if self._disconnect_waiter is None:
|
||||
self._disconnect_waiter = futures.Future(loop=self._loop)
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
yield from super().stop()
|
||||
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
||||
self._disconnect_waiter.set_result(None)
|
||||
|
||||
async def wait_disconnect(self):
|
||||
return await self._disconnect_waiter
|
||||
@asyncio.coroutine
|
||||
def wait_disconnect(self):
|
||||
return (yield from self._disconnect_waiter)
|
||||
|
||||
def handle_write_timeout(self):
|
||||
pass
|
||||
|
@ -47,16 +50,19 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
||||
self._disconnect_waiter.set_result(None)
|
||||
|
||||
async def handle_disconnect(self, disconnect):
|
||||
@asyncio.coroutine
|
||||
def handle_disconnect(self, disconnect):
|
||||
self.logger.debug("Client disconnecting")
|
||||
if self._disconnect_waiter and not self._disconnect_waiter.done():
|
||||
self.logger.debug("Setting waiter result to %r" % disconnect)
|
||||
self._disconnect_waiter.set_result(disconnect)
|
||||
|
||||
async def handle_connection_closed(self):
|
||||
await self.handle_disconnect(None)
|
||||
@asyncio.coroutine
|
||||
def handle_connection_closed(self):
|
||||
yield from self.handle_disconnect(None)
|
||||
|
||||
async def handle_connect(self, connect: ConnectPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_connect(self, connect: ConnectPacket):
|
||||
# Broker handler shouldn't received CONNECT message during messages handling
|
||||
# as CONNECT messages are managed by the broker on client connection
|
||||
self.logger.error('%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling' %
|
||||
|
@ -64,42 +70,51 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
||||
self._disconnect_waiter.set_result(None)
|
||||
|
||||
async def handle_pingreq(self, pingreq: PingReqPacket):
|
||||
await self._send_packet(PingRespPacket.build())
|
||||
@asyncio.coroutine
|
||||
def handle_pingreq(self, pingreq: PingReqPacket):
|
||||
yield from self._send_packet(PingRespPacket.build())
|
||||
|
||||
async def handle_subscribe(self, subscribe: SubscribePacket):
|
||||
@asyncio.coroutine
|
||||
def handle_subscribe(self, subscribe: SubscribePacket):
|
||||
subscription = {'packet_id': subscribe.variable_header.packet_id, 'topics': subscribe.payload.topics}
|
||||
await self._pending_subscriptions.put(subscription)
|
||||
yield from self._pending_subscriptions.put(subscription)
|
||||
|
||||
async def handle_unsubscribe(self, unsubscribe: UnsubscribePacket):
|
||||
@asyncio.coroutine
|
||||
def handle_unsubscribe(self, unsubscribe: UnsubscribePacket):
|
||||
unsubscription = {'packet_id': unsubscribe.variable_header.packet_id, 'topics': unsubscribe.payload.topics}
|
||||
await self._pending_unsubscriptions.put(unsubscription)
|
||||
yield from self._pending_unsubscriptions.put(unsubscription)
|
||||
|
||||
async def get_next_pending_subscription(self):
|
||||
subscription = await self._pending_subscriptions.get()
|
||||
@asyncio.coroutine
|
||||
def get_next_pending_subscription(self):
|
||||
subscription = yield from self._pending_subscriptions.get()
|
||||
return subscription
|
||||
|
||||
async def get_next_pending_unsubscription(self):
|
||||
unsubscription = await self._pending_unsubscriptions.get()
|
||||
@asyncio.coroutine
|
||||
def get_next_pending_unsubscription(self):
|
||||
unsubscription = yield from self._pending_unsubscriptions.get()
|
||||
return unsubscription
|
||||
|
||||
async def mqtt_acknowledge_subscription(self, packet_id, return_codes):
|
||||
@asyncio.coroutine
|
||||
def mqtt_acknowledge_subscription(self, packet_id, return_codes):
|
||||
suback = SubackPacket.build(packet_id, return_codes)
|
||||
await self._send_packet(suback)
|
||||
yield from self._send_packet(suback)
|
||||
|
||||
async def mqtt_acknowledge_unsubscription(self, packet_id):
|
||||
@asyncio.coroutine
|
||||
def mqtt_acknowledge_unsubscription(self, packet_id):
|
||||
unsuback = UnsubackPacket.build(packet_id)
|
||||
await self._send_packet(unsuback)
|
||||
yield from self._send_packet(unsuback)
|
||||
|
||||
async def mqtt_connack_authorize(self, authorize: bool):
|
||||
@asyncio.coroutine
|
||||
def mqtt_connack_authorize(self, authorize: bool):
|
||||
if authorize:
|
||||
connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED)
|
||||
else:
|
||||
connack = ConnackPacket.build(self.session.parent, NOT_AUTHORIZED)
|
||||
await self._send_packet(connack)
|
||||
yield from self._send_packet(connack)
|
||||
|
||||
@classmethod
|
||||
async def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
|
||||
@asyncio.coroutine
|
||||
def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
|
||||
"""
|
||||
|
||||
:param reader:
|
||||
|
@ -108,10 +123,9 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
:param loop:
|
||||
:return:
|
||||
"""
|
||||
log = logging.getLogger(__name__)
|
||||
remote_address, remote_port = writer.get_peer_info()
|
||||
connect = await ConnectPacket.from_stream(reader)
|
||||
await plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
|
||||
connect = yield from ConnectPacket.from_stream(reader)
|
||||
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
|
||||
if connect.payload.client_id is None:
|
||||
raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
|
||||
|
||||
|
@ -144,9 +158,9 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
format_client_message(address=remote_address, port=remote_port)
|
||||
connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
|
||||
if connack is not None:
|
||||
await plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
|
||||
await connack.to_stream(writer)
|
||||
await writer.close()
|
||||
yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
|
||||
yield from connack.to_stream(writer)
|
||||
yield from writer.close()
|
||||
raise MQTTException(error_msg)
|
||||
|
||||
incoming_session = Session()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
import asyncio
|
||||
from asyncio import futures
|
||||
from hbmqtt.mqtt.protocol.handler import ProtocolHandler, EVENT_MQTT_PACKET_RECEIVED
|
||||
from hbmqtt.mqtt.packet import *
|
||||
|
@ -27,13 +28,15 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
self._disconnect_waiter = None
|
||||
self._pingresp_waiter = None
|
||||
|
||||
async def start(self):
|
||||
await super().start()
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
yield from super().start()
|
||||
if self._disconnect_waiter is None:
|
||||
self._disconnect_waiter = futures.Future(loop=self._loop)
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
yield from super().stop()
|
||||
if self._ping_task:
|
||||
try:
|
||||
self.logger.debug("Cancel ping task")
|
||||
|
@ -77,11 +80,12 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
packet = ConnectPacket(vh=vh, payload=payload)
|
||||
return packet
|
||||
|
||||
async def mqtt_connect(self):
|
||||
@asyncio.coroutine
|
||||
def mqtt_connect(self):
|
||||
connect_packet = self._build_connect_packet()
|
||||
await self._send_packet(connect_packet)
|
||||
connack = await ConnackPacket.from_stream(self.reader)
|
||||
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
|
||||
yield from self._send_packet(connect_packet)
|
||||
connack = yield from ConnackPacket.from_stream(self.reader)
|
||||
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
|
||||
return connack.return_code
|
||||
|
||||
def handle_write_timeout(self):
|
||||
|
@ -95,7 +99,8 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
def handle_read_timeout(self):
|
||||
pass
|
||||
|
||||
async def mqtt_subscribe(self, topics, packet_id):
|
||||
@asyncio.coroutine
|
||||
def mqtt_subscribe(self, topics, packet_id):
|
||||
"""
|
||||
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
|
||||
:return:
|
||||
|
@ -103,17 +108,18 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
|
||||
# Build and send SUBSCRIBE message
|
||||
subscribe = SubscribePacket.build(topics, packet_id)
|
||||
await self._send_packet(subscribe)
|
||||
yield from self._send_packet(subscribe)
|
||||
|
||||
# Wait for SUBACK is received
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
|
||||
return_codes = await waiter
|
||||
return_codes = yield from waiter
|
||||
|
||||
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
|
||||
return return_codes
|
||||
|
||||
async def handle_suback(self, suback: SubackPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
packet_id = suback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._subscriptions_waiter.get(packet_id)
|
||||
|
@ -121,20 +127,22 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
except KeyError as ke:
|
||||
self.logger.warning("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
async def mqtt_unsubscribe(self, topics, packet_id):
|
||||
@asyncio.coroutine
|
||||
def mqtt_unsubscribe(self, topics, packet_id):
|
||||
"""
|
||||
|
||||
:param topics: array of topics ['/a/b', ...]
|
||||
:return:
|
||||
"""
|
||||
unsubscribe = UnsubscribePacket.build(topics, packet_id)
|
||||
await self._send_packet(unsubscribe)
|
||||
yield from self._send_packet(unsubscribe)
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
|
||||
await waiter
|
||||
yield from waiter
|
||||
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
|
||||
|
||||
async def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
packet_id = unsuback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._unsubscriptions_waiter.get(packet_id)
|
||||
|
@ -142,25 +150,30 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
except KeyError as ke:
|
||||
self.logger.warning("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
async def mqtt_disconnect(self):
|
||||
@asyncio.coroutine
|
||||
def mqtt_disconnect(self):
|
||||
disconnect_packet = DisconnectPacket()
|
||||
await self._send_packet(disconnect_packet)
|
||||
yield from self._send_packet(disconnect_packet)
|
||||
|
||||
async def mqtt_ping(self):
|
||||
@asyncio.coroutine
|
||||
def mqtt_ping(self):
|
||||
ping_packet = PingReqPacket()
|
||||
await self._send_packet(ping_packet)
|
||||
yield from self._send_packet(ping_packet)
|
||||
self._pingresp_waiter = futures.Future(loop=self._loop)
|
||||
resp = await self._pingresp_queue.get()
|
||||
resp = yield from self._pingresp_queue.get()
|
||||
self._pingresp_waiter = None
|
||||
return resp
|
||||
|
||||
async def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
await self._pingresp_queue.put(pingresp)
|
||||
@asyncio.coroutine
|
||||
def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
yield from self._pingresp_queue.put(pingresp)
|
||||
|
||||
async def handle_connection_closed(self):
|
||||
@asyncio.coroutine
|
||||
def handle_connection_closed(self):
|
||||
self.logger.debug("Broker closed connection")
|
||||
if not self._disconnect_waiter.done():
|
||||
self._disconnect_waiter.set_result(None)
|
||||
|
||||
async def wait_disconnect(self):
|
||||
await self._disconnect_waiter
|
||||
@asyncio.coroutine
|
||||
def wait_disconnect(self):
|
||||
yield from self._disconnect_waiter
|
||||
|
|
|
@ -97,20 +97,22 @@ class ProtocolHandler:
|
|||
else:
|
||||
return False
|
||||
|
||||
async def start(self):
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
if not self._is_attached():
|
||||
raise ProtocolHandlerException("Handler is not attached to a stream")
|
||||
self._reader_ready = asyncio.Event(loop=self._loop)
|
||||
self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop)
|
||||
await asyncio.wait([self._reader_ready.wait()], loop=self._loop)
|
||||
yield from asyncio.wait([self._reader_ready.wait()], loop=self._loop)
|
||||
if self.keepalive_timeout:
|
||||
self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout)
|
||||
|
||||
self.logger.debug("Handler tasks started")
|
||||
await self._retry_deliveries()
|
||||
yield from self._retry_deliveries()
|
||||
self.logger.debug("Handler ready")
|
||||
|
||||
async def stop(self):
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
# Stop messages flow waiter
|
||||
self._stop_waiters()
|
||||
if self._keepalive_task:
|
||||
|
@ -118,11 +120,11 @@ class ProtocolHandler:
|
|||
self.logger.debug("waiting for tasks to be stopped")
|
||||
if not self._reader_task.done():
|
||||
self._reader_task.cancel()
|
||||
await asyncio.wait(
|
||||
yield from asyncio.wait(
|
||||
[self._reader_stopped.wait()], loop=self._loop)
|
||||
self.logger.debug("closing writer")
|
||||
try:
|
||||
await self.writer.close()
|
||||
yield from self.writer.close()
|
||||
except Exception as e:
|
||||
self.logger.debug("Handler writer close failed: %s" % e)
|
||||
|
||||
|
@ -138,7 +140,8 @@ class ProtocolHandler:
|
|||
self._pubrel_waiters.values()):
|
||||
waiter.cancel()
|
||||
|
||||
async def _retry_deliveries(self):
|
||||
@asyncio.coroutine
|
||||
def _retry_deliveries(self):
|
||||
"""
|
||||
Handle [MQTT-4.4.0-1] by resending PUBLISH and PUBREL messages for pending out messages
|
||||
:return:
|
||||
|
@ -148,12 +151,13 @@ class ProtocolHandler:
|
|||
for message in itertools.chain(self.session.inflight_in.values(), self.session.inflight_out.values()):
|
||||
tasks.append(asyncio.wait_for(self._handle_message_flow(message), 10, loop=self._loop))
|
||||
if tasks:
|
||||
done, pending = await asyncio.wait(tasks)
|
||||
done, pending = yield from asyncio.wait(tasks)
|
||||
self.logger.debug("%d messages redelivered" % len(done))
|
||||
self.logger.debug("%d messages not redelivered due to timeout" % len(pending))
|
||||
self.logger.debug("End messages delivery retries")
|
||||
|
||||
async def mqtt_publish(self, topic, data, qos, retain, ack_timeout=None):
|
||||
@asyncio.coroutine
|
||||
def mqtt_publish(self, topic, data, qos, retain, ack_timeout=None):
|
||||
"""
|
||||
Sends a MQTT publish message and manages messages flows.
|
||||
This methods doesn't return until the message has been acknowledged by receiver or timeout occur
|
||||
|
@ -175,13 +179,14 @@ class ProtocolHandler:
|
|||
message = OutgoingApplicationMessage(packet_id, topic, qos, data, retain)
|
||||
# Handle message flow
|
||||
if ack_timeout is not None and ack_timeout > 0:
|
||||
await asyncio.wait_for(self._handle_message_flow(message), ack_timeout, loop=self._loop)
|
||||
yield from asyncio.wait_for(self._handle_message_flow(message), ack_timeout, loop=self._loop)
|
||||
else:
|
||||
await self._handle_message_flow(message)
|
||||
yield from self._handle_message_flow(message)
|
||||
|
||||
return message
|
||||
|
||||
async def _handle_message_flow(self, app_message):
|
||||
@asyncio.coroutine
|
||||
def _handle_message_flow(self, app_message):
|
||||
"""
|
||||
Handle protocol flow for incoming and outgoing messages, depending on service level and according to MQTT
|
||||
spec. paragraph 4.3-Quality of Service levels and protocol flows
|
||||
|
@ -189,15 +194,16 @@ class ProtocolHandler:
|
|||
:return: nothing.
|
||||
"""
|
||||
if app_message.qos == QOS_0:
|
||||
await self._handle_qos0_message_flow(app_message)
|
||||
yield from self._handle_qos0_message_flow(app_message)
|
||||
elif app_message.qos == QOS_1:
|
||||
await self._handle_qos1_message_flow(app_message)
|
||||
yield from self._handle_qos1_message_flow(app_message)
|
||||
elif app_message.qos == QOS_2:
|
||||
await self._handle_qos2_message_flow(app_message)
|
||||
yield from self._handle_qos2_message_flow(app_message)
|
||||
else:
|
||||
raise HBMQTTException("Unexcepted QOS value '%d" % str(app_message.qos))
|
||||
|
||||
async def _handle_qos0_message_flow(self, app_message):
|
||||
@asyncio.coroutine
|
||||
def _handle_qos0_message_flow(self, app_message):
|
||||
"""
|
||||
Handle QOS_0 application message acknowledgment
|
||||
For incoming messages, this method stores the message
|
||||
|
@ -209,7 +215,7 @@ class ProtocolHandler:
|
|||
if app_message.direction == OUTGOING:
|
||||
packet = app_message.build_publish_packet()
|
||||
# Send PUBLISH packet
|
||||
await self._send_packet(packet)
|
||||
yield from self._send_packet(packet)
|
||||
app_message.publish_packet = packet
|
||||
elif app_message.direction == INCOMING:
|
||||
if app_message.publish_packet.dup_flag:
|
||||
|
@ -221,7 +227,8 @@ class ProtocolHandler:
|
|||
except:
|
||||
self.logger.warning("delivered messages queue full. QOS_0 message discarded")
|
||||
|
||||
async def _handle_qos1_message_flow(self, app_message):
|
||||
@asyncio.coroutine
|
||||
def _handle_qos1_message_flow(self, app_message):
|
||||
"""
|
||||
Handle QOS_1 application message acknowledgment
|
||||
For incoming messages, this method stores the message and reply with PUBACK
|
||||
|
@ -242,13 +249,13 @@ class ProtocolHandler:
|
|||
else:
|
||||
publish_packet = app_message.build_publish_packet()
|
||||
# Send PUBLISH packet
|
||||
await self._send_packet(publish_packet)
|
||||
yield from self._send_packet(publish_packet)
|
||||
app_message.publish_packet = publish_packet
|
||||
|
||||
# Wait for puback
|
||||
waiter = asyncio.Future(loop=self._loop)
|
||||
self._puback_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
yield from waiter
|
||||
del self._puback_waiters[app_message.packet_id]
|
||||
app_message.puback_packet = waiter.result()
|
||||
|
||||
|
@ -257,13 +264,14 @@ class ProtocolHandler:
|
|||
elif app_message.direction == INCOMING:
|
||||
# Initiate delivery
|
||||
self.logger.debug("Add message to delivery")
|
||||
await self.session.delivered_message_queue.put(app_message)
|
||||
yield from self.session.delivered_message_queue.put(app_message)
|
||||
# Send PUBACK
|
||||
puback = PubackPacket.build(app_message.packet_id)
|
||||
await self._send_packet(puback)
|
||||
yield from self._send_packet(puback)
|
||||
app_message.puback_packet = puback
|
||||
|
||||
async def _handle_qos2_message_flow(self, app_message):
|
||||
@asyncio.coroutine
|
||||
def _handle_qos2_message_flow(self, app_message):
|
||||
"""
|
||||
Handle QOS_2 application message acknowledgment
|
||||
For incoming messages, this method stores the message, sends PUBREC, waits for PUBREL, initiate delivery
|
||||
|
@ -288,7 +296,7 @@ class ProtocolHandler:
|
|||
self.session.inflight_out[app_message.packet_id] = app_message
|
||||
publish_packet = app_message.build_publish_packet()
|
||||
# Send PUBLISH packet
|
||||
await self._send_packet(publish_packet)
|
||||
yield from self._send_packet(publish_packet)
|
||||
app_message.publish_packet = publish_packet
|
||||
# Wait PUBREC
|
||||
if app_message.packet_id in self._pubrec_waiters:
|
||||
|
@ -299,17 +307,17 @@ class ProtocolHandler:
|
|||
raise HBMQTTException(message)
|
||||
waiter = asyncio.Future(loop=self._loop)
|
||||
self._pubrec_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
yield from waiter
|
||||
del self._pubrec_waiters[app_message.packet_id]
|
||||
app_message.pubrec_packet = waiter.result()
|
||||
if not app_message.pubcomp_packet:
|
||||
# Send pubrel
|
||||
app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id)
|
||||
await self._send_packet(app_message.pubrel_packet)
|
||||
yield from self._send_packet(app_message.pubrel_packet)
|
||||
# Wait for PUBCOMP
|
||||
waiter = asyncio.Future(loop=self._loop)
|
||||
self._pubcomp_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
yield from waiter
|
||||
del self._pubcomp_waiters[app_message.packet_id]
|
||||
app_message.pubcomp_packet = waiter.result()
|
||||
# Discard inflight message
|
||||
|
@ -318,7 +326,7 @@ class ProtocolHandler:
|
|||
self.session.inflight_in[app_message.packet_id] = app_message
|
||||
# Send pubrec
|
||||
pubrec_packet = PubrecPacket.build(app_message.packet_id)
|
||||
await self._send_packet(pubrec_packet)
|
||||
yield from self._send_packet(pubrec_packet)
|
||||
app_message.pubrec_packet = pubrec_packet
|
||||
# Wait PUBREL
|
||||
if app_message.packet_id in self._pubrel_waiters:
|
||||
|
@ -329,18 +337,19 @@ class ProtocolHandler:
|
|||
raise HBMQTTException(message)
|
||||
waiter = asyncio.Future(loop=self._loop)
|
||||
self._pubrel_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
yield from waiter
|
||||
del self._pubrel_waiters[app_message.packet_id]
|
||||
app_message.pubrel_packet = waiter.result()
|
||||
# Initiate delivery and discard message
|
||||
await self.session.delivered_message_queue.put(app_message)
|
||||
yield from self.session.delivered_message_queue.put(app_message)
|
||||
del self.session.inflight_in[app_message.packet_id]
|
||||
# Send pubcomp
|
||||
pubcomp_packet = PubcompPacket.build(app_message.packet_id)
|
||||
await self._send_packet(pubcomp_packet)
|
||||
yield from self._send_packet(pubcomp_packet)
|
||||
app_message.pubcomp_packet = pubcomp_packet
|
||||
|
||||
async def _reader_loop(self):
|
||||
@asyncio.coroutine
|
||||
def _reader_loop(self):
|
||||
self.logger.debug("%s Starting reader coro" % self.session.client_id)
|
||||
running_tasks = collections.deque()
|
||||
keepalive_timeout = self.session.keep_alive
|
||||
|
@ -354,17 +363,17 @@ class ProtocolHandler:
|
|||
if len(running_tasks) > 1:
|
||||
self.logger.debug("handler running tasks: %d" % len(running_tasks))
|
||||
|
||||
fixed_header = await asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader),
|
||||
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader),
|
||||
keepalive_timeout, loop=self._loop)
|
||||
if fixed_header:
|
||||
if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15:
|
||||
self.logger.warning("%s Received reserved packet, which is forbidden: closing connection" %
|
||||
(self.session.client_id))
|
||||
await self.handle_connection_closed()
|
||||
yield from self.handle_connection_closed()
|
||||
else:
|
||||
cls = packet_class(fixed_header)
|
||||
packet = await cls.from_stream(self.reader, fixed_header=fixed_header)
|
||||
await self.plugins_manager.fire_event(
|
||||
packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header)
|
||||
yield from self.plugins_manager.fire_event(
|
||||
EVENT_MQTT_PACKET_RECEIVED, packet=packet, session=self.session)
|
||||
task = None
|
||||
if packet.fixed_header.packet_type == CONNACK:
|
||||
|
@ -418,30 +427,32 @@ class ProtocolHandler:
|
|||
break
|
||||
while running_tasks:
|
||||
running_tasks.popleft().cancel()
|
||||
await self.handle_connection_closed()
|
||||
yield from self.handle_connection_closed()
|
||||
self._reader_stopped.set()
|
||||
self.logger.debug("%s Reader coro stopped" % self.session.client_id)
|
||||
await self.stop()
|
||||
yield from self.stop()
|
||||
|
||||
async def _send_packet(self, packet):
|
||||
@asyncio.coroutine
|
||||
def _send_packet(self, packet):
|
||||
try:
|
||||
await packet.to_stream(self.writer)
|
||||
yield from packet.to_stream(self.writer)
|
||||
if self._keepalive_task:
|
||||
self._keepalive_task.cancel()
|
||||
self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout)
|
||||
|
||||
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
|
||||
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
|
||||
except ConnectionResetError as cre:
|
||||
await self.handle_connection_closed()
|
||||
yield from self.handle_connection_closed()
|
||||
raise
|
||||
except BaseException as e:
|
||||
self.logger.warning("Unhandled exception: %s" % e)
|
||||
raise
|
||||
|
||||
async def mqtt_deliver_next_message(self):
|
||||
@asyncio.coroutine
|
||||
def mqtt_deliver_next_message(self):
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug("%d message(s) available for delivery" % self.session.delivered_message_queue.qsize())
|
||||
message = await self.session.delivered_message_queue.get()
|
||||
message = yield from self.session.delivered_message_queue.get()
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug("Delivering message %s" % message)
|
||||
return message
|
||||
|
@ -452,37 +463,48 @@ class ProtocolHandler:
|
|||
def handle_read_timeout(self):
|
||||
self.logger.debug('%s read timeout unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_connack(self, connack: ConnackPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_connack(self, connack: ConnackPacket):
|
||||
self.logger.debug('%s CONNACK unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_connect(self, connect: ConnectPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_connect(self, connect: ConnectPacket):
|
||||
self.logger.debug('%s CONNECT unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_subscribe(self, subscribe: SubscribePacket):
|
||||
@asyncio.coroutine
|
||||
def handle_subscribe(self, subscribe: SubscribePacket):
|
||||
self.logger.debug('%s SUBSCRIBE unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_unsubscribe(self, subscribe: UnsubscribePacket):
|
||||
@asyncio.coroutine
|
||||
def handle_unsubscribe(self, subscribe: UnsubscribePacket):
|
||||
self.logger.debug('%s UNSUBSCRIBE unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_suback(self, suback: SubackPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
self.logger.debug('%s SUBACK unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
self.logger.debug('%s UNSUBACK unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
self.logger.debug('%s PINGRESP unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_pingreq(self, pingreq: PingReqPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_pingreq(self, pingreq: PingReqPacket):
|
||||
self.logger.debug('%s PINGREQ unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_disconnect(self, disconnect: DisconnectPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_disconnect(self, disconnect: DisconnectPacket):
|
||||
self.logger.debug('%s DISCONNECT unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_connection_closed(self):
|
||||
@asyncio.coroutine
|
||||
def handle_connection_closed(self):
|
||||
self.logger.debug('%s Connection closed unhandled' % self.session.client_id)
|
||||
|
||||
async def handle_puback(self, puback: PubackPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_puback(self, puback: PubackPacket):
|
||||
packet_id = puback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._puback_waiters[packet_id]
|
||||
|
@ -492,7 +514,8 @@ class ProtocolHandler:
|
|||
except InvalidStateError:
|
||||
self.logger.warning("PUBACK waiter with Id '%d' already done" % packet_id)
|
||||
|
||||
async def handle_pubrec(self, pubrec: PubrecPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_pubrec(self, pubrec: PubrecPacket):
|
||||
packet_id = pubrec.packet_id
|
||||
try:
|
||||
waiter = self._pubrec_waiters[packet_id]
|
||||
|
@ -502,7 +525,8 @@ class ProtocolHandler:
|
|||
except InvalidStateError:
|
||||
self.logger.warning("PUBREC waiter with Id '%d' already done" % packet_id)
|
||||
|
||||
async def handle_pubcomp(self, pubcomp: PubcompPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_pubcomp(self, pubcomp: PubcompPacket):
|
||||
packet_id = pubcomp.packet_id
|
||||
try:
|
||||
waiter = self._pubcomp_waiters[packet_id]
|
||||
|
@ -512,7 +536,8 @@ class ProtocolHandler:
|
|||
except InvalidStateError:
|
||||
self.logger.warning("PUBCOMP waiter with Id '%d' already done" % packet_id)
|
||||
|
||||
async def handle_pubrel(self, pubrel: PubrelPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_pubrel(self, pubrel: PubrelPacket):
|
||||
packet_id = pubrel.packet_id
|
||||
try:
|
||||
waiter = self._pubrel_waiters[packet_id]
|
||||
|
@ -522,11 +547,12 @@ class ProtocolHandler:
|
|||
except InvalidStateError:
|
||||
self.logger.warning("PUBREL waiter with Id '%d' already done" % packet_id)
|
||||
|
||||
async def handle_publish(self, publish_packet: PublishPacket):
|
||||
@asyncio.coroutine
|
||||
def handle_publish(self, publish_packet: PublishPacket):
|
||||
packet_id = publish_packet.variable_header.packet_id
|
||||
qos = publish_packet.qos
|
||||
|
||||
incoming_message = IncomingApplicationMessage(packet_id, publish_packet.topic_name, qos, publish_packet.data, publish_packet.retain_flag)
|
||||
incoming_message.publish_packet = publish_packet
|
||||
await self._handle_message_flow(incoming_message)
|
||||
yield from self._handle_message_flow(incoming_message)
|
||||
self.logger.debug("Message queue size: %d" % self.session.delivered_message_queue.qsize())
|
||||
|
|
|
@ -25,11 +25,12 @@ class PublishVariableHeader(MQTTVariableHeader):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
|
||||
topic_name = await decode_string(reader)
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
|
||||
topic_name = yield from decode_string(reader)
|
||||
has_qos = (fixed_header.flags >> 1) & 0x03
|
||||
if has_qos:
|
||||
packet_id = await decode_packet_id(reader)
|
||||
packet_id = yield from decode_packet_id(reader)
|
||||
else:
|
||||
packet_id = None
|
||||
return cls(topic_name, packet_id)
|
||||
|
@ -44,9 +45,10 @@ class PublishPayload(MQTTPayload):
|
|||
return self.data
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
variable_header: MQTTVariableHeader):
|
||||
data = await reader.read(fixed_header.remaining_length-variable_header.bytes_length)
|
||||
data = yield from reader.read(fixed_header.remaining_length-variable_header.bytes_length)
|
||||
return cls(data)
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -27,13 +27,14 @@ class SubackPayload(MQTTPayload):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader,
|
||||
variable_header: MQTTVariableHeader):
|
||||
return_codes = []
|
||||
bytes_to_read = fixed_header.remaining_length - variable_header.bytes_length
|
||||
for i in range(0, bytes_to_read):
|
||||
try:
|
||||
return_code_byte = await read_or_raise(reader, 1)
|
||||
return_code_byte = yield from read_or_raise(reader, 1)
|
||||
return_code = bytes_to_int(return_code_byte)
|
||||
return_codes.append(return_code)
|
||||
except NoDataException:
|
||||
|
|
|
@ -19,15 +19,16 @@ class SubscribePayload(MQTTPayload):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
variable_header: MQTTVariableHeader):
|
||||
topics = []
|
||||
payload_length = fixed_header.remaining_length - variable_header.bytes_length
|
||||
read_bytes = 0
|
||||
while read_bytes < payload_length:
|
||||
try:
|
||||
topic = await decode_string(reader)
|
||||
qos_byte = await read_or_raise(reader, 1)
|
||||
topic = yield from decode_string(reader)
|
||||
qos_byte = yield from read_or_raise(reader, 1)
|
||||
qos = bytes_to_int(qos_byte)
|
||||
topics.append((topic, qos))
|
||||
read_bytes += 2 + len(topic.encode('utf-8')) + 1
|
||||
|
|
|
@ -18,14 +18,15 @@ class UnubscribePayload(MQTTPayload):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
async def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
variable_header: MQTTVariableHeader):
|
||||
topics = []
|
||||
payload_length = fixed_header.remaining_length - variable_header.bytes_length
|
||||
read_bytes = 0
|
||||
while read_bytes < payload_length:
|
||||
try:
|
||||
topic = await decode_string(reader)
|
||||
topic = yield from decode_string(reader)
|
||||
topics.append(topic)
|
||||
read_bytes += 2 + len(topic.encode('utf-8'))
|
||||
except NoDataException:
|
||||
|
|
|
@ -26,7 +26,8 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
|
|||
def __init__(self, context):
|
||||
super().__init__(context)
|
||||
|
||||
async def authenticate(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def authenticate(self, *args, **kwargs):
|
||||
authenticated = super().authenticate(*args, **kwargs)
|
||||
if authenticated:
|
||||
allow_anonymous = self.auth_config.get('allow-anonymous', True) # allow anonymous by default
|
||||
|
@ -73,7 +74,8 @@ class FileAuthPlugin(BaseAuthPlugin):
|
|||
else:
|
||||
self.context.logger.debug("Configuration parameter 'password_file' not found")
|
||||
|
||||
async def authenticate(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def authenticate(self, *args, **kwargs):
|
||||
authenticated = super().authenticate(*args, **kwargs)
|
||||
if authenticated:
|
||||
session = kwargs.get('session', None)
|
||||
|
|
|
@ -12,7 +12,8 @@ class EventLoggerPlugin:
|
|||
def __init__(self, context):
|
||||
self.context = context
|
||||
|
||||
async def log_event(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def log_event(self, *args, **kwargs):
|
||||
self.context.logger.info("### '%s' EVENT FIRED ###" % kwargs['event_name'].replace('old', ''))
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
@ -24,7 +25,8 @@ class PacketLoggerPlugin:
|
|||
def __init__(self, context):
|
||||
self.context = context
|
||||
|
||||
async def on_mqtt_packet_received(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_mqtt_packet_received(self, *args, **kwargs):
|
||||
packet = kwargs.get('packet')
|
||||
session = kwargs.get('session', None)
|
||||
if self.context.logger.isEnabledFor(logging.DEBUG):
|
||||
|
@ -33,7 +35,8 @@ class PacketLoggerPlugin:
|
|||
else:
|
||||
self.context.logger.debug("<-in-- %s" % repr(packet))
|
||||
|
||||
async def on_mqtt_packet_sent(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_mqtt_packet_sent(self, *args, **kwargs):
|
||||
packet = kwargs.get('packet')
|
||||
session = kwargs.get('session', None)
|
||||
if self.context.logger.isEnabledFor(logging.DEBUG):
|
||||
|
|
|
@ -87,13 +87,14 @@ class PluginManager:
|
|||
return p
|
||||
return None
|
||||
|
||||
async def close(self):
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
"""
|
||||
Free PluginManager resources and cancel pending event methods
|
||||
This method call a close() coroutine for each plugin, allowing plugins to close and free resources
|
||||
:return:
|
||||
"""
|
||||
await self.map_plugin_coro("close")
|
||||
yield from self.map_plugin_coro("close")
|
||||
for task in self._fired_events:
|
||||
task.cancel()
|
||||
|
||||
|
@ -108,10 +109,11 @@ class PluginManager:
|
|||
def _schedule_coro(self, coro):
|
||||
return asyncio.ensure_future(coro, loop=self._loop)
|
||||
|
||||
async def fire_event(self, event_name, wait=False, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def fire_event(self, event_name, wait=False, *args, **kwargs):
|
||||
"""
|
||||
Fire an event to plugins.
|
||||
PluginManager schedule async calls for each plugin on method called "on_" + event_name
|
||||
PluginManager schedule @asyncio.coroutinecalls for each plugin on method called "on_" + event_name
|
||||
For example, on_connect will be called on event 'connect'
|
||||
Method calls are schedule in the asyn loop. wait parameter must be set to true to wait until all
|
||||
mehtods are completed.
|
||||
|
@ -133,11 +135,12 @@ class PluginManager:
|
|||
(event_method_name, plugin.name))
|
||||
if wait:
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, loop=self._loop)
|
||||
yield from asyncio.wait(tasks, loop=self._loop)
|
||||
else:
|
||||
self._fired_events.extend(tasks)
|
||||
|
||||
async def map(self, coro, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def map(self, coro, *args, **kwargs):
|
||||
"""
|
||||
Schedule a given coroutine call for each plugin.
|
||||
The coro called get the Plugin instance as first argument of its method call
|
||||
|
@ -164,7 +167,7 @@ class PluginManager:
|
|||
self.logger.error("Method '%r' on plugin '%s' is not a coroutine" %
|
||||
(coro, plugin.name))
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks, loop=self._loop)
|
||||
ret_list = yield from asyncio.gather(*tasks, loop=self._loop)
|
||||
# Create result map plugin=>ret
|
||||
ret_dict = {k: v for k, v in zip(plugins_list, ret_list)}
|
||||
else:
|
||||
|
@ -172,15 +175,17 @@ class PluginManager:
|
|||
return ret_dict
|
||||
|
||||
@staticmethod
|
||||
async def _call_coro(plugin, coro_name, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def _call_coro(plugin, coro_name, *args, **kwargs):
|
||||
try:
|
||||
coro = getattr(plugin.object, coro_name, None)(*args, **kwargs)
|
||||
return await coro
|
||||
return (yield from coro)
|
||||
except TypeError:
|
||||
# Plugin doesn't implement coro_name
|
||||
return None
|
||||
|
||||
async def map_plugin_coro(self, coro_name, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def map_plugin_coro(self, coro_name, *args, **kwargs):
|
||||
"""
|
||||
Call a plugin declared by plugin by its name
|
||||
:param coro_name:
|
||||
|
@ -188,4 +193,4 @@ class PluginManager:
|
|||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
return await self.map(self._call_coro, coro_name, *args, **kwargs)
|
||||
return (yield from self.map(self._call_coro, coro_name, *args, **kwargs))
|
||||
|
|
|
@ -32,7 +32,8 @@ class SQLitePlugin:
|
|||
if self.cursor:
|
||||
self.cursor.execute("CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)")
|
||||
|
||||
async def save_session(self, session):
|
||||
@asyncio.coroutine
|
||||
def save_session(self, session):
|
||||
if self.cursor:
|
||||
dump = pickle.dumps(session)
|
||||
try:
|
||||
|
@ -42,7 +43,8 @@ class SQLitePlugin:
|
|||
except Exception as e:
|
||||
self.context.logger.error("Failed saving session '%s': %s" % (session, e))
|
||||
|
||||
async def find_session(self, client_id):
|
||||
@asyncio.coroutine
|
||||
def find_session(self, client_id):
|
||||
if self.cursor:
|
||||
row = self.cursor.execute("SELECT data FROM session where client_id=?", (client_id,)).fetchone()
|
||||
if row:
|
||||
|
@ -50,12 +52,14 @@ class SQLitePlugin:
|
|||
else:
|
||||
return None
|
||||
|
||||
async def del_session(self, client_id):
|
||||
@asyncio.coroutine
|
||||
def del_session(self, client_id):
|
||||
if self.cursor:
|
||||
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
|
||||
self.conn.commit()
|
||||
|
||||
async def on_broker_post_shutdown(self):
|
||||
@asyncio.coroutine
|
||||
def on_broker_post_shutdown(self):
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.context.logger.info("Database file '%s' closed" % self.db_file)
|
||||
|
|
|
@ -42,17 +42,20 @@ class BrokerSysPlugin:
|
|||
STAT_PUBLISH_SENT):
|
||||
self._stats[stat] = 0
|
||||
|
||||
async def _broadcast_sys_topic(self, topic_basename, data):
|
||||
return await self.context.broadcast_message(DOLLAR_SYS_ROOT + topic_basename, data)
|
||||
@asyncio.coroutine
|
||||
def _broadcast_sys_topic(self, topic_basename, data):
|
||||
return (yield from self.context.broadcast_message(DOLLAR_SYS_ROOT + topic_basename, data))
|
||||
|
||||
def schedule_broadcast_sys_topic(self, topic_basename, data):
|
||||
return asyncio.ensure_future(self._broadcast_sys_topic(DOLLAR_SYS_ROOT + topic_basename, data),
|
||||
loop=self.context.loop)
|
||||
|
||||
async def on_broker_pre_start(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_broker_pre_start(self, *args, **kwargs):
|
||||
self._clear_stats()
|
||||
|
||||
async def on_broker_post_start(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_broker_post_start(self, *args, **kwargs):
|
||||
self._stats[STAT_START_TIME] = datetime.now()
|
||||
from hbmqtt.version import get_version
|
||||
version = 'HBMQTT version ' + get_version()
|
||||
|
@ -70,7 +73,8 @@ class BrokerSysPlugin:
|
|||
pass
|
||||
# 'sys_internal' config parameter not found
|
||||
|
||||
async def on_broker_pre_stop(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_broker_pre_stop(self, *args, **kwargs):
|
||||
# Stop $SYS topics broadcasting
|
||||
if self.sys_handle:
|
||||
self.sys_handle.cancel()
|
||||
|
@ -127,7 +131,8 @@ class BrokerSysPlugin:
|
|||
self.context.logger.debug("Broadcasting $SYS topics")
|
||||
self.sys_handle = self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
|
||||
|
||||
async def on_mqtt_packet_received(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_mqtt_packet_received(self, *args, **kwargs):
|
||||
packet = kwargs.get('packet')
|
||||
if packet:
|
||||
packet_size = packet.bytes_length
|
||||
|
@ -136,7 +141,8 @@ class BrokerSysPlugin:
|
|||
if packet.fixed_header.packet_type == PUBLISH:
|
||||
self._stats[STAT_PUBLISH_RECEIVED] += 1
|
||||
|
||||
async def on_mqtt_packet_sent(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_mqtt_packet_sent(self, *args, **kwargs):
|
||||
packet = kwargs.get('packet')
|
||||
if packet:
|
||||
packet_size = packet.bytes_length
|
||||
|
@ -145,10 +151,12 @@ class BrokerSysPlugin:
|
|||
if packet.fixed_header.packet_type == PUBLISH:
|
||||
self._stats[STAT_PUBLISH_SENT] += 1
|
||||
|
||||
async def on_broker_client_connected(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_broker_client_connected(self, *args, **kwargs):
|
||||
self._stats[STAT_CLIENTS_CONNECTED] += 1
|
||||
self._stats[STAT_CLIENTS_MAXIMUM] = max(self._stats[STAT_CLIENTS_MAXIMUM], self._stats[STAT_CLIENTS_CONNECTED])
|
||||
|
||||
async def on_broker_client_disconnected(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_broker_client_disconnected(self, *args, **kwargs):
|
||||
self._stats[STAT_CLIENTS_CONNECTED] -= 1
|
||||
self._stats[STAT_CLIENTS_DISCONNECTED] += 1
|
||||
|
|
|
@ -30,10 +30,11 @@ config = {
|
|||
|
||||
broker = Broker(config)
|
||||
|
||||
async def test_coro():
|
||||
await broker.start()
|
||||
#await asyncio.sleep(5)
|
||||
#await broker.shutdown()
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from broker.start()
|
||||
#yield from asyncio.sleep(5)
|
||||
#yield from broker.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -18,11 +18,12 @@ config = {
|
|||
}
|
||||
C = MQTTClient(config=config)
|
||||
|
||||
async def test_coro():
|
||||
await C.connect('mqtt://test.mosquitto.org:1883/')
|
||||
await asyncio.sleep(18)
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect('mqtt://test.mosquitto.org:1883/')
|
||||
yield from asyncio.sleep(18)
|
||||
|
||||
await C.disconnect()
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -29,28 +29,30 @@ def disconnected(future):
|
|||
asyncio.get_event_loop().stop()
|
||||
|
||||
|
||||
async def test_coro():
|
||||
await C.connect('mqtt://test:test@localhost:1883/')
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect('mqtt://test:test@localhost:1883/')
|
||||
tasks = [
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
|
||||
]
|
||||
await asyncio.wait(tasks)
|
||||
yield from asyncio.wait(tasks)
|
||||
logger.info("messages published")
|
||||
await C.disconnect()
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
async def test_coro2():
|
||||
@asyncio.coroutine
|
||||
def test_coro2():
|
||||
try:
|
||||
future = await C.connect('mqtt://test.mosquitto.org:1883/')
|
||||
future = yield from C.connect('mqtt://test.mosquitto.org:1883/')
|
||||
future.add_done_callback(disconnected)
|
||||
message = await C.publish('a/b', b'TEST MESSAGE WITH QOS_0', qos=0x00)
|
||||
message = await C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)
|
||||
message = await C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)
|
||||
message = yield from C.publish('a/b', b'TEST MESSAGE WITH QOS_0', qos=0x00)
|
||||
message = yield from C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)
|
||||
message = yield from C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)
|
||||
#print(message)
|
||||
logger.info("messages published")
|
||||
await C.disconnect()
|
||||
yield from C.disconnect()
|
||||
except ConnectException as ce:
|
||||
logger.error("Connection failed: %s" % ce)
|
||||
asyncio.get_event_loop().stop()
|
||||
|
|
|
@ -23,16 +23,17 @@ config = {
|
|||
C = MQTTClient(config=config)
|
||||
#C = MQTTClient()
|
||||
|
||||
async def test_coro():
|
||||
await C.connect('mqtts://test.mosquitto.org/', cafile='mosquitto.org.crt')
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect('mqtts://test.mosquitto.org/', cafile='mosquitto.org.crt')
|
||||
tasks = [
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
|
||||
]
|
||||
await asyncio.wait(tasks)
|
||||
yield from asyncio.wait(tasks)
|
||||
logger.info("messages published")
|
||||
await C.disconnect()
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -24,16 +24,17 @@ config = {
|
|||
C = MQTTClient(config=config)
|
||||
#C = MQTTClient()
|
||||
|
||||
async def test_coro():
|
||||
await C.connect('wss://test.mosquitto.org:8081/', cafile='mosquitto.org.crt')
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect('wss://test.mosquitto.org:8081/', cafile='mosquitto.org.crt')
|
||||
tasks = [
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||
asyncio.ensure_future(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
|
||||
]
|
||||
await asyncio.wait(tasks)
|
||||
yield from asyncio.wait(tasks)
|
||||
logger.info("messages published")
|
||||
await C.disconnect()
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,22 +15,23 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
C = MQTTClient()
|
||||
|
||||
async def uptime_coro():
|
||||
await C.connect('mqtt://localhost/')
|
||||
@asyncio.coroutine
|
||||
def uptime_coro():
|
||||
yield from C.connect('mqtt://localhost/')
|
||||
# Subscribe to '$SYS/broker/uptime' with QOS=1
|
||||
await C.subscribe([
|
||||
yield from C.subscribe([
|
||||
('$SYS/broker/uptime', QOS_1),
|
||||
('$SYS/broker/load/#', QOS_2),
|
||||
])
|
||||
logger.info("Subscribed")
|
||||
try:
|
||||
for i in range(1, 100):
|
||||
message = await C.deliver_message()
|
||||
message = yield from C.deliver_message()
|
||||
packet = message.publish_packet
|
||||
print("%d %s : %s" % (i, packet.variable_header.topic_name, str(packet.payload.data)))
|
||||
await C.unsubscribe(['$SYS/broker/uptime'])
|
||||
yield from C.unsubscribe(['$SYS/broker/uptime'])
|
||||
logger.info("UnSubscribed")
|
||||
await C.disconnect()
|
||||
yield from C.disconnect()
|
||||
except ClientException as ce:
|
||||
logger.error("Client exception: %s" % ce)
|
||||
|
||||
|
|
|
@ -47,8 +47,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def main(*args, **kwargs):
|
||||
if sys.version_info[:2] < (3, 5):
|
||||
logger.fatal("Error: Python 3.5 is required")
|
||||
if sys.version_info[:2] < (3, 4):
|
||||
logger.fatal("Error: Python 3.4+ is required")
|
||||
sys.exit(-1)
|
||||
|
||||
arguments = docopt(__doc__, version=get_version())
|
||||
|
|
|
@ -2,6 +2,6 @@ keep_alive: 10
|
|||
ping_delay: 1
|
||||
default_qos: 0
|
||||
default_retain: false
|
||||
auto_reconnect: true
|
||||
auto_reconnect: false
|
||||
reconnect_max_interval: 10
|
||||
reconnect_retries: 2
|
||||
|
|
|
@ -107,8 +107,8 @@ def do_pub(client, arguments):
|
|||
|
||||
|
||||
def main(*args, **kwargs):
|
||||
if sys.version_info[:2] < (3, 5):
|
||||
logger.fatal("Error: Python 3.5 is required")
|
||||
if sys.version_info[:2] < (3, 4):
|
||||
logger.fatal("Error: Python 3.4+ is required")
|
||||
sys.exit(-1)
|
||||
|
||||
arguments = docopt(__doc__, version=get_version())
|
||||
|
|
|
@ -88,8 +88,8 @@ def do_sub(client, arguments):
|
|||
|
||||
|
||||
def main(*args, **kwargs):
|
||||
if sys.version_info[:2] < (3, 5):
|
||||
logger.fatal("Error: Python 3.5 is required")
|
||||
if sys.version_info[:2] < (3, 4):
|
||||
logger.fatal("Error: Python 3.4+ is required")
|
||||
sys.exit(-1)
|
||||
|
||||
arguments = docopt(__doc__, version=get_version())
|
||||
|
|
|
@ -41,18 +41,20 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
self.check_empty_waiters(handler)
|
||||
|
||||
def test_start_stop(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
pass
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
s = Session()
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
handler = ProtocolHandler(self.plugin_manager)
|
||||
handler.attach(s, reader_adapted, writer_adapted)
|
||||
await self.start_handler(handler, s)
|
||||
await self.stop_handler(handler, s)
|
||||
yield from self.start_handler(handler, s)
|
||||
yield from self.stop_handler(handler, s)
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -67,31 +69,33 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_publish_qos0(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
try:
|
||||
packet = await PublishPacket.from_stream(reader)
|
||||
packet = yield from PublishPacket.from_stream(reader)
|
||||
self.assertEquals(packet.variable_header.topic_name, '/topic')
|
||||
self.assertEquals(packet.qos, QOS_0)
|
||||
self.assertIsNone(packet.packet_id)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
s = Session()
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
handler.attach(s, reader_adapted, writer_adapted)
|
||||
await self.start_handler(handler, s)
|
||||
message = await handler.mqtt_publish('/topic', b'test_data', QOS_0, False)
|
||||
yield from self.start_handler(handler, s)
|
||||
message = yield from handler.mqtt_publish('/topic', b'test_data', QOS_0, False)
|
||||
self.assertIsInstance(message, OutgoingApplicationMessage)
|
||||
self.assertIsNotNone(message.publish_packet)
|
||||
self.assertIsNone(message.puback_packet)
|
||||
self.assertIsNone(message.pubrec_packet)
|
||||
self.assertIsNone(message.pubrel_packet)
|
||||
self.assertIsNone(message.pubcomp_packet)
|
||||
await self.stop_handler(handler, s)
|
||||
yield from self.stop_handler(handler, s)
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -106,8 +110,9 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_publish_qos1(self):
|
||||
async def server_mock(reader, writer):
|
||||
packet = await PublishPacket.from_stream(reader)
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
packet = yield from PublishPacket.from_stream(reader)
|
||||
try:
|
||||
self.assertEquals(packet.variable_header.topic_name, '/topic')
|
||||
self.assertEquals(packet.qos, QOS_1)
|
||||
|
@ -115,25 +120,26 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
self.assertIn(packet.packet_id, self.session.inflight_out)
|
||||
self.assertIn(packet.packet_id, self.handler._puback_waiters)
|
||||
puback = PubackPacket.build(packet.packet_id)
|
||||
await puback.to_stream(writer)
|
||||
yield from puback.to_stream(writer)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.start_handler(self.handler, self.session)
|
||||
message = await self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False)
|
||||
yield from self.start_handler(self.handler, self.session)
|
||||
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_1, False)
|
||||
self.assertIsInstance(message, OutgoingApplicationMessage)
|
||||
self.assertIsNotNone(message.publish_packet)
|
||||
self.assertIsNotNone(message.puback_packet)
|
||||
self.assertIsNone(message.pubrec_packet)
|
||||
self.assertIsNone(message.pubrel_packet)
|
||||
self.assertIsNone(message.pubcomp_packet)
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
if not future.done():
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
@ -151,39 +157,41 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_publish_qos2(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
try:
|
||||
packet = await PublishPacket.from_stream(reader)
|
||||
packet = yield from PublishPacket.from_stream(reader)
|
||||
self.assertEquals(packet.topic_name, '/topic')
|
||||
self.assertEquals(packet.qos, QOS_2)
|
||||
self.assertIsNotNone(packet.packet_id)
|
||||
self.assertIn(packet.packet_id, self.session.inflight_out)
|
||||
self.assertIn(packet.packet_id, self.handler._pubrec_waiters)
|
||||
pubrec = PubrecPacket.build(packet.packet_id)
|
||||
await pubrec.to_stream(writer)
|
||||
yield from pubrec.to_stream(writer)
|
||||
|
||||
pubrel = await PubrelPacket.from_stream(reader)
|
||||
pubrel = yield from PubrelPacket.from_stream(reader)
|
||||
self.assertIn(packet.packet_id, self.handler._pubcomp_waiters)
|
||||
pubcomp = PubcompPacket.build(packet.packet_id)
|
||||
await pubcomp.to_stream(writer)
|
||||
yield from pubcomp.to_stream(writer)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.start_handler(self.handler, self.session)
|
||||
message = await self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False)
|
||||
yield from self.start_handler(self.handler, self.session)
|
||||
message = yield from self.handler.mqtt_publish('/topic', b'test_data', QOS_2, False)
|
||||
self.assertIsInstance(message, OutgoingApplicationMessage)
|
||||
self.assertIsNotNone(message.publish_packet)
|
||||
self.assertIsNone(message.puback_packet)
|
||||
self.assertIsNotNone(message.pubrec_packet)
|
||||
self.assertIsNotNone(message.pubrel_packet)
|
||||
self.assertIsNotNone(message.pubcomp_packet)
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
if not future.done():
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
@ -201,25 +209,27 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_receive_qos0(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
packet = PublishPacket.build('/topic', b'test_data', 1, False, QOS_0, False)
|
||||
await packet.to_stream(writer)
|
||||
yield from packet.to_stream(writer)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.start_handler(self.handler, self.session)
|
||||
message = await self.handler.mqtt_deliver_next_message()
|
||||
yield from self.start_handler(self.handler, self.session)
|
||||
message = yield from self.handler.mqtt_deliver_next_message()
|
||||
self.assertIsInstance(message, IncomingApplicationMessage)
|
||||
self.assertIsNotNone(message.publish_packet)
|
||||
self.assertIsNone(message.puback_packet)
|
||||
self.assertIsNone(message.pubrec_packet)
|
||||
self.assertIsNone(message.pubrel_packet)
|
||||
self.assertIsNone(message.pubcomp_packet)
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -236,32 +246,34 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_receive_qos1(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
try:
|
||||
packet = PublishPacket.build('/topic', b'test_data', 1, False, QOS_1, False)
|
||||
await packet.to_stream(writer)
|
||||
puback = await PubackPacket.from_stream(reader)
|
||||
yield from packet.to_stream(writer)
|
||||
puback = yield from PubackPacket.from_stream(reader)
|
||||
self.assertIsNotNone(puback)
|
||||
self.assertEqual(packet.packet_id, puback.packet_id)
|
||||
except Exception as ae:
|
||||
print(ae)
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.start_handler(self.handler, self.session)
|
||||
message = await self.handler.mqtt_deliver_next_message()
|
||||
yield from self.start_handler(self.handler, self.session)
|
||||
message = yield from self.handler.mqtt_deliver_next_message()
|
||||
self.assertIsInstance(message, IncomingApplicationMessage)
|
||||
self.assertIsNotNone(message.publish_packet)
|
||||
self.assertIsNotNone(message.puback_packet)
|
||||
self.assertIsNone(message.pubrec_packet)
|
||||
self.assertIsNone(message.pubrel_packet)
|
||||
self.assertIsNone(message.pubcomp_packet)
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -279,37 +291,39 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_receive_qos2(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
try:
|
||||
packet = PublishPacket.build('/topic', b'test_data', 2, False, QOS_2, False)
|
||||
await packet.to_stream(writer)
|
||||
pubrec = await PubrecPacket.from_stream(reader)
|
||||
yield from packet.to_stream(writer)
|
||||
pubrec = yield from PubrecPacket.from_stream(reader)
|
||||
self.assertIsNotNone(pubrec)
|
||||
self.assertEqual(packet.packet_id, pubrec.packet_id)
|
||||
self.assertIn(packet.packet_id, self.handler._pubrel_waiters)
|
||||
pubrel = PubrelPacket.build(packet.packet_id)
|
||||
await pubrel.to_stream(writer)
|
||||
pubcomp = await PubcompPacket.from_stream(reader)
|
||||
yield from pubrel.to_stream(writer)
|
||||
pubcomp = yield from PubcompPacket.from_stream(reader)
|
||||
self.assertIsNotNone(pubcomp)
|
||||
self.assertEqual(packet.packet_id, pubcomp.packet_id)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.start_handler(self.handler, self.session)
|
||||
message = await self.handler.mqtt_deliver_next_message()
|
||||
yield from self.start_handler(self.handler, self.session)
|
||||
message = yield from self.handler.mqtt_deliver_next_message()
|
||||
self.assertIsInstance(message, IncomingApplicationMessage)
|
||||
self.assertIsNotNone(message.publish_packet)
|
||||
self.assertIsNone(message.puback_packet)
|
||||
self.assertIsNotNone(message.pubrec_packet)
|
||||
self.assertIsNotNone(message.pubrel_packet)
|
||||
self.assertIsNotNone(message.pubcomp_packet)
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -325,14 +339,16 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
async def start_handler(self, handler, session):
|
||||
@asyncio.coroutine
|
||||
def start_handler(self, handler, session):
|
||||
self.check_empty_waiters(handler)
|
||||
self.check_no_message(session)
|
||||
await handler.start()
|
||||
yield from handler.start()
|
||||
self.assertTrue(handler._reader_ready)
|
||||
|
||||
async def stop_handler(self, handler, session):
|
||||
await handler.stop()
|
||||
@asyncio.coroutine
|
||||
def stop_handler(self, handler, session):
|
||||
yield from handler.stop()
|
||||
self.assertTrue(handler._reader_stopped)
|
||||
self.check_empty_waiters(handler)
|
||||
self.check_no_message(session)
|
||||
|
@ -348,8 +364,9 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
self.assertFalse(session.inflight_in)
|
||||
|
||||
def test_publish_qos1_retry(self):
|
||||
async def server_mock(reader, writer):
|
||||
packet = await PublishPacket.from_stream(reader)
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
packet = yield from PublishPacket.from_stream(reader)
|
||||
try:
|
||||
self.assertEquals(packet.topic_name, '/topic')
|
||||
self.assertEquals(packet.qos, QOS_1)
|
||||
|
@ -357,18 +374,19 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
self.assertIn(packet.packet_id, self.session.inflight_out)
|
||||
self.assertIn(packet.packet_id, self.handler._puback_waiters)
|
||||
puback = PubackPacket.build(packet.packet_id)
|
||||
await puback.to_stream(writer)
|
||||
yield from puback.to_stream(writer)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.handler.start()
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.handler.start()
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
if not future.done():
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
@ -389,32 +407,34 @@ class ProtocolHandlerTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_publish_qos2_retry(self):
|
||||
async def server_mock(reader, writer):
|
||||
@asyncio.coroutine
|
||||
def server_mock(reader, writer):
|
||||
try:
|
||||
packet = await PublishPacket.from_stream(reader)
|
||||
packet = yield from PublishPacket.from_stream(reader)
|
||||
self.assertEquals(packet.topic_name, '/topic')
|
||||
self.assertEquals(packet.qos, QOS_2)
|
||||
self.assertIsNotNone(packet.packet_id)
|
||||
self.assertIn(packet.packet_id, self.session.inflight_out)
|
||||
self.assertIn(packet.packet_id, self.handler._pubrec_waiters)
|
||||
pubrec = PubrecPacket.build(packet.packet_id)
|
||||
await pubrec.to_stream(writer)
|
||||
yield from pubrec.to_stream(writer)
|
||||
|
||||
pubrel = await PubrelPacket.from_stream(reader)
|
||||
pubrel = yield from PubrelPacket.from_stream(reader)
|
||||
self.assertIn(packet.packet_id, self.handler._pubcomp_waiters)
|
||||
pubcomp = PubcompPacket.build(packet.packet_id)
|
||||
await pubcomp.to_stream(writer)
|
||||
yield from pubcomp.to_stream(writer)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader, writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=self.loop)
|
||||
reader_adapted, writer_adapted = adapt(reader, writer)
|
||||
self.handler = ProtocolHandler(self.plugin_manager, loop=self.loop)
|
||||
self.handler.attach(self.session, reader_adapted, writer_adapted)
|
||||
await self.handler.start()
|
||||
await self.stop_handler(self.handler, self.session)
|
||||
yield from self.handler.start()
|
||||
yield from self.stop_handler(self.handler, self.session)
|
||||
if not future.done():
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
|
|
@ -21,14 +21,17 @@ class EventTestPlugin:
|
|||
self.test_flag = False
|
||||
self.coro_flag = False
|
||||
|
||||
async def on_test(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def on_test(self, *args, **kwargs):
|
||||
self.test_flag = True
|
||||
self.context.logger.info("on_test")
|
||||
|
||||
async def test_coro(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def test_coro(self, *args, **kwargs):
|
||||
self.coro_flag = True
|
||||
|
||||
async def ret_coro(self, *args, **kwargs):
|
||||
@asyncio.coroutine
|
||||
def ret_coro(self, *args, **kwargs):
|
||||
return "TEST"
|
||||
|
||||
|
||||
|
@ -41,10 +44,11 @@ class TestPluginManager(unittest.TestCase):
|
|||
self.assertTrue(len(manager._plugins) > 0)
|
||||
|
||||
def test_fire_event(self):
|
||||
async def fire_event():
|
||||
await manager.fire_event("test")
|
||||
await asyncio.sleep(1, loop=self.loop)
|
||||
await manager.close()
|
||||
@asyncio.coroutine
|
||||
def fire_event():
|
||||
yield from manager.fire_event("test")
|
||||
yield from asyncio.sleep(1, loop=self.loop)
|
||||
yield from manager.close()
|
||||
|
||||
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
|
||||
self.loop.run_until_complete(fire_event())
|
||||
|
@ -52,9 +56,10 @@ class TestPluginManager(unittest.TestCase):
|
|||
self.assertTrue(plugin.object.test_flag)
|
||||
|
||||
def test_fire_event_wait(self):
|
||||
async def fire_event():
|
||||
await manager.fire_event("test", wait=True)
|
||||
await manager.close()
|
||||
@asyncio.coroutine
|
||||
def fire_event():
|
||||
yield from manager.fire_event("test", wait=True)
|
||||
yield from manager.close()
|
||||
|
||||
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
|
||||
self.loop.run_until_complete(fire_event())
|
||||
|
@ -62,8 +67,9 @@ class TestPluginManager(unittest.TestCase):
|
|||
self.assertTrue(plugin.object.test_flag)
|
||||
|
||||
def test_map_coro(self):
|
||||
async def call_coro():
|
||||
await manager.map_plugin_coro('test_coro')
|
||||
@asyncio.coroutine
|
||||
def call_coro():
|
||||
yield from manager.map_plugin_coro('test_coro')
|
||||
|
||||
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
|
||||
self.loop.run_until_complete(call_coro())
|
||||
|
@ -71,8 +77,9 @@ class TestPluginManager(unittest.TestCase):
|
|||
self.assertTrue(plugin.object.test_coro)
|
||||
|
||||
def test_map_coro_return(self):
|
||||
async def call_coro():
|
||||
return await manager.map_plugin_coro('ret_coro')
|
||||
@asyncio.coroutine
|
||||
def call_coro():
|
||||
return (yield from manager.map_plugin_coro('ret_coro'))
|
||||
|
||||
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
|
||||
ret = self.loop.run_until_complete(call_coro())
|
||||
|
@ -84,8 +91,9 @@ class TestPluginManager(unittest.TestCase):
|
|||
Run plugin coro but expect no return as an empty filter is given
|
||||
:return:
|
||||
"""
|
||||
async def call_coro():
|
||||
return await manager.map_plugin_coro('ret_coro', filter_plugins=[])
|
||||
@asyncio.coroutine
|
||||
def call_coro():
|
||||
return (yield from manager.map_plugin_coro('ret_coro', filter_plugins=[]))
|
||||
|
||||
manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop)
|
||||
ret = self.loop.run_until_complete(call_coro())
|
||||
|
|
|
@ -26,12 +26,12 @@ test_config = {
|
|||
}
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
def __await__(self, *args, **kwargs):
|
||||
future = asyncio.Future()
|
||||
future.set_result(self)
|
||||
result = yield from future
|
||||
return result
|
||||
#class AsyncMock(MagicMock):
|
||||
# def __yield from__(self, *args, **kwargs):
|
||||
# future = asyncio.Future()
|
||||
# future.set_result(self)
|
||||
# result = yield from future
|
||||
# return result
|
||||
|
||||
class BrokerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -41,12 +41,13 @@ class BrokerTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_start_stop(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
self.assertDictEqual(broker._sessions, {})
|
||||
self.assertIn('default', broker._servers)
|
||||
|
@ -54,7 +55,7 @@ class BrokerTest(unittest.TestCase):
|
|||
[call().fire_event(EVENT_BROKER_PRE_START),
|
||||
call().fire_event(EVENT_BROKER_POST_START)], any_order=True)
|
||||
MockPluginManager.reset_mock()
|
||||
await broker.shutdown()
|
||||
yield from broker.shutdown()
|
||||
MockPluginManager.assert_has_calls(
|
||||
[call().fire_event(EVENT_BROKER_PRE_SHUTDOWN),
|
||||
call().fire_event(EVENT_BROKER_POST_SHUTDOWN)], any_order=True)
|
||||
|
@ -68,20 +69,21 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_connect(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
client = MQTTClient()
|
||||
ret = await client.connect('mqtt://localhost/')
|
||||
ret = yield from client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
self.assertIn(client.session.client_id, broker._sessions)
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
self.assertDictEqual(broker._sessions, {})
|
||||
MockPluginManager.assert_has_calls(
|
||||
|
@ -97,17 +99,18 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_subscribe(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
client = MQTTClient()
|
||||
ret = await client.connect('mqtt://localhost/')
|
||||
ret = yield from client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
await client.subscribe([('/topic', QOS_0)])
|
||||
yield from client.subscribe([('/topic', QOS_0)])
|
||||
|
||||
# Test if the client test client subscription is registered
|
||||
self.assertIn('/topic', broker._subscriptions)
|
||||
|
@ -117,9 +120,9 @@ class BrokerTest(unittest.TestCase):
|
|||
self.assertEquals(s, client.session)
|
||||
self.assertEquals(qos, QOS_0)
|
||||
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
MockPluginManager.assert_has_calls(
|
||||
[call().fire_event(EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
|
@ -134,17 +137,18 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_subscribe_twice(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
client = MQTTClient()
|
||||
ret = await client.connect('mqtt://localhost/')
|
||||
ret = yield from client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
await client.subscribe([('/topic', QOS_0)])
|
||||
yield from client.subscribe([('/topic', QOS_0)])
|
||||
|
||||
# Test if the client test client subscription is registered
|
||||
self.assertIn('/topic', broker._subscriptions)
|
||||
|
@ -154,15 +158,15 @@ class BrokerTest(unittest.TestCase):
|
|||
self.assertEquals(s, client.session)
|
||||
self.assertEquals(qos, QOS_0)
|
||||
|
||||
await client.subscribe([('/topic', QOS_0)])
|
||||
yield from client.subscribe([('/topic', QOS_0)])
|
||||
self.assertEquals(len(subs), 1)
|
||||
(s, qos) = subs[0]
|
||||
self.assertEquals(s, client.session)
|
||||
self.assertEquals(qos, QOS_0)
|
||||
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
MockPluginManager.assert_has_calls(
|
||||
[call().fire_event(EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
|
@ -177,17 +181,18 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_unsubscribe(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
client = MQTTClient()
|
||||
ret = await client.connect('mqtt://localhost/')
|
||||
ret = yield from client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
await client.subscribe([('/topic', QOS_0)])
|
||||
yield from client.subscribe([('/topic', QOS_0)])
|
||||
|
||||
# Test if the client test client subscription is registered
|
||||
self.assertIn('/topic', broker._subscriptions)
|
||||
|
@ -197,12 +202,12 @@ class BrokerTest(unittest.TestCase):
|
|||
self.assertEquals(s, client.session)
|
||||
self.assertEquals(qos, QOS_0)
|
||||
|
||||
await client.unsubscribe(['/topic'])
|
||||
await asyncio.sleep(0.1)
|
||||
yield from client.unsubscribe(['/topic'])
|
||||
yield from asyncio.sleep(0.1)
|
||||
self.assertEquals(broker._subscriptions['/topic'], [])
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
MockPluginManager.assert_has_calls(
|
||||
[call().fire_event(EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
|
@ -221,23 +226,24 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_publish(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
pub_client = MQTTClient()
|
||||
ret = await pub_client.connect('mqtt://localhost/')
|
||||
ret = yield from pub_client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
|
||||
ret_message = await pub_client.publish('/topic', b'data', QOS_0)
|
||||
await pub_client.disconnect()
|
||||
ret_message = yield from pub_client.publish('/topic', b'data', QOS_0)
|
||||
yield from pub_client.disconnect()
|
||||
self.assertEquals(broker._retained_messages, {})
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
MockPluginManager.assert_has_calls(
|
||||
[call().fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
|
@ -253,27 +259,28 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_publish_retain(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
|
||||
pub_client = MQTTClient()
|
||||
ret = await pub_client.connect('mqtt://localhost/')
|
||||
ret = yield from pub_client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
ret_message = await pub_client.publish('/topic', b'data', QOS_0, retain=True)
|
||||
await pub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
ret_message = yield from pub_client.publish('/topic', b'data', QOS_0, retain=True)
|
||||
yield from pub_client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
self.assertIn('/topic', broker._retained_messages)
|
||||
retained_message = broker._retained_messages['/topic']
|
||||
self.assertEquals(retained_message.source_session, pub_client.session)
|
||||
self.assertEquals(retained_message.topic, '/topic')
|
||||
self.assertEquals(retained_message.data, b'data')
|
||||
self.assertEquals(retained_message.qos, QOS_0)
|
||||
await broker.shutdown()
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
@ -284,31 +291,32 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_subscribe_publish(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
sub_client = MQTTClient()
|
||||
await sub_client.connect('mqtt://localhost')
|
||||
ret = await sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
|
||||
yield from sub_client.connect('mqtt://localhost')
|
||||
ret = yield from sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
|
||||
self.assertEquals(ret, [QOS_0, QOS_1, QOS_2])
|
||||
|
||||
await self._client_publish('/qos0', b'data', QOS_0)
|
||||
await self._client_publish('/qos1', b'data', QOS_1)
|
||||
await self._client_publish('/qos2', b'data', QOS_2)
|
||||
await asyncio.sleep(0.1)
|
||||
yield from self._client_publish('/qos0', b'data', QOS_0)
|
||||
yield from self._client_publish('/qos1', b'data', QOS_1)
|
||||
yield from self._client_publish('/qos2', b'data', QOS_2)
|
||||
yield from asyncio.sleep(0.1)
|
||||
for qos in [QOS_0, QOS_1, QOS_2]:
|
||||
message = await sub_client.deliver_message()
|
||||
message = yield from sub_client.deliver_message()
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEquals(message.topic, '/qos%s' % qos)
|
||||
self.assertEquals(message.data, b'data')
|
||||
self.assertEquals(message.qos, qos)
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from sub_client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
@ -319,35 +327,36 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
@patch('hbmqtt.broker.PluginManager', new_callable=AsyncMock)
|
||||
@patch('hbmqtt.broker.PluginManager')
|
||||
def test_client_publish_retain_subscribe(self, MockPluginManager):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
broker = Broker(test_config, plugin_namespace="hbmqtt.test.plugins")
|
||||
await broker.start()
|
||||
yield from broker.start()
|
||||
self.assertTrue(broker.transitions.is_started())
|
||||
sub_client = MQTTClient()
|
||||
await sub_client.connect('mqtt://localhost', cleansession=False)
|
||||
ret = await sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
|
||||
yield from sub_client.connect('mqtt://localhost', cleansession=False)
|
||||
ret = yield from sub_client.subscribe([('/qos0', QOS_0), ('/qos1', QOS_1), ('/qos2', QOS_2)])
|
||||
self.assertEquals(ret, [QOS_0, QOS_1, QOS_2])
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
yield from sub_client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
|
||||
await self._client_publish('/qos0', b'data', QOS_0, retain=True)
|
||||
await self._client_publish('/qos1', b'data', QOS_1, retain=True)
|
||||
await self._client_publish('/qos2', b'data', QOS_2, retain=True)
|
||||
await sub_client.reconnect()
|
||||
yield from self._client_publish('/qos0', b'data', QOS_0, retain=True)
|
||||
yield from self._client_publish('/qos1', b'data', QOS_1, retain=True)
|
||||
yield from self._client_publish('/qos2', b'data', QOS_2, retain=True)
|
||||
yield from sub_client.reconnect()
|
||||
for qos in [QOS_0, QOS_1, QOS_2]:
|
||||
log.debug("TEST QOS: %d" % qos)
|
||||
message = await sub_client.deliver_message()
|
||||
message = yield from sub_client.deliver_message()
|
||||
log.debug("Message: " + repr(message.publish_packet))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEquals(message.topic, '/qos%s' % qos)
|
||||
self.assertEquals(message.data, b'data')
|
||||
self.assertEquals(message.qos, qos)
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await broker.shutdown()
|
||||
yield from sub_client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
self.assertTrue(broker.transitions.is_stopped())
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
|
@ -358,10 +367,11 @@ class BrokerTest(unittest.TestCase):
|
|||
if future.exception():
|
||||
raise future.exception()
|
||||
|
||||
async def _client_publish(self, topic, data, qos, retain=False):
|
||||
@asyncio.coroutine
|
||||
def _client_publish(self, topic, data, qos, retain=False):
|
||||
pub_client = MQTTClient()
|
||||
ret = await pub_client.connect('mqtt://localhost/')
|
||||
ret = yield from pub_client.connect('mqtt://localhost/')
|
||||
self.assertEqual(ret, 0)
|
||||
ret = await pub_client.publish(topic, data, qos, retain)
|
||||
await pub_client.disconnect()
|
||||
ret = yield from pub_client.publish(topic, data, qos, retain)
|
||||
yield from pub_client.disconnect()
|
||||
return ret
|
||||
|
|
|
@ -22,12 +22,13 @@ class MQTTClientTest(unittest.TestCase):
|
|||
self.loop.close()
|
||||
|
||||
def test_connect_tcp(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
ret = await client.connect('mqtt://test.mosquitto.org/')
|
||||
ret = yield from client.connect('mqtt://test.mosquitto.org/')
|
||||
self.assertIsNotNone(client.session)
|
||||
await client.disconnect()
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -38,13 +39,14 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_connect_tcp_secure(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
ca = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mosquitto.org.crt')
|
||||
ret = await client.connect('mqtts://test.mosquitto.org/', cafile=ca)
|
||||
ret = yield from client.connect('mqtts://test.mosquitto.org/', cafile=ca)
|
||||
self.assertIsNotNone(client.session)
|
||||
await client.disconnect()
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -55,11 +57,12 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_connect_tcp_failure(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
config = {'auto_reconnect': False}
|
||||
client = MQTTClient(config=config)
|
||||
ret = await client.connect('mqtt://localhost/')
|
||||
ret = yield from client.connect('mqtt://localhost/')
|
||||
except ConnectException as e:
|
||||
future.set_result(True)
|
||||
|
||||
|
@ -69,12 +72,13 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_connect_ws(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
await client.connect('ws://test.mosquitto.org:8080/')
|
||||
yield from client.connect('ws://test.mosquitto.org:8080/')
|
||||
self.assertIsNotNone(client.session)
|
||||
await client.disconnect()
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -85,13 +89,14 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_connect_ws_secure(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
ca = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mosquitto.org.crt')
|
||||
await client.connect('wss://test.mosquitto.org:8081/', cafile=ca)
|
||||
yield from client.connect('wss://test.mosquitto.org:8081/', cafile=ca)
|
||||
self.assertIsNotNone(client.session)
|
||||
await client.disconnect()
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -102,13 +107,14 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_ping(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
ret = await client.connect('mqtt://test.mosquitto.org/')
|
||||
ret = yield from client.connect('mqtt://test.mosquitto.org/')
|
||||
self.assertIsNotNone(client.session)
|
||||
await client.ping()
|
||||
await client.disconnect()
|
||||
yield from client.ping()
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -119,12 +125,13 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_subscribe(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
await client.connect('mqtt://test.mosquitto.org/')
|
||||
yield from client.connect('mqtt://test.mosquitto.org/')
|
||||
self.assertIsNotNone(client.session)
|
||||
ret = await client.subscribe([
|
||||
ret = yield from client.subscribe([
|
||||
('$SYS/broker/uptime', QOS_0),
|
||||
('$SYS/broker/uptime', QOS_1),
|
||||
('$SYS/broker/uptime', QOS_2),
|
||||
|
@ -132,7 +139,7 @@ class MQTTClientTest(unittest.TestCase):
|
|||
self.assertEquals(ret[0], QOS_0)
|
||||
self.assertEquals(ret[1], QOS_1)
|
||||
self.assertEquals(ret[2], QOS_2)
|
||||
await client.disconnect()
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
@ -143,17 +150,18 @@ class MQTTClientTest(unittest.TestCase):
|
|||
raise future.exception()
|
||||
|
||||
def test_unsubscribe(self):
|
||||
async def test_coro():
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
try:
|
||||
client = MQTTClient()
|
||||
await client.connect('mqtt://test.mosquitto.org/')
|
||||
yield from client.connect('mqtt://test.mosquitto.org/')
|
||||
self.assertIsNotNone(client.session)
|
||||
ret = await client.subscribe([
|
||||
ret = yield from client.subscribe([
|
||||
('$SYS/broker/uptime', QOS_0),
|
||||
])
|
||||
self.assertEquals(ret[0], QOS_0)
|
||||
await client.unsubscribe(['$SYS/broker/uptime'])
|
||||
await client.disconnect()
|
||||
yield from client.unsubscribe(['$SYS/broker/uptime'])
|
||||
yield from client.disconnect()
|
||||
future.set_result(True)
|
||||
except Exception as ae:
|
||||
future.set_exception(ae)
|
||||
|
|
Ładowanie…
Reference in New Issue