diff --git a/.travis.yml b/.travis.yml index 9e92eee3..d32df307 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,6 +13,7 @@ script: jobs: include: - stage: deploy latest.datasette.io + if: branch = master AND type = push script: - pip install . - npm install -g now @@ -23,7 +24,6 @@ jobs: - now alias --token=$NOW_TOKEN - echo "{\"name\":\"datasette-latest-$ALIAS\",\"alias\":\"$ALIAS.datasette.io\"}" > now.json - now alias --token=$NOW_TOKEN - on: master - stage: release tagged version if: tag IS present python: 3.6 diff --git a/datasette/app.py b/datasette/app.py index 70f2a93f..fb389d73 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -94,6 +94,12 @@ CONFIG_OPTIONS = ( ConfigOption("cache_size_kb", 0, """ SQLite cache size in KB (0 == use SQLite default) """.strip()), + ConfigOption("allow_csv_stream", True, """ + Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows) + """.strip()), + ConfigOption("max_csv_mb", 100, """ + Maximum size allowed for CSV export in MB. Set 0 to disable this limit. + """.strip()), ) DEFAULT_CONFIG = { option.name: option.default diff --git a/datasette/utils.py b/datasette/utils.py index a179eddf..005db87f 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -832,3 +832,22 @@ def value_as_boolean(value): class ValueAsBooleanError(ValueError): pass + + +class WriteLimitExceeded(Exception): + pass + + +class LimitedWriter: + def __init__(self, writer, limit_mb): + self.writer = writer + self.limit_bytes = limit_mb * 1024 * 1024 + self.bytes_count = 0 + + def write(self, bytes): + self.bytes_count += len(bytes) + if self.limit_bytes and (self.bytes_count > self.limit_bytes): + raise WriteLimitExceeded("CSV contains more than {} bytes".format( + self.limit_bytes + )) + self.writer.write(bytes) diff --git a/datasette/views/base.py b/datasette/views/base.py index 53ae08bd..0ca52e61 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -16,6 +16,7 @@ from datasette.utils import ( CustomJSONEncoder, InterruptedError, InvalidSql, + LimitedWriter, path_from_row_pks, path_with_added_args, path_with_format, @@ -150,13 +151,23 @@ class BaseView(RenderMixin): return await self.view_get(request, name, hash, **kwargs) async def as_csv(self, request, name, hash, **kwargs): + stream = request.args.get("_stream") + if stream: + # Some quick sanity checks + if not self.ds.config["allow_csv_stream"]: + raise DatasetteError("CSV streaming is disabled", status=400) + if request.args.get("_next"): + raise DatasetteError( + "_next not allowed for CSV streaming", status=400 + ) + kwargs["_size"] = "max" + # Fetch the first page try: response_or_template_contexts = await self.data( request, name, hash, **kwargs ) if isinstance(response_or_template_contexts, response.HTTPResponse): return response_or_template_contexts - else: data, extra_template_data, templates = response_or_template_contexts except (sqlite3.OperationalError, InvalidSql) as e: @@ -167,6 +178,7 @@ class BaseView(RenderMixin): except DatasetteError: raise + # Convert rows and columns to CSV headings = data["columns"] # if there are expanded_columns we need to add additional headings @@ -179,22 +191,40 @@ class BaseView(RenderMixin): headings.append("{}_label".format(column)) async def stream_fn(r): - writer = csv.writer(r) - writer.writerow(headings) - for row in data["rows"]: - if not expanded_columns: - # Simple path - writer.writerow(row) - else: - # Look for {"value": "label": } dicts and expand - new_row = [] - for cell in row: - if isinstance(cell, dict): - new_row.append(cell["value"]) - new_row.append(cell["label"]) + nonlocal data + writer = csv.writer(LimitedWriter(r, self.ds.config["max_csv_mb"])) + first = True + next = None + while first or (next and stream): + try: + if next: + kwargs["_next"] = next + if not first: + data, extra_template_data, templates = await self.data( + request, name, hash, **kwargs + ) + if first: + writer.writerow(headings) + first = False + next = data.get("next") + for row in data["rows"]: + if not expanded_columns: + # Simple path + writer.writerow(row) else: - new_row.append(cell) - writer.writerow(new_row) + # Look for {"value": "label": } dicts and expand + new_row = [] + for cell in row: + if isinstance(cell, dict): + new_row.append(cell["value"]) + new_row.append(cell["label"]) + else: + new_row.append(cell) + writer.writerow(new_row) + except Exception as e: + print('caught this', e) + r.write(str(e)) + return content_type = "text/plain; charset=utf-8" headers = {} @@ -393,7 +423,8 @@ class BaseView(RenderMixin): return r async def custom_sql( - self, request, name, hash, sql, editable=True, canned_query=None + self, request, name, hash, sql, editable=True, canned_query=None, + _size=None ): params = request.raw_args if "sql" in params: @@ -415,6 +446,8 @@ class BaseView(RenderMixin): extra_args = {} if params.get("_timelimit"): extra_args["custom_time_limit"] = int(params["_timelimit"]) + if _size: + extra_args["page_size"] = _size results = await self.ds.execute( name, sql, params, truncate=True, **extra_args ) diff --git a/datasette/views/database.py b/datasette/views/database.py index 2f3f41d3..a7df485b 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError class DatabaseView(BaseView): - async def data(self, request, name, hash, default_labels=False): + async def data(self, request, name, hash, default_labels=False, _size=None): if request.args.get("sql"): if not self.ds.config["allow_sql"]: raise DatasetteError("sql= is not allowed", status=400) sql = request.raw_args.pop("sql") validate_sql_select(sql) - return await self.custom_sql(request, name, hash, sql) + return await self.custom_sql(request, name, hash, sql, _size=_size) info = self.ds.inspect()[name] metadata = self.ds.metadata.get("databases", {}).get(name, {}) diff --git a/datasette/views/table.py b/datasette/views/table.py index c57fd954..cb2c9ae5 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -220,7 +220,7 @@ class RowTableShared(BaseView): class TableView(RowTableShared): - async def data(self, request, name, hash, table, default_labels=False): + async def data(self, request, name, hash, table, default_labels=False, _next=None, _size=None): canned_query = self.ds.get_canned_query(name, table) if canned_query is not None: return await self.custom_sql( @@ -375,7 +375,7 @@ class TableView(RowTableShared): count_sql = "select count(*) {}".format(from_sql) - _next = special_args.get("_next") + _next = _next or special_args.get("_next") offset = "" if _next: if is_view: @@ -462,7 +462,7 @@ class TableView(RowTableShared): extra_args = {} # Handle ?_size=500 - page_size = request.raw_args.get("_size") + page_size = _size or request.raw_args.get("_size") if page_size: if page_size == "max": page_size = self.max_returned_rows @@ -512,6 +512,8 @@ class TableView(RowTableShared): facet_results = {} facets_timed_out = [] for column in facets: + if _next: + continue facet_sql = """ select {col} as value, count(*) as count {from_sql} {and_or_where} {col} is not null @@ -665,6 +667,8 @@ class TableView(RowTableShared): for facet_column in columns: if facet_column in facets: continue + if _next: + continue if not self.ds.config["suggest_facets"]: continue suggested_facet_sql = ''' diff --git a/docs/config.rst b/docs/config.rst index 8f0cd246..e0013bf0 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -125,3 +125,24 @@ Sets the amount of memory SQLite uses for its `per-connection cache