diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index 90f7eb9..954c403 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -1,44 +1,14 @@ from pathlib import Path -from typing import Any from passlib.apps import custom_app_context as pwd_context from amqtt.broker import BrokerContext -from amqtt.plugins.base import BasePlugin -from amqtt.plugins.manager import BaseContext +from amqtt.plugins.base import BaseAuthPlugin from amqtt.session import Session _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line -class BaseAuthPlugin(BasePlugin[BaseContext]): - """Base class for authentication plugins.""" - - def __init__(self, context: BaseContext) -> None: - super().__init__(context) - - self.auth_config: dict[str, Any] | None = self._get_config_section("auth") - if not self.auth_config: - self.context.logger.warning("'auth' section not found in context configuration") - - async def authenticate(self, *, session: Session) -> bool | None: - """Logic for session authentication. - - Args: - session: amqtt.session.Session - - Returns: - - `True` if user is authentication succeed, `False` if user authentication fails - - `None` if authentication can't be achieved (then plugin result is then ignored) - - """ - if not self.auth_config: - # auth config section not found - self.context.logger.warning("'auth' section not found in context configuration") - return False - return True - - class AnonymousAuthPlugin(BaseAuthPlugin): """Authentication plugin allowing anonymous access.""" diff --git a/amqtt/plugins/base.py b/amqtt/plugins/base.py index 2d5f644..90ee1b3 100644 --- a/amqtt/plugins/base.py +++ b/amqtt/plugins/base.py @@ -1,6 +1,8 @@ from typing import Any, Generic, TypeVar +from amqtt.broker import Action from amqtt.plugins.manager import BaseContext +from amqtt.session import Session C = TypeVar("C", bound=BaseContext) @@ -24,3 +26,62 @@ class BasePlugin(Generic[C]): async def close(self) -> None: """Override if plugin needs to clean up resources upon shutdown.""" + + +class BaseTopicPlugin(BasePlugin[BaseContext]): + """Base class for topic plugins.""" + + def __init__(self, context: BaseContext) -> None: + super().__init__(context) + + self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") + if self.topic_config is None: + self.context.logger.warning("'topic-check' section not found in context configuration") + + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + """Logic for filtering out topics. + + Args: + session: amqtt.session.Session + topic: str + action: amqtt.broker.Action + + Returns: + bool: `True` if topic is allowed, `False` otherwise + + """ + if not self.topic_config: + # auth config section not found + self.context.logger.warning("'topic-check' section not found in context configuration") + return False + return True + + +class BaseAuthPlugin(BasePlugin[BaseContext]): + """Base class for authentication plugins.""" + + def __init__(self, context: BaseContext) -> None: + super().__init__(context) + + self.auth_config: dict[str, Any] | None = self._get_config_section("auth") + if not self.auth_config: + self.context.logger.warning("'auth' section not found in context configuration") + + async def authenticate(self, *, session: Session) -> bool | None: + """Logic for session authentication. + + Args: + session: amqtt.session.Session + + Returns: + - `True` if user is authentication succeed, `False` if user authentication fails + - `None` if authentication can't be achieved (then plugin result is then ignored) + + """ + if not self.auth_config: + # auth config section not found + self.context.logger.warning("'auth' section not found in context configuration") + return False + return True diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index d1beb60..dd177e7 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -18,9 +18,8 @@ _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: from amqtt.broker import Action - from amqtt.plugins.authentication import BaseAuthPlugin - from amqtt.plugins.base import BasePlugin - from amqtt.plugins.topic_checking import BaseTopicPlugin + from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin + class Plugin(NamedTuple): name: str diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index d92d672..c61e313 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -1,42 +1,11 @@ from typing import Any from amqtt.broker import Action -from amqtt.plugins.base import BasePlugin +from amqtt.plugins.base import BaseTopicPlugin from amqtt.plugins.manager import BaseContext from amqtt.session import Session -class BaseTopicPlugin(BasePlugin[BaseContext]): - """Base class for topic plugins.""" - - def __init__(self, context: BaseContext) -> None: - super().__init__(context) - - self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") - if self.topic_config is None: - self.context.logger.warning("'topic-check' section not found in context configuration") - - async def topic_filtering( - self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None - ) -> bool: - """Logic for filtering out topics. - - Args: - session: amqtt.session.Session - topic: str - action: amqtt.broker.Action - - Returns: - bool: `True` if topic is allowed, `False` otherwise - - """ - if not self.topic_config: - # auth config section not found - self.context.logger.warning("'topic-check' section not found in context configuration") - return False - return True - - class TopicTabooPlugin(BaseTopicPlugin): def __init__(self, context: BaseContext) -> None: super().__init__(context) diff --git a/docs/custom_plugins.md b/docs/custom_plugins.md index 08ccbdc..34eb4ec 100644 --- a/docs/custom_plugins.md +++ b/docs/custom_plugins.md @@ -58,7 +58,7 @@ auth: These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method. -::: amqtt.plugins.authentication.BaseAuthPlugin +::: amqtt.plugins.base.BaseAuthPlugin ## Topic Filter Plugins @@ -75,4 +75,4 @@ topic-check: These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method. -::: amqtt.plugins.topic_checking.BaseTopicPlugin +::: amqtt.plugins.base.BaseTopicPlugin diff --git a/tests/plugins/mocks.py b/tests/plugins/mocks.py index dab1229..dce94a6 100644 --- a/tests/plugins/mocks.py +++ b/tests/plugins/mocks.py @@ -4,10 +4,8 @@ from dataclasses import dataclass from amqtt.broker import Action -from amqtt.plugins.base import BasePlugin +from amqtt.plugins.base import BasePlugin, BaseTopicPlugin, BaseAuthPlugin from amqtt.plugins.manager import BaseContext -from amqtt.plugins.topic_checking import BaseTopicPlugin -from amqtt.plugins.authentication import BaseAuthPlugin from amqtt.session import Session diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index b9fa575..5164534 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -4,9 +4,8 @@ import unittest from amqtt.broker import Action from amqtt.events import BrokerEvents -from amqtt.plugins.authentication import BaseAuthPlugin from amqtt.plugins.manager import BaseContext, PluginManager -from amqtt.plugins.topic_checking import BaseTopicPlugin +from amqtt.plugins.base import BaseTopicPlugin, BaseAuthPlugin from amqtt.session import Session formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" diff --git a/tests/plugins/test_topic_checking.py b/tests/plugins/test_topic_checking.py index e90f365..ef5c69a 100644 --- a/tests/plugins/test_topic_checking.py +++ b/tests/plugins/test_topic_checking.py @@ -4,7 +4,8 @@ import pytest from amqtt.broker import Action, BrokerContext, Broker from amqtt.plugins.manager import BaseContext -from amqtt.plugins.topic_checking import BaseTopicPlugin, TopicAccessControlListPlugin, TopicTabooPlugin +from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabooPlugin +from amqtt.plugins.base import BaseTopicPlugin from amqtt.session import Session # Base plug-in object