diff --git a/include/kem.hpp b/include/kem.hpp index fd5e51b..846e139 100644 --- a/include/kem.hpp +++ b/include/kem.hpp @@ -60,12 +60,15 @@ keygen(std::span d, // used in CPA-PKE // cipher text of length (k * du * 32 + dv * 32) -bytes which can be shared with // recipient party ( having respective secret key ) over insecure channel. // -// It also returns a SHAKE256 object which acts as a KDF ( key derivation -// function ), used for generating arbitrary length shared secret key, to be -// used for symmetric key encryption between these two participating entities. +// It also computes a fixed length 32 -bytes shared secret, which can be used for +// symmetric key encryption between these two participating entities. Alternatively +// they might choose to derive longer keys from this shared secret. // -// Other side of communication should also be able to generate same arbitrary -// length key stream ( using KDF ), after successful decryption of cipher text. +// Other side of communication should also be able to generate same 32 -byte shared secret, +// after successful decryption of cipher text. +// +// If invalid public key is input, this function execution will fail, returning false, +// otherwise it will return true, while producing both cipher text and shared secret. // // See algorithm 8 defined in Kyber specification // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf @@ -76,7 +79,7 @@ keygen(std::span d, // used in CPA-PKE // https://github.com/pq-crystals/kyber.git. It also helps in properly // benchmarking underlying KEM's encapsulation implementation. template -static inline void +[[nodiscard("Use result, it might fail because of malformed input public key")]] static inline bool encapsulate(std::span m, std::span pubkey, std::span cipher, @@ -106,8 +109,13 @@ encapsulate(std::span m, h512.finalize(); h512.digest(_g_out); - pke::encrypt(pubkey, m, _g_out1, cipher); + const auto has_mod_check_passed = pke::encrypt(pubkey, m, _g_out1, cipher); + if (!has_mod_check_passed) { + return has_mod_check_passed; + } + std::copy(_g_out0.begin(), _g_out0.end(), shared_secret.begin()); + return true; } // Given (k * 24 * 32 + 96) -bytes secret key and (k * du * 32 + dv * 32) -bytes @@ -171,7 +179,8 @@ decapsulate(std::span sec xof256.finalize(); xof256.squeeze(j_out); - pke::encrypt(pubkey, _g_in0, _g_out1, c_prime); + // Explicitly ignore return value, because public key, held as part of secret key is *assumed* to be valid. + (void)pke::encrypt(pubkey, _g_in0, _g_out1, c_prime); // line 7-11 of algorithm 9, in constant-time using kdf_t = std::span; diff --git a/include/kyber1024_kem.hpp b/include/kyber1024_kem.hpp index f819153..acb97da 100644 --- a/include/kyber1024_kem.hpp +++ b/include/kyber1024_kem.hpp @@ -43,13 +43,13 @@ keygen(std::span d, std::span z, std::span // at same SHAKE256 XOF backed KDF. // // Returned KDF can be used for deriving shared key of arbitrary bytes length. -inline void +[[nodiscard("If public key is malformed, encapsulation fails")]] inline bool encapsulate(std::span m, std::span pubkey, std::span cipher, std::span shared_secret) { - kem::encapsulate(m, pubkey, cipher, shared_secret); + return kem::encapsulate(m, pubkey, cipher, shared_secret); } // Given a Kyber1024 KEM secret key ( of 3168 -bytes ) and a cipher text of 1568 diff --git a/include/kyber512_kem.hpp b/include/kyber512_kem.hpp index 17eafc1..1abc9c4 100644 --- a/include/kyber512_kem.hpp +++ b/include/kyber512_kem.hpp @@ -43,13 +43,13 @@ keygen(std::span d, std::span z, std::span // SHAKE256 XOF backed KDF. // // Returned KDF can be used for deriving shared key of arbitrary bytes length. -inline void +[[nodiscard("If public key is malformed, encapsulation fails")]] inline bool encapsulate(std::span m, std::span pubkey, std::span cipher, std::span shared_secret) { - kem::encapsulate(m, pubkey, cipher, shared_secret); + return kem::encapsulate(m, pubkey, cipher, shared_secret); } // Given a Kyber512 KEM secret key ( of 1632 -bytes ) and a cipher text of 768 diff --git a/include/kyber768_kem.hpp b/include/kyber768_kem.hpp index 4ab239a..4e8ab34 100644 --- a/include/kyber768_kem.hpp +++ b/include/kyber768_kem.hpp @@ -42,13 +42,13 @@ keygen(std::span d, std::span z, std::span // at same SHAKE256 XOF backed KDF. // // Returned KDF can be used for deriving shared key of arbitrary bytes length. -inline void +[[nodiscard("If public key is malformed, encapsulation fails")]] inline bool encapsulate(std::span m, std::span pubkey, std::span cipher, std::span shared_secret) { - kem::encapsulate(m, pubkey, cipher, shared_secret); + return kem::encapsulate(m, pubkey, cipher, shared_secret); } // Given a Kyber768 KEM secret key ( of 2400 -bytes ) and a cipher text of 1088 diff --git a/include/pke.hpp b/include/pke.hpp index 20e1aad..84983b3 100644 --- a/include/pke.hpp +++ b/include/pke.hpp @@ -80,16 +80,19 @@ keygen(std::span d, std::span pubk kyber_utils::poly_vec_encode(s, seckey); } -// Given (k * 12 * 32 + 32) -bytes public key, 32 -bytes message ( to be +// Given (k * 12 * 32 + 32) -bytes *valid* public key, 32 -bytes message ( to be // encrypted ) and 32 -bytes random coin ( from where all randomness is // deterministically sampled ), this routine encrypts message using // INDCPA-secure Kyber encryption algorithm, computing compressed cipher text of // (k * du * 32 + dv * 32) -bytes. // +// If modulus check, as described in point (2) of section 6.2 of ML-KEM draft standard, +// fails, it returns false, otherwise it returns true. +// // See algorithm 5 defined in Kyber specification // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf template -static inline void +[[nodiscard("Use result of modulus check on public key")]] static inline bool encrypt(std::span pubkey, std::span msg, std::span rcoin, @@ -102,7 +105,17 @@ encrypt(std::span pubkey, auto rho = pubkey.template subspan(); std::array t_prime{}; + std::array encoded_tprime{}; + kyber_utils::poly_vec_decode(_pubkey0, t_prime); + kyber_utils::poly_vec_encode(t_prime, encoded_tprime); + + using encoded_pkey_t = std::span; + const auto are_equal = kyber_utils::ct_memcmp(encoded_pkey_t(_pubkey0), encoded_pkey_t(encoded_tprime)); + if (are_equal == 0u) { + // Got an invalid public key + return false; + } // step 4, 5, 6, 7, 8 std::array A_prime{}; @@ -158,6 +171,8 @@ encrypt(std::span pubkey, // step 22 kyber_utils::poly_compress(v); kyber_utils::encode(v, _enc1); + + return true; } // Given (k * 12 * 32) -bytes secret key and (k * du * 32 + dv * 32) -bytes