moving base classes for auth and topic plugins into common file

pull/226/head
Andrew Mirsky 2025-06-17 12:53:37 -04:00
rodzic 352678c87e
commit b7ccc458e9
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 71 dodań i 74 usunięć

Wyświetl plik

@ -1,44 +1,14 @@
from pathlib import Path
from typing import Any
from passlib.apps import custom_app_context as pwd_context
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.plugins.base import BaseAuthPlugin
from amqtt.session import Session
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
if not self.auth_config:
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.
Args:
session: amqtt.session.Session
Returns:
- `True` if user is authentication succeed, `False` if user authentication fails
- `None` if authentication can't be achieved (then plugin result is then ignored)
"""
if not self.auth_config:
# auth config section not found
self.context.logger.warning("'auth' section not found in context configuration")
return False
return True
class AnonymousAuthPlugin(BaseAuthPlugin):
"""Authentication plugin allowing anonymous access."""

Wyświetl plik

@ -1,6 +1,8 @@
from typing import Any, Generic, TypeVar
from amqtt.broker import Action
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
C = TypeVar("C", bound=BaseContext)
@ -24,3 +26,62 @@ class BasePlugin(Generic[C]):
async def close(self) -> None:
"""Override if plugin needs to clean up resources upon shutdown."""
class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
if self.topic_config is None:
self.context.logger.warning("'topic-check' section not found in context configuration")
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
"""Logic for filtering out topics.
Args:
session: amqtt.session.Session
topic: str
action: amqtt.broker.Action
Returns:
bool: `True` if topic is allowed, `False` otherwise
"""
if not self.topic_config:
# auth config section not found
self.context.logger.warning("'topic-check' section not found in context configuration")
return False
return True
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
if not self.auth_config:
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.
Args:
session: amqtt.session.Session
Returns:
- `True` if user is authentication succeed, `False` if user authentication fails
- `None` if authentication can't be achieved (then plugin result is then ignored)
"""
if not self.auth_config:
# auth config section not found
self.context.logger.warning("'auth' section not found in context configuration")
return False
return True

Wyświetl plik

@ -18,9 +18,8 @@ _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
class Plugin(NamedTuple):
name: str

Wyświetl plik

@ -1,42 +1,11 @@
from typing import Any
from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.base import BaseTopicPlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
if self.topic_config is None:
self.context.logger.warning("'topic-check' section not found in context configuration")
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
"""Logic for filtering out topics.
Args:
session: amqtt.session.Session
topic: str
action: amqtt.broker.Action
Returns:
bool: `True` if topic is allowed, `False` otherwise
"""
if not self.topic_config:
# auth config section not found
self.context.logger.warning("'topic-check' section not found in context configuration")
return False
return True
class TopicTabooPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None:
super().__init__(context)

Wyświetl plik

@ -58,7 +58,7 @@ auth:
These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method.
::: amqtt.plugins.authentication.BaseAuthPlugin
::: amqtt.plugins.base.BaseAuthPlugin
## Topic Filter Plugins
@ -75,4 +75,4 @@ topic-check:
These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method.
::: amqtt.plugins.topic_checking.BaseTopicPlugin
::: amqtt.plugins.base.BaseTopicPlugin

Wyświetl plik

@ -4,10 +4,8 @@ from dataclasses import dataclass
from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.base import BasePlugin, BaseTopicPlugin, BaseAuthPlugin
from amqtt.plugins.manager import BaseContext
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.session import Session

Wyświetl plik

@ -4,9 +4,8 @@ import unittest
from amqtt.broker import Action
from amqtt.events import BrokerEvents
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.plugins.base import BaseTopicPlugin, BaseAuthPlugin
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"

Wyświetl plik

@ -4,7 +4,8 @@ import pytest
from amqtt.broker import Action, BrokerContext, Broker
from amqtt.plugins.manager import BaseContext
from amqtt.plugins.topic_checking import BaseTopicPlugin, TopicAccessControlListPlugin, TopicTabooPlugin
from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabooPlugin
from amqtt.plugins.base import BaseTopicPlugin
from amqtt.session import Session
# Base plug-in object