Migrate the session table to be keyed off of libsignal IDs.

fork-5.53.8
Greyson Parrelli 2021-08-19 14:11:14 -04:00 zatwierdzone przez Alex Hart
rodzic c24dfdce34
commit 6618d696e4
8 zmienionych plików z 142 dodań i 156 usunięć

Wyświetl plik

@ -220,7 +220,14 @@ class ConversationSettingsRepository(
Preconditions.checkArgument(FeatureFlags.internalUser(), "Internal users only!") Preconditions.checkArgument(FeatureFlags.internalUser(), "Internal users only!")
SignalExecutors.BOUNDED.execute { SignalExecutors.BOUNDED.execute {
DatabaseFactory.getSessionDatabase(context).deleteAllFor(recipientId) val recipient = Recipient.resolved(recipientId)
if (recipient.hasUuid()) {
DatabaseFactory.getSessionDatabase(context).deleteAllFor(recipient.requireUuid().toString())
}
if (recipient.hasE164()) {
DatabaseFactory.getSessionDatabase(context).deleteAllFor(recipient.requireE164())
}
} }
} }

Wyświetl plik

@ -543,8 +543,9 @@ public class DirectoryHelper {
} }
private static boolean hasCommunicatedWith(@NonNull Context context, @NonNull Recipient recipient) { private static boolean hasCommunicatedWith(@NonNull Context context, @NonNull Recipient recipient) {
return DatabaseFactory.getThreadDatabase(context).hasThread(recipient.getId()) || return DatabaseFactory.getThreadDatabase(context).hasThread(recipient.getId()) ||
DatabaseFactory.getSessionDatabase(context).hasSessionFor(recipient.getId()); (recipient.hasUuid() && DatabaseFactory.getSessionDatabase(context).hasSessionFor(recipient.requireUuid().toString())) ||
(recipient.hasE164() && DatabaseFactory.getSessionDatabase(context).hasSessionFor(recipient.requireE164()));
} }
static class DirectoryResult { static class DirectoryResult {

Wyświetl plik

@ -7,7 +7,6 @@ import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log; import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.database.DatabaseFactory; import org.thoughtcrime.securesms.database.DatabaseFactory;
import org.thoughtcrime.securesms.database.SessionDatabase; import org.thoughtcrime.securesms.database.SessionDatabase;
import org.thoughtcrime.securesms.database.SessionDatabase.RecipientDevice;
import org.thoughtcrime.securesms.recipients.Recipient; import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId; import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.libsignal.NoSessionException; import org.whispersystems.libsignal.NoSessionException;
@ -16,9 +15,7 @@ import org.whispersystems.libsignal.protocol.CiphertextMessage;
import org.whispersystems.libsignal.state.SessionRecord; import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore; import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class TextSecureSessionStore implements SignalServiceSessionStore { public class TextSecureSessionStore implements SignalServiceSessionStore {
@ -35,8 +32,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override @Override
public SessionRecord loadSession(@NonNull SignalProtocolAddress address) { public SessionRecord loadSession(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) { synchronized (LOCK) {
RecipientId recipientId = RecipientId.fromExternalPush(address.getName()); SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(address);
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(recipientId, address.getDeviceId());
if (sessionRecord == null) { if (sessionRecord == null) {
Log.w(TAG, "No existing session information found."); Log.w(TAG, "No existing session information found.");
@ -50,11 +46,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override @Override
public List<SessionRecord> loadExistingSessions(List<SignalProtocolAddress> addresses) throws NoSessionException { public List<SessionRecord> loadExistingSessions(List<SignalProtocolAddress> addresses) throws NoSessionException {
synchronized (LOCK) { synchronized (LOCK) {
List<RecipientDevice> ids = addresses.stream() List<SessionRecord> sessionRecords = DatabaseFactory.getSessionDatabase(context).load(addresses);
.map(address -> new RecipientDevice(RecipientId.fromExternalPush(address.getName()), address.getDeviceId()))
.collect(Collectors.toList());
List<SessionRecord> sessionRecords = DatabaseFactory.getSessionDatabase(context).load(ids);
if (sessionRecords.size() != addresses.size()) { if (sessionRecords.size() != addresses.size()) {
String message = "Mismatch! Asked for " + addresses.size() + " sessions, but only found " + sessionRecords.size() + "!"; String message = "Mismatch! Asked for " + addresses.size() + " sessions, but only found " + sessionRecords.size() + "!";
@ -69,96 +61,76 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override @Override
public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) { public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) {
synchronized (LOCK) { synchronized (LOCK) {
RecipientId id = RecipientId.fromExternalPush(address.getName()); DatabaseFactory.getSessionDatabase(context).store(address, record);
DatabaseFactory.getSessionDatabase(context).store(id, address.getDeviceId(), record);
} }
} }
@Override @Override
public boolean containsSession(SignalProtocolAddress address) { public boolean containsSession(SignalProtocolAddress address) {
synchronized (LOCK) { synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) { SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(address);
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(recipientId, address.getDeviceId());
return sessionRecord != null && return sessionRecord != null &&
sessionRecord.hasSenderChain() && sessionRecord.hasSenderChain() &&
sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION; sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
} else {
return false;
}
} }
} }
@Override @Override
public void deleteSession(SignalProtocolAddress address) { public void deleteSession(SignalProtocolAddress address) {
synchronized (LOCK) { synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) { DatabaseFactory.getSessionDatabase(context).delete(address);
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
DatabaseFactory.getSessionDatabase(context).delete(recipientId, address.getDeviceId());
} else {
Log.w(TAG, "Tried to delete session for " + address.toString() + ", but none existed!");
}
} }
} }
@Override @Override
public void deleteAllSessions(String name) { public void deleteAllSessions(String name) {
synchronized (LOCK) { synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(name)) { DatabaseFactory.getSessionDatabase(context).deleteAllFor(name);
RecipientId recipientId = RecipientId.fromExternalPush(name);
DatabaseFactory.getSessionDatabase(context).deleteAllFor(recipientId);
}
} }
} }
@Override @Override
public List<Integer> getSubDeviceSessions(String name) { public List<Integer> getSubDeviceSessions(String name) {
synchronized (LOCK) { synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(name)) { return DatabaseFactory.getSessionDatabase(context).getSubDevices(name);
RecipientId recipientId = RecipientId.fromExternalPush(name);
return DatabaseFactory.getSessionDatabase(context).getSubDevices(recipientId);
} else {
Log.w(TAG, "Tried to get sub device sessions for " + name + ", but none existed!");
return Collections.emptyList();
}
} }
} }
@Override @Override
public void archiveSession(SignalProtocolAddress address) { public void archiveSession(SignalProtocolAddress address) {
synchronized (LOCK) { synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) { SessionRecord session = DatabaseFactory.getSessionDatabase(context).load(address);
RecipientId recipientId = RecipientId.fromExternalPush(address.getName()); if (session != null) {
archiveSession(recipientId, address.getDeviceId()); session.archiveCurrentState();
DatabaseFactory.getSessionDatabase(context).store(address, session);
} }
} }
} }
public void archiveSession(@NonNull RecipientId recipientId, int deviceId) { public void archiveSession(@NonNull RecipientId recipientId, int deviceId) {
synchronized (LOCK) { synchronized (LOCK) {
SessionRecord session = DatabaseFactory.getSessionDatabase(context).load(recipientId, deviceId); Recipient recipient = Recipient.resolved(recipientId);
if (session != null) {
session.archiveCurrentState(); if (recipient.hasUuid()) {
DatabaseFactory.getSessionDatabase(context).store(recipientId, deviceId, session); archiveSession(new SignalProtocolAddress(recipient.requireUuid().toString(), deviceId));
}
if (recipient.hasE164()) {
archiveSession(new SignalProtocolAddress(recipient.requireE164(), deviceId));
} }
} }
} }
public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) { public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) { synchronized (LOCK) {
if (DatabaseFactory.getRecipientDatabase(context).containsPhoneOrUuid(address.getName())) { List<SessionDatabase.SessionRow> sessions = DatabaseFactory.getSessionDatabase(context).getAllFor(address.getName());
RecipientId recipientId = RecipientId.fromExternalPush(address.getName());
List<SessionDatabase.SessionRow> sessions = DatabaseFactory.getSessionDatabase(context).getAllFor(recipientId);
for (SessionDatabase.SessionRow row : sessions) { for (SessionDatabase.SessionRow row : sessions) {
if (row.getDeviceId() != address.getDeviceId()) { if (row.getDeviceId() != address.getDeviceId()) {
row.getRecord().archiveCurrentState(); row.getRecord().archiveCurrentState();
storeSession(new SignalProtocolAddress(Recipient.resolved(row.getRecipientId()).requireServiceId(), row.getDeviceId()), row.getRecord()); storeSession(new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), row.getRecord());
}
} }
} else {
Log.w(TAG, "Tried to archive sibling sessions for " + address.toString() + ", but none existed!");
} }
} }
} }
@ -169,7 +141,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
for (SessionDatabase.SessionRow row : sessions) { for (SessionDatabase.SessionRow row : sessions) {
row.getRecord().archiveCurrentState(); row.getRecord().archiveCurrentState();
storeSession(new SignalProtocolAddress(Recipient.resolved(row.getRecipientId()).requireServiceId(), row.getDeviceId()), row.getRecord()); storeSession(new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), row.getRecord());
} }
} }
} }

Wyświetl plik

@ -2961,17 +2961,19 @@ public class RecipientDatabase extends Database {
} }
// Sessions // Sessions
boolean hasE164Session = DatabaseFactory.getSessionDatabase(context).getAllFor(byE164).size() > 0; SessionDatabase sessionDatabase = DatabaseFactory.getSessionDatabase(context);
boolean hasUuidSession = DatabaseFactory.getSessionDatabase(context).getAllFor(byUuid).size() > 0;
boolean hasE164Session = sessionDatabase.getAllFor(e164Settings.e164).size() > 0;
boolean hasUuidSession = sessionDatabase.getAllFor(uuidSettings.uuid.toString()).size() > 0;
if (hasE164Session && hasUuidSession) { if (hasE164Session && hasUuidSession) {
Log.w(TAG, "Had a session for both users. Deleting the E164.", true); Log.w(TAG, "Had a session for both users. Deleting the E164.", true);
db.delete(SessionDatabase.TABLE_NAME, SessionDatabase.RECIPIENT_ID + " = ?", SqlUtil.buildArgs(byE164)); sessionDatabase.deleteAllFor(e164Settings.e164);
} else if (hasE164Session && !hasUuidSession) { } else if (hasE164Session && !hasUuidSession) {
Log.w(TAG, "Had a session for E164, but not UUID. Re-assigning to the UUID.", true); Log.w(TAG, "Had a session for E164, but not UUID. Re-assigning to the UUID.", true);
ContentValues values = new ContentValues(); ContentValues values = new ContentValues();
values.put(SessionDatabase.RECIPIENT_ID, byUuid.serialize()); values.put(SessionDatabase.ADDRESS, uuidSettings.uuid.toString());
db.update(SessionDatabase.TABLE_NAME, values, SessionDatabase.RECIPIENT_ID + " = ?", SqlUtil.buildArgs(byE164)); db.update(SessionDatabase.TABLE_NAME, values, SessionDatabase.ADDRESS + " = ?", SqlUtil.buildArgs(e164Settings.e164));
} else if (!hasE164Session && hasUuidSession) { } else if (!hasE164Session && hasUuidSession) {
Log.w(TAG, "Had a session for UUID, but not E164. No action necessary.", true); Log.w(TAG, "Had a session for UUID, but not E164. No action necessary.", true);
} else { } else {

Wyświetl plik

@ -1,18 +1,20 @@
package org.thoughtcrime.securesms.database; package org.thoughtcrime.securesms.database;
import android.content.ContentValues;
import android.content.Context; import android.content.Context;
import android.database.Cursor; import android.database.Cursor;
import androidx.annotation.NonNull; import androidx.annotation.NonNull;
import androidx.annotation.Nullable; import androidx.annotation.Nullable;
import net.zetetic.database.sqlcipher.SQLiteStatement;
import org.thoughtcrime.securesms.database.helpers.SQLCipherOpenHelper; import org.thoughtcrime.securesms.database.helpers.SQLCipherOpenHelper;
import org.signal.core.util.logging.Log; import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.recipients.RecipientId; import org.thoughtcrime.securesms.util.CursorUtil;
import org.thoughtcrime.securesms.util.SqlUtil; import org.thoughtcrime.securesms.util.SqlUtil;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.state.SessionRecord; import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
@ -27,40 +29,42 @@ public class SessionDatabase extends Database {
public static final String TABLE_NAME = "sessions"; public static final String TABLE_NAME = "sessions";
private static final String ID = "_id"; private static final String ID = "_id";
public static final String RECIPIENT_ID = "address"; public static final String ADDRESS = "address";
public static final String DEVICE = "device"; public static final String DEVICE = "device";
public static final String RECORD = "record"; public static final String RECORD = "record";
public static final String CREATE_TABLE = "CREATE TABLE " + TABLE_NAME + public static final String CREATE_TABLE = "CREATE TABLE " + TABLE_NAME + "(" + ID + " INTEGER PRIMARY KEY AUTOINCREMENT, " +
"(" + ID + " INTEGER PRIMARY KEY, " + RECIPIENT_ID + " INTEGER NOT NULL, " + ADDRESS + " TEXT NOT NULL, " +
DEVICE + " INTEGER NOT NULL, " + RECORD + " BLOB NOT NULL, " + DEVICE + " INTEGER NOT NULL, " +
"UNIQUE(" + RECIPIENT_ID + "," + DEVICE + ") ON CONFLICT REPLACE);"; RECORD + " BLOB NOT NULL, " +
"UNIQUE(" + ADDRESS + "," + DEVICE + "));";
SessionDatabase(Context context, SQLCipherOpenHelper databaseHelper) { SessionDatabase(Context context, SQLCipherOpenHelper databaseHelper) {
super(context, databaseHelper); super(context, databaseHelper);
} }
public void store(@NonNull RecipientId recipientId, int deviceId, @NonNull SessionRecord record) { public void store(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) {
SQLiteDatabase database = databaseHelper.getSignalWritableDatabase(); SQLiteDatabase db = databaseHelper.getSignalWritableDatabase();
ContentValues values = new ContentValues(); try (SQLiteStatement statement = db.compileStatement("INSERT INTO " + TABLE_NAME + " (" + ADDRESS + ", " + DEVICE + ", " + RECORD + ") VALUES (?, ?, ?) " +
values.put(RECIPIENT_ID, recipientId.serialize()); "ON CONFLICT (" + ADDRESS + ", " + DEVICE + ") DO UPDATE SET " + RECORD + " = excluded." + RECORD))
values.put(DEVICE, deviceId); {
values.put(RECORD, record.serialize()); statement.bindString(1, address.getName());
statement.bindLong(2, address.getDeviceId());
database.insertWithOnConflict(TABLE_NAME, null, values, SQLiteDatabase.CONFLICT_REPLACE); statement.bindBlob(3, record.serialize());
statement.execute();
}
} }
public @Nullable SessionRecord load(@NonNull RecipientId recipientId, int deviceId) { public @Nullable SessionRecord load(@NonNull SignalProtocolAddress address) {
SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); SQLiteDatabase database = databaseHelper.getSignalReadableDatabase();
String[] projection = new String[] { RECORD };
String selection = ADDRESS + " = ? AND " + DEVICE + " = ?";
String[] args = SqlUtil.buildArgs(address.getName(), address.getDeviceId());
try (Cursor cursor = database.query(TABLE_NAME, new String[]{RECORD}, try (Cursor cursor = database.query(TABLE_NAME, projection, selection, args, null, null, null)) {
RECIPIENT_ID + " = ? AND " + DEVICE + " = ?", if (cursor.moveToFirst()) {
new String[] {recipientId.serialize(), String.valueOf(deviceId)},
null, null, null))
{
if (cursor != null && cursor.moveToFirst()) {
try { try {
return new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))); return new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD)));
} catch (IOException e) { } catch (IOException e) {
@ -72,17 +76,17 @@ public class SessionDatabase extends Database {
return null; return null;
} }
public @NonNull List<SessionRecord> load(@NonNull List<RecipientDevice> ids) { public @NonNull List<SessionRecord> load(@NonNull List<SignalProtocolAddress> addresses) {
SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); SQLiteDatabase database = databaseHelper.getSignalReadableDatabase();
List<SessionRecord> sessions = new ArrayList<>(ids.size()); List<SessionRecord> sessions = new ArrayList<>(addresses.size());
database.beginTransaction(); database.beginTransaction();
try { try {
String[] projection = new String[]{RECORD}; String[] projection = new String[] { RECORD };
String query = RECIPIENT_ID + " = ? AND " + DEVICE + " = ?"; String query = ADDRESS + " = ? AND " + DEVICE + " = ?";
for (RecipientDevice id : ids) { for (SignalProtocolAddress address : addresses) {
String[] args = SqlUtil.buildArgs(id.getRecipientId(), id.getDevice()); String[] args = SqlUtil.buildArgs(address.getName(), address.getDeviceId());
try (Cursor cursor = database.query(TABLE_NAME, projection, query, args, null, null, null)) { try (Cursor cursor = database.query(TABLE_NAME, projection, query, args, null, null, null)) {
if (cursor.moveToFirst()) { if (cursor.moveToFirst()) {
@ -102,19 +106,15 @@ public class SessionDatabase extends Database {
return sessions; return sessions;
} }
public @NonNull List<SessionRow> getAllFor(@NonNull RecipientId recipientId) { public @NonNull List<SessionRow> getAllFor(@NonNull String addressName) {
SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); SQLiteDatabase database = databaseHelper.getSignalReadableDatabase();
List<SessionRow> results = new LinkedList<>(); List<SessionRow> results = new LinkedList<>();
try (Cursor cursor = database.query(TABLE_NAME, null, try (Cursor cursor = database.query(TABLE_NAME, null, ADDRESS + " = ?", SqlUtil.buildArgs(addressName), null, null, null)) {
RECIPIENT_ID + " = ?", while (cursor.moveToNext()) {
new String[] {recipientId.serialize()},
null, null, null))
{
while (cursor != null && cursor.moveToNext()) {
try { try {
results.add(new SessionRow(recipientId, results.add(new SessionRow(CursorUtil.requireString(cursor, ADDRESS),
cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE)), CursorUtil.requireInt(cursor, DEVICE),
new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))))); new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD)))));
} catch (IOException e) { } catch (IOException e) {
Log.w(TAG, e); Log.w(TAG, e);
@ -130,10 +130,10 @@ public class SessionDatabase extends Database {
List<SessionRow> results = new LinkedList<>(); List<SessionRow> results = new LinkedList<>();
try (Cursor cursor = database.query(TABLE_NAME, null, null, null, null, null, null)) { try (Cursor cursor = database.query(TABLE_NAME, null, null, null, null, null, null)) {
while (cursor != null && cursor.moveToNext()) { while (cursor.moveToNext()) {
try { try {
results.add(new SessionRow(RecipientId.from(cursor.getLong(cursor.getColumnIndexOrThrow(RECIPIENT_ID))), results.add(new SessionRow(CursorUtil.requireString(cursor, ADDRESS),
cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE)), CursorUtil.requireInt(cursor, DEVICE),
new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))))); new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD)))));
} catch (IOException e) { } catch (IOException e) {
Log.w(TAG, e); Log.w(TAG, e);
@ -144,16 +144,15 @@ public class SessionDatabase extends Database {
return results; return results;
} }
public @NonNull List<Integer> getSubDevices(@NonNull RecipientId recipientId) { public @NonNull List<Integer> getSubDevices(@NonNull String addressName) {
SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); SQLiteDatabase database = databaseHelper.getSignalReadableDatabase();
List<Integer> results = new LinkedList<>(); List<Integer> results = new LinkedList<>();
String[] projection = new String[] { DEVICE };
String selection = ADDRESS + " = ?";
String[] args = SqlUtil.buildArgs(addressName);
try (Cursor cursor = database.query(TABLE_NAME, new String[] {DEVICE}, try (Cursor cursor = database.query(TABLE_NAME, projection, selection, args, null, null, null)) {
RECIPIENT_ID + " = ?", while (cursor.moveToNext()) {
new String[] {recipientId.serialize()},
null, null, null))
{
while (cursor != null && cursor.moveToNext()) {
int device = cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE)); int device = cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE));
if (device != SignalServiceAddress.DEFAULT_DEVICE_ID) { if (device != SignalServiceAddress.DEFAULT_DEVICE_ID) {
@ -165,41 +164,43 @@ public class SessionDatabase extends Database {
return results; return results;
} }
public void delete(@NonNull RecipientId recipientId, int deviceId) { public void delete(@NonNull SignalProtocolAddress address) {
SQLiteDatabase database = databaseHelper.getSignalWritableDatabase(); SQLiteDatabase database = databaseHelper.getSignalWritableDatabase();
String selection = ADDRESS + " = ? AND " + DEVICE + " = ?";
String[] args = SqlUtil.buildArgs(address.getName(), address.getDeviceId());
database.delete(TABLE_NAME, RECIPIENT_ID + " = ? AND " + DEVICE + " = ?",
new String[] {recipientId.serialize(), String.valueOf(deviceId)}); database.delete(TABLE_NAME, selection, args);
} }
public void deleteAllFor(@NonNull RecipientId recipientId) { public void deleteAllFor(@NonNull String addressName) {
SQLiteDatabase database = databaseHelper.getSignalWritableDatabase(); SQLiteDatabase database = databaseHelper.getSignalWritableDatabase();
database.delete(TABLE_NAME, RECIPIENT_ID + " = ?", new String[] {recipientId.serialize()}); database.delete(TABLE_NAME, ADDRESS + " = ?", SqlUtil.buildArgs(addressName));
} }
public boolean hasSessionFor(@NonNull RecipientId recipientId) { public boolean hasSessionFor(@NonNull String addressName) {
SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); SQLiteDatabase database = databaseHelper.getSignalReadableDatabase();
String query = RECIPIENT_ID + " = ?"; String query = ADDRESS + " = ?";
String[] args = SqlUtil.buildArgs(recipientId); String[] args = SqlUtil.buildArgs(addressName);
try (Cursor cursor = database.query(TABLE_NAME, new String[] { ID }, query, args, null, null, null, "1")) { try (Cursor cursor = database.query(TABLE_NAME, new String[] { "1" }, query, args, null, null, null, "1")) {
return cursor != null && cursor.moveToFirst(); return cursor.moveToFirst();
} }
} }
public static final class SessionRow { public static final class SessionRow {
private final RecipientId recipientId; private final String address;
private final int deviceId; private final int deviceId;
private final SessionRecord record; private final SessionRecord record;
public SessionRow(@NonNull RecipientId recipientId, int deviceId, SessionRecord record) { public SessionRow(@NonNull String address, int deviceId, SessionRecord record) {
this.recipientId = recipientId; this.address = address;
this.deviceId = deviceId; this.deviceId = deviceId;
this.record = record; this.record = record;
} }
public RecipientId getRecipientId() { public @NonNull String getAddress() {
return recipientId; return address;
} }
public int getDeviceId() { public int getDeviceId() {
@ -210,22 +211,4 @@ public class SessionDatabase extends Database {
return record; return record;
} }
} }
public static final class RecipientDevice {
private final RecipientId recipientId;
private final int device;
public RecipientDevice(@NonNull RecipientId recipientId, int device) {
this.recipientId = recipientId;
this.device = device;
}
public @NonNull RecipientId getRecipientId() {
return recipientId;
}
public int getDevice() {
return device;
}
}
} }

Wyświetl plik

@ -1,7 +1,5 @@
package org.thoughtcrime.securesms.database; package org.thoughtcrime.securesms.database;
import androidx.annotation.NonNull;
import net.zetetic.database.sqlcipher.SQLiteDatabase; import net.zetetic.database.sqlcipher.SQLiteDatabase;
/** /**

Wyświetl plik

@ -210,8 +210,9 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper implements SignalDatab
private static final int ABANDONED_ATTACHMENT_CLEANUP = 110; private static final int ABANDONED_ATTACHMENT_CLEANUP = 110;
private static final int AVATAR_PICKER = 111; private static final int AVATAR_PICKER = 111;
private static final int THREAD_CLEANUP = 112; private static final int THREAD_CLEANUP = 112;
private static final int SESSION_MIGRATION = 113;
private static final int DATABASE_VERSION = 112; private static final int DATABASE_VERSION = 113;
private static final String DATABASE_NAME = "signal.db"; private static final String DATABASE_NAME = "signal.db";
private final Context context; private final Context context;
@ -1965,6 +1966,28 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper implements SignalDatab
db.delete("part", "mid != -8675309 AND mid NOT IN (SELECT _id FROM mms)", null); db.delete("part", "mid != -8675309 AND mid NOT IN (SELECT _id FROM mms)", null);
} }
if (oldVersion < SESSION_MIGRATION) {
long start = System.currentTimeMillis();
db.execSQL("CREATE TABLE sessions_tmp (_id INTEGER PRIMARY KEY AUTOINCREMENT, " +
"address TEXT NOT NULL, " +
"device INTEGER NOT NULL, " +
"record BLOB NOT NULL, " +
"UNIQUE(address, device))");
db.execSQL("INSERT INTO sessions_tmp (address, device, record) " +
"SELECT COALESCE(recipient.uuid, recipient.phone) AS new_address, " +
"sessions.device, " +
"sessions.record " +
"FROM sessions INNER JOIN recipient ON sessions.address = recipient._id " +
"WHERE new_address NOT NULL");
db.execSQL("DROP TABLE sessions");
db.execSQL("ALTER TABLE sessions_tmp RENAME TO sessions");
Log.d(TAG, "Session migration took " + (System.currentTimeMillis() - start) + " ms");
}
db.setTransactionSuccessful(); db.setTransactionSuccessful();
} finally { } finally {
db.endTransaction(); db.endTransaction();

Wyświetl plik

@ -73,7 +73,7 @@ class SessionStoreMigrationHelper {
ContentValues contentValues = new ContentValues(); ContentValues contentValues = new ContentValues();
contentValues.put(SessionDatabase.RECIPIENT_ID, address); contentValues.put(SessionDatabase.ADDRESS, address);
contentValues.put(SessionDatabase.DEVICE, deviceId); contentValues.put(SessionDatabase.DEVICE, deviceId);
contentValues.put(SessionDatabase.RECORD, sessionRecord.serialize()); contentValues.put(SessionDatabase.RECORD, sessionRecord.serialize());