kopia lustrzana https://github.com/Yakifo/amqtt
commit
2637127b41
|
@ -4,6 +4,7 @@ source = bumper
|
|||
|
||||
omit =
|
||||
tests/*
|
||||
amqtt/scripts/*.py
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
|
|
|
@ -32,8 +32,11 @@ jobs:
|
|||
cache-local-path: ${{ env.UV_CACHE_DIR }}
|
||||
python-version: "3.13"
|
||||
|
||||
- name: install openldap dependencies
|
||||
run: sudo apt-get install -y libldap2-dev libsasl2-dev
|
||||
|
||||
- name: 🏗 Install the project
|
||||
run: uv sync --locked --dev
|
||||
run: uv sync --locked --dev --all-extras
|
||||
|
||||
- name: Run mypy
|
||||
run: uv run --frozen mypy ${{ env.PROJECT_PATH }}/
|
||||
|
@ -68,8 +71,11 @@ jobs:
|
|||
cache-local-path: ${{ env.UV_CACHE_DIR }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: install openldap dependencies
|
||||
run: sudo apt-get install -y libldap2-dev libsasl2-dev
|
||||
|
||||
- name: 🏗 Install the project
|
||||
run: uv sync --locked --dev
|
||||
run: uv sync --locked --dev --all-extras
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run --frozen pytest tests/ --cov=./ --cov-report=xml --junitxml=pytest-report.xml
|
||||
|
|
|
@ -4,6 +4,9 @@ __pycache__
|
|||
node_modules
|
||||
.vite
|
||||
*.pem
|
||||
*.crt
|
||||
*.key
|
||||
*.patch
|
||||
|
||||
#------- Environment Files -------
|
||||
.python-version
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
version: 2
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: "ubuntu-24.04"
|
||||
tools:
|
||||
python: "3.13"
|
||||
jobs:
|
||||
pre_install:
|
||||
- pip install --upgrade pip
|
||||
- pip install uv
|
||||
- uv pip install --group dev --group docs
|
||||
- uv run pytest
|
||||
build:
|
||||
os: "ubuntu-24.04"
|
||||
tools:
|
||||
python: "3.13"
|
||||
apt_packages:
|
||||
- libldap2-dev
|
||||
- libsasl2-dev
|
||||
jobs:
|
||||
pre_install:
|
||||
- pip install --upgrade pip
|
||||
- pip install uv
|
||||
- uv venv
|
||||
- uv pip install --group dev --group docs ".[contrib]"
|
||||
- uv run pytest --mock-docker=true
|
||||
build:
|
||||
html:
|
||||
- uv run python -m mkdocs build --clean --site-dir $READTHEDOCS_OUTPUT/html --config-file mkdocs.rtd.yml
|
||||
|
||||
mkdocs:
|
||||
configuration: mkdocs.rtd.yml
|
||||
mkdocs:
|
||||
configuration: mkdocs.rtd.yml
|
||||
|
|
2
Makefile
2
Makefile
|
@ -1,7 +1,7 @@
|
|||
# Image name and tag
|
||||
IMAGE_NAME := amqtt
|
||||
IMAGE_TAG := latest
|
||||
VERSION_TAG := 0.11.2
|
||||
VERSION_TAG := 0.11.3
|
||||
REGISTRY := amqtt/$(IMAGE_NAME)
|
||||
|
||||
# Platforms to build for
|
||||
|
|
15
README.md
15
README.md
|
@ -14,10 +14,21 @@
|
|||
## Features
|
||||
|
||||
- Full set of [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) protocol specifications
|
||||
- Communication over TCP and/or websocket, including support for SSL/TLS
|
||||
- Communication over multiple TCP and/or websocket ports, including support for SSL/TLS
|
||||
- Support QoS 0, QoS 1 and QoS 2 messages flow
|
||||
- Client auto-reconnection on network lost
|
||||
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
|
||||
|
||||
- Plugin framework for functionality expansion; included plugins:
|
||||
- `$SYS` topic publishing
|
||||
- AWS IOT-style shadow states
|
||||
- x509 certificate authentication (including cli cert creation)
|
||||
- Secure file-based password authentication
|
||||
- Configuration-based topic authorization
|
||||
- MySQL, Postgres & SQLite user and/or topic auth (including cli manager)
|
||||
- External server (HTTP) user and/or topic auth
|
||||
- LDAP user and/or topic auth
|
||||
- JWT user and/or topic auth
|
||||
- Fail over session persistence
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
"""INIT."""
|
||||
|
||||
__version__ = "0.11.2"
|
||||
__version__ = "0.11.3"
|
||||
|
|
|
@ -3,6 +3,8 @@ from asyncio import StreamReader, StreamWriter
|
|||
from contextlib import suppress
|
||||
import io
|
||||
import logging
|
||||
import ssl
|
||||
from typing import cast
|
||||
|
||||
from websockets import ConnectionClosed
|
||||
from websockets.asyncio.connection import Connection
|
||||
|
@ -52,6 +54,11 @@ class WriterAdapter(ABC):
|
|||
"""Return peer socket info (remote address and remote port as tuple)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_ssl_info(self) -> ssl.SSLObject | None:
|
||||
"""Return peer certificate information (if available) used to establish a TLS session."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the protocol connection."""
|
||||
|
@ -121,6 +128,9 @@ class WebSocketsWriter(WriterAdapter):
|
|||
remote_address: tuple[str, int] | None = self._protocol.remote_address[:2]
|
||||
return remote_address
|
||||
|
||||
def get_ssl_info(self) -> ssl.SSLObject | None:
|
||||
return cast("ssl.SSLObject", self._protocol.transport.get_extra_info("ssl_object"))
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._protocol.close()
|
||||
|
||||
|
@ -170,6 +180,9 @@ class StreamWriterAdapter(WriterAdapter):
|
|||
extra_info = self._writer.get_extra_info("peername")
|
||||
return extra_info[0], extra_info[1]
|
||||
|
||||
def get_ssl_info(self) -> ssl.SSLObject | None:
|
||||
return cast("ssl.SSLObject", self._writer.get_extra_info("ssl_object"))
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.is_closed:
|
||||
self.is_closed = True # we first mark this closed so yields below don't cause races with waiting writes
|
||||
|
@ -204,6 +217,9 @@ class BufferWriter(WriterAdapter):
|
|||
This adapter simply adapts writing to a byte buffer.
|
||||
"""
|
||||
|
||||
def get_ssl_info(self) -> ssl.SSLObject | None:
|
||||
return None
|
||||
|
||||
def __init__(self, buffer: bytes = b"") -> None:
|
||||
self._stream = io.BytesIO(buffer)
|
||||
|
||||
|
|
292
amqtt/broker.py
292
amqtt/broker.py
|
@ -2,12 +2,12 @@ import asyncio
|
|||
from asyncio import CancelledError, futures
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
import copy
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from math import floor
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
from typing import Any, ClassVar, TypeAlias
|
||||
|
||||
from transitions import Machine, MachineError
|
||||
|
@ -22,24 +22,19 @@ from amqtt.adapters import (
|
|||
WebSocketsWriter,
|
||||
WriterAdapter,
|
||||
)
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
|
||||
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
|
||||
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
|
||||
from amqtt.utils import format_client_message, gen_client_id
|
||||
|
||||
from .events import BrokerEvents
|
||||
from .mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from .mqtt.disconnect import DisconnectPacket
|
||||
from .plugins.manager import PluginManager
|
||||
|
||||
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
||||
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
||||
|
||||
|
||||
_defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
|
||||
|
||||
|
||||
# Default port numbers
|
||||
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
||||
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
||||
|
@ -57,6 +52,8 @@ class RetainedApplicationMessage(ApplicationMessage):
|
|||
|
||||
|
||||
class Server:
|
||||
"""Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
listener_name: str,
|
||||
|
@ -94,28 +91,46 @@ class Server:
|
|||
await self.instance.wait_closed()
|
||||
|
||||
|
||||
class BrokerContext(BaseContext):
|
||||
"""BrokerContext is used as the context passed to plugins interacting with the broker.
|
||||
class ExternalServer(Server):
|
||||
"""For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
|
||||
|
||||
It act as an adapter to broker services from plugins developed for HBMQTT broker.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
super().__init__("aiohttp", None) # type: ignore[arg-type]
|
||||
|
||||
async def acquire_connection(self) -> None:
|
||||
pass
|
||||
|
||||
def release_connection(self) -> None:
|
||||
pass
|
||||
|
||||
async def close_instance(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BrokerContext(BaseContext):
|
||||
"""Used to provide the server's context as well as public methods for accessing internal state."""
|
||||
|
||||
def __init__(self, broker: "Broker") -> None:
|
||||
super().__init__()
|
||||
self.config: _CONFIG_LISTENER | None = None
|
||||
self.config: BrokerConfig | None = None
|
||||
self._broker_instance = broker
|
||||
|
||||
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
||||
"""Send message to all client sessions subscribing to `topic`."""
|
||||
await self._broker_instance.internal_message_broadcast(topic, data, qos)
|
||||
|
||||
def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
||||
self._broker_instance.retain_message(None, topic_name, data, qos)
|
||||
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
||||
await self._broker_instance.retain_message(None, topic_name, data, qos)
|
||||
|
||||
@property
|
||||
def sessions(self) -> Generator[Session]:
|
||||
for session in self._broker_instance.sessions.values():
|
||||
yield session[0]
|
||||
|
||||
def get_session(self, client_id: str) -> Session | None:
|
||||
"""Return the session associated with `client_id`, if it exists."""
|
||||
return self._broker_instance.sessions.get(client_id, (None, None))[0]
|
||||
|
||||
@property
|
||||
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
|
||||
return self._broker_instance.retained_messages
|
||||
|
@ -124,17 +139,33 @@ class BrokerContext(BaseContext):
|
|||
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
|
||||
return self._broker_instance.subscriptions
|
||||
|
||||
async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
|
||||
"""Create a topic subscription for the given `client_id`.
|
||||
|
||||
If a client session doesn't exist for `client_id`, create a disconnected session.
|
||||
If `topic` and `qos` are both `None`, only create the client session.
|
||||
"""
|
||||
if client_id not in self._broker_instance.sessions:
|
||||
broker_handler, session = self._broker_instance.create_offline_session(client_id)
|
||||
self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
|
||||
|
||||
if topic is not None and qos is not None:
|
||||
session, _ = self._broker_instance.sessions[client_id]
|
||||
await self._broker_instance.add_subscription((topic, qos), session)
|
||||
|
||||
|
||||
class Broker:
|
||||
"""MQTT 3.1.1 compliant broker implementation.
|
||||
|
||||
Args:
|
||||
config: dictionary of configuration options (see [broker configuration](broker_config.md)).
|
||||
config: `BrokerConfig` or dictionary of equivalent structure options (see [broker configuration](broker_config.md)).
|
||||
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
|
||||
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
|
||||
|
||||
Raises:
|
||||
BrokerError, ParserError, PluginError
|
||||
BrokerError: problem with broker configuration
|
||||
PluginImportError: if importing a plugin from configuration
|
||||
PluginInitError: if initialization plugin fails
|
||||
|
||||
"""
|
||||
|
||||
|
@ -150,20 +181,20 @@ class Broker:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: _CONFIG_LISTENER | None = None,
|
||||
config: BrokerConfig | dict[str, Any] | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
plugin_namespace: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the broker."""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config = copy.deepcopy(_defaults or {})
|
||||
if config is not None:
|
||||
# if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
|
||||
if ("auth" in config or "topic-check" in config) and "plugins" not in config:
|
||||
# set to None so that the config isn't updated with the new-style default plugin list
|
||||
config["plugins"] = None # type: ignore[assignment]
|
||||
self.config.update(config)
|
||||
self._build_listeners_config(self.config)
|
||||
|
||||
if isinstance(config, dict):
|
||||
self.config = BrokerConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config or BrokerConfig()
|
||||
|
||||
# listeners are populated from default within BrokerConfig
|
||||
self.listeners_config = self.config.listeners
|
||||
|
||||
self._loop = loop or asyncio.get_running_loop()
|
||||
self._servers: dict[str, Server] = {}
|
||||
|
@ -182,6 +213,9 @@ class Broker:
|
|||
# Tasks queue for managing broadcasting tasks
|
||||
self._tasks_queue: deque[asyncio.Task[OutgoingApplicationMessage]] = deque()
|
||||
|
||||
# Task for session monitor
|
||||
self._session_monitor_task: asyncio.Task[Any] | None = None
|
||||
|
||||
# Initialize plugins manager
|
||||
|
||||
context = BrokerContext(self)
|
||||
|
@ -189,26 +223,6 @@ class Broker:
|
|||
namespace = plugin_namespace or "amqtt.broker.plugins"
|
||||
self.plugins_manager = PluginManager(namespace, context, self._loop)
|
||||
|
||||
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
|
||||
self.listeners_config = {}
|
||||
try:
|
||||
listeners_config = broker_config.get("listeners")
|
||||
if not isinstance(listeners_config, dict):
|
||||
msg = "Listener config not found or invalid"
|
||||
raise BrokerError(msg)
|
||||
defaults = listeners_config.get("default")
|
||||
if defaults is None:
|
||||
msg = "Listener config has not default included or is invalid"
|
||||
raise BrokerError(msg)
|
||||
|
||||
for listener_name, listener_conf in listeners_config.items():
|
||||
config = defaults.copy()
|
||||
config.update(listener_conf)
|
||||
self.listeners_config[listener_name] = config
|
||||
except KeyError as ke:
|
||||
msg = f"Listener config not found or invalid: {ke}"
|
||||
raise BrokerError(msg) from ke
|
||||
|
||||
def _init_states(self) -> None:
|
||||
self.transitions = Machine(states=Broker.states, initial="new")
|
||||
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
|
||||
|
@ -245,6 +259,7 @@ class Broker:
|
|||
self.transitions.starting_success()
|
||||
await self.plugins_manager.fire_event(BrokerEvents.POST_START)
|
||||
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
|
||||
self._session_monitor_task = asyncio.create_task(self._session_monitor())
|
||||
self.logger.debug("Broker started")
|
||||
except Exception as e:
|
||||
self.logger.exception("Broker startup failed")
|
||||
|
@ -262,18 +277,27 @@ class Broker:
|
|||
max_connections = listener.get("max_connections", -1)
|
||||
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
|
||||
|
||||
try:
|
||||
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
||||
except ValueError as e:
|
||||
msg = f"Invalid port value in bind value: {listener['bind']}"
|
||||
raise BrokerError(msg) from e
|
||||
# for listeners which are external, don't need to create a server
|
||||
if listener.type == ListenerType.EXTERNAL:
|
||||
|
||||
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
|
||||
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
||||
# broker still needs to associate a new connection to the listener
|
||||
self.logger.info(f"External listener exists for '{listener_name}' ")
|
||||
self._servers[listener_name] = ExternalServer()
|
||||
else:
|
||||
# for tcp and websockets, start servers to listen for inbound connections
|
||||
try:
|
||||
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
||||
except ValueError as e:
|
||||
msg = f"Invalid port value in bind value: {listener['bind']}"
|
||||
raise BrokerError(msg) from e
|
||||
|
||||
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
||||
instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context)
|
||||
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
||||
|
||||
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
|
||||
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
||||
|
||||
@staticmethod
|
||||
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
|
||||
"""Create an SSL context for a listener."""
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(
|
||||
|
@ -295,30 +319,61 @@ class Broker:
|
|||
async def _create_server_instance(
|
||||
self,
|
||||
listener_name: str,
|
||||
listener_type: str,
|
||||
listener_type: ListenerType,
|
||||
address: str | None,
|
||||
port: int,
|
||||
ssl_context: ssl.SSLContext | None,
|
||||
) -> asyncio.Server | websockets.asyncio.server.Server:
|
||||
"""Create a server instance for a listener."""
|
||||
if listener_type == "tcp":
|
||||
return await asyncio.start_server(
|
||||
partial(self.stream_connected, listener_name=listener_name),
|
||||
address,
|
||||
port,
|
||||
reuse_address=True,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
if listener_type == "ws":
|
||||
return await websockets.serve(
|
||||
partial(self.ws_connected, listener_name=listener_name),
|
||||
address,
|
||||
port,
|
||||
ssl=ssl_context,
|
||||
subprotocols=[websockets.Subprotocol("mqtt")],
|
||||
)
|
||||
msg = f"Unsupported listener type: {listener_type}"
|
||||
raise BrokerError(msg)
|
||||
match listener_type:
|
||||
case ListenerType.TCP:
|
||||
return await asyncio.start_server(
|
||||
partial(self.stream_connected, listener_name=listener_name),
|
||||
address,
|
||||
port,
|
||||
reuse_address=True,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
case ListenerType.WS:
|
||||
return await websockets.serve(
|
||||
partial(self.ws_connected, listener_name=listener_name),
|
||||
address,
|
||||
port,
|
||||
ssl=ssl_context,
|
||||
subprotocols=[websockets.Subprotocol("mqtt")],
|
||||
)
|
||||
case _:
|
||||
msg = f"Unsupported listener type: {listener_type}"
|
||||
raise BrokerError(msg)
|
||||
|
||||
async def _session_monitor(self) -> None:
|
||||
|
||||
self.logger.info("Starting session expiration monitor.")
|
||||
|
||||
while True:
|
||||
|
||||
session_count_before = len(self._sessions)
|
||||
|
||||
# clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
|
||||
sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
|
||||
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
|
||||
|
||||
# if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
|
||||
if self.config.session_expiry_interval is not None:
|
||||
retain_after = floor(time.time() - self.config.session_expiry_interval)
|
||||
|
||||
sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
|
||||
if session.transitions.state == "disconnected" and
|
||||
session.last_disconnect_time and
|
||||
session.last_disconnect_time < retain_after]
|
||||
|
||||
for client_id in sessions_to_remove:
|
||||
await self._cleanup_session(client_id)
|
||||
|
||||
if session_count_before > (session_count_after := len(self._sessions)):
|
||||
self.logger.debug(f"Expired {session_count_before - session_count_after} sessions")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop broker instance."""
|
||||
|
@ -337,6 +392,8 @@ class Broker:
|
|||
self.transitions.shutdown()
|
||||
|
||||
await self._shutdown_broadcast_loop()
|
||||
if self._session_monitor_task:
|
||||
self._session_monitor_task.cancel()
|
||||
|
||||
for server in self._servers.values():
|
||||
await server.close_instance()
|
||||
|
@ -372,6 +429,10 @@ class Broker:
|
|||
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
||||
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||
|
||||
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
|
||||
"""Engage the broker in handling the data stream to/from an established connection."""
|
||||
await self._client_connected(listener_name, reader, writer)
|
||||
|
||||
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
||||
"""Handle a new client connection."""
|
||||
server = self._servers.get(listener_name)
|
||||
|
@ -443,7 +504,17 @@ class Broker:
|
|||
# Get session from cache
|
||||
elif client_session.client_id in self._sessions:
|
||||
self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
|
||||
client_session, _ = self._sessions[client_session.client_id]
|
||||
|
||||
# even though the session previously existed, the new connection can bring updated configuration and credentials
|
||||
existing_client_session, _ = self._sessions[client_session.client_id]
|
||||
existing_client_session.will_flag = client_session.will_flag
|
||||
existing_client_session.will_message = client_session.will_message
|
||||
existing_client_session.will_topic = client_session.will_topic
|
||||
existing_client_session.will_qos = client_session.will_qos
|
||||
existing_client_session.keep_alive = client_session.keep_alive
|
||||
existing_client_session.username = client_session.username
|
||||
existing_client_session.password = client_session.password
|
||||
client_session = existing_client_session
|
||||
client_session.parent = 1
|
||||
else:
|
||||
client_session.parent = 0
|
||||
|
@ -455,6 +526,14 @@ class Broker:
|
|||
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
||||
return handler, client_session
|
||||
|
||||
def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
|
||||
session = Session()
|
||||
session.client_id = client_id
|
||||
|
||||
bph = BrokerProtocolHandler(self.plugins_manager, session)
|
||||
session.transitions.disconnect()
|
||||
return bph, session
|
||||
|
||||
async def _handle_client_session(
|
||||
self,
|
||||
reader: ReaderAdapter,
|
||||
|
@ -506,9 +585,7 @@ class Broker:
|
|||
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
|
||||
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
|
||||
for topic in self._subscriptions:
|
||||
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
|
||||
|
||||
|
||||
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
|
||||
|
||||
await self._client_message_loop(client_session, handler)
|
||||
|
||||
|
@ -540,7 +617,6 @@ class Broker:
|
|||
|
||||
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
|
||||
|
||||
|
||||
if subscribe_waiter in done:
|
||||
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
||||
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
||||
|
@ -595,7 +671,7 @@ class Broker:
|
|||
client_session.will_qos,
|
||||
)
|
||||
if client_session.will_retain:
|
||||
self.retain_message(
|
||||
await self.retain_message(
|
||||
client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
|
@ -610,7 +686,6 @@ class Broker:
|
|||
client_id=client_session.client_id,
|
||||
client_session=client_session)
|
||||
|
||||
|
||||
async def _handle_subscription(
|
||||
self,
|
||||
client_session: Session,
|
||||
|
@ -620,7 +695,7 @@ class Broker:
|
|||
"""Handle client subscription."""
|
||||
self.logger.debug(f"{client_session.client_id} handling subscription")
|
||||
subscriptions = subscribe_waiter.result()
|
||||
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
||||
return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
||||
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
||||
for index, subscription in enumerate(subscriptions.topics):
|
||||
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
||||
|
@ -660,6 +735,13 @@ class Broker:
|
|||
self.logger.debug(f"{client_session.client_id} handling message delivery")
|
||||
app_message = wait_deliver.result()
|
||||
|
||||
# notify of a message's receipt, even if a client isn't necessarily allowed to send it
|
||||
await self.plugins_manager.fire_event(
|
||||
BrokerEvents.MESSAGE_RECEIVED,
|
||||
client_id=client_session.client_id,
|
||||
message=app_message,
|
||||
)
|
||||
|
||||
if app_message is None:
|
||||
self.logger.debug("app_message was empty!")
|
||||
return True
|
||||
|
@ -681,16 +763,17 @@ class Broker:
|
|||
|
||||
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
||||
if not permitted:
|
||||
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
|
||||
self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
|
||||
else:
|
||||
# notify that a received message is valid and is allowed to be distributed to other clients
|
||||
await self.plugins_manager.fire_event(
|
||||
BrokerEvents.MESSAGE_RECEIVED,
|
||||
BrokerEvents.MESSAGE_BROADCAST,
|
||||
client_id=client_session.client_id,
|
||||
message=app_message,
|
||||
)
|
||||
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
||||
if app_message.publish_packet and app_message.publish_packet.retain_flag:
|
||||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
return True
|
||||
|
||||
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
||||
|
@ -703,10 +786,11 @@ class Broker:
|
|||
"""Stop a running handler and detach if from the session."""
|
||||
try:
|
||||
await handler.stop()
|
||||
except Exception:
|
||||
# a failure in stopping a handler shouldn't cause the broker to fail
|
||||
except asyncio.QueueEmpty:
|
||||
self.logger.exception("Failed to stop handler")
|
||||
|
||||
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
|
||||
async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
|
||||
"""Call the authenticate method on registered plugins to test user authentication.
|
||||
|
||||
User is considered authenticated if all plugins called returns True.
|
||||
|
@ -719,7 +803,7 @@ class Broker:
|
|||
"""
|
||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||
|
||||
results = [ result for _, result in returns.items() if result is not None] if returns else []
|
||||
results = [result for _, result in returns.items() if result is not None] if returns else []
|
||||
if len(results) < 1:
|
||||
self.logger.debug("Authentication failed: no plugin responded with a boolean")
|
||||
return False
|
||||
|
@ -733,7 +817,7 @@ class Broker:
|
|||
|
||||
return False
|
||||
|
||||
def retain_message(
|
||||
async def retain_message(
|
||||
self,
|
||||
source_session: Session | None,
|
||||
topic_name: str | None,
|
||||
|
@ -744,12 +828,25 @@ class Broker:
|
|||
# If retained flag set, store the message for further subscriptions
|
||||
self.logger.debug(f"Retaining message on topic {topic_name}")
|
||||
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
||||
|
||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||
client_id=None,
|
||||
retained_message=self._retained_messages[topic_name])
|
||||
|
||||
# [MQTT-3.3.1-10]
|
||||
elif topic_name in self._retained_messages:
|
||||
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
||||
|
||||
cleared_message = self._retained_messages[topic_name]
|
||||
cleared_message.data = b""
|
||||
|
||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||
client_id=None,
|
||||
retained_message=cleared_message)
|
||||
|
||||
del self._retained_messages[topic_name]
|
||||
|
||||
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
||||
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
||||
topic_filter, qos = subscription
|
||||
if "#" in topic_filter and not topic_filter.endswith("#"):
|
||||
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
||||
|
@ -853,7 +950,8 @@ class Broker:
|
|||
task.result()
|
||||
except CancelledError:
|
||||
self.logger.info(f"Task has been cancelled: {task}")
|
||||
except Exception:
|
||||
# if a task fails, don't want it to cause the broker to fail
|
||||
except Exception: # pylint: disable=W0718
|
||||
self.logger.exception(f"Task failed and will be skipped: {task}")
|
||||
|
||||
run_broadcast_task = asyncio.ensure_future(self._run_broadcast(running_tasks))
|
||||
|
@ -892,6 +990,12 @@ class Broker:
|
|||
for target_session, sub_qos in subscriptions:
|
||||
qos = broadcast.get("qos", sub_qos)
|
||||
|
||||
sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
|
||||
if not sendable:
|
||||
self.logger.info(
|
||||
f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
|
||||
continue
|
||||
|
||||
# Retain all messages which cannot be broadcasted, due to the session not being connected
|
||||
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
|
||||
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
|
||||
|
@ -934,6 +1038,10 @@ class Broker:
|
|||
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
|
||||
await target_session.retained_messages.put(retained_message)
|
||||
|
||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||
client_id=target_session.client_id,
|
||||
retained_message=retained_message)
|
||||
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")
|
||||
|
||||
|
|
|
@ -2,10 +2,8 @@ import asyncio
|
|||
from collections import deque
|
||||
from collections.abc import Callable, Coroutine
|
||||
import contextlib
|
||||
import copy
|
||||
from functools import wraps
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
@ -19,20 +17,18 @@ from amqtt.adapters import (
|
|||
WebSocketsReader,
|
||||
WebSocketsWriter,
|
||||
)
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.contexts import BaseContext, ClientConfig
|
||||
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
|
||||
from amqtt.mqtt.connack import CONNECTION_ACCEPTED
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
from amqtt.utils import gen_client_id, read_yaml_config
|
||||
from amqtt.utils import gen_client_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
_defaults: dict[str, Any] | None = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
|
||||
|
||||
|
||||
class ClientContext(BaseContext):
|
||||
"""ClientContext is used as the context passed to plugins interacting with the client.
|
||||
|
@ -42,7 +38,7 @@ class ClientContext(BaseContext):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.config: ClientConfig | None = None
|
||||
|
||||
|
||||
base_logger = logging.getLogger(__name__)
|
||||
|
@ -79,26 +75,27 @@ def mqtt_connected(func: _F) -> _F:
|
|||
|
||||
|
||||
class MQTTClient:
|
||||
"""MQTT client implementation.
|
||||
|
||||
MQTTClient instances provides API for connecting to a broker and send/receive
|
||||
messages using the MQTT protocol.
|
||||
"""MQTT client implementation, providing an API for connecting to a broker and send/receive messages using the MQTT protocol.
|
||||
|
||||
Args:
|
||||
client_id: MQTT client ID to use when connecting to the broker. If none,
|
||||
it will be generated randomly by `amqtt.utils.gen_client_id`
|
||||
config: dictionary of configuration options (see [client configuration](client_config.md)).
|
||||
config: `ClientConfig` or dictionary of equivalent structure options (see [client configuration](client_config.md)).
|
||||
|
||||
Raises:
|
||||
PluginError
|
||||
PluginImportError: if importing a plugin from configuration fails
|
||||
PluginInitError: if initialization plugin fails
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config = copy.deepcopy(_defaults or {})
|
||||
if config is not None:
|
||||
self.config.update(config)
|
||||
|
||||
if isinstance(config, dict):
|
||||
self.config = ClientConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config or ClientConfig()
|
||||
|
||||
self.client_id = client_id if client_id is not None else gen_client_id()
|
||||
|
||||
self.session: Session | None = None
|
||||
|
@ -146,7 +143,7 @@ class MQTTClient:
|
|||
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
|
||||
|
||||
Raises:
|
||||
ClientError, ConnectError
|
||||
ConnectError: could not connect to broker
|
||||
|
||||
"""
|
||||
additional_headers = additional_headers if additional_headers is not None else {}
|
||||
|
@ -159,7 +156,8 @@ class MQTTClient:
|
|||
except asyncio.CancelledError as e:
|
||||
msg = "Future or Task was cancelled"
|
||||
raise ConnectError(msg) from e
|
||||
except Exception as e:
|
||||
# no matter the failure mode, still try to reconnect
|
||||
except Exception as e: # pylint: disable=W0718
|
||||
self.logger.warning(f"Connection failed: {e!r}")
|
||||
if not self.config.get("auto_reconnect", False):
|
||||
raise
|
||||
|
@ -233,7 +231,8 @@ class MQTTClient:
|
|||
except asyncio.CancelledError as e:
|
||||
msg = "Future or Task was cancelled"
|
||||
raise ConnectError(msg) from e
|
||||
except Exception as e:
|
||||
# no matter the failure mode, still try to reconnect
|
||||
except Exception as e: # pylint: disable=W0718
|
||||
self.logger.warning(f"Reconnection attempt failed: {e!r}")
|
||||
self.logger.debug("", exc_info=True)
|
||||
if 0 <= reconnect_retries < nb_attempt:
|
||||
|
@ -381,6 +380,7 @@ class MQTTClient:
|
|||
|
||||
Raises:
|
||||
asyncio.TimeoutError: if timeout occurs before a message is delivered
|
||||
ClientError: if client is not connected
|
||||
|
||||
"""
|
||||
if self._handler is None:
|
||||
|
@ -424,14 +424,10 @@ class MQTTClient:
|
|||
scheme = uri_attributes.scheme
|
||||
secure = scheme in ("mqtts", "wss")
|
||||
self.session.username = (
|
||||
self.session.username
|
||||
if self.session.username
|
||||
else (str(uri_attributes.username) if uri_attributes.username else None)
|
||||
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
|
||||
)
|
||||
self.session.password = (
|
||||
self.session.password
|
||||
if self.session.password
|
||||
else (str(uri_attributes.password) if uri_attributes.password else None)
|
||||
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
|
||||
)
|
||||
self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
|
||||
self.session.remote_port = uri_attributes.port
|
||||
|
@ -462,15 +458,15 @@ class MQTTClient:
|
|||
if secure:
|
||||
sc = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH,
|
||||
cafile=self.session.cafile,
|
||||
capath=self.session.capath,
|
||||
cadata=self.session.cadata,
|
||||
cafile=self.session.cafile
|
||||
)
|
||||
|
||||
if "certfile" in self.config:
|
||||
sc.load_verify_locations(cafile=self.config["certfile"])
|
||||
if "check_hostname" in self.config and isinstance(self.config["check_hostname"], bool):
|
||||
sc.check_hostname = self.config["check_hostname"]
|
||||
if self.config.connection.certfile and self.config.connection.keyfile:
|
||||
sc.load_cert_chain(certfile=self.config.connection.certfile, keyfile=self.config.connection.keyfile)
|
||||
if self.config.connection.cafile:
|
||||
sc.load_verify_locations(cafile=self.config.connection.cafile)
|
||||
if self.config.check_hostname is not None:
|
||||
sc.check_hostname = self.config.check_hostname
|
||||
sc.verify_mode = ssl.CERT_REQUIRED
|
||||
kwargs["ssl"] = sc
|
||||
|
||||
|
@ -525,7 +521,7 @@ class MQTTClient:
|
|||
self._connected_state.set()
|
||||
self.logger.debug(f"Connected to {self.session.remote_address}:{self.session.remote_port}")
|
||||
|
||||
except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError) as e:
|
||||
except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError, asyncio.TimeoutError) as e:
|
||||
self.logger.debug(f"Connection failed : {self.session.broker_uri} [{e!r}]")
|
||||
self.session.transitions.disconnect()
|
||||
raise ConnectError(e) from e
|
||||
|
@ -581,10 +577,18 @@ class MQTTClient:
|
|||
cadata: str | None = None,
|
||||
) -> Session:
|
||||
"""Initialize the MQTT session."""
|
||||
broker_conf = self.config.get("broker", {}).copy()
|
||||
broker_conf.update(
|
||||
{k: v for k, v in {"uri": uri, "cafile": cafile, "capath": capath, "cadata": cadata}.items() if v is not None},
|
||||
)
|
||||
broker_conf = self.config.get("connection", {}).copy()
|
||||
|
||||
if uri is not None:
|
||||
broker_conf.uri = uri
|
||||
if cleansession is not None:
|
||||
self.config.cleansession = cleansession
|
||||
if cafile is not None:
|
||||
broker_conf.cafile = cafile
|
||||
if capath is not None:
|
||||
broker_conf.capath = capath
|
||||
if cadata is not None:
|
||||
broker_conf.cadata = cadata
|
||||
|
||||
if not broker_conf.get("uri"):
|
||||
msg = "Missing connection parameter 'uri'"
|
||||
|
@ -593,15 +597,12 @@ class MQTTClient:
|
|||
session = Session()
|
||||
session.broker_uri = broker_conf["uri"]
|
||||
session.client_id = self.client_id
|
||||
|
||||
session.cafile = broker_conf.get("cafile")
|
||||
session.capath = broker_conf.get("capath")
|
||||
session.cadata = broker_conf.get("cadata")
|
||||
|
||||
if cleansession is not None:
|
||||
broker_conf["cleansession"] = cleansession # noop?
|
||||
session.clean_session = cleansession
|
||||
else:
|
||||
session.clean_session = self.config.get("cleansession", True)
|
||||
session.clean_session = self.config.get("cleansession", True)
|
||||
|
||||
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]
|
||||
|
||||
|
|
|
@ -142,8 +142,8 @@ def int_to_bytes_str(value: int) -> bytes:
|
|||
return str(value).encode("utf-8")
|
||||
|
||||
|
||||
def float_to_bytes_str(value: float, places:int=3) -> bytes:
|
||||
def float_to_bytes_str(value: float, places: int = 3) -> bytes:
|
||||
"""Convert an float value to a bytes array containing the numeric character."""
|
||||
quant = Decimal(f"0.{''.join(['0' for i in range(places-1)])}1")
|
||||
quant = Decimal(f"0.{''.join(['0' for i in range(places - 1)])}1")
|
||||
rounded = Decimal(value).quantize(quant, rounding=ROUND_HALF_UP)
|
||||
return str(rounded).encode("utf-8")
|
||||
|
|
|
@ -1,22 +1,379 @@
|
|||
from enum import Enum
|
||||
from dataclasses import dataclass, field, fields, replace
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import warnings
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
try:
|
||||
from enum import Enum, StrEnum
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
|
||||
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseContext:
|
||||
def __init__(self) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.logger: logging.Logger = _LOGGER
|
||||
self.config: dict[str, Any] | None = None
|
||||
self.logger: logging.Logger = logging.getLogger(__name__)
|
||||
# cleanup with a `Generic` type
|
||||
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
class Action(StrEnum):
|
||||
"""Actions issued by the broker."""
|
||||
|
||||
SUBSCRIBE = "subscribe"
|
||||
PUBLISH = "publish"
|
||||
RECEIVE = "receive"
|
||||
|
||||
|
||||
class ListenerType(StrEnum):
|
||||
"""Types of mqtt listeners."""
|
||||
|
||||
TCP = "tcp"
|
||||
WS = "ws"
|
||||
EXTERNAL = "external"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Display the string value, instead of the enum member."""
|
||||
return f'"{self.value!s}"'
|
||||
|
||||
|
||||
class Dictable:
|
||||
"""Add dictionary methods to a dataclass."""
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Allow dict-style `[]` access to a dataclass."""
|
||||
return self.get(key)
|
||||
|
||||
def get(self, name: str, default: Any = None) -> Any:
|
||||
"""Allow dict-style access to a dataclass."""
|
||||
name = name.replace("-", "_")
|
||||
if hasattr(self, name):
|
||||
return getattr(self, name)
|
||||
if default is not None:
|
||||
return default
|
||||
msg = f"'{name}' is not defined"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Provide dict-style 'in' check."""
|
||||
return getattr(self, name.replace("-", "_"), None) is not None
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Provide dict-style iteration."""
|
||||
for f in fields(self): # type: ignore[arg-type]
|
||||
yield getattr(self, f.name)
|
||||
|
||||
def copy(self) -> dataclass: # type: ignore[valid-type]
|
||||
"""Return a copy of the dataclass."""
|
||||
return replace(self) # type: ignore[type-var]
|
||||
|
||||
@staticmethod
|
||||
def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return value # It's already a list of dicts
|
||||
if isinstance(value, dict):
|
||||
return [value] # Promote single dict to a list
|
||||
msg = "Could not convert 'list' to 'list[dict[str, Any]]'"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListenerConfig(Dictable):
|
||||
"""Structured configuration for a broker's listeners."""
|
||||
|
||||
type: ListenerType = ListenerType.TCP
|
||||
"""Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'"""
|
||||
bind: str | None = "0.0.0.0:1883"
|
||||
"""address and port for the listener to bind to"""
|
||||
max_connections: int = 0
|
||||
"""max number of connections allowed for this listener"""
|
||||
ssl: bool = False
|
||||
"""secured by ssl"""
|
||||
cafile: str | Path | None = None
|
||||
"""Path to a file of concatenated CA certificates in PEM format. See
|
||||
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
|
||||
capath: str | Path | None = None
|
||||
"""Path to a directory containing one or more CA certificates in PEM format, following the
|
||||
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
|
||||
cadata: str | Path | None = None
|
||||
"""Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates."""
|
||||
certfile: str | Path | None = None
|
||||
"""Full path to file in PEM format containing the server's certificate (as well as any number of CA
|
||||
certificates needed to establish the certificate's authenticity.)"""
|
||||
keyfile: str | Path | None = None
|
||||
"""Full path to file in PEM format containing the server's private key."""
|
||||
reader: str | None = None
|
||||
writer: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if (self.certfile is None) ^ (self.keyfile is None):
|
||||
msg = "If specifying the 'certfile' or 'keyfile', both are required."
|
||||
raise ValueError(msg)
|
||||
|
||||
for fn in ("cafile", "capath", "certfile", "keyfile"):
|
||||
if isinstance(getattr(self, fn), str):
|
||||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
if getattr(self, fn) and not getattr(self, fn).exists():
|
||||
msg = f"'{fn}' does not exist : {getattr(self, fn)}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
def apply(self, other: "ListenerConfig") -> None:
|
||||
"""Apply the field from 'other', if 'self' field is default."""
|
||||
for f in fields(self):
|
||||
if getattr(self, f.name) == f.default:
|
||||
setattr(self, f.name, other[f.name])
|
||||
|
||||
|
||||
def default_listeners() -> dict[str, Any]:
|
||||
"""Create defaults for BrokerConfig.listeners."""
|
||||
return {
|
||||
"default": ListenerConfig()
|
||||
}
|
||||
|
||||
|
||||
def default_broker_plugins() -> dict[str, Any]:
|
||||
"""Create defaults for BrokerConfig.plugins."""
|
||||
return {
|
||||
"amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrokerConfig(Dictable):
|
||||
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
||||
|
||||
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
|
||||
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
|
||||
does not set a value, the 'default' settings are applied. See
|
||||
[`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
|
||||
sys_interval: int | None = None
|
||||
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics)
|
||||
for recommended configuration.*"""
|
||||
timeout_disconnect_delay: int | None = 0
|
||||
"""Client disconnect timeout without a keep-alive."""
|
||||
session_expiry_interval: int | None = None
|
||||
"""Seconds for an inactive session to be retained."""
|
||||
auth: dict[str, Any] | None = None
|
||||
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and
|
||||
[`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*"""
|
||||
topic_check: dict[str, Any] | None = None
|
||||
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
|
||||
[`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
|
||||
plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
|
||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
||||
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
||||
[custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
|
||||
to support legacy use cases."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.sys_interval is not None:
|
||||
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
|
||||
|
||||
if self.auth is not None or self.topic_check is not None:
|
||||
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
|
||||
|
||||
default_listener = self.listeners["default"]
|
||||
for listener_name, listener in self.listeners.items():
|
||||
if listener_name == "default":
|
||||
continue
|
||||
listener.apply(default_listener)
|
||||
|
||||
if isinstance(self.plugins, list):
|
||||
_plugins: dict[str, Any] = {}
|
||||
for plugin in self.plugins:
|
||||
# in case a plugin in a yaml file is listed without config map
|
||||
if isinstance(plugin, str):
|
||||
_plugins |= {plugin: {}}
|
||||
continue
|
||||
_plugins |= plugin
|
||||
self.plugins = _plugins
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
|
||||
"""Create a broker config from a dictionary."""
|
||||
if d is None:
|
||||
return BrokerConfig()
|
||||
|
||||
# patch the incoming dictionary so it can be loaded correctly
|
||||
if "topic-check" in d:
|
||||
d["topic_check"] = d["topic-check"]
|
||||
del d["topic-check"]
|
||||
|
||||
# identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
|
||||
if ("auth" in d or "topic-check" in d) and "plugins" not in d:
|
||||
d["plugins"] = None
|
||||
|
||||
return dict_to_dataclass(data_class=BrokerConfig,
|
||||
data=d,
|
||||
config=DaciteConfig(
|
||||
cast=[StrEnum, ListenerType],
|
||||
strict=True,
|
||||
type_hooks={list[dict[str, Any]]: cls._coerce_lists}
|
||||
))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig(Dictable):
|
||||
"""Properties for connecting to the broker."""
|
||||
|
||||
uri: str | None = "mqtt://127.0.0.1:1883"
|
||||
"""URI of the broker"""
|
||||
cafile: str | Path | None = None
|
||||
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
|
||||
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
|
||||
capath: str | Path | None = None
|
||||
"""Path to a directory containing one or more CA certificates in PEM format, following the
|
||||
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
|
||||
cadata: str | None = None
|
||||
"""The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded
|
||||
certificates or a bytes-like object of DER-encoded certificates."""
|
||||
certfile: str | Path | None = None
|
||||
"""Full path to file in PEM format containing the client's certificate (as well as any number of CA
|
||||
certificates needed to establish the certificate's authenticity.)"""
|
||||
keyfile: str | Path | None = None
|
||||
"""Full path to file in PEM format containing the client's private key associated with the certfile."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if (self.certfile is None) ^ (self.keyfile is None):
|
||||
msg = "If specifying the 'certfile' or 'keyfile', both are required."
|
||||
raise ValueError(msg)
|
||||
|
||||
for fn in ("cafile", "capath", "certfile", "keyfile"):
|
||||
if isinstance(getattr(self, fn), str):
|
||||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopicConfig(Dictable):
|
||||
"""Configuration of how messages to specific topics are published.
|
||||
|
||||
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
|
||||
"""
|
||||
|
||||
qos: int = 0
|
||||
"""The quality of service associated with the publishing to this topic."""
|
||||
retain: bool = False
|
||||
"""Determines if the message should be retained by the topic it was published."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
|
||||
msg = "Topic config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WillConfig(Dictable):
|
||||
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
|
||||
|
||||
topic: str
|
||||
"""The will message will be published to this topic."""
|
||||
message: str
|
||||
"""The contents of the message to be published."""
|
||||
qos: int | None = QOS_0
|
||||
"""The quality of service associated with sending this message."""
|
||||
retain: bool | None = False
|
||||
"""Determines if the message should be retained by the topic it was published."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
|
||||
msg = "Will config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def default_client_plugins() -> dict[str, Any]:
|
||||
"""Create defaults for `ClientConfig.plugins`."""
|
||||
return {
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig(Dictable):
|
||||
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
||||
|
||||
keep_alive: int | None = 10
|
||||
"""Keep-alive timeout sent to the broker."""
|
||||
ping_delay: int | None = 1
|
||||
"""Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection."""
|
||||
default_qos: int | None = QOS_0
|
||||
"""Default QoS for messages published."""
|
||||
default_retain: bool | None = False
|
||||
"""Default retain value to messages published."""
|
||||
auto_reconnect: bool | None = True
|
||||
"""Enable or disable auto-reconnect if connection with the broker is interrupted."""
|
||||
connection_timeout: int | None = 60
|
||||
"""The number of seconds before a connection times out"""
|
||||
reconnect_retries: int | None = 2
|
||||
"""Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely."""
|
||||
reconnect_max_interval: int | None = 10
|
||||
"""Maximum seconds to wait before retrying a connection."""
|
||||
cleansession: bool | None = True
|
||||
"""Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`"""
|
||||
topics: dict[str, TopicConfig] | None = field(default_factory=dict)
|
||||
"""Specify the topics and what flags should be set for messages published to them."""
|
||||
broker: ConnectionConfig | None = None
|
||||
"""*Deprecated* Configuration for connecting to the broker. Use `connection` field instead."""
|
||||
connection: ConnectionConfig = field(default_factory=ConnectionConfig)
|
||||
"""Configuration for connecting to the broker. See
|
||||
[`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information."""
|
||||
plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
|
||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
|
||||
a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for
|
||||
more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases."""
|
||||
check_hostname: bool | None = True
|
||||
"""If establishing a secure connection, should the hostname of the certificate be verified."""
|
||||
will: WillConfig | None = None
|
||||
"""Message, topic and flags that should be sent to if the client disconnects. See
|
||||
[`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
|
||||
msg = "Client config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.broker is not None:
|
||||
warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2)
|
||||
self.connection = self.broker
|
||||
|
||||
if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile):
|
||||
msg = "Connection key and certificate files are _both_ required."
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
|
||||
"""Create a client config from a dictionary."""
|
||||
if d is None:
|
||||
return ClientConfig()
|
||||
|
||||
return dict_to_dataclass(data_class=ClientConfig,
|
||||
data=d,
|
||||
config=DaciteConfig(
|
||||
cast=[StrEnum],
|
||||
strict=True)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
"""Module for contributed plugins."""
|
||||
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import JSON, TypeDecorator
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, dataclass_type: type[T]) -> None:
|
||||
if not is_dataclass(dataclass_type):
|
||||
msg = f"{dataclass_type} must be a dataclass type"
|
||||
raise TypeError(msg)
|
||||
self.dataclass_type = dataclass_type
|
||||
super().__init__()
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
value: list[Any] | None, # Python -> DB
|
||||
dialect: Any
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if value is None:
|
||||
return None
|
||||
return [asdict(item) for item in value]
|
||||
|
||||
def process_result_value(
|
||||
self,
|
||||
value: list[dict[str, Any]] | None, # DB -> Python
|
||||
dialect: Any
|
||||
) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
return [self.dataclass_type(**item) for item in value]
|
||||
|
||||
def process_literal_param(self, value: Any, dialect: Any) -> Any:
|
||||
# Required by SQLAlchemy, typically used for literal SQL rendering.
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self) -> type:
|
||||
# Required by TypeEngine to indicate the expected Python type.
|
||||
return list
|
|
@ -0,0 +1,52 @@
|
|||
"""Plugin to determine authentication of clients with DB storage."""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import click
|
||||
|
||||
try:
|
||||
from enum import StrEnum
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin
|
||||
|
||||
|
||||
class DBType(StrEnum):
|
||||
"""Enumeration for supported relational databases."""
|
||||
|
||||
MARIA = "mariadb"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
SQLITE = "sqlite"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DBInfo:
|
||||
"""SQLAlchemy database information."""
|
||||
|
||||
connect_str: str
|
||||
connect_port: int | None
|
||||
|
||||
|
||||
_db_map = {
|
||||
DBType.MARIA: DBInfo("mysql+aiomysql", 3306),
|
||||
DBType.MYSQL: DBInfo("mysql+aiomysql", 3306),
|
||||
DBType.POSTGRESQL: DBInfo("postgresql+asyncpg", 5432),
|
||||
DBType.SQLITE: DBInfo("sqlite+aiosqlite", None)
|
||||
}
|
||||
|
||||
|
||||
def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str:
|
||||
"""Create sqlalchemy database connection string."""
|
||||
db_info = _db_map[db_type]
|
||||
if db_type == DBType.SQLITE:
|
||||
return f"{db_info.connect_str}:///{db_filename}"
|
||||
db_password = click.prompt("Enter the db password (press enter for none)", hide_input=True)
|
||||
pwd = f":{db_password}" if db_password else ""
|
||||
return f"{db_info.connect_str}://{db_username}:{pwd}@{db_host}:{db_port or db_info.connect_port}"
|
||||
|
||||
|
||||
__all__ = ["DBType", "TopicAuthDBPlugin", "UserAuthDBPlugin", "db_connection_str"]
|
|
@ -0,0 +1,190 @@
|
|||
from collections.abc import Iterator
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.contrib.auth_db.models import AllowedTopic, Base, TopicAuth, UserAuth
|
||||
from amqtt.errors import MQTTError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserManager:
|
||||
|
||||
def __init__(self, connection: str) -> None:
|
||||
self._engine = create_async_engine(connection)
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
async def db_sync(self) -> None:
|
||||
"""Sync the database schema."""
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
@staticmethod
|
||||
async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> UserAuth:
|
||||
stmt = select(UserAuth).filter(UserAuth.username == username)
|
||||
user_auth = await db_session.scalar(stmt)
|
||||
if not user_auth:
|
||||
msg = f"Username '{username}' doesn't exist."
|
||||
logger.debug(msg)
|
||||
raise MQTTError(msg)
|
||||
|
||||
return user_auth
|
||||
|
||||
async def get_user_auth(self, username: str) -> UserAuth | None:
|
||||
"""Retrieve a user by username."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
try:
|
||||
return await self._get_auth_or_raise(db_session, username)
|
||||
except MQTTError:
|
||||
return None
|
||||
|
||||
async def list_user_auths(self) -> Iterator[UserAuth]:
|
||||
"""Return list of all clients."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(UserAuth).order_by(UserAuth.username)
|
||||
users = await db_session.scalars(stmt)
|
||||
if not users:
|
||||
msg = "No users exist."
|
||||
logger.info(msg)
|
||||
raise MQTTError(msg)
|
||||
return users
|
||||
|
||||
async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None:
|
||||
"""Create a new user."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(UserAuth).filter(UserAuth.username == username)
|
||||
user_auth = await db_session.scalar(stmt)
|
||||
if user_auth:
|
||||
msg = f"Username '{username}' already exists."
|
||||
logger.info(msg)
|
||||
raise MQTTError(msg)
|
||||
|
||||
user_auth = UserAuth(username=username)
|
||||
user_auth.password = plain_password
|
||||
|
||||
db_session.add(user_auth)
|
||||
await db_session.commit()
|
||||
await db_session.flush()
|
||||
return user_auth
|
||||
|
||||
async def delete_user_auth(self, username: str) -> UserAuth | None:
|
||||
"""Delete a user."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
|
||||
try:
|
||||
user_auth = await self._get_auth_or_raise(db_session, username)
|
||||
except MQTTError:
|
||||
return None
|
||||
|
||||
await db_session.delete(user_auth)
|
||||
await db_session.commit()
|
||||
await db_session.flush()
|
||||
return user_auth
|
||||
|
||||
async def update_user_auth_password(self, username: str, plain_password: str) -> UserAuth | None:
|
||||
"""Change a user's password."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
user_auth = await self._get_auth_or_raise(db_session, username)
|
||||
user_auth.password = plain_password
|
||||
await db_session.commit()
|
||||
await db_session.flush()
|
||||
return user_auth
|
||||
|
||||
|
||||
class TopicManager:
|
||||
|
||||
def __init__(self, connection: str) -> None:
|
||||
self._engine = create_async_engine(connection)
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
async def db_sync(self) -> None:
|
||||
"""Sync the database schema."""
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
@staticmethod
|
||||
async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> TopicAuth:
|
||||
stmt = select(TopicAuth).filter(TopicAuth.username == username)
|
||||
topic_auth = await db_session.scalar(stmt)
|
||||
if not topic_auth:
|
||||
msg = f"Username '{username}' doesn't exist."
|
||||
logger.debug(msg)
|
||||
raise MQTTError(msg)
|
||||
|
||||
return topic_auth
|
||||
|
||||
@staticmethod
|
||||
def _field_name(action: Action) -> str:
|
||||
return f"{action}_acl"
|
||||
|
||||
async def create_topic_auth(self, username: str) -> TopicAuth | None:
|
||||
"""Create a new user."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(TopicAuth).filter(TopicAuth.username == username)
|
||||
topic_auth = await db_session.scalar(stmt)
|
||||
if topic_auth:
|
||||
msg = f"Username '{username}' already exists."
|
||||
raise MQTTError(msg)
|
||||
|
||||
topic_auth = TopicAuth(username=username)
|
||||
|
||||
db_session.add(topic_auth)
|
||||
await db_session.commit()
|
||||
await db_session.flush()
|
||||
return topic_auth
|
||||
|
||||
async def get_topic_auth(self, username: str) -> TopicAuth | None:
|
||||
"""Retrieve a allowed topics by username."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
try:
|
||||
return await self._get_auth_or_raise(db_session, username)
|
||||
except MQTTError:
|
||||
return None
|
||||
|
||||
async def list_topic_auths(self) -> Iterator[TopicAuth]:
|
||||
"""Return list of all authorized clients."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(TopicAuth).order_by(TopicAuth.username)
|
||||
topics = await db_session.scalars(stmt)
|
||||
if not topics:
|
||||
msg = "No topics exist."
|
||||
logger.info(msg)
|
||||
raise MQTTError(msg)
|
||||
return topics
|
||||
|
||||
async def add_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
|
||||
"""Add allowed topic from action for user."""
|
||||
if action == Action.PUBLISH and topic.startswith("$"):
|
||||
msg = "MQTT does not allow clients to publish to $ topics."
|
||||
raise MQTTError(msg)
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
user_auth = await self._get_auth_or_raise(db_session, username)
|
||||
topic_list = getattr(user_auth, self._field_name(action))
|
||||
|
||||
updated_list = [*topic_list, AllowedTopic(topic)]
|
||||
setattr(user_auth, self._field_name(action), updated_list)
|
||||
await db_session.commit()
|
||||
await db_session.flush()
|
||||
return updated_list
|
||||
|
||||
async def remove_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
|
||||
"""Remove topic from action for user."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
topic_auth = await self._get_auth_or_raise(db_session, username)
|
||||
topic_list = topic_auth.get_topic_list(action)
|
||||
|
||||
if AllowedTopic(topic) not in topic_list:
|
||||
msg = f"Client '{username}' doesn't have topic '{topic}' for action '{action}'."
|
||||
logger.debug(msg)
|
||||
raise MQTTError(msg)
|
||||
|
||||
updated_list = [allowed_topic for allowed_topic in topic_list if allowed_topic != AllowedTopic(topic)]
|
||||
|
||||
setattr(topic_auth, f"{action}_acl", updated_list)
|
||||
await db_session.commit()
|
||||
await db_session.flush()
|
||||
return updated_list
|
|
@ -0,0 +1,124 @@
|
|||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.contrib import DataClassListJSON
|
||||
from amqtt.plugins import TopicMatcher
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from passlib.context import CryptContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
matcher = TopicMatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllowedTopic:
|
||||
topic: str
|
||||
|
||||
def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool:
|
||||
"""Determine `in`."""
|
||||
return self.__eq__(item)
|
||||
|
||||
def __eq__(self, item: object) -> bool:
|
||||
"""Determine `==` or `!=`."""
|
||||
if isinstance(item, str):
|
||||
return matcher.is_topic_allowed(item, self.topic)
|
||||
if isinstance(item, AllowedTopic):
|
||||
return item.topic == self.topic
|
||||
msg = "AllowedTopic can only be compared to another AllowedTopic or string."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Display topic."""
|
||||
return self.topic
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Display topic."""
|
||||
return self.topic
|
||||
|
||||
|
||||
class PasswordHasher:
|
||||
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""
|
||||
|
||||
_instance: Optional["PasswordHasher"] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
if not hasattr(self, "_crypt_context"):
|
||||
self._crypt_context: CryptContext | None = None
|
||||
|
||||
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "PasswordHasher":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
@property
|
||||
def crypt_context(self) -> "CryptContext":
|
||||
if not self._crypt_context:
|
||||
msg = "CryptContext is empty"
|
||||
raise ValueError(msg)
|
||||
return self._crypt_context
|
||||
|
||||
@crypt_context.setter
|
||||
def crypt_context(self, value: "CryptContext") -> None:
|
||||
self._crypt_context = value
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class UserAuth(Base):
|
||||
__tablename__ = "user_auth"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
username: Mapped[str] = mapped_column(String, unique=True)
|
||||
_password_hash: Mapped[str] = mapped_column("password_hash", String(128))
|
||||
|
||||
publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
|
||||
subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
|
||||
receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
|
||||
|
||||
@hybrid_property
|
||||
def password(self) -> None:
|
||||
msg = "Password is write-only"
|
||||
raise AttributeError(msg)
|
||||
|
||||
@password.inplace.setter # type: ignore[arg-type]
|
||||
def _password_setter(self, plain_password: str) -> None:
|
||||
self._password_hash = PasswordHasher().crypt_context.hash(plain_password)
|
||||
|
||||
def verify_password(self, plain_password: str) -> bool:
|
||||
return bool(PasswordHasher().crypt_context.verify(plain_password, self._password_hash))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Display client id and password hash."""
|
||||
return f"'{self.username}' with password hash: {self._password_hash}"
|
||||
|
||||
|
||||
class TopicAuth(Base):
|
||||
__tablename__ = "topic_auth"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
username: Mapped[str] = mapped_column(String, unique=True)
|
||||
|
||||
publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
|
||||
subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
|
||||
receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
|
||||
|
||||
def get_topic_list(self, action: Action) -> list[AllowedTopic]:
|
||||
return cast("list[AllowedTopic]", getattr(self, f"{action}_acl"))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Display client id and password hash."""
|
||||
return f"""'{self.username}':
|
||||
\tpublish: {self.publish_acl}, subscribe: {self.subscribe_acl}, receive: {self.receive_acl}
|
||||
"""
|
|
@ -0,0 +1,111 @@
|
|||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.contrib.auth_db.managers import TopicManager, UserManager
|
||||
from amqtt.contrib.auth_db.models import Base, PasswordHasher
|
||||
from amqtt.errors import MQTTError
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_hash_scheme() -> list[str]:
|
||||
"""Create config dataclass defaults."""
|
||||
return ["argon2", "bcrypt", "pbkdf2_sha256", "scrypt"]
|
||||
|
||||
|
||||
class UserAuthDBPlugin(BaseAuthPlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
# access the singleton and set the proper crypt context
|
||||
pwd_hasher = PasswordHasher()
|
||||
pwd_hasher.crypt_context = CryptContext(schemes=self.config.hash_schemes, deprecated="auto")
|
||||
|
||||
self._user_manager = UserManager(self.config.connection)
|
||||
self._engine = create_async_engine(f"{self.config.connection}")
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Sync the schema (if configured)."""
|
||||
if not self.config.sync_schema:
|
||||
return
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Authenticate a client's session."""
|
||||
if not session.username or not session.password:
|
||||
return False
|
||||
|
||||
user_auth = await self._user_manager.get_user_auth(session.username)
|
||||
if not user_auth:
|
||||
return False
|
||||
|
||||
return bool(session.password) and user_auth.verify_password(session.password)
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for DB authentication."""
|
||||
|
||||
connection: str
|
||||
"""SQLAlchemy connection string for the asyncio version of the database connector:
|
||||
|
||||
- `mysql+aiomysql://user:password@host:port/dbname`
|
||||
- `postgresql+asyncpg://user:password@host:port/dbname`
|
||||
- `sqlite+aiosqlite:///dbfilename.db`
|
||||
"""
|
||||
sync_schema: bool = False
|
||||
"""Use SQLAlchemy to create / update the database schema."""
|
||||
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
|
||||
"""list of hash schemes to use for passwords"""
|
||||
|
||||
|
||||
class TopicAuthDBPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self._topic_manager = TopicManager(self.config.connection)
|
||||
self._engine = create_async_engine(f"{self.config.connection}")
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Sync the schema (if configured)."""
|
||||
if not self.config.sync_schema:
|
||||
return
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool | None:
|
||||
if not session or not session.username or not topic:
|
||||
return None
|
||||
|
||||
try:
|
||||
topic_auth = await self._topic_manager.get_topic_auth(session.username)
|
||||
topic_list = getattr(topic_auth, f"{action}_acl")
|
||||
except MQTTError:
|
||||
return False
|
||||
|
||||
return topic in topic_list
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for DB topic filtering."""
|
||||
|
||||
connection: str
|
||||
"""SQLAlchemy connection string for the asyncio version of the database connector:
|
||||
|
||||
- `mysql+aiomysql://user:password@host:port/dbname`
|
||||
- `postgresql+asyncpg://user:password@host:port/dbname`
|
||||
- `sqlite+aiosqlite:///dbfilename.db`
|
||||
"""
|
||||
sync_schema: bool = False
|
||||
"""Use SQLAlchemy to create / update the database schema."""
|
|
@ -0,0 +1,151 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.contrib.auth_db import DBType, db_connection_str
|
||||
from amqtt.contrib.auth_db.managers import TopicManager, UserManager
|
||||
from amqtt.errors import MQTTError
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
topic_app = typer.Typer(no_args_is_help=True)
|
||||
|
||||
|
||||
@topic_app.callback()
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
db_type: Annotated[DBType, typer.Option("--db", "-d", help="db type", count=False)],
|
||||
db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "",
|
||||
db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0,
|
||||
db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost",
|
||||
db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db",
|
||||
) -> None:
|
||||
"""Command line interface to add / remove topic authorization.
|
||||
|
||||
Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database
|
||||
password (if applicable).
|
||||
|
||||
If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.TopicManager` which provides
|
||||
the underlying functionality to this command line interface.
|
||||
"""
|
||||
if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists():
|
||||
pass
|
||||
elif db_type == DBType.SQLITE and not Path(db_filename).exists():
|
||||
logger.error(f"SQLite option could not find '{db_filename}'")
|
||||
raise typer.Exit(code=1)
|
||||
elif db_type != DBType.SQLITE and not db_username:
|
||||
logger.error("DB access requires a username be provided.")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename}
|
||||
|
||||
|
||||
@topic_app.command(name="sync")
|
||||
def db_sync(ctx: typer.Context) -> None:
|
||||
"""Create the table and schema for username and topic lists for subscribe, publish or receive.
|
||||
|
||||
Non-destructive if run multiple times. To clear the whole table, need to drop it manually.
|
||||
"""
|
||||
async def run_sync() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
|
||||
mgr = UserManager(connect)
|
||||
try:
|
||||
await mgr.db_sync()
|
||||
except MQTTError as me:
|
||||
logger.critical("Could not sync schema on db.")
|
||||
raise typer.Exit(code=1) from me
|
||||
asyncio.run(run_sync())
|
||||
logger.info("Success: database synced.")
|
||||
|
||||
|
||||
@topic_app.command(name="list")
|
||||
def list_clients(ctx: typer.Context) -> None:
|
||||
"""List all Client IDs (in alphabetical order). Will also display the hashed passwords."""
|
||||
|
||||
async def run_list() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
|
||||
mgr = TopicManager(connect)
|
||||
user_count = 0
|
||||
for user in await mgr.list_topic_auths():
|
||||
user_count += 1
|
||||
logger.info(user)
|
||||
|
||||
if not user_count:
|
||||
logger.info("No client authorizations exist.")
|
||||
|
||||
asyncio.run(run_list())
|
||||
|
||||
|
||||
@topic_app.command(name="add")
|
||||
def add_topic_allowance(
|
||||
ctx: typer.Context,
|
||||
topic: Annotated[str, typer.Argument(help="list of topics", show_default=False)],
|
||||
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client", show_default=False)],
|
||||
action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow", show_default=False)]
|
||||
) -> None:
|
||||
"""Create a new user with a client id and password (prompted)."""
|
||||
async def run_add() -> None:
|
||||
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
|
||||
ctx.obj["filename"])
|
||||
mgr = TopicManager(connect)
|
||||
|
||||
with contextlib.suppress(MQTTError):
|
||||
await mgr.create_topic_auth(client_id)
|
||||
|
||||
topic_auth = await mgr.get_topic_auth(client_id)
|
||||
if not topic_auth:
|
||||
logger.info(f"Topic auth doesn't exist for '{client_id}'")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if topic in [allowed_topic.topic for allowed_topic in topic_auth.get_topic_list(action)]:
|
||||
logger.info(f"Topic '{topic}' already exists for '{action}'.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
await mgr.add_allowed_topic(client_id, topic, action)
|
||||
|
||||
logger.info(f"Success: topic '{topic}' added to {action} for '{client_id}'")
|
||||
|
||||
asyncio.run(run_add())
|
||||
|
||||
|
||||
@topic_app.command(name="rm")
|
||||
def remove_topic_allowance(ctx: typer.Context,
|
||||
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")],
|
||||
action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow")],
|
||||
topic: Annotated[str, typer.Argument(help="list of topics")]
|
||||
) -> None:
|
||||
"""Remove a client from the authentication database."""
|
||||
async def run_remove() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
|
||||
ctx.obj["filename"])
|
||||
mgr = TopicManager(connect)
|
||||
|
||||
topic_auth = await mgr.get_topic_auth(client_id)
|
||||
|
||||
if not topic_auth:
|
||||
logger.info(f"client '{client_id}' doesn't exist.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if topic not in getattr(topic_auth, f"{action}_acl"):
|
||||
logger.info(f"Error: topic '{topic}' not in the {action} allow list for {client_id}.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
await mgr.remove_allowed_topic(client_id, topic, action)
|
||||
except MQTTError as me:
|
||||
logger.info(f"'Error: could not remove '{topic}' for client '{client_id}'.")
|
||||
raise typer.Exit(1) from me
|
||||
|
||||
logger.info(f"Success: removed topic '{topic}' from {action} for '{client_id}'")
|
||||
|
||||
asyncio.run(run_remove())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
topic_app()
|
|
@ -0,0 +1,161 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import click
|
||||
import passlib
|
||||
import typer
|
||||
|
||||
from amqtt.contrib.auth_db import DBType, db_connection_str
|
||||
from amqtt.contrib.auth_db.managers import UserManager
|
||||
from amqtt.errors import MQTTError
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
user_app = typer.Typer(no_args_is_help=True)
|
||||
|
||||
|
||||
@user_app.callback()
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
db_type: Annotated[DBType, typer.Option(..., "--db", "-d", help="db type", show_default=False)],
|
||||
db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "",
|
||||
db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0,
|
||||
db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost",
|
||||
db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db",
|
||||
) -> None:
|
||||
"""Command line interface to list, create, remove and add clients.
|
||||
|
||||
Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database
|
||||
password (if applicable) and the client id's password.
|
||||
|
||||
If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.UserManager` which provides
|
||||
the underlying functionality to this command line interface.
|
||||
"""
|
||||
if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists():
|
||||
pass
|
||||
elif db_type == DBType.SQLITE and not Path(db_filename).exists():
|
||||
logger.error(f"SQLite option could not find '{db_filename}'")
|
||||
raise typer.Exit(code=1)
|
||||
elif db_type != DBType.SQLITE and not db_username:
|
||||
logger.error("DB access requires a username be provided.")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename}
|
||||
|
||||
|
||||
@user_app.command(name="sync")
|
||||
def db_sync(ctx: typer.Context) -> None:
|
||||
"""Create the table and schema for username and hashed password.
|
||||
|
||||
Non-destructive if run multiple times. To clear the whole table, need to drop it manually.
|
||||
"""
|
||||
async def run_sync() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
|
||||
mgr = UserManager(connect)
|
||||
try:
|
||||
await mgr.db_sync()
|
||||
except MQTTError as me:
|
||||
logger.critical("Could not sync schema on db.")
|
||||
raise typer.Exit(code=1) from me
|
||||
|
||||
asyncio.run(run_sync())
|
||||
logger.info("Success: database synced.")
|
||||
|
||||
|
||||
@user_app.command(name="list")
|
||||
def list_user_auths(ctx: typer.Context) -> None:
|
||||
"""List all Client IDs (in alphabetical order). Will also display the hashed passwords."""
|
||||
|
||||
async def run_list() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
|
||||
mgr = UserManager(connect)
|
||||
user_count = 0
|
||||
for user in await mgr.list_user_auths():
|
||||
user_count += 1
|
||||
logger.info(user)
|
||||
|
||||
if not user_count:
|
||||
logger.info("No client authentications exist.")
|
||||
|
||||
asyncio.run(run_list())
|
||||
|
||||
|
||||
@user_app.command(name="add")
|
||||
def create_user_auth(
|
||||
ctx: typer.Context,
|
||||
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")],
|
||||
) -> None:
|
||||
"""Create a new user with a client id and password (prompted)."""
|
||||
async def run_create() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
|
||||
ctx.obj["filename"])
|
||||
mgr = UserManager(connect)
|
||||
client_password = click.prompt("Enter the client's password", hide_input=True)
|
||||
if not client_password.strip():
|
||||
logger.info("Error: client password cannot be empty.")
|
||||
raise typer.Exit(1)
|
||||
try:
|
||||
user = await mgr.create_user_auth(client_id, client_password.strip())
|
||||
except passlib.exc.MissingBackendError as mbe:
|
||||
logger.info(f"Please install backend: {mbe}")
|
||||
raise typer.Exit(code=1) from mbe
|
||||
|
||||
if not user:
|
||||
logger.info(f"Error: could not create user: {client_id}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
logger.info(f"Success: created {user}")
|
||||
|
||||
asyncio.run(run_create())
|
||||
|
||||
|
||||
@user_app.command(name="rm")
|
||||
def remove_user_auth(ctx: typer.Context,
|
||||
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")]) -> None:
|
||||
"""Remove a client from the authentication database."""
|
||||
async def run_remove() -> None:
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
|
||||
ctx.obj["filename"])
|
||||
mgr = UserManager(connect)
|
||||
user = await mgr.get_user_auth(client_id)
|
||||
if not user:
|
||||
logger.info(f"Error: client '{client_id}' does not exist.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not click.confirm(f"Please confirm the removal of '{client_id}'?"):
|
||||
raise typer.Exit(0)
|
||||
|
||||
user = await mgr.delete_user_auth(client_id)
|
||||
if not user:
|
||||
logger.info(f"Error: client '{client_id}' does not exist.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
logger.info(f"Success: '{user.username}' was removed.")
|
||||
|
||||
asyncio.run(run_remove())
|
||||
|
||||
|
||||
@user_app.command(name="pwd")
|
||||
def change_password(
|
||||
ctx: typer.Context,
|
||||
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")],
|
||||
) -> None:
|
||||
"""Update a user's password (prompted)."""
|
||||
async def run_password() -> None:
|
||||
client_password = click.prompt("Enter the client's new password", hide_input=True)
|
||||
if not client_password.strip():
|
||||
logger.error("Error: client password cannot be empty.")
|
||||
raise typer.Exit(1)
|
||||
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
|
||||
ctx.obj["filename"])
|
||||
mgr = UserManager(connect)
|
||||
await mgr.update_user_auth_password(client_id, client_password.strip())
|
||||
logger.info(f"Success: client '{client_id}' password updated.")
|
||||
|
||||
asyncio.run(run_password())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
user_app()
|
|
@ -0,0 +1,250 @@
|
|||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
from datetime import UTC
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from datetime import timezone
|
||||
UTC = timezone.utc
|
||||
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509 import Certificate, CertificateSigningRequest
|
||||
from cryptography.x509.oid import NameOID
|
||||
|
||||
from amqtt.plugins.base import BaseAuthPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserAuthCertPlugin(BaseAuthPlugin):
|
||||
"""Used a *signed* x509 certificate's `Subject AlternativeName` or `SAN` to verify client authentication.
|
||||
|
||||
Often used for IoT devices, this method provides the most secure form of identification. A root
|
||||
certificate, often referenced as a CA certificate -- either issued by a known authority (such as LetsEncrypt)
|
||||
or a self-signed certificate) is used to sign a private key and certificate for the server. Each device/client
|
||||
also gets a unique private key and certificate signed by the same CA certificate; also included in the device
|
||||
certificate is a 'SAN' or SubjectAlternativeName which is the device's unique identifier.
|
||||
|
||||
Since both server and device certificates are signed by the same CA certificate, the client can
|
||||
verify the server's authenticity; and the server can verify the client's authenticity. And since
|
||||
the device's certificate contains a x509 SAN, the server (with this plugin) can identify the device securely.
|
||||
|
||||
!!! note "URI and Client ID configuration"
|
||||
`uri_domain` configuration must be set to the same uri used to generate the device credentials
|
||||
|
||||
when a device is connecting with private key and certificate, the `client_id` must
|
||||
match the device id used to generate the device credentials.
|
||||
|
||||
Available ore three scripts to help with the key generation and certificate signing: `ca_creds`, `server_creds`
|
||||
and `device_creds`.
|
||||
|
||||
!!! note "Configuring broker & client for using Self-signed root CA"
|
||||
If using self-signed root credentials, the `cafile` configuration for both broker and client need to be
|
||||
configured with `cafile` set to the `ca.crt`.
|
||||
"""
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Verify the client's session using the provided client's x509 certificate."""
|
||||
if not session.ssl_object:
|
||||
return False
|
||||
|
||||
der_cert = session.ssl_object.getpeercert(binary_form=True)
|
||||
if der_cert:
|
||||
cert = x509.load_der_x509_certificate(der_cert, backend=default_backend())
|
||||
|
||||
try:
|
||||
san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
||||
uris = san.value.get_values_for_type(x509.UniformResourceIdentifier)
|
||||
|
||||
if self.config.uri_domain not in uris[0]:
|
||||
return False
|
||||
|
||||
pattern = rf"^spiffe://{re.escape(self.config.uri_domain)}/device/([^/]+)$"
|
||||
match = re.match(pattern, uris[0])
|
||||
if not match:
|
||||
return False
|
||||
|
||||
return match.group(1) == session.client_id
|
||||
|
||||
except x509.ExtensionNotFound:
|
||||
logger.warning("No SAN extension found.")
|
||||
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for the CertificateAuthPlugin."""
|
||||
|
||||
uri_domain: str
|
||||
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
|
||||
|
||||
|
||||
def generate_root_creds(country: str, state: str, locality: str,
|
||||
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
"""Generate CA key and certificate."""
|
||||
# generate private key for the server
|
||||
ca_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=4096,
|
||||
)
|
||||
# Create certificate subject and issuer (self-signed)
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
|
||||
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
|
||||
x509.NameAttribute(NameOID.LOCALITY_NAME, locality),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, cn),
|
||||
])
|
||||
|
||||
# 3. Build self-signed certificate
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(ca_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(datetime.now(UTC))
|
||||
.not_valid_after(datetime.now(UTC) + timedelta(days=3650)) # 10 years
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=None),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
digital_signature=False,
|
||||
key_encipherment=False,
|
||||
content_commitment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
return ca_key, cert
|
||||
|
||||
|
||||
def generate_server_csr(country: str, org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
|
||||
"""Generate server private key and server certificate-signing-request."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
csr = (
|
||||
x509.CertificateSigningRequestBuilder()
|
||||
.subject_name(x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, cn),
|
||||
]))
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.DNSName(cn),
|
||||
x509.IPAddress(IPv4Address("127.0.0.1")),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(key, hashes.SHA256())
|
||||
)
|
||||
|
||||
return key, csr
|
||||
|
||||
|
||||
def generate_device_csr(country: str, org_name: str, common_name: str,
|
||||
uri_san: str, dns_san: str
|
||||
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
|
||||
"""Generate a device key and a csr."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
csr = (
|
||||
x509.CertificateSigningRequestBuilder()
|
||||
.subject_name(x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name
|
||||
),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
]))
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.UniformResourceIdentifier(uri_san),
|
||||
x509.DNSName(dns_san),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(key, hashes.SHA256())
|
||||
)
|
||||
|
||||
return key, csr
|
||||
|
||||
|
||||
def sign_csr(csr: CertificateSigningRequest,
|
||||
ca_key: rsa.RSAPrivateKey,
|
||||
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
|
||||
"""Sign a csr with CA credentials."""
|
||||
return (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(csr.subject)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(csr.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(datetime.now(UTC))
|
||||
.not_valid_after(datetime.now(UTC) + timedelta(days=validity_days))
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value,
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()), # type: ignore[arg-type]
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
|
||||
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
"""Load server key and certificate."""
|
||||
with Path(ca_key_fn).open("rb") as f:
|
||||
ca_key: rsa.RSAPrivateKey = serialization.load_pem_private_key(f.read(), password=None) # type: ignore[assignment]
|
||||
with Path(ca_crt_fn).open("rb") as f:
|
||||
ca_cert = x509.load_pem_x509_certificate(f.read())
|
||||
return ca_key, ca_cert
|
||||
|
||||
|
||||
def write_key_and_crt(key: rsa.RSAPrivateKey, crt: Certificate,
|
||||
prefix: str, path: Path | None = None) -> None:
|
||||
"""Create pem-encoded files for key and certificate."""
|
||||
path = path or Path()
|
||||
|
||||
crt_fn = path / f"{prefix}.crt"
|
||||
key_fn = path / f"{prefix}.key"
|
||||
|
||||
with crt_fn.open("wb") as f:
|
||||
f.write(crt.public_bytes(serialization.Encoding.PEM))
|
||||
with key_fn.open("wb") as f:
|
||||
f.write(key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption()
|
||||
))
|
|
@ -0,0 +1,181 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
from enum import StrEnum
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientResponse, ClientSession, FormData
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseMode(StrEnum):
|
||||
STATUS = "status"
|
||||
JSON = "json"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
class RequestMethod(StrEnum):
|
||||
GET = "get"
|
||||
POST = "post"
|
||||
PUT = "put"
|
||||
|
||||
|
||||
class ParamsMode(StrEnum):
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
|
||||
|
||||
class ACLError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
HTTP_2xx_MIN = 200
|
||||
HTTP_2xx_MAX = 299
|
||||
|
||||
HTTP_4xx_MIN = 400
|
||||
HTTP_4xx_MAX = 499
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpConfig:
|
||||
"""Configuration for the HTTP Auth & ACL Plugin."""
|
||||
|
||||
host: str
|
||||
"""hostname of the server for the auth & acl check"""
|
||||
port: int
|
||||
"""port of the server for the auth & acl check"""
|
||||
request_method: RequestMethod = RequestMethod.GET
|
||||
"""send the request as a GET, POST or PUT"""
|
||||
params_mode: ParamsMode = ParamsMode.JSON # see docs/plugins/http.md for additional details
|
||||
"""send the request with `JSON` or `FORM` data. *additional details below*"""
|
||||
response_mode: ResponseMode = ResponseMode.JSON # see docs/plugins/http.md for additional details
|
||||
"""expected response from the auth/acl server. `STATUS` (code), `JSON`, or `TEXT`. *additional details below*"""
|
||||
with_tls: bool = False
|
||||
"""http or https"""
|
||||
user_agent: str = "amqtt"
|
||||
"""the 'User-Agent' header sent along with the request"""
|
||||
superuser_uri: str | None = None
|
||||
"""URI to verify if the user is a superuser (e.g. '/superuser'), `None` if superuser is not supported"""
|
||||
timeout: int = 5
|
||||
"""duration, in seconds, to wait for the HTTP server to respond"""
|
||||
|
||||
|
||||
class AuthHttpPlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
self.http = ClientSession(headers={"User-Agent": self.config.user_agent})
|
||||
|
||||
match self.config.request_method:
|
||||
case RequestMethod.GET:
|
||||
self.method = self.http.get
|
||||
case RequestMethod.PUT:
|
||||
self.method = self.http.put
|
||||
case _:
|
||||
self.method = self.http.post
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
await self.http.close()
|
||||
|
||||
@staticmethod
|
||||
def _is_2xx(r: ClientResponse) -> bool:
|
||||
return HTTP_2xx_MIN <= r.status <= HTTP_2xx_MAX
|
||||
|
||||
@staticmethod
|
||||
def _is_4xx(r: ClientResponse) -> bool:
|
||||
return HTTP_4xx_MIN <= r.status <= HTTP_4xx_MAX
|
||||
|
||||
def _get_params(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
match self.config.params_mode:
|
||||
case ParamsMode.FORM:
|
||||
match self.config.request_method:
|
||||
case RequestMethod.GET:
|
||||
kwargs = {"params": payload}
|
||||
case _: # POST, PUT
|
||||
d: Any = FormData(payload)
|
||||
kwargs = {"data": d}
|
||||
case _: # JSON
|
||||
kwargs = {"json": payload}
|
||||
return kwargs
|
||||
|
||||
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911
|
||||
|
||||
kwargs = self._get_params(payload)
|
||||
|
||||
async with self.method(url, **kwargs) as r:
|
||||
logger.debug(f"http request returned {r.status}")
|
||||
|
||||
match self.config.response_mode:
|
||||
case ResponseMode.TEXT:
|
||||
return self._is_2xx(r) and (await r.text()).lower() == "ok"
|
||||
case ResponseMode.STATUS:
|
||||
if self._is_2xx(r):
|
||||
return True
|
||||
if self._is_4xx(r):
|
||||
return False
|
||||
# any other code
|
||||
return None
|
||||
case _:
|
||||
if not self._is_2xx(r):
|
||||
return False
|
||||
data: dict[str, Any] = await r.json()
|
||||
data = {k.lower(): v for k, v in data.items()}
|
||||
return data.get("ok", None)
|
||||
|
||||
def get_url(self, uri: str) -> str:
|
||||
return f"{'https' if self.config.with_tls else 'http'}://{self.config.host}:{self.config.port}{uri}"
|
||||
|
||||
|
||||
class UserAuthHttpPlugin(AuthHttpPlugin, BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
d = {"username": session.username, "password": session.password, "client_id": session.client_id}
|
||||
return await self._send_request(self.get_url(self.config.user_uri), d)
|
||||
|
||||
@dataclass
|
||||
class Config(HttpConfig):
|
||||
"""Configuration for the HTTP Auth Plugin."""
|
||||
|
||||
user_uri: str = "/user"
|
||||
"""URI of the auth check."""
|
||||
|
||||
|
||||
class TopicAuthHttpPlugin(AuthHttpPlugin, BaseTopicPlugin):
|
||||
|
||||
async def topic_filtering(self, *,
|
||||
session: Session | None = None,
|
||||
topic: str | None = None,
|
||||
action: Action | None = None) -> bool | None:
|
||||
if not session:
|
||||
return None
|
||||
acc = 0
|
||||
match action:
|
||||
case Action.PUBLISH:
|
||||
acc = 2
|
||||
case Action.SUBSCRIBE:
|
||||
acc = 4
|
||||
case Action.RECEIVE:
|
||||
acc = 1
|
||||
|
||||
d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc}
|
||||
return await self._send_request(self.get_url(self.config.topic_uri), d)
|
||||
|
||||
@dataclass
|
||||
class Config(HttpConfig):
|
||||
"""Configuration for the HTTP Topic Plugin."""
|
||||
|
||||
topic_uri: str = "/acl"
|
||||
"""URI of the topic check."""
|
|
@ -0,0 +1,120 @@
|
|||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
import jwt
|
||||
|
||||
try:
|
||||
from enum import StrEnum
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.plugins import TopicMatcher
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Algorithms(StrEnum):
|
||||
ES256 = "ES256"
|
||||
ES256K = "ES256K"
|
||||
ES384 = "ES384"
|
||||
ES512 = "ES512"
|
||||
ES521 = "ES521"
|
||||
EdDSA = "EdDSA"
|
||||
HS256 = "HS256"
|
||||
HS384 = "HS384"
|
||||
HS512 = "HS512"
|
||||
PS256 = "PS256"
|
||||
PS384 = "PS384"
|
||||
PS512 = "PS512"
|
||||
RS256 = "RS256"
|
||||
RS384 = "RS384"
|
||||
RS512 = "RS512"
|
||||
|
||||
|
||||
class UserAuthJwtPlugin(BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
|
||||
if not session.username or not session.password:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
||||
return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(f"jwt for '{session.username}' is expired")
|
||||
return False
|
||||
except jwt.InvalidTokenError:
|
||||
logger.debug(f"jwt for '{session.username}' is invalid")
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for the JWT user authentication."""
|
||||
|
||||
secret_key: str
|
||||
"""Secret key to decrypt the token."""
|
||||
user_claim: str
|
||||
"""Payload key for user name."""
|
||||
algorithm: str = "HS256"
|
||||
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
|
||||
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
|
||||
|
||||
|
||||
class TopicAuthJwtPlugin(BaseTopicPlugin):
|
||||
|
||||
_topic_jwt_claims: ClassVar = {
|
||||
Action.PUBLISH: "publish_claim",
|
||||
Action.SUBSCRIBE: "subscribe_claim",
|
||||
Action.RECEIVE: "receive_claim",
|
||||
}
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.topic_matcher = TopicMatcher()
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool | None:
|
||||
|
||||
if not session or not topic or not action:
|
||||
return None
|
||||
|
||||
if not session.password:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
|
||||
claim = getattr(self.config, self._topic_jwt_claims[action])
|
||||
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(f"jwt for '{session.username}' is expired")
|
||||
return False
|
||||
except jwt.InvalidTokenError:
|
||||
logger.debug(f"jwt for '{session.username}' is invalid")
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for the JWT topic authorization."""
|
||||
|
||||
secret_key: str
|
||||
"""Secret key to decrypt the token."""
|
||||
publish_claim: str
|
||||
"""Payload key for contains a list of permissible publish topics."""
|
||||
subscribe_claim: str
|
||||
"""Payload key for contains a list of permissible subscribe topics."""
|
||||
receive_claim: str
|
||||
"""Payload key for contains a list of permissible receive topics."""
|
||||
algorithm: str = "HS256"
|
||||
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
|
||||
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
|
|
@ -0,0 +1,138 @@
|
|||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
import ldap
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.errors import PluginInitError
|
||||
from amqtt.plugins import TopicMatcher
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LdapConfig:
|
||||
"""Configuration for the LDAP Plugins."""
|
||||
|
||||
server: str
|
||||
"""uri formatted server location. e.g `ldap://localhost:389`"""
|
||||
base_dn: str
|
||||
"""distinguished name (dn) of the ldap server. e.g. `dc=amqtt,dc=io`"""
|
||||
user_attribute: str
|
||||
"""attribute in ldap entry to match the username against"""
|
||||
bind_dn: str
|
||||
"""distinguished name (dn) of known, preferably read-only, user. e.g. `cn=admin,dc=amqtt,dc=io`"""
|
||||
bind_password: str
|
||||
"""password for known, preferably read-only, user"""
|
||||
|
||||
|
||||
class AuthLdapPlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.conn = ldap.initialize(self.config.server)
|
||||
self.conn.protocol_version = ldap.VERSION3 # pylint: disable=E1101
|
||||
try:
|
||||
self.conn.simple_bind_s(self.config.bind_dn, self.config.bind_password)
|
||||
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
|
||||
raise PluginInitError(self.__class__) from e
|
||||
|
||||
|
||||
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
|
||||
"""Plugin to authenticate a user with an LDAP directory server."""
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
|
||||
# use our initial creds to see if the user exists
|
||||
search_filter = f"({self.config.user_attribute}={session.username})"
|
||||
result = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, ["dn"]) # pylint: disable=E1101
|
||||
if not result:
|
||||
logger.debug(f"user not found: {session.username}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# `search_s` responds with list of tuples: (dn, entry); first in list is our match
|
||||
user_dn = result[0][0]
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
try:
|
||||
user_conn = ldap.initialize(self.config.server)
|
||||
user_conn.simple_bind_s(user_dn, session.password)
|
||||
except ldap.INVALID_CREDENTIALS: # pylint: disable=E1101
|
||||
logger.debug(f"invalid credentials for '{session.username}'")
|
||||
return False
|
||||
except ldap.LDAPError as e: # pylint: disable=E1101
|
||||
logger.debug(f"LDAP error during user bind: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@dataclass
|
||||
class Config(LdapConfig):
|
||||
"""Configuration for the User Auth LDAP Plugin."""
|
||||
|
||||
|
||||
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
||||
"""Plugin to authenticate a user with an LDAP directory server."""
|
||||
|
||||
_action_attr_map: ClassVar = {
|
||||
Action.PUBLISH: "publish_attribute",
|
||||
Action.SUBSCRIBE: "subscribe_attribute",
|
||||
Action.RECEIVE: "receive_attribute"
|
||||
}
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.topic_matcher = TopicMatcher()
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool | None:
|
||||
|
||||
# if not provided needed criteria, can't properly evaluate topic filtering
|
||||
if not session or not action or not topic:
|
||||
return None
|
||||
|
||||
search_filter = f"({self.config.user_attribute}={session.username})"
|
||||
attrs = [
|
||||
"cn",
|
||||
self.config.publish_attribute,
|
||||
self.config.subscribe_attribute,
|
||||
self.config.receive_attribute
|
||||
]
|
||||
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
|
||||
|
||||
if not results:
|
||||
logger.debug(f"user not found: {session.username}")
|
||||
return False
|
||||
|
||||
if len(results) > 1:
|
||||
found_users = [dn for dn, _ in results]
|
||||
logger.debug(f"multiple users found: {', '.join(found_users)}")
|
||||
return False
|
||||
|
||||
dn, entry = results[0]
|
||||
|
||||
ldap_attribute = getattr(self.config, self._action_attr_map[action])
|
||||
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
|
||||
logger.debug(f"DN: {dn} - {ldap_attribute}={topic_filters}")
|
||||
|
||||
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
|
||||
|
||||
@dataclass
|
||||
class Config(LdapConfig):
|
||||
"""Configuration for the LDAPAuthPlugin."""
|
||||
|
||||
publish_attribute: str
|
||||
"""LDAP attribute which contains a list of permissible publish topics."""
|
||||
subscribe_attribute: str
|
||||
"""LDAP attribute which contains a list of permissible subscribe topics."""
|
||||
receive_attribute: str
|
||||
"""LDAP attribute which contains a list of permissible receive topics."""
|
|
@ -0,0 +1,264 @@
|
|||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from amqtt.broker import BrokerContext, RetainedApplicationMessage
|
||||
from amqtt.contrib import DataClassListJSON
|
||||
from amqtt.errors import PluginError
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetainedMessage:
|
||||
topic: str
|
||||
data: str
|
||||
qos: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Subscription:
|
||||
topic: str
|
||||
qos: int
|
||||
|
||||
|
||||
class StoredSession(Base):
|
||||
__tablename__ = "stored_sessions"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
client_id: Mapped[str] = mapped_column(String)
|
||||
|
||||
clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
|
||||
|
||||
will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
|
||||
will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None)
|
||||
will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None)
|
||||
will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None)
|
||||
|
||||
keep_alive: Mapped[int] = mapped_column(Integer, default=0)
|
||||
retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list)
|
||||
subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list)
|
||||
|
||||
|
||||
class StoredMessage(Base):
|
||||
__tablename__ = "stored_messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
topic: Mapped[str] = mapped_column(String)
|
||||
data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
|
||||
qos: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
|
||||
class SessionDBPlugin(BasePlugin[BrokerContext]):
|
||||
"""Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
|
||||
|
||||
Configuration:
|
||||
- file *(string)* path & filename to store the session db. default: `amqtt.db`
|
||||
- clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
# bypass the `test_plugins_correct_has_attr` until it can be updated
|
||||
if not hasattr(self.config, "file"):
|
||||
logger.warning("`Config` is missing a `file` attribute")
|
||||
return
|
||||
|
||||
self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}")
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession:
|
||||
|
||||
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
|
||||
stored_session = await db_session.scalar(stmt)
|
||||
if stored_session is None:
|
||||
stored_session = StoredSession(client_id=client_id)
|
||||
db_session.add(stored_session)
|
||||
await db_session.flush()
|
||||
return stored_session
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
|
||||
|
||||
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
|
||||
stored_message = await db_session.scalar(stmt)
|
||||
if stored_message is None:
|
||||
stored_message = StoredMessage(topic=topic)
|
||||
db_session.add(stored_message)
|
||||
await db_session.flush()
|
||||
return stored_message
|
||||
|
||||
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
|
||||
"""Search to see if session already exists."""
|
||||
# if client id doesn't exist, create (can ignore if session is anonymous)
|
||||
# update session information (will, clean_session, etc)
|
||||
|
||||
# don't store session information for clean or anonymous sessions
|
||||
if client_session.clean_session in (None, True) or client_session.is_anonymous:
|
||||
return
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stored_session = await self._get_or_create_session(db_session, client_id)
|
||||
|
||||
stored_session.clean_session = client_session.clean_session
|
||||
stored_session.will_flag = client_session.will_flag
|
||||
stored_session.will_message = client_session.will_message # type: ignore[assignment]
|
||||
stored_session.will_qos = client_session.will_qos
|
||||
stored_session.will_retain = client_session.will_retain
|
||||
stored_session.will_topic = client_session.will_topic
|
||||
stored_session.keep_alive = client_session.keep_alive
|
||||
|
||||
await db_session.flush()
|
||||
|
||||
async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None:
|
||||
"""Create/update subscription if clean session = false."""
|
||||
session = self.context.get_session(client_id)
|
||||
if not session:
|
||||
logger.warning(f"'{client_id}' is subscribing but doesn't have a session")
|
||||
return
|
||||
|
||||
if session.clean_session:
|
||||
return
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
# stored sessions shouldn't need to be created here, but we'll use the same helper...
|
||||
stored_session = await self._get_or_create_session(db_session, client_id)
|
||||
stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)]
|
||||
await db_session.flush()
|
||||
|
||||
async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None:
|
||||
"""Remove subscription if clean session = false."""
|
||||
|
||||
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None:
|
||||
"""Update to retained messages.
|
||||
|
||||
if retained_message.data is None or '', the message is being cleared
|
||||
"""
|
||||
# if client_id is valid, the retained message is for a disconnected client
|
||||
if client_id is not None:
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
# stored sessions shouldn't need to be created here, but we'll use the same helper...
|
||||
stored_session = await self._get_or_create_session(db_session, client_id)
|
||||
stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic,
|
||||
retained_message.data.decode(),
|
||||
retained_message.qos or 0)]
|
||||
await db_session.flush()
|
||||
return
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
# if the retained message has data, we need to store/update for the topic
|
||||
if retained_message.data:
|
||||
client_message = await self._get_or_create_message(db_session, retained_message.topic)
|
||||
client_message.data = retained_message.data # type: ignore[assignment]
|
||||
client_message.qos = retained_message.qos or 0
|
||||
await db_session.flush()
|
||||
return
|
||||
|
||||
# if there is no data, clear the stored message (if exists) for the topic
|
||||
stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic)
|
||||
topic_message = await db_session.scalar(stmt)
|
||||
if topic_message is not None:
|
||||
await db_session.delete(topic_message)
|
||||
await db_session.flush()
|
||||
return
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Initialize the database and db connection."""
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def on_broker_post_start(self) -> None:
|
||||
"""Load subscriptions."""
|
||||
if len(self.context.subscriptions) > 0:
|
||||
msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet"
|
||||
raise PluginError(msg)
|
||||
|
||||
if len(list(self.context.sessions)) > 0:
|
||||
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
|
||||
raise PluginError(msg)
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(StoredSession)
|
||||
stored_sessions = await db_session.execute(stmt)
|
||||
|
||||
restored_sessions = 0
|
||||
for stored_session in stored_sessions.scalars():
|
||||
await self.context.add_subscription(stored_session.client_id, None, None)
|
||||
for subscription in stored_session.subscriptions:
|
||||
await self.context.add_subscription(stored_session.client_id,
|
||||
subscription.topic,
|
||||
subscription.qos)
|
||||
session = self.context.get_session(stored_session.client_id)
|
||||
if not session:
|
||||
continue
|
||||
session.clean_session = stored_session.clean_session
|
||||
session.will_flag = stored_session.will_flag
|
||||
session.will_message = stored_session.will_message
|
||||
session.will_qos = stored_session.will_qos
|
||||
session.will_retain = stored_session.will_retain
|
||||
session.will_topic = stored_session.will_topic
|
||||
session.keep_alive = stored_session.keep_alive
|
||||
|
||||
for message in stored_session.retained:
|
||||
retained_message = RetainedApplicationMessage(
|
||||
source_session=None,
|
||||
topic=message.topic,
|
||||
data=message.data.encode(),
|
||||
qos=message.qos
|
||||
)
|
||||
await session.retained_messages.put(retained_message)
|
||||
restored_sessions += 1
|
||||
|
||||
stmt = select(StoredMessage)
|
||||
stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt)
|
||||
|
||||
restored_messages = 0
|
||||
retained_messages = self.context.retained_messages
|
||||
for stored_message in stored_messages.scalars():
|
||||
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
|
||||
source_session=None,
|
||||
topic=stored_message.topic,
|
||||
data=stored_message.data or b"",
|
||||
qos=stored_message.qos
|
||||
))
|
||||
restored_messages += 1
|
||||
logger.info(f"Retained messages restored: {restored_messages}")
|
||||
|
||||
logger.info(f"Restored {restored_sessions} sessions.")
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
"""Clean up the db connection."""
|
||||
await self._engine.dispose()
|
||||
|
||||
async def on_broker_post_shutdown(self) -> None:
|
||||
|
||||
if self.config.clear_on_shutdown and self.config.file.exists():
|
||||
self.config.file.unlink()
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration variables."""
|
||||
|
||||
file: str | Path = "amqtt.db"
|
||||
"""path & filename to store the sqlite session db."""
|
||||
clear_on_shutdown: bool = True
|
||||
"""if the broker shutdowns down normally, don't retain any information."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create `Path` from string path."""
|
||||
if isinstance(self.file, str):
|
||||
self.file = Path(self.file)
|
|
@ -0,0 +1,6 @@
|
|||
"""Module for the shadow state plugin."""
|
||||
|
||||
from .plugin import ShadowPlugin, ShadowTopicAuthPlugin
|
||||
from .states import ShadowOperation
|
||||
|
||||
__all__ = ["ShadowOperation", "ShadowPlugin", "ShadowTopicAuthPlugin"]
|
|
@ -0,0 +1,113 @@
|
|||
from collections.abc import MutableMapping
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from amqtt.contrib.shadows.states import MetaTimestamp, ShadowOperation, State, StateDocument
|
||||
|
||||
|
||||
def asdict_no_none(obj: Any) -> Any:
|
||||
"""Create dictionary from dataclass, but eliminate any key set to `None`."""
|
||||
if is_dataclass(obj):
|
||||
result = {}
|
||||
for f in fields(obj):
|
||||
value = getattr(obj, f.name)
|
||||
if value is not None:
|
||||
result[f.name] = asdict_no_none(value)
|
||||
return result
|
||||
if isinstance(obj, list):
|
||||
return [asdict_no_none(item) for item in obj if item is not None]
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
key: asdict_no_none(value)
|
||||
for key, value in obj.items()
|
||||
if value is not None
|
||||
}
|
||||
return obj
|
||||
|
||||
|
||||
def create_shadow_topic(device_id: str, shadow_name: str, message_op: "ShadowOperation") -> str:
|
||||
"""Create a shadow topic for message type."""
|
||||
return f"$shadow/{device_id}/{shadow_name}/{message_op}"
|
||||
|
||||
|
||||
class ShadowMessage:
|
||||
def to_message(self) -> bytes:
|
||||
return json.dumps(asdict_no_none(self)).encode("utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetAcceptedMessage(ShadowMessage):
|
||||
state: State[dict[str, Any]]
|
||||
metadata: State[MetaTimestamp]
|
||||
timestamp: int
|
||||
version: int
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetRejectedMessage(ShadowMessage):
|
||||
code: int
|
||||
message: str
|
||||
timestamp: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateAcceptedMessage(ShadowMessage):
|
||||
state: State[dict[str, Any]]
|
||||
metadata: State[MetaTimestamp]
|
||||
timestamp: int
|
||||
version: int
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_ACCEPT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateRejectedMessage(ShadowMessage):
|
||||
code: int
|
||||
message: str
|
||||
timestamp: int
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateDeltaMessage(ShadowMessage):
|
||||
state: MutableMapping[str, Any]
|
||||
metadata: MutableMapping[str, Any]
|
||||
timestamp: int
|
||||
version: int
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
|
||||
|
||||
|
||||
class UpdateIotaMessage(UpdateDeltaMessage):
|
||||
"""Same format, corollary name."""
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateDocumentMessage(ShadowMessage):
|
||||
previous: StateDocument
|
||||
current: StateDocument
|
||||
timestamp: int
|
||||
|
||||
@staticmethod
|
||||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DOCUMENTS)
|
|
@ -0,0 +1,140 @@
|
|||
from collections.abc import Sequence
|
||||
from dataclasses import asdict
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, CheckConstraint, Integer, String, UniqueConstraint, desc, event, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, Session, make_transient, mapped_column
|
||||
|
||||
from amqtt.contrib.shadows.states import StateDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShadowUpdateError(Exception):
|
||||
def __init__(self, message: str = "updating an existing Shadow is not allowed") -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ShadowBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def sync_shadow_base(connection: AsyncConnection) -> None:
|
||||
"""Create tables and table schemas."""
|
||||
await connection.run_sync(ShadowBase.metadata.create_all)
|
||||
|
||||
|
||||
def default_state_document() -> dict[str, Any]:
|
||||
"""Create a default (empty) state document, factory for model field."""
|
||||
return asdict(StateDocument())
|
||||
|
||||
|
||||
class Shadow(ShadowBase):
|
||||
__tablename__ = "shadows_shadow"
|
||||
|
||||
id: Mapped[str | None] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
|
||||
device_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
_state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict)
|
||||
|
||||
created_at: Mapped[int] = mapped_column(Integer, default=lambda: int(time.time()), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("version > 0", name="check_quantity_positive"),
|
||||
UniqueConstraint("device_id", "name", "version", name="uq_device_id_name_version"),
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> StateDocument:
|
||||
if not self._state:
|
||||
return StateDocument()
|
||||
return StateDocument.from_dict(self._state)
|
||||
|
||||
@state.setter
|
||||
def state(self, value: StateDocument) -> None:
|
||||
self._state = asdict(value)
|
||||
|
||||
@classmethod
|
||||
async def latest_version(cls, session: AsyncSession, device_id: str, name: str) -> Optional["Shadow"]:
|
||||
"""Get the latest version of the shadow associated with the device and name."""
|
||||
stmt = (
|
||||
select(cls).where(
|
||||
cls.device_id == device_id,
|
||||
cls.name == name
|
||||
).order_by(desc(cls.version)).limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
async def all(cls, session: AsyncSession, device_id: str, name: str) -> Sequence["Shadow"]:
|
||||
"""Return a list of all shadows associated with the device and name."""
|
||||
stmt = (
|
||||
select(cls).where(
|
||||
cls.device_id == device_id,
|
||||
cls.name == name
|
||||
).order_by(desc(cls.version)))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@event.listens_for(Shadow, "before_insert")
|
||||
def assign_incremental_version(_: Mapper[Any], connection: Session, target: "Shadow") -> None:
|
||||
"""Get the latest version of the state document."""
|
||||
stmt = (
|
||||
select(func.max(Shadow.version))
|
||||
.where(
|
||||
Shadow.device_id == target.device_id,
|
||||
Shadow.name == target.name
|
||||
)
|
||||
)
|
||||
result = connection.execute(stmt).scalar_one_or_none()
|
||||
target.version = (result or 0) + 1
|
||||
|
||||
|
||||
@event.listens_for(Shadow, "before_update")
|
||||
def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow") -> None:
|
||||
"""Prevent existing shadow from being updated."""
|
||||
raise ShadowUpdateError
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None:
|
||||
"""Force a shadow to insert a new version, instead of updating an existing."""
|
||||
# Make a copy of the dirty set so we can safely mutate the session
|
||||
dirty = list(session.dirty)
|
||||
|
||||
for obj in dirty:
|
||||
if not session.is_modified(obj, include_collections=False):
|
||||
continue # skip unchanged
|
||||
|
||||
# You can scope this to a particular class
|
||||
if not isinstance(obj, Shadow):
|
||||
continue
|
||||
|
||||
# Clone logic: convert update into insert
|
||||
session.expunge(obj) # remove from session
|
||||
make_transient(obj) # remove identity and history
|
||||
obj.id = "" # clear primary key
|
||||
obj.version += 1 # bump version or modify fields
|
||||
|
||||
session.add(obj) # re-add as new object
|
||||
|
||||
|
||||
_listener_example = '''#
|
||||
# @event.listens_for(Shadow, "before_insert")
|
||||
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:
|
||||
# """Listen for insertion and convert state document to json."""
|
||||
# if not isinstance(target.state, StateDocument):
|
||||
# msg = "'state' field needs to be a StateDocument"
|
||||
# raise TypeError(msg)
|
||||
#
|
||||
# target.state = target.state.to_dict()
|
||||
'''
|
|
@ -0,0 +1,198 @@
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.contrib.shadows.messages import (
|
||||
GetAcceptedMessage,
|
||||
GetRejectedMessage,
|
||||
UpdateAcceptedMessage,
|
||||
UpdateDeltaMessage,
|
||||
UpdateDocumentMessage,
|
||||
UpdateIotaMessage,
|
||||
)
|
||||
from amqtt.contrib.shadows.models import Shadow, sync_shadow_base
|
||||
from amqtt.contrib.shadows.states import (
|
||||
ShadowOperation,
|
||||
StateDocument,
|
||||
calculate_delta_update,
|
||||
calculate_iota_update,
|
||||
)
|
||||
from amqtt.plugins.base import BasePlugin, BaseTopicPlugin
|
||||
from amqtt.session import ApplicationMessage, Session
|
||||
|
||||
shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)")
|
||||
|
||||
DeviceID = str
|
||||
ShadowName = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShadowTopic:
|
||||
device_id: DeviceID
|
||||
name: ShadowName
|
||||
message_op: ShadowOperation
|
||||
|
||||
|
||||
def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
|
||||
"""Nested defaultdict for shadow cache."""
|
||||
return defaultdict(shadow_dict) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class ShadowPlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
self._shadows: dict[DeviceID, dict[ShadowName, StateDocument]] = defaultdict(dict)
|
||||
|
||||
self._engine = create_async_engine(self.config.connection)
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Sync the schema."""
|
||||
async with self._engine.begin() as conn:
|
||||
await sync_shadow_base(conn)
|
||||
|
||||
@staticmethod
|
||||
def shadow_topic_match(topic: str) -> ShadowTopic | None:
|
||||
"""Check if topic matches the shadow topic format."""
|
||||
# pattern is "$shadow/<username>/<shadow_name>/get, update, etc
|
||||
match = shadow_topic_re.search(topic)
|
||||
if match:
|
||||
groups = match.groupdict()
|
||||
return ShadowTopic(groups["client_id"], groups["shadow_name"], ShadowOperation(groups["request"]))
|
||||
return None
|
||||
|
||||
async def _handle_get(self, st: ShadowTopic) -> None:
|
||||
"""Send 'accepted."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
|
||||
if not shadow:
|
||||
reject_msg = GetRejectedMessage(
|
||||
code=404,
|
||||
message="shadow not found",
|
||||
)
|
||||
await self.context.broadcast_message(reject_msg.topic(st.device_id, st.name), reject_msg.to_message())
|
||||
return
|
||||
|
||||
accept_msg = GetAcceptedMessage(
|
||||
state=shadow.state.state,
|
||||
metadata=shadow.state.metadata,
|
||||
timestamp=shadow.created_at,
|
||||
version=shadow.version
|
||||
)
|
||||
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
|
||||
|
||||
async def _handle_update(self, st: ShadowTopic, update: dict[str, Any]) -> None:
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
|
||||
if not shadow:
|
||||
shadow = Shadow(device_id=st.device_id, name=st.name)
|
||||
|
||||
state_update = StateDocument.from_dict(update)
|
||||
|
||||
prev_state = shadow.state or StateDocument()
|
||||
prev_state.version = shadow.version or 0 # only required when generating shadow messages
|
||||
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
|
||||
|
||||
next_state = prev_state + state_update
|
||||
|
||||
shadow.state = next_state
|
||||
db_session.add(shadow)
|
||||
await db_session.commit()
|
||||
|
||||
next_state.version = shadow.version
|
||||
next_state.timestamp = shadow.created_at
|
||||
|
||||
accept_msg = UpdateAcceptedMessage(
|
||||
state=next_state.state,
|
||||
metadata=next_state.metadata,
|
||||
timestamp=123,
|
||||
version=1
|
||||
)
|
||||
|
||||
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
|
||||
|
||||
delta_msg = UpdateDeltaMessage(
|
||||
state=calculate_delta_update(next_state.state.desired, next_state.state.reported),
|
||||
metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
|
||||
version=shadow.version,
|
||||
timestamp=shadow.created_at
|
||||
)
|
||||
await self.context.broadcast_message(delta_msg.topic(st.device_id, st.name), delta_msg.to_message())
|
||||
|
||||
iota_msg = UpdateIotaMessage(
|
||||
state=calculate_iota_update(next_state.state.desired, next_state.state.reported),
|
||||
metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
|
||||
version=shadow.version,
|
||||
timestamp=shadow.created_at
|
||||
)
|
||||
await self.context.broadcast_message(iota_msg.topic(st.device_id, st.name), iota_msg.to_message())
|
||||
|
||||
doc_msg = UpdateDocumentMessage(
|
||||
previous=prev_state,
|
||||
current=next_state,
|
||||
timestamp=shadow.created_at
|
||||
)
|
||||
|
||||
await self.context.broadcast_message(doc_msg.topic(st.device_id, st.name), doc_msg.to_message())
|
||||
|
||||
async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None:
|
||||
"""Process a message that was received from a client."""
|
||||
topic = message.topic
|
||||
if not topic.startswith("$shadow"): # this is less overhead than do the full regular expression match
|
||||
return
|
||||
|
||||
if not (shadow_topic := self.shadow_topic_match(topic)):
|
||||
return
|
||||
|
||||
match shadow_topic.message_op:
|
||||
|
||||
case ShadowOperation.GET:
|
||||
await self._handle_get(shadow_topic)
|
||||
case ShadowOperation.UPDATE:
|
||||
await self._handle_update(shadow_topic, json.loads(message.data.decode("utf-8")))
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for shadow plugin."""
|
||||
|
||||
connection: str
|
||||
"""SQLAlchemy connection string for the asyncio version of the database connector:
|
||||
|
||||
- `mysql+aiomysql://user:password@host:port/dbname`
|
||||
- `postgresql+asyncpg://user:password@host:port/dbname`
|
||||
- `sqlite+aiosqlite:///dbfilename.db`
|
||||
"""
|
||||
|
||||
|
||||
class ShadowTopicAuthPlugin(BaseTopicPlugin):
|
||||
|
||||
async def topic_filtering(self, *,
|
||||
session: Session | None = None,
|
||||
topic: str | None = None,
|
||||
action: Action | None = None) -> bool | None:
|
||||
|
||||
session = session or Session()
|
||||
if not topic:
|
||||
return False
|
||||
|
||||
shadow_topic = ShadowPlugin.shadow_topic_match(topic)
|
||||
|
||||
if not shadow_topic:
|
||||
return False
|
||||
|
||||
return shadow_topic.device_id == session.username or session.username in self.config.superusers
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for only allowing devices access to their own shadow topics."""
|
||||
|
||||
superusers: list[str] = field(default_factory=list)
|
||||
"""A list of one or more usernames that can write to any device topic,
|
||||
primarily for the central app sending updates to devices."""
|
|
@ -0,0 +1,206 @@
|
|||
from collections import Counter
|
||||
from collections.abc import MutableMapping
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
try:
|
||||
from enum import StrEnum
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from mergedeep import merge
|
||||
|
||||
C = TypeVar("C", bound=Any)
|
||||
|
||||
|
||||
class StateError(Exception):
|
||||
def __init__(self, msg: str = "'state' field is required") -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaTimestamp:
|
||||
timestamp: int = 0
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare timestamps."""
|
||||
if isinstance(other, int):
|
||||
return self.timestamp == other
|
||||
if isinstance(other, self.__class__):
|
||||
return self.timestamp == other.timestamp
|
||||
msg = "needs to be int or MetaTimestamp"
|
||||
raise ValueError(msg)
|
||||
|
||||
# numeric operations to make this dataclass transparent
|
||||
def __abs__(self) -> int:
|
||||
"""Absolute timestamp."""
|
||||
return self.timestamp
|
||||
|
||||
def __add__(self, other: int) -> int:
|
||||
"""Add to a timestamp."""
|
||||
return self.timestamp + other
|
||||
|
||||
def __sub__(self, other: int) -> int:
|
||||
"""Subtract from a timestamp."""
|
||||
return self.timestamp - other
|
||||
|
||||
def __mul__(self, other: int) -> int:
|
||||
"""Multiply a timestamp."""
|
||||
return self.timestamp * other
|
||||
|
||||
def __float__(self) -> float:
|
||||
"""Convert timestamp to float."""
|
||||
return float(self.timestamp)
|
||||
|
||||
def __int__(self) -> int:
|
||||
"""Convert timestamp to int."""
|
||||
return int(self.timestamp)
|
||||
|
||||
def __lt__(self, other: int) -> bool:
|
||||
"""Compare timestamp."""
|
||||
return self.timestamp < other
|
||||
|
||||
def __le__(self, other: int) -> bool:
|
||||
"""Compare timestamp."""
|
||||
return self.timestamp <= other
|
||||
|
||||
def __gt__(self, other: int) -> bool:
|
||||
"""Compare timestamp."""
|
||||
return self.timestamp > other
|
||||
|
||||
def __ge__(self, other: int) -> bool:
|
||||
"""Compare timestamp."""
|
||||
return self.timestamp >= other
|
||||
|
||||
|
||||
def create_metadata(state: MutableMapping[str, Any], timestamp: int) -> dict[str, Any]:
|
||||
"""Create metadata (timestamps) for each of the keys in 'state'."""
|
||||
metadata: dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
if isinstance(value, dict):
|
||||
metadata[key] = create_metadata(value, timestamp)
|
||||
elif value is None:
|
||||
metadata[key] = None
|
||||
else:
|
||||
metadata[key] = MetaTimestamp(timestamp)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def calculate_delta_update(desired: MutableMapping[str, Any],
|
||||
reported: MutableMapping[str, Any],
|
||||
depth: bool = True,
|
||||
exclude_nones: bool = True,
|
||||
ordered_lists: bool = True) -> dict[str, Any]:
|
||||
"""Calculate state differences between desired and reported."""
|
||||
diff_dict = {}
|
||||
for key, value in desired.items():
|
||||
if value is None and exclude_nones:
|
||||
continue
|
||||
# if the desired has an element that the reported does not...
|
||||
if key not in reported:
|
||||
diff_dict[key] = value
|
||||
# if the desired has an element that's a list, but the list is
|
||||
elif isinstance(value, list) and not ordered_lists:
|
||||
if Counter(value) != Counter(reported[key]):
|
||||
diff_dict[key] = value
|
||||
elif isinstance(value, dict) and depth:
|
||||
# recurse, report when there is a difference
|
||||
obj_diff = calculate_delta_update(value, reported[key])
|
||||
if obj_diff:
|
||||
diff_dict[key] = obj_diff
|
||||
elif value != reported[key]:
|
||||
diff_dict[key] = value
|
||||
|
||||
return diff_dict
|
||||
|
||||
|
||||
def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
|
||||
"""Calculate state differences between desired and reported (including missing keys)."""
|
||||
delta = calculate_delta_update(desired, reported, depth=False, exclude_nones=False)
|
||||
|
||||
for key in reported:
|
||||
if key not in desired:
|
||||
delta[key] = None
|
||||
|
||||
return delta
|
||||
|
||||
|
||||
@dataclass
|
||||
class State(Generic[C]):
|
||||
desired: MutableMapping[str, C] = field(default_factory=dict)
|
||||
reported: MutableMapping[str, C] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, C]) -> "State[C]":
|
||||
"""Create state from dictionary."""
|
||||
return cls(
|
||||
desired=data.get("desired", {}),
|
||||
reported=data.get("reported", {})
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Determine if state is empty."""
|
||||
return bool(self.desired) or bool(self.reported)
|
||||
|
||||
def __add__(self, other: "State[C]") -> "State[C]":
|
||||
"""Merge states together."""
|
||||
return State(
|
||||
desired=merge({}, self.desired, other.desired),
|
||||
reported=merge({}, self.reported, other.reported)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateDocument:
|
||||
state: State[dict[str, Any]] = field(default_factory=State)
|
||||
metadata: State[MetaTimestamp] = field(default_factory=State)
|
||||
version: int | None = None # only required when generating shadow messages
|
||||
timestamp: int | None = None # only required when generating shadow messages
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "StateDocument":
|
||||
"""Create state document from dictionary."""
|
||||
now = int(time.time())
|
||||
if data and "state" not in data:
|
||||
raise StateError
|
||||
|
||||
state = State.from_dict(data.get("state", {}))
|
||||
metadata = State(
|
||||
desired=create_metadata(state.desired, now),
|
||||
reported=create_metadata(state.reported, now)
|
||||
)
|
||||
|
||||
return cls(state=state, metadata=metadata)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize meta data if not provided."""
|
||||
now = int(time.time())
|
||||
if not self.metadata:
|
||||
self.metadata = State(
|
||||
desired=create_metadata(self.state.desired, now),
|
||||
reported=create_metadata(self.state.reported, now),
|
||||
)
|
||||
|
||||
def __add__(self, other: "StateDocument") -> "StateDocument":
|
||||
"""Merge two state documents together."""
|
||||
return StateDocument(
|
||||
state=self.state + other.state,
|
||||
metadata=self.metadata + other.metadata
|
||||
)
|
||||
|
||||
|
||||
class ShadowOperation(StrEnum):
|
||||
GET = "get"
|
||||
UPDATE = "update"
|
||||
GET_ACCEPT = "get/accepted"
|
||||
GET_REJECT = "get/rejected"
|
||||
UPDATE_ACCEPT = "update/accepted"
|
||||
UPDATE_REJECT = "update/rejected"
|
||||
UPDATE_DOCUMENTS = "update/documents"
|
||||
UPDATE_DELTA = "update/delta"
|
||||
UPDATE_IOTA = "update/iota"
|
|
@ -16,10 +16,12 @@ class CodecError(Exception):
|
|||
class NoDataError(Exception):
|
||||
"""Exceptions thrown by packet encode/decode functions."""
|
||||
|
||||
|
||||
class ZeroLengthReadError(NoDataError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Decoding a string of length zero.")
|
||||
|
||||
|
||||
class BrokerError(Exception):
|
||||
"""Exceptions thrown by broker."""
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
|
@ -31,4 +31,6 @@ class BrokerEvents(Events):
|
|||
CLIENT_DISCONNECTED = "broker_client_disconnected"
|
||||
CLIENT_SUBSCRIBED = "broker_client_subscribed"
|
||||
CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
|
||||
RETAINED_MESSAGE = "broker_retained_message"
|
||||
MESSAGE_RECEIVED = "broker_message_received"
|
||||
MESSAGE_BROADCAST = "broker_message_broadcast"
|
||||
|
|
|
@ -167,7 +167,7 @@ class PacketIdVariableHeader(MQTTVariableHeader):
|
|||
_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
|
||||
|
||||
|
||||
class MQTTPayload(Generic[_VH], ABC):
|
||||
class MQTTPayload(ABC, Generic[_VH]):
|
||||
"""Abstract base class for MQTT payloads."""
|
||||
|
||||
async def to_stream(self, writer: asyncio.StreamWriter) -> None:
|
||||
|
|
|
@ -31,6 +31,7 @@ _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
|
|||
if TYPE_CHECKING:
|
||||
from amqtt.broker import BrokerContext
|
||||
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
|
||||
self.packet_id = packet_id
|
||||
|
@ -235,6 +236,7 @@ class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
|
|||
incoming_session.password = connect.password
|
||||
incoming_session.remote_address = remote_address
|
||||
incoming_session.remote_port = remote_port
|
||||
incoming_session.ssl_object = writer.get_ssl_info()
|
||||
|
||||
incoming_session.keep_alive = max(connect.keep_alive, 0)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from amqtt.session import Session
|
|||
if TYPE_CHECKING:
|
||||
from amqtt.client import ClientContext
|
||||
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -4,13 +4,13 @@ try:
|
|||
from asyncio import InvalidStateError, QueueFull, QueueShutDown
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.12
|
||||
class InvalidStateError(Exception): # type: ignore[no-redef]
|
||||
class InvalidStateError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
pass
|
||||
|
||||
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
pass
|
||||
|
||||
|
||||
|
@ -63,6 +63,7 @@ from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplic
|
|||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class ProtocolHandler(Generic[C]):
|
||||
"""Class implementing the MQTT communication protocol using asyncio features."""
|
||||
|
||||
|
@ -199,7 +200,7 @@ class ProtocolHandler(Generic[C]):
|
|||
async def mqtt_publish(
|
||||
self,
|
||||
topic: str,
|
||||
data: bytes | bytearray ,
|
||||
data: bytes | bytearray,
|
||||
qos: int | None,
|
||||
retain: bool,
|
||||
ack_timeout: int | None = None,
|
||||
|
@ -535,7 +536,7 @@ class ProtocolHandler(Generic[C]):
|
|||
self.handle_read_timeout()
|
||||
except NoDataError:
|
||||
self.logger.debug(f"{self.session.client_id} No data available")
|
||||
except Exception as e: # noqa: BLE001
|
||||
except Exception as e: # noqa: BLE001, pylint: disable=W0718
|
||||
self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}")
|
||||
break
|
||||
while running_tasks:
|
||||
|
|
|
@ -1 +1,38 @@
|
|||
"""INIT."""
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class TopicMatcher:
|
||||
|
||||
_instance: Optional["TopicMatcher"] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
if not hasattr(self, "_topic_filter_matchers"):
|
||||
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "TopicMatcher":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def is_topic_allowed(self, topic: str, a_filter: str) -> bool:
|
||||
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
|
||||
return False
|
||||
|
||||
if "#" not in a_filter and "+" not in a_filter:
|
||||
# if filter doesn't contain wildcard, return exact match
|
||||
return a_filter == topic
|
||||
|
||||
# else use regex (re.compile is an expensive operation, store the matcher for future use)
|
||||
if a_filter not in self._topic_filter_matchers:
|
||||
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
|
||||
.replace("\\#", "?.*")
|
||||
.replace("\\+", "[^/]*")
|
||||
.lstrip("?"))
|
||||
match_pattern = self._topic_filter_matchers[a_filter]
|
||||
return bool(match_pattern.fullmatch(topic))
|
||||
|
||||
def are_topics_allowed(self, topic: str, many_filters: list[str]) -> bool:
|
||||
|
||||
return any(self.is_topic_allowed(topic, a_filter) for a_filter in many_filters)
|
||||
|
|
|
@ -37,9 +37,10 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Allow empty username."""
|
||||
"""Configuration for AnonymousAuthPlugin."""
|
||||
|
||||
allow_anonymous: bool = field(default=True)
|
||||
"""Allow all anonymous authentication (even with _no_ username)."""
|
||||
|
||||
|
||||
class FileAuthPlugin(BaseAuthPlugin):
|
||||
|
@ -78,7 +79,7 @@ class FileAuthPlugin(BaseAuthPlugin):
|
|||
self.context.logger.warning(f"Password file '{password_file}' not found")
|
||||
except ValueError:
|
||||
self.context.logger.exception(f"Malformed password file '{password_file}'")
|
||||
except Exception:
|
||||
except OSError:
|
||||
self.context.logger.exception(f"Unexpected error reading password file '{password_file}'")
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
|
@ -107,6 +108,7 @@ class FileAuthPlugin(BaseAuthPlugin):
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Path to the properly encoded password file."""
|
||||
"""Configuration for FileAuthPlugin."""
|
||||
|
||||
password_file: str | Path | None = None
|
||||
"""Path to file with `username:password` pairs, one per line. All passwords are encoded using sha-512."""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.contexts import Action, BaseContext, BrokerConfig
|
||||
from amqtt.session import Session
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
@ -46,13 +46,13 @@ class BasePlugin(Generic[C]):
|
|||
return section_config
|
||||
|
||||
# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default)
|
||||
if option_name in self.context.config:
|
||||
return self.context.config[option_name]
|
||||
return default
|
||||
|
@ -75,20 +75,21 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
if not bool(self.topic_config) and not is_dataclass(self.context.config):
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
# overloaded context.config with either BrokerConfig or plugin's Config
|
||||
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default)
|
||||
if self.topic_config and option_name in self.topic_config:
|
||||
return self.topic_config[option_name]
|
||||
return default
|
||||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
) -> bool | None:
|
||||
"""Logic for filtering out topics.
|
||||
|
||||
Args:
|
||||
|
@ -97,7 +98,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
action: amqtt.broker.Action
|
||||
|
||||
Returns:
|
||||
bool: `True` if topic is allowed, `False` otherwise
|
||||
bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined
|
||||
|
||||
"""
|
||||
return bool(self.topic_config) or is_dataclass(self.context.config)
|
||||
|
@ -106,13 +107,13 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
if is_dataclass(self.context.config):
|
||||
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default)
|
||||
if self.auth_config and option_name in self.auth_config:
|
||||
return self.auth_config[option_name]
|
||||
return default
|
||||
|
@ -125,7 +126,6 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
|||
# auth config section not found and Config dataclass not provided
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Logic for session authentication.
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ import copy
|
|||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
from inspect import iscoroutinefunction
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
|
||||
import warnings
|
||||
|
||||
|
@ -49,6 +51,7 @@ def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
|
|||
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class PluginManager(Generic[C]):
|
||||
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
|
||||
|
||||
|
@ -95,10 +98,9 @@ class PluginManager(Generic[C]):
|
|||
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
|
||||
# plugins loaded directly from config dictionary
|
||||
|
||||
|
||||
if "auth" in self.app_context.config:
|
||||
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||
if "topic-check" in self.app_context.config:
|
||||
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
|
||||
|
||||
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
|
||||
|
@ -130,7 +132,7 @@ class PluginManager(Generic[C]):
|
|||
"Loading plugins from EntryPoints is deprecated and will be removed in a future version."
|
||||
" Use `plugins` section of config instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=4
|
||||
)
|
||||
|
||||
self._load_ep_plugins(namespace)
|
||||
|
@ -145,7 +147,7 @@ class PluginManager(Generic[C]):
|
|||
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
|
||||
self._event_plugin_callbacks[event].append(awaitable)
|
||||
|
||||
def _load_ep_plugins(self, namespace:str) -> None:
|
||||
def _load_ep_plugins(self, namespace: str) -> None:
|
||||
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
auth_filter_list = []
|
||||
|
@ -222,7 +224,7 @@ class PluginManager(Generic[C]):
|
|||
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
|
||||
"""Load plugin from string dotted path: mymodule.myfile.MyPlugin."""
|
||||
try:
|
||||
plugin_class: Any = import_string(plugin_path)
|
||||
plugin_class: Any = import_string(plugin_path)
|
||||
except ImportError as ep:
|
||||
msg = f"Plugin import failed: {plugin_path}"
|
||||
raise PluginImportError(msg) from ep
|
||||
|
@ -291,6 +293,15 @@ class PluginManager(Generic[C]):
|
|||
return asyncio.ensure_future(coro)
|
||||
|
||||
def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
|
||||
if self.logger.getEffectiveLevel() <= logging.DEBUG:
|
||||
try:
|
||||
future.result()
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warning("fired event was cancelled")
|
||||
# display plugin fault; don't allow it to cause a broker failure
|
||||
except Exception as exc: # noqa: BLE001, pylint: disable=W0718
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
|
||||
|
||||
with contextlib.suppress(KeyError, ValueError):
|
||||
self._fired_events.remove(future)
|
||||
|
||||
|
@ -366,7 +377,7 @@ class PluginManager(Generic[C]):
|
|||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
return await self._map_plugin_method(
|
||||
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
|
||||
self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type]
|
||||
|
||||
async def map_plugin_topic(
|
||||
self, *, session: Session, topic: str, action: "Action"
|
||||
|
|
|
@ -1,85 +1,11 @@
|
|||
import json
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.session import Session
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
||||
|
||||
class SQLitePlugin:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
self.context: BaseContext = context
|
||||
self.conn: sqlite3.Connection | None = None
|
||||
self.cursor: sqlite3.Cursor | None = None
|
||||
self.db_file: str | None = None
|
||||
self.persistence_config: dict[str, Any]
|
||||
class SQLitePlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
if (
|
||||
persistence_config := self.context.config.get("persistence") if self.context.config is not None else None
|
||||
) is not None:
|
||||
self.persistence_config = persistence_config
|
||||
self.init_db()
|
||||
else:
|
||||
self.context.logger.warning("'persistence' section not found in context configuration")
|
||||
|
||||
def init_db(self) -> None:
|
||||
self.db_file = self.persistence_config.get("file")
|
||||
if not self.db_file:
|
||||
self.context.logger.warning("'file' persistence parameter not found")
|
||||
else:
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.db_file)
|
||||
self.cursor = self.conn.cursor()
|
||||
self.context.logger.info(f"Database file '{self.db_file}' opened")
|
||||
except Exception:
|
||||
self.context.logger.exception(f"Error while initializing database '{self.db_file}'")
|
||||
if self.cursor:
|
||||
self.cursor.execute(
|
||||
"CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)",
|
||||
)
|
||||
self.cursor.execute("PRAGMA table_info(session)")
|
||||
columns = {col[1] for col in self.cursor.fetchall()}
|
||||
required_columns = {"client_id", "data"}
|
||||
if not required_columns.issubset(columns):
|
||||
self.context.logger.error("Database schema for 'session' table is incompatible.")
|
||||
|
||||
async def save_session(self, session: Session) -> None:
|
||||
if self.cursor and self.conn:
|
||||
dump: str = json.dumps(session, default=str)
|
||||
try:
|
||||
self.cursor.execute(
|
||||
"INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)",
|
||||
(session.client_id, dump),
|
||||
)
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
self.context.logger.exception(f"Failed saving session '{session}'")
|
||||
|
||||
async def find_session(self, client_id: str) -> Session | None:
|
||||
if self.cursor:
|
||||
row = self.cursor.execute(
|
||||
"SELECT data FROM session where client_id=?",
|
||||
(client_id,),
|
||||
).fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
return None
|
||||
|
||||
async def del_session(self, client_id: str) -> None:
|
||||
if self.cursor and self.conn:
|
||||
try:
|
||||
exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone()
|
||||
if exists:
|
||||
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'")
|
||||
|
||||
async def on_broker_post_shutdown(self) -> None:
|
||||
if self.conn:
|
||||
try:
|
||||
self.conn.close()
|
||||
self.context.logger.info(f"Database file '{self.db_file}' closed")
|
||||
except Exception:
|
||||
self.context.logger.exception("Error closing database connection")
|
||||
finally:
|
||||
self.conn = None
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1)
|
||||
|
|
|
@ -14,7 +14,7 @@ except ImportError:
|
|||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class Buffer(Protocol): # type: ignore[no-redef]
|
||||
class Buffer(Protocol): # type: ignore[no-redef]
|
||||
def __buffer__(self, flags: int = ...) -> memoryview:
|
||||
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
|
||||
|
||||
|
@ -75,7 +75,6 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
self._sys_interval: int = 0
|
||||
self._current_process = psutil.Process()
|
||||
|
||||
|
||||
def _clear_stats(self) -> None:
|
||||
"""Initialize broker statistics data structures."""
|
||||
for stat in (
|
||||
|
@ -112,7 +111,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
"""Initialize statistics and start $SYS broadcasting."""
|
||||
self._stats[STAT_START_TIME] = int(datetime.now(tz=UTC).timestamp())
|
||||
version = f"aMQTT version {amqtt.__version__}"
|
||||
self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
|
||||
await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
|
||||
|
||||
# Start $SYS topics management
|
||||
try:
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.errors import PluginInitError
|
||||
from amqtt.plugins.base import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
|
@ -13,17 +15,27 @@ class TopicTabooPlugin(BaseTopicPlugin):
|
|||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
) -> bool | None:
|
||||
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
|
||||
if filter_result:
|
||||
if session and session.username == "admin":
|
||||
return True
|
||||
return not (topic and topic in self._taboo)
|
||||
return filter_result
|
||||
return bool(filter_result)
|
||||
|
||||
|
||||
class TopicAccessControlListPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
if self._get_config_option("acl", None):
|
||||
warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1)
|
||||
|
||||
if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None):
|
||||
msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included"
|
||||
raise PluginInitError(msg)
|
||||
|
||||
@staticmethod
|
||||
def topic_ac(topic_requested: str, topic_allowed: str) -> bool:
|
||||
req_split = topic_requested.split("/")
|
||||
|
@ -46,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
|||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
) -> bool | None:
|
||||
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
|
||||
if not filter_result:
|
||||
return False
|
||||
|
@ -58,18 +70,26 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
|||
|
||||
req_topic = topic
|
||||
if not req_topic:
|
||||
return False\
|
||||
return False
|
||||
|
||||
username = session.username if session else None
|
||||
if username is None:
|
||||
username = "anonymous"
|
||||
|
||||
acl: dict[str, Any] = {}
|
||||
acl: dict[str, Any] | None = None
|
||||
match action:
|
||||
case Action.PUBLISH:
|
||||
acl = self._get_config_option("publish-acl", {})
|
||||
acl = self._get_config_option("publish-acl", None)
|
||||
case Action.SUBSCRIBE:
|
||||
acl = self._get_config_option("acl", {})
|
||||
acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None))
|
||||
case Action.RECEIVE:
|
||||
acl = self._get_config_option("receive-acl", None)
|
||||
case _:
|
||||
msg = "Received an invalid action type."
|
||||
raise ValueError(msg)
|
||||
|
||||
if acl is None:
|
||||
return True
|
||||
|
||||
allowed_topics = acl.get(username, [])
|
||||
if not allowed_topics:
|
||||
|
|
|
@ -21,7 +21,7 @@ def main() -> None:
|
|||
app()
|
||||
|
||||
|
||||
def _version(v:bool) -> None:
|
||||
def _version(v: bool) -> None:
|
||||
if v:
|
||||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
@ -41,9 +41,12 @@ def broker_main(
|
|||
) -> None:
|
||||
"""Command-line script for running a MQTT 3.1.1 broker."""
|
||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||
if debug:
|
||||
formatter = "[%(asctime)s] %(name)s:%(lineno)d :: %(levelname)s - %(message)s"
|
||||
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(level=level, format=formatter)
|
||||
logging.getLogger("transitions").setLevel(logging.WARNING)
|
||||
try:
|
||||
if config_file:
|
||||
config = read_yaml_config(config_file)
|
||||
|
@ -62,7 +65,7 @@ def broker_main(
|
|||
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
_ = loop.create_task(broker.start()) #noqa : RUF006
|
||||
_ = loop.create_task(broker.start()) # noqa : RUF006
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import typer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer(add_completion=False, rich_markup_mode=None)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the cli for `ca_creds`."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ca_creds(
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
state: str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
|
||||
locality: str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
|
||||
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
|
||||
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
|
||||
) -> None:
|
||||
"""Generate a self-signed key and certificate to be used as the root CA, with a key size of 2048 and a 1-year expiration."""
|
||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
try:
|
||||
from amqtt.contrib.cert import generate_root_creds, write_key_and_crt # pylint: disable=import-outside-toplevel
|
||||
except ImportError:
|
||||
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
|
||||
logger.critical(msg)
|
||||
sys.exit(1)
|
||||
|
||||
ca_key, ca_crt = generate_root_creds(country=country, state=state, locality=locality, org_name=org_name, cn=cn)
|
||||
|
||||
write_key_and_crt(ca_key, ca_crt, "ca", Path(output_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -7,7 +7,7 @@ auto_reconnect: true
|
|||
cleansession: true
|
||||
reconnect_max_interval: 10
|
||||
reconnect_retries: 2
|
||||
broker:
|
||||
connection:
|
||||
uri: "mqtt://127.0.0.1"
|
||||
plugins:
|
||||
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import typer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app = typer.Typer(add_completion=False, rich_markup_mode=None)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the `device_creds` cli."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def device_creds( # pylint: disable=too-many-locals
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
device_id: str = typer.Option(..., "--device-id", help="device id for the SAN"),
|
||||
uri: str = typer.Option(..., "--uri", help="domain name for device SAN"),
|
||||
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
|
||||
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="root key filename used for signing."),
|
||||
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="root cert filename used for signing."),
|
||||
) -> None:
|
||||
"""Generate a key and certificate for each device in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa: E501
|
||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
try:
|
||||
from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel
|
||||
generate_device_csr,
|
||||
load_ca,
|
||||
sign_csr,
|
||||
write_key_and_crt,
|
||||
)
|
||||
except ImportError:
|
||||
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
|
||||
logger.critical(msg)
|
||||
sys.exit(1)
|
||||
|
||||
ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn)
|
||||
|
||||
uri_san = f"spiffe://{uri}/device/{device_id}"
|
||||
dns_san = f"{device_id}.local"
|
||||
|
||||
device_key, device_csr = generate_device_csr(
|
||||
country=country,
|
||||
org_name=org_name,
|
||||
common_name=device_id,
|
||||
uri_san=uri_san,
|
||||
dns_san=dns_san
|
||||
)
|
||||
|
||||
device_crt = sign_csr(device_csr, ca_key, ca_crt)
|
||||
|
||||
write_key_and_crt(device_key, device_crt, device_id, Path(output_dir))
|
||||
|
||||
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,37 @@
|
|||
import logging
|
||||
import sys
|
||||
|
||||
from amqtt.errors import MQTTError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the auth db cli."""
|
||||
try:
|
||||
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
|
||||
except ImportError:
|
||||
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
|
||||
sys.exit(1)
|
||||
|
||||
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
|
||||
|
||||
try:
|
||||
topic_app()
|
||||
except ModuleNotFoundError as mnfe:
|
||||
logger.critical(f"Please install database-specific dependencies: {mnfe}")
|
||||
sys.exit(1)
|
||||
except ValueError as ve:
|
||||
if "greenlet" in f"{ve}":
|
||||
logger.critical("Please install database-specific dependencies: 'greenlet'")
|
||||
sys.exit(1)
|
||||
logger.critical(f"Unknown error: {ve}")
|
||||
sys.exit(1)
|
||||
|
||||
except MQTTError as me:
|
||||
logger.critical(f"could not execute command: {me}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,36 @@
|
|||
import logging
|
||||
import sys
|
||||
|
||||
from amqtt.errors import MQTTError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the auth db cli."""
|
||||
try:
|
||||
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
|
||||
except ImportError:
|
||||
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
|
||||
sys.exit(1)
|
||||
|
||||
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
|
||||
try:
|
||||
user_app()
|
||||
except ModuleNotFoundError as mnfe:
|
||||
logger.critical(f"Please install database-specific dependencies: {mnfe}")
|
||||
sys.exit(1)
|
||||
except ValueError as ve:
|
||||
if "greenlet" in f"{ve}":
|
||||
logger.critical("Please install database-specific dependencies: 'greenlet'")
|
||||
sys.exit(1)
|
||||
logger.critical(f"Unknown error: {ve}")
|
||||
sys.exit(1)
|
||||
|
||||
except MQTTError as me:
|
||||
logger.critical(f"could not execute command: {me}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -52,7 +52,7 @@ class MessageInput:
|
|||
with Path(self.file).open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
yield line.encode(encoding="utf-8")
|
||||
except Exception:
|
||||
except (FileNotFoundError, OSError):
|
||||
logger.exception(f"Failed to read file '{self.file}'")
|
||||
if self.lines:
|
||||
for line in sys.stdin:
|
||||
|
@ -118,6 +118,7 @@ async def do_pub(
|
|||
logger.fatal("Publish canceled due to previous error")
|
||||
raise asyncio.CancelledError from ce
|
||||
|
||||
|
||||
app = typer.Typer(add_completion=False, rich_markup_mode=None)
|
||||
|
||||
|
||||
|
@ -131,8 +132,9 @@ def _version(v: bool) -> None:
|
|||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
|
||||
@app.command()
|
||||
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
def publisher_main( # pylint: disable=R0914,R0917
|
||||
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),
|
||||
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
|
||||
|
@ -155,7 +157,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
will_retain: bool = typer.Option(False, "--will-retain", help="If the client disconnects unexpectedly the message sent out will be treated as a retained message. *only valid, if `--will-topic` is specified*"),
|
||||
extra_headers_json: str | None = typer.Option(None, "--extra-headers", help="Specify a JSON object string with key-value pairs representing additional headers that are transmitted on the initial connection. *websocket connections only*."),
|
||||
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
|
||||
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
|
||||
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
|
||||
) -> None:
|
||||
"""Command-line MQTT client for publishing simple messages."""
|
||||
provided = [bool(message), bool(file), stdin, lines, no_message]
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import typer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer(add_completion=False, rich_markup_mode=None)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the `server_creds` cli."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def server_creds(
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
|
||||
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
|
||||
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
|
||||
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
|
||||
) -> None:
|
||||
"""Generate a key and certificate for the broker in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa : E501
|
||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
try:
|
||||
from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel
|
||||
generate_server_csr,
|
||||
load_ca,
|
||||
sign_csr,
|
||||
write_key_and_crt,
|
||||
)
|
||||
except ImportError:
|
||||
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
|
||||
logger.critical(msg)
|
||||
sys.exit(1)
|
||||
|
||||
ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn)
|
||||
|
||||
server_key, server_csr = generate_server_csr(country=country, org_name=org_name, cn=cn)
|
||||
server_crt = sign_csr(server_csr, ca_key, ca_crt)
|
||||
|
||||
write_key_and_crt(server_key, server_crt, "server", Path(output_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -100,17 +100,17 @@ def main() -> None:
|
|||
app()
|
||||
|
||||
|
||||
def _version(v:bool) -> None:
|
||||
def _version(v: bool) -> None:
|
||||
if v:
|
||||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
|
||||
@app.command()
|
||||
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
def subscribe_main( # pylint: disable=R0914,R0917
|
||||
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
|
||||
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
|
||||
topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008
|
||||
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from asyncio import Queue
|
||||
from collections import OrderedDict
|
||||
from typing import Any, ClassVar
|
||||
import logging
|
||||
from math import floor
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from transitions import Machine
|
||||
|
||||
|
@ -10,6 +13,11 @@ from amqtt.mqtt.publish import PublishPacket
|
|||
OUTGOING = 0
|
||||
INCOMING = 1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationMessage:
|
||||
"""ApplicationMessage and subclasses are used to store published message information flow.
|
||||
|
@ -138,6 +146,9 @@ class Session:
|
|||
self.cadata: bytes | None = None
|
||||
self._packet_id: int = 0
|
||||
self.parent: int = 0
|
||||
self.last_connect_time: int | None = None
|
||||
self.ssl_object: ssl.SSLObject | None = None
|
||||
self.last_disconnect_time: int | None = None
|
||||
|
||||
# Used to store outgoing ApplicationMessage while publish protocol flows
|
||||
self.inflight_out: OrderedDict[int, OutgoingApplicationMessage] = OrderedDict()
|
||||
|
@ -161,6 +172,7 @@ class Session:
|
|||
source="new",
|
||||
dest="connected",
|
||||
)
|
||||
self.transitions.on_enter_connected(self._on_enter_connected)
|
||||
self.transitions.add_transition(
|
||||
trigger="connect",
|
||||
source="disconnected",
|
||||
|
@ -171,6 +183,7 @@ class Session:
|
|||
source="connected",
|
||||
dest="disconnected",
|
||||
)
|
||||
self.transitions.on_enter_disconnected(self._on_enter_disconnected)
|
||||
self.transitions.add_transition(
|
||||
trigger="disconnect",
|
||||
source="new",
|
||||
|
@ -182,6 +195,20 @@ class Session:
|
|||
dest="disconnected",
|
||||
)
|
||||
|
||||
def _on_enter_connected(self) -> None:
|
||||
cur_time = floor(time.time())
|
||||
if self.last_disconnect_time is not None:
|
||||
logger.debug(f"Session reconnected after {cur_time - self.last_disconnect_time} seconds.")
|
||||
|
||||
self.last_connect_time = cur_time
|
||||
self.last_disconnect_time = None
|
||||
|
||||
def _on_enter_disconnected(self) -> None:
|
||||
cur_time = floor(time.time())
|
||||
if self.last_connect_time is not None:
|
||||
logger.debug(f"Session disconnected after {cur_time - self.last_connect_time} seconds.")
|
||||
self.last_disconnect_time = cur_time
|
||||
|
||||
@property
|
||||
def next_packet_id(self) -> int:
|
||||
self._packet_id = (self._packet_id % 65535) + 1
|
||||
|
|
|
@ -30,3 +30,7 @@ h2.doc-heading-parameter {
|
|||
.md-nav__link--active {
|
||||
color: #f15581 !important;
|
||||
}
|
||||
|
||||
.admonition {
|
||||
font-size: 16px !important;
|
||||
}
|
|
@ -1,5 +1,43 @@
|
|||
# Changelog
|
||||
|
||||
## 0.11.3
|
||||
|
||||
API changes:
|
||||
|
||||
- broker and client configuration via dataclasses and enums instead of unstructured dictionaries (backwards compatible)
|
||||
- `MESSAGE_RECIEVE` event moved to after topic filtering
|
||||
- `MESSAGE_BROADCAST` event added for prior topic filtering
|
||||
- `RETAINED_MESSAGE` event added for messages with retained flag or offline clients without setting clean session flag
|
||||
- method `retain_message` changed to coroutine (broker)
|
||||
- change `add_subscription` to a public method (broker)
|
||||
- add listener type for external servers and api method for passing new connection via `external_connected` method (broker)
|
||||
- for TLS sessions, properly load the key and cert file (client)
|
||||
- added abstract method `get_ssl_info` to `WriterAdapter`
|
||||
|
||||
Details:
|
||||
|
||||
* Structural elements for the 0.11.3 release https://github.com/Yakifo/amqtt/pull/265
|
||||
* Release Candidate Branch for 0.11.3 https://github.com/Yakifo/amqtt/pull/272
|
||||
* update the configuration for the broker running at test.amqtt.io https://github.com/Yakifo/amqtt/pull/271
|
||||
* Improved broker script logging https://github.com/Yakifo/amqtt/pull/277
|
||||
* test.amqtt.io dashboard cleanup https://github.com/Yakifo/amqtt/pull/278
|
||||
* Structured broker and client configurations https://github.com/Yakifo/amqtt/pull/269
|
||||
* Determine auth & topic access via external http server https://github.com/Yakifo/amqtt/pull/262
|
||||
* Plugin: authentication against a relational database https://github.com/Yakifo/amqtt/pull/280
|
||||
* Fixes #247 : expire disconnected sessions https://github.com/Yakifo/amqtt/pull/279
|
||||
* Expanded structure for plugin documentation https://github.com/Yakifo/amqtt/pull/281
|
||||
* Yakifo/amqtt#120 confirms : validate example is functioning https://github.com/Yakifo/amqtt/pull/284
|
||||
* Yakifo/amqtt#39 : adding W0718 'broad exception caught' https://github.com/Yakifo/amqtt/pull/285
|
||||
* Documentation improvement for 0.11.3 https://github.com/Yakifo/amqtt/pull/286
|
||||
* Plugin naming convention https://github.com/Yakifo/amqtt/pull/288
|
||||
* embed amqtt into an existing server https://github.com/Yakifo/amqtt/pull/283
|
||||
* Plugin: rebuild of session persistence https://github.com/Yakifo/amqtt/pull/256
|
||||
* Plugin: determine authentication based on X509 certificates https://github.com/Yakifo/amqtt/pull/264
|
||||
* Plugin: device 'shadows' to bridge device online/offline states https://github.com/Yakifo/amqtt/pull/282
|
||||
* Plugin: authenticate against LDAP server https://github.com/Yakifo/amqtt/pull/287
|
||||
* Sample: broker and client communicating with mqtt over unix socket https://github.com/Yakifo/amqtt/pull/291
|
||||
* Plugin: jwt authentication and authorization https://github.com/Yakifo/amqtt/pull/289
|
||||
|
||||
## 0.11.2
|
||||
|
||||
- config-file based plugin loading [PR #240](https://github.com/Yakifo/amqtt/pull/240)
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
{% extends "base.html" %}
|
||||
|
||||
{% block outdated %}
|
||||
You're not viewing the latest version.
|
||||
<a href="{{ '../' ~ base_url }}">
|
||||
<strong>Click here to go to latest.</strong>
|
||||
</a>
|
||||
{% endblock %}
|
|
@ -0,0 +1,39 @@
|
|||
# Relational Database for Authentication and Authorization
|
||||
|
||||
- `amqtt.contrib.auth_db.UserAuthDBPlugin` (authentication) verify a client's ability to connect to broker
|
||||
- `amqtt.contrib.auth_db.TopicAuthDBPlugin` (authorization) determine a client's access to topics
|
||||
|
||||
Relational database access is supported using SQLAlchemy so MySQL, MariaDB, Postgres and SQLite support is available.
|
||||
|
||||
For ease of use, the [`user_mgr` command-line utility](auth_db.md/#user_mgr) to add, remove, update and
|
||||
list clients. And the [`topic_mgr` command-line utility](auth_db.md/#topic_mgr) to add client access to
|
||||
subscribe, publish and receive messages on topics.
|
||||
|
||||
# Authentication Configuration
|
||||
|
||||
::: amqtt.contrib.auth_db.UserAuthDBPlugin.Config
|
||||
options:
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
# Authorization Configuration
|
||||
|
||||
::: amqtt.contrib.auth_db.TopicAuthDBPlugin.Config
|
||||
options:
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
## CLI
|
||||
|
||||
|
||||
::: mkdocs-typer2
|
||||
:module: amqtt.contrib.auth_db.user_mgr_cli
|
||||
:name: user_mgr
|
||||
|
||||
|
||||
::: mkdocs-typer2
|
||||
:module: amqtt.contrib.auth_db.topic_mgr_cli
|
||||
:name: topic_mgr
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
# Authentication Using Signed Certificates
|
||||
|
||||
Using client-specific certificates, signed by a common authority (even if self-signed) provides
|
||||
a highly secure way of authenticating mqtt clients. Often used with IoT devices where a unique
|
||||
certificate can be initialized on initial provisioning.
|
||||
|
||||
With so many options, X509 certificates can be daunting to create with `openssl`. Included are
|
||||
command line utilities to generate a root self-signed certificate and then the proper broker and
|
||||
device certificates with the correct X509 attributes to enable authenticity.
|
||||
|
||||
### Quick start
|
||||
|
||||
Generate a self-signed root credentials and server credentials:
|
||||
|
||||
```shell
|
||||
$ ca_creds --country US --state NY --locality NY --org-name "My Org's Name" --cn "my.domain.name"
|
||||
$ server_creds --country US --org-name "My Org's Name" --cn "my.domain.name"
|
||||
```
|
||||
|
||||
!!! warning "Security of private keys"
|
||||
Your root credential private key and your server key should *never* be shared with anyone. The
|
||||
certificates -- specifically the root CA certificate -- is completely safe to share and will need
|
||||
to be shared along with device credentials when using a self-signed CA.
|
||||
|
||||
Include in your server config:
|
||||
|
||||
```yaml
|
||||
listeners:
|
||||
ssl-mqtt:
|
||||
bind: "127.0.0.1:8883"
|
||||
ssl: true
|
||||
certfile: server.crt
|
||||
keyfile: server.key
|
||||
cafile: ca.crt
|
||||
plugins:
|
||||
amqtt.contrib.cert.CertificateAuthPlugin:
|
||||
uri_domain: my.domain.name
|
||||
```
|
||||
|
||||
Generate a device's credentials:
|
||||
|
||||
```shell
|
||||
$ device_creds --country US --org-name "My Org's Name" --device-id myUniqueDeviceId --uri my.domain.name
|
||||
```
|
||||
|
||||
And use these to initialize the `MQTTClient`:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from amqtt.client import MQTTClient
|
||||
|
||||
client_config = {
|
||||
'keyfile': 'myUniqueDeviceId.key',
|
||||
'certfile': 'myUniqueDeviceId.crt',
|
||||
'broker': {
|
||||
'cafile': 'ca.crt'
|
||||
}
|
||||
}
|
||||
|
||||
async def main():
|
||||
client = MQTTClient(config=client_config)
|
||||
await client.connect("mqtts://my.domain.name:8883")
|
||||
# publish messages or subscribe to receive
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Background
|
||||
|
||||
Often used for IoT devices, this method provides the most secure form of identification. A root
|
||||
certificate, often referenced as a CA certificate -- either issued by a known authority (such as LetsEncrypt)
|
||||
or a self-signed certificate) is used to sign a private key and certificate for the server. Each device/client
|
||||
also gets a unique private key and certificate signed by the same CA certificate; also included in the device
|
||||
certificate is a 'SAN' or SubjectAlternativeName which is the device's unique identifier.
|
||||
|
||||
Since both server and device certificates are signed by the same CA certificate, the client can
|
||||
verify the server's authenticity; and the server can verify the client's authenticity. And since
|
||||
the device's certificate contains a x509 SAN, the server (with this plugin) can identify the device securely.
|
||||
|
||||
!!! note "URI and Client ID configuration"
|
||||
`uri_domain` configuration must be set to the same uri used to generate the device credentials
|
||||
|
||||
when a device is connecting with private key and certificate, the `client_id` must
|
||||
match the device id used to generate the device credentials.
|
||||
|
||||
Available ore three scripts to help with the key generation and certificate signing: `ca_creds`, `server_creds`
|
||||
and `device_creds`.
|
||||
|
||||
!!! note "Configuring broker & client for using Self-signed root CA"
|
||||
If using self-signed root credentials, the `cafile` configuration for both broker and client need to be
|
||||
configured with `cafile` set to the `ca.crt`.
|
||||
|
||||
## Root & Certificate Credentials
|
||||
|
||||
The process for generating a server's private key and certificate is only done once. If you have a private key & certificate --
|
||||
such as one from verifying your webserver's domain with LetsEncrypt -- that you want to use, pass them to the `server_creds` cli.
|
||||
If you'd like to use a self-signed certificate, generate your own CA by running the `ca_creds` cli (make sure your client is
|
||||
configured with `check_hostname` as `False`).
|
||||
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
theme: redux
|
||||
---
|
||||
flowchart LR
|
||||
subgraph ca_cred["ca_cred #40;cli#41; or other CA"]
|
||||
ca["ca key & cert"]
|
||||
end
|
||||
|
||||
subgraph server_cred["server_cred fl°°40¶ßclifl°°41¶ß"]
|
||||
scsr("certificate signing<br>request fl°°40¶ßCSRfl°°41¶ß with<br>SAN of DNS & IP Address")
|
||||
spk["private key"]
|
||||
ssi["sign csr"]
|
||||
end
|
||||
|
||||
|
||||
spk -.-> skc["server key & cert"]
|
||||
ca_cred --> ssi
|
||||
spk --> scsr
|
||||
con["country, org<br>& common name"] --> scsr
|
||||
scsr --> ssi
|
||||
ssi --> skc
|
||||
```
|
||||
|
||||
## Device credentials
|
||||
|
||||
For each device, create a device id to generate a device-specific private key and certificate using the `device_creds` cli.
|
||||
Use the same CA as was used for the server (above) so the client & server recognize each other.
|
||||
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
theme: redux
|
||||
---
|
||||
flowchart LR
|
||||
subgraph ca_cred["ca_cred #40;cli#41; or other CA"]
|
||||
ca["ca key & cert"]
|
||||
end
|
||||
subgraph device_cred["device_cred fl°°40¶ßclifl°°41¶ß"]
|
||||
ccsr("certificate signing<br>request fl°°40¶ßCSRfl°°41¶ß with<br>SAN of URI & DNS")
|
||||
cpk["private key"]
|
||||
csi["sign csr"]
|
||||
end
|
||||
cpk --> ccsr
|
||||
csi --> ckc[device key & cert]
|
||||
cpk -.-> ckc
|
||||
ccsr --> csi
|
||||
ca_cred --> csi
|
||||
|
||||
con["country, org<br/>common name<br/>& device id"] --> ccsr
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
::: amqtt.contrib.cert.UserAuthCertPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
|
||||
## Key and Certificate Generation
|
||||
|
||||
::: mkdocs-typer2
|
||||
:module: amqtt.scripts.ca_creds
|
||||
:name: ca_creds
|
||||
|
||||
::: mkdocs-typer2
|
||||
:module: amqtt.scripts.server_creds
|
||||
:name: server_creds
|
||||
|
||||
::: mkdocs-typer2
|
||||
:module: amqtt.scripts.device_creds
|
||||
:name: device_creds
|
|
@ -0,0 +1,45 @@
|
|||
# Contributed Plugins
|
||||
|
||||
These are fully supported plugins but require additional dependencies to be installed:
|
||||
|
||||
`$ pip install '.[contrib]'`
|
||||
|
||||
|
||||
- [Relational Database Auth](auth_db.md)<br/>
|
||||
Grant or deny access to clients based on entries in a relational db (mysql, postgres, maria, sqlite). _Includes
|
||||
manager script to add, remove and create db entries_<br/>
|
||||
- `amqtt.contrib.auth_db.UserAuthDBPlugin`
|
||||
- `amqtt.contrib.auth_db.TopicAuthDBPlugin`
|
||||
|
||||
- [HTTP Auth](http.md)<br/>
|
||||
Determine client authentication and/or authorization based on response from a separate HTTP server.<br/>
|
||||
- `amqtt.contrib.http.UserAuthHttpPlugin`
|
||||
- `amqtt.contrib.http.TopicAuthHttpPlugin`
|
||||
|
||||
- [LDAP Auth](ldap.md)<br/>
|
||||
Authenticate a user with an LDAP directory server.<br/>
|
||||
- `amqtt.contrib.ldap.UserAuthLdapPlugin`
|
||||
- `amqtt.contrib.ldap.TopicAuthLdapPlugin`
|
||||
|
||||
- [Shadows](shadows.md)<br/>
|
||||
Device shadows provide a persistent, cloud-based representation of the state of a device,
|
||||
even when the device is offline. This plugin tracks the desired and reported state of a client
|
||||
and provides MQTT topic-based communication channels to retrieve and update a shadow.<br/>
|
||||
`amqtt.contrib.shadows.ShadowPlugin`
|
||||
|
||||
- [Certificate Auth](cert.md)<br/>
|
||||
Using client-specific certificates, signed by a common authority (even if self-signed) provides
|
||||
a highly secure way of authenticating mqtt clients. Often used with IoT devices where a unique
|
||||
certificate can be initialized on initial provisioning. _Includes command line utilities to generate
|
||||
root, broker and device certificates with the correct X509
|
||||
attributes to enable authenticity._<br/>
|
||||
`amqtt.contrib.cert.UserAuthCertPlugin`
|
||||
|
||||
- [JWT Auth](jwt.md)<br/>
|
||||
Plugin to determine user authentication and topic authorization based on claims in a JWT.
|
||||
- `amqtt.contrib.jwt.UserAuthJwtPlugin` (client authentication)
|
||||
- `amqtt.contrib.jwt.TopicAuthJwtPlugin` (topic authorization)
|
||||
|
||||
- [Session Persistence](session.md)<br/>
|
||||
Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.<br/>
|
||||
`amqtt.contrib.persistence.SessionDBPlugin`
|
|
@ -77,7 +77,9 @@ and then run via `amqtt -c myBroker.yaml`.
|
|||
variables to configure its behavior.
|
||||
|
||||
::: amqtt.plugins.base.BasePlugin
|
||||
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 3
|
||||
|
||||
|
||||
## Events
|
||||
|
@ -85,16 +87,13 @@ and then run via `amqtt -c myBroker.yaml`.
|
|||
|
||||
All plugins are notified of events if the `BasePlugin` subclass implements one or more of these methods:
|
||||
|
||||
### Client and Broker
|
||||
|
||||
### Client
|
||||
|
||||
- `async def on_mqtt_packet_sent(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
|
||||
### Client Only
|
||||
|
||||
none
|
||||
|
||||
### Broker Only
|
||||
### Broker
|
||||
|
||||
- `async def on_broker_pre_start(self) -> None`
|
||||
- `async def on_broker_post_start(self) -> None`
|
||||
|
@ -107,29 +106,59 @@ none
|
|||
- `async def on_broker_client_connected(self, *, client_id:str) -> None`
|
||||
- `async def on_broker_client_disconnected(self, *, client_id:str) -> None`
|
||||
|
||||
- `async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None`
|
||||
|
||||
- `async def on_broker_client_subscribed(self, *, client_id: str, topic: str, qos: int) -> None`
|
||||
- `async def on_broker_client_unsubscribed(self, *, client_id: str, topic: str) -> None`
|
||||
|
||||
- `async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None`
|
||||
- `async def on_broker_message_broadcast(self, *, client_id: str, message: ApplicationMessage) -> None`
|
||||
- `async def on_mqtt_packet_sent(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
|
||||
|
||||
!!! note retained message event
|
||||
if the `client_id` is `None`, the message is retained for a topic
|
||||
if the `retained_message.data` is `None` or empty (`''`), the topic message is being cleared
|
||||
|
||||
|
||||
## Authentication Plugins
|
||||
|
||||
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`
|
||||
is used by the aMQTT `Broker` to determine if a connection from a client is allowed by
|
||||
implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise.
|
||||
implementing the `authenticate` method and returning:
|
||||
- `True` if the session is allowed
|
||||
- `False` if not allowed
|
||||
- `None` if plugin can't determine authentication
|
||||
|
||||
If there are multiple authentication plugins:
|
||||
- at least one plugin must return `True` to allow access
|
||||
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
|
||||
- `None` gets ignored from the determination
|
||||
|
||||
::: amqtt.plugins.base.BaseAuthPlugin
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 3
|
||||
|
||||
## Topic Filter Plugins
|
||||
|
||||
In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin`
|
||||
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE)
|
||||
messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or
|
||||
`False` otherwise.
|
||||
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH), receive (RECEIVE)
|
||||
and/or subscribe (SUBSCRIBE) messages to a particular topic by implementing the `topic_filtering` method and returning:
|
||||
- `True` if topic is allowed
|
||||
- `False` if not allowed
|
||||
- `None` will be ignored
|
||||
|
||||
If there are multiple topic plugins:
|
||||
- at least one plugin must return `True` to allow access
|
||||
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
|
||||
- `None` will be ignored
|
||||
|
||||
::: amqtt.plugins.base.BaseTopicPlugin
|
||||
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 3
|
||||
|
||||
!!! note
|
||||
A custom plugin class can subclass from both `BaseAuthPlugin` and `BaseTopicPlugin` as long it defines
|
|
@ -0,0 +1,112 @@
|
|||
# Authentication & Authorization via external HTTP server
|
||||
|
||||
If clients accessing the broker are managed by another application, it can implement API endpoints
|
||||
that respond with information about client authentication and/or topic-level authorization.
|
||||
|
||||
- `amqtt.contrib.http.UserAuthHttpPlugin` (client authentication)
|
||||
- `amqtt.contrib.http.TopicAuthHttpPlugin` (topic authorization)
|
||||
|
||||
Configuration of these plugins is identical (except for the uri name) so that they can be used independently, if desired.
|
||||
|
||||
# User Auth
|
||||
|
||||
See the [Request and Response Modes](#request-response-modes) section below for details on `params_mode` and `response_mode`.
|
||||
|
||||
!!! info "browser-based mqtt over websockets"
|
||||
|
||||
One of the primary use cases for this plugin is to enable browser-based applications to communicate with mqtt
|
||||
over websockets.
|
||||
|
||||
!!! warning
|
||||
Care must be taken to make sure the mqtt password is secure (encrypted).
|
||||
|
||||
For more implementation information:
|
||||
|
||||
??? info "recipe for authentication"
|
||||
Provide the client id and username when webpage is initially rendered or passed to the mqtt initialization from stored
|
||||
cookies. If application is secure, the user's password will already be stored as a hashed value and, therefore, cannot
|
||||
be used in this context to authenticate a client. Instead, the application should create its own encrypted key (eg jwt)
|
||||
which the server can then verify when the broker contacts the application.
|
||||
|
||||
??? example "mqtt in javascript"
|
||||
Example initialization of mqtt in javascript:
|
||||
|
||||
import mqtt from 'mqtt';
|
||||
const url = 'https://path.to.amqtt.broker';
|
||||
const options = {
|
||||
'myclientid',
|
||||
connectTimeout: 30000,
|
||||
username: 'myclientid',
|
||||
password: '' // encrypted password
|
||||
};
|
||||
try {
|
||||
const clientMqtt = await mqtt.connect(url, options);
|
||||
|
||||
::: amqtt.contrib.http.UserAuthHttpPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
# Topic ACL
|
||||
|
||||
See the [Request and Response Modes](#request-response-modes) section below for details on `params_mode` and `response_mode`.
|
||||
|
||||
::: amqtt.contrib.http.TopicAuthHttpPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
[//]: # (manually creating the heading so it doesn't show in the sidebar ToC)
|
||||
[](){#request-response-modes}
|
||||
<h2>Request and Response Modes</h2>
|
||||
|
||||
Each URI endpoint will receive different information in order to determine authentication and authorization;
|
||||
format will depend on `params_mode` configuration attribute (`json` or `form`).:
|
||||
|
||||
*For user authentication, the request will contain:*
|
||||
|
||||
- username *(str)*
|
||||
- password *(str)*
|
||||
- client_id *(str)*
|
||||
|
||||
*For acl check, the request will contain:*
|
||||
|
||||
- username *(str)*
|
||||
- client_id *(str)*
|
||||
- topic *(str)*
|
||||
- acc *(int)* : client can receive (1), can publish(2), can receive & publish (3) and can subscribe (4)
|
||||
|
||||
All endpoints should respond with the following, dependent on `response_mode` configuration attribute:
|
||||
|
||||
*In `status` mode:*
|
||||
|
||||
- status code: 2xx (granted) or 4xx(denied) or 5xx (noop)
|
||||
|
||||
!!! note "5xx response"
|
||||
**noop** (no operation): plugin will not participate in the operation and will defer to another
|
||||
plugin to determine access. if there is no other auth/filtering plugin, access will be denied.
|
||||
|
||||
*In `json` mode:*
|
||||
|
||||
- status code: 2xx
|
||||
- content-type: application/json
|
||||
- response: {'ok': True } (granted)
|
||||
or {'ok': False, 'error': 'optional error message' } (denied)
|
||||
or { 'error': 'optional error message' } (noop)
|
||||
|
||||
!!! note "excluded 'ok' key"
|
||||
**noop** (no operation): plugin will not participate in the operation and will defer to another
|
||||
plugin to determine access. if there is no other auth/filtering plugin, access will be denied.
|
||||
|
||||
*In `text` mode:*
|
||||
|
||||
- status code: 2xx
|
||||
- content-type: text/plain
|
||||
- response: 'ok' or 'error'
|
||||
|
||||
!!! note "noop not supported"
|
||||
in text mode, noop (no operation) is not supported
|
|
@ -0,0 +1,50 @@
|
|||
# Authentication & Authorization from JWT
|
||||
|
||||
- `amqtt.contrib.jwt.UserAuthJwtPlugin` (client authentication)
|
||||
- `amqtt.contrib.jwt.TopicAuthJwtPlugin` (topic authorization)
|
||||
|
||||
Plugin to determine user authentication and topic authorization based on claims in a JWT.
|
||||
|
||||
# User Authentication
|
||||
|
||||
For auth, the JWT should include a key as specified in the configuration as `user_clam`:
|
||||
|
||||
```python
|
||||
from datetime import datetime, UTC, timedelta
|
||||
claims = {
|
||||
"username": "example_user",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
}
|
||||
```
|
||||
|
||||
::: amqtt.contrib.jwt.UserAuthJwtPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
|
||||
# Topic Authorization
|
||||
|
||||
For authorizing a client for certain topics, the token should also include claims for publish, subscribe and receive;
|
||||
keys based on how `publish_claim`, `subscribe_claim` and `receive_claim` are specified in the plugin's configuration.
|
||||
|
||||
```python
|
||||
from datetime import datetime, UTC, timedelta
|
||||
|
||||
claims = {
|
||||
"username": "example_user",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"publish_acl": ['my/topic/#', 'my/+/other'],
|
||||
"subscribe_acl": ['my/+/other'],
|
||||
"receive_acl": ['#']
|
||||
}
|
||||
```
|
||||
|
||||
::: amqtt.contrib.jwt.TopicAuthJwtPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
|
@ -0,0 +1,25 @@
|
|||
# Authentication with LDAP Server
|
||||
|
||||
If clients accessing the broker are managed by an LDAP server, this plugin can verify credentials
|
||||
for client authentication and/or topic-level authorization.
|
||||
|
||||
- `amqtt.contrib.ldap.UserAuthLdapPlugin` (client authentication)
|
||||
- `amqtt.contrib.ldap.TopicAuthLdapPlugin` (topic authorization)
|
||||
|
||||
Authenticate a user with an LDAP directory server.
|
||||
|
||||
# User Auth
|
||||
|
||||
::: amqtt.contrib.ldap.UserAuthLdapPlugin.Config
|
||||
options:
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
# Topic Auth (ACL)
|
||||
|
||||
::: amqtt.contrib.ldap.TopicAuthLdapPlugin.Config
|
||||
options:
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
|
@ -1,4 +1,4 @@
|
|||
# Existing Plugins
|
||||
# Packaged Plugins
|
||||
|
||||
With the aMQTT plugins framework, one can add additional functionality without
|
||||
having to rewrite core logic in the broker or client. Plugins can be loaded and configured using
|
||||
|
@ -23,7 +23,7 @@ and configured for the broker:
|
|||
--8<-- "pyproject.toml:included"
|
||||
```
|
||||
|
||||
But the same 4 plugins were activated in the previous default config:
|
||||
But the previous default config only caused 4 plugins to be active:
|
||||
|
||||
```yaml
|
||||
--8<-- "samples/legacy.yaml"
|
||||
|
@ -31,7 +31,7 @@ and configured for the broker:
|
|||
|
||||
## Client
|
||||
|
||||
By default, the `PacketLoggerPlugin` is activated and configured for the client:
|
||||
By default, the `PacketLoggerPlugin` is activated and configured for the client:
|
||||
|
||||
```yaml
|
||||
--8<-- "amqtt/scripts/default_client.yaml"
|
||||
|
@ -43,15 +43,13 @@ By default, the `PacketLoggerPlugin` is activated and configured for the clien
|
|||
|
||||
`amqtt.plugins.authentication.AnonymousAuthPlugin`
|
||||
|
||||
**Configuration**
|
||||
Authentication plugin allowing anonymous access.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
.
|
||||
.
|
||||
amqtt.plugins.authentication.AnonymousAuthPlugin:
|
||||
allow_anonymous: false
|
||||
```
|
||||
::: amqtt.plugins.authentication.AnonymousAuthPlugin.Config
|
||||
options:
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
!!! danger
|
||||
even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
|
||||
|
@ -71,15 +69,14 @@ plugins:
|
|||
|
||||
`amqtt.plugins.authentication.FileAuthPlugin`
|
||||
|
||||
clients are authorized by providing username and password, compared against file
|
||||
Authentication plugin based on a file-stored user database.
|
||||
|
||||
**Configuration**
|
||||
::: amqtt.plugins.authentication.FileAuthPlugin.Config
|
||||
options:
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
amqtt.plugins.authentication.FileAuthPlugin:
|
||||
password_file: /path/to/password_file
|
||||
```
|
||||
|
||||
??? warning "EntryPoint-style configuration is deprecated"
|
||||
```yaml
|
||||
|
@ -134,13 +131,26 @@ plugins:
|
|||
|
||||
**Configuration**
|
||||
|
||||
- `acl` *(mapping)*: determines subscription access
|
||||
The list should be a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
Each acl category are a list a key-value pair, where:
|
||||
|
||||
> `<username>:["<topic1>", "<topic2>", ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing.
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
!!! info "`#` and `$SYS` topics"
|
||||
|
||||
Per the MQTT 3.1.1 spec 4.7.2, a single `#` will not allow access to `$` broker
|
||||
topics; need to additionally specify `$SYS/#` to allow a client full access subscribe & receive.
|
||||
|
||||
Also MQTT spec prevents clients from publishing to topics starting with `$`; these will be ignored.
|
||||
|
||||
If set to `None`, no restrictions are placed on client subscriptions (legacy behavior). An empty list will block clients from using any topics.
|
||||
|
||||
- `subscribe-acl` *(mapping)*: determines subscription access.
|
||||
|
||||
- `acl` *(mapping)*: Deprecated and replaced by `subscribe-acl`.
|
||||
|
||||
- `publish-acl` *(mapping)*: determines publish access.
|
||||
|
||||
- `receive-acl` *(mapping)*: determines if a message can be sent to a client.
|
||||
|
||||
!!! info "Reserved usernames"
|
||||
|
||||
|
@ -235,7 +245,7 @@ plugins:
|
|||
|
||||
`amqtt.plugins.logging_amqtt.PacketLoggerPlugin`
|
||||
|
||||
This plugin issues debug-level messages for [mqtt events](custom_plugins.md#client-and-broker): `on_mqtt_packet_sent`
|
||||
This plugin issues debug-level messages for [mqtt events](custom_plugins.md#events): `on_mqtt_packet_sent`
|
||||
and `on_mqtt_packet_received`.
|
||||
|
||||
```yaml
|
|
@ -0,0 +1,12 @@
|
|||
# Session Persistence
|
||||
|
||||
`amqtt.contrib.persistence.SessionDBPlugin`
|
||||
|
||||
Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
|
||||
|
||||
::: amqtt.contrib.persistence.SessionDBPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
|
@ -0,0 +1,202 @@
|
|||
# Device Shadows Plugin
|
||||
|
||||
Device shadows provide a persistent, cloud-based representation of the state of a device,
|
||||
even when the device is offline. This plugin tracks the desired and reported state of a client
|
||||
and provides MQTT topic-based communication channels to retrieve and update a shadow.
|
||||
|
||||
Typically, this structure is used for MQTT IoT devices to communicate with a central application.
|
||||
|
||||
This plugin is patterned after [AWS's IoT Shadow](https://docs.aws.amazon.com/iot/latest/developerguide/iot-device-shadows.html) service.
|
||||
|
||||
## How it works
|
||||
|
||||
All shadow states are associated with a `device id` and `name` and have the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"state": {
|
||||
"desired": {
|
||||
"property1": "value1"
|
||||
},
|
||||
"reported": {
|
||||
"property1": "value1"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"desired": {
|
||||
"property1": {
|
||||
"timestamp": 1623855600
|
||||
}
|
||||
},
|
||||
"reported": {
|
||||
"property1": {
|
||||
"timestamp": 1623855602
|
||||
}
|
||||
}
|
||||
},
|
||||
"version": 10,
|
||||
"timestamp": 1623855602
|
||||
}
|
||||
```
|
||||
|
||||
The `state` is updated by messages to shadow topics and includes key/value pairs, where the value can be any valid
|
||||
json object (int, string, dictionary, list, etc). `metadata` is automatically updated by the plugin based on when
|
||||
the key/values were most recently updated. Both `state` and `metadata` are split between:
|
||||
|
||||
- desired: the intended state of a device
|
||||
- reported: the actual state of a device
|
||||
|
||||
A client can update a part or all of the desired or reported state. On any update, the plugin:
|
||||
|
||||
- updates the 'state' portion of the shadow with any key/values provided in the update
|
||||
- stores a version of the update
|
||||
- tracks the timestamp of each key/value pair change
|
||||
- sends messages that the shadow was updated
|
||||
|
||||
## Typical usage
|
||||
|
||||
As mentioned above, this plugin is often used for MQTT IoT devices to communicate with a central application. The
|
||||
app pushes updates to a device's 'desired' shadow state and the device can confirm the change was made by updating
|
||||
the 'reported' state. With this sequence the 'desired' state matches the 'reported' state and the delta message is empty.
|
||||
|
||||
In most situations, the app only updates the 'desired' state and the device only updates the 'reported' state.
|
||||
|
||||
If online, the IoT device will receive and can act on that information immediately. If offline, the app doesn't need
|
||||
to republish or retry a change 'command', waiting for an acknowledgement from the device. If a device is offline, it
|
||||
simply retrieves the configuration changes when it comes back online.
|
||||
|
||||
Once a device receives its desired state, it should either (1) update its reported state to match the change in desired
|
||||
or (2) if the desired state is invalid, clear that key/value from the desired state. The latter is the only case
|
||||
when a device should update its own 'desired' state.
|
||||
|
||||
For example, if the app sends a command to set the brightness of a device to 100 lumens, but the device only supports
|
||||
a maximum of 80, it can send an update `'state': {'desired': {'lumens': null}}` to clear the invalid state.
|
||||
|
||||
The reported state can (and most likely will) include key/values that will never show up in the desired state. For
|
||||
example, the app might set the thermostat to 70 and the device reports both the configuration change of 70 to the
|
||||
thermostat *and* the current temperature of the room.
|
||||
|
||||
```json
|
||||
{
|
||||
"state": {
|
||||
"desired": {
|
||||
"thermostat": 68
|
||||
},
|
||||
"reported": {
|
||||
"thermostat": 68,
|
||||
"temperature": 78
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
!!! note "desired and reported state structure"
|
||||
It is important that both the app and the device have the same understanding of the key/value
|
||||
state structure and units. Creating [JSON schemas](https://json-schema.org/) for desired and
|
||||
reported shadow states are very useful as it can provide a clear way of describing the schema.
|
||||
These schemas can also be used to generate [dataclasses](https://pypi.org/project/datamodel-code-generator/),
|
||||
[pojos](https://github.com/joelittlejohn/jsonschema2pojo) or [many other language constructs](https://json-schema.org/tools?query=&sortBy=name&sortOrder=ascending&groupBy=toolingTypes&licenses=&languages=&drafts=&toolingTypes=&environments=&showObsolete=false&supportsBowtie=false#schema-to-code) that
|
||||
can be easily included by both app and device to make state encoding and decoding consistent.
|
||||
|
||||
## Shadow state access
|
||||
|
||||
All shadows are addressed by using specific topics, all of which have the following base:
|
||||
|
||||
`$shadow/<device_id>/<shadow name>`
|
||||
|
||||
Clients send either `get` or `update` messages:
|
||||
|
||||
| Operation | Topic | Direction | Payload |
|
||||
|-------------------------|-------------------------------------------------------|-----------|-----------------------------------------------------|
|
||||
| **Update** | `$shadow/{device_id}/{shadow_name}/update` | → | `{ "state": { "desired" or "reported": ... } }` |
|
||||
| **Get** | `$shadow/{device_id}/{shadow_name}/get` | → | Empty message triggers get accepted or rejected |
|
||||
|
||||
Then clients can subscribe to any or all of these topics which receive messages issued by the plugin:
|
||||
|
||||
| Operation | Topic | Direction | Payload |
|
||||
|-------------------------|-------------------------------------------------------|-----------|-----------------------------------------------------|
|
||||
| **Update Accepted** | `$shadow/{device_id}/{shadow_name}/update/accepted` | ← | Full updated document |
|
||||
| **Update Rejected** | `$shadow/{device_id}/{shadow_name}/update/rejected` | ← | Error message |
|
||||
| **Update Documents** | `$shadow/{device_id}/{shadow_name}/update/documents` | ← | Full current & previous shadow documents |
|
||||
| **Get Accepted** | `$shadow/{device_id}/{shadow_name}/get/accepted` | ← | Full shadow document |
|
||||
| **Get Rejected** | `$shadow/{device_id}/{shadow_name}/get/rejected` | ← | Error message |
|
||||
| **Delta** | `$shadow/{device_id}/{shadow_name}/update/delta` | ← | Difference between desired and reported |
|
||||
| **Iota** | `$shadow/{device_id}/{shadow_name}/update/iota` | ← | Difference between desired and reported, with nulls |
|
||||
|
||||
## Delta messages
|
||||
|
||||
While the 'accepted' and 'documents' messages carry the full desired and reported states, this plugin also generates
|
||||
a 'delta' message - containing items in the desired state that are different from those items in the reported state. This
|
||||
topic optimizes for IoT devices which typically have lower bandwidth and not as powerful processing by (1) to reducing the
|
||||
amount of data transmitted and (2) simplifying device implementation as it only needs to respond to differences.
|
||||
|
||||
While shadows are stateful since delta messages are only based on the desired and reported state and *not on the previous
|
||||
and current state*. Therefore, it doesn't matter if an IoT device is offline and misses a delta message. When it comes
|
||||
back online, the delta is identical.
|
||||
|
||||
This is also an improvement over a connection without the clean flag and QoS > 0. When an IoT device comes back online, bandwidth
|
||||
isn't consumed and the IoT device does not have to process a backlog of messages to understand how it should behave.
|
||||
For a setting -- such as volume -- that goes from 80 then to 91 and then to 60 while the device is offline, it will
|
||||
only receive a single change that its volume should now be 60.
|
||||
|
||||
|
||||
| Reported Shadow State | Desired Shadow State | Resulting Delta Message (`delta`) |
|
||||
|----------------------------------------|------------------------------------------|---------------------------------------|
|
||||
| `{ "temperature": 70 }` | `{ "temperature": 72 }` | `{ "temperature": 72 }` |
|
||||
| `{ "led": "off", "fan": "low" }` | `{ "led": "on", "fan": "low" }` | `{ "led": "on" }` |
|
||||
| `{ "door": "closed" }` | `{ "door": "closed", "alarm": "armed" }` | `{ "alarm": "armed" }` |
|
||||
| `{ "volume": 10 }` | `{ "volume": 10 }` | *(no delta; states match)* |
|
||||
| `{ "brightness": 100 }` | `{ "brightness": 80, "mode": "eco" }` | `{ "brightness": 80, "mode": "eco" }` |
|
||||
| `{ "levels": [1, 10, 4]}` | `{"levels": [1, 4, 10]}` | `{"levels": [1, 4, 10]}` |
|
||||
| `{ "brightness": 100, "mode": "eco" }` | `{ "brightness": 80 }` | `{ "brightness": 80}` |
|
||||
|
||||
|
||||
## Iota messages
|
||||
|
||||
Typically, null values never show in any received update message as a null signals the removal of a key from the desired
|
||||
or reported state. However, if the app removes a key from the desired state -- such as a piece of state that is no longer
|
||||
needed or applicable -- the device won't receive any notification of this deletion in a delta messages.
|
||||
|
||||
These messages are very similar to 'delta' messages as they also contain items in the desired state that are different from
|
||||
those in the reported state; it *also* contains any items in the reported state that are *missing* from the desired
|
||||
state (last row in table).
|
||||
|
||||
| Reported Shadow State | Desired Shadow State | Resulting Delta Message (`delta`) |
|
||||
|----------------------------------------|------------------------------------------|-----------------------------------------|
|
||||
| `{ "temperature": 70 }` | `{ "temperature": 72 }` | `{ "temperature": 72 }` |
|
||||
| `{ "led": "off", "fan": "low" }` | `{ "led": "on", "fan": "low" }` | `{ "led": "on" }` |
|
||||
| `{ "door": "closed" }` | `{ "door": "closed", "alarm": "armed" }` | `{ "alarm": "armed" }` |
|
||||
| `{ "volume": 10 }` | `{ "volume": 10 }` | *(no delta; states match)* |
|
||||
| `{ "brightness": 100 }` | `{ "brightness": 80, "mode": "eco" }` | `{ "brightness": 80, "mode": "eco" }` |
|
||||
| `{ "levels": [1, 10, 4]}` | `{"levels": [1, 4, 10]}` | `{"levels": [1, 4, 10]}` |
|
||||
| `{ "brightness": 100, "mode": "eco" }` | `{ "brightness": 80 }` | `{ "brightness": 80, "mode": null }` |
|
||||
|
||||
## Configuration
|
||||
|
||||
::: amqtt.contrib.shadows.ShadowPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
|
||||
## Security
|
||||
|
||||
Often a device only needs access to get/update and receive changes in its own shadow state. In addition to the `ShadowPlugin`,
|
||||
included is the `ShadowTopicAuthPlugin`. This allows (authorizes) a device to only subscribe, publish and receive its own topics.
|
||||
|
||||
|
||||
::: amqtt.contrib.shadows.ShadowTopicAuthPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
!!! warning
|
||||
|
||||
`ShadowTopicAuthPlugin` only handles topic authorization. Another plugin should be used to authenticate client device
|
||||
connections to the broker. See [file auth](packaged_plugins.md#password-file-auth-plugin),
|
||||
[http auth](http.md), [db auth](auth_db.md) or [certificate auth](cert.md) plugins. Or create your own:
|
||||
[auth plugins](custom_plugins.md#authentication-plugins):
|
|
@ -1,4 +1,4 @@
|
|||
#
|
||||
# Broker
|
||||
|
||||
::: mkdocs-typer2
|
||||
:module: amqtt.scripts.broker_script
|
||||
|
|
|
@ -21,15 +21,5 @@ The `amqtt.broker` module provides the following key methods in the `Broker` cla
|
|||
- `start()`: Starts the broker and begins serving
|
||||
- `shutdown()`: Gracefully shuts down the broker
|
||||
|
||||
### Broker configuration
|
||||
|
||||
The `Broker` class's `__init__` method accepts a `config` parameter which allows setup of default and custom behaviors.
|
||||
|
||||
Details on the `config` parameter structure is a dictionary whose structure is identical to yaml formatted file[^1]
|
||||
used by the included broker script: [broker configuration](broker_config.md)
|
||||
|
||||
|
||||
|
||||
::: amqtt.broker.Broker
|
||||
|
||||
[^1]: See [PyYAML](http://pyyaml.org/wiki/PyYAMLDocumentation) for loading YAML files as Python dict.
|
||||
|
|
|
@ -1,42 +1,31 @@
|
|||
# Broker Configuration
|
||||
|
||||
This configuration structure is valid as a python dictionary passed to the `amqtt.broker.Broker` class's `__init__` method or
|
||||
as a yaml formatted file passed to the `amqtt` script.
|
||||
|
||||
### `listeners` *(list[dict[str, Any]])*
|
||||
|
||||
Defines the network listeners used by the service. Items defined in the `default` listener will be
|
||||
applied to all other listeners, unless they are overridden by the configuration for the specific
|
||||
listener.
|
||||
|
||||
- `default` | `<listener_name>`: Named listener
|
||||
- `type` *(string)*: Transport type. Can be `tcp` or `ws`.
|
||||
- `bind` *(string)*: IP address and port (e.g., `0.0.0.0:1883`)
|
||||
- `max-connections` *(integer)*: Maximum number of clients that can connect to this interface
|
||||
- `ssl` *(string)*: Enable SSL connection. Can be `on` or `off` (default: off).
|
||||
- `cafile` *(string)*: Path to a file of concatenated CA certificates in PEM format. See [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.
|
||||
- `capath` *(string)*: Path to a directory containing several CA certificates in PEM format, following an [OpenSSL specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).
|
||||
- `cadata` *(string)*: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
|
||||
- `certfile` *(string)*: Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity.
|
||||
- `keyfile` *(string): A file containing the private key. Otherwise the private key will be taken from `certfile` as well.
|
||||
|
||||
### `timeout-disconnect-delay` *(int)*
|
||||
|
||||
Client disconnect timeout without a keep-alive.
|
||||
|
||||
### `plugins` *(mapping)*
|
||||
|
||||
A list of strings representing the modules and class name of `BasePlugin`, `BaseAuthPlugin` and `BaseTopicPlugins`. Each
|
||||
entry may have one or more configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md)
|
||||
|
||||
|
||||
??? warning "Deprecated: `sys_interval` "
|
||||
**`sys_interval`** *(int)*
|
||||
|
||||
System status report interval in seconds, used by the `amqtt.plugins.sys.broker.BrokerSysPlugin`.
|
||||
This configuration structure is a `amqtt.contexts.BrokerConfig` or a python dictionary with the same structure
|
||||
when instantiating `amqtt.broker.Broker` or as a yaml formatted file passed to the `amqtt` script.
|
||||
|
||||
If not specified, the `Broker()` will be started with the default `BrokerConfig()`, as represented in yaml format:
|
||||
|
||||
```yaml
|
||||
---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
timeout_disconnect_delay: 0
|
||||
plugins:
|
||||
amqtt.plugins.logging_amqtt.EventLoggerPlugin:
|
||||
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
amqtt.plugins.authentication.AnonymousAuthPlugin:
|
||||
allow_anonymous: true
|
||||
amqtt.plugins.sys.broker.BrokerSysPlugin:
|
||||
sys_interval: 20
|
||||
```
|
||||
|
||||
::: amqtt.contexts.BrokerConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
??? warning "Deprecated: `auth` configuration settings"
|
||||
|
||||
|
@ -64,6 +53,13 @@ entry may have one or more configuration settings. For more information, see the
|
|||
|
||||
- `password-file` *(string)*. Path to sha-512 encoded password file, used by `amqtt.plugins.authentication.FileAuthPlugin`.
|
||||
|
||||
|
||||
??? warning "Deprecated: `sys_interval` "
|
||||
**`sys_interval`** *(int)*
|
||||
|
||||
System status report interval in seconds, used by the `amqtt.plugins.sys.broker.BrokerSysPlugin`.
|
||||
|
||||
|
||||
??? warning "Deprecated: `topic-check` configuration settings"
|
||||
|
||||
|
||||
|
@ -88,19 +84,17 @@ entry may have one or more configuration settings. For more information, see the
|
|||
- The username `admin` is allowed access to all topic.
|
||||
- The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Default Configuration
|
||||
|
||||
```yaml
|
||||
--8<-- "amqtt/scripts/default_broker.yaml"
|
||||
```
|
||||
::: amqtt.contexts.ListenerConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
## Example
|
||||
|
||||
When a configuration is passed to the `amqtt` script, here is the equivalent format based on the structures above:
|
||||
|
||||
```yaml
|
||||
listeners:
|
||||
default:
|
||||
|
@ -149,7 +143,7 @@ This configuration file would create the following listeners:
|
|||
- `my-ws-1`: an unsecured websocket listener on port 9001 allowing `500` clients connections simultaneously
|
||||
- `my-wss-1`: a secured websocket listener on port 9003 allowing `500`
|
||||
|
||||
And enable the following access controls:
|
||||
And enable the following topic access:
|
||||
|
||||
- `username1` to login and subscribe/publish to topics `repositories/+/master`, `calendar/#` and `data/memes`
|
||||
- `username2` to login and subscribe/publish to topics `calendar/2025/#` and `data/memes`
|
||||
|
|
|
@ -129,15 +129,4 @@ amqtt/LYRf52W[56SOjW04 <-in-- PubcompPacket(ts=2015-11-11 21:54:48.713107, fixed
|
|||
|
||||
Both coroutines have the same results except that `test_coro2()` manages messages flow in parallel which may be more efficient.
|
||||
|
||||
|
||||
### Client configuration
|
||||
|
||||
The `MQTTClient` class's `__init__` method accepts a `config` parameter which allows setup of default and custom behaviors.
|
||||
|
||||
Details on the `config` parameter structure is a dictionary whose structure is identical to yaml formatted file[^1]
|
||||
used by the included broker script: [client configuration](client_config.md)
|
||||
|
||||
::: amqtt.client.MQTTClient
|
||||
|
||||
[^1]: See [PyYAML](http://pyyaml.org/wiki/PyYAMLDocumentation) for loading YAML files as Python dict.
|
||||
|
||||
|
|
|
@ -1,97 +1,57 @@
|
|||
# Client Configuration
|
||||
|
||||
This configuration structure is valid as a python dictionary passed to the `amqtt.broker.MQTTClient` class's `__init__` method or
|
||||
as a yaml formatted file passed to the `amqtt_pub` script.
|
||||
This configuration structure is either a `amqtt.contexts.ClientConfig` or a python dictionary with identical structure
|
||||
when instantiating `amqtt.broker.MQTTClient` or as a yaml formatted file passed to the `amqtt_pub` script.
|
||||
|
||||
### `keep_alive` *(int)*
|
||||
|
||||
Keep-alive timeout sent to the broker. Defaults to `10` seconds.
|
||||
|
||||
### `ping_delay` *(int)*
|
||||
|
||||
Auto-ping delay before keep-alive timeout. Defaults to 1. Setting to `0` will disable to 0 and may lead to broker disconnection.
|
||||
|
||||
### `default_qos` *(int: 0-2)*
|
||||
|
||||
Default QoS for messages published. Defaults to 0.
|
||||
|
||||
|
||||
### `default_retain` *(bool)*
|
||||
|
||||
Default retain value to messages published. Defaults to `false`.
|
||||
|
||||
|
||||
### `auto_reconnect` *(bool)*
|
||||
|
||||
Enable or disable auto-reconnect if connection with the broker is interrupted. Defaults to `false`.
|
||||
|
||||
|
||||
### `connect_timeout` *(int)*
|
||||
|
||||
If specified, the number of seconds before a connection times out
|
||||
|
||||
### `reconnect_retries` *(int)*
|
||||
|
||||
Maximum reconnection retries. Defaults to `2`. Negative value will cause client to reconnect infinitely.
|
||||
|
||||
### `reconnect_max_interval` *(int)*
|
||||
|
||||
Maximum interval between 2 connection retry. Defaults to `10`.
|
||||
|
||||
### `cleansession` *(bool)*
|
||||
|
||||
Upon reconnect, should subscriptions be cleared. Defaults to `true`.
|
||||
|
||||
### `topics` *(list[mapping])*
|
||||
|
||||
Specify the topics and what flags should be set for messages published to them.
|
||||
|
||||
- `<topic>`: Named listener
|
||||
- `qos` *(int, 0-3)*:
|
||||
- `retain` *(bool)*:
|
||||
|
||||
### `will` *(mapping)*
|
||||
|
||||
If included, the message that should be sent if the client disconnects.
|
||||
|
||||
- `topic` *(string)*:
|
||||
- `message` *(string)*:
|
||||
- `qos` *(int): 0, 1 or 2
|
||||
- `retain`: *(bool)* new clients subscribing to `topic` will receive this message
|
||||
|
||||
### `broker` *(mapping)*
|
||||
|
||||
- `uri` *(string)*: Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*
|
||||
|
||||
TLS certificates used to verify the broker's authenticity.
|
||||
|
||||
- `cafile` *(string)*: Path to a file of concatenated CA certificates in PEM format. See [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.
|
||||
- `capath` *(string)*: Path to a directory containing several CA certificates in PEM format, following an [OpenSSL specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).
|
||||
- `cadata` *(string)*: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
|
||||
|
||||
|
||||
### `certfile` *(string)*
|
||||
|
||||
Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the server certificate's authenticity.
|
||||
|
||||
### `check_hostname` *(bool)*
|
||||
|
||||
Bypass ssl host certificate verification, allowing self-signed certificates
|
||||
|
||||
### `plugins` *(mapping)*
|
||||
|
||||
A list of strings representing the modules and class name of any `BasePlugin`s. Each entry may have one or more
|
||||
configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md)
|
||||
|
||||
|
||||
## Default Configuration
|
||||
If not specified, the `MQTTClient()` will be started with the default `ClientConfig()`, as represented in yaml format:
|
||||
|
||||
```yaml
|
||||
--8<-- "amqtt/scripts/default_client.yaml"
|
||||
---
|
||||
keep_alive: 10
|
||||
ping_delay: 1
|
||||
default_qos: 0
|
||||
default_retain: false
|
||||
auto_reconnect: true
|
||||
connection_timeout: 60
|
||||
reconnect_retries: 2
|
||||
reconnect_max_interval: 10
|
||||
cleansession: true
|
||||
broker:
|
||||
uri: "mqtt://127.0.0.1"
|
||||
plugins:
|
||||
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
```
|
||||
|
||||
::: amqtt.contexts.ClientConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
::: amqtt.contexts.TopicConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
::: amqtt.contexts.WillConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
::: amqtt.contexts.ConnectionConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
A more expansive `ClientConfig` in equivalent yaml format:
|
||||
|
||||
```yaml
|
||||
|
||||
keep_alive: 10
|
||||
|
@ -102,9 +62,9 @@ auto_reconnect: true
|
|||
reconnect_max_interval: 5
|
||||
reconnect_retries: 10
|
||||
topics:
|
||||
test:
|
||||
topic/subtopic:
|
||||
qos: 0
|
||||
some_topic:
|
||||
topic/other:
|
||||
qos: 2
|
||||
retain: true
|
||||
will:
|
||||
|
@ -113,9 +73,8 @@ will:
|
|||
qos: 1
|
||||
retain: false
|
||||
broker:
|
||||
uri: mqtt://localhost:1883
|
||||
cafile: /path/to/ca/file
|
||||
uri: 'mqtt://localhost:1883'
|
||||
cafile: '/path/to/ca/file'
|
||||
plugins:
|
||||
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
|
||||
|
||||
```
|
||||
|
|
|
@ -2,20 +2,26 @@
|
|||
|
||||
This document describes `aMQTT` common API both used by [MQTT Client](client.md) and [Broker](broker.md).
|
||||
|
||||
## Reference
|
||||
## ApplicationMessage
|
||||
|
||||
### ApplicationMessage
|
||||
::: amqtt.session.ApplicationMessage
|
||||
options:
|
||||
heading_level: 3
|
||||
|
||||
The `amqtt.session` module provides the following message classes:
|
||||
## IncomingApplicationMessage
|
||||
|
||||
#### ApplicationMessage
|
||||
Represents messages received from MQTT clients.
|
||||
|
||||
Base class for MQTT application messages.
|
||||
::: amqtt.session.IncomingApplicationMessage
|
||||
options:
|
||||
heading_level: 3
|
||||
|
||||
#### IncomingApplicationMessage
|
||||
|
||||
Inherits from ApplicationMessage. Represents messages received from MQTT clients.
|
||||
|
||||
#### OutgoingApplicationMessage
|
||||
## OutgoingApplicationMessage
|
||||
|
||||
Inherits from ApplicationMessage. Represents messages to be sent to MQTT clients.
|
||||
|
||||
::: amqtt.session.OutgoingApplicationMessage
|
||||
options:
|
||||
heading_level: 3
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import ast
|
||||
import pprint
|
||||
from typing import Any
|
||||
|
||||
import griffe
|
||||
from griffe import Inspector, ObjectNode, Visitor, Attribute
|
||||
|
||||
from amqtt.contexts import default_listeners, default_broker_plugins, default_client_plugins
|
||||
from amqtt.contrib.auth_db.plugin import default_hash_scheme
|
||||
|
||||
default_factory_map = {
|
||||
'default_listeners': default_listeners(),
|
||||
'default_broker_plugins': default_broker_plugins(),
|
||||
'default_client_plugins': default_client_plugins(),
|
||||
'default_hash_scheme': default_hash_scheme()
|
||||
}
|
||||
|
||||
def get_qualified_name(node: ast.AST) -> str | None:
|
||||
"""Recursively build the qualified name from an AST node."""
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
parent = get_qualified_name(node.value)
|
||||
if parent:
|
||||
return f"{parent}.{node.attr}"
|
||||
return node.attr
|
||||
elif isinstance(node, ast.Call):
|
||||
# e.g., uuid.uuid4()
|
||||
return get_qualified_name(node.func)
|
||||
return None
|
||||
|
||||
def get_fully_qualified_name(call_node):
|
||||
"""
|
||||
Extracts the fully qualified name from an ast.Call node.
|
||||
"""
|
||||
if isinstance(call_node.func, ast.Name):
|
||||
# Direct function call (e.g., "my_function(arg)")
|
||||
return call_node.func.id
|
||||
elif isinstance(call_node.func, ast.Attribute):
|
||||
# Method call or qualified name (e.g., "obj.method(arg)" or "module.submodule.function(arg)")
|
||||
parts = []
|
||||
current = call_node.func
|
||||
while isinstance(current, ast.Attribute):
|
||||
parts.append(current.attr)
|
||||
current = current.value
|
||||
if isinstance(current, ast.Name):
|
||||
parts.append(current.id)
|
||||
return ".".join(reversed(parts))
|
||||
else:
|
||||
# Handle other potential cases (e.g., ast.Subscript) if necessary
|
||||
return None
|
||||
|
||||
def get_callable_name(node):
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
return f"{get_callable_name(node.value)}.{node.attr}"
|
||||
return None
|
||||
|
||||
def evaluate_callable_node(node):
|
||||
try:
|
||||
# Wrap the node in an Expression so it can be compiled
|
||||
expr = ast.Expression(body=node)
|
||||
compiled = compile(expr, filename="<ast>", mode="eval")
|
||||
return eval(compiled, {"__builtins__": __builtins__, "list": list, "dict": dict})
|
||||
except Exception as e:
|
||||
return f"<unresolvable: {e}>"
|
||||
|
||||
class DataclassDefaultFactoryExtension(griffe.Extension):
|
||||
"""Renders the output of a dataclasses field which uses a default factory.
|
||||
|
||||
def other_field_defaults():
|
||||
return {'item1': 'value1', 'item2': 'value2'}
|
||||
|
||||
@dataclass
|
||||
class MyDataClass:
|
||||
my_field: dict[str, Any] = field(default_factory=dict)
|
||||
my_other_field: dict[str, Any] = field(default_factory=other_field_defaults)
|
||||
|
||||
instead of documentation rendering this as:
|
||||
|
||||
```
|
||||
class MyDataClass:
|
||||
my_field: dict[str, Any] = dict()
|
||||
my_other_field: dict[str, Any] = other_field_defaults()
|
||||
```
|
||||
|
||||
it will be displayed with the output of factory functions for more clarity:
|
||||
|
||||
```
|
||||
class MyDataClass:
|
||||
my_field: dict[str, Any] = {}
|
||||
my_other_field: dict[str, Any] = {'item1': 'value1', 'item2': 'value2'}
|
||||
```
|
||||
|
||||
_note_ : for any custom default factory function, it must be added to the `default_factory_map`
|
||||
in this file as `griffe` doesn't provide a straightforward mechanism with its AST to dynamically
|
||||
import/call the function.
|
||||
"""
|
||||
|
||||
def on_attribute_instance(
|
||||
self,
|
||||
*,
|
||||
node: ast.AST | ObjectNode,
|
||||
attr: Attribute,
|
||||
agent: Visitor | Inspector,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called for every `node` and/or `attr` on a file's AST."""
|
||||
if not hasattr(node, "value"):
|
||||
return
|
||||
if isinstance(node.value, ast.Call):
|
||||
# Search for all of the `default_factory` fields.
|
||||
default_factory_value: str | None = None
|
||||
for kw in node.value.keywords:
|
||||
if kw.arg == "default_factory":
|
||||
# based on the node type, return the proper function name
|
||||
match get_callable_name(kw.value):
|
||||
# `dict` and `list` are common default factory functions
|
||||
case 'dict':
|
||||
default_factory_value = "{}"
|
||||
case 'list':
|
||||
default_factory_value = "[]"
|
||||
|
||||
case _:
|
||||
# otherwise, see the nodes is in our map for the custom default factory function
|
||||
callable_name = get_callable_name(kw.value)
|
||||
if callable_name in default_factory_map:
|
||||
default_factory_value = pprint.pformat(default_factory_map[callable_name], indent=4, width=80, sort_dicts=False)
|
||||
else:
|
||||
# if not, display as the default
|
||||
default_factory_value = f"{callable_name}()"
|
||||
|
||||
# store the information in the griffe attribute, which is what is passed to the template for rendering
|
||||
if "dataclass_ext" not in attr.extra:
|
||||
attr.extra["dataclass_ext"] = {}
|
||||
attr.extra["dataclass_ext"]["has_default_factory"] = False
|
||||
if default_factory_value is not None:
|
||||
attr.extra["dataclass_ext"]["has_default_factory"] = True
|
||||
attr.extra["dataclass_ext"]["default_factory"] = default_factory_value
|
|
@ -0,0 +1 @@
|
|||
template overrides for mkdocs-materials
|
|
@ -0,0 +1,13 @@
|
|||
{% extends "_base/class.html.jinja" %}
|
||||
|
||||
{% block signature scoped %}
|
||||
{% if config.extra.class_style != 'simple' %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock signature %}
|
||||
|
||||
{% block bases scoped %}
|
||||
{% if config.extra.class_style != 'simple' %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock bases %}
|
|
@ -0,0 +1,7 @@
|
|||
{% extends "_base/backlinks.html.jinja" %}
|
||||
|
||||
{% block logs scoped %}
|
||||
<p style="color:red">backlinks.html.jinja</p>
|
||||
{% endblock logs %}
|
||||
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
{% extends "_base/class.html.jinja" %}
|
||||
|
||||
{% block logs scoped %}
|
||||
{% if config.extra.template_log_display %}<p style="color:red">class.html.jinja</p>{% endif %}
|
||||
{% endblock logs %}
|
||||
|
||||
{% block signature scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock signature %}
|
|
@ -0,0 +1,21 @@
|
|||
{% extends "_base/docstring/attributes.html.jinja" %}
|
||||
|
||||
{% block logs scoped %}
|
||||
{% if config.extra.template_log_display %}<p style="color:red">docstring/attributes.html.jinja</p>{% endif %}
|
||||
{% endblock logs %}
|
||||
|
||||
{% block table_style scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock table_style %}
|
||||
{% block list_style scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock list_style %}
|
||||
{% block spacy_style scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock spacy_style %}
|
|
@ -0,0 +1,22 @@
|
|||
{% extends "_base/docstring/functions.html.jinja" %}
|
||||
|
||||
{% block logs scoped %}
|
||||
{% if config.extra.template_log_display %}<p style="color:red">docstring/functions.html.jinja</p>{% endif %}
|
||||
{% endblock logs %}
|
||||
|
||||
{% block table_style scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block list_style scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block spacy_style scoped %}
|
||||
{% if config.extra.class_style == 'simple' %}{% else %}
|
||||
{{ super() }}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
{% if 'default_factory' in expression.__str__() %}
|
||||
{{ obj.extra.dataclass_ext.default_factory | safe }}
|
||||
{% else %}
|
||||
{% extends "_base/expression.html.jinja" %}
|
||||
{% endif %}
|
|
@ -0,0 +1,6 @@
|
|||
{% if config.extra.class_style == 'simple' %}{% else %}{% extends "_base/function.html.jinja" %}{% endif %}
|
||||
|
||||
|
||||
{% block logs scoped %}
|
||||
{% if config.extra.template_log_display %}<p style="color:red">function.html.jinja</p>{% endif %}
|
||||
{% endblock logs %}
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"name": "amqttio",
|
||||
"private": true,
|
||||
"version": "0.11.2",
|
||||
"version": "0.11.3",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
|
|
|
@ -15,4 +15,43 @@ export type TopicEntry<T> = {
|
|||
// Define the topic_map type
|
||||
export type TopicMap = {
|
||||
[topic: string]: TopicEntry<DataPoint[]>;
|
||||
};
|
||||
};
|
||||
|
||||
// no need for a full uuid, generate a 6-character alphanumeric sequence, in two parts
|
||||
export function getClientID() {
|
||||
const genPart = () => {
|
||||
const rand = (Math.random() * 46656) | 0
|
||||
// convert random number into an ascii sequence of letters, trimmed to 3 characters
|
||||
return ("000" + rand.toString(36)).slice(-3)
|
||||
}
|
||||
return `web-client-${genPart() + genPart()}`
|
||||
}
|
||||
|
||||
export function secondsToDhms(seconds: number) {
|
||||
const days = Math.floor(seconds / (24 * 3600));
|
||||
seconds %= (24 * 3600);
|
||||
const hours = Math.floor(seconds / 3600);
|
||||
seconds %= 3600;
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
seconds = seconds % 60;
|
||||
|
||||
return {
|
||||
days: days,
|
||||
hours: hours,
|
||||
minutes: minutes,
|
||||
seconds: seconds,
|
||||
};
|
||||
}
|
||||
|
||||
export function getMQTTSettings() {
|
||||
return {
|
||||
url: import.meta.env.VITE_MQTT_WS_TYPE + '://' + import.meta.env.VITE_MQTT_WS_HOST + ':' + import.meta.env.VITE_MQTT_WS_PORT,
|
||||
client_id: getClientID(),
|
||||
clean: true,
|
||||
protocol: 'wss',
|
||||
protocolVersion: 4, // MQTT 3.1.1
|
||||
wsOptions: {
|
||||
protocol: 'mqtt'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,17 +7,10 @@ import MainGrid from './components/MainGrid';
|
|||
import AppTheme from '../shared-theme/AppTheme';
|
||||
import AmqttLogo from './amqtt_bw.svg';
|
||||
|
||||
import {
|
||||
chartsCustomizations,
|
||||
treeViewCustomizations,
|
||||
} from './theme/customizations';
|
||||
import AppBar from "@mui/material/AppBar";
|
||||
import {Toolbar} from "@mui/material";
|
||||
|
||||
const xThemeComponents = {
|
||||
...chartsCustomizations,
|
||||
...treeViewCustomizations,
|
||||
};
|
||||
const xThemeComponents = {};
|
||||
|
||||
export default function Dashboard(props: { disableCustomTheme?: boolean }) {
|
||||
return (
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
<Typography component="h2" variant="h6" color="primary" gutterBottom>
|
||||
{props.children}
|
||||
</Typography>
|
|
@ -1,105 +0,0 @@
|
|||
import * as React from 'react';
|
||||
import { styled } from '@mui/material/styles';
|
||||
import AppBar from '@mui/material/AppBar';
|
||||
import Box from '@mui/material/Box';
|
||||
import Stack from '@mui/material/Stack';
|
||||
import MuiToolbar from '@mui/material/Toolbar';
|
||||
import { tabsClasses } from '@mui/material/Tabs';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import MenuRoundedIcon from '@mui/icons-material/MenuRounded';
|
||||
import DashboardRoundedIcon from '@mui/icons-material/DashboardRounded';
|
||||
import SideMenuMobile from './SideMenuMobile';
|
||||
import MenuButton from './MenuButton';
|
||||
import ColorModeIconDropdown from '../../shared-theme/ColorModeIconDropdown';
|
||||
|
||||
const Toolbar = styled(MuiToolbar)({
|
||||
width: '100%',
|
||||
padding: '12px',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'start',
|
||||
justifyContent: 'center',
|
||||
gap: '12px',
|
||||
flexShrink: 0,
|
||||
[`& ${tabsClasses.flexContainer}`]: {
|
||||
gap: '8px',
|
||||
p: '8px',
|
||||
pb: 0,
|
||||
},
|
||||
});
|
||||
|
||||
export default function AppNavbar() {
|
||||
const [open, setOpen] = React.useState(false);
|
||||
|
||||
const toggleDrawer = (newOpen: boolean) => () => {
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
||||
return (
|
||||
<AppBar
|
||||
position="fixed"
|
||||
sx={{
|
||||
display: { xs: 'auto', md: 'none' },
|
||||
boxShadow: 0,
|
||||
bgcolor: 'background.paper',
|
||||
backgroundImage: 'none',
|
||||
borderBottom: '1px solid',
|
||||
borderColor: 'divider',
|
||||
top: 'var(--template-frame-height, 0px)',
|
||||
}}
|
||||
>
|
||||
<Toolbar variant="regular">
|
||||
<Stack
|
||||
direction="row"
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
flexGrow: 1,
|
||||
width: '100%',
|
||||
gap: 1,
|
||||
}}
|
||||
>
|
||||
<Stack
|
||||
direction="row"
|
||||
spacing={1}
|
||||
sx={{ justifyContent: 'center', mr: 'auto' }}
|
||||
>
|
||||
<CustomIcon />
|
||||
<Typography variant="h4" component="h1" sx={{ color: 'text.primary' }}>
|
||||
Dashboard
|
||||
</Typography>
|
||||
</Stack>
|
||||
<ColorModeIconDropdown />
|
||||
<MenuButton aria-label="menu" onClick={toggleDrawer(true)}>
|
||||
<MenuRoundedIcon />
|
||||
</MenuButton>
|
||||
<SideMenuMobile open={open} toggleDrawer={toggleDrawer} />
|
||||
</Stack>
|
||||
</Toolbar>
|
||||
</AppBar>
|
||||
);
|
||||
}
|
||||
|
||||
export function CustomIcon() {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
width: '1.5rem',
|
||||
height: '1.5rem',
|
||||
bgcolor: 'black',
|
||||
borderRadius: '999px',
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
alignSelf: 'center',
|
||||
backgroundImage:
|
||||
'linear-gradient(135deg, hsl(210, 98%, 60%) 0%, hsl(210, 100%, 35%) 100%)',
|
||||
color: 'hsla(210, 100%, 95%, 0.9)',
|
||||
border: '1px solid',
|
||||
borderColor: 'hsl(210, 100%, 55%)',
|
||||
boxShadow: 'inset 0 2px 5px rgba(255, 255, 255, 0.3)',
|
||||
}}
|
||||
>
|
||||
<DashboardRoundedIcon color="inherit" sx={{ fontSize: '1rem' }} />
|
||||
</Box>
|
||||
);
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
import Card from '@mui/material/Card';
|
||||
import CardContent from '@mui/material/CardContent';
|
||||
import Button from '@mui/material/Button';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import AutoAwesomeRoundedIcon from '@mui/icons-material/AutoAwesomeRounded';
|
||||
|
||||
export default function CardAlert() {
|
||||
return (
|
||||
<Card variant="outlined" sx={{ m: 1.5, flexShrink: 0 }}>
|
||||
<CardContent>
|
||||
<AutoAwesomeRoundedIcon fontSize="small" />
|
||||
<Typography gutterBottom sx={{ fontWeight: 600 }}>
|
||||
Plan about to expire
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ mb: 2, color: 'text.secondary' }}>
|
||||
Enjoy 10% off when renewing your plan today.
|
||||
</Typography>
|
||||
<Button variant="contained" size="small" fullWidth>
|
||||
Get the discount
|
||||
</Button>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
import CountUp from "react-countup";
|
||||
|
||||
|
||||
const byte_units = [
|
||||
'Bytes',
|
||||
'KB',
|
||||
'MB',
|
||||
'GB',
|
||||
'TB'
|
||||
]
|
||||
|
||||
const update_time = 5;
|
||||
|
||||
export function ByteCounter(props: any) {
|
||||
|
||||
let start = props.start;
|
||||
let end = props.end;
|
||||
if(end - start < 200){
|
||||
return <CountUp
|
||||
start={start}
|
||||
end={end}
|
||||
duration={update_time}/>
|
||||
}
|
||||
|
||||
let unit = byte_units[0];
|
||||
for (let i = 0; i < byte_units.length; i++) {
|
||||
|
||||
if( start > 1_000) {
|
||||
start = start / 1000;
|
||||
end = end / 1000;
|
||||
unit = byte_units[i+1];
|
||||
}
|
||||
}
|
||||
|
||||
return <CountUp
|
||||
start={start}
|
||||
end={end}
|
||||
suffix={" " + unit}
|
||||
decimals={2}
|
||||
duration={update_time}/>
|
||||
}
|
||||
|
||||
export function StandardCounter(props: any) {
|
||||
|
||||
return <CountUp
|
||||
start={props.start}
|
||||
end={props.end}
|
||||
duration={update_time}/>
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
import { DataGrid } from '@mui/x-data-grid';
|
||||
import { columns, rows } from '../internals/data/gridData';
|
||||
|
||||
export default function CustomizedDataGrid() {
|
||||
return (
|
||||
<DataGrid
|
||||
checkboxSelection
|
||||
rows={rows}
|
||||
columns={columns}
|
||||
getRowClassName={(params) =>
|
||||
params.indexRelativeToCurrentPage % 2 === 0 ? 'even' : 'odd'
|
||||
}
|
||||
initialState={{
|
||||
pagination: { paginationModel: { pageSize: 20 } },
|
||||
}}
|
||||
pageSizeOptions={[10, 20, 50]}
|
||||
disableColumnResize
|
||||
density="compact"
|
||||
slotProps={{
|
||||
filterPanel: {
|
||||
filterFormProps: {
|
||||
logicOperatorInputProps: {
|
||||
variant: 'outlined',
|
||||
size: 'small',
|
||||
},
|
||||
columnInputProps: {
|
||||
variant: 'outlined',
|
||||
size: 'small',
|
||||
sx: { mt: 'auto' },
|
||||
},
|
||||
operatorInputProps: {
|
||||
variant: 'outlined',
|
||||
size: 'small',
|
||||
sx: { mt: 'auto' },
|
||||
},
|
||||
valueInputProps: {
|
||||
InputComponentProps: {
|
||||
variant: 'outlined',
|
||||
size: 'small',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
|
@ -4,9 +4,10 @@
|
|||
import Typography from '@mui/material/Typography';
|
||||
import Stack from '@mui/material/Stack';
|
||||
import { LineChart } from '@mui/x-charts/LineChart';
|
||||
import CountUp from 'react-countup';
|
||||
|
||||
import type { DataPoint } from '../../assets/helpers.jsx';
|
||||
import {CircularProgress} from "@mui/material";
|
||||
import {StandardCounter, ByteCounter} from "./Counter.tsx";
|
||||
import {useRef} from "react";
|
||||
|
||||
const currentTimeZone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
|
@ -116,7 +117,7 @@
|
|||
</LineChart>
|
||||
}
|
||||
|
||||
export default function SessionsChart(props: any) {
|
||||
export default function DashboardChart(props: any) {
|
||||
|
||||
const lastCalc = useRef<number>(0);
|
||||
|
||||
|
@ -153,13 +154,9 @@
|
|||
<Typography variant="h4" component="p">
|
||||
|
||||
{ props.data.length < 2 ? "" :
|
||||
<CountUp
|
||||
start={props.data[props.data.length - 2].value}
|
||||
end={props.data[props.data.length - 1].value}
|
||||
duration={5}
|
||||
decimals={props.decimals}
|
||||
|
||||
/>} {props.label}
|
||||
props.isBytes ? <ByteCounter start={props.data[props.data.length - 2].value} end={props.data[props.data.length - 1].value}/> :
|
||||
<StandardCounter start={props.data[props.data.length - 2].value} end={props.data[props.data.length - 1].value} />
|
||||
} {props.label}
|
||||
</Typography>
|
||||
<p>
|
||||
{ calc_per_second(props.data[props.data.length-1], props.data[props.data.length-2]) }
|
|
@ -0,0 +1,105 @@
|
|||
import Grid from "@mui/material/Grid";
|
||||
import Typography from "@mui/material/Typography";
|
||||
import {FontAwesomeIcon} from "@fortawesome/react-fontawesome";
|
||||
import {faDiscord, faDocker, faGithub, faPython} from "@fortawesome/free-brands-svg-icons";
|
||||
import rtdIcon from "../../assets/readthedocs.svg";
|
||||
import {Paper, Table, TableBody, TableCell, TableContainer, TableHead, TableRow} from "@mui/material";
|
||||
|
||||
|
||||
export default function DescriptionPanel() {
|
||||
return <>
|
||||
<Grid size={{xs: 10, md: 5}}>
|
||||
<Typography component="h2" variant="h6" sx={{mb: 2}}>
|
||||
Overview
|
||||
</Typography>
|
||||
<div>
|
||||
<p style={{textAlign: 'left'}}>This is <b>test.amqtt.io</b>.</p>
|
||||
<p style={{textAlign: 'left'}}>It hosts a publicly available aMQTT server/broker.</p>
|
||||
<p style={{textAlign: 'left'}}><a href="http://www.mqtt.org">MQTT</a> is a very lightweight
|
||||
protocol that uses a publish/subscribe model. This makes it suitable for "machine to machine"
|
||||
messaging such as with low power sensors or mobile devices.
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>For more information: </p>
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td style={{width: 250}}>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faGithub} size="xl"/> github: <a
|
||||
href="https://github.com/Yakifo/amqtt">Yakifo/amqtt</a>
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faPython} size="xl"/> PyPi: <a
|
||||
href="https://pypi.org/project/amqtt/">aMQTT</a>
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faDiscord} size="xl"/> Discord: <a
|
||||
href="https://discord.gg/S3sP6dDaF3">aMQTT</a>
|
||||
</p>
|
||||
</td>
|
||||
<td>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<img
|
||||
src={rtdIcon}
|
||||
style={{width: 20, verticalAlign: -4}}
|
||||
alt="website logo"
|
||||
/>
|
||||
ReadTheDocs: <a href="https://amqtt.readthedocs.io/">aMQTT</a>
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faDocker} size="xl"/> DockerHub: <a
|
||||
href="https://hub.docker.com/repositories/amqtt">aMQTT</a>
|
||||
</p>
|
||||
<p> </p>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
</div>
|
||||
</Grid>
|
||||
<Grid size={{xs: 1, md: 1}}></Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<Typography component="h2" variant="h6" sx={{mb: 2}}>
|
||||
Access
|
||||
</Typography>
|
||||
<TableContainer component={Paper}>
|
||||
<Table sx={{maxWidth: 400}} size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Host</TableCell>
|
||||
<TableCell>test.amqtt.io</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
<TableCell>TCP</TableCell>
|
||||
<TableCell>1883</TableCell>
|
||||
</TableRow>
|
||||
<TableRow>
|
||||
<TableCell>TLS TCP</TableCell>
|
||||
<TableCell>8883</TableCell>
|
||||
</TableRow>
|
||||
<TableRow>
|
||||
<TableCell>Websocket</TableCell>
|
||||
<TableCell>8080</TableCell>
|
||||
</TableRow>
|
||||
<TableRow>
|
||||
<TableCell>SSL Websocket</TableCell>
|
||||
<TableCell>8443</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
The purpose of this free MQTT broker at <strong>test.amqtt.io</strong> is to learn about and test the MQTT
|
||||
protocol. It
|
||||
should not be used in production, development, staging or uat environments. Do not to use it to send any
|
||||
sensitive information or personal data into the system as all topics are public. Any illegal use of this
|
||||
MQTT broker is strictly forbidden. By using this MQTT broker located at <strong>test.amqtt.io</strong> you
|
||||
warrant that you are neither a sanctioned person nor located in a country that is subject to sanctions.
|
||||
</p>
|
||||
</Grid>
|
||||
</>
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
import Stack from '@mui/material/Stack';
|
||||
import NotificationsRoundedIcon from '@mui/icons-material/NotificationsRounded';
|
||||
import NavbarBreadcrumbs from './NavbarBreadcrumbs';
|
||||
import MenuButton from './MenuButton';
|
||||
import ColorModeIconDropdown from '../../shared-theme/ColorModeIconDropdown';
|
||||
|
||||
import Search from './Search';
|
||||
|
||||
export default function Header() {
|
||||
return (
|
||||
<Stack
|
||||
direction="row"
|
||||
sx={{
|
||||
display: { xs: 'none', md: 'flex' },
|
||||
width: '100%',
|
||||
alignItems: { xs: 'flex-start', md: 'center' },
|
||||
justifyContent: 'space-between',
|
||||
maxWidth: { sm: '100%', md: '1700px' },
|
||||
pt: 1.5,
|
||||
}}
|
||||
spacing={2}
|
||||
>
|
||||
<NavbarBreadcrumbs />
|
||||
<Stack direction="row" sx={{ gap: 1 }}>
|
||||
<Search />
|
||||
<MenuButton showBadge aria-label="Open notifications">
|
||||
<NotificationsRoundedIcon />
|
||||
</MenuButton>
|
||||
<ColorModeIconDropdown />
|
||||
</Stack>
|
||||
</Stack>
|
||||
);
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
import Card from '@mui/material/Card';
|
||||
import CardContent from '@mui/material/CardContent';
|
||||
import Button from '@mui/material/Button';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import ChevronRightRoundedIcon from '@mui/icons-material/ChevronRightRounded';
|
||||
import InsightsRoundedIcon from '@mui/icons-material/InsightsRounded';
|
||||
import useMediaQuery from '@mui/material/useMediaQuery';
|
||||
import { useTheme } from '@mui/material/styles';
|
||||
|
||||
export default function HighlightedCard() {
|
||||
const theme = useTheme();
|
||||
const isSmallScreen = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
|
||||
return (
|
||||
<Card sx={{ height: '100%' }}>
|
||||
<CardContent>
|
||||
<InsightsRoundedIcon />
|
||||
<Typography
|
||||
component="h2"
|
||||
variant="subtitle2"
|
||||
gutterBottom
|
||||
sx={{ fontWeight: '600' }}
|
||||
>
|
||||
Explore your data
|
||||
</Typography>
|
||||
<Typography sx={{ color: 'text.secondary', mb: '8px' }}>
|
||||
Uncover performance and visitor insights with our data wizardry.
|
||||
</Typography>
|
||||
<Button
|
||||
variant="contained"
|
||||
size="small"
|
||||
color="primary"
|
||||
endIcon={<ChevronRightRoundedIcon />}
|
||||
fullWidth={isSmallScreen}
|
||||
>
|
||||
Get insights
|
||||
</Button>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
|
@ -1,18 +1,13 @@
|
|||
import Grid from '@mui/material/Grid';
|
||||
import Box from '@mui/material/Box';
|
||||
import Stack from '@mui/material/Stack';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import Copyright from '../internals/components/Copyright';
|
||||
import SessionsChart from './SessionsChart';
|
||||
import Copyright from '../Copyright';
|
||||
import DashboardChart from './DashboardChart.tsx';
|
||||
import {useEffect, useState} from "react";
|
||||
// @ts-ignore
|
||||
import useMqtt from '../../assets/usemqtt';
|
||||
import type {DataPoint, TopicMap} from '../../assets/helpers';
|
||||
import {Paper, Table, TableBody, TableCell, TableContainer, TableHead, TableRow} from "@mui/material";
|
||||
import {FontAwesomeIcon} from '@fortawesome/react-fontawesome'
|
||||
import {faGithub, faPython, faDocker, faDiscord} from "@fortawesome/free-brands-svg-icons";
|
||||
|
||||
import rtdIcon from "../../assets/readthedocs.svg";
|
||||
import {type DataPoint, getMQTTSettings, secondsToDhms, type TopicMap} from '../../assets/helpers';
|
||||
import DescriptionPanel from "./DescriptionPanel";
|
||||
|
||||
export default function MainGrid() {
|
||||
|
||||
|
@ -27,54 +22,9 @@ export default function MainGrid() {
|
|||
const [memSize, setMemSize] = useState<DataPoint[]>([]);
|
||||
const [version, setVersion] = useState<string>('');
|
||||
|
||||
function getRandomInt(min: number, max: number) {
|
||||
min = Math.ceil(min);
|
||||
max = Math.floor(max);
|
||||
return Math.floor(Math.random() * (max - min + 1)) + min;
|
||||
}
|
||||
const {mqttSubscribe, isConnected, messageQueue, messageTick} = useMqtt(getMQTTSettings());
|
||||
|
||||
function secondsToDhms(seconds: number) {
|
||||
const days = Math.floor(seconds / (24 * 3600));
|
||||
seconds %= (24 * 3600);
|
||||
const hours = Math.floor(seconds / 3600);
|
||||
seconds %= 3600;
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
seconds = seconds % 60;
|
||||
|
||||
return {
|
||||
days: days,
|
||||
hours: hours,
|
||||
minutes: minutes,
|
||||
seconds: seconds,
|
||||
};
|
||||
}
|
||||
|
||||
const mqtt_settings = {
|
||||
url: import.meta.env.VITE_MQTT_WS_TYPE + '://' + import.meta.env.VITE_MQTT_WS_HOST + ':' + import.meta.env.VITE_MQTT_WS_PORT, client_id: `web-client-${getRandomInt(1, 100)}`,
|
||||
clean: true,
|
||||
protocol: 'wss',
|
||||
protocolVersion: 4, // MQTT 3.1.1
|
||||
wsOptions: {
|
||||
protocol: 'mqtt'
|
||||
}
|
||||
};
|
||||
|
||||
const {mqttSubscribe, isConnected, messageQueue, messageTick} = useMqtt(mqtt_settings);
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected) {
|
||||
mqttSubscribe('$SYS/broker/version');
|
||||
mqttSubscribe('$SYS/broker/messages/publish/#');
|
||||
mqttSubscribe('$SYS/broker/load/bytes/#');
|
||||
mqttSubscribe('$SYS/broker/uptime/formatted');
|
||||
mqttSubscribe('$SYS/broker/uptime');
|
||||
mqttSubscribe('$SYS/broker/clients/connected');
|
||||
mqttSubscribe('$SYS/broker/cpu/percent');
|
||||
mqttSubscribe('$SYS/broker/heap/size')
|
||||
}
|
||||
}, [isConnected, mqttSubscribe]);
|
||||
|
||||
const topic_map: TopicMap = {
|
||||
const topicMap: TopicMap = {
|
||||
'$SYS/broker/messages/publish/sent': {current: sent, update: setSent},
|
||||
'$SYS/broker/messages/publish/received': {current: received, update: setReceived},
|
||||
'$SYS/broker/load/bytes/received': {current: bytesIn, update: setBytesIn},
|
||||
|
@ -84,6 +34,17 @@ export default function MainGrid() {
|
|||
'$SYS/broker/heap/size': {current: memSize, update: setMemSize},
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected) {
|
||||
for(const topic in topicMap) {
|
||||
mqttSubscribe(topic);
|
||||
}
|
||||
mqttSubscribe('$SYS/broker/version');
|
||||
mqttSubscribe('$SYS/broker/uptime/formatted');
|
||||
mqttSubscribe('$SYS/broker/uptime');
|
||||
}
|
||||
}, [isConnected, mqttSubscribe]);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
while (messageQueue.current.length > 0) {
|
||||
|
@ -92,8 +53,8 @@ export default function MainGrid() {
|
|||
|
||||
const d = payload.message;
|
||||
|
||||
if(payload.topic in topic_map) {
|
||||
const { update } = topic_map[payload.topic];
|
||||
if(payload.topic in topicMap) {
|
||||
const { update } = topicMap[payload.topic];
|
||||
const newPoint: DataPoint = {
|
||||
time: new Date().toISOString(),
|
||||
timestamp: Date.now(),
|
||||
|
@ -115,6 +76,13 @@ export default function MainGrid() {
|
|||
}
|
||||
}, [messageTick, messageQueue]);
|
||||
|
||||
const upTime = () => {
|
||||
return isConnected && serverUptime ? <>
|
||||
<strong>aMQTT broker</strong> {version.replace('aMQTT version ', 'v')} <strong>started at </strong> {serverStart}
|
||||
<strong>up for</strong> {serverUptime}
|
||||
</> : <></>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box sx={{width: '100%', maxWidth: {sm: '100%', md: '1700px'}}}>
|
||||
{/* cards */}
|
||||
|
@ -125,131 +93,39 @@ export default function MainGrid() {
|
|||
columns={12}
|
||||
sx={{mb: (theme) => theme.spacing(2)}}
|
||||
>
|
||||
<Grid size={{xs: 10, md: 5}}>
|
||||
<Typography component="h2" variant="h6" sx={{mb: 2}}>
|
||||
Overview
|
||||
</Typography>
|
||||
<div>
|
||||
<p style={{textAlign: 'left'}}>This is <b>test.amqtt.io</b>.</p>
|
||||
<p style={{textAlign: 'left'}}>It hosts a publicly available aMQTT server/broker.</p>
|
||||
<p style={{textAlign: 'left'}}><a href="http://www.mqtt.org">MQTT</a> is a very lightweight
|
||||
protocol that uses a publish/subscribe model. This makes it suitable for "machine to machine"
|
||||
messaging such as with low power sensors or mobile devices.
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>For more information: </p>
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td style={{width: 250}}>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faGithub} size="xl"/> github: <a
|
||||
href="https://github.com/Yakifo/amqtt">Yakifo/amqtt</a>
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faPython} size="xl"/> PyPi: <a
|
||||
href="https://pypi.org/project/amqtt/">aMQTT</a>
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faDiscord} size="xl"/> Discord: <a
|
||||
href="https://discord.gg/S3sP6dDaF3">aMQTT</a>
|
||||
</p>
|
||||
</td>
|
||||
<td>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<img
|
||||
src={rtdIcon}
|
||||
style={{width: 20, verticalAlign: -4}}
|
||||
alt="website logo"
|
||||
/>
|
||||
ReadTheDocs: <a href="https://amqtt.readthedocs.io/">aMQTT</a>
|
||||
</p>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
<FontAwesomeIcon icon={faDocker} size="xl"/> DockerHub: <a
|
||||
href="https://hub.docker.com/repositories/amqtt">aMQTT</a>
|
||||
</p>
|
||||
<p> </p>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
</div>
|
||||
</Grid>
|
||||
<Grid size={{xs: 1, md: 1}}></Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<Typography component="h2" variant="h6" sx={{mb: 2}}>
|
||||
Access
|
||||
</Typography>
|
||||
<TableContainer component={Paper}>
|
||||
<Table sx={{maxWidth: 400}} size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Host</TableCell>
|
||||
<TableCell>test.amqtt.io</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
<TableCell>TCP</TableCell>
|
||||
<TableCell>1883</TableCell>
|
||||
</TableRow>
|
||||
<TableRow>
|
||||
<TableCell>TLS TCP</TableCell>
|
||||
<TableCell>8883</TableCell>
|
||||
</TableRow>
|
||||
<TableRow>
|
||||
<TableCell>Websocket</TableCell>
|
||||
<TableCell>8080</TableCell>
|
||||
</TableRow>
|
||||
<TableRow>
|
||||
<TableCell>SSL Websocket</TableCell>
|
||||
<TableCell>8443</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
<p style={{textAlign: 'left'}}>
|
||||
The purpose of this free MQTT broker at <strong>test.amqtt.io</strong> is to learn about and test the MQTT
|
||||
protocol. It
|
||||
should not be used in production, development, staging or uat environments. Do not to use it to send any
|
||||
sensitive information or personal data into the system as all topics are public. Any illegal use of this
|
||||
MQTT broker is strictly forbidden. By using this MQTT broker located at <strong>test.amqtt.io</strong> you
|
||||
warrant that you are neither a sanctioned person nor located in a country that is subject to sanctions.
|
||||
</p>
|
||||
</Grid>
|
||||
<DescriptionPanel/>
|
||||
</Grid>
|
||||
<Grid
|
||||
container
|
||||
spacing={2}
|
||||
columns={12}
|
||||
sx={{mb: (theme) => theme.spacing(2)}}
|
||||
><Grid size={{xs: 12, md: 12}}>
|
||||
<strong>broker</strong> ('{version}') <strong>started at </strong> {serverStart}
|
||||
<strong>up for</strong> {serverUptime}
|
||||
>
|
||||
<Grid size={{xs: 12, md: 12}}>
|
||||
{upTime()}
|
||||
</Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<SessionsChart title={'Sent Messages'} label={''} data={sent} isConnected={isConnected} isPerSecond/>
|
||||
<DashboardChart title={'Sent Messages'} label={''} data={sent} isConnected={isConnected} isPerSecond/>
|
||||
</Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<SessionsChart title={'Received Messages'} label={''} data={received} isConnected={isConnected} isPerSecond/>
|
||||
<DashboardChart title={'Received Messages'} label={''} data={received} isConnected={isConnected} isPerSecond/>
|
||||
</Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<SessionsChart title={'Bytes Out'} label={'Bytes'} data={bytesOut} isConnected={isConnected}/>
|
||||
<DashboardChart title={'Bytes Out'} label={''} data={bytesOut} isConnected={isConnected} isBytes/>
|
||||
</Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<SessionsChart title={'Bytes In'} label={'Bytes'} data={bytesIn} isConnected={isConnected}/>
|
||||
<DashboardChart title={'Bytes In'} label={''} data={bytesIn} isConnected={isConnected} isBytes/>
|
||||
</Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<SessionsChart title={'Clients Connected'} label={''} data={clientsConnected} isConnected={isConnected}/>
|
||||
<DashboardChart title={'Clients Connected'} label={''} data={clientsConnected} isConnected={isConnected}/>
|
||||
</Grid>
|
||||
<Grid size={{xs: 12, md: 6}}>
|
||||
<Grid container spacing={2} columns={2}>
|
||||
<Grid size={{lg:1}}>
|
||||
<SessionsChart title={'CPU'} label={'%'} data={cpuPercent} decimals={2} isConnected={isConnected}/>
|
||||
<DashboardChart title={'CPU'} label={'%'} data={cpuPercent} decimals={2} isConnected={isConnected}/>
|
||||
</Grid>
|
||||
<Grid size={{lg:1}}>
|
||||
<SessionsChart title={'Memory'} label={'MB'} data={memSize} decimals={1} isConnected={isConnected}/>
|
||||
<DashboardChart title={'Memory'} label={'MB'} data={memSize} decimals={1} isConnected={isConnected}/>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</Grid>
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
import Badge, { badgeClasses } from '@mui/material/Badge';
|
||||
import IconButton from '@mui/material/IconButton';
|
||||
import type { IconButtonProps } from '@mui/material/IconButton';
|
||||
|
||||
|
||||
export interface MenuButtonProps extends IconButtonProps {
|
||||
showBadge?: boolean;
|
||||
}
|
||||
|
||||
export default function MenuButton({
|
||||
showBadge = false,
|
||||
...props
|
||||
}) {
|
||||
return (
|
||||
<Badge
|
||||
color="error"
|
||||
variant="dot"
|
||||
invisible={!showBadge}
|
||||
sx={{ [`& .${badgeClasses.badge}`]: { right: 2, top: 2 } }}
|
||||
>
|
||||
<IconButton size="small" {...props} />
|
||||
</Badge>
|
||||
);
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
import List from '@mui/material/List';
|
||||
import ListItem from '@mui/material/ListItem';
|
||||
import ListItemButton from '@mui/material/ListItemButton';
|
||||
import ListItemIcon from '@mui/material/ListItemIcon';
|
||||
import ListItemText from '@mui/material/ListItemText';
|
||||
import Stack from '@mui/material/Stack';
|
||||
import HomeRoundedIcon from '@mui/icons-material/HomeRounded';
|
||||
import AnalyticsRoundedIcon from '@mui/icons-material/AnalyticsRounded';
|
||||
import PeopleRoundedIcon from '@mui/icons-material/PeopleRounded';
|
||||
import AssignmentRoundedIcon from '@mui/icons-material/AssignmentRounded';
|
||||
import SettingsRoundedIcon from '@mui/icons-material/SettingsRounded';
|
||||
import InfoRoundedIcon from '@mui/icons-material/InfoRounded';
|
||||
import HelpRoundedIcon from '@mui/icons-material/HelpRounded';
|
||||
|
||||
const mainListItems = [
|
||||
{ text: 'Home', icon: <HomeRoundedIcon /> },
|
||||
{ text: 'Analytics', icon: <AnalyticsRoundedIcon /> },
|
||||
{ text: 'Clients', icon: <PeopleRoundedIcon /> },
|
||||
{ text: 'Tasks', icon: <AssignmentRoundedIcon /> },
|
||||
];
|
||||
|
||||
const secondaryListItems = [
|
||||
{ text: 'Settings', icon: <SettingsRoundedIcon /> },
|
||||
{ text: 'About', icon: <InfoRoundedIcon /> },
|
||||
{ text: 'Feedback', icon: <HelpRoundedIcon /> },
|
||||
];
|
||||
|
||||
export default function MenuContent() {
|
||||
return (
|
||||
<Stack sx={{ flexGrow: 1, p: 1, justifyContent: 'space-between' }}>
|
||||
<List dense>
|
||||
{mainListItems.map((item, index) => (
|
||||
<ListItem key={index} disablePadding sx={{ display: 'block' }}>
|
||||
<ListItemButton selected={index === 0}>
|
||||
<ListItemIcon>{item.icon}</ListItemIcon>
|
||||
<ListItemText primary={item.text} />
|
||||
</ListItemButton>
|
||||
</ListItem>
|
||||
))}
|
||||
</List>
|
||||
<List dense>
|
||||
{secondaryListItems.map((item, index) => (
|
||||
<ListItem key={index} disablePadding sx={{ display: 'block' }}>
|
||||
<ListItemButton>
|
||||
<ListItemIcon>{item.icon}</ListItemIcon>
|
||||
<ListItemText primary={item.text} />
|
||||
</ListItemButton>
|
||||
</ListItem>
|
||||
))}
|
||||
</List>
|
||||
</Stack>
|
||||
);
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
import { styled } from '@mui/material/styles';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import Breadcrumbs, { breadcrumbsClasses } from '@mui/material/Breadcrumbs';
|
||||
import NavigateNextRoundedIcon from '@mui/icons-material/NavigateNextRounded';
|
||||
|
||||
const StyledBreadcrumbs = styled(Breadcrumbs)(({ theme }) => ({
|
||||
margin: theme.spacing(1, 0),
|
||||
[`& .${breadcrumbsClasses.separator}`]: {
|
||||
color: (theme.vars || theme).palette.action.disabled,
|
||||
margin: 1,
|
||||
},
|
||||
[`& .${breadcrumbsClasses.ol}`]: {
|
||||
alignItems: 'center',
|
||||
},
|
||||
}));
|
||||
|
||||
export default function NavbarBreadcrumbs() {
|
||||
return (
|
||||
<StyledBreadcrumbs
|
||||
aria-label="breadcrumb"
|
||||
separator={<NavigateNextRoundedIcon fontSize="small" />}
|
||||
>
|
||||
<Typography variant="body1">Dashboard</Typography>
|
||||
<Typography variant="body1" sx={{ color: 'text.primary', fontWeight: 600 }}>
|
||||
Home
|
||||
</Typography>
|
||||
</StyledBreadcrumbs>
|
||||
);
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
import * as React from 'react';
|
||||
import { styled } from '@mui/material/styles';
|
||||
import Divider, { dividerClasses } from '@mui/material/Divider';
|
||||
import Menu from '@mui/material/Menu';
|
||||
import MuiMenuItem from '@mui/material/MenuItem';
|
||||
import { paperClasses } from '@mui/material/Paper';
|
||||
import { listClasses } from '@mui/material/List';
|
||||
import ListItemText from '@mui/material/ListItemText';
|
||||
import ListItemIcon, { listItemIconClasses } from '@mui/material/ListItemIcon';
|
||||
import LogoutRoundedIcon from '@mui/icons-material/LogoutRounded';
|
||||
import MoreVertRoundedIcon from '@mui/icons-material/MoreVertRounded';
|
||||
import MenuButton from './MenuButton';
|
||||
|
||||
const MenuItem = styled(MuiMenuItem)({
|
||||
margin: '2px 0',
|
||||
});
|
||||
|
||||
export default function OptionsMenu() {
|
||||
const [anchorEl, setAnchorEl] = React.useState<null | HTMLElement>(null);
|
||||
const open = Boolean(anchorEl);
|
||||
const handleClick = (event: React.MouseEvent<HTMLElement>) => {
|
||||
setAnchorEl(event.currentTarget);
|
||||
};
|
||||
const handleClose = () => {
|
||||
setAnchorEl(null);
|
||||
};
|
||||
return (
|
||||
<React.Fragment>
|
||||
<MenuButton
|
||||
aria-label="Open menu"
|
||||
onClick={handleClick}
|
||||
sx={{ borderColor: 'transparent' }}
|
||||
>
|
||||
<MoreVertRoundedIcon />
|
||||
</MenuButton>
|
||||
<Menu
|
||||
anchorEl={anchorEl}
|
||||
id="menu"
|
||||
open={open}
|
||||
onClose={handleClose}
|
||||
onClick={handleClose}
|
||||
transformOrigin={{ horizontal: 'right', vertical: 'top' }}
|
||||
anchorOrigin={{ horizontal: 'right', vertical: 'bottom' }}
|
||||
sx={{
|
||||
[`& .${listClasses.root}`]: {
|
||||
padding: '4px',
|
||||
},
|
||||
[`& .${paperClasses.root}`]: {
|
||||
padding: 0,
|
||||
},
|
||||
[`& .${dividerClasses.root}`]: {
|
||||
margin: '4px -4px',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<MenuItem onClick={handleClose}>Profile</MenuItem>
|
||||
<MenuItem onClick={handleClose}>My account</MenuItem>
|
||||
<Divider />
|
||||
<MenuItem onClick={handleClose}>Add another account</MenuItem>
|
||||
<MenuItem onClick={handleClose}>Settings</MenuItem>
|
||||
<Divider />
|
||||
<MenuItem
|
||||
onClick={handleClose}
|
||||
sx={{
|
||||
[`& .${listItemIconClasses.root}`]: {
|
||||
ml: 'auto',
|
||||
minWidth: 0,
|
||||
},
|
||||
}}
|
||||
>
|
||||
<ListItemText>Logout</ListItemText>
|
||||
<ListItemIcon>
|
||||
<LogoutRoundedIcon fontSize="small" />
|
||||
</ListItemIcon>
|
||||
</MenuItem>
|
||||
</Menu>
|
||||
</React.Fragment>
|
||||
);
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
import Card from '@mui/material/Card';
|
||||
import CardContent from '@mui/material/CardContent';
|
||||
import Chip from '@mui/material/Chip';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import Stack from '@mui/material/Stack';
|
||||
import { BarChart } from '@mui/x-charts/BarChart';
|
||||
import { useTheme } from '@mui/material/styles';
|
||||
|
||||
export default function PageViewsBarChart() {
|
||||
const theme = useTheme();
|
||||
const colorPalette = [
|
||||
(theme.vars || theme).palette.primary.dark,
|
||||
(theme.vars || theme).palette.primary.main,
|
||||
(theme.vars || theme).palette.primary.light,
|
||||
];
|
||||
return (
|
||||
<Card variant="outlined" sx={{ width: '100%' }}>
|
||||
<CardContent>
|
||||
<Typography component="h2" variant="subtitle2" gutterBottom>
|
||||
Page views and downloads
|
||||
</Typography>
|
||||
<Stack sx={{ justifyContent: 'space-between' }}>
|
||||
<Stack
|
||||
direction="row"
|
||||
sx={{
|
||||
alignContent: { xs: 'center', sm: 'flex-start' },
|
||||
alignItems: 'center',
|
||||
gap: 1,
|
||||
}}
|
||||
>
|
||||
<Typography variant="h4" component="p">
|
||||
1.3M
|
||||
</Typography>
|
||||
<Chip size="small" color="error" label="-8%" />
|
||||
</Stack>
|
||||
<Typography variant="caption" sx={{ color: 'text.secondary' }}>
|
||||
Page views and downloads for the last 6 months
|
||||
</Typography>
|
||||
</Stack>
|
||||
<BarChart
|
||||
borderRadius={8}
|
||||
colors={colorPalette}
|
||||
xAxis={
|
||||
[
|
||||
{
|
||||
scaleType: 'band',
|
||||
categoryGapRatio: 0.5,
|
||||
data: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul'],
|
||||
},
|
||||
] as any
|
||||
}
|
||||
series={[
|
||||
{
|
||||
id: 'page-views',
|
||||
label: 'Page views',
|
||||
data: [2234, 3872, 2998, 4125, 3357, 2789, 2998],
|
||||
stack: 'A',
|
||||
},
|
||||
{
|
||||
id: 'downloads',
|
||||
label: 'Downloads',
|
||||
data: [3098, 4215, 2384, 2101, 4752, 3593, 2384],
|
||||
stack: 'A',
|
||||
},
|
||||
{
|
||||
id: 'conversions',
|
||||
label: 'Conversions',
|
||||
data: [4051, 2275, 3129, 4693, 3904, 2038, 2275],
|
||||
stack: 'A',
|
||||
},
|
||||
]}
|
||||
height={250}
|
||||
margin={{ left: 50, right: 0, top: 20, bottom: 20 }}
|
||||
grid={{ horizontal: true }}
|
||||
hideLegend
|
||||
slotProps={{
|
||||
legend: {
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
import FormControl from '@mui/material/FormControl';
|
||||
import InputAdornment from '@mui/material/InputAdornment';
|
||||
import OutlinedInput from '@mui/material/OutlinedInput';
|
||||
import SearchRoundedIcon from '@mui/icons-material/SearchRounded';
|
||||
|
||||
export default function Search() {
|
||||
return (
|
||||
<FormControl sx={{ width: { xs: '100%', md: '25ch' } }} variant="outlined">
|
||||
<OutlinedInput
|
||||
size="small"
|
||||
id="search"
|
||||
placeholder="Search…"
|
||||
sx={{ flexGrow: 1 }}
|
||||
startAdornment={
|
||||
<InputAdornment position="start" sx={{ color: 'text.primary' }}>
|
||||
<SearchRoundedIcon fontSize="small" />
|
||||
</InputAdornment>
|
||||
}
|
||||
inputProps={{
|
||||
'aria-label': 'search',
|
||||
}}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Ładowanie…
Reference in New Issue