From 9f6b761d9888a2cd62c7bc0e8154c883871a62d5 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Sun, 18 Feb 2018 16:43:18 -0800 Subject: [PATCH] Migrate sessions into database --- .../securesms/crypto/SessionUtil.java | 25 +-- .../storage/TextSecureSessionStore.java | 203 +++--------------- .../securesms/database/DatabaseFactory.java | 6 + .../database/OneTimePreKeyDatabase.java | 7 +- .../securesms/database/SessionDatabase.java | 170 +++++++++++++++ .../database/SignedPreKeyDatabase.java | 4 +- .../database/helpers/SQLCipherOpenHelper.java | 11 +- .../helpers/SessionStoreMigrationHelper.java | 109 ++++++++++ 8 files changed, 336 insertions(+), 199 deletions(-) create mode 100644 src/org/thoughtcrime/securesms/database/SessionDatabase.java create mode 100644 src/org/thoughtcrime/securesms/database/helpers/SessionStoreMigrationHelper.java diff --git a/src/org/thoughtcrime/securesms/crypto/SessionUtil.java b/src/org/thoughtcrime/securesms/crypto/SessionUtil.java index 827470954..36f19b15e 100644 --- a/src/org/thoughtcrime/securesms/crypto/SessionUtil.java +++ b/src/org/thoughtcrime/securesms/crypto/SessionUtil.java @@ -5,7 +5,6 @@ import android.support.annotation.NonNull; import org.thoughtcrime.securesms.crypto.storage.TextSecureSessionStore; import org.thoughtcrime.securesms.database.Address; -import org.thoughtcrime.securesms.recipients.Recipient; import org.whispersystems.libsignal.SignalProtocolAddress; import org.whispersystems.libsignal.state.SessionRecord; import org.whispersystems.libsignal.state.SessionStore; @@ -15,36 +14,20 @@ import java.util.List; public class SessionUtil { - public static boolean hasSession(Context context, Recipient recipient) { - return hasSession(context, recipient.getAddress()); - } - public static boolean hasSession(Context context, @NonNull Address address) { - SessionStore sessionStore = new TextSecureSessionStore(context, null); + SessionStore sessionStore = new TextSecureSessionStore(context); SignalProtocolAddress axolotlAddress = new SignalProtocolAddress(address.serialize(), SignalServiceAddress.DEFAULT_DEVICE_ID); return sessionStore.containsSession(axolotlAddress); } public static void archiveSiblingSessions(Context context, SignalProtocolAddress address) { - SessionStore sessionStore = new TextSecureSessionStore(context); - List devices = sessionStore.getSubDeviceSessions(address.getName()); - devices.add(1); - - for (int device : devices) { - if (device != address.getDeviceId()) { - SignalProtocolAddress sibling = new SignalProtocolAddress(address.getName(), device); - - if (sessionStore.containsSession(sibling)) { - SessionRecord sessionRecord = sessionStore.loadSession(sibling); - sessionRecord.archiveCurrentState(); - sessionStore.storeSession(sibling, sessionRecord); - } - } - } + TextSecureSessionStore sessionStore = new TextSecureSessionStore(context); + sessionStore.archiveSiblingSessions(address); } public static void archiveAllSessions(Context context) { new TextSecureSessionStore(context).archiveAllSessions(); } + } diff --git a/src/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java b/src/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java index 389869e3c..2f953806a 100644 --- a/src/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java +++ b/src/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java @@ -7,6 +7,9 @@ import android.util.Log; import org.thoughtcrime.securesms.crypto.MasterCipher; import org.thoughtcrime.securesms.crypto.MasterSecret; +import org.thoughtcrime.securesms.database.Address; +import org.thoughtcrime.securesms.database.DatabaseFactory; +import org.thoughtcrime.securesms.database.SessionDatabase; import org.thoughtcrime.securesms.util.Conversions; import org.whispersystems.libsignal.InvalidMessageException; import org.whispersystems.libsignal.SignalProtocolAddress; @@ -29,144 +32,77 @@ import static org.whispersystems.libsignal.state.StorageProtos.SessionStructure; public class TextSecureSessionStore implements SessionStore { - private static final String TAG = TextSecureSessionStore.class.getSimpleName(); - private static final String SESSIONS_DIRECTORY_V2 = "sessions-v2"; - private static final Object FILE_LOCK = new Object(); + private static final String TAG = TextSecureSessionStore.class.getSimpleName(); - private static final int SINGLE_STATE_VERSION = 1; - private static final int ARCHIVE_STATES_VERSION = 2; - private static final int PLAINTEXT_VERSION = 3; - private static final int CURRENT_VERSION = 3; + private static final Object FILE_LOCK = new Object(); - @NonNull private final Context context; - @Nullable private final MasterSecret masterSecret; + @NonNull private final Context context; public TextSecureSessionStore(@NonNull Context context) { - this(context, null); - } - - public TextSecureSessionStore(@NonNull Context context, @Nullable MasterSecret masterSecret) { - this.context = context.getApplicationContext(); - this.masterSecret = masterSecret; + this.context = context; } @Override public SessionRecord loadSession(@NonNull SignalProtocolAddress address) { synchronized (FILE_LOCK) { - try { - FileInputStream in = new FileInputStream(getSessionFile(address)); - int versionMarker = readInteger(in); + SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(Address.fromSerialized(address.getName()), address.getDeviceId()); - if (versionMarker > CURRENT_VERSION) { - throw new AssertionError("Unknown version: " + versionMarker); - } - - byte[] serialized = readBlob(in); - in.close(); - - if (versionMarker < PLAINTEXT_VERSION && masterSecret != null) { - serialized = new MasterCipher(masterSecret).decryptBytes(serialized); - } else if (versionMarker < PLAINTEXT_VERSION) { - throw new AssertionError("Session didn't get migrated: (" + versionMarker + "," + address + ")"); - } - - if (versionMarker == SINGLE_STATE_VERSION) { - SessionStructure sessionStructure = SessionStructure.parseFrom(serialized); - SessionState sessionState = new SessionState(sessionStructure); - return new SessionRecord(sessionState); - } else if (versionMarker >= ARCHIVE_STATES_VERSION) { - return new SessionRecord(serialized); - } else { - throw new AssertionError("Unknown version: " + versionMarker); - } - } catch (InvalidMessageException | IOException e) { + if (sessionRecord == null) { Log.w(TAG, "No existing session information found."); return new SessionRecord(); } + + return sessionRecord; } } @Override public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) { synchronized (FILE_LOCK) { - try { - RandomAccessFile sessionFile = new RandomAccessFile(getSessionFile(address), "rw"); - FileChannel out = sessionFile.getChannel(); - - out.position(0); - writeInteger(CURRENT_VERSION, out); - writeBlob(record.serialize(), out); - out.truncate(out.position()); - - sessionFile.close(); - } catch (IOException e) { - throw new AssertionError(e); - } + DatabaseFactory.getSessionDatabase(context).store(Address.fromSerialized(address.getName()), address.getDeviceId(), record); } } @Override public boolean containsSession(SignalProtocolAddress address) { - if (!getSessionFile(address).exists()) return false; + synchronized (FILE_LOCK) { + SessionRecord sessionRecord = DatabaseFactory.getSessionDatabase(context).load(Address.fromSerialized(address.getName()), address.getDeviceId()); - SessionRecord sessionRecord = loadSession(address); - - return sessionRecord.getSessionState().hasSenderChain() && - sessionRecord.getSessionState().getSessionVersion() == CiphertextMessage.CURRENT_VERSION; + return sessionRecord != null && + sessionRecord.getSessionState().hasSenderChain() && + sessionRecord.getSessionState().getSessionVersion() == CiphertextMessage.CURRENT_VERSION; + } } @Override public void deleteSession(SignalProtocolAddress address) { - getSessionFile(address).delete(); + synchronized (FILE_LOCK) { + DatabaseFactory.getSessionDatabase(context).delete(Address.fromSerialized(address.getName()), address.getDeviceId()); + } } @Override public void deleteAllSessions(String name) { - List devices = getSubDeviceSessions(name); - - deleteSession(new SignalProtocolAddress(name, SignalServiceAddress.DEFAULT_DEVICE_ID)); - - for (int device : devices) { - deleteSession(new SignalProtocolAddress(name, device)); + synchronized (FILE_LOCK) { + DatabaseFactory.getSessionDatabase(context).deleteAllFor(Address.fromSerialized(name)); } } @Override public List getSubDeviceSessions(String name) { - List results = new LinkedList<>(); - File parent = getSessionDirectory(); - String[] children = parent.list(); - - if (children == null) return results; - - for (String child : children) { - try { - String[] parts = child.split("[.]", 2); - String sessionName = parts[0]; - - if (sessionName.equals(name) && parts.length > 1) { - results.add(Integer.parseInt(parts[1])); - } - } catch (NumberFormatException e) { - Log.w(TAG, e); - } + synchronized (FILE_LOCK) { + return DatabaseFactory.getSessionDatabase(context).getSubDevices(Address.fromSerialized(name)); } - - return results; } - public void migrateSessions() { + public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) { synchronized (FILE_LOCK) { - File directory = getSessionDirectory(); + List sessions = DatabaseFactory.getSessionDatabase(context).getAllFor(Address.fromSerialized(address.getName())); - for (File session : directory.listFiles()) { - if (session.isFile()) { - SignalProtocolAddress address = getAddressName(session); - - if (address != null) { - SessionRecord sessionRecord = loadSession(address); - storeSession(address, sessionRecord); - } + for (SessionDatabase.SessionRow row : sessions) { + if (row.getDeviceId() != address.getDeviceId()) { + row.getRecord().archiveCurrentState(); + storeSession(new SignalProtocolAddress(row.getAddress().serialize(), row.getDeviceId()), row.getRecord()); } } } @@ -174,81 +110,12 @@ public class TextSecureSessionStore implements SessionStore { public void archiveAllSessions() { synchronized (FILE_LOCK) { - File directory = getSessionDirectory(); + List sessions = DatabaseFactory.getSessionDatabase(context).getAll(); - for (File session : directory.listFiles()) { - if (session.isFile()) { - SignalProtocolAddress address = getAddressName(session); - - if (address != null) { - SessionRecord sessionRecord = loadSession(address); - sessionRecord.archiveCurrentState(); - storeSession(address, sessionRecord); - } - } + for (SessionDatabase.SessionRow row : sessions) { + row.getRecord().archiveCurrentState(); + storeSession(new SignalProtocolAddress(row.getAddress().serialize(), row.getDeviceId()), row.getRecord()); } } } - - private File getSessionFile(SignalProtocolAddress address) { - return new File(getSessionDirectory(), getSessionName(address)); - } - - private File getSessionDirectory() { - File directory = new File(context.getFilesDir(), SESSIONS_DIRECTORY_V2); - - if (!directory.exists()) { - if (!directory.mkdirs()) { - Log.w(TAG, "Session directory creation failed!"); - } - } - - return directory; - } - - private String getSessionName(SignalProtocolAddress address) { - int deviceId = address.getDeviceId(); - return address.getName() + (deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID ? "" : "." + deviceId); - } - - private @Nullable SignalProtocolAddress getAddressName(File sessionFile) { - try { - String[] parts = sessionFile.getName().split("[.]"); - - int deviceId; - - if (parts.length > 1) deviceId = Integer.parseInt(parts[1]); - else deviceId = SignalServiceAddress.DEFAULT_DEVICE_ID; - - return new SignalProtocolAddress(parts[0], deviceId); - } catch (NumberFormatException e) { - Log.w(TAG, e); - return null; - } - } - - private byte[] readBlob(FileInputStream in) throws IOException { - int length = readInteger(in); - byte[] blobBytes = new byte[length]; - - in.read(blobBytes, 0, blobBytes.length); - return blobBytes; - } - - private void writeBlob(byte[] blobBytes, FileChannel out) throws IOException { - writeInteger(blobBytes.length, out); - out.write(ByteBuffer.wrap(blobBytes)); - } - - private int readInteger(FileInputStream in) throws IOException { - byte[] integer = new byte[4]; - in.read(integer, 0, integer.length); - return Conversions.byteArrayToInt(integer); - } - - private void writeInteger(int value, FileChannel out) throws IOException { - byte[] valueBytes = Conversions.intToByteArray(value); - out.write(ByteBuffer.wrap(valueBytes)); - } - } diff --git a/src/org/thoughtcrime/securesms/database/DatabaseFactory.java b/src/org/thoughtcrime/securesms/database/DatabaseFactory.java index 47cd06b1d..0937e7244 100644 --- a/src/org/thoughtcrime/securesms/database/DatabaseFactory.java +++ b/src/org/thoughtcrime/securesms/database/DatabaseFactory.java @@ -55,6 +55,7 @@ public class DatabaseFactory { private final GroupReceiptDatabase groupReceiptDatabase; private final OneTimePreKeyDatabase preKeyDatabase; private final SignedPreKeyDatabase signedPreKeyDatabase; + private final SessionDatabase sessionDatabase; public static DatabaseFactory getInstance(Context context) { synchronized (lock) { @@ -125,6 +126,10 @@ public class DatabaseFactory { return getInstance(context).signedPreKeyDatabase; } + public static SessionDatabase getSessionDatabase(Context context) { + return getInstance(context).sessionDatabase; + } + private DatabaseFactory(@NonNull Context context) { SQLiteDatabase.loadLibs(context); @@ -147,6 +152,7 @@ public class DatabaseFactory { this.contactsDatabase = new ContactsDatabase(context); this.preKeyDatabase = new OneTimePreKeyDatabase(context, databaseHelper); this.signedPreKeyDatabase = new SignedPreKeyDatabase(context, databaseHelper); + this.sessionDatabase = new SessionDatabase(context, databaseHelper); } public void onApplicationLevelUpgrade(@NonNull Context context, @NonNull MasterSecret masterSecret, diff --git a/src/org/thoughtcrime/securesms/database/OneTimePreKeyDatabase.java b/src/org/thoughtcrime/securesms/database/OneTimePreKeyDatabase.java index 82f620d0a..6129f8eae 100644 --- a/src/org/thoughtcrime/securesms/database/OneTimePreKeyDatabase.java +++ b/src/org/thoughtcrime/securesms/database/OneTimePreKeyDatabase.java @@ -62,7 +62,6 @@ public class OneTimePreKeyDatabase extends Database { return null; } - public void insertPreKey(int keyId, PreKeyRecord record) { SQLiteDatabase database = databaseHelper.getWritableDatabase(); @@ -71,7 +70,7 @@ public class OneTimePreKeyDatabase extends Database { contentValues.put(PUBLIC_KEY, Base64.encodeBytes(record.getKeyPair().getPublicKey().serialize())); contentValues.put(PRIVATE_KEY, Base64.encodeBytes(record.getKeyPair().getPrivateKey().serialize())); - database.insert(TABLE_NAME, null, contentValues); + database.replace(TABLE_NAME, null, contentValues); } public void removePreKey(int keyId) { @@ -79,8 +78,4 @@ public class OneTimePreKeyDatabase extends Database { database.delete(TABLE_NAME, KEY_ID + " = ?", new String[] {String.valueOf(keyId)}); } - - - - } diff --git a/src/org/thoughtcrime/securesms/database/SessionDatabase.java b/src/org/thoughtcrime/securesms/database/SessionDatabase.java new file mode 100644 index 000000000..c269ff343 --- /dev/null +++ b/src/org/thoughtcrime/securesms/database/SessionDatabase.java @@ -0,0 +1,170 @@ +package org.thoughtcrime.securesms.database; + + +import android.content.ContentValues; +import android.content.Context; +import android.database.Cursor; +import android.support.annotation.NonNull; +import android.support.annotation.Nullable; +import android.util.Log; + +import net.sqlcipher.database.SQLiteDatabase; + +import org.thoughtcrime.securesms.database.helpers.SQLCipherOpenHelper; +import org.whispersystems.libsignal.state.SessionRecord; +import org.whispersystems.signalservice.api.push.SignalServiceAddress; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +public class SessionDatabase extends Database { + + private static final String TAG = SessionDatabase.class.getSimpleName(); + + public static final String TABLE_NAME = "sessions"; + + private static final String ID = "_id"; + public static final String ADDRESS = "address"; + public static final String DEVICE = "device"; + public static final String RECORD = "record"; + + public static final String CREATE_TABLE = "CREATE TABLE " + TABLE_NAME + + "(" + ID + " INTEGER PRIMARY KEY, " + ADDRESS + " TEXT NOT NULL, " + + DEVICE + " INTEGER NOT NULL, " + RECORD + " BLOB NOT NULL, " + + "UNIQUE(" + ADDRESS + "," + DEVICE + ") ON CONFLICT REPLACE);"; + + SessionDatabase(Context context, SQLCipherOpenHelper databaseHelper) { + super(context, databaseHelper); + } + + public void store(@NonNull Address address, int deviceId, @NonNull SessionRecord record) { + SQLiteDatabase database = databaseHelper.getWritableDatabase(); + + ContentValues values = new ContentValues(); + values.put(ADDRESS, address.serialize()); + values.put(DEVICE, deviceId); + values.put(RECORD, record.serialize()); + + database.insertWithOnConflict(TABLE_NAME, null, values, SQLiteDatabase.CONFLICT_REPLACE); + } + + public @Nullable SessionRecord load(@NonNull Address address, int deviceId) { + SQLiteDatabase database = databaseHelper.getReadableDatabase(); + + try (Cursor cursor = database.query(TABLE_NAME, new String[]{RECORD}, + ADDRESS + " = ? AND " + DEVICE + " = ?", + new String[] {address.serialize(), String.valueOf(deviceId)}, + null, null, null)) + { + if (cursor != null && cursor.moveToFirst()) { + try { + return new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))); + } catch (IOException e) { + Log.w(TAG, e); + } + } + } + + return null; + } + + public @NonNull List getAllFor(@NonNull Address address) { + SQLiteDatabase database = databaseHelper.getReadableDatabase(); + List results = new LinkedList<>(); + + try (Cursor cursor = database.query(TABLE_NAME, null, + ADDRESS + " = ?", + new String[] {address.serialize()}, + null, null, null)) + { + while (cursor != null && cursor.moveToNext()) { + try { + results.add(new SessionRow(address, + cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE)), + new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))))); + } catch (IOException e) { + Log.w(TAG, e); + } + } + } + + return results; + } + + public @NonNull List getAll() { + SQLiteDatabase database = databaseHelper.getReadableDatabase(); + List results = new LinkedList<>(); + + try (Cursor cursor = database.query(TABLE_NAME, null, null, null, null, null, null)) { + while (cursor != null && cursor.moveToNext()) { + try { + results.add(new SessionRow(Address.fromSerialized(cursor.getString(cursor.getColumnIndexOrThrow(ADDRESS))), + cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE)), + new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))))); + } catch (IOException e) { + Log.w(TAG, e); + } + } + } + + return results; + } + + public @NonNull List getSubDevices(@NonNull Address address) { + SQLiteDatabase database = databaseHelper.getReadableDatabase(); + List results = new LinkedList<>(); + + try (Cursor cursor = database.query(TABLE_NAME, new String[] {DEVICE}, + ADDRESS + " = ?", + new String[] {address.serialize()}, + null, null, null)) + { + while (cursor != null && cursor.moveToNext()) { + int device = cursor.getInt(cursor.getColumnIndexOrThrow(DEVICE)); + + if (device != SignalServiceAddress.DEFAULT_DEVICE_ID) { + results.add(device); + } + } + } + + return results; + } + + public void delete(@NonNull Address address, int deviceId) { + SQLiteDatabase database = databaseHelper.getWritableDatabase(); + + database.delete(TABLE_NAME, ADDRESS + " = ? AND " + DEVICE + " = ?", + new String[] {address.serialize(), String.valueOf(deviceId)}); + } + + public void deleteAllFor(@NonNull Address address) { + SQLiteDatabase database = databaseHelper.getWritableDatabase(); + database.delete(TABLE_NAME, ADDRESS + " = ?", new String[] {address.serialize()}); + } + + public static final class SessionRow { + private final Address address; + private final int deviceId; + private final SessionRecord record; + + public SessionRow(Address address, int deviceId, SessionRecord record) { + this.address = address; + this.deviceId = deviceId; + this.record = record; + } + + public Address getAddress() { + return address; + } + + public int getDeviceId() { + return deviceId; + } + + public SessionRecord getRecord() { + return record; + } + } +} diff --git a/src/org/thoughtcrime/securesms/database/SignedPreKeyDatabase.java b/src/org/thoughtcrime/securesms/database/SignedPreKeyDatabase.java index c8508ab32..6f5e9625b 100644 --- a/src/org/thoughtcrime/securesms/database/SignedPreKeyDatabase.java +++ b/src/org/thoughtcrime/securesms/database/SignedPreKeyDatabase.java @@ -44,7 +44,7 @@ public class SignedPreKeyDatabase extends Database { SIGNATURE + " TEXT NOT NULL, " + TIMESTAMP + " INTEGER DEFAULT 0);"; - public SignedPreKeyDatabase(Context context, SQLCipherOpenHelper databaseHelper) { + SignedPreKeyDatabase(Context context, SQLCipherOpenHelper databaseHelper) { super(context, databaseHelper); } @@ -105,7 +105,7 @@ public class SignedPreKeyDatabase extends Database { contentValues.put(SIGNATURE, Base64.encodeBytes(record.getSignature())); contentValues.put(TIMESTAMP, record.getTimestamp()); - database.insert(TABLE_NAME, null, contentValues); + database.replace(TABLE_NAME, null, contentValues); } diff --git a/src/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java b/src/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java index 052e224a8..4d605ed53 100644 --- a/src/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java +++ b/src/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java @@ -21,6 +21,7 @@ import org.thoughtcrime.securesms.database.MmsDatabase; import org.thoughtcrime.securesms.database.OneTimePreKeyDatabase; import org.thoughtcrime.securesms.database.PushDatabase; import org.thoughtcrime.securesms.database.RecipientDatabase; +import org.thoughtcrime.securesms.database.SessionDatabase; import org.thoughtcrime.securesms.database.SignedPreKeyDatabase; import org.thoughtcrime.securesms.database.SmsDatabase; import org.thoughtcrime.securesms.database.ThreadDatabase; @@ -35,8 +36,9 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper { private static final int RECIPIENT_CALL_RINGTONE_VERSION = 2; private static final int MIGRATE_PREKEYS_VERSION = 3; + private static final int MIGRATE_SESSIONS_VERSION = 4; - private static final int DATABASE_VERSION = 3; + private static final int DATABASE_VERSION = 4; private static final String DATABASE_NAME = "signal.db"; private final Context context; @@ -75,6 +77,7 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper { db.execSQL(GroupReceiptDatabase.CREATE_TABLE); db.execSQL(OneTimePreKeyDatabase.CREATE_TABLE); db.execSQL(SignedPreKeyDatabase.CREATE_TABLE); + db.execSQL(SessionDatabase.CREATE_TABLE); executeStatements(db, SmsDatabase.CREATE_INDEXS); executeStatements(db, MmsDatabase.CREATE_INDEXS); @@ -99,6 +102,7 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper { ApplicationContext.getInstance(context).getJobManager().add(new RefreshPreKeysJob(context)); } + SessionStoreMigrationHelper.migrateSessions(context, db); PreKeyMigrationHelper.cleanUpPreKeys(context); } } @@ -123,8 +127,11 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper { if (!PreKeyMigrationHelper.migratePreKeys(context, db)) { ApplicationContext.getInstance(context).getJobManager().add(new RefreshPreKeysJob(context)); } + } - PreKeyMigrationHelper.cleanUpPreKeys(context); + if (oldVersion < MIGRATE_SESSIONS_VERSION) { + db.execSQL("CREATE TABLE sessions (_id INTEGER PRIMARY KEY, address TEXT NOT NULL, device INTEGER NOT NULL, record BLOB NOT NULL, UNIQUE(address, device) ON CONFLICT REPLACE)"); + SessionStoreMigrationHelper.migrateSessions(context, db); } db.setTransactionSuccessful(); diff --git a/src/org/thoughtcrime/securesms/database/helpers/SessionStoreMigrationHelper.java b/src/org/thoughtcrime/securesms/database/helpers/SessionStoreMigrationHelper.java new file mode 100644 index 000000000..6c725d5c9 --- /dev/null +++ b/src/org/thoughtcrime/securesms/database/helpers/SessionStoreMigrationHelper.java @@ -0,0 +1,109 @@ +package org.thoughtcrime.securesms.database.helpers; + + +import android.content.ContentValues; +import android.content.Context; +import android.util.Log; + +import net.sqlcipher.database.SQLiteDatabase; + +import org.thoughtcrime.securesms.database.Address; +import org.thoughtcrime.securesms.database.SessionDatabase; +import org.thoughtcrime.securesms.util.Conversions; +import org.whispersystems.libsignal.state.SessionRecord; +import org.whispersystems.libsignal.state.SessionState; +import org.whispersystems.libsignal.state.StorageProtos.SessionStructure; +import org.whispersystems.signalservice.api.push.SignalServiceAddress; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; + +class SessionStoreMigrationHelper { + + private static final String TAG = SessionStoreMigrationHelper.class.getSimpleName(); + + private static final String SESSIONS_DIRECTORY_V2 = "sessions-v2"; + private static final Object FILE_LOCK = new Object(); + + private static final int SINGLE_STATE_VERSION = 1; + private static final int ARCHIVE_STATES_VERSION = 2; + private static final int PLAINTEXT_VERSION = 3; + private static final int CURRENT_VERSION = 3; + + static void migrateSessions(Context context, SQLiteDatabase database) { + File directory = new File(context.getFilesDir(), SESSIONS_DIRECTORY_V2); + + if (directory.exists()) { + File[] sessionFiles = directory.listFiles(); + + if (sessionFiles != null) { + for (File sessionFile : sessionFiles) { + try { + String[] parts = sessionFile.getName().split("[.]"); + Address address = Address.fromSerialized(parts[0]); + + int deviceId; + + if (parts.length > 1) deviceId = Integer.parseInt(parts[1]); + else deviceId = SignalServiceAddress.DEFAULT_DEVICE_ID; + + FileInputStream in = new FileInputStream(sessionFile); + int versionMarker = readInteger(in); + + if (versionMarker > CURRENT_VERSION) { + throw new AssertionError("Unknown version: " + versionMarker + ", " + sessionFile.getAbsolutePath()); + } + + byte[] serialized = readBlob(in); + in.close(); + + if (versionMarker < PLAINTEXT_VERSION) { + throw new AssertionError("Not plaintext: " + versionMarker + ", " + sessionFile.getAbsolutePath()); + } + + SessionRecord sessionRecord; + + if (versionMarker == SINGLE_STATE_VERSION) { + Log.w(TAG, "Migrating single state version: " + sessionFile.getAbsolutePath()); + SessionStructure sessionStructure = SessionStructure.parseFrom(serialized); + SessionState sessionState = new SessionState(sessionStructure); + + sessionRecord = new SessionRecord(sessionState); + } else if (versionMarker >= ARCHIVE_STATES_VERSION) { + Log.w(TAG, "Migrating session: " + sessionFile.getAbsolutePath()); + sessionRecord = new SessionRecord(serialized); + } else { + throw new AssertionError("Unknown version: " + versionMarker + ", " + sessionFile.getAbsolutePath()); + } + + + ContentValues contentValues = new ContentValues(); + contentValues.put(SessionDatabase.ADDRESS, address.serialize()); + contentValues.put(SessionDatabase.DEVICE, deviceId); + contentValues.put(SessionDatabase.RECORD, sessionRecord.serialize()); + + database.insert(SessionDatabase.TABLE_NAME, null, contentValues); + } catch (NumberFormatException | IOException e) { + Log.w(TAG, e); + } + } + } + } + } + + private static byte[] readBlob(FileInputStream in) throws IOException { + int length = readInteger(in); + byte[] blobBytes = new byte[length]; + + in.read(blobBytes, 0, blobBytes.length); + return blobBytes; + } + + private static int readInteger(FileInputStream in) throws IOException { + byte[] integer = new byte[4]; + in.read(integer, 0, integer.length); + return Conversions.byteArrayToInt(integer); + } + +}