kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #244 from ajmirsky/require_at_least_one_auth
Require at least one authpull/248/head
commit
6b606f04d3
|
@ -700,17 +700,20 @@ class Broker:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||||
auth_result = True
|
|
||||||
if returns:
|
results = [ result for _, result in returns.items() if result is not None] if returns else []
|
||||||
for plugin in returns:
|
if len(results) < 1:
|
||||||
res = returns[plugin]
|
self.logger.debug("Authentication failed: no plugin responded with a boolean")
|
||||||
if res is False:
|
return False
|
||||||
auth_result = False
|
|
||||||
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
|
if all(results):
|
||||||
else:
|
self.logger.debug("Authentication succeeded")
|
||||||
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
|
return True
|
||||||
# If all plugins returned True, authentication is success
|
|
||||||
return auth_result
|
for plugin, result in returns.items():
|
||||||
|
self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def retain_message(
|
def retain_message(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -119,9 +119,9 @@ class PluginManager(Generic[C]):
|
||||||
auth_filter_list = []
|
auth_filter_list = []
|
||||||
topic_filter_list = []
|
topic_filter_list = []
|
||||||
if self.app_context.config and "auth" in self.app_context.config:
|
if self.app_context.config and "auth" in self.app_context.config:
|
||||||
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
|
auth_filter_list = self.app_context.config["auth"].get("plugins", None)
|
||||||
if self.app_context.config and "topic-check" in self.app_context.config:
|
if self.app_context.config and "topic-check" in self.app_context.config:
|
||||||
topic_filter_list = self.app_context.config["topic-check"].get("plugins", [])
|
topic_filter_list = self.app_context.config["topic-check"].get("plugins", None)
|
||||||
|
|
||||||
ep: EntryPoints | list[EntryPoint] = []
|
ep: EntryPoints | list[EntryPoint] = []
|
||||||
if hasattr(entry_points(), "select"):
|
if hasattr(entry_points(), "select"):
|
||||||
|
@ -133,10 +133,12 @@ class PluginManager(Generic[C]):
|
||||||
ep_plugin = self._load_ep_plugin(item)
|
ep_plugin = self._load_ep_plugin(item)
|
||||||
if ep_plugin is not None:
|
if ep_plugin is not None:
|
||||||
self._plugins.append(ep_plugin.object)
|
self._plugins.append(ep_plugin.object)
|
||||||
if ((not auth_filter_list or ep_plugin.name in auth_filter_list)
|
# maintain legacy behavior that if there is no list, use all auth plugins
|
||||||
|
if ((auth_filter_list is None or ep_plugin.name in auth_filter_list)
|
||||||
and hasattr(ep_plugin.object, "authenticate")):
|
and hasattr(ep_plugin.object, "authenticate")):
|
||||||
self._auth_plugins.append(ep_plugin.object)
|
self._auth_plugins.append(ep_plugin.object)
|
||||||
if ((not topic_filter_list or ep_plugin.name in topic_filter_list)
|
# maintain legacy behavior that if there is no list, use all topic plugins
|
||||||
|
if ((topic_filter_list is None or ep_plugin.name in topic_filter_list)
|
||||||
and hasattr(ep_plugin.object, "topic_filtering")):
|
and hasattr(ep_plugin.object, "topic_filtering")):
|
||||||
self._topic_plugins.append(ep_plugin.object)
|
self._topic_plugins.append(ep_plugin.object)
|
||||||
self.logger.debug(f" Plugin {item.name} ready")
|
self.logger.debug(f" Plugin {item.name} ready")
|
||||||
|
|
|
@ -19,7 +19,7 @@ async def main() -> None:
|
||||||
client = MQTTClient(config=config)
|
client = MQTTClient(config=config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await client.connect("mqtt://test.mosquitto.org:1883/")
|
await client.connect("mqtt://localhost:1883/")
|
||||||
logger.info("client connected")
|
logger.info("client connected")
|
||||||
await asyncio.sleep(15)
|
await asyncio.sleep(15)
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
|
|
|
@ -13,8 +13,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def uptime_coro() -> None:
|
async def uptime_coro() -> None:
|
||||||
client = MQTTClient()
|
client = MQTTClient(config={'auto_reconnect': False})
|
||||||
await client.connect("mqtt://test.mosquitto.org/")
|
await client.connect("mqtt://localhost:1883")
|
||||||
|
|
||||||
await client.subscribe(
|
await client.subscribe(
|
||||||
[
|
[
|
||||||
|
|
|
@ -8,6 +8,8 @@ import urllib.request
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from amqtt.broker import Broker
|
from amqtt.broker import Broker
|
||||||
|
from amqtt.contexts import BaseContext
|
||||||
|
from amqtt.plugins.base import BasePlugin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -22,7 +24,7 @@ test_config = {
|
||||||
"sys_interval": 0,
|
"sys_interval": 0,
|
||||||
"auth": {
|
"auth": {
|
||||||
"allow-anonymous": True,
|
"allow-anonymous": True,
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,12 +51,15 @@ test_config_acl: dict[str, int | dict[str, Any]] = {
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_plugin_manager():
|
def mock_plugin_manager():
|
||||||
with unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager:
|
with (unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager):
|
||||||
plugin_manager_instance = plugin_manager.return_value
|
plugin_manager_instance = plugin_manager.return_value
|
||||||
|
|
||||||
# disable topic filtering when using the mock manager
|
# disable topic filtering when using the mock manager
|
||||||
plugin_manager_instance.is_topic_filtering_enabled.return_value = False
|
plugin_manager_instance.is_topic_filtering_enabled.return_value = False
|
||||||
|
|
||||||
|
# allow any connection when using the mock manager
|
||||||
|
plugin_manager_instance.map_plugin_auth = unittest.mock.AsyncMock(return_value={ BasePlugin(BaseContext()): True })
|
||||||
|
|
||||||
yield plugin_manager
|
yield plugin_manager
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -83,6 +83,7 @@ listeners:
|
||||||
type: tcp
|
type: tcp
|
||||||
bind: 0.0.0.0:1883
|
bind: 0.0.0.0:1883
|
||||||
plugins:
|
plugins:
|
||||||
|
- amqtt.plugins.authentication.AnonymousAuthPlugin
|
||||||
- tests.plugins.mocks.TestAllowTopicPlugin:
|
- tests.plugins.mocks.TestAllowTopicPlugin:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -93,6 +94,7 @@ listeners:
|
||||||
type: tcp
|
type: tcp
|
||||||
bind: 0.0.0.0:1883
|
bind: 0.0.0.0:1883
|
||||||
plugins:
|
plugins:
|
||||||
|
- amqtt.plugins.authentication.AnonymousAuthPlugin
|
||||||
- tests.plugins.mocks.TestBlockTopicPlugin:
|
- tests.plugins.mocks.TestBlockTopicPlugin:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
|
||||||
|
|
||||||
async def authenticate(self, *, session: Session) -> bool | None:
|
async def authenticate(self, *, session: Session) -> bool | None:
|
||||||
self.test_auth_flag = True
|
self.test_auth_flag = True
|
||||||
return None
|
return True
|
||||||
|
|
||||||
async def topic_filtering(
|
async def topic_filtering(
|
||||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||||
|
@ -84,8 +84,11 @@ class TestPluginManager(unittest.TestCase):
|
||||||
assert plugin.test_close_flag
|
assert plugin.test_close_flag
|
||||||
|
|
||||||
def test_plugin_auth_coro(self) -> None:
|
def test_plugin_auth_coro(self) -> None:
|
||||||
|
# provide context that activates auth plugins
|
||||||
|
context = BaseContext()
|
||||||
|
context.config = {'auth':{}}
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=context)
|
||||||
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
|
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
|
||||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||||
plugin = manager.get_plugin("EventTestPlugin")
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
|
@ -93,8 +96,11 @@ class TestPluginManager(unittest.TestCase):
|
||||||
assert plugin.test_auth_flag
|
assert plugin.test_auth_flag
|
||||||
|
|
||||||
def test_plugin_topic_coro(self) -> None:
|
def test_plugin_topic_coro(self) -> None:
|
||||||
|
# provide context that activates topic check plugins
|
||||||
|
context = BaseContext()
|
||||||
|
context.config = {'topic-check':{}}
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=context)
|
||||||
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
|
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
|
||||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||||
plugin = manager.get_plugin("EventTestPlugin")
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
|
|
|
@ -145,6 +145,7 @@ async def test_all_plugin_events():
|
||||||
},
|
},
|
||||||
'sys_interval': 1,
|
'sys_interval': 1,
|
||||||
'plugins':{
|
'plugins':{
|
||||||
|
'amqtt.plugins.authentication.AnonymousAuthPlugin': {},
|
||||||
'tests.plugins.test_plugins.AllEventsPlugin': {}
|
'tests.plugins.test_plugins.AllEventsPlugin': {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,8 @@ async def test_broker_sys_plugin_deprecated_config() -> None:
|
||||||
match group:
|
match group:
|
||||||
case 'tests.mock_plugins':
|
case 'tests.mock_plugins':
|
||||||
return [
|
return [
|
||||||
EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
|
EntryPoint(name='broker_sys', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
|
||||||
|
EntryPoint(name='auth_anonymous', group='test.mock_plugins', value='amqtt.plugins.authentication:AnonymousAuthPlugin'),
|
||||||
]
|
]
|
||||||
case _:
|
case _:
|
||||||
return list()
|
return list()
|
||||||
|
@ -92,7 +93,9 @@ async def test_broker_sys_plugin_deprecated_config() -> None:
|
||||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
},
|
},
|
||||||
'sys_interval': 1,
|
'sys_interval': 1,
|
||||||
'auth': {}
|
'auth': {
|
||||||
|
'allow_anonymous': True
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||||
|
@ -132,6 +135,7 @@ async def test_broker_sys_plugin_config() -> None:
|
||||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
},
|
},
|
||||||
'plugins': [
|
'plugins': [
|
||||||
|
{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True}},
|
||||||
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
|
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,8 @@ async def test_start_stop(broker, mock_plugin_manager):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_client_connect(broker, mock_plugin_manager):
|
async def test_client_connect(broker, mock_plugin_manager):
|
||||||
client = MQTTClient()
|
client = MQTTClient(config={'auto_reconnect':False})
|
||||||
|
|
||||||
ret = await client.connect("mqtt://127.0.0.1/")
|
ret = await client.connect("mqtt://127.0.0.1/")
|
||||||
assert ret == 0
|
assert ret == 0
|
||||||
assert client.session is not None
|
assert client.session is not None
|
||||||
|
@ -733,3 +734,69 @@ async def test_broker_socket_open_close(broker):
|
||||||
s.send(static_connect_packet)
|
s.send(static_connect_packet)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
|
|
||||||
|
legacy_config_empty_auth_plugin_list = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'sys_interval': 0,
|
||||||
|
'auth':{
|
||||||
|
'plugins':[] # explicitly declare no auth plugins
|
||||||
|
}
|
||||||
|
}
|
||||||
|
class_path_config_no_auth = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'plugins':{
|
||||||
|
'tests.plugins.test_plugins.AllEventsPlugin': {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_config", [
|
||||||
|
legacy_config_empty_auth_plugin_list,
|
||||||
|
class_path_config_no_auth,
|
||||||
|
])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broker_without_auth_plugin(test_config):
|
||||||
|
|
||||||
|
broker = Broker(config=test_config)
|
||||||
|
|
||||||
|
await broker.start()
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# make sure all expected events get triggered
|
||||||
|
with pytest.raises(ConnectError):
|
||||||
|
mqtt_client = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
await mqtt_client.connect()
|
||||||
|
|
||||||
|
|
||||||
|
await broker.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
legacy_config_with_absent_auth_plugin_filter = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'sys_interval': 0,
|
||||||
|
'auth':{
|
||||||
|
'allow-anonymous': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broker_with_absent_auth_plugin_filter():
|
||||||
|
|
||||||
|
# maintain legacy behavior that if a config is missing the 'auth' > 'plugins' filter, all plugins are active
|
||||||
|
broker = Broker(config=legacy_config_with_absent_auth_plugin_filter)
|
||||||
|
|
||||||
|
await broker.start()
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
mqtt_client = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
await mqtt_client.connect()
|
||||||
|
|
||||||
|
await broker.shutdown()
|
||||||
|
|
|
@ -200,10 +200,46 @@ async def test_client_publish_ws():
|
||||||
await broker.shutdown()
|
await broker.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def test_client_subscribe():
|
broker_std_config = {
|
||||||
client_subscribe_main()
|
"listeners": {
|
||||||
|
"default": {
|
||||||
|
"type": "tcp",
|
||||||
|
"bind": "0.0.0.0:1883", }
|
||||||
|
},
|
||||||
|
'sys_interval':2,
|
||||||
|
"auth": {
|
||||||
|
"allow-anonymous": True,
|
||||||
|
"plugins": ["auth_anonymous"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_subscribe():
|
||||||
|
|
||||||
|
# start a standard broker
|
||||||
|
broker = Broker(config=broker_std_config)
|
||||||
|
await broker.start()
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# run the sample
|
||||||
|
client_subscribe_script = Path(__file__).parent.parent / "samples/client_subscribe.py"
|
||||||
|
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
" ".join(["python", str(client_subscribe_script)]),
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, stderr = await process.communicate()
|
||||||
|
|
||||||
|
assert "ERROR" not in stdout.decode("utf-8")
|
||||||
|
assert "Exception" not in stdout.decode("utf-8")
|
||||||
|
assert "ERROR" not in stderr.decode("utf-8")
|
||||||
|
assert "Exception" not in stderr.decode("utf-8")
|
||||||
|
|
||||||
|
await broker.shutdown()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_client_subscribe_plugin_acl():
|
async def test_client_subscribe_plugin_acl():
|
||||||
broker = Broker(config=broker_acl_config)
|
broker = Broker(config=broker_acl_config)
|
||||||
|
|
Ładowanie…
Reference in New Issue