Implemented cursor-based pagination for table view

Closes #5
pull/81/head
Simon Willison 2017-11-10 12:41:14 -08:00
rodzic e9fce44195
commit 21c9c04310
4 zmienionych plików z 72 dodań i 18 usunięć

Wyświetl plik

@ -14,10 +14,11 @@ import json
import hashlib import hashlib
import time import time
from .utils import ( from .utils import (
build_where_clause, build_where_clauses,
CustomJSONEncoder, CustomJSONEncoder,
InvalidSql, InvalidSql,
path_from_row_pks, path_from_row_pks,
path_with_added_args,
compound_pks_from_path, compound_pks_from_path,
sqlite_timelimit, sqlite_timelimit,
validate_sql_select, validate_sql_select,
@ -78,6 +79,7 @@ class BaseView(HTTPMethodView):
self.files = datasette.files self.files = datasette.files
self.jinja = datasette.jinja self.jinja = datasette.jinja
self.executor = datasette.executor self.executor = datasette.executor
self.page_size = datasette.page_size
self.cache_headers = datasette.cache_headers self.cache_headers = datasette.cache_headers
def redirect(self, request, path): def redirect(self, request, path):
@ -270,16 +272,52 @@ class TableView(BaseView):
select = '*' select = '*'
order_by = ', '.join(pks) order_by = ', '.join(pks)
if request.args: # Special args start with _ and do not contain a __
where_clause, params = build_where_clause(request.args) # That's so if there is a column that starts with _
sql = 'select {} from "{}" where {} order by {} limit 50'.format( # it can still be queried using ?_col__exact=blah
select, table, where_clause, order_by special_args = {}
) other_args = {}
for key, value in request.args.items():
if key.startswith('_') and '__' not in key:
special_args[key] = value[0]
else: else:
sql = 'select {} from "{}" order by {} limit 50'.format( other_args[key] = value[0]
select, table, order_by
if other_args:
where_clauses, params = build_where_clauses(other_args)
else:
where_clauses = []
params = {}
after = special_args.get('_after')
if after:
if use_rowid:
where_clauses.append(
'rowid > :p{}'.format(
len(params),
)
)
params['p{}'.format(len(params))] = after
else:
pk_values = compound_pks_from_path(after)
if len(pk_values) == len(pks):
param_counter = len(params)
for pk, value in zip(pks, pk_values):
where_clauses.append(
'"{}" > :p{}'.format(
pk, param_counter,
)
)
params['p{}'.format(param_counter)] = value
param_counter += 1
where_clause = ''
if where_clauses:
where_clause = 'where {}'.format(' and '.join(where_clauses))
sql = 'select {} from "{}" {} order by {} limit {}'.format(
select, table, where_clause, order_by, self.page_size + 1,
) )
params = []
rows = await self.execute(name, sql, params) rows = await self.execute(name, sql, params)
@ -290,20 +328,27 @@ class TableView(BaseView):
rows = list(rows) rows = list(rows)
info = ensure_build_metadata(self.files) info = ensure_build_metadata(self.files)
total_rows = info[name]['tables'].get(table) total_rows = info[name]['tables'].get(table)
after = None
after_link = None
if len(rows) > self.page_size:
after = path_from_row_pks(rows[-2], pks, use_rowid)
after_link = path_with_added_args(request, {'_after': after})
return { return {
'database': name, 'database': name,
'table': table, 'table': table,
'rows': rows, 'rows': rows[:self.page_size],
'total_rows': total_rows, 'total_rows': total_rows,
'columns': columns, 'columns': columns,
'primary_keys': pks, 'primary_keys': pks,
'sql': sql, 'sql': sql,
'sql_params': params, 'sql_params': params,
'after': after,
}, lambda: { }, lambda: {
'database_hash': hash, 'database_hash': hash,
'use_rowid': use_rowid, 'use_rowid': use_rowid,
'row_link': lambda row: path_from_row_pks(row, pks, use_rowid), 'row_link': lambda row: path_from_row_pks(row, pks, use_rowid),
'display_columns': display_columns, 'display_columns': display_columns,
'after_link': after_link,
} }
@ -381,13 +426,14 @@ def resolve_db_name(files, db_name, **kwargs):
class Datasette: class Datasette:
def __init__(self, files, num_threads=3, cache_headers=True): def __init__(self, files, num_threads=3, cache_headers=True, page_size=50):
self.files = files self.files = files
self.num_threads = num_threads self.num_threads = num_threads
self.executor = futures.ThreadPoolExecutor( self.executor = futures.ThreadPoolExecutor(
max_workers=num_threads max_workers=num_threads
) )
self.cache_headers = cache_headers self.cache_headers = cache_headers
self.page_size = page_size
def app(self): def app(self):
app = Sanic(__name__) app = Sanic(__name__)

Wyświetl plik

@ -34,5 +34,8 @@ td {
</tr> </tr>
{% endfor %} {% endfor %}
</table> </table>
{% if after_link %}
<p><a href="{{ after_link }}">Next page</a></p>
{% endif %}
{% if took_ms %}<small>Took {{ took_ms }}</small>{% endif %} {% if took_ms %}<small>Took {{ took_ms }}</small>{% endif %}
{% endblock %} {% endblock %}

Wyświetl plik

@ -23,10 +23,10 @@ def path_from_row_pks(row, pks, use_rowid):
return ','.join(bits) return ','.join(bits)
def build_where_clause(args): def build_where_clauses(args):
sql_bits = [] sql_bits = []
params = {} params = {}
for i, (key, values) in enumerate(sorted(args.items())): for i, (key, value) in enumerate(sorted(args.items())):
if '__' in key: if '__' in key:
column, lookup = key.rsplit('__', 1) column, lookup = key.rsplit('__', 1)
else: else:
@ -45,7 +45,6 @@ def build_where_clause(args):
'like': '"{}" like :{}', 'like': '"{}" like :{}',
}[lookup] }[lookup]
numeric_operators = {'gt', 'gte', 'lt', 'lte'} numeric_operators = {'gt', 'gte', 'lt', 'lte'}
value = values[0]
value_convert = { value_convert = {
'contains': lambda s: '%{}%'.format(s), 'contains': lambda s: '%{}%'.format(s),
'endswith': lambda s: '%{}'.format(s), 'endswith': lambda s: '%{}'.format(s),
@ -59,8 +58,7 @@ def build_where_clause(args):
template.format(column, param_id) template.format(column, param_id)
) )
params[param_id] = converted params[param_id] = converted
where_clause = ' and '.join(sql_bits) return sql_bits, params
return where_clause, params
class CustomJSONEncoder(json.JSONEncoder): class CustomJSONEncoder(json.JSONEncoder):
@ -103,3 +101,9 @@ def validate_sql_select(sql):
raise InvalidSql('Statement must begin with SELECT') raise InvalidSql('Statement must begin with SELECT')
if 'pragma' in sql: if 'pragma' in sql:
raise InvalidSql('Statement may not contain PRAGMA') raise InvalidSql('Statement may not contain PRAGMA')
def path_with_added_args(request, args):
current = request.raw_args.copy()
current.update(args)
return request.path + '?' + urllib.parse.urlencode(current)

Wyświetl plik

@ -90,7 +90,8 @@ def test_custom_json_encoder(obj, expected):
), ),
]) ])
def test_build_where(args, expected_where, expected_params): def test_build_where(args, expected_where, expected_params):
actual_where, actual_params = utils.build_where_clause(args) sql_bits, actual_params = utils.build_where_clauses(args)
actual_where = ' and '.join(sql_bits)
assert expected_where == actual_where assert expected_where == actual_where
assert { assert {
'p{}'.format(i): param 'p{}'.format(i): param