kopia lustrzana https://github.com/Yakifo/amqtt
loading plugins from config file, bypassing entry points. authenticate and subscribe/publish
rodzic
06053ce7ee
commit
7b936d785c
|
|
@ -693,16 +693,17 @@ class Broker:
|
|||
auth_config = self.config.get("auth", None)
|
||||
if isinstance(auth_config, dict):
|
||||
auth_plugins = auth_config.get("plugins", None)
|
||||
returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
|
||||
# returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
|
||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||
auth_result = True
|
||||
if returns:
|
||||
for plugin in returns:
|
||||
res = returns[plugin]
|
||||
if res is False:
|
||||
auth_result = False
|
||||
self.logger.debug(f"Authentication failed due to '{plugin.name}' plugin result: {res}")
|
||||
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
|
||||
else:
|
||||
self.logger.debug(f"'{plugin.name}' plugin result: {res}")
|
||||
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
|
||||
# If all plugins returned True, authentication is success
|
||||
return auth_result
|
||||
|
||||
|
|
@ -771,13 +772,15 @@ class Broker:
|
|||
|
||||
if not enabled:
|
||||
return True
|
||||
results = await self.plugins_manager.map_plugin_coro(
|
||||
"topic_filtering",
|
||||
session=session,
|
||||
topic=topic,
|
||||
action=action,
|
||||
filter_plugins=topic_plugins,
|
||||
)
|
||||
|
||||
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
||||
# results = await self.plugins_manager.map_plugin_coro(
|
||||
# "topic_filtering",
|
||||
# session=session,
|
||||
# topic=topic,
|
||||
# action=action,
|
||||
# filter_plugins=topic_plugins,
|
||||
# )
|
||||
return all(result for result in results.values())
|
||||
|
||||
async def _delete_session(self, client_id: str) -> None:
|
||||
|
|
|
|||
|
|
@ -30,3 +30,8 @@ class ConnectError(ClientError):
|
|||
|
||||
class ProtocolHandlerError(Exception):
|
||||
"""Exceptions thrown by protocol handle."""
|
||||
|
||||
|
||||
class PluginLoadError(Exception):
|
||||
"""Exception thrown when loading a plugin."""
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from passlib.apps import custom_app_context as pwd_context
|
|||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
||||
|
|
@ -13,7 +14,7 @@ _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
|||
class BaseAuthPlugin(BasePlugin):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
|
||||
|
|
|
|||
|
|
@ -1,19 +1,26 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""The base from which all plugins should inherit."""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
self.context = context
|
||||
|
||||
def _get_config_section(self, name: str) -> dict[str, Any] | None:
|
||||
if not self.context.config or not self.context.config.get(name, None):
|
||||
|
||||
if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None):
|
||||
return None
|
||||
|
||||
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
|
||||
# mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check
|
||||
if isinstance(section_config, int):
|
||||
return None
|
||||
return section_config
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,10 +6,20 @@ import contextlib
|
|||
import copy
|
||||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
import logging
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||
|
||||
from amqtt.errors import MQTTError, PluginLoadError
|
||||
from amqtt.session import Session
|
||||
from amqtt.utils import import_string
|
||||
from dacite import from_dict, Config, DaciteError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.broker import Action
|
||||
|
||||
class Plugin(NamedTuple):
|
||||
name: str
|
||||
|
|
@ -53,7 +63,9 @@ class PluginManager:
|
|||
self.logger = logging.getLogger(namespace)
|
||||
self.context = context if context is not None else BaseContext()
|
||||
self.context.loop = self._loop
|
||||
self._plugins: list[Plugin] = []
|
||||
self._plugins: list[BasePlugin] = []
|
||||
self._auth_plugins: list[BaseAuthPlugin] = []
|
||||
self._topic_plugins: list[BaseTopicPlugin] = []
|
||||
self._load_plugins(namespace)
|
||||
self._fired_events: list[asyncio.Future[Any]] = []
|
||||
plugins_manager[namespace] = self
|
||||
|
|
@ -63,20 +75,54 @@ class PluginManager:
|
|||
return self.context
|
||||
|
||||
def _load_plugins(self, namespace: str) -> None:
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
ep: EntryPoints | list[EntryPoint] = []
|
||||
if hasattr(entry_points(), "select"):
|
||||
ep = entry_points().select(group=namespace)
|
||||
elif namespace in entry_points():
|
||||
ep = [entry_points()[namespace]]
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
|
||||
if 'plugins' in self.app_context.config:
|
||||
self.logger.info("Loading plugins from config file")
|
||||
for plugin_info in self.app_context.config['plugins']:
|
||||
|
||||
if isinstance(plugin_info, dict):
|
||||
assert len(plugin_info.keys()) == 1
|
||||
plugin_path = list(plugin_info.keys())[0]
|
||||
plugin_cfg = plugin_info[plugin_path]
|
||||
plugin = self._load_str_plugin(plugin_path, plugin_cfg)
|
||||
elif isinstance(plugin_info, str):
|
||||
plugin = self._load_str_plugin(plugin_info, {})
|
||||
else:
|
||||
msg = 'Unexpected entry in plugins config'
|
||||
raise PluginLoadError(msg)
|
||||
|
||||
for item in ep:
|
||||
plugin = self._load_plugin(item)
|
||||
if plugin is not None:
|
||||
self._plugins.append(plugin)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
if isinstance(plugin, BaseAuthPlugin):
|
||||
self._auth_plugins.append(plugin)
|
||||
if isinstance(plugin, BaseTopicPlugin):
|
||||
self._topic_plugins.append(plugin)
|
||||
|
||||
def _load_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||
|
||||
|
||||
else:
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
|
||||
auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else []
|
||||
topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else []
|
||||
ep: EntryPoints | list[EntryPoint] = []
|
||||
if hasattr(entry_points(), "select"):
|
||||
ep = entry_points().select(group=namespace)
|
||||
elif namespace in entry_points():
|
||||
ep = [entry_points()[namespace]]
|
||||
|
||||
for item in ep:
|
||||
plugin = self._load_ep_plugin(item)
|
||||
if plugin is not None:
|
||||
self._plugins.append(plugin.object)
|
||||
if plugin.name in auth_filter_list:
|
||||
self._auth_plugins.append(plugin.object)
|
||||
elif plugin.name in topic_filter_list:
|
||||
self._topic_plugins.append(plugin.object)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
|
||||
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||
try:
|
||||
self.logger.debug(f" Loading plugin {ep!s}")
|
||||
plugin = ep.load()
|
||||
|
|
@ -92,16 +138,43 @@ class PluginManager:
|
|||
|
||||
return None
|
||||
|
||||
def get_plugin(self, name: str) -> Plugin | None:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin':
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
|
||||
:param name:
|
||||
:return:
|
||||
"""
|
||||
for p in self._plugins:
|
||||
if p.name == name:
|
||||
return p
|
||||
return None
|
||||
try:
|
||||
plugin_class = import_string(plugin_path)
|
||||
except ModuleNotFoundError as ep:
|
||||
self.logger.error(f"Plugin import failed: {plugin_path}")
|
||||
raise MQTTError() from ep
|
||||
|
||||
if not issubclass(plugin_class, BasePlugin):
|
||||
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
|
||||
raise PluginLoadError(msg)
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
|
||||
try:
|
||||
plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True))
|
||||
except DaciteError as e:
|
||||
raise PluginLoadError from e
|
||||
|
||||
try:
|
||||
return plugin_class(plugin_context)
|
||||
except ImportError as e:
|
||||
raise PluginLoadError from e
|
||||
|
||||
# def get_plugin(self, name: str) -> Plugin | None:
|
||||
# """Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
#
|
||||
# :param name:
|
||||
# :return:
|
||||
# """
|
||||
# for p in self._plugins:
|
||||
# if p.name == name:
|
||||
# return p
|
||||
# return None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Free PluginManager resources and cancel pending event methods."""
|
||||
|
|
@ -111,7 +184,7 @@ class PluginManager:
|
|||
self._fired_events.clear()
|
||||
|
||||
@property
|
||||
def plugins(self) -> list[Plugin]:
|
||||
def plugins(self) -> list['BasePlugin']:
|
||||
"""Get the loaded plugins list.
|
||||
|
||||
:return:
|
||||
|
|
@ -137,7 +210,7 @@ class PluginManager:
|
|||
tasks: list[asyncio.Future[Any]] = []
|
||||
event_method_name = "on_" + event_name
|
||||
for plugin in self._plugins:
|
||||
event_method = getattr(plugin.object, event_method_name, None)
|
||||
event_method = getattr(plugin, event_method_name, None)
|
||||
if event_method:
|
||||
try:
|
||||
task = self._schedule_coro(event_method(*args, **kwargs))
|
||||
|
|
@ -149,7 +222,7 @@ class PluginManager:
|
|||
|
||||
task.add_done_callback(clean_fired_events)
|
||||
except AssertionError:
|
||||
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.name}' is not a coroutine")
|
||||
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.__class__}' is not a coroutine")
|
||||
|
||||
self._fired_events.extend(tasks)
|
||||
if wait and tasks:
|
||||
|
|
@ -212,3 +285,44 @@ class PluginManager:
|
|||
:return:
|
||||
"""
|
||||
return await self.map(self._call_coro, coro_name, *args, **kwargs)
|
||||
|
||||
|
||||
async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]:
|
||||
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
for plugin in self._auth_plugins:
|
||||
|
||||
async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None:
|
||||
return await p.authenticate(session=s)
|
||||
|
||||
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
else:
|
||||
ret_dict = {}
|
||||
return ret_dict
|
||||
|
||||
async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]:
|
||||
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
for plugin in self._topic_plugins:
|
||||
|
||||
async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None:
|
||||
return await p.topic_filtering(session=s, topic=t, action=a)
|
||||
|
||||
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))}
|
||||
else:
|
||||
ret_dict = {}
|
||||
return ret_dict
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
from amqtt.broker import Action, BrokerContext
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
class BaseTopicPlugin(BasePlugin):
|
||||
"""Base class for topic plugins."""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
|
||||
|
|
@ -37,7 +38,7 @@ class BaseTopicPlugin(BasePlugin):
|
|||
|
||||
|
||||
class TopicTabooPlugin(BaseTopicPlugin):
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,3 +6,5 @@ default_retain: false
|
|||
auto_reconnect: true
|
||||
reconnect_max_interval: 10
|
||||
reconnect_retries: 2
|
||||
broker:
|
||||
uri: "mqtt://127.0.0.1"
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import string
|
||||
import typing
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
|
@ -48,3 +51,32 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
|
|||
except yaml.YAMLError:
|
||||
logger.exception(f"Invalid config_file {config_file}")
|
||||
return None
|
||||
|
||||
def cached_import(module_path: str, class_name: str=None) -> ModuleType:
|
||||
# Check whether module is loaded and fully initialized.
|
||||
if not ((module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None)) # noqa
|
||||
and getattr(spec, "_initializing", False) is False): # noqa
|
||||
module = import_module(module_path)
|
||||
if class_name:
|
||||
return getattr(module, class_name)
|
||||
return module
|
||||
|
||||
|
||||
# TODO : figure out proper return type
|
||||
def import_string(dotted_path) -> Any:
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by the
|
||||
last name in the path. Raise ImportError if the import failed.
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = dotted_path.rsplit(".", 1)
|
||||
except ValueError as err:
|
||||
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
|
||||
|
||||
try:
|
||||
return cached_import(module_path, class_name)
|
||||
except AttributeError as err:
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||
) from err
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ dependencies = [
|
|||
"websockets==15.0.1", # https://pypi.org/project/websockets
|
||||
"passlib==1.7.4", # https://pypi.org/project/passlib
|
||||
"PyYAML==6.0.2", # https://pypi.org/project/PyYAML
|
||||
"typer==0.15.4"
|
||||
"typer==0.15.4",
|
||||
"dacite>=1.9.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
@ -184,14 +185,14 @@ max-returns = 10
|
|||
|
||||
# ----------------------------------- PYTEST -----------------------------------
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
||||
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
timeout = 10
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||
#log_cli = true
|
||||
#log_level = "DEBUG"
|
||||
addopts = ["--tb=short", "--capture=tee-sys"]
|
||||
log_cli = true
|
||||
log_level = "DEBUG"
|
||||
|
||||
# ------------------------------------ MYPY ------------------------------------
|
||||
[tool.mypy]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- test.plugins.plugins.TestSimplePlugin
|
||||
- test.plugins.plugins.TestConfigPlugin:
|
||||
option1: foo
|
||||
option2: bar
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestSimplePlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
|
||||
class TestConfigPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
option1: int
|
||||
option2: str
|
||||
|
||||
|
||||
class TestAuthPlugin(BaseAuthPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return False
|
||||
|
||||
|
||||
class TestTopicPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
return True
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from yaml import CLoader as Loader
|
||||
from dacite import from_dict, Config, UnexpectedDataError
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import PluginLoadError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
plugin_config = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestSimplePlugin:
|
||||
- tests.plugins.mocks.TestConfigPlugin:
|
||||
option1: 1
|
||||
option2: bar
|
||||
"""
|
||||
|
||||
|
||||
plugin_invalid_config_one = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestSimplePlugin:
|
||||
option1: 1
|
||||
option2: bar
|
||||
"""
|
||||
|
||||
plugin_invalid_config_two = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestConfigPlugin:
|
||||
"""
|
||||
|
||||
plugin_config_auth = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestAuthPlugin:
|
||||
"""
|
||||
|
||||
plugin_config_topic = """---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- tests.plugins.mocks.TestTopicPlugin:
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_config_extra_fields():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginLoadError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_config_missing_fields():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
|
||||
|
||||
with pytest.raises(PluginLoadError):
|
||||
_ = Broker(config=cfg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternate_plugin_load():
|
||||
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
|
||||
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
client1 = MQTTClient()
|
||||
await client1.connect()
|
||||
await client1.publish('my/topic', b'my message')
|
||||
await client1.disconnect()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_topic_plugin_load():
|
||||
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
await broker.shutdown()
|
||||
11
uv.lock
11
uv.lock
|
|
@ -12,6 +12,7 @@ name = "amqtt"
|
|||
version = "0.11.0rc1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "dacite" },
|
||||
{ name = "passlib" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "transitions" },
|
||||
|
|
@ -66,6 +67,7 @@ docs = [
|
|||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
|
||||
{ name = "dacite", specifier = ">=1.9.2" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "pyyaml", specifier = "==6.0.2" },
|
||||
{ name = "transitions", specifier = "==0.9.2" },
|
||||
|
|
@ -485,6 +487,15 @@ version = "0.9.5"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" }
|
||||
|
||||
[[package]]
|
||||
name = "dacite"
|
||||
version = "1.9.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.4.0"
|
||||
|
|
|
|||
Ładowanie…
Reference in New Issue