From 26b2922f177caa4e147aaee28be0cff37a457802 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 2 Sep 2020 15:21:12 -0700 Subject: [PATCH] await_me_maybe utility function --- datasette/app.py | 41 ++++++++----------------------------- datasette/utils/__init__.py | 9 ++++++++ datasette/views/base.py | 4 ++-- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 2185a3ab..bb47f411 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -45,6 +45,7 @@ from .database import Database, QueryInterrupted from .utils import ( async_call_with_supported_arguments, + await_me_maybe, call_with_supported_arguments, display_actor, escape_css_string, @@ -312,10 +313,7 @@ class Datasette: async def invoke_startup(self): for hook in pm.hook.startup(datasette=self): - if callable(hook): - hook = hook() - if asyncio.iscoroutine(hook): - hook = await hook + await await_me_maybe(hook) def sign(self, value, namespace="default"): return URLSafeSerializer(self._secret, namespace).dumps(value) @@ -400,10 +398,7 @@ class Datasette: for more_queries in pm.hook.canned_queries( datasette=self, database=database_name, actor=actor, ): - if callable(more_queries): - more_queries = more_queries() - if asyncio.iscoroutine(more_queries): - more_queries = await more_queries + more_queries = await await_me_maybe(more_queries) queries.update(more_queries or {}) # Fix any {"name": "select ..."} queries to be {"name": {"sql": "select ..."}} for key in queries: @@ -475,10 +470,7 @@ class Datasette: for check in pm.hook.permission_allowed( datasette=self, actor=actor, action=action, resource=resource, ): - if callable(check): - check = check() - if asyncio.iscoroutine(check): - check = await check + check = await await_me_maybe(check) if check is not None: result = check used_default = False @@ -718,10 +710,7 @@ class Datasette: request=request, datasette=self, ): - if callable(extra_script): - extra_script = extra_script() - if asyncio.iscoroutine(extra_script): - extra_script = await extra_script + extra_script = await await_me_maybe(extra_script) body_scripts.append(Markup(extra_script)) extra_template_vars = {} @@ -735,10 +724,7 @@ class Datasette: request=request, datasette=self, ): - if callable(extra_vars): - extra_vars = extra_vars() - if asyncio.iscoroutine(extra_vars): - extra_vars = await extra_vars + extra_vars = await await_me_maybe(extra_vars) assert isinstance(extra_vars, dict), "extra_vars is of type {}".format( type(extra_vars) ) @@ -786,10 +772,7 @@ class Datasette: request=request, datasette=self, ): - if callable(hook): - hook = hook() - if asyncio.iscoroutine(hook): - hook = await hook + hook = await await_me_maybe(hook) collected.extend(hook) collected.extend(self.metadata(key) or []) output = [] @@ -981,10 +964,7 @@ class DatasetteRouter: default_actor = scope.get("actor") or None actor = None for actor in pm.hook.actor_from_request(datasette=self.ds, request=request): - if callable(actor): - actor = actor() - if asyncio.iscoroutine(actor): - actor = await actor + actor = await await_me_maybe(actor) if actor: break scope_modifications["actor"] = actor or default_actor @@ -1079,10 +1059,7 @@ class DatasetteRouter: for custom_response in pm.hook.forbidden( datasette=self.ds, request=request, message=message ): - if callable(custom_response): - custom_response = custom_response() - if asyncio.iscoroutine(custom_response): - custom_response = await custom_response + custom_response = await await_me_maybe(custom_response) if custom_response is not None: await custom_response.asgi_send(send) return diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 60922957..caa6920d 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import contextmanager from collections import OrderedDict import base64 @@ -51,6 +52,14 @@ ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so """ +async def await_me_maybe(value): + if callable(value): + value = value() + if asyncio.iscoroutine(value): + value = await value + return value + + def urlsafe_components(token): "Splits token on commas and URL decodes each component" return [urllib.parse.unquote_plus(b) for b in token.split(",")] diff --git a/datasette/views/base.py b/datasette/views/base.py index fa730af8..3b9885db 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -12,6 +12,7 @@ from datasette import __version__ from datasette.plugins import pm from datasette.database import QueryInterrupted from datasette.utils import ( + await_me_maybe, InvalidSql, LimitedWriter, call_with_supported_arguments, @@ -492,8 +493,7 @@ class DataView(BaseView): request=request, view_name=self.name, ) - if asyncio.iscoroutine(it_can_render): - it_can_render = await it_can_render + it_can_render = await await_me_maybe(it_can_render) if it_can_render: renderers[key] = path_with_format( request, key, {**url_labels_extra}