refactor: more cleanup/linting especial Pylint

pull/165/head
MVladislav 2025-01-12 22:44:19 +01:00
rodzic 40d1214c79
commit 57bfec0ea8
13 zmienionych plików z 108 dodań i 83 usunięć

Wyświetl plik

@ -106,10 +106,10 @@ repos:
types: [python] types: [python]
require_serial: true require_serial: true
exclude: ^tests/.+|^docs/.+|^samples/.+ exclude: ^tests/.+|^docs/.+|^samples/.+
# - id: pylint - id: pylint
# name: Run Pylint in Virtualenv name: Run Pylint in Virtualenv
# entry: scripts/run-in-env.sh pylint entry: scripts/run-in-env.sh pylint
# language: script language: script
# types: [python] types: [python]
# require_serial: true require_serial: true
# exclude: ^tests/.+|^docs/.+|^samples/.+ exclude: ^tests/.+|^docs/.+|^samples/.+

Wyświetl plik

@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from contextlib import suppress from contextlib import suppress
import io import io
@ -7,13 +8,14 @@ from websockets import ConnectionClosed
from websockets.asyncio.connection import Connection from websockets.asyncio.connection import Connection
class ReaderAdapter: class ReaderAdapter(ABC):
"""Base class for all network protocol reader adapters. """Base class for all network protocol reader adapters.
Reader adapters are used to adapt read operations on the network depending on the Reader adapters are used to adapt read operations on the network depending on the
protocol used. protocol used.
""" """
@abstractmethod
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
"""Read up to n bytes. If n is not provided, or set to -1, read until EOF and return all read bytes. """Read up to n bytes. If n is not provided, or set to -1, read until EOF and return all read bytes.
@ -22,30 +24,35 @@ class ReaderAdapter:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def feed_eof(self) -> None: def feed_eof(self) -> None:
"""Acknowledge EOF.""" """Acknowledge EOF."""
raise NotImplementedError raise NotImplementedError
class WriterAdapter: class WriterAdapter(ABC):
"""Base class for all network protocol writer adapters. """Base class for all network protocol writer adapters.
Writer adapters are used to adapt write operations on the network depending on Writer adapters are used to adapt write operations on the network depending on
the protocol used. the protocol used.
""" """
@abstractmethod
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
"""Write some data to the protocol layer.""" """Write some data to the protocol layer."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def drain(self) -> None: async def drain(self) -> None:
"""Let the write buffer of the underlying transport a chance to be flushed.""" """Let the write buffer of the underlying transport a chance to be flushed."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_peer_info(self) -> tuple[str, int] | None: def get_peer_info(self) -> tuple[str, int] | None:
"""Return peer socket info (remote address and remote port as tuple).""" """Return peer socket info (remote address and remote port as tuple)."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def close(self) -> None: async def close(self) -> None:
"""Close the protocol connection.""" """Close the protocol connection."""
raise NotImplementedError raise NotImplementedError
@ -81,6 +88,10 @@ class WebSocketsReader(ReaderAdapter):
buffer.extend(message) buffer.extend(message)
self._stream = io.BytesIO(buffer) self._stream = io.BytesIO(buffer)
def feed_eof(self) -> None:
# NOTE: not implemented?!
pass
class WebSocketsWriter(WriterAdapter): class WebSocketsWriter(WriterAdapter):
"""WebSockets API writer adapter. """WebSockets API writer adapter.
@ -182,6 +193,10 @@ class BufferReader(ReaderAdapter):
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
return self._stream.read(n) return self._stream.read(n)
def feed_eof(self) -> None:
# NOTE: not implemented?!
pass
class BufferWriter(WriterAdapter): class BufferWriter(WriterAdapter):
"""ByteBuffer writer adapter. """ByteBuffer writer adapter.

Wyświetl plik

@ -61,6 +61,7 @@ class RetainedApplicationMessage(ApplicationMessage):
__slots__ = ("data", "qos", "source_session", "topic") __slots__ = ("data", "qos", "source_session", "topic")
def __init__(self, source_session: Session | None, topic: str, data: bytes, qos: int | None = None) -> None: def __init__(self, source_session: Session | None, topic: str, data: bytes, qos: int | None = None) -> None:
super().__init__(None, topic, qos, data, True) # noqa: FBT003
self.source_session = source_session self.source_session = source_session
self.topic = topic self.topic = topic
self.data = data self.data = data
@ -267,13 +268,13 @@ class Broker:
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}" msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
raise BrokerError(msg) from ke raise BrokerError(msg) from ke
except FileNotFoundError as fnfe: except FileNotFoundError as fnfe:
msg = "Can't read cert files '{}' or '{}' : {}".format(listener["certfile"], listener["keyfile"], fnfe) msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
raise BrokerError(msg) from fnfe raise BrokerError(msg) from fnfe
try: try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e: except ValueError as e:
msg = "Invalid port value in bind value: {}".format(listener["bind"]) msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e raise BrokerError(msg) from e
instance: asyncio.Server | websockets.asyncio.server.Server | None = None instance: asyncio.Server | websockets.asyncio.server.Server | None = None
@ -358,34 +359,37 @@ class Broker:
await server.acquire_connection() await server.acquire_connection()
remote_info = writer.get_peer_info() remote_info = writer.get_peer_info()
if remote_info is not None: if remote_info is None:
remote_address, remote_port = remote_info self.logger.warning("remote info could not get from peer info")
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'") return
# Wait for first packet and expect a CONNECT remote_address, remote_port = remote_info
try: self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
except AMQTTError as exc: # Wait for first packet and expect a CONNECT
self.logger.warning( try:
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:" handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
f"Can't read first packet an CONNECT: {exc}", except AMQTTError as exc:
) self.logger.warning(
# await writer.close() f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
self.logger.debug("Connection closed") f"Can't read first packet an CONNECT: {exc}",
server.release_connection() )
return # await writer.close()
except MQTTError: self.logger.debug("Connection closed")
self.logger.exception( server.release_connection()
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}", return
) except MQTTError:
await writer.close() self.logger.exception(
server.release_connection() f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
self.logger.debug("Connection closed") )
return await writer.close()
except NoDataError as ne: server.release_connection()
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails self.logger.debug("Connection closed")
server.release_connection() return
return except NoDataError as ne:
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails
server.release_connection()
return
if client_session.clean_session: if client_session.clean_session:
# Delete existing session and create a new one # Delete existing session and create a new one
@ -397,7 +401,7 @@ class Broker:
# Get session from cache # Get session from cache
elif client_session.client_id in self._sessions: elif client_session.client_id in self._sessions:
self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}") self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
(client_session, h) = self._sessions[client_session.client_id] client_session, _ = self._sessions[client_session.client_id]
client_session.parent = 1 client_session.parent = 1
else: else:
client_session.parent = 0 client_session.parent = 0
@ -453,7 +457,7 @@ class Broker:
connected = True connected = True
while connected: while connected:
try: try:
done, pending = await asyncio.wait( done, _ = await asyncio.wait(
[ [
disconnect_waiter, disconnect_waiter,
subscribe_waiter, subscribe_waiter,
@ -751,9 +755,9 @@ class Broker:
try: try:
task.result() task.result()
except CancelledError: except CancelledError:
self.logger.info("Task has been cancelled: %s", task) self.logger.info(f"Task has been cancelled: {task}")
except Exception: except Exception:
self.logger.exception("Task failed and will be skipped: %s", task) self.logger.exception(f"Task failed and will be skipped: {task}")
run_broadcast_task = asyncio.ensure_future(self._run_broadcast(running_tasks)) run_broadcast_task = asyncio.ensure_future(self._run_broadcast(running_tasks))

Wyświetl plik

@ -230,7 +230,7 @@ class MQTTClient:
raise ConnectError(msg) from e raise ConnectError(msg) from e
except Exception as e: except Exception as e:
self.logger.warning(f"Reconnection attempt failed: {e}") self.logger.warning(f"Reconnection attempt failed: {e}")
if reconnect_retries >= 0 and nb_attempt > reconnect_retries: if reconnect_retries < nb_attempt: # reconnect_retries >= 0 and
self.logger.exception("Maximum connection attempts reached. Reconnection aborted.") self.logger.exception("Maximum connection attempts reached. Reconnection aborted.")
msg = "Too many failed attempts" msg = "Too many failed attempts"
raise ConnectError(msg) from e raise ConnectError(msg) from e

Wyświetl plik

@ -135,10 +135,10 @@ class ProtocolHandler:
await self.writer.close() await self.writer.close()
except asyncio.CancelledError: except asyncio.CancelledError:
self.logger.debug("Writer close was cancelled.") self.logger.debug("Writer close was cancelled.")
except TimeoutError as e:
self.logger.debug(f"Writer close operation timed out: {e}.")
except OSError as e: except OSError as e:
self.logger.debug(f"Writer close failed due to I/O error: {e}") self.logger.debug(f"Writer close failed due to I/O error: {e}")
except TimeoutError:
self.logger.debug("Writer close operation timed out.")
def _stop_waiters(self) -> None: def _stop_waiters(self) -> None:
self.logger.debug(f"Stopping {len(self._puback_waiters)} puback waiters") self.logger.debug(f"Stopping {len(self._puback_waiters)} puback waiters")
@ -339,8 +339,10 @@ class ProtocolHandler:
if app_message.pubrel_packet and app_message.pubcomp_packet: if app_message.pubrel_packet and app_message.pubcomp_packet:
msg = f"Message '{app_message.packet_id}' has already been acknowledged" msg = f"Message '{app_message.packet_id}' has already been acknowledged"
raise AMQTTError(msg) raise AMQTTError(msg)
if not app_message.pubrel_packet: if not app_message.pubrel_packet:
# Store message # Store message
publish_packet: PublishPacket
if app_message.publish_packet is not None: if app_message.publish_packet is not None:
# This is a retry flow, no need to store just check the message exists in session # This is a retry flow, no need to store just check the message exists in session
if app_message.packet_id not in self.session.inflight_out: if app_message.packet_id not in self.session.inflight_out:
@ -353,9 +355,11 @@ class ProtocolHandler:
publish_packet = app_message.build_publish_packet() publish_packet = app_message.build_publish_packet()
else: else:
self.logger.debug("Message can not be stored, to be checked!") self.logger.debug("Message can not be stored, to be checked!")
# Send PUBLISH packet # Send PUBLISH packet
await self._send_packet(publish_packet) await self._send_packet(publish_packet)
app_message.publish_packet = publish_packet app_message.publish_packet = publish_packet
# Wait PUBREC # Wait PUBREC
if app_message.packet_id in self._pubrec_waiters: if app_message.packet_id in self._pubrec_waiters:
# PUBREC waiter already exists for this packet ID # PUBREC waiter already exists for this packet ID
@ -369,6 +373,7 @@ class ProtocolHandler:
finally: finally:
self._pubrec_waiters.pop(app_message.packet_id, None) self._pubrec_waiters.pop(app_message.packet_id, None)
self.session.inflight_out.pop(app_message.packet_id, None) self.session.inflight_out.pop(app_message.packet_id, None)
if not app_message.pubcomp_packet: if not app_message.pubcomp_packet:
# Send pubrel # Send pubrel
app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id) app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id)
@ -552,8 +557,9 @@ class ProtocolHandler:
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session) await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
except (ConnectionResetError, BrokenPipeError): except (ConnectionResetError, BrokenPipeError):
await self.handle_connection_closed() await self.handle_connection_closed()
except asyncio.CancelledError: except asyncio.CancelledError as e:
raise msg = "Packet handling was cancelled"
raise ProtocolHandlerError(msg) from e
except Exception as e: except Exception as e:
self.logger.warning(f"Unhandled exception: {e}") self.logger.warning(f"Unhandled exception: {e}")
raise raise
@ -606,7 +612,7 @@ class ProtocolHandler:
raise AMQTTError(msg) raise AMQTTError(msg)
self.logger.debug(f"{self.session.client_id} SUBSCRIBE unhandled") self.logger.debug(f"{self.session.client_id} SUBSCRIBE unhandled")
async def handle_unsubscribe(self, subscribe: UnsubscribePacket) -> None: async def handle_unsubscribe(self, unsubscribe: UnsubscribePacket) -> None:
if self.session is None: if self.session is None:
msg = "Session is not initialized." msg = "Session is not initialized."
raise AMQTTError(msg) raise AMQTTError(msg)

Wyświetl plik

@ -60,7 +60,7 @@ class FileAuthPlugin(BaseAuthPlugin):
return return
try: try:
with Path(password_file).open("r") as file: with Path(password_file).open(mode="r", encoding="utf-8") as file:
self.context.logger.debug(f"Reading user database from {password_file}") self.context.logger.debug(f"Reading user database from {password_file}")
for _line in file: for _line in file:
line = _line.strip() line = _line.strip()

Wyświetl plik

@ -122,7 +122,7 @@ class PluginManager:
def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]: def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]:
return asyncio.ensure_future(coro) return asyncio.ensure_future(coro)
async def fire_event(self, event_name: str, wait: bool = False, *args: Any, **kwargs: Any) -> None: async def fire_event(self, event_name: str, *args: Any, wait: bool = False, **kwargs: Any) -> None:
"""Fire an event to plugins. """Fire an event to plugins.
PluginManager schedules async calls for each plugin on method called "on_" + event_name. PluginManager schedules async calls for each plugin on method called "on_" + event_name.

Wyświetl plik

@ -36,8 +36,8 @@ class TopicTabooPlugin(BaseTopicPlugin):
class TopicAccessControlListPlugin(BaseTopicPlugin): class TopicAccessControlListPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None: # def __init__(self, context: BaseContext) -> None:
super().__init__(context) # super().__init__(context)
@staticmethod @staticmethod
def topic_ac(topic_requested: str, topic_allowed: str) -> bool: def topic_ac(topic_requested: str, topic_allowed: str) -> bool:
@ -79,11 +79,10 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
if username is None: if username is None:
username = "anonymous" username = "anonymous"
if self.topic_config is None: acl: dict[str, Any] = {}
acl: dict[str, Any] = {} if self.topic_config is not None and action == Action.PUBLISH:
elif action == Action.PUBLISH:
acl = self.topic_config.get("publish-acl", {}) acl = self.topic_config.get("publish-acl", {})
elif action == Action.SUBSCRIBE: elif self.topic_config is not None and action == Action.SUBSCRIBE:
acl = self.topic_config.get("acl", {}) acl = self.topic_config.get("acl", {})
allowed_topics = acl.get(username, None) allowed_topics = acl.get(username, None)

Wyświetl plik

@ -35,7 +35,10 @@ from collections.abc import Generator
import contextlib import contextlib
import json import json
import logging import logging
import os
from pathlib import Path from pathlib import Path
import socket
import sys
from typing import Any from typing import Any
from docopt import docopt from docopt import docopt
@ -49,9 +52,6 @@ logger = logging.getLogger(__name__)
def _gen_client_id() -> str: def _gen_client_id() -> str:
import os
import socket
pid = os.getpid() pid = os.getpid()
hostname = socket.gethostname() hostname = socket.gethostname()
return f"amqtt_pub/{pid}-{hostname}" return f"amqtt_pub/{pid}-{hostname}"
@ -79,20 +79,16 @@ def _get_message(arguments: dict[str, Any]) -> Generator[bytes | bytearray]:
yield arguments["-m"].encode(encoding="utf-8") yield arguments["-m"].encode(encoding="utf-8")
if arguments["-f"]: if arguments["-f"]:
try: try:
with Path(arguments["-f"]).open() as f: with Path(arguments["-f"]).open(encoding="utf-8") as f:
for line in f: for line in f:
yield line.encode(encoding="utf-8") yield line.encode(encoding="utf-8")
except Exception: except Exception:
logger.exception(f"Failed to read file '{arguments['-f']}'") logger.exception(f"Failed to read file '{arguments['-f']}'")
if arguments["-l"]: if arguments["-l"]:
import sys
for line in sys.stdin: for line in sys.stdin:
if line: if line:
yield line.encode(encoding="utf-8") yield line.encode(encoding="utf-8")
if arguments["-s"]: if arguments["-s"]:
import sys
message = bytearray() message = bytearray()
for line in sys.stdin: for line in sys.stdin:
message.extend(line.encode(encoding="utf-8")) message.extend(line.encode(encoding="utf-8"))

Wyświetl plik

@ -31,7 +31,9 @@ import asyncio
import contextlib import contextlib
import json import json
import logging import logging
import os
from pathlib import Path from pathlib import Path
import socket
import sys import sys
from typing import Any from typing import Any
@ -47,9 +49,6 @@ logger = logging.getLogger(__name__)
def _gen_client_id() -> str: def _gen_client_id() -> str:
import os
import socket
pid = os.getpid() pid = os.getpid()
hostname = socket.gethostname() hostname = socket.gethostname()
return f"amqtt_sub/{pid}-{hostname}" return f"amqtt_sub/{pid}-{hostname}"

Wyświetl plik

@ -43,25 +43,23 @@ def get_git_changeset() -> str | None:
# Call git log to get the latest changeset timestamp # Call git log to get the latest changeset timestamp
try: try:
git_log = subprocess.Popen( # noqa: S603 with subprocess.Popen( # noqa: S603
[git_path, "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"], [git_path, "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
cwd=repo_dir, cwd=repo_dir,
universal_newlines=True, universal_newlines=True,
) ) as git_log:
timestamp_str, stderr = git_log.communicate()
# Capture the output if git_log.returncode != 0:
timestamp_str, stderr = git_log.communicate() logger.error(f"Git command failed with error: {stderr}")
return None
if git_log.returncode != 0: # Convert the timestamp to a datetime object
logger.error(f"Git command failed with error: {stderr}") timestamp = datetime.datetime.fromtimestamp(int(timestamp_str), tz=datetime.UTC)
return None return timestamp.strftime("%Y%m%d%H%M%S")
# Convert the timestamp to a datetime object
timestamp = datetime.datetime.fromtimestamp(int(timestamp_str), tz=datetime.UTC)
return timestamp.strftime("%Y%m%d%H%M%S")
except Exception: except Exception:
logger.exception("An error occurred while retrieving the git changeset.") logger.exception("An error occurred while retrieving the git changeset.")
return None return None

Wyświetl plik

@ -169,7 +169,7 @@ max-complexity = 42
[tool.ruff.lint.pylint] [tool.ruff.lint.pylint]
max-args = 12 max-args = 12
max-branches = 42 max-branches = 42
max-statements = 142 max-statements = 143
max-returns = 10 max-returns = 10
# ----------------------------------- PYTEST ----------------------------------- # ----------------------------------- PYTEST -----------------------------------
@ -239,6 +239,14 @@ disable = [
"missing-function-docstring", "missing-function-docstring",
"missing-class-docstring", "missing-class-docstring",
"unused-argument", "unused-argument",
"protected-access",
"line-too-long",
"too-many-branches",
"too-many-statements",
"too-many-nested-blocks",
"too-many-public-methods",
"invalid-name",
"redefined-slots-in-subclass",
] ]
# enable useless-suppression temporarily every now and then to clean them up # enable useless-suppression temporarily every now and then to clean them up
enable = [ enable = [
@ -265,4 +273,4 @@ max-branches = 20 # too-many-branches
max-parents = 10 max-parents = 10
max-positional-arguments = 10 # too-many-positional-arguments max-positional-arguments = 10 # too-many-positional-arguments
max-returns = 7 max-returns = 7
max-statements = 60 # too-many-statements max-statements = 61 # too-many-statements

Wyświetl plik

@ -39,11 +39,11 @@ class TestVersionFunctions(unittest.TestCase):
# Mock git executable check # Mock git executable check
mock_which.return_value = True mock_which.return_value = True
# Mock subprocess.Popen for git log # Mock subprocess.Popen for git log with context manager behavior
mock_process = MagicMock() mock_process = MagicMock()
mock_process.communicate.return_value = ("1638352940", "") mock_process.communicate.return_value = ("1638352940", "")
mock_process.returncode = 0 mock_process.returncode = 0
mock_popen.return_value = mock_process mock_popen.return_value.__enter__.return_value = mock_process
# Call the function # Call the function
changeset = get_git_changeset() changeset = get_git_changeset()