kopia lustrzana https://github.com/Yakifo/amqtt
adding event for broker when a message is being retained
rodzic
66e0ea2443
commit
c06e585be5
|
@ -759,8 +759,9 @@ class Broker:
|
||||||
self.logger.debug(f"Retaining message on topic {topic_name}")
|
self.logger.debug(f"Retaining message on topic {topic_name}")
|
||||||
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
||||||
|
|
||||||
kwargs = {'client_id': None, "retained_message": self._retained_messages[topic_name]}
|
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, method_kwargs=kwargs)
|
client_id=None,
|
||||||
|
retained_message=self._retained_messages[topic_name])
|
||||||
|
|
||||||
# [MQTT-3.3.1-10]
|
# [MQTT-3.3.1-10]
|
||||||
elif topic_name in self._retained_messages:
|
elif topic_name in self._retained_messages:
|
||||||
|
@ -769,8 +770,9 @@ class Broker:
|
||||||
cleared_message = self._retained_messages[topic_name]
|
cleared_message = self._retained_messages[topic_name]
|
||||||
cleared_message.data = None
|
cleared_message.data = None
|
||||||
|
|
||||||
kwargs = {'client_id': None, "retained_message": cleared_message}
|
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, method_kwargs=kwargs)
|
client_id=None,
|
||||||
|
retained_message=cleared_message)
|
||||||
|
|
||||||
del self._retained_messages[topic_name]
|
del self._retained_messages[topic_name]
|
||||||
|
|
||||||
|
@ -962,8 +964,9 @@ class Broker:
|
||||||
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
|
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
|
||||||
await target_session.retained_messages.put(retained_message)
|
await target_session.retained_messages.put(retained_message)
|
||||||
|
|
||||||
kwargs = {'client_id': target_session.client_id, "retained_message": retained_message}
|
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, method_kwargs=kwargs)
|
client_id=target_session.client_id,
|
||||||
|
retained_message=retained_message)
|
||||||
|
|
||||||
if self.logger.isEnabledFor(logging.DEBUG):
|
if self.logger.isEnabledFor(logging.DEBUG):
|
||||||
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")
|
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")
|
||||||
|
|
|
@ -199,13 +199,13 @@ max-returns = 10
|
||||||
|
|
||||||
# ----------------------------------- PYTEST -----------------------------------
|
# ----------------------------------- PYTEST -----------------------------------
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
timeout = 10
|
timeout = 10
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
addopts = ["--tb=short", "--capture=tee-sys"]
|
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||||
log_cli = true
|
#log_cli = true
|
||||||
log_level = "DEBUG"
|
log_level = "DEBUG"
|
||||||
|
|
||||||
# ------------------------------------ MYPY ------------------------------------
|
# ------------------------------------ MYPY ------------------------------------
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from importlib.metadata import EntryPoint
|
from importlib.metadata import EntryPoint
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
@ -15,13 +16,15 @@ from amqtt.broker import Broker, BrokerContext
|
||||||
from amqtt.client import MQTTClient
|
from amqtt.client import MQTTClient
|
||||||
from amqtt.errors import PluginInitError, PluginImportError
|
from amqtt.errors import PluginInitError, PluginImportError
|
||||||
from amqtt.events import MQTTEvents, BrokerEvents
|
from amqtt.events import MQTTEvents, BrokerEvents
|
||||||
from amqtt.mqtt.constants import QOS_0
|
from amqtt.mqtt.constants import QOS_0, QOS_2, QOS_1
|
||||||
from amqtt.plugins.base import BasePlugin
|
from amqtt.plugins.base import BasePlugin
|
||||||
from amqtt.contexts import BaseContext
|
from amqtt.contexts import BaseContext
|
||||||
|
from amqtt.plugins.persistence import RetainedMessage
|
||||||
|
|
||||||
_INVALID_METHOD: str = "invalid_foo"
|
_INVALID_METHOD: str = "invalid_foo"
|
||||||
_PLUGIN: str = "Plugin"
|
_PLUGIN: str = "Plugin"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class _TestContext(BaseContext):
|
class _TestContext(BaseContext):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -159,7 +162,7 @@ async def test_all_plugin_events():
|
||||||
client = MQTTClient()
|
client = MQTTClient()
|
||||||
await client.connect("mqtt://127.0.0.1:1883/")
|
await client.connect("mqtt://127.0.0.1:1883/")
|
||||||
await client.subscribe([('my/test/topic', QOS_0),])
|
await client.subscribe([('my/test/topic', QOS_0),])
|
||||||
await client.publish('test/topic', b'my test message')
|
await client.publish('test/topic', b'my test message', retain=True)
|
||||||
await client.unsubscribe(['my/test/topic',])
|
await client.unsubscribe(['my/test/topic',])
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
@ -170,3 +173,70 @@ async def test_all_plugin_events():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
|
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
|
||||||
|
|
||||||
|
|
||||||
|
class RetainedMessageEventPlugin(BasePlugin[BrokerContext]):
|
||||||
|
"""A plugin to verify all events get sent to plugins."""
|
||||||
|
def __init__(self, context: BaseContext) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
self.topic_retained_message_flag = False
|
||||||
|
self.session_retained_message_flag = False
|
||||||
|
self.topic_clear_retained_message_flag = False
|
||||||
|
|
||||||
|
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedMessage) -> None:
|
||||||
|
"""retaining message event handler."""
|
||||||
|
if client_id:
|
||||||
|
session, _ = self.context.get_session(client_id)
|
||||||
|
assert session.transitions.state != "connected"
|
||||||
|
logger.debug("retained message event fired for offline client")
|
||||||
|
self.session_retained_message_flag = True
|
||||||
|
else:
|
||||||
|
if not retained_message.data:
|
||||||
|
logger.debug("retained message event fired for clearing a topic")
|
||||||
|
self.topic_clear_retained_message_flag = True
|
||||||
|
else:
|
||||||
|
logger.debug("retained message event fired for setting a topic")
|
||||||
|
self.topic_retained_message_flag = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retained_message_plugin_event():
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'sys_interval': 1,
|
||||||
|
'plugins':[{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}},
|
||||||
|
{'tests.plugins.test_plugins.RetainedMessageEventPlugin': {}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||||
|
|
||||||
|
await broker.start()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# make sure all expected events get triggered
|
||||||
|
client1 = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
await client1.connect("mqtt://myUsername@127.0.0.1:1883/", cleansession=False)
|
||||||
|
await client1.subscribe([('test/topic', QOS_1),])
|
||||||
|
await client1.publish('test/retained', b'message should be retained for test/retained', retain=True)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await client1.disconnect()
|
||||||
|
|
||||||
|
client2 = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
await client2.connect("mqtt://myOtherUsername@127.0.0.1:1883/", cleansession=True)
|
||||||
|
await client2.publish('test/topic', b'message should be retained for myUsername since subscription was qos > 0')
|
||||||
|
await client2.publish('test/retained', b'', retain=True) # should clear previously retained message
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await client2.disconnect()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# get the plugin so it doesn't get gc on shutdown
|
||||||
|
test_plugin = broker.plugins_manager.get_plugin('RetainedMessageEventPlugin')
|
||||||
|
await broker.shutdown()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert test_plugin.topic_retained_message_flag, "message to topic wasn't retained"
|
||||||
|
assert test_plugin.session_retained_message_flag, "message to disconnected client wasn't retained"
|
||||||
|
assert test_plugin.topic_clear_retained_message_flag, "message to retained topic wasn't cleared"
|
||||||
|
|
Ładowanie…
Reference in New Issue