diff --git a/datasette/app.py b/datasette/app.py index 8f69ee98..1363bc5c 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,3 +1,4 @@ +from asgi_csrf import Errors import asyncio from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import asgi_csrf @@ -1657,6 +1658,16 @@ class Datasette: if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) + async def custom_csrf_error(scope, send, message_id): + await asgi_send( + send, + await self.render_template( + "csrf_error.html", + {"message_id": message_id, "message_name": Errors(message_id).name}, + ), + 403, + ) + asgi = asgi_csrf.asgi_csrf( DatasetteRouter(self, routes), signing_secret=self._secret, @@ -1664,6 +1675,7 @@ class Datasette: skip_if_scope=lambda scope: any( pm.hook.skip_csrf(datasette=self, scope=scope) ), + send_csrf_failed=custom_csrf_error, ) if self.setting("trace_debug"): asgi = AsgiTracer(asgi) diff --git a/setup.py b/setup.py index c69404f8..923bc826 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ setup( "uvicorn>=0.11", "aiofiles>=0.4", "janus>=0.6.2", - "asgi-csrf>=0.9", + "asgi-csrf>=0.10", "PyYAML>=5.3", "mergedeep>=1.1.1", "itsdangerous>=1.1", diff --git a/tests/test_html.py b/tests/test_html.py index d648bdf0..c559f0c2 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -1,3 +1,4 @@ +from asgi_csrf import Errors from bs4 import BeautifulSoup as Soup from datasette.app import Datasette from datasette.utils import allowed_pragmas @@ -1158,3 +1159,16 @@ async def test_database_color(ds_client): pdb.set_trace() assert any(fragment in response.text for fragment in expected_fragments) + + +@pytest.mark.asyncio +async def test_custom_csrf_error(ds_client): + response = await ds_client.post( + "/-/messages", + data={ + "message": "A message", + }, + cookies={"csrftoken": "x"}, + ) + assert response.status_code == 403 + assert "Error code is FORM_URLENCODED_MISMATCH." in response.text