Backported experimental #2058 fix to 0.64.x

schema-version-fix-0.64.x
Simon Willison 2023-04-12 17:56:10 -07:00
rodzic 2a0a94fe97
commit 42bf9e2aab
7 zmienionych plików z 131 dodań i 41 usunięć

Wyświetl plik

@ -376,23 +376,50 @@ class Datasette:
await init_internal_db(internal_db) await init_internal_db(internal_db)
self.internal_db_created = True self.internal_db_created = True
current_schema_versions = { current_schema_versions_and_hashes = {
row["database_name"]: row["schema_version"] row["database_name"]: (row["schema_version"], row["schema_hash"])
for row in await internal_db.execute( for row in await internal_db.execute(
"select database_name, schema_version from databases" "select database_name, schema_version, schema_hash from databases"
) )
} }
for database_name, db in self.databases.items(): for database_name, db in self.databases.items():
schema_version = (await db.execute("PRAGMA schema_version")).first()[0] schema_version = await db.schema_version()
# Compare schema versions to see if we should skip it current_version_and_hash = current_schema_versions_and_hashes.get(
if schema_version == current_schema_versions.get(database_name): database_name
continue )
if current_version_and_hash:
# We might get to skip this database
if schema_version is not None and current_version_and_hash:
# Use this to decide if the schema has changed
if schema_version == current_version_and_hash[0]:
continue
else:
# Use the schema hash instead
schema_hash = await db.schema_hash()
if schema_hash == current_version_and_hash[1]:
continue
# Calculate new schema hash
schema_hash = await db.schema_hash()
placeholders = "(?, ?, ?, ?, ?)"
values = [
database_name,
str(db.path),
db.is_memory,
schema_version,
schema_hash,
]
if db.path is None:
placeholders = "(?, null, ?, ?, ?)"
values = [database_name, db.is_memory, schema_version, schema_hash]
await internal_db.execute_write( await internal_db.execute_write(
""" """
INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version) INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version, schema_hash)
VALUES (?, ?, ?, ?) VALUES {}
""", """.format(
[database_name, str(db.path), db.is_memory, schema_version], placeholders
),
values,
) )
await populate_schema_tables(internal_db, db) await populate_schema_tables(internal_db, db)

Wyświetl plik

@ -136,6 +136,7 @@ def sqlite_extensions(fn):
multiple=True, multiple=True,
help="Path to a SQLite extension to load, and optional entrypoint", help="Path to a SQLite extension to load, and optional entrypoint",
)(fn) )(fn)
# Wrap it in a custom error handler # Wrap it in a custom error handler
@functools.wraps(fn) @functools.wraps(fn)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):

Wyświetl plik

@ -1,6 +1,7 @@
import asyncio import asyncio
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
import hashlib
import janus import janus
import queue import queue
import sys import sys
@ -50,6 +51,24 @@ class Database:
# This is used to track all file connections so they can be closed # This is used to track all file connections so they can be closed
self._all_file_connections = [] self._all_file_connections = []
async def schema_version(self):
# This can return 'None' if the schema_version cannot be read
# See https://github.com/simonw/datasette/issues/2058
try:
return (await self.execute("PRAGMA schema_version")).first()[0]
except sqlite3.OperationalError:
return None
async def schema_hash(self):
return hashlib.md5(
(
(
await self.execute("SELECT group_concat(sql) FROM sqlite_master")
).first()[0]
or ""
).encode("utf8")
).hexdigest()
@property @property
def cached_table_counts(self): def cached_table_counts(self):
if self._cached_table_counts is not None: if self._cached_table_counts is not None:

Wyświetl plik

@ -9,7 +9,8 @@ async def init_internal_db(db):
database_name TEXT PRIMARY KEY, database_name TEXT PRIMARY KEY,
path TEXT, path TEXT,
is_memory INTEGER, is_memory INTEGER,
schema_version INTEGER schema_version INTEGER,
schema_hash TEXT
); );
CREATE TABLE IF NOT EXISTS tables ( CREATE TABLE IF NOT EXISTS tables (
database_name TEXT, database_name TEXT,

Wyświetl plik

@ -83,13 +83,11 @@ async def test_through_filters_from_request(app_client):
request = Request.fake( request = Request.fake(
'/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}'
) )
filter_args = await ( filter_args = await through_filters(
through_filters( request=request,
request=request, datasette=app_client.ds,
datasette=app_client.ds, table="roadside_attractions",
table="roadside_attractions", database="fixtures",
database="fixtures",
)
)() )()
assert filter_args.where_clauses == [ assert filter_args.where_clauses == [
"pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)"
@ -106,13 +104,11 @@ async def test_through_filters_from_request(app_client):
request = Request.fake( request = Request.fake(
'/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}'
) )
filter_args = await ( filter_args = await through_filters(
through_filters( request=request,
request=request, datasette=app_client.ds,
datasette=app_client.ds, table="roadside_attractions",
table="roadside_attractions", database="fixtures",
database="fixtures",
)
)() )()
assert filter_args.where_clauses == [ assert filter_args.where_clauses == [
"pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)"
@ -127,12 +123,10 @@ async def test_through_filters_from_request(app_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_where_filters_from_request(app_client): async def test_where_filters_from_request(app_client):
request = Request.fake("/?_where=pk+>+3") request = Request.fake("/?_where=pk+>+3")
filter_args = await ( filter_args = await where_filters(
where_filters( request=request,
request=request, datasette=app_client.ds,
datasette=app_client.ds, database="fixtures",
database="fixtures",
)
)() )()
assert filter_args.where_clauses == ["pk > 3"] assert filter_args.where_clauses == ["pk > 3"]
assert filter_args.params == {} assert filter_args.params == {}
@ -145,13 +139,11 @@ async def test_where_filters_from_request(app_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_filters_from_request(app_client): async def test_search_filters_from_request(app_client):
request = Request.fake("/?_search=bobcat") request = Request.fake("/?_search=bobcat")
filter_args = await ( filter_args = await search_filters(
search_filters( request=request,
request=request, datasette=app_client.ds,
datasette=app_client.ds, database="fixtures",
database="fixtures", table="searchable",
table="searchable",
)
)() )()
assert filter_args.where_clauses == [ assert filter_args.where_clauses == [
"rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))" "rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))"

Wyświetl plik

@ -1,5 +1,8 @@
from .fixtures import app_client from .fixtures import app_client
import pytest import pytest
from unittest.mock import patch
from datasette.app import Datasette
from datasette.database import Database
def test_internal_only_available_to_root(app_client): def test_internal_only_available_to_root(app_client):
@ -65,3 +68,51 @@ def test_internal_foreign_keys(app_client):
"table_name", "table_name",
"from", "from",
} }
@pytest.mark.asyncio
@pytest.mark.parametrize("schema_version_returns_none", (True, False))
async def test_detects_schema_changes(schema_version_returns_none):
ds = Datasette()
db_name = "test_detects_schema_changes_{}".format(schema_version_returns_none)
db = ds.add_memory_database(db_name)
# Test if Datasette correctly detects schema changes, whether or not
# the schema_version method is working.
# https://github.com/simonw/datasette/issues/2058
_internal = ds.get_database("_internal")
async def get_tables():
return [
dict(r)
for r in await _internal.execute(
"select table_name from tables where database_name = ?", [db_name]
)
]
async def test_it():
await ds.refresh_schemas()
initial_hash = await db.schema_hash()
# _internal should list zero tables
tables = await get_tables()
assert tables == []
# Create a new table
await db.execute_write("CREATE TABLE test (id INTEGER PRIMARY KEY)")
await ds.refresh_schemas()
assert await db.schema_hash() != initial_hash
# _internal should list one table
tables = await get_tables()
assert tables == [
{"table_name": "test"},
]
async def schema_version_none(self):
return None
if schema_version_returns_none:
with patch(
"datasette.database.Database.schema_version", new=schema_version_none
):
await test_it()
else:
await test_it()

Wyświetl plik

@ -8,6 +8,7 @@ from pathlib import Path
# this resolves to "./ext", which is enough for SQLite to calculate the rest # this resolves to "./ext", which is enough for SQLite to calculate the rest
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext") COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
# See if ext.c has been compiled, based off the different possible suffixes. # See if ext.c has been compiled, based off the different possible suffixes.
def has_compiled_ext(): def has_compiled_ext():
for ext in ["dylib", "so", "dll"]: for ext in ["dylib", "so", "dll"]:
@ -20,7 +21,6 @@ def has_compiled_ext():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") @pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
async def test_load_extension_default_entrypoint(): async def test_load_extension_default_entrypoint():
# The default entrypoint only loads a() and NOT b() or c(), so those # The default entrypoint only loads a() and NOT b() or c(), so those
# should fail. # should fail.
ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH]) ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH])
@ -41,7 +41,6 @@ async def test_load_extension_default_entrypoint():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") @pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
async def test_load_extension_multiple_entrypoints(): async def test_load_extension_multiple_entrypoints():
# Load in the default entrypoint and the other 2 custom entrypoints, now # Load in the default entrypoint and the other 2 custom entrypoints, now
# all a(), b(), and c() should run successfully. # all a(), b(), and c() should run successfully.
ds = Datasette( ds = Datasette(