Add new header, with MACRO definition, for ease of forcing inlining of small functions

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2024-09-01 23:09:17 +04:00
parent 48c06432ee
commit 0f2849520b
14 changed files with 152 additions and 116 deletions

View File

@@ -12,7 +12,7 @@ namespace k_pke {
// K-PKE key generation algorithm, generating byte serialized public key and secret keym given a 32 -bytes input seed `d`.
// See algorithm 12 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1>
static inline constexpr void
constexpr void
keygen(std::span<const uint8_t, 32> d,
std::span<uint8_t, ml_kem_utils::get_pke_public_key_len(k)> pubkey,
std::span<uint8_t, ml_kem_utils::get_pke_secret_key_len(k)> seckey)
@@ -72,7 +72,7 @@ keygen(std::span<const uint8_t, 32> d,
//
// See algorithm 13 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result of modulus check on public key")]] static inline constexpr bool
[[nodiscard("Use result of modulus check on public key")]] constexpr bool
encrypt(std::span<const uint8_t, ml_kem_utils::get_pke_public_key_len(k)> pubkey,
std::span<const uint8_t, 32> msg,
std::span<const uint8_t, 32> rcoin,
@@ -149,7 +149,7 @@ encrypt(std::span<const uint8_t, ml_kem_utils::get_pke_public_key_len(k)> pubkey
//
// See algorithm 14 defined in K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t du, size_t dv>
static inline constexpr void
constexpr void
decrypt(std::span<const uint8_t, ml_kem_utils::get_pke_secret_key_len(k)> seckey,
std::span<const uint8_t, ml_kem_utils::get_pke_cipher_text_len(k, du, dv)> enc,
std::span<uint8_t, 32> dec)

View File

@@ -1,15 +1,16 @@
#pragma once
#include "ml_kem/internals/rng/prng.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
#include <bit>
#include <cstdint>
namespace ml_kem_field {
// Ml_kem Prime Field Modulus ( = 3329 )
static constexpr uint32_t Q = (1u << 8) * 13 + 1;
inline constexpr uint32_t Q = (1u << 8) * 13 + 1;
// Bit width of Ml_kem Prime Field Modulus ( = 12 )
static constexpr size_t Q_BIT_WIDTH = std::bit_width(Q);
inline constexpr size_t Q_BIT_WIDTH = std::bit_width(Q);
// Precomputed Barrett Reduction Constant
//
@@ -19,7 +20,7 @@ static constexpr size_t Q_BIT_WIDTH = std::bit_width(Q);
// r = floor((1 << 2k) / Q) = 5039
//
// See https://www.nayuki.io/page/barrett-reduction-algorithm.
static constexpr uint32_t R = (1u << (2 * Q_BIT_WIDTH)) / Q;
inline constexpr uint32_t R = (1u << (2 * Q_BIT_WIDTH)) / Q;
// Prime field Zq | q = 3329, with arithmetic operations defined over it.
//
@@ -33,7 +34,7 @@ private:
uint32_t v = 0u;
// Given a 32 -bit unsigned integer `v` such that `v` ∈ [0, 2*Q), this routine can be invoked for reducing `v` modulo prime Q.
static inline constexpr uint32_t reduce_once(const uint32_t v)
static forceinline constexpr uint32_t reduce_once(const uint32_t v)
{
const uint32_t t0 = v - Q;
const uint32_t t1 = -(t0 >> 31);
@@ -45,7 +46,7 @@ private:
// Given a 32 -bit unsigned integer `v` such that `v` ∈ [0, Q*Q), this routine can be invoked for reducing `v` modulo Q, using
// barrett reduction technique, following algorithm description @ https://www.nayuki.io/page/barrett-reduction-algorithm.
static inline constexpr uint32_t barrett_reduce(const uint32_t v)
static forceinline constexpr uint32_t barrett_reduce(const uint32_t v)
{
const uint64_t t0 = static_cast<uint64_t>(v) * static_cast<uint64_t>(R);
const uint32_t t1 = static_cast<uint32_t>(t0 >> (2 * Q_BIT_WIDTH));
@@ -57,35 +58,35 @@ private:
public:
// Constructor(s)
inline constexpr zq_t() = default;
inline constexpr zq_t(const uint16_t a /* Expects a ∈ [0, Q) */) { this->v = a; }
static inline constexpr zq_t from_non_reduced(const uint16_t a /* Doesn't expect that a ∈ [0, Q) */) { return barrett_reduce(a); }
forceinline constexpr zq_t() = default;
forceinline constexpr zq_t(const uint16_t a /* Expects a ∈ [0, Q) */) { this->v = a; }
static forceinline constexpr zq_t from_non_reduced(const uint16_t a /* Doesn't expect that a ∈ [0, Q) */) { return barrett_reduce(a); }
// Returns canonical value held under Zq type. Returned value must ∈ [0, Q).
inline constexpr uint32_t raw() const { return this->v; }
forceinline constexpr uint32_t raw() const { return this->v; }
static inline constexpr zq_t zero() { return zq_t(0u); }
static inline constexpr zq_t one() { return zq_t(1u); }
static forceinline constexpr zq_t zero() { return zq_t(0u); }
static forceinline constexpr zq_t one() { return zq_t(1u); }
// Modulo addition of two Zq elements.
inline constexpr zq_t operator+(const zq_t& rhs) const { return reduce_once(this->v + rhs.v); }
inline constexpr void operator+=(const zq_t& rhs) { *this = *this + rhs; }
forceinline constexpr zq_t operator+(const zq_t& rhs) const { return reduce_once(this->v + rhs.v); }
forceinline constexpr void operator+=(const zq_t& rhs) { *this = *this + rhs; }
// Modulo negation of a Zq element.
inline constexpr zq_t operator-() const { return zq_t(Q - this->v); }
forceinline constexpr zq_t operator-() const { return zq_t(Q - this->v); }
// Modulo subtraction of one Zq element from another one.
inline constexpr zq_t operator-(const zq_t& rhs) const { return *this + (-rhs); }
inline constexpr void operator-=(const zq_t& rhs) { *this = *this - rhs; }
forceinline constexpr zq_t operator-(const zq_t& rhs) const { return *this + (-rhs); }
forceinline constexpr void operator-=(const zq_t& rhs) { *this = *this - rhs; }
// Modulo multiplication of two Zq elements.
inline constexpr zq_t operator*(const zq_t& rhs) const { return barrett_reduce(this->v * rhs.v); }
inline constexpr void operator*=(const zq_t& rhs) { *this = *this * rhs; }
forceinline constexpr zq_t operator*(const zq_t& rhs) const { return barrett_reduce(this->v * rhs.v); }
forceinline constexpr void operator*=(const zq_t& rhs) { *this = *this * rhs; }
// Modulo exponentiation of Zq element.
//
// Taken from https://github.com/itzmeanjan/dilithium/blob/3fe6ab61/include/field.hpp#L144-L167.
inline constexpr zq_t operator^(const size_t n) const
forceinline constexpr zq_t operator^(const size_t n) const
{
zq_t base = *this;
@@ -108,15 +109,15 @@ public:
// Multiplicative inverse of Zq element. Also division of one Zq element by another one.
//
// Note, if Zq element is 0, we can't compute multiplicative inverse and 0 is returned.
inline constexpr zq_t inv() const { return *this ^ static_cast<size_t>((Q - 2)); }
inline constexpr zq_t operator/(const zq_t& rhs) const { return *this * rhs.inv(); }
forceinline constexpr zq_t inv() const { return *this ^ static_cast<size_t>((Q - 2)); }
forceinline constexpr zq_t operator/(const zq_t& rhs) const { return *this * rhs.inv(); }
// Comparison operators, see https://en.cppreference.com/w/cpp/language/default_comparisons
inline constexpr auto operator<=>(const zq_t&) const = default;
forceinline constexpr auto operator<=>(const zq_t&) const = default;
// Samples a random Zq element, using pseudo random number generator.
template<size_t bit_security_level>
static inline zq_t random(ml_kem_prng::prng_t<bit_security_level>& prng)
static forceinline zq_t random(ml_kem_prng::prng_t<bit_security_level>& prng)
{
uint16_t res = 0;
prng.read(std::span(reinterpret_cast<uint8_t*>(&res), sizeof(res)));

View File

@@ -12,7 +12,7 @@ namespace ml_kem {
// ML-KEM key generation algorithm, generating byte serialized public key and secret key, given 32 -bytes seed `d` and `z`.
// See algorithm 15 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t k, size_t eta1>
static inline constexpr void
constexpr void
keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
std::span<const uint8_t, 32> z, // used in CCA-KEM
std::span<uint8_t, ml_kem_utils::get_kem_public_key_len(k)> pubkey,
@@ -50,7 +50,7 @@ keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
//
// See algorithm 16 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result, it might fail because of malformed input public key")]] static inline constexpr bool
[[nodiscard("Use result, it might fail because of malformed input public key")]] constexpr bool
encapsulate(std::span<const uint8_t, 32> m,
std::span<const uint8_t, ml_kem_utils::get_kem_public_key_len(k)> pubkey,
std::span<uint8_t, ml_kem_utils::get_kem_cipher_text_len(k, du, dv)> cipher,
@@ -98,7 +98,7 @@ encapsulate(std::span<const uint8_t, 32> m,
//
// See algorithm 17 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
static inline constexpr void
constexpr void
decapsulate(std::span<const uint8_t, ml_kem_utils::get_kem_secret_key_len(k)> seckey,
std::span<const uint8_t, ml_kem_utils::get_kem_cipher_text_len(k, du, dv)> cipher,
std::span<uint8_t, 32> shared_secret)

View File

@@ -1,6 +1,7 @@
#pragma once
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include <span>
@@ -11,7 +12,7 @@ namespace ml_kem_utils {
// See formula 4.5 on page 18 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// Following implementation collects inspiration from https://github.com/FiloSottile/mlkem768/blob/cffbfb96/mlkem768.go#L395-L425.
template<size_t d>
static inline constexpr ml_kem_field::zq_t
forceinline constexpr ml_kem_field::zq_t
compress(const ml_kem_field::zq_t x)
requires(ml_kem_params::check_d(d))
{
@@ -31,7 +32,7 @@ compress(const ml_kem_field::zq_t x)
//
// See formula 4.6 on page 18 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t d>
static inline constexpr ml_kem_field::zq_t
forceinline constexpr ml_kem_field::zq_t
decompress(const ml_kem_field::zq_t x)
requires(ml_kem_params::check_d(d))
{
@@ -47,7 +48,7 @@ decompress(const ml_kem_field::zq_t x)
// Utility function to compress each of 256 coefficients of a degree-255 polynomial while mutating the input.
template<size_t d>
static inline constexpr void
constexpr void
poly_compress(std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
requires(ml_kem_params::check_d(d))
{
@@ -58,7 +59,7 @@ poly_compress(std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
// Utility function to decompress each of 256 coefficients of a degree-255 polynomial while mutating the input.
template<size_t d>
static inline constexpr void
constexpr void
poly_decompress(std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
requires(ml_kem_params::check_d(d))
{

View File

@@ -1,27 +1,28 @@
#pragma once
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
namespace ml_kem_ntt {
static constexpr size_t LOG2N = 8;
static constexpr size_t N = 1 << LOG2N;
inline constexpr size_t LOG2N = 8;
inline constexpr size_t N = 1 << LOG2N;
// First primitive 256 -th root of unity modulo q | q = 3329
//
// Meaning, 17 ** 256 == 1 mod q
static constexpr auto ζ = ml_kem_field::zq_t(17);
inline constexpr auto ζ = ml_kem_field::zq_t(17);
// Multiplicative inverse of N/ 2 over Z_q | q = 3329 and N = 256
//
// Meaning (N/ 2) * INV_N = 1 mod q
static constexpr auto INV_N = ml_kem_field::zq_t(N / 2).inv();
inline constexpr auto INV_N = ml_kem_field::zq_t(N / 2).inv();
// Given a 64 -bit unsigned integer, this routine extracts specified many contiguous bits from ( least significant bits ) LSB side
// and reverses their bit order, returning bit reversed `mbw` -bit wide number.
//
// See https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L30-L38 for source of inspiration.
template<size_t mbw>
static inline constexpr size_t
forceinline constexpr size_t
bit_rev(const size_t v)
{
size_t v_rev = 0ul;
@@ -35,7 +36,7 @@ bit_rev(const size_t v)
}
// Compile-time computed constants ( powers of ζ ), used for polynomial evaluation i.e. computation of NTT form.
static constexpr std::array<ml_kem_field::zq_t, N / 2> NTT_ζ_EXP = []() -> auto {
inline constexpr std::array<ml_kem_field::zq_t, N / 2> NTT_ζ_EXP = []() -> auto {
std::array<ml_kem_field::zq_t, N / 2> res{};
for (size_t i = 0; i < res.size(); i++) {
@@ -46,7 +47,7 @@ static constexpr std::array<ml_kem_field::zq_t, N / 2> NTT_ζ_EXP = []() -> auto
}();
// Compile-time computed constants ( negated powers of ζ ), used for polynomial interpolation i.e. computation of iNTT form.
static constexpr std::array<ml_kem_field::zq_t, N / 2> INTT_ζ_EXP = []() -> auto {
inline constexpr std::array<ml_kem_field::zq_t, N / 2> INTT_ζ_EXP = []() -> auto {
std::array<ml_kem_field::zq_t, N / 2> res{};
for (size_t i = 0; i < res.size(); i++) {
@@ -57,7 +58,7 @@ static constexpr std::array<ml_kem_field::zq_t, N / 2> INTT_ζ_EXP = []() -> aut
}();
// Compile-time computed constants ( powers of ζ ), used when multiplying two degree-255 polynomials in NTT domain.
static constexpr std::array<ml_kem_field::zq_t, N / 2> POLY_MUL_ζ_EXP = []() -> auto {
inline constexpr std::array<ml_kem_field::zq_t, N / 2> POLY_MUL_ζ_EXP = []() -> auto {
std::array<ml_kem_field::zq_t, N / 2> res{};
for (size_t i = 0; i < res.size(); i++) {
@@ -74,7 +75,7 @@ static constexpr std::array<ml_kem_field::zq_t, N / 2> POLY_MUL_ζ_EXP = []() ->
//
// Implementation inspired from https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L69-L144.
// See algorithm 8 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
static inline constexpr void
forceinline constexpr void
ntt(std::span<ml_kem_field::zq_t, N> poly)
{
for (size_t l = LOG2N - 1; l >= 1; l--) {
@@ -110,7 +111,7 @@ ntt(std::span<ml_kem_field::zq_t, N> poly)
//
// Implementation inspired from https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L146-L224.
// See algorithm 9 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
static inline constexpr void
forceinline constexpr void
intt(std::span<ml_kem_field::zq_t, N> poly)
{
for (size_t l = 1; l < LOG2N; l++) {
@@ -146,7 +147,7 @@ intt(std::span<ml_kem_field::zq_t, N> poly)
// Given two degree-1 polynomials, this routine computes resulting degree-1 polynomial h.
// See algorithm 11 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
static inline constexpr void
forceinline constexpr void
basemul(std::span<const ml_kem_field::zq_t, 2> f, std::span<const ml_kem_field::zq_t, 2> g, std::span<ml_kem_field::zq_t, 2> h, const ml_kem_field::zq_t ζ)
{
ml_kem_field::zq_t f0 = f[0];
@@ -178,7 +179,7 @@ basemul(std::span<const ml_kem_field::zq_t, 2> f, std::span<const ml_kem_field::
// h = f ◦ g
//
// See algorithm 10 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
static inline constexpr void
constexpr void
polymul(std::span<const ml_kem_field::zq_t, N> f, std::span<const ml_kem_field::zq_t, N> g, std::span<ml_kem_field::zq_t, N> h)
{
using poly_t = std::span<const ml_kem_field::zq_t, 2>;

View File

@@ -3,6 +3,7 @@
#include "ml_kem/internals/poly/compression.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/poly/serialize.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
#include "ml_kem/internals/utility/params.hpp"
namespace ml_kem_utils {
@@ -10,7 +11,7 @@ namespace ml_kem_utils {
// Given two matrices ( in NTT domain ) of compatible dimension, where each matrix element is a degree-255 polynomial over Z_q | q = 3329,
// this routine multiplies them, computing a resulting matrix.
template<size_t a_rows, size_t a_cols, size_t b_rows, size_t b_cols>
static inline constexpr void
constexpr void
matrix_multiply(std::span<const ml_kem_field::zq_t, a_rows * a_cols * ml_kem_ntt::N> a,
std::span<const ml_kem_field::zq_t, b_rows * b_cols * ml_kem_ntt::N> b,
std::span<ml_kem_field::zq_t, a_rows * b_cols * ml_kem_ntt::N> c)
@@ -42,7 +43,7 @@ matrix_multiply(std::span<const ml_kem_field::zq_t, a_rows * a_cols * ml_kem_ntt
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials ( where polynomial coefficients are in non-NTT form ),
// this routine applies in-place polynomial NTT over `k` polynomials.
template<size_t k>
static inline constexpr void
constexpr void
poly_vec_ntt(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
requires((k == 1) || ml_kem_params::check_k(k))
{
@@ -57,7 +58,7 @@ poly_vec_ntt(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials ( where polynomial coefficients are in NTT form i.e.
// they are placed in bit-reversed order ), this routine applies in-place polynomial iNTT over those `k` polynomials.
template<size_t k>
static inline constexpr void
constexpr void
poly_vec_intt(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
requires((k == 1) || ml_kem_params::check_k(k))
{
@@ -71,7 +72,7 @@ poly_vec_intt(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials, this routine adds it to another polynomial vector of same dimension.
template<size_t k>
static inline constexpr void
constexpr void
poly_vec_add_to(std::span<const ml_kem_field::zq_t, k * ml_kem_ntt::N> src, std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> dst)
requires((k == 1) || ml_kem_params::check_k(k))
{
@@ -84,7 +85,7 @@ poly_vec_add_to(std::span<const ml_kem_field::zq_t, k * ml_kem_ntt::N> src, std:
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials, this routine subtracts it to another polynomial vector of same dimension.
template<size_t k>
static inline constexpr void
constexpr void
poly_vec_sub_from(std::span<const ml_kem_field::zq_t, k * ml_kem_ntt::N> src, std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> dst)
requires((k == 1) || ml_kem_params::check_k(k))
{
@@ -98,7 +99,7 @@ poly_vec_sub_from(std::span<const ml_kem_field::zq_t, k * ml_kem_ntt::N> src, st
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials, this routine encodes each of those polynomials into 32 x l -bytes,
// writing to a (k x 32 x l) -bytes destination array.
template<size_t k, size_t l>
static inline constexpr void
constexpr void
poly_vec_encode(std::span<const ml_kem_field::zq_t, k * ml_kem_ntt::N> src, std::span<uint8_t, k * 32 * l> dst)
requires(ml_kem_params::check_k(k))
{
@@ -116,7 +117,7 @@ poly_vec_encode(std::span<const ml_kem_field::zq_t, k * ml_kem_ntt::N> src, std:
// Given a byte array of length (k x 32 x l) -bytes, this routine decodes them into k degree-255 polynomials, writing them to a
// column vector of dimension `k x 1`.
template<size_t k, size_t l>
static inline constexpr void
constexpr void
poly_vec_decode(std::span<const uint8_t, k * 32 * l> src, std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> dst)
requires(ml_kem_params::check_k(k))
{
@@ -133,7 +134,7 @@ poly_vec_decode(std::span<const uint8_t, k * 32 * l> src, std::span<ml_kem_field
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials, each of k * 256 coefficients are compressed, while mutating input.
template<size_t k, size_t d>
static inline constexpr void
constexpr void
poly_vec_compress(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
requires(ml_kem_params::check_k(k))
{
@@ -147,7 +148,7 @@ poly_vec_compress(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
// Given a vector ( of dimension `k x 1` ) of degree-255 polynomials, each of k * 256 coefficients are decompressed, while mutating input.
template<size_t k, size_t d>
static inline constexpr void
constexpr void
poly_vec_decompress(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec)
requires(ml_kem_params::check_k(k))
{

View File

@@ -1,6 +1,7 @@
#pragma once
#include "ml_kem/internals/math/field.hpp"
#include "ml_kem/internals/poly/ntt.hpp"
#include "ml_kem/internals/utility/force_inline.hpp"
#include "ml_kem/internals/utility/params.hpp"
#include "shake128.hpp"
#include "shake256.hpp"
@@ -15,7 +16,7 @@ namespace ml_kem_utils {
// statiscally close to randomly sampled elements of R_q.
//
// See algorithm 6 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
inline constexpr void
forceinline constexpr void
sample_ntt(shake128::shake128_t& hasher, std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
{
constexpr size_t n = poly.size();
@@ -48,7 +49,7 @@ sample_ntt(shake128::shake128_t& hasher, std::span<ml_kem_field::zq_t, ml_kem_nt
//
// See step (4-8) of algorithm 12/ 13 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, bool transpose>
static inline constexpr void
constexpr void
generate_matrix(std::span<ml_kem_field::zq_t, k * k * ml_kem_ntt::N> mat, std::span<const uint8_t, 32> rho)
requires(ml_kem_params::check_k(k))
{
@@ -82,7 +83,7 @@ generate_matrix(std::span<ml_kem_field::zq_t, k * k * ml_kem_ntt::N> mat, std::s
//
// See algorithm 7 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t eta>
static inline constexpr void
constexpr void
sample_poly_cbd(std::span<const uint8_t, 64 * eta> prf, std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
requires(ml_kem_params::check_eta(eta))
{
@@ -132,7 +133,7 @@ sample_poly_cbd(std::span<const uint8_t, 64 * eta> prf, std::span<ml_kem_field::
// Sample a polynomial vector from Bη, following step (9-12) of algorithm 12/ 13 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t k, size_t eta>
static inline constexpr void
constexpr void
generate_vector(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec, std::span<const uint8_t, 32> sigma, const uint8_t nonce)
requires((k == 1) || ml_kem_params::check_k(k))
{

View File

@@ -10,7 +10,7 @@ namespace ml_kem_utils {
//
// See algorithm 4 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t l>
static inline constexpr void
constexpr void
encode(std::span<const ml_kem_field::zq_t, ml_kem_ntt::N> poly, std::span<uint8_t, 32 * l> arr)
requires(ml_kem_params::check_l(l))
{
@@ -144,7 +144,7 @@ encode(std::span<const ml_kem_field::zq_t, ml_kem_ntt::N> poly, std::span<uint8_
//
// See algorithm 5 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
template<size_t l>
static inline constexpr void
constexpr void
decode(std::span<const uint8_t, 32 * l> arr, std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
requires(ml_kem_params::check_l(l))
{

View File

@@ -1,4 +1,5 @@
#pragma once
#include "ml_kem/internals/utility/force_inline.hpp"
#include "shake256.hpp"
#include <limits>
#include <random>
@@ -24,7 +25,7 @@ private:
public:
// Default constructor which seeds PRNG with system randomness.
inline prng_t()
forceinline prng_t()
{
std::array<uint8_t, bit_security_level / std::numeric_limits<uint8_t>::digits> seed{};
auto _seed = std::span(seed);
@@ -45,14 +46,14 @@ public:
}
// Explicit constructor which can be used for seeding PRNG.
inline explicit constexpr prng_t(std::span<const uint8_t, bit_security_level / std::numeric_limits<uint8_t>::digits> seed)
forceinline explicit constexpr prng_t(std::span<const uint8_t, bit_security_level / std::numeric_limits<uint8_t>::digits> seed)
{
state.absorb(seed);
state.finalize();
}
// Once PRNG is seeded i.e. PRNG object is constructed, you can request arbitrary many pseudo-random bytes from PRNG.
inline constexpr void read(std::span<uint8_t> bytes) { state.squeeze(bytes); }
forceinline constexpr void read(std::span<uint8_t> bytes) { state.squeeze(bytes); }
};
}

View File

@@ -0,0 +1,29 @@
#pragma once
// Following content is taken from https://github.com/itzmeanjan/raccoon/blob/bfa45f9f22ea7b98f5d6588a8513ff4182af79ca/include/raccoon/internals/utility/force_inline.hpp
#ifdef _MSC_VER
// MSVC
#define forceinline __forceinline
#elif defined(__GNUC__)
// GCC
#if defined(__cplusplus) && __cplusplus >= 201103L
#define forceinline inline __attribute__((__always_inline__))
#else
#define forceinline inline
#endif
#elif defined(__CLANG__)
// Clang
#if __has_attribute(__always_inline__)
#define forceinline inline __attribute__((__always_inline__))
#else
#define forceinline inline
#endif
#else
// Others
#define forceinline inline
#endif

View File

@@ -1,4 +1,5 @@
#pragma once
#include "ml_kem/internals/utility/force_inline.hpp"
#include "subtle.hpp"
#include <span>
@@ -7,7 +8,7 @@ namespace ml_kem_utils {
// Given two byte arrays of equal length, this routine can be used for comparing them in constant-time,
// producing truth value (0xffffffff) in case of equality, otherwise it returns false value (0x00000000).
template<size_t n>
static inline constexpr uint32_t
forceinline constexpr uint32_t
ct_memcmp(std::span<const uint8_t, n> bytes0, std::span<const uint8_t, n> bytes1)
{
uint32_t flag = -1u;
@@ -24,7 +25,7 @@ ct_memcmp(std::span<const uint8_t, n> bytes0, std::span<const uint8_t, n> bytes1
//
// In simple words, `sink = cond ? source0 ? source1`
template<size_t n>
static inline constexpr void
forceinline constexpr void
ct_cond_memcpy(const uint32_t cond, std::span<uint8_t, n> sink, std::span<const uint8_t, n> source0, std::span<const uint8_t, n> source1)
{
for (size_t i = 0; i < n; i++) {
@@ -33,42 +34,42 @@ ct_cond_memcpy(const uint32_t cond, std::span<uint8_t, n> sink, std::span<const
}
// Returns compile-time computable K-PKE public key byte length.
static inline constexpr size_t
forceinline constexpr size_t
get_pke_public_key_len(const size_t k)
{
return k * 12 * 32 + 32;
}
// Returns compile-time computable K-PKE secret key byte length.
static inline constexpr size_t
forceinline constexpr size_t
get_pke_secret_key_len(const size_t k)
{
return k * 12 * 32;
}
// Returns compile-time computable K-PKE cipher text byte length.
static inline constexpr size_t
forceinline constexpr size_t
get_pke_cipher_text_len(size_t k, size_t du, size_t dv)
{
return 32 * (k * du + dv);
}
// Returns compile-time computable ML-KEM public key byte length.
static inline constexpr size_t
forceinline constexpr size_t
get_kem_public_key_len(const size_t k)
{
return get_pke_public_key_len(k);
}
// Returns compile-time computable ML-KEM secret key byte length.
static inline constexpr size_t
forceinline constexpr size_t
get_kem_secret_key_len(const size_t k)
{
return get_pke_secret_key_len(k) + get_pke_public_key_len(k) + 32 + 32;
}
// Returns compile-time computable ML-KEM cipher text byte length.
static inline constexpr size_t
forceinline constexpr size_t
get_kem_cipher_text_len(size_t k, size_t du, size_t dv)
{
return get_pke_cipher_text_len(k, du, dv);

View File

@@ -6,35 +6,35 @@ namespace ml_kem_1024 {
// ML-KEM Key Encapsulation Mechanism instantiated with ML-KEM-1024 parameters
// See row 3 of table 2 of ML-KEM specification @ https://doi.org/10.6028/NIST.FIPS.203.ipd
static constexpr size_t k = 4;
static constexpr size_t η1 = 2;
static constexpr size_t η2 = 2;
static constexpr size_t du = 11;
static constexpr size_t dv = 5;
inline constexpr size_t k = 4;
inline constexpr size_t η1 = 2;
inline constexpr size_t η2 = 2;
inline constexpr size_t du = 11;
inline constexpr size_t dv = 5;
// 32 -bytes seed `d`, used in underlying K-PKE key generation
static constexpr size_t SEED_D_BYTE_LEN = 32;
inline constexpr size_t SEED_D_BYTE_LEN = 32;
// 32 -bytes seed `z`, used in ML-KEM key generation
static constexpr size_t SEED_Z_BYTE_LEN = 32;
inline constexpr size_t SEED_Z_BYTE_LEN = 32;
// 1568 -bytes ML-KEM-1024 public key
static constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
inline constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
// 3168 -bytes ML-KEM-1024 secret key
static constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
inline constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
// 32 -bytes seed `m`, used in ML-KEM encapsulation
static constexpr size_t SEED_M_BYTE_LEN = 32;
inline constexpr size_t SEED_M_BYTE_LEN = 32;
// 1568 -bytes ML-KEM-1024 cipher text
static constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
inline constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
// 32 -bytes ML-KEM-1024 shared secret
static constexpr size_t SHARED_SECRET_BYTE_LEN = 32;
inline constexpr size_t SHARED_SECRET_BYTE_LEN = 32;
// Computes a new ML-KEM-1024 keypair, given seed `d` and `z`.
inline constexpr void
constexpr void
keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
std::span<const uint8_t, SEED_Z_BYTE_LEN> z,
std::span<uint8_t, PKEY_BYTE_LEN> pubkey,
@@ -45,7 +45,7 @@ keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
// Given seed `m` and a ML-KEM-1024 public key, this routine computes a ML-KEM-1024 cipher text and a fixed size shared secret.
// If, input ML-KEM-1024 public key is malformed, encapsulation will fail, returning false.
[[nodiscard("If public key is malformed, encapsulation fails")]] inline constexpr bool
[[nodiscard("If public key is malformed, encapsulation fails")]] constexpr bool
encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
std::span<const uint8_t, PKEY_BYTE_LEN> pubkey,
std::span<uint8_t, CIPHER_TEXT_BYTE_LEN> cipher,
@@ -55,7 +55,7 @@ encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
}
// Given a ML-KEM-1024 secret key and a cipher text, this routine computes a fixed size shared secret.
inline constexpr void
constexpr void
decapsulate(std::span<const uint8_t, SKEY_BYTE_LEN> seckey, std::span<const uint8_t, CIPHER_TEXT_BYTE_LEN> cipher, std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
ml_kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);

View File

@@ -6,35 +6,35 @@ namespace ml_kem_512 {
// ML-KEM Key Encapsulation Mechanism instantiated with ML-KEM-512 parameters
// See row 1 of table 2 of ML-KEM specification @ https://doi.org/10.6028/NIST.FIPS.203.ipd
static constexpr size_t k = 2;
static constexpr size_t η1 = 3;
static constexpr size_t η2 = 2;
static constexpr size_t du = 10;
static constexpr size_t dv = 4;
inline constexpr size_t k = 2;
inline constexpr size_t η1 = 3;
inline constexpr size_t η2 = 2;
inline constexpr size_t du = 10;
inline constexpr size_t dv = 4;
// 32 -bytes seed `d`, used in underlying K-PKE key generation
static constexpr size_t SEED_D_BYTE_LEN = 32;
inline constexpr size_t SEED_D_BYTE_LEN = 32;
// 32 -bytes seed `z`, used in ML-KEM key generation
static constexpr size_t SEED_Z_BYTE_LEN = 32;
inline constexpr size_t SEED_Z_BYTE_LEN = 32;
// 800 -bytes ML-KEM-512 public key
static constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
inline constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
// 1632 -bytes ML-KEM-512 secret key
static constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
inline constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
// 32 -bytes seed `m`, used in ML-KEM encapsulation
static constexpr size_t SEED_M_BYTE_LEN = 32;
inline constexpr size_t SEED_M_BYTE_LEN = 32;
// 768 -bytes ML-KEM-512 cipher text
static constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
inline constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
// 32 -bytes ML-KEM-512 shared secret
static constexpr size_t SHARED_SECRET_BYTE_LEN = 32;
inline constexpr size_t SHARED_SECRET_BYTE_LEN = 32;
// Computes a new ML-KEM-512 keypair, given seed `d` and `z`.
inline constexpr void
constexpr void
keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
std::span<const uint8_t, SEED_Z_BYTE_LEN> z,
std::span<uint8_t, PKEY_BYTE_LEN> pubkey,
@@ -45,7 +45,7 @@ keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
// Given seed `m` and a ML-KEM-512 public key, this routine computes a ML-KEM-512 cipher text and a fixed size shared secret.
// If, input ML-KEM-512 public key is malformed, encapsulation will fail, returning false.
[[nodiscard("If public key is malformed, encapsulation fails")]] inline constexpr bool
[[nodiscard("If public key is malformed, encapsulation fails")]] constexpr bool
encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
std::span<const uint8_t, PKEY_BYTE_LEN> pubkey,
std::span<uint8_t, CIPHER_TEXT_BYTE_LEN> cipher,
@@ -55,7 +55,7 @@ encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
}
// Given a ML-KEM-512 secret key and a cipher text, this routine computes a fixed size shared secret.
inline constexpr void
constexpr void
decapsulate(std::span<const uint8_t, SKEY_BYTE_LEN> seckey, std::span<const uint8_t, CIPHER_TEXT_BYTE_LEN> cipher, std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
ml_kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);

View File

@@ -6,35 +6,35 @@ namespace ml_kem_768 {
// ML-KEM Key Encapsulation Mechanism instantiated with ML-KEM-768 parameters
// See row 2 of table 2 of ML-KEM specification @ https://doi.org/10.6028/NIST.FIPS.203.ipd
static constexpr size_t k = 3;
static constexpr size_t η1 = 2;
static constexpr size_t η2 = 2;
static constexpr size_t du = 10;
static constexpr size_t dv = 4;
inline constexpr size_t k = 3;
inline constexpr size_t η1 = 2;
inline constexpr size_t η2 = 2;
inline constexpr size_t du = 10;
inline constexpr size_t dv = 4;
// 32 -bytes seed `d`, used in underlying K-PKE key generation
static constexpr size_t SEED_D_BYTE_LEN = 32;
inline constexpr size_t SEED_D_BYTE_LEN = 32;
// 32 -bytes seed `z`, used in ML-KEM key generation
static constexpr size_t SEED_Z_BYTE_LEN = 32;
inline constexpr size_t SEED_Z_BYTE_LEN = 32;
// 1184 -bytes ML-KEM-768 public key
static constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
inline constexpr size_t PKEY_BYTE_LEN = ml_kem_utils::get_kem_public_key_len(k);
// 2400 -bytes ML-KEM-768 secret key
static constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
inline constexpr size_t SKEY_BYTE_LEN = ml_kem_utils::get_kem_secret_key_len(k);
// 32 -bytes seed `m`, used in ML-KEM encapsulation
static constexpr size_t SEED_M_BYTE_LEN = 32;
inline constexpr size_t SEED_M_BYTE_LEN = 32;
// 1088 -bytes ML-KEM-768 cipher text
static constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
inline constexpr size_t CIPHER_TEXT_BYTE_LEN = ml_kem_utils::get_kem_cipher_text_len(k, du, dv);
// 32 -bytes ML-KEM-768 shared secret
static constexpr size_t SHARED_SECRET_BYTE_LEN = 32;
inline constexpr size_t SHARED_SECRET_BYTE_LEN = 32;
// Computes a new ML-KEM-768 keypair, given seed `d` and `z`.
inline constexpr void
constexpr void
keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
std::span<const uint8_t, SEED_Z_BYTE_LEN> z,
std::span<uint8_t, PKEY_BYTE_LEN> pubkey,
@@ -45,7 +45,7 @@ keygen(std::span<const uint8_t, SEED_D_BYTE_LEN> d,
// Given seed `m` and a ML-KEM-768 public key, this routine computes a ML-KEM-768 cipher text and a fixed size shared secret.
// If, input ML-KEM-768 public key is malformed, encapsulation will fail, returning false.
[[nodiscard("If public key is malformed, encapsulation fails")]] inline constexpr bool
[[nodiscard("If public key is malformed, encapsulation fails")]] constexpr bool
encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
std::span<const uint8_t, PKEY_BYTE_LEN> pubkey,
std::span<uint8_t, CIPHER_TEXT_BYTE_LEN> cipher,
@@ -55,7 +55,7 @@ encapsulate(std::span<const uint8_t, SEED_M_BYTE_LEN> m,
}
// Given a ML-KEM-768 secret key and a cipher text, this routine computes a fixed size shared secret.
inline constexpr void
constexpr void
decapsulate(std::span<const uint8_t, SKEY_BYTE_LEN> seckey, std::span<const uint8_t, CIPHER_TEXT_BYTE_LEN> cipher, std::span<uint8_t, SHARED_SECRET_BYTE_LEN> shared_secret)
{
ml_kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher, shared_secret);