code quality improvements (#293)

* add additional linting rules
pull/276/head
Andrew Mirsky 2025-08-10 22:00:33 -04:00 zatwierdzone przez GitHub
rodzic b4d58c9130
commit 8e47ede192
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
50 zmienionych plików z 209 dodań i 191 usunięć

Wyświetl plik

@ -139,7 +139,7 @@ class BrokerContext(BaseContext):
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
return self._broker_instance.subscriptions
async def add_subscription(self, client_id: str, topic: str|None, qos: int|None) -> None:
async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
"""Create a topic subscription for the given `client_id`.
If a client session doesn't exist for `client_id`, create a disconnected session.
@ -325,7 +325,6 @@ class Broker:
ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener."""
match listener_type:
case ListenerType.TCP:
return await asyncio.start_server(
@ -356,17 +355,17 @@ class Broker:
session_count_before = len(self._sessions)
# clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
sessions_to_remove = [ client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session) ]
sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
# if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
if self.config.session_expiry_interval is not None:
retain_after = floor(time.time() - self.config.session_expiry_interval)
sessions_to_remove += [ client_id for client_id, (session, _) in self._sessions.items()
sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and
session.last_disconnect_time and
session.last_disconnect_time < retain_after ]
session.last_disconnect_time < retain_after]
for client_id in sessions_to_remove:
await self._cleanup_session(client_id)
@ -586,9 +585,7 @@ class Broker:
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
for topic in self._subscriptions:
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler)
@ -620,7 +617,6 @@ class Broker:
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
if subscribe_waiter in done:
await self._handle_subscription(client_session, handler, subscribe_waiter)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
@ -690,7 +686,6 @@ class Broker:
client_id=client_session.client_id,
client_session=client_session)
async def _handle_subscription(
self,
client_session: Session,
@ -808,7 +803,7 @@ class Broker:
"""
returns = await self.plugins_manager.map_plugin_auth(session=session)
results = [ result for _, result in returns.items() if result is not None] if returns else []
results = [result for _, result in returns.items() if result is not None] if returns else []
if len(results) < 1:
self.logger.debug("Authentication failed: no plugin responded with a boolean")
return False

Wyświetl plik

@ -424,14 +424,10 @@ class MQTTClient:
scheme = uri_attributes.scheme
secure = scheme in ("mqtts", "wss")
self.session.username = (
self.session.username
if self.session.username
else (str(uri_attributes.username) if uri_attributes.username else None)
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
)
self.session.password = (
self.session.password
if self.session.password
else (str(uri_attributes.password) if uri_attributes.password else None)
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
)
self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
self.session.remote_port = uri_attributes.port

Wyświetl plik

@ -142,8 +142,8 @@ def int_to_bytes_str(value: int) -> bytes:
return str(value).encode("utf-8")
def float_to_bytes_str(value: float, places:int=3) -> bytes:
def float_to_bytes_str(value: float, places: int = 3) -> bytes:
"""Convert an float value to a bytes array containing the numeric character."""
quant = Decimal(f"0.{''.join(['0' for i in range(places-1)])}1")
quant = Decimal(f"0.{''.join(['0' for i in range(places - 1)])}1")
rounded = Decimal(value).quantize(quant, rounding=ROUND_HALF_UP)
return str(rounded).encode("utf-8")

Wyświetl plik

@ -6,7 +6,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from collections.abc import Iterator
@ -50,14 +50,15 @@ class ListenerType(StrEnum):
"""Display the string value, instead of the enum member."""
return f'"{self.value!s}"'
class Dictable:
"""Add dictionary methods to a dataclass."""
def __getitem__(self, key:str) -> Any:
def __getitem__(self, key: str) -> Any:
"""Allow dict-style `[]` access to a dataclass."""
return self.get(key)
def get(self, name:str, default:Any=None) -> Any:
def get(self, name: str, default: Any = None) -> Any:
"""Allow dict-style access to a dataclass."""
name = name.replace("-", "_")
if hasattr(self, name):
@ -148,10 +149,10 @@ def default_listeners() -> dict[str, Any]:
def default_broker_plugins() -> dict[str, Any]:
"""Create defaults for BrokerConfig.plugins."""
return {
"amqtt.plugins.logging_amqtt.EventLoggerPlugin":{},
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{},
"amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True},
"amqtt.plugins.sys.broker.BrokerSysPlugin":{"sys_interval":20}
"amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
}
@ -159,7 +160,7 @@ def default_broker_plugins() -> dict[str, Any]:
class BrokerConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
does not set a value, the 'default' settings are applied. See
[`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
@ -178,7 +179,7 @@ class BrokerConfig(Dictable):
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
[`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
plugins: dict[str, Any] | list[str | dict[str,Any]] | None = field(default_factory=default_broker_plugins)
plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
@ -203,7 +204,7 @@ class BrokerConfig(Dictable):
for plugin in self.plugins:
# in case a plugin in a yaml file is listed without config map
if isinstance(plugin, str):
_plugins |= {plugin:{}}
_plugins |= {plugin: {}}
continue
_plugins |= plugin
self.plugins = _plugins
@ -263,6 +264,7 @@ class ConnectionConfig(Dictable):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
@dataclass
class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published.
@ -305,7 +307,7 @@ class WillConfig(Dictable):
def default_client_plugins() -> dict[str, Any]:
"""Create defaults for `ClientConfig.plugins`."""
return {
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
}

Wyświetl plik

@ -36,9 +36,11 @@ class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
if value is None:
return None
return [self.dataclass_type(**item) for item in value]
def process_literal_param(self, value: Any, dialect: Any) -> Any:
# Required by SQLAlchemy, typically used for literal SQL rendering.
return value
@property
def python_type(self) -> type:
# Required by TypeEngine to indicate the expected Python type.

Wyświetl plik

@ -8,7 +8,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin
@ -39,7 +39,7 @@ _db_map = {
}
def db_connection_str(db_type: DBType, db_username: str, db_host:str, db_port: int|None, db_filename: str) -> str:
def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str:
"""Create sqlalchemy database connection string."""
db_info = _db_map[db_type]
if db_type == DBType.SQLITE:

Wyświetl plik

@ -52,7 +52,7 @@ class UserManager:
raise MQTTError(msg)
return users
async def create_user_auth(self, username: str, plain_password:str) -> UserAuth | None:
async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None:
"""Create a new user."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(UserAuth).filter(UserAuth.username == username)

Wyświetl plik

@ -18,11 +18,12 @@ logger = logging.getLogger(__name__)
matcher = TopicMatcher()
@dataclass
class AllowedTopic:
topic: str
def __contains__(self, item: Union[str,"AllowedTopic"]) -> bool:
def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool:
"""Determine `in`."""
return self.__eq__(item)
@ -43,6 +44,7 @@ class AllowedTopic:
"""Display topic."""
return self.topic
class PasswordHasher:
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""

Wyświetl plik

@ -66,6 +66,7 @@ class UserAuthDBPlugin(BaseAuthPlugin):
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
"""list of hash schemes to use for passwords"""
class TopicAuthDBPlugin(BaseTopicPlugin):
def __init__(self, context: BrokerContext) -> None:
@ -83,7 +84,7 @@ class TopicAuthDBPlugin(BaseTopicPlugin):
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool|None:
) -> bool | None:
if not session or not session.username or not topic:
return None

Wyświetl plik

@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
topic_app = typer.Typer(no_args_is_help=True)
@topic_app.callback()
def main(
ctx: typer.Context,

Wyświetl plik

@ -89,8 +89,9 @@ class CertificateAuthPlugin(BaseAuthPlugin):
uri_domain: str
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
def generate_root_creds(country:str, state:str, locality:str,
org_name:str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
def generate_root_creds(country: str, state: str, locality: str,
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Generate CA key and certificate."""
# generate private key for the server
ca_key = rsa.generate_private_key(
@ -143,7 +144,7 @@ def generate_root_creds(country:str, state:str, locality:str,
return ca_key, cert
def generate_server_csr(country:str, org_name: str, cn:str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
def generate_server_csr(country: str, org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
"""Generate server private key and server certificate-signing-request."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
@ -167,7 +168,6 @@ def generate_server_csr(country:str, org_name: str, cn:str) -> tuple[rsa.RSAPriv
return key, csr
def generate_device_csr(country: str, org_name: str, common_name: str,
uri_san: str, dns_san: str
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
@ -194,9 +194,10 @@ def generate_device_csr(country: str, org_name: str, common_name: str,
return key, csr
def sign_csr(csr: CertificateSigningRequest,
ca_key: rsa.RSAPrivateKey,
ca_cert: Certificate, validity_days: int=365) -> Certificate:
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
"""Sign a csr with CA credentials."""
return (
x509.CertificateBuilder()
@ -221,7 +222,8 @@ def sign_csr(csr: CertificateSigningRequest,
.sign(ca_key, hashes.SHA256())
)
def load_ca(ca_key_fn:str, ca_crt_fn:str) -> tuple[rsa.RSAPrivateKey, Certificate]:
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Load server key and certificate."""
with Path(ca_key_fn).open("rb") as f:
ca_key: rsa.RSAPrivateKey = serialization.load_pem_private_key(f.read(), password=None) # type: ignore[assignment]
@ -230,8 +232,8 @@ def load_ca(ca_key_fn:str, ca_crt_fn:str) -> tuple[rsa.RSAPrivateKey, Certificat
return ca_key, ca_cert
def write_key_and_crt(key:rsa.RSAPrivateKey, crt:Certificate,
prefix:str, path: Path | None = None) -> None:
def write_key_and_crt(key: rsa.RSAPrivateKey, crt: Certificate,
prefix: str, path: Path | None = None) -> None:
"""Create pem-encoded files for key and certificate."""
path = path or Path()

Wyświetl plik

@ -5,7 +5,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
import logging
@ -26,6 +26,7 @@ class ResponseMode(StrEnum):
JSON = "json"
TEXT = "text"
class RequestMethod(StrEnum):
GET = "get"
POST = "post"
@ -76,7 +77,7 @@ class AuthHttpPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.http = ClientSession(headers = {"User-Agent": self.config.user_agent})
self.http = ClientSession(headers={"User-Agent": self.config.user_agent})
match self.config.request_method:
case RequestMethod.GET:
@ -102,15 +103,15 @@ class AuthHttpPlugin(BasePlugin[BrokerContext]):
case ParamsMode.FORM:
match self.config.request_method:
case RequestMethod.GET:
kwargs = { "params": payload }
case _: # POST, PUT
kwargs = {"params": payload}
case _: # POST, PUT
d: Any = FormData(payload)
kwargs = {"data": d}
case _: # JSON
kwargs = { "json": payload}
kwargs = {"json": payload}
return kwargs
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool|None: # pylint: disable=R0911
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911
kwargs = self._get_params(payload)
@ -131,7 +132,7 @@ class AuthHttpPlugin(BasePlugin[BrokerContext]):
if not self._is_2xx(r):
return False
data: dict[str, Any] = await r.json()
data = {k.lower():v for k,v in data.items()}
data = {k.lower(): v for k, v in data.items()}
return data.get("ok", None)
def get_url(self, uri: str) -> str:

Wyświetl plik

@ -9,7 +9,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from amqtt.broker import BrokerContext

Wyświetl plik

@ -8,7 +8,7 @@ from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.errors import PluginInitError
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin, BasePlugin
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
@ -42,6 +42,7 @@ class AuthLdapPlugin(BasePlugin[BrokerContext]):
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
raise PluginInitError(self.__class__) from e
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
@ -76,6 +77,7 @@ class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
class Config(LdapConfig):
"""Configuration for the User Auth LDAP Plugin."""
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
@ -107,7 +109,6 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
]
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
if not results:
logger.debug(f"user not found: {session.username}")
return False
@ -120,12 +121,11 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
dn, entry = results[0]
ldap_attribute = getattr(self.config, self._action_attr_map[action])
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
logger.debug(f"DN: {dn} - {ldap_attribute}={topic_filters}")
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
@dataclass
class Config(LdapConfig):
"""Configuration for the LDAPAuthPlugin."""

Wyświetl plik

@ -82,7 +82,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
@staticmethod
async def _get_or_create_session(db_session: AsyncSession, client_id:str) -> StoredSession:
async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession:
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
stored_session = await db_session.scalar(stmt)
@ -92,9 +92,8 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
await db_session.flush()
return stored_session
@staticmethod
async def _get_or_create_message(db_session: AsyncSession, topic:str) -> StoredMessage:
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
stored_message = await db_session.scalar(stmt)
@ -104,8 +103,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
await db_session.flush()
return stored_message
async def on_broker_client_connected(self, client_id:str, client_session:Session) -> None:
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
"""Search to see if session already exists."""
# if client id doesn't exist, create (can ignore if session is anonymous)
# update session information (will, clean_session, etc)
@ -240,7 +238,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
restored_messages += 1
logger.info(f"Retained messages restored: {restored_messages}")
logger.info(f"Restored {restored_sessions} sessions.")
async def on_broker_pre_shutdown(self) -> None:

Wyświetl plik

@ -35,6 +35,7 @@ class ShadowMessage:
def to_message(self) -> bytes:
return json.dumps(asdict_no_none(self)).encode("utf-8")
@dataclass
class GetAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
@ -46,6 +47,7 @@ class GetAcceptedMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
@dataclass
class GetRejectedMessage(ShadowMessage):
code: int
@ -56,6 +58,7 @@ class GetRejectedMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
@dataclass
class UpdateAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
@ -78,6 +81,7 @@ class UpdateRejectedMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
@dataclass
class UpdateDeltaMessage(ShadowMessage):
state: MutableMapping[str, Any]
@ -89,6 +93,7 @@ class UpdateDeltaMessage(ShadowMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
class UpdateIotaMessage(UpdateDeltaMessage):
"""Same format, corollary name."""
@ -96,6 +101,7 @@ class UpdateIotaMessage(UpdateDeltaMessage):
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
@dataclass
class UpdateDocumentMessage(ShadowMessage):
previous: StateDocument

Wyświetl plik

@ -40,7 +40,7 @@ class Shadow(ShadowBase):
device_id: Mapped[str] = mapped_column(String(128), nullable=False)
name: Mapped[str] = mapped_column(String(128), nullable=False)
version: Mapped[int] =mapped_column(Integer, nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
_state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict)
@ -106,7 +106,7 @@ def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow")
@event.listens_for(Session, "before_flush")
def convert_update_to_insert(session: Session, _flush_context: object, _instances:object | None) -> None:
def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None:
"""Force a shadow to insert a new version, instead of updating an existing."""
# Make a copy of the dirty set so we can safely mutate the session
dirty = list(session.dirty)
@ -127,6 +127,7 @@ def convert_update_to_insert(session: Session, _flush_context: object, _instance
session.add(obj) # re-add as new object
_listener_example = '''#
# @event.listens_for(Shadow, "before_insert")
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:

Wyświetl plik

@ -28,7 +28,7 @@ from amqtt.session import ApplicationMessage, Session
shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)")
DeviceID= str
DeviceID = str
ShadowName = str
@ -43,6 +43,7 @@ def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
"""Nested defaultdict for shadow cache."""
return defaultdict(shadow_dict) # type: ignore[arg-type]
class ShadowPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
@ -52,7 +53,6 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
self._engine = create_async_engine(self.config.connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def on_broker_pre_start(self) -> None:
"""Sync the schema."""
async with self._engine.begin() as conn:
@ -83,8 +83,8 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
accept_msg = GetAcceptedMessage(
state=shadow.state.state,
metadata=shadow.state.metadata,
timestamp= shadow.created_at,
version= shadow.version
timestamp=shadow.created_at,
version=shadow.version
)
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
@ -98,7 +98,7 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
prev_state = shadow.state or StateDocument()
prev_state.version = shadow.version or 0 # only required when generating shadow messages
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
next_state = prev_state + state_update

Wyświetl plik

@ -7,7 +7,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
import time
from typing import Any, Generic, TypeVar
@ -16,10 +16,12 @@ from mergedeep import merge
C = TypeVar("C", bound=Any)
class StateError(Exception):
def __init__(self, msg: str = "'state' field is required") -> None:
super().__init__(msg)
@dataclass
class MetaTimestamp:
timestamp: int = 0
@ -58,7 +60,7 @@ class MetaTimestamp:
"""Convert timestamp to int."""
return int(self.timestamp)
def __lt__(self, other:int ) -> bool:
def __lt__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp < other
@ -127,6 +129,7 @@ def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMa
return delta
@dataclass
class State(Generic[C]):
desired: MutableMapping[str, C] = field(default_factory=dict)
@ -156,8 +159,8 @@ class State(Generic[C]):
class StateDocument:
state: State[dict[str, Any]] = field(default_factory=State)
metadata: State[MetaTimestamp] = field(default_factory=State)
version: int | None = None # only required when generating shadow messages
timestamp: int | None = None # only required when generating shadow messages
version: int | None = None # only required when generating shadow messages
timestamp: int | None = None # only required when generating shadow messages
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StateDocument":

Wyświetl plik

@ -16,10 +16,12 @@ class CodecError(Exception):
class NoDataError(Exception):
"""Exceptions thrown by packet encode/decode functions."""
class ZeroLengthReadError(NoDataError):
def __init__(self) -> None:
super().__init__("Decoding a string of length zero.")
class BrokerError(Exception):
"""Exceptions thrown by broker."""

Wyświetl plik

@ -3,7 +3,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass

Wyświetl plik

@ -167,7 +167,7 @@ class PacketIdVariableHeader(MQTTVariableHeader):
_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
class MQTTPayload(Generic[_VH], ABC):
class MQTTPayload(ABC, Generic[_VH]):
"""Abstract base class for MQTT payloads."""
async def to_stream(self, writer: asyncio.StreamWriter) -> None:

Wyświetl plik

@ -31,6 +31,7 @@ _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
if TYPE_CHECKING:
from amqtt.broker import BrokerContext
class Subscription:
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
self.packet_id = packet_id

Wyświetl plik

@ -19,6 +19,7 @@ from amqtt.session import Session
if TYPE_CHECKING:
from amqtt.client import ClientContext
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
def __init__(
self,

Wyświetl plik

@ -4,13 +4,13 @@ try:
from asyncio import InvalidStateError, QueueFull, QueueShutDown
except ImportError:
# Fallback for Python < 3.12
class InvalidStateError(Exception): # type: ignore[no-redef]
class InvalidStateError(Exception): # type: ignore[no-redef]
pass
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
pass
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
pass
@ -63,6 +63,7 @@ from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplic
C = TypeVar("C", bound=BaseContext)
class ProtocolHandler(Generic[C]):
"""Class implementing the MQTT communication protocol using asyncio features."""
@ -199,7 +200,7 @@ class ProtocolHandler(Generic[C]):
async def mqtt_publish(
self,
topic: str,
data: bytes | bytearray ,
data: bytes | bytearray,
qos: int | None,
retain: bool,
ack_timeout: int | None = None,

Wyświetl plik

@ -46,7 +46,7 @@ class BasePlugin(Generic[C]):
return section_config
# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
@ -75,7 +75,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
if not bool(self.topic_config) and not is_dataclass(self.context.config):
self.context.logger.warning("'topic-check' section not found in context configuration")
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
@ -107,7 +107,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
@ -126,7 +126,6 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
# auth config section not found and Config dataclass not provided
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.

Wyświetl plik

@ -51,6 +51,7 @@ def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
C = TypeVar("C", bound=BaseContext)
class PluginManager(Generic[C]):
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
@ -97,7 +98,6 @@ class PluginManager(Generic[C]):
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
# plugins loaded directly from config dictionary
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
@ -147,7 +147,7 @@ class PluginManager(Generic[C]):
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
self._event_plugin_callbacks[event].append(awaitable)
def _load_ep_plugins(self, namespace:str) -> None:
def _load_ep_plugins(self, namespace: str) -> None:
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""
self.logger.debug(f"Loading plugins for namespace {namespace}")
auth_filter_list = []
@ -224,7 +224,7 @@ class PluginManager(Generic[C]):
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
"""Load plugin from string dotted path: mymodule.myfile.MyPlugin."""
try:
plugin_class: Any = import_string(plugin_path)
plugin_class: Any = import_string(plugin_path)
except ImportError as ep:
msg = f"Plugin import failed: {plugin_path}"
raise PluginImportError(msg) from ep
@ -377,7 +377,7 @@ class PluginManager(Generic[C]):
:return: dict containing return from coro call for each plugin.
"""
return await self._map_plugin_method(
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type]
async def map_plugin_topic(
self, *, session: Session, topic: str, action: "Action"

Wyświetl plik

@ -14,7 +14,7 @@ except ImportError:
from typing import Protocol, runtime_checkable
@runtime_checkable
class Buffer(Protocol): # type: ignore[no-redef]
class Buffer(Protocol): # type: ignore[no-redef]
def __buffer__(self, flags: int = ...) -> memoryview:
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
@ -75,7 +75,6 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
self._sys_interval: int = 0
self._current_process = psutil.Process()
def _clear_stats(self) -> None:
"""Initialize broker statistics data structures."""
for stat in (

Wyświetl plik

@ -21,7 +21,7 @@ def main() -> None:
app()
def _version(v:bool) -> None:
def _version(v: bool) -> None:
if v:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@ -65,7 +65,7 @@ def broker_main(
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
raise typer.Exit(code=1) from exc
_ = loop.create_task(broker.start()) #noqa : RUF006
_ = loop.create_task(broker.start()) # noqa : RUF006
try:
loop.run_forever()
except KeyboardInterrupt:

Wyświetl plik

@ -13,12 +13,13 @@ def main() -> None:
"""Run the cli for `ca_creds`."""
app()
@app.command()
def ca_creds(
country:str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
state:str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
locality:str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
org_name:str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
state: str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
locality: str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
) -> None:

Wyświetl plik

@ -16,7 +16,7 @@ def main() -> None:
@app.command()
def device_creds( # pylint: disable=too-many-locals
def device_creds( # pylint: disable=too-many-locals
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
device_id: str = typer.Option(..., "--device-id", help="device id for the SAN"),
@ -59,5 +59,6 @@ def device_creds( # pylint: disable=too-many-locals
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
if __name__ == "__main__":
main()

Wyświetl plik

@ -9,12 +9,12 @@ logger = logging.getLogger(__name__)
def main() -> None:
"""Run the auth db cli."""
try:
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
except ImportError:
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
sys.exit(1)
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
try:
topic_app()
@ -32,5 +32,6 @@ def main() -> None:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()

Wyświetl plik

@ -9,12 +9,12 @@ logger = logging.getLogger(__name__)
def main() -> None:
"""Run the auth db cli."""
try:
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
except ImportError:
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
sys.exit(1)
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
try:
user_app()
except ModuleNotFoundError as mnfe:
@ -31,5 +31,6 @@ def main() -> None:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()

Wyświetl plik

@ -118,6 +118,7 @@ async def do_pub(
logger.fatal("Publish canceled due to previous error")
raise asyncio.CancelledError from ce
app = typer.Typer(add_completion=False, rich_markup_mode=None)
@ -131,8 +132,9 @@ def _version(v: bool) -> None:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@app.command()
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def publisher_main( # pylint: disable=R0914,R0917
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
@ -155,7 +157,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
will_retain: bool = typer.Option(False, "--will-retain", help="If the client disconnects unexpectedly the message sent out will be treated as a retained message. *only valid, if `--will-topic` is specified*"),
extra_headers_json: str | None = typer.Option(None, "--extra-headers", help="Specify a JSON object string with key-value pairs representing additional headers that are transmitted on the initial connection. *websocket connections only*."),
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
) -> None:
"""Command-line MQTT client for publishing simple messages."""
provided = [bool(message), bool(file), stdin, lines, no_message]

Wyświetl plik

@ -13,14 +13,15 @@ def main() -> None:
"""Run the `server_creds` cli."""
app()
@app.command()
def server_creds(
country:str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name:str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
ca_key_fn:str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
ca_crt_fn:str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
) -> None:
"""Generate a key and certificate for the broker in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa : E501
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"

Wyświetl plik

@ -100,17 +100,17 @@ def main() -> None:
app()
def _version(v:bool) -> None:
def _version(v: bool) -> None:
if v:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@app.command()
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def subscribe_main( # pylint: disable=R0914,R0917
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),

Wyświetl plik

@ -166,30 +166,40 @@ indent-style = "space"
docstring-code-format = true
[tool.ruff.lint]
preview = true
select = ["ALL"]
extend-select = [
"UP", # pyupgrade
"D", # pydocstyle
"UP", # pyupgrade
"D", # pydocstyle,
]
ignore = [
"FBT001", # Checks for the use of boolean positional arguments in function definitions.
"FBT002", # Checks for the use of boolean positional arguments in function definitions.
"G004", # Logging statement uses f-string
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D107", # Missing docstring in `__init__`
"D203", # Incorrect blank line before class (mutually exclusive D211)
"D213", # Multi-line summary second line (mutually exclusive D212)
"FIX002", # Checks for "TODO" comments.
"TD002", # TODO Missing author.
"TD003", # TODO Missing issue link for this TODO.
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument
"PERF203",# try-except penalty within loops (3.10 only),
"COM812" # rule causes conflicts when used with the formatter
"FBT001", # Checks for the use of boolean positional arguments in function definitions.
"FBT002", # Checks for the use of boolean positional arguments in function definitions.
"G004", # Logging statement uses f-string
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D107", # Missing docstring in `__init__`
"D203", # Incorrect blank line before class (mutually exclusive D211)
"D213", # Multi-line summary second line (mutually exclusive D212)
"FIX002", # Checks for "TODO" comments.
"TD002", # TODO Missing author.
"TD003", # TODO Missing issue link for this TODO.
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument
"PERF203",# try-except penalty within loops (3.10 only),
"COM812", # rule causes conflicts when used with the formatter,
# ignore certain preview rules
"DOC",
"PLW",
"PLR",
"CPY",
"PLC",
"RUF052",
"B903"
]
[tool.ruff.lint.per-file-ignores]

Wyświetl plik

@ -1,6 +1,5 @@
import asyncio
import logging
import os
from pathlib import Path
from amqtt.broker import Broker
@ -24,13 +23,13 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
'amqtt.plugins.authentication.FileAuthPlugin': {
'password_file': Path(__file__).parent / 'passwd',
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
"amqtt.plugins.authentication.FileAuthPlugin": {
"password_file": Path(__file__).parent / "passwd",
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
'amqtt.plugins.topic_checking.TopicAccessControlListPlugin': {
'acl': {
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
"amqtt.plugins.topic_checking.TopicAccessControlListPlugin": {
"acl": {
# username: [list of allowed topics]
"test": ["repositories/+/master", "calendar/#", "data/memes"],
"anonymous": [],

Wyświetl plik

@ -1,8 +1,6 @@
import asyncio
import logging
import os
from dataclasses import dataclass
from pathlib import Path
import logging
from amqtt.broker import Broker
from amqtt.plugins.base import BasePlugin
@ -18,7 +16,7 @@ logger = logging.getLogger(__name__)
class RemoteInfoPlugin(BasePlugin):
async def on_broker_client_connected(self, *, client_id:str, client_session:Session) -> None:
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ''
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ""
logger.info(f"client '{client_id}' connected from"
f" '{client_session.remote_address}' {display_port_str}")
@ -40,8 +38,8 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
'samples.broker_custom_plugin.RemoteInfoPlugin': { 'display_port': True },
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
"samples.broker_custom_plugin.RemoteInfoPlugin": { "display_port": True },
}
}

Wyświetl plik

@ -1,6 +1,6 @@
import asyncio
import logging
from asyncio import CancelledError
import logging
from amqtt.broker import Broker

Wyświetl plik

@ -1,6 +1,5 @@
import asyncio
import logging
import os
from pathlib import Path
from amqtt.broker import Broker
@ -24,11 +23,11 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
'amqtt.plugins.authentication.FileAuthPlugin': {
'password_file': Path(__file__).parent / 'passwd',
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
"amqtt.plugins.authentication.FileAuthPlugin": {
"password_file": Path(__file__).parent / "passwd",
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
}
}

Wyświetl plik

@ -1,6 +1,5 @@
import asyncio
import logging
import os
from pathlib import Path
from amqtt.broker import Broker
@ -24,12 +23,12 @@ config = {
},
},
"plugins": {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True},
'amqtt.plugins.authentication.FileAuthPlugin': {
'password_file': Path(__file__).parent / 'passwd',
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
"amqtt.plugins.authentication.FileAuthPlugin": {
"password_file": Path(__file__).parent / "passwd",
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 10},
'amqtt.plugins.topic_checking.TopicTabooPlugin': {},
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 10},
"amqtt.plugins.topic_checking.TopicTabooPlugin": {},
}
}

Wyświetl plik

@ -1,6 +1,6 @@
import asyncio
import logging
from asyncio import CancelledError
import logging
from amqtt.client import MQTTClient

Wyświetl plik

@ -35,7 +35,7 @@ async def test_coro1() -> None:
async def test_coro2() -> None:
try:
client = MQTTClient(config={'auto_reconnect': False, 'connection_timeout': 1})
client = MQTTClient(config={"auto_reconnect": False, "connection_timeout": 1})
await client.connect("mqtt://localhost:1884/")
await client.publish("a/b", b"TEST MESSAGE WITH QOS_0", qos=0x00)
await client.publish("a/b", b"TEST MESSAGE WITH QOS_1", qos=0x01)
@ -43,7 +43,7 @@ async def test_coro2() -> None:
logger.info("test_coro2 messages published")
await client.disconnect()
except ConnectError:
logger.info(f"Connection failed", exc_info=True)
logger.info("Connection failed", exc_info=True)
def __main__():

Wyświetl plik

@ -27,7 +27,7 @@ async def test_coro() -> None:
await client.publish("calendar/amqtt/releases", b"NEW RELEASE", qos=QOS_1)
logger.info("messages published")
await client.disconnect()
except ConnectError as ce:
except ConnectError:
logger.exception("ERROR: Connection failed")

Wyświetl plik

@ -1,7 +1,6 @@
import argparse
import asyncio
import logging
from pathlib import Path
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
@ -29,7 +28,7 @@ config = {
async def test_coro(certfile: str) -> None:
config['certfile'] = certfile
config["certfile"] = certfile
client = MQTTClient(config=config)
await client.connect("mqtts://localhost:8883")
@ -47,7 +46,7 @@ def __main__():
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
parser = argparse.ArgumentParser()
parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity")
parser.add_argument("--cert", default="cert.pem", help="path & file to verify server's authenticity")
args = parser.parse_args()
asyncio.run(test_coro(args.cert))

Wyświetl plik

@ -4,7 +4,6 @@ import logging
from amqtt.client import ClientError, MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
"""
This sample shows how to subscribe to different $SYS topics and how to receive incoming messages
"""
@ -13,7 +12,7 @@ logger = logging.getLogger(__name__)
async def uptime_coro() -> None:
client = MQTTClient(config={'auto_reconnect': False})
client = MQTTClient(config={"auto_reconnect": False})
await client.connect("mqtt://localhost:1883")
await client.subscribe(

Wyświetl plik

@ -35,7 +35,7 @@ async def uptime_coro() -> None:
await client.unsubscribe(["$SYS/#", "data/memes"])
logger.info("UnSubscribed")
await client.disconnect()
except ClientError as ce:
except ClientError:
logger.exception("Client exception")

Wyświetl plik

@ -9,14 +9,13 @@ from aiohttp import web
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import ConnectError
logger = logging.getLogger(__name__)
MQTT_LISTENER_NAME = "myMqttListener"
async def hello(request):
"""get request handler"""
"""Get request handler"""
return web.Response(text="Hello, world")
class WebSocketResponseReader(ReaderAdapter):
@ -27,17 +26,17 @@ class WebSocketResponseReader(ReaderAdapter):
self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes:
"""
read 'n' bytes from the datastream, if < 0 read all available bytes
"""Read 'n' bytes from the datastream, if < 0 read all available bytes
Raises:
BrokerPipeError : if reading on a closed websocket connection
"""
# continue until buffer contains at least the amount of data being requested
while not self.buffer or len(self.buffer) < n:
# if the websocket is closed
if self.ws.closed:
raise BrokenPipeError()
raise BrokenPipeError
try:
# read from stream
@ -46,10 +45,10 @@ class WebSocketResponseReader(ReaderAdapter):
if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
raise BrokenPipeError()
raise BrokenPipeError
except asyncio.TimeoutError:
raise BrokenPipeError()
raise BrokenPipeError
# return all bytes currently in the buffer
if n == -1:
@ -74,7 +73,7 @@ class WebSocketResponseWriter(WriterAdapter):
# needed for `get_peer_info`
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
peer_name = request.transport.get_extra_info('peername')
peer_name = request.transport.get_extra_info("peername")
if peer_name is not None:
self.client_ip, self.port = peer_name[0:2]
else:
@ -110,17 +109,17 @@ class WebSocketResponseWriter(WriterAdapter):
async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse:
# establish connection by responding to the websocket request with the 'mqtt' protocol
ws = web.WebSocketResponse(protocols=['mqtt',])
ws = web.WebSocketResponse(protocols=["mqtt"])
await ws.prepare(request)
# access the broker created when the server started
b: Broker = request.app['broker']
b: Broker = request.app["broker"]
# hand-off the websocket data stream to the broker for handling
# `listener_name` is the same name of the externalized listener in the broker config
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME)
logger.debug('websocket connection closed')
logger.debug("websocket connection closed")
return ws
@ -140,9 +139,9 @@ def main():
app = web.Application()
app.add_routes(
[
web.get('/', hello), # http get request/response route
web.get('/ws', websocket_handler), # standard websocket handler
web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections
web.get("/", hello), # http get request/response route
web.get("/ws", websocket_handler), # standard websocket handler
web.get("/mqtt", mqtt_websocket_handler), # websocket handler for mqtt connections
])
# create background task for running the `amqtt` broker
app.cleanup_ctx.append(run_broker)
@ -154,12 +153,12 @@ def main():
async def run_broker(_app):
"""App init function to start (and then shutdown) the `amqtt` broker.
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks
"""
# standard TCP connection as well as an externalized-listener
cfg = BrokerConfig(
listeners={
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
"default":ListenerConfig(type=ListenerType.TCP, bind="127.0.0.1:1883"),
MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL),
}
)
@ -169,7 +168,7 @@ async def run_broker(_app):
broker = Broker(config=cfg, loop=loop)
# store broker instance so that incoming requests can hand off processing of a datastream
_app['broker'] = broker
_app["broker"] = broker
# start the broker
await broker.start()
@ -180,6 +179,6 @@ async def run_broker(_app):
await broker.shutdown()
if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

Wyświetl plik

@ -1,20 +1,19 @@
import contextlib
import logging
import asyncio
import ssl
from asyncio import StreamWriter, StreamReader, Event
from functools import partial
import logging
from pathlib import Path
import ssl
import typer
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker
from amqtt.client import ClientContext
from amqtt.contexts import ClientConfig, BrokerConfig, ListenerConfig, ListenerType
from amqtt.contexts import BrokerConfig, ClientConfig, ListenerConfig, ListenerType
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from amqtt.plugins.manager import PluginManager
from amqtt.session import Session
from amqtt.adapters import ReaderAdapter, WriterAdapter
logger = logging.getLogger(__name__)
@ -60,7 +59,7 @@ class UnixStreamWriterAdapter(WriterAdapter):
await self._writer.drain()
def get_peer_info(self) -> tuple[str, int]:
extra_info = self._writer.get_extra_info('socket')
extra_info = self._writer.get_extra_info("socket")
return extra_info.getsockname(), 0
async def close(self) -> None:
@ -86,7 +85,7 @@ async def run_broker(socket_file: Path):
# configure the broker with a single, external listener
cfg = BrokerConfig(
listeners={
'default': ListenerConfig(
"default": ListenerConfig(
type=ListenerType.EXTERNAL
)
},
@ -109,7 +108,7 @@ async def run_broker(socket_file: Path):
# passes the connection to the broker for protocol communications
await b.external_connected(reader=r, writer=w, listener_name=listener_name)
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name='default'), path=socket_file)
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name="default"), path=socket_file)
await b.start()
try:
@ -163,7 +162,7 @@ async def run_client(socket_file: Path):
try:
while True:
# periodically send a message
await cph.mqtt_publish('my/topic', b'my message', 0, False)
await cph.mqtt_publish("my/topic", b"my message", 0, False)
await asyncio.sleep(1)
except KeyboardInterrupt:
cph.detach()