kopia lustrzana https://github.com/Yakifo/amqtt
all test cases working. most fixes were to make the dataclass config's backwards compatible with how the rest of the code was accessing config dictionaries
rodzic
5c59248b4f
commit
81866d0238
|
@ -37,10 +37,6 @@ from .plugins.manager import PluginManager
|
|||
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
||||
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
||||
|
||||
|
||||
_default_broker = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
|
||||
_defaults = dict_to_dataclass(BrokerConfig, _default_broker, config=DaciteConfig(cast=[StrEnum]))
|
||||
|
||||
# Default port numbers
|
||||
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
||||
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
||||
|
@ -103,7 +99,7 @@ class BrokerContext(BaseContext):
|
|||
|
||||
def __init__(self, broker: "Broker") -> None:
|
||||
super().__init__()
|
||||
self.config: _CONFIG_LISTENER | None = None
|
||||
self.config: BrokerConfig | None = None
|
||||
self._broker_instance = broker
|
||||
|
||||
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
||||
|
@ -158,20 +154,10 @@ class Broker:
|
|||
"""Initialize the broker."""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.config = dict_to_dataclass(BrokerConfig, config, config=DaciteConfig(cast=[StrEnum]))
|
||||
self.config = BrokerConfig.from_dict(config)
|
||||
|
||||
self.config |= _defaults
|
||||
|
||||
# if config is not None:
|
||||
# # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
|
||||
# if ("auth" in config or "topic-check" in config) and "plugins" not in config:
|
||||
# # set to None so that the config isn't updated with the new-style default plugin list
|
||||
# config["plugins"] = None # type: ignore[assignment]
|
||||
# self.config.update(config)
|
||||
|
||||
|
||||
|
||||
self._build_listeners_config(self.config)
|
||||
# listeners are populated from default within BrokerConfig
|
||||
self.listeners_config = self.config.listeners
|
||||
|
||||
self._loop = loop or asyncio.get_running_loop()
|
||||
self._servers: dict[str, Server] = {}
|
||||
|
@ -197,25 +183,6 @@ class Broker:
|
|||
namespace = plugin_namespace or "amqtt.broker.plugins"
|
||||
self.plugins_manager = PluginManager(namespace, context, self._loop)
|
||||
|
||||
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
|
||||
self.listeners_config = {}
|
||||
try:
|
||||
listeners_config = broker_config.get("listeners")
|
||||
if not isinstance(listeners_config, dict):
|
||||
msg = "Listener config not found or invalid"
|
||||
raise BrokerError(msg)
|
||||
defaults = listeners_config.get("default")
|
||||
if defaults is None:
|
||||
msg = "Listener config has not default included or is invalid"
|
||||
raise BrokerError(msg)
|
||||
|
||||
for listener_name, listener_conf in listeners_config.items():
|
||||
listener_conf |= defaults
|
||||
self.listeners_config[listener_name] = listener_conf
|
||||
except KeyError as ke:
|
||||
msg = f"Listener config not found or invalid: {ke}"
|
||||
raise BrokerError(msg) from ke
|
||||
|
||||
def _init_states(self) -> None:
|
||||
self.transitions = Machine(states=Broker.states, initial="new")
|
||||
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
|
||||
|
|
|
@ -33,9 +33,6 @@ from amqtt.utils import gen_client_id, read_yaml_config
|
|||
if TYPE_CHECKING:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
_default_client = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
|
||||
_defaults = dict_to_dataclass(ClientConfig, _default_client, config=DaciteConfig(cast=[StrEnum]))
|
||||
|
||||
|
||||
class ClientContext(BaseContext):
|
||||
"""ClientContext is used as the context passed to plugins interacting with the client.
|
||||
|
@ -99,9 +96,8 @@ class MQTTClient:
|
|||
|
||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config = dict_to_dataclass(ClientConfig, config or {}, config=DaciteConfig(cast=[StrEnum]))
|
||||
self.config = ClientConfig.from_dict(config)
|
||||
|
||||
self.config |= _defaults
|
||||
self.client_id = client_id if client_id is not None else gen_client_id()
|
||||
|
||||
self.session: Session | None = None
|
||||
|
@ -585,7 +581,17 @@ class MQTTClient:
|
|||
) -> Session:
|
||||
"""Initialize the MQTT session."""
|
||||
broker_conf = self.config.get("broker", {}).copy()
|
||||
broker_conf.update(ConnectionConfig(uri=uri, cafile=cafile, capath=capath, cadata=cadata))
|
||||
|
||||
if uri is not None:
|
||||
broker_conf.uri = uri
|
||||
if cleansession is not None:
|
||||
self.config.cleansession = cleansession
|
||||
if cafile is not None:
|
||||
broker_conf.cafile = cafile
|
||||
if capath is not None:
|
||||
broker_conf.capath = capath
|
||||
if cadata is not None:
|
||||
broker_conf.cadata = cadata
|
||||
|
||||
if not broker_conf.get("uri"):
|
||||
msg = "Missing connection parameter 'uri'"
|
||||
|
@ -594,15 +600,12 @@ class MQTTClient:
|
|||
session = Session()
|
||||
session.broker_uri = broker_conf["uri"]
|
||||
session.client_id = self.client_id
|
||||
|
||||
session.cafile = broker_conf.get("cafile")
|
||||
session.capath = broker_conf.get("capath")
|
||||
session.cadata = broker_conf.get("cadata")
|
||||
|
||||
if cleansession is not None:
|
||||
broker_conf["cleansession"] = cleansession # noop?
|
||||
session.clean_session = cleansession
|
||||
else:
|
||||
session.clean_session = self.config.get("cleansession", True)
|
||||
session.clean_session = self.config.get("cleansession", True)
|
||||
|
||||
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]
|
||||
|
||||
|
|
|
@ -12,6 +12,10 @@ _LOGGER = logging.getLogger(__name__)
|
|||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
|
||||
|
||||
|
||||
class BaseContext:
|
||||
def __init__(self) -> None:
|
||||
|
@ -47,15 +51,18 @@ class Dictable:
|
|||
raise ValueError(f"'{name}' is not defined")
|
||||
|
||||
def __contains__(self, name):
|
||||
return getattr(self, name, None) is not None
|
||||
return getattr(self, name.replace('-', '_'), None) is not None
|
||||
|
||||
def __iter__(self):
|
||||
for field in fields(self):
|
||||
yield getattr(self, field.name)
|
||||
|
||||
def copy(self):
|
||||
return replace(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListenerConfig:
|
||||
class ListenerConfig(Dictable):
|
||||
"""Structured configuration for a broker's listeners."""
|
||||
|
||||
type: ListenerType = ListenerType.TCP
|
||||
|
@ -132,19 +139,17 @@ class BrokerConfig(Dictable):
|
|||
"""Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`TopicTabooPlugin`](#taboo-topic-plugin) and
|
||||
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
|
||||
plugins: dict | list[dict] | None = field(default_factory=default_broker_plugins)
|
||||
plugins: dict | list | None = field(default_factory=default_broker_plugins)
|
||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
||||
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
||||
[Plugins](http://localhost:8000/custom_plugins/) for more information."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
def __post_init__(self) -> None:
|
||||
if self.sys_interval is not None:
|
||||
warnings.warn("sys_interval is deprecated, use 'plugins' to define configuration",
|
||||
DeprecationWarning, stacklevel=1)
|
||||
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
|
||||
|
||||
if self.auth is not None or self.topic_check is not None:
|
||||
warnings.warn("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration",
|
||||
DeprecationWarning, stacklevel=1)
|
||||
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
|
||||
|
||||
default_listener = self.listeners['default']
|
||||
for listener_name, listener in self.listeners.items():
|
||||
|
@ -155,13 +160,35 @@ class BrokerConfig(Dictable):
|
|||
if isinstance(self.plugins, list):
|
||||
_plugins = {}
|
||||
for plugin in self.plugins:
|
||||
if isinstance(plugin, str):
|
||||
_plugins |= {plugin:{}}
|
||||
continue
|
||||
_plugins |= plugin
|
||||
self.plugins = _plugins
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig':
|
||||
if d is None:
|
||||
return BrokerConfig()
|
||||
|
||||
if 'topic-check' in d:
|
||||
d['topic_check'] = d['topic-check']
|
||||
del d['topic-check']
|
||||
|
||||
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d:
|
||||
d['plugins'] = None
|
||||
|
||||
return dict_to_dataclass(data_class=BrokerConfig,
|
||||
data=d,
|
||||
config=DaciteConfig(
|
||||
cast=[StrEnum],
|
||||
strict=True)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig(Dictable):
|
||||
uri: str | None = "mqtt://127.0.0.1"
|
||||
uri: str | None = "mqtt://127.0.0.1:1883"
|
||||
"""URI of the broker"""
|
||||
cafile: str | Path | None = None
|
||||
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
|
||||
|
@ -188,7 +215,7 @@ class ConnectionConfig(Dictable):
|
|||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
|
||||
@dataclass
|
||||
class TopicConfig:
|
||||
class TopicConfig(Dictable):
|
||||
"""Configuration of how messages to specific topics are published. The topic name is
|
||||
specified as the key in the dictionary of the `ClientConfig.topics."""
|
||||
qos: int = 0
|
||||
|
@ -203,7 +230,7 @@ class TopicConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
class WillConfig:
|
||||
class WillConfig(Dictable):
|
||||
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
|
||||
|
||||
topic: str
|
||||
|
@ -267,3 +294,16 @@ class ClientConfig(Dictable):
|
|||
msg = "Client config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig':
|
||||
if d is None:
|
||||
return ClientConfig()
|
||||
|
||||
return dict_to_dataclass(data_class=ClientConfig,
|
||||
data=d,
|
||||
config=DaciteConfig(
|
||||
cast=[StrEnum],
|
||||
strict=True)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.contexts import Action, BaseContext, BrokerConfig
|
||||
from amqtt.session import Session
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
@ -79,7 +79,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
if self.topic_config and option_name in self.topic_config:
|
||||
|
@ -110,7 +110,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
|||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
if self.auth_config and option_name in self.auth_config:
|
||||
|
|
|
@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
|
|||
# plugins loaded directly from config dictionary
|
||||
|
||||
|
||||
if "auth" in self.app_context.config:
|
||||
if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||
if "topic-check" in self.app_context.config:
|
||||
if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
|
||||
|
||||
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
|
||||
|
|
|
@ -58,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
|||
|
||||
req_topic = topic
|
||||
if not req_topic:
|
||||
return False\
|
||||
return False
|
||||
|
||||
username = session.username if session else None
|
||||
if username is None:
|
||||
|
|
|
@ -236,7 +236,7 @@ async def test_client_connect_clean_session_false(broker):
|
|||
client = MQTTClient(client_id="", config={"auto_reconnect": False})
|
||||
return_code = None
|
||||
try:
|
||||
await client.connect("mqtt://127.0.0.1/", cleansession=False)
|
||||
await client.connect("mqtt://127.0.0.1", cleansession=False)
|
||||
except ConnectError as ce:
|
||||
return_code = ce.return_code
|
||||
assert return_code == 0x02
|
||||
|
@ -431,11 +431,11 @@ async def test_client_publish_acl_forbidden(acl_broker):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
|
||||
sub_client1 = MQTTClient()
|
||||
sub_client1 = MQTTClient(client_id="sub_client1")
|
||||
ret_conn = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/")
|
||||
assert ret_conn == 0
|
||||
|
||||
sub_client2 = MQTTClient()
|
||||
sub_client2 = MQTTClient(client_id="sub_client2")
|
||||
ret_conn = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/")
|
||||
assert ret_conn == 0
|
||||
|
||||
|
@ -445,7 +445,7 @@ async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
|
|||
ret_sub = await sub_client2.subscribe([("public/subtopic/test", QOS_0)])
|
||||
assert ret_sub == [128]
|
||||
|
||||
pub_client = MQTTClient()
|
||||
pub_client = MQTTClient(client_id="pub_client")
|
||||
ret_conn = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
|
||||
assert ret_conn == 0
|
||||
|
||||
|
|
|
@ -1,39 +1,40 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from yaml import CLoader as Loader
|
||||
|
||||
|
||||
import yaml
|
||||
from dacite import from_dict, Config, UnexpectedDataError
|
||||
from enum import StrEnum
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import BrokerConfig, ListenerConfig, Dictable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _test_broker_config():
|
||||
# Parse with dacite
|
||||
config = from_dict(data_class=BrokerConfig, data=data, config=Config(cast=[StrEnum]))
|
||||
|
||||
assert isinstance(config, BrokerConfig)
|
||||
assert isinstance(config.listeners['default'], ListenerConfig)
|
||||
assert isinstance(config.listeners['secure'], ListenerConfig)
|
||||
|
||||
default = config.listeners['default']
|
||||
secure = config.listeners['secure']
|
||||
secure_one = secure.copy()
|
||||
secure_two = secure.copy()
|
||||
|
||||
secure_one |= default
|
||||
|
||||
assert secure_one.max_connections == 50
|
||||
assert secure_one.bind == '0.0.0.0:8883'
|
||||
assert secure_one.cafile == Path('ca.key')
|
||||
|
||||
secure_two.update(default)
|
||||
assert secure_two.max_connections == 50
|
||||
assert secure_two.bind == '0.0.0.0:8883'
|
||||
assert secure_two.cafile == Path('ca.key')
|
||||
def test_entrypoint_broker_config(caplog):
|
||||
test_cfg: dict[str, Any] = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'auth': {
|
||||
'allow_anonymous': True
|
||||
}
|
||||
}
|
||||
if 'plugins' not in test_cfg:
|
||||
test_cfg['plugins'] = None
|
||||
# cfg: dict[str, Any] = yaml.load(config, Loader=Loader)
|
||||
|
||||
|
||||
broker_config = from_dict(data_class=BrokerConfig, data=test_cfg, config=Config(cast=[StrEnum]))
|
||||
assert isinstance(broker_config, BrokerConfig)
|
||||
|
||||
assert broker_config.plugins is None
|
||||
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue