Refactored util functions into new utils module

pull/81/head
Simon Willison 2017-11-10 11:25:54 -08:00
rodzic 1c57bd202f
commit a8a293cd71
3 zmienionych plików z 122 dodań i 111 usunięć

Wyświetl plik

@ -5,18 +5,23 @@ from sanic.views import HTTPMethodView
from sanic_jinja2 import SanicJinja2
from jinja2 import FileSystemLoader
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from functools import wraps
from concurrent import futures
import asyncio
import threading
import urllib.parse
import json
import base64
import hashlib
import sys
import time
from .utils import (
build_where_clause,
CustomJSONEncoder,
InvalidSql,
path_from_row_pks,
compound_pks_from_path,
sqlite_timelimit,
validate_sql_select,
)
app_root = Path(__file__).parent.parent
@ -373,93 +378,6 @@ def resolve_db_name(files, db_name, **kwargs):
return name, expected, None
def compound_pks_from_path(path):
return [
urllib.parse.unquote_plus(b) for b in path.split(',')
]
def path_from_row_pks(row, pks, use_rowid):
if use_rowid:
return urllib.parse.quote_plus(str(row['rowid']))
bits = []
for pk in pks:
bits.append(
urllib.parse.quote_plus(str(row[pk]))
)
return ','.join(bits)
def build_where_clause(args):
sql_bits = []
params = {}
for i, (key, values) in enumerate(sorted(args.items())):
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
template = {
'exact': '"{}" = :{}',
'contains': '"{}" like :{}',
'endswith': '"{}" like :{}',
'startswith': '"{}" like :{}',
'gt': '"{}" > :{}',
'gte': '"{}" >= :{}',
'lt': '"{}" < :{}',
'lte': '"{}" <= :{}',
'glob': '"{}" glob :{}',
'like': '"{}" like :{}',
}[lookup]
numeric_operators = {'gt', 'gte', 'lt', 'lte'}
value = values[0]
value_convert = {
'contains': lambda s: '%{}%'.format(s),
'endswith': lambda s: '%{}'.format(s),
'startswith': lambda s: '{}%'.format(s),
}.get(lookup, lambda s: s)
converted = value_convert(value)
if lookup in numeric_operators and converted.isdigit():
converted = int(converted)
param_id = 'p{}'.format(i)
sql_bits.append(
template.format(column, param_id)
)
params[param_id] = converted
where_clause = ' and '.join(sql_bits)
return where_clause, params
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, sqlite3.Row):
return tuple(obj)
if isinstance(obj, sqlite3.Cursor):
return list(obj)
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode('utf8')
except UnicodeDecodeError:
return {
'$base64': True,
'encoded': base64.b64encode(obj).decode('latin1'),
}
return json.JSONEncoder.default(self, obj)
@contextmanager
def sqlite_timelimit(conn, ms):
deadline = time.time() + (ms / 1000)
def handler():
if time.time() >= deadline:
return 1
conn.set_progress_handler(handler, 10000)
yield
conn.set_progress_handler(None, 10000)
class Datasette:
def __init__(self, files, num_threads=3):
self.files = files
@ -497,15 +415,3 @@ class Datasette:
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>'
)
return app
class InvalidSql(Exception):
pass
def validate_sql_select(sql):
sql = sql.strip().lower()
if not sql.startswith('select '):
raise InvalidSql('Statement must begin with SELECT')
if 'pragma' in sql:
raise InvalidSql('Statement may not contain PRAGMA')

105
datasette/utils.py 100644
Wyświetl plik

@ -0,0 +1,105 @@
from contextlib import contextmanager
import base64
import json
import sqlite3
import time
import urllib
def compound_pks_from_path(path):
return [
urllib.parse.unquote_plus(b) for b in path.split(',')
]
def path_from_row_pks(row, pks, use_rowid):
if use_rowid:
return urllib.parse.quote_plus(str(row['rowid']))
bits = []
for pk in pks:
bits.append(
urllib.parse.quote_plus(str(row[pk]))
)
return ','.join(bits)
def build_where_clause(args):
sql_bits = []
params = {}
for i, (key, values) in enumerate(sorted(args.items())):
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
template = {
'exact': '"{}" = :{}',
'contains': '"{}" like :{}',
'endswith': '"{}" like :{}',
'startswith': '"{}" like :{}',
'gt': '"{}" > :{}',
'gte': '"{}" >= :{}',
'lt': '"{}" < :{}',
'lte': '"{}" <= :{}',
'glob': '"{}" glob :{}',
'like': '"{}" like :{}',
}[lookup]
numeric_operators = {'gt', 'gte', 'lt', 'lte'}
value = values[0]
value_convert = {
'contains': lambda s: '%{}%'.format(s),
'endswith': lambda s: '%{}'.format(s),
'startswith': lambda s: '{}%'.format(s),
}.get(lookup, lambda s: s)
converted = value_convert(value)
if lookup in numeric_operators and converted.isdigit():
converted = int(converted)
param_id = 'p{}'.format(i)
sql_bits.append(
template.format(column, param_id)
)
params[param_id] = converted
where_clause = ' and '.join(sql_bits)
return where_clause, params
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, sqlite3.Row):
return tuple(obj)
if isinstance(obj, sqlite3.Cursor):
return list(obj)
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode('utf8')
except UnicodeDecodeError:
return {
'$base64': True,
'encoded': base64.b64encode(obj).decode('latin1'),
}
return json.JSONEncoder.default(self, obj)
@contextmanager
def sqlite_timelimit(conn, ms):
deadline = time.time() + (ms / 1000)
def handler():
if time.time() >= deadline:
return 1
conn.set_progress_handler(handler, 10000)
yield
conn.set_progress_handler(None, 10000)
class InvalidSql(Exception):
pass
def validate_sql_select(sql):
sql = sql.strip().lower()
if not sql.startswith('select '):
raise InvalidSql('Statement must begin with SELECT')
if 'pragma' in sql:
raise InvalidSql('Statement may not contain PRAGMA')

Wyświetl plik

@ -2,7 +2,7 @@
Tests for various datasette helper functions.
"""
from datasette import app
from datasette import utils
import pytest
import json
@ -15,7 +15,7 @@ import json
('123%2F433%2F112', ['123/433/112']),
])
def test_compound_pks_from_path(path, expected):
assert expected == app.compound_pks_from_path(path)
assert expected == utils.compound_pks_from_path(path)
@pytest.mark.parametrize('row,pks,expected_path', [
@ -24,7 +24,7 @@ def test_compound_pks_from_path(path, expected):
({'A': 123}, ['A'], '123'),
])
def test_path_from_row_pks(row, pks, expected_path):
actual_path = app.path_from_row_pks(row, pks, False)
actual_path = utils.path_from_row_pks(row, pks, False)
assert expected_path == actual_path
@ -40,7 +40,7 @@ def test_path_from_row_pks(row, pks, expected_path):
def test_custom_json_encoder(obj, expected):
actual = json.dumps(
obj,
cls=app.CustomJSONEncoder,
cls=utils.CustomJSONEncoder,
sort_keys=True
)
assert expected == actual
@ -90,7 +90,7 @@ def test_custom_json_encoder(obj, expected):
),
])
def test_build_where(args, expected_where, expected_params):
actual_where, actual_params = app.build_where_clause(args)
actual_where, actual_params = utils.build_where_clause(args)
assert expected_where == actual_where
assert {
'p{}'.format(i): param
@ -104,8 +104,8 @@ def test_build_where(args, expected_where, expected_params):
"SELECT * FROM pragma_index_info('idx52')",
])
def test_validate_sql_select_bad(bad_sql):
with pytest.raises(app.InvalidSql):
app.validate_sql_select(bad_sql)
with pytest.raises(utils.InvalidSql):
utils.validate_sql_select(bad_sql)
@pytest.mark.parametrize('good_sql', [
@ -114,4 +114,4 @@ def test_validate_sql_select_bad(bad_sql):
'select 1 + 1',
])
def test_validate_sql_select_good(good_sql):
app.validate_sql_select(good_sql)
utils.validate_sql_select(good_sql)