From 42bf9e2aab0cbe7eac4f55320168ed4a750fec5b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 12 Apr 2023 17:56:10 -0700 Subject: [PATCH] Backported experimental #2058 fix to 0.64.x --- datasette/app.py | 49 ++++++++++++++++++++++++-------- datasette/cli.py | 1 + datasette/database.py | 19 +++++++++++++ datasette/utils/internal_db.py | 3 +- tests/test_filters.py | 46 +++++++++++++----------------- tests/test_internal_db.py | 51 ++++++++++++++++++++++++++++++++++ tests/test_load_extensions.py | 3 +- 7 files changed, 131 insertions(+), 41 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 6b889f08..41c73acd 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -376,23 +376,50 @@ class Datasette: await init_internal_db(internal_db) self.internal_db_created = True - current_schema_versions = { - row["database_name"]: row["schema_version"] + current_schema_versions_and_hashes = { + row["database_name"]: (row["schema_version"], row["schema_hash"]) for row in await internal_db.execute( - "select database_name, schema_version from databases" + "select database_name, schema_version, schema_hash from databases" ) } for database_name, db in self.databases.items(): - schema_version = (await db.execute("PRAGMA schema_version")).first()[0] - # Compare schema versions to see if we should skip it - if schema_version == current_schema_versions.get(database_name): - continue + schema_version = await db.schema_version() + current_version_and_hash = current_schema_versions_and_hashes.get( + database_name + ) + if current_version_and_hash: + # We might get to skip this database + if schema_version is not None and current_version_and_hash: + # Use this to decide if the schema has changed + if schema_version == current_version_and_hash[0]: + continue + else: + # Use the schema hash instead + schema_hash = await db.schema_hash() + if schema_hash == current_version_and_hash[1]: + continue + + # Calculate new schema hash + schema_hash = await db.schema_hash() + placeholders = "(?, ?, ?, ?, ?)" + values = [ + database_name, + str(db.path), + db.is_memory, + schema_version, + schema_hash, + ] + if db.path is None: + placeholders = "(?, null, ?, ?, ?)" + values = [database_name, db.is_memory, schema_version, schema_hash] await internal_db.execute_write( """ - INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version) - VALUES (?, ?, ?, ?) - """, - [database_name, str(db.path), db.is_memory, schema_version], + INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version, schema_hash) + VALUES {} + """.format( + placeholders + ), + values, ) await populate_schema_tables(internal_db, db) diff --git a/datasette/cli.py b/datasette/cli.py index 89ee12b6..fd65ea94 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -136,6 +136,7 @@ def sqlite_extensions(fn): multiple=True, help="Path to a SQLite extension to load, and optional entrypoint", )(fn) + # Wrap it in a custom error handler @functools.wraps(fn) def wrapped(*args, **kwargs): diff --git a/datasette/database.py b/datasette/database.py index dfca179c..1de7b393 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -1,6 +1,7 @@ import asyncio from collections import namedtuple from pathlib import Path +import hashlib import janus import queue import sys @@ -50,6 +51,24 @@ class Database: # This is used to track all file connections so they can be closed self._all_file_connections = [] + async def schema_version(self): + # This can return 'None' if the schema_version cannot be read + # See https://github.com/simonw/datasette/issues/2058 + try: + return (await self.execute("PRAGMA schema_version")).first()[0] + except sqlite3.OperationalError: + return None + + async def schema_hash(self): + return hashlib.md5( + ( + ( + await self.execute("SELECT group_concat(sql) FROM sqlite_master") + ).first()[0] + or "" + ).encode("utf8") + ).hexdigest() + @property def cached_table_counts(self): if self._cached_table_counts is not None: diff --git a/datasette/utils/internal_db.py b/datasette/utils/internal_db.py index e4b49e80..08868f3f 100644 --- a/datasette/utils/internal_db.py +++ b/datasette/utils/internal_db.py @@ -9,7 +9,8 @@ async def init_internal_db(db): database_name TEXT PRIMARY KEY, path TEXT, is_memory INTEGER, - schema_version INTEGER + schema_version INTEGER, + schema_hash TEXT ); CREATE TABLE IF NOT EXISTS tables ( database_name TEXT, diff --git a/tests/test_filters.py b/tests/test_filters.py index 2ff57489..08407612 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -83,13 +83,11 @@ async def test_through_filters_from_request(app_client): request = Request.fake( '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' ) - filter_args = await ( - through_filters( - request=request, - datasette=app_client.ds, - table="roadside_attractions", - database="fixtures", - ) + filter_args = await through_filters( + request=request, + datasette=app_client.ds, + table="roadside_attractions", + database="fixtures", )() assert filter_args.where_clauses == [ "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" @@ -106,13 +104,11 @@ async def test_through_filters_from_request(app_client): request = Request.fake( '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' ) - filter_args = await ( - through_filters( - request=request, - datasette=app_client.ds, - table="roadside_attractions", - database="fixtures", - ) + filter_args = await through_filters( + request=request, + datasette=app_client.ds, + table="roadside_attractions", + database="fixtures", )() assert filter_args.where_clauses == [ "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" @@ -127,12 +123,10 @@ async def test_through_filters_from_request(app_client): @pytest.mark.asyncio async def test_where_filters_from_request(app_client): request = Request.fake("/?_where=pk+>+3") - filter_args = await ( - where_filters( - request=request, - datasette=app_client.ds, - database="fixtures", - ) + filter_args = await where_filters( + request=request, + datasette=app_client.ds, + database="fixtures", )() assert filter_args.where_clauses == ["pk > 3"] assert filter_args.params == {} @@ -145,13 +139,11 @@ async def test_where_filters_from_request(app_client): @pytest.mark.asyncio async def test_search_filters_from_request(app_client): request = Request.fake("/?_search=bobcat") - filter_args = await ( - search_filters( - request=request, - datasette=app_client.ds, - database="fixtures", - table="searchable", - ) + filter_args = await search_filters( + request=request, + datasette=app_client.ds, + database="fixtures", + table="searchable", )() assert filter_args.where_clauses == [ "rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))" diff --git a/tests/test_internal_db.py b/tests/test_internal_db.py index 755ddae5..b1aa4763 100644 --- a/tests/test_internal_db.py +++ b/tests/test_internal_db.py @@ -1,5 +1,8 @@ from .fixtures import app_client import pytest +from unittest.mock import patch +from datasette.app import Datasette +from datasette.database import Database def test_internal_only_available_to_root(app_client): @@ -65,3 +68,51 @@ def test_internal_foreign_keys(app_client): "table_name", "from", } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("schema_version_returns_none", (True, False)) +async def test_detects_schema_changes(schema_version_returns_none): + ds = Datasette() + db_name = "test_detects_schema_changes_{}".format(schema_version_returns_none) + db = ds.add_memory_database(db_name) + # Test if Datasette correctly detects schema changes, whether or not + # the schema_version method is working. + # https://github.com/simonw/datasette/issues/2058 + + _internal = ds.get_database("_internal") + + async def get_tables(): + return [ + dict(r) + for r in await _internal.execute( + "select table_name from tables where database_name = ?", [db_name] + ) + ] + + async def test_it(): + await ds.refresh_schemas() + initial_hash = await db.schema_hash() + # _internal should list zero tables + tables = await get_tables() + assert tables == [] + # Create a new table + await db.execute_write("CREATE TABLE test (id INTEGER PRIMARY KEY)") + await ds.refresh_schemas() + assert await db.schema_hash() != initial_hash + # _internal should list one table + tables = await get_tables() + assert tables == [ + {"table_name": "test"}, + ] + + async def schema_version_none(self): + return None + + if schema_version_returns_none: + with patch( + "datasette.database.Database.schema_version", new=schema_version_none + ): + await test_it() + else: + await test_it() diff --git a/tests/test_load_extensions.py b/tests/test_load_extensions.py index 360bc8f3..8b14fd32 100644 --- a/tests/test_load_extensions.py +++ b/tests/test_load_extensions.py @@ -8,6 +8,7 @@ from pathlib import Path # this resolves to "./ext", which is enough for SQLite to calculate the rest COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext") + # See if ext.c has been compiled, based off the different possible suffixes. def has_compiled_ext(): for ext in ["dylib", "so", "dll"]: @@ -20,7 +21,6 @@ def has_compiled_ext(): @pytest.mark.asyncio @pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") async def test_load_extension_default_entrypoint(): - # The default entrypoint only loads a() and NOT b() or c(), so those # should fail. ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH]) @@ -41,7 +41,6 @@ async def test_load_extension_default_entrypoint(): @pytest.mark.asyncio @pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") async def test_load_extension_multiple_entrypoints(): - # Load in the default entrypoint and the other 2 custom entrypoints, now # all a(), b(), and c() should run successfully. ds = Datasette(