use google-test library for writing/ running tests

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2023-07-16 16:53:28 +04:00
parent 8760317253
commit a8512938f1
11 changed files with 210 additions and 284 deletions

View File

@@ -1,78 +0,0 @@
#pragma once
#include "kem.hpp"
#include "utils.hpp"
#include <cassert>
// Test functional correctness of Kyber PQC suite implementation
namespace test_kyber {
// Given k, η1, η2, du, dv - Kyber parameters, this routine checks whether
//
// - A new key pair can be generated for key establishment over insecure channel
// - Key pair is for receiving party, its public key will be used by sender.
// - Sender can produce a cipher text and a key derivation function ( KDF )
// - Sender uses receiver's public key.
// - Cipher text is sent over insecure channel to receiver
// - Receiver can decrypt message ( using secret key ) and arrives at same KDF
// - Both parties use KDF ( SHAKE256 hasher object ) to generate arbitrary
// length shared secret key.
// - This shared secret key can now be used with any symmetric key primitive.
//
// works as expected.
template<const size_t k,
const size_t eta1,
const size_t eta2,
const size_t du,
const size_t dv,
const size_t klen>
void
test_kyber_kem()
{
constexpr size_t slen = 32;
constexpr size_t pklen = kyber_utils::get_kem_public_key_len<k>();
constexpr size_t sklen = kyber_utils::get_kem_secret_key_len<k>();
constexpr size_t ctlen = kyber_utils::get_kem_cipher_len<k, du, dv>();
uint8_t* d = static_cast<uint8_t*>(std::malloc(slen));
uint8_t* z = static_cast<uint8_t*>(std::malloc(slen));
uint8_t* m = static_cast<uint8_t*>(std::malloc(slen));
uint8_t* pkey = static_cast<uint8_t*>(std::malloc(pklen));
uint8_t* skey = static_cast<uint8_t*>(std::malloc(sklen));
uint8_t* cipher = static_cast<uint8_t*>(std::malloc(ctlen));
uint8_t* sender_key = static_cast<uint8_t*>(std::malloc(klen));
uint8_t* receiver_key = static_cast<uint8_t*>(std::malloc(klen));
std::memset(pkey, 0, pklen);
std::memset(skey, 0, sklen);
std::memset(cipher, 0, ctlen);
prng::prng_t prng;
prng.read(d, slen);
prng.read(z, slen);
prng.read(m, slen);
kem::keygen<k, eta1>(d, z, pkey, skey);
auto skdf = kem::encapsulate<k, eta1, eta2, du, dv>(m, pkey, cipher);
auto rkdf = kem::decapsulate<k, eta1, eta2, du, dv>(skey, cipher);
skdf.squeeze(sender_key, klen);
rkdf.squeeze(receiver_key, klen);
bool flg = false;
for (size_t i = 0; i < klen; i++) {
flg |= static_cast<bool>(sender_key[i] ^ receiver_key[i]);
}
std::free(d);
std::free(z);
std::free(m);
std::free(pkey);
std::free(skey);
std::free(cipher);
std::free(sender_key);
std::free(receiver_key);
assert(!flg);
}
}

View File

@@ -1,8 +0,0 @@
#pragma once
#include "test_compression.hpp"
#include "test_field.hpp"
#include "test_kem.hpp"
#include "test_kem_kat.hpp"
#include "test_ntt.hpp"
#include "test_serialize.hpp"

View File

@@ -1,43 +0,0 @@
#pragma once
#include "ntt.hpp"
#include <cassert>
// Test functional correctness of Kyber PQC suite implementation
namespace test_kyber {
// Ensure functional correctness of (inverse) NTT implementation for degree-255
// polynomial over F_q | q = 3329, using following rule
//
// f -> random polynomial
// f' = ntt(f)
// f'' = intt(f')
//
// assert(f == f'')
inline void
test_ntt_intt()
{
constexpr size_t poly_len = sizeof(field::zq_t) * ntt::N;
field::zq_t* poly_a = static_cast<field::zq_t*>(std::malloc(poly_len));
field::zq_t* poly_b = static_cast<field::zq_t*>(std::malloc(poly_len));
prng::prng_t prng;
for (size_t i = 0; i < ntt::N; i++) {
poly_a[i] = field::zq_t::random(prng);
}
std::memcpy(poly_b, poly_a, poly_len);
ntt::ntt(poly_b);
ntt::intt(poly_b);
for (size_t i = 0; i < ntt::N; i++) {
assert(poly_a[i] == poly_b[i]);
}
std::free(poly_a);
std::free(poly_b);
}
}

View File

@@ -1,44 +0,0 @@
#pragma once
#include "serialize.hpp"
#include <cassert>
// Test functional correctness of Kyber PQC suite implementation
namespace test_kyber {
// Ensure that degree-255 polynomial serialization to byte array ( of length
// 32*l -bytes ) and deserialization of that byte array to degree-255 polynomial
// works as expected for parameterizable values of l | l ∈ [1, 12]
//
// l denotes significant bit width ( from LSB side ) for each coefficient of
// polynomial.
template<const size_t l>
void
test_serialization()
{
constexpr size_t plen = sizeof(field::zq_t) * ntt::N;
constexpr size_t blen = 32 * l;
constexpr uint32_t mask = (1u << l) - 1u;
field::zq_t* src = static_cast<field::zq_t*>(std::malloc(plen));
uint8_t* arr = static_cast<uint8_t*>(std::malloc(blen));
field::zq_t* dst = static_cast<field::zq_t*>(std::malloc(plen));
prng::prng_t prng;
for (size_t i = 0; i < ntt::N; i++) {
src[i] = field::zq_t::random(prng);
}
kyber_utils::encode<l>(src, arr);
kyber_utils::decode<l>(arr, dst);
for (size_t i = 0; i < ntt::N; i++) {
assert((src[i].to_canonical() & mask) == (dst[i].to_canonical() & mask));
}
std::free(src);
std::free(arr);
std::free(dst);
}
}

View File

@@ -1,39 +0,0 @@
#include "tests/test_kyber.hpp"
#include <iostream>
int
main()
{
test_kyber::test_field_ops();
std::cout << "[test] Kyber prime field operations\n";
test_kyber::test_ntt_intt();
std::cout << "[test] (i)NTT over degree-255 polynomial\n";
test_kyber::test_serialization<12>();
test_kyber::test_serialization<11>();
test_kyber::test_serialization<10>();
test_kyber::test_serialization<5>();
test_kyber::test_serialization<4>();
test_kyber::test_serialization<1>();
std::cout << "[test] Polynomial serialization/ deserialization\n";
test_kyber::test_compression<11>();
test_kyber::test_compression<10>();
test_kyber::test_compression<5>();
test_kyber::test_compression<4>();
test_kyber::test_compression<1>();
std::cout << "[test] Coefficient compression/ decompression\n";
test_kyber::test_kyber_kem<2, 3, 2, 10, 4, 32>(); // kyber-512, 32B -key
test_kyber::test_kyber_kem<3, 2, 2, 10, 4, 32>(); // kyber-768, 32B -key
test_kyber::test_kyber_kem<4, 2, 2, 11, 5, 32>(); // kyber-1024, 32B -key
std::cout << "[test] INDCCA2-secure Kyber KEM\n";
test_kyber::test_kyber512_kem_kat();
test_kyber::test_kyber768_kem_kat();
test_kyber::test_kyber1024_kem_kat();
std::cout << "[test] Kyber KEM Known Answer Tests\n";
return EXIT_SUCCESS;
}

View File

@@ -1,9 +1,5 @@
#pragma once
#include "compression.hpp"
#include <cassert>
// Test functional correctness of Kyber PQC suite implementation
namespace test_kyber {
#include <gtest/gtest.h>
// Test functional correctness of compression/ decompression logic s.t. given an
// element x ∈ Z_q following is satisfied
@@ -12,16 +8,18 @@ namespace test_kyber {
//
// |(x' - x) mod q| <= round(q / 2 ^ (d + 1))
//
// This test is executed a few times on some random Z_q elements, for some
// specified `d`.
template<const size_t d>
void
test_compression()
// Returned boolean accumulates result of all compression/ decompression
// execution iterations. It must hold truth value for function caller to believe
// that compression/ decompression logic is implemented correctly.
template<const size_t d, const size_t itr_cnt>
bool
test_zq_compression()
requires(itr_cnt > 0)
{
constexpr size_t cnt = 1024;
bool res = true;
prng::prng_t prng;
for (size_t i = 0; i < cnt; i++) {
for (size_t i = 0; i < itr_cnt; i++) {
const auto a = field::zq_t::random(prng);
const auto b = kyber_utils::compress<d>(a);
@@ -41,8 +39,17 @@ test_compression()
const size_t err = static_cast<size_t>(std::abs(c_prime - a_prime));
const size_t terr = kyber_utils::compute_error<d>();
assert(err <= terr);
res &= (err <= terr);
}
return res;
}
TEST(KyberKEM, CompressDecompressZq)
{
ASSERT_TRUE((test_zq_compression<11, 1ul << 20>()));
ASSERT_TRUE((test_zq_compression<10, 1ul << 20>()));
ASSERT_TRUE((test_zq_compression<5, 1ul << 20>()));
ASSERT_TRUE((test_zq_compression<4, 1ul << 20>()));
ASSERT_TRUE((test_zq_compression<1, 1ul << 20>()));
}

View File

@@ -1,17 +1,12 @@
#pragma once
#include "field.hpp"
#include <cassert>
// Test functional correctness of Kyber PQC suite implementation
namespace test_kyber {
#include <gtest/gtest.h>
// Test functional correctness of Kyber prime field operations ( using
// Montgomery Arithmetic ), by running through multiple rounds of execution of
// field operations on randomly sampled field elements
inline void
test_field_ops()
// field operations on randomly sampled field elements.
TEST(KyberKEM, ArithmeticOverZq)
{
constexpr size_t itr_cnt = 1ul << 10;
constexpr size_t itr_cnt = 1ul << 20;
prng::prng_t prng;
for (size_t i = 0; i < itr_cnt; i++) {
@@ -23,27 +18,24 @@ test_field_ops()
const auto d = c - b;
const auto e = c - a;
assert(d == a);
assert(e == b);
ASSERT_EQ(d, a);
ASSERT_EQ(e, b);
// Multiplication, Exponentiation, Inversion and Division
const auto f = a * b;
const auto g = f / b;
const auto h = f / a;
if (b != field::zq_t()) {
assert(g == a);
ASSERT_EQ(g, a);
} else {
assert(g == field::zq_t());
ASSERT_EQ(g, field::zq_t());
}
if (a != field::zq_t()) {
assert(h == b);
ASSERT_EQ(h, b);
} else {
assert(h == field::zq_t());
ASSERT_EQ(h, field::zq_t());
}
}
}
}

72
tests/test_kem.cpp Normal file
View File

@@ -0,0 +1,72 @@
#include "kem.hpp"
#include "utils.hpp"
#include <gtest/gtest.h>
// Given k, η1, η2, du, dv - Kyber parameters, this routine checks whether
//
// - A new key pair can be generated for key establishment over insecure channel
// - Key pair is for receiving party, its public key will be used by sender.
// - Sender can produce a cipher text and a key derivation function ( KDF )
// - Sender uses receiver's public key.
// - Cipher text is sent over insecure channel to receiver
// - Receiver can decrypt message ( using secret key ) and arrives at same KDF
// - Both parties use KDF ( SHAKE256 hasher object ) to generate arbitrary
// length shared secret key.
// - This shared secret key can now be used with any symmetric key primitive.
//
// works as expected.
template<const size_t k,
const size_t eta1,
const size_t eta2,
const size_t du,
const size_t dv,
const size_t klen>
void
test_kyber_kem()
requires(klen > 0)
{
constexpr size_t slen = 32;
constexpr size_t pklen = kyber_utils::get_kem_public_key_len<k>();
constexpr size_t sklen = kyber_utils::get_kem_secret_key_len<k>();
constexpr size_t ctlen = kyber_utils::get_kem_cipher_len<k, du, dv>();
std::vector<uint8_t> d(slen);
std::vector<uint8_t> z(slen);
std::vector<uint8_t> m(slen);
std::vector<uint8_t> pkey(pklen);
std::vector<uint8_t> skey(sklen);
std::vector<uint8_t> cipher(ctlen);
std::vector<uint8_t> sender_key(klen);
std::vector<uint8_t> receiver_key(klen);
prng::prng_t prng;
prng.read(d.data(), d.size());
prng.read(z.data(), z.size());
prng.read(m.data(), m.size());
kem::keygen<k, eta1>(d.data(), z.data(), pkey.data(), skey.data());
auto skdf = kem::encapsulate<k, eta1, eta2, du, dv>(
m.data(), pkey.data(), cipher.data());
auto rkdf =
kem::decapsulate<k, eta1, eta2, du, dv>(skey.data(), cipher.data());
skdf.squeeze(sender_key.data(), sender_key.size());
rkdf.squeeze(receiver_key.data(), receiver_key.size());
ASSERT_EQ(sender_key, receiver_key);
}
TEST(KyberKEM, Kyber512KeygenEncapsDecaps)
{
test_kyber_kem<2, 3, 2, 10, 4, 32>();
}
TEST(KyberKEM, Kyber768KeygenEncapsDecaps)
{
test_kyber_kem<3, 2, 2, 10, 4, 32>();
}
TEST(KyberKEM, Kyber1024KeygenEncapsDecaps)
{
test_kyber_kem<4, 2, 2, 11, 5, 32>();
}

View File

@@ -2,27 +2,22 @@
#include "kyber512_kem.hpp"
#include "kyber768_kem.hpp"
#include "utils.hpp"
#include <cassert>
#include <fstream>
// Test functional correctness of Kyber PQC suite implementation
namespace test_kyber {
using namespace std::literals;
namespace utils = kyber_utils;
namespace kyber512 = kyber512_kem;
namespace kyber768 = kyber768_kem;
namespace kyber1024 = kyber1024_kem;
#include <gtest/gtest.h>
// Test if
//
// - Is Kyber512 KEM implemented correctly ?
// - Is it conformant with the specification ?
//
// using Known Answer Tests.
inline void
test_kyber512_kem_kat()
// using Known Answer Tests, generated following
// https://gist.github.com/itzmeanjan/c8f5bc9640d0f0bdd2437dfe364d7710.
TEST(KyberKEM, Kyber512KnownAnswerTests)
{
using namespace std::literals;
namespace utils = kyber_utils;
namespace kyber512 = kyber512_kem;
const std::string kat_file = "./kats/kyber512.kat";
std::fstream file(kat_file);
@@ -85,11 +80,11 @@ test_kyber512_kem_kat()
skdf.squeeze(shrd_sec0.data(), shrd_sec0.size());
rkdf.squeeze(shrd_sec1.data(), shrd_sec1.size());
assert(std::ranges::equal(___pk, pkey));
assert(std::ranges::equal(___sk, skey));
assert(std::ranges::equal(___ct, ctxt));
assert(std::ranges::equal(___ss, shrd_sec0));
assert(std::ranges::equal(shrd_sec0, shrd_sec1));
ASSERT_EQ(___pk, pkey);
ASSERT_EQ(___sk, skey);
ASSERT_EQ(___ct, ctxt);
ASSERT_EQ(___ss, shrd_sec0);
ASSERT_EQ(shrd_sec0, shrd_sec1);
std::string empty_line;
std::getline(file, empty_line);
@@ -101,15 +96,12 @@ test_kyber512_kem_kat()
file.close();
}
// Test if
//
// - Is Kyber768 KEM implemented correctly ?
// - Is it conformant with the specification ?
//
// using Known Answer Tests.
inline void
test_kyber768_kem_kat()
TEST(KyberKEM, Kyber768KnownAnswerTests)
{
using namespace std::literals;
namespace utils = kyber_utils;
namespace kyber768 = kyber768_kem;
const std::string kat_file = "./kats/kyber768.kat";
std::fstream file(kat_file);
@@ -172,11 +164,11 @@ test_kyber768_kem_kat()
skdf.squeeze(shrd_sec0.data(), shrd_sec0.size());
rkdf.squeeze(shrd_sec1.data(), shrd_sec1.size());
assert(std::ranges::equal(___pk, pkey));
assert(std::ranges::equal(___sk, skey));
assert(std::ranges::equal(___ct, ctxt));
assert(std::ranges::equal(___ss, shrd_sec0));
assert(std::ranges::equal(shrd_sec0, shrd_sec1));
ASSERT_EQ(___pk, pkey);
ASSERT_EQ(___sk, skey);
ASSERT_EQ(___ct, ctxt);
ASSERT_EQ(___ss, shrd_sec0);
ASSERT_EQ(shrd_sec0, shrd_sec1);
std::string empty_line;
std::getline(file, empty_line);
@@ -193,10 +185,14 @@ test_kyber768_kem_kat()
// - Is Kyber1024 KEM implemented correctly ?
// - Is it conformant with the specification ?
//
// using Known Answer Tests.
inline void
test_kyber1024_kem_kat()
// using Known Answer Tests, generated following
// https://gist.github.com/itzmeanjan/c8f5bc9640d0f0bdd2437dfe364d7710.
TEST(KyberKEM, Kyber1024KnownAnswerTests)
{
using namespace std::literals;
namespace utils = kyber_utils;
namespace kyber1024 = kyber1024_kem;
const std::string kat_file = "./kats/kyber1024.kat";
std::fstream file(kat_file);
@@ -259,11 +255,11 @@ test_kyber1024_kem_kat()
skdf.squeeze(shrd_sec0.data(), shrd_sec0.size());
rkdf.squeeze(shrd_sec1.data(), shrd_sec1.size());
assert(std::ranges::equal(___pk, pkey));
assert(std::ranges::equal(___sk, skey));
assert(std::ranges::equal(___ct, ctxt));
assert(std::ranges::equal(___ss, shrd_sec0));
assert(std::ranges::equal(shrd_sec0, shrd_sec1));
ASSERT_EQ(___pk, pkey);
ASSERT_EQ(___sk, skey);
ASSERT_EQ(___ct, ctxt);
ASSERT_EQ(___ss, shrd_sec0);
ASSERT_EQ(shrd_sec0, shrd_sec1);
std::string empty_line;
std::getline(file, empty_line);
@@ -274,5 +270,3 @@ test_kyber1024_kem_kat()
file.close();
}
}

29
tests/test_ntt.cpp Normal file
View File

@@ -0,0 +1,29 @@
#include "ntt.hpp"
#include <gtest/gtest.h>
#include <vector>
// Ensure functional correctness of (inverse) NTT implementation for degree-255
// polynomial over F_q | q = 3329, using following rule
//
// f <- random polynomial
// f' <- ntt(f)
// f'' <- intt(f')
//
// assert(f == f'')
TEST(KyberKEM, NumberTheoreticTransform)
{
std::vector<field::zq_t> poly_a(ntt::N);
std::vector<field::zq_t> poly_b(ntt::N);
prng::prng_t prng;
for (size_t i = 0; i < ntt::N; i++) {
poly_a[i] = field::zq_t::random(prng);
poly_b[i] = poly_a[i];
}
ntt::ntt(poly_b.data());
ntt::intt(poly_b.data());
ASSERT_EQ(poly_a, poly_b);
}

44
tests/test_serialize.cpp Normal file
View File

@@ -0,0 +1,44 @@
#include "serialize.hpp"
#include <gtest/gtest.h>
#include <vector>
// Ensure that degree-255 polynomial serialization to byte array ( of length
// 32*l -bytes ) and deserialization of that byte array to degree-255 polynomial
// works as expected for parameterizable values of l | l ∈ [1, 12].
//
// l denotes significant bit width ( from LSB side ) for each coefficient of
// polynomial.
template<const size_t l>
void
test_serialize_deserialize()
{
constexpr size_t blen = (ntt::N * l) / 8;
constexpr uint32_t mask = (1u << l) - 1u;
std::vector<field::zq_t> src(ntt::N);
std::vector<field::zq_t> dst(ntt::N);
std::vector<uint8_t> bytes(blen);
prng::prng_t prng;
for (size_t i = 0; i < ntt::N; i++) {
src[i] = field::zq_t::random(prng);
}
kyber_utils::encode<l>(src.data(), bytes.data());
kyber_utils::decode<l>(bytes.data(), dst.data());
for (size_t i = 0; i < ntt::N; i++) {
ASSERT_EQ((src[i].to_canonical() & mask), (dst[i].to_canonical() & mask));
}
}
TEST(KyberKEM, PolynomialSerialization)
{
test_serialize_deserialize<12>();
test_serialize_deserialize<11>();
test_serialize_deserialize<10>();
test_serialize_deserialize<5>();
test_serialize_deserialize<4>();
test_serialize_deserialize<1>();
}