diff --git a/datasette/app.py b/datasette/app.py index beefa108..2d352097 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,3 +1,4 @@ +import asyncio import collections import hashlib import itertools @@ -5,6 +6,7 @@ import json import os import sqlite3 import sys +import threading import traceback import urllib.parse from concurrent import futures @@ -26,10 +28,13 @@ from .views.table import RowView, TableView from . import hookspecs from .utils import ( + InterruptedError, + Results, escape_css_string, escape_sqlite, get_plugins, module_from_path, + sqlite_timelimit, to_css_class ) from .inspect import inspect_hash, inspect_views, inspect_tables @@ -37,6 +42,7 @@ from .version import __version__ app_root = Path(__file__).parent.parent +connections = threading.local() pm = pluggy.PluginManager("datasette") pm.add_hookspecs(hookspecs) @@ -285,6 +291,68 @@ class Datasette: for p in get_plugins(pm) ] + async def execute( + self, + db_name, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + ): + """Executes sql against db_name in a thread""" + page_size = page_size or self.page_size + + def sql_operation_in_thread(): + conn = getattr(connections, db_name, None) + if not conn: + info = self.inspect()[db_name] + conn = sqlite3.connect( + "file:{}?immutable=1".format(info["file"]), + uri=True, + check_same_thread=False, + ) + self.prepare_connection(conn) + setattr(connections, db_name, conn) + + time_limit_ms = self.sql_time_limit_ms + if custom_time_limit and custom_time_limit < time_limit_ms: + time_limit_ms = custom_time_limit + + with sqlite_timelimit(conn, time_limit_ms): + try: + cursor = conn.cursor() + cursor.execute(sql, params or {}) + max_returned_rows = self.max_returned_rows + if max_returned_rows == page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except sqlite3.OperationalError as e: + if e.args == ('interrupted',): + raise InterruptedError(e) + print( + "ERROR: conn={}, sql = {}, params = {}: {}".format( + conn, repr(sql), params, e + ) + ) + raise + + if truncate: + return Results(rows, truncated, cursor.description) + + else: + return Results(rows, False, cursor.description) + + return await asyncio.get_event_loop().run_in_executor( + self.executor, sql_operation_in_thread + ) + def app(self): app = Sanic(__name__) default_templates = str(app_root / "datasette" / "templates") diff --git a/datasette/utils.py b/datasette/utils.py index 9c5ee433..c92311b5 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -36,6 +36,19 @@ class InterruptedError(Exception): pass +class Results: + def __init__(self, rows, truncated, description): + self.rows = rows + self.truncated = truncated + self.description = description + + def __iter__(self): + return iter(self.rows) + + def __len__(self): + return len(self.rows) + + def urlsafe_components(token): "Splits token on commas and URL decodes each component" return [ diff --git a/datasette/views/base.py b/datasette/views/base.py index 997350dd..7777dda1 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -2,7 +2,6 @@ import asyncio import json import re import sqlite3 -import threading import time import pint @@ -18,11 +17,9 @@ from datasette.utils import ( path_from_row_pks, path_with_added_args, path_with_ext, - sqlite_timelimit, to_css_class ) -connections = threading.local() ureg = pint.UnitRegistry() HASH_LENGTH = 7 @@ -128,68 +125,6 @@ class BaseView(RenderMixin): return name, expected, None - async def execute( - self, - db_name, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - ): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def sql_operation_in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.ds.inspect()[db_name] - conn = sqlite3.connect( - "file:{}?immutable=1".format(info["file"]), - uri=True, - check_same_thread=False, - ) - self.ds.prepare_connection(conn) - setattr(connections, db_name, conn) - - time_limit_ms = self.ds.sql_time_limit_ms - if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms: - time_limit_ms = custom_time_limit - - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params or {}) - max_returned_rows = self.max_returned_rows - if max_returned_rows == page_size: - max_returned_rows += 1 - if max_returned_rows and truncate: - rows = cursor.fetchmany(max_returned_rows + 1) - truncated = len(rows) > max_returned_rows - rows = rows[:max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except sqlite3.OperationalError as e: - if e.args == ('interrupted',): - raise InterruptedError(e) - print( - "ERROR: conn={}, sql = {}, params = {}: {}".format( - conn, repr(sql), params, e - ) - ) - raise - - if truncate: - return rows, truncated, cursor.description - - else: - return rows - - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread - ) - def get_templates(self, database, table=None): assert NotImplemented @@ -348,10 +283,10 @@ class BaseView(RenderMixin): extra_args = {} if params.get("_timelimit"): extra_args["custom_time_limit"] = int(params["_timelimit"]) - rows, truncated, description = await self.execute( + results = await self.ds.execute( name, sql, params, truncate=True, **extra_args ) - columns = [r[0] for r in description] + columns = [r[0] for r in results.description] templates = ["query-{}.html".format(to_css_class(name)), "query.html"] if canned_query: @@ -364,8 +299,8 @@ class BaseView(RenderMixin): return { "database": name, - "rows": rows, - "truncated": truncated, + "rows": results.rows, + "truncated": results.truncated, "columns": columns, "query": {"sql": sql, "params": params}, }, { diff --git a/datasette/views/table.py b/datasette/views/table.py index cc552eb5..36f79f40 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -73,7 +73,7 @@ class RowTableShared(BaseView): placeholders=", ".join(["?"] * len(set(values))), ) try: - results = await self.execute( + results = await self.ds.execute( database, sql, list(set(values)) ) except InterruptedError: @@ -132,7 +132,7 @@ class RowTableShared(BaseView): placeholders=", ".join(["?"] * len(ids_to_lookup)), ) try: - results = await self.execute( + results = await self.ds.execute( database, sql, list(set(ids_to_lookup)) ) except InterruptedError: @@ -246,7 +246,7 @@ class TableView(RowTableShared): is_view = bool( list( - await self.execute( + await self.ds.execute( name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {"n": table}, @@ -257,7 +257,7 @@ class TableView(RowTableShared): table_definition = None if is_view: view_definition = list( - await self.execute( + await self.ds.execute( name, 'select sql from sqlite_master where name = :n and type="view"', {"n": table}, @@ -265,7 +265,7 @@ class TableView(RowTableShared): )[0][0] else: table_definition_rows = list( - await self.execute( + await self.ds.execute( name, 'select sql from sqlite_master where name = :n and type="table"', {"n": table}, @@ -534,7 +534,7 @@ class TableView(RowTableShared): if request.raw_args.get("_timelimit"): extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) - rows, truncated, description = await self.execute( + results = await self.ds.execute( name, sql, params, truncate=True, **extra_args ) @@ -560,7 +560,7 @@ class TableView(RowTableShared): limit=facet_size+1, ) try: - facet_rows = await self.execute( + facet_rows_results = await self.ds.execute( name, facet_sql, params, truncate=False, custom_time_limit=self.ds.config["facet_time_limit_ms"], @@ -569,9 +569,9 @@ class TableView(RowTableShared): facet_results[column] = { "name": column, "results": facet_results_values, - "truncated": len(facet_rows) > facet_size, + "truncated": len(facet_rows_results) > facet_size, } - facet_rows = facet_rows[:facet_size] + facet_rows = facet_rows_results.rows[:facet_size] # Attempt to expand foreign keys into labels values = [row["value"] for row in facet_rows] expanded = (await self.expand_foreign_keys( @@ -602,8 +602,8 @@ class TableView(RowTableShared): except InterruptedError: facets_timed_out.append(column) - columns = [r[0] for r in description] - rows = list(rows) + columns = [r[0] for r in results.description] + rows = list(results.rows) filter_columns = columns[:] if use_rowid and filter_columns[0] == "rowid": @@ -641,7 +641,7 @@ class TableView(RowTableShared): filtered_table_rows_count = None if count_sql: try: - count_rows = list(await self.execute( + count_rows = list(await self.ds.execute( name, count_sql, from_sql_params )) filtered_table_rows_count = count_rows[0][0] @@ -665,7 +665,7 @@ class TableView(RowTableShared): ) distinct_values = None try: - distinct_values = await self.execute( + distinct_values = await self.ds.execute( name, suggested_facet_sql, from_sql_params, truncate=False, custom_time_limit=self.ds.config["facet_suggest_time_limit_ms"], @@ -701,7 +701,7 @@ class TableView(RowTableShared): display_columns, display_rows = await self.display_columns_and_rows( name, table, - description, + results.description, rows, link_column=not is_view, expand_foreign_keys=True, @@ -755,7 +755,7 @@ class TableView(RowTableShared): "table_definition": table_definition, "human_description_en": human_description_en, "rows": rows[:page_size], - "truncated": truncated, + "truncated": results.truncated, "table_rows_count": table_rows_count, "filtered_table_rows_count": filtered_table_rows_count, "columns": columns, @@ -790,12 +790,11 @@ class RowView(RowTableShared): params = {} for i, pk_value in enumerate(pk_values): params["p{}".format(i)] = pk_value - # rows, truncated, description = await self.execute(name, sql, params, truncate=True) - rows, truncated, description = await self.execute( + results = await self.ds.execute( name, sql, params, truncate=True ) - columns = [r[0] for r in description] - rows = list(rows) + columns = [r[0] for r in results.description] + rows = list(results.rows) if not rows: raise NotFound("Record not found: {}".format(pk_values)) @@ -803,7 +802,7 @@ class RowView(RowTableShared): display_columns, display_rows = await self.display_columns_and_rows( name, table, - description, + results.description, rows, link_column=False, expand_foreign_keys=True, @@ -874,7 +873,7 @@ class RowView(RowTableShared): ] ) try: - rows = list(await self.execute(name, sql, {"id": pk_values[0]})) + rows = list(await self.ds.execute(name, sql, {"id": pk_values[0]})) except sqlite3.OperationalError: # Almost certainly hit the timeout return []