update all remaining function interfaces to use statically defined std::span

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2023-10-01 09:41:47 +05:30
parent 2c528a3857
commit 5930d75188
15 changed files with 62 additions and 50 deletions

View File

@@ -75,4 +75,4 @@ poly_decompress(std::span<field::zq_t, ntt::N> poly)
}
}
} // namespace kyber_utils
}

View File

@@ -205,4 +205,4 @@ private:
}
};
} // namespace field
}

View File

@@ -115,7 +115,6 @@ encapsulate(
h512.absorb(_g_in);
h512.finalize();
h512.digest(_g_out);
h512.reset();
pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, cipher);
std::copy(_g_out0.begin(), _g_out0.end(), _kdf_in0.begin());
@@ -189,7 +188,6 @@ decapsulate(
h512.absorb(_g_in);
h512.finalize();
h512.digest(_g_out);
h512.reset();
pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, c_prime);
@@ -214,4 +212,4 @@ decapsulate(
return xof256;
}
} // namespace kem
}

View File

@@ -64,4 +64,4 @@ decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
}
} // namespace kyber1024_kem
}

View File

@@ -64,4 +64,4 @@ decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
}
} // namespace kyber512_kem
}

View File

@@ -63,4 +63,4 @@ decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
}
} // namespace kyber768_kem
}

View File

@@ -103,14 +103,14 @@ constexpr std::array<field::zq_t, N / 2> POLY_MUL_ζ_EXP = compute_mul_ζ();
// Implementation inspired from
// https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L69-L144
inline void
ntt(std::span<field::zq_t> poly)
ntt(std::span<field::zq_t, N> poly)
{
for (size_t l = LOG2N - 1; l >= 1; l--) {
const size_t len = 1ul << l;
const size_t lenx2 = len << 1;
const size_t k_beg = N >> (l + 1);
for (size_t start = 0; start < N; start += lenx2) {
for (size_t start = 0; start < poly.size(); start += lenx2) {
const size_t k_now = k_beg + (start >> (l + 1));
// Looking up precomputed constant, though it can be computed using
//
@@ -140,14 +140,14 @@ ntt(std::span<field::zq_t> poly)
// Implementation inspired from
// https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L146-L224
inline void
intt(std::span<field::zq_t> poly)
intt(std::span<field::zq_t, N> poly)
{
for (size_t l = 1; l < LOG2N; l++) {
const size_t len = 1ul << l;
const size_t lenx2 = len << 1;
const size_t k_beg = (N >> l) - 1;
for (size_t start = 0; start < N; start += lenx2) {
for (size_t start = 0; start < poly.size(); start += lenx2) {
const size_t k_now = k_beg - (start >> (l + 1));
// Looking up precomputed constant, though it can be computed using
//
@@ -185,10 +185,10 @@ intt(std::span<field::zq_t> poly)
// See page 6 of Kyber specification
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
static inline void
basemul(std::span<const field::zq_t> f, // degree-1 polynomial
std::span<const field::zq_t> g, // degree-1 polynomial
std::span<field::zq_t> h, // degree-1 polynomial
const field::zq_t ζ // zeta
basemul(std::span<const field::zq_t, 2> f, // degree-1 polynomial
std::span<const field::zq_t, 2> g, // degree-1 polynomial
std::span<field::zq_t, 2> h, // degree-1 polynomial
const field::zq_t ζ // zeta
)
{
field::zq_t f0 = f[0];
@@ -219,20 +219,23 @@ basemul(std::span<const field::zq_t> f, // degree-1 polynomial
//
// h = f ◦ g
inline void
polymul(std::span<const field::zq_t> f, // degree-255 polynomial
std::span<const field::zq_t> g, // degree-255 polynomial
std::span<field::zq_t> h // degree-255 polynomial
polymul(std::span<const field::zq_t, N> f, // degree-255 polynomial
std::span<const field::zq_t, N> g, // degree-255 polynomial
std::span<field::zq_t, N> h // degree-255 polynomial
)
{
constexpr size_t cnt = N >> 1;
constexpr size_t cnt = f.size() >> 1;
using poly_t = std::span<const field::zq_t, 2>;
using mut_poly_t = std::span<field::zq_t, 2>;
for (size_t i = 0; i < cnt; i++) {
const size_t off = i << 1;
basemul(f.subspan(off, 2),
g.subspan(off, 2),
h.subspan(off, 2),
basemul(poly_t(f.subspan(off, 2)),
poly_t(g.subspan(off, 2)),
mut_poly_t(h.subspan(off, 2)),
POLY_MUL_ζ_EXP[i]);
}
}
} // namespace ntt
}

View File

@@ -135,4 +135,4 @@ check_decap_params(const size_t k,
return check_encap_params(k, η1, η2, du, dv);
}
} // namespace kyber_params
}

View File

@@ -41,10 +41,9 @@ keygen(std::span<const uint8_t, 32> d,
h512.absorb(d);
h512.finalize();
h512.digest(_g_out);
h512.reset();
const auto rho = _g_out.template subspan<0, 32>();
const auto sigma = _g_out.template subspan<32, 32>();
const auto sigma = _g_out.template subspan<rho.size(), 32>();
// step 4, 5, 6, 7, 8
std::array<field::zq_t, k * k * ntt::N> A_prime{};
@@ -210,4 +209,4 @@ decrypt(std::span<const uint8_t, k * 12 * 32> seckey,
kyber_utils::encode<1>(v, dec);
}
} // namespace pke
}

View File

@@ -20,7 +20,10 @@ matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
std::span<field::zq_t, a_rows * b_cols * ntt::N> c)
requires(kyber_params::check_matrix_dim(a_cols, b_rows))
{
using poly_t = std::span<const field::zq_t, ntt::N>;
std::array<field::zq_t, ntt::N> tmp{};
auto _tmp = std::span(tmp);
for (size_t i = 0; i < a_rows; i++) {
for (size_t j = 0; j < b_cols; j++) {
@@ -30,7 +33,9 @@ matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
const size_t aoff = (i * a_cols + k) * ntt::N;
const size_t boff = (k * b_cols + j) * ntt::N;
ntt::polymul(a.subspan(aoff, ntt::N), b.subspan(boff, ntt::N), tmp);
ntt::polymul(poly_t(a.subspan(aoff, ntt::N)),
poly_t(b.subspan(boff, ntt::N)),
_tmp);
for (size_t l = 0; l < ntt::N; l++) {
c[coff + l] += tmp[l];
@@ -48,9 +53,11 @@ static inline void
poly_vec_ntt(std::span<field::zq_t, k * ntt::N> vec)
requires((k == 1) || kyber_params::check_k(k))
{
using poly_t = std::span<field::zq_t, ntt::N>;
for (size_t i = 0; i < k; i++) {
const size_t off = i * ntt::N;
ntt::ntt(vec.subspan(off, ntt::N));
ntt::ntt(poly_t(vec.subspan(off, ntt::N)));
}
}
@@ -63,9 +70,11 @@ static inline void
poly_vec_intt(std::span<field::zq_t, k * ntt::N> vec)
requires((k == 1) || kyber_params::check_k(k))
{
using poly_t = std::span<field::zq_t, ntt::N>;
for (size_t i = 0; i < k; i++) {
const size_t off = i * ntt::N;
ntt::intt(vec.subspan(off, ntt::N));
ntt::intt(poly_t(vec.subspan(off, ntt::N)));
}
}
@@ -171,4 +180,4 @@ poly_vec_decompress(std::span<field::zq_t, k * ntt::N> vec)
}
}
} // namespace kyber_utils
}

View File

@@ -61,4 +61,4 @@ public:
inline void read(std::span<uint8_t> bytes) { state.squeeze(bytes); }
};
} // namespace prng
}

View File

@@ -20,11 +20,9 @@ namespace kyber_utils {
// See algorithm 1, defined in Kyber specification
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
inline void
parse(shake128::shake128_t& hasher, // Squeezes bytes
std::span<field::zq_t> poly // Degree 255 polynomial
)
parse(shake128::shake128_t& hasher, std::span<field::zq_t, ntt::N> poly)
{
constexpr size_t n = ntt::N;
constexpr size_t n = poly.size();
size_t coeff_idx = 0;
std::array<uint8_t, shake128::RATE / 8> buf{};
@@ -32,7 +30,7 @@ parse(shake128::shake128_t& hasher, // Squeezes bytes
while (coeff_idx < n) {
hasher.squeeze(buf);
for (size_t off = 0; (off < sizeof(buf)) && (coeff_idx < n); off += 3) {
for (size_t off = 0; (off < buf.size()) && (coeff_idx < n); off += 3) {
const uint16_t d1 = (static_cast<uint16_t>(buf[off + 1] & 0x0f) << 8) |
(static_cast<uint16_t>(buf[off + 0]) << 0);
const uint16_t d2 = (static_cast<uint16_t>(buf[off + 2]) << 4) |
@@ -81,7 +79,9 @@ generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat,
shake128::shake128_t hasher{};
hasher.absorb(xof_in);
hasher.finalize();
parse(hasher, mat.subspan(off, ntt::N));
using poly_t = std::span<field::zq_t, mat.size() / (k * k)>;
parse(hasher, poly_t(mat.subspan(off, ntt::N)));
}
}
}
@@ -95,9 +95,7 @@ generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat,
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
template<size_t eta>
static inline void
cbd(std::span<const uint8_t, 64 * eta> prf, // Byte array of length 64 * eta
std::span<field::zq_t> poly // Degree 255 polynomial
)
cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
requires(kyber_params::check_eta(eta))
{
if constexpr (eta == 2) {
@@ -176,8 +174,9 @@ generate_vector(std::span<field::zq_t, k * ntt::N> vec,
hasher.finalize();
hasher.squeeze(prf_out);
kyber_utils::cbd<eta>(prf_out, vec.subspan(off, ntt::N));
using poly_t = std::span<field::zq_t, vec.size() / k>;
kyber_utils::cbd<eta>(prf_out, poly_t(vec.subspan(off, ntt::N)));
}
}
} // namespace kyber_utils
}

View File

@@ -328,4 +328,4 @@ decode(std::span<const uint8_t, 32 * l> arr,
}
}
} // namespace kyber_utils
}

View File

@@ -83,4 +83,4 @@ get_kem_cipher_len()
return k * du * 32 + dv * 32;
}
} // namespace kyber_utils
}

View File

@@ -1,3 +1,4 @@
#include "field.hpp"
#include "ntt.hpp"
#include <gtest/gtest.h>
#include <vector>
@@ -15,15 +16,18 @@ TEST(KyberKEM, NumberTheoreticTransform)
std::vector<field::zq_t> poly_a(ntt::N);
std::vector<field::zq_t> poly_b(ntt::N);
auto _poly_a = std::span<field::zq_t, ntt::N>(poly_a);
auto _poly_b = std::span<field::zq_t, ntt::N>(poly_b);
prng::prng_t prng;
for (size_t i = 0; i < ntt::N; i++) {
poly_a[i] = field::zq_t::random(prng);
poly_b[i] = poly_a[i];
_poly_a[i] = field::zq_t::random(prng);
}
std::copy(_poly_a.begin(), _poly_a.end(), _poly_b.begin());
ntt::ntt(poly_b);
ntt::intt(poly_b);
ntt::ntt(_poly_b);
ntt::intt(_poly_b);
EXPECT_EQ(poly_a, poly_b);
}