Port Datasette from Sanic to ASGI + Uvicorn (#518)

Datasette now uses ASGI internally, and no longer depends on Sanic.

It now uses Uvicorn as the underlying HTTP server.

This was thirteen months in the making... for full details see the issue:

https://github.com/simonw/datasette/issues/272

And for a full sequence of commits plus commentary, see the pull request:

https://github.com/simonw/datasette/pull/518
pull/524/head
Simon Willison 2019-06-23 20:13:09 -07:00 zatwierdzone przez GitHub
rodzic 35429f9089
commit ba8db9679f
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
18 zmienionych plików z 770 dodań i 207 usunięć

Wyświetl plik

@ -1,11 +1,9 @@
import asyncio
import collections
import hashlib
import json
import os
import sys
import threading
import time
import traceback
import urllib.parse
from concurrent import futures
@ -14,10 +12,8 @@ from pathlib import Path
import click
from markupsafe import Markup
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound
from .views.base import DatasetteError, ureg
from .views.base import DatasetteError, ureg, AsgiRouter
from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView
from .views.special import JsonDataView
@ -36,7 +32,16 @@ from .utils import (
sqlite_timelimit,
to_css_class,
)
from .tracer import capture_traces, trace
from .utils.asgi import (
AsgiLifespan,
NotFound,
asgi_static,
asgi_send,
asgi_send_html,
asgi_send_json,
asgi_send_redirect,
)
from .tracer import trace, AsgiTracer
from .plugins import pm, DEFAULT_PLUGINS
from .version import __version__
@ -126,8 +131,8 @@ CONFIG_OPTIONS = (
DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS}
async def favicon(request):
return response.text("")
async def favicon(scope, receive, send):
await asgi_send(send, "", 200)
class Datasette:
@ -413,6 +418,7 @@ class Datasette:
"full": sys.version,
},
"datasette": datasette_version,
"asgi": "3.0",
"sqlite": {
"version": sqlite_version,
"fts_versions": fts_versions,
@ -543,21 +549,7 @@ class Datasette:
self.renderers[renderer["extension"]] = renderer["callback"]
def app(self):
class TracingSanic(Sanic):
async def handle_request(self, request, write_callback, stream_callback):
if request.args.get("_trace"):
request["traces"] = []
request["trace_start"] = time.time()
with capture_traces(request["traces"]):
await super().handle_request(
request, write_callback, stream_callback
)
else:
await super().handle_request(
request, write_callback, stream_callback
)
app = TracingSanic(__name__)
"Returns an ASGI app function that serves the whole of Datasette"
default_templates = str(app_root / "datasette" / "templates")
template_paths = []
if self.template_dir:
@ -588,134 +580,127 @@ class Datasette:
pm.hook.prepare_jinja2_environment(env=self.jinja_env)
self.register_renderers()
routes = []
def add_route(view, regex):
routes.append((regex, view))
# Generate a regex snippet to match all registered renderer file extensions
renderer_regex = "|".join(r"\." + key for key in self.renderers.keys())
app.add_route(IndexView.as_view(self), r"/<as_format:(\.jsono?)?$>")
add_route(IndexView.as_asgi(self), r"/(?P<as_format>(\.jsono?)?$)")
# TODO: /favicon.ico and /-/static/ deserve far-future cache expires
app.add_route(favicon, "/favicon.ico")
app.static("/-/static/", str(app_root / "datasette" / "static"))
add_route(favicon, "/favicon.ico")
add_route(
asgi_static(app_root / "datasette" / "static"), r"/-/static/(?P<path>.*)$"
)
for path, dirname in self.static_mounts:
app.static(path, dirname)
add_route(asgi_static(dirname), r"/" + path + "/(?P<path>.*)$")
# Mount any plugin static/ directories
for plugin in get_plugins(pm):
if plugin["static_path"]:
modpath = "/-/static-plugins/{}/".format(plugin["name"])
app.static(modpath, plugin["static_path"])
app.add_route(
JsonDataView.as_view(self, "metadata.json", lambda: self._metadata),
r"/-/metadata<as_format:(\.json)?$>",
modpath = "/-/static-plugins/{}/(?P<path>.*)$".format(plugin["name"])
add_route(asgi_static(plugin["static_path"]), modpath)
add_route(
JsonDataView.as_asgi(self, "metadata.json", lambda: self._metadata),
r"/-/metadata(?P<as_format>(\.json)?)$",
)
app.add_route(
JsonDataView.as_view(self, "versions.json", self.versions),
r"/-/versions<as_format:(\.json)?$>",
add_route(
JsonDataView.as_asgi(self, "versions.json", self.versions),
r"/-/versions(?P<as_format>(\.json)?)$",
)
app.add_route(
JsonDataView.as_view(self, "plugins.json", self.plugins),
r"/-/plugins<as_format:(\.json)?$>",
add_route(
JsonDataView.as_asgi(self, "plugins.json", self.plugins),
r"/-/plugins(?P<as_format>(\.json)?)$",
)
app.add_route(
JsonDataView.as_view(self, "config.json", lambda: self._config),
r"/-/config<as_format:(\.json)?$>",
add_route(
JsonDataView.as_asgi(self, "config.json", lambda: self._config),
r"/-/config(?P<as_format>(\.json)?)$",
)
app.add_route(
JsonDataView.as_view(self, "databases.json", self.connected_databases),
r"/-/databases<as_format:(\.json)?$>",
add_route(
JsonDataView.as_asgi(self, "databases.json", self.connected_databases),
r"/-/databases(?P<as_format>(\.json)?)$",
)
app.add_route(
DatabaseDownload.as_view(self), r"/<db_name:[^/]+?><as_db:(\.db)$>"
add_route(
DatabaseDownload.as_asgi(self), r"/(?P<db_name>[^/]+?)(?P<as_db>\.db)$"
)
app.add_route(
DatabaseView.as_view(self),
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>",
)
app.add_route(
TableView.as_view(self), r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>"
)
app.add_route(
RowView.as_view(self),
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:("
add_route(
DatabaseView.as_asgi(self),
r"/(?P<db_name>[^/]+?)(?P<as_format>"
+ renderer_regex
+ r")?$>",
+ r"|.jsono|\.csv)?$",
)
add_route(
TableView.as_asgi(self),
r"/(?P<db_name>[^/]+)/(?P<table_and_format>[^/]+?$)",
)
add_route(
RowView.as_asgi(self),
r"/(?P<db_name>[^/]+)/(?P<table>[^/]+?)/(?P<pk_path>[^/]+?)(?P<as_format>"
+ renderer_regex
+ r")?$",
)
self.register_custom_units()
# On 404 with a trailing slash redirect to path without that slash:
# pylint: disable=unused-variable
@app.middleware("response")
def redirect_on_404_with_trailing_slash(request, original_response):
if original_response.status == 404 and request.path.endswith("/"):
path = request.path.rstrip("/")
if request.query_string:
path = "{}?{}".format(path, request.query_string)
return response.redirect(path)
@app.middleware("response")
async def add_traces_to_response(request, response):
if request.get("traces") is None:
return
traces = request["traces"]
trace_info = {
"request_duration_ms": 1000 * (time.time() - request["trace_start"]),
"sum_trace_duration_ms": sum(t["duration_ms"] for t in traces),
"num_traces": len(traces),
"traces": traces,
}
if "text/html" in response.content_type and b"</body>" in response.body:
extra = json.dumps(trace_info, indent=2)
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
response.body = response.body.replace(b"</body>", extra_html)
elif "json" in response.content_type and response.body.startswith(b"{"):
data = json.loads(response.body.decode("utf8"))
if "_trace" not in data:
data["_trace"] = trace_info
response.body = json.dumps(data).encode("utf8")
@app.exception(Exception)
def on_exception(request, exception):
title = None
help = None
if isinstance(exception, NotFound):
status = 404
info = {}
message = exception.args[0]
elif isinstance(exception, InvalidUsage):
status = 405
info = {}
message = exception.args[0]
elif isinstance(exception, DatasetteError):
status = exception.status
info = exception.error_dict
message = exception.message
if exception.messagge_is_html:
message = Markup(message)
title = exception.title
else:
status = 500
info = {}
message = str(exception)
traceback.print_exc()
templates = ["500.html"]
if status != 500:
templates = ["{}.html".format(status)] + templates
info.update(
{"ok": False, "error": message, "status": status, "title": title}
)
if request is not None and request.path.split("?")[0].endswith(".json"):
r = response.json(info, status=status)
else:
template = self.jinja_env.select_template(templates)
r = response.html(template.render(info), status=status)
if self.cors:
r.headers["Access-Control-Allow-Origin"] = "*"
return r
# First time server starts up, calculate table counts for immutable databases
@app.listener("before_server_start")
async def setup_db(app, loop):
async def setup_db():
# First time server starts up, calculate table counts for immutable databases
for dbname, database in self.databases.items():
if not database.is_mutable:
await database.table_counts(limit=60 * 60 * 1000)
return app
return AsgiLifespan(
AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db
)
class DatasetteRouter(AsgiRouter):
def __init__(self, datasette, routes):
self.ds = datasette
super().__init__(routes)
async def handle_404(self, scope, receive, send):
# If URL has a trailing slash, redirect to URL without it
path = scope.get("raw_path", scope["path"].encode("utf8"))
if path.endswith(b"/"):
path = path.rstrip(b"/")
if scope["query_string"]:
path += b"?" + scope["query_string"]
await asgi_send_redirect(send, path.decode("latin1"))
else:
await super().handle_404(scope, receive, send)
async def handle_500(self, scope, receive, send, exception):
title = None
if isinstance(exception, NotFound):
status = 404
info = {}
message = exception.args[0]
elif isinstance(exception, DatasetteError):
status = exception.status
info = exception.error_dict
message = exception.message
if exception.messagge_is_html:
message = Markup(message)
title = exception.title
else:
status = 500
info = {}
message = str(exception)
traceback.print_exc()
templates = ["500.html"]
if status != 500:
templates = ["{}.html".format(status)] + templates
info.update({"ok": False, "error": message, "status": status, "title": title})
headers = {}
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
if scope["path"].split("?")[0].endswith(".json"):
await asgi_send_json(send, info, status=status, headers=headers)
else:
template = self.ds.jinja_env.select_template(templates)
await asgi_send_html(
send, template.render(info), status=status, headers=headers
)

Wyświetl plik

@ -1,4 +1,5 @@
import asyncio
import uvicorn
import click
from click import formatting
from click_default_group import DefaultGroup
@ -354,4 +355,4 @@ def serve(
asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks())
# Start the server
ds.app().run(host=host, port=port, debug=debug)
uvicorn.run(ds.app(), host=host, port=port, log_level="info")

Wyświetl plik

@ -88,5 +88,5 @@ def json_renderer(args, data, view_name):
content_type = "text/plain"
else:
body = json.dumps(data, cls=CustomJSONEncoder)
content_type = "application/json"
content_type = "application/json; charset=utf-8"
return {"body": body, "status_code": status_code, "content_type": content_type}

Wyświetl plik

@ -1,6 +1,7 @@
import asyncio
from contextlib import contextmanager
import time
import json
import traceback
tracers = {}
@ -32,15 +33,15 @@ def trace(type, **kwargs):
start = time.time()
yield
end = time.time()
trace = {
trace_info = {
"type": type,
"start": start,
"end": end,
"duration_ms": (end - start) * 1000,
"traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]),
}
trace.update(kwargs)
tracer.append(trace)
trace_info.update(kwargs)
tracer.append(trace_info)
@contextmanager
@ -53,3 +54,77 @@ def capture_traces(tracer):
tracers[task_id] = tracer
yield
del tracers[task_id]
class AsgiTracer:
# If the body is larger than this we don't attempt to append the trace
max_body_bytes = 1024 * 256 # 256 KB
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if b"_trace=1" not in scope.get("query_string", b"").split(b"&"):
await self.app(scope, receive, send)
return
trace_start = time.time()
traces = []
accumulated_body = b""
size_limit_exceeded = False
response_headers = []
async def wrapped_send(message):
nonlocal accumulated_body, size_limit_exceeded, response_headers
if message["type"] == "http.response.start":
response_headers = message["headers"]
await send(message)
return
if message["type"] != "http.response.body" or size_limit_exceeded:
await send(message)
return
# Accumulate body until the end or until size is exceeded
accumulated_body += message["body"]
if len(accumulated_body) > self.max_body_bytes:
await send(
{
"type": "http.response.body",
"body": accumulated_body,
"more_body": True,
}
)
size_limit_exceeded = True
return
if not message.get("more_body"):
# We have all the body - modify it and send the result
# TODO: What to do about Content-Type or other cases?
trace_info = {
"request_duration_ms": 1000 * (time.time() - trace_start),
"sum_trace_duration_ms": sum(t["duration_ms"] for t in traces),
"num_traces": len(traces),
"traces": traces,
}
try:
content_type = [
v.decode("utf8")
for k, v in response_headers
if k.lower() == b"content-type"
][0]
except IndexError:
content_type = ""
if "text/html" in content_type and b"</body>" in accumulated_body:
extra = json.dumps(trace_info, indent=2)
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
accumulated_body = accumulated_body.replace(b"</body>", extra_html)
elif "json" in content_type and accumulated_body.startswith(b"{"):
data = json.loads(accumulated_body.decode("utf8"))
if "_trace" not in data:
data["_trace"] = trace_info
accumulated_body = json.dumps(data).encode("utf8")
await send({"type": "http.response.body", "body": accumulated_body})
with capture_traces(traces):
await self.app(scope, receive, wrapped_send)

Wyświetl plik

@ -697,13 +697,13 @@ class LimitedWriter:
self.limit_bytes = limit_mb * 1024 * 1024
self.bytes_count = 0
def write(self, bytes):
async def write(self, bytes):
self.bytes_count += len(bytes)
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
raise WriteLimitExceeded(
"CSV contains more than {} bytes".format(self.limit_bytes)
)
self.writer.write(bytes)
await self.writer.write(bytes)
_infinities = {float("inf"), float("-inf")}
@ -741,3 +741,16 @@ def format_bytes(bytes):
return "{} {}".format(int(current), unit)
else:
return "{:.1f} {}".format(current, unit)
class RequestParameters(dict):
def get(self, name, default=None):
"Return first value in the list, if available"
try:
return super().get(name)[0]
except (KeyError, TypeError):
return default
def getlist(self, name, default=None):
"Return full list"
return super().get(name, default)

Wyświetl plik

@ -0,0 +1,377 @@
import json
from datasette.utils import RequestParameters
from mimetypes import guess_type
from urllib.parse import parse_qs, urlunparse
from pathlib import Path
from html import escape
import re
import aiofiles
class NotFound(Exception):
pass
class Request:
def __init__(self, scope):
self.scope = scope
@property
def method(self):
return self.scope["method"]
@property
def url(self):
return urlunparse(
(self.scheme, self.host, self.path, None, self.query_string, None)
)
@property
def scheme(self):
return self.scope.get("scheme") or "http"
@property
def headers(self):
return dict(
[
(k.decode("latin-1").lower(), v.decode("latin-1"))
for k, v in self.scope.get("headers") or []
]
)
@property
def host(self):
return self.headers.get("host") or "localhost"
@property
def path(self):
return (
self.scope.get("raw_path", self.scope["path"].encode("latin-1"))
).decode("latin-1")
@property
def query_string(self):
return (self.scope.get("query_string") or b"").decode("latin-1")
@property
def args(self):
return RequestParameters(parse_qs(qs=self.query_string))
@property
def raw_args(self):
return {key: value[0] for key, value in self.args.items()}
@classmethod
def fake(cls, path_with_query_string, method="GET", scheme="http"):
"Useful for constructing Request objects for tests"
path, _, query_string = path_with_query_string.partition("?")
scope = {
"http_version": "1.1",
"method": method,
"path": path,
"raw_path": path.encode("latin-1"),
"query_string": query_string.encode("latin-1"),
"scheme": scheme,
"type": "http",
}
return cls(scope)
class AsgiRouter:
def __init__(self, routes=None):
routes = routes or []
self.routes = [
# Compile any strings to regular expressions
((re.compile(pattern) if isinstance(pattern, str) else pattern), view)
for pattern, view in routes
]
async def __call__(self, scope, receive, send):
# Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves
path = scope["raw_path"].decode("ascii")
for regex, view in self.routes:
match = regex.match(path)
if match is not None:
new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
try:
return await view(new_scope, receive, send)
except Exception as exception:
return await self.handle_500(scope, receive, send, exception)
return await self.handle_404(scope, receive, send)
async def handle_404(self, scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 404,
"headers": [[b"content-type", b"text/html"]],
}
)
await send({"type": "http.response.body", "body": b"<h1>404</h1>"})
async def handle_500(self, scope, receive, send, exception):
await send(
{
"type": "http.response.start",
"status": 404,
"headers": [[b"content-type", b"text/html"]],
}
)
html = "<h1>500</h1><pre{}></pre>".format(escape(repr(exception)))
await send({"type": "http.response.body", "body": html.encode("latin-1")})
class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None):
self.app = app
on_startup = on_startup or []
on_shutdown = on_shutdown or []
if not isinstance(on_startup or [], list):
on_startup = [on_startup]
if not isinstance(on_shutdown or [], list):
on_shutdown = [on_shutdown]
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
for fn in self.on_startup:
await fn()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.on_shutdown:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
else:
await self.app(scope, receive, send)
class AsgiView:
def dispatch_request(self, request, *args, **kwargs):
handler = getattr(self, request.method.lower(), None)
return handler(request, *args, **kwargs)
@classmethod
def as_asgi(cls, *class_args, **class_kwargs):
async def view(scope, receive, send):
# Uses scope to create a request object, then dispatches that to
# self.get(...) or self.options(...) along with keyword arguments
# that were already tucked into scope["url_route"]["kwargs"] by
# the router, similar to how Django Channels works:
# https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter
request = Request(scope)
self = view.view_class(*class_args, **class_kwargs)
response = await self.dispatch_request(
request, **scope["url_route"]["kwargs"]
)
await response.asgi_send(send)
view.view_class = cls
view.__doc__ = cls.__doc__
view.__module__ = cls.__module__
view.__name__ = cls.__name__
return view
class AsgiStream:
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
self.stream_fn = stream_fn
self.status = status
self.headers = headers or {}
self.content_type = content_type
async def asgi_send(self, send):
# Remove any existing content-type header
headers = dict(
[(k, v) for k, v in self.headers.items() if k.lower() != "content-type"]
)
headers["content-type"] = self.content_type
await send(
{
"type": "http.response.start",
"status": self.status,
"headers": [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in headers.items()
],
}
)
w = AsgiWriter(send)
await self.stream_fn(w)
await send({"type": "http.response.body", "body": b""})
class AsgiWriter:
def __init__(self, send):
self.send = send
async def write(self, chunk):
await self.send(
{
"type": "http.response.body",
"body": chunk.encode("latin-1"),
"more_body": True,
}
)
async def asgi_send_json(send, info, status=200, headers=None):
headers = headers or {}
await asgi_send(
send,
json.dumps(info),
status=status,
headers=headers,
content_type="application/json; charset=utf-8",
)
async def asgi_send_html(send, html, status=200, headers=None):
headers = headers or {}
await asgi_send(
send, html, status=status, headers=headers, content_type="text/html"
)
async def asgi_send_redirect(send, location, status=302):
await asgi_send(
send,
"",
status=status,
headers={"Location": location},
content_type="text/html",
)
async def asgi_send(send, content, status, headers=None, content_type="text/plain"):
await asgi_start(send, status, headers, content_type)
await send({"type": "http.response.body", "body": content.encode("latin-1")})
async def asgi_start(send, status, headers=None, content_type="text/plain"):
headers = headers or {}
# Remove any existing content-type header
headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"])
headers["content-type"] = content_type
await send(
{
"type": "http.response.start",
"status": status,
"headers": [
[key.encode("latin1"), value.encode("latin1")]
for key, value in headers.items()
],
}
)
async def asgi_send_file(
send, filepath, filename=None, content_type=None, chunk_size=4096
):
headers = {}
if filename:
headers["Content-Disposition"] = 'attachment; filename="{}"'.format(filename)
first = True
async with aiofiles.open(str(filepath), mode="rb") as fp:
if first:
await asgi_start(
send,
200,
headers,
content_type or guess_type(str(filepath))[0] or "text/plain",
)
first = False
more_body = True
while more_body:
chunk = await fp.read(chunk_size)
more_body = len(chunk) == chunk_size
await send(
{"type": "http.response.body", "body": chunk, "more_body": more_body}
)
def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None):
async def inner_static(scope, receive, send):
path = scope["url_route"]["kwargs"]["path"]
full_path = (Path(root_path) / path).absolute()
# Ensure full_path is within root_path to avoid weird "../" tricks
try:
full_path.relative_to(root_path)
except ValueError:
await asgi_send_html(send, "404", 404)
return
first = True
try:
await asgi_send_file(send, full_path, chunk_size=chunk_size)
except FileNotFoundError:
await asgi_send_html(send, "404", 404)
return
return inner_static
class Response:
def __init__(self, body=None, status=200, headers=None, content_type="text/plain"):
self.body = body
self.status = status
self.headers = headers or {}
self.content_type = content_type
async def asgi_send(self, send):
headers = {}
headers.update(self.headers)
headers["content-type"] = self.content_type
await send(
{
"type": "http.response.start",
"status": self.status,
"headers": [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in headers.items()
],
}
)
body = self.body
if not isinstance(body, bytes):
body = body.encode("utf-8")
await send({"type": "http.response.body", "body": body})
@classmethod
def html(cls, body, status=200, headers=None):
return cls(
body,
status=status,
headers=headers,
content_type="text/html; charset=utf-8",
)
@classmethod
def text(cls, body, status=200, headers=None):
return cls(
body,
status=status,
headers=headers,
content_type="text/plain; charset=utf-8",
)
@classmethod
def redirect(cls, path, status=302, headers=None):
headers = headers or {}
headers["Location"] = path
return cls("", status=status, headers=headers)
class AsgiFileDownload:
def __init__(
self, filepath, filename=None, content_type="application/octet-stream"
):
self.filepath = filepath
self.filename = filename
self.content_type = content_type
async def asgi_send(self, send):
return await asgi_send_file(send, self.filepath, content_type=self.content_type)

Wyświetl plik

@ -7,9 +7,8 @@ import urllib
import jinja2
import pint
from sanic import response
from sanic.exceptions import NotFound
from sanic.views import HTTPMethodView
from html import escape
from datasette import __version__
from datasette.plugins import pm
@ -26,6 +25,14 @@ from datasette.utils import (
sqlite3,
to_css_class,
)
from datasette.utils.asgi import (
AsgiStream,
AsgiWriter,
AsgiRouter,
AsgiView,
NotFound,
Response,
)
ureg = pint.UnitRegistry()
@ -49,7 +56,14 @@ class DatasetteError(Exception):
self.messagge_is_html = messagge_is_html
class BaseView(HTTPMethodView):
class BaseView(AsgiView):
ds = None
async def head(self, *args, **kwargs):
response = await self.get(*args, **kwargs)
response.body = b""
return response
def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins:
seen_urls = set()
@ -104,7 +118,7 @@ class BaseView(HTTPMethodView):
datasette=self.ds,
):
body_scripts.append(jinja2.Markup(script))
return response.html(
return Response.html(
template.render(
{
**context,
@ -136,7 +150,7 @@ class DataView(BaseView):
self.ds = datasette
def options(self, request, *args, **kwargs):
r = response.text("ok")
r = Response.text("ok")
if self.ds.cors:
r.headers["Access-Control-Allow-Origin"] = "*"
return r
@ -146,7 +160,7 @@ class DataView(BaseView):
path = "{}?{}".format(path, request.query_string)
if remove_args:
path = path_with_removed_args(request, remove_args, path=path)
r = response.redirect(path)
r = Response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors:
r.headers["Access-Control-Allow-Origin"] = "*"
@ -195,17 +209,17 @@ class DataView(BaseView):
kwargs["table"] = table
if _format:
kwargs["as_format"] = ".{}".format(_format)
elif "table" in kwargs:
elif kwargs.get("table"):
kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
should_redirect = "/{}-{}".format(name, expected)
if "table" in kwargs:
if kwargs.get("table"):
should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
if "pk_path" in kwargs:
if kwargs.get("pk_path"):
should_redirect += "/" + kwargs["pk_path"]
if "as_format" in kwargs:
if kwargs.get("as_format"):
should_redirect += kwargs["as_format"]
if "as_db" in kwargs:
if kwargs.get("as_db"):
should_redirect += kwargs["as_db"]
if (
@ -246,7 +260,7 @@ class DataView(BaseView):
response_or_template_contexts = await self.data(
request, database, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
if isinstance(response_or_template_contexts, Response):
return response_or_template_contexts
else:
data, _, _ = response_or_template_contexts
@ -282,13 +296,13 @@ class DataView(BaseView):
if not first:
data, _, _ = await self.data(request, database, hash, **kwargs)
if first:
writer.writerow(headings)
await writer.writerow(headings)
first = False
next = data.get("next")
for row in data["rows"]:
if not expanded_columns:
# Simple path
writer.writerow(row)
await writer.writerow(row)
else:
# Look for {"value": "label": } dicts and expand
new_row = []
@ -298,10 +312,10 @@ class DataView(BaseView):
new_row.append(cell["label"])
else:
new_row.append(cell)
writer.writerow(new_row)
await writer.writerow(new_row)
except Exception as e:
print("caught this", e)
r.write(str(e))
await r.write(str(e))
return
content_type = "text/plain; charset=utf-8"
@ -315,7 +329,7 @@ class DataView(BaseView):
)
headers["Content-Disposition"] = disposition
return response.stream(stream_fn, headers=headers, content_type=content_type)
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL
@ -363,7 +377,7 @@ class DataView(BaseView):
response_or_template_contexts = await self.data(
request, database, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
if isinstance(response_or_template_contexts, Response):
return response_or_template_contexts
else:
@ -414,17 +428,11 @@ class DataView(BaseView):
if result is None:
raise NotFound("No data")
response_args = {
"content_type": result.get("content_type", "text/plain"),
"status": result.get("status_code", 200),
}
if type(result.get("body")) == bytes:
response_args["body_bytes"] = result.get("body")
else:
response_args["body"] = result.get("body")
r = response.HTTPResponse(**response_args)
r = Response(
body=result.get("body"),
status=result.get("status_code", 200),
content_type=result.get("content_type", "text/plain"),
)
else:
extras = {}
if callable(extra_template_data):

Wyświetl plik

@ -1,10 +1,9 @@
import os
from sanic import response
from datasette.utils import to_css_class, validate_sql_select
from datasette.utils.asgi import AsgiFileDownload
from .base import DataView, DatasetteError
from .base import DatasetteError, DataView
class DatabaseView(DataView):
@ -79,8 +78,8 @@ class DatabaseDownload(DataView):
if not db.path:
raise DatasetteError("Cannot download database", status=404)
filepath = db.path
return await response.file_stream(
return AsgiFileDownload(
filepath,
filename=os.path.basename(filepath),
mime_type="application/octet-stream",
content_type="application/octet-stream",
)

Wyświetl plik

@ -1,9 +1,8 @@
import hashlib
import json
from sanic import response
from datasette.utils import CustomJSONEncoder
from datasette.utils.asgi import Response
from datasette.version import __version__
from .base import BaseView
@ -104,9 +103,9 @@ class IndexView(BaseView):
headers = {}
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
return response.HTTPResponse(
return Response(
json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder),
content_type="application/json",
content_type="application/json; charset=utf-8",
headers=headers,
)
else:

Wyświetl plik

@ -1,5 +1,5 @@
import json
from sanic import response
from datasette.utils.asgi import Response
from .base import BaseView
@ -17,8 +17,10 @@ class JsonDataView(BaseView):
headers = {}
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
return response.HTTPResponse(
json.dumps(data), content_type="application/json", headers=headers
return Response(
json.dumps(data),
content_type="application/json; charset=utf-8",
headers=headers,
)
else:

Wyświetl plik

@ -3,13 +3,12 @@ import itertools
import json
import jinja2
from sanic.exceptions import NotFound
from sanic.request import RequestParameters
from datasette.plugins import pm
from datasette.utils import (
CustomRow,
QueryInterrupted,
RequestParameters,
append_querystring,
compound_keys_after_sql,
escape_sqlite,
@ -24,6 +23,7 @@ from datasette.utils import (
urlsafe_components,
value_as_boolean,
)
from datasette.utils.asgi import NotFound
from datasette.filters import Filters
from .base import DataView, DatasetteError, ureg
@ -219,8 +219,7 @@ class TableView(RowTableShared):
if is_view:
order_by = ""
# We roll our own query_string decoder because by default Sanic
# drops anything with an empty value e.g. ?name__exact=
# Ensure we don't drop anything with an empty value e.g. ?name__exact=
args = RequestParameters(
urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
)

Wyświetl plik

@ -4,7 +4,5 @@ filterwarnings=
ignore:Using or importing the ABCs::jinja2
# https://bugs.launchpad.net/beautifulsoup/+bug/1778909
ignore:Using or importing the ABCs::bs4.element
# Sanic verify_ssl=True
ignore:verify_ssl is deprecated::sanic
# Python 3.7 PendingDeprecationWarning: Task.current_task()
ignore:.*current_task.*:PendingDeprecationWarning

Wyświetl plik

@ -37,17 +37,18 @@ setup(
author="Simon Willison",
license="Apache License, Version 2.0",
url="https://github.com/simonw/datasette",
packages=find_packages(exclude='tests'),
packages=find_packages(exclude="tests"),
package_data={"datasette": ["templates/*.html"]},
include_package_data=True,
install_requires=[
"click>=6.7",
"click-default-group==1.2",
"Sanic==0.7.0",
"Jinja2==2.10.1",
"hupper==1.0",
"pint==0.8.1",
"pluggy>=0.12.0",
"uvicorn>=0.8.1",
"aiofiles==0.4.0",
],
entry_points="""
[console_scripts]
@ -60,6 +61,7 @@ setup(
"pytest-asyncio==0.10.0",
"aiohttp==3.5.3",
"beautifulsoup4==4.6.1",
"asgiref==3.1.2",
]
+ maybe_black
},

Wyświetl plik

@ -1,5 +1,7 @@
from datasette.app import Datasette
from datasette.utils import sqlite3
from asgiref.testing import ApplicationCommunicator
from asgiref.sync import async_to_sync
import itertools
import json
import os
@ -10,16 +12,82 @@ import sys
import string
import tempfile
import time
from urllib.parse import unquote
class TestResponse:
def __init__(self, status, headers, body):
self.status = status
self.headers = headers
self.body = body
@property
def json(self):
return json.loads(self.text)
@property
def text(self):
return self.body.decode("utf8")
class TestClient:
def __init__(self, sanic_test_client):
self.sanic_test_client = sanic_test_client
max_redirects = 5
def get(self, path, allow_redirects=True):
return self.sanic_test_client.get(
path, allow_redirects=allow_redirects, gather_request=False
def __init__(self, asgi_app):
self.asgi_app = asgi_app
@async_to_sync
async def get(self, path, allow_redirects=True, redirect_count=0, method="GET"):
return await self._get(path, allow_redirects, redirect_count, method)
async def _get(self, path, allow_redirects=True, redirect_count=0, method="GET"):
query_string = b""
if "?" in path:
path, _, query_string = path.partition("?")
query_string = query_string.encode("utf8")
instance = ApplicationCommunicator(
self.asgi_app,
{
"type": "http",
"http_version": "1.0",
"method": method,
"path": unquote(path),
"raw_path": path.encode("ascii"),
"query_string": query_string,
"headers": [[b"host", b"localhost"]],
},
)
await instance.send_input({"type": "http.request"})
# First message back should be response.start with headers and status
messages = []
start = await instance.receive_output(2)
messages.append(start)
assert start["type"] == "http.response.start"
headers = dict(
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
)
status = start["status"]
# Now loop until we run out of response.body
body = b""
while True:
message = await instance.receive_output(2)
messages.append(message)
assert message["type"] == "http.response.body"
body += message["body"]
if not message.get("more_body"):
break
response = TestResponse(status, headers, body)
if allow_redirects and response.status in (301, 302):
assert (
redirect_count < self.max_redirects
), "Redirected {} times, max_redirects={}".format(
redirect_count, self.max_redirects
)
location = response.headers["Location"]
return await self._get(
location, allow_redirects=True, redirect_count=redirect_count + 1
)
return response
def make_app_client(
@ -32,6 +100,7 @@ def make_app_client(
is_immutable=False,
extra_databases=None,
inspect_data=None,
static_mounts=None,
):
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, filename)
@ -73,9 +142,10 @@ def make_app_client(
plugins_dir=plugins_dir,
config=config,
inspect_data=inspect_data,
static_mounts=static_mounts,
)
ds.sqlite_functions.append(("sleep", 1, lambda n: time.sleep(float(n))))
client = TestClient(ds.app().test_client)
client = TestClient(ds.app())
client.ds = ds
yield client
@ -88,7 +158,7 @@ def app_client():
@pytest.fixture(scope="session")
def app_client_no_files():
ds = Datasette([])
client = TestClient(ds.app().test_client)
client = TestClient(ds.app())
client.ds = ds
yield client

Wyświetl plik

@ -22,6 +22,7 @@ import urllib
def test_homepage(app_client):
response = app_client.get("/.json")
assert response.status == 200
assert "application/json; charset=utf-8" == response.headers["content-type"]
assert response.json.keys() == {"fixtures": 0}.keys()
d = response.json["fixtures"]
assert d["name"] == "fixtures"
@ -771,8 +772,8 @@ def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pag
fetched.extend(response.json["rows"])
path = response.json["next_url"]
if path:
assert response.json["next"]
assert urllib.parse.urlencode({"_next": response.json["next"]}) in path
path = path.replace("http://localhost", "")
assert count < 30, "Possible infinite loop detected"
assert expected_rows == len(fetched)
@ -812,6 +813,8 @@ def test_paginate_compound_keys(app_client):
response = app_client.get(path)
fetched.extend(response.json["rows"])
path = response.json["next_url"]
if path:
path = path.replace("http://localhost", "")
assert page < 100
assert 1001 == len(fetched)
assert 21 == page
@ -833,6 +836,8 @@ def test_paginate_compound_keys_with_extra_filters(app_client):
response = app_client.get(path)
fetched.extend(response.json["rows"])
path = response.json["next_url"]
if path:
path = path.replace("http://localhost", "")
assert 2 == page
expected = [r[3] for r in generate_compound_rows(1001) if "d" in r[3]]
assert expected == [f["content"] for f in fetched]
@ -881,6 +886,8 @@ def test_sortable(app_client, query_string, sort_key, human_description_en):
assert human_description_en == response.json["human_description_en"]
fetched.extend(response.json["rows"])
path = response.json["next_url"]
if path:
path = path.replace("http://localhost", "")
assert 5 == page
expected = list(generate_sortable_rows(201))
expected.sort(key=sort_key)
@ -1191,6 +1198,7 @@ def test_plugins_json(app_client):
def test_versions_json(app_client):
response = app_client.get("/-/versions.json")
assert "python" in response.json
assert "3.0" == response.json.get("asgi")
assert "version" in response.json["python"]
assert "full" in response.json["python"]
assert "datasette" in response.json
@ -1236,6 +1244,8 @@ def test_page_size_matching_max_returned_rows(
fetched.extend(response.json["rows"])
assert len(response.json["rows"]) in (1, 50)
path = response.json["next_url"]
if path:
path = path.replace("http://localhost", "")
assert 201 == len(fetched)

Wyświetl plik

@ -46,7 +46,7 @@ def test_table_csv(app_client):
response = app_client.get("/fixtures/simple_primary_key.csv")
assert response.status == 200
assert not response.headers.get("Access-Control-Allow-Origin")
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert "text/plain; charset=utf-8" == response.headers["content-type"]
assert EXPECTED_TABLE_CSV == response.text
@ -59,7 +59,7 @@ def test_table_csv_cors_headers(app_client_with_cors):
def test_table_csv_with_labels(app_client):
response = app_client.get("/fixtures/facetable.csv?_labels=1")
assert response.status == 200
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert "text/plain; charset=utf-8" == response.headers["content-type"]
assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text
@ -68,14 +68,14 @@ def test_custom_sql_csv(app_client):
"/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
)
assert response.status == 200
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert "text/plain; charset=utf-8" == response.headers["content-type"]
assert EXPECTED_CUSTOM_CSV == response.text
def test_table_csv_download(app_client):
response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1")
assert response.status == 200
assert "text/csv; charset=utf-8" == response.headers["Content-Type"]
assert "text/csv; charset=utf-8" == response.headers["content-type"]
expected_disposition = 'attachment; filename="simple_primary_key.csv"'
assert expected_disposition == response.headers["Content-Disposition"]

Wyświetl plik

@ -8,6 +8,7 @@ from .fixtures import ( # noqa
METADATA,
)
import json
import pathlib
import pytest
import re
import urllib.parse
@ -16,6 +17,7 @@ import urllib.parse
def test_homepage(app_client_two_attached_databases):
response = app_client_two_attached_databases.get("/")
assert response.status == 200
assert "text/html; charset=utf-8" == response.headers["content-type"]
soup = Soup(response.body, "html.parser")
assert "Datasette Fixtures" == soup.find("h1").text
assert (
@ -44,6 +46,29 @@ def test_homepage(app_client_two_attached_databases):
] == table_links
def test_http_head(app_client):
response = app_client.get("/", method="HEAD")
assert response.status == 200
def test_static(app_client):
response = app_client.get("/-/static/app2.css")
assert response.status == 404
response = app_client.get("/-/static/app.css")
assert response.status == 200
assert "text/css" == response.headers["content-type"]
def test_static_mounts():
for client in make_app_client(
static_mounts=[("custom-static", str(pathlib.Path(__file__).parent))]
):
response = client.get("/custom-static/test_html.py")
assert response.status == 200
response = client.get("/custom-static/not_exists.py")
assert response.status == 404
def test_memory_database_page():
for client in make_app_client(memory=True):
response = client.get("/:memory:")

Wyświetl plik

@ -3,11 +3,11 @@ Tests for various datasette helper functions.
"""
from datasette import utils
from datasette.utils.asgi import Request
from datasette.filters import Filters
import json
import os
import pytest
from sanic.request import Request
import sqlite3
import tempfile
from unittest.mock import patch
@ -53,7 +53,7 @@ def test_urlsafe_components(path, expected):
],
)
def test_path_with_added_args(path, added_args, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
request = Request.fake(path)
actual = utils.path_with_added_args(request, added_args)
assert expected == actual
@ -67,11 +67,11 @@ def test_path_with_added_args(path, added_args, expected):
],
)
def test_path_with_removed_args(path, args, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
request = Request.fake(path)
actual = utils.path_with_removed_args(request, args)
assert expected == actual
# Run the test again but this time use the path= argument
request = Request("/".encode("utf8"), {}, "1.1", "GET", None)
request = Request.fake("/")
actual = utils.path_with_removed_args(request, args, path=path)
assert expected == actual
@ -84,7 +84,7 @@ def test_path_with_removed_args(path, args, expected):
],
)
def test_path_with_replaced_args(path, args, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
request = Request.fake(path)
actual = utils.path_with_replaced_args(request, args)
assert expected == actual
@ -363,7 +363,7 @@ def test_table_columns():
],
)
def test_path_with_format(path, format, extra_qs, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
request = Request.fake(path)
actual = utils.path_with_format(request, format, extra_qs)
assert expected == actual