From b373e8bd56cd7d9eadb152e5edaa37c293364835 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Thu, 12 Jun 2025 20:12:08 -0400 Subject: [PATCH] abstract common pattern for calling plugin methods --- amqtt/plugins/manager.py | 93 ++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 56 deletions(-) diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index 177c539..9aa8f9c 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -182,36 +182,50 @@ class PluginManager(Generic[C]): await asyncio.wait(tasks) self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") - async def map_plugin_auth(self, session: Session) -> dict["BasePlugin[C]", str | bool | None]: - """Schedule a coroutine for plugin 'authenticate' calls. + @staticmethod + async def _map_plugin_method( + plugins: list["BasePlugin[C]"], + method_name: str, + method_kwargs: dict[str, Any], + ) -> dict["BasePlugin[C]", str | bool | None]: + """Generic helper to map a method call across plugins. - :param session: the client session associated with the authentication check + :param plugins: List of plugins to execute the method on + :param method_name: Name of the method to call on each plugin + :param method_kwargs: Keyword arguments to pass to the method :return: dict containing return from coro call for each plugin. """ tasks: list[asyncio.Future[Any]] = [] - for plugin in self._auth_plugins: - - async def auth_coro(p: "BaseAuthPlugin", s: Session) -> str | bool | None: - return await p.authenticate(session=s) - - if not hasattr(plugin, "authenticate"): + for plugin in plugins: + if not hasattr(plugin, method_name): continue - coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session) + async def call_method(p: "BasePlugin[C]", kwargs: dict[str, Any]) -> Any: + method = getattr(p, method_name) + return await method(**kwargs) + + coro_instance: Awaitable[Any] = call_method(plugin, method_kwargs) tasks.append(asyncio.ensure_future(coro_instance)) ret_dict: dict[BasePlugin[C], str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) - # Create result map plugin => ret - ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) # type: ignore[arg-type] + ret_dict = dict(zip(plugins, ret_list, strict=False)) # type: ignore[arg-type] return ret_dict - async def map_plugin_topic(self, - session: Session, topic: str, action: "Action" - ) -> dict["BasePlugin[C]", str | bool | None]: + async def map_plugin_auth(self, *, session: Session) -> dict["BasePlugin[C]", str | bool | None]: + """Schedule a coroutine for plugin 'authenticate' calls. + + :param session: the client session associated with the authentication check + :return: dict containing return from coro call for each plugin. + """ + return await self._map_plugin_method(self._auth_plugins, "authenticate", {'session': session }) + + async def map_plugin_topic( + self, *, session: Session, topic: str, action: "Action" + ) -> dict["BasePlugin[C]", str | bool | None]: """Schedule a coroutine for plugin 'topic_filtering' calls. :param session: the client session associated with the topic_filtering check @@ -219,46 +233,13 @@ class PluginManager(Generic[C]): :param action: the action being executed :return: dict containing return from coro call for each plugin. """ - tasks: list[asyncio.Future[Any]] = [] + return await self._map_plugin_method( + self._topic_plugins, "topic_filtering", {'session': session, 'topic': topic, 'action': action} + ) - for plugin in self._topic_plugins: + async def map_plugin_close(self) -> None: + """Schedule a coroutine for plugin 'close' calls. - async def topic_coro(p: "BaseTopicPlugin", s: Session, t: str, a: "Action") -> str | bool | None: - return await p.topic_filtering(session=s, topic=t, action=a) - - if not hasattr(plugin, "topic_filtering"): - continue - - coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action) - tasks.append(asyncio.ensure_future(coro_instance)) - - ret_dict: dict[BasePlugin[C], str | bool | None] = {} - if tasks: - ret_list = await asyncio.gather(*tasks) - # Create result map plugin => ret - ret_dict= dict(zip(self._topic_plugins, ret_list, strict=False)) # type: ignore[arg-type] - - return ret_dict - - async def map_plugin_close(self) -> dict["BasePlugin[C]", str | bool | None]: - - tasks: list[asyncio.Future[Any]] = [] - - for plugin in self._plugins: - - async def close_coro(p: "BasePlugin[C]") -> None: - await p.close() - - if not hasattr(plugin, "close"): - continue - - coro_instance: Awaitable[str | bool | None] = close_coro(plugin) - tasks.append(asyncio.ensure_future(coro_instance)) - - ret_dict: dict[BasePlugin[C], str | bool | None] = {} - if tasks: - ret_list = await asyncio.gather(*tasks) - # Create result map plugin => ret - ret_dict = dict(zip(self._plugins, ret_list, strict=False)) - - return ret_dict + :return: dict containing return from coro call for each plugin. + """ + await self._map_plugin_method(self._plugins, "close", {})