diff --git a/datasette/views/base.py b/datasette/views/base.py index 14179824..7772f42d 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -228,6 +228,8 @@ class BaseView(RenderMixin): content_type = "text/plain; charset=utf-8" headers = {} + if self.ds.cors: + headers["Access-Control-Allow-Origin"] = "*" if request.args.get("_dl", None): content_type = "text/csv; charset=utf-8" disposition = 'attachment; filename="{}.csv"'.format( diff --git a/tests/fixtures.py b/tests/fixtures.py index 004a0b03..cc1734f4 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -23,32 +23,39 @@ class TestClient: ) -@pytest.fixture(scope='session') -def app_client(sql_time_limit_ms=None, max_returned_rows=None, config=None, filename="fixtures.db"): +@pytest.fixture(scope="session") +def app_client( + sql_time_limit_ms=None, + max_returned_rows=None, + cors=False, + config=None, + filename="fixtures.db", +): with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, filename) conn = sqlite3.connect(filepath) conn.executescript(TABLES) os.chdir(os.path.dirname(filepath)) - plugins_dir = os.path.join(tmpdir, 'plugins') + plugins_dir = os.path.join(tmpdir, "plugins") os.mkdir(plugins_dir) - open(os.path.join(plugins_dir, 'my_plugin.py'), 'w').write(PLUGIN1) - open(os.path.join(plugins_dir, 'my_plugin_2.py'), 'w').write(PLUGIN2) + open(os.path.join(plugins_dir, "my_plugin.py"), "w").write(PLUGIN1) + open(os.path.join(plugins_dir, "my_plugin_2.py"), "w").write(PLUGIN2) config = config or {} - config.update({ - 'default_page_size': 50, - 'max_returned_rows': max_returned_rows or 100, - 'sql_time_limit_ms': sql_time_limit_ms or 200, - }) + config.update( + { + "default_page_size": 50, + "max_returned_rows": max_returned_rows or 100, + "sql_time_limit_ms": sql_time_limit_ms or 200, + } + ) ds = Datasette( [filepath], + cors=cors, metadata=METADATA, plugins_dir=plugins_dir, config=config, ) - ds.sqlite_functions.append( - ('sleep', 1, lambda n: time.sleep(float(n))), - ) + ds.sqlite_functions.append(("sleep", 1, lambda n: time.sleep(float(n)))) client = TestClient(ds.app().test_client) client.ds = ds yield client @@ -83,6 +90,11 @@ def app_client_with_dot(): yield from app_client(filename="fixtures.dot.db") +@pytest.fixture(scope='session') +def app_client_with_cors(): + yield from app_client(cors=True) + + def generate_compound_rows(num): for a, b, c in itertools.islice( itertools.product(string.ascii_lowercase, repeat=3), num diff --git a/tests/test_csv.py b/tests/test_csv.py index 194de5b1..398dbd1f 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -1,4 +1,8 @@ -from .fixtures import app_client, app_client_csv_max_mb_one # noqa +from .fixtures import ( # noqa + app_client, + app_client_csv_max_mb_one, + app_client_with_cors +) EXPECTED_TABLE_CSV = '''id,content 1,hello @@ -30,13 +34,21 @@ pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood 15,2,0,MC,4,Memnonia,Arcadia Planitia '''.lstrip().replace('\n', '\r\n') + def test_table_csv(app_client): response = app_client.get('/fixtures/simple_primary_key.csv') assert response.status == 200 + assert not response.headers.get("Access-Control-Allow-Origin") assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] assert EXPECTED_TABLE_CSV == response.text +def test_table_csv_cors_headers(app_client_with_cors): + response = app_client_with_cors.get('/fixtures/simple_primary_key.csv') + assert response.status == 200 + assert "*" == response.headers["Access-Control-Allow-Origin"] + + def test_table_csv_with_labels(app_client): response = app_client.get('/fixtures/facetable.csv?_labels=1') assert response.status == 200