Custom error on CSRF failures, closes #2390

Uses https://github.com/simonw/asgi-csrf/issues/28
pull/2395/head
Simon Willison 2024-08-14 21:29:16 -07:00
rodzic 93067668fe
commit 06d4ffb92e
3 zmienionych plików z 27 dodań i 1 usunięć

Wyświetl plik

@ -1,3 +1,4 @@
from asgi_csrf import Errors
import asyncio import asyncio
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import asgi_csrf import asgi_csrf
@ -1657,6 +1658,16 @@ class Datasette:
if not database.is_mutable: if not database.is_mutable:
await database.table_counts(limit=60 * 60 * 1000) await database.table_counts(limit=60 * 60 * 1000)
async def custom_csrf_error(scope, send, message_id):
await asgi_send(
send,
await self.render_template(
"csrf_error.html",
{"message_id": message_id, "message_name": Errors(message_id).name},
),
403,
)
asgi = asgi_csrf.asgi_csrf( asgi = asgi_csrf.asgi_csrf(
DatasetteRouter(self, routes), DatasetteRouter(self, routes),
signing_secret=self._secret, signing_secret=self._secret,
@ -1664,6 +1675,7 @@ class Datasette:
skip_if_scope=lambda scope: any( skip_if_scope=lambda scope: any(
pm.hook.skip_csrf(datasette=self, scope=scope) pm.hook.skip_csrf(datasette=self, scope=scope)
), ),
send_csrf_failed=custom_csrf_error,
) )
if self.setting("trace_debug"): if self.setting("trace_debug"):
asgi = AsgiTracer(asgi) asgi = AsgiTracer(asgi)

Wyświetl plik

@ -55,7 +55,7 @@ setup(
"uvicorn>=0.11", "uvicorn>=0.11",
"aiofiles>=0.4", "aiofiles>=0.4",
"janus>=0.6.2", "janus>=0.6.2",
"asgi-csrf>=0.9", "asgi-csrf>=0.10",
"PyYAML>=5.3", "PyYAML>=5.3",
"mergedeep>=1.1.1", "mergedeep>=1.1.1",
"itsdangerous>=1.1", "itsdangerous>=1.1",

Wyświetl plik

@ -1,3 +1,4 @@
from asgi_csrf import Errors
from bs4 import BeautifulSoup as Soup from bs4 import BeautifulSoup as Soup
from datasette.app import Datasette from datasette.app import Datasette
from datasette.utils import allowed_pragmas from datasette.utils import allowed_pragmas
@ -1158,3 +1159,16 @@ async def test_database_color(ds_client):
pdb.set_trace() pdb.set_trace()
assert any(fragment in response.text for fragment in expected_fragments) assert any(fragment in response.text for fragment in expected_fragments)
@pytest.mark.asyncio
async def test_custom_csrf_error(ds_client):
response = await ds_client.post(
"/-/messages",
data={
"message": "A message",
},
cookies={"csrftoken": "x"},
)
assert response.status_code == 403
assert "Error code is FORM_URLENCODED_MISMATCH." in response.text