Signal-Android/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdsiSocket.java

223 wiersze
10 KiB
Java

package org.whispersystems.signalservice.api.services;
import org.signal.cdsi.proto.ClientRequest;
import org.signal.cdsi.proto.ClientResponse;
import org.signal.libsignal.cds2.AttestationDataException;
import org.signal.libsignal.cds2.Cds2Client;
import org.signal.libsignal.cds2.Cds2CommunicationFailureException;
import org.signal.libsignal.protocol.logging.Log;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.signalservice.api.push.TrustStore;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.api.util.Tls12SocketFactory;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager;
import org.whispersystems.signalservice.internal.util.Hex;
import org.whispersystems.signalservice.internal.util.Util;
import org.whispersystems.util.Base64;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import io.reactivex.rxjava3.core.Observable;
import okhttp3.ConnectionSpec;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
/**
* Handles the websocket and general lifecycle of a CDSI request.
*/
final class CdsiSocket {
private static final String TAG = CdsiSocket.class.getSimpleName();
private final OkHttpClient okhttp;
private final String baseUrl;
private final String mrEnclave;
private Cds2Client client;
private static final byte[] CERTIFICATE = ("-----BEGIN CERTIFICATE-----\n"
+ " MIICjzCCAjSgAwIBAgIUImUM1lqdNInzg7SVUr9QGzknBqwwCgYIKoZIzj0EAwIw\n"
+ " aDEaMBgGA1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENv\n"
+ " cnBvcmF0aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJ\n"
+ " BgNVBAYTAlVTMB4XDTE4MDUyMTEwNDUxMFoXDTQ5MTIzMTIzNTk1OVowaDEaMBgG\n"
+ " A1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENvcnBvcmF0\n"
+ " aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJBgNVBAYT\n"
+ " AlVTMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEC6nEwMDIYZOj/iPWsCzaEKi7\n"
+ " 1OiOSLRFhWGjbnBVJfVnkY4u3IjkDYYL0MxO4mqsyYjlBalTVYxFP2sJBK5zlKOB\n"
+ " uzCBuDAfBgNVHSMEGDAWgBQiZQzWWp00ifODtJVSv1AbOScGrDBSBgNVHR8ESzBJ\n"
+ " MEegRaBDhkFodHRwczovL2NlcnRpZmljYXRlcy50cnVzdGVkc2VydmljZXMuaW50\n"
+ " ZWwuY29tL0ludGVsU0dYUm9vdENBLmRlcjAdBgNVHQ4EFgQUImUM1lqdNInzg7SV\n"
+ " Ur9QGzknBqwwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQEwCgYI\n"
+ " KoZIzj0EAwIDSQAwRgIhAOW/5QkR+S9CiSDcNoowLuPRLsWGf/Yi7GSX94BgwTwg\n"
+ " AiEA4J0lrHoMs+Xo5o/sX6O9QWxHRAvZUGOdRQ7cvqRXaqI=\n"
+ " -----END CERTIFICATE-----").getBytes(StandardCharsets.UTF_8);
CdsiSocket(SignalServiceConfiguration configuration, String mrEnclave) {
this.baseUrl = configuration.getSignalCdsiUrls()[0].getUrl();
this.mrEnclave = mrEnclave;
Pair<SSLSocketFactory, X509TrustManager> socketFactory = createTlsSocketFactory(configuration.getSignalCdsiUrls()[0].getTrustStore());
this.okhttp = new OkHttpClient.Builder().sslSocketFactory(new Tls12SocketFactory(socketFactory.first()),
socketFactory.second())
.connectionSpecs(Util.immutableList(ConnectionSpec.RESTRICTED_TLS))
.readTimeout(30, TimeUnit.SECONDS)
.connectTimeout(30, TimeUnit.SECONDS)
.build();
}
Observable<ClientResponse> connect(String username, String password, ClientRequest clientRequest, Consumer<byte[]> tokenSaver) {
return Observable.create(emitter -> {
AtomicReference<Stage> stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE);
String url = String.format("%s/v1/%s/discovery", baseUrl, mrEnclave);
Request request = new Request.Builder()
.url(url)
.addHeader("Authorization", basicAuth(username, password))
.build();
WebSocket webSocket = okhttp.newWebSocket(request, new WebSocketListener() {
@Override
public void onOpen(WebSocket webSocket, Response response) {
Log.d(TAG, "onOpen");
stage.set(Stage.WAITING_FOR_CONNECTION);
}
@Override
public void onMessage(WebSocket webSocket, okio.ByteString bytes) {
Log.d(TAG, "[onMessage] stage: " + stage.get());
try {
switch (stage.get()) {
case INIT:
throw new IOException("Received a message before we were open!");
case WAITING_FOR_CONNECTION:
client = Cds2Client.create_NOT_FOR_PRODUCTION(Hex.fromStringCondensed(mrEnclave),
CERTIFICATE,
bytes.toByteArray(),
Instant.now().minus(Duration.ofHours(24)));
Log.d(TAG, "[onMessage] Sending initial handshake...");
webSocket.send(okio.ByteString.of(client.initialRequest()));
stage.set(Stage.WAITING_FOR_HANDSHAKE);
break;
case WAITING_FOR_HANDSHAKE:
client.completeHandshake(bytes.toByteArray());
Log.d(TAG, "[onMessage] Handshake read success.");
Log.d(TAG, "[onMessage] Sending data...");
byte[] ciphertextBytes = client.establishedSend(clientRequest.toByteArray());
webSocket.send(okio.ByteString.of(ciphertextBytes));
Log.d(TAG, "[onMessage] Data sent.");
stage.set(Stage.WAITING_FOR_TOKEN);
break;
case WAITING_FOR_TOKEN:
ClientResponse tokenResponse = ClientResponse.parseFrom(client.establishedRecv(bytes.toByteArray()));
if (tokenResponse.getToken().isEmpty()) {
throw new IOException("No token! Cannot continue!");
}
tokenSaver.accept(tokenResponse.getToken().toByteArray());
Log.d(TAG, "[onMessage] Sending token ack...");
webSocket.send(okio.ByteString.of(client.establishedSend(ClientRequest.newBuilder()
.setTokenAck(true)
.build()
.toByteArray())));
stage.set(Stage.WAITING_FOR_RESPONSE);
break;
case WAITING_FOR_RESPONSE:
emitter.onNext(ClientResponse.parseFrom(client.establishedRecv(bytes.toByteArray())));
break;
case CLOSED:
Log.w(TAG, "[onMessage] Received a message after the websocket closed! Ignoring.");
break;
case FAILED:
Log.w(TAG, "[onMessage] Received a message after we entered the failure state! Ignoring.");
webSocket.close(1000, "OK");
break;
}
} catch (IOException | AttestationDataException | Cds2CommunicationFailureException e) {
Log.w(TAG, e);
webSocket.close(1000, "OK");
emitter.onError(e);
}
}
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
if (code == 1000) {
emitter.onComplete();
stage.set(Stage.CLOSED);
} else {
Log.w(TAG, "Remote side is closing with non-normal code " + code);
webSocket.close(1000, "Remote closed with code " + code);
stage.set(Stage.FAILED);
emitter.onError(new NonSuccessfulResponseCodeException(code));
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
emitter.onError(t);
stage.set(Stage.FAILED);
webSocket.close(1000, "OK");
}
});
emitter.setCancellable(() -> webSocket.close(1000, "OK"));
});
}
private static String basicAuth(String username, String password) {
return "Basic " + Base64.encodeBytes((username + ":" + password).getBytes(StandardCharsets.UTF_8));
}
private static Pair<SSLSocketFactory, X509TrustManager> createTlsSocketFactory(TrustStore trustStore) {
try {
SSLContext context = SSLContext.getInstance("TLS");
TrustManager[] trustManagers = BlacklistingTrustManager.createFor(trustStore);
context.init(null, trustManagers, null);
return new Pair<>(context.getSocketFactory(), (X509TrustManager) trustManagers[0]);
} catch (NoSuchAlgorithmException | KeyManagementException e) {
throw new AssertionError(e);
}
}
private enum Stage {
INIT,
WAITING_FOR_CONNECTION,
WAITING_FOR_HANDSHAKE,
WAITING_FOR_TOKEN,
WAITING_TO_INITIALIZE,
WAITING_FOR_RESPONSE,
CLOSED,
FAILED
}
}