kopia lustrzana https://github.com/Yakifo/amqtt
commit
2a7aa11524
|
@ -22,7 +22,7 @@ from amqtt.adapters import (
|
||||||
WebSocketsWriter,
|
WebSocketsWriter,
|
||||||
WriterAdapter,
|
WriterAdapter,
|
||||||
)
|
)
|
||||||
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig
|
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
|
||||||
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
|
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
|
||||||
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||||
|
@ -52,6 +52,8 @@ class RetainedApplicationMessage(ApplicationMessage):
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
|
"""Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
listener_name: str,
|
listener_name: str,
|
||||||
|
@ -89,11 +91,24 @@ class Server:
|
||||||
await self.instance.wait_closed()
|
await self.instance.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
class BrokerContext(BaseContext):
|
class ExternalServer(Server):
|
||||||
"""BrokerContext is used as the context passed to plugins interacting with the broker.
|
"""For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
|
||||||
|
|
||||||
It act as an adapter to broker services from plugins developed for HBMQTT broker.
|
def __init__(self) -> None:
|
||||||
"""
|
super().__init__("aiohttp", None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
async def acquire_connection(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def release_connection(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def close_instance(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BrokerContext(BaseContext):
|
||||||
|
"""BrokerContext is used as the context passed to plugins interacting with the broker."""
|
||||||
|
|
||||||
def __init__(self, broker: "Broker") -> None:
|
def __init__(self, broker: "Broker") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -243,6 +258,14 @@ class Broker:
|
||||||
max_connections = listener.get("max_connections", -1)
|
max_connections = listener.get("max_connections", -1)
|
||||||
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
|
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
|
||||||
|
|
||||||
|
# for listeners which are external, don't need to create a server
|
||||||
|
if listener.type == ListenerType.EXTERNAL:
|
||||||
|
|
||||||
|
# broker still needs to associate a new connection to the listener
|
||||||
|
self.logger.info(f"External listener exists for '{listener_name}' ")
|
||||||
|
self._servers[listener_name] = ExternalServer()
|
||||||
|
else:
|
||||||
|
# for tcp and websockets, start servers to listen for inbound connections
|
||||||
try:
|
try:
|
||||||
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -385,6 +408,10 @@ class Broker:
|
||||||
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
||||||
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||||
|
|
||||||
|
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
|
||||||
|
"""Engage the broker in handling the data stream to/from an established connection."""
|
||||||
|
await self._client_connected(listener_name, reader, writer)
|
||||||
|
|
||||||
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
||||||
"""Handle a new client connection."""
|
"""Handle a new client connection."""
|
||||||
server = self._servers.get(listener_name)
|
server = self._servers.get(listener_name)
|
||||||
|
|
|
@ -44,6 +44,7 @@ class ListenerType(StrEnum):
|
||||||
|
|
||||||
TCP = "tcp"
|
TCP = "tcp"
|
||||||
WS = "ws"
|
WS = "ws"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Display the string value, instead of the enum member."""
|
"""Display the string value, instead of the enum member."""
|
||||||
|
@ -114,6 +115,8 @@ class ListenerConfig(Dictable):
|
||||||
certificates needed to establish the certificate's authenticity.)"""
|
certificates needed to establish the certificate's authenticity.)"""
|
||||||
keyfile: str | Path | None = None
|
keyfile: str | Path | None = None
|
||||||
"""Full path to file in PEM format containing the server's private key."""
|
"""Full path to file in PEM format containing the server's private key."""
|
||||||
|
reader: str | None = None
|
||||||
|
writer: str | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Check config for errors and transform fields for easier use."""
|
"""Check config for errors and transform fields for easier use."""
|
||||||
|
|
|
@ -0,0 +1,180 @@
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||||
|
from amqtt.broker import Broker
|
||||||
|
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
|
||||||
|
from amqtt.errors import ConnectError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MQTT_LISTENER_NAME = "myMqttListener"
|
||||||
|
|
||||||
|
async def hello(request):
|
||||||
|
"""get request handler"""
|
||||||
|
return web.Response(text="Hello, world")
|
||||||
|
|
||||||
|
class WebSocketResponseReader(ReaderAdapter):
|
||||||
|
"""Interface to allow mqtt broker to read from an aiohttp websocket connection."""
|
||||||
|
|
||||||
|
def __init__(self, ws: web.WebSocketResponse):
|
||||||
|
self.ws = ws
|
||||||
|
self.buffer = bytearray()
|
||||||
|
|
||||||
|
async def read(self, n: int = -1) -> bytes:
|
||||||
|
"""
|
||||||
|
read 'n' bytes from the datastream, if < 0 read all available bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BrokerPipeError : if reading on a closed websocket connection
|
||||||
|
"""
|
||||||
|
# continue until buffer contains at least the amount of data being requested
|
||||||
|
while not self.buffer or len(self.buffer) < n:
|
||||||
|
# if the websocket is closed
|
||||||
|
if self.ws.closed:
|
||||||
|
raise BrokenPipeError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# read from stream
|
||||||
|
msg = await asyncio.wait_for(self.ws.receive(), timeout=0.5)
|
||||||
|
# mqtt streams should always be binary...
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
self.buffer.extend(msg.data)
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||||
|
raise BrokenPipeError()
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise BrokenPipeError()
|
||||||
|
|
||||||
|
# return all bytes currently in the buffer
|
||||||
|
if n == -1:
|
||||||
|
result = bytes(self.buffer)
|
||||||
|
self.buffer.clear()
|
||||||
|
# return the requested number of bytes from the buffer
|
||||||
|
else:
|
||||||
|
result = self.buffer[:n]
|
||||||
|
del self.buffer[:n]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def feed_eof(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class WebSocketResponseWriter(WriterAdapter):
|
||||||
|
"""Interface to allow mqtt broker to write to an aiohttp websocket connection."""
|
||||||
|
|
||||||
|
def __init__(self, ws: web.WebSocketResponse, request: web.Request):
|
||||||
|
super().__init__()
|
||||||
|
self.ws = ws
|
||||||
|
|
||||||
|
# needed for `get_peer_info`
|
||||||
|
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
|
||||||
|
peer_name = request.transport.get_extra_info('peername')
|
||||||
|
if peer_name is not None:
|
||||||
|
self.client_ip, self.port = peer_name[0:2]
|
||||||
|
else:
|
||||||
|
self.client_ip, self.port = request.remote, 0
|
||||||
|
|
||||||
|
# interpret AF_INET6
|
||||||
|
self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip
|
||||||
|
|
||||||
|
self._stream = io.BytesIO(b"")
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
"""Add bytes to stream buffer."""
|
||||||
|
self._stream.write(data)
|
||||||
|
|
||||||
|
async def drain(self) -> None:
|
||||||
|
"""Send the collected bytes in the buffer to the websocket connection."""
|
||||||
|
data = self._stream.getvalue()
|
||||||
|
if data and len(data):
|
||||||
|
await self.ws.send_bytes(data)
|
||||||
|
self._stream = io.BytesIO(b"")
|
||||||
|
|
||||||
|
def get_peer_info(self) -> tuple[str, int] | None:
|
||||||
|
return self.client_ip, self.port
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
# no clean up needed, stream will be gc along with instance
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse:
|
||||||
|
|
||||||
|
# establish connection by responding to the websocket request with the 'mqtt' protocol
|
||||||
|
ws = web.WebSocketResponse(protocols=['mqtt',])
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
# access the broker created when the server started
|
||||||
|
b: Broker = request.app['broker']
|
||||||
|
|
||||||
|
# hand-off the websocket data stream to the broker for handling
|
||||||
|
# `listener_name` is the same name of the externalized listener in the broker config
|
||||||
|
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME)
|
||||||
|
|
||||||
|
logger.debug('websocket connection closed')
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.StreamResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for msg in ws:
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
logging.info("websocket connection closed")
|
||||||
|
return ws
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# create an `aiohttp` server
|
||||||
|
lp = asyncio.get_event_loop()
|
||||||
|
app = web.Application()
|
||||||
|
app.add_routes(
|
||||||
|
[
|
||||||
|
web.get('/', hello), # http get request/response route
|
||||||
|
web.get('/ws', websocket_handler), # standard websocket handler
|
||||||
|
web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections
|
||||||
|
])
|
||||||
|
# create background task for running the `amqtt` broker
|
||||||
|
app.cleanup_ctx.append(run_broker)
|
||||||
|
|
||||||
|
# make sure that both `aiohttp` server and `amqtt` broker run in the same loop
|
||||||
|
# so the server can hand off the connection to the broker (prevents attached-to-a-different-loop `RuntimeError`)
|
||||||
|
web.run_app(app, loop=lp)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_broker(_app):
|
||||||
|
"""App init function to start (and then shutdown) the `amqtt` broker.
|
||||||
|
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
|
||||||
|
|
||||||
|
# standard TCP connection as well as an externalized-listener
|
||||||
|
cfg = BrokerConfig(
|
||||||
|
listeners={
|
||||||
|
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
|
||||||
|
MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure the `Broker` runs in the same loop as the aiohttp server
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
broker = Broker(config=cfg, loop=loop)
|
||||||
|
|
||||||
|
# store broker instance so that incoming requests can hand off processing of a datastream
|
||||||
|
_app['broker'] = broker
|
||||||
|
# start the broker
|
||||||
|
await broker.start()
|
||||||
|
|
||||||
|
# pass control back to web app
|
||||||
|
yield
|
||||||
|
|
||||||
|
# closing activities
|
||||||
|
await broker.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
|
@ -1,11 +1,19 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from paho.mqtt import client as mqtt_client
|
import yaml
|
||||||
|
from paho.mqtt import client as paho_client
|
||||||
|
from yaml import Loader
|
||||||
|
|
||||||
|
from amqtt.broker import Broker
|
||||||
from amqtt.events import BrokerEvents
|
from amqtt.events import BrokerEvents
|
||||||
from amqtt.client import MQTTClient
|
from amqtt.client import MQTTClient
|
||||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||||
|
@ -40,7 +48,7 @@ async def test_paho_connect(broker, mock_plugin_manager):
|
||||||
assert rc == 0, f"Disconnect failed with result code {rc}"
|
assert rc == 0, f"Disconnect failed with result code {rc}"
|
||||||
test_complete.set()
|
test_complete.set()
|
||||||
|
|
||||||
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
||||||
test_client.enable_logger(paho_logger)
|
test_client.enable_logger(paho_logger)
|
||||||
|
|
||||||
test_client.on_connect = on_connect
|
test_client.on_connect = on_connect
|
||||||
|
@ -76,7 +84,7 @@ async def test_paho_qos1(broker, mock_plugin_manager):
|
||||||
port = 1883
|
port = 1883
|
||||||
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
||||||
|
|
||||||
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
||||||
test_client.enable_logger(paho_logger)
|
test_client.enable_logger(paho_logger)
|
||||||
|
|
||||||
test_client.connect(host, port)
|
test_client.connect(host, port)
|
||||||
|
@ -107,7 +115,7 @@ async def test_paho_qos2(broker, mock_plugin_manager):
|
||||||
port = 1883
|
port = 1883
|
||||||
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
client_id = f'python-mqtt-{random.randint(0, 1000)}'
|
||||||
|
|
||||||
test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id)
|
||||||
test_client.enable_logger(paho_logger)
|
test_client.enable_logger(paho_logger)
|
||||||
|
|
||||||
test_client.connect(host, port)
|
test_client.connect(host, port)
|
||||||
|
@ -124,3 +132,52 @@ async def test_paho_qos2(broker, mock_plugin_manager):
|
||||||
assert message.data == b"test message"
|
assert message.data == b"test message"
|
||||||
await sub_client.disconnect()
|
await sub_client.disconnect()
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_paho_client(flag):
|
||||||
|
client_id = 'websocket_client_1'
|
||||||
|
logging.info("creating paho client")
|
||||||
|
test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2,
|
||||||
|
transport='websockets',
|
||||||
|
client_id=client_id)
|
||||||
|
|
||||||
|
test_client.ws_set_options('')
|
||||||
|
logging.info("client connecting...")
|
||||||
|
test_client.connect('127.0.0.1', 8080)
|
||||||
|
logging.info("starting loop")
|
||||||
|
test_client.loop_start()
|
||||||
|
logging.info("client connected")
|
||||||
|
time.sleep(1)
|
||||||
|
logging.info("sending messages")
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
test_client.publish("/qos2", "test message", qos=2)
|
||||||
|
time.sleep(1)
|
||||||
|
test_client.loop_stop()
|
||||||
|
test_client.disconnect()
|
||||||
|
flag.set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_paho_ws():
|
||||||
|
path = Path('docs_test/test.amqtt.local.yaml')
|
||||||
|
with path.open() as f:
|
||||||
|
cfg: dict[str, Any] = yaml.load(f, Loader=Loader)
|
||||||
|
logger.warning(cfg)
|
||||||
|
broker = Broker(config=cfg)
|
||||||
|
await broker.start()
|
||||||
|
|
||||||
|
# python websockets and paho mqtt don't play well with each other in the same thread
|
||||||
|
flag = threading.Event()
|
||||||
|
thread = Thread(target=run_paho_client, args=(flag,))
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
thread.join(1)
|
||||||
|
|
||||||
|
assert flag.is_set(), "paho thread didn't execute completely"
|
||||||
|
|
||||||
|
logging.info("client disconnected")
|
||||||
|
await broker.shutdown()
|
||||||
|
|
|
@ -2,14 +2,15 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
from multiprocessing import Process
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from samples.http_server_integration import main as http_server_main
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from amqtt.broker import Broker
|
from amqtt.broker import Broker
|
||||||
from samples.client_publish import __main__ as client_publish_main
|
from amqtt.client import MQTTClient
|
||||||
from samples.client_subscribe import __main__ as client_subscribe_main
|
|
||||||
from samples.client_keepalive import __main__ as client_keepalive_main
|
|
||||||
from samples.broker_acl import config as broker_acl_config
|
from samples.broker_acl import config as broker_acl_config
|
||||||
from samples.broker_taboo import config as broker_taboo_config
|
from samples.broker_taboo import config as broker_taboo_config
|
||||||
|
|
||||||
|
@ -276,3 +277,24 @@ async def test_client_subscribe_plugin_taboo():
|
||||||
assert "Exception" not in stderr.decode("utf-8")
|
assert "Exception" not in stderr.decode("utf-8")
|
||||||
|
|
||||||
await broker.shutdown()
|
await broker.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def external_http_server():
|
||||||
|
p = Process(target=http_server_main)
|
||||||
|
p.start()
|
||||||
|
yield p
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_external_http_server(external_http_server):
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
client = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
await client.connect("ws://127.0.0.1:8080/mqtt")
|
||||||
|
assert client.session is not None
|
||||||
|
await client.publish("my/topic", b'test message')
|
||||||
|
await client.disconnect()
|
||||||
|
# Send the interrupt signal
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
Ładowanie…
Reference in New Issue