database.execute_write_fn(transaction=True) parameter, closes #2277

pull/2096/merge
Simon Willison 2024-02-17 20:28:15 -08:00
rodzic e1c80efff8
commit 5e0e440f2c
3 zmienionych plików z 60 dodań i 13 usunięć

Wyświetl plik

@ -179,17 +179,25 @@ class Database:
# Threaded mode - send to write thread # Threaded mode - send to write thread
return await self._send_to_write_thread(fn, isolated_connection=True) return await self._send_to_write_thread(fn, isolated_connection=True)
async def execute_write_fn(self, fn, block=True): async def execute_write_fn(self, fn, block=True, transaction=True):
if self.ds.executor is None: if self.ds.executor is None:
# non-threaded mode # non-threaded mode
if self._write_connection is None: if self._write_connection is None:
self._write_connection = self.connect(write=True) self._write_connection = self.connect(write=True)
self.ds._prepare_connection(self._write_connection, self.name) self.ds._prepare_connection(self._write_connection, self.name)
return fn(self._write_connection) if transaction:
with self._write_connection:
return fn(self._write_connection)
else:
return fn(self._write_connection)
else: else:
return await self._send_to_write_thread(fn, block) return await self._send_to_write_thread(
fn, block=block, transaction=transaction
)
async def _send_to_write_thread(self, fn, block=True, isolated_connection=False): async def _send_to_write_thread(
self, fn, block=True, isolated_connection=False, transaction=True
):
if self._write_queue is None: if self._write_queue is None:
self._write_queue = queue.Queue() self._write_queue = queue.Queue()
if self._write_thread is None: if self._write_thread is None:
@ -202,7 +210,9 @@ class Database:
self._write_thread.start() self._write_thread.start()
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
reply_queue = janus.Queue() reply_queue = janus.Queue()
self._write_queue.put(WriteTask(fn, task_id, reply_queue, isolated_connection)) self._write_queue.put(
WriteTask(fn, task_id, reply_queue, isolated_connection, transaction)
)
if block: if block:
result = await reply_queue.async_q.get() result = await reply_queue.async_q.get()
if isinstance(result, Exception): if isinstance(result, Exception):
@ -244,7 +254,11 @@ class Database:
pass pass
else: else:
try: try:
result = task.fn(conn) if task.transaction:
with conn:
result = task.fn(conn)
else:
result = task.fn(conn)
except Exception as e: except Exception as e:
sys.stderr.write("{}\n".format(e)) sys.stderr.write("{}\n".format(e))
sys.stderr.flush() sys.stderr.flush()
@ -554,13 +568,14 @@ class Database:
class WriteTask: class WriteTask:
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection") __slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction")
def __init__(self, fn, task_id, reply_queue, isolated_connection): def __init__(self, fn, task_id, reply_queue, isolated_connection, transaction):
self.fn = fn self.fn = fn
self.task_id = task_id self.task_id = task_id
self.reply_queue = reply_queue self.reply_queue = reply_queue
self.isolated_connection = isolated_connection self.isolated_connection = isolated_connection
self.transaction = transaction
class QueryInterrupted(Exception): class QueryInterrupted(Exception):

Wyświetl plik

@ -1010,7 +1010,9 @@ You can pass additional SQL parameters as a tuple or dictionary.
The method will block until the operation is completed, and the return value will be the return from calling ``conn.execute(...)`` using the underlying ``sqlite3`` Python library. The method will block until the operation is completed, and the return value will be the return from calling ``conn.execute(...)`` using the underlying ``sqlite3`` Python library.
If you pass ``block=False`` this behaviour changes to "fire and forget" - queries will be added to the write queue and executed in a separate thread while your code can continue to do other things. The method will return a UUID representing the queued task. If you pass ``block=False`` this behavior changes to "fire and forget" - queries will be added to the write queue and executed in a separate thread while your code can continue to do other things. The method will return a UUID representing the queued task.
Each call to ``execute_write()`` will be executed inside a transaction.
.. _database_execute_write_script: .. _database_execute_write_script:
@ -1019,6 +1021,8 @@ await db.execute_write_script(sql, block=True)
Like ``execute_write()`` but can be used to send multiple SQL statements in a single string separated by semicolons, using the ``sqlite3`` `conn.executescript() <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.executescript>`__ method. Like ``execute_write()`` but can be used to send multiple SQL statements in a single string separated by semicolons, using the ``sqlite3`` `conn.executescript() <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.executescript>`__ method.
Each call to ``execute_write_script()`` will be executed inside a transaction.
.. _database_execute_write_many: .. _database_execute_write_many:
await db.execute_write_many(sql, params_seq, block=True) await db.execute_write_many(sql, params_seq, block=True)
@ -1033,10 +1037,12 @@ Like ``execute_write()`` but uses the ``sqlite3`` `conn.executemany() <https://d
[(1, "Melanie"), (2, "Selma"), (2, "Viktor")], [(1, "Melanie"), (2, "Selma"), (2, "Viktor")],
) )
Each call to ``execute_write_many()`` will be executed inside a transaction.
.. _database_execute_write_fn: .. _database_execute_write_fn:
await db.execute_write_fn(fn, block=True) await db.execute_write_fn(fn, block=True, transaction=True)
----------------------------------------- -----------------------------------------------------------
This method works like ``.execute_write()``, but instead of a SQL statement you give it a callable Python function. Your function will be queued up and then called when the write connection is available, passing that connection as the argument to the function. This method works like ``.execute_write()``, but instead of a SQL statement you give it a callable Python function. Your function will be queued up and then called when the write connection is available, passing that connection as the argument to the function.
@ -1052,7 +1058,6 @@ For example:
def delete_and_return_count(conn): def delete_and_return_count(conn):
conn.execute("delete from some_table where id > 5") conn.execute("delete from some_table where id > 5")
conn.commit()
return conn.execute( return conn.execute(
"select count(*) from some_table" "select count(*) from some_table"
).fetchone()[0] ).fetchone()[0]
@ -1069,7 +1074,7 @@ The value returned from ``await database.execute_write_fn(...)`` will be the ret
If your function raises an exception that exception will be propagated up to the ``await`` line. If your function raises an exception that exception will be propagated up to the ``await`` line.
If you see ``OperationalError: database table is locked`` errors you should check that you remembered to explicitly call ``conn.commit()`` in your write function. By default your function will be executed inside a transaction. You can pass ``transaction=False`` to disable this behavior, though if you do that you should be careful to manually apply transactions - ideally using the ``with conn:`` pattern, or you may see ``OperationalError: database table is locked`` errors.
If you specify ``block=False`` the method becomes fire-and-forget, queueing your function to be executed and then allowing your code after the call to ``.execute_write_fn()`` to continue running while the underlying thread waits for an opportunity to run your function. A UUID representing the queued task will be returned. Any exceptions in your code will be silently swallowed. If you specify ``block=False`` the method becomes fire-and-forget, queueing your function to be executed and then allowing your code after the call to ``.execute_write_fn()`` to continue running while the underlying thread waits for an opportunity to run your function. A UUID representing the queued task will be returned. Any exceptions in your code will be silently swallowed.

Wyświetl plik

@ -66,6 +66,33 @@ async def test_execute_fn(db):
assert 2 == await db.execute_fn(get_1_plus_1) assert 2 == await db.execute_fn(get_1_plus_1)
@pytest.mark.asyncio
async def test_execute_fn_transaction_false():
datasette = Datasette(memory=True)
db = datasette.add_memory_database("test_execute_fn_transaction_false")
def run(conn):
try:
with conn:
conn.execute("create table foo (id integer primary key)")
conn.execute("insert into foo (id) values (44)")
# Table should exist
assert (
conn.execute(
'select count(*) from sqlite_master where name = "foo"'
).fetchone()[0]
== 1
)
assert conn.execute("select id from foo").fetchall()[0][0] == 44
raise ValueError("Cancel commit")
except ValueError:
pass
# Row should NOT exist
assert conn.execute("select count(*) from foo").fetchone()[0] == 0
await db.execute_write_fn(run, transaction=False)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"tables,exists", "tables,exists",
( (