package org.thoughtcrime.securesms.database import android.app.Application import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.platform.app.InstrumentationRegistry import org.hamcrest.MatcherAssert.assertThat import org.hamcrest.Matchers import org.junit.Assert.assertEquals import org.junit.Assert.assertNotEquals import org.junit.Assert.assertNotNull import org.junit.Assert.assertNull import org.junit.Before import org.junit.Test import org.junit.runner.RunWith import org.signal.libsignal.protocol.IdentityKey import org.signal.libsignal.protocol.SignalProtocolAddress import org.signal.libsignal.protocol.state.SessionRecord import org.signal.libsignal.zkgroup.groups.GroupMasterKey import org.signal.storageservice.protos.groups.local.DecryptedGroup import org.signal.storageservice.protos.groups.local.DecryptedMember import org.thoughtcrime.securesms.conversation.colors.AvatarColor import org.thoughtcrime.securesms.database.model.DistributionListId import org.thoughtcrime.securesms.database.model.DistributionListRecord import org.thoughtcrime.securesms.database.model.Mention import org.thoughtcrime.securesms.database.model.MessageId import org.thoughtcrime.securesms.database.model.MessageRecord import org.thoughtcrime.securesms.database.model.ReactionRecord import org.thoughtcrime.securesms.groups.GroupId import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.mms.IncomingMediaMessage import org.thoughtcrime.securesms.notifications.profiles.NotificationProfile import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.sms.IncomingTextMessage import org.thoughtcrime.securesms.util.CursorUtil import org.whispersystems.signalservice.api.push.ACI import org.whispersystems.signalservice.api.push.PNI import org.whispersystems.signalservice.api.util.UuidUtil import java.util.Optional import java.util.UUID @RunWith(AndroidJUnit4::class) class RecipientDatabaseTest_merges { private lateinit var recipientDatabase: RecipientDatabase private lateinit var identityDatabase: IdentityDatabase private lateinit var groupReceiptDatabase: GroupReceiptDatabase private lateinit var groupDatabase: GroupDatabase private lateinit var threadDatabase: ThreadDatabase private lateinit var smsDatabase: MessageDatabase private lateinit var mmsDatabase: MessageDatabase private lateinit var sessionDatabase: SessionDatabase private lateinit var mentionDatabase: MentionDatabase private lateinit var reactionDatabase: ReactionDatabase private lateinit var notificationProfileDatabase: NotificationProfileDatabase private lateinit var distributionListDatabase: DistributionListDatabase private val localAci = ACI.from(UUID.randomUUID()) private val localPni = PNI.from(UUID.randomUUID()) @Before fun setup() { recipientDatabase = SignalDatabase.recipients identityDatabase = SignalDatabase.identities groupReceiptDatabase = SignalDatabase.groupReceipts groupDatabase = SignalDatabase.groups threadDatabase = SignalDatabase.threads smsDatabase = SignalDatabase.sms mmsDatabase = SignalDatabase.mms sessionDatabase = SignalDatabase.sessions mentionDatabase = SignalDatabase.mentions reactionDatabase = SignalDatabase.reactions notificationProfileDatabase = SignalDatabase.notificationProfiles distributionListDatabase = SignalDatabase.distributionLists SignalStore.account().setAci(localAci) SignalStore.account().setPni(localPni) } /** High trust lets you merge two different users into one. You should prefer the ACI user. Not shown: merging threads, dropping e164 sessions, etc. */ @Test fun getAndPossiblyMerge_general() { // Setup val recipientIdAci: RecipientId = recipientDatabase.getOrInsertFromServiceId(ACI_A) val recipientIdE164: RecipientId = recipientDatabase.getOrInsertFromE164(E164_A) val recipientIdAciB: RecipientId = recipientDatabase.getOrInsertFromServiceId(ACI_B) val smsId1: Long = smsDatabase.insertMessageInbox(smsMessage(sender = recipientIdAci, time = 0, body = "0")).get().messageId val smsId2: Long = smsDatabase.insertMessageInbox(smsMessage(sender = recipientIdE164, time = 1, body = "1")).get().messageId val smsId3: Long = smsDatabase.insertMessageInbox(smsMessage(sender = recipientIdAci, time = 2, body = "2")).get().messageId val mmsId1: Long = mmsDatabase.insertSecureDecryptedMessageInbox(mmsMessage(sender = recipientIdAci, time = 3, body = "3"), -1).get().messageId val mmsId2: Long = mmsDatabase.insertSecureDecryptedMessageInbox(mmsMessage(sender = recipientIdE164, time = 4, body = "4"), -1).get().messageId val mmsId3: Long = mmsDatabase.insertSecureDecryptedMessageInbox(mmsMessage(sender = recipientIdAci, time = 5, body = "5"), -1).get().messageId val threadIdAci: Long = threadDatabase.getThreadIdFor(recipientIdAci)!! val threadIdE164: Long = threadDatabase.getThreadIdFor(recipientIdE164)!! assertNotEquals(threadIdAci, threadIdE164) mentionDatabase.insert(threadIdAci, mmsId1, listOf(Mention(recipientIdE164, 0, 1))) mentionDatabase.insert(threadIdE164, mmsId2, listOf(Mention(recipientIdAci, 0, 1))) groupReceiptDatabase.insert(listOf(recipientIdAci, recipientIdE164), mmsId1, 0, 3) val identityKeyAci: IdentityKey = identityKey(1) val identityKeyE164: IdentityKey = identityKey(2) identityDatabase.saveIdentity(ACI_A.toString(), recipientIdAci, identityKeyAci, IdentityDatabase.VerifiedStatus.VERIFIED, false, 0, false) identityDatabase.saveIdentity(E164_A, recipientIdE164, identityKeyE164, IdentityDatabase.VerifiedStatus.VERIFIED, false, 0, false) sessionDatabase.store(localAci, SignalProtocolAddress(ACI_A.toString(), 1), SessionRecord()) reactionDatabase.addReaction(MessageId(smsId1, false), ReactionRecord("a", recipientIdAci, 1, 1)) reactionDatabase.addReaction(MessageId(mmsId1, true), ReactionRecord("b", recipientIdE164, 1, 1)) val profile1: NotificationProfile = notificationProfile(name = "Test") val profile2: NotificationProfile = notificationProfile(name = "Test2") notificationProfileDatabase.addAllowedRecipient(profileId = profile1.id, recipientId = recipientIdAci) notificationProfileDatabase.addAllowedRecipient(profileId = profile1.id, recipientId = recipientIdE164) notificationProfileDatabase.addAllowedRecipient(profileId = profile2.id, recipientId = recipientIdE164) notificationProfileDatabase.addAllowedRecipient(profileId = profile2.id, recipientId = recipientIdAciB) val distributionListId: DistributionListId = distributionListDatabase.createList("testlist", listOf(recipientIdE164, recipientIdAciB))!! // Merge val retrievedId: RecipientId = recipientDatabase.getAndPossiblyMerge(ACI_A, E164_A, true) val retrievedThreadId: Long = threadDatabase.getThreadIdFor(retrievedId)!! assertEquals(recipientIdAci, retrievedId) // Recipient validation val retrievedRecipient = Recipient.resolved(retrievedId) assertEquals(ACI_A, retrievedRecipient.requireServiceId()) assertEquals(E164_A, retrievedRecipient.requireE164()) val existingE164Recipient = Recipient.resolved(recipientIdE164) assertEquals(retrievedId, existingE164Recipient.id) // Thread validation assertEquals(threadIdAci, retrievedThreadId) assertNull(threadDatabase.getThreadIdFor(recipientIdE164)) assertNull(threadDatabase.getThreadRecord(threadIdE164)) // SMS validation val sms1: MessageRecord = smsDatabase.getMessageRecord(smsId1)!! val sms2: MessageRecord = smsDatabase.getMessageRecord(smsId2)!! val sms3: MessageRecord = smsDatabase.getMessageRecord(smsId3)!! assertEquals(retrievedId, sms1.recipient.id) assertEquals(retrievedId, sms2.recipient.id) assertEquals(retrievedId, sms3.recipient.id) assertEquals(retrievedThreadId, sms1.threadId) assertEquals(retrievedThreadId, sms2.threadId) assertEquals(retrievedThreadId, sms3.threadId) // MMS validation val mms1: MessageRecord = mmsDatabase.getMessageRecord(mmsId1)!! val mms2: MessageRecord = mmsDatabase.getMessageRecord(mmsId2)!! val mms3: MessageRecord = mmsDatabase.getMessageRecord(mmsId3)!! assertEquals(retrievedId, mms1.recipient.id) assertEquals(retrievedId, mms2.recipient.id) assertEquals(retrievedId, mms3.recipient.id) assertEquals(retrievedThreadId, mms1.threadId) assertEquals(retrievedThreadId, mms2.threadId) assertEquals(retrievedThreadId, mms3.threadId) // Mention validation val mention1: MentionModel = getMention(mmsId1) assertEquals(retrievedId, mention1.recipientId) assertEquals(retrievedThreadId, mention1.threadId) val mention2: MentionModel = getMention(mmsId2) assertEquals(retrievedId, mention2.recipientId) assertEquals(retrievedThreadId, mention2.threadId) // Group receipt validation val groupReceipts: List = groupReceiptDatabase.getGroupReceiptInfo(mmsId1) assertEquals(retrievedId, groupReceipts[0].recipientId) assertEquals(retrievedId, groupReceipts[1].recipientId) // Identity validation assertEquals(identityKeyAci, identityDatabase.getIdentityStoreRecord(ACI_A.toString())!!.identityKey) assertNull(identityDatabase.getIdentityStoreRecord(E164_A)) // Session validation assertNotNull(sessionDatabase.load(localAci, SignalProtocolAddress(ACI_A.toString(), 1))) // Reaction validation val reactionsSms: List = reactionDatabase.getReactions(MessageId(smsId1, false)) val reactionsMms: List = reactionDatabase.getReactions(MessageId(mmsId1, true)) assertEquals(1, reactionsSms.size) assertEquals(ReactionRecord("a", recipientIdAci, 1, 1), reactionsSms[0]) assertEquals(1, reactionsMms.size) assertEquals(ReactionRecord("b", recipientIdAci, 1, 1), reactionsMms[0]) // Notification Profile validation val updatedProfile1: NotificationProfile = notificationProfileDatabase.getProfile(profile1.id)!! val updatedProfile2: NotificationProfile = notificationProfileDatabase.getProfile(profile2.id)!! assertThat("Notification Profile 1 should now only contain ACI $recipientIdAci", updatedProfile1.allowedMembers, Matchers.containsInAnyOrder(recipientIdAci)) assertThat("Notification Profile 2 should now contain ACI A ($recipientIdAci) and ACI B ($recipientIdAciB)", updatedProfile2.allowedMembers, Matchers.containsInAnyOrder(recipientIdAci, recipientIdAciB)) // Distribution List validation val updatedList: DistributionListRecord = distributionListDatabase.getList(distributionListId)!! assertThat("Distribution list should have updated $recipientIdE164 to $recipientIdAci", updatedList.members, Matchers.containsInAnyOrder(recipientIdAci, recipientIdAciB)) } private val context: Application get() = InstrumentationRegistry.getInstrumentation().targetContext.applicationContext as Application private fun smsMessage(sender: RecipientId, time: Long = 0, body: String = "", groupId: Optional = Optional.empty()): IncomingTextMessage { return IncomingTextMessage(sender, 1, time, time, time, body, groupId, 0, true, null) } private fun mmsMessage(sender: RecipientId, time: Long = 0, body: String = "", groupId: Optional = Optional.empty()): IncomingMediaMessage { return IncomingMediaMessage(sender, groupId, body, time, time, time, emptyList(), 0, 0, false, false, true, Optional.empty()) } private fun identityKey(value: Byte): IdentityKey { val bytes = ByteArray(33) bytes[0] = 0x05 bytes[1] = value return IdentityKey(bytes) } private fun groupMasterKey(value: Byte): GroupMasterKey { val bytes = ByteArray(32) bytes[0] = value return GroupMasterKey(bytes) } private fun decryptedGroup(members: Collection): DecryptedGroup { return DecryptedGroup.newBuilder() .addAllMembers(members.map { DecryptedMember.newBuilder().setUuid(UuidUtil.toByteString(it)).build() }) .build() } private fun getMention(messageId: Long): MentionModel { SignalDatabase.rawDatabase.rawQuery("SELECT * FROM ${MentionDatabase.TABLE_NAME} WHERE ${MentionDatabase.MESSAGE_ID} = $messageId").use { cursor -> cursor.moveToFirst() return MentionModel( recipientId = RecipientId.from(CursorUtil.requireLong(cursor, MentionDatabase.RECIPIENT_ID)), threadId = CursorUtil.requireLong(cursor, MentionDatabase.THREAD_ID) ) } } private fun notificationProfile(name: String): NotificationProfile { return (notificationProfileDatabase.createProfile(name = name, emoji = "", color = AvatarColor.A210, System.currentTimeMillis()) as NotificationProfileDatabase.NotificationProfileChangeResult.Success).notificationProfile } /** The normal mention model doesn't have a threadId, so we need to do it ourselves for the test */ data class MentionModel( val recipientId: RecipientId, val threadId: Long ) companion object { val ACI_A = ACI.from(UUID.fromString("3436efbe-5a76-47fa-a98a-7e72c948a82e")) val ACI_B = ACI.from(UUID.fromString("8de7f691-0b60-4a68-9cd9-ed2f8453f9ed")) val E164_A = "+12221234567" val E164_B = "+13331234567" } }