updating existing plugins to use new plugin configuraiton format. adding deprecation warnings. updating existing tests and adding additional cases to check that both old and new config formats work correctly

pull/240/head
Andrew Mirsky 2025-06-27 23:08:09 -04:00
rodzic 2485351600
commit a8e1692631
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
13 zmienionych plików z 510 dodań i 243 usunięć

Wyświetl plik

@ -1,8 +1,10 @@
from dataclasses import dataclass, field
from pathlib import Path
from passlib.apps import custom_app_context as pwd_context
from amqtt.broker import BrokerContext
from amqtt.contexts import BaseContext
from amqtt.plugins.base import BaseAuthPlugin
from amqtt.session import Session
@ -12,12 +14,17 @@ _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class AnonymousAuthPlugin(BaseAuthPlugin):
"""Authentication plugin allowing anonymous access."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
# Default to allowing anonymous
self._allow_anonymous = self._get_config_option("allow-anonymous", True) # noqa: FBT003
async def authenticate(self, *, session: Session) -> bool:
authenticated = await super().authenticate(session=session)
if authenticated:
# Default to allowing anonymous
allow_anonymous = self.auth_config.get("allow-anonymous", True) if isinstance(self.auth_config, dict) else True
if allow_anonymous:
if self._allow_anonymous:
self.context.logger.debug("Authentication success: config allows anonymous")
return True
@ -27,6 +34,12 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
self.context.logger.debug("Authentication failure: session has no username")
return False
@dataclass
class Config:
"""Allow empty username."""
allow_anonymous: bool = field(default=True)
class FileAuthPlugin(BaseAuthPlugin):
"""Authentication plugin based on a file-stored user database."""
@ -38,7 +51,7 @@ class FileAuthPlugin(BaseAuthPlugin):
def _read_password_file(self) -> None:
"""Read the password file and populates the user dictionary."""
password_file = self.auth_config.get("password-file") if isinstance(self.auth_config, dict) else None
password_file = self._get_config_option("password-file", None)
if not password_file:
self.context.logger.warning("Configuration parameter 'password-file' not found")
return
@ -87,3 +100,9 @@ class FileAuthPlugin(BaseAuthPlugin):
self.context.logger.debug(f"Authentication failure: password mismatch for user '{session.username}'")
return False
@dataclass
class Config:
"""Path to the properly encoded password file."""
password_file: str | None = None

Wyświetl plik

@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar
from amqtt.contexts import Action, BaseContext
@ -24,6 +24,16 @@ class BasePlugin(Generic[C]):
return None
return section_config
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if option_name in self.context.config:
return self.context.config[option_name]
return default
@dataclass
class Config:
"""Override to define the configuration and defaults for plugin."""
@ -39,6 +49,18 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
if not bool(self.topic_config) and not is_dataclass(self.context.config):
self.context.logger.warning("'topic-check' section not found in context configuration")
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.topic_config and option_name in self.topic_config:
return self.topic_config[option_name]
return default
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
@ -54,18 +76,31 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
bool: `True` if topic is allowed, `False` otherwise
"""
return bool(self.topic_config)
return bool(self.topic_config) or is_dataclass(self.context.config)
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.auth_config and option_name in self.auth_config:
return self.auth_config[option_name]
return default
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
if not self.auth_config:
if not bool(self.auth_config) and not is_dataclass(self.context.config):
# auth config section not found and Config dataclass not provided
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.
@ -77,8 +112,4 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
- `None` if authentication can't be achieved (then plugin result is then ignored)
"""
if not self.auth_config:
# auth config section not found
self.context.logger.warning("'auth' section not found in context configuration")
return False
return True
return bool(self.auth_config) or is_dataclass(self.context.config)

Wyświetl plik

@ -9,6 +9,7 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
from inspect import iscoroutinefunction
import logging
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
import warnings
from dacite import Config as DaciteConfig, DaciteError, from_dict
@ -82,12 +83,25 @@ class PluginManager(Generic[C]):
def _load_plugins(self, namespace: str | None = None) -> None:
if self.app_context.config and "plugins" in self.app_context.config:
if "auth" in self.app_context.config:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugin_list: list[Any] = self.app_context.config.get("plugins", [])
self._load_str_plugins(plugin_list)
else:
if not namespace:
msg = "Namespace needs to be provided for EntryPoint plugin definitions"
raise PluginLoadError(msg)
warnings.warn(
"Loading plugins from EntryPoints is deprecated and will be removed in a future version."
" Use `plugins` section of config instead.",
DeprecationWarning,
stacklevel=2
)
self._load_ep_plugins(namespace)
for plugin in self._plugins:
@ -183,7 +197,7 @@ class PluginManager(Generic[C]):
try:
plugin_class: Any = import_string(plugin_path)
except ModuleNotFoundError as ep:
except ImportError as ep:
msg = f"Plugin import failed: {plugin_path}"
raise PluginImportError(msg) from ep
@ -204,10 +218,12 @@ class PluginManager(Generic[C]):
raise PluginLoadError(msg) from e
try:
pc = plugin_class(plugin_context)
self.logger.debug(f"Loading plugin {plugin_path}")
return cast("BasePlugin[C]", plugin_class(plugin_context))
except ImportError as e:
raise PluginLoadError from e
return cast("BasePlugin[C]", pc)
except Exception as e:
self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True)
raise PluginInitError(plugin_class) from e
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
"""Get a plugin by its name from the plugins loaded for the current namespace.

Wyświetl plik

@ -1,6 +1,6 @@
import asyncio
from collections import deque # pylint: disable=C0412
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass
from typing import Any, SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412
import psutil
@ -116,12 +116,10 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
# Start $SYS topics management
try:
if is_dataclass(self.context.config):
self._sys_interval = self.context.config.sys_interval # type: ignore[unreachable]
else:
x = self.context.config.get("sys_interval") if self.context.config is not None else None
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
self._sys_interval = int(x)
self._sys_interval = self._get_config_option("sys_interval", None)
if isinstance(self._sys_interval, str | Buffer | SupportsInt | SupportsIndex):
self._sys_interval = int(self._sys_interval)
if self._sys_interval > 0:
self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds")
self._sys_handle = (

Wyświetl plik

@ -1,3 +1,4 @@
from dataclasses import dataclass, field
from typing import Any
from amqtt.contexts import Action, BaseContext
@ -51,26 +52,34 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
return False
# hbmqtt and older amqtt do not support publish filtering
if action == Action.PUBLISH and self.topic_config is not None and "publish-acl" not in self.topic_config:
if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}):
# maintain backward compatibility, assume permitted
return True
req_topic = topic
if not req_topic:
return False
return False\
username = session.username if session else None
if username is None:
username = "anonymous"
acl: dict[str, Any] = {}
if self.topic_config is not None and action == Action.PUBLISH:
acl = self.topic_config.get("publish-acl", {})
elif self.topic_config is not None and action == Action.SUBSCRIBE:
acl = self.topic_config.get("acl", {})
match action:
case Action.PUBLISH:
acl = self._get_config_option("publish-acl", {})
case Action.SUBSCRIBE:
acl = self._get_config_option("acl", {})
allowed_topics = acl.get(username, [])
if not allowed_topics:
return False
return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics)
@dataclass
class Config:
"""Mappings of username and list of approved topics."""
publish_acl: dict[str, list[str]] = field(default_factory=dict)
acl: dict[str, list[str]] = field(default_factory=dict)

Wyświetl plik

@ -3,10 +3,10 @@ listeners:
default:
type: tcp
bind: 0.0.0.0:1883
sys_interval: 20
auth:
plugins:
- auth_anonymous
allow-anonymous: true
topic-check:
enabled: False
plugins:
- amqtt.plugins.logging_amqtt.EventLoggerPlugin:
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
- amqtt.plugins.authentication.AnonymousAuthPlugin:
allow_anonymous: true
- amqtt.plugins.sys.broker.BrokerSysPlugin:
sys_interval: 20

Wyświetl plik

@ -7,4 +7,6 @@ auto_reconnect: true
reconnect_max_interval: 10
reconnect_retries: 2
broker:
uri: "mqtt://127.0.0.1"
uri: "mqtt://127.0.0.1"
plugins:
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:

Wyświetl plik

@ -28,9 +28,17 @@ its own variables to configure its behavior.
Plugins that are defined in the`project.entry-points` are notified of events if the subclass
implements one or more of these methods:
### Client and Broker
- `async def on_mqtt_packet_sent(self, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
- `async def on_mqtt_packet_received(self, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
### Client Only
none
### Broker Only
- `async def on_broker_pre_start() -> None`
- `async def on_broker_post_start() -> None`
- `async def on_broker_pre_shutdown() -> None`

Wyświetl plik

@ -1,48 +1,99 @@
# Existing Plugins
With the aMQTT Broker plugins framework, one can add additional functionality without
having to rewrite core logic. Plugins loaded by default are specified in `pyproject.toml`:
With the aMQTT plugins framework, one can add additional functionality without
having to rewrite core logic in the broker or client. Plugins can be loaded and configured using
the `plugins` section of the config file (or parameter passed to the class).
## Broker
By default, `EventLoggerPlugin`, `PacketLoggerPlugin`, `AnonymousAuthPlugin` and `BrokerSysPlugin` are activated
and configured for the broker:
```yaml
--8<-- "pyproject.toml:included"
--8<-- "amqtt/scripts/default_broker.yaml"
```
## auth_anonymous (Auth Plugin)
`amqtt.plugins.authentication:AnonymousAuthPlugin`
??? warning "Loading plugins from EntryPoints in `pyproject.toml` has been deprecated"
Previously, all plugins were loaded from EntryPoints:
```toml
--8<-- "pyproject.toml:included"
```
But the same 4 plugins were activated in the previous default config:
```yaml
--8<-- "samples/legacy.yaml"
```
## Client
By default, the `PacketLoggerPlugin` is activated and configured for the client:
```yaml
--8<-- "amqtt/scripts/default_client.yaml"
```
## Plugins
### Anonymous (Auth Plugin)
`amqtt.plugins.authentication.AnonymousAuthPlugin`
**Configuration**
```yaml
auth:
plugins:
- auth_anonymous
allow-anonymous: true # if false, providing a username will allow access
plugins:
- ...
- amqtt.plugins.authentication.AnonymousAuthPlugin:
allow_anonymous: false
- ...
```
!!! danger
even if `allow-anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
## auth_file (Auth Plugin)
??? warning "EntryPoint-style configuration is deprecated"
`amqtt.plugins.authentication:FileAuthPlugin`
```yaml
auth:
plugins:
- auth_anonymous
allow-anonymous: true # if false, providing a username will allow access
```
### Password File (Auth Plugin)
`amqtt.plugins.authentication.FileAuthPlugin`
clients are authorized by providing username and password, compared against file
**Configuration**
```yaml
auth:
plugins:
- auth_file
password-file: /path/to/password_file
plugins:
- ...
- amqtt.plugins.authentication.FileAuthPlugin:
password_file: /path/to/password_file
- ...
```
??? warning "EntryPoint-style configuration is deprecated"
```yaml
auth:
plugins:
- auth_file
password-file: /path/to/password_file
```
**File Format**
The file includes `username:password` pairs, one per line.
@ -58,33 +109,42 @@ passwd = input() if not sys.stdin.isatty() else getpass()
print(sha512_crypt.hash(passwd))
```
## Taboo (Topic Plugin)
### Taboo (Topic Plugin)
`amqtt.plugins.topic_checking:TopicTabooPlugin`
`amqtt.plugins.topic_checking.TopicTabooPlugin`
Prevents using topics named: `prohibited`, `top-secret`, and `data/classified`
**Configuration**
```yaml
topic-check:
enabled: true
plugins:
- topic_taboo
plugins:
- ...
- amqtt.plugins.topic_checking.TopicTabooPlugin:
- ...
```
## ACL (Topic Plugin)
??? warning "EntryPoint-style configuration is deprecated"
`amqtt.plugins.topic_checking:TopicAccessControlListPlugin`
```yaml
topic-check:
enabled: true
plugins:
- topic_taboo
```
### ACL (Topic Plugin)
`amqtt.plugins.topic_checking.TopicAccessControlListPlugin`
**Configuration**
- `acl` *(list)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access.
- `acl` *(mapping)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access.
The list should be a key-value pair, where:
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
- `publish-acl` *(list)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where:
- `publish-acl` *(mapping)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where:
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
!!! info "Reserved usernames"
@ -93,47 +153,89 @@ topic-check:
- The username `anonymous` will control allowed topics, if using the `auth_anonymous` plugin.
```yaml
topic-check:
enabled: true
plugins:
- topic_acl
publish-acl:
- username: ["list", "of", "allowed", "topics", "for", "publishing"]
- .
acl:
- username: ["list", "of", "allowed", "topics", "for", "subscribing"]
- .
plugins:
- ...
- amqtt.plugins.topic_checking.TopicAccessControlListPlugin:
publish_acl:
- username: ["list", "of", "allowed", "topics", "for", "publishing"]
acl:
- username: ["list", "of", "allowed", "topics", "for", "subscribing"]
- ...
```
## Plugin: $SYS
??? warning "EntryPoint-style configuration is deprecated"
```yaml
topic-check:
enabled: true
plugins:
- topic_acl
publish-acl:
- username: ["list", "of", "allowed", "topics", "for", "publishing"]
- .
acl:
- username: ["list", "of", "allowed", "topics", "for", "subscribing"]
- .
```
`amqtt.plugins.sys.broker:BrokerSysPlugin`
### $SYS topics
`amqtt.plugins.sys.broker.BrokerSysPlugin`
Publishes, on a periodic basis, statistics about the broker
**Configuration**
- `sys_interval` - int, seconds between updates
```yaml
plugins:
- ...
- amqtt.plugins.sys.broker.BrokerSysPlugin:
sys_interval: 20 # int, seconds between updates
- ...
```
**Supported Topics**
- `$SYS/broker/version` - payload: `str`
- `$SYS/broker/load/bytes/received` - payload: `int`
- `$SYS/broker/load/bytes/sent` - payload: `int`
- `$SYS/broker/messages/received` - payload: `int`
- `$SYS/broker/messages/sent` - payload: `int`
- `$SYS/broker/time` - payload: `int` (current time, epoch seconds)
- `$SYS/broker/uptime` - payload: `int` (seconds since broker start)
- `$SYS/broker/uptime/formatted` - payload: `str` (start time of broker in UTC)
- `$SYS/broker/clients/connected` - payload: `int` (current number of connected clients)
- `$SYS/broker/clients/disconnected` - payload: `int` (number of clients that have disconnected)
- `$SYS/broker/clients/maximum` - payload: `int`
- `$SYS/broker/clients/total` - payload: `int`
- `$SYS/broker/messages/inflight` - payload: `int`
- `$SYS/broker/messages/inflight/in` - payload: `int`
- `$SYS/broker/messages/inflight/out` - payload: `int`
- `$SYS/broker/messages/inflight/stored` - payload: `int`
- `$SYS/broker/messages/publish/received` - payload: `int`
- `$SYS/broker/messages/publish/sent` - payload: `int`
- `$SYS/broker/messages/retained/count` - payload: `int`
- `$SYS/broker/messages/subscriptions/count` - payload: `int`
- `$SYS/broker/version` *(string)*
- `$SYS/broker/load/bytes/received` *(int)*
- `$SYS/broker/load/bytes/sent` *(int)*
- `$SYS/broker/messages/received` *(int)*
- `$SYS/broker/messages/sent` *(int)*
- `$SYS/broker/time` *(int, current time in epoch seconds)*
- `$SYS/broker/uptime` *(int, seconds since broker start)*
- `$SYS/broker/uptime/formatted` *(string, start time of broker in UTC)*
- `$SYS/broker/clients/connected` *(int, number of currently connected clients)*
- `$SYS/broker/clients/disconnected` *(int, number of clients that have disconnected)*
- `$SYS/broker/clients/maximum` *(int, maximum number of clients connected)*
- `$SYS/broker/clients/total` *(int)*
- `$SYS/broker/messages/inflight` *(int)*
- `$SYS/broker/messages/inflight/in` *(int)*
- `$SYS/broker/messages/inflight/out` *(int)*
- `$SYS/broker/messages/inflight/stored` *(int)*
- `$SYS/broker/messages/publish/received` *(int)*
- `$SYS/broker/messages/publish/sent` *(int)*
- `$SYS/broker/messages/retained/count` *(int)*
- `$SYS/broker/messages/subscriptions/count` *(int)*
- `$SYS/broker/heap/size` *(float, MB)*
- `$SYS/broker/heap/maximum` *(float, MB)*
- `$SYS/broker/cpu/percent` *(float, %)*
- `$SYS/broker/cpu/maximum` *(float, %)*
### Event Logger
`amqtt.plugins.logging_amqtt.EventLoggerPlugin`
This plugin issues log messages when [broker and mqtt events](custom_plugins.md#events) are triggered:
- info level messages for `client connected` and `client disconnected`
- debug level for all others
### Packet Logger
`amqtt.plugins.logging_amqtt.PacketLoggerPlugin`
This plugin issues debug-level messages for [mqtt events](custom_plugins.md#client-and-broker): `on_mqtt_packet_sent`
and `on_mqtt_packet_received`.

Wyświetl plik

@ -0,0 +1,13 @@
---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
sys_interval: 20
auth:
plugins:
- auth_anonymous
allow-anonymous: true
topic-check:
enabled: False

Wyświetl plik

@ -3,19 +3,42 @@ import logging
from pathlib import Path
import unittest
import pytest
from amqtt.plugins.authentication import AnonymousAuthPlugin, FileAuthPlugin
from amqtt.contexts import BaseContext
from amqtt.plugins.base import BaseAuthPlugin
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
@pytest.mark.asyncio
async def test_base_no_config(logdog):
"""Check BaseTopicPlugin returns false if no topic-check is present."""
with logdog() as pile:
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {}
plugin = BaseAuthPlugin(context)
s = Session()
authorised = await plugin.authenticate(session=s)
assert authorised is False
# Warning messages are only generated if using deprecated plugin configuration on initial load
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 1
assert log_records[0].levelno == logging.WARNING
assert log_records[0].message == "'auth' section not found in context configuration"
class TestAnonymousAuthPlugin(unittest.TestCase):
def setUp(self) -> None:
self.loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
def test_allow_anonymous(self) -> None:
def test_allow_anonymous_dict_config(self) -> None:
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = {"auth": {"allow-anonymous": True}}
@ -25,6 +48,16 @@ class TestAnonymousAuthPlugin(unittest.TestCase):
ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s))
assert ret
def test_allow_anonymous_dataclass_config(self) -> None:
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = AnonymousAuthPlugin.Config(allow_anonymous=True)
s = Session()
s.username = ""
auth_plugin = AnonymousAuthPlugin(context)
ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s))
assert ret
def test_disallow_anonymous(self) -> None:
context = BaseContext()
context.logger = logging.getLogger(__name__)

Wyświetl plik

@ -82,54 +82,36 @@ class MockInitErrorPlugin(BasePlugin):
@pytest.mark.asyncio
async def test_plugin_exception_while_init() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'plugins':{
'tests.plugins.test_plugins.MockInitErrorPlugin':{}
}
}
with pytest.raises(PluginInitError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
with pytest.raises(PluginInitError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
@pytest.mark.asyncio
async def test_plugin_exception_while_loading() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'plugins':{
'tests.plugins.mock_plugins.MockImportErrorPlugin':{}
}
}
with pytest.raises(PluginImportError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
with pytest.raises(PluginImportError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
class AllEventsPlugin(BasePlugin[BaseContext]):
@ -153,47 +135,37 @@ class AllEventsPlugin(BasePlugin[BaseContext]):
if name not in ('authenticate', 'topic_filtering'):
pytest.fail(f'unexpected method called: {name}')
@pytest.mark.asyncio
async def test_all_plugin_events():
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='AllEventsPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:AllEventsPlugin'),
]
case _:
return list()
# patch the entry points so we can load our test plugin
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'plugins':{
'tests.plugins.test_plugins.AllEventsPlugin': {}
}
}
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
await broker.start()
await asyncio.sleep(2)
await broker.start()
await asyncio.sleep(2)
# make sure all expected events get triggered
client = MQTTClient()
await client.connect("mqtt://127.0.0.1:1883/")
await client.subscribe([('my/test/topic', QOS_0),])
await client.publish('test/topic', b'my test message')
await client.unsubscribe(['my/test/topic',])
await client.disconnect()
await asyncio.sleep(1)
# make sure all expected events get triggered
client = MQTTClient()
await client.connect("mqtt://127.0.0.1:1883/")
await client.subscribe([('my/test/topic', QOS_0),])
await client.publish('test/topic', b'my test message')
await client.unsubscribe(['my/test/topic',])
await client.disconnect()
await asyncio.sleep(1)
# get the plugin so it doesn't get gc on shutdown
test_plugin = broker.plugins_manager.get_plugin('AllEventsPlugin')
await broker.shutdown()
await asyncio.sleep(1)
# get the plugin so it doesn't get gc on shutdown
test_plugin = broker.plugins_manager.get_plugin('AllEventsPlugin')
await broker.shutdown()
await asyncio.sleep(1)
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'

Wyświetl plik

@ -9,6 +9,8 @@ from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabo
from amqtt.plugins.base import BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
# Base plug-in object
@ -24,9 +26,11 @@ async def test_base_no_config(logdog):
authorised = await plugin.topic_filtering()
assert authorised is False
# Warning messages are no longer generated
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Warning messages are only generated if using deprecated plugin configuration on initial load
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 1
assert log_records[0].levelno == logging.WARNING
assert log_records[0].message == "'topic-check' section not found in context configuration"
@pytest.mark.asyncio
@ -42,9 +46,11 @@ async def test_base_empty_config(logdog):
authorised = await plugin.topic_filtering()
assert authorised is False
# Warning messages are no longer generated
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Warning messages are only generated if using deprecated plugin configuration on initial load
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 1
assert log_records[0].levelno == logging.WARNING
assert log_records[0].message == "'topic-check' section not found in context configuration"
@pytest.mark.asyncio
@ -59,9 +65,9 @@ async def test_base_disabled_config(logdog):
authorised = await plugin.topic_filtering()
assert authorised is True
# Should NOT have printed warnings
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Should NOT have printed warnings
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
@pytest.mark.asyncio
@ -76,9 +82,9 @@ async def test_base_enabled_config(logdog):
authorised = await plugin.topic_filtering()
assert authorised is True
# Should NOT have printed warnings
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Should NOT have printed warnings
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Taboo plug-in
@ -95,9 +101,11 @@ async def test_taboo_empty_config(logdog):
plugin = TopicTabooPlugin(context)
assert (await plugin.topic_filtering()) is False
# Warning messages are no longer generated
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Warning messages are only generated if using deprecated plugin configuration on initial load
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 1
assert log_records[0].levelno == logging.WARNING
assert log_records[0].message == "'topic-check' section not found in context configuration"
@pytest.mark.asyncio
@ -119,13 +127,17 @@ async def test_taboo_disabled(logdog):
assert len(log_records) == 0
@pytest.mark.parametrize("test_config", [
({"topic-check": {"enabled": True}}),
(TopicTabooPlugin.Config())
])
@pytest.mark.asyncio
async def test_taboo_not_taboo_topic(logdog):
async def test_taboo_not_taboo_topic(logdog, test_config):
"""Check TopicTabooPlugin returns true if topic not taboo."""
with logdog() as pile:
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {"topic-check": {"enabled": True}}
context.config = test_config
session = Session()
session.username = "anybody"
@ -138,13 +150,17 @@ async def test_taboo_not_taboo_topic(logdog):
assert len(log_records) == 0
@pytest.mark.parametrize("test_config", [
({"topic-check": {"enabled": True}}),
(TopicTabooPlugin.Config())
])
@pytest.mark.asyncio
async def test_taboo_anon_taboo_topic(logdog):
async def test_taboo_anon_taboo_topic(logdog, test_config):
"""Check TopicTabooPlugin returns false if topic is taboo and session is anonymous."""
with logdog() as pile:
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {"topic-check": {"enabled": True}}
context.config = test_config
session = Session()
session.username = ""
@ -157,13 +173,17 @@ async def test_taboo_anon_taboo_topic(logdog):
assert len(log_records) == 0
@pytest.mark.parametrize("test_config", [
({"topic-check": {"enabled": True}}),
(TopicTabooPlugin.Config())
])
@pytest.mark.asyncio
async def test_taboo_notadmin_taboo_topic(logdog):
async def test_taboo_notadmin_taboo_topic(logdog, test_config):
"""Check TopicTabooPlugin returns false if topic is taboo and user is not "admin"."""
with logdog() as pile:
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {"topic-check": {"enabled": True}}
context.config = test_config
session = Session()
session.username = "notadmin"
@ -175,14 +195,17 @@ async def test_taboo_notadmin_taboo_topic(logdog):
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
@pytest.mark.parametrize("test_config", [
({"topic-check": {"enabled": True}}),
(TopicTabooPlugin.Config())
])
@pytest.mark.asyncio
async def test_taboo_admin_taboo_topic(logdog):
async def test_taboo_admin_taboo_topic(logdog, test_config):
"""Check TopicTabooPlugin returns true if topic is taboo and user is "admin"."""
with logdog() as pile:
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {"topic-check": {"enabled": True}}
context.config = test_config
session = Session()
session.username = "admin"
@ -251,9 +274,11 @@ async def test_taclp_empty_config(logdog):
plugin = TopicAccessControlListPlugin(context)
assert (await plugin.topic_filtering()) is False
# Warning messages are no longer generated
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 0
# Warning messages are only generated if using deprecated plugin configuration on initial load
log_records = list(pile.drain(name="testlog"))
assert len(log_records) == 1
assert log_records[0].levelno == logging.WARNING
assert log_records[0].message == "'topic-check' section not found in context configuration"
@pytest.mark.asyncio
@ -275,15 +300,19 @@ async def test_taclp_true_disabled(logdog):
assert authorised is True
@pytest.mark.parametrize("test_config", [
({"topic-check": {"enabled": True}}),
(TopicAccessControlListPlugin.Config())
])
@pytest.mark.asyncio
async def test_taclp_true_no_pub_acl(logdog):
async def test_taclp_true_no_pub_acl(logdog, test_config):
"""Check TopicAccessControlListPlugin returns true if action=publish and no publish-acl given.
(This is for backward-compatibility with existing installations.).
"""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {"topic-check": {"enabled": True}}
context.config = test_config
session = Session()
session.username = "user"
@ -297,17 +326,23 @@ async def test_taclp_true_no_pub_acl(logdog):
assert authorised is True
@pytest.mark.asyncio
async def test_taclp_false_sub_no_topic(logdog):
"""Check TopicAccessControlListPlugin returns false user there is no topic."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {
@pytest.mark.parametrize("test_config", [
({
"topic-check": {
"enabled": True,
"acl": {"anotheruser": ["allowed/topic", "another/allowed/topic/#"]},
},
}
}),
(TopicAccessControlListPlugin.Config(
acl={"anotheruser": ["allowed/topic", "another/allowed/topic/#"]}
))
])
@pytest.mark.asyncio
async def test_taclp_false_sub_no_topic(logdog, test_config):
"""Check TopicAccessControlListPlugin returns false user there is no topic."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = test_config
session = Session()
session.username = "user"
@ -321,17 +356,23 @@ async def test_taclp_false_sub_no_topic(logdog):
assert authorised is False
@pytest.mark.asyncio
async def test_taclp_false_sub_unknown_user(logdog):
"""Check TopicAccessControlListPlugin returns false user is not listed in ACL."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {
@pytest.mark.parametrize("test_config", [
({
"topic-check": {
"enabled": True,
"acl": {"anotheruser": ["allowed/topic", "another/allowed/topic/#"]},
},
}
}),
(TopicAccessControlListPlugin.Config(
acl={"anotheruser": ["allowed/topic", "another/allowed/topic/#"]}
))
])
@pytest.mark.asyncio
async def test_taclp_false_sub_unknown_user(logdog, test_config):
"""Check TopicAccessControlListPlugin returns false user is not listed in ACL."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = test_config
session = Session()
session.username = "user"
@ -345,17 +386,23 @@ async def test_taclp_false_sub_unknown_user(logdog):
assert authorised is False
@pytest.mark.asyncio
async def test_taclp_false_sub_no_permission(logdog):
"""Check TopicAccessControlListPlugin returns false if "acl" does not list allowed topic."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {
@pytest.mark.parametrize("test_config", [
({
"topic-check": {
"enabled": True,
"acl": {"user": ["allowed/topic", "another/allowed/topic/#"]},
},
}
}),
(TopicAccessControlListPlugin.Config(
acl={"user": ["allowed/topic", "another/allowed/topic/#"]}
))
])
@pytest.mark.asyncio
async def test_taclp_false_sub_no_permission(logdog, test_config):
"""Check TopicAccessControlListPlugin returns false if "acl" does not list allowed topic."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = test_config
session = Session()
session.username = "user"
@ -368,18 +415,23 @@ async def test_taclp_false_sub_no_permission(logdog):
)
assert authorised is False
@pytest.mark.asyncio
async def test_taclp_true_sub_permission(logdog):
"""Check TopicAccessControlListPlugin returns true if "acl" lists allowed topic."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {
@pytest.mark.parametrize("test_config", [
({
"topic-check": {
"enabled": True,
"acl": {"user": ["allowed/topic", "another/allowed/topic/#"]},
},
}
}),
(TopicAccessControlListPlugin.Config(
acl={"user": ["allowed/topic", "another/allowed/topic/#"]}
))
])
@pytest.mark.asyncio
async def test_taclp_true_sub_permission(logdog, test_config):
"""Check TopicAccessControlListPlugin returns true if "acl" lists allowed topic."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = test_config
session = Session()
session.username = "user"
@ -393,17 +445,23 @@ async def test_taclp_true_sub_permission(logdog):
assert authorised is True
@pytest.mark.asyncio
async def test_taclp_true_pub_permission(logdog):
"""Check TopicAccessControlListPlugin returns true if "publish-acl" lists allowed topic for publish action."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {
@pytest.mark.parametrize("test_config", [
({
"topic-check": {
"enabled": True,
"publish-acl": {"user": ["allowed/topic", "another/allowed/topic/#"]},
},
}
}),
(TopicAccessControlListPlugin.Config(
publish_acl={"user": ["allowed/topic", "another/allowed/topic/#"]}
))
])
@pytest.mark.asyncio
async def test_taclp_true_pub_permission(logdog, test_config):
"""Check TopicAccessControlListPlugin returns true if "publish-acl" lists allowed topic for publish action."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = test_config
session = Session()
session.username = "user"
@ -417,17 +475,23 @@ async def test_taclp_true_pub_permission(logdog):
assert authorised is True
@pytest.mark.asyncio
async def test_taclp_true_anon_sub_permission(logdog):
"""Check TopicAccessControlListPlugin handles anonymous users."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = {
@pytest.mark.parametrize("test_config", [
({
"topic-check": {
"enabled": True,
"acl": {"anonymous": ["allowed/topic", "another/allowed/topic/#"]},
},
}
}),
(TopicAccessControlListPlugin.Config(
acl={"anonymous": ["allowed/topic", "another/allowed/topic/#"]}
))
])
@pytest.mark.asyncio
async def test_taclp_true_anon_sub_permission(logdog, test_config):
"""Check TopicAccessControlListPlugin handles anonymous users."""
context = BaseContext()
context.logger = logging.getLogger("testlog")
context.config = test_config
session = Session()
session.username = None