Use sqlite.interrupt() instead of sqlite_timelimit() - refs #1270

pull/1271/head
Simon Willison 2021-03-22 10:33:11 -07:00
rodzic c4f1ec7f33
commit fb2ad7ada0
1 zmienionych plików z 44 dodań i 33 usunięć

Wyświetl plik

@ -143,7 +143,7 @@ class Database:
result = e result = e
task.reply_queue.sync_q.put(result) task.reply_queue.sync_q.put(result)
async def execute_fn(self, fn): async def execute_fn(self, fn, time_limit=None):
def in_thread(): def in_thread():
conn = getattr(connections, self.name, None) conn = getattr(connections, self.name, None)
if not conn: if not conn:
@ -152,9 +152,17 @@ class Database:
setattr(connections, self.name, conn) setattr(connections, self.name, conn)
return fn(conn) return fn(conn)
return await asyncio.get_event_loop().run_in_executor( executor = asyncio.get_event_loop().run_in_executor(self.ds.executor, in_thread)
self.ds.executor, in_thread try:
) return await asyncio.wait_for(
executor,
timeout=(time_limit / 1000.0) if time_limit is not None else None,
)
except asyncio.TimeoutError:
conn = getattr(connections, self.name, None)
if conn:
conn.interrupt()
raise QueryInterrupted()
async def execute( async def execute(
self, self,
@ -168,36 +176,37 @@ class Database:
"""Executes sql against db_name in a thread""" """Executes sql against db_name in a thread"""
page_size = page_size or self.ds.page_size page_size = page_size or self.ds.page_size
def sql_operation_in_thread(conn): time_limit_ms = self.ds.sql_time_limit_ms
time_limit_ms = self.ds.sql_time_limit_ms if custom_time_limit and custom_time_limit < time_limit_ms:
if custom_time_limit and custom_time_limit < time_limit_ms: time_limit_ms = custom_time_limit
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms): def sql_operation_in_thread(conn):
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(sql, params if params is not None else {}) with open("/tmp/sql.log", "ab", buffering=0) as fp:
max_returned_rows = self.ds.max_returned_rows fp.write(("{}: {}\n".format(sql, params)).encode("utf-8"))
if max_returned_rows == page_size: cursor.execute(sql, params if params is not None else {})
max_returned_rows += 1 max_returned_rows = self.ds.max_returned_rows
if max_returned_rows and truncate: if max_returned_rows == page_size:
rows = cursor.fetchmany(max_returned_rows + 1) max_returned_rows += 1
truncated = len(rows) > max_returned_rows if max_returned_rows and truncate:
rows = rows[:max_returned_rows] rows = cursor.fetchmany(max_returned_rows + 1)
else: truncated = len(rows) > max_returned_rows
rows = cursor.fetchall() rows = rows[:max_returned_rows]
truncated = False else:
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: rows = cursor.fetchall()
if e.args == ("interrupted",): truncated = False
raise QueryInterrupted(e, sql, params) except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
if log_sql_errors: if e.args == ("interrupted",):
sys.stderr.write( raise QueryInterrupted(e, sql, params)
"ERROR: conn={}, sql = {}, params = {}: {}\n".format( if log_sql_errors:
conn, repr(sql), params, e sys.stderr.write(
) "ERROR: conn={}, sql = {}, params = {}: {}\n".format(
conn, repr(sql), params, e
) )
sys.stderr.flush() )
raise sys.stderr.flush()
raise
if truncate: if truncate:
return Results(rows, truncated, cursor.description) return Results(rows, truncated, cursor.description)
@ -206,7 +215,9 @@ class Database:
return Results(rows, False, cursor.description) return Results(rows, False, cursor.description)
with trace("sql", database=self.name, sql=sql.strip(), params=params): with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_fn(sql_operation_in_thread) results = await self.execute_fn(
sql_operation_in_thread, time_limit=time_limit_ms
)
return results return results
@property @property