Listeners accept max_connections parameters

Server manage connection count through semaphore
HBMQTT-23
pull/8/head
Nicolas Jouanin 2015-08-06 22:44:37 +02:00
rodzic 635eea30b3
commit a87d989553
3 zmienionych plików z 72 dodań i 31 usunięć

Wyświetl plik

@ -4,7 +4,9 @@
import logging
import ssl
import websockets
import asyncio
from functools import partial
from transitions import Machine, MachineError
from hbmqtt.session import Session
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
@ -48,6 +50,47 @@ class RetainedApplicationMessage:
self.qos = qos
class Server:
def __init__(self, listener_name, server_instance, max_connections=-1, loop=None):
self.logger = logging.getLogger(__name__)
self.instance = server_instance
self.conn_count = 0
self.listener_name = listener_name
if loop is not None:
self._loop = loop
else:
self._loop = asyncio.get_event_loop()
self.max_connections = max_connections
if self.max_connections > 0:
self.semaphore = asyncio.Semaphore(self.max_connections, loop=self._loop)
else:
self.semaphore = None
@asyncio.coroutine
def acquire_connection(self):
if self.semaphore:
yield from self.semaphore.acquire()
self.conn_count += 1
if self.max_connections > 0:
self.logger.debug("Listener '%s': %d/%d connections acquired" %
(self.listener_name, self.conn_count, self.max_connections))
def release_connection(self):
if self.semaphore:
self.semaphore.release()
self.conn_count -= 1
if self.max_connections > 0:
self.logger.debug("Listener '%s': %d/%d connections acquired" %
(self.listener_name, self.conn_count, self.max_connections))
@asyncio.coroutine
def close_instance(self):
if self.instance:
self.instance.close()
yield self.instance.wait_closed()
class Broker:
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
@ -92,7 +135,7 @@ class Broker:
else:
self._loop = asyncio.get_event_loop()
self._servers = []
self._servers = dict()
self._init_states()
self._sessions = dict()
self._subscriptions = dict()
@ -135,6 +178,12 @@ class Broker:
listener = self.listeners_config[listener_name]
self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind']))
#Max connections
try:
max_connections = listener['max_connections']
except KeyError:
max_connections = -1
# SSL Context
sc = None
if 'ssl' in listener and listener['ssl'].upper() == 'ON':
@ -150,16 +199,18 @@ class Broker:
if listener['type'] == 'tcp':
address, port = listener['bind'].split(':')
server = yield from asyncio.start_server(self.stream_connected,
cb_partial = partial(self.stream_connected, listener_name=listener_name)
instance = yield from asyncio.start_server(cb_partial,
address,
port,
ssl=sc,
loop=self._loop)
self._servers.append(server)
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
elif listener['type'] == 'ws':
address, port = listener['bind'].split(':')
server = yield from websockets.serve(self.ws_connected, address, port, ssl=sc, loop=self._loop)
self._servers.append(server)
cb_partial = partial(self.ws_connected, listener_name=listener_name)
instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
self.machine.starting_success()
self.logger.debug("Broker started")
except Exception as e:
@ -174,27 +225,29 @@ class Broker:
except MachineError as me:
self.logger.debug("Invalid method call at this moment: %s" % me)
raise BrokerException("Broker instance can't be stopped: %s" % me)
for server in self._servers:
server.close()
yield from server.wait_closed()
for listener_name in self._servers:
server = self._servers[listener_name]
yield from server.close_instance()
self.logger.debug("Broker closing")
self.logger.info("Broker closed")
self.machine.stopping_success()
@asyncio.coroutine
def ws_connected(self, websocket, uri):
self.logger.debug("ws_connected")
yield from self.client_connected(WebSocketsReader(websocket), WebSocketsWriter(websocket))
def ws_connected(self, websocket, uri, listener_name):
yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
@asyncio.coroutine
def stream_connected(self, reader, writer):
self.logger.debug("stream_connected")
yield from self.client_connected(StreamReaderAdapter(reader), StreamWriterAdapter(writer))
def stream_connected(self, reader, writer, listener_name):
yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
@asyncio.coroutine
def client_connected(self, reader: ReaderAdapter, writer: WriterAdapter):
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
# Wait for connection available
server = self._servers[listener_name]
yield from server.acquire_connection()
remote_address, remote_port = writer.get_peer_info()
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port))
self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
# Wait for first packet and expect a CONNECT
connect = None
@ -380,6 +433,7 @@ class Broker:
client_session.machine.disconnect()
yield from writer.close()
self.logger.debug("%s Session disconnected" % client_session.client_id)
server.release_connection()
@asyncio.coroutine
def check_connect(self, connect: ConnectPacket):

Wyświetl plik

@ -11,25 +11,12 @@ config = {
},
'tcp-mqtt': {
'bind': '0.0.0.0:1883',
'max_connections': 10
},
'ws-mqtt': {
'bind': '127.0.0.1:8080',
'type': 'ws'
},
'wss-mqtt': {
'bind': '127.0.0.1:8081',
'type': 'ws',
'ssl': 'on',
'certfile': 'localhost.server.crt',
'keyfile': 'server.key',
},
'tcp-ssl': {
'bind': '127.0.0.1:8883',
'ssl': 'on',
'certfile': 'localhost.server.crt',
'keyfile': 'server.key',
'type': 'tcp'
}
}
}

Wyświetl plik

@ -25,7 +25,7 @@ C = MQTTClient()
@asyncio.coroutine
def test_coro():
yield from C.connect('mqtt://test.mosquitto.org:1883/')
yield from C.connect('mqtt://localhost:1883/')
tasks = [
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),