kopia lustrzana https://github.com/Yakifo/amqtt
All tests passing
rodzic
b6bd91e3fb
commit
c3a144c6a3
187
hbmqtt/broker.py
187
hbmqtt/broker.py
|
@ -6,6 +6,7 @@ import ssl
|
|||
import websockets
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
from functools import partial
|
||||
from transitions import Machine, MachineError
|
||||
|
@ -460,12 +461,20 @@ class Broker:
|
|||
self._sessions[client_session.client_id] = (client_session, handler)
|
||||
|
||||
authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
|
||||
yield from handler.mqtt_connack_authorize(authenticated)
|
||||
if not authenticated:
|
||||
yield from writer.close()
|
||||
return
|
||||
|
||||
client_session.transitions.connect()
|
||||
while True:
|
||||
try:
|
||||
client_session.transitions.connect()
|
||||
break
|
||||
except MachineError:
|
||||
self.logger.warning("Client %s is reconnecting too quickly, make it wait" % client_session.client_id)
|
||||
# Wait a bit may be client is reconnecting too fast
|
||||
yield from asyncio.sleep(1, loop=self._loop)
|
||||
yield from handler.mqtt_connack_authorize(authenticated)
|
||||
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
||||
|
||||
self.logger.debug("%s Start messages handling" % client_session.client_id)
|
||||
|
@ -474,85 +483,90 @@ class Broker:
|
|||
yield from self.publish_session_retained_messages(client_session)
|
||||
|
||||
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
|
||||
connected = True
|
||||
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect(), loop=self._loop)
|
||||
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription(), loop=self._loop)
|
||||
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop)
|
||||
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop)
|
||||
connected = True
|
||||
while connected:
|
||||
done, pending = yield from asyncio.wait(
|
||||
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
|
||||
return_when=asyncio.FIRST_COMPLETED, loop=self._loop)
|
||||
if disconnect_waiter in done:
|
||||
result = disconnect_waiter.result()
|
||||
self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result))
|
||||
if result is None:
|
||||
self.logger.debug("Will flag: %s" % client_session.will_flag)
|
||||
# Connection closed anormally, send will message
|
||||
if client_session.will_flag:
|
||||
self.logger.debug("Client %s disconnected abnormally, sending will message" %
|
||||
format_client_message(client_session))
|
||||
yield from self.broadcast_application_message(
|
||||
client_session, client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
if client_session.will_retain:
|
||||
self.retain_message(client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
connected = False
|
||||
if unsubscribe_waiter in done:
|
||||
self.logger.debug("%s handling unsubscription" % client_session.client_id)
|
||||
unsubscription = unsubscribe_waiter.result()
|
||||
for topic in unsubscription['topics']:
|
||||
self.del_subscription(topic, client_session)
|
||||
yield from self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
||||
client_id=client_session.client_id,
|
||||
topic=topic)
|
||||
yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
|
||||
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
|
||||
if subscribe_waiter in done:
|
||||
self.logger.debug("%s handling subscription" % client_session.client_id)
|
||||
subscriptions = subscribe_waiter.result()
|
||||
return_codes = []
|
||||
for subscription in subscriptions['topics']:
|
||||
return_codes.append(self.add_subscription(subscription, client_session))
|
||||
yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
|
||||
for index, subscription in enumerate(subscriptions['topics']):
|
||||
if return_codes[index] != 0x80:
|
||||
try:
|
||||
done, pending = yield from asyncio.wait(
|
||||
[disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
|
||||
return_when=asyncio.FIRST_COMPLETED, loop=self._loop)
|
||||
if disconnect_waiter in done:
|
||||
result = disconnect_waiter.result()
|
||||
self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result))
|
||||
if result is None:
|
||||
self.logger.debug("Will flag: %s" % client_session.will_flag)
|
||||
# Connection closed anormally, send will message
|
||||
if client_session.will_flag:
|
||||
self.logger.debug("Client %s disconnected abnormally, sending will message" %
|
||||
format_client_message(client_session))
|
||||
yield from self.broadcast_application_message(
|
||||
client_session, client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
if client_session.will_retain:
|
||||
self.retain_message(client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
self.logger.debug("%s Disconnecting session" % client_session.client_id)
|
||||
yield from self._stop_handler(handler)
|
||||
client_session.transitions.disconnect()
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
||||
yield from writer.close()
|
||||
connected = False
|
||||
if unsubscribe_waiter in done:
|
||||
self.logger.debug("%s handling unsubscription" % client_session.client_id)
|
||||
unsubscription = unsubscribe_waiter.result()
|
||||
for topic in unsubscription['topics']:
|
||||
self._del_subscription(topic, client_session)
|
||||
yield from self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
||||
client_id=client_session.client_id,
|
||||
topic=subscription[0],
|
||||
qos=subscription[1])
|
||||
yield from self.publish_retained_messages_for_subscription(subscription, client_session)
|
||||
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
|
||||
self.logger.debug(repr(self._subscriptions))
|
||||
if wait_deliver in done:
|
||||
self.logger.debug("%s handling message delivery" % client_session.client_id)
|
||||
app_message = wait_deliver.result()
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
client_id=client_session.client_id,
|
||||
message=app_message)
|
||||
yield from self.broadcast_application_message(client_session, app_message.topic, app_message.data)
|
||||
if app_message.publish_packet.retain_flag:
|
||||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
|
||||
topic=topic)
|
||||
yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
|
||||
unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
|
||||
if subscribe_waiter in done:
|
||||
self.logger.debug("%s handling subscription" % client_session.client_id)
|
||||
subscriptions = subscribe_waiter.result()
|
||||
return_codes = []
|
||||
for subscription in subscriptions['topics']:
|
||||
return_codes.append(self.add_subscription(subscription, client_session))
|
||||
yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
|
||||
for index, subscription in enumerate(subscriptions['topics']):
|
||||
if return_codes[index] != 0x80:
|
||||
yield from self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
client_id=client_session.client_id,
|
||||
topic=subscription[0],
|
||||
qos=subscription[1])
|
||||
yield from self.publish_retained_messages_for_subscription(subscription, client_session)
|
||||
subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
|
||||
self.logger.debug(repr(self._subscriptions))
|
||||
if wait_deliver in done:
|
||||
self.logger.debug("%s handling message delivery" % client_session.client_id)
|
||||
app_message = wait_deliver.result()
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
client_id=client_session.client_id,
|
||||
message=app_message)
|
||||
yield from self.broadcast_application_message(client_session, app_message.topic, app_message.data)
|
||||
if app_message.publish_packet.retain_flag:
|
||||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Client loop cancelled")
|
||||
break
|
||||
disconnect_waiter.cancel()
|
||||
subscribe_waiter.cancel()
|
||||
unsubscribe_waiter.cancel()
|
||||
wait_deliver.cancel()
|
||||
|
||||
self.logger.debug("%s Client disconnecting" % client_session.client_id)
|
||||
yield from self._stop_handler(handler)
|
||||
client_session.transitions.disconnect()
|
||||
yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
||||
yield from writer.close()
|
||||
self.logger.debug("%s Session disconnected" % client_session.client_id)
|
||||
self.logger.debug("%s Client disconnected" % client_session.client_id)
|
||||
server.release_connection()
|
||||
|
||||
|
||||
def _init_handler(self, session, reader, writer):
|
||||
"""
|
||||
Create a BrokerProtocolHandler and attach to a session
|
||||
|
@ -647,7 +661,14 @@ class Broker:
|
|||
except KeyError:
|
||||
return 0x80
|
||||
|
||||
def del_subscription(self, a_filter, session):
|
||||
def _del_subscription(self, a_filter, session):
|
||||
"""
|
||||
Delete a session subscription on a given topic
|
||||
:param a_filter:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
deleted = 0
|
||||
try:
|
||||
subscriptions = self._subscriptions[a_filter]
|
||||
for index, (sub_session, qos) in enumerate(subscriptions):
|
||||
|
@ -655,12 +676,27 @@ class Broker:
|
|||
self.logger.debug("Removing subscription on topic '%s' for client %s" %
|
||||
(a_filter, format_client_message(session=session)))
|
||||
subscriptions.pop(index)
|
||||
# Remove filter for subsriptions list if there are no more subscribers
|
||||
if not self._subscriptions[a_filter]:
|
||||
del self._subscriptions[a_filter]
|
||||
deleted += 1
|
||||
break
|
||||
except KeyError:
|
||||
# Unsubscribe topic not found in current subscribed topics
|
||||
pass
|
||||
finally:
|
||||
return deleted
|
||||
|
||||
def _del_all_subscriptions(self, session):
|
||||
"""
|
||||
Delete all topic subscriptions for a given session
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
filter_queue = deque()
|
||||
for topic in self._subscriptions:
|
||||
if self._del_subscription(topic, session):
|
||||
filter_queue.append(topic)
|
||||
for topic in filter_queue:
|
||||
if not self._subscriptions[topic]:
|
||||
del self._subscriptions[topic]
|
||||
|
||||
def matches(self, topic, a_filter):
|
||||
import re
|
||||
|
@ -726,13 +762,14 @@ class Broker:
|
|||
self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s" %
|
||||
(subscription[0], format_client_message(session=session)))
|
||||
publish_tasks = []
|
||||
handler = self._get_handler(session)
|
||||
for d_topic in self._retained_messages:
|
||||
self.logger.debug("matching : %s %s" % (d_topic, subscription[0]))
|
||||
if self.matches(d_topic, subscription[0]):
|
||||
self.logger.debug("%s and %s match" % (d_topic, subscription[0]))
|
||||
retained = self._retained_messages[d_topic]
|
||||
publish_tasks.append(asyncio.Task(
|
||||
session.handler.mqtt_publish(
|
||||
handler.mqtt_publish(
|
||||
retained.topic, retained.data, subscription[1], True), loop=self._loop))
|
||||
if publish_tasks:
|
||||
yield from asyncio.wait(publish_tasks, loop=self._loop)
|
||||
|
@ -755,11 +792,7 @@ class Broker:
|
|||
|
||||
# Delete subscriptions
|
||||
self.logger.debug("deleting session %s subscriptions" % repr(session))
|
||||
nb_sub = 0
|
||||
for a_filter in self._subscriptions:
|
||||
self.del_subscription(a_filter, session)
|
||||
nb_sub += 1
|
||||
self.logger.debug("%d subscriptions deleted" % nb_sub)
|
||||
self._del_all_subscriptions(session)
|
||||
|
||||
self.logger.debug("deleting existing session %s" % repr(self._sessions[client_id]))
|
||||
del self._sessions[client_id]
|
||||
|
|
|
@ -53,7 +53,9 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
|
||||
@asyncio.coroutine
|
||||
def handle_disconnect(self, disconnect):
|
||||
self.logger.debug("Client disconnecting")
|
||||
if self._disconnect_waiter and not self._disconnect_waiter.done():
|
||||
self.logger.debug("Setting waiter result to %r" % disconnect)
|
||||
self._disconnect_waiter.set_result(disconnect)
|
||||
|
||||
@asyncio.coroutine
|
||||
|
|
|
@ -133,6 +133,10 @@ class ProtocolHandler:
|
|||
self.logger.debug("Handler writer close failed: %s" % e)
|
||||
|
||||
def _stop_waiters(self):
|
||||
self.logger.debug("Stopping %d puback waiters" % len(self._puback_waiters))
|
||||
self.logger.debug("Stopping %d pucomp waiters" % len(self._pubcomp_waiters))
|
||||
self.logger.debug("Stopping %d purec waiters" % len(self._pubrec_waiters))
|
||||
self.logger.debug("Stopping %d purel waiters" % len(self._pubrel_waiters))
|
||||
for waiter in itertools.chain(
|
||||
self._puback_waiters.values(),
|
||||
self._pubcomp_waiters.values(),
|
||||
|
@ -167,7 +171,7 @@ class ProtocolHandler:
|
|||
|
||||
message = OutgoingApplicationMessage(packet_id, topic, qos, data, retain)
|
||||
# Handle message flow
|
||||
yield from asyncio.wait_for(self._handle_message_flow(message), 10, loop=self._loop)
|
||||
yield from asyncio.wait_for(self._handle_message_flow(message), 60, loop=self._loop)
|
||||
return message
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -394,8 +398,6 @@ class ProtocolHandler:
|
|||
break
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Task cancelled, reader loop ending")
|
||||
while running_tasks:
|
||||
running_tasks.popleft().cancel()
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.debug("%s Input stream read timeout" % self.session.client_id)
|
||||
|
@ -405,9 +407,12 @@ class ProtocolHandler:
|
|||
except BaseException as e:
|
||||
self.logger.warning("%s Unhandled exception in reader coro: %s" % (type(self).__name__, e))
|
||||
break
|
||||
while running_tasks:
|
||||
running_tasks.popleft().cancel()
|
||||
yield from self.handle_connection_closed()
|
||||
self._reader_stopped.set()
|
||||
self.logger.debug("%s Reader coro stopped" % self.session.client_id)
|
||||
yield from self.stop()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _send_packet(self, packet):
|
||||
|
|
|
@ -193,7 +193,8 @@ class BrokerTest(unittest.TestCase):
|
|||
self.assertEquals(qos, QOS_0)
|
||||
|
||||
yield from client.unsubscribe(['/topic'])
|
||||
self.assertNotIn('/topic', broker._subscriptions)
|
||||
yield from asyncio.sleep(0.1)
|
||||
self.assertEquals(broker._subscriptions['/topic'], [])
|
||||
yield from client.disconnect()
|
||||
yield from asyncio.sleep(0.1)
|
||||
yield from broker.shutdown()
|
||||
|
|
Ładowanie…
Reference in New Issue