type correction and linting fixes

config_dataclasses
Andrew Mirsky 2025-07-13 00:04:24 -04:00
rodzic 81866d0238
commit 4a39124545
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
6 zmienionych plików z 130 dodań i 88 usunięć

Wyświetl plik

@ -2,14 +2,11 @@ import asyncio
from asyncio import CancelledError, futures from asyncio import CancelledError, futures
from collections import deque from collections import deque
from collections.abc import Generator from collections.abc import Generator
from enum import StrEnum
from functools import partial from functools import partial
import logging import logging
from pathlib import Path
import re import re
import ssl import ssl
from typing import Any, ClassVar, TypeAlias from typing import Any, ClassVar, TypeAlias
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig
from transitions import Machine, MachineError from transitions import Machine, MachineError
import websockets.asyncio.server import websockets.asyncio.server
@ -23,18 +20,17 @@ from amqtt.adapters import (
WebSocketsWriter, WebSocketsWriter,
WriterAdapter, WriterAdapter,
) )
from amqtt.contexts import Action, BaseContext, BrokerConfig from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config from amqtt.utils import format_client_message, gen_client_id
from .events import BrokerEvents from .events import BrokerEvents
from .mqtt.constants import QOS_0, QOS_1, QOS_2 from .mqtt.constants import QOS_0, QOS_1, QOS_2
from .mqtt.disconnect import DisconnectPacket from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import PluginManager from .plugins.manager import PluginManager
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
# Default port numbers # Default port numbers
@ -147,14 +143,17 @@ class Broker:
def __init__( def __init__(
self, self,
config: _CONFIG_LISTENER | None = None, config: BrokerConfig | dict[str, Any] | None = None,
loop: asyncio.AbstractEventLoop | None = None, loop: asyncio.AbstractEventLoop | None = None,
plugin_namespace: str | None = None, plugin_namespace: str | None = None,
) -> None: ) -> None:
"""Initialize the broker.""" """Initialize the broker."""
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
if isinstance(config, dict):
self.config = BrokerConfig.from_dict(config) self.config = BrokerConfig.from_dict(config)
else:
self.config = config or BrokerConfig()
# listeners are populated from default within BrokerConfig # listeners are populated from default within BrokerConfig
self.listeners_config = self.config.listeners self.listeners_config = self.config.listeners
@ -247,7 +246,8 @@ class Broker:
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext: @staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
"""Create an SSL context for a listener.""" """Create an SSL context for a listener."""
try: try:
ssl_context = ssl.create_default_context( ssl_context = ssl.create_default_context(
@ -680,7 +680,7 @@ class Broker:
except Exception: except Exception:
self.logger.exception("Failed to stop handler") self.logger.exception("Failed to stop handler")
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool: async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
"""Call the authenticate method on registered plugins to test user authentication. """Call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True. User is considered authenticated if all plugins called returns True.

Wyświetl plik

@ -2,15 +2,11 @@ import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
import contextlib import contextlib
import copy
from enum import StrEnum
from functools import wraps from functools import wraps
import logging import logging
from pathlib import Path
import ssl import ssl
from typing import TYPE_CHECKING, Any, TypeAlias, cast from typing import TYPE_CHECKING, Any, TypeAlias, cast
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig
import websockets import websockets
from websockets import HeadersLike, InvalidHandshake, InvalidURI from websockets import HeadersLike, InvalidHandshake, InvalidURI
@ -21,14 +17,14 @@ from amqtt.adapters import (
WebSocketsReader, WebSocketsReader,
WebSocketsWriter, WebSocketsWriter,
) )
from amqtt.contexts import BaseContext, ClientConfig, ConnectionConfig from amqtt.contexts import BaseContext, ClientConfig
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
from amqtt.mqtt.connack import CONNECTION_ACCEPTED from amqtt.mqtt.connack import CONNECTION_ACCEPTED
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from amqtt.plugins.manager import PluginManager from amqtt.plugins.manager import PluginManager
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import gen_client_id, read_yaml_config from amqtt.utils import gen_client_id
if TYPE_CHECKING: if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection from websockets.asyncio.client import ClientConnection
@ -42,7 +38,7 @@ class ClientContext(BaseContext):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.config = None self.config: ClientConfig | None = None
base_logger = logging.getLogger(__name__) base_logger = logging.getLogger(__name__)
@ -94,9 +90,13 @@ class MQTTClient:
""" """
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None: def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
if isinstance(config, dict):
self.config = ClientConfig.from_dict(config) self.config = ClientConfig.from_dict(config)
else:
self.config = config or ClientConfig()
self.client_id = client_id if client_id is not None else gen_client_id() self.client_id = client_id if client_id is not None else gen_client_id()

Wyświetl plik

@ -1,9 +1,19 @@
import warnings from dataclasses import dataclass, field, fields, replace
from dataclasses import dataclass, field, fields, replace, asdict
try:
from enum import Enum, StrEnum from enum import Enum, StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
pass
from collections.abc import Iterator
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, LiteralString, Literal from typing import TYPE_CHECKING, Any, Literal
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
from amqtt.mqtt.constants import QOS_0, QOS_2 from amqtt.mqtt.constants import QOS_0, QOS_2
@ -14,14 +24,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
class BaseContext: class BaseContext:
def __init__(self) -> None: def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None self.loop: asyncio.AbstractEventLoop | None = None
self.logger: logging.Logger = _LOGGER self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None # cleanup with a `Generic` type
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
class Action(Enum): class Action(Enum):
@ -32,33 +41,44 @@ class Action(Enum):
class ListenerType(StrEnum): class ListenerType(StrEnum):
TCP = 'tcp' """Types of mqtt listeners."""
WS = 'ws'
TCP = "tcp"
WS = "ws"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'"{str(self.value)}"' """Display the string value, instead of the enum member."""
return f'"{self.value!s}"'
class Dictable: class Dictable:
def __getitem__(self, key): """Add dictionary methods to a dataclass."""
def __getitem__(self, key:str) -> Any:
"""Allow dict-style `[]` access to a dataclass."""
return self.get(key) return self.get(key)
def get(self, name, default=None): def get(self, name:str, default:Any=None) -> Any:
"""Allow dict-style access to a dataclass."""
name = name.replace("-", "_") name = name.replace("-", "_")
if hasattr(self, name): if hasattr(self, name):
return getattr(self, name) return getattr(self, name)
if default is not None: if default is not None:
return default return default
raise ValueError(f"'{name}' is not defined") msg = f"'{name}' is not defined"
raise ValueError(msg)
def __contains__(self, name): def __contains__(self, name: str) -> bool:
return getattr(self, name.replace('-', '_'), None) is not None """Provide dict-style 'in' check."""
return getattr(self, name.replace("-", "_"), None) is not None
def __iter__(self): def __iter__(self) -> Iterator[Any]:
for field in fields(self): """Provide dict-style iteration."""
yield getattr(self, field.name) for f in fields(self): # type: ignore[arg-type]
yield getattr(self, f.name)
def copy(self): def copy(self) -> dataclass: # type: ignore[valid-type]
return replace(self) """Return a copy of the dataclass."""
return replace(self) # type: ignore[type-var]
@dataclass @dataclass
@ -87,45 +107,46 @@ class ListenerConfig(Dictable):
keyfile: str | Path | None = None keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key.""" """Full path to file in PEM format containing the server's private key."""
def __post_init__(self): def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None): if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required." msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg) raise ValueError(msg)
for fn in ('cafile', 'capath', 'certfile', 'keyfile'): for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str): if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn))) setattr(self, fn, Path(getattr(self, fn)))
def apply(self, other): def apply(self, other: "ListenerConfig") -> None:
"""Apply the field from 'other', if 'self' field is default.""" """Apply the field from 'other', if 'self' field is default."""
if not isinstance(other, ListenerConfig):
msg = f'cannot apply {self.__class__.__name__} to {other.__class__.__name__}'
raise TypeError(msg)
for f in fields(self): for f in fields(self):
if getattr(self, f.name) == f.default: if getattr(self, f.name) == f.default:
setattr(self, f.name, other[f.name]) setattr(self, f.name, other[f.name])
def default_listeners(): def default_listeners() -> dict[str, Any]:
"""Create defaults for BrokerConfig.listeners."""
return { return {
'default': ListenerConfig() "default": ListenerConfig()
} }
def default_broker_plugins(): def default_broker_plugins() -> dict[str, Any]:
"""Create defaults for BrokerConfig.plugins."""
return { return {
'amqtt.plugins.logging_amqtt.EventLoggerPlugin':{}, "amqtt.plugins.logging_amqtt.EventLoggerPlugin":{},
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{}, "amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{},
'amqtt.plugins.authentication.AnonymousAuthPlugin':{'allow_anonymous':True}, "amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True},
'amqtt.plugins.sys.broker.BrokerSysPlugin':{'sys_interval':20} "amqtt.plugins.sys.broker.BrokerSysPlugin":{"sys_interval":20}
} }
@dataclass @dataclass
class BrokerConfig(Dictable): class BrokerConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary.""" """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
listeners: dict[Literal['default'] | str, ListenerConfig] = field(default_factory=default_listeners)
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
"""Network of listeners used by the services. a 'default' named listener is required; if another listener """Network of listeners used by the services. a 'default' named listener is required; if another listener
does not set a value, the 'default' settings are applied. See does not set a value, the 'default' settings are applied. See
[ListenerConfig](#amqtt.contexts.ListenerConfig) for more information.""" [ListenerConfig](./#amqtt.contexts.ListenerConfig) for more information."""
sys_interval: int | None = None sys_interval: int | None = None
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../packaged_plugins.md/#sys-topics) """*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../packaged_plugins.md/#sys-topics)
configuration instead.*""" configuration instead.*"""
@ -133,50 +154,55 @@ class BrokerConfig(Dictable):
"""Client disconnect timeout without a keep-alive.""" """Client disconnect timeout without a keep-alive."""
auth: dict[str, Any] | None = None auth: dict[str, Any] | None = None
"""*Deprecated field used to config EntryPoint-loaded plugins. See """*Deprecated field used to config EntryPoint-loaded plugins. See
[`AnonymousAuthPlugin`](#anonymous-auth-plugin) and [`AnonymousAuthPlugin`](./#anonymous-auth-plugin) and
[`FileAuthPlugin`](/packaged_plugins/#password-file-auth-plugin) for more information.*""" [`FileAuthPlugin`](/packaged_plugins/#password-file-auth-plugin) for more information.*"""
topic_check: dict[str, Any] | None = None topic_check: dict[str, Any] | None = None
"""Deprecated field used to config EntryPoint-loaded plugins. See """Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](#taboo-topic-plugin) and [`TopicTabooPlugin`](#taboo-topic-plugin) and
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*""" [`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
plugins: dict | list | None = field(default_factory=default_broker_plugins) plugins: dict[str, Any] | list[dict[str,Any]] | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin` """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[Plugins](http://localhost:8000/custom_plugins/) for more information.""" [Plugins](http://localhost:8000/custom_plugins/) for more information."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.sys_interval is not None: if self.sys_interval is not None:
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration") logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
if self.auth is not None or self.topic_check is not None: if self.auth is not None or self.topic_check is not None:
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration") logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
default_listener = self.listeners['default'] default_listener = self.listeners["default"]
for listener_name, listener in self.listeners.items(): for listener_name, listener in self.listeners.items():
if listener_name == 'default': if listener_name == "default":
continue continue
listener.apply(default_listener) listener.apply(default_listener)
if isinstance(self.plugins, list): if isinstance(self.plugins, list):
_plugins = {} _plugins: dict[str, Any] = {}
for plugin in self.plugins: for plugin in self.plugins:
if isinstance(plugin, str): # in case a plugin in a yaml file is listed without config
_plugins |= {plugin:{}} if isinstance(plugin, str): # type: ignore[unreachable]
_plugins |= {plugin:{}} # type: ignore[unreachable]
continue continue
_plugins |= plugin _plugins |= plugin
self.plugins = _plugins self.plugins = _plugins
@classmethod @classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig': def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
"""Create a broker config from a dictionary."""
if d is None: if d is None:
return BrokerConfig() return BrokerConfig()
if 'topic-check' in d: # patch the incoming dictionary so it can be loaded correctly
d['topic_check'] = d['topic-check'] if "topic-check" in d:
del d['topic-check'] d["topic_check"] = d["topic-check"]
del d["topic-check"]
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d: # identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
d['plugins'] = None if ("auth" in d or "topic-check" in d) and "plugins" not in d:
d["plugins"] = None
return dict_to_dataclass(data_class=BrokerConfig, return dict_to_dataclass(data_class=BrokerConfig,
data=d, data=d,
@ -188,6 +214,8 @@ class BrokerConfig(Dictable):
@dataclass @dataclass
class ConnectionConfig(Dictable): class ConnectionConfig(Dictable):
"""Properties for connecting to the broker."""
uri: str | None = "mqtt://127.0.0.1:1883" uri: str | None = "mqtt://127.0.0.1:1883"
"""URI of the broker""" """URI of the broker"""
cafile: str | Path | None = None cafile: str | Path | None = None
@ -206,24 +234,29 @@ class ConnectionConfig(Dictable):
"""Full path to file in PEM format containing the client's private key associated with the certfile.""" """Full path to file in PEM format containing the client's private key associated with the certfile."""
def __post__init__(self) -> None: def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None): if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required." msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg) raise ValueError(msg)
for fn in ('cafile', 'capath', 'certfile', 'keyfile'): for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str): if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn))) setattr(self, fn, Path(getattr(self, fn)))
@dataclass @dataclass
class TopicConfig(Dictable): class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published. The topic name is """Configuration of how messages to specific topics are published.
specified as the key in the dictionary of the `ClientConfig.topics."""
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
"""
qos: int = 0 qos: int = 0
"""The quality of service associated with the publishing to this topic.""" """The quality of service associated with the publishing to this topic."""
retain: bool = False retain: bool = False
"""Determines if the message should be retained by the topic it was published.""" """Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None: def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2): if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Topic config: default QoS must be 0, 1 or 2." msg = "Topic config: default QoS must be 0, 1 or 2."
raise ValueError(msg) raise ValueError(msg)
@ -243,19 +276,23 @@ class WillConfig(Dictable):
"""Determines if the message should be retained by the topic it was published.""" """Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None: def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2): if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Will config: default QoS must be 0, 1 or 2." msg = "Will config: default QoS must be 0, 1 or 2."
raise ValueError(msg) raise ValueError(msg)
def default_client_plugins(): def default_client_plugins() -> dict[str, Any]:
"""Create defaults for `ClientConfig.plugins`."""
return { return {
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{} "amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}
} }
@dataclass @dataclass
class ClientConfig(Dictable): class ClientConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
keep_alive: int | None = 10 keep_alive: int | None = 10
"""Keep-alive timeout sent to the broker.""" """Keep-alive timeout sent to the broker."""
ping_delay: int | None = 1 ping_delay: int | None = 1
@ -278,8 +315,8 @@ class ClientConfig(Dictable):
"""Specify the topics and what flags should be set for messages published to them.""" """Specify the topics and what flags should be set for messages published to them."""
broker: ConnectionConfig | None = field(default_factory=ConnectionConfig) broker: ConnectionConfig | None = field(default_factory=ConnectionConfig)
"""Configuration for connecting to the broker. See """Configuration for connecting to the broker. See
[ConnectionConfig](#amqtt.contexts.ConnectionConfig) for more information.""" [ConnectionConfig](./#amqtt.contexts.ConnectionConfig) for more information."""
plugins: dict | list | None = field(default_factory=default_client_plugins) plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
a dictionary of configuration options for that plugin. See [Plugins](http://localhost:8000/custom_plugins/) a dictionary of configuration options for that plugin. See [Plugins](http://localhost:8000/custom_plugins/)
for more information.""" for more information."""
@ -287,15 +324,17 @@ class ClientConfig(Dictable):
"""If establishing a secure connection, should the hostname of the certificate be verified.""" """If establishing a secure connection, should the hostname of the certificate be verified."""
will: WillConfig | None = None will: WillConfig | None = None
"""Message, topic and flags that should be sent to if the client disconnects. See """Message, topic and flags that should be sent to if the client disconnects. See
[WillConfig](#amqtt.contexts.WillConfig)""" [WillConfig](./#amqtt.contexts.WillConfig)"""
def __post__init__(self) -> None: def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2): if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
msg = "Client config: default QoS must be 0, 1 or 2." msg = "Client config: default QoS must be 0, 1 or 2."
raise ValueError(msg) raise ValueError(msg)
@classmethod @classmethod
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig': def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
"""Create a client config from a dictionary."""
if d is None: if d is None:
return ClientConfig() return ClientConfig()
@ -305,5 +344,3 @@ class ClientConfig(Dictable):
cast=[StrEnum], cast=[StrEnum],
strict=True) strict=True)
) )

Wyświetl plik

@ -52,7 +52,7 @@ class BasePlugin(Generic[C]):
if is_dataclass(self.context.config): if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] return getattr(self.context.config, option_name.replace("-", "_"), default)
if option_name in self.context.config: if option_name in self.context.config:
return self.context.config[option_name] return self.context.config[option_name]
return default return default
@ -79,9 +79,10 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
if not self.context.config: if not self.context.config:
return default return default
# overloaded context.config with either BrokerConfig or plugin's Config
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig): if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.topic_config and option_name in self.topic_config: if self.topic_config and option_name in self.topic_config:
return self.topic_config[option_name] return self.topic_config[option_name]
return default return default
@ -112,7 +113,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig): if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.auth_config and option_name in self.auth_config: if self.auth_config and option_name in self.auth_config:
return self.auth_config[option_name] return self.auth_config[option_name]
return default return default

Wyświetl plik

@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
# plugins loaded directly from config dictionary # plugins loaded directly from config dictionary
if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None: if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config") self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None: if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config") self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", []) plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])

Wyświetl plik

@ -1 +1,5 @@
{% if 'default_factory' in expression.__str__() %}{{ obj.extra.dataclass_ext.default_factory | safe }}{% else %}{{ expression | safe }}{% endif %} {% if 'default_factory' in expression.__str__() %}
{{ obj.extra.dataclass_ext.default_factory | safe }}
{% else %}
{% extends "_base/expression.html.jinja" %}
{% endif %}