import hashlib import itertools import json import os import sqlite3 import sys import traceback import urllib.parse from concurrent import futures from pathlib import Path import pluggy from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader from sanic import Sanic, response from sanic.exceptions import InvalidUsage, NotFound from datasette.views.base import ( HASH_BLOCK_SIZE, DatasetteError, RenderMixin, ureg ) from datasette.views.database import DatabaseDownload, DatabaseView from datasette.views.index import IndexView from datasette.views.table import RowView, TableView from . import hookspecs from .utils import ( detect_fts, detect_spatialite, escape_css_string, escape_sqlite, get_all_foreign_keys, get_plugins, module_from_path, to_css_class ) from .version import __version__ app_root = Path(__file__).parent.parent pm = pluggy.PluginManager("datasette") pm.add_hookspecs(hookspecs) pm.load_setuptools_entrypoints("datasette") class JsonDataView(RenderMixin): def __init__(self, datasette, filename, data_callback): self.ds = datasette self.jinja_env = datasette.jinja_env self.filename = filename self.data_callback = data_callback async def get(self, request, as_json): data = self.data_callback() if as_json: headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" return response.HTTPResponse( json.dumps(data), content_type="application/json", headers=headers ) else: return self.render(["show_json.html"], filename=self.filename, data=data) async def favicon(request): return response.text("") class Datasette: def __init__( self, files, num_threads=3, cache_headers=True, page_size=100, max_returned_rows=1000, sql_time_limit_ms=1000, cors=False, inspect_data=None, metadata=None, sqlite_extensions=None, template_dir=None, plugins_dir=None, static_mounts=None, ): self.files = files self.num_threads = num_threads self.executor = futures.ThreadPoolExecutor(max_workers=num_threads) self.cache_headers = cache_headers self.page_size = page_size self.max_returned_rows = max_returned_rows self.sql_time_limit_ms = sql_time_limit_ms self.cors = cors self._inspect = inspect_data self.metadata = metadata or {} self.sqlite_functions = [] self.sqlite_extensions = sqlite_extensions or [] self.template_dir = template_dir self.plugins_dir = plugins_dir self.static_mounts = static_mounts or [] # Execute plugins in constructor, to ensure they are available # when the rest of `datasette inspect` executes if self.plugins_dir: for filename in os.listdir(self.plugins_dir): filepath = os.path.join(self.plugins_dir, filename) mod = module_from_path(filepath, name=filename) try: pm.register(mod) except ValueError: # Plugin already registered pass def app_css_hash(self): if not hasattr(self, "_app_css_hash"): self._app_css_hash = hashlib.sha1( open( os.path.join(str(app_root), "datasette/static/app.css") ).read().encode( "utf8" ) ).hexdigest()[ :6 ] return self._app_css_hash def get_canned_query(self, database_name, query_name): query = self.metadata.get("databases", {}).get(database_name, {}).get( "queries", {} ).get( query_name ) if query: return {"name": query_name, "sql": query} def asset_urls(self, key): urls_or_dicts = (self.metadata.get(key) or []) # Flatten list-of-lists from plugins: urls_or_dicts += list(itertools.chain.from_iterable(getattr(pm.hook, key)())) for url_or_dict in urls_or_dicts: if isinstance(url_or_dict, dict): yield {"url": url_or_dict["url"], "sri": url_or_dict.get("sri")} else: yield {"url": url_or_dict} def extra_css_urls(self): return self.asset_urls("extra_css_urls") def extra_js_urls(self): return self.asset_urls("extra_js_urls") def update_with_inherited_metadata(self, metadata): # Fills in source/license with defaults, if available metadata.update( { "source": metadata.get("source") or self.metadata.get("source"), "source_url": metadata.get("source_url") or self.metadata.get("source_url"), "license": metadata.get("license") or self.metadata.get("license"), "license_url": metadata.get("license_url") or self.metadata.get("license_url"), } ) def prepare_connection(self, conn): conn.row_factory = sqlite3.Row conn.text_factory = lambda x: str(x, "utf-8", "replace") for name, num_args, func in self.sqlite_functions: conn.create_function(name, num_args, func) if self.sqlite_extensions: conn.enable_load_extension(True) for extension in self.sqlite_extensions: conn.execute("SELECT load_extension('{}')".format(extension)) pm.hook.prepare_connection(conn=conn) def inspect(self): if not self._inspect: self._inspect = {} for filename in self.files: path = Path(filename) name = path.stem if name in self._inspect: raise Exception("Multiple files with same stem %s" % name) # Calculate hash, efficiently m = hashlib.sha256() with path.open("rb") as fp: while True: data = fp.read(HASH_BLOCK_SIZE) if not data: break m.update(data) # List tables and their row counts database_metadata = self.metadata.get("databases", {}).get(name, {}) tables = {} views = [] with sqlite3.connect( "file:{}?immutable=1".format(path), uri=True ) as conn: self.prepare_connection(conn) table_names = [ r["name"] for r in conn.execute( 'select * from sqlite_master where type="table"' ) ] views = [ v[0] for v in conn.execute( 'select name from sqlite_master where type = "view"' ) ] for table in table_names: try: count = conn.execute( "select count(*) from {}".format(escape_sqlite(table)) ).fetchone()[ 0 ] except sqlite3.OperationalError: # This can happen when running against a FTS virtual tables # e.g. "select count(*) from some_fts;" count = 0 # Does this table have a FTS table? fts_table = detect_fts(conn, table) # Figure out primary keys table_info_rows = [ row for row in conn.execute( 'PRAGMA table_info("{}")'.format(table) ).fetchall() if row[-1] ] table_info_rows.sort(key=lambda row: row[-1]) primary_keys = [str(r[1]) for r in table_info_rows] label_column = None # If table has two columns, one of which is ID, then label_column is the other one column_names = [ r[1] for r in conn.execute( "PRAGMA table_info({});".format(escape_sqlite(table)) ).fetchall() ] if ( column_names and len(column_names) == 2 and "id" in column_names ): label_column = [c for c in column_names if c != "id"][0] table_metadata = database_metadata.get("tables", {}).get( table, {} ) tables[table] = { "name": table, "columns": column_names, "primary_keys": primary_keys, "count": count, "label_column": label_column, "hidden": table_metadata.get("hidden") or False, "fts_table": fts_table, } foreign_keys = get_all_foreign_keys(conn) for table, info in foreign_keys.items(): tables[table]["foreign_keys"] = info # Mark tables 'hidden' if they relate to FTS virtual tables hidden_tables = [ r["name"] for r in conn.execute( """ select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%' """ ) ] if detect_spatialite(conn): # Also hide Spatialite internal tables hidden_tables += [ "ElementaryGeometries", "SpatialIndex", "geometry_columns", "spatial_ref_sys", "spatialite_history", "sql_statements_log", "sqlite_sequence", "views_geometry_columns", "virts_geometry_columns", ] + [ r["name"] for r in conn.execute( """ select name from sqlite_master where name like "idx_%" and type = "table" """ ) ] for t in tables.keys(): for hidden_table in hidden_tables: if t == hidden_table or t.startswith(hidden_table): tables[t]["hidden"] = True continue self._inspect[name] = { "hash": m.hexdigest(), "file": str(path), "tables": tables, "views": views, } return self._inspect def register_custom_units(self): "Register any custom units defined in the metadata.json with Pint" for unit in self.metadata.get("custom_units", []): ureg.define(unit) def versions(self): conn = sqlite3.connect(":memory:") self.prepare_connection(conn) sqlite_version = conn.execute("select sqlite_version()").fetchone()[0] sqlite_extensions = {} for extension, testsql, hasversion in ( ("json1", "SELECT json('{}')", False), ("spatialite", "SELECT spatialite_version()", True), ): try: result = conn.execute(testsql) if hasversion: sqlite_extensions[extension] = result.fetchone()[0] else: sqlite_extensions[extension] = None except Exception as e: pass # Figure out supported FTS versions fts_versions = [] for fts in ("FTS5", "FTS4", "FTS3"): try: conn.execute( "CREATE VIRTUAL TABLE v{fts} USING {fts} (t TEXT)".format(fts=fts) ) fts_versions.append(fts) except sqlite3.OperationalError: continue return { "python": { "version": ".".join(map(str, sys.version_info[:3])), "full": sys.version }, "datasette": {"version": __version__}, "sqlite": { "version": sqlite_version, "fts_versions": fts_versions, "extensions": sqlite_extensions, }, } def plugins(self): return [ { "name": p["name"], "static": p["static_path"] is not None, "templates": p["templates_path"] is not None, "version": p.get("version"), } for p in get_plugins(pm) ] def app(self): app = Sanic(__name__) default_templates = str(app_root / "datasette" / "templates") template_paths = [] if self.template_dir: template_paths.append(self.template_dir) template_paths.extend( [ plugin["templates_path"] for plugin in get_plugins(pm) if plugin["templates_path"] ] ) template_paths.append(default_templates) template_loader = ChoiceLoader( [ FileSystemLoader(template_paths), # Support {% extends "default:table.html" %}: PrefixLoader( {"default": FileSystemLoader(default_templates)}, delimiter=":" ), ] ) self.jinja_env = Environment(loader=template_loader, autoescape=True) self.jinja_env.filters["escape_css_string"] = escape_css_string self.jinja_env.filters["quote_plus"] = lambda u: urllib.parse.quote_plus(u) self.jinja_env.filters["escape_sqlite"] = escape_sqlite self.jinja_env.filters["to_css_class"] = to_css_class pm.hook.prepare_jinja2_environment(env=self.jinja_env) app.add_route(IndexView.as_view(self), "/") # TODO: /favicon.ico and /-/static/ deserve far-future cache expires app.add_route(favicon, "/favicon.ico") app.static("/-/static/", str(app_root / "datasette" / "static")) for path, dirname in self.static_mounts: app.static(path, dirname) # Mount any plugin static/ directories for plugin in get_plugins(pm): if plugin["static_path"]: modpath = "/-/static-plugins/{}/".format(plugin["name"]) app.static(modpath, plugin["static_path"]) app.add_route( JsonDataView.as_view(self, "inspect.json", self.inspect), "/-/inspect", ) app.add_route( JsonDataView.as_view(self, "metadata.json", lambda: self.metadata), "/-/metadata", ) app.add_route( JsonDataView.as_view(self, "versions.json", self.versions), "/-/versions", ) app.add_route( JsonDataView.as_view(self, "plugins.json", self.plugins), "/-/plugins", ) app.add_route( DatabaseView.as_view(self), "/" ) app.add_route( DatabaseDownload.as_view(self), "/" ) app.add_route( TableView.as_view(self), "//", ) app.add_route( RowView.as_view(self), "///", ) self.register_custom_units() @app.exception(Exception) def on_exception(request, exception): title = None if isinstance(exception, NotFound): status = 404 info = {} message = exception.args[0] elif isinstance(exception, InvalidUsage): status = 405 info = {} message = exception.args[0] elif isinstance(exception, DatasetteError): status = exception.status info = exception.error_dict message = exception.message title = exception.title else: status = 500 info = {} message = str(exception) traceback.print_exc() templates = ["500.html"] if status != 500: templates = ["{}.html".format(status)] + templates info.update( {"ok": False, "error": message, "status": status, "title": title} ) if request.path.split("?")[0].endswith(".json"): return response.json(info, status=status) else: template = self.jinja_env.select_template(templates) return response.html(template.render(info), status=status) return app