kopia lustrzana https://github.com/Yakifo/amqtt
Yakifo/amqtt#73 : refinement of external websocket server
rodzic
c1625e264e
commit
0141cddeeb
|
@ -403,6 +403,9 @@ class Broker:
|
|||
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
||||
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||
|
||||
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
|
||||
await self._client_connected(listener_name, reader, writer)
|
||||
|
||||
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
||||
"""Handle a new client connection."""
|
||||
server = self._servers.get(listener_name)
|
||||
|
|
|
@ -8,25 +8,38 @@ from aiohttp import web
|
|||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
|
||||
from amqtt.errors import BrokerError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def hello(request):
|
||||
return web.Response(text="Hello, world")
|
||||
|
||||
class AIOWebSocketsReader(ReaderAdapter):
|
||||
class WebSocketResponseReader(ReaderAdapter):
|
||||
"""Interface to allow mqtt broker to read from an aiohttp websocket connection."""
|
||||
|
||||
def __init__(self, ws: web.WebSocketResponse):
|
||||
self.ws = ws
|
||||
self.buffer = bytearray()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
Raises:
|
||||
BrokerPipeError : if reading on a closed websocket connection
|
||||
"""
|
||||
# continue until buffer contains at least the amount of data being requested
|
||||
while not self.buffer or len(self.buffer) < n:
|
||||
if self.ws.closed:
|
||||
raise BrokenPipeError()
|
||||
msg = await self.ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self.buffer.extend(msg.data)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
try:
|
||||
async with asyncio.timeout(0.5):
|
||||
msg = await self.ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self.buffer.extend(msg.data)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
raise BrokenPipeError()
|
||||
except asyncio.TimeoutError:
|
||||
raise BrokenPipeError()
|
||||
|
||||
if n == -1:
|
||||
|
@ -41,11 +54,24 @@ class AIOWebSocketsReader(ReaderAdapter):
|
|||
def feed_eof(self) -> None:
|
||||
pass
|
||||
|
||||
class AIOWebSocketsWriter(WriterAdapter):
|
||||
class WebSocketResponseWriter(WriterAdapter):
|
||||
"""Interface to allow mqtt broker to write to an aiohttp websocket connection."""
|
||||
|
||||
def __init__(self, ws: web.WebSocketResponse):
|
||||
def __init__(self, ws: web.WebSocketResponse, request: web.Request):
|
||||
super().__init__()
|
||||
self.ws = ws
|
||||
|
||||
# needed for `get_peer_info`
|
||||
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
|
||||
peer_name = request.transport.get_extra_info('peername')
|
||||
if peer_name is not None:
|
||||
self.client_ip, self.port = peer_name[0:2]
|
||||
else:
|
||||
self.client_ip, self.port = request.remote, 0
|
||||
|
||||
# interpret AF_INET6
|
||||
self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip
|
||||
|
||||
self._stream = io.BytesIO(b"")
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
|
@ -58,21 +84,25 @@ class AIOWebSocketsWriter(WriterAdapter):
|
|||
self._stream = io.BytesIO(b"")
|
||||
|
||||
def get_peer_info(self) -> tuple[str, int] | None:
|
||||
return "external", 0
|
||||
return self.client_ip, self.port
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
async def websocket_handler(request):
|
||||
print()
|
||||
async def websocket_handler(request: web.Request) -> web.StreamResponse:
|
||||
|
||||
# respond to the websocket request with the 'mqtt' protocol
|
||||
ws = web.WebSocketResponse(protocols=['mqtt',])
|
||||
await ws.prepare(request)
|
||||
|
||||
# access the broker created when the server started and notify the broker of this new connection
|
||||
b: Broker = request.app['broker']
|
||||
await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws))
|
||||
print('websocket connection closed')
|
||||
|
||||
# send/receive data to the websocket. must pass the name of the externalized listener in the broker config
|
||||
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), 'myAIOHttp')
|
||||
|
||||
logger.debug('websocket connection closed')
|
||||
return ws
|
||||
|
||||
|
||||
|
@ -89,17 +119,16 @@ def main():
|
|||
|
||||
|
||||
async def run_broker(_app):
|
||||
"""https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
cfg = BrokerConfig(
|
||||
listeners={
|
||||
'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'),
|
||||
'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL),
|
||||
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
|
||||
'myAIOHttp': ListenerConfig(type=ListenerType.EXTERNAL),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
broker = Broker(config=cfg, loop=loop)
|
||||
_app['broker'] = broker
|
||||
await broker.start()
|
||||
|
@ -109,8 +138,6 @@ async def run_broker(_app):
|
|||
await broker.shutdown()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from paho.mqtt import client as paho_client
|
||||
from yaml import Loader
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.events import BrokerEvents
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||
|
@ -125,20 +133,51 @@ async def test_paho_qos2(broker, mock_plugin_manager):
|
|||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def test_paho_ws():
|
||||
|
||||
|
||||
def run_paho_client(flag):
|
||||
client_id = 'websocket_client_1'
|
||||
logging.info("creating paho client")
|
||||
test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2,
|
||||
transport='websockets',
|
||||
client_id=client_id)
|
||||
|
||||
test_client.ws_set_options('/ws')
|
||||
|
||||
test_client.connect('localhost', 8080)
|
||||
test_client.ws_set_options('')
|
||||
logging.info("client connecting...")
|
||||
test_client.connect('127.0.0.1', 8080)
|
||||
logging.info("starting loop")
|
||||
test_client.loop_start()
|
||||
await asyncio.sleep(0.1)
|
||||
logging.info("client connected")
|
||||
time.sleep(1)
|
||||
logging.info("sending messages")
|
||||
test_client.publish("/qos2", "test message", qos=2)
|
||||
test_client.publish("/qos2", "test message", qos=2)
|
||||
test_client.publish("/qos2", "test message", qos=2)
|
||||
test_client.publish("/qos2", "test message", qos=2)
|
||||
await asyncio.sleep(0.1)
|
||||
test_client.loop_stop()
|
||||
time.sleep(1)
|
||||
test_client.loop_stop()
|
||||
test_client.disconnect()
|
||||
flag.set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paho_ws():
|
||||
path = Path('docs_test/test.amqtt.local.yaml')
|
||||
with path.open() as f:
|
||||
cfg: dict[str, Any] = yaml.load(f, Loader=Loader)
|
||||
logger.warning(cfg)
|
||||
broker = Broker(config=cfg)
|
||||
await broker.start()
|
||||
|
||||
# python websockets and paho mqtt don't play well with each other in the same thread
|
||||
flag = threading.Event()
|
||||
thread = Thread(target=run_paho_client, args=(flag,))
|
||||
thread.start()
|
||||
|
||||
await asyncio.sleep(5)
|
||||
thread.join(1)
|
||||
|
||||
assert flag.is_set(), "paho thread didn't execute completely"
|
||||
|
||||
logging.info("client disconnected")
|
||||
await broker.shutdown()
|
||||
|
|
Ładowanie…
Reference in New Issue