diff --git a/tests/test_compression.cpp b/tests/test_compression.cpp index 29b8fca..3c6b70e 100644 --- a/tests/test_compression.cpp +++ b/tests/test_compression.cpp @@ -1,4 +1,5 @@ #include "ml_kem/internals/poly/compression.hpp" +#include "ml_kem/internals/utility/force_inline.hpp" #include // Decompression error that can happen for some given `d` s.t. @@ -10,7 +11,7 @@ // See eq. 2 of Ml_kem specification // https://doi.org/10.6028/NIST.FIPS.203.ipd template -static inline constexpr size_t +forceinline constexpr size_t compute_error() { constexpr double t0 = static_cast(ml_kem_field::Q); diff --git a/tests/test_field.cpp b/tests/test_field.cpp index bd66c23..30fd436 100644 --- a/tests/test_field.cpp +++ b/tests/test_field.cpp @@ -6,11 +6,11 @@ // field operations on randomly sampled field elements. TEST(ML_KEM, ArithmeticOverZq) { - static constexpr size_t itr_cnt = 1ul << 20; + constexpr size_t ITERATION_COUNT = 1ul << 20; ml_kem_prng::prng_t<128> prng{}; - for (size_t i = 0; i < itr_cnt; i++) { + for (size_t i = 0; i < ITERATION_COUNT; i++) { const auto a = ml_kem_field::zq_t::random(prng); const auto b = ml_kem_field::zq_t::random(prng); @@ -27,13 +27,13 @@ TEST(ML_KEM, ArithmeticOverZq) const auto g = f / b; const auto h = f / a; - if (b != ml_kem_field::zq_t()) { + if (b != ml_kem_field::zq_t::zero()) { EXPECT_EQ(g, a); } else { EXPECT_EQ(g, ml_kem_field::zq_t()); } - if (a != ml_kem_field::zq_t()) { + if (a != ml_kem_field::zq_t::zero()) { EXPECT_EQ(h, b); } else { EXPECT_EQ(h, ml_kem_field::zq_t()); diff --git a/tests/test_helper.hpp b/tests/test_helper.hpp index cc90018..0589a9a 100644 --- a/tests/test_helper.hpp +++ b/tests/test_helper.hpp @@ -1,6 +1,7 @@ #pragma once #include "ml_kem/internals/math/field.hpp" #include "ml_kem/internals/rng/prng.hpp" +#include "ml_kem/internals/utility/force_inline.hpp" #include #include #include @@ -12,7 +13,7 @@ // Given a hex encoded string of length 2*L, this routine can be used for parsing it as a byte array of length L. template -static inline std::array +static forceinline std::array from_hex(std::string_view bytes) { const size_t blen = bytes.length(); @@ -35,10 +36,24 @@ from_hex(std::string_view bytes) return res; } +// Given a string of following format, this lambda function can extract out the hex string portion +// and then it can parse it, returning a byte array of requested length. +// +// DATA = 010203....0d0e0f +template +static forceinline std::array +extract_and_parse_hex_string(std::string_view in_str) +{ + using namespace std::literals; + + const auto hex_str = in_str.substr(in_str.find("="sv) + 2, in_str.size()); + return from_hex(hex_str); +}; + // Given a valid ML-KEM-{512, 768, 1024} public key, this function mutates the last coefficient // of serialized polynomial vector s.t. it produces a malformed (i.e. non-reduced) polynomial vector. template -static inline constexpr void +static forceinline constexpr void make_malformed_pubkey(std::span pubkey) { constexpr auto last_coeff_ends_at = pubkey_byte_len - 32; @@ -59,7 +74,7 @@ make_malformed_pubkey(std::span pubkey) // Given a ML-KEM-{512, 768, 1024} cipher text, this function flips a random bit of it, while sampling choice of random index from input PRNG. template -static inline constexpr void +static forceinline constexpr void random_bitflip_in_cipher_text(std::span cipher, ml_kem_prng::prng_t& prng) { size_t random_u64 = 0; diff --git a/tests/test_ml_kem_1024_kat.cpp b/tests/test_ml_kem_1024_kat.cpp index 95ef0c8..3642d4b 100644 --- a/tests/test_ml_kem_1024_kat.cpp +++ b/tests/test_ml_kem_1024_kat.cpp @@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_1024_KnownAnswerTests) std::fstream file(kat_file); while (true) { - std::string d; + std::string d_line; - if (!std::getline(file, d).eof()) { - std::string z; - std::string pk; - std::string sk; - std::string m; - std::string ct; - std::string ss; + if (!std::getline(file, d_line).eof()) { + std::string z_line; + std::string pk_line; + std::string sk_line; + std::string m_line; + std::string ct_line; + std::string ss_line; - std::getline(file, z); - std::getline(file, pk); - std::getline(file, sk); - std::getline(file, m); - std::getline(file, ct); - std::getline(file, ss); + std::getline(file, z_line); + std::getline(file, pk_line); + std::getline(file, sk_line); + std::getline(file, m_line); + std::getline(file, ct_line); + std::getline(file, ss_line); - auto _d = std::string_view(d); - auto __d = _d.substr(_d.find("="sv) + 2, _d.size()); - auto ___d = from_hex<32>(__d); + const auto d = extract_and_parse_hex_string(d_line); + const auto z = extract_and_parse_hex_string(z_line); + const auto pk = extract_and_parse_hex_string(pk_line); + const auto sk = extract_and_parse_hex_string(sk_line); + const auto m = extract_and_parse_hex_string(m_line); + const auto ct = extract_and_parse_hex_string(ct_line); + const auto ss = extract_and_parse_hex_string(ss_line); - auto _z = std::string_view(z); - auto __z = _z.substr(_z.find("="sv) + 2, _z.size()); - auto ___z = from_hex<32>(__z); + std::array computed_pkey{}; + std::array computed_skey{}; + std::array computed_ctxt{}; + std::array computed_shared_secret_sender{}; + std::array computed_shared_secret_receiver{}; - auto _pk = std::string_view(pk); - auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size()); - auto ___pk = from_hex(__pk); + ml_kem_1024::keygen(d, z, computed_pkey, computed_skey); + EXPECT_TRUE(ml_kem_1024::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender)); + ml_kem_1024::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver); - auto _sk = std::string_view(sk); - auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size()); - auto ___sk = from_hex(__sk); - - auto _m = std::string_view(m); - auto __m = _m.substr(_m.find("="sv) + 2, _m.size()); - auto ___m = from_hex<32>(__m); - - auto _ct = std::string_view(ct); - auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size()); - auto ___ct = from_hex(__ct); - - auto _ss = std::string_view(ss); - auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size()); - auto ___ss = from_hex<32>(__ss); - - std::array pkey{}; - std::array skey{}; - std::array ctxt{}; - std::array shrd_sec0{}; - std::array shrd_sec1{}; - - ml_kem_1024::keygen(___d, ___z, pkey, skey); - EXPECT_TRUE(ml_kem_1024::encapsulate(___m, pkey, ctxt, shrd_sec0)); - ml_kem_1024::decapsulate(skey, ctxt, shrd_sec1); - - EXPECT_EQ(___pk, pkey); - EXPECT_EQ(___sk, skey); - EXPECT_EQ(___ct, ctxt); - EXPECT_EQ(___ss, shrd_sec0); - EXPECT_EQ(shrd_sec0, shrd_sec1); + EXPECT_EQ(pk, computed_pkey); + EXPECT_EQ(sk, computed_skey); + EXPECT_EQ(ct, computed_ctxt); + EXPECT_EQ(ss, computed_shared_secret_sender); + EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver); std::string empty_line; std::getline(file, empty_line); diff --git a/tests/test_ml_kem_512_kat.cpp b/tests/test_ml_kem_512_kat.cpp index d23ce3f..78aa34e 100644 --- a/tests/test_ml_kem_512_kat.cpp +++ b/tests/test_ml_kem_512_kat.cpp @@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_512_KnownAnswerTests) std::fstream file(kat_file); while (true) { - std::string d; + std::string d_line; - if (!std::getline(file, d).eof()) { - std::string z; - std::string pk; - std::string sk; - std::string m; - std::string ct; - std::string ss; + if (!std::getline(file, d_line).eof()) { + std::string z_line; + std::string pk_line; + std::string sk_line; + std::string m_line; + std::string ct_line; + std::string ss_line; - std::getline(file, z); - std::getline(file, pk); - std::getline(file, sk); - std::getline(file, m); - std::getline(file, ct); - std::getline(file, ss); + std::getline(file, z_line); + std::getline(file, pk_line); + std::getline(file, sk_line); + std::getline(file, m_line); + std::getline(file, ct_line); + std::getline(file, ss_line); - auto _d = std::string_view(d); - auto __d = _d.substr(_d.find("="sv) + 2, _d.size()); - auto ___d = from_hex<32>(__d); + const auto d = extract_and_parse_hex_string(d_line); + const auto z = extract_and_parse_hex_string(z_line); + const auto pk = extract_and_parse_hex_string(pk_line); + const auto sk = extract_and_parse_hex_string(sk_line); + const auto m = extract_and_parse_hex_string(m_line); + const auto ct = extract_and_parse_hex_string(ct_line); + const auto ss = extract_and_parse_hex_string(ss_line); - auto _z = std::string_view(z); - auto __z = _z.substr(_z.find("="sv) + 2, _z.size()); - auto ___z = from_hex<32>(__z); + std::array computed_pkey{}; + std::array computed_skey{}; + std::array computed_ctxt{}; + std::array computed_shared_secret_sender{}; + std::array computed_shared_secret_receiver{}; - auto _pk = std::string_view(pk); - auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size()); - auto ___pk = from_hex(__pk); + ml_kem_512::keygen(d, z, computed_pkey, computed_skey); + EXPECT_TRUE(ml_kem_512::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender)); + ml_kem_512::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver); - auto _sk = std::string_view(sk); - auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size()); - auto ___sk = from_hex(__sk); - - auto _m = std::string_view(m); - auto __m = _m.substr(_m.find("="sv) + 2, _m.size()); - auto ___m = from_hex<32>(__m); - - auto _ct = std::string_view(ct); - auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size()); - auto ___ct = from_hex(__ct); - - auto _ss = std::string_view(ss); - auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size()); - auto ___ss = from_hex<32>(__ss); - - std::array pkey{}; - std::array skey{}; - std::array ctxt{}; - std::array shrd_sec0{}; - std::array shrd_sec1{}; - - ml_kem_512::keygen(___d, ___z, pkey, skey); - EXPECT_TRUE(ml_kem_512::encapsulate(___m, pkey, ctxt, shrd_sec0)); - ml_kem_512::decapsulate(skey, ctxt, shrd_sec1); - - EXPECT_EQ(___pk, pkey); - EXPECT_EQ(___sk, skey); - EXPECT_EQ(___ct, ctxt); - EXPECT_EQ(___ss, shrd_sec0); - EXPECT_EQ(shrd_sec0, shrd_sec1); + EXPECT_EQ(pk, computed_pkey); + EXPECT_EQ(sk, computed_skey); + EXPECT_EQ(ct, computed_ctxt); + EXPECT_EQ(ss, computed_shared_secret_sender); + EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver); std::string empty_line; std::getline(file, empty_line); diff --git a/tests/test_ml_kem_768_kat.cpp b/tests/test_ml_kem_768_kat.cpp index ca7dce7..051cdfa 100644 --- a/tests/test_ml_kem_768_kat.cpp +++ b/tests/test_ml_kem_768_kat.cpp @@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_768_KnownAnswerTests) std::fstream file(kat_file); while (true) { - std::string d; + std::string d_line; - if (!std::getline(file, d).eof()) { - std::string z; - std::string pk; - std::string sk; - std::string m; - std::string ct; - std::string ss; + if (!std::getline(file, d_line).eof()) { + std::string z_line; + std::string pk_line; + std::string sk_line; + std::string m_line; + std::string ct_line; + std::string ss_line; - std::getline(file, z); - std::getline(file, pk); - std::getline(file, sk); - std::getline(file, m); - std::getline(file, ct); - std::getline(file, ss); + std::getline(file, z_line); + std::getline(file, pk_line); + std::getline(file, sk_line); + std::getline(file, m_line); + std::getline(file, ct_line); + std::getline(file, ss_line); - auto _d = std::string_view(d); - auto __d = _d.substr(_d.find("="sv) + 2, _d.size()); - auto ___d = from_hex<32>(__d); + const auto d = extract_and_parse_hex_string(d_line); + const auto z = extract_and_parse_hex_string(z_line); + const auto pk = extract_and_parse_hex_string(pk_line); + const auto sk = extract_and_parse_hex_string(sk_line); + const auto m = extract_and_parse_hex_string(m_line); + const auto ct = extract_and_parse_hex_string(ct_line); + const auto ss = extract_and_parse_hex_string(ss_line); - auto _z = std::string_view(z); - auto __z = _z.substr(_z.find("="sv) + 2, _z.size()); - auto ___z = from_hex<32>(__z); + std::array computed_pkey{}; + std::array computed_skey{}; + std::array computed_ctxt{}; + std::array computed_shared_secret_sender{}; + std::array computed_shared_secret_receiver{}; - auto _pk = std::string_view(pk); - auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size()); - auto ___pk = from_hex(__pk); + ml_kem_768::keygen(d, z, computed_pkey, computed_skey); + EXPECT_TRUE(ml_kem_768::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender)); + ml_kem_768::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver); - auto _sk = std::string_view(sk); - auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size()); - auto ___sk = from_hex(__sk); - - auto _m = std::string_view(m); - auto __m = _m.substr(_m.find("="sv) + 2, _m.size()); - auto ___m = from_hex<32>(__m); - - auto _ct = std::string_view(ct); - auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size()); - auto ___ct = from_hex(__ct); - - auto _ss = std::string_view(ss); - auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size()); - auto ___ss = from_hex<32>(__ss); - - std::array pkey{}; - std::array skey{}; - std::array ctxt{}; - std::array shrd_sec0{}; - std::array shrd_sec1{}; - - ml_kem_768::keygen(___d, ___z, pkey, skey); - EXPECT_TRUE(ml_kem_768::encapsulate(___m, pkey, ctxt, shrd_sec0)); - ml_kem_768::decapsulate(skey, ctxt, shrd_sec1); - - EXPECT_EQ(___pk, pkey); - EXPECT_EQ(___sk, skey); - EXPECT_EQ(___ct, ctxt); - EXPECT_EQ(___ss, shrd_sec0); - EXPECT_EQ(shrd_sec0, shrd_sec1); + EXPECT_EQ(pk, computed_pkey); + EXPECT_EQ(sk, computed_skey); + EXPECT_EQ(ct, computed_ctxt); + EXPECT_EQ(ss, computed_shared_secret_sender); + EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver); std::string empty_line; std::getline(file, empty_line);