kopia lustrzana https://github.com/Yakifo/amqtt
Yakifo/amqtt#73 : working example of using aiohttp server to receive the websocket connection
rodzic
76651e32f3
commit
ad4854df00
|
@ -99,7 +99,7 @@ class OtherServer(Server):
|
||||||
async def acquire_connection(self) -> None:
|
async def acquire_connection(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def release_connection(self) -> None:
|
def release_connection(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def close_instance(self) -> None:
|
async def close_instance(self) -> None:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from asyncio import InvalidStateError, QueueFull, QueueShutDown
|
from asyncio import InvalidStateError, QueueFull, QueueShutDown
|
||||||
|
@ -535,6 +536,10 @@ class ProtocolHandler(Generic[C]):
|
||||||
self.handle_read_timeout()
|
self.handle_read_timeout()
|
||||||
except NoDataError:
|
except NoDataError:
|
||||||
self.logger.debug(f"{self.session.client_id} No data available")
|
self.logger.debug(f"{self.session.client_id} No data available")
|
||||||
|
except RuntimeError:
|
||||||
|
self.logger.debug(f"{self.session.client_id} websocket closed")
|
||||||
|
traceback.print_exc()
|
||||||
|
break
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}")
|
self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}")
|
||||||
break
|
break
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
@ -19,16 +20,18 @@ class AIOWebSocketsReader(ReaderAdapter):
|
||||||
def __init__(self, ws: web.WebSocketResponse):
|
def __init__(self, ws: web.WebSocketResponse):
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
self.buffer = bytearray()
|
self.buffer = bytearray()
|
||||||
self.closed = False
|
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
print(f"attempting to read {n} bytes")
|
print(f"attempting to read {n} bytes")
|
||||||
while not self.buffer and not self.closed or len(self.buffer) < n:
|
while not self.buffer or len(self.buffer) < n:
|
||||||
|
if self.ws.closed:
|
||||||
|
raise BrokenPipeError()
|
||||||
msg = await self.ws.receive()
|
msg = await self.ws.receive()
|
||||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
self.buffer.extend(msg.data)
|
self.buffer.extend(msg.data)
|
||||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||||
self.closed = True
|
print("received a close message!")
|
||||||
|
break
|
||||||
print(f"buffer size: {len(self.buffer)}")
|
print(f"buffer size: {len(self.buffer)}")
|
||||||
if n == -1:
|
if n == -1:
|
||||||
result = bytes(self.buffer)
|
result = bytes(self.buffer)
|
||||||
|
@ -47,14 +50,17 @@ class AIOWebSocketsWriter(WriterAdapter):
|
||||||
def __init__(self, ws: web.WebSocketResponse):
|
def __init__(self, ws: web.WebSocketResponse):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
self._stream = io.BytesIO(b"")
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
print(f"broker wants to write data: {data}")
|
print(f"broker wants to write data: {data}")
|
||||||
self.ws.send_bytes(data)
|
self._stream.write(data)
|
||||||
|
|
||||||
|
|
||||||
async def drain(self) -> None:
|
async def drain(self) -> None:
|
||||||
pass
|
data = self._stream.getvalue()
|
||||||
|
if data and len(data):
|
||||||
|
await self.ws.send_bytes(data)
|
||||||
|
self._stream = io.BytesIO(b"")
|
||||||
|
|
||||||
def get_peer_info(self) -> tuple[str, int] | None:
|
def get_peer_info(self) -> tuple[str, int] | None:
|
||||||
return "aiohttp", 1234567
|
return "aiohttp", 1234567
|
||||||
|
|
|
@ -4,7 +4,7 @@ import random
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from paho.mqtt import client as mqtt_client
|
from paho.mqtt import client as paho_client
|
||||||
|
|
||||||
from amqtt.events import BrokerEvents
|
from amqtt.events import BrokerEvents
|
||||||
from amqtt.client import MQTTClient
|
from amqtt.client import MQTTClient
|
||||||
|
@ -40,7 +40,7 @@ async def test_paho_connect(broker, mock_plugin_manager):
|
||||||
assert rc == 0, f"Disconnect failed with result code {rc}"
|
assert rc == 0, f"Disconnect failed with result code {rc}"
|
||||||
test_complete.set()
|
test_complete.set()
|
||||||
|
|
||||||
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
||||||
test_client.enable_logger(paho_logger)
|
test_client.enable_logger(paho_logger)
|
||||||
|
|
||||||
test_client.on_connect = on_connect
|
test_client.on_connect = on_connect
|
||||||
|
@ -76,7 +76,7 @@ async def test_paho_qos1(broker, mock_plugin_manager):
|
||||||
port = 1883
|
port = 1883
|
||||||
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
||||||
|
|
||||||
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
||||||
test_client.enable_logger(paho_logger)
|
test_client.enable_logger(paho_logger)
|
||||||
|
|
||||||
test_client.connect(host, port)
|
test_client.connect(host, port)
|
||||||
|
@ -107,7 +107,7 @@ async def test_paho_qos2(broker, mock_plugin_manager):
|
||||||
port = 1883
|
port = 1883
|
||||||
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
||||||
|
|
||||||
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
||||||
test_client.enable_logger(paho_logger)
|
test_client.enable_logger(paho_logger)
|
||||||
|
|
||||||
test_client.connect(host, port)
|
test_client.connect(host, port)
|
||||||
|
@ -124,3 +124,21 @@ async def test_paho_qos2(broker, mock_plugin_manager):
|
||||||
assert message.data == b"test message"
|
assert message.data == b"test message"
|
||||||
await sub_client.disconnect()
|
await sub_client.disconnect()
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def test_paho_ws():
|
||||||
|
client_id = 'websocket_client_1'
|
||||||
|
test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2,
|
||||||
|
transport='websockets',
|
||||||
|
client_id=client_id)
|
||||||
|
|
||||||
|
test_client.ws_set_options('/ws')
|
||||||
|
|
||||||
|
test_client.connect('localhost', 8080)
|
||||||
|
test_client.loop_start()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
test_client.loop_stop()
|
Ładowanie…
Reference in New Issue