diff --git a/datasette/utils.py b/datasette/utils.py index 92104125..6d7ac147 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -70,12 +70,22 @@ class InvalidSql(Exception): pass +allowed_sql_res = [ + re.compile(r'^select\b'), + re.compile(r'^with\b'), +] +disallawed_sql_res = [ + (re.compile('pragma'), 'Statement may not contain PRAGMA'), +] + + def validate_sql_select(sql): sql = sql.strip().lower() - if not sql.startswith('select '): - raise InvalidSql('Statement must begin with SELECT') - if 'pragma' in sql: - raise InvalidSql('Statement may not contain PRAGMA') + if not any(r.match(sql) for r in allowed_sql_res): + raise InvalidSql('Statement must be a SELECT') + for r, msg in disallawed_sql_res: + if r.search(sql): + raise InvalidSql(msg) def path_with_added_args(request, args): diff --git a/tests/test_app.py b/tests/test_app.py index 99cb9ffb..a2a6d9cc 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -181,14 +181,14 @@ def test_invalid_custom_sql(app_client): gather_request=False ) assert response.status == 400 - assert 'Statement must begin with SELECT' in response.text + assert 'Statement must be a SELECT' in response.text response = app_client.get( '/test_tables.json?sql=.schema', gather_request=False ) assert response.status == 400 assert response.json['ok'] is False - assert 'Statement must begin with SELECT' == response.json['error'] + assert 'Statement must be a SELECT' == response.json['error'] def test_table_page(app_client): diff --git a/tests/test_utils.py b/tests/test_utils.py index 9ff6acf7..8fe1855e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -123,6 +123,8 @@ def test_validate_sql_select_bad(bad_sql): 'select count(*) from airports', 'select foo from bar', 'select 1 + 1', + 'SELECT\nblah FROM foo', + 'WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;' ]) def test_validate_sql_select_good(good_sql): utils.validate_sql_select(good_sql)