datasette/datasette/app.py

488 wiersze
18 KiB
Python
Czysty Zwykły widok Historia

2018-05-13 12:58:28 +00:00
import hashlib
import itertools
2017-10-23 00:41:19 +00:00
import json
2018-05-13 12:58:28 +00:00
import os
import sqlite3
import sys
import traceback
2018-05-13 12:58:28 +00:00
import urllib.parse
from concurrent import futures
from pathlib import Path
import pluggy
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound
from datasette.views.base import (
HASH_BLOCK_SIZE,
DatasetteError,
RenderMixin,
ureg
)
from datasette.views.database import DatabaseDownload, DatabaseView
from datasette.views.index import IndexView
from datasette.views.table import RowView, TableView
from . import hookspecs
from .utils import (
detect_fts,
2018-05-13 12:58:28 +00:00
detect_spatialite,
escape_css_string,
escape_sqlite,
get_all_foreign_keys,
get_plugins,
module_from_path,
2018-05-13 12:58:28 +00:00
to_css_class
)
from .version import __version__
app_root = Path(__file__).parent.parent
pm = pluggy.PluginManager("datasette")
pm.add_hookspecs(hookspecs)
pm.load_setuptools_entrypoints("datasette")
class JsonDataView(RenderMixin):
def __init__(self, datasette, filename, data_callback):
self.ds = datasette
self.jinja_env = datasette.jinja_env
self.filename = filename
self.data_callback = data_callback
async def get(self, request, as_json):
data = self.data_callback()
if as_json:
headers = {}
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
return response.HTTPResponse(
json.dumps(data), content_type="application/json", headers=headers
)
else:
return self.render(["show_json.html"], filename=self.filename, data=data)
async def favicon(request):
return response.text("")
class Datasette:
def __init__(
self,
files,
num_threads=3,
cache_headers=True,
page_size=100,
max_returned_rows=1000,
sql_time_limit_ms=1000,
cors=False,
inspect_data=None,
metadata=None,
sqlite_extensions=None,
template_dir=None,
plugins_dir=None,
static_mounts=None,
):
self.files = files
self.num_threads = num_threads
self.executor = futures.ThreadPoolExecutor(max_workers=num_threads)
self.cache_headers = cache_headers
self.page_size = page_size
self.max_returned_rows = max_returned_rows
self.sql_time_limit_ms = sql_time_limit_ms
self.cors = cors
self._inspect = inspect_data
self.metadata = metadata or {}
self.sqlite_functions = []
self.sqlite_extensions = sqlite_extensions or []
self.template_dir = template_dir
self.plugins_dir = plugins_dir
self.static_mounts = static_mounts or []
# Execute plugins in constructor, to ensure they are available
# when the rest of `datasette inspect` executes
if self.plugins_dir:
for filename in os.listdir(self.plugins_dir):
filepath = os.path.join(self.plugins_dir, filename)
mod = module_from_path(filepath, name=filename)
try:
pm.register(mod)
except ValueError:
# Plugin already registered
pass
def app_css_hash(self):
if not hasattr(self, "_app_css_hash"):
self._app_css_hash = hashlib.sha1(
open(
os.path.join(str(app_root), "datasette/static/app.css")
).read().encode(
"utf8"
)
).hexdigest()[
:6
]
return self._app_css_hash
def get_canned_query(self, database_name, query_name):
query = self.metadata.get("databases", {}).get(database_name, {}).get(
"queries", {}
).get(
query_name
)
if query:
return {"name": query_name, "sql": query}
def asset_urls(self, key):
urls_or_dicts = (self.metadata.get(key) or [])
# Flatten list-of-lists from plugins:
urls_or_dicts += list(itertools.chain.from_iterable(getattr(pm.hook, key)()))
for url_or_dict in urls_or_dicts:
if isinstance(url_or_dict, dict):
yield {"url": url_or_dict["url"], "sri": url_or_dict.get("sri")}
else:
yield {"url": url_or_dict}
def extra_css_urls(self):
return self.asset_urls("extra_css_urls")
def extra_js_urls(self):
return self.asset_urls("extra_js_urls")
def update_with_inherited_metadata(self, metadata):
# Fills in source/license with defaults, if available
metadata.update(
{
"source": metadata.get("source") or self.metadata.get("source"),
"source_url": metadata.get("source_url")
or self.metadata.get("source_url"),
"license": metadata.get("license") or self.metadata.get("license"),
"license_url": metadata.get("license_url")
or self.metadata.get("license_url"),
}
)
def prepare_connection(self, conn):
conn.row_factory = sqlite3.Row
conn.text_factory = lambda x: str(x, "utf-8", "replace")
for name, num_args, func in self.sqlite_functions:
conn.create_function(name, num_args, func)
if self.sqlite_extensions:
conn.enable_load_extension(True)
for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension))
pm.hook.prepare_connection(conn=conn)
def inspect(self):
if not self._inspect:
self._inspect = {}
for filename in self.files:
path = Path(filename)
name = path.stem
if name in self._inspect:
raise Exception("Multiple files with same stem %s" % name)
# Calculate hash, efficiently
m = hashlib.sha256()
with path.open("rb") as fp:
while True:
data = fp.read(HASH_BLOCK_SIZE)
if not data:
break
m.update(data)
# List tables and their row counts
database_metadata = self.metadata.get("databases", {}).get(name, {})
tables = {}
views = []
with sqlite3.connect(
"file:{}?immutable=1".format(path), uri=True
) as conn:
self.prepare_connection(conn)
table_names = [
r["name"]
for r in conn.execute(
'select * from sqlite_master where type="table"'
)
]
views = [
v[0]
for v in conn.execute(
'select name from sqlite_master where type = "view"'
)
]
for table in table_names:
try:
count = conn.execute(
"select count(*) from {}".format(escape_sqlite(table))
).fetchone()[
0
]
except sqlite3.OperationalError:
# This can happen when running against a FTS virtual tables
# e.g. "select count(*) from some_fts;"
count = 0
# Does this table have a FTS table?
fts_table = detect_fts(conn, table)
# Figure out primary keys
table_info_rows = [
row
for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
if row[-1]
]
table_info_rows.sort(key=lambda row: row[-1])
primary_keys = [str(r[1]) for r in table_info_rows]
label_column = None
# If table has two columns, one of which is ID, then label_column is the other one
column_names = [
r[1]
for r in conn.execute(
"PRAGMA table_info({});".format(escape_sqlite(table))
).fetchall()
]
if (
column_names
and len(column_names) == 2
and "id" in column_names
):
label_column = [c for c in column_names if c != "id"][0]
table_metadata = database_metadata.get("tables", {}).get(
table, {}
)
tables[table] = {
"name": table,
"columns": column_names,
"primary_keys": primary_keys,
"count": count,
"label_column": label_column,
"hidden": table_metadata.get("hidden") or False,
"fts_table": fts_table,
}
foreign_keys = get_all_foreign_keys(conn)
for table, info in foreign_keys.items():
tables[table]["foreign_keys"] = info
# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [
r["name"]
for r in conn.execute(
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
"""
)
]
if detect_spatialite(conn):
# Also hide Spatialite internal tables
hidden_tables += [
"ElementaryGeometries",
"SpatialIndex",
"geometry_columns",
"spatial_ref_sys",
"spatialite_history",
"sql_statements_log",
"sqlite_sequence",
"views_geometry_columns",
"virts_geometry_columns",
] + [
r["name"]
for r in conn.execute(
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
"""
)
]
for t in tables.keys():
for hidden_table in hidden_tables:
if t == hidden_table or t.startswith(hidden_table):
tables[t]["hidden"] = True
continue
self._inspect[name] = {
"hash": m.hexdigest(),
"file": str(path),
"tables": tables,
"views": views,
}
return self._inspect
def register_custom_units(self):
"Register any custom units defined in the metadata.json with Pint"
for unit in self.metadata.get("custom_units", []):
ureg.define(unit)
def versions(self):
conn = sqlite3.connect(":memory:")
self.prepare_connection(conn)
sqlite_version = conn.execute("select sqlite_version()").fetchone()[0]
sqlite_extensions = {}
for extension, testsql, hasversion in (
("json1", "SELECT json('{}')", False),
("spatialite", "SELECT spatialite_version()", True),
):
try:
result = conn.execute(testsql)
if hasversion:
sqlite_extensions[extension] = result.fetchone()[0]
else:
sqlite_extensions[extension] = None
except Exception as e:
pass
# Figure out supported FTS versions
fts_versions = []
for fts in ("FTS5", "FTS4", "FTS3"):
try:
conn.execute(
"CREATE VIRTUAL TABLE v{fts} USING {fts} (t TEXT)".format(fts=fts)
)
fts_versions.append(fts)
except sqlite3.OperationalError:
continue
return {
"python": {
"version": ".".join(map(str, sys.version_info[:3])), "full": sys.version
},
"datasette": {"version": __version__},
"sqlite": {
"version": sqlite_version,
"fts_versions": fts_versions,
"extensions": sqlite_extensions,
},
}
2018-05-13 13:06:02 +00:00
def plugins(self):
return [
{
"name": p["name"],
"static": p["static_path"] is not None,
"templates": p["templates_path"] is not None,
"version": p.get("version"),
}
for p in get_plugins(pm)
]
def app(self):
app = Sanic(__name__)
default_templates = str(app_root / "datasette" / "templates")
template_paths = []
if self.template_dir:
template_paths.append(self.template_dir)
template_paths.extend(
[
plugin["templates_path"]
for plugin in get_plugins(pm)
if plugin["templates_path"]
]
)
template_paths.append(default_templates)
template_loader = ChoiceLoader(
[
FileSystemLoader(template_paths),
# Support {% extends "default:table.html" %}:
PrefixLoader(
{"default": FileSystemLoader(default_templates)}, delimiter=":"
),
]
)
self.jinja_env = Environment(loader=template_loader, autoescape=True)
self.jinja_env.filters["escape_css_string"] = escape_css_string
self.jinja_env.filters["quote_plus"] = lambda u: urllib.parse.quote_plus(u)
self.jinja_env.filters["escape_sqlite"] = escape_sqlite
self.jinja_env.filters["to_css_class"] = to_css_class
pm.hook.prepare_jinja2_environment(env=self.jinja_env)
app.add_route(IndexView.as_view(self), "/<as_json:(\.jsono?)?$>")
# TODO: /favicon.ico and /-/static/ deserve far-future cache expires
app.add_route(favicon, "/favicon.ico")
app.static("/-/static/", str(app_root / "datasette" / "static"))
for path, dirname in self.static_mounts:
app.static(path, dirname)
# Mount any plugin static/ directories
for plugin in get_plugins(pm):
if plugin["static_path"]:
modpath = "/-/static-plugins/{}/".format(plugin["name"])
app.static(modpath, plugin["static_path"])
app.add_route(
JsonDataView.as_view(self, "inspect.json", self.inspect),
"/-/inspect<as_json:(\.json)?$>",
)
app.add_route(
JsonDataView.as_view(self, "metadata.json", lambda: self.metadata),
"/-/metadata<as_json:(\.json)?$>",
)
app.add_route(
JsonDataView.as_view(self, "versions.json", self.versions),
"/-/versions<as_json:(\.json)?$>",
)
app.add_route(
2018-05-13 13:06:02 +00:00
JsonDataView.as_view(self, "plugins.json", self.plugins),
"/-/plugins<as_json:(\.json)?$>",
)
app.add_route(
DatabaseView.as_view(self), "/<db_name:[^/\.]+?><as_json:(\.jsono?)?$>"
)
app.add_route(
DatabaseDownload.as_view(self), "/<db_name:[^/]+?><as_db:(\.db)$>"
)
app.add_route(
TableView.as_view(self),
"/<db_name:[^/]+>/<table:[^/]+?><as_json:(\.jsono?)?$>",
)
app.add_route(
RowView.as_view(self),
"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(\.jsono?)?$>",
)
self.register_custom_units()
@app.exception(Exception)
def on_exception(request, exception):
title = None
if isinstance(exception, NotFound):
status = 404
info = {}
message = exception.args[0]
elif isinstance(exception, InvalidUsage):
status = 405
info = {}
message = exception.args[0]
elif isinstance(exception, DatasetteError):
status = exception.status
info = exception.error_dict
message = exception.message
title = exception.title
else:
status = 500
info = {}
message = str(exception)
traceback.print_exc()
templates = ["500.html"]
if status != 500:
templates = ["{}.html".format(status)] + templates
info.update(
{"ok": False, "error": message, "status": status, "title": title}
)
if request.path.split("?")[0].endswith(".json"):
return response.json(info, status=status)
else:
template = self.jinja_env.select_template(templates)
return response.html(template.render(info), status=status)
return app