Move .execute() from Datasette to Database

Refs #569 - I split this change out from #579
pull/546/head^2
Simon Willison 2019-11-15 14:49:45 -08:00
rodzic 8fc9a5d877
commit a9909c29cc
2 zmienionych plików z 121 dodań i 106 usunięć

Wyświetl plik

@ -24,13 +24,11 @@ from .database import Database
from .utils import ( from .utils import (
QueryInterrupted, QueryInterrupted,
Results,
escape_css_string, escape_css_string,
escape_sqlite, escape_sqlite,
get_plugins, get_plugins,
module_from_path, module_from_path,
sqlite3, sqlite3,
sqlite_timelimit,
to_css_class, to_css_class,
) )
from .utils.asgi import ( from .utils.asgi import (
@ -42,13 +40,12 @@ from .utils.asgi import (
asgi_send_json, asgi_send_json,
asgi_send_redirect, asgi_send_redirect,
) )
from .tracer import trace, AsgiTracer from .tracer import AsgiTracer
from .plugins import pm, DEFAULT_PLUGINS from .plugins import pm, DEFAULT_PLUGINS
from .version import __version__ from .version import __version__
app_root = Path(__file__).parent.parent app_root = Path(__file__).parent.parent
connections = threading.local()
MEMORY = object() MEMORY = object()
ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help")) ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
@ -336,6 +333,25 @@ class Datasette:
# pylint: disable=no-member # pylint: disable=no-member
pm.hook.prepare_connection(conn=conn) pm.hook.prepare_connection(conn=conn)
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
log_sql_errors=True,
):
return await self.databases[db_name].execute(
sql,
params=params,
truncate=truncate,
custom_time_limit=custom_time_limit,
page_size=page_size,
log_sql_errors=log_sql_errors,
)
async def expand_foreign_keys(self, database, table, column, values): async def expand_foreign_keys(self, database, table, column, values):
"Returns dict mapping (column, value) -> label" "Returns dict mapping (column, value) -> label"
labeled_fks = {} labeled_fks = {}
@ -477,72 +493,6 @@ class Datasette:
.get(table, {}) .get(table, {})
) )
async def execute_against_connection_in_thread(self, db_name, fn):
def in_thread():
conn = getattr(connections, db_name, None)
if not conn:
conn = self.databases[db_name].connect()
self.prepare_connection(conn)
setattr(connections, db_name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread)
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
log_sql_errors=True,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread(conn):
time_limit_ms = self.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
max_returned_rows = self.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ("interrupted",):
raise QueryInterrupted(e, sql, params)
if log_sql_errors:
print(
"ERROR: conn={}, sql = {}, params = {}: {}".format(
conn, repr(sql), params, e
)
)
raise
if truncate:
return Results(rows, truncated, cursor.description)
else:
return Results(rows, False, cursor.description)
with trace("sql", database=db_name, sql=sql.strip(), params=params):
results = await self.execute_against_connection_in_thread(
db_name, sql_operation_in_thread
)
return results
def register_renderers(self): def register_renderers(self):
""" Register output renderers which output data in custom formats. """ """ Register output renderers which output data in custom formats. """
# Built-in renderers # Built-in renderers

Wyświetl plik

@ -1,17 +1,25 @@
import asyncio
import contextlib
from pathlib import Path from pathlib import Path
import threading
from .tracer import trace
from .utils import ( from .utils import (
QueryInterrupted, QueryInterrupted,
Results,
detect_fts, detect_fts,
detect_primary_keys, detect_primary_keys,
detect_spatialite, detect_spatialite,
get_all_foreign_keys, get_all_foreign_keys,
get_outbound_foreign_keys, get_outbound_foreign_keys,
sqlite_timelimit,
sqlite3, sqlite3,
table_columns, table_columns,
) )
from .inspect import inspect_hash from .inspect import inspect_hash
connections = threading.local()
class Database: class Database:
def __init__(self, ds, path=None, is_mutable=False, is_memory=False): def __init__(self, ds, path=None, is_mutable=False, is_memory=False):
@ -45,6 +53,73 @@ class Database:
"file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False "file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False
) )
async def execute_against_connection_in_thread(self, fn):
def in_thread():
conn = getattr(connections, self.name, None)
if not conn:
conn = self.connect()
self.ds.prepare_connection(conn)
setattr(connections, self.name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(
self.ds.executor, in_thread
)
async def execute(
self,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
log_sql_errors=True,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.ds.page_size
def sql_operation_in_thread(conn):
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
max_returned_rows = self.ds.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ("interrupted",):
raise QueryInterrupted(e, sql, params)
if log_sql_errors:
print(
"ERROR: conn={}, sql = {}, params = {}: {}".format(
conn, repr(sql), params, e
)
)
raise
if truncate:
return Results(rows, truncated, cursor.description)
else:
return Results(rows, False, cursor.description)
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_against_connection_in_thread(
sql_operation_in_thread
)
return results
@property @property
def size(self): def size(self):
if self.is_memory: if self.is_memory:
@ -62,8 +137,7 @@ class Database:
for table in await self.table_names(): for table in await self.table_names():
try: try:
table_count = ( table_count = (
await self.ds.execute( await self.execute(
self.name,
"select count(*) from [{}]".format(table), "select count(*) from [{}]".format(table),
custom_time_limit=limit, custom_time_limit=limit,
) )
@ -89,32 +163,30 @@ class Database:
return Path(self.path).stem return Path(self.path).stem
async def table_exists(self, table): async def table_exists(self, table):
results = await self.ds.execute( results = await self.execute(
self.name, "select 1 from sqlite_master where type='table' and name=?", params=(table,)
"select 1 from sqlite_master where type='table' and name=?",
params=(table,),
) )
return bool(results.rows) return bool(results.rows)
async def table_names(self): async def table_names(self):
results = await self.ds.execute( results = await self.execute(
self.name, "select name from sqlite_master where type='table'" "select name from sqlite_master where type='table'"
) )
return [r[0] for r in results.rows] return [r[0] for r in results.rows]
async def table_columns(self, table): async def table_columns(self, table):
return await self.ds.execute_against_connection_in_thread( return await self.execute_against_connection_in_thread(
self.name, lambda conn: table_columns(conn, table) lambda conn: table_columns(conn, table)
) )
async def primary_keys(self, table): async def primary_keys(self, table):
return await self.ds.execute_against_connection_in_thread( return await self.execute_against_connection_in_thread(
self.name, lambda conn: detect_primary_keys(conn, table) lambda conn: detect_primary_keys(conn, table)
) )
async def fts_table(self, table): async def fts_table(self, table):
return await self.ds.execute_against_connection_in_thread( return await self.execute_against_connection_in_thread(
self.name, lambda conn: detect_fts(conn, table) lambda conn: detect_fts(conn, table)
) )
async def label_column_for_table(self, table): async def label_column_for_table(self, table):
@ -124,8 +196,8 @@ class Database:
if explicit_label_column: if explicit_label_column:
return explicit_label_column return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one # If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.ds.execute_against_connection_in_thread( column_names = await self.execute_against_connection_in_thread(
self.name, lambda conn: table_columns(conn, table) lambda conn: table_columns(conn, table)
) )
# Is there a name or title column? # Is there a name or title column?
name_or_title = [c for c in column_names if c in ("name", "title")] name_or_title = [c for c in column_names if c in ("name", "title")]
@ -141,8 +213,8 @@ class Database:
return None return None
async def foreign_keys_for_table(self, table): async def foreign_keys_for_table(self, table):
return await self.ds.execute_against_connection_in_thread( return await self.execute_against_connection_in_thread(
self.name, lambda conn: get_outbound_foreign_keys(conn, table) lambda conn: get_outbound_foreign_keys(conn, table)
) )
async def hidden_table_names(self): async def hidden_table_names(self):
@ -150,18 +222,17 @@ class Database:
hidden_tables = [ hidden_tables = [
r[0] r[0]
for r in ( for r in (
await self.ds.execute( await self.execute(
self.name,
""" """
select name from sqlite_master select name from sqlite_master
where rootpage = 0 where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%' and sql like '%VIRTUAL TABLE%USING FTS%'
""", """
) )
).rows ).rows
] ]
has_spatialite = await self.ds.execute_against_connection_in_thread( has_spatialite = await self.execute_against_connection_in_thread(
self.name, detect_spatialite detect_spatialite
) )
if has_spatialite: if has_spatialite:
# Also hide Spatialite internal tables # Also hide Spatialite internal tables
@ -178,13 +249,12 @@ class Database:
] + [ ] + [
r[0] r[0]
for r in ( for r in (
await self.ds.execute( await self.execute(
self.name,
""" """
select name from sqlite_master select name from sqlite_master
where name like "idx_%" where name like "idx_%"
and type = "table" and type = "table"
""", """
) )
).rows ).rows
] ]
@ -207,25 +277,20 @@ class Database:
return hidden_tables return hidden_tables
async def view_names(self): async def view_names(self):
results = await self.ds.execute( results = await self.execute("select name from sqlite_master where type='view'")
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows] return [r[0] for r in results.rows]
async def get_all_foreign_keys(self): async def get_all_foreign_keys(self):
return await self.ds.execute_against_connection_in_thread( return await self.execute_against_connection_in_thread(get_all_foreign_keys)
self.name, get_all_foreign_keys
)
async def get_outbound_foreign_keys(self, table): async def get_outbound_foreign_keys(self, table):
return await self.ds.execute_against_connection_in_thread( return await self.execute_against_connection_in_thread(
self.name, lambda conn: get_outbound_foreign_keys(conn, table) lambda conn: get_outbound_foreign_keys(conn, table)
) )
async def get_table_definition(self, table, type_="table"): async def get_table_definition(self, table, type_="table"):
table_definition_rows = list( table_definition_rows = list(
await self.ds.execute( await self.execute(
self.name,
"select sql from sqlite_master where name = :n and type=:t", "select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_}, {"n": table, "t": type_},
) )