Experimental fix addressing #2058

schema-version-fix
Simon Willison 2023-04-12 17:51:34 -07:00
rodzic 5890a20c37
commit 98b2de65c6
4 zmienionych plików z 104 dodań i 13 usunięć

Wyświetl plik

@ -424,25 +424,45 @@ class Datasette:
await init_internal_db(internal_db) await init_internal_db(internal_db)
self.internal_db_created = True self.internal_db_created = True
current_schema_versions = { current_schema_versions_and_hashes = {
row["database_name"]: row["schema_version"] row["database_name"]: (row["schema_version"], row["schema_hash"])
for row in await internal_db.execute( for row in await internal_db.execute(
"select database_name, schema_version from databases" "select database_name, schema_version, schema_hash from databases"
) )
} }
for database_name, db in self.databases.items(): for database_name, db in self.databases.items():
schema_version = (await db.execute("PRAGMA schema_version")).first()[0] schema_version = await db.schema_version()
# Compare schema versions to see if we should skip it current_version_and_hash = current_schema_versions_and_hashes.get(
if schema_version == current_schema_versions.get(database_name): database_name
)
if current_version_and_hash:
# We might get to skip this database
if schema_version is not None and current_version_and_hash:
# Use this to decide if the schema has changed
if schema_version == current_version_and_hash[0]:
continue continue
placeholders = "(?, ?, ?, ?)" else:
values = [database_name, str(db.path), db.is_memory, schema_version] # Use the schema hash instead
schema_hash = await db.schema_hash()
if schema_hash == current_version_and_hash[1]:
continue
# Calculate new schema hash
schema_hash = await db.schema_hash()
placeholders = "(?, ?, ?, ?, ?)"
values = [
database_name,
str(db.path),
db.is_memory,
schema_version,
schema_hash,
]
if db.path is None: if db.path is None:
placeholders = "(?, null, ?, ?)" placeholders = "(?, null, ?, ?, ?)"
values = [database_name, db.is_memory, schema_version] values = [database_name, db.is_memory, schema_version, schema_hash]
await internal_db.execute_write( await internal_db.execute_write(
""" """
INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version) INSERT OR REPLACE INTO databases (database_name, path, is_memory, schema_version, schema_hash)
VALUES {} VALUES {}
""".format( """.format(
placeholders placeholders

Wyświetl plik

@ -1,6 +1,7 @@
import asyncio import asyncio
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
import hashlib
import janus import janus
import queue import queue
import sys import sys
@ -50,6 +51,24 @@ class Database:
# This is used to track all file connections so they can be closed # This is used to track all file connections so they can be closed
self._all_file_connections = [] self._all_file_connections = []
async def schema_version(self):
# This can return 'None' if the schema_version cannot be read
# See https://github.com/simonw/datasette/issues/2058
try:
return (await self.execute("PRAGMA schema_version")).first()[0]
except sqlite3.OperationalError:
return None
async def schema_hash(self):
return hashlib.md5(
(
(
await self.execute("SELECT group_concat(sql) FROM sqlite_master")
).first()[0]
or ""
).encode("utf8")
).hexdigest()
@property @property
def cached_table_counts(self): def cached_table_counts(self):
if self._cached_table_counts is not None: if self._cached_table_counts is not None:

Wyświetl plik

@ -9,7 +9,8 @@ async def init_internal_db(db):
database_name TEXT PRIMARY KEY, database_name TEXT PRIMARY KEY,
path TEXT, path TEXT,
is_memory INTEGER, is_memory INTEGER,
schema_version INTEGER schema_version INTEGER,
schema_hash TEXT
); );
CREATE TABLE IF NOT EXISTS tables ( CREATE TABLE IF NOT EXISTS tables (
database_name TEXT, database_name TEXT,

Wyświetl plik

@ -1,4 +1,7 @@
import pytest import pytest
from unittest.mock import patch
from datasette.app import Datasette
from datasette.database import Database
@pytest.mark.asyncio @pytest.mark.asyncio
@ -83,3 +86,51 @@ async def test_internal_foreign_keys(ds_client):
"table_name", "table_name",
"from", "from",
} }
@pytest.mark.asyncio
@pytest.mark.parametrize("schema_version_returns_none", (True, False))
async def test_detects_schema_changes(schema_version_returns_none):
ds = Datasette()
db_name = "test_detects_schema_changes_{}".format(schema_version_returns_none)
db = ds.add_memory_database(db_name)
# Test if Datasette correctly detects schema changes, whether or not
# the schema_version method is working.
# https://github.com/simonw/datasette/issues/2058
_internal = ds.get_database("_internal")
async def get_tables():
return [
dict(r)
for r in await _internal.execute(
"select table_name from tables where database_name = ?", [db_name]
)
]
async def test_it():
await ds.refresh_schemas()
initial_hash = await db.schema_hash()
# _internal should list zero tables
tables = await get_tables()
assert tables == []
# Create a new table
await db.execute_write("CREATE TABLE test (id INTEGER PRIMARY KEY)")
await ds.refresh_schemas()
assert await db.schema_hash() != initial_hash
# _internal should list one table
tables = await get_tables()
assert tables == [
{"table_name": "test"},
]
async def schema_version_none(self):
return None
if schema_version_returns_none:
with patch(
"datasette.database.Database.schema_version", new=schema_version_none
):
await test_it()
else:
await test_it()