diff --git a/datasette/database.py b/datasette/database.py index 3579cce9..20b6fd06 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -143,7 +143,7 @@ class Database: result = e task.reply_queue.sync_q.put(result) - async def execute_fn(self, fn): + async def execute_fn(self, fn, time_limit=None): def in_thread(): conn = getattr(connections, self.name, None) if not conn: @@ -152,9 +152,17 @@ class Database: setattr(connections, self.name, conn) return fn(conn) - return await asyncio.get_event_loop().run_in_executor( - self.ds.executor, in_thread - ) + executor = asyncio.get_event_loop().run_in_executor(self.ds.executor, in_thread) + try: + return await asyncio.wait_for( + executor, + timeout=(time_limit / 1000.0) if time_limit is not None else None, + ) + except asyncio.TimeoutError: + conn = getattr(connections, self.name, None) + if conn: + conn.interrupt() + raise QueryInterrupted() async def execute( self, @@ -168,36 +176,37 @@ class Database: """Executes sql against db_name in a thread""" page_size = page_size or self.ds.page_size - def sql_operation_in_thread(conn): - time_limit_ms = self.ds.sql_time_limit_ms - if custom_time_limit and custom_time_limit < time_limit_ms: - time_limit_ms = custom_time_limit + time_limit_ms = self.ds.sql_time_limit_ms + if custom_time_limit and custom_time_limit < time_limit_ms: + time_limit_ms = custom_time_limit - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params if params is not None else {}) - max_returned_rows = self.ds.max_returned_rows - if max_returned_rows == page_size: - max_returned_rows += 1 - if max_returned_rows and truncate: - rows = cursor.fetchmany(max_returned_rows + 1) - truncated = len(rows) > max_returned_rows - rows = rows[:max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: - if e.args == ("interrupted",): - raise QueryInterrupted(e, sql, params) - if log_sql_errors: - sys.stderr.write( - "ERROR: conn={}, sql = {}, params = {}: {}\n".format( - conn, repr(sql), params, e - ) + def sql_operation_in_thread(conn): + try: + cursor = conn.cursor() + with open("/tmp/sql.log", "ab", buffering=0) as fp: + fp.write(("{}: {}\n".format(sql, params)).encode("utf-8")) + cursor.execute(sql, params if params is not None else {}) + max_returned_rows = self.ds.max_returned_rows + if max_returned_rows == page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: + if e.args == ("interrupted",): + raise QueryInterrupted(e, sql, params) + if log_sql_errors: + sys.stderr.write( + "ERROR: conn={}, sql = {}, params = {}: {}\n".format( + conn, repr(sql), params, e ) - sys.stderr.flush() - raise + ) + sys.stderr.flush() + raise if truncate: return Results(rows, truncated, cursor.description) @@ -206,7 +215,9 @@ class Database: return Results(rows, False, cursor.description) with trace("sql", database=self.name, sql=sql.strip(), params=params): - results = await self.execute_fn(sql_operation_in_thread) + results = await self.execute_fn( + sql_operation_in_thread, time_limit=time_limit_ms + ) return results @property