kopia lustrzana https://github.com/Yakifo/amqtt
Listeners accept max_connections parameters
Server manage connection count through semaphore HBMQTT-23pull/8/head
rodzic
635eea30b3
commit
a87d989553
|
@ -4,7 +4,9 @@
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
import ssl
|
||||||
import websockets
|
import websockets
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from transitions import Machine, MachineError
|
from transitions import Machine, MachineError
|
||||||
from hbmqtt.session import Session
|
from hbmqtt.session import Session
|
||||||
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||||
|
@ -48,6 +50,47 @@ class RetainedApplicationMessage:
|
||||||
self.qos = qos
|
self.qos = qos
|
||||||
|
|
||||||
|
|
||||||
|
class Server:
|
||||||
|
def __init__(self, listener_name, server_instance, max_connections=-1, loop=None):
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.instance = server_instance
|
||||||
|
self.conn_count = 0
|
||||||
|
self.listener_name = listener_name
|
||||||
|
if loop is not None:
|
||||||
|
self._loop = loop
|
||||||
|
else:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
self.max_connections = max_connections
|
||||||
|
if self.max_connections > 0:
|
||||||
|
self.semaphore = asyncio.Semaphore(self.max_connections, loop=self._loop)
|
||||||
|
else:
|
||||||
|
self.semaphore = None
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def acquire_connection(self):
|
||||||
|
if self.semaphore:
|
||||||
|
yield from self.semaphore.acquire()
|
||||||
|
self.conn_count += 1
|
||||||
|
if self.max_connections > 0:
|
||||||
|
self.logger.debug("Listener '%s': %d/%d connections acquired" %
|
||||||
|
(self.listener_name, self.conn_count, self.max_connections))
|
||||||
|
|
||||||
|
def release_connection(self):
|
||||||
|
if self.semaphore:
|
||||||
|
self.semaphore.release()
|
||||||
|
self.conn_count -= 1
|
||||||
|
if self.max_connections > 0:
|
||||||
|
self.logger.debug("Listener '%s': %d/%d connections acquired" %
|
||||||
|
(self.listener_name, self.conn_count, self.max_connections))
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def close_instance(self):
|
||||||
|
if self.instance:
|
||||||
|
self.instance.close()
|
||||||
|
yield self.instance.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
class Broker:
|
class Broker:
|
||||||
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
|
states = ['new', 'starting', 'started', 'not_started', 'stopping', 'stopped', 'not_stopped', 'stopped']
|
||||||
|
|
||||||
|
@ -92,7 +135,7 @@ class Broker:
|
||||||
else:
|
else:
|
||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
self._servers = []
|
self._servers = dict()
|
||||||
self._init_states()
|
self._init_states()
|
||||||
self._sessions = dict()
|
self._sessions = dict()
|
||||||
self._subscriptions = dict()
|
self._subscriptions = dict()
|
||||||
|
@ -135,6 +178,12 @@ class Broker:
|
||||||
listener = self.listeners_config[listener_name]
|
listener = self.listeners_config[listener_name]
|
||||||
self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind']))
|
self.logger.info("Binding listener '%s' to %s" % (listener_name, listener['bind']))
|
||||||
|
|
||||||
|
#Max connections
|
||||||
|
try:
|
||||||
|
max_connections = listener['max_connections']
|
||||||
|
except KeyError:
|
||||||
|
max_connections = -1
|
||||||
|
|
||||||
# SSL Context
|
# SSL Context
|
||||||
sc = None
|
sc = None
|
||||||
if 'ssl' in listener and listener['ssl'].upper() == 'ON':
|
if 'ssl' in listener and listener['ssl'].upper() == 'ON':
|
||||||
|
@ -150,16 +199,18 @@ class Broker:
|
||||||
|
|
||||||
if listener['type'] == 'tcp':
|
if listener['type'] == 'tcp':
|
||||||
address, port = listener['bind'].split(':')
|
address, port = listener['bind'].split(':')
|
||||||
server = yield from asyncio.start_server(self.stream_connected,
|
cb_partial = partial(self.stream_connected, listener_name=listener_name)
|
||||||
|
instance = yield from asyncio.start_server(cb_partial,
|
||||||
address,
|
address,
|
||||||
port,
|
port,
|
||||||
ssl=sc,
|
ssl=sc,
|
||||||
loop=self._loop)
|
loop=self._loop)
|
||||||
self._servers.append(server)
|
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
|
||||||
elif listener['type'] == 'ws':
|
elif listener['type'] == 'ws':
|
||||||
address, port = listener['bind'].split(':')
|
address, port = listener['bind'].split(':')
|
||||||
server = yield from websockets.serve(self.ws_connected, address, port, ssl=sc, loop=self._loop)
|
cb_partial = partial(self.ws_connected, listener_name=listener_name)
|
||||||
self._servers.append(server)
|
instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop)
|
||||||
|
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
|
||||||
self.machine.starting_success()
|
self.machine.starting_success()
|
||||||
self.logger.debug("Broker started")
|
self.logger.debug("Broker started")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -174,27 +225,29 @@ class Broker:
|
||||||
except MachineError as me:
|
except MachineError as me:
|
||||||
self.logger.debug("Invalid method call at this moment: %s" % me)
|
self.logger.debug("Invalid method call at this moment: %s" % me)
|
||||||
raise BrokerException("Broker instance can't be stopped: %s" % me)
|
raise BrokerException("Broker instance can't be stopped: %s" % me)
|
||||||
for server in self._servers:
|
for listener_name in self._servers:
|
||||||
server.close()
|
server = self._servers[listener_name]
|
||||||
yield from server.wait_closed()
|
yield from server.close_instance()
|
||||||
self.logger.debug("Broker closing")
|
self.logger.debug("Broker closing")
|
||||||
self.logger.info("Broker closed")
|
self.logger.info("Broker closed")
|
||||||
self.machine.stopping_success()
|
self.machine.stopping_success()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def ws_connected(self, websocket, uri):
|
def ws_connected(self, websocket, uri, listener_name):
|
||||||
self.logger.debug("ws_connected")
|
yield from self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
||||||
yield from self.client_connected(WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def stream_connected(self, reader, writer):
|
def stream_connected(self, reader, writer, listener_name):
|
||||||
self.logger.debug("stream_connected")
|
yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||||
yield from self.client_connected(StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def client_connected(self, reader: ReaderAdapter, writer: WriterAdapter):
|
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
|
||||||
|
# Wait for connection available
|
||||||
|
server = self._servers[listener_name]
|
||||||
|
yield from server.acquire_connection()
|
||||||
|
|
||||||
remote_address, remote_port = writer.get_peer_info()
|
remote_address, remote_port = writer.get_peer_info()
|
||||||
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port))
|
self.logger.debug("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))
|
||||||
|
|
||||||
# Wait for first packet and expect a CONNECT
|
# Wait for first packet and expect a CONNECT
|
||||||
connect = None
|
connect = None
|
||||||
|
@ -380,6 +433,7 @@ class Broker:
|
||||||
client_session.machine.disconnect()
|
client_session.machine.disconnect()
|
||||||
yield from writer.close()
|
yield from writer.close()
|
||||||
self.logger.debug("%s Session disconnected" % client_session.client_id)
|
self.logger.debug("%s Session disconnected" % client_session.client_id)
|
||||||
|
server.release_connection()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def check_connect(self, connect: ConnectPacket):
|
def check_connect(self, connect: ConnectPacket):
|
||||||
|
|
|
@ -11,25 +11,12 @@ config = {
|
||||||
},
|
},
|
||||||
'tcp-mqtt': {
|
'tcp-mqtt': {
|
||||||
'bind': '0.0.0.0:1883',
|
'bind': '0.0.0.0:1883',
|
||||||
|
'max_connections': 10
|
||||||
},
|
},
|
||||||
'ws-mqtt': {
|
'ws-mqtt': {
|
||||||
'bind': '127.0.0.1:8080',
|
'bind': '127.0.0.1:8080',
|
||||||
'type': 'ws'
|
'type': 'ws'
|
||||||
},
|
},
|
||||||
'wss-mqtt': {
|
|
||||||
'bind': '127.0.0.1:8081',
|
|
||||||
'type': 'ws',
|
|
||||||
'ssl': 'on',
|
|
||||||
'certfile': 'localhost.server.crt',
|
|
||||||
'keyfile': 'server.key',
|
|
||||||
},
|
|
||||||
'tcp-ssl': {
|
|
||||||
'bind': '127.0.0.1:8883',
|
|
||||||
'ssl': 'on',
|
|
||||||
'certfile': 'localhost.server.crt',
|
|
||||||
'keyfile': 'server.key',
|
|
||||||
'type': 'tcp'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ C = MQTTClient()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_coro():
|
def test_coro():
|
||||||
yield from C.connect('mqtt://test.mosquitto.org:1883/')
|
yield from C.connect('mqtt://localhost:1883/')
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||||
|
|
Ładowanie…
Reference in New Issue