diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index 707ac76..afb5643 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -4,7 +4,9 @@ import logging import ssl import websockets +import asyncio +from functools import partial from transitions import Machine, MachineError from hbmqtt.session import Session from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler @@ -48,6 +50,47 @@ class RetainedApplicationMessage: self.qos = qos +class Server: + def __init__(self, listener_name, server_instance, max_connections=-1, loop=None): + self.logger = logging.getLogger(__name__) + self.instance = server_instance + self.conn_count = 0 + self.listener_name = listener_name + if loop is not None: + self._loop = loop + else: + self._loop = asyncio.get_event_loop() + + self.max_connections = max_connections + if self.max_connections > 0: + self.semaphore = asyncio.Semaphore(self.max_connections, loop=self._loop) + else: + self.semaphore = None + + @asyncio.coroutine + def acquire_connection(self): + if self.semaphore: + yield from self.semaphore.acquire() + self.conn_count += 1 + if self.max_connections > 0: + self.logger.debug("Listener '%s': %d/%d connections acquired" % + (self.listener_name, self.conn_count, self.max_connections)) + + def release_connection(self): + if self.semaphore: + self.semaphore.release() + self.conn_count -= 1 + if self.max_connections > 0: + self.logger.debug("Listener '%s': %d/%d connections acquired" % + (self.listener_name, self.conn_count, self.max_connections)) + + @asyncio.coroutine + def close_instance(self): + if self.instance: + self.instance.close() + yield self.instance.wait_closed() + + class Broker: states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped'] @@ -92,7 +135,7 @@ class Broker: else: self._loop = asyncio.get_event_loop() - self._servers = [] + self._servers = dict() self._init_states() self._sessions = dict() self._subscriptions = dict() @@ -135,6 +178,12 @@ class Broker: listener = self.listeners_config[listener_name] self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind'])) + #Max connections + try: + max_connections = listener['max_connections'] + except KeyError: + max_connections = -1 + # SSL Context sc = None if 'ssl' in listener and listener['ssl'].upper() == 'ON': @@ -150,16 +199,18 @@ class Broker: if listener['type'] == 'tcp': address, port = listener['bind'].split(':') - server = yield from asyncio.start_server(self.stream_connected, + cb_partial = partial(self.stream_connected, listener_name=listener_name) + instance = yield from asyncio.start_server(cb_partial, address, port, ssl=sc, loop=self._loop) - self._servers.append(server) + self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop) elif listener['type'] == 'ws': address, port = listener['bind'].split(':') - server = yield from websockets.serve(self.ws_connected, address, port, ssl=sc, loop=self._loop) - self._servers.append(server) + cb_partial = partial(self.ws_connected, listener_name=listener_name) + instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop) + self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop) self.machine.starting_success() self.logger.debug("Broker started") except Exception as e: @@ -174,27 +225,29 @@ class Broker: except MachineError as me: self.logger.debug("Invalid method call at this moment: %s" % me) raise BrokerException("Broker instance can't be stopped: %s" % me) - for server in self._servers: - server.close() - yield from server.wait_closed() + for listener_name in self._servers: + server = self._servers[listener_name] + yield from server.close_instance() self.logger.debug("Broker closing") self.logger.info("Broker closed") self.machine.stopping_success() @asyncio.coroutine - def ws_connected(self, websocket, uri): - self.logger.debug("ws_connected") - yield from self.client_connected(WebSocketsReader(websocket), WebSocketsWriter(websocket)) + def ws_connected(self, websocket, uri, listener_name): + yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket)) @asyncio.coroutine - def stream_connected(self, reader, writer): - self.logger.debug("stream_connected") - yield from self.client_connected(StreamReaderAdapter(reader), StreamWriterAdapter(writer)) + def stream_connected(self, reader, writer, listener_name): + yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) @asyncio.coroutine - def client_connected(self, reader: ReaderAdapter, writer: WriterAdapter): + def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter): + # Wait for connection available + server = self._servers[listener_name] + yield from server.acquire_connection() + remote_address, remote_port = writer.get_peer_info() - self.logger.debug("Connection from %s:%d" % (remote_address, remote_port)) + self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name)) # Wait for first packet and expect a CONNECT connect = None @@ -380,6 +433,7 @@ class Broker: client_session.machine.disconnect() yield from writer.close() self.logger.debug("%s Session disconnected" % client_session.client_id) + server.release_connection() @asyncio.coroutine def check_connect(self, connect: ConnectPacket): diff --git a/samples/broker_start.py b/samples/broker_start.py index eead04d..1b3151e 100644 --- a/samples/broker_start.py +++ b/samples/broker_start.py @@ -11,25 +11,12 @@ config = { }, 'tcp-mqtt': { 'bind': '0.0.0.0:1883', + 'max_connections': 10 }, 'ws-mqtt': { 'bind': '127.0.0.1:8080', 'type': 'ws' }, - 'wss-mqtt': { - 'bind': '127.0.0.1:8081', - 'type': 'ws', - 'ssl': 'on', - 'certfile': 'localhost.server.crt', - 'keyfile': 'server.key', - }, - 'tcp-ssl': { - 'bind': '127.0.0.1:8883', - 'ssl': 'on', - 'certfile': 'localhost.server.crt', - 'keyfile': 'server.key', - 'type': 'tcp' - } } } diff --git a/samples/client_publish.py b/samples/client_publish.py index 8873aa6..53bcab5 100644 --- a/samples/client_publish.py +++ b/samples/client_publish.py @@ -25,7 +25,7 @@ C = MQTTClient() @asyncio.coroutine def test_coro(): - yield from C.connect('mqtt://test.mosquitto.org:1883/') + yield from C.connect('mqtt://localhost:1883/') tasks = [ asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')), asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),