moving base classes for auth and topic plugins into common file

pull/226/head
Andrew Mirsky 2025-06-17 12:53:37 -04:00
rodzic 352678c87e
commit b7ccc458e9
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 71 dodań i 74 usunięć

Wyświetl plik

@ -1,44 +1,14 @@
from pathlib import Path from pathlib import Path
from typing import Any
from passlib.apps import custom_app_context as pwd_context from passlib.apps import custom_app_context as pwd_context
from amqtt.broker import BrokerContext from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin from amqtt.plugins.base import BaseAuthPlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session from amqtt.session import Session
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
if not self.auth_config:
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.
Args:
session: amqtt.session.Session
Returns:
- `True` if user is authentication succeed, `False` if user authentication fails
- `None` if authentication can't be achieved (then plugin result is then ignored)
"""
if not self.auth_config:
# auth config section not found
self.context.logger.warning("'auth' section not found in context configuration")
return False
return True
class AnonymousAuthPlugin(BaseAuthPlugin): class AnonymousAuthPlugin(BaseAuthPlugin):
"""Authentication plugin allowing anonymous access.""" """Authentication plugin allowing anonymous access."""

Wyświetl plik

@ -1,6 +1,8 @@
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from amqtt.broker import Action
from amqtt.plugins.manager import BaseContext from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
C = TypeVar("C", bound=BaseContext) C = TypeVar("C", bound=BaseContext)
@ -24,3 +26,62 @@ class BasePlugin(Generic[C]):
async def close(self) -> None: async def close(self) -> None:
"""Override if plugin needs to clean up resources upon shutdown.""" """Override if plugin needs to clean up resources upon shutdown."""
class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
if self.topic_config is None:
self.context.logger.warning("'topic-check' section not found in context configuration")
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
"""Logic for filtering out topics.
Args:
session: amqtt.session.Session
topic: str
action: amqtt.broker.Action
Returns:
bool: `True` if topic is allowed, `False` otherwise
"""
if not self.topic_config:
# auth config section not found
self.context.logger.warning("'topic-check' section not found in context configuration")
return False
return True
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
if not self.auth_config:
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.
Args:
session: amqtt.session.Session
Returns:
- `True` if user is authentication succeed, `False` if user authentication fails
- `None` if authentication can't be achieved (then plugin result is then ignored)
"""
if not self.auth_config:
# auth config section not found
self.context.logger.warning("'auth' section not found in context configuration")
return False
return True

Wyświetl plik

@ -18,9 +18,8 @@ _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from amqtt.broker import Action from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
class Plugin(NamedTuple): class Plugin(NamedTuple):
name: str name: str

Wyświetl plik

@ -1,42 +1,11 @@
from typing import Any from typing import Any
from amqtt.broker import Action from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin from amqtt.plugins.base import BaseTopicPlugin
from amqtt.plugins.manager import BaseContext from amqtt.plugins.manager import BaseContext
from amqtt.session import Session from amqtt.session import Session
class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
if self.topic_config is None:
self.context.logger.warning("'topic-check' section not found in context configuration")
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
"""Logic for filtering out topics.
Args:
session: amqtt.session.Session
topic: str
action: amqtt.broker.Action
Returns:
bool: `True` if topic is allowed, `False` otherwise
"""
if not self.topic_config:
# auth config section not found
self.context.logger.warning("'topic-check' section not found in context configuration")
return False
return True
class TopicTabooPlugin(BaseTopicPlugin): class TopicTabooPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None: def __init__(self, context: BaseContext) -> None:
super().__init__(context) super().__init__(context)

Wyświetl plik

@ -58,7 +58,7 @@ auth:
These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method. These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method.
::: amqtt.plugins.authentication.BaseAuthPlugin ::: amqtt.plugins.base.BaseAuthPlugin
## Topic Filter Plugins ## Topic Filter Plugins
@ -75,4 +75,4 @@ topic-check:
These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method. These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method.
::: amqtt.plugins.topic_checking.BaseTopicPlugin ::: amqtt.plugins.base.BaseTopicPlugin

Wyświetl plik

@ -4,10 +4,8 @@ from dataclasses import dataclass
from amqtt.broker import Action from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin from amqtt.plugins.base import BasePlugin, BaseTopicPlugin, BaseAuthPlugin
from amqtt.plugins.manager import BaseContext from amqtt.plugins.manager import BaseContext
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.session import Session from amqtt.session import Session

Wyświetl plik

@ -4,9 +4,8 @@ import unittest
from amqtt.broker import Action from amqtt.broker import Action
from amqtt.events import BrokerEvents from amqtt.events import BrokerEvents
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.manager import BaseContext, PluginManager from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.topic_checking import BaseTopicPlugin from amqtt.plugins.base import BaseTopicPlugin, BaseAuthPlugin
from amqtt.session import Session from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"

Wyświetl plik

@ -4,7 +4,8 @@ import pytest
from amqtt.broker import Action, BrokerContext, Broker from amqtt.broker import Action, BrokerContext, Broker
from amqtt.plugins.manager import BaseContext from amqtt.plugins.manager import BaseContext
from amqtt.plugins.topic_checking import BaseTopicPlugin, TopicAccessControlListPlugin, TopicTabooPlugin from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabooPlugin
from amqtt.plugins.base import BaseTopicPlugin
from amqtt.session import Session from amqtt.session import Session
# Base plug-in object # Base plug-in object