From da53e0360da4771ffb56a8e3eb3f7476f3168299 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 4 Feb 2022 21:19:49 -0800 Subject: [PATCH] tracer.trace_child_tasks() for asyncio.gather tracing Also added documentation for datasette.tracer module. Closes #1576 --- datasette/tracer.py | 20 +++++++---- docs/internals.rst | 71 ++++++++++++++++++++++++++++++++++++++ tests/plugins/my_plugin.py | 12 +++++++ tests/test_tracer.py | 15 ++++++++ 4 files changed, 111 insertions(+), 7 deletions(-) diff --git a/datasette/tracer.py b/datasette/tracer.py index 6703f060..fc7338b0 100644 --- a/datasette/tracer.py +++ b/datasette/tracer.py @@ -1,5 +1,6 @@ import asyncio from contextlib import contextmanager +from contextvars import ContextVar from markupsafe import escape import time import json @@ -9,20 +10,25 @@ tracers = {} TRACE_RESERVED_KEYS = {"type", "start", "end", "duration_ms", "traceback"} - -# asyncio.current_task was introduced in Python 3.7: -for obj in (asyncio, asyncio.Task): - current_task = getattr(obj, "current_task", None) - if current_task is not None: - break +trace_task_id = ContextVar("trace_task_id", default=None) def get_task_id(): + current = trace_task_id.get(None) + if current is not None: + return current try: loop = asyncio.get_event_loop() except RuntimeError: return None - return id(current_task(loop=loop)) + return id(asyncio.current_task(loop=loop)) + + +@contextmanager +def trace_child_tasks(): + token = trace_task_id.set(get_task_id()) + yield + trace_task_id.reset(token) @contextmanager diff --git a/docs/internals.rst b/docs/internals.rst index 6a5666fd..a5dbdfb4 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -864,3 +864,74 @@ parse_metadata(content) This function accepts a string containing either JSON or YAML, expected to be of the format described in :ref:`metadata`. It returns a nested Python dictionary representing the parsed data from that string. If the metadata cannot be parsed as either JSON or YAML the function will raise a ``utils.BadMetadataError`` exception. + +.. _internals_tracer + +datasette.tracer +================ + +Running Datasette with ``--setting trace_debug 1`` enables trace debug output, which can then be viewed by adding ``?_trace=1`` to the query string for any page. + +You can see an example of this at the bottom of `latest.datasette.io/fixtures/facetable?_trace=1 `__. The JSON output shows full details of every SQL query that was executed to generate the page. + +The `datasette-pretty-traces `__ plugin can be installed to provide a more readable display of this information. You can see `a demo of that here `__. + +You can add your own custom traces to the JSON output using the ``trace()`` context manager. This takes a string that identifies the type of trace being recorded, and records any keyword arguments as additional JSON keys on the resulting trace object. + +The start and end time, duration and a traceback of where the trace was executed will be automatically attached to the JSON object. + +This example uses trace to record the start, end and duration of any HTTP GET requests made using the function: + +.. code-block:: python + + from datasette.tracer import trace + import httpx + + async def fetch_url(url): + with trace("fetch-url", url=url): + async with httpx.AsyncClient() as client: + return await client.get(url) + +.. _internals_tracer_trace_child_tasks + +Tracing child tasks +------------------- + +If your code uses a mechanism such as ``asyncio.gather()`` to execute code in additional tasks you may find that some of the traces are missing from the display. + +You can use the ``trace_child_tasks()`` context manager to ensure these child tasks are correctly handled. + +.. code-block:: python + + from datasette import tracer + + with tracer.trace_child_tasks(): + results = await asyncio.gather( + # ... async tasks here + ) + +This example uses the :ref:`register_routes() ` plugin hook to add a page at ``/parallel-queries`` which executes two SQL queries in parallel using ``asyncio.gather()`` and returns their results. + +.. code-block:: python + + from datasette import hookimpl + from datasette import tracer + + @hookimpl + def register_routes(): + + async def parallel_queries(datasette): + db = datasette.get_database() + with tracer.trace_child_tasks(): + one, two = await asyncio.gather( + db.execute("select 1"), + db.execute("select 2"), + ) + return Response.json({"one": one.single_value(), "two": two.single_value()}) + + return [ + (r"/parallel-queries$", parallel_queries), + ] + + +Adding ``?_trace=1`` will show that the trace covers both of those child tasks. diff --git a/tests/plugins/my_plugin.py b/tests/plugins/my_plugin.py index 75c76ea8..610cea17 100644 --- a/tests/plugins/my_plugin.py +++ b/tests/plugins/my_plugin.py @@ -1,5 +1,7 @@ +import asyncio from datasette import hookimpl from datasette.facets import Facet +from datasette import tracer from datasette.utils import path_with_added_args from datasette.utils.asgi import asgi_send_json, Response import base64 @@ -270,6 +272,15 @@ def register_routes(): def asgi_scope(scope): return Response.json(scope, default=repr) + async def parallel_queries(datasette): + db = datasette.get_database() + with tracer.trace_child_tasks(): + one, two = await asyncio.gather( + db.execute("select coalesce(sleep(0.1), 1)"), + db.execute("select coalesce(sleep(0.1), 2)"), + ) + return Response.json({"one": one.single_value(), "two": two.single_value()}) + return [ (r"/one/$", one), (r"/two/(?P.*)$", two), @@ -281,6 +292,7 @@ def register_routes(): (r"/add-message/$", add_message), (r"/render-message/$", render_message), (r"/asgi-scope$", asgi_scope), + (r"/parallel-queries$", parallel_queries), ] diff --git a/tests/test_tracer.py b/tests/test_tracer.py index 20a4427e..ceadee50 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -51,3 +51,18 @@ def test_trace(trace_debug): execute_manys = [trace for trace in traces if trace.get("executemany")] assert execute_manys assert all(isinstance(trace["count"], int) for trace in execute_manys) + + +def test_trace_parallel_queries(): + with make_app_client(settings={"trace_debug": True}) as client: + response = client.get("/parallel-queries?_trace=1") + assert response.status == 200 + + data = response.json + assert data["one"] == 1 + assert data["two"] == 2 + trace_info = data["_trace"] + traces = [trace for trace in trace_info["traces"] if "sql" in trace] + one, two = traces + # "two" should have started before "one" ended + assert two["start"] < one["end"]