From 0d0a151a64c80b91235dc331a69345c0cd20928f Mon Sep 17 00:00:00 2001 From: Anjan Roy Date: Tue, 18 Jun 2024 11:41:43 +0400 Subject: [PATCH] Refactor PKE and KEM implementation Signed-off-by: Anjan Roy --- benchmarks/bench_kem.cpp | 14 +-- .../kyber/internals/{pke.hpp => k_pke.hpp} | 69 +++----------- .../kyber/internals/{kem.hpp => ml_kem.hpp} | 89 ++++++------------- include/kyber/kyber1024_kem.hpp | 8 +- include/kyber/kyber512_kem.hpp | 8 +- include/kyber/kyber768_kem.hpp | 8 +- tests/test_kem.cpp | 8 +- 7 files changed, 62 insertions(+), 142 deletions(-) rename include/kyber/internals/{pke.hpp => k_pke.hpp} (71%) rename include/kyber/internals/{kem.hpp => ml_kem.hpp} (56%) diff --git a/benchmarks/bench_kem.cpp b/benchmarks/bench_kem.cpp index e0faf74..bcd205e 100644 --- a/benchmarks/bench_kem.cpp +++ b/benchmarks/bench_kem.cpp @@ -1,5 +1,5 @@ #include "bench_helper.hpp" -#include "kyber/internals/kem.hpp" +#include "kyber/internals/ml_kem.hpp" #include "x86_64_cpu_ticks.hpp" #include @@ -35,7 +35,7 @@ bench_keygen(benchmark::State& state) const uint64_t start = cpu_ticks(); #endif - kem::keygen(_d, _z, _pkey, _skey); + ml_kem::keygen(_d, _z, _pkey, _skey); benchmark::DoNotOptimize(_d); benchmark::DoNotOptimize(_z); @@ -88,7 +88,7 @@ bench_encapsulate(benchmark::State& state) prng.read(_d); prng.read(_z); - kem::keygen(_d, _z, _pkey, _skey); + ml_kem::keygen(_d, _z, _pkey, _skey); prng.read(_m); @@ -101,7 +101,7 @@ bench_encapsulate(benchmark::State& state) const uint64_t start = cpu_ticks(); #endif - (void)kem::encapsulate(_m, _pkey, _cipher, _sender_key); + (void)ml_kem::encapsulate(_m, _pkey, _cipher, _sender_key); benchmark::DoNotOptimize(_m); benchmark::DoNotOptimize(_pkey); @@ -156,11 +156,11 @@ bench_decapsulate(benchmark::State& state) prng.read(_d); prng.read(_z); - kem::keygen(_d, _z, _pkey, _skey); + ml_kem::keygen(_d, _z, _pkey, _skey); prng.read(_m); - (void)kem::encapsulate(_m, _pkey, _cipher, _sender_key); + (void)ml_kem::encapsulate(_m, _pkey, _cipher, _sender_key); #ifdef __x86_64__ uint64_t total_ticks = 0ul; @@ -171,7 +171,7 @@ bench_decapsulate(benchmark::State& state) const uint64_t start = cpu_ticks(); #endif - kem::decapsulate(_skey, _cipher, _receiver_key); + ml_kem::decapsulate(_skey, _cipher, _receiver_key); benchmark::DoNotOptimize(_skey); benchmark::DoNotOptimize(_cipher); diff --git a/include/kyber/internals/pke.hpp b/include/kyber/internals/k_pke.hpp similarity index 71% rename from include/kyber/internals/pke.hpp rename to include/kyber/internals/k_pke.hpp index fa0005d..e71d8ee 100644 --- a/include/kyber/internals/pke.hpp +++ b/include/kyber/internals/k_pke.hpp @@ -5,33 +5,17 @@ #include "kyber/internals/utility/params.hpp" #include "kyber/internals/utility/utils.hpp" #include "sha3_512.hpp" -#include -#include -// IND-CPA-secure Public Key Encryption Scheme -namespace pke { +// Public Key Encryption Scheme +namespace k_pke { -// Kyber CPAPKE key generation algorithm, which takes two parameters `k` & `η1` -// ( read eta1 ) and generates byte serialized public key and secret key of -// following length -// -// public key: (k * 12 * 32 + 32) -bytes wide -// secret key: (k * 12 * 32) -bytes wide -// -// See algorithm 4 defined in Kyber specification -// https://doi.org/10.6028/NIST.FIPS.203.ipd -// -// Note, this routine allows you to pass 32 -bytes seed ( see first parameter ), -// which is designed this way for ease of writing test cases against known -// answer tests, obtained from Kyber reference implementation -// https://github.com/pq-crystals/kyber.git. It also helps in properly -// benchmarking underlying PKE's key generation implementation. +// K-PKE key generation algorithm, generating byte serialized public key and secret keym given a 32 -bytes input seed `d`. +// See algorithm 12 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd. template static inline void keygen(std::span d, std::span pubkey, std::span seckey) requires(kyber_params::check_keygen_params(k, eta1)) { - // step 2 std::array g_out{}; auto _g_out = std::span(g_out); @@ -43,34 +27,27 @@ keygen(std::span d, std::span pubk const auto rho = _g_out.template subspan<0, 32>(); const auto sigma = _g_out.template subspan(); - // step 4, 5, 6, 7, 8 std::array A_prime{}; kyber_utils::generate_matrix(A_prime, rho); - // step 3 uint8_t N = 0; - // step 9, 10, 11, 12 std::array s{}; kyber_utils::generate_vector(s, sigma, N); N += k; - // step 13, 14, 15, 16 std::array e{}; kyber_utils::generate_vector(e, sigma, N); N += k; - // step 17, 18 kyber_utils::poly_vec_ntt(s); kyber_utils::poly_vec_ntt(e); - // step 19 std::array t_prime{}; kyber_utils::matrix_multiply(A_prime, s, t_prime); kyber_utils::poly_vec_add_to(e, t_prime); - // step 20, 21, 22 constexpr size_t pkoff = k * 12 * 32; auto _pubkey0 = pubkey.template subspan<0, pkoff>(); auto _pubkey1 = pubkey.template subspan(); @@ -80,17 +57,13 @@ keygen(std::span d, std::span pubk kyber_utils::poly_vec_encode(s, seckey); } -// Given (k * 12 * 32 + 32) -bytes *valid* public key, 32 -bytes message ( to be -// encrypted ) and 32 -bytes random coin ( from where all randomness is -// deterministically sampled ), this routine encrypts message using -// INDCPA-secure Kyber encryption algorithm, computing compressed cipher text of -// (k * du * 32 + dv * 32) -bytes. +// Given a *valid* K-PKE public key, 32 -bytes message ( to be encrypted ) and 32 -bytes random coin +// ( from where all randomness is deterministically sampled ), this routine encrypts message using +// K-PKE encryption algorithm, computing compressed cipher text. // -// If modulus check, as described in point (2) of section 6.2 of ML-KEM draft standard, -// fails, it returns false, otherwise it returns true. +// If modulus check, as described in point (2) of section 6.2 of ML-KEM draft standard, fails, it returns false. // -// See algorithm 5 defined in Kyber specification -// https://doi.org/10.6028/NIST.FIPS.203.ipd +// See algorithm 13 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd. template [[nodiscard("Use result of modulus check on public key")]] static inline bool encrypt(std::span pubkey, @@ -99,7 +72,6 @@ encrypt(std::span pubkey, std::span enc) requires(kyber_params::check_encrypt_params(k, eta1, eta2, du, dv)) { - // step 2, 3 constexpr size_t pkoff = k * 12 * 32; auto _pubkey0 = pubkey.template subspan<0, pkoff>(); auto rho = pubkey.template subspan(); @@ -117,38 +89,30 @@ encrypt(std::span pubkey, return false; } - // step 4, 5, 6, 7, 8 std::array A_prime{}; kyber_utils::generate_matrix(A_prime, rho); - // step 1 uint8_t N = 0; - // step 9, 10, 11, 12 std::array r{}; kyber_utils::generate_vector(r, rcoin, N); N += k; - // step 13, 14, 15, 16 std::array e1{}; kyber_utils::generate_vector(e1, rcoin, N); N += k; - // step 17 std::array e2{}; kyber_utils::generate_vector<1, eta2>(e2, rcoin, N); - // step 18 kyber_utils::poly_vec_ntt(r); - // step 19 std::array u{}; kyber_utils::matrix_multiply(A_prime, r, u); kyber_utils::poly_vec_intt(u); kyber_utils::poly_vec_add_to(e1, u); - // step 20 std::array v{}; kyber_utils::matrix_multiply<1, k, k, 1>(t_prime, r, v); @@ -164,24 +128,19 @@ encrypt(std::span pubkey, auto _enc0 = enc.template subspan<0, encoff>(); auto _enc1 = enc.template subspan(); - // step 21 kyber_utils::poly_vec_compress(u); kyber_utils::poly_vec_encode(u, _enc0); - // step 22 kyber_utils::poly_compress(v); kyber_utils::encode(v, _enc1); return true; } -// Given (k * 12 * 32) -bytes secret key and (k * du * 32 + dv * 32) -bytes -// encrypted ( cipher ) text, this routine recovers 32 -bytes plain text which -// was encrypted using respective public key, which is associated with this -// secret key. +// Given K-PKE secret key and cipher text, this routine recovers 32 -bytes plain text which +// was encrypted using K-PKE public key i.e. associated with this secret key. // -// See algorithm 6 defined in Kyber specification -// https://doi.org/10.6028/NIST.FIPS.203.ipd +// See algorithm 14 defined in K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd. template static inline void decrypt(std::span seckey, std::span enc, std::span dec) @@ -191,23 +150,19 @@ decrypt(std::span seckey, std::span(); auto _enc1 = enc.template subspan(); - // step 1 std::array u{}; kyber_utils::poly_vec_decode(_enc0, u); kyber_utils::poly_vec_decompress(u); - // step 2 std::array v{}; kyber_utils::decode(_enc1, v); kyber_utils::poly_decompress(v); - // step 3 std::array s_prime{}; kyber_utils::poly_vec_decode(seckey, s_prime); - // step 4 kyber_utils::poly_vec_ntt(u); std::array t{}; diff --git a/include/kyber/internals/kem.hpp b/include/kyber/internals/ml_kem.hpp similarity index 56% rename from include/kyber/internals/kem.hpp rename to include/kyber/internals/ml_kem.hpp index 81d390a..ff9bae1 100644 --- a/include/kyber/internals/kem.hpp +++ b/include/kyber/internals/ml_kem.hpp @@ -1,31 +1,16 @@ #pragma once +#include "k_pke.hpp" #include "kyber/internals/utility/utils.hpp" -#include "pke.hpp" #include "sha3_256.hpp" #include "sha3_512.hpp" #include "shake256.hpp" #include -#include -#include -// IND-CCA2-secure Key Encapsulation Mechanism -namespace kem { +// Key Encapsulation Mechanism +namespace ml_kem { -// Kyber CCAKEM key generation algorithm, which takes two parameters `k` & `η1` -// ( read eta1 ) and generates byte serialized public key and secret key of -// following length -// -// public key: (k * 12 * 32 + 32) -bytes wide -// secret key: (k * 24 * 32 + 96) -bytes wide [ includes public key ] -// -// See algorithm 7 defined in Kyber specification -// https://doi.org/10.6028/NIST.FIPS.203.ipd -// -// Note, this routine allows you to pass two 32 -bytes seeds ( see first & -// second parameter ), which is designed this way for ease of writing test cases -// against known answer tests, obtained from Kyber reference implementation -// https://github.com/pq-crystals/kyber.git. It also helps in properly -// benchmarking underlying KEM's key generation implementation. +// ML-KEM key generation algorithm, generating byte serialized public key and secret key, given 32 -bytes seed `d` and `z`. +// See algorithm 15 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd template static inline void keygen(std::span d, // used in CPA-PKE @@ -34,20 +19,19 @@ keygen(std::span d, // used in CPA-PKE std::span seckey) requires(kyber_params::check_keygen_params(k, eta1)) { - constexpr size_t skoff0 = k * 12 * 32; - constexpr size_t skoff1 = skoff0 + pubkey.size(); - constexpr size_t skoff2 = skoff1 + 32; + static constexpr size_t skoff0 = k * 12 * 32; + static constexpr size_t skoff1 = skoff0 + pubkey.size(); + static constexpr size_t skoff2 = skoff1 + 32; auto _seckey0 = seckey.template subspan<0, skoff0>(); auto _seckey1 = seckey.template subspan(); auto _seckey2 = seckey.template subspan(); auto _seckey3 = seckey.template subspan(); - pke::keygen(d, pubkey, _seckey0); // CPAPKE key generation + k_pke::keygen(d, pubkey, _seckey0); std::copy(pubkey.begin(), pubkey.end(), _seckey1.begin()); std::copy(z.begin(), z.end(), _seckey3.begin()); - // hash public key sha3_256::sha3_256_t hasher{}; hasher.absorb(pubkey); hasher.finalize(); @@ -55,29 +39,16 @@ keygen(std::span d, // used in CPA-PKE hasher.reset(); } -// Given (k * 12 * 32 + 32) -bytes public key and 32 -bytes seed ( used for -// deriving 32 -bytes message & 32 -bytes random coin ), this routine computes -// cipher text of length (k * du * 32 + dv * 32) -bytes which can be shared with -// recipient party ( having respective secret key ) over insecure channel. +// Given ML-KEM public key and 32 -bytes seed ( used for deriving 32 -bytes message & 32 -bytes random coin ), this routine computes +// ML-KEM cipher text which can be shared with recipient party ( owning corresponding secret key ) over insecure channel. // -// It also computes a fixed length 32 -bytes shared secret, which can be used for -// symmetric key encryption between these two participating entities. Alternatively -// they might choose to derive longer keys from this shared secret. +// It also computes a fixed length 32 -bytes shared secret, which can be used for fast symmetric key encryption between these +// two participating entities. Alternatively they might choose to derive longer keys from this shared secret. Other side of +// communication should also be able to generate same 32 -byte shared secret, after successful decryption of cipher text. // -// Other side of communication should also be able to generate same 32 -byte shared secret, -// after successful decryption of cipher text. +// If invalid ML-KEM public key is input, this function execution will fail, returning false. // -// If invalid public key is input, this function execution will fail, returning false, -// otherwise it will return true, while producing both cipher text and shared secret. -// -// See algorithm 8 defined in Kyber specification -// https://doi.org/10.6028/NIST.FIPS.203.ipd -// -// Note, this routine allows you to pass 32 -bytes seed ( see first parameter ), -// which is designed this way for ease of writing test cases against known -// answer tests, obtained from Kyber reference implementation -// https://github.com/pq-crystals/kyber.git. It also helps in properly -// benchmarking underlying KEM's encapsulation implementation. +// See algorithm 16 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd template [[nodiscard("Use result, it might fail because of malformed input public key")]] static inline bool encapsulate(std::span m, @@ -109,8 +80,9 @@ encapsulate(std::span m, h512.finalize(); h512.digest(_g_out); - const auto has_mod_check_passed = pke::encrypt(pubkey, m, _g_out1, cipher); + const auto has_mod_check_passed = k_pke::encrypt(pubkey, m, _g_out1, cipher); if (!has_mod_check_passed) { + // Got an invalid public key return has_mod_check_passed; } @@ -118,20 +90,13 @@ encapsulate(std::span m, return true; } -// Given (k * 24 * 32 + 96) -bytes secret key and (k * du * 32 + dv * 32) -bytes -// encrypted ( cipher ) text, this routine recovers 32 -bytes plain text which -// was encrypted by sender, using respective public key, associated with this -// secret key. - -// Recovered 32 -bytes plain text is used for deriving same key stream ( using -// SHAKE256 key derivation function ), which is the shared secret key between -// two communicating parties, over insecure channel. Using returned KDF ( -// SHAKE256 object ) both parties can reach to same shared secret key ( of -// arbitrary length ), which will be used for encrypting traffic using symmetric -// key primitives. +// Given ML-KEM secret key and cipher text, this routine recovers 32 -bytes plain text which was encrypted by sender, +// using ML-KEM public key, associated with this secret key. // -// See algorithm 9 defined in Kyber specification -// https://doi.org/10.6028/NIST.FIPS.203.ipd +// Recovered 32 -bytes plain text is used for deriving a 32 -bytes shared secret key, which can now be +// used for encrypting communication between two participating parties, using fast symmetric key algorithms. +// +// See algorithm 17 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd. template static inline void decapsulate(std::span seckey, @@ -165,7 +130,7 @@ decapsulate(std::span sec auto _g_out0 = _g_out.template first(); auto _g_out1 = _g_out.template last<32>(); - pke::decrypt(pke_sk, cipher, _g_in0); + k_pke::decrypt(pke_sk, cipher, _g_in0); std::copy(h.begin(), h.end(), _g_in1.begin()); sha3_512::sha3_512_t h512{}; @@ -180,9 +145,9 @@ decapsulate(std::span sec xof256.squeeze(j_out); // Explicitly ignore return value, because public key, held as part of secret key is *assumed* to be valid. - (void)pke::encrypt(pubkey, _g_in0, _g_out1, c_prime); + (void)k_pke::encrypt(pubkey, _g_in0, _g_out1, c_prime); - // line 7-11 of algorithm 9, in constant-time + // line 9-12 of algorithm 17, in constant-time using kdf_t = std::span; const uint32_t cond = kyber_utils::ct_memcmp(cipher, std::span(c_prime)); kyber_utils::ct_cond_memcpy(cond, shared_secret, kdf_t(_g_out0), kdf_t(z)); diff --git a/include/kyber/kyber1024_kem.hpp b/include/kyber/kyber1024_kem.hpp index 54b1e2f..e8b13c1 100644 --- a/include/kyber/kyber1024_kem.hpp +++ b/include/kyber/kyber1024_kem.hpp @@ -1,5 +1,5 @@ #pragma once -#include "kyber/internals/kem.hpp" +#include "kyber/internals/ml_kem.hpp" namespace kyber1024_kem { @@ -40,7 +40,7 @@ keygen(std::span d, std::span pubkey, std::span seckey) { - kem::keygen(d, z, pubkey, seckey); + ml_kem::keygen(d, z, pubkey, seckey); } // Given seed `m` and a ML-KEM-1024 public key, this routine computes a ML-KEM-1024 cipher text and a fixed size shared secret. @@ -51,14 +51,14 @@ encapsulate(std::span m, std::span cipher, std::span shared_secret) { - return kem::encapsulate(m, pubkey, cipher, shared_secret); + return ml_kem::encapsulate(m, pubkey, cipher, shared_secret); } // Given a ML-KEM-1024 secret key and a cipher text, this routine computes a fixed size shared secret. inline void decapsulate(std::span seckey, std::span cipher, std::span shared_secret) { - kem::decapsulate(seckey, cipher, shared_secret); + ml_kem::decapsulate(seckey, cipher, shared_secret); } } diff --git a/include/kyber/kyber512_kem.hpp b/include/kyber/kyber512_kem.hpp index 25c11b9..0936020 100644 --- a/include/kyber/kyber512_kem.hpp +++ b/include/kyber/kyber512_kem.hpp @@ -1,5 +1,5 @@ #pragma once -#include "kyber/internals/kem.hpp" +#include "kyber/internals/ml_kem.hpp" namespace kyber512_kem { @@ -40,7 +40,7 @@ keygen(std::span d, std::span pubkey, std::span seckey) { - kem::keygen(d, z, pubkey, seckey); + ml_kem::keygen(d, z, pubkey, seckey); } // Given seed `m` and a ML-KEM-512 public key, this routine computes a ML-KEM-512 cipher text and a fixed size shared secret. @@ -51,14 +51,14 @@ encapsulate(std::span m, std::span cipher, std::span shared_secret) { - return kem::encapsulate(m, pubkey, cipher, shared_secret); + return ml_kem::encapsulate(m, pubkey, cipher, shared_secret); } // Given a ML-KEM-512 secret key and a cipher text, this routine computes a fixed size shared secret. inline void decapsulate(std::span seckey, std::span cipher, std::span shared_secret) { - kem::decapsulate(seckey, cipher, shared_secret); + ml_kem::decapsulate(seckey, cipher, shared_secret); } } diff --git a/include/kyber/kyber768_kem.hpp b/include/kyber/kyber768_kem.hpp index ce4263f..607d7d9 100644 --- a/include/kyber/kyber768_kem.hpp +++ b/include/kyber/kyber768_kem.hpp @@ -1,5 +1,5 @@ #pragma once -#include "kyber/internals/kem.hpp" +#include "kyber/internals/ml_kem.hpp" namespace kyber768_kem { @@ -40,7 +40,7 @@ keygen(std::span d, std::span pubkey, std::span seckey) { - kem::keygen(d, z, pubkey, seckey); + ml_kem::keygen(d, z, pubkey, seckey); } // Given seed `m` and a ML-KEM-768 public key, this routine computes a ML-KEM-768 cipher text and a fixed size shared secret. @@ -51,14 +51,14 @@ encapsulate(std::span m, std::span cipher, std::span shared_secret) { - return kem::encapsulate(m, pubkey, cipher, shared_secret); + return ml_kem::encapsulate(m, pubkey, cipher, shared_secret); } // Given a ML-KEM-768 secret key and a cipher text, this routine computes a fixed size shared secret. inline void decapsulate(std::span seckey, std::span cipher, std::span shared_secret) { - kem::decapsulate(seckey, cipher, shared_secret); + ml_kem::decapsulate(seckey, cipher, shared_secret); } } diff --git a/tests/test_kem.cpp b/tests/test_kem.cpp index 4a8bccf..352e405 100644 --- a/tests/test_kem.cpp +++ b/tests/test_kem.cpp @@ -1,4 +1,4 @@ -#include "kyber/internals/kem.hpp" +#include "kyber/internals/ml_kem.hpp" #include "kyber/internals/utility/utils.hpp" #include @@ -48,9 +48,9 @@ test_kyber_kem() prng.read(z); prng.read(m); - kem::keygen(_d, _z, _pkey, _skey); - (void)kem::encapsulate(_m, _pkey, _cipher, _sender_key); - kem::decapsulate(_skey, _cipher, _receiver_key); + ml_kem::keygen(_d, _z, _pkey, _skey); + (void)ml_kem::encapsulate(_m, _pkey, _cipher, _sender_key); + ml_kem::decapsulate(_skey, _cipher, _receiver_key); EXPECT_EQ(sender_key, receiver_key); }