diff --git a/datasette/database.py b/datasette/database.py index 707d8f85..d34aac73 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -179,17 +179,25 @@ class Database: # Threaded mode - send to write thread return await self._send_to_write_thread(fn, isolated_connection=True) - async def execute_write_fn(self, fn, block=True): + async def execute_write_fn(self, fn, block=True, transaction=True): if self.ds.executor is None: # non-threaded mode if self._write_connection is None: self._write_connection = self.connect(write=True) self.ds._prepare_connection(self._write_connection, self.name) - return fn(self._write_connection) + if transaction: + with self._write_connection: + return fn(self._write_connection) + else: + return fn(self._write_connection) else: - return await self._send_to_write_thread(fn, block) + return await self._send_to_write_thread( + fn, block=block, transaction=transaction + ) - async def _send_to_write_thread(self, fn, block=True, isolated_connection=False): + async def _send_to_write_thread( + self, fn, block=True, isolated_connection=False, transaction=True + ): if self._write_queue is None: self._write_queue = queue.Queue() if self._write_thread is None: @@ -202,7 +210,9 @@ class Database: self._write_thread.start() task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") reply_queue = janus.Queue() - self._write_queue.put(WriteTask(fn, task_id, reply_queue, isolated_connection)) + self._write_queue.put( + WriteTask(fn, task_id, reply_queue, isolated_connection, transaction) + ) if block: result = await reply_queue.async_q.get() if isinstance(result, Exception): @@ -244,7 +254,11 @@ class Database: pass else: try: - result = task.fn(conn) + if task.transaction: + with conn: + result = task.fn(conn) + else: + result = task.fn(conn) except Exception as e: sys.stderr.write("{}\n".format(e)) sys.stderr.flush() @@ -554,13 +568,14 @@ class Database: class WriteTask: - __slots__ = ("fn", "task_id", "reply_queue", "isolated_connection") + __slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction") - def __init__(self, fn, task_id, reply_queue, isolated_connection): + def __init__(self, fn, task_id, reply_queue, isolated_connection, transaction): self.fn = fn self.task_id = task_id self.reply_queue = reply_queue self.isolated_connection = isolated_connection + self.transaction = transaction class QueryInterrupted(Exception): diff --git a/docs/internals.rst b/docs/internals.rst index bd7a70b5..6ca62423 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -1010,7 +1010,9 @@ You can pass additional SQL parameters as a tuple or dictionary. The method will block until the operation is completed, and the return value will be the return from calling ``conn.execute(...)`` using the underlying ``sqlite3`` Python library. -If you pass ``block=False`` this behaviour changes to "fire and forget" - queries will be added to the write queue and executed in a separate thread while your code can continue to do other things. The method will return a UUID representing the queued task. +If you pass ``block=False`` this behavior changes to "fire and forget" - queries will be added to the write queue and executed in a separate thread while your code can continue to do other things. The method will return a UUID representing the queued task. + +Each call to ``execute_write()`` will be executed inside a transaction. .. _database_execute_write_script: @@ -1019,6 +1021,8 @@ await db.execute_write_script(sql, block=True) Like ``execute_write()`` but can be used to send multiple SQL statements in a single string separated by semicolons, using the ``sqlite3`` `conn.executescript() `__ method. +Each call to ``execute_write_script()`` will be executed inside a transaction. + .. _database_execute_write_many: await db.execute_write_many(sql, params_seq, block=True) @@ -1033,10 +1037,12 @@ Like ``execute_write()`` but uses the ``sqlite3`` `conn.executemany() 5") - conn.commit() return conn.execute( "select count(*) from some_table" ).fetchone()[0] @@ -1069,7 +1074,7 @@ The value returned from ``await database.execute_write_fn(...)`` will be the ret If your function raises an exception that exception will be propagated up to the ``await`` line. -If you see ``OperationalError: database table is locked`` errors you should check that you remembered to explicitly call ``conn.commit()`` in your write function. +By default your function will be executed inside a transaction. You can pass ``transaction=False`` to disable this behavior, though if you do that you should be careful to manually apply transactions - ideally using the ``with conn:`` pattern, or you may see ``OperationalError: database table is locked`` errors. If you specify ``block=False`` the method becomes fire-and-forget, queueing your function to be executed and then allowing your code after the call to ``.execute_write_fn()`` to continue running while the underlying thread waits for an opportunity to run your function. A UUID representing the queued task will be returned. Any exceptions in your code will be silently swallowed. diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index dd68a6cb..57e75046 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -66,6 +66,33 @@ async def test_execute_fn(db): assert 2 == await db.execute_fn(get_1_plus_1) +@pytest.mark.asyncio +async def test_execute_fn_transaction_false(): + datasette = Datasette(memory=True) + db = datasette.add_memory_database("test_execute_fn_transaction_false") + + def run(conn): + try: + with conn: + conn.execute("create table foo (id integer primary key)") + conn.execute("insert into foo (id) values (44)") + # Table should exist + assert ( + conn.execute( + 'select count(*) from sqlite_master where name = "foo"' + ).fetchone()[0] + == 1 + ) + assert conn.execute("select id from foo").fetchall()[0][0] == 44 + raise ValueError("Cancel commit") + except ValueError: + pass + # Row should NOT exist + assert conn.execute("select count(*) from foo").fetchone()[0] == 0 + + await db.execute_write_fn(run, transaction=False) + + @pytest.mark.parametrize( "tables,exists", (