kopia lustrzana https://github.com/simonw/datasette
Got CSV working again
rodzic
e772eb6429
commit
921faae104
|
@ -175,172 +175,7 @@ class DataView(BaseView):
|
|||
raise NotImplementedError
|
||||
|
||||
async def as_csv(self, request, database):
|
||||
kwargs = {}
|
||||
stream = request.args.get("_stream")
|
||||
# Do not calculate facets or counts:
|
||||
extra_parameters = [
|
||||
"{}=1".format(key)
|
||||
for key in ("_nofacet", "_nocount")
|
||||
if not request.args.get(key)
|
||||
]
|
||||
if extra_parameters:
|
||||
# Replace request object with a new one with modified scope
|
||||
if not request.query_string:
|
||||
new_query_string = "&".join(extra_parameters)
|
||||
else:
|
||||
new_query_string = (
|
||||
request.query_string + "&" + "&".join(extra_parameters)
|
||||
)
|
||||
new_scope = dict(
|
||||
request.scope, query_string=new_query_string.encode("latin-1")
|
||||
)
|
||||
receive = request.receive
|
||||
request = Request(new_scope, receive)
|
||||
if stream:
|
||||
# Some quick soundness checks
|
||||
if not self.ds.setting("allow_csv_stream"):
|
||||
raise BadRequest("CSV streaming is disabled")
|
||||
if request.args.get("_next"):
|
||||
raise BadRequest("_next not allowed for CSV streaming")
|
||||
kwargs["_size"] = "max"
|
||||
# Fetch the first page
|
||||
try:
|
||||
response_or_template_contexts = await self.data(request)
|
||||
if isinstance(response_or_template_contexts, Response):
|
||||
return response_or_template_contexts
|
||||
elif len(response_or_template_contexts) == 4:
|
||||
data, _, _, _ = response_or_template_contexts
|
||||
else:
|
||||
data, _, _ = response_or_template_contexts
|
||||
except (sqlite3.OperationalError, InvalidSql) as e:
|
||||
raise DatasetteError(str(e), title="Invalid SQL", status=400)
|
||||
|
||||
except sqlite3.OperationalError as e:
|
||||
raise DatasetteError(str(e))
|
||||
|
||||
except DatasetteError:
|
||||
raise
|
||||
|
||||
# Convert rows and columns to CSV
|
||||
headings = data["columns"]
|
||||
# if there are expanded_columns we need to add additional headings
|
||||
expanded_columns = set(data.get("expanded_columns") or [])
|
||||
if expanded_columns:
|
||||
headings = []
|
||||
for column in data["columns"]:
|
||||
headings.append(column)
|
||||
if column in expanded_columns:
|
||||
headings.append(f"{column}_label")
|
||||
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
preamble = ""
|
||||
postamble = ""
|
||||
|
||||
trace = request.args.get("_trace")
|
||||
if trace:
|
||||
content_type = "text/html; charset=utf-8"
|
||||
preamble = (
|
||||
"<html><head><title>CSV debug</title></head>"
|
||||
'<body><textarea style="width: 90%; height: 70vh">'
|
||||
)
|
||||
postamble = "</textarea></body></html>"
|
||||
|
||||
async def stream_fn(r):
|
||||
nonlocal data, trace
|
||||
limited_writer = LimitedWriter(r, self.ds.setting("max_csv_mb"))
|
||||
if trace:
|
||||
await limited_writer.write(preamble)
|
||||
writer = csv.writer(EscapeHtmlWriter(limited_writer))
|
||||
else:
|
||||
writer = csv.writer(limited_writer)
|
||||
first = True
|
||||
next = None
|
||||
while first or (next and stream):
|
||||
try:
|
||||
kwargs = {}
|
||||
if next:
|
||||
kwargs["_next"] = next
|
||||
if not first:
|
||||
data, _, _ = await self.data(request, **kwargs)
|
||||
if first:
|
||||
if request.args.get("_header") != "off":
|
||||
await writer.writerow(headings)
|
||||
first = False
|
||||
next = data.get("next")
|
||||
for row in data["rows"]:
|
||||
if any(isinstance(r, bytes) for r in row):
|
||||
new_row = []
|
||||
for column, cell in zip(headings, row):
|
||||
if isinstance(cell, bytes):
|
||||
# If this is a table page, use .urls.row_blob()
|
||||
if data.get("table"):
|
||||
pks = data.get("primary_keys") or []
|
||||
cell = self.ds.absolute_url(
|
||||
request,
|
||||
self.ds.urls.row_blob(
|
||||
database,
|
||||
data["table"],
|
||||
path_from_row_pks(row, pks, not pks),
|
||||
column,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Otherwise generate URL for this query
|
||||
url = self.ds.absolute_url(
|
||||
request,
|
||||
path_with_format(
|
||||
request=request,
|
||||
format="blob",
|
||||
extra_qs={
|
||||
"_blob_column": column,
|
||||
"_blob_hash": hashlib.sha256(
|
||||
cell
|
||||
).hexdigest(),
|
||||
},
|
||||
replace_format="csv",
|
||||
),
|
||||
)
|
||||
cell = url.replace("&_nocount=1", "").replace(
|
||||
"&_nofacet=1", ""
|
||||
)
|
||||
new_row.append(cell)
|
||||
row = new_row
|
||||
if not expanded_columns:
|
||||
# Simple path
|
||||
await writer.writerow(row)
|
||||
else:
|
||||
# Look for {"value": "label": } dicts and expand
|
||||
new_row = []
|
||||
for heading, cell in zip(data["columns"], row):
|
||||
if heading in expanded_columns:
|
||||
if cell is None:
|
||||
new_row.extend(("", ""))
|
||||
else:
|
||||
assert isinstance(cell, dict)
|
||||
new_row.append(cell["value"])
|
||||
new_row.append(cell["label"])
|
||||
else:
|
||||
new_row.append(cell)
|
||||
await writer.writerow(new_row)
|
||||
except Exception as e:
|
||||
sys.stderr.write("Caught this error: {}\n".format(e))
|
||||
sys.stderr.flush()
|
||||
await r.write(str(e))
|
||||
return
|
||||
await limited_writer.write(postamble)
|
||||
|
||||
headers = {}
|
||||
if self.ds.cors:
|
||||
add_cors_headers(headers)
|
||||
if request.args.get("_dl", None):
|
||||
if not trace:
|
||||
content_type = "text/csv; charset=utf-8"
|
||||
disposition = 'attachment; filename="{}.csv"'.format(
|
||||
request.url_vars.get("table", database)
|
||||
)
|
||||
headers["content-disposition"] = disposition
|
||||
|
||||
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
|
||||
return await stream_csv(self.ds, self.data, request, database)
|
||||
|
||||
async def get(self, request):
|
||||
db = await self.ds.resolve_database(request)
|
||||
|
@ -543,3 +378,169 @@ class DataView(BaseView):
|
|||
|
||||
def _error(messages, status=400):
|
||||
return Response.json({"ok": False, "errors": messages}, status=status)
|
||||
|
||||
|
||||
async def stream_csv(datasette, fetch_data, request, database):
|
||||
kwargs = {}
|
||||
stream = request.args.get("_stream")
|
||||
# Do not calculate facets or counts:
|
||||
extra_parameters = [
|
||||
"{}=1".format(key)
|
||||
for key in ("_nofacet", "_nocount")
|
||||
if not request.args.get(key)
|
||||
]
|
||||
if extra_parameters:
|
||||
# Replace request object with a new one with modified scope
|
||||
if not request.query_string:
|
||||
new_query_string = "&".join(extra_parameters)
|
||||
else:
|
||||
new_query_string = request.query_string + "&" + "&".join(extra_parameters)
|
||||
new_scope = dict(request.scope, query_string=new_query_string.encode("latin-1"))
|
||||
receive = request.receive
|
||||
request = Request(new_scope, receive)
|
||||
if stream:
|
||||
# Some quick soundness checks
|
||||
if not datasette.setting("allow_csv_stream"):
|
||||
raise BadRequest("CSV streaming is disabled")
|
||||
if request.args.get("_next"):
|
||||
raise BadRequest("_next not allowed for CSV streaming")
|
||||
kwargs["_size"] = "max"
|
||||
# Fetch the first page
|
||||
try:
|
||||
response_or_template_contexts = await fetch_data(request)
|
||||
if isinstance(response_or_template_contexts, Response):
|
||||
return response_or_template_contexts
|
||||
elif len(response_or_template_contexts) == 4:
|
||||
data, _, _, _ = response_or_template_contexts
|
||||
else:
|
||||
data, _, _ = response_or_template_contexts
|
||||
except (sqlite3.OperationalError, InvalidSql) as e:
|
||||
raise DatasetteError(str(e), title="Invalid SQL", status=400)
|
||||
|
||||
except sqlite3.OperationalError as e:
|
||||
raise DatasetteError(str(e))
|
||||
|
||||
except DatasetteError:
|
||||
raise
|
||||
|
||||
# Convert rows and columns to CSV
|
||||
headings = data["columns"]
|
||||
# if there are expanded_columns we need to add additional headings
|
||||
expanded_columns = set(data.get("expanded_columns") or [])
|
||||
if expanded_columns:
|
||||
headings = []
|
||||
for column in data["columns"]:
|
||||
headings.append(column)
|
||||
if column in expanded_columns:
|
||||
headings.append(f"{column}_label")
|
||||
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
preamble = ""
|
||||
postamble = ""
|
||||
|
||||
trace = request.args.get("_trace")
|
||||
if trace:
|
||||
content_type = "text/html; charset=utf-8"
|
||||
preamble = (
|
||||
"<html><head><title>CSV debug</title></head>"
|
||||
'<body><textarea style="width: 90%; height: 70vh">'
|
||||
)
|
||||
postamble = "</textarea></body></html>"
|
||||
|
||||
async def stream_fn(r):
|
||||
nonlocal data, trace
|
||||
print("max_csv_mb", datasette.setting("max_csv_mb"))
|
||||
limited_writer = LimitedWriter(r, datasette.setting("max_csv_mb"))
|
||||
if trace:
|
||||
await limited_writer.write(preamble)
|
||||
writer = csv.writer(EscapeHtmlWriter(limited_writer))
|
||||
else:
|
||||
writer = csv.writer(limited_writer)
|
||||
first = True
|
||||
next = None
|
||||
while first or (next and stream):
|
||||
try:
|
||||
kwargs = {}
|
||||
if next:
|
||||
kwargs["_next"] = next
|
||||
if not first:
|
||||
data, _, _ = await fetch_data(request, **kwargs)
|
||||
if first:
|
||||
if request.args.get("_header") != "off":
|
||||
await writer.writerow(headings)
|
||||
first = False
|
||||
next = data.get("next")
|
||||
for row in data["rows"]:
|
||||
if any(isinstance(r, bytes) for r in row):
|
||||
new_row = []
|
||||
for column, cell in zip(headings, row):
|
||||
if isinstance(cell, bytes):
|
||||
# If this is a table page, use .urls.row_blob()
|
||||
if data.get("table"):
|
||||
pks = data.get("primary_keys") or []
|
||||
cell = datasette.absolute_url(
|
||||
request,
|
||||
datasette.urls.row_blob(
|
||||
database,
|
||||
data["table"],
|
||||
path_from_row_pks(row, pks, not pks),
|
||||
column,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Otherwise generate URL for this query
|
||||
url = datasette.absolute_url(
|
||||
request,
|
||||
path_with_format(
|
||||
request=request,
|
||||
format="blob",
|
||||
extra_qs={
|
||||
"_blob_column": column,
|
||||
"_blob_hash": hashlib.sha256(
|
||||
cell
|
||||
).hexdigest(),
|
||||
},
|
||||
replace_format="csv",
|
||||
),
|
||||
)
|
||||
cell = url.replace("&_nocount=1", "").replace(
|
||||
"&_nofacet=1", ""
|
||||
)
|
||||
new_row.append(cell)
|
||||
row = new_row
|
||||
if not expanded_columns:
|
||||
# Simple path
|
||||
await writer.writerow(row)
|
||||
else:
|
||||
# Look for {"value": "label": } dicts and expand
|
||||
new_row = []
|
||||
for heading, cell in zip(data["columns"], row):
|
||||
if heading in expanded_columns:
|
||||
if cell is None:
|
||||
new_row.extend(("", ""))
|
||||
else:
|
||||
assert isinstance(cell, dict)
|
||||
new_row.append(cell["value"])
|
||||
new_row.append(cell["label"])
|
||||
else:
|
||||
new_row.append(cell)
|
||||
await writer.writerow(new_row)
|
||||
except Exception as e:
|
||||
sys.stderr.write("Caught this error: {}\n".format(e))
|
||||
sys.stderr.flush()
|
||||
await r.write(str(e))
|
||||
return
|
||||
await limited_writer.write(postamble)
|
||||
|
||||
headers = {}
|
||||
if datasette.cors:
|
||||
add_cors_headers(headers)
|
||||
if request.args.get("_dl", None):
|
||||
if not trace:
|
||||
content_type = "text/csv; charset=utf-8"
|
||||
disposition = 'attachment; filename="{}.csv"'.format(
|
||||
request.url_vars.get("table", database)
|
||||
)
|
||||
headers["content-disposition"] = disposition
|
||||
|
||||
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
|
||||
|
|
|
@ -38,7 +38,7 @@ from datasette.utils import (
|
|||
from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response
|
||||
from datasette.filters import Filters
|
||||
import sqlite_utils
|
||||
from .base import BaseView, DataView, DatasetteError, ureg, _error
|
||||
from .base import BaseView, DataView, DatasetteError, ureg, _error, stream_csv
|
||||
from .database import QueryView
|
||||
|
||||
LINK_WITH_LABEL = (
|
||||
|
@ -1564,11 +1564,35 @@ async def table_view_traced(datasette, request):
|
|||
)
|
||||
if isinstance(view_data, Response):
|
||||
return view_data
|
||||
data, rows, columns, sql, next_url = view_data
|
||||
data, rows, columns, expanded_columns, sql, next_url = view_data
|
||||
|
||||
# Handle formats from plugins
|
||||
if format_ == "csv":
|
||||
assert False, "CSV not implemented yet"
|
||||
|
||||
async def fetch_data(request, _next=None):
|
||||
(
|
||||
data,
|
||||
rows,
|
||||
columns,
|
||||
expanded_columns,
|
||||
sql,
|
||||
next_url,
|
||||
) = await table_view_data(
|
||||
datasette,
|
||||
request,
|
||||
resolved,
|
||||
extra_extras=extra_extras,
|
||||
context_for_html_hack=context_for_html_hack,
|
||||
default_labels=default_labels,
|
||||
_next=_next,
|
||||
)
|
||||
data["rows"] = rows
|
||||
data["table"] = resolved.table
|
||||
data["columns"] = columns
|
||||
data["expanded_columns"] = expanded_columns
|
||||
return data, None, None
|
||||
|
||||
return await stream_csv(datasette, fetch_data, request, resolved.db.name)
|
||||
elif format_ in datasette.renderers.keys():
|
||||
# Dispatch request to the correct output format renderer
|
||||
# (CSV is not handled here due to streaming)
|
||||
|
@ -1666,6 +1690,7 @@ async def table_view_data(
|
|||
extra_extras=None,
|
||||
context_for_html_hack=False,
|
||||
default_labels=False,
|
||||
_next=None,
|
||||
):
|
||||
extra_extras = extra_extras or set()
|
||||
# We have a table or view
|
||||
|
@ -1779,7 +1804,7 @@ async def table_view_data(
|
|||
count_sql = f"select count(*) {from_sql}"
|
||||
|
||||
# Handle pagination driven by ?_next=
|
||||
_next = request.args.get("_next")
|
||||
_next = _next or request.args.get("_next")
|
||||
|
||||
offset = ""
|
||||
if _next:
|
||||
|
@ -2430,7 +2455,7 @@ async def table_view_data(
|
|||
data["sort"] = sort
|
||||
data["sort_desc"] = sort_desc
|
||||
|
||||
return data, rows[:page_size], columns, sql, next_url
|
||||
return data, rows[:page_size], columns, expanded_columns, sql, next_url
|
||||
|
||||
|
||||
async def _next_value_and_url(
|
||||
|
|
Ładowanie…
Reference in New Issue