From 0141cddeebb9d55f8290af99d8a5a69a9916fc74 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 19:19:51 -0400 Subject: [PATCH] Yakifo/amqtt#73 : refinement of external websocket server --- amqtt/broker.py | 3 +++ samples/web.py | 65 ++++++++++++++++++++++++++++++++-------------- tests/test_paho.py | 53 ++++++++++++++++++++++++++++++++----- 3 files changed, 95 insertions(+), 26 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 1a797d6..7bcbeec 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -403,6 +403,9 @@ class Broker: async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) + async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None: + await self._client_connected(listener_name, reader, writer) + async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: """Handle a new client connection.""" server = self._servers.get(listener_name) diff --git a/samples/web.py b/samples/web.py index 552953f..787ccf4 100644 --- a/samples/web.py +++ b/samples/web.py @@ -8,25 +8,38 @@ from aiohttp import web from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.broker import Broker from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType -from amqtt.errors import BrokerError + + +logger = logging.getLogger(__name__) async def hello(request): return web.Response(text="Hello, world") -class AIOWebSocketsReader(ReaderAdapter): +class WebSocketResponseReader(ReaderAdapter): + """Interface to allow mqtt broker to read from an aiohttp websocket connection.""" + def __init__(self, ws: web.WebSocketResponse): self.ws = ws self.buffer = bytearray() async def read(self, n: int = -1) -> bytes: + """ + Raises: + BrokerPipeError : if reading on a closed websocket connection + """ + # continue until buffer contains at least the amount of data being requested while not self.buffer or len(self.buffer) < n: if self.ws.closed: raise BrokenPipeError() - msg = await self.ws.receive() - if msg.type == aiohttp.WSMsgType.BINARY: - self.buffer.extend(msg.data) - elif msg.type == aiohttp.WSMsgType.CLOSE: + try: + async with asyncio.timeout(0.5): + msg = await self.ws.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + self.buffer.extend(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSE: + raise BrokenPipeError() + except asyncio.TimeoutError: raise BrokenPipeError() if n == -1: @@ -41,11 +54,24 @@ class AIOWebSocketsReader(ReaderAdapter): def feed_eof(self) -> None: pass -class AIOWebSocketsWriter(WriterAdapter): +class WebSocketResponseWriter(WriterAdapter): + """Interface to allow mqtt broker to write to an aiohttp websocket connection.""" - def __init__(self, ws: web.WebSocketResponse): + def __init__(self, ws: web.WebSocketResponse, request: web.Request): super().__init__() self.ws = ws + + # needed for `get_peer_info` + # https://docs.python.org/3/library/socket.html#socket.socket.getpeername + peer_name = request.transport.get_extra_info('peername') + if peer_name is not None: + self.client_ip, self.port = peer_name[0:2] + else: + self.client_ip, self.port = request.remote, 0 + + # interpret AF_INET6 + self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip + self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: @@ -58,21 +84,25 @@ class AIOWebSocketsWriter(WriterAdapter): self._stream = io.BytesIO(b"") def get_peer_info(self) -> tuple[str, int] | None: - return "external", 0 + return self.client_ip, self.port async def close(self) -> None: pass -async def websocket_handler(request): - print() +async def websocket_handler(request: web.Request) -> web.StreamResponse: + + # respond to the websocket request with the 'mqtt' protocol ws = web.WebSocketResponse(protocols=['mqtt',]) await ws.prepare(request) + # access the broker created when the server started and notify the broker of this new connection b: Broker = request.app['broker'] - await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws)) - print('websocket connection closed') + # send/receive data to the websocket. must pass the name of the externalized listener in the broker config + await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), 'myAIOHttp') + + logger.debug('websocket connection closed') return ws @@ -89,17 +119,16 @@ def main(): async def run_broker(_app): + """https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks""" loop = asyncio.get_event_loop() cfg = BrokerConfig( listeners={ - 'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), - 'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL), + 'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'), + 'myAIOHttp': ListenerConfig(type=ListenerType.EXTERNAL), } ) - - broker = Broker(config=cfg, loop=loop) _app['broker'] = broker await broker.start() @@ -109,8 +138,6 @@ async def run_broker(_app): await broker.shutdown() - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) main() diff --git a/tests/test_paho.py b/tests/test_paho.py index 0e4fb3c..9a254a0 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -1,11 +1,19 @@ import asyncio import logging import random +import threading +import time +from pathlib import Path +from threading import Thread +from typing import Any from unittest.mock import MagicMock, call, patch import pytest +import yaml from paho.mqtt import client as paho_client +from yaml import Loader +from amqtt.broker import Broker from amqtt.events import BrokerEvents from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 @@ -125,20 +133,51 @@ async def test_paho_qos2(broker, mock_plugin_manager): await sub_client.disconnect() await asyncio.sleep(0.1) -async def test_paho_ws(): + + +def run_paho_client(flag): client_id = 'websocket_client_1' + logging.info("creating paho client") test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2, transport='websockets', client_id=client_id) - test_client.ws_set_options('/ws') - - test_client.connect('localhost', 8080) + test_client.ws_set_options('') + logging.info("client connecting...") + test_client.connect('127.0.0.1', 8080) + logging.info("starting loop") test_client.loop_start() - await asyncio.sleep(0.1) + logging.info("client connected") + time.sleep(1) + logging.info("sending messages") test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2) - await asyncio.sleep(0.1) - test_client.loop_stop() \ No newline at end of file + time.sleep(1) + test_client.loop_stop() + test_client.disconnect() + flag.set() + + +@pytest.mark.asyncio +async def test_paho_ws(): + path = Path('docs_test/test.amqtt.local.yaml') + with path.open() as f: + cfg: dict[str, Any] = yaml.load(f, Loader=Loader) + logger.warning(cfg) + broker = Broker(config=cfg) + await broker.start() + + # python websockets and paho mqtt don't play well with each other in the same thread + flag = threading.Event() + thread = Thread(target=run_paho_client, args=(flag,)) + thread.start() + + await asyncio.sleep(5) + thread.join(1) + + assert flag.is_set(), "paho thread didn't execute completely" + + logging.info("client disconnected") + await broker.shutdown()