From 35f9437413ca32c1381c49f3ec04a8edd234ee4e Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Fri, 14 Jan 2022 14:07:14 -0500 Subject: [PATCH] Delay database notifications until after a transaction has finished. --- .../database/DatabaseObserverTest.kt | 157 ++++++++++++++++++ .../securesms/database/SQLiteDatabaseTest.kt | 125 ++++++++++++++ .../securesms/database/DatabaseObserver.java | 83 +++++---- .../securesms/database/SQLiteDatabase.java | 104 +++++++++++- .../securesms/database/SignalDatabase.kt | 5 + 5 files changed, 444 insertions(+), 30 deletions(-) create mode 100644 app/src/androidTest/java/org/thoughtcrime/securesms/database/DatabaseObserverTest.kt create mode 100644 app/src/androidTest/java/org/thoughtcrime/securesms/database/SQLiteDatabaseTest.kt diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/DatabaseObserverTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/DatabaseObserverTest.kt new file mode 100644 index 000000000..c86e5bdb5 --- /dev/null +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/DatabaseObserverTest.kt @@ -0,0 +1,157 @@ +package org.thoughtcrime.securesms.database + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.signal.core.util.concurrent.SignalExecutors +import org.thoughtcrime.securesms.dependencies.ApplicationDependencies +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +/** + * When writing tests, be very careful to call [DatabaseObserver.flush] before asserting any observer state. Internally, the observer is enqueueing tasks on + * an executor, and failing to flush the executor will lead to incorrect/flaky tests. + */ +@RunWith(AndroidJUnit4::class) +class DatabaseObserverTest { + + private lateinit var db: SQLiteDatabase + private lateinit var observer: DatabaseObserver + + @Before + fun setup() { + db = SignalDatabase.instance!!.signalWritableDatabase + observer = ApplicationDependencies.getDatabaseObserver() + } + + @Test + fun notifyConversationListeners_runsImmediatelyIfNotInTransaction() { + val hasRun = AtomicBoolean(false) + observer.registerConversationObserver(1) { hasRun.set(true) } + observer.notifyConversationListeners(1) + observer.flush() + assertTrue(hasRun.get()) + } + + @Test + fun notifyConversationListeners_runsAfterSuccessIfInTransaction() { + val hasRun = AtomicBoolean(false) + + db.beginTransaction() + + observer.registerConversationObserver(1) { hasRun.set(true) } + observer.notifyConversationListeners(1) + observer.flush() + assertFalse(hasRun.get()) + + db.setTransactionSuccessful() + db.endTransaction() + + observer.flush() + + assertTrue(hasRun.get()) + } + + @Test + fun notifyConversationListeners_doesNotRunAfterFailedTransaction() { + val hasRun = AtomicBoolean(false) + + db.beginTransaction() + + observer.registerConversationObserver(1) { hasRun.set(true) } + observer.notifyConversationListeners(1) + observer.flush() + assertFalse(hasRun.get()) + + db.endTransaction() + observer.flush() + assertFalse(hasRun.get()) + + // Verifying we still don't run it even after a subsequent success + db.beginTransaction() + db.setTransactionSuccessful() + db.endTransaction() + + observer.flush() + + assertFalse(hasRun.get()) + } + + @Test + fun notifyConversationListeners_onlyRunAfterAllTransactionsComplete() { + val hasRun = AtomicBoolean(false) + + db.beginTransaction() + + observer.registerConversationObserver(1) { hasRun.set(true) } + observer.notifyConversationListeners(1) + observer.flush() + assertFalse(hasRun.get()) + + db.beginTransaction() + db.setTransactionSuccessful() + db.endTransaction() + observer.flush() + assertFalse(hasRun.get()) + + db.setTransactionSuccessful() + db.endTransaction() + + observer.flush() + + assertTrue(hasRun.get()) + } + + @Test + fun notifyConversationListeners_runsImmediatelyIfTheTransactionIsOnAnotherThread() { + db.beginTransaction() + + val latch = CountDownLatch(1) + SignalExecutors.BOUNDED.execute { + val hasRun = AtomicBoolean(false) + + observer.registerConversationObserver(1) { hasRun.set(true) } + observer.notifyConversationListeners(1) + observer.flush() + assertTrue(hasRun.get()) + + latch.countDown() + } + + latch.await() + + db.setTransactionSuccessful() + db.endTransaction() + } + + @Test + fun notifyConversationListeners_runsAfterSuccessIfInTransaction_ignoreDuplicateNotifications() { + val thread1Count = AtomicInteger(0) + val thread2Count = AtomicInteger(0) + + db.beginTransaction() + + observer.registerConversationObserver(1) { thread1Count.incrementAndGet() } + observer.registerConversationObserver(2) { thread2Count.incrementAndGet() } + + observer.notifyConversationListeners(1) + observer.notifyConversationListeners(2) + observer.notifyConversationListeners(2) + + observer.flush() + assertEquals(0, thread1Count.get()) + assertEquals(0, thread2Count.get()) + + db.setTransactionSuccessful() + db.endTransaction() + + observer.flush() + assertEquals(1, thread1Count.get()) + assertEquals(1, thread2Count.get()) + } +} diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/SQLiteDatabaseTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/SQLiteDatabaseTest.kt new file mode 100644 index 000000000..b3f28e970 --- /dev/null +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/SQLiteDatabaseTest.kt @@ -0,0 +1,125 @@ +package org.thoughtcrime.securesms.database + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import junit.framework.Assert.assertFalse +import junit.framework.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.signal.core.util.concurrent.SignalExecutors +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicBoolean + +/** + * These are tests for the wrapper we wrote around SQLCipherDatabase, not the stock or SQLCipher one. + */ +@RunWith(AndroidJUnit4::class) +class SQLiteDatabaseTest { + + private lateinit var db: SQLiteDatabase + + @Before + fun setup() { + db = SignalDatabase.instance!!.signalWritableDatabase + } + + @Test + fun runPostSuccessfulTransaction_runsImmediatelyIfNotInTransaction() { + val hasRun = AtomicBoolean(false) + db.runPostSuccessfulTransaction { hasRun.set(true) } + assertTrue(hasRun.get()) + } + + @Test + fun runPostSuccessfulTransaction_runsAfterSuccessIfInTransaction() { + val hasRun = AtomicBoolean(false) + + db.beginTransaction() + + db.runPostSuccessfulTransaction { hasRun.set(true) } + assertFalse(hasRun.get()) + + db.setTransactionSuccessful() + db.endTransaction() + + assertTrue(hasRun.get()) + } + + @Test + fun runPostSuccessfulTransaction_doesNotRunAfterFailedTransaction() { + val hasRun = AtomicBoolean(false) + + db.beginTransaction() + + db.runPostSuccessfulTransaction { hasRun.set(true) } + assertFalse(hasRun.get()) + + db.endTransaction() + + assertFalse(hasRun.get()) + + // Verifying we still don't run it even after a subsequent success + db.beginTransaction() + db.setTransactionSuccessful() + db.endTransaction() + + assertFalse(hasRun.get()) + } + + @Test + fun runPostSuccessfulTransaction_onlyRunAfterAllTransactionsComplete() { + val hasRun = AtomicBoolean(false) + + db.beginTransaction() + + db.runPostSuccessfulTransaction { hasRun.set(true) } + assertFalse(hasRun.get()) + + db.beginTransaction() + db.setTransactionSuccessful() + db.endTransaction() + assertFalse(hasRun.get()) + + db.setTransactionSuccessful() + db.endTransaction() + + assertTrue(hasRun.get()) + } + + @Test + fun runPostSuccessfulTransaction_runsImmediatelyIfTheTransactionIsOnAnotherThread() { + db.beginTransaction() + + val latch = CountDownLatch(1) + SignalExecutors.BOUNDED.execute { + val hasRun = AtomicBoolean(false) + db.runPostSuccessfulTransaction { hasRun.set(true) } + assertTrue(hasRun.get()) + latch.countDown() + } + + latch.await() + + db.setTransactionSuccessful() + db.endTransaction() + } + + @Test + fun runPostSuccessfulTransaction_runsAfterSuccessIfInTransaction_ignoreDuplicates() { + val hasRun1 = AtomicBoolean(false) + val hasRun2 = AtomicBoolean(false) + + db.beginTransaction() + + db.runPostSuccessfulTransaction("key") { hasRun1.set(true) } + db.runPostSuccessfulTransaction("key") { hasRun2.set(true) } + assertFalse(hasRun1.get()) + assertFalse(hasRun2.get()) + + db.setTransactionSuccessful() + db.endTransaction() + + assertTrue(hasRun1.get()) + assertFalse(hasRun2.get()) + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/DatabaseObserver.java b/app/src/main/java/org/thoughtcrime/securesms/database/DatabaseObserver.java index 8570b4976..0a39cbbad 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/DatabaseObserver.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/DatabaseObserver.java @@ -3,6 +3,7 @@ package org.thoughtcrime.securesms.database; import android.app.Application; import androidx.annotation.NonNull; +import androidx.annotation.VisibleForTesting; import org.jetbrains.annotations.NotNull; import org.signal.core.util.concurrent.SignalExecutors; @@ -14,6 +15,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.UUID; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; /** @@ -23,6 +25,19 @@ import java.util.concurrent.Executor; */ public class DatabaseObserver { + private static final String KEY_CONVERSATION = "Conversation:"; + private static final String KEY_VERBOSE_CONVERSATION = "VerboseConversation:"; + private static final String KEY_CONVERSATION_LIST = "ConversationList"; + private static final String KEY_PAYMENT = "Payment:"; + private static final String KEY_ALL_PAYMENTS = "AllPayments"; + private static final String KEY_CHAT_COLORS = "ChatColors"; + private static final String KEY_STICKERS = "Stickers"; + private static final String KEY_STICKER_PACKS = "StickerPacks"; + private static final String KEY_ATTACHMENTS = "Attachments"; + private static final String KEY_MESSAGE_UPDATE = "MessageUpdate:"; + private static final String KEY_MESSAGE_INSERT = "MessageInsert:"; + private static final String KEY_NOTIFICATION_PROFILES = "NotificationProfiles"; + private final Application application; private final Executor executor; @@ -150,37 +165,28 @@ public class DatabaseObserver { } public void notifyConversationListeners(Set threadIds) { - executor.execute(() -> { - for (long threadId : threadIds) { - notifyMapped(conversationObservers, threadId); - notifyMapped(verboseConversationObservers, threadId); - } - }); + for (long threadId : threadIds) { + notifyConversationListeners(threadId); + } } public void notifyConversationListeners(long threadId) { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_CONVERSATION + threadId, () -> { notifyMapped(conversationObservers, threadId); notifyMapped(verboseConversationObservers, threadId); }); } public void notifyVerboseConversationListeners(Set threadIds) { - executor.execute(() -> { - for (long threadId : threadIds) { + for (long threadId : threadIds) { + runPostSuccessfulTransaction(KEY_VERBOSE_CONVERSATION + threadId, () -> { notifyMapped(verboseConversationObservers, threadId); - } - }); - } - - public void notifyVerboseConversationListeners(long threadId) { - executor.execute(() -> { - notifyMapped(verboseConversationObservers, threadId); - }); + }); + } } public void notifyConversationListListeners() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_CONVERSATION_LIST, () -> { for (Observer listener : conversationListObservers) { listener.onChanged(); } @@ -188,19 +194,19 @@ public class DatabaseObserver { } public void notifyPaymentListeners(@NonNull UUID paymentId) { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_PAYMENT + paymentId.toString(), () -> { notifyMapped(paymentObservers, paymentId); }); } public void notifyAllPaymentsListeners() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_ALL_PAYMENTS, () -> { notifySet(allPaymentsObservers); }); } public void notifyChatColorsListeners() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_CHAT_COLORS, () -> { for (Observer chatColorsObserver : chatColorsObservers) { chatColorsObserver.onChanged(); } @@ -208,31 +214,31 @@ public class DatabaseObserver { } public void notifyStickerObservers() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_STICKERS, () -> { notifySet(stickerObservers); }); } public void notifyStickerPackObservers() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_STICKER_PACKS, () -> { notifySet(stickerPackObservers); }); } public void notifyAttachmentObservers() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_ATTACHMENTS, () -> { notifySet(attachmentObservers); }); } public void notifyMessageUpdateObservers(@NonNull MessageId messageId) { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_MESSAGE_UPDATE + messageId.toString(), () -> { messageUpdateObservers.stream().forEach(l -> l.onMessageChanged(messageId)); }); } public void notifyMessageInsertObservers(long threadId, @NonNull MessageId messageId) { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_MESSAGE_INSERT + messageId, () -> { Set listeners = messageInsertObservers.get(threadId); if (listeners != null) { @@ -242,11 +248,17 @@ public class DatabaseObserver { } public void notifyNotificationProfileObservers() { - executor.execute(() -> { + runPostSuccessfulTransaction(KEY_NOTIFICATION_PROFILES, () -> { notifySet(notificationProfileObservers); }); } + private void runPostSuccessfulTransaction(@NonNull String dedupeKey, @NonNull Runnable runnable) { + SignalDatabase.runPostSuccessfulTransaction(dedupeKey, () -> { + executor.execute(runnable); + }); + } + private void registerMapped(@NonNull Map> map, @NonNull K key, @NonNull V listener) { Set listeners = map.get(key); @@ -274,12 +286,27 @@ public class DatabaseObserver { } } - public static void notifySet(@NonNull Set set) { + private static void notifySet(@NonNull Set set) { for (final Observer observer : set) { observer.onChanged(); } } + /** + * Blocks until the executor is empty. Only intended to be used for testing. + */ + @VisibleForTesting + void flush() { + CountDownLatch latch = new CountDownLatch(1); + executor.execute(latch::countDown); + + try { + latch.await(); + } catch (InterruptedException e) { + throw new AssertionError(); + } + } + public interface Observer { /** * Called when the relevant data changes. Executed on a serial executor, so don't do any diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java b/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java index 943e1f510..6fed200de 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/SQLiteDatabase.java @@ -4,6 +4,8 @@ package org.thoughtcrime.securesms.database; import android.content.ContentValues; import android.database.Cursor; +import androidx.annotation.NonNull; + import net.zetetic.database.SQLException; import net.zetetic.database.sqlcipher.SQLiteStatement; import net.zetetic.database.sqlcipher.SQLiteTransactionListener; @@ -11,8 +13,11 @@ import net.zetetic.database.sqlcipher.SQLiteTransactionListener; import org.signal.core.util.tracing.Tracer; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Set; /** * This is a wrapper around {@link net.zetetic.database.sqlcipher.SQLiteDatabase}. There's difficulties @@ -33,10 +38,14 @@ public class SQLiteDatabase { private static final String KEY_THREAD = "thread"; private static final String NAME_LOCK = "LOCK"; - private final net.zetetic.database.sqlcipher.SQLiteDatabase wrapped; private final Tracer tracer; + private static final ThreadLocal> POST_TRANSACTION_TASKS = new ThreadLocal<>(); + static { + POST_TRANSACTION_TASKS.set(new LinkedHashSet<>()); + } + public SQLiteDatabase(net.zetetic.database.sqlcipher.SQLiteDatabase wrapped) { this.wrapped = wrapped; this.tracer = Tracer.getInstance(); @@ -102,10 +111,77 @@ public class SQLiteDatabase { return wrapped; } + /** + * Allows you to enqueue a task to be run after the active transaction is successfully completed. + * If the transaction fails, the task is discarded. + * If there is no current transaction open, the task is run immediately. + */ + public void runPostSuccessfulTransaction(@NonNull Runnable task) { + if (wrapped.inTransaction()) { + getPostTransactionTasks().add(task); + } else { + task.run(); + } + } + + /** + * Does the same as {@link #runPostSuccessfulTransaction(Runnable)}, except that you can pass in a "dedupe key". + * There can only be one task enqueued for a given dedupe key. So, if you enqueue a second task with that key, it will be discarded. + */ + public void runPostSuccessfulTransaction(@NonNull String dedupeKey, @NonNull Runnable task) { + if (wrapped.inTransaction()) { + getPostTransactionTasks().add(new DedupedRunnable(dedupeKey, task)); + } else { + task.run(); + } + } + + private @NonNull Set getPostTransactionTasks() { + Set tasks = POST_TRANSACTION_TASKS.get(); + + if (tasks == null) { + tasks = new LinkedHashSet<>(); + POST_TRANSACTION_TASKS.set(tasks); + } + + return tasks; + } + private interface Returnable { E run(); } + /** + * Runnable whose equals/hashcode is determined by a key you pass in. + */ + private static class DedupedRunnable implements Runnable { + private final String key; + private final Runnable runnable; + + protected DedupedRunnable(@NonNull String key, @NonNull Runnable runnable) { + this.key = key; + this.runnable = runnable; + } + + @Override + public void run() { + runnable.run(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final DedupedRunnable that = (DedupedRunnable) o; + return key.equals(that.key); + } + + @Override + public int hashCode() { + return Objects.hash(key); + } + } + // ======================================================= // Traced @@ -113,7 +189,31 @@ public class SQLiteDatabase { public void beginTransaction() { traceLockStart(); - trace("beginTransaction()", wrapped::beginTransaction); + + if (wrapped.inTransaction()) { + trace("beginTransaction()", wrapped::beginTransaction); + } else { + trace("beginTransaction()", () -> { + wrapped.beginTransactionWithListener(new SQLiteTransactionListener() { + @Override + public void onBegin() { } + + @Override + public void onCommit() { + Set tasks = getPostTransactionTasks(); + for (Runnable r : tasks) { + r.run(); + } + tasks.clear(); + } + + @Override + public void onRollback() { + getPostTransactionTasks().clear(); + } + }); + }); + } } public void endTransaction() { diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/SignalDatabase.kt b/app/src/main/java/org/thoughtcrime/securesms/database/SignalDatabase.kt index ce8713f7d..cb9db2c3c 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/SignalDatabase.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/SignalDatabase.kt @@ -220,6 +220,11 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data val inTransaction: Boolean get() = instance!!.rawWritableDatabase.inTransaction() + @JvmStatic + fun runPostSuccessfulTransaction(dedupeKey: String, task: Runnable) { + instance!!.signalReadableDatabase.runPostSuccessfulTransaction(dedupeKey, task) + } + @JvmStatic fun databaseFileExists(context: Context): Boolean { return context.getDatabasePath(DATABASE_NAME).exists()