kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #240 from ajmirsky/migrate_existing_plugins
config-file based plugin loadingpull/241/head^2
commit
4e9a43cdcf
|
@ -3,7 +3,6 @@ from asyncio import CancelledError, futures
|
|||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
import copy
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
@ -23,6 +22,7 @@ from amqtt.adapters import (
|
|||
WebSocketsWriter,
|
||||
WriterAdapter,
|
||||
)
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
|
||||
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
|
@ -30,7 +30,7 @@ from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
|
|||
|
||||
from .events import BrokerEvents
|
||||
from .mqtt.disconnect import DisconnectPacket
|
||||
from .plugins.manager import BaseContext, PluginManager
|
||||
from .plugins.manager import PluginManager
|
||||
|
||||
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
||||
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
||||
|
@ -44,13 +44,6 @@ DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
|||
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
"""Actions issued by the broker."""
|
||||
|
||||
SUBSCRIBE = "subscribe"
|
||||
PUBLISH = "publish"
|
||||
|
||||
|
||||
class RetainedApplicationMessage(ApplicationMessage):
|
||||
__slots__ = ("data", "qos", "source_session", "topic")
|
||||
|
||||
|
@ -164,6 +157,10 @@ class Broker:
|
|||
self.logger = logging.getLogger(__name__)
|
||||
self.config = copy.deepcopy(_defaults or {})
|
||||
if config is not None:
|
||||
# if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
|
||||
if ("auth" in config or "topic-check" in config) and "plugins" not in config:
|
||||
# set to None so that the config isn't updated with the new-style default plugin list
|
||||
config["plugins"] = None # type: ignore[assignment]
|
||||
self.config.update(config)
|
||||
self._build_listeners_config(self.config)
|
||||
|
||||
|
@ -766,13 +763,7 @@ class Broker:
|
|||
:param action: What is being done with the topic? subscribe or publish
|
||||
:return:
|
||||
"""
|
||||
topic_config = self.config.get("topic-check", {})
|
||||
enabled = False
|
||||
|
||||
if isinstance(topic_config, dict):
|
||||
enabled = topic_config.get("enabled", False)
|
||||
|
||||
if not enabled:
|
||||
if not self.plugins_manager.is_topic_filtering_enabled():
|
||||
return True
|
||||
|
||||
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
||||
|
|
|
@ -19,11 +19,12 @@ from amqtt.adapters import (
|
|||
WebSocketsReader,
|
||||
WebSocketsWriter,
|
||||
)
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
|
||||
from amqtt.mqtt.connack import CONNECTION_ACCEPTED
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
from amqtt.utils import gen_client_id, read_yaml_config
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
from enum import Enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
|
||||
class BaseContext:
|
||||
def __init__(self) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.logger: logging.Logger = _LOGGER
|
||||
self.config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
"""Actions issued by the broker."""
|
||||
|
||||
SUBSCRIBE = "subscribe"
|
||||
PUBLISH = "publish"
|
|
@ -29,11 +29,16 @@ class PluginError(Exception):
|
|||
|
||||
|
||||
class PluginImportError(PluginError):
|
||||
def __init__(self, plugin: Any) -> None:
|
||||
super().__init__(f"Plugin import failed: {plugin!r}")
|
||||
"""Exceptions thrown when loading plugin."""
|
||||
|
||||
|
||||
class PluginCoroError(PluginError):
|
||||
"""Exceptions thrown when loading a plugin with a non-async call method."""
|
||||
|
||||
|
||||
class PluginInitError(PluginError):
|
||||
"""Exceptions thrown when initializing plugin."""
|
||||
|
||||
def __init__(self, plugin: Any) -> None:
|
||||
super().__init__(f"Plugin init failed: {plugin!r}")
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import logging
|
|||
from typing import Generic, TypeVar, cast
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError
|
||||
from amqtt.events import MQTTEvents
|
||||
from amqtt.mqtt import packet_class
|
||||
|
@ -57,7 +58,7 @@ from amqtt.mqtt.suback import SubackPacket
|
|||
from amqtt.mqtt.subscribe import SubscribePacket
|
||||
from amqtt.mqtt.unsuback import UnsubackPacket
|
||||
from amqtt.mqtt.unsubscribe import UnsubscribePacket
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from passlib.apps import custom_app_context as pwd_context
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.plugins.base import BaseAuthPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
|
@ -12,12 +14,17 @@ _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
|||
class AnonymousAuthPlugin(BaseAuthPlugin):
|
||||
"""Authentication plugin allowing anonymous access."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
# Default to allowing anonymous
|
||||
self._allow_anonymous = self._get_config_option("allow-anonymous", True) # noqa: FBT003
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool:
|
||||
authenticated = await super().authenticate(session=session)
|
||||
if authenticated:
|
||||
# Default to allowing anonymous
|
||||
allow_anonymous = self.auth_config.get("allow-anonymous", True) if isinstance(self.auth_config, dict) else True
|
||||
if allow_anonymous:
|
||||
|
||||
if self._allow_anonymous:
|
||||
self.context.logger.debug("Authentication success: config allows anonymous")
|
||||
return True
|
||||
|
||||
|
@ -27,6 +34,12 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
|
|||
self.context.logger.debug("Authentication failure: session has no username")
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Allow empty username."""
|
||||
|
||||
allow_anonymous: bool = field(default=True)
|
||||
|
||||
|
||||
class FileAuthPlugin(BaseAuthPlugin):
|
||||
"""Authentication plugin based on a file-stored user database."""
|
||||
|
@ -38,7 +51,7 @@ class FileAuthPlugin(BaseAuthPlugin):
|
|||
|
||||
def _read_password_file(self) -> None:
|
||||
"""Read the password file and populates the user dictionary."""
|
||||
password_file = self.auth_config.get("password-file") if isinstance(self.auth_config, dict) else None
|
||||
password_file = self._get_config_option("password-file", None)
|
||||
if not password_file:
|
||||
self.context.logger.warning("Configuration parameter 'password-file' not found")
|
||||
return
|
||||
|
@ -87,3 +100,9 @@ class FileAuthPlugin(BaseAuthPlugin):
|
|||
|
||||
self.context.logger.debug(f"Authentication failure: password mismatch for user '{session.username}'")
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Path to the properly encoded password file."""
|
||||
|
||||
password_file: str | None = None
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
@ -24,6 +24,20 @@ class BasePlugin(Generic[C]):
|
|||
return None
|
||||
return section_config
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
if option_name in self.context.config:
|
||||
return self.context.config[option_name]
|
||||
return default
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Override to define the configuration and defaults for plugin."""
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Override if plugin needs to clean up resources upon shutdown."""
|
||||
|
||||
|
@ -35,9 +49,19 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
super().__init__(context)
|
||||
|
||||
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
|
||||
if self.topic_config is None:
|
||||
if not bool(self.topic_config) and not is_dataclass(self.context.config):
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
if self.topic_config and option_name in self.topic_config:
|
||||
return self.topic_config[option_name]
|
||||
return default
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
|
@ -52,23 +76,31 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
bool: `True` if topic is allowed, `False` otherwise
|
||||
|
||||
"""
|
||||
if not self.topic_config:
|
||||
# auth config section not found
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
return False
|
||||
return True
|
||||
return bool(self.topic_config) or is_dataclass(self.context.config)
|
||||
|
||||
|
||||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
if self.auth_config and option_name in self.auth_config:
|
||||
return self.auth_config[option_name]
|
||||
return default
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
|
||||
if not self.auth_config:
|
||||
if not bool(self.auth_config) and not is_dataclass(self.context.config):
|
||||
# auth config section not found and Config dataclass not provided
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Logic for session authentication.
|
||||
|
||||
|
@ -80,8 +112,4 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
|||
- `None` if authentication can't be achieved (then plugin result is then ignored)
|
||||
|
||||
"""
|
||||
if not self.auth_config:
|
||||
# auth config section not found
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
return False
|
||||
return True
|
||||
return bool(self.auth_config) or is_dataclass(self.context.config)
|
||||
|
|
|
@ -3,11 +3,11 @@ from functools import partial
|
|||
import logging
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.events import BrokerEvents
|
||||
from amqtt.mqtt import MQTTPacket
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, MQTTPayload, MQTTVariableHeader
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
PACKET: TypeAlias = MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
|
||||
__all__ = ["PluginManager", "get_plugin_manager"]
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
@ -8,17 +8,17 @@ import copy
|
|||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
from inspect import iscoroutinefunction
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar
|
||||
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
|
||||
import warnings
|
||||
|
||||
from amqtt.errors import PluginImportError, PluginInitError
|
||||
from dacite import Config as DaciteConfig, DaciteError, from_dict
|
||||
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.errors import PluginCoroError, PluginImportError, PluginInitError, PluginLoadError
|
||||
from amqtt.events import BrokerEvents, Events, MQTTEvents
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
from amqtt.utils import import_string
|
||||
|
||||
|
||||
class Plugin(NamedTuple):
|
||||
|
@ -39,11 +39,11 @@ def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
|
|||
return plugins_manager.get(namespace)
|
||||
|
||||
|
||||
class BaseContext:
|
||||
def __init__(self) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.logger: logging.Logger = _LOGGER
|
||||
self.config: dict[str, Any] | None = None
|
||||
def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
|
||||
try:
|
||||
return issubclass(sub_class, super_class)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
|
||||
|
@ -70,6 +70,9 @@ class PluginManager(Generic[C]):
|
|||
self._auth_plugins: list[BaseAuthPlugin] = []
|
||||
self._topic_plugins: list[BaseTopicPlugin] = []
|
||||
self._event_plugin_callbacks: dict[str, list[AsyncFunc]] = defaultdict(list)
|
||||
self._is_topic_filtering_enabled = False
|
||||
self._is_auth_filtering_enabled = False
|
||||
|
||||
self._load_plugins(namespace)
|
||||
self._fired_events: list[asyncio.Future[Any]] = []
|
||||
plugins_manager[namespace] = self
|
||||
|
@ -78,10 +81,41 @@ class PluginManager(Generic[C]):
|
|||
def app_context(self) -> BaseContext:
|
||||
return self.context
|
||||
|
||||
def _load_plugins(self, namespace: str) -> None:
|
||||
def _load_plugins(self, namespace: str | None = None) -> None:
|
||||
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
|
||||
if "auth" in self.app_context.config:
|
||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||
if "topic-check" in self.app_context.config:
|
||||
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
|
||||
|
||||
plugin_list: list[Any] = self.app_context.config.get("plugins", [])
|
||||
self._load_str_plugins(plugin_list)
|
||||
else:
|
||||
if not namespace:
|
||||
msg = "Namespace needs to be provided for EntryPoint plugin definitions"
|
||||
raise PluginLoadError(msg)
|
||||
|
||||
warnings.warn(
|
||||
"Loading plugins from EntryPoints is deprecated and will be removed in a future version."
|
||||
" Use `plugins` section of config instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
self._load_ep_plugins(namespace)
|
||||
|
||||
for plugin in self._plugins:
|
||||
for event in list(BrokerEvents) + list(MQTTEvents):
|
||||
if awaitable := getattr(plugin, f"on_{event}", None):
|
||||
if not iscoroutinefunction(awaitable):
|
||||
msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'"
|
||||
raise PluginImportError(msg)
|
||||
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
|
||||
self._event_plugin_callbacks[event].append(awaitable)
|
||||
|
||||
def _load_ep_plugins(self, namespace:str) -> None:
|
||||
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
|
||||
auth_filter_list = []
|
||||
topic_filter_list = []
|
||||
if self.app_context.config and "auth" in self.app_context.config:
|
||||
|
@ -107,15 +141,6 @@ class PluginManager(Generic[C]):
|
|||
self._topic_plugins.append(ep_plugin.object)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
|
||||
for plugin in self._plugins:
|
||||
for event in list(BrokerEvents) + list(MQTTEvents):
|
||||
if awaitable := getattr(plugin, f"on_{event}", None):
|
||||
if not iscoroutinefunction(awaitable):
|
||||
msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'"
|
||||
raise PluginImportError(msg)
|
||||
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
|
||||
self._event_plugin_callbacks[event].append(awaitable)
|
||||
|
||||
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||
try:
|
||||
self.logger.debug(f" Loading plugin {ep!s}")
|
||||
|
@ -136,6 +161,70 @@ class PluginManager(Generic[C]):
|
|||
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
|
||||
raise PluginInitError(ep) from e
|
||||
|
||||
def _load_str_plugins(self, plugin_list: list[Any]) -> None:
|
||||
|
||||
self.logger.info("Loading plugins from config")
|
||||
self._is_topic_filtering_enabled = True
|
||||
self._is_auth_filtering_enabled = True
|
||||
for plugin_info in plugin_list:
|
||||
|
||||
if isinstance(plugin_info, dict):
|
||||
if len(plugin_info.keys()) > 1:
|
||||
msg = f"config file should have only one key: {plugin_info.keys()}"
|
||||
raise ValueError(msg)
|
||||
plugin_path = next(iter(plugin_info.keys()))
|
||||
plugin_cfg = plugin_info[plugin_path]
|
||||
plugin = self._load_str_plugin(plugin_path, plugin_cfg)
|
||||
elif isinstance(plugin_info, str):
|
||||
plugin = self._load_str_plugin(plugin_info, {})
|
||||
else:
|
||||
msg = "Unexpected entry in plugins config"
|
||||
raise PluginLoadError(msg)
|
||||
|
||||
self._plugins.append(plugin)
|
||||
if isinstance(plugin, BaseAuthPlugin):
|
||||
if not iscoroutinefunction(plugin.authenticate):
|
||||
msg = f"Auth plugin {plugin_info} has non-async authenticate method."
|
||||
raise PluginCoroError(msg)
|
||||
self._auth_plugins.append(plugin)
|
||||
if isinstance(plugin, BaseTopicPlugin):
|
||||
if not iscoroutinefunction(plugin.topic_filtering):
|
||||
msg = f"Topic plugin {plugin_info} has non-async topic_filtering method."
|
||||
raise PluginCoroError(msg)
|
||||
self._topic_plugins.append(plugin)
|
||||
|
||||
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
|
||||
|
||||
try:
|
||||
plugin_class: Any = import_string(plugin_path)
|
||||
except ImportError as ep:
|
||||
msg = f"Plugin import failed: {plugin_path}"
|
||||
raise PluginImportError(msg) from ep
|
||||
|
||||
if not safe_issubclass(plugin_class, BasePlugin):
|
||||
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
|
||||
raise PluginLoadError(msg)
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
|
||||
try:
|
||||
plugin_context.config = from_dict(data_class=plugin_class.Config,
|
||||
data=plugin_cfg or {},
|
||||
config=DaciteConfig(strict=True))
|
||||
except DaciteError as e:
|
||||
raise PluginLoadError from e
|
||||
except TypeError as e:
|
||||
msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass."
|
||||
raise PluginLoadError(msg) from e
|
||||
|
||||
try:
|
||||
pc = plugin_class(plugin_context)
|
||||
self.logger.debug(f"Loading plugin {plugin_path}")
|
||||
return cast("BasePlugin[C]", pc)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True)
|
||||
raise PluginInitError(plugin_class) from e
|
||||
|
||||
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
|
||||
|
@ -147,6 +236,12 @@ class PluginManager(Generic[C]):
|
|||
return p
|
||||
return None
|
||||
|
||||
def is_topic_filtering_enabled(self) -> bool:
|
||||
topic_config = self.app_context.config.get("topic-check", {}) if self.app_context.config else {}
|
||||
if isinstance(topic_config, dict):
|
||||
return topic_config.get("enabled", False) or self._is_topic_filtering_enabled
|
||||
return False or self._is_topic_filtering_enabled
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Free PluginManager resources and cancel pending event methods."""
|
||||
await self.map_plugin_close()
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from collections import deque # pylint: disable=C0412
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412
|
||||
|
||||
import psutil
|
||||
|
@ -70,8 +71,11 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
# Broker statistics initialization
|
||||
self._stats: dict[str, int] = {}
|
||||
self._sys_handle: asyncio.Handle | None = None
|
||||
|
||||
self._sys_interval: int = 0
|
||||
self._current_process = psutil.Process()
|
||||
|
||||
|
||||
def _clear_stats(self) -> None:
|
||||
"""Initialize broker statistics data structures."""
|
||||
for stat in (
|
||||
|
@ -112,21 +116,21 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
|
||||
# Start $SYS topics management
|
||||
try:
|
||||
sys_interval: int = 0
|
||||
x = self.context.config.get("sys_interval") if self.context.config is not None else None
|
||||
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
|
||||
sys_interval = int(x)
|
||||
if sys_interval > 0:
|
||||
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
|
||||
self._sys_interval = self._get_config_option("sys_interval", None)
|
||||
if isinstance(self._sys_interval, str | Buffer | SupportsInt | SupportsIndex):
|
||||
self._sys_interval = int(self._sys_interval)
|
||||
|
||||
if self._sys_interval > 0:
|
||||
self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds")
|
||||
self._sys_handle = (
|
||||
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
|
||||
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
|
||||
if self.context.loop is not None
|
||||
else None
|
||||
)
|
||||
else:
|
||||
self.context.logger.debug("$SYS disabled")
|
||||
except KeyError:
|
||||
pass
|
||||
self.context.logger.debug("could not find 'sys_interval' key: {e!r}")
|
||||
# 'sys_interval' config parameter not found
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
|
@ -194,15 +198,9 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
tasks.popleft()
|
||||
|
||||
# Reschedule
|
||||
sys_interval: int = 0
|
||||
x = self.context.config.get("sys_interval") if self.context.config is not None else None
|
||||
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
|
||||
sys_interval = int(x)
|
||||
self.context.logger.debug("Broadcasting $SYS topics")
|
||||
|
||||
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
|
||||
self.context.logger.debug(f"Broadcast $SYS topics again in {self._sys_interval} seconds.")
|
||||
self._sys_handle = (
|
||||
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
|
||||
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
|
||||
if self.context.loop is not None
|
||||
else None
|
||||
)
|
||||
|
@ -237,3 +235,9 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
"""Handle broker client disconnection."""
|
||||
self._stats[STAT_CLIENTS_CONNECTED] -= 1
|
||||
self._stats[STAT_CLIENTS_DISCONNECTED] += 1
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration struct for plugin."""
|
||||
|
||||
sys_interval: int = 0
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.plugins.base import BaseTopicPlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
|
@ -52,26 +52,34 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
|||
return False
|
||||
|
||||
# hbmqtt and older amqtt do not support publish filtering
|
||||
if action == Action.PUBLISH and self.topic_config is not None and "publish-acl" not in self.topic_config:
|
||||
if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}):
|
||||
# maintain backward compatibility, assume permitted
|
||||
return True
|
||||
|
||||
req_topic = topic
|
||||
if not req_topic:
|
||||
return False
|
||||
return False\
|
||||
|
||||
username = session.username if session else None
|
||||
if username is None:
|
||||
username = "anonymous"
|
||||
|
||||
acl: dict[str, Any] = {}
|
||||
if self.topic_config is not None and action == Action.PUBLISH:
|
||||
acl = self.topic_config.get("publish-acl", {})
|
||||
elif self.topic_config is not None and action == Action.SUBSCRIBE:
|
||||
acl = self.topic_config.get("acl", {})
|
||||
match action:
|
||||
case Action.PUBLISH:
|
||||
acl = self._get_config_option("publish-acl", {})
|
||||
case Action.SUBSCRIBE:
|
||||
acl = self._get_config_option("acl", {})
|
||||
|
||||
allowed_topics = acl.get(username, [])
|
||||
if not allowed_topics:
|
||||
return False
|
||||
|
||||
return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics)
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Mappings of username and list of approved topics."""
|
||||
|
||||
publish_acl: dict[str, list[str]] = field(default_factory=dict)
|
||||
acl: dict[str, list[str]] = field(default_factory=dict)
|
||||
|
|
|
@ -3,10 +3,10 @@ listeners:
|
|||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
sys_interval: 20
|
||||
auth:
|
||||
plugins:
|
||||
- auth_anonymous
|
||||
allow-anonymous: true
|
||||
topic-check:
|
||||
enabled: False
|
||||
plugins:
|
||||
- amqtt.plugins.logging_amqtt.EventLoggerPlugin:
|
||||
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
- amqtt.plugins.authentication.AnonymousAuthPlugin:
|
||||
allow_anonymous: true
|
||||
- amqtt.plugins.sys.broker.BrokerSysPlugin:
|
||||
sys_interval: 20
|
||||
|
|
|
@ -7,4 +7,6 @@ auto_reconnect: true
|
|||
reconnect_max_interval: 10
|
||||
reconnect_retries: 2
|
||||
broker:
|
||||
uri: "mqtt://127.0.0.1"
|
||||
uri: "mqtt://127.0.0.1"
|
||||
plugins:
|
||||
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import string
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
|
@ -48,3 +50,39 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
|
|||
except yaml.YAMLError:
|
||||
logger.exception(f"Invalid config_file {config_file}")
|
||||
return None
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str | None = None) -> Any:
|
||||
"""Return cached import of a class from a module path (or retrieve, cache and then return)."""
|
||||
# Check whether module is loaded and fully initialized.
|
||||
if not ((module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False):
|
||||
module = import_module(module_path)
|
||||
if class_name:
|
||||
return getattr(module, class_name)
|
||||
return module
|
||||
|
||||
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""Import a dotted module path.
|
||||
|
||||
Returns:
|
||||
attribute/class designated by the last name in the path
|
||||
|
||||
Raises:
|
||||
ImportError (if the import failed)
|
||||
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = dotted_path.rsplit(".", 1)
|
||||
except ValueError as err:
|
||||
msg = f"{dotted_path} doesn't look like a module path"
|
||||
raise ImportError(msg) from err
|
||||
|
||||
try:
|
||||
return cached_import(module_path, class_name)
|
||||
except AttributeError as err:
|
||||
msg = f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||
|
||||
raise ImportError(msg) from err
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
try:
|
||||
from datetime import UTC, datetime
|
||||
except ImportError:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
UTC = timezone.utc
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
|
||||
import amqtt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
"""Return the version of the amqtt package.
|
||||
|
||||
This function is deprecated. Use amqtt.__version__ instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"amqtt.version.get_version() is deprecated, use amqtt.__version__ instead",
|
||||
stacklevel=3, # Adjusted stack level to better reflect the caller
|
||||
)
|
||||
return amqtt.__version__
|
||||
|
||||
|
||||
def get_git_changeset() -> str | None:
|
||||
"""Return a numeric identifier of the latest git changeset.
|
||||
|
||||
The result is the UTC timestamp of the changeset in YYYYMMDDHHMMSS format.
|
||||
This value isn't guaranteed to be unique, but collisions are very unlikely,
|
||||
so it's sufficient for generating the development version numbers.
|
||||
"""
|
||||
# Define the repository directory (two levels above the current script)
|
||||
repo_dir = Path(__file__).resolve().parent.parent
|
||||
|
||||
# Ensure the directory exists and is valid
|
||||
if not repo_dir.is_dir():
|
||||
logger.error(f"Invalid directory: {repo_dir} is not a valid directory")
|
||||
return None
|
||||
|
||||
# Use the system's PATH to locate 'git', or define the full path if necessary
|
||||
git_path = "git" # Assuming git is available in the system PATH
|
||||
|
||||
# Ensure 'git' is executable and available
|
||||
if not shutil.which(git_path):
|
||||
logger.error(f"{git_path} is not found in the system PATH.")
|
||||
return None
|
||||
|
||||
# Call git log to get the latest changeset timestamp
|
||||
try:
|
||||
with subprocess.Popen( # noqa: S603
|
||||
[git_path, "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=repo_dir,
|
||||
universal_newlines=True,
|
||||
) as git_log:
|
||||
timestamp_str, stderr = git_log.communicate()
|
||||
|
||||
if git_log.returncode != 0:
|
||||
logger.error(f"Git command failed with error: {stderr}")
|
||||
return None
|
||||
|
||||
# Convert the timestamp to a datetime object
|
||||
timestamp = datetime.fromtimestamp(int(timestamp_str), tz=UTC)
|
||||
return timestamp.strftime("%Y%m%d%H%M%S")
|
||||
|
||||
except Exception:
|
||||
logger.exception("An error occurred while retrieving the git changeset.")
|
||||
return None
|
|
@ -1,35 +1,79 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
# Custom Plugins
|
||||
|
||||
With the aMQTT Broker plugins framework, one can add additional functionality to the broker without
|
||||
having to subclass or rewrite any of the core broker logic. To define a custom list of plugins to be loaded,
|
||||
add this section to your `pyproject.toml`"
|
||||
With the aMQTT plugins framework, one can add additional functionality to the client or broker without
|
||||
having to rewrite any of the core logic.
|
||||
|
||||
```toml
|
||||
[project.entry-points."mypackage.mymodule.plugins"]
|
||||
plugin_alias = "module.submodule.file:ClassName"
|
||||
```
|
||||
|
||||
and specify the namespace when instantiating the broker:
|
||||
To create a custom plugin, subclass from `BasePlugin` (client or broker) or `BaseAuthPlugin` (broker only)
|
||||
or `BaseTopicPlugin` (broker only). Each custom plugin may define settings specific to itself by creating
|
||||
a nested (or inner) `dataclass` named `Config` which declares each option and a default value (if applicable). A
|
||||
plugin's configuration dataclass will be type-checked and made available from within the `self.context` instance variable.
|
||||
|
||||
```python
|
||||
from amqtt.broker import Broker
|
||||
from dataclasses import dataclass, field
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.contexts import BaseContext
|
||||
|
||||
broker = Broker(plugin_namespace='mypackage.mymodule.plugins')
|
||||
|
||||
class OneClassName(BasePlugin[BaseContext]):
|
||||
"""This is a plugin with no functionality"""
|
||||
|
||||
|
||||
class TwoClassName(BasePlugin[BaseContext]):
|
||||
"""This is a plugin with configuration options."""
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
my_option_one: str = self.context.config.option1
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
option1: int
|
||||
option3: str = field(default="my_default_value")
|
||||
|
||||
```
|
||||
|
||||
Each plugin has access to the full configuration file through the provided `BaseContext` and can define
|
||||
its own variables to configure its behavior.
|
||||
This plugin class then should be added to the configuration file of the broker or client (or to the `config`
|
||||
dictionary passed to the `Broker` or `MQTTClient`).
|
||||
|
||||
```yaml
|
||||
...
|
||||
...
|
||||
plugins:
|
||||
- module.submodule.file.OneClassName:
|
||||
- module.submodule.file.TwoClassName:
|
||||
option1: 123
|
||||
```
|
||||
|
||||
??? warning "Deprecated: activating plugins using `EntryPoints`"
|
||||
With the aMQTT plugins framework, one can add additional functionality to the client or broker without
|
||||
having to rewrite any of the core logic. To define a custom list of plugins to be loaded, add this section
|
||||
to your `pyproject.toml`"
|
||||
|
||||
```toml
|
||||
[project.entry-points."mypackage.mymodule.plugins"]
|
||||
plugin_alias = "module.submodule.file:ClassName"
|
||||
```
|
||||
|
||||
Each plugin has access to the full configuration file through the provided `BaseContext` and can define its own
|
||||
variables to configure its behavior.
|
||||
|
||||
::: amqtt.plugins.base.BasePlugin
|
||||
|
||||
## Events
|
||||
|
||||
Plugins that are defined in the`project.entry-points` are notified of events if the subclass
|
||||
implements one or more of these methods:
|
||||
All plugins are notified of events if the `BasePlugin` subclass implements one or more of these methods:
|
||||
|
||||
- `async def on_mqtt_packet_sent(self, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
- `async def on_mqtt_packet_received(self, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
### Client and Broker
|
||||
|
||||
- `async def on_mqtt_packet_sent(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
|
||||
### Client Only
|
||||
|
||||
none
|
||||
|
||||
### Broker Only
|
||||
|
||||
- `async def on_broker_pre_start() -> None`
|
||||
- `async def on_broker_post_start() -> None`
|
||||
|
@ -47,32 +91,22 @@ implements one or more of these methods:
|
|||
|
||||
## Authentication Plugins
|
||||
|
||||
Of the plugins listed in `project.entry-points`, one or more can be used to validate client sessions
|
||||
by specifying their alias in `auth` > `plugins` section of the config:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
plugins:
|
||||
- plugin_alias_name
|
||||
```
|
||||
|
||||
These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method.
|
||||
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`
|
||||
is used by the aMQTT `Broker` to determine if a connection from a client is allowed by
|
||||
implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise.
|
||||
|
||||
::: amqtt.plugins.base.BaseAuthPlugin
|
||||
|
||||
## Topic Filter Plugins
|
||||
|
||||
Of the plugins listed in `project.entry-points`, one or more can be used to determine topic access
|
||||
by specifying their alias in `topic-check` > `plugins` section of the config:
|
||||
|
||||
```yaml
|
||||
topic-check:
|
||||
enable: True
|
||||
plugins:
|
||||
- plugin_alias_name
|
||||
```
|
||||
|
||||
These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method.
|
||||
|
||||
In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin`
|
||||
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE)
|
||||
messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or
|
||||
`False` otherwise.
|
||||
|
||||
::: amqtt.plugins.base.BaseTopicPlugin
|
||||
|
||||
|
||||
!!! note
|
||||
A custom plugin class can subclass from both `BaseAuthPlugin` and `BaseTopicPlugin` as long it defines
|
||||
both the `authenticate` and `topic_filtering` method.
|
||||
|
|
|
@ -1,48 +1,99 @@
|
|||
# Existing Plugins
|
||||
|
||||
With the aMQTT Broker plugins framework, one can add additional functionality without
|
||||
having to rewrite core logic. Plugins loaded by default are specified in `pyproject.toml`:
|
||||
With the aMQTT plugins framework, one can add additional functionality without
|
||||
having to rewrite core logic in the broker or client. Plugins can be loaded and configured using
|
||||
the `plugins` section of the config file (or parameter passed to the class).
|
||||
|
||||
|
||||
## Broker
|
||||
|
||||
By default, `EventLoggerPlugin`, `PacketLoggerPlugin`, `AnonymousAuthPlugin` and `BrokerSysPlugin` are activated
|
||||
and configured for the broker:
|
||||
|
||||
```yaml
|
||||
--8<-- "pyproject.toml:included"
|
||||
--8<-- "amqtt/scripts/default_broker.yaml"
|
||||
```
|
||||
|
||||
## auth_anonymous (Auth Plugin)
|
||||
|
||||
`amqtt.plugins.authentication:AnonymousAuthPlugin`
|
||||
??? warning "Loading plugins from EntryPoints in `pyproject.toml` has been deprecated"
|
||||
|
||||
Previously, all plugins were loaded from EntryPoints:
|
||||
|
||||
```toml
|
||||
--8<-- "pyproject.toml:included"
|
||||
```
|
||||
|
||||
But the same 4 plugins were activated in the previous default config:
|
||||
|
||||
```yaml
|
||||
--8<-- "samples/legacy.yaml"
|
||||
```
|
||||
|
||||
## Client
|
||||
|
||||
By default, the `PacketLoggerPlugin` is activated and configured for the client:
|
||||
|
||||
```yaml
|
||||
--8<-- "amqtt/scripts/default_client.yaml"
|
||||
```
|
||||
|
||||
## Plugins
|
||||
|
||||
### Anonymous (Auth Plugin)
|
||||
|
||||
`amqtt.plugins.authentication.AnonymousAuthPlugin`
|
||||
|
||||
**Configuration**
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
plugins:
|
||||
- auth_anonymous
|
||||
allow-anonymous: true # if false, providing a username will allow access
|
||||
plugins:
|
||||
- ...
|
||||
- amqtt.plugins.authentication.AnonymousAuthPlugin:
|
||||
allow_anonymous: false
|
||||
- ...
|
||||
|
||||
```
|
||||
|
||||
!!! danger
|
||||
even if `allow-anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
|
||||
even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
|
||||
|
||||
|
||||
## auth_file (Auth Plugin)
|
||||
??? warning "EntryPoint-style configuration is deprecated"
|
||||
|
||||
`amqtt.plugins.authentication:FileAuthPlugin`
|
||||
```yaml
|
||||
auth:
|
||||
plugins:
|
||||
- auth_anonymous
|
||||
allow-anonymous: true # if false, providing a username will allow access
|
||||
|
||||
```
|
||||
|
||||
### Password File (Auth Plugin)
|
||||
|
||||
`amqtt.plugins.authentication.FileAuthPlugin`
|
||||
|
||||
clients are authorized by providing username and password, compared against file
|
||||
|
||||
**Configuration**
|
||||
|
||||
```yaml
|
||||
|
||||
auth:
|
||||
plugins:
|
||||
- auth_file
|
||||
password-file: /path/to/password_file
|
||||
|
||||
plugins:
|
||||
- ...
|
||||
- amqtt.plugins.authentication.FileAuthPlugin:
|
||||
password_file: /path/to/password_file
|
||||
- ...
|
||||
```
|
||||
|
||||
??? warning "EntryPoint-style configuration is deprecated"
|
||||
```yaml
|
||||
|
||||
auth:
|
||||
plugins:
|
||||
- auth_file
|
||||
password-file: /path/to/password_file
|
||||
|
||||
```
|
||||
|
||||
**File Format**
|
||||
|
||||
The file includes `username:password` pairs, one per line.
|
||||
|
@ -58,33 +109,42 @@ passwd = input() if not sys.stdin.isatty() else getpass()
|
|||
print(sha512_crypt.hash(passwd))
|
||||
```
|
||||
|
||||
## Taboo (Topic Plugin)
|
||||
### Taboo (Topic Plugin)
|
||||
|
||||
`amqtt.plugins.topic_checking:TopicTabooPlugin`
|
||||
`amqtt.plugins.topic_checking.TopicTabooPlugin`
|
||||
|
||||
Prevents using topics named: `prohibited`, `top-secret`, and `data/classified`
|
||||
|
||||
**Configuration**
|
||||
|
||||
```yaml
|
||||
topic-check:
|
||||
enabled: true
|
||||
plugins:
|
||||
- topic_taboo
|
||||
plugins:
|
||||
- ...
|
||||
- amqtt.plugins.topic_checking.TopicTabooPlugin:
|
||||
- ...
|
||||
```
|
||||
|
||||
## ACL (Topic Plugin)
|
||||
??? warning "EntryPoint-style configuration is deprecated"
|
||||
|
||||
`amqtt.plugins.topic_checking:TopicAccessControlListPlugin`
|
||||
```yaml
|
||||
topic-check:
|
||||
enabled: true
|
||||
plugins:
|
||||
- topic_taboo
|
||||
```
|
||||
|
||||
### ACL (Topic Plugin)
|
||||
|
||||
`amqtt.plugins.topic_checking.TopicAccessControlListPlugin`
|
||||
|
||||
**Configuration**
|
||||
|
||||
- `acl` *(list)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access.
|
||||
- `acl` *(mapping)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access.
|
||||
The list should be a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
|
||||
- `publish-acl` *(list)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where:
|
||||
- `publish-acl` *(mapping)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
!!! info "Reserved usernames"
|
||||
|
@ -93,47 +153,89 @@ topic-check:
|
|||
- The username `anonymous` will control allowed topics, if using the `auth_anonymous` plugin.
|
||||
|
||||
```yaml
|
||||
topic-check:
|
||||
enabled: true
|
||||
plugins:
|
||||
- topic_acl
|
||||
publish-acl:
|
||||
- username: ["list", "of", "allowed", "topics", "for", "publishing"]
|
||||
- .
|
||||
acl:
|
||||
- username: ["list", "of", "allowed", "topics", "for", "subscribing"]
|
||||
- .
|
||||
plugins:
|
||||
- ...
|
||||
- amqtt.plugins.topic_checking.TopicAccessControlListPlugin:
|
||||
publish_acl:
|
||||
- username: ["list", "of", "allowed", "topics", "for", "publishing"]
|
||||
acl:
|
||||
- username: ["list", "of", "allowed", "topics", "for", "subscribing"]
|
||||
- ...
|
||||
```
|
||||
|
||||
## Plugin: $SYS
|
||||
??? warning "EntryPoint-style configuration is deprecated"
|
||||
```yaml
|
||||
topic-check:
|
||||
enabled: true
|
||||
plugins:
|
||||
- topic_acl
|
||||
publish-acl:
|
||||
- username: ["list", "of", "allowed", "topics", "for", "publishing"]
|
||||
- .
|
||||
acl:
|
||||
- username: ["list", "of", "allowed", "topics", "for", "subscribing"]
|
||||
- .
|
||||
```
|
||||
|
||||
`amqtt.plugins.sys.broker:BrokerSysPlugin`
|
||||
### $SYS topics
|
||||
|
||||
`amqtt.plugins.sys.broker.BrokerSysPlugin`
|
||||
|
||||
Publishes, on a periodic basis, statistics about the broker
|
||||
|
||||
**Configuration**
|
||||
|
||||
- `sys_interval` - int, seconds between updates
|
||||
```yaml
|
||||
plugins:
|
||||
- ...
|
||||
- amqtt.plugins.sys.broker.BrokerSysPlugin:
|
||||
sys_interval: 20 # int, seconds between updates
|
||||
- ...
|
||||
```
|
||||
|
||||
**Supported Topics**
|
||||
|
||||
- `$SYS/broker/version` - payload: `str`
|
||||
- `$SYS/broker/load/bytes/received` - payload: `int`
|
||||
- `$SYS/broker/load/bytes/sent` - payload: `int`
|
||||
- `$SYS/broker/messages/received` - payload: `int`
|
||||
- `$SYS/broker/messages/sent` - payload: `int`
|
||||
- `$SYS/broker/time` - payload: `int` (current time, epoch seconds)
|
||||
- `$SYS/broker/uptime` - payload: `int` (seconds since broker start)
|
||||
- `$SYS/broker/uptime/formatted` - payload: `str` (start time of broker in UTC)
|
||||
- `$SYS/broker/clients/connected` - payload: `int` (current number of connected clients)
|
||||
- `$SYS/broker/clients/disconnected` - payload: `int` (number of clients that have disconnected)
|
||||
- `$SYS/broker/clients/maximum` - payload: `int`
|
||||
- `$SYS/broker/clients/total` - payload: `int`
|
||||
- `$SYS/broker/messages/inflight` - payload: `int`
|
||||
- `$SYS/broker/messages/inflight/in` - payload: `int`
|
||||
- `$SYS/broker/messages/inflight/out` - payload: `int`
|
||||
- `$SYS/broker/messages/inflight/stored` - payload: `int`
|
||||
- `$SYS/broker/messages/publish/received` - payload: `int`
|
||||
- `$SYS/broker/messages/publish/sent` - payload: `int`
|
||||
- `$SYS/broker/messages/retained/count` - payload: `int`
|
||||
- `$SYS/broker/messages/subscriptions/count` - payload: `int`
|
||||
- `$SYS/broker/version` *(string)*
|
||||
- `$SYS/broker/load/bytes/received` *(int)*
|
||||
- `$SYS/broker/load/bytes/sent` *(int)*
|
||||
- `$SYS/broker/messages/received` *(int)*
|
||||
- `$SYS/broker/messages/sent` *(int)*
|
||||
- `$SYS/broker/time` *(int, current time in epoch seconds)*
|
||||
- `$SYS/broker/uptime` *(int, seconds since broker start)*
|
||||
- `$SYS/broker/uptime/formatted` *(string, start time of broker in UTC)*
|
||||
- `$SYS/broker/clients/connected` *(int, number of currently connected clients)*
|
||||
- `$SYS/broker/clients/disconnected` *(int, number of clients that have disconnected)*
|
||||
- `$SYS/broker/clients/maximum` *(int, maximum number of clients connected)*
|
||||
- `$SYS/broker/clients/total` *(int)*
|
||||
- `$SYS/broker/messages/inflight` *(int)*
|
||||
- `$SYS/broker/messages/inflight/in` *(int)*
|
||||
- `$SYS/broker/messages/inflight/out` *(int)*
|
||||
- `$SYS/broker/messages/inflight/stored` *(int)*
|
||||
- `$SYS/broker/messages/publish/received` *(int)*
|
||||
- `$SYS/broker/messages/publish/sent` *(int)*
|
||||
- `$SYS/broker/messages/retained/count` *(int)*
|
||||
- `$SYS/broker/messages/subscriptions/count` *(int)*
|
||||
- `$SYS/broker/heap/size` *(float, MB)*
|
||||
- `$SYS/broker/heap/maximum` *(float, MB)*
|
||||
- `$SYS/broker/cpu/percent` *(float, %)*
|
||||
- `$SYS/broker/cpu/maximum` *(float, %)*
|
||||
|
||||
|
||||
### Event Logger
|
||||
|
||||
`amqtt.plugins.logging_amqtt.EventLoggerPlugin`
|
||||
|
||||
This plugin issues log messages when [broker and mqtt events](custom_plugins.md#events) are triggered:
|
||||
|
||||
- info level messages for `client connected` and `client disconnected`
|
||||
- debug level for all others
|
||||
|
||||
|
||||
### Packet Logger
|
||||
|
||||
`amqtt.plugins.logging_amqtt.PacketLoggerPlugin`
|
||||
|
||||
This plugin issues debug-level messages for [mqtt events](custom_plugins.md#client-and-broker): `on_mqtt_packet_sent`
|
||||
and `on_mqtt_packet_received`.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
This configuration structure is valid as a python dictionary passed to the `amqtt.broker.Broker` class's `__init__` method or
|
||||
as a yaml formatted file passed to the `amqtt` script.
|
||||
|
||||
### `listeners` *(list[mapping])*
|
||||
### `listeners` *(list[dict[str, Any]])*
|
||||
|
||||
Defines the network listeners used by the service. Items defined in the `default` listener will be
|
||||
applied to all other listeners, unless they are overridden by the configuration for the specific
|
||||
|
@ -20,77 +20,77 @@ listener.
|
|||
- `certfile` *(string)*: Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity.
|
||||
- `keyfile` *(string): A file containing the private key. Otherwise the private key will be taken from `certfile` as well.
|
||||
|
||||
### `sys_interval` *(int)*
|
||||
|
||||
System status report interval in seconds (`broker_sys` plugin)
|
||||
|
||||
### `timeout-disconnect-delay` *(int)*
|
||||
|
||||
Client disconnect timeout without a keep-alive
|
||||
|
||||
|
||||
### `auth` *(mapping)*
|
||||
### `plugins` *(mapping)*
|
||||
|
||||
Configuration for authentication behaviour:
|
||||
|
||||
- `plugins` *(list[string])*: defines the list of plugins which are activated as authentication plugins.
|
||||
|
||||
!!! note "Entry points"
|
||||
Plugins used here must first be defined in the `amqtt.broker.plugins` [entry point](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata).
|
||||
A list of strings representing the modules and class name of `BasePlugin`, `BaseAuthPlugin` and `BaseTopicPlugins`. Each
|
||||
entry may have one or more configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md)
|
||||
|
||||
|
||||
!!! danger "Legacy behavior"
|
||||
if `plugins` is omitted from the `auth` section, all plugins listed in the `amqtt.broker.plugins` entrypoint will be enabled
|
||||
for authentication, *including allowing anonymous login.*
|
||||
|
||||
`plugins: []` will deny connections from all clients.
|
||||
|
||||
- `allow-anonymous` *(bool)*: `True` will allow anonymous connections.
|
||||
|
||||
*Used by the internal `amqtt.plugins.authentication.AnonymousAuthPlugin` plugin*
|
||||
??? warning "Deprecated: `sys_interval` "
|
||||
**`sys_interval`** *(int)*
|
||||
|
||||
!!! danger "Username only connections"
|
||||
`False` does not disable the `auth_anonymous` plugin; connections will still be allowed as long as a username is provided.
|
||||
|
||||
If security is required, do not include `auth_anonymous` in the `plugins` list.
|
||||
System status report interval in seconds, used by the `amqtt.plugins.sys.broker.BrokerSysPlugin`.
|
||||
|
||||
|
||||
|
||||
- `password-file` *(string)*: Path to file which includes `username:password` pair, one per line. The password should be encoded using sha-512 with `mkpasswd -m sha-512` or:
|
||||
```python
|
||||
import sys
|
||||
from getpass import getpass
|
||||
from passlib.hash import sha512_crypt
|
||||
|
||||
passwd = input() if not sys.stdin.isatty() else getpass()
|
||||
print(sha512_crypt.hash(passwd))
|
||||
```
|
||||
|
||||
*Used by the internal `amqtt.plugins.authentication.FileAuthPlugin` plugin.*
|
||||
|
||||
### `topic-check` *(mapping)*
|
||||
??? warning "Deprecated: `auth` configuration settings"
|
||||
|
||||
Configuration for access control policies for publishing and subscribing to topics:
|
||||
**`auth`**
|
||||
|
||||
Configuration for authentication behaviour:
|
||||
|
||||
- `plugins` *(list[string])*: defines the list of plugins which are activated as authentication plugins.
|
||||
|
||||
!!! note
|
||||
Plugins used here must first be defined in the `amqtt.broker.plugins` [entry point](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata).
|
||||
|
||||
|
||||
- `enabled` *(bool)*: Enable access control policies (`true`). `false` will allow clients to publish and subscribe to any topic.
|
||||
- `plugins` *(list[string])*: defines the list of plugins which are activated as access control plugins. Note the plugins must be defined in the `amqtt.broker.plugins` [entry point](https://pythonhosted.org/setuptools/setuptools.html#dynamic-discovery-of-services-and-plugins).
|
||||
!!! warning
|
||||
If `plugins` is omitted from the `auth` section, all plugins listed in the `amqtt.broker.plugins` entrypoint will be enabled
|
||||
for authentication, including _allowing anonymous login._
|
||||
|
||||
`plugins: []` will deny connections from all clients.
|
||||
|
||||
- `allow-anonymous` *(bool)*: `True` will allow anonymous connections, used by `amqtt.plugins.authentication.AnonymousAuthPlugin`.
|
||||
|
||||
!!! danger
|
||||
`False` does not disable the `auth_anonymous` plugin; connections will still be allowed as long as a username is provided. If security is required, do not include `auth_anonymous` in the `plugins` list.
|
||||
|
||||
|
||||
- `acl` *(list)*: plugin to determine subscription access; if `publish-acl` is not specified, determine both publish and subscription access.
|
||||
The list should be a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
- `password-file` *(string)*. Path to sha-512 encoded password file, used by `amqtt.plugins.authentication.FileAuthPlugin`.
|
||||
|
||||
*used by the `amqtt.plugins.topic_acl.TopicAclPlugin`*
|
||||
|
||||
- `publish-acl` *(list)*: plugin to determine publish access. This parameter defines the list of access control rules; each item is a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
!!! info "Reserved usernames"
|
||||
|
||||
- The username `admin` is allowed access to all topic.
|
||||
- The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin.
|
||||
??? warning "Deprecated: `topic-check` configuration settings"
|
||||
|
||||
|
||||
*used by the `amqtt.plugins.topic_acl.TopicAclPlugin`*
|
||||
**`topic-check`**
|
||||
|
||||
Configuration for access control policies for publishing and subscribing to topics:
|
||||
|
||||
- `enabled` *(bool)*: Enable access control policies (`true`). `false` will allow clients to publish and subscribe to any topic.
|
||||
- `plugins` *(list[string])*: defines the list of plugins which are activated as access control plugins. Note the plugins must be defined in the `amqtt.broker.plugins` [entry point](https://pythonhosted.org/setuptools/setuptools.html#dynamic-discovery-of-services-and-plugins).
|
||||
|
||||
- `acl` *(list)*: plugin to determine subscription access; if `publish-acl` is not specified, determine both publish and subscription access.
|
||||
The list should be a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
*used by the `amqtt.plugins.topic_acl.TopicAclPlugin`*
|
||||
|
||||
- `publish-acl` *(list)*: plugin to determine publish access. This parameter defines the list of access control rules; each item is a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
_Reserved usernames (used by the `amqtt.plugins.topic_acl.TopicAclPlugin`)_
|
||||
|
||||
- The username `admin` is allowed access to all topic.
|
||||
- The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -130,14 +130,13 @@ listeners:
|
|||
certfile: /some/certfile
|
||||
keyfile: /some/key
|
||||
timeout-disconnect-delay: 2
|
||||
auth:
|
||||
plugins: ['auth_anonymous', 'auth_file']
|
||||
allow-anonymous: true
|
||||
password-file: /some/password-file
|
||||
topic-check:
|
||||
enabled: true
|
||||
plugins: ['topic_acl']
|
||||
acl:
|
||||
plugins:
|
||||
- amqtt.plugins.authentication.AnonymousAuthPlugin:
|
||||
allow-anonymous: true
|
||||
- amqtt.plugin.authentication.FileAuthPlugin:
|
||||
password-file: /some/password-file
|
||||
- amqtt.plugins.topic_checking.TopicAccessControlListPlugin:
|
||||
acl:
|
||||
username1: ['repositories/+/master', 'calendar/#', 'data/memes']
|
||||
username2: [ 'calendar/2025/#', 'data/memes']
|
||||
anonymous: ['calendar/2025/#']
|
||||
|
|
|
@ -68,8 +68,8 @@ TLS certificates used to verify the broker's authenticity.
|
|||
- `cafile` *(string)*: Path to a file of concatenated CA certificates in PEM format. See [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.
|
||||
- `capath` *(string)*: Path to a directory containing several CA certificates in PEM format, following an [OpenSSL specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).
|
||||
- `cadata` *(string)*: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
|
||||
-
|
||||
-
|
||||
|
||||
|
||||
### `certfile` *(string)*
|
||||
|
||||
Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the server certificate's authenticity.
|
||||
|
@ -78,6 +78,11 @@ Path to a single file in PEM format containing the certificate as well as any nu
|
|||
|
||||
Bypass ssl host certificate verification, allowing self-signed certificates
|
||||
|
||||
### `plugins` *(mapping)*
|
||||
|
||||
A list of strings representing the modules and class name of any `BasePlugin`s. Each entry may have one or more
|
||||
configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md)
|
||||
|
||||
|
||||
## Default Configuration
|
||||
|
||||
|
@ -110,4 +115,7 @@ will:
|
|||
broker:
|
||||
uri: mqtt://localhost:1883
|
||||
cafile: /path/to/ca/file
|
||||
plugins:
|
||||
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
|
||||
```
|
||||
|
|
|
@ -33,6 +33,7 @@ dependencies = [
|
|||
"passlib==1.7.4", # https://pypi.org/project/passlib
|
||||
"PyYAML==6.0.2", # https://pypi.org/project/PyYAML
|
||||
"typer==0.15.4",
|
||||
"dacite>=1.9.2",
|
||||
"psutil>=7.0.0",
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
sys_interval: 20
|
||||
auth:
|
||||
plugins:
|
||||
- auth_anonymous
|
||||
allow-anonymous: true
|
||||
topic-check:
|
||||
enabled: False
|
||||
|
|
@ -50,6 +50,11 @@ test_config_acl: dict[str, int | dict[str, Any]] = {
|
|||
@pytest.fixture
|
||||
def mock_plugin_manager():
|
||||
with unittest.mock.patch("amqtt.broker.PluginManager") as plugin_manager:
|
||||
plugin_manager_instance = plugin_manager.return_value
|
||||
|
||||
# disable topic filtering when using the mock manager
|
||||
plugin_manager_instance.is_topic_filtering_enabled.return_value = False
|
||||
|
||||
yield plugin_manager
|
||||
|
||||
|
||||
|
|
|
@ -2,10 +2,8 @@ import logging
|
|||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from amqtt.broker import Action
|
||||
|
||||
from amqtt.plugins.base import BasePlugin, BaseTopicPlugin, BaseAuthPlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.plugins.base import BasePlugin, BaseAuthPlugin, BaseTopicPlugin
|
||||
from amqtt.contexts import BaseContext, Action
|
||||
|
||||
from amqtt.session import Session
|
||||
|
||||
|
@ -29,25 +27,43 @@ class TestConfigPlugin(BasePlugin):
|
|||
option2: str
|
||||
|
||||
|
||||
class AuthPlugin(BaseAuthPlugin):
|
||||
class TestCoroErrorPlugin(BaseAuthPlugin):
|
||||
|
||||
def authenticate(self, *, session: Session) -> bool | None:
|
||||
return True
|
||||
|
||||
|
||||
class TestAuthPlugin(BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return True
|
||||
|
||||
|
||||
class NoAuthPlugin(BaseAuthPlugin):
|
||||
class TestNoAuthPlugin(BaseAuthPlugin):
|
||||
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return False
|
||||
|
||||
|
||||
class TestTopicPlugin(BaseTopicPlugin):
|
||||
class TestAllowTopicPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
def topic_filtering(
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TestBlockTopicPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
logger.debug("topic filtering plugin is returning false")
|
||||
return False
|
||||
|
|
|
@ -3,19 +3,42 @@ import logging
|
|||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from amqtt.plugins.authentication import AnonymousAuthPlugin, FileAuthPlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.plugins.base import BaseAuthPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_no_config(logdog):
|
||||
"""Check BaseTopicPlugin returns false if no topic-check is present."""
|
||||
with logdog() as pile:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {}
|
||||
|
||||
plugin = BaseAuthPlugin(context)
|
||||
s = Session()
|
||||
authorised = await plugin.authenticate(session=s)
|
||||
assert authorised is False
|
||||
|
||||
# Warning messages are only generated if using deprecated plugin configuration on initial load
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 1
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'auth' section not found in context configuration"
|
||||
|
||||
|
||||
class TestAnonymousAuthPlugin(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
|
||||
def test_allow_anonymous(self) -> None:
|
||||
def test_allow_anonymous_dict_config(self) -> None:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
context.config = {"auth": {"allow-anonymous": True}}
|
||||
|
@ -25,6 +48,16 @@ class TestAnonymousAuthPlugin(unittest.TestCase):
|
|||
ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s))
|
||||
assert ret
|
||||
|
||||
def test_allow_anonymous_dataclass_config(self) -> None:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
context.config = AnonymousAuthPlugin.Config(allow_anonymous=True)
|
||||
s = Session()
|
||||
s.username = ""
|
||||
auth_plugin = AnonymousAuthPlugin(context)
|
||||
ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s))
|
||||
assert ret
|
||||
|
||||
def test_disallow_anonymous(self) -> None:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from yaml import CLoader as Loader
|
||||
from dacite import from_dict, Config, UnexpectedDataError
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import PluginLoadError, ConnectError, PluginCoroError
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
plugin_config = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestSimplePlugin:
|
||||
- tests.plugins.mocks.TestConfigPlugin:
|
||||
option1: 1
|
||||
option2: bar
|
||||
"""
|
||||
|
||||
|
||||
plugin_invalid_config_one = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestSimplePlugin:
|
||||
option1: 1
|
||||
option2: bar
|
||||
"""
|
||||
|
||||
plugin_invalid_config_two = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestConfigPlugin:
|
||||
"""
|
||||
|
||||
plugin_coro_error_config = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestCoroErrorPlugin:
|
||||
"""
|
||||
|
||||
plugin_config_auth = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestAuthPlugin:
|
||||
"""
|
||||
|
||||
plugin_config_no_auth = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestNoAuthPlugin:
|
||||
"""
|
||||
|
||||
|
||||
plugin_config_topic = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestAllowTopicPlugin:
|
||||
"""
|
||||
|
||||
|
||||
plugin_config_topic_block = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestBlockTopicPlugin:
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_config_extra_fields():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginLoadError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_config_missing_fields():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginLoadError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternate_plugin_load():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
|
||||
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coro_error_plugin_load():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_coro_error_config, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginCoroError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
client1 = MQTTClient()
|
||||
await client1.connect()
|
||||
await client1.publish('my/topic', b'my message')
|
||||
await client1.disconnect()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_auth_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_no_auth, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
client1 = MQTTClient(config={'auto_reconnect': False})
|
||||
with pytest.raises(ConnectError):
|
||||
await client1.connect()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_topic_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
client2 = MQTTClient(config={'auto_reconnect': False})
|
||||
await client2.connect()
|
||||
await client2.subscribe([
|
||||
('my/topic', QOS_0)
|
||||
])
|
||||
|
||||
client1 = MQTTClient(config={'auto_reconnect': True})
|
||||
await client1.connect()
|
||||
await client1.publish('my/topic', b'my message')
|
||||
|
||||
message = await client2.deliver_message(timeout_duration=1)
|
||||
assert message.topic == 'my/topic'
|
||||
assert message.data == b'my message'
|
||||
|
||||
await client2.disconnect()
|
||||
await client1.disconnect()
|
||||
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_topic_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_topic_block, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
client2 = MQTTClient(config={'auto_reconnect': False})
|
||||
await client2.connect()
|
||||
await client2.subscribe([
|
||||
('my/topic', QOS_0)
|
||||
])
|
||||
|
||||
client1 = MQTTClient(config={'auto_reconnect': True})
|
||||
await client1.connect()
|
||||
await client1.publish('my/topic', b'my message')
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
message = await client2.deliver_message(timeout_duration=1)
|
||||
logger.debug(f"msg received: {message.topic} >> {message.data}")
|
||||
|
||||
await client2.disconnect()
|
||||
await client1.disconnect()
|
||||
|
||||
await broker.shutdown()
|
|
@ -2,10 +2,11 @@ import asyncio
|
|||
import logging
|
||||
import unittest
|
||||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.events import BrokerEvents
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.plugins.base import BaseTopicPlugin, BaseAuthPlugin
|
||||
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.contexts import BaseContext, Action
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
|
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.plugins.persistence import SQLitePlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
|
|
|
@ -13,11 +13,11 @@ import pytest
|
|||
import amqtt.plugins
|
||||
from amqtt.broker import Broker, BrokerContext
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import PluginError, PluginInitError, PluginImportError
|
||||
from amqtt.errors import PluginInitError, PluginImportError
|
||||
from amqtt.events import MQTTEvents, BrokerEvents
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.contexts import BaseContext
|
||||
|
||||
_INVALID_METHOD: str = "invalid_foo"
|
||||
_PLUGIN: str = "Plugin"
|
||||
|
@ -82,54 +82,36 @@ class MockInitErrorPlugin(BasePlugin):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_while_init() -> None:
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'plugins':{
|
||||
'tests.plugins.test_plugins.MockInitErrorPlugin':{}
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(PluginInitError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
with pytest.raises(PluginInitError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_while_loading() -> None:
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'plugins':{
|
||||
'tests.plugins.mock_plugins.MockImportErrorPlugin':{}
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(PluginImportError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
with pytest.raises(PluginImportError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
||||
|
||||
class AllEventsPlugin(BasePlugin[BaseContext]):
|
||||
|
@ -153,47 +135,37 @@ class AllEventsPlugin(BasePlugin[BaseContext]):
|
|||
if name not in ('authenticate', 'topic_filtering'):
|
||||
pytest.fail(f'unexpected method called: {name}')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_plugin_events():
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='AllEventsPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:AllEventsPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
# patch the entry points so we can load our test plugin
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'plugins':{
|
||||
'tests.plugins.test_plugins.AllEventsPlugin': {}
|
||||
}
|
||||
}
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
await broker.start()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await broker.start()
|
||||
await asyncio.sleep(2)
|
||||
# make sure all expected events get triggered
|
||||
client = MQTTClient()
|
||||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
await client.subscribe([('my/test/topic', QOS_0),])
|
||||
await client.publish('test/topic', b'my test message')
|
||||
await client.unsubscribe(['my/test/topic',])
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# make sure all expected events get triggered
|
||||
client = MQTTClient()
|
||||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
await client.subscribe([('my/test/topic', QOS_0),])
|
||||
await client.publish('test/topic', b'my test message')
|
||||
await client.unsubscribe(['my/test/topic',])
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(1)
|
||||
# get the plugin so it doesn't get gc on shutdown
|
||||
test_plugin = broker.plugins_manager.get_plugin('AllEventsPlugin')
|
||||
await broker.shutdown()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# get the plugin so it doesn't get gc on shutdown
|
||||
test_plugin = broker.plugins_manager.get_plugin('AllEventsPlugin')
|
||||
await broker.shutdown()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
|
||||
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from importlib.metadata import EntryPoint
|
||||
from logging.config import dictConfig
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
@ -9,8 +10,34 @@ from amqtt.broker import Broker
|
|||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
|
||||
dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'verbose': {
|
||||
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
|
||||
}
|
||||
},
|
||||
'handlers': {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': 'DEBUG',
|
||||
'formatter': 'verbose',
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'transitions': {
|
||||
'level': 'WARNING',
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
all_sys_topics = [
|
||||
'$SYS/broker/version',
|
||||
'$SYS/broker/load/bytes/received',
|
||||
|
@ -42,7 +69,7 @@ all_sys_topics = [
|
|||
|
||||
# test broker sys
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_sys_plugin() -> None:
|
||||
async def test_broker_sys_plugin_deprecated_config() -> None:
|
||||
|
||||
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
|
||||
|
||||
|
@ -64,7 +91,8 @@ async def test_broker_sys_plugin() -> None:
|
|||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
'sys_interval': 1,
|
||||
'auth': {}
|
||||
}
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
@ -92,3 +120,45 @@ async def test_broker_sys_plugin() -> None:
|
|||
assert sys_msg_count > 1
|
||||
|
||||
assert all(sys_topic_flags.values()), f'topic not received: {[ topic for topic, flag in sys_topic_flags.items() if not flag ]}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_sys_plugin_config() -> None:
|
||||
|
||||
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'plugins': [
|
||||
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
|
||||
]
|
||||
}
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
await broker.start()
|
||||
client = MQTTClient()
|
||||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
await client.subscribe([("$SYS/#", QOS_0), ])
|
||||
await client.publish('test/topic', b'my test message')
|
||||
await asyncio.sleep(2)
|
||||
sys_msg_count = 0
|
||||
try:
|
||||
while sys_msg_count < 30:
|
||||
message = await client.deliver_message(timeout_duration=1)
|
||||
if '$SYS' in message.topic:
|
||||
sys_msg_count += 1
|
||||
assert message.topic in sys_topic_flags
|
||||
sys_topic_flags[message.topic] = True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"TimeoutError after {sys_msg_count} messages")
|
||||
|
||||
await client.disconnect()
|
||||
await broker.shutdown()
|
||||
|
||||
assert sys_msg_count > 1
|
||||
|
||||
assert all(
|
||||
sys_topic_flags.values()), f'topic not received: {[topic for topic, flag in sys_topic_flags.items() if not flag]}'
|
||||
|
|
|
@ -2,12 +2,15 @@ import logging
|
|||
|
||||
import pytest
|
||||
|
||||
from amqtt.broker import Action, BrokerContext, Broker
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.broker import BrokerContext, Broker
|
||||
|
||||
from amqtt.contexts import BaseContext, Action
|
||||
from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabooPlugin
|
||||
from amqtt.plugins.base import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base plug-in object
|
||||
|
||||
|
||||
|
@ -23,15 +26,11 @@ async def test_base_no_config(logdog):
|
|||
authorised = await plugin.topic_filtering()
|
||||
assert authorised is False
|
||||
|
||||
# Should have printed a couple of warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 2
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
|
||||
assert log_records[1].levelno == logging.WARNING
|
||||
assert log_records[1].message == "'topic-check' section not found in context configuration"
|
||||
assert pile.is_empty()
|
||||
# Warning messages are only generated if using deprecated plugin configuration on initial load
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 1
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -47,14 +46,11 @@ async def test_base_empty_config(logdog):
|
|||
authorised = await plugin.topic_filtering()
|
||||
assert authorised is False
|
||||
|
||||
# Should have printed just one warning
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 2
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
|
||||
assert log_records[1].levelno == logging.WARNING
|
||||
assert log_records[1].message == "'topic-check' section not found in context configuration"
|
||||
# Warning messages are only generated if using deprecated plugin configuration on initial load
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 1
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -69,9 +65,9 @@ async def test_base_disabled_config(logdog):
|
|||
authorised = await plugin.topic_filtering()
|
||||
assert authorised is True
|
||||
|
||||
# Should NOT have printed warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 0
|
||||
# Should NOT have printed warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -86,9 +82,9 @@ async def test_base_enabled_config(logdog):
|
|||
authorised = await plugin.topic_filtering()
|
||||
assert authorised is True
|
||||
|
||||
# Should NOT have printed warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 0
|
||||
# Should NOT have printed warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 0
|
||||
|
||||
|
||||
# Taboo plug-in
|
||||
|
@ -105,13 +101,11 @@ async def test_taboo_empty_config(logdog):
|
|||
plugin = TopicTabooPlugin(context)
|
||||
assert (await plugin.topic_filtering()) is False
|
||||
|
||||
# Should have printed a couple of warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 2
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
assert log_records[1].levelno == logging.WARNING
|
||||
assert log_records[1].message == "'topic-check' section not found in context configuration"
|
||||
# Warning messages are only generated if using deprecated plugin configuration on initial load
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 1
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -133,13 +127,17 @@ async def test_taboo_disabled(logdog):
|
|||
assert len(log_records) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({"topic-check": {"enabled": True}}),
|
||||
(TopicTabooPlugin.Config())
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taboo_not_taboo_topic(logdog):
|
||||
async def test_taboo_not_taboo_topic(logdog, test_config):
|
||||
"""Check TopicTabooPlugin returns true if topic not taboo."""
|
||||
with logdog() as pile:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {"topic-check": {"enabled": True}}
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "anybody"
|
||||
|
@ -152,13 +150,17 @@ async def test_taboo_not_taboo_topic(logdog):
|
|||
assert len(log_records) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({"topic-check": {"enabled": True}}),
|
||||
(TopicTabooPlugin.Config())
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taboo_anon_taboo_topic(logdog):
|
||||
async def test_taboo_anon_taboo_topic(logdog, test_config):
|
||||
"""Check TopicTabooPlugin returns false if topic is taboo and session is anonymous."""
|
||||
with logdog() as pile:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {"topic-check": {"enabled": True}}
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = ""
|
||||
|
@ -171,13 +173,17 @@ async def test_taboo_anon_taboo_topic(logdog):
|
|||
assert len(log_records) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({"topic-check": {"enabled": True}}),
|
||||
(TopicTabooPlugin.Config())
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taboo_notadmin_taboo_topic(logdog):
|
||||
async def test_taboo_notadmin_taboo_topic(logdog, test_config):
|
||||
"""Check TopicTabooPlugin returns false if topic is taboo and user is not "admin"."""
|
||||
with logdog() as pile:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {"topic-check": {"enabled": True}}
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "notadmin"
|
||||
|
@ -189,14 +195,17 @@ async def test_taboo_notadmin_taboo_topic(logdog):
|
|||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({"topic-check": {"enabled": True}}),
|
||||
(TopicTabooPlugin.Config())
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taboo_admin_taboo_topic(logdog):
|
||||
async def test_taboo_admin_taboo_topic(logdog, test_config):
|
||||
"""Check TopicTabooPlugin returns true if topic is taboo and user is "admin"."""
|
||||
with logdog() as pile:
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {"topic-check": {"enabled": True}}
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "admin"
|
||||
|
@ -265,11 +274,11 @@ async def test_taclp_empty_config(logdog):
|
|||
plugin = TopicAccessControlListPlugin(context)
|
||||
assert (await plugin.topic_filtering()) is False
|
||||
|
||||
# Should have printed a couple of warnings
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 2
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
assert log_records[1].message == "'topic-check' section not found in context configuration"
|
||||
# Warning messages are only generated if using deprecated plugin configuration on initial load
|
||||
log_records = list(pile.drain(name="testlog"))
|
||||
assert len(log_records) == 1
|
||||
assert log_records[0].levelno == logging.WARNING
|
||||
assert log_records[0].message == "'topic-check' section not found in context configuration"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -291,15 +300,19 @@ async def test_taclp_true_disabled(logdog):
|
|||
assert authorised is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({"topic-check": {"enabled": True}}),
|
||||
(TopicAccessControlListPlugin.Config())
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_no_pub_acl(logdog):
|
||||
async def test_taclp_true_no_pub_acl(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin returns true if action=publish and no publish-acl given.
|
||||
|
||||
(This is for backward-compatibility with existing installations.).
|
||||
"""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {"topic-check": {"enabled": True}}
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "user"
|
||||
|
@ -313,17 +326,23 @@ async def test_taclp_true_no_pub_acl(logdog):
|
|||
assert authorised is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_false_sub_no_topic(logdog):
|
||||
"""Check TopicAccessControlListPlugin returns false user there is no topic."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({
|
||||
"topic-check": {
|
||||
"enabled": True,
|
||||
"acl": {"anotheruser": ["allowed/topic", "another/allowed/topic/#"]},
|
||||
},
|
||||
}
|
||||
}),
|
||||
(TopicAccessControlListPlugin.Config(
|
||||
acl={"anotheruser": ["allowed/topic", "another/allowed/topic/#"]}
|
||||
))
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_false_sub_no_topic(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin returns false user there is no topic."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "user"
|
||||
|
@ -337,17 +356,23 @@ async def test_taclp_false_sub_no_topic(logdog):
|
|||
assert authorised is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_false_sub_unknown_user(logdog):
|
||||
"""Check TopicAccessControlListPlugin returns false user is not listed in ACL."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({
|
||||
"topic-check": {
|
||||
"enabled": True,
|
||||
"acl": {"anotheruser": ["allowed/topic", "another/allowed/topic/#"]},
|
||||
},
|
||||
}
|
||||
}),
|
||||
(TopicAccessControlListPlugin.Config(
|
||||
acl={"anotheruser": ["allowed/topic", "another/allowed/topic/#"]}
|
||||
))
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_false_sub_unknown_user(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin returns false user is not listed in ACL."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "user"
|
||||
|
@ -361,17 +386,23 @@ async def test_taclp_false_sub_unknown_user(logdog):
|
|||
assert authorised is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_false_sub_no_permission(logdog):
|
||||
"""Check TopicAccessControlListPlugin returns false if "acl" does not list allowed topic."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({
|
||||
"topic-check": {
|
||||
"enabled": True,
|
||||
"acl": {"user": ["allowed/topic", "another/allowed/topic/#"]},
|
||||
},
|
||||
}
|
||||
}),
|
||||
(TopicAccessControlListPlugin.Config(
|
||||
acl={"user": ["allowed/topic", "another/allowed/topic/#"]}
|
||||
))
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_false_sub_no_permission(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin returns false if "acl" does not list allowed topic."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "user"
|
||||
|
@ -384,18 +415,23 @@ async def test_taclp_false_sub_no_permission(logdog):
|
|||
)
|
||||
assert authorised is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_sub_permission(logdog):
|
||||
"""Check TopicAccessControlListPlugin returns true if "acl" lists allowed topic."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({
|
||||
"topic-check": {
|
||||
"enabled": True,
|
||||
"acl": {"user": ["allowed/topic", "another/allowed/topic/#"]},
|
||||
},
|
||||
}
|
||||
}),
|
||||
(TopicAccessControlListPlugin.Config(
|
||||
acl={"user": ["allowed/topic", "another/allowed/topic/#"]}
|
||||
))
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_sub_permission(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin returns true if "acl" lists allowed topic."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "user"
|
||||
|
@ -409,17 +445,23 @@ async def test_taclp_true_sub_permission(logdog):
|
|||
assert authorised is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_pub_permission(logdog):
|
||||
"""Check TopicAccessControlListPlugin returns true if "publish-acl" lists allowed topic for publish action."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({
|
||||
"topic-check": {
|
||||
"enabled": True,
|
||||
"publish-acl": {"user": ["allowed/topic", "another/allowed/topic/#"]},
|
||||
},
|
||||
}
|
||||
}),
|
||||
(TopicAccessControlListPlugin.Config(
|
||||
publish_acl={"user": ["allowed/topic", "another/allowed/topic/#"]}
|
||||
))
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_pub_permission(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin returns true if "publish-acl" lists allowed topic for publish action."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = "user"
|
||||
|
@ -433,17 +475,23 @@ async def test_taclp_true_pub_permission(logdog):
|
|||
assert authorised is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_anon_sub_permission(logdog):
|
||||
"""Check TopicAccessControlListPlugin handles anonymous users."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = {
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
({
|
||||
"topic-check": {
|
||||
"enabled": True,
|
||||
"acl": {"anonymous": ["allowed/topic", "another/allowed/topic/#"]},
|
||||
},
|
||||
}
|
||||
}),
|
||||
(TopicAccessControlListPlugin.Config(
|
||||
acl={"anonymous": ["allowed/topic", "another/allowed/topic/#"]}
|
||||
))
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_taclp_true_anon_sub_permission(logdog, test_config):
|
||||
"""Check TopicAccessControlListPlugin handles anonymous users."""
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger("testlog")
|
||||
context.config = test_config
|
||||
|
||||
session = Session()
|
||||
session.username = None
|
||||
|
|
|
@ -463,7 +463,7 @@ async def test_client_no_auth():
|
|||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:NoAuthPlugin'),
|
||||
EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:TestNoAuthPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
import subprocess
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import warnings
|
||||
|
||||
from amqtt.version import get_git_changeset, get_version
|
||||
|
||||
|
||||
class TestVersionFunctions(unittest.TestCase):
|
||||
@patch("amqtt.version.warnings.warn")
|
||||
def test_get_version(self, mock_warn):
|
||||
"""Test get_version returns amqtt.__version__ and raises a deprecation warning."""
|
||||
with patch("amqtt.__version__", "1.2.3"):
|
||||
version = get_version()
|
||||
assert version == "1.2.3"
|
||||
mock_warn.assert_called_once_with(
|
||||
"amqtt.version.get_version() is deprecated, use amqtt.__version__ instead",
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
def test_get_version_no_warning(self):
|
||||
"""Test get_version does not trigger a warning when explicitly suppressed."""
|
||||
with patch("amqtt.__version__", "1.2.3"), warnings.catch_warnings(record=True) as captured_warnings:
|
||||
warnings.simplefilter("ignore")
|
||||
version = get_version()
|
||||
assert version == "1.2.3"
|
||||
assert len(captured_warnings) == 0 # No warnings should be captured
|
||||
|
||||
@patch("amqtt.version.Path")
|
||||
@patch("amqtt.version.shutil.which")
|
||||
@patch("amqtt.version.subprocess.Popen")
|
||||
def test_get_git_changeset(self, mock_popen, mock_which, mock_path):
|
||||
"""Test get_git_changeset returns the correct timestamp or None on failure."""
|
||||
# Mock the repo directory
|
||||
mock_repo_dir = MagicMock()
|
||||
mock_repo_dir.is_dir.return_value = True
|
||||
mock_path.return_value.resolve.return_value.parent.parent = mock_repo_dir
|
||||
|
||||
# Mock git executable check
|
||||
mock_which.return_value = True
|
||||
|
||||
# Mock subprocess.Popen for git log with context manager behavior
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = ("1638352940", "")
|
||||
mock_process.returncode = 0
|
||||
mock_popen.return_value.__enter__.return_value = mock_process
|
||||
|
||||
# Call the function
|
||||
changeset = get_git_changeset()
|
||||
|
||||
# Verify the results
|
||||
assert changeset == "20211201100220" # Matches timestamp conversion
|
||||
mock_which.assert_called_once_with("git")
|
||||
mock_popen.assert_called_once_with(
|
||||
["git", "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=mock_repo_dir,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
# Test invalid directory
|
||||
mock_repo_dir.is_dir.return_value = False
|
||||
changeset = get_git_changeset()
|
||||
assert changeset is None
|
||||
|
||||
# Test missing git
|
||||
mock_repo_dir.is_dir.return_value = True
|
||||
mock_which.return_value = False
|
||||
changeset = get_git_changeset()
|
||||
assert changeset is None
|
||||
|
||||
# Test git command failure
|
||||
mock_which.return_value = True
|
||||
mock_process.returncode = 1
|
||||
mock_process.communicate.return_value = ("", "Some error")
|
||||
changeset = get_git_changeset()
|
||||
assert changeset is None
|
11
uv.lock
11
uv.lock
|
@ -12,6 +12,7 @@ name = "amqtt"
|
|||
version = "0.11.1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "dacite" },
|
||||
{ name = "passlib" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pyyaml" },
|
||||
|
@ -67,6 +68,7 @@ docs = [
|
|||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
|
||||
{ name = "dacite", specifier = ">=1.9.2" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
{ name = "pyyaml", specifier = "==6.0.2" },
|
||||
|
@ -487,6 +489,15 @@ version = "0.9.5"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" }
|
||||
|
||||
[[package]]
|
||||
name = "dacite"
|
||||
version = "1.9.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.4.0"
|
||||
|
|
Ładowanie…
Reference in New Issue