Yakifo/amqtt#258 (in progress): a plugin which requests topic acl via http

pull/262/head
Andrew Mirsky 2025-07-08 22:05:58 -04:00
rodzic 62470645b8
commit 0575f0e041
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 188 dodań i 35 usunięć

Wyświetl plik

@ -681,7 +681,7 @@ class Broker:
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted: if not permitted:
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
else: else:
await self.plugins_manager.fire_event( await self.plugins_manager.fire_event(
BrokerEvents.MESSAGE_RECEIVED, BrokerEvents.MESSAGE_RECEIVED,
@ -892,6 +892,12 @@ class Broker:
for target_session, sub_qos in subscriptions: for target_session, sub_qos in subscriptions:
qos = broadcast.get("qos", sub_qos) qos = broadcast.get("qos", sub_qos)
sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
if not sendable:
self.logger.info(
f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
continue
# Retain all messages which cannot be broadcasted, due to the session not being connected # Retain all messages which cannot be broadcasted, due to the session not being connected
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4] # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
# and, if a client used anonymous authentication, there is no expectation that messages should be retained # and, if a client used anonymous authentication, there is no expectation that messages should be retained

Wyświetl plik

@ -20,3 +20,4 @@ class Action(Enum):
SUBSCRIBE = "subscribe" SUBSCRIBE = "subscribe"
PUBLISH = "publish" PUBLISH = "publish"
RECEIVE = "receive"

Wyświetl plik

@ -97,8 +97,20 @@ class HttpAuthACL(BaseAuthPlugin, BaseTopicPlugin):
async def topic_filtering(self, *, async def topic_filtering(self, *,
session: Session | None = None, session: Session | None = None,
topic: str | None = None, topic: str | None = None,
action: Action | None = None) -> bool: action: Action | None = None) -> bool | None:
return False if not session:
return None
acc = 0
match action:
case Action.PUBLISH:
acc = 2
case Action.SUBSCRIBE:
acc = 4
case Action.RECEIVE:
acc = 1
d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc}
return await self._send_request(self.get_url(self.config.acl_uri), d)
@dataclass @dataclass
class Config: class Config:
@ -128,7 +140,7 @@ class HttpAuthACL(BaseAuthPlugin, BaseTopicPlugin):
- username *(str)* - username *(str)*
- client_id *(str)* - client_id *(str)*
- topic *(str)* - topic *(str)*
- acc *(int)* read only = 1, write only = 2, read & write = 3 and subscribe = 4 - acc *(int)* client can receive (1), can publish(2), can receive & publish (3) and can subscribe (4)
""" """
host: str host: str

Wyświetl plik

@ -88,7 +88,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
async def topic_filtering( async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool: ) -> bool | None:
"""Logic for filtering out topics. """Logic for filtering out topics.
Args: Args:
@ -97,7 +97,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
action: amqtt.broker.Action action: amqtt.broker.Action
Returns: Returns:
bool: `True` if topic is allowed, `False` otherwise bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined
""" """
return bool(self.topic_config) or is_dataclass(self.context.config) return bool(self.topic_config) or is_dataclass(self.context.config)

Wyświetl plik

@ -1,7 +1,9 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
import warnings
from amqtt.contexts import Action, BaseContext from amqtt.contexts import Action, BaseContext
from amqtt.errors import PluginInitError
from amqtt.plugins.base import BaseTopicPlugin from amqtt.plugins.base import BaseTopicPlugin
from amqtt.session import Session from amqtt.session import Session
@ -13,7 +15,7 @@ class TopicTabooPlugin(BaseTopicPlugin):
async def topic_filtering( async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool: ) -> bool | None:
filter_result = await super().topic_filtering(session=session, topic=topic, action=action) filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
if filter_result: if filter_result:
if session and session.username == "admin": if session and session.username == "admin":
@ -24,6 +26,16 @@ class TopicTabooPlugin(BaseTopicPlugin):
class TopicAccessControlListPlugin(BaseTopicPlugin): class TopicAccessControlListPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
if self._get_config_option("acl", None):
warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1)
if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None):
msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included"
raise PluginInitError(msg)
@staticmethod @staticmethod
def topic_ac(topic_requested: str, topic_allowed: str) -> bool: def topic_ac(topic_requested: str, topic_allowed: str) -> bool:
req_split = topic_requested.split("/") req_split = topic_requested.split("/")
@ -46,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
async def topic_filtering( async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool: ) -> bool | None:
filter_result = await super().topic_filtering(session=session, topic=topic, action=action) filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
if not filter_result: if not filter_result:
return False return False
@ -58,18 +70,26 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
req_topic = topic req_topic = topic
if not req_topic: if not req_topic:
return False\ return False
username = session.username if session else None username = session.username if session else None
if username is None: if username is None:
username = "anonymous" username = "anonymous"
acl: dict[str, Any] = {} acl: dict[str, Any] | None = None
match action: match action:
case Action.PUBLISH: case Action.PUBLISH:
acl = self._get_config_option("publish-acl", {}) acl = self._get_config_option("publish-acl", None)
case Action.SUBSCRIBE: case Action.SUBSCRIBE:
acl = self._get_config_option("acl", {}) acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None))
case Action.RECEIVE:
acl = self._get_config_option("receive-acl", None)
case _:
msg = "Received an invalid action type."
raise ValueError(msg)
if acl is None:
return True
allowed_topics = acl.get(username, []) allowed_topics = acl.get(username, [])
if not allowed_topics: if not allowed_topics:

Wyświetl plik

@ -96,16 +96,31 @@ none
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin` In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`
is used by the aMQTT `Broker` to determine if a connection from a client is allowed by is used by the aMQTT `Broker` to determine if a connection from a client is allowed by
implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise. implementing the `authenticate` method and returning:
- `True` if the session is allowed
- `False` if not allowed
- `None` if plugin can't determine authentication
If there are multiple authentication plugins:
- at least one plugin must return `True` to allow access
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
- `None` gets ignored from the determination
::: amqtt.plugins.base.BaseAuthPlugin ::: amqtt.plugins.base.BaseAuthPlugin
## Topic Filter Plugins ## Topic Filter Plugins
In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin` In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin`
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE) is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH), receive (RECEIVE)
messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or and/or subscribe (SUBSCRIBE) messages to a particular topic by implementing the `topic_filtering` method and returning:
`False` otherwise. - `True` if topic is allowed
- `False` if not allowed
- `None` will be ignored
If there are multiple topic plugins:
- at least one plugin must return `True` to allow access
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
- `None` will be ignored
::: amqtt.plugins.base.BaseTopicPlugin ::: amqtt.plugins.base.BaseTopicPlugin

Wyświetl plik

@ -134,13 +134,18 @@ plugins:
**Configuration** **Configuration**
- `acl` *(mapping)*: determines subscription access - `subscribe-acl` *(mapping)*: determines subscription access. If absent, no restrictions are placed on client subscriptions.
The list should be a key-value pair, where: The list should be a key-value pair, where:
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). `<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
- `acl` *(mapping)*: Deprecated and replaced by `subscribe-acl`.
- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing. - `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing.
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). `<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
- `receive-acl` *(mapping)*: determines if a message can be sent. If absent, no restrictions are placed on client's receiving messages.
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
!!! info "Reserved usernames" !!! info "Reserved usernames"

Wyświetl plik

@ -9,13 +9,14 @@ from aiohttp import web
from aiohttp.web import Response, Request from aiohttp.web import Response, Request
from amqtt.broker import BrokerContext, Broker from amqtt.broker import BrokerContext, Broker
from amqtt.contexts import Action
from amqtt.contrib.http_acl import HttpAuthACL, ParamsMode, ResponseMode, RequestMethod from amqtt.contrib.http_acl import HttpAuthACL, ParamsMode, ResponseMode, RequestMethod
from amqtt.session import Session from amqtt.session import Session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def determine_response_mode(d) -> Response: def determine_auth_response_mode(d) -> Response:
assert 'username' in d assert 'username' in d
assert 'password' in d assert 'password' in d
assert 'client_id' in d assert 'client_id' in d
@ -31,29 +32,29 @@ class JsonAuthView(web.View):
async def get(self) -> Response: async def get(self) -> Response:
d = await self.request.json() d = await self.request.json()
return determine_response_mode(d) return determine_auth_response_mode(d)
async def post(self) -> Response: async def post(self) -> Response:
d = dict(await self.request.json()) d = dict(await self.request.json())
return determine_response_mode(d) return determine_auth_response_mode(d)
async def put(self) -> Response: async def put(self) -> Response:
d = dict(await self.request.json()) d = dict(await self.request.json())
return determine_response_mode(d) return determine_auth_response_mode(d)
class FormAuthView(web.View): class FormAuthView(web.View):
async def get(self) -> Response: async def get(self) -> Response:
d = self.request.query d = self.request.query
return determine_response_mode(d) return determine_auth_response_mode(d)
async def post(self) -> Response: async def post(self) -> Response:
d = dict(await self.request.post()) d = dict(await self.request.post())
return determine_response_mode(d) return determine_auth_response_mode(d)
async def put(self) -> Response: async def put(self) -> Response:
d = dict(await self.request.post()) d = dict(await self.request.post())
return determine_response_mode(d) return determine_auth_response_mode(d)
@pytest.fixture @pytest.fixture
@ -64,7 +65,7 @@ async def empty_broker():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def http_acl_server(): async def http_auth_server():
app = web.Application() app = web.Application()
app.add_routes([ app.add_routes([
web.view('/user/json', JsonAuthView), web.view('/user/json', JsonAuthView),
@ -80,10 +81,10 @@ async def http_acl_server():
await runner.cleanup() await runner.cleanup()
def test_server_up_and_down(http_acl_server): def test_server_up_and_down(http_auth_server):
pass pass
def generate_cases(): def generate_use_cases(root_url):
# generate all variations of: # generate all variations of:
# ('/user/json', RequestMethod.GET, ParamsMode.JSON, ResponseMode.JSON, 'json', 'json', True), # ('/user/json', RequestMethod.GET, ParamsMode.JSON, ResponseMode.JSON, 'json', 'json', True),
@ -91,22 +92,22 @@ def generate_cases():
for request in RequestMethod: for request in RequestMethod:
for params in ParamsMode: for params in ParamsMode:
for response in ResponseMode: for response in ResponseMode:
url = '/user/json' if params == ParamsMode.JSON else '/user/form' url = f'/{root_url}/json' if params == ParamsMode.JSON else f'/{root_url}/form'
for is_authenticated in [True, False]: for is_authenticated in [True, False]:
prefix = '' if is_authenticated else 'not' prefix = '' if is_authenticated else 'not'
case = (url, request, params, response, str(response), f"{prefix}{str(response)}", is_authenticated) case = (url, request, params, response, str(response), f"{prefix}{str(response)}", is_authenticated)
cases.append(case) cases.append(case)
return cases return cases
def test_generated_cases(): def test_generated_use_cases():
cases = generate_cases() cases = generate_use_cases('user')
assert len(cases) == 36 assert len(cases) == 36
@pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,password,is_authenticated", @pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,password,is_authenticated",
generate_cases()) generate_use_cases('user'))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_params_response(empty_broker, http_acl_server, url, async def test_request_auth_response(empty_broker, http_auth_server, url,
request_method, params_mode, response_mode, request_method, params_mode, response_mode,
username, password, is_authenticated): username, password, is_authenticated):
@ -129,3 +130,96 @@ async def test_request_params_response(empty_broker, http_acl_server, url,
assert await http_acl.authenticate(session=session) == is_authenticated assert await http_acl.authenticate(session=session) == is_authenticated
await http_acl.on_broker_pre_shutdown() await http_acl.on_broker_pre_shutdown()
def determine_acl_response(d) -> Response:
assert 'username' in d
assert 'client_id' in d
assert 'topic' in d
assert 'acc' in d
if d['username'] == 'json':
return web.json_response({'Ok': d['username'] == d['client_id']})
elif d['username'] == 'status':
return web.Response(status=200) if d['username'] == d['client_id'] else web.Response(status=400)
else: # text
return web.Response(text='ok' if d['username'] == d['client_id'] else 'error')
class JsonACLView(web.View):
async def get(self) -> Response:
d = await self.request.json()
return determine_acl_response(d)
async def post(self) -> Response:
d = dict(await self.request.json())
return determine_acl_response(d)
async def put(self) -> Response:
d = dict(await self.request.json())
return determine_acl_response(d)
class FormACLView(web.View):
async def get(self) -> Response:
d = self.request.query
return determine_acl_response(d)
async def post(self) -> Response:
d = dict(await self.request.post())
return determine_acl_response(d)
async def put(self) -> Response:
d = dict(await self.request.post())
return determine_acl_response(d)
@pytest_asyncio.fixture
async def http_acl_server():
app = web.Application()
app.add_routes([
web.view('/acl/json', JsonACLView),
web.view('/acl/form', FormACLView),
])
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", 8080)
await site.start()
yield f"http://localhost:8080"
await runner.cleanup()
@pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,client_id,is_authenticated",
generate_use_cases('acl'))
@pytest.mark.asyncio
async def test_request_acl_response(empty_broker, http_acl_server, url,
request_method, params_mode, response_mode,
username, client_id, is_authenticated):
# url = '/acl/json'
# request_method = RequestMethod.GET
# params_mode = ParamsMode.JSON
# response_mode = ResponseMode.JSON
context = BrokerContext(broker=empty_broker)
context.config = HttpAuthACL.Config(
host="localhost",
port=8080,
user_uri='/user',
acl_uri=url,
request_method=request_method,
params_mode=params_mode,
response_mode=response_mode,
)
http_acl = HttpAuthACL(context)
s = Session()
s.username = username
s.client_id = client_id
t = 'my/topic'
a = Action.PUBLISH
assert await http_acl.topic_filtering(session=s, topic=t, action=a) == is_authenticated