replacing amqtt publisher use of 'docopts' with 'typer'

pull/168/head
Andrew Mirsky 2025-05-18 11:12:14 -04:00
rodzic d39bd2ec5c
commit 7255561b09
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
5 zmienionych plików z 161 dodań i 97 usunięć

Wyświetl plik

@ -5,11 +5,11 @@ try:
from datetime import UTC, datetime from datetime import UTC, datetime
except ImportError: except ImportError:
from datetime import datetime, timezone from datetime import datetime, timezone
UTC = timezone.utc UTC = timezone.utc
from struct import unpack from struct import unpack
from typing import Generic from typing import Generic, Self, TypeVar
from typing_extensions import Self, TypeVar
from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise

Wyświetl plik

@ -1,20 +1,24 @@
import asyncio import asyncio
from typing import SupportsIndex, SupportsInt # pylint: disable=C0412 from typing import SupportsIndex, SupportsInt # pylint: disable=C0412
from collections import deque
try: try:
from collections import deque
from collections.abc import Buffer from collections.abc import Buffer
except ImportError: except ImportError:
from collections import deque
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
@runtime_checkable @runtime_checkable
class Buffer(Protocol): # type: ignore[no-redef] class Buffer(Protocol): # type: ignore[no-redef]
def __buffer__(self, flags: int = ...) -> memoryview: def __buffer__(self, flags: int = ...) -> memoryview:
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12.""" """Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
try: try:
from datetime import UTC, datetime from datetime import UTC, datetime
except ImportError: except ImportError:
from datetime import datetime, timezone from datetime import datetime, timezone
UTC = timezone.utc UTC = timezone.utc

Wyświetl plik

@ -1,38 +1,7 @@
"""amqtt_pub - MQTT 3.1.1 publisher.
Usage:
amqtt_pub --version
amqtt_pub (-h | --help)
amqtt_pub --url BROKER_URL -t TOPIC (-f FILE | -l | -m MESSAGE | -n | -s) [-c CONFIG_FILE] [-i CLIENT_ID] [-q | --qos QOS] [-d] [-k KEEP_ALIVE] [--clean-session] [--ca-file CAFILE] [--ca-path CAPATH] [--ca-data CADATA] [ --will-topic WILL_TOPIC [--will-message WILL_MESSAGE] [--will-qos WILL_QOS] [--will-retain] ] [--extra-headers HEADER] [-r]
Options:
-h --help Show this screen.
--version Show version.
--url BROKER_URL Broker connection URL (must conform to MQTT URI scheme (see https://github.com/mqtt/mqtt.github.io/wiki/URI-Scheme>)
-c CONFIG_FILE Broker configuration file (YAML format)
-i CLIENT_ID Id to use as client ID.
-q | --qos QOS Quality of service to use for the message, from 0, 1, and 2. Defaults to 0.
-r Set retain flag on connect
-t TOPIC Message topic
-m MESSAGE Message data to send
-f FILE Read file by line and publish message for each line
-s Read from stdin and publish message for each line
-k KEEP_ALIVE Keep alive timeout in seconds
--clean-session Clean session on connect (defaults to False)
--ca-file CAFILE CA file
--ca-path CAPATH CA Path
--ca-data CADATA CA data
--will-topic WILL_TOPIC
--will-message WILL_MESSAGE
--will-qos WILL_QOS
--will-retain
--extra-headers EXTRA_HEADERS JSON object with key-value pairs of additional headers for websocket connections
-d Enable debug messages
""" # noqa: E501
import asyncio import asyncio
from collections.abc import Generator from collections.abc import Generator
import contextlib import contextlib
from dataclasses import dataclass
import json import json
import logging import logging
import os import os
@ -41,11 +10,11 @@ import socket
import sys import sys
from typing import Any from typing import Any
from docopt import docopt import typer
import amqtt from amqtt import __version__ as amqtt_version
from amqtt.client import MQTTClient from amqtt.client import MQTTClient
from amqtt.errors import ConnectError from amqtt.errors import ClientError, ConnectError
from amqtt.utils import read_yaml_config from amqtt.utils import read_yaml_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,64 +26,79 @@ def _gen_client_id() -> str:
return f"amqtt_pub/{pid}-{hostname}" return f"amqtt_pub/{pid}-{hostname}"
def _get_qos(arguments: dict[str, Any]) -> int | None: def _get_extra_headers(extra_headers_json: str | None = None) -> dict[str, Any]:
try: try:
return int(arguments["--qos"][0]) extra_headers: dict[str, Any] = json.loads(extra_headers_json or "{}")
except (ValueError, IndexError):
return None
def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]:
try:
extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"])
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return {} return {}
return extra_headers return extra_headers
def _get_message(arguments: dict[str, Any]) -> Generator[bytes | bytearray]: @dataclass
if arguments["-n"]: class MessageInput:
message_str: str | None = None
file: str | None = None
stdin: bool | None = False
lines: bool | None = False
no_message: bool | None = False
def get_message(self) -> Generator[bytes | bytearray]:
if self.no_message:
yield b"" yield b""
if arguments["-m"]: if self.message_str:
yield arguments["-m"].encode(encoding="utf-8") yield self.message_str.encode(encoding="utf-8")
if arguments["-f"]: if self.file:
try: try:
with Path(arguments["-f"]).open(encoding="utf-8") as f: with Path(self.file).open(encoding="utf-8") as f:
for line in f: for line in f:
yield line.encode(encoding="utf-8") yield line.encode(encoding="utf-8")
except Exception: except Exception:
logger.exception(f"Failed to read file '{arguments['-f']}'") logger.exception(f"Failed to read file '{self.file}'")
if arguments["-l"]: if self.lines:
for line in sys.stdin: for line in sys.stdin:
if line: if line:
yield line.encode(encoding="utf-8") yield line.encode(encoding="utf-8")
if arguments["-s"]: if self.stdin:
message = bytearray() messages = bytearray()
for line in sys.stdin: for line in sys.stdin:
message.extend(line.encode(encoding="utf-8")) messages.extend(line.encode(encoding="utf-8"))
yield message yield messages
async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None: @dataclass
"""Perform the publish.""" class CAInfo:
ca_file: str | None = None
ca_path: str | None = None
ca_data: str | None = None
async def do_pub(
client: MQTTClient,
url: str,
topic: str,
message_input: MessageInput,
ca_info: CAInfo,
clean_session: bool = False,
retain: bool = False,
extra_headers_json: str | None = None,
qos: int | None = None,
) -> None:
"""Publish the message."""
running_tasks = [] running_tasks = []
try: try:
logger.info(f"{client.client_id} Connecting to broker") logger.info(f"{client.client_id} Connecting to broker")
await client.connect( await client.connect(
uri=arguments["--url"], uri=url,
cleansession=arguments["--clean-session"], cleansession=clean_session,
cafile=arguments["--ca-file"], cafile=ca_info.ca_file,
capath=arguments["--ca-path"], capath=ca_info.ca_path,
cadata=arguments["--ca-data"], cadata=ca_info.ca_data,
additional_headers=_get_extra_headers(arguments), additional_headers=_get_extra_headers(extra_headers_json),
) )
qos = _get_qos(arguments) for message in message_input.get_message():
topic = arguments["-t"]
retain = arguments["-r"]
for message in _get_message(arguments):
logger.info(f"{client.client_id} Publishing to '{topic}'") logger.info(f"{client.client_id} Publishing to '{topic}'")
task = asyncio.ensure_future(client.publish(topic, message, qos, retain)) task = asyncio.ensure_future(client.publish(topic, message, qos, retain))
running_tasks.append(task) running_tasks.append(task)
@ -128,22 +112,68 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
await client.disconnect() await client.disconnect()
logger.info(f"{client.client_id} Disconnected from broker") logger.info(f"{client.client_id} Disconnected from broker")
except ConnectError as ce: except ConnectError as ce:
logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}") logger.fatal(f"Connection to '{url}' failed: {ce!r}")
except asyncio.CancelledError: except asyncio.CancelledError:
logger.fatal("Publish canceled due to previous error") logger.fatal("Publish canceled due to previous error")
def main() -> None: def main() -> None:
"""Entry point for the amqtt publisher."""
typer.run(publisher)
def _version() -> None:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
def publisher( # pylint: disable=R0914,R0917 # noqa : PLR0913
url: str = typer.Option(
..., "--url", help="Broker connection URL (must conform to MQTT URI scheme: mqtt://<username:password>@HOST:port)"
),
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Broker configuration file (YAML format)"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="Client ID to use for the connection"),
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
retain: bool = typer.Option(False, "-r", help="Set retain flag on connect"),
topic: str = typer.Option(..., "-t", help="Message topic"),
message: str | None = typer.Option(None, "-m", help="Message data to send"),
file: str | None = typer.Option(None, "-f", help="Read file by line and publish each line as a message"),
stdin: bool = typer.Option(False, "-s", help="Read from stdin and publish message for first line"),
lines: bool = typer.Option(False, "-l", help="Read from stdin and publish message for each line"),
no_message: bool = typer.Option(False, "-n", help="Publish an empty message"),
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
clean_session: bool = typer.Option(False, "--clean-session", help="Clean session on connect (defaults to False)"),
ca_file: str | None = typer.Option(None, "--ca-file", help="CA file"),
ca_path: str | None = typer.Option(None, "--ca-path", help="CA path"),
ca_data: str | None = typer.Option(None, "--ca-data", help="CA data"),
will_topic: str | None = typer.Option(None, "--will-topic", help="Last will topic"),
will_message: str | None = typer.Option(None, "--will-message", help="Last will message"),
will_qos: int | None = typer.Option(None, "--will-qos", help="Last will QoS"),
will_retain: bool = typer.Option(False, "--will-retain", help="Set retain flag for last will message"),
extra_headers_json: str | None = typer.Option(
None, "--extra-headers", help="JSON object with key-value headers for websocket connections"
),
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
version: bool | None = typer.Option( # noqa : ARG001
None,
"--version",
callback=_version,
is_eager=True,
help="Show version and exit",
),
) -> None:
"""Run the MQTT publisher.""" """Run the MQTT publisher."""
arguments = docopt(__doc__, version=amqtt.__version__) provided = [bool(message), bool(file), stdin, lines, no_message]
if sum(provided) != 1:
typer.echo("❌ You must provide exactly one of --config, --file, or --stdin.", err=True)
raise typer.Exit(code=1)
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
level = logging.DEBUG if arguments["-d"] else logging.INFO level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(level=level, format=formatter) logging.basicConfig(level=level, format=formatter)
config = None if config_file:
if arguments["-c"]: config = read_yaml_config(config_file)
config = read_yaml_config(arguments["-c"])
else: else:
default_config_path = Path(__file__).parent / "default_client.yaml" default_config_path = Path(__file__).parent / "default_client.yaml"
logger.debug(f"Using default configuration from {default_config_path}") logger.debug(f"Using default configuration from {default_config_path}")
@ -151,7 +181,6 @@ def main() -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
client_id = arguments.get("-i", None)
if not client_id: if not client_id:
client_id = _gen_client_id() client_id = _gen_client_id()
@ -159,22 +188,49 @@ def main() -> None:
logger.debug("Failed to correctly initialize config") logger.debug("Failed to correctly initialize config")
return return
if arguments["-k"]: if keep_alive:
config["keep_alive"] = int(arguments["-k"]) config["keep_alive"] = int(keep_alive)
if arguments["--will-topic"] and arguments["--will-message"] and arguments["--will-qos"]: if will_topic and will_message and will_qos:
config["will"] = { config["will"] = {
"topic": arguments["--will-topic"], "topic": will_topic,
"message": arguments["--will-message"].encode("utf-8"), "message": will_message.encode("utf-8"),
"qos": int(arguments["--will-qos"]), "qos": int(will_qos),
"retain": arguments["--will-retain"], "retain": will_retain,
} }
client = MQTTClient(client_id=client_id, config=config) client = MQTTClient(client_id=client_id, config=config)
message_input = MessageInput(
message_str=message,
file=file,
stdin=stdin,
no_message=no_message,
lines=lines,
)
ca_info = CAInfo(
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
)
with contextlib.suppress(KeyboardInterrupt): with contextlib.suppress(KeyboardInterrupt):
loop.run_until_complete(do_pub(client, arguments)) try:
loop.run_until_complete(
do_pub(
client=client,
message_input=message_input,
url=url,
topic=topic,
retain=retain,
clean_session=clean_session,
ca_info=ca_info,
extra_headers_json=extra_headers_json,
qos=qos,
)
)
except (ClientError, ConnectError) as ce:
typer.echo(f"❌ Connection failed: {ce}", err=True)
loop.close() loop.close()
if __name__ == "__main__": if __name__ == "__main__":
main() typer.run(main)

Wyświetl plik

@ -2,6 +2,7 @@ try:
from datetime import UTC, datetime from datetime import UTC, datetime
except ImportError: except ImportError:
from datetime import datetime, timezone from datetime import datetime, timezone
UTC = timezone.utc UTC = timezone.utc
import logging import logging

Wyświetl plik

@ -1,5 +1,6 @@
[build-system] [build-system]
requires = ["hatchling", "hatch-vcs"] requires = ["hatchling", "hatch-vcs", "uv-dynamic-versioning"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[project] [project]
@ -117,11 +118,13 @@ ignore = [
"TD003", # TODO Missing issue link for this TODO. "TD003", # TODO Missing issue link for this TODO.
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed "ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument "ARG002", # Unused method argument
"PERF203" # try-except penalty within loops (3.10 only) "PERF203",# try-except penalty within loops (3.10 only),
"COM812" # rule causes conflicts when used with the formatter
] ]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"tests/**" = ["ALL"] "tests/**" = ["ALL"]
"amqtt/scripts/pub_script.py" = ["FBT003"]
[tool.ruff.lint.flake8-pytest-style] [tool.ruff.lint.flake8-pytest-style]
fixture-parentheses = false fixture-parentheses = false