kopia lustrzana https://github.com/Yakifo/amqtt
rodzic
b4d58c9130
commit
8e47ede192
|
@ -139,7 +139,7 @@ class BrokerContext(BaseContext):
|
|||
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
|
||||
return self._broker_instance.subscriptions
|
||||
|
||||
async def add_subscription(self, client_id: str, topic: str|None, qos: int|None) -> None:
|
||||
async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
|
||||
"""Create a topic subscription for the given `client_id`.
|
||||
|
||||
If a client session doesn't exist for `client_id`, create a disconnected session.
|
||||
|
@ -325,7 +325,6 @@ class Broker:
|
|||
ssl_context: ssl.SSLContext | None,
|
||||
) -> asyncio.Server | websockets.asyncio.server.Server:
|
||||
"""Create a server instance for a listener."""
|
||||
|
||||
match listener_type:
|
||||
case ListenerType.TCP:
|
||||
return await asyncio.start_server(
|
||||
|
@ -356,17 +355,17 @@ class Broker:
|
|||
session_count_before = len(self._sessions)
|
||||
|
||||
# clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
|
||||
sessions_to_remove = [ client_id for client_id, (session, _) in self._sessions.items()
|
||||
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session) ]
|
||||
sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
|
||||
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
|
||||
|
||||
# if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
|
||||
if self.config.session_expiry_interval is not None:
|
||||
retain_after = floor(time.time() - self.config.session_expiry_interval)
|
||||
|
||||
sessions_to_remove += [ client_id for client_id, (session, _) in self._sessions.items()
|
||||
sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
|
||||
if session.transitions.state == "disconnected" and
|
||||
session.last_disconnect_time and
|
||||
session.last_disconnect_time < retain_after ]
|
||||
session.last_disconnect_time < retain_after]
|
||||
|
||||
for client_id in sessions_to_remove:
|
||||
await self._cleanup_session(client_id)
|
||||
|
@ -586,9 +585,7 @@ class Broker:
|
|||
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
|
||||
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
|
||||
for topic in self._subscriptions:
|
||||
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
|
||||
|
||||
|
||||
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
|
||||
|
||||
await self._client_message_loop(client_session, handler)
|
||||
|
||||
|
@ -620,7 +617,6 @@ class Broker:
|
|||
|
||||
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
|
||||
|
||||
|
||||
if subscribe_waiter in done:
|
||||
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
||||
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
||||
|
@ -690,7 +686,6 @@ class Broker:
|
|||
client_id=client_session.client_id,
|
||||
client_session=client_session)
|
||||
|
||||
|
||||
async def _handle_subscription(
|
||||
self,
|
||||
client_session: Session,
|
||||
|
@ -808,7 +803,7 @@ class Broker:
|
|||
"""
|
||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||
|
||||
results = [ result for _, result in returns.items() if result is not None] if returns else []
|
||||
results = [result for _, result in returns.items() if result is not None] if returns else []
|
||||
if len(results) < 1:
|
||||
self.logger.debug("Authentication failed: no plugin responded with a boolean")
|
||||
return False
|
||||
|
|
|
@ -424,14 +424,10 @@ class MQTTClient:
|
|||
scheme = uri_attributes.scheme
|
||||
secure = scheme in ("mqtts", "wss")
|
||||
self.session.username = (
|
||||
self.session.username
|
||||
if self.session.username
|
||||
else (str(uri_attributes.username) if uri_attributes.username else None)
|
||||
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
|
||||
)
|
||||
self.session.password = (
|
||||
self.session.password
|
||||
if self.session.password
|
||||
else (str(uri_attributes.password) if uri_attributes.password else None)
|
||||
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
|
||||
)
|
||||
self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
|
||||
self.session.remote_port = uri_attributes.port
|
||||
|
|
|
@ -142,8 +142,8 @@ def int_to_bytes_str(value: int) -> bytes:
|
|||
return str(value).encode("utf-8")
|
||||
|
||||
|
||||
def float_to_bytes_str(value: float, places:int=3) -> bytes:
|
||||
def float_to_bytes_str(value: float, places: int = 3) -> bytes:
|
||||
"""Convert an float value to a bytes array containing the numeric character."""
|
||||
quant = Decimal(f"0.{''.join(['0' for i in range(places-1)])}1")
|
||||
quant = Decimal(f"0.{''.join(['0' for i in range(places - 1)])}1")
|
||||
rounded = Decimal(value).quantize(quant, rounding=ROUND_HALF_UP)
|
||||
return str(rounded).encode("utf-8")
|
||||
|
|
|
@ -6,7 +6,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
@ -50,14 +50,15 @@ class ListenerType(StrEnum):
|
|||
"""Display the string value, instead of the enum member."""
|
||||
return f'"{self.value!s}"'
|
||||
|
||||
|
||||
class Dictable:
|
||||
"""Add dictionary methods to a dataclass."""
|
||||
|
||||
def __getitem__(self, key:str) -> Any:
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Allow dict-style `[]` access to a dataclass."""
|
||||
return self.get(key)
|
||||
|
||||
def get(self, name:str, default:Any=None) -> Any:
|
||||
def get(self, name: str, default: Any = None) -> Any:
|
||||
"""Allow dict-style access to a dataclass."""
|
||||
name = name.replace("-", "_")
|
||||
if hasattr(self, name):
|
||||
|
@ -148,10 +149,10 @@ def default_listeners() -> dict[str, Any]:
|
|||
def default_broker_plugins() -> dict[str, Any]:
|
||||
"""Create defaults for BrokerConfig.plugins."""
|
||||
return {
|
||||
"amqtt.plugins.logging_amqtt.EventLoggerPlugin":{},
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{},
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin":{"sys_interval":20}
|
||||
"amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
|
||||
}
|
||||
|
||||
|
||||
|
@ -159,7 +160,7 @@ def default_broker_plugins() -> dict[str, Any]:
|
|||
class BrokerConfig(Dictable):
|
||||
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
||||
|
||||
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
|
||||
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
|
||||
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
|
||||
does not set a value, the 'default' settings are applied. See
|
||||
[`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
|
||||
|
@ -178,7 +179,7 @@ class BrokerConfig(Dictable):
|
|||
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
|
||||
[`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
|
||||
plugins: dict[str, Any] | list[str | dict[str,Any]] | None = field(default_factory=default_broker_plugins)
|
||||
plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
|
||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
||||
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
||||
[custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
|
||||
|
@ -203,7 +204,7 @@ class BrokerConfig(Dictable):
|
|||
for plugin in self.plugins:
|
||||
# in case a plugin in a yaml file is listed without config map
|
||||
if isinstance(plugin, str):
|
||||
_plugins |= {plugin:{}}
|
||||
_plugins |= {plugin: {}}
|
||||
continue
|
||||
_plugins |= plugin
|
||||
self.plugins = _plugins
|
||||
|
@ -263,6 +264,7 @@ class ConnectionConfig(Dictable):
|
|||
if isinstance(getattr(self, fn), str):
|
||||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopicConfig(Dictable):
|
||||
"""Configuration of how messages to specific topics are published.
|
||||
|
@ -305,7 +307,7 @@ class WillConfig(Dictable):
|
|||
def default_client_plugins() -> dict[str, Any]:
|
||||
"""Create defaults for `ClientConfig.plugins`."""
|
||||
return {
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -36,9 +36,11 @@ class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
|
|||
if value is None:
|
||||
return None
|
||||
return [self.dataclass_type(**item) for item in value]
|
||||
|
||||
def process_literal_param(self, value: Any, dialect: Any) -> Any:
|
||||
# Required by SQLAlchemy, typically used for literal SQL rendering.
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self) -> type:
|
||||
# Required by TypeEngine to indicate the expected Python type.
|
||||
|
|
|
@ -8,7 +8,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin
|
||||
|
@ -39,7 +39,7 @@ _db_map = {
|
|||
}
|
||||
|
||||
|
||||
def db_connection_str(db_type: DBType, db_username: str, db_host:str, db_port: int|None, db_filename: str) -> str:
|
||||
def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str:
|
||||
"""Create sqlalchemy database connection string."""
|
||||
db_info = _db_map[db_type]
|
||||
if db_type == DBType.SQLITE:
|
||||
|
|
|
@ -52,7 +52,7 @@ class UserManager:
|
|||
raise MQTTError(msg)
|
||||
return users
|
||||
|
||||
async def create_user_auth(self, username: str, plain_password:str) -> UserAuth | None:
|
||||
async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None:
|
||||
"""Create a new user."""
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(UserAuth).filter(UserAuth.username == username)
|
||||
|
|
|
@ -18,11 +18,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
matcher = TopicMatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllowedTopic:
|
||||
topic: str
|
||||
|
||||
def __contains__(self, item: Union[str,"AllowedTopic"]) -> bool:
|
||||
def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool:
|
||||
"""Determine `in`."""
|
||||
return self.__eq__(item)
|
||||
|
||||
|
@ -43,6 +44,7 @@ class AllowedTopic:
|
|||
"""Display topic."""
|
||||
return self.topic
|
||||
|
||||
|
||||
class PasswordHasher:
|
||||
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ class UserAuthDBPlugin(BaseAuthPlugin):
|
|||
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
|
||||
"""list of hash schemes to use for passwords"""
|
||||
|
||||
|
||||
class TopicAuthDBPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
|
@ -83,7 +84,7 @@ class TopicAuthDBPlugin(BaseTopicPlugin):
|
|||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool|None:
|
||||
) -> bool | None:
|
||||
if not session or not session.username or not topic:
|
||||
return None
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
|
|||
topic_app = typer.Typer(no_args_is_help=True)
|
||||
|
||||
|
||||
|
||||
@topic_app.callback()
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
|
|
|
@ -89,8 +89,9 @@ class CertificateAuthPlugin(BaseAuthPlugin):
|
|||
uri_domain: str
|
||||
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
|
||||
|
||||
def generate_root_creds(country:str, state:str, locality:str,
|
||||
org_name:str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
|
||||
def generate_root_creds(country: str, state: str, locality: str,
|
||||
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
"""Generate CA key and certificate."""
|
||||
# generate private key for the server
|
||||
ca_key = rsa.generate_private_key(
|
||||
|
@ -143,7 +144,7 @@ def generate_root_creds(country:str, state:str, locality:str,
|
|||
return ca_key, cert
|
||||
|
||||
|
||||
def generate_server_csr(country:str, org_name: str, cn:str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
|
||||
def generate_server_csr(country: str, org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
|
||||
"""Generate server private key and server certificate-signing-request."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
|
@ -167,7 +168,6 @@ def generate_server_csr(country:str, org_name: str, cn:str) -> tuple[rsa.RSAPriv
|
|||
return key, csr
|
||||
|
||||
|
||||
|
||||
def generate_device_csr(country: str, org_name: str, common_name: str,
|
||||
uri_san: str, dns_san: str
|
||||
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
|
||||
|
@ -194,9 +194,10 @@ def generate_device_csr(country: str, org_name: str, common_name: str,
|
|||
|
||||
return key, csr
|
||||
|
||||
|
||||
def sign_csr(csr: CertificateSigningRequest,
|
||||
ca_key: rsa.RSAPrivateKey,
|
||||
ca_cert: Certificate, validity_days: int=365) -> Certificate:
|
||||
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
|
||||
"""Sign a csr with CA credentials."""
|
||||
return (
|
||||
x509.CertificateBuilder()
|
||||
|
@ -221,7 +222,8 @@ def sign_csr(csr: CertificateSigningRequest,
|
|||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
def load_ca(ca_key_fn:str, ca_crt_fn:str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
|
||||
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
"""Load server key and certificate."""
|
||||
with Path(ca_key_fn).open("rb") as f:
|
||||
ca_key: rsa.RSAPrivateKey = serialization.load_pem_private_key(f.read(), password=None) # type: ignore[assignment]
|
||||
|
@ -230,8 +232,8 @@ def load_ca(ca_key_fn:str, ca_crt_fn:str) -> tuple[rsa.RSAPrivateKey, Certificat
|
|||
return ca_key, ca_cert
|
||||
|
||||
|
||||
def write_key_and_crt(key:rsa.RSAPrivateKey, crt:Certificate,
|
||||
prefix:str, path: Path | None = None) -> None:
|
||||
def write_key_and_crt(key: rsa.RSAPrivateKey, crt: Certificate,
|
||||
prefix: str, path: Path | None = None) -> None:
|
||||
"""Create pem-encoded files for key and certificate."""
|
||||
path = path or Path()
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
import logging
|
||||
|
@ -26,6 +26,7 @@ class ResponseMode(StrEnum):
|
|||
JSON = "json"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
class RequestMethod(StrEnum):
|
||||
GET = "get"
|
||||
POST = "post"
|
||||
|
@ -76,7 +77,7 @@ class AuthHttpPlugin(BasePlugin[BrokerContext]):
|
|||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
self.http = ClientSession(headers = {"User-Agent": self.config.user_agent})
|
||||
self.http = ClientSession(headers={"User-Agent": self.config.user_agent})
|
||||
|
||||
match self.config.request_method:
|
||||
case RequestMethod.GET:
|
||||
|
@ -102,15 +103,15 @@ class AuthHttpPlugin(BasePlugin[BrokerContext]):
|
|||
case ParamsMode.FORM:
|
||||
match self.config.request_method:
|
||||
case RequestMethod.GET:
|
||||
kwargs = { "params": payload }
|
||||
case _: # POST, PUT
|
||||
kwargs = {"params": payload}
|
||||
case _: # POST, PUT
|
||||
d: Any = FormData(payload)
|
||||
kwargs = {"data": d}
|
||||
case _: # JSON
|
||||
kwargs = { "json": payload}
|
||||
kwargs = {"json": payload}
|
||||
return kwargs
|
||||
|
||||
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool|None: # pylint: disable=R0911
|
||||
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911
|
||||
|
||||
kwargs = self._get_params(payload)
|
||||
|
||||
|
@ -131,7 +132,7 @@ class AuthHttpPlugin(BasePlugin[BrokerContext]):
|
|||
if not self._is_2xx(r):
|
||||
return False
|
||||
data: dict[str, Any] = await r.json()
|
||||
data = {k.lower():v for k,v in data.items()}
|
||||
data = {k.lower(): v for k, v in data.items()}
|
||||
return data.get("ok", None)
|
||||
|
||||
def get_url(self, uri: str) -> str:
|
||||
|
|
|
@ -9,7 +9,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
|
|
|
@ -8,7 +8,7 @@ from amqtt.broker import BrokerContext
|
|||
from amqtt.contexts import Action
|
||||
from amqtt.errors import PluginInitError
|
||||
from amqtt.plugins import TopicMatcher
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin, BasePlugin
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -42,6 +42,7 @@ class AuthLdapPlugin(BasePlugin[BrokerContext]):
|
|||
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
|
||||
raise PluginInitError(self.__class__) from e
|
||||
|
||||
|
||||
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
|
||||
"""Plugin to authenticate a user with an LDAP directory server."""
|
||||
|
||||
|
@ -76,6 +77,7 @@ class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
|
|||
class Config(LdapConfig):
|
||||
"""Configuration for the User Auth LDAP Plugin."""
|
||||
|
||||
|
||||
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
||||
"""Plugin to authenticate a user with an LDAP directory server."""
|
||||
|
||||
|
@ -107,7 +109,6 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
|||
]
|
||||
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
|
||||
|
||||
|
||||
if not results:
|
||||
logger.debug(f"user not found: {session.username}")
|
||||
return False
|
||||
|
@ -120,12 +121,11 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
|||
dn, entry = results[0]
|
||||
|
||||
ldap_attribute = getattr(self.config, self._action_attr_map[action])
|
||||
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
|
||||
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
|
||||
logger.debug(f"DN: {dn} - {ldap_attribute}={topic_filters}")
|
||||
|
||||
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(LdapConfig):
|
||||
"""Configuration for the LDAPAuthPlugin."""
|
||||
|
|
|
@ -82,7 +82,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_session(db_session: AsyncSession, client_id:str) -> StoredSession:
|
||||
async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession:
|
||||
|
||||
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
|
||||
stored_session = await db_session.scalar(stmt)
|
||||
|
@ -92,9 +92,8 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
await db_session.flush()
|
||||
return stored_session
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_message(db_session: AsyncSession, topic:str) -> StoredMessage:
|
||||
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
|
||||
|
||||
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
|
||||
stored_message = await db_session.scalar(stmt)
|
||||
|
@ -104,8 +103,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
await db_session.flush()
|
||||
return stored_message
|
||||
|
||||
|
||||
async def on_broker_client_connected(self, client_id:str, client_session:Session) -> None:
|
||||
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
|
||||
"""Search to see if session already exists."""
|
||||
# if client id doesn't exist, create (can ignore if session is anonymous)
|
||||
# update session information (will, clean_session, etc)
|
||||
|
@ -240,7 +238,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
restored_messages += 1
|
||||
logger.info(f"Retained messages restored: {restored_messages}")
|
||||
|
||||
|
||||
logger.info(f"Restored {restored_sessions} sessions.")
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
|
|
|
@ -35,6 +35,7 @@ class ShadowMessage:
|
|||
def to_message(self) -> bytes:
|
||||
return json.dumps(asdict_no_none(self)).encode("utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetAcceptedMessage(ShadowMessage):
|
||||
state: State[dict[str, Any]]
|
||||
|
@ -46,6 +47,7 @@ class GetAcceptedMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetRejectedMessage(ShadowMessage):
|
||||
code: int
|
||||
|
@ -56,6 +58,7 @@ class GetRejectedMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateAcceptedMessage(ShadowMessage):
|
||||
state: State[dict[str, Any]]
|
||||
|
@ -78,6 +81,7 @@ class UpdateRejectedMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateDeltaMessage(ShadowMessage):
|
||||
state: MutableMapping[str, Any]
|
||||
|
@ -89,6 +93,7 @@ class UpdateDeltaMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
|
||||
|
||||
|
||||
class UpdateIotaMessage(UpdateDeltaMessage):
|
||||
"""Same format, corollary name."""
|
||||
|
||||
|
@ -96,6 +101,7 @@ class UpdateIotaMessage(UpdateDeltaMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateDocumentMessage(ShadowMessage):
|
||||
previous: StateDocument
|
||||
|
|
|
@ -40,7 +40,7 @@ class Shadow(ShadowBase):
|
|||
|
||||
device_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
version: Mapped[int] =mapped_column(Integer, nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
_state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict)
|
||||
|
||||
|
@ -106,7 +106,7 @@ def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow")
|
|||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def convert_update_to_insert(session: Session, _flush_context: object, _instances:object | None) -> None:
|
||||
def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None:
|
||||
"""Force a shadow to insert a new version, instead of updating an existing."""
|
||||
# Make a copy of the dirty set so we can safely mutate the session
|
||||
dirty = list(session.dirty)
|
||||
|
@ -127,6 +127,7 @@ def convert_update_to_insert(session: Session, _flush_context: object, _instance
|
|||
|
||||
session.add(obj) # re-add as new object
|
||||
|
||||
|
||||
_listener_example = '''#
|
||||
# @event.listens_for(Shadow, "before_insert")
|
||||
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:
|
||||
|
|
|
@ -28,7 +28,7 @@ from amqtt.session import ApplicationMessage, Session
|
|||
|
||||
shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)")
|
||||
|
||||
DeviceID= str
|
||||
DeviceID = str
|
||||
ShadowName = str
|
||||
|
||||
|
||||
|
@ -43,6 +43,7 @@ def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
|
|||
"""Nested defaultdict for shadow cache."""
|
||||
return defaultdict(shadow_dict) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class ShadowPlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
|
@ -52,7 +53,6 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
|
|||
self._engine = create_async_engine(self.config.connection)
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Sync the schema."""
|
||||
async with self._engine.begin() as conn:
|
||||
|
@ -83,8 +83,8 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
|
|||
accept_msg = GetAcceptedMessage(
|
||||
state=shadow.state.state,
|
||||
metadata=shadow.state.metadata,
|
||||
timestamp= shadow.created_at,
|
||||
version= shadow.version
|
||||
timestamp=shadow.created_at,
|
||||
version=shadow.version
|
||||
)
|
||||
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
|
||||
|
||||
|
@ -98,7 +98,7 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
|
|||
|
||||
prev_state = shadow.state or StateDocument()
|
||||
prev_state.version = shadow.version or 0 # only required when generating shadow messages
|
||||
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
|
||||
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
|
||||
|
||||
next_state = prev_state + state_update
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
@ -16,10 +16,12 @@ from mergedeep import merge
|
|||
|
||||
C = TypeVar("C", bound=Any)
|
||||
|
||||
|
||||
class StateError(Exception):
|
||||
def __init__(self, msg: str = "'state' field is required") -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaTimestamp:
|
||||
timestamp: int = 0
|
||||
|
@ -58,7 +60,7 @@ class MetaTimestamp:
|
|||
"""Convert timestamp to int."""
|
||||
return int(self.timestamp)
|
||||
|
||||
def __lt__(self, other:int ) -> bool:
|
||||
def __lt__(self, other: int) -> bool:
|
||||
"""Compare timestamp."""
|
||||
return self.timestamp < other
|
||||
|
||||
|
@ -127,6 +129,7 @@ def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMa
|
|||
|
||||
return delta
|
||||
|
||||
|
||||
@dataclass
|
||||
class State(Generic[C]):
|
||||
desired: MutableMapping[str, C] = field(default_factory=dict)
|
||||
|
@ -156,8 +159,8 @@ class State(Generic[C]):
|
|||
class StateDocument:
|
||||
state: State[dict[str, Any]] = field(default_factory=State)
|
||||
metadata: State[MetaTimestamp] = field(default_factory=State)
|
||||
version: int | None = None # only required when generating shadow messages
|
||||
timestamp: int | None = None # only required when generating shadow messages
|
||||
version: int | None = None # only required when generating shadow messages
|
||||
timestamp: int | None = None # only required when generating shadow messages
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "StateDocument":
|
||||
|
|
|
@ -16,10 +16,12 @@ class CodecError(Exception):
|
|||
class NoDataError(Exception):
|
||||
"""Exceptions thrown by packet encode/decode functions."""
|
||||
|
||||
|
||||
class ZeroLengthReadError(NoDataError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Decoding a string of length zero.")
|
||||
|
||||
|
||||
class BrokerError(Exception):
|
||||
"""Exceptions thrown by broker."""
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ try:
|
|||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
class StrEnum(str, Enum): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ class PacketIdVariableHeader(MQTTVariableHeader):
|
|||
_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
|
||||
|
||||
|
||||
class MQTTPayload(Generic[_VH], ABC):
|
||||
class MQTTPayload(ABC, Generic[_VH]):
|
||||
"""Abstract base class for MQTT payloads."""
|
||||
|
||||
async def to_stream(self, writer: asyncio.StreamWriter) -> None:
|
||||
|
|
|
@ -31,6 +31,7 @@ _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
|
|||
if TYPE_CHECKING:
|
||||
from amqtt.broker import BrokerContext
|
||||
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
|
||||
self.packet_id = packet_id
|
||||
|
|
|
@ -19,6 +19,7 @@ from amqtt.session import Session
|
|||
if TYPE_CHECKING:
|
||||
from amqtt.client import ClientContext
|
||||
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -4,13 +4,13 @@ try:
|
|||
from asyncio import InvalidStateError, QueueFull, QueueShutDown
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.12
|
||||
class InvalidStateError(Exception): # type: ignore[no-redef]
|
||||
class InvalidStateError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
pass
|
||||
|
||||
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
|
||||
pass
|
||||
|
||||
|
||||
|
@ -63,6 +63,7 @@ from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplic
|
|||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class ProtocolHandler(Generic[C]):
|
||||
"""Class implementing the MQTT communication protocol using asyncio features."""
|
||||
|
||||
|
@ -199,7 +200,7 @@ class ProtocolHandler(Generic[C]):
|
|||
async def mqtt_publish(
|
||||
self,
|
||||
topic: str,
|
||||
data: bytes | bytearray ,
|
||||
data: bytes | bytearray,
|
||||
qos: int | None,
|
||||
retain: bool,
|
||||
ack_timeout: int | None = None,
|
||||
|
|
|
@ -46,7 +46,7 @@ class BasePlugin(Generic[C]):
|
|||
return section_config
|
||||
|
||||
# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
|
@ -75,7 +75,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
if not bool(self.topic_config) and not is_dataclass(self.context.config):
|
||||
self.context.logger.warning("'topic-check' section not found in context configuration")
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
|
@ -107,7 +107,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
|
||||
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
|
||||
if not self.context.config:
|
||||
return default
|
||||
|
||||
|
@ -126,7 +126,6 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
|||
# auth config section not found and Config dataclass not provided
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Logic for session authentication.
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
|
|||
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class PluginManager(Generic[C]):
|
||||
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
|
||||
|
||||
|
@ -97,7 +98,6 @@ class PluginManager(Generic[C]):
|
|||
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
|
||||
# plugins loaded directly from config dictionary
|
||||
|
||||
|
||||
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||
|
@ -147,7 +147,7 @@ class PluginManager(Generic[C]):
|
|||
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
|
||||
self._event_plugin_callbacks[event].append(awaitable)
|
||||
|
||||
def _load_ep_plugins(self, namespace:str) -> None:
|
||||
def _load_ep_plugins(self, namespace: str) -> None:
|
||||
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
auth_filter_list = []
|
||||
|
@ -224,7 +224,7 @@ class PluginManager(Generic[C]):
|
|||
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
|
||||
"""Load plugin from string dotted path: mymodule.myfile.MyPlugin."""
|
||||
try:
|
||||
plugin_class: Any = import_string(plugin_path)
|
||||
plugin_class: Any = import_string(plugin_path)
|
||||
except ImportError as ep:
|
||||
msg = f"Plugin import failed: {plugin_path}"
|
||||
raise PluginImportError(msg) from ep
|
||||
|
@ -377,7 +377,7 @@ class PluginManager(Generic[C]):
|
|||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
return await self._map_plugin_method(
|
||||
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
|
||||
self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type]
|
||||
|
||||
async def map_plugin_topic(
|
||||
self, *, session: Session, topic: str, action: "Action"
|
||||
|
|
|
@ -14,7 +14,7 @@ except ImportError:
|
|||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class Buffer(Protocol): # type: ignore[no-redef]
|
||||
class Buffer(Protocol): # type: ignore[no-redef]
|
||||
def __buffer__(self, flags: int = ...) -> memoryview:
|
||||
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
|
||||
|
||||
|
@ -75,7 +75,6 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
self._sys_interval: int = 0
|
||||
self._current_process = psutil.Process()
|
||||
|
||||
|
||||
def _clear_stats(self) -> None:
|
||||
"""Initialize broker statistics data structures."""
|
||||
for stat in (
|
||||
|
|
|
@ -21,7 +21,7 @@ def main() -> None:
|
|||
app()
|
||||
|
||||
|
||||
def _version(v:bool) -> None:
|
||||
def _version(v: bool) -> None:
|
||||
if v:
|
||||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
@ -65,7 +65,7 @@ def broker_main(
|
|||
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
_ = loop.create_task(broker.start()) #noqa : RUF006
|
||||
_ = loop.create_task(broker.start()) # noqa : RUF006
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -13,12 +13,13 @@ def main() -> None:
|
|||
"""Run the cli for `ca_creds`."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ca_creds(
|
||||
country:str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
state:str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
|
||||
locality:str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
|
||||
org_name:str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
state: str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
|
||||
locality: str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
|
||||
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
|
||||
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
|
||||
) -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ def main() -> None:
|
|||
|
||||
|
||||
@app.command()
|
||||
def device_creds( # pylint: disable=too-many-locals
|
||||
def device_creds( # pylint: disable=too-many-locals
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
device_id: str = typer.Option(..., "--device-id", help="device id for the SAN"),
|
||||
|
@ -59,5 +59,6 @@ def device_creds( # pylint: disable=too-many-locals
|
|||
|
||||
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -9,12 +9,12 @@ logger = logging.getLogger(__name__)
|
|||
def main() -> None:
|
||||
"""Run the auth db cli."""
|
||||
try:
|
||||
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
|
||||
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
|
||||
except ImportError:
|
||||
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
|
||||
sys.exit(1)
|
||||
|
||||
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
|
||||
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
|
||||
|
||||
try:
|
||||
topic_app()
|
||||
|
@ -32,5 +32,6 @@ def main() -> None:
|
|||
logger.critical(f"could not execute command: {me}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -9,12 +9,12 @@ logger = logging.getLogger(__name__)
|
|||
def main() -> None:
|
||||
"""Run the auth db cli."""
|
||||
try:
|
||||
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
|
||||
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
|
||||
except ImportError:
|
||||
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
|
||||
sys.exit(1)
|
||||
|
||||
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
|
||||
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
|
||||
try:
|
||||
user_app()
|
||||
except ModuleNotFoundError as mnfe:
|
||||
|
@ -31,5 +31,6 @@ def main() -> None:
|
|||
logger.critical(f"could not execute command: {me}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -118,6 +118,7 @@ async def do_pub(
|
|||
logger.fatal("Publish canceled due to previous error")
|
||||
raise asyncio.CancelledError from ce
|
||||
|
||||
|
||||
app = typer.Typer(add_completion=False, rich_markup_mode=None)
|
||||
|
||||
|
||||
|
@ -131,8 +132,9 @@ def _version(v: bool) -> None:
|
|||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
|
||||
@app.command()
|
||||
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
def publisher_main( # pylint: disable=R0914,R0917
|
||||
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),
|
||||
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
|
||||
|
@ -155,7 +157,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
will_retain: bool = typer.Option(False, "--will-retain", help="If the client disconnects unexpectedly the message sent out will be treated as a retained message. *only valid, if `--will-topic` is specified*"),
|
||||
extra_headers_json: str | None = typer.Option(None, "--extra-headers", help="Specify a JSON object string with key-value pairs representing additional headers that are transmitted on the initial connection. *websocket connections only*."),
|
||||
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
|
||||
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
|
||||
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
|
||||
) -> None:
|
||||
"""Command-line MQTT client for publishing simple messages."""
|
||||
provided = [bool(message), bool(file), stdin, lines, no_message]
|
||||
|
|
|
@ -13,14 +13,15 @@ def main() -> None:
|
|||
"""Run the `server_creds` cli."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def server_creds(
|
||||
country:str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
org_name:str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
|
||||
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
|
||||
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
|
||||
ca_key_fn:str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
|
||||
ca_crt_fn:str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
|
||||
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
|
||||
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
|
||||
) -> None:
|
||||
"""Generate a key and certificate for the broker in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa : E501
|
||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||
|
|
|
@ -100,17 +100,17 @@ def main() -> None:
|
|||
app()
|
||||
|
||||
|
||||
def _version(v:bool) -> None:
|
||||
def _version(v: bool) -> None:
|
||||
if v:
|
||||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
|
||||
@app.command()
|
||||
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
def subscribe_main( # pylint: disable=R0914,R0917
|
||||
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
|
||||
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
|
||||
topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008
|
||||
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
|
||||
|
|
|
@ -166,30 +166,40 @@ indent-style = "space"
|
|||
docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
preview = true
|
||||
select = ["ALL"]
|
||||
|
||||
extend-select = [
|
||||
"UP", # pyupgrade
|
||||
"D", # pydocstyle
|
||||
"UP", # pyupgrade
|
||||
"D", # pydocstyle,
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"FBT001", # Checks for the use of boolean positional arguments in function definitions.
|
||||
"FBT002", # Checks for the use of boolean positional arguments in function definitions.
|
||||
"G004", # Logging statement uses f-string
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D107", # Missing docstring in `__init__`
|
||||
"D203", # Incorrect blank line before class (mutually exclusive D211)
|
||||
"D213", # Multi-line summary second line (mutually exclusive D212)
|
||||
"FIX002", # Checks for "TODO" comments.
|
||||
"TD002", # TODO Missing author.
|
||||
"TD003", # TODO Missing issue link for this TODO.
|
||||
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
||||
"ARG002", # Unused method argument
|
||||
"PERF203",# try-except penalty within loops (3.10 only),
|
||||
"COM812" # rule causes conflicts when used with the formatter
|
||||
"FBT001", # Checks for the use of boolean positional arguments in function definitions.
|
||||
"FBT002", # Checks for the use of boolean positional arguments in function definitions.
|
||||
"G004", # Logging statement uses f-string
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D107", # Missing docstring in `__init__`
|
||||
"D203", # Incorrect blank line before class (mutually exclusive D211)
|
||||
"D213", # Multi-line summary second line (mutually exclusive D212)
|
||||
"FIX002", # Checks for "TODO" comments.
|
||||
"TD002", # TODO Missing author.
|
||||
"TD003", # TODO Missing issue link for this TODO.
|
||||
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
||||
"ARG002", # Unused method argument
|
||||
"PERF203",# try-except penalty within loops (3.10 only),
|
||||
"COM812", # rule causes conflicts when used with the formatter,
|
||||
|
||||
# ignore certain preview rules
|
||||
"DOC",
|
||||
"PLW",
|
||||
"PLR",
|
||||
"CPY",
|
||||
"PLC",
|
||||
"RUF052",
|
||||
"B903"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
@ -24,13 +23,13 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
|
||||
'amqtt.plugins.authentication.FileAuthPlugin': {
|
||||
'password_file': Path(__file__).parent / 'passwd',
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
|
||||
"amqtt.plugins.authentication.FileAuthPlugin": {
|
||||
"password_file": Path(__file__).parent / "passwd",
|
||||
},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
|
||||
'amqtt.plugins.topic_checking.TopicAccessControlListPlugin': {
|
||||
'acl': {
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
|
||||
"amqtt.plugins.topic_checking.TopicAccessControlListPlugin": {
|
||||
"acl": {
|
||||
# username: [list of allowed topics]
|
||||
"test": ["repositories/+/master", "calendar/#", "data/memes"],
|
||||
"anonymous": [],
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
@ -18,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
class RemoteInfoPlugin(BasePlugin):
|
||||
|
||||
async def on_broker_client_connected(self, *, client_id:str, client_session:Session) -> None:
|
||||
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ''
|
||||
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ""
|
||||
|
||||
logger.info(f"client '{client_id}' connected from"
|
||||
f" '{client_session.remote_address}' {display_port_str}")
|
||||
|
@ -40,8 +38,8 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
|
||||
'samples.broker_custom_plugin.RemoteInfoPlugin': { 'display_port': True },
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
|
||||
"samples.broker_custom_plugin.RemoteInfoPlugin": { "display_port": True },
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
import logging
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
@ -24,11 +23,11 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
|
||||
'amqtt.plugins.authentication.FileAuthPlugin': {
|
||||
'password_file': Path(__file__).parent / 'passwd',
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
|
||||
"amqtt.plugins.authentication.FileAuthPlugin": {
|
||||
"password_file": Path(__file__).parent / "passwd",
|
||||
},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
@ -24,12 +23,12 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True},
|
||||
'amqtt.plugins.authentication.FileAuthPlugin': {
|
||||
'password_file': Path(__file__).parent / 'passwd',
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
|
||||
"amqtt.plugins.authentication.FileAuthPlugin": {
|
||||
"password_file": Path(__file__).parent / "passwd",
|
||||
},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 10},
|
||||
'amqtt.plugins.topic_checking.TopicTabooPlugin': {},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 10},
|
||||
"amqtt.plugins.topic_checking.TopicTabooPlugin": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
import logging
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ async def test_coro1() -> None:
|
|||
|
||||
async def test_coro2() -> None:
|
||||
try:
|
||||
client = MQTTClient(config={'auto_reconnect': False, 'connection_timeout': 1})
|
||||
client = MQTTClient(config={"auto_reconnect": False, "connection_timeout": 1})
|
||||
await client.connect("mqtt://localhost:1884/")
|
||||
await client.publish("a/b", b"TEST MESSAGE WITH QOS_0", qos=0x00)
|
||||
await client.publish("a/b", b"TEST MESSAGE WITH QOS_1", qos=0x01)
|
||||
|
@ -43,7 +43,7 @@ async def test_coro2() -> None:
|
|||
logger.info("test_coro2 messages published")
|
||||
await client.disconnect()
|
||||
except ConnectError:
|
||||
logger.info(f"Connection failed", exc_info=True)
|
||||
logger.info("Connection failed", exc_info=True)
|
||||
|
||||
|
||||
def __main__():
|
||||
|
|
|
@ -27,7 +27,7 @@ async def test_coro() -> None:
|
|||
await client.publish("calendar/amqtt/releases", b"NEW RELEASE", qos=QOS_1)
|
||||
logger.info("messages published")
|
||||
await client.disconnect()
|
||||
except ConnectError as ce:
|
||||
except ConnectError:
|
||||
logger.exception("ERROR: Connection failed")
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||
|
@ -29,7 +28,7 @@ config = {
|
|||
|
||||
|
||||
async def test_coro(certfile: str) -> None:
|
||||
config['certfile'] = certfile
|
||||
config["certfile"] = certfile
|
||||
client = MQTTClient(config=config)
|
||||
|
||||
await client.connect("mqtts://localhost:8883")
|
||||
|
@ -47,7 +46,7 @@ def __main__():
|
|||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity")
|
||||
parser.add_argument("--cert", default="cert.pem", help="path & file to verify server's authenticity")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_coro(args.cert))
|
||||
|
|
|
@ -4,7 +4,6 @@ import logging
|
|||
from amqtt.client import ClientError, MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||
|
||||
|
||||
"""
|
||||
This sample shows how to subscribe to different $SYS topics and how to receive incoming messages
|
||||
"""
|
||||
|
@ -13,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def uptime_coro() -> None:
|
||||
client = MQTTClient(config={'auto_reconnect': False})
|
||||
client = MQTTClient(config={"auto_reconnect": False})
|
||||
await client.connect("mqtt://localhost:1883")
|
||||
|
||||
await client.subscribe(
|
||||
|
|
|
@ -35,7 +35,7 @@ async def uptime_coro() -> None:
|
|||
await client.unsubscribe(["$SYS/#", "data/memes"])
|
||||
logger.info("UnSubscribed")
|
||||
await client.disconnect()
|
||||
except ClientError as ce:
|
||||
except ClientError:
|
||||
logger.exception("Client exception")
|
||||
|
||||
|
||||
|
|
|
@ -9,14 +9,13 @@ from aiohttp import web
|
|||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
|
||||
from amqtt.errors import ConnectError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MQTT_LISTENER_NAME = "myMqttListener"
|
||||
|
||||
async def hello(request):
|
||||
"""get request handler"""
|
||||
"""Get request handler"""
|
||||
return web.Response(text="Hello, world")
|
||||
|
||||
class WebSocketResponseReader(ReaderAdapter):
|
||||
|
@ -27,17 +26,17 @@ class WebSocketResponseReader(ReaderAdapter):
|
|||
self.buffer = bytearray()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
read 'n' bytes from the datastream, if < 0 read all available bytes
|
||||
"""Read 'n' bytes from the datastream, if < 0 read all available bytes
|
||||
|
||||
Raises:
|
||||
BrokerPipeError : if reading on a closed websocket connection
|
||||
|
||||
"""
|
||||
# continue until buffer contains at least the amount of data being requested
|
||||
while not self.buffer or len(self.buffer) < n:
|
||||
# if the websocket is closed
|
||||
if self.ws.closed:
|
||||
raise BrokenPipeError()
|
||||
raise BrokenPipeError
|
||||
|
||||
try:
|
||||
# read from stream
|
||||
|
@ -46,10 +45,10 @@ class WebSocketResponseReader(ReaderAdapter):
|
|||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self.buffer.extend(msg.data)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
raise BrokenPipeError()
|
||||
raise BrokenPipeError
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise BrokenPipeError()
|
||||
raise BrokenPipeError
|
||||
|
||||
# return all bytes currently in the buffer
|
||||
if n == -1:
|
||||
|
@ -74,7 +73,7 @@ class WebSocketResponseWriter(WriterAdapter):
|
|||
|
||||
# needed for `get_peer_info`
|
||||
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
|
||||
peer_name = request.transport.get_extra_info('peername')
|
||||
peer_name = request.transport.get_extra_info("peername")
|
||||
if peer_name is not None:
|
||||
self.client_ip, self.port = peer_name[0:2]
|
||||
else:
|
||||
|
@ -110,17 +109,17 @@ class WebSocketResponseWriter(WriterAdapter):
|
|||
async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse:
|
||||
|
||||
# establish connection by responding to the websocket request with the 'mqtt' protocol
|
||||
ws = web.WebSocketResponse(protocols=['mqtt',])
|
||||
ws = web.WebSocketResponse(protocols=["mqtt"])
|
||||
await ws.prepare(request)
|
||||
|
||||
# access the broker created when the server started
|
||||
b: Broker = request.app['broker']
|
||||
b: Broker = request.app["broker"]
|
||||
|
||||
# hand-off the websocket data stream to the broker for handling
|
||||
# `listener_name` is the same name of the externalized listener in the broker config
|
||||
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME)
|
||||
|
||||
logger.debug('websocket connection closed')
|
||||
logger.debug("websocket connection closed")
|
||||
return ws
|
||||
|
||||
|
||||
|
@ -140,9 +139,9 @@ def main():
|
|||
app = web.Application()
|
||||
app.add_routes(
|
||||
[
|
||||
web.get('/', hello), # http get request/response route
|
||||
web.get('/ws', websocket_handler), # standard websocket handler
|
||||
web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections
|
||||
web.get("/", hello), # http get request/response route
|
||||
web.get("/ws", websocket_handler), # standard websocket handler
|
||||
web.get("/mqtt", mqtt_websocket_handler), # websocket handler for mqtt connections
|
||||
])
|
||||
# create background task for running the `amqtt` broker
|
||||
app.cleanup_ctx.append(run_broker)
|
||||
|
@ -154,12 +153,12 @@ def main():
|
|||
|
||||
async def run_broker(_app):
|
||||
"""App init function to start (and then shutdown) the `amqtt` broker.
|
||||
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
|
||||
|
||||
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks
|
||||
"""
|
||||
# standard TCP connection as well as an externalized-listener
|
||||
cfg = BrokerConfig(
|
||||
listeners={
|
||||
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
|
||||
"default":ListenerConfig(type=ListenerType.TCP, bind="127.0.0.1:1883"),
|
||||
MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL),
|
||||
}
|
||||
)
|
||||
|
@ -169,7 +168,7 @@ async def run_broker(_app):
|
|||
broker = Broker(config=cfg, loop=loop)
|
||||
|
||||
# store broker instance so that incoming requests can hand off processing of a datastream
|
||||
_app['broker'] = broker
|
||||
_app["broker"] = broker
|
||||
# start the broker
|
||||
await broker.start()
|
||||
|
||||
|
@ -180,6 +179,6 @@ async def run_broker(_app):
|
|||
await broker.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import asyncio
|
||||
import ssl
|
||||
from asyncio import StreamWriter, StreamReader, Event
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
|
||||
import typer
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.client import ClientContext
|
||||
from amqtt.contexts import ClientConfig, BrokerConfig, ListenerConfig, ListenerType
|
||||
from amqtt.contexts import BrokerConfig, ClientConfig, ListenerConfig, ListenerType
|
||||
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import Session
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,7 +59,7 @@ class UnixStreamWriterAdapter(WriterAdapter):
|
|||
await self._writer.drain()
|
||||
|
||||
def get_peer_info(self) -> tuple[str, int]:
|
||||
extra_info = self._writer.get_extra_info('socket')
|
||||
extra_info = self._writer.get_extra_info("socket")
|
||||
return extra_info.getsockname(), 0
|
||||
|
||||
async def close(self) -> None:
|
||||
|
@ -86,7 +85,7 @@ async def run_broker(socket_file: Path):
|
|||
# configure the broker with a single, external listener
|
||||
cfg = BrokerConfig(
|
||||
listeners={
|
||||
'default': ListenerConfig(
|
||||
"default": ListenerConfig(
|
||||
type=ListenerType.EXTERNAL
|
||||
)
|
||||
},
|
||||
|
@ -109,7 +108,7 @@ async def run_broker(socket_file: Path):
|
|||
# passes the connection to the broker for protocol communications
|
||||
await b.external_connected(reader=r, writer=w, listener_name=listener_name)
|
||||
|
||||
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name='default'), path=socket_file)
|
||||
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name="default"), path=socket_file)
|
||||
await b.start()
|
||||
|
||||
try:
|
||||
|
@ -163,7 +162,7 @@ async def run_client(socket_file: Path):
|
|||
try:
|
||||
while True:
|
||||
# periodically send a message
|
||||
await cph.mqtt_publish('my/topic', b'my message', 0, False)
|
||||
await cph.mqtt_publish("my/topic", b"my message", 0, False)
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
cph.detach()
|
||||
|
|
Ładowanie…
Reference in New Issue