kopia lustrzana https://github.com/simonw/datasette
Backported experimental #2058 fix to 0.64.x
rodzic
2a0a94fe97
commit
42bf9e2aab
|
@ -376,23 +376,50 @@ class Datasette:
|
||||||
await init_internal_db(internal_db)
|
await init_internal_db(internal_db)
|
||||||
self.internal_db_created = True
|
self.internal_db_created = True
|
||||||
|
|
||||||
current_schema_versions = {
|
current_schema_versions_and_hashes = {
|
||||||
row["database_name"]: row["schema_version"]
|
row["database_name"]: (row["schema_version"], row["schema_hash"])
|
||||||
for row in await internal_db.execute(
|
for row in await internal_db.execute(
|
||||||
"select database_name, schema_version from databases"
|
"select database_name, schema_version, schema_hash from databases"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
for database_name, db in self.databases.items():
|
for database_name, db in self.databases.items():
|
||||||
schema_version = (await db.execute("PRAGMA schema_version")).first()[0]
|
schema_version = await db.schema_version()
|
||||||
# Compare schema versions to see if we should skip it
|
current_version_and_hash = current_schema_versions_and_hashes.get(
|
||||||
if schema_version == current_schema_versions.get(database_name):
|
database_name
|
||||||
continue
|
)
|
||||||
|
if current_version_and_hash:
|
||||||
|
# We might get to skip this database
|
||||||
|
if schema_version is not None and current_version_and_hash:
|
||||||
|
# Use this to decide if the schema has changed
|
||||||
|
if schema_version == current_version_and_hash[0]:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Use the schema hash instead
|
||||||
|
schema_hash = await db.schema_hash()
|
||||||
|
if schema_hash == current_version_and_hash[1]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate new schema hash
|
||||||
|
schema_hash = await db.schema_hash()
|
||||||
|
placeholders = "(?, ?, ?, ?, ?)"
|
||||||
|
values = [
|
||||||
|
database_name,
|
||||||
|
str(db.path),
|
||||||
|
db.is_memory,
|
||||||
|
schema_version,
|
||||||
|
schema_hash,
|
||||||
|
]
|
||||||
|
if db.path is None:
|
||||||
|
placeholders = "(?, null, ?, ?, ?)"
|
||||||
|
values = [database_name, db.is_memory, schema_version, schema_hash]
|
||||||
await internal_db.execute_write(
|
await internal_db.execute_write(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version)
|
INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version, schema_hash)
|
||||||
VALUES (?, ?, ?, ?)
|
VALUES {}
|
||||||
""",
|
""".format(
|
||||||
[database_name, str(db.path), db.is_memory, schema_version],
|
placeholders
|
||||||
|
),
|
||||||
|
values,
|
||||||
)
|
)
|
||||||
await populate_schema_tables(internal_db, db)
|
await populate_schema_tables(internal_db, db)
|
||||||
|
|
||||||
|
|
|
@ -136,6 +136,7 @@ def sqlite_extensions(fn):
|
||||||
multiple=True,
|
multiple=True,
|
||||||
help="Path to a SQLite extension to load, and optional entrypoint",
|
help="Path to a SQLite extension to load, and optional entrypoint",
|
||||||
)(fn)
|
)(fn)
|
||||||
|
|
||||||
# Wrap it in a custom error handler
|
# Wrap it in a custom error handler
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import hashlib
|
||||||
import janus
|
import janus
|
||||||
import queue
|
import queue
|
||||||
import sys
|
import sys
|
||||||
|
@ -50,6 +51,24 @@ class Database:
|
||||||
# This is used to track all file connections so they can be closed
|
# This is used to track all file connections so they can be closed
|
||||||
self._all_file_connections = []
|
self._all_file_connections = []
|
||||||
|
|
||||||
|
async def schema_version(self):
|
||||||
|
# This can return 'None' if the schema_version cannot be read
|
||||||
|
# See https://github.com/simonw/datasette/issues/2058
|
||||||
|
try:
|
||||||
|
return (await self.execute("PRAGMA schema_version")).first()[0]
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def schema_hash(self):
|
||||||
|
return hashlib.md5(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
await self.execute("SELECT group_concat(sql) FROM sqlite_master")
|
||||||
|
).first()[0]
|
||||||
|
or ""
|
||||||
|
).encode("utf8")
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cached_table_counts(self):
|
def cached_table_counts(self):
|
||||||
if self._cached_table_counts is not None:
|
if self._cached_table_counts is not None:
|
||||||
|
|
|
@ -9,7 +9,8 @@ async def init_internal_db(db):
|
||||||
database_name TEXT PRIMARY KEY,
|
database_name TEXT PRIMARY KEY,
|
||||||
path TEXT,
|
path TEXT,
|
||||||
is_memory INTEGER,
|
is_memory INTEGER,
|
||||||
schema_version INTEGER
|
schema_version INTEGER,
|
||||||
|
schema_hash TEXT
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS tables (
|
CREATE TABLE IF NOT EXISTS tables (
|
||||||
database_name TEXT,
|
database_name TEXT,
|
||||||
|
|
|
@ -83,13 +83,11 @@ async def test_through_filters_from_request(app_client):
|
||||||
request = Request.fake(
|
request = Request.fake(
|
||||||
'/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}'
|
'/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}'
|
||||||
)
|
)
|
||||||
filter_args = await (
|
filter_args = await through_filters(
|
||||||
through_filters(
|
request=request,
|
||||||
request=request,
|
datasette=app_client.ds,
|
||||||
datasette=app_client.ds,
|
table="roadside_attractions",
|
||||||
table="roadside_attractions",
|
database="fixtures",
|
||||||
database="fixtures",
|
|
||||||
)
|
|
||||||
)()
|
)()
|
||||||
assert filter_args.where_clauses == [
|
assert filter_args.where_clauses == [
|
||||||
"pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)"
|
"pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)"
|
||||||
|
@ -106,13 +104,11 @@ async def test_through_filters_from_request(app_client):
|
||||||
request = Request.fake(
|
request = Request.fake(
|
||||||
'/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}'
|
'/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}'
|
||||||
)
|
)
|
||||||
filter_args = await (
|
filter_args = await through_filters(
|
||||||
through_filters(
|
request=request,
|
||||||
request=request,
|
datasette=app_client.ds,
|
||||||
datasette=app_client.ds,
|
table="roadside_attractions",
|
||||||
table="roadside_attractions",
|
database="fixtures",
|
||||||
database="fixtures",
|
|
||||||
)
|
|
||||||
)()
|
)()
|
||||||
assert filter_args.where_clauses == [
|
assert filter_args.where_clauses == [
|
||||||
"pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)"
|
"pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)"
|
||||||
|
@ -127,12 +123,10 @@ async def test_through_filters_from_request(app_client):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_where_filters_from_request(app_client):
|
async def test_where_filters_from_request(app_client):
|
||||||
request = Request.fake("/?_where=pk+>+3")
|
request = Request.fake("/?_where=pk+>+3")
|
||||||
filter_args = await (
|
filter_args = await where_filters(
|
||||||
where_filters(
|
request=request,
|
||||||
request=request,
|
datasette=app_client.ds,
|
||||||
datasette=app_client.ds,
|
database="fixtures",
|
||||||
database="fixtures",
|
|
||||||
)
|
|
||||||
)()
|
)()
|
||||||
assert filter_args.where_clauses == ["pk > 3"]
|
assert filter_args.where_clauses == ["pk > 3"]
|
||||||
assert filter_args.params == {}
|
assert filter_args.params == {}
|
||||||
|
@ -145,13 +139,11 @@ async def test_where_filters_from_request(app_client):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_filters_from_request(app_client):
|
async def test_search_filters_from_request(app_client):
|
||||||
request = Request.fake("/?_search=bobcat")
|
request = Request.fake("/?_search=bobcat")
|
||||||
filter_args = await (
|
filter_args = await search_filters(
|
||||||
search_filters(
|
request=request,
|
||||||
request=request,
|
datasette=app_client.ds,
|
||||||
datasette=app_client.ds,
|
database="fixtures",
|
||||||
database="fixtures",
|
table="searchable",
|
||||||
table="searchable",
|
|
||||||
)
|
|
||||||
)()
|
)()
|
||||||
assert filter_args.where_clauses == [
|
assert filter_args.where_clauses == [
|
||||||
"rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))"
|
"rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))"
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
from .fixtures import app_client
|
from .fixtures import app_client
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from datasette.app import Datasette
|
||||||
|
from datasette.database import Database
|
||||||
|
|
||||||
|
|
||||||
def test_internal_only_available_to_root(app_client):
|
def test_internal_only_available_to_root(app_client):
|
||||||
|
@ -65,3 +68,51 @@ def test_internal_foreign_keys(app_client):
|
||||||
"table_name",
|
"table_name",
|
||||||
"from",
|
"from",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("schema_version_returns_none", (True, False))
|
||||||
|
async def test_detects_schema_changes(schema_version_returns_none):
|
||||||
|
ds = Datasette()
|
||||||
|
db_name = "test_detects_schema_changes_{}".format(schema_version_returns_none)
|
||||||
|
db = ds.add_memory_database(db_name)
|
||||||
|
# Test if Datasette correctly detects schema changes, whether or not
|
||||||
|
# the schema_version method is working.
|
||||||
|
# https://github.com/simonw/datasette/issues/2058
|
||||||
|
|
||||||
|
_internal = ds.get_database("_internal")
|
||||||
|
|
||||||
|
async def get_tables():
|
||||||
|
return [
|
||||||
|
dict(r)
|
||||||
|
for r in await _internal.execute(
|
||||||
|
"select table_name from tables where database_name = ?", [db_name]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def test_it():
|
||||||
|
await ds.refresh_schemas()
|
||||||
|
initial_hash = await db.schema_hash()
|
||||||
|
# _internal should list zero tables
|
||||||
|
tables = await get_tables()
|
||||||
|
assert tables == []
|
||||||
|
# Create a new table
|
||||||
|
await db.execute_write("CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
||||||
|
await ds.refresh_schemas()
|
||||||
|
assert await db.schema_hash() != initial_hash
|
||||||
|
# _internal should list one table
|
||||||
|
tables = await get_tables()
|
||||||
|
assert tables == [
|
||||||
|
{"table_name": "test"},
|
||||||
|
]
|
||||||
|
|
||||||
|
async def schema_version_none(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if schema_version_returns_none:
|
||||||
|
with patch(
|
||||||
|
"datasette.database.Database.schema_version", new=schema_version_none
|
||||||
|
):
|
||||||
|
await test_it()
|
||||||
|
else:
|
||||||
|
await test_it()
|
||||||
|
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||||
# this resolves to "./ext", which is enough for SQLite to calculate the rest
|
# this resolves to "./ext", which is enough for SQLite to calculate the rest
|
||||||
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
|
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
|
||||||
|
|
||||||
|
|
||||||
# See if ext.c has been compiled, based off the different possible suffixes.
|
# See if ext.c has been compiled, based off the different possible suffixes.
|
||||||
def has_compiled_ext():
|
def has_compiled_ext():
|
||||||
for ext in ["dylib", "so", "dll"]:
|
for ext in ["dylib", "so", "dll"]:
|
||||||
|
@ -20,7 +21,6 @@ def has_compiled_ext():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
|
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
|
||||||
async def test_load_extension_default_entrypoint():
|
async def test_load_extension_default_entrypoint():
|
||||||
|
|
||||||
# The default entrypoint only loads a() and NOT b() or c(), so those
|
# The default entrypoint only loads a() and NOT b() or c(), so those
|
||||||
# should fail.
|
# should fail.
|
||||||
ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH])
|
ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH])
|
||||||
|
@ -41,7 +41,6 @@ async def test_load_extension_default_entrypoint():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
|
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
|
||||||
async def test_load_extension_multiple_entrypoints():
|
async def test_load_extension_multiple_entrypoints():
|
||||||
|
|
||||||
# Load in the default entrypoint and the other 2 custom entrypoints, now
|
# Load in the default entrypoint and the other 2 custom entrypoints, now
|
||||||
# all a(), b(), and c() should run successfully.
|
# all a(), b(), and c() should run successfully.
|
||||||
ds = Datasette(
|
ds = Datasette(
|
||||||
|
|
Ładowanie…
Reference in New Issue