diff --git a/datasette/app.py b/datasette/app.py index 690d7d5a..9d7c2a02 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -14,10 +14,11 @@ import json import hashlib import time from .utils import ( - build_where_clause, + build_where_clauses, CustomJSONEncoder, InvalidSql, path_from_row_pks, + path_with_added_args, compound_pks_from_path, sqlite_timelimit, validate_sql_select, @@ -78,6 +79,7 @@ class BaseView(HTTPMethodView): self.files = datasette.files self.jinja = datasette.jinja self.executor = datasette.executor + self.page_size = datasette.page_size self.cache_headers = datasette.cache_headers def redirect(self, request, path): @@ -270,16 +272,52 @@ class TableView(BaseView): select = '*' order_by = ', '.join(pks) - if request.args: - where_clause, params = build_where_clause(request.args) - sql = 'select {} from "{}" where {} order by {} limit 50'.format( - select, table, where_clause, order_by - ) + # Special args start with _ and do not contain a __ + # That's so if there is a column that starts with _ + # it can still be queried using ?_col__exact=blah + special_args = {} + other_args = {} + for key, value in request.args.items(): + if key.startswith('_') and '__' not in key: + special_args[key] = value[0] + else: + other_args[key] = value[0] + + if other_args: + where_clauses, params = build_where_clauses(other_args) else: - sql = 'select {} from "{}" order by {} limit 50'.format( - select, table, order_by - ) - params = [] + where_clauses = [] + params = {} + + after = special_args.get('_after') + if after: + if use_rowid: + where_clauses.append( + 'rowid > :p{}'.format( + len(params), + ) + ) + params['p{}'.format(len(params))] = after + else: + pk_values = compound_pks_from_path(after) + if len(pk_values) == len(pks): + param_counter = len(params) + for pk, value in zip(pks, pk_values): + where_clauses.append( + '"{}" > :p{}'.format( + pk, param_counter, + ) + ) + params['p{}'.format(param_counter)] = value + param_counter += 1 + + where_clause = '' + if where_clauses: + where_clause = 'where {}'.format(' and '.join(where_clauses)) + + sql = 'select {} from "{}" {} order by {} limit {}'.format( + select, table, where_clause, order_by, self.page_size + 1, + ) rows = await self.execute(name, sql, params) @@ -290,20 +328,27 @@ class TableView(BaseView): rows = list(rows) info = ensure_build_metadata(self.files) total_rows = info[name]['tables'].get(table) + after = None + after_link = None + if len(rows) > self.page_size: + after = path_from_row_pks(rows[-2], pks, use_rowid) + after_link = path_with_added_args(request, {'_after': after}) return { 'database': name, 'table': table, - 'rows': rows, + 'rows': rows[:self.page_size], 'total_rows': total_rows, 'columns': columns, 'primary_keys': pks, 'sql': sql, 'sql_params': params, + 'after': after, }, lambda: { 'database_hash': hash, 'use_rowid': use_rowid, 'row_link': lambda row: path_from_row_pks(row, pks, use_rowid), 'display_columns': display_columns, + 'after_link': after_link, } @@ -381,13 +426,14 @@ def resolve_db_name(files, db_name, **kwargs): class Datasette: - def __init__(self, files, num_threads=3, cache_headers=True): + def __init__(self, files, num_threads=3, cache_headers=True, page_size=50): self.files = files self.num_threads = num_threads self.executor = futures.ThreadPoolExecutor( max_workers=num_threads ) self.cache_headers = cache_headers + self.page_size = page_size def app(self): app = Sanic(__name__) diff --git a/datasette/templates/table.html b/datasette/templates/table.html index 2f167b42..c8f08d5f 100644 --- a/datasette/templates/table.html +++ b/datasette/templates/table.html @@ -34,5 +34,8 @@ td { {% endfor %} +{% if after_link %} +

Next page

+{% endif %} {% if took_ms %}Took {{ took_ms }}{% endif %} {% endblock %} diff --git a/datasette/utils.py b/datasette/utils.py index 000f86d2..d9ae3b6b 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -23,10 +23,10 @@ def path_from_row_pks(row, pks, use_rowid): return ','.join(bits) -def build_where_clause(args): +def build_where_clauses(args): sql_bits = [] params = {} - for i, (key, values) in enumerate(sorted(args.items())): + for i, (key, value) in enumerate(sorted(args.items())): if '__' in key: column, lookup = key.rsplit('__', 1) else: @@ -45,7 +45,6 @@ def build_where_clause(args): 'like': '"{}" like :{}', }[lookup] numeric_operators = {'gt', 'gte', 'lt', 'lte'} - value = values[0] value_convert = { 'contains': lambda s: '%{}%'.format(s), 'endswith': lambda s: '%{}'.format(s), @@ -59,8 +58,7 @@ def build_where_clause(args): template.format(column, param_id) ) params[param_id] = converted - where_clause = ' and '.join(sql_bits) - return where_clause, params + return sql_bits, params class CustomJSONEncoder(json.JSONEncoder): @@ -103,3 +101,9 @@ def validate_sql_select(sql): raise InvalidSql('Statement must begin with SELECT') if 'pragma' in sql: raise InvalidSql('Statement may not contain PRAGMA') + + +def path_with_added_args(request, args): + current = request.raw_args.copy() + current.update(args) + return request.path + '?' + urllib.parse.urlencode(current) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5a3f26a5..e37fd325 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -90,7 +90,8 @@ def test_custom_json_encoder(obj, expected): ), ]) def test_build_where(args, expected_where, expected_params): - actual_where, actual_params = utils.build_where_clause(args) + sql_bits, actual_params = utils.build_where_clauses(args) + actual_where = ' and '.join(sql_bits) assert expected_where == actual_where assert { 'p{}'.format(i): param