all test cases working. most fixes were to make the dataclass config's backwards compatible with how the rest of the code was accessing config dictionaries

config_dataclasses
Andrew Mirsky 2025-07-12 19:53:15 -04:00
rodzic 5c59248b4f
commit 81866d0238
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 103 dodań i 92 usunięć

Wyświetl plik

@ -37,10 +37,6 @@ from .plugins.manager import PluginManager
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
_default_broker = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
_defaults = dict_to_dataclass(BrokerConfig, _default_broker, config=DaciteConfig(cast=[StrEnum]))
# Default port numbers
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
@ -103,7 +99,7 @@ class BrokerContext(BaseContext):
def __init__(self, broker: "Broker") -> None:
super().__init__()
self.config: _CONFIG_LISTENER | None = None
self.config: BrokerConfig | None = None
self._broker_instance = broker
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
@ -158,20 +154,10 @@ class Broker:
"""Initialize the broker."""
self.logger = logging.getLogger(__name__)
self.config = dict_to_dataclass(BrokerConfig, config, config=DaciteConfig(cast=[StrEnum]))
self.config = BrokerConfig.from_dict(config)
self.config |= _defaults
# if config is not None:
# # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
# if ("auth" in config or "topic-check" in config) and "plugins" not in config:
# # set to None so that the config isn't updated with the new-style default plugin list
# config["plugins"] = None # type: ignore[assignment]
# self.config.update(config)
self._build_listeners_config(self.config)
# listeners are populated from default within BrokerConfig
self.listeners_config = self.config.listeners
self._loop = loop or asyncio.get_running_loop()
self._servers: dict[str, Server] = {}
@ -197,25 +183,6 @@ class Broker:
namespace = plugin_namespace or "amqtt.broker.plugins"
self.plugins_manager = PluginManager(namespace, context, self._loop)
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
self.listeners_config = {}
try:
listeners_config = broker_config.get("listeners")
if not isinstance(listeners_config, dict):
msg = "Listener config not found or invalid"
raise BrokerError(msg)
defaults = listeners_config.get("default")
if defaults is None:
msg = "Listener config has not default included or is invalid"
raise BrokerError(msg)
for listener_name, listener_conf in listeners_config.items():
listener_conf |= defaults
self.listeners_config[listener_name] = listener_conf
except KeyError as ke:
msg = f"Listener config not found or invalid: {ke}"
raise BrokerError(msg) from ke
def _init_states(self) -> None:
self.transitions = Machine(states=Broker.states, initial="new")
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)

Wyświetl plik

@ -33,9 +33,6 @@ from amqtt.utils import gen_client_id, read_yaml_config
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
_default_client = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
_defaults = dict_to_dataclass(ClientConfig, _default_client, config=DaciteConfig(cast=[StrEnum]))
class ClientContext(BaseContext):
"""ClientContext is used as the context passed to plugins interacting with the client.
@ -99,9 +96,8 @@ class MQTTClient:
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
self.logger = logging.getLogger(__name__)
self.config = dict_to_dataclass(ClientConfig, config or {}, config=DaciteConfig(cast=[StrEnum]))
self.config = ClientConfig.from_dict(config)
self.config |= _defaults
self.client_id = client_id if client_id is not None else gen_client_id()
self.session: Session | None = None
@ -585,7 +581,17 @@ class MQTTClient:
) -> Session:
"""Initialize the MQTT session."""
broker_conf = self.config.get("broker", {}).copy()
broker_conf.update(ConnectionConfig(uri=uri, cafile=cafile, capath=capath, cadata=cadata))
if uri is not None:
broker_conf.uri = uri
if cleansession is not None:
self.config.cleansession = cleansession
if cafile is not None:
broker_conf.cafile = cafile
if capath is not None:
broker_conf.capath = capath
if cadata is not None:
broker_conf.cadata = cadata
if not broker_conf.get("uri"):
msg = "Missing connection parameter 'uri'"
@ -594,15 +600,12 @@ class MQTTClient:
session = Session()
session.broker_uri = broker_conf["uri"]
session.client_id = self.client_id
session.cafile = broker_conf.get("cafile")
session.capath = broker_conf.get("capath")
session.cadata = broker_conf.get("cadata")
if cleansession is not None:
broker_conf["cleansession"] = cleansession # noop?
session.clean_session = cleansession
else:
session.clean_session = self.config.get("cleansession", True)
session.clean_session = self.config.get("cleansession", True)
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]

Wyświetl plik

@ -12,6 +12,10 @@ _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
import asyncio
logger = logging.getLogger(__name__)
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
class BaseContext:
def __init__(self) -> None:
@ -47,15 +51,18 @@ class Dictable:
raise ValueError(f"'{name}' is not defined")
def __contains__(self, name):
return getattr(self, name, None) is not None
return getattr(self, name.replace('-', '_'), None) is not None
def __iter__(self):
for field in fields(self):
yield getattr(self, field.name)
def copy(self):
return replace(self)
@dataclass
class ListenerConfig:
class ListenerConfig(Dictable):
"""Structured configuration for a broker's listeners."""
type: ListenerType = ListenerType.TCP
@ -132,19 +139,17 @@ class BrokerConfig(Dictable):
"""Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](#taboo-topic-plugin) and
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
plugins: dict | list[dict] | None = field(default_factory=default_broker_plugins)
plugins: dict | list | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[Plugins](http://localhost:8000/custom_plugins/) for more information."""
def __post__init__(self) -> None:
def __post_init__(self) -> None:
if self.sys_interval is not None:
warnings.warn("sys_interval is deprecated, use 'plugins' to define configuration",
DeprecationWarning, stacklevel=1)
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
if self.auth is not None or self.topic_check is not None:
warnings.warn("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration",
DeprecationWarning, stacklevel=1)
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
default_listener = self.listeners['default']
for listener_name, listener in self.listeners.items():
@ -155,13 +160,35 @@ class BrokerConfig(Dictable):
if isinstance(self.plugins, list):
_plugins = {}
for plugin in self.plugins:
if isinstance(plugin, str):
_plugins |= {plugin:{}}
continue
_plugins |= plugin
self.plugins = _plugins
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig':
if d is None:
return BrokerConfig()
if 'topic-check' in d:
d['topic_check'] = d['topic-check']
del d['topic-check']
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d:
d['plugins'] = None
return dict_to_dataclass(data_class=BrokerConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum],
strict=True)
)
@dataclass
class ConnectionConfig(Dictable):
uri: str | None = "mqtt://127.0.0.1"
uri: str | None = "mqtt://127.0.0.1:1883"
"""URI of the broker"""
cafile: str | Path | None = None
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
@ -188,7 +215,7 @@ class ConnectionConfig(Dictable):
setattr(self, fn, Path(getattr(self, fn)))
@dataclass
class TopicConfig:
class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published. The topic name is
specified as the key in the dictionary of the `ClientConfig.topics."""
qos: int = 0
@ -203,7 +230,7 @@ class TopicConfig:
@dataclass
class WillConfig:
class WillConfig(Dictable):
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
topic: str
@ -267,3 +294,16 @@ class ClientConfig(Dictable):
msg = "Client config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig':
if d is None:
return ClientConfig()
return dict_to_dataclass(data_class=ClientConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum],
strict=True)
)

Wyświetl plik

@ -1,7 +1,7 @@
from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar, cast
from amqtt.contexts import Action, BaseContext
from amqtt.contexts import Action, BaseContext, BrokerConfig
from amqtt.session import Session
C = TypeVar("C", bound=BaseContext)
@ -79,7 +79,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
if not self.context.config:
return default
if is_dataclass(self.context.config):
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.topic_config and option_name in self.topic_config:
@ -110,7 +110,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
if not self.context.config:
return default
if is_dataclass(self.context.config):
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.auth_config and option_name in self.auth_config:

Wyświetl plik

@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
# plugins loaded directly from config dictionary
if "auth" in self.app_context.config:
if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config:
if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])

Wyświetl plik

@ -58,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
req_topic = topic
if not req_topic:
return False\
return False
username = session.username if session else None
if username is None:

Wyświetl plik

@ -236,7 +236,7 @@ async def test_client_connect_clean_session_false(broker):
client = MQTTClient(client_id="", config={"auto_reconnect": False})
return_code = None
try:
await client.connect("mqtt://127.0.0.1/", cleansession=False)
await client.connect("mqtt://127.0.0.1", cleansession=False)
except ConnectError as ce:
return_code = ce.return_code
assert return_code == 0x02
@ -431,11 +431,11 @@ async def test_client_publish_acl_forbidden(acl_broker):
@pytest.mark.asyncio
async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
sub_client1 = MQTTClient()
sub_client1 = MQTTClient(client_id="sub_client1")
ret_conn = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/")
assert ret_conn == 0
sub_client2 = MQTTClient()
sub_client2 = MQTTClient(client_id="sub_client2")
ret_conn = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/")
assert ret_conn == 0
@ -445,7 +445,7 @@ async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
ret_sub = await sub_client2.subscribe([("public/subtopic/test", QOS_0)])
assert ret_sub == [128]
pub_client = MQTTClient()
pub_client = MQTTClient(client_id="pub_client")
ret_conn = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
assert ret_conn == 0

Wyświetl plik

@ -1,39 +1,40 @@
import logging
from typing import Any
import pytest
from dataclasses import dataclass, field
from pathlib import Path
from yaml import CLoader as Loader
import yaml
from dacite import from_dict, Config, UnexpectedDataError
from enum import StrEnum
from amqtt.broker import BrokerContext
from amqtt.contexts import BrokerConfig, ListenerConfig, Dictable
logger = logging.getLogger(__name__)
def _test_broker_config():
# Parse with dacite
config = from_dict(data_class=BrokerConfig, data=data, config=Config(cast=[StrEnum]))
assert isinstance(config, BrokerConfig)
assert isinstance(config.listeners['default'], ListenerConfig)
assert isinstance(config.listeners['secure'], ListenerConfig)
default = config.listeners['default']
secure = config.listeners['secure']
secure_one = secure.copy()
secure_two = secure.copy()
secure_one |= default
assert secure_one.max_connections == 50
assert secure_one.bind == '0.0.0.0:8883'
assert secure_one.cafile == Path('ca.key')
secure_two.update(default)
assert secure_two.max_connections == 50
assert secure_two.bind == '0.0.0.0:8883'
assert secure_two.cafile == Path('ca.key')
def test_entrypoint_broker_config(caplog):
test_cfg: dict[str, Any] = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'auth': {
'allow_anonymous': True
}
}
if 'plugins' not in test_cfg:
test_cfg['plugins'] = None
# cfg: dict[str, Any] = yaml.load(config, Loader=Loader)
broker_config = from_dict(data_class=BrokerConfig, data=test_cfg, config=Config(cast=[StrEnum]))
assert isinstance(broker_config, BrokerConfig)
assert broker_config.plugins is None