kopia lustrzana https://github.com/micropython/micropython-lib
unix-ffi/sqlite3: Add commit and rollback functionality like CPython.
To increase the similarity between this module and CPythons sqlite3 module the commit() and rollback() as defined in CPythons version have been added, along with the different (auto)commit behaviors present there. The defaults are also set to the same as in CPython, and can be changed with the same parameters in connect(), as is showcased in the new test. Signed-off-by: Robert Klink <rhermanklink@ripe.net>pull/905/head
rodzic
83598cdb3c
commit
b77f67bd7c
|
@ -12,6 +12,8 @@ sq3 = ffilib.open("libsqlite3")
|
|||
sqlite3_open = sq3.func("i", "sqlite3_open", "sp")
|
||||
# int sqlite3_config(int, ...);
|
||||
sqlite3_config = sq3.func("i", "sqlite3_config", "ii")
|
||||
# int sqlite3_get_autocommit(sqlite3*);
|
||||
sqlite3_get_autocommit = sq3.func("i", "sqlite3_get_autocommit", "p")
|
||||
# int sqlite3_close_v2(sqlite3*);
|
||||
sqlite3_close = sq3.func("i", "sqlite3_close_v2", "p")
|
||||
# int sqlite3_prepare(
|
||||
|
@ -57,6 +59,9 @@ SQLITE_NULL = 5
|
|||
|
||||
SQLITE_CONFIG_URI = 17
|
||||
|
||||
# For compatibility with CPython sqlite3 driver
|
||||
LEGACY_TRANSACTION_CONTROL = -1
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
pass
|
||||
|
@ -71,86 +76,138 @@ def get_ptr_size():
|
|||
return uctypes.sizeof({"ptr": (0 | uctypes.PTR, uctypes.PTR)})
|
||||
|
||||
|
||||
def __prepare_stmt(db, sql):
|
||||
# Prepares a statement
|
||||
stmt_ptr = bytes(get_ptr_size())
|
||||
res = sqlite3_prepare(db, sql, -1, stmt_ptr, None)
|
||||
check_error(db, res)
|
||||
return int.from_bytes(stmt_ptr, sys.byteorder)
|
||||
|
||||
def __exec_stmt(db, sql):
|
||||
# Prepares, executes, and finalizes a statement
|
||||
stmt = __prepare_stmt(db, sql)
|
||||
sqlite3_step(stmt)
|
||||
res = sqlite3_finalize(stmt)
|
||||
check_error(db, res)
|
||||
|
||||
def __is_dml(sql):
|
||||
# Checks if a sql query is a DML, as these get a BEGIN in LEGACY_TRANSACTION_CONTROL
|
||||
for dml in ["INSERT", "DELETE", "UPDATE", "MERGE"]:
|
||||
if dml in sql.upper():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Connections:
|
||||
def __init__(self, h):
|
||||
self.h = h
|
||||
def __init__(self, db, isolation_level, autocommit):
|
||||
self.db = db
|
||||
self.isolation_level = isolation_level
|
||||
self.autocommit = autocommit
|
||||
|
||||
def commit(self):
|
||||
if self.autocommit == LEGACY_TRANSACTION_CONTROL and not sqlite3_get_autocommit(self.db):
|
||||
__exec_stmt(self.db, "COMMIT")
|
||||
elif self.autocommit == False:
|
||||
__exec_stmt(self.db, "COMMIT")
|
||||
__exec_stmt(self.db, "BEGIN")
|
||||
|
||||
def rollback(self):
|
||||
if self.autocommit == LEGACY_TRANSACTION_CONTROL and not sqlite3_get_autocommit(self.db):
|
||||
__exec_stmt(self.db, "ROLLBACK")
|
||||
elif self.autocommit == False:
|
||||
__exec_stmt(self.db, "ROLLBACK")
|
||||
__exec_stmt(self.db, "BEGIN")
|
||||
|
||||
def cursor(self):
|
||||
return Cursor(self.h)
|
||||
return Cursor(self.db, self.isolation_level, self.autocommit)
|
||||
|
||||
def close(self):
|
||||
if self.h:
|
||||
s = sqlite3_close(self.h)
|
||||
check_error(self.h, s)
|
||||
self.h = None
|
||||
if self.db:
|
||||
if self.autocommit == False and not sqlite3_get_autocommit(self.db):
|
||||
__exec_stmt(self.db, "ROLLBACK")
|
||||
|
||||
res = sqlite3_close(self.db)
|
||||
check_error(self.db, res)
|
||||
self.db = None
|
||||
|
||||
|
||||
class Cursor:
|
||||
def __init__(self, h):
|
||||
self.h = h
|
||||
self.stmnt = None
|
||||
def __init__(self, db, isolation_level, autocommit):
|
||||
self.db = db
|
||||
self.isolation_level = isolation_level
|
||||
self.autocommit = autocommit
|
||||
self.stmt = None
|
||||
|
||||
def __quote(val):
|
||||
if isinstance(val, str):
|
||||
return "'%s'" % val
|
||||
return str(val)
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
if self.stmnt:
|
||||
if self.stmt:
|
||||
# If there is an existing statement, finalize that to free it
|
||||
res = sqlite3_finalize(self.stmnt)
|
||||
check_error(self.h, res)
|
||||
res = sqlite3_finalize(self.stmt)
|
||||
check_error(self.db, res)
|
||||
|
||||
if params:
|
||||
params = [quote(v) for v in params]
|
||||
params = [self.__quote(v) for v in params]
|
||||
sql = sql % tuple(params)
|
||||
|
||||
stmnt_ptr = bytes(get_ptr_size())
|
||||
res = sqlite3_prepare(self.h, sql, -1, stmnt_ptr, None)
|
||||
check_error(self.h, res)
|
||||
self.stmnt = int.from_bytes(stmnt_ptr, sys.byteorder)
|
||||
self.num_cols = sqlite3_column_count(self.stmnt)
|
||||
if __is_dml(sql) and self.autocommit == LEGACY_TRANSACTION_CONTROL and sqlite3_get_autocommit(self.db):
|
||||
# For compatibility with CPython, add functionality for their default transaction
|
||||
# behavior. Changing autocommit from LEGACY_TRANSACTION_CONTROL will remove this
|
||||
__exec_stmt(self.db, "BEGIN " + self.isolation_level)
|
||||
|
||||
self.stmt = __prepare_stmt(self.db, sql)
|
||||
self.num_cols = sqlite3_column_count(self.stmt)
|
||||
|
||||
if not self.num_cols:
|
||||
v = self.fetchone()
|
||||
# If it's not select, actually execute it here
|
||||
# num_cols == 0 for statements which don't return data (=> modify it)
|
||||
assert v is None
|
||||
self.lastrowid = sqlite3_last_insert_rowid(self.h)
|
||||
self.lastrowid = sqlite3_last_insert_rowid(self.db)
|
||||
|
||||
def close(self):
|
||||
if self.stmnt:
|
||||
s = sqlite3_finalize(self.stmnt)
|
||||
check_error(self.h, s)
|
||||
self.stmnt = None
|
||||
if self.stmt:
|
||||
res = sqlite3_finalize(self.stmt)
|
||||
check_error(self.db, res)
|
||||
self.stmt = None
|
||||
|
||||
def make_row(self):
|
||||
def __make_row(self):
|
||||
res = []
|
||||
for i in range(self.num_cols):
|
||||
t = sqlite3_column_type(self.stmnt, i)
|
||||
t = sqlite3_column_type(self.stmt, i)
|
||||
if t == SQLITE_INTEGER:
|
||||
res.append(sqlite3_column_int(self.stmnt, i))
|
||||
res.append(sqlite3_column_int(self.stmt, i))
|
||||
elif t == SQLITE_FLOAT:
|
||||
res.append(sqlite3_column_double(self.stmnt, i))
|
||||
res.append(sqlite3_column_double(self.stmt, i))
|
||||
elif t == SQLITE_TEXT:
|
||||
res.append(sqlite3_column_text(self.stmnt, i))
|
||||
res.append(sqlite3_column_text(self.stmt, i))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return tuple(res)
|
||||
|
||||
def fetchone(self):
|
||||
res = sqlite3_step(self.stmnt)
|
||||
res = sqlite3_step(self.stmt)
|
||||
if res == SQLITE_DONE:
|
||||
return None
|
||||
if res == SQLITE_ROW:
|
||||
return self.make_row()
|
||||
check_error(self.h, res)
|
||||
return self.__make_row()
|
||||
check_error(self.db, res)
|
||||
|
||||
|
||||
def connect(fname, uri=False):
|
||||
def connect(fname, uri=False, isolation_level="", autocommit=LEGACY_TRANSACTION_CONTROL):
|
||||
if isolation_level not in [None, "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE"]:
|
||||
raise Error("Invalid option for isolation level")
|
||||
|
||||
sqlite3_config(SQLITE_CONFIG_URI, int(uri))
|
||||
|
||||
sqlite_ptr = bytes(get_ptr_size())
|
||||
sqlite3_open(fname, sqlite_ptr)
|
||||
return Connections(int.from_bytes(sqlite_ptr, sys.byteorder))
|
||||
db = int.from_bytes(sqlite_ptr, sys.byteorder)
|
||||
|
||||
if autocommit == False:
|
||||
__exec_stmt(db, "BEGIN")
|
||||
|
||||
def quote(val):
|
||||
if isinstance(val, str):
|
||||
return "'%s'" % val
|
||||
return str(val)
|
||||
return Connections(db, isolation_level, autocommit)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import sqlite3
|
||||
|
||||
|
||||
def test_autocommit():
|
||||
conn = sqlite3.connect(":memory:", autocommit=True)
|
||||
|
||||
# First cursor creates table and inserts value (DML)
|
||||
cur = conn.cursor()
|
||||
cur.execute("CREATE TABLE foo(a int)")
|
||||
cur.execute("INSERT INTO foo VALUES (42)")
|
||||
cur.close()
|
||||
|
||||
# Second cursor fetches 42 due to the autocommit
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT * FROM foo")
|
||||
assert cur.fetchone() == (42,)
|
||||
assert cur.fetchone() is None
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
def test_manual():
|
||||
conn = sqlite3.connect(":memory:", autocommit=False)
|
||||
|
||||
# First cursor creates table, insert rolls back
|
||||
cur = conn.cursor()
|
||||
cur.execute("CREATE TABLE foo(a int)")
|
||||
conn.commit()
|
||||
cur.execute("INSERT INTO foo VALUES (42)")
|
||||
cur.close()
|
||||
conn.rollback()
|
||||
|
||||
# Second connection fetches nothing due to the rollback
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT * FROM foo")
|
||||
assert cur.fetchone() is None
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
test_autocommit()
|
||||
test_manual()
|
Ładowanie…
Reference in New Issue