kopia lustrzana https://github.com/Yakifo/amqtt
instead of filtering upon call of the plugin's coro, filter upon plugin loading
rodzic
7b936d785c
commit
43efa4c829
|
@ -689,11 +689,6 @@ class Broker:
|
|||
:param listener:
|
||||
:return:
|
||||
"""
|
||||
auth_plugins = None
|
||||
auth_config = self.config.get("auth", None)
|
||||
if isinstance(auth_config, dict):
|
||||
auth_plugins = auth_config.get("plugins", None)
|
||||
# returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
|
||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||
auth_result = True
|
||||
if returns:
|
||||
|
@ -765,22 +760,14 @@ class Broker:
|
|||
"""
|
||||
topic_config = self.config.get("topic-check", {})
|
||||
enabled = False
|
||||
topic_plugins: list[str] | None = None
|
||||
|
||||
if isinstance(topic_config, dict):
|
||||
enabled = topic_config.get("enabled", False)
|
||||
topic_plugins = topic_config.get("plugins")
|
||||
|
||||
if not enabled:
|
||||
return True
|
||||
|
||||
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
||||
# results = await self.plugins_manager.map_plugin_coro(
|
||||
# "topic_filtering",
|
||||
# session=session,
|
||||
# topic=topic,
|
||||
# action=action,
|
||||
# filter_plugins=topic_plugins,
|
||||
# )
|
||||
return all(result for result in results.values())
|
||||
|
||||
async def _delete_session(self, client_id: str) -> None:
|
||||
|
|
|
@ -34,4 +34,3 @@ class ProtocolHandlerError(Exception):
|
|||
|
||||
class PluginLoadError(Exception):
|
||||
"""Exception thrown when loading a plugin."""
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from amqtt.session import Session
|
|||
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
||||
|
||||
|
||||
class BaseAuthPlugin(BasePlugin):
|
||||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
class BasePlugin:
|
||||
|
||||
class BasePlugin(Generic[C]):
|
||||
"""The base from which all plugins should inherit."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
self.context = context
|
||||
def __init__(self, context: C) -> None:
|
||||
self.context: C = context
|
||||
|
||||
def _get_config_section(self, name: str) -> dict[str, Any] | None:
|
||||
|
||||
if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None):
|
||||
if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
|
||||
return None
|
||||
|
||||
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
|
||||
|
@ -21,6 +22,5 @@ class BasePlugin:
|
|||
return None
|
||||
return section_config
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -1,25 +1,22 @@
|
|||
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable
|
||||
import contextlib
|
||||
import copy
|
||||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
import logging
|
||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
||||
|
||||
from amqtt.errors import MQTTError, PluginLoadError
|
||||
from amqtt.session import Session
|
||||
from amqtt.utils import import_string
|
||||
from dacite import from_dict, Config, DaciteError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
|
||||
class Plugin(NamedTuple):
|
||||
name: str
|
||||
|
@ -75,52 +72,31 @@ class PluginManager:
|
|||
return self.context
|
||||
|
||||
def _load_plugins(self, namespace: str) -> None:
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
|
||||
if 'plugins' in self.app_context.config:
|
||||
self.logger.info("Loading plugins from config file")
|
||||
for plugin_info in self.app_context.config['plugins']:
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
|
||||
if isinstance(plugin_info, dict):
|
||||
assert len(plugin_info.keys()) == 1
|
||||
plugin_path = list(plugin_info.keys())[0]
|
||||
plugin_cfg = plugin_info[plugin_path]
|
||||
plugin = self._load_str_plugin(plugin_path, plugin_cfg)
|
||||
elif isinstance(plugin_info, str):
|
||||
plugin = self._load_str_plugin(plugin_info, {})
|
||||
else:
|
||||
msg = 'Unexpected entry in plugins config'
|
||||
raise PluginLoadError(msg)
|
||||
auth_filter_list = []
|
||||
topic_filter_list = []
|
||||
if self.app_context.config and "auth" in self.app_context.config:
|
||||
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
|
||||
if self.app_context.config and "topic" in self.app_context.config:
|
||||
topic_filter_list = self.app_context.config["topic"].get("plugins", [])
|
||||
|
||||
self._plugins.append(plugin)
|
||||
if isinstance(plugin, BaseAuthPlugin):
|
||||
self._auth_plugins.append(plugin)
|
||||
if isinstance(plugin, BaseTopicPlugin):
|
||||
self._topic_plugins.append(plugin)
|
||||
ep: EntryPoints | list[EntryPoint] = []
|
||||
if hasattr(entry_points(), "select"):
|
||||
ep = entry_points().select(group=namespace)
|
||||
elif namespace in entry_points():
|
||||
ep = [entry_points()[namespace]]
|
||||
|
||||
|
||||
|
||||
else:
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
|
||||
auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else []
|
||||
topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else []
|
||||
ep: EntryPoints | list[EntryPoint] = []
|
||||
if hasattr(entry_points(), "select"):
|
||||
ep = entry_points().select(group=namespace)
|
||||
elif namespace in entry_points():
|
||||
ep = [entry_points()[namespace]]
|
||||
|
||||
for item in ep:
|
||||
plugin = self._load_ep_plugin(item)
|
||||
if plugin is not None:
|
||||
self._plugins.append(plugin.object)
|
||||
if plugin.name in auth_filter_list:
|
||||
self._auth_plugins.append(plugin.object)
|
||||
elif plugin.name in topic_filter_list:
|
||||
self._topic_plugins.append(plugin.object)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
for item in ep:
|
||||
plugin = self._load_ep_plugin(item)
|
||||
if plugin is not None:
|
||||
self._plugins.append(plugin.object)
|
||||
if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"):
|
||||
self._auth_plugins.append(plugin.object)
|
||||
if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"):
|
||||
self._topic_plugins.append(plugin.object)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
|
||||
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||
try:
|
||||
|
@ -138,53 +114,27 @@ class PluginManager:
|
|||
|
||||
return None
|
||||
|
||||
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin':
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
def get_plugin(self, name: str) -> Optional["BasePlugin"]:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
|
||||
try:
|
||||
plugin_class = import_string(plugin_path)
|
||||
except ModuleNotFoundError as ep:
|
||||
self.logger.error(f"Plugin import failed: {plugin_path}")
|
||||
raise MQTTError() from ep
|
||||
|
||||
if not issubclass(plugin_class, BasePlugin):
|
||||
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
|
||||
raise PluginLoadError(msg)
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
|
||||
try:
|
||||
plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True))
|
||||
except DaciteError as e:
|
||||
raise PluginLoadError from e
|
||||
|
||||
try:
|
||||
return plugin_class(plugin_context)
|
||||
except ImportError as e:
|
||||
raise PluginLoadError from e
|
||||
|
||||
# def get_plugin(self, name: str) -> Plugin | None:
|
||||
# """Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
#
|
||||
# :param name:
|
||||
# :return:
|
||||
# """
|
||||
# for p in self._plugins:
|
||||
# if p.name == name:
|
||||
# return p
|
||||
# return None
|
||||
:param name:
|
||||
:return:
|
||||
"""
|
||||
for p in self._plugins:
|
||||
self.logger.debug(f"plugin name >>>> {p.__class__.__name__}")
|
||||
if p.__class__.__name__ == name:
|
||||
return p
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Free PluginManager resources and cancel pending event methods."""
|
||||
await self.map_plugin_coro("close")
|
||||
await self.map_plugin_close()
|
||||
for task in self._fired_events:
|
||||
task.cancel()
|
||||
self._fired_events.clear()
|
||||
|
||||
@property
|
||||
def plugins(self) -> list['BasePlugin']:
|
||||
def plugins(self) -> list["BasePlugin"]:
|
||||
"""Get the loaded plugins list.
|
||||
|
||||
:return:
|
||||
|
@ -229,100 +179,83 @@ class PluginManager:
|
|||
await asyncio.wait(tasks)
|
||||
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
|
||||
|
||||
async def map(
|
||||
self,
|
||||
coro: Callable[[Plugin, Any], Awaitable[str | bool | None]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict[Plugin, str | bool | None]:
|
||||
"""Schedule a given coroutine call for each plugin.
|
||||
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]:
|
||||
"""Schedule a coroutine for plugin 'authenticate' calls.
|
||||
|
||||
The coro called gets the Plugin instance as the first argument of its method call.
|
||||
:param coro: coro to call on each plugin
|
||||
:param filter_plugins: list of plugin names to filter (only plugin whose name is
|
||||
in the filter are called). None will call all plugins. [] will call None.
|
||||
:param args: arguments to pass to coro
|
||||
:param kwargs: arguments to pass to coro
|
||||
:param session: the client session associated with the authentication check
|
||||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
p_list = kwargs.pop("filter_plugins", None)
|
||||
if p_list is None:
|
||||
p_list = [p.name for p in self.plugins]
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
plugins_list: list[Plugin] = []
|
||||
for plugin in self._plugins:
|
||||
if plugin.name in p_list:
|
||||
coro_instance = coro(plugin, *args, **kwargs)
|
||||
if coro_instance:
|
||||
try:
|
||||
tasks.append(self._schedule_coro(coro_instance))
|
||||
plugins_list.append(plugin)
|
||||
except AssertionError:
|
||||
self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine")
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(plugins_list, ret_list, strict=False))
|
||||
else:
|
||||
ret_dict = {}
|
||||
return ret_dict
|
||||
|
||||
@staticmethod
|
||||
async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None:
|
||||
if not hasattr(plugin.object, coro_name):
|
||||
_LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}")
|
||||
return None
|
||||
|
||||
coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs)
|
||||
return await coro
|
||||
|
||||
async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]:
|
||||
"""Call a plugin declared by plugin by its name.
|
||||
|
||||
:param coro_name:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
return await self.map(self._call_coro, coro_name, *args, **kwargs)
|
||||
|
||||
|
||||
async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]:
|
||||
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
for plugin in self._auth_plugins:
|
||||
|
||||
async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None:
|
||||
async def auth_coro(p: "BaseAuthPlugin", s: Session) -> str | bool | None:
|
||||
return await p.authenticate(session=s)
|
||||
|
||||
if not hasattr(plugin, "authenticate"):
|
||||
continue
|
||||
|
||||
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
else:
|
||||
ret_dict = {}
|
||||
|
||||
return ret_dict
|
||||
|
||||
async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]:
|
||||
async def map_plugin_topic(self,
|
||||
session: Session, topic: str, action: "Action"
|
||||
) -> dict["BasePlugin", str | bool | None]:
|
||||
"""Schedule a coroutine for plugin 'topic_filtering' calls.
|
||||
|
||||
:param session: the client session associated with the topic_filtering check
|
||||
:param topic: the topic that needs to be filtered
|
||||
:param action: the action being executed
|
||||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
for plugin in self._topic_plugins:
|
||||
|
||||
async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None:
|
||||
async def topic_coro(p: "BaseTopicPlugin", s: Session, t: str, a: "Action") -> str | bool | None:
|
||||
return await p.topic_filtering(session=s, topic=t, action=a)
|
||||
|
||||
if not hasattr(plugin, "topic_filtering"):
|
||||
continue
|
||||
|
||||
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))}
|
||||
else:
|
||||
ret_dict = {}
|
||||
ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
|
||||
return ret_dict
|
||||
|
||||
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]:
|
||||
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
for plugin in self._plugins:
|
||||
|
||||
async def close_coro(p: "BasePlugin") -> None:
|
||||
await p.close()
|
||||
|
||||
if not hasattr(plugin, "close"):
|
||||
continue
|
||||
|
||||
coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
|
||||
return ret_dict
|
||||
|
|
|
@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected"
|
|||
STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
|
||||
|
||||
|
||||
class BrokerSysPlugin(BasePlugin):
|
||||
class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
# Broker statistics initialization
|
||||
|
|
|
@ -6,7 +6,7 @@ from amqtt.plugins.manager import BaseContext
|
|||
from amqtt.session import Session
|
||||
|
||||
|
||||
class BaseTopicPlugin(BasePlugin):
|
||||
class BaseTopicPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for topic plugins."""
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import string
|
||||
import typing
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
@ -51,32 +48,3 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
|
|||
except yaml.YAMLError:
|
||||
logger.exception(f"Invalid config_file {config_file}")
|
||||
return None
|
||||
|
||||
def cached_import(module_path: str, class_name: str=None) -> ModuleType:
|
||||
# Check whether module is loaded and fully initialized.
|
||||
if not ((module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None)) # noqa
|
||||
and getattr(spec, "_initializing", False) is False): # noqa
|
||||
module = import_module(module_path)
|
||||
if class_name:
|
||||
return getattr(module, class_name)
|
||||
return module
|
||||
|
||||
|
||||
# TODO : figure out proper return type
|
||||
def import_string(dotted_path) -> Any:
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by the
|
||||
last name in the path. Raise ImportError if the import failed.
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = dotted_path.rsplit(".", 1)
|
||||
except ValueError as err:
|
||||
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
|
||||
|
||||
try:
|
||||
return cached_import(module_path, class_name)
|
||||
except AttributeError as err:
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||
) from err
|
||||
|
|
|
@ -185,13 +185,13 @@ max-returns = 10
|
|||
|
||||
# ----------------------------------- PYTEST -----------------------------------
|
||||
[tool.pytest.ini_options]
|
||||
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
||||
addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
timeout = 10
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
addopts = ["--tb=short", "--capture=tee-sys"]
|
||||
log_cli = true
|
||||
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||
#log_cli = true
|
||||
log_level = "DEBUG"
|
||||
|
||||
# ------------------------------------ MYPY ------------------------------------
|
||||
|
|
|
@ -1,117 +0,0 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from yaml import CLoader as Loader
|
||||
from dacite import from_dict, Config, UnexpectedDataError
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import PluginLoadError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
plugin_config = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestSimplePlugin:
|
||||
- tests.plugins.mocks.TestConfigPlugin:
|
||||
option1: 1
|
||||
option2: bar
|
||||
"""
|
||||
|
||||
|
||||
plugin_invalid_config_one = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestSimplePlugin:
|
||||
option1: 1
|
||||
option2: bar
|
||||
"""
|
||||
|
||||
plugin_invalid_config_two = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestConfigPlugin:
|
||||
"""
|
||||
|
||||
plugin_config_auth = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestAuthPlugin:
|
||||
"""
|
||||
|
||||
plugin_config_topic = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestTopicPlugin:
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_config_extra_fields():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginLoadError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_config_missing_fields():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginLoadError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternate_plugin_load():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
|
||||
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
client1 = MQTTClient()
|
||||
await client1.connect()
|
||||
await client1.publish('my/topic', b'my message')
|
||||
await client1.disconnect()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_topic_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await broker.shutdown()
|
|
@ -1,9 +1,12 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
from amqtt.plugins.manager import BaseContext, Plugin, PluginManager
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
|
@ -14,21 +17,29 @@ class EmptyTestPlugin:
|
|||
self.context = context
|
||||
|
||||
|
||||
class EventTestPlugin:
|
||||
class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
self.context = context
|
||||
self.test_flag = False
|
||||
self.coro_flag = False
|
||||
super().__init__(context)
|
||||
self.test_close_flag = False
|
||||
self.test_auth_flag = False
|
||||
self.test_topic_flag = False
|
||||
self.test_event_flag = False
|
||||
|
||||
async def on_test(self) -> None:
|
||||
self.test_flag = True
|
||||
self.context.logger.info("on_test")
|
||||
self.test_event_flag = True
|
||||
|
||||
async def test_coro(self) -> None:
|
||||
self.coro_flag = True
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
self.test_auth_flag = True
|
||||
return None
|
||||
|
||||
async def ret_coro(self) -> str:
|
||||
return "TEST"
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
self.test_topic_flag = True
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
self.test_close_flag = True
|
||||
|
||||
|
||||
class TestPluginManager(unittest.TestCase):
|
||||
|
@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase):
|
|||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
self.loop.run_until_complete(fire_event())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.object.test_flag
|
||||
assert plugin.test_event_flag
|
||||
|
||||
def test_fire_event_wait(self) -> None:
|
||||
async def fire_event() -> None:
|
||||
|
@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase):
|
|||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
self.loop.run_until_complete(fire_event())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.object.test_flag
|
||||
assert plugin.test_event_flag
|
||||
|
||||
def test_map_coro(self) -> None:
|
||||
async def call_coro() -> None:
|
||||
await manager.map_plugin_coro("test_coro")
|
||||
def test_plugin_close_coro(self) -> None:
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
self.loop.run_until_complete(call_coro())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
self.loop.run_until_complete(manager.map_plugin_close())
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.object.test_coro
|
||||
assert plugin.test_close_flag
|
||||
|
||||
def test_map_coro_return(self) -> None:
|
||||
async def call_coro() -> dict[Plugin, str]:
|
||||
return await manager.map_plugin_coro("ret_coro")
|
||||
def test_plugin_auth_coro(self) -> None:
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
ret = self.loop.run_until_complete(call_coro())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert ret[plugin] == "TEST"
|
||||
assert plugin.test_auth_flag
|
||||
|
||||
def test_map_coro_filter(self) -> None:
|
||||
"""Run plugin coro but expect no return as an empty filter is given."""
|
||||
|
||||
async def call_coro() -> dict[Plugin, Any]:
|
||||
return await manager.map_plugin_coro("ret_coro", filter_plugins=[])
|
||||
def test_plugin_topic_coro(self) -> None:
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
ret = self.loop.run_until_complete(call_coro())
|
||||
assert len(ret) == 0
|
||||
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.test_topic_flag
|
||||
|
|
Ładowanie…
Reference in New Issue