type correction and linting fixes

config_dataclasses
Andrew Mirsky 2025-07-13 00:04:24 -04:00
rodzic 81866d0238
commit 4a39124545
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
6 zmienionych plików z 130 dodań i 88 usunięć

Wyświetl plik

@ -2,14 +2,11 @@ import asyncio
from asyncio import CancelledError, futures
from collections import deque
from collections.abc import Generator
from enum import StrEnum
from functools import partial
import logging
from pathlib import Path
import re
import ssl
from typing import Any, ClassVar, TypeAlias
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig
from transitions import Machine, MachineError
import websockets.asyncio.server
@ -23,18 +20,17 @@ from amqtt.adapters import (
WebSocketsWriter,
WriterAdapter,
)
from amqtt.contexts import Action, BaseContext, BrokerConfig
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
from amqtt.utils import format_client_message, gen_client_id
from .events import BrokerEvents
from .mqtt.constants import QOS_0, QOS_1, QOS_2
from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import PluginManager
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
# Default port numbers
@ -147,14 +143,17 @@ class Broker:
def __init__(
self,
config: _CONFIG_LISTENER | None = None,
config: BrokerConfig | dict[str, Any] | None = None,
loop: asyncio.AbstractEventLoop | None = None,
plugin_namespace: str | None = None,
) -> None:
"""Initialize the broker."""
self.logger = logging.getLogger(__name__)
self.config = BrokerConfig.from_dict(config)
if isinstance(config, dict):
self.config = BrokerConfig.from_dict(config)
else:
self.config = config or BrokerConfig()
# listeners are populated from default within BrokerConfig
self.listeners_config = self.config.listeners
@ -247,7 +246,8 @@ class Broker:
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
@staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
"""Create an SSL context for a listener."""
try:
ssl_context = ssl.create_default_context(
@ -680,7 +680,7 @@ class Broker:
except Exception:
self.logger.exception("Failed to stop handler")
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
"""Call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True.

Wyświetl plik

@ -2,15 +2,11 @@ import asyncio
from collections import deque
from collections.abc import Callable, Coroutine
import contextlib
import copy
from enum import StrEnum
from functools import wraps
import logging
from pathlib import Path
import ssl
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from urllib.parse import urlparse, urlunparse
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig
import websockets
from websockets import HeadersLike, InvalidHandshake, InvalidURI
@ -21,14 +17,14 @@ from amqtt.adapters import (
WebSocketsReader,
WebSocketsWriter,
)
from amqtt.contexts import BaseContext, ClientConfig, ConnectionConfig
from amqtt.contexts import BaseContext, ClientConfig
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
from amqtt.mqtt.connack import CONNECTION_ACCEPTED
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from amqtt.plugins.manager import PluginManager
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import gen_client_id, read_yaml_config
from amqtt.utils import gen_client_id
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
@ -42,7 +38,7 @@ class ClientContext(BaseContext):
def __init__(self) -> None:
super().__init__()
self.config = None
self.config: ClientConfig | None = None
base_logger = logging.getLogger(__name__)
@ -94,9 +90,13 @@ class MQTTClient:
"""
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
self.logger = logging.getLogger(__name__)
self.config = ClientConfig.from_dict(config)
if isinstance(config, dict):
self.config = ClientConfig.from_dict(config)
else:
self.config = config or ClientConfig()
self.client_id = client_id if client_id is not None else gen_client_id()

Wyświetl plik

@ -1,9 +1,19 @@
import warnings
from dataclasses import dataclass, field, fields, replace, asdict
from enum import Enum, StrEnum
from dataclasses import dataclass, field, fields, replace
try:
from enum import Enum, StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
pass
from collections.abc import Iterator
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, LiteralString, Literal
from typing import TYPE_CHECKING, Any, Literal
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
from amqtt.mqtt.constants import QOS_0, QOS_2
@ -14,14 +24,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
class BaseContext:
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None
self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None
# cleanup with a `Generic` type
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
class Action(Enum):
@ -32,33 +41,44 @@ class Action(Enum):
class ListenerType(StrEnum):
TCP = 'tcp'
WS = 'ws'
"""Types of mqtt listeners."""
TCP = "tcp"
WS = "ws"
def __repr__(self) -> str:
return f'"{str(self.value)}"'
"""Display the string value, instead of the enum member."""
return f'"{self.value!s}"'
class Dictable:
def __getitem__(self, key):
"""Add dictionary methods to a dataclass."""
def __getitem__(self, key:str) -> Any:
"""Allow dict-style `[]` access to a dataclass."""
return self.get(key)
def get(self, name, default=None):
def get(self, name:str, default:Any=None) -> Any:
"""Allow dict-style access to a dataclass."""
name = name.replace("-", "_")
if hasattr(self, name):
return getattr(self, name)
if default is not None:
return default
raise ValueError(f"'{name}' is not defined")
msg = f"'{name}' is not defined"
raise ValueError(msg)
def __contains__(self, name):
return getattr(self, name.replace('-', '_'), None) is not None
def __contains__(self, name: str) -> bool:
"""Provide dict-style 'in' check."""
return getattr(self, name.replace("-", "_"), None) is not None
def __iter__(self):
for field in fields(self):
yield getattr(self, field.name)
def __iter__(self) -> Iterator[Any]:
"""Provide dict-style iteration."""
for f in fields(self): # type: ignore[arg-type]
yield getattr(self, f.name)
def copy(self):
return replace(self)
def copy(self) -> dataclass: # type: ignore[valid-type]
"""Return a copy of the dataclass."""
return replace(self) # type: ignore[type-var]
@dataclass
@ -87,96 +107,102 @@ class ListenerConfig(Dictable):
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key."""
def __post_init__(self):
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg)
for fn in ('cafile', 'capath', 'certfile', 'keyfile'):
for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
def apply(self, other):
def apply(self, other: "ListenerConfig") -> None:
"""Apply the field from 'other', if 'self' field is default."""
if not isinstance(other, ListenerConfig):
msg = f'cannot apply {self.__class__.__name__} to {other.__class__.__name__}'
raise TypeError(msg)
for f in fields(self):
if getattr(self, f.name) == f.default:
setattr(self, f.name, other[f.name])
def default_listeners():
def default_listeners() -> dict[str, Any]:
"""Create defaults for BrokerConfig.listeners."""
return {
'default': ListenerConfig()
"default": ListenerConfig()
}
def default_broker_plugins():
def default_broker_plugins() -> dict[str, Any]:
"""Create defaults for BrokerConfig.plugins."""
return {
'amqtt.plugins.logging_amqtt.EventLoggerPlugin':{},
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{},
'amqtt.plugins.authentication.AnonymousAuthPlugin':{'allow_anonymous':True},
'amqtt.plugins.sys.broker.BrokerSysPlugin':{'sys_interval':20}
"amqtt.plugins.logging_amqtt.EventLoggerPlugin":{},
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{},
"amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True},
"amqtt.plugins.sys.broker.BrokerSysPlugin":{"sys_interval":20}
}
@dataclass
class BrokerConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
listeners: dict[Literal['default'] | str, ListenerConfig] = field(default_factory=default_listeners)
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
does not set a value, the 'default' settings are applied. See
[ListenerConfig](#amqtt.contexts.ListenerConfig) for more information."""
[ListenerConfig](./#amqtt.contexts.ListenerConfig) for more information."""
sys_interval: int | None = None
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../packaged_plugins.md/#sys-topics)
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../packaged_plugins.md/#sys-topics)
configuration instead.*"""
timeout_disconnect_delay: int | None = 0
"""Client disconnect timeout without a keep-alive."""
auth: dict[str, Any] | None = None
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`AnonymousAuthPlugin`](#anonymous-auth-plugin) and
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`AnonymousAuthPlugin`](./#anonymous-auth-plugin) and
[`FileAuthPlugin`](/packaged_plugins/#password-file-auth-plugin) for more information.*"""
topic_check: dict[str, Any] | None = None
"""Deprecated field used to config EntryPoint-loaded plugins. See
"""Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](#taboo-topic-plugin) and
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
plugins: dict | list | None = field(default_factory=default_broker_plugins)
plugins: dict[str, Any] | list[dict[str,Any]] | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[Plugins](http://localhost:8000/custom_plugins/) for more information."""
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.sys_interval is not None:
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
if self.auth is not None or self.topic_check is not None:
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
default_listener = self.listeners['default']
default_listener = self.listeners["default"]
for listener_name, listener in self.listeners.items():
if listener_name == 'default':
if listener_name == "default":
continue
listener.apply(default_listener)
if isinstance(self.plugins, list):
_plugins = {}
_plugins: dict[str, Any] = {}
for plugin in self.plugins:
if isinstance(plugin, str):
_plugins |= {plugin:{}}
# in case a plugin in a yaml file is listed without config
if isinstance(plugin, str): # type: ignore[unreachable]
_plugins |= {plugin:{}} # type: ignore[unreachable]
continue
_plugins |= plugin
self.plugins = _plugins
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig':
def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
"""Create a broker config from a dictionary."""
if d is None:
return BrokerConfig()
if 'topic-check' in d:
d['topic_check'] = d['topic-check']
del d['topic-check']
# patch the incoming dictionary so it can be loaded correctly
if "topic-check" in d:
d["topic_check"] = d["topic-check"]
del d["topic-check"]
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d:
d['plugins'] = None
# identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
if ("auth" in d or "topic-check" in d) and "plugins" not in d:
d["plugins"] = None
return dict_to_dataclass(data_class=BrokerConfig,
data=d,
@ -188,6 +214,8 @@ class BrokerConfig(Dictable):
@dataclass
class ConnectionConfig(Dictable):
"""Properties for connecting to the broker."""
uri: str | None = "mqtt://127.0.0.1:1883"
"""URI of the broker"""
cafile: str | Path | None = None
@ -206,24 +234,29 @@ class ConnectionConfig(Dictable):
"""Full path to file in PEM format containing the client's private key associated with the certfile."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg)
for fn in ('cafile', 'capath', 'certfile', 'keyfile'):
for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
@dataclass
class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published. The topic name is
specified as the key in the dictionary of the `ClientConfig.topics."""
"""Configuration of how messages to specific topics are published.
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
"""
qos: int = 0
"""The quality of service associated with the publishing to this topic."""
retain: bool = False
"""Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Topic config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
@ -243,19 +276,23 @@ class WillConfig(Dictable):
"""Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Will config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
def default_client_plugins():
def default_client_plugins() -> dict[str, Any]:
"""Create defaults for `ClientConfig.plugins`."""
return {
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{}
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}
}
@dataclass
class ClientConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
keep_alive: int | None = 10
"""Keep-alive timeout sent to the broker."""
ping_delay: int | None = 1
@ -278,8 +315,8 @@ class ClientConfig(Dictable):
"""Specify the topics and what flags should be set for messages published to them."""
broker: ConnectionConfig | None = field(default_factory=ConnectionConfig)
"""Configuration for connecting to the broker. See
[ConnectionConfig](#amqtt.contexts.ConnectionConfig) for more information."""
plugins: dict | list | None = field(default_factory=default_client_plugins)
[ConnectionConfig](./#amqtt.contexts.ConnectionConfig) for more information."""
plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
a dictionary of configuration options for that plugin. See [Plugins](http://localhost:8000/custom_plugins/)
for more information."""
@ -287,15 +324,17 @@ class ClientConfig(Dictable):
"""If establishing a secure connection, should the hostname of the certificate be verified."""
will: WillConfig | None = None
"""Message, topic and flags that should be sent to if the client disconnects. See
[WillConfig](#amqtt.contexts.WillConfig)"""
[WillConfig](./#amqtt.contexts.WillConfig)"""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
msg = "Client config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig':
def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
"""Create a client config from a dictionary."""
if d is None:
return ClientConfig()
@ -305,5 +344,3 @@ class ClientConfig(Dictable):
cast=[StrEnum],
strict=True)
)

Wyświetl plik

@ -52,7 +52,7 @@ class BasePlugin(Generic[C]):
if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if option_name in self.context.config:
return self.context.config[option_name]
return default
@ -79,9 +79,10 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
if not self.context.config:
return default
# overloaded context.config with either BrokerConfig or plugin's Config
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.topic_config and option_name in self.topic_config:
return self.topic_config[option_name]
return default
@ -112,7 +113,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.auth_config and option_name in self.auth_config:
return self.auth_config[option_name]
return default

Wyświetl plik

@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
# plugins loaded directly from config dictionary
if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None:
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None:
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])

Wyświetl plik

@ -1 +1,5 @@
{% if 'default_factory' in expression.__str__() %}{{ obj.extra.dataclass_ext.default_factory | safe }}{% else %}{{ expression | safe }}{% endif %}
{% if 'default_factory' in expression.__str__() %}
{{ obj.extra.dataclass_ext.default_factory | safe }}
{% else %}
{% extends "_base/expression.html.jinja" %}
{% endif %}