From ad4854df007baf1d191420c50cd078ae661fc7d2 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 17:23:18 -0400 Subject: [PATCH] Yakifo/amqtt#73 : working example of using aiohttp server to receive the websocket connection --- amqtt/broker.py | 2 +- amqtt/mqtt/protocol/handler.py | 5 +++++ samples/web.py | 18 ++++++++++++------ tests/test_paho.py | 26 ++++++++++++++++++++++---- 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 3bbc8c0..028af12 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -99,7 +99,7 @@ class OtherServer(Server): async def acquire_connection(self) -> None: pass - async def release_connection(self) -> None: + def release_connection(self) -> None: pass async def close_instance(self) -> None: diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 66bc1cb..8c53dc1 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -1,4 +1,5 @@ import asyncio +import traceback try: from asyncio import InvalidStateError, QueueFull, QueueShutDown @@ -535,6 +536,10 @@ class ProtocolHandler(Generic[C]): self.handle_read_timeout() except NoDataError: self.logger.debug(f"{self.session.client_id} No data available") + except RuntimeError: + self.logger.debug(f"{self.session.client_id} websocket closed") + traceback.print_exc() + break except Exception as e: # noqa: BLE001 self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}") break diff --git a/samples/web.py b/samples/web.py index e54d4e5..5f8bc68 100644 --- a/samples/web.py +++ b/samples/web.py @@ -1,4 +1,5 @@ import asyncio +import io import logging import aiohttp @@ -19,16 +20,18 @@ class AIOWebSocketsReader(ReaderAdapter): def __init__(self, ws: web.WebSocketResponse): self.ws = ws self.buffer = bytearray() - self.closed = False async def read(self, n: int = -1) -> bytes: print(f"attempting to read {n} bytes") - while not self.buffer and not self.closed or len(self.buffer) < n: + while not self.buffer or len(self.buffer) < n: + if self.ws.closed: + raise BrokenPipeError() msg = await self.ws.receive() if msg.type == aiohttp.WSMsgType.BINARY: self.buffer.extend(msg.data) elif msg.type == aiohttp.WSMsgType.CLOSE: - self.closed = True + print("received a close message!") + break print(f"buffer size: {len(self.buffer)}") if n == -1: result = bytes(self.buffer) @@ -47,14 +50,17 @@ class AIOWebSocketsWriter(WriterAdapter): def __init__(self, ws: web.WebSocketResponse): super().__init__() self.ws = ws + self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: print(f"broker wants to write data: {data}") - self.ws.send_bytes(data) - + self._stream.write(data) async def drain(self) -> None: - pass + data = self._stream.getvalue() + if data and len(data): + await self.ws.send_bytes(data) + self._stream = io.BytesIO(b"") def get_peer_info(self) -> tuple[str, int] | None: return "aiohttp", 1234567 diff --git a/tests/test_paho.py b/tests/test_paho.py index be53b2a..0e4fb3c 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -4,7 +4,7 @@ import random from unittest.mock import MagicMock, call, patch import pytest -from paho.mqtt import client as mqtt_client +from paho.mqtt import client as paho_client from amqtt.events import BrokerEvents from amqtt.client import MQTTClient @@ -40,7 +40,7 @@ async def test_paho_connect(broker, mock_plugin_manager): assert rc == 0, f"Disconnect failed with result code {rc}" test_complete.set() - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.on_connect = on_connect @@ -76,7 +76,7 @@ async def test_paho_qos1(broker, mock_plugin_manager): port = 1883 client_id = f'python-mqtt-{random.randint(0, 1000)}' - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.connect(host, port) @@ -107,7 +107,7 @@ async def test_paho_qos2(broker, mock_plugin_manager): port = 1883 client_id = f'python-mqtt-{random.randint(0, 1000)}' - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.connect(host, port) @@ -124,3 +124,21 @@ async def test_paho_qos2(broker, mock_plugin_manager): assert message.data == b"test message" await sub_client.disconnect() await asyncio.sleep(0.1) + +async def test_paho_ws(): + client_id = 'websocket_client_1' + test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2, + transport='websockets', + client_id=client_id) + + test_client.ws_set_options('/ws') + + test_client.connect('localhost', 8080) + test_client.loop_start() + await asyncio.sleep(0.1) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + await asyncio.sleep(0.1) + test_client.loop_stop() \ No newline at end of file