added .clang-format configuration file; reformatted source tree

Signed-off-by: Anjan Roy <hello@itzmeanjan.in>
This commit is contained in:
Anjan Roy
2023-11-10 22:48:41 +05:30
parent e91593e7e3
commit 4f0d00a168
12 changed files with 305 additions and 172 deletions

225
.clang-format Normal file
View File

@@ -0,0 +1,225 @@
---
Language: Cpp
# BasedOnStyle: Mozilla
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignArrayOfStructures: None
AlignConsecutiveAssignments:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: true
AlignConsecutiveBitFields:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: false
AlignConsecutiveDeclarations:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: false
AlignConsecutiveMacros:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: false
AlignEscapedNewlines: Right
AlignOperands: Align
AlignTrailingComments:
Kind: Always
OverEmptyLines: 0
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortEnumsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: TopLevel
AlwaysBreakAfterReturnType: TopLevel
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: Yes
AttributeMacros:
- __capability
BinPackArguments: false
BinPackParameters: false
BitFieldColonSpacing: Both
BraceWrapping:
AfterCaseLabel: false
AfterClass: true
AfterControlStatement: Never
AfterEnum: true
AfterExternBlock: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: true
AfterUnion: true
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: true
SplitEmptyRecord: false
SplitEmptyNamespace: true
BreakAfterAttributes: Never
BreakAfterJavaFieldAnnotations: false
BreakArrays: true
BreakBeforeBinaryOperators: None
BreakBeforeConceptDeclarations: Always
BreakBeforeBraces: Mozilla
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeComma
BreakInheritanceList: BeforeComma
BreakStringLiterals: true
ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerIndentWidth: 2
ContinuationIndentWidth: 2
Cpp11BracedListStyle: false
DerivePointerAlignment: false
DisableFormat: false
EmptyLineAfterAccessModifier: Never
EmptyLineBeforeAccessModifier: LogicalBlock
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: false
ForEachMacros:
- foreach
- Q_FOREACH
- BOOST_FOREACH
IfMacros:
- KJ_IF_MAYBE
IncludeBlocks: Preserve
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
SortPriority: 0
CaseSensitive: false
- Regex: '^(<|"(gtest|gmock|isl|json)/)'
Priority: 3
SortPriority: 0
CaseSensitive: false
- Regex: '.*'
Priority: 1
SortPriority: 0
CaseSensitive: false
IncludeIsMainRegex: '(Test)?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false
IndentCaseBlocks: false
IndentCaseLabels: true
IndentExternBlock: AfterExternBlock
IndentGotoLabels: true
IndentPPDirectives: None
IndentRequiresClause: true
IndentWidth: 2
IndentWrappedFunctionNames: false
InsertBraces: false
InsertNewlineAtEOF: false
InsertTrailingCommas: None
IntegerLiteralSeparator:
Binary: 0
BinaryMinDigits: 0
Decimal: 0
DecimalMinDigits: 0
Hex: 0
HexMinDigits: 0
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: true
LambdaBodyIndentation: Signature
LineEnding: DeriveLF
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 2
ObjCBreakBeforeNestedBlockParam: true
ObjCSpaceAfterProperty: true
ObjCSpaceBeforeProtocolList: false
PackConstructorInitializers: BinPack
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakOpenParenthesis: 0
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyIndentedWhitespace: 0
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PPIndentWidth: -1
QualifierAlignment: Leave
ReferenceAlignment: Pointer
ReflowComments: true
RemoveBracesLLVM: false
RemoveSemicolon: false
RequiresClausePosition: OwnLine
RequiresExpressionIndentation: OuterScope
SeparateDefinitionBlocks: Leave
ShortNamespaceLines: 1
SortIncludes: CaseSensitive
SortJavaStaticImport: Before
SortUsingDeclarations: LexicographicNumeric
SpaceAfterCStyleCast: false
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: false
SpaceAroundPointerQualifiers: Default
SpaceBeforeAssignmentOperators: true
SpaceBeforeCaseColon: false
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeParensOptions:
AfterControlStatements: true
AfterForeachMacros: true
AfterFunctionDefinitionName: false
AfterFunctionDeclarationName: false
AfterIfMacros: true
AfterOverloadedOperator: false
AfterRequiresInClause: false
AfterRequiresInExpression: false
BeforeNonEmptyParentheses: false
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: Never
SpacesInConditionalStatement: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: -1
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Latest
StatementAttributeLikeMacros:
- Q_EMIT
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
TabWidth: 8
UseTab: Never
WhitespaceSensitiveMacros:
- BOOST_PP_STRINGIZE
- CF_SWIFT_NAME
- NS_SWIFT_NAME
- PP_STRINGIZE
- STRINGIZE
...

View File

@@ -61,4 +61,4 @@ clean:
rm -rf $(BUILD_DIR)
format: $(KYBER_SOURCES) $(TEST_SOURCES) $(BENCHMARK_SOURCES)
clang-format -i --style=Mozilla $^
clang-format -i $^

View File

@@ -28,10 +28,7 @@ struct zq_t
public:
// Given a 16 -bit unsigned integer `a`, this function constructs a Zq
// element, such that `a` is reduced modulo Q.
inline constexpr zq_t(const uint16_t a = 0u)
{
this->v = barrett_reduce(static_cast<uint32_t>(a));
}
inline constexpr zq_t(const uint16_t a = 0u) { this->v = barrett_reduce(static_cast<uint32_t>(a)); }
// Returns canonical value held under Zq type. Returned value must ∈ [0, Q).
inline constexpr uint32_t raw() const { return this->v; }
@@ -43,10 +40,7 @@ public:
static inline constexpr zq_t zero() { return zq_t(); }
// Modulo addition of two Zq elements.
inline constexpr zq_t operator+(const zq_t rhs) const
{
return zq_t(this->v + rhs.v);
}
inline constexpr zq_t operator+(const zq_t rhs) const { return zq_t(this->v + rhs.v); }
// Compound modulo addition of two Zq elements.
inline constexpr void operator+=(const zq_t rhs) { *this = *this + rhs; }
@@ -55,10 +49,7 @@ public:
inline constexpr zq_t operator-() const { return zq_t(Q - this->v); }
// Modulo subtraction of one Zq element from another one.
inline constexpr zq_t operator-(const zq_t rhs) const
{
return *this + (-rhs);
}
inline constexpr zq_t operator-(const zq_t rhs) const { return *this + (-rhs); }
// Compound modulo subtraction of two Zq elements.
inline constexpr void operator-=(const zq_t rhs) { *this = *this - rhs; }
@@ -94,26 +85,17 @@ public:
//
// Note, if Zq element is 0, we can't compute multiplicative inverse and 0 is
// returned.
inline constexpr zq_t inv() const
{
return *this ^ static_cast<size_t>((Q - 2));
}
inline constexpr zq_t inv() const { return *this ^ static_cast<size_t>((Q - 2)); }
// Modulo division of two Zq elements.
//
// Note, if denominator is 0, returned result is 0 too, becaue we can't
// compute multiplicative inverse of 0.
inline constexpr zq_t operator/(const zq_t rhs) const
{
return *this * rhs.inv();
}
inline constexpr zq_t operator/(const zq_t rhs) const { return *this * rhs.inv(); }
// Compare two Zq elements, returning truth value, in case they are same,
// otherwise returns false value.
inline constexpr bool operator==(const zq_t rhs) const
{
return this->v == rhs.v;
}
inline constexpr bool operator==(const zq_t rhs) const { return this->v == rhs.v; }
// Samples a random Zq element, using pseudo random number generator.
static inline zq_t random(prng::prng_t& prng)

View File

@@ -75,10 +75,9 @@ keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
// benchmarking underlying KEM's encapsulation implementation.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
static inline shake256::shake256_t
encapsulate(
std::span<const uint8_t, 32> m,
std::span<const uint8_t, kyber_utils::get_kem_public_key_len<k>()> pubkey,
std::span<uint8_t, kyber_utils::get_kem_cipher_len<k, du, dv>()> cipher)
encapsulate(std::span<const uint8_t, 32> m,
std::span<const uint8_t, kyber_utils::get_kem_public_key_len<k>()> pubkey,
std::span<uint8_t, kyber_utils::get_kem_cipher_len<k, du, dv>()> cipher)
requires(kyber_params::check_encap_params(k, eta1, eta2, du, dv))
{
std::array<uint8_t, 64> g_in{};
@@ -145,9 +144,8 @@ encapsulate(
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
static inline shake256::shake256_t
decapsulate(
std::span<const uint8_t, kyber_utils::get_kem_secret_key_len<k>()> seckey,
std::span<const uint8_t, kyber_utils::get_kem_cipher_len<k, du, dv>()> cipher)
decapsulate(std::span<const uint8_t, kyber_utils::get_kem_secret_key_len<k>()> seckey,
std::span<const uint8_t, kyber_utils::get_kem_cipher_len<k, du, dv>()> cipher)
requires(kyber_params::check_decap_params(k, eta1, eta2, du, dv))
{
constexpr size_t sklen = k * 12 * 32;

View File

@@ -58,8 +58,7 @@ encapsulate(std::span<const uint8_t, 32> m,
//
// Returned KDF can be used for deriving shared key of arbitrary bytes length.
inline shake256::shake256_t
decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
std::span<const uint8_t, CIPHER_LEN> cipher)
decapsulate(std::span<const uint8_t, SKEY_LEN> seckey, std::span<const uint8_t, CIPHER_LEN> cipher)
{
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
}

View File

@@ -58,8 +58,7 @@ encapsulate(std::span<const uint8_t, 32> m,
//
// Returned KDF can be used for deriving shared key of arbitrary bytes length.
inline shake256::shake256_t
decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
std::span<const uint8_t, CIPHER_LEN> cipher)
decapsulate(std::span<const uint8_t, SKEY_LEN> seckey, std::span<const uint8_t, CIPHER_LEN> cipher)
{
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
}

View File

@@ -57,8 +57,7 @@ encapsulate(std::span<const uint8_t, 32> m,
//
// Returned KDF can be used for deriving shared key of arbitrary bytes length.
inline shake256::shake256_t
decapsulate(std::span<const uint8_t, SKEY_LEN> seckey,
std::span<const uint8_t, CIPHER_LEN> cipher)
decapsulate(std::span<const uint8_t, SKEY_LEN> seckey, std::span<const uint8_t, CIPHER_LEN> cipher)
{
return kem::decapsulate<k, η1, η2, du, dv>(seckey, cipher);
}

View File

@@ -231,10 +231,7 @@ polymul(std::span<const field::zq_t, N> f, // degree-255 polynomial
for (size_t i = 0; i < cnt; i++) {
const size_t off = i << 1;
basemul(poly_t(f.subspan(off, 2)),
poly_t(g.subspan(off, 2)),
mut_poly_t(h.subspan(off, 2)),
POLY_MUL_ζ_EXP[i]);
basemul(poly_t(f.subspan(off, 2)), poly_t(g.subspan(off, 2)), mut_poly_t(h.subspan(off, 2)), POLY_MUL_ζ_EXP[i]);
}
}

View File

@@ -44,8 +44,7 @@ check_k(const size_t k)
consteval bool
check_l(const size_t l)
{
return (l == 1) || (l == 4) || (l == 5) || (l == 10) || (l == 11) ||
(l == 12);
return (l == 1) || (l == 4) || (l == 5) || (l == 10) || (l == 11) || (l == 12);
}
// Compile-time check to ensure that operand matrices are having compatible
@@ -77,11 +76,7 @@ check_keygen_params(const size_t k, const size_t eta1)
// See algorithm 5 and table 1 of Kyber specification
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
consteval bool
check_encrypt_params(const size_t k,
const size_t η1,
const size_t η2,
const size_t du,
const size_t dv)
check_encrypt_params(const size_t k, const size_t η1, const size_t η2, const size_t du, const size_t dv)
{
bool flg0 = (k == 2) && (η1 == 3) && (η2 == 2) && (du == 10) && (dv == 4);
bool flg1 = (k == 3) && (η1 == 2) && (η2 == 2) && (du == 10) && (dv == 4);
@@ -111,11 +106,7 @@ check_decrypt_params(const size_t k, const size_t du, const size_t dv)
// See algorithm 8 and table 1 of Kyber specification
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
consteval bool
check_encap_params(const size_t k,
const size_t η1,
const size_t η2,
const size_t du,
const size_t dv)
check_encap_params(const size_t k, const size_t η1, const size_t η2, const size_t du, const size_t dv)
{
return check_encrypt_params(k, η1, η2, du, dv);
}
@@ -126,11 +117,7 @@ check_encap_params(const size_t k,
// See algorithm 9 and table 1 of Kyber specification
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
consteval bool
check_decap_params(const size_t k,
const size_t η1,
const size_t η2,
const size_t du,
const size_t dv)
check_decap_params(const size_t k, const size_t η1, const size_t η2, const size_t du, const size_t dv)
{
return check_encap_params(k, η1, η2, du, dv);
}

View File

@@ -33,9 +33,7 @@ matrix_multiply(std::span<const field::zq_t, a_rows * a_cols * ntt::N> a,
const size_t aoff = (i * a_cols + k) * ntt::N;
const size_t boff = (k * b_cols + j) * ntt::N;
ntt::polymul(poly_t(a.subspan(aoff, ntt::N)),
poly_t(b.subspan(boff, ntt::N)),
_tmp);
ntt::polymul(poly_t(a.subspan(aoff, ntt::N)), poly_t(b.subspan(boff, ntt::N)), _tmp);
for (size_t l = 0; l < ntt::N; l++) {
c[coff + l] += tmp[l];
@@ -82,8 +80,7 @@ poly_vec_intt(std::span<field::zq_t, k * ntt::N> vec)
// routine adds it to another polynomial vector of same dimension
template<size_t k>
static inline constexpr void
poly_vec_add_to(std::span<const field::zq_t, k * ntt::N> src,
std::span<field::zq_t, k * ntt::N> dst)
poly_vec_add_to(std::span<const field::zq_t, k * ntt::N> src, std::span<field::zq_t, k * ntt::N> dst)
requires((k == 1) || kyber_params::check_k(k))
{
constexpr size_t cnt = k * ntt::N;
@@ -97,8 +94,7 @@ poly_vec_add_to(std::span<const field::zq_t, k * ntt::N> src,
// routine subtracts it to another polynomial vector of same dimension
template<size_t k>
static inline constexpr void
poly_vec_sub_from(std::span<const field::zq_t, k * ntt::N> src,
std::span<field::zq_t, k * ntt::N> dst)
poly_vec_sub_from(std::span<const field::zq_t, k * ntt::N> src, std::span<field::zq_t, k * ntt::N> dst)
requires((k == 1) || kyber_params::check_k(k))
{
constexpr size_t cnt = k * ntt::N;
@@ -113,8 +109,7 @@ poly_vec_sub_from(std::span<const field::zq_t, k * ntt::N> src,
// (k x 32 x l) -bytes destination array
template<size_t k, size_t l>
static inline void
poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src,
std::span<uint8_t, k * 32 * l> dst)
poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src, std::span<uint8_t, k * 32 * l> dst)
requires(kyber_params::check_k(k))
{
using poly_t = std::span<const field::zq_t, src.size() / k>;
@@ -124,8 +119,7 @@ poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src,
const size_t off0 = i * ntt::N;
const size_t off1 = i * l * 32;
kyber_utils::encode<l>(poly_t(src.subspan(off0, ntt::N)),
serialized_t(dst.subspan(off1, 32 * l)));
kyber_utils::encode<l>(poly_t(src.subspan(off0, ntt::N)), serialized_t(dst.subspan(off1, 32 * l)));
}
}
@@ -134,8 +128,7 @@ poly_vec_encode(std::span<const field::zq_t, k * ntt::N> src,
// k x 1
template<size_t k, size_t l>
static inline void
poly_vec_decode(std::span<const uint8_t, k * 32 * l> src,
std::span<field::zq_t, k * ntt::N> dst)
poly_vec_decode(std::span<const uint8_t, k * 32 * l> src, std::span<field::zq_t, k * ntt::N> dst)
requires(kyber_params::check_k(k))
{
using serialized_t = std::span<const uint8_t, src.size() / k>;
@@ -145,8 +138,7 @@ poly_vec_decode(std::span<const uint8_t, k * 32 * l> src,
const size_t off0 = i * l * 32;
const size_t off1 = i * ntt::N;
kyber_utils::decode<l>(serialized_t(src.subspan(off0, 32 * l)),
poly_t(dst.subspan(off1, ntt::N)));
kyber_utils::decode<l>(serialized_t(src.subspan(off0, 32 * l)), poly_t(dst.subspan(off1, ntt::N)));
}
}

View File

@@ -31,10 +31,8 @@ parse(shake128::shake128_t& hasher, std::span<field::zq_t, ntt::N> poly)
hasher.squeeze(buf);
for (size_t off = 0; (off < buf.size()) && (coeff_idx < n); off += 3) {
const uint16_t d1 = (static_cast<uint16_t>(buf[off + 1] & 0x0f) << 8) |
(static_cast<uint16_t>(buf[off + 0]) << 0);
const uint16_t d2 = (static_cast<uint16_t>(buf[off + 2]) << 4) |
(static_cast<uint16_t>(buf[off + 1] >> 4));
const uint16_t d1 = (static_cast<uint16_t>(buf[off + 1] & 0x0f) << 8) | static_cast<uint16_t>(buf[off + 0]);
const uint16_t d2 = (static_cast<uint16_t>(buf[off + 2]) << 4) | (static_cast<uint16_t>(buf[off + 1] >> 4));
if (d1 < field::Q) {
poly[coeff_idx] = field::zq_t(d1);
@@ -57,8 +55,7 @@ parse(shake128::shake128_t& hasher, std::span<field::zq_t, ntt::N> poly)
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
template<size_t k, bool transpose>
static inline void
generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat,
std::span<const uint8_t, 32> rho)
generate_matrix(std::span<field::zq_t, k * k * ntt::N> mat, std::span<const uint8_t, 32> rho)
requires(kyber_params::check_k(k))
{
std::array<uint8_t, rho.size() + 2> xof_in{};
@@ -113,10 +110,8 @@ cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
const uint8_t t1 = (word >> 1) & mask8;
const uint8_t t2 = t0 + t1;
poly[poff + 0] =
field::zq_t((t2 >> 0) & mask2) - field::zq_t((t2 >> 2) & mask2);
poly[poff + 1] =
field::zq_t((t2 >> 4) & mask2) - field::zq_t((t2 >> 6) & mask2);
poly[poff + 0] = field::zq_t((t2 >> 0) & mask2) - field::zq_t((t2 >> 2) & mask2);
poly[poff + 1] = field::zq_t((t2 >> 4) & mask2) - field::zq_t((t2 >> 6) & mask2);
}
} else {
static_assert(eta == 3, "η must be 3 !");
@@ -129,23 +124,18 @@ cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
const size_t boff = i * 3;
const size_t poff = i << 2;
const uint32_t word = (static_cast<uint32_t>(prf[boff + 2]) << 16) |
(static_cast<uint32_t>(prf[boff + 1]) << 8) |
(static_cast<uint32_t>(prf[boff + 0]) << 0);
const uint32_t word = (static_cast<uint32_t>(prf[boff + 2]) << 16) | (static_cast<uint32_t>(prf[boff + 1]) << 8) |
static_cast<uint32_t>(prf[boff + 0]);
const uint32_t t0 = (word >> 0) & mask24;
const uint32_t t1 = (word >> 1) & mask24;
const uint32_t t2 = (word >> 2) & mask24;
const uint32_t t3 = t0 + t1 + t2;
poly[poff + 0] =
field::zq_t((t3 >> 0) & mask3) - field::zq_t((t3 >> 3) & mask3);
poly[poff + 1] =
field::zq_t((t3 >> 6) & mask3) - field::zq_t((t3 >> 9) & mask3);
poly[poff + 2] =
field::zq_t((t3 >> 12) & mask3) - field::zq_t((t3 >> 15) & mask3);
poly[poff + 3] =
field::zq_t((t3 >> 18) & mask3) - field::zq_t((t3 >> 21) & mask3);
poly[poff + 0] = field::zq_t((t3 >> 0) & mask3) - field::zq_t((t3 >> 3) & mask3);
poly[poff + 1] = field::zq_t((t3 >> 6) & mask3) - field::zq_t((t3 >> 9) & mask3);
poly[poff + 2] = field::zq_t((t3 >> 12) & mask3) - field::zq_t((t3 >> 15) & mask3);
poly[poff + 3] = field::zq_t((t3 >> 18) & mask3) - field::zq_t((t3 >> 21) & mask3);
}
}
}
@@ -155,9 +145,7 @@ cbd(std::span<const uint8_t, 64 * eta> prf, std::span<field::zq_t, ntt::N> poly)
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
template<size_t k, size_t eta>
static inline void
generate_vector(std::span<field::zq_t, k * ntt::N> vec,
std::span<const uint8_t, 32> sigma,
const uint8_t nonce)
generate_vector(std::span<field::zq_t, k * ntt::N> vec, std::span<const uint8_t, 32> sigma, const uint8_t nonce)
requires((k == 1) || kyber_params::check_k(k))
{
std::array<uint8_t, 64 * eta> prf_out{};

View File

@@ -15,8 +15,7 @@ namespace kyber_utils {
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
template<size_t l>
static inline void
encode(std::span<const field::zq_t, ntt::N> poly,
std::span<uint8_t, 32 * l> arr)
encode(std::span<const field::zq_t, ntt::N> poly, std::span<uint8_t, 32 * l> arr)
requires(kyber_params::check_l(l))
{
std::fill(arr.begin(), arr.end(), 0);
@@ -42,8 +41,7 @@ encode(std::span<const field::zq_t, ntt::N> poly,
for (size_t i = 0; i < itr_cnt; i++) {
const size_t off = i << 1;
arr[i] = (static_cast<uint8_t>(poly[off + 1].raw() & msk) << 4) |
(static_cast<uint8_t>(poly[off + 0].raw() & msk) << 0);
arr[i] = (static_cast<uint8_t>(poly[off + 1].raw() & msk) << 4) | static_cast<uint8_t>(poly[off + 0].raw() & msk);
}
} else if constexpr (l == 5) {
constexpr size_t itr_cnt = ntt::N >> 3;
@@ -66,18 +64,13 @@ encode(std::span<const field::zq_t, ntt::N> poly,
const auto t6 = poly[poff + 6].raw();
const auto t7 = poly[poff + 7].raw();
arr[boff + 0] = (static_cast<uint8_t>(t1 & mask3) << 5) |
(static_cast<uint8_t>(t0 & mask5) << 0);
arr[boff + 1] = (static_cast<uint8_t>(t3 & mask1) << 7) |
(static_cast<uint8_t>(t2 & mask5) << 2) |
arr[boff + 0] = (static_cast<uint8_t>(t1 & mask3) << 5) | (static_cast<uint8_t>(t0 & mask5) << 0);
arr[boff + 1] = (static_cast<uint8_t>(t3 & mask1) << 7) | (static_cast<uint8_t>(t2 & mask5) << 2) |
static_cast<uint8_t>((t1 >> 3) & mask2);
arr[boff + 2] = (static_cast<uint8_t>(t4 & mask4) << 4) |
static_cast<uint8_t>((t3 >> 1) & mask4);
arr[boff + 3] = (static_cast<uint8_t>(t6 & mask2) << 6) |
(static_cast<uint8_t>(t5 & mask5) << 1) |
arr[boff + 2] = (static_cast<uint8_t>(t4 & mask4) << 4) | static_cast<uint8_t>((t3 >> 1) & mask4);
arr[boff + 3] = (static_cast<uint8_t>(t6 & mask2) << 6) | (static_cast<uint8_t>(t5 & mask5) << 1) |
static_cast<uint8_t>((t4 >> 4) & mask1);
arr[boff + 4] = (static_cast<uint8_t>(t7 & mask5) << 3) |
static_cast<uint8_t>((t6 >> 2) & mask3);
arr[boff + 4] = (static_cast<uint8_t>(t7 & mask5) << 3) | static_cast<uint8_t>((t6 >> 2) & mask3);
}
} else if constexpr (l == 10) {
constexpr size_t itr_cnt = ntt::N >> 2;
@@ -95,12 +88,9 @@ encode(std::span<const field::zq_t, ntt::N> poly,
const auto t3 = poly[poff + 3].raw();
arr[boff + 0] = static_cast<uint8_t>(t0);
arr[boff + 1] = static_cast<uint8_t>((t1 & mask6) << 2) |
static_cast<uint8_t>((t0 >> 8) & mask2);
arr[boff + 2] = static_cast<uint8_t>((t2 & mask4) << 4) |
static_cast<uint8_t>((t1 >> 6) & mask4);
arr[boff + 3] = static_cast<uint8_t>((t3 & mask2) << 6) |
static_cast<uint8_t>((t2 >> 4) & mask6);
arr[boff + 1] = static_cast<uint8_t>((t1 & mask6) << 2) | static_cast<uint8_t>((t0 >> 8) & mask2);
arr[boff + 2] = static_cast<uint8_t>((t2 & mask4) << 4) | static_cast<uint8_t>((t1 >> 6) & mask4);
arr[boff + 3] = static_cast<uint8_t>((t3 & mask2) << 6) | static_cast<uint8_t>((t2 >> 4) & mask6);
arr[boff + 4] = static_cast<uint8_t>(t3 >> 2);
}
} else if constexpr (l == 11) {
@@ -128,22 +118,15 @@ encode(std::span<const field::zq_t, ntt::N> poly,
const auto t7 = poly[poff + 7].raw();
arr[boff + 0] = static_cast<uint8_t>(t0 & mask8);
arr[boff + 1] = static_cast<uint8_t>((t1 & mask5) << 3) |
static_cast<uint8_t>((t0 >> 8) & mask3);
arr[boff + 2] = static_cast<uint8_t>((t2 & mask2) << 6) |
static_cast<uint8_t>((t1 >> 5) & mask6);
arr[boff + 1] = static_cast<uint8_t>((t1 & mask5) << 3) | static_cast<uint8_t>((t0 >> 8) & mask3);
arr[boff + 2] = static_cast<uint8_t>((t2 & mask2) << 6) | static_cast<uint8_t>((t1 >> 5) & mask6);
arr[boff + 3] = static_cast<uint8_t>((t2 >> 2) & mask8);
arr[boff + 4] = static_cast<uint8_t>((t3 & mask7) << 1) |
static_cast<uint8_t>((t2 >> 10) & mask1);
arr[boff + 5] = static_cast<uint8_t>((t4 & mask4) << 4) |
static_cast<uint8_t>((t3 >> 7) & mask4);
arr[boff + 6] = static_cast<uint8_t>((t5 & mask1) << 7) |
static_cast<uint8_t>((t4 >> 4) & mask7);
arr[boff + 4] = static_cast<uint8_t>((t3 & mask7) << 1) | static_cast<uint8_t>((t2 >> 10) & mask1);
arr[boff + 5] = static_cast<uint8_t>((t4 & mask4) << 4) | static_cast<uint8_t>((t3 >> 7) & mask4);
arr[boff + 6] = static_cast<uint8_t>((t5 & mask1) << 7) | static_cast<uint8_t>((t4 >> 4) & mask7);
arr[boff + 7] = static_cast<uint8_t>((t5 >> 1) & mask8);
arr[boff + 8] = static_cast<uint8_t>((t6 & mask6) << 2) |
static_cast<uint8_t>((t5 >> 9) & mask2);
arr[boff + 9] = static_cast<uint8_t>((t7 & mask3) << 5) |
static_cast<uint8_t>((t6 >> 6) & mask5);
arr[boff + 8] = static_cast<uint8_t>((t6 & mask6) << 2) | static_cast<uint8_t>((t5 >> 9) & mask2);
arr[boff + 9] = static_cast<uint8_t>((t7 & mask3) << 5) | static_cast<uint8_t>((t6 >> 6) & mask5);
arr[boff + 10] = static_cast<uint8_t>((t7 >> 3) & mask8);
}
} else {
@@ -160,8 +143,7 @@ encode(std::span<const field::zq_t, ntt::N> poly,
const auto t1 = poly[poff + 1].raw();
arr[boff + 0] = static_cast<uint8_t>(t0);
arr[boff + 1] = static_cast<uint8_t>((t1 & mask4) << 4) |
static_cast<uint8_t>((t0 >> 8) & mask4);
arr[boff + 1] = static_cast<uint8_t>((t1 & mask4) << 4) | static_cast<uint8_t>((t0 >> 8) & mask4);
arr[boff + 2] = static_cast<uint8_t>(t1 >> 4);
}
}
@@ -175,8 +157,7 @@ encode(std::span<const field::zq_t, ntt::N> poly,
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
template<size_t l>
static inline void
decode(std::span<const uint8_t, 32 * l> arr,
std::span<field::zq_t, ntt::N> poly)
decode(std::span<const uint8_t, 32 * l> arr, std::span<field::zq_t, ntt::N> poly)
requires(kyber_params::check_l(l))
{
if constexpr (l == 1) {
@@ -220,16 +201,16 @@ decode(std::span<const uint8_t, 32 * l> arr,
const size_t boff = i * 5;
const auto t0 = static_cast<uint16_t>(arr[boff + 0] & mask5);
const auto t1 = static_cast<uint16_t>((arr[boff + 1] & mask2) << 3) |
static_cast<uint16_t>((arr[boff + 0] >> 5) & mask3);
const auto t1 =
static_cast<uint16_t>((arr[boff + 1] & mask2) << 3) | static_cast<uint16_t>((arr[boff + 0] >> 5) & mask3);
const auto t2 = static_cast<uint16_t>((arr[boff + 1] >> 2) & mask5);
const auto t3 = static_cast<uint16_t>((arr[boff + 2] & mask4) << 1) |
static_cast<uint16_t>((arr[boff + 1] >> 7) & mask1);
const auto t4 = static_cast<uint16_t>((arr[boff + 3] & mask1) << 4) |
static_cast<uint16_t>((arr[boff + 2] >> 4) & mask4);
const auto t3 =
static_cast<uint16_t>((arr[boff + 2] & mask4) << 1) | static_cast<uint16_t>((arr[boff + 1] >> 7) & mask1);
const auto t4 =
static_cast<uint16_t>((arr[boff + 3] & mask1) << 4) | static_cast<uint16_t>((arr[boff + 2] >> 4) & mask4);
const auto t5 = static_cast<uint16_t>((arr[boff + 3] >> 1) & mask5);
const auto t6 = static_cast<uint16_t>((arr[boff + 4] & mask3) << 2) |
static_cast<uint16_t>((arr[boff + 3] >> 6) & mask2);
const auto t6 =
static_cast<uint16_t>((arr[boff + 4] & mask3) << 2) | static_cast<uint16_t>((arr[boff + 3] >> 6) & mask2);
const auto t7 = static_cast<uint16_t>((arr[boff + 4] >> 3) & mask5);
poly[poff + 0] = field::zq_t(t0);
@@ -251,14 +232,10 @@ decode(std::span<const uint8_t, 32 * l> arr,
const size_t poff = i << 2;
const size_t boff = i * 5;
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask2) << 8) |
static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2] & mask4) << 6) |
static_cast<uint16_t>(arr[boff + 1] >> 2);
const auto t2 = (static_cast<uint16_t>(arr[boff + 3] & mask6) << 4) |
static_cast<uint16_t>(arr[boff + 2] >> 4);
const auto t3 = (static_cast<uint16_t>(arr[boff + 4]) << 2) |
static_cast<uint16_t>(arr[boff + 3] >> 6);
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask2) << 8) | static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2] & mask4) << 6) | static_cast<uint16_t>(arr[boff + 1] >> 2);
const auto t2 = (static_cast<uint16_t>(arr[boff + 3] & mask6) << 4) | static_cast<uint16_t>(arr[boff + 2] >> 4);
const auto t3 = (static_cast<uint16_t>(arr[boff + 4]) << 2) | static_cast<uint16_t>(arr[boff + 3] >> 6);
poly[poff + 0] = field::zq_t(t0);
poly[poff + 1] = field::zq_t(t1);
@@ -279,24 +256,16 @@ decode(std::span<const uint8_t, 32 * l> arr,
const size_t poff = i << 3;
const size_t boff = i * 11;
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask3) << 8) |
static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2] & mask6) << 5) |
static_cast<uint16_t>(arr[boff + 1] >> 3);
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask3) << 8) | static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2] & mask6) << 5) | static_cast<uint16_t>(arr[boff + 1] >> 3);
const auto t2 = (static_cast<uint16_t>(arr[boff + 4] & mask1) << 10) |
(static_cast<uint16_t>(arr[boff + 3]) << 2) |
static_cast<uint16_t>(arr[boff + 2] >> 6);
const auto t3 = (static_cast<uint16_t>(arr[boff + 5] & mask4) << 7) |
static_cast<uint16_t>(arr[boff + 4] >> 1);
const auto t4 = (static_cast<uint16_t>(arr[boff + 6] & mask7) << 4) |
static_cast<uint16_t>(arr[boff + 5] >> 4);
(static_cast<uint16_t>(arr[boff + 3]) << 2) | static_cast<uint16_t>(arr[boff + 2] >> 6);
const auto t3 = (static_cast<uint16_t>(arr[boff + 5] & mask4) << 7) | static_cast<uint16_t>(arr[boff + 4] >> 1);
const auto t4 = (static_cast<uint16_t>(arr[boff + 6] & mask7) << 4) | static_cast<uint16_t>(arr[boff + 5] >> 4);
const auto t5 = (static_cast<uint16_t>(arr[boff + 8] & mask2) << 9) |
(static_cast<uint16_t>(arr[boff + 7]) << 1) |
static_cast<uint16_t>(arr[boff + 6] >> 7);
const auto t6 = (static_cast<uint16_t>(arr[boff + 9] & mask5) << 6) |
static_cast<uint16_t>(arr[boff + 8] >> 2);
const auto t7 = (static_cast<uint16_t>(arr[boff + 10]) << 3) |
static_cast<uint16_t>(arr[boff + 9] >> 5);
(static_cast<uint16_t>(arr[boff + 7]) << 1) | static_cast<uint16_t>(arr[boff + 6] >> 7);
const auto t6 = (static_cast<uint16_t>(arr[boff + 9] & mask5) << 6) | static_cast<uint16_t>(arr[boff + 8] >> 2);
const auto t7 = (static_cast<uint16_t>(arr[boff + 10]) << 3) | static_cast<uint16_t>(arr[boff + 9] >> 5);
poly[poff + 0] = field::zq_t(t0);
poly[poff + 1] = field::zq_t(t1);
@@ -317,10 +286,8 @@ decode(std::span<const uint8_t, 32 * l> arr,
const size_t poff = i << 1;
const size_t boff = i * 3;
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask4) << 8) |
static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2]) << 4) |
static_cast<uint16_t>(arr[boff + 1] >> 4);
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask4) << 8) | static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2]) << 4) | static_cast<uint16_t>(arr[boff + 1] >> 4);
poly[poff + 0] = field::zq_t(t0);
poly[poff + 1] = field::zq_t(t1);