diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index f819aa82..33decbfc 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -678,10 +678,12 @@ async def resolve_table_and_format( return table_and_format, None -def path_with_format(request, format, extra_qs=None): +def path_with_format(request, format, extra_qs=None, replace_format=None): qs = extra_qs or {} path = request.path - if "." in request.path: + if replace_format and path.endswith(".{}".format(replace_format)): + path = path[: -(1 + len(replace_format))] + if "." in path: qs["_format"] = format else: path = "{}.{}".format(path, format) diff --git a/datasette/views/base.py b/datasette/views/base.py index 4432ddca..6ca78934 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -1,5 +1,6 @@ import asyncio import csv +import hashlib import re import time import urllib @@ -14,6 +15,7 @@ from datasette.utils import ( InvalidSql, LimitedWriter, call_with_supported_arguments, + path_from_row_pks, path_with_added_args, path_with_removed_args, path_with_format, @@ -310,6 +312,40 @@ class DataView(BaseView): first = False next = data.get("next") for row in data["rows"]: + if any(isinstance(r, bytes) for r in row): + new_row = [] + for column, cell in zip(headings, row): + if isinstance(cell, bytes): + # If this is a table page, use .urls.row_blob() + if data.get("table"): + pks = data.get("primary_keys") or [] + cell = self.ds.absolute_url( + request, + self.ds.urls.row_blob( + database, + data["table"], + path_from_row_pks(row, pks, not pks), + column, + ), + ) + else: + # Otherwise generate URL for this query + cell = self.ds.absolute_url( + request, + path_with_format( + request, + "blob", + extra_qs={ + "_blob_column": column, + "_blob_hash": hashlib.sha256( + cell + ).hexdigest(), + }, + replace_format="csv", + ), + ) + new_row.append(cell) + row = new_row if not expanded_columns: # Simple path await writer.writerow(row) diff --git a/tests/test_csv.py b/tests/test_csv.py index 863659f7..1a701828 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -1,5 +1,3 @@ -import textwrap -import pytest from .fixtures import ( # noqa app_client, app_client_csv_max_mb_one, @@ -80,19 +78,27 @@ def test_table_csv_with_nullable_labels(app_client): assert EXPECTED_TABLE_WITH_NULLABLE_LABELS_CSV == response.text -@pytest.mark.xfail def test_table_csv_blob_columns(app_client): response = app_client.get("/fixtures/binary_data.csv") assert response.status == 200 assert "text/plain; charset=utf-8" == response.headers["content-type"] - assert EXPECTED_TABLE_CSV == textwrap.dedent( - """ - rowid,data - 1,/fixtures/binary_data/-/blob/1/data.blob - 2,/fixtures/binary_data/-/blob/1/data.blob - """.strip().replace( - "\n", "\r\n" - ) + assert response.text == ( + "rowid,data\r\n" + "1,http://localhost/fixtures/binary_data/1.blob?_blob_column=data\r\n" + "2,http://localhost/fixtures/binary_data/2.blob?_blob_column=data\r\n" + "3,\r\n" + ) + + +def test_custom_sql_csv_blob_columns(app_client): + response = app_client.get("/fixtures.csv?sql=select+rowid,+data+from+binary_data") + assert response.status == 200 + assert "text/plain; charset=utf-8" == response.headers["content-type"] + assert response.text == ( + "rowid,data\r\n" + '1,"http://localhost/fixtures.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=f3088978da8f9aea479ffc7f631370b968d2e855eeb172bea7f6c7a04262bb6d"\r\n' + '2,"http://localhost/fixtures.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=b835b0483cedb86130b9a2c280880bf5fadc5318ddf8c18d0df5204d40df1724"\r\n' + "3,\r\n" ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0e2af098..bae3b685 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -386,6 +386,14 @@ def test_path_with_format(path, format, extra_qs, expected): assert expected == actual +def test_path_with_format_replace_format(): + request = Request.fake("/foo/bar.csv") + assert utils.path_with_format(request, "blob") == "/foo/bar.csv?_format=blob" + assert ( + utils.path_with_format(request, "blob", replace_format="csv") == "/foo/bar.blob" + ) + + @pytest.mark.parametrize( "bytes,expected", [