all test cases working. most fixes were to make the dataclass config's backwards compatible with how the rest of the code was accessing config dictionaries

config_dataclasses
Andrew Mirsky 2025-07-12 19:53:15 -04:00
rodzic 5c59248b4f
commit 81866d0238
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 103 dodań i 92 usunięć

Wyświetl plik

@ -37,10 +37,6 @@ from .plugins.manager import PluginManager
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]] _CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
_default_broker = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
_defaults = dict_to_dataclass(BrokerConfig, _default_broker, config=DaciteConfig(cast=[StrEnum]))
# Default port numbers # Default port numbers
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
@ -103,7 +99,7 @@ class BrokerContext(BaseContext):
def __init__(self, broker: "Broker") -> None: def __init__(self, broker: "Broker") -> None:
super().__init__() super().__init__()
self.config: _CONFIG_LISTENER | None = None self.config: BrokerConfig | None = None
self._broker_instance = broker self._broker_instance = broker
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None: async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
@ -158,20 +154,10 @@ class Broker:
"""Initialize the broker.""" """Initialize the broker."""
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.config = dict_to_dataclass(BrokerConfig, config, config=DaciteConfig(cast=[StrEnum])) self.config = BrokerConfig.from_dict(config)
self.config |= _defaults # listeners are populated from default within BrokerConfig
self.listeners_config = self.config.listeners
# if config is not None:
# # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
# if ("auth" in config or "topic-check" in config) and "plugins" not in config:
# # set to None so that the config isn't updated with the new-style default plugin list
# config["plugins"] = None # type: ignore[assignment]
# self.config.update(config)
self._build_listeners_config(self.config)
self._loop = loop or asyncio.get_running_loop() self._loop = loop or asyncio.get_running_loop()
self._servers: dict[str, Server] = {} self._servers: dict[str, Server] = {}
@ -197,25 +183,6 @@ class Broker:
namespace = plugin_namespace or "amqtt.broker.plugins" namespace = plugin_namespace or "amqtt.broker.plugins"
self.plugins_manager = PluginManager(namespace, context, self._loop) self.plugins_manager = PluginManager(namespace, context, self._loop)
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
self.listeners_config = {}
try:
listeners_config = broker_config.get("listeners")
if not isinstance(listeners_config, dict):
msg = "Listener config not found or invalid"
raise BrokerError(msg)
defaults = listeners_config.get("default")
if defaults is None:
msg = "Listener config has not default included or is invalid"
raise BrokerError(msg)
for listener_name, listener_conf in listeners_config.items():
listener_conf |= defaults
self.listeners_config[listener_name] = listener_conf
except KeyError as ke:
msg = f"Listener config not found or invalid: {ke}"
raise BrokerError(msg) from ke
def _init_states(self) -> None: def _init_states(self) -> None:
self.transitions = Machine(states=Broker.states, initial="new") self.transitions = Machine(states=Broker.states, initial="new")
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change) self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)

Wyświetl plik

@ -33,9 +33,6 @@ from amqtt.utils import gen_client_id, read_yaml_config
if TYPE_CHECKING: if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection from websockets.asyncio.client import ClientConnection
_default_client = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
_defaults = dict_to_dataclass(ClientConfig, _default_client, config=DaciteConfig(cast=[StrEnum]))
class ClientContext(BaseContext): class ClientContext(BaseContext):
"""ClientContext is used as the context passed to plugins interacting with the client. """ClientContext is used as the context passed to plugins interacting with the client.
@ -99,9 +96,8 @@ class MQTTClient:
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None: def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.config = dict_to_dataclass(ClientConfig, config or {}, config=DaciteConfig(cast=[StrEnum])) self.config = ClientConfig.from_dict(config)
self.config |= _defaults
self.client_id = client_id if client_id is not None else gen_client_id() self.client_id = client_id if client_id is not None else gen_client_id()
self.session: Session | None = None self.session: Session | None = None
@ -585,7 +581,17 @@ class MQTTClient:
) -> Session: ) -> Session:
"""Initialize the MQTT session.""" """Initialize the MQTT session."""
broker_conf = self.config.get("broker", {}).copy() broker_conf = self.config.get("broker", {}).copy()
broker_conf.update(ConnectionConfig(uri=uri, cafile=cafile, capath=capath, cadata=cadata))
if uri is not None:
broker_conf.uri = uri
if cleansession is not None:
self.config.cleansession = cleansession
if cafile is not None:
broker_conf.cafile = cafile
if capath is not None:
broker_conf.capath = capath
if cadata is not None:
broker_conf.cadata = cadata
if not broker_conf.get("uri"): if not broker_conf.get("uri"):
msg = "Missing connection parameter 'uri'" msg = "Missing connection parameter 'uri'"
@ -594,14 +600,11 @@ class MQTTClient:
session = Session() session = Session()
session.broker_uri = broker_conf["uri"] session.broker_uri = broker_conf["uri"]
session.client_id = self.client_id session.client_id = self.client_id
session.cafile = broker_conf.get("cafile") session.cafile = broker_conf.get("cafile")
session.capath = broker_conf.get("capath") session.capath = broker_conf.get("capath")
session.cadata = broker_conf.get("cadata") session.cadata = broker_conf.get("cadata")
if cleansession is not None:
broker_conf["cleansession"] = cleansession # noop?
session.clean_session = cleansession
else:
session.clean_session = self.config.get("cleansession", True) session.clean_session = self.config.get("cleansession", True)
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"] session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]

Wyświetl plik

@ -12,6 +12,10 @@ _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio import asyncio
logger = logging.getLogger(__name__)
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
class BaseContext: class BaseContext:
def __init__(self) -> None: def __init__(self) -> None:
@ -47,15 +51,18 @@ class Dictable:
raise ValueError(f"'{name}' is not defined") raise ValueError(f"'{name}' is not defined")
def __contains__(self, name): def __contains__(self, name):
return getattr(self, name, None) is not None return getattr(self, name.replace('-', '_'), None) is not None
def __iter__(self): def __iter__(self):
for field in fields(self): for field in fields(self):
yield getattr(self, field.name) yield getattr(self, field.name)
def copy(self):
return replace(self)
@dataclass @dataclass
class ListenerConfig: class ListenerConfig(Dictable):
"""Structured configuration for a broker's listeners.""" """Structured configuration for a broker's listeners."""
type: ListenerType = ListenerType.TCP type: ListenerType = ListenerType.TCP
@ -132,19 +139,17 @@ class BrokerConfig(Dictable):
"""Deprecated field used to config EntryPoint-loaded plugins. See """Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](#taboo-topic-plugin) and [`TopicTabooPlugin`](#taboo-topic-plugin) and
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*""" [`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
plugins: dict | list[dict] | None = field(default_factory=default_broker_plugins) plugins: dict | list | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin` """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[Plugins](http://localhost:8000/custom_plugins/) for more information.""" [Plugins](http://localhost:8000/custom_plugins/) for more information."""
def __post__init__(self) -> None: def __post_init__(self) -> None:
if self.sys_interval is not None: if self.sys_interval is not None:
warnings.warn("sys_interval is deprecated, use 'plugins' to define configuration", logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
DeprecationWarning, stacklevel=1)
if self.auth is not None or self.topic_check is not None: if self.auth is not None or self.topic_check is not None:
warnings.warn("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration", logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
DeprecationWarning, stacklevel=1)
default_listener = self.listeners['default'] default_listener = self.listeners['default']
for listener_name, listener in self.listeners.items(): for listener_name, listener in self.listeners.items():
@ -155,13 +160,35 @@ class BrokerConfig(Dictable):
if isinstance(self.plugins, list): if isinstance(self.plugins, list):
_plugins = {} _plugins = {}
for plugin in self.plugins: for plugin in self.plugins:
if isinstance(plugin, str):
_plugins |= {plugin:{}}
continue
_plugins |= plugin _plugins |= plugin
self.plugins = _plugins self.plugins = _plugins
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig':
if d is None:
return BrokerConfig()
if 'topic-check' in d:
d['topic_check'] = d['topic-check']
del d['topic-check']
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d:
d['plugins'] = None
return dict_to_dataclass(data_class=BrokerConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum],
strict=True)
)
@dataclass @dataclass
class ConnectionConfig(Dictable): class ConnectionConfig(Dictable):
uri: str | None = "mqtt://127.0.0.1" uri: str | None = "mqtt://127.0.0.1:1883"
"""URI of the broker""" """URI of the broker"""
cafile: str | Path | None = None cafile: str | Path | None = None
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See """Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
@ -188,7 +215,7 @@ class ConnectionConfig(Dictable):
setattr(self, fn, Path(getattr(self, fn))) setattr(self, fn, Path(getattr(self, fn)))
@dataclass @dataclass
class TopicConfig: class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published. The topic name is """Configuration of how messages to specific topics are published. The topic name is
specified as the key in the dictionary of the `ClientConfig.topics.""" specified as the key in the dictionary of the `ClientConfig.topics."""
qos: int = 0 qos: int = 0
@ -203,7 +230,7 @@ class TopicConfig:
@dataclass @dataclass
class WillConfig: class WillConfig(Dictable):
"""Configuration of the 'last will & testament' of the client upon improper disconnection.""" """Configuration of the 'last will & testament' of the client upon improper disconnection."""
topic: str topic: str
@ -267,3 +294,16 @@ class ClientConfig(Dictable):
msg = "Client config: default QoS must be 0, 1 or 2." msg = "Client config: default QoS must be 0, 1 or 2."
raise ValueError(msg) raise ValueError(msg)
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig':
if d is None:
return ClientConfig()
return dict_to_dataclass(data_class=ClientConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum],
strict=True)
)

Wyświetl plik

@ -1,7 +1,7 @@
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar, cast from typing import Any, Generic, TypeVar, cast
from amqtt.contexts import Action, BaseContext from amqtt.contexts import Action, BaseContext, BrokerConfig
from amqtt.session import Session from amqtt.session import Session
C = TypeVar("C", bound=BaseContext) C = TypeVar("C", bound=BaseContext)
@ -79,7 +79,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
if not self.context.config: if not self.context.config:
return default return default
if is_dataclass(self.context.config): if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.topic_config and option_name in self.topic_config: if self.topic_config and option_name in self.topic_config:
@ -110,7 +110,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
if not self.context.config: if not self.context.config:
return default return default
if is_dataclass(self.context.config): if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.auth_config and option_name in self.auth_config: if self.auth_config and option_name in self.auth_config:

Wyświetl plik

@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
# plugins loaded directly from config dictionary # plugins loaded directly from config dictionary
if "auth" in self.app_context.config: if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config") self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config: if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config") self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", []) plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])

Wyświetl plik

@ -58,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
req_topic = topic req_topic = topic
if not req_topic: if not req_topic:
return False\ return False
username = session.username if session else None username = session.username if session else None
if username is None: if username is None:

Wyświetl plik

@ -236,7 +236,7 @@ async def test_client_connect_clean_session_false(broker):
client = MQTTClient(client_id="", config={"auto_reconnect": False}) client = MQTTClient(client_id="", config={"auto_reconnect": False})
return_code = None return_code = None
try: try:
await client.connect("mqtt://127.0.0.1/", cleansession=False) await client.connect("mqtt://127.0.0.1", cleansession=False)
except ConnectError as ce: except ConnectError as ce:
return_code = ce.return_code return_code = ce.return_code
assert return_code == 0x02 assert return_code == 0x02
@ -431,11 +431,11 @@ async def test_client_publish_acl_forbidden(acl_broker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_publish_acl_permitted_sub_forbidden(acl_broker): async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
sub_client1 = MQTTClient() sub_client1 = MQTTClient(client_id="sub_client1")
ret_conn = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/") ret_conn = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/")
assert ret_conn == 0 assert ret_conn == 0
sub_client2 = MQTTClient() sub_client2 = MQTTClient(client_id="sub_client2")
ret_conn = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/") ret_conn = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/")
assert ret_conn == 0 assert ret_conn == 0
@ -445,7 +445,7 @@ async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
ret_sub = await sub_client2.subscribe([("public/subtopic/test", QOS_0)]) ret_sub = await sub_client2.subscribe([("public/subtopic/test", QOS_0)])
assert ret_sub == [128] assert ret_sub == [128]
pub_client = MQTTClient() pub_client = MQTTClient(client_id="pub_client")
ret_conn = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/") ret_conn = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
assert ret_conn == 0 assert ret_conn == 0

Wyświetl plik

@ -1,39 +1,40 @@
import logging import logging
from typing import Any
import pytest import pytest
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from yaml import CLoader as Loader
import yaml
from dacite import from_dict, Config, UnexpectedDataError from dacite import from_dict, Config, UnexpectedDataError
from enum import StrEnum from enum import StrEnum
from amqtt.broker import BrokerContext
from amqtt.contexts import BrokerConfig, ListenerConfig, Dictable from amqtt.contexts import BrokerConfig, ListenerConfig, Dictable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _test_broker_config(): def test_entrypoint_broker_config(caplog):
# Parse with dacite test_cfg: dict[str, Any] = {
config = from_dict(data_class=BrokerConfig, data=data, config=Config(cast=[StrEnum])) "listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
assert isinstance(config, BrokerConfig) },
assert isinstance(config.listeners['default'], ListenerConfig) 'sys_interval': 1,
assert isinstance(config.listeners['secure'], ListenerConfig) 'auth': {
'allow_anonymous': True
default = config.listeners['default'] }
secure = config.listeners['secure'] }
secure_one = secure.copy() if 'plugins' not in test_cfg:
secure_two = secure.copy() test_cfg['plugins'] = None
# cfg: dict[str, Any] = yaml.load(config, Loader=Loader)
secure_one |= default
assert secure_one.max_connections == 50
assert secure_one.bind == '0.0.0.0:8883'
assert secure_one.cafile == Path('ca.key')
secure_two.update(default)
assert secure_two.max_connections == 50
assert secure_two.bind == '0.0.0.0:8883'
assert secure_two.cafile == Path('ca.key')
broker_config = from_dict(data_class=BrokerConfig, data=test_cfg, config=Config(cast=[StrEnum]))
assert isinstance(broker_config, BrokerConfig)
assert broker_config.plugins is None