amqtt/hbmqtt/broker.py

247 wiersze
11 KiB
Python

# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
from transitions import Machine, MachineError
from hbmqtt.session import Session
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler, Subscription
from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.connack import ConnackPacket, ReturnCode
from hbmqtt.errors import HBMQTTException
from hbmqtt.utils import format_client_message, gen_client_id
_defaults = {
'bind-address': 'localhost',
'bind-port': 1883,
'timeout-disconnect-delay': 1
}
class BrokerException(BaseException):
pass
class Broker:
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
def __init__(self, config=None, loop=None):
self.logger = logging.getLogger(__name__)
self.config = _defaults
if config is not None:
self.config.update(config)
if loop is not None:
self._loop = loop
else:
self._loop = asyncio.get_event_loop()
self._server = None
self._handlers = []
self._init_states()
self._sessions = dict()
self._topics = dict()
def _init_states(self):
self.machine = Machine(states=Broker.states, initial='new')
self.machine.add_transition(trigger='start', source='new', dest='starting')
self.machine.add_transition(trigger='starting_fail', source='starting', dest='not_started')
self.machine.add_transition(trigger='starting_success', source='starting', dest='started')
self.machine.add_transition(trigger='shutdown', source='started', dest='stopping')
self.machine.add_transition(trigger='stopping_success', source='stopping', dest='stopped')
self.machine.add_transition(trigger='stopping_failure', source='stopping', dest='not_stopped')
self.machine.add_transition(trigger='start', source='stopped', dest='starting')
@asyncio.coroutine
def start(self):
try:
self.machine.start()
self.logger.debug("Broker starting")
except MachineError as me:
self.logger.debug("Invalid method call at this moment: %s" % me)
raise BrokerException("Broker instance can't be started: %s" % me)
try:
self._server = yield from asyncio.start_server(self.client_connected,
self.config['bind-address'],
self.config['bind-port'],
loop=self._loop)
self.logger.info("Broker listening on %s:%d" % (self.config['bind-address'], self.config['bind-port']))
self.machine.starting_success()
except Exception as e:
self.logger.error("Broker startup failed: %s" % e)
self.machine.starting_fail()
raise BrokerException("Broker instance can't be started: %s" % e)
@asyncio.coroutine
def shutdown(self):
try:
self.machine.shutdown()
except MachineError as me:
self.logger.debug("Invalid method call at this moment: %s" % me)
raise BrokerException("Broker instance can't be stopped: %s" % me)
self._server.close()
self.logger.debug("Broker closing")
yield from self._server.wait_closed()
self.logger.info("Broker closed")
self.machine.stopping_success()
@asyncio.coroutine
def client_connected(self, reader, writer):
self.logger.info(repr(writer.get_extra_info('peername')))
extra_info = writer.get_extra_info('peername')
remote_address = extra_info[0]
remote_port = extra_info[1]
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port))
# Wait for first packet and expect a CONNECT
connect = None
try:
connect = yield from ConnectPacket.from_stream(reader)
self.check_connect(connect)
except HBMQTTException as exc:
self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
(format_client_message(address=remote_address, port=remote_port), exc))
writer.close()
return
except BrokerException as be:
self.logger.error('Invalid connection from %s : %s' %
(format_client_message(address=remote_address, port=remote_port), be))
writer.close()
return
connack = None
if connect.variable_header.proto_level != 4:
# only MQTT 3.1.1 supported
self.logger.error('Invalid protocol from %s: %d' %
(format_client_message(address=remote_address, port=remote_port),
connect.variable_header.protocol_level))
connack = ConnackPacket.build(0, ReturnCode.UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
elif connect.variable_header.username_flag and connect.payload.username is None:
self.logger.error('Invalid username from %s' %
(format_client_message(address=remote_address, port=remote_port)))
connack = ConnackPacket.build(0, ReturnCode.BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.variable_header.password_flag and connect.payload.password is None:
self.logger.error('Invalid password %s' % (format_client_message(address=remote_address, port=remote_port)))
connack = ConnackPacket.build(0, ReturnCode.BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
elif connect.variable_header.clean_session_flag == False and connect.payload.client_id is None:
self.logger.error('[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' %
format_client_message(address=remote_address, port=remote_port))
connack = ConnackPacket.build(0, ReturnCode.IDENTIFIER_REJECTED)
self.logger.debug(" -out-> " + repr(connack))
if connack is not None:
self.logger.debug(" -out-> " + repr(connack))
yield from connack.to_stream(writer)
writer.close()
return
new_session = None
if connect.variable_header.clean_session_flag:
client_id = connect.payload.client_id
if client_id is not None and client_id in self._sessions:
# Delete existing session
del self._sessions[client_id]
new_session = Session()
new_session.parent = 0
self._sessions[client_id] = new_session
else:
# Get session from cache
client_id = connect.payload.client_id
if client_id in self._sessions:
new_session = self._sessions[client_id]
new_session.parent = 1
else:
new_session = Session()
new_session.parent = 0
if new_session.client_id is None:
# Generate client ID
new_session.client_id = gen_client_id()
new_session.remote_address = remote_address
new_session.remote_port = remote_port
new_session.clean_session = connect.variable_header.clean_session_flag
new_session.will_flag = connect.variable_header.will_flag
new_session.will_retain = connect.variable_header.will_retain_flag
new_session.will_qos = connect.variable_header.will_qos
new_session.will_topic = connect.payload.will_topic
new_session.will_message = connect.payload.will_message
new_session.username = connect.payload.username
new_session.password = connect.payload.password
new_session.client_id = connect.payload.client_id
new_session.keep_alive = connect.variable_header.keep_alive + self.config['timeout-disconnect-delay']
new_session.reader = reader
new_session.writer = writer
if self.authenticate(new_session):
connack = ConnackPacket.build(new_session.parent, ReturnCode.CONNECTION_ACCEPTED)
self.logger.info('%s : connection accepted' % format_client_message(session=new_session))
self.logger.debug(" -out-> " + repr(connack))
yield from connack.to_stream(writer)
else:
connack = ConnackPacket.build(new_session.parent, ReturnCode.NOT_AUTHORIZED)
self.logger.info('%s : connection refused' % format_client_message(session=new_session))
self.logger.debug(" -out-> " + repr(connack))
yield from connack.to_stream(writer)
writer.close()
return
new_session.machine.connect()
handler = BrokerProtocolHandler(new_session, self._loop)
self._handlers.append(handler)
self.logger.debug("Start messages handling")
yield from handler.start()
self.logger.debug("Wait for disconnect")
connected = True
wait_disconnect = asyncio.Task(handler.wait_disconnect())
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
while connected:
done, pending = yield from asyncio.wait([wait_disconnect, wait_subscription],
return_when=asyncio.FIRST_COMPLETED)
if wait_disconnect in done:
connected = False
wait_subscription.cancel()
elif wait_subscription in done:
subscription = wait_subscription.result()
return_codes = []
for topic in subscription.topics:
return_codes.append(self.add_subscription(topic, handler))
yield from handler.mqtt_acknowledge_subscription(subscription.packet_id, return_codes)
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
self.logger.debug("Client disconnected")
yield from handler.stop()
new_session.machine.disconnect()
@asyncio.coroutine
def check_connect(self, connect: ConnectPacket):
if connect.payload.client_id is None:
raise BrokerException('[[MQTT-3.1.3-3]] : Client identifier must be present' )
if connect.variable_header.will_flag:
if connect.payload.will_topic is None or connect.payload.will_message is None:
raise BrokerException('will flag set, but will topic/message not present in payload')
if connect.variable_header.reserved_flag:
raise BrokerException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
def authenticate(self, session: Session):
# TODO : Handle client authentication here
return True
def add_subscription(self, topic, handler):
try:
filter = topic['filter']
qos = topic['qos']
if 'max-qos' in self.config and qos > self.config['max-qos']:
qos = self.config['max-qos']
if filter not in self._topics:
self._topics[filter] = []
self._topics[filter].append({'handler': handler, 'qos': qos})
return qos
except KeyError:
return 0x80