From 98b2de65c6d97a3070a670a1a5db856e7af7ac52 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 12 Apr 2023 17:51:34 -0700 Subject: [PATCH] Experimental fix addressing #2058 --- datasette/app.py | 44 +++++++++++++++++++++-------- datasette/database.py | 19 +++++++++++++ datasette/utils/internal_db.py | 3 +- tests/test_internal_db.py | 51 ++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 13 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index d7dace67..e65cc1bf 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -424,25 +424,45 @@ class Datasette: await init_internal_db(internal_db) self.internal_db_created = True - current_schema_versions = { - row["database_name"]: row["schema_version"] + current_schema_versions_and_hashes = { + row["database_name"]: (row["schema_version"], row["schema_hash"]) for row in await internal_db.execute( - "select database_name, schema_version from databases" + "select database_name, schema_version, schema_hash from databases" ) } for database_name, db in self.databases.items(): - schema_version = (await db.execute("PRAGMA schema_version")).first()[0] - # Compare schema versions to see if we should skip it - if schema_version == current_schema_versions.get(database_name): - continue - placeholders = "(?, ?, ?, ?)" - values = [database_name, str(db.path), db.is_memory, schema_version] + schema_version = await db.schema_version() + current_version_and_hash = current_schema_versions_and_hashes.get( + database_name + ) + if current_version_and_hash: + # We might get to skip this database + if schema_version is not None and current_version_and_hash: + # Use this to decide if the schema has changed + if schema_version == current_version_and_hash[0]: + continue + else: + # Use the schema hash instead + schema_hash = await db.schema_hash() + if schema_hash == current_version_and_hash[1]: + continue + + # Calculate new schema hash + schema_hash = await db.schema_hash() + placeholders = "(?, ?, ?, ?, ?)" + values = [ + database_name, + str(db.path), + db.is_memory, + schema_version, + schema_hash, + ] if db.path is None: - placeholders = "(?, null, ?, ?)" - values = [database_name, db.is_memory, schema_version] + placeholders = "(?, null, ?, ?, ?)" + values = [database_name, db.is_memory, schema_version, schema_hash] await internal_db.execute_write( """ - INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version) + INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version, schema_hash) VALUES {} """.format( placeholders diff --git a/datasette/database.py b/datasette/database.py index d8043c24..f782a515 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -1,6 +1,7 @@ import asyncio from collections import namedtuple from pathlib import Path +import hashlib import janus import queue import sys @@ -50,6 +51,24 @@ class Database: # This is used to track all file connections so they can be closed self._all_file_connections = [] + async def schema_version(self): + # This can return 'None' if the schema_version cannot be read + # See https://github.com/simonw/datasette/issues/2058 + try: + return (await self.execute("PRAGMA schema_version")).first()[0] + except sqlite3.OperationalError: + return None + + async def schema_hash(self): + return hashlib.md5( + ( + ( + await self.execute("SELECT group_concat(sql) FROM sqlite_master") + ).first()[0] + or "" + ).encode("utf8") + ).hexdigest() + @property def cached_table_counts(self): if self._cached_table_counts is not None: diff --git a/datasette/utils/internal_db.py b/datasette/utils/internal_db.py index e4b49e80..08868f3f 100644 --- a/datasette/utils/internal_db.py +++ b/datasette/utils/internal_db.py @@ -9,7 +9,8 @@ async def init_internal_db(db): database_name TEXT PRIMARY KEY, path TEXT, is_memory INTEGER, - schema_version INTEGER + schema_version INTEGER, + schema_hash TEXT ); CREATE TABLE IF NOT EXISTS tables ( database_name TEXT, diff --git a/tests/test_internal_db.py b/tests/test_internal_db.py index a666dd72..22c83623 100644 --- a/tests/test_internal_db.py +++ b/tests/test_internal_db.py @@ -1,4 +1,7 @@ import pytest +from unittest.mock import patch +from datasette.app import Datasette +from datasette.database import Database @pytest.mark.asyncio @@ -83,3 +86,51 @@ async def test_internal_foreign_keys(ds_client): "table_name", "from", } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("schema_version_returns_none", (True, False)) +async def test_detects_schema_changes(schema_version_returns_none): + ds = Datasette() + db_name = "test_detects_schema_changes_{}".format(schema_version_returns_none) + db = ds.add_memory_database(db_name) + # Test if Datasette correctly detects schema changes, whether or not + # the schema_version method is working. + # https://github.com/simonw/datasette/issues/2058 + + _internal = ds.get_database("_internal") + + async def get_tables(): + return [ + dict(r) + for r in await _internal.execute( + "select table_name from tables where database_name = ?", [db_name] + ) + ] + + async def test_it(): + await ds.refresh_schemas() + initial_hash = await db.schema_hash() + # _internal should list zero tables + tables = await get_tables() + assert tables == [] + # Create a new table + await db.execute_write("CREATE TABLE test (id INTEGER PRIMARY KEY)") + await ds.refresh_schemas() + assert await db.schema_hash() != initial_hash + # _internal should list one table + tables = await get_tables() + assert tables == [ + {"table_name": "test"}, + ] + + async def schema_version_none(self): + return None + + if schema_version_returns_none: + with patch( + "datasette.database.Database.schema_version", new=schema_version_none + ): + await test_it() + else: + await test_it()