All tests passing

pull/8/head
Nico 2015-10-12 21:33:14 +02:00
rodzic b6bd91e3fb
commit c3a144c6a3
4 zmienionych plików z 122 dodań i 81 usunięć

Wyświetl plik

@ -6,6 +6,7 @@ import ssl
import websockets import websockets
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from collections import deque
from functools import partial from functools import partial
from transitions import Machine, MachineError from transitions import Machine, MachineError
@ -460,12 +461,20 @@ class Broker:
self._sessions[client_session.client_id] = (client_session, handler) self._sessions[client_session.client_id] = (client_session, handler)
authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name]) authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
yield from handler.mqtt_connack_authorize(authenticated)
if not authenticated: if not authenticated:
yield from writer.close() yield from writer.close()
return return
client_session.transitions.connect() while True:
try:
client_session.transitions.connect()
break
except MachineError:
self.logger.warning("Client %s is reconnecting too quickly, make it wait" % client_session.client_id)
# Wait a bit may be client is reconnecting too fast
yield from asyncio.sleep(1, loop=self._loop)
yield from handler.mqtt_connack_authorize(authenticated)
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id) yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
self.logger.debug("%s Start messages handling" % client_session.client_id) self.logger.debug("%s Start messages handling" % client_session.client_id)
@ -474,85 +483,90 @@ class Broker:
yield from self.publish_session_retained_messages(client_session) yield from self.publish_session_retained_messages(client_session)
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect) # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
connected = True
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop) disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription(), loop=self._loop) subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription(), loop=self._loop)
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop) unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop)
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop) wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop)
connected = True
while connected: while connected:
done, pending = yield from asyncio.wait( try:
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver], done, pending = yield from asyncio.wait(
return_when=asyncio.FIRST_COMPLETED, loop=self._loop) [disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
if disconnect_waiter in done: return_when=asyncio.FIRST_COMPLETED, loop=self._loop)
result = disconnect_waiter.result() if disconnect_waiter in done:
self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result)) result = disconnect_waiter.result()
if result is None: self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result))
self.logger.debug("Will flag: %s" % client_session.will_flag) if result is None:
# Connection closed anormally, send will message self.logger.debug("Will flag: %s" % client_session.will_flag)
if client_session.will_flag: # Connection closed anormally, send will message
self.logger.debug("Client %s disconnected abnormally, sending will message" % if client_session.will_flag:
format_client_message(client_session)) self.logger.debug("Client %s disconnected abnormally, sending will message" %
yield from self.broadcast_application_message( format_client_message(client_session))
client_session, client_session.will_topic, yield from self.broadcast_application_message(
client_session.will_message, client_session, client_session.will_topic,
client_session.will_qos) client_session.will_message,
if client_session.will_retain: client_session.will_qos)
self.retain_message(client_session, if client_session.will_retain:
client_session.will_topic, self.retain_message(client_session,
client_session.will_message, client_session.will_topic,
client_session.will_qos) client_session.will_message,
connected = False client_session.will_qos)
if unsubscribe_waiter in done: self.logger.debug("%s Disconnecting session" % client_session.client_id)
self.logger.debug("%s handling unsubscription" % client_session.client_id) yield from self._stop_handler(handler)
unsubscription = unsubscribe_waiter.result() client_session.transitions.disconnect()
for topic in unsubscription['topics']: yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
self.del_subscription(topic, client_session) yield from writer.close()
yield from self.plugins_manager.fire_event( connected = False
EVENT_BROKER_CLIENT_UNSUBSCRIBED, if unsubscribe_waiter in done:
client_id=client_session.client_id, self.logger.debug("%s handling unsubscription" % client_session.client_id)
topic=topic) unsubscription = unsubscribe_waiter.result()
yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id']) for topic in unsubscription['topics']:
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop) self._del_subscription(topic, client_session)
if subscribe_waiter in done:
self.logger.debug("%s handling subscription" % client_session.client_id)
subscriptions = subscribe_waiter.result()
return_codes = []
for subscription in subscriptions['topics']:
return_codes.append(self.add_subscription(subscription, client_session))
yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
for index, subscription in enumerate(subscriptions['topics']):
if return_codes[index] != 0x80:
yield from self.plugins_manager.fire_event( yield from self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED, EVENT_BROKER_CLIENT_UNSUBSCRIBED,
client_id=client_session.client_id, client_id=client_session.client_id,
topic=subscription[0], topic=topic)
qos=subscription[1]) yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
yield from self.publish_retained_messages_for_subscription(subscription, client_session) unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop) if subscribe_waiter in done:
self.logger.debug(repr(self._subscriptions)) self.logger.debug("%s handling subscription" % client_session.client_id)
if wait_deliver in done: subscriptions = subscribe_waiter.result()
self.logger.debug("%s handling message delivery" % client_session.client_id) return_codes = []
app_message = wait_deliver.result() for subscription in subscriptions['topics']:
yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED, return_codes.append(self.add_subscription(subscription, client_session))
client_id=client_session.client_id, yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
message=app_message) for index, subscription in enumerate(subscriptions['topics']):
yield from self.broadcast_application_message(client_session, app_message.topic, app_message.data) if return_codes[index] != 0x80:
if app_message.publish_packet.retain_flag: yield from self.plugins_manager.fire_event(
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) EVENT_BROKER_CLIENT_SUBSCRIBED,
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop) client_id=client_session.client_id,
topic=subscription[0],
qos=subscription[1])
yield from self.publish_retained_messages_for_subscription(subscription, client_session)
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
self.logger.debug(repr(self._subscriptions))
if wait_deliver in done:
self.logger.debug("%s handling message delivery" % client_session.client_id)
app_message = wait_deliver.result()
yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
client_id=client_session.client_id,
message=app_message)
yield from self.broadcast_application_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
except asyncio.CancelledError:
self.logger.debug("Client loop cancelled")
break
disconnect_waiter.cancel() disconnect_waiter.cancel()
subscribe_waiter.cancel() subscribe_waiter.cancel()
unsubscribe_waiter.cancel() unsubscribe_waiter.cancel()
wait_deliver.cancel() wait_deliver.cancel()
self.logger.debug("%s Client disconnecting" % client_session.client_id) self.logger.debug("%s Client disconnected" % client_session.client_id)
yield from self._stop_handler(handler)
client_session.transitions.disconnect()
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
yield from writer.close()
self.logger.debug("%s Session disconnected" % client_session.client_id)
server.release_connection() server.release_connection()
def _init_handler(self, session, reader, writer): def _init_handler(self, session, reader, writer):
""" """
Create a BrokerProtocolHandler and attach to a session Create a BrokerProtocolHandler and attach to a session
@ -647,7 +661,14 @@ class Broker:
except KeyError: except KeyError:
return 0x80 return 0x80
def del_subscription(self, a_filter, session): def _del_subscription(self, a_filter, session):
"""
Delete a session subscription on a given topic
:param a_filter:
:param session:
:return:
"""
deleted = 0
try: try:
subscriptions = self._subscriptions[a_filter] subscriptions = self._subscriptions[a_filter]
for index, (sub_session, qos) in enumerate(subscriptions): for index, (sub_session, qos) in enumerate(subscriptions):
@ -655,12 +676,27 @@ class Broker:
self.logger.debug("Removing subscription on topic '%s' for client %s" % self.logger.debug("Removing subscription on topic '%s' for client %s" %
(a_filter, format_client_message(session=session))) (a_filter, format_client_message(session=session)))
subscriptions.pop(index) subscriptions.pop(index)
# Remove filter for subsriptions list if there are no more subscribers deleted += 1
if not self._subscriptions[a_filter]: break
del self._subscriptions[a_filter]
except KeyError: except KeyError:
# Unsubscribe topic not found in current subscribed topics # Unsubscribe topic not found in current subscribed topics
pass pass
finally:
return deleted
def _del_all_subscriptions(self, session):
"""
Delete all topic subscriptions for a given session
:param session:
:return:
"""
filter_queue = deque()
for topic in self._subscriptions:
if self._del_subscription(topic, session):
filter_queue.append(topic)
for topic in filter_queue:
if not self._subscriptions[topic]:
del self._subscriptions[topic]
def matches(self, topic, a_filter): def matches(self, topic, a_filter):
import re import re
@ -726,13 +762,14 @@ class Broker:
self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s" % self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s" %
(subscription[0], format_client_message(session=session))) (subscription[0], format_client_message(session=session)))
publish_tasks = [] publish_tasks = []
handler = self._get_handler(session)
for d_topic in self._retained_messages: for d_topic in self._retained_messages:
self.logger.debug("matching : %s %s" % (d_topic, subscription[0])) self.logger.debug("matching : %s %s" % (d_topic, subscription[0]))
if self.matches(d_topic, subscription[0]): if self.matches(d_topic, subscription[0]):
self.logger.debug("%s and %s match" % (d_topic, subscription[0])) self.logger.debug("%s and %s match" % (d_topic, subscription[0]))
retained = self._retained_messages[d_topic] retained = self._retained_messages[d_topic]
publish_tasks.append(asyncio.Task( publish_tasks.append(asyncio.Task(
session.handler.mqtt_publish( handler.mqtt_publish(
retained.topic, retained.data, subscription[1], True), loop=self._loop)) retained.topic, retained.data, subscription[1], True), loop=self._loop))
if publish_tasks: if publish_tasks:
yield from asyncio.wait(publish_tasks, loop=self._loop) yield from asyncio.wait(publish_tasks, loop=self._loop)
@ -755,11 +792,7 @@ class Broker:
# Delete subscriptions # Delete subscriptions
self.logger.debug("deleting session %s subscriptions" % repr(session)) self.logger.debug("deleting session %s subscriptions" % repr(session))
nb_sub = 0 self._del_all_subscriptions(session)
for a_filter in self._subscriptions:
self.del_subscription(a_filter, session)
nb_sub += 1
self.logger.debug("%d subscriptions deleted" % nb_sub)
self.logger.debug("deleting existing session %s" % repr(self._sessions[client_id])) self.logger.debug("deleting existing session %s" % repr(self._sessions[client_id]))
del self._sessions[client_id] del self._sessions[client_id]

Wyświetl plik

@ -53,7 +53,9 @@ class BrokerProtocolHandler(ProtocolHandler):
@asyncio.coroutine @asyncio.coroutine
def handle_disconnect(self, disconnect): def handle_disconnect(self, disconnect):
self.logger.debug("Client disconnecting")
if self._disconnect_waiter and not self._disconnect_waiter.done(): if self._disconnect_waiter and not self._disconnect_waiter.done():
self.logger.debug("Setting waiter result to %r" % disconnect)
self._disconnect_waiter.set_result(disconnect) self._disconnect_waiter.set_result(disconnect)
@asyncio.coroutine @asyncio.coroutine

Wyświetl plik

@ -133,6 +133,10 @@ class ProtocolHandler:
self.logger.debug("Handler writer close failed: %s" % e) self.logger.debug("Handler writer close failed: %s" % e)
def _stop_waiters(self): def _stop_waiters(self):
self.logger.debug("Stopping %d puback waiters" % len(self._puback_waiters))
self.logger.debug("Stopping %d pucomp waiters" % len(self._pubcomp_waiters))
self.logger.debug("Stopping %d purec waiters" % len(self._pubrec_waiters))
self.logger.debug("Stopping %d purel waiters" % len(self._pubrel_waiters))
for waiter in itertools.chain( for waiter in itertools.chain(
self._puback_waiters.values(), self._puback_waiters.values(),
self._pubcomp_waiters.values(), self._pubcomp_waiters.values(),
@ -167,7 +171,7 @@ class ProtocolHandler:
message = OutgoingApplicationMessage(packet_id, topic, qos, data, retain) message = OutgoingApplicationMessage(packet_id, topic, qos, data, retain)
# Handle message flow # Handle message flow
yield from asyncio.wait_for(self._handle_message_flow(message), 10, loop=self._loop) yield from asyncio.wait_for(self._handle_message_flow(message), 60, loop=self._loop)
return message return message
@asyncio.coroutine @asyncio.coroutine
@ -394,8 +398,6 @@ class ProtocolHandler:
break break
except asyncio.CancelledError: except asyncio.CancelledError:
self.logger.debug("Task cancelled, reader loop ending") self.logger.debug("Task cancelled, reader loop ending")
while running_tasks:
running_tasks.popleft().cancel()
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.logger.debug("%s Input stream read timeout" % self.session.client_id) self.logger.debug("%s Input stream read timeout" % self.session.client_id)
@ -405,9 +407,12 @@ class ProtocolHandler:
except BaseException as e: except BaseException as e:
self.logger.warning("%s Unhandled exception in reader coro: %s" % (type(self).__name__, e)) self.logger.warning("%s Unhandled exception in reader coro: %s" % (type(self).__name__, e))
break break
while running_tasks:
running_tasks.popleft().cancel()
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
self._reader_stopped.set() self._reader_stopped.set()
self.logger.debug("%s Reader coro stopped" % self.session.client_id) self.logger.debug("%s Reader coro stopped" % self.session.client_id)
yield from self.stop()
@asyncio.coroutine @asyncio.coroutine
def _send_packet(self, packet): def _send_packet(self, packet):

Wyświetl plik

@ -193,7 +193,8 @@ class BrokerTest(unittest.TestCase):
self.assertEquals(qos, QOS_0) self.assertEquals(qos, QOS_0)
yield from client.unsubscribe(['/topic']) yield from client.unsubscribe(['/topic'])
self.assertNotIn('/topic', broker._subscriptions) yield from asyncio.sleep(0.1)
self.assertEquals(broker._subscriptions['/topic'], [])
yield from client.disconnect() yield from client.disconnect()
yield from asyncio.sleep(0.1) yield from asyncio.sleep(0.1)
yield from broker.shutdown() yield from broker.shutdown()