From 2761c855640723f1a8f8c5c5cb27c3c6ffda3a73 Mon Sep 17 00:00:00 2001 From: Kevin Hester Date: Tue, 23 Feb 2021 10:10:35 +0800 Subject: [PATCH] clean up the crypto api --- src/esp32/ESP32CryptoEngine.cpp | 15 +-- src/mesh/Channels.cpp | 142 ++++++++++++-------- src/mesh/Channels.h | 29 ++-- src/mesh/CryptoEngine.cpp | 5 +- src/mesh/CryptoEngine.h | 11 +- src/nrf52/NRF52CryptoEngine.cpp | 27 +--- src/portduino/CrossPlatformCryptoEngine.cpp | 11 +- 7 files changed, 135 insertions(+), 105 deletions(-) diff --git a/src/esp32/ESP32CryptoEngine.cpp b/src/esp32/ESP32CryptoEngine.cpp index 613d5cc1..9d86ffeb 100644 --- a/src/esp32/ESP32CryptoEngine.cpp +++ b/src/esp32/ESP32CryptoEngine.cpp @@ -18,9 +18,6 @@ class ESP32CryptoEngine : public CryptoEngine mbedtls_aes_context aes; - /// How many bytes in our key - uint8_t keySize = 0; - public: ESP32CryptoEngine() { mbedtls_aes_init(&aes); } @@ -35,12 +32,12 @@ class ESP32CryptoEngine : public CryptoEngine * @param bytes a _static_ buffer that will remain valid for the life of this crypto instance (i.e. this class will cache the * provided pointer) */ - virtual void setKey(size_t numBytes, uint8_t *bytes) + virtual void setKey(const CryptoKey &k) { - keySize = numBytes; - DEBUG_MSG("Installing AES%d key!\n", numBytes * 8); - if (numBytes != 0) { - auto res = mbedtls_aes_setkey_enc(&aes, bytes, numBytes * 8); + CryptoEngine::setKey(k); + + if (key.length != 0) { + auto res = mbedtls_aes_setkey_enc(&aes, key.bytes, key.length * 8); assert(!res); } } @@ -52,7 +49,7 @@ class ESP32CryptoEngine : public CryptoEngine */ virtual void encrypt(uint32_t fromNode, uint64_t packetNum, size_t numBytes, uint8_t *bytes) { - if (keySize != 0) { + if (key.length > 0) { uint8_t stream_block[16]; static uint8_t scratch[MAX_BLOCKSIZE]; size_t nc_off = 0; diff --git a/src/mesh/Channels.cpp b/src/mesh/Channels.cpp index 696f78f4..f3238b9c 100644 --- a/src/mesh/Channels.cpp +++ b/src/mesh/Channels.cpp @@ -10,7 +10,7 @@ static const uint8_t defaultpsk[] = {0xd4, 0xf1, 0xbb, 0x3a, 0x20, 0x29, 0x07, 0 Channels channels; -uint8_t xorHash(uint8_t *p, size_t len) +uint8_t xorHash(const uint8_t *p, size_t len) { uint8_t code = 0; for (int i = 0; i < len; i++) @@ -18,6 +18,26 @@ uint8_t xorHash(uint8_t *p, size_t len) return code; } +/** Given a channel number, return the (0 to 255) hash for that channel. + * The hash is just an xor of the channel name followed by the channel PSK being used for encryption + * If no suitable channel could be found, return -1 + */ +int16_t Channels::generateHash(ChannelIndex channelNum) +{ + auto k = getKey(channelNum); + if (k.length < 0) + return -1; // invalid + else { + Channel &c = getByIndex(channelNum); + + uint8_t h = xorHash((const uint8_t *)c.settings.name, strlen(c.settings.name)); + + h ^= xorHash(k.bytes, k.length); + + return h; + } +} + /** * Validate a channel, fixing any errors as needed */ @@ -75,51 +95,69 @@ void Channels::initDefaultChannel(ChannelIndex chIndex) ch.role = Channel_Role_PRIMARY; } -/** Given a channel index, change to use the crypto key specified by that index - */ -void Channels::setCrypto(ChannelIndex chIndex) +CryptoKey Channels::getKey(ChannelIndex chIndex) { Channel &ch = getByIndex(chIndex); ChannelSettings &channelSettings = ch.settings; assert(ch.has_settings); - memset(activePSK, 0, sizeof(activePSK)); // In case the user provided a short key, we want to pad the rest with zeros - memcpy(activePSK, channelSettings.psk.bytes, channelSettings.psk.size); - activePSKSize = channelSettings.psk.size; - if (activePSKSize == 0) { - if (ch.role == Channel_Role_SECONDARY) { - DEBUG_MSG("Unset PSK for secondary channel %s. using primary key\n", ch.settings.name); - setCrypto(primaryIndex); - } else - DEBUG_MSG("Warning: User disabled encryption\n"); - } else if (activePSKSize == 1) { - // Convert the short single byte variants of psk into variant that can be used more generally + CryptoKey k; + memset(k.bytes, 0, sizeof(k.bytes)); // In case the user provided a short key, we want to pad the rest with zeros - uint8_t pskIndex = activePSK[0]; - DEBUG_MSG("Expanding short PSK #%d\n", pskIndex); - if (pskIndex == 0) - activePSKSize = 0; // Turn off encryption - else { - memcpy(activePSK, defaultpsk, sizeof(defaultpsk)); - activePSKSize = sizeof(defaultpsk); - // Bump up the last byte of PSK as needed - uint8_t *last = activePSK + sizeof(defaultpsk) - 1; - *last = *last + pskIndex - 1; // index of 1 means no change vs defaultPSK + if (ch.role == Channel_Role_DISABLED) { + k.length = -1; // invalid + } else { + memcpy(k.bytes, channelSettings.psk.bytes, channelSettings.psk.size); + k.length = channelSettings.psk.size; + if (k.length == 0) { + if (ch.role == Channel_Role_SECONDARY) { + DEBUG_MSG("Unset PSK for secondary channel %s. using primary key\n", ch.settings.name); + k = getKey(primaryIndex); + } else + DEBUG_MSG("Warning: User disabled encryption\n"); + } else if (k.length == 1) { + // Convert the short single byte variants of psk into variant that can be used more generally + + uint8_t pskIndex = k.bytes[0]; + DEBUG_MSG("Expanding short PSK #%d\n", pskIndex); + if (pskIndex == 0) + k.length = 0; // Turn off encryption + else { + memcpy(k.bytes, defaultpsk, sizeof(defaultpsk)); + k.length = sizeof(defaultpsk); + // Bump up the last byte of PSK as needed + uint8_t *last = k.bytes + sizeof(defaultpsk) - 1; + *last = *last + pskIndex - 1; // index of 1 means no change vs defaultPSK + } + } else if (k.length < 16) { + // Error! The user specified only the first few bits of an AES128 key. So by convention we just pad the rest of the + // key with zeros + DEBUG_MSG("Warning: User provided a too short AES128 key - padding\n"); + k.length = 16; + } else if (k.length < 32 && k.length != 16) { + // Error! The user specified only the first few bits of an AES256 key. So by convention we just pad the rest of the + // key with zeros + DEBUG_MSG("Warning: User provided a too short AES256 key - padding\n"); + k.length = 32; } - } else if (activePSKSize < 16) { - // Error! The user specified only the first few bits of an AES128 key. So by convention we just pad the rest of the key - // with zeros - DEBUG_MSG("Warning: User provided a too short AES128 key - padding\n"); - activePSKSize = 16; - } else if (activePSKSize < 32 && activePSKSize != 16) { - // Error! The user specified only the first few bits of an AES256 key. So by convention we just pad the rest of the key - // with zeros - DEBUG_MSG("Warning: User provided a too short AES256 key - padding\n"); - activePSKSize = 32; } - // Tell our crypto engine about the psk - crypto->setKey(activePSKSize, activePSK); + return k; +} + +/** Given a channel index, change to use the crypto key specified by that index + */ +int16_t Channels::setCrypto(ChannelIndex chIndex) +{ + CryptoKey k = getKey(chIndex); + + if (k.length < 0) + return -1; + else { + // Tell our crypto engine about the psk + crypto->setKey(k); + return getHash(chIndex); + } } void Channels::initDefaults() @@ -139,8 +177,6 @@ void Channels::onConfigChanged() if (ch.role == Channel_Role_PRIMARY) primaryIndex = i; } - - setCrypto(primaryIndex); // FIXME: for the time being (still single channel - just use our only channel as the crypto key) } Channel &Channels::getByIndex(ChannelIndex chIndex) @@ -207,7 +243,6 @@ their nodes * * Where X is either: * (for custom PSKS) a letter from A to Z (base26), and formed by xoring all the bytes of the PSK together, -* OR (for the standard minimially secure PSKs) a number from 0 to 9. * * This function will also need to be implemented in GUI apps that talk to the radio. * @@ -219,14 +254,14 @@ const char *Channels::getPrimaryName() char suffix; auto channelSettings = getPrimary(); - if (channelSettings.psk.size != 1) { - // We have a standard PSK, so generate a letter based hash. - uint8_t code = xorHash(activePSK, activePSKSize); + // if (channelSettings.psk.size != 1) { + // We have a standard PSK, so generate a letter based hash. + uint8_t code = getHash(primaryIndex); - suffix = 'A' + (code % 26); - } else { + suffix = 'A' + (code % 26); + /* } else { suffix = '0' + channelSettings.psk.bytes[0]; - } + } */ snprintf(buf, sizeof(buf), "#%s-%c", channelSettings.name, suffix); return buf; @@ -238,7 +273,10 @@ const char *Channels::getPrimaryName() * * @return -1 if no suitable channel could be found, otherwise returns the channel index */ -int16_t Channels::setActiveByHash(ChannelHash channelHash) {} +int16_t Channels::setActiveByHash(ChannelHash channelHash) +{ + // fixme cant work; +} /** Given a channel index setup crypto for encoding that channel (or the primary channel if that channel is unsecured) * @@ -246,9 +284,7 @@ int16_t Channels::setActiveByHash(ChannelHash channelHash) {} * * @eturn the (0 to 255) hash for that channel - if no suitable channel could be found, return -1 */ -int16_t Channels::setActiveByIndex(ChannelIndex channelIndex) {} - -/** Given a channel number, return the (0 to 255) hash for that channel - * If no suitable channel could be found, return -1 - */ -ChannelHash Channels::generateHash(ChannelIndex channelNum) {} \ No newline at end of file +int16_t Channels::setActiveByIndex(ChannelIndex channelIndex) +{ + return setCrypto(channelIndex); +} diff --git a/src/mesh/Channels.h b/src/mesh/Channels.h index a86b7ff6..4e70e230 100644 --- a/src/mesh/Channels.h +++ b/src/mesh/Channels.h @@ -2,6 +2,7 @@ #include "mesh-pb-constants.h" #include +#include "CryptoEngine.h" /** A channel number (index into the channel table) */ @@ -23,12 +24,8 @@ class Channels no sending or receiving will be allowed */ ChannelIndex activeChannelIndex = 0; - /// The in-use psk - which has been constructed based on the (possibly short psk) in channelSettings - uint8_t activePSK[32]; - uint8_t activePSKSize = 0; - - /// the precomputed hashes for each of our channels - ChannelHash hashes[MAX_NUM_CHANNELS]; + /// the precomputed hashes for each of our channels, or -1 for invalid + int16_t hashes[MAX_NUM_CHANNELS]; public: const ChannelSettings &getPrimary() { return getByIndex(getPrimaryIndex()).settings; } @@ -87,21 +84,24 @@ class Channels */ int16_t setActiveByIndex(ChannelIndex channelIndex); - /** return the channel hash we are currently using for sending */ - ChannelHash getActiveHash(); - private: /** Given a channel index, change to use the crypto key specified by that index + * + * @eturn the (0 to 255) hash for that channel - if no suitable channel could be found, return -1 */ - void setCrypto(ChannelIndex chIndex); + int16_t setCrypto(ChannelIndex chIndex); /** Return the channel index for the specified channel hash, or -1 for not found */ int8_t getIndexByHash(ChannelHash channelHash); /** Given a channel number, return the (0 to 255) hash for that channel * If no suitable channel could be found, return -1 + * + * called by fixupChannel when a new channel is set */ - ChannelHash generateHash(ChannelIndex channelNum); + int16_t generateHash(ChannelIndex channelNum); + + int16_t getHash(ChannelIndex i) { return hashes[i]; } /** * Validate a channel, fixing any errors as needed @@ -112,6 +112,13 @@ class Channels * Write a default channel to the specified channel index */ void initDefaultChannel(ChannelIndex chIndex); + + /** + * Return the key used for encrypting this channel (if channel is secondary and no key provided, use the primary channel's PSK) + */ + CryptoKey getKey(ChannelIndex chIndex); + + }; /// Singleton channel table diff --git a/src/mesh/CryptoEngine.cpp b/src/mesh/CryptoEngine.cpp index d72be111..74f4b783 100644 --- a/src/mesh/CryptoEngine.cpp +++ b/src/mesh/CryptoEngine.cpp @@ -1,9 +1,10 @@ #include "CryptoEngine.h" #include "configuration.h" -void CryptoEngine::setKey(size_t numBytes, uint8_t *bytes) +void CryptoEngine::setKey(const CryptoKey &k) { - DEBUG_MSG("WARNING: Using stub crypto - all crypto is sent in plaintext!\n"); + DEBUG_MSG("Installing AES%d key!\n", k.length * 8); + key = k; } /** diff --git a/src/mesh/CryptoEngine.h b/src/mesh/CryptoEngine.h index b97abed5..9853f564 100644 --- a/src/mesh/CryptoEngine.h +++ b/src/mesh/CryptoEngine.h @@ -2,6 +2,13 @@ #include +struct CryptoKey { + uint8_t bytes[32]; + + /// # of bytes, or -1 to mean "invalid key - do not use" + int8_t length; +}; + /** * see docs/software/crypto.md for details. * @@ -15,6 +22,8 @@ class CryptoEngine /** Our per packet nonce */ uint8_t nonce[16]; + CryptoKey key; + public: virtual ~CryptoEngine() {} @@ -27,7 +36,7 @@ class CryptoEngine * @param bytes a _static_ buffer that will remain valid for the life of this crypto instance (i.e. this class will cache the * provided pointer) */ - virtual void setKey(size_t numBytes, uint8_t *bytes); + virtual void setKey(const CryptoKey &k); /** * Encrypt a packet diff --git a/src/nrf52/NRF52CryptoEngine.cpp b/src/nrf52/NRF52CryptoEngine.cpp index 2bf16f23..431fa2e9 100644 --- a/src/nrf52/NRF52CryptoEngine.cpp +++ b/src/nrf52/NRF52CryptoEngine.cpp @@ -6,30 +6,13 @@ class NRF52CryptoEngine : public CryptoEngine { - /// How many bytes in our key - uint8_t keySize = 0; - const uint8_t *keyBytes; + public: NRF52CryptoEngine() {} ~NRF52CryptoEngine() {} - /** - * Set the key used for encrypt, decrypt. - * - * As a special case: If all bytes are zero, we assume _no encryption_ and send all data in cleartext. - * - * @param numBytes must be 16 (AES128), 32 (AES256) or 0 (no crypt) - * @param bytes a _static_ buffer that will remain valid for the life of this crypto instance (i.e. this class will cache the - * provided pointer) - */ - virtual void setKey(size_t numBytes, uint8_t *bytes) - { - keySize = numBytes; - keyBytes = bytes; - } - /** * Encrypt a packet * @@ -39,11 +22,11 @@ class NRF52CryptoEngine : public CryptoEngine { // DEBUG_MSG("NRF52 encrypt!\n"); - if (keySize != 0) { + if (key.length > 0) { ocrypto_aes_ctr_ctx ctx; initNonce(fromNode, packetNum); - ocrypto_aes_ctr_init(&ctx, keyBytes, keySize, nonce); + ocrypto_aes_ctr_init(&ctx, key.bytes, key.length, nonce); ocrypto_aes_ctr_encrypt(&ctx, bytes, bytes, numBytes); } @@ -53,11 +36,11 @@ class NRF52CryptoEngine : public CryptoEngine { // DEBUG_MSG("NRF52 decrypt!\n"); - if (keySize != 0) { + if (key.length > 0) { ocrypto_aes_ctr_ctx ctx; initNonce(fromNode, packetNum); - ocrypto_aes_ctr_init(&ctx, keyBytes, keySize, nonce); + ocrypto_aes_ctr_init(&ctx, key.bytes, key.length, nonce); ocrypto_aes_ctr_decrypt(&ctx, bytes, bytes, numBytes); } diff --git a/src/portduino/CrossPlatformCryptoEngine.cpp b/src/portduino/CrossPlatformCryptoEngine.cpp index b9e818c0..06225aa1 100644 --- a/src/portduino/CrossPlatformCryptoEngine.cpp +++ b/src/portduino/CrossPlatformCryptoEngine.cpp @@ -10,9 +10,6 @@ class CrossPlatformCryptoEngine : public CryptoEngine CTRCommon *ctr = NULL; - /// How many bytes in our key - uint8_t keySize = 0; - public: CrossPlatformCryptoEngine() {} @@ -27,9 +24,9 @@ class CrossPlatformCryptoEngine : public CryptoEngine * @param bytes a _static_ buffer that will remain valid for the life of this crypto instance (i.e. this class will cache the * provided pointer) */ - virtual void setKey(size_t numBytes, uint8_t *bytes) + virtual void setKey(const CryptoKey &k) { - keySize = numBytes; + CryptoEngine::setKey(k); DEBUG_MSG("Installing AES%d key!\n", numBytes * 8); if (ctr) { delete ctr; @@ -41,7 +38,7 @@ class CrossPlatformCryptoEngine : public CryptoEngine else ctr = new CTR(); - ctr->setKey(bytes, numBytes); + ctr->setKey(key.bytes, key.length); } } @@ -52,7 +49,7 @@ class CrossPlatformCryptoEngine : public CryptoEngine */ virtual void encrypt(uint32_t fromNode, uint64_t packetNum, size_t numBytes, uint8_t *bytes) { - if (keySize != 0) { + if (key.length > 0) { uint8_t stream_block[16]; static uint8_t scratch[MAX_BLOCKSIZE]; size_t nc_off = 0;