kopia lustrzana https://github.com/Yakifo/amqtt
Listeners accept max_connections parameters
Server manage connection count through semaphore HBMQTT-23pull/8/head
rodzic
635eea30b3
commit
a87d989553
|
@ -4,7 +4,9 @@
|
|||
import logging
|
||||
import ssl
|
||||
import websockets
|
||||
import asyncio
|
||||
|
||||
from functools import partial
|
||||
from transitions import Machine, MachineError
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||
|
@ -48,6 +50,47 @@ class RetainedApplicationMessage:
|
|||
self.qos = qos
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, listener_name, server_instance, max_connections=-1, loop=None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.instance = server_instance
|
||||
self.conn_count = 0
|
||||
self.listener_name = listener_name
|
||||
if loop is not None:
|
||||
self._loop = loop
|
||||
else:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self.max_connections = max_connections
|
||||
if self.max_connections > 0:
|
||||
self.semaphore = asyncio.Semaphore(self.max_connections, loop=self._loop)
|
||||
else:
|
||||
self.semaphore = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def acquire_connection(self):
|
||||
if self.semaphore:
|
||||
yield from self.semaphore.acquire()
|
||||
self.conn_count += 1
|
||||
if self.max_connections > 0:
|
||||
self.logger.debug("Listener '%s': %d/%d connections acquired" %
|
||||
(self.listener_name, self.conn_count, self.max_connections))
|
||||
|
||||
def release_connection(self):
|
||||
if self.semaphore:
|
||||
self.semaphore.release()
|
||||
self.conn_count -= 1
|
||||
if self.max_connections > 0:
|
||||
self.logger.debug("Listener '%s': %d/%d connections acquired" %
|
||||
(self.listener_name, self.conn_count, self.max_connections))
|
||||
|
||||
@asyncio.coroutine
|
||||
def close_instance(self):
|
||||
if self.instance:
|
||||
self.instance.close()
|
||||
yield self.instance.wait_closed()
|
||||
|
||||
|
||||
class Broker:
|
||||
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
|
||||
|
||||
|
@ -92,7 +135,7 @@ class Broker:
|
|||
else:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self._servers = []
|
||||
self._servers = dict()
|
||||
self._init_states()
|
||||
self._sessions = dict()
|
||||
self._subscriptions = dict()
|
||||
|
@ -135,6 +178,12 @@ class Broker:
|
|||
listener = self.listeners_config[listener_name]
|
||||
self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind']))
|
||||
|
||||
#Max connections
|
||||
try:
|
||||
max_connections = listener['max_connections']
|
||||
except KeyError:
|
||||
max_connections = -1
|
||||
|
||||
# SSL Context
|
||||
sc = None
|
||||
if 'ssl' in listener and listener['ssl'].upper() == 'ON':
|
||||
|
@ -150,16 +199,18 @@ class Broker:
|
|||
|
||||
if listener['type'] == 'tcp':
|
||||
address, port = listener['bind'].split(':')
|
||||
server = yield from asyncio.start_server(self.stream_connected,
|
||||
cb_partial = partial(self.stream_connected, listener_name=listener_name)
|
||||
instance = yield from asyncio.start_server(cb_partial,
|
||||
address,
|
||||
port,
|
||||
ssl=sc,
|
||||
loop=self._loop)
|
||||
self._servers.append(server)
|
||||
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
|
||||
elif listener['type'] == 'ws':
|
||||
address, port = listener['bind'].split(':')
|
||||
server = yield from websockets.serve(self.ws_connected, address, port, ssl=sc, loop=self._loop)
|
||||
self._servers.append(server)
|
||||
cb_partial = partial(self.ws_connected, listener_name=listener_name)
|
||||
instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
|
||||
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
|
||||
self.machine.starting_success()
|
||||
self.logger.debug("Broker started")
|
||||
except Exception as e:
|
||||
|
@ -174,27 +225,29 @@ class Broker:
|
|||
except MachineError as me:
|
||||
self.logger.debug("Invalid method call at this moment: %s" % me)
|
||||
raise BrokerException("Broker instance can't be stopped: %s" % me)
|
||||
for server in self._servers:
|
||||
server.close()
|
||||
yield from server.wait_closed()
|
||||
for listener_name in self._servers:
|
||||
server = self._servers[listener_name]
|
||||
yield from server.close_instance()
|
||||
self.logger.debug("Broker closing")
|
||||
self.logger.info("Broker closed")
|
||||
self.machine.stopping_success()
|
||||
|
||||
@asyncio.coroutine
|
||||
def ws_connected(self, websocket, uri):
|
||||
self.logger.debug("ws_connected")
|
||||
yield from self.client_connected(WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
||||
def ws_connected(self, websocket, uri, listener_name):
|
||||
yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
||||
|
||||
@asyncio.coroutine
|
||||
def stream_connected(self, reader, writer):
|
||||
self.logger.debug("stream_connected")
|
||||
yield from self.client_connected(StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||
def stream_connected(self, reader, writer, listener_name):
|
||||
yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||
|
||||
@asyncio.coroutine
|
||||
def client_connected(self, reader: ReaderAdapter, writer: WriterAdapter):
|
||||
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
|
||||
# Wait for connection available
|
||||
server = self._servers[listener_name]
|
||||
yield from server.acquire_connection()
|
||||
|
||||
remote_address, remote_port = writer.get_peer_info()
|
||||
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port))
|
||||
self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
|
||||
|
||||
# Wait for first packet and expect a CONNECT
|
||||
connect = None
|
||||
|
@ -380,6 +433,7 @@ class Broker:
|
|||
client_session.machine.disconnect()
|
||||
yield from writer.close()
|
||||
self.logger.debug("%s Session disconnected" % client_session.client_id)
|
||||
server.release_connection()
|
||||
|
||||
@asyncio.coroutine
|
||||
def check_connect(self, connect: ConnectPacket):
|
||||
|
|
|
@ -11,25 +11,12 @@ config = {
|
|||
},
|
||||
'tcp-mqtt': {
|
||||
'bind': '0.0.0.0:1883',
|
||||
'max_connections': 10
|
||||
},
|
||||
'ws-mqtt': {
|
||||
'bind': '127.0.0.1:8080',
|
||||
'type': 'ws'
|
||||
},
|
||||
'wss-mqtt': {
|
||||
'bind': '127.0.0.1:8081',
|
||||
'type': 'ws',
|
||||
'ssl': 'on',
|
||||
'certfile': 'localhost.server.crt',
|
||||
'keyfile': 'server.key',
|
||||
},
|
||||
'tcp-ssl': {
|
||||
'bind': '127.0.0.1:8883',
|
||||
'ssl': 'on',
|
||||
'certfile': 'localhost.server.crt',
|
||||
'keyfile': 'server.key',
|
||||
'type': 'tcp'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ C = MQTTClient()
|
|||
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect('mqtt://test.mosquitto.org:1883/')
|
||||
yield from C.connect('mqtt://localhost:1883/')
|
||||
tasks = [
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||
|
|
Ładowanie…
Reference in New Issue