Table page partially works on PostgreSQL, refs #670

postgresql-prototype
Simon Willison 2020-02-13 12:43:06 -08:00
rodzic 32a2f5793a
commit b87130a036
4 zmienionych plików z 115 dodań i 64 usunięć

Wyświetl plik

@ -78,6 +78,12 @@ class Database:
"""Executes sql against db_name in a thread"""
page_size = page_size or self.ds.page_size
# Where are we?
import io, traceback
stored_stack = io.StringIO()
traceback.print_stack(file=stored_stack)
def sql_operation_in_thread(conn):
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
@ -114,10 +120,15 @@ class Database:
else:
return Results(rows, False, cursor.description)
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_against_connection_in_thread(
sql_operation_in_thread
)
try:
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_against_connection_in_thread(
sql_operation_in_thread
)
except Exception as e:
print(e)
print(stored_stack.getvalue())
raise
return results
@property

Wyświetl plik

@ -73,7 +73,7 @@ class Facet:
self,
ds,
request,
database,
db,
sql=None,
table=None,
params=None,
@ -83,7 +83,7 @@ class Facet:
assert table or sql, "Must provide either table= or sql="
self.ds = ds
self.request = request
self.database = database
self.db = db
# For foreign key expansion. Can be None for e.g. canned SQL queries:
self.table = table
self.sql = sql or "select * from [{}]".format(table)
@ -113,17 +113,16 @@ class Facet:
async def get_columns(self, sql, params=None):
# Detect column names using the "limit 0" trick
return (
await self.ds.execute(
self.database, "select * from ({}) limit 0".format(sql), params or []
await self.db.execute(
"select * from ({}) as derived limit 0".format(sql), params or []
)
).columns
async def get_row_count(self):
if self.row_count is None:
self.row_count = (
await self.ds.execute(
self.database,
"select count(*) from ({})".format(self.sql),
await self.db.execute(
"select count(*) from ({}) as derived".format(self.sql),
self.params,
)
).rows[0][0]
@ -153,8 +152,7 @@ class ColumnFacet(Facet):
)
distinct_values = None
try:
distinct_values = await self.ds.execute(
self.database,
distinct_values = await self.db.execute(
suggested_facet_sql,
self.params,
truncate=False,
@ -203,8 +201,7 @@ class ColumnFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database,
facet_rows_results = await self.db.execute(
facet_sql,
self.params,
truncate=False,
@ -225,8 +222,8 @@ class ColumnFacet(Facet):
if self.table:
# Attempt to expand foreign keys into labels
values = [row["value"] for row in facet_rows]
expanded = await self.ds.expand_foreign_keys(
self.database, self.table, column, values
expanded = await self.db.expand_foreign_keys(
self.table, column, values
)
else:
expanded = {}
@ -285,8 +282,7 @@ class ArrayFacet(Facet):
column=escape_sqlite(column), sql=self.sql
)
try:
results = await self.ds.execute(
self.database,
results = await self.db.execute(
suggested_facet_sql,
self.params,
truncate=False,
@ -298,8 +294,7 @@ class ArrayFacet(Facet):
# Now sanity check that first 100 arrays contain only strings
first_100 = [
v[0]
for v in await self.ds.execute(
self.database,
for v in await self.db.execute(
"select {column} from ({sql}) where {column} is not null and json_array_length({column}) > 0 limit 100".format(
column=escape_sqlite(column), sql=self.sql
),
@ -349,8 +344,7 @@ class ArrayFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database,
facet_rows_results = await self.db.execute(
facet_sql,
self.params,
truncate=False,
@ -416,8 +410,7 @@ class DateFacet(Facet):
column=escape_sqlite(column), sql=self.sql
)
try:
results = await self.ds.execute(
self.database,
results = await self.db.execute(
suggested_facet_sql,
self.params,
truncate=False,
@ -462,8 +455,7 @@ class DateFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database,
facet_rows_results = await self.db.execute(
facet_sql,
self.params,
truncate=False,

Wyświetl plik

@ -7,6 +7,10 @@ class PostgresqlResults:
self.rows = rows
self.truncated = truncated
@property
def description(self):
return [[c] for c in self.columns]
@property
def columns(self):
try:
@ -24,6 +28,8 @@ class PostgresqlResults:
class PostgresqlDatabase:
size = 0
is_mutable = False
is_memory = False
hash = None
def __init__(self, ds, name, dsn):
self.ds = ds
@ -65,7 +71,7 @@ class PostgresqlDatabase:
return counts
async def table_exists(self, table):
raise NotImplementedError
return table in await self.table_names()
async def table_names(self):
results = await self.execute(
@ -159,29 +165,41 @@ class PostgresqlDatabase:
return []
async def get_table_definition(self, table, type_="table"):
table_definition_rows = list(
await self.execute(
"select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_},
sql = """
SELECT
'CREATE TABLE ' || relname || E'\n(\n' ||
array_to_string(
array_agg(
' ' || column_name || ' ' || type || ' '|| not_null
)
)
if not table_definition_rows:
return None
bits = [table_definition_rows[0][0] + ";"]
# Add on any indexes
index_rows = list(
await self.ds.execute(
self.name,
"select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null",
{"n": table},
)
)
for index_row in index_rows:
bits.append(index_row[0] + ";")
return "\n".join(bits)
, E',\n'
) || E'\n);\n'
from
(
SELECT
c.relname, a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as type,
case
when a.attnotnull
then 'NOT NULL'
else 'NULL'
END as not_null
FROM pg_class c,
pg_attribute a,
pg_type t
WHERE c.relname = $1
AND a.attnum > 0
AND a.attrelid = c.oid
AND a.atttypid = t.oid
ORDER BY a.attnum
) as tabledefinition
group by relname;
"""
return await (await self.connection()).fetchval(sql, table)
async def get_view_definition(self, view):
return await self.get_table_definition(view, "view")
# return await self.get_table_definition(view, "view")
return []
def __repr__(self):
tags = []

Wyświetl plik

@ -5,6 +5,7 @@ import json
import jinja2
from datasette.plugins import pm
from datasette.postgresql_database import PostgresqlDatabase
from datasette.utils import (
CustomRow,
QueryInterrupted,
@ -64,7 +65,12 @@ class Row:
class RowTableShared(DataView):
async def sortable_columns_for_table(self, database, table, use_rowid):
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_metadata = self.ds.table_metadata(database, table)
if "sortable_columns" in table_metadata:
sortable_columns = set(table_metadata["sortable_columns"])
@ -77,7 +83,12 @@ class RowTableShared(DataView):
async def expandable_columns(self, database, table):
# Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = []
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
for fk in await db.foreign_keys_for_table(table):
label_column = await db.label_column_for_table(fk["other_table"])
expandables.append((fk, label_column))
@ -87,7 +98,12 @@ class RowTableShared(DataView):
self, database, table, description, rows, link_column=False, truncate_cells=0
):
"Returns columns, rows for specified table - including fancy foreign key treatment"
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_metadata = self.ds.table_metadata(database, table)
sortable_columns = await self.sortable_columns_for_table(database, table, True)
columns = [
@ -228,7 +244,15 @@ class TableView(RowTableShared):
editable=False,
canned_query=table,
)
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
print("Here we go, db = ", db)
is_view = bool(await db.get_view_definition(table))
table_exists = bool(await db.table_exists(table))
if not is_view and not table_exists:
@ -533,17 +557,13 @@ class TableView(RowTableShared):
if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
results = await self.ds.execute(
database, sql, params, truncate=True, **extra_args
)
results = await db.execute(sql, params, truncate=True, **extra_args)
# Number of filtered rows in whole set:
filtered_table_rows_count = None
if count_sql:
try:
count_rows = list(
await self.ds.execute(database, count_sql, from_sql_params)
)
count_rows = list(await db.execute(count_sql, from_sql_params))
filtered_table_rows_count = count_rows[0][0]
except QueryInterrupted:
pass
@ -566,7 +586,7 @@ class TableView(RowTableShared):
klass(
self.ds,
request,
database,
db,
sql=sql_no_limit,
params=params,
table=table,
@ -584,7 +604,7 @@ class TableView(RowTableShared):
facets_timed_out.extend(instance_facets_timed_out)
# Figure out columns and rows for the query
columns = [r[0] for r in results.description]
columns = list(results.rows[0].keys())
rows = list(results.rows)
# Expand labeled columns if requested
@ -781,7 +801,12 @@ class RowView(RowTableShared):
async def data(self, request, database, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
pks = await db.primary_keys(table)
use_rowid = not pks
select = "*"
@ -795,7 +820,7 @@ class RowView(RowTableShared):
params = {}
for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value
results = await self.ds.execute(database, sql, params, truncate=True)
results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
@ -860,7 +885,12 @@ class RowView(RowTableShared):
async def foreign_key_tables(self, database, table, pk_values):
if len(pk_values) != 1:
return []
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
all_foreign_keys = await db.get_all_foreign_keys()
foreign_keys = all_foreign_keys[table]["incoming"]
if len(foreign_keys) == 0:
@ -876,7 +906,7 @@ class RowView(RowTableShared):
]
)
try:
rows = list(await self.ds.execute(database, sql, {"id": pk_values[0]}))
rows = list(await db.execute(sql, {"id": pk_values[0]}))
except sqlite3.OperationalError:
# Almost certainly hit the timeout
return []