replacing amqtt publisher use of 'docopts' with 'typer'

pull/168/head
Andrew Mirsky 2025-05-18 11:12:14 -04:00
rodzic d39bd2ec5c
commit 7255561b09
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
5 zmienionych plików z 161 dodań i 97 usunięć

Wyświetl plik

@ -5,11 +5,11 @@ try:
from datetime import UTC, datetime
except ImportError:
from datetime import datetime, timezone
UTC = timezone.utc
from struct import unpack
from typing import Generic
from typing_extensions import Self, TypeVar
from typing import Generic, Self, TypeVar
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise

Wyświetl plik

@ -1,20 +1,24 @@
import asyncio
from typing import SupportsIndex, SupportsInt # pylint: disable=C0412
from collections import deque
try:
from collections import deque
from collections.abc import Buffer
except ImportError:
from collections import deque
from typing import Protocol, runtime_checkable
@runtime_checkable
class Buffer(Protocol): # type: ignore[no-redef]
def __buffer__(self, flags: int = ...) -> memoryview:
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
try:
from datetime import UTC, datetime
except ImportError:
from datetime import datetime, timezone
UTC = timezone.utc

Wyświetl plik

@ -1,38 +1,7 @@
"""amqtt_pub - MQTT 3.1.1 publisher.
Usage:
amqtt_pub --version
amqtt_pub (-h | --help)
amqtt_pub --url BROKER_URL -t TOPIC (-f FILE | -l | -m MESSAGE | -n | -s) [-c CONFIG_FILE] [-i CLIENT_ID] [-q | --qos QOS] [-d] [-k KEEP_ALIVE] [--clean-session] [--ca-file CAFILE] [--ca-path CAPATH] [--ca-data CADATA] [ --will-topic WILL_TOPIC [--will-message WILL_MESSAGE] [--will-qos WILL_QOS] [--will-retain] ] [--extra-headers HEADER] [-r]
Options:
-h --help Show this screen.
--version Show version.
--url BROKER_URL Broker connection URL (must conform to MQTT URI scheme (see https://github.com/mqtt/mqtt.github.io/wiki/URI-Scheme>)
-c CONFIG_FILE Broker configuration file (YAML format)
-i CLIENT_ID Id to use as client ID.
-q | --qos QOS Quality of service to use for the message, from 0, 1, and 2. Defaults to 0.
-r Set retain flag on connect
-t TOPIC Message topic
-m MESSAGE Message data to send
-f FILE Read file by line and publish message for each line
-s Read from stdin and publish message for each line
-k KEEP_ALIVE Keep alive timeout in seconds
--clean-session Clean session on connect (defaults to False)
--ca-file CAFILE CA file
--ca-path CAPATH CA Path
--ca-data CADATA CA data
--will-topic WILL_TOPIC
--will-message WILL_MESSAGE
--will-qos WILL_QOS
--will-retain
--extra-headers EXTRA_HEADERS JSON object with key-value pairs of additional headers for websocket connections
-d Enable debug messages
""" # noqa: E501
import asyncio
from collections.abc import Generator
import contextlib
from dataclasses import dataclass
import json
import logging
import os
@ -41,11 +10,11 @@ import socket
import sys
from typing import Any
from docopt import docopt
import typer
import amqtt
from amqtt import __version__ as amqtt_version
from amqtt.client import MQTTClient
from amqtt.errors import ConnectError
from amqtt.errors import ClientError, ConnectError
from amqtt.utils import read_yaml_config
logger = logging.getLogger(__name__)
@ -57,64 +26,79 @@ def _gen_client_id() -> str:
return f"amqtt_pub/{pid}-{hostname}"
def _get_qos(arguments: dict[str, Any]) -> int | None:
def _get_extra_headers(extra_headers_json: str | None = None) -> dict[str, Any]:
try:
return int(arguments["--qos"][0])
except (ValueError, IndexError):
return None
def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]:
try:
extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"])
extra_headers: dict[str, Any] = json.loads(extra_headers_json or "{}")
except (json.JSONDecodeError, TypeError):
return {}
return extra_headers
def _get_message(arguments: dict[str, Any]) -> Generator[bytes | bytearray]:
if arguments["-n"]:
yield b""
if arguments["-m"]:
yield arguments["-m"].encode(encoding="utf-8")
if arguments["-f"]:
try:
with Path(arguments["-f"]).open(encoding="utf-8") as f:
for line in f:
@dataclass
class MessageInput:
message_str: str | None = None
file: str | None = None
stdin: bool | None = False
lines: bool | None = False
no_message: bool | None = False
def get_message(self) -> Generator[bytes | bytearray]:
if self.no_message:
yield b""
if self.message_str:
yield self.message_str.encode(encoding="utf-8")
if self.file:
try:
with Path(self.file).open(encoding="utf-8") as f:
for line in f:
yield line.encode(encoding="utf-8")
except Exception:
logger.exception(f"Failed to read file '{self.file}'")
if self.lines:
for line in sys.stdin:
if line:
yield line.encode(encoding="utf-8")
except Exception:
logger.exception(f"Failed to read file '{arguments['-f']}'")
if arguments["-l"]:
for line in sys.stdin:
if line:
yield line.encode(encoding="utf-8")
if arguments["-s"]:
message = bytearray()
for line in sys.stdin:
message.extend(line.encode(encoding="utf-8"))
yield message
if self.stdin:
messages = bytearray()
for line in sys.stdin:
messages.extend(line.encode(encoding="utf-8"))
yield messages
async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
"""Perform the publish."""
@dataclass
class CAInfo:
ca_file: str | None = None
ca_path: str | None = None
ca_data: str | None = None
async def do_pub(
client: MQTTClient,
url: str,
topic: str,
message_input: MessageInput,
ca_info: CAInfo,
clean_session: bool = False,
retain: bool = False,
extra_headers_json: str | None = None,
qos: int | None = None,
) -> None:
"""Publish the message."""
running_tasks = []
try:
logger.info(f"{client.client_id} Connecting to broker")
await client.connect(
uri=arguments["--url"],
cleansession=arguments["--clean-session"],
cafile=arguments["--ca-file"],
capath=arguments["--ca-path"],
cadata=arguments["--ca-data"],
additional_headers=_get_extra_headers(arguments),
uri=url,
cleansession=clean_session,
cafile=ca_info.ca_file,
capath=ca_info.ca_path,
cadata=ca_info.ca_data,
additional_headers=_get_extra_headers(extra_headers_json),
)
qos = _get_qos(arguments)
topic = arguments["-t"]
retain = arguments["-r"]
for message in _get_message(arguments):
for message in message_input.get_message():
logger.info(f"{client.client_id} Publishing to '{topic}'")
task = asyncio.ensure_future(client.publish(topic, message, qos, retain))
running_tasks.append(task)
@ -128,22 +112,68 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
await client.disconnect()
logger.info(f"{client.client_id} Disconnected from broker")
except ConnectError as ce:
logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}")
logger.fatal(f"Connection to '{url}' failed: {ce!r}")
except asyncio.CancelledError:
logger.fatal("Publish canceled due to previous error")
def main() -> None:
"""Entry point for the amqtt publisher."""
typer.run(publisher)
def _version() -> None:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
def publisher( # pylint: disable=R0914,R0917 # noqa : PLR0913
url: str = typer.Option(
..., "--url", help="Broker connection URL (must conform to MQTT URI scheme: mqtt://<username:password>@HOST:port)"
),
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Broker configuration file (YAML format)"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="Client ID to use for the connection"),
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
retain: bool = typer.Option(False, "-r", help="Set retain flag on connect"),
topic: str = typer.Option(..., "-t", help="Message topic"),
message: str | None = typer.Option(None, "-m", help="Message data to send"),
file: str | None = typer.Option(None, "-f", help="Read file by line and publish each line as a message"),
stdin: bool = typer.Option(False, "-s", help="Read from stdin and publish message for first line"),
lines: bool = typer.Option(False, "-l", help="Read from stdin and publish message for each line"),
no_message: bool = typer.Option(False, "-n", help="Publish an empty message"),
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
clean_session: bool = typer.Option(False, "--clean-session", help="Clean session on connect (defaults to False)"),
ca_file: str | None = typer.Option(None, "--ca-file", help="CA file"),
ca_path: str | None = typer.Option(None, "--ca-path", help="CA path"),
ca_data: str | None = typer.Option(None, "--ca-data", help="CA data"),
will_topic: str | None = typer.Option(None, "--will-topic", help="Last will topic"),
will_message: str | None = typer.Option(None, "--will-message", help="Last will message"),
will_qos: int | None = typer.Option(None, "--will-qos", help="Last will QoS"),
will_retain: bool = typer.Option(False, "--will-retain", help="Set retain flag for last will message"),
extra_headers_json: str | None = typer.Option(
None, "--extra-headers", help="JSON object with key-value headers for websocket connections"
),
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
version: bool | None = typer.Option( # noqa : ARG001
None,
"--version",
callback=_version,
is_eager=True,
help="Show version and exit",
),
) -> None:
"""Run the MQTT publisher."""
arguments = docopt(__doc__, version=amqtt.__version__)
provided = [bool(message), bool(file), stdin, lines, no_message]
if sum(provided) != 1:
typer.echo("❌ You must provide exactly one of --config, --file, or --stdin.", err=True)
raise typer.Exit(code=1)
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
level = logging.DEBUG if arguments["-d"] else logging.INFO
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(level=level, format=formatter)
config = None
if arguments["-c"]:
config = read_yaml_config(arguments["-c"])
if config_file:
config = read_yaml_config(config_file)
else:
default_config_path = Path(__file__).parent / "default_client.yaml"
logger.debug(f"Using default configuration from {default_config_path}")
@ -151,7 +181,6 @@ def main() -> None:
loop = asyncio.get_event_loop()
client_id = arguments.get("-i", None)
if not client_id:
client_id = _gen_client_id()
@ -159,22 +188,49 @@ def main() -> None:
logger.debug("Failed to correctly initialize config")
return
if arguments["-k"]:
config["keep_alive"] = int(arguments["-k"])
if keep_alive:
config["keep_alive"] = int(keep_alive)
if arguments["--will-topic"] and arguments["--will-message"] and arguments["--will-qos"]:
if will_topic and will_message and will_qos:
config["will"] = {
"topic": arguments["--will-topic"],
"message": arguments["--will-message"].encode("utf-8"),
"qos": int(arguments["--will-qos"]),
"retain": arguments["--will-retain"],
"topic": will_topic,
"message": will_message.encode("utf-8"),
"qos": int(will_qos),
"retain": will_retain,
}
client = MQTTClient(client_id=client_id, config=config)
message_input = MessageInput(
message_str=message,
file=file,
stdin=stdin,
no_message=no_message,
lines=lines,
)
ca_info = CAInfo(
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
)
with contextlib.suppress(KeyboardInterrupt):
loop.run_until_complete(do_pub(client, arguments))
try:
loop.run_until_complete(
do_pub(
client=client,
message_input=message_input,
url=url,
topic=topic,
retain=retain,
clean_session=clean_session,
ca_info=ca_info,
extra_headers_json=extra_headers_json,
qos=qos,
)
)
except (ClientError, ConnectError) as ce:
typer.echo(f"❌ Connection failed: {ce}", err=True)
loop.close()
if __name__ == "__main__":
main()
typer.run(main)

Wyświetl plik

@ -2,6 +2,7 @@ try:
from datetime import UTC, datetime
except ImportError:
from datetime import datetime, timezone
UTC = timezone.utc
import logging

Wyświetl plik

@ -1,5 +1,6 @@
[build-system]
requires = ["hatchling", "hatch-vcs"]
requires = ["hatchling", "hatch-vcs", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
[project]
@ -117,11 +118,13 @@ ignore = [
"TD003", # TODO Missing issue link for this TODO.
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument
"PERF203" # try-except penalty within loops (3.10 only)
"PERF203",# try-except penalty within loops (3.10 only),
"COM812" # rule causes conflicts when used with the formatter
]
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["ALL"]
"amqtt/scripts/pub_script.py" = ["FBT003"]
[tool.ruff.lint.flake8-pytest-style]
fixture-parentheses = false