Guarantee table export order is valid.

main
Greyson Parrelli 2022-11-16 13:21:33 -05:00 zatwierdzone przez Alex Hart
rodzic 7c60c32918
commit cf00995b6f
3 zmienionych plików z 262 dodań i 18 usunięć

Wyświetl plik

@ -8,6 +8,7 @@ import android.text.TextUtils;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;
import androidx.annotation.VisibleForTesting;
import androidx.documentfile.provider.DocumentFile;
import com.annimon.stream.function.Predicate;
@ -19,6 +20,7 @@ import org.greenrobot.eventbus.EventBus;
import org.signal.core.util.Conversions;
import org.signal.core.util.CursorUtil;
import org.signal.core.util.SetUtil;
import org.signal.core.util.SqlUtil;
import org.signal.core.util.Stopwatch;
import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.kdf.HKDF;
@ -62,10 +64,15 @@ import java.io.OutputStream;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
@ -84,7 +91,11 @@ public class FullBackupExporter extends FullBackupBase {
private static final long IDENTITY_KEY_BACKUP_RECORD_COUNT = 2L;
private static final long FINAL_MESSAGE_COUNT = 1L;
private static final Set<String> BLACKLISTED_TABLES = SetUtil.newHashSet(
/**
* Tables in list will still have their *schema* exported (so the tables will be created),
* but we will not export the actual contents.
*/
private static final Set<String> TABLE_CONTENT_BLOCKLIST = SetUtil.newHashSet(
SignedPreKeyDatabase.TABLE_NAME,
OneTimePreKeyDatabase.TABLE_NAME,
SessionDatabase.TABLE_NAME,
@ -175,7 +186,7 @@ public class FullBackupExporter extends FullBackupBase {
count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMmsMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.MMS_ID))), (cursor, innerCount) -> exportAttachment(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal);
} else if (table.equals(StickerDatabase.TABLE_NAME)) {
count = exportTable(table, input, outputStream, cursor -> true, (cursor, innerCount) -> exportSticker(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal);
} else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) {
} else if (!TABLE_CONTENT_BLOCKLIST.contains(table)) {
count = exportTable(table, input, outputStream, null, null, count, estimatedCount, cancellationSignal);
}
stopwatch.split("table::" + table);
@ -229,7 +240,7 @@ public class FullBackupExporter extends FullBackupBase {
count += getCount(input, BackupCountQueries.getAttachmentCount());
} else if (table.equals(StickerDatabase.TABLE_NAME)) {
count += getCount(input, "SELECT COUNT(*) FROM " + table);
} else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) {
} else if (!TABLE_CONTENT_BLOCKLIST.contains(table)) {
count += getCount(input, "SELECT COUNT(*) FROM " + table);
}
}
@ -266,31 +277,112 @@ public class FullBackupExporter extends FullBackupBase {
private static List<String> exportSchema(@NonNull SQLiteDatabase input, @NonNull BackupFrameOutputStream outputStream)
throws IOException
{
List<String> tables = new LinkedList<>();
List<String> tablesInOrder = getTablesToExportInOrder(input);
try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master", null)) {
Map<String, String> createStatementsByTable = new HashMap<>();
try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master WHERE type = 'table' AND sql NOT NULL", null)) {
while (cursor != null && cursor.moveToNext()) {
String sql = cursor.getString(0);
String name = cursor.getString(1);
String type = cursor.getString(2);
if (sql != null) {
boolean isSmsFtsSecretTable = name != null && !name.equals(SearchDatabase.SMS_FTS_TABLE_NAME) && name.startsWith(SearchDatabase.SMS_FTS_TABLE_NAME);
boolean isMmsFtsSecretTable = name != null && !name.equals(SearchDatabase.MMS_FTS_TABLE_NAME) && name.startsWith(SearchDatabase.MMS_FTS_TABLE_NAME);
boolean isEmojiFtsSecretTable = name != null && !name.equals(EmojiSearchDatabase.TABLE_NAME) && name.startsWith(EmojiSearchDatabase.TABLE_NAME);
createStatementsByTable.put(name, sql);
}
}
if (!isSmsFtsSecretTable && !isMmsFtsSecretTable && !isEmojiFtsSecretTable) {
if ("table".equals(type)) {
tables.add(name);
}
for (String table : tablesInOrder) {
String statement = createStatementsByTable.get(table);
outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(cursor.getString(0)).build());
}
if (statement != null) {
outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(statement).build());
} else {
throw new IOException("Failed to find a create statement for table: " + table);
}
}
try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master where type != 'table' AND sql NOT NULL", null)) {
while (cursor != null && cursor.moveToNext()) {
String sql = cursor.getString(0);
String name = cursor.getString(1);
if (isTableAllowed(name)) {
outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(sql).build());
}
}
}
return tables;
return tablesInOrder;
}
/**
* Returns the list of tables we should export, in the order they should be exported in.
* The order is chosen to ensure we won't violate any foreign key constraints when we import them.
*/
private static List<String> getTablesToExportInOrder(@NonNull SQLiteDatabase input) {
List<String> tables = SqlUtil.getAllTables(input)
.stream()
.filter(FullBackupExporter::isTableAllowed)
.sorted()
.collect(Collectors.toList());
Map<String, Set<String>> dependsOn = new LinkedHashMap<>();
for (String table : tables) {
dependsOn.put(table, SqlUtil.getForeignKeyDependencies(input, table));
}
return computeTableOrder(dependsOn);
}
@VisibleForTesting
static List<String> computeTableOrder(@NonNull Map<String, Set<String>> dependsOn) {
List<String> rootNodes = dependsOn.keySet()
.stream()
.filter(table -> {
boolean nothingDependsOnIt = dependsOn.values().stream().noneMatch(it -> it.contains(table));
return nothingDependsOnIt;
})
.sorted()
.collect(Collectors.toList());
LinkedHashSet<String> outputOrder = new LinkedHashSet<>();
for (String root : rootNodes) {
postOrderTraversal(root, dependsOn, outputOrder);
}
return new ArrayList<>(outputOrder);
}
private static void postOrderTraversal(String current, Map<String, Set<String>> dependsOn, LinkedHashSet<String> outputOrder) {
Set<String> dependencies = dependsOn.get(current);
if (dependencies == null || dependencies.isEmpty()) {
outputOrder.add(current);
return;
}
for (String dependency : dependencies) {
postOrderTraversal(dependency, dependsOn, outputOrder);
}
outputOrder.add(current);
}
private static boolean isTableAllowed(@Nullable String table) {
if (table == null) {
return true;
}
boolean isReservedTable = table.startsWith("sqlite_");
boolean isSmsFtsSecretTable = !table.equals(SearchDatabase.SMS_FTS_TABLE_NAME) && table.startsWith(SearchDatabase.SMS_FTS_TABLE_NAME);
boolean isMmsFtsSecretTable = !table.equals(SearchDatabase.MMS_FTS_TABLE_NAME) && table.startsWith(SearchDatabase.MMS_FTS_TABLE_NAME);
boolean isEmojiFtsSecretTable = !table.equals(EmojiSearchDatabase.TABLE_NAME) && table.startsWith(EmojiSearchDatabase.TABLE_NAME);
return !isReservedTable &&
!isSmsFtsSecretTable &&
!isMmsFtsSecretTable &&
!isEmojiFtsSecretTable;
}
private static int exportTable(@NonNull String table,

Wyświetl plik

@ -0,0 +1,141 @@
package org.thoughtcrime.securesms.backup
import org.junit.Assert.assertEquals
import org.junit.Test
class FullBackupExporterTest {
@Test
fun `computeTableOrder - empty`() {
val order = FullBackupExporter.computeTableOrder(mapOf())
assertEquals(listOf<String>(), order)
}
/**
* A B C
*/
@Test
fun `computeTableOrder - no dependencies`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"A" to setOf(),
"B" to setOf(),
"C" to setOf(),
)
)
assertEquals(listOf("A", "B", "C"), order)
}
/**
* C
* |
* B
* |
* A
*/
@Test
fun `computeTableOrder - single chain`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"C" to setOf("B"),
"B" to setOf("A"),
)
)
assertEquals(listOf("A", "B", "C"), order)
}
/**
* F G H
* B E
* A C D
*/
@Test
fun `computeTableOrder - complex 1`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"F" to setOf("B", "E"),
"B" to setOf("A"),
"E" to setOf("C", "D"),
"G" to setOf(),
"H" to setOf(),
"A" to setOf(),
"C" to setOf(),
"D" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H"), order)
}
/**
* I
* |
* C E H
* |
* A B D F G
*/
@Test
fun `computeTableOrder - complex 2`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"I" to setOf("C", "E", "H"),
"C" to setOf("A", "B"),
"E" to setOf("D"),
"H" to setOf("F", "G"),
"A" to setOf(),
"B" to setOf(),
"D" to setOf(),
"F" to setOf(),
"G" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H", "I"), order)
}
/**
* C E H
* |
* A B D F G
*/
@Test
fun `computeTableOrder - multiple roots`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"C" to setOf("A", "B"),
"E" to setOf("D"),
"H" to setOf("F", "G"),
"A" to setOf(),
"B" to setOf(),
"D" to setOf(),
"F" to setOf(),
"G" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H"), order)
}
/**
* C D E
* |
* A B A A B
*/
@Test
fun `computeTableOrder - multiple roots, dupes across graphs`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"C" to setOf("A", "B"),
"D" to setOf("A"),
"E" to setOf("A", "B"),
"A" to setOf(),
"B" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E"), order)
}
}

Wyświetl plik

@ -33,6 +33,17 @@ object SqlUtil {
return tables
}
/**
* Given a table, this will return a set of tables that it has a foreign key dependency on.
*/
@JvmStatic
fun getForeignKeyDependencies(db: SupportSQLiteDatabase, table: String): Set<String> {
return db.query("PRAGMA foreign_key_list($table)")
.readToSet{ cursor ->
cursor.requireNonNullString("table")
}
}
@JvmStatic
fun isEmpty(db: SupportSQLiteDatabase, table: String): Boolean {
db.query("SELECT COUNT(*) FROM $table", null).use { cursor ->