Moved all SQLite queries to threads

SQLite operations are blocking, but we're running everything in Sanic, an
asyncio web framework, so blocking operations are bad - a long-running DB
operation could hold up the entire server.

Instead, I've moved all SQLite operations into threads. These are managed by a
concurrent.futures ThreadPoolExecutor. This means I can run up to X queries in
parallel, and I can continue to queue up additional incoming HTTP traffic
while the threadpool is busy.

Each thread is responsible for managing its own SQLite connections - one per
database. These are cached in a threadlocal.

Since we are working with immutable, read-only SQLite databases it should be
safe to share SQLite objects across threads. On this assumption I'm using the
check_same_thread=False option. Opening a database connection looks like this:

    conn = sqlite3.connect(
        'file:filename.db?immutable=1',
        uri=True,
        check_same_thread=False,
    )

The following articles were helpful in figuring this out:

* https://pymotw.com/3/asyncio/executors.html
* https://marlinux.wordpress.com/2017/05/19/python-3-6-asyncio-sqlalchemy/

Closes #45. Refs #38.
pull/81/head
Simon Willison 2017-11-04 19:21:44 -07:00
rodzic 1fc75809a6
commit 31b21f5c5e
1 zmienionych plików z 58 dodań i 47 usunięć

Wyświetl plik

@ -8,6 +8,9 @@ import sqlite3
from contextlib import contextmanager
from pathlib import Path
from functools import wraps
from concurrent import futures
import asyncio
import threading
import urllib.parse
import json
import base64
@ -22,19 +25,7 @@ DB_GLOBS = ('*.db', '*.sqlite', '*.sqlite3')
HASH_BLOCK_SIZE = 1024 * 1024
SQL_TIME_LIMIT_MS = 1000
conns = {}
def get_conn(name):
if name not in conns:
info = ensure_build_metadata()[name]
conns[name] = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True
)
conns[name].row_factory = sqlite3.Row
conns[name].text_factory = lambda x: str(x, 'utf-8', 'replace')
return conns[name]
connections = threading.local()
def ensure_build_metadata(regenerate=False):
@ -80,8 +71,9 @@ def ensure_build_metadata(regenerate=False):
class BaseView(HTTPMethodView):
template = None
def __init__(self, jinja):
def __init__(self, jinja, executor):
self.jinja = jinja
self.executor = executor
def redirect(self, request, path):
if request.query_string:
@ -92,6 +84,40 @@ class BaseView(HTTPMethodView):
r.headers['Link'] = '<{}>; rel=preload'.format(path)
return r
async def pks_for_table(self, name, table):
rows = [
row for row in await self.execute(
name,
'PRAGMA table_info("{}")'.format(table)
)
if row[-1]
]
rows.sort(key=lambda row: row[-1])
return [str(r[1]) for r in rows]
async def execute(self, db_name, sql):
"""Executes sql against db_name in a thread"""
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = ensure_build_metadata()[db_name]
conn = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True,
check_same_thread=False,
)
conn.row_factory = sqlite3.Row
conn.text_factory = lambda x: str(x, 'utf-8', 'replace')
setattr(connections, db_name, conn)
with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS):
rows = conn.execute(sql)
return rows
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = resolve_db_name(db_name, **kwargs)
if should_redirect:
@ -106,7 +132,7 @@ class BaseView(HTTPMethodView):
extra_template_data = {}
start = time.time()
try:
data, extra_template_data = self.data(
data, extra_template_data = await self.data(
request, name, hash, **kwargs
)
except sqlite3.OperationalError as e:
@ -154,8 +180,9 @@ class BaseView(HTTPMethodView):
class IndexView(HTTPMethodView):
def __init__(self, jinja):
def __init__(self, jinja, executor):
self.jinja = jinja
self.executor = executor
async def get(self, request):
databases = []
@ -188,11 +215,9 @@ async def favicon(request):
class DatabaseView(BaseView):
template = 'database.html'
def data(self, request, name, hash):
conn = get_conn(name)
async def data(self, request, name, hash):
sql = request.args.get('sql') or 'select * from sqlite_master'
with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS):
rows = conn.execute(sql)
rows = await self.execute(name, sql)
columns = [r[0] for r in rows.description]
return {
'database': name,
@ -216,8 +241,7 @@ class DatabaseDownload(BaseView):
class TableView(BaseView):
template = 'table.html'
def data(self, request, name, hash, table):
conn = get_conn(name)
async def data(self, request, name, hash, table):
table = urllib.parse.unquote_plus(table)
if request.args:
where_clause, params = build_where_clause(request.args)
@ -228,12 +252,11 @@ class TableView(BaseView):
sql = 'select * from "{}" limit 50'.format(table)
params = []
with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS):
rows = conn.execute(sql, params)
rows = await self.execute(name, sql)
columns = [r[0] for r in rows.description]
rows = list(rows)
pks = pks_for_table(conn, table)
pks = await self.pks_for_table(name, table)
info = ensure_build_metadata()
total_rows = info[name]['tables'].get(table)
return {
@ -252,11 +275,10 @@ class TableView(BaseView):
class RowView(BaseView):
template = 'row.html'
def data(self, request, name, hash, table, pk_path):
conn = get_conn(name)
async def data(self, request, name, hash, table, pk_path):
table = urllib.parse.unquote_plus(table)
pk_values = compound_pks_from_path(pk_path)
pks = pks_for_table(conn, table)
pks = await self.pks_for_table(name, table)
wheres = [
'"{}"=?'.format(pk)
for pk in pks
@ -264,9 +286,8 @@ class RowView(BaseView):
sql = 'select * from "{}" where {}'.format(
table, ' AND '.join(wheres)
)
rows = conn.execute(sql, pk_values)
rows = await self.execute(name, sql)
columns = [r[0] for r in rows.description]
pks = pks_for_table(conn, table)
rows = list(rows)
if not rows:
raise NotFound('Record not found: {}'.format(pk_values))
@ -322,17 +343,6 @@ def compound_pks_from_path(path):
]
def pks_for_table(conn, table):
rows = [
row for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
if row[-1]
]
rows.sort(key=lambda row: row[-1])
return [str(r[1]) for r in rows]
def path_from_row_pks(row, pks):
if not pks:
return ''
@ -410,7 +420,7 @@ def sqlite_timelimit(conn, ms):
conn.set_progress_handler(None, 10000)
def app_factory(files):
def app_factory(files, num_threads=3):
app = Sanic(__name__)
jinja = SanicJinja2(
app,
@ -418,22 +428,23 @@ def app_factory(files):
str(app_root / 'datasite' / 'templates')
])
)
app.add_route(IndexView.as_view(jinja), '/')
executor = futures.ThreadPoolExecutor(max_workers=num_threads)
app.add_route(IndexView.as_view(jinja, executor), '/')
app.add_route(favicon, '/favicon.ico')
app.add_route(
DatabaseView.as_view(jinja),
DatabaseView.as_view(jinja, executor),
'/<db_name:[^/\.]+?><as_json:(.jsono?)?$>'
)
app.add_route(
DatabaseDownload.as_view(jinja),
DatabaseDownload.as_view(jinja, executor),
'/<db_name:[^/]+?><as_db:(\.db)$>'
)
app.add_route(
TableView.as_view(jinja),
TableView.as_view(jinja, executor),
'/<db_name:[^/]+>/<table:[^/]+?><as_json:(.jsono?)?$>'
)
app.add_route(
RowView.as_view(jinja),
RowView.as_view(jinja, executor),
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>'
)
return app