Apply black to everything, enforce via unit tests (#449)

I've run the black code formatting tool against everything:

    black tests datasette setup.py

I also added a new unit test, in tests/test_black.py, which will fail if the code does not
conform to black's exacting standards.

This unit test only runs on Python 3.6 or higher, because black itself doesn't run on 3.5.
pull/450/head
Simon Willison 2019-05-03 22:15:14 -04:00 zatwierdzone przez GitHub
rodzic 66c87cee0c
commit 35d6ee2790
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
31 zmienionych plików z 2744 dodań i 2688 usunięć

Wyświetl plik

@ -1,3 +1,3 @@
from datasette.version import __version_info__, __version__ # noqa
from .hookspecs import hookimpl # noqa
from .hookspecs import hookspec # noqa
from .hookspecs import hookimpl # noqa
from .hookspecs import hookspec # noqa

Wyświetl plik

@ -1,4 +1,3 @@
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@ -58,17 +57,18 @@ HANDLERS = {}
def register_vcs_handler(vcs, method): # decorator
"""Decorator to mark a method as the handler for a particular VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None))
p = subprocess.Popen(
[c] + args,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr else None),
)
break
except EnvironmentError:
e = sys.exc_info()[1]
@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for i in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
return {
"version": dirname[len(parentdir_prefix) :],
"full-revisionid": None,
"dirty": False,
"error": None,
"date": None,
}
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
(str(rootdirs), parentdir_prefix))
print(
"Tried directories %s but none started with prefix %s"
% (str(rootdirs), parentdir_prefix)
)
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
tags = set([r for r in refs if re.search(r"\d", r)])
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
r = ref[len(tag_prefix) :]
if verbose:
print("picking %s" % r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
return {
"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": None,
"date": date,
}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
return {
"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": "no suitable tags",
"date": None,
}
@register_vcs_handler("git", "pieces_from_vcs")
@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=True)
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
"--always", "--long",
"--match", "%s*" % tag_prefix],
cwd=root)
describe_out, rc = run_command(
GITS,
[
"describe",
"--tags",
"--dirty",
"--always",
"--long",
"--match",
"%s*" % tag_prefix,
],
cwd=root,
)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
% (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
full_tag,
tag_prefix,
)
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
cwd=root)
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
cwd=root)[0].strip()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[
0
].strip()
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
@ -330,8 +355,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@ -445,11 +469,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
return {
"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None,
}
if not style or style == "default":
style = "pep440" # the default
@ -469,9 +495,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
return {
"version": rendered,
"full-revisionid": pieces["long"],
"dirty": pieces["dirty"],
"error": None,
"date": pieces.get("date"),
}
def get_versions():
@ -485,8 +515,7 @@ def get_versions():
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
verbose)
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
@ -495,13 +524,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for i in cfg.versionfile_source.split('/'):
for i in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None}
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None,
}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@ -515,6 +547,10 @@ def get_versions():
except NotThisMethod:
pass
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to compute version", "date": None}
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to compute version",
"date": None,
}

Wyświetl plik

@ -17,10 +17,7 @@ from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound
from .views.base import (
DatasetteError,
ureg
)
from .views.base import DatasetteError, ureg
from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView
from .views.special import JsonDataView
@ -39,7 +36,7 @@ from .utils import (
sqlite3,
sqlite_timelimit,
table_columns,
to_css_class
to_css_class,
)
from .inspect import inspect_hash, inspect_views, inspect_tables
from .tracer import capture_traces, trace
@ -51,72 +48,85 @@ app_root = Path(__file__).parent.parent
connections = threading.local()
MEMORY = object()
ConfigOption = collections.namedtuple(
"ConfigOption", ("name", "default", "help")
)
ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
CONFIG_OPTIONS = (
ConfigOption("default_page_size", 100, """
Default page size for the table view
""".strip()),
ConfigOption("max_returned_rows", 1000, """
Maximum rows that can be returned from a table or custom query
""".strip()),
ConfigOption("num_sql_threads", 3, """
Number of threads in the thread pool for executing SQLite queries
""".strip()),
ConfigOption("sql_time_limit_ms", 1000, """
Time limit for a SQL query in milliseconds
""".strip()),
ConfigOption("default_facet_size", 30, """
Number of values to return for requested facets
""".strip()),
ConfigOption("facet_time_limit_ms", 200, """
Time limit for calculating a requested facet
""".strip()),
ConfigOption("facet_suggest_time_limit_ms", 50, """
Time limit for calculating a suggested facet
""".strip()),
ConfigOption("hash_urls", False, """
Include DB file contents hash in URLs, for far-future caching
""".strip()),
ConfigOption("allow_facet", True, """
Allow users to specify columns to facet using ?_facet= parameter
""".strip()),
ConfigOption("allow_download", True, """
Allow users to download the original SQLite database files
""".strip()),
ConfigOption("suggest_facets", True, """
Calculate and display suggested facets
""".strip()),
ConfigOption("allow_sql", True, """
Allow arbitrary SQL queries via ?sql= parameter
""".strip()),
ConfigOption("default_cache_ttl", 5, """
Default HTTP cache TTL (used in Cache-Control: max-age= header)
""".strip()),
ConfigOption("default_cache_ttl_hashed", 365 * 24 * 60 * 60, """
Default HTTP cache TTL for hashed URL pages
""".strip()),
ConfigOption("cache_size_kb", 0, """
SQLite cache size in KB (0 == use SQLite default)
""".strip()),
ConfigOption("allow_csv_stream", True, """
Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)
""".strip()),
ConfigOption("max_csv_mb", 100, """
Maximum size allowed for CSV export in MB - set 0 to disable this limit
""".strip()),
ConfigOption("truncate_cells_html", 2048, """
Truncate cells longer than this in HTML table view - set 0 to disable
""".strip()),
ConfigOption("force_https_urls", False, """
Force URLs in API output to always use https:// protocol
""".strip()),
ConfigOption("default_page_size", 100, "Default page size for the table view"),
ConfigOption(
"max_returned_rows",
1000,
"Maximum rows that can be returned from a table or custom query",
),
ConfigOption(
"num_sql_threads",
3,
"Number of threads in the thread pool for executing SQLite queries",
),
ConfigOption(
"sql_time_limit_ms", 1000, "Time limit for a SQL query in milliseconds"
),
ConfigOption(
"default_facet_size", 30, "Number of values to return for requested facets"
),
ConfigOption(
"facet_time_limit_ms", 200, "Time limit for calculating a requested facet"
),
ConfigOption(
"facet_suggest_time_limit_ms",
50,
"Time limit for calculating a suggested facet",
),
ConfigOption(
"hash_urls",
False,
"Include DB file contents hash in URLs, for far-future caching",
),
ConfigOption(
"allow_facet",
True,
"Allow users to specify columns to facet using ?_facet= parameter",
),
ConfigOption(
"allow_download",
True,
"Allow users to download the original SQLite database files",
),
ConfigOption("suggest_facets", True, "Calculate and display suggested facets"),
ConfigOption("allow_sql", True, "Allow arbitrary SQL queries via ?sql= parameter"),
ConfigOption(
"default_cache_ttl",
5,
"Default HTTP cache TTL (used in Cache-Control: max-age= header)",
),
ConfigOption(
"default_cache_ttl_hashed",
365 * 24 * 60 * 60,
"Default HTTP cache TTL for hashed URL pages",
),
ConfigOption(
"cache_size_kb", 0, "SQLite cache size in KB (0 == use SQLite default)"
),
ConfigOption(
"allow_csv_stream",
True,
"Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)",
),
ConfigOption(
"max_csv_mb",
100,
"Maximum size allowed for CSV export in MB - set 0 to disable this limit",
),
ConfigOption(
"truncate_cells_html",
2048,
"Truncate cells longer than this in HTML table view - set 0 to disable",
),
ConfigOption(
"force_https_urls",
False,
"Force URLs in API output to always use https:// protocol",
),
)
DEFAULT_CONFIG = {
option.name: option.default
for option in CONFIG_OPTIONS
}
DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS}
async def favicon(request):
@ -151,11 +161,13 @@ class ConnectedDatabase:
counts = {}
for table in await self.table_names():
try:
table_count = (await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)).rows[0][0]
table_count = (
await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)
).rows[0][0]
counts[table] = table_count
except InterruptedError:
counts[table] = None
@ -175,18 +187,26 @@ class ConnectedDatabase:
return Path(self.path).stem
async def table_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='table'")
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='table'"
)
return [r[0] for r in results.rows]
async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [r[0] for r in (
await self.ds.execute(self.name, """
hidden_tables = [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
""")
).rows]
""",
)
).rows
]
has_spatialite = await self.ds.execute_against_connection_in_thread(
self.name, detect_spatialite
)
@ -205,18 +225,23 @@ class ConnectedDatabase:
] + [
r[0]
for r in (
await self.ds.execute(self.name, """
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
""")
""",
)
).rows
]
# Add any from metadata.json
db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata:
hidden_tables += [
t for t in db_metadata["tables"] if db_metadata["tables"][t].get("hidden")
t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
]
# Also mark as hidden any tables which start with the name of a hidden table
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
@ -229,7 +254,9 @@ class ConnectedDatabase:
return hidden_tables
async def view_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='view'")
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows]
def __repr__(self):
@ -245,13 +272,10 @@ class ConnectedDatabase:
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<ConnectedDatabase: {}{}>".format(
self.name, tags_str
)
return "<ConnectedDatabase: {}{}>".format(self.name, tags_str)
class Datasette:
def __init__(
self,
files,
@ -283,7 +307,9 @@ class Datasette:
path = None
is_memory = True
is_mutable = path not in self.immutables
db = ConnectedDatabase(self, path, is_mutable=is_mutable, is_memory=is_memory)
db = ConnectedDatabase(
self, path, is_mutable=is_mutable, is_memory=is_memory
)
if db.name in self.databases:
raise Exception("Multiple files with same stem: {}".format(db.name))
self.databases[db.name] = db
@ -322,26 +348,24 @@ class Datasette:
def config_dict(self):
# Returns a fully resolved config dictionary, useful for templates
return {
option.name: self.config(option.name)
for option in CONFIG_OPTIONS
}
return {option.name: self.config(option.name) for option in CONFIG_OPTIONS}
def metadata(self, key=None, database=None, table=None, fallback=True):
"""
Looks up metadata, cascading backwards from specified level.
Returns None if metadata value is not found.
"""
assert not (database is None and table is not None), \
"Cannot call metadata() with table= specified but not database="
assert not (
database is None and table is not None
), "Cannot call metadata() with table= specified but not database="
databases = self._metadata.get("databases") or {}
search_list = []
if database is not None:
search_list.append(databases.get(database) or {})
if table is not None:
table_metadata = (
(databases.get(database) or {}).get("tables") or {}
).get(table) or {}
table_metadata = ((databases.get(database) or {}).get("tables") or {}).get(
table
) or {}
search_list.insert(0, table_metadata)
search_list.append(self._metadata)
if not fallback:
@ -359,9 +383,7 @@ class Datasette:
m.update(item)
return m
def plugin_config(
self, plugin_name, database=None, table=None, fallback=True
):
def plugin_config(self, plugin_name, database=None, table=None, fallback=True):
"Return config for plugin, falling back from specified database/table"
plugins = self.metadata(
"plugins", database=database, table=table, fallback=fallback
@ -373,29 +395,19 @@ class Datasette:
def app_css_hash(self):
if not hasattr(self, "_app_css_hash"):
self._app_css_hash = hashlib.sha1(
open(
os.path.join(str(app_root), "datasette/static/app.css")
).read().encode(
"utf8"
)
).hexdigest()[
:6
]
open(os.path.join(str(app_root), "datasette/static/app.css"))
.read()
.encode("utf8")
).hexdigest()[:6]
return self._app_css_hash
def get_canned_queries(self, database_name):
queries = self.metadata(
"queries", database=database_name, fallback=False
) or {}
queries = self.metadata("queries", database=database_name, fallback=False) or {}
names = queries.keys()
return [
self.get_canned_query(database_name, name) for name in names
]
return [self.get_canned_query(database_name, name) for name in names]
def get_canned_query(self, database_name, query_name):
queries = self.metadata(
"queries", database=database_name, fallback=False
) or {}
queries = self.metadata("queries", database=database_name, fallback=False) or {}
query = queries.get(query_name)
if query:
if not isinstance(query, dict):
@ -407,7 +419,7 @@ class Datasette:
table_definition_rows = list(
await self.execute(
database_name,
'select sql from sqlite_master where name = :n and type=:t',
"select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_},
)
)
@ -416,21 +428,19 @@ class Datasette:
return table_definition_rows[0][0]
def get_view_definition(self, database_name, view):
return self.get_table_definition(database_name, view, 'view')
return self.get_table_definition(database_name, view, "view")
def update_with_inherited_metadata(self, metadata):
# Fills in source/license with defaults, if available
metadata.update(
{
"source": metadata.get("source") or self.metadata("source"),
"source_url": metadata.get("source_url")
or self.metadata("source_url"),
"source_url": metadata.get("source_url") or self.metadata("source_url"),
"license": metadata.get("license") or self.metadata("license"),
"license_url": metadata.get("license_url")
or self.metadata("license_url"),
"about": metadata.get("about") or self.metadata("about"),
"about_url": metadata.get("about_url")
or self.metadata("about_url"),
"about_url": metadata.get("about_url") or self.metadata("about_url"),
}
)
@ -444,7 +454,7 @@ class Datasette:
for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension))
if self.config("cache_size_kb"):
conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb")))
conn.execute("PRAGMA cache_size=-{}".format(self.config("cache_size_kb")))
# pylint: disable=no-member
pm.hook.prepare_connection(conn=conn)
@ -452,7 +462,7 @@ class Datasette:
results = await self.execute(
database,
"select 1 from sqlite_master where type='table' and name=?",
params=(table,)
params=(table,),
)
return bool(results.rows)
@ -463,32 +473,28 @@ class Datasette:
# Find the foreign_key for this column
try:
fk = [
foreign_key for foreign_key in foreign_keys
foreign_key
for foreign_key in foreign_keys
if foreign_key["column"] == column
][0]
except IndexError:
return {}
label_column = await self.label_column_for_table(database, fk["other_table"])
if not label_column:
return {
(fk["column"], value): str(value)
for value in values
}
return {(fk["column"], value): str(value) for value in values}
labeled_fks = {}
sql = '''
sql = """
select {other_column}, {label_column}
from {other_table}
where {other_column} in ({placeholders})
'''.format(
""".format(
other_column=escape_sqlite(fk["other_column"]),
label_column=escape_sqlite(label_column),
other_table=escape_sqlite(fk["other_table"]),
placeholders=", ".join(["?"] * len(set(values))),
)
try:
results = await self.execute(
database, sql, list(set(values))
)
results = await self.execute(database, sql, list(set(values)))
except InterruptedError:
pass
else:
@ -499,7 +505,7 @@ class Datasette:
def absolute_url(self, request, path):
url = urllib.parse.urljoin(request.url, path)
if url.startswith("http://") and self.config("force_https_urls"):
url = "https://" + url[len("http://"):]
url = "https://" + url[len("http://") :]
return url
def inspect(self):
@ -532,10 +538,12 @@ class Datasette:
"file": str(path),
"size": path.stat().st_size,
"views": inspect_views(conn),
"tables": inspect_tables(conn, (self.metadata("databases") or {}).get(name, {}))
"tables": inspect_tables(
conn, (self.metadata("databases") or {}).get(name, {})
),
}
except sqlite3.OperationalError as e:
if (e.args[0] == 'no such module: VirtualSpatialIndex'):
if e.args[0] == "no such module: VirtualSpatialIndex":
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
@ -582,7 +590,8 @@ class Datasette:
datasette_version["note"] = self.version_note
return {
"python": {
"version": ".".join(map(str, sys.version_info[:3])), "full": sys.version
"version": ".".join(map(str, sys.version_info[:3])),
"full": sys.version,
},
"datasette": datasette_version,
"sqlite": {
@ -611,10 +620,11 @@ class Datasette:
def table_metadata(self, database, table):
"Fetch table-specific metadata."
return (self.metadata("databases") or {}).get(database, {}).get(
"tables", {}
).get(
table, {}
return (
(self.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {})
)
async def table_columns(self, db_name, table):
@ -628,16 +638,12 @@ class Datasette:
)
async def label_column_for_table(self, db_name, table):
explicit_label_column = (
self.table_metadata(
db_name, table
).get("label_column")
)
explicit_label_column = self.table_metadata(db_name, table).get("label_column")
if explicit_label_column:
return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.table_columns(db_name, table)
if (column_names and len(column_names) == 2 and "id" in column_names):
if column_names and len(column_names) == 2 and "id" in column_names:
return [c for c in column_names if c != "id"][0]
# Couldn't find a label:
return None
@ -664,9 +670,7 @@ class Datasette:
setattr(connections, db_name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(
self.executor, in_thread
)
return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread)
async def execute(
self,
@ -701,7 +705,7 @@ class Datasette:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
if e.args == ("interrupted",):
raise InterruptedError(e, sql, params)
if log_sql_errors:
print(
@ -726,7 +730,7 @@ class Datasette:
def register_renderers(self):
""" Register output renderers which output data in custom formats. """
# Built-in renderers
self.renderers['json'] = json_renderer
self.renderers["json"] = json_renderer
# Hooks
hook_renderers = []
@ -737,19 +741,22 @@ class Datasette:
hook_renderers.append(hook)
for renderer in hook_renderers:
self.renderers[renderer['extension']] = renderer['callback']
self.renderers[renderer["extension"]] = renderer["callback"]
def app(self):
class TracingSanic(Sanic):
async def handle_request(self, request, write_callback, stream_callback):
if request.args.get("_trace"):
request["traces"] = []
request["trace_start"] = time.time()
with capture_traces(request["traces"]):
res = await super().handle_request(request, write_callback, stream_callback)
res = await super().handle_request(
request, write_callback, stream_callback
)
else:
res = await super().handle_request(request, write_callback, stream_callback)
res = await super().handle_request(
request, write_callback, stream_callback
)
return res
app = TracingSanic(__name__)
@ -822,15 +829,16 @@ class Datasette:
)
app.add_route(
DatabaseView.as_view(self),
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>"
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>",
)
app.add_route(
TableView.as_view(self),
r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>",
TableView.as_view(self), r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>"
)
app.add_route(
RowView.as_view(self),
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:(" + renderer_regex + r")?$>",
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:("
+ renderer_regex
+ r")?$>",
)
self.register_custom_units()
@ -852,7 +860,7 @@ class Datasette:
"duration": time.time() - request["trace_start"],
"queries": request["traces"],
}
if "text/html" in response.content_type and b'</body>' in response.body:
if "text/html" in response.content_type and b"</body>" in response.body:
extra = json.dumps(traces, indent=2)
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
response.body = response.body.replace(b"</body>", extra_html)
@ -908,6 +916,6 @@ class Datasette:
async def setup_db(app, loop):
for dbname, database in self.databases.items():
if not database.is_mutable:
await database.table_counts(limit=60*60*1000)
await database.table_counts(limit=60 * 60 * 1000)
return app

Wyświetl plik

@ -20,16 +20,14 @@ class Config(click.ParamType):
def convert(self, config, param, ctx):
if ":" not in config:
self.fail(
'"{}" should be name:value'.format(config), param, ctx
)
self.fail('"{}" should be name:value'.format(config), param, ctx)
return
name, value = config.split(":")
if name not in DEFAULT_CONFIG:
self.fail(
"{} is not a valid option (--help-config to see all)".format(
name
), param, ctx
"{} is not a valid option (--help-config to see all)".format(name),
param,
ctx,
)
return
# Type checking
@ -44,14 +42,12 @@ class Config(click.ParamType):
return
elif isinstance(default, int):
if not value.isdigit():
self.fail(
'"{}" should be an integer'.format(name), param, ctx
)
self.fail('"{}" should be an integer'.format(name), param, ctx)
return
return name, int(value)
else:
# Should never happen:
self.fail('Invalid option')
self.fail("Invalid option")
@click.group(cls=DefaultGroup, default="serve", default_if_no_args=True)
@ -204,13 +200,9 @@ def plugins(all, plugins_dir):
multiple=True,
)
@click.option(
"--install",
help="Additional packages (e.g. plugins) to install",
multiple=True,
)
@click.option(
"--spatialite", is_flag=True, help="Enable SpatialLite extension"
"--install", help="Additional packages (e.g. plugins) to install", multiple=True
)
@click.option("--spatialite", is_flag=True, help="Enable SpatialLite extension")
@click.option("--version-note", help="Additional note to show on /-/versions")
@click.option("--title", help="Title for metadata")
@click.option("--license", help="License label for metadata")
@ -322,9 +314,7 @@ def package(
help="mountpoint:path-to-directory for serving static files",
multiple=True,
)
@click.option(
"--memory", is_flag=True, help="Make :memory: database available"
)
@click.option("--memory", is_flag=True, help="Make :memory: database available")
@click.option(
"--config",
type=Config(),
@ -332,11 +322,7 @@ def package(
multiple=True,
)
@click.option("--version-note", help="Additional note to show on /-/versions")
@click.option(
"--help-config",
is_flag=True,
help="Show available config options",
)
@click.option("--help-config", is_flag=True, help="Show available config options")
def serve(
files,
immutable,
@ -360,12 +346,12 @@ def serve(
if help_config:
formatter = formatting.HelpFormatter()
with formatter.section("Config options"):
formatter.write_dl([
(option.name, '{} (default={})'.format(
option.help, option.default
))
for option in CONFIG_OPTIONS
])
formatter.write_dl(
[
(option.name, "{} (default={})".format(option.help, option.default))
for option in CONFIG_OPTIONS
]
)
click.echo(formatter.getvalue())
sys.exit(0)
if reload:
@ -384,7 +370,9 @@ def serve(
if metadata:
metadata_data = json.loads(metadata.read())
click.echo("Serve! files={} (immutables={}) on port {}".format(files, immutable, port))
click.echo(
"Serve! files={} (immutables={}) on port {}".format(files, immutable, port)
)
ds = Datasette(
files,
immutables=immutable,

Wyświetl plik

@ -31,14 +31,15 @@ def load_facet_configs(request, table_metadata):
metadata_config = {"simple": metadata_config}
else:
# This should have a single key and a single value
assert len(metadata_config.values()) == 1, "Metadata config dicts should be {type: config}"
assert (
len(metadata_config.values()) == 1
), "Metadata config dicts should be {type: config}"
type, metadata_config = metadata_config.items()[0]
if isinstance(metadata_config, str):
metadata_config = {"simple": metadata_config}
facet_configs.setdefault(type, []).append({
"source": "metadata",
"config": metadata_config
})
facet_configs.setdefault(type, []).append(
{"source": "metadata", "config": metadata_config}
)
qs_pairs = urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
for key, values in qs_pairs.items():
if key.startswith("_facet"):
@ -53,10 +54,9 @@ def load_facet_configs(request, table_metadata):
config = json.loads(value)
else:
config = {"simple": value}
facet_configs.setdefault(type, []).append({
"source": "request",
"config": config
})
facet_configs.setdefault(type, []).append(
{"source": "request", "config": config}
)
return facet_configs
@ -214,7 +214,9 @@ class ColumnFacet(Facet):
"name": column,
"type": self.type,
"hideable": source != "metadata",
"toggle_url": path_with_removed_args(self.request, {"_facet": column}),
"toggle_url": path_with_removed_args(
self.request, {"_facet": column}
),
"results": facet_results_values,
"truncated": len(facet_rows_results) > facet_size,
}
@ -269,30 +271,31 @@ class ArrayFacet(Facet):
select distinct json_type({column})
from ({sql})
""".format(
column=escape_sqlite(column),
sql=self.sql,
column=escape_sqlite(column), sql=self.sql
)
try:
results = await self.ds.execute(
self.database, suggested_facet_sql, self.params,
self.database,
suggested_facet_sql,
self.params,
truncate=False,
custom_time_limit=self.ds.config("facet_suggest_time_limit_ms"),
log_sql_errors=False,
)
types = tuple(r[0] for r in results.rows)
if types in (
("array",),
("array", None)
):
suggested_facets.append({
"name": column,
"type": "array",
"toggle_url": self.ds.absolute_url(
self.request, path_with_added_args(
self.request, {"_facet_array": column}
)
),
})
if types in (("array",), ("array", None)):
suggested_facets.append(
{
"name": column,
"type": "array",
"toggle_url": self.ds.absolute_url(
self.request,
path_with_added_args(
self.request, {"_facet_array": column}
),
),
}
)
except (InterruptedError, sqlite3.OperationalError):
continue
return suggested_facets
@ -314,13 +317,13 @@ class ArrayFacet(Facet):
) join json_each({col}) j
group by j.value order by count desc limit {limit}
""".format(
col=escape_sqlite(column),
sql=self.sql,
limit=facet_size+1,
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database, facet_sql, self.params,
self.database,
facet_sql,
self.params,
truncate=False,
custom_time_limit=self.ds.config("facet_time_limit_ms"),
)
@ -330,7 +333,9 @@ class ArrayFacet(Facet):
"type": self.type,
"results": facet_results_values,
"hideable": source != "metadata",
"toggle_url": path_with_removed_args(self.request, {"_facet_array": column}),
"toggle_url": path_with_removed_args(
self.request, {"_facet_array": column}
),
"truncated": len(facet_rows_results) > facet_size,
}
facet_rows = facet_rows_results.rows[:facet_size]
@ -346,13 +351,17 @@ class ArrayFacet(Facet):
toggle_path = path_with_added_args(
self.request, {"{}__arraycontains".format(column): value}
)
facet_results_values.append({
"value": value,
"label": value,
"count": row["count"],
"toggle_url": self.ds.absolute_url(self.request, toggle_path),
"selected": selected,
})
facet_results_values.append(
{
"value": value,
"label": value,
"count": row["count"],
"toggle_url": self.ds.absolute_url(
self.request, toggle_path
),
"selected": selected,
}
)
except InterruptedError:
facets_timed_out.append(column)

Wyświetl plik

@ -1,10 +1,7 @@
import json
import numbers
from .utils import (
detect_json1,
escape_sqlite,
)
from .utils import detect_json1, escape_sqlite
class Filter:
@ -20,7 +17,16 @@ class Filter:
class TemplatedFilter(Filter):
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
def __init__(
self,
key,
display,
sql_template,
human_template,
format="{}",
numeric=False,
no_argument=False,
):
self.key = key
self.display = display
self.sql_template = sql_template
@ -34,16 +40,10 @@ class TemplatedFilter(Filter):
if self.numeric and converted.isdigit():
converted = int(converted)
if self.no_argument:
kwargs = {
'c': column,
}
kwargs = {"c": column}
converted = None
else:
kwargs = {
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
kwargs = {"c": column, "p": "p{}".format(param_counter), "t": table}
return self.sql_template.format(**kwargs), converted
def human_clause(self, column, value):
@ -58,8 +58,8 @@ class TemplatedFilter(Filter):
class InFilter(Filter):
key = 'in'
display = 'in'
key = "in"
display = "in"
def __init__(self):
pass
@ -81,34 +81,98 @@ class InFilter(Filter):
class Filters:
_filters = [
# key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
InFilter(),
] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in (
_filters = (
[
# key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter(
"exact",
"=",
'"{c}" = :{p}',
lambda c, v: "{c} = {v}" if v.isdigit() else '{c} = "{v}"',
),
TemplatedFilter(
"not",
"!=",
'"{c}" != :{p}',
lambda c, v: "{c} != {v}" if v.isdigit() else '{c} != "{v}"',
),
TemplatedFilter(
"contains",
"contains",
'"{c}" like :{p}',
'{c} contains "{v}"',
format="%{}%",
),
TemplatedFilter(
"endswith",
"ends with",
'"{c}" like :{p}',
'{c} ends with "{v}"',
format="%{}",
),
TemplatedFilter(
"startswith",
"starts with",
'"{c}" like :{p}',
'{c} starts with "{v}"',
format="{}%",
),
TemplatedFilter("gt", ">", '"{c}" > :{p}', "{c} > {v}", numeric=True),
TemplatedFilter(
"gte", "\u2265", '"{c}" >= :{p}', "{c} \u2265 {v}", numeric=True
),
TemplatedFilter("lt", "<", '"{c}" < :{p}', "{c} < {v}", numeric=True),
TemplatedFilter(
"lte", "\u2264", '"{c}" <= :{p}', "{c} \u2264 {v}", numeric=True
),
TemplatedFilter("like", "like", '"{c}" like :{p}', '{c} like "{v}"'),
TemplatedFilter("glob", "glob", '"{c}" glob :{p}', '{c} glob "{v}"'),
InFilter(),
]
+ (
[
TemplatedFilter(
"arraycontains",
"array contains",
"""rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p}
)""", '{c} contains "{v}"')
] if detect_json1() else []) + [
TemplatedFilter('date', 'date', 'date({c}) = :{p}', '"{c}" is on date {v}'),
TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
]
_filters_by_key = {
f.key: f for f in _filters
}
)""",
'{c} contains "{v}"',
)
]
if detect_json1()
else []
)
+ [
TemplatedFilter("date", "date", "date({c}) = :{p}", '"{c}" is on date {v}'),
TemplatedFilter(
"isnull", "is null", '"{c}" is null', "{c} is null", no_argument=True
),
TemplatedFilter(
"notnull",
"is not null",
'"{c}" is not null',
"{c} is not null",
no_argument=True,
),
TemplatedFilter(
"isblank",
"is blank",
'("{c}" is null or "{c}" = "")',
"{c} is blank",
no_argument=True,
),
TemplatedFilter(
"notblank",
"is not blank",
'("{c}" is not null and "{c}" != "")',
"{c} is not blank",
no_argument=True,
),
]
)
_filters_by_key = {f.key: f for f in _filters}
def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs
@ -132,22 +196,22 @@ class Filters:
and_bits = []
commas, tail = bits[:-1], bits[-1:]
if commas:
and_bits.append(', '.join(commas))
and_bits.append(", ".join(commas))
if tail:
and_bits.append(tail[0])
s = ' and '.join(and_bits)
s = " and ".join(and_bits)
if not s:
return ''
return 'where {}'.format(s)
return ""
return "where {}".format(s)
def selections(self):
"Yields (column, lookup, value) tuples"
for key, value in self.pairs:
if '__' in key:
column, lookup = key.rsplit('__', 1)
if "__" in key:
column, lookup = key.rsplit("__", 1)
else:
column = key
lookup = 'exact'
lookup = "exact"
yield column, lookup, value
def has_selections(self):
@ -174,13 +238,15 @@ class Filters:
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
sql_bit, param = filter.where_clause(
table, column, self.convert_unit(column, value), i
)
sql_bits.append(sql_bit)
if param is not None:
if not isinstance(param, list):
param = [param]
for individual_param in param:
param_id = 'p{}'.format(i)
param_id = "p{}".format(i)
params[param_id] = individual_param
i += 1
return sql_bits, params

Wyświetl plik

@ -7,7 +7,7 @@ from .utils import (
escape_sqlite,
get_all_foreign_keys,
table_columns,
sqlite3
sqlite3,
)
@ -29,7 +29,9 @@ def inspect_hash(path):
def inspect_views(conn):
" List views in a database. "
return [v[0] for v in conn.execute('select name from sqlite_master where type = "view"')]
return [
v[0] for v in conn.execute('select name from sqlite_master where type = "view"')
]
def inspect_tables(conn, database_metadata):
@ -37,15 +39,11 @@ def inspect_tables(conn, database_metadata):
tables = {}
table_names = [
r["name"]
for r in conn.execute(
'select * from sqlite_master where type="table"'
)
for r in conn.execute('select * from sqlite_master where type="table"')
]
for table in table_names:
table_metadata = database_metadata.get("tables", {}).get(
table, {}
)
table_metadata = database_metadata.get("tables", {}).get(table, {})
try:
count = conn.execute(

Wyświetl plik

@ -41,8 +41,12 @@ def publish_subcommand(publish):
name,
spatialite,
):
fail_if_publish_binary_not_installed("gcloud", "Google Cloud", "https://cloud.google.com/sdk/")
project = check_output("gcloud config get-value project", shell=True, universal_newlines=True).strip()
fail_if_publish_binary_not_installed(
"gcloud", "Google Cloud", "https://cloud.google.com/sdk/"
)
project = check_output(
"gcloud config get-value project", shell=True, universal_newlines=True
).strip()
with temporary_docker_directory(
files,
@ -68,4 +72,9 @@ def publish_subcommand(publish):
):
image_id = "gcr.io/{project}/{name}".format(project=project, name=name)
check_call("gcloud builds submit --tag {}".format(image_id), shell=True)
check_call("gcloud beta run deploy --allow-unauthenticated --image {}".format(image_id), shell=True)
check_call(
"gcloud beta run deploy --allow-unauthenticated --image {}".format(
image_id
),
shell=True,
)

Wyświetl plik

@ -5,46 +5,54 @@ import sys
def add_common_publish_arguments_and_options(subcommand):
for decorator in reversed((
click.argument("files", type=click.Path(exists=True), nargs=-1),
click.option(
"-m",
"--metadata",
type=click.File(mode="r"),
help="Path to JSON file containing metadata to publish",
),
click.option("--extra-options", help="Extra options to pass to datasette serve"),
click.option("--branch", help="Install datasette from a GitHub branch e.g. master"),
click.option(
"--template-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom templates",
),
click.option(
"--plugins-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom plugins",
),
click.option(
"--static",
type=StaticMount(),
help="mountpoint:path-to-directory for serving static files",
multiple=True,
),
click.option(
"--install",
help="Additional packages (e.g. plugins) to install",
multiple=True,
),
click.option("--version-note", help="Additional note to show on /-/versions"),
click.option("--title", help="Title for metadata"),
click.option("--license", help="License label for metadata"),
click.option("--license_url", help="License URL for metadata"),
click.option("--source", help="Source label for metadata"),
click.option("--source_url", help="Source URL for metadata"),
click.option("--about", help="About label for metadata"),
click.option("--about_url", help="About URL for metadata"),
)):
for decorator in reversed(
(
click.argument("files", type=click.Path(exists=True), nargs=-1),
click.option(
"-m",
"--metadata",
type=click.File(mode="r"),
help="Path to JSON file containing metadata to publish",
),
click.option(
"--extra-options", help="Extra options to pass to datasette serve"
),
click.option(
"--branch", help="Install datasette from a GitHub branch e.g. master"
),
click.option(
"--template-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom templates",
),
click.option(
"--plugins-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom plugins",
),
click.option(
"--static",
type=StaticMount(),
help="mountpoint:path-to-directory for serving static files",
multiple=True,
),
click.option(
"--install",
help="Additional packages (e.g. plugins) to install",
multiple=True,
),
click.option(
"--version-note", help="Additional note to show on /-/versions"
),
click.option("--title", help="Title for metadata"),
click.option("--license", help="License label for metadata"),
click.option("--license_url", help="License URL for metadata"),
click.option("--source", help="Source label for metadata"),
click.option("--source_url", help="Source URL for metadata"),
click.option("--about", help="About label for metadata"),
click.option("--about_url", help="About URL for metadata"),
)
):
subcommand = decorator(subcommand)
return subcommand

Wyświetl plik

@ -76,9 +76,7 @@ def publish_subcommand(publish):
"about_url": about_url,
},
):
now_json = {
"version": 1
}
now_json = {"version": 1}
if alias:
now_json["alias"] = alias
open("now.json", "w").write(json.dumps(now_json))

Wyświetl plik

@ -89,8 +89,4 @@ def json_renderer(args, data, view_name):
else:
body = json.dumps(data, cls=CustomJSONEncoder)
content_type = "application/json"
return {
"body": body,
"status_code": status_code,
"content_type": content_type
}
return {"body": body, "status_code": status_code, "content_type": content_type}

Wyświetl plik

@ -21,27 +21,29 @@ except ImportError:
import sqlite3
# From https://www.sqlite.org/lang_keywords.html
reserved_words = set((
'abort action add after all alter analyze and as asc attach autoincrement '
'before begin between by cascade case cast check collate column commit '
'conflict constraint create cross current_date current_time '
'current_timestamp database default deferrable deferred delete desc detach '
'distinct drop each else end escape except exclusive exists explain fail '
'for foreign from full glob group having if ignore immediate in index '
'indexed initially inner insert instead intersect into is isnull join key '
'left like limit match natural no not notnull null of offset on or order '
'outer plan pragma primary query raise recursive references regexp reindex '
'release rename replace restrict right rollback row savepoint select set '
'table temp temporary then to transaction trigger union unique update using '
'vacuum values view virtual when where with without'
).split())
reserved_words = set(
(
"abort action add after all alter analyze and as asc attach autoincrement "
"before begin between by cascade case cast check collate column commit "
"conflict constraint create cross current_date current_time "
"current_timestamp database default deferrable deferred delete desc detach "
"distinct drop each else end escape except exclusive exists explain fail "
"for foreign from full glob group having if ignore immediate in index "
"indexed initially inner insert instead intersect into is isnull join key "
"left like limit match natural no not notnull null of offset on or order "
"outer plan pragma primary query raise recursive references regexp reindex "
"release rename replace restrict right rollback row savepoint select set "
"table temp temporary then to transaction trigger union unique update using "
"vacuum values view virtual when where with without"
).split()
)
SPATIALITE_DOCKERFILE_EXTRAS = r'''
SPATIALITE_DOCKERFILE_EXTRAS = r"""
RUN apt-get update && \
apt-get install -y python3-dev gcc libsqlite3-mod-spatialite && \
rm -rf /var/lib/apt/lists/*
ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so
'''
"""
class InterruptedError(Exception):
@ -67,27 +69,24 @@ class Results:
def urlsafe_components(token):
"Splits token on commas and URL decodes each component"
return [
urllib.parse.unquote_plus(b) for b in token.split(',')
]
return [urllib.parse.unquote_plus(b) for b in token.split(",")]
def path_from_row_pks(row, pks, use_rowid, quote=True):
""" Generate an optionally URL-quoted unique identifier
for a row from its primary keys."""
if use_rowid:
bits = [row['rowid']]
bits = [row["rowid"]]
else:
bits = [
row[pk]["value"] if isinstance(row[pk], dict) else row[pk]
for pk in pks
row[pk]["value"] if isinstance(row[pk], dict) else row[pk] for pk in pks
]
if quote:
bits = [urllib.parse.quote_plus(str(bit)) for bit in bits]
else:
bits = [str(bit) for bit in bits]
return ','.join(bits)
return ",".join(bits)
def compound_keys_after_sql(pks, start_index=0):
@ -106,16 +105,17 @@ def compound_keys_after_sql(pks, start_index=0):
and_clauses = []
last = pks_left[-1]
rest = pks_left[:-1]
and_clauses = ['{} = :p{}'.format(
escape_sqlite(pk), (i + start_index)
) for i, pk in enumerate(rest)]
and_clauses.append('{} > :p{}'.format(
escape_sqlite(last), (len(rest) + start_index)
))
or_clauses.append('({})'.format(' and '.join(and_clauses)))
and_clauses = [
"{} = :p{}".format(escape_sqlite(pk), (i + start_index))
for i, pk in enumerate(rest)
]
and_clauses.append(
"{} > :p{}".format(escape_sqlite(last), (len(rest) + start_index))
)
or_clauses.append("({})".format(" and ".join(and_clauses)))
pks_left.pop()
or_clauses.reverse()
return '({})'.format('\n or\n'.join(or_clauses))
return "({})".format("\n or\n".join(or_clauses))
class CustomJSONEncoder(json.JSONEncoder):
@ -127,11 +127,11 @@ class CustomJSONEncoder(json.JSONEncoder):
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode('utf8')
return obj.decode("utf8")
except UnicodeDecodeError:
return {
'$base64': True,
'encoded': base64.b64encode(obj).decode('latin1'),
"$base64": True,
"encoded": base64.b64encode(obj).decode("latin1"),
}
return json.JSONEncoder.default(self, obj)
@ -163,20 +163,18 @@ class InvalidSql(Exception):
allowed_sql_res = [
re.compile(r'^select\b'),
re.compile(r'^explain select\b'),
re.compile(r'^explain query plan select\b'),
re.compile(r'^with\b'),
]
disallawed_sql_res = [
(re.compile('pragma'), 'Statement may not contain PRAGMA'),
re.compile(r"^select\b"),
re.compile(r"^explain select\b"),
re.compile(r"^explain query plan select\b"),
re.compile(r"^with\b"),
]
disallawed_sql_res = [(re.compile("pragma"), "Statement may not contain PRAGMA")]
def validate_sql_select(sql):
sql = sql.strip().lower()
if not any(r.match(sql) for r in allowed_sql_res):
raise InvalidSql('Statement must be a SELECT')
raise InvalidSql("Statement must be a SELECT")
for r, msg in disallawed_sql_res:
if r.search(sql):
raise InvalidSql(msg)
@ -184,9 +182,7 @@ def validate_sql_select(sql):
def append_querystring(url, querystring):
op = "&" if ("?" in url) else "?"
return "{}{}{}".format(
url, op, querystring
)
return "{}{}{}".format(url, op, querystring)
def path_with_added_args(request, args, path=None):
@ -198,14 +194,10 @@ def path_with_added_args(request, args, path=None):
for key, value in urllib.parse.parse_qsl(request.query_string):
if key not in args_to_remove:
current.append((key, value))
current.extend([
(key, value)
for key, value in args
if value is not None
])
current.extend([(key, value) for key, value in args if value is not None])
query_string = urllib.parse.urlencode(current)
if query_string:
query_string = '?{}'.format(query_string)
query_string = "?{}".format(query_string)
return path + query_string
@ -220,18 +212,21 @@ def path_with_removed_args(request, args, path=None):
# args can be a dict or a set
current = []
if isinstance(args, set):
def should_remove(key, value):
return key in args
elif isinstance(args, dict):
# Must match key AND value
def should_remove(key, value):
return args.get(key) == value
for key, value in urllib.parse.parse_qsl(query_string):
if not should_remove(key, value):
current.append((key, value))
query_string = urllib.parse.urlencode(current)
if query_string:
query_string = '?{}'.format(query_string)
query_string = "?{}".format(query_string)
return path + query_string
@ -247,54 +242,66 @@ def path_with_replaced_args(request, args, path=None):
current.extend([p for p in args if p[1] is not None])
query_string = urllib.parse.urlencode(current)
if query_string:
query_string = '?{}'.format(query_string)
query_string = "?{}".format(query_string)
return path + query_string
_css_re = re.compile(r'''['"\n\\]''')
_boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
_css_re = re.compile(r"""['"\n\\]""")
_boring_keyword_re = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def escape_css_string(s):
return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s)
return _css_re.sub(lambda m: "\\{:X}".format(ord(m.group())), s)
def escape_sqlite(s):
if _boring_keyword_re.match(s) and (s.lower() not in reserved_words):
return s
else:
return '[{}]'.format(s)
return "[{}]".format(s)
def make_dockerfile(files, metadata_file, extra_options, branch, template_dir, plugins_dir, static, install, spatialite, version_note):
cmd = ['datasette', 'serve', '--host', '0.0.0.0']
def make_dockerfile(
files,
metadata_file,
extra_options,
branch,
template_dir,
plugins_dir,
static,
install,
spatialite,
version_note,
):
cmd = ["datasette", "serve", "--host", "0.0.0.0"]
cmd.append('", "'.join(files))
cmd.extend(['--cors', '--inspect-file', 'inspect-data.json'])
cmd.extend(["--cors", "--inspect-file", "inspect-data.json"])
if metadata_file:
cmd.extend(['--metadata', '{}'.format(metadata_file)])
cmd.extend(["--metadata", "{}".format(metadata_file)])
if template_dir:
cmd.extend(['--template-dir', 'templates/'])
cmd.extend(["--template-dir", "templates/"])
if plugins_dir:
cmd.extend(['--plugins-dir', 'plugins/'])
cmd.extend(["--plugins-dir", "plugins/"])
if version_note:
cmd.extend(['--version-note', '{}'.format(version_note)])
cmd.extend(["--version-note", "{}".format(version_note)])
if static:
for mount_point, _ in static:
cmd.extend(['--static', '{}:{}'.format(mount_point, mount_point)])
cmd.extend(["--static", "{}:{}".format(mount_point, mount_point)])
if extra_options:
for opt in extra_options.split():
cmd.append('{}'.format(opt))
cmd.append("{}".format(opt))
cmd = [shlex.quote(part) for part in cmd]
# port attribute is a (fixed) env variable and should not be quoted
cmd.extend(['--port', '$PORT'])
cmd = ' '.join(cmd)
cmd.extend(["--port", "$PORT"])
cmd = " ".join(cmd)
if branch:
install = ['https://github.com/simonw/datasette/archive/{}.zip'.format(
branch
)] + list(install)
install = [
"https://github.com/simonw/datasette/archive/{}.zip".format(branch)
] + list(install)
else:
install = ['datasette'] + list(install)
install = ["datasette"] + list(install)
return '''
return """
FROM python:3.6
COPY . /app
WORKDIR /app
@ -303,11 +310,11 @@ RUN pip install -U {install_from}
RUN datasette inspect {files} --inspect-file inspect-data.json
ENV PORT 8001
EXPOSE 8001
CMD {cmd}'''.format(
files=' '.join(files),
CMD {cmd}""".format(
files=" ".join(files),
cmd=cmd,
install_from=' '.join(install),
spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else '',
install_from=" ".join(install),
spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else "",
).strip()
@ -324,7 +331,7 @@ def temporary_docker_directory(
install,
spatialite,
version_note,
extra_metadata=None
extra_metadata=None,
):
extra_metadata = extra_metadata or {}
tmp = tempfile.TemporaryDirectory()
@ -332,10 +339,7 @@ def temporary_docker_directory(
datasette_dir = os.path.join(tmp.name, name)
os.mkdir(datasette_dir)
saved_cwd = os.getcwd()
file_paths = [
os.path.join(saved_cwd, file_path)
for file_path in files
]
file_paths = [os.path.join(saved_cwd, file_path) for file_path in files]
file_names = [os.path.split(f)[-1] for f in files]
if metadata:
metadata_content = json.load(metadata)
@ -347,7 +351,7 @@ def temporary_docker_directory(
try:
dockerfile = make_dockerfile(
file_names,
metadata_content and 'metadata.json',
metadata_content and "metadata.json",
extra_options,
branch,
template_dir,
@ -359,24 +363,23 @@ def temporary_docker_directory(
)
os.chdir(datasette_dir)
if metadata_content:
open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2))
open('Dockerfile', 'w').write(dockerfile)
open("metadata.json", "w").write(json.dumps(metadata_content, indent=2))
open("Dockerfile", "w").write(dockerfile)
for path, filename in zip(file_paths, file_names):
link_or_copy(path, os.path.join(datasette_dir, filename))
if template_dir:
link_or_copy_directory(
os.path.join(saved_cwd, template_dir),
os.path.join(datasette_dir, 'templates')
os.path.join(datasette_dir, "templates"),
)
if plugins_dir:
link_or_copy_directory(
os.path.join(saved_cwd, plugins_dir),
os.path.join(datasette_dir, 'plugins')
os.path.join(datasette_dir, "plugins"),
)
for mount_point, path in static:
link_or_copy_directory(
os.path.join(saved_cwd, path),
os.path.join(datasette_dir, mount_point)
os.path.join(saved_cwd, path), os.path.join(datasette_dir, mount_point)
)
yield datasette_dir
finally:
@ -396,7 +399,7 @@ def temporary_heroku_directory(
static,
install,
version_note,
extra_metadata=None
extra_metadata=None,
):
# FIXME: lots of duplicated code from above
@ -404,10 +407,7 @@ def temporary_heroku_directory(
tmp = tempfile.TemporaryDirectory()
saved_cwd = os.getcwd()
file_paths = [
os.path.join(saved_cwd, file_path)
for file_path in files
]
file_paths = [os.path.join(saved_cwd, file_path) for file_path in files]
file_names = [os.path.split(f)[-1] for f in files]
if metadata:
@ -422,53 +422,54 @@ def temporary_heroku_directory(
os.chdir(tmp.name)
if metadata_content:
open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2))
open("metadata.json", "w").write(json.dumps(metadata_content, indent=2))
open('runtime.txt', 'w').write('python-3.6.7')
open("runtime.txt", "w").write("python-3.6.7")
if branch:
install = ['https://github.com/simonw/datasette/archive/{branch}.zip'.format(
branch=branch
)] + list(install)
install = [
"https://github.com/simonw/datasette/archive/{branch}.zip".format(
branch=branch
)
] + list(install)
else:
install = ['datasette'] + list(install)
install = ["datasette"] + list(install)
open('requirements.txt', 'w').write('\n'.join(install))
os.mkdir('bin')
open('bin/post_compile', 'w').write('datasette inspect --inspect-file inspect-data.json')
open("requirements.txt", "w").write("\n".join(install))
os.mkdir("bin")
open("bin/post_compile", "w").write(
"datasette inspect --inspect-file inspect-data.json"
)
extras = []
if template_dir:
link_or_copy_directory(
os.path.join(saved_cwd, template_dir),
os.path.join(tmp.name, 'templates')
os.path.join(tmp.name, "templates"),
)
extras.extend(['--template-dir', 'templates/'])
extras.extend(["--template-dir", "templates/"])
if plugins_dir:
link_or_copy_directory(
os.path.join(saved_cwd, plugins_dir),
os.path.join(tmp.name, 'plugins')
os.path.join(saved_cwd, plugins_dir), os.path.join(tmp.name, "plugins")
)
extras.extend(['--plugins-dir', 'plugins/'])
extras.extend(["--plugins-dir", "plugins/"])
if version_note:
extras.extend(['--version-note', version_note])
extras.extend(["--version-note", version_note])
if metadata_content:
extras.extend(['--metadata', 'metadata.json'])
extras.extend(["--metadata", "metadata.json"])
if extra_options:
extras.extend(extra_options.split())
for mount_point, path in static:
link_or_copy_directory(
os.path.join(saved_cwd, path),
os.path.join(tmp.name, mount_point)
os.path.join(saved_cwd, path), os.path.join(tmp.name, mount_point)
)
extras.extend(['--static', '{}:{}'.format(mount_point, mount_point)])
extras.extend(["--static", "{}:{}".format(mount_point, mount_point)])
quoted_files = " ".join(map(shlex.quote, file_names))
procfile_cmd = 'web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}'.format(
quoted_files=quoted_files,
extras=' '.join(extras),
procfile_cmd = "web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}".format(
quoted_files=quoted_files, extras=" ".join(extras)
)
open('Procfile', 'w').write(procfile_cmd)
open("Procfile", "w").write(procfile_cmd)
for path, filename in zip(file_paths, file_names):
link_or_copy(path, os.path.join(tmp.name, filename))
@ -484,9 +485,7 @@ def detect_primary_keys(conn, table):
" Figure out primary keys for a table. "
table_info_rows = [
row
for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
for row in conn.execute('PRAGMA table_info("{}")'.format(table)).fetchall()
if row[-1]
]
table_info_rows.sort(key=lambda row: row[-1])
@ -494,33 +493,26 @@ def detect_primary_keys(conn, table):
def get_outbound_foreign_keys(conn, table):
infos = conn.execute(
'PRAGMA foreign_key_list([{}])'.format(table)
).fetchall()
infos = conn.execute("PRAGMA foreign_key_list([{}])".format(table)).fetchall()
fks = []
for info in infos:
if info is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = info
fks.append({
'other_table': table_name,
'column': from_,
'other_column': to_
})
fks.append(
{"other_table": table_name, "column": from_, "other_column": to_}
)
return fks
def get_all_foreign_keys(conn):
tables = [r[0] for r in conn.execute('select name from sqlite_master where type="table"')]
tables = [
r[0] for r in conn.execute('select name from sqlite_master where type="table"')
]
table_to_foreign_keys = {}
for table in tables:
table_to_foreign_keys[table] = {
'incoming': [],
'outgoing': [],
}
table_to_foreign_keys[table] = {"incoming": [], "outgoing": []}
for table in tables:
infos = conn.execute(
'PRAGMA foreign_key_list([{}])'.format(table)
).fetchall()
infos = conn.execute("PRAGMA foreign_key_list([{}])".format(table)).fetchall()
for info in infos:
if info is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = info
@ -528,22 +520,20 @@ def get_all_foreign_keys(conn):
# Weird edge case where something refers to a table that does
# not actually exist
continue
table_to_foreign_keys[table_name]['incoming'].append({
'other_table': table,
'column': to_,
'other_column': from_
})
table_to_foreign_keys[table]['outgoing'].append({
'other_table': table_name,
'column': from_,
'other_column': to_
})
table_to_foreign_keys[table_name]["incoming"].append(
{"other_table": table, "column": to_, "other_column": from_}
)
table_to_foreign_keys[table]["outgoing"].append(
{"other_table": table_name, "column": from_, "other_column": to_}
)
return table_to_foreign_keys
def detect_spatialite(conn):
rows = conn.execute('select 1 from sqlite_master where tbl_name = "geometry_columns"').fetchall()
rows = conn.execute(
'select 1 from sqlite_master where tbl_name = "geometry_columns"'
).fetchall()
return len(rows) > 0
@ -557,7 +547,7 @@ def detect_fts(conn, table):
def detect_fts_sql(table):
return r'''
return r"""
select name from sqlite_master
where rootpage = 0
and (
@ -567,7 +557,9 @@ def detect_fts_sql(table):
and sql like '%VIRTUAL TABLE%USING FTS%'
)
)
'''.format(table=table)
""".format(
table=table
)
def detect_json1(conn=None):
@ -589,51 +581,53 @@ def table_columns(conn, table):
]
filter_column_re = re.compile(r'^_filter_column_\d+$')
filter_column_re = re.compile(r"^_filter_column_\d+$")
def filters_should_redirect(special_args):
redirect_params = []
# Handle _filter_column=foo&_filter_op=exact&_filter_value=...
filter_column = special_args.get('_filter_column')
filter_op = special_args.get('_filter_op') or ''
filter_value = special_args.get('_filter_value') or ''
if '__' in filter_op:
filter_op, filter_value = filter_op.split('__', 1)
filter_column = special_args.get("_filter_column")
filter_op = special_args.get("_filter_op") or ""
filter_value = special_args.get("_filter_value") or ""
if "__" in filter_op:
filter_op, filter_value = filter_op.split("__", 1)
if filter_column:
redirect_params.append(
('{}__{}'.format(filter_column, filter_op), filter_value)
("{}__{}".format(filter_column, filter_op), filter_value)
)
for key in ('_filter_column', '_filter_op', '_filter_value'):
for key in ("_filter_column", "_filter_op", "_filter_value"):
if key in special_args:
redirect_params.append((key, None))
# Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello
column_keys = [k for k in special_args if filter_column_re.match(k)]
for column_key in column_keys:
number = column_key.split('_')[-1]
number = column_key.split("_")[-1]
column = special_args[column_key]
op = special_args.get('_filter_op_{}'.format(number)) or 'exact'
value = special_args.get('_filter_value_{}'.format(number)) or ''
if '__' in op:
op, value = op.split('__', 1)
op = special_args.get("_filter_op_{}".format(number)) or "exact"
value = special_args.get("_filter_value_{}".format(number)) or ""
if "__" in op:
op, value = op.split("__", 1)
if column:
redirect_params.append(('{}__{}'.format(column, op), value))
redirect_params.extend([
('_filter_column_{}'.format(number), None),
('_filter_op_{}'.format(number), None),
('_filter_value_{}'.format(number), None),
])
redirect_params.append(("{}__{}".format(column, op), value))
redirect_params.extend(
[
("_filter_column_{}".format(number), None),
("_filter_op_{}".format(number), None),
("_filter_value_{}".format(number), None),
]
)
return redirect_params
whitespace_re = re.compile(r'\s')
whitespace_re = re.compile(r"\s")
def is_url(value):
"Must start with http:// or https:// and contain JUST a URL"
if not isinstance(value, str):
return False
if not value.startswith('http://') and not value.startswith('https://'):
if not value.startswith("http://") and not value.startswith("https://"):
return False
# Any whitespace at all is invalid
if whitespace_re.search(value):
@ -641,8 +635,8 @@ def is_url(value):
return True
css_class_re = re.compile(r'^[a-zA-Z]+[_a-zA-Z0-9-]*$')
css_invalid_chars_re = re.compile(r'[^a-zA-Z0-9_\-]')
css_class_re = re.compile(r"^[a-zA-Z]+[_a-zA-Z0-9-]*$")
css_invalid_chars_re = re.compile(r"[^a-zA-Z0-9_\-]")
def to_css_class(s):
@ -656,16 +650,16 @@ def to_css_class(s):
"""
if css_class_re.match(s):
return s
md5_suffix = hashlib.md5(s.encode('utf8')).hexdigest()[:6]
md5_suffix = hashlib.md5(s.encode("utf8")).hexdigest()[:6]
# Strip leading _, -
s = s.lstrip('_').lstrip('-')
s = s.lstrip("_").lstrip("-")
# Replace any whitespace with hyphens
s = '-'.join(s.split())
s = "-".join(s.split())
# Remove any remaining invalid characters
s = css_invalid_chars_re.sub('', s)
s = css_invalid_chars_re.sub("", s)
# Attach the md5 suffix
bits = [b for b in (s, md5_suffix) if b]
return '-'.join(bits)
return "-".join(bits)
def link_or_copy(src, dst):
@ -689,8 +683,8 @@ def module_from_path(path, name):
# Adapted from http://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html
mod = imp.new_module(name)
mod.__file__ = path
with open(path, 'r') as file:
code = compile(file.read(), path, 'exec', dont_inherit=True)
with open(path, "r") as file:
code = compile(file.read(), path, "exec", dont_inherit=True)
exec(code, mod.__dict__)
return mod
@ -702,37 +696,39 @@ def get_plugins(pm):
static_path = None
templates_path = None
try:
if pkg_resources.resource_isdir(plugin.__name__, 'static'):
static_path = pkg_resources.resource_filename(plugin.__name__, 'static')
if pkg_resources.resource_isdir(plugin.__name__, 'templates'):
templates_path = pkg_resources.resource_filename(plugin.__name__, 'templates')
if pkg_resources.resource_isdir(plugin.__name__, "static"):
static_path = pkg_resources.resource_filename(plugin.__name__, "static")
if pkg_resources.resource_isdir(plugin.__name__, "templates"):
templates_path = pkg_resources.resource_filename(
plugin.__name__, "templates"
)
except (KeyError, ImportError):
# Caused by --plugins_dir= plugins - KeyError/ImportError thrown in Py3.5
pass
plugin_info = {
'name': plugin.__name__,
'static_path': static_path,
'templates_path': templates_path,
"name": plugin.__name__,
"static_path": static_path,
"templates_path": templates_path,
}
distinfo = plugin_to_distinfo.get(plugin)
if distinfo:
plugin_info['version'] = distinfo.version
plugin_info["version"] = distinfo.version
plugins.append(plugin_info)
return plugins
async def resolve_table_and_format(table_and_format, table_exists, allowed_formats=[]):
if '.' in table_and_format:
if "." in table_and_format:
# Check if a table exists with this exact name
it_exists = await table_exists(table_and_format)
if it_exists:
return table_and_format, None
# Check if table ends with a known format
formats = list(allowed_formats) + ['csv', 'jsono']
formats = list(allowed_formats) + ["csv", "jsono"]
for _format in formats:
if table_and_format.endswith(".{}".format(_format)):
table = table_and_format[:-(len(_format) + 1)]
table = table_and_format[: -(len(_format) + 1)]
return table, _format
return table_and_format, None
@ -747,9 +743,7 @@ def path_with_format(request, format, extra_qs=None):
if qs:
extra = urllib.parse.urlencode(sorted(qs.items()))
if request.query_string:
path = "{}?{}&{}".format(
path, request.query_string, extra
)
path = "{}?{}&{}".format(path, request.query_string, extra)
else:
path = "{}?{}".format(path, extra)
elif request.query_string:
@ -777,9 +771,9 @@ class CustomRow(OrderedDict):
def value_as_boolean(value):
if value.lower() not in ('on', 'off', 'true', 'false', '1', '0'):
if value.lower() not in ("on", "off", "true", "false", "1", "0"):
raise ValueAsBooleanError
return value.lower() in ('on', 'true', '1')
return value.lower() in ("on", "true", "1")
class ValueAsBooleanError(ValueError):
@ -799,9 +793,9 @@ class LimitedWriter:
def write(self, bytes):
self.bytes_count += len(bytes)
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
raise WriteLimitExceeded("CSV contains more than {} bytes".format(
self.limit_bytes
))
raise WriteLimitExceeded(
"CSV contains more than {} bytes".format(self.limit_bytes)
)
self.writer.write(bytes)
@ -810,10 +804,7 @@ _infinities = {float("inf"), float("-inf")}
def remove_infinites(row):
if any((c in _infinities) if isinstance(c, float) else 0 for c in row):
return [
None if (isinstance(c, float) and c in _infinities) else c
for c in row
]
return [None if (isinstance(c, float) and c in _infinities) else c for c in row]
return row
@ -824,7 +815,8 @@ class StaticMount(click.ParamType):
if ":" not in value:
self.fail(
'"{}" should be of format mountpoint:directory'.format(value),
param, ctx
param,
ctx,
)
path, dirpath = value.split(":")
if not os.path.exists(dirpath) or not os.path.isdir(dirpath):

Wyświetl plik

@ -1,6 +1,6 @@
from ._version import get_versions
__version__ = get_versions()['version']
__version__ = get_versions()["version"]
del get_versions
__version_info__ = tuple(__version__.split("."))

Wyświetl plik

@ -33,8 +33,15 @@ HASH_LENGTH = 7
class DatasetteError(Exception):
def __init__(self, message, title=None, error_dict=None, status=500, template=None, messagge_is_html=False):
def __init__(
self,
message,
title=None,
error_dict=None,
status=500,
template=None,
messagge_is_html=False,
):
self.message = message
self.title = title
self.error_dict = error_dict or {}
@ -43,18 +50,19 @@ class DatasetteError(Exception):
class RenderMixin(HTTPMethodView):
def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins:
seen_urls = set()
for url_or_dict in itertools.chain(
itertools.chain.from_iterable(getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds
)),
(self.ds.metadata(key) or [])
itertools.chain.from_iterable(
getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds,
)
),
(self.ds.metadata(key) or []),
):
if isinstance(url_or_dict, dict):
url = url_or_dict["url"]
@ -73,14 +81,12 @@ class RenderMixin(HTTPMethodView):
def database_url(self, database):
db = self.ds.databases[database]
if self.ds.config("hash_urls") and db.hash:
return "/{}-{}".format(
database, db.hash[:HASH_LENGTH]
)
return "/{}-{}".format(database, db.hash[:HASH_LENGTH])
else:
return "/{}".format(database)
def database_color(self, database):
return 'ff0000'
return "ff0000"
def render(self, templates, **context):
template = self.ds.jinja_env.select_template(templates)
@ -95,7 +101,7 @@ class RenderMixin(HTTPMethodView):
database=context.get("database"),
table=context.get("table"),
view_name=self.name,
datasette=self.ds
datasette=self.ds,
):
body_scripts.append(jinja2.Markup(script))
return response.html(
@ -116,14 +122,14 @@ class RenderMixin(HTTPMethodView):
"format_bytes": format_bytes,
"database_url": self.database_url,
"database_color": self.database_color,
}
},
}
)
)
class BaseView(RenderMixin):
name = ''
name = ""
re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
def __init__(self, datasette):
@ -171,32 +177,30 @@ class BaseView(RenderMixin):
expected = "000"
if db.hash is not None:
expected = db.hash[:HASH_LENGTH]
correct_hash_provided = (expected == hash)
correct_hash_provided = expected == hash
if not correct_hash_provided:
if "table_and_format" in kwargs:
async def async_table_exists(t):
return await self.ds.table_exists(name, t)
table, _format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"]
),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
allowed_formats=self.ds.renderers.keys(),
)
kwargs["table"] = table
if _format:
kwargs["as_format"] = ".{}".format(_format)
elif "table" in kwargs:
kwargs["table"] = urllib.parse.unquote_plus(
kwargs["table"]
)
kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
should_redirect = "/{}-{}".format(name, expected)
if "table" in kwargs:
should_redirect += "/" + urllib.parse.quote_plus(
kwargs["table"]
)
should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
if "pk_path" in kwargs:
should_redirect += "/" + kwargs["pk_path"]
if "as_format" in kwargs:
@ -219,7 +223,9 @@ class BaseView(RenderMixin):
if should_redirect:
return self.redirect(request, should_redirect, remove_args={"_hash"})
return await self.view_get(request, database, hash, correct_hash_provided, **kwargs)
return await self.view_get(
request, database, hash, correct_hash_provided, **kwargs
)
async def as_csv(self, request, database, hash, **kwargs):
stream = request.args.get("_stream")
@ -228,9 +234,7 @@ class BaseView(RenderMixin):
if not self.ds.config("allow_csv_stream"):
raise DatasetteError("CSV streaming is disabled", status=400)
if request.args.get("_next"):
raise DatasetteError(
"_next not allowed for CSV streaming", status=400
)
raise DatasetteError("_next not allowed for CSV streaming", status=400)
kwargs["_size"] = "max"
# Fetch the first page
try:
@ -271,9 +275,7 @@ class BaseView(RenderMixin):
if next:
kwargs["_next"] = next
if not first:
data, _, _ = await self.data(
request, database, hash, **kwargs
)
data, _, _ = await self.data(request, database, hash, **kwargs)
if first:
writer.writerow(headings)
first = False
@ -293,7 +295,7 @@ class BaseView(RenderMixin):
new_row.append(cell)
writer.writerow(new_row)
except Exception as e:
print('caught this', e)
print("caught this", e)
r.write(str(e))
return
@ -304,15 +306,11 @@ class BaseView(RenderMixin):
if request.args.get("_dl", None):
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', database)
kwargs.get("table", database)
)
headers["Content-Disposition"] = disposition
return response.stream(
stream_fn,
headers=headers,
content_type=content_type
)
return response.stream(stream_fn, headers=headers, content_type=content_type)
async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL
@ -325,22 +323,20 @@ class BaseView(RenderMixin):
if not _format:
_format = (args.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in args:
async def async_table_exists(t):
return await self.ds.table_exists(database, t)
table, _ext_format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
args["table_and_format"]
),
table_and_format=urllib.parse.unquote_plus(args["table_and_format"]),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
allowed_formats=self.ds.renderers.keys(),
)
_format = _format or _ext_format
args["table"] = table
del args["table_and_format"]
elif "table" in args:
args["table"] = urllib.parse.unquote_plus(
args["table"]
)
args["table"] = urllib.parse.unquote_plus(args["table"])
return _format, args
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
@ -351,7 +347,7 @@ class BaseView(RenderMixin):
if _format is None:
# HTML views default to expanding all foriegn key labels
kwargs['default_labels'] = True
kwargs["default_labels"] = True
extra_template_data = {}
start = time.time()
@ -367,11 +363,16 @@ class BaseView(RenderMixin):
else:
data, extra_template_data, templates = response_or_template_contexts
except InterruptedError:
raise DatasetteError("""
raise DatasetteError(
"""
SQL query took too long. The time limit is controlled by the
<a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
configuration option.
""", title="SQL Interrupted", status=400, messagge_is_html=True)
""",
title="SQL Interrupted",
status=400,
messagge_is_html=True,
)
except (sqlite3.OperationalError, InvalidSql) as e:
raise DatasetteError(str(e), title="Invalid SQL", status=400)
@ -408,14 +409,14 @@ class BaseView(RenderMixin):
raise NotFound("No data")
response_args = {
'content_type': result.get('content_type', 'text/plain'),
'status': result.get('status_code', 200)
"content_type": result.get("content_type", "text/plain"),
"status": result.get("status_code", 200),
}
if type(result.get('body')) == bytes:
response_args['body_bytes'] = result.get('body')
if type(result.get("body")) == bytes:
response_args["body_bytes"] = result.get("body")
else:
response_args['body'] = result.get('body')
response_args["body"] = result.get("body")
r = response.HTTPResponse(**response_args)
else:
@ -431,14 +432,12 @@ class BaseView(RenderMixin):
url_labels_extra = {"_labels": "on"}
renderers = {
key: path_with_format(request, key, {**url_labels_extra}) for key in self.ds.renderers.keys()
}
url_csv_args = {
"_size": "max",
**url_labels_extra
key: path_with_format(request, key, {**url_labels_extra})
for key in self.ds.renderers.keys()
}
url_csv_args = {"_size": "max", **url_labels_extra}
url_csv = path_with_format(request, "csv", url_csv_args)
url_csv_path = url_csv.split('?')[0]
url_csv_path = url_csv.split("?")[0]
context = {
**data,
**extras,
@ -450,10 +449,11 @@ class BaseView(RenderMixin):
(key, value)
for key, value in urllib.parse.parse_qsl(request.query_string)
if key not in ("_labels", "_facet", "_size")
] + [("_size", "max")],
]
+ [("_size", "max")],
"datasette_version": __version__,
"config": self.ds.config_dict(),
}
},
}
if "metadata" not in context:
context["metadata"] = self.ds.metadata
@ -474,9 +474,9 @@ class BaseView(RenderMixin):
if self.ds.cache_headers and response.status == 200:
ttl = int(ttl)
if ttl == 0:
ttl_header = 'no-cache'
ttl_header = "no-cache"
else:
ttl_header = 'max-age={}'.format(ttl)
ttl_header = "max-age={}".format(ttl)
response.headers["Cache-Control"] = ttl_header
response.headers["Referrer-Policy"] = "no-referrer"
if self.ds.cors:
@ -484,8 +484,15 @@ class BaseView(RenderMixin):
return response
async def custom_sql(
self, request, database, hash, sql, editable=True, canned_query=None,
metadata=None, _size=None
self,
request,
database,
hash,
sql,
editable=True,
canned_query=None,
metadata=None,
_size=None,
):
params = request.raw_args
if "sql" in params:
@ -565,10 +572,14 @@ class BaseView(RenderMixin):
"hide_sql": "_hide_sql" in params,
}
return {
"database": database,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
}, extra_template, templates
return (
{
"database": database,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
},
extra_template,
templates,
)

Wyświetl plik

@ -15,7 +15,7 @@ from .base import HASH_LENGTH, RenderMixin
class IndexView(RenderMixin):
name = 'index'
name = "index"
def __init__(self, datasette):
self.ds = datasette
@ -43,23 +43,25 @@ class IndexView(RenderMixin):
}
hidden_tables = [t for t in tables.values() if t["hidden"]]
databases.append({
"name": name,
"hash": db.hash,
"color": db.hash[:6] if db.hash else hashlib.md5(name.encode("utf8")).hexdigest()[:6],
"path": self.database_url(name),
"tables_truncated": sorted(
tables.values(), key=lambda t: t["count"] or 0, reverse=True
)[
:5
],
"tables_count": len(tables),
"tables_more": len(tables) > 5,
"table_rows_sum": sum((t["count"] or 0) for t in tables.values()),
"hidden_table_rows_sum": sum(t["count"] for t in hidden_tables),
"hidden_tables_count": len(hidden_tables),
"views_count": len(views),
})
databases.append(
{
"name": name,
"hash": db.hash,
"color": db.hash[:6]
if db.hash
else hashlib.md5(name.encode("utf8")).hexdigest()[:6],
"path": self.database_url(name),
"tables_truncated": sorted(
tables.values(), key=lambda t: t["count"] or 0, reverse=True
)[:5],
"tables_count": len(tables),
"tables_more": len(tables) > 5,
"table_rows_sum": sum((t["count"] or 0) for t in tables.values()),
"hidden_table_rows_sum": sum(t["count"] for t in hidden_tables),
"hidden_tables_count": len(hidden_tables),
"views_count": len(views),
}
)
if as_format:
headers = {}
if self.ds.cors:

Wyświetl plik

@ -18,14 +18,8 @@ class JsonDataView(RenderMixin):
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
return response.HTTPResponse(
json.dumps(data),
content_type="application/json",
headers=headers
json.dumps(data), content_type="application/json", headers=headers
)
else:
return self.render(
["show_json.html"],
filename=self.filename,
data=data
)
return self.render(["show_json.html"], filename=self.filename, data=data)

Wyświetl plik

@ -31,12 +31,13 @@ from datasette.utils import (
from datasette.filters import Filters
from .base import BaseView, DatasetteError, ureg
LINK_WITH_LABEL = '<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
LINK_WITH_LABEL = (
'<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
)
LINK_WITH_VALUE = '<a href="/{database}/{table}/{link_id}">{id}</a>'
class RowTableShared(BaseView):
async def sortable_columns_for_table(self, database, table, use_rowid):
table_metadata = self.ds.table_metadata(database, table)
if "sortable_columns" in table_metadata:
@ -51,18 +52,14 @@ class RowTableShared(BaseView):
# Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = []
for fk in await self.ds.foreign_keys_for_table(database, table):
label_column = await self.ds.label_column_for_table(database, fk["other_table"])
label_column = await self.ds.label_column_for_table(
database, fk["other_table"]
)
expandables.append((fk, label_column))
return expandables
async def display_columns_and_rows(
self,
database,
table,
description,
rows,
link_column=False,
truncate_cells=0,
self, database, table, description, rows, link_column=False, truncate_cells=0
):
"Returns columns, rows for specified table - including fancy foreign key treatment"
table_metadata = self.ds.table_metadata(database, table)
@ -121,8 +118,10 @@ class RowTableShared(BaseView):
if plugin_display_value is not None:
display_value = plugin_display_value
elif isinstance(value, bytes):
display_value = jinja2.Markup("&lt;Binary&nbsp;data:&nbsp;{}&nbsp;byte{}&gt;".format(
len(value), "" if len(value) == 1 else "s")
display_value = jinja2.Markup(
"&lt;Binary&nbsp;data:&nbsp;{}&nbsp;byte{}&gt;".format(
len(value), "" if len(value) == 1 else "s"
)
)
elif isinstance(value, dict):
# It's an expanded foreign key - display link to other row
@ -133,13 +132,15 @@ class RowTableShared(BaseView):
link_template = (
LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE
)
display_value = jinja2.Markup(link_template.format(
database=database,
table=urllib.parse.quote_plus(other_table),
link_id=urllib.parse.quote_plus(str(value)),
id=str(jinja2.escape(value)),
label=str(jinja2.escape(label)),
))
display_value = jinja2.Markup(
link_template.format(
database=database,
table=urllib.parse.quote_plus(other_table),
link_id=urllib.parse.quote_plus(str(value)),
id=str(jinja2.escape(value)),
label=str(jinja2.escape(label)),
)
)
elif value in ("", None):
display_value = jinja2.Markup("&nbsp;")
elif is_url(str(value).strip()):
@ -180,9 +181,18 @@ class RowTableShared(BaseView):
class TableView(RowTableShared):
name = 'table'
name = "table"
async def data(self, request, database, hash, table, default_labels=False, _next=None, _size=None):
async def data(
self,
request,
database,
hash,
table,
default_labels=False,
_next=None,
_size=None,
):
canned_query = self.ds.get_canned_query(database, table)
if canned_query is not None:
return await self.custom_sql(
@ -271,12 +281,13 @@ class TableView(RowTableShared):
raise DatasetteError("_where= is not allowed", status=400)
else:
where_clauses.extend(request.args["_where"])
extra_wheres_for_ui = [{
"text": text,
"remove_url": path_with_removed_args(
request, {"_where": text}
)
} for text in request.args["_where"]]
extra_wheres_for_ui = [
{
"text": text,
"remove_url": path_with_removed_args(request, {"_where": text}),
}
for text in request.args["_where"]
]
# _search support:
fts_table = special_args.get("_fts_table")
@ -296,8 +307,7 @@ class TableView(RowTableShared):
search = search_args["_search"]
where_clauses.append(
"{fts_pk} in (select rowid from {fts_table} where {fts_table} match :search)".format(
fts_table=escape_sqlite(fts_table),
fts_pk=escape_sqlite(fts_pk)
fts_table=escape_sqlite(fts_table), fts_pk=escape_sqlite(fts_pk)
)
)
search_descriptions.append('search matches "{}"'.format(search))
@ -306,14 +316,16 @@ class TableView(RowTableShared):
# More complex: search against specific columns
for i, (key, search_text) in enumerate(search_args.items()):
search_col = key.split("_search_", 1)[1]
if search_col not in await self.ds.table_columns(database, fts_table):
if search_col not in await self.ds.table_columns(
database, fts_table
):
raise DatasetteError("Cannot search by that column", status=400)
where_clauses.append(
"rowid in (select rowid from {fts_table} where {search_col} match :search_{i})".format(
fts_table=escape_sqlite(fts_table),
search_col=escape_sqlite(search_col),
i=i
i=i,
)
)
search_descriptions.append(
@ -325,7 +337,9 @@ class TableView(RowTableShared):
sortable_columns = set()
sortable_columns = await self.sortable_columns_for_table(database, table, use_rowid)
sortable_columns = await self.sortable_columns_for_table(
database, table, use_rowid
)
# Allow for custom sort order
sort = special_args.get("_sort")
@ -346,9 +360,9 @@ class TableView(RowTableShared):
from_sql = "from {table_name} {where}".format(
table_name=escape_sqlite(table),
where=(
"where {} ".format(" and ".join(where_clauses))
) if where_clauses else "",
where=("where {} ".format(" and ".join(where_clauses)))
if where_clauses
else "",
)
# Copy of params so we can mutate them later:
from_sql_params = dict(**params)
@ -410,7 +424,9 @@ class TableView(RowTableShared):
column=escape_sqlite(sort or sort_desc),
op=">" if sort else "<",
p=len(params),
extra_desc_only="" if sort else " or {column2} is null".format(
extra_desc_only=""
if sort
else " or {column2} is null".format(
column2=escape_sqlite(sort or sort_desc)
),
next_clauses=" and ".join(next_by_pk_clauses),
@ -470,9 +486,7 @@ class TableView(RowTableShared):
order_by=order_by,
)
sql = "{sql_no_limit} limit {limit}{offset}".format(
sql_no_limit=sql_no_limit.rstrip(),
limit=page_size + 1,
offset=offset,
sql_no_limit=sql_no_limit.rstrip(), limit=page_size + 1, offset=offset
)
if request.raw_args.get("_timelimit"):
@ -486,15 +500,17 @@ class TableView(RowTableShared):
filtered_table_rows_count = None
if count_sql:
try:
count_rows = list(await self.ds.execute(
database, count_sql, from_sql_params
))
count_rows = list(
await self.ds.execute(database, count_sql, from_sql_params)
)
filtered_table_rows_count = count_rows[0][0]
except InterruptedError:
pass
# facets support
if not self.ds.config("allow_facet") and any(arg.startswith("_facet") for arg in request.args):
if not self.ds.config("allow_facet") and any(
arg.startswith("_facet") for arg in request.args
):
raise DatasetteError("_facet= is not allowed", status=400)
# pylint: disable=no-member
@ -505,19 +521,23 @@ class TableView(RowTableShared):
facets_timed_out = []
facet_instances = []
for klass in facet_classes:
facet_instances.append(klass(
self.ds,
request,
database,
sql=sql_no_limit,
params=params,
table=table,
metadata=table_metadata,
row_count=filtered_table_rows_count,
))
facet_instances.append(
klass(
self.ds,
request,
database,
sql=sql_no_limit,
params=params,
table=table,
metadata=table_metadata,
row_count=filtered_table_rows_count,
)
)
for facet in facet_instances:
instance_facet_results, instance_facets_timed_out = await facet.facet_results()
instance_facet_results, instance_facets_timed_out = (
await facet.facet_results()
)
facet_results.update(instance_facet_results)
facets_timed_out.extend(instance_facets_timed_out)
@ -542,9 +562,7 @@ class TableView(RowTableShared):
columns_to_expand = request.args["_label"]
if columns_to_expand is None and all_labels:
# expand all columns with foreign keys
columns_to_expand = [
fk["column"] for fk, _ in expandable_columns
]
columns_to_expand = [fk["column"] for fk, _ in expandable_columns]
if columns_to_expand:
expanded_labels = {}
@ -557,9 +575,9 @@ class TableView(RowTableShared):
column_index = columns.index(column)
values = [row[column_index] for row in rows]
# Expand them
expanded_labels.update(await self.ds.expand_foreign_keys(
database, table, column, values
))
expanded_labels.update(
await self.ds.expand_foreign_keys(database, table, column, values)
)
if expanded_labels:
# Rewrite the rows
new_rows = []
@ -569,8 +587,8 @@ class TableView(RowTableShared):
value = row[column]
if (column, value) in expanded_labels:
new_row[column] = {
'value': value,
'label': expanded_labels[(column, value)]
"value": value,
"label": expanded_labels[(column, value)],
}
else:
new_row[column] = value
@ -608,7 +626,11 @@ class TableView(RowTableShared):
# Detect suggested facets
suggested_facets = []
if self.ds.config("suggest_facets") and self.ds.config("allow_facet") and not _next:
if (
self.ds.config("suggest_facets")
and self.ds.config("allow_facet")
and not _next
):
for facet in facet_instances:
# TODO: ensure facet is not suggested if it is already active
# used to use 'if facet_column in facets' for this
@ -634,10 +656,11 @@ class TableView(RowTableShared):
link_column=not is_view,
truncate_cells=self.ds.config("truncate_cells_html"),
)
metadata = (self.ds.metadata("databases") or {}).get(database, {}).get(
"tables", {}
).get(
table, {}
metadata = (
(self.ds.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {})
)
self.ds.update_with_inherited_metadata(metadata)
form_hidden_args = []
@ -656,7 +679,7 @@ class TableView(RowTableShared):
"sorted_facet_results": sorted(
facet_results.values(),
key=lambda f: (len(f["results"]), f["name"]),
reverse=True
reverse=True,
),
"extra_wheres_for_ui": extra_wheres_for_ui,
"form_hidden_args": form_hidden_args,
@ -682,32 +705,36 @@ class TableView(RowTableShared):
"table_definition": await self.ds.get_table_definition(database, table),
}
return {
"database": database,
"table": table,
"is_view": is_view,
"human_description_en": human_description_en,
"rows": rows[:page_size],
"truncated": results.truncated,
"filtered_table_rows_count": filtered_table_rows_count,
"expanded_columns": expanded_columns,
"expandable_columns": expandable_columns,
"columns": columns,
"primary_keys": pks,
"units": units,
"query": {"sql": sql, "params": params},
"facet_results": facet_results,
"suggested_facets": suggested_facets,
"next": next_value and str(next_value) or None,
"next_url": next_url,
}, extra_template, (
"table-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"table.html",
return (
{
"database": database,
"table": table,
"is_view": is_view,
"human_description_en": human_description_en,
"rows": rows[:page_size],
"truncated": results.truncated,
"filtered_table_rows_count": filtered_table_rows_count,
"expanded_columns": expanded_columns,
"expandable_columns": expandable_columns,
"columns": columns,
"primary_keys": pks,
"units": units,
"query": {"sql": sql, "params": params},
"facet_results": facet_results,
"suggested_facets": suggested_facets,
"next": next_value and str(next_value) or None,
"next_url": next_url,
},
extra_template,
(
"table-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"table.html",
),
)
class RowView(RowTableShared):
name = 'row'
name = "row"
async def data(self, request, database, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
@ -720,15 +747,13 @@ class RowView(RowTableShared):
select = "rowid, *"
pks = ["rowid"]
wheres = ['"{}"=:p{}'.format(pk, i) for i, pk in enumerate(pks)]
sql = 'select {} from {} where {}'.format(
sql = "select {} from {} where {}".format(
select, escape_sqlite(table), " AND ".join(wheres)
)
params = {}
for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value
results = await self.ds.execute(
database, sql, params, truncate=True
)
results = await self.ds.execute(database, sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
@ -760,13 +785,10 @@ class RowView(RowTableShared):
),
"_rows_and_columns.html",
],
"metadata": (
self.ds.metadata("databases") or {}
).get(database, {}).get(
"tables", {}
).get(
table, {}
),
"metadata": (self.ds.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {}),
}
data = {
@ -784,8 +806,13 @@ class RowView(RowTableShared):
database, table, pk_values
)
return data, template_data, (
"row-{}-{}.html".format(to_css_class(database), to_css_class(table)), "row.html"
return (
data,
template_data,
(
"row-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"row.html",
),
)
async def foreign_key_tables(self, database, table, pk_values):
@ -801,7 +828,7 @@ class RowView(RowTableShared):
sql = "select " + ", ".join(
[
'(select count(*) from {table} where {column}=:id)'.format(
"(select count(*) from {table} where {column}=:id)".format(
table=escape_sqlite(fk["other_table"]),
column=escape_sqlite(fk["other_column"]),
)
@ -822,8 +849,8 @@ class RowView(RowTableShared):
)
foreign_key_tables = []
for fk in foreign_keys:
count = foreign_table_counts.get(
(fk["other_table"], fk["other_column"])
) or 0
count = (
foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0
)
foreign_key_tables.append({**fk, **{"count": count}})
return foreign_key_tables

Wyświetl plik

@ -1,72 +1,78 @@
from setuptools import setup, find_packages
import os
import sys
import versioneer
def get_long_description():
with open(os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'README.md'
), encoding='utf8') as fp:
with open(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
encoding="utf8",
) as fp:
return fp.read()
def get_version():
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'datasette', 'version.py'
os.path.dirname(os.path.abspath(__file__)), "datasette", "version.py"
)
g = {}
exec(open(path).read(), g)
return g['__version__']
return g["__version__"]
# Only install black on Python 3.6 or higher
maybe_black = []
if sys.version_info > (3, 6):
maybe_black = ["black"]
setup(
name='datasette',
name="datasette",
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(),
description='An instant JSON API for your SQLite databases',
description="An instant JSON API for your SQLite databases",
long_description=get_long_description(),
long_description_content_type='text/markdown',
author='Simon Willison',
license='Apache License, Version 2.0',
url='https://github.com/simonw/datasette',
long_description_content_type="text/markdown",
author="Simon Willison",
license="Apache License, Version 2.0",
url="https://github.com/simonw/datasette",
packages=find_packages(),
package_data={'datasette': ['templates/*.html']},
package_data={"datasette": ["templates/*.html"]},
include_package_data=True,
install_requires=[
'click>=6.7',
'click-default-group==1.2',
'Sanic==0.7.0',
'Jinja2==2.10.1',
'hupper==1.0',
'pint==0.8.1',
'pluggy>=0.7.1',
"click>=6.7",
"click-default-group==1.2",
"Sanic==0.7.0",
"Jinja2==2.10.1",
"hupper==1.0",
"pint==0.8.1",
"pluggy>=0.7.1",
],
entry_points='''
entry_points="""
[console_scripts]
datasette=datasette.cli:cli
''',
setup_requires=['pytest-runner'],
""",
setup_requires=["pytest-runner"],
extras_require={
'test': [
'pytest==4.0.2',
'pytest-asyncio==0.10.0',
'aiohttp==3.5.3',
'beautifulsoup4==4.6.1',
"test": [
"pytest==4.0.2",
"pytest-asyncio==0.10.0",
"aiohttp==3.5.3",
"beautifulsoup4==4.6.1",
]
+ maybe_black
},
tests_require=[
'datasette[test]',
],
tests_require=["datasette[test]"],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Intended Audience :: End Users/Desktop',
'Topic :: Database',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.5',
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: End Users/Desktop",
"Topic :: Database",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.5",
],
)

Wyświetl plik

@ -8,3 +8,10 @@ def pytest_unconfigure(config):
import sys
del sys._called_from_test
def pytest_collection_modifyitems(items):
# Ensure test_black.py runs first before any asyncio code kicks in
test_black = [fn for fn in items if fn.name == "test_black"]
if test_black:
items.insert(0, items.pop(items.index(test_black[0])))

Wyświetl plik

@ -17,9 +17,7 @@ class TestClient:
def get(self, path, allow_redirects=True):
return self.sanic_test_client.get(
path,
allow_redirects=allow_redirects,
gather_request=False
path, allow_redirects=allow_redirects, gather_request=False
)
@ -79,39 +77,35 @@ def app_client_no_files():
client.ds = ds
yield client
@pytest.fixture(scope="session")
def app_client_with_memory():
yield from make_app_client(memory=True)
@pytest.fixture(scope="session")
def app_client_with_hash():
yield from make_app_client(config={
'hash_urls': True,
}, is_immutable=True)
yield from make_app_client(config={"hash_urls": True}, is_immutable=True)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_shorter_time_limit():
yield from make_app_client(20)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_returned_rows_matches_page_size():
yield from make_app_client(max_returned_rows=50)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_larger_cache_size():
yield from make_app_client(config={
'cache_size_kb': 2500,
})
yield from make_app_client(config={"cache_size_kb": 2500})
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_csv_max_mb_one():
yield from make_app_client(config={
'max_csv_mb': 1,
})
yield from make_app_client(config={"max_csv_mb": 1})
@pytest.fixture(scope="session")
@ -119,7 +113,7 @@ def app_client_with_dot():
yield from make_app_client(filename="fixtures.dot.db")
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_with_cors():
yield from make_app_client(cors=True)
@ -128,7 +122,7 @@ def generate_compound_rows(num):
for a, b, c in itertools.islice(
itertools.product(string.ascii_lowercase, repeat=3), num
):
yield a, b, c, '{}-{}-{}'.format(a, b, c)
yield a, b, c, "{}-{}-{}".format(a, b, c)
def generate_sortable_rows(num):
@ -137,107 +131,81 @@ def generate_sortable_rows(num):
itertools.product(string.ascii_lowercase, repeat=2), num
):
yield {
'pk1': a,
'pk2': b,
'content': '{}-{}'.format(a, b),
'sortable': rand.randint(-100, 100),
'sortable_with_nulls': rand.choice([
None, rand.random(), rand.random()
]),
'sortable_with_nulls_2': rand.choice([
None, rand.random(), rand.random()
]),
'text': rand.choice(['$null', '$blah']),
"pk1": a,
"pk2": b,
"content": "{}-{}".format(a, b),
"sortable": rand.randint(-100, 100),
"sortable_with_nulls": rand.choice([None, rand.random(), rand.random()]),
"sortable_with_nulls_2": rand.choice([None, rand.random(), rand.random()]),
"text": rand.choice(["$null", "$blah"]),
}
METADATA = {
'title': 'Datasette Fixtures',
'description': 'An example SQLite database demonstrating Datasette',
'license': 'Apache License 2.0',
'license_url': 'https://github.com/simonw/datasette/blob/master/LICENSE',
'source': 'tests/fixtures.py',
'source_url': 'https://github.com/simonw/datasette/blob/master/tests/fixtures.py',
'about': 'About Datasette',
'about_url': 'https://github.com/simonw/datasette',
"plugins": {
"name-of-plugin": {
"depth": "root"
}
},
'databases': {
'fixtures': {
'description': 'Test tables description',
"plugins": {
"name-of-plugin": {
"depth": "database"
}
},
'tables': {
'simple_primary_key': {
'description_html': 'Simple <em>primary</em> key',
'title': 'This <em>HTML</em> is escaped',
"title": "Datasette Fixtures",
"description": "An example SQLite database demonstrating Datasette",
"license": "Apache License 2.0",
"license_url": "https://github.com/simonw/datasette/blob/master/LICENSE",
"source": "tests/fixtures.py",
"source_url": "https://github.com/simonw/datasette/blob/master/tests/fixtures.py",
"about": "About Datasette",
"about_url": "https://github.com/simonw/datasette",
"plugins": {"name-of-plugin": {"depth": "root"}},
"databases": {
"fixtures": {
"description": "Test tables description",
"plugins": {"name-of-plugin": {"depth": "database"}},
"tables": {
"simple_primary_key": {
"description_html": "Simple <em>primary</em> key",
"title": "This <em>HTML</em> is escaped",
"plugins": {
"name-of-plugin": {
"depth": "table",
"special": "this-is-simple_primary_key"
"special": "this-is-simple_primary_key",
}
}
},
},
'sortable': {
'sortable_columns': [
'sortable',
'sortable_with_nulls',
'sortable_with_nulls_2',
'text',
"sortable": {
"sortable_columns": [
"sortable",
"sortable_with_nulls",
"sortable_with_nulls_2",
"text",
],
"plugins": {
"name-of-plugin": {
"depth": "table"
}
}
"plugins": {"name-of-plugin": {"depth": "table"}},
},
'no_primary_key': {
'sortable_columns': [],
'hidden': True,
"no_primary_key": {"sortable_columns": [], "hidden": True},
"units": {"units": {"distance": "m", "frequency": "Hz"}},
"primary_key_multiple_columns_explicit_label": {
"label_column": "content2"
},
'units': {
'units': {
'distance': 'm',
'frequency': 'Hz'
}
"simple_view": {"sortable_columns": ["content"]},
"searchable_view_configured_by_metadata": {
"fts_table": "searchable_fts",
"fts_pk": "pk",
},
'primary_key_multiple_columns_explicit_label': {
'label_column': 'content2',
},
'simple_view': {
'sortable_columns': ['content'],
},
'searchable_view_configured_by_metadata': {
'fts_table': 'searchable_fts',
'fts_pk': 'pk'
}
},
'queries': {
'pragma_cache_size': 'PRAGMA cache_size;',
'neighborhood_search': {
'sql': '''
"queries": {
"pragma_cache_size": "PRAGMA cache_size;",
"neighborhood_search": {
"sql": """
select neighborhood, facet_cities.name, state
from facetable
join facet_cities
on facetable.city_id = facet_cities.id
where neighborhood like '%' || :text || '%'
order by neighborhood;
''',
'title': 'Search neighborhoods',
'description_html': '<b>Demonstrating</b> simple like search',
""",
"title": "Search neighborhoods",
"description_html": "<b>Demonstrating</b> simple like search",
},
}
},
}
},
}
},
}
PLUGIN1 = '''
PLUGIN1 = """
from datasette import hookimpl
import base64
import pint
@ -304,9 +272,9 @@ def render_cell(value, column, table, database, datasette):
table=table,
)
})
'''
"""
PLUGIN2 = '''
PLUGIN2 = """
from datasette import hookimpl
import jinja2
import json
@ -349,9 +317,10 @@ def render_cell(value, database):
label=jinja2.escape(data["label"] or "") or "&nbsp;"
)
)
'''
"""
TABLES = '''
TABLES = (
"""
CREATE TABLE simple_primary_key (
id varchar(30) primary key,
content text
@ -581,26 +550,42 @@ CREATE VIEW searchable_view AS
CREATE VIEW searchable_view_configured_by_metadata AS
SELECT * from searchable;
''' + '\n'.join([
'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(i=i + 1)
for i in range(201)
]) + '\n'.join([
'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format(
a=a, b=b, c=c, content=content
) for a, b, c, content in generate_compound_rows(1001)
]) + '\n'.join([
'''INSERT INTO sortable VALUES (
"""
+ "\n".join(
[
'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(
i=i + 1
)
for i in range(201)
]
)
+ "\n".join(
[
'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format(
a=a, b=b, c=c, content=content
)
for a, b, c, content in generate_compound_rows(1001)
]
)
+ "\n".join(
[
"""INSERT INTO sortable VALUES (
"{pk1}", "{pk2}", "{content}", {sortable},
{sortable_with_nulls}, {sortable_with_nulls_2}, "{text}");
'''.format(
**row
).replace('None', 'null') for row in generate_sortable_rows(201)
])
TABLE_PARAMETERIZED_SQL = [(
"insert into binary_data (data) values (?);", [b'this is binary data']
)]
""".format(
**row
).replace(
"None", "null"
)
for row in generate_sortable_rows(201)
]
)
)
TABLE_PARAMETERIZED_SQL = [
("insert into binary_data (data) values (?);", [b"this is binary data"])
]
if __name__ == '__main__':
if __name__ == "__main__":
# Can be called with data.db OR data.db metadata.json
db_filename = sys.argv[-1]
metadata_filename = None
@ -615,9 +600,7 @@ if __name__ == '__main__':
conn.execute(sql, params)
print("Test tables written to {}".format(db_filename))
if metadata_filename:
open(metadata_filename, 'w').write(json.dumps(METADATA))
open(metadata_filename, "w").write(json.dumps(METADATA))
print("- metadata written to {}".format(metadata_filename))
else:
print("Usage: {} db_to_write.db [metadata_to_write.json]".format(
sys.argv[0]
))
print("Usage: {} db_to_write.db [metadata_to_write.json]".format(sys.argv[0]))

Plik diff jest za duży Load Diff

Wyświetl plik

@ -0,0 +1,20 @@
from click.testing import CliRunner
from pathlib import Path
import pytest
import sys
code_root = Path(__file__).parent.parent
@pytest.mark.skipif(
sys.version_info[:2] < (3, 6), reason="Black requires Python 3.6 or later"
)
def test_black():
# Do not import at top of module because Python 3.5 will not have it installed
import black
runner = CliRunner()
result = runner.invoke(
black.main, [str(code_root / "tests"), str(code_root / "datasette"), "--check"]
)
assert result.exit_code == 0, result.output

Wyświetl plik

@ -1,22 +1,26 @@
from .fixtures import ( # noqa
from .fixtures import ( # noqa
app_client,
app_client_csv_max_mb_one,
app_client_with_cors
app_client_with_cors,
)
EXPECTED_TABLE_CSV = '''id,content
EXPECTED_TABLE_CSV = """id,content
1,hello
2,world
3,
4,RENDER_CELL_DEMO
'''.replace('\n', '\r\n')
""".replace(
"\n", "\r\n"
)
EXPECTED_CUSTOM_CSV = '''content
EXPECTED_CUSTOM_CSV = """content
hello
world
'''.replace('\n', '\r\n')
""".replace(
"\n", "\r\n"
)
EXPECTED_TABLE_WITH_LABELS_CSV = '''
EXPECTED_TABLE_WITH_LABELS_CSV = """
pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags
1,1,1,CA,1,San Francisco,Mission,"[""tag1"", ""tag2""]"
2,1,1,CA,1,San Francisco,Dogpatch,"[""tag1"", ""tag3""]"
@ -33,45 +37,47 @@ pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags
13,1,1,MI,3,Detroit,Corktown,[]
14,1,1,MI,3,Detroit,Mexicantown,[]
15,2,0,MC,4,Memnonia,Arcadia Planitia,[]
'''.lstrip().replace('\n', '\r\n')
""".lstrip().replace(
"\n", "\r\n"
)
def test_table_csv(app_client):
response = app_client.get('/fixtures/simple_primary_key.csv')
response = app_client.get("/fixtures/simple_primary_key.csv")
assert response.status == 200
assert not response.headers.get("Access-Control-Allow-Origin")
assert 'text/plain; charset=utf-8' == response.headers['Content-Type']
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_TABLE_CSV == response.text
def test_table_csv_cors_headers(app_client_with_cors):
response = app_client_with_cors.get('/fixtures/simple_primary_key.csv')
response = app_client_with_cors.get("/fixtures/simple_primary_key.csv")
assert response.status == 200
assert "*" == response.headers["Access-Control-Allow-Origin"]
def test_table_csv_with_labels(app_client):
response = app_client.get('/fixtures/facetable.csv?_labels=1')
response = app_client.get("/fixtures/facetable.csv?_labels=1")
assert response.status == 200
assert 'text/plain; charset=utf-8' == response.headers['Content-Type']
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text
def test_custom_sql_csv(app_client):
response = app_client.get(
'/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2'
"/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
)
assert response.status == 200
assert 'text/plain; charset=utf-8' == response.headers['Content-Type']
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_CUSTOM_CSV == response.text
def test_table_csv_download(app_client):
response = app_client.get('/fixtures/simple_primary_key.csv?_dl=1')
response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1")
assert response.status == 200
assert 'text/csv; charset=utf-8' == response.headers['Content-Type']
assert "text/csv; charset=utf-8" == response.headers["Content-Type"]
expected_disposition = 'attachment; filename="simple_primary_key.csv"'
assert expected_disposition == response.headers['Content-Disposition']
assert expected_disposition == response.headers["Content-Disposition"]
def test_max_csv_mb(app_client_csv_max_mb_one):
@ -88,12 +94,8 @@ def test_max_csv_mb(app_client_csv_max_mb_one):
def test_table_csv_stream(app_client):
# Without _stream should return header + 100 rows:
response = app_client.get(
"/fixtures/compound_three_primary_keys.csv?_size=max"
)
response = app_client.get("/fixtures/compound_three_primary_keys.csv?_size=max")
assert 101 == len([b for b in response.body.split(b"\r\n") if b])
# With _stream=1 should return header + 1001 rows
response = app_client.get(
"/fixtures/compound_three_primary_keys.csv?_stream=1"
)
response = app_client.get("/fixtures/compound_three_primary_keys.csv?_stream=1")
assert 1002 == len([b for b in response.body.split(b"\r\n") if b])

Wyświetl plik

@ -9,13 +9,13 @@ from pathlib import Path
import pytest
import re
docs_path = Path(__file__).parent.parent / 'docs'
label_re = re.compile(r'\.\. _([^\s:]+):')
docs_path = Path(__file__).parent.parent / "docs"
label_re = re.compile(r"\.\. _([^\s:]+):")
def get_headings(filename, underline="-"):
content = (docs_path / filename).open().read()
heading_re = re.compile(r'(\w+)(\([^)]*\))?\n\{}+\n'.format(underline))
heading_re = re.compile(r"(\w+)(\([^)]*\))?\n\{}+\n".format(underline))
return set(h[0] for h in heading_re.findall(content))
@ -24,38 +24,37 @@ def get_labels(filename):
return set(label_re.findall(content))
@pytest.mark.parametrize('config', app.CONFIG_OPTIONS)
@pytest.mark.parametrize("config", app.CONFIG_OPTIONS)
def test_config_options_are_documented(config):
assert config.name in get_headings("config.rst")
@pytest.mark.parametrize("name,filename", (
("serve", "datasette-serve-help.txt"),
("package", "datasette-package-help.txt"),
("publish now", "datasette-publish-now-help.txt"),
("publish heroku", "datasette-publish-heroku-help.txt"),
("publish cloudrun", "datasette-publish-cloudrun-help.txt"),
))
@pytest.mark.parametrize(
"name,filename",
(
("serve", "datasette-serve-help.txt"),
("package", "datasette-package-help.txt"),
("publish now", "datasette-publish-now-help.txt"),
("publish heroku", "datasette-publish-heroku-help.txt"),
("publish cloudrun", "datasette-publish-cloudrun-help.txt"),
),
)
def test_help_includes(name, filename):
expected = open(str(docs_path / filename)).read()
runner = CliRunner()
result = runner.invoke(cli, name.split() + ["--help"], terminal_width=88)
actual = "$ datasette {} --help\n\n{}".format(
name, result.output
)
actual = "$ datasette {} --help\n\n{}".format(name, result.output)
# actual has "Usage: cli package [OPTIONS] FILES"
# because it doesn't know that cli will be aliased to datasette
expected = expected.replace("Usage: datasette", "Usage: cli")
assert expected == actual
@pytest.mark.parametrize('plugin', [
name for name in dir(app.pm.hook) if not name.startswith('_')
])
@pytest.mark.parametrize(
"plugin", [name for name in dir(app.pm.hook) if not name.startswith("_")]
)
def test_plugin_hooks_are_documented(plugin):
headings = [
s.split("(")[0] for s in get_headings("plugins.rst", "~")
]
headings = [s.split("(")[0] for s in get_headings("plugins.rst", "~")]
assert plugin in headings

Wyświetl plik

@ -2,102 +2,57 @@ from datasette.filters import Filters
import pytest
@pytest.mark.parametrize('args,expected_where,expected_params', [
(
@pytest.mark.parametrize(
"args,expected_where,expected_params",
[
((("name_english__contains", "foo"),), ['"name_english" like :p0'], ["%foo%"]),
(
('name_english__contains', 'foo'),
(("foo", "bar"), ("bar__contains", "baz")),
['"bar" like :p0', '"foo" = :p1'],
["%baz%", "bar"],
),
['"name_english" like :p0'],
['%foo%']
),
(
(
('foo', 'bar'),
('bar__contains', 'baz'),
(("foo__startswith", "bar"), ("bar__endswith", "baz")),
['"bar" like :p0', '"foo" like :p1'],
["%baz", "bar%"],
),
['"bar" like :p0', '"foo" = :p1'],
['%baz%', 'bar']
),
(
(
('foo__startswith', 'bar'),
('bar__endswith', 'baz'),
(("foo__lt", "1"), ("bar__gt", "2"), ("baz__gte", "3"), ("bax__lte", "4")),
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1],
),
['"bar" like :p0', '"foo" like :p1'],
['%baz', 'bar%']
),
(
(
('foo__lt', '1'),
('bar__gt', '2'),
('baz__gte', '3'),
('bax__lte', '4'),
(("foo__like", "2%2"), ("zax__glob", "3*")),
['"foo" like :p0', '"zax" glob :p1'],
["2%2", "3*"],
),
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1]
),
(
# Multiple like arguments:
(
('foo__like', '2%2'),
('zax__glob', '3*'),
(("foo__like", "2%2"), ("foo__like", "3%3")),
['"foo" like :p0', '"foo" like :p1'],
["2%2", "3%3"],
),
['"foo" like :p0', '"zax" glob :p1'],
['2%2', '3*']
),
# Multiple like arguments:
(
(
('foo__like', '2%2'),
('foo__like', '3%3'),
(("foo__isnull", "1"), ("baz__isnull", "1"), ("bar__gt", "10")),
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10],
),
['"foo" like :p0', '"foo" like :p1'],
['2%2', '3%3']
),
(
((("foo__in", "1,2,3"),), ["foo in (:p0, :p1, :p2)"], ["1", "2", "3"]),
# date
((("foo__date", "1988-01-01"),), ["date(foo) = :p0"], ["1988-01-01"]),
# JSON array variants of __in (useful for unexpected characters)
((("foo__in", "[1,2,3]"),), ["foo in (:p0, :p1, :p2)"], [1, 2, 3]),
(
('foo__isnull', '1'),
('baz__isnull', '1'),
('bar__gt', '10'),
(("foo__in", '["dog,cat", "cat[dog]"]'),),
["foo in (:p0, :p1)"],
["dog,cat", "cat[dog]"],
),
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10]
),
(
(
('foo__in', '1,2,3'),
),
['foo in (:p0, :p1, :p2)'],
["1", "2", "3"]
),
# date
(
(
("foo__date", "1988-01-01"),
),
["date(foo) = :p0"],
["1988-01-01"]
),
# JSON array variants of __in (useful for unexpected characters)
(
(
('foo__in', '[1,2,3]'),
),
['foo in (:p0, :p1, :p2)'],
[1, 2, 3]
),
(
(
('foo__in', '["dog,cat", "cat[dog]"]'),
),
['foo in (:p0, :p1)'],
["dog,cat", "cat[dog]"]
),
])
],
)
def test_build_where(args, expected_where, expected_params):
f = Filters(sorted(args))
sql_bits, actual_params = f.build_where_clauses("table")
assert expected_where == sql_bits
assert {
'p{}'.format(i): param
for i, param in enumerate(expected_params)
"p{}".format(i): param for i, param in enumerate(expected_params)
} == actual_params

Plik diff jest za duży Load Diff

Wyświetl plik

@ -5,7 +5,7 @@ import pytest
import tempfile
TABLES = '''
TABLES = """
CREATE TABLE "election_results" (
"county" INTEGER,
"party" INTEGER,
@ -32,13 +32,13 @@ CREATE TABLE "office" (
"id" INTEGER PRIMARY KEY ,
"name" TEXT
);
'''
"""
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def ds_instance():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, 'fixtures.db')
filepath = os.path.join(tmpdir, "fixtures.db")
conn = sqlite3.connect(filepath)
conn.executescript(TABLES)
yield Datasette([filepath])
@ -46,58 +46,47 @@ def ds_instance():
def test_inspect_hidden_tables(ds_instance):
info = ds_instance.inspect()
tables = info['fixtures']['tables']
tables = info["fixtures"]["tables"]
expected_hidden = (
'election_results_fts',
'election_results_fts_content',
'election_results_fts_docsize',
'election_results_fts_segdir',
'election_results_fts_segments',
'election_results_fts_stat',
)
expected_visible = (
'election_results',
'county',
'party',
'office',
"election_results_fts",
"election_results_fts_content",
"election_results_fts_docsize",
"election_results_fts_segdir",
"election_results_fts_segments",
"election_results_fts_stat",
)
expected_visible = ("election_results", "county", "party", "office")
assert sorted(expected_hidden) == sorted(
[table for table in tables if tables[table]['hidden']]
[table for table in tables if tables[table]["hidden"]]
)
assert sorted(expected_visible) == sorted(
[table for table in tables if not tables[table]['hidden']]
[table for table in tables if not tables[table]["hidden"]]
)
def test_inspect_foreign_keys(ds_instance):
info = ds_instance.inspect()
tables = info['fixtures']['tables']
for table_name in ('county', 'party', 'office'):
assert 0 == tables[table_name]['count']
foreign_keys = tables[table_name]['foreign_keys']
assert [] == foreign_keys['outgoing']
assert [{
'column': 'id',
'other_column': table_name,
'other_table': 'election_results'
}] == foreign_keys['incoming']
tables = info["fixtures"]["tables"]
for table_name in ("county", "party", "office"):
assert 0 == tables[table_name]["count"]
foreign_keys = tables[table_name]["foreign_keys"]
assert [] == foreign_keys["outgoing"]
assert [
{
"column": "id",
"other_column": table_name,
"other_table": "election_results",
}
] == foreign_keys["incoming"]
election_results = tables['election_results']
assert 0 == election_results['count']
assert sorted([{
'column': 'county',
'other_column': 'id',
'other_table': 'county'
}, {
'column': 'party',
'other_column': 'id',
'other_table': 'party'
}, {
'column': 'office',
'other_column': 'id',
'other_table': 'office'
}], key=lambda d: d['column']) == sorted(
election_results['foreign_keys']['outgoing'],
key=lambda d: d['column']
)
assert [] == election_results['foreign_keys']['incoming']
election_results = tables["election_results"]
assert 0 == election_results["count"]
assert sorted(
[
{"column": "county", "other_column": "id", "other_table": "county"},
{"column": "party", "other_column": "id", "other_table": "party"},
{"column": "office", "other_column": "id", "other_table": "office"},
],
key=lambda d: d["column"],
) == sorted(election_results["foreign_keys"]["outgoing"], key=lambda d: d["column"])
assert [] == election_results["foreign_keys"]["incoming"]

Wyświetl plik

@ -1,7 +1,5 @@
from bs4 import BeautifulSoup as Soup
from .fixtures import ( # noqa
app_client,
)
from .fixtures import app_client # noqa
import base64
import json
import re
@ -13,41 +11,26 @@ def test_plugins_dir_plugin(app_client):
response = app_client.get(
"/fixtures.json?sql=select+convert_units(100%2C+'m'%2C+'ft')"
)
assert pytest.approx(328.0839) == response.json['rows'][0][0]
assert pytest.approx(328.0839) == response.json["rows"][0][0]
@pytest.mark.parametrize(
"path,expected_decoded_object",
[
(
"/",
{
"template": "index.html",
"database": None,
"table": None,
},
),
("/", {"template": "index.html", "database": None, "table": None}),
(
"/fixtures/",
{
"template": "database.html",
"database": "fixtures",
"table": None,
},
{"template": "database.html", "database": "fixtures", "table": None},
),
(
"/fixtures/sortable",
{
"template": "table.html",
"database": "fixtures",
"table": "sortable",
},
{"template": "table.html", "database": "fixtures", "table": "sortable"},
),
],
)
def test_plugin_extra_css_urls(app_client, path, expected_decoded_object):
response = app_client.get(path)
links = Soup(response.body, 'html.parser').findAll('link')
links = Soup(response.body, "html.parser").findAll("link")
special_href = [
l for l in links if l.attrs["href"].endswith("/extra-css-urls-demo.css")
][0]["href"]
@ -59,47 +42,43 @@ def test_plugin_extra_css_urls(app_client, path, expected_decoded_object):
def test_plugin_extra_js_urls(app_client):
response = app_client.get('/')
scripts = Soup(response.body, 'html.parser').findAll('script')
response = app_client.get("/")
scripts = Soup(response.body, "html.parser").findAll("script")
assert [
s for s in scripts
if s.attrs == {
'integrity': 'SRIHASH',
'crossorigin': 'anonymous',
'src': 'https://example.com/jquery.js'
s
for s in scripts
if s.attrs
== {
"integrity": "SRIHASH",
"crossorigin": "anonymous",
"src": "https://example.com/jquery.js",
}
]
def test_plugins_with_duplicate_js_urls(app_client):
# If two plugins both require jQuery, jQuery should be loaded only once
response = app_client.get(
"/fixtures"
)
response = app_client.get("/fixtures")
# This test is a little tricky, as if the user has any other plugins in
# their current virtual environment those may affect what comes back too.
# What matters is that https://example.com/jquery.js is only there once
# and it comes before plugin1.js and plugin2.js which could be in either
# order
scripts = Soup(response.body, 'html.parser').findAll('script')
srcs = [s['src'] for s in scripts if s.get('src')]
scripts = Soup(response.body, "html.parser").findAll("script")
srcs = [s["src"] for s in scripts if s.get("src")]
# No duplicates allowed:
assert len(srcs) == len(set(srcs))
# jquery.js loaded once:
assert 1 == srcs.count('https://example.com/jquery.js')
assert 1 == srcs.count("https://example.com/jquery.js")
# plugin1.js and plugin2.js are both there:
assert 1 == srcs.count('https://example.com/plugin1.js')
assert 1 == srcs.count('https://example.com/plugin2.js')
assert 1 == srcs.count("https://example.com/plugin1.js")
assert 1 == srcs.count("https://example.com/plugin2.js")
# jquery comes before them both
assert srcs.index(
'https://example.com/jquery.js'
) < srcs.index(
'https://example.com/plugin1.js'
assert srcs.index("https://example.com/jquery.js") < srcs.index(
"https://example.com/plugin1.js"
)
assert srcs.index(
'https://example.com/jquery.js'
) < srcs.index(
'https://example.com/plugin2.js'
assert srcs.index("https://example.com/jquery.js") < srcs.index(
"https://example.com/plugin2.js"
)
@ -107,13 +86,9 @@ def test_plugins_render_cell_link_from_json(app_client):
sql = """
select '{"href": "http://example.com/", "label":"Example"}'
""".strip()
path = "/fixtures?" + urllib.parse.urlencode({
"sql": sql,
})
path = "/fixtures?" + urllib.parse.urlencode({"sql": sql})
response = app_client.get(path)
td = Soup(
response.body, "html.parser"
).find("table").find("tbody").find("td")
td = Soup(response.body, "html.parser").find("table").find("tbody").find("td")
a = td.find("a")
assert a is not None, str(a)
assert a.attrs["href"] == "http://example.com/"
@ -129,10 +104,7 @@ def test_plugins_render_cell_demo(app_client):
"column": "content",
"table": "simple_primary_key",
"database": "fixtures",
"config": {
"depth": "table",
"special": "this-is-simple_primary_key"
}
"config": {"depth": "table", "special": "this-is-simple_primary_key"},
} == json.loads(td.string)

Wyświetl plik

@ -35,7 +35,14 @@ def test_publish_cloudrun(mock_call, mock_output, mock_which):
result = runner.invoke(cli.cli, ["publish", "cloudrun", "test.db"])
assert 0 == result.exit_code
tag = "gcr.io/{}/datasette".format(mock_output.return_value)
mock_call.assert_has_calls([
mock.call("gcloud builds submit --tag {}".format(tag), shell=True),
mock.call("gcloud beta run deploy --allow-unauthenticated --image {}".format(tag), shell=True)])
mock_call.assert_has_calls(
[
mock.call("gcloud builds submit --tag {}".format(tag), shell=True),
mock.call(
"gcloud beta run deploy --allow-unauthenticated --image {}".format(
tag
),
shell=True,
),
]
)

Wyświetl plik

@ -57,7 +57,9 @@ def test_publish_heroku(mock_call, mock_check_output, mock_which):
open("test.db", "w").write("data")
result = runner.invoke(cli.cli, ["publish", "heroku", "test.db"])
assert 0 == result.exit_code, result.output
mock_call.assert_called_once_with(["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"])
mock_call.assert_called_once_with(
["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"]
)
@mock.patch("shutil.which")

Wyświetl plik

@ -13,72 +13,78 @@ import tempfile
from unittest.mock import patch
@pytest.mark.parametrize('path,expected', [
('foo', ['foo']),
('foo,bar', ['foo', 'bar']),
('123,433,112', ['123', '433', '112']),
('123%2C433,112', ['123,433', '112']),
('123%2F433%2F112', ['123/433/112']),
])
@pytest.mark.parametrize(
"path,expected",
[
("foo", ["foo"]),
("foo,bar", ["foo", "bar"]),
("123,433,112", ["123", "433", "112"]),
("123%2C433,112", ["123,433", "112"]),
("123%2F433%2F112", ["123/433/112"]),
],
)
def test_urlsafe_components(path, expected):
assert expected == utils.urlsafe_components(path)
@pytest.mark.parametrize('path,added_args,expected', [
('/foo', {'bar': 1}, '/foo?bar=1'),
('/foo?bar=1', {'baz': 2}, '/foo?bar=1&baz=2'),
('/foo?bar=1&bar=2', {'baz': 3}, '/foo?bar=1&bar=2&baz=3'),
('/foo?bar=1', {'bar': None}, '/foo'),
# Test order is preserved
('/?_facet=prim_state&_facet=area_name', (
('prim_state', 'GA'),
), '/?_facet=prim_state&_facet=area_name&prim_state=GA'),
('/?_facet=state&_facet=city&state=MI', (
('city', 'Detroit'),
), '/?_facet=state&_facet=city&state=MI&city=Detroit'),
('/?_facet=state&_facet=city', (
('_facet', 'planet_int'),
), '/?_facet=state&_facet=city&_facet=planet_int'),
])
@pytest.mark.parametrize(
"path,added_args,expected",
[
("/foo", {"bar": 1}, "/foo?bar=1"),
("/foo?bar=1", {"baz": 2}, "/foo?bar=1&baz=2"),
("/foo?bar=1&bar=2", {"baz": 3}, "/foo?bar=1&bar=2&baz=3"),
("/foo?bar=1", {"bar": None}, "/foo"),
# Test order is preserved
(
"/?_facet=prim_state&_facet=area_name",
(("prim_state", "GA"),),
"/?_facet=prim_state&_facet=area_name&prim_state=GA",
),
(
"/?_facet=state&_facet=city&state=MI",
(("city", "Detroit"),),
"/?_facet=state&_facet=city&state=MI&city=Detroit",
),
(
"/?_facet=state&_facet=city",
(("_facet", "planet_int"),),
"/?_facet=state&_facet=city&_facet=planet_int",
),
],
)
def test_path_with_added_args(path, added_args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_added_args(request, added_args)
assert expected == actual
@pytest.mark.parametrize('path,args,expected', [
('/foo?bar=1', {'bar'}, '/foo'),
('/foo?bar=1&baz=2', {'bar'}, '/foo?baz=2'),
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
])
@pytest.mark.parametrize(
"path,args,expected",
[
("/foo?bar=1", {"bar"}, "/foo"),
("/foo?bar=1&baz=2", {"bar"}, "/foo?baz=2"),
("/foo?bar=1&bar=2&bar=3", {"bar": "2"}, "/foo?bar=1&bar=3"),
],
)
def test_path_with_removed_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_removed_args(request, args)
assert expected == actual
# Run the test again but this time use the path= argument
request = Request(
"/".encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request("/".encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_removed_args(request, args, path=path)
assert expected == actual
@pytest.mark.parametrize('path,args,expected', [
('/foo?bar=1', {'bar': 2}, '/foo?bar=2'),
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
])
@pytest.mark.parametrize(
"path,args,expected",
[
("/foo?bar=1", {"bar": 2}, "/foo?bar=2"),
("/foo?bar=1&baz=2", {"bar": None}, "/foo?baz=2"),
],
)
def test_path_with_replaced_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_replaced_args(request, args)
assert expected == actual
@ -93,17 +99,8 @@ def test_path_with_replaced_args(path, args, expected):
utils.CustomRow(
["searchable_id", "tag"],
[
(
"searchable_id",
{"value": 1, "label": "1"},
),
(
"tag",
{
"value": "feline",
"label": "feline",
},
),
("searchable_id", {"value": 1, "label": "1"}),
("tag", {"value": "feline", "label": "feline"}),
],
),
["searchable_id", "tag"],
@ -116,47 +113,54 @@ def test_path_from_row_pks(row, pks, expected_path):
assert expected_path == actual_path
@pytest.mark.parametrize('obj,expected', [
({
'Description': 'Soft drinks',
'Picture': b"\x15\x1c\x02\xc7\xad\x05\xfe",
'CategoryID': 1,
}, """
@pytest.mark.parametrize(
"obj,expected",
[
(
{
"Description": "Soft drinks",
"Picture": b"\x15\x1c\x02\xc7\xad\x05\xfe",
"CategoryID": 1,
},
"""
{"CategoryID": 1, "Description": "Soft drinks", "Picture": {"$base64": true, "encoded": "FRwCx60F/g=="}}
""".strip()),
])
""".strip(),
)
],
)
def test_custom_json_encoder(obj, expected):
actual = json.dumps(
obj,
cls=utils.CustomJSONEncoder,
sort_keys=True
)
actual = json.dumps(obj, cls=utils.CustomJSONEncoder, sort_keys=True)
assert expected == actual
@pytest.mark.parametrize('bad_sql', [
'update blah;',
'PRAGMA case_sensitive_like = true'
"SELECT * FROM pragma_index_info('idx52')",
])
@pytest.mark.parametrize(
"bad_sql",
[
"update blah;",
"PRAGMA case_sensitive_like = true" "SELECT * FROM pragma_index_info('idx52')",
],
)
def test_validate_sql_select_bad(bad_sql):
with pytest.raises(utils.InvalidSql):
utils.validate_sql_select(bad_sql)
@pytest.mark.parametrize('good_sql', [
'select count(*) from airports',
'select foo from bar',
'select 1 + 1',
'SELECT\nblah FROM foo',
'WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;'
])
@pytest.mark.parametrize(
"good_sql",
[
"select count(*) from airports",
"select foo from bar",
"select 1 + 1",
"SELECT\nblah FROM foo",
"WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;",
],
)
def test_validate_sql_select_good(good_sql):
utils.validate_sql_select(good_sql)
def test_detect_fts():
sql = '''
sql = """
CREATE TABLE "Dumb_Table" (
"TreeID" INTEGER,
"qSpecies" TEXT
@ -173,34 +177,40 @@ def test_detect_fts():
CREATE VIEW Test_View AS SELECT * FROM Dumb_Table;
CREATE VIRTUAL TABLE "Street_Tree_List_fts" USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content="Street_Tree_List");
CREATE VIRTUAL TABLE r USING rtree(a, b, c);
'''
conn = utils.sqlite3.connect(':memory:')
"""
conn = utils.sqlite3.connect(":memory:")
conn.executescript(sql)
assert None is utils.detect_fts(conn, 'Dumb_Table')
assert None is utils.detect_fts(conn, 'Test_View')
assert None is utils.detect_fts(conn, 'r')
assert 'Street_Tree_List_fts' == utils.detect_fts(conn, 'Street_Tree_List')
assert None is utils.detect_fts(conn, "Dumb_Table")
assert None is utils.detect_fts(conn, "Test_View")
assert None is utils.detect_fts(conn, "r")
assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List")
@pytest.mark.parametrize('url,expected', [
('http://www.google.com/', True),
('https://example.com/', True),
('www.google.com', False),
('http://www.google.com/ is a search engine', False),
])
@pytest.mark.parametrize(
"url,expected",
[
("http://www.google.com/", True),
("https://example.com/", True),
("www.google.com", False),
("http://www.google.com/ is a search engine", False),
],
)
def test_is_url(url, expected):
assert expected == utils.is_url(url)
@pytest.mark.parametrize('s,expected', [
('simple', 'simple'),
('MixedCase', 'MixedCase'),
('-no-leading-hyphens', 'no-leading-hyphens-65bea6'),
('_no-leading-underscores', 'no-leading-underscores-b921bc'),
('no spaces', 'no-spaces-7088d7'),
('-', '336d5e'),
('no $ characters', 'no--characters-59e024'),
])
@pytest.mark.parametrize(
"s,expected",
[
("simple", "simple"),
("MixedCase", "MixedCase"),
("-no-leading-hyphens", "no-leading-hyphens-65bea6"),
("_no-leading-underscores", "no-leading-underscores-b921bc"),
("no spaces", "no-spaces-7088d7"),
("-", "336d5e"),
("no $ characters", "no--characters-59e024"),
],
)
def test_to_css_class(s, expected):
assert expected == utils.to_css_class(s)
@ -208,11 +218,11 @@ def test_to_css_class(s, expected):
def test_temporary_docker_directory_uses_hard_link():
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
open('hello', 'w').write('world')
open("hello", "w").write("world")
# Default usage of this should use symlink
with utils.temporary_docker_directory(
files=['hello'],
name='t',
files=["hello"],
name="t",
metadata=None,
extra_options=None,
branch=None,
@ -223,23 +233,23 @@ def test_temporary_docker_directory_uses_hard_link():
spatialite=False,
version_note=None,
) as temp_docker:
hello = os.path.join(temp_docker, 'hello')
assert 'world' == open(hello).read()
hello = os.path.join(temp_docker, "hello")
assert "world" == open(hello).read()
# It should be a hard link
assert 2 == os.stat(hello).st_nlink
@patch('os.link')
@patch("os.link")
def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
# Copy instead if os.link raises OSError (normally due to different device)
mock_link.side_effect = OSError
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
open('hello', 'w').write('world')
open("hello", "w").write("world")
# Default usage of this should use symlink
with utils.temporary_docker_directory(
files=['hello'],
name='t',
files=["hello"],
name="t",
metadata=None,
extra_options=None,
branch=None,
@ -250,49 +260,53 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
spatialite=False,
version_note=None,
) as temp_docker:
hello = os.path.join(temp_docker, 'hello')
assert 'world' == open(hello).read()
hello = os.path.join(temp_docker, "hello")
assert "world" == open(hello).read()
# It should be a copy, not a hard link
assert 1 == os.stat(hello).st_nlink
def test_temporary_docker_directory_quotes_args():
with tempfile.TemporaryDirectory() as td:
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
open('hello', 'w').write('world')
open("hello", "w").write("world")
with utils.temporary_docker_directory(
files=['hello'],
name='t',
files=["hello"],
name="t",
metadata=None,
extra_options='--$HOME',
extra_options="--$HOME",
branch=None,
template_dir=None,
plugins_dir=None,
static=[],
install=[],
spatialite=False,
version_note='$PWD',
version_note="$PWD",
) as temp_docker:
df = os.path.join(temp_docker, 'Dockerfile')
df = os.path.join(temp_docker, "Dockerfile")
df_contents = open(df).read()
assert "'$PWD'" in df_contents
assert "'--$HOME'" in df_contents
def test_compound_keys_after_sql():
assert '((a > :p0))' == utils.compound_keys_after_sql(['a'])
assert '''
assert "((a > :p0))" == utils.compound_keys_after_sql(["a"])
assert """
((a > :p0)
or
(a = :p0 and b > :p1))
'''.strip() == utils.compound_keys_after_sql(['a', 'b'])
assert '''
""".strip() == utils.compound_keys_after_sql(
["a", "b"]
)
assert """
((a > :p0)
or
(a = :p0 and b > :p1)
or
(a = :p0 and b = :p1 and c > :p2))
'''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])
""".strip() == utils.compound_keys_after_sql(
["a", "b", "c"]
)
async def table_exists(table):
@ -314,7 +328,7 @@ async def test_resolve_table_and_format(
table_and_format, expected_table, expected_format
):
actual_table, actual_format = await utils.resolve_table_and_format(
table_and_format, table_exists, ['json']
table_and_format, table_exists, ["json"]
)
assert expected_table == actual_table
assert expected_format == actual_format
@ -322,9 +336,11 @@ async def test_resolve_table_and_format(
def test_table_columns():
conn = sqlite3.connect(":memory:")
conn.executescript("""
conn.executescript(
"""
create table places (id integer primary key, name text, bob integer)
""")
"""
)
assert ["id", "name", "bob"] == utils.table_columns(conn, "places")
@ -347,10 +363,7 @@ def test_table_columns():
],
)
def test_path_with_format(path, format, extra_qs, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_format(request, format, extra_qs)
assert expected == actual
@ -358,13 +371,13 @@ def test_path_with_format(path, format, extra_qs, expected):
@pytest.mark.parametrize(
"bytes,expected",
[
(120, '120 bytes'),
(1024, '1.0 KB'),
(1024 * 1024, '1.0 MB'),
(1024 * 1024 * 1024, '1.0 GB'),
(1024 * 1024 * 1024 * 1.3, '1.3 GB'),
(1024 * 1024 * 1024 * 1024, '1.0 TB'),
]
(120, "120 bytes"),
(1024, "1.0 KB"),
(1024 * 1024, "1.0 MB"),
(1024 * 1024 * 1024, "1.0 GB"),
(1024 * 1024 * 1024 * 1.3, "1.3 GB"),
(1024 * 1024 * 1024 * 1024, "1.0 TB"),
],
)
def test_format_bytes(bytes, expected):
assert expected == utils.format_bytes(bytes)