mirror of
https://github.com/itzmeanjan/ml-kem.git
synced 2026-01-09 15:47:55 -05:00
update all remaining function interfaces to use statically defined std::span
Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
@@ -75,4 +75,4 @@ poly_decompress(std::span<field::zq_t, ntt::N> poly)
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kyber_utils
|
||||
}
|
||||
|
||||
@@ -205,4 +205,4 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace field
|
||||
}
|
||||
|
||||
@@ -115,7 +115,6 @@ encapsulate(
|
||||
h512.absorb(_g_in);
|
||||
h512.finalize();
|
||||
h512.digest(_g_out);
|
||||
h512.reset();
|
||||
|
||||
pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, cipher);
|
||||
std::copy(_g_out0.begin(), _g_out0.end(), _kdf_in0.begin());
|
||||
@@ -189,7 +188,6 @@ decapsulate(
|
||||
h512.absorb(_g_in);
|
||||
h512.finalize();
|
||||
h512.digest(_g_out);
|
||||
h512.reset();
|
||||
|
||||
pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, c_prime);
|
||||
|
||||
@@ -214,4 +212,4 @@ decapsulate(
|
||||
return xof256;
|
||||
}
|
||||
|
||||
} // namespace kem
|
||||
}
|
||||
|
||||
@@ -64,4 +64,4 @@ decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
|
||||
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
|
||||
}
|
||||
|
||||
} // namespace kyber1024_kem
|
||||
}
|
||||
|
||||
@@ -64,4 +64,4 @@ decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
|
||||
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
|
||||
}
|
||||
|
||||
} // namespace kyber512_kem
|
||||
}
|
||||
|
||||
@@ -63,4 +63,4 @@ decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
|
||||
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
|
||||
}
|
||||
|
||||
} // namespace kyber768_kem
|
||||
}
|
||||
|
||||
@@ -103,14 +103,14 @@ constexpr std::array<field::zq_t, N / 2> POLY_MUL_ζ_EXP = compute_mul_ζ();
|
||||
// Implementation inspired from
|
||||
// https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L69-L144
|
||||
inline void
|
||||
ntt(std::span<field::zq_t> poly)
|
||||
ntt(std::span<field::zq_t, N> poly)
|
||||
{
|
||||
for (size_t l = LOG2N - 1; l >= 1; l--) {
|
||||
const size_t len = 1ul << l;
|
||||
const size_t lenx2 = len << 1;
|
||||
const size_t k_beg = N >> (l + 1);
|
||||
|
||||
for (size_t start = 0; start < N; start += lenx2) {
|
||||
for (size_t start = 0; start < poly.size(); start += lenx2) {
|
||||
const size_t k_now = k_beg + (start >> (l + 1));
|
||||
// Looking up precomputed constant, though it can be computed using
|
||||
//
|
||||
@@ -140,14 +140,14 @@ ntt(std::span<field::zq_t> poly)
|
||||
// Implementation inspired from
|
||||
// https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L146-L224
|
||||
inline void
|
||||
intt(std::span<field::zq_t> poly)
|
||||
intt(std::span<field::zq_t, N> poly)
|
||||
{
|
||||
for (size_t l = 1; l < LOG2N; l++) {
|
||||
const size_t len = 1ul << l;
|
||||
const size_t lenx2 = len << 1;
|
||||
const size_t k_beg = (N >> l) - 1;
|
||||
|
||||
for (size_t start = 0; start < N; start += lenx2) {
|
||||
for (size_t start = 0; start < poly.size(); start += lenx2) {
|
||||
const size_t k_now = k_beg - (start >> (l + 1));
|
||||
// Looking up precomputed constant, though it can be computed using
|
||||
//
|
||||
@@ -185,10 +185,10 @@ intt(std::span<field::zq_t> poly)
|
||||
// See page 6 of Kyber specification
|
||||
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
|
||||
static inline void
|
||||
basemul(std::span<const field::zq_t> f, // degree-1 polynomial
|
||||
std::span<const field::zq_t> g, // degree-1 polynomial
|
||||
std::span<field::zq_t> h, // degree-1 polynomial
|
||||
const field::zq_t ζ // zeta
|
||||
basemul(std::span<const field::zq_t, 2> f, // degree-1 polynomial
|
||||
std::span<const field::zq_t, 2> g, // degree-1 polynomial
|
||||
std::span<field::zq_t, 2> h, // degree-1 polynomial
|
||||
const field::zq_t ζ // zeta
|
||||
)
|
||||
{
|
||||
field::zq_t f0 = f[0];
|
||||
@@ -219,20 +219,23 @@ basemul(std::span<const field::zq_t> f, // degree-1 polynomial
|
||||
//
|
||||
// h = f ◦ g
|
||||
inline void
|
||||
polymul(std::span<const field::zq_t> f, // degree-255 polynomial
|
||||
std::span<const field::zq_t> g, // degree-255 polynomial
|
||||
std::span<field::zq_t> h // degree-255 polynomial
|
||||
polymul(std::span<const field::zq_t, N> f, // degree-255 polynomial
|
||||
std::span<const field::zq_t, N> g, // degree-255 polynomial
|
||||
std::span<field::zq_t, N> h // degree-255 polynomial
|
||||
)
|
||||
{
|
||||
constexpr size_t cnt = N >> 1;
|
||||
constexpr size_t cnt = f.size() >> 1;
|
||||
|
||||
using poly_t = std::span<const field::zq_t, 2>;
|
||||
using mut_poly_t = std::span<field::zq_t, 2>;
|
||||
|
||||
for (size_t i = 0; i < cnt; i++) {
|
||||
const size_t off = i << 1;
|
||||
basemul(f.subspan(off, 2),
|
||||
g.subspan(off, 2),
|
||||
h.subspan(off, 2),
|
||||
basemul(poly_t(f.subspan(off, 2)),
|
||||
poly_t(g.subspan(off, 2)),
|
||||
mut_poly_t(h.subspan(off, 2)),
|
||||
POLY_MUL_ζ_EXP[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ntt
|
||||
}
|
||||
|
||||
@@ -135,4 +135,4 @@ check_decap_params(const size_t k,
|
||||
return check_encap_params(k, η1, η2, du, dv);
|
||||
}
|
||||
|
||||
} // namespace kyber_params
|
||||
}
|
||||
|
||||
@@ -41,10 +41,9 @@ keygen(std::span<const uint8_t, 32> d,
|
||||
h512.absorb(d);
|
||||
h512.finalize();
|
||||
h512.digest(_g_out);
|
||||
h512.reset();
|
||||
|
||||
const auto rho = _g_out.template subspan<0, 32>();
|
||||
const auto sigma = _g_out.template subspan<32, 32>();
|
||||
const auto sigma = _g_out.template subspan<rho.size(), 32>();
|
||||
|
||||
// step 4, 5, 6, 7, 8
|
||||
std::array<field::zq_t, k * k * ntt::N> A_prime{};
|
||||
@@ -210,4 +209,4 @@ decrypt(std::span<const uint8_t, k * 12 * 32> seckey,
|
||||
kyber_utils::encode<1>(v, dec);
|
||||
}
|
||||
|
||||
} // namespace pke
|
||||
}
|
||||
|
||||
@@ -20,7 +20,10 @@ matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
|
||||
std::span<field::zq_t, a_rows * b_cols * ntt::N> c)
|
||||
requires(kyber_params::check_matrix_dim(a_cols, b_rows))
|
||||
{
|
||||
using poly_t = std::span<const field::zq_t, ntt::N>;
|
||||
|
||||
std::array<field::zq_t, ntt::N> tmp{};
|
||||
auto _tmp = std::span(tmp);
|
||||
|
||||
for (size_t i = 0; i < a_rows; i++) {
|
||||
for (size_t j = 0; j < b_cols; j++) {
|
||||
@@ -30,7 +33,9 @@ matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
|
||||
const size_t aoff = (i * a_cols + k) * ntt::N;
|
||||
const size_t boff = (k * b_cols + j) * ntt::N;
|
||||
|
||||
ntt::polymul(a.subspan(aoff, ntt::N), b.subspan(boff, ntt::N), tmp);
|
||||
ntt::polymul(poly_t(a.subspan(aoff, ntt::N)),
|
||||
poly_t(b.subspan(boff, ntt::N)),
|
||||
_tmp);
|
||||
|
||||
for (size_t l = 0; l < ntt::N; l++) {
|
||||
c[coff + l] += tmp[l];
|
||||
@@ -48,9 +53,11 @@ static inline void
|
||||
poly_vec_ntt(std::span<field::zq_t, k * ntt::N> vec)
|
||||
requires((k == 1) || kyber_params::check_k(k))
|
||||
{
|
||||
using poly_t = std::span<field::zq_t, ntt::N>;
|
||||
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
const size_t off = i * ntt::N;
|
||||
ntt::ntt(vec.subspan(off, ntt::N));
|
||||
ntt::ntt(poly_t(vec.subspan(off, ntt::N)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,9 +70,11 @@ static inline void
|
||||
poly_vec_intt(std::span<field::zq_t, k * ntt::N> vec)
|
||||
requires((k == 1) || kyber_params::check_k(k))
|
||||
{
|
||||
using poly_t = std::span<field::zq_t, ntt::N>;
|
||||
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
const size_t off = i * ntt::N;
|
||||
ntt::intt(vec.subspan(off, ntt::N));
|
||||
ntt::intt(poly_t(vec.subspan(off, ntt::N)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,4 +180,4 @@ poly_vec_decompress(std::span<field::zq_t, k * ntt::N> vec)
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kyber_utils
|
||||
}
|
||||
|
||||
@@ -61,4 +61,4 @@ public:
|
||||
inline void read(std::span<uint8_t> bytes) { state.squeeze(bytes); }
|
||||
};
|
||||
|
||||
} // namespace prng
|
||||
}
|
||||
|
||||
@@ -20,11 +20,9 @@ namespace kyber_utils {
|
||||
// See algorithm 1, defined in Kyber specification
|
||||
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
|
||||
inline void
|
||||
parse(shake128::shake128_t& hasher, // Squeezes bytes
|
||||
std::span<field::zq_t> poly // Degree 255 polynomial
|
||||
)
|
||||
parse(shake128::shake128_t& hasher, std::span<field::zq_t, ntt::N> poly)
|
||||
{
|
||||
constexpr size_t n = ntt::N;
|
||||
constexpr size_t n = poly.size();
|
||||
|
||||
size_t coeff_idx = 0;
|
||||
std::array<uint8_t, shake128::RATE / 8> buf{};
|
||||
@@ -32,7 +30,7 @@ parse(shake128::shake128_t& hasher, // Squeezes bytes
|
||||
while (coeff_idx < n) {
|
||||
hasher.squeeze(buf);
|
||||
|
||||
for (size_t off = 0; (off < sizeof(buf)) && (coeff_idx < n); off += 3) {
|
||||
for (size_t off = 0; (off < buf.size()) && (coeff_idx < n); off += 3) {
|
||||
const uint16_t d1 = (static_cast<uint16_t>(buf[off + 1] & 0x0f) << 8) |
|
||||
(static_cast<uint16_t>(buf[off + 0]) << 0);
|
||||
const uint16_t d2 = (static_cast<uint16_t>(buf[off + 2]) << 4) |
|
||||
@@ -81,7 +79,9 @@ generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat,
|
||||
shake128::shake128_t hasher{};
|
||||
hasher.absorb(xof_in);
|
||||
hasher.finalize();
|
||||
parse(hasher, mat.subspan(off, ntt::N));
|
||||
|
||||
using poly_t = std::span<field::zq_t, mat.size() / (k * k)>;
|
||||
parse(hasher, poly_t(mat.subspan(off, ntt::N)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,9 +95,7 @@ generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat,
|
||||
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
|
||||
template<size_t eta>
|
||||
static inline void
|
||||
cbd(std::span<const uint8_t, 64 * eta> prf, // Byte array of length 64 * eta
|
||||
std::span<field::zq_t> poly // Degree 255 polynomial
|
||||
)
|
||||
cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
|
||||
requires(kyber_params::check_eta(eta))
|
||||
{
|
||||
if constexpr (eta == 2) {
|
||||
@@ -176,8 +174,9 @@ generate_vector(std::span<field::zq_t, k * ntt::N> vec,
|
||||
hasher.finalize();
|
||||
hasher.squeeze(prf_out);
|
||||
|
||||
kyber_utils::cbd<eta>(prf_out, vec.subspan(off, ntt::N));
|
||||
using poly_t = std::span<field::zq_t, vec.size() / k>;
|
||||
kyber_utils::cbd<eta>(prf_out, poly_t(vec.subspan(off, ntt::N)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kyber_utils
|
||||
}
|
||||
|
||||
@@ -328,4 +328,4 @@ decode(std::span<const uint8_t, 32 * l> arr,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kyber_utils
|
||||
}
|
||||
|
||||
@@ -83,4 +83,4 @@ get_kem_cipher_len()
|
||||
return k * du * 32 + dv * 32;
|
||||
}
|
||||
|
||||
} // namespace kyber_utils
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "field.hpp"
|
||||
#include "ntt.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
@@ -15,15 +16,18 @@ TEST(KyberKEM, NumberTheoreticTransform)
|
||||
std::vector<field::zq_t> poly_a(ntt::N);
|
||||
std::vector<field::zq_t> poly_b(ntt::N);
|
||||
|
||||
auto _poly_a = std::span<field::zq_t, ntt::N>(poly_a);
|
||||
auto _poly_b = std::span<field::zq_t, ntt::N>(poly_b);
|
||||
|
||||
prng::prng_t prng;
|
||||
|
||||
for (size_t i = 0; i < ntt::N; i++) {
|
||||
poly_a[i] = field::zq_t::random(prng);
|
||||
poly_b[i] = poly_a[i];
|
||||
_poly_a[i] = field::zq_t::random(prng);
|
||||
}
|
||||
std::copy(_poly_a.begin(), _poly_a.end(), _poly_b.begin());
|
||||
|
||||
ntt::ntt(poly_b);
|
||||
ntt::intt(poly_b);
|
||||
ntt::ntt(_poly_b);
|
||||
ntt::intt(_poly_b);
|
||||
|
||||
EXPECT_EQ(poly_a, poly_b);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user