Put AsgiLifestyle back so server starts up again, refs #1955

pull/1965/head
Simon Willison 2022-12-17 17:22:00 -08:00
rodzic 63fb750f39
commit 8b73fc6b47
3 zmienionych plików z 45 dodań i 19 usunięć

Wyświetl plik

@ -69,6 +69,7 @@ from .utils import (
row_sql_params_pks, row_sql_params_pks,
) )
from .utils.asgi import ( from .utils.asgi import (
AsgiLifespan,
Forbidden, Forbidden,
NotFound, NotFound,
DatabaseNotFound, DatabaseNotFound,
@ -1431,6 +1432,7 @@ class Datasette:
) )
if self.setting("trace_debug"): if self.setting("trace_debug"):
asgi = AsgiTracer(asgi) asgi = AsgiTracer(asgi)
asgi = AsgiLifespan(asgi)
asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup])
for wrapper in pm.hook.asgi_wrapper(datasette=self): for wrapper in pm.hook.asgi_wrapper(datasette=self):
asgi = wrapper(asgi) asgi = wrapper(asgi)

Wyświetl plik

@ -156,6 +156,35 @@ class Request:
return cls(scope, None) return cls(scope, None)
class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None):
self.app = app
on_startup = on_startup or []
on_shutdown = on_shutdown or []
if not isinstance(on_startup or [], list):
on_startup = [on_startup]
if not isinstance(on_shutdown or [], list):
on_shutdown = [on_shutdown]
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
for fn in self.on_startup:
await fn()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.on_shutdown:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
else:
await self.app(scope, receive, send)
class AsgiStream: class AsgiStream:
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
self.stream_fn = stream_fn self.stream_fn = stream_fn

Wyświetl plik

@ -27,6 +27,17 @@ UNDOCUMENTED_PERMISSIONS = {
_ds_client = None _ds_client = None
def wait_until_responds(url, timeout=5.0, client=httpx, **kwargs):
start = time.time()
while time.time() - start < timeout:
try:
client.get(url, **kwargs)
return
except httpx.ConnectError:
time.sleep(0.1)
raise AssertionError("Timed out waiting for {} to respond".format(url))
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def ds_client(): async def ds_client():
from datasette.app import Datasette from datasette.app import Datasette
@ -161,13 +172,7 @@ def ds_localhost_http_server():
# Avoid FileNotFoundError: [Errno 2] No such file or directory: # Avoid FileNotFoundError: [Errno 2] No such file or directory:
cwd=tempfile.gettempdir(), cwd=tempfile.gettempdir(),
) )
# Loop until port 8041 serves traffic wait_until_responds("http://localhost:8041/")
while True:
try:
httpx.get("http://localhost:8041/")
break
except httpx.ConnectError:
time.sleep(0.1)
# Check it started successfully # Check it started successfully
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
yield ds_proc yield ds_proc
@ -202,12 +207,7 @@ def ds_localhost_https_server(tmp_path_factory):
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
cwd=tempfile.gettempdir(), cwd=tempfile.gettempdir(),
) )
while True: wait_until_responds("http://localhost:8042/", verify=client_cert)
try:
httpx.get("https://localhost:8042/", verify=client_cert)
break
except httpx.ConnectError:
time.sleep(0.1)
# Check it started successfully # Check it started successfully
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
yield ds_proc, client_cert yield ds_proc, client_cert
@ -231,12 +231,7 @@ def ds_unix_domain_socket_server(tmp_path_factory):
# Poll until available # Poll until available
transport = httpx.HTTPTransport(uds=uds) transport = httpx.HTTPTransport(uds=uds)
client = httpx.Client(transport=transport) client = httpx.Client(transport=transport)
while True: wait_until_responds("http://localhost/_memory.json", client=client)
try:
client.get("http://localhost/_memory.json")
break
except httpx.ConnectError:
time.sleep(0.1)
# Check it started successfully # Check it started successfully
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
yield ds_proc, uds yield ds_proc, uds