From dda99fc09fb0b5523948f6d481c6c051c1c7b5de Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 25 May 2023 17:18:43 -0700 Subject: [PATCH] New View base class (#2080) * New View base class, closes #2078 * Use new View subclass for PatternPortfolioView --- datasette/app.py | 40 +++++++++++++++++- datasette/views/base.py | 37 +++++++++++++++++ datasette/views/special.py | 19 +++++---- tests/test_base_view.py | 84 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 170 insertions(+), 10 deletions(-) create mode 100644 tests/test_base_view.py diff --git a/datasette/app.py b/datasette/app.py index d7dace67..1f80c5a9 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -17,6 +17,7 @@ import secrets import sys import threading import time +import types import urllib.parse from concurrent import futures from pathlib import Path @@ -1361,7 +1362,7 @@ class Datasette: r"/-/allow-debug$", ) add_route( - PatternPortfolioView.as_view(self), + wrap_view(PatternPortfolioView, self), r"/-/patterns$", ) add_route(DatabaseDownload.as_view(self), r"/(?P[^\/\.]+)\.db$") @@ -1673,7 +1674,42 @@ def _cleaner_task_str(task): return _cleaner_task_str_re.sub("", s) -def wrap_view(view_fn, datasette): +def wrap_view(view_fn_or_class, datasette): + is_function = isinstance(view_fn_or_class, types.FunctionType) + if is_function: + return wrap_view_function(view_fn_or_class, datasette) + else: + if not isinstance(view_fn_or_class, type): + raise ValueError("view_fn_or_class must be a function or a class") + return wrap_view_class(view_fn_or_class, datasette) + + +def wrap_view_class(view_class, datasette): + async def async_view_for_class(request, send): + instance = view_class() + if inspect.iscoroutinefunction(instance.__call__): + return await async_call_with_supported_arguments( + instance.__call__, + scope=request.scope, + receive=request.receive, + send=send, + request=request, + datasette=datasette, + ) + else: + return call_with_supported_arguments( + instance.__call__, + scope=request.scope, + receive=request.receive, + send=send, + request=request, + datasette=datasette, + ) + + return async_view_for_class + + +def wrap_view_function(view_fn, datasette): @functools.wraps(view_fn) async def async_view_fn(request, send): if inspect.iscoroutinefunction(view_fn): diff --git a/datasette/views/base.py b/datasette/views/base.py index 927d1aff..94645cd8 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -53,6 +53,43 @@ class DatasetteError(Exception): self.message_is_html = message_is_html +class View: + async def head(self, request, datasette): + if not hasattr(self, "get"): + return await self.method_not_allowed(request) + response = await self.get(request, datasette) + response.body = "" + return response + + async def method_not_allowed(self, request): + if ( + request.path.endswith(".json") + or request.headers.get("content-type") == "application/json" + ): + response = Response.json( + {"ok": False, "error": "Method not allowed"}, status=405 + ) + else: + response = Response.text("Method not allowed", status=405) + return response + + async def options(self, request, datasette): + response = Response.text("ok") + response.headers["allow"] = ", ".join( + method.upper() + for method in ("head", "get", "post", "put", "patch", "delete") + if hasattr(self, method) + ) + return response + + async def __call__(self, request, datasette): + try: + handler = getattr(self, request.method.lower()) + except AttributeError: + return await self.method_not_allowed(request) + return await handler(request, datasette) + + class BaseView: ds = None has_json_alternate = True diff --git a/datasette/views/special.py b/datasette/views/special.py index 1aeb4be6..03e085d6 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -6,7 +6,7 @@ from datasette.utils import ( tilde_encode, tilde_decode, ) -from .base import BaseView +from .base import BaseView, View import secrets import urllib @@ -57,13 +57,16 @@ class JsonDataView(BaseView): ) -class PatternPortfolioView(BaseView): - name = "patterns" - has_json_alternate = False - - async def get(self, request): - await self.ds.ensure_permissions(request.actor, ["view-instance"]) - return await self.render(["patterns.html"], request=request) +class PatternPortfolioView(View): + async def get(self, request, datasette): + await datasette.ensure_permissions(request.actor, ["view-instance"]) + return Response.html( + await datasette.render_template( + "patterns.html", + request=request, + view_name="patterns", + ) + ) class AuthTokenView(BaseView): diff --git a/tests/test_base_view.py b/tests/test_base_view.py new file mode 100644 index 00000000..2cd4d601 --- /dev/null +++ b/tests/test_base_view.py @@ -0,0 +1,84 @@ +from datasette.views.base import View +from datasette import Request, Response +from datasette.app import Datasette +import json +import pytest + + +class GetView(View): + async def get(self, request, datasette): + return Response.json( + { + "absolute_url": datasette.absolute_url(request, "/"), + "request_path": request.path, + } + ) + + +class GetAndPostView(GetView): + async def post(self, request, datasette): + return Response.json( + { + "method": request.method, + "absolute_url": datasette.absolute_url(request, "/"), + "request_path": request.path, + } + ) + + +@pytest.mark.asyncio +async def test_get_view(): + v = GetView() + datasette = Datasette() + response = await v(Request.fake("/foo"), datasette) + assert json.loads(response.body) == { + "absolute_url": "http://localhost/", + "request_path": "/foo", + } + # Try a HEAD request + head_response = await v(Request.fake("/foo", method="HEAD"), datasette) + assert head_response.body == "" + assert head_response.status == 200 + # And OPTIONS + options_response = await v(Request.fake("/foo", method="OPTIONS"), datasette) + assert options_response.body == "ok" + assert options_response.status == 200 + assert options_response.headers["allow"] == "HEAD, GET" + # And POST + post_response = await v(Request.fake("/foo", method="POST"), datasette) + assert post_response.body == "Method not allowed" + assert post_response.status == 405 + # And POST with .json extension + post_json_response = await v(Request.fake("/foo.json", method="POST"), datasette) + assert json.loads(post_json_response.body) == { + "ok": False, + "error": "Method not allowed", + } + assert post_json_response.status == 405 + + +@pytest.mark.asyncio +async def test_post_view(): + v = GetAndPostView() + datasette = Datasette() + response = await v(Request.fake("/foo"), datasette) + assert json.loads(response.body) == { + "absolute_url": "http://localhost/", + "request_path": "/foo", + } + # Try a HEAD request + head_response = await v(Request.fake("/foo", method="HEAD"), datasette) + assert head_response.body == "" + assert head_response.status == 200 + # And OPTIONS + options_response = await v(Request.fake("/foo", method="OPTIONS"), datasette) + assert options_response.body == "ok" + assert options_response.status == 200 + assert options_response.headers["allow"] == "HEAD, GET, POST" + # And POST + post_response = await v(Request.fake("/foo", method="POST"), datasette) + assert json.loads(post_response.body) == { + "method": "POST", + "absolute_url": "http://localhost/", + "request_path": "/foo", + }