diff --git a/datasite/app.py b/datasite/app.py index 44ae9c51..cdbc305f 100644 --- a/datasite/app.py +++ b/datasite/app.py @@ -95,7 +95,7 @@ class BaseView(HTTPMethodView): rows.sort(key=lambda row: row[-1]) return [str(r[1]) for r in rows] - async def execute(self, db_name, sql): + async def execute(self, db_name, sql, params=None): """Executes sql against db_name in a thread""" def sql_operation_in_thread(): conn = getattr(connections, db_name, None) @@ -111,7 +111,8 @@ class BaseView(HTTPMethodView): setattr(connections, db_name, conn) with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS): - rows = conn.execute(sql) + print('execute: ', sql, 'params=', params) + rows = conn.execute(sql, params or {}) return rows return await asyncio.get_event_loop().run_in_executor( @@ -135,7 +136,7 @@ class BaseView(HTTPMethodView): data, extra_template_data = await self.data( request, name, hash, **kwargs ) - except sqlite3.OperationalError as e: + except (sqlite3.OperationalError, InvalidSql) as e: data = { 'ok': False, 'error': str(e), @@ -216,8 +217,13 @@ class DatabaseView(BaseView): template = 'database.html' async def data(self, request, name, hash): - sql = request.args.get('sql') or 'select * from sqlite_master' - rows = await self.execute(name, sql) + sql = 'select * from sqlite_master' + params = {} + if request.args.get('sql'): + params = request.raw_args + sql = params.pop('sql') + validate_sql_select(sql) + rows = await self.execute(name, sql, params) columns = [r[0] for r in rows.description] return { 'database': name, @@ -448,3 +454,15 @@ def app_factory(files, num_threads=3): '///' ) return app + + +class InvalidSql(Exception): + pass + + +def validate_sql_select(sql): + sql = sql.strip().lower() + if not sql.startswith('select '): + raise InvalidSql('Statement must begin with SELECT') + if 'pragma' in sql: + raise InvalidSql('Statement may not contain PRAGMA') diff --git a/test_helpers.py b/test_helpers.py index b37fd767..f8885f7a 100644 --- a/test_helpers.py +++ b/test_helpers.py @@ -1,6 +1,5 @@ from datasite import app import pytest -import sqlite3 import json @@ -15,29 +14,6 @@ def test_compound_pks_from_path(path, expected): assert expected == app.compound_pks_from_path(path) -@pytest.mark.parametrize('sql,table,expected_keys', [ - (''' - CREATE TABLE `Compound` ( - A varchar(5) NOT NULL, - B varchar(10) NOT NULL, - PRIMARY KEY (A, B) - ); - ''', 'Compound', ['A', 'B']), - (''' - CREATE TABLE `Compound2` ( - A varchar(5) NOT NULL, - B varchar(10) NOT NULL, - PRIMARY KEY (B, A) - ); - ''', 'Compound2', ['B', 'A']), -]) -def test_pks_for_table(sql, table, expected_keys): - conn = sqlite3.connect(':memory:') - conn.execute(sql) - actual = app.pks_for_table(conn, table) - assert expected_keys == actual - - @pytest.mark.parametrize('row,pks,expected_path', [ ({'A': 'foo', 'B': 'bar'}, ['A', 'B'], 'foo,bar'), ({'A': 'f,o', 'B': 'bar'}, ['A', 'B'], 'f%2Co,bar'), @@ -113,3 +89,22 @@ def test_build_where(args, expected_where, expected_params): actual_where, actual_params = app.build_where_clause(args) assert expected_where == actual_where assert expected_params == actual_params + + +@pytest.mark.parametrize('bad_sql', [ + 'update blah;', + 'PRAGMA case_sensitive_like = true' + "SELECT * FROM pragma_index_info('idx52')", +]) +def test_validate_sql_select_bad(bad_sql): + with pytest.raises(app.InvalidSql): + app.validate_sql_select(bad_sql) + + +@pytest.mark.parametrize('good_sql', [ + 'select count(*) from airports', + 'select foo from bar', + 'select 1 + 1', +]) +def test_validate_sql_select_good(good_sql): + app.validate_sql_select(good_sql)