Fix token mismatch issues when using CDSv2.

fork-5.53.8
Greyson Parrelli 2022-09-07 14:18:54 -04:00 zatwierdzone przez Cody Henthorne
rodzic f1bcc756d3
commit 658741be52
4 zmienionych plików z 96 dodań i 55 usunięć

Wyświetl plik

@ -37,7 +37,6 @@ import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.api.util.UuidUtil
import java.io.IOException
import java.lang.Exception
import java.util.Calendar
import java.util.concurrent.Callable
import java.util.concurrent.ExecutionException
@ -80,9 +79,9 @@ object ContactDiscovery {
descriptor = "refresh-all",
refresh = {
if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false)
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false, ignoreResults = false)
} else if (FeatureFlags.cdsV2Compat()) {
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true)
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true, ignoreResults = false)
} else if (FeatureFlags.cdsV2LoadTesting()) {
loadTestRefreshAll(context)
} else {
@ -105,9 +104,9 @@ object ContactDiscovery {
descriptor = "refresh-multiple",
refresh = {
if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false)
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false, ignoreResults = false)
} else if (FeatureFlags.cdsV2Compat()) {
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true)
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true, ignoreResults = false)
} else if (FeatureFlags.cdsV2LoadTesting()) {
loadTestRefresh(context, recipients)
} else {
@ -128,9 +127,9 @@ object ContactDiscovery {
descriptor = "refresh-single",
refresh = {
if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false)
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false, ignoreResults = false)
} else if (FeatureFlags.cdsV2Compat()) {
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true)
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true, ignoreResults = false)
} else if (FeatureFlags.cdsV2LoadTesting()) {
loadTestRefresh(context, listOf(recipient))
} else {
@ -404,7 +403,7 @@ object ContactDiscovery {
try {
v2Future.get()
} catch (e: Exception) {
} catch (e: Throwable) {
Log.w(TAG, "Failed to complete the V2 fetch!", e)
}

Wyświetl plik

@ -18,6 +18,7 @@ import org.thoughtcrime.securesms.phonenumbers.PhoneNumberFormatter
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId
import org.whispersystems.signalservice.api.push.ACI
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.api.services.CdsiV2Service
import java.io.IOException
import java.util.Optional
@ -44,7 +45,7 @@ object ContactDiscoveryRefreshV2 {
@WorkerThread
@Synchronized
@JvmStatic
fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult {
fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult {
val recipientE164s: Set<String> = SignalDatabase.recipients.getAllE164s().sanitize()
val systemE164s: Set<String> = SystemContactsRepository.getAllDisplayNumbers(context).toE164s(context).sanitize()
@ -52,7 +53,7 @@ object ContactDiscoveryRefreshV2 {
recipientE164s = recipientE164s,
systemE164s = systemE164s,
inputPreviousE164s = SignalDatabase.cds.getAllE164s(),
saveToken = true,
isPartialRefresh = false,
useCompat = useCompat,
ignoreResults = ignoreResults
)
@ -62,14 +63,14 @@ object ContactDiscoveryRefreshV2 {
@WorkerThread
@Synchronized
@JvmStatic
fun refresh(context: Context, inputRecipients: List<Recipient>, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult {
fun refresh(context: Context, inputRecipients: List<Recipient>, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult {
val recipients: List<Recipient> = inputRecipients.map { it.resolve() }
val inputE164s: Set<String> = recipients.mapNotNull { it.e164.orElse(null) }.toSet()
return if (inputE164s.size > MAXIMUM_ONE_OFF_REQUEST_SIZE) {
Log.i(TAG, "List of specific recipients to refresh is too large! (Size: ${recipients.size}). Doing a full refresh instead.")
val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, ignoreResults)
val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, useCompat = useCompat, ignoreResults = ignoreResults)
val inputIds: Set<RecipientId> = recipients.map { it.id }.toSet()
ContactDiscovery.RefreshResult(
@ -81,7 +82,7 @@ object ContactDiscoveryRefreshV2 {
recipientE164s = inputE164s,
systemE164s = inputE164s,
inputPreviousE164s = emptySet(),
saveToken = false,
isPartialRefresh = true,
useCompat = useCompat,
ignoreResults = ignoreResults
)
@ -93,13 +94,14 @@ object ContactDiscoveryRefreshV2 {
recipientE164s: Set<String>,
systemE164s: Set<String>,
inputPreviousE164s: Set<String>,
saveToken: Boolean,
isPartialRefresh: Boolean,
useCompat: Boolean,
ignoreResults: Boolean
): ContactDiscovery.RefreshResult {
val stopwatch = Stopwatch("refreshInternal-${if (useCompat) "compat" else "v2"}")
val tag = "refreshInternal-${if (useCompat) "compat" else "v2"}"
val stopwatch = Stopwatch(tag)
val previousE164s: Set<String> = if (SignalStore.misc().cdsToken != null) inputPreviousE164s else emptySet()
val previousE164s: Set<String> = if (SignalStore.misc().cdsToken != null && !isPartialRefresh) inputPreviousE164s else emptySet()
val allE164s: Set<String> = recipientE164s + systemE164s
val newRawE164s: Set<String> = allE164s - previousE164s
@ -107,40 +109,50 @@ object ContactDiscoveryRefreshV2 {
val newE164s: Set<String> = fuzzyInput.numbers
if (newE164s.isEmpty() && previousE164s.isEmpty()) {
Log.w(TAG, "[refreshInternal] No data to send! Ignoring.")
Log.w(TAG, "[$tag] No data to send! Ignoring.")
return ContactDiscovery.RefreshResult(emptySet(), emptyMap())
}
val token: ByteArray? = if (previousE164s.isNotEmpty()) SignalStore.misc().cdsToken else null
val token: ByteArray? = if (previousE164s.isNotEmpty() && !isPartialRefresh) SignalStore.misc().cdsToken else null
stopwatch.split("preamble")
val response: CdsiV2Service.Response = ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi(
previousE164s,
newE164s,
SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(),
useCompat,
Optional.ofNullable(token),
BuildConfig.CDSI_MRENCLAVE
) { tokenToSave ->
if (saveToken) {
SignalStore.misc().cdsToken = tokenToSave
Log.d(TAG, "Token saved!")
} else {
Log.d(TAG, "Ignoring token.")
val response: CdsiV2Service.Response = try {
ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi(
previousE164s,
newE164s,
SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(),
useCompat,
Optional.ofNullable(token),
BuildConfig.CDSI_MRENCLAVE
) { tokenToSave ->
stopwatch.split("network-pre-token")
if (!isPartialRefresh) {
SignalStore.misc().cdsToken = tokenToSave
SignalDatabase.cds.updateAfterFullCdsQuery(previousE164s + newE164s, allE164s + newE164s)
Log.d(TAG, "Token saved!")
} else {
SignalDatabase.cds.updateAfterPartialCdsQuery(newE164s)
Log.d(TAG, "Ignoring token.")
}
stopwatch.split("cds-db")
}
} catch (e: NonSuccessfulResponseCodeException) {
if (e.code == 4101) {
Log.w(TAG, "Our token was invalid! Only thing we can do now is clear our local state :(")
SignalStore.misc().cdsToken = null
SignalDatabase.cds.clearAll()
}
throw e
}
Log.d(TAG, "[refreshInternal] Used ${response.quotaUsedDebugOnly} quota.")
stopwatch.split("network")
SignalDatabase.cds.updateAfterCdsQuery(newE164s, allE164s + newE164s)
stopwatch.split("cds-db")
Log.d(TAG, "[$tag] Used ${response.quotaUsedDebugOnly} quota.")
stopwatch.split("network-post-token")
val registeredIds: MutableSet<RecipientId> = mutableSetOf()
val rewrites: MutableMap<String, String> = mutableMapOf()
if (ignoreResults) {
Log.w(TAG, "[refreshInternal] Ignoring CDSv2 results.")
Log.w(TAG, "[$tag] Ignoring CDSv2 results.")
} else {
if (useCompat) {
val transformed: Map<String, ACI?> = response.results.mapValues { entry -> entry.value.aci.orElse(null) }

Wyświetl plik

@ -8,7 +8,7 @@ import org.signal.core.util.delete
import org.signal.core.util.logging.Log
import org.signal.core.util.requireNonNullString
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.core.util.withinTransaction
/**
* Keeps track of the numbers we've previously queried CDS for.
@ -53,32 +53,58 @@ class CdsDatabase(context: Context, databaseHelper: SignalDatabase) : Database(c
}
/**
* @param newE164s The newly-added E164s that we hadn't previously queried for.
* @param seenE164s The E164s that were seen in either the system contacts or recipients table.
* This should be a superset of [newE164s]
*
* Saves the set of e164s used after a full refresh.
* @param fullE164s All of the e164s used in the last CDS query (previous and new).
* @param seenE164s The E164s that were seen in either the system contacts or recipients table. This is different from [fullE164s] in that [fullE164s]
* includes every number we've ever seen, even if it's not in our contacts anymore.
*/
fun updateAfterCdsQuery(newE164s: Set<String>, seenE164s: Set<String>) {
fun updateAfterFullCdsQuery(fullE164s: Set<String>, seenE164s: Set<String>) {
val lastSeen = System.currentTimeMillis()
writableDatabase.beginTransaction()
try {
val insertValues: List<ContentValues> = newE164s.map { contentValuesOf(E164 to it) }
writableDatabase.withinTransaction { db ->
val existingE164s: Set<String> = getAllE164s()
val removedE164s: Set<String> = existingE164s - fullE164s
val addedE164s: Set<String> = fullE164s - existingE164s
SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues)
.forEach { writableDatabase.execSQL(it.where, it.whereArgs) }
if (removedE164s.isNotEmpty()) {
SqlUtil.buildCollectionQuery(E164, removedE164s)
.forEach { db.delete(TABLE_NAME, it.where, it.whereArgs) }
}
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
if (addedE164s.isNotEmpty()) {
val insertValues: List<ContentValues> = addedE164s.map { contentValuesOf(E164 to it) }
SqlUtil.buildCollectionQuery(E164, seenE164s)
.forEach { query -> writableDatabase.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues)
.forEach { db.execSQL(it.where, it.whereArgs) }
}
writableDatabase.setTransactionSuccessful()
} finally {
writableDatabase.endTransaction()
if (seenE164s.isNotEmpty()) {
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
SqlUtil.buildCollectionQuery(E164, seenE164s)
.forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
}
}
}
/**
* Updates after a partial CDS query. Will not insert new entries. Instead, this will simply update the lastSeen timestamp of any entry we already have.
* @param seenE164s The newly-added E164s that we hadn't previously queried for.
*/
fun updateAfterPartialCdsQuery(seenE164s: Set<String>) {
val lastSeen = System.currentTimeMillis()
writableDatabase.withinTransaction { db ->
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
SqlUtil.buildCollectionQuery(E164, seenE164s)
.forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
}
}
/**
* Wipes the entire table.
*/
fun clearAll() {
writableDatabase
.delete(TABLE_NAME)

Wyświetl plik

@ -549,7 +549,11 @@ public class SignalServiceAccountManager {
if (serviceResponse.getResult().isPresent()) {
return serviceResponse.getResult().get();
} else if (serviceResponse.getApplicationError().isPresent()) {
throw new IOException(serviceResponse.getApplicationError().get());
if (serviceResponse.getApplicationError().get() instanceof IOException) {
throw (IOException) serviceResponse.getApplicationError().get();
} else {
throw new IOException(serviceResponse.getApplicationError().get());
}
} else if (serviceResponse.getExecutionError().isPresent()) {
throw new IOException(serviceResponse.getExecutionError().get());
} else {