diff --git a/datasette/app.py b/datasette/app.py index 2ef7da41..4a8ead1d 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,11 +1,9 @@ import asyncio import collections import hashlib -import json import os import sys import threading -import time import traceback import urllib.parse from concurrent import futures @@ -14,10 +12,8 @@ from pathlib import Path import click from markupsafe import Markup from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader -from sanic import Sanic, response -from sanic.exceptions import InvalidUsage, NotFound -from .views.base import DatasetteError, ureg +from .views.base import DatasetteError, ureg, AsgiRouter from .views.database import DatabaseDownload, DatabaseView from .views.index import IndexView from .views.special import JsonDataView @@ -36,7 +32,16 @@ from .utils import ( sqlite_timelimit, to_css_class, ) -from .tracer import capture_traces, trace +from .utils.asgi import ( + AsgiLifespan, + NotFound, + asgi_static, + asgi_send, + asgi_send_html, + asgi_send_json, + asgi_send_redirect, +) +from .tracer import trace, AsgiTracer from .plugins import pm, DEFAULT_PLUGINS from .version import __version__ @@ -126,8 +131,8 @@ CONFIG_OPTIONS = ( DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS} -async def favicon(request): - return response.text("") +async def favicon(scope, receive, send): + await asgi_send(send, "", 200) class Datasette: @@ -413,6 +418,7 @@ class Datasette: "full": sys.version, }, "datasette": datasette_version, + "asgi": "3.0", "sqlite": { "version": sqlite_version, "fts_versions": fts_versions, @@ -543,21 +549,7 @@ class Datasette: self.renderers[renderer["extension"]] = renderer["callback"] def app(self): - class TracingSanic(Sanic): - async def handle_request(self, request, write_callback, stream_callback): - if request.args.get("_trace"): - request["traces"] = [] - request["trace_start"] = time.time() - with capture_traces(request["traces"]): - await super().handle_request( - request, write_callback, stream_callback - ) - else: - await super().handle_request( - request, write_callback, stream_callback - ) - - app = TracingSanic(__name__) + "Returns an ASGI app function that serves the whole of Datasette" default_templates = str(app_root / "datasette" / "templates") template_paths = [] if self.template_dir: @@ -588,134 +580,127 @@ class Datasette: pm.hook.prepare_jinja2_environment(env=self.jinja_env) self.register_renderers() + + routes = [] + + def add_route(view, regex): + routes.append((regex, view)) + # Generate a regex snippet to match all registered renderer file extensions renderer_regex = "|".join(r"\." + key for key in self.renderers.keys()) - app.add_route(IndexView.as_view(self), r"/") + add_route(IndexView.as_asgi(self), r"/(?P(\.jsono?)?$)") # TODO: /favicon.ico and /-/static/ deserve far-future cache expires - app.add_route(favicon, "/favicon.ico") - app.static("/-/static/", str(app_root / "datasette" / "static")) + add_route(favicon, "/favicon.ico") + + add_route( + asgi_static(app_root / "datasette" / "static"), r"/-/static/(?P.*)$" + ) for path, dirname in self.static_mounts: - app.static(path, dirname) + add_route(asgi_static(dirname), r"/" + path + "/(?P.*)$") + # Mount any plugin static/ directories for plugin in get_plugins(pm): if plugin["static_path"]: - modpath = "/-/static-plugins/{}/".format(plugin["name"]) - app.static(modpath, plugin["static_path"]) - app.add_route( - JsonDataView.as_view(self, "metadata.json", lambda: self._metadata), - r"/-/metadata", + modpath = "/-/static-plugins/{}/(?P.*)$".format(plugin["name"]) + add_route(asgi_static(plugin["static_path"]), modpath) + add_route( + JsonDataView.as_asgi(self, "metadata.json", lambda: self._metadata), + r"/-/metadata(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "versions.json", self.versions), - r"/-/versions", + add_route( + JsonDataView.as_asgi(self, "versions.json", self.versions), + r"/-/versions(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "plugins.json", self.plugins), - r"/-/plugins", + add_route( + JsonDataView.as_asgi(self, "plugins.json", self.plugins), + r"/-/plugins(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "config.json", lambda: self._config), - r"/-/config", + add_route( + JsonDataView.as_asgi(self, "config.json", lambda: self._config), + r"/-/config(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "databases.json", self.connected_databases), - r"/-/databases", + add_route( + JsonDataView.as_asgi(self, "databases.json", self.connected_databases), + r"/-/databases(?P(\.json)?)$", ) - app.add_route( - DatabaseDownload.as_view(self), r"/" + add_route( + DatabaseDownload.as_asgi(self), r"/(?P[^/]+?)(?P\.db)$" ) - app.add_route( - DatabaseView.as_view(self), - r"/", - ) - app.add_route( - TableView.as_view(self), r"//" - ) - app.add_route( - RowView.as_view(self), - r"///[^/]+?)(?P" + renderer_regex - + r")?$>", + + r"|.jsono|\.csv)?$", + ) + add_route( + TableView.as_asgi(self), + r"/(?P[^/]+)/(?P[^/]+?$)", + ) + add_route( + RowView.as_asgi(self), + r"/(?P[^/]+)/(?P[^/]+?)/(?P[^/]+?)(?P" + + renderer_regex + + r")?$", ) self.register_custom_units() - # On 404 with a trailing slash redirect to path without that slash: - # pylint: disable=unused-variable - @app.middleware("response") - def redirect_on_404_with_trailing_slash(request, original_response): - if original_response.status == 404 and request.path.endswith("/"): - path = request.path.rstrip("/") - if request.query_string: - path = "{}?{}".format(path, request.query_string) - return response.redirect(path) - - @app.middleware("response") - async def add_traces_to_response(request, response): - if request.get("traces") is None: - return - traces = request["traces"] - trace_info = { - "request_duration_ms": 1000 * (time.time() - request["trace_start"]), - "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), - "num_traces": len(traces), - "traces": traces, - } - if "text/html" in response.content_type and b"" in response.body: - extra = json.dumps(trace_info, indent=2) - extra_html = "
{}
".format(extra).encode("utf8") - response.body = response.body.replace(b"", extra_html) - elif "json" in response.content_type and response.body.startswith(b"{"): - data = json.loads(response.body.decode("utf8")) - if "_trace" not in data: - data["_trace"] = trace_info - response.body = json.dumps(data).encode("utf8") - - @app.exception(Exception) - def on_exception(request, exception): - title = None - help = None - if isinstance(exception, NotFound): - status = 404 - info = {} - message = exception.args[0] - elif isinstance(exception, InvalidUsage): - status = 405 - info = {} - message = exception.args[0] - elif isinstance(exception, DatasetteError): - status = exception.status - info = exception.error_dict - message = exception.message - if exception.messagge_is_html: - message = Markup(message) - title = exception.title - else: - status = 500 - info = {} - message = str(exception) - traceback.print_exc() - templates = ["500.html"] - if status != 500: - templates = ["{}.html".format(status)] + templates - info.update( - {"ok": False, "error": message, "status": status, "title": title} - ) - if request is not None and request.path.split("?")[0].endswith(".json"): - r = response.json(info, status=status) - - else: - template = self.jinja_env.select_template(templates) - r = response.html(template.render(info), status=status) - if self.cors: - r.headers["Access-Control-Allow-Origin"] = "*" - return r - - # First time server starts up, calculate table counts for immutable databases - @app.listener("before_server_start") - async def setup_db(app, loop): + async def setup_db(): + # First time server starts up, calculate table counts for immutable databases for dbname, database in self.databases.items(): if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) - return app + return AsgiLifespan( + AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db + ) + + +class DatasetteRouter(AsgiRouter): + def __init__(self, datasette, routes): + self.ds = datasette + super().__init__(routes) + + async def handle_404(self, scope, receive, send): + # If URL has a trailing slash, redirect to URL without it + path = scope.get("raw_path", scope["path"].encode("utf8")) + if path.endswith(b"/"): + path = path.rstrip(b"/") + if scope["query_string"]: + path += b"?" + scope["query_string"] + await asgi_send_redirect(send, path.decode("latin1")) + else: + await super().handle_404(scope, receive, send) + + async def handle_500(self, scope, receive, send, exception): + title = None + if isinstance(exception, NotFound): + status = 404 + info = {} + message = exception.args[0] + elif isinstance(exception, DatasetteError): + status = exception.status + info = exception.error_dict + message = exception.message + if exception.messagge_is_html: + message = Markup(message) + title = exception.title + else: + status = 500 + info = {} + message = str(exception) + traceback.print_exc() + templates = ["500.html"] + if status != 500: + templates = ["{}.html".format(status)] + templates + info.update({"ok": False, "error": message, "status": status, "title": title}) + headers = {} + if self.ds.cors: + headers["Access-Control-Allow-Origin"] = "*" + if scope["path"].split("?")[0].endswith(".json"): + await asgi_send_json(send, info, status=status, headers=headers) + else: + template = self.ds.jinja_env.select_template(templates) + await asgi_send_html( + send, template.render(info), status=status, headers=headers + ) diff --git a/datasette/cli.py b/datasette/cli.py index 0d47f47a..181b281c 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -1,4 +1,5 @@ import asyncio +import uvicorn import click from click import formatting from click_default_group import DefaultGroup @@ -354,4 +355,4 @@ def serve( asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks()) # Start the server - ds.app().run(host=host, port=port, debug=debug) + uvicorn.run(ds.app(), host=host, port=port, log_level="info") diff --git a/datasette/renderer.py b/datasette/renderer.py index 417fecb5..349c2922 100644 --- a/datasette/renderer.py +++ b/datasette/renderer.py @@ -88,5 +88,5 @@ def json_renderer(args, data, view_name): content_type = "text/plain" else: body = json.dumps(data, cls=CustomJSONEncoder) - content_type = "application/json" + content_type = "application/json; charset=utf-8" return {"body": body, "status_code": status_code, "content_type": content_type} diff --git a/datasette/tracer.py b/datasette/tracer.py index c6fe0a00..e46a6fda 100644 --- a/datasette/tracer.py +++ b/datasette/tracer.py @@ -1,6 +1,7 @@ import asyncio from contextlib import contextmanager import time +import json import traceback tracers = {} @@ -32,15 +33,15 @@ def trace(type, **kwargs): start = time.time() yield end = time.time() - trace = { + trace_info = { "type": type, "start": start, "end": end, "duration_ms": (end - start) * 1000, "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), } - trace.update(kwargs) - tracer.append(trace) + trace_info.update(kwargs) + tracer.append(trace_info) @contextmanager @@ -53,3 +54,77 @@ def capture_traces(tracer): tracers[task_id] = tracer yield del tracers[task_id] + + +class AsgiTracer: + # If the body is larger than this we don't attempt to append the trace + max_body_bytes = 1024 * 256 # 256 KB + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if b"_trace=1" not in scope.get("query_string", b"").split(b"&"): + await self.app(scope, receive, send) + return + trace_start = time.time() + traces = [] + + accumulated_body = b"" + size_limit_exceeded = False + response_headers = [] + + async def wrapped_send(message): + nonlocal accumulated_body, size_limit_exceeded, response_headers + if message["type"] == "http.response.start": + response_headers = message["headers"] + await send(message) + return + + if message["type"] != "http.response.body" or size_limit_exceeded: + await send(message) + return + + # Accumulate body until the end or until size is exceeded + accumulated_body += message["body"] + if len(accumulated_body) > self.max_body_bytes: + await send( + { + "type": "http.response.body", + "body": accumulated_body, + "more_body": True, + } + ) + size_limit_exceeded = True + return + + if not message.get("more_body"): + # We have all the body - modify it and send the result + # TODO: What to do about Content-Type or other cases? + trace_info = { + "request_duration_ms": 1000 * (time.time() - trace_start), + "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), + "num_traces": len(traces), + "traces": traces, + } + try: + content_type = [ + v.decode("utf8") + for k, v in response_headers + if k.lower() == b"content-type" + ][0] + except IndexError: + content_type = "" + if "text/html" in content_type and b"" in accumulated_body: + extra = json.dumps(trace_info, indent=2) + extra_html = "
{}
".format(extra).encode("utf8") + accumulated_body = accumulated_body.replace(b"", extra_html) + elif "json" in content_type and accumulated_body.startswith(b"{"): + data = json.loads(accumulated_body.decode("utf8")) + if "_trace" not in data: + data["_trace"] = trace_info + accumulated_body = json.dumps(data).encode("utf8") + await send({"type": "http.response.body", "body": accumulated_body}) + + with capture_traces(traces): + await self.app(scope, receive, wrapped_send) diff --git a/datasette/utils.py b/datasette/utils/__init__.py similarity index 98% rename from datasette/utils.py rename to datasette/utils/__init__.py index 58746be4..94ccc23e 100644 --- a/datasette/utils.py +++ b/datasette/utils/__init__.py @@ -697,13 +697,13 @@ class LimitedWriter: self.limit_bytes = limit_mb * 1024 * 1024 self.bytes_count = 0 - def write(self, bytes): + async def write(self, bytes): self.bytes_count += len(bytes) if self.limit_bytes and (self.bytes_count > self.limit_bytes): raise WriteLimitExceeded( "CSV contains more than {} bytes".format(self.limit_bytes) ) - self.writer.write(bytes) + await self.writer.write(bytes) _infinities = {float("inf"), float("-inf")} @@ -741,3 +741,16 @@ def format_bytes(bytes): return "{} {}".format(int(current), unit) else: return "{:.1f} {}".format(current, unit) + + +class RequestParameters(dict): + def get(self, name, default=None): + "Return first value in the list, if available" + try: + return super().get(name)[0] + except (KeyError, TypeError): + return default + + def getlist(self, name, default=None): + "Return full list" + return super().get(name, default) diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py new file mode 100644 index 00000000..fdf330ae --- /dev/null +++ b/datasette/utils/asgi.py @@ -0,0 +1,377 @@ +import json +from datasette.utils import RequestParameters +from mimetypes import guess_type +from urllib.parse import parse_qs, urlunparse +from pathlib import Path +from html import escape +import re +import aiofiles + + +class NotFound(Exception): + pass + + +class Request: + def __init__(self, scope): + self.scope = scope + + @property + def method(self): + return self.scope["method"] + + @property + def url(self): + return urlunparse( + (self.scheme, self.host, self.path, None, self.query_string, None) + ) + + @property + def scheme(self): + return self.scope.get("scheme") or "http" + + @property + def headers(self): + return dict( + [ + (k.decode("latin-1").lower(), v.decode("latin-1")) + for k, v in self.scope.get("headers") or [] + ] + ) + + @property + def host(self): + return self.headers.get("host") or "localhost" + + @property + def path(self): + return ( + self.scope.get("raw_path", self.scope["path"].encode("latin-1")) + ).decode("latin-1") + + @property + def query_string(self): + return (self.scope.get("query_string") or b"").decode("latin-1") + + @property + def args(self): + return RequestParameters(parse_qs(qs=self.query_string)) + + @property + def raw_args(self): + return {key: value[0] for key, value in self.args.items()} + + @classmethod + def fake(cls, path_with_query_string, method="GET", scheme="http"): + "Useful for constructing Request objects for tests" + path, _, query_string = path_with_query_string.partition("?") + scope = { + "http_version": "1.1", + "method": method, + "path": path, + "raw_path": path.encode("latin-1"), + "query_string": query_string.encode("latin-1"), + "scheme": scheme, + "type": "http", + } + return cls(scope) + + +class AsgiRouter: + def __init__(self, routes=None): + routes = routes or [] + self.routes = [ + # Compile any strings to regular expressions + ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) + for pattern, view in routes + ] + + async def __call__(self, scope, receive, send): + # Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves + path = scope["raw_path"].decode("ascii") + for regex, view in self.routes: + match = regex.match(path) + if match is not None: + new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) + try: + return await view(new_scope, receive, send) + except Exception as exception: + return await self.handle_500(scope, receive, send, exception) + return await self.handle_404(scope, receive, send) + + async def handle_404(self, scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/html"]], + } + ) + await send({"type": "http.response.body", "body": b"

404

"}) + + async def handle_500(self, scope, receive, send, exception): + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/html"]], + } + ) + html = "

500

".format(escape(repr(exception))) + await send({"type": "http.response.body", "body": html.encode("latin-1")}) + + +class AsgiLifespan: + def __init__(self, app, on_startup=None, on_shutdown=None): + self.app = app + on_startup = on_startup or [] + on_shutdown = on_shutdown or [] + if not isinstance(on_startup or [], list): + on_startup = [on_startup] + if not isinstance(on_shutdown or [], list): + on_shutdown = [on_shutdown] + self.on_startup = on_startup + self.on_shutdown = on_shutdown + + async def __call__(self, scope, receive, send): + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + for fn in self.on_startup: + await fn() + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + for fn in self.on_shutdown: + await fn() + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await self.app(scope, receive, send) + + +class AsgiView: + def dispatch_request(self, request, *args, **kwargs): + handler = getattr(self, request.method.lower(), None) + return handler(request, *args, **kwargs) + + @classmethod + def as_asgi(cls, *class_args, **class_kwargs): + async def view(scope, receive, send): + # Uses scope to create a request object, then dispatches that to + # self.get(...) or self.options(...) along with keyword arguments + # that were already tucked into scope["url_route"]["kwargs"] by + # the router, similar to how Django Channels works: + # https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter + request = Request(scope) + self = view.view_class(*class_args, **class_kwargs) + response = await self.dispatch_request( + request, **scope["url_route"]["kwargs"] + ) + await response.asgi_send(send) + + view.view_class = cls + view.__doc__ = cls.__doc__ + view.__module__ = cls.__module__ + view.__name__ = cls.__name__ + return view + + +class AsgiStream: + def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): + self.stream_fn = stream_fn + self.status = status + self.headers = headers or {} + self.content_type = content_type + + async def asgi_send(self, send): + # Remove any existing content-type header + headers = dict( + [(k, v) for k, v in self.headers.items() if k.lower() != "content-type"] + ) + headers["content-type"] = self.content_type + await send( + { + "type": "http.response.start", + "status": self.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in headers.items() + ], + } + ) + w = AsgiWriter(send) + await self.stream_fn(w) + await send({"type": "http.response.body", "body": b""}) + + +class AsgiWriter: + def __init__(self, send): + self.send = send + + async def write(self, chunk): + await self.send( + { + "type": "http.response.body", + "body": chunk.encode("latin-1"), + "more_body": True, + } + ) + + +async def asgi_send_json(send, info, status=200, headers=None): + headers = headers or {} + await asgi_send( + send, + json.dumps(info), + status=status, + headers=headers, + content_type="application/json; charset=utf-8", + ) + + +async def asgi_send_html(send, html, status=200, headers=None): + headers = headers or {} + await asgi_send( + send, html, status=status, headers=headers, content_type="text/html" + ) + + +async def asgi_send_redirect(send, location, status=302): + await asgi_send( + send, + "", + status=status, + headers={"Location": location}, + content_type="text/html", + ) + + +async def asgi_send(send, content, status, headers=None, content_type="text/plain"): + await asgi_start(send, status, headers, content_type) + await send({"type": "http.response.body", "body": content.encode("latin-1")}) + + +async def asgi_start(send, status, headers=None, content_type="text/plain"): + headers = headers or {} + # Remove any existing content-type header + headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"]) + headers["content-type"] = content_type + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + [key.encode("latin1"), value.encode("latin1")] + for key, value in headers.items() + ], + } + ) + + +async def asgi_send_file( + send, filepath, filename=None, content_type=None, chunk_size=4096 +): + headers = {} + if filename: + headers["Content-Disposition"] = 'attachment; filename="{}"'.format(filename) + first = True + async with aiofiles.open(str(filepath), mode="rb") as fp: + if first: + await asgi_start( + send, + 200, + headers, + content_type or guess_type(str(filepath))[0] or "text/plain", + ) + first = False + more_body = True + while more_body: + chunk = await fp.read(chunk_size) + more_body = len(chunk) == chunk_size + await send( + {"type": "http.response.body", "body": chunk, "more_body": more_body} + ) + + +def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None): + async def inner_static(scope, receive, send): + path = scope["url_route"]["kwargs"]["path"] + full_path = (Path(root_path) / path).absolute() + # Ensure full_path is within root_path to avoid weird "../" tricks + try: + full_path.relative_to(root_path) + except ValueError: + await asgi_send_html(send, "404", 404) + return + first = True + try: + await asgi_send_file(send, full_path, chunk_size=chunk_size) + except FileNotFoundError: + await asgi_send_html(send, "404", 404) + return + + return inner_static + + +class Response: + def __init__(self, body=None, status=200, headers=None, content_type="text/plain"): + self.body = body + self.status = status + self.headers = headers or {} + self.content_type = content_type + + async def asgi_send(self, send): + headers = {} + headers.update(self.headers) + headers["content-type"] = self.content_type + await send( + { + "type": "http.response.start", + "status": self.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in headers.items() + ], + } + ) + body = self.body + if not isinstance(body, bytes): + body = body.encode("utf-8") + await send({"type": "http.response.body", "body": body}) + + @classmethod + def html(cls, body, status=200, headers=None): + return cls( + body, + status=status, + headers=headers, + content_type="text/html; charset=utf-8", + ) + + @classmethod + def text(cls, body, status=200, headers=None): + return cls( + body, + status=status, + headers=headers, + content_type="text/plain; charset=utf-8", + ) + + @classmethod + def redirect(cls, path, status=302, headers=None): + headers = headers or {} + headers["Location"] = path + return cls("", status=status, headers=headers) + + +class AsgiFileDownload: + def __init__( + self, filepath, filename=None, content_type="application/octet-stream" + ): + self.filepath = filepath + self.filename = filename + self.content_type = content_type + + async def asgi_send(self, send): + return await asgi_send_file(send, self.filepath, content_type=self.content_type) diff --git a/datasette/views/base.py b/datasette/views/base.py index 9db8cc76..7acb7304 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -7,9 +7,8 @@ import urllib import jinja2 import pint -from sanic import response -from sanic.exceptions import NotFound -from sanic.views import HTTPMethodView + +from html import escape from datasette import __version__ from datasette.plugins import pm @@ -26,6 +25,14 @@ from datasette.utils import ( sqlite3, to_css_class, ) +from datasette.utils.asgi import ( + AsgiStream, + AsgiWriter, + AsgiRouter, + AsgiView, + NotFound, + Response, +) ureg = pint.UnitRegistry() @@ -49,7 +56,14 @@ class DatasetteError(Exception): self.messagge_is_html = messagge_is_html -class BaseView(HTTPMethodView): +class BaseView(AsgiView): + ds = None + + async def head(self, *args, **kwargs): + response = await self.get(*args, **kwargs) + response.body = b"" + return response + def _asset_urls(self, key, template, context): # Flatten list-of-lists from plugins: seen_urls = set() @@ -104,7 +118,7 @@ class BaseView(HTTPMethodView): datasette=self.ds, ): body_scripts.append(jinja2.Markup(script)) - return response.html( + return Response.html( template.render( { **context, @@ -136,7 +150,7 @@ class DataView(BaseView): self.ds = datasette def options(self, request, *args, **kwargs): - r = response.text("ok") + r = Response.text("ok") if self.ds.cors: r.headers["Access-Control-Allow-Origin"] = "*" return r @@ -146,7 +160,7 @@ class DataView(BaseView): path = "{}?{}".format(path, request.query_string) if remove_args: path = path_with_removed_args(request, remove_args, path=path) - r = response.redirect(path) + r = Response.redirect(path) r.headers["Link"] = "<{}>; rel=preload".format(path) if self.ds.cors: r.headers["Access-Control-Allow-Origin"] = "*" @@ -195,17 +209,17 @@ class DataView(BaseView): kwargs["table"] = table if _format: kwargs["as_format"] = ".{}".format(_format) - elif "table" in kwargs: + elif kwargs.get("table"): kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"]) should_redirect = "/{}-{}".format(name, expected) - if "table" in kwargs: + if kwargs.get("table"): should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"]) - if "pk_path" in kwargs: + if kwargs.get("pk_path"): should_redirect += "/" + kwargs["pk_path"] - if "as_format" in kwargs: + if kwargs.get("as_format"): should_redirect += kwargs["as_format"] - if "as_db" in kwargs: + if kwargs.get("as_db"): should_redirect += kwargs["as_db"] if ( @@ -246,7 +260,7 @@ class DataView(BaseView): response_or_template_contexts = await self.data( request, database, hash, **kwargs ) - if isinstance(response_or_template_contexts, response.HTTPResponse): + if isinstance(response_or_template_contexts, Response): return response_or_template_contexts else: data, _, _ = response_or_template_contexts @@ -282,13 +296,13 @@ class DataView(BaseView): if not first: data, _, _ = await self.data(request, database, hash, **kwargs) if first: - writer.writerow(headings) + await writer.writerow(headings) first = False next = data.get("next") for row in data["rows"]: if not expanded_columns: # Simple path - writer.writerow(row) + await writer.writerow(row) else: # Look for {"value": "label": } dicts and expand new_row = [] @@ -298,10 +312,10 @@ class DataView(BaseView): new_row.append(cell["label"]) else: new_row.append(cell) - writer.writerow(new_row) + await writer.writerow(new_row) except Exception as e: print("caught this", e) - r.write(str(e)) + await r.write(str(e)) return content_type = "text/plain; charset=utf-8" @@ -315,7 +329,7 @@ class DataView(BaseView): ) headers["Content-Disposition"] = disposition - return response.stream(stream_fn, headers=headers, content_type=content_type) + return AsgiStream(stream_fn, headers=headers, content_type=content_type) async def get_format(self, request, database, args): """ Determine the format of the response from the request, from URL @@ -363,7 +377,7 @@ class DataView(BaseView): response_or_template_contexts = await self.data( request, database, hash, **kwargs ) - if isinstance(response_or_template_contexts, response.HTTPResponse): + if isinstance(response_or_template_contexts, Response): return response_or_template_contexts else: @@ -414,17 +428,11 @@ class DataView(BaseView): if result is None: raise NotFound("No data") - response_args = { - "content_type": result.get("content_type", "text/plain"), - "status": result.get("status_code", 200), - } - - if type(result.get("body")) == bytes: - response_args["body_bytes"] = result.get("body") - else: - response_args["body"] = result.get("body") - - r = response.HTTPResponse(**response_args) + r = Response( + body=result.get("body"), + status=result.get("status_code", 200), + content_type=result.get("content_type", "text/plain"), + ) else: extras = {} if callable(extra_template_data): diff --git a/datasette/views/database.py b/datasette/views/database.py index a5b606f1..78af19c5 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,10 +1,9 @@ import os -from sanic import response - from datasette.utils import to_css_class, validate_sql_select +from datasette.utils.asgi import AsgiFileDownload -from .base import DataView, DatasetteError +from .base import DatasetteError, DataView class DatabaseView(DataView): @@ -79,8 +78,8 @@ class DatabaseDownload(DataView): if not db.path: raise DatasetteError("Cannot download database", status=404) filepath = db.path - return await response.file_stream( + return AsgiFileDownload( filepath, filename=os.path.basename(filepath), - mime_type="application/octet-stream", + content_type="application/octet-stream", ) diff --git a/datasette/views/index.py b/datasette/views/index.py index c9d15c36..2c1c017a 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -1,9 +1,8 @@ import hashlib import json -from sanic import response - from datasette.utils import CustomJSONEncoder +from datasette.utils.asgi import Response from datasette.version import __version__ from .base import BaseView @@ -104,9 +103,9 @@ class IndexView(BaseView): headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" - return response.HTTPResponse( + return Response( json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder), - content_type="application/json", + content_type="application/json; charset=utf-8", headers=headers, ) else: diff --git a/datasette/views/special.py b/datasette/views/special.py index 91b577fc..c4976bb2 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -1,5 +1,5 @@ import json -from sanic import response +from datasette.utils.asgi import Response from .base import BaseView @@ -17,8 +17,10 @@ class JsonDataView(BaseView): headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" - return response.HTTPResponse( - json.dumps(data), content_type="application/json", headers=headers + return Response( + json.dumps(data), + content_type="application/json; charset=utf-8", + headers=headers, ) else: diff --git a/datasette/views/table.py b/datasette/views/table.py index 14b8743a..06be5671 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -3,13 +3,12 @@ import itertools import json import jinja2 -from sanic.exceptions import NotFound -from sanic.request import RequestParameters from datasette.plugins import pm from datasette.utils import ( CustomRow, QueryInterrupted, + RequestParameters, append_querystring, compound_keys_after_sql, escape_sqlite, @@ -24,6 +23,7 @@ from datasette.utils import ( urlsafe_components, value_as_boolean, ) +from datasette.utils.asgi import NotFound from datasette.filters import Filters from .base import DataView, DatasetteError, ureg @@ -219,8 +219,7 @@ class TableView(RowTableShared): if is_view: order_by = "" - # We roll our own query_string decoder because by default Sanic - # drops anything with an empty value e.g. ?name__exact= + # Ensure we don't drop anything with an empty value e.g. ?name__exact= args = RequestParameters( urllib.parse.parse_qs(request.query_string, keep_blank_values=True) ) diff --git a/pytest.ini b/pytest.ini index f2c8a6d2..aa292efc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,7 +4,5 @@ filterwarnings= ignore:Using or importing the ABCs::jinja2 # https://bugs.launchpad.net/beautifulsoup/+bug/1778909 ignore:Using or importing the ABCs::bs4.element - # Sanic verify_ssl=True - ignore:verify_ssl is deprecated::sanic # Python 3.7 PendingDeprecationWarning: Task.current_task() ignore:.*current_task.*:PendingDeprecationWarning diff --git a/setup.py b/setup.py index 60c1bcc5..f66d03da 100644 --- a/setup.py +++ b/setup.py @@ -37,17 +37,18 @@ setup( author="Simon Willison", license="Apache License, Version 2.0", url="https://github.com/simonw/datasette", - packages=find_packages(exclude='tests'), + packages=find_packages(exclude="tests"), package_data={"datasette": ["templates/*.html"]}, include_package_data=True, install_requires=[ "click>=6.7", "click-default-group==1.2", - "Sanic==0.7.0", "Jinja2==2.10.1", "hupper==1.0", "pint==0.8.1", "pluggy>=0.12.0", + "uvicorn>=0.8.1", + "aiofiles==0.4.0", ], entry_points=""" [console_scripts] @@ -60,6 +61,7 @@ setup( "pytest-asyncio==0.10.0", "aiohttp==3.5.3", "beautifulsoup4==4.6.1", + "asgiref==3.1.2", ] + maybe_black }, diff --git a/tests/fixtures.py b/tests/fixtures.py index 04ac3c68..00140f50 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,5 +1,7 @@ from datasette.app import Datasette from datasette.utils import sqlite3 +from asgiref.testing import ApplicationCommunicator +from asgiref.sync import async_to_sync import itertools import json import os @@ -10,16 +12,82 @@ import sys import string import tempfile import time +from urllib.parse import unquote + + +class TestResponse: + def __init__(self, status, headers, body): + self.status = status + self.headers = headers + self.body = body + + @property + def json(self): + return json.loads(self.text) + + @property + def text(self): + return self.body.decode("utf8") class TestClient: - def __init__(self, sanic_test_client): - self.sanic_test_client = sanic_test_client + max_redirects = 5 - def get(self, path, allow_redirects=True): - return self.sanic_test_client.get( - path, allow_redirects=allow_redirects, gather_request=False + def __init__(self, asgi_app): + self.asgi_app = asgi_app + + @async_to_sync + async def get(self, path, allow_redirects=True, redirect_count=0, method="GET"): + return await self._get(path, allow_redirects, redirect_count, method) + + async def _get(self, path, allow_redirects=True, redirect_count=0, method="GET"): + query_string = b"" + if "?" in path: + path, _, query_string = path.partition("?") + query_string = query_string.encode("utf8") + instance = ApplicationCommunicator( + self.asgi_app, + { + "type": "http", + "http_version": "1.0", + "method": method, + "path": unquote(path), + "raw_path": path.encode("ascii"), + "query_string": query_string, + "headers": [[b"host", b"localhost"]], + }, ) + await instance.send_input({"type": "http.request"}) + # First message back should be response.start with headers and status + messages = [] + start = await instance.receive_output(2) + messages.append(start) + assert start["type"] == "http.response.start" + headers = dict( + [(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]] + ) + status = start["status"] + # Now loop until we run out of response.body + body = b"" + while True: + message = await instance.receive_output(2) + messages.append(message) + assert message["type"] == "http.response.body" + body += message["body"] + if not message.get("more_body"): + break + response = TestResponse(status, headers, body) + if allow_redirects and response.status in (301, 302): + assert ( + redirect_count < self.max_redirects + ), "Redirected {} times, max_redirects={}".format( + redirect_count, self.max_redirects + ) + location = response.headers["Location"] + return await self._get( + location, allow_redirects=True, redirect_count=redirect_count + 1 + ) + return response def make_app_client( @@ -32,6 +100,7 @@ def make_app_client( is_immutable=False, extra_databases=None, inspect_data=None, + static_mounts=None, ): with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, filename) @@ -73,9 +142,10 @@ def make_app_client( plugins_dir=plugins_dir, config=config, inspect_data=inspect_data, + static_mounts=static_mounts, ) ds.sqlite_functions.append(("sleep", 1, lambda n: time.sleep(float(n)))) - client = TestClient(ds.app().test_client) + client = TestClient(ds.app()) client.ds = ds yield client @@ -88,7 +158,7 @@ def app_client(): @pytest.fixture(scope="session") def app_client_no_files(): ds = Datasette([]) - client = TestClient(ds.app().test_client) + client = TestClient(ds.app()) client.ds = ds yield client diff --git a/tests/test_api.py b/tests/test_api.py index 5c1bff15..a32ed5e3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -22,6 +22,7 @@ import urllib def test_homepage(app_client): response = app_client.get("/.json") assert response.status == 200 + assert "application/json; charset=utf-8" == response.headers["content-type"] assert response.json.keys() == {"fixtures": 0}.keys() d = response.json["fixtures"] assert d["name"] == "fixtures" @@ -771,8 +772,8 @@ def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pag fetched.extend(response.json["rows"]) path = response.json["next_url"] if path: - assert response.json["next"] assert urllib.parse.urlencode({"_next": response.json["next"]}) in path + path = path.replace("http://localhost", "") assert count < 30, "Possible infinite loop detected" assert expected_rows == len(fetched) @@ -812,6 +813,8 @@ def test_paginate_compound_keys(app_client): response = app_client.get(path) fetched.extend(response.json["rows"]) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert page < 100 assert 1001 == len(fetched) assert 21 == page @@ -833,6 +836,8 @@ def test_paginate_compound_keys_with_extra_filters(app_client): response = app_client.get(path) fetched.extend(response.json["rows"]) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert 2 == page expected = [r[3] for r in generate_compound_rows(1001) if "d" in r[3]] assert expected == [f["content"] for f in fetched] @@ -881,6 +886,8 @@ def test_sortable(app_client, query_string, sort_key, human_description_en): assert human_description_en == response.json["human_description_en"] fetched.extend(response.json["rows"]) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert 5 == page expected = list(generate_sortable_rows(201)) expected.sort(key=sort_key) @@ -1191,6 +1198,7 @@ def test_plugins_json(app_client): def test_versions_json(app_client): response = app_client.get("/-/versions.json") assert "python" in response.json + assert "3.0" == response.json.get("asgi") assert "version" in response.json["python"] assert "full" in response.json["python"] assert "datasette" in response.json @@ -1236,6 +1244,8 @@ def test_page_size_matching_max_returned_rows( fetched.extend(response.json["rows"]) assert len(response.json["rows"]) in (1, 50) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert 201 == len(fetched) diff --git a/tests/test_csv.py b/tests/test_csv.py index cf0e6732..c3cdc241 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -46,7 +46,7 @@ def test_table_csv(app_client): response = app_client.get("/fixtures/simple_primary_key.csv") assert response.status == 200 assert not response.headers.get("Access-Control-Allow-Origin") - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_TABLE_CSV == response.text @@ -59,7 +59,7 @@ def test_table_csv_cors_headers(app_client_with_cors): def test_table_csv_with_labels(app_client): response = app_client.get("/fixtures/facetable.csv?_labels=1") assert response.status == 200 - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text @@ -68,14 +68,14 @@ def test_custom_sql_csv(app_client): "/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2" ) assert response.status == 200 - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_CUSTOM_CSV == response.text def test_table_csv_download(app_client): response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1") assert response.status == 200 - assert "text/csv; charset=utf-8" == response.headers["Content-Type"] + assert "text/csv; charset=utf-8" == response.headers["content-type"] expected_disposition = 'attachment; filename="simple_primary_key.csv"' assert expected_disposition == response.headers["Content-Disposition"] diff --git a/tests/test_html.py b/tests/test_html.py index 6b673c13..32fa2fe3 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -8,6 +8,7 @@ from .fixtures import ( # noqa METADATA, ) import json +import pathlib import pytest import re import urllib.parse @@ -16,6 +17,7 @@ import urllib.parse def test_homepage(app_client_two_attached_databases): response = app_client_two_attached_databases.get("/") assert response.status == 200 + assert "text/html; charset=utf-8" == response.headers["content-type"] soup = Soup(response.body, "html.parser") assert "Datasette Fixtures" == soup.find("h1").text assert ( @@ -44,6 +46,29 @@ def test_homepage(app_client_two_attached_databases): ] == table_links +def test_http_head(app_client): + response = app_client.get("/", method="HEAD") + assert response.status == 200 + + +def test_static(app_client): + response = app_client.get("/-/static/app2.css") + assert response.status == 404 + response = app_client.get("/-/static/app.css") + assert response.status == 200 + assert "text/css" == response.headers["content-type"] + + +def test_static_mounts(): + for client in make_app_client( + static_mounts=[("custom-static", str(pathlib.Path(__file__).parent))] + ): + response = client.get("/custom-static/test_html.py") + assert response.status == 200 + response = client.get("/custom-static/not_exists.py") + assert response.status == 404 + + def test_memory_database_page(): for client in make_app_client(memory=True): response = client.get("/:memory:") diff --git a/tests/test_utils.py b/tests/test_utils.py index a5f603e6..e9e722b8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,11 +3,11 @@ Tests for various datasette helper functions. """ from datasette import utils +from datasette.utils.asgi import Request from datasette.filters import Filters import json import os import pytest -from sanic.request import Request import sqlite3 import tempfile from unittest.mock import patch @@ -53,7 +53,7 @@ def test_urlsafe_components(path, expected): ], ) def test_path_with_added_args(path, added_args, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_added_args(request, added_args) assert expected == actual @@ -67,11 +67,11 @@ def test_path_with_added_args(path, added_args, expected): ], ) def test_path_with_removed_args(path, args, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_removed_args(request, args) assert expected == actual # Run the test again but this time use the path= argument - request = Request("/".encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake("/") actual = utils.path_with_removed_args(request, args, path=path) assert expected == actual @@ -84,7 +84,7 @@ def test_path_with_removed_args(path, args, expected): ], ) def test_path_with_replaced_args(path, args, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_replaced_args(request, args) assert expected == actual @@ -363,7 +363,7 @@ def test_table_columns(): ], ) def test_path_with_format(path, format, extra_qs, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_format(request, format, extra_qs) assert expected == actual