diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java b/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java index 618dcc56a..0a30e256e 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java @@ -3,8 +3,12 @@ package org.thoughtcrime.securesms.database; import android.content.ContentValues; import android.database.Cursor; +import android.os.CancellationSignal; +import android.util.Pair; import androidx.annotation.NonNull; +import androidx.sqlite.db.SupportSQLiteDatabase; +import androidx.sqlite.db.SupportSQLiteQuery; import net.zetetic.database.SQLException; import net.zetetic.database.sqlcipher.SQLiteStatement; @@ -12,9 +16,11 @@ import net.zetetic.database.sqlcipher.SQLiteTransactionListener; import org.signal.core.util.tracing.Tracer; +import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -25,7 +31,7 @@ import java.util.Set; * making a subclass, so instead we just match the interface. Callers should just need to change * their import statements. */ -public class SQLiteDatabase { +public class SQLiteDatabase implements SupportSQLiteDatabase { public static final int CONFLICT_ROLLBACK = 1; public static final int CONFLICT_ABORT = 2; @@ -183,6 +189,77 @@ public class SQLiteDatabase { } } + // ======================================================= + // Overrides + // ======================================================= + + @Override + public void beginTransactionWithListener(android.database.sqlite.SQLiteTransactionListener transactionListener) { + beginTransactionWithListener(new ConvertedTransactionListener(transactionListener)); + } + + @Override + public void beginTransactionWithListenerNonExclusive(android.database.sqlite.SQLiteTransactionListener transactionListener) { + beginTransactionWithListenerNonExclusive(new ConvertedTransactionListener(transactionListener)); + } + + @Override + public Cursor query(String query) { + return rawQuery(query, null); + } + + @Override + public Cursor query(String query, Object[] bindArgs) { + return rawQuery(query, bindArgs); + } + + @Override + public Cursor query(SupportSQLiteQuery query) { + DatabaseMonitor.onSql(query.getSql(), null); + return wrapped.query(query); + } + + @Override + public Cursor query(SupportSQLiteQuery query, CancellationSignal cancellationSignal) { + DatabaseMonitor.onSql(query.getSql(), null); + return wrapped.query(query, cancellationSignal); + } + + @Override + public long insert(String table, int conflictAlgorithm, ContentValues values) throws android.database.SQLException { + return insertWithOnConflict(table, null, values, conflictAlgorithm); + } + + @Override + public int delete(String table, String whereClause, Object[] whereArgs) { + return delete(table, whereClause, (String[]) whereArgs); + } + + @Override + public int update(String table, int conflictAlgorithm, ContentValues values, String whereClause, Object[] whereArgs) { + return updateWithOnConflict(table, values, whereClause, (String[]) whereArgs, conflictAlgorithm); + } + + @Override + public void setMaxSqlCacheSize(int cacheSize) { + wrapped.setMaxSqlCacheSize(cacheSize); + } + + @Override + public List> getAttachedDbs() { + return wrapped.getAttachedDbs(); + } + + @Override + public boolean isDatabaseIntegrityOk() { + return wrapped.isDatabaseIntegrityOk(); + } + + @Override + public void close() throws IOException { + wrapped.close(); + } + // ======================================================= // Traced @@ -297,6 +374,7 @@ public class SQLiteDatabase { } public int updateWithOnConflict(String table, ContentValues values, String whereClause, String[] whereArgs, int conflictAlgorithm) { + DatabaseMonitor.onUpdate(table, values, whereClause, whereArgs); return traceSql("updateWithOnConflict()", table, whereClause, true, () -> wrapped.updateWithOnConflict(table, values, whereClause, whereArgs, conflictAlgorithm)); } @@ -415,4 +493,28 @@ public class SQLiteDatabase { public void setLocale(Locale locale) { wrapped.setLocale(locale); } + + private static class ConvertedTransactionListener implements SQLiteTransactionListener { + + private final android.database.sqlite.SQLiteTransactionListener listener; + + ConvertedTransactionListener(android.database.sqlite.SQLiteTransactionListener listener) { + this.listener = listener; + } + + @Override + public void onBegin() { + listener.onBegin(); + } + + @Override + public void onCommit() { + listener.onCommit(); + } + + @Override + public void onRollback() { + listener.onRollback(); + } + } } diff --git a/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt b/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt index c7176cf58..18bea539d 100644 --- a/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt +++ b/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt @@ -1,6 +1,11 @@ package org.signal.core.util +import android.content.ContentValues +import android.database.Cursor +import android.database.sqlite.SQLiteDatabase +import androidx.core.content.contentValuesOf import androidx.sqlite.db.SupportSQLiteDatabase +import androidx.sqlite.db.SupportSQLiteQueryBuilder fun SupportSQLiteDatabase.getTableRowCount(table: String): Int { return this.query("SELECT COUNT(*) FROM $table").use { @@ -10,4 +15,212 @@ fun SupportSQLiteDatabase.getTableRowCount(table: String): Int { 0 } } -} \ No newline at end of file +} + +/** + * Begins a SELECT statement with a helpful builder pattern. + */ +fun SupportSQLiteDatabase.select(vararg columns: String): SelectBuilderPart1 { + return SelectBuilderPart1(this, arrayOf(*columns)) +} + +/** + * Begins an UPDATE statement with a helpful builder pattern. + */ +fun SupportSQLiteDatabase.update(tableName: String): UpdateBuilderPart1 { + return UpdateBuilderPart1(this, tableName) +} + +/** + * Begins a DELETE statement with a helpful builder pattern. + */ +fun SupportSQLiteDatabase.delete(tableName: String): DeleteBuilderPart1 { + return DeleteBuilderPart1(this, tableName) +} + +class SelectBuilderPart1( + private val db: SupportSQLiteDatabase, + private val columns: Array +) { + fun from(tableName: String): SelectBuilderPart2 { + return SelectBuilderPart2(db, columns, tableName) + } +} + +class SelectBuilderPart2( + private val db: SupportSQLiteDatabase, + private val columns: Array, + private val tableName: String +) { + fun where(where: String, vararg whereArgs: Any): SelectBuilderPart3 { + return SelectBuilderPart3(db, columns, tableName, where, SqlUtil.buildArgs(whereArgs)) + } + + fun run(): Cursor { + return db.query( + SupportSQLiteQueryBuilder + .builder(tableName) + .columns(columns) + .create() + ) + } +} + +class SelectBuilderPart3( + private val db: SupportSQLiteDatabase, + private val columns: Array, + private val tableName: String, + private val where: String, + private val whereArgs: Array +) { + fun orderBy(orderBy: String): SelectBuilderPart4a { + return SelectBuilderPart4a(db, columns, tableName, where, whereArgs, orderBy) + } + + fun limit(limit: Int): SelectBuilderPart4b { + return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, limit.toString()) + } + + fun run(): Cursor { + return db.query( + SupportSQLiteQueryBuilder + .builder(tableName) + .columns(columns) + .selection(where, whereArgs) + .create() + ) + } +} + +class SelectBuilderPart4a( + private val db: SupportSQLiteDatabase, + private val columns: Array, + private val tableName: String, + private val where: String, + private val whereArgs: Array, + private val orderBy: String +) { + fun limit(limit: Int): SelectBuilderPart5 { + return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit.toString()) + } + + fun run(): Cursor { + return db.query( + SupportSQLiteQueryBuilder + .builder(tableName) + .columns(columns) + .selection(where, whereArgs) + .orderBy(orderBy) + .create() + ) + } +} + +class SelectBuilderPart4b( + private val db: SupportSQLiteDatabase, + private val columns: Array, + private val tableName: String, + private val where: String, + private val whereArgs: Array, + private val limit: String +) { + fun orderBy(orderBy: String): SelectBuilderPart5 { + return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit) + } + + fun run(): Cursor { + return db.query( + SupportSQLiteQueryBuilder + .builder(tableName) + .columns(columns) + .selection(where, whereArgs) + .limit(limit) + .create() + ) + } +} + +class SelectBuilderPart5( + private val db: SupportSQLiteDatabase, + private val columns: Array, + private val tableName: String, + private val where: String, + private val whereArgs: Array, + private val orderBy: String, + private val limit: String +) { + fun run(): Cursor { + return db.query( + SupportSQLiteQueryBuilder + .builder(tableName) + .columns(columns) + .selection(where, whereArgs) + .orderBy(orderBy) + .limit(limit) + .create() + ) + } +} + +class UpdateBuilderPart1( + private val db: SupportSQLiteDatabase, + private val tableName: String +) { + fun values(values: ContentValues): UpdateBuilderPart2 { + return UpdateBuilderPart2(db, tableName, values) + } + + fun values(vararg values: Pair): UpdateBuilderPart2 { + return UpdateBuilderPart2(db, tableName, contentValuesOf(*values)) + } +} + +class UpdateBuilderPart2( + private val db: SupportSQLiteDatabase, + private val tableName: String, + private val values: ContentValues +) { + fun where(where: String, vararg whereArgs: Any): UpdateBuilderPart3 { + return UpdateBuilderPart3(db, tableName, values, where, SqlUtil.buildArgs(whereArgs)) + } + + fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int { + return db.update(tableName, conflictStrategy, values, null, null) + } +} + +class UpdateBuilderPart3( + private val db: SupportSQLiteDatabase, + private val tableName: String, + private val values: ContentValues, + private val where: String, + private val whereArgs: Array +) { + fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int { + return db.update(tableName, conflictStrategy, values, where, whereArgs) + } +} + +class DeleteBuilderPart1( + private val db: SupportSQLiteDatabase, + private val tableName: String +) { + fun where(where: String, vararg whereArgs: Any): DeleteBuilderPart2 { + return DeleteBuilderPart2(db, tableName, where, SqlUtil.buildArgs(whereArgs)) + } + + fun run(): Int { + return db.delete(tableName, null, null) + } +} + +class DeleteBuilderPart2( + private val db: SupportSQLiteDatabase, + private val tableName: String, + private val where: String, + private val whereArgs: Array +) { + fun run(): Int { + return db.delete(tableName, where, whereArgs) + } +}