diff --git a/datasette/app.py b/datasette/app.py index cf0b6ab7..06543761 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -176,10 +176,13 @@ class BaseView(RenderMixin): try: cursor = conn.cursor() cursor.execute(sql, params or {}) - if self.max_returned_rows and truncate: - rows = cursor.fetchmany(self.max_returned_rows + 1) - truncated = len(rows) > self.max_returned_rows - rows = rows[:self.max_returned_rows] + max_returned_rows = self.max_returned_rows + if max_returned_rows == self.page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] else: rows = cursor.fetchall() truncated = False diff --git a/tests/fixtures.py b/tests/fixtures.py index 564306b1..74be52b7 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -9,7 +9,7 @@ import tempfile import time -def app_client(sql_time_limit_ms=None): +def app_client(sql_time_limit_ms=None, max_returned_rows=None): with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, 'test_tables.db') conn = sqlite3.connect(filepath) @@ -21,7 +21,7 @@ def app_client(sql_time_limit_ms=None): ds = Datasette( [filepath], page_size=50, - max_returned_rows=100, + max_returned_rows=max_returned_rows or 100, sql_time_limit_ms=sql_time_limit_ms or 20, metadata=METADATA, plugins_dir=plugins_dir, @@ -38,6 +38,10 @@ def app_client_longer_time_limit(): yield from app_client(200) +def app_client_returend_rows_matches_page_size(): + yield from app_client(max_returned_rows=50) + + def generate_compound_rows(num): for a, b, c in itertools.islice( itertools.product(string.ascii_lowercase, repeat=3), num diff --git a/tests/test_api.py b/tests/test_api.py index 39d8f132..0a804645 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ from .fixtures import ( app_client, app_client_longer_time_limit, + app_client_returend_rows_matches_page_size, generate_compound_rows, generate_sortable_rows, METADATA, @@ -9,6 +10,7 @@ import pytest pytest.fixture(scope='module')(app_client) pytest.fixture(scope='module')(app_client_longer_time_limit) +pytest.fixture(scope='module')(app_client_returend_rows_matches_page_size) def test_homepage(app_client): @@ -691,3 +693,16 @@ def test_plugins_json(app_client): 'static': False, 'templates': False } in response.json + + +def test_page_size_matching_max_returned_rows(app_client_returend_rows_matches_page_size): + fetched = [] + path = '/test_tables/no_primary_key.json' + while path: + response = app_client_returend_rows_matches_page_size.get( + path, gather_request=False + ) + fetched.extend(response.json['rows']) + assert len(response.json['rows']) in (1, 50) + path = response.json['next_url'] + assert 201 == len(fetched)