kopia lustrzana https://github.com/simonw/datasette
492 wiersze
16 KiB
Python
492 wiersze
16 KiB
Python
import asyncio
|
|
from collections import namedtuple
|
|
from pathlib import Path
|
|
import janus
|
|
import queue
|
|
import sys
|
|
import threading
|
|
import uuid
|
|
|
|
from .tracer import trace
|
|
from .utils import (
|
|
detect_fts,
|
|
detect_primary_keys,
|
|
detect_spatialite,
|
|
get_all_foreign_keys,
|
|
get_outbound_foreign_keys,
|
|
sqlite_timelimit,
|
|
sqlite3,
|
|
table_columns,
|
|
table_column_details,
|
|
)
|
|
from .inspect import inspect_hash
|
|
|
|
connections = threading.local()
|
|
|
|
AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file"))
|
|
|
|
|
|
class Database:
|
|
def __init__(
|
|
self, ds, path=None, is_mutable=False, is_memory=False, memory_name=None
|
|
):
|
|
self.name = None
|
|
self.ds = ds
|
|
self.path = path
|
|
self.is_mutable = is_mutable
|
|
self.is_memory = is_memory
|
|
self.memory_name = memory_name
|
|
if memory_name is not None:
|
|
self.is_memory = True
|
|
self.is_mutable = True
|
|
self.hash = None
|
|
self.cached_size = None
|
|
self._cached_table_counts = None
|
|
self._write_thread = None
|
|
self._write_queue = None
|
|
if not self.is_mutable and not self.is_memory:
|
|
p = Path(path)
|
|
self.hash = inspect_hash(p)
|
|
self.cached_size = p.stat().st_size
|
|
|
|
@property
|
|
def cached_table_counts(self):
|
|
if self._cached_table_counts is not None:
|
|
return self._cached_table_counts
|
|
# Maybe use self.ds.inspect_data to populate cached_table_counts
|
|
if self.ds.inspect_data and self.ds.inspect_data.get(self.name):
|
|
self._cached_table_counts = {
|
|
key: value["count"]
|
|
for key, value in self.ds.inspect_data[self.name]["tables"].items()
|
|
}
|
|
return self._cached_table_counts
|
|
|
|
def suggest_name(self):
|
|
if self.path:
|
|
return Path(self.path).stem
|
|
elif self.memory_name:
|
|
return self.memory_name
|
|
else:
|
|
return "db"
|
|
|
|
def connect(self, write=False):
|
|
if self.memory_name:
|
|
uri = "file:{}?mode=memory&cache=shared".format(self.memory_name)
|
|
conn = sqlite3.connect(
|
|
uri,
|
|
uri=True,
|
|
check_same_thread=False,
|
|
)
|
|
if not write:
|
|
conn.execute("PRAGMA query_only=1")
|
|
return conn
|
|
if self.is_memory:
|
|
return sqlite3.connect(":memory:", uri=True)
|
|
# mode=ro or immutable=1?
|
|
if self.is_mutable:
|
|
qs = "?mode=ro"
|
|
else:
|
|
qs = "?immutable=1"
|
|
assert not (write and not self.is_mutable)
|
|
if write:
|
|
qs = ""
|
|
return sqlite3.connect(
|
|
f"file:{self.path}{qs}", uri=True, check_same_thread=False
|
|
)
|
|
|
|
async def execute_write(self, sql, params=None, block=True):
|
|
def _inner(conn):
|
|
with conn:
|
|
return conn.execute(sql, params or [])
|
|
|
|
with trace("sql", database=self.name, sql=sql.strip(), params=params):
|
|
results = await self.execute_write_fn(_inner, block=block)
|
|
return results
|
|
|
|
async def execute_write_script(self, sql, block=True):
|
|
def _inner(conn):
|
|
with conn:
|
|
return conn.executescript(sql)
|
|
|
|
with trace("sql", database=self.name, sql=sql.strip(), executescript=True):
|
|
results = await self.execute_write_fn(_inner, block=block)
|
|
return results
|
|
|
|
async def execute_write_many(self, sql, params_seq, block=True):
|
|
def _inner(conn):
|
|
count = 0
|
|
|
|
def count_params(params):
|
|
nonlocal count
|
|
for param in params:
|
|
count += 1
|
|
yield param
|
|
|
|
with conn:
|
|
return conn.executemany(sql, count_params(params_seq)), count
|
|
|
|
with trace(
|
|
"sql", database=self.name, sql=sql.strip(), executemany=True
|
|
) as kwargs:
|
|
results, count = await self.execute_write_fn(_inner, block=block)
|
|
kwargs["count"] = count
|
|
return results
|
|
|
|
async def execute_write_fn(self, fn, block=True):
|
|
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
|
|
if self._write_queue is None:
|
|
self._write_queue = queue.Queue()
|
|
if self._write_thread is None:
|
|
self._write_thread = threading.Thread(
|
|
target=self._execute_writes, daemon=True
|
|
)
|
|
self._write_thread.start()
|
|
reply_queue = janus.Queue()
|
|
self._write_queue.put(WriteTask(fn, task_id, reply_queue))
|
|
if block:
|
|
result = await reply_queue.async_q.get()
|
|
if isinstance(result, Exception):
|
|
raise result
|
|
else:
|
|
return result
|
|
else:
|
|
return task_id
|
|
|
|
def _execute_writes(self):
|
|
# Infinite looping thread that protects the single write connection
|
|
# to this database
|
|
conn_exception = None
|
|
conn = None
|
|
try:
|
|
conn = self.connect(write=True)
|
|
self.ds._prepare_connection(conn, self.name)
|
|
except Exception as e:
|
|
conn_exception = e
|
|
while True:
|
|
task = self._write_queue.get()
|
|
if conn_exception is not None:
|
|
result = conn_exception
|
|
else:
|
|
try:
|
|
result = task.fn(conn)
|
|
except Exception as e:
|
|
sys.stderr.write("{}\n".format(e))
|
|
sys.stderr.flush()
|
|
result = e
|
|
task.reply_queue.sync_q.put(result)
|
|
|
|
async def execute_fn(self, fn):
|
|
def in_thread():
|
|
conn = getattr(connections, self.name, None)
|
|
if not conn:
|
|
conn = self.connect()
|
|
self.ds._prepare_connection(conn, self.name)
|
|
setattr(connections, self.name, conn)
|
|
return fn(conn)
|
|
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
self.ds.executor, in_thread
|
|
)
|
|
|
|
async def execute(
|
|
self,
|
|
sql,
|
|
params=None,
|
|
truncate=False,
|
|
custom_time_limit=None,
|
|
page_size=None,
|
|
log_sql_errors=True,
|
|
):
|
|
"""Executes sql against db_name in a thread"""
|
|
page_size = page_size or self.ds.page_size
|
|
|
|
def sql_operation_in_thread(conn):
|
|
time_limit_ms = self.ds.sql_time_limit_ms
|
|
if custom_time_limit and custom_time_limit < time_limit_ms:
|
|
time_limit_ms = custom_time_limit
|
|
|
|
with sqlite_timelimit(conn, time_limit_ms):
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql, params if params is not None else {})
|
|
max_returned_rows = self.ds.max_returned_rows
|
|
if max_returned_rows == page_size:
|
|
max_returned_rows += 1
|
|
if max_returned_rows and truncate:
|
|
rows = cursor.fetchmany(max_returned_rows + 1)
|
|
truncated = len(rows) > max_returned_rows
|
|
rows = rows[:max_returned_rows]
|
|
else:
|
|
rows = cursor.fetchall()
|
|
truncated = False
|
|
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
|
|
if e.args == ("interrupted",):
|
|
raise QueryInterrupted(e, sql, params)
|
|
if log_sql_errors:
|
|
sys.stderr.write(
|
|
"ERROR: conn={}, sql = {}, params = {}: {}\n".format(
|
|
conn, repr(sql), params, e
|
|
)
|
|
)
|
|
sys.stderr.flush()
|
|
raise
|
|
|
|
if truncate:
|
|
return Results(rows, truncated, cursor.description)
|
|
|
|
else:
|
|
return Results(rows, False, cursor.description)
|
|
|
|
with trace("sql", database=self.name, sql=sql.strip(), params=params):
|
|
results = await self.execute_fn(sql_operation_in_thread)
|
|
return results
|
|
|
|
@property
|
|
def size(self):
|
|
if self.is_memory:
|
|
return 0
|
|
if self.cached_size is not None:
|
|
return self.cached_size
|
|
else:
|
|
return Path(self.path).stat().st_size
|
|
|
|
async def table_counts(self, limit=10):
|
|
if not self.is_mutable and self.cached_table_counts is not None:
|
|
return self.cached_table_counts
|
|
# Try to get counts for each table, $limit timeout for each count
|
|
counts = {}
|
|
for table in await self.table_names():
|
|
try:
|
|
table_count = (
|
|
await self.execute(
|
|
f"select count(*) from [{table}]",
|
|
custom_time_limit=limit,
|
|
)
|
|
).rows[0][0]
|
|
counts[table] = table_count
|
|
# In some cases I saw "SQL Logic Error" here in addition to
|
|
# QueryInterrupted - so we catch that too:
|
|
except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError):
|
|
counts[table] = None
|
|
if not self.is_mutable:
|
|
self._cached_table_counts = counts
|
|
return counts
|
|
|
|
@property
|
|
def mtime_ns(self):
|
|
if self.is_memory:
|
|
return None
|
|
return Path(self.path).stat().st_mtime_ns
|
|
|
|
async def attached_databases(self):
|
|
# This used to be:
|
|
# select seq, name, file from pragma_database_list() where seq > 0
|
|
# But SQLite prior to 3.16.0 doesn't support pragma functions
|
|
results = await self.execute("PRAGMA database_list;")
|
|
# {'seq': 0, 'name': 'main', 'file': ''}
|
|
return [AttachedDatabase(*row) for row in results.rows if row["seq"] > 0]
|
|
|
|
async def table_exists(self, table):
|
|
results = await self.execute(
|
|
"select 1 from sqlite_master where type='table' and name=?", params=(table,)
|
|
)
|
|
return bool(results.rows)
|
|
|
|
async def table_names(self):
|
|
results = await self.execute(
|
|
"select name from sqlite_master where type='table'"
|
|
)
|
|
return [r[0] for r in results.rows]
|
|
|
|
async def table_columns(self, table):
|
|
return await self.execute_fn(lambda conn: table_columns(conn, table))
|
|
|
|
async def table_column_details(self, table):
|
|
return await self.execute_fn(lambda conn: table_column_details(conn, table))
|
|
|
|
async def primary_keys(self, table):
|
|
return await self.execute_fn(lambda conn: detect_primary_keys(conn, table))
|
|
|
|
async def fts_table(self, table):
|
|
return await self.execute_fn(lambda conn: detect_fts(conn, table))
|
|
|
|
async def label_column_for_table(self, table):
|
|
explicit_label_column = self.ds.table_metadata(self.name, table).get(
|
|
"label_column"
|
|
)
|
|
if explicit_label_column:
|
|
return explicit_label_column
|
|
column_names = await self.execute_fn(lambda conn: table_columns(conn, table))
|
|
# Is there a name or title column?
|
|
name_or_title = [c for c in column_names if c.lower() in ("name", "title")]
|
|
if name_or_title:
|
|
return name_or_title[0]
|
|
# If a table has two columns, one of which is ID, then label_column is the other one
|
|
if (
|
|
column_names
|
|
and len(column_names) == 2
|
|
and ("id" in column_names or "pk" in column_names)
|
|
):
|
|
return [c for c in column_names if c not in ("id", "pk")][0]
|
|
# Couldn't find a label:
|
|
return None
|
|
|
|
async def foreign_keys_for_table(self, table):
|
|
return await self.execute_fn(
|
|
lambda conn: get_outbound_foreign_keys(conn, table)
|
|
)
|
|
|
|
async def hidden_table_names(self):
|
|
# Mark tables 'hidden' if they relate to FTS virtual tables
|
|
hidden_tables = [
|
|
r[0]
|
|
for r in (
|
|
await self.execute(
|
|
"""
|
|
select name from sqlite_master
|
|
where rootpage = 0
|
|
and (
|
|
sql like '%VIRTUAL TABLE%USING FTS%'
|
|
) or name in ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4')
|
|
"""
|
|
)
|
|
).rows
|
|
]
|
|
has_spatialite = await self.execute_fn(detect_spatialite)
|
|
if has_spatialite:
|
|
# Also hide Spatialite internal tables
|
|
hidden_tables += [
|
|
"ElementaryGeometries",
|
|
"SpatialIndex",
|
|
"geometry_columns",
|
|
"spatial_ref_sys",
|
|
"spatialite_history",
|
|
"sql_statements_log",
|
|
"sqlite_sequence",
|
|
"views_geometry_columns",
|
|
"virts_geometry_columns",
|
|
"data_licenses",
|
|
"KNN",
|
|
"KNN2",
|
|
] + [
|
|
r[0]
|
|
for r in (
|
|
await self.execute(
|
|
"""
|
|
select name from sqlite_master
|
|
where name like "idx_%"
|
|
and type = "table"
|
|
"""
|
|
)
|
|
).rows
|
|
]
|
|
# Add any from metadata.json
|
|
db_metadata = self.ds.metadata(database=self.name)
|
|
if "tables" in db_metadata:
|
|
hidden_tables += [
|
|
t
|
|
for t in db_metadata["tables"]
|
|
if db_metadata["tables"][t].get("hidden")
|
|
]
|
|
# Also mark as hidden any tables which start with the name of a hidden table
|
|
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
|
|
for table_name in await self.table_names():
|
|
for hidden_table in hidden_tables[:]:
|
|
if table_name.startswith(hidden_table):
|
|
hidden_tables.append(table_name)
|
|
continue
|
|
|
|
return hidden_tables
|
|
|
|
async def view_names(self):
|
|
results = await self.execute("select name from sqlite_master where type='view'")
|
|
return [r[0] for r in results.rows]
|
|
|
|
async def get_all_foreign_keys(self):
|
|
return await self.execute_fn(get_all_foreign_keys)
|
|
|
|
async def get_table_definition(self, table, type_="table"):
|
|
table_definition_rows = list(
|
|
await self.execute(
|
|
"select sql from sqlite_master where name = :n and type=:t",
|
|
{"n": table, "t": type_},
|
|
)
|
|
)
|
|
if not table_definition_rows:
|
|
return None
|
|
bits = [table_definition_rows[0][0] + ";"]
|
|
# Add on any indexes
|
|
index_rows = list(
|
|
await self.execute(
|
|
"select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null",
|
|
{"n": table},
|
|
)
|
|
)
|
|
for index_row in index_rows:
|
|
bits.append(index_row[0] + ";")
|
|
return "\n".join(bits)
|
|
|
|
async def get_view_definition(self, view):
|
|
return await self.get_table_definition(view, "view")
|
|
|
|
def __repr__(self):
|
|
tags = []
|
|
if self.is_mutable:
|
|
tags.append("mutable")
|
|
if self.is_memory:
|
|
tags.append("memory")
|
|
if self.hash:
|
|
tags.append(f"hash={self.hash}")
|
|
if self.size is not None:
|
|
tags.append(f"size={self.size}")
|
|
tags_str = ""
|
|
if tags:
|
|
tags_str = f" ({', '.join(tags)})"
|
|
return f"<Database: {self.name}{tags_str}>"
|
|
|
|
|
|
class WriteTask:
|
|
__slots__ = ("fn", "task_id", "reply_queue")
|
|
|
|
def __init__(self, fn, task_id, reply_queue):
|
|
self.fn = fn
|
|
self.task_id = task_id
|
|
self.reply_queue = reply_queue
|
|
|
|
|
|
class QueryInterrupted(Exception):
|
|
pass
|
|
|
|
|
|
class MultipleValues(Exception):
|
|
pass
|
|
|
|
|
|
class Results:
|
|
def __init__(self, rows, truncated, description):
|
|
self.rows = rows
|
|
self.truncated = truncated
|
|
self.description = description
|
|
|
|
@property
|
|
def columns(self):
|
|
return [d[0] for d in self.description]
|
|
|
|
def first(self):
|
|
if self.rows:
|
|
return self.rows[0]
|
|
else:
|
|
return None
|
|
|
|
def single_value(self):
|
|
if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]):
|
|
return self.rows[0][0]
|
|
else:
|
|
raise MultipleValues
|
|
|
|
def __iter__(self):
|
|
return iter(self.rows)
|
|
|
|
def __len__(self):
|
|
return len(self.rows)
|