kopia lustrzana https://github.com/Yakifo/amqtt
updating broker sys plugin to use new plugin load and config
rodzic
3228200db1
commit
c86e9f7dc4
|
@ -202,6 +202,9 @@ class PluginManager(Generic[C]):
|
||||||
config=DaciteConfig(strict=True))
|
config=DaciteConfig(strict=True))
|
||||||
except DaciteError as e:
|
except DaciteError as e:
|
||||||
raise PluginLoadError from e
|
raise PluginLoadError from e
|
||||||
|
except TypeError as e:
|
||||||
|
msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass."
|
||||||
|
raise PluginLoadError(msg) from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.logger.debug(f"Loading plugin {plugin_path}")
|
self.logger.debug(f"Loading plugin {plugin_path}")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque # pylint: disable=C0412
|
from collections import deque # pylint: disable=C0412
|
||||||
|
from dataclasses import is_dataclass, dataclass
|
||||||
from typing import SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412
|
from typing import SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412
|
||||||
|
|
||||||
from amqtt.plugins.base import BasePlugin
|
from amqtt.plugins.base import BasePlugin
|
||||||
|
@ -51,6 +52,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||||
# Broker statistics initialization
|
# Broker statistics initialization
|
||||||
self._stats: dict[str, int] = {}
|
self._stats: dict[str, int] = {}
|
||||||
self._sys_handle: asyncio.Handle | None = None
|
self._sys_handle: asyncio.Handle | None = None
|
||||||
|
self._sys_interval: int = 0
|
||||||
|
|
||||||
def _clear_stats(self) -> None:
|
def _clear_stats(self) -> None:
|
||||||
"""Initialize broker statistics data structures."""
|
"""Initialize broker statistics data structures."""
|
||||||
|
@ -90,20 +92,24 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||||
|
|
||||||
# Start $SYS topics management
|
# Start $SYS topics management
|
||||||
try:
|
try:
|
||||||
sys_interval: int = 0
|
|
||||||
|
if is_dataclass(self.context.config):
|
||||||
|
self._sys_interval = self.context.config.sys_interval
|
||||||
|
else:
|
||||||
x = self.context.config.get("sys_interval") if self.context.config is not None else None
|
x = self.context.config.get("sys_interval") if self.context.config is not None else None
|
||||||
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
|
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
|
||||||
sys_interval = int(x)
|
self._sys_interval = int(x)
|
||||||
if sys_interval > 0:
|
if self._sys_interval > 0:
|
||||||
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
|
self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds")
|
||||||
self._sys_handle = (
|
self._sys_handle = (
|
||||||
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
|
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
|
||||||
if self.context.loop is not None
|
if self.context.loop is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.context.logger.debug("$SYS disabled")
|
self.context.logger.debug("$SYS disabled")
|
||||||
except KeyError:
|
except KeyError as e:
|
||||||
|
self.context.logger.debug("could not find 'sys_interval' key: {e!r}")
|
||||||
pass
|
pass
|
||||||
# 'sys_interval' config parameter not found
|
# 'sys_interval' config parameter not found
|
||||||
|
|
||||||
|
@ -160,15 +166,9 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||||
tasks.popleft()
|
tasks.popleft()
|
||||||
|
|
||||||
# Reschedule
|
# Reschedule
|
||||||
sys_interval: int = 0
|
self.context.logger.debug(f"Broadcast $SYS topics again in {self._sys_interval} seconds.")
|
||||||
x = self.context.config.get("sys_interval") if self.context.config is not None else None
|
|
||||||
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
|
|
||||||
sys_interval = int(x)
|
|
||||||
self.context.logger.debug("Broadcasting $SYS topics")
|
|
||||||
|
|
||||||
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
|
|
||||||
self._sys_handle = (
|
self._sys_handle = (
|
||||||
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
|
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
|
||||||
if self.context.loop is not None
|
if self.context.loop is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -203,3 +203,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||||
"""Handle broker client disconnection."""
|
"""Handle broker client disconnection."""
|
||||||
self._stats[STAT_CLIENTS_CONNECTED] -= 1
|
self._stats[STAT_CLIENTS_CONNECTED] -= 1
|
||||||
self._stats[STAT_CLIENTS_DISCONNECTED] += 1
|
self._stats[STAT_CLIENTS_DISCONNECTED] += 1
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
sys_interval: int = 0
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from importlib.metadata import EntryPoint
|
from importlib.metadata import EntryPoint
|
||||||
|
from logging.config import dictConfig
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -9,8 +10,34 @@ from amqtt.broker import Broker
|
||||||
from amqtt.client import MQTTClient
|
from amqtt.client import MQTTClient
|
||||||
from amqtt.mqtt.constants import QOS_0
|
from amqtt.mqtt.constants import QOS_0
|
||||||
|
|
||||||
|
dictConfig({
|
||||||
|
'version': 1,
|
||||||
|
'disable_existing_loggers': False,
|
||||||
|
'formatters': {
|
||||||
|
'verbose': {
|
||||||
|
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'handlers': {
|
||||||
|
'console': {
|
||||||
|
'class': 'logging.StreamHandler',
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'formatter': 'verbose',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'loggers': {
|
||||||
|
'transitions': {
|
||||||
|
'level': 'WARNING',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
all_sys_topics = [
|
all_sys_topics = [
|
||||||
'$SYS/broker/version',
|
'$SYS/broker/version',
|
||||||
'$SYS/broker/load/bytes/received',
|
'$SYS/broker/load/bytes/received',
|
||||||
|
@ -38,7 +65,7 @@ all_sys_topics = [
|
||||||
|
|
||||||
# test broker sys
|
# test broker sys
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_broker_sys_plugin() -> None:
|
async def test_broker_sys_plugin_deprecated_config() -> None:
|
||||||
|
|
||||||
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
|
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
|
||||||
|
|
||||||
|
@ -88,3 +115,46 @@ async def test_broker_sys_plugin() -> None:
|
||||||
assert sys_msg_count > 1
|
assert sys_msg_count > 1
|
||||||
|
|
||||||
assert all(sys_topic_flags.values()), f'topic not received: {[ topic for topic, flag in sys_topic_flags.items() if not flag ]}'
|
assert all(sys_topic_flags.values()), f'topic not received: {[ topic for topic, flag in sys_topic_flags.items() if not flag ]}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broker_sys_plugin_config() -> None:
|
||||||
|
|
||||||
|
sys_topic_flags = {sys_topic:False for sys_topic in all_sys_topics}
|
||||||
|
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'plugins': [
|
||||||
|
{'amqtt.plugins.sys.broker.BrokerSysPlugin': {'sys_interval': 1}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||||
|
await broker.start()
|
||||||
|
client = MQTTClient()
|
||||||
|
await client.connect("mqtt://127.0.0.1:1883/")
|
||||||
|
await client.subscribe([("$SYS/#", QOS_0), ])
|
||||||
|
await client.publish('test/topic', b'my test message')
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
sys_msg_count = 0
|
||||||
|
try:
|
||||||
|
while sys_msg_count < 30:
|
||||||
|
message = await client.deliver_message(timeout_duration=1)
|
||||||
|
if '$SYS' in message.topic:
|
||||||
|
sys_msg_count += 1
|
||||||
|
assert message.topic in sys_topic_flags
|
||||||
|
sys_topic_flags[message.topic] = True
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.debug(f"TimeoutError after {sys_msg_count} messages")
|
||||||
|
|
||||||
|
await client.disconnect()
|
||||||
|
await broker.shutdown()
|
||||||
|
|
||||||
|
assert sys_msg_count > 1
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
sys_topic_flags.values()), f'topic not received: {[topic for topic, flag in sys_topic_flags.items() if not flag]}'
|
||||||
|
|
Ładowanie…
Reference in New Issue