kopia lustrzana https://github.com/Yakifo/amqtt
all test cases working. most fixes were to make the dataclass config's backwards compatible with how the rest of the code was accessing config dictionaries
rodzic
5c59248b4f
commit
81866d0238
|
@ -37,10 +37,6 @@ from .plugins.manager import PluginManager
|
||||||
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
||||||
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
||||||
|
|
||||||
|
|
||||||
_default_broker = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
|
|
||||||
_defaults = dict_to_dataclass(BrokerConfig, _default_broker, config=DaciteConfig(cast=[StrEnum]))
|
|
||||||
|
|
||||||
# Default port numbers
|
# Default port numbers
|
||||||
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
||||||
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
||||||
|
@ -103,7 +99,7 @@ class BrokerContext(BaseContext):
|
||||||
|
|
||||||
def __init__(self, broker: "Broker") -> None:
|
def __init__(self, broker: "Broker") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config: _CONFIG_LISTENER | None = None
|
self.config: BrokerConfig | None = None
|
||||||
self._broker_instance = broker
|
self._broker_instance = broker
|
||||||
|
|
||||||
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
||||||
|
@ -158,20 +154,10 @@ class Broker:
|
||||||
"""Initialize the broker."""
|
"""Initialize the broker."""
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
self.config = dict_to_dataclass(BrokerConfig, config, config=DaciteConfig(cast=[StrEnum]))
|
self.config = BrokerConfig.from_dict(config)
|
||||||
|
|
||||||
self.config |= _defaults
|
# listeners are populated from default within BrokerConfig
|
||||||
|
self.listeners_config = self.config.listeners
|
||||||
# if config is not None:
|
|
||||||
# # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
|
|
||||||
# if ("auth" in config or "topic-check" in config) and "plugins" not in config:
|
|
||||||
# # set to None so that the config isn't updated with the new-style default plugin list
|
|
||||||
# config["plugins"] = None # type: ignore[assignment]
|
|
||||||
# self.config.update(config)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self._build_listeners_config(self.config)
|
|
||||||
|
|
||||||
self._loop = loop or asyncio.get_running_loop()
|
self._loop = loop or asyncio.get_running_loop()
|
||||||
self._servers: dict[str, Server] = {}
|
self._servers: dict[str, Server] = {}
|
||||||
|
@ -197,25 +183,6 @@ class Broker:
|
||||||
namespace = plugin_namespace or "amqtt.broker.plugins"
|
namespace = plugin_namespace or "amqtt.broker.plugins"
|
||||||
self.plugins_manager = PluginManager(namespace, context, self._loop)
|
self.plugins_manager = PluginManager(namespace, context, self._loop)
|
||||||
|
|
||||||
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
|
|
||||||
self.listeners_config = {}
|
|
||||||
try:
|
|
||||||
listeners_config = broker_config.get("listeners")
|
|
||||||
if not isinstance(listeners_config, dict):
|
|
||||||
msg = "Listener config not found or invalid"
|
|
||||||
raise BrokerError(msg)
|
|
||||||
defaults = listeners_config.get("default")
|
|
||||||
if defaults is None:
|
|
||||||
msg = "Listener config has not default included or is invalid"
|
|
||||||
raise BrokerError(msg)
|
|
||||||
|
|
||||||
for listener_name, listener_conf in listeners_config.items():
|
|
||||||
listener_conf |= defaults
|
|
||||||
self.listeners_config[listener_name] = listener_conf
|
|
||||||
except KeyError as ke:
|
|
||||||
msg = f"Listener config not found or invalid: {ke}"
|
|
||||||
raise BrokerError(msg) from ke
|
|
||||||
|
|
||||||
def _init_states(self) -> None:
|
def _init_states(self) -> None:
|
||||||
self.transitions = Machine(states=Broker.states, initial="new")
|
self.transitions = Machine(states=Broker.states, initial="new")
|
||||||
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
|
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
|
||||||
|
|
|
@ -33,9 +33,6 @@ from amqtt.utils import gen_client_id, read_yaml_config
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from websockets.asyncio.client import ClientConnection
|
from websockets.asyncio.client import ClientConnection
|
||||||
|
|
||||||
_default_client = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
|
|
||||||
_defaults = dict_to_dataclass(ClientConfig, _default_client, config=DaciteConfig(cast=[StrEnum]))
|
|
||||||
|
|
||||||
|
|
||||||
class ClientContext(BaseContext):
|
class ClientContext(BaseContext):
|
||||||
"""ClientContext is used as the context passed to plugins interacting with the client.
|
"""ClientContext is used as the context passed to plugins interacting with the client.
|
||||||
|
@ -99,9 +96,8 @@ class MQTTClient:
|
||||||
|
|
||||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.config = dict_to_dataclass(ClientConfig, config or {}, config=DaciteConfig(cast=[StrEnum]))
|
self.config = ClientConfig.from_dict(config)
|
||||||
|
|
||||||
self.config |= _defaults
|
|
||||||
self.client_id = client_id if client_id is not None else gen_client_id()
|
self.client_id = client_id if client_id is not None else gen_client_id()
|
||||||
|
|
||||||
self.session: Session | None = None
|
self.session: Session | None = None
|
||||||
|
@ -585,7 +581,17 @@ class MQTTClient:
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Initialize the MQTT session."""
|
"""Initialize the MQTT session."""
|
||||||
broker_conf = self.config.get("broker", {}).copy()
|
broker_conf = self.config.get("broker", {}).copy()
|
||||||
broker_conf.update(ConnectionConfig(uri=uri, cafile=cafile, capath=capath, cadata=cadata))
|
|
||||||
|
if uri is not None:
|
||||||
|
broker_conf.uri = uri
|
||||||
|
if cleansession is not None:
|
||||||
|
self.config.cleansession = cleansession
|
||||||
|
if cafile is not None:
|
||||||
|
broker_conf.cafile = cafile
|
||||||
|
if capath is not None:
|
||||||
|
broker_conf.capath = capath
|
||||||
|
if cadata is not None:
|
||||||
|
broker_conf.cadata = cadata
|
||||||
|
|
||||||
if not broker_conf.get("uri"):
|
if not broker_conf.get("uri"):
|
||||||
msg = "Missing connection parameter 'uri'"
|
msg = "Missing connection parameter 'uri'"
|
||||||
|
@ -594,14 +600,11 @@ class MQTTClient:
|
||||||
session = Session()
|
session = Session()
|
||||||
session.broker_uri = broker_conf["uri"]
|
session.broker_uri = broker_conf["uri"]
|
||||||
session.client_id = self.client_id
|
session.client_id = self.client_id
|
||||||
|
|
||||||
session.cafile = broker_conf.get("cafile")
|
session.cafile = broker_conf.get("cafile")
|
||||||
session.capath = broker_conf.get("capath")
|
session.capath = broker_conf.get("capath")
|
||||||
session.cadata = broker_conf.get("cadata")
|
session.cadata = broker_conf.get("cadata")
|
||||||
|
|
||||||
if cleansession is not None:
|
|
||||||
broker_conf["cleansession"] = cleansession # noop?
|
|
||||||
session.clean_session = cleansession
|
|
||||||
else:
|
|
||||||
session.clean_session = self.config.get("cleansession", True)
|
session.clean_session = self.config.get("cleansession", True)
|
||||||
|
|
||||||
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]
|
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]
|
||||||
|
|
|
@ -12,6 +12,10 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
|
||||||
|
|
||||||
|
|
||||||
class BaseContext:
|
class BaseContext:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -47,15 +51,18 @@ class Dictable:
|
||||||
raise ValueError(f"'{name}' is not defined")
|
raise ValueError(f"'{name}' is not defined")
|
||||||
|
|
||||||
def __contains__(self, name):
|
def __contains__(self, name):
|
||||||
return getattr(self, name, None) is not None
|
return getattr(self, name.replace('-', '_'), None) is not None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for field in fields(self):
|
for field in fields(self):
|
||||||
yield getattr(self, field.name)
|
yield getattr(self, field.name)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return replace(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ListenerConfig:
|
class ListenerConfig(Dictable):
|
||||||
"""Structured configuration for a broker's listeners."""
|
"""Structured configuration for a broker's listeners."""
|
||||||
|
|
||||||
type: ListenerType = ListenerType.TCP
|
type: ListenerType = ListenerType.TCP
|
||||||
|
@ -132,19 +139,17 @@ class BrokerConfig(Dictable):
|
||||||
"""Deprecated field used to config EntryPoint-loaded plugins. See
|
"""Deprecated field used to config EntryPoint-loaded plugins. See
|
||||||
[`TopicTabooPlugin`](#taboo-topic-plugin) and
|
[`TopicTabooPlugin`](#taboo-topic-plugin) and
|
||||||
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
|
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
|
||||||
plugins: dict | list[dict] | None = field(default_factory=default_broker_plugins)
|
plugins: dict | list | None = field(default_factory=default_broker_plugins)
|
||||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
||||||
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
||||||
[Plugins](http://localhost:8000/custom_plugins/) for more information."""
|
[Plugins](http://localhost:8000/custom_plugins/) for more information."""
|
||||||
|
|
||||||
def __post__init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.sys_interval is not None:
|
if self.sys_interval is not None:
|
||||||
warnings.warn("sys_interval is deprecated, use 'plugins' to define configuration",
|
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
|
||||||
DeprecationWarning, stacklevel=1)
|
|
||||||
|
|
||||||
if self.auth is not None or self.topic_check is not None:
|
if self.auth is not None or self.topic_check is not None:
|
||||||
warnings.warn("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration",
|
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
|
||||||
DeprecationWarning, stacklevel=1)
|
|
||||||
|
|
||||||
default_listener = self.listeners['default']
|
default_listener = self.listeners['default']
|
||||||
for listener_name, listener in self.listeners.items():
|
for listener_name, listener in self.listeners.items():
|
||||||
|
@ -155,13 +160,35 @@ class BrokerConfig(Dictable):
|
||||||
if isinstance(self.plugins, list):
|
if isinstance(self.plugins, list):
|
||||||
_plugins = {}
|
_plugins = {}
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
|
if isinstance(plugin, str):
|
||||||
|
_plugins |= {plugin:{}}
|
||||||
|
continue
|
||||||
_plugins |= plugin
|
_plugins |= plugin
|
||||||
self.plugins = _plugins
|
self.plugins = _plugins
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig':
|
||||||
|
if d is None:
|
||||||
|
return BrokerConfig()
|
||||||
|
|
||||||
|
if 'topic-check' in d:
|
||||||
|
d['topic_check'] = d['topic-check']
|
||||||
|
del d['topic-check']
|
||||||
|
|
||||||
|
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d:
|
||||||
|
d['plugins'] = None
|
||||||
|
|
||||||
|
return dict_to_dataclass(data_class=BrokerConfig,
|
||||||
|
data=d,
|
||||||
|
config=DaciteConfig(
|
||||||
|
cast=[StrEnum],
|
||||||
|
strict=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionConfig(Dictable):
|
class ConnectionConfig(Dictable):
|
||||||
uri: str | None = "mqtt://127.0.0.1"
|
uri: str | None = "mqtt://127.0.0.1:1883"
|
||||||
"""URI of the broker"""
|
"""URI of the broker"""
|
||||||
cafile: str | Path | None = None
|
cafile: str | Path | None = None
|
||||||
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
|
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
|
||||||
|
@ -188,7 +215,7 @@ class ConnectionConfig(Dictable):
|
||||||
setattr(self, fn, Path(getattr(self, fn)))
|
setattr(self, fn, Path(getattr(self, fn)))
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TopicConfig:
|
class TopicConfig(Dictable):
|
||||||
"""Configuration of how messages to specific topics are published. The topic name is
|
"""Configuration of how messages to specific topics are published. The topic name is
|
||||||
specified as the key in the dictionary of the `ClientConfig.topics."""
|
specified as the key in the dictionary of the `ClientConfig.topics."""
|
||||||
qos: int = 0
|
qos: int = 0
|
||||||
|
@ -203,7 +230,7 @@ class TopicConfig:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WillConfig:
|
class WillConfig(Dictable):
|
||||||
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
|
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
|
||||||
|
|
||||||
topic: str
|
topic: str
|
||||||
|
@ -267,3 +294,16 @@ class ClientConfig(Dictable):
|
||||||
msg = "Client config: default QoS must be 0, 1 or 2."
|
msg = "Client config: default QoS must be 0, 1 or 2."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig':
|
||||||
|
if d is None:
|
||||||
|
return ClientConfig()
|
||||||
|
|
||||||
|
return dict_to_dataclass(data_class=ClientConfig,
|
||||||
|
data=d,
|
||||||
|
config=DaciteConfig(
|
||||||
|
cast=[StrEnum],
|
||||||
|
strict=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from dataclasses import dataclass, is_dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
from typing import Any, Generic, TypeVar, cast
|
from typing import Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
from amqtt.contexts import Action, BaseContext
|
from amqtt.contexts import Action, BaseContext, BrokerConfig
|
||||||
from amqtt.session import Session
|
from amqtt.session import Session
|
||||||
|
|
||||||
C = TypeVar("C", bound=BaseContext)
|
C = TypeVar("C", bound=BaseContext)
|
||||||
|
@ -79,7 +79,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
||||||
if not self.context.config:
|
if not self.context.config:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
if is_dataclass(self.context.config):
|
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||||
if self.topic_config and option_name in self.topic_config:
|
if self.topic_config and option_name in self.topic_config:
|
||||||
|
@ -110,7 +110,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||||
if not self.context.config:
|
if not self.context.config:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
if is_dataclass(self.context.config):
|
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||||
if self.auth_config and option_name in self.auth_config:
|
if self.auth_config and option_name in self.auth_config:
|
||||||
|
|
|
@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
|
||||||
# plugins loaded directly from config dictionary
|
# plugins loaded directly from config dictionary
|
||||||
|
|
||||||
|
|
||||||
if "auth" in self.app_context.config:
|
if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||||
if "topic-check" in self.app_context.config:
|
if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||||
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
|
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
|
||||||
|
|
||||||
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
|
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
|
||||||
|
|
|
@ -58,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
||||||
|
|
||||||
req_topic = topic
|
req_topic = topic
|
||||||
if not req_topic:
|
if not req_topic:
|
||||||
return False\
|
return False
|
||||||
|
|
||||||
username = session.username if session else None
|
username = session.username if session else None
|
||||||
if username is None:
|
if username is None:
|
||||||
|
|
|
@ -236,7 +236,7 @@ async def test_client_connect_clean_session_false(broker):
|
||||||
client = MQTTClient(client_id="", config={"auto_reconnect": False})
|
client = MQTTClient(client_id="", config={"auto_reconnect": False})
|
||||||
return_code = None
|
return_code = None
|
||||||
try:
|
try:
|
||||||
await client.connect("mqtt://127.0.0.1/", cleansession=False)
|
await client.connect("mqtt://127.0.0.1", cleansession=False)
|
||||||
except ConnectError as ce:
|
except ConnectError as ce:
|
||||||
return_code = ce.return_code
|
return_code = ce.return_code
|
||||||
assert return_code == 0x02
|
assert return_code == 0x02
|
||||||
|
@ -431,11 +431,11 @@ async def test_client_publish_acl_forbidden(acl_broker):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
|
async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
|
||||||
sub_client1 = MQTTClient()
|
sub_client1 = MQTTClient(client_id="sub_client1")
|
||||||
ret_conn = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/")
|
ret_conn = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/")
|
||||||
assert ret_conn == 0
|
assert ret_conn == 0
|
||||||
|
|
||||||
sub_client2 = MQTTClient()
|
sub_client2 = MQTTClient(client_id="sub_client2")
|
||||||
ret_conn = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/")
|
ret_conn = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/")
|
||||||
assert ret_conn == 0
|
assert ret_conn == 0
|
||||||
|
|
||||||
|
@ -445,7 +445,7 @@ async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
|
||||||
ret_sub = await sub_client2.subscribe([("public/subtopic/test", QOS_0)])
|
ret_sub = await sub_client2.subscribe([("public/subtopic/test", QOS_0)])
|
||||||
assert ret_sub == [128]
|
assert ret_sub == [128]
|
||||||
|
|
||||||
pub_client = MQTTClient()
|
pub_client = MQTTClient(client_id="pub_client")
|
||||||
ret_conn = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
|
ret_conn = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
|
||||||
assert ret_conn == 0
|
assert ret_conn == 0
|
||||||
|
|
||||||
|
|
|
@ -1,39 +1,40 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from yaml import CLoader as Loader
|
||||||
|
|
||||||
|
|
||||||
|
import yaml
|
||||||
from dacite import from_dict, Config, UnexpectedDataError
|
from dacite import from_dict, Config, UnexpectedDataError
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from amqtt.broker import BrokerContext
|
||||||
from amqtt.contexts import BrokerConfig, ListenerConfig, Dictable
|
from amqtt.contexts import BrokerConfig, ListenerConfig, Dictable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _test_broker_config():
|
def test_entrypoint_broker_config(caplog):
|
||||||
# Parse with dacite
|
test_cfg: dict[str, Any] = {
|
||||||
config = from_dict(data_class=BrokerConfig, data=data, config=Config(cast=[StrEnum]))
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
assert isinstance(config, BrokerConfig)
|
},
|
||||||
assert isinstance(config.listeners['default'], ListenerConfig)
|
'sys_interval': 1,
|
||||||
assert isinstance(config.listeners['secure'], ListenerConfig)
|
'auth': {
|
||||||
|
'allow_anonymous': True
|
||||||
default = config.listeners['default']
|
}
|
||||||
secure = config.listeners['secure']
|
}
|
||||||
secure_one = secure.copy()
|
if 'plugins' not in test_cfg:
|
||||||
secure_two = secure.copy()
|
test_cfg['plugins'] = None
|
||||||
|
# cfg: dict[str, Any] = yaml.load(config, Loader=Loader)
|
||||||
secure_one |= default
|
|
||||||
|
|
||||||
assert secure_one.max_connections == 50
|
|
||||||
assert secure_one.bind == '0.0.0.0:8883'
|
|
||||||
assert secure_one.cafile == Path('ca.key')
|
|
||||||
|
|
||||||
secure_two.update(default)
|
|
||||||
assert secure_two.max_connections == 50
|
|
||||||
assert secure_two.bind == '0.0.0.0:8883'
|
|
||||||
assert secure_two.cafile == Path('ca.key')
|
|
||||||
|
|
||||||
|
|
||||||
|
broker_config = from_dict(data_class=BrokerConfig, data=test_cfg, config=Config(cast=[StrEnum]))
|
||||||
|
assert isinstance(broker_config, BrokerConfig)
|
||||||
|
|
||||||
|
assert broker_config.plugins is None
|
||||||
|
|
||||||
|
|
||||||
|
|
Ładowanie…
Reference in New Issue