datasette/datasette/tracer.py

145 wiersze
4.6 KiB
Python

import asyncio
from contextlib import contextmanager
from contextvars import ContextVar
from markupsafe import escape
import time
import json
import traceback
tracers = {}
TRACE_RESERVED_KEYS = {"type", "start", "end", "duration_ms", "traceback"}
trace_task_id = ContextVar("trace_task_id", default=None)
def get_task_id():
current = trace_task_id.get(None)
if current is not None:
return current
try:
loop = asyncio.get_event_loop()
except RuntimeError:
return None
return id(asyncio.current_task(loop=loop))
@contextmanager
def trace_child_tasks():
token = trace_task_id.set(get_task_id())
yield
trace_task_id.reset(token)
@contextmanager
def trace(type, **kwargs):
assert not TRACE_RESERVED_KEYS.intersection(
kwargs.keys()
), f".trace() keyword parameters cannot include {TRACE_RESERVED_KEYS}"
task_id = get_task_id()
if task_id is None:
yield kwargs
return
tracer = tracers.get(task_id)
if tracer is None:
yield kwargs
return
start = time.perf_counter()
yield kwargs
end = time.perf_counter()
trace_info = {
"type": type,
"start": start,
"end": end,
"duration_ms": (end - start) * 1000,
"traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]),
}
trace_info.update(kwargs)
tracer.append(trace_info)
@contextmanager
def capture_traces(tracer):
# tracer is a list
task_id = get_task_id()
if task_id is None:
yield
return
tracers[task_id] = tracer
yield
del tracers[task_id]
class AsgiTracer:
# If the body is larger than this we don't attempt to append the trace
max_body_bytes = 1024 * 256 # 256 KB
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if b"_trace=1" not in scope.get("query_string", b"").split(b"&"):
await self.app(scope, receive, send)
return
trace_start = time.perf_counter()
traces = []
accumulated_body = b""
size_limit_exceeded = False
response_headers = []
async def wrapped_send(message):
nonlocal accumulated_body, size_limit_exceeded, response_headers
if message["type"] == "http.response.start":
response_headers = message["headers"]
await send(message)
return
if message["type"] != "http.response.body" or size_limit_exceeded:
await send(message)
return
# Accumulate body until the end or until size is exceeded
accumulated_body += message["body"]
if len(accumulated_body) > self.max_body_bytes:
await send(
{
"type": "http.response.body",
"body": accumulated_body,
"more_body": True,
}
)
size_limit_exceeded = True
return
if not message.get("more_body"):
# We have all the body - modify it and send the result
# TODO: What to do about Content-Type or other cases?
trace_info = {
"request_duration_ms": 1000 * (time.perf_counter() - trace_start),
"sum_trace_duration_ms": sum(t["duration_ms"] for t in traces),
"num_traces": len(traces),
"traces": traces,
}
try:
content_type = [
v.decode("utf8")
for k, v in response_headers
if k.lower() == b"content-type"
][0]
except IndexError:
content_type = ""
if "text/html" in content_type and b"</body>" in accumulated_body:
extra = escape(json.dumps(trace_info, indent=2))
extra_html = f"<pre>{extra}</pre></body>".encode("utf8")
accumulated_body = accumulated_body.replace(b"</body>", extra_html)
elif "json" in content_type and accumulated_body.startswith(b"{"):
data = json.loads(accumulated_body.decode("utf8"))
if "_trace" not in data:
data["_trace"] = trace_info
accumulated_body = json.dumps(data).encode("utf8")
await send({"type": "http.response.body", "body": accumulated_body})
with capture_traces(traces):
await self.app(scope, receive, wrapped_send)