kopia lustrzana https://github.com/Yakifo/amqtt
Yakifo/amqtt#258 (in progress): a plugin which requests topic acl via http
rodzic
62470645b8
commit
0575f0e041
|
@ -681,7 +681,7 @@ class Broker:
|
|||
|
||||
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
||||
if not permitted:
|
||||
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
|
||||
self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
|
||||
else:
|
||||
await self.plugins_manager.fire_event(
|
||||
BrokerEvents.MESSAGE_RECEIVED,
|
||||
|
@ -892,6 +892,12 @@ class Broker:
|
|||
for target_session, sub_qos in subscriptions:
|
||||
qos = broadcast.get("qos", sub_qos)
|
||||
|
||||
sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
|
||||
if not sendable:
|
||||
self.logger.info(
|
||||
f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
|
||||
continue
|
||||
|
||||
# Retain all messages which cannot be broadcasted, due to the session not being connected
|
||||
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
|
||||
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
|
||||
|
|
|
@ -20,3 +20,4 @@ class Action(Enum):
|
|||
|
||||
SUBSCRIBE = "subscribe"
|
||||
PUBLISH = "publish"
|
||||
RECEIVE = "receive"
|
||||
|
|
|
@ -97,8 +97,20 @@ class HttpAuthACL(BaseAuthPlugin, BaseTopicPlugin):
|
|||
async def topic_filtering(self, *,
|
||||
session: Session | None = None,
|
||||
topic: str | None = None,
|
||||
action: Action | None = None) -> bool:
|
||||
return False
|
||||
action: Action | None = None) -> bool | None:
|
||||
if not session:
|
||||
return None
|
||||
acc = 0
|
||||
match action:
|
||||
case Action.PUBLISH:
|
||||
acc = 2
|
||||
case Action.SUBSCRIBE:
|
||||
acc = 4
|
||||
case Action.RECEIVE:
|
||||
acc = 1
|
||||
|
||||
d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc}
|
||||
return await self._send_request(self.get_url(self.config.acl_uri), d)
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
|
@ -128,7 +140,7 @@ class HttpAuthACL(BaseAuthPlugin, BaseTopicPlugin):
|
|||
- username *(str)*
|
||||
- client_id *(str)*
|
||||
- topic *(str)*
|
||||
- acc *(int)* read only = 1, write only = 2, read & write = 3 and subscribe = 4
|
||||
- acc *(int)* client can receive (1), can publish(2), can receive & publish (3) and can subscribe (4)
|
||||
"""
|
||||
|
||||
host: str
|
||||
|
|
|
@ -88,7 +88,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
) -> bool | None:
|
||||
"""Logic for filtering out topics.
|
||||
|
||||
Args:
|
||||
|
@ -97,7 +97,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
|
|||
action: amqtt.broker.Action
|
||||
|
||||
Returns:
|
||||
bool: `True` if topic is allowed, `False` otherwise
|
||||
bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined
|
||||
|
||||
"""
|
||||
return bool(self.topic_config) or is_dataclass(self.context.config)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from amqtt.contexts import Action, BaseContext
|
||||
from amqtt.errors import PluginInitError
|
||||
from amqtt.plugins.base import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
|
@ -13,7 +15,7 @@ class TopicTabooPlugin(BaseTopicPlugin):
|
|||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
) -> bool | None:
|
||||
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
|
||||
if filter_result:
|
||||
if session and session.username == "admin":
|
||||
|
@ -24,6 +26,16 @@ class TopicTabooPlugin(BaseTopicPlugin):
|
|||
|
||||
class TopicAccessControlListPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
if self._get_config_option("acl", None):
|
||||
warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1)
|
||||
|
||||
if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None):
|
||||
msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included"
|
||||
raise PluginInitError(msg)
|
||||
|
||||
@staticmethod
|
||||
def topic_ac(topic_requested: str, topic_allowed: str) -> bool:
|
||||
req_split = topic_requested.split("/")
|
||||
|
@ -46,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
|||
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
) -> bool | None:
|
||||
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
|
||||
if not filter_result:
|
||||
return False
|
||||
|
@ -58,18 +70,26 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
|
|||
|
||||
req_topic = topic
|
||||
if not req_topic:
|
||||
return False\
|
||||
return False
|
||||
|
||||
username = session.username if session else None
|
||||
if username is None:
|
||||
username = "anonymous"
|
||||
|
||||
acl: dict[str, Any] = {}
|
||||
acl: dict[str, Any] | None = None
|
||||
match action:
|
||||
case Action.PUBLISH:
|
||||
acl = self._get_config_option("publish-acl", {})
|
||||
acl = self._get_config_option("publish-acl", None)
|
||||
case Action.SUBSCRIBE:
|
||||
acl = self._get_config_option("acl", {})
|
||||
acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None))
|
||||
case Action.RECEIVE:
|
||||
acl = self._get_config_option("receive-acl", None)
|
||||
case _:
|
||||
msg = "Received an invalid action type."
|
||||
raise ValueError(msg)
|
||||
|
||||
if acl is None:
|
||||
return True
|
||||
|
||||
allowed_topics = acl.get(username, [])
|
||||
if not allowed_topics:
|
||||
|
|
|
@ -96,16 +96,31 @@ none
|
|||
|
||||
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`
|
||||
is used by the aMQTT `Broker` to determine if a connection from a client is allowed by
|
||||
implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise.
|
||||
implementing the `authenticate` method and returning:
|
||||
- `True` if the session is allowed
|
||||
- `False` if not allowed
|
||||
- `None` if plugin can't determine authentication
|
||||
|
||||
If there are multiple authentication plugins:
|
||||
- at least one plugin must return `True` to allow access
|
||||
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
|
||||
- `None` gets ignored from the determination
|
||||
|
||||
::: amqtt.plugins.base.BaseAuthPlugin
|
||||
|
||||
## Topic Filter Plugins
|
||||
|
||||
In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin`
|
||||
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE)
|
||||
messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or
|
||||
`False` otherwise.
|
||||
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH), receive (RECEIVE)
|
||||
and/or subscribe (SUBSCRIBE) messages to a particular topic by implementing the `topic_filtering` method and returning:
|
||||
- `True` if topic is allowed
|
||||
- `False` if not allowed
|
||||
- `None` will be ignored
|
||||
|
||||
If there are multiple topic plugins:
|
||||
- at least one plugin must return `True` to allow access
|
||||
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
|
||||
- `None` will be ignored
|
||||
|
||||
::: amqtt.plugins.base.BaseTopicPlugin
|
||||
|
||||
|
|
|
@ -134,13 +134,18 @@ plugins:
|
|||
|
||||
**Configuration**
|
||||
|
||||
- `acl` *(mapping)*: determines subscription access
|
||||
- `subscribe-acl` *(mapping)*: determines subscription access. If absent, no restrictions are placed on client subscriptions.
|
||||
The list should be a key-value pair, where:
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
- `acl` *(mapping)*: Deprecated and replaced by `subscribe-acl`.
|
||||
|
||||
- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing.
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing.
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
- `receive-acl` *(mapping)*: determines if a message can be sent. If absent, no restrictions are placed on client's receiving messages.
|
||||
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
|
||||
|
||||
|
||||
!!! info "Reserved usernames"
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ from aiohttp import web
|
|||
from aiohttp.web import Response, Request
|
||||
|
||||
from amqtt.broker import BrokerContext, Broker
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.contrib.http_acl import HttpAuthACL, ParamsMode, ResponseMode, RequestMethod
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def determine_response_mode(d) -> Response:
|
||||
def determine_auth_response_mode(d) -> Response:
|
||||
assert 'username' in d
|
||||
assert 'password' in d
|
||||
assert 'client_id' in d
|
||||
|
@ -31,29 +32,29 @@ class JsonAuthView(web.View):
|
|||
|
||||
async def get(self) -> Response:
|
||||
d = await self.request.json()
|
||||
return determine_response_mode(d)
|
||||
return determine_auth_response_mode(d)
|
||||
|
||||
async def post(self) -> Response:
|
||||
d = dict(await self.request.json())
|
||||
return determine_response_mode(d)
|
||||
return determine_auth_response_mode(d)
|
||||
|
||||
async def put(self) -> Response:
|
||||
d = dict(await self.request.json())
|
||||
return determine_response_mode(d)
|
||||
return determine_auth_response_mode(d)
|
||||
|
||||
class FormAuthView(web.View):
|
||||
|
||||
async def get(self) -> Response:
|
||||
d = self.request.query
|
||||
return determine_response_mode(d)
|
||||
return determine_auth_response_mode(d)
|
||||
|
||||
async def post(self) -> Response:
|
||||
d = dict(await self.request.post())
|
||||
return determine_response_mode(d)
|
||||
return determine_auth_response_mode(d)
|
||||
|
||||
async def put(self) -> Response:
|
||||
d = dict(await self.request.post())
|
||||
return determine_response_mode(d)
|
||||
return determine_auth_response_mode(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -64,7 +65,7 @@ async def empty_broker():
|
|||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def http_acl_server():
|
||||
async def http_auth_server():
|
||||
app = web.Application()
|
||||
app.add_routes([
|
||||
web.view('/user/json', JsonAuthView),
|
||||
|
@ -80,10 +81,10 @@ async def http_acl_server():
|
|||
await runner.cleanup()
|
||||
|
||||
|
||||
def test_server_up_and_down(http_acl_server):
|
||||
def test_server_up_and_down(http_auth_server):
|
||||
pass
|
||||
|
||||
def generate_cases():
|
||||
def generate_use_cases(root_url):
|
||||
# generate all variations of:
|
||||
# ('/user/json', RequestMethod.GET, ParamsMode.JSON, ResponseMode.JSON, 'json', 'json', True),
|
||||
|
||||
|
@ -91,22 +92,22 @@ def generate_cases():
|
|||
for request in RequestMethod:
|
||||
for params in ParamsMode:
|
||||
for response in ResponseMode:
|
||||
url = '/user/json' if params == ParamsMode.JSON else '/user/form'
|
||||
url = f'/{root_url}/json' if params == ParamsMode.JSON else f'/{root_url}/form'
|
||||
for is_authenticated in [True, False]:
|
||||
prefix = '' if is_authenticated else 'not'
|
||||
case = (url, request, params, response, str(response), f"{prefix}{str(response)}", is_authenticated)
|
||||
cases.append(case)
|
||||
return cases
|
||||
|
||||
def test_generated_cases():
|
||||
cases = generate_cases()
|
||||
def test_generated_use_cases():
|
||||
cases = generate_use_cases('user')
|
||||
assert len(cases) == 36
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,password,is_authenticated",
|
||||
generate_cases())
|
||||
generate_use_cases('user'))
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_params_response(empty_broker, http_acl_server, url,
|
||||
async def test_request_auth_response(empty_broker, http_auth_server, url,
|
||||
request_method, params_mode, response_mode,
|
||||
username, password, is_authenticated):
|
||||
|
||||
|
@ -129,3 +130,96 @@ async def test_request_params_response(empty_broker, http_acl_server, url,
|
|||
assert await http_acl.authenticate(session=session) == is_authenticated
|
||||
|
||||
await http_acl.on_broker_pre_shutdown()
|
||||
|
||||
|
||||
def determine_acl_response(d) -> Response:
|
||||
assert 'username' in d
|
||||
assert 'client_id' in d
|
||||
assert 'topic' in d
|
||||
assert 'acc' in d
|
||||
if d['username'] == 'json':
|
||||
return web.json_response({'Ok': d['username'] == d['client_id']})
|
||||
elif d['username'] == 'status':
|
||||
return web.Response(status=200) if d['username'] == d['client_id'] else web.Response(status=400)
|
||||
else: # text
|
||||
return web.Response(text='ok' if d['username'] == d['client_id'] else 'error')
|
||||
|
||||
|
||||
class JsonACLView(web.View):
|
||||
|
||||
async def get(self) -> Response:
|
||||
d = await self.request.json()
|
||||
return determine_acl_response(d)
|
||||
|
||||
async def post(self) -> Response:
|
||||
d = dict(await self.request.json())
|
||||
return determine_acl_response(d)
|
||||
|
||||
async def put(self) -> Response:
|
||||
d = dict(await self.request.json())
|
||||
return determine_acl_response(d)
|
||||
|
||||
|
||||
class FormACLView(web.View):
|
||||
|
||||
async def get(self) -> Response:
|
||||
d = self.request.query
|
||||
return determine_acl_response(d)
|
||||
|
||||
async def post(self) -> Response:
|
||||
d = dict(await self.request.post())
|
||||
return determine_acl_response(d)
|
||||
|
||||
async def put(self) -> Response:
|
||||
d = dict(await self.request.post())
|
||||
return determine_acl_response(d)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def http_acl_server():
|
||||
app = web.Application()
|
||||
app.add_routes([
|
||||
web.view('/acl/json', JsonACLView),
|
||||
web.view('/acl/form', FormACLView),
|
||||
])
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "localhost", 8080)
|
||||
await site.start()
|
||||
|
||||
yield f"http://localhost:8080"
|
||||
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,client_id,is_authenticated",
|
||||
generate_use_cases('acl'))
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_acl_response(empty_broker, http_acl_server, url,
|
||||
request_method, params_mode, response_mode,
|
||||
username, client_id, is_authenticated):
|
||||
|
||||
# url = '/acl/json'
|
||||
# request_method = RequestMethod.GET
|
||||
# params_mode = ParamsMode.JSON
|
||||
# response_mode = ResponseMode.JSON
|
||||
|
||||
context = BrokerContext(broker=empty_broker)
|
||||
context.config = HttpAuthACL.Config(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
user_uri='/user',
|
||||
acl_uri=url,
|
||||
request_method=request_method,
|
||||
params_mode=params_mode,
|
||||
response_mode=response_mode,
|
||||
)
|
||||
http_acl = HttpAuthACL(context)
|
||||
|
||||
s = Session()
|
||||
s.username = username
|
||||
s.client_id = client_id
|
||||
t = 'my/topic'
|
||||
a = Action.PUBLISH
|
||||
|
||||
assert await http_acl.topic_filtering(session=s, topic=t, action=a) == is_authenticated
|
Ładowanie…
Reference in New Issue