updating broker sys plugin to use new plugin load and config

pull/240/head
Andrew Mirsky 2025-06-25 14:48:39 -04:00
rodzic 3228200db1
commit c86e9f7dc4
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
3 zmienionych plików z 94 dodań i 17 usunięć

Wyświetl plik

@ -202,6 +202,9 @@ class PluginManager(Generic[C]):
config=DaciteConfig(strict=True))
except DaciteError as e:
raise PluginLoadError from e
except TypeError as e:
msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass."
raise PluginLoadError(msg) from e
try:
self.logger.debug(f"Loading plugin {plugin_path}")

Wyświetl plik

@ -1,5 +1,6 @@
import asyncio
from collections import deque # pylint: disable=C0412
from dataclasses import is_dataclass, dataclass
from typing import SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412
from amqtt.plugins.base import BasePlugin
@ -51,6 +52,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
# Broker statistics initialization
self._stats: dict[str, int] = {}
self._sys_handle: asyncio.Handle | None = None
self._sys_interval: int = 0
def _clear_stats(self) -> None:
"""Initialize broker statistics data structures."""
@ -90,20 +92,24 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
# Start $SYS topics management
try:
sys_interval: int = 0
x = self.context.config.get("sys_interval") if self.context.config is not None else None
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
sys_interval = int(x)
if sys_interval > 0:
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
if is_dataclass(self.context.config):
self._sys_interval = self.context.config.sys_interval
else:
x = self.context.config.get("sys_interval") if self.context.config is not None else None
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
self._sys_interval = int(x)
if self._sys_interval > 0:
self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds")
self._sys_handle = (
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
if self.context.loop is not None
else None
)
else:
self.context.logger.debug("$SYS disabled")
except KeyError:
except KeyError as e:
self.context.logger.debug("could not find 'sys_interval' key: {e!r}")
pass
# 'sys_interval' config parameter not found
@ -160,15 +166,9 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
tasks.popleft()
# Reschedule
sys_interval: int = 0
x = self.context.config.get("sys_interval") if self.context.config is not None else None
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
sys_interval = int(x)
self.context.logger.debug("Broadcasting $SYS topics")
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
self.context.logger.debug(f"Broadcast $SYS topics again in {self._sys_interval} seconds.")
self._sys_handle = (
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
if self.context.loop is not None
else None
)
@ -203,3 +203,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
"""Handle broker client disconnection."""
self._stats[STAT_CLIENTS_CONNECTED] -= 1
self._stats[STAT_CLIENTS_DISCONNECTED] += 1
@dataclass
class Config:
sys_interval: int = 0

Wyświetl plik

@ -1,6 +1,7 @@
import asyncio
import logging
from importlib.metadata import EntryPoint
from logging.config import dictConfig
from unittest.mock import patch
import pytest
@ -9,8 +10,34 @@ from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_0
dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'verbose': {
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
}
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
'formatter': 'verbose',
}
},
'loggers': {
'transitions': {
'level': 'WARNING',
}
}
})
# logging.basicConfig(level=logging.DEBUG, format=formatter)
logger = logging.getLogger(__name__)
all_sys_topics = [
'$SYS/broker/version',
'$SYS/broker/load/bytes/received',
@ -38,7 +65,7 @@ all_sys_topics = [
# test broker sys
@pytest.mark.asyncio
async def test_broker_sys_plugin() -> None:
async def test_broker_sys_plugin_deprecated_config() -> None:
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
@ -88,3 +115,46 @@ async def test_broker_sys_plugin() -> None:
assert sys_msg_count > 1
assert all(sys_topic_flags.values()), f'topic not received: {[ topic for topic, flag in sys_topic_flags.items() if not flag ]}'
@pytest.mark.asyncio
async def test_broker_sys_plugin_config() -> None:
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'plugins': [
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
]
}
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
await broker.start()
client = MQTTClient()
await client.connect("mqtt://127.0.0.1:1883/")
await client.subscribe([("$SYS/#", QOS_0), ])
await client.publish('test/topic', b'my test message')
await asyncio.sleep(2)
sys_msg_count = 0
try:
while sys_msg_count < 30:
message = await client.deliver_message(timeout_duration=1)
if '$SYS' in message.topic:
sys_msg_count += 1
assert message.topic in sys_topic_flags
sys_topic_flags[message.topic] = True
except asyncio.TimeoutError:
logger.debug(f"TimeoutError after {sys_msg_count} messages")
await client.disconnect()
await broker.shutdown()
assert sys_msg_count > 1
assert all(
sys_topic_flags.values()), f'topic not received: {[topic for topic, flag in sys_topic_flags.items() if not flag]}'