diff --git a/datasette/app.py b/datasette/app.py index 09a47bc5..09f281e3 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -69,6 +69,7 @@ from .utils import ( row_sql_params_pks, ) from .utils.asgi import ( + AsgiLifespan, Forbidden, NotFound, DatabaseNotFound, @@ -1431,6 +1432,7 @@ class Datasette: ) if self.setting("trace_debug"): asgi = AsgiTracer(asgi) + asgi = AsgiLifespan(asgi) asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) for wrapper in pm.hook.asgi_wrapper(datasette=self): asgi = wrapper(asgi) diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index 56690251..b2c6f3ab 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -156,6 +156,35 @@ class Request: return cls(scope, None) +class AsgiLifespan: + def __init__(self, app, on_startup=None, on_shutdown=None): + self.app = app + on_startup = on_startup or [] + on_shutdown = on_shutdown or [] + if not isinstance(on_startup or [], list): + on_startup = [on_startup] + if not isinstance(on_shutdown or [], list): + on_shutdown = [on_shutdown] + self.on_startup = on_startup + self.on_shutdown = on_shutdown + + async def __call__(self, scope, receive, send): + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + for fn in self.on_startup: + await fn() + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + for fn in self.on_shutdown: + await fn() + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await self.app(scope, receive, send) + + class AsgiStream: def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): self.stream_fn = stream_fn diff --git a/tests/conftest.py b/tests/conftest.py index 44c44f87..69dee68b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,17 @@ UNDOCUMENTED_PERMISSIONS = { _ds_client = None +def wait_until_responds(url, timeout=5.0, client=httpx, **kwargs): + start = time.time() + while time.time() - start < timeout: + try: + client.get(url, **kwargs) + return + except httpx.ConnectError: + time.sleep(0.1) + raise AssertionError("Timed out waiting for {} to respond".format(url)) + + @pytest_asyncio.fixture async def ds_client(): from datasette.app import Datasette @@ -161,13 +172,7 @@ def ds_localhost_http_server(): # Avoid FileNotFoundError: [Errno 2] No such file or directory: cwd=tempfile.gettempdir(), ) - # Loop until port 8041 serves traffic - while True: - try: - httpx.get("http://localhost:8041/") - break - except httpx.ConnectError: - time.sleep(0.1) + wait_until_responds("http://localhost:8041/") # Check it started successfully assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") yield ds_proc @@ -202,12 +207,7 @@ def ds_localhost_https_server(tmp_path_factory): stderr=subprocess.STDOUT, cwd=tempfile.gettempdir(), ) - while True: - try: - httpx.get("https://localhost:8042/", verify=client_cert) - break - except httpx.ConnectError: - time.sleep(0.1) + wait_until_responds("http://localhost:8042/", verify=client_cert) # Check it started successfully assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") yield ds_proc, client_cert @@ -231,12 +231,7 @@ def ds_unix_domain_socket_server(tmp_path_factory): # Poll until available transport = httpx.HTTPTransport(uds=uds) client = httpx.Client(transport=transport) - while True: - try: - client.get("http://localhost/_memory.json") - break - except httpx.ConnectError: - time.sleep(0.1) + wait_until_responds("http://localhost/_memory.json", client=client) # Check it started successfully assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") yield ds_proc, uds