kopia lustrzana https://github.com/Yakifo/amqtt
instead of filtering upon call of the plugin's coro, filter upon plugin loading
rodzic
7b936d785c
commit
43efa4c829
|
@ -689,11 +689,6 @@ class Broker:
|
||||||
:param listener:
|
:param listener:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
auth_plugins = None
|
|
||||||
auth_config = self.config.get("auth", None)
|
|
||||||
if isinstance(auth_config, dict):
|
|
||||||
auth_plugins = auth_config.get("plugins", None)
|
|
||||||
# returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
|
|
||||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||||
auth_result = True
|
auth_result = True
|
||||||
if returns:
|
if returns:
|
||||||
|
@ -765,22 +760,14 @@ class Broker:
|
||||||
"""
|
"""
|
||||||
topic_config = self.config.get("topic-check", {})
|
topic_config = self.config.get("topic-check", {})
|
||||||
enabled = False
|
enabled = False
|
||||||
topic_plugins: list[str] | None = None
|
|
||||||
if isinstance(topic_config, dict):
|
if isinstance(topic_config, dict):
|
||||||
enabled = topic_config.get("enabled", False)
|
enabled = topic_config.get("enabled", False)
|
||||||
topic_plugins = topic_config.get("plugins")
|
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
||||||
# results = await self.plugins_manager.map_plugin_coro(
|
|
||||||
# "topic_filtering",
|
|
||||||
# session=session,
|
|
||||||
# topic=topic,
|
|
||||||
# action=action,
|
|
||||||
# filter_plugins=topic_plugins,
|
|
||||||
# )
|
|
||||||
return all(result for result in results.values())
|
return all(result for result in results.values())
|
||||||
|
|
||||||
async def _delete_session(self, client_id: str) -> None:
|
async def _delete_session(self, client_id: str) -> None:
|
||||||
|
|
|
@ -34,4 +34,3 @@ class ProtocolHandlerError(Exception):
|
||||||
|
|
||||||
class PluginLoadError(Exception):
|
class PluginLoadError(Exception):
|
||||||
"""Exception thrown when loading a plugin."""
|
"""Exception thrown when loading a plugin."""
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from amqtt.session import Session
|
||||||
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
||||||
|
|
||||||
|
|
||||||
class BaseAuthPlugin(BasePlugin):
|
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||||
"""Base class for authentication plugins."""
|
"""Base class for authentication plugins."""
|
||||||
|
|
||||||
def __init__(self, context: BaseContext) -> None:
|
def __init__(self, context: BaseContext) -> None:
|
||||||
|
|
|
@ -1,18 +1,19 @@
|
||||||
from dataclasses import dataclass
|
from typing import Any, Generic, TypeVar
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from amqtt.plugins.manager import BaseContext
|
from amqtt.plugins.manager import BaseContext
|
||||||
|
|
||||||
|
C = TypeVar("C", bound=BaseContext)
|
||||||
|
|
||||||
class BasePlugin:
|
|
||||||
|
class BasePlugin(Generic[C]):
|
||||||
"""The base from which all plugins should inherit."""
|
"""The base from which all plugins should inherit."""
|
||||||
|
|
||||||
def __init__(self, context: BaseContext) -> None:
|
def __init__(self, context: C) -> None:
|
||||||
self.context = context
|
self.context: C = context
|
||||||
|
|
||||||
def _get_config_section(self, name: str) -> dict[str, Any] | None:
|
def _get_config_section(self, name: str) -> dict[str, Any] | None:
|
||||||
|
|
||||||
if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None):
|
if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
|
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
|
||||||
|
@ -21,6 +22,5 @@ class BasePlugin:
|
||||||
return None
|
return None
|
||||||
return section_config
|
return section_config
|
||||||
|
|
||||||
@dataclass
|
async def close(self) -> None:
|
||||||
class Config:
|
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,25 +1,22 @@
|
||||||
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
|
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
||||||
|
|
||||||
from amqtt.errors import MQTTError, PluginLoadError
|
|
||||||
from amqtt.session import Session
|
from amqtt.session import Session
|
||||||
from amqtt.utils import import_string
|
|
||||||
from dacite import from_dict, Config, DaciteError
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from amqtt.plugins.base import BasePlugin
|
|
||||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
|
||||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
|
||||||
from amqtt.broker import Action
|
from amqtt.broker import Action
|
||||||
|
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||||
|
from amqtt.plugins.base import BasePlugin
|
||||||
|
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||||
|
|
||||||
class Plugin(NamedTuple):
|
class Plugin(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
|
@ -75,52 +72,31 @@ class PluginManager:
|
||||||
return self.context
|
return self.context
|
||||||
|
|
||||||
def _load_plugins(self, namespace: str) -> None:
|
def _load_plugins(self, namespace: str) -> None:
|
||||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
|
||||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
|
||||||
|
|
||||||
if 'plugins' in self.app_context.config:
|
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||||
self.logger.info("Loading plugins from config file")
|
|
||||||
for plugin_info in self.app_context.config['plugins']:
|
|
||||||
|
|
||||||
if isinstance(plugin_info, dict):
|
auth_filter_list = []
|
||||||
assert len(plugin_info.keys()) == 1
|
topic_filter_list = []
|
||||||
plugin_path = list(plugin_info.keys())[0]
|
if self.app_context.config and "auth" in self.app_context.config:
|
||||||
plugin_cfg = plugin_info[plugin_path]
|
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
|
||||||
plugin = self._load_str_plugin(plugin_path, plugin_cfg)
|
if self.app_context.config and "topic" in self.app_context.config:
|
||||||
elif isinstance(plugin_info, str):
|
topic_filter_list = self.app_context.config["topic"].get("plugins", [])
|
||||||
plugin = self._load_str_plugin(plugin_info, {})
|
|
||||||
else:
|
|
||||||
msg = 'Unexpected entry in plugins config'
|
|
||||||
raise PluginLoadError(msg)
|
|
||||||
|
|
||||||
self._plugins.append(plugin)
|
ep: EntryPoints | list[EntryPoint] = []
|
||||||
if isinstance(plugin, BaseAuthPlugin):
|
if hasattr(entry_points(), "select"):
|
||||||
self._auth_plugins.append(plugin)
|
ep = entry_points().select(group=namespace)
|
||||||
if isinstance(plugin, BaseTopicPlugin):
|
elif namespace in entry_points():
|
||||||
self._topic_plugins.append(plugin)
|
ep = [entry_points()[namespace]]
|
||||||
|
|
||||||
|
for item in ep:
|
||||||
|
plugin = self._load_ep_plugin(item)
|
||||||
else:
|
if plugin is not None:
|
||||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
self._plugins.append(plugin.object)
|
||||||
|
if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"):
|
||||||
auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else []
|
self._auth_plugins.append(plugin.object)
|
||||||
topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else []
|
if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"):
|
||||||
ep: EntryPoints | list[EntryPoint] = []
|
self._topic_plugins.append(plugin.object)
|
||||||
if hasattr(entry_points(), "select"):
|
self.logger.debug(f" Plugin {item.name} ready")
|
||||||
ep = entry_points().select(group=namespace)
|
|
||||||
elif namespace in entry_points():
|
|
||||||
ep = [entry_points()[namespace]]
|
|
||||||
|
|
||||||
for item in ep:
|
|
||||||
plugin = self._load_ep_plugin(item)
|
|
||||||
if plugin is not None:
|
|
||||||
self._plugins.append(plugin.object)
|
|
||||||
if plugin.name in auth_filter_list:
|
|
||||||
self._auth_plugins.append(plugin.object)
|
|
||||||
elif plugin.name in topic_filter_list:
|
|
||||||
self._topic_plugins.append(plugin.object)
|
|
||||||
self.logger.debug(f" Plugin {item.name} ready")
|
|
||||||
|
|
||||||
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
|
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||||
try:
|
try:
|
||||||
|
@ -138,53 +114,27 @@ class PluginManager:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin':
|
def get_plugin(self, name: str) -> Optional["BasePlugin"]:
|
||||||
from amqtt.plugins.base import BasePlugin
|
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
|
||||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
|
||||||
|
|
||||||
try:
|
:param name:
|
||||||
plugin_class = import_string(plugin_path)
|
:return:
|
||||||
except ModuleNotFoundError as ep:
|
"""
|
||||||
self.logger.error(f"Plugin import failed: {plugin_path}")
|
for p in self._plugins:
|
||||||
raise MQTTError() from ep
|
self.logger.debug(f"plugin name >>>> {p.__class__.__name__}")
|
||||||
|
if p.__class__.__name__ == name:
|
||||||
if not issubclass(plugin_class, BasePlugin):
|
return p
|
||||||
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
|
return None
|
||||||
raise PluginLoadError(msg)
|
|
||||||
|
|
||||||
plugin_context = copy.copy(self.app_context)
|
|
||||||
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
|
|
||||||
try:
|
|
||||||
plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True))
|
|
||||||
except DaciteError as e:
|
|
||||||
raise PluginLoadError from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
return plugin_class(plugin_context)
|
|
||||||
except ImportError as e:
|
|
||||||
raise PluginLoadError from e
|
|
||||||
|
|
||||||
# def get_plugin(self, name: str) -> Plugin | None:
|
|
||||||
# """Get a plugin by its name from the plugins loaded for the current namespace.
|
|
||||||
#
|
|
||||||
# :param name:
|
|
||||||
# :return:
|
|
||||||
# """
|
|
||||||
# for p in self._plugins:
|
|
||||||
# if p.name == name:
|
|
||||||
# return p
|
|
||||||
# return None
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Free PluginManager resources and cancel pending event methods."""
|
"""Free PluginManager resources and cancel pending event methods."""
|
||||||
await self.map_plugin_coro("close")
|
await self.map_plugin_close()
|
||||||
for task in self._fired_events:
|
for task in self._fired_events:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
self._fired_events.clear()
|
self._fired_events.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def plugins(self) -> list['BasePlugin']:
|
def plugins(self) -> list["BasePlugin"]:
|
||||||
"""Get the loaded plugins list.
|
"""Get the loaded plugins list.
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
|
@ -229,100 +179,83 @@ class PluginManager:
|
||||||
await asyncio.wait(tasks)
|
await asyncio.wait(tasks)
|
||||||
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
|
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
|
||||||
|
|
||||||
async def map(
|
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]:
|
||||||
self,
|
"""Schedule a coroutine for plugin 'authenticate' calls.
|
||||||
coro: Callable[[Plugin, Any], Awaitable[str | bool | None]],
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict[Plugin, str | bool | None]:
|
|
||||||
"""Schedule a given coroutine call for each plugin.
|
|
||||||
|
|
||||||
The coro called gets the Plugin instance as the first argument of its method call.
|
:param session: the client session associated with the authentication check
|
||||||
:param coro: coro to call on each plugin
|
|
||||||
:param filter_plugins: list of plugin names to filter (only plugin whose name is
|
|
||||||
in the filter are called). None will call all plugins. [] will call None.
|
|
||||||
:param args: arguments to pass to coro
|
|
||||||
:param kwargs: arguments to pass to coro
|
|
||||||
:return: dict containing return from coro call for each plugin.
|
:return: dict containing return from coro call for each plugin.
|
||||||
"""
|
"""
|
||||||
p_list = kwargs.pop("filter_plugins", None)
|
|
||||||
if p_list is None:
|
|
||||||
p_list = [p.name for p in self.plugins]
|
|
||||||
tasks: list[asyncio.Future[Any]] = []
|
|
||||||
plugins_list: list[Plugin] = []
|
|
||||||
for plugin in self._plugins:
|
|
||||||
if plugin.name in p_list:
|
|
||||||
coro_instance = coro(plugin, *args, **kwargs)
|
|
||||||
if coro_instance:
|
|
||||||
try:
|
|
||||||
tasks.append(self._schedule_coro(coro_instance))
|
|
||||||
plugins_list.append(plugin)
|
|
||||||
except AssertionError:
|
|
||||||
self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine")
|
|
||||||
if tasks:
|
|
||||||
ret_list = await asyncio.gather(*tasks)
|
|
||||||
# Create result map plugin => ret
|
|
||||||
ret_dict = dict(zip(plugins_list, ret_list, strict=False))
|
|
||||||
else:
|
|
||||||
ret_dict = {}
|
|
||||||
return ret_dict
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None:
|
|
||||||
if not hasattr(plugin.object, coro_name):
|
|
||||||
_LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs)
|
|
||||||
return await coro
|
|
||||||
|
|
||||||
async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]:
|
|
||||||
"""Call a plugin declared by plugin by its name.
|
|
||||||
|
|
||||||
:param coro_name:
|
|
||||||
:param args:
|
|
||||||
:param kwargs:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return await self.map(self._call_coro, coro_name, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]:
|
|
||||||
|
|
||||||
tasks: list[asyncio.Future[Any]] = []
|
tasks: list[asyncio.Future[Any]] = []
|
||||||
|
|
||||||
for plugin in self._auth_plugins:
|
for plugin in self._auth_plugins:
|
||||||
|
|
||||||
async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None:
|
async def auth_coro(p: "BaseAuthPlugin", s: Session) -> str | bool | None:
|
||||||
return await p.authenticate(session=s)
|
return await p.authenticate(session=s)
|
||||||
|
|
||||||
|
if not hasattr(plugin, "authenticate"):
|
||||||
|
continue
|
||||||
|
|
||||||
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
|
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
|
||||||
tasks.append(asyncio.ensure_future(coro_instance))
|
tasks.append(asyncio.ensure_future(coro_instance))
|
||||||
|
|
||||||
|
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||||
if tasks:
|
if tasks:
|
||||||
ret_list = await asyncio.gather(*tasks)
|
ret_list = await asyncio.gather(*tasks)
|
||||||
# Create result map plugin => ret
|
# Create result map plugin => ret
|
||||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||||
else:
|
|
||||||
ret_dict = {}
|
|
||||||
return ret_dict
|
return ret_dict
|
||||||
|
|
||||||
async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]:
|
async def map_plugin_topic(self,
|
||||||
|
session: Session, topic: str, action: "Action"
|
||||||
|
) -> dict["BasePlugin", str | bool | None]:
|
||||||
|
"""Schedule a coroutine for plugin 'topic_filtering' calls.
|
||||||
|
|
||||||
|
:param session: the client session associated with the topic_filtering check
|
||||||
|
:param topic: the topic that needs to be filtered
|
||||||
|
:param action: the action being executed
|
||||||
|
:return: dict containing return from coro call for each plugin.
|
||||||
|
"""
|
||||||
tasks: list[asyncio.Future[Any]] = []
|
tasks: list[asyncio.Future[Any]] = []
|
||||||
|
|
||||||
for plugin in self._topic_plugins:
|
for plugin in self._topic_plugins:
|
||||||
|
|
||||||
async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None:
|
async def topic_coro(p: "BaseTopicPlugin", s: Session, t: str, a: "Action") -> str | bool | None:
|
||||||
return await p.topic_filtering(session=s, topic=t, action=a)
|
return await p.topic_filtering(session=s, topic=t, action=a)
|
||||||
|
|
||||||
|
if not hasattr(plugin, "topic_filtering"):
|
||||||
|
continue
|
||||||
|
|
||||||
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
|
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
|
||||||
tasks.append(asyncio.ensure_future(coro_instance))
|
tasks.append(asyncio.ensure_future(coro_instance))
|
||||||
|
|
||||||
|
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||||
if tasks:
|
if tasks:
|
||||||
ret_list = await asyncio.gather(*tasks)
|
ret_list = await asyncio.gather(*tasks)
|
||||||
# Create result map plugin => ret
|
# Create result map plugin => ret
|
||||||
ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))}
|
ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||||
else:
|
|
||||||
ret_dict = {}
|
return ret_dict
|
||||||
|
|
||||||
|
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]:
|
||||||
|
|
||||||
|
tasks: list[asyncio.Future[Any]] = []
|
||||||
|
|
||||||
|
for plugin in self._plugins:
|
||||||
|
|
||||||
|
async def close_coro(p: "BasePlugin") -> None:
|
||||||
|
await p.close()
|
||||||
|
|
||||||
|
if not hasattr(plugin, "close"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
|
||||||
|
tasks.append(asyncio.ensure_future(coro_instance))
|
||||||
|
|
||||||
|
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||||
|
if tasks:
|
||||||
|
ret_list = await asyncio.gather(*tasks)
|
||||||
|
# Create result map plugin => ret
|
||||||
|
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||||
|
|
||||||
return ret_dict
|
return ret_dict
|
||||||
|
|
|
@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected"
|
||||||
STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
|
STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
|
||||||
|
|
||||||
|
|
||||||
class BrokerSysPlugin(BasePlugin):
|
class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||||
def __init__(self, context: BrokerContext) -> None:
|
def __init__(self, context: BrokerContext) -> None:
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
# Broker statistics initialization
|
# Broker statistics initialization
|
||||||
|
|
|
@ -6,7 +6,7 @@ from amqtt.plugins.manager import BaseContext
|
||||||
from amqtt.session import Session
|
from amqtt.session import Session
|
||||||
|
|
||||||
|
|
||||||
class BaseTopicPlugin(BasePlugin):
|
class BaseTopicPlugin(BasePlugin[BaseContext]):
|
||||||
"""Base class for topic plugins."""
|
"""Base class for topic plugins."""
|
||||||
|
|
||||||
def __init__(self, context: BaseContext) -> None:
|
def __init__(self, context: BaseContext) -> None:
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from importlib import import_module
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import typing
|
import typing
|
||||||
from types import ModuleType
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -51,32 +48,3 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
|
||||||
except yaml.YAMLError:
|
except yaml.YAMLError:
|
||||||
logger.exception(f"Invalid config_file {config_file}")
|
logger.exception(f"Invalid config_file {config_file}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def cached_import(module_path: str, class_name: str=None) -> ModuleType:
|
|
||||||
# Check whether module is loaded and fully initialized.
|
|
||||||
if not ((module := sys.modules.get(module_path))
|
|
||||||
and (spec := getattr(module, "__spec__", None)) # noqa
|
|
||||||
and getattr(spec, "_initializing", False) is False): # noqa
|
|
||||||
module = import_module(module_path)
|
|
||||||
if class_name:
|
|
||||||
return getattr(module, class_name)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
# TODO : figure out proper return type
|
|
||||||
def import_string(dotted_path) -> Any:
|
|
||||||
"""
|
|
||||||
Import a dotted module path and return the attribute/class designated by the
|
|
||||||
last name in the path. Raise ImportError if the import failed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
module_path, class_name = dotted_path.rsplit(".", 1)
|
|
||||||
except ValueError as err:
|
|
||||||
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
|
|
||||||
|
|
||||||
try:
|
|
||||||
return cached_import(module_path, class_name)
|
|
||||||
except AttributeError as err:
|
|
||||||
raise ImportError(
|
|
||||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
|
||||||
) from err
|
|
||||||
|
|
|
@ -185,13 +185,13 @@ max-returns = 10
|
||||||
|
|
||||||
# ----------------------------------- PYTEST -----------------------------------
|
# ----------------------------------- PYTEST -----------------------------------
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
timeout = 10
|
timeout = 10
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
addopts = ["--tb=short", "--capture=tee-sys"]
|
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||||
log_cli = true
|
#log_cli = true
|
||||||
log_level = "DEBUG"
|
log_level = "DEBUG"
|
||||||
|
|
||||||
# ------------------------------------ MYPY ------------------------------------
|
# ------------------------------------ MYPY ------------------------------------
|
||||||
|
|
|
@ -1,117 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from amqtt.broker import Broker
|
|
||||||
from yaml import CLoader as Loader
|
|
||||||
from dacite import from_dict, Config, UnexpectedDataError
|
|
||||||
|
|
||||||
from amqtt.client import MQTTClient
|
|
||||||
from amqtt.errors import PluginLoadError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
plugin_config = """---
|
|
||||||
listeners:
|
|
||||||
default:
|
|
||||||
type: tcp
|
|
||||||
bind: 0.0.0.0:1883
|
|
||||||
plugins:
|
|
||||||
- tests.plugins.mocks.TestSimplePlugin:
|
|
||||||
- tests.plugins.mocks.TestConfigPlugin:
|
|
||||||
option1: 1
|
|
||||||
option2: bar
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
plugin_invalid_config_one = """---
|
|
||||||
listeners:
|
|
||||||
default:
|
|
||||||
type: tcp
|
|
||||||
bind: 0.0.0.0:1883
|
|
||||||
plugins:
|
|
||||||
- tests.plugins.mocks.TestSimplePlugin:
|
|
||||||
option1: 1
|
|
||||||
option2: bar
|
|
||||||
"""
|
|
||||||
|
|
||||||
plugin_invalid_config_two = """---
|
|
||||||
listeners:
|
|
||||||
default:
|
|
||||||
type: tcp
|
|
||||||
bind: 0.0.0.0:1883
|
|
||||||
plugins:
|
|
||||||
- tests.plugins.mocks.TestConfigPlugin:
|
|
||||||
"""
|
|
||||||
|
|
||||||
plugin_config_auth = """---
|
|
||||||
listeners:
|
|
||||||
default:
|
|
||||||
type: tcp
|
|
||||||
bind: 0.0.0.0:1883
|
|
||||||
plugins:
|
|
||||||
- tests.plugins.mocks.TestAuthPlugin:
|
|
||||||
"""
|
|
||||||
|
|
||||||
plugin_config_topic = """---
|
|
||||||
listeners:
|
|
||||||
default:
|
|
||||||
type: tcp
|
|
||||||
bind: 0.0.0.0:1883
|
|
||||||
plugins:
|
|
||||||
- tests.plugins.mocks.TestTopicPlugin:
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plugin_config_extra_fields():
|
|
||||||
|
|
||||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
|
||||||
|
|
||||||
with pytest.raises(PluginLoadError):
|
|
||||||
_ = Broker(config=cfg)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plugin_config_missing_fields():
|
|
||||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
|
||||||
|
|
||||||
with pytest.raises(PluginLoadError):
|
|
||||||
_ = Broker(config=cfg)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_alternate_plugin_load():
|
|
||||||
|
|
||||||
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
|
|
||||||
|
|
||||||
broker = Broker(config=cfg)
|
|
||||||
await broker.start()
|
|
||||||
await broker.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_plugin_load():
|
|
||||||
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
|
|
||||||
broker = Broker(config=cfg)
|
|
||||||
await broker.start()
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
client1 = MQTTClient()
|
|
||||||
await client1.connect()
|
|
||||||
await client1.publish('my/topic', b'my message')
|
|
||||||
await client1.disconnect()
|
|
||||||
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
await broker.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_topic_plugin_load():
|
|
||||||
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
|
|
||||||
broker = Broker(config=cfg)
|
|
||||||
await broker.start()
|
|
||||||
await broker.shutdown()
|
|
|
@ -1,9 +1,12 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from amqtt.plugins.manager import BaseContext, Plugin, PluginManager
|
from amqtt.broker import Action
|
||||||
|
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||||
|
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||||
|
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||||
|
from amqtt.session import Session
|
||||||
|
|
||||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||||
|
@ -14,21 +17,29 @@ class EmptyTestPlugin:
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
|
|
||||||
class EventTestPlugin:
|
class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
|
||||||
def __init__(self, context: BaseContext) -> None:
|
def __init__(self, context: BaseContext) -> None:
|
||||||
self.context = context
|
super().__init__(context)
|
||||||
self.test_flag = False
|
self.test_close_flag = False
|
||||||
self.coro_flag = False
|
self.test_auth_flag = False
|
||||||
|
self.test_topic_flag = False
|
||||||
|
self.test_event_flag = False
|
||||||
|
|
||||||
async def on_test(self) -> None:
|
async def on_test(self) -> None:
|
||||||
self.test_flag = True
|
self.test_event_flag = True
|
||||||
self.context.logger.info("on_test")
|
|
||||||
|
|
||||||
async def test_coro(self) -> None:
|
async def authenticate(self, *, session: Session) -> bool | None:
|
||||||
self.coro_flag = True
|
self.test_auth_flag = True
|
||||||
|
return None
|
||||||
|
|
||||||
async def ret_coro(self) -> str:
|
async def topic_filtering(
|
||||||
return "TEST"
|
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||||
|
) -> bool:
|
||||||
|
self.test_topic_flag = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self.test_close_flag = True
|
||||||
|
|
||||||
|
|
||||||
class TestPluginManager(unittest.TestCase):
|
class TestPluginManager(unittest.TestCase):
|
||||||
|
@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase):
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||||
self.loop.run_until_complete(fire_event())
|
self.loop.run_until_complete(fire_event())
|
||||||
plugin = manager.get_plugin("event_plugin")
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
assert plugin is not None
|
assert plugin is not None
|
||||||
assert plugin.object.test_flag
|
assert plugin.test_event_flag
|
||||||
|
|
||||||
def test_fire_event_wait(self) -> None:
|
def test_fire_event_wait(self) -> None:
|
||||||
async def fire_event() -> None:
|
async def fire_event() -> None:
|
||||||
|
@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase):
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||||
self.loop.run_until_complete(fire_event())
|
self.loop.run_until_complete(fire_event())
|
||||||
plugin = manager.get_plugin("event_plugin")
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
assert plugin is not None
|
assert plugin is not None
|
||||||
assert plugin.object.test_flag
|
assert plugin.test_event_flag
|
||||||
|
|
||||||
def test_map_coro(self) -> None:
|
def test_plugin_close_coro(self) -> None:
|
||||||
async def call_coro() -> None:
|
|
||||||
await manager.map_plugin_coro("test_coro")
|
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||||
self.loop.run_until_complete(call_coro())
|
self.loop.run_until_complete(manager.map_plugin_close())
|
||||||
plugin = manager.get_plugin("event_plugin")
|
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||||
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
assert plugin is not None
|
assert plugin is not None
|
||||||
assert plugin.object.test_coro
|
assert plugin.test_close_flag
|
||||||
|
|
||||||
def test_map_coro_return(self) -> None:
|
def test_plugin_auth_coro(self) -> None:
|
||||||
async def call_coro() -> dict[Plugin, str]:
|
|
||||||
return await manager.map_plugin_coro("ret_coro")
|
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||||
ret = self.loop.run_until_complete(call_coro())
|
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
|
||||||
plugin = manager.get_plugin("event_plugin")
|
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||||
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
assert plugin is not None
|
assert plugin is not None
|
||||||
assert ret[plugin] == "TEST"
|
assert plugin.test_auth_flag
|
||||||
|
|
||||||
def test_map_coro_filter(self) -> None:
|
def test_plugin_topic_coro(self) -> None:
|
||||||
"""Run plugin coro but expect no return as an empty filter is given."""
|
|
||||||
|
|
||||||
async def call_coro() -> dict[Plugin, Any]:
|
|
||||||
return await manager.map_plugin_coro("ret_coro", filter_plugins=[])
|
|
||||||
|
|
||||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||||
ret = self.loop.run_until_complete(call_coro())
|
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
|
||||||
assert len(ret) == 0
|
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||||
|
plugin = manager.get_plugin("EventTestPlugin")
|
||||||
|
assert plugin is not None
|
||||||
|
assert plugin.test_topic_flag
|
||||||
|
|
Ładowanie…
Reference in New Issue