Refactor PKE and KEM implementation

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2024-06-18 11:41:43 +04:00
parent bb1a5ace51
commit 0d0a151a64
7 changed files with 62 additions and 142 deletions

View File

@@ -1,5 +1,5 @@
#include "bench_helper.hpp"
#include "kyber/internals/kem.hpp"
#include "kyber/internals/ml_kem.hpp"
#include "x86_64_cpu_ticks.hpp"
#include <benchmark/benchmark.h>
@@ -35,7 +35,7 @@ bench_keygen(benchmark::State& state)
const uint64_t start = cpu_ticks();
#endif
kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
ml_kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
benchmark::DoNotOptimize(_d);
benchmark::DoNotOptimize(_z);
@@ -88,7 +88,7 @@ bench_encapsulate(benchmark::State& state)
prng.read(_d);
prng.read(_z);
kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
ml_kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
prng.read(_m);
@@ -101,7 +101,7 @@ bench_encapsulate(benchmark::State& state)
const uint64_t start = cpu_ticks();
#endif
(void)kem::encapsulate<k, eta1, eta2, du, dv>(_m, _pkey, _cipher, _sender_key);
(void)ml_kem::encapsulate<k, eta1, eta2, du, dv>(_m, _pkey, _cipher, _sender_key);
benchmark::DoNotOptimize(_m);
benchmark::DoNotOptimize(_pkey);
@@ -156,11 +156,11 @@ bench_decapsulate(benchmark::State& state)
prng.read(_d);
prng.read(_z);
kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
ml_kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
prng.read(_m);
(void)kem::encapsulate<k, eta1, eta2, du, dv>(_m, _pkey, _cipher, _sender_key);
(void)ml_kem::encapsulate<k, eta1, eta2, du, dv>(_m, _pkey, _cipher, _sender_key);
#ifdef __x86_64__
uint64_t total_ticks = 0ul;
@@ -171,7 +171,7 @@ bench_decapsulate(benchmark::State& state)
const uint64_t start = cpu_ticks();
#endif
kem::decapsulate<k, eta1, eta2, du, dv>(_skey, _cipher, _receiver_key);
ml_kem::decapsulate<k, eta1, eta2, du, dv>(_skey, _cipher, _receiver_key);
benchmark::DoNotOptimize(_skey);
benchmark::DoNotOptimize(_cipher);

View File

@@ -5,33 +5,17 @@
#include "kyber/internals/utility/params.hpp"
#include "kyber/internals/utility/utils.hpp"
#include "sha3_512.hpp"
#include <array>
#include <span>
// IND-CPA-secure Public Key Encryption Scheme
namespace pke {
// Public Key Encryption Scheme
namespace k_pke {
// Kyber CPAPKE key generation algorithm, which takes two parameters `k` & `η1`
// ( read eta1 ) and generates byte serialized public key and secret key of
// following length
//
// public key: (k * 12 * 32 + 32) -bytes wide
// secret key: (k * 12 * 32) -bytes wide
//
// See algorithm 4 defined in Kyber specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
//
// Note, this routine allows you to pass 32 -bytes seed ( see first parameter ),
// which is designed this way for ease of writing test cases against known
// answer tests, obtained from Kyber reference implementation
// https://github.com/pq-crystals/kyber.git. It also helps in properly
// benchmarking underlying PKE's key generation implementation.
// K-PKE key generation algorithm, generating byte serialized public key and secret keym given a 32 -bytes input seed `d`.
// See algorithm 12 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1>
static inline void
keygen(std::span<const uint8_t, 32> d, std::span<uint8_t, k * 12 * 32 + 32> pubkey, std::span<uint8_t, k * 12 * 32> seckey)
requires(kyber_params::check_keygen_params(k, eta1))
{
// step 2
std::array<uint8_t, 64> g_out{};
auto _g_out = std::span(g_out);
@@ -43,34 +27,27 @@ keygen(std::span<const uint8_t, 32> d, std::span<uint8_t, k * 12 * 32 + 32> pubk
const auto rho = _g_out.template subspan<0, 32>();
const auto sigma = _g_out.template subspan<rho.size(), 32>();
// step 4, 5, 6, 7, 8
std::array<field::zq_t, k * k * ntt::N> A_prime{};
kyber_utils::generate_matrix<k, false>(A_prime, rho);
// step 3
uint8_t N = 0;
// step 9, 10, 11, 12
std::array<field::zq_t, k * ntt::N> s{};
kyber_utils::generate_vector<k, eta1>(s, sigma, N);
N += k;
// step 13, 14, 15, 16
std::array<field::zq_t, k * ntt::N> e{};
kyber_utils::generate_vector<k, eta1>(e, sigma, N);
N += k;
// step 17, 18
kyber_utils::poly_vec_ntt<k>(s);
kyber_utils::poly_vec_ntt<k>(e);
// step 19
std::array<field::zq_t, k * ntt::N> t_prime{};
kyber_utils::matrix_multiply<k, k, k, 1>(A_prime, s, t_prime);
kyber_utils::poly_vec_add_to<k>(e, t_prime);
// step 20, 21, 22
constexpr size_t pkoff = k * 12 * 32;
auto _pubkey0 = pubkey.template subspan<0, pkoff>();
auto _pubkey1 = pubkey.template subspan<pkoff, 32>();
@@ -80,17 +57,13 @@ keygen(std::span<const uint8_t, 32> d, std::span<uint8_t, k * 12 * 32 + 32> pubk
kyber_utils::poly_vec_encode<k, 12>(s, seckey);
}
// Given (k * 12 * 32 + 32) -bytes *valid* public key, 32 -bytes message ( to be
// encrypted ) and 32 -bytes random coin ( from where all randomness is
// deterministically sampled ), this routine encrypts message using
// INDCPA-secure Kyber encryption algorithm, computing compressed cipher text of
// (k * du * 32 + dv * 32) -bytes.
// Given a *valid* K-PKE public key, 32 -bytes message ( to be encrypted ) and 32 -bytes random coin
// ( from where all randomness is deterministically sampled ), this routine encrypts message using
// K-PKE encryption algorithm, computing compressed cipher text.
//
// If modulus check, as described in point (2) of section 6.2 of ML-KEM draft standard,
// fails, it returns false, otherwise it returns true.
// If modulus check, as described in point (2) of section 6.2 of ML-KEM draft standard, fails, it returns false.
//
// See algorithm 5 defined in Kyber specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
// See algorithm 13 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result of modulus check on public key")]] static inline bool
encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
@@ -99,7 +72,6 @@ encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
std::span<uint8_t, k * du * 32 + dv * 32> enc)
requires(kyber_params::check_encrypt_params(k, eta1, eta2, du, dv))
{
// step 2, 3
constexpr size_t pkoff = k * 12 * 32;
auto _pubkey0 = pubkey.template subspan<0, pkoff>();
auto rho = pubkey.template subspan<pkoff, 32>();
@@ -117,38 +89,30 @@ encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
return false;
}
// step 4, 5, 6, 7, 8
std::array<field::zq_t, k * k * ntt::N> A_prime{};
kyber_utils::generate_matrix<k, true>(A_prime, rho);
// step 1
uint8_t N = 0;
// step 9, 10, 11, 12
std::array<field::zq_t, k * ntt::N> r{};
kyber_utils::generate_vector<k, eta1>(r, rcoin, N);
N += k;
// step 13, 14, 15, 16
std::array<field::zq_t, k * ntt::N> e1{};
kyber_utils::generate_vector<k, eta2>(e1, rcoin, N);
N += k;
// step 17
std::array<field::zq_t, ntt::N> e2{};
kyber_utils::generate_vector<1, eta2>(e2, rcoin, N);
// step 18
kyber_utils::poly_vec_ntt<k>(r);
// step 19
std::array<field::zq_t, k * ntt::N> u{};
kyber_utils::matrix_multiply<k, k, k, 1>(A_prime, r, u);
kyber_utils::poly_vec_intt<k>(u);
kyber_utils::poly_vec_add_to<k>(e1, u);
// step 20
std::array<field::zq_t, ntt::N> v{};
kyber_utils::matrix_multiply<1, k, k, 1>(t_prime, r, v);
@@ -164,24 +128,19 @@ encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
auto _enc0 = enc.template subspan<0, encoff>();
auto _enc1 = enc.template subspan<encoff, dv * 32>();
// step 21
kyber_utils::poly_vec_compress<k, du>(u);
kyber_utils::poly_vec_encode<k, du>(u, _enc0);
// step 22
kyber_utils::poly_compress<dv>(v);
kyber_utils::encode<dv>(v, _enc1);
return true;
}
// Given (k * 12 * 32) -bytes secret key and (k * du * 32 + dv * 32) -bytes
// encrypted ( cipher ) text, this routine recovers 32 -bytes plain text which
// was encrypted using respective public key, which is associated with this
// secret key.
// Given K-PKE secret key and cipher text, this routine recovers 32 -bytes plain text which
// was encrypted using K-PKE public key i.e. associated with this secret key.
//
// See algorithm 6 defined in Kyber specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
// See algorithm 14 defined in K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t du, size_t dv>
static inline void
decrypt(std::span<const uint8_t, k * 12 * 32> seckey, std::span<const uint8_t, k * du * 32 + dv * 32> enc, std::span<uint8_t, 32> dec)
@@ -191,23 +150,19 @@ decrypt(std::span<const uint8_t, k * 12 * 32> seckey, std::span<const uint8_t, k
auto _enc0 = enc.template subspan<0, encoff>();
auto _enc1 = enc.template subspan<encoff, dv * 32>();
// step 1
std::array<field::zq_t, k * ntt::N> u{};
kyber_utils::poly_vec_decode<k, du>(_enc0, u);
kyber_utils::poly_vec_decompress<k, du>(u);
// step 2
std::array<field::zq_t, ntt::N> v{};
kyber_utils::decode<dv>(_enc1, v);
kyber_utils::poly_decompress<dv>(v);
// step 3
std::array<field::zq_t, k * ntt::N> s_prime{};
kyber_utils::poly_vec_decode<k, 12>(seckey, s_prime);
// step 4
kyber_utils::poly_vec_ntt<k>(u);
std::array<field::zq_t, ntt::N> t{};

View File

@@ -1,31 +1,16 @@
#pragma once
#include "k_pke.hpp"
#include "kyber/internals/utility/utils.hpp"
#include "pke.hpp"
#include "sha3_256.hpp"
#include "sha3_512.hpp"
#include "shake256.hpp"
#include <algorithm>
#include <array>
#include <cstdint>
// IND-CCA2-secure Key Encapsulation Mechanism
namespace kem {
// Key Encapsulation Mechanism
namespace ml_kem {
// Kyber CCAKEM key generation algorithm, which takes two parameters `k` & `η1`
// ( read eta1 ) and generates byte serialized public key and secret key of
// following length
//
// public key: (k * 12 * 32 + 32) -bytes wide
// secret key: (k * 24 * 32 + 96) -bytes wide [ includes public key ]
//
// See algorithm 7 defined in Kyber specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
//
// Note, this routine allows you to pass two 32 -bytes seeds ( see first &
// second parameter ), which is designed this way for ease of writing test cases
// against known answer tests, obtained from Kyber reference implementation
// https://github.com/pq-crystals/kyber.git. It also helps in properly
// benchmarking underlying KEM's key generation implementation.
// ML-KEM key generation algorithm, generating byte serialized public key and secret key, given 32 -bytes seed `d` and `z`.
// See algorithm 15 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t k, size_t eta1>
static inline void
keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
@@ -34,20 +19,19 @@ keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
std::span<uint8_t, kyber_utils::get_kem_secret_key_len(k)> seckey)
requires(kyber_params::check_keygen_params(k, eta1))
{
constexpr size_t skoff0 = k * 12 * 32;
constexpr size_t skoff1 = skoff0 + pubkey.size();
constexpr size_t skoff2 = skoff1 + 32;
static constexpr size_t skoff0 = k * 12 * 32;
static constexpr size_t skoff1 = skoff0 + pubkey.size();
static constexpr size_t skoff2 = skoff1 + 32;
auto _seckey0 = seckey.template subspan<0, skoff0>();
auto _seckey1 = seckey.template subspan<skoff0, skoff1 - skoff0>();
auto _seckey2 = seckey.template subspan<skoff1, skoff2 - skoff1>();
auto _seckey3 = seckey.template subspan<skoff2, seckey.size() - skoff2>();
pke::keygen<k, eta1>(d, pubkey, _seckey0); // CPAPKE key generation
k_pke::keygen<k, eta1>(d, pubkey, _seckey0);
std::copy(pubkey.begin(), pubkey.end(), _seckey1.begin());
std::copy(z.begin(), z.end(), _seckey3.begin());
// hash public key
sha3_256::sha3_256_t hasher{};
hasher.absorb(pubkey);
hasher.finalize();
@@ -55,29 +39,16 @@ keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
hasher.reset();
}
// Given (k * 12 * 32 + 32) -bytes public key and 32 -bytes seed ( used for
// deriving 32 -bytes message & 32 -bytes random coin ), this routine computes
// cipher text of length (k * du * 32 + dv * 32) -bytes which can be shared with
// recipient party ( having respective secret key ) over insecure channel.
// Given ML-KEM public key and 32 -bytes seed ( used for deriving 32 -bytes message & 32 -bytes random coin ), this routine computes
// ML-KEM cipher text which can be shared with recipient party ( owning corresponding secret key ) over insecure channel.
//
// It also computes a fixed length 32 -bytes shared secret, which can be used for
// symmetric key encryption between these two participating entities. Alternatively
// they might choose to derive longer keys from this shared secret.
// It also computes a fixed length 32 -bytes shared secret, which can be used for fast symmetric key encryption between these
// two participating entities. Alternatively they might choose to derive longer keys from this shared secret. Other side of
// communication should also be able to generate same 32 -byte shared secret, after successful decryption of cipher text.
//
// Other side of communication should also be able to generate same 32 -byte shared secret,
// after successful decryption of cipher text.
// If invalid ML-KEM public key is input, this function execution will fail, returning false.
//
// If invalid public key is input, this function execution will fail, returning false,
// otherwise it will return true, while producing both cipher text and shared secret.
//
// See algorithm 8 defined in Kyber specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
//
// Note, this routine allows you to pass 32 -bytes seed ( see first parameter ),
// which is designed this way for ease of writing test cases against known
// answer tests, obtained from Kyber reference implementation
// https://github.com/pq-crystals/kyber.git. It also helps in properly
// benchmarking underlying KEM's encapsulation implementation.
// See algorithm 16 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result, it might fail because of malformed input public key")]] static inline bool
encapsulate(std::span<const uint8_t, 32> m,
@@ -109,8 +80,9 @@ encapsulate(std::span<const uint8_t, 32> m,
h512.finalize();
h512.digest(_g_out);
const auto has_mod_check_passed = pke::encrypt<k, eta1, eta2, du, dv>(pubkey, m, _g_out1, cipher);
const auto has_mod_check_passed = k_pke::encrypt<k, eta1, eta2, du, dv>(pubkey, m, _g_out1, cipher);
if (!has_mod_check_passed) {
// Got an invalid public key
return has_mod_check_passed;
}
@@ -118,20 +90,13 @@ encapsulate(std::span<const uint8_t, 32> m,
return true;
}
// Given (k * 24 * 32 + 96) -bytes secret key and (k * du * 32 + dv * 32) -bytes
// encrypted ( cipher ) text, this routine recovers 32 -bytes plain text which
// was encrypted by sender, using respective public key, associated with this
// secret key.
// Recovered 32 -bytes plain text is used for deriving same key stream ( using
// SHAKE256 key derivation function ), which is the shared secret key between
// two communicating parties, over insecure channel. Using returned KDF (
// SHAKE256 object ) both parties can reach to same shared secret key ( of
// arbitrary length ), which will be used for encrypting traffic using symmetric
// key primitives.
// Given ML-KEM secret key and cipher text, this routine recovers 32 -bytes plain text which was encrypted by sender,
// using ML-KEM public key, associated with this secret key.
//
// See algorithm 9 defined in Kyber specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
// Recovered 32 -bytes plain text is used for deriving a 32 -bytes shared secret key, which can now be
// used for encrypting communication between two participating parties, using fast symmetric key algorithms.
//
// See algorithm 17 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
static inline void
decapsulate(std::span<const uint8_t, kyber_utils::get_kem_secret_key_len(k)> seckey,
@@ -165,7 +130,7 @@ decapsulate(std::span<const uint8_t, kyber_utils::get_kem_secret_key_len(k)> sec
auto _g_out0 = _g_out.template first<shared_secret.size()>();
auto _g_out1 = _g_out.template last<32>();
pke::decrypt<k, du, dv>(pke_sk, cipher, _g_in0);
k_pke::decrypt<k, du, dv>(pke_sk, cipher, _g_in0);
std::copy(h.begin(), h.end(), _g_in1.begin());
sha3_512::sha3_512_t h512{};
@@ -180,9 +145,9 @@ decapsulate(std::span<const uint8_t, kyber_utils::get_kem_secret_key_len(k)> sec
xof256.squeeze(j_out);
// Explicitly ignore return value, because public key, held as part of secret key is *assumed* to be valid.
(void)pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, c_prime);
(void)k_pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, c_prime);
// line 7-11 of algorithm 9, in constant-time
// line 9-12 of algorithm 17, in constant-time
using kdf_t = std::span<const uint8_t, shared_secret.size()>;
const uint32_t cond = kyber_utils::ct_memcmp(cipher, std::span<const uint8_t, ctlen>(c_prime));
kyber_utils::ct_cond_memcpy(cond, shared_secret, kdf_t(_g_out0), kdf_t(z));

View File

@@ -1,5 +1,5 @@
#pragma once
#include "kyber/internals/kem.hpp"
#include "kyber/internals/ml_kem.hpp"
namespace kyber1024_kem {
@@ -40,7 +40,7 @@ keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
std::span<uint8_t, PKEY_BYTE_LEN> pubkey,
std::span<uint8_t, SKEY_BYTE_LEN> seckey)
{
kem::keygen<k, η1>(d, z, pubkey, seckey);
ml_kem::keygen<k, η1>(d, z, pubkey, seckey);
}
// Given seed `m` and a ML-KEM-1024 public key, this routine computes a ML-KEM-1024 cipher text and a fixed size shared secret.
@@ -51,14 +51,14 @@ encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
std::span<uint8_t, CIPHER_TEXT_BYTE_LEN> cipher,
std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
return kem::encapsulate<k, η1, η2, du, dv>(m, pubkey, cipher, shared_secret);
return ml_kem::encapsulate<k, η1, η2, du, dv>(m, pubkey, cipher, shared_secret);
}
// Given a ML-KEM-1024 secret key and a cipher text, this routine computes a fixed size shared secret.
inline void
decapsulate(std::span<const uint8_t, SKEY_BYTE_LEN> seckey, std::span<const uint8_t, CIPHER_TEXT_BYTE_LEN> cipher, std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);
ml_kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);
}
}

View File

@@ -1,5 +1,5 @@
#pragma once
#include "kyber/internals/kem.hpp"
#include "kyber/internals/ml_kem.hpp"
namespace kyber512_kem {
@@ -40,7 +40,7 @@ keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
std::span<uint8_t, PKEY_BYTE_LEN> pubkey,
std::span<uint8_t, SKEY_BYTE_LEN> seckey)
{
kem::keygen<k, η1>(d, z, pubkey, seckey);
ml_kem::keygen<k, η1>(d, z, pubkey, seckey);
}
// Given seed `m` and a ML-KEM-512 public key, this routine computes a ML-KEM-512 cipher text and a fixed size shared secret.
@@ -51,14 +51,14 @@ encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
std::span<uint8_t, CIPHER_TEXT_BYTE_LEN> cipher,
std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
return kem::encapsulate<k, η1, η2, du, dv>(m, pubkey, cipher, shared_secret);
return ml_kem::encapsulate<k, η1, η2, du, dv>(m, pubkey, cipher, shared_secret);
}
// Given a ML-KEM-512 secret key and a cipher text, this routine computes a fixed size shared secret.
inline void
decapsulate(std::span<const uint8_t, SKEY_BYTE_LEN> seckey, std::span<const uint8_t, CIPHER_TEXT_BYTE_LEN> cipher, std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);
ml_kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);
}
}

View File

@@ -1,5 +1,5 @@
#pragma once
#include "kyber/internals/kem.hpp"
#include "kyber/internals/ml_kem.hpp"
namespace kyber768_kem {
@@ -40,7 +40,7 @@ keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
std::span<uint8_t, PKEY_BYTE_LEN> pubkey,
std::span<uint8_t, SKEY_BYTE_LEN> seckey)
{
kem::keygen<k, η1>(d, z, pubkey, seckey);
ml_kem::keygen<k, η1>(d, z, pubkey, seckey);
}
// Given seed `m` and a ML-KEM-768 public key, this routine computes a ML-KEM-768 cipher text and a fixed size shared secret.
@@ -51,14 +51,14 @@ encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
std::span<uint8_t, CIPHER_TEXT_BYTE_LEN> cipher,
std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
return kem::encapsulate<k, η1, η2, du, dv>(m, pubkey, cipher, shared_secret);
return ml_kem::encapsulate<k, η1, η2, du, dv>(m, pubkey, cipher, shared_secret);
}
// Given a ML-KEM-768 secret key and a cipher text, this routine computes a fixed size shared secret.
inline void
decapsulate(std::span<const uint8_t, SKEY_BYTE_LEN> seckey, std::span<const uint8_t, CIPHER_TEXT_BYTE_LEN> cipher, std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);
ml_kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);
}
}

View File

@@ -1,4 +1,4 @@
#include "kyber/internals/kem.hpp"
#include "kyber/internals/ml_kem.hpp"
#include "kyber/internals/utility/utils.hpp"
#include <gtest/gtest.h>
@@ -48,9 +48,9 @@ test_kyber_kem()
prng.read(z);
prng.read(m);
kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
(void)kem::encapsulate<k, eta1, eta2, du, dv>(_m, _pkey, _cipher, _sender_key);
kem::decapsulate<k, eta1, eta2, du, dv>(_skey, _cipher, _receiver_key);
ml_kem::keygen<k, eta1>(_d, _z, _pkey, _skey);
(void)ml_kem::encapsulate<k, eta1, eta2, du, dv>(_m, _pkey, _cipher, _sender_key);
ml_kem::decapsulate<k, eta1, eta2, du, dv>(_skey, _cipher, _receiver_key);
EXPECT_EQ(sender_key, receiver_key);
}