mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 15:38:01 -05:00
Update barretenberg to device code
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
#include "../bn254/fq.hpp"
|
||||
#include "../bn254/fq12.hpp"
|
||||
#include "../bn254/fq2.hpp"
|
||||
#include "../bn254/fr.hpp"
|
||||
#include "../bn254/g1.hpp"
|
||||
#include "../bn254/g2.hpp"
|
||||
#include "../bn254/fq.cuh"
|
||||
#include "../bn254/fq12.cuh"
|
||||
#include "../bn254/fq2.cuh"
|
||||
#include "../bn254/fr.cuh"
|
||||
#include "../bn254/g1.cuh"
|
||||
#include "../bn254/g2.cuh"
|
||||
|
||||
namespace bb::curve {
|
||||
class BN254 {
|
||||
@@ -3,7 +3,7 @@
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
|
||||
#include "../../fields/field.hpp"
|
||||
#include "../../fields/field.cuh"
|
||||
|
||||
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
|
||||
namespace bb {
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../fields/field2.hpp"
|
||||
#include "./fq.hpp"
|
||||
#include "../../fields/field2.cuh"
|
||||
#include "./fq.cuh"
|
||||
|
||||
namespace bb {
|
||||
struct Bn254Fq2Params {
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <iomanip>
|
||||
#include <ostream>
|
||||
|
||||
#include "../../fields/field.hpp"
|
||||
#include "../../fields/field.cuh"
|
||||
|
||||
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../groups/group.hpp"
|
||||
#include "./fq.hpp"
|
||||
#include "./fr.hpp"
|
||||
#include "../../groups/group.cuh"
|
||||
#include "./fq.cuh"
|
||||
#include "./fr.cuh"
|
||||
|
||||
namespace bb {
|
||||
struct Bn254G1Params {
|
||||
@@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../groups/group.hpp"
|
||||
#include "./fq2.hpp"
|
||||
#include "./fr.hpp"
|
||||
#include "../../groups/group.cuh"
|
||||
#include "./fq2.cuh"
|
||||
#include "./fr.cuh"
|
||||
|
||||
namespace bb {
|
||||
struct Bn254G2Params {
|
||||
@@ -6,5 +6,5 @@
|
||||
* declarations header) Spectialized definitions are in "field_impl_generic.hpp" and "field_impl_x64.hpp"
|
||||
* (which include "field_impl.hpp")
|
||||
*/
|
||||
#include "./field_impl_generic.hpp"
|
||||
#include "./field_impl_x64.hpp"
|
||||
#include "./field_impl_generic.cuh"
|
||||
#include "./field_impl_x64.cuh"
|
||||
@@ -1,9 +1,8 @@
|
||||
#pragma once
|
||||
#include "../../common/assert.hpp"
|
||||
#include "../../common/compiler_hints.hpp"
|
||||
#include "../../numeric/random/engine.hpp"
|
||||
#include "../../numeric/uint128/uint128.hpp"
|
||||
#include "../../numeric/uint256/uint256_impl.hpp"
|
||||
#include "../../numeric/uint256/uint256_impl.cuh"
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
@@ -52,33 +51,33 @@ template <class Params_> struct alignas(32) field {
|
||||
// std::array<field, N> arr {}; // zero-initialized, preferable for moderate N
|
||||
field() = default;
|
||||
|
||||
constexpr field(const numeric::uint256_t& input) noexcept
|
||||
__device__ constexpr field(const numeric::uint256_t& input) noexcept
|
||||
: data{ input.data[0], input.data[1], input.data[2], input.data[3] }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE (unsigned long is platform dependent, which we want in this case)
|
||||
constexpr field(const unsigned long input) noexcept
|
||||
__device__ constexpr field(const unsigned long input) noexcept
|
||||
: data{ input, 0, 0, 0 }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
constexpr field(const unsigned int input) noexcept
|
||||
__device__ constexpr field(const unsigned int input) noexcept
|
||||
: data{ input, 0, 0, 0 }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE (unsigned long long is platform dependent, which we want in this case)
|
||||
constexpr field(const unsigned long long input) noexcept
|
||||
__device__ constexpr field(const unsigned long long input) noexcept
|
||||
: data{ input, 0, 0, 0 }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
constexpr field(const int input) noexcept
|
||||
__device__ constexpr field(const int input) noexcept
|
||||
: data{ 0, 0, 0, 0 }
|
||||
{
|
||||
if (input < 0) {
|
||||
@@ -98,63 +97,47 @@ template <class Params_> struct alignas(32) field {
|
||||
}
|
||||
}
|
||||
|
||||
constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
|
||||
__device__ constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
|
||||
: data{ a, b, c, d } {};
|
||||
|
||||
/**
|
||||
* @brief Convert a 512-bit big integer into a field element.
|
||||
*
|
||||
* @details Used for deriving field elements from random values. 512-bits prevents biased output as 2^512>>modulus
|
||||
*
|
||||
*/
|
||||
constexpr explicit field(const uint512_t& input) noexcept
|
||||
{
|
||||
uint256_t value = (input % modulus).lo;
|
||||
data[0] = value.data[0];
|
||||
data[1] = value.data[1];
|
||||
data[2] = value.data[2];
|
||||
data[3] = value.data[3];
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
constexpr explicit field(std::string input) noexcept
|
||||
__device__ constexpr explicit field(std::string input) noexcept
|
||||
{
|
||||
uint256_t value(input);
|
||||
*this = field(value);
|
||||
}
|
||||
|
||||
constexpr explicit operator bool() const
|
||||
__device__ constexpr explicit operator bool() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
ASSERT(out.data[0] == 0 || out.data[0] == 1);
|
||||
return static_cast<bool>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint8_t() const
|
||||
__device__ constexpr explicit operator uint8_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return static_cast<uint8_t>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint16_t() const
|
||||
__device__ constexpr explicit operator uint16_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return static_cast<uint16_t>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint32_t() const
|
||||
__device__ constexpr explicit operator uint32_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return static_cast<uint32_t>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint64_t() const
|
||||
__device__ constexpr explicit operator uint64_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return out.data[0];
|
||||
}
|
||||
|
||||
constexpr explicit operator uint128_t() const
|
||||
__device__ constexpr explicit operator uint128_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
uint128_t lo = out.data[0];
|
||||
@@ -162,7 +145,7 @@ template <class Params_> struct alignas(32) field {
|
||||
return (hi << 64) | lo;
|
||||
}
|
||||
|
||||
constexpr operator uint256_t() const noexcept
|
||||
__device__ constexpr operator uint256_t() const noexcept
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return uint256_t(out.data[0], out.data[1], out.data[2], out.data[3]);
|
||||
@@ -180,8 +163,9 @@ template <class Params_> struct alignas(32) field {
|
||||
constexpr ~field() noexcept = default;
|
||||
alignas(32) uint64_t data[4]; // NOLINT
|
||||
|
||||
static constexpr uint256_t modulus =
|
||||
uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
|
||||
static constexpr __device__ uint256_t get_modulus() {
|
||||
return uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
|
||||
}
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr uint256_t r_squared_uint{
|
||||
Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3
|
||||
@@ -287,45 +271,45 @@ template <class Params_> struct alignas(32) field {
|
||||
return result;
|
||||
}
|
||||
|
||||
BB_INLINE constexpr field operator*(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field operator+(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field operator-(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field operator-() const noexcept;
|
||||
constexpr field operator/(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr field operator*(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr field operator+(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr field operator-(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr field operator-() const noexcept;
|
||||
__device__ constexpr field operator/(const field& other) const noexcept;
|
||||
|
||||
// prefix increment (++x)
|
||||
BB_INLINE constexpr field operator++() noexcept;
|
||||
BB_INLINE __device__ constexpr field operator++() noexcept;
|
||||
// postfix increment (x++)
|
||||
// NOLINTNEXTLINE
|
||||
BB_INLINE constexpr field operator++(int) noexcept;
|
||||
BB_INLINE __device__ constexpr field operator++(int) noexcept;
|
||||
|
||||
BB_INLINE constexpr field& operator*=(const field& other) noexcept;
|
||||
BB_INLINE constexpr field& operator+=(const field& other) noexcept;
|
||||
BB_INLINE constexpr field& operator-=(const field& other) noexcept;
|
||||
constexpr field& operator/=(const field& other) noexcept;
|
||||
BB_INLINE __device__ constexpr field& operator*=(const field& other) noexcept;
|
||||
BB_INLINE __device__ constexpr field& operator+=(const field& other) noexcept;
|
||||
BB_INLINE __device__ constexpr field& operator-=(const field& other) noexcept;
|
||||
__device__ constexpr field& operator/=(const field& other) noexcept;
|
||||
|
||||
// NOTE: comparison operators exist so that `field` is comparible with stl methods that require them.
|
||||
// (e.g. std::sort)
|
||||
// Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
|
||||
BB_INLINE constexpr bool operator>(const field& other) const noexcept;
|
||||
BB_INLINE constexpr bool operator<(const field& other) const noexcept;
|
||||
BB_INLINE constexpr bool operator==(const field& other) const noexcept;
|
||||
BB_INLINE constexpr bool operator!=(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr bool operator>(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr bool operator<(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr bool operator==(const field& other) const noexcept;
|
||||
BB_INLINE __device__ constexpr bool operator!=(const field& other) const noexcept;
|
||||
|
||||
BB_INLINE constexpr field to_montgomery_form() const noexcept;
|
||||
BB_INLINE constexpr field from_montgomery_form() const noexcept;
|
||||
BB_INLINE __device__ constexpr field to_montgomery_form() const noexcept;
|
||||
BB_INLINE __device__ constexpr field from_montgomery_form() const noexcept;
|
||||
|
||||
BB_INLINE constexpr field sqr() const noexcept;
|
||||
BB_INLINE constexpr void self_sqr() noexcept;
|
||||
BB_INLINE __device__ constexpr field sqr() const noexcept;
|
||||
BB_INLINE __device__ constexpr void self_sqr() noexcept;
|
||||
|
||||
BB_INLINE constexpr field pow(const uint256_t& exponent) const noexcept;
|
||||
BB_INLINE constexpr field pow(uint64_t exponent) const noexcept;
|
||||
BB_INLINE __device__ constexpr field pow(const uint256_t& exponent) const noexcept;
|
||||
BB_INLINE __device__ constexpr field pow(uint64_t exponent) const noexcept;
|
||||
static_assert(Params::modulus_0 != 1);
|
||||
static constexpr uint256_t modulus_minus_two =
|
||||
uint256_t(Params::modulus_0 - 2ULL, Params::modulus_1, Params::modulus_2, Params::modulus_3);
|
||||
constexpr field invert() const noexcept;
|
||||
static void batch_invert(std::span<field> coeffs) noexcept;
|
||||
static void batch_invert(field* coeffs, size_t n) noexcept;
|
||||
__device__ constexpr field invert() const noexcept;
|
||||
static __device__ void batch_invert(std::span<field> coeffs) noexcept;
|
||||
static __device__ void batch_invert(field* coeffs, size_t n) noexcept;
|
||||
/**
|
||||
* @brief Compute square root of the field element.
|
||||
*
|
||||
@@ -333,17 +317,17 @@ template <class Params_> struct alignas(32) field {
|
||||
*/
|
||||
constexpr std::pair<bool, field> sqrt() const noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_neg() noexcept;
|
||||
BB_INLINE __device__ constexpr void self_neg() noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_to_montgomery_form() noexcept;
|
||||
BB_INLINE constexpr void self_from_montgomery_form() noexcept;
|
||||
BB_INLINE __device__ constexpr void self_to_montgomery_form() noexcept;
|
||||
BB_INLINE __device__ constexpr void self_from_montgomery_form() noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_conditional_negate(uint64_t predicate) noexcept;
|
||||
BB_INLINE __device__ constexpr void self_conditional_negate(uint64_t predicate) noexcept;
|
||||
|
||||
BB_INLINE constexpr field reduce_once() const noexcept;
|
||||
BB_INLINE constexpr void self_reduce_once() noexcept;
|
||||
BB_INLINE __device__ constexpr field reduce_once() const noexcept;
|
||||
BB_INLINE __device__ constexpr void self_reduce_once() noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_set_msb() noexcept;
|
||||
BB_INLINE __device__ constexpr void self_set_msb() noexcept;
|
||||
[[nodiscard]] BB_INLINE constexpr bool is_msb_set() const noexcept;
|
||||
[[nodiscard]] BB_INLINE constexpr uint64_t is_msb_set_word() const noexcept;
|
||||
|
||||
@@ -472,53 +456,6 @@ template <class Params_> struct alignas(32) field {
|
||||
};
|
||||
}
|
||||
|
||||
static void split_into_endomorphism_scalars_384(const field& input, field& k1_out, field& k2_out)
|
||||
{
|
||||
constexpr field minus_b1f{
|
||||
Params::endo_minus_b1_lo,
|
||||
Params::endo_minus_b1_mid,
|
||||
0,
|
||||
0,
|
||||
};
|
||||
constexpr field b2f{
|
||||
Params::endo_b2_lo,
|
||||
Params::endo_b2_mid,
|
||||
0,
|
||||
0,
|
||||
};
|
||||
constexpr uint256_t g1{
|
||||
Params::endo_g1_lo,
|
||||
Params::endo_g1_mid,
|
||||
Params::endo_g1_hi,
|
||||
Params::endo_g1_hihi,
|
||||
};
|
||||
constexpr uint256_t g2{
|
||||
Params::endo_g2_lo,
|
||||
Params::endo_g2_mid,
|
||||
Params::endo_g2_hi,
|
||||
Params::endo_g2_hihi,
|
||||
};
|
||||
|
||||
field kf = input.reduce_once();
|
||||
uint256_t k{ kf.data[0], kf.data[1], kf.data[2], kf.data[3] };
|
||||
|
||||
uint512_t c1 = (uint512_t(k) * static_cast<uint512_t>(g1)) >> 384;
|
||||
uint512_t c2 = (uint512_t(k) * static_cast<uint512_t>(g2)) >> 384;
|
||||
|
||||
field c1f{ c1.lo.data[0], c1.lo.data[1], c1.lo.data[2], c1.lo.data[3] };
|
||||
field c2f{ c2.lo.data[0], c2.lo.data[1], c2.lo.data[2], c2.lo.data[3] };
|
||||
|
||||
c1f.self_to_montgomery_form();
|
||||
c2f.self_to_montgomery_form();
|
||||
c1f = c1f * minus_b1f;
|
||||
c2f = c2f * b2f;
|
||||
field r2f = c1f - c2f;
|
||||
field beta = cube_root_of_unity();
|
||||
field r1f = input.reduce_once() - r2f * beta;
|
||||
k1_out = r1f;
|
||||
k2_out = -r2f;
|
||||
}
|
||||
|
||||
// static constexpr auto coset_generators = compute_coset_generators();
|
||||
// static constexpr std::array<field, 15> coset_generators = compute_coset_generators((1 << 30U));
|
||||
|
||||
@@ -540,12 +477,10 @@ template <class Params_> struct alignas(32) field {
|
||||
src = T;
|
||||
}
|
||||
|
||||
static field random_element(numeric::RNG* engine = nullptr) noexcept;
|
||||
|
||||
static constexpr field multiplicative_generator() noexcept;
|
||||
|
||||
static constexpr uint256_t twice_modulus = modulus + modulus;
|
||||
static constexpr uint256_t not_modulus = -modulus;
|
||||
static constexpr uint256_t twice_modulus = get_modulus() + get_modulus();
|
||||
static constexpr uint256_t not_modulus = -get_modulus();
|
||||
static constexpr uint256_t twice_not_modulus = -twice_modulus;
|
||||
|
||||
struct wnaf_table {
|
||||
@@ -612,9 +547,9 @@ template <class Params_> struct alignas(32) field {
|
||||
uint64_t& result_8);
|
||||
BB_INLINE static constexpr std::array<uint64_t, WASM_NUM_LIMBS> wasm_convert(const uint64_t* data);
|
||||
#endif
|
||||
BB_INLINE static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b) noexcept;
|
||||
BB_INLINE __device__ static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t mac(
|
||||
BB_INLINE __device__ static constexpr uint64_t mac(
|
||||
uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& carry_out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr void mac(
|
||||
@@ -3,13 +3,12 @@
|
||||
#include "../../common/slab_allocator.hpp"
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include "../../numeric/bitop/get_msb.hpp"
|
||||
#include "../../numeric/random/engine.hpp"
|
||||
#include <memory>
|
||||
#include <span>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "./field_declarations.hpp"
|
||||
#include "./field_declarations.cuh"
|
||||
|
||||
namespace bb {
|
||||
using namespace numeric;
|
||||
@@ -24,7 +23,7 @@ using namespace numeric;
|
||||
* Mutiplication
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::operator*(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator*(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::mul");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -39,7 +38,7 @@ template <class T> constexpr field<T> field<T>::operator*(const field& other) co
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator*=(const field& other) noexcept
|
||||
template <class T> __device__ constexpr field<T>& field<T>::operator*=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_mul");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -61,7 +60,7 @@ template <class T> constexpr field<T>& field<T>::operator*=(const field& other)
|
||||
* Squaring
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::sqr() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::sqr() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::sqr");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -75,7 +74,7 @@ template <class T> constexpr field<T> field<T>::sqr() const noexcept
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_sqr() noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_sqr() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("f::self_sqr");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -95,7 +94,7 @@ template <class T> constexpr void field<T>::self_sqr() noexcept
|
||||
* Addition
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::operator+(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator+(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::add");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -109,7 +108,7 @@ template <class T> constexpr field<T> field<T>::operator+(const field& other) co
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator+=(const field& other) noexcept
|
||||
template <class T> __device__ constexpr field<T>& field<T>::operator+=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_add");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -125,14 +124,14 @@ template <class T> constexpr field<T>& field<T>::operator+=(const field& other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::operator++() noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator++() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("++f");
|
||||
return *this += 1;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cert-dcl21-cpp) circular linting errors. If const is added, linter suggests removing
|
||||
template <class T> constexpr field<T> field<T>::operator++(int) noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator++(int) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::increment");
|
||||
field<T> value_before_incrementing = *this;
|
||||
@@ -145,7 +144,7 @@ template <class T> constexpr field<T> field<T>::operator++(int) noexcept
|
||||
* Subtraction
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::operator-(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator-(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::sub");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -159,7 +158,7 @@ template <class T> constexpr field<T> field<T>::operator-(const field& other) co
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::operator-() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator-() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("-f");
|
||||
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -188,7 +187,7 @@ template <class T> constexpr field<T> field<T>::operator-() const noexcept
|
||||
return (p - *this).reduce_once(); // modulus - *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator-=(const field& other) noexcept
|
||||
template <class T> __device__ constexpr field<T>& field<T>::operator-=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_sub");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -204,9 +203,10 @@ template <class T> constexpr field<T>& field<T>::operator-=(const field& other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_neg() noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_neg() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_neg");
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
|
||||
@@ -217,7 +217,7 @@ template <class T> constexpr void field<T>::self_neg() noexcept
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_conditional_negate(const uint64_t predicate) noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_conditional_negate(const uint64_t predicate) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_conditional_negate");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -243,7 +243,7 @@ template <class T> constexpr void field<T>::self_conditional_negate(const uint64
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
template <class T> constexpr bool field<T>::operator>(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr bool field<T>::operator>(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::gt");
|
||||
const field left = reduce_once();
|
||||
@@ -268,12 +268,12 @@ template <class T> constexpr bool field<T>::operator>(const field& other) const
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
template <class T> constexpr bool field<T>::operator<(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr bool field<T>::operator<(const field& other) const noexcept
|
||||
{
|
||||
return (other > *this);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::operator==(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr bool field<T>::operator==(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::eqeq");
|
||||
const field left = reduce_once();
|
||||
@@ -282,12 +282,12 @@ template <class T> constexpr bool field<T>::operator==(const field& other) const
|
||||
(left.data[3] == right.data[3]);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::operator!=(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr bool field<T>::operator!=(const field& other) const noexcept
|
||||
{
|
||||
return (!operator==(other));
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::to_montgomery_form() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::to_montgomery_form() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::to_montgomery_form");
|
||||
constexpr field r_squared =
|
||||
@@ -306,14 +306,14 @@ template <class T> constexpr field<T> field<T>::to_montgomery_form() const noexc
|
||||
return (result * r_squared).reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::from_montgomery_form() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::from_montgomery_form() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::from_montgomery_form");
|
||||
constexpr field one_raw{ 1, 0, 0, 0 };
|
||||
return operator*(one_raw).reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_to_montgomery_form() noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_to_montgomery_form() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_to_montgomery_form");
|
||||
constexpr field r_squared =
|
||||
@@ -326,7 +326,7 @@ template <class T> constexpr void field<T>::self_to_montgomery_form() noexcept
|
||||
self_reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_from_montgomery_form() noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_from_montgomery_form() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_from_montgomery_form");
|
||||
constexpr field one_raw{ 1, 0, 0, 0 };
|
||||
@@ -334,7 +334,7 @@ template <class T> constexpr void field<T>::self_from_montgomery_form() noexcept
|
||||
self_reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::reduce_once() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::reduce_once() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::reduce_once");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -348,7 +348,7 @@ template <class T> constexpr field<T> field<T>::reduce_once() const noexcept
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_reduce_once() noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_reduce_once() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_reduce_once");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
@@ -363,7 +363,7 @@ template <class T> constexpr void field<T>::self_reduce_once() noexcept
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::pow(const uint256_t& exponent) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::pow(const uint256_t& exponent) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::pow");
|
||||
field accumulator{ data[0], data[1], data[2], data[3] };
|
||||
@@ -384,12 +384,12 @@ template <class T> constexpr field<T> field<T>::pow(const uint256_t& exponent) c
|
||||
return accumulator;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::pow(const uint64_t exponent) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::pow(const uint64_t exponent) const noexcept
|
||||
{
|
||||
return pow({ exponent, 0, 0, 0 });
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::invert() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::invert() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::invert");
|
||||
if (*this == zero()) {
|
||||
@@ -398,12 +398,12 @@ template <class T> constexpr field<T> field<T>::invert() const noexcept
|
||||
return pow(modulus_minus_two);
|
||||
}
|
||||
|
||||
template <class T> void field<T>::batch_invert(field* coeffs, const size_t n) noexcept
|
||||
template <class T> __device__ void field<T>::batch_invert(field* coeffs, const size_t n) noexcept
|
||||
{
|
||||
batch_invert(std::span{ coeffs, n });
|
||||
}
|
||||
|
||||
template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
|
||||
template <class T> __device__ void field<T>::batch_invert(std::span<field> coeffs) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::batch_invert");
|
||||
const size_t n = coeffs.size();
|
||||
@@ -452,7 +452,7 @@ template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt");
|
||||
// Tonelli-shanks algorithm begins by finding a field element Q and integer S,
|
||||
@@ -532,7 +532,7 @@ template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noex
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
|
||||
template <class T> __device__ constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::sqrt");
|
||||
field root;
|
||||
@@ -549,41 +549,41 @@ template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const no
|
||||
|
||||
} // namespace bb;
|
||||
|
||||
template <class T> constexpr field<T> field<T>::operator/(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::operator/(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::div");
|
||||
return operator*(other.invert());
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator/=(const field& other) noexcept
|
||||
template <class T> __device__ constexpr field<T>& field<T>::operator/=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_div");
|
||||
*this = operator/(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_set_msb() noexcept
|
||||
template <class T> __device__ constexpr void field<T>::self_set_msb() noexcept
|
||||
{
|
||||
data[3] = 0ULL | (1ULL << 63ULL);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::is_msb_set() const noexcept
|
||||
template <class T> __device__ constexpr bool field<T>::is_msb_set() const noexcept
|
||||
{
|
||||
return (data[3] >> 63ULL) == 1ULL;
|
||||
}
|
||||
|
||||
template <class T> constexpr uint64_t field<T>::is_msb_set_word() const noexcept
|
||||
template <class T> __device__ constexpr uint64_t field<T>::is_msb_set_word() const noexcept
|
||||
{
|
||||
return (data[3] >> 63ULL);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::is_zero() const noexcept
|
||||
template <class T> __device__ constexpr bool field<T>::is_zero() const noexcept
|
||||
{
|
||||
return ((data[0] | data[1] | data[2] | data[3]) == 0) ||
|
||||
(data[0] == T::modulus_0 && data[1] == T::modulus_1 && data[2] == T::modulus_2 && data[3] == T::modulus_3);
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::get_root_of_unity(size_t subgroup_size) noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::get_root_of_unity(size_t subgroup_size) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
field r{ T::primitive_root_0, T::primitive_root_1, T::primitive_root_2, T::primitive_root_3 };
|
||||
@@ -596,20 +596,7 @@ template <class T> constexpr field<T> field<T>::get_root_of_unity(size_t subgrou
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> field<T> field<T>::random_element(numeric::RNG* engine) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::random_element");
|
||||
if (engine == nullptr) {
|
||||
engine = &numeric::get_randomness();
|
||||
}
|
||||
|
||||
uint512_t source = engine->get_random_uint512();
|
||||
uint512_t q(modulus);
|
||||
uint512_t reduced = source % q;
|
||||
return field(reduced.lo);
|
||||
}
|
||||
|
||||
template <class T> constexpr size_t field<T>::primitive_root_log_size() noexcept
|
||||
template <class T> __device__ constexpr size_t field<T>::primitive_root_log_size() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::primitive_root_log_size");
|
||||
uint256_t target = modulus - 1;
|
||||
@@ -620,7 +607,7 @@ template <class T> constexpr size_t field<T>::primitive_root_log_size() noexcept
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute_coset_generators() noexcept
|
||||
{
|
||||
constexpr size_t n = COSET_GENERATOR_SIZE;
|
||||
@@ -654,7 +641,7 @@ constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::multiplicative_generator() noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::multiplicative_generator() noexcept
|
||||
{
|
||||
field target(1);
|
||||
uint256_t p_minus_one_over_two = (modulus - 1) >> 1;
|
||||
@@ -3,13 +3,13 @@
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "./field_impl.hpp"
|
||||
#include "./field_impl.cuh"
|
||||
#include "../../common/op_count.hpp"
|
||||
|
||||
namespace bb {
|
||||
using namespace numeric;
|
||||
// NOLINTBEGIN(readability-implicit-bool-conversion)
|
||||
template <class T> constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(uint64_t a, uint64_t b) noexcept
|
||||
template <class T> __device__ constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(uint64_t a, uint64_t b) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = (static_cast<uint128_t>(a) * static_cast<uint128_t>(b));
|
||||
@@ -20,7 +20,7 @@ template <class T> constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(ui
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr uint64_t field<T>::mac(
|
||||
const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t carry_in, uint64_t& carry_out) noexcept
|
||||
{
|
||||
@@ -36,7 +36,7 @@ constexpr uint64_t field<T>::mac(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr void field<T>::mac(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
@@ -56,7 +56,7 @@ constexpr void field<T>::mac(const uint64_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr uint64_t field<T>::mac_mini(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
@@ -73,7 +73,7 @@ constexpr uint64_t field<T>::mac_mini(const uint64_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr void field<T>::mac_mini(
|
||||
const uint64_t a, const uint64_t b, const uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept
|
||||
{
|
||||
@@ -88,7 +88,7 @@ constexpr void field<T>::mac_mini(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr uint64_t field<T>::mac_discard_lo(const uint64_t a, const uint64_t b, const uint64_t c) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
@@ -99,7 +99,7 @@ constexpr uint64_t field<T>::mac_discard_lo(const uint64_t a, const uint64_t b,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr uint64_t field<T>::addc(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t carry_in,
|
||||
@@ -119,7 +119,7 @@ constexpr uint64_t field<T>::addc(const uint64_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr uint64_t field<T>::sbb(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t borrow_in,
|
||||
@@ -139,7 +139,7 @@ constexpr uint64_t field<T>::sbb(const uint64_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class T> __device__
|
||||
constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
@@ -177,8 +177,9 @@ constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::reduce() const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::reduce() const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
uint256_t val{ data[0], data[1], data[2], data[3] };
|
||||
if (val >= modulus) {
|
||||
@@ -202,7 +203,7 @@ template <class T> constexpr field<T> field<T>::reduce() const noexcept
|
||||
};
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::add(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::add(const field& other) const noexcept
|
||||
{
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
uint64_t r0 = data[0] + other.data[0];
|
||||
@@ -251,7 +252,7 @@ template <class T> constexpr field<T> field<T>::add(const field& other) const no
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::subtract(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::subtract(const field& other) const noexcept
|
||||
{
|
||||
uint64_t borrow = 0;
|
||||
uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
|
||||
@@ -283,8 +284,9 @@ template <class T> constexpr field<T> field<T>::subtract(const field& other) con
|
||||
* @param other
|
||||
* @return constexpr field<T>
|
||||
*/
|
||||
template <class T> constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return subtract(other);
|
||||
}
|
||||
@@ -309,8 +311,9 @@ template <class T> constexpr field<T> field<T>::subtract_coarse(const field& oth
|
||||
* @details Explanation of Montgomery form can be found in \ref field_docs_montgomery_explainer and the difference
|
||||
* between WASM and generic versions is explained in \ref field_docs_architecture_details
|
||||
*/
|
||||
template <class T> constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
uint64_t c = 0;
|
||||
uint64_t t0 = 0;
|
||||
@@ -531,8 +534,9 @@ template <class T> constexpr std::array<uint64_t, WASM_NUM_LIMBS> field<T>::wasm
|
||||
(data[3] >> 40) & 0x1fffffff };
|
||||
}
|
||||
#endif
|
||||
template <class T> constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
|
||||
template <class T> __device__ constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return montgomery_mul_big(other);
|
||||
}
|
||||
@@ -653,6 +657,7 @@ template <class T> constexpr field<T> field<T>::montgomery_mul(const field& othe
|
||||
|
||||
template <class T> constexpr field<T> field<T>::montgomery_square() const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return montgomery_mul_big(*this);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#if (BBERG_NO_ASM == 0)
|
||||
#include "./field_impl.hpp"
|
||||
#include "./field_impl.cuh"
|
||||
#include "asm_macros.hpp"
|
||||
namespace bb {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "../../common/serialize.hpp"
|
||||
#include "../../ecc/curves/bn254/fq2.hpp"
|
||||
#include "../../numeric/uint256/uint256.hpp"
|
||||
#include "../../ecc/curves/bn254/fq2.cuh"
|
||||
#include "../../numeric/uint256/uint256.cuh"
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@@ -76,12 +76,6 @@ template <typename Fq_, typename Fr_, typename Params> class alignas(64) affine_
|
||||
|
||||
static constexpr std::optional<affine_element> derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Samples a random point on the curve.
|
||||
*
|
||||
* @return A randomly chosen point on the curve
|
||||
*/
|
||||
static affine_element random_element(numeric::RNG* engine = nullptr) noexcept;
|
||||
static constexpr affine_element hash_to_curve(const std::vector<uint8_t>& seed, uint8_t attempt_count = 0) noexcept
|
||||
requires SupportsHashToCurve<Params>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "./element.hpp"
|
||||
#include "./element.cuh"
|
||||
#include "../../crypto/blake3s/blake3s.hpp"
|
||||
#include "../../crypto/keccak/keccak.hpp"
|
||||
|
||||
@@ -187,104 +187,4 @@ constexpr std::optional<affine_element<Fq, Fr, T>> affine_element<Fq, Fr, T>::de
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Hash a seed buffer into a point
|
||||
*
|
||||
* @details ALGORITHM DESCRIPTION:
|
||||
* 1. Initialize unsigned integer `attempt_count = 0`
|
||||
* 2. Copy seed into a buffer whose size is 2 bytes greater than `seed` (initialized to 0)
|
||||
* 3. Interpret `attempt_count` as a byte and write into buffer at [buffer.size() - 2]
|
||||
* 4. Compute Blake3s hash of buffer
|
||||
* 5. Set the end byte of the buffer to `1`
|
||||
* 6. Compute Blake3s hash of buffer
|
||||
* 7. Interpret the two hash outputs as the high / low 256 bits of a 512-bit integer (big-endian)
|
||||
* 8. Derive x-coordinate of point by reducing the 512-bit integer modulo the curve's field modulus (Fq)
|
||||
* 9. Compute y^2 from the curve formula y^2 = x^3 + ax + b (a, b are curve params. for BN254, a = 0, b = 3)
|
||||
* 10. IF y^2 IS NOT A QUADRATIC RESIDUE
|
||||
* 10a. increment `attempt_count` by 1 and go to step 2
|
||||
* 11. IF y^2 IS A QUADRATIC RESIDUE
|
||||
* 11a. derive y coordinate via y = sqrt(y)
|
||||
* 11b. Interpret most significant bit of 512-bit integer as a 'parity' bit
|
||||
* 11c. If parity bit is set AND y's most significant bit is not set, invert y
|
||||
* 11d. If parity bit is not set AND y's most significant bit is set, invert y
|
||||
* N.B. last 2 steps are because the sqrt() algorithm can return 2 values,
|
||||
* we need to a way to canonically distinguish between these 2 values and select a "preferred" one
|
||||
* 11e. return (x, y)
|
||||
*
|
||||
* @note This algorihm is constexpr: we can hash-to-curve (and derive generators) at compile-time!
|
||||
* @tparam Fq
|
||||
* @tparam Fr
|
||||
* @tparam T
|
||||
* @param seed Bytes that uniquely define the point being generated
|
||||
* @param attempt_count
|
||||
* @return constexpr affine_element<Fq, Fr, T>
|
||||
*/
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::hash_to_curve(const std::vector<uint8_t>& seed,
|
||||
uint8_t attempt_count) noexcept
|
||||
requires SupportsHashToCurve<T>
|
||||
|
||||
{
|
||||
std::vector<uint8_t> target_seed(seed);
|
||||
// expand by 2 bytes to cover incremental hash attempts
|
||||
const size_t seed_size = seed.size();
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
target_seed.push_back(0);
|
||||
}
|
||||
target_seed[seed_size] = attempt_count;
|
||||
target_seed[seed_size + 1] = 0;
|
||||
const auto hash_hi = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
|
||||
target_seed[seed_size + 1] = 1;
|
||||
const auto hash_lo = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
|
||||
// custom serialize methods as common/serialize.hpp is not constexpr!
|
||||
const auto read_uint256 = [](const uint8_t* in) {
|
||||
const auto read_limb = [](const uint8_t* in, uint64_t& out) {
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
out += static_cast<uint64_t>(in[i]) << ((7 - i) * 8);
|
||||
}
|
||||
};
|
||||
uint256_t out = 0;
|
||||
read_limb(&in[0], out.data[3]);
|
||||
read_limb(&in[8], out.data[2]);
|
||||
read_limb(&in[16], out.data[1]);
|
||||
read_limb(&in[24], out.data[0]);
|
||||
return out;
|
||||
};
|
||||
// interpret 64 byte hash output as a uint512_t, reduce to Fq element
|
||||
//(512 bits of entropy ensures result is not biased as 512 >> Fq::modulus.get_msb())
|
||||
Fq x(uint512_t(read_uint256(&hash_lo[0]), read_uint256(&hash_hi[0])));
|
||||
bool sign_bit = hash_hi[0] > 127;
|
||||
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
|
||||
if (result.has_value()) {
|
||||
return result.value();
|
||||
}
|
||||
return hash_to_curve(seed, attempt_count + 1);
|
||||
}
|
||||
|
||||
template <typename Fq, typename Fr, typename T>
|
||||
affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::random_element(numeric::RNG* engine) noexcept
|
||||
{
|
||||
if (engine == nullptr) {
|
||||
engine = &numeric::get_randomness();
|
||||
}
|
||||
|
||||
Fq x;
|
||||
Fq y;
|
||||
while (true) {
|
||||
// Sample a random x-coordinate and check if it satisfies curve equation.
|
||||
x = Fq::random_element(engine);
|
||||
// Negate the y-coordinate based on a randomly sampled bit.
|
||||
bool sign_bit = (engine->get_random_uint8() & 1) != 0;
|
||||
|
||||
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
|
||||
|
||||
if (result.has_value()) {
|
||||
return result.value();
|
||||
}
|
||||
}
|
||||
throw_or_abort("affine_element::random_element error");
|
||||
return affine_element<Fq, Fr, T>(x, y);
|
||||
}
|
||||
|
||||
} // namespace bb::group_elements
|
||||
@@ -1,10 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "affine_element.hpp"
|
||||
#include "affine_element.cuh"
|
||||
#include "../../common/compiler_hints.hpp"
|
||||
#include "../../common/mem.hpp"
|
||||
#include "../../numeric/random/engine.hpp"
|
||||
#include "../../numeric/uint256/uint256.hpp"
|
||||
#include "../../numeric/uint256/uint256.cuh"
|
||||
#include "wnaf.hpp"
|
||||
#include <array>
|
||||
#include <random>
|
||||
@@ -49,8 +48,6 @@ template <class Fq, class Fr, class Params> class alignas(32) element {
|
||||
|
||||
constexpr operator affine_element<Fq, Fr, Params>() const noexcept;
|
||||
|
||||
static element random_element(numeric::RNG* engine = nullptr) noexcept;
|
||||
|
||||
constexpr element dbl() const noexcept;
|
||||
constexpr void self_dbl() noexcept;
|
||||
constexpr void self_mixed_add_or_sub(const affine_element<Fq, Fr, Params>& other, uint64_t predicate) noexcept;
|
||||
@@ -107,29 +104,6 @@ template <class Fq, class Fr, class Params> class alignas(32) element {
|
||||
element mul_without_endomorphism(const Fr& scalar) const noexcept;
|
||||
element mul_with_endomorphism(const Fr& scalar) const noexcept;
|
||||
|
||||
template <typename = typename std::enable_if<Params::can_hash_to_curve>>
|
||||
static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept;
|
||||
// {
|
||||
// bool found_one = false;
|
||||
// Fq yy;
|
||||
// Fq x;
|
||||
// Fq y;
|
||||
// Fq t0;
|
||||
// while (!found_one) {
|
||||
// x = Fq::random_element(engine);
|
||||
// yy = x.sqr() * x + Params::b;
|
||||
// if constexpr (Params::has_a) {
|
||||
// yy += (x * Params::a);
|
||||
// }
|
||||
// y = yy.sqrt();
|
||||
// t0 = y.sqr();
|
||||
// found_one = (yy == t0);
|
||||
// }
|
||||
// return { x, y, Fq::one() };
|
||||
// }
|
||||
// for serialization: update with new fields
|
||||
// TODO(https://github.com/AztecProtocol/barretenberg/issues/908) point at inifinty isn't handled
|
||||
|
||||
static void conditional_negate_affine(const affine_element<Fq, Fr, Params>& in,
|
||||
affine_element<Fq, Fr, Params>& out,
|
||||
uint64_t predicate) noexcept;
|
||||
@@ -1,8 +1,7 @@
|
||||
#pragma once
|
||||
#include "../../common/op_count.hpp"
|
||||
#include "../../common/thread.hpp"
|
||||
#include "./element.hpp"
|
||||
#include "element.hpp"
|
||||
#include "./element.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
// NOLINTBEGIN(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays)
|
||||
@@ -577,23 +576,6 @@ constexpr bool element<Fq, Fr, T>::operator==(const element& other) const noexce
|
||||
return both_infinity || ((lhs_x == rhs_x) && (lhs_y == rhs_y));
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
element<Fq, Fr, T> element<Fq, Fr, T>::random_element(numeric::RNG* engine) noexcept
|
||||
{
|
||||
if constexpr (T::can_hash_to_curve) {
|
||||
element result = random_coordinates_on_curve(engine);
|
||||
result.z = Fq::random_element(engine);
|
||||
Fq zz = result.z.sqr();
|
||||
Fq zzz = zz * result.z;
|
||||
result.x *= zz;
|
||||
result.y *= zzz;
|
||||
return result;
|
||||
} else {
|
||||
Fr scalar = Fr::random_element(engine);
|
||||
return (element{ T::one_x, T::one_y, Fq::one() } * scalar);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
element<Fq, Fr, T> element<Fq, Fr, T>::mul_without_endomorphism(const Fr& scalar) const noexcept
|
||||
{
|
||||
@@ -1193,27 +1175,5 @@ void element<Fq, Fr, T>::batch_normalize(element* elements, const size_t num_ele
|
||||
elements[i].z = Fq::one();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Fq, typename Fr, typename T>
|
||||
template <typename>
|
||||
element<Fq, Fr, T> element<Fq, Fr, T>::random_coordinates_on_curve(numeric::RNG* engine) noexcept
|
||||
{
|
||||
bool found_one = false;
|
||||
Fq yy;
|
||||
Fq x;
|
||||
Fq y;
|
||||
while (!found_one) {
|
||||
x = Fq::random_element(engine);
|
||||
yy = x.sqr() * x + T::b;
|
||||
if constexpr (T::has_a) {
|
||||
yy += (x * T::a);
|
||||
}
|
||||
auto [found_root, y1] = yy.sqrt();
|
||||
y = y1;
|
||||
found_one = found_root;
|
||||
}
|
||||
return { x, y, Fq::one() };
|
||||
}
|
||||
|
||||
} // namespace bb::group_elements
|
||||
// NOLINTEND(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays)
|
||||
@@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../common/assert.hpp"
|
||||
#include "./affine_element.hpp"
|
||||
#include "./element.hpp"
|
||||
#include "./affine_element.cuh"
|
||||
#include "./element.cuh"
|
||||
#include "./wnaf.hpp"
|
||||
#include "../../common/constexpr_utils.hpp"
|
||||
#include "../../crypto/blake3s/blake3s.hpp"
|
||||
@@ -1,157 +0,0 @@
|
||||
#pragma once
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "../uint256/uint256.hpp"
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
inline std::vector<uint64_t> slice_input(const uint256_t& input, const uint64_t base, const size_t num_slices)
|
||||
{
|
||||
uint256_t target = input;
|
||||
std::vector<uint64_t> slices;
|
||||
if (num_slices > 0) {
|
||||
for (size_t i = 0; i < num_slices; ++i) {
|
||||
slices.push_back((target % base).data[0]);
|
||||
target /= base;
|
||||
}
|
||||
} else {
|
||||
while (target > 0) {
|
||||
slices.push_back((target % base).data[0]);
|
||||
target /= base;
|
||||
}
|
||||
}
|
||||
return slices;
|
||||
}
|
||||
|
||||
inline std::vector<uint64_t> slice_input_using_variable_bases(const uint256_t& input,
|
||||
const std::vector<uint64_t>& bases)
|
||||
{
|
||||
uint256_t target = input;
|
||||
std::vector<uint64_t> slices;
|
||||
for (size_t i = 0; i < bases.size(); ++i) {
|
||||
if (target >= bases[i] && i == bases.size() - 1) {
|
||||
throw_or_abort(format("Last key slice greater than ", bases[i]));
|
||||
}
|
||||
slices.push_back((target % bases[i]).data[0]);
|
||||
target /= bases[i];
|
||||
}
|
||||
return slices;
|
||||
}
|
||||
|
||||
template <uint64_t base, uint64_t num_slices> constexpr std::array<uint256_t, num_slices> get_base_powers()
|
||||
{
|
||||
std::array<uint256_t, num_slices> output{};
|
||||
output[0] = 1;
|
||||
for (size_t i = 1; i < num_slices; ++i) {
|
||||
output[i] = output[i - 1] * base;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
template <uint64_t base> constexpr uint256_t map_into_sparse_form(const uint64_t input)
|
||||
{
|
||||
uint256_t out = 0UL;
|
||||
auto converted = input;
|
||||
|
||||
constexpr auto base_powers = get_base_powers<base, 32>();
|
||||
for (size_t i = 0; i < 32; ++i) {
|
||||
uint64_t sparse_bit = ((converted >> i) & 1U);
|
||||
if (sparse_bit) {
|
||||
out += base_powers[i];
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <uint64_t base> constexpr uint64_t map_from_sparse_form(const uint256_t& input)
|
||||
{
|
||||
uint256_t target = input;
|
||||
uint64_t output = 0;
|
||||
|
||||
constexpr auto bases = get_base_powers<base, 32>();
|
||||
|
||||
for (uint64_t i = 0; i < 32; ++i) {
|
||||
const auto& base_power = bases[static_cast<size_t>(31 - i)];
|
||||
uint256_t prev_threshold = 0;
|
||||
for (uint64_t j = 1; j < base + 1; ++j) {
|
||||
const auto threshold = prev_threshold + base_power;
|
||||
if (target < threshold) {
|
||||
bool bit = ((j - 1) & 1);
|
||||
if (bit) {
|
||||
output += (1ULL << (31ULL - i));
|
||||
}
|
||||
if (j > 1) {
|
||||
target -= (prev_threshold);
|
||||
}
|
||||
break;
|
||||
}
|
||||
prev_threshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
template <uint64_t base, size_t num_bits> class sparse_int {
|
||||
public:
|
||||
sparse_int(const uint64_t input = 0)
|
||||
: value(input)
|
||||
{
|
||||
for (size_t i = 0; i < num_bits; ++i) {
|
||||
const uint64_t bit = (input >> i) & 1U;
|
||||
limbs[i] = bit;
|
||||
}
|
||||
}
|
||||
sparse_int(const sparse_int& other) noexcept = default;
|
||||
sparse_int(sparse_int&& other) noexcept = default;
|
||||
sparse_int& operator=(const sparse_int& other) noexcept = default;
|
||||
sparse_int& operator=(sparse_int&& other) noexcept = default;
|
||||
~sparse_int() noexcept = default;
|
||||
|
||||
sparse_int operator+(const sparse_int& other) const
|
||||
{
|
||||
sparse_int result(*this);
|
||||
for (size_t i = 0; i < num_bits - 1; ++i) {
|
||||
result.limbs[i] += other.limbs[i];
|
||||
if (result.limbs[i] >= base) {
|
||||
result.limbs[i] -= base;
|
||||
++result.limbs[i + 1];
|
||||
}
|
||||
}
|
||||
result.limbs[num_bits - 1] += other.limbs[num_bits - 1];
|
||||
result.limbs[num_bits - 1] %= base;
|
||||
result.value += other.value;
|
||||
return result;
|
||||
};
|
||||
|
||||
sparse_int operator+=(const sparse_int& other)
|
||||
{
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
[[nodiscard]] uint64_t get_value() const { return value; }
|
||||
|
||||
[[nodiscard]] uint64_t get_sparse_value() const
|
||||
{
|
||||
uint64_t result = 0;
|
||||
for (size_t i = num_bits - 1; i < num_bits; --i) {
|
||||
result *= base;
|
||||
result += limbs[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::array<uint64_t, num_bits>& get_limbs() const { return limbs; }
|
||||
|
||||
private:
|
||||
std::array<uint64_t, num_bits> limbs;
|
||||
uint64_t value;
|
||||
uint64_t sparse_value;
|
||||
};
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -1,139 +0,0 @@
|
||||
#include "engine.hpp"
|
||||
#include "../../common/assert.hpp"
|
||||
#include <array>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
namespace {
|
||||
auto generate_random_data()
|
||||
{
|
||||
std::array<unsigned int, 32> random_data;
|
||||
std::random_device source;
|
||||
std::generate(std::begin(random_data), std::end(random_data), std::ref(source));
|
||||
return random_data;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class RandomEngine : public RNG {
|
||||
public:
|
||||
uint8_t get_random_uint8() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
uint32_t out = buf[0];
|
||||
return static_cast<uint8_t>(out);
|
||||
}
|
||||
|
||||
uint16_t get_random_uint16() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
uint32_t out = buf[0];
|
||||
return static_cast<uint16_t>(out);
|
||||
}
|
||||
|
||||
uint32_t get_random_uint32() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
uint32_t out = buf[0];
|
||||
return static_cast<uint32_t>(out);
|
||||
}
|
||||
|
||||
uint64_t get_random_uint64() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
auto lo = static_cast<uint64_t>(buf[0]);
|
||||
auto hi = static_cast<uint64_t>(buf[1]);
|
||||
return (lo + (hi << 32ULL));
|
||||
}
|
||||
|
||||
uint128_t get_random_uint128() override
|
||||
{
|
||||
auto big = get_random_uint256();
|
||||
auto lo = static_cast<uint128_t>(big.data[0]);
|
||||
auto hi = static_cast<uint128_t>(big.data[1]);
|
||||
return (lo + (hi << static_cast<uint128_t>(64ULL)));
|
||||
}
|
||||
|
||||
uint256_t get_random_uint256() override
|
||||
{
|
||||
const auto get64 = [](const std::array<uint32_t, 32>& buffer, const size_t offset) {
|
||||
auto lo = static_cast<uint64_t>(buffer[0 + offset]);
|
||||
auto hi = static_cast<uint64_t>(buffer[1 + offset]);
|
||||
return (lo + (hi << 32ULL));
|
||||
};
|
||||
auto buf = generate_random_data();
|
||||
uint64_t lolo = get64(buf, 0);
|
||||
uint64_t lohi = get64(buf, 2);
|
||||
uint64_t hilo = get64(buf, 4);
|
||||
uint64_t hihi = get64(buf, 6);
|
||||
return { lolo, lohi, hilo, hihi };
|
||||
}
|
||||
};
|
||||
|
||||
class DebugEngine : public RNG {
|
||||
public:
|
||||
DebugEngine()
|
||||
// disable linting for this line: we want the DEBUG engine to produce predictable pseudorandom numbers!
|
||||
// NOLINTNEXTLINE(cert-msc32-c, cert-msc51-cpp)
|
||||
: engine(std::mt19937_64(12345))
|
||||
{}
|
||||
|
||||
DebugEngine(std::uint_fast64_t seed)
|
||||
: engine(std::mt19937_64(seed))
|
||||
{}
|
||||
|
||||
uint8_t get_random_uint8() override { return static_cast<uint8_t>(dist(engine)); }
|
||||
|
||||
uint16_t get_random_uint16() override { return static_cast<uint16_t>(dist(engine)); }
|
||||
|
||||
uint32_t get_random_uint32() override { return static_cast<uint32_t>(dist(engine)); }
|
||||
|
||||
uint64_t get_random_uint64() override { return dist(engine); }
|
||||
|
||||
uint128_t get_random_uint128() override
|
||||
{
|
||||
uint128_t hi = dist(engine);
|
||||
uint128_t lo = dist(engine);
|
||||
return (hi << 64) | lo;
|
||||
}
|
||||
|
||||
uint256_t get_random_uint256() override
|
||||
{
|
||||
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
|
||||
auto a = dist(engine);
|
||||
auto b = dist(engine);
|
||||
auto c = dist(engine);
|
||||
auto d = dist(engine);
|
||||
return { a, b, c, d };
|
||||
}
|
||||
|
||||
private:
|
||||
std::mt19937_64 engine;
|
||||
std::uniform_int_distribution<uint64_t> dist = std::uniform_int_distribution<uint64_t>{ 0ULL, UINT64_MAX };
|
||||
};
|
||||
|
||||
/**
|
||||
* Used by tests to ensure consistent behavior.
|
||||
*/
|
||||
RNG& get_debug_randomness(bool reset, std::uint_fast64_t seed)
|
||||
{
|
||||
// static std::seed_seq seed({ 1, 2, 3, 4, 5 });
|
||||
static DebugEngine debug_engine = DebugEngine();
|
||||
if (reset) {
|
||||
debug_engine = DebugEngine(seed);
|
||||
}
|
||||
return debug_engine;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default engine. If wanting consistent proof construction, uncomment the line to return the debug engine.
|
||||
*/
|
||||
RNG& get_randomness()
|
||||
{
|
||||
// return get_debug_randomness();
|
||||
static RandomEngine engine;
|
||||
return engine;
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -1,52 +0,0 @@
|
||||
#pragma once
|
||||
#include "../uint128/uint128.hpp"
|
||||
#include "../uint256/uint256.hpp"
|
||||
#include "../uintx/uintx.hpp"
|
||||
#include "unistd.h"
|
||||
#include <cstdint>
|
||||
#include <random>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
class RNG {
|
||||
public:
|
||||
virtual uint8_t get_random_uint8() = 0;
|
||||
|
||||
virtual uint16_t get_random_uint16() = 0;
|
||||
|
||||
virtual uint32_t get_random_uint32() = 0;
|
||||
|
||||
virtual uint64_t get_random_uint64() = 0;
|
||||
|
||||
virtual uint128_t get_random_uint128() = 0;
|
||||
|
||||
virtual uint256_t get_random_uint256() = 0;
|
||||
|
||||
virtual ~RNG() = default;
|
||||
RNG() noexcept = default;
|
||||
RNG(const RNG& other) = default;
|
||||
RNG(RNG&& other) = default;
|
||||
RNG& operator=(const RNG& other) = default;
|
||||
RNG& operator=(RNG&& other) = default;
|
||||
|
||||
uint512_t get_random_uint512()
|
||||
{
|
||||
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
|
||||
auto lo = get_random_uint256();
|
||||
auto hi = get_random_uint256();
|
||||
return { lo, hi };
|
||||
}
|
||||
|
||||
uint1024_t get_random_uint1024()
|
||||
{
|
||||
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
|
||||
auto lo = get_random_uint512();
|
||||
auto hi = get_random_uint512();
|
||||
return { lo, hi };
|
||||
}
|
||||
};
|
||||
|
||||
RNG& get_debug_randomness(bool reset = false, std::uint_fast64_t seed = 12345);
|
||||
RNG& get_randomness();
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "../../common/assert.hpp"
|
||||
namespace bb::numeric {
|
||||
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
|
||||
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
|
||||
{
|
||||
const uint32_t a_lo = a & 0xffffULL;
|
||||
const uint32_t a_hi = a >> 16ULL;
|
||||
@@ -23,7 +23,7 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, co
|
||||
}
|
||||
|
||||
// compute a + b + carry, returning the carry
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
|
||||
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
|
||||
{
|
||||
const uint32_t sum = a + b;
|
||||
const auto carry_temp = static_cast<uint32_t>(sum < a);
|
||||
@@ -32,12 +32,12 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const
|
||||
return { r, carry_out };
|
||||
}
|
||||
|
||||
constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
|
||||
__device__ constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
|
||||
{
|
||||
return a + b + carry_in;
|
||||
}
|
||||
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
|
||||
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
|
||||
{
|
||||
const uint32_t t_1 = a - (borrow_in >> 31ULL);
|
||||
const auto borrow_temp_1 = static_cast<uint32_t>(t_1 > a);
|
||||
@@ -47,13 +47,13 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const u
|
||||
return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
|
||||
}
|
||||
|
||||
constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
|
||||
__device__ constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
|
||||
{
|
||||
return a - b - (borrow_in >> 31ULL);
|
||||
}
|
||||
|
||||
// {r, carry_out} = a + carry_in + b * c
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
|
||||
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
|
||||
const uint32_t b,
|
||||
const uint32_t c,
|
||||
const uint32_t carry_in)
|
||||
@@ -67,7 +67,7 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
|
||||
__device__ constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
|
||||
const uint32_t b,
|
||||
const uint32_t c,
|
||||
const uint32_t carry_in)
|
||||
@@ -75,7 +75,7 @@ constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
|
||||
return (b * c + a + carry_in);
|
||||
}
|
||||
|
||||
constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
|
||||
__device__ constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
|
||||
{
|
||||
if (*this == 0 || b == 0) {
|
||||
return { 0, 0 };
|
||||
@@ -123,7 +123,7 @@ constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b)
|
||||
return { quotient, remainder };
|
||||
}
|
||||
|
||||
constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
|
||||
__device__ constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
|
||||
{
|
||||
const auto [r0, t0] = mul_wide(data[0], other.data[0]);
|
||||
const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "../bitop/get_msb.hpp"
|
||||
#include "./uint256.hpp"
|
||||
#include "./uint256.cuh"
|
||||
#include "../../common/assert.hpp"
|
||||
namespace bb::numeric {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
**/
|
||||
#pragma once
|
||||
|
||||
#include "../uint256/uint256.hpp"
|
||||
#include "../uint256/uint256.cuh"
|
||||
#include "../../common/assert.hpp"
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include <cstdint>
|
||||
@@ -166,13 +166,4 @@ template <class base_uint> inline std::ostream& operator<<(std::ostream& os, uin
|
||||
os << a.lo << ", " << a.hi << std::endl;
|
||||
return os;
|
||||
}
|
||||
|
||||
using uint512_t = uintx<numeric::uint256_t>;
|
||||
using uint1024_t = uintx<uint512_t>;
|
||||
|
||||
} // namespace bb::numeric
|
||||
|
||||
#include "./uintx_impl.hpp"
|
||||
|
||||
using bb::numeric::uint1024_t; // NOLINT
|
||||
using bb::numeric::uint512_t; // NOLINT
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "./uintx.hpp"
|
||||
#include "./uintx.cuh"
|
||||
#include "../../common/assert.hpp"
|
||||
|
||||
namespace bb::numeric {
|
||||
10
sumcheck/src/cuda/includes/prime_field.h
Normal file
10
sumcheck/src/cuda/includes/prime_field.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef __PRIME_FIELD_H__
|
||||
#define __PRIME_FIELD_H__
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
struct FieldBinding {
|
||||
uint64_t data[4];
|
||||
};
|
||||
|
||||
#endif
|
||||
25
sumcheck/src/cuda/includes/test.cpp
Normal file
25
sumcheck/src/cuda/includes/test.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
|
||||
#include "prime_field.h"
|
||||
#include "./barretenberg/ecc/curves/bn254/fq.hpp"
|
||||
|
||||
int main() {
|
||||
bb::fq a = bb::fq(1UL);
|
||||
bb::fq b = bb::fq(2U);
|
||||
bb::fq c = a + b;
|
||||
assert(c == bb::fq(3U));
|
||||
c = a * b;
|
||||
assert(c == bb::fq(2U));
|
||||
// memory layout test
|
||||
struct Field val_c { { 1UL, 2UL, 3UL, 4UL } };
|
||||
bb::fq val;
|
||||
std::memcpy(&val, &val_c, 32);
|
||||
assert(val.data[0] == 1UL);
|
||||
assert(val.data[1] == 2UL);
|
||||
assert(val.data[2] == 3UL);
|
||||
assert(val.data[3] == 4UL);
|
||||
std::cout << "test ended" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
#include "prime_field.h"
|
||||
|
||||
Reference in New Issue
Block a user