Moved .execute() method from BaseView to Datasette class

Also introduced new Results() class with results.truncated, results.description, results.rows
columns
Simon Willison 2018-05-24 17:15:37 -07:00
rodzic 28a52fcffb
commit 81df47e8d9
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 17E2DEA2588B7F52
4 zmienionych plików z 105 dodań i 90 usunięć

Wyświetl plik

@ -1,3 +1,4 @@
import asyncio
import collections
import hashlib
import itertools
@ -5,6 +6,7 @@ import json
import os
import sqlite3
import sys
import threading
import traceback
import urllib.parse
from concurrent import futures
@ -26,10 +28,13 @@ from .views.table import RowView, TableView
from . import hookspecs
from .utils import (
InterruptedError,
Results,
escape_css_string,
escape_sqlite,
get_plugins,
module_from_path,
sqlite_timelimit,
to_css_class
)
from .inspect import inspect_hash, inspect_views, inspect_tables
@ -37,6 +42,7 @@ from .version import __version__
app_root = Path(__file__).parent.parent
connections = threading.local()
pm = pluggy.PluginManager("datasette")
pm.add_hookspecs(hookspecs)
@ -285,6 +291,68 @@ class Datasette:
for p in get_plugins(pm)
]
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.inspect()[db_name]
conn = sqlite3.connect(
"file:{}?immutable=1".format(info["file"]),
uri=True,
check_same_thread=False,
)
self.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
max_returned_rows = self.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
raise InterruptedError(e)
print(
"ERROR: conn={}, sql = {}, params = {}: {}".format(
conn, repr(sql), params, e
)
)
raise
if truncate:
return Results(rows, truncated, cursor.description)
else:
return Results(rows, False, cursor.description)
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
def app(self):
app = Sanic(__name__)
default_templates = str(app_root / "datasette" / "templates")

Wyświetl plik

@ -36,6 +36,19 @@ class InterruptedError(Exception):
pass
class Results:
def __init__(self, rows, truncated, description):
self.rows = rows
self.truncated = truncated
self.description = description
def __iter__(self):
return iter(self.rows)
def __len__(self):
return len(self.rows)
def urlsafe_components(token):
"Splits token on commas and URL decodes each component"
return [

Wyświetl plik

@ -2,7 +2,6 @@ import asyncio
import json
import re
import sqlite3
import threading
import time
import pint
@ -18,11 +17,9 @@ from datasette.utils import (
path_from_row_pks,
path_with_added_args,
path_with_ext,
sqlite_timelimit,
to_css_class
)
connections = threading.local()
ureg = pint.UnitRegistry()
HASH_LENGTH = 7
@ -128,68 +125,6 @@ class BaseView(RenderMixin):
return name, expected, None
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.ds.inspect()[db_name]
conn = sqlite3.connect(
"file:{}?immutable=1".format(info["file"]),
uri=True,
check_same_thread=False,
)
self.ds.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
max_returned_rows = self.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
raise InterruptedError(e)
print(
"ERROR: conn={}, sql = {}, params = {}: {}".format(
conn, repr(sql), params, e
)
)
raise
if truncate:
return rows, truncated, cursor.description
else:
return rows
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
def get_templates(self, database, table=None):
assert NotImplemented
@ -348,10 +283,10 @@ class BaseView(RenderMixin):
extra_args = {}
if params.get("_timelimit"):
extra_args["custom_time_limit"] = int(params["_timelimit"])
rows, truncated, description = await self.execute(
results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
)
columns = [r[0] for r in description]
columns = [r[0] for r in results.description]
templates = ["query-{}.html".format(to_css_class(name)), "query.html"]
if canned_query:
@ -364,8 +299,8 @@ class BaseView(RenderMixin):
return {
"database": name,
"rows": rows,
"truncated": truncated,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
}, {

Wyświetl plik

@ -73,7 +73,7 @@ class RowTableShared(BaseView):
placeholders=", ".join(["?"] * len(set(values))),
)
try:
results = await self.execute(
results = await self.ds.execute(
database, sql, list(set(values))
)
except InterruptedError:
@ -132,7 +132,7 @@ class RowTableShared(BaseView):
placeholders=", ".join(["?"] * len(ids_to_lookup)),
)
try:
results = await self.execute(
results = await self.ds.execute(
database, sql, list(set(ids_to_lookup))
)
except InterruptedError:
@ -246,7 +246,7 @@ class TableView(RowTableShared):
is_view = bool(
list(
await self.execute(
await self.ds.execute(
name,
"SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n",
{"n": table},
@ -257,7 +257,7 @@ class TableView(RowTableShared):
table_definition = None
if is_view:
view_definition = list(
await self.execute(
await self.ds.execute(
name,
'select sql from sqlite_master where name = :n and type="view"',
{"n": table},
@ -265,7 +265,7 @@ class TableView(RowTableShared):
)[0][0]
else:
table_definition_rows = list(
await self.execute(
await self.ds.execute(
name,
'select sql from sqlite_master where name = :n and type="table"',
{"n": table},
@ -534,7 +534,7 @@ class TableView(RowTableShared):
if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
rows, truncated, description = await self.execute(
results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
)
@ -560,7 +560,7 @@ class TableView(RowTableShared):
limit=facet_size+1,
)
try:
facet_rows = await self.execute(
facet_rows_results = await self.ds.execute(
name, facet_sql, params,
truncate=False,
custom_time_limit=self.ds.config["facet_time_limit_ms"],
@ -569,9 +569,9 @@ class TableView(RowTableShared):
facet_results[column] = {
"name": column,
"results": facet_results_values,
"truncated": len(facet_rows) > facet_size,
"truncated": len(facet_rows_results) > facet_size,
}
facet_rows = facet_rows[:facet_size]
facet_rows = facet_rows_results.rows[:facet_size]
# Attempt to expand foreign keys into labels
values = [row["value"] for row in facet_rows]
expanded = (await self.expand_foreign_keys(
@ -602,8 +602,8 @@ class TableView(RowTableShared):
except InterruptedError:
facets_timed_out.append(column)
columns = [r[0] for r in description]
rows = list(rows)
columns = [r[0] for r in results.description]
rows = list(results.rows)
filter_columns = columns[:]
if use_rowid and filter_columns[0] == "rowid":
@ -641,7 +641,7 @@ class TableView(RowTableShared):
filtered_table_rows_count = None
if count_sql:
try:
count_rows = list(await self.execute(
count_rows = list(await self.ds.execute(
name, count_sql, from_sql_params
))
filtered_table_rows_count = count_rows[0][0]
@ -665,7 +665,7 @@ class TableView(RowTableShared):
)
distinct_values = None
try:
distinct_values = await self.execute(
distinct_values = await self.ds.execute(
name, suggested_facet_sql, from_sql_params,
truncate=False,
custom_time_limit=self.ds.config["facet_suggest_time_limit_ms"],
@ -701,7 +701,7 @@ class TableView(RowTableShared):
display_columns, display_rows = await self.display_columns_and_rows(
name,
table,
description,
results.description,
rows,
link_column=not is_view,
expand_foreign_keys=True,
@ -755,7 +755,7 @@ class TableView(RowTableShared):
"table_definition": table_definition,
"human_description_en": human_description_en,
"rows": rows[:page_size],
"truncated": truncated,
"truncated": results.truncated,
"table_rows_count": table_rows_count,
"filtered_table_rows_count": filtered_table_rows_count,
"columns": columns,
@ -790,12 +790,11 @@ class RowView(RowTableShared):
params = {}
for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value
# rows, truncated, description = await self.execute(name, sql, params, truncate=True)
rows, truncated, description = await self.execute(
results = await self.ds.execute(
name, sql, params, truncate=True
)
columns = [r[0] for r in description]
rows = list(rows)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
raise NotFound("Record not found: {}".format(pk_values))
@ -803,7 +802,7 @@ class RowView(RowTableShared):
display_columns, display_rows = await self.display_columns_and_rows(
name,
table,
description,
results.description,
rows,
link_column=False,
expand_foreign_keys=True,
@ -874,7 +873,7 @@ class RowView(RowTableShared):
]
)
try:
rows = list(await self.execute(name, sql, {"id": pk_values[0]}))
rows = list(await self.ds.execute(name, sql, {"id": pk_values[0]}))
except sqlite3.OperationalError:
# Almost certainly hit the timeout
return []