Rename public header files (along with namespaces) for ML-KEM

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2024-06-18 11:52:54 +04:00
parent 0d0a151a64
commit cd0a3bd25b
31 changed files with 459 additions and 491 deletions

View File

@@ -1,4 +1,4 @@
name: Test Kyber Key Encapsulation Mechanism
name: Test Ml_kem Key Encapsulation Mechanism
on:
push:

View File

@@ -14,7 +14,7 @@ DEP_IFLAGS = -I $(SHA3_INC_DIR) -I $(SUBTLE_INC_DIR)
DUDECT_DEP_IFLAGS = $(DEP_IFLAGS) -I $(DUDECT_INC_DIR)
SRC_DIR = include
KYBER_SOURCES := $(shell find $(SRC_DIR) -name '*.hpp')
ML_KEM_SOURCES := $(shell find $(SRC_DIR) -name '*.hpp')
BUILD_DIR = build
DUDECT_BUILD_DIR = $(BUILD_DIR)/dudect
ASAN_BUILD_DIR = $(BUILD_DIR)/asan
@@ -123,5 +123,5 @@ perf: $(PERF_BINARY)
clean:
rm -rf $(BUILD_DIR)
format: $(KYBER_SOURCES) $(TEST_SOURCES) $(DUDECT_TEST_SOURCES) $(BENCHMARK_SOURCES) $(BENCHMARK_HEADERS)
format: $(ML_KEM_SOURCES) $(TEST_SOURCES) $(DUDECT_TEST_SOURCES) $(BENCHMARK_SOURCES) $(BENCHMARK_HEADERS)
clang-format -i $^

View File

@@ -1,16 +1,16 @@
#include "bench_helper.hpp"
#include "kyber/internals/ml_kem.hpp"
#include "ml_kem/internals/ml_kem.hpp"
#include "x86_64_cpu_ticks.hpp"
#include <benchmark/benchmark.h>
// Benchmarking IND-CCA2-secure Kyber KEM key generation algorithm
// Benchmarking IND-CCA2-secure Ml_kem KEM key generation algorithm
template<size_t k, size_t eta1, size_t bit_security_level>
void
bench_keygen(benchmark::State& state)
{
constexpr size_t slen = 32;
constexpr size_t pklen = kyber_utils::get_kem_public_key_len(k);
constexpr size_t sklen = kyber_utils::get_kem_secret_key_len(k);
constexpr size_t pklen = ml_kem_utils::get_kem_public_key_len(k);
constexpr size_t sklen = ml_kem_utils::get_kem_secret_key_len(k);
std::vector<uint8_t> d(slen);
std::vector<uint8_t> z(slen);
@@ -57,15 +57,15 @@ bench_keygen(benchmark::State& state)
#endif
}
// Benchmarking IND-CCA2-secure Kyber KEM encapsulation algorithm
// Benchmarking IND-CCA2-secure Ml_kem KEM encapsulation algorithm
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv, size_t bit_security_level>
void
bench_encapsulate(benchmark::State& state)
{
constexpr size_t slen = 32;
constexpr size_t pklen = kyber_utils::get_kem_public_key_len(k);
constexpr size_t sklen = kyber_utils::get_kem_secret_key_len(k);
constexpr size_t ctlen = kyber_utils::get_kem_cipher_text_len(k, du, dv);
constexpr size_t pklen = ml_kem_utils::get_kem_public_key_len(k);
constexpr size_t sklen = ml_kem_utils::get_kem_secret_key_len(k);
constexpr size_t ctlen = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
constexpr size_t klen = 32;
std::vector<uint8_t> d(slen);
@@ -123,15 +123,15 @@ bench_encapsulate(benchmark::State& state)
#endif
}
// Benchmarking IND-CCA2-secure Kyber KEM decapsulation algorithm
// Benchmarking IND-CCA2-secure Ml_kem KEM decapsulation algorithm
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv, size_t bit_security_level>
void
bench_decapsulate(benchmark::State& state)
{
constexpr size_t slen = 32;
constexpr size_t pklen = kyber_utils::get_kem_public_key_len(k);
constexpr size_t sklen = kyber_utils::get_kem_secret_key_len(k);
constexpr size_t ctlen = kyber_utils::get_kem_cipher_text_len(k, du, dv);
constexpr size_t pklen = ml_kem_utils::get_kem_public_key_len(k);
constexpr size_t sklen = ml_kem_utils::get_kem_secret_key_len(k);
constexpr size_t ctlen = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
constexpr size_t klen = 32;
std::vector<uint8_t> d(slen);
@@ -193,17 +193,17 @@ bench_decapsulate(benchmark::State& state)
#endif
}
// Kyber512
BENCHMARK(bench_keygen<2, 3, 128>)->Name("kyber512/keygen")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_encapsulate<2, 3, 2, 10, 4, 128>)->Name("kyber512/encap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_decapsulate<2, 3, 2, 10, 4, 128>)->Name("kyber512/decap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
// Ml_kem512
BENCHMARK(bench_keygen<2, 3, 128>)->Name("ml_kem512/keygen")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_encapsulate<2, 3, 2, 10, 4, 128>)->Name("ml_kem512/encap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_decapsulate<2, 3, 2, 10, 4, 128>)->Name("ml_kem512/decap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
// Kyber768
BENCHMARK(bench_keygen<3, 2, 192>)->Name("kyber768/keygen")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_encapsulate<3, 2, 2, 10, 4, 192>)->Name("kyber768/encap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_decapsulate<3, 2, 2, 10, 4, 192>)->Name("kyber768/decap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
// Ml_kem768
BENCHMARK(bench_keygen<3, 2, 192>)->Name("ml_kem768/keygen")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_encapsulate<3, 2, 2, 10, 4, 192>)->Name("ml_kem768/encap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_decapsulate<3, 2, 2, 10, 4, 192>)->Name("ml_kem768/decap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
// Kyber1024
BENCHMARK(bench_keygen<4, 2, 256>)->Name("kyber1024/keygen")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_encapsulate<4, 2, 2, 11, 5, 256>)->Name("kyber1024/encap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_decapsulate<4, 2, 2, 11, 5, 256>)->Name("kyber1024/decap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
// Ml_kem1024
BENCHMARK(bench_keygen<4, 2, 256>)->Name("ml_kem1024/keygen")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_encapsulate<4, 2, 2, 11, 5, 256>)->Name("ml_kem1024/encap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);
BENCHMARK(bench_decapsulate<4, 2, 2, 11, 5, 256>)->Name("ml_kem1024/decap")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max);

View File

@@ -1,88 +0,0 @@
#include "kyber/kyber512_kem.hpp"
#include <algorithm>
#include <cassert>
#include <iomanip>
#include <iostream>
#include <sstream>
// Given a bytearray of length N, this function converts it to human readable hex formatted string of length 2*N | N >= 0.
static inline std::string
to_hex(std::span<const uint8_t> bytes)
{
std::stringstream ss;
ss << std::hex;
for (size_t i = 0; i < bytes.size(); i++) {
ss << std::setw(2) << std::setfill('0') << static_cast<uint32_t>(bytes[i]);
}
return ss.str();
}
// Compile it with
//
// g++ -std=c++20 -Wall -Wextra -pedantic -O3 -march=native -I ./include -I ./sha3/include -I ./subtle/include/ examples/kyber512_kem.cpp
int
main()
{
constexpr size_t SEED_LEN = 32;
constexpr size_t KEY_LEN = 32;
// seeds required for keypair generation
std::vector<uint8_t> d(SEED_LEN, 0);
std::vector<uint8_t> z(SEED_LEN, 0);
auto _d = std::span<uint8_t, SEED_LEN>(d);
auto _z = std::span<uint8_t, SEED_LEN>(z);
// public/ private keypair
std::vector<uint8_t> pkey(kyber512_kem::PKEY_BYTE_LEN, 0);
std::vector<uint8_t> skey(kyber512_kem::SKEY_BYTE_LEN, 0);
auto _pkey = std::span<uint8_t, kyber512_kem::PKEY_BYTE_LEN>(pkey);
auto _skey = std::span<uint8_t, kyber512_kem::SKEY_BYTE_LEN>(skey);
// seed required for key encapsulation
std::vector<uint8_t> m(SEED_LEN, 0);
std::vector<uint8_t> cipher(kyber512_kem::CIPHER_TEXT_BYTE_LEN, 0);
auto _m = std::span<uint8_t, SEED_LEN>(m);
auto _cipher = std::span<uint8_t, kyber512_kem::CIPHER_TEXT_BYTE_LEN>(cipher);
// shared secret that sender/ receiver arrives at
std::vector<uint8_t> shrd_key0(KEY_LEN, 0);
std::vector<uint8_t> shrd_key1(KEY_LEN, 0);
auto _shrd_key0 = std::span<uint8_t, KEY_LEN>(shrd_key0);
auto _shrd_key1 = std::span<uint8_t, KEY_LEN>(shrd_key1);
// pseudo-randomness source
prng::prng_t<128> prng{};
// fill up seeds using PRNG
prng.read(_d);
prng.read(_z);
// generate a keypair
kyber512_kem::keygen(_d, _z, _pkey, _skey);
// fill up seed required for key encapsulation, using PRNG
prng.read(_m);
// encapsulate key, compute cipher text and obtain KDF
const bool is_encapsulated = kyber512_kem::encapsulate(_m, _pkey, _cipher, _shrd_key0);
// decapsulate cipher text and obtain KDF
kyber512_kem::decapsulate(_skey, _cipher, _shrd_key1);
// check that both of the communicating parties arrived at same shared key
assert(std::ranges::equal(_shrd_key0, _shrd_key1));
std::cout << "Kyber512 KEM\n";
std::cout << "pubkey : " << to_hex(_pkey) << "\n";
std::cout << "seckey : " << to_hex(_skey) << "\n";
std::cout << "encapsulated ? : " << std::boolalpha << is_encapsulated << "\n";
std::cout << "cipher : " << to_hex(_cipher) << "\n";
std::cout << "shared secret : " << to_hex(_shrd_key0) << "\n";
return EXIT_SUCCESS;
}

88
examples/ml_kem_768.cpp Normal file
View File

@@ -0,0 +1,88 @@
#include "ml_kem/ml_kem_768.hpp"
#include <algorithm>
#include <cassert>
#include <iomanip>
#include <iostream>
#include <sstream>
// Given a bytearray of length N, this function converts it to human readable hex formatted string of length 2*N | N >= 0.
static inline std::string
to_hex(std::span<const uint8_t> bytes)
{
std::stringstream ss;
ss << std::hex;
for (size_t i = 0; i < bytes.size(); i++) {
ss << std::setw(2) << std::setfill('0') << static_cast<uint32_t>(bytes[i]);
}
return ss.str();
}
// Compile it with
//
// g++ -std=c++20 -Wall -Wextra -pedantic -O3 -march=native -I ./include -I ./sha3/include -I ./subtle/include/ examples/ml_kem_768.cpp
int
main()
{
constexpr size_t SEED_LEN = 32;
constexpr size_t KEY_LEN = 32;
// Seeds required for keypair generation
std::vector<uint8_t> d(SEED_LEN, 0);
std::vector<uint8_t> z(SEED_LEN, 0);
auto _d = std::span<uint8_t, SEED_LEN>(d);
auto _z = std::span<uint8_t, SEED_LEN>(z);
// Public/ private keypair
std::vector<uint8_t> pkey(ml_kem_768::PKEY_BYTE_LEN, 0);
std::vector<uint8_t> skey(ml_kem_768::SKEY_BYTE_LEN, 0);
auto _pkey = std::span<uint8_t, ml_kem_768::PKEY_BYTE_LEN>(pkey);
auto _skey = std::span<uint8_t, ml_kem_768::SKEY_BYTE_LEN>(skey);
// Seed required for key encapsulation
std::vector<uint8_t> m(SEED_LEN, 0);
std::vector<uint8_t> cipher(ml_kem_768::CIPHER_TEXT_BYTE_LEN, 0);
auto _m = std::span<uint8_t, SEED_LEN>(m);
auto _cipher = std::span<uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN>(cipher);
// Shared secret that sender/ receiver arrives at
std::vector<uint8_t> shrd_key0(KEY_LEN, 0);
std::vector<uint8_t> shrd_key1(KEY_LEN, 0);
auto _shrd_key0 = std::span<uint8_t, KEY_LEN>(shrd_key0);
auto _shrd_key1 = std::span<uint8_t, KEY_LEN>(shrd_key1);
// Pseudo-randomness source
prng::prng_t<128> prng{};
// Fill up seeds using PRNG
prng.read(_d);
prng.read(_z);
// Generate a keypair
ml_kem_768::keygen(_d, _z, _pkey, _skey);
// Fill up seed required for key encapsulation, using PRNG
prng.read(_m);
// Encapsulate key, compute cipher text and obtain KDF
const bool is_encapsulated = ml_kem_768::encapsulate(_m, _pkey, _cipher, _shrd_key0);
// Decapsulate cipher text and obtain KDF
ml_kem_768::decapsulate(_skey, _cipher, _shrd_key1);
// Check that both of the communicating parties arrived at same shared secret key
assert(std::ranges::equal(_shrd_key0, _shrd_key1));
std::cout << "ML-KEM-768\n";
std::cout << "Pubkey : " << to_hex(_pkey) << "\n";
std::cout << "Seckey : " << to_hex(_skey) << "\n";
std::cout << "Encapsulated ? : " << std::boolalpha << is_encapsulated << "\n";
std::cout << "Cipher : " << to_hex(_cipher) << "\n";
std::cout << "Shared secret : " << to_hex(_shrd_key0) << "\n";
return EXIT_SUCCESS;
}

View File

@@ -1,9 +1,9 @@
#pragma once
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/poly_vec.hpp"
#include "kyber/internals/poly/sampling.hpp"
#include "kyber/internals/utility/params.hpp"
#include "kyber/internals/utility/utils.hpp"
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/poly_vec.hpp"
#include "ml_kem/internals/poly/sampling.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include "ml_kem/internals/utility/utils.hpp"
#include "sha3_512.hpp"
// Public Key Encryption Scheme
@@ -14,7 +14,7 @@ namespace k_pke {
template<size_t k, size_t eta1>
static inline void
keygen(std::span<const uint8_t, 32> d, std::span<uint8_t, k * 12 * 32 + 32> pubkey, std::span<uint8_t, k * 12 * 32> seckey)
requires(kyber_params::check_keygen_params(k, eta1))
requires(ml_kem_params::check_keygen_params(k, eta1))
{
std::array<uint8_t, 64> g_out{};
auto _g_out = std::span(g_out);
@@ -28,33 +28,33 @@ keygen(std::span<const uint8_t, 32> d, std::span<uint8_t, k * 12 * 32 + 32> pubk
const auto sigma = _g_out.template subspan<rho.size(), 32>();
std::array<field::zq_t, k * k * ntt::N> A_prime{};
kyber_utils::generate_matrix<k, false>(A_prime, rho);
ml_kem_utils::generate_matrix<k, false>(A_prime, rho);
uint8_t N = 0;
std::array<field::zq_t, k * ntt::N> s{};
kyber_utils::generate_vector<k, eta1>(s, sigma, N);
ml_kem_utils::generate_vector<k, eta1>(s, sigma, N);
N += k;
std::array<field::zq_t, k * ntt::N> e{};
kyber_utils::generate_vector<k, eta1>(e, sigma, N);
ml_kem_utils::generate_vector<k, eta1>(e, sigma, N);
N += k;
kyber_utils::poly_vec_ntt<k>(s);
kyber_utils::poly_vec_ntt<k>(e);
ml_kem_utils::poly_vec_ntt<k>(s);
ml_kem_utils::poly_vec_ntt<k>(e);
std::array<field::zq_t, k * ntt::N> t_prime{};
kyber_utils::matrix_multiply<k, k, k, 1>(A_prime, s, t_prime);
kyber_utils::poly_vec_add_to<k>(e, t_prime);
ml_kem_utils::matrix_multiply<k, k, k, 1>(A_prime, s, t_prime);
ml_kem_utils::poly_vec_add_to<k>(e, t_prime);
constexpr size_t pkoff = k * 12 * 32;
auto _pubkey0 = pubkey.template subspan<0, pkoff>();
auto _pubkey1 = pubkey.template subspan<pkoff, 32>();
kyber_utils::poly_vec_encode<k, 12>(t_prime, _pubkey0);
ml_kem_utils::poly_vec_encode<k, 12>(t_prime, _pubkey0);
std::copy(rho.begin(), rho.end(), _pubkey1.begin());
kyber_utils::poly_vec_encode<k, 12>(s, seckey);
ml_kem_utils::poly_vec_encode<k, 12>(s, seckey);
}
// Given a *valid* K-PKE public key, 32 -bytes message ( to be encrypted ) and 32 -bytes random coin
@@ -70,7 +70,7 @@ encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
std::span<const uint8_t, 32> msg,
std::span<const uint8_t, 32> rcoin,
std::span<uint8_t, k * du * 32 + dv * 32> enc)
requires(kyber_params::check_encrypt_params(k, eta1, eta2, du, dv))
requires(ml_kem_params::check_encrypt_params(k, eta1, eta2, du, dv))
{
constexpr size_t pkoff = k * 12 * 32;
auto _pubkey0 = pubkey.template subspan<0, pkoff>();
@@ -79,60 +79,60 @@ encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
std::array<field::zq_t, k * ntt::N> t_prime{};
std::array<uint8_t, _pubkey0.size()> encoded_tprime{};
kyber_utils::poly_vec_decode<k, 12>(_pubkey0, t_prime);
kyber_utils::poly_vec_encode<k, 12>(t_prime, encoded_tprime);
ml_kem_utils::poly_vec_decode<k, 12>(_pubkey0, t_prime);
ml_kem_utils::poly_vec_encode<k, 12>(t_prime, encoded_tprime);
using encoded_pkey_t = std::span<const uint8_t, _pubkey0.size()>;
const auto are_equal = kyber_utils::ct_memcmp(encoded_pkey_t(_pubkey0), encoded_pkey_t(encoded_tprime));
const auto are_equal = ml_kem_utils::ct_memcmp(encoded_pkey_t(_pubkey0), encoded_pkey_t(encoded_tprime));
if (are_equal == 0u) {
// Got an invalid public key
return false;
}
std::array<field::zq_t, k * k * ntt::N> A_prime{};
kyber_utils::generate_matrix<k, true>(A_prime, rho);
ml_kem_utils::generate_matrix<k, true>(A_prime, rho);
uint8_t N = 0;
std::array<field::zq_t, k * ntt::N> r{};
kyber_utils::generate_vector<k, eta1>(r, rcoin, N);
ml_kem_utils::generate_vector<k, eta1>(r, rcoin, N);
N += k;
std::array<field::zq_t, k * ntt::N> e1{};
kyber_utils::generate_vector<k, eta2>(e1, rcoin, N);
ml_kem_utils::generate_vector<k, eta2>(e1, rcoin, N);
N += k;
std::array<field::zq_t, ntt::N> e2{};
kyber_utils::generate_vector<1, eta2>(e2, rcoin, N);
ml_kem_utils::generate_vector<1, eta2>(e2, rcoin, N);
kyber_utils::poly_vec_ntt<k>(r);
ml_kem_utils::poly_vec_ntt<k>(r);
std::array<field::zq_t, k * ntt::N> u{};
kyber_utils::matrix_multiply<k, k, k, 1>(A_prime, r, u);
kyber_utils::poly_vec_intt<k>(u);
kyber_utils::poly_vec_add_to<k>(e1, u);
ml_kem_utils::matrix_multiply<k, k, k, 1>(A_prime, r, u);
ml_kem_utils::poly_vec_intt<k>(u);
ml_kem_utils::poly_vec_add_to<k>(e1, u);
std::array<field::zq_t, ntt::N> v{};
kyber_utils::matrix_multiply<1, k, k, 1>(t_prime, r, v);
kyber_utils::poly_vec_intt<1>(v);
kyber_utils::poly_vec_add_to<1>(e2, v);
ml_kem_utils::matrix_multiply<1, k, k, 1>(t_prime, r, v);
ml_kem_utils::poly_vec_intt<1>(v);
ml_kem_utils::poly_vec_add_to<1>(e2, v);
std::array<field::zq_t, ntt::N> m{};
kyber_utils::decode<1>(msg, m);
kyber_utils::poly_decompress<1>(m);
kyber_utils::poly_vec_add_to<1>(m, v);
ml_kem_utils::decode<1>(msg, m);
ml_kem_utils::poly_decompress<1>(m);
ml_kem_utils::poly_vec_add_to<1>(m, v);
constexpr size_t encoff = k * du * 32;
auto _enc0 = enc.template subspan<0, encoff>();
auto _enc1 = enc.template subspan<encoff, dv * 32>();
kyber_utils::poly_vec_compress<k, du>(u);
kyber_utils::poly_vec_encode<k, du>(u, _enc0);
ml_kem_utils::poly_vec_compress<k, du>(u);
ml_kem_utils::poly_vec_encode<k, du>(u, _enc0);
kyber_utils::poly_compress<dv>(v);
kyber_utils::encode<dv>(v, _enc1);
ml_kem_utils::poly_compress<dv>(v);
ml_kem_utils::encode<dv>(v, _enc1);
return true;
}
@@ -144,7 +144,7 @@ encrypt(std::span<const uint8_t, k * 12 * 32 + 32> pubkey,
template<size_t k, size_t du, size_t dv>
static inline void
decrypt(std::span<const uint8_t, k * 12 * 32> seckey, std::span<const uint8_t, k * du * 32 + dv * 32> enc, std::span<uint8_t, 32> dec)
requires(kyber_params::check_decrypt_params(k, du, dv))
requires(ml_kem_params::check_decrypt_params(k, du, dv))
{
constexpr size_t encoff = k * du * 32;
auto _enc0 = enc.template subspan<0, encoff>();
@@ -152,27 +152,27 @@ decrypt(std::span<const uint8_t, k * 12 * 32> seckey, std::span<const uint8_t, k
std::array<field::zq_t, k * ntt::N> u{};
kyber_utils::poly_vec_decode<k, du>(_enc0, u);
kyber_utils::poly_vec_decompress<k, du>(u);
ml_kem_utils::poly_vec_decode<k, du>(_enc0, u);
ml_kem_utils::poly_vec_decompress<k, du>(u);
std::array<field::zq_t, ntt::N> v{};
kyber_utils::decode<dv>(_enc1, v);
kyber_utils::poly_decompress<dv>(v);
ml_kem_utils::decode<dv>(_enc1, v);
ml_kem_utils::poly_decompress<dv>(v);
std::array<field::zq_t, k * ntt::N> s_prime{};
kyber_utils::poly_vec_decode<k, 12>(seckey, s_prime);
ml_kem_utils::poly_vec_decode<k, 12>(seckey, s_prime);
kyber_utils::poly_vec_ntt<k>(u);
ml_kem_utils::poly_vec_ntt<k>(u);
std::array<field::zq_t, ntt::N> t{};
kyber_utils::matrix_multiply<1, k, k, 1>(s_prime, u, t);
kyber_utils::poly_vec_intt<1>(t);
kyber_utils::poly_vec_sub_from<1>(t, v);
ml_kem_utils::matrix_multiply<1, k, k, 1>(s_prime, u, t);
ml_kem_utils::poly_vec_intt<1>(t);
ml_kem_utils::poly_vec_sub_from<1>(t, v);
kyber_utils::poly_compress<1>(v);
kyber_utils::encode<1>(v, dec);
ml_kem_utils::poly_compress<1>(v);
ml_kem_utils::encode<1>(v, dec);
}
}

View File

@@ -1,14 +1,14 @@
#pragma once
#include "kyber/internals/rng/prng.hpp"
#include "ml_kem/internals/rng/prng.hpp"
#include <bit>
#include <cstdint>
namespace field {
// Kyber Prime Field Modulus ( = 3329 )
// Ml_kem Prime Field Modulus ( = 3329 )
static constexpr uint32_t Q = (1u << 8) * 13 + 1;
// Bit width of Kyber Prime Field Modulus ( = 12 )
// Bit width of Ml_kem Prime Field Modulus ( = 12 )
static constexpr size_t Q_BIT_WIDTH = std::bit_width(Q);
// Precomputed Barrett Reduction Constant

View File

@@ -1,6 +1,6 @@
#pragma once
#include "k_pke.hpp"
#include "kyber/internals/utility/utils.hpp"
#include "ml_kem/internals/utility/utils.hpp"
#include "sha3_256.hpp"
#include "sha3_512.hpp"
#include "shake256.hpp"
@@ -15,9 +15,9 @@ template<size_t k, size_t eta1>
static inline void
keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
std::span<const uint8_t, 32> z, // used in CCA-KEM
std::span<uint8_t, kyber_utils::get_kem_public_key_len(k)> pubkey,
std::span<uint8_t, kyber_utils::get_kem_secret_key_len(k)> seckey)
requires(kyber_params::check_keygen_params(k, eta1))
std::span<uint8_t, ml_kem_utils::get_kem_public_key_len(k)> pubkey,
std::span<uint8_t, ml_kem_utils::get_kem_secret_key_len(k)> seckey)
requires(ml_kem_params::check_keygen_params(k, eta1))
{
static constexpr size_t skoff0 = k * 12 * 32;
static constexpr size_t skoff1 = skoff0 + pubkey.size();
@@ -52,10 +52,10 @@ keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result, it might fail because of malformed input public key")]] static inline bool
encapsulate(std::span<const uint8_t, 32> m,
std::span<const uint8_t, kyber_utils::get_kem_public_key_len(k)> pubkey,
std::span<uint8_t, kyber_utils::get_kem_cipher_text_len(k, du, dv)> cipher,
std::span<const uint8_t, ml_kem_utils::get_kem_public_key_len(k)> pubkey,
std::span<uint8_t, ml_kem_utils::get_kem_cipher_text_len(k, du, dv)> cipher,
std::span<uint8_t, 32> shared_secret)
requires(kyber_params::check_encap_params(k, eta1, eta2, du, dv))
requires(ml_kem_params::check_encap_params(k, eta1, eta2, du, dv))
{
std::array<uint8_t, m.size() + sha3_256::DIGEST_LEN> g_in{};
std::array<uint8_t, sha3_512::DIGEST_LEN> g_out{};
@@ -99,10 +99,10 @@ encapsulate(std::span<const uint8_t, 32> m,
// See algorithm 17 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
static inline void
decapsulate(std::span<const uint8_t, kyber_utils::get_kem_secret_key_len(k)> seckey,
std::span<const uint8_t, kyber_utils::get_kem_cipher_text_len(k, du, dv)> cipher,
decapsulate(std::span<const uint8_t, ml_kem_utils::get_kem_secret_key_len(k)> seckey,
std::span<const uint8_t, ml_kem_utils::get_kem_cipher_text_len(k, du, dv)> cipher,
std::span<uint8_t, 32> shared_secret)
requires(kyber_params::check_decap_params(k, eta1, eta2, du, dv))
requires(ml_kem_params::check_decap_params(k, eta1, eta2, du, dv))
{
constexpr size_t sklen = k * 12 * 32;
constexpr size_t pklen = k * 12 * 32 + 32;
@@ -149,8 +149,8 @@ decapsulate(std::span<const uint8_t, kyber_utils::get_kem_secret_key_len(k)> sec
// line 9-12 of algorithm 17, in constant-time
using kdf_t = std::span<const uint8_t, shared_secret.size()>;
const uint32_t cond = kyber_utils::ct_memcmp(cipher, std::span<const uint8_t, ctlen>(c_prime));
kyber_utils::ct_cond_memcpy(cond, shared_secret, kdf_t(_g_out0), kdf_t(z));
const uint32_t cond = ml_kem_utils::ct_memcmp(cipher, std::span<const uint8_t, ctlen>(c_prime));
ml_kem_utils::ct_cond_memcpy(cond, shared_secret, kdf_t(_g_out0), kdf_t(z));
}
}

View File

@@ -1,23 +1,23 @@
#pragma once
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/ntt.hpp"
#include "kyber/internals/utility/params.hpp"
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include <span>
// IND-CPA-secure Public Key Encryption Scheme Utilities
namespace kyber_utils {
namespace ml_kem_utils {
// Given an element x ∈ Z_q | q = 3329, this routine compresses it by discarding
// some low-order bits, computing y ∈ [0, 2^d) | d < round(log2(q))
//
// See top of page 5 of Kyber specification
// See top of page 5 of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
//
// Following implementation collects inspiration from https://github.com/FiloSottile/mlkem768/blob/cffbfb96c407b3cfc9f6e1749475b673794402c1/mlkem768.go#L395-L425.
template<size_t d>
static inline constexpr field::zq_t
compress(const field::zq_t x)
requires(kyber_params::check_d(d))
requires(ml_kem_params::check_d(d))
{
constexpr uint16_t mask = (1u << d) - 1;
@@ -35,14 +35,14 @@ compress(const field::zq_t x)
// it back to y ∈ Z_q | q = 3329
//
// This routine recovers the compressed element with error probability as
// defined in eq. 2 of Kyber specification.
// defined in eq. 2 of Ml_kem specification.
//
// See top of page 5 of Kyber specification
// See top of page 5 of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t d>
static inline constexpr field::zq_t
decompress(const field::zq_t x)
requires(kyber_params::check_d(d))
requires(ml_kem_params::check_d(d))
{
constexpr uint32_t t0 = 1u << d;
constexpr uint32_t t1 = t0 >> 1;
@@ -59,7 +59,7 @@ decompress(const field::zq_t x)
template<size_t d>
static inline constexpr void
poly_compress(std::span<field::zq_t, ntt::N> poly)
requires(kyber_params::check_d(d))
requires(ml_kem_params::check_d(d))
{
for (size_t i = 0; i < poly.size(); i++) {
poly[i] = compress<d>(poly[i]);
@@ -71,7 +71,7 @@ poly_compress(std::span<field::zq_t, ntt::N> poly)
template<size_t d>
static inline constexpr void
poly_decompress(std::span<field::zq_t, ntt::N> poly)
requires(kyber_params::check_d(d))
requires(ml_kem_params::check_d(d))
{
for (size_t i = 0; i < poly.size(); i++) {
poly[i] = decompress<d>(poly[i]);

View File

@@ -1,9 +1,9 @@
#pragma once
#include "kyber/internals/math/field.hpp"
#include "ml_kem/internals/math/field.hpp"
#include <array>
#include <cstring>
// (inverse) Number Theoretic Transform for degree-255 polynomial, over Kyber
// (inverse) Number Theoretic Transform for degree-255 polynomial, over Ml_kem
// Prime Field Zq | q = 3329
namespace ntt {
@@ -182,7 +182,7 @@ intt(std::span<field::zq_t, N> poly)
//
// h = f * g mod X ^ 2 ζ ^ (2 * br<7>(i) + 1) | i ∈ [0, 128)
//
// See page 6 of Kyber specification
// See page 6 of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
static inline constexpr void
basemul(std::span<const field::zq_t, 2> f, // degree-1 polynomial

View File

@@ -1,14 +1,14 @@
#pragma once
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/compression.hpp"
#include "kyber/internals/poly/ntt.hpp"
#include "kyber/internals/poly/serialize.hpp"
#include "kyber/internals/utility/params.hpp"
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/compression.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/poly/serialize.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include <array>
#include <cstdint>
// IND-CPA-secure Public Key Encryption Scheme Utilities
namespace kyber_utils {
namespace ml_kem_utils {
// Given two matrices ( in NTT domain ) of compatible dimension, where each
// matrix element is a degree-255 polynomial over Z_q | q = 3329, this routine
@@ -18,7 +18,7 @@ static inline constexpr void
matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
std::span<const field::zq_t, b_rows * b_cols * ntt::N> b,
std::span<field::zq_t, a_rows * b_cols * ntt::N> c)
requires(kyber_params::check_matrix_dim(a_cols, b_rows))
requires(ml_kem_params::check_matrix_dim(a_cols, b_rows))
{
using poly_t = std::span<const field::zq_t, ntt::N>;
@@ -49,7 +49,7 @@ matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
template<size_t k>
static inline constexpr void
poly_vec_ntt(std::span<field::zq_t, k * ntt::N> vec)
requires((k == 1) || kyber_params::check_k(k))
requires((k == 1) || ml_kem_params::check_k(k))
{
using poly_t = std::span<field::zq_t, ntt::N>;
@@ -66,7 +66,7 @@ poly_vec_ntt(std::span<field::zq_t, k * ntt::N> vec)
template<size_t k>
static inline constexpr void
poly_vec_intt(std::span<field::zq_t, k * ntt::N> vec)
requires((k == 1) || kyber_params::check_k(k))
requires((k == 1) || ml_kem_params::check_k(k))
{
using poly_t = std::span<field::zq_t, ntt::N>;
@@ -81,7 +81,7 @@ poly_vec_intt(std::span<field::zq_t, k * ntt::N> vec)
template<size_t k>
static inline constexpr void
poly_vec_add_to(std::span<const field::zq_t, k * ntt::N> src, std::span<field::zq_t, k * ntt::N> dst)
requires((k == 1) || kyber_params::check_k(k))
requires((k == 1) || ml_kem_params::check_k(k))
{
constexpr size_t cnt = k * ntt::N;
@@ -95,7 +95,7 @@ poly_vec_add_to(std::span<const field::zq_t, k * ntt::N> src, std::span<field::z
template<size_t k>
static inline constexpr void
poly_vec_sub_from(std::span<const field::zq_t, k * ntt::N> src, std::span<field::zq_t, k * ntt::N> dst)
requires((k == 1) || kyber_params::check_k(k))
requires((k == 1) || ml_kem_params::check_k(k))
{
constexpr size_t cnt = k * ntt::N;
@@ -110,7 +110,7 @@ poly_vec_sub_from(std::span<const field::zq_t, k * ntt::N> src, std::span<field:
template<size_t k, size_t l>
static inline void
poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src, std::span<uint8_t, k * 32 * l> dst)
requires(kyber_params::check_k(k))
requires(ml_kem_params::check_k(k))
{
using poly_t = std::span<const field::zq_t, src.size() / k>;
using serialized_t = std::span<uint8_t, dst.size() / k>;
@@ -119,7 +119,7 @@ poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src, std::span<uint8_t,
const size_t off0 = i * ntt::N;
const size_t off1 = i * l * 32;
kyber_utils::encode<l>(poly_t(src.subspan(off0, ntt::N)), serialized_t(dst.subspan(off1, 32 * l)));
ml_kem_utils::encode<l>(poly_t(src.subspan(off0, ntt::N)), serialized_t(dst.subspan(off1, 32 * l)));
}
}
@@ -129,7 +129,7 @@ poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src, std::span<uint8_t,
template<size_t k, size_t l>
static inline void
poly_vec_decode(std::span<const uint8_t, k * 32 * l> src, std::span<field::zq_t, k * ntt::N> dst)
requires(kyber_params::check_k(k))
requires(ml_kem_params::check_k(k))
{
using serialized_t = std::span<const uint8_t, src.size() / k>;
using poly_t = std::span<field::zq_t, dst.size() / k>;
@@ -138,7 +138,7 @@ poly_vec_decode(std::span<const uint8_t, k * 32 * l> src, std::span<field::zq_t,
const size_t off0 = i * l * 32;
const size_t off1 = i * ntt::N;
kyber_utils::decode<l>(serialized_t(src.subspan(off0, 32 * l)), poly_t(dst.subspan(off1, ntt::N)));
ml_kem_utils::decode<l>(serialized_t(src.subspan(off0, 32 * l)), poly_t(dst.subspan(off1, ntt::N)));
}
}
@@ -147,13 +147,13 @@ poly_vec_decode(std::span<const uint8_t, k * 32 * l> src, std::span<field::zq_t,
template<size_t k, size_t d>
static inline constexpr void
poly_vec_compress(std::span<field::zq_t, k * ntt::N> vec)
requires(kyber_params::check_k(k))
requires(ml_kem_params::check_k(k))
{
using poly_t = std::span<field::zq_t, vec.size() / k>;
for (size_t i = 0; i < k; i++) {
const size_t off = i * ntt::N;
kyber_utils::poly_compress<d>(poly_t(vec.subspan(off, ntt::N)));
ml_kem_utils::poly_compress<d>(poly_t(vec.subspan(off, ntt::N)));
}
}
@@ -162,13 +162,13 @@ poly_vec_compress(std::span<field::zq_t, k * ntt::N> vec)
template<size_t k, size_t d>
static inline constexpr void
poly_vec_decompress(std::span<field::zq_t, k * ntt::N> vec)
requires(kyber_params::check_k(k))
requires(ml_kem_params::check_k(k))
{
using poly_t = std::span<field::zq_t, vec.size() / k>;
for (size_t i = 0; i < k; i++) {
const size_t off = i * ntt::N;
kyber_utils::poly_decompress<d>(poly_t(vec.subspan(off, ntt::N)));
ml_kem_utils::poly_decompress<d>(poly_t(vec.subspan(off, ntt::N)));
}
}

View File

@@ -1,14 +1,14 @@
#pragma once
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/ntt.hpp"
#include "kyber/internals/utility/params.hpp"
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include "shake128.hpp"
#include "shake256.hpp"
#include <array>
#include <cstdint>
// IND-CPA-secure Public Key Encryption Scheme Utilities
namespace kyber_utils {
namespace ml_kem_utils {
// Uniform sampling in R_q | q = 3329
//
@@ -17,7 +17,7 @@ namespace kyber_utils {
// to uniform random byte stream, produced polynomial coefficients are also
// statiscally close to randomly sampled elements of R_q.
//
// See algorithm 1, defined in Kyber specification
// See algorithm 1, defined in Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
inline void
parse(shake128::shake128_t& hasher, std::span<field::zq_t, ntt::N> poly)
@@ -51,12 +51,12 @@ parse(shake128::shake128_t& hasher, std::span<field::zq_t, ntt::N> poly)
// domain, by sampling from a XOF ( read SHAKE128 ), which is seeded with 32
// -bytes key and two nonces ( each of 1 -byte )
//
// See step (4-8) of algorithm 4/ 5, defined in Kyber specification
// See step (4-8) of algorithm 4/ 5, defined in Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t k, bool transpose>
static inline void
generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat, std::span<const uint8_t, 32> rho)
requires(kyber_params::check_k(k))
requires(ml_kem_params::check_k(k))
{
std::array<uint8_t, rho.size() + 2> xof_in{};
std::copy(rho.begin(), rho.end(), xof_in.begin());
@@ -88,12 +88,12 @@ generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat, std::span<const uint
// A degree 255 polynomial deterministically sampled from 64 * eta -bytes output
// of a pseudorandom function ( PRF )
//
// See algorithm 2, defined in Kyber specification
// See algorithm 2, defined in Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t eta>
static inline void
cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
requires(kyber_params::check_eta(eta))
requires(ml_kem_params::check_eta(eta))
{
if constexpr (eta == 2) {
static_assert(eta == 2, "η must be 2 !");
@@ -140,12 +140,12 @@ cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
}
// Sample a polynomial vector from Bη, following step (9-12) of algorithm 4,
// defined in Kyber specification
// defined in Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t k, size_t eta>
static inline void
generate_vector(std::span<field::zq_t, k * ntt::N> vec, std::span<const uint8_t, 32> sigma, const uint8_t nonce)
requires((k == 1) || kyber_params::check_k(k))
requires((k == 1) || ml_kem_params::check_k(k))
{
std::array<uint8_t, 64 * eta> prf_out{};
std::array<uint8_t, sigma.size() + 1> prf_in{};
@@ -162,7 +162,7 @@ generate_vector(std::span<field::zq_t, k * ntt::N> vec, std::span<const uint8_t,
hasher.squeeze(prf_out);
using poly_t = std::span<field::zq_t, vec.size() / k>;
kyber_utils::cbd<eta>(prf_out, poly_t(vec.subspan(off, ntt::N)));
ml_kem_utils::cbd<eta>(prf_out, poly_t(vec.subspan(off, ntt::N)));
}
}

View File

@@ -1,22 +1,22 @@
#pragma once
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/ntt.hpp"
#include "kyber/internals/utility/params.hpp"
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include <cstring>
// IND-CPA-secure Public Key Encryption Scheme Utilities
namespace kyber_utils {
namespace ml_kem_utils {
// Given a degree-255 polynomial, where significant portion of each ( total 256
// of them ) coefficient ∈ [0, 2^l), this routine serializes the polynomial to a
// byte array of length 32 * l -bytes
//
// See algorithm 3 described in section 1.1 ( page 7 ) of Kyber specification
// See algorithm 3 described in section 1.1 ( page 7 ) of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t l>
static inline void
encode(std::span<const field::zq_t, ntt::N> poly, std::span<uint8_t, 32 * l> arr)
requires(kyber_params::check_l(l))
requires(ml_kem_params::check_l(l))
{
std::fill(arr.begin(), arr.end(), 0);
@@ -147,12 +147,12 @@ encode(std::span<const field::zq_t, ntt::N> poly, std::span<uint8_t, 32 * l> arr
// polynomial of degree 255 s.t. significant portion of each ( total 256 of them
// ) coefficient ∈ [0, 2^l)
//
// See algorithm 3 described in section 1.1 ( page 7 ) of Kyber specification
// See algorithm 3 described in section 1.1 ( page 7 ) of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t l>
static inline void
decode(std::span<const uint8_t, 32 * l> arr, std::span<field::zq_t, ntt::N> poly)
requires(kyber_params::check_l(l))
requires(ml_kem_params::check_l(l))
{
if constexpr (l == 1) {
constexpr size_t itr_cnt = ntt::N >> 3;

View File

@@ -2,7 +2,7 @@
#include <cstddef>
// Holds compile-time executable functions, ensuring that functions are invoked with proper arguments.
namespace kyber_params {
namespace ml_kem_params {
// Compile-time check to ensure that number of bits ( read `d` ) to consider during
// polynomial coefficient compression/ decompression is within tolerable bounds.

View File

@@ -2,7 +2,7 @@
#include "subtle.hpp"
#include <span>
namespace kyber_utils {
namespace ml_kem_utils {
// Given two byte arrays of equal length, this routine can be used for comparing them in constant-time,
// producing truth value (0xffffffff) in case of equality, otherwise it returns false value (0x00000000).

View File

@@ -1,7 +1,7 @@
#pragma once
#include "kyber/internals/ml_kem.hpp"
#include "ml_kem/internals/ml_kem.hpp"
namespace kyber1024_kem {
namespace ml_kem_1024 {
// ML-KEM Key Encapsulation Mechanism instantiated with ML-KEM-1024 parameters
// See row 3 of table 2 of ML-KEM specification @ https://doi.org/10.6028/NIST.FIPS.203.ipd
@@ -19,16 +19,16 @@ static constexpr size_t SEED_D_BYTE_LEN = 32;
static constexpr size_t SEED_Z_BYTE_LEN = 32;
// 1568 -bytes ML-KEM-1024 public key
static constexpr size_t PKEY_BYTE_LEN = kyber_utils::get_kem_public_key_len(k);
static constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
// 3168 -bytes ML-KEM-1024 secret key
static constexpr size_t SKEY_BYTE_LEN = kyber_utils::get_kem_secret_key_len(k);
static constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
// 32 -bytes seed `m`, used in ML-KEM encapsulation
static constexpr size_t SEED_M_BYTE_LEN = 32;
// 1568 -bytes ML-KEM-1024 cipher text
static constexpr size_t CIPHER_TEXT_BYTE_LEN = kyber_utils::get_kem_cipher_text_len(k, du, dv);
static constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
// 32 -bytes ML-KEM-1024 shared secret
static constexpr size_t SHARED_SECRET_BYTE_LEN = 32;

View File

@@ -1,7 +1,7 @@
#pragma once
#include "kyber/internals/ml_kem.hpp"
#include "ml_kem/internals/ml_kem.hpp"
namespace kyber512_kem {
namespace ml_kem_512 {
// ML-KEM Key Encapsulation Mechanism instantiated with ML-KEM-512 parameters
// See row 1 of table 2 of ML-KEM specification @ https://doi.org/10.6028/NIST.FIPS.203.ipd
@@ -19,16 +19,16 @@ static constexpr size_t SEED_D_BYTE_LEN = 32;
static constexpr size_t SEED_Z_BYTE_LEN = 32;
// 800 -bytes ML-KEM-512 public key
static constexpr size_t PKEY_BYTE_LEN = kyber_utils::get_kem_public_key_len(k);
static constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
// 1632 -bytes ML-KEM-512 secret key
static constexpr size_t SKEY_BYTE_LEN = kyber_utils::get_kem_secret_key_len(k);
static constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
// 32 -bytes seed `m`, used in ML-KEM encapsulation
static constexpr size_t SEED_M_BYTE_LEN = 32;
// 768 -bytes ML-KEM-512 cipher text
static constexpr size_t CIPHER_TEXT_BYTE_LEN = kyber_utils::get_kem_cipher_text_len(k, du, dv);
static constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
// 32 -bytes ML-KEM-512 shared secret
static constexpr size_t SHARED_SECRET_BYTE_LEN = 32;

View File

@@ -1,7 +1,7 @@
#pragma once
#include "kyber/internals/ml_kem.hpp"
#include "ml_kem/internals/ml_kem.hpp"
namespace kyber768_kem {
namespace ml_kem_768 {
// ML-KEM Key Encapsulation Mechanism instantiated with ML-KEM-768 parameters
// See row 2 of table 2 of ML-KEM specification @ https://doi.org/10.6028/NIST.FIPS.203.ipd
@@ -19,16 +19,16 @@ static constexpr size_t SEED_D_BYTE_LEN = 32;
static constexpr size_t SEED_Z_BYTE_LEN = 32;
// 1184 -bytes ML-KEM-768 public key
static constexpr size_t PKEY_BYTE_LEN = kyber_utils::get_kem_public_key_len(k);
static constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
// 2400 -bytes ML-KEM-768 secret key
static constexpr size_t SKEY_BYTE_LEN = kyber_utils::get_kem_secret_key_len(k);
static constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
// 32 -bytes seed `m`, used in ML-KEM encapsulation
static constexpr size_t SEED_M_BYTE_LEN = 32;
// 1088 -bytes ML-KEM-768 cipher text
static constexpr size_t CIPHER_TEXT_BYTE_LEN = kyber_utils::get_kem_cipher_text_len(k, du, dv);
static constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
// 32 -bytes ML-KEM-768 shared secret
static constexpr size_t SHARED_SECRET_BYTE_LEN = 32;

View File

@@ -1,4 +1,4 @@
#include "kyber/kyber768_kem.hpp"
#include "ml_kem/ml_kem_1024.hpp"
#define DUDECT_IMPLEMENTATION
#define DUDECT_VISIBLITY_STATIC
@@ -12,41 +12,41 @@ do_one_computation(uint8_t* const data)
constexpr size_t doff0 = 0;
constexpr size_t doff1 = doff0 + SEED_LEN;
constexpr size_t doff2 = doff1 + 1;
constexpr size_t doff3 = doff2 + kyber768_kem::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff4 = doff3 + kyber768_kem::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff3 = doff2 + ml_kem_1024::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff4 = doff3 + ml_kem_1024::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff5 = doff4 + SEED_LEN;
constexpr size_t doff6 = doff5 + SEED_LEN;
std::array<field::zq_t, kyber768_kem::k * ntt::N> poly_vec{};
std::array<uint8_t, kyber768_kem::k * 32 * kyber768_kem::du> byte_arr{};
std::array<field::zq_t, ml_kem_1024::k * ntt::N> poly_vec{};
std::array<uint8_t, ml_kem_1024::k * 32 * ml_kem_1024::du> byte_arr{};
auto sigma = std::span<const uint8_t, SEED_LEN>(data + doff0, doff1 - doff0);
const auto nonce = data[doff1];
// Generate new secret polynomial vector
kyber_utils::generate_vector<kyber768_kem::k, kyber768_kem::η1>(poly_vec, sigma, nonce);
ml_kem_utils::generate_vector<ml_kem_1024::k, ml_kem_1024::η1>(poly_vec, sigma, nonce);
// Apply NTT on that secret vector
kyber_utils::poly_vec_ntt<kyber768_kem::k>(poly_vec);
ml_kem_utils::poly_vec_ntt<ml_kem_1024::k>(poly_vec);
// Apply iNTT on bit-reversed NTT form of secret polynomial vector
kyber_utils::poly_vec_intt<kyber768_kem::k>(poly_vec);
ml_kem_utils::poly_vec_intt<ml_kem_1024::k>(poly_vec);
// Compress coefficients of polynomial vector
kyber_utils::poly_vec_compress<kyber768_kem::k, kyber768_kem::du>(poly_vec);
ml_kem_utils::poly_vec_compress<ml_kem_1024::k, ml_kem_1024::du>(poly_vec);
// Serialize polynomial vector into byte array
kyber_utils::poly_vec_encode<kyber768_kem::k, kyber768_kem::du>(poly_vec, byte_arr);
ml_kem_utils::poly_vec_encode<ml_kem_1024::k, ml_kem_1024::du>(poly_vec, byte_arr);
// Recover coefficients of polynomial vector from byte array
kyber_utils::poly_vec_decode<kyber768_kem::k, kyber768_kem::du>(byte_arr, poly_vec);
ml_kem_utils::poly_vec_decode<ml_kem_1024::k, ml_kem_1024::du>(byte_arr, poly_vec);
// Decompress coefficients of polynomial vector
kyber_utils::poly_vec_decompress<kyber768_kem::k, kyber768_kem::du>(poly_vec);
ml_kem_utils::poly_vec_decompress<ml_kem_1024::k, ml_kem_1024::du>(poly_vec);
std::array<uint8_t, SEED_LEN> sink{};
auto _sink = std::span(sink);
using ctxt_t = std::span<const uint8_t, kyber768_kem::CIPHER_TEXT_BYTE_LEN>;
using ctxt_t = std::span<const uint8_t, ml_kem_1024::CIPHER_TEXT_BYTE_LEN>;
using seed_t = std::span<const uint8_t, SEED_LEN>;
// Ensure Fujisaki-Okamoto transform, used during decapsulation, is constant-time
const uint32_t cond = kyber_utils::ct_memcmp(ctxt_t(data + doff2, doff3 - doff2), ctxt_t(data + doff3, doff4 - doff3));
kyber_utils::ct_cond_memcpy(cond, _sink, seed_t(data + doff4, doff5 - doff4), seed_t(data + doff5, doff6 - doff5));
const uint32_t cond = ml_kem_utils::ct_memcmp(ctxt_t(data + doff2, doff3 - doff2), ctxt_t(data + doff3, doff4 - doff3));
ml_kem_utils::ct_cond_memcpy(cond, _sink, seed_t(data + doff4, doff5 - doff4), seed_t(data + doff5, doff6 - doff5));
// Just so that optimizer doesn't remove above function calls !
return static_cast<uint8_t>(poly_vec[0].raw() ^ poly_vec[poly_vec.size() - 1].raw()) ^ // result of generating vector of polynomials
@@ -69,14 +69,14 @@ prepare_inputs(dudect_config_t* const c, uint8_t* const input_data, uint8_t* con
}
dudect_state_t
test_kyber768_kem()
test_ml_kem_1024()
{
constexpr size_t chunk_size = SEED_LEN + // bytes holding seed `sigma`
1 + // single byte nonce
kyber768_kem::CIPHER_TEXT_BYTE_LEN + // bytes holding received cipher text
kyber768_kem::CIPHER_TEXT_BYTE_LEN + // bytes for locally computed cipher text
SEED_LEN + // bytes for first source buffer to copy from
SEED_LEN; // bytes for second source buffer to copy from
constexpr size_t chunk_size = SEED_LEN + // bytes holding seed `sigma`
1 + // single byte nonce
ml_kem_1024::CIPHER_TEXT_BYTE_LEN + // bytes holding received cipher text
ml_kem_1024::CIPHER_TEXT_BYTE_LEN + // bytes for locally computed cipher text
SEED_LEN + // bytes for first source buffer to copy from
SEED_LEN; // bytes for second source buffer to copy from
constexpr size_t number_measurements = 1e5;
dudect_config_t config = {
@@ -100,7 +100,7 @@ test_kyber768_kem()
int
main()
{
if (test_kyber768_kem() != DUDECT_NO_LEAKAGE_EVIDENCE_YET) {
if (test_ml_kem_1024() != DUDECT_NO_LEAKAGE_EVIDENCE_YET) {
return EXIT_FAILURE;
}

View File

@@ -1,4 +1,4 @@
#include "kyber/kyber512_kem.hpp"
#include "ml_kem/ml_kem_512.hpp"
#include <cstdio>
#define DUDECT_IMPLEMENTATION
@@ -13,41 +13,41 @@ do_one_computation(uint8_t* const data)
constexpr size_t doff0 = 0;
constexpr size_t doff1 = doff0 + SEED_LEN;
constexpr size_t doff2 = doff1 + 1;
constexpr size_t doff3 = doff2 + kyber512_kem::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff4 = doff3 + kyber512_kem::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff3 = doff2 + ml_kem_512::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff4 = doff3 + ml_kem_512::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff5 = doff4 + SEED_LEN;
constexpr size_t doff6 = doff5 + SEED_LEN;
std::array<field::zq_t, kyber512_kem::k * ntt::N> poly_vec{};
std::array<uint8_t, kyber512_kem::k * 32 * kyber512_kem::du> byte_arr{};
std::array<field::zq_t, ml_kem_512::k * ntt::N> poly_vec{};
std::array<uint8_t, ml_kem_512::k * 32 * ml_kem_512::du> byte_arr{};
auto sigma = std::span<const uint8_t, SEED_LEN>(data + doff0, doff1 - doff0);
const auto nonce = data[doff1];
// Generate new secret polynomial vector
kyber_utils::generate_vector<kyber512_kem::k, kyber512_kem::η1>(poly_vec, sigma, nonce);
ml_kem_utils::generate_vector<ml_kem_512::k, ml_kem_512::η1>(poly_vec, sigma, nonce);
// Apply NTT on that secret vector
kyber_utils::poly_vec_ntt<kyber512_kem::k>(poly_vec);
ml_kem_utils::poly_vec_ntt<ml_kem_512::k>(poly_vec);
// Apply iNTT on bit-reversed NTT form of secret polynomial vector
kyber_utils::poly_vec_intt<kyber512_kem::k>(poly_vec);
ml_kem_utils::poly_vec_intt<ml_kem_512::k>(poly_vec);
// Compress coefficients of polynomial vector
kyber_utils::poly_vec_compress<kyber512_kem::k, kyber512_kem::du>(poly_vec);
ml_kem_utils::poly_vec_compress<ml_kem_512::k, ml_kem_512::du>(poly_vec);
// Serialize polynomial vector into byte array
kyber_utils::poly_vec_encode<kyber512_kem::k, kyber512_kem::du>(poly_vec, byte_arr);
ml_kem_utils::poly_vec_encode<ml_kem_512::k, ml_kem_512::du>(poly_vec, byte_arr);
// Recover coefficients of polynomial vector from byte array
kyber_utils::poly_vec_decode<kyber512_kem::k, kyber512_kem::du>(byte_arr, poly_vec);
ml_kem_utils::poly_vec_decode<ml_kem_512::k, ml_kem_512::du>(byte_arr, poly_vec);
// Decompress coefficients of polynomial vector
kyber_utils::poly_vec_decompress<kyber512_kem::k, kyber512_kem::du>(poly_vec);
ml_kem_utils::poly_vec_decompress<ml_kem_512::k, ml_kem_512::du>(poly_vec);
std::array<uint8_t, SEED_LEN> sink{};
auto _sink = std::span(sink);
using ctxt_t = std::span<const uint8_t, kyber512_kem::CIPHER_TEXT_BYTE_LEN>;
using ctxt_t = std::span<const uint8_t, ml_kem_512::CIPHER_TEXT_BYTE_LEN>;
using seed_t = std::span<const uint8_t, SEED_LEN>;
// Ensure Fujisaki-Okamoto transform, used during decapsulation, is constant-time
const uint32_t cond = kyber_utils::ct_memcmp(ctxt_t(data + doff2, doff3 - doff2), ctxt_t(data + doff3, doff4 - doff3));
kyber_utils::ct_cond_memcpy(cond, _sink, seed_t(data + doff4, doff5 - doff4), seed_t(data + doff5, doff6 - doff5));
const uint32_t cond = ml_kem_utils::ct_memcmp(ctxt_t(data + doff2, doff3 - doff2), ctxt_t(data + doff3, doff4 - doff3));
ml_kem_utils::ct_cond_memcpy(cond, _sink, seed_t(data + doff4, doff5 - doff4), seed_t(data + doff5, doff6 - doff5));
// Just so that optimizer doesn't remove above function calls !
return static_cast<uint8_t>(poly_vec[0].raw() ^ poly_vec[poly_vec.size() - 1].raw()) ^ // result of generating vector of polynomials
@@ -70,14 +70,14 @@ prepare_inputs(dudect_config_t* const c, uint8_t* const input_data, uint8_t* con
}
dudect_state_t
test_kyber512_kem()
test_ml_kem_512()
{
constexpr size_t chunk_size = SEED_LEN + // bytes holding seed `sigma`
1 + // single byte nonce
kyber512_kem::CIPHER_TEXT_BYTE_LEN + // bytes holding received cipher text
kyber512_kem::CIPHER_TEXT_BYTE_LEN + // bytes for locally computed cipher text
SEED_LEN + // bytes for first source buffer to copy from
SEED_LEN; // bytes for second source buffer to copy from
constexpr size_t chunk_size = SEED_LEN + // bytes holding seed `sigma`
1 + // single byte nonce
ml_kem_512::CIPHER_TEXT_BYTE_LEN + // bytes holding received cipher text
ml_kem_512::CIPHER_TEXT_BYTE_LEN + // bytes for locally computed cipher text
SEED_LEN + // bytes for first source buffer to copy from
SEED_LEN; // bytes for second source buffer to copy from
constexpr size_t number_measurements = 1e5;
dudect_config_t config = {
@@ -101,7 +101,7 @@ test_kyber512_kem()
int
main()
{
if (test_kyber512_kem() != DUDECT_NO_LEAKAGE_EVIDENCE_YET) {
if (test_ml_kem_512() != DUDECT_NO_LEAKAGE_EVIDENCE_YET) {
return EXIT_FAILURE;
}

View File

@@ -1,4 +1,4 @@
#include "kyber/kyber1024_kem.hpp"
#include "ml_kem/ml_kem_768.hpp"
#define DUDECT_IMPLEMENTATION
#define DUDECT_VISIBLITY_STATIC
@@ -12,41 +12,41 @@ do_one_computation(uint8_t* const data)
constexpr size_t doff0 = 0;
constexpr size_t doff1 = doff0 + SEED_LEN;
constexpr size_t doff2 = doff1 + 1;
constexpr size_t doff3 = doff2 + kyber1024_kem::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff4 = doff3 + kyber1024_kem::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff3 = doff2 + ml_kem_768::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff4 = doff3 + ml_kem_768::CIPHER_TEXT_BYTE_LEN;
constexpr size_t doff5 = doff4 + SEED_LEN;
constexpr size_t doff6 = doff5 + SEED_LEN;
std::array<field::zq_t, kyber1024_kem::k * ntt::N> poly_vec{};
std::array<uint8_t, kyber1024_kem::k * 32 * kyber1024_kem::du> byte_arr{};
std::array<field::zq_t, ml_kem_768::k * ntt::N> poly_vec{};
std::array<uint8_t, ml_kem_768::k * 32 * ml_kem_768::du> byte_arr{};
auto sigma = std::span<const uint8_t, SEED_LEN>(data + doff0, doff1 - doff0);
const auto nonce = data[doff1];
// Generate new secret polynomial vector
kyber_utils::generate_vector<kyber1024_kem::k, kyber1024_kem::η1>(poly_vec, sigma, nonce);
ml_kem_utils::generate_vector<ml_kem_768::k, ml_kem_768::η1>(poly_vec, sigma, nonce);
// Apply NTT on that secret vector
kyber_utils::poly_vec_ntt<kyber1024_kem::k>(poly_vec);
ml_kem_utils::poly_vec_ntt<ml_kem_768::k>(poly_vec);
// Apply iNTT on bit-reversed NTT form of secret polynomial vector
kyber_utils::poly_vec_intt<kyber1024_kem::k>(poly_vec);
ml_kem_utils::poly_vec_intt<ml_kem_768::k>(poly_vec);
// Compress coefficients of polynomial vector
kyber_utils::poly_vec_compress<kyber1024_kem::k, kyber1024_kem::du>(poly_vec);
ml_kem_utils::poly_vec_compress<ml_kem_768::k, ml_kem_768::du>(poly_vec);
// Serialize polynomial vector into byte array
kyber_utils::poly_vec_encode<kyber1024_kem::k, kyber1024_kem::du>(poly_vec, byte_arr);
ml_kem_utils::poly_vec_encode<ml_kem_768::k, ml_kem_768::du>(poly_vec, byte_arr);
// Recover coefficients of polynomial vector from byte array
kyber_utils::poly_vec_decode<kyber1024_kem::k, kyber1024_kem::du>(byte_arr, poly_vec);
ml_kem_utils::poly_vec_decode<ml_kem_768::k, ml_kem_768::du>(byte_arr, poly_vec);
// Decompress coefficients of polynomial vector
kyber_utils::poly_vec_decompress<kyber1024_kem::k, kyber1024_kem::du>(poly_vec);
ml_kem_utils::poly_vec_decompress<ml_kem_768::k, ml_kem_768::du>(poly_vec);
std::array<uint8_t, SEED_LEN> sink{};
auto _sink = std::span(sink);
using ctxt_t = std::span<const uint8_t, kyber1024_kem::CIPHER_TEXT_BYTE_LEN>;
using ctxt_t = std::span<const uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN>;
using seed_t = std::span<const uint8_t, SEED_LEN>;
// Ensure Fujisaki-Okamoto transform, used during decapsulation, is constant-time
const uint32_t cond = kyber_utils::ct_memcmp(ctxt_t(data + doff2, doff3 - doff2), ctxt_t(data + doff3, doff4 - doff3));
kyber_utils::ct_cond_memcpy(cond, _sink, seed_t(data + doff4, doff5 - doff4), seed_t(data + doff5, doff6 - doff5));
const uint32_t cond = ml_kem_utils::ct_memcmp(ctxt_t(data + doff2, doff3 - doff2), ctxt_t(data + doff3, doff4 - doff3));
ml_kem_utils::ct_cond_memcpy(cond, _sink, seed_t(data + doff4, doff5 - doff4), seed_t(data + doff5, doff6 - doff5));
// Just so that optimizer doesn't remove above function calls !
return static_cast<uint8_t>(poly_vec[0].raw() ^ poly_vec[poly_vec.size() - 1].raw()) ^ // result of generating vector of polynomials
@@ -69,14 +69,14 @@ prepare_inputs(dudect_config_t* const c, uint8_t* const input_data, uint8_t* con
}
dudect_state_t
test_kyber1024_kem()
test_ml_kem_768()
{
constexpr size_t chunk_size = SEED_LEN + // bytes holding seed `sigma`
1 + // single byte nonce
kyber1024_kem::CIPHER_TEXT_BYTE_LEN + // bytes holding received cipher text
kyber1024_kem::CIPHER_TEXT_BYTE_LEN + // bytes for locally computed cipher text
SEED_LEN + // bytes for first source buffer to copy from
SEED_LEN; // bytes for second source buffer to copy from
constexpr size_t chunk_size = SEED_LEN + // bytes holding seed `sigma`
1 + // single byte nonce
ml_kem_768::CIPHER_TEXT_BYTE_LEN + // bytes holding received cipher text
ml_kem_768::CIPHER_TEXT_BYTE_LEN + // bytes for locally computed cipher text
SEED_LEN + // bytes for first source buffer to copy from
SEED_LEN; // bytes for second source buffer to copy from
constexpr size_t number_measurements = 1e5;
dudect_config_t config = {
@@ -100,7 +100,7 @@ test_kyber1024_kem()
int
main()
{
if (test_kyber1024_kem() != DUDECT_NO_LEAKAGE_EVIDENCE_YET) {
if (test_ml_kem_768() != DUDECT_NO_LEAKAGE_EVIDENCE_YET) {
return EXIT_FAILURE;
}

View File

@@ -1,4 +1,4 @@
#include "kyber/internals/poly/compression.hpp"
#include "ml_kem/internals/poly/compression.hpp"
#include <gtest/gtest.h>
// Decompression error that can happen for some given `d` s.t.
@@ -7,7 +7,7 @@
//
// |(x' - x) mod q| <= round(q / 2 ^ (d + 1))
//
// See eq. 2 of Kyber specification
// See eq. 2 of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t d>
static inline constexpr size_t
@@ -41,8 +41,8 @@ test_zq_compression()
for (size_t i = 0; i < itr_cnt; i++) {
const auto a = field::zq_t::random(prng);
const auto b = kyber_utils::compress<d>(a);
const auto c = kyber_utils::decompress<d>(b);
const auto b = ml_kem_utils::compress<d>(a);
const auto c = ml_kem_utils::decompress<d>(b);
const auto a_canon = a.raw();
const auto c_canon = c.raw();
@@ -64,7 +64,7 @@ test_zq_compression()
return res;
}
TEST(KyberKEM, CompressDecompressZq)
TEST(Ml_kemKEM, CompressDecompressZq)
{
EXPECT_TRUE((test_zq_compression<11, 1ul << 20>()));
EXPECT_TRUE((test_zq_compression<10, 1ul << 20>()));

View File

@@ -1,10 +1,10 @@
#include "kyber/internals/math/field.hpp"
#include "ml_kem/internals/math/field.hpp"
#include <gtest/gtest.h>
// Test functional correctness of Kyber prime field operations ( using
// Test functional correctness of Ml_kem prime field operations ( using
// Montgomery Arithmetic ), by running through multiple rounds of execution of
// field operations on randomly sampled field elements.
TEST(KyberKEM, ArithmeticOverZq)
TEST(Ml_kemKEM, ArithmeticOverZq)
{
constexpr size_t itr_cnt = 1ul << 20;
prng::prng_t<128> prng{};

View File

@@ -1,8 +1,7 @@
#include "kyber/internals/ml_kem.hpp"
#include "kyber/internals/utility/utils.hpp"
#include "ml_kem/internals/ml_kem.hpp"
#include <gtest/gtest.h>
// Given k, η1, η2, du, dv - Kyber parameters, this routine checks whether
// Given k, η1, η2, du, dv - ML-KEM parameters, this routine checks whether
//
// - A new key pair can be generated for key establishment over insecure channel
// - Key pair is for receiving party, its public key will be used by sender.
@@ -17,12 +16,12 @@
// works as expected.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv, size_t bit_security_level>
void
test_kyber_kem()
test_ml_kem_kem()
{
constexpr size_t slen = 32;
constexpr size_t pklen = kyber_utils::get_kem_public_key_len(k);
constexpr size_t sklen = kyber_utils::get_kem_secret_key_len(k);
constexpr size_t ctlen = kyber_utils::get_kem_cipher_text_len(k, du, dv);
constexpr size_t pklen = ml_kem_utils::get_kem_public_key_len(k);
constexpr size_t sklen = ml_kem_utils::get_kem_secret_key_len(k);
constexpr size_t ctlen = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
constexpr size_t sslen = 32;
std::vector<uint8_t> d(slen);
@@ -55,17 +54,17 @@ test_kyber_kem()
EXPECT_EQ(sender_key, receiver_key);
}
TEST(KyberKEM, Kyber512KeygenEncapsDecaps)
TEST(Ml_kemKEM, Ml_kem512KeygenEncapsDecaps)
{
test_kyber_kem<2, 3, 2, 10, 4, 128>();
test_ml_kem_kem<2, 3, 2, 10, 4, 128>();
}
TEST(KyberKEM, Kyber768KeygenEncapsDecaps)
TEST(Ml_kemKEM, Ml_kem768KeygenEncapsDecaps)
{
test_kyber_kem<3, 2, 2, 10, 4, 192>();
test_ml_kem_kem<3, 2, 2, 10, 4, 192>();
}
TEST(KyberKEM, Kyber1024KeygenEncapsDecaps)
TEST(Ml_kemKEM, Ml_kem1024KeygenEncapsDecaps)
{
test_kyber_kem<4, 2, 2, 11, 5, 256>();
test_ml_kem_kem<4, 2, 2, 11, 5, 256>();
}

View File

@@ -1,13 +1,11 @@
#include "kyber/kyber1024_kem.hpp"
#include "kyber/kyber512_kem.hpp"
#include "kyber/kyber768_kem.hpp"
#include <array>
#include "ml_kem/ml_kem_1024.hpp"
#include "ml_kem/ml_kem_512.hpp"
#include "ml_kem/ml_kem_768.hpp"
#include <charconv>
#include <fstream>
#include <gtest/gtest.h>
// Given a hex encoded string of length 2*L, this routine can be used for
// parsing it as a byte array of length L.
// Given a hex encoded string of length 2*L, this routine can be used for parsing it as a byte array of length L.
template<size_t L>
static inline std::array<uint8_t, L>
from_hex(std::string_view bytes)
@@ -34,17 +32,16 @@ from_hex(std::string_view bytes)
// Test if
//
// - Is Kyber512 KEM implemented correctly ?
// - Is ML-KEM-512 implemented correctly ?
// - Is it conformant with the specification ?
//
// using Known Answer Tests, generated following
// https://gist.github.com/itzmeanjan/c8f5bc9640d0f0bdd2437dfe364d7710.
TEST(KyberKEM, Kyber512KnownAnswerTests)
TEST(ML_KEM, ML_KEM_512_KnownAnswerTests)
{
using namespace std::literals;
namespace kyber512 = kyber512_kem;
const std::string kat_file = "./kats/kyber512.kat";
const std::string kat_file = "./kats/ml_kem_512.kat";
std::fstream file(kat_file);
while (true) {
@@ -75,11 +72,11 @@ TEST(KyberKEM, Kyber512KnownAnswerTests)
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<kyber512::PKEY_BYTE_LEN>(__pk);
auto ___pk = from_hex<ml_kem_512::PKEY_BYTE_LEN>(__pk);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<kyber512::SKEY_BYTE_LEN>(__sk);
auto ___sk = from_hex<ml_kem_512::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
@@ -87,101 +84,21 @@ TEST(KyberKEM, Kyber512KnownAnswerTests)
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<kyber512::CIPHER_TEXT_BYTE_LEN>(__ct);
auto ___ct = from_hex<ml_kem_512::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, kyber512::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, kyber512::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, kyber512::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, kyber512_kem::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
std::array<uint8_t, kyber512_kem::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
std::array<uint8_t, ml_kem_512::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, ml_kem_512::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, ml_kem_512::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
kyber512::keygen(___d, ___z, pkey, skey);
(void)kyber512::encapsulate(___m, pkey, ctxt, shrd_sec0);
kyber512::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);
EXPECT_EQ(___ct, ctxt);
EXPECT_EQ(___ss, shrd_sec0);
EXPECT_EQ(shrd_sec0, shrd_sec1);
std::string empty_line;
std::getline(file, empty_line);
} else {
break;
}
}
file.close();
}
TEST(KyberKEM, Kyber768KnownAnswerTests)
{
using namespace std::literals;
namespace kyber768 = kyber768_kem;
const std::string kat_file = "./kats/kyber768.kat";
std::fstream file(kat_file);
while (true) {
std::string d;
if (!std::getline(file, d).eof()) {
std::string z;
std::string pk;
std::string sk;
std::string m;
std::string ct;
std::string ss;
std::getline(file, z);
std::getline(file, pk);
std::getline(file, sk);
std::getline(file, m);
std::getline(file, ct);
std::getline(file, ss);
auto _d = std::string_view(d);
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
auto ___d = from_hex<32>(__d);
auto _z = std::string_view(z);
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
auto ___z = from_hex<32>(__z);
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<kyber768::PKEY_BYTE_LEN>(__pk);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<kyber768::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
auto ___m = from_hex<32>(__m);
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<kyber768::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, kyber768::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, kyber768::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, kyber768::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, 32> shrd_sec0{};
std::array<uint8_t, 32> shrd_sec1{};
kyber768::keygen(___d, ___z, pkey, skey);
(void)kyber768::encapsulate(___m, pkey, ctxt, shrd_sec0);
kyber768::decapsulate(skey, ctxt, shrd_sec1);
ml_kem_512::keygen(___d, ___z, pkey, skey);
(void)ml_kem_512::encapsulate(___m, pkey, ctxt, shrd_sec0);
ml_kem_512::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);
@@ -201,17 +118,16 @@ TEST(KyberKEM, Kyber768KnownAnswerTests)
// Test if
//
// - Is Kyber1024 KEM implemented correctly ?
// - Is ML-KEM-768 implemented correctly ?
// - Is it conformant with the specification ?
//
// using Known Answer Tests, generated following
// https://gist.github.com/itzmeanjan/c8f5bc9640d0f0bdd2437dfe364d7710.
TEST(KyberKEM, Kyber1024KnownAnswerTests)
TEST(ML_KEM, ML_KEM_768_KnownAnswerTests)
{
using namespace std::literals;
namespace kyber1024 = kyber1024_kem;
const std::string kat_file = "./kats/kyber1024.kat";
const std::string kat_file = "./kats/ml_kem_768.kat";
std::fstream file(kat_file);
while (true) {
@@ -242,11 +158,11 @@ TEST(KyberKEM, Kyber1024KnownAnswerTests)
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<kyber1024::PKEY_BYTE_LEN>(__pk);
auto ___pk = from_hex<ml_kem_768::PKEY_BYTE_LEN>(__pk);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<kyber1024::SKEY_BYTE_LEN>(__sk);
auto ___sk = from_hex<ml_kem_768::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
@@ -254,21 +170,107 @@ TEST(KyberKEM, Kyber1024KnownAnswerTests)
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<kyber1024::CIPHER_TEXT_BYTE_LEN>(__ct);
auto ___ct = from_hex<ml_kem_768::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, kyber1024::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, kyber1024::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, kyber1024::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, ml_kem_768::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, ml_kem_768::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, 32> shrd_sec0{};
std::array<uint8_t, 32> shrd_sec1{};
kyber1024::keygen(___d, ___z, pkey, skey);
(void)kyber1024::encapsulate(___m, pkey, ctxt, shrd_sec0);
kyber1024::decapsulate(skey, ctxt, shrd_sec1);
ml_kem_768::keygen(___d, ___z, pkey, skey);
(void)ml_kem_768::encapsulate(___m, pkey, ctxt, shrd_sec0);
ml_kem_768::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);
EXPECT_EQ(___ct, ctxt);
EXPECT_EQ(___ss, shrd_sec0);
EXPECT_EQ(shrd_sec0, shrd_sec1);
std::string empty_line;
std::getline(file, empty_line);
} else {
break;
}
}
file.close();
}
// Test if
//
// - Is ML-KEM-1024 implemented correctly ?
// - Is it conformant with the specification ?
//
// using Known Answer Tests, generated following
// https://gist.github.com/itzmeanjan/c8f5bc9640d0f0bdd2437dfe364d7710.
TEST(Ml_kemKEM, ML_KEM_1024_KnownAnswerTests)
{
using namespace std::literals;
const std::string kat_file = "./kats/ml_kem_1024.kat";
std::fstream file(kat_file);
while (true) {
std::string d;
if (!std::getline(file, d).eof()) {
std::string z;
std::string pk;
std::string sk;
std::string m;
std::string ct;
std::string ss;
std::getline(file, z);
std::getline(file, pk);
std::getline(file, sk);
std::getline(file, m);
std::getline(file, ct);
std::getline(file, ss);
auto _d = std::string_view(d);
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
auto ___d = from_hex<32>(__d);
auto _z = std::string_view(z);
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
auto ___z = from_hex<32>(__z);
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<ml_kem_1024::PKEY_BYTE_LEN>(__pk);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<ml_kem_1024::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
auto ___m = from_hex<32>(__m);
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<ml_kem_1024::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, ml_kem_1024::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, ml_kem_1024::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, ml_kem_1024::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, 32> shrd_sec0{};
std::array<uint8_t, 32> shrd_sec1{};
ml_kem_1024::keygen(___d, ___z, pkey, skey);
(void)ml_kem_1024::encapsulate(___m, pkey, ctxt, shrd_sec0);
ml_kem_1024::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);

View File

@@ -1,33 +0,0 @@
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/ntt.hpp"
#include <gtest/gtest.h>
#include <vector>
// Ensure functional correctness of (inverse) NTT implementation for degree-255
// polynomial over F_q | q = 3329, using following rule
//
// f <- random polynomial
// f' <- ntt(f)
// f'' <- intt(f')
//
// assert(f == f'')
TEST(KyberKEM, NumberTheoreticTransform)
{
std::vector<field::zq_t> poly_a(ntt::N);
std::vector<field::zq_t> poly_b(ntt::N);
auto _poly_a = std::span<field::zq_t, ntt::N>(poly_a);
auto _poly_b = std::span<field::zq_t, ntt::N>(poly_b);
prng::prng_t<128> prng{};
for (size_t i = 0; i < ntt::N; i++) {
_poly_a[i] = field::zq_t::random(prng);
}
std::copy(_poly_a.begin(), _poly_a.end(), _poly_b.begin());
ntt::ntt(_poly_b);
ntt::intt(_poly_b);
EXPECT_EQ(poly_a, poly_b);
}

View File

@@ -1,5 +1,5 @@
#include "kyber/internals/math/field.hpp"
#include "kyber/internals/poly/serialize.hpp"
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/serialize.hpp"
#include <cstdint>
#include <gtest/gtest.h>
#include <vector>
@@ -30,15 +30,15 @@ test_serialize_deserialize()
using poly_t = std::span<field::zq_t, ntt::N>;
using serialized_t = std::span<uint8_t, blen>;
kyber_utils::encode<l>(poly_t(src), serialized_t(bytes));
kyber_utils::decode<l>(serialized_t(bytes), poly_t(dst));
ml_kem_utils::encode<l>(poly_t(src), serialized_t(bytes));
ml_kem_utils::decode<l>(serialized_t(bytes), poly_t(dst));
for (size_t i = 0; i < ntt::N; i++) {
EXPECT_EQ((src[i].raw() & mask), (dst[i].raw() & mask));
}
}
TEST(KyberKEM, PolynomialSerialization)
TEST(Ml_kemKEM, PolynomialSerialization)
{
test_serialize_deserialize<12>();
test_serialize_deserialize<11>();