kopia lustrzana https://github.com/Yakifo/amqtt
replacing amqtt publisher use of 'docopts' with 'typer'
rodzic
d39bd2ec5c
commit
7255561b09
|
@ -5,11 +5,11 @@ try:
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
UTC = timezone.utc
|
UTC = timezone.utc
|
||||||
|
|
||||||
from struct import unpack
|
from struct import unpack
|
||||||
from typing import Generic
|
from typing import Generic, Self, TypeVar
|
||||||
from typing_extensions import Self, TypeVar
|
|
||||||
|
|
||||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||||
from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise
|
from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise
|
||||||
|
|
|
@ -1,20 +1,24 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import SupportsIndex, SupportsInt # pylint: disable=C0412
|
from typing import SupportsIndex, SupportsInt # pylint: disable=C0412
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
try:
|
try:
|
||||||
|
from collections import deque
|
||||||
from collections.abc import Buffer
|
from collections.abc import Buffer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
from collections import deque
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Buffer(Protocol): # type: ignore[no-redef]
|
class Buffer(Protocol): # type: ignore[no-redef]
|
||||||
def __buffer__(self, flags: int = ...) -> memoryview:
|
def __buffer__(self, flags: int = ...) -> memoryview:
|
||||||
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
|
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
UTC = timezone.utc
|
UTC = timezone.utc
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,38 +1,7 @@
|
||||||
"""amqtt_pub - MQTT 3.1.1 publisher.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
amqtt_pub --version
|
|
||||||
amqtt_pub (-h | --help)
|
|
||||||
amqtt_pub --url BROKER_URL -t TOPIC (-f FILE | -l | -m MESSAGE | -n | -s) [-c CONFIG_FILE] [-i CLIENT_ID] [-q | --qos QOS] [-d] [-k KEEP_ALIVE] [--clean-session] [--ca-file CAFILE] [--ca-path CAPATH] [--ca-data CADATA] [ --will-topic WILL_TOPIC [--will-message WILL_MESSAGE] [--will-qos WILL_QOS] [--will-retain] ] [--extra-headers HEADER] [-r]
|
|
||||||
|
|
||||||
Options:
|
|
||||||
-h --help Show this screen.
|
|
||||||
--version Show version.
|
|
||||||
--url BROKER_URL Broker connection URL (must conform to MQTT URI scheme (see https://github.com/mqtt/mqtt.github.io/wiki/URI-Scheme>)
|
|
||||||
-c CONFIG_FILE Broker configuration file (YAML format)
|
|
||||||
-i CLIENT_ID Id to use as client ID.
|
|
||||||
-q | --qos QOS Quality of service to use for the message, from 0, 1, and 2. Defaults to 0.
|
|
||||||
-r Set retain flag on connect
|
|
||||||
-t TOPIC Message topic
|
|
||||||
-m MESSAGE Message data to send
|
|
||||||
-f FILE Read file by line and publish message for each line
|
|
||||||
-s Read from stdin and publish message for each line
|
|
||||||
-k KEEP_ALIVE Keep alive timeout in seconds
|
|
||||||
--clean-session Clean session on connect (defaults to False)
|
|
||||||
--ca-file CAFILE CA file
|
|
||||||
--ca-path CAPATH CA Path
|
|
||||||
--ca-data CADATA CA data
|
|
||||||
--will-topic WILL_TOPIC
|
|
||||||
--will-message WILL_MESSAGE
|
|
||||||
--will-qos WILL_QOS
|
|
||||||
--will-retain
|
|
||||||
--extra-headers EXTRA_HEADERS JSON object with key-value pairs of additional headers for websocket connections
|
|
||||||
-d Enable debug messages
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -41,11 +10,11 @@ import socket
|
||||||
import sys
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from docopt import docopt
|
import typer
|
||||||
|
|
||||||
import amqtt
|
from amqtt import __version__ as amqtt_version
|
||||||
from amqtt.client import MQTTClient
|
from amqtt.client import MQTTClient
|
||||||
from amqtt.errors import ConnectError
|
from amqtt.errors import ClientError, ConnectError
|
||||||
from amqtt.utils import read_yaml_config
|
from amqtt.utils import read_yaml_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -57,64 +26,79 @@ def _gen_client_id() -> str:
|
||||||
return f"amqtt_pub/{pid}-{hostname}"
|
return f"amqtt_pub/{pid}-{hostname}"
|
||||||
|
|
||||||
|
|
||||||
def _get_qos(arguments: dict[str, Any]) -> int | None:
|
def _get_extra_headers(extra_headers_json: str | None = None) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
return int(arguments["--qos"][0])
|
extra_headers: dict[str, Any] = json.loads(extra_headers_json or "{}")
|
||||||
except (ValueError, IndexError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
try:
|
|
||||||
extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"])
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return {}
|
return {}
|
||||||
return extra_headers
|
return extra_headers
|
||||||
|
|
||||||
|
|
||||||
def _get_message(arguments: dict[str, Any]) -> Generator[bytes | bytearray]:
|
@dataclass
|
||||||
if arguments["-n"]:
|
class MessageInput:
|
||||||
|
message_str: str | None = None
|
||||||
|
file: str | None = None
|
||||||
|
stdin: bool | None = False
|
||||||
|
lines: bool | None = False
|
||||||
|
no_message: bool | None = False
|
||||||
|
|
||||||
|
def get_message(self) -> Generator[bytes | bytearray]:
|
||||||
|
if self.no_message:
|
||||||
yield b""
|
yield b""
|
||||||
if arguments["-m"]:
|
if self.message_str:
|
||||||
yield arguments["-m"].encode(encoding="utf-8")
|
yield self.message_str.encode(encoding="utf-8")
|
||||||
if arguments["-f"]:
|
if self.file:
|
||||||
try:
|
try:
|
||||||
with Path(arguments["-f"]).open(encoding="utf-8") as f:
|
with Path(self.file).open(encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
yield line.encode(encoding="utf-8")
|
yield line.encode(encoding="utf-8")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to read file '{arguments['-f']}'")
|
logger.exception(f"Failed to read file '{self.file}'")
|
||||||
if arguments["-l"]:
|
if self.lines:
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
if line:
|
if line:
|
||||||
yield line.encode(encoding="utf-8")
|
yield line.encode(encoding="utf-8")
|
||||||
if arguments["-s"]:
|
if self.stdin:
|
||||||
message = bytearray()
|
messages = bytearray()
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
message.extend(line.encode(encoding="utf-8"))
|
messages.extend(line.encode(encoding="utf-8"))
|
||||||
yield message
|
yield messages
|
||||||
|
|
||||||
|
|
||||||
async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
|
@dataclass
|
||||||
"""Perform the publish."""
|
class CAInfo:
|
||||||
|
ca_file: str | None = None
|
||||||
|
ca_path: str | None = None
|
||||||
|
ca_data: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def do_pub(
|
||||||
|
client: MQTTClient,
|
||||||
|
url: str,
|
||||||
|
topic: str,
|
||||||
|
message_input: MessageInput,
|
||||||
|
ca_info: CAInfo,
|
||||||
|
clean_session: bool = False,
|
||||||
|
retain: bool = False,
|
||||||
|
extra_headers_json: str | None = None,
|
||||||
|
qos: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Publish the message."""
|
||||||
running_tasks = []
|
running_tasks = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"{client.client_id} Connecting to broker")
|
logger.info(f"{client.client_id} Connecting to broker")
|
||||||
|
|
||||||
await client.connect(
|
await client.connect(
|
||||||
uri=arguments["--url"],
|
uri=url,
|
||||||
cleansession=arguments["--clean-session"],
|
cleansession=clean_session,
|
||||||
cafile=arguments["--ca-file"],
|
cafile=ca_info.ca_file,
|
||||||
capath=arguments["--ca-path"],
|
capath=ca_info.ca_path,
|
||||||
cadata=arguments["--ca-data"],
|
cadata=ca_info.ca_data,
|
||||||
additional_headers=_get_extra_headers(arguments),
|
additional_headers=_get_extra_headers(extra_headers_json),
|
||||||
)
|
)
|
||||||
|
|
||||||
qos = _get_qos(arguments)
|
for message in message_input.get_message():
|
||||||
topic = arguments["-t"]
|
|
||||||
retain = arguments["-r"]
|
|
||||||
for message in _get_message(arguments):
|
|
||||||
logger.info(f"{client.client_id} Publishing to '{topic}'")
|
logger.info(f"{client.client_id} Publishing to '{topic}'")
|
||||||
task = asyncio.ensure_future(client.publish(topic, message, qos, retain))
|
task = asyncio.ensure_future(client.publish(topic, message, qos, retain))
|
||||||
running_tasks.append(task)
|
running_tasks.append(task)
|
||||||
|
@ -128,22 +112,68 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
logger.info(f"{client.client_id} Disconnected from broker")
|
logger.info(f"{client.client_id} Disconnected from broker")
|
||||||
except ConnectError as ce:
|
except ConnectError as ce:
|
||||||
logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}")
|
logger.fatal(f"Connection to '{url}' failed: {ce!r}")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.fatal("Publish canceled due to previous error")
|
logger.fatal("Publish canceled due to previous error")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
"""Entry point for the amqtt publisher."""
|
||||||
|
typer.run(publisher)
|
||||||
|
|
||||||
|
|
||||||
|
def _version() -> None:
|
||||||
|
typer.echo(f"{amqtt_version}")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
|
||||||
|
def publisher( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||||
|
url: str = typer.Option(
|
||||||
|
..., "--url", help="Broker connection URL (must conform to MQTT URI scheme: mqtt://<username:password>@HOST:port)"
|
||||||
|
),
|
||||||
|
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Broker configuration file (YAML format)"),
|
||||||
|
client_id: str | None = typer.Option(None, "-i", "--client-id", help="Client ID to use for the connection"),
|
||||||
|
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
|
||||||
|
retain: bool = typer.Option(False, "-r", help="Set retain flag on connect"),
|
||||||
|
topic: str = typer.Option(..., "-t", help="Message topic"),
|
||||||
|
message: str | None = typer.Option(None, "-m", help="Message data to send"),
|
||||||
|
file: str | None = typer.Option(None, "-f", help="Read file by line and publish each line as a message"),
|
||||||
|
stdin: bool = typer.Option(False, "-s", help="Read from stdin and publish message for first line"),
|
||||||
|
lines: bool = typer.Option(False, "-l", help="Read from stdin and publish message for each line"),
|
||||||
|
no_message: bool = typer.Option(False, "-n", help="Publish an empty message"),
|
||||||
|
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
|
||||||
|
clean_session: bool = typer.Option(False, "--clean-session", help="Clean session on connect (defaults to False)"),
|
||||||
|
ca_file: str | None = typer.Option(None, "--ca-file", help="CA file"),
|
||||||
|
ca_path: str | None = typer.Option(None, "--ca-path", help="CA path"),
|
||||||
|
ca_data: str | None = typer.Option(None, "--ca-data", help="CA data"),
|
||||||
|
will_topic: str | None = typer.Option(None, "--will-topic", help="Last will topic"),
|
||||||
|
will_message: str | None = typer.Option(None, "--will-message", help="Last will message"),
|
||||||
|
will_qos: int | None = typer.Option(None, "--will-qos", help="Last will QoS"),
|
||||||
|
will_retain: bool = typer.Option(False, "--will-retain", help="Set retain flag for last will message"),
|
||||||
|
extra_headers_json: str | None = typer.Option(
|
||||||
|
None, "--extra-headers", help="JSON object with key-value headers for websocket connections"
|
||||||
|
),
|
||||||
|
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
|
||||||
|
version: bool | None = typer.Option( # noqa : ARG001
|
||||||
|
None,
|
||||||
|
"--version",
|
||||||
|
callback=_version,
|
||||||
|
is_eager=True,
|
||||||
|
help="Show version and exit",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
"""Run the MQTT publisher."""
|
"""Run the MQTT publisher."""
|
||||||
arguments = docopt(__doc__, version=amqtt.__version__)
|
provided = [bool(message), bool(file), stdin, lines, no_message]
|
||||||
|
if sum(provided) != 1:
|
||||||
|
typer.echo("❌ You must provide exactly one of --config, --file, or --stdin.", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||||
level = logging.DEBUG if arguments["-d"] else logging.INFO
|
level = logging.DEBUG if debug else logging.INFO
|
||||||
logging.basicConfig(level=level, format=formatter)
|
logging.basicConfig(level=level, format=formatter)
|
||||||
|
|
||||||
config = None
|
if config_file:
|
||||||
if arguments["-c"]:
|
config = read_yaml_config(config_file)
|
||||||
config = read_yaml_config(arguments["-c"])
|
|
||||||
else:
|
else:
|
||||||
default_config_path = Path(__file__).parent / "default_client.yaml"
|
default_config_path = Path(__file__).parent / "default_client.yaml"
|
||||||
logger.debug(f"Using default configuration from {default_config_path}")
|
logger.debug(f"Using default configuration from {default_config_path}")
|
||||||
|
@ -151,7 +181,6 @@ def main() -> None:
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
client_id = arguments.get("-i", None)
|
|
||||||
if not client_id:
|
if not client_id:
|
||||||
client_id = _gen_client_id()
|
client_id = _gen_client_id()
|
||||||
|
|
||||||
|
@ -159,22 +188,49 @@ def main() -> None:
|
||||||
logger.debug("Failed to correctly initialize config")
|
logger.debug("Failed to correctly initialize config")
|
||||||
return
|
return
|
||||||
|
|
||||||
if arguments["-k"]:
|
if keep_alive:
|
||||||
config["keep_alive"] = int(arguments["-k"])
|
config["keep_alive"] = int(keep_alive)
|
||||||
|
|
||||||
if arguments["--will-topic"] and arguments["--will-message"] and arguments["--will-qos"]:
|
if will_topic and will_message and will_qos:
|
||||||
config["will"] = {
|
config["will"] = {
|
||||||
"topic": arguments["--will-topic"],
|
"topic": will_topic,
|
||||||
"message": arguments["--will-message"].encode("utf-8"),
|
"message": will_message.encode("utf-8"),
|
||||||
"qos": int(arguments["--will-qos"]),
|
"qos": int(will_qos),
|
||||||
"retain": arguments["--will-retain"],
|
"retain": will_retain,
|
||||||
}
|
}
|
||||||
|
|
||||||
client = MQTTClient(client_id=client_id, config=config)
|
client = MQTTClient(client_id=client_id, config=config)
|
||||||
|
message_input = MessageInput(
|
||||||
|
message_str=message,
|
||||||
|
file=file,
|
||||||
|
stdin=stdin,
|
||||||
|
no_message=no_message,
|
||||||
|
lines=lines,
|
||||||
|
)
|
||||||
|
ca_info = CAInfo(
|
||||||
|
ca_file=ca_file,
|
||||||
|
ca_path=ca_path,
|
||||||
|
ca_data=ca_data,
|
||||||
|
)
|
||||||
with contextlib.suppress(KeyboardInterrupt):
|
with contextlib.suppress(KeyboardInterrupt):
|
||||||
loop.run_until_complete(do_pub(client, arguments))
|
try:
|
||||||
|
loop.run_until_complete(
|
||||||
|
do_pub(
|
||||||
|
client=client,
|
||||||
|
message_input=message_input,
|
||||||
|
url=url,
|
||||||
|
topic=topic,
|
||||||
|
retain=retain,
|
||||||
|
clean_session=clean_session,
|
||||||
|
ca_info=ca_info,
|
||||||
|
extra_headers_json=extra_headers_json,
|
||||||
|
qos=qos,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except (ClientError, ConnectError) as ce:
|
||||||
|
typer.echo(f"❌ Connection failed: {ce}", err=True)
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
typer.run(main)
|
||||||
|
|
|
@ -2,6 +2,7 @@ try:
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
UTC = timezone.utc
|
UTC = timezone.utc
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling", "hatch-vcs"]
|
requires = ["hatchling", "hatch-vcs", "uv-dynamic-versioning"]
|
||||||
|
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
|
@ -117,11 +118,13 @@ ignore = [
|
||||||
"TD003", # TODO Missing issue link for this TODO.
|
"TD003", # TODO Missing issue link for this TODO.
|
||||||
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
||||||
"ARG002", # Unused method argument
|
"ARG002", # Unused method argument
|
||||||
"PERF203" # try-except penalty within loops (3.10 only)
|
"PERF203",# try-except penalty within loops (3.10 only),
|
||||||
|
"COM812" # rule causes conflicts when used with the formatter
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"tests/**" = ["ALL"]
|
"tests/**" = ["ALL"]
|
||||||
|
"amqtt/scripts/pub_script.py" = ["FBT003"]
|
||||||
|
|
||||||
[tool.ruff.lint.flake8-pytest-style]
|
[tool.ruff.lint.flake8-pytest-style]
|
||||||
fixture-parentheses = false
|
fixture-parentheses = false
|
||||||
|
|
Ładowanie…
Reference in New Issue