diff --git a/amqtt/broker.py b/amqtt/broker.py index 5eb1ddb..f096417 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -3,7 +3,6 @@ from asyncio import CancelledError, futures from collections import deque from collections.abc import Generator import copy -from enum import Enum from functools import partial import logging from pathlib import Path @@ -23,6 +22,7 @@ from amqtt.adapters import ( WebSocketsWriter, WriterAdapter, ) +from amqtt.contexts import Action, BaseContext from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session @@ -30,7 +30,7 @@ from amqtt.utils import format_client_message, gen_client_id, read_yaml_config from .events import BrokerEvents from .mqtt.disconnect import DisconnectPacket -from .plugins.manager import BaseContext, PluginManager +from .plugins.manager import PluginManager _CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]] _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] @@ -44,13 +44,6 @@ DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 -class Action(Enum): - """Actions issued by the broker.""" - - SUBSCRIBE = "subscribe" - PUBLISH = "publish" - - class RetainedApplicationMessage(ApplicationMessage): __slots__ = ("data", "qos", "source_session", "topic") @@ -164,6 +157,10 @@ class Broker: self.logger = logging.getLogger(__name__) self.config = copy.deepcopy(_defaults or {}) if config is not None: + # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config + if ("auth" in config or "topic-check" in config) and "plugins" not in config: + # set to None so that the config isn't updated with the new-style default plugin list + config["plugins"] = None # type: ignore[assignment] self.config.update(config) self._build_listeners_config(self.config) @@ -766,13 +763,7 @@ class Broker: :param action: What is being done with the topic? subscribe or publish :return: """ - topic_config = self.config.get("topic-check", {}) - enabled = False - - if isinstance(topic_config, dict): - enabled = topic_config.get("enabled", False) - - if not enabled: + if not self.plugins_manager.is_topic_filtering_enabled(): return True results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action) diff --git a/amqtt/client.py b/amqtt/client.py index 33ad26a..70781a1 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -19,11 +19,12 @@ from amqtt.adapters import ( WebSocketsReader, WebSocketsWriter, ) +from amqtt.contexts import BaseContext from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError from amqtt.mqtt.connack import CONNECTION_ACCEPTED from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler -from amqtt.plugins.manager import BaseContext, PluginManager +from amqtt.plugins.manager import PluginManager from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session from amqtt.utils import gen_client_id, read_yaml_config diff --git a/amqtt/contexts.py b/amqtt/contexts.py new file mode 100644 index 0000000..3a78b44 --- /dev/null +++ b/amqtt/contexts.py @@ -0,0 +1,22 @@ +from enum import Enum +import logging +from typing import TYPE_CHECKING, Any + +_LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + import asyncio + + +class BaseContext: + def __init__(self) -> None: + self.loop: asyncio.AbstractEventLoop | None = None + self.logger: logging.Logger = _LOGGER + self.config: dict[str, Any] | None = None + + +class Action(Enum): + """Actions issued by the broker.""" + + SUBSCRIBE = "subscribe" + PUBLISH = "publish" diff --git a/amqtt/errors.py b/amqtt/errors.py index 4c653a8..5d15c88 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -29,11 +29,16 @@ class PluginError(Exception): class PluginImportError(PluginError): - def __init__(self, plugin: Any) -> None: - super().__init__(f"Plugin import failed: {plugin!r}") + """Exceptions thrown when loading plugin.""" + + +class PluginCoroError(PluginError): + """Exceptions thrown when loading a plugin with a non-async call method.""" class PluginInitError(PluginError): + """Exceptions thrown when initializing plugin.""" + def __init__(self, plugin: Any) -> None: super().__init__(f"Plugin init failed: {plugin!r}") diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 666e7f5..eff191c 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -20,6 +20,7 @@ import logging from typing import Generic, TypeVar, cast from amqtt.adapters import ReaderAdapter, WriterAdapter +from amqtt.contexts import BaseContext from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError from amqtt.events import MQTTEvents from amqtt.mqtt import packet_class @@ -57,7 +58,7 @@ from amqtt.mqtt.suback import SubackPacket from amqtt.mqtt.subscribe import SubscribePacket from amqtt.mqtt.unsuback import UnsubackPacket from amqtt.mqtt.unsubscribe import UnsubscribePacket -from amqtt.plugins.manager import BaseContext, PluginManager +from amqtt.plugins.manager import PluginManager from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session C = TypeVar("C", bound=BaseContext) diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index 954c403..aef38e5 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -1,8 +1,10 @@ +from dataclasses import dataclass, field from pathlib import Path from passlib.apps import custom_app_context as pwd_context from amqtt.broker import BrokerContext +from amqtt.contexts import BaseContext from amqtt.plugins.base import BaseAuthPlugin from amqtt.session import Session @@ -12,12 +14,17 @@ _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line class AnonymousAuthPlugin(BaseAuthPlugin): """Authentication plugin allowing anonymous access.""" + def __init__(self, context: BaseContext) -> None: + super().__init__(context) + + # Default to allowing anonymous + self._allow_anonymous = self._get_config_option("allow-anonymous", True) # noqa: FBT003 + async def authenticate(self, *, session: Session) -> bool: authenticated = await super().authenticate(session=session) if authenticated: - # Default to allowing anonymous - allow_anonymous = self.auth_config.get("allow-anonymous", True) if isinstance(self.auth_config, dict) else True - if allow_anonymous: + + if self._allow_anonymous: self.context.logger.debug("Authentication success: config allows anonymous") return True @@ -27,6 +34,12 @@ class AnonymousAuthPlugin(BaseAuthPlugin): self.context.logger.debug("Authentication failure: session has no username") return False + @dataclass + class Config: + """Allow empty username.""" + + allow_anonymous: bool = field(default=True) + class FileAuthPlugin(BaseAuthPlugin): """Authentication plugin based on a file-stored user database.""" @@ -38,7 +51,7 @@ class FileAuthPlugin(BaseAuthPlugin): def _read_password_file(self) -> None: """Read the password file and populates the user dictionary.""" - password_file = self.auth_config.get("password-file") if isinstance(self.auth_config, dict) else None + password_file = self._get_config_option("password-file", None) if not password_file: self.context.logger.warning("Configuration parameter 'password-file' not found") return @@ -87,3 +100,9 @@ class FileAuthPlugin(BaseAuthPlugin): self.context.logger.debug(f"Authentication failure: password mismatch for user '{session.username}'") return False + + @dataclass + class Config: + """Path to the properly encoded password file.""" + + password_file: str | None = None diff --git a/amqtt/plugins/base.py b/amqtt/plugins/base.py index 90ee1b3..d9abcfc 100644 --- a/amqtt/plugins/base.py +++ b/amqtt/plugins/base.py @@ -1,7 +1,7 @@ +from dataclasses import dataclass, is_dataclass from typing import Any, Generic, TypeVar -from amqtt.broker import Action -from amqtt.plugins.manager import BaseContext +from amqtt.contexts import Action, BaseContext from amqtt.session import Session C = TypeVar("C", bound=BaseContext) @@ -24,6 +24,20 @@ class BasePlugin(Generic[C]): return None return section_config + def _get_config_option(self, option_name: str, default: Any=None) -> Any: + if not self.context.config: + return default + + if is_dataclass(self.context.config): + return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] + if option_name in self.context.config: + return self.context.config[option_name] + return default + + @dataclass + class Config: + """Override to define the configuration and defaults for plugin.""" + async def close(self) -> None: """Override if plugin needs to clean up resources upon shutdown.""" @@ -35,9 +49,19 @@ class BaseTopicPlugin(BasePlugin[BaseContext]): super().__init__(context) self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") - if self.topic_config is None: + if not bool(self.topic_config) and not is_dataclass(self.context.config): self.context.logger.warning("'topic-check' section not found in context configuration") + def _get_config_option(self, option_name: str, default: Any=None) -> Any: + if not self.context.config: + return default + + if is_dataclass(self.context.config): + return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] + if self.topic_config and option_name in self.topic_config: + return self.topic_config[option_name] + return default + async def topic_filtering( self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None ) -> bool: @@ -52,23 +76,31 @@ class BaseTopicPlugin(BasePlugin[BaseContext]): bool: `True` if topic is allowed, `False` otherwise """ - if not self.topic_config: - # auth config section not found - self.context.logger.warning("'topic-check' section not found in context configuration") - return False - return True + return bool(self.topic_config) or is_dataclass(self.context.config) class BaseAuthPlugin(BasePlugin[BaseContext]): """Base class for authentication plugins.""" + def _get_config_option(self, option_name: str, default: Any=None) -> Any: + if not self.context.config: + return default + + if is_dataclass(self.context.config): + return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] + if self.auth_config and option_name in self.auth_config: + return self.auth_config[option_name] + return default + def __init__(self, context: BaseContext) -> None: super().__init__(context) self.auth_config: dict[str, Any] | None = self._get_config_section("auth") - if not self.auth_config: + if not bool(self.auth_config) and not is_dataclass(self.context.config): + # auth config section not found and Config dataclass not provided self.context.logger.warning("'auth' section not found in context configuration") + async def authenticate(self, *, session: Session) -> bool | None: """Logic for session authentication. @@ -80,8 +112,4 @@ class BaseAuthPlugin(BasePlugin[BaseContext]): - `None` if authentication can't be achieved (then plugin result is then ignored) """ - if not self.auth_config: - # auth config section not found - self.context.logger.warning("'auth' section not found in context configuration") - return False - return True + return bool(self.auth_config) or is_dataclass(self.context.config) diff --git a/amqtt/plugins/logging_amqtt.py b/amqtt/plugins/logging_amqtt.py index 595684d..e2d3433 100644 --- a/amqtt/plugins/logging_amqtt.py +++ b/amqtt/plugins/logging_amqtt.py @@ -3,11 +3,11 @@ from functools import partial import logging from typing import Any, TypeAlias +from amqtt.contexts import BaseContext from amqtt.events import BrokerEvents from amqtt.mqtt import MQTTPacket from amqtt.mqtt.packet import MQTTFixedHeader, MQTTPayload, MQTTVariableHeader from amqtt.plugins.base import BasePlugin -from amqtt.plugins.manager import BaseContext from amqtt.session import Session PACKET: TypeAlias = MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader] diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index c36ecc5..c98ef8a 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -1,4 +1,4 @@ -__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"] +__all__ = ["PluginManager", "get_plugin_manager"] import asyncio from collections import defaultdict @@ -8,17 +8,17 @@ import copy from importlib.metadata import EntryPoint, EntryPoints, entry_points from inspect import iscoroutinefunction import logging -from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar +from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast +import warnings -from amqtt.errors import PluginImportError, PluginInitError +from dacite import Config as DaciteConfig, DaciteError, from_dict + +from amqtt.contexts import Action, BaseContext +from amqtt.errors import PluginCoroError, PluginImportError, PluginInitError, PluginLoadError from amqtt.events import BrokerEvents, Events, MQTTEvents +from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin from amqtt.session import Session - -_LOGGER = logging.getLogger(__name__) - -if TYPE_CHECKING: - from amqtt.broker import Action - from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin +from amqtt.utils import import_string class Plugin(NamedTuple): @@ -39,11 +39,11 @@ def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None": return plugins_manager.get(namespace) -class BaseContext: - def __init__(self) -> None: - self.loop: asyncio.AbstractEventLoop | None = None - self.logger: logging.Logger = _LOGGER - self.config: dict[str, Any] | None = None +def safe_issubclass(sub_class: Any, super_class: Any) -> bool: + try: + return issubclass(sub_class, super_class) + except TypeError: + return False AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]] @@ -70,6 +70,9 @@ class PluginManager(Generic[C]): self._auth_plugins: list[BaseAuthPlugin] = [] self._topic_plugins: list[BaseTopicPlugin] = [] self._event_plugin_callbacks: dict[str, list[AsyncFunc]] = defaultdict(list) + self._is_topic_filtering_enabled = False + self._is_auth_filtering_enabled = False + self._load_plugins(namespace) self._fired_events: list[asyncio.Future[Any]] = [] plugins_manager[namespace] = self @@ -78,10 +81,41 @@ class PluginManager(Generic[C]): def app_context(self) -> BaseContext: return self.context - def _load_plugins(self, namespace: str) -> None: + def _load_plugins(self, namespace: str | None = None) -> None: + if self.app_context.config and self.app_context.config.get("plugins", None) is not None: + if "auth" in self.app_context.config: + self.logger.warning("Loading plugins from config will ignore 'auth' section of config") + if "topic-check" in self.app_context.config: + self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config") + + plugin_list: list[Any] = self.app_context.config.get("plugins", []) + self._load_str_plugins(plugin_list) + else: + if not namespace: + msg = "Namespace needs to be provided for EntryPoint plugin definitions" + raise PluginLoadError(msg) + + warnings.warn( + "Loading plugins from EntryPoints is deprecated and will be removed in a future version." + " Use `plugins` section of config instead.", + DeprecationWarning, + stacklevel=2 + ) + + self._load_ep_plugins(namespace) + + for plugin in self._plugins: + for event in list(BrokerEvents) + list(MQTTEvents): + if awaitable := getattr(plugin, f"on_{event}", None): + if not iscoroutinefunction(awaitable): + msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'" + raise PluginImportError(msg) + self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'") + self._event_plugin_callbacks[event].append(awaitable) + + def _load_ep_plugins(self, namespace:str) -> None: self.logger.debug(f"Loading plugins for namespace {namespace}") - auth_filter_list = [] topic_filter_list = [] if self.app_context.config and "auth" in self.app_context.config: @@ -107,15 +141,6 @@ class PluginManager(Generic[C]): self._topic_plugins.append(ep_plugin.object) self.logger.debug(f" Plugin {item.name} ready") - for plugin in self._plugins: - for event in list(BrokerEvents) + list(MQTTEvents): - if awaitable := getattr(plugin, f"on_{event}", None): - if not iscoroutinefunction(awaitable): - msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'" - raise PluginImportError(msg) - self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'") - self._event_plugin_callbacks[event].append(awaitable) - def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None: try: self.logger.debug(f" Loading plugin {ep!s}") @@ -136,6 +161,70 @@ class PluginManager(Generic[C]): self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True) raise PluginInitError(ep) from e + def _load_str_plugins(self, plugin_list: list[Any]) -> None: + + self.logger.info("Loading plugins from config") + self._is_topic_filtering_enabled = True + self._is_auth_filtering_enabled = True + for plugin_info in plugin_list: + + if isinstance(plugin_info, dict): + if len(plugin_info.keys()) > 1: + msg = f"config file should have only one key: {plugin_info.keys()}" + raise ValueError(msg) + plugin_path = next(iter(plugin_info.keys())) + plugin_cfg = plugin_info[plugin_path] + plugin = self._load_str_plugin(plugin_path, plugin_cfg) + elif isinstance(plugin_info, str): + plugin = self._load_str_plugin(plugin_info, {}) + else: + msg = "Unexpected entry in plugins config" + raise PluginLoadError(msg) + + self._plugins.append(plugin) + if isinstance(plugin, BaseAuthPlugin): + if not iscoroutinefunction(plugin.authenticate): + msg = f"Auth plugin {plugin_info} has non-async authenticate method." + raise PluginCoroError(msg) + self._auth_plugins.append(plugin) + if isinstance(plugin, BaseTopicPlugin): + if not iscoroutinefunction(plugin.topic_filtering): + msg = f"Topic plugin {plugin_info} has non-async topic_filtering method." + raise PluginCoroError(msg) + self._topic_plugins.append(plugin) + + def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]": + + try: + plugin_class: Any = import_string(plugin_path) + except ImportError as ep: + msg = f"Plugin import failed: {plugin_path}" + raise PluginImportError(msg) from ep + + if not safe_issubclass(plugin_class, BasePlugin): + msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'" + raise PluginLoadError(msg) + + plugin_context = copy.copy(self.app_context) + plugin_context.logger = self.logger.getChild(plugin_class.__name__) + try: + plugin_context.config = from_dict(data_class=plugin_class.Config, + data=plugin_cfg or {}, + config=DaciteConfig(strict=True)) + except DaciteError as e: + raise PluginLoadError from e + except TypeError as e: + msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass." + raise PluginLoadError(msg) from e + + try: + pc = plugin_class(plugin_context) + self.logger.debug(f"Loading plugin {plugin_path}") + return cast("BasePlugin[C]", pc) + except Exception as e: + self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True) + raise PluginInitError(plugin_class) from e + def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]: """Get a plugin by its name from the plugins loaded for the current namespace. @@ -147,6 +236,12 @@ class PluginManager(Generic[C]): return p return None + def is_topic_filtering_enabled(self) -> bool: + topic_config = self.app_context.config.get("topic-check", {}) if self.app_context.config else {} + if isinstance(topic_config, dict): + return topic_config.get("enabled", False) or self._is_topic_filtering_enabled + return False or self._is_topic_filtering_enabled + async def close(self) -> None: """Free PluginManager resources and cancel pending event methods.""" await self.map_plugin_close() diff --git a/amqtt/plugins/persistence.py b/amqtt/plugins/persistence.py index 8baf2d9..ee79d33 100644 --- a/amqtt/plugins/persistence.py +++ b/amqtt/plugins/persistence.py @@ -2,7 +2,7 @@ import json import sqlite3 from typing import Any -from amqtt.plugins.manager import BaseContext +from amqtt.contexts import BaseContext from amqtt.session import Session diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index b0ba07f..ea022f5 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -1,5 +1,6 @@ import asyncio from collections import deque # pylint: disable=C0412 +from dataclasses import dataclass from typing import Any, SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412 import psutil @@ -70,8 +71,11 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]): # Broker statistics initialization self._stats: dict[str, int] = {} self._sys_handle: asyncio.Handle | None = None + + self._sys_interval: int = 0 self._current_process = psutil.Process() + def _clear_stats(self) -> None: """Initialize broker statistics data structures.""" for stat in ( @@ -112,21 +116,21 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]): # Start $SYS topics management try: - sys_interval: int = 0 - x = self.context.config.get("sys_interval") if self.context.config is not None else None - if isinstance(x, str | Buffer | SupportsInt | SupportsIndex): - sys_interval = int(x) - if sys_interval > 0: - self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds") + self._sys_interval = self._get_config_option("sys_interval", None) + if isinstance(self._sys_interval, str | Buffer | SupportsInt | SupportsIndex): + self._sys_interval = int(self._sys_interval) + + if self._sys_interval > 0: + self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds") self._sys_handle = ( - self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics) + self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics) if self.context.loop is not None else None ) else: self.context.logger.debug("$SYS disabled") except KeyError: - pass + self.context.logger.debug("could not find 'sys_interval' key: {e!r}") # 'sys_interval' config parameter not found async def on_broker_pre_shutdown(self) -> None: @@ -194,15 +198,9 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]): tasks.popleft() # Reschedule - sys_interval: int = 0 - x = self.context.config.get("sys_interval") if self.context.config is not None else None - if isinstance(x, str | Buffer | SupportsInt | SupportsIndex): - sys_interval = int(x) - self.context.logger.debug("Broadcasting $SYS topics") - - self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds") + self.context.logger.debug(f"Broadcast $SYS topics again in {self._sys_interval} seconds.") self._sys_handle = ( - self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics) + self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics) if self.context.loop is not None else None ) @@ -237,3 +235,9 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]): """Handle broker client disconnection.""" self._stats[STAT_CLIENTS_CONNECTED] -= 1 self._stats[STAT_CLIENTS_DISCONNECTED] += 1 + + @dataclass + class Config: + """Configuration struct for plugin.""" + + sys_interval: int = 0 diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index c61e313..c5183ea 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -1,8 +1,8 @@ +from dataclasses import dataclass, field from typing import Any -from amqtt.broker import Action +from amqtt.contexts import Action, BaseContext from amqtt.plugins.base import BaseTopicPlugin -from amqtt.plugins.manager import BaseContext from amqtt.session import Session @@ -52,26 +52,34 @@ class TopicAccessControlListPlugin(BaseTopicPlugin): return False # hbmqtt and older amqtt do not support publish filtering - if action == Action.PUBLISH and self.topic_config is not None and "publish-acl" not in self.topic_config: + if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}): # maintain backward compatibility, assume permitted return True req_topic = topic if not req_topic: - return False + return False\ username = session.username if session else None if username is None: username = "anonymous" acl: dict[str, Any] = {} - if self.topic_config is not None and action == Action.PUBLISH: - acl = self.topic_config.get("publish-acl", {}) - elif self.topic_config is not None and action == Action.SUBSCRIBE: - acl = self.topic_config.get("acl", {}) + match action: + case Action.PUBLISH: + acl = self._get_config_option("publish-acl", {}) + case Action.SUBSCRIBE: + acl = self._get_config_option("acl", {}) allowed_topics = acl.get(username, []) if not allowed_topics: return False return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics) + + @dataclass + class Config: + """Mappings of username and list of approved topics.""" + + publish_acl: dict[str, list[str]] = field(default_factory=dict) + acl: dict[str, list[str]] = field(default_factory=dict) diff --git a/amqtt/scripts/default_broker.yaml b/amqtt/scripts/default_broker.yaml index 1ec1016..d05b7f8 100644 --- a/amqtt/scripts/default_broker.yaml +++ b/amqtt/scripts/default_broker.yaml @@ -3,10 +3,10 @@ listeners: default: type: tcp bind: 0.0.0.0:1883 -sys_interval: 20 -auth: - plugins: - - auth_anonymous - allow-anonymous: true -topic-check: - enabled: False +plugins: + - amqtt.plugins.logging_amqtt.EventLoggerPlugin: + - amqtt.plugins.logging_amqtt.PacketLoggerPlugin: + - amqtt.plugins.authentication.AnonymousAuthPlugin: + allow_anonymous: true + - amqtt.plugins.sys.broker.BrokerSysPlugin: + sys_interval: 20 diff --git a/amqtt/scripts/default_client.yaml b/amqtt/scripts/default_client.yaml index 2feb944..3921e49 100644 --- a/amqtt/scripts/default_client.yaml +++ b/amqtt/scripts/default_client.yaml @@ -7,4 +7,6 @@ auto_reconnect: true reconnect_max_interval: 10 reconnect_retries: 2 broker: - uri: "mqtt://127.0.0.1" \ No newline at end of file + uri: "mqtt://127.0.0.1" +plugins: + - amqtt.plugins.logging_amqtt.PacketLoggerPlugin: diff --git a/amqtt/utils.py b/amqtt/utils.py index ca14ad2..af5c0b9 100644 --- a/amqtt/utils.py +++ b/amqtt/utils.py @@ -1,9 +1,11 @@ from __future__ import annotations +from importlib import import_module import logging from pathlib import Path import secrets import string +import sys import typing from typing import Any @@ -48,3 +50,39 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None: except yaml.YAMLError: logger.exception(f"Invalid config_file {config_file}") return None + + +def cached_import(module_path: str, class_name: str | None = None) -> Any: + """Return cached import of a class from a module path (or retrieve, cache and then return).""" + # Check whether module is loaded and fully initialized. + if not ((module := sys.modules.get(module_path)) + and (spec := getattr(module, "__spec__", None)) + and getattr(spec, "_initializing", False) is False): + module = import_module(module_path) + if class_name: + return getattr(module, class_name) + return module + + +def import_string(dotted_path: str) -> Any: + """Import a dotted module path. + + Returns: + attribute/class designated by the last name in the path + + Raises: + ImportError (if the import failed) + + """ + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as err: + msg = f"{dotted_path} doesn't look like a module path" + raise ImportError(msg) from err + + try: + return cached_import(module_path, class_name) + except AttributeError as err: + msg = f'Module "{module_path}" does not define a "{class_name}" attribute/class' + + raise ImportError(msg) from err diff --git a/amqtt/version.py b/amqtt/version.py deleted file mode 100644 index 2b41d19..0000000 --- a/amqtt/version.py +++ /dev/null @@ -1,75 +0,0 @@ -try: - from datetime import UTC, datetime -except ImportError: - from datetime import datetime, timezone - - UTC = timezone.utc - -import logging -from pathlib import Path -import shutil -import subprocess -import warnings - -import amqtt - -logger = logging.getLogger(__name__) - - -def get_version() -> str: - """Return the version of the amqtt package. - - This function is deprecated. Use amqtt.__version__ instead. - """ - warnings.warn( - "amqtt.version.get_version() is deprecated, use amqtt.__version__ instead", - stacklevel=3, # Adjusted stack level to better reflect the caller - ) - return amqtt.__version__ - - -def get_git_changeset() -> str | None: - """Return a numeric identifier of the latest git changeset. - - The result is the UTC timestamp of the changeset in YYYYMMDDHHMMSS format. - This value isn't guaranteed to be unique, but collisions are very unlikely, - so it's sufficient for generating the development version numbers. - """ - # Define the repository directory (two levels above the current script) - repo_dir = Path(__file__).resolve().parent.parent - - # Ensure the directory exists and is valid - if not repo_dir.is_dir(): - logger.error(f"Invalid directory: {repo_dir} is not a valid directory") - return None - - # Use the system's PATH to locate 'git', or define the full path if necessary - git_path = "git" # Assuming git is available in the system PATH - - # Ensure 'git' is executable and available - if not shutil.which(git_path): - logger.error(f"{git_path} is not found in the system PATH.") - return None - - # Call git log to get the latest changeset timestamp - try: - with subprocess.Popen( # noqa: S603 - [git_path, "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=repo_dir, - universal_newlines=True, - ) as git_log: - timestamp_str, stderr = git_log.communicate() - - if git_log.returncode != 0: - logger.error(f"Git command failed with error: {stderr}") - return None - - # Convert the timestamp to a datetime object - timestamp = datetime.fromtimestamp(int(timestamp_str), tz=UTC) - return timestamp.strftime("%Y%m%d%H%M%S") - - except Exception: - logger.exception("An error occurred while retrieving the git changeset.") - return None diff --git a/docs/custom_plugins.md b/docs/custom_plugins.md index 34eb4ec..da2d56b 100644 --- a/docs/custom_plugins.md +++ b/docs/custom_plugins.md @@ -1,35 +1,79 @@ +from dataclasses import dataclass + # Custom Plugins -With the aMQTT Broker plugins framework, one can add additional functionality to the broker without -having to subclass or rewrite any of the core broker logic. To define a custom list of plugins to be loaded, -add this section to your `pyproject.toml`" +With the aMQTT plugins framework, one can add additional functionality to the client or broker without +having to rewrite any of the core logic. -```toml -[project.entry-points."mypackage.mymodule.plugins"] -plugin_alias = "module.submodule.file:ClassName" -``` - -and specify the namespace when instantiating the broker: +To create a custom plugin, subclass from `BasePlugin` (client or broker) or `BaseAuthPlugin` (broker only) +or `BaseTopicPlugin` (broker only). Each custom plugin may define settings specific to itself by creating +a nested (or inner) `dataclass` named `Config` which declares each option and a default value (if applicable). A +plugin's configuration dataclass will be type-checked and made available from within the `self.context` instance variable. ```python -from amqtt.broker import Broker +from dataclasses import dataclass, field +from amqtt.plugins.base import BasePlugin +from amqtt.contexts import BaseContext -broker = Broker(plugin_namespace='mypackage.mymodule.plugins') + +class OneClassName(BasePlugin[BaseContext]): + """This is a plugin with no functionality""" + + +class TwoClassName(BasePlugin[BaseContext]): + """This is a plugin with configuration options.""" + def __init__(self, context: BaseContext): + super().__init__(context) + my_option_one: str = self.context.config.option1 + + @dataclass + class Config: + option1: int + option3: str = field(default="my_default_value") ``` -Each plugin has access to the full configuration file through the provided `BaseContext` and can define -its own variables to configure its behavior. +This plugin class then should be added to the configuration file of the broker or client (or to the `config` +dictionary passed to the `Broker` or `MQTTClient`). + +```yaml +... +... +plugins: + - module.submodule.file.OneClassName: + - module.submodule.file.TwoClassName: + option1: 123 +``` + +??? warning "Deprecated: activating plugins using `EntryPoints`" + With the aMQTT plugins framework, one can add additional functionality to the client or broker without + having to rewrite any of the core logic. To define a custom list of plugins to be loaded, add this section + to your `pyproject.toml`" + + ```toml + [project.entry-points."mypackage.mymodule.plugins"] + plugin_alias = "module.submodule.file:ClassName" + ``` + + Each plugin has access to the full configuration file through the provided `BaseContext` and can define its own + variables to configure its behavior. ::: amqtt.plugins.base.BasePlugin ## Events -Plugins that are defined in the`project.entry-points` are notified of events if the subclass -implements one or more of these methods: +All plugins are notified of events if the `BasePlugin` subclass implements one or more of these methods: -- `async def on_mqtt_packet_sent(self, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None` -- `async def on_mqtt_packet_received(self, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None` +### Client and Broker + +- `async def on_mqtt_packet_sent(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None` +- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None` + +### Client Only + +none + +### Broker Only - `async def on_broker_pre_start() -> None` - `async def on_broker_post_start() -> None` @@ -47,32 +91,22 @@ implements one or more of these methods: ## Authentication Plugins -Of the plugins listed in `project.entry-points`, one or more can be used to validate client sessions -by specifying their alias in `auth` > `plugins` section of the config: - -```yaml -auth: - plugins: - - plugin_alias_name -``` - -These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method. +In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin` +is used by the aMQTT `Broker` to determine if a connection from a client is allowed by +implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise. ::: amqtt.plugins.base.BaseAuthPlugin ## Topic Filter Plugins -Of the plugins listed in `project.entry-points`, one or more can be used to determine topic access -by specifying their alias in `topic-check` > `plugins` section of the config: - -```yaml -topic-check: - enable: True - plugins: - - plugin_alias_name -``` - -These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method. - +In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin` +is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE) +messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or +`False` otherwise. ::: amqtt.plugins.base.BaseTopicPlugin + + +!!! note + A custom plugin class can subclass from both `BaseAuthPlugin` and `BaseTopicPlugin` as long it defines + both the `authenticate` and `topic_filtering` method. diff --git a/docs/packaged_plugins.md b/docs/packaged_plugins.md index d3dca4d..5adf827 100644 --- a/docs/packaged_plugins.md +++ b/docs/packaged_plugins.md @@ -1,48 +1,99 @@ # Existing Plugins -With the aMQTT Broker plugins framework, one can add additional functionality without -having to rewrite core logic. Plugins loaded by default are specified in `pyproject.toml`: +With the aMQTT plugins framework, one can add additional functionality without +having to rewrite core logic in the broker or client. Plugins can be loaded and configured using +the `plugins` section of the config file (or parameter passed to the class). + + +## Broker + +By default, `EventLoggerPlugin`, `PacketLoggerPlugin`, `AnonymousAuthPlugin` and `BrokerSysPlugin` are activated +and configured for the broker: ```yaml ---8<-- "pyproject.toml:included" +--8<-- "amqtt/scripts/default_broker.yaml" ``` -## auth_anonymous (Auth Plugin) -`amqtt.plugins.authentication:AnonymousAuthPlugin` +??? warning "Loading plugins from EntryPoints in `pyproject.toml` has been deprecated" + Previously, all plugins were loaded from EntryPoints: + + ```toml + --8<-- "pyproject.toml:included" + ``` + + But the same 4 plugins were activated in the previous default config: + + ```yaml + --8<-- "samples/legacy.yaml" + ``` + +## Client + +By default, the `PacketLoggerPlugin` is activated and configured for the client: + +```yaml +--8<-- "amqtt/scripts/default_client.yaml" +``` + +## Plugins + +### Anonymous (Auth Plugin) + +`amqtt.plugins.authentication.AnonymousAuthPlugin` **Configuration** ```yaml -auth: - plugins: - - auth_anonymous - allow-anonymous: true # if false, providing a username will allow access +plugins: + - ... + - amqtt.plugins.authentication.AnonymousAuthPlugin: + allow_anonymous: false + - ... ``` !!! danger - even if `allow-anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client + even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client -## auth_file (Auth Plugin) +??? warning "EntryPoint-style configuration is deprecated" -`amqtt.plugins.authentication:FileAuthPlugin` + ```yaml + auth: + plugins: + - auth_anonymous + allow-anonymous: true # if false, providing a username will allow access + + ``` + +### Password File (Auth Plugin) + +`amqtt.plugins.authentication.FileAuthPlugin` clients are authorized by providing username and password, compared against file **Configuration** ```yaml - -auth: - plugins: - - auth_file - password-file: /path/to/password_file - +plugins: + - ... + - amqtt.plugins.authentication.FileAuthPlugin: + password_file: /path/to/password_file + - ... ``` +??? warning "EntryPoint-style configuration is deprecated" + ```yaml + + auth: + plugins: + - auth_file + password-file: /path/to/password_file + + ``` + **File Format** The file includes `username:password` pairs, one per line. @@ -58,33 +109,42 @@ passwd = input() if not sys.stdin.isatty() else getpass() print(sha512_crypt.hash(passwd)) ``` -## Taboo (Topic Plugin) +### Taboo (Topic Plugin) -`amqtt.plugins.topic_checking:TopicTabooPlugin` +`amqtt.plugins.topic_checking.TopicTabooPlugin` Prevents using topics named: `prohibited`, `top-secret`, and `data/classified` **Configuration** ```yaml -topic-check: - enabled: true - plugins: - - topic_taboo +plugins: + - ... + - amqtt.plugins.topic_checking.TopicTabooPlugin: + - ... ``` -## ACL (Topic Plugin) +??? warning "EntryPoint-style configuration is deprecated" -`amqtt.plugins.topic_checking:TopicAccessControlListPlugin` + ```yaml + topic-check: + enabled: true + plugins: + - topic_taboo + ``` + +### ACL (Topic Plugin) + +`amqtt.plugins.topic_checking.TopicAccessControlListPlugin` **Configuration** -- `acl` *(list)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access. +- `acl` *(mapping)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access. The list should be a key-value pair, where: `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). -- `publish-acl` *(list)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: +- `publish-acl` *(mapping)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). !!! info "Reserved usernames" @@ -93,47 +153,89 @@ topic-check: - The username `anonymous` will control allowed topics, if using the `auth_anonymous` plugin. ```yaml -topic-check: - enabled: true - plugins: - - topic_acl - publish-acl: - - username: ["list", "of", "allowed", "topics", "for", "publishing"] - - . - acl: - - username: ["list", "of", "allowed", "topics", "for", "subscribing"] - - . +plugins: + - ... + - amqtt.plugins.topic_checking.TopicAccessControlListPlugin: + publish_acl: + - username: ["list", "of", "allowed", "topics", "for", "publishing"] + acl: + - username: ["list", "of", "allowed", "topics", "for", "subscribing"] + - ... ``` -## Plugin: $SYS +??? warning "EntryPoint-style configuration is deprecated" + ```yaml + topic-check: + enabled: true + plugins: + - topic_acl + publish-acl: + - username: ["list", "of", "allowed", "topics", "for", "publishing"] + - . + acl: + - username: ["list", "of", "allowed", "topics", "for", "subscribing"] + - . + ``` -`amqtt.plugins.sys.broker:BrokerSysPlugin` +### $SYS topics + +`amqtt.plugins.sys.broker.BrokerSysPlugin` Publishes, on a periodic basis, statistics about the broker **Configuration** - `sys_interval` - int, seconds between updates +```yaml +plugins: + - ... + - amqtt.plugins.sys.broker.BrokerSysPlugin: + sys_interval: 20 # int, seconds between updates + - ... +``` **Supported Topics** -- `$SYS/broker/version` - payload: `str` -- `$SYS/broker/load/bytes/received` - payload: `int` -- `$SYS/broker/load/bytes/sent` - payload: `int` -- `$SYS/broker/messages/received` - payload: `int` -- `$SYS/broker/messages/sent` - payload: `int` -- `$SYS/broker/time` - payload: `int` (current time, epoch seconds) -- `$SYS/broker/uptime` - payload: `int` (seconds since broker start) -- `$SYS/broker/uptime/formatted` - payload: `str` (start time of broker in UTC) -- `$SYS/broker/clients/connected` - payload: `int` (current number of connected clients) -- `$SYS/broker/clients/disconnected` - payload: `int` (number of clients that have disconnected) -- `$SYS/broker/clients/maximum` - payload: `int` -- `$SYS/broker/clients/total` - payload: `int` -- `$SYS/broker/messages/inflight` - payload: `int` -- `$SYS/broker/messages/inflight/in` - payload: `int` -- `$SYS/broker/messages/inflight/out` - payload: `int` -- `$SYS/broker/messages/inflight/stored` - payload: `int` -- `$SYS/broker/messages/publish/received` - payload: `int` -- `$SYS/broker/messages/publish/sent` - payload: `int` -- `$SYS/broker/messages/retained/count` - payload: `int` -- `$SYS/broker/messages/subscriptions/count` - payload: `int` +- `$SYS/broker/version` *(string)* +- `$SYS/broker/load/bytes/received` *(int)* +- `$SYS/broker/load/bytes/sent` *(int)* +- `$SYS/broker/messages/received` *(int)* +- `$SYS/broker/messages/sent` *(int)* +- `$SYS/broker/time` *(int, current time in epoch seconds)* +- `$SYS/broker/uptime` *(int, seconds since broker start)* +- `$SYS/broker/uptime/formatted` *(string, start time of broker in UTC)* +- `$SYS/broker/clients/connected` *(int, number of currently connected clients)* +- `$SYS/broker/clients/disconnected` *(int, number of clients that have disconnected)* +- `$SYS/broker/clients/maximum` *(int, maximum number of clients connected)* +- `$SYS/broker/clients/total` *(int)* +- `$SYS/broker/messages/inflight` *(int)* +- `$SYS/broker/messages/inflight/in` *(int)* +- `$SYS/broker/messages/inflight/out` *(int)* +- `$SYS/broker/messages/inflight/stored` *(int)* +- `$SYS/broker/messages/publish/received` *(int)* +- `$SYS/broker/messages/publish/sent` *(int)* +- `$SYS/broker/messages/retained/count` *(int)* +- `$SYS/broker/messages/subscriptions/count` *(int)* +- `$SYS/broker/heap/size` *(float, MB)* +- `$SYS/broker/heap/maximum` *(float, MB)* +- `$SYS/broker/cpu/percent` *(float, %)* +- `$SYS/broker/cpu/maximum` *(float, %)* + + +### Event Logger + +`amqtt.plugins.logging_amqtt.EventLoggerPlugin` + +This plugin issues log messages when [broker and mqtt events](custom_plugins.md#events) are triggered: + +- info level messages for `client connected` and `client disconnected` +- debug level for all others + + +### Packet Logger + +`amqtt.plugins.logging_amqtt.PacketLoggerPlugin` + +This plugin issues debug-level messages for [mqtt events](custom_plugins.md#client-and-broker): `on_mqtt_packet_sent` +and `on_mqtt_packet_received`. + diff --git a/docs/references/broker_config.md b/docs/references/broker_config.md index 30ec0d1..7afbc9d 100644 --- a/docs/references/broker_config.md +++ b/docs/references/broker_config.md @@ -3,7 +3,7 @@ This configuration structure is valid as a python dictionary passed to the `amqtt.broker.Broker` class's `__init__` method or as a yaml formatted file passed to the `amqtt` script. -### `listeners` *(list[mapping])* +### `listeners` *(list[dict[str, Any]])* Defines the network listeners used by the service. Items defined in the `default` listener will be applied to all other listeners, unless they are overridden by the configuration for the specific @@ -20,77 +20,77 @@ listener. - `certfile` *(string)*: Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity. - `keyfile` *(string): A file containing the private key. Otherwise the private key will be taken from `certfile` as well. -### `sys_interval` *(int)* - -System status report interval in seconds (`broker_sys` plugin) - ### `timeout-disconnect-delay` *(int)* Client disconnect timeout without a keep-alive -### `auth` *(mapping)* +### `plugins` *(mapping)* -Configuration for authentication behaviour: - -- `plugins` *(list[string])*: defines the list of plugins which are activated as authentication plugins. - - !!! note "Entry points" - Plugins used here must first be defined in the `amqtt.broker.plugins` [entry point](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata). +A list of strings representing the modules and class name of `BasePlugin`, `BaseAuthPlugin` and `BaseTopicPlugins`. Each +entry may have one or more configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md) - !!! danger "Legacy behavior" - if `plugins` is omitted from the `auth` section, all plugins listed in the `amqtt.broker.plugins` entrypoint will be enabled - for authentication, *including allowing anonymous login.* - - `plugins: []` will deny connections from all clients. - -- `allow-anonymous` *(bool)*: `True` will allow anonymous connections. - - *Used by the internal `amqtt.plugins.authentication.AnonymousAuthPlugin` plugin* +??? warning "Deprecated: `sys_interval` " + **`sys_interval`** *(int)* - !!! danger "Username only connections" - `False` does not disable the `auth_anonymous` plugin; connections will still be allowed as long as a username is provided. - - If security is required, do not include `auth_anonymous` in the `plugins` list. + System status report interval in seconds, used by the `amqtt.plugins.sys.broker.BrokerSysPlugin`. -- `password-file` *(string)*: Path to file which includes `username:password` pair, one per line. The password should be encoded using sha-512 with `mkpasswd -m sha-512` or: - ```python - import sys - from getpass import getpass - from passlib.hash import sha512_crypt - - passwd = input() if not sys.stdin.isatty() else getpass() - print(sha512_crypt.hash(passwd)) - ``` - - *Used by the internal `amqtt.plugins.authentication.FileAuthPlugin` plugin.* -### `topic-check` *(mapping)* +??? warning "Deprecated: `auth` configuration settings" -Configuration for access control policies for publishing and subscribing to topics: + **`auth`** + + Configuration for authentication behaviour: + + - `plugins` *(list[string])*: defines the list of plugins which are activated as authentication plugins. + + !!! note + Plugins used here must first be defined in the `amqtt.broker.plugins` [entry point](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata). + -- `enabled` *(bool)*: Enable access control policies (`true`). `false` will allow clients to publish and subscribe to any topic. -- `plugins` *(list[string])*: defines the list of plugins which are activated as access control plugins. Note the plugins must be defined in the `amqtt.broker.plugins` [entry point](https://pythonhosted.org/setuptools/setuptools.html#dynamic-discovery-of-services-and-plugins). + !!! warning + If `plugins` is omitted from the `auth` section, all plugins listed in the `amqtt.broker.plugins` entrypoint will be enabled + for authentication, including _allowing anonymous login._ + + `plugins: []` will deny connections from all clients. + + - `allow-anonymous` *(bool)*: `True` will allow anonymous connections, used by `amqtt.plugins.authentication.AnonymousAuthPlugin`. + + !!! danger + `False` does not disable the `auth_anonymous` plugin; connections will still be allowed as long as a username is provided. If security is required, do not include `auth_anonymous` in the `plugins` list. + -- `acl` *(list)*: plugin to determine subscription access; if `publish-acl` is not specified, determine both publish and subscription access. - The list should be a key-value pair, where: -`:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + - `password-file` *(string)*. Path to sha-512 encoded password file, used by `amqtt.plugins.authentication.FileAuthPlugin`. - *used by the `amqtt.plugins.topic_acl.TopicAclPlugin`* - -- `publish-acl` *(list)*: plugin to determine publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: -`:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). - - !!! info "Reserved usernames" - - - The username `admin` is allowed access to all topic. - - The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin. +??? warning "Deprecated: `topic-check` configuration settings" - *used by the `amqtt.plugins.topic_acl.TopicAclPlugin`* + **`topic-check`** + + Configuration for access control policies for publishing and subscribing to topics: + + - `enabled` *(bool)*: Enable access control policies (`true`). `false` will allow clients to publish and subscribe to any topic. + - `plugins` *(list[string])*: defines the list of plugins which are activated as access control plugins. Note the plugins must be defined in the `amqtt.broker.plugins` [entry point](https://pythonhosted.org/setuptools/setuptools.html#dynamic-discovery-of-services-and-plugins). + + - `acl` *(list)*: plugin to determine subscription access; if `publish-acl` is not specified, determine both publish and subscription access. + The list should be a key-value pair, where: + `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + + *used by the `amqtt.plugins.topic_acl.TopicAclPlugin`* + + - `publish-acl` *(list)*: plugin to determine publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: + `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + + _Reserved usernames (used by the `amqtt.plugins.topic_acl.TopicAclPlugin`)_ + + - The username `admin` is allowed access to all topic. + - The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin. + + + @@ -130,14 +130,13 @@ listeners: certfile: /some/certfile keyfile: /some/key timeout-disconnect-delay: 2 -auth: - plugins: ['auth_anonymous', 'auth_file'] - allow-anonymous: true - password-file: /some/password-file -topic-check: - enabled: true - plugins: ['topic_acl'] - acl: +plugins: + - amqtt.plugins.authentication.AnonymousAuthPlugin: + allow-anonymous: true + - amqtt.plugin.authentication.FileAuthPlugin: + password-file: /some/password-file + - amqtt.plugins.topic_checking.TopicAccessControlListPlugin: + acl: username1: ['repositories/+/master', 'calendar/#', 'data/memes'] username2: [ 'calendar/2025/#', 'data/memes'] anonymous: ['calendar/2025/#'] diff --git a/docs/references/client_config.md b/docs/references/client_config.md index 67f875f..9581364 100644 --- a/docs/references/client_config.md +++ b/docs/references/client_config.md @@ -68,8 +68,8 @@ TLS certificates used to verify the broker's authenticity. - `cafile` *(string)*: Path to a file of concatenated CA certificates in PEM format. See [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info. - `capath` *(string)*: Path to a directory containing several CA certificates in PEM format, following an [OpenSSL specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/). - `cadata` *(string)*: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. -- -- + + ### `certfile` *(string)* Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the server certificate's authenticity. @@ -78,6 +78,11 @@ Path to a single file in PEM format containing the certificate as well as any nu Bypass ssl host certificate verification, allowing self-signed certificates +### `plugins` *(mapping)* + +A list of strings representing the modules and class name of any `BasePlugin`s. Each entry may have one or more +configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md) + ## Default Configuration @@ -110,4 +115,7 @@ will: broker: uri: mqtt://localhost:1883 cafile: /path/to/ca/file +plugins: + - amqtt.plugins.logging_amqtt.PacketLoggerPlugin: + ``` diff --git a/pyproject.toml b/pyproject.toml index 6856f53..4a2b35f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "passlib==1.7.4", # https://pypi.org/project/passlib "PyYAML==6.0.2", # https://pypi.org/project/PyYAML "typer==0.15.4", + "dacite>=1.9.2", "psutil>=7.0.0", ] diff --git a/samples/legacy.yaml b/samples/legacy.yaml new file mode 100644 index 0000000..6f11d92 --- /dev/null +++ b/samples/legacy.yaml @@ -0,0 +1,13 @@ +--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +sys_interval: 20 +auth: + plugins: + - auth_anonymous + allow-anonymous: true +topic-check: + enabled: False + diff --git a/tests/conftest.py b/tests/conftest.py index 65128f3..45d1f68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,11 @@ test_config_acl: dict[str, int | dict[str, Any]] = { @pytest.fixture def mock_plugin_manager(): with unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager: + plugin_manager_instance = plugin_manager.return_value + + # disable topic filtering when using the mock manager + plugin_manager_instance.is_topic_filtering_enabled.return_value = False + yield plugin_manager diff --git a/tests/plugins/mocks.py b/tests/plugins/mocks.py index dce94a6..def9920 100644 --- a/tests/plugins/mocks.py +++ b/tests/plugins/mocks.py @@ -2,10 +2,8 @@ import logging from dataclasses import dataclass -from amqtt.broker import Action - -from amqtt.plugins.base import BasePlugin, BaseTopicPlugin, BaseAuthPlugin -from amqtt.plugins.manager import BaseContext +from amqtt.plugins.base import BasePlugin, BaseAuthPlugin, BaseTopicPlugin +from amqtt.contexts import BaseContext, Action from amqtt.session import Session @@ -29,25 +27,43 @@ class TestConfigPlugin(BasePlugin): option2: str -class AuthPlugin(BaseAuthPlugin): +class TestCoroErrorPlugin(BaseAuthPlugin): + + def authenticate(self, *, session: Session) -> bool | None: + return True + + +class TestAuthPlugin(BaseAuthPlugin): async def authenticate(self, *, session: Session) -> bool | None: return True -class NoAuthPlugin(BaseAuthPlugin): +class TestNoAuthPlugin(BaseAuthPlugin): async def authenticate(self, *, session: Session) -> bool | None: return False -class TestTopicPlugin(BaseTopicPlugin): +class TestAllowTopicPlugin(BaseTopicPlugin): def __init__(self, context: BaseContext): super().__init__(context) - def topic_filtering( + async def topic_filtering( self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None ) -> bool: return True + + +class TestBlockTopicPlugin(BaseTopicPlugin): + + def __init__(self, context: BaseContext): + super().__init__(context) + + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + logger.debug("topic filtering plugin is returning false") + return False diff --git a/tests/plugins/test_authentication.py b/tests/plugins/test_authentication.py index 92d79c9..a500a73 100644 --- a/tests/plugins/test_authentication.py +++ b/tests/plugins/test_authentication.py @@ -3,19 +3,42 @@ import logging from pathlib import Path import unittest +import pytest + from amqtt.plugins.authentication import AnonymousAuthPlugin, FileAuthPlugin -from amqtt.plugins.manager import BaseContext +from amqtt.contexts import BaseContext +from amqtt.plugins.base import BaseAuthPlugin from amqtt.session import Session formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" logging.basicConfig(level=logging.DEBUG, format=formatter) +@pytest.mark.asyncio +async def test_base_no_config(logdog): + """Check BaseTopicPlugin returns false if no topic-check is present.""" + with logdog() as pile: + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = {} + + plugin = BaseAuthPlugin(context) + s = Session() + authorised = await plugin.authenticate(session=s) + assert authorised is False + + # Warning messages are only generated if using deprecated plugin configuration on initial load + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 1 + assert log_records[0].levelno == logging.WARNING + assert log_records[0].message == "'auth' section not found in context configuration" + + class TestAnonymousAuthPlugin(unittest.TestCase): def setUp(self) -> None: self.loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - def test_allow_anonymous(self) -> None: + def test_allow_anonymous_dict_config(self) -> None: context = BaseContext() context.logger = logging.getLogger(__name__) context.config = {"auth": {"allow-anonymous": True}} @@ -25,6 +48,16 @@ class TestAnonymousAuthPlugin(unittest.TestCase): ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s)) assert ret + def test_allow_anonymous_dataclass_config(self) -> None: + context = BaseContext() + context.logger = logging.getLogger(__name__) + context.config = AnonymousAuthPlugin.Config(allow_anonymous=True) + s = Session() + s.username = "" + auth_plugin = AnonymousAuthPlugin(context) + ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s)) + assert ret + def test_disallow_anonymous(self) -> None: context = BaseContext() context.logger = logging.getLogger(__name__) diff --git a/tests/plugins/test_config.py b/tests/plugins/test_config.py new file mode 100644 index 0000000..0e9e8e4 --- /dev/null +++ b/tests/plugins/test_config.py @@ -0,0 +1,219 @@ +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Any + +import pytest +import yaml + +from amqtt.broker import Broker +from yaml import CLoader as Loader +from dacite import from_dict, Config, UnexpectedDataError + +from amqtt.client import MQTTClient +from amqtt.errors import PluginLoadError, ConnectError, PluginCoroError +from amqtt.mqtt.constants import QOS_0 + +logger = logging.getLogger(__name__) + +plugin_config = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestSimplePlugin: + - tests.plugins.mocks.TestConfigPlugin: + option1: 1 + option2: bar +""" + + +plugin_invalid_config_one = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestSimplePlugin: + option1: 1 + option2: bar +""" + +plugin_invalid_config_two = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestConfigPlugin: +""" + +plugin_coro_error_config = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestCoroErrorPlugin: +""" + +plugin_config_auth = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestAuthPlugin: +""" + +plugin_config_no_auth = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestNoAuthPlugin: +""" + + +plugin_config_topic = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestAllowTopicPlugin: +""" + + +plugin_config_topic_block = """--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - tests.plugins.mocks.TestBlockTopicPlugin: +""" + + + +@pytest.mark.asyncio +async def test_plugin_config_extra_fields(): + + cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader) + + with pytest.raises(PluginLoadError): + _ = Broker(config=cfg) + + +@pytest.mark.asyncio +async def test_plugin_config_missing_fields(): + cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader) + + with pytest.raises(PluginLoadError): + _ = Broker(config=cfg) + + +@pytest.mark.asyncio +async def test_alternate_plugin_load(): + + cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader) + + broker = Broker(config=cfg) + await broker.start() + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_coro_error_plugin_load(): + + cfg: dict[str, Any] = yaml.load(plugin_coro_error_config, Loader=Loader) + + with pytest.raises(PluginCoroError): + _ = Broker(config=cfg) + + +@pytest.mark.asyncio +async def test_auth_plugin_load(): + cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader) + broker = Broker(config=cfg) + await broker.start() + await asyncio.sleep(0.5) + + client1 = MQTTClient() + await client1.connect() + await client1.publish('my/topic', b'my message') + await client1.disconnect() + + await asyncio.sleep(0.5) + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_no_auth_plugin_load(): + cfg: dict[str, Any] = yaml.load(plugin_config_no_auth, Loader=Loader) + broker = Broker(config=cfg) + await broker.start() + await asyncio.sleep(0.5) + + client1 = MQTTClient(config={'auto_reconnect': False}) + with pytest.raises(ConnectError): + await client1.connect() + + await asyncio.sleep(0.5) + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_allow_topic_plugin_load(): + cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader) + broker = Broker(config=cfg) + await broker.start() + await asyncio.sleep(0.5) + + client2 = MQTTClient(config={'auto_reconnect': False}) + await client2.connect() + await client2.subscribe([ + ('my/topic', QOS_0) + ]) + + client1 = MQTTClient(config={'auto_reconnect': True}) + await client1.connect() + await client1.publish('my/topic', b'my message') + + message = await client2.deliver_message(timeout_duration=1) + assert message.topic == 'my/topic' + assert message.data == b'my message' + + await client2.disconnect() + await client1.disconnect() + + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_block_topic_plugin_load(): + cfg: dict[str, Any] = yaml.load(plugin_config_topic_block, Loader=Loader) + broker = Broker(config=cfg) + await broker.start() + await asyncio.sleep(0.5) + + client2 = MQTTClient(config={'auto_reconnect': False}) + await client2.connect() + await client2.subscribe([ + ('my/topic', QOS_0) + ]) + + client1 = MQTTClient(config={'auto_reconnect': True}) + await client1.connect() + await client1.publish('my/topic', b'my message') + + with pytest.raises(asyncio.TimeoutError): + message = await client2.deliver_message(timeout_duration=1) + logger.debug(f"msg received: {message.topic} >> {message.data}") + + await client2.disconnect() + await client1.disconnect() + + await broker.shutdown() diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index 5164534..a081e6f 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -2,10 +2,11 @@ import asyncio import logging import unittest -from amqtt.broker import Action from amqtt.events import BrokerEvents -from amqtt.plugins.manager import BaseContext, PluginManager -from amqtt.plugins.base import BaseTopicPlugin, BaseAuthPlugin + +from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin +from amqtt.plugins.manager import PluginManager +from amqtt.contexts import BaseContext, Action from amqtt.session import Session formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" diff --git a/tests/plugins/test_persistence.py b/tests/plugins/test_persistence.py index 75f1939..35aebd3 100644 --- a/tests/plugins/test_persistence.py +++ b/tests/plugins/test_persistence.py @@ -4,7 +4,7 @@ from pathlib import Path import sqlite3 import unittest -from amqtt.plugins.manager import BaseContext +from amqtt.contexts import BaseContext from amqtt.plugins.persistence import SQLitePlugin from amqtt.session import Session diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index bda597d..61ce0e1 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -13,11 +13,11 @@ import pytest import amqtt.plugins from amqtt.broker import Broker, BrokerContext from amqtt.client import MQTTClient -from amqtt.errors import PluginError, PluginInitError, PluginImportError +from amqtt.errors import PluginInitError, PluginImportError from amqtt.events import MQTTEvents, BrokerEvents from amqtt.mqtt.constants import QOS_0 from amqtt.plugins.base import BasePlugin -from amqtt.plugins.manager import BaseContext +from amqtt.contexts import BaseContext _INVALID_METHOD: str = "invalid_foo" _PLUGIN: str = "Plugin" @@ -82,54 +82,36 @@ class MockInitErrorPlugin(BasePlugin): @pytest.mark.asyncio async def test_plugin_exception_while_init() -> None: - class MockEntryPoints: - def select(self, group) -> list[EntryPoint]: - match group: - case 'tests.mock_plugins': - return [ - EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'), - ] - case _: - return list() - - with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: - - config = { - "listeners": { - "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, - }, - 'sys_interval': 1 + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1, + 'plugins':{ + 'tests.plugins.test_plugins.MockInitErrorPlugin':{} } + } - with pytest.raises(PluginInitError): - _ = Broker(plugin_namespace='tests.mock_plugins', config=config) + with pytest.raises(PluginInitError): + _ = Broker(plugin_namespace='tests.mock_plugins', config=config) @pytest.mark.asyncio async def test_plugin_exception_while_loading() -> None: - class MockEntryPoints: - def select(self, group) -> list[EntryPoint]: - match group: - case 'tests.mock_plugins': - return [ - EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'), - ] - case _: - return list() - - with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: - - config = { - "listeners": { - "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, - }, - 'sys_interval': 1 + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1, + 'plugins':{ + 'tests.plugins.mock_plugins.MockImportErrorPlugin':{} } + } - with pytest.raises(PluginImportError): - _ = Broker(plugin_namespace='tests.mock_plugins', config=config) + with pytest.raises(PluginImportError): + _ = Broker(plugin_namespace='tests.mock_plugins', config=config) class AllEventsPlugin(BasePlugin[BaseContext]): @@ -153,47 +135,37 @@ class AllEventsPlugin(BasePlugin[BaseContext]): if name not in ('authenticate', 'topic_filtering'): pytest.fail(f'unexpected method called: {name}') + @pytest.mark.asyncio async def test_all_plugin_events(): - class MockEntryPoints: - def select(self, group) -> list[EntryPoint]: - match group: - case 'tests.mock_plugins': - return [ - EntryPoint(name='AllEventsPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:AllEventsPlugin'), - ] - case _: - return list() - - # patch the entry points so we can load our test plugin - with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: - - config = { - "listeners": { - "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, - }, - 'sys_interval': 1 + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1, + 'plugins':{ + 'tests.plugins.test_plugins.AllEventsPlugin': {} } + } + broker = Broker(plugin_namespace='tests.mock_plugins', config=config) - broker = Broker(plugin_namespace='tests.mock_plugins', config=config) + await broker.start() + await asyncio.sleep(2) - await broker.start() - await asyncio.sleep(2) + # make sure all expected events get triggered + client = MQTTClient() + await client.connect("mqtt://127.0.0.1:1883/") + await client.subscribe([('my/test/topic', QOS_0),]) + await client.publish('test/topic', b'my test message') + await client.unsubscribe(['my/test/topic',]) + await client.disconnect() + await asyncio.sleep(1) - # make sure all expected events get triggered - client = MQTTClient() - await client.connect("mqtt://127.0.0.1:1883/") - await client.subscribe([('my/test/topic', QOS_0),]) - await client.publish('test/topic', b'my test message') - await client.unsubscribe(['my/test/topic',]) - await client.disconnect() - await asyncio.sleep(1) + # get the plugin so it doesn't get gc on shutdown + test_plugin = broker.plugins_manager.get_plugin('AllEventsPlugin') + await broker.shutdown() + await asyncio.sleep(1) - # get the plugin so it doesn't get gc on shutdown - test_plugin = broker.plugins_manager.get_plugin('AllEventsPlugin') - await broker.shutdown() - await asyncio.sleep(1) - - assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}' + assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}' diff --git a/tests/plugins/test_sys.py b/tests/plugins/test_sys.py index 563a313..c9b3bdd 100644 --- a/tests/plugins/test_sys.py +++ b/tests/plugins/test_sys.py @@ -1,6 +1,7 @@ import asyncio import logging from importlib.metadata import EntryPoint +from logging.config import dictConfig from unittest.mock import patch import pytest @@ -9,8 +10,34 @@ from amqtt.broker import Broker from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_0 +dictConfig({ + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' + } + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'DEBUG', + 'formatter': 'verbose', + } + }, + 'loggers': { + 'transitions': { + 'level': 'WARNING', + } + } +}) + +# logging.basicConfig(level=logging.DEBUG, format=formatter) + logger = logging.getLogger(__name__) + + all_sys_topics = [ '$SYS/broker/version', '$SYS/broker/load/bytes/received', @@ -42,7 +69,7 @@ all_sys_topics = [ # test broker sys @pytest.mark.asyncio -async def test_broker_sys_plugin() -> None: +async def test_broker_sys_plugin_deprecated_config() -> None: sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics} @@ -64,7 +91,8 @@ async def test_broker_sys_plugin() -> None: "listeners": { "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, }, - 'sys_interval': 1 + 'sys_interval': 1, + 'auth': {} } broker = Broker(plugin_namespace='tests.mock_plugins', config=config) @@ -92,3 +120,45 @@ async def test_broker_sys_plugin() -> None: assert sys_msg_count > 1 assert all(sys_topic_flags.values()), f'topic not received: {[ topic for topic, flag in sys_topic_flags.items() if not flag ]}' + + +@pytest.mark.asyncio +async def test_broker_sys_plugin_config() -> None: + + sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics} + + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'plugins': [ + {'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}}, + ] + } + + broker = Broker(plugin_namespace='tests.mock_plugins', config=config) + await broker.start() + client = MQTTClient() + await client.connect("mqtt://127.0.0.1:1883/") + await client.subscribe([("$SYS/#", QOS_0), ]) + await client.publish('test/topic', b'my test message') + await asyncio.sleep(2) + sys_msg_count = 0 + try: + while sys_msg_count < 30: + message = await client.deliver_message(timeout_duration=1) + if '$SYS' in message.topic: + sys_msg_count += 1 + assert message.topic in sys_topic_flags + sys_topic_flags[message.topic] = True + + except asyncio.TimeoutError: + logger.debug(f"TimeoutError after {sys_msg_count} messages") + + await client.disconnect() + await broker.shutdown() + + assert sys_msg_count > 1 + + assert all( + sys_topic_flags.values()), f'topic not received: {[topic for topic, flag in sys_topic_flags.items() if not flag]}' diff --git a/tests/plugins/test_topic_checking.py b/tests/plugins/test_topic_checking.py index ef5c69a..4d63a11 100644 --- a/tests/plugins/test_topic_checking.py +++ b/tests/plugins/test_topic_checking.py @@ -2,12 +2,15 @@ import logging import pytest -from amqtt.broker import Action, BrokerContext, Broker -from amqtt.plugins.manager import BaseContext +from amqtt.broker import BrokerContext, Broker + +from amqtt.contexts import BaseContext, Action from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabooPlugin from amqtt.plugins.base import BaseTopicPlugin from amqtt.session import Session +logger = logging.getLogger(__name__) + # Base plug-in object @@ -23,15 +26,11 @@ async def test_base_no_config(logdog): authorised = await plugin.topic_filtering() assert authorised is False - # Should have printed a couple of warnings - log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 2 - assert log_records[0].levelno == logging.WARNING - assert log_records[0].message == "'topic-check' section not found in context configuration" - - assert log_records[1].levelno == logging.WARNING - assert log_records[1].message == "'topic-check' section not found in context configuration" - assert pile.is_empty() + # Warning messages are only generated if using deprecated plugin configuration on initial load + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 1 + assert log_records[0].levelno == logging.WARNING + assert log_records[0].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio @@ -47,14 +46,11 @@ async def test_base_empty_config(logdog): authorised = await plugin.topic_filtering() assert authorised is False - # Should have printed just one warning - log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 2 - assert log_records[0].levelno == logging.WARNING - assert log_records[0].message == "'topic-check' section not found in context configuration" - - assert log_records[1].levelno == logging.WARNING - assert log_records[1].message == "'topic-check' section not found in context configuration" + # Warning messages are only generated if using deprecated plugin configuration on initial load + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 1 + assert log_records[0].levelno == logging.WARNING + assert log_records[0].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio @@ -69,9 +65,9 @@ async def test_base_disabled_config(logdog): authorised = await plugin.topic_filtering() assert authorised is True - # Should NOT have printed warnings - log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 0 + # Should NOT have printed warnings + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 0 @pytest.mark.asyncio @@ -86,9 +82,9 @@ async def test_base_enabled_config(logdog): authorised = await plugin.topic_filtering() assert authorised is True - # Should NOT have printed warnings - log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 0 + # Should NOT have printed warnings + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 0 # Taboo plug-in @@ -105,13 +101,11 @@ async def test_taboo_empty_config(logdog): plugin = TopicTabooPlugin(context) assert (await plugin.topic_filtering()) is False - # Should have printed a couple of warnings - log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 2 - assert log_records[0].levelno == logging.WARNING - assert log_records[0].message == "'topic-check' section not found in context configuration" - assert log_records[1].levelno == logging.WARNING - assert log_records[1].message == "'topic-check' section not found in context configuration" + # Warning messages are only generated if using deprecated plugin configuration on initial load + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 1 + assert log_records[0].levelno == logging.WARNING + assert log_records[0].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio @@ -133,13 +127,17 @@ async def test_taboo_disabled(logdog): assert len(log_records) == 0 +@pytest.mark.parametrize("test_config", [ + ({"topic-check": {"enabled": True}}), + (TopicTabooPlugin.Config()) +]) @pytest.mark.asyncio -async def test_taboo_not_taboo_topic(logdog): +async def test_taboo_not_taboo_topic(logdog, test_config): """Check TopicTabooPlugin returns true if topic not taboo.""" with logdog() as pile: context = BaseContext() context.logger = logging.getLogger("testlog") - context.config = {"topic-check": {"enabled": True}} + context.config = test_config session = Session() session.username = "anybody" @@ -152,13 +150,17 @@ async def test_taboo_not_taboo_topic(logdog): assert len(log_records) == 0 +@pytest.mark.parametrize("test_config", [ + ({"topic-check": {"enabled": True}}), + (TopicTabooPlugin.Config()) +]) @pytest.mark.asyncio -async def test_taboo_anon_taboo_topic(logdog): +async def test_taboo_anon_taboo_topic(logdog, test_config): """Check TopicTabooPlugin returns false if topic is taboo and session is anonymous.""" with logdog() as pile: context = BaseContext() context.logger = logging.getLogger("testlog") - context.config = {"topic-check": {"enabled": True}} + context.config = test_config session = Session() session.username = "" @@ -171,13 +173,17 @@ async def test_taboo_anon_taboo_topic(logdog): assert len(log_records) == 0 +@pytest.mark.parametrize("test_config", [ + ({"topic-check": {"enabled": True}}), + (TopicTabooPlugin.Config()) +]) @pytest.mark.asyncio -async def test_taboo_notadmin_taboo_topic(logdog): +async def test_taboo_notadmin_taboo_topic(logdog, test_config): """Check TopicTabooPlugin returns false if topic is taboo and user is not "admin".""" with logdog() as pile: context = BaseContext() context.logger = logging.getLogger("testlog") - context.config = {"topic-check": {"enabled": True}} + context.config = test_config session = Session() session.username = "notadmin" @@ -189,14 +195,17 @@ async def test_taboo_notadmin_taboo_topic(logdog): log_records = list(pile.drain(name="testlog")) assert len(log_records) == 0 - +@pytest.mark.parametrize("test_config", [ + ({"topic-check": {"enabled": True}}), + (TopicTabooPlugin.Config()) +]) @pytest.mark.asyncio -async def test_taboo_admin_taboo_topic(logdog): +async def test_taboo_admin_taboo_topic(logdog, test_config): """Check TopicTabooPlugin returns true if topic is taboo and user is "admin".""" with logdog() as pile: context = BaseContext() context.logger = logging.getLogger("testlog") - context.config = {"topic-check": {"enabled": True}} + context.config = test_config session = Session() session.username = "admin" @@ -265,11 +274,11 @@ async def test_taclp_empty_config(logdog): plugin = TopicAccessControlListPlugin(context) assert (await plugin.topic_filtering()) is False - # Should have printed a couple of warnings - log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 2 - assert log_records[0].message == "'topic-check' section not found in context configuration" - assert log_records[1].message == "'topic-check' section not found in context configuration" + # Warning messages are only generated if using deprecated plugin configuration on initial load + log_records = list(pile.drain(name="testlog")) + assert len(log_records) == 1 + assert log_records[0].levelno == logging.WARNING + assert log_records[0].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio @@ -291,15 +300,19 @@ async def test_taclp_true_disabled(logdog): assert authorised is True +@pytest.mark.parametrize("test_config", [ + ({"topic-check": {"enabled": True}}), + (TopicAccessControlListPlugin.Config()) +]) @pytest.mark.asyncio -async def test_taclp_true_no_pub_acl(logdog): +async def test_taclp_true_no_pub_acl(logdog, test_config): """Check TopicAccessControlListPlugin returns true if action=publish and no publish-acl given. (This is for backward-compatibility with existing installations.). """ context = BaseContext() context.logger = logging.getLogger("testlog") - context.config = {"topic-check": {"enabled": True}} + context.config = test_config session = Session() session.username = "user" @@ -313,17 +326,23 @@ async def test_taclp_true_no_pub_acl(logdog): assert authorised is True -@pytest.mark.asyncio -async def test_taclp_false_sub_no_topic(logdog): - """Check TopicAccessControlListPlugin returns false user there is no topic.""" - context = BaseContext() - context.logger = logging.getLogger("testlog") - context.config = { +@pytest.mark.parametrize("test_config", [ + ({ "topic-check": { "enabled": True, "acl": {"anotheruser": ["allowed/topic", "another/allowed/topic/#"]}, }, - } + }), + (TopicAccessControlListPlugin.Config( + acl={"anotheruser": ["allowed/topic", "another/allowed/topic/#"]} + )) +]) +@pytest.mark.asyncio +async def test_taclp_false_sub_no_topic(logdog, test_config): + """Check TopicAccessControlListPlugin returns false user there is no topic.""" + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = test_config session = Session() session.username = "user" @@ -337,17 +356,23 @@ async def test_taclp_false_sub_no_topic(logdog): assert authorised is False -@pytest.mark.asyncio -async def test_taclp_false_sub_unknown_user(logdog): - """Check TopicAccessControlListPlugin returns false user is not listed in ACL.""" - context = BaseContext() - context.logger = logging.getLogger("testlog") - context.config = { +@pytest.mark.parametrize("test_config", [ + ({ "topic-check": { "enabled": True, "acl": {"anotheruser": ["allowed/topic", "another/allowed/topic/#"]}, }, - } + }), + (TopicAccessControlListPlugin.Config( + acl={"anotheruser": ["allowed/topic", "another/allowed/topic/#"]} + )) +]) +@pytest.mark.asyncio +async def test_taclp_false_sub_unknown_user(logdog, test_config): + """Check TopicAccessControlListPlugin returns false user is not listed in ACL.""" + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = test_config session = Session() session.username = "user" @@ -361,17 +386,23 @@ async def test_taclp_false_sub_unknown_user(logdog): assert authorised is False -@pytest.mark.asyncio -async def test_taclp_false_sub_no_permission(logdog): - """Check TopicAccessControlListPlugin returns false if "acl" does not list allowed topic.""" - context = BaseContext() - context.logger = logging.getLogger("testlog") - context.config = { +@pytest.mark.parametrize("test_config", [ + ({ "topic-check": { "enabled": True, "acl": {"user": ["allowed/topic", "another/allowed/topic/#"]}, }, - } + }), + (TopicAccessControlListPlugin.Config( + acl={"user": ["allowed/topic", "another/allowed/topic/#"]} + )) +]) +@pytest.mark.asyncio +async def test_taclp_false_sub_no_permission(logdog, test_config): + """Check TopicAccessControlListPlugin returns false if "acl" does not list allowed topic.""" + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = test_config session = Session() session.username = "user" @@ -384,18 +415,23 @@ async def test_taclp_false_sub_no_permission(logdog): ) assert authorised is False - -@pytest.mark.asyncio -async def test_taclp_true_sub_permission(logdog): - """Check TopicAccessControlListPlugin returns true if "acl" lists allowed topic.""" - context = BaseContext() - context.logger = logging.getLogger("testlog") - context.config = { +@pytest.mark.parametrize("test_config", [ + ({ "topic-check": { "enabled": True, "acl": {"user": ["allowed/topic", "another/allowed/topic/#"]}, }, - } + }), + (TopicAccessControlListPlugin.Config( + acl={"user": ["allowed/topic", "another/allowed/topic/#"]} + )) +]) +@pytest.mark.asyncio +async def test_taclp_true_sub_permission(logdog, test_config): + """Check TopicAccessControlListPlugin returns true if "acl" lists allowed topic.""" + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = test_config session = Session() session.username = "user" @@ -409,17 +445,23 @@ async def test_taclp_true_sub_permission(logdog): assert authorised is True -@pytest.mark.asyncio -async def test_taclp_true_pub_permission(logdog): - """Check TopicAccessControlListPlugin returns true if "publish-acl" lists allowed topic for publish action.""" - context = BaseContext() - context.logger = logging.getLogger("testlog") - context.config = { +@pytest.mark.parametrize("test_config", [ + ({ "topic-check": { "enabled": True, "publish-acl": {"user": ["allowed/topic", "another/allowed/topic/#"]}, }, - } + }), + (TopicAccessControlListPlugin.Config( + publish_acl={"user": ["allowed/topic", "another/allowed/topic/#"]} + )) +]) +@pytest.mark.asyncio +async def test_taclp_true_pub_permission(logdog, test_config): + """Check TopicAccessControlListPlugin returns true if "publish-acl" lists allowed topic for publish action.""" + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = test_config session = Session() session.username = "user" @@ -433,17 +475,23 @@ async def test_taclp_true_pub_permission(logdog): assert authorised is True -@pytest.mark.asyncio -async def test_taclp_true_anon_sub_permission(logdog): - """Check TopicAccessControlListPlugin handles anonymous users.""" - context = BaseContext() - context.logger = logging.getLogger("testlog") - context.config = { +@pytest.mark.parametrize("test_config", [ + ({ "topic-check": { "enabled": True, "acl": {"anonymous": ["allowed/topic", "another/allowed/topic/#"]}, }, - } + }), + (TopicAccessControlListPlugin.Config( + acl={"anonymous": ["allowed/topic", "another/allowed/topic/#"]} + )) +]) +@pytest.mark.asyncio +async def test_taclp_true_anon_sub_permission(logdog, test_config): + """Check TopicAccessControlListPlugin handles anonymous users.""" + context = BaseContext() + context.logger = logging.getLogger("testlog") + context.config = test_config session = Session() session.username = None diff --git a/tests/test_client.py b/tests/test_client.py index 28f5c49..4859c4e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -463,7 +463,7 @@ async def test_client_no_auth(): match group: case 'tests.mock_plugins': return [ - EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:NoAuthPlugin'), + EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:TestNoAuthPlugin'), ] case _: return list() diff --git a/tests/test_version.py b/tests/test_version.py deleted file mode 100644 index d19fe1c..0000000 --- a/tests/test_version.py +++ /dev/null @@ -1,78 +0,0 @@ -import subprocess -import unittest -from unittest.mock import MagicMock, patch -import warnings - -from amqtt.version import get_git_changeset, get_version - - -class TestVersionFunctions(unittest.TestCase): - @patch("amqtt.version.warnings.warn") - def test_get_version(self, mock_warn): - """Test get_version returns amqtt.__version__ and raises a deprecation warning.""" - with patch("amqtt.__version__", "1.2.3"): - version = get_version() - assert version == "1.2.3" - mock_warn.assert_called_once_with( - "amqtt.version.get_version() is deprecated, use amqtt.__version__ instead", - stacklevel=3, - ) - - def test_get_version_no_warning(self): - """Test get_version does not trigger a warning when explicitly suppressed.""" - with patch("amqtt.__version__", "1.2.3"), warnings.catch_warnings(record=True) as captured_warnings: - warnings.simplefilter("ignore") - version = get_version() - assert version == "1.2.3" - assert len(captured_warnings) == 0 # No warnings should be captured - - @patch("amqtt.version.Path") - @patch("amqtt.version.shutil.which") - @patch("amqtt.version.subprocess.Popen") - def test_get_git_changeset(self, mock_popen, mock_which, mock_path): - """Test get_git_changeset returns the correct timestamp or None on failure.""" - # Mock the repo directory - mock_repo_dir = MagicMock() - mock_repo_dir.is_dir.return_value = True - mock_path.return_value.resolve.return_value.parent.parent = mock_repo_dir - - # Mock git executable check - mock_which.return_value = True - - # Mock subprocess.Popen for git log with context manager behavior - mock_process = MagicMock() - mock_process.communicate.return_value = ("1638352940", "") - mock_process.returncode = 0 - mock_popen.return_value.__enter__.return_value = mock_process - - # Call the function - changeset = get_git_changeset() - - # Verify the results - assert changeset == "20211201100220" # Matches timestamp conversion - mock_which.assert_called_once_with("git") - mock_popen.assert_called_once_with( - ["git", "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=mock_repo_dir, - universal_newlines=True, - ) - - # Test invalid directory - mock_repo_dir.is_dir.return_value = False - changeset = get_git_changeset() - assert changeset is None - - # Test missing git - mock_repo_dir.is_dir.return_value = True - mock_which.return_value = False - changeset = get_git_changeset() - assert changeset is None - - # Test git command failure - mock_which.return_value = True - mock_process.returncode = 1 - mock_process.communicate.return_value = ("", "Some error") - changeset = get_git_changeset() - assert changeset is None diff --git a/uv.lock b/uv.lock index 18071eb..fd26736 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ name = "amqtt" version = "0.11.1" source = { editable = "." } dependencies = [ + { name = "dacite" }, { name = "passlib" }, { name = "psutil" }, { name = "pyyaml" }, @@ -67,6 +68,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" }, + { name = "dacite", specifier = ">=1.9.2" }, { name = "passlib", specifier = "==1.7.4" }, { name = "psutil", specifier = ">=7.0.0" }, { name = "pyyaml", specifier = "==6.0.2" }, @@ -487,6 +489,15 @@ version = "0.9.5" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" } +[[package]] +name = "dacite" +version = "1.9.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" }, +] + [[package]] name = "dill" version = "0.4.0"