kopia lustrzana https://github.com/ryukoposting/Signal-Android
Add batch identity checks to stories and share/forward flows.
rodzic
87cb2d6bf8
commit
a7a5f2e8c6
|
@ -23,10 +23,11 @@ class ContactSearchMediator(
|
||||||
selectionLimits: SelectionLimits,
|
selectionLimits: SelectionLimits,
|
||||||
displayCheckBox: Boolean,
|
displayCheckBox: Boolean,
|
||||||
mapStateToConfiguration: (ContactSearchState) -> ContactSearchConfiguration,
|
mapStateToConfiguration: (ContactSearchState) -> ContactSearchConfiguration,
|
||||||
private val contactSelectionPreFilter: (View?, Set<ContactSearchKey>) -> Set<ContactSearchKey> = { _, s -> s }
|
private val contactSelectionPreFilter: (View?, Set<ContactSearchKey>) -> Set<ContactSearchKey> = { _, s -> s },
|
||||||
|
performSafetyNumberChecks: Boolean = true
|
||||||
) {
|
) {
|
||||||
|
|
||||||
private val viewModel: ContactSearchViewModel = ViewModelProvider(fragment, ContactSearchViewModel.Factory(selectionLimits, ContactSearchRepository())).get(ContactSearchViewModel::class.java)
|
private val viewModel: ContactSearchViewModel = ViewModelProvider(fragment, ContactSearchViewModel.Factory(selectionLimits, ContactSearchRepository(), performSafetyNumberChecks)).get(ContactSearchViewModel::class.java)
|
||||||
|
|
||||||
init {
|
init {
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,9 @@ import org.whispersystems.signalservice.api.util.Preconditions
|
||||||
*/
|
*/
|
||||||
class ContactSearchViewModel(
|
class ContactSearchViewModel(
|
||||||
private val selectionLimits: SelectionLimits,
|
private val selectionLimits: SelectionLimits,
|
||||||
private val contactSearchRepository: ContactSearchRepository
|
private val contactSearchRepository: ContactSearchRepository,
|
||||||
|
private val performSafetyNumberChecks: Boolean,
|
||||||
|
private val safetyNumberRepository: SafetyNumberRepository = SafetyNumberRepository(),
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
private val disposables = CompositeDisposable()
|
private val disposables = CompositeDisposable()
|
||||||
|
@ -75,6 +77,10 @@ class ContactSearchViewModel(
|
||||||
return@subscribe
|
return@subscribe
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (performSafetyNumberChecks) {
|
||||||
|
safetyNumberRepository.batchSafetyNumberCheck(newSelectionEntries)
|
||||||
|
}
|
||||||
|
|
||||||
selectionStore.update { state -> state + newSelectionEntries }
|
selectionStore.update { state -> state + newSelectionEntries }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,9 +129,13 @@ class ContactSearchViewModel(
|
||||||
controller.value?.onDataInvalidated()
|
controller.value?.onDataInvalidated()
|
||||||
}
|
}
|
||||||
|
|
||||||
class Factory(private val selectionLimits: SelectionLimits, private val repository: ContactSearchRepository) : ViewModelProvider.Factory {
|
class Factory(
|
||||||
|
private val selectionLimits: SelectionLimits,
|
||||||
|
private val repository: ContactSearchRepository,
|
||||||
|
private val performSafetyNumberChecks: Boolean
|
||||||
|
) : ViewModelProvider.Factory {
|
||||||
override fun <T : ViewModel> create(modelClass: Class<T>): T {
|
override fun <T : ViewModel> create(modelClass: Class<T>): T {
|
||||||
return modelClass.cast(ContactSearchViewModel(selectionLimits, repository)) as T
|
return modelClass.cast(ContactSearchViewModel(selectionLimits, repository, performSafetyNumberChecks)) as T
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
package org.thoughtcrime.securesms.contacts.paged
|
||||||
|
|
||||||
|
import androidx.annotation.VisibleForTesting
|
||||||
|
import io.reactivex.rxjava3.core.Single
|
||||||
|
import org.signal.core.util.concurrent.SignalExecutors
|
||||||
|
import org.signal.core.util.logging.Log
|
||||||
|
import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore
|
||||||
|
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
|
||||||
|
import org.thoughtcrime.securesms.recipients.Recipient
|
||||||
|
import org.thoughtcrime.securesms.recipients.RecipientId
|
||||||
|
import org.thoughtcrime.securesms.util.IdentityUtil
|
||||||
|
import org.thoughtcrime.securesms.util.Stopwatch
|
||||||
|
import org.whispersystems.signalservice.api.services.ProfileService
|
||||||
|
import org.whispersystems.signalservice.internal.ServiceResponseProcessor
|
||||||
|
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generic repository for interacting with safety numbers and fetch new ones.
|
||||||
|
*/
|
||||||
|
class SafetyNumberRepository(
|
||||||
|
private val profileService: ProfileService = ApplicationDependencies.getProfileService(),
|
||||||
|
private val aciIdentityStore: SignalIdentityKeyStore = ApplicationDependencies.getProtocolStore().aci().identities()
|
||||||
|
) {
|
||||||
|
|
||||||
|
private val recentlyFetched: MutableMap<RecipientId, Long> = HashMap()
|
||||||
|
|
||||||
|
fun batchSafetyNumberCheck(newSelectionEntries: List<ContactSearchKey>) {
|
||||||
|
SignalExecutors.UNBOUNDED.execute { batchSafetyNumberCheckSync(newSelectionEntries) }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
@VisibleForTesting
|
||||||
|
fun batchSafetyNumberCheckSync(newSelectionEntries: List<ContactSearchKey>, now: Long = System.currentTimeMillis(), batchSize: Int = MAX_BATCH_SIZE) {
|
||||||
|
val stopwatch = Stopwatch("batch-snc")
|
||||||
|
val recipientIds: Set<RecipientId> = newSelectionEntries.flattenToRecipientIds()
|
||||||
|
stopwatch.split("recipient-ids")
|
||||||
|
|
||||||
|
val recentIds = recentlyFetched.filter { (_, timestamp) -> (now - timestamp) < RECENT_TIME_WINDOW }.keys
|
||||||
|
val recipients = Recipient.resolvedList(recipientIds - recentIds).filter { it.hasServiceId() }
|
||||||
|
stopwatch.split("recipient-resolve")
|
||||||
|
|
||||||
|
if (recipients.isNotEmpty()) {
|
||||||
|
Log.i(TAG, "Checking on ${recipients.size} identities...")
|
||||||
|
val requests: List<Single<List<IdentityCheckResponse.AciIdentityPair>>> = recipients.chunked(batchSize) { it.createBatchRequestSingle() }
|
||||||
|
stopwatch.split("requests")
|
||||||
|
|
||||||
|
val aciKeyPairs: List<IdentityCheckResponse.AciIdentityPair> = Single.zip(requests) { responses ->
|
||||||
|
responses
|
||||||
|
.map { it as List<IdentityCheckResponse.AciIdentityPair> }
|
||||||
|
.flatten()
|
||||||
|
}.blockingGet()
|
||||||
|
|
||||||
|
stopwatch.split("batch-fetches")
|
||||||
|
|
||||||
|
if (aciKeyPairs.isEmpty()) {
|
||||||
|
Log.d(TAG, "No identity key mismatches")
|
||||||
|
} else {
|
||||||
|
aciKeyPairs
|
||||||
|
.filter { it.aci != null && it.identityKey != null }
|
||||||
|
.forEach { IdentityUtil.saveIdentity(it.aci.toString(), it.identityKey) }
|
||||||
|
}
|
||||||
|
recentlyFetched += recipients.associate { it.id to now }
|
||||||
|
stopwatch.split("saving-identities")
|
||||||
|
}
|
||||||
|
stopwatch.stop(TAG)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun List<ContactSearchKey>.flattenToRecipientIds(): Set<RecipientId> {
|
||||||
|
return this
|
||||||
|
.map {
|
||||||
|
when (it) {
|
||||||
|
is ContactSearchKey.RecipientSearchKey.KnownRecipient -> {
|
||||||
|
val recipient = Recipient.resolved(it.recipientId)
|
||||||
|
if (recipient.isGroup) {
|
||||||
|
recipient.participantIds
|
||||||
|
} else {
|
||||||
|
listOf(it.recipientId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is ContactSearchKey.RecipientSearchKey.Story -> Recipient.resolved(it.recipientId).participantIds
|
||||||
|
else -> throw AssertionError("Invalid contact selection $it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.flatten()
|
||||||
|
.toMutableSet()
|
||||||
|
.apply { remove(Recipient.self().id) }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun List<Recipient>.createBatchRequestSingle(): Single<List<IdentityCheckResponse.AciIdentityPair>> {
|
||||||
|
return profileService
|
||||||
|
.performIdentityCheck(
|
||||||
|
mapNotNull { r ->
|
||||||
|
val identityRecord = aciIdentityStore.getIdentityRecord(r.id)
|
||||||
|
if (identityRecord.isPresent) {
|
||||||
|
r.requireServiceId() to identityRecord.get().identityKey
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}.associate { it }
|
||||||
|
)
|
||||||
|
.map { ServiceResponseProcessor.DefaultProcessor(it).resultOrThrow.aciKeyPairs ?: emptyList() }
|
||||||
|
.onErrorReturn { t ->
|
||||||
|
Log.w(TAG, "Unable to fetch identities", t)
|
||||||
|
emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val TAG = Log.tag(SafetyNumberRepository::class.java)
|
||||||
|
private val RECENT_TIME_WINDOW = TimeUnit.SECONDS.toMillis(30)
|
||||||
|
private const val MAX_BATCH_SIZE = 1000
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,8 @@ public final class IdentityRecordList {
|
||||||
|
|
||||||
public static final IdentityRecordList EMPTY = new IdentityRecordList(Collections.emptyList());
|
public static final IdentityRecordList EMPTY = new IdentityRecordList(Collections.emptyList());
|
||||||
|
|
||||||
|
private static final long DEFAULT_UNTRUSTED_WINDOW = TimeUnit.SECONDS.toMillis(5);
|
||||||
|
|
||||||
private final List<IdentityRecord> identityRecords;
|
private final List<IdentityRecord> identityRecords;
|
||||||
private final boolean isVerified;
|
private final boolean isVerified;
|
||||||
private final boolean isUnverified;
|
private final boolean isUnverified;
|
||||||
|
@ -78,7 +80,7 @@ public final class IdentityRecordList {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isUntrusted(identityRecord)) {
|
if (isUntrusted(identityRecord, DEFAULT_UNTRUSTED_WINDOW)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,10 +89,14 @@ public final class IdentityRecordList {
|
||||||
}
|
}
|
||||||
|
|
||||||
public @NonNull List<IdentityRecord> getUntrustedRecords() {
|
public @NonNull List<IdentityRecord> getUntrustedRecords() {
|
||||||
|
return getUntrustedRecords(DEFAULT_UNTRUSTED_WINDOW);
|
||||||
|
}
|
||||||
|
|
||||||
|
public @NonNull List<IdentityRecord> getUntrustedRecords(long untrustedWindowMillis) {
|
||||||
List<IdentityRecord> results = new ArrayList<>(identityRecords.size());
|
List<IdentityRecord> results = new ArrayList<>(identityRecords.size());
|
||||||
|
|
||||||
for (IdentityRecord identityRecord : identityRecords) {
|
for (IdentityRecord identityRecord : identityRecords) {
|
||||||
if (isUntrusted(identityRecord)) {
|
if (isUntrusted(identityRecord, untrustedWindowMillis)) {
|
||||||
results.add(identityRecord);
|
results.add(identityRecord);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,7 +108,7 @@ public final class IdentityRecordList {
|
||||||
List<Recipient> untrusted = new ArrayList<>(identityRecords.size());
|
List<Recipient> untrusted = new ArrayList<>(identityRecords.size());
|
||||||
|
|
||||||
for (IdentityRecord identityRecord : identityRecords) {
|
for (IdentityRecord identityRecord : identityRecords) {
|
||||||
if (isUntrusted(identityRecord)) {
|
if (isUntrusted(identityRecord, DEFAULT_UNTRUSTED_WINDOW)) {
|
||||||
untrusted.add(Recipient.resolved(identityRecord.getRecipientId()));
|
untrusted.add(Recipient.resolved(identityRecord.getRecipientId()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -134,9 +140,9 @@ public final class IdentityRecordList {
|
||||||
return unverified;
|
return unverified;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean isUntrusted(@NonNull IdentityRecord identityRecord) {
|
private static boolean isUntrusted(@NonNull IdentityRecord identityRecord, long untrustedWindowMillis) {
|
||||||
return !identityRecord.isApprovedNonBlocking() &&
|
return !identityRecord.isApprovedNonBlocking() &&
|
||||||
System.currentTimeMillis() - identityRecord.getTimestamp() < TimeUnit.SECONDS.toMillis(5);
|
System.currentTimeMillis() - identityRecord.getTimestamp() < untrustedWindowMillis;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import org.thoughtcrime.securesms.database.SignalDatabase
|
||||||
import org.thoughtcrime.securesms.database.model.IdentityRecord
|
import org.thoughtcrime.securesms.database.model.IdentityRecord
|
||||||
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
|
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
|
||||||
import org.thoughtcrime.securesms.recipients.Recipient
|
import org.thoughtcrime.securesms.recipients.Recipient
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
object UntrustedRecords {
|
object UntrustedRecords {
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ object UntrustedRecords {
|
||||||
}
|
}
|
||||||
.flatten()
|
.flatten()
|
||||||
|
|
||||||
return ApplicationDependencies.getProtocolStore().aci().identities().getIdentityRecords(recipients).untrustedRecords
|
return ApplicationDependencies.getProtocolStore().aci().identities().getIdentityRecords(recipients).getUntrustedRecords(TimeUnit.SECONDS.toMillis(30))
|
||||||
}
|
}
|
||||||
|
|
||||||
class UntrustedRecordsException(val untrustedRecords: List<IdentityRecord>, val destinations: Set<ContactSearchKey.RecipientSearchKey>) : Throwable()
|
class UntrustedRecordsException(val untrustedRecords: List<IdentityRecord>, val destinations: Set<ContactSearchKey.RecipientSearchKey>) : Throwable()
|
||||||
|
|
|
@ -62,11 +62,11 @@ class ChooseGroupStoryBottomSheet : FixedRoundedCornerBottomSheetDialogFragment(
|
||||||
|
|
||||||
val contactRecycler: RecyclerView = view.findViewById(R.id.contact_recycler)
|
val contactRecycler: RecyclerView = view.findViewById(R.id.contact_recycler)
|
||||||
mediator = ContactSearchMediator(
|
mediator = ContactSearchMediator(
|
||||||
this,
|
fragment = this,
|
||||||
contactRecycler,
|
recyclerView = contactRecycler,
|
||||||
FeatureFlags.shareSelectionLimit(),
|
selectionLimits = FeatureFlags.shareSelectionLimit(),
|
||||||
true,
|
displayCheckBox = true,
|
||||||
{ state ->
|
mapStateToConfiguration = { state ->
|
||||||
ContactSearchConfiguration.build {
|
ContactSearchConfiguration.build {
|
||||||
query = state.query
|
query = state.query
|
||||||
|
|
||||||
|
@ -77,7 +77,8 @@ class ChooseGroupStoryBottomSheet : FixedRoundedCornerBottomSheetDialogFragment(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
performSafetyNumberChecks = false
|
||||||
)
|
)
|
||||||
|
|
||||||
mediator.getSelectionState().observe(viewLifecycleOwner) { state ->
|
mediator.getSelectionState().observe(viewLifecycleOwner) { state ->
|
||||||
|
|
|
@ -0,0 +1,246 @@
|
||||||
|
package org.thoughtcrime.securesms.contacts.paged
|
||||||
|
|
||||||
|
import android.app.Application
|
||||||
|
import io.reactivex.rxjava3.core.Single
|
||||||
|
import org.junit.Before
|
||||||
|
import org.junit.BeforeClass
|
||||||
|
import org.junit.Rule
|
||||||
|
import org.junit.Test
|
||||||
|
import org.junit.runner.RunWith
|
||||||
|
import org.mockito.Mock
|
||||||
|
import org.mockito.MockedStatic
|
||||||
|
import org.mockito.internal.configuration.plugins.Plugins
|
||||||
|
import org.mockito.internal.junit.JUnitRule
|
||||||
|
import org.mockito.junit.MockitoRule
|
||||||
|
import org.mockito.kotlin.any
|
||||||
|
import org.mockito.kotlin.argThat
|
||||||
|
import org.mockito.kotlin.times
|
||||||
|
import org.mockito.kotlin.verify
|
||||||
|
import org.mockito.kotlin.whenever
|
||||||
|
import org.mockito.quality.Strictness
|
||||||
|
import org.robolectric.RobolectricTestRunner
|
||||||
|
import org.robolectric.annotation.Config
|
||||||
|
import org.signal.core.util.logging.Log
|
||||||
|
import org.thoughtcrime.securesms.crypto.IdentityKeyUtil
|
||||||
|
import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore
|
||||||
|
import org.thoughtcrime.securesms.database.IdentityDatabase
|
||||||
|
import org.thoughtcrime.securesms.database.RecipientDatabaseTestUtils
|
||||||
|
import org.thoughtcrime.securesms.database.model.IdentityRecord
|
||||||
|
import org.thoughtcrime.securesms.recipients.Recipient
|
||||||
|
import org.thoughtcrime.securesms.recipients.RecipientId
|
||||||
|
import org.thoughtcrime.securesms.testutil.SystemOutLogger
|
||||||
|
import org.thoughtcrime.securesms.util.IdentityUtil
|
||||||
|
import org.whispersystems.signalservice.api.push.ACI
|
||||||
|
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
|
||||||
|
import org.whispersystems.signalservice.api.services.ProfileService
|
||||||
|
import org.whispersystems.signalservice.internal.ServiceResponse
|
||||||
|
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse
|
||||||
|
import java.io.IOException
|
||||||
|
import java.util.Optional
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
|
@RunWith(RobolectricTestRunner::class)
|
||||||
|
@Config(application = Application::class)
|
||||||
|
class SafetyNumberRepositoryTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
@JvmField
|
||||||
|
val mockitoRule: MockitoRule = JUnitRule(Plugins.getMockitoLogger(), Strictness.STRICT_STUBS)
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
lateinit var profileService: ProfileService
|
||||||
|
|
||||||
|
@Mock(lenient = true)
|
||||||
|
lateinit var aciIdentityStore: SignalIdentityKeyStore
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
lateinit var staticIdentityUtil: MockedStatic<IdentityUtil>
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
lateinit var staticRecipient: MockedStatic<Recipient>
|
||||||
|
|
||||||
|
private var now: Long = System.currentTimeMillis()
|
||||||
|
|
||||||
|
private lateinit var recipientPool: MutableList<Recipient>
|
||||||
|
private lateinit var identityPool: MutableMap<Recipient, IdentityRecord>
|
||||||
|
|
||||||
|
private lateinit var repository: SafetyNumberRepository
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
@BeforeClass
|
||||||
|
@JvmStatic
|
||||||
|
fun setUpClass() {
|
||||||
|
Log.initialize(SystemOutLogger())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
fun setUp() {
|
||||||
|
now = System.currentTimeMillis()
|
||||||
|
repository = SafetyNumberRepository(profileService, aciIdentityStore)
|
||||||
|
|
||||||
|
recipientPool = mutableListOf()
|
||||||
|
identityPool = mutableMapOf()
|
||||||
|
|
||||||
|
for (id in 1L until 12) {
|
||||||
|
val recipient = RecipientDatabaseTestUtils.createRecipient(resolved = true, recipientId = RecipientId.from(id))
|
||||||
|
staticRecipient.`when`<Recipient> { Recipient.resolved(RecipientId.from(id)) }.thenReturn(recipient)
|
||||||
|
recipientPool.add(recipient)
|
||||||
|
|
||||||
|
val record = IdentityRecord(
|
||||||
|
recipientId = recipient.id,
|
||||||
|
identityKey = IdentityKeyUtil.generateIdentityKeyPair().publicKey,
|
||||||
|
verifiedStatus = IdentityDatabase.VerifiedStatus.DEFAULT,
|
||||||
|
firstUse = false,
|
||||||
|
timestamp = 0,
|
||||||
|
nonblockingApproval = false
|
||||||
|
)
|
||||||
|
whenever(aciIdentityStore.getIdentityRecord(recipient.id)).thenReturn(Optional.of(record))
|
||||||
|
identityPool[recipient] = record
|
||||||
|
}
|
||||||
|
|
||||||
|
staticRecipient.`when`<Recipient> { Recipient.self() }.thenReturn(recipientPool[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch request for a current identity key should return an empty list and not perform any identity key updates.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_batchOf1_noChanges() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf()), 200, "")))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
|
||||||
|
staticIdentityUtil.verifyNoInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch request for an out-of-date identity key should return the new identity key and update the store.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_batchOf1_oneChange() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val otherAci = ACI.from(other.requireServiceId())
|
||||||
|
val otherNewIdentityKey = IdentityKeyUtil.generateIdentityKeyPair().publicKey
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf(IdentityCheckResponse.AciIdentityPair(otherAci, otherNewIdentityKey))), 200, "")))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
|
||||||
|
staticIdentityUtil.verify { IdentityUtil.saveIdentity(otherAci.toString(), otherNewIdentityKey) }
|
||||||
|
staticIdentityUtil.verifyNoMoreInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch request for an out-of-date identity key should return the new identity key and update the store.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_batchOf2_oneChange() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val secondOther = recipientPool[2]
|
||||||
|
val otherAci = ACI.from(other.requireServiceId())
|
||||||
|
val otherNewIdentityKey = IdentityKeyUtil.generateIdentityKeyPair().publicKey
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id), ContactSearchKey.RecipientSearchKey.KnownRecipient(secondOther.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other, secondOther))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey, secondOther.requireServiceId() to identityPool[secondOther]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf(IdentityCheckResponse.AciIdentityPair(otherAci, otherNewIdentityKey))), 200, "")))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
|
||||||
|
staticIdentityUtil.verify { IdentityUtil.saveIdentity(otherAci.toString(), otherNewIdentityKey) }
|
||||||
|
staticIdentityUtil.verifyNoMoreInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch request for a current identity key should previously checked should abort checking.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_batchOf1_abortOnPriorRecentCheck() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf()), 200, "")))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
verify(profileService, times(1)).performIdentityCheck(any())
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now + TimeUnit.SECONDS.toMillis(10))
|
||||||
|
verify(profileService, times(1)).performIdentityCheck(any())
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now + TimeUnit.SECONDS.toMillis(31))
|
||||||
|
verify(profileService, times(2)).performIdentityCheck(any())
|
||||||
|
|
||||||
|
staticIdentityUtil.verifyNoInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch request for a current identity keys should return an empty list and not perform any identity key updates.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_batchOf10WithSmallBatchSize_noChanges() {
|
||||||
|
val keys = recipientPool.map { ContactSearchKey.RecipientSearchKey.KnownRecipient(it.id) }
|
||||||
|
val others = recipientPool.subList(1, recipientPool.lastIndex)
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(others.map { it.id }) }) }.thenReturn(others)
|
||||||
|
|
||||||
|
for (chunk in others.chunked(2)) {
|
||||||
|
whenever(profileService.performIdentityCheck(chunk.associate { it.requireServiceId() to identityPool[it]!!.identityKey }))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(listOf()), 200, "")))
|
||||||
|
}
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now, 2)
|
||||||
|
|
||||||
|
staticIdentityUtil.verifyNoInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_serverError() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forApplicationError(NonSuccessfulResponseCodeException(400), 400, "")))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
|
||||||
|
staticIdentityUtil.verifyNoInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_networkError() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forUnknownError(IOException())))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
|
||||||
|
staticIdentityUtil.verifyNoInteractions()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun batchSafetyNumberCheckSync_badJson() {
|
||||||
|
val other = recipientPool[1]
|
||||||
|
val keys = listOf(ContactSearchKey.RecipientSearchKey.KnownRecipient(other.id))
|
||||||
|
|
||||||
|
staticRecipient.`when`<List<Recipient>> { Recipient.resolvedList(argThat { containsAll(keys.map { it.recipientId }) }) }.thenReturn(listOf(other))
|
||||||
|
whenever(profileService.performIdentityCheck(mapOf(other.requireServiceId() to identityPool[other]!!.identityKey)))
|
||||||
|
.thenReturn(Single.just(ServiceResponse.forResult(IdentityCheckResponse(), 200, "")))
|
||||||
|
|
||||||
|
repository.batchSafetyNumberCheckSync(keys, now)
|
||||||
|
|
||||||
|
staticIdentityUtil.verifyNoInteractions()
|
||||||
|
}
|
||||||
|
}
|
|
@ -112,7 +112,7 @@ public final class ProfileService {
|
||||||
.onErrorReturn(ServiceResponse::forUnknownError);
|
.onErrorReturn(ServiceResponse::forUnknownError);
|
||||||
}
|
}
|
||||||
|
|
||||||
public @NonNull Single<ServiceResponse<IdentityCheckResponse>> performIdentityCheck(@Nonnull Map<ServiceId, IdentityKey> aciIdentityKeyMap, @Nonnull Optional<UnidentifiedAccess> unidentifiedAccess) {
|
public @NonNull Single<ServiceResponse<IdentityCheckResponse>> performIdentityCheck(@Nonnull Map<ServiceId, IdentityKey> aciIdentityKeyMap) {
|
||||||
List<AciFingerprintPair> aciKeyPairs = aciIdentityKeyMap.entrySet()
|
List<AciFingerprintPair> aciKeyPairs = aciIdentityKeyMap.entrySet()
|
||||||
.stream()
|
.stream()
|
||||||
.map(e -> new AciFingerprintPair(e.getKey(), e.getValue()))
|
.map(e -> new AciFingerprintPair(e.getKey(), e.getValue()))
|
||||||
|
@ -129,9 +129,9 @@ public final class ProfileService {
|
||||||
|
|
||||||
ResponseMapper<IdentityCheckResponse> responseMapper = DefaultResponseMapper.getDefault(IdentityCheckResponse.class);
|
ResponseMapper<IdentityCheckResponse> responseMapper = DefaultResponseMapper.getDefault(IdentityCheckResponse.class);
|
||||||
|
|
||||||
return signalWebSocket.request(builder.build(), unidentifiedAccess)
|
return signalWebSocket.request(builder.build(), Optional.empty())
|
||||||
.map(responseMapper::map)
|
.map(responseMapper::map)
|
||||||
.onErrorResumeNext(t -> performIdentityCheckRestFallback(request, unidentifiedAccess, responseMapper))
|
.onErrorResumeNext(t -> performIdentityCheckRestFallback(request, Optional.empty(), responseMapper))
|
||||||
.onErrorReturn(ServiceResponse::forUnknownError);
|
.onErrorReturn(ServiceResponse::forUnknownError);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
|
|
||||||
import org.signal.libsignal.protocol.IdentityKey;
|
import org.signal.libsignal.protocol.IdentityKey;
|
||||||
import org.whispersystems.signalservice.api.push.ServiceId;
|
import org.whispersystems.signalservice.api.push.ACI;
|
||||||
import org.whispersystems.signalservice.internal.util.JsonUtil;
|
import org.whispersystems.signalservice.internal.util.JsonUtil;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -16,6 +16,13 @@ public class IdentityCheckResponse {
|
||||||
@JsonProperty("elements")
|
@JsonProperty("elements")
|
||||||
private List<AciIdentityPair> aciKeyPairs;
|
private List<AciIdentityPair> aciKeyPairs;
|
||||||
|
|
||||||
|
public IdentityCheckResponse() {}
|
||||||
|
|
||||||
|
// Visible for testing
|
||||||
|
public IdentityCheckResponse(List<AciIdentityPair> aciKeyPairs) {
|
||||||
|
this.aciKeyPairs = aciKeyPairs;
|
||||||
|
}
|
||||||
|
|
||||||
public @Nullable List<AciIdentityPair> getAciKeyPairs() {
|
public @Nullable List<AciIdentityPair> getAciKeyPairs() {
|
||||||
return aciKeyPairs;
|
return aciKeyPairs;
|
||||||
}
|
}
|
||||||
|
@ -23,14 +30,22 @@ public class IdentityCheckResponse {
|
||||||
public static final class AciIdentityPair {
|
public static final class AciIdentityPair {
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
@JsonDeserialize(using = JsonUtil.ServiceIdDeserializer.class)
|
@JsonDeserialize(using = JsonUtil.AciDeserializer.class)
|
||||||
private ServiceId aci;
|
private ACI aci;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
@JsonDeserialize(using = JsonUtil.IdentityKeyDeserializer.class)
|
@JsonDeserialize(using = JsonUtil.IdentityKeyDeserializer.class)
|
||||||
private IdentityKey identityKey;
|
private IdentityKey identityKey;
|
||||||
|
|
||||||
public @Nullable ServiceId getAci() {
|
public AciIdentityPair() {}
|
||||||
|
|
||||||
|
// Visible for testing
|
||||||
|
public AciIdentityPair(ACI aci, IdentityKey identityKey) {
|
||||||
|
this.aci = aci;
|
||||||
|
this.identityKey = identityKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
public @Nullable ACI getAci() {
|
||||||
return aci;
|
return aci;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Ładowanie…
Reference in New Issue