diff --git a/datasette/plugins.py b/datasette/plugins.py index bf3735dc..6e7cfccd 100644 --- a/datasette/plugins.py +++ b/datasette/plugins.py @@ -8,6 +8,7 @@ DEFAULT_PLUGINS = ( "datasette.publish.now", "datasette.publish.cloudrun", "datasette.facets", + "datasette.sql_functions", ) pm = pluggy.PluginManager("datasette") diff --git a/datasette/sql_functions.py b/datasette/sql_functions.py new file mode 100644 index 00000000..312294c1 --- /dev/null +++ b/datasette/sql_functions.py @@ -0,0 +1,7 @@ +from datasette import hookimpl +from datasette.utils import escape_fts + + +@hookimpl +def prepare_connection(conn): + conn.create_function("escape_fts", 1, escape_fts) diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 91ab0d76..ab5f995b 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -758,6 +758,20 @@ def format_bytes(bytes): return "{:.1f} {}".format(current, unit) +_escape_fts_re = re.compile(r'\s+|(".*?")') + + +def escape_fts(query): + # If query has unbalanced ", add one at end + if query.count('"') % 2: + query += '"' + bits = _escape_fts_re.split(query) + bits = [b for b in bits if b and b != '""'] + return " ".join( + '"{}"'.format(bit) if not bit.startswith('"') else bit for bit in bits + ) + + class RequestParameters(dict): def get(self, name, default=None): "Return first value in the list, if available" diff --git a/datasette/views/table.py b/datasette/views/table.py index 516b474d..54839344 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -361,7 +361,7 @@ class TableView(RowTableShared): # Simple ?_search=xxx search = search_args["_search"] where_clauses.append( - "{fts_pk} in (select rowid from {fts_table} where {fts_table} match :search)".format( + "{fts_pk} in (select rowid from {fts_table} where {fts_table} match escape_fts(:search))".format( fts_table=escape_sqlite(fts_table), fts_pk=escape_sqlite(fts_pk) ) ) @@ -375,7 +375,7 @@ class TableView(RowTableShared): raise DatasetteError("Cannot search by that column", status=400) where_clauses.append( - "rowid in (select rowid from {fts_table} where {search_col} match :search_{i})".format( + "rowid in (select rowid from {fts_table} where {search_col} match escape_fts(:search_{i}))".format( fts_table=escape_sqlite(fts_table), search_col=escape_sqlite(search_col), i=i, diff --git a/tests/test_api.py b/tests/test_api.py index 979ee9bf..27fa26bb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -947,6 +947,11 @@ def test_sortable_columns_metadata(app_client): [2, "terry dog", "sara weasel", "puma"], ], ), + ( + # Special keyword shouldn't break FTS query + "/fixtures/searchable.json?_search=AND", + [], + ), ( "/fixtures/searchable.json?_search=weasel", [[2, "terry dog", "sara weasel", "puma"]], diff --git a/tests/test_utils.py b/tests/test_utils.py index 28b0d0e1..f448ad22 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -388,3 +388,21 @@ def test_path_with_format(path, format, extra_qs, expected): ) def test_format_bytes(bytes, expected): assert expected == utils.format_bytes(bytes) + + +@pytest.mark.parametrize( + "query,expected", + [ + ("dog", '"dog"'), + ("cat,", '"cat,"'), + ("cat dog", '"cat" "dog"'), + # If a phrase is already double quoted, leave it so + ('"cat dog"', '"cat dog"'), + ('"cat dog" fish', '"cat dog" "fish"'), + # Sensibly handle unbalanced double quotes + ('cat"', '"cat"'), + ('"cat dog" "fish', '"cat dog" "fish"'), + ], +) +def test_escape_fts(query, expected): + assert expected == utils.escape_fts(query)