kopia lustrzana https://github.com/simonw/datasette
Put AsgiLifestyle back so server starts up again, refs #1955
rodzic
96b3a86d7f
commit
5649e547ef
|
@ -63,6 +63,7 @@ from .utils import (
|
||||||
to_css_class,
|
to_css_class,
|
||||||
)
|
)
|
||||||
from .utils.asgi import (
|
from .utils.asgi import (
|
||||||
|
AsgiLifespan,
|
||||||
Forbidden,
|
Forbidden,
|
||||||
NotFound,
|
NotFound,
|
||||||
Request,
|
Request,
|
||||||
|
@ -1271,6 +1272,7 @@ class Datasette:
|
||||||
)
|
)
|
||||||
if self.setting("trace_debug"):
|
if self.setting("trace_debug"):
|
||||||
asgi = AsgiTracer(asgi)
|
asgi = AsgiTracer(asgi)
|
||||||
|
asgi = AsgiLifespan(asgi)
|
||||||
asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup])
|
asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup])
|
||||||
for wrapper in pm.hook.asgi_wrapper(datasette=self):
|
for wrapper in pm.hook.asgi_wrapper(datasette=self):
|
||||||
asgi = wrapper(asgi)
|
asgi = wrapper(asgi)
|
||||||
|
|
|
@ -135,6 +135,35 @@ class Request:
|
||||||
return cls(scope, None)
|
return cls(scope, None)
|
||||||
|
|
||||||
|
|
||||||
|
class AsgiLifespan:
|
||||||
|
def __init__(self, app, on_startup=None, on_shutdown=None):
|
||||||
|
self.app = app
|
||||||
|
on_startup = on_startup or []
|
||||||
|
on_shutdown = on_shutdown or []
|
||||||
|
if not isinstance(on_startup or [], list):
|
||||||
|
on_startup = [on_startup]
|
||||||
|
if not isinstance(on_shutdown or [], list):
|
||||||
|
on_shutdown = [on_shutdown]
|
||||||
|
self.on_startup = on_startup
|
||||||
|
self.on_shutdown = on_shutdown
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope["type"] == "lifespan":
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
for fn in self.on_startup:
|
||||||
|
await fn()
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
elif message["type"] == "lifespan.shutdown":
|
||||||
|
for fn in self.on_shutdown:
|
||||||
|
await fn()
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class AsgiStream:
|
class AsgiStream:
|
||||||
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
|
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
|
||||||
self.stream_fn = stream_fn
|
self.stream_fn = stream_fn
|
||||||
|
|
|
@ -23,6 +23,17 @@ UNDOCUMENTED_PERMISSIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def wait_until_responds(url, timeout=5.0, client=httpx, **kwargs):
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
try:
|
||||||
|
client.get(url, **kwargs)
|
||||||
|
return
|
||||||
|
except httpx.ConnectError:
|
||||||
|
time.sleep(0.1)
|
||||||
|
raise AssertionError("Timed out waiting for {} to respond".format(url))
|
||||||
|
|
||||||
|
|
||||||
def pytest_report_header(config):
|
def pytest_report_header(config):
|
||||||
return "SQLite: {}".format(
|
return "SQLite: {}".format(
|
||||||
sqlite3.connect(":memory:").execute("select sqlite_version()").fetchone()[0]
|
sqlite3.connect(":memory:").execute("select sqlite_version()").fetchone()[0]
|
||||||
|
@ -111,13 +122,7 @@ def ds_localhost_http_server():
|
||||||
# Avoid FileNotFoundError: [Errno 2] No such file or directory:
|
# Avoid FileNotFoundError: [Errno 2] No such file or directory:
|
||||||
cwd=tempfile.gettempdir(),
|
cwd=tempfile.gettempdir(),
|
||||||
)
|
)
|
||||||
# Loop until port 8041 serves traffic
|
wait_until_responds("http://localhost:8041/")
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
httpx.get("http://localhost:8041/")
|
|
||||||
break
|
|
||||||
except httpx.ConnectError:
|
|
||||||
time.sleep(0.1)
|
|
||||||
# Check it started successfully
|
# Check it started successfully
|
||||||
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
|
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
|
||||||
yield ds_proc
|
yield ds_proc
|
||||||
|
@ -152,12 +157,7 @@ def ds_localhost_https_server(tmp_path_factory):
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
cwd=tempfile.gettempdir(),
|
cwd=tempfile.gettempdir(),
|
||||||
)
|
)
|
||||||
while True:
|
wait_until_responds("http://localhost:8042/", verify=client_cert)
|
||||||
try:
|
|
||||||
httpx.get("https://localhost:8042/", verify=client_cert)
|
|
||||||
break
|
|
||||||
except httpx.ConnectError:
|
|
||||||
time.sleep(0.1)
|
|
||||||
# Check it started successfully
|
# Check it started successfully
|
||||||
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
|
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
|
||||||
yield ds_proc, client_cert
|
yield ds_proc, client_cert
|
||||||
|
@ -181,12 +181,7 @@ def ds_unix_domain_socket_server(tmp_path_factory):
|
||||||
# Poll until available
|
# Poll until available
|
||||||
transport = httpx.HTTPTransport(uds=uds)
|
transport = httpx.HTTPTransport(uds=uds)
|
||||||
client = httpx.Client(transport=transport)
|
client = httpx.Client(transport=transport)
|
||||||
while True:
|
wait_until_responds("http://localhost/_memory.json", client=client)
|
||||||
try:
|
|
||||||
client.get("http://localhost/_memory.json")
|
|
||||||
break
|
|
||||||
except httpx.ConnectError:
|
|
||||||
time.sleep(0.1)
|
|
||||||
# Check it started successfully
|
# Check it started successfully
|
||||||
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
|
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
|
||||||
yield ds_proc, uds
|
yield ds_proc, uds
|
||||||
|
|
Ładowanie…
Reference in New Issue