Use sqlite.interrupt() instead of sqlite_timelimit() - refs #1270

pull/1271/head
Simon Willison 2021-03-22 10:33:11 -07:00
rodzic c4f1ec7f33
commit fb2ad7ada0
1 zmienionych plików z 44 dodań i 33 usunięć

Wyświetl plik

@ -143,7 +143,7 @@ class Database:
result = e
task.reply_queue.sync_q.put(result)
async def execute_fn(self, fn):
async def execute_fn(self, fn, time_limit=None):
def in_thread():
conn = getattr(connections, self.name, None)
if not conn:
@ -152,9 +152,17 @@ class Database:
setattr(connections, self.name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(
self.ds.executor, in_thread
)
executor = asyncio.get_event_loop().run_in_executor(self.ds.executor, in_thread)
try:
return await asyncio.wait_for(
executor,
timeout=(time_limit / 1000.0) if time_limit is not None else None,
)
except asyncio.TimeoutError:
conn = getattr(connections, self.name, None)
if conn:
conn.interrupt()
raise QueryInterrupted()
async def execute(
self,
@ -168,36 +176,37 @@ class Database:
"""Executes sql against db_name in a thread"""
page_size = page_size or self.ds.page_size
def sql_operation_in_thread(conn):
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params if params is not None else {})
max_returned_rows = self.ds.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
if e.args == ("interrupted",):
raise QueryInterrupted(e, sql, params)
if log_sql_errors:
sys.stderr.write(
"ERROR: conn={}, sql = {}, params = {}: {}\n".format(
conn, repr(sql), params, e
)
def sql_operation_in_thread(conn):
try:
cursor = conn.cursor()
with open("/tmp/sql.log", "ab", buffering=0) as fp:
fp.write(("{}: {}\n".format(sql, params)).encode("utf-8"))
cursor.execute(sql, params if params is not None else {})
max_returned_rows = self.ds.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
if e.args == ("interrupted",):
raise QueryInterrupted(e, sql, params)
if log_sql_errors:
sys.stderr.write(
"ERROR: conn={}, sql = {}, params = {}: {}\n".format(
conn, repr(sql), params, e
)
sys.stderr.flush()
raise
)
sys.stderr.flush()
raise
if truncate:
return Results(rows, truncated, cursor.description)
@ -206,7 +215,9 @@ class Database:
return Results(rows, False, cursor.description)
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_fn(sql_operation_in_thread)
results = await self.execute_fn(
sql_operation_in_thread, time_limit=time_limit_ms
)
return results
@property