Replace AsgiLifespan with AsgiRunOnFirstRequest, refs #1955

pull/1965/head
Simon Willison 2022-12-15 09:34:07 -08:00
rodzic 89cffcf14c
commit 63fb750f39
4 zmienionych plików z 22 dodań i 49 usunięć

Wyświetl plik

@ -69,8 +69,6 @@ from .utils import (
row_sql_params_pks, row_sql_params_pks,
) )
from .utils.asgi import ( from .utils.asgi import (
AsgiLifespan,
Base400,
Forbidden, Forbidden,
NotFound, NotFound,
DatabaseNotFound, DatabaseNotFound,
@ -78,11 +76,10 @@ from .utils.asgi import (
RowNotFound, RowNotFound,
Request, Request,
Response, Response,
AsgiRunOnFirstRequest,
asgi_static, asgi_static,
asgi_send, asgi_send,
asgi_send_file, asgi_send_file,
asgi_send_html,
asgi_send_json,
asgi_send_redirect, asgi_send_redirect,
) )
from .utils.internal_db import init_internal_db, populate_schema_tables from .utils.internal_db import init_internal_db, populate_schema_tables
@ -1420,7 +1417,7 @@ class Datasette:
async def setup_db(): async def setup_db():
# First time server starts up, calculate table counts for immutable databases # First time server starts up, calculate table counts for immutable databases
for dbname, database in self.databases.items(): for database in self.databases.values():
if not database.is_mutable: if not database.is_mutable:
await database.table_counts(limit=60 * 60 * 1000) await database.table_counts(limit=60 * 60 * 1000)
@ -1434,10 +1431,7 @@ class Datasette:
) )
if self.setting("trace_debug"): if self.setting("trace_debug"):
asgi = AsgiTracer(asgi) asgi = AsgiTracer(asgi)
asgi = AsgiLifespan( asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup])
asgi,
on_startup=setup_db,
)
for wrapper in pm.hook.asgi_wrapper(datasette=self): for wrapper in pm.hook.asgi_wrapper(datasette=self):
asgi = wrapper(asgi) asgi = wrapper(asgi)
return asgi return asgi
@ -1730,42 +1724,34 @@ class DatasetteClient:
return path return path
async def get(self, path, **kwargs): async def get(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.get(self._fix(path), **kwargs) return await client.get(self._fix(path), **kwargs)
async def options(self, path, **kwargs): async def options(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.options(self._fix(path), **kwargs) return await client.options(self._fix(path), **kwargs)
async def head(self, path, **kwargs): async def head(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.head(self._fix(path), **kwargs) return await client.head(self._fix(path), **kwargs)
async def post(self, path, **kwargs): async def post(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.post(self._fix(path), **kwargs) return await client.post(self._fix(path), **kwargs)
async def put(self, path, **kwargs): async def put(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.put(self._fix(path), **kwargs) return await client.put(self._fix(path), **kwargs)
async def patch(self, path, **kwargs): async def patch(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.patch(self._fix(path), **kwargs) return await client.patch(self._fix(path), **kwargs)
async def delete(self, path, **kwargs): async def delete(self, path, **kwargs):
await self.ds.invoke_startup()
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.delete(self._fix(path), **kwargs) return await client.delete(self._fix(path), **kwargs)
async def request(self, method, path, **kwargs): async def request(self, method, path, **kwargs):
await self.ds.invoke_startup()
avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None) avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None)
async with httpx.AsyncClient(app=self.app) as client: async with httpx.AsyncClient(app=self.app) as client:
return await client.request( return await client.request(

Wyświetl plik

@ -156,35 +156,6 @@ class Request:
return cls(scope, None) return cls(scope, None)
class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None):
self.app = app
on_startup = on_startup or []
on_shutdown = on_shutdown or []
if not isinstance(on_startup or [], list):
on_startup = [on_startup]
if not isinstance(on_shutdown or [], list):
on_shutdown = [on_shutdown]
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
for fn in self.on_startup:
await fn()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.on_shutdown:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
else:
await self.app(scope, receive, send)
class AsgiStream: class AsgiStream:
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
self.stream_fn = stream_fn self.stream_fn = stream_fn
@ -449,3 +420,18 @@ class AsgiFileDownload:
content_type=self.content_type, content_type=self.content_type,
headers=self.headers, headers=self.headers,
) )
class AsgiRunOnFirstRequest:
def __init__(self, asgi, on_startup):
assert isinstance(on_startup, list)
self.asgi = asgi
self.on_startup = on_startup
self._started = False
async def __call__(self, scope, receive, send):
if not self._started:
self._started = True
for hook in self.on_startup:
await hook()
return await self.asgi(scope, receive, send)

Wyświetl plik

@ -902,13 +902,14 @@ Potential use-cases:
.. note:: .. note::
If you are writing :ref:`unit tests <testing_plugins>` for a plugin that uses this hook you will need to explicitly call ``await ds.invoke_startup()`` in your tests. An example: If you are writing :ref:`unit tests <testing_plugins>` for a plugin that uses this hook and doesn't exercise Datasette by sending
any simulated requests through it you will need to explicitly call ``await ds.invoke_startup()`` in your tests. An example:
.. code-block:: python .. code-block:: python
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_my_plugin(): async def test_my_plugin():
ds = Datasette([], metadata={}) ds = Datasette()
await ds.invoke_startup() await ds.invoke_startup()
# Rest of test goes here # Rest of test goes here

Wyświetl plik

@ -80,7 +80,7 @@ Creating a ``Datasette()`` instance like this as useful shortcut in tests, but t
This method registers any :ref:`plugin_hook_startup` or :ref:`plugin_hook_prepare_jinja2_environment` plugins that might themselves need to make async calls. This method registers any :ref:`plugin_hook_startup` or :ref:`plugin_hook_prepare_jinja2_environment` plugins that might themselves need to make async calls.
If you are using ``await datasette.client.get()`` and similar methods then you don't need to worry about this - those method calls ensure that ``.invoke_startup()`` has been called for you. If you are using ``await datasette.client.get()`` and similar methods then you don't need to worry about this - Datasette automatically calls ``invoke_startup()`` the first time it handles a request.
.. _testing_plugins_pdb: .. _testing_plugins_pdb: