Listeners accept max_connections parameters

Server manage connection count through semaphore
HBMQTT-23
pull/8/head
Nicolas Jouanin 2015-08-06 22:44:37 +02:00
rodzic 635eea30b3
commit a87d989553
3 zmienionych plików z 72 dodań i 31 usunięć

Wyświetl plik

@ -4,7 +4,9 @@
import logging import logging
import ssl import ssl
import websockets import websockets
import asyncio
from functools import partial
from transitions import Machine, MachineError from transitions import Machine, MachineError
from hbmqtt.session import Session from hbmqtt.session import Session
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
@ -48,6 +50,47 @@ class RetainedApplicationMessage:
self.qos = qos self.qos = qos
class Server:
def __init__(self, listener_name, server_instance, max_connections=-1, loop=None):
self.logger = logging.getLogger(__name__)
self.instance = server_instance
self.conn_count = 0
self.listener_name = listener_name
if loop is not None:
self._loop = loop
else:
self._loop = asyncio.get_event_loop()
self.max_connections = max_connections
if self.max_connections > 0:
self.semaphore = asyncio.Semaphore(self.max_connections, loop=self._loop)
else:
self.semaphore = None
@asyncio.coroutine
def acquire_connection(self):
if self.semaphore:
yield from self.semaphore.acquire()
self.conn_count += 1
if self.max_connections > 0:
self.logger.debug("Listener '%s': %d/%d connections acquired" %
(self.listener_name, self.conn_count, self.max_connections))
def release_connection(self):
if self.semaphore:
self.semaphore.release()
self.conn_count -= 1
if self.max_connections > 0:
self.logger.debug("Listener '%s': %d/%d connections acquired" %
(self.listener_name, self.conn_count, self.max_connections))
@asyncio.coroutine
def close_instance(self):
if self.instance:
self.instance.close()
yield self.instance.wait_closed()
class Broker: class Broker:
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped'] states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
@ -92,7 +135,7 @@ class Broker:
else: else:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._servers = [] self._servers = dict()
self._init_states() self._init_states()
self._sessions = dict() self._sessions = dict()
self._subscriptions = dict() self._subscriptions = dict()
@ -135,6 +178,12 @@ class Broker:
listener = self.listeners_config[listener_name] listener = self.listeners_config[listener_name]
self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind'])) self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind']))
#Max connections
try:
max_connections = listener['max_connections']
except KeyError:
max_connections = -1
# SSL Context # SSL Context
sc = None sc = None
if 'ssl' in listener and listener['ssl'].upper() == 'ON': if 'ssl' in listener and listener['ssl'].upper() == 'ON':
@ -150,16 +199,18 @@ class Broker:
if listener['type'] == 'tcp': if listener['type'] == 'tcp':
address, port = listener['bind'].split(':') address, port = listener['bind'].split(':')
server = yield from asyncio.start_server(self.stream_connected, cb_partial = partial(self.stream_connected, listener_name=listener_name)
instance = yield from asyncio.start_server(cb_partial,
address, address,
port, port,
ssl=sc, ssl=sc,
loop=self._loop) loop=self._loop)
self._servers.append(server) self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
elif listener['type'] == 'ws': elif listener['type'] == 'ws':
address, port = listener['bind'].split(':') address, port = listener['bind'].split(':')
server = yield from websockets.serve(self.ws_connected, address, port, ssl=sc, loop=self._loop) cb_partial = partial(self.ws_connected, listener_name=listener_name)
self._servers.append(server) instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
self.machine.starting_success() self.machine.starting_success()
self.logger.debug("Broker started") self.logger.debug("Broker started")
except Exception as e: except Exception as e:
@ -174,27 +225,29 @@ class Broker:
except MachineError as me: except MachineError as me:
self.logger.debug("Invalid method call at this moment: %s" % me) self.logger.debug("Invalid method call at this moment: %s" % me)
raise BrokerException("Broker instance can't be stopped: %s" % me) raise BrokerException("Broker instance can't be stopped: %s" % me)
for server in self._servers: for listener_name in self._servers:
server.close() server = self._servers[listener_name]
yield from server.wait_closed() yield from server.close_instance()
self.logger.debug("Broker closing") self.logger.debug("Broker closing")
self.logger.info("Broker closed") self.logger.info("Broker closed")
self.machine.stopping_success() self.machine.stopping_success()
@asyncio.coroutine @asyncio.coroutine
def ws_connected(self, websocket, uri): def ws_connected(self, websocket, uri, listener_name):
self.logger.debug("ws_connected") yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
yield from self.client_connected(WebSocketsReader(websocket), WebSocketsWriter(websocket))
@asyncio.coroutine @asyncio.coroutine
def stream_connected(self, reader, writer): def stream_connected(self, reader, writer, listener_name):
self.logger.debug("stream_connected") yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
yield from self.client_connected(StreamReaderAdapter(reader), StreamWriterAdapter(writer))
@asyncio.coroutine @asyncio.coroutine
def client_connected(self, reader: ReaderAdapter, writer: WriterAdapter): def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
# Wait for connection available
server = self._servers[listener_name]
yield from server.acquire_connection()
remote_address, remote_port = writer.get_peer_info() remote_address, remote_port = writer.get_peer_info()
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port)) self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
# Wait for first packet and expect a CONNECT # Wait for first packet and expect a CONNECT
connect = None connect = None
@ -380,6 +433,7 @@ class Broker:
client_session.machine.disconnect() client_session.machine.disconnect()
yield from writer.close() yield from writer.close()
self.logger.debug("%s Session disconnected" % client_session.client_id) self.logger.debug("%s Session disconnected" % client_session.client_id)
server.release_connection()
@asyncio.coroutine @asyncio.coroutine
def check_connect(self, connect: ConnectPacket): def check_connect(self, connect: ConnectPacket):

Wyświetl plik

@ -11,25 +11,12 @@ config = {
}, },
'tcp-mqtt': { 'tcp-mqtt': {
'bind': '0.0.0.0:1883', 'bind': '0.0.0.0:1883',
'max_connections': 10
}, },
'ws-mqtt': { 'ws-mqtt': {
'bind': '127.0.0.1:8080', 'bind': '127.0.0.1:8080',
'type': 'ws' 'type': 'ws'
}, },
'wss-mqtt': {
'bind': '127.0.0.1:8081',
'type': 'ws',
'ssl': 'on',
'certfile': 'localhost.server.crt',
'keyfile': 'server.key',
},
'tcp-ssl': {
'bind': '127.0.0.1:8883',
'ssl': 'on',
'certfile': 'localhost.server.crt',
'keyfile': 'server.key',
'type': 'tcp'
}
} }
} }

Wyświetl plik

@ -25,7 +25,7 @@ C = MQTTClient()
@asyncio.coroutine @asyncio.coroutine
def test_coro(): def test_coro():
yield from C.connect('mqtt://test.mosquitto.org:1883/') yield from C.connect('mqtt://localhost:1883/')
tasks = [ tasks = [
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')), asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)), asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),