Add basic CDSv2 database writes and unit tests.

fork-5.53.8
Greyson Parrelli 2022-05-16 12:15:11 -04:00 zatwierdzone przez Cody Henthorne
rodzic 307be5c75e
commit dda5ce4809
8 zmienionych plików z 324 dodań i 27 usunięć

Wyświetl plik

@ -0,0 +1,206 @@
package org.thoughtcrime.securesms.database
import androidx.core.content.contentValuesOf
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.signal.core.util.requireLong
import org.signal.core.util.requireString
import org.signal.core.util.select
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.recipients.RecipientId
import org.whispersystems.signalservice.api.push.ACI
import org.whispersystems.signalservice.api.push.PNI
import org.whispersystems.signalservice.api.push.ServiceId
import java.util.UUID
@RunWith(AndroidJUnit4::class)
class RecipientDatabaseTest_processCdsV2Result {
private lateinit var recipientDatabase: RecipientDatabase
private val localAci = ACI.from(UUID.randomUUID())
private val localPni = PNI.from(UUID.randomUUID())
@Before
fun setup() {
recipientDatabase = SignalDatabase.recipients
ensureDbEmpty()
SignalStore.account().setAci(localAci)
SignalStore.account().setPni(localPni)
}
@Test
fun processCdsV2Result_noMatch() {
// Note that we haven't inserted any test data
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(resultId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_fullMatch() {
val inputId: RecipientId = insert(E164_A, PNI_A, ACI_A)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_onlyE164Matches() {
val inputId: RecipientId = insert(E164_A, null, null)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_e164AndPniMatches() {
val inputId: RecipientId = insert(E164_A, PNI_A, null)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_e164AndAciMatches() {
val inputId: RecipientId = insert(E164_A, null, ACI_A)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_onlyPniMatches() {
val inputId: RecipientId = insert(null, PNI_A, null)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_pniAndAciMatches() {
val inputId: RecipientId = insert(null, PNI_A, ACI_A)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
@Test
fun processCdsV2Result_onlyAciMatches() {
val inputId: RecipientId = insert(null, null, ACI_A)
val resultId: RecipientId = recipientDatabase.processCdsV2Result(E164_A, PNI_A, ACI_A)
val record: IdRecord = require(resultId)
assertEquals(inputId, record.id)
assertEquals(E164_A, record.e164)
assertEquals(ACI_A, record.sid)
assertEquals(PNI_A, record.pni)
}
private fun insert(e164: String?, pni: PNI?, aci: ACI?): RecipientId {
val id: Long = SignalDatabase.rawDatabase.insert(
RecipientDatabase.TABLE_NAME,
null,
contentValuesOf(
RecipientDatabase.PHONE to e164,
RecipientDatabase.SERVICE_ID to (aci ?: pni)?.toString(),
RecipientDatabase.PNI_COLUMN to pni?.toString(),
RecipientDatabase.REGISTERED to RecipientDatabase.RegisteredState.REGISTERED.id
)
)
return RecipientId.from(id)
}
private fun require(id: RecipientId): IdRecord {
return get(id)!!
}
private fun get(id: RecipientId): IdRecord? {
SignalDatabase.rawDatabase
.select(RecipientDatabase.ID, RecipientDatabase.PHONE, RecipientDatabase.SERVICE_ID, RecipientDatabase.PNI_COLUMN)
.from(RecipientDatabase.TABLE_NAME)
.where("${RecipientDatabase.ID} = ?", id)
.run()
.use { cursor ->
return if (cursor.moveToFirst()) {
IdRecord(
id = RecipientId.from(cursor.requireLong(RecipientDatabase.ID)),
e164 = cursor.requireString(RecipientDatabase.PHONE),
sid = ServiceId.parseOrNull(cursor.requireString(RecipientDatabase.SERVICE_ID)),
pni = PNI.parseOrNull(cursor.requireString(RecipientDatabase.PNI_COLUMN))
)
} else {
null
}
}
}
private fun ensureDbEmpty() {
SignalDatabase.rawDatabase.rawQuery("SELECT COUNT(*) FROM ${RecipientDatabase.TABLE_NAME} WHERE ${RecipientDatabase.DISTRIBUTION_LIST_ID} IS NULL ", null).use { cursor ->
assertTrue(cursor.moveToFirst())
assertEquals(0, cursor.getLong(0))
}
}
private data class IdRecord(
val id: RecipientId,
val e164: String?,
val sid: ServiceId?,
val pni: PNI?,
)
companion object {
val ACI_A = ACI.from(UUID.fromString("3436efbe-5a76-47fa-a98a-7e72c948a82e"))
val ACI_B = ACI.from(UUID.fromString("8de7f691-0b60-4a68-9cd9-ed2f8453f9ed"))
val PNI_A = PNI.from(UUID.fromString("154b8d92-c960-4f6c-8385-671ad2ffb999"))
val PNI_B = PNI.from(UUID.fromString("ba92b1fb-cd55-40bf-adda-c35a85375533"))
const val E164_A = "+12221234567"
const val E164_B = "+13331234567"
}
}

Wyświetl plik

@ -74,7 +74,7 @@ object ContactDiscovery {
context = context,
descriptor = "refresh-all",
refresh = {
if (FeatureFlags.usePnpCds()) {
if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refreshAll(context)
} else {
ContactDiscoveryRefreshV1.refreshAll(context)
@ -95,7 +95,7 @@ object ContactDiscovery {
context = context,
descriptor = "refresh-multiple",
refresh = {
if (FeatureFlags.usePnpCds()) {
if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refresh(context, recipients)
} else {
ContactDiscoveryRefreshV1.refresh(context, recipients)
@ -114,7 +114,7 @@ object ContactDiscovery {
context = context,
descriptor = "refresh-single",
refresh = {
if (FeatureFlags.usePnpCds()) {
if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient))
} else {
ContactDiscoveryRefreshV1.refresh(context, listOf(recipient))

Wyświetl plik

@ -121,8 +121,8 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
const val TABLE_NAME = "recipient"
const val ID = "_id"
private const val SERVICE_ID = "uuid"
private const val PNI_COLUMN = "pni"
const val SERVICE_ID = "uuid"
const val PNI_COLUMN = "pni"
private const val USERNAME = "username"
const val PHONE = "phone"
const val EMAIL = "email"
@ -403,6 +403,14 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
return getByColumn(SERVICE_ID, serviceId.toString())
}
/**
* Will return a recipient matching the PNI, but only in the explicit [PNI_COLUMN]. This should only be checked in conjunction with [getByServiceId] as a way
* to avoid creating a recipient we already merged.
*/
fun getByPni(pni: PNI): Optional<RecipientId> {
return getByColumn(PNI_COLUMN, pni.toString())
}
fun getByUsername(username: String): Optional<RecipientId> {
return getByColumn(USERNAME, username)
}
@ -2131,7 +2139,11 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
}
/**
* A dumb implementation of processing CDSv2 results. Suitable only for testing and not for actual use.
* Processes CDSv2 results, merging recipients as necessary.
*
* Important: This is under active development and is not suitable for actual use.
*
* @return A set of [RecipientId]s that were updated/inserted.
*/
fun bulkProcessCdsV2Result(mapping: Map<String, CdsV2Result>): Set<RecipientId> {
val ids: MutableSet<RecipientId> = mutableSetOf()
@ -2140,7 +2152,7 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
db.beginTransaction()
try {
for ((e164, result) in mapping) {
ids += getAndPossiblyMerge(result.bestServiceId(), e164, true)
ids += processCdsV2Result(e164, result.pni, result.aci)
}
db.setTransactionSuccessful()
@ -2151,6 +2163,47 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
return ids
}
@VisibleForTesting
fun processCdsV2Result(e164: String, pni: PNI, aci: ACI?): RecipientId {
val byE164: RecipientId? = getByE164(e164).orElse(null)
val byPni: RecipientId? = getByServiceId(pni).orElse(null)
val byPniOnly: RecipientId? = getByPni(pni).orElse(null)
val byAci: RecipientId? = aci?.let { getByServiceId(it).orElse(null) }
val commonId: RecipientId? = listOf(byE164, byPni, byPniOnly, byAci).commonId()
val allRequiredDbFields: List<RecipientId?> = if (aci != null) listOf(byE164, byAci, byPniOnly) else listOf(byE164, byPni, byPniOnly)
val allRequiredDbFieldPopulated: Boolean = allRequiredDbFields.all { it != null }
// All ID's agree and the database is up-to-date
if (commonId != null && allRequiredDbFieldPopulated) {
return commonId
}
// All ID's agree but we need to update the database
if (commonId != null && !allRequiredDbFieldPopulated) {
writableDatabase
.update(TABLE_NAME)
.values(
PHONE to e164,
SERVICE_ID to (aci ?: pni).toString(),
PNI_COLUMN to pni.toString(),
REGISTERED to RegisteredState.REGISTERED.id,
STORAGE_SERVICE_ID to Base64.encodeBytes(StorageSyncHelper.generateKey())
)
.where("$ID = ?", commonId)
.run()
return commonId
}
// Nothing matches
if (byE164 == null && byPni == null && byAci == null) {
val id: Long = writableDatabase.insert(TABLE_NAME, null, buildContentValuesForCdsInsert(e164, pni, aci))
return RecipientId.from(id)
}
throw NotImplementedError("Handle cases where IDs map to different individuals")
}
fun getUninvitedRecipientsForInsights(): List<RecipientId> {
val results: MutableList<RecipientId> = LinkedList()
val args = arrayOf((System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)).toString())
@ -2865,6 +2918,18 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
return values
}
private fun buildContentValuesForCdsInsert(e164: String, pni: PNI, aci: ACI?): ContentValues {
val serviceId: ServiceId = aci ?: pni
return ContentValues().apply {
put(PHONE, e164)
put(SERVICE_ID, serviceId.toString())
put(PNI_COLUMN, pni.toString())
put(REGISTERED, RegisteredState.REGISTERED.id)
put(STORAGE_SERVICE_ID, Base64.encodeBytes(StorageSyncHelper.generateKey()))
put(AVATAR_COLOR, AvatarColor.random().serialize())
}
}
private fun getValuesForStorageContact(contact: SignalContactRecord, isInsert: Boolean): ContentValues {
return ContentValues().apply {
val profileName = ProfileName.fromParts(contact.givenName.orElse(null), contact.familyName.orElse(null))
@ -2940,6 +3005,22 @@ open class RecipientDatabase(context: Context, databaseHelper: SignalDatabase) :
}
}
/**
* @return The common id if all non-null ids are equal, or null if all are null or at least one non-null pair doesn't match.
*/
private fun Collection<RecipientId?>.commonId(): RecipientId? {
val nonNull = this.filterNotNull()
if (nonNull.isEmpty()) {
return null
}
return if (nonNull.all { it.equals(nonNull[0]) }) {
nonNull[0]
} else {
null
}
}
/**
* Should only be used for debugging! A very destructive action that clears all known serviceIds.
*/

Wyświetl plik

@ -307,7 +307,7 @@ public class RetrieveProfileJob extends BaseJob {
recipientDatabase.markProfilesFetched(success, System.currentTimeMillis());
// XXX The service hasn't implemented profiles for PNIs yet, so if using PNP CDS we don't want to mark users without profiles as unregistered.
if ((operationState.unregistered.size() > 0 || newlyRegistered.size() > 0) && !FeatureFlags.usePnpCds()) {
if ((operationState.unregistered.size() > 0 || newlyRegistered.size() > 0) && !FeatureFlags.phoneNumberPrivacy()) {
Log.i(TAG, "Marking " + newlyRegistered.size() + " users as registered and " + operationState.unregistered.size() + " users as unregistered.");
recipientDatabase.bulkUpdatedRegisteredStatus(newlyRegistered, operationState.unregistered);
}

Wyświetl plik

@ -23,6 +23,7 @@ import org.thoughtcrime.securesms.jobs.PushProcessMessageJob;
import org.thoughtcrime.securesms.messages.MessageDecryptionUtil.DecryptionResult;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.thoughtcrime.securesms.util.FeatureFlags;
import org.thoughtcrime.securesms.util.GroupUtil;
import org.signal.core.util.SetUtil;
import org.thoughtcrime.securesms.util.Stopwatch;
@ -30,6 +31,7 @@ import org.thoughtcrime.securesms.util.TextSecurePreferences;
import org.whispersystems.signalservice.api.SignalSessionLock;
import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope;
import org.whispersystems.signalservice.api.messages.SignalServiceGroupContext;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.io.Closeable;
import java.io.IOException;
@ -81,6 +83,11 @@ public class IncomingMessageProcessor {
* one was created. Otherwise null.
*/
public @Nullable String processEnvelope(@NonNull SignalServiceEnvelope envelope) {
if (FeatureFlags.phoneNumberPrivacy() && envelope.hasSourceE164()) {
Log.w(TAG, "PNP enabled -- mimicking PNP by dropping the E164 from the envelope.");
envelope = envelope.withoutE164();
}
if (envelope.hasSourceUuid()) {
Recipient.externalHighTrustPush(context, envelope.getSourceAddress());
}

Wyświetl plik

@ -81,7 +81,7 @@ public final class MessageDecryptionUtil {
ServiceId pni = SignalStore.account().requirePni();
ServiceId destination;
if (!FeatureFlags.usePnpCds()) {
if (!FeatureFlags.phoneNumberPrivacy()) {
destination = aci;
} else if (envelope.hasDestinationUuid()) {
destination = ServiceId.parseOrThrow(envelope.getDestinationUuid());

Wyświetl plik

@ -60,7 +60,6 @@ public final class FeatureFlags {
private static final String GROUP_NAME_MAX_LENGTH = "global.groupsv2.maxNameLength";
private static final String INTERNAL_USER = "android.internalUser";
private static final String VERIFY_V2 = "android.verifyV2";
private static final String PHONE_NUMBER_PRIVACY_VERSION = "android.phoneNumberPrivacyVersion";
private static final String CLIENT_EXPIRATION = "android.clientExpiration";
public static final String DONATE_MEGAPHONE = "android.donate.2";
private static final String CUSTOM_VIDEO_MUXER = "android.customVideoMuxer";
@ -94,7 +93,7 @@ public final class FeatureFlags {
private static final String USE_HARDWARE_AEC_IF_OLD = "android.calling.useHardwareAecIfOlderThanApi29";
private static final String USE_AEC3 = "android.calling.useAec3";
private static final String PAYMENTS_COUNTRY_BLOCKLIST = "android.payments.blocklist";
private static final String PNP_CDS = "android.pnp.cds";
private static final String PHONE_NUMBER_PRIVACY = "android.pnp";
private static final String USE_FCM_FOREGROUND_SERVICE = "android.useFcmForegroundService.3";
private static final String STORIES_AUTO_DOWNLOAD_MAXIMUM = "android.stories.autoDownloadMaximum";
private static final String GIFT_BADGES = "android.giftBadges.2";
@ -152,8 +151,7 @@ public final class FeatureFlags {
@VisibleForTesting
static final Set<String> NOT_REMOTE_CAPABLE = SetUtil.newHashSet(
PHONE_NUMBER_PRIVACY_VERSION,
PNP_CDS
PHONE_NUMBER_PRIVACY
);
/**
@ -327,11 +325,11 @@ public final class FeatureFlags {
}
/**
* Whether the user can choose phone number privacy settings, and;
* Whether to fetch and store the secondary certificate
* Whether phone number privacy is enabled.
* IMPORTANT: This is under active development. Enabling this *will* break your contacts in terrible, irreversible ways.
*/
public static boolean phoneNumberPrivacy() {
return getVersionFlag(PHONE_NUMBER_PRIVACY_VERSION) == VersionFlag.ON;
return getBoolean(PHONE_NUMBER_PRIVACY, false) && Environment.IS_STAGING;
}
/** Whether to use the custom streaming muxer or built in android muxer. */
@ -493,15 +491,6 @@ public final class FeatureFlags {
return getBoolean(USE_AEC3, true);
}
/**
* Whether or not to use the phone number privacy CDS flow. Only currently works in staging.
*
* Note: This feature is in very early stages of development and *will* break your contacts.
*/
public static boolean usePnpCds() {
return Environment.IS_STAGING && getBoolean(PNP_CDS, false);
}
public static boolean useFcmForegroundService() {
return getBoolean(USE_FCM_FOREGROUND_SERVICE, false);
}

Wyświetl plik

@ -119,6 +119,12 @@ public class SignalServiceEnvelope {
this.serverDeliveredTimestamp = serverDeliveredTimestamp;
}
public SignalServiceEnvelope withoutE164() {
return deserialize(serializeToProto().clearSourceE164()
.build()
.toByteArray());
}
public String getServerGuid() {
return envelope.getServerGuid();
}
@ -156,6 +162,10 @@ public class SignalServiceEnvelope {
return envelope.hasSourceDevice();
}
public boolean hasSourceE164() {
return envelope.hasSourceE164();
}
/**
* @return The envelope's sender device ID.
*/
@ -263,7 +273,8 @@ public class SignalServiceEnvelope {
return envelope.getDestinationUuid();
}
public byte[] serialize() {
private SignalServiceEnvelopeProto.Builder serializeToProto() {
SignalServiceEnvelopeProto.Builder builder = SignalServiceEnvelopeProto.newBuilder()
.setType(getType())
.setDeviceId(getSourceDevice())
@ -295,8 +306,11 @@ public class SignalServiceEnvelope {
builder.setDestinationUuid(getDestinationUuid().toString());
}
return builder;
}
return builder.build().toByteArray();
public byte[] serialize() {
return serializeToProto().build().toByteArray();
}
public static SignalServiceEnvelope deserialize(byte[] serialized) {