Update barretenberg to device code

This commit is contained in:
DoHoonKim
2024-06-17 01:36:51 +09:00
committed by DoHoon Kim
parent 22250b1b9a
commit 911f54d5ac
27 changed files with 191 additions and 757 deletions

View File

@@ -1,10 +1,10 @@
#pragma once
#include "../bn254/fq.hpp"
#include "../bn254/fq12.hpp"
#include "../bn254/fq2.hpp"
#include "../bn254/fr.hpp"
#include "../bn254/g1.hpp"
#include "../bn254/g2.hpp"
#include "../bn254/fq.cuh"
#include "../bn254/fq12.cuh"
#include "../bn254/fq2.cuh"
#include "../bn254/fr.cuh"
#include "../bn254/g1.cuh"
#include "../bn254/g2.cuh"
namespace bb::curve {
class BN254 {

View File

@@ -3,7 +3,7 @@
#include <cstdint>
#include <iomanip>
#include "../../fields/field.hpp"
#include "../../fields/field.cuh"
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
namespace bb {

View File

@@ -1,7 +1,7 @@
#pragma once
#include "../../fields/field2.hpp"
#include "./fq.hpp"
#include "../../fields/field2.cuh"
#include "./fq.cuh"
namespace bb {
struct Bn254Fq2Params {

View File

@@ -4,7 +4,7 @@
#include <iomanip>
#include <ostream>
#include "../../fields/field.hpp"
#include "../../fields/field.cuh"
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)

View File

@@ -1,8 +1,8 @@
#pragma once
#include "../../groups/group.hpp"
#include "./fq.hpp"
#include "./fr.hpp"
#include "../../groups/group.cuh"
#include "./fq.cuh"
#include "./fr.cuh"
namespace bb {
struct Bn254G1Params {

View File

@@ -1,8 +1,8 @@
#pragma once
#include "../../groups/group.hpp"
#include "./fq2.hpp"
#include "./fr.hpp"
#include "../../groups/group.cuh"
#include "./fq2.cuh"
#include "./fr.cuh"
namespace bb {
struct Bn254G2Params {

View File

@@ -6,5 +6,5 @@
* declarations header) Spectialized definitions are in "field_impl_generic.hpp" and "field_impl_x64.hpp"
* (which include "field_impl.hpp")
*/
#include "./field_impl_generic.hpp"
#include "./field_impl_x64.hpp"
#include "./field_impl_generic.cuh"
#include "./field_impl_x64.cuh"

View File

@@ -1,9 +1,8 @@
#pragma once
#include "../../common/assert.hpp"
#include "../../common/compiler_hints.hpp"
#include "../../numeric/random/engine.hpp"
#include "../../numeric/uint128/uint128.hpp"
#include "../../numeric/uint256/uint256_impl.hpp"
#include "../../numeric/uint256/uint256_impl.cuh"
#include <array>
#include <cstdint>
#include <iostream>
@@ -52,33 +51,33 @@ template <class Params_> struct alignas(32) field {
// std::array<field, N> arr {}; // zero-initialized, preferable for moderate N
field() = default;
constexpr field(const numeric::uint256_t& input) noexcept
__device__ constexpr field(const numeric::uint256_t& input) noexcept
: data{ input.data[0], input.data[1], input.data[2], input.data[3] }
{
self_to_montgomery_form();
}
// NOLINTNEXTLINE (unsigned long is platform dependent, which we want in this case)
constexpr field(const unsigned long input) noexcept
__device__ constexpr field(const unsigned long input) noexcept
: data{ input, 0, 0, 0 }
{
self_to_montgomery_form();
}
constexpr field(const unsigned int input) noexcept
__device__ constexpr field(const unsigned int input) noexcept
: data{ input, 0, 0, 0 }
{
self_to_montgomery_form();
}
// NOLINTNEXTLINE (unsigned long long is platform dependent, which we want in this case)
constexpr field(const unsigned long long input) noexcept
__device__ constexpr field(const unsigned long long input) noexcept
: data{ input, 0, 0, 0 }
{
self_to_montgomery_form();
}
constexpr field(const int input) noexcept
__device__ constexpr field(const int input) noexcept
: data{ 0, 0, 0, 0 }
{
if (input < 0) {
@@ -98,63 +97,47 @@ template <class Params_> struct alignas(32) field {
}
}
constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
__device__ constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
: data{ a, b, c, d } {};
/**
* @brief Convert a 512-bit big integer into a field element.
*
* @details Used for deriving field elements from random values. 512-bits prevents biased output as 2^512>>modulus
*
*/
constexpr explicit field(const uint512_t& input) noexcept
{
uint256_t value = (input % modulus).lo;
data[0] = value.data[0];
data[1] = value.data[1];
data[2] = value.data[2];
data[3] = value.data[3];
self_to_montgomery_form();
}
constexpr explicit field(std::string input) noexcept
__device__ constexpr explicit field(std::string input) noexcept
{
uint256_t value(input);
*this = field(value);
}
constexpr explicit operator bool() const
__device__ constexpr explicit operator bool() const
{
field out = from_montgomery_form();
ASSERT(out.data[0] == 0 || out.data[0] == 1);
return static_cast<bool>(out.data[0]);
}
constexpr explicit operator uint8_t() const
__device__ constexpr explicit operator uint8_t() const
{
field out = from_montgomery_form();
return static_cast<uint8_t>(out.data[0]);
}
constexpr explicit operator uint16_t() const
__device__ constexpr explicit operator uint16_t() const
{
field out = from_montgomery_form();
return static_cast<uint16_t>(out.data[0]);
}
constexpr explicit operator uint32_t() const
__device__ constexpr explicit operator uint32_t() const
{
field out = from_montgomery_form();
return static_cast<uint32_t>(out.data[0]);
}
constexpr explicit operator uint64_t() const
__device__ constexpr explicit operator uint64_t() const
{
field out = from_montgomery_form();
return out.data[0];
}
constexpr explicit operator uint128_t() const
__device__ constexpr explicit operator uint128_t() const
{
field out = from_montgomery_form();
uint128_t lo = out.data[0];
@@ -162,7 +145,7 @@ template <class Params_> struct alignas(32) field {
return (hi << 64) | lo;
}
constexpr operator uint256_t() const noexcept
__device__ constexpr operator uint256_t() const noexcept
{
field out = from_montgomery_form();
return uint256_t(out.data[0], out.data[1], out.data[2], out.data[3]);
@@ -180,8 +163,9 @@ template <class Params_> struct alignas(32) field {
constexpr ~field() noexcept = default;
alignas(32) uint64_t data[4]; // NOLINT
static constexpr uint256_t modulus =
uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
static constexpr __device__ uint256_t get_modulus() {
return uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
}
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr uint256_t r_squared_uint{
Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3
@@ -287,45 +271,45 @@ template <class Params_> struct alignas(32) field {
return result;
}
BB_INLINE constexpr field operator*(const field& other) const noexcept;
BB_INLINE constexpr field operator+(const field& other) const noexcept;
BB_INLINE constexpr field operator-(const field& other) const noexcept;
BB_INLINE constexpr field operator-() const noexcept;
constexpr field operator/(const field& other) const noexcept;
BB_INLINE __device__ constexpr field operator*(const field& other) const noexcept;
BB_INLINE __device__ constexpr field operator+(const field& other) const noexcept;
BB_INLINE __device__ constexpr field operator-(const field& other) const noexcept;
BB_INLINE __device__ constexpr field operator-() const noexcept;
__device__ constexpr field operator/(const field& other) const noexcept;
// prefix increment (++x)
BB_INLINE constexpr field operator++() noexcept;
BB_INLINE __device__ constexpr field operator++() noexcept;
// postfix increment (x++)
// NOLINTNEXTLINE
BB_INLINE constexpr field operator++(int) noexcept;
BB_INLINE __device__ constexpr field operator++(int) noexcept;
BB_INLINE constexpr field& operator*=(const field& other) noexcept;
BB_INLINE constexpr field& operator+=(const field& other) noexcept;
BB_INLINE constexpr field& operator-=(const field& other) noexcept;
constexpr field& operator/=(const field& other) noexcept;
BB_INLINE __device__ constexpr field& operator*=(const field& other) noexcept;
BB_INLINE __device__ constexpr field& operator+=(const field& other) noexcept;
BB_INLINE __device__ constexpr field& operator-=(const field& other) noexcept;
__device__ constexpr field& operator/=(const field& other) noexcept;
// NOTE: comparison operators exist so that `field` is comparible with stl methods that require them.
// (e.g. std::sort)
// Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
BB_INLINE constexpr bool operator>(const field& other) const noexcept;
BB_INLINE constexpr bool operator<(const field& other) const noexcept;
BB_INLINE constexpr bool operator==(const field& other) const noexcept;
BB_INLINE constexpr bool operator!=(const field& other) const noexcept;
BB_INLINE __device__ constexpr bool operator>(const field& other) const noexcept;
BB_INLINE __device__ constexpr bool operator<(const field& other) const noexcept;
BB_INLINE __device__ constexpr bool operator==(const field& other) const noexcept;
BB_INLINE __device__ constexpr bool operator!=(const field& other) const noexcept;
BB_INLINE constexpr field to_montgomery_form() const noexcept;
BB_INLINE constexpr field from_montgomery_form() const noexcept;
BB_INLINE __device__ constexpr field to_montgomery_form() const noexcept;
BB_INLINE __device__ constexpr field from_montgomery_form() const noexcept;
BB_INLINE constexpr field sqr() const noexcept;
BB_INLINE constexpr void self_sqr() noexcept;
BB_INLINE __device__ constexpr field sqr() const noexcept;
BB_INLINE __device__ constexpr void self_sqr() noexcept;
BB_INLINE constexpr field pow(const uint256_t& exponent) const noexcept;
BB_INLINE constexpr field pow(uint64_t exponent) const noexcept;
BB_INLINE __device__ constexpr field pow(const uint256_t& exponent) const noexcept;
BB_INLINE __device__ constexpr field pow(uint64_t exponent) const noexcept;
static_assert(Params::modulus_0 != 1);
static constexpr uint256_t modulus_minus_two =
uint256_t(Params::modulus_0 - 2ULL, Params::modulus_1, Params::modulus_2, Params::modulus_3);
constexpr field invert() const noexcept;
static void batch_invert(std::span<field> coeffs) noexcept;
static void batch_invert(field* coeffs, size_t n) noexcept;
__device__ constexpr field invert() const noexcept;
static __device__ void batch_invert(std::span<field> coeffs) noexcept;
static __device__ void batch_invert(field* coeffs, size_t n) noexcept;
/**
* @brief Compute square root of the field element.
*
@@ -333,17 +317,17 @@ template <class Params_> struct alignas(32) field {
*/
constexpr std::pair<bool, field> sqrt() const noexcept;
BB_INLINE constexpr void self_neg() noexcept;
BB_INLINE __device__ constexpr void self_neg() noexcept;
BB_INLINE constexpr void self_to_montgomery_form() noexcept;
BB_INLINE constexpr void self_from_montgomery_form() noexcept;
BB_INLINE __device__ constexpr void self_to_montgomery_form() noexcept;
BB_INLINE __device__ constexpr void self_from_montgomery_form() noexcept;
BB_INLINE constexpr void self_conditional_negate(uint64_t predicate) noexcept;
BB_INLINE __device__ constexpr void self_conditional_negate(uint64_t predicate) noexcept;
BB_INLINE constexpr field reduce_once() const noexcept;
BB_INLINE constexpr void self_reduce_once() noexcept;
BB_INLINE __device__ constexpr field reduce_once() const noexcept;
BB_INLINE __device__ constexpr void self_reduce_once() noexcept;
BB_INLINE constexpr void self_set_msb() noexcept;
BB_INLINE __device__ constexpr void self_set_msb() noexcept;
[[nodiscard]] BB_INLINE constexpr bool is_msb_set() const noexcept;
[[nodiscard]] BB_INLINE constexpr uint64_t is_msb_set_word() const noexcept;
@@ -472,53 +456,6 @@ template <class Params_> struct alignas(32) field {
};
}
static void split_into_endomorphism_scalars_384(const field& input, field& k1_out, field& k2_out)
{
constexpr field minus_b1f{
Params::endo_minus_b1_lo,
Params::endo_minus_b1_mid,
0,
0,
};
constexpr field b2f{
Params::endo_b2_lo,
Params::endo_b2_mid,
0,
0,
};
constexpr uint256_t g1{
Params::endo_g1_lo,
Params::endo_g1_mid,
Params::endo_g1_hi,
Params::endo_g1_hihi,
};
constexpr uint256_t g2{
Params::endo_g2_lo,
Params::endo_g2_mid,
Params::endo_g2_hi,
Params::endo_g2_hihi,
};
field kf = input.reduce_once();
uint256_t k{ kf.data[0], kf.data[1], kf.data[2], kf.data[3] };
uint512_t c1 = (uint512_t(k) * static_cast<uint512_t>(g1)) >> 384;
uint512_t c2 = (uint512_t(k) * static_cast<uint512_t>(g2)) >> 384;
field c1f{ c1.lo.data[0], c1.lo.data[1], c1.lo.data[2], c1.lo.data[3] };
field c2f{ c2.lo.data[0], c2.lo.data[1], c2.lo.data[2], c2.lo.data[3] };
c1f.self_to_montgomery_form();
c2f.self_to_montgomery_form();
c1f = c1f * minus_b1f;
c2f = c2f * b2f;
field r2f = c1f - c2f;
field beta = cube_root_of_unity();
field r1f = input.reduce_once() - r2f * beta;
k1_out = r1f;
k2_out = -r2f;
}
// static constexpr auto coset_generators = compute_coset_generators();
// static constexpr std::array<field, 15> coset_generators = compute_coset_generators((1 << 30U));
@@ -540,12 +477,10 @@ template <class Params_> struct alignas(32) field {
src = T;
}
static field random_element(numeric::RNG* engine = nullptr) noexcept;
static constexpr field multiplicative_generator() noexcept;
static constexpr uint256_t twice_modulus = modulus + modulus;
static constexpr uint256_t not_modulus = -modulus;
static constexpr uint256_t twice_modulus = get_modulus() + get_modulus();
static constexpr uint256_t not_modulus = -get_modulus();
static constexpr uint256_t twice_not_modulus = -twice_modulus;
struct wnaf_table {
@@ -612,9 +547,9 @@ template <class Params_> struct alignas(32) field {
uint64_t& result_8);
BB_INLINE static constexpr std::array<uint64_t, WASM_NUM_LIMBS> wasm_convert(const uint64_t* data);
#endif
BB_INLINE static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b) noexcept;
BB_INLINE __device__ static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b) noexcept;
BB_INLINE static constexpr uint64_t mac(
BB_INLINE __device__ static constexpr uint64_t mac(
uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& carry_out) noexcept;
BB_INLINE static constexpr void mac(

View File

@@ -3,13 +3,12 @@
#include "../../common/slab_allocator.hpp"
#include "../../common/throw_or_abort.hpp"
#include "../../numeric/bitop/get_msb.hpp"
#include "../../numeric/random/engine.hpp"
#include <memory>
#include <span>
#include <type_traits>
#include <vector>
#include "./field_declarations.hpp"
#include "./field_declarations.cuh"
namespace bb {
using namespace numeric;
@@ -24,7 +23,7 @@ using namespace numeric;
* Mutiplication
*
**/
template <class T> constexpr field<T> field<T>::operator*(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::operator*(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::mul");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -39,7 +38,7 @@ template <class T> constexpr field<T> field<T>::operator*(const field& other) co
}
}
template <class T> constexpr field<T>& field<T>::operator*=(const field& other) noexcept
template <class T> __device__ constexpr field<T>& field<T>::operator*=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_mul");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -61,7 +60,7 @@ template <class T> constexpr field<T>& field<T>::operator*=(const field& other)
* Squaring
*
**/
template <class T> constexpr field<T> field<T>::sqr() const noexcept
template <class T> __device__ constexpr field<T> field<T>::sqr() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::sqr");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -75,7 +74,7 @@ template <class T> constexpr field<T> field<T>::sqr() const noexcept
}
}
template <class T> constexpr void field<T>::self_sqr() noexcept
template <class T> __device__ constexpr void field<T>::self_sqr() noexcept
{
BB_OP_COUNT_TRACK_NAME("f::self_sqr");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -95,7 +94,7 @@ template <class T> constexpr void field<T>::self_sqr() noexcept
* Addition
*
**/
template <class T> constexpr field<T> field<T>::operator+(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::operator+(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::add");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -109,7 +108,7 @@ template <class T> constexpr field<T> field<T>::operator+(const field& other) co
}
}
template <class T> constexpr field<T>& field<T>::operator+=(const field& other) noexcept
template <class T> __device__ constexpr field<T>& field<T>::operator+=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_add");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -125,14 +124,14 @@ template <class T> constexpr field<T>& field<T>::operator+=(const field& other)
return *this;
}
template <class T> constexpr field<T> field<T>::operator++() noexcept
template <class T> __device__ constexpr field<T> field<T>::operator++() noexcept
{
BB_OP_COUNT_TRACK_NAME("++f");
return *this += 1;
}
// NOLINTNEXTLINE(cert-dcl21-cpp) circular linting errors. If const is added, linter suggests removing
template <class T> constexpr field<T> field<T>::operator++(int) noexcept
template <class T> __device__ constexpr field<T> field<T>::operator++(int) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::increment");
field<T> value_before_incrementing = *this;
@@ -145,7 +144,7 @@ template <class T> constexpr field<T> field<T>::operator++(int) noexcept
* Subtraction
*
**/
template <class T> constexpr field<T> field<T>::operator-(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::operator-(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::sub");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -159,7 +158,7 @@ template <class T> constexpr field<T> field<T>::operator-(const field& other) co
}
}
template <class T> constexpr field<T> field<T>::operator-() const noexcept
template <class T> __device__ constexpr field<T> field<T>::operator-() const noexcept
{
BB_OP_COUNT_TRACK_NAME("-f");
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -188,7 +187,7 @@ template <class T> constexpr field<T> field<T>::operator-() const noexcept
return (p - *this).reduce_once(); // modulus - *this;
}
template <class T> constexpr field<T>& field<T>::operator-=(const field& other) noexcept
template <class T> __device__ constexpr field<T>& field<T>::operator-=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_sub");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -204,9 +203,10 @@ template <class T> constexpr field<T>& field<T>::operator-=(const field& other)
return *this;
}
template <class T> constexpr void field<T>::self_neg() noexcept
template <class T> __device__ constexpr void field<T>::self_neg() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_neg");
constexpr uint256_t modulus = get_modulus();
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
@@ -217,7 +217,7 @@ template <class T> constexpr void field<T>::self_neg() noexcept
}
}
template <class T> constexpr void field<T>::self_conditional_negate(const uint64_t predicate) noexcept
template <class T> __device__ constexpr void field<T>::self_conditional_negate(const uint64_t predicate) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_conditional_negate");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -243,7 +243,7 @@ template <class T> constexpr void field<T>::self_conditional_negate(const uint64
* @return true
* @return false
*/
template <class T> constexpr bool field<T>::operator>(const field& other) const noexcept
template <class T> __device__ constexpr bool field<T>::operator>(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::gt");
const field left = reduce_once();
@@ -268,12 +268,12 @@ template <class T> constexpr bool field<T>::operator>(const field& other) const
* @return true
* @return false
*/
template <class T> constexpr bool field<T>::operator<(const field& other) const noexcept
template <class T> __device__ constexpr bool field<T>::operator<(const field& other) const noexcept
{
return (other > *this);
}
template <class T> constexpr bool field<T>::operator==(const field& other) const noexcept
template <class T> __device__ constexpr bool field<T>::operator==(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::eqeq");
const field left = reduce_once();
@@ -282,12 +282,12 @@ template <class T> constexpr bool field<T>::operator==(const field& other) const
(left.data[3] == right.data[3]);
}
template <class T> constexpr bool field<T>::operator!=(const field& other) const noexcept
template <class T> __device__ constexpr bool field<T>::operator!=(const field& other) const noexcept
{
return (!operator==(other));
}
template <class T> constexpr field<T> field<T>::to_montgomery_form() const noexcept
template <class T> __device__ constexpr field<T> field<T>::to_montgomery_form() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::to_montgomery_form");
constexpr field r_squared =
@@ -306,14 +306,14 @@ template <class T> constexpr field<T> field<T>::to_montgomery_form() const noexc
return (result * r_squared).reduce_once();
}
template <class T> constexpr field<T> field<T>::from_montgomery_form() const noexcept
template <class T> __device__ constexpr field<T> field<T>::from_montgomery_form() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::from_montgomery_form");
constexpr field one_raw{ 1, 0, 0, 0 };
return operator*(one_raw).reduce_once();
}
template <class T> constexpr void field<T>::self_to_montgomery_form() noexcept
template <class T> __device__ constexpr void field<T>::self_to_montgomery_form() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_to_montgomery_form");
constexpr field r_squared =
@@ -326,7 +326,7 @@ template <class T> constexpr void field<T>::self_to_montgomery_form() noexcept
self_reduce_once();
}
template <class T> constexpr void field<T>::self_from_montgomery_form() noexcept
template <class T> __device__ constexpr void field<T>::self_from_montgomery_form() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_from_montgomery_form");
constexpr field one_raw{ 1, 0, 0, 0 };
@@ -334,7 +334,7 @@ template <class T> constexpr void field<T>::self_from_montgomery_form() noexcept
self_reduce_once();
}
template <class T> constexpr field<T> field<T>::reduce_once() const noexcept
template <class T> __device__ constexpr field<T> field<T>::reduce_once() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::reduce_once");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -348,7 +348,7 @@ template <class T> constexpr field<T> field<T>::reduce_once() const noexcept
}
}
template <class T> constexpr void field<T>::self_reduce_once() noexcept
template <class T> __device__ constexpr void field<T>::self_reduce_once() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_reduce_once");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
@@ -363,7 +363,7 @@ template <class T> constexpr void field<T>::self_reduce_once() noexcept
}
}
template <class T> constexpr field<T> field<T>::pow(const uint256_t& exponent) const noexcept
template <class T> __device__ constexpr field<T> field<T>::pow(const uint256_t& exponent) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::pow");
field accumulator{ data[0], data[1], data[2], data[3] };
@@ -384,12 +384,12 @@ template <class T> constexpr field<T> field<T>::pow(const uint256_t& exponent) c
return accumulator;
}
template <class T> constexpr field<T> field<T>::pow(const uint64_t exponent) const noexcept
template <class T> __device__ constexpr field<T> field<T>::pow(const uint64_t exponent) const noexcept
{
return pow({ exponent, 0, 0, 0 });
}
template <class T> constexpr field<T> field<T>::invert() const noexcept
template <class T> __device__ constexpr field<T> field<T>::invert() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::invert");
if (*this == zero()) {
@@ -398,12 +398,12 @@ template <class T> constexpr field<T> field<T>::invert() const noexcept
return pow(modulus_minus_two);
}
template <class T> void field<T>::batch_invert(field* coeffs, const size_t n) noexcept
template <class T> __device__ void field<T>::batch_invert(field* coeffs, const size_t n) noexcept
{
batch_invert(std::span{ coeffs, n });
}
template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
template <class T> __device__ void field<T>::batch_invert(std::span<field> coeffs) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::batch_invert");
const size_t n = coeffs.size();
@@ -452,7 +452,7 @@ template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
}
}
template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
template <class T> __device__ constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt");
// Tonelli-shanks algorithm begins by finding a field element Q and integer S,
@@ -532,7 +532,7 @@ template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noex
return r;
}
template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
template <class T> __device__ constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::sqrt");
field root;
@@ -549,41 +549,41 @@ template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const no
} // namespace bb;
template <class T> constexpr field<T> field<T>::operator/(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::operator/(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::div");
return operator*(other.invert());
}
template <class T> constexpr field<T>& field<T>::operator/=(const field& other) noexcept
template <class T> __device__ constexpr field<T>& field<T>::operator/=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_div");
*this = operator/(other);
return *this;
}
template <class T> constexpr void field<T>::self_set_msb() noexcept
template <class T> __device__ constexpr void field<T>::self_set_msb() noexcept
{
data[3] = 0ULL | (1ULL << 63ULL);
}
template <class T> constexpr bool field<T>::is_msb_set() const noexcept
template <class T> __device__ constexpr bool field<T>::is_msb_set() const noexcept
{
return (data[3] >> 63ULL) == 1ULL;
}
template <class T> constexpr uint64_t field<T>::is_msb_set_word() const noexcept
template <class T> __device__ constexpr uint64_t field<T>::is_msb_set_word() const noexcept
{
return (data[3] >> 63ULL);
}
template <class T> constexpr bool field<T>::is_zero() const noexcept
template <class T> __device__ constexpr bool field<T>::is_zero() const noexcept
{
return ((data[0] | data[1] | data[2] | data[3]) == 0) ||
(data[0] == T::modulus_0 && data[1] == T::modulus_1 && data[2] == T::modulus_2 && data[3] == T::modulus_3);
}
template <class T> constexpr field<T> field<T>::get_root_of_unity(size_t subgroup_size) noexcept
template <class T> __device__ constexpr field<T> field<T>::get_root_of_unity(size_t subgroup_size) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
field r{ T::primitive_root_0, T::primitive_root_1, T::primitive_root_2, T::primitive_root_3 };
@@ -596,20 +596,7 @@ template <class T> constexpr field<T> field<T>::get_root_of_unity(size_t subgrou
return r;
}
template <class T> field<T> field<T>::random_element(numeric::RNG* engine) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::random_element");
if (engine == nullptr) {
engine = &numeric::get_randomness();
}
uint512_t source = engine->get_random_uint512();
uint512_t q(modulus);
uint512_t reduced = source % q;
return field(reduced.lo);
}
template <class T> constexpr size_t field<T>::primitive_root_log_size() noexcept
template <class T> __device__ constexpr size_t field<T>::primitive_root_log_size() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::primitive_root_log_size");
uint256_t target = modulus - 1;
@@ -620,7 +607,7 @@ template <class T> constexpr size_t field<T>::primitive_root_log_size() noexcept
return result;
}
template <class T>
template <class T> __device__
constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute_coset_generators() noexcept
{
constexpr size_t n = COSET_GENERATOR_SIZE;
@@ -654,7 +641,7 @@ constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute
return result;
}
template <class T> constexpr field<T> field<T>::multiplicative_generator() noexcept
template <class T> __device__ constexpr field<T> field<T>::multiplicative_generator() noexcept
{
field target(1);
uint256_t p_minus_one_over_two = (modulus - 1) >> 1;

View File

@@ -3,13 +3,13 @@
#include <array>
#include <cstdint>
#include "./field_impl.hpp"
#include "./field_impl.cuh"
#include "../../common/op_count.hpp"
namespace bb {
using namespace numeric;
// NOLINTBEGIN(readability-implicit-bool-conversion)
template <class T> constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(uint64_t a, uint64_t b) noexcept
template <class T> __device__ constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(uint64_t a, uint64_t b) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = (static_cast<uint128_t>(a) * static_cast<uint128_t>(b));
@@ -20,7 +20,7 @@ template <class T> constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(ui
#endif
}
template <class T>
template <class T> __device__
constexpr uint64_t field<T>::mac(
const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t carry_in, uint64_t& carry_out) noexcept
{
@@ -36,7 +36,7 @@ constexpr uint64_t field<T>::mac(
#endif
}
template <class T>
template <class T> __device__
constexpr void field<T>::mac(const uint64_t a,
const uint64_t b,
const uint64_t c,
@@ -56,7 +56,7 @@ constexpr void field<T>::mac(const uint64_t a,
#endif
}
template <class T>
template <class T> __device__
constexpr uint64_t field<T>::mac_mini(const uint64_t a,
const uint64_t b,
const uint64_t c,
@@ -73,7 +73,7 @@ constexpr uint64_t field<T>::mac_mini(const uint64_t a,
#endif
}
template <class T>
template <class T> __device__
constexpr void field<T>::mac_mini(
const uint64_t a, const uint64_t b, const uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept
{
@@ -88,7 +88,7 @@ constexpr void field<T>::mac_mini(
#endif
}
template <class T>
template <class T> __device__
constexpr uint64_t field<T>::mac_discard_lo(const uint64_t a, const uint64_t b, const uint64_t c) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
@@ -99,7 +99,7 @@ constexpr uint64_t field<T>::mac_discard_lo(const uint64_t a, const uint64_t b,
#endif
}
template <class T>
template <class T> __device__
constexpr uint64_t field<T>::addc(const uint64_t a,
const uint64_t b,
const uint64_t carry_in,
@@ -119,7 +119,7 @@ constexpr uint64_t field<T>::addc(const uint64_t a,
#endif
}
template <class T>
template <class T> __device__
constexpr uint64_t field<T>::sbb(const uint64_t a,
const uint64_t b,
const uint64_t borrow_in,
@@ -139,7 +139,7 @@ constexpr uint64_t field<T>::sbb(const uint64_t a,
#endif
}
template <class T>
template <class T> __device__
constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
const uint64_t b,
const uint64_t c,
@@ -177,8 +177,9 @@ constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
#endif
}
template <class T> constexpr field<T> field<T>::reduce() const noexcept
template <class T> __device__ constexpr field<T> field<T>::reduce() const noexcept
{
constexpr uint256_t modulus = get_modulus();
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
uint256_t val{ data[0], data[1], data[2], data[3] };
if (val >= modulus) {
@@ -202,7 +203,7 @@ template <class T> constexpr field<T> field<T>::reduce() const noexcept
};
}
template <class T> constexpr field<T> field<T>::add(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::add(const field& other) const noexcept
{
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
uint64_t r0 = data[0] + other.data[0];
@@ -251,7 +252,7 @@ template <class T> constexpr field<T> field<T>::add(const field& other) const no
}
}
template <class T> constexpr field<T> field<T>::subtract(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::subtract(const field& other) const noexcept
{
uint64_t borrow = 0;
uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
@@ -283,8 +284,9 @@ template <class T> constexpr field<T> field<T>::subtract(const field& other) con
* @param other
* @return constexpr field<T>
*/
template <class T> constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
{
constexpr uint256_t modulus = get_modulus();
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return subtract(other);
}
@@ -309,8 +311,9 @@ template <class T> constexpr field<T> field<T>::subtract_coarse(const field& oth
* @details Explanation of Montgomery form can be found in \ref field_docs_montgomery_explainer and the difference
* between WASM and generic versions is explained in \ref field_docs_architecture_details
*/
template <class T> constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
{
constexpr uint256_t modulus = get_modulus();
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
uint64_t c = 0;
uint64_t t0 = 0;
@@ -531,8 +534,9 @@ template <class T> constexpr std::array<uint64_t, WASM_NUM_LIMBS> field<T>::wasm
(data[3] >> 40) & 0x1fffffff };
}
#endif
template <class T> constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
template <class T> __device__ constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
{
constexpr uint256_t modulus = get_modulus();
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return montgomery_mul_big(other);
}
@@ -653,6 +657,7 @@ template <class T> constexpr field<T> field<T>::montgomery_mul(const field& othe
template <class T> constexpr field<T> field<T>::montgomery_square() const noexcept
{
constexpr uint256_t modulus = get_modulus();
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return montgomery_mul_big(*this);
}

View File

@@ -1,7 +1,7 @@
#pragma once
#if (BBERG_NO_ASM == 0)
#include "./field_impl.hpp"
#include "./field_impl.cuh"
#include "asm_macros.hpp"
namespace bb {

View File

@@ -1,7 +1,7 @@
#pragma once
#include "../../common/serialize.hpp"
#include "../../ecc/curves/bn254/fq2.hpp"
#include "../../numeric/uint256/uint256.hpp"
#include "../../ecc/curves/bn254/fq2.cuh"
#include "../../numeric/uint256/uint256.cuh"
#include <cstring>
#include <type_traits>
#include <vector>
@@ -76,12 +76,6 @@ template <typename Fq_, typename Fr_, typename Params> class alignas(64) affine_
static constexpr std::optional<affine_element> derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept;
/**
* @brief Samples a random point on the curve.
*
* @return A randomly chosen point on the curve
*/
static affine_element random_element(numeric::RNG* engine = nullptr) noexcept;
static constexpr affine_element hash_to_curve(const std::vector<uint8_t>& seed, uint8_t attempt_count = 0) noexcept
requires SupportsHashToCurve<Params>;

View File

@@ -1,5 +1,5 @@
#pragma once
#include "./element.hpp"
#include "./element.cuh"
#include "../../crypto/blake3s/blake3s.hpp"
#include "../../crypto/keccak/keccak.hpp"
@@ -187,104 +187,4 @@ constexpr std::optional<affine_element<Fq, Fr, T>> affine_element<Fq, Fr, T>::de
}
return std::nullopt;
}
/**
* @brief Hash a seed buffer into a point
*
* @details ALGORITHM DESCRIPTION:
* 1. Initialize unsigned integer `attempt_count = 0`
* 2. Copy seed into a buffer whose size is 2 bytes greater than `seed` (initialized to 0)
* 3. Interpret `attempt_count` as a byte and write into buffer at [buffer.size() - 2]
* 4. Compute Blake3s hash of buffer
* 5. Set the end byte of the buffer to `1`
* 6. Compute Blake3s hash of buffer
* 7. Interpret the two hash outputs as the high / low 256 bits of a 512-bit integer (big-endian)
* 8. Derive x-coordinate of point by reducing the 512-bit integer modulo the curve's field modulus (Fq)
* 9. Compute y^2 from the curve formula y^2 = x^3 + ax + b (a, b are curve params. for BN254, a = 0, b = 3)
* 10. IF y^2 IS NOT A QUADRATIC RESIDUE
* 10a. increment `attempt_count` by 1 and go to step 2
* 11. IF y^2 IS A QUADRATIC RESIDUE
* 11a. derive y coordinate via y = sqrt(y)
* 11b. Interpret most significant bit of 512-bit integer as a 'parity' bit
* 11c. If parity bit is set AND y's most significant bit is not set, invert y
* 11d. If parity bit is not set AND y's most significant bit is set, invert y
* N.B. last 2 steps are because the sqrt() algorithm can return 2 values,
* we need to a way to canonically distinguish between these 2 values and select a "preferred" one
* 11e. return (x, y)
*
* @note This algorihm is constexpr: we can hash-to-curve (and derive generators) at compile-time!
* @tparam Fq
* @tparam Fr
* @tparam T
* @param seed Bytes that uniquely define the point being generated
* @param attempt_count
* @return constexpr affine_element<Fq, Fr, T>
*/
template <class Fq, class Fr, class T>
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::hash_to_curve(const std::vector<uint8_t>& seed,
uint8_t attempt_count) noexcept
requires SupportsHashToCurve<T>
{
std::vector<uint8_t> target_seed(seed);
// expand by 2 bytes to cover incremental hash attempts
const size_t seed_size = seed.size();
for (size_t i = 0; i < 2; ++i) {
target_seed.push_back(0);
}
target_seed[seed_size] = attempt_count;
target_seed[seed_size + 1] = 0;
const auto hash_hi = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
target_seed[seed_size + 1] = 1;
const auto hash_lo = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
// custom serialize methods as common/serialize.hpp is not constexpr!
const auto read_uint256 = [](const uint8_t* in) {
const auto read_limb = [](const uint8_t* in, uint64_t& out) {
for (size_t i = 0; i < 8; ++i) {
out += static_cast<uint64_t>(in[i]) << ((7 - i) * 8);
}
};
uint256_t out = 0;
read_limb(&in[0], out.data[3]);
read_limb(&in[8], out.data[2]);
read_limb(&in[16], out.data[1]);
read_limb(&in[24], out.data[0]);
return out;
};
// interpret 64 byte hash output as a uint512_t, reduce to Fq element
//(512 bits of entropy ensures result is not biased as 512 >> Fq::modulus.get_msb())
Fq x(uint512_t(read_uint256(&hash_lo[0]), read_uint256(&hash_hi[0])));
bool sign_bit = hash_hi[0] > 127;
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
if (result.has_value()) {
return result.value();
}
return hash_to_curve(seed, attempt_count + 1);
}
template <typename Fq, typename Fr, typename T>
affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::random_element(numeric::RNG* engine) noexcept
{
if (engine == nullptr) {
engine = &numeric::get_randomness();
}
Fq x;
Fq y;
while (true) {
// Sample a random x-coordinate and check if it satisfies curve equation.
x = Fq::random_element(engine);
// Negate the y-coordinate based on a randomly sampled bit.
bool sign_bit = (engine->get_random_uint8() & 1) != 0;
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
if (result.has_value()) {
return result.value();
}
}
throw_or_abort("affine_element::random_element error");
return affine_element<Fq, Fr, T>(x, y);
}
} // namespace bb::group_elements

View File

@@ -1,10 +1,9 @@
#pragma once
#include "affine_element.hpp"
#include "affine_element.cuh"
#include "../../common/compiler_hints.hpp"
#include "../../common/mem.hpp"
#include "../../numeric/random/engine.hpp"
#include "../../numeric/uint256/uint256.hpp"
#include "../../numeric/uint256/uint256.cuh"
#include "wnaf.hpp"
#include <array>
#include <random>
@@ -49,8 +48,6 @@ template <class Fq, class Fr, class Params> class alignas(32) element {
constexpr operator affine_element<Fq, Fr, Params>() const noexcept;
static element random_element(numeric::RNG* engine = nullptr) noexcept;
constexpr element dbl() const noexcept;
constexpr void self_dbl() noexcept;
constexpr void self_mixed_add_or_sub(const affine_element<Fq, Fr, Params>& other, uint64_t predicate) noexcept;
@@ -107,29 +104,6 @@ template <class Fq, class Fr, class Params> class alignas(32) element {
element mul_without_endomorphism(const Fr& scalar) const noexcept;
element mul_with_endomorphism(const Fr& scalar) const noexcept;
template <typename = typename std::enable_if<Params::can_hash_to_curve>>
static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept;
// {
// bool found_one = false;
// Fq yy;
// Fq x;
// Fq y;
// Fq t0;
// while (!found_one) {
// x = Fq::random_element(engine);
// yy = x.sqr() * x + Params::b;
// if constexpr (Params::has_a) {
// yy += (x * Params::a);
// }
// y = yy.sqrt();
// t0 = y.sqr();
// found_one = (yy == t0);
// }
// return { x, y, Fq::one() };
// }
// for serialization: update with new fields
// TODO(https://github.com/AztecProtocol/barretenberg/issues/908) point at inifinty isn't handled
static void conditional_negate_affine(const affine_element<Fq, Fr, Params>& in,
affine_element<Fq, Fr, Params>& out,
uint64_t predicate) noexcept;

View File

@@ -1,8 +1,7 @@
#pragma once
#include "../../common/op_count.hpp"
#include "../../common/thread.hpp"
#include "./element.hpp"
#include "element.hpp"
#include "./element.cuh"
#include <cstdint>
// NOLINTBEGIN(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays)
@@ -577,23 +576,6 @@ constexpr bool element<Fq, Fr, T>::operator==(const element& other) const noexce
return both_infinity || ((lhs_x == rhs_x) && (lhs_y == rhs_y));
}
template <class Fq, class Fr, class T>
element<Fq, Fr, T> element<Fq, Fr, T>::random_element(numeric::RNG* engine) noexcept
{
if constexpr (T::can_hash_to_curve) {
element result = random_coordinates_on_curve(engine);
result.z = Fq::random_element(engine);
Fq zz = result.z.sqr();
Fq zzz = zz * result.z;
result.x *= zz;
result.y *= zzz;
return result;
} else {
Fr scalar = Fr::random_element(engine);
return (element{ T::one_x, T::one_y, Fq::one() } * scalar);
}
}
template <class Fq, class Fr, class T>
element<Fq, Fr, T> element<Fq, Fr, T>::mul_without_endomorphism(const Fr& scalar) const noexcept
{
@@ -1193,27 +1175,5 @@ void element<Fq, Fr, T>::batch_normalize(element* elements, const size_t num_ele
elements[i].z = Fq::one();
}
}
template <typename Fq, typename Fr, typename T>
template <typename>
element<Fq, Fr, T> element<Fq, Fr, T>::random_coordinates_on_curve(numeric::RNG* engine) noexcept
{
bool found_one = false;
Fq yy;
Fq x;
Fq y;
while (!found_one) {
x = Fq::random_element(engine);
yy = x.sqr() * x + T::b;
if constexpr (T::has_a) {
yy += (x * T::a);
}
auto [found_root, y1] = yy.sqrt();
y = y1;
found_one = found_root;
}
return { x, y, Fq::one() };
}
} // namespace bb::group_elements
// NOLINTEND(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays)

View File

@@ -1,8 +1,8 @@
#pragma once
#include "../../common/assert.hpp"
#include "./affine_element.hpp"
#include "./element.hpp"
#include "./affine_element.cuh"
#include "./element.cuh"
#include "./wnaf.hpp"
#include "../../common/constexpr_utils.hpp"
#include "../../crypto/blake3s/blake3s.hpp"

View File

@@ -1,157 +0,0 @@
#pragma once
#include "../../common/throw_or_abort.hpp"
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <vector>
#include "../uint256/uint256.hpp"
namespace bb::numeric {
inline std::vector<uint64_t> slice_input(const uint256_t& input, const uint64_t base, const size_t num_slices)
{
uint256_t target = input;
std::vector<uint64_t> slices;
if (num_slices > 0) {
for (size_t i = 0; i < num_slices; ++i) {
slices.push_back((target % base).data[0]);
target /= base;
}
} else {
while (target > 0) {
slices.push_back((target % base).data[0]);
target /= base;
}
}
return slices;
}
inline std::vector<uint64_t> slice_input_using_variable_bases(const uint256_t& input,
const std::vector<uint64_t>& bases)
{
uint256_t target = input;
std::vector<uint64_t> slices;
for (size_t i = 0; i < bases.size(); ++i) {
if (target >= bases[i] && i == bases.size() - 1) {
throw_or_abort(format("Last key slice greater than ", bases[i]));
}
slices.push_back((target % bases[i]).data[0]);
target /= bases[i];
}
return slices;
}
template <uint64_t base, uint64_t num_slices> constexpr std::array<uint256_t, num_slices> get_base_powers()
{
std::array<uint256_t, num_slices> output{};
output[0] = 1;
for (size_t i = 1; i < num_slices; ++i) {
output[i] = output[i - 1] * base;
}
return output;
}
template <uint64_t base> constexpr uint256_t map_into_sparse_form(const uint64_t input)
{
uint256_t out = 0UL;
auto converted = input;
constexpr auto base_powers = get_base_powers<base, 32>();
for (size_t i = 0; i < 32; ++i) {
uint64_t sparse_bit = ((converted >> i) & 1U);
if (sparse_bit) {
out += base_powers[i];
}
}
return out;
}
template <uint64_t base> constexpr uint64_t map_from_sparse_form(const uint256_t& input)
{
uint256_t target = input;
uint64_t output = 0;
constexpr auto bases = get_base_powers<base, 32>();
for (uint64_t i = 0; i < 32; ++i) {
const auto& base_power = bases[static_cast<size_t>(31 - i)];
uint256_t prev_threshold = 0;
for (uint64_t j = 1; j < base + 1; ++j) {
const auto threshold = prev_threshold + base_power;
if (target < threshold) {
bool bit = ((j - 1) & 1);
if (bit) {
output += (1ULL << (31ULL - i));
}
if (j > 1) {
target -= (prev_threshold);
}
break;
}
prev_threshold = threshold;
}
}
return output;
}
template <uint64_t base, size_t num_bits> class sparse_int {
public:
sparse_int(const uint64_t input = 0)
: value(input)
{
for (size_t i = 0; i < num_bits; ++i) {
const uint64_t bit = (input >> i) & 1U;
limbs[i] = bit;
}
}
sparse_int(const sparse_int& other) noexcept = default;
sparse_int(sparse_int&& other) noexcept = default;
sparse_int& operator=(const sparse_int& other) noexcept = default;
sparse_int& operator=(sparse_int&& other) noexcept = default;
~sparse_int() noexcept = default;
sparse_int operator+(const sparse_int& other) const
{
sparse_int result(*this);
for (size_t i = 0; i < num_bits - 1; ++i) {
result.limbs[i] += other.limbs[i];
if (result.limbs[i] >= base) {
result.limbs[i] -= base;
++result.limbs[i + 1];
}
}
result.limbs[num_bits - 1] += other.limbs[num_bits - 1];
result.limbs[num_bits - 1] %= base;
result.value += other.value;
return result;
};
sparse_int operator+=(const sparse_int& other)
{
*this = *this + other;
return *this;
}
[[nodiscard]] uint64_t get_value() const { return value; }
[[nodiscard]] uint64_t get_sparse_value() const
{
uint64_t result = 0;
for (size_t i = num_bits - 1; i < num_bits; --i) {
result *= base;
result += limbs[i];
}
return result;
}
const std::array<uint64_t, num_bits>& get_limbs() const { return limbs; }
private:
std::array<uint64_t, num_bits> limbs;
uint64_t value;
uint64_t sparse_value;
};
} // namespace bb::numeric

View File

@@ -1,139 +0,0 @@
#include "engine.hpp"
#include "../../common/assert.hpp"
#include <array>
#include <functional>
#include <random>
namespace bb::numeric {
namespace {
auto generate_random_data()
{
std::array<unsigned int, 32> random_data;
std::random_device source;
std::generate(std::begin(random_data), std::end(random_data), std::ref(source));
return random_data;
}
} // namespace
class RandomEngine : public RNG {
public:
uint8_t get_random_uint8() override
{
auto buf = generate_random_data();
uint32_t out = buf[0];
return static_cast<uint8_t>(out);
}
uint16_t get_random_uint16() override
{
auto buf = generate_random_data();
uint32_t out = buf[0];
return static_cast<uint16_t>(out);
}
uint32_t get_random_uint32() override
{
auto buf = generate_random_data();
uint32_t out = buf[0];
return static_cast<uint32_t>(out);
}
uint64_t get_random_uint64() override
{
auto buf = generate_random_data();
auto lo = static_cast<uint64_t>(buf[0]);
auto hi = static_cast<uint64_t>(buf[1]);
return (lo + (hi << 32ULL));
}
uint128_t get_random_uint128() override
{
auto big = get_random_uint256();
auto lo = static_cast<uint128_t>(big.data[0]);
auto hi = static_cast<uint128_t>(big.data[1]);
return (lo + (hi << static_cast<uint128_t>(64ULL)));
}
uint256_t get_random_uint256() override
{
const auto get64 = [](const std::array<uint32_t, 32>& buffer, const size_t offset) {
auto lo = static_cast<uint64_t>(buffer[0 + offset]);
auto hi = static_cast<uint64_t>(buffer[1 + offset]);
return (lo + (hi << 32ULL));
};
auto buf = generate_random_data();
uint64_t lolo = get64(buf, 0);
uint64_t lohi = get64(buf, 2);
uint64_t hilo = get64(buf, 4);
uint64_t hihi = get64(buf, 6);
return { lolo, lohi, hilo, hihi };
}
};
class DebugEngine : public RNG {
public:
DebugEngine()
// disable linting for this line: we want the DEBUG engine to produce predictable pseudorandom numbers!
// NOLINTNEXTLINE(cert-msc32-c, cert-msc51-cpp)
: engine(std::mt19937_64(12345))
{}
DebugEngine(std::uint_fast64_t seed)
: engine(std::mt19937_64(seed))
{}
uint8_t get_random_uint8() override { return static_cast<uint8_t>(dist(engine)); }
uint16_t get_random_uint16() override { return static_cast<uint16_t>(dist(engine)); }
uint32_t get_random_uint32() override { return static_cast<uint32_t>(dist(engine)); }
uint64_t get_random_uint64() override { return dist(engine); }
uint128_t get_random_uint128() override
{
uint128_t hi = dist(engine);
uint128_t lo = dist(engine);
return (hi << 64) | lo;
}
uint256_t get_random_uint256() override
{
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
auto a = dist(engine);
auto b = dist(engine);
auto c = dist(engine);
auto d = dist(engine);
return { a, b, c, d };
}
private:
std::mt19937_64 engine;
std::uniform_int_distribution<uint64_t> dist = std::uniform_int_distribution<uint64_t>{ 0ULL, UINT64_MAX };
};
/**
* Used by tests to ensure consistent behavior.
*/
RNG& get_debug_randomness(bool reset, std::uint_fast64_t seed)
{
// static std::seed_seq seed({ 1, 2, 3, 4, 5 });
static DebugEngine debug_engine = DebugEngine();
if (reset) {
debug_engine = DebugEngine(seed);
}
return debug_engine;
}
/**
* Default engine. If wanting consistent proof construction, uncomment the line to return the debug engine.
*/
RNG& get_randomness()
{
// return get_debug_randomness();
static RandomEngine engine;
return engine;
}
} // namespace bb::numeric

View File

@@ -1,52 +0,0 @@
#pragma once
#include "../uint128/uint128.hpp"
#include "../uint256/uint256.hpp"
#include "../uintx/uintx.hpp"
#include "unistd.h"
#include <cstdint>
#include <random>
namespace bb::numeric {
class RNG {
public:
virtual uint8_t get_random_uint8() = 0;
virtual uint16_t get_random_uint16() = 0;
virtual uint32_t get_random_uint32() = 0;
virtual uint64_t get_random_uint64() = 0;
virtual uint128_t get_random_uint128() = 0;
virtual uint256_t get_random_uint256() = 0;
virtual ~RNG() = default;
RNG() noexcept = default;
RNG(const RNG& other) = default;
RNG(RNG&& other) = default;
RNG& operator=(const RNG& other) = default;
RNG& operator=(RNG&& other) = default;
uint512_t get_random_uint512()
{
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
auto lo = get_random_uint256();
auto hi = get_random_uint256();
return { lo, hi };
}
uint1024_t get_random_uint1024()
{
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
auto lo = get_random_uint512();
auto hi = get_random_uint512();
return { lo, hi };
}
};
RNG& get_debug_randomness(bool reset = false, std::uint_fast64_t seed = 12345);
RNG& get_randomness();
} // namespace bb::numeric

View File

@@ -5,7 +5,7 @@
#include "../../common/assert.hpp"
namespace bb::numeric {
constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
{
const uint32_t a_lo = a & 0xffffULL;
const uint32_t a_hi = a >> 16ULL;
@@ -23,7 +23,7 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, co
}
// compute a + b + carry, returning the carry
constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
{
const uint32_t sum = a + b;
const auto carry_temp = static_cast<uint32_t>(sum < a);
@@ -32,12 +32,12 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const
return { r, carry_out };
}
constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
__device__ constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
{
return a + b + carry_in;
}
constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
{
const uint32_t t_1 = a - (borrow_in >> 31ULL);
const auto borrow_temp_1 = static_cast<uint32_t>(t_1 > a);
@@ -47,13 +47,13 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const u
return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
}
constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
__device__ constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
{
return a - b - (borrow_in >> 31ULL);
}
// {r, carry_out} = a + carry_in + b * c
constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
__device__ constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
const uint32_t b,
const uint32_t c,
const uint32_t carry_in)
@@ -67,7 +67,7 @@ constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
return result;
}
constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
__device__ constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
const uint32_t b,
const uint32_t c,
const uint32_t carry_in)
@@ -75,7 +75,7 @@ constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
return (b * c + a + carry_in);
}
constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
__device__ constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
{
if (*this == 0 || b == 0) {
return { 0, 0 };
@@ -123,7 +123,7 @@ constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b)
return { quotient, remainder };
}
constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
__device__ constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
{
const auto [r0, t0] = mul_wide(data[0], other.data[0]);
const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);

View File

@@ -1,6 +1,6 @@
#pragma once
#include "../bitop/get_msb.hpp"
#include "./uint256.hpp"
#include "./uint256.cuh"
#include "../../common/assert.hpp"
namespace bb::numeric {

View File

@@ -11,7 +11,7 @@
**/
#pragma once
#include "../uint256/uint256.hpp"
#include "../uint256/uint256.cuh"
#include "../../common/assert.hpp"
#include "../../common/throw_or_abort.hpp"
#include <cstdint>
@@ -166,13 +166,4 @@ template <class base_uint> inline std::ostream& operator<<(std::ostream& os, uin
os << a.lo << ", " << a.hi << std::endl;
return os;
}
using uint512_t = uintx<numeric::uint256_t>;
using uint1024_t = uintx<uint512_t>;
} // namespace bb::numeric
#include "./uintx_impl.hpp"
using bb::numeric::uint1024_t; // NOLINT
using bb::numeric::uint512_t; // NOLINT

View File

@@ -1,5 +1,5 @@
#pragma once
#include "./uintx.hpp"
#include "./uintx.cuh"
#include "../../common/assert.hpp"
namespace bb::numeric {

View File

@@ -0,0 +1,10 @@
#ifndef __PRIME_FIELD_H__
#define __PRIME_FIELD_H__
#include <stdint.h>
struct FieldBinding {
uint64_t data[4];
};
#endif

View File

@@ -0,0 +1,25 @@
#include <cassert>
#include <iostream>
#include <cstring>
#include "prime_field.h"
#include "./barretenberg/ecc/curves/bn254/fq.hpp"
int main() {
bb::fq a = bb::fq(1UL);
bb::fq b = bb::fq(2U);
bb::fq c = a + b;
assert(c == bb::fq(3U));
c = a * b;
assert(c == bb::fq(2U));
// memory layout test
struct Field val_c { { 1UL, 2UL, 3UL, 4UL } };
bb::fq val;
std::memcpy(&val, &val_c, 32);
assert(val.data[0] == 1UL);
assert(val.data[1] == 2UL);
assert(val.data[2] == 3UL);
assert(val.data[3] == 4UL);
std::cout << "test ended" << std::endl;
return 0;
}

View File

@@ -0,0 +1 @@
#include "prime_field.h"