From a7a5f2e8c6e0cb7e55f646f219f37cb7058b31f3 Mon Sep 17 00:00:00 2001 From: Cody Henthorne Date: Tue, 26 Jul 2022 16:31:55 -0400 Subject: [PATCH] Add batch identity checks to stories and share/forward flows. --- .../contacts/paged/ContactSearchMediator.kt | 5 +- .../contacts/paged/ContactSearchViewModel.kt | 16 +- .../contacts/paged/SafetyNumberRepository.kt | 114 ++++++++ .../database/identity/IdentityRecordList.java | 16 +- .../mediasend/v2/UntrustedRecords.kt | 3 +- .../v2/stories/ChooseGroupStoryBottomSheet.kt | 13 +- .../paged/SafetyNumberRepositoryTest.kt | 246 ++++++++++++++++++ .../api/services/ProfileService.java | 6 +- .../internal/push/IdentityCheckResponse.java | 23 +- 9 files changed, 418 insertions(+), 24 deletions(-) create mode 100644 app/src/main/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepository.kt create mode 100644 app/src/test/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepositoryTest.kt diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchMediator.kt b/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchMediator.kt index 4cb29c6b8..3ee2a31d3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchMediator.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchMediator.kt @@ -23,10 +23,11 @@ class ContactSearchMediator( selectionLimits: SelectionLimits, displayCheckBox: Boolean, mapStateToConfiguration: (ContactSearchState) -> ContactSearchConfiguration, - private val contactSelectionPreFilter: (View?, Set) -> Set = { _, s -> s } + private val contactSelectionPreFilter: (View?, Set) -> Set = { _, s -> s }, + performSafetyNumberChecks: Boolean = true ) { - private val viewModel: ContactSearchViewModel = ViewModelProvider(fragment, ContactSearchViewModel.Factory(selectionLimits, ContactSearchRepository())).get(ContactSearchViewModel::class.java) + private val viewModel: ContactSearchViewModel = ViewModelProvider(fragment, ContactSearchViewModel.Factory(selectionLimits, ContactSearchRepository(), performSafetyNumberChecks)).get(ContactSearchViewModel::class.java) init { diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchViewModel.kt b/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchViewModel.kt index e42c7a578..feb2e0c84 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchViewModel.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/ContactSearchViewModel.kt @@ -21,7 +21,9 @@ import org.whispersystems.signalservice.api.util.Preconditions */ class ContactSearchViewModel( private val selectionLimits: SelectionLimits, - private val contactSearchRepository: ContactSearchRepository + private val contactSearchRepository: ContactSearchRepository, + private val performSafetyNumberChecks: Boolean, + private val safetyNumberRepository: SafetyNumberRepository = SafetyNumberRepository(), ) : ViewModel() { private val disposables = CompositeDisposable() @@ -75,6 +77,10 @@ class ContactSearchViewModel( return@subscribe } + if (performSafetyNumberChecks) { + safetyNumberRepository.batchSafetyNumberCheck(newSelectionEntries) + } + selectionStore.update { state -> state + newSelectionEntries } } } @@ -123,9 +129,13 @@ class ContactSearchViewModel( controller.value?.onDataInvalidated() } - class Factory(private val selectionLimits: SelectionLimits, private val repository: ContactSearchRepository) : ViewModelProvider.Factory { + class Factory( + private val selectionLimits: SelectionLimits, + private val repository: ContactSearchRepository, + private val performSafetyNumberChecks: Boolean + ) : ViewModelProvider.Factory { override fun create(modelClass: Class): T { - return modelClass.cast(ContactSearchViewModel(selectionLimits, repository)) as T + return modelClass.cast(ContactSearchViewModel(selectionLimits, repository, performSafetyNumberChecks)) as T } } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepository.kt new file mode 100644 index 000000000..bb813d3d6 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepository.kt @@ -0,0 +1,114 @@ +package org.thoughtcrime.securesms.contacts.paged + +import androidx.annotation.VisibleForTesting +import io.reactivex.rxjava3.core.Single +import org.signal.core.util.concurrent.SignalExecutors +import org.signal.core.util.logging.Log +import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore +import org.thoughtcrime.securesms.dependencies.ApplicationDependencies +import org.thoughtcrime.securesms.recipients.Recipient +import org.thoughtcrime.securesms.recipients.RecipientId +import org.thoughtcrime.securesms.util.IdentityUtil +import org.thoughtcrime.securesms.util.Stopwatch +import org.whispersystems.signalservice.api.services.ProfileService +import org.whispersystems.signalservice.internal.ServiceResponseProcessor +import org.whispersystems.signalservice.internal.push.IdentityCheckResponse +import java.util.concurrent.TimeUnit + +/** + * Generic repository for interacting with safety numbers and fetch new ones. + */ +class SafetyNumberRepository( + private val profileService: ProfileService = ApplicationDependencies.getProfileService(), + private val aciIdentityStore: SignalIdentityKeyStore = ApplicationDependencies.getProtocolStore().aci().identities() +) { + + private val recentlyFetched: MutableMap = HashMap() + + fun batchSafetyNumberCheck(newSelectionEntries: List) { + SignalExecutors.UNBOUNDED.execute { batchSafetyNumberCheckSync(newSelectionEntries) } + } + + @Suppress("UNCHECKED_CAST") + @VisibleForTesting + fun batchSafetyNumberCheckSync(newSelectionEntries: List, now: Long = System.currentTimeMillis(), batchSize: Int = MAX_BATCH_SIZE) { + val stopwatch = Stopwatch("batch-snc") + val recipientIds: Set = newSelectionEntries.flattenToRecipientIds() + stopwatch.split("recipient-ids") + + val recentIds = recentlyFetched.filter { (_, timestamp) -> (now - timestamp) < RECENT_TIME_WINDOW }.keys + val recipients = Recipient.resolvedList(recipientIds - recentIds).filter { it.hasServiceId() } + stopwatch.split("recipient-resolve") + + if (recipients.isNotEmpty()) { + Log.i(TAG, "Checking on ${recipients.size} identities...") + val requests: List>> = recipients.chunked(batchSize) { it.createBatchRequestSingle() } + stopwatch.split("requests") + + val aciKeyPairs: List = Single.zip(requests) { responses -> + responses + .map { it as List } + .flatten() + }.blockingGet() + + stopwatch.split("batch-fetches") + + if (aciKeyPairs.isEmpty()) { + Log.d(TAG, "No identity key mismatches") + } else { + aciKeyPairs + .filter { it.aci != null && it.identityKey != null } + .forEach { IdentityUtil.saveIdentity(it.aci.toString(), it.identityKey) } + } + recentlyFetched += recipients.associate { it.id to now } + stopwatch.split("saving-identities") + } + stopwatch.stop(TAG) + } + + private fun List.flattenToRecipientIds(): Set { + return this + .map { + when (it) { + is ContactSearchKey.RecipientSearchKey.KnownRecipient -> { + val recipient = Recipient.resolved(it.recipientId) + if (recipient.isGroup) { + recipient.participantIds + } else { + listOf(it.recipientId) + } + } + is ContactSearchKey.RecipientSearchKey.Story -> Recipient.resolved(it.recipientId).participantIds + else -> throw AssertionError("Invalid contact selection $it") + } + } + .flatten() + .toMutableSet() + .apply { remove(Recipient.self().id) } + } + + private fun List.createBatchRequestSingle(): Single> { + return profileService + .performIdentityCheck( + mapNotNull { r -> + val identityRecord = aciIdentityStore.getIdentityRecord(r.id) + if (identityRecord.isPresent) { + r.requireServiceId() to identityRecord.get().identityKey + } else { + null + } + }.associate { it } + ) + .map { ServiceResponseProcessor.DefaultProcessor(it).resultOrThrow.aciKeyPairs ?: emptyList() } + .onErrorReturn { t -> + Log.w(TAG, "Unable to fetch identities", t) + emptyList() + } + } + + companion object { + private val TAG = Log.tag(SafetyNumberRepository::class.java) + private val RECENT_TIME_WINDOW = TimeUnit.SECONDS.toMillis(30) + private const val MAX_BATCH_SIZE = 1000 + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/identity/IdentityRecordList.java b/app/src/main/java/org/thoughtcrime/securesms/database/identity/IdentityRecordList.java index 6534ae0e2..9b80620b5 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/identity/IdentityRecordList.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/identity/IdentityRecordList.java @@ -16,6 +16,8 @@ public final class IdentityRecordList { public static final IdentityRecordList EMPTY = new IdentityRecordList(Collections.emptyList()); + private static final long DEFAULT_UNTRUSTED_WINDOW = TimeUnit.SECONDS.toMillis(5); + private final List identityRecords; private final boolean isVerified; private final boolean isUnverified; @@ -78,7 +80,7 @@ public final class IdentityRecordList { continue; } - if (isUntrusted(identityRecord)) { + if (isUntrusted(identityRecord, DEFAULT_UNTRUSTED_WINDOW)) { return true; } } @@ -87,10 +89,14 @@ public final class IdentityRecordList { } public @NonNull List getUntrustedRecords() { + return getUntrustedRecords(DEFAULT_UNTRUSTED_WINDOW); + } + + public @NonNull List getUntrustedRecords(long untrustedWindowMillis) { List results = new ArrayList<>(identityRecords.size()); for (IdentityRecord identityRecord : identityRecords) { - if (isUntrusted(identityRecord)) { + if (isUntrusted(identityRecord, untrustedWindowMillis)) { results.add(identityRecord); } } @@ -102,7 +108,7 @@ public final class IdentityRecordList { List untrusted = new ArrayList<>(identityRecords.size()); for (IdentityRecord identityRecord : identityRecords) { - if (isUntrusted(identityRecord)) { + if (isUntrusted(identityRecord, DEFAULT_UNTRUSTED_WINDOW)) { untrusted.add(Recipient.resolved(identityRecord.getRecipientId())); } } @@ -134,9 +140,9 @@ public final class IdentityRecordList { return unverified; } - private static boolean isUntrusted(@NonNull IdentityRecord identityRecord) { + private static boolean isUntrusted(@NonNull IdentityRecord identityRecord, long untrustedWindowMillis) { return !identityRecord.isApprovedNonBlocking() && - System.currentTimeMillis() - identityRecord.getTimestamp() < TimeUnit.SECONDS.toMillis(5); + System.currentTimeMillis() - identityRecord.getTimestamp() < untrustedWindowMillis; } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/UntrustedRecords.kt b/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/UntrustedRecords.kt index 0941721ae..0800bf720 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/UntrustedRecords.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/UntrustedRecords.kt @@ -10,6 +10,7 @@ import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.model.IdentityRecord import org.thoughtcrime.securesms.dependencies.ApplicationDependencies import org.thoughtcrime.securesms.recipients.Recipient +import java.util.concurrent.TimeUnit object UntrustedRecords { @@ -41,7 +42,7 @@ object UntrustedRecords { } .flatten() - return ApplicationDependencies.getProtocolStore().aci().identities().getIdentityRecords(recipients).untrustedRecords + return ApplicationDependencies.getProtocolStore().aci().identities().getIdentityRecords(recipients).getUntrustedRecords(TimeUnit.SECONDS.toMillis(30)) } class UntrustedRecordsException(val untrustedRecords: List, val destinations: Set) : Throwable() diff --git a/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/stories/ChooseGroupStoryBottomSheet.kt b/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/stories/ChooseGroupStoryBottomSheet.kt index 7cec89245..039ea2c01 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/stories/ChooseGroupStoryBottomSheet.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/mediasend/v2/stories/ChooseGroupStoryBottomSheet.kt @@ -62,11 +62,11 @@ class ChooseGroupStoryBottomSheet : FixedRoundedCornerBottomSheetDialogFragment( val contactRecycler: RecyclerView = view.findViewById(R.id.contact_recycler) mediator = ContactSearchMediator( - this, - contactRecycler, - FeatureFlags.shareSelectionLimit(), - true, - { state -> + fragment = this, + recyclerView = contactRecycler, + selectionLimits = FeatureFlags.shareSelectionLimit(), + displayCheckBox = true, + mapStateToConfiguration = { state -> ContactSearchConfiguration.build { query = state.query @@ -77,7 +77,8 @@ class ChooseGroupStoryBottomSheet : FixedRoundedCornerBottomSheetDialogFragment( ) ) } - } + }, + performSafetyNumberChecks = false ) mediator.getSelectionState().observe(viewLifecycleOwner) { state -> diff --git a/app/src/test/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepositoryTest.kt b/app/src/test/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepositoryTest.kt new file mode 100644 index 000000000..4747805a8 --- /dev/null +++ b/app/src/test/java/org/thoughtcrime/securesms/contacts/paged/SafetyNumberRepositoryTest.kt @@ -0,0 +1,246 @@ +package org.thoughtcrime.securesms.contacts.paged + +import android.app.Application +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.BeforeClass +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Mock +import org.mockito.MockedStatic +import org.mockito.internal.configuration.plugins.Plugins +import org.mockito.internal.junit.JUnitRule +import org.mockito.junit.MockitoRule +import org.mockito.kotlin.any +import org.mockito.kotlin.argThat +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever +import org.mockito.quality.Strictness +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import org.signal.core.util.logging.Log +import org.thoughtcrime.securesms.crypto.IdentityKeyUtil +import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore +import org.thoughtcrime.securesms.database.IdentityDatabase +import org.thoughtcrime.securesms.database.RecipientDatabaseTestUtils +import org.thoughtcrime.securesms.database.model.IdentityRecord +import org.thoughtcrime.securesms.recipients.Recipient +import org.thoughtcrime.securesms.recipients.RecipientId +import org.thoughtcrime.securesms.testutil.SystemOutLogger +import org.thoughtcrime.securesms.util.IdentityUtil +import org.whispersystems.signalservice.api.push.ACI +import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException +import org.whispersystems.signalservice.api.services.ProfileService +import org.whispersystems.signalservice.internal.ServiceResponse +import org.whispersystems.signalservice.internal.push.IdentityCheckResponse +import java.io.IOException +import java.util.Optional +import java.util.concurrent.TimeUnit + +@RunWith(RobolectricTestRunner::class) +@Config(application = Application::class) +class SafetyNumberRepositoryTest { + + @Rule + @JvmField + val mockitoRule: MockitoRule = JUnitRule(Plugins.getMockitoLogger(), Strictness.STRICT_STUBS) + + @Mock + lateinit var profileService: ProfileService + + @Mock(lenient = true) + lateinit var aciIdentityStore: SignalIdentityKeyStore + + @Mock + lateinit var staticIdentityUtil: MockedStatic + + @Mock + lateinit var staticRecipient: MockedStatic + + private var now: Long = System.currentTimeMillis() + + private lateinit var recipientPool: MutableList + private lateinit var identityPool: MutableMap + + private lateinit var repository: SafetyNumberRepository + + companion object { + @BeforeClass + @JvmStatic + fun setUpClass() { + Log.initialize(SystemOutLogger()) + } + } + + @Before + fun setUp() { + now = System.currentTimeMillis() + repository = SafetyNumberRepository(profileService, aciIdentityStore) + + recipientPool = mutableListOf() + identityPool = mutableMapOf() + + for (id in 1L until 12) { + val recipient = RecipientDatabaseTestUtils.createRecipient(resolved = true, recipientId = RecipientId.from(id)) + staticRecipient.`when` { Recipient.resolved(RecipientId.from(id)) }.thenReturn(recipient) + recipientPool.add(recipient) + + val record = IdentityRecord( + recipientId = recipient.id, + identityKey = IdentityKeyUtil.generateIdentityKeyPair().publicKey, + verifiedStatus = IdentityDatabase.VerifiedStatus.DEFAULT, + firstUse = false, + timestamp = 0, + nonblockingApproval = false + ) + whenever(aciIdentityStore.getIdentityRecord(recipient.id)).thenReturn(Optional.of(record)) + identityPool[recipient] = record + } + + staticRecipient.`when` { Recipient.self() }.thenReturn(recipientPool[0]) + } + + /** + * Batch request for a current identity key should return an empty list and not perform any identity key updates. + */ + @Test + fun batchSafetyNumberCheckSync_batchOf1_noChanges() { + val other = recipientPool[1] + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf()), 200, ""))) + + repository.batchSafetyNumberCheckSync(keys, now) + + staticIdentityUtil.verifyNoInteractions() + } + + /** + * Batch request for an out-of-date identity key should return the new identity key and update the store. + */ + @Test + fun batchSafetyNumberCheckSync_batchOf1_oneChange() { + val other = recipientPool[1] + val otherAci = ACI.from(other.requireServiceId()) + val otherNewIdentityKey = IdentityKeyUtil.generateIdentityKeyPair().publicKey + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf(IdentityCheckResponse.AciIdentityPair(otherAci, otherNewIdentityKey))), 200, ""))) + + repository.batchSafetyNumberCheckSync(keys, now) + + staticIdentityUtil.verify { IdentityUtil.saveIdentity(otherAci.toString(), otherNewIdentityKey) } + staticIdentityUtil.verifyNoMoreInteractions() + } + + /** + * Batch request for an out-of-date identity key should return the new identity key and update the store. + */ + @Test + fun batchSafetyNumberCheckSync_batchOf2_oneChange() { + val other = recipientPool[1] + val secondOther = recipientPool[2] + val otherAci = ACI.from(other.requireServiceId()) + val otherNewIdentityKey = IdentityKeyUtil.generateIdentityKeyPair().publicKey + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id), ContactSearchKey.RecipientSearchKey.KnownRecipient(secondOther.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other, secondOther)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey, secondOther.requireServiceId() to identityPool[secondOther]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf(IdentityCheckResponse.AciIdentityPair(otherAci, otherNewIdentityKey))), 200, ""))) + + repository.batchSafetyNumberCheckSync(keys, now) + + staticIdentityUtil.verify { IdentityUtil.saveIdentity(otherAci.toString(), otherNewIdentityKey) } + staticIdentityUtil.verifyNoMoreInteractions() + } + + /** + * Batch request for a current identity key should previously checked should abort checking. + */ + @Test + fun batchSafetyNumberCheckSync_batchOf1_abortOnPriorRecentCheck() { + val other = recipientPool[1] + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf()), 200, ""))) + + repository.batchSafetyNumberCheckSync(keys, now) + verify(profileService, times(1)).performIdentityCheck(any()) + repository.batchSafetyNumberCheckSync(keys, now + TimeUnit.SECONDS.toMillis(10)) + verify(profileService, times(1)).performIdentityCheck(any()) + repository.batchSafetyNumberCheckSync(keys, now + TimeUnit.SECONDS.toMillis(31)) + verify(profileService, times(2)).performIdentityCheck(any()) + + staticIdentityUtil.verifyNoInteractions() + } + + /** + * Batch request for a current identity keys should return an empty list and not perform any identity key updates. + */ + @Test + fun batchSafetyNumberCheckSync_batchOf10WithSmallBatchSize_noChanges() { + val keys = recipientPool.map { ContactSearchKey.RecipientSearchKey.KnownRecipient(it.id) } + val others = recipientPool.subList(1, recipientPool.lastIndex) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(others.map { it.id }) }) }.thenReturn(others) + + for (chunk in others.chunked(2)) { + whenever(profileService.performIdentityCheck(chunk.associate { it.requireServiceId() to identityPool[it]!!.identityKey })) + .thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf()), 200, ""))) + } + + repository.batchSafetyNumberCheckSync(keys, now, 2) + + staticIdentityUtil.verifyNoInteractions() + } + + @Test + fun batchSafetyNumberCheckSync_serverError() { + val other = recipientPool[1] + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forApplicationError(NonSuccessfulResponseCodeException(400), 400, ""))) + + repository.batchSafetyNumberCheckSync(keys, now) + + staticIdentityUtil.verifyNoInteractions() + } + + @Test + fun batchSafetyNumberCheckSync_networkError() { + val other = recipientPool[1] + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forUnknownError(IOException()))) + + repository.batchSafetyNumberCheckSync(keys, now) + + staticIdentityUtil.verifyNoInteractions() + } + + @Test + fun batchSafetyNumberCheckSync_badJson() { + val other = recipientPool[1] + val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id)) + + staticRecipient.`when`> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other)) + whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey))) + .thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(), 200, ""))) + + repository.batchSafetyNumberCheckSync(keys, now) + + staticIdentityUtil.verifyNoInteractions() + } +} diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java index dace6a7e5..6ab6fcc30 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java @@ -112,7 +112,7 @@ public final class ProfileService { .onErrorReturn(ServiceResponse::forUnknownError); } - public @NonNull Single> performIdentityCheck(@Nonnull Map aciIdentityKeyMap, @Nonnull Optional unidentifiedAccess) { + public @NonNull Single> performIdentityCheck(@Nonnull Map aciIdentityKeyMap) { List aciKeyPairs = aciIdentityKeyMap.entrySet() .stream() .map(e -> new AciFingerprintPair(e.getKey(), e.getValue())) @@ -129,9 +129,9 @@ public final class ProfileService { ResponseMapper responseMapper = DefaultResponseMapper.getDefault(IdentityCheckResponse.class); - return signalWebSocket.request(builder.build(), unidentifiedAccess) + return signalWebSocket.request(builder.build(), Optional.empty()) .map(responseMapper::map) - .onErrorResumeNext(t -> performIdentityCheckRestFallback(request, unidentifiedAccess, responseMapper)) + .onErrorResumeNext(t -> performIdentityCheckRestFallback(request, Optional.empty(), responseMapper)) .onErrorReturn(ServiceResponse::forUnknownError); } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/IdentityCheckResponse.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/IdentityCheckResponse.java index c755f841a..42515e767 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/IdentityCheckResponse.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/push/IdentityCheckResponse.java @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import org.signal.libsignal.protocol.IdentityKey; -import org.whispersystems.signalservice.api.push.ServiceId; +import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.internal.util.JsonUtil; import java.util.List; @@ -16,6 +16,13 @@ public class IdentityCheckResponse { @JsonProperty("elements") private List aciKeyPairs; + public IdentityCheckResponse() {} + + // Visible for testing + public IdentityCheckResponse(List aciKeyPairs) { + this.aciKeyPairs = aciKeyPairs; + } + public @Nullable List getAciKeyPairs() { return aciKeyPairs; } @@ -23,14 +30,22 @@ public class IdentityCheckResponse { public static final class AciIdentityPair { @JsonProperty - @JsonDeserialize(using = JsonUtil.ServiceIdDeserializer.class) - private ServiceId aci; + @JsonDeserialize(using = JsonUtil.AciDeserializer.class) + private ACI aci; @JsonProperty @JsonDeserialize(using = JsonUtil.IdentityKeyDeserializer.class) private IdentityKey identityKey; - public @Nullable ServiceId getAci() { + public AciIdentityPair() {} + + // Visible for testing + public AciIdentityPair(ACI aci, IdentityKey identityKey) { + this.aci = aci; + this.identityKey = identityKey; + } + + public @Nullable ACI getAci() { return aci; }