kopia lustrzana https://github.com/simonw/datasette
Move .execute() from Datasette to Database
Refs #569 - I split this change out from #579pull/546/head^2
rodzic
8fc9a5d877
commit
a9909c29cc
|
@ -24,13 +24,11 @@ from .database import Database
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
QueryInterrupted,
|
QueryInterrupted,
|
||||||
Results,
|
|
||||||
escape_css_string,
|
escape_css_string,
|
||||||
escape_sqlite,
|
escape_sqlite,
|
||||||
get_plugins,
|
get_plugins,
|
||||||
module_from_path,
|
module_from_path,
|
||||||
sqlite3,
|
sqlite3,
|
||||||
sqlite_timelimit,
|
|
||||||
to_css_class,
|
to_css_class,
|
||||||
)
|
)
|
||||||
from .utils.asgi import (
|
from .utils.asgi import (
|
||||||
|
@ -42,13 +40,12 @@ from .utils.asgi import (
|
||||||
asgi_send_json,
|
asgi_send_json,
|
||||||
asgi_send_redirect,
|
asgi_send_redirect,
|
||||||
)
|
)
|
||||||
from .tracer import trace, AsgiTracer
|
from .tracer import AsgiTracer
|
||||||
from .plugins import pm, DEFAULT_PLUGINS
|
from .plugins import pm, DEFAULT_PLUGINS
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
app_root = Path(__file__).parent.parent
|
app_root = Path(__file__).parent.parent
|
||||||
|
|
||||||
connections = threading.local()
|
|
||||||
MEMORY = object()
|
MEMORY = object()
|
||||||
|
|
||||||
ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
|
ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
|
||||||
|
@ -336,6 +333,25 @@ class Datasette:
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
pm.hook.prepare_connection(conn=conn)
|
pm.hook.prepare_connection(conn=conn)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
db_name,
|
||||||
|
sql,
|
||||||
|
params=None,
|
||||||
|
truncate=False,
|
||||||
|
custom_time_limit=None,
|
||||||
|
page_size=None,
|
||||||
|
log_sql_errors=True,
|
||||||
|
):
|
||||||
|
return await self.databases[db_name].execute(
|
||||||
|
sql,
|
||||||
|
params=params,
|
||||||
|
truncate=truncate,
|
||||||
|
custom_time_limit=custom_time_limit,
|
||||||
|
page_size=page_size,
|
||||||
|
log_sql_errors=log_sql_errors,
|
||||||
|
)
|
||||||
|
|
||||||
async def expand_foreign_keys(self, database, table, column, values):
|
async def expand_foreign_keys(self, database, table, column, values):
|
||||||
"Returns dict mapping (column, value) -> label"
|
"Returns dict mapping (column, value) -> label"
|
||||||
labeled_fks = {}
|
labeled_fks = {}
|
||||||
|
@ -477,72 +493,6 @@ class Datasette:
|
||||||
.get(table, {})
|
.get(table, {})
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute_against_connection_in_thread(self, db_name, fn):
|
|
||||||
def in_thread():
|
|
||||||
conn = getattr(connections, db_name, None)
|
|
||||||
if not conn:
|
|
||||||
conn = self.databases[db_name].connect()
|
|
||||||
self.prepare_connection(conn)
|
|
||||||
setattr(connections, db_name, conn)
|
|
||||||
return fn(conn)
|
|
||||||
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread)
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
db_name,
|
|
||||||
sql,
|
|
||||||
params=None,
|
|
||||||
truncate=False,
|
|
||||||
custom_time_limit=None,
|
|
||||||
page_size=None,
|
|
||||||
log_sql_errors=True,
|
|
||||||
):
|
|
||||||
"""Executes sql against db_name in a thread"""
|
|
||||||
page_size = page_size or self.page_size
|
|
||||||
|
|
||||||
def sql_operation_in_thread(conn):
|
|
||||||
time_limit_ms = self.sql_time_limit_ms
|
|
||||||
if custom_time_limit and custom_time_limit < time_limit_ms:
|
|
||||||
time_limit_ms = custom_time_limit
|
|
||||||
|
|
||||||
with sqlite_timelimit(conn, time_limit_ms):
|
|
||||||
try:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(sql, params or {})
|
|
||||||
max_returned_rows = self.max_returned_rows
|
|
||||||
if max_returned_rows == page_size:
|
|
||||||
max_returned_rows += 1
|
|
||||||
if max_returned_rows and truncate:
|
|
||||||
rows = cursor.fetchmany(max_returned_rows + 1)
|
|
||||||
truncated = len(rows) > max_returned_rows
|
|
||||||
rows = rows[:max_returned_rows]
|
|
||||||
else:
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
truncated = False
|
|
||||||
except sqlite3.OperationalError as e:
|
|
||||||
if e.args == ("interrupted",):
|
|
||||||
raise QueryInterrupted(e, sql, params)
|
|
||||||
if log_sql_errors:
|
|
||||||
print(
|
|
||||||
"ERROR: conn={}, sql = {}, params = {}: {}".format(
|
|
||||||
conn, repr(sql), params, e
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if truncate:
|
|
||||||
return Results(rows, truncated, cursor.description)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return Results(rows, False, cursor.description)
|
|
||||||
|
|
||||||
with trace("sql", database=db_name, sql=sql.strip(), params=params):
|
|
||||||
results = await self.execute_against_connection_in_thread(
|
|
||||||
db_name, sql_operation_in_thread
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def register_renderers(self):
|
def register_renderers(self):
|
||||||
""" Register output renderers which output data in custom formats. """
|
""" Register output renderers which output data in custom formats. """
|
||||||
# Built-in renderers
|
# Built-in renderers
|
||||||
|
|
|
@ -1,17 +1,25 @@
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from .tracer import trace
|
||||||
from .utils import (
|
from .utils import (
|
||||||
QueryInterrupted,
|
QueryInterrupted,
|
||||||
|
Results,
|
||||||
detect_fts,
|
detect_fts,
|
||||||
detect_primary_keys,
|
detect_primary_keys,
|
||||||
detect_spatialite,
|
detect_spatialite,
|
||||||
get_all_foreign_keys,
|
get_all_foreign_keys,
|
||||||
get_outbound_foreign_keys,
|
get_outbound_foreign_keys,
|
||||||
|
sqlite_timelimit,
|
||||||
sqlite3,
|
sqlite3,
|
||||||
table_columns,
|
table_columns,
|
||||||
)
|
)
|
||||||
from .inspect import inspect_hash
|
from .inspect import inspect_hash
|
||||||
|
|
||||||
|
connections = threading.local()
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, ds, path=None, is_mutable=False, is_memory=False):
|
def __init__(self, ds, path=None, is_mutable=False, is_memory=False):
|
||||||
|
@ -45,6 +53,73 @@ class Database:
|
||||||
"file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False
|
"file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def execute_against_connection_in_thread(self, fn):
|
||||||
|
def in_thread():
|
||||||
|
conn = getattr(connections, self.name, None)
|
||||||
|
if not conn:
|
||||||
|
conn = self.connect()
|
||||||
|
self.ds.prepare_connection(conn)
|
||||||
|
setattr(connections, self.name, conn)
|
||||||
|
return fn(conn)
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
self.ds.executor, in_thread
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
sql,
|
||||||
|
params=None,
|
||||||
|
truncate=False,
|
||||||
|
custom_time_limit=None,
|
||||||
|
page_size=None,
|
||||||
|
log_sql_errors=True,
|
||||||
|
):
|
||||||
|
"""Executes sql against db_name in a thread"""
|
||||||
|
page_size = page_size or self.ds.page_size
|
||||||
|
|
||||||
|
def sql_operation_in_thread(conn):
|
||||||
|
time_limit_ms = self.ds.sql_time_limit_ms
|
||||||
|
if custom_time_limit and custom_time_limit < time_limit_ms:
|
||||||
|
time_limit_ms = custom_time_limit
|
||||||
|
|
||||||
|
with sqlite_timelimit(conn, time_limit_ms):
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(sql, params or {})
|
||||||
|
max_returned_rows = self.ds.max_returned_rows
|
||||||
|
if max_returned_rows == page_size:
|
||||||
|
max_returned_rows += 1
|
||||||
|
if max_returned_rows and truncate:
|
||||||
|
rows = cursor.fetchmany(max_returned_rows + 1)
|
||||||
|
truncated = len(rows) > max_returned_rows
|
||||||
|
rows = rows[:max_returned_rows]
|
||||||
|
else:
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
truncated = False
|
||||||
|
except sqlite3.OperationalError as e:
|
||||||
|
if e.args == ("interrupted",):
|
||||||
|
raise QueryInterrupted(e, sql, params)
|
||||||
|
if log_sql_errors:
|
||||||
|
print(
|
||||||
|
"ERROR: conn={}, sql = {}, params = {}: {}".format(
|
||||||
|
conn, repr(sql), params, e
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if truncate:
|
||||||
|
return Results(rows, truncated, cursor.description)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return Results(rows, False, cursor.description)
|
||||||
|
|
||||||
|
with trace("sql", database=self.name, sql=sql.strip(), params=params):
|
||||||
|
results = await self.execute_against_connection_in_thread(
|
||||||
|
sql_operation_in_thread
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self):
|
def size(self):
|
||||||
if self.is_memory:
|
if self.is_memory:
|
||||||
|
@ -62,8 +137,7 @@ class Database:
|
||||||
for table in await self.table_names():
|
for table in await self.table_names():
|
||||||
try:
|
try:
|
||||||
table_count = (
|
table_count = (
|
||||||
await self.ds.execute(
|
await self.execute(
|
||||||
self.name,
|
|
||||||
"select count(*) from [{}]".format(table),
|
"select count(*) from [{}]".format(table),
|
||||||
custom_time_limit=limit,
|
custom_time_limit=limit,
|
||||||
)
|
)
|
||||||
|
@ -89,32 +163,30 @@ class Database:
|
||||||
return Path(self.path).stem
|
return Path(self.path).stem
|
||||||
|
|
||||||
async def table_exists(self, table):
|
async def table_exists(self, table):
|
||||||
results = await self.ds.execute(
|
results = await self.execute(
|
||||||
self.name,
|
"select 1 from sqlite_master where type='table' and name=?", params=(table,)
|
||||||
"select 1 from sqlite_master where type='table' and name=?",
|
|
||||||
params=(table,),
|
|
||||||
)
|
)
|
||||||
return bool(results.rows)
|
return bool(results.rows)
|
||||||
|
|
||||||
async def table_names(self):
|
async def table_names(self):
|
||||||
results = await self.ds.execute(
|
results = await self.execute(
|
||||||
self.name, "select name from sqlite_master where type='table'"
|
"select name from sqlite_master where type='table'"
|
||||||
)
|
)
|
||||||
return [r[0] for r in results.rows]
|
return [r[0] for r in results.rows]
|
||||||
|
|
||||||
async def table_columns(self, table):
|
async def table_columns(self, table):
|
||||||
return await self.ds.execute_against_connection_in_thread(
|
return await self.execute_against_connection_in_thread(
|
||||||
self.name, lambda conn: table_columns(conn, table)
|
lambda conn: table_columns(conn, table)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def primary_keys(self, table):
|
async def primary_keys(self, table):
|
||||||
return await self.ds.execute_against_connection_in_thread(
|
return await self.execute_against_connection_in_thread(
|
||||||
self.name, lambda conn: detect_primary_keys(conn, table)
|
lambda conn: detect_primary_keys(conn, table)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def fts_table(self, table):
|
async def fts_table(self, table):
|
||||||
return await self.ds.execute_against_connection_in_thread(
|
return await self.execute_against_connection_in_thread(
|
||||||
self.name, lambda conn: detect_fts(conn, table)
|
lambda conn: detect_fts(conn, table)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def label_column_for_table(self, table):
|
async def label_column_for_table(self, table):
|
||||||
|
@ -124,8 +196,8 @@ class Database:
|
||||||
if explicit_label_column:
|
if explicit_label_column:
|
||||||
return explicit_label_column
|
return explicit_label_column
|
||||||
# If a table has two columns, one of which is ID, then label_column is the other one
|
# If a table has two columns, one of which is ID, then label_column is the other one
|
||||||
column_names = await self.ds.execute_against_connection_in_thread(
|
column_names = await self.execute_against_connection_in_thread(
|
||||||
self.name, lambda conn: table_columns(conn, table)
|
lambda conn: table_columns(conn, table)
|
||||||
)
|
)
|
||||||
# Is there a name or title column?
|
# Is there a name or title column?
|
||||||
name_or_title = [c for c in column_names if c in ("name", "title")]
|
name_or_title = [c for c in column_names if c in ("name", "title")]
|
||||||
|
@ -141,8 +213,8 @@ class Database:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def foreign_keys_for_table(self, table):
|
async def foreign_keys_for_table(self, table):
|
||||||
return await self.ds.execute_against_connection_in_thread(
|
return await self.execute_against_connection_in_thread(
|
||||||
self.name, lambda conn: get_outbound_foreign_keys(conn, table)
|
lambda conn: get_outbound_foreign_keys(conn, table)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def hidden_table_names(self):
|
async def hidden_table_names(self):
|
||||||
|
@ -150,18 +222,17 @@ class Database:
|
||||||
hidden_tables = [
|
hidden_tables = [
|
||||||
r[0]
|
r[0]
|
||||||
for r in (
|
for r in (
|
||||||
await self.ds.execute(
|
await self.execute(
|
||||||
self.name,
|
|
||||||
"""
|
"""
|
||||||
select name from sqlite_master
|
select name from sqlite_master
|
||||||
where rootpage = 0
|
where rootpage = 0
|
||||||
and sql like '%VIRTUAL TABLE%USING FTS%'
|
and sql like '%VIRTUAL TABLE%USING FTS%'
|
||||||
""",
|
"""
|
||||||
)
|
)
|
||||||
).rows
|
).rows
|
||||||
]
|
]
|
||||||
has_spatialite = await self.ds.execute_against_connection_in_thread(
|
has_spatialite = await self.execute_against_connection_in_thread(
|
||||||
self.name, detect_spatialite
|
detect_spatialite
|
||||||
)
|
)
|
||||||
if has_spatialite:
|
if has_spatialite:
|
||||||
# Also hide Spatialite internal tables
|
# Also hide Spatialite internal tables
|
||||||
|
@ -178,13 +249,12 @@ class Database:
|
||||||
] + [
|
] + [
|
||||||
r[0]
|
r[0]
|
||||||
for r in (
|
for r in (
|
||||||
await self.ds.execute(
|
await self.execute(
|
||||||
self.name,
|
|
||||||
"""
|
"""
|
||||||
select name from sqlite_master
|
select name from sqlite_master
|
||||||
where name like "idx_%"
|
where name like "idx_%"
|
||||||
and type = "table"
|
and type = "table"
|
||||||
""",
|
"""
|
||||||
)
|
)
|
||||||
).rows
|
).rows
|
||||||
]
|
]
|
||||||
|
@ -207,25 +277,20 @@ class Database:
|
||||||
return hidden_tables
|
return hidden_tables
|
||||||
|
|
||||||
async def view_names(self):
|
async def view_names(self):
|
||||||
results = await self.ds.execute(
|
results = await self.execute("select name from sqlite_master where type='view'")
|
||||||
self.name, "select name from sqlite_master where type='view'"
|
|
||||||
)
|
|
||||||
return [r[0] for r in results.rows]
|
return [r[0] for r in results.rows]
|
||||||
|
|
||||||
async def get_all_foreign_keys(self):
|
async def get_all_foreign_keys(self):
|
||||||
return await self.ds.execute_against_connection_in_thread(
|
return await self.execute_against_connection_in_thread(get_all_foreign_keys)
|
||||||
self.name, get_all_foreign_keys
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_outbound_foreign_keys(self, table):
|
async def get_outbound_foreign_keys(self, table):
|
||||||
return await self.ds.execute_against_connection_in_thread(
|
return await self.execute_against_connection_in_thread(
|
||||||
self.name, lambda conn: get_outbound_foreign_keys(conn, table)
|
lambda conn: get_outbound_foreign_keys(conn, table)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_table_definition(self, table, type_="table"):
|
async def get_table_definition(self, table, type_="table"):
|
||||||
table_definition_rows = list(
|
table_definition_rows = list(
|
||||||
await self.ds.execute(
|
await self.execute(
|
||||||
self.name,
|
|
||||||
"select sql from sqlite_master where name = :n and type=:t",
|
"select sql from sqlite_master where name = :n and type=:t",
|
||||||
{"n": table, "t": type_},
|
{"n": table, "t": type_},
|
||||||
)
|
)
|
||||||
|
|
Ładowanie…
Reference in New Issue