kopia lustrzana https://github.com/Yakifo/amqtt
moving base classes for auth and topic plugins into common file
rodzic
352678c87e
commit
b7ccc458e9
|
@ -1,44 +1,14 @@
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from passlib.apps import custom_app_context as pwd_context
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.plugins.base import BaseAuthPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
||||
|
||||
|
||||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
|
||||
if not self.auth_config:
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Logic for session authentication.
|
||||
|
||||
Args:
|
||||
session: amqtt.session.Session
|
||||
|
||||
Returns:
|
||||
- `True` if user is authentication succeed, `False` if user authentication fails
|
||||
- `None` if authentication can't be achieved (then plugin result is then ignored)
|
||||
|
||||
"""
|
||||
if not self.auth_config:
|
||||
# auth config section not found
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class AnonymousAuthPlugin(BaseAuthPlugin):
|
||||
"""Authentication plugin allowing anonymous access."""
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
@ -24,3 +26,62 @@ class BasePlugin(Generic[C]):
|
|||
|
||||
async def close(self) -> None:
|
||||
"""Override if plugin needs to clean up resources upon shutdown."""
|
||||
|
||||
|
||||
class BaseTopicPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for topic plugins."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
|
||||
if self.topic_config is None:
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
"""Logic for filtering out topics.
|
||||
|
||||
Args:
|
||||
session: amqtt.session.Session
|
||||
topic: str
|
||||
action: amqtt.broker.Action
|
||||
|
||||
Returns:
|
||||
bool: `True` if topic is allowed, `False` otherwise
|
||||
|
||||
"""
|
||||
if not self.topic_config:
|
||||
# auth config section not found
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
|
||||
if not self.auth_config:
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Logic for session authentication.
|
||||
|
||||
Args:
|
||||
session: amqtt.session.Session
|
||||
|
||||
Returns:
|
||||
- `True` if user is authentication succeed, `False` if user authentication fails
|
||||
- `None` if authentication can't be achieved (then plugin result is then ignored)
|
||||
|
||||
"""
|
||||
if not self.auth_config:
|
||||
# auth config section not found
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -18,9 +18,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
|
||||
|
||||
class Plugin(NamedTuple):
|
||||
name: str
|
||||
|
|
|
@ -1,42 +1,11 @@
|
|||
from typing import Any
|
||||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.base import BaseTopicPlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
class BaseTopicPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for topic plugins."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
|
||||
if self.topic_config is None:
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
"""Logic for filtering out topics.
|
||||
|
||||
Args:
|
||||
session: amqtt.session.Session
|
||||
topic: str
|
||||
action: amqtt.broker.Action
|
||||
|
||||
Returns:
|
||||
bool: `True` if topic is allowed, `False` otherwise
|
||||
|
||||
"""
|
||||
if not self.topic_config:
|
||||
# auth config section not found
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TopicTabooPlugin(BaseTopicPlugin):
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
|
|
@ -58,7 +58,7 @@ auth:
|
|||
|
||||
These plugins should subclass from `BaseAuthPlugin` and implement the `authenticate` method.
|
||||
|
||||
::: amqtt.plugins.authentication.BaseAuthPlugin
|
||||
::: amqtt.plugins.base.BaseAuthPlugin
|
||||
|
||||
## Topic Filter Plugins
|
||||
|
||||
|
@ -75,4 +75,4 @@ topic-check:
|
|||
These plugins should subclass from `BaseTopicPlugin` and implement the `topic_filtering` method.
|
||||
|
||||
|
||||
::: amqtt.plugins.topic_checking.BaseTopicPlugin
|
||||
::: amqtt.plugins.base.BaseTopicPlugin
|
||||
|
|
|
@ -4,10 +4,8 @@ from dataclasses import dataclass
|
|||
|
||||
from amqtt.broker import Action
|
||||
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.base import BasePlugin, BaseTopicPlugin, BaseAuthPlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
|
||||
from amqtt.session import Session
|
||||
|
||||
|
|
|
@ -4,9 +4,8 @@ import unittest
|
|||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.events import BrokerEvents
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.plugins.base import BaseTopicPlugin, BaseAuthPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
|
|
|
@ -4,7 +4,8 @@ import pytest
|
|||
|
||||
from amqtt.broker import Action, BrokerContext, Broker
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin, TopicAccessControlListPlugin, TopicTabooPlugin
|
||||
from amqtt.plugins.topic_checking import TopicAccessControlListPlugin, TopicTabooPlugin
|
||||
from amqtt.plugins.base import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
# Base plug-in object
|
||||
|
|
Ładowanie…
Reference in New Issue