diff --git a/datasette/filters.py b/datasette/filters.py index 5fd722f3..abaafc5b 100644 --- a/datasette/filters.py +++ b/datasette/filters.py @@ -1,5 +1,10 @@ +import json import numbers -from .utils import detect_json1 + +from .utils import ( + detect_json1, + escape_sqlite, +) class Filter: @@ -52,6 +57,29 @@ class TemplatedFilter(Filter): return template.format(c=column, v=value) +class InFilter(Filter): + key = 'in' + display = 'in' + + def __init__(self): + pass + + def split_value(self, value): + if value.startswith("["): + return json.loads(value) + else: + return [v.strip() for v in value.split(",")] + + def where_clause(self, table, column, value, param_counter): + values = self.split_value(value) + params = [":p{}".format(param_counter + i) for i in range(len(values))] + sql = "{} in ({})".format(escape_sqlite(column), ", ".join(params)) + return sql, values + + def human_clause(self, column, value): + return "{} in {}".format(column, json.dumps(self.split_value(value))) + + class Filters: _filters = [ # key, display, sql_template, human_template, format=, numeric=, no_argument= @@ -64,8 +92,9 @@ class Filters: TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True), TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True), TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True), - TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'), + TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), + InFilter(), ] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in ( select {t}.rowid from {t}, json_each({t}.{c}) j where j.value = :{p} diff --git a/docs/json_api.rst b/docs/json_api.rst index 1ea35672..67700224 100644 --- a/docs/json_api.rst +++ b/docs/json_api.rst @@ -219,6 +219,15 @@ You can filter the data returned by the table based on column values using a que ``?column__glob=value`` Similar to LIKE but uses Unix wildcard syntax and is case sensitive. +``?column__in=value1,value2,value3`` + Rows where column matches any of the provided values. + + You can use a comma separated string, or you can use a JSON array. + + The JSON array option is useful if one of your matching values itself contains a comma: + + ``?column__in=["value","value,with,commas"]`` + ``?column__arraycontains=value`` Works against columns that contain JSON arrays - matches if any of the values in that array match. diff --git a/tests/test_filters.py b/tests/test_filters.py index b0cb3f34..a5d6e3d0 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -53,6 +53,28 @@ import pytest ['"bar" > :p0', '"baz" is null', '"foo" is null'], [10] ), + ( + { + 'foo__in': '1,2,3', + }, + ['foo in (:p0, :p1, :p2)'], + ["1", "2", "3"] + ), + # JSON array variants of __in (useful for unexpected characters) + ( + { + 'foo__in': '[1,2,3]', + }, + ['foo in (:p0, :p1, :p2)'], + [1, 2, 3] + ), + ( + { + 'foo__in': '["dog,cat", "cat[dog]"]', + }, + ['foo in (:p0, :p1)'], + ["dog,cat", "cat[dog]"] + ), ]) def test_build_where(args, expected_where, expected_params): f = Filters(sorted(args.items()))