diff --git a/hbmqtt/plugins/manager.py b/hbmqtt/plugins/manager.py index af42bda..73d985c 100644 --- a/hbmqtt/plugins/manager.py +++ b/hbmqtt/plugins/manager.py @@ -140,22 +140,30 @@ class PluginManager: """ Schedule a given coroutine call for each plugin. The coro called get the Plugin instance as first argument of its method call - :param coro: - :param args: - :param kwargs: - :return: + :param coro: coro to call on each plugin + :param filter_plugins: list of plugin names to filter (only plugin whose name is in filter are called). + None will call all plugins. [] will call None. + :param args: arguments to pass to coro + :param kwargs: arguments to pass to coro + :return: dict containing return from coro call for each plugin """ + p_list = kwargs.pop('filter_plugins', None) + if p_list is None: + p_list = [p.name for p in self.plugins] tasks = [] plugins_list = [] for plugin in self._plugins: - coro_instance = coro(plugin, *args, **kwargs) - if coro_instance: - tasks.append(self._schedule_coro(coro_instance)) - plugins_list.append(plugin) - ret_list = yield from asyncio.gather(*tasks, loop=self._loop) - - # Create result map plugin=>ret - ret_dict = {k: v for k, v in zip(plugins_list, ret_list)} + if plugin.name in p_list: + coro_instance = coro(plugin, *args, **kwargs) + if coro_instance: + tasks.append(self._schedule_coro(coro_instance)) + plugins_list.append(plugin) + if tasks: + ret_list = yield from asyncio.gather(*tasks, loop=self._loop) + # Create result map plugin=>ret + ret_dict = {k: v for k, v in zip(plugins_list, ret_list)} + else: + ret_dict = {} return ret_dict @staticmethod diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index d18394e..312e2ec 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -22,16 +22,16 @@ class EventTestPlugin: self.coro_flag = False @asyncio.coroutine - def on_test(self): + def on_test(self, *args, **kwargs): self.test_flag = True self.context.logger.info("on_test") @asyncio.coroutine - def test_coro(self): + def test_coro(self, *args, **kwargs): self.coro_flag = True @asyncio.coroutine - def ret_coro(self): + def ret_coro(self, *args, **kwargs): return "TEST" @@ -85,3 +85,16 @@ class TestPluginManager(unittest.TestCase): ret = self.loop.run_until_complete(call_coro()) plugin = manager.get_plugin("event_plugin") self.assertEqual(ret[plugin], "TEST") + + def test_map_coro_filter(self): + """ + Run plugin coro but expect no return as an empty filter is given + :return: + """ + @asyncio.coroutine + def call_coro(): + return (yield from manager.map_plugin_coro('ret_coro', filter_plugins=[])) + + manager = PluginManager("hbmqtt.test.plugins", context=None, loop=self.loop) + ret = self.loop.run_until_complete(call_coro()) + self.assertTrue(len(ret) == 0)