code quality improvements (#293)

* add additional linting rules
pull/276/head
Andrew Mirsky 2025-08-10 22:00:33 -04:00 zatwierdzone przez GitHub
rodzic b4d58c9130
commit 8e47ede192
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
50 zmienionych plików z 209 dodań i 191 usunięć

Wyświetl plik

@ -325,7 +325,6 @@ class Broker:
ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener."""
match listener_type:
case ListenerType.TCP:
return await asyncio.start_server(
@ -588,8 +587,6 @@ class Broker:
for topic in self._subscriptions:
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler)
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
@ -620,7 +617,6 @@ class Broker:
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
if subscribe_waiter in done:
await self._handle_subscription(client_session, handler, subscribe_waiter)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
@ -690,7 +686,6 @@ class Broker:
client_id=client_session.client_id,
client_session=client_session)
async def _handle_subscription(
self,
client_session: Session,

Wyświetl plik

@ -424,14 +424,10 @@ class MQTTClient:
scheme = uri_attributes.scheme
secure = scheme in ("mqtts", "wss")
self.session.username = (
self.session.username
if self.session.username
else (str(uri_attributes.username) if uri_attributes.username else None)
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
)
self.session.password = (
self.session.password
if self.session.password
else (str(uri_attributes.password) if uri_attributes.password else None)
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
)
self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
self.session.remote_port = uri_attributes.port

Wyświetl plik

@ -50,6 +50,7 @@ class ListenerType(StrEnum):
"""Display the string value, instead of the enum member."""
return f'"{self.value!s}"'
class Dictable:
"""Add dictionary methods to a dataclass."""
@ -263,6 +264,7 @@ class ConnectionConfig(Dictable):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
@dataclass
class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published.

Wyświetl plik

@ -36,9 +36,11 @@ class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
if value is None:
return None
return [self.dataclass_type(**item) for item in value]
def process_literal_param(self, value: Any, dialect: Any) -> Any:
# Required by SQLAlchemy, typically used for literal SQL rendering.
return value
@property
def python_type(self) -> type:
# Required by TypeEngine to indicate the expected Python type.

Wyświetl plik

@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
matcher = TopicMatcher()
@dataclass
class AllowedTopic:
topic: str
@ -43,6 +44,7 @@ class AllowedTopic:
"""Display topic."""
return self.topic
class PasswordHasher:
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""

Wyświetl plik

@ -66,6 +66,7 @@ class UserAuthDBPlugin(BaseAuthPlugin):
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
"""list of hash schemes to use for passwords"""
class TopicAuthDBPlugin(BaseTopicPlugin):
def __init__(self, context: BrokerContext) -> None:

Wyświetl plik

@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
topic_app = typer.Typer(no_args_is_help=True)
@topic_app.callback()
def main(
ctx: typer.Context,

Wyświetl plik

@ -89,6 +89,7 @@ class CertificateAuthPlugin(BaseAuthPlugin):
uri_domain: str
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
def generate_root_creds(country: str, state: str, locality: str,
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Generate CA key and certificate."""
@ -167,7 +168,6 @@ def generate_server_csr(country:str, org_name: str, cn:str) -> tuple[rsa.RSAPriv
return key, csr
def generate_device_csr(country: str, org_name: str, common_name: str,
uri_san: str, dns_san: str
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
@ -194,6 +194,7 @@ def generate_device_csr(country: str, org_name: str, common_name: str,
return key, csr
def sign_csr(csr: CertificateSigningRequest,
ca_key: rsa.RSAPrivateKey,
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
@ -221,6 +222,7 @@ def sign_csr(csr: CertificateSigningRequest,
.sign(ca_key, hashes.SHA256())
)
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Load server key and certificate."""
with Path(ca_key_fn).open("rb") as f:

Wyświetl plik

@ -26,6 +26,7 @@ class ResponseMode(StrEnum):
JSON = "json"
TEXT = "text"
class RequestMethod(StrEnum):
GET = "get"
POST = "post"

Wyświetl plik

@ -8,7 +8,7 @@ from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.errors import PluginInitError
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin, BasePlugin
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
@ -42,6 +42,7 @@ class AuthLdapPlugin(BasePlugin[BrokerContext]):
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
raise PluginInitError(self.__class__) from e
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
@ -76,6 +77,7 @@ class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
class Config(LdapConfig):
"""Configuration for the User Auth LDAP Plugin."""
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
@ -107,7 +109,6 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
]
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
if not results:
logger.debug(f"user not found: {session.username}")
return False
@ -125,7 +126,6 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
@dataclass
class Config(LdapConfig):
"""Configuration for the LDAPAuthPlugin."""

Wyświetl plik

@ -92,7 +92,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
await db_session.flush()
return stored_session
@staticmethod
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
@ -104,7 +103,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
await db_session.flush()
return stored_message
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
"""Search to see if session already exists."""
# if client id doesn't exist, create (can ignore if session is anonymous)
@ -240,7 +238,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
restored_messages += 1
logger.info(f"Retained messages restored: {restored_messages}")
logger.info(f"Restored {restored_sessions} sessions.")
async def on_broker_pre_shutdown(self) -> None:

Wyświetl plik

@ -35,6 +35,7 @@ class ShadowMessage:
def to_message(self) -> bytes:
return json.dumps(asdict_no_none(self)).encode("utf-8")
@dataclass
class GetAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
@ -46,6 +47,7 @@ class GetAcceptedMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
@dataclass
class GetRejectedMessage(ShadowMessage):
code: int
@ -56,6 +58,7 @@ class GetRejectedMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
@dataclass
class UpdateAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
@ -78,6 +81,7 @@ class UpdateRejectedMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
@dataclass
class UpdateDeltaMessage(ShadowMessage):
state: MutableMapping[str, Any]
@ -89,6 +93,7 @@ class UpdateDeltaMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
class UpdateIotaMessage(UpdateDeltaMessage):
"""Same format, corollary name."""
@ -96,6 +101,7 @@ class UpdateIotaMessage(UpdateDeltaMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
@dataclass
class UpdateDocumentMessage(ShadowMessage):
previous: StateDocument

Wyświetl plik

@ -127,6 +127,7 @@ def convert_update_to_insert(session: Session, _flush_context: object, _instance
session.add(obj) # re-add as new object
_listener_example = '''#
# @event.listens_for(Shadow, "before_insert")
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:

Wyświetl plik

@ -43,6 +43,7 @@ def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
"""Nested defaultdict for shadow cache."""
return defaultdict(shadow_dict) # type: ignore[arg-type]
class ShadowPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
@ -52,7 +53,6 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
self._engine = create_async_engine(self.config.connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def on_broker_pre_start(self) -> None:
"""Sync the schema."""
async with self._engine.begin() as conn:

Wyświetl plik

@ -16,10 +16,12 @@ from mergedeep import merge
C = TypeVar("C", bound=Any)
class StateError(Exception):
def __init__(self, msg: str = "'state' field is required") -> None:
super().__init__(msg)
@dataclass
class MetaTimestamp:
timestamp: int = 0
@ -127,6 +129,7 @@ def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMa
return delta
@dataclass
class State(Generic[C]):
desired: MutableMapping[str, C] = field(default_factory=dict)

Wyświetl plik

@ -16,10 +16,12 @@ class CodecError(Exception):
class NoDataError(Exception):
"""Exceptions thrown by packet encode/decode functions."""
class ZeroLengthReadError(NoDataError):
def __init__(self) -> None:
super().__init__("Decoding a string of length zero.")
class BrokerError(Exception):
"""Exceptions thrown by broker."""

Wyświetl plik

@ -167,7 +167,7 @@ class PacketIdVariableHeader(MQTTVariableHeader):
_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
class MQTTPayload(Generic[_VH], ABC):
class MQTTPayload(ABC, Generic[_VH]):
"""Abstract base class for MQTT payloads."""
async def to_stream(self, writer: asyncio.StreamWriter) -> None:

Wyświetl plik

@ -31,6 +31,7 @@ _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
if TYPE_CHECKING:
from amqtt.broker import BrokerContext
class Subscription:
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
self.packet_id = packet_id

Wyświetl plik

@ -19,6 +19,7 @@ from amqtt.session import Session
if TYPE_CHECKING:
from amqtt.client import ClientContext
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
def __init__(
self,

Wyświetl plik

@ -63,6 +63,7 @@ from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplic
C = TypeVar("C", bound=BaseContext)
class ProtocolHandler(Generic[C]):
"""Class implementing the MQTT communication protocol using asyncio features."""

Wyświetl plik

@ -126,7 +126,6 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
# auth config section not found and Config dataclass not provided
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.

Wyświetl plik

@ -51,6 +51,7 @@ def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
C = TypeVar("C", bound=BaseContext)
class PluginManager(Generic[C]):
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
@ -97,7 +98,6 @@ class PluginManager(Generic[C]):
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
# plugins loaded directly from config dictionary
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:

Wyświetl plik

@ -75,7 +75,6 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
self._sys_interval: int = 0
self._current_process = psutil.Process()
def _clear_stats(self) -> None:
"""Initialize broker statistics data structures."""
for stat in (

Wyświetl plik

@ -13,6 +13,7 @@ def main() -> None:
"""Run the cli for `ca_creds`."""
app()
@app.command()
def ca_creds(
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),

Wyświetl plik

@ -59,5 +59,6 @@ def device_creds( # pylint: disable=too-many-locals
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
if __name__ == "__main__":
main()

Wyświetl plik

@ -32,5 +32,6 @@ def main() -> None:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()

Wyświetl plik

@ -31,5 +31,6 @@ def main() -> None:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()

Wyświetl plik

@ -118,6 +118,7 @@ async def do_pub(
logger.fatal("Publish canceled due to previous error")
raise asyncio.CancelledError from ce
app = typer.Typer(add_completion=False, rich_markup_mode=None)
@ -131,8 +132,9 @@ def _version(v: bool) -> None:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@app.command()
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def publisher_main( # pylint: disable=R0914,R0917
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),

Wyświetl plik

@ -13,6 +13,7 @@ def main() -> None:
"""Run the `server_creds` cli."""
app()
@app.command()
def server_creds(
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),

Wyświetl plik

@ -107,7 +107,7 @@ def _version(v:bool) -> None:
@app.command()
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def subscribe_main( # pylint: disable=R0914,R0917
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),

Wyświetl plik

@ -166,11 +166,12 @@ indent-style = "space"
docstring-code-format = true
[tool.ruff.lint]
preview = true
select = ["ALL"]
extend-select = [
"UP", # pyupgrade
"D", # pydocstyle
"D", # pydocstyle,
]
ignore = [
@ -189,7 +190,16 @@ ignore = [
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument
"PERF203",# try-except penalty within loops (3.10 only),
"COM812" # rule causes conflicts when used with the formatter
"COM812", # rule causes conflicts when used with the formatter,
# ignore certain preview rules
"DOC",
"PLW",
"PLR",
"CPY",
"PLC",
"RUF052",
"B903"
]
[tool.ruff.lint.per-file-ignores]

Wyświetl plik

@ -1,6 +1,5 @@
import asyncio
import logging
import os
from pathlib import Path
from amqtt.broker import Broker
@ -24,13 +23,13 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
'amqtt.plugins.authentication.FileAuthPlugin': {
'password_file': Path(__file__).parent / 'passwd',
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
"amqtt.plugins.authentication.FileAuthPlugin": {
"password_file": Path(__file__).parent / "passwd",
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
'amqtt.plugins.topic_checking.TopicAccessControlListPlugin': {
'acl': {
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
"amqtt.plugins.topic_checking.TopicAccessControlListPlugin": {
"acl": {
# username: [list of allowed topics]
"test": ["repositories/+/master", "calendar/#", "data/memes"],
"anonymous": [],

Wyświetl plik

@ -1,8 +1,6 @@
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
import logging
from amqtt.broker import Broker
from amqtt.plugins.base import BasePlugin
@ -18,7 +16,7 @@ logger = logging.getLogger(__name__)
class RemoteInfoPlugin(BasePlugin):
async def on_broker_client_connected(self, *, client_id:str, client_session:Session) -> None:
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ''
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ""
logger.info(f"client '{client_id}' connected from"
f" '{client_session.remote_address}' {display_port_str}")
@ -40,8 +38,8 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
'samples.broker_custom_plugin.RemoteInfoPlugin': { 'display_port': True },
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
"samples.broker_custom_plugin.RemoteInfoPlugin": { "display_port": True },
}
}

Wyświetl plik

@ -1,6 +1,6 @@
import asyncio
import logging
from asyncio import CancelledError
import logging
from amqtt.broker import Broker

Wyświetl plik

@ -1,6 +1,5 @@
import asyncio
import logging
import os
from pathlib import Path
from amqtt.broker import Broker
@ -24,11 +23,11 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
'amqtt.plugins.authentication.FileAuthPlugin': {
'password_file': Path(__file__).parent / 'passwd',
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
"amqtt.plugins.authentication.FileAuthPlugin": {
"password_file": Path(__file__).parent / "passwd",
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
}
}

Wyświetl plik

@ -1,6 +1,5 @@
import asyncio
import logging
import os
from pathlib import Path
from amqtt.broker import Broker
@ -24,12 +23,12 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True},
'amqtt.plugins.authentication.FileAuthPlugin': {
'password_file': Path(__file__).parent / 'passwd',
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
"amqtt.plugins.authentication.FileAuthPlugin": {
"password_file": Path(__file__).parent / "passwd",
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 10},
'amqtt.plugins.topic_checking.TopicTabooPlugin': {},
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 10},
"amqtt.plugins.topic_checking.TopicTabooPlugin": {},
}
}

Wyświetl plik

@ -1,6 +1,6 @@
import asyncio
import logging
from asyncio import CancelledError
import logging
from amqtt.client import MQTTClient

Wyświetl plik

@ -35,7 +35,7 @@ async def test_coro1() -> None:
async def test_coro2() -> None:
try:
client = MQTTClient(config={'auto_reconnect': False, 'connection_timeout': 1})
client = MQTTClient(config={"auto_reconnect": False, "connection_timeout": 1})
await client.connect("mqtt://localhost:1884/")
await client.publish("a/b", b"TEST MESSAGE WITH QOS_0", qos=0x00)
await client.publish("a/b", b"TEST MESSAGE WITH QOS_1", qos=0x01)
@ -43,7 +43,7 @@ async def test_coro2() -> None:
logger.info("test_coro2 messages published")
await client.disconnect()
except ConnectError:
logger.info(f"Connection failed", exc_info=True)
logger.info("Connection failed", exc_info=True)
def __main__():

Wyświetl plik

@ -27,7 +27,7 @@ async def test_coro() -> None:
await client.publish("calendar/amqtt/releases", b"NEW RELEASE", qos=QOS_1)
logger.info("messages published")
await client.disconnect()
except ConnectError as ce:
except ConnectError:
logger.exception("ERROR: Connection failed")

Wyświetl plik

@ -1,7 +1,6 @@
import argparse
import asyncio
import logging
from pathlib import Path
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
@ -29,7 +28,7 @@ config = {
async def test_coro(certfile: str) -> None:
config['certfile'] = certfile
config["certfile"] = certfile
client = MQTTClient(config=config)
await client.connect("mqtts://localhost:8883")
@ -47,7 +46,7 @@ def __main__():
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
parser = argparse.ArgumentParser()
parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity")
parser.add_argument("--cert", default="cert.pem", help="path & file to verify server's authenticity")
args = parser.parse_args()
asyncio.run(test_coro(args.cert))

Wyświetl plik

@ -4,7 +4,6 @@ import logging
from amqtt.client import ClientError, MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
"""
This sample shows how to subscribe to different $SYS topics and how to receive incoming messages
"""
@ -13,7 +12,7 @@ logger = logging.getLogger(__name__)
async def uptime_coro() -> None:
client = MQTTClient(config={'auto_reconnect': False})
client = MQTTClient(config={"auto_reconnect": False})
await client.connect("mqtt://localhost:1883")
await client.subscribe(

Wyświetl plik

@ -35,7 +35,7 @@ async def uptime_coro() -> None:
await client.unsubscribe(["$SYS/#", "data/memes"])
logger.info("UnSubscribed")
await client.disconnect()
except ClientError as ce:
except ClientError:
logger.exception("Client exception")

Wyświetl plik

@ -9,14 +9,13 @@ from aiohttp import web
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import ConnectError
logger = logging.getLogger(__name__)
MQTT_LISTENER_NAME = "myMqttListener"
async def hello(request):
"""get request handler"""
"""Get request handler"""
return web.Response(text="Hello, world")
class WebSocketResponseReader(ReaderAdapter):
@ -27,17 +26,17 @@ class WebSocketResponseReader(ReaderAdapter):
self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes:
"""
read 'n' bytes from the datastream, if < 0 read all available bytes
"""Read 'n' bytes from the datastream, if < 0 read all available bytes
Raises:
BrokerPipeError : if reading on a closed websocket connection
"""
# continue until buffer contains at least the amount of data being requested
while not self.buffer or len(self.buffer) < n:
# if the websocket is closed
if self.ws.closed:
raise BrokenPipeError()
raise BrokenPipeError
try:
# read from stream
@ -46,10 +45,10 @@ class WebSocketResponseReader(ReaderAdapter):
if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
raise BrokenPipeError()
raise BrokenPipeError
except asyncio.TimeoutError:
raise BrokenPipeError()
raise BrokenPipeError
# return all bytes currently in the buffer
if n == -1:
@ -74,7 +73,7 @@ class WebSocketResponseWriter(WriterAdapter):
# needed for `get_peer_info`
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
peer_name = request.transport.get_extra_info('peername')
peer_name = request.transport.get_extra_info("peername")
if peer_name is not None:
self.client_ip, self.port = peer_name[0:2]
else:
@ -110,17 +109,17 @@ class WebSocketResponseWriter(WriterAdapter):
async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse:
# establish connection by responding to the websocket request with the 'mqtt' protocol
ws = web.WebSocketResponse(protocols=['mqtt',])
ws = web.WebSocketResponse(protocols=["mqtt"])
await ws.prepare(request)
# access the broker created when the server started
b: Broker = request.app['broker']
b: Broker = request.app["broker"]
# hand-off the websocket data stream to the broker for handling
# `listener_name` is the same name of the externalized listener in the broker config
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME)
logger.debug('websocket connection closed')
logger.debug("websocket connection closed")
return ws
@ -140,9 +139,9 @@ def main():
app = web.Application()
app.add_routes(
[
web.get('/', hello), # http get request/response route
web.get('/ws', websocket_handler), # standard websocket handler
web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections
web.get("/", hello), # http get request/response route
web.get("/ws", websocket_handler), # standard websocket handler
web.get("/mqtt", mqtt_websocket_handler), # websocket handler for mqtt connections
])
# create background task for running the `amqtt` broker
app.cleanup_ctx.append(run_broker)
@ -154,12 +153,12 @@ def main():
async def run_broker(_app):
"""App init function to start (and then shutdown) the `amqtt` broker.
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks
"""
# standard TCP connection as well as an externalized-listener
cfg = BrokerConfig(
listeners={
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
"default":ListenerConfig(type=ListenerType.TCP, bind="127.0.0.1:1883"),
MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL),
}
)
@ -169,7 +168,7 @@ async def run_broker(_app):
broker = Broker(config=cfg, loop=loop)
# store broker instance so that incoming requests can hand off processing of a datastream
_app['broker'] = broker
_app["broker"] = broker
# start the broker
await broker.start()
@ -180,6 +179,6 @@ async def run_broker(_app):
await broker.shutdown()
if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

Wyświetl plik

@ -1,20 +1,19 @@
import contextlib
import logging
import asyncio
import ssl
from asyncio import StreamWriter, StreamReader, Event
from functools import partial
import logging
from pathlib import Path
import ssl
import typer
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker
from amqtt.client import ClientContext
from amqtt.contexts import ClientConfig, BrokerConfig, ListenerConfig, ListenerType
from amqtt.contexts import BrokerConfig, ClientConfig, ListenerConfig, ListenerType
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from amqtt.plugins.manager import PluginManager
from amqtt.session import Session
from amqtt.adapters import ReaderAdapter, WriterAdapter
logger = logging.getLogger(__name__)
@ -60,7 +59,7 @@ class UnixStreamWriterAdapter(WriterAdapter):
await self._writer.drain()
def get_peer_info(self) -> tuple[str, int]:
extra_info = self._writer.get_extra_info('socket')
extra_info = self._writer.get_extra_info("socket")
return extra_info.getsockname(), 0
async def close(self) -> None:
@ -86,7 +85,7 @@ async def run_broker(socket_file: Path):
# configure the broker with a single, external listener
cfg = BrokerConfig(
listeners={
'default': ListenerConfig(
"default": ListenerConfig(
type=ListenerType.EXTERNAL
)
},
@ -109,7 +108,7 @@ async def run_broker(socket_file: Path):
# passes the connection to the broker for protocol communications
await b.external_connected(reader=r, writer=w, listener_name=listener_name)
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name='default'), path=socket_file)
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name="default"), path=socket_file)
await b.start()
try:
@ -163,7 +162,7 @@ async def run_client(socket_file: Path):
try:
while True:
# periodically send a message
await cph.mqtt_publish('my/topic', b'my message', 0, False)
await cph.mqtt_publish("my/topic", b"my message", 0, False)
await asyncio.sleep(1)
except KeyboardInterrupt:
cph.detach()