Merge pull request #244 from ajmirsky/require_at_least_one_auth

Require at least one auth
pull/248/head
Andrew Mirsky 2025-07-03 11:21:22 -04:00 zatwierdzone przez GitHub
commit 6b606f04d3
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
11 zmienionych plików z 154 dodań i 28 usunięć

Wyświetl plik

@ -700,17 +700,20 @@ class Broker:
:return:
"""
returns = await self.plugins_manager.map_plugin_auth(session=session)
auth_result = True
if returns:
for plugin in returns:
res = returns[plugin]
if res is False:
auth_result = False
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
else:
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
# If all plugins returned True, authentication is success
return auth_result
results = [ result for _, result in returns.items() if result is not None] if returns else []
if len(results) < 1:
self.logger.debug("Authentication failed: no plugin responded with a boolean")
return False
if all(results):
self.logger.debug("Authentication succeeded")
return True
for plugin, result in returns.items():
self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}")
return False
def retain_message(
self,

Wyświetl plik

@ -119,9 +119,9 @@ class PluginManager(Generic[C]):
auth_filter_list = []
topic_filter_list = []
if self.app_context.config and "auth" in self.app_context.config:
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
auth_filter_list = self.app_context.config["auth"].get("plugins", None)
if self.app_context.config and "topic-check" in self.app_context.config:
topic_filter_list = self.app_context.config["topic-check"].get("plugins", [])
topic_filter_list = self.app_context.config["topic-check"].get("plugins", None)
ep: EntryPoints | list[EntryPoint] = []
if hasattr(entry_points(), "select"):
@ -133,10 +133,12 @@ class PluginManager(Generic[C]):
ep_plugin = self._load_ep_plugin(item)
if ep_plugin is not None:
self._plugins.append(ep_plugin.object)
if ((not auth_filter_list or ep_plugin.name in auth_filter_list)
# maintain legacy behavior that if there is no list, use all auth plugins
if ((auth_filter_list is None or ep_plugin.name in auth_filter_list)
and hasattr(ep_plugin.object, "authenticate")):
self._auth_plugins.append(ep_plugin.object)
if ((not topic_filter_list or ep_plugin.name in topic_filter_list)
# maintain legacy behavior that if there is no list, use all topic plugins
if ((topic_filter_list is None or ep_plugin.name in topic_filter_list)
and hasattr(ep_plugin.object, "topic_filtering")):
self._topic_plugins.append(ep_plugin.object)
self.logger.debug(f" Plugin {item.name} ready")

Wyświetl plik

@ -19,7 +19,7 @@ async def main() -> None:
client = MQTTClient(config=config)
try:
await client.connect("mqtt://test.mosquitto.org:1883/")
await client.connect("mqtt://localhost:1883/")
logger.info("client connected")
await asyncio.sleep(15)
except CancelledError:

Wyświetl plik

@ -13,8 +13,8 @@ logger = logging.getLogger(__name__)
async def uptime_coro() -> None:
client = MQTTClient()
await client.connect("mqtt://test.mosquitto.org/")
client = MQTTClient(config={'auto_reconnect': False})
await client.connect("mqtt://localhost:1883")
await client.subscribe(
[

Wyświetl plik

@ -8,6 +8,8 @@ import urllib.request
import pytest
from amqtt.broker import Broker
from amqtt.contexts import BaseContext
from amqtt.plugins.base import BasePlugin
log = logging.getLogger(__name__)
@ -22,7 +24,7 @@ test_config = {
"sys_interval": 0,
"auth": {
"allow-anonymous": True,
},
}
}
@ -49,12 +51,15 @@ test_config_acl: dict[str, int | dict[str, Any]] = {
@pytest.fixture
def mock_plugin_manager():
with unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager:
with (unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager):
plugin_manager_instance = plugin_manager.return_value
# disable topic filtering when using the mock manager
plugin_manager_instance.is_topic_filtering_enabled.return_value = False
# allow any connection when using the mock manager
plugin_manager_instance.map_plugin_auth = unittest.mock.AsyncMock(return_value={ BasePlugin(BaseContext()): True })
yield plugin_manager

Wyświetl plik

@ -83,6 +83,7 @@ listeners:
type: tcp
bind: 0.0.0.0:1883
plugins:
- amqtt.plugins.authentication.AnonymousAuthPlugin
- tests.plugins.mocks.TestAllowTopicPlugin:
"""
@ -93,6 +94,7 @@ listeners:
type: tcp
bind: 0.0.0.0:1883
plugins:
- amqtt.plugins.authentication.AnonymousAuthPlugin
- tests.plugins.mocks.TestBlockTopicPlugin:
"""

Wyświetl plik

@ -31,7 +31,7 @@ class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
self.test_auth_flag = True
return None
return True
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
@ -84,8 +84,11 @@ class TestPluginManager(unittest.TestCase):
assert plugin.test_close_flag
def test_plugin_auth_coro(self) -> None:
# provide context that activates auth plugins
context = BaseContext()
context.config = {'auth':{}}
manager = PluginManager("amqtt.test.plugins", context=None)
manager = PluginManager("amqtt.test.plugins", context=context)
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
@ -93,8 +96,11 @@ class TestPluginManager(unittest.TestCase):
assert plugin.test_auth_flag
def test_plugin_topic_coro(self) -> None:
# provide context that activates topic check plugins
context = BaseContext()
context.config = {'topic-check':{}}
manager = PluginManager("amqtt.test.plugins", context=None)
manager = PluginManager("amqtt.test.plugins", context=context)
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")

Wyświetl plik

@ -145,6 +145,7 @@ async def test_all_plugin_events():
},
'sys_interval': 1,
'plugins':{
'amqtt.plugins.authentication.AnonymousAuthPlugin': {},
'tests.plugins.test_plugins.AllEventsPlugin': {}
}
}

Wyświetl plik

@ -79,7 +79,8 @@ async def test_broker_sys_plugin_deprecated_config() -> None:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
EntryPoint(name='broker_sys', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
EntryPoint(name='auth_anonymous', group='test.mock_plugins', value='amqtt.plugins.authentication:AnonymousAuthPlugin'),
]
case _:
return list()
@ -92,7 +93,9 @@ async def test_broker_sys_plugin_deprecated_config() -> None:
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'auth': {}
'auth': {
'allow_anonymous': True
}
}
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
@ -132,6 +135,7 @@ async def test_broker_sys_plugin_config() -> None:
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'plugins': [
{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True}},
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
]
}

Wyświetl plik

@ -76,7 +76,8 @@ async def test_start_stop(broker, mock_plugin_manager):
@pytest.mark.asyncio
async def test_client_connect(broker, mock_plugin_manager):
client = MQTTClient()
client = MQTTClient(config={'auto_reconnect':False})
ret = await client.connect("mqtt://127.0.0.1/")
assert ret == 0
assert client.session is not None
@ -733,3 +734,69 @@ async def test_broker_socket_open_close(broker):
s.send(static_connect_packet)
await asyncio.sleep(0.1)
s.close()
legacy_config_empty_auth_plugin_list = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 0,
'auth':{
'plugins':[] # explicitly declare no auth plugins
}
}
class_path_config_no_auth = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'plugins':{
'tests.plugins.test_plugins.AllEventsPlugin': {}
}
}
@pytest.mark.parametrize("test_config", [
legacy_config_empty_auth_plugin_list,
class_path_config_no_auth,
])
@pytest.mark.asyncio
async def test_broker_without_auth_plugin(test_config):
broker = Broker(config=test_config)
await broker.start()
await asyncio.sleep(2)
# make sure all expected events get triggered
with pytest.raises(ConnectError):
mqtt_client = MQTTClient(config={'auto_reconnect': False})
await mqtt_client.connect()
await broker.shutdown()
legacy_config_with_absent_auth_plugin_filter = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 0,
'auth':{
'allow-anonymous': True
}
}
@pytest.mark.asyncio
async def test_broker_with_absent_auth_plugin_filter():
# maintain legacy behavior that if a config is missing the 'auth' > 'plugins' filter, all plugins are active
broker = Broker(config=legacy_config_with_absent_auth_plugin_filter)
await broker.start()
await asyncio.sleep(2)
mqtt_client = MQTTClient(config={'auto_reconnect': False})
await mqtt_client.connect()
await broker.shutdown()

Wyświetl plik

@ -200,10 +200,46 @@ async def test_client_publish_ws():
await broker.shutdown()
def test_client_subscribe():
client_subscribe_main()
broker_std_config = {
"listeners": {
"default": {
"type": "tcp",
"bind": "0.0.0.0:1883", }
},
'sys_interval':2,
"auth": {
"allow-anonymous": True,
"plugins": ["auth_anonymous"]
}
}
@pytest.mark.asyncio
async def test_client_subscribe():
# start a standard broker
broker = Broker(config=broker_std_config)
await broker.start()
await asyncio.sleep(1)
# run the sample
client_subscribe_script = Path(__file__).parent.parent / "samples/client_subscribe.py"
process = await asyncio.create_subprocess_shell(
" ".join(["python", str(client_subscribe_script)]),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
assert "ERROR" not in stdout.decode("utf-8")
assert "Exception" not in stdout.decode("utf-8")
assert "ERROR" not in stderr.decode("utf-8")
assert "Exception" not in stderr.decode("utf-8")
await broker.shutdown()
@pytest.mark.asyncio
async def test_client_subscribe_plugin_acl():
broker = Broker(config=broker_acl_config)