kopia lustrzana https://github.com/Yakifo/amqtt
380 wiersze
16 KiB
Python
380 wiersze
16 KiB
Python
from dataclasses import dataclass, field, fields, replace
|
|
import logging
|
|
import warnings
|
|
|
|
try:
|
|
from enum import Enum, StrEnum
|
|
except ImportError:
|
|
# support for python 3.10
|
|
from enum import Enum
|
|
class StrEnum(str, Enum): # type: ignore[no-redef]
|
|
pass
|
|
|
|
from collections.abc import Iterator
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
|
|
|
|
from amqtt.mqtt3.constants import QOS_0, QOS_2
|
|
|
|
if TYPE_CHECKING:
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseContext:
|
|
def __init__(self) -> None:
|
|
self.loop: asyncio.AbstractEventLoop | None = None
|
|
self.logger: logging.Logger = logging.getLogger(__name__)
|
|
# cleanup with a `Generic` type
|
|
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
|
|
|
|
|
|
class Action(StrEnum):
|
|
"""Actions issued by the broker."""
|
|
|
|
SUBSCRIBE = "subscribe"
|
|
PUBLISH = "publish"
|
|
RECEIVE = "receive"
|
|
|
|
|
|
class ListenerType(StrEnum):
|
|
"""Types of mqtt listeners."""
|
|
|
|
TCP = "tcp"
|
|
WS = "ws"
|
|
EXTERNAL = "external"
|
|
|
|
def __repr__(self) -> str:
|
|
"""Display the string value, instead of the enum member."""
|
|
return f'"{self.value!s}"'
|
|
|
|
|
|
class Dictable:
|
|
"""Add dictionary methods to a dataclass."""
|
|
|
|
def __getitem__(self, key: str) -> Any:
|
|
"""Allow dict-style `[]` access to a dataclass."""
|
|
return self.get(key)
|
|
|
|
def get(self, name: str, default: Any = None) -> Any:
|
|
"""Allow dict-style access to a dataclass."""
|
|
name = name.replace("-", "_")
|
|
if hasattr(self, name):
|
|
return getattr(self, name)
|
|
if default is not None:
|
|
return default
|
|
msg = f"'{name}' is not defined"
|
|
raise ValueError(msg)
|
|
|
|
def __contains__(self, name: str) -> bool:
|
|
"""Provide dict-style 'in' check."""
|
|
return getattr(self, name.replace("-", "_"), None) is not None
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
"""Provide dict-style iteration."""
|
|
for f in fields(self): # type: ignore[arg-type]
|
|
yield getattr(self, f.name)
|
|
|
|
def copy(self) -> dataclass: # type: ignore[valid-type]
|
|
"""Return a copy of the dataclass."""
|
|
return replace(self) # type: ignore[type-var]
|
|
|
|
@staticmethod
|
|
def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]:
|
|
if isinstance(value, list):
|
|
return value # It's already a list of dicts
|
|
if isinstance(value, dict):
|
|
return [value] # Promote single dict to a list
|
|
msg = "Could not convert 'list' to 'list[dict[str, Any]]'"
|
|
raise ValueError(msg)
|
|
|
|
|
|
@dataclass
|
|
class ListenerConfig(Dictable):
|
|
"""Structured configuration for a broker's listeners."""
|
|
|
|
type: ListenerType = ListenerType.TCP
|
|
"""Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'"""
|
|
bind: str | None = "0.0.0.0:1883"
|
|
"""address and port for the listener to bind to"""
|
|
max_connections: int = 0
|
|
"""max number of connections allowed for this listener"""
|
|
ssl: bool = False
|
|
"""secured by ssl"""
|
|
cafile: str | Path | None = None
|
|
"""Path to a file of concatenated CA certificates in PEM format. See
|
|
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
|
|
capath: str | Path | None = None
|
|
"""Path to a directory containing one or more CA certificates in PEM format, following the
|
|
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
|
|
cadata: str | Path | None = None
|
|
"""Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates."""
|
|
certfile: str | Path | None = None
|
|
"""Full path to file in PEM format containing the server's certificate (as well as any number of CA
|
|
certificates needed to establish the certificate's authenticity.)"""
|
|
keyfile: str | Path | None = None
|
|
"""Full path to file in PEM format containing the server's private key."""
|
|
reader: str | None = None
|
|
writer: str | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Check config for errors and transform fields for easier use."""
|
|
if (self.certfile is None) ^ (self.keyfile is None):
|
|
msg = "If specifying the 'certfile' or 'keyfile', both are required."
|
|
raise ValueError(msg)
|
|
|
|
for fn in ("cafile", "capath", "certfile", "keyfile"):
|
|
if isinstance(getattr(self, fn), str):
|
|
setattr(self, fn, Path(getattr(self, fn)))
|
|
if getattr(self, fn) and not getattr(self, fn).exists():
|
|
msg = f"'{fn}' does not exist : {getattr(self, fn)}"
|
|
raise FileNotFoundError(msg)
|
|
|
|
def apply(self, other: "ListenerConfig") -> None:
|
|
"""Apply the field from 'other', if 'self' field is default."""
|
|
for f in fields(self):
|
|
if getattr(self, f.name) == f.default:
|
|
setattr(self, f.name, other[f.name])
|
|
|
|
|
|
def default_listeners() -> dict[str, Any]:
|
|
"""Create defaults for BrokerConfig.listeners."""
|
|
return {
|
|
"default": ListenerConfig()
|
|
}
|
|
|
|
|
|
def default_broker_plugins() -> dict[str, Any]:
|
|
"""Create defaults for BrokerConfig.plugins."""
|
|
return {
|
|
"amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
|
|
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
|
|
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
|
|
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class BrokerConfig(Dictable):
|
|
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
|
|
|
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
|
|
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
|
|
does not set a value, the 'default' settings are applied. See
|
|
[`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
|
|
sys_interval: int | None = None
|
|
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics)
|
|
for recommended configuration.*"""
|
|
timeout_disconnect_delay: int | None = 0
|
|
"""Client disconnect timeout without a keep-alive."""
|
|
session_expiry_interval: int | None = None
|
|
"""Seconds for an inactive session to be retained."""
|
|
auth: dict[str, Any] | None = None
|
|
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
|
[`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and
|
|
[`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*"""
|
|
topic_check: dict[str, Any] | None = None
|
|
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
|
[`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
|
|
[`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
|
|
plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
|
|
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
|
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
|
[custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
|
|
to support legacy use cases."""
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Check config for errors and transform fields for easier use."""
|
|
if self.sys_interval is not None:
|
|
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
|
|
|
|
if self.auth is not None or self.topic_check is not None:
|
|
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
|
|
|
|
default_listener = self.listeners["default"]
|
|
for listener_name, listener in self.listeners.items():
|
|
if listener_name == "default":
|
|
continue
|
|
listener.apply(default_listener)
|
|
|
|
if isinstance(self.plugins, list):
|
|
_plugins: dict[str, Any] = {}
|
|
for plugin in self.plugins:
|
|
# in case a plugin in a yaml file is listed without config map
|
|
if isinstance(plugin, str):
|
|
_plugins |= {plugin: {}}
|
|
continue
|
|
_plugins |= plugin
|
|
self.plugins = _plugins
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
|
|
"""Create a broker config from a dictionary."""
|
|
if d is None:
|
|
return BrokerConfig()
|
|
|
|
# patch the incoming dictionary so it can be loaded correctly
|
|
if "topic-check" in d:
|
|
d["topic_check"] = d["topic-check"]
|
|
del d["topic-check"]
|
|
|
|
# identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
|
|
if ("auth" in d or "topic-check" in d) and "plugins" not in d:
|
|
d["plugins"] = None
|
|
|
|
return dict_to_dataclass(data_class=BrokerConfig,
|
|
data=d,
|
|
config=DaciteConfig(
|
|
cast=[StrEnum, ListenerType],
|
|
strict=True,
|
|
type_hooks={list[dict[str, Any]]: cls._coerce_lists}
|
|
))
|
|
|
|
|
|
@dataclass
|
|
class ConnectionConfig(Dictable):
|
|
"""Properties for connecting to the broker."""
|
|
|
|
uri: str | None = "mqtt://127.0.0.1:1883"
|
|
"""URI of the broker"""
|
|
cafile: str | Path | None = None
|
|
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
|
|
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
|
|
capath: str | Path | None = None
|
|
"""Path to a directory containing one or more CA certificates in PEM format, following the
|
|
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
|
|
cadata: str | None = None
|
|
"""The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded
|
|
certificates or a bytes-like object of DER-encoded certificates."""
|
|
certfile: str | Path | None = None
|
|
"""Full path to file in PEM format containing the client's certificate (as well as any number of CA
|
|
certificates needed to establish the certificate's authenticity.)"""
|
|
keyfile: str | Path | None = None
|
|
"""Full path to file in PEM format containing the client's private key associated with the certfile."""
|
|
|
|
def __post__init__(self) -> None:
|
|
"""Check config for errors and transform fields for easier use."""
|
|
if (self.certfile is None) ^ (self.keyfile is None):
|
|
msg = "If specifying the 'certfile' or 'keyfile', both are required."
|
|
raise ValueError(msg)
|
|
|
|
for fn in ("cafile", "capath", "certfile", "keyfile"):
|
|
if isinstance(getattr(self, fn), str):
|
|
setattr(self, fn, Path(getattr(self, fn)))
|
|
|
|
|
|
@dataclass
|
|
class TopicConfig(Dictable):
|
|
"""Configuration of how messages to specific topics are published.
|
|
|
|
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
|
|
"""
|
|
|
|
qos: int = 0
|
|
"""The quality of service associated with the publishing to this topic."""
|
|
retain: bool = False
|
|
"""Determines if the message should be retained by the topic it was published."""
|
|
|
|
def __post__init__(self) -> None:
|
|
"""Check config for errors and transform fields for easier use."""
|
|
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
|
|
msg = "Topic config: default QoS must be 0, 1 or 2."
|
|
raise ValueError(msg)
|
|
|
|
|
|
@dataclass
|
|
class WillConfig(Dictable):
|
|
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
|
|
|
|
topic: str
|
|
"""The will message will be published to this topic."""
|
|
message: str
|
|
"""The contents of the message to be published."""
|
|
qos: int | None = QOS_0
|
|
"""The quality of service associated with sending this message."""
|
|
retain: bool | None = False
|
|
"""Determines if the message should be retained by the topic it was published."""
|
|
|
|
def __post__init__(self) -> None:
|
|
"""Check config for errors and transform fields for easier use."""
|
|
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
|
|
msg = "Will config: default QoS must be 0, 1 or 2."
|
|
raise ValueError(msg)
|
|
|
|
|
|
def default_client_plugins() -> dict[str, Any]:
|
|
"""Create defaults for `ClientConfig.plugins`."""
|
|
return {
|
|
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ClientConfig(Dictable):
|
|
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
|
|
|
keep_alive: int | None = 10
|
|
"""Keep-alive timeout sent to the broker."""
|
|
ping_delay: int | None = 1
|
|
"""Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection."""
|
|
default_qos: int | None = QOS_0
|
|
"""Default QoS for messages published."""
|
|
default_retain: bool | None = False
|
|
"""Default retain value to messages published."""
|
|
auto_reconnect: bool | None = True
|
|
"""Enable or disable auto-reconnect if connection with the broker is interrupted."""
|
|
connection_timeout: int | None = 60
|
|
"""The number of seconds before a connection times out"""
|
|
reconnect_retries: int | None = 2
|
|
"""Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely."""
|
|
reconnect_max_interval: int | None = 10
|
|
"""Maximum seconds to wait before retrying a connection."""
|
|
cleansession: bool | None = True
|
|
"""Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`"""
|
|
topics: dict[str, TopicConfig] | None = field(default_factory=dict)
|
|
"""Specify the topics and what flags should be set for messages published to them."""
|
|
broker: ConnectionConfig | None = None
|
|
"""*Deprecated* Configuration for connecting to the broker. Use `connection` field instead."""
|
|
connection: ConnectionConfig = field(default_factory=ConnectionConfig)
|
|
"""Configuration for connecting to the broker. See
|
|
[`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information."""
|
|
plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
|
|
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
|
|
a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for
|
|
more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases."""
|
|
check_hostname: bool | None = True
|
|
"""If establishing a secure connection, should the hostname of the certificate be verified."""
|
|
will: WillConfig | None = None
|
|
"""Message, topic and flags that should be sent to if the client disconnects. See
|
|
[`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information."""
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Check config for errors and transform fields for easier use."""
|
|
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
|
|
msg = "Client config: default QoS must be 0, 1 or 2."
|
|
raise ValueError(msg)
|
|
|
|
if self.broker is not None:
|
|
warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2)
|
|
self.connection = self.broker
|
|
|
|
if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile):
|
|
msg = "Connection key and certificate files are _both_ required."
|
|
raise ValueError(msg)
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
|
|
"""Create a client config from a dictionary."""
|
|
if d is None:
|
|
return ClientConfig()
|
|
|
|
return dict_to_dataclass(data_class=ClientConfig,
|
|
data=d,
|
|
config=DaciteConfig(
|
|
cast=[StrEnum],
|
|
strict=True)
|
|
)
|