Merge pull request #283 from ajmirsky/issue_73

embed amqtt into an existing server
pull/256/head
Andrew Mirsky 2025-08-08 21:14:12 -04:00 zatwierdzone przez GitHub
commit 2a7aa11524
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
5 zmienionych plików z 310 dodań i 21 usunięć

Wyświetl plik

@ -22,7 +22,7 @@ from amqtt.adapters import (
WebSocketsWriter,
WriterAdapter,
)
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
@ -52,6 +52,8 @@ class RetainedApplicationMessage(ApplicationMessage):
class Server:
"""Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
def __init__(
self,
listener_name: str,
@ -89,11 +91,24 @@ class Server:
await self.instance.wait_closed()
class BrokerContext(BaseContext):
"""BrokerContext is used as the context passed to plugins interacting with the broker.
class ExternalServer(Server):
"""For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
It act as an adapter to broker services from plugins developed for HBMQTT broker.
"""
def __init__(self) -> None:
super().__init__("aiohttp", None) # type: ignore[arg-type]
async def acquire_connection(self) -> None:
pass
def release_connection(self) -> None:
pass
async def close_instance(self) -> None:
pass
class BrokerContext(BaseContext):
"""BrokerContext is used as the context passed to plugins interacting with the broker."""
def __init__(self, broker: "Broker") -> None:
super().__init__()
@ -243,16 +258,24 @@ class Broker:
max_connections = listener.get("max_connections", -1)
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
# for listeners which are external, don't need to create a server
if listener.type == ListenerType.EXTERNAL:
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
# broker still needs to associate a new connection to the listener
self.logger.info(f"External listener exists for '{listener_name}' ")
self._servers[listener_name] = ExternalServer()
else:
# for tcp and websockets, start servers to listen for inbound connections
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
@ -385,6 +408,10 @@ class Broker:
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
"""Engage the broker in handling the data stream to/from an established connection."""
await self._client_connected(listener_name, reader, writer)
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
"""Handle a new client connection."""
server = self._servers.get(listener_name)

Wyświetl plik

@ -44,6 +44,7 @@ class ListenerType(StrEnum):
TCP = "tcp"
WS = "ws"
EXTERNAL = "external"
def __repr__(self) -> str:
"""Display the string value, instead of the enum member."""
@ -114,6 +115,8 @@ class ListenerConfig(Dictable):
certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key."""
reader: str | None = None
writer: str | None = None
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""

Wyświetl plik

@ -0,0 +1,180 @@
import asyncio
import io
import logging
import aiohttp
from aiohttp import web
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import ConnectError
logger = logging.getLogger(__name__)
MQTT_LISTENER_NAME = "myMqttListener"
async def hello(request):
"""get request handler"""
return web.Response(text="Hello, world")
class WebSocketResponseReader(ReaderAdapter):
"""Interface to allow mqtt broker to read from an aiohttp websocket connection."""
def __init__(self, ws: web.WebSocketResponse):
self.ws = ws
self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes:
"""
read 'n' bytes from the datastream, if < 0 read all available bytes
Raises:
BrokerPipeError : if reading on a closed websocket connection
"""
# continue until buffer contains at least the amount of data being requested
while not self.buffer or len(self.buffer) < n:
# if the websocket is closed
if self.ws.closed:
raise BrokenPipeError()
try:
# read from stream
msg = await asyncio.wait_for(self.ws.receive(), timeout=0.5)
# mqtt streams should always be binary...
if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
raise BrokenPipeError()
except asyncio.TimeoutError:
raise BrokenPipeError()
# return all bytes currently in the buffer
if n == -1:
result = bytes(self.buffer)
self.buffer.clear()
# return the requested number of bytes from the buffer
else:
result = self.buffer[:n]
del self.buffer[:n]
return result
def feed_eof(self) -> None:
pass
class WebSocketResponseWriter(WriterAdapter):
"""Interface to allow mqtt broker to write to an aiohttp websocket connection."""
def __init__(self, ws: web.WebSocketResponse, request: web.Request):
super().__init__()
self.ws = ws
# needed for `get_peer_info`
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
peer_name = request.transport.get_extra_info('peername')
if peer_name is not None:
self.client_ip, self.port = peer_name[0:2]
else:
self.client_ip, self.port = request.remote, 0
# interpret AF_INET6
self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip
self._stream = io.BytesIO(b"")
def write(self, data: bytes) -> None:
"""Add bytes to stream buffer."""
self._stream.write(data)
async def drain(self) -> None:
"""Send the collected bytes in the buffer to the websocket connection."""
data = self._stream.getvalue()
if data and len(data):
await self.ws.send_bytes(data)
self._stream = io.BytesIO(b"")
def get_peer_info(self) -> tuple[str, int] | None:
return self.client_ip, self.port
async def close(self) -> None:
# no clean up needed, stream will be gc along with instance
pass
async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse:
# establish connection by responding to the websocket request with the 'mqtt' protocol
ws = web.WebSocketResponse(protocols=['mqtt',])
await ws.prepare(request)
# access the broker created when the server started
b: Broker = request.app['broker']
# hand-off the websocket data stream to the broker for handling
# `listener_name` is the same name of the externalized listener in the broker config
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME)
logger.debug('websocket connection closed')
return ws
async def websocket_handler(request: web.Request) -> web.StreamResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
logging.info(msg)
logging.info("websocket connection closed")
return ws
def main():
# create an `aiohttp` server
lp = asyncio.get_event_loop()
app = web.Application()
app.add_routes(
[
web.get('/', hello), # http get request/response route
web.get('/ws', websocket_handler), # standard websocket handler
web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections
])
# create background task for running the `amqtt` broker
app.cleanup_ctx.append(run_broker)
# make sure that both `aiohttp` server and `amqtt` broker run in the same loop
# so the server can hand off the connection to the broker (prevents attached-to-a-different-loop `RuntimeError`)
web.run_app(app, loop=lp)
async def run_broker(_app):
"""App init function to start (and then shutdown) the `amqtt` broker.
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
# standard TCP connection as well as an externalized-listener
cfg = BrokerConfig(
listeners={
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL),
}
)
# make sure the `Broker` runs in the same loop as the aiohttp server
loop = asyncio.get_event_loop()
broker = Broker(config=cfg, loop=loop)
# store broker instance so that incoming requests can hand off processing of a datastream
_app['broker'] = broker
# start the broker
await broker.start()
# pass control back to web app
yield
# closing activities
await broker.shutdown()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
main()

Wyświetl plik

@ -1,11 +1,19 @@
import asyncio
import logging
import random
import threading
import time
from pathlib import Path
from threading import Thread
from typing import Any
from unittest.mock import MagicMock, call, patch
import pytest
from paho.mqtt import client as mqtt_client
import yaml
from paho.mqtt import client as paho_client
from yaml import Loader
from amqtt.broker import Broker
from amqtt.events import BrokerEvents
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
@ -40,7 +48,7 @@ async def test_paho_connect(broker, mock_plugin_manager):
assert rc == 0, f"Disconnect failed with result code {rc}"
test_complete.set()
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
test_client.enable_logger(paho_logger)
test_client.on_connect = on_connect
@ -76,7 +84,7 @@ async def test_paho_qos1(broker, mock_plugin_manager):
port = 1883
client_id = f'python-mqtt-{random.randint(0, 1000)}'
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
test_client.enable_logger(paho_logger)
test_client.connect(host, port)
@ -107,7 +115,7 @@ async def test_paho_qos2(broker, mock_plugin_manager):
port = 1883
client_id = f'python-mqtt-{random.randint(0, 1000)}'
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
test_client.enable_logger(paho_logger)
test_client.connect(host, port)
@ -124,3 +132,52 @@ async def test_paho_qos2(broker, mock_plugin_manager):
assert message.data == b"test message"
await sub_client.disconnect()
await asyncio.sleep(0.1)
def run_paho_client(flag):
client_id = 'websocket_client_1'
logging.info("creating paho client")
test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2,
transport='websockets',
client_id=client_id)
test_client.ws_set_options('')
logging.info("client connecting...")
test_client.connect('127.0.0.1', 8080)
logging.info("starting loop")
test_client.loop_start()
logging.info("client connected")
time.sleep(1)
logging.info("sending messages")
test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2)
time.sleep(1)
test_client.loop_stop()
test_client.disconnect()
flag.set()
@pytest.mark.asyncio
async def test_paho_ws():
path = Path('docs_test/test.amqtt.local.yaml')
with path.open() as f:
cfg: dict[str, Any] = yaml.load(f, Loader=Loader)
logger.warning(cfg)
broker = Broker(config=cfg)
await broker.start()
# python websockets and paho mqtt don't play well with each other in the same thread
flag = threading.Event()
thread = Thread(target=run_paho_client, args=(flag,))
thread.start()
await asyncio.sleep(5)
thread.join(1)
assert flag.is_set(), "paho thread didn't execute completely"
logging.info("client disconnected")
await broker.shutdown()

Wyświetl plik

@ -2,14 +2,15 @@ import asyncio
import logging
import signal
import subprocess
from multiprocessing import Process
from pathlib import Path
from samples.http_server_integration import main as http_server_main
import pytest
from amqtt.broker import Broker
from samples.client_publish import __main__ as client_publish_main
from samples.client_subscribe import __main__ as client_subscribe_main
from samples.client_keepalive import __main__ as client_keepalive_main
from amqtt.client import MQTTClient
from samples.broker_acl import config as broker_acl_config
from samples.broker_taboo import config as broker_taboo_config
@ -275,4 +276,25 @@ async def test_client_subscribe_plugin_taboo():
assert "ERROR" not in stderr.decode("utf-8")
assert "Exception" not in stderr.decode("utf-8")
await broker.shutdown()
await broker.shutdown()
@pytest.fixture
def external_http_server():
p = Process(target=http_server_main)
p.start()
yield p
p.terminate()
@pytest.mark.asyncio
async def test_external_http_server(external_http_server):
await asyncio.sleep(1)
client = MQTTClient(config={'auto_reconnect': False})
await client.connect("ws://127.0.0.1:8080/mqtt")
assert client.session is not None
await client.publish("my/topic", b'test message')
await client.disconnect()
# Send the interrupt signal
await asyncio.sleep(1)