mirror of
https://github.com/itzmeanjan/ml-kem.git
synced 2026-01-09 15:47:55 -05:00
Refactor KAT test runner functions, reducing lines of code
Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "ml_kem/internals/poly/compression.hpp"
|
||||
#include "ml_kem/internals/utility/force_inline.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
// Decompression error that can happen for some given `d` s.t.
|
||||
@@ -10,7 +11,7 @@
|
||||
// See eq. 2 of Ml_kem specification
|
||||
// https://doi.org/10.6028/NIST.FIPS.203.ipd
|
||||
template<size_t d>
|
||||
static inline constexpr size_t
|
||||
forceinline constexpr size_t
|
||||
compute_error()
|
||||
{
|
||||
constexpr double t0 = static_cast<double>(ml_kem_field::Q);
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
// field operations on randomly sampled field elements.
|
||||
TEST(ML_KEM, ArithmeticOverZq)
|
||||
{
|
||||
static constexpr size_t itr_cnt = 1ul << 20;
|
||||
constexpr size_t ITERATION_COUNT = 1ul << 20;
|
||||
|
||||
ml_kem_prng::prng_t<128> prng{};
|
||||
|
||||
for (size_t i = 0; i < itr_cnt; i++) {
|
||||
for (size_t i = 0; i < ITERATION_COUNT; i++) {
|
||||
const auto a = ml_kem_field::zq_t::random(prng);
|
||||
const auto b = ml_kem_field::zq_t::random(prng);
|
||||
|
||||
@@ -27,13 +27,13 @@ TEST(ML_KEM, ArithmeticOverZq)
|
||||
const auto g = f / b;
|
||||
const auto h = f / a;
|
||||
|
||||
if (b != ml_kem_field::zq_t()) {
|
||||
if (b != ml_kem_field::zq_t::zero()) {
|
||||
EXPECT_EQ(g, a);
|
||||
} else {
|
||||
EXPECT_EQ(g, ml_kem_field::zq_t());
|
||||
}
|
||||
|
||||
if (a != ml_kem_field::zq_t()) {
|
||||
if (a != ml_kem_field::zq_t::zero()) {
|
||||
EXPECT_EQ(h, b);
|
||||
} else {
|
||||
EXPECT_EQ(h, ml_kem_field::zq_t());
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include "ml_kem/internals/math/field.hpp"
|
||||
#include "ml_kem/internals/rng/prng.hpp"
|
||||
#include "ml_kem/internals/utility/force_inline.hpp"
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <charconv>
|
||||
@@ -12,7 +13,7 @@
|
||||
|
||||
// Given a hex encoded string of length 2*L, this routine can be used for parsing it as a byte array of length L.
|
||||
template<size_t L>
|
||||
static inline std::array<uint8_t, L>
|
||||
static forceinline std::array<uint8_t, L>
|
||||
from_hex(std::string_view bytes)
|
||||
{
|
||||
const size_t blen = bytes.length();
|
||||
@@ -35,10 +36,24 @@ from_hex(std::string_view bytes)
|
||||
return res;
|
||||
}
|
||||
|
||||
// Given a string of following format, this lambda function can extract out the hex string portion
|
||||
// and then it can parse it, returning a byte array of requested length.
|
||||
//
|
||||
// DATA = 010203....0d0e0f
|
||||
template<size_t byte_len>
|
||||
static forceinline std::array<uint8_t, byte_len>
|
||||
extract_and_parse_hex_string(std::string_view in_str)
|
||||
{
|
||||
using namespace std::literals;
|
||||
|
||||
const auto hex_str = in_str.substr(in_str.find("="sv) + 2, in_str.size());
|
||||
return from_hex<byte_len>(hex_str);
|
||||
};
|
||||
|
||||
// Given a valid ML-KEM-{512, 768, 1024} public key, this function mutates the last coefficient
|
||||
// of serialized polynomial vector s.t. it produces a malformed (i.e. non-reduced) polynomial vector.
|
||||
template<size_t pubkey_byte_len>
|
||||
static inline constexpr void
|
||||
static forceinline constexpr void
|
||||
make_malformed_pubkey(std::span<uint8_t, pubkey_byte_len> pubkey)
|
||||
{
|
||||
constexpr auto last_coeff_ends_at = pubkey_byte_len - 32;
|
||||
@@ -59,7 +74,7 @@ make_malformed_pubkey(std::span<uint8_t, pubkey_byte_len> pubkey)
|
||||
|
||||
// Given a ML-KEM-{512, 768, 1024} cipher text, this function flips a random bit of it, while sampling choice of random index from input PRNG.
|
||||
template<size_t cipher_byte_len, size_t bit_sec_lvl>
|
||||
static inline constexpr void
|
||||
static forceinline constexpr void
|
||||
random_bitflip_in_cipher_text(std::span<uint8_t, cipher_byte_len> cipher, ml_kem_prng::prng_t<bit_sec_lvl>& prng)
|
||||
{
|
||||
size_t random_u64 = 0;
|
||||
|
||||
@@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_1024_KnownAnswerTests)
|
||||
std::fstream file(kat_file);
|
||||
|
||||
while (true) {
|
||||
std::string d;
|
||||
std::string d_line;
|
||||
|
||||
if (!std::getline(file, d).eof()) {
|
||||
std::string z;
|
||||
std::string pk;
|
||||
std::string sk;
|
||||
std::string m;
|
||||
std::string ct;
|
||||
std::string ss;
|
||||
if (!std::getline(file, d_line).eof()) {
|
||||
std::string z_line;
|
||||
std::string pk_line;
|
||||
std::string sk_line;
|
||||
std::string m_line;
|
||||
std::string ct_line;
|
||||
std::string ss_line;
|
||||
|
||||
std::getline(file, z);
|
||||
std::getline(file, pk);
|
||||
std::getline(file, sk);
|
||||
std::getline(file, m);
|
||||
std::getline(file, ct);
|
||||
std::getline(file, ss);
|
||||
std::getline(file, z_line);
|
||||
std::getline(file, pk_line);
|
||||
std::getline(file, sk_line);
|
||||
std::getline(file, m_line);
|
||||
std::getline(file, ct_line);
|
||||
std::getline(file, ss_line);
|
||||
|
||||
auto _d = std::string_view(d);
|
||||
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
|
||||
auto ___d = from_hex<32>(__d);
|
||||
const auto d = extract_and_parse_hex_string<ml_kem_1024::SEED_D_BYTE_LEN>(d_line);
|
||||
const auto z = extract_and_parse_hex_string<ml_kem_1024::SEED_Z_BYTE_LEN>(z_line);
|
||||
const auto pk = extract_and_parse_hex_string<ml_kem_1024::PKEY_BYTE_LEN>(pk_line);
|
||||
const auto sk = extract_and_parse_hex_string<ml_kem_1024::SKEY_BYTE_LEN>(sk_line);
|
||||
const auto m = extract_and_parse_hex_string<ml_kem_1024::SEED_M_BYTE_LEN>(m_line);
|
||||
const auto ct = extract_and_parse_hex_string<ml_kem_1024::CIPHER_TEXT_BYTE_LEN>(ct_line);
|
||||
const auto ss = extract_and_parse_hex_string<ml_kem_1024::SHARED_SECRET_BYTE_LEN>(ss_line);
|
||||
|
||||
auto _z = std::string_view(z);
|
||||
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
|
||||
auto ___z = from_hex<32>(__z);
|
||||
std::array<uint8_t, ml_kem_1024::PKEY_BYTE_LEN> computed_pkey{};
|
||||
std::array<uint8_t, ml_kem_1024::SKEY_BYTE_LEN> computed_skey{};
|
||||
std::array<uint8_t, ml_kem_1024::CIPHER_TEXT_BYTE_LEN> computed_ctxt{};
|
||||
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> computed_shared_secret_sender{};
|
||||
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> computed_shared_secret_receiver{};
|
||||
|
||||
auto _pk = std::string_view(pk);
|
||||
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
|
||||
auto ___pk = from_hex<ml_kem_1024::PKEY_BYTE_LEN>(__pk);
|
||||
ml_kem_1024::keygen(d, z, computed_pkey, computed_skey);
|
||||
EXPECT_TRUE(ml_kem_1024::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender));
|
||||
ml_kem_1024::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver);
|
||||
|
||||
auto _sk = std::string_view(sk);
|
||||
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
|
||||
auto ___sk = from_hex<ml_kem_1024::SKEY_BYTE_LEN>(__sk);
|
||||
|
||||
auto _m = std::string_view(m);
|
||||
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
|
||||
auto ___m = from_hex<32>(__m);
|
||||
|
||||
auto _ct = std::string_view(ct);
|
||||
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
|
||||
auto ___ct = from_hex<ml_kem_1024::CIPHER_TEXT_BYTE_LEN>(__ct);
|
||||
|
||||
auto _ss = std::string_view(ss);
|
||||
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
|
||||
auto ___ss = from_hex<32>(__ss);
|
||||
|
||||
std::array<uint8_t, ml_kem_1024::PKEY_BYTE_LEN> pkey{};
|
||||
std::array<uint8_t, ml_kem_1024::SKEY_BYTE_LEN> skey{};
|
||||
std::array<uint8_t, ml_kem_1024::CIPHER_TEXT_BYTE_LEN> ctxt{};
|
||||
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
|
||||
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
|
||||
|
||||
ml_kem_1024::keygen(___d, ___z, pkey, skey);
|
||||
EXPECT_TRUE(ml_kem_1024::encapsulate(___m, pkey, ctxt, shrd_sec0));
|
||||
ml_kem_1024::decapsulate(skey, ctxt, shrd_sec1);
|
||||
|
||||
EXPECT_EQ(___pk, pkey);
|
||||
EXPECT_EQ(___sk, skey);
|
||||
EXPECT_EQ(___ct, ctxt);
|
||||
EXPECT_EQ(___ss, shrd_sec0);
|
||||
EXPECT_EQ(shrd_sec0, shrd_sec1);
|
||||
EXPECT_EQ(pk, computed_pkey);
|
||||
EXPECT_EQ(sk, computed_skey);
|
||||
EXPECT_EQ(ct, computed_ctxt);
|
||||
EXPECT_EQ(ss, computed_shared_secret_sender);
|
||||
EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver);
|
||||
|
||||
std::string empty_line;
|
||||
std::getline(file, empty_line);
|
||||
|
||||
@@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_512_KnownAnswerTests)
|
||||
std::fstream file(kat_file);
|
||||
|
||||
while (true) {
|
||||
std::string d;
|
||||
std::string d_line;
|
||||
|
||||
if (!std::getline(file, d).eof()) {
|
||||
std::string z;
|
||||
std::string pk;
|
||||
std::string sk;
|
||||
std::string m;
|
||||
std::string ct;
|
||||
std::string ss;
|
||||
if (!std::getline(file, d_line).eof()) {
|
||||
std::string z_line;
|
||||
std::string pk_line;
|
||||
std::string sk_line;
|
||||
std::string m_line;
|
||||
std::string ct_line;
|
||||
std::string ss_line;
|
||||
|
||||
std::getline(file, z);
|
||||
std::getline(file, pk);
|
||||
std::getline(file, sk);
|
||||
std::getline(file, m);
|
||||
std::getline(file, ct);
|
||||
std::getline(file, ss);
|
||||
std::getline(file, z_line);
|
||||
std::getline(file, pk_line);
|
||||
std::getline(file, sk_line);
|
||||
std::getline(file, m_line);
|
||||
std::getline(file, ct_line);
|
||||
std::getline(file, ss_line);
|
||||
|
||||
auto _d = std::string_view(d);
|
||||
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
|
||||
auto ___d = from_hex<32>(__d);
|
||||
const auto d = extract_and_parse_hex_string<ml_kem_512::SEED_D_BYTE_LEN>(d_line);
|
||||
const auto z = extract_and_parse_hex_string<ml_kem_512::SEED_Z_BYTE_LEN>(z_line);
|
||||
const auto pk = extract_and_parse_hex_string<ml_kem_512::PKEY_BYTE_LEN>(pk_line);
|
||||
const auto sk = extract_and_parse_hex_string<ml_kem_512::SKEY_BYTE_LEN>(sk_line);
|
||||
const auto m = extract_and_parse_hex_string<ml_kem_512::SEED_M_BYTE_LEN>(m_line);
|
||||
const auto ct = extract_and_parse_hex_string<ml_kem_512::CIPHER_TEXT_BYTE_LEN>(ct_line);
|
||||
const auto ss = extract_and_parse_hex_string<ml_kem_512::SHARED_SECRET_BYTE_LEN>(ss_line);
|
||||
|
||||
auto _z = std::string_view(z);
|
||||
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
|
||||
auto ___z = from_hex<32>(__z);
|
||||
std::array<uint8_t, ml_kem_512::PKEY_BYTE_LEN> computed_pkey{};
|
||||
std::array<uint8_t, ml_kem_512::SKEY_BYTE_LEN> computed_skey{};
|
||||
std::array<uint8_t, ml_kem_512::CIPHER_TEXT_BYTE_LEN> computed_ctxt{};
|
||||
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> computed_shared_secret_sender{};
|
||||
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> computed_shared_secret_receiver{};
|
||||
|
||||
auto _pk = std::string_view(pk);
|
||||
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
|
||||
auto ___pk = from_hex<ml_kem_512::PKEY_BYTE_LEN>(__pk);
|
||||
ml_kem_512::keygen(d, z, computed_pkey, computed_skey);
|
||||
EXPECT_TRUE(ml_kem_512::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender));
|
||||
ml_kem_512::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver);
|
||||
|
||||
auto _sk = std::string_view(sk);
|
||||
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
|
||||
auto ___sk = from_hex<ml_kem_512::SKEY_BYTE_LEN>(__sk);
|
||||
|
||||
auto _m = std::string_view(m);
|
||||
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
|
||||
auto ___m = from_hex<32>(__m);
|
||||
|
||||
auto _ct = std::string_view(ct);
|
||||
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
|
||||
auto ___ct = from_hex<ml_kem_512::CIPHER_TEXT_BYTE_LEN>(__ct);
|
||||
|
||||
auto _ss = std::string_view(ss);
|
||||
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
|
||||
auto ___ss = from_hex<32>(__ss);
|
||||
|
||||
std::array<uint8_t, ml_kem_512::PKEY_BYTE_LEN> pkey{};
|
||||
std::array<uint8_t, ml_kem_512::SKEY_BYTE_LEN> skey{};
|
||||
std::array<uint8_t, ml_kem_512::CIPHER_TEXT_BYTE_LEN> ctxt{};
|
||||
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
|
||||
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
|
||||
|
||||
ml_kem_512::keygen(___d, ___z, pkey, skey);
|
||||
EXPECT_TRUE(ml_kem_512::encapsulate(___m, pkey, ctxt, shrd_sec0));
|
||||
ml_kem_512::decapsulate(skey, ctxt, shrd_sec1);
|
||||
|
||||
EXPECT_EQ(___pk, pkey);
|
||||
EXPECT_EQ(___sk, skey);
|
||||
EXPECT_EQ(___ct, ctxt);
|
||||
EXPECT_EQ(___ss, shrd_sec0);
|
||||
EXPECT_EQ(shrd_sec0, shrd_sec1);
|
||||
EXPECT_EQ(pk, computed_pkey);
|
||||
EXPECT_EQ(sk, computed_skey);
|
||||
EXPECT_EQ(ct, computed_ctxt);
|
||||
EXPECT_EQ(ss, computed_shared_secret_sender);
|
||||
EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver);
|
||||
|
||||
std::string empty_line;
|
||||
std::getline(file, empty_line);
|
||||
|
||||
@@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_768_KnownAnswerTests)
|
||||
std::fstream file(kat_file);
|
||||
|
||||
while (true) {
|
||||
std::string d;
|
||||
std::string d_line;
|
||||
|
||||
if (!std::getline(file, d).eof()) {
|
||||
std::string z;
|
||||
std::string pk;
|
||||
std::string sk;
|
||||
std::string m;
|
||||
std::string ct;
|
||||
std::string ss;
|
||||
if (!std::getline(file, d_line).eof()) {
|
||||
std::string z_line;
|
||||
std::string pk_line;
|
||||
std::string sk_line;
|
||||
std::string m_line;
|
||||
std::string ct_line;
|
||||
std::string ss_line;
|
||||
|
||||
std::getline(file, z);
|
||||
std::getline(file, pk);
|
||||
std::getline(file, sk);
|
||||
std::getline(file, m);
|
||||
std::getline(file, ct);
|
||||
std::getline(file, ss);
|
||||
std::getline(file, z_line);
|
||||
std::getline(file, pk_line);
|
||||
std::getline(file, sk_line);
|
||||
std::getline(file, m_line);
|
||||
std::getline(file, ct_line);
|
||||
std::getline(file, ss_line);
|
||||
|
||||
auto _d = std::string_view(d);
|
||||
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
|
||||
auto ___d = from_hex<32>(__d);
|
||||
const auto d = extract_and_parse_hex_string<ml_kem_768::SEED_D_BYTE_LEN>(d_line);
|
||||
const auto z = extract_and_parse_hex_string<ml_kem_768::SEED_Z_BYTE_LEN>(z_line);
|
||||
const auto pk = extract_and_parse_hex_string<ml_kem_768::PKEY_BYTE_LEN>(pk_line);
|
||||
const auto sk = extract_and_parse_hex_string<ml_kem_768::SKEY_BYTE_LEN>(sk_line);
|
||||
const auto m = extract_and_parse_hex_string<ml_kem_768::SEED_M_BYTE_LEN>(m_line);
|
||||
const auto ct = extract_and_parse_hex_string<ml_kem_768::CIPHER_TEXT_BYTE_LEN>(ct_line);
|
||||
const auto ss = extract_and_parse_hex_string<ml_kem_768::SHARED_SECRET_BYTE_LEN>(ss_line);
|
||||
|
||||
auto _z = std::string_view(z);
|
||||
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
|
||||
auto ___z = from_hex<32>(__z);
|
||||
std::array<uint8_t, ml_kem_768::PKEY_BYTE_LEN> computed_pkey{};
|
||||
std::array<uint8_t, ml_kem_768::SKEY_BYTE_LEN> computed_skey{};
|
||||
std::array<uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN> computed_ctxt{};
|
||||
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> computed_shared_secret_sender{};
|
||||
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> computed_shared_secret_receiver{};
|
||||
|
||||
auto _pk = std::string_view(pk);
|
||||
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
|
||||
auto ___pk = from_hex<ml_kem_768::PKEY_BYTE_LEN>(__pk);
|
||||
ml_kem_768::keygen(d, z, computed_pkey, computed_skey);
|
||||
EXPECT_TRUE(ml_kem_768::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender));
|
||||
ml_kem_768::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver);
|
||||
|
||||
auto _sk = std::string_view(sk);
|
||||
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
|
||||
auto ___sk = from_hex<ml_kem_768::SKEY_BYTE_LEN>(__sk);
|
||||
|
||||
auto _m = std::string_view(m);
|
||||
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
|
||||
auto ___m = from_hex<32>(__m);
|
||||
|
||||
auto _ct = std::string_view(ct);
|
||||
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
|
||||
auto ___ct = from_hex<ml_kem_768::CIPHER_TEXT_BYTE_LEN>(__ct);
|
||||
|
||||
auto _ss = std::string_view(ss);
|
||||
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
|
||||
auto ___ss = from_hex<32>(__ss);
|
||||
|
||||
std::array<uint8_t, ml_kem_768::PKEY_BYTE_LEN> pkey{};
|
||||
std::array<uint8_t, ml_kem_768::SKEY_BYTE_LEN> skey{};
|
||||
std::array<uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN> ctxt{};
|
||||
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
|
||||
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
|
||||
|
||||
ml_kem_768::keygen(___d, ___z, pkey, skey);
|
||||
EXPECT_TRUE(ml_kem_768::encapsulate(___m, pkey, ctxt, shrd_sec0));
|
||||
ml_kem_768::decapsulate(skey, ctxt, shrd_sec1);
|
||||
|
||||
EXPECT_EQ(___pk, pkey);
|
||||
EXPECT_EQ(___sk, skey);
|
||||
EXPECT_EQ(___ct, ctxt);
|
||||
EXPECT_EQ(___ss, shrd_sec0);
|
||||
EXPECT_EQ(shrd_sec0, shrd_sec1);
|
||||
EXPECT_EQ(pk, computed_pkey);
|
||||
EXPECT_EQ(sk, computed_skey);
|
||||
EXPECT_EQ(ct, computed_ctxt);
|
||||
EXPECT_EQ(ss, computed_shared_secret_sender);
|
||||
EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver);
|
||||
|
||||
std::string empty_line;
|
||||
std::getline(file, empty_line);
|
||||
|
||||
Reference in New Issue
Block a user