Refactor to use new resolve_database/table/row methods, refs #1896

pull/1912/head
Simon Willison 2022-11-18 14:46:25 -08:00
rodzic c588a89f26
commit ee64130fa8
8 zmienionych plików z 256 dodań i 135 usunięć

Wyświetl plik

@ -62,13 +62,19 @@ from .utils import (
parse_metadata,
resolve_env_secrets,
resolve_routes,
tilde_decode,
to_css_class,
urlsafe_components,
row_sql_params_pks,
)
from .utils.asgi import (
AsgiLifespan,
Base400,
Forbidden,
NotFound,
DatabaseNotFound,
TableNotFound,
RowNotFound,
Request,
Response,
asgi_static,
@ -198,6 +204,12 @@ async def favicon(request, send):
)
ResolvedTable = collections.namedtuple("ResolvedTable", ("db", "table", "is_view"))
ResolvedRow = collections.namedtuple(
"ResolvedRow", ("db", "table", "sql", "params", "pks", "pk_values", "row")
)
class Datasette:
# Message constants:
INFO = 1
@ -1292,6 +1304,41 @@ class Datasette:
for pattern, view in routes
]
async def resolve_database(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
return self.get_database(route=database_route)
except KeyError:
raise DatabaseNotFound(
"Database not found: {}".format(database_route), database_route
)
async def resolve_table(self, request):
db = await self.resolve_database(request)
table_name = tilde_decode(request.url_vars["table"])
# Table must exist
is_view = False
table_exists = await db.table_exists(table_name)
if not table_exists:
is_view = await db.view_exists(table_name)
if not (table_exists or is_view):
raise TableNotFound(
"Table not found: {}".format(table_name), db.name, table_name
)
return ResolvedTable(db, table_name, is_view)
async def resolve_row(self, request):
db, table_name, _ = await self.resolve_table(request)
pk_values = urlsafe_components(request.url_vars["pks"])
sql, params, pks = await row_sql_params_pks(db, table_name, pk_values)
results = await db.execute(sql, params, truncate=True)
row = results.first()
if row is None:
raise RowNotFound(
"Row not found: {}".format(pk_values), db.name, table_name, pk_values
)
return ResolvedRow(db, table_name, sql, params, pks, pk_values, results.first())
def app(self):
"""Returns an ASGI app function that serves the whole of Datasette"""
routes = self._routes()

Wyświetl plik

@ -1193,3 +1193,18 @@ def truncate_url(url, length):
rest, ext = bits
return rest[: length - 1 - len(ext)] + "…." + ext
return url[: length - 1] + ""
async def row_sql_params_pks(db, table, pk_values):
pks = await db.primary_keys(table)
use_rowid = not pks
select = "*"
if use_rowid:
select = "rowid, *"
pks = ["rowid"]
wheres = [f'"{pk}"=:p{i}' for i, pk in enumerate(pks)]
sql = f"select {select} from {escape_sqlite(table)} where {' AND '.join(wheres)}"
params = {}
for i, pk_value in enumerate(pk_values):
params[f"p{i}"] = pk_value
return sql, params, pks

Wyświetl plik

@ -21,6 +21,27 @@ class NotFound(Base400):
status = 404
class DatabaseNotFound(NotFound):
def __init__(self, message, database_name):
super().__init__(message)
self.database_name = database_name
class TableNotFound(NotFound):
def __init__(self, message, database_name, table):
super().__init__(message)
self.database_name = database_name
self.table = table
class RowNotFound(NotFound):
def __init__(self, message, database_name, table, pk_values):
super().__init__(message)
self.database_name = database_name
self.table_name = table
self.pk_values = pk_values
class Forbidden(Base400):
status = 403

Wyświetl plik

@ -20,7 +20,6 @@ from datasette.utils import (
InvalidSql,
LimitedWriter,
call_with_supported_arguments,
tilde_decode,
path_from_row_pks,
path_with_added_args,
path_with_removed_args,
@ -346,13 +345,9 @@ class DataView(BaseView):
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
async def get(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
db = await self.ds.resolve_database(request)
database = db.name
database_route = db.route
_format = request.url_vars["format"]
data_kwargs = {}

Wyświetl plik

@ -35,11 +35,7 @@ class DatabaseView(DataView):
name = "database"
async def data(self, request, default_labels=False, _size=None):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
db = await self.ds.resolve_database(request)
database = db.name
visible, private = await self.ds.check_visibility(
@ -228,11 +224,7 @@ class QueryView(DataView):
named_parameters=None,
write=False,
):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
db = await self.ds.resolve_database(request)
database = db.name
params = {key: request.args.get(key) for key in request.args}
if "sql" in params:
@ -582,11 +574,7 @@ class TableCreateView(BaseView):
self.ds = datasette
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
return _error(["Database not found: {}".format(database_route)], 404)
db = await self.ds.resolve_database(request)
database_name = db.name
# Must have create-table permission
@ -727,11 +715,7 @@ class TableCreateView(BaseView):
self.ds = datasette
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
return _error(["Database not found: {}".format(database_route)], 404)
db = await self.ds.resolve_database(request)
database_name = db.name
# Must have create-table permission

Wyświetl plik

@ -6,22 +6,21 @@ from datasette.utils import (
urlsafe_components,
to_css_class,
escape_sqlite,
row_sql_params_pks,
)
import json
import sqlite_utils
from .table import _sql_params_pks, display_columns_and_rows
from .table import display_columns_and_rows
class RowView(DataView):
name = "row"
async def data(self, request, default_labels=False):
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
resolved = await self.ds.resolve_row(request)
database = resolved.db.name
table = resolved.table
pk_values = resolved.pk_values
# Ensure user has permission to view this row
visible, private = await self.ds.check_visibility(
@ -35,14 +34,9 @@ class RowView(DataView):
if not visible:
raise Forbidden("You do not have permission to view this table")
pk_values = urlsafe_components(request.url_vars["pks"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
sql, params, pks = await _sql_params_pks(db, table, pk_values)
results = await db.execute(sql, params, truncate=True)
results = await resolved.db.execute(
resolved.sql, resolved.params, truncate=True
)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
@ -83,7 +77,7 @@ class RowView(DataView):
"table": table,
"rows": rows,
"columns": columns,
"primary_keys": pks,
"primary_keys": resolved.pks,
"primary_key_values": pk_values,
"units": self.ds.table_metadata(database, table).get("units", {}),
}
@ -149,6 +143,11 @@ class RowView(DataView):
return foreign_key_tables
class RowError(Exception):
def __init__(self, error):
self.error = error
class RowDeleteView(BaseView):
name = "row-delete"
@ -156,24 +155,20 @@ class RowDeleteView(BaseView):
self.ds = datasette
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
from datasette.app import DatabaseNotFound, TableNotFound, RowNotFound
try:
db = self.ds.get_database(route=database_route)
except KeyError:
return _error(["Database not found: {}".format(database_route)], 404)
resolved = await self.ds.resolve_row(request)
except DatabaseNotFound as e:
return _error(["Database not found: {}".format(e.database_name)], 404)
except TableNotFound as e:
return _error(["Table not found: {}".format(e.table)], 404)
except RowNotFound as e:
return _error(["Record not found: {}".format(e.pk_values)], 404)
db = resolved.db
database_name = db.name
if not await db.table_exists(table):
return _error(["Table not found: {}".format(table)], 404)
pk_values = urlsafe_components(request.url_vars["pks"])
sql, params, pks = await _sql_params_pks(db, table, pk_values)
results = await db.execute(sql, params, truncate=True)
rows = list(results.rows)
if not rows:
return _error([f"Record not found: {pk_values}"], 404)
table = resolved.table
pk_values = resolved.pk_values
# Ensure user has permission to delete this row
if not await self.ds.permission_allowed(

Wyświetl plik

@ -93,36 +93,33 @@ class TableView(DataView):
return expandables
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database_name = db.name
table_name = tilde_decode(request.url_vars["table"])
# Handle POST to a canned query
canned_query = await self.ds.get_canned_query(
database_name, table_name, request.actor
)
if canned_query:
return await QueryView(self.ds).data(
request,
canned_query["sql"],
metadata=canned_query,
editable=False,
canned_query=table_name,
named_parameters=canned_query.get("params"),
write=bool(canned_query.get("write")),
)
else:
# Handle POST to a table
return await self.table_post(request, database_name, table_name)
from datasette.app import TableNotFound
async def table_post(self, request, database_name, table_name):
# Table must exist (may handle table creation in the future)
db = self.ds.get_database(database_name)
if not await db.table_exists(table_name):
raise NotFound("Table not found: {}".format(table_name))
try:
resolved = await self.ds.resolve_table(request)
except TableNotFound as e:
# Was this actually a canned query?
canned_query = await self.ds.get_canned_query(
e.database_name, e.table, request.actor
)
if canned_query:
# Handle POST to a canned query
return await QueryView(self.ds).data(
request,
canned_query["sql"],
metadata=canned_query,
editable=False,
canned_query=e.table,
named_parameters=canned_query.get("params"),
write=bool(canned_query.get("write")),
)
# Handle POST to a table
return await self.table_post(
request, resolved.db, resolved.db.name, resolved.table
)
async def table_post(self, request, db, database_name, table_name):
# Must have insert-row permission
if not await self.ds.permission_allowed(
request.actor, "insert-row", resource=(database_name, table_name)
@ -221,12 +218,31 @@ class TableView(DataView):
_next=None,
_size=None,
):
database_route = tilde_decode(request.url_vars["database"])
table_name = tilde_decode(request.url_vars["table"])
from datasette.app import TableNotFound
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
resolved = await self.ds.resolve_table(request)
except TableNotFound as e:
# Was this actually a canned query?
canned_query = await self.ds.get_canned_query(
e.database_name, e.table, request.actor
)
# If this is a canned query, not a table, then dispatch to QueryView instead
if canned_query:
return await QueryView(self.ds).data(
request,
canned_query["sql"],
metadata=canned_query,
editable=False,
canned_query=e.table,
named_parameters=canned_query.get("params"),
write=bool(canned_query.get("write")),
)
else:
raise
table_name = resolved.table
db = resolved.db
database_name = db.name
# For performance profiling purposes, ?_noparallel=1 turns off asyncio.gather
@ -243,21 +259,6 @@ class TableView(DataView):
_gather_sequential if request.args.get("_noparallel") else _gather_parallel
)
# If this is a canned query, not a table, then dispatch to QueryView instead
canned_query = await self.ds.get_canned_query(
database_name, table_name, request.actor
)
if canned_query:
return await QueryView(self.ds).data(
request,
canned_query["sql"],
metadata=canned_query,
editable=False,
canned_query=table_name,
named_parameters=canned_query.get("params"),
write=bool(canned_query.get("write")),
)
is_view, table_exists = map(
bool,
await gather(
@ -874,21 +875,6 @@ class TableView(DataView):
)
async def _sql_params_pks(db, table, pk_values):
pks = await db.primary_keys(table)
use_rowid = not pks
select = "*"
if use_rowid:
select = "rowid, *"
pks = ["rowid"]
wheres = [f'"{pk}"=:p{i}' for i, pk in enumerate(pks)]
sql = f"select {select} from {escape_sqlite(table)} where {' AND '.join(wheres)}"
params = {}
for i, pk_value in enumerate(pk_values):
params[f"p{i}"] = pk_value
return sql, params, pks
async def display_columns_and_rows(
datasette,
database_name,
@ -1161,13 +1147,13 @@ class TableInsertView(BaseView):
return rows, errors, extras
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
return _error(["Database not found: {}".format(database_route)], 404)
resolved = await self.ds.resolve_table(request)
except NotFound as e:
return _error([e.args[0]], 404)
db = resolved.db
database_name = db.name
table_name = tilde_decode(request.url_vars["table"])
table_name = resolved.table
# Table must exist (may handle table creation in the future)
db = self.ds.get_database(database_name)
@ -1221,13 +1207,13 @@ class TableDropView(BaseView):
self.ds = datasette
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
return _error(["Database not found: {}".format(database_route)], 404)
resolved = await self.ds.resolve_table(request)
except NotFound as e:
return _error([e.args[0]], 404)
db = resolved.db
database_name = db.name
table_name = tilde_decode(request.url_vars["table"])
table_name = resolved.table
# Table must exist
db = self.ds.get_database(database_name)
if not await db.table_exists(table_name):

Wyświetl plik

@ -579,6 +579,84 @@ For example:
downloads_are_allowed = datasette.setting("allow_download")
.. _datasette_resolve_database:
.resolve_database(request)
--------------------------
``request`` - :ref:`internals_request`
A request object
If you are implementing your own custom views, you may need to resolve the database that the user is requesting based on a URL path. If the regular expression for your route declares a ``database`` named group, you can use this method to resolve the database object.
This returns a :ref:`Database <internals_database>` instance.
If the database cannot be found, it raises a ``datasette.utils.asgi.DatabaseNotFound`` exception - which is a subclass of ``datasette.utils.asgi.NotFound`` with a ``.database_name`` attribute set to the name of the database that was requested.
.. _datasette_resolve_table:
.resolve_table(request)
-----------------------
``request`` - :ref:`internals_request`
A request object
This assumes that the regular expression for your route declares both a ``database`` and a ``table`` named group.
It returns a ``ResolvedTable`` named tuple instance with the following fields:
``db`` - :ref:`Database <internals_database>`
The database object
``table`` - string
The name of the table (or view)
``is_view`` - boolean
``True`` if this is a view, ``False`` if it is a table
If the database or table cannot be found it raises a ``datasette.utils.asgi.DatabaseNotFound`` exception.
If the table does not exist it raises a ``datasette.utils.asgi.TableNotFound`` exception - a subclass of ``datasette.utils.asgi.NotFound`` with ``.database_name`` and ``.table`` attributes.
.. _datasette_resolve_row:
.resolve_row(request)
---------------------
``request`` - :ref:`internals_request`
A request object
This method assumes your route declares named groups for ``database``, ``table`` and ``pks``.
It returns a ``ResolvedRow`` named tuple instance with the following fields:
``db`` - :ref:`Database <internals_database>`
The database object
``table`` - string
The name of the table
``sql`` - string
SQL snippet that can be used in a ``WHERE`` clause to select the row
``params`` - dict
Parameters that should be passed to the SQL query
``pks`` - list
List of primary key column names
``pk_values`` - list
List of primary key values decoded from the URL
``row`` - ``sqlite3.Row``
The row itself
If the database or table cannot be found it raises a ``datasette.utils.asgi.DatabaseNotFound`` exception.
If the table does not exist it raises a ``datasette.utils.asgi.TableNotFound`` exception.
If the row cannot be found it raises a ``datasette.utils.asgi.RowNotFound`` exception. This has ``.database_name``, ``.table`` and ``.pk_values`` attributes, extracted from the request path.
.. _internals_datasette_client:
datasette.client
@ -770,7 +848,7 @@ The ``Results`` object also has the following properties and methods:
``.columns`` - list of strings
A list of column names returned by the query.
``.rows`` - list of sqlite3.Row
``.rows`` - list of ``sqlite3.Row``
This property provides direct access to the list of rows returned by the database. You can access specific rows by index using ``results.rows[0]``.
``.first()`` - row or None