unix-ffi/sqlite3: Add commit and rollback functionality like CPython.

To increase the similarity between this module and CPythons sqlite3 module
the commit() and rollback() as defined in CPythons version have been
added, along with the different (auto)commit behaviors present there.
The defaults are also set to the same as in CPython, and can be changed
with the same parameters in connect(), as is showcased in the new test.

Signed-off-by: Robert Klink <rhermanklink@ripe.net>
pull/905/head
Robert Klink 2024-08-06 15:24:39 +02:00 zatwierdzone przez Damien George
rodzic 83598cdb3c
commit b77f67bd7c
2 zmienionych plików z 137 dodań i 38 usunięć

Wyświetl plik

@ -12,6 +12,8 @@ sq3 = ffilib.open("libsqlite3")
sqlite3_open = sq3.func("i", "sqlite3_open", "sp") sqlite3_open = sq3.func("i", "sqlite3_open", "sp")
# int sqlite3_config(int, ...); # int sqlite3_config(int, ...);
sqlite3_config = sq3.func("i", "sqlite3_config", "ii") sqlite3_config = sq3.func("i", "sqlite3_config", "ii")
# int sqlite3_get_autocommit(sqlite3*);
sqlite3_get_autocommit = sq3.func("i", "sqlite3_get_autocommit", "p")
# int sqlite3_close_v2(sqlite3*); # int sqlite3_close_v2(sqlite3*);
sqlite3_close = sq3.func("i", "sqlite3_close_v2", "p") sqlite3_close = sq3.func("i", "sqlite3_close_v2", "p")
# int sqlite3_prepare( # int sqlite3_prepare(
@ -57,6 +59,9 @@ SQLITE_NULL = 5
SQLITE_CONFIG_URI = 17 SQLITE_CONFIG_URI = 17
# For compatibility with CPython sqlite3 driver
LEGACY_TRANSACTION_CONTROL = -1
class Error(Exception): class Error(Exception):
pass pass
@ -71,86 +76,138 @@ def get_ptr_size():
return uctypes.sizeof({"ptr": (0 | uctypes.PTR, uctypes.PTR)}) return uctypes.sizeof({"ptr": (0 | uctypes.PTR, uctypes.PTR)})
def __prepare_stmt(db, sql):
# Prepares a statement
stmt_ptr = bytes(get_ptr_size())
res = sqlite3_prepare(db, sql, -1, stmt_ptr, None)
check_error(db, res)
return int.from_bytes(stmt_ptr, sys.byteorder)
def __exec_stmt(db, sql):
# Prepares, executes, and finalizes a statement
stmt = __prepare_stmt(db, sql)
sqlite3_step(stmt)
res = sqlite3_finalize(stmt)
check_error(db, res)
def __is_dml(sql):
# Checks if a sql query is a DML, as these get a BEGIN in LEGACY_TRANSACTION_CONTROL
for dml in ["INSERT", "DELETE", "UPDATE", "MERGE"]:
if dml in sql.upper():
return True
return False
class Connections: class Connections:
def __init__(self, h): def __init__(self, db, isolation_level, autocommit):
self.h = h self.db = db
self.isolation_level = isolation_level
self.autocommit = autocommit
def commit(self):
if self.autocommit == LEGACY_TRANSACTION_CONTROL and not sqlite3_get_autocommit(self.db):
__exec_stmt(self.db, "COMMIT")
elif self.autocommit == False:
__exec_stmt(self.db, "COMMIT")
__exec_stmt(self.db, "BEGIN")
def rollback(self):
if self.autocommit == LEGACY_TRANSACTION_CONTROL and not sqlite3_get_autocommit(self.db):
__exec_stmt(self.db, "ROLLBACK")
elif self.autocommit == False:
__exec_stmt(self.db, "ROLLBACK")
__exec_stmt(self.db, "BEGIN")
def cursor(self): def cursor(self):
return Cursor(self.h) return Cursor(self.db, self.isolation_level, self.autocommit)
def close(self): def close(self):
if self.h: if self.db:
s = sqlite3_close(self.h) if self.autocommit == False and not sqlite3_get_autocommit(self.db):
check_error(self.h, s) __exec_stmt(self.db, "ROLLBACK")
self.h = None
res = sqlite3_close(self.db)
check_error(self.db, res)
self.db = None
class Cursor: class Cursor:
def __init__(self, h): def __init__(self, db, isolation_level, autocommit):
self.h = h self.db = db
self.stmnt = None self.isolation_level = isolation_level
self.autocommit = autocommit
self.stmt = None
def __quote(val):
if isinstance(val, str):
return "'%s'" % val
return str(val)
def execute(self, sql, params=None): def execute(self, sql, params=None):
if self.stmnt: if self.stmt:
# If there is an existing statement, finalize that to free it # If there is an existing statement, finalize that to free it
res = sqlite3_finalize(self.stmnt) res = sqlite3_finalize(self.stmt)
check_error(self.h, res) check_error(self.db, res)
if params: if params:
params = [quote(v) for v in params] params = [self.__quote(v) for v in params]
sql = sql % tuple(params) sql = sql % tuple(params)
stmnt_ptr = bytes(get_ptr_size()) if __is_dml(sql) and self.autocommit == LEGACY_TRANSACTION_CONTROL and sqlite3_get_autocommit(self.db):
res = sqlite3_prepare(self.h, sql, -1, stmnt_ptr, None) # For compatibility with CPython, add functionality for their default transaction
check_error(self.h, res) # behavior. Changing autocommit from LEGACY_TRANSACTION_CONTROL will remove this
self.stmnt = int.from_bytes(stmnt_ptr, sys.byteorder) __exec_stmt(self.db, "BEGIN " + self.isolation_level)
self.num_cols = sqlite3_column_count(self.stmnt)
self.stmt = __prepare_stmt(self.db, sql)
self.num_cols = sqlite3_column_count(self.stmt)
if not self.num_cols: if not self.num_cols:
v = self.fetchone() v = self.fetchone()
# If it's not select, actually execute it here # If it's not select, actually execute it here
# num_cols == 0 for statements which don't return data (=> modify it) # num_cols == 0 for statements which don't return data (=> modify it)
assert v is None assert v is None
self.lastrowid = sqlite3_last_insert_rowid(self.h) self.lastrowid = sqlite3_last_insert_rowid(self.db)
def close(self): def close(self):
if self.stmnt: if self.stmt:
s = sqlite3_finalize(self.stmnt) res = sqlite3_finalize(self.stmt)
check_error(self.h, s) check_error(self.db, res)
self.stmnt = None self.stmt = None
def make_row(self): def __make_row(self):
res = [] res = []
for i in range(self.num_cols): for i in range(self.num_cols):
t = sqlite3_column_type(self.stmnt, i) t = sqlite3_column_type(self.stmt, i)
if t == SQLITE_INTEGER: if t == SQLITE_INTEGER:
res.append(sqlite3_column_int(self.stmnt, i)) res.append(sqlite3_column_int(self.stmt, i))
elif t == SQLITE_FLOAT: elif t == SQLITE_FLOAT:
res.append(sqlite3_column_double(self.stmnt, i)) res.append(sqlite3_column_double(self.stmt, i))
elif t == SQLITE_TEXT: elif t == SQLITE_TEXT:
res.append(sqlite3_column_text(self.stmnt, i)) res.append(sqlite3_column_text(self.stmt, i))
else: else:
raise NotImplementedError raise NotImplementedError
return tuple(res) return tuple(res)
def fetchone(self): def fetchone(self):
res = sqlite3_step(self.stmnt) res = sqlite3_step(self.stmt)
if res == SQLITE_DONE: if res == SQLITE_DONE:
return None return None
if res == SQLITE_ROW: if res == SQLITE_ROW:
return self.make_row() return self.__make_row()
check_error(self.h, res) check_error(self.db, res)
def connect(fname, uri=False): def connect(fname, uri=False, isolation_level="", autocommit=LEGACY_TRANSACTION_CONTROL):
if isolation_level not in [None, "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE"]:
raise Error("Invalid option for isolation level")
sqlite3_config(SQLITE_CONFIG_URI, int(uri)) sqlite3_config(SQLITE_CONFIG_URI, int(uri))
sqlite_ptr = bytes(get_ptr_size()) sqlite_ptr = bytes(get_ptr_size())
sqlite3_open(fname, sqlite_ptr) sqlite3_open(fname, sqlite_ptr)
return Connections(int.from_bytes(sqlite_ptr, sys.byteorder)) db = int.from_bytes(sqlite_ptr, sys.byteorder)
if autocommit == False:
__exec_stmt(db, "BEGIN")
def quote(val): return Connections(db, isolation_level, autocommit)
if isinstance(val, str):
return "'%s'" % val
return str(val)

Wyświetl plik

@ -0,0 +1,42 @@
import sqlite3
def test_autocommit():
conn = sqlite3.connect(":memory:", autocommit=True)
# First cursor creates table and inserts value (DML)
cur = conn.cursor()
cur.execute("CREATE TABLE foo(a int)")
cur.execute("INSERT INTO foo VALUES (42)")
cur.close()
# Second cursor fetches 42 due to the autocommit
cur = conn.cursor()
cur.execute("SELECT * FROM foo")
assert cur.fetchone() == (42,)
assert cur.fetchone() is None
cur.close()
conn.close()
def test_manual():
conn = sqlite3.connect(":memory:", autocommit=False)
# First cursor creates table, insert rolls back
cur = conn.cursor()
cur.execute("CREATE TABLE foo(a int)")
conn.commit()
cur.execute("INSERT INTO foo VALUES (42)")
cur.close()
conn.rollback()
# Second connection fetches nothing due to the rollback
cur = conn.cursor()
cur.execute("SELECT * FROM foo")
assert cur.fetchone() is None
cur.close()
conn.close()
test_autocommit()
test_manual()