Yakifo/amqtt#73 : refinement of external websocket server

pull/283/head
Andrew Mirsky 2025-08-04 19:19:51 -04:00
rodzic c1625e264e
commit 0141cddeeb
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
3 zmienionych plików z 95 dodań i 26 usunięć

Wyświetl plik

@ -403,6 +403,9 @@ class Broker:
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
await self._client_connected(listener_name, reader, writer)
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
"""Handle a new client connection.""" """Handle a new client connection."""
server = self._servers.get(listener_name) server = self._servers.get(listener_name)

Wyświetl plik

@ -8,25 +8,38 @@ from aiohttp import web
from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker from amqtt.broker import Broker
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import BrokerError
logger = logging.getLogger(__name__)
async def hello(request): async def hello(request):
return web.Response(text="Hello, world") return web.Response(text="Hello, world")
class AIOWebSocketsReader(ReaderAdapter): class WebSocketResponseReader(ReaderAdapter):
"""Interface to allow mqtt broker to read from an aiohttp websocket connection."""
def __init__(self, ws: web.WebSocketResponse): def __init__(self, ws: web.WebSocketResponse):
self.ws = ws self.ws = ws
self.buffer = bytearray() self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
"""
Raises:
BrokerPipeError : if reading on a closed websocket connection
"""
# continue until buffer contains at least the amount of data being requested
while not self.buffer or len(self.buffer) < n: while not self.buffer or len(self.buffer) < n:
if self.ws.closed: if self.ws.closed:
raise BrokenPipeError() raise BrokenPipeError()
msg = await self.ws.receive() try:
if msg.type == aiohttp.WSMsgType.BINARY: async with asyncio.timeout(0.5):
self.buffer.extend(msg.data) msg = await self.ws.receive()
elif msg.type == aiohttp.WSMsgType.CLOSE: if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
raise BrokenPipeError()
except asyncio.TimeoutError:
raise BrokenPipeError() raise BrokenPipeError()
if n == -1: if n == -1:
@ -41,11 +54,24 @@ class AIOWebSocketsReader(ReaderAdapter):
def feed_eof(self) -> None: def feed_eof(self) -> None:
pass pass
class AIOWebSocketsWriter(WriterAdapter): class WebSocketResponseWriter(WriterAdapter):
"""Interface to allow mqtt broker to write to an aiohttp websocket connection."""
def __init__(self, ws: web.WebSocketResponse): def __init__(self, ws: web.WebSocketResponse, request: web.Request):
super().__init__() super().__init__()
self.ws = ws self.ws = ws
# needed for `get_peer_info`
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
peer_name = request.transport.get_extra_info('peername')
if peer_name is not None:
self.client_ip, self.port = peer_name[0:2]
else:
self.client_ip, self.port = request.remote, 0
# interpret AF_INET6
self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip
self._stream = io.BytesIO(b"") self._stream = io.BytesIO(b"")
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
@ -58,21 +84,25 @@ class AIOWebSocketsWriter(WriterAdapter):
self._stream = io.BytesIO(b"") self._stream = io.BytesIO(b"")
def get_peer_info(self) -> tuple[str, int] | None: def get_peer_info(self) -> tuple[str, int] | None:
return "external", 0 return self.client_ip, self.port
async def close(self) -> None: async def close(self) -> None:
pass pass
async def websocket_handler(request): async def websocket_handler(request: web.Request) -> web.StreamResponse:
print()
# respond to the websocket request with the 'mqtt' protocol
ws = web.WebSocketResponse(protocols=['mqtt',]) ws = web.WebSocketResponse(protocols=['mqtt',])
await ws.prepare(request) await ws.prepare(request)
# access the broker created when the server started and notify the broker of this new connection
b: Broker = request.app['broker'] b: Broker = request.app['broker']
await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws))
print('websocket connection closed')
# send/receive data to the websocket. must pass the name of the externalized listener in the broker config
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), 'myAIOHttp')
logger.debug('websocket connection closed')
return ws return ws
@ -89,17 +119,16 @@ def main():
async def run_broker(_app): async def run_broker(_app):
"""https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
cfg = BrokerConfig( cfg = BrokerConfig(
listeners={ listeners={
'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), 'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL), 'myAIOHttp': ListenerConfig(type=ListenerType.EXTERNAL),
} }
) )
broker = Broker(config=cfg, loop=loop) broker = Broker(config=cfg, loop=loop)
_app['broker'] = broker _app['broker'] = broker
await broker.start() await broker.start()
@ -109,8 +138,6 @@ async def run_broker(_app):
await broker.shutdown() await broker.shutdown()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
main() main()

Wyświetl plik

@ -1,11 +1,19 @@
import asyncio import asyncio
import logging import logging
import random import random
import threading
import time
from pathlib import Path
from threading import Thread
from typing import Any
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch
import pytest import pytest
import yaml
from paho.mqtt import client as paho_client from paho.mqtt import client as paho_client
from yaml import Loader
from amqtt.broker import Broker
from amqtt.events import BrokerEvents from amqtt.events import BrokerEvents
from amqtt.client import MQTTClient from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2 from amqtt.mqtt.constants import QOS_1, QOS_2
@ -125,20 +133,51 @@ async def test_paho_qos2(broker, mock_plugin_manager):
await sub_client.disconnect() await sub_client.disconnect()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
async def test_paho_ws():
def run_paho_client(flag):
client_id = 'websocket_client_1' client_id = 'websocket_client_1'
logging.info("creating paho client")
test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2, test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2,
transport='websockets', transport='websockets',
client_id=client_id) client_id=client_id)
test_client.ws_set_options('/ws') test_client.ws_set_options('')
logging.info("client connecting...")
test_client.connect('localhost', 8080) test_client.connect('127.0.0.1', 8080)
logging.info("starting loop")
test_client.loop_start() test_client.loop_start()
await asyncio.sleep(0.1) logging.info("client connected")
time.sleep(1)
logging.info("sending messages")
test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2)
await asyncio.sleep(0.1) time.sleep(1)
test_client.loop_stop() test_client.loop_stop()
test_client.disconnect()
flag.set()
@pytest.mark.asyncio
async def test_paho_ws():
path = Path('docs_test/test.amqtt.local.yaml')
with path.open() as f:
cfg: dict[str, Any] = yaml.load(f, Loader=Loader)
logger.warning(cfg)
broker = Broker(config=cfg)
await broker.start()
# python websockets and paho mqtt don't play well with each other in the same thread
flag = threading.Event()
thread = Thread(target=run_paho_client, args=(flag,))
thread.start()
await asyncio.sleep(5)
thread.join(1)
assert flag.is_set(), "paho thread didn't execute completely"
logging.info("client disconnected")
await broker.shutdown()