Refactored ConnectedDatabase to datasette/database.py

Closes #487
pull/497/head
Simon Willison 2019-05-26 22:07:27 -07:00
rodzic 026c84db30
commit 6569287d90
3 zmienionych plików z 228 dodań i 223 usunięć

Wyświetl plik

@ -1,5 +1,4 @@
import asyncio
import click
import collections
import hashlib
import json
@ -12,6 +11,7 @@ import urllib.parse
from concurrent import futures
from pathlib import Path
import click
from markupsafe import Markup
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
@ -23,23 +23,19 @@ from .views.index import IndexView
from .views.special import JsonDataView
from .views.table import RowView, TableView
from .renderer import json_renderer
from .database import Database
from .utils import (
InterruptedError,
Results,
detect_spatialite,
escape_css_string,
escape_sqlite,
get_all_foreign_keys,
get_outbound_foreign_keys,
get_plugins,
module_from_path,
sqlite3,
sqlite_timelimit,
table_columns,
to_css_class,
)
from .inspect import inspect_hash, inspect_views, inspect_tables
from .tracer import capture_traces, trace
from .plugins import pm, DEFAULT_PLUGINS
from .version import __version__
@ -134,219 +130,6 @@ async def favicon(request):
return response.text("")
class ConnectedDatabase:
def __init__(self, ds, path=None, is_mutable=False, is_memory=False):
self.ds = ds
self.path = path
self.is_mutable = is_mutable
self.is_memory = is_memory
self.hash = None
self.cached_size = None
self.cached_table_counts = None
if not self.is_mutable:
p = Path(path)
self.hash = inspect_hash(p)
self.cached_size = p.stat().st_size
# Maybe use self.ds.inspect_data to populate cached_table_counts
if self.ds.inspect_data and self.ds.inspect_data.get(self.name):
self.cached_table_counts = {
key: value["count"]
for key, value in self.ds.inspect_data[self.name]["tables"].items()
}
@property
def size(self):
if self.is_memory:
return 0
if self.cached_size is not None:
return self.cached_size
else:
return Path(self.path).stat().st_size
async def table_counts(self, limit=10):
if not self.is_mutable and self.cached_table_counts is not None:
return self.cached_table_counts
# Try to get counts for each table, $limit timeout for each count
counts = {}
for table in await self.table_names():
try:
table_count = (
await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)
).rows[0][0]
counts[table] = table_count
# In some cases I saw "SQL Logic Error" here in addition to
# InterruptedError - so we catch that too:
except (InterruptedError, sqlite3.OperationalError):
counts[table] = None
if not self.is_mutable:
self.cached_table_counts = counts
return counts
@property
def mtime_ns(self):
return Path(self.path).stat().st_mtime_ns
@property
def name(self):
if self.is_memory:
return ":memory:"
else:
return Path(self.path).stem
async def table_exists(self, table):
results = await self.ds.execute(
self.name,
"select 1 from sqlite_master where type='table' and name=?",
params=(table,),
)
return bool(results.rows)
async def table_names(self):
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='table'"
)
return [r[0] for r in results.rows]
async def table_columns(self, table):
return await self.ds.execute_against_connection_in_thread(
self.name, lambda conn: table_columns(conn, table)
)
async def label_column_for_table(self, table):
explicit_label_column = self.ds.table_metadata(self.name, table).get(
"label_column"
)
if explicit_label_column:
return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.ds.execute_against_connection_in_thread(
self.name, lambda conn: table_columns(conn, table)
)
# Is there a name or title column?
name_or_title = [c for c in column_names if c in ("name", "title")]
if name_or_title:
return name_or_title[0]
if (
column_names
and len(column_names) == 2
and ("id" in column_names or "pk" in column_names)
):
return [c for c in column_names if c not in ("id", "pk")][0]
# Couldn't find a label:
return None
async def foreign_keys_for_table(self, table):
return await self.ds.execute_against_connection_in_thread(
self.name, lambda conn: get_outbound_foreign_keys(conn, table)
)
async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
""",
)
).rows
]
has_spatialite = await self.ds.execute_against_connection_in_thread(
self.name, detect_spatialite
)
if has_spatialite:
# Also hide Spatialite internal tables
hidden_tables += [
"ElementaryGeometries",
"SpatialIndex",
"geometry_columns",
"spatial_ref_sys",
"spatialite_history",
"sql_statements_log",
"sqlite_sequence",
"views_geometry_columns",
"virts_geometry_columns",
] + [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
""",
)
).rows
]
# Add any from metadata.json
db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata:
hidden_tables += [
t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
]
# Also mark as hidden any tables which start with the name of a hidden table
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
for table_name in await self.table_names():
for hidden_table in hidden_tables[:]:
if table_name.startswith(hidden_table):
hidden_tables.append(table_name)
continue
return hidden_tables
async def view_names(self):
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows]
async def get_all_foreign_keys(self):
return await self.ds.execute_against_connection_in_thread(
self.name, get_all_foreign_keys
)
async def get_table_definition(self, table, type_="table"):
table_definition_rows = list(
await self.ds.execute(
self.name,
"select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_},
)
)
if not table_definition_rows:
return None
return table_definition_rows[0][0]
async def get_view_definition(self, view):
return await self.get_table_definition(view, "view")
def __repr__(self):
tags = []
if self.is_mutable:
tags.append("mutable")
if self.is_memory:
tags.append("memory")
if self.hash:
tags.append("hash={}".format(self.hash))
if self.size is not None:
tags.append("size={}".format(self.size))
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<ConnectedDatabase: {}{}>".format(self.name, tags_str)
class Datasette:
def __init__(
self,
@ -380,9 +163,7 @@ class Datasette:
path = None
is_memory = True
is_mutable = path not in self.immutables
db = ConnectedDatabase(
self, path, is_mutable=is_mutable, is_memory=is_memory
)
db = Database(self, path, is_mutable=is_mutable, is_memory=is_memory)
if db.name in self.databases:
raise Exception("Multiple files with same stem: {}".format(db.name))
self.databases[db.name] = db
@ -751,6 +532,7 @@ class Datasette:
# Hooks
hook_renderers = []
# pylint: disable=no-member
for hook in pm.hook.register_output_renderer(datasette=self):
if type(hook) == list:
hook_renderers += hook

Wyświetl plik

@ -0,0 +1,224 @@
from pathlib import Path
from .utils import (
InterruptedError,
detect_spatialite,
get_all_foreign_keys,
get_outbound_foreign_keys,
sqlite3,
table_columns,
)
from .inspect import inspect_hash
class Database:
def __init__(self, ds, path=None, is_mutable=False, is_memory=False):
self.ds = ds
self.path = path
self.is_mutable = is_mutable
self.is_memory = is_memory
self.hash = None
self.cached_size = None
self.cached_table_counts = None
if not self.is_mutable:
p = Path(path)
self.hash = inspect_hash(p)
self.cached_size = p.stat().st_size
# Maybe use self.ds.inspect_data to populate cached_table_counts
if self.ds.inspect_data and self.ds.inspect_data.get(self.name):
self.cached_table_counts = {
key: value["count"]
for key, value in self.ds.inspect_data[self.name]["tables"].items()
}
@property
def size(self):
if self.is_memory:
return 0
if self.cached_size is not None:
return self.cached_size
else:
return Path(self.path).stat().st_size
async def table_counts(self, limit=10):
if not self.is_mutable and self.cached_table_counts is not None:
return self.cached_table_counts
# Try to get counts for each table, $limit timeout for each count
counts = {}
for table in await self.table_names():
try:
table_count = (
await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)
).rows[0][0]
counts[table] = table_count
# In some cases I saw "SQL Logic Error" here in addition to
# InterruptedError - so we catch that too:
except (InterruptedError, sqlite3.OperationalError):
counts[table] = None
if not self.is_mutable:
self.cached_table_counts = counts
return counts
@property
def mtime_ns(self):
return Path(self.path).stat().st_mtime_ns
@property
def name(self):
if self.is_memory:
return ":memory:"
else:
return Path(self.path).stem
async def table_exists(self, table):
results = await self.ds.execute(
self.name,
"select 1 from sqlite_master where type='table' and name=?",
params=(table,),
)
return bool(results.rows)
async def table_names(self):
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='table'"
)
return [r[0] for r in results.rows]
async def table_columns(self, table):
return await self.ds.execute_against_connection_in_thread(
self.name, lambda conn: table_columns(conn, table)
)
async def label_column_for_table(self, table):
explicit_label_column = self.ds.table_metadata(self.name, table).get(
"label_column"
)
if explicit_label_column:
return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.ds.execute_against_connection_in_thread(
self.name, lambda conn: table_columns(conn, table)
)
# Is there a name or title column?
name_or_title = [c for c in column_names if c in ("name", "title")]
if name_or_title:
return name_or_title[0]
if (
column_names
and len(column_names) == 2
and ("id" in column_names or "pk" in column_names)
):
return [c for c in column_names if c not in ("id", "pk")][0]
# Couldn't find a label:
return None
async def foreign_keys_for_table(self, table):
return await self.ds.execute_against_connection_in_thread(
self.name, lambda conn: get_outbound_foreign_keys(conn, table)
)
async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
""",
)
).rows
]
has_spatialite = await self.ds.execute_against_connection_in_thread(
self.name, detect_spatialite
)
if has_spatialite:
# Also hide Spatialite internal tables
hidden_tables += [
"ElementaryGeometries",
"SpatialIndex",
"geometry_columns",
"spatial_ref_sys",
"spatialite_history",
"sql_statements_log",
"sqlite_sequence",
"views_geometry_columns",
"virts_geometry_columns",
] + [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
""",
)
).rows
]
# Add any from metadata.json
db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata:
hidden_tables += [
t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
]
# Also mark as hidden any tables which start with the name of a hidden table
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
for table_name in await self.table_names():
for hidden_table in hidden_tables[:]:
if table_name.startswith(hidden_table):
hidden_tables.append(table_name)
continue
return hidden_tables
async def view_names(self):
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows]
async def get_all_foreign_keys(self):
return await self.ds.execute_against_connection_in_thread(
self.name, get_all_foreign_keys
)
async def get_table_definition(self, table, type_="table"):
table_definition_rows = list(
await self.ds.execute(
self.name,
"select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_},
)
)
if not table_definition_rows:
return None
return table_definition_rows[0][0]
async def get_view_definition(self, view):
return await self.get_table_definition(view, "view")
def __repr__(self):
tags = []
if self.is_mutable:
tags.append("mutable")
if self.is_memory:
tags.append("memory")
if self.hash:
tags.append("hash={}".format(self.hash))
if self.size is not None:
tags.append("size={}".format(self.size))
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<Database: {}{}>".format(self.name, tags_str)

Wyświetl plik

@ -5,7 +5,6 @@ from sanic import response
from datasette.utils import (
detect_fts,
detect_primary_keys,
get_all_foreign_keys,
to_css_class,
validate_sql_select,
)