diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9cdc91..8301d3d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,75 +10,6 @@ ci: - pylint repos: - # Codespell for spelling corrections - - repo: https://github.com/codespell-project/codespell - rev: v2.4.1 - hooks: - - id: codespell - args: - - --ignore-words-list=ihs,ro,fo,assertIn,astroid,formated - - --skip="./.*,*.csv,*.json" - - --quiet-level=2 - exclude_types: - - csv - - json - - # General pre-commit hooks - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 - hooks: - - id: detect-private-key - exclude: tests/_test_files/certs/ - - id: check-merge-conflict - - id: check-added-large-files - - id: check-case-conflict - # - id: no-commit-to-branch - # args: [--branch, main] - - id: check-executables-have-shebangs - - id: trailing-whitespace - name: Trim Trailing Whitespace - description: This hook trims trailing whitespace. - entry: trailing-whitespace-fixer - language: python - types: [text] - args: [--markdown-linebreak-ext=md] - - id: check-toml - - id: check-json - - id: check-yaml - args: [--allow-multiple-documents] - - id: mixed-line-ending - - # Prettier for code formatting - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - additional_dependencies: - - prettier@3.2.5 - - prettier-plugin-sort-json@3.1.0 - exclude_types: - - python - - # Secret detection - - repo: https://github.com/Yelp/detect-secrets - rev: v1.5.0 - hooks: - - id: detect-secrets - args: - - --exclude-files=tests/* - - --exclude-files=samples/client_subscribe_acl.py - - --exclude-files=docs/quickstart.rst - - repo: https://github.com/gitleaks/gitleaks - rev: v8.26.0 - hooks: - - id: gitleaks - - # YAML Linting - - repo: https://github.com/adrienverge/yamllint.git - rev: v1.37.1 - hooks: - - id: yamllint - # Python-specific hooks ###################################################### - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.10 @@ -89,12 +20,6 @@ repos: - --unsafe-fixes - --line-length=130 - --exit-non-zero-on-fix - - id: ruff-format - - repo: https://github.com/asottile/pyupgrade - rev: v3.19.1 - hooks: - - id: pyupgrade - args: [--py313-plus] # Local hooks for mypy and pylint - repo: local diff --git a/.prettierrc.yml b/.prettierrc.yml deleted file mode 100644 index 14fd7f3..0000000 --- a/.prettierrc.yml +++ /dev/null @@ -1,2 +0,0 @@ ---- -jsonRecursiveSort: true diff --git a/.yamllint b/.yamllint deleted file mode 100644 index 9cdab86..0000000 --- a/.yamllint +++ /dev/null @@ -1,70 +0,0 @@ ---- -extends: default - -yaml-files: - - "*.yaml" - - "*.yml" - - ".yamllint" - -ignore-from-file: .gitignore - -rules: - braces: - level: error - min-spaces-inside: 0 - max-spaces-inside: 1 - min-spaces-inside-empty: -1 - max-spaces-inside-empty: -1 - brackets: - level: error - min-spaces-inside: 0 - max-spaces-inside: 1 - min-spaces-inside-empty: -1 - max-spaces-inside-empty: -1 - colons: - level: error - max-spaces-before: 0 - max-spaces-after: 1 - commas: - level: error - max-spaces-before: 0 - min-spaces-after: 1 - max-spaces-after: 1 - comments: - level: error - require-starting-space: true - min-spaces-from-content: 1 - comments-indentation: false - document-end: - level: error - present: false - document-start: - level: warning - present: true - empty-lines: - level: error - max: 1 - max-start: 0 - max-end: 1 - hyphens: - level: error - max-spaces-after: 1 - indentation: - level: error - spaces: 2 - indent-sequences: consistent - check-multi-line-strings: false - key-duplicates: - level: error - line-length: disable - new-line-at-end-of-file: - level: error - new-lines: - level: error - type: unix - trailing-spaces: - level: error - truthy: disable - octal-values: - forbid-implicit-octal: true - forbid-explicit-octal: true diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index c23549e..8b85917 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -1,31 +1,30 @@ from pathlib import Path +from typing import Any from passlib.apps import custom_app_context as pwd_context from amqtt.broker import BrokerContext +from amqtt.plugins.base import BasePlugin from amqtt.session import Session _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line -class BaseAuthPlugin: +class BaseAuthPlugin(BasePlugin): """Base class for authentication plugins.""" def __init__(self, context: BrokerContext) -> None: - self.context = context - self.auth_config = self.context.config.get("auth", None) if self.context.config else None + super().__init__(context) + + self.auth_config: dict[str, Any] | None = self._get_config_section("auth") if not self.auth_config: self.context.logger.warning("'auth' section not found in context configuration") - async def authenticate(self, *args: None, **kwargs: Session) -> bool | None: + async def authenticate(self, *, session: Session) -> bool | None: """Logic for session authentication. Args: - *args: positional arguments (not used) - **kwargs: payload from broker - ``` - session: amqtt.session.Session - ``` + session: amqtt.session.Session Returns: - `True` if user is authentication succeed, `False` if user authentication fails @@ -42,8 +41,8 @@ class BaseAuthPlugin: class AnonymousAuthPlugin(BaseAuthPlugin): """Authentication plugin allowing anonymous access.""" - async def authenticate(self, *args: None, **kwargs: Session) -> bool: - authenticated = await super().authenticate(*args, **kwargs) + async def authenticate(self, *, session: Session) -> bool: + authenticated = await super().authenticate(session=session) if authenticated: # Default to allowing anonymous allow_anonymous = self.auth_config.get("allow-anonymous", True) if isinstance(self.auth_config, dict) else True @@ -51,7 +50,6 @@ class AnonymousAuthPlugin(BaseAuthPlugin): self.context.logger.debug("Authentication success: config allows anonymous") return True - session: Session | None = kwargs.get("session") if session and session.username: self.context.logger.debug(f"Authentication success: session has username '{session.username}'") return True @@ -95,11 +93,10 @@ class FileAuthPlugin(BaseAuthPlugin): except Exception: self.context.logger.exception(f"Unexpected error reading password file '{password_file}'") - async def authenticate(self, *args: None, **kwargs: Session) -> bool | None: + async def authenticate(self, *, session: Session) -> bool | None: """Authenticate users based on the file-stored user database.""" - authenticated = await super().authenticate(*args, **kwargs) + authenticated = await super().authenticate(session=session) if authenticated: - session = kwargs.get("session") if not session: self.context.logger.debug("Authentication failure: no session provided") return False diff --git a/amqtt/plugins/base.py b/amqtt/plugins/base.py new file mode 100644 index 0000000..89f4443 --- /dev/null +++ b/amqtt/plugins/base.py @@ -0,0 +1,19 @@ +from typing import Any + +from amqtt.broker import BrokerContext + + +class BasePlugin: + """The base from which all plugins should inherit.""" + + def __init__(self, context: BrokerContext) -> None: + self.context = context + + def _get_config_section(self, name: str) -> dict[str, Any] | None: + if not self.context.config or not self.context.config.get(name, None): + return None + section_config: int | dict[str, Any] | None = self.context.config.get(name, None) + # mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check + if isinstance(section_config, int): + return None + return section_config diff --git a/amqtt/plugins/logging_amqtt.py b/amqtt/plugins/logging_amqtt.py index 42124f8..84cf6d7 100644 --- a/amqtt/plugins/logging_amqtt.py +++ b/amqtt/plugins/logging_amqtt.py @@ -3,18 +3,15 @@ from functools import partial import logging from typing import TYPE_CHECKING, Any -from amqtt.plugins.manager import BaseContext +from amqtt.plugins.base import BasePlugin if TYPE_CHECKING: from amqtt.session import Session -class EventLoggerPlugin: +class EventLoggerPlugin(BasePlugin): """A plugin to log events dynamically based on method names.""" - def __init__(self, context: BaseContext) -> None: - self.context = context - async def log_event(self, *args: Any, **kwargs: Any) -> None: """Log the occurrence of an event.""" event_name = kwargs["event_name"].replace("old", "") @@ -28,12 +25,9 @@ class EventLoggerPlugin: raise AttributeError(msg) -class PacketLoggerPlugin: +class PacketLoggerPlugin(BasePlugin): """A plugin to log MQTT packets sent and received.""" - def __init__(self, context: BaseContext) -> None: - self.context = context - async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: """Log an MQTT packet when it is received.""" packet = kwargs.get("packet") diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index 1ecd1ee..f7ede3f 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -2,6 +2,8 @@ import asyncio from collections import deque # pylint: disable=C0412 from typing import SupportsIndex, SupportsInt # pylint: disable=C0412 +from amqtt.plugins.base import BasePlugin + try: from collections.abc import Buffer except ImportError: @@ -40,9 +42,9 @@ STAT_CLIENTS_CONNECTED = "clients_connected" STAT_CLIENTS_DISCONNECTED = "clients_disconnected" -class BrokerSysPlugin: +class BrokerSysPlugin(BasePlugin): def __init__(self, context: BrokerContext) -> None: - self.context = context + super().__init__(context) # Broker statistics initialization self._stats: dict[str, int] = {} self._sys_handle: asyncio.Handle | None = None diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index e520203..b9f91ae 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -1,30 +1,29 @@ from typing import Any -from amqtt.broker import Action -from amqtt.plugins.manager import BaseContext +from amqtt.broker import Action, BrokerContext +from amqtt.plugins.base import BasePlugin +from amqtt.session import Session -class BaseTopicPlugin: +class BaseTopicPlugin(BasePlugin): """Base class for topic plugins.""" - def __init__(self, context: BaseContext) -> None: - self.context = context - self.topic_config: dict[str, Any] | None = self.context.config.get("topic-check", None) if self.context.config else None + def __init__(self, context: BrokerContext) -> None: + super().__init__(context) + + self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") if self.topic_config is None: self.context.logger.warning("'topic-check' section not found in context configuration") - async def topic_filtering(self, *args: Any, **kwargs: Any) -> bool: + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: """Logic for filtering out topics. Args: - *args: positional arguments (not used) - - **kwargs: payload from broker - ``` - session: amqtt.session.Session - topic: str - action: amqtt.broker.Action - ``` + session: amqtt.session.Session + topic: str + action: amqtt.broker.Action Returns: bool: `True` if topic is allowed, `False` otherwise @@ -32,21 +31,21 @@ class BaseTopicPlugin: """ if not self.topic_config: # auth config section not found - self.context.logger.warning("'auth' section not found in context configuration") + self.context.logger.warning("'topic-check' section not found in context configuration") return False return True class TopicTabooPlugin(BaseTopicPlugin): - def __init__(self, context: BaseContext) -> None: + def __init__(self, context: BrokerContext) -> None: super().__init__(context) self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"] - async def topic_filtering(self, *args: Any, **kwargs: Any) -> bool: - filter_result = await super().topic_filtering(*args, **kwargs) + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + filter_result = await super().topic_filtering(session=session, topic=topic, action=action) if filter_result: - session = kwargs.get("session") - topic = kwargs.get("topic") if session and session.username == "admin": return True return not (topic and topic in self._taboo) @@ -54,6 +53,7 @@ class TopicTabooPlugin(BaseTopicPlugin): class TopicAccessControlListPlugin(BaseTopicPlugin): + @staticmethod def topic_ac(topic_requested: str, topic_allowed: str) -> bool: req_split = topic_requested.split("/") @@ -74,22 +74,22 @@ class TopicAccessControlListPlugin(BaseTopicPlugin): break return ret - async def topic_filtering(self, *args: Any, **kwargs: Any) -> bool: - filter_result = await super().topic_filtering(*args, **kwargs) + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + filter_result = await super().topic_filtering(session=session, topic=topic, action=action) if not filter_result: return False # hbmqtt and older amqtt do not support publish filtering - action = kwargs.get("action") if action == Action.PUBLISH and self.topic_config is not None and "publish-acl" not in self.topic_config: # maintain backward compatibility, assume permitted return True - req_topic = kwargs.get("topic") + req_topic = topic if not req_topic: return False - session = kwargs.get("session") username = session.username if session else None if username is None: username = "anonymous" @@ -100,7 +100,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin): elif self.topic_config is not None and action == Action.SUBSCRIBE: acl = self.topic_config.get("acl", {}) - allowed_topics = acl.get(username, None) + allowed_topics = acl.get(username, []) if not allowed_topics: return False diff --git a/docs/custom_plugins.md b/docs/custom_plugins.md index fe534d5..218270c 100644 --- a/docs/custom_plugins.md +++ b/docs/custom_plugins.md @@ -1,7 +1,30 @@ # Custom Plugins -Every plugin listed in the `project.entry-points` is loaded and notified of events -by defining any of the following methods: +With the aMQTT Broker plugins framework, one can add additional functionality to the broker without +having to subclass or rewrite any of the core broker logic. To define a custom list of plugins to be loaded, +add this section to your `pyproject.toml`" + +```toml +[project.entry-points."mypackage.mymodule.plugins"] +plugin_alias = "module.submodule.file:ClassName" +``` + +and specify the namespace when instantiating the broker: + +```python +from amqtt.broker import Broker + +broker = Broker(plugin_namespace='mypackage.mymodule.plugins') + +``` + +Each plugin has access to the full configuration file through the provided `BaseContext` and can define +its own variables to configure its behavior. + +::: amqtt.plugins.base.BasePlugin + +Plugins that are defined in the`project.entry-points` are loaded and notified of events by when the subclass +implements one or more of these methods: - `on_mqtt_packet_sent` - `on_mqtt_packet_received` @@ -18,7 +41,7 @@ by defining any of the following methods: ## Authentication Plugins -Of the plugins listed in `project.entry-points`, plugins can be used to validate client sessions +Of the plugins listed in `project.entry-points`, one or more can be used to validate client sessions by specifying their alias in `auth` > `plugins` section of the config: ```yaml @@ -27,22 +50,23 @@ auth: - plugin_alias_name ``` -These plugins should sub-class from `BaseAuthPlugin` and implement the `authenticate` method. +These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method. ::: amqtt.plugins.authentication.BaseAuthPlugin ## Topic Filter Plugins -Of the plugins listed in `project.entry-points`, plugins can be used to validate client sessions +Of the plugins listed in `project.entry-points`, one or more can be used to determine topic access by specifying their alias in `topic-check` > `plugins` section of the config: ```yaml topic-check: + enable: True plugins: - plugin_alias_name ``` -These plugins should sub-class from `BaseTopicPlugin` and implement the `topic_filtering` method. +These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method. ::: amqtt.plugins.topic_checking.BaseTopicPlugin diff --git a/docs/packaged_plugins.md b/docs/packaged_plugins.md index 5d2237c..c2502e4 100644 --- a/docs/packaged_plugins.md +++ b/docs/packaged_plugins.md @@ -1,12 +1,10 @@ # Existing Plugins With the aMQTT Broker plugins framework, one can add additional functionality without -having to rewrite core logic. The list of plugins that get loaded are specified in `pyproject.toml`; -each plugin can then check the configuration to determine how to behave (including disabling). +having to rewrite core logic. Plugins loaded by default are specified in `pyproject.toml`: -```toml -[project.entry-points."amqtt.broker.plugins"] -plugin_alias = "module.submodule.file:ClassName" +```yaml +--8<-- "pyproject.toml:included" ``` ## auth_anonymous (Auth Plugin) @@ -14,7 +12,7 @@ plugin_alias = "module.submodule.file:ClassName" `amqtt.plugins.authentication:AnonymousAuthPlugin` -**Config Options** +**Configuration** ```yaml auth: @@ -34,7 +32,7 @@ auth: clients are authorized by providing username and password, compared against file -**Config Options** +**Configuration** ```yaml @@ -64,7 +62,6 @@ print(sha512_crypt.hash(passwd)) `amqtt.plugins.topic_checking:TopicTabooPlugin` - Prevents using topics named: `prohibited`, `top-secret`, and `data/classified` **Configuration** @@ -82,6 +79,19 @@ topic-check: **Configuration** +- `acl` *(list)*: determines subscription access; if `publish-acl` is not specified, determine both publish and subscription access. + The list should be a key-value pair, where: +`:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + + +- `publish-acl` *(list)*: determines publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: +`:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + + !!! info "Reserved usernames" + + - The username `admin` is allowed access to all topics. + - The username `anonymous` will control allowed topics, if using the `auth_anonymous` plugin. + ```yaml topic-check: enabled: true @@ -95,20 +105,17 @@ topic-check: - . ``` - - - - - ## Plugin: $SYS +`amqtt.plugins.sys.broker:BrokerSysPlugin` + Publishes, on a periodic basis, statistics about the broker -**Config Options** +**Configuration** - `sys_interval` - int, seconds between updates -### Supported Topics +**Supported Topics** - `$SYS/broker/load/bytes/received` - payload: `data`, int - `$SYS/broker/load/bytes/sent` - payload: `data`, int diff --git a/docs/references/broker_config.md b/docs/references/broker_config.md index 157679d..a2fd26c 100644 --- a/docs/references/broker_config.md +++ b/docs/references/broker_config.md @@ -33,9 +33,30 @@ Client disconnect timeout without a keep-alive Configuration for authentication behaviour: -- `plugins` *(list[string])*: defines the list of plugins which are activated as authentication plugins. Note the plugins must be defined in the `amqtt.broker.plugins` [entry point](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata). -- `allow-anonymous` *(bool)*: used by the internal `amqtt.plugins.authentication.AnonymousAuthPlugin` plugin. This parameter enables (`on`) or disable anonymous connection, i.e. connection without username. -- `password-file` *(string)*: used by the internal `amqtt.plugins.authentication.FileAuthPlugin` plugin. Path to file which includes `username:password` pair, one per line. The password should be encoded using sha-512 with `mkpasswd -m sha-512` or: +- `plugins` *(list[string])*: defines the list of plugins which are activated as authentication plugins. + + !!! note "Entry points" + Plugins used here must first be defined in the `amqtt.broker.plugins` [entry point](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata). + + + !!! danger "Legacy behavior" + if `plugins` is omitted from the `auth` section, all plugins listed in the `amqtt.broker.plugins` entrypoint will be enabled + for authentication, *including allowing anonymous login.* + + `plugins: []` will deny connections from all clients. + +- `allow-anonymous` *(bool)*: `True` will allow anonymous connections. + + *Used by the internal `amqtt.plugins.authentication.AnonymousAuthPlugin` plugin* + + !!! danger "Username only connections" + `False` does not disable the `auth_anonymous` plugin; connections will still be allowed as long as a username is provided. + + If security is required, do not include `auth_anonymous` in the `plugins` list. + + + +- `password-file` *(string)*: Path to file which includes `username:password` pair, one per line. The password should be encoded using sha-512 with `mkpasswd -m sha-512` or: ```python import sys from getpass import getpass @@ -44,6 +65,8 @@ Configuration for authentication behaviour: passwd = input() if not sys.stdin.isatty() else getpass() print(sha512_crypt.hash(passwd)) ``` + + *Used by the internal `amqtt.plugins.authentication.FileAuthPlugin` plugin.* ### `topic-check` *(mapping)* @@ -51,12 +74,23 @@ Configuration for access control policies for publishing and subscribing to topi - `enabled` *(bool)*: Enable access control policies (`true`). `false` will allow clients to publish and subscribe to any topic. - `plugins` *(list[string])*: defines the list of plugins which are activated as access control plugins. Note the plugins must be defined in the `amqtt.broker.plugins` [entry point](https://pythonhosted.org/setuptools/setuptools.html#dynamic-discovery-of-services-and-plugins). -- `acl` *(list)*: used by the internal `amqtt.plugins.topic_acl.TopicAclPlugin` plugin to determine subscription access. This parameter defines the list of access control rules; each item is a key-value pair, where: + +- `acl` *(list)*: plugin to determine subscription access; if `publish-acl` is not specified, determine both publish and subscription access. + The list should be a key-value pair, where: `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). -use `anonymous` username for the list of allowed topics if using the `auth_anonymous` plugin. -- `publish-acl` *(list)*: used by the internal `amqtt.plugins.topic_acl.TopicAclPlugin` plugin to determine publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: + + *used by the `amqtt.plugins.topic_acl.TopicAclPlugin`* + +- `publish-acl` *(list)*: plugin to determine publish access. This parameter defines the list of access control rules; each item is a key-value pair, where: `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). -use `anonymous` username for the list of allowed topics if using the `auth_anonymous` plugin. + + !!! info "Reserved usernames" + + - The username `admin` is allowed access to all topic. + - The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin. + + + *used by the `amqtt.plugins.topic_acl.TopicAclPlugin`* diff --git a/mkdocs.rtd.yml b/mkdocs.rtd.yml index 3396033..67aa860 100644 --- a/mkdocs.rtd.yml +++ b/mkdocs.rtd.yml @@ -88,7 +88,6 @@ theme: extra_css: - assets/extra.css - #extra_javascript: #- assets/extra.js @@ -117,7 +116,6 @@ markdown_extensions: - toc: permalink: "ยค" - plugins: - search - autorefs diff --git a/pyproject.toml b/pyproject.toml index 34710fa..1fdb3c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ test_plugin = "tests.plugins.test_manager:EmptyTestPlugin" event_plugin = "tests.plugins.test_manager:EventTestPlugin" packet_logger_plugin = "amqtt.plugins.logging_amqtt:PacketLoggerPlugin" +# --8<-- [start:included] [project.entry-points."amqtt.broker.plugins"] event_logger_plugin = "amqtt.plugins.logging_amqtt:EventLoggerPlugin" packet_logger_plugin = "amqtt.plugins.logging_amqtt:PacketLoggerPlugin" @@ -104,6 +105,8 @@ auth_file = "amqtt.plugins.authentication:FileAuthPlugin" topic_taboo = "amqtt.plugins.topic_checking:TopicTabooPlugin" topic_acl = "amqtt.plugins.topic_checking:TopicAccessControlListPlugin" broker_sys = "amqtt.plugins.sys.broker:BrokerSysPlugin" +# --8<-- [end:included] + [project.entry-points."amqtt.client.plugins"] packet_logger_plugin = "amqtt.plugins.logging_amqtt:PacketLoggerPlugin" diff --git a/tests/plugins/test_topic_checking.py b/tests/plugins/test_topic_checking.py index 6237be4..e90f365 100644 --- a/tests/plugins/test_topic_checking.py +++ b/tests/plugins/test_topic_checking.py @@ -2,7 +2,7 @@ import logging import pytest -from amqtt.broker import Action +from amqtt.broker import Action, BrokerContext, Broker from amqtt.plugins.manager import BaseContext from amqtt.plugins.topic_checking import BaseTopicPlugin, TopicAccessControlListPlugin, TopicTabooPlugin from amqtt.session import Session @@ -29,7 +29,7 @@ async def test_base_no_config(logdog): assert log_records[0].message == "'topic-check' section not found in context configuration" assert log_records[1].levelno == logging.WARNING - assert log_records[1].message == "'auth' section not found in context configuration" + assert log_records[1].message == "'topic-check' section not found in context configuration" assert pile.is_empty() @@ -37,7 +37,8 @@ async def test_base_no_config(logdog): async def test_base_empty_config(logdog): """Check BaseTopicPlugin returns false if topic-check is empty.""" with logdog() as pile: - context = BaseContext() + broker = Broker() + context = BrokerContext(broker) context.logger = logging.getLogger("testlog") context.config = {"topic-check": {}} @@ -47,9 +48,12 @@ async def test_base_empty_config(logdog): # Should have printed just one warning log_records = list(pile.drain(name="testlog")) - assert len(log_records) == 1 + assert len(log_records) == 2 assert log_records[0].levelno == logging.WARNING - assert log_records[0].message == "'auth' section not found in context configuration" + assert log_records[0].message == "'topic-check' section not found in context configuration" + + assert log_records[1].levelno == logging.WARNING + assert log_records[1].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio @@ -106,7 +110,7 @@ async def test_taboo_empty_config(logdog): assert log_records[0].levelno == logging.WARNING assert log_records[0].message == "'topic-check' section not found in context configuration" assert log_records[1].levelno == logging.WARNING - assert log_records[1].message == "'auth' section not found in context configuration" + assert log_records[1].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio @@ -264,7 +268,7 @@ async def test_taclp_empty_config(logdog): log_records = list(pile.drain(name="testlog")) assert len(log_records) == 2 assert log_records[0].message == "'topic-check' section not found in context configuration" - assert log_records[1].message == "'auth' section not found in context configuration" + assert log_records[1].message == "'topic-check' section not found in context configuration" @pytest.mark.asyncio