Merge pull request #249 from ajmirsky/plugin_config_casting

improve static type checking for plugin's `Config` class

( no functionality change, mypy checking improvement and comments only )
pull/254/head
Andrew Mirsky 2025-07-04 17:10:31 -04:00 zatwierdzone przez GitHub
commit 701b21272c
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
1 zmienionych plików z 30 dodań i 5 usunięć

Wyświetl plik

@ -1,5 +1,5 @@
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar, cast
from amqtt.contexts import Action, BaseContext from amqtt.contexts import Action, BaseContext
from amqtt.session import Session from amqtt.session import Session
@ -8,28 +8,51 @@ C = TypeVar("C", bound=BaseContext)
class BasePlugin(Generic[C]): class BasePlugin(Generic[C]):
"""The base from which all plugins should inherit.""" """The base from which all plugins should inherit.
Type Parameters
---------------
C:
A BaseContext: either BrokerContext or ClientContext, depending on plugin usage
Attributes
----------
context (C):
Information about the environment in which this plugin is executed. Modifying
the broker or client state should happen through methods available here.
config (self.Config):
An instance of the Config dataclass defined by the plugin (or an empty dataclass, if not
defined). If using entrypoint- or mixed-style configuration, use `_get_config_option()`
to access the variable.
"""
def __init__(self, context: C) -> None: def __init__(self, context: C) -> None:
self.context: C = context self.context: C = context
# since the PluginManager will hydrate the config from a plugin's `Config` class, this is a safe cast
self.config = cast("self.Config", context.config) # type: ignore[name-defined]
# Deprecated: included to support entrypoint-style configs. Replaced by dataclass Config class.
def _get_config_section(self, name: str) -> dict[str, Any] | None: def _get_config_section(self, name: str) -> dict[str, Any] | None:
if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None): if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
return None return None
section_config: int | dict[str, Any] | None = self.context.config.get(name, None) section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
# mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check # mypy has difficulty excluding int from `config`'s type, unless there's an explicit check
if isinstance(section_config, int): if isinstance(section_config, int):
return None return None
return section_config return section_config
# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
def _get_config_option(self, option_name: str, default: Any=None) -> Any: def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config: if not self.context.config:
return default return default
if is_dataclass(self.context.config): if is_dataclass(self.context.config):
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if option_name in self.context.config: if option_name in self.context.config:
return self.context.config[option_name] return self.context.config[option_name]
return default return default
@ -57,7 +80,8 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
return default return default
if is_dataclass(self.context.config): if is_dataclass(self.context.config):
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.topic_config and option_name in self.topic_config: if self.topic_config and option_name in self.topic_config:
return self.topic_config[option_name] return self.topic_config[option_name]
return default return default
@ -87,6 +111,7 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
return default return default
if is_dataclass(self.context.config): if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.auth_config and option_name in self.auth_config: if self.auth_config and option_name in self.auth_config:
return self.auth_config[option_name] return self.auth_config[option_name]