kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #244 from ajmirsky/require_at_least_one_auth
Require at least one authpull/248/head
commit
6b606f04d3
|
@ -700,17 +700,20 @@ class Broker:
|
|||
:return:
|
||||
"""
|
||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||
auth_result = True
|
||||
if returns:
|
||||
for plugin in returns:
|
||||
res = returns[plugin]
|
||||
if res is False:
|
||||
auth_result = False
|
||||
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
|
||||
else:
|
||||
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
|
||||
# If all plugins returned True, authentication is success
|
||||
return auth_result
|
||||
|
||||
results = [ result for _, result in returns.items() if result is not None] if returns else []
|
||||
if len(results) < 1:
|
||||
self.logger.debug("Authentication failed: no plugin responded with a boolean")
|
||||
return False
|
||||
|
||||
if all(results):
|
||||
self.logger.debug("Authentication succeeded")
|
||||
return True
|
||||
|
||||
for plugin, result in returns.items():
|
||||
self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}")
|
||||
|
||||
return False
|
||||
|
||||
def retain_message(
|
||||
self,
|
||||
|
|
|
@ -119,9 +119,9 @@ class PluginManager(Generic[C]):
|
|||
auth_filter_list = []
|
||||
topic_filter_list = []
|
||||
if self.app_context.config and "auth" in self.app_context.config:
|
||||
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
|
||||
auth_filter_list = self.app_context.config["auth"].get("plugins", None)
|
||||
if self.app_context.config and "topic-check" in self.app_context.config:
|
||||
topic_filter_list = self.app_context.config["topic-check"].get("plugins", [])
|
||||
topic_filter_list = self.app_context.config["topic-check"].get("plugins", None)
|
||||
|
||||
ep: EntryPoints | list[EntryPoint] = []
|
||||
if hasattr(entry_points(), "select"):
|
||||
|
@ -133,10 +133,12 @@ class PluginManager(Generic[C]):
|
|||
ep_plugin = self._load_ep_plugin(item)
|
||||
if ep_plugin is not None:
|
||||
self._plugins.append(ep_plugin.object)
|
||||
if ((not auth_filter_list or ep_plugin.name in auth_filter_list)
|
||||
# maintain legacy behavior that if there is no list, use all auth plugins
|
||||
if ((auth_filter_list is None or ep_plugin.name in auth_filter_list)
|
||||
and hasattr(ep_plugin.object, "authenticate")):
|
||||
self._auth_plugins.append(ep_plugin.object)
|
||||
if ((not topic_filter_list or ep_plugin.name in topic_filter_list)
|
||||
# maintain legacy behavior that if there is no list, use all topic plugins
|
||||
if ((topic_filter_list is None or ep_plugin.name in topic_filter_list)
|
||||
and hasattr(ep_plugin.object, "topic_filtering")):
|
||||
self._topic_plugins.append(ep_plugin.object)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
|
|
|
@ -19,7 +19,7 @@ async def main() -> None:
|
|||
client = MQTTClient(config=config)
|
||||
|
||||
try:
|
||||
await client.connect("mqtt://test.mosquitto.org:1883/")
|
||||
await client.connect("mqtt://localhost:1883/")
|
||||
logger.info("client connected")
|
||||
await asyncio.sleep(15)
|
||||
except CancelledError:
|
||||
|
|
|
@ -13,8 +13,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def uptime_coro() -> None:
|
||||
client = MQTTClient()
|
||||
await client.connect("mqtt://test.mosquitto.org/")
|
||||
client = MQTTClient(config={'auto_reconnect': False})
|
||||
await client.connect("mqtt://localhost:1883")
|
||||
|
||||
await client.subscribe(
|
||||
[
|
||||
|
|
|
@ -8,6 +8,8 @@ import urllib.request
|
|||
import pytest
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -22,7 +24,7 @@ test_config = {
|
|||
"sys_interval": 0,
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -49,12 +51,15 @@ test_config_acl: dict[str, int | dict[str, Any]] = {
|
|||
|
||||
@pytest.fixture
|
||||
def mock_plugin_manager():
|
||||
with unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager:
|
||||
with (unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager):
|
||||
plugin_manager_instance = plugin_manager.return_value
|
||||
|
||||
# disable topic filtering when using the mock manager
|
||||
plugin_manager_instance.is_topic_filtering_enabled.return_value = False
|
||||
|
||||
# allow any connection when using the mock manager
|
||||
plugin_manager_instance.map_plugin_auth = unittest.mock.AsyncMock(return_value={ BasePlugin(BaseContext()): True })
|
||||
|
||||
yield plugin_manager
|
||||
|
||||
|
||||
|
|
|
@ -83,6 +83,7 @@ listeners:
|
|||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- amqtt.plugins.authentication.AnonymousAuthPlugin
|
||||
- tests.plugins.mocks.TestAllowTopicPlugin:
|
||||
"""
|
||||
|
||||
|
@ -93,6 +94,7 @@ listeners:
|
|||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- amqtt.plugins.authentication.AnonymousAuthPlugin
|
||||
- tests.plugins.mocks.TestBlockTopicPlugin:
|
||||
"""
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
|
|||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
self.test_auth_flag = True
|
||||
return None
|
||||
return True
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
|
@ -84,8 +84,11 @@ class TestPluginManager(unittest.TestCase):
|
|||
assert plugin.test_close_flag
|
||||
|
||||
def test_plugin_auth_coro(self) -> None:
|
||||
# provide context that activates auth plugins
|
||||
context = BaseContext()
|
||||
context.config = {'auth':{}}
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
manager = PluginManager("amqtt.test.plugins", context=context)
|
||||
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
|
@ -93,8 +96,11 @@ class TestPluginManager(unittest.TestCase):
|
|||
assert plugin.test_auth_flag
|
||||
|
||||
def test_plugin_topic_coro(self) -> None:
|
||||
# provide context that activates topic check plugins
|
||||
context = BaseContext()
|
||||
context.config = {'topic-check':{}}
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
manager = PluginManager("amqtt.test.plugins", context=context)
|
||||
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
|
|
|
@ -145,6 +145,7 @@ async def test_all_plugin_events():
|
|||
},
|
||||
'sys_interval': 1,
|
||||
'plugins':{
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': {},
|
||||
'tests.plugins.test_plugins.AllEventsPlugin': {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,8 @@ async def test_broker_sys_plugin_deprecated_config() -> None:
|
|||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
|
||||
EntryPoint(name='broker_sys', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
|
||||
EntryPoint(name='auth_anonymous', group='test.mock_plugins', value='amqtt.plugins.authentication:AnonymousAuthPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
@ -92,7 +93,9 @@ async def test_broker_sys_plugin_deprecated_config() -> None:
|
|||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'auth': {}
|
||||
'auth': {
|
||||
'allow_anonymous': True
|
||||
}
|
||||
}
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
@ -132,6 +135,7 @@ async def test_broker_sys_plugin_config() -> None:
|
|||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'plugins': [
|
||||
{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True}},
|
||||
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
|
||||
]
|
||||
}
|
||||
|
|
|
@ -76,7 +76,8 @@ async def test_start_stop(broker, mock_plugin_manager):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_connect(broker, mock_plugin_manager):
|
||||
client = MQTTClient()
|
||||
client = MQTTClient(config={'auto_reconnect':False})
|
||||
|
||||
ret = await client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
assert client.session is not None
|
||||
|
@ -733,3 +734,69 @@ async def test_broker_socket_open_close(broker):
|
|||
s.send(static_connect_packet)
|
||||
await asyncio.sleep(0.1)
|
||||
s.close()
|
||||
|
||||
|
||||
legacy_config_empty_auth_plugin_list = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 0,
|
||||
'auth':{
|
||||
'plugins':[] # explicitly declare no auth plugins
|
||||
}
|
||||
}
|
||||
class_path_config_no_auth = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'plugins':{
|
||||
'tests.plugins.test_plugins.AllEventsPlugin': {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
legacy_config_empty_auth_plugin_list,
|
||||
class_path_config_no_auth,
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_without_auth_plugin(test_config):
|
||||
|
||||
broker = Broker(config=test_config)
|
||||
|
||||
await broker.start()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# make sure all expected events get triggered
|
||||
with pytest.raises(ConnectError):
|
||||
mqtt_client = MQTTClient(config={'auto_reconnect': False})
|
||||
await mqtt_client.connect()
|
||||
|
||||
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
|
||||
legacy_config_with_absent_auth_plugin_filter = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 0,
|
||||
'auth':{
|
||||
'allow-anonymous': True
|
||||
}
|
||||
}
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_with_absent_auth_plugin_filter():
|
||||
|
||||
# maintain legacy behavior that if a config is missing the 'auth' > 'plugins' filter, all plugins are active
|
||||
broker = Broker(config=legacy_config_with_absent_auth_plugin_filter)
|
||||
|
||||
await broker.start()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
mqtt_client = MQTTClient(config={'auto_reconnect': False})
|
||||
await mqtt_client.connect()
|
||||
|
||||
await broker.shutdown()
|
||||
|
|
|
@ -200,10 +200,46 @@ async def test_client_publish_ws():
|
|||
await broker.shutdown()
|
||||
|
||||
|
||||
def test_client_subscribe():
|
||||
client_subscribe_main()
|
||||
broker_std_config = {
|
||||
"listeners": {
|
||||
"default": {
|
||||
"type": "tcp",
|
||||
"bind": "0.0.0.0:1883", }
|
||||
},
|
||||
'sys_interval':2,
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
"plugins": ["auth_anonymous"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_subscribe():
|
||||
|
||||
# start a standard broker
|
||||
broker = Broker(config=broker_std_config)
|
||||
await broker.start()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# run the sample
|
||||
client_subscribe_script = Path(__file__).parent.parent / "samples/client_subscribe.py"
|
||||
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
" ".join(["python", str(client_subscribe_script)]),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
assert "ERROR" not in stdout.decode("utf-8")
|
||||
assert "Exception" not in stdout.decode("utf-8")
|
||||
assert "ERROR" not in stderr.decode("utf-8")
|
||||
assert "Exception" not in stderr.decode("utf-8")
|
||||
|
||||
await broker.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_subscribe_plugin_acl():
|
||||
broker = Broker(config=broker_acl_config)
|
||||
|
|
Ładowanie…
Reference in New Issue