Refactored run_sanity_checks to check_connection(conn), refs #674

prepare-connection-datasette
Simon Willison 2020-02-15 09:56:48 -08:00
rodzic f1442a8151
commit d3f2fade88
4 zmienionych plików z 60 dodań i 21 usunięć

Wyświetl plik

@ -216,26 +216,6 @@ class Datasette:
def remove_database(self, name): def remove_database(self, name):
self.databases.pop(name) self.databases.pop(name)
async def run_sanity_checks(self):
# Only one check right now, for Spatialite
for database_name, database in self.databases.items():
# Run pragma_info on every table
for table in await database.table_names():
try:
await self.execute(
database_name,
"PRAGMA table_info({});".format(escape_sqlite(table)),
)
except sqlite3.OperationalError as e:
if e.args[0] == "no such module: VirtualSpatialIndex":
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
"\n\nRead more: https://datasette.readthedocs.io/en/latest/spatialite.html"
)
else:
raise
def config(self, key): def config(self, key):
return self._config.get(key, None) return self._config.get(key, None)

Wyświetl plik

@ -10,6 +10,9 @@ from subprocess import call
import sys import sys
from .app import Datasette, DEFAULT_CONFIG, CONFIG_OPTIONS, pm from .app import Datasette, DEFAULT_CONFIG, CONFIG_OPTIONS, pm
from .utils import ( from .utils import (
check_connection,
ConnectionProblem,
SpatialiteConnectionProblem,
temporary_docker_directory, temporary_docker_directory,
value_as_boolean, value_as_boolean,
StaticMount, StaticMount,
@ -369,7 +372,25 @@ def serve(
version_note=version_note, version_note=version_note,
) )
# Run async sanity checks - but only if we're not under pytest # Run async sanity checks - but only if we're not under pytest
asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks()) asyncio.get_event_loop().run_until_complete(check_databases(ds))
# Start the server # Start the server
uvicorn.run(ds.app(), host=host, port=port, log_level="info") uvicorn.run(ds.app(), host=host, port=port, log_level="info")
async def check_databases(ds):
# Run check_connection against every connected database
# to confirm they are all usable
for database in list(ds.databases.values()):
try:
await database.execute_against_connection_in_thread(check_connection)
except SpatialiteConnectionProblem:
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
"\n\nRead more: https://datasette.readthedocs.io/en/latest/spatialite.html"
)
except ConnectionProblem as e:
raise click.UsageError(
"Connection to {} failed check: {}".format(database.path, str(e.args[0]))
)

Wyświetl plik

@ -790,3 +790,28 @@ class RequestParameters(dict):
def getlist(self, name, default=None): def getlist(self, name, default=None):
"Return full list" "Return full list"
return super().get(name, default) return super().get(name, default)
class ConnectionProblem(Exception):
pass
class SpatialiteConnectionProblem(ConnectionProblem):
pass
def check_connection(conn):
tables = [
r[0]
for r in conn.execute(
"select name from sqlite_master where type='table'"
).fetchall()
]
for table in tables:
try:
conn.execute("PRAGMA table_info({});".format(escape_sqlite(table)),)
except sqlite3.OperationalError as e:
if e.args[0] == "no such module: VirtualSpatialIndex":
raise SpatialiteConnectionProblem(e)
else:
raise ConnectionProblem(e)

Wyświetl plik

@ -7,6 +7,7 @@ from datasette.utils.asgi import Request
from datasette.filters import Filters from datasette.filters import Filters
import json import json
import os import os
import pathlib
import pytest import pytest
import sqlite3 import sqlite3
import tempfile import tempfile
@ -410,3 +411,15 @@ def test_format_bytes(bytes, expected):
) )
def test_escape_fts(query, expected): def test_escape_fts(query, expected):
assert expected == utils.escape_fts(query) assert expected == utils.escape_fts(query)
def test_check_connection_spatialite_raises():
path = str(pathlib.Path(__file__).parent / "spatialite.db")
conn = sqlite3.connect(path)
with pytest.raises(utils.SpatialiteConnectionProblem):
utils.check_connection(conn)
def test_check_connection_passes():
conn = sqlite3.connect(":memory:")
utils.check_connection(conn)