kopia lustrzana https://github.com/Yakifo/amqtt
rodzic
b4d58c9130
commit
8e47ede192
|
@ -325,7 +325,6 @@ class Broker:
|
|||
ssl_context: ssl.SSLContext | None,
|
||||
) -> asyncio.Server | websockets.asyncio.server.Server:
|
||||
"""Create a server instance for a listener."""
|
||||
|
||||
match listener_type:
|
||||
case ListenerType.TCP:
|
||||
return await asyncio.start_server(
|
||||
|
@ -588,8 +587,6 @@ class Broker:
|
|||
for topic in self._subscriptions:
|
||||
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
|
||||
|
||||
|
||||
|
||||
await self._client_message_loop(client_session, handler)
|
||||
|
||||
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
|
||||
|
@ -620,7 +617,6 @@ class Broker:
|
|||
|
||||
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
|
||||
|
||||
|
||||
if subscribe_waiter in done:
|
||||
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
||||
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
||||
|
@ -690,7 +686,6 @@ class Broker:
|
|||
client_id=client_session.client_id,
|
||||
client_session=client_session)
|
||||
|
||||
|
||||
async def _handle_subscription(
|
||||
self,
|
||||
client_session: Session,
|
||||
|
|
|
@ -424,14 +424,10 @@ class MQTTClient:
|
|||
scheme = uri_attributes.scheme
|
||||
secure = scheme in ("mqtts", "wss")
|
||||
self.session.username = (
|
||||
self.session.username
|
||||
if self.session.username
|
||||
else (str(uri_attributes.username) if uri_attributes.username else None)
|
||||
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
|
||||
)
|
||||
self.session.password = (
|
||||
self.session.password
|
||||
if self.session.password
|
||||
else (str(uri_attributes.password) if uri_attributes.password else None)
|
||||
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
|
||||
)
|
||||
self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
|
||||
self.session.remote_port = uri_attributes.port
|
||||
|
|
|
@ -50,6 +50,7 @@ class ListenerType(StrEnum):
|
|||
"""Display the string value, instead of the enum member."""
|
||||
return f'"{self.value!s}"'
|
||||
|
||||
|
||||
class Dictable:
|
||||
"""Add dictionary methods to a dataclass."""
|
||||
|
||||
|
@ -263,6 +264,7 @@ class ConnectionConfig(Dictable):
|
|||
if isinstance(getattr(self, fn), str):
|
||||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopicConfig(Dictable):
|
||||
"""Configuration of how messages to specific topics are published.
|
||||
|
|
|
@ -36,9 +36,11 @@ class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
|
|||
if value is None:
|
||||
return None
|
||||
return [self.dataclass_type(**item) for item in value]
|
||||
|
||||
def process_literal_param(self, value: Any, dialect: Any) -> Any:
|
||||
# Required by SQLAlchemy, typically used for literal SQL rendering.
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self) -> type:
|
||||
# Required by TypeEngine to indicate the expected Python type.
|
||||
|
|
|
@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
matcher = TopicMatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllowedTopic:
|
||||
topic: str
|
||||
|
@ -43,6 +44,7 @@ class AllowedTopic:
|
|||
"""Display topic."""
|
||||
return self.topic
|
||||
|
||||
|
||||
class PasswordHasher:
|
||||
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ class UserAuthDBPlugin(BaseAuthPlugin):
|
|||
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
|
||||
"""list of hash schemes to use for passwords"""
|
||||
|
||||
|
||||
class TopicAuthDBPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
|
|
|
@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
|
|||
topic_app = typer.Typer(no_args_is_help=True)
|
||||
|
||||
|
||||
|
||||
@topic_app.callback()
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
|
|
|
@ -89,6 +89,7 @@ class CertificateAuthPlugin(BaseAuthPlugin):
|
|||
uri_domain: str
|
||||
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
|
||||
|
||||
|
||||
def generate_root_creds(country: str, state: str, locality: str,
|
||||
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
"""Generate CA key and certificate."""
|
||||
|
@ -167,7 +168,6 @@ def generate_server_csr(country:str, org_name: str, cn:str) -> tuple[rsa.RSAPriv
|
|||
return key, csr
|
||||
|
||||
|
||||
|
||||
def generate_device_csr(country: str, org_name: str, common_name: str,
|
||||
uri_san: str, dns_san: str
|
||||
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
|
||||
|
@ -194,6 +194,7 @@ def generate_device_csr(country: str, org_name: str, common_name: str,
|
|||
|
||||
return key, csr
|
||||
|
||||
|
||||
def sign_csr(csr: CertificateSigningRequest,
|
||||
ca_key: rsa.RSAPrivateKey,
|
||||
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
|
||||
|
@ -221,6 +222,7 @@ def sign_csr(csr: CertificateSigningRequest,
|
|||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
|
||||
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
|
||||
"""Load server key and certificate."""
|
||||
with Path(ca_key_fn).open("rb") as f:
|
||||
|
|
|
@ -26,6 +26,7 @@ class ResponseMode(StrEnum):
|
|||
JSON = "json"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
class RequestMethod(StrEnum):
|
||||
GET = "get"
|
||||
POST = "post"
|
||||
|
|
|
@ -8,7 +8,7 @@ from amqtt.broker import BrokerContext
|
|||
from amqtt.contexts import Action
|
||||
from amqtt.errors import PluginInitError
|
||||
from amqtt.plugins import TopicMatcher
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin, BasePlugin
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -42,6 +42,7 @@ class AuthLdapPlugin(BasePlugin[BrokerContext]):
|
|||
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
|
||||
raise PluginInitError(self.__class__) from e
|
||||
|
||||
|
||||
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
|
||||
"""Plugin to authenticate a user with an LDAP directory server."""
|
||||
|
||||
|
@ -76,6 +77,7 @@ class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
|
|||
class Config(LdapConfig):
|
||||
"""Configuration for the User Auth LDAP Plugin."""
|
||||
|
||||
|
||||
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
||||
"""Plugin to authenticate a user with an LDAP directory server."""
|
||||
|
||||
|
@ -107,7 +109,6 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
|||
]
|
||||
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
|
||||
|
||||
|
||||
if not results:
|
||||
logger.debug(f"user not found: {session.username}")
|
||||
return False
|
||||
|
@ -125,7 +126,6 @@ class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
|
|||
|
||||
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(LdapConfig):
|
||||
"""Configuration for the LDAPAuthPlugin."""
|
||||
|
|
|
@ -92,7 +92,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
await db_session.flush()
|
||||
return stored_session
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
|
||||
|
||||
|
@ -104,7 +103,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
await db_session.flush()
|
||||
return stored_message
|
||||
|
||||
|
||||
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
|
||||
"""Search to see if session already exists."""
|
||||
# if client id doesn't exist, create (can ignore if session is anonymous)
|
||||
|
@ -240,7 +238,6 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
restored_messages += 1
|
||||
logger.info(f"Retained messages restored: {restored_messages}")
|
||||
|
||||
|
||||
logger.info(f"Restored {restored_sessions} sessions.")
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
|
|
|
@ -35,6 +35,7 @@ class ShadowMessage:
|
|||
def to_message(self) -> bytes:
|
||||
return json.dumps(asdict_no_none(self)).encode("utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetAcceptedMessage(ShadowMessage):
|
||||
state: State[dict[str, Any]]
|
||||
|
@ -46,6 +47,7 @@ class GetAcceptedMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetRejectedMessage(ShadowMessage):
|
||||
code: int
|
||||
|
@ -56,6 +58,7 @@ class GetRejectedMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateAcceptedMessage(ShadowMessage):
|
||||
state: State[dict[str, Any]]
|
||||
|
@ -78,6 +81,7 @@ class UpdateRejectedMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateDeltaMessage(ShadowMessage):
|
||||
state: MutableMapping[str, Any]
|
||||
|
@ -89,6 +93,7 @@ class UpdateDeltaMessage(ShadowMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
|
||||
|
||||
|
||||
class UpdateIotaMessage(UpdateDeltaMessage):
|
||||
"""Same format, corollary name."""
|
||||
|
||||
|
@ -96,6 +101,7 @@ class UpdateIotaMessage(UpdateDeltaMessage):
|
|||
def topic(device_id: str, shadow_name: str) -> str:
|
||||
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateDocumentMessage(ShadowMessage):
|
||||
previous: StateDocument
|
||||
|
|
|
@ -127,6 +127,7 @@ def convert_update_to_insert(session: Session, _flush_context: object, _instance
|
|||
|
||||
session.add(obj) # re-add as new object
|
||||
|
||||
|
||||
_listener_example = '''#
|
||||
# @event.listens_for(Shadow, "before_insert")
|
||||
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:
|
||||
|
|
|
@ -43,6 +43,7 @@ def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
|
|||
"""Nested defaultdict for shadow cache."""
|
||||
return defaultdict(shadow_dict) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class ShadowPlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
|
@ -52,7 +53,6 @@ class ShadowPlugin(BasePlugin[BrokerContext]):
|
|||
self._engine = create_async_engine(self.config.connection)
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Sync the schema."""
|
||||
async with self._engine.begin() as conn:
|
||||
|
|
|
@ -16,10 +16,12 @@ from mergedeep import merge
|
|||
|
||||
C = TypeVar("C", bound=Any)
|
||||
|
||||
|
||||
class StateError(Exception):
|
||||
def __init__(self, msg: str = "'state' field is required") -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaTimestamp:
|
||||
timestamp: int = 0
|
||||
|
@ -127,6 +129,7 @@ def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMa
|
|||
|
||||
return delta
|
||||
|
||||
|
||||
@dataclass
|
||||
class State(Generic[C]):
|
||||
desired: MutableMapping[str, C] = field(default_factory=dict)
|
||||
|
|
|
@ -16,10 +16,12 @@ class CodecError(Exception):
|
|||
class NoDataError(Exception):
|
||||
"""Exceptions thrown by packet encode/decode functions."""
|
||||
|
||||
|
||||
class ZeroLengthReadError(NoDataError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Decoding a string of length zero.")
|
||||
|
||||
|
||||
class BrokerError(Exception):
|
||||
"""Exceptions thrown by broker."""
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ class PacketIdVariableHeader(MQTTVariableHeader):
|
|||
_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
|
||||
|
||||
|
||||
class MQTTPayload(Generic[_VH], ABC):
|
||||
class MQTTPayload(ABC, Generic[_VH]):
|
||||
"""Abstract base class for MQTT payloads."""
|
||||
|
||||
async def to_stream(self, writer: asyncio.StreamWriter) -> None:
|
||||
|
|
|
@ -31,6 +31,7 @@ _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
|
|||
if TYPE_CHECKING:
|
||||
from amqtt.broker import BrokerContext
|
||||
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
|
||||
self.packet_id = packet_id
|
||||
|
|
|
@ -19,6 +19,7 @@ from amqtt.session import Session
|
|||
if TYPE_CHECKING:
|
||||
from amqtt.client import ClientContext
|
||||
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -63,6 +63,7 @@ from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplic
|
|||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class ProtocolHandler(Generic[C]):
|
||||
"""Class implementing the MQTT communication protocol using asyncio features."""
|
||||
|
||||
|
|
|
@ -126,7 +126,6 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
|||
# auth config section not found and Config dataclass not provided
|
||||
self.context.logger.warning("'auth' section not found in context configuration")
|
||||
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
"""Logic for session authentication.
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
|
|||
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class PluginManager(Generic[C]):
|
||||
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
|
||||
|
||||
|
@ -97,7 +98,6 @@ class PluginManager(Generic[C]):
|
|||
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
|
||||
# plugins loaded directly from config dictionary
|
||||
|
||||
|
||||
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||
|
|
|
@ -75,7 +75,6 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
self._sys_interval: int = 0
|
||||
self._current_process = psutil.Process()
|
||||
|
||||
|
||||
def _clear_stats(self) -> None:
|
||||
"""Initialize broker statistics data structures."""
|
||||
for stat in (
|
||||
|
|
|
@ -13,6 +13,7 @@ def main() -> None:
|
|||
"""Run the cli for `ca_creds`."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ca_creds(
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
|
|
|
@ -59,5 +59,6 @@ def device_creds( # pylint: disable=too-many-locals
|
|||
|
||||
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -32,5 +32,6 @@ def main() -> None:
|
|||
logger.critical(f"could not execute command: {me}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -31,5 +31,6 @@ def main() -> None:
|
|||
logger.critical(f"could not execute command: {me}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -118,6 +118,7 @@ async def do_pub(
|
|||
logger.fatal("Publish canceled due to previous error")
|
||||
raise asyncio.CancelledError from ce
|
||||
|
||||
|
||||
app = typer.Typer(add_completion=False, rich_markup_mode=None)
|
||||
|
||||
|
||||
|
@ -131,8 +132,9 @@ def _version(v: bool) -> None:
|
|||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
|
||||
@app.command()
|
||||
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
def publisher_main( # pylint: disable=R0914,R0917
|
||||
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),
|
||||
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
|
||||
|
|
|
@ -13,6 +13,7 @@ def main() -> None:
|
|||
"""Run the `server_creds` cli."""
|
||||
app()
|
||||
|
||||
|
||||
@app.command()
|
||||
def server_creds(
|
||||
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
|
||||
|
|
|
@ -107,7 +107,7 @@ def _version(v:bool) -> None:
|
|||
|
||||
|
||||
@app.command()
|
||||
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
def subscribe_main( # pylint: disable=R0914,R0917
|
||||
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
|
||||
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
|
|
|
@ -166,11 +166,12 @@ indent-style = "space"
|
|||
docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
preview = true
|
||||
select = ["ALL"]
|
||||
|
||||
extend-select = [
|
||||
"UP", # pyupgrade
|
||||
"D", # pydocstyle
|
||||
"D", # pydocstyle,
|
||||
]
|
||||
|
||||
ignore = [
|
||||
|
@ -189,7 +190,16 @@ ignore = [
|
|||
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
||||
"ARG002", # Unused method argument
|
||||
"PERF203",# try-except penalty within loops (3.10 only),
|
||||
"COM812" # rule causes conflicts when used with the formatter
|
||||
"COM812", # rule causes conflicts when used with the formatter,
|
||||
|
||||
# ignore certain preview rules
|
||||
"DOC",
|
||||
"PLW",
|
||||
"PLR",
|
||||
"CPY",
|
||||
"PLC",
|
||||
"RUF052",
|
||||
"B903"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
@ -24,13 +23,13 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
|
||||
'amqtt.plugins.authentication.FileAuthPlugin': {
|
||||
'password_file': Path(__file__).parent / 'passwd',
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
|
||||
"amqtt.plugins.authentication.FileAuthPlugin": {
|
||||
"password_file": Path(__file__).parent / "passwd",
|
||||
},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
|
||||
'amqtt.plugins.topic_checking.TopicAccessControlListPlugin': {
|
||||
'acl': {
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
|
||||
"amqtt.plugins.topic_checking.TopicAccessControlListPlugin": {
|
||||
"acl": {
|
||||
# username: [list of allowed topics]
|
||||
"test": ["repositories/+/master", "calendar/#", "data/memes"],
|
||||
"anonymous": [],
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
@ -18,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
class RemoteInfoPlugin(BasePlugin):
|
||||
|
||||
async def on_broker_client_connected(self, *, client_id:str, client_session:Session) -> None:
|
||||
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ''
|
||||
display_port_str = f"on port '{client_session.remote_port}'" if self.config.display_port else ""
|
||||
|
||||
logger.info(f"client '{client_id}' connected from"
|
||||
f" '{client_session.remote_address}' {display_port_str}")
|
||||
|
@ -40,8 +38,8 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
|
||||
'samples.broker_custom_plugin.RemoteInfoPlugin': { 'display_port': True },
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
|
||||
"samples.broker_custom_plugin.RemoteInfoPlugin": { "display_port": True },
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
import logging
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
@ -24,11 +23,11 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': { 'allow_anonymous': True},
|
||||
'amqtt.plugins.authentication.FileAuthPlugin': {
|
||||
'password_file': Path(__file__).parent / 'passwd',
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": { "allow_anonymous": True},
|
||||
"amqtt.plugins.authentication.FileAuthPlugin": {
|
||||
"password_file": Path(__file__).parent / "passwd",
|
||||
},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': { "sys_interval": 10},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": { "sys_interval": 10},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.broker import Broker
|
||||
|
@ -24,12 +23,12 @@ config = {
|
|||
},
|
||||
},
|
||||
"plugins": {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': True},
|
||||
'amqtt.plugins.authentication.FileAuthPlugin': {
|
||||
'password_file': Path(__file__).parent / 'passwd',
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
|
||||
"amqtt.plugins.authentication.FileAuthPlugin": {
|
||||
"password_file": Path(__file__).parent / "passwd",
|
||||
},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 10},
|
||||
'amqtt.plugins.topic_checking.TopicTabooPlugin': {},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 10},
|
||||
"amqtt.plugins.topic_checking.TopicTabooPlugin": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
import logging
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ async def test_coro1() -> None:
|
|||
|
||||
async def test_coro2() -> None:
|
||||
try:
|
||||
client = MQTTClient(config={'auto_reconnect': False, 'connection_timeout': 1})
|
||||
client = MQTTClient(config={"auto_reconnect": False, "connection_timeout": 1})
|
||||
await client.connect("mqtt://localhost:1884/")
|
||||
await client.publish("a/b", b"TEST MESSAGE WITH QOS_0", qos=0x00)
|
||||
await client.publish("a/b", b"TEST MESSAGE WITH QOS_1", qos=0x01)
|
||||
|
@ -43,7 +43,7 @@ async def test_coro2() -> None:
|
|||
logger.info("test_coro2 messages published")
|
||||
await client.disconnect()
|
||||
except ConnectError:
|
||||
logger.info(f"Connection failed", exc_info=True)
|
||||
logger.info("Connection failed", exc_info=True)
|
||||
|
||||
|
||||
def __main__():
|
||||
|
|
|
@ -27,7 +27,7 @@ async def test_coro() -> None:
|
|||
await client.publish("calendar/amqtt/releases", b"NEW RELEASE", qos=QOS_1)
|
||||
logger.info("messages published")
|
||||
await client.disconnect()
|
||||
except ConnectError as ce:
|
||||
except ConnectError:
|
||||
logger.exception("ERROR: Connection failed")
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||
|
@ -29,7 +28,7 @@ config = {
|
|||
|
||||
|
||||
async def test_coro(certfile: str) -> None:
|
||||
config['certfile'] = certfile
|
||||
config["certfile"] = certfile
|
||||
client = MQTTClient(config=config)
|
||||
|
||||
await client.connect("mqtts://localhost:8883")
|
||||
|
@ -47,7 +46,7 @@ def __main__():
|
|||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity")
|
||||
parser.add_argument("--cert", default="cert.pem", help="path & file to verify server's authenticity")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_coro(args.cert))
|
||||
|
|
|
@ -4,7 +4,6 @@ import logging
|
|||
from amqtt.client import ClientError, MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||
|
||||
|
||||
"""
|
||||
This sample shows how to subscribe to different $SYS topics and how to receive incoming messages
|
||||
"""
|
||||
|
@ -13,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def uptime_coro() -> None:
|
||||
client = MQTTClient(config={'auto_reconnect': False})
|
||||
client = MQTTClient(config={"auto_reconnect": False})
|
||||
await client.connect("mqtt://localhost:1883")
|
||||
|
||||
await client.subscribe(
|
||||
|
|
|
@ -35,7 +35,7 @@ async def uptime_coro() -> None:
|
|||
await client.unsubscribe(["$SYS/#", "data/memes"])
|
||||
logger.info("UnSubscribed")
|
||||
await client.disconnect()
|
||||
except ClientError as ce:
|
||||
except ClientError:
|
||||
logger.exception("Client exception")
|
||||
|
||||
|
||||
|
|
|
@ -9,14 +9,13 @@ from aiohttp import web
|
|||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
|
||||
from amqtt.errors import ConnectError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MQTT_LISTENER_NAME = "myMqttListener"
|
||||
|
||||
async def hello(request):
|
||||
"""get request handler"""
|
||||
"""Get request handler"""
|
||||
return web.Response(text="Hello, world")
|
||||
|
||||
class WebSocketResponseReader(ReaderAdapter):
|
||||
|
@ -27,17 +26,17 @@ class WebSocketResponseReader(ReaderAdapter):
|
|||
self.buffer = bytearray()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
read 'n' bytes from the datastream, if < 0 read all available bytes
|
||||
"""Read 'n' bytes from the datastream, if < 0 read all available bytes
|
||||
|
||||
Raises:
|
||||
BrokerPipeError : if reading on a closed websocket connection
|
||||
|
||||
"""
|
||||
# continue until buffer contains at least the amount of data being requested
|
||||
while not self.buffer or len(self.buffer) < n:
|
||||
# if the websocket is closed
|
||||
if self.ws.closed:
|
||||
raise BrokenPipeError()
|
||||
raise BrokenPipeError
|
||||
|
||||
try:
|
||||
# read from stream
|
||||
|
@ -46,10 +45,10 @@ class WebSocketResponseReader(ReaderAdapter):
|
|||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self.buffer.extend(msg.data)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
raise BrokenPipeError()
|
||||
raise BrokenPipeError
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise BrokenPipeError()
|
||||
raise BrokenPipeError
|
||||
|
||||
# return all bytes currently in the buffer
|
||||
if n == -1:
|
||||
|
@ -74,7 +73,7 @@ class WebSocketResponseWriter(WriterAdapter):
|
|||
|
||||
# needed for `get_peer_info`
|
||||
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
|
||||
peer_name = request.transport.get_extra_info('peername')
|
||||
peer_name = request.transport.get_extra_info("peername")
|
||||
if peer_name is not None:
|
||||
self.client_ip, self.port = peer_name[0:2]
|
||||
else:
|
||||
|
@ -110,17 +109,17 @@ class WebSocketResponseWriter(WriterAdapter):
|
|||
async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse:
|
||||
|
||||
# establish connection by responding to the websocket request with the 'mqtt' protocol
|
||||
ws = web.WebSocketResponse(protocols=['mqtt',])
|
||||
ws = web.WebSocketResponse(protocols=["mqtt"])
|
||||
await ws.prepare(request)
|
||||
|
||||
# access the broker created when the server started
|
||||
b: Broker = request.app['broker']
|
||||
b: Broker = request.app["broker"]
|
||||
|
||||
# hand-off the websocket data stream to the broker for handling
|
||||
# `listener_name` is the same name of the externalized listener in the broker config
|
||||
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME)
|
||||
|
||||
logger.debug('websocket connection closed')
|
||||
logger.debug("websocket connection closed")
|
||||
return ws
|
||||
|
||||
|
||||
|
@ -140,9 +139,9 @@ def main():
|
|||
app = web.Application()
|
||||
app.add_routes(
|
||||
[
|
||||
web.get('/', hello), # http get request/response route
|
||||
web.get('/ws', websocket_handler), # standard websocket handler
|
||||
web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections
|
||||
web.get("/", hello), # http get request/response route
|
||||
web.get("/ws", websocket_handler), # standard websocket handler
|
||||
web.get("/mqtt", mqtt_websocket_handler), # websocket handler for mqtt connections
|
||||
])
|
||||
# create background task for running the `amqtt` broker
|
||||
app.cleanup_ctx.append(run_broker)
|
||||
|
@ -154,12 +153,12 @@ def main():
|
|||
|
||||
async def run_broker(_app):
|
||||
"""App init function to start (and then shutdown) the `amqtt` broker.
|
||||
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
|
||||
|
||||
https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks
|
||||
"""
|
||||
# standard TCP connection as well as an externalized-listener
|
||||
cfg = BrokerConfig(
|
||||
listeners={
|
||||
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
|
||||
"default":ListenerConfig(type=ListenerType.TCP, bind="127.0.0.1:1883"),
|
||||
MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL),
|
||||
}
|
||||
)
|
||||
|
@ -169,7 +168,7 @@ async def run_broker(_app):
|
|||
broker = Broker(config=cfg, loop=loop)
|
||||
|
||||
# store broker instance so that incoming requests can hand off processing of a datastream
|
||||
_app['broker'] = broker
|
||||
_app["broker"] = broker
|
||||
# start the broker
|
||||
await broker.start()
|
||||
|
||||
|
@ -180,6 +179,6 @@ async def run_broker(_app):
|
|||
await broker.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import asyncio
|
||||
import ssl
|
||||
from asyncio import StreamWriter, StreamReader, Event
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
|
||||
import typer
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.client import ClientContext
|
||||
from amqtt.contexts import ClientConfig, BrokerConfig, ListenerConfig, ListenerType
|
||||
from amqtt.contexts import BrokerConfig, ClientConfig, ListenerConfig, ListenerType
|
||||
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import Session
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,7 +59,7 @@ class UnixStreamWriterAdapter(WriterAdapter):
|
|||
await self._writer.drain()
|
||||
|
||||
def get_peer_info(self) -> tuple[str, int]:
|
||||
extra_info = self._writer.get_extra_info('socket')
|
||||
extra_info = self._writer.get_extra_info("socket")
|
||||
return extra_info.getsockname(), 0
|
||||
|
||||
async def close(self) -> None:
|
||||
|
@ -86,7 +85,7 @@ async def run_broker(socket_file: Path):
|
|||
# configure the broker with a single, external listener
|
||||
cfg = BrokerConfig(
|
||||
listeners={
|
||||
'default': ListenerConfig(
|
||||
"default": ListenerConfig(
|
||||
type=ListenerType.EXTERNAL
|
||||
)
|
||||
},
|
||||
|
@ -109,7 +108,7 @@ async def run_broker(socket_file: Path):
|
|||
# passes the connection to the broker for protocol communications
|
||||
await b.external_connected(reader=r, writer=w, listener_name=listener_name)
|
||||
|
||||
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name='default'), path=socket_file)
|
||||
await asyncio.start_unix_server(partial(unix_stream_connected, listener_name="default"), path=socket_file)
|
||||
await b.start()
|
||||
|
||||
try:
|
||||
|
@ -163,7 +162,7 @@ async def run_client(socket_file: Path):
|
|||
try:
|
||||
while True:
|
||||
# periodically send a message
|
||||
await cph.mqtt_publish('my/topic', b'my message', 0, False)
|
||||
await cph.mqtt_publish("my/topic", b"my message", 0, False)
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
cph.detach()
|
||||
|
|
Ładowanie…
Reference in New Issue