kopia lustrzana https://github.com/Yakifo/amqtt
type correction and linting fixes
rodzic
81866d0238
commit
4a39124545
|
@ -2,14 +2,11 @@ import asyncio
|
|||
from asyncio import CancelledError, futures
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from enum import StrEnum
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
import ssl
|
||||
from typing import Any, ClassVar, TypeAlias
|
||||
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig
|
||||
|
||||
from transitions import Machine, MachineError
|
||||
import websockets.asyncio.server
|
||||
|
@ -23,18 +20,17 @@ from amqtt.adapters import (
|
|||
WebSocketsWriter,
|
||||
WriterAdapter,
|
||||
)
|
||||
from amqtt.contexts import Action, BaseContext, BrokerConfig
|
||||
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig
|
||||
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
|
||||
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
|
||||
from amqtt.utils import format_client_message, gen_client_id
|
||||
|
||||
from .events import BrokerEvents
|
||||
from .mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from .mqtt.disconnect import DisconnectPacket
|
||||
from .plugins.manager import PluginManager
|
||||
|
||||
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
||||
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
||||
|
||||
# Default port numbers
|
||||
|
@ -147,14 +143,17 @@ class Broker:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: _CONFIG_LISTENER | None = None,
|
||||
config: BrokerConfig | dict[str, Any] | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
plugin_namespace: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the broker."""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.config = BrokerConfig.from_dict(config)
|
||||
if isinstance(config, dict):
|
||||
self.config = BrokerConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config or BrokerConfig()
|
||||
|
||||
# listeners are populated from default within BrokerConfig
|
||||
self.listeners_config = self.config.listeners
|
||||
|
@ -247,7 +246,8 @@ class Broker:
|
|||
|
||||
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
||||
|
||||
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
|
||||
@staticmethod
|
||||
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
|
||||
"""Create an SSL context for a listener."""
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(
|
||||
|
@ -680,7 +680,7 @@ class Broker:
|
|||
except Exception:
|
||||
self.logger.exception("Failed to stop handler")
|
||||
|
||||
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
|
||||
async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
|
||||
"""Call the authenticate method on registered plugins to test user authentication.
|
||||
|
||||
User is considered authenticated if all plugins called returns True.
|
||||
|
|
|
@ -2,15 +2,11 @@ import asyncio
|
|||
from collections import deque
|
||||
from collections.abc import Callable, Coroutine
|
||||
import contextlib
|
||||
import copy
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig
|
||||
|
||||
import websockets
|
||||
from websockets import HeadersLike, InvalidHandshake, InvalidURI
|
||||
|
@ -21,14 +17,14 @@ from amqtt.adapters import (
|
|||
WebSocketsReader,
|
||||
WebSocketsWriter,
|
||||
)
|
||||
from amqtt.contexts import BaseContext, ClientConfig, ConnectionConfig
|
||||
from amqtt.contexts import BaseContext, ClientConfig
|
||||
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
|
||||
from amqtt.mqtt.connack import CONNECTION_ACCEPTED
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
from amqtt.utils import gen_client_id, read_yaml_config
|
||||
from amqtt.utils import gen_client_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
@ -42,7 +38,7 @@ class ClientContext(BaseContext):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.config: ClientConfig | None = None
|
||||
|
||||
|
||||
base_logger = logging.getLogger(__name__)
|
||||
|
@ -94,9 +90,13 @@ class MQTTClient:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config = ClientConfig.from_dict(config)
|
||||
|
||||
if isinstance(config, dict):
|
||||
self.config = ClientConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config or ClientConfig()
|
||||
|
||||
self.client_id = client_id if client_id is not None else gen_client_id()
|
||||
|
||||
|
|
|
@ -1,9 +1,19 @@
|
|||
import warnings
|
||||
from dataclasses import dataclass, field, fields, replace, asdict
|
||||
from enum import Enum, StrEnum
|
||||
from dataclasses import dataclass, field, fields, replace
|
||||
|
||||
try:
|
||||
from enum import Enum, StrEnum
|
||||
except ImportError:
|
||||
# support for python 3.10
|
||||
from enum import Enum
|
||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, LiteralString, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
|
||||
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_2
|
||||
|
||||
|
@ -14,14 +24,13 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dacite import from_dict as dict_to_dataclass, Config as DaciteConfig, UnexpectedDataError
|
||||
|
||||
|
||||
class BaseContext:
|
||||
def __init__(self) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.logger: logging.Logger = _LOGGER
|
||||
self.config: dict[str, Any] | None = None
|
||||
# cleanup with a `Generic` type
|
||||
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
|
@ -32,33 +41,44 @@ class Action(Enum):
|
|||
|
||||
|
||||
class ListenerType(StrEnum):
|
||||
TCP = 'tcp'
|
||||
WS = 'ws'
|
||||
"""Types of mqtt listeners."""
|
||||
|
||||
TCP = "tcp"
|
||||
WS = "ws"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'"{str(self.value)}"'
|
||||
"""Display the string value, instead of the enum member."""
|
||||
return f'"{self.value!s}"'
|
||||
|
||||
class Dictable:
|
||||
def __getitem__(self, key):
|
||||
"""Add dictionary methods to a dataclass."""
|
||||
|
||||
def __getitem__(self, key:str) -> Any:
|
||||
"""Allow dict-style `[]` access to a dataclass."""
|
||||
return self.get(key)
|
||||
|
||||
def get(self, name, default=None):
|
||||
def get(self, name:str, default:Any=None) -> Any:
|
||||
"""Allow dict-style access to a dataclass."""
|
||||
name = name.replace("-", "_")
|
||||
if hasattr(self, name):
|
||||
return getattr(self, name)
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"'{name}' is not defined")
|
||||
msg = f"'{name}' is not defined"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __contains__(self, name):
|
||||
return getattr(self, name.replace('-', '_'), None) is not None
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Provide dict-style 'in' check."""
|
||||
return getattr(self, name.replace("-", "_"), None) is not None
|
||||
|
||||
def __iter__(self):
|
||||
for field in fields(self):
|
||||
yield getattr(self, field.name)
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Provide dict-style iteration."""
|
||||
for f in fields(self): # type: ignore[arg-type]
|
||||
yield getattr(self, f.name)
|
||||
|
||||
def copy(self):
|
||||
return replace(self)
|
||||
def copy(self) -> dataclass: # type: ignore[valid-type]
|
||||
"""Return a copy of the dataclass."""
|
||||
return replace(self) # type: ignore[type-var]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -87,96 +107,102 @@ class ListenerConfig(Dictable):
|
|||
keyfile: str | Path | None = None
|
||||
"""Full path to file in PEM format containing the server's private key."""
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if (self.certfile is None) ^ (self.keyfile is None):
|
||||
msg = "If specifying the 'certfile' or 'keyfile', both are required."
|
||||
raise ValueError(msg)
|
||||
|
||||
for fn in ('cafile', 'capath', 'certfile', 'keyfile'):
|
||||
for fn in ("cafile", "capath", "certfile", "keyfile"):
|
||||
if isinstance(getattr(self, fn), str):
|
||||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
|
||||
def apply(self, other):
|
||||
def apply(self, other: "ListenerConfig") -> None:
|
||||
"""Apply the field from 'other', if 'self' field is default."""
|
||||
if not isinstance(other, ListenerConfig):
|
||||
msg = f'cannot apply {self.__class__.__name__} to {other.__class__.__name__}'
|
||||
raise TypeError(msg)
|
||||
for f in fields(self):
|
||||
if getattr(self, f.name) == f.default:
|
||||
setattr(self, f.name, other[f.name])
|
||||
|
||||
def default_listeners():
|
||||
def default_listeners() -> dict[str, Any]:
|
||||
"""Create defaults for BrokerConfig.listeners."""
|
||||
return {
|
||||
'default': ListenerConfig()
|
||||
"default": ListenerConfig()
|
||||
}
|
||||
|
||||
def default_broker_plugins():
|
||||
def default_broker_plugins() -> dict[str, Any]:
|
||||
"""Create defaults for BrokerConfig.plugins."""
|
||||
return {
|
||||
'amqtt.plugins.logging_amqtt.EventLoggerPlugin':{},
|
||||
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{},
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin':{'allow_anonymous':True},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin':{'sys_interval':20}
|
||||
"amqtt.plugins.logging_amqtt.EventLoggerPlugin":{},
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{},
|
||||
"amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True},
|
||||
"amqtt.plugins.sys.broker.BrokerSysPlugin":{"sys_interval":20}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrokerConfig(Dictable):
|
||||
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
||||
listeners: dict[Literal['default'] | str, ListenerConfig] = field(default_factory=default_listeners)
|
||||
|
||||
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
|
||||
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
|
||||
does not set a value, the 'default' settings are applied. See
|
||||
[ListenerConfig](#amqtt.contexts.ListenerConfig) for more information."""
|
||||
[ListenerConfig](./#amqtt.contexts.ListenerConfig) for more information."""
|
||||
sys_interval: int | None = None
|
||||
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../packaged_plugins.md/#sys-topics)
|
||||
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../packaged_plugins.md/#sys-topics)
|
||||
configuration instead.*"""
|
||||
timeout_disconnect_delay: int | None = 0
|
||||
"""Client disconnect timeout without a keep-alive."""
|
||||
auth: dict[str, Any] | None = None
|
||||
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`AnonymousAuthPlugin`](#anonymous-auth-plugin) and
|
||||
"""*Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`AnonymousAuthPlugin`](./#anonymous-auth-plugin) and
|
||||
[`FileAuthPlugin`](/packaged_plugins/#password-file-auth-plugin) for more information.*"""
|
||||
topic_check: dict[str, Any] | None = None
|
||||
"""Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
"""Deprecated field used to config EntryPoint-loaded plugins. See
|
||||
[`TopicTabooPlugin`](#taboo-topic-plugin) and
|
||||
[`TopicACLPlugin`](#acl-topic-plugin) for more information.*"""
|
||||
plugins: dict | list | None = field(default_factory=default_broker_plugins)
|
||||
plugins: dict[str, Any] | list[dict[str,Any]] | None = field(default_factory=default_broker_plugins)
|
||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
|
||||
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
||||
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
|
||||
[Plugins](http://localhost:8000/custom_plugins/) for more information."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.sys_interval is not None:
|
||||
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
|
||||
|
||||
if self.auth is not None or self.topic_check is not None:
|
||||
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
|
||||
|
||||
default_listener = self.listeners['default']
|
||||
default_listener = self.listeners["default"]
|
||||
for listener_name, listener in self.listeners.items():
|
||||
if listener_name == 'default':
|
||||
if listener_name == "default":
|
||||
continue
|
||||
listener.apply(default_listener)
|
||||
|
||||
if isinstance(self.plugins, list):
|
||||
_plugins = {}
|
||||
_plugins: dict[str, Any] = {}
|
||||
for plugin in self.plugins:
|
||||
if isinstance(plugin, str):
|
||||
_plugins |= {plugin:{}}
|
||||
# in case a plugin in a yaml file is listed without config
|
||||
if isinstance(plugin, str): # type: ignore[unreachable]
|
||||
_plugins |= {plugin:{}} # type: ignore[unreachable]
|
||||
continue
|
||||
_plugins |= plugin
|
||||
self.plugins = _plugins
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> 'BrokerConfig':
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
|
||||
"""Create a broker config from a dictionary."""
|
||||
if d is None:
|
||||
return BrokerConfig()
|
||||
|
||||
if 'topic-check' in d:
|
||||
d['topic_check'] = d['topic-check']
|
||||
del d['topic-check']
|
||||
# patch the incoming dictionary so it can be loaded correctly
|
||||
if "topic-check" in d:
|
||||
d["topic_check"] = d["topic-check"]
|
||||
del d["topic-check"]
|
||||
|
||||
if ('auth' in d or 'topic-check' in d) and 'plugins' not in d:
|
||||
d['plugins'] = None
|
||||
# identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
|
||||
if ("auth" in d or "topic-check" in d) and "plugins" not in d:
|
||||
d["plugins"] = None
|
||||
|
||||
return dict_to_dataclass(data_class=BrokerConfig,
|
||||
data=d,
|
||||
|
@ -188,6 +214,8 @@ class BrokerConfig(Dictable):
|
|||
|
||||
@dataclass
|
||||
class ConnectionConfig(Dictable):
|
||||
"""Properties for connecting to the broker."""
|
||||
|
||||
uri: str | None = "mqtt://127.0.0.1:1883"
|
||||
"""URI of the broker"""
|
||||
cafile: str | Path | None = None
|
||||
|
@ -206,24 +234,29 @@ class ConnectionConfig(Dictable):
|
|||
"""Full path to file in PEM format containing the client's private key associated with the certfile."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if (self.certfile is None) ^ (self.keyfile is None):
|
||||
msg = "If specifying the 'certfile' or 'keyfile', both are required."
|
||||
raise ValueError(msg)
|
||||
|
||||
for fn in ('cafile', 'capath', 'certfile', 'keyfile'):
|
||||
for fn in ("cafile", "capath", "certfile", "keyfile"):
|
||||
if isinstance(getattr(self, fn), str):
|
||||
setattr(self, fn, Path(getattr(self, fn)))
|
||||
|
||||
@dataclass
|
||||
class TopicConfig(Dictable):
|
||||
"""Configuration of how messages to specific topics are published. The topic name is
|
||||
specified as the key in the dictionary of the `ClientConfig.topics."""
|
||||
"""Configuration of how messages to specific topics are published.
|
||||
|
||||
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
|
||||
"""
|
||||
|
||||
qos: int = 0
|
||||
"""The quality of service associated with the publishing to this topic."""
|
||||
retain: bool = False
|
||||
"""Determines if the message should be retained by the topic it was published."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
|
||||
msg = "Topic config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
@ -243,19 +276,23 @@ class WillConfig(Dictable):
|
|||
"""Determines if the message should be retained by the topic it was published."""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
|
||||
msg = "Will config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def default_client_plugins():
|
||||
def default_client_plugins() -> dict[str, Any]:
|
||||
"""Create defaults for `ClientConfig.plugins`."""
|
||||
return {
|
||||
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{}
|
||||
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig(Dictable):
|
||||
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
|
||||
|
||||
keep_alive: int | None = 10
|
||||
"""Keep-alive timeout sent to the broker."""
|
||||
ping_delay: int | None = 1
|
||||
|
@ -278,8 +315,8 @@ class ClientConfig(Dictable):
|
|||
"""Specify the topics and what flags should be set for messages published to them."""
|
||||
broker: ConnectionConfig | None = field(default_factory=ConnectionConfig)
|
||||
"""Configuration for connecting to the broker. See
|
||||
[ConnectionConfig](#amqtt.contexts.ConnectionConfig) for more information."""
|
||||
plugins: dict | list | None = field(default_factory=default_client_plugins)
|
||||
[ConnectionConfig](./#amqtt.contexts.ConnectionConfig) for more information."""
|
||||
plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
|
||||
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
|
||||
a dictionary of configuration options for that plugin. See [Plugins](http://localhost:8000/custom_plugins/)
|
||||
for more information."""
|
||||
|
@ -287,15 +324,17 @@ class ClientConfig(Dictable):
|
|||
"""If establishing a secure connection, should the hostname of the certificate be verified."""
|
||||
will: WillConfig | None = None
|
||||
"""Message, topic and flags that should be sent to if the client disconnects. See
|
||||
[WillConfig](#amqtt.contexts.WillConfig)"""
|
||||
[WillConfig](./#amqtt.contexts.WillConfig)"""
|
||||
|
||||
def __post__init__(self) -> None:
|
||||
"""Check config for errors and transform fields for easier use."""
|
||||
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
|
||||
msg = "Client config: default QoS must be 0, 1 or 2."
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> 'ClientConfig':
|
||||
def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
|
||||
"""Create a client config from a dictionary."""
|
||||
if d is None:
|
||||
return ClientConfig()
|
||||
|
||||
|
@ -305,5 +344,3 @@ class ClientConfig(Dictable):
|
|||
cast=[StrEnum],
|
||||
strict=True)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class BasePlugin(Generic[C]):
|
|||
|
||||
if is_dataclass(self.context.config):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default)
|
||||
if option_name in self.context.config:
|
||||
return self.context.config[option_name]
|
||||
return default
|
||||
|
@ -79,9 +79,10 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
if not self.context.config:
|
||||
return default
|
||||
|
||||
# overloaded context.config with either BrokerConfig or plugin's Config
|
||||
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default)
|
||||
if self.topic_config and option_name in self.topic_config:
|
||||
return self.topic_config[option_name]
|
||||
return default
|
||||
|
@ -112,7 +113,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
|
|||
|
||||
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
|
||||
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
|
||||
return getattr(self.context.config, option_name.replace("-", "_"), default)
|
||||
if self.auth_config and option_name in self.auth_config:
|
||||
return self.auth_config[option_name]
|
||||
return default
|
||||
|
|
|
@ -96,9 +96,9 @@ class PluginManager(Generic[C]):
|
|||
# plugins loaded directly from config dictionary
|
||||
|
||||
|
||||
if 'auth' in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
|
||||
if 'topic-check' in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
|
||||
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
|
||||
|
||||
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
|
||||
|
|
|
@ -1 +1,5 @@
|
|||
{% if 'default_factory' in expression.__str__() %}{{ obj.extra.dataclass_ext.default_factory | safe }}{% else %}{{ expression | safe }}{% endif %}
|
||||
{% if 'default_factory' in expression.__str__() %}
|
||||
{{ obj.extra.dataclass_ext.default_factory | safe }}
|
||||
{% else %}
|
||||
{% extends "_base/expression.html.jinja" %}
|
||||
{% endif %}
|
Ładowanie…
Reference in New Issue