Refactor KAT test runner functions, reducing lines of code

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2024-09-01 23:41:37 +04:00
parent 53c0afa644
commit 5cb46afd16
6 changed files with 126 additions and 170 deletions

View File

@@ -1,4 +1,5 @@
#include "ml_kem/internals/poly/compression.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
#include <gtest/gtest.h>
// Decompression error that can happen for some given `d` s.t.
@@ -10,7 +11,7 @@
// See eq. 2 of Ml_kem specification
// https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t d>
static inline constexpr size_t
forceinline constexpr size_t
compute_error()
{
constexpr double t0 = static_cast<double>(ml_kem_field::Q);

View File

@@ -6,11 +6,11 @@
// field operations on randomly sampled field elements.
TEST(ML_KEM, ArithmeticOverZq)
{
static constexpr size_t itr_cnt = 1ul << 20;
constexpr size_t ITERATION_COUNT = 1ul << 20;
ml_kem_prng::prng_t<128> prng{};
for (size_t i = 0; i < itr_cnt; i++) {
for (size_t i = 0; i < ITERATION_COUNT; i++) {
const auto a = ml_kem_field::zq_t::random(prng);
const auto b = ml_kem_field::zq_t::random(prng);
@@ -27,13 +27,13 @@ TEST(ML_KEM, ArithmeticOverZq)
const auto g = f / b;
const auto h = f / a;
if (b != ml_kem_field::zq_t()) {
if (b != ml_kem_field::zq_t::zero()) {
EXPECT_EQ(g, a);
} else {
EXPECT_EQ(g, ml_kem_field::zq_t());
}
if (a != ml_kem_field::zq_t()) {
if (a != ml_kem_field::zq_t::zero()) {
EXPECT_EQ(h, b);
} else {
EXPECT_EQ(h, ml_kem_field::zq_t());

View File

@@ -1,6 +1,7 @@
#pragma once
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/rng/prng.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
#include <array>
#include <cassert>
#include <charconv>
@@ -12,7 +13,7 @@
// Given a hex encoded string of length 2*L, this routine can be used for parsing it as a byte array of length L.
template<size_t L>
static inline std::array<uint8_t, L>
static forceinline std::array<uint8_t, L>
from_hex(std::string_view bytes)
{
const size_t blen = bytes.length();
@@ -35,10 +36,24 @@ from_hex(std::string_view bytes)
return res;
}
// Given a string of following format, this lambda function can extract out the hex string portion
// and then it can parse it, returning a byte array of requested length.
//
// DATA = 010203....0d0e0f
template<size_t byte_len>
static forceinline std::array<uint8_t, byte_len>
extract_and_parse_hex_string(std::string_view in_str)
{
using namespace std::literals;
const auto hex_str = in_str.substr(in_str.find("="sv) + 2, in_str.size());
return from_hex<byte_len>(hex_str);
};
// Given a valid ML-KEM-{512, 768, 1024} public key, this function mutates the last coefficient
// of serialized polynomial vector s.t. it produces a malformed (i.e. non-reduced) polynomial vector.
template<size_t pubkey_byte_len>
static inline constexpr void
static forceinline constexpr void
make_malformed_pubkey(std::span<uint8_t, pubkey_byte_len> pubkey)
{
constexpr auto last_coeff_ends_at = pubkey_byte_len - 32;
@@ -59,7 +74,7 @@ make_malformed_pubkey(std::span<uint8_t, pubkey_byte_len> pubkey)
// Given a ML-KEM-{512, 768, 1024} cipher text, this function flips a random bit of it, while sampling choice of random index from input PRNG.
template<size_t cipher_byte_len, size_t bit_sec_lvl>
static inline constexpr void
static forceinline constexpr void
random_bitflip_in_cipher_text(std::span<uint8_t, cipher_byte_len> cipher, ml_kem_prng::prng_t<bit_sec_lvl>& prng)
{
size_t random_u64 = 0;

View File

@@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_1024_KnownAnswerTests)
std::fstream file(kat_file);
while (true) {
std::string d;
std::string d_line;
if (!std::getline(file, d).eof()) {
std::string z;
std::string pk;
std::string sk;
std::string m;
std::string ct;
std::string ss;
if (!std::getline(file, d_line).eof()) {
std::string z_line;
std::string pk_line;
std::string sk_line;
std::string m_line;
std::string ct_line;
std::string ss_line;
std::getline(file, z);
std::getline(file, pk);
std::getline(file, sk);
std::getline(file, m);
std::getline(file, ct);
std::getline(file, ss);
std::getline(file, z_line);
std::getline(file, pk_line);
std::getline(file, sk_line);
std::getline(file, m_line);
std::getline(file, ct_line);
std::getline(file, ss_line);
auto _d = std::string_view(d);
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
auto ___d = from_hex<32>(__d);
const auto d = extract_and_parse_hex_string<ml_kem_1024::SEED_D_BYTE_LEN>(d_line);
const auto z = extract_and_parse_hex_string<ml_kem_1024::SEED_Z_BYTE_LEN>(z_line);
const auto pk = extract_and_parse_hex_string<ml_kem_1024::PKEY_BYTE_LEN>(pk_line);
const auto sk = extract_and_parse_hex_string<ml_kem_1024::SKEY_BYTE_LEN>(sk_line);
const auto m = extract_and_parse_hex_string<ml_kem_1024::SEED_M_BYTE_LEN>(m_line);
const auto ct = extract_and_parse_hex_string<ml_kem_1024::CIPHER_TEXT_BYTE_LEN>(ct_line);
const auto ss = extract_and_parse_hex_string<ml_kem_1024::SHARED_SECRET_BYTE_LEN>(ss_line);
auto _z = std::string_view(z);
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
auto ___z = from_hex<32>(__z);
std::array<uint8_t, ml_kem_1024::PKEY_BYTE_LEN> computed_pkey{};
std::array<uint8_t, ml_kem_1024::SKEY_BYTE_LEN> computed_skey{};
std::array<uint8_t, ml_kem_1024::CIPHER_TEXT_BYTE_LEN> computed_ctxt{};
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> computed_shared_secret_sender{};
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> computed_shared_secret_receiver{};
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<ml_kem_1024::PKEY_BYTE_LEN>(__pk);
ml_kem_1024::keygen(d, z, computed_pkey, computed_skey);
EXPECT_TRUE(ml_kem_1024::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender));
ml_kem_1024::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<ml_kem_1024::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
auto ___m = from_hex<32>(__m);
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<ml_kem_1024::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, ml_kem_1024::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, ml_kem_1024::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, ml_kem_1024::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
std::array<uint8_t, ml_kem_1024::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
ml_kem_1024::keygen(___d, ___z, pkey, skey);
EXPECT_TRUE(ml_kem_1024::encapsulate(___m, pkey, ctxt, shrd_sec0));
ml_kem_1024::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);
EXPECT_EQ(___ct, ctxt);
EXPECT_EQ(___ss, shrd_sec0);
EXPECT_EQ(shrd_sec0, shrd_sec1);
EXPECT_EQ(pk, computed_pkey);
EXPECT_EQ(sk, computed_skey);
EXPECT_EQ(ct, computed_ctxt);
EXPECT_EQ(ss, computed_shared_secret_sender);
EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver);
std::string empty_line;
std::getline(file, empty_line);

View File

@@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_512_KnownAnswerTests)
std::fstream file(kat_file);
while (true) {
std::string d;
std::string d_line;
if (!std::getline(file, d).eof()) {
std::string z;
std::string pk;
std::string sk;
std::string m;
std::string ct;
std::string ss;
if (!std::getline(file, d_line).eof()) {
std::string z_line;
std::string pk_line;
std::string sk_line;
std::string m_line;
std::string ct_line;
std::string ss_line;
std::getline(file, z);
std::getline(file, pk);
std::getline(file, sk);
std::getline(file, m);
std::getline(file, ct);
std::getline(file, ss);
std::getline(file, z_line);
std::getline(file, pk_line);
std::getline(file, sk_line);
std::getline(file, m_line);
std::getline(file, ct_line);
std::getline(file, ss_line);
auto _d = std::string_view(d);
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
auto ___d = from_hex<32>(__d);
const auto d = extract_and_parse_hex_string<ml_kem_512::SEED_D_BYTE_LEN>(d_line);
const auto z = extract_and_parse_hex_string<ml_kem_512::SEED_Z_BYTE_LEN>(z_line);
const auto pk = extract_and_parse_hex_string<ml_kem_512::PKEY_BYTE_LEN>(pk_line);
const auto sk = extract_and_parse_hex_string<ml_kem_512::SKEY_BYTE_LEN>(sk_line);
const auto m = extract_and_parse_hex_string<ml_kem_512::SEED_M_BYTE_LEN>(m_line);
const auto ct = extract_and_parse_hex_string<ml_kem_512::CIPHER_TEXT_BYTE_LEN>(ct_line);
const auto ss = extract_and_parse_hex_string<ml_kem_512::SHARED_SECRET_BYTE_LEN>(ss_line);
auto _z = std::string_view(z);
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
auto ___z = from_hex<32>(__z);
std::array<uint8_t, ml_kem_512::PKEY_BYTE_LEN> computed_pkey{};
std::array<uint8_t, ml_kem_512::SKEY_BYTE_LEN> computed_skey{};
std::array<uint8_t, ml_kem_512::CIPHER_TEXT_BYTE_LEN> computed_ctxt{};
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> computed_shared_secret_sender{};
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> computed_shared_secret_receiver{};
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<ml_kem_512::PKEY_BYTE_LEN>(__pk);
ml_kem_512::keygen(d, z, computed_pkey, computed_skey);
EXPECT_TRUE(ml_kem_512::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender));
ml_kem_512::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<ml_kem_512::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
auto ___m = from_hex<32>(__m);
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<ml_kem_512::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, ml_kem_512::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, ml_kem_512::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, ml_kem_512::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
std::array<uint8_t, ml_kem_512::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
ml_kem_512::keygen(___d, ___z, pkey, skey);
EXPECT_TRUE(ml_kem_512::encapsulate(___m, pkey, ctxt, shrd_sec0));
ml_kem_512::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);
EXPECT_EQ(___ct, ctxt);
EXPECT_EQ(___ss, shrd_sec0);
EXPECT_EQ(shrd_sec0, shrd_sec1);
EXPECT_EQ(pk, computed_pkey);
EXPECT_EQ(sk, computed_skey);
EXPECT_EQ(ct, computed_ctxt);
EXPECT_EQ(ss, computed_shared_secret_sender);
EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver);
std::string empty_line;
std::getline(file, empty_line);

View File

@@ -18,66 +18,46 @@ TEST(ML_KEM, ML_KEM_768_KnownAnswerTests)
std::fstream file(kat_file);
while (true) {
std::string d;
std::string d_line;
if (!std::getline(file, d).eof()) {
std::string z;
std::string pk;
std::string sk;
std::string m;
std::string ct;
std::string ss;
if (!std::getline(file, d_line).eof()) {
std::string z_line;
std::string pk_line;
std::string sk_line;
std::string m_line;
std::string ct_line;
std::string ss_line;
std::getline(file, z);
std::getline(file, pk);
std::getline(file, sk);
std::getline(file, m);
std::getline(file, ct);
std::getline(file, ss);
std::getline(file, z_line);
std::getline(file, pk_line);
std::getline(file, sk_line);
std::getline(file, m_line);
std::getline(file, ct_line);
std::getline(file, ss_line);
auto _d = std::string_view(d);
auto __d = _d.substr(_d.find("="sv) + 2, _d.size());
auto ___d = from_hex<32>(__d);
const auto d = extract_and_parse_hex_string<ml_kem_768::SEED_D_BYTE_LEN>(d_line);
const auto z = extract_and_parse_hex_string<ml_kem_768::SEED_Z_BYTE_LEN>(z_line);
const auto pk = extract_and_parse_hex_string<ml_kem_768::PKEY_BYTE_LEN>(pk_line);
const auto sk = extract_and_parse_hex_string<ml_kem_768::SKEY_BYTE_LEN>(sk_line);
const auto m = extract_and_parse_hex_string<ml_kem_768::SEED_M_BYTE_LEN>(m_line);
const auto ct = extract_and_parse_hex_string<ml_kem_768::CIPHER_TEXT_BYTE_LEN>(ct_line);
const auto ss = extract_and_parse_hex_string<ml_kem_768::SHARED_SECRET_BYTE_LEN>(ss_line);
auto _z = std::string_view(z);
auto __z = _z.substr(_z.find("="sv) + 2, _z.size());
auto ___z = from_hex<32>(__z);
std::array<uint8_t, ml_kem_768::PKEY_BYTE_LEN> computed_pkey{};
std::array<uint8_t, ml_kem_768::SKEY_BYTE_LEN> computed_skey{};
std::array<uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN> computed_ctxt{};
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> computed_shared_secret_sender{};
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> computed_shared_secret_receiver{};
auto _pk = std::string_view(pk);
auto __pk = _pk.substr(_pk.find("="sv) + 2, _pk.size());
auto ___pk = from_hex<ml_kem_768::PKEY_BYTE_LEN>(__pk);
ml_kem_768::keygen(d, z, computed_pkey, computed_skey);
EXPECT_TRUE(ml_kem_768::encapsulate(m, computed_pkey, computed_ctxt, computed_shared_secret_sender));
ml_kem_768::decapsulate(computed_skey, computed_ctxt, computed_shared_secret_receiver);
auto _sk = std::string_view(sk);
auto __sk = _sk.substr(_sk.find("="sv) + 2, _sk.size());
auto ___sk = from_hex<ml_kem_768::SKEY_BYTE_LEN>(__sk);
auto _m = std::string_view(m);
auto __m = _m.substr(_m.find("="sv) + 2, _m.size());
auto ___m = from_hex<32>(__m);
auto _ct = std::string_view(ct);
auto __ct = _ct.substr(_ct.find("="sv) + 2, _ct.size());
auto ___ct = from_hex<ml_kem_768::CIPHER_TEXT_BYTE_LEN>(__ct);
auto _ss = std::string_view(ss);
auto __ss = _ss.substr(_ss.find("="sv) + 2, _ss.size());
auto ___ss = from_hex<32>(__ss);
std::array<uint8_t, ml_kem_768::PKEY_BYTE_LEN> pkey{};
std::array<uint8_t, ml_kem_768::SKEY_BYTE_LEN> skey{};
std::array<uint8_t, ml_kem_768::CIPHER_TEXT_BYTE_LEN> ctxt{};
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> shrd_sec0{};
std::array<uint8_t, ml_kem_768::SHARED_SECRET_BYTE_LEN> shrd_sec1{};
ml_kem_768::keygen(___d, ___z, pkey, skey);
EXPECT_TRUE(ml_kem_768::encapsulate(___m, pkey, ctxt, shrd_sec0));
ml_kem_768::decapsulate(skey, ctxt, shrd_sec1);
EXPECT_EQ(___pk, pkey);
EXPECT_EQ(___sk, skey);
EXPECT_EQ(___ct, ctxt);
EXPECT_EQ(___ss, shrd_sec0);
EXPECT_EQ(shrd_sec0, shrd_sec1);
EXPECT_EQ(pk, computed_pkey);
EXPECT_EQ(sk, computed_skey);
EXPECT_EQ(ct, computed_ctxt);
EXPECT_EQ(ss, computed_shared_secret_sender);
EXPECT_EQ(computed_shared_secret_sender, computed_shared_secret_receiver);
std::string empty_line;
std::getline(file, empty_line);