Merge pull request #212 from ajmirsky/plugin_call_optimization

reduce call logic for plugin coros: authenticate, topic_filtering, close
pull/216/head
Andrew Mirsky 2025-06-14 10:19:39 -04:00 zatwierdzone przez GitHub
commit aace4d65f7
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
19 zmienionych plików z 259 dodań i 145 usunięć

Wyświetl plik

@ -711,20 +711,16 @@ class Broker:
:param session:
:return:
"""
auth_plugins = None
auth_config = self.config.get("auth", None)
if isinstance(auth_config, dict):
auth_plugins = auth_config.get("plugins", None)
returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
returns = await self.plugins_manager.map_plugin_auth(session=session)
auth_result = True
if returns:
for plugin in returns:
res = returns[plugin]
if res is False:
auth_result = False
self.logger.debug(f"Authentication failed due to '{plugin.name}' plugin result: {res}")
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
else:
self.logger.debug(f"'{plugin.name}' plugin result: {res}")
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
# If all plugins returned True, authentication is success
return auth_result
@ -785,20 +781,14 @@ class Broker:
"""
topic_config = self.config.get("topic-check", {})
enabled = False
topic_plugins: list[str] | None = None
if isinstance(topic_config, dict):
enabled = topic_config.get("enabled", False)
topic_plugins = topic_config.get("plugins")
if not enabled:
return True
results = await self.plugins_manager.map_plugin_coro(
"topic_filtering",
session=session,
topic=topic,
action=action,
filter_plugins=topic_plugins,
)
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
return all(result for result in results.values())
async def _delete_session(self, client_id: str) -> None:

Wyświetl plik

@ -110,7 +110,7 @@ class MQTTClient:
# Init plugins manager
context = ClientContext()
context.config = self.config
self.plugins_manager = PluginManager("amqtt.client.plugins", context)
self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context)
self.client_tasks: deque[asyncio.Task[Any]] = deque()
async def connect(

Wyświetl plik

@ -50,3 +50,7 @@ class ConnectError(ClientError):
class ProtocolHandlerError(Exception):
"""Exceptions thrown by protocol handle."""
class PluginLoadError(Exception):
"""Exception thrown when loading a plugin."""

Wyświetl plik

@ -1,5 +1,6 @@
import asyncio
from asyncio import AbstractEventLoop, Queue
from typing import TYPE_CHECKING
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.errors import MQTTError
@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
if TYPE_CHECKING:
from amqtt.broker import BrokerContext
class Subscription:
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
@ -41,10 +44,10 @@ class UnSubscription:
self.topics = topics
class BrokerProtocolHandler(ProtocolHandler):
class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
def __init__(
self,
plugins_manager: PluginManager,
plugins_manager: PluginManager["BrokerContext"],
session: Session | None = None,
loop: AbstractEventLoop | None = None,
) -> None:
@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler):
cls,
reader: ReaderAdapter,
writer: WriterAdapter,
plugins_manager: PluginManager,
plugins_manager: PluginManager["BrokerContext"],
loop: asyncio.AbstractEventLoop | None = None,
) -> tuple["BrokerProtocolHandler", Session]:
"""Initialize from a CONNECT packet and validates the connection."""

Wyświetl plik

@ -1,5 +1,5 @@
import asyncio
from typing import Any
from typing import TYPE_CHECKING, Any
from amqtt.errors import AMQTTError, NoDataError
from amqtt.mqtt.connack import ConnackPacket
@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import PluginManager
from amqtt.session import Session
if TYPE_CHECKING:
from amqtt.client import ClientContext
class ClientProtocolHandler(ProtocolHandler):
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
def __init__(
self,
plugins_manager: PluginManager,
plugins_manager: PluginManager["ClientContext"],
session: Session | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:

Wyświetl plik

@ -17,7 +17,7 @@ except ImportError:
import collections
import itertools
import logging
from typing import cast
from typing import Generic, TypeVar, cast
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError
@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket
from amqtt.mqtt.subscribe import SubscribePacket
from amqtt.mqtt.unsuback import UnsubackPacket
from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import PluginManager
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session
EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent"
EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received"
C = TypeVar("C", bound=BaseContext)
class ProtocolHandler:
class ProtocolHandler(Generic[C]):
"""Class implementing the MQTT communication protocol using asyncio features."""
def __init__(
self,
plugins_manager: PluginManager,
plugins_manager: PluginManager[C],
session: Session | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
@ -79,7 +80,7 @@ class ProtocolHandler:
self.session: Session | None = None
self.reader: ReaderAdapter | None = None
self.writer: WriterAdapter | None = None
self.plugins_manager: PluginManager = plugins_manager
self.plugins_manager: PluginManager[C] = plugins_manager
try:
self._loop = loop if loop is not None else asyncio.get_running_loop()

Wyświetl plik

@ -5,15 +5,16 @@ from passlib.apps import custom_app_context as pwd_context
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class BaseAuthPlugin(BasePlugin):
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")

Wyświetl plik

@ -1,19 +1,26 @@
from typing import Any
from typing import Any, Generic, TypeVar
from amqtt.broker import BrokerContext
from amqtt.plugins.manager import BaseContext
C = TypeVar("C", bound=BaseContext)
class BasePlugin:
class BasePlugin(Generic[C]):
"""The base from which all plugins should inherit."""
def __init__(self, context: BrokerContext) -> None:
self.context = context
def __init__(self, context: C) -> None:
self.context: C = context
def _get_config_section(self, name: str) -> dict[str, Any] | None:
if not self.context.config or not self.context.config.get(name, None):
if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
return None
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
# mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check
if isinstance(section_config, int):
return None
return section_config
async def close(self) -> None:
"""Override if plugin needs to clean up resources upon shutdown."""

Wyświetl plik

@ -4,12 +4,13 @@ import logging
from typing import TYPE_CHECKING, Any
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
if TYPE_CHECKING:
from amqtt.session import Session
class EventLoggerPlugin(BasePlugin):
class EventLoggerPlugin(BasePlugin[BaseContext]):
"""A plugin to log events dynamically based on method names."""
async def log_event(self, *args: Any, **kwargs: Any) -> None:
@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin):
raise AttributeError(msg)
class PacketLoggerPlugin(BasePlugin):
class PacketLoggerPlugin(BasePlugin[BaseContext]):
"""A plugin to log MQTT packets sent and received."""
async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None:

Wyświetl plik

@ -1,17 +1,24 @@
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable
import contextlib
import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging
from typing import Any, NamedTuple
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
from amqtt.session import Session
from amqtt.errors import PluginImportError, PluginInitError
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
class Plugin(NamedTuple):
name: str
@ -19,10 +26,10 @@ class Plugin(NamedTuple):
object: Any
plugins_manager: dict[str, "PluginManager"] = {}
plugins_manager: dict[str, "PluginManager[Any]"] = {}
def get_plugin_manager(namespace: str) -> "PluginManager | None":
def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
"""Get the plugin manager for a given namespace.
:param namespace: The namespace of the plugin manager to retrieve.
@ -38,14 +45,17 @@ class BaseContext:
self.config: dict[str, Any] | None = None
class PluginManager:
C = TypeVar("C", bound=BaseContext)
class PluginManager(Generic[C]):
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to
run plugin calls asynchronously in an event queue.
"""
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
try:
self._loop = loop if loop is not None else asyncio.get_running_loop()
except RuntimeError:
@ -55,7 +65,9 @@ class PluginManager:
self.logger = logging.getLogger(namespace)
self.context = context if context is not None else BaseContext()
self.context.loop = self._loop
self._plugins: list[Plugin] = []
self._plugins: list[BasePlugin[C]] = []
self._auth_plugins: list[BaseAuthPlugin] = []
self._topic_plugins: list[BaseTopicPlugin] = []
self._load_plugins(namespace)
self._fired_events: list[asyncio.Future[Any]] = []
plugins_manager[namespace] = self
@ -65,7 +77,16 @@ class PluginManager:
return self.context
def _load_plugins(self, namespace: str) -> None:
self.logger.debug(f"Loading plugins for namespace {namespace}")
auth_filter_list = []
topic_filter_list = []
if self.app_context.config and "auth" in self.app_context.config:
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
if self.app_context.config and "topic" in self.app_context.config:
topic_filter_list = self.app_context.config["topic"].get("plugins", [])
ep: EntryPoints | list[EntryPoint] = []
if hasattr(entry_points(), "select"):
ep = entry_points().select(group=namespace)
@ -73,12 +94,16 @@ class PluginManager:
ep = [entry_points()[namespace]]
for item in ep:
plugin = self._load_plugin(item)
plugin = self._load_ep_plugin(item)
if plugin is not None:
self._plugins.append(plugin)
self._plugins.append(plugin.object)
if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"):
self._auth_plugins.append(plugin.object)
if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"):
self._topic_plugins.append(plugin.object)
self.logger.debug(f" Plugin {item.name} ready")
def _load_plugin(self, ep: EntryPoint) -> Plugin | None:
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
try:
self.logger.debug(f" Loading plugin {ep!s}")
plugin = ep.load()
@ -98,26 +123,27 @@ class PluginManager:
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
raise PluginInitError(ep) from e
def get_plugin(self, name: str) -> Plugin | None:
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
"""Get a plugin by its name from the plugins loaded for the current namespace.
:param name:
:return:
"""
for p in self._plugins:
if p.name == name:
self.logger.debug(f"plugin name >>>> {p.__class__.__name__}")
if p.__class__.__name__ == name:
return p
return None
async def close(self) -> None:
"""Free PluginManager resources and cancel pending event methods."""
await self.map_plugin_coro("close")
await self.map_plugin_close()
for task in self._fired_events:
task.cancel()
self._fired_events.clear()
@property
def plugins(self) -> list[Plugin]:
def plugins(self) -> list["BasePlugin[C]"]:
"""Get the loaded plugins list.
:return:
@ -143,7 +169,7 @@ class PluginManager:
tasks: list[asyncio.Future[Any]] = []
event_method_name = "on_" + event_name
for plugin in self._plugins:
event_method = getattr(plugin.object, event_method_name, None)
event_method = getattr(plugin, event_method_name, None)
if event_method:
try:
task = self._schedule_coro(event_method(*args, **kwargs))
@ -155,66 +181,73 @@ class PluginManager:
task.add_done_callback(clean_fired_events)
except AssertionError:
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.name}' is not a coroutine")
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.__class__}' is not a coroutine")
self._fired_events.extend(tasks)
if wait and tasks:
await asyncio.wait(tasks)
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
async def map(
self,
coro: Callable[[Plugin, Any], Awaitable[str | bool | None]],
*args: Any,
**kwargs: Any,
) -> dict[Plugin, str | bool | None]:
"""Schedule a given coroutine call for each plugin.
@staticmethod
async def _map_plugin_method(
plugins: list["BasePlugin[C]"],
method_name: str,
method_kwargs: dict[str, Any],
) -> dict["BasePlugin[C]", str | bool | None]:
"""Call plugin coroutines.
The coro called gets the Plugin instance as the first argument of its method call.
:param coro: coro to call on each plugin
:param filter_plugins: list of plugin names to filter (only plugin whose name is
in the filter are called). None will call all plugins. [] will call None.
:param args: arguments to pass to coro
:param kwargs: arguments to pass to coro
:param plugins: List of plugins to execute the method on
:param method_name: Name of the method to call on each plugin
:param method_kwargs: Keyword arguments to pass to the method
:return: dict containing return from coro call for each plugin.
"""
p_list = kwargs.pop("filter_plugins", None)
if p_list is None:
p_list = [p.name for p in self.plugins]
tasks: list[asyncio.Future[Any]] = []
plugins_list: list[Plugin] = []
for plugin in self._plugins:
if plugin.name in p_list:
coro_instance = coro(plugin, *args, **kwargs)
if coro_instance:
try:
tasks.append(self._schedule_coro(coro_instance))
plugins_list.append(plugin)
except AssertionError:
self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine")
for plugin in plugins:
if not hasattr(plugin, method_name):
continue
async def call_method(p: "BasePlugin[C]", kwargs: dict[str, Any]) -> Any:
method = getattr(p, method_name)
return await method(**kwargs)
coro_instance: Awaitable[Any] = call_method(plugin, method_kwargs)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(plugins_list, ret_list, strict=False))
else:
ret_dict = {}
ret_dict = dict(zip(plugins, ret_list, strict=False))
return ret_dict
@staticmethod
async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None:
if not hasattr(plugin.object, coro_name):
_LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}")
return None
async def map_plugin_auth(self, *, session: Session) -> dict["BasePlugin[C]", str | bool | None]:
"""Schedule a coroutine for plugin 'authenticate' calls.
coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs)
return await coro
async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]:
"""Call a plugin declared by plugin by its name.
:param coro_name:
:param args:
:param kwargs:
:return:
:param session: the client session associated with the authentication check
:return: dict containing return from coro call for each plugin.
"""
return await self.map(self._call_coro, coro_name, *args, **kwargs)
return await self._map_plugin_method(
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
async def map_plugin_topic(
self, *, session: Session, topic: str, action: "Action"
) -> dict["BasePlugin[C]", str | bool | None]:
"""Schedule a coroutine for plugin 'topic_filtering' calls.
:param session: the client session associated with the topic_filtering check
:param topic: the topic that needs to be filtered
:param action: the action being executed
:return: dict containing return from coro call for each plugin.
"""
return await self._map_plugin_method(
self._topic_plugins, "topic_filtering", # type: ignore[arg-type]
{"session": session, "topic": topic, "action": action}
)
async def map_plugin_close(self) -> None:
"""Schedule a coroutine for plugin 'close' calls.
:return: dict containing return from coro call for each plugin.
"""
await self._map_plugin_method(self._plugins, "close", {})

Wyświetl plik

@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected"
STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
class BrokerSysPlugin(BasePlugin):
class BrokerSysPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# Broker statistics initialization

Wyświetl plik

@ -1,14 +1,15 @@
from typing import Any
from amqtt.broker import Action, BrokerContext
from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
class BaseTopicPlugin(BasePlugin):
class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins."""
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
@ -37,7 +38,7 @@ class BaseTopicPlugin(BasePlugin):
class TopicTabooPlugin(BaseTopicPlugin):
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"]

Wyświetl plik

@ -6,3 +6,5 @@ default_retain: false
auto_reconnect: true
reconnect_max_interval: 10
reconnect_retries: 2
broker:
uri: "mqtt://127.0.0.1"

Wyświetl plik

@ -23,7 +23,9 @@ its own variables to configure its behavior.
::: amqtt.plugins.base.BasePlugin
Plugins that are defined in the`project.entry-points` are loaded and notified of events by when the subclass
## Events
Plugins that are defined in the`project.entry-points` are notified of events if the subclass
implements one or more of these methods:
- `on_mqtt_packet_sent`

Wyświetl plik

@ -32,7 +32,8 @@ dependencies = [
"websockets==15.0.1", # https://pypi.org/project/websockets
"passlib==1.7.4", # https://pypi.org/project/passlib
"PyYAML==6.0.2", # https://pypi.org/project/PyYAML
"typer==0.15.4"
"typer==0.15.4",
"dacite>=1.9.2",
]
[dependency-groups]

Wyświetl plik

@ -0,0 +1,10 @@
---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- test.plugins.plugins.TestSimplePlugin
- test.plugins.plugins.TestConfigPlugin:
option1: foo
option2: bar

Wyświetl plik

@ -1,18 +1,55 @@
import logging
from dataclasses import dataclass
from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class NoAuthPlugin(BaseAuthPlugin):
class TestSimplePlugin(BasePlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
class TestConfigPlugin(BasePlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
@dataclass
class Config:
option1: int
option2: str
async def authenticate(self, *, session: Session) -> bool | None:
return False
class AuthPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
return True
class NoAuthPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
return False
class TestTopicPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
return True

Wyświetl plik

@ -1,9 +1,12 @@
import asyncio
import logging
from typing import Any
import unittest
from amqtt.plugins.manager import BaseContext, Plugin, PluginManager
from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
@ -14,21 +17,29 @@ class EmptyTestPlugin:
self.context = context
class EventTestPlugin:
class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None:
self.context = context
self.test_flag = False
self.coro_flag = False
super().__init__(context)
self.test_close_flag = False
self.test_auth_flag = False
self.test_topic_flag = False
self.test_event_flag = False
async def on_test(self) -> None:
self.test_flag = True
self.context.logger.info("on_test")
self.test_event_flag = True
async def test_coro(self) -> None:
self.coro_flag = True
async def authenticate(self, *, session: Session) -> bool | None:
self.test_auth_flag = True
return None
async def ret_coro(self) -> str:
return "TEST"
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
self.test_topic_flag = True
return False
async def close(self) -> None:
self.test_close_flag = True
class TestPluginManager(unittest.TestCase):
@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase):
manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(fire_event())
plugin = manager.get_plugin("event_plugin")
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.object.test_flag
assert plugin.test_event_flag
def test_fire_event_wait(self) -> None:
async def fire_event() -> None:
@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase):
manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(fire_event())
plugin = manager.get_plugin("event_plugin")
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.object.test_flag
assert plugin.test_event_flag
def test_map_coro(self) -> None:
async def call_coro() -> None:
await manager.map_plugin_coro("test_coro")
def test_plugin_close_coro(self) -> None:
manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(call_coro())
plugin = manager.get_plugin("event_plugin")
self.loop.run_until_complete(manager.map_plugin_close())
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.object.test_coro
assert plugin.test_close_flag
def test_map_coro_return(self) -> None:
async def call_coro() -> dict[Plugin, str]:
return await manager.map_plugin_coro("ret_coro")
def test_plugin_auth_coro(self) -> None:
manager = PluginManager("amqtt.test.plugins", context=None)
ret = self.loop.run_until_complete(call_coro())
plugin = manager.get_plugin("event_plugin")
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert ret[plugin] == "TEST"
assert plugin.test_auth_flag
def test_map_coro_filter(self) -> None:
"""Run plugin coro but expect no return as an empty filter is given."""
async def call_coro() -> dict[Plugin, Any]:
return await manager.map_plugin_coro("ret_coro", filter_plugins=[])
def test_plugin_topic_coro(self) -> None:
manager = PluginManager("amqtt.test.plugins", context=None)
ret = self.loop.run_until_complete(call_coro())
assert len(ret) == 0
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.test_topic_flag

11
uv.lock
Wyświetl plik

@ -12,6 +12,7 @@ name = "amqtt"
version = "0.11.0rc1"
source = { editable = "." }
dependencies = [
{ name = "dacite" },
{ name = "passlib" },
{ name = "pyyaml" },
{ name = "transitions" },
@ -66,6 +67,7 @@ docs = [
[package.metadata]
requires-dist = [
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
{ name = "dacite", specifier = ">=1.9.2" },
{ name = "passlib", specifier = "==1.7.4" },
{ name = "pyyaml", specifier = "==6.0.2" },
{ name = "transitions", specifier = "==0.9.2" },
@ -485,6 +487,15 @@ version = "0.9.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" }
[[package]]
name = "dacite"
version = "1.9.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" },
]
[[package]]
name = "dill"
version = "0.4.0"