datasette/tests/test_utils.py

721 wiersze
23 KiB
Python
Czysty Zwykły widok Historia

"""
Tests for various datasette helper functions.
"""
2024-01-31 03:55:26 +00:00
from datasette.app import Datasette
from datasette import utils
from datasette.utils.asgi import Request
from datasette.utils.sqlite import sqlite3
import json
import os
import pathlib
import pytest
import tempfile
from unittest.mock import patch
@pytest.mark.parametrize(
"path,expected",
[
("foo", ["foo"]),
("foo,bar", ["foo", "bar"]),
("123,433,112", ["123", "433", "112"]),
("123~2C433,112", ["123,433", "112"]),
("123~2F433~2F112", ["123/433/112"]),
],
)
def test_urlsafe_components(path, expected):
assert expected == utils.urlsafe_components(path)
@pytest.mark.parametrize(
"path,added_args,expected",
[
("/foo", {"bar": 1}, "/foo?bar=1"),
("/foo?bar=1", {"baz": 2}, "/foo?bar=1&baz=2"),
("/foo?bar=1&bar=2", {"baz": 3}, "/foo?bar=1&bar=2&baz=3"),
("/foo?bar=1", {"bar": None}, "/foo"),
# Test order is preserved
(
"/?_facet=prim_state&_facet=area_name",
(("prim_state", "GA"),),
"/?_facet=prim_state&_facet=area_name&prim_state=GA",
),
(
"/?_facet=state&_facet=city&state=MI",
(("city", "Detroit"),),
"/?_facet=state&_facet=city&state=MI&city=Detroit",
),
(
"/?_facet=state&_facet=city",
(("_facet", "planet_int"),),
"/?_facet=state&_facet=city&_facet=planet_int",
),
],
)
def test_path_with_added_args(path, added_args, expected):
request = Request.fake(path)
actual = utils.path_with_added_args(request, added_args)
assert expected == actual
@pytest.mark.parametrize(
"path,args,expected",
[
("/foo?bar=1", {"bar"}, "/foo"),
("/foo?bar=1&baz=2", {"bar"}, "/foo?baz=2"),
("/foo?bar=1&bar=2&bar=3", {"bar": "2"}, "/foo?bar=1&bar=3"),
],
)
def test_path_with_removed_args(path, args, expected):
request = Request.fake(path)
actual = utils.path_with_removed_args(request, args)
assert expected == actual
# Run the test again but this time use the path= argument
request = Request.fake("/")
actual = utils.path_with_removed_args(request, args, path=path)
assert expected == actual
@pytest.mark.parametrize(
"path,args,expected",
[
("/foo?bar=1", {"bar": 2}, "/foo?bar=2"),
("/foo?bar=1&baz=2", {"bar": None}, "/foo?baz=2"),
],
)
def test_path_with_replaced_args(path, args, expected):
request = Request.fake(path)
actual = utils.path_with_replaced_args(request, args)
assert expected == actual
@pytest.mark.parametrize(
"row,pks,expected_path",
[
({"A": "foo", "B": "bar"}, ["A", "B"], "foo,bar"),
({"A": "f,o", "B": "bar"}, ["A", "B"], "f~2Co,bar"),
({"A": 123}, ["A"], "123"),
(
utils.CustomRow(
["searchable_id", "tag"],
[
("searchable_id", {"value": 1, "label": "1"}),
("tag", {"value": "feline", "label": "feline"}),
],
),
["searchable_id", "tag"],
"1,feline",
),
],
)
def test_path_from_row_pks(row, pks, expected_path):
actual_path = utils.path_from_row_pks(row, pks, False)
assert expected_path == actual_path
@pytest.mark.parametrize(
"obj,expected",
[
(
{
"Description": "Soft drinks",
"Picture": b"\x15\x1c\x02\xc7\xad\x05\xfe",
"CategoryID": 1,
},
"""
{"CategoryID": 1, "Description": "Soft drinks", "Picture": {"$base64": true, "encoded": "FRwCx60F/g=="}}
""".strip(),
)
],
)
def test_custom_json_encoder(obj, expected):
actual = json.dumps(obj, cls=utils.CustomJSONEncoder, sort_keys=True)
assert expected == actual
@pytest.mark.parametrize(
"bad_sql",
[
"update blah;",
"-- sql comment to skip\nupdate blah;",
"update blah set some_column='# Hello there\n\n* This is a list\n* of items\n--\n[And a link](https://github.com/simonw/datasette-render-markdown).'\nas demo_markdown",
"PRAGMA case_sensitive_like = true",
"SELECT * FROM pragma_not_on_allow_list('idx52')",
"/* This comment is not valid. select 1",
"/**/\nupdate foo set bar = 1\n/* test */ select 1",
],
)
def test_validate_sql_select_bad(bad_sql):
with pytest.raises(utils.InvalidSql):
utils.validate_sql_select(bad_sql)
@pytest.mark.parametrize(
"good_sql",
[
"select count(*) from airports",
"select foo from bar",
"--sql comment to skip\nselect foo from bar",
"select '# Hello there\n\n* This is a list\n* of items\n--\n[And a link](https://github.com/simonw/datasette-render-markdown).'\nas demo_markdown",
"select 1 + 1",
2019-10-06 17:23:58 +00:00
"explain select 1 + 1",
"explain\nselect 1 + 1",
2019-10-06 17:23:58 +00:00
"explain query plan select 1 + 1",
"explain query plan\nselect 1 + 1",
"SELECT\nblah FROM foo",
"WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;",
2019-10-06 17:23:58 +00:00
"explain WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;",
"explain query plan WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;",
"SELECT * FROM pragma_index_info('idx52')",
"select * from pragma_table_xinfo('table')",
# Various types of comment
"-- comment\nselect 1",
"-- one line\n -- two line\nselect 1",
" /* comment */\nselect 1",
" /* comment */select 1",
"/* comment */\n -- another\n /* one more */ select 1",
"/* This comment \n has multiple lines */\nselect 1",
],
)
def test_validate_sql_select_good(good_sql):
utils.validate_sql_select(good_sql)
@pytest.mark.parametrize("open_quote,close_quote", [('"', '"'), ("[", "]")])
def test_detect_fts(open_quote, close_quote):
sql = """
CREATE TABLE "Dumb_Table" (
"TreeID" INTEGER,
"qSpecies" TEXT
);
CREATE TABLE "Street_Tree_List" (
"TreeID" INTEGER,
"qSpecies" TEXT,
"qAddress" TEXT,
"SiteOrder" INTEGER,
"qSiteInfo" TEXT,
"PlantType" TEXT,
"qCaretaker" TEXT
);
CREATE VIEW Test_View AS SELECT * FROM Dumb_Table;
CREATE VIRTUAL TABLE {open}Street_Tree_List_fts{close} USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content={open}Street_Tree_List{close});
CREATE VIRTUAL TABLE r USING rtree(a, b, c);
""".format(
open=open_quote, close=close_quote
)
conn = utils.sqlite3.connect(":memory:")
conn.executescript(sql)
assert None is utils.detect_fts(conn, "Dumb_Table")
assert None is utils.detect_fts(conn, "Test_View")
assert None is utils.detect_fts(conn, "r")
assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List")
@pytest.mark.parametrize("table", ("regular", "has'single quote"))
def test_detect_fts_different_table_names(table):
sql = """
CREATE TABLE [{table}] (
"TreeID" INTEGER,
"qSpecies" TEXT
);
CREATE VIRTUAL TABLE [{table}_fts] USING FTS4 ("qSpecies", content="{table}");
""".format(
table=table
)
conn = utils.sqlite3.connect(":memory:")
conn.executescript(sql)
assert "{table}_fts".format(table=table) == utils.detect_fts(conn, table)
@pytest.mark.parametrize(
"url,expected",
[
("http://www.google.com/", True),
("https://example.com/", True),
("www.google.com", False),
("http://www.google.com/ is a search engine", False),
],
)
def test_is_url(url, expected):
assert expected == utils.is_url(url)
@pytest.mark.parametrize(
"s,expected",
[
("simple", "simple"),
("MixedCase", "MixedCase"),
("-no-leading-hyphens", "no-leading-hyphens-65bea6"),
("_no-leading-underscores", "no-leading-underscores-b921bc"),
("no spaces", "no-spaces-7088d7"),
("-", "336d5e"),
("no $ characters", "no--characters-59e024"),
],
)
def test_to_css_class(s, expected):
assert expected == utils.to_css_class(s)
def test_temporary_docker_directory_uses_hard_link():
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
with open("hello", "w") as fp:
fp.write("world")
# Default usage of this should use symlink
with utils.temporary_docker_directory(
files=["hello"],
name="t",
metadata=None,
2017-12-09 18:38:04 +00:00
extra_options=None,
branch=None,
template_dir=None,
plugins_dir=None,
2017-12-09 18:38:04 +00:00
static=[],
install=[],
2018-05-31 14:47:22 +00:00
spatialite=False,
version_note=None,
secret="secret",
) as temp_docker:
hello = os.path.join(temp_docker, "hello")
with open(hello) as fp:
assert "world" == fp.read()
# It should be a hard link
assert 2 == os.stat(hello).st_nlink
@patch("os.link")
def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
# Copy instead if os.link raises OSError (normally due to different device)
mock_link.side_effect = OSError
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
with open("hello", "w") as fp:
fp.write("world")
# Default usage of this should use symlink
with utils.temporary_docker_directory(
files=["hello"],
name="t",
metadata=None,
2017-12-09 18:38:04 +00:00
extra_options=None,
branch=None,
template_dir=None,
plugins_dir=None,
2017-12-09 18:38:04 +00:00
static=[],
install=[],
2018-05-31 14:47:22 +00:00
spatialite=False,
version_note=None,
secret=None,
) as temp_docker:
hello = os.path.join(temp_docker, "hello")
with open(hello) as fp:
assert "world" == fp.read()
# It should be a copy, not a hard link
assert 1 == os.stat(hello).st_nlink
def test_temporary_docker_directory_quotes_args():
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
with open("hello", "w") as fp:
fp.write("world")
with utils.temporary_docker_directory(
files=["hello"],
name="t",
metadata=None,
extra_options="--$HOME",
branch=None,
template_dir=None,
plugins_dir=None,
static=[],
install=[],
spatialite=False,
version_note="$PWD",
secret="secret",
) as temp_docker:
df = os.path.join(temp_docker, "Dockerfile")
with open(df) as fp:
df_contents = fp.read()
assert "'$PWD'" in df_contents
assert "'--$HOME'" in df_contents
assert "ENV DATASETTE_SECRET 'secret'" in df_contents
def test_compound_keys_after_sql():
assert "((a > :p0))" == utils.compound_keys_after_sql(["a"])
assert """
((a > :p0)
or
(a = :p0 and b > :p1))
""".strip() == utils.compound_keys_after_sql(
["a", "b"]
)
assert """
((a > :p0)
or
(a = :p0 and b > :p1)
or
(a = :p0 and b = :p1 and c > :p2))
""".strip() == utils.compound_keys_after_sql(
["a", "b", "c"]
)
def test_table_columns():
conn = sqlite3.connect(":memory:")
conn.executescript(
"""
create table places (id integer primary key, name text, bob integer)
"""
)
assert ["id", "name", "bob"] == utils.table_columns(conn, "places")
@pytest.mark.parametrize(
"path,format,extra_qs,expected",
[
("/foo?sql=select+1", "csv", {}, "/foo.csv?sql=select+1"),
("/foo?sql=select+1", "json", {}, "/foo.json?sql=select+1"),
("/foo/bar", "json", {}, "/foo/bar.json"),
("/foo/bar", "csv", {}, "/foo/bar.csv"),
("/foo/bar", "csv", {"_dl": 1}, "/foo/bar.csv?_dl=1"),
(
"/sf-trees/Street_Tree_List?_search=cherry&_size=1000",
"csv",
{"_dl": 1},
"/sf-trees/Street_Tree_List.csv?_search=cherry&_size=1000&_dl=1",
),
],
)
def test_path_with_format(path, format, extra_qs, expected):
request = Request.fake(path)
actual = utils.path_with_format(request=request, format=format, extra_qs=extra_qs)
assert expected == actual
@pytest.mark.parametrize(
"bytes,expected",
[
(120, "120 bytes"),
(1024, "1.0 KB"),
(1024 * 1024, "1.0 MB"),
(1024 * 1024 * 1024, "1.0 GB"),
(1024 * 1024 * 1024 * 1.3, "1.3 GB"),
(1024 * 1024 * 1024 * 1024, "1.0 TB"),
],
)
def test_format_bytes(bytes, expected):
assert expected == utils.format_bytes(bytes)
@pytest.mark.parametrize(
"query,expected",
[
("dog", '"dog"'),
("cat,", '"cat,"'),
("cat dog", '"cat" "dog"'),
# If a phrase is already double quoted, leave it so
('"cat dog"', '"cat dog"'),
('"cat dog" fish', '"cat dog" "fish"'),
# Sensibly handle unbalanced double quotes
('cat"', '"cat"'),
('"cat dog" "fish', '"cat dog" "fish"'),
],
)
def test_escape_fts(query, expected):
assert expected == utils.escape_fts(query)
@pytest.mark.parametrize(
"input,expected",
[
("dog", "dog"),
('dateutil_parse("1/2/2020")', r"dateutil_parse(\0000221/2/2020\000022)"),
("this\r\nand\r\nthat", r"this\00000Aand\00000Athat"),
],
)
def test_escape_css_string(input, expected):
assert expected == utils.escape_css_string(input)
def test_check_connection_spatialite_raises():
path = str(pathlib.Path(__file__).parent / "spatialite.db")
conn = sqlite3.connect(path)
with pytest.raises(utils.SpatialiteConnectionProblem):
utils.check_connection(conn)
def test_check_connection_passes():
conn = sqlite3.connect(":memory:")
utils.check_connection(conn)
def test_call_with_supported_arguments():
def foo(a, b):
return f"{a}+{b}"
assert "1+2" == utils.call_with_supported_arguments(foo, a=1, b=2)
assert "1+2" == utils.call_with_supported_arguments(foo, a=1, b=2, c=3)
with pytest.raises(TypeError):
utils.call_with_supported_arguments(foo, a=1)
@pytest.mark.parametrize(
"data,should_raise",
[
([["foo", "bar"], ["foo", "baz"]], False),
([("foo", "bar"), ("foo", "baz")], False),
((["foo", "bar"], ["foo", "baz"]), False),
([["foo", "bar"], ["foo", "baz", "bax"]], True),
({"foo": ["bar", "baz"]}, False),
({"foo": ("bar", "baz")}, False),
({"foo": "bar"}, True),
],
)
def test_multi_params(data, should_raise):
if should_raise:
with pytest.raises(AssertionError):
utils.MultiParams(data)
return
p1 = utils.MultiParams(data)
assert "bar" == p1["foo"]
assert ["bar", "baz"] == list(p1.getlist("foo"))
@pytest.mark.parametrize(
"actor,allow,expected",
[
# Default is to allow:
(None, None, True),
# {} means deny-all:
(None, {}, False),
({"id": "root"}, {}, False),
# true means allow-all
({"id": "root"}, True, True),
(None, True, True),
# false means deny-all
({"id": "root"}, False, False),
(None, False, False),
# Special case for "unauthenticated": true
(None, {"unauthenticated": True}, True),
(None, {"unauthenticated": False}, False),
# Match on just one property:
(None, {"id": "root"}, False),
({"id": "root"}, None, True),
({"id": "simon", "staff": True}, {"staff": True}, True),
({"id": "simon", "staff": False}, {"staff": True}, False),
# Special "*" value for any key:
({"id": "root"}, {"id": "*"}, True),
({}, {"id": "*"}, False),
({"name": "root"}, {"id": "*"}, False),
# Supports single strings or list of values:
({"id": "root"}, {"id": "bob"}, False),
({"id": "root"}, {"id": ["bob"]}, False),
({"id": "root"}, {"id": "root"}, True),
({"id": "root"}, {"id": ["root"]}, True),
# Any matching role will work:
({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["staff"]}, True),
({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["dev"]}, True),
({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["otter"]}, False),
({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["dev", "otter"]}, True),
({"id": "garry", "roles": []}, {"roles": ["staff"]}, False),
({"id": "garry"}, {"roles": ["staff"]}, False),
# Any single matching key works:
({"id": "root"}, {"bot_id": "my-bot", "id": ["root"]}, True),
],
)
def test_actor_matches_allow(actor, allow, expected):
assert expected == utils.actor_matches_allow(actor, allow)
@pytest.mark.parametrize(
"config,expected",
[
({"foo": "bar"}, {"foo": "bar"}),
({"$env": "FOO"}, "x"),
({"k": {"$env": "FOO"}}, {"k": "x"}),
([{"k": {"$env": "FOO"}}, {"z": {"$env": "FOO"}}], [{"k": "x"}, {"z": "x"}]),
({"k": [{"in_a_list": {"$env": "FOO"}}]}, {"k": [{"in_a_list": "x"}]}),
],
)
def test_resolve_env_secrets(config, expected):
assert expected == utils.resolve_env_secrets(config, {"FOO": "x"})
2020-06-29 18:40:40 +00:00
@pytest.mark.parametrize(
"actor,expected",
[
({"id": "blah"}, "blah"),
({"id": "blah", "login": "l"}, "l"),
({"id": "blah", "login": "l"}, "l"),
({"id": "blah", "login": "l", "username": "u"}, "u"),
({"login": "l", "name": "n"}, "n"),
(
{"id": "blah", "login": "l", "username": "u", "name": "n", "display": "d"},
"d",
),
({"weird": "shape"}, "{'weird': 'shape'}"),
],
)
def test_display_actor(actor, expected):
assert expected == utils.display_actor(actor)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"dbs,expected_path",
[
(["one_table"], "/one/one"),
(["two_tables"], "/two"),
(["one_table", "two_tables"], "/"),
],
)
async def test_initial_path_for_datasette(tmp_path_factory, dbs, expected_path):
db_dir = tmp_path_factory.mktemp("dbs")
one_table = str(db_dir / "one.db")
sqlite3.connect(one_table).execute("create table one (id integer primary key)")
two_tables = str(db_dir / "two.db")
sqlite3.connect(two_tables).execute("create table two (id integer primary key)")
sqlite3.connect(two_tables).execute("create table three (id integer primary key)")
datasette = Datasette(
[{"one_table": one_table, "two_tables": two_tables}[db] for db in dbs]
)
path = await utils.initial_path_for_datasette(datasette)
assert path == expected_path
@pytest.mark.parametrize(
"content,expected",
(
("title: Hello", {"title": "Hello"}),
('{"title": "Hello"}', {"title": "Hello"}),
("{{ this }} is {{ bad }}", None),
),
)
def test_parse_metadata(content, expected):
if expected is None:
with pytest.raises(utils.BadMetadataError):
utils.parse_metadata(content)
else:
assert utils.parse_metadata(content) == expected
@pytest.mark.asyncio
@pytest.mark.parametrize(
"sql,expected",
(
("select 1", []),
("select 1 + :one", ["one"]),
("select 1 + :one + :two", ["one", "two"]),
("select 'bob' || '0:00' || :cat", ["cat"]),
("select this is invalid :one, :two, :three", ["one", "two", "three"]),
),
)
async def test_derive_named_parameters(sql, expected):
ds = Datasette([], memory=True)
db = ds.get_database("_memory")
params = await utils.derive_named_parameters(db, sql)
assert params == expected
@pytest.mark.parametrize(
"original,expected",
(
("abc", "abc"),
("/foo/bar", "~2Ffoo~2Fbar"),
("/-/bar", "~2F-~2Fbar"),
("-/db-/table.csv", "-~2Fdb-~2Ftable~2Ecsv"),
(r"%~-/", "~25~7E-~2F"),
("~25~7E~2D~2F", "~7E25~7E7E~7E2D~7E2F"),
("with space", "with+space"),
),
)
def test_tilde_encoding(original, expected):
actual = utils.tilde_encode(original)
assert actual == expected
# And test round-trip
assert original == utils.tilde_decode(actual)
@pytest.mark.parametrize(
"url,length,expected",
(
("https://example.com/", 5, "http…"),
("https://example.com/foo/bar", 15, "https://exampl…"),
("https://example.com/foo/bar/baz.jpg", 30, "https://example.com/foo/ba….jpg"),
# Extensions longer than 4 characters are not treated specially:
("https://example.com/foo/bar/baz.jpeg2", 30, "https://example.com/foo/bar/b…"),
(
"https://example.com/foo/bar/baz.jpeg2",
None,
"https://example.com/foo/bar/baz.jpeg2",
),
),
)
def test_truncate_url(url, length, expected):
actual = utils.truncate_url(url, length)
assert actual == expected
@pytest.mark.parametrize(
"pairs,expected",
(
# Simple nested objects
([("a", "b")], {"a": "b"}),
([("a.b", "c")], {"a": {"b": "c"}}),
# JSON literals
([("a.b", "true")], {"a": {"b": True}}),
([("a.b", "false")], {"a": {"b": False}}),
([("a.b", "null")], {"a": {"b": None}}),
([("a.b", "1")], {"a": {"b": 1}}),
([("a.b", "1.1")], {"a": {"b": 1.1}}),
# Nested JSON literals
([("a.b", '{"foo": "bar"}')], {"a": {"b": {"foo": "bar"}}}),
([("a.b", "[1, 2, 3]")], {"a": {"b": [1, 2, 3]}}),
# JSON strings are preserved
([("a.b", '"true"')], {"a": {"b": "true"}}),
([("a.b", '"[1, 2, 3]"')], {"a": {"b": "[1, 2, 3]"}}),
# Later keys over-ride the previous
(
[
("a", "b"),
("a.b", "c"),
],
{"a": {"b": "c"}},
),
(
[
("settings.trace_debug", "true"),
("plugins.datasette-ripgrep.path", "/etc"),
("settings.trace_debug", "false"),
],
{
"settings": {
"trace_debug": False,
},
"plugins": {
"datasette-ripgrep": {
"path": "/etc",
}
},
},
),
),
)
def test_pairs_to_nested_config(pairs, expected):
actual = utils.pairs_to_nested_config(pairs)
assert actual == expected
@pytest.mark.asyncio
async def test_calculate_etag(tmp_path):
path = tmp_path / "test.txt"
path.write_text("hello")
etag = '"5d41402abc4b2a76b9719d911017c592"'
assert etag == await utils.calculate_etag(path)
assert utils._etag_cache[path] == etag
utils._etag_cache[path] = "hash"
assert "hash" == await utils.calculate_etag(path)
utils._etag_cache.clear()