2015-07-07 20:48:53 +00:00
|
|
|
# Copyright (c) 2015 Nicolas JOUANIN
|
|
|
|
#
|
|
|
|
# See the file license.txt for copying permission.
|
|
|
|
import logging
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
from transitions import Machine, MachineError
|
2015-07-08 19:54:10 +00:00
|
|
|
from hbmqtt.session import Session
|
2015-07-12 20:35:56 +00:00
|
|
|
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
2015-07-10 20:55:22 +00:00
|
|
|
from hbmqtt.mqtt.connect import ConnectPacket
|
|
|
|
from hbmqtt.mqtt.connack import ConnackPacket, ReturnCode
|
|
|
|
from hbmqtt.errors import HBMQTTException
|
|
|
|
from hbmqtt.utils import format_client_message, gen_client_id
|
2015-07-07 20:48:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
_defaults = {
|
2015-07-11 18:52:34 +00:00
|
|
|
'bind-address': 'localhost',
|
2015-07-11 20:42:50 +00:00
|
|
|
'bind-port': 1883,
|
2015-07-12 20:35:56 +00:00
|
|
|
'timeout-disconnect-delay': 10
|
2015-07-07 20:48:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class BrokerException(BaseException):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2015-07-12 20:35:56 +00:00
|
|
|
class RetainedApplicationMessage:
|
|
|
|
def __init__(self, source_session, topic, data):
|
|
|
|
self.source_session = source_session
|
|
|
|
self.topic = topic
|
|
|
|
self.data = data
|
|
|
|
|
|
|
|
|
2015-07-07 20:48:53 +00:00
|
|
|
class Broker:
|
|
|
|
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
|
|
|
|
|
|
|
|
def __init__(self, config=None, loop=None):
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
self.config = _defaults
|
|
|
|
if config is not None:
|
|
|
|
self.config.update(config)
|
|
|
|
|
|
|
|
if loop is not None:
|
|
|
|
self._loop = loop
|
|
|
|
else:
|
|
|
|
self._loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
self._server = None
|
|
|
|
self._init_states()
|
2015-07-10 20:55:22 +00:00
|
|
|
self._sessions = dict()
|
2015-07-11 20:22:33 +00:00
|
|
|
self._topics = dict()
|
2015-07-12 20:35:56 +00:00
|
|
|
self._global_retained_messages=dict()
|
2015-07-07 20:48:53 +00:00
|
|
|
|
|
|
|
def _init_states(self):
|
|
|
|
self.machine = Machine(states=Broker.states, initial='new')
|
|
|
|
self.machine.add_transition(trigger='start', source='new', dest='starting')
|
|
|
|
self.machine.add_transition(trigger='starting_fail', source='starting', dest='not_started')
|
|
|
|
self.machine.add_transition(trigger='starting_success', source='starting', dest='started')
|
|
|
|
self.machine.add_transition(trigger='shutdown', source='started', dest='stopping')
|
|
|
|
self.machine.add_transition(trigger='stopping_success', source='stopping', dest='stopped')
|
|
|
|
self.machine.add_transition(trigger='stopping_failure', source='stopping', dest='not_stopped')
|
|
|
|
self.machine.add_transition(trigger='start', source='stopped', dest='starting')
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def start(self):
|
|
|
|
try:
|
|
|
|
self.machine.start()
|
|
|
|
self.logger.debug("Broker starting")
|
|
|
|
except MachineError as me:
|
|
|
|
self.logger.debug("Invalid method call at this moment: %s" % me)
|
|
|
|
raise BrokerException("Broker instance can't be started: %s" % me)
|
|
|
|
|
|
|
|
try:
|
2015-07-11 18:52:34 +00:00
|
|
|
self._server = yield from asyncio.start_server(self.client_connected,
|
|
|
|
self.config['bind-address'],
|
|
|
|
self.config['bind-port'],
|
|
|
|
loop=self._loop)
|
2015-07-07 20:48:53 +00:00
|
|
|
self.logger.info("Broker listening on %s:%d" % (self.config['bind-address'], self.config['bind-port']))
|
|
|
|
self.machine.starting_success()
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error("Broker startup failed: %s" % e)
|
|
|
|
self.machine.starting_fail()
|
|
|
|
raise BrokerException("Broker instance can't be started: %s" % e)
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def shutdown(self):
|
|
|
|
try:
|
|
|
|
self.machine.shutdown()
|
|
|
|
except MachineError as me:
|
|
|
|
self.logger.debug("Invalid method call at this moment: %s" % me)
|
|
|
|
raise BrokerException("Broker instance can't be stopped: %s" % me)
|
|
|
|
self._server.close()
|
|
|
|
self.logger.debug("Broker closing")
|
|
|
|
yield from self._server.wait_closed()
|
|
|
|
self.logger.info("Broker closed")
|
|
|
|
self.machine.stopping_success()
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def client_connected(self, reader, writer):
|
2015-07-08 19:54:10 +00:00
|
|
|
extra_info = writer.get_extra_info('peername')
|
|
|
|
remote_address = extra_info[0]
|
|
|
|
remote_port = extra_info[1]
|
2015-07-07 20:58:49 +00:00
|
|
|
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port))
|
2015-07-10 20:55:22 +00:00
|
|
|
|
|
|
|
# Wait for first packet and expect a CONNECT
|
|
|
|
connect = None
|
|
|
|
try:
|
|
|
|
connect = yield from ConnectPacket.from_stream(reader)
|
2015-07-12 20:35:56 +00:00
|
|
|
self.logger.debug(" <-in-- " + repr(connect))
|
2015-07-10 20:55:22 +00:00
|
|
|
self.check_connect(connect)
|
|
|
|
except HBMQTTException as exc:
|
2015-07-11 18:52:34 +00:00
|
|
|
self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
|
|
|
|
(format_client_message(address=remote_address, port=remote_port), exc))
|
2015-07-10 20:55:22 +00:00
|
|
|
writer.close()
|
2015-07-12 20:35:56 +00:00
|
|
|
self.logger.debug("Connection closed")
|
2015-07-10 20:55:22 +00:00
|
|
|
return
|
|
|
|
except BrokerException as be:
|
2015-07-11 18:52:34 +00:00
|
|
|
self.logger.error('Invalid connection from %s : %s' %
|
|
|
|
(format_client_message(address=remote_address, port=remote_port), be))
|
2015-07-10 20:55:22 +00:00
|
|
|
writer.close()
|
2015-07-12 20:35:56 +00:00
|
|
|
self.logger.debug("Connection closed")
|
2015-07-10 20:55:22 +00:00
|
|
|
return
|
|
|
|
|
2015-07-11 18:52:34 +00:00
|
|
|
connack = None
|
2015-07-10 20:55:22 +00:00
|
|
|
if connect.variable_header.proto_level != 4:
|
|
|
|
# only MQTT 3.1.1 supported
|
2015-07-11 18:52:34 +00:00
|
|
|
self.logger.error('Invalid protocol from %s: %d' %
|
|
|
|
(format_client_message(address=remote_address, port=remote_port),
|
|
|
|
connect.variable_header.protocol_level))
|
|
|
|
connack = ConnackPacket.build(0, ReturnCode.UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
|
|
|
|
elif connect.variable_header.username_flag and connect.payload.username is None:
|
|
|
|
self.logger.error('Invalid username from %s' %
|
|
|
|
(format_client_message(address=remote_address, port=remote_port)))
|
|
|
|
connack = ConnackPacket.build(0, ReturnCode.BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
|
|
|
|
elif connect.variable_header.password_flag and connect.payload.password is None:
|
2015-07-10 20:55:22 +00:00
|
|
|
self.logger.error('Invalid password %s' % (format_client_message(address=remote_address, port=remote_port)))
|
2015-07-11 18:52:34 +00:00
|
|
|
connack = ConnackPacket.build(0, ReturnCode.BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
|
|
|
|
elif connect.variable_header.clean_session_flag == False and connect.payload.client_id is None:
|
|
|
|
self.logger.error('[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' %
|
|
|
|
format_client_message(address=remote_address, port=remote_port))
|
|
|
|
connack = ConnackPacket.build(0, ReturnCode.IDENTIFIER_REJECTED)
|
|
|
|
self.logger.debug(" -out-> " + repr(connack))
|
|
|
|
if connack is not None:
|
2015-07-10 20:55:22 +00:00
|
|
|
self.logger.debug(" -out-> " + repr(connack))
|
|
|
|
yield from connack.to_stream(writer)
|
|
|
|
writer.close()
|
|
|
|
return
|
|
|
|
|
2015-07-12 20:35:56 +00:00
|
|
|
client_session = None
|
2015-07-10 20:55:22 +00:00
|
|
|
if connect.variable_header.clean_session_flag:
|
|
|
|
client_id = connect.payload.client_id
|
|
|
|
if client_id is not None and client_id in self._sessions:
|
2015-07-11 18:59:58 +00:00
|
|
|
# Delete existing session
|
2015-07-10 20:55:22 +00:00
|
|
|
del self._sessions[client_id]
|
2015-07-12 20:35:56 +00:00
|
|
|
client_session = Session()
|
|
|
|
client_session.parent = 0
|
|
|
|
self._sessions[client_id] = client_session
|
2015-07-10 20:55:22 +00:00
|
|
|
else:
|
|
|
|
# Get session from cache
|
|
|
|
client_id = connect.payload.client_id
|
|
|
|
if client_id in self._sessions:
|
2015-07-12 20:35:56 +00:00
|
|
|
client_session = self._sessions[client_id]
|
|
|
|
client_session.parent = 1
|
2015-07-10 20:55:22 +00:00
|
|
|
else:
|
2015-07-12 20:35:56 +00:00
|
|
|
client_session = Session()
|
|
|
|
client_session.parent = 0
|
2015-07-10 20:55:22 +00:00
|
|
|
|
2015-07-12 20:35:56 +00:00
|
|
|
if client_session.client_id is None:
|
2015-07-10 20:55:22 +00:00
|
|
|
# Generate client ID
|
2015-07-12 20:35:56 +00:00
|
|
|
client_session.client_id = gen_client_id()
|
|
|
|
client_session.remote_address = remote_address
|
|
|
|
client_session.remote_port = remote_port
|
|
|
|
client_session.clean_session = connect.variable_header.clean_session_flag
|
|
|
|
client_session.will_flag = connect.variable_header.will_flag
|
|
|
|
client_session.will_retain = connect.variable_header.will_retain_flag
|
|
|
|
client_session.will_qos = connect.variable_header.will_qos
|
|
|
|
client_session.will_topic = connect.payload.will_topic
|
|
|
|
client_session.will_message = connect.payload.will_message
|
|
|
|
client_session.username = connect.payload.username
|
|
|
|
client_session.password = connect.payload.password
|
|
|
|
client_session.client_id = connect.payload.client_id
|
|
|
|
if connect.variable_header.keep_alive > 0:
|
|
|
|
client_session.keep_alive = connect.variable_header.keep_alive + self.config['timeout-disconnect-delay']
|
|
|
|
else:
|
|
|
|
client_session.keep_alive = 0
|
|
|
|
|
|
|
|
client_session.reader = reader
|
|
|
|
client_session.writer = writer
|
|
|
|
|
|
|
|
if self.authenticate(client_session):
|
|
|
|
connack = ConnackPacket.build(client_session.parent, ReturnCode.CONNECTION_ACCEPTED)
|
|
|
|
self.logger.info('%s : connection accepted' % format_client_message(session=client_session))
|
2015-07-10 20:55:22 +00:00
|
|
|
self.logger.debug(" -out-> " + repr(connack))
|
|
|
|
yield from connack.to_stream(writer)
|
|
|
|
else:
|
2015-07-12 20:35:56 +00:00
|
|
|
connack = ConnackPacket.build(client_session.parent, ReturnCode.NOT_AUTHORIZED)
|
|
|
|
self.logger.info('%s : connection refused' % format_client_message(session=client_session))
|
2015-07-10 20:55:22 +00:00
|
|
|
self.logger.debug(" -out-> " + repr(connack))
|
|
|
|
yield from connack.to_stream(writer)
|
|
|
|
writer.close()
|
|
|
|
return
|
|
|
|
|
2015-07-12 20:35:56 +00:00
|
|
|
client_session.machine.connect()
|
|
|
|
handler = BrokerProtocolHandler(self._loop)
|
|
|
|
handler.attach_to_session(client_session)
|
2015-07-08 19:54:10 +00:00
|
|
|
self.logger.debug("Start messages handling")
|
2015-07-08 20:47:33 +00:00
|
|
|
yield from handler.start()
|
2015-07-12 20:35:56 +00:00
|
|
|
yield from self.publish_session_retained_messages(client_session)
|
2015-07-08 19:54:10 +00:00
|
|
|
self.logger.debug("Wait for disconnect")
|
2015-07-11 20:22:33 +00:00
|
|
|
|
|
|
|
connected = True
|
|
|
|
wait_disconnect = asyncio.Task(handler.wait_disconnect())
|
|
|
|
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
|
2015-07-12 20:35:56 +00:00
|
|
|
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message())
|
2015-07-11 20:22:33 +00:00
|
|
|
while connected:
|
2015-07-12 20:35:56 +00:00
|
|
|
done, pending = yield from asyncio.wait([wait_disconnect, wait_subscription, wait_deliver],
|
2015-07-11 20:22:33 +00:00
|
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
if wait_disconnect in done:
|
|
|
|
connected = False
|
|
|
|
wait_subscription.cancel()
|
2015-07-12 20:35:56 +00:00
|
|
|
wait_deliver.cancel()
|
2015-07-11 20:22:33 +00:00
|
|
|
elif wait_subscription in done:
|
|
|
|
subscription = wait_subscription.result()
|
|
|
|
return_codes = []
|
|
|
|
for topic in subscription.topics:
|
2015-07-12 20:35:56 +00:00
|
|
|
return_codes.append(self.add_subscription(topic, client_session))
|
2015-07-11 20:22:33 +00:00
|
|
|
yield from handler.mqtt_acknowledge_subscription(subscription.packet_id, return_codes)
|
2015-07-12 20:35:56 +00:00
|
|
|
i=0
|
|
|
|
for topic in subscription.topics:
|
|
|
|
if return_codes[i] != 0x80:
|
|
|
|
yield from self.publish_retained_messages_for_subscription(topic, client_session)
|
2015-07-11 20:22:33 +00:00
|
|
|
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
|
2015-07-12 20:35:56 +00:00
|
|
|
elif wait_deliver in done:
|
|
|
|
publish_packet = wait_deliver.result().packet
|
|
|
|
topic_name = publish_packet.variable_header.topic_name
|
|
|
|
data = publish_packet.payload.data
|
|
|
|
asyncio.Task(self.broadcast_application_message(client_session, topic_name, data, retained=False))
|
|
|
|
if publish_packet.retain_flag:
|
2015-07-12 21:03:56 +00:00
|
|
|
if publish_packet.payload.data is not None and publish_packet.payload.data != b'':
|
|
|
|
# If retained flag set, store the message for further subscriptions
|
|
|
|
self.logger.debug("Retaining message from packet %s" % repr(publish_packet))
|
|
|
|
retained_message = RetainedApplicationMessage(client_session, topic_name, data)
|
|
|
|
self._global_retained_messages[topic_name] = retained_message
|
|
|
|
else:
|
|
|
|
# [MQTT-3.3.1-10]
|
|
|
|
self.logger.debug("Clear retained messages for topic '%s'" % topic_name)
|
|
|
|
del self._global_retained_messages[topic_name]
|
2015-07-12 20:35:56 +00:00
|
|
|
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message())
|
2015-07-11 20:22:33 +00:00
|
|
|
|
2015-07-12 20:35:56 +00:00
|
|
|
self.logger.debug("Client disconnecting")
|
|
|
|
try:
|
|
|
|
yield from handler.stop()
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error(e)
|
|
|
|
finally:
|
|
|
|
handler.detach_from_session()
|
|
|
|
handler = None
|
|
|
|
client_session.machine.disconnect()
|
|
|
|
writer.close()
|
|
|
|
self.logger.debug("Session disconnected")
|
2015-07-10 20:55:22 +00:00
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def check_connect(self, connect: ConnectPacket):
|
|
|
|
if connect.payload.client_id is None:
|
|
|
|
raise BrokerException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
|
|
|
|
|
|
|
|
if connect.variable_header.will_flag:
|
|
|
|
if connect.payload.will_topic is None or connect.payload.will_message is None:
|
|
|
|
raise BrokerException('will flag set, but will topic/message not present in payload')
|
|
|
|
|
|
|
|
if connect.variable_header.reserved_flag:
|
|
|
|
raise BrokerException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
|
|
|
|
|
|
|
|
def authenticate(self, session: Session):
|
|
|
|
# TODO : Handle client authentication here
|
2015-07-11 18:52:34 +00:00
|
|
|
return True
|
2015-07-11 20:22:33 +00:00
|
|
|
|
2015-07-12 20:35:56 +00:00
|
|
|
def add_subscription(self, topic, session):
|
|
|
|
import re
|
|
|
|
wildcard_pattern = re.compile('(/.+?\+)|(/\+.+?)|(/.+?\+.+?)')
|
2015-07-11 20:22:33 +00:00
|
|
|
try:
|
|
|
|
filter = topic['filter']
|
2015-07-12 20:35:56 +00:00
|
|
|
if '#' in filter and not filter.endswith('#'):
|
|
|
|
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
|
|
|
return 0x80
|
|
|
|
if '+' in filter and wildcard_pattern.match(filter):
|
|
|
|
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
|
|
|
|
return 0x80
|
|
|
|
|
2015-07-11 20:22:33 +00:00
|
|
|
qos = topic['qos']
|
|
|
|
if 'max-qos' in self.config and qos > self.config['max-qos']:
|
|
|
|
qos = self.config['max-qos']
|
|
|
|
if filter not in self._topics:
|
|
|
|
self._topics[filter] = []
|
2015-07-12 20:35:56 +00:00
|
|
|
self._topics[filter].append({'session': session, 'qos': qos})
|
2015-07-11 20:22:33 +00:00
|
|
|
return qos
|
|
|
|
except KeyError:
|
|
|
|
return 0x80
|
2015-07-12 20:35:56 +00:00
|
|
|
|
|
|
|
def matches(self, topic, filter):
|
|
|
|
import re
|
|
|
|
match_pattern = re.compile(filter.replace('#', '.*').replace('+', '[\s\w\d]+'))
|
|
|
|
if match_pattern.match(topic):
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def broadcast_application_message(self, source_session, topic, data, retained):
|
|
|
|
publish_tasks = []
|
|
|
|
for k_filter in self._topics:
|
|
|
|
if self.matches(topic, k_filter):
|
|
|
|
handlers = self._topics[k_filter]
|
|
|
|
for d in handlers:
|
|
|
|
target_session = d['session']
|
|
|
|
qos = d['qos']
|
|
|
|
if target_session.machine.state == 'connected':
|
|
|
|
self.logger.debug("broadcasting application message from %s on topic '%s' to %s" %
|
|
|
|
(format_client_message(session=source_session),
|
|
|
|
topic, format_client_message(session=target_session)))
|
|
|
|
handler = d['session'].handler
|
|
|
|
packet_id = handler.session.next_packet_id
|
|
|
|
publish_tasks.append(
|
|
|
|
asyncio.Task(handler.mqtt_publish(topic, data, packet_id, False, qos, retained)))
|
|
|
|
else:
|
|
|
|
self.logger.debug("retaining application message from %s on topic '%s' to client '%s'" %
|
|
|
|
(format_client_message(session=source_session),
|
|
|
|
topic, format_client_message(session=target_session)))
|
|
|
|
retained_message = RetainedApplicationMessage(source_session, topic, data)
|
|
|
|
target_session.retained_messages.put(retained_message)
|
|
|
|
if len(publish_tasks) > 0:
|
|
|
|
asyncio.wait(publish_tasks)
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def publish_session_retained_messages(self, session):
|
|
|
|
self.logger.debug("Begin broadcasting messages retained for session %s" % format_client_message(session=session))
|
|
|
|
while not session.retained_messages.empty():
|
|
|
|
retained = yield from session.retained_messages.get()
|
|
|
|
yield from self.broadcast_application_message(
|
|
|
|
retained.source_session, retained.topic, retained.data, False)
|
|
|
|
self.logger.debug("End broadcasting messages retained for session %s" % format_client_message(session=session))
|
|
|
|
|
|
|
|
@asyncio.coroutine
|
|
|
|
def publish_retained_messages_for_subscription(self, subscription, session):
|
|
|
|
self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s" %
|
|
|
|
(subscription['filter'], format_client_message(session=session)))
|
|
|
|
publish_tasks = []
|
|
|
|
for d_topic in self._global_retained_messages:
|
|
|
|
self.logger.debug("matching : %s %s" % (d_topic, subscription['filter']))
|
|
|
|
if self.matches(d_topic, subscription['filter']):
|
|
|
|
self.logger.debug("%s and %s match" % (d_topic, subscription['filter']))
|
|
|
|
retained = self._global_retained_messages[d_topic]
|
|
|
|
packet_id = session.next_packet_id
|
|
|
|
publish_tasks.append(asyncio.Task(
|
|
|
|
session.handler.mqtt_publish(
|
|
|
|
retained.topic, retained.data, packet_id, False, subscription['qos'], True)))
|
|
|
|
if len(publish_tasks) > 0:
|
|
|
|
asyncio.wait(publish_tasks)
|
|
|
|
self.logger.debug("End broadcasting messages retained due to subscription on '%s' from %s" %
|
|
|
|
(subscription['filter'], format_client_message(session=session)))
|