mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-08 23:18:00 -05:00
20
sumcheck/src/cuda/includes/barretenberg/common/assert.hpp
Normal file
20
sumcheck/src/cuda/includes/barretenberg/common/assert.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
// NOLINTBEGIN
|
||||
#if NDEBUG
|
||||
// Compiler should optimize this out in release builds, without triggering an unused variable warning.
|
||||
#define DONT_EVALUATE(expression) \
|
||||
{ \
|
||||
true ? static_cast<void>(0) : static_cast<void>((expression)); \
|
||||
}
|
||||
#define ASSERT(expression) DONT_EVALUATE((expression))
|
||||
#else
|
||||
// cassert in wasi-sdk takes one second to compile, only include if needed
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#define ASSERT(expression) assert((expression))
|
||||
#endif // NDEBUG
|
||||
|
||||
// NOLINTEND
|
||||
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
#define BB_INLINE __forceinline inline
|
||||
#else
|
||||
#define BB_INLINE __attribute__((always_inline)) inline
|
||||
#endif
|
||||
|
||||
// TODO(AD): Other instrumentation?
|
||||
#ifdef XRAY
|
||||
#define BB_PROFILE [[clang::xray_always_instrument]] [[clang::noinline]]
|
||||
#define BB_NO_PROFILE [[clang::xray_never_instrument]]
|
||||
#else
|
||||
#define BB_PROFILE
|
||||
#define BB_NO_PROFILE
|
||||
#endif
|
||||
|
||||
// Optimization hints for clang - which outcome of an expression is expected for better
|
||||
// branch-prediction optimization
|
||||
#ifdef __clang__
|
||||
#define BB_LIKELY(x) __builtin_expect(!!(x), 1)
|
||||
#define BB_UNLIKELY(x) __builtin_expect(!!(x), 0)
|
||||
#else
|
||||
#define BB_LIKELY(x) x
|
||||
#define BB_UNLIKELY(x) x
|
||||
#endif
|
||||
@@ -0,0 +1,162 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
/**
|
||||
* @brief constexpr_utils defines some helper methods that perform some stl-equivalent operations
|
||||
* but in a constexpr context over quantities known at compile-time
|
||||
*
|
||||
* Current methods are:
|
||||
*
|
||||
* constexpr_for : loop over a range , where the size_t iterator `i` is a constexpr variable
|
||||
* constexpr_find : find if an element is in an array
|
||||
*/
|
||||
namespace bb {
|
||||
|
||||
/**
|
||||
* @brief Implements a loop using a compile-time iterator. Requires c++20.
|
||||
* Implementation (and description) from https://artificial-mind.net/blog/2020/10/31/constexpr-for
|
||||
*
|
||||
* @tparam Start the loop start value
|
||||
* @tparam End the loop end value
|
||||
* @tparam Inc how much the iterator increases by per iteration
|
||||
* @tparam F a Lambda function that is executed once per loop
|
||||
*
|
||||
* @param f An rvalue reference to the lambda
|
||||
* @details Implements a `for` loop where the iterator is a constexpr variable.
|
||||
* Use this when you need to evaluate `if constexpr` statements on the iterator (or apply other constexpr expressions)
|
||||
* Outside of this use-case avoid using this fn as it gives negligible performance increases vs regular loops.
|
||||
*
|
||||
* N.B. A side-effect of this method is that all loops will be unrolled
|
||||
* (each loop iteration uses different iterator template parameters => unique constexpr_for implementation per
|
||||
* iteration)
|
||||
* Do not use this for large (~100+) loops!
|
||||
*
|
||||
* ##############################
|
||||
* EXAMPLE USE OF `constexpr_for`
|
||||
* ##############################
|
||||
*
|
||||
* constexpr_for<0, 10, 1>([&]<size_t i>(){
|
||||
* if constexpr (i & 1 == 0)
|
||||
* {
|
||||
* foo[i] = even_container[i >> 1];
|
||||
* }
|
||||
* else
|
||||
* {
|
||||
* foo[i] = odd_container[i >> 1];
|
||||
* }
|
||||
* });
|
||||
*
|
||||
* In the above example we are iterating from i = 0 to i < 10.
|
||||
* The provided lambda function has captured everything in its surrounding scope (via `[&]`),
|
||||
* which is where `foo`, `even_container` and `odd_container` have come from.
|
||||
*
|
||||
* We do not need to explicitly define the `class F` parameter as the compiler derives it from our provided input
|
||||
* argument `F&& f` (i.e. the lambda function)
|
||||
*
|
||||
* In the loop itself we're evaluating a constexpr if statement that defines which code path is taken.
|
||||
*
|
||||
* The above example benefits from `constexpr_for` because a run-time `if` statement has been reduced to a compile-time
|
||||
* `if` statement. N.B. this would only give measurable improvements if the `constexpr_for` statement is itself in a hot
|
||||
* loop that's iterated over many (>thousands) times
|
||||
*/
|
||||
template <size_t Start, size_t End, size_t Inc, class F> constexpr void constexpr_for(F&& f)
|
||||
{
|
||||
// Call function `f<Start>()` iff Start < End
|
||||
if constexpr (Start < End) {
|
||||
// F must be a template lambda with a single **typed** template parameter that represents the iterator
|
||||
// (e.g. [&]<size_t i>(){ ... } is good)
|
||||
// (and [&]<typename i>(){ ... } won't compile!)
|
||||
|
||||
/**
|
||||
* Explaining f.template operator()<Start>()
|
||||
*
|
||||
* The following line must explicitly tell the compiler that <Start> is a template parameter by using the
|
||||
* `template` keyword.
|
||||
* (if we wrote f<Start>(), the compiler could legitimately interpret `<` as a less than symbol)
|
||||
*
|
||||
* The fragment `f.template` tells the compiler that we're calling a *templated* member of `f`.
|
||||
* The "member" being called is the function operator, `operator()`, which must be explicitly provided
|
||||
* (for any function X, `X(args)` is an alias for `X.operator()(args)`)
|
||||
* The compiler has no alias `X.template <tparam>(args)` for `X.template operator()<tparam>(args)` so we must
|
||||
* write it explicitly here
|
||||
*
|
||||
* To summarize what the next line tells the compiler...
|
||||
* 1. I want to call a member of `f` that expects one or more template parameters
|
||||
* 2. The member of `f` that I want to call is the function operator
|
||||
* 3. The template parameter is `Start`
|
||||
* 4. The function operator itself contains no arguments
|
||||
*/
|
||||
f.template operator()<Start>();
|
||||
|
||||
// Once we have executed `f`, we recursively call the `constexpr_for` function, increasing the value of `Start`
|
||||
// by `Inc`
|
||||
constexpr_for<Start + Inc, End, Inc>(f);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief returns true/false depending on whether `key` is in `container`
|
||||
*
|
||||
* @tparam container i.e. what are we looking in?
|
||||
* @tparam key i.e. what are we looking for?
|
||||
* @return true found!
|
||||
* @return false not found!
|
||||
*
|
||||
* @details method is constexpr and can be used in static_asserts
|
||||
*/
|
||||
template <const auto& container, auto key> constexpr bool constexpr_find()
|
||||
{
|
||||
// using ElementType = typename std::remove_extent<ContainerType>::type;
|
||||
bool found = false;
|
||||
constexpr_for<0, container.size(), 1>([&]<size_t k>() {
|
||||
if constexpr (std::get<k>(container) == key) {
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a constexpr array object whose elements contain a default value
|
||||
*
|
||||
* @tparam T type contained in the array
|
||||
* @tparam Is index sequence
|
||||
* @param value the value each array element is being initialized to
|
||||
* @return constexpr std::array<T, sizeof...(Is)>
|
||||
*
|
||||
* @details This method is used to create constexpr arrays whose encapsulated type:
|
||||
*
|
||||
* 1. HAS NO CONSTEXPR DEFAULT CONSTRUCTOR
|
||||
* 2. HAS A CONSTEXPR COPY CONSTRUCTOR
|
||||
*
|
||||
* An example of this is bb::field_t
|
||||
* (the default constructor does not default assign values to the field_t member variables for efficiency reasons, to
|
||||
* reduce the time require to construct large arrays of field elements. This means the default constructor for field_t
|
||||
* cannot be constexpr)
|
||||
*/
|
||||
template <typename T, std::size_t... Is>
|
||||
constexpr std::array<T, sizeof...(Is)> create_array(T value, std::index_sequence<Is...> /*unused*/)
|
||||
{
|
||||
// cast Is to void to remove the warning: unused value
|
||||
std::array<T, sizeof...(Is)> result = { { (static_cast<void>(Is), value)... } };
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a constexpr array object whose values all are 0
|
||||
*
|
||||
* @tparam T
|
||||
* @tparam N
|
||||
* @return constexpr std::array<T, N>
|
||||
*
|
||||
* @details Use in the same context as create_array, i.e. when encapsulated type has a default constructor that is not
|
||||
* constexpr
|
||||
*/
|
||||
template <typename T, size_t N> constexpr std::array<T, N> create_empty_array()
|
||||
{
|
||||
return create_array(T(0), std::make_index_sequence<N>());
|
||||
}
|
||||
}; // namespace bb
|
||||
129
sumcheck/src/cuda/includes/barretenberg/common/log.hpp
Normal file
129
sumcheck/src/cuda/includes/barretenberg/common/log.hpp
Normal file
@@ -0,0 +1,129 @@
|
||||
#pragma once
|
||||
#include "../env/logstr.hpp"
|
||||
#include "../stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp"
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define BENCHMARK_INFO_PREFIX "##BENCHMARK_INFO_PREFIX##"
|
||||
#define BENCHMARK_INFO_SEPARATOR "#"
|
||||
#define BENCHMARK_INFO_SUFFIX "##BENCHMARK_INFO_SUFFIX##"
|
||||
|
||||
template <typename... Args> std::string format(Args... args)
|
||||
{
|
||||
std::ostringstream os;
|
||||
((os << args), ...);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
template <typename T> void benchmark_format_chain(std::ostream& os, T const& first)
|
||||
{
|
||||
// We will be saving these values to a CSV file, so we can't tolerate commas
|
||||
std::stringstream current_argument;
|
||||
current_argument << first;
|
||||
std::string current_argument_string = current_argument.str();
|
||||
std::replace(current_argument_string.begin(), current_argument_string.end(), ',', ';');
|
||||
os << current_argument_string << BENCHMARK_INFO_SUFFIX;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void benchmark_format_chain(std::ostream& os, T const& first, Args const&... args)
|
||||
{
|
||||
// We will be saving these values to a CSV file, so we can't tolerate commas
|
||||
std::stringstream current_argument;
|
||||
current_argument << first;
|
||||
std::string current_argument_string = current_argument.str();
|
||||
std::replace(current_argument_string.begin(), current_argument_string.end(), ',', ';');
|
||||
os << current_argument_string << BENCHMARK_INFO_SEPARATOR;
|
||||
benchmark_format_chain(os, args...);
|
||||
}
|
||||
|
||||
template <typename... Args> std::string benchmark_format(Args... args)
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << BENCHMARK_INFO_PREFIX;
|
||||
benchmark_format_chain(os, args...);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
#if NDEBUG
|
||||
template <typename... Args> inline void debug(Args... args)
|
||||
{
|
||||
logstr(format(args...).c_str());
|
||||
}
|
||||
#else
|
||||
template <typename... Args> inline void debug(Args... /*unused*/) {}
|
||||
#endif
|
||||
|
||||
template <typename... Args> inline void info(Args... args)
|
||||
{
|
||||
logstr(format(args...).c_str());
|
||||
}
|
||||
|
||||
template <typename... Args> inline void important(Args... args)
|
||||
{
|
||||
logstr(format("important: ", args...).c_str());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Info used to store circuit statistics during CI/CD with concrete structure. Writes straight to log
|
||||
*
|
||||
* @details Automatically appends the necessary prefix and suffix, as well as separators.
|
||||
*
|
||||
* @tparam Args
|
||||
* @param args
|
||||
*/
|
||||
#ifdef CI
|
||||
template <typename Arg1, typename Arg2, typename Arg3, typename Arg4, typename Arg5>
|
||||
inline void benchmark_info(Arg1 composer, Arg2 class_name, Arg3 operation, Arg4 metric, Arg5 value)
|
||||
{
|
||||
logstr(benchmark_format(composer, class_name, operation, metric, value).c_str());
|
||||
}
|
||||
#else
|
||||
template <typename... Args> inline void benchmark_info(Args... /*unused*/) {}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief A class for saving benchmarks and printing them all at once in the end of the function.
|
||||
*
|
||||
*/
|
||||
class BenchmarkInfoCollator {
|
||||
|
||||
std::vector<std::string> saved_benchmarks;
|
||||
|
||||
public:
|
||||
BenchmarkInfoCollator() = default;
|
||||
BenchmarkInfoCollator(const BenchmarkInfoCollator& other) = default;
|
||||
BenchmarkInfoCollator(BenchmarkInfoCollator&& other) = default;
|
||||
BenchmarkInfoCollator& operator=(const BenchmarkInfoCollator& other) = default;
|
||||
BenchmarkInfoCollator& operator=(BenchmarkInfoCollator&& other) = default;
|
||||
|
||||
/**
|
||||
* @brief Info used to store circuit statistics during CI/CD with concrete structure. Stores string in vector for now
|
||||
* (used to flush all benchmarks at the end of test).
|
||||
*
|
||||
* @details Automatically appends the necessary prefix and suffix, as well as separators.
|
||||
*
|
||||
* @tparam Args
|
||||
* @param args
|
||||
*/
|
||||
#ifdef CI
|
||||
template <typename Arg1, typename Arg2, typename Arg3, typename Arg4, typename Arg5>
|
||||
inline void benchmark_info_deferred(Arg1 composer, Arg2 class_name, Arg3 operation, Arg4 metric, Arg5 value)
|
||||
{
|
||||
saved_benchmarks.push_back(benchmark_format(composer, class_name, operation, metric, value).c_str());
|
||||
}
|
||||
#else
|
||||
explicit BenchmarkInfoCollator(std::vector<std::string> saved_benchmarks)
|
||||
: saved_benchmarks(std::move(saved_benchmarks))
|
||||
{}
|
||||
template <typename... Args> inline void benchmark_info_deferred(Args... /*unused*/) {}
|
||||
#endif
|
||||
~BenchmarkInfoCollator()
|
||||
{
|
||||
for (auto& x : saved_benchmarks) {
|
||||
logstr(x.c_str());
|
||||
}
|
||||
}
|
||||
};
|
||||
82
sumcheck/src/cuda/includes/barretenberg/common/mem.hpp
Normal file
82
sumcheck/src/cuda/includes/barretenberg/common/mem.hpp
Normal file
@@ -0,0 +1,82 @@
|
||||
#pragma once
|
||||
#include "log.hpp"
|
||||
#include "memory.h"
|
||||
#include "wasm_export.hpp"
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
// #include <malloc.h>
|
||||
|
||||
#define pad(size, alignment) (size - (size % alignment) + ((size % alignment) == 0 ? 0 : alignment))
|
||||
|
||||
#ifdef __APPLE__
|
||||
inline void* aligned_alloc(size_t alignment, size_t size)
|
||||
{
|
||||
void* t = 0;
|
||||
posix_memalign(&t, alignment, size);
|
||||
if (t == 0) {
|
||||
info("bad alloc of size: ", size);
|
||||
std::abort();
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
inline void aligned_free(void* mem)
|
||||
{
|
||||
free(mem);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__linux__) || defined(__wasm__)
|
||||
inline void* protected_aligned_alloc(size_t alignment, size_t size)
|
||||
{
|
||||
size += (size % alignment);
|
||||
void* t = nullptr;
|
||||
// pad size to alignment
|
||||
if (size % alignment != 0) {
|
||||
size += alignment - (size % alignment);
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
|
||||
t = aligned_alloc(alignment, size);
|
||||
if (t == nullptr) {
|
||||
info("bad alloc of size: ", size);
|
||||
std::abort();
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
#define aligned_alloc protected_aligned_alloc
|
||||
|
||||
inline void aligned_free(void* mem)
|
||||
{
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory, cppcoreguidelines-no-malloc)
|
||||
free(mem);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
inline void* aligned_alloc(size_t alignment, size_t size)
|
||||
{
|
||||
return _aligned_malloc(size, alignment);
|
||||
}
|
||||
|
||||
inline void aligned_free(void* mem)
|
||||
{
|
||||
_aligned_free(mem);
|
||||
}
|
||||
#endif
|
||||
|
||||
// inline void print_malloc_info()
|
||||
// {
|
||||
// struct mallinfo minfo = mallinfo();
|
||||
|
||||
// info("Total non-mmapped bytes (arena): ", minfo.arena);
|
||||
// info("Number of free chunks (ordblks): ", minfo.ordblks);
|
||||
// info("Number of fastbin blocks (smblks): ", minfo.smblks);
|
||||
// info("Number of mmapped regions (hblks): ", minfo.hblks);
|
||||
// info("Space allocated in mmapped regions (hblkhd): ", minfo.hblkhd);
|
||||
// info("Maximum total allocated space (usmblks): ", minfo.usmblks);
|
||||
// info("Space available in freed fastbin blocks (fsmblks): ", minfo.fsmblks);
|
||||
// info("Total allocated space (uordblks): ", minfo.uordblks);
|
||||
// info("Total free space (fordblks): ", minfo.fordblks);
|
||||
// info("Top-most, releasable space (keepcost): ", minfo.keepcost);
|
||||
// }
|
||||
15
sumcheck/src/cuda/includes/barretenberg/common/net.hpp
Normal file
15
sumcheck/src/cuda/includes/barretenberg/common/net.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__linux__) || defined(__wasm__)
|
||||
#include <arpa/inet.h>
|
||||
#include <endian.h>
|
||||
#define ntohll be64toh
|
||||
#define htonll htobe64
|
||||
#endif
|
||||
|
||||
inline bool is_little_endian()
|
||||
{
|
||||
constexpr int num = 42;
|
||||
// NOLINTNEXTLINE Nope. nope nope nope nope nope.
|
||||
return (*(char*)&num == 42);
|
||||
}
|
||||
104
sumcheck/src/cuda/includes/barretenberg/common/op_count.cpp
Normal file
104
sumcheck/src/cuda/includes/barretenberg/common/op_count.cpp
Normal file
@@ -0,0 +1,104 @@
|
||||
|
||||
#include <cstddef>
|
||||
#ifdef BB_USE_OP_COUNT
|
||||
#include "op_count.hpp"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
namespace bb::detail {
|
||||
|
||||
GlobalOpCountContainer::~GlobalOpCountContainer()
|
||||
{
|
||||
// This is useful for printing counts at the end of non-benchmarks.
|
||||
// See op_count_google_bench.hpp for benchmarks.
|
||||
// print();
|
||||
}
|
||||
|
||||
void GlobalOpCountContainer::add_entry(const char* key, const std::shared_ptr<OpStats>& count)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
std::stringstream ss;
|
||||
ss << std::this_thread::get_id();
|
||||
counts.push_back({ key, ss.str(), count });
|
||||
}
|
||||
|
||||
void GlobalOpCountContainer::print() const
|
||||
{
|
||||
std::cout << "print_op_counts() START" << std::endl;
|
||||
for (const Entry& entry : counts) {
|
||||
if (entry.count->count > 0) {
|
||||
std::cout << entry.key << "\t" << entry.count->count << "\t[thread=" << entry.thread_id << "]" << std::endl;
|
||||
}
|
||||
if (entry.count->time > 0) {
|
||||
std::cout << entry.key << "(t)\t" << static_cast<double>(entry.count->time) / 1000000.0
|
||||
<< "ms\t[thread=" << entry.thread_id << "]" << std::endl;
|
||||
}
|
||||
if (entry.count->cycles > 0) {
|
||||
std::cout << entry.key << "(c)\t" << entry.count->cycles << "\t[thread=" << entry.thread_id << "]"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
std::cout << "print_op_counts() END" << std::endl;
|
||||
}
|
||||
|
||||
std::map<std::string, std::size_t> GlobalOpCountContainer::get_aggregate_counts() const
|
||||
{
|
||||
std::map<std::string, std::size_t> aggregate_counts;
|
||||
for (const Entry& entry : counts) {
|
||||
if (entry.count->count > 0) {
|
||||
aggregate_counts[entry.key] += entry.count->count;
|
||||
}
|
||||
if (entry.count->time > 0) {
|
||||
aggregate_counts[entry.key + "(t)"] += entry.count->time;
|
||||
}
|
||||
if (entry.count->cycles > 0) {
|
||||
aggregate_counts[entry.key + "(c)"] += entry.count->cycles;
|
||||
}
|
||||
}
|
||||
return aggregate_counts;
|
||||
}
|
||||
|
||||
void GlobalOpCountContainer::clear()
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
for (Entry& entry : counts) {
|
||||
*entry.count = OpStats();
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
GlobalOpCountContainer GLOBAL_OP_COUNTS;
|
||||
|
||||
OpCountCycleReporter::OpCountCycleReporter(OpStats* stats)
|
||||
: stats(stats)
|
||||
{
|
||||
#if __clang__ && (defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86))
|
||||
// Don't support any other targets but x86 clang for now, this is a bit lazy but more than fits our needs
|
||||
cycles = __builtin_ia32_rdtsc();
|
||||
#endif
|
||||
}
|
||||
OpCountCycleReporter::~OpCountCycleReporter()
|
||||
{
|
||||
#if __clang__ && (defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86))
|
||||
// Don't support any other targets but x86 clang for now, this is a bit lazy but more than fits our needs
|
||||
stats->count += 1;
|
||||
stats->cycles += __builtin_ia32_rdtsc() - cycles;
|
||||
#endif
|
||||
}
|
||||
OpCountTimeReporter::OpCountTimeReporter(OpStats* stats)
|
||||
: stats(stats)
|
||||
{
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
auto now_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
|
||||
time = static_cast<std::size_t>(now_ns.time_since_epoch().count());
|
||||
}
|
||||
OpCountTimeReporter::~OpCountTimeReporter()
|
||||
{
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
auto now_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
|
||||
stats->count += 1;
|
||||
stats->time += static_cast<std::size_t>(now_ns.time_since_epoch().count()) - time;
|
||||
}
|
||||
} // namespace bb::detail
|
||||
#endif
|
||||
160
sumcheck/src/cuda/includes/barretenberg/common/op_count.hpp
Normal file
160
sumcheck/src/cuda/includes/barretenberg/common/op_count.hpp
Normal file
@@ -0,0 +1,160 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#ifndef BB_USE_OP_COUNT
|
||||
// require a semicolon to appease formatters
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TRACK() (void)0
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TRACK_NAME(name) (void)0
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_CYCLES_NAME(name) (void)0
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TIME_NAME(name) (void)0
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_CYCLES() (void)0
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TIME() (void)0
|
||||
#else
|
||||
/**
|
||||
* Provides an abstraction that counts operations based on function names.
|
||||
* For efficiency, we spread out counts across threads.
|
||||
*/
|
||||
|
||||
#include "./compiler_hints.hpp"
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace bb::detail {
|
||||
// Compile-time string
|
||||
// See e.g. https://www.reddit.com/r/cpp_questions/comments/pumi9r/does_c20_not_support_string_literals_as_template/
|
||||
template <std::size_t N> struct OperationLabel {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
|
||||
constexpr OperationLabel(const char (&str)[N])
|
||||
{
|
||||
for (std::size_t i = 0; i < N; ++i) {
|
||||
value[i] = str[i];
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
|
||||
char value[N];
|
||||
};
|
||||
|
||||
struct OpStats {
|
||||
std::size_t count = 0;
|
||||
std::size_t time = 0;
|
||||
std::size_t cycles = 0;
|
||||
};
|
||||
|
||||
// Contains all statically known op counts
|
||||
struct GlobalOpCountContainer {
|
||||
public:
|
||||
struct Entry {
|
||||
std::string key;
|
||||
std::string thread_id;
|
||||
std::shared_ptr<OpStats> count;
|
||||
};
|
||||
~GlobalOpCountContainer();
|
||||
std::mutex mutex;
|
||||
std::vector<Entry> counts;
|
||||
void print() const;
|
||||
// NOTE: Should be called when other threads aren't active
|
||||
void clear();
|
||||
void add_entry(const char* key, const std::shared_ptr<OpStats>& count);
|
||||
std::map<std::string, std::size_t> get_aggregate_counts() const;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
extern GlobalOpCountContainer GLOBAL_OP_COUNTS;
|
||||
|
||||
template <OperationLabel Op> struct GlobalOpCount {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static thread_local std::shared_ptr<OpStats> stats;
|
||||
|
||||
static OpStats* ensure_stats()
|
||||
{
|
||||
if (BB_UNLIKELY(stats == nullptr)) {
|
||||
stats = std::make_shared<OpStats>();
|
||||
GLOBAL_OP_COUNTS.add_entry(Op.value, stats);
|
||||
}
|
||||
return stats.get();
|
||||
}
|
||||
static constexpr void increment_op_count()
|
||||
{
|
||||
#ifndef BB_USE_OP_COUNT_TIME_ONLY
|
||||
if (std::is_constant_evaluated()) {
|
||||
// We do nothing if the compiler tries to run this
|
||||
return;
|
||||
}
|
||||
ensure_stats();
|
||||
stats->count++;
|
||||
#endif
|
||||
}
|
||||
static constexpr void add_cycle_time(std::size_t cycles)
|
||||
{
|
||||
#ifndef BB_USE_OP_COUNT_TRACK_ONLY
|
||||
if (std::is_constant_evaluated()) {
|
||||
// We do nothing if the compiler tries to run this
|
||||
return;
|
||||
}
|
||||
ensure_stats();
|
||||
stats->cycles += cycles;
|
||||
#else
|
||||
static_cast<void>(cycles);
|
||||
#endif
|
||||
}
|
||||
static constexpr void add_clock_time(std::size_t time)
|
||||
{
|
||||
#ifndef BB_USE_OP_COUNT_TRACK_ONLY
|
||||
if (std::is_constant_evaluated()) {
|
||||
// We do nothing if the compiler tries to run this
|
||||
return;
|
||||
}
|
||||
ensure_stats();
|
||||
stats->time += time;
|
||||
#else
|
||||
static_cast<void>(time);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
template <OperationLabel Op> thread_local std::shared_ptr<OpStats> GlobalOpCount<Op>::stats;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
|
||||
struct OpCountCycleReporter {
|
||||
OpStats* stats;
|
||||
std::size_t cycles;
|
||||
OpCountCycleReporter(OpStats* stats);
|
||||
~OpCountCycleReporter();
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
|
||||
struct OpCountTimeReporter {
|
||||
OpStats* stats;
|
||||
std::size_t time;
|
||||
OpCountTimeReporter(OpStats* stats);
|
||||
~OpCountTimeReporter();
|
||||
};
|
||||
} // namespace bb::detail
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TRACK_NAME(name) bb::detail::GlobalOpCount<name>::increment_op_count()
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TRACK() BB_OP_COUNT_TRACK_NAME(__func__)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_CYCLES_NAME(name) \
|
||||
bb::detail::OpCountCycleReporter __bb_op_count_cyles(bb::detail::GlobalOpCount<name>::ensure_stats())
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_CYCLES() BB_OP_COUNT_CYCLES_NAME(__func__)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TIME_NAME(name) \
|
||||
bb::detail::OpCountTimeReporter __bb_op_count_time(bb::detail::GlobalOpCount<name>::ensure_stats())
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BB_OP_COUNT_TIME() BB_OP_COUNT_TIME_NAME(__func__)
|
||||
#endif
|
||||
@@ -0,0 +1,243 @@
|
||||
#include "slab_allocator.hpp"
|
||||
#include <barretenberg/common/assert.hpp>
|
||||
#include <barretenberg/common/log.hpp>
|
||||
#include <barretenberg/common/mem.hpp>
|
||||
#include <cstddef>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
||||
#define LOGGING 0
|
||||
|
||||
/**
|
||||
* If we can guarantee that all slabs will be released before the allocator is destroyed, we wouldn't need this.
|
||||
* However, there is (and maybe again) cases where a global is holding onto a slab. In such a case you will have
|
||||
* issues if the runtime frees the allocator before the slab is released. The effect is subtle, so it's worth
|
||||
* protecting against rather than just saying "don't do globals". But you know, don't do globals...
|
||||
* (Irony of global slab allocator noted).
|
||||
*/
|
||||
namespace {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
bool allocator_destroyed = false;
|
||||
|
||||
// Slabs that are being manually managed by the user.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
std::unordered_map<void*, std::shared_ptr<void>> manual_slabs;
|
||||
#ifndef NO_MULTITHREADING
|
||||
// The manual slabs unordered map is not thread-safe, so we need to manage access to it when multithreaded.
|
||||
std::mutex manual_slabs_mutex;
|
||||
#endif
|
||||
template <typename... Args> inline void dbg_info(Args... args)
|
||||
{
|
||||
#if LOGGING == 1
|
||||
info(args...);
|
||||
#else
|
||||
// Suppress warning.
|
||||
(void)(sizeof...(args));
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows preallocating memory slabs sized to serve the fact that these slabs of memory follow certain sizing
|
||||
* patterns and numbers based on prover system type and circuit size. Without the slab allocator, memory
|
||||
* fragmentation prevents proof construction when approaching memory space limits (4GB in WASM).
|
||||
*
|
||||
* If no circuit_size_hint is given to the constructor, it behaves as a standard memory allocator.
|
||||
*/
|
||||
class SlabAllocator {
|
||||
private:
|
||||
size_t circuit_size_hint_;
|
||||
std::map<size_t, std::list<void*>> memory_store;
|
||||
#ifndef NO_MULTITHREADING
|
||||
std::mutex memory_store_mutex;
|
||||
#endif
|
||||
|
||||
public:
|
||||
~SlabAllocator();
|
||||
SlabAllocator() = default;
|
||||
SlabAllocator(const SlabAllocator& other) = delete;
|
||||
SlabAllocator(SlabAllocator&& other) = delete;
|
||||
SlabAllocator& operator=(const SlabAllocator& other) = delete;
|
||||
SlabAllocator& operator=(SlabAllocator&& other) = delete;
|
||||
|
||||
void init(size_t circuit_size_hint);
|
||||
|
||||
std::shared_ptr<void> get(size_t size);
|
||||
|
||||
size_t get_total_size();
|
||||
|
||||
private:
|
||||
void release(void* ptr, size_t size);
|
||||
};
|
||||
|
||||
SlabAllocator::~SlabAllocator()
|
||||
{
|
||||
allocator_destroyed = true;
|
||||
for (auto& e : memory_store) {
|
||||
for (auto& p : e.second) {
|
||||
aligned_free(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SlabAllocator::init(size_t circuit_size_hint)
|
||||
{
|
||||
if (circuit_size_hint <= circuit_size_hint_) {
|
||||
return;
|
||||
}
|
||||
|
||||
circuit_size_hint_ = circuit_size_hint;
|
||||
|
||||
// Free any existing slabs.
|
||||
for (auto& e : memory_store) {
|
||||
for (auto& p : e.second) {
|
||||
aligned_free(p);
|
||||
}
|
||||
}
|
||||
memory_store.clear();
|
||||
|
||||
dbg_info("slab allocator initing for size: ", circuit_size_hint);
|
||||
|
||||
if (circuit_size_hint == 0ULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Over-allocate because we know there are requests for circuit_size + n. (somewhat arbitrary n = 512)
|
||||
size_t overalloc = 512;
|
||||
size_t base_size = circuit_size_hint + overalloc;
|
||||
|
||||
std::map<size_t, size_t> prealloc_num;
|
||||
|
||||
// Size comments below assume a base (circuit) size of 2^19, 524288 bytes.
|
||||
|
||||
// /* 0.5 MiB */ prealloc_num[base_size * 1] = 2; // Batch invert skipped temporary.
|
||||
// /* 2 MiB */ prealloc_num[base_size * 4] = 4 + // Composer base wire vectors.
|
||||
// 1; // Miscellaneous.
|
||||
// /* 6 MiB */ prealloc_num[base_size * 12] = 2 + // next_var_index, prev_var_index
|
||||
// 2; // real_variable_index, real_variable_tags
|
||||
/* 16 MiB */ prealloc_num[base_size * 32] = 11; // Composer base selector vectors.
|
||||
/* 32 MiB */ prealloc_num[base_size * 32 * 2] = 1; // Miscellaneous.
|
||||
/* 50 MiB */ prealloc_num[base_size * 32 * 3] = 1; // Variables.
|
||||
/* 64 MiB */ prealloc_num[base_size * 32 * 4] = 1 + // SRS monomial points.
|
||||
4 + // Coset-fft wires.
|
||||
15 + // Coset-fft constraint selectors.
|
||||
8 + // Coset-fft perm selectors.
|
||||
1 + // Coset-fft sorted poly.
|
||||
1 + // Pippenger point_schedule.
|
||||
4; // Miscellaneous.
|
||||
/* 128 MiB */ prealloc_num[base_size * 32 * 8] = 1 + // Proving key evaluation domain roots.
|
||||
2; // Pippenger point_pairs.
|
||||
|
||||
for (auto& e : prealloc_num) {
|
||||
for (size_t i = 0; i < e.second; ++i) {
|
||||
auto size = e.first;
|
||||
memory_store[size].push_back(aligned_alloc(32, size));
|
||||
dbg_info("Allocated memory slab of size: ", size, " total: ", get_total_size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<void> SlabAllocator::get(size_t req_size)
|
||||
{
|
||||
#ifndef NO_MULTITHREADING
|
||||
std::unique_lock<std::mutex> lock(memory_store_mutex);
|
||||
#endif
|
||||
|
||||
auto it = memory_store.lower_bound(req_size);
|
||||
|
||||
// Can use a preallocated slab that is less than 2 times the requested size.
|
||||
if (it != memory_store.end() && it->first < req_size * 2) {
|
||||
size_t size = it->first;
|
||||
auto* ptr = it->second.back();
|
||||
it->second.pop_back();
|
||||
|
||||
if (it->second.empty()) {
|
||||
memory_store.erase(it);
|
||||
}
|
||||
|
||||
if (req_size >= circuit_size_hint_ && size > req_size + req_size / 10) {
|
||||
dbg_info("WARNING: Using memory slab of size: ",
|
||||
size,
|
||||
" for requested ",
|
||||
req_size,
|
||||
" total: ",
|
||||
get_total_size());
|
||||
} else {
|
||||
dbg_info("Reusing memory slab of size: ", size, " for requested ", req_size, " total: ", get_total_size());
|
||||
}
|
||||
|
||||
return { ptr, [this, size](void* p) {
|
||||
if (allocator_destroyed) {
|
||||
aligned_free(p);
|
||||
return;
|
||||
}
|
||||
this->release(p, size);
|
||||
} };
|
||||
}
|
||||
|
||||
if (req_size > static_cast<size_t>(1024 * 1024)) {
|
||||
dbg_info("WARNING: Allocating unmanaged memory slab of size: ", req_size);
|
||||
}
|
||||
if (req_size % 32 == 0) {
|
||||
return { aligned_alloc(32, req_size), aligned_free };
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
return { malloc(req_size), free };
|
||||
}
|
||||
|
||||
size_t SlabAllocator::get_total_size()
|
||||
{
|
||||
return std::accumulate(memory_store.begin(), memory_store.end(), size_t{ 0 }, [](size_t acc, const auto& kv) {
|
||||
return acc + kv.first * kv.second.size();
|
||||
});
|
||||
}
|
||||
|
||||
void SlabAllocator::release(void* ptr, size_t size)
|
||||
{
|
||||
#ifndef NO_MULTITHREADING
|
||||
std::unique_lock<std::mutex> lock(memory_store_mutex);
|
||||
#endif
|
||||
memory_store[size].push_back(ptr);
|
||||
// dbg_info("Pooled poly memory of size: ", size, " total: ", get_total_size());
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
SlabAllocator allocator;
|
||||
} // namespace
|
||||
|
||||
namespace bb {
|
||||
void init_slab_allocator(size_t circuit_subgroup_size)
|
||||
{
|
||||
allocator.init(circuit_subgroup_size);
|
||||
}
|
||||
|
||||
// auto init = ([]() {
|
||||
// init_slab_allocator(524288);
|
||||
// return 0;
|
||||
// })();
|
||||
|
||||
std::shared_ptr<void> get_mem_slab(size_t size)
|
||||
{
|
||||
return allocator.get(size);
|
||||
}
|
||||
|
||||
void* get_mem_slab_raw(size_t size)
|
||||
{
|
||||
auto slab = get_mem_slab(size);
|
||||
#ifndef NO_MULTITHREADING
|
||||
std::unique_lock<std::mutex> lock(manual_slabs_mutex);
|
||||
#endif
|
||||
manual_slabs[slab.get()] = slab;
|
||||
return slab.get();
|
||||
}
|
||||
|
||||
void free_mem_slab_raw(void* p)
|
||||
{
|
||||
if (allocator_destroyed) {
|
||||
aligned_free(p);
|
||||
return;
|
||||
}
|
||||
#ifndef NO_MULTITHREADING
|
||||
std::unique_lock<std::mutex> lock(manual_slabs_mutex);
|
||||
#endif
|
||||
manual_slabs.erase(p);
|
||||
}
|
||||
} // namespace bb
|
||||
@@ -0,0 +1,78 @@
|
||||
#pragma once
|
||||
#include "./assert.hpp"
|
||||
#include "./log.hpp"
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#ifndef NO_MULTITHREADING
|
||||
#include <mutex>
|
||||
#endif
|
||||
|
||||
namespace bb {
|
||||
|
||||
/**
|
||||
* Allocates a bunch of memory slabs sized to serve an UltraPLONK proof construction.
|
||||
* If you want normal memory allocator behavior, just don't call this init function.
|
||||
*
|
||||
* WARNING: If client code is still holding onto slabs from previous use, when those slabs
|
||||
* are released they'll end up back in the allocator. That's probably not desired as presumably
|
||||
* those slabs are now too small, so they're effectively leaked. But good client code should be releasing
|
||||
* it's resources promptly anyway. It's not considered "proper use" to call init, take slab, and call init
|
||||
* again, before releasing the slab.
|
||||
*
|
||||
* TODO: Take a composer type and allocate slabs according to those requirements?
|
||||
* TODO: De-globalise. Init the allocator and pass around. Use a PolynomialFactory (PolynomialStore?).
|
||||
* TODO: Consider removing, but once due-dilligence has been done that we no longer have memory limitations.
|
||||
*/
|
||||
void init_slab_allocator(size_t circuit_subgroup_size);
|
||||
|
||||
/**
|
||||
* Returns a slab from the preallocated pool of slabs, or fallback to a new heap allocation (32 byte aligned).
|
||||
* Ref counted result so no need to manually free.
|
||||
*/
|
||||
std::shared_ptr<void> get_mem_slab(size_t size);
|
||||
|
||||
/**
|
||||
* Sometimes you want a raw pointer to a slab so you can manage when it's released manually (e.g. c_binds, containers).
|
||||
* This still gets a slab with a shared_ptr, but holds the shared_ptr internally until free_mem_slab_raw is called.
|
||||
*/
|
||||
void* get_mem_slab_raw(size_t size);
|
||||
|
||||
void free_mem_slab_raw(void*);
|
||||
|
||||
/**
|
||||
* Allocator for containers such as std::vector. Makes them leverage the underlying slab allocator where possible.
|
||||
*/
|
||||
template <typename T> class ContainerSlabAllocator {
|
||||
public:
|
||||
using value_type = T;
|
||||
using pointer = T*;
|
||||
using const_pointer = const T*;
|
||||
using size_type = std::size_t;
|
||||
|
||||
template <typename U> struct rebind {
|
||||
using other = ContainerSlabAllocator<U>;
|
||||
};
|
||||
|
||||
pointer allocate(size_type n)
|
||||
{
|
||||
// info("ContainerSlabAllocator allocating: ", n * sizeof(T));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
return reinterpret_cast<pointer>(get_mem_slab_raw(n * sizeof(T)));
|
||||
}
|
||||
|
||||
void deallocate(pointer p, size_type /*unused*/) { free_mem_slab_raw(p); }
|
||||
|
||||
friend bool operator==(const ContainerSlabAllocator<T>& /*unused*/, const ContainerSlabAllocator<T>& /*unused*/)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
friend bool operator!=(const ContainerSlabAllocator<T>& /*unused*/, const ContainerSlabAllocator<T>& /*unused*/)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace bb
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
#include "log.hpp"
|
||||
#include <string>
|
||||
|
||||
inline void throw_or_abort [[noreturn]] (std::string const& err)
|
||||
{
|
||||
#ifndef __wasm__
|
||||
throw std::runtime_error(err);
|
||||
#else
|
||||
info("abort: ", err);
|
||||
std::abort();
|
||||
#endif
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
#ifdef __clang__
|
||||
#define WASM_EXPORT extern "C" __attribute__((visibility("default"))) __attribute__((annotate("wasm_export")))
|
||||
#define ASYNC_WASM_EXPORT \
|
||||
extern "C" __attribute__((visibility("default"))) __attribute__((annotate("async_wasm_export")))
|
||||
#else
|
||||
#define WASM_EXPORT extern "C" __attribute__((visibility("default")))
|
||||
#define ASYNC_WASM_EXPORT extern "C" __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#ifdef __wasm__
|
||||
// Allow linker to not link this
|
||||
#define WASM_IMPORT(name) extern "C" __attribute__((import_module("env"), import_name(name)))
|
||||
#else
|
||||
#define WASM_IMPORT(name) extern "C"
|
||||
#endif
|
||||
@@ -0,0 +1 @@
|
||||
barretenberg_module(crypto_blake3s)
|
||||
@@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
/*
|
||||
BLAKE3 reference source code package - C implementations
|
||||
|
||||
Intellectual property:
|
||||
|
||||
The Rust code is copyright Jack O'Connor, 2019-2020.
|
||||
The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020.
|
||||
The assembly code is copyright Samuel Neves, 2019-2020.
|
||||
|
||||
This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache
|
||||
License 2.0.
|
||||
|
||||
- CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0
|
||||
- Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
More information about the BLAKE3 hash function can be found at
|
||||
https://github.com/BLAKE3-team/BLAKE3.
|
||||
*/
|
||||
|
||||
#ifndef BLAKE3_IMPL_H
|
||||
#define BLAKE3_IMPL_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
#include "blake3s.hpp"
|
||||
|
||||
namespace blake3 {
|
||||
|
||||
// Right rotates 32 bit inputs
|
||||
constexpr uint32_t rotr32(uint32_t w, uint32_t c)
|
||||
{
|
||||
return (w >> c) | (w << (32 - c));
|
||||
}
|
||||
|
||||
constexpr uint32_t load32(const uint8_t* src)
|
||||
{
|
||||
return (static_cast<uint32_t>(src[0]) << 0) | (static_cast<uint32_t>(src[1]) << 8) |
|
||||
(static_cast<uint32_t>(src[2]) << 16) | (static_cast<uint32_t>(src[3]) << 24);
|
||||
}
|
||||
|
||||
constexpr void load_key_words(const std::array<uint8_t, BLAKE3_KEY_LEN>& key, key_array& key_words)
|
||||
{
|
||||
key_words[0] = load32(&key[0]);
|
||||
key_words[1] = load32(&key[4]);
|
||||
key_words[2] = load32(&key[8]);
|
||||
key_words[3] = load32(&key[12]);
|
||||
key_words[4] = load32(&key[16]);
|
||||
key_words[5] = load32(&key[20]);
|
||||
key_words[6] = load32(&key[24]);
|
||||
key_words[7] = load32(&key[28]);
|
||||
}
|
||||
|
||||
constexpr void store32(uint8_t* dst, uint32_t w)
|
||||
{
|
||||
dst[0] = static_cast<uint8_t>(w >> 0);
|
||||
dst[1] = static_cast<uint8_t>(w >> 8);
|
||||
dst[2] = static_cast<uint8_t>(w >> 16);
|
||||
dst[3] = static_cast<uint8_t>(w >> 24);
|
||||
}
|
||||
|
||||
constexpr void store_cv_words(out_array& bytes_out, key_array& cv_words)
|
||||
{
|
||||
store32(&bytes_out[0], cv_words[0]);
|
||||
store32(&bytes_out[4], cv_words[1]);
|
||||
store32(&bytes_out[8], cv_words[2]);
|
||||
store32(&bytes_out[12], cv_words[3]);
|
||||
store32(&bytes_out[16], cv_words[4]);
|
||||
store32(&bytes_out[20], cv_words[5]);
|
||||
store32(&bytes_out[24], cv_words[6]);
|
||||
store32(&bytes_out[28], cv_words[7]);
|
||||
}
|
||||
|
||||
} // namespace blake3
|
||||
|
||||
#include "blake3s.tcc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,113 @@
|
||||
/*
|
||||
BLAKE3 reference source code package - C implementations
|
||||
|
||||
Intellectual property:
|
||||
|
||||
The Rust code is copyright Jack O'Connor, 2019-2020.
|
||||
The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020.
|
||||
The assembly code is copyright Samuel Neves, 2019-2020.
|
||||
|
||||
This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache
|
||||
License 2.0.
|
||||
|
||||
- CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0
|
||||
- Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
More information about the BLAKE3 hash function can be found at
|
||||
https://github.com/BLAKE3-team/BLAKE3.
|
||||
|
||||
|
||||
NOTE: We have modified the original code from the BLAKE3 reference C implementation.
|
||||
The following code works ONLY for inputs of size less than 1024 bytes. This kind of constraint
|
||||
on the input size greatly simplifies the code and helps us get rid of the recursive merkle-tree
|
||||
like operations on chunks (data of size 1024 bytes). This is because we would always be using BLAKE3
|
||||
hashing for inputs of size 32 bytes (or lesser) in barretenberg. The full C++ version of BLAKE3
|
||||
from the original authors is in the module `../crypto/blake3s_full`.
|
||||
|
||||
Also, the length of the output in this specific implementation is fixed at 32 bytes which is the only
|
||||
version relevant to Barretenberg.
|
||||
*/
|
||||
#pragma once
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace blake3 {
|
||||
|
||||
// internal flags
|
||||
enum blake3_flags {
|
||||
CHUNK_START = 1 << 0,
|
||||
CHUNK_END = 1 << 1,
|
||||
PARENT = 1 << 2,
|
||||
ROOT = 1 << 3,
|
||||
KEYED_HASH = 1 << 4,
|
||||
DERIVE_KEY_CONTEXT = 1 << 5,
|
||||
DERIVE_KEY_MATERIAL = 1 << 6,
|
||||
};
|
||||
|
||||
// constants
|
||||
enum blake3s_constant {
|
||||
BLAKE3_KEY_LEN = 32,
|
||||
BLAKE3_OUT_LEN = 32,
|
||||
BLAKE3_BLOCK_LEN = 64,
|
||||
BLAKE3_CHUNK_LEN = 1024,
|
||||
BLAKE3_MAX_DEPTH = 54
|
||||
};
|
||||
|
||||
using key_array = std::array<uint32_t, BLAKE3_KEY_LEN>;
|
||||
using block_array = std::array<uint8_t, BLAKE3_BLOCK_LEN>;
|
||||
using state_array = std::array<uint32_t, 16>;
|
||||
using out_array = std::array<uint8_t, BLAKE3_OUT_LEN>;
|
||||
|
||||
static constexpr key_array IV = { 0x6A09E667UL, 0xBB67AE85UL, 0x3C6EF372UL, 0xA54FF53AUL,
|
||||
0x510E527FUL, 0x9B05688CUL, 0x1F83D9ABUL, 0x5BE0CD19UL };
|
||||
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_0 = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_1 = { 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 };
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_2 = { 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 };
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_3 = { 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 };
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_4 = { 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 };
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_5 = { 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 };
|
||||
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_6 = { 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 };
|
||||
static constexpr std::array<std::array<uint8_t, 16>, 7> MSG_SCHEDULE = {
|
||||
MSG_SCHEDULE_0, MSG_SCHEDULE_1, MSG_SCHEDULE_2, MSG_SCHEDULE_3, MSG_SCHEDULE_4, MSG_SCHEDULE_5, MSG_SCHEDULE_6,
|
||||
};
|
||||
|
||||
struct blake3_hasher {
|
||||
key_array key;
|
||||
key_array cv;
|
||||
block_array buf;
|
||||
uint8_t buf_len = 0;
|
||||
uint8_t blocks_compressed = 0;
|
||||
uint8_t flags = 0;
|
||||
};
|
||||
|
||||
inline const char* blake3_version()
|
||||
{
|
||||
static const std::string version = "0.3.7";
|
||||
return version.c_str();
|
||||
}
|
||||
|
||||
constexpr void blake3_hasher_init(blake3_hasher* self);
|
||||
constexpr void blake3_hasher_update(blake3_hasher* self, const uint8_t* input, size_t input_len);
|
||||
constexpr void blake3_hasher_finalize(const blake3_hasher* self, uint8_t* out);
|
||||
|
||||
constexpr void g(state_array& state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y);
|
||||
constexpr void round_fn(state_array& state, const uint32_t* msg, size_t round);
|
||||
|
||||
constexpr void compress_pre(
|
||||
state_array& state, const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags);
|
||||
|
||||
constexpr void blake3_compress_in_place(key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags);
|
||||
|
||||
constexpr void blake3_compress_xof(
|
||||
const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags, uint8_t* out);
|
||||
|
||||
constexpr std::array<uint8_t, BLAKE3_OUT_LEN> blake3s_constexpr(const uint8_t* input, size_t input_size);
|
||||
inline std::vector<uint8_t> blake3s(std::vector<uint8_t> const& input);
|
||||
|
||||
} // namespace blake3
|
||||
|
||||
#include "blake3-impl.hpp"
|
||||
@@ -0,0 +1,263 @@
|
||||
#pragma once
|
||||
/*
|
||||
BLAKE3 reference source code package - C implementations
|
||||
|
||||
Intellectual property:
|
||||
|
||||
The Rust code is copyright Jack O'Connor, 2019-2020.
|
||||
The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020.
|
||||
The assembly code is copyright Samuel Neves, 2019-2020.
|
||||
|
||||
This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache
|
||||
License 2.0.
|
||||
|
||||
- CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0
|
||||
- Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
More information about the BLAKE3 hash function can be found at
|
||||
https://github.com/BLAKE3-team/BLAKE3.
|
||||
|
||||
|
||||
NOTE: We have modified the original code from the BLAKE3 reference C implementation.
|
||||
The following code works ONLY for inputs of size less than 1024 bytes. This kind of constraint
|
||||
on the input size greatly simplifies the code and helps us get rid of the recursive merkle-tree
|
||||
like operations on chunks (data of size 1024 bytes). This is because we would always be using BLAKE3
|
||||
hashing for inputs of size 32 bytes (or lesser) in barretenberg. The full C++ version of BLAKE3
|
||||
from the original authors is in the module `../crypto/blake3s_full`.
|
||||
|
||||
Also, the length of the output in this specific implementation is fixed at 32 bytes which is the only
|
||||
version relevant to Barretenberg.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "blake3s.hpp"
|
||||
|
||||
namespace blake3 {
|
||||
|
||||
/*
|
||||
* Core Blake3s functions. These are similar to that of Blake2s except for a few
|
||||
* constant parameters and fewer rounds.
|
||||
*
|
||||
*/
|
||||
constexpr void g(state_array& state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y)
|
||||
{
|
||||
state[a] = state[a] + state[b] + x;
|
||||
state[d] = rotr32(state[d] ^ state[a], 16);
|
||||
state[c] = state[c] + state[d];
|
||||
state[b] = rotr32(state[b] ^ state[c], 12);
|
||||
state[a] = state[a] + state[b] + y;
|
||||
state[d] = rotr32(state[d] ^ state[a], 8);
|
||||
state[c] = state[c] + state[d];
|
||||
state[b] = rotr32(state[b] ^ state[c], 7);
|
||||
}
|
||||
|
||||
constexpr void round_fn(state_array& state, const uint32_t* msg, size_t round)
|
||||
{
|
||||
// Select the message schedule based on the round.
|
||||
const auto schedule = MSG_SCHEDULE[round];
|
||||
|
||||
// Mix the columns.
|
||||
g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
|
||||
g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
|
||||
g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
|
||||
g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
|
||||
|
||||
// Mix the rows.
|
||||
g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
|
||||
g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
|
||||
g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
|
||||
g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
|
||||
}
|
||||
|
||||
constexpr void compress_pre(
|
||||
state_array& state, const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags)
|
||||
{
|
||||
std::array<uint32_t, 16> block_words;
|
||||
block_words[0] = load32(&block[0]);
|
||||
block_words[1] = load32(&block[4]);
|
||||
block_words[2] = load32(&block[8]);
|
||||
block_words[3] = load32(&block[12]);
|
||||
block_words[4] = load32(&block[16]);
|
||||
block_words[5] = load32(&block[20]);
|
||||
block_words[6] = load32(&block[24]);
|
||||
block_words[7] = load32(&block[28]);
|
||||
block_words[8] = load32(&block[32]);
|
||||
block_words[9] = load32(&block[36]);
|
||||
block_words[10] = load32(&block[40]);
|
||||
block_words[11] = load32(&block[44]);
|
||||
block_words[12] = load32(&block[48]);
|
||||
block_words[13] = load32(&block[52]);
|
||||
block_words[14] = load32(&block[56]);
|
||||
block_words[15] = load32(&block[60]);
|
||||
|
||||
state[0] = cv[0];
|
||||
state[1] = cv[1];
|
||||
state[2] = cv[2];
|
||||
state[3] = cv[3];
|
||||
state[4] = cv[4];
|
||||
state[5] = cv[5];
|
||||
state[6] = cv[6];
|
||||
state[7] = cv[7];
|
||||
state[8] = IV[0];
|
||||
state[9] = IV[1];
|
||||
state[10] = IV[2];
|
||||
state[11] = IV[3];
|
||||
state[12] = 0;
|
||||
state[13] = 0;
|
||||
state[14] = static_cast<uint32_t>(block_len);
|
||||
state[15] = static_cast<uint32_t>(flags);
|
||||
|
||||
round_fn(state, &block_words[0], 0);
|
||||
round_fn(state, &block_words[0], 1);
|
||||
round_fn(state, &block_words[0], 2);
|
||||
round_fn(state, &block_words[0], 3);
|
||||
round_fn(state, &block_words[0], 4);
|
||||
round_fn(state, &block_words[0], 5);
|
||||
round_fn(state, &block_words[0], 6);
|
||||
}
|
||||
|
||||
constexpr void blake3_compress_in_place(key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags)
|
||||
{
|
||||
state_array state;
|
||||
compress_pre(state, cv, block, block_len, flags);
|
||||
cv[0] = state[0] ^ state[8];
|
||||
cv[1] = state[1] ^ state[9];
|
||||
cv[2] = state[2] ^ state[10];
|
||||
cv[3] = state[3] ^ state[11];
|
||||
cv[4] = state[4] ^ state[12];
|
||||
cv[5] = state[5] ^ state[13];
|
||||
cv[6] = state[6] ^ state[14];
|
||||
cv[7] = state[7] ^ state[15];
|
||||
}
|
||||
|
||||
constexpr void blake3_compress_xof(
|
||||
const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags, uint8_t* out)
|
||||
{
|
||||
state_array state;
|
||||
compress_pre(state, cv, block, block_len, flags);
|
||||
|
||||
store32(&out[0], state[0] ^ state[8]);
|
||||
store32(&out[4], state[1] ^ state[9]);
|
||||
store32(&out[8], state[2] ^ state[10]);
|
||||
store32(&out[12], state[3] ^ state[11]);
|
||||
store32(&out[16], state[4] ^ state[12]);
|
||||
store32(&out[20], state[5] ^ state[13]);
|
||||
store32(&out[24], state[6] ^ state[14]);
|
||||
store32(&out[28], state[7] ^ state[15]);
|
||||
store32(&out[32], state[8] ^ cv[0]);
|
||||
store32(&out[36], state[9] ^ cv[1]);
|
||||
store32(&out[40], state[10] ^ cv[2]);
|
||||
store32(&out[44], state[11] ^ cv[3]);
|
||||
store32(&out[48], state[12] ^ cv[4]);
|
||||
store32(&out[52], state[13] ^ cv[5]);
|
||||
store32(&out[56], state[14] ^ cv[6]);
|
||||
store32(&out[60], state[15] ^ cv[7]);
|
||||
}
|
||||
|
||||
constexpr uint8_t maybe_start_flag(const blake3_hasher* self)
|
||||
{
|
||||
if (self->blocks_compressed == 0) {
|
||||
return CHUNK_START;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct output_t {
|
||||
key_array input_cv = {};
|
||||
block_array block = {};
|
||||
uint8_t block_len = 0;
|
||||
uint8_t flags = 0;
|
||||
};
|
||||
|
||||
constexpr output_t make_output(const key_array& input_cv, const uint8_t* block, uint8_t block_len, uint8_t flags)
|
||||
{
|
||||
output_t ret;
|
||||
for (size_t i = 0; i < (BLAKE3_OUT_LEN >> 2); ++i) {
|
||||
ret.input_cv[i] = input_cv[i];
|
||||
}
|
||||
for (size_t i = 0; i < BLAKE3_BLOCK_LEN; i++) {
|
||||
ret.block[i] = block[i];
|
||||
}
|
||||
ret.block_len = block_len;
|
||||
ret.flags = flags;
|
||||
return ret;
|
||||
}
|
||||
|
||||
constexpr void blake3_hasher_init(blake3_hasher* self)
|
||||
{
|
||||
for (size_t i = 0; i < (BLAKE3_KEY_LEN >> 2); ++i) {
|
||||
self->key[i] = IV[i];
|
||||
self->cv[i] = IV[i];
|
||||
}
|
||||
for (size_t i = 0; i < BLAKE3_BLOCK_LEN; i++) {
|
||||
self->buf[i] = 0;
|
||||
}
|
||||
self->buf_len = 0;
|
||||
self->blocks_compressed = 0;
|
||||
self->flags = 0;
|
||||
}
|
||||
|
||||
constexpr void blake3_hasher_update(blake3_hasher* self, const uint8_t* input, size_t input_len)
|
||||
{
|
||||
if (input_len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (input_len > BLAKE3_BLOCK_LEN) {
|
||||
blake3_compress_in_place(self->cv, input, BLAKE3_BLOCK_LEN, self->flags | maybe_start_flag(self));
|
||||
|
||||
self->blocks_compressed = static_cast<uint8_t>(self->blocks_compressed + 1U);
|
||||
input += BLAKE3_BLOCK_LEN;
|
||||
input_len -= BLAKE3_BLOCK_LEN;
|
||||
}
|
||||
|
||||
size_t take = BLAKE3_BLOCK_LEN - (static_cast<size_t>(self->buf_len));
|
||||
if (take > input_len) {
|
||||
take = input_len;
|
||||
}
|
||||
uint8_t* dest = &self->buf[0] + (static_cast<size_t>(self->buf_len));
|
||||
for (size_t i = 0; i < take; i++) {
|
||||
dest[i] = input[i];
|
||||
}
|
||||
|
||||
self->buf_len = static_cast<uint8_t>(self->buf_len + static_cast<uint8_t>(take));
|
||||
input_len -= take;
|
||||
}
|
||||
|
||||
constexpr void blake3_hasher_finalize(const blake3_hasher* self, uint8_t* out)
|
||||
{
|
||||
uint8_t block_flags = self->flags | maybe_start_flag(self) | CHUNK_END;
|
||||
output_t output = make_output(self->cv, &self->buf[0], self->buf_len, block_flags);
|
||||
|
||||
block_array wide_buf;
|
||||
blake3_compress_xof(output.input_cv, &output.block[0], output.block_len, output.flags | ROOT, &wide_buf[0]);
|
||||
for (size_t i = 0; i < BLAKE3_OUT_LEN; i++) {
|
||||
out[i] = wide_buf[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> blake3s(std::vector<uint8_t> const& input)
|
||||
{
|
||||
blake3_hasher hasher;
|
||||
blake3_hasher_init(&hasher);
|
||||
blake3_hasher_update(&hasher, static_cast<const uint8_t*>(input.data()), input.size());
|
||||
|
||||
std::vector<uint8_t> output(BLAKE3_OUT_LEN);
|
||||
blake3_hasher_finalize(&hasher, &output[0]);
|
||||
return output;
|
||||
}
|
||||
|
||||
constexpr std::array<uint8_t, BLAKE3_OUT_LEN> blake3s_constexpr(const uint8_t* input, const size_t input_size)
|
||||
{
|
||||
blake3_hasher hasher;
|
||||
blake3_hasher_init(&hasher);
|
||||
blake3_hasher_update(&hasher, input, input_size);
|
||||
|
||||
std::array<uint8_t, BLAKE3_OUT_LEN> output;
|
||||
blake3_hasher_finalize(&hasher, &output[0]);
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace blake3
|
||||
@@ -0,0 +1,11 @@
|
||||
#include "../../common/wasm_export.hpp"
|
||||
#include "../../ecc/curves/bn254/fr.hpp"
|
||||
#include "blake3s.hpp"
|
||||
|
||||
WASM_EXPORT void blake3s_to_field(uint8_t const* data, size_t length, uint8_t* r)
|
||||
{
|
||||
std::vector<uint8_t> inputv(data, data + length);
|
||||
std::vector<uint8_t> output = blake3::blake3s(inputv);
|
||||
auto result = bb::fr::serialize_from_buffer(output.data());
|
||||
bb::fr::serialize_to_buffer(result, r);
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
barretenberg_module(crypto_keccak)
|
||||
@@ -0,0 +1,20 @@
|
||||
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
|
||||
* Copyright 2018-2019 Pawel Bylica.
|
||||
* Licensed under the Apache License, Version 2.0.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct keccak256 {
|
||||
uint64_t word64s[4];
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
133
sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.cpp
Normal file
133
sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.cpp
Normal file
@@ -0,0 +1,133 @@
|
||||
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
|
||||
* Copyright 2018-2019 Pawel Bylica.
|
||||
* Licensed under the Apache License, Version 2.0.
|
||||
*/
|
||||
|
||||
#include "keccak.hpp"
|
||||
|
||||
#include "./hash_types.hpp"
|
||||
|
||||
#if _MSC_VER
|
||||
#include <string.h>
|
||||
#define __builtin_memcpy memcpy
|
||||
#endif
|
||||
|
||||
#if _WIN32
|
||||
/* On Windows assume little endian. */
|
||||
#define __LITTLE_ENDIAN 1234
|
||||
#define __BIG_ENDIAN 4321
|
||||
#define __BYTE_ORDER __LITTLE_ENDIAN
|
||||
#elif __APPLE__
|
||||
#include <machine/endian.h>
|
||||
#else
|
||||
#include <endian.h>
|
||||
#endif
|
||||
|
||||
#if __BYTE_ORDER == __LITTLE_ENDIAN
|
||||
#define to_le64(X) X
|
||||
#else
|
||||
#define to_le64(X) __builtin_bswap64(X)
|
||||
#endif
|
||||
|
||||
#if __BYTE_ORDER == __LITTLE_ENDIAN
|
||||
#define to_be64(X) __builtin_bswap64(X)
|
||||
#else
|
||||
#define to_be64(X) X
|
||||
#endif
|
||||
|
||||
/** Loads 64-bit integer from given memory location as little-endian number. */
|
||||
static inline uint64_t load_le(const uint8_t* data)
|
||||
{
|
||||
/* memcpy is the best way of expressing the intention. Every compiler will
|
||||
optimize is to single load instruction if the target architecture
|
||||
supports unaligned memory access (GCC and clang even in O0).
|
||||
This is great trick because we are violating C/C++ memory alignment
|
||||
restrictions with no performance penalty. */
|
||||
uint64_t word;
|
||||
__builtin_memcpy(&word, data, sizeof(word));
|
||||
return to_le64(word);
|
||||
}
|
||||
|
||||
static inline void keccak(uint64_t* out, size_t bits, const uint8_t* data, size_t size)
|
||||
{
|
||||
static const size_t word_size = sizeof(uint64_t);
|
||||
const size_t hash_size = bits / 8;
|
||||
const size_t block_size = (1600 - bits * 2) / 8;
|
||||
|
||||
size_t i;
|
||||
uint64_t* state_iter;
|
||||
uint64_t last_word = 0;
|
||||
uint8_t* last_word_iter = (uint8_t*)&last_word;
|
||||
|
||||
uint64_t state[25] = { 0 };
|
||||
|
||||
while (size >= block_size) {
|
||||
for (i = 0; i < (block_size / word_size); ++i) {
|
||||
state[i] ^= load_le(data);
|
||||
data += word_size;
|
||||
}
|
||||
|
||||
ethash_keccakf1600(state);
|
||||
|
||||
size -= block_size;
|
||||
}
|
||||
|
||||
state_iter = state;
|
||||
|
||||
while (size >= word_size) {
|
||||
*state_iter ^= load_le(data);
|
||||
++state_iter;
|
||||
data += word_size;
|
||||
size -= word_size;
|
||||
}
|
||||
|
||||
while (size > 0) {
|
||||
*last_word_iter = *data;
|
||||
++last_word_iter;
|
||||
++data;
|
||||
--size;
|
||||
}
|
||||
*last_word_iter = 0x01;
|
||||
*state_iter ^= to_le64(last_word);
|
||||
|
||||
state[(block_size / word_size) - 1] ^= 0x8000000000000000;
|
||||
|
||||
ethash_keccakf1600(state);
|
||||
|
||||
for (i = 0; i < (hash_size / word_size); ++i)
|
||||
out[i] = to_le64(state[i]);
|
||||
}
|
||||
|
||||
struct keccak256 ethash_keccak256(const uint8_t* data, size_t size) NOEXCEPT
|
||||
{
|
||||
struct keccak256 hash;
|
||||
keccak(hash.word64s, 256, data, size);
|
||||
return hash;
|
||||
}
|
||||
|
||||
struct keccak256 hash_field_elements(const uint64_t* limbs, size_t num_elements)
|
||||
{
|
||||
uint8_t input_buffer[num_elements * 32];
|
||||
|
||||
for (size_t i = 0; i < num_elements; ++i) {
|
||||
for (size_t j = 0; j < 4; ++j) {
|
||||
uint64_t word = (limbs[i * 4 + j]);
|
||||
size_t idx = i * 32 + j * 8;
|
||||
input_buffer[idx] = (uint8_t)((word >> 56) & 0xff);
|
||||
input_buffer[idx + 1] = (uint8_t)((word >> 48) & 0xff);
|
||||
input_buffer[idx + 2] = (uint8_t)((word >> 40) & 0xff);
|
||||
input_buffer[idx + 3] = (uint8_t)((word >> 32) & 0xff);
|
||||
input_buffer[idx + 4] = (uint8_t)((word >> 24) & 0xff);
|
||||
input_buffer[idx + 5] = (uint8_t)((word >> 16) & 0xff);
|
||||
input_buffer[idx + 6] = (uint8_t)((word >> 8) & 0xff);
|
||||
input_buffer[idx + 7] = (uint8_t)(word & 0xff);
|
||||
}
|
||||
}
|
||||
|
||||
return ethash_keccak256(input_buffer, num_elements * 32);
|
||||
}
|
||||
|
||||
struct keccak256 hash_field_element(const uint64_t* limb)
|
||||
{
|
||||
return hash_field_elements(limb, 1);
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
|
||||
* Copyright 2018-2019 Pawel Bylica.
|
||||
* Licensed under the Apache License, Version 2.0.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "./hash_types.hpp"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
#define NOEXCEPT noexcept
|
||||
#else
|
||||
#define NOEXCEPT
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* The Keccak-f[1600] function.
|
||||
*
|
||||
* The implementation of the Keccak-f function with 1600-bit width of the permutation (b).
|
||||
* The size of the state is also 1600 bit what gives 25 64-bit words.
|
||||
*
|
||||
* @param state The state of 25 64-bit words on which the permutation is to be performed.
|
||||
*/
|
||||
void ethash_keccakf1600(uint64_t state[25]) NOEXCEPT;
|
||||
|
||||
struct keccak256 ethash_keccak256(const uint8_t* data, size_t size) NOEXCEPT;
|
||||
|
||||
struct keccak256 hash_field_elements(const uint64_t* limbs, size_t num_elements);
|
||||
|
||||
struct keccak256 hash_field_element(const uint64_t* limb);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,235 @@
|
||||
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
|
||||
* Copyright 2018-2019 Pawel Bylica.
|
||||
* Licensed under the Apache License, Version 2.0.
|
||||
*/
|
||||
|
||||
#include "keccak.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
static uint64_t rol(uint64_t x, unsigned s)
|
||||
{
|
||||
return (x << s) | (x >> (64 - s));
|
||||
}
|
||||
|
||||
static const uint64_t round_constants[24] = {
|
||||
0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, 0x000000000000808b,
|
||||
0x0000000080000001, 0x8000000080008081, 0x8000000000008009, 0x000000000000008a, 0x0000000000000088,
|
||||
0x0000000080008009, 0x000000008000000a, 0x000000008000808b, 0x800000000000008b, 0x8000000000008089,
|
||||
0x8000000000008003, 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a,
|
||||
0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008,
|
||||
};
|
||||
|
||||
void ethash_keccakf1600(uint64_t state[25]) NOEXCEPT
|
||||
{
|
||||
/* The implementation based on the "simple" implementation by Ronny Van Keer. */
|
||||
|
||||
int round;
|
||||
|
||||
uint64_t Aba, Abe, Abi, Abo, Abu;
|
||||
uint64_t Aga, Age, Agi, Ago, Agu;
|
||||
uint64_t Aka, Ake, Aki, Ako, Aku;
|
||||
uint64_t Ama, Ame, Ami, Amo, Amu;
|
||||
uint64_t Asa, Ase, Asi, Aso, Asu;
|
||||
|
||||
uint64_t Eba, Ebe, Ebi, Ebo, Ebu;
|
||||
uint64_t Ega, Ege, Egi, Ego, Egu;
|
||||
uint64_t Eka, Eke, Eki, Eko, Eku;
|
||||
uint64_t Ema, Eme, Emi, Emo, Emu;
|
||||
uint64_t Esa, Ese, Esi, Eso, Esu;
|
||||
|
||||
uint64_t Ba, Be, Bi, Bo, Bu;
|
||||
|
||||
uint64_t Da, De, Di, Do, Du;
|
||||
|
||||
Aba = state[0];
|
||||
Abe = state[1];
|
||||
Abi = state[2];
|
||||
Abo = state[3];
|
||||
Abu = state[4];
|
||||
Aga = state[5];
|
||||
Age = state[6];
|
||||
Agi = state[7];
|
||||
Ago = state[8];
|
||||
Agu = state[9];
|
||||
Aka = state[10];
|
||||
Ake = state[11];
|
||||
Aki = state[12];
|
||||
Ako = state[13];
|
||||
Aku = state[14];
|
||||
Ama = state[15];
|
||||
Ame = state[16];
|
||||
Ami = state[17];
|
||||
Amo = state[18];
|
||||
Amu = state[19];
|
||||
Asa = state[20];
|
||||
Ase = state[21];
|
||||
Asi = state[22];
|
||||
Aso = state[23];
|
||||
Asu = state[24];
|
||||
|
||||
for (round = 0; round < 24; round += 2) {
|
||||
/* Round (round + 0): Axx -> Exx */
|
||||
|
||||
Ba = Aba ^ Aga ^ Aka ^ Ama ^ Asa;
|
||||
Be = Abe ^ Age ^ Ake ^ Ame ^ Ase;
|
||||
Bi = Abi ^ Agi ^ Aki ^ Ami ^ Asi;
|
||||
Bo = Abo ^ Ago ^ Ako ^ Amo ^ Aso;
|
||||
Bu = Abu ^ Agu ^ Aku ^ Amu ^ Asu;
|
||||
|
||||
Da = Bu ^ rol(Be, 1);
|
||||
De = Ba ^ rol(Bi, 1);
|
||||
Di = Be ^ rol(Bo, 1);
|
||||
Do = Bi ^ rol(Bu, 1);
|
||||
Du = Bo ^ rol(Ba, 1);
|
||||
|
||||
Ba = Aba ^ Da;
|
||||
Be = rol(Age ^ De, 44);
|
||||
Bi = rol(Aki ^ Di, 43);
|
||||
Bo = rol(Amo ^ Do, 21);
|
||||
Bu = rol(Asu ^ Du, 14);
|
||||
Eba = Ba ^ (~Be & Bi) ^ round_constants[round];
|
||||
Ebe = Be ^ (~Bi & Bo);
|
||||
Ebi = Bi ^ (~Bo & Bu);
|
||||
Ebo = Bo ^ (~Bu & Ba);
|
||||
Ebu = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Abo ^ Do, 28);
|
||||
Be = rol(Agu ^ Du, 20);
|
||||
Bi = rol(Aka ^ Da, 3);
|
||||
Bo = rol(Ame ^ De, 45);
|
||||
Bu = rol(Asi ^ Di, 61);
|
||||
Ega = Ba ^ (~Be & Bi);
|
||||
Ege = Be ^ (~Bi & Bo);
|
||||
Egi = Bi ^ (~Bo & Bu);
|
||||
Ego = Bo ^ (~Bu & Ba);
|
||||
Egu = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Abe ^ De, 1);
|
||||
Be = rol(Agi ^ Di, 6);
|
||||
Bi = rol(Ako ^ Do, 25);
|
||||
Bo = rol(Amu ^ Du, 8);
|
||||
Bu = rol(Asa ^ Da, 18);
|
||||
Eka = Ba ^ (~Be & Bi);
|
||||
Eke = Be ^ (~Bi & Bo);
|
||||
Eki = Bi ^ (~Bo & Bu);
|
||||
Eko = Bo ^ (~Bu & Ba);
|
||||
Eku = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Abu ^ Du, 27);
|
||||
Be = rol(Aga ^ Da, 36);
|
||||
Bi = rol(Ake ^ De, 10);
|
||||
Bo = rol(Ami ^ Di, 15);
|
||||
Bu = rol(Aso ^ Do, 56);
|
||||
Ema = Ba ^ (~Be & Bi);
|
||||
Eme = Be ^ (~Bi & Bo);
|
||||
Emi = Bi ^ (~Bo & Bu);
|
||||
Emo = Bo ^ (~Bu & Ba);
|
||||
Emu = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Abi ^ Di, 62);
|
||||
Be = rol(Ago ^ Do, 55);
|
||||
Bi = rol(Aku ^ Du, 39);
|
||||
Bo = rol(Ama ^ Da, 41);
|
||||
Bu = rol(Ase ^ De, 2);
|
||||
Esa = Ba ^ (~Be & Bi);
|
||||
Ese = Be ^ (~Bi & Bo);
|
||||
Esi = Bi ^ (~Bo & Bu);
|
||||
Eso = Bo ^ (~Bu & Ba);
|
||||
Esu = Bu ^ (~Ba & Be);
|
||||
|
||||
/* Round (round + 1): Exx -> Axx */
|
||||
|
||||
Ba = Eba ^ Ega ^ Eka ^ Ema ^ Esa;
|
||||
Be = Ebe ^ Ege ^ Eke ^ Eme ^ Ese;
|
||||
Bi = Ebi ^ Egi ^ Eki ^ Emi ^ Esi;
|
||||
Bo = Ebo ^ Ego ^ Eko ^ Emo ^ Eso;
|
||||
Bu = Ebu ^ Egu ^ Eku ^ Emu ^ Esu;
|
||||
|
||||
Da = Bu ^ rol(Be, 1);
|
||||
De = Ba ^ rol(Bi, 1);
|
||||
Di = Be ^ rol(Bo, 1);
|
||||
Do = Bi ^ rol(Bu, 1);
|
||||
Du = Bo ^ rol(Ba, 1);
|
||||
|
||||
Ba = Eba ^ Da;
|
||||
Be = rol(Ege ^ De, 44);
|
||||
Bi = rol(Eki ^ Di, 43);
|
||||
Bo = rol(Emo ^ Do, 21);
|
||||
Bu = rol(Esu ^ Du, 14);
|
||||
Aba = Ba ^ (~Be & Bi) ^ round_constants[round + 1];
|
||||
Abe = Be ^ (~Bi & Bo);
|
||||
Abi = Bi ^ (~Bo & Bu);
|
||||
Abo = Bo ^ (~Bu & Ba);
|
||||
Abu = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Ebo ^ Do, 28);
|
||||
Be = rol(Egu ^ Du, 20);
|
||||
Bi = rol(Eka ^ Da, 3);
|
||||
Bo = rol(Eme ^ De, 45);
|
||||
Bu = rol(Esi ^ Di, 61);
|
||||
Aga = Ba ^ (~Be & Bi);
|
||||
Age = Be ^ (~Bi & Bo);
|
||||
Agi = Bi ^ (~Bo & Bu);
|
||||
Ago = Bo ^ (~Bu & Ba);
|
||||
Agu = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Ebe ^ De, 1);
|
||||
Be = rol(Egi ^ Di, 6);
|
||||
Bi = rol(Eko ^ Do, 25);
|
||||
Bo = rol(Emu ^ Du, 8);
|
||||
Bu = rol(Esa ^ Da, 18);
|
||||
Aka = Ba ^ (~Be & Bi);
|
||||
Ake = Be ^ (~Bi & Bo);
|
||||
Aki = Bi ^ (~Bo & Bu);
|
||||
Ako = Bo ^ (~Bu & Ba);
|
||||
Aku = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Ebu ^ Du, 27);
|
||||
Be = rol(Ega ^ Da, 36);
|
||||
Bi = rol(Eke ^ De, 10);
|
||||
Bo = rol(Emi ^ Di, 15);
|
||||
Bu = rol(Eso ^ Do, 56);
|
||||
Ama = Ba ^ (~Be & Bi);
|
||||
Ame = Be ^ (~Bi & Bo);
|
||||
Ami = Bi ^ (~Bo & Bu);
|
||||
Amo = Bo ^ (~Bu & Ba);
|
||||
Amu = Bu ^ (~Ba & Be);
|
||||
|
||||
Ba = rol(Ebi ^ Di, 62);
|
||||
Be = rol(Ego ^ Do, 55);
|
||||
Bi = rol(Eku ^ Du, 39);
|
||||
Bo = rol(Ema ^ Da, 41);
|
||||
Bu = rol(Ese ^ De, 2);
|
||||
Asa = Ba ^ (~Be & Bi);
|
||||
Ase = Be ^ (~Bi & Bo);
|
||||
Asi = Bi ^ (~Bo & Bu);
|
||||
Aso = Bo ^ (~Bu & Ba);
|
||||
Asu = Bu ^ (~Ba & Be);
|
||||
}
|
||||
|
||||
state[0] = Aba;
|
||||
state[1] = Abe;
|
||||
state[2] = Abi;
|
||||
state[3] = Abo;
|
||||
state[4] = Abu;
|
||||
state[5] = Aga;
|
||||
state[6] = Age;
|
||||
state[7] = Agi;
|
||||
state[8] = Ago;
|
||||
state[9] = Agu;
|
||||
state[10] = Aka;
|
||||
state[11] = Ake;
|
||||
state[12] = Aki;
|
||||
state[13] = Ako;
|
||||
state[14] = Aku;
|
||||
state[15] = Ama;
|
||||
state[16] = Ame;
|
||||
state[17] = Ami;
|
||||
state[18] = Amo;
|
||||
state[19] = Amu;
|
||||
state[20] = Asa;
|
||||
state[21] = Ase;
|
||||
state[22] = Asi;
|
||||
state[23] = Aso;
|
||||
state[24] = Asu;
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
#include "../bn254/fq.hpp"
|
||||
#include "../bn254/fq12.hpp"
|
||||
#include "../bn254/fq2.hpp"
|
||||
#include "../bn254/fr.hpp"
|
||||
#include "../bn254/g1.hpp"
|
||||
#include "../bn254/g2.hpp"
|
||||
|
||||
namespace bb::curve {
|
||||
class BN254 {
|
||||
public:
|
||||
using ScalarField = bb::fr;
|
||||
using BaseField = bb::fq;
|
||||
using Group = typename bb::g1;
|
||||
using Element = typename Group::element;
|
||||
using AffineElement = typename Group::affine_element;
|
||||
using G2AffineElement = typename bb::g2::affine_element;
|
||||
using G2BaseField = typename bb::fq2;
|
||||
using TargetField = bb::fq12;
|
||||
|
||||
// TODO(#673): This flag is temporary. It is needed in the verifier classes (GeminiVerifier, etc.) while these
|
||||
// classes are instantiated with "native" curve types. Eventually, the verifier classes will be instantiated only
|
||||
// with stdlib types, and "native" verification will be acheived via a simulated builder.
|
||||
static constexpr bool is_stdlib_type = false;
|
||||
};
|
||||
} // namespace bb::curve
|
||||
115
sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq.hpp
Normal file
115
sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq.hpp
Normal file
@@ -0,0 +1,115 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
|
||||
#include "../../fields/field.hpp"
|
||||
|
||||
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
|
||||
namespace bb {
|
||||
class Bn254FqParams {
|
||||
public:
|
||||
static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL;
|
||||
static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL;
|
||||
static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL;
|
||||
static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL;
|
||||
|
||||
static constexpr uint64_t r_squared_0 = 0xF32CFC5B538AFA89UL;
|
||||
static constexpr uint64_t r_squared_1 = 0xB5E71911D44501FBUL;
|
||||
static constexpr uint64_t r_squared_2 = 0x47AB1EFF0A417FF6UL;
|
||||
static constexpr uint64_t r_squared_3 = 0x06D89F71CAB8351FUL;
|
||||
|
||||
static constexpr uint64_t cube_root_0 = 0x71930c11d782e155UL;
|
||||
static constexpr uint64_t cube_root_1 = 0xa6bb947cffbe3323UL;
|
||||
static constexpr uint64_t cube_root_2 = 0xaa303344d4741444UL;
|
||||
static constexpr uint64_t cube_root_3 = 0x2c3b3f0d26594943UL;
|
||||
|
||||
static constexpr uint64_t modulus_wasm_0 = 0x187cfd47;
|
||||
static constexpr uint64_t modulus_wasm_1 = 0x10460b6;
|
||||
static constexpr uint64_t modulus_wasm_2 = 0x1c72a34f;
|
||||
static constexpr uint64_t modulus_wasm_3 = 0x2d522d0;
|
||||
static constexpr uint64_t modulus_wasm_4 = 0x1585d978;
|
||||
static constexpr uint64_t modulus_wasm_5 = 0x2db40c0;
|
||||
static constexpr uint64_t modulus_wasm_6 = 0xa6e141;
|
||||
static constexpr uint64_t modulus_wasm_7 = 0xe5c2634;
|
||||
static constexpr uint64_t modulus_wasm_8 = 0x30644e;
|
||||
|
||||
static constexpr uint64_t r_squared_wasm_0 = 0xe1a2a074659bac10UL;
|
||||
static constexpr uint64_t r_squared_wasm_1 = 0x639855865406005aUL;
|
||||
static constexpr uint64_t r_squared_wasm_2 = 0xff54c5802d3e2632UL;
|
||||
static constexpr uint64_t r_squared_wasm_3 = 0x2a11a68c34ea65a6UL;
|
||||
|
||||
static constexpr uint64_t cube_root_wasm_0 = 0x62b1a3a46a337995UL;
|
||||
static constexpr uint64_t cube_root_wasm_1 = 0xadc97d2722e2726eUL;
|
||||
static constexpr uint64_t cube_root_wasm_2 = 0x64ee82ede2db85faUL;
|
||||
static constexpr uint64_t cube_root_wasm_3 = 0x0c0afea1488a03bbUL;
|
||||
|
||||
static constexpr uint64_t primitive_root_0 = 0UL;
|
||||
static constexpr uint64_t primitive_root_1 = 0UL;
|
||||
static constexpr uint64_t primitive_root_2 = 0UL;
|
||||
static constexpr uint64_t primitive_root_3 = 0UL;
|
||||
|
||||
static constexpr uint64_t primitive_root_wasm_0 = 0x0000000000000000UL;
|
||||
static constexpr uint64_t primitive_root_wasm_1 = 0x0000000000000000UL;
|
||||
static constexpr uint64_t primitive_root_wasm_2 = 0x0000000000000000UL;
|
||||
static constexpr uint64_t primitive_root_wasm_3 = 0x0000000000000000UL;
|
||||
|
||||
static constexpr uint64_t endo_g1_lo = 0x7a7bd9d4391eb18d;
|
||||
static constexpr uint64_t endo_g1_mid = 0x4ccef014a773d2cfUL;
|
||||
static constexpr uint64_t endo_g1_hi = 0x0000000000000002UL;
|
||||
static constexpr uint64_t endo_g2_lo = 0xd91d232ec7e0b3d2UL;
|
||||
static constexpr uint64_t endo_g2_mid = 0x0000000000000002UL;
|
||||
static constexpr uint64_t endo_minus_b1_lo = 0x8211bbeb7d4f1129UL;
|
||||
static constexpr uint64_t endo_minus_b1_mid = 0x6f4d8248eeb859fcUL;
|
||||
static constexpr uint64_t endo_b2_lo = 0x89d3256894d213e2UL;
|
||||
static constexpr uint64_t endo_b2_mid = 0UL;
|
||||
|
||||
static constexpr uint64_t r_inv = 0x87d20782e4866389UL;
|
||||
|
||||
static constexpr uint64_t coset_generators_0[8]{
|
||||
0x7a17caa950ad28d7ULL, 0x4d750e37163c3674ULL, 0x20d251c4dbcb4411ULL, 0xf42f9552a15a51aeULL,
|
||||
0x4f4bc0b2b5ef64bdULL, 0x22a904407b7e725aULL, 0xf60647ce410d7ff7ULL, 0xc9638b5c069c8d94ULL,
|
||||
};
|
||||
static constexpr uint64_t coset_generators_1[8]{
|
||||
0x1f6ac17ae15521b9ULL, 0x29e3aca3d71c2cf7ULL, 0x345c97cccce33835ULL, 0x3ed582f5c2aa4372ULL,
|
||||
0x1a4b98fbe78db996ULL, 0x24c48424dd54c4d4ULL, 0x2f3d6f4dd31bd011ULL, 0x39b65a76c8e2db4fULL,
|
||||
};
|
||||
static constexpr uint64_t coset_generators_2[8]{
|
||||
0x334bea4e696bd284ULL, 0x99ba8dbde1e518b0ULL, 0x29312d5a5e5edcULL, 0x6697d49cd2d7a508ULL,
|
||||
0x5c65ec9f484e3a79ULL, 0xc2d4900ec0c780a5ULL, 0x2943337e3940c6d1ULL, 0x8fb1d6edb1ba0cfdULL,
|
||||
};
|
||||
static constexpr uint64_t coset_generators_3[8]{
|
||||
0x2a1f6744ce179d8eULL, 0x3829df06681f7cbdULL, 0x463456c802275bedULL, 0x543ece899c2f3b1cULL,
|
||||
0x180a96573d3d9f8ULL, 0xf8b21270ddbb927ULL, 0x1d9598e8a7e39857ULL, 0x2ba010aa41eb7786ULL,
|
||||
};
|
||||
|
||||
static constexpr uint64_t coset_generators_wasm_0[8] = { 0xeb8a8ec140766463ULL, 0xfded87957d76333dULL,
|
||||
0x4c710c8092f2ff5eULL, 0x9af4916ba86fcb7fULL,
|
||||
0xe9781656bdec97a0ULL, 0xfbdb0f2afaec667aULL,
|
||||
0x4a5e94161069329bULL, 0x98e2190125e5febcULL };
|
||||
static constexpr uint64_t coset_generators_wasm_1[8] = { 0xf2b1f20626a3da49ULL, 0x56c12d76cb13587fULL,
|
||||
0x5251d378d7f4a143ULL, 0x4de2797ae4d5ea06ULL,
|
||||
0x49731f7cf1b732c9ULL, 0xad825aed9626b0ffULL,
|
||||
0xa91300efa307f9c3ULL, 0xa4a3a6f1afe94286ULL };
|
||||
static constexpr uint64_t coset_generators_wasm_2[8] = { 0xf905ef8d84d5fea4ULL, 0x93b7a45b84f1507eULL,
|
||||
0xe6b99ee0068dfab5ULL, 0x39bb9964882aa4ecULL,
|
||||
0x8cbd93e909c74f23ULL, 0x276f48b709e2a0fcULL,
|
||||
0x7a71433b8b7f4b33ULL, 0xcd733dc00d1bf56aULL };
|
||||
static constexpr uint64_t coset_generators_wasm_3[8] = { 0x2958a27c02b7cd5fULL, 0x06bc8a3277c371abULL,
|
||||
0x1484c05bce00b620ULL, 0x224cf685243dfa96ULL,
|
||||
0x30152cae7a7b3f0bULL, 0x0d791464ef86e357ULL,
|
||||
0x1b414a8e45c427ccULL, 0x290980b79c016c41ULL };
|
||||
|
||||
// used in msgpack schema serialization
|
||||
static constexpr char schema_name[] = "fq";
|
||||
static constexpr bool has_high_2adicity = false;
|
||||
|
||||
// The modulus is larger than BN254 scalar field modulus, so it maps to two BN254 scalars
|
||||
static constexpr size_t NUM_BN254_SCALARS = 2;
|
||||
};
|
||||
|
||||
using fq = field<Bn254FqParams>;
|
||||
|
||||
} // namespace bb
|
||||
|
||||
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)
|
||||
@@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../fields/field2.hpp"
|
||||
#include "./fq.hpp"
|
||||
|
||||
namespace bb {
|
||||
struct Bn254Fq2Params {
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr fq twist_coeff_b_0{
|
||||
0x3bf938e377b802a8UL, 0x020b1b273633535dUL, 0x26b7edf049755260UL, 0x2514c6324384a86dUL
|
||||
};
|
||||
static constexpr fq twist_coeff_b_1{
|
||||
0x38e7ecccd1dcff67UL, 0x65f0b37d93ce0d3eUL, 0xd749d0dd22ac00aaUL, 0x0141b9ce4a688d4dUL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_x_0{
|
||||
0xb5773b104563ab30UL, 0x347f91c8a9aa6454UL, 0x7a007127242e0991UL, 0x1956bcd8118214ecUL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_x_1{
|
||||
0x6e849f1ea0aa4757UL, 0xaa1c7b6d89f89141UL, 0xb6e713cdfae0ca3aUL, 0x26694fbb4e82ebc3UL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_y_0{
|
||||
0xe4bbdd0c2936b629UL, 0xbb30f162e133bacbUL, 0x31a9d1b6f9645366UL, 0x253570bea500f8ddUL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_y_1{
|
||||
0xa1d77ce45ffe77c7UL, 0x07affd117826d1dbUL, 0x6d16bd27bb7edc6bUL, 0x2c87200285defeccUL
|
||||
};
|
||||
static constexpr fq twist_cube_root_0{
|
||||
0x505ecc6f0dff1ac2UL, 0x2071416db35ec465UL, 0xf2b53469fa43ea78UL, 0x18545552044c99aaUL
|
||||
};
|
||||
static constexpr fq twist_cube_root_1{
|
||||
0xad607f911cfe17a8UL, 0xb6bb78aa154154c4UL, 0xb53dd351736b20dbUL, 0x1d8ed57c5cc33d41UL
|
||||
};
|
||||
#else
|
||||
static constexpr fq twist_coeff_b_0{
|
||||
0xdc19fa4aab489658UL, 0xd416744fbbf6e69UL, 0x8f7734ed0a8a033aUL, 0x19316b8353ee09bbUL
|
||||
};
|
||||
static constexpr fq twist_coeff_b_1{
|
||||
0x1cfd999a3b9fece0UL, 0xbe166fb279c1a7c7UL, 0xe93a1ba45580154cUL, 0x283739c94d11a9baUL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_x_0{
|
||||
0xecdea09b24a59190UL, 0x17db8ffeae2fe1c2UL, 0xbb09c97c6dabac4dUL, 0x2492b3d41d289af3UL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_x_1{
|
||||
0xf1663598f1142ef1UL, 0x77ec057e0bf56062UL, 0xdd0baaecb677a631UL, 0x135e4e31d284d463UL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_y_0{
|
||||
0xf46e7f60db1f0678UL, 0x31fc2eba5bcc5c3eUL, 0xedb3adc3086a2411UL, 0x1d46bd0f837817bcUL
|
||||
};
|
||||
static constexpr fq twist_mul_by_q_y_1{
|
||||
0x6b3fbdf579a647d5UL, 0xcc568fb62ff64974UL, 0xc1bfbf4ac4348ac6UL, 0x15871d4d3940b4d3UL
|
||||
};
|
||||
static constexpr fq twist_cube_root_0{
|
||||
0x49d0cc74381383d0UL, 0x9611849fe4bbe3d6UL, 0xd1a231d73067c92aUL, 0x445c312767932c2UL
|
||||
};
|
||||
static constexpr fq twist_cube_root_1{
|
||||
0x35a58c718e7c28bbUL, 0x98d42c77e7b8901aUL, 0xf9c53da2d0ca8c84UL, 0x1a68dd04e1b8c51dUL
|
||||
};
|
||||
#endif
|
||||
};
|
||||
|
||||
using fq2 = field2<fq, Bn254Fq2Params>;
|
||||
} // namespace bb
|
||||
121
sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fr.hpp
Normal file
121
sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fr.hpp
Normal file
@@ -0,0 +1,121 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
#include <ostream>
|
||||
|
||||
#include "../../fields/field.hpp"
|
||||
|
||||
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
|
||||
|
||||
namespace bb {
|
||||
class Bn254FrParams {
|
||||
public:
|
||||
// Note: limbs here are combined as concat(_3, _2, _1, _0)
|
||||
// E.g. this modulus forms the value:
|
||||
// 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001
|
||||
// = 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
||||
static constexpr uint64_t modulus_0 = 0x43E1F593F0000001UL;
|
||||
static constexpr uint64_t modulus_1 = 0x2833E84879B97091UL;
|
||||
static constexpr uint64_t modulus_2 = 0xB85045B68181585DUL;
|
||||
static constexpr uint64_t modulus_3 = 0x30644E72E131A029UL;
|
||||
|
||||
static constexpr uint64_t r_squared_0 = 0x1BB8E645AE216DA7UL;
|
||||
static constexpr uint64_t r_squared_1 = 0x53FE3AB1E35C59E3UL;
|
||||
static constexpr uint64_t r_squared_2 = 0x8C49833D53BB8085UL;
|
||||
static constexpr uint64_t r_squared_3 = 0x216D0B17F4E44A5UL;
|
||||
|
||||
static constexpr uint64_t cube_root_0 = 0x93e7cede4a0329b3UL;
|
||||
static constexpr uint64_t cube_root_1 = 0x7d4fdca77a96c167UL;
|
||||
static constexpr uint64_t cube_root_2 = 0x8be4ba08b19a750aUL;
|
||||
static constexpr uint64_t cube_root_3 = 0x1cbd5653a5661c25UL;
|
||||
|
||||
static constexpr uint64_t primitive_root_0 = 0x636e735580d13d9cUL;
|
||||
static constexpr uint64_t primitive_root_1 = 0xa22bf3742445ffd6UL;
|
||||
static constexpr uint64_t primitive_root_2 = 0x56452ac01eb203d8UL;
|
||||
static constexpr uint64_t primitive_root_3 = 0x1860ef942963f9e7UL;
|
||||
|
||||
static constexpr uint64_t endo_g1_lo = 0x7a7bd9d4391eb18dUL;
|
||||
static constexpr uint64_t endo_g1_mid = 0x4ccef014a773d2cfUL;
|
||||
static constexpr uint64_t endo_g1_hi = 0x0000000000000002UL;
|
||||
static constexpr uint64_t endo_g2_lo = 0xd91d232ec7e0b3d7UL;
|
||||
static constexpr uint64_t endo_g2_mid = 0x0000000000000002UL;
|
||||
static constexpr uint64_t endo_minus_b1_lo = 0x8211bbeb7d4f1128UL;
|
||||
static constexpr uint64_t endo_minus_b1_mid = 0x6f4d8248eeb859fcUL;
|
||||
static constexpr uint64_t endo_b2_lo = 0x89d3256894d213e3UL;
|
||||
static constexpr uint64_t endo_b2_mid = 0UL;
|
||||
|
||||
static constexpr uint64_t r_inv = 0xc2e1f593efffffffUL;
|
||||
|
||||
static constexpr uint64_t coset_generators_0[8]{
|
||||
0x5eef048d8fffffe7ULL, 0xb8538a9dfffffe2ULL, 0x3057819e4fffffdbULL, 0xdcedb5ba9fffffd6ULL,
|
||||
0x8983e9d6efffffd1ULL, 0x361a1df33fffffccULL, 0xe2b0520f8fffffc7ULL, 0x8f46862bdfffffc2ULL,
|
||||
};
|
||||
static constexpr uint64_t coset_generators_1[8]{
|
||||
0x12ee50ec1ce401d0ULL, 0x49eac781bc44cefaULL, 0x307f6d866832bb01ULL, 0x677be41c0793882aULL,
|
||||
0x9e785ab1a6f45554ULL, 0xd574d1474655227eULL, 0xc7147dce5b5efa7ULL, 0x436dbe728516bcd1ULL,
|
||||
};
|
||||
static constexpr uint64_t coset_generators_2[8]{
|
||||
0x29312d5a5e5ee7ULL, 0x6697d49cd2d7a515ULL, 0x5c65ec9f484e3a89ULL, 0xc2d4900ec0c780b7ULL,
|
||||
0x2943337e3940c6e5ULL, 0x8fb1d6edb1ba0d13ULL, 0xf6207a5d2a335342ULL, 0x5c8f1dcca2ac9970ULL,
|
||||
};
|
||||
static constexpr uint64_t coset_generators_3[8]{
|
||||
0x463456c802275bedULL, 0x543ece899c2f3b1cULL, 0x180a96573d3d9f8ULL, 0xf8b21270ddbb927ULL,
|
||||
0x1d9598e8a7e39857ULL, 0x2ba010aa41eb7786ULL, 0x39aa886bdbf356b5ULL, 0x47b5002d75fb35e5ULL,
|
||||
};
|
||||
|
||||
static constexpr uint64_t modulus_wasm_0 = 0x10000001;
|
||||
static constexpr uint64_t modulus_wasm_1 = 0x1f0fac9f;
|
||||
static constexpr uint64_t modulus_wasm_2 = 0xe5c2450;
|
||||
static constexpr uint64_t modulus_wasm_3 = 0x7d090f3;
|
||||
static constexpr uint64_t modulus_wasm_4 = 0x1585d283;
|
||||
static constexpr uint64_t modulus_wasm_5 = 0x2db40c0;
|
||||
static constexpr uint64_t modulus_wasm_6 = 0xa6e141;
|
||||
static constexpr uint64_t modulus_wasm_7 = 0xe5c2634;
|
||||
static constexpr uint64_t modulus_wasm_8 = 0x30644e;
|
||||
|
||||
static constexpr uint64_t r_squared_wasm_0 = 0x38c2e14b45b69bd4UL;
|
||||
static constexpr uint64_t r_squared_wasm_1 = 0x0ffedb1885883377UL;
|
||||
static constexpr uint64_t r_squared_wasm_2 = 0x7840f9f0abc6e54dUL;
|
||||
static constexpr uint64_t r_squared_wasm_3 = 0x0a054a3e848b0f05UL;
|
||||
|
||||
static constexpr uint64_t cube_root_wasm_0 = 0x7334a1ce7065364dUL;
|
||||
static constexpr uint64_t cube_root_wasm_1 = 0xae21578e4a14d22aUL;
|
||||
static constexpr uint64_t cube_root_wasm_2 = 0xcea2148a96b51265UL;
|
||||
static constexpr uint64_t cube_root_wasm_3 = 0x0038f7edf614a198UL;
|
||||
|
||||
static constexpr uint64_t primitive_root_wasm_0 = 0x2faf11711a27b370UL;
|
||||
static constexpr uint64_t primitive_root_wasm_1 = 0xc23fe9fced28f1b8UL;
|
||||
static constexpr uint64_t primitive_root_wasm_2 = 0x43a0fc9bbe2af541UL;
|
||||
static constexpr uint64_t primitive_root_wasm_3 = 0x05d90b5719653a4fUL;
|
||||
|
||||
static constexpr uint64_t coset_generators_wasm_0[8] = { 0xab46711cdffffcb2ULL, 0xdb1b52736ffffc09ULL,
|
||||
0x0af033c9fffffb60ULL, 0xf6e31f8c9ffffab6ULL,
|
||||
0x26b800e32ffffa0dULL, 0x568ce239bffff964ULL,
|
||||
0x427fcdfc5ffff8baULL, 0x7254af52effff811ULL };
|
||||
static constexpr uint64_t coset_generators_wasm_1[8] = { 0x2476607dbd2dfff1ULL, 0x9a3208a561c2b00bULL,
|
||||
0x0fedb0cd06576026ULL, 0x5d7570ac31329faeULL,
|
||||
0xd33118d3d5c74fc9ULL, 0x48ecc0fb7a5bffe3ULL,
|
||||
0x967480daa5373f6cULL, 0x0c30290249cbef86ULL };
|
||||
static constexpr uint64_t coset_generators_wasm_2[8] = { 0xe6b99ee0068dfc25ULL, 0x39bb9964882aa6a5ULL,
|
||||
0x8cbd93e909c75126ULL, 0x276f48b709e2a349ULL,
|
||||
0x7a71433b8b7f4dc9ULL, 0xcd733dc00d1bf84aULL,
|
||||
0x6824f28e0d374a6dULL, 0xbb26ed128ed3f4eeULL };
|
||||
static constexpr uint64_t coset_generators_wasm_3[8] = { 0x1484c05bce00b620ULL, 0x224cf685243dfa96ULL,
|
||||
0x30152cae7a7b3f0bULL, 0x0d791464ef86e357ULL,
|
||||
0x1b414a8e45c427ccULL, 0x290980b79c016c41ULL,
|
||||
0x066d686e110d108dULL, 0x14359e97674a5502ULL };
|
||||
|
||||
// used in msgpack schema serialization
|
||||
static constexpr char schema_name[] = "fr";
|
||||
static constexpr bool has_high_2adicity = true;
|
||||
|
||||
// This is a BN254 scalar, so it represents one BN254 scalar
|
||||
static constexpr size_t NUM_BN254_SCALARS = 1;
|
||||
};
|
||||
|
||||
using fr = field<Bn254FrParams>;
|
||||
|
||||
} // namespace bb
|
||||
|
||||
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../groups/group.hpp"
|
||||
#include "./fq.hpp"
|
||||
#include "./fr.hpp"
|
||||
|
||||
namespace bb {
|
||||
struct Bn254G1Params {
|
||||
static constexpr bool USE_ENDOMORPHISM = true;
|
||||
static constexpr bool can_hash_to_curve = true;
|
||||
static constexpr bool small_elements = true;
|
||||
static constexpr bool has_a = false;
|
||||
static constexpr fq one_x = fq::one();
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr fq one_y{ 0xa6ba871b8b1e1b3aUL, 0x14f1d651eb8e167bUL, 0xccdd46def0f28c58UL, 0x1c14ef83340fbe5eUL };
|
||||
#else
|
||||
static constexpr fq one_y{ 0x9d0709d62af99842UL, 0xf7214c0419c29186UL, 0xa603f5090339546dUL, 0x1b906c52ac7a88eaUL };
|
||||
#endif
|
||||
static constexpr fq a{ 0UL, 0UL, 0UL, 0UL };
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr fq b{ 0x7a17caa950ad28d7UL, 0x1f6ac17ae15521b9UL, 0x334bea4e696bd284UL, 0x2a1f6744ce179d8eUL };
|
||||
#else
|
||||
static constexpr fq b{ 0xeb8a8ec140766463UL, 0xf2b1f20626a3da49UL, 0xf905ef8d84d5fea4UL, 0x2958a27c02b7cd5fUL };
|
||||
#endif
|
||||
};
|
||||
|
||||
using g1 = group<fq, fr, Bn254G1Params>;
|
||||
|
||||
} // namespace bb
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../groups/group.hpp"
|
||||
#include "./fq2.hpp"
|
||||
#include "./fr.hpp"
|
||||
|
||||
namespace bb {
|
||||
struct Bn254G2Params {
|
||||
static constexpr bool USE_ENDOMORPHISM = false;
|
||||
static constexpr bool can_hash_to_curve = false;
|
||||
static constexpr bool small_elements = false;
|
||||
static constexpr bool has_a = false;
|
||||
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr fq2 one_x{ { 0x8e83b5d102bc2026, 0xdceb1935497b0172, 0xfbb8264797811adf, 0x19573841af96503b },
|
||||
{ 0xafb4737da84c6140, 0x6043dd5a5802d8c4, 0x09e950fc52a02f86, 0x14fef0833aea7b6b } };
|
||||
static constexpr fq2 one_y{ { 0x619dfa9d886be9f6, 0xfe7fd297f59e9b78, 0xff9e1a62231b7dfe, 0x28fd7eebae9e4206 },
|
||||
{ 0x64095b56c71856ee, 0xdc57f922327d3cbb, 0x55f935be33351076, 0x0da4a0e693fd6482 } };
|
||||
#else
|
||||
static constexpr fq2 one_x{
|
||||
{ 0xe6df8b2cfb43050UL, 0x254c7d92a843857eUL, 0xf2006d8ad80dd622UL, 0x24a22107dfb004e3UL },
|
||||
{ 0xe8e7528c0b334b65UL, 0x56e941e8b293cf69UL, 0xe1169545c074740bUL, 0x2ac61491edca4b42UL }
|
||||
};
|
||||
static constexpr fq2 one_y{
|
||||
{ 0xdc508d48384e8843UL, 0xd55415a8afd31226UL, 0x834bf204bacb6e00UL, 0x51b9758138c5c79UL },
|
||||
{ 0x64067e0b46a5f641UL, 0x37726529a3a77875UL, 0x4454445bd915f391UL, 0x10d5ac894edeed3UL }
|
||||
};
|
||||
#endif
|
||||
static constexpr fq2 a = fq2::zero();
|
||||
static constexpr fq2 b = fq2::twist_coeff_b();
|
||||
};
|
||||
|
||||
using g2 = group<fq2, fr, Bn254G2Params>;
|
||||
} // namespace bb
|
||||
@@ -0,0 +1,990 @@
|
||||
#pragma once
|
||||
// clang-format off
|
||||
|
||||
/*
|
||||
* Clear all flags via xorq opcode
|
||||
**/
|
||||
#define CLEAR_FLAGS(empty_reg) \
|
||||
"xorq " empty_reg ", " empty_reg " \n\t"
|
||||
|
||||
/**
|
||||
* Load 4-limb field element, pointed to by a, into
|
||||
* registers (lolo, lohi, hilo, hihi)
|
||||
**/
|
||||
#define LOAD_FIELD_ELEMENT(a, lolo, lohi, hilo, hihi) \
|
||||
"movq 0(" a "), " lolo " \n\t" \
|
||||
"movq 8(" a "), " lohi " \n\t" \
|
||||
"movq 16(" a "), " hilo " \n\t" \
|
||||
"movq 24(" a "), " hihi " \n\t"
|
||||
|
||||
/**
|
||||
* Store 4-limb field element located in
|
||||
* registers (lolo, lohi, hilo, hihi), into
|
||||
* memory pointed to by r
|
||||
**/
|
||||
#define STORE_FIELD_ELEMENT(r, lolo, lohi, hilo, hihi) \
|
||||
"movq " lolo ", 0(" r ") \n\t" \
|
||||
"movq " lohi ", 8(" r ") \n\t" \
|
||||
"movq " hilo ", 16(" r ") \n\t" \
|
||||
"movq " hihi ", 24(" r ") \n\t"
|
||||
|
||||
#if !defined(__ADX__) || defined(DISABLE_ADX)
|
||||
/**
|
||||
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
|
||||
* and add 4-limb field element pointed to by a
|
||||
**/
|
||||
#define ADD(b) \
|
||||
"addq 0(" b "), %%r12 \n\t" \
|
||||
"adcq 8(" b "), %%r13 \n\t" \
|
||||
"adcq 16(" b "), %%r14 \n\t" \
|
||||
"adcq 24(" b "), %%r15 \n\t"
|
||||
|
||||
/**
|
||||
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
|
||||
* and subtract 4-limb field element pointed to by b
|
||||
**/
|
||||
#define SUB(b) \
|
||||
"subq 0(" b "), %%r12 \n\t" \
|
||||
"sbbq 8(" b "), %%r13 \n\t" \
|
||||
"sbbq 16(" b "), %%r14 \n\t" \
|
||||
"sbbq 24(" b "), %%r15 \n\t"
|
||||
|
||||
|
||||
/**
|
||||
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
|
||||
* add 4-limb field element pointed to by b, and reduce modulo p
|
||||
**/
|
||||
#define ADD_REDUCE(b, modulus_0, modulus_1, modulus_2, modulus_3) \
|
||||
"addq 0(" b "), %%r12 \n\t" \
|
||||
"adcq 8(" b "), %%r13 \n\t" \
|
||||
"adcq 16(" b "), %%r14 \n\t" \
|
||||
"adcq 24(" b "), %%r15 \n\t" \
|
||||
"movq %%r12, %%r8 \n\t" \
|
||||
"movq %%r13, %%r9 \n\t" \
|
||||
"movq %%r14, %%r10 \n\t" \
|
||||
"movq %%r15, %%r11 \n\t" \
|
||||
"addq " modulus_0 ", %%r12 \n\t" \
|
||||
"adcq " modulus_1 ", %%r13 \n\t" \
|
||||
"adcq " modulus_2 ", %%r14 \n\t" \
|
||||
"adcq " modulus_3 ", %%r15 \n\t" \
|
||||
"cmovncq %%r8, %%r12 \n\t" \
|
||||
"cmovncq %%r9, %%r13 \n\t" \
|
||||
"cmovncq %%r10, %%r14 \n\t" \
|
||||
"cmovncq %%r11, %%r15 \n\t"
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Take a 4-limb integer, r, in (%r12, %r13, %r14, %r15)
|
||||
* and conditionally subtract modulus, if r > p.
|
||||
**/
|
||||
#define REDUCE_FIELD_ELEMENT(neg_modulus_0, neg_modulus_1, neg_modulus_2, neg_modulus_3) \
|
||||
/* Duplicate `r` */ \
|
||||
"movq %%r12, %%r8 \n\t" \
|
||||
"movq %%r13, %%r9 \n\t" \
|
||||
"movq %%r14, %%r10 \n\t" \
|
||||
"movq %%r15, %%r11 \n\t" \
|
||||
"addq " neg_modulus_0 ", %%r12 \n\t" /* r'[0] -= modulus.data[0] */ \
|
||||
"adcq " neg_modulus_1 ", %%r13 \n\t" /* r'[1] -= modulus.data[1] */ \
|
||||
"adcq " neg_modulus_2 ", %%r14 \n\t" /* r'[2] -= modulus.data[2] */ \
|
||||
"adcq " neg_modulus_3 ", %%r15 \n\t" /* r'[3] -= modulus.data[3] */ \
|
||||
\
|
||||
/* if r does not need to be reduced, overflow flag is 1 */ \
|
||||
/* set r' = r if this flag is set */ \
|
||||
"cmovncq %%r8, %%r12 \n\t" \
|
||||
"cmovncq %%r9, %%r13 \n\t" \
|
||||
"cmovncq %%r10, %%r14 \n\t" \
|
||||
"cmovncq %%r11, %%r15 \n\t"
|
||||
|
||||
/**
|
||||
* Compute Montgomery squaring of a
|
||||
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define SQR(a) \
|
||||
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
\
|
||||
"xorq %%r8, %%r8 \n\t" /* clear flags */ \
|
||||
/* compute a[0] *a[1], a[0]*a[2], a[0]*a[3], a[1]*a[2], a[1]*a[3], a[2]*a[3] */ \
|
||||
"mulxq 8(" a "), %%r9, %%r10 \n\t" /* (r[1], r[2]) <- a[0] * a[1] */ \
|
||||
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[1], t[2]) <- a[0] * a[2] */ \
|
||||
"mulxq 24(" a "), %%r11, %%r12 \n\t" /* (r[3], r[4]) <- a[0] * a[3] */ \
|
||||
\
|
||||
\
|
||||
/* accumulate products into result registers */ \
|
||||
"addq %%r8, %%r10 \n\t" /* r[2] += t[1] */ \
|
||||
"adcq %%r15, %%r11 \n\t" /* r[3] += t[2] */ \
|
||||
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %r%dx */ \
|
||||
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[5], t[6]) <- a[1] * a[2] */ \
|
||||
"mulxq 24(" a "), %%rdi, %%rcx \n\t" /* (t[3], t[4]) <- a[1] * a[3] */ \
|
||||
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %%rdx */ \
|
||||
"mulxq 16(" a "), %%r13, %%r14 \n\t" /* (r[5], r[6]) <- a[3] * a[2] */ \
|
||||
"adcq %%rdi, %%r12 \n\t" /* r[4] += t[3] */ \
|
||||
"adcq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"addq %%r8, %%r11 \n\t" /* r[3] += t[5] */ \
|
||||
"adcq %%r15, %%r12 \n\t" /* r[4] += t[6] */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
\
|
||||
/* double result registers */ \
|
||||
"addq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \
|
||||
"adcq %%r10, %%r10 \n\t" /* r[2] = 2r[2] */ \
|
||||
"adcq %%r11, %%r11 \n\t" /* r[3] = 2r[3] */ \
|
||||
"adcq %%r12, %%r12 \n\t" /* r[4] = 2r[4] */ \
|
||||
"adcq %%r13, %%r13 \n\t" /* r[5] = 2r[5] */ \
|
||||
"adcq %%r14, %%r14 \n\t" /* r[6] = 2r[6] */ \
|
||||
\
|
||||
/* compute a[3]*a[3], a[2]*a[2], a[1]*a[1], a[0]*a[0] */ \
|
||||
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
"mulxq %%rdx, %%r8, %%rcx \n\t" /* (r[0], t[4]) <- a[0] * a[0] */ \
|
||||
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq %%rdx, %%rdx, %%rdi \n\t" /* (t[7], t[8]) <- a[2] * a[2] */ \
|
||||
/* add squares into result registers */ \
|
||||
"addq %%rdx, %%r12 \n\t" /* r[4] += t[7] */ \
|
||||
"adcq %%rdi, %%r13 \n\t" /* r[5] += t[8] */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"addq %%rcx, %%r9 \n\t" /* r[1] += t[4] */ \
|
||||
"movq 24(" a "), %%rdx \n\t" /* r[2] += flag_c */ \
|
||||
"mulxq %%rdx, %%rcx, %%r15 \n\t" /* (t[5], r[7]) <- a[3] * a[3] */ \
|
||||
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq %%rdx, %%rdi, %%rdx \n\t" /* (t[3], t[6]) <- a[1] * a[1] */ \
|
||||
"adcq %%rdi, %%r10 \n\t" /* r[2] += t[3] */ \
|
||||
"adcq %%rdx, %%r11 \n\t" /* r[3] += t[6] */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
"addq %%rcx, %%r14 \n\t" /* r[6] += t[5] */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
\
|
||||
/* perform modular reduction: r[0] */ \
|
||||
"movq %%r8, %%rdx \n\t" /* move r8 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"addq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \
|
||||
"adcq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"adcq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r11 \n\t" /* r[4] += flag_c */ \
|
||||
/* Partial fix "adcq $0, %%r12 \n\t"*/ /* r[4] += flag_c */ \
|
||||
"addq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \
|
||||
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"mulxq %[modulus_3], %%r8, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcq %%rdi, %%r10 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"adcq %%rdx, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"addq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
\
|
||||
/* perform modular reduction: r[1] */ \
|
||||
"movq %%r9, %%rdx \n\t" /* move r9 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"addq %%rdi, %%r9 \n\t" /* r[1] += t[0] (%r8 now free) */ \
|
||||
"adcq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"adcq %%rcx, %%r11 \n\t" /* r[3] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
"addq %%rdi, %%r10 \n\t" /* r[2] += t[2] */ \
|
||||
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"mulxq %[modulus_3], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcq %%rdi, %%r11 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"addq %%r8, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
\
|
||||
/* perform modular reduction: r[2] */ \
|
||||
"movq %%r10, %%rdx \n\t" /* move r10 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"addq %%rdi, %%r10 \n\t" /* r[2] += t[0] (%r8 now free) */ \
|
||||
"adcq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"mulxq %[modulus_3], %%r10, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcq %%rcx, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
|
||||
"adcq %%r9, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
|
||||
"adcq %%rdx, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
"addq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \
|
||||
"adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
|
||||
"adcq %%r10, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
\
|
||||
/* perform modular reduction: r[3] */ \
|
||||
"movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"addq %%rdi, %%r11 \n\t" /* r[3] += t[0] (%r11 now free) */ \
|
||||
"adcq %%r8, %%r12 \n\t" /* r[4] += t[2] */ \
|
||||
"adcq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"mulxq %[modulus_3], %%r10, %%r11 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcq %%r9, %%r14 \n\t" /* r[6] += t[1] + flag_c */ \
|
||||
"adcq %%r11, %%r15 \n\t" /* r[7] += t[3] + flag_c */ \
|
||||
"addq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcq %%r8, %%r13 \n\t" /* r[5] += t[0] + flag_c */ \
|
||||
"adcq %%r10, %%r14 \n\t" /* r[6] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */
|
||||
|
||||
|
||||
/**
|
||||
* Compute Montgomery multiplication of a, b.
|
||||
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define MUL(a1, a2, a3, a4, b) \
|
||||
"movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
"xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \
|
||||
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
|
||||
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
|
||||
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
|
||||
"mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
|
||||
/* zero flags */ \
|
||||
\
|
||||
/* start computing modular reduction */ \
|
||||
"movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
\
|
||||
/* start first addition chain */ \
|
||||
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adcq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
\
|
||||
/* reduce by r[0] * k */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"addq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \
|
||||
"adcq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adcq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r10 \n\t" /* r[3] += flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
"addq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"adcq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
|
||||
"adcq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
|
||||
"addq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_i */ \
|
||||
\
|
||||
/* modulus = 254 bits, so max(t[3]) = 62 bits */ \
|
||||
/* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \
|
||||
/* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \
|
||||
/* (which is very convenient because we're out of registers!) */ \
|
||||
/* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \
|
||||
\
|
||||
/* a[1] * b */ \
|
||||
"movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
|
||||
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
|
||||
"adcq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
"addq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
\
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \
|
||||
"adcq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"addq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
\
|
||||
/* reduce by r[1] * k */ \
|
||||
"movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \
|
||||
"adcq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"addq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"adcq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
\
|
||||
/* a[2] * b */ \
|
||||
"movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
|
||||
"addq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"adcq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"addq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \
|
||||
"adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"addq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
\
|
||||
/* reduce by r[2] * k */ \
|
||||
"movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"addq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \
|
||||
"adcq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"addq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
|
||||
"adcq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
|
||||
"addq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
\
|
||||
/* a[3] * b */ \
|
||||
"movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \
|
||||
"addq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
\
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \
|
||||
"adcq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
|
||||
"adcq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += + flag_c */ \
|
||||
"addq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
\
|
||||
/* reduce by r[3] * k */ \
|
||||
"movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \
|
||||
"addq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \
|
||||
"adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
"addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
\
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \
|
||||
"adcq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
|
||||
"adcq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
|
||||
"adcq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \
|
||||
"addq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
|
||||
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */
|
||||
|
||||
|
||||
/**
|
||||
* Compute 256-bit multiplication of a, b.
|
||||
* Result is stored, r. // in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define MUL_256(a, b, r) \
|
||||
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
\
|
||||
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
|
||||
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
|
||||
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
|
||||
"mulxq 16(" b "), %%r15, %%rax \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
|
||||
/* zero flags */ \
|
||||
"xorq %%r10, %%r10 \n\t" /* clear r10 register, we use this when we need 0 */ \
|
||||
\
|
||||
\
|
||||
/* start first addition chain */ \
|
||||
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adcq %%r10, %%rax \n\t" /* r[3] += flag_c */ \
|
||||
"addq %%rdi, %%rax \n\t" /* r[3] += t[2] + flag_c */ \
|
||||
\
|
||||
/* a[1] * b */ \
|
||||
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
|
||||
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adcq %%rsi, %%rax \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"addq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
\
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
|
||||
"adcq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
/* a[2] * b */ \
|
||||
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
|
||||
"addq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcq %%r9, %%rax \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"addq %%rdi, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
\
|
||||
/* a[3] * b */ \
|
||||
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
|
||||
"adcq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"movq %%r13, 0(" r ") \n\t" \
|
||||
"movq %%r14, 8(" r ") \n\t" \
|
||||
"movq %%r15, 16(" r ") \n\t" \
|
||||
"movq %%rax, 24(" r ") \n\t"
|
||||
|
||||
|
||||
#else // 6047895us
|
||||
/**
|
||||
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
|
||||
* and add 4-limb field element pointed to by a
|
||||
**/
|
||||
#define ADD(b) \
|
||||
"adcxq 0(" b "), %%r12 \n\t" \
|
||||
"adcxq 8(" b "), %%r13 \n\t" \
|
||||
"adcxq 16(" b "), %%r14 \n\t" \
|
||||
"adcxq 24(" b "), %%r15 \n\t"
|
||||
|
||||
/**
|
||||
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
|
||||
* and subtract 4-limb field element pointed to by b
|
||||
**/
|
||||
#define SUB(b) \
|
||||
"subq 0(" b "), %%r12 \n\t" \
|
||||
"sbbq 8(" b "), %%r13 \n\t" \
|
||||
"sbbq 16(" b "), %%r14 \n\t" \
|
||||
"sbbq 24(" b "), %%r15 \n\t"
|
||||
|
||||
/**
|
||||
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
|
||||
* add 4-limb field element pointed to by b, and reduce modulo p
|
||||
**/
|
||||
#define ADD_REDUCE(b, modulus_0, modulus_1, modulus_2, modulus_3) \
|
||||
"adcxq 0(" b "), %%r12 \n\t" \
|
||||
"movq %%r12, %%r8 \n\t" \
|
||||
"adoxq " modulus_0 ", %%r12 \n\t" \
|
||||
"adcxq 8(" b "), %%r13 \n\t" \
|
||||
"movq %%r13, %%r9 \n\t" \
|
||||
"adoxq " modulus_1 ", %%r13 \n\t" \
|
||||
"adcxq 16(" b "), %%r14 \n\t" \
|
||||
"movq %%r14, %%r10 \n\t" \
|
||||
"adoxq " modulus_2 ", %%r14 \n\t" \
|
||||
"adcxq 24(" b "), %%r15 \n\t" \
|
||||
"movq %%r15, %%r11 \n\t" \
|
||||
"adoxq " modulus_3 ", %%r15 \n\t" \
|
||||
"cmovnoq %%r8, %%r12 \n\t" \
|
||||
"cmovnoq %%r9, %%r13 \n\t" \
|
||||
"cmovnoq %%r10, %%r14 \n\t" \
|
||||
"cmovnoq %%r11, %%r15 \n\t"
|
||||
|
||||
|
||||
/**
|
||||
* Take a 4-limb integer, r, in (%r12, %r13, %r14, %r15)
|
||||
* and conditionally subtract modulus, if r > p.
|
||||
**/
|
||||
#define REDUCE_FIELD_ELEMENT(neg_modulus_0, neg_modulus_1, neg_modulus_2, neg_modulus_3) \
|
||||
/* Duplicate `r` */ \
|
||||
"movq %%r12, %%r8 \n\t" \
|
||||
"movq %%r13, %%r9 \n\t" \
|
||||
"movq %%r14, %%r10 \n\t" \
|
||||
"movq %%r15, %%r11 \n\t" \
|
||||
/* Add the negative representation of 'modulus' into `r`. We do this instead */ \
|
||||
/* of subtracting, because we can use `adoxq`. */ \
|
||||
/* This opcode only has a dependence on the overflow */ \
|
||||
/* flag (sub/sbb changes both carry and overflow flags). */ \
|
||||
/* We can process an `adcxq` and `acoxq` opcode simultaneously. */ \
|
||||
"adoxq " neg_modulus_0 ", %%r12 \n\t" /* r'[0] -= modulus.data[0] */ \
|
||||
"adoxq " neg_modulus_1 ", %%r13 \n\t" /* r'[1] -= modulus.data[1] */ \
|
||||
"adoxq " neg_modulus_2 ", %%r14 \n\t" /* r'[2] -= modulus.data[2] */ \
|
||||
"adoxq " neg_modulus_3 ", %%r15 \n\t" /* r'[3] -= modulus.data[3] */ \
|
||||
\
|
||||
/* if r does not need to be reduced, overflow flag is 1 */ \
|
||||
/* set r' = r if this flag is set */ \
|
||||
"cmovnoq %%r8, %%r12 \n\t" \
|
||||
"cmovnoq %%r9, %%r13 \n\t" \
|
||||
"cmovnoq %%r10, %%r14 \n\t" \
|
||||
"cmovnoq %%r11, %%r15 \n\t"
|
||||
|
||||
|
||||
/**
|
||||
* Compute Montgomery squaring of a
|
||||
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define SQR(a) \
|
||||
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
\
|
||||
"xorq %%r8, %%r8 \n\t" /* clear flags */ \
|
||||
/* compute a[0] *a[1], a[0]*a[2], a[0]*a[3], a[1]*a[2], a[1]*a[3], a[2]*a[3] */ \
|
||||
"mulxq 8(" a "), %%r9, %%r10 \n\t" /* (r[1], r[2]) <- a[0] * a[1] */ \
|
||||
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[1], t[2]) <- a[0] * a[2] */ \
|
||||
"mulxq 24(" a "), %%r11, %%r12 \n\t" /* (r[3], r[4]) <- a[0] * a[3] */ \
|
||||
\
|
||||
\
|
||||
/* accumulate products into result registers */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[2] += t[1] */ \
|
||||
"adcxq %%r15, %%r11 \n\t" /* r[3] += t[2] */ \
|
||||
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %r%dx */ \
|
||||
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[5], t[6]) <- a[1] * a[2] */ \
|
||||
"mulxq 24(" a "), %%rdi, %%rcx \n\t" /* (t[3], t[4]) <- a[1] * a[3] */ \
|
||||
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %%rdx */ \
|
||||
"mulxq 16(" a "), %%r13, %%r14 \n\t" /* (r[5], r[6]) <- a[3] * a[2] */ \
|
||||
"adoxq %%r8, %%r11 \n\t" /* r[3] += t[5] */ \
|
||||
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] */ \
|
||||
"adoxq %%r15, %%r12 \n\t" /* r[4] += t[6] */ \
|
||||
"adcxq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \
|
||||
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
\
|
||||
/* double result registers */ \
|
||||
"adoxq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \
|
||||
"adcxq %%r12, %%r12 \n\t" /* r[4] = 2r[4] */ \
|
||||
"adoxq %%r10, %%r10 \n\t" /* r[2] = 2r[2] */ \
|
||||
"adcxq %%r13, %%r13 \n\t" /* r[5] = 2r[5] */ \
|
||||
"adoxq %%r11, %%r11 \n\t" /* r[3] = 2r[3] */ \
|
||||
"adcxq %%r14, %%r14 \n\t" /* r[6] = 2r[6] */ \
|
||||
\
|
||||
/* compute a[3]*a[3], a[2]*a[2], a[1]*a[1], a[0]*a[0] */ \
|
||||
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
"mulxq %%rdx, %%r8, %%rcx \n\t" /* (r[0], t[4]) <- a[0] * a[0] */ \
|
||||
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq %%rdx, %%rdx, %%rdi \n\t" /* (t[7], t[8]) <- a[2] * a[2] */ \
|
||||
/* add squares into result registers */ \
|
||||
"adcxq %%rcx, %%r9 \n\t" /* r[1] += t[4] */ \
|
||||
"adoxq %%rdx, %%r12 \n\t" /* r[4] += t[7] */ \
|
||||
"adoxq %%rdi, %%r13 \n\t" /* r[5] += t[8] */ \
|
||||
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \
|
||||
"mulxq %%rdx, %%rcx, %%r15 \n\t" /* (t[5], r[7]) <- a[3] * a[3] */ \
|
||||
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq %%rdx, %%rdi, %%rdx \n\t" /* (t[3], t[6]) <- a[1] * a[1] */ \
|
||||
"adcxq %%rdi, %%r10 \n\t" /* r[2] += t[3] */ \
|
||||
"adcxq %%rdx, %%r11 \n\t" /* r[3] += t[6] */ \
|
||||
"adoxq %%rcx, %%r14 \n\t" /* r[6] += t[5] */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \
|
||||
\
|
||||
/* perform modular reduction: r[0] */ \
|
||||
"movq %%r8, %%rdx \n\t" /* move r8 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"adoxq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \
|
||||
"mulxq %[modulus_3], %%r8, %%rdi \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
|
||||
"adoxq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_o */ \
|
||||
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
"adoxq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_o */ \
|
||||
"adcxq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \
|
||||
"adoxq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_o */ \
|
||||
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"adcxq %%rdi, %%r10 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adcxq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
\
|
||||
/* perform modular reduction: r[1] */ \
|
||||
"movq %%r9, %%rdx \n\t" /* move r9 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
|
||||
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"adoxq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_3], %%r8, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
|
||||
"adoxq %%rcx, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \
|
||||
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \
|
||||
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
"mulxq %[modulus_0], %%r8, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"adcxq %%r8, %%r9 \n\t" /* r[1] += t[0] (%r9 now free) */ \
|
||||
"adoxq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"mulxq %[modulus_1], %%r8, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"adcxq %%r8, %%r10 \n\t" /* r[2] += t[2] */ \
|
||||
"adoxq %%rcx, %%r11 \n\t" /* r[3] += t[3] + flag_o */ \
|
||||
"adcxq %%rdi, %%r11 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
/* perform modular reduction: r[2] */ \
|
||||
"movq %%r10, %%rdx \n\t" /* move r10 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"adoxq %%rcx, %%r12 \n\t" /* r[4] += t[3] + flag_o */ \
|
||||
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \
|
||||
"adoxq %%r9, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \
|
||||
"mulxq %[modulus_3], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adcxq %%r8, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
|
||||
"adoxq %%r9, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \
|
||||
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"adcxq %%r8, %%r10 \n\t" /* r[2] += t[0] (%r10 now free) */ \
|
||||
"adoxq %%r9, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
"adcxq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \
|
||||
"adoxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_o */ \
|
||||
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
|
||||
\
|
||||
/* perform modular reduction: r[3] */ \
|
||||
"movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
|
||||
"adoxq %%rdi, %%r11 \n\t" /* r[3] += t[0] (%r11 now free) */ \
|
||||
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[2] */ \
|
||||
"adoxq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adcxq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
|
||||
"mulxq %[modulus_3], %%r10, %%r11 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[0] + flag_o */ \
|
||||
"adcxq %%r10, %%r14 \n\t" /* r[6] += t[2] + flag_c */ \
|
||||
"adoxq %%r9, %%r14 \n\t" /* r[6] += t[1] + flag_o */ \
|
||||
"adcxq %%r11, %%r15 \n\t" /* r[7] += t[3] + flag_c */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */
|
||||
|
||||
/**
|
||||
* Compute Montgomery multiplication of a, b.
|
||||
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define MUL(a1, a2, a3, a4, b) \
|
||||
"movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
"xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \
|
||||
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
|
||||
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
|
||||
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
|
||||
"mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
|
||||
/* zero flags */ \
|
||||
\
|
||||
/* start computing modular reduction */ \
|
||||
"movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
\
|
||||
/* start first addition chain */ \
|
||||
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_o */ \
|
||||
"adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
\
|
||||
/* reduce by r[0] * k */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
|
||||
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_i */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \
|
||||
"adcxq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_o */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"adoxq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adcxq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_o */ \
|
||||
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
/* modulus = 254 bits, so max(t[3]) = 62 bits */ \
|
||||
/* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \
|
||||
/* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \
|
||||
/* (which is very convenient because we're out of registers!) */ \
|
||||
/* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \
|
||||
\
|
||||
/* a[1] * b */ \
|
||||
"movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
|
||||
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
|
||||
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
|
||||
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
|
||||
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
/* reduce by r[1] * k */ \
|
||||
"movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"adcxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_o */ \
|
||||
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adoxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"adoxq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \
|
||||
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
|
||||
"adcxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
\
|
||||
/* a[2] * b */ \
|
||||
"movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \
|
||||
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adoxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
|
||||
"adcxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_o */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
|
||||
"adoxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
/* reduce by r[2] * k */ \
|
||||
"movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \
|
||||
"adoxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"adcxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \
|
||||
"adoxq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \
|
||||
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
/* a[3] * b */ \
|
||||
"movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
|
||||
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \
|
||||
\
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
|
||||
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_o */ \
|
||||
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += + flag_o */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
\
|
||||
/* reduce by r[3] * k */ \
|
||||
"movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \
|
||||
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
\
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \
|
||||
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
|
||||
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_o */ \
|
||||
"adcxq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */
|
||||
|
||||
/**
|
||||
* Compute Montgomery multiplication of a, b.
|
||||
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define MUL_FOO(a1, a2, a3, a4, b) \
|
||||
"movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
"xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \
|
||||
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
|
||||
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
|
||||
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
|
||||
"mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
|
||||
/* zero flags */ \
|
||||
\
|
||||
/* start computing modular reduction */ \
|
||||
"movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
\
|
||||
/* start first addition chain */ \
|
||||
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_o */ \
|
||||
"adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
\
|
||||
/* reduce by r[0] * k */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
|
||||
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_i */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \
|
||||
"adcxq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_o */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"adoxq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adcxq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_o */ \
|
||||
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
/* modulus = 254 bits, so max(t[3]) = 62 bits */ \
|
||||
/* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \
|
||||
/* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \
|
||||
/* (which is very convenient because we're out of registers!) */ \
|
||||
/* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \
|
||||
\
|
||||
/* a[1] * b */ \
|
||||
"movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
|
||||
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
|
||||
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
|
||||
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
|
||||
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
/* reduce by r[1] * k */ \
|
||||
"movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"adcxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_o */ \
|
||||
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adoxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"adoxq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \
|
||||
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
|
||||
"adcxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
|
||||
\
|
||||
/* a[2] * b */ \
|
||||
"movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \
|
||||
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adoxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
|
||||
"adcxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_o */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
|
||||
"adoxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
|
||||
"adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
/* reduce by r[2] * k */ \
|
||||
"movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
|
||||
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \
|
||||
"adoxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"adcxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \
|
||||
"adoxq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
|
||||
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \
|
||||
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
/* a[3] * b */ \
|
||||
"movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
|
||||
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
|
||||
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \
|
||||
\
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
|
||||
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_o */ \
|
||||
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
|
||||
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += + flag_o */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
|
||||
\
|
||||
/* reduce by r[3] * k */ \
|
||||
"movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \
|
||||
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
|
||||
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
|
||||
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \
|
||||
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \
|
||||
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
|
||||
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
|
||||
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
|
||||
\
|
||||
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \
|
||||
"mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \
|
||||
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \
|
||||
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
|
||||
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_o */ \
|
||||
"adcxq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \
|
||||
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */
|
||||
|
||||
/**
|
||||
* Compute 256-bit multiplication of a, b.
|
||||
* Result is stored, r. // in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
|
||||
**/
|
||||
#define MUL_256(a, b, r) \
|
||||
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
|
||||
\
|
||||
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
|
||||
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
|
||||
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
|
||||
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
|
||||
"mulxq 16(" b "), %%r15, %%rax \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
|
||||
/* zero flags */ \
|
||||
"xorq %%r10, %%r10 \n\t" /* clear r10 register, we use this when we need 0 */ \
|
||||
\
|
||||
\
|
||||
/* start first addition chain */ \
|
||||
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
|
||||
"adoxq %%rdi, %%rax \n\t" /* r[3] += t[2] + flag_o */ \
|
||||
"adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
|
||||
"adcxq %%r10, %%rax \n\t" /* r[3] += flag_o */ \
|
||||
\
|
||||
/* a[1] * b */ \
|
||||
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
|
||||
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
|
||||
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%rsi, %%rax \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
\
|
||||
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
|
||||
"adcxq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
/* a[2] * b */ \
|
||||
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
|
||||
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
|
||||
"adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
|
||||
"adoxq %%r9, %%rax \n\t" /* r[3] += t[1] + flag_o */ \
|
||||
"adcxq %%rdi, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
\
|
||||
\
|
||||
/* a[3] * b */ \
|
||||
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \
|
||||
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
|
||||
"adcxq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
|
||||
"movq %%r13, 0(" r ") \n\t" \
|
||||
"movq %%r14, 8(" r ") \n\t" \
|
||||
"movq %%r15, 16(" r ") \n\t" \
|
||||
"movq %%rax, 24(" r ") \n\t"
|
||||
#endif
|
||||
10
sumcheck/src/cuda/includes/barretenberg/ecc/fields/field.hpp
Normal file
10
sumcheck/src/cuda/includes/barretenberg/ecc/fields/field.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* @brief Include order of header-only field class is structured to ensure linter/language server can resolve paths.
|
||||
* Declarations are defined in "field_declarations.hpp", definitions in "field_impl.hpp" (which includes
|
||||
* declarations header) Spectialized definitions are in "field_impl_generic.hpp" and "field_impl_x64.hpp"
|
||||
* (which include "field_impl.hpp")
|
||||
*/
|
||||
#include "./field_impl_generic.hpp"
|
||||
#include "./field_impl_x64.hpp"
|
||||
@@ -0,0 +1,682 @@
|
||||
#pragma once
|
||||
#include "../../common/assert.hpp"
|
||||
#include "../../common/compiler_hints.hpp"
|
||||
#include "../../numeric/random/engine.hpp"
|
||||
#include "../../numeric/uint128/uint128.hpp"
|
||||
#include "../../numeric/uint256/uint256.hpp"
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <span>
|
||||
|
||||
#ifndef DISABLE_ASM
|
||||
#ifdef __BMI2__
|
||||
#define BBERG_NO_ASM 0
|
||||
#else
|
||||
#define BBERG_NO_ASM 1
|
||||
#endif
|
||||
#else
|
||||
#define BBERG_NO_ASM 1
|
||||
#endif
|
||||
|
||||
namespace bb {
|
||||
using namespace numeric;
|
||||
/**
|
||||
* @brief General class for prime fields see \ref field_docs["field documentation"] for general implementation reference
|
||||
*
|
||||
* @tparam Params_
|
||||
*/
|
||||
template <class Params_> struct alignas(32) field {
|
||||
public:
|
||||
using View = field;
|
||||
using Params = Params_;
|
||||
using in_buf = const uint8_t*;
|
||||
using vec_in_buf = const uint8_t*;
|
||||
using out_buf = uint8_t*;
|
||||
using vec_out_buf = uint8_t**;
|
||||
|
||||
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
|
||||
#define WASM_NUM_LIMBS 9
|
||||
#define WASM_LIMB_BITS 29
|
||||
#endif
|
||||
|
||||
// We don't initialize data in the default constructor since we'd lose a lot of time on huge array initializations.
|
||||
// Other alternatives have been noted, such as casting to get around constructors where they matter,
|
||||
// however it is felt that sanitizer tools (e.g. MSAN) can detect garbage well, whereas doing
|
||||
// hacky casts where needed would require rework to critical algos like MSM, FFT, Sumcheck.
|
||||
// Instead, the recommended solution is use an explicit {} where initialization is important:
|
||||
// field f; // not initialized
|
||||
// field f{}; // zero-initialized
|
||||
// std::array<field, N> arr; // not initialized, good for huge N
|
||||
// std::array<field, N> arr {}; // zero-initialized, preferable for moderate N
|
||||
field() = default;
|
||||
|
||||
constexpr field(const numeric::uint256_t& input) noexcept
|
||||
: data{ input.data[0], input.data[1], input.data[2], input.data[3] }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE (unsigned long is platform dependent, which we want in this case)
|
||||
constexpr field(const unsigned long input) noexcept
|
||||
: data{ input, 0, 0, 0 }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
constexpr field(const unsigned int input) noexcept
|
||||
: data{ input, 0, 0, 0 }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE (unsigned long long is platform dependent, which we want in this case)
|
||||
constexpr field(const unsigned long long input) noexcept
|
||||
: data{ input, 0, 0, 0 }
|
||||
{
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
constexpr field(const int input) noexcept
|
||||
: data{ 0, 0, 0, 0 }
|
||||
{
|
||||
if (input < 0) {
|
||||
data[0] = static_cast<uint64_t>(-input);
|
||||
data[1] = 0;
|
||||
data[2] = 0;
|
||||
data[3] = 0;
|
||||
self_to_montgomery_form();
|
||||
self_neg();
|
||||
self_reduce_once();
|
||||
} else {
|
||||
data[0] = static_cast<uint64_t>(input);
|
||||
data[1] = 0;
|
||||
data[2] = 0;
|
||||
data[3] = 0;
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
}
|
||||
|
||||
constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
|
||||
: data{ a, b, c, d } {};
|
||||
|
||||
/**
|
||||
* @brief Convert a 512-bit big integer into a field element.
|
||||
*
|
||||
* @details Used for deriving field elements from random values. 512-bits prevents biased output as 2^512>>modulus
|
||||
*
|
||||
*/
|
||||
constexpr explicit field(const uint512_t& input) noexcept
|
||||
{
|
||||
uint256_t value = (input % modulus).lo;
|
||||
data[0] = value.data[0];
|
||||
data[1] = value.data[1];
|
||||
data[2] = value.data[2];
|
||||
data[3] = value.data[3];
|
||||
self_to_montgomery_form();
|
||||
}
|
||||
|
||||
constexpr explicit field(std::string input) noexcept
|
||||
{
|
||||
uint256_t value(input);
|
||||
*this = field(value);
|
||||
}
|
||||
|
||||
constexpr explicit operator bool() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
ASSERT(out.data[0] == 0 || out.data[0] == 1);
|
||||
return static_cast<bool>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint8_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return static_cast<uint8_t>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint16_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return static_cast<uint16_t>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint32_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return static_cast<uint32_t>(out.data[0]);
|
||||
}
|
||||
|
||||
constexpr explicit operator uint64_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return out.data[0];
|
||||
}
|
||||
|
||||
constexpr explicit operator uint128_t() const
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
uint128_t lo = out.data[0];
|
||||
uint128_t hi = out.data[1];
|
||||
return (hi << 64) | lo;
|
||||
}
|
||||
|
||||
constexpr operator uint256_t() const noexcept
|
||||
{
|
||||
field out = from_montgomery_form();
|
||||
return uint256_t(out.data[0], out.data[1], out.data[2], out.data[3]);
|
||||
}
|
||||
|
||||
[[nodiscard]] constexpr uint256_t uint256_t_no_montgomery_conversion() const noexcept
|
||||
{
|
||||
return { data[0], data[1], data[2], data[3] };
|
||||
}
|
||||
|
||||
constexpr field(const field& other) noexcept = default;
|
||||
constexpr field(field&& other) noexcept = default;
|
||||
constexpr field& operator=(const field& other) noexcept = default;
|
||||
constexpr field& operator=(field&& other) noexcept = default;
|
||||
constexpr ~field() noexcept = default;
|
||||
alignas(32) uint64_t data[4]; // NOLINT
|
||||
|
||||
static constexpr uint256_t modulus =
|
||||
uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr uint256_t r_squared_uint{
|
||||
Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3
|
||||
};
|
||||
#else
|
||||
static constexpr uint256_t r_squared_uint{
|
||||
Params_::r_squared_wasm_0, Params_::r_squared_wasm_1, Params_::r_squared_wasm_2, Params_::r_squared_wasm_3
|
||||
};
|
||||
static constexpr std::array<uint64_t, 9> wasm_modulus = { Params::modulus_wasm_0, Params::modulus_wasm_1,
|
||||
Params::modulus_wasm_2, Params::modulus_wasm_3,
|
||||
Params::modulus_wasm_4, Params::modulus_wasm_5,
|
||||
Params::modulus_wasm_6, Params::modulus_wasm_7,
|
||||
Params::modulus_wasm_8 };
|
||||
|
||||
#endif
|
||||
static constexpr field cube_root_of_unity()
|
||||
{
|
||||
// endomorphism i.e. lambda * [P] = (beta * x, y)
|
||||
if constexpr (Params::cube_root_0 != 0) {
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
constexpr field result{
|
||||
Params::cube_root_0, Params::cube_root_1, Params::cube_root_2, Params::cube_root_3
|
||||
};
|
||||
#else
|
||||
constexpr field result{
|
||||
Params::cube_root_wasm_0, Params::cube_root_wasm_1, Params::cube_root_wasm_2, Params::cube_root_wasm_3
|
||||
};
|
||||
#endif
|
||||
return result;
|
||||
} else {
|
||||
constexpr field two_inv = field(2).invert();
|
||||
constexpr field numerator = (-field(3)).sqrt() - field(1);
|
||||
constexpr field result = two_inv * numerator;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr field zero() { return field(0, 0, 0, 0); }
|
||||
static constexpr field neg_one() { return -field(1); }
|
||||
static constexpr field one() { return field(1); }
|
||||
|
||||
static constexpr field external_coset_generator()
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const field result{
|
||||
Params::coset_generators_0[7],
|
||||
Params::coset_generators_1[7],
|
||||
Params::coset_generators_2[7],
|
||||
Params::coset_generators_3[7],
|
||||
};
|
||||
#else
|
||||
const field result{
|
||||
Params::coset_generators_wasm_0[7],
|
||||
Params::coset_generators_wasm_1[7],
|
||||
Params::coset_generators_wasm_2[7],
|
||||
Params::coset_generators_wasm_3[7],
|
||||
};
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static constexpr field tag_coset_generator()
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const field result{
|
||||
Params::coset_generators_0[6],
|
||||
Params::coset_generators_1[6],
|
||||
Params::coset_generators_2[6],
|
||||
Params::coset_generators_3[6],
|
||||
};
|
||||
#else
|
||||
const field result{
|
||||
Params::coset_generators_wasm_0[6],
|
||||
Params::coset_generators_wasm_1[6],
|
||||
Params::coset_generators_wasm_2[6],
|
||||
Params::coset_generators_wasm_3[6],
|
||||
};
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static constexpr field coset_generator(const size_t idx)
|
||||
{
|
||||
ASSERT(idx < 7);
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const field result{
|
||||
Params::coset_generators_0[idx],
|
||||
Params::coset_generators_1[idx],
|
||||
Params::coset_generators_2[idx],
|
||||
Params::coset_generators_3[idx],
|
||||
};
|
||||
#else
|
||||
const field result{
|
||||
Params::coset_generators_wasm_0[idx],
|
||||
Params::coset_generators_wasm_1[idx],
|
||||
Params::coset_generators_wasm_2[idx],
|
||||
Params::coset_generators_wasm_3[idx],
|
||||
};
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
BB_INLINE constexpr field operator*(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field operator+(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field operator-(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field operator-() const noexcept;
|
||||
constexpr field operator/(const field& other) const noexcept;
|
||||
|
||||
// prefix increment (++x)
|
||||
BB_INLINE constexpr field operator++() noexcept;
|
||||
// postfix increment (x++)
|
||||
// NOLINTNEXTLINE
|
||||
BB_INLINE constexpr field operator++(int) noexcept;
|
||||
|
||||
BB_INLINE constexpr field& operator*=(const field& other) noexcept;
|
||||
BB_INLINE constexpr field& operator+=(const field& other) noexcept;
|
||||
BB_INLINE constexpr field& operator-=(const field& other) noexcept;
|
||||
constexpr field& operator/=(const field& other) noexcept;
|
||||
|
||||
// NOTE: comparison operators exist so that `field` is comparible with stl methods that require them.
|
||||
// (e.g. std::sort)
|
||||
// Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
|
||||
BB_INLINE constexpr bool operator>(const field& other) const noexcept;
|
||||
BB_INLINE constexpr bool operator<(const field& other) const noexcept;
|
||||
BB_INLINE constexpr bool operator==(const field& other) const noexcept;
|
||||
BB_INLINE constexpr bool operator!=(const field& other) const noexcept;
|
||||
|
||||
BB_INLINE constexpr field to_montgomery_form() const noexcept;
|
||||
BB_INLINE constexpr field from_montgomery_form() const noexcept;
|
||||
|
||||
BB_INLINE constexpr field sqr() const noexcept;
|
||||
BB_INLINE constexpr void self_sqr() noexcept;
|
||||
|
||||
BB_INLINE constexpr field pow(const uint256_t& exponent) const noexcept;
|
||||
BB_INLINE constexpr field pow(uint64_t exponent) const noexcept;
|
||||
static_assert(Params::modulus_0 != 1);
|
||||
static constexpr uint256_t modulus_minus_two =
|
||||
uint256_t(Params::modulus_0 - 2ULL, Params::modulus_1, Params::modulus_2, Params::modulus_3);
|
||||
constexpr field invert() const noexcept;
|
||||
static void batch_invert(std::span<field> coeffs) noexcept;
|
||||
static void batch_invert(field* coeffs, size_t n) noexcept;
|
||||
/**
|
||||
* @brief Compute square root of the field element.
|
||||
*
|
||||
* @return <true, root> if the element is a quadratic remainder, <false, 0> if it's not
|
||||
*/
|
||||
constexpr std::pair<bool, field> sqrt() const noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_neg() noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_to_montgomery_form() noexcept;
|
||||
BB_INLINE constexpr void self_from_montgomery_form() noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_conditional_negate(uint64_t predicate) noexcept;
|
||||
|
||||
BB_INLINE constexpr field reduce_once() const noexcept;
|
||||
BB_INLINE constexpr void self_reduce_once() noexcept;
|
||||
|
||||
BB_INLINE constexpr void self_set_msb() noexcept;
|
||||
[[nodiscard]] BB_INLINE constexpr bool is_msb_set() const noexcept;
|
||||
[[nodiscard]] BB_INLINE constexpr uint64_t is_msb_set_word() const noexcept;
|
||||
|
||||
[[nodiscard]] BB_INLINE constexpr bool is_zero() const noexcept;
|
||||
|
||||
static constexpr field get_root_of_unity(size_t subgroup_size) noexcept;
|
||||
|
||||
static void serialize_to_buffer(const field& value, uint8_t* buffer) { write(buffer, value); }
|
||||
|
||||
static field serialize_from_buffer(const uint8_t* buffer) { return from_buffer<field>(buffer); }
|
||||
|
||||
[[nodiscard]] BB_INLINE std::vector<uint8_t> to_buffer() const { return to_buffer(*this); }
|
||||
|
||||
struct wide_array {
|
||||
uint64_t data[8]; // NOLINT
|
||||
};
|
||||
BB_INLINE constexpr wide_array mul_512(const field& other) const noexcept;
|
||||
BB_INLINE constexpr wide_array sqr_512() const noexcept;
|
||||
|
||||
BB_INLINE constexpr field conditionally_subtract_from_double_modulus(const uint64_t predicate) const noexcept
|
||||
{
|
||||
if (predicate != 0) {
|
||||
constexpr field p{
|
||||
twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3]
|
||||
};
|
||||
return p - *this;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* For short Weierstrass curves y^2 = x^3 + b mod r, if there exists a cube root of unity mod r,
|
||||
* we can take advantage of an enodmorphism to decompose a 254 bit scalar into 2 128 bit scalars.
|
||||
* \beta = cube root of 1, mod q (q = order of fq)
|
||||
* \lambda = cube root of 1, mod r (r = order of fr)
|
||||
*
|
||||
* For a point P1 = (X, Y), where Y^2 = X^3 + b, we know that
|
||||
* the point P2 = (X * \beta, Y) is also a point on the curve
|
||||
* We can represent P2 as a scalar multiplication of P1, where P2 = \lambda * P1
|
||||
*
|
||||
* For a generic multiplication of P1 by a 254 bit scalar k, we can decompose k
|
||||
* into 2 127 bit scalars (k1, k2), such that k = k1 - (k2 * \lambda)
|
||||
*
|
||||
* We can now represent (k * P1) as (k1 * P1) - (k2 * P2), where P2 = (X * \beta, Y).
|
||||
* As k1, k2 have half the bit length of k, we have reduced the number of loop iterations of our
|
||||
* scalar multiplication algorithm in half
|
||||
*
|
||||
* To find k1, k2, We use the extended euclidean algorithm to find 4 short scalars [a1, a2], [b1, b2] such that
|
||||
* modulus = (a1 * b2) - (b1 * a2)
|
||||
* We then compute scalars c1 = round(b2 * k / r), c2 = round(b1 * k / r), where
|
||||
* k1 = (c1 * a1) + (c2 * a2), k2 = -((c1 * b1) + (c2 * b2))
|
||||
* We pre-compute scalars g1 = (2^256 * b1) / n, g2 = (2^256 * b2) / n, to avoid having to perform long division
|
||||
* on 512-bit scalars
|
||||
**/
|
||||
static void split_into_endomorphism_scalars(const field& k, field& k1, field& k2)
|
||||
{
|
||||
// if the modulus is a >= 255-bit integer, we need to use a basis where g1, g2 have been shifted by 2^384
|
||||
if constexpr (Params::modulus_3 >= 0x4000000000000000ULL) {
|
||||
split_into_endomorphism_scalars_384(k, k1, k2);
|
||||
} else {
|
||||
std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> ret = split_into_endomorphism_scalars(k);
|
||||
k1.data[0] = ret.first[0];
|
||||
k1.data[1] = ret.first[1];
|
||||
|
||||
// TODO(https://github.com/AztecProtocol/barretenberg/issues/851): We should move away from this hack by
|
||||
// returning pair of uint64_t[2] instead of a half-set field
|
||||
#if !defined(__clang__) && defined(__GNUC__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Warray-bounds"
|
||||
#endif
|
||||
k2.data[0] = ret.second[0]; // NOLINT
|
||||
k2.data[1] = ret.second[1];
|
||||
#if !defined(__clang__) && defined(__GNUC__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: this form is only usable if the modulus is 254 bits or less, otherwise see
|
||||
// split_into_endomorphism_scalars_384.
|
||||
// TODO(https://github.com/AztecProtocol/barretenberg/issues/851): Unify these APIs.
|
||||
static std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> split_into_endomorphism_scalars(const field& k)
|
||||
{
|
||||
static_assert(Params::modulus_3 < 0x4000000000000000ULL);
|
||||
field input = k.reduce_once();
|
||||
|
||||
constexpr field endo_g1 = { Params::endo_g1_lo, Params::endo_g1_mid, Params::endo_g1_hi, 0 };
|
||||
|
||||
constexpr field endo_g2 = { Params::endo_g2_lo, Params::endo_g2_mid, 0, 0 };
|
||||
|
||||
constexpr field endo_minus_b1 = { Params::endo_minus_b1_lo, Params::endo_minus_b1_mid, 0, 0 };
|
||||
|
||||
constexpr field endo_b2 = { Params::endo_b2_lo, Params::endo_b2_mid, 0, 0 };
|
||||
|
||||
// compute c1 = (g2 * k) >> 256
|
||||
wide_array c1 = endo_g2.mul_512(input);
|
||||
// compute c2 = (g1 * k) >> 256
|
||||
wide_array c2 = endo_g1.mul_512(input);
|
||||
|
||||
// (the bit shifts are implicit, as we only utilize the high limbs of c1, c2
|
||||
|
||||
field c1_hi = {
|
||||
c1.data[4], c1.data[5], c1.data[6], c1.data[7]
|
||||
}; // *(field*)((uintptr_t)(&c1) + (4 * sizeof(uint64_t)));
|
||||
field c2_hi = {
|
||||
c2.data[4], c2.data[5], c2.data[6], c2.data[7]
|
||||
}; // *(field*)((uintptr_t)(&c2) + (4 * sizeof(uint64_t)));
|
||||
|
||||
// compute q1 = c1 * -b1
|
||||
wide_array q1 = c1_hi.mul_512(endo_minus_b1);
|
||||
// compute q2 = c2 * b2
|
||||
wide_array q2 = c2_hi.mul_512(endo_b2);
|
||||
|
||||
// FIX: Avoid using 512-bit multiplication as its not necessary.
|
||||
// c1_hi, c2_hi can be uint256_t's and the final result (without montgomery reduction)
|
||||
// could be casted to a field.
|
||||
field q1_lo{ q1.data[0], q1.data[1], q1.data[2], q1.data[3] };
|
||||
field q2_lo{ q2.data[0], q2.data[1], q2.data[2], q2.data[3] };
|
||||
|
||||
field t1 = (q2_lo - q1_lo).reduce_once();
|
||||
field beta = cube_root_of_unity();
|
||||
field t2 = (t1 * beta + input).reduce_once();
|
||||
return {
|
||||
{ t2.data[0], t2.data[1] },
|
||||
{ t1.data[0], t1.data[1] },
|
||||
};
|
||||
}
|
||||
|
||||
static void split_into_endomorphism_scalars_384(const field& input, field& k1_out, field& k2_out)
|
||||
{
|
||||
constexpr field minus_b1f{
|
||||
Params::endo_minus_b1_lo,
|
||||
Params::endo_minus_b1_mid,
|
||||
0,
|
||||
0,
|
||||
};
|
||||
constexpr field b2f{
|
||||
Params::endo_b2_lo,
|
||||
Params::endo_b2_mid,
|
||||
0,
|
||||
0,
|
||||
};
|
||||
constexpr uint256_t g1{
|
||||
Params::endo_g1_lo,
|
||||
Params::endo_g1_mid,
|
||||
Params::endo_g1_hi,
|
||||
Params::endo_g1_hihi,
|
||||
};
|
||||
constexpr uint256_t g2{
|
||||
Params::endo_g2_lo,
|
||||
Params::endo_g2_mid,
|
||||
Params::endo_g2_hi,
|
||||
Params::endo_g2_hihi,
|
||||
};
|
||||
|
||||
field kf = input.reduce_once();
|
||||
uint256_t k{ kf.data[0], kf.data[1], kf.data[2], kf.data[3] };
|
||||
|
||||
uint512_t c1 = (uint512_t(k) * static_cast<uint512_t>(g1)) >> 384;
|
||||
uint512_t c2 = (uint512_t(k) * static_cast<uint512_t>(g2)) >> 384;
|
||||
|
||||
field c1f{ c1.lo.data[0], c1.lo.data[1], c1.lo.data[2], c1.lo.data[3] };
|
||||
field c2f{ c2.lo.data[0], c2.lo.data[1], c2.lo.data[2], c2.lo.data[3] };
|
||||
|
||||
c1f.self_to_montgomery_form();
|
||||
c2f.self_to_montgomery_form();
|
||||
c1f = c1f * minus_b1f;
|
||||
c2f = c2f * b2f;
|
||||
field r2f = c1f - c2f;
|
||||
field beta = cube_root_of_unity();
|
||||
field r1f = input.reduce_once() - r2f * beta;
|
||||
k1_out = r1f;
|
||||
k2_out = -r2f;
|
||||
}
|
||||
|
||||
// static constexpr auto coset_generators = compute_coset_generators();
|
||||
// static constexpr std::array<field, 15> coset_generators = compute_coset_generators((1 << 30U));
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const field& a)
|
||||
{
|
||||
field out = a.from_montgomery_form();
|
||||
std::ios_base::fmtflags f(os.flags());
|
||||
os << std::hex << "0x" << std::setfill('0') << std::setw(16) << out.data[3] << std::setw(16) << out.data[2]
|
||||
<< std::setw(16) << out.data[1] << std::setw(16) << out.data[0];
|
||||
os.flags(f);
|
||||
return os;
|
||||
}
|
||||
|
||||
BB_INLINE static void __copy(const field& a, field& r) noexcept { r = a; } // NOLINT
|
||||
BB_INLINE static void __swap(field& src, field& dest) noexcept // NOLINT
|
||||
{
|
||||
field T = dest;
|
||||
dest = src;
|
||||
src = T;
|
||||
}
|
||||
|
||||
static field random_element(numeric::RNG* engine = nullptr) noexcept;
|
||||
|
||||
static constexpr field multiplicative_generator() noexcept;
|
||||
|
||||
static constexpr uint256_t twice_modulus = modulus + modulus;
|
||||
static constexpr uint256_t not_modulus = -modulus;
|
||||
static constexpr uint256_t twice_not_modulus = -twice_modulus;
|
||||
|
||||
struct wnaf_table {
|
||||
uint8_t windows[64]; // NOLINT
|
||||
|
||||
constexpr wnaf_table(const uint256_t& target)
|
||||
: windows{
|
||||
static_cast<uint8_t>(target.data[0] & 15), static_cast<uint8_t>((target.data[0] >> 4) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 8) & 15), static_cast<uint8_t>((target.data[0] >> 12) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 16) & 15), static_cast<uint8_t>((target.data[0] >> 20) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 24) & 15), static_cast<uint8_t>((target.data[0] >> 28) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 32) & 15), static_cast<uint8_t>((target.data[0] >> 36) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 40) & 15), static_cast<uint8_t>((target.data[0] >> 44) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 48) & 15), static_cast<uint8_t>((target.data[0] >> 52) & 15),
|
||||
static_cast<uint8_t>((target.data[0] >> 56) & 15), static_cast<uint8_t>((target.data[0] >> 60) & 15),
|
||||
static_cast<uint8_t>(target.data[1] & 15), static_cast<uint8_t>((target.data[1] >> 4) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 8) & 15), static_cast<uint8_t>((target.data[1] >> 12) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 16) & 15), static_cast<uint8_t>((target.data[1] >> 20) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 24) & 15), static_cast<uint8_t>((target.data[1] >> 28) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 32) & 15), static_cast<uint8_t>((target.data[1] >> 36) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 40) & 15), static_cast<uint8_t>((target.data[1] >> 44) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 48) & 15), static_cast<uint8_t>((target.data[1] >> 52) & 15),
|
||||
static_cast<uint8_t>((target.data[1] >> 56) & 15), static_cast<uint8_t>((target.data[1] >> 60) & 15),
|
||||
static_cast<uint8_t>(target.data[2] & 15), static_cast<uint8_t>((target.data[2] >> 4) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 8) & 15), static_cast<uint8_t>((target.data[2] >> 12) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 16) & 15), static_cast<uint8_t>((target.data[2] >> 20) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 24) & 15), static_cast<uint8_t>((target.data[2] >> 28) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 32) & 15), static_cast<uint8_t>((target.data[2] >> 36) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 40) & 15), static_cast<uint8_t>((target.data[2] >> 44) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 48) & 15), static_cast<uint8_t>((target.data[2] >> 52) & 15),
|
||||
static_cast<uint8_t>((target.data[2] >> 56) & 15), static_cast<uint8_t>((target.data[2] >> 60) & 15),
|
||||
static_cast<uint8_t>(target.data[3] & 15), static_cast<uint8_t>((target.data[3] >> 4) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 8) & 15), static_cast<uint8_t>((target.data[3] >> 12) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 16) & 15), static_cast<uint8_t>((target.data[3] >> 20) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 24) & 15), static_cast<uint8_t>((target.data[3] >> 28) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 32) & 15), static_cast<uint8_t>((target.data[3] >> 36) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 40) & 15), static_cast<uint8_t>((target.data[3] >> 44) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 48) & 15), static_cast<uint8_t>((target.data[3] >> 52) & 15),
|
||||
static_cast<uint8_t>((target.data[3] >> 56) & 15), static_cast<uint8_t>((target.data[3] >> 60) & 15)
|
||||
}
|
||||
{}
|
||||
};
|
||||
|
||||
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
|
||||
BB_INLINE static constexpr void wasm_madd(uint64_t& left_limb,
|
||||
const std::array<uint64_t, WASM_NUM_LIMBS>& right_limbs,
|
||||
uint64_t& result_0,
|
||||
uint64_t& result_1,
|
||||
uint64_t& result_2,
|
||||
uint64_t& result_3,
|
||||
uint64_t& result_4,
|
||||
uint64_t& result_5,
|
||||
uint64_t& result_6,
|
||||
uint64_t& result_7,
|
||||
uint64_t& result_8);
|
||||
BB_INLINE static constexpr void wasm_reduce(uint64_t& result_0,
|
||||
uint64_t& result_1,
|
||||
uint64_t& result_2,
|
||||
uint64_t& result_3,
|
||||
uint64_t& result_4,
|
||||
uint64_t& result_5,
|
||||
uint64_t& result_6,
|
||||
uint64_t& result_7,
|
||||
uint64_t& result_8);
|
||||
BB_INLINE static constexpr std::array<uint64_t, WASM_NUM_LIMBS> wasm_convert(const uint64_t* data);
|
||||
#endif
|
||||
BB_INLINE static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t mac(
|
||||
uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& carry_out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr void mac(
|
||||
uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& out, uint64_t& carry_out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t mac_mini(uint64_t a, uint64_t b, uint64_t c, uint64_t& out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr void mac_mini(
|
||||
uint64_t a, uint64_t b, uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t mac_discard_lo(uint64_t a, uint64_t b, uint64_t c) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t addc(uint64_t a, uint64_t b, uint64_t carry_in, uint64_t& carry_out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t sbb(uint64_t a, uint64_t b, uint64_t borrow_in, uint64_t& borrow_out) noexcept;
|
||||
|
||||
BB_INLINE static constexpr uint64_t square_accumulate(uint64_t a,
|
||||
uint64_t b,
|
||||
uint64_t c,
|
||||
uint64_t carry_in_lo,
|
||||
uint64_t carry_in_hi,
|
||||
uint64_t& carry_lo,
|
||||
uint64_t& carry_hi) noexcept;
|
||||
BB_INLINE constexpr field reduce() const noexcept;
|
||||
BB_INLINE constexpr field add(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field subtract(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field subtract_coarse(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field montgomery_mul(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field montgomery_mul_big(const field& other) const noexcept;
|
||||
BB_INLINE constexpr field montgomery_square() const noexcept;
|
||||
|
||||
#if (BBERG_NO_ASM == 0)
|
||||
BB_INLINE static field asm_mul(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static field asm_sqr(const field& a) noexcept;
|
||||
BB_INLINE static field asm_add(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static field asm_sub(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static field asm_mul_with_coarse_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static field asm_sqr_with_coarse_reduction(const field& a) noexcept;
|
||||
BB_INLINE static field asm_add_with_coarse_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static field asm_sub_with_coarse_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static field asm_add_without_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static void asm_self_sqr(const field& a) noexcept;
|
||||
BB_INLINE static void asm_self_add(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static void asm_self_sub(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static void asm_self_mul_with_coarse_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static void asm_self_sqr_with_coarse_reduction(const field& a) noexcept;
|
||||
BB_INLINE static void asm_self_add_with_coarse_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static void asm_self_sub_with_coarse_reduction(const field& a, const field& b) noexcept;
|
||||
BB_INLINE static void asm_self_add_without_reduction(const field& a, const field& b) noexcept;
|
||||
|
||||
BB_INLINE static void asm_conditional_negate(field& r, uint64_t predicate) noexcept;
|
||||
BB_INLINE static field asm_reduce_once(const field& a) noexcept;
|
||||
BB_INLINE static void asm_self_reduce_once(const field& a) noexcept;
|
||||
static constexpr uint64_t zero_reference = 0x00ULL;
|
||||
#endif
|
||||
static constexpr size_t COSET_GENERATOR_SIZE = 15;
|
||||
constexpr field tonelli_shanks_sqrt() const noexcept;
|
||||
static constexpr size_t primitive_root_log_size() noexcept;
|
||||
static constexpr std::array<field, COSET_GENERATOR_SIZE> compute_coset_generators() noexcept;
|
||||
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr uint128_t lo_mask = 0xffffffffffffffffUL;
|
||||
#endif
|
||||
};
|
||||
} // namespace bb
|
||||
@@ -0,0 +1,673 @@
|
||||
#pragma once
|
||||
#include "../../common/op_count.hpp"
|
||||
#include "../../common/slab_allocator.hpp"
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include "../../numeric/bitop/get_msb.hpp"
|
||||
#include "../../numeric/random/engine.hpp"
|
||||
#include <memory>
|
||||
#include <span>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "./field_declarations.hpp"
|
||||
|
||||
namespace bb {
|
||||
using namespace numeric;
|
||||
// clang-format off
|
||||
// disable the following style guides:
|
||||
// cppcoreguidelines-avoid-c-arrays : we make heavy use of c-style arrays here to prevent default-initialization of memory when constructing `field` objects.
|
||||
// The intention is for field to act like a primitive numeric type with the performance/complexity trade-offs expected from this.
|
||||
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
|
||||
// clang-format on
|
||||
/**
|
||||
*
|
||||
* Mutiplication
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::operator*(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::mul");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
// >= 255-bits or <= 64-bits.
|
||||
return montgomery_mul(other);
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
return montgomery_mul(other);
|
||||
}
|
||||
return asm_mul_with_coarse_reduction(*this, other);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator*=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_mul");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
// >= 255-bits or <= 64-bits.
|
||||
*this = operator*(other);
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
*this = operator*(other);
|
||||
} else {
|
||||
asm_self_mul_with_coarse_reduction(*this, other); // asm_self_mul(*this, other);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Squaring
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::sqr() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::sqr");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
return montgomery_square();
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
return montgomery_square();
|
||||
}
|
||||
return asm_sqr_with_coarse_reduction(*this); // asm_sqr(*this);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_sqr() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("f::self_sqr");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
*this = montgomery_square();
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
*this = montgomery_square();
|
||||
} else {
|
||||
asm_self_sqr_with_coarse_reduction(*this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Addition
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::operator+(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::add");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
return add(other);
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
return add(other);
|
||||
}
|
||||
return asm_add_with_coarse_reduction(*this, other); // asm_add_without_reduction(*this, other);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator+=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_add");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
(*this) = operator+(other);
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
(*this) = operator+(other);
|
||||
} else {
|
||||
asm_self_add_with_coarse_reduction(*this, other); // asm_self_add(*this, other);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::operator++() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("++f");
|
||||
return *this += 1;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cert-dcl21-cpp) circular linting errors. If const is added, linter suggests removing
|
||||
template <class T> constexpr field<T> field<T>::operator++(int) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::increment");
|
||||
field<T> value_before_incrementing = *this;
|
||||
*this += 1;
|
||||
return value_before_incrementing;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Subtraction
|
||||
*
|
||||
**/
|
||||
template <class T> constexpr field<T> field<T>::operator-(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::sub");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
return subtract_coarse(other); // modulus - *this;
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
return subtract_coarse(other); // subtract(other);
|
||||
}
|
||||
return asm_sub_with_coarse_reduction(*this, other); // asm_sub(*this, other);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::operator-() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("-f");
|
||||
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
|
||||
return p - *this; // modulus - *this;
|
||||
}
|
||||
|
||||
// TODO(@zac-williamson): there are 3 ways we can make this more efficient
|
||||
// 1: we subtract `p` from `*this` instead of `2p`
|
||||
// 2: instead of `p - *this`, we use an asm block that does `p - *this` without the assembly reduction step
|
||||
// 3: we replace `(p - *this).reduce_once()` with an assembly block that is equivalent to `p - *this`,
|
||||
// but we call `REDUCE_FIELD_ELEMENT` with `not_twice_modulus` instead of `twice_modulus`
|
||||
// not sure which is faster and whether any of the above might break something!
|
||||
//
|
||||
// More context below:
|
||||
// the operator-(a, b) method's asm implementation has a sneaky was to check underflow.
|
||||
// if `a - b` underflows we need to add in `2p`. Instead of conditional branching which would cause pipeline
|
||||
// flushes, we add `2p` into the result of `a - b`. If the result triggers the overflow flag, then we know we are
|
||||
// correcting an *underflow* produced from computing `a - b`. Finally...we use the overflow flag to conditionally
|
||||
// move data into registers such that we end up with either `a - b` or `2p + (a - b)` (this is branchless). OK! So
|
||||
// what's the problem? Well we assume that every field element lies between 0 and 2p - 1. But we are computing `2p -
|
||||
// *this`! If *this = 0 then we exceed this bound hence the need for the extra reduction step. HOWEVER, we also know
|
||||
// that 2p - *this won't underflow so we could skip the underflow check present in the assembly code
|
||||
constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] };
|
||||
return (p - *this).reduce_once(); // modulus - *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator-=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_sub");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
*this = subtract_coarse(other); // subtract(other);
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
*this = subtract_coarse(other); // subtract(other);
|
||||
} else {
|
||||
asm_self_sub_with_coarse_reduction(*this, other); // asm_self_sub(*this, other);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_neg() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_neg");
|
||||
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
|
||||
*this = p - *this;
|
||||
} else {
|
||||
constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] };
|
||||
*this = (p - *this).reduce_once();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_conditional_negate(const uint64_t predicate) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_conditional_negate");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
*this = predicate ? -(*this) : *this; // NOLINT
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
*this = predicate ? -(*this) : *this; // NOLINT
|
||||
} else {
|
||||
asm_conditional_negate(*this, predicate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Greater-than operator
|
||||
* @details comparison operators exist so that `field` is comparible with stl methods that require them.
|
||||
* (e.g. std::sort)
|
||||
* Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
|
||||
*
|
||||
* @tparam T
|
||||
* @param other
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
template <class T> constexpr bool field<T>::operator>(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::gt");
|
||||
const field left = reduce_once();
|
||||
const field right = other.reduce_once();
|
||||
const bool t0 = left.data[3] > right.data[3];
|
||||
const bool t1 = (left.data[3] == right.data[3]) && (left.data[2] > right.data[2]);
|
||||
const bool t2 =
|
||||
(left.data[3] == right.data[3]) && (left.data[2] == right.data[2]) && (left.data[1] > right.data[1]);
|
||||
const bool t3 = (left.data[3] == right.data[3]) && (left.data[2] == right.data[2]) &&
|
||||
(left.data[1] == right.data[1]) && (left.data[0] > right.data[0]);
|
||||
return (t0 || t1 || t2 || t3);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Less-than operator
|
||||
* @details comparison operators exist so that `field` is comparible with stl methods that require them.
|
||||
* (e.g. std::sort)
|
||||
* Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
|
||||
*
|
||||
* @tparam T
|
||||
* @param other
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
template <class T> constexpr bool field<T>::operator<(const field& other) const noexcept
|
||||
{
|
||||
return (other > *this);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::operator==(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::eqeq");
|
||||
const field left = reduce_once();
|
||||
const field right = other.reduce_once();
|
||||
return (left.data[0] == right.data[0]) && (left.data[1] == right.data[1]) && (left.data[2] == right.data[2]) &&
|
||||
(left.data[3] == right.data[3]);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::operator!=(const field& other) const noexcept
|
||||
{
|
||||
return (!operator==(other));
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::to_montgomery_form() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::to_montgomery_form");
|
||||
constexpr field r_squared =
|
||||
field{ r_squared_uint.data[0], r_squared_uint.data[1], r_squared_uint.data[2], r_squared_uint.data[3] };
|
||||
|
||||
field result = *this;
|
||||
// TODO(@zac-williamson): are these reductions needed?
|
||||
// Rationale: We want to take any 256-bit input and be able to convert into montgomery form.
|
||||
// A basic heuristic we use is that any input into the `*` operator must be between [0, 2p - 1]
|
||||
// to prevent overflows in the asm algorithm.
|
||||
// However... r_squared is already reduced so perhaps we can relax this requirement?
|
||||
// (would be good to identify a failure case where not calling self_reduce triggers an error)
|
||||
result.self_reduce_once();
|
||||
result.self_reduce_once();
|
||||
result.self_reduce_once();
|
||||
return (result * r_squared).reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::from_montgomery_form() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::from_montgomery_form");
|
||||
constexpr field one_raw{ 1, 0, 0, 0 };
|
||||
return operator*(one_raw).reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_to_montgomery_form() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_to_montgomery_form");
|
||||
constexpr field r_squared =
|
||||
field{ r_squared_uint.data[0], r_squared_uint.data[1], r_squared_uint.data[2], r_squared_uint.data[3] };
|
||||
|
||||
self_reduce_once();
|
||||
self_reduce_once();
|
||||
self_reduce_once();
|
||||
*this *= r_squared;
|
||||
self_reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_from_montgomery_form() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_from_montgomery_form");
|
||||
constexpr field one_raw{ 1, 0, 0, 0 };
|
||||
*this *= one_raw;
|
||||
self_reduce_once();
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::reduce_once() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::reduce_once");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
return reduce();
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
return reduce();
|
||||
}
|
||||
return asm_reduce_once(*this);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_reduce_once() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_reduce_once");
|
||||
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
*this = reduce();
|
||||
} else {
|
||||
if (std::is_constant_evaluated()) {
|
||||
*this = reduce();
|
||||
} else {
|
||||
asm_self_reduce_once(*this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::pow(const uint256_t& exponent) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::pow");
|
||||
field accumulator{ data[0], data[1], data[2], data[3] };
|
||||
field to_mul{ data[0], data[1], data[2], data[3] };
|
||||
const uint64_t maximum_set_bit = exponent.get_msb();
|
||||
|
||||
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
|
||||
accumulator.self_sqr();
|
||||
if (exponent.get_bit(static_cast<uint64_t>(i))) {
|
||||
accumulator *= to_mul;
|
||||
}
|
||||
}
|
||||
if (exponent == uint256_t(0)) {
|
||||
accumulator = one();
|
||||
} else if (*this == zero()) {
|
||||
accumulator = zero();
|
||||
}
|
||||
return accumulator;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::pow(const uint64_t exponent) const noexcept
|
||||
{
|
||||
return pow({ exponent, 0, 0, 0 });
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::invert() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::invert");
|
||||
if (*this == zero()) {
|
||||
throw_or_abort("Trying to invert zero in the field");
|
||||
}
|
||||
return pow(modulus_minus_two);
|
||||
}
|
||||
|
||||
template <class T> void field<T>::batch_invert(field* coeffs, const size_t n) noexcept
|
||||
{
|
||||
batch_invert(std::span{ coeffs, n });
|
||||
}
|
||||
|
||||
template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::batch_invert");
|
||||
const size_t n = coeffs.size();
|
||||
|
||||
auto temporaries_ptr = std::static_pointer_cast<field[]>(get_mem_slab(n * sizeof(field)));
|
||||
auto skipped_ptr = std::static_pointer_cast<bool[]>(get_mem_slab(n));
|
||||
auto temporaries = temporaries_ptr.get();
|
||||
auto* skipped = skipped_ptr.get();
|
||||
|
||||
field accumulator = one();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
temporaries[i] = accumulator;
|
||||
if (coeffs[i].is_zero()) {
|
||||
skipped[i] = true;
|
||||
} else {
|
||||
skipped[i] = false;
|
||||
accumulator *= coeffs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// std::vector<field> temporaries;
|
||||
// std::vector<bool> skipped;
|
||||
// temporaries.reserve(n);
|
||||
// skipped.reserve(n);
|
||||
|
||||
// field accumulator = one();
|
||||
// for (size_t i = 0; i < n; ++i) {
|
||||
// temporaries.emplace_back(accumulator);
|
||||
// if (coeffs[i].is_zero()) {
|
||||
// skipped.emplace_back(true);
|
||||
// } else {
|
||||
// skipped.emplace_back(false);
|
||||
// accumulator *= coeffs[i];
|
||||
// }
|
||||
// }
|
||||
|
||||
accumulator = accumulator.invert();
|
||||
|
||||
field T0;
|
||||
for (size_t i = n - 1; i < n; --i) {
|
||||
if (!skipped[i]) {
|
||||
T0 = accumulator * temporaries[i];
|
||||
accumulator *= coeffs[i];
|
||||
coeffs[i] = T0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt");
|
||||
// Tonelli-shanks algorithm begins by finding a field element Q and integer S,
|
||||
// such that (p - 1) = Q.2^{s}
|
||||
|
||||
// We can compute the square root of a, by considering a^{(Q + 1) / 2} = R
|
||||
// Once we have found such an R, we have
|
||||
// R^{2} = a^{Q + 1} = a^{Q}a
|
||||
// If a^{Q} = 1, we have found our square root.
|
||||
// Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity.
|
||||
// This is because t^{2^{s-1}} = a^{Q.2^{s-1}}.
|
||||
// We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2}
|
||||
// From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1
|
||||
// i.e. t^{2^{s-1}} = 1
|
||||
|
||||
// To proceed with computing our square root, we want to transform t into a smaller subgroup,
|
||||
// specifically, the (s-2)'th roots of unity.
|
||||
// We do this by finding some value b,such that
|
||||
// (t.b^2)^{2^{s-2}} = 1 and R' = R.b
|
||||
// Finding such a b is trivial, because from Euler's criterion, we know that,
|
||||
// for any quadratic non-residue z, z^{(p - 1) / 2} = -1
|
||||
// i.e. z^{Q.2^{s-1}} = -1
|
||||
// => z^Q is a 2^{s-1}'th root of -1
|
||||
// => z^{Q^2} is a 2^{s-2}'th root of -1
|
||||
// Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1
|
||||
// => t.z^{Q^2} is a 2^{s - 2}'th root of unity.
|
||||
|
||||
// We can iteratively transform t into ever smaller subgroups, until t = 1.
|
||||
// At each iteration, we need to find a new value for b, which we can obtain
|
||||
// by repeatedly squaring z^{Q}
|
||||
constexpr uint256_t Q = (modulus - 1) >> static_cast<uint64_t>(primitive_root_log_size() - 1);
|
||||
constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 2;
|
||||
|
||||
// __to_montgomery_form(Q_minus_one_over_two, Q_minus_one_over_two);
|
||||
field z = coset_generator(0); // the generator is a non-residue
|
||||
field b = pow(Q_minus_one_over_two);
|
||||
field r = operator*(b); // r = a^{(Q + 1) / 2}
|
||||
field t = r * b; // t = a^{(Q - 1) / 2 + (Q + 1) / 2} = a^{Q}
|
||||
|
||||
// check if t is a square with euler's criterion
|
||||
// if not, we don't have a quadratic residue and a has no square root!
|
||||
field check = t;
|
||||
for (size_t i = 0; i < primitive_root_log_size() - 1; ++i) {
|
||||
check.self_sqr();
|
||||
}
|
||||
if (check != one()) {
|
||||
return zero();
|
||||
}
|
||||
field t1 = z.pow(Q_minus_one_over_two);
|
||||
field t2 = t1 * z;
|
||||
field c = t2 * t1; // z^Q
|
||||
|
||||
size_t m = primitive_root_log_size();
|
||||
|
||||
while (t != one()) {
|
||||
size_t i = 0;
|
||||
field t2m = t;
|
||||
|
||||
// find the smallest value of m, such that t^{2^m} = 1
|
||||
while (t2m != one()) {
|
||||
t2m.self_sqr();
|
||||
i += 1;
|
||||
}
|
||||
|
||||
size_t j = m - i - 1;
|
||||
b = c;
|
||||
while (j > 0) {
|
||||
b.self_sqr();
|
||||
--j;
|
||||
} // b = z^2^(m-i-1)
|
||||
|
||||
c = b.sqr();
|
||||
t = t * c;
|
||||
r = r * b;
|
||||
m = i;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::sqrt");
|
||||
field root;
|
||||
if constexpr ((T::modulus_0 & 0x3UL) == 0x3UL) {
|
||||
constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2;
|
||||
root = pow(sqrt_exponent);
|
||||
} else {
|
||||
root = tonelli_shanks_sqrt();
|
||||
}
|
||||
if ((root * root) == (*this)) {
|
||||
return std::pair<bool, field>(true, root);
|
||||
}
|
||||
return std::pair<bool, field>(false, field::zero());
|
||||
|
||||
} // namespace bb;
|
||||
|
||||
template <class T> constexpr field<T> field<T>::operator/(const field& other) const noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::div");
|
||||
return operator*(other.invert());
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T>& field<T>::operator/=(const field& other) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_div");
|
||||
*this = operator/(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T> constexpr void field<T>::self_set_msb() noexcept
|
||||
{
|
||||
data[3] = 0ULL | (1ULL << 63ULL);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::is_msb_set() const noexcept
|
||||
{
|
||||
return (data[3] >> 63ULL) == 1ULL;
|
||||
}
|
||||
|
||||
template <class T> constexpr uint64_t field<T>::is_msb_set_word() const noexcept
|
||||
{
|
||||
return (data[3] >> 63ULL);
|
||||
}
|
||||
|
||||
template <class T> constexpr bool field<T>::is_zero() const noexcept
|
||||
{
|
||||
return ((data[0] | data[1] | data[2] | data[3]) == 0) ||
|
||||
(data[0] == T::modulus_0 && data[1] == T::modulus_1 && data[2] == T::modulus_2 && data[3] == T::modulus_3);
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::get_root_of_unity(size_t subgroup_size) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
field r{ T::primitive_root_0, T::primitive_root_1, T::primitive_root_2, T::primitive_root_3 };
|
||||
#else
|
||||
field r{ T::primitive_root_wasm_0, T::primitive_root_wasm_1, T::primitive_root_wasm_2, T::primitive_root_wasm_3 };
|
||||
#endif
|
||||
for (size_t i = primitive_root_log_size(); i > subgroup_size; --i) {
|
||||
r.self_sqr();
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> field<T> field<T>::random_element(numeric::RNG* engine) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::random_element");
|
||||
if (engine == nullptr) {
|
||||
engine = &numeric::get_randomness();
|
||||
}
|
||||
|
||||
uint512_t source = engine->get_random_uint512();
|
||||
uint512_t q(modulus);
|
||||
uint512_t reduced = source % q;
|
||||
return field(reduced.lo);
|
||||
}
|
||||
|
||||
template <class T> constexpr size_t field<T>::primitive_root_log_size() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::primitive_root_log_size");
|
||||
uint256_t target = modulus - 1;
|
||||
size_t result = 0;
|
||||
while (!target.get_bit(result)) {
|
||||
++result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute_coset_generators() noexcept
|
||||
{
|
||||
constexpr size_t n = COSET_GENERATOR_SIZE;
|
||||
constexpr uint64_t subgroup_size = 1 << 30;
|
||||
|
||||
std::array<field, COSET_GENERATOR_SIZE> result{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
if (n > 0) {
|
||||
result[0] = (multiplicative_generator());
|
||||
}
|
||||
field work_variable = multiplicative_generator() + field(1);
|
||||
|
||||
size_t count = 1;
|
||||
while (count < n) {
|
||||
// work_variable contains a new field element, and we need to test that, for all previous vector elements,
|
||||
// result[i] / work_variable is not a member of our subgroup
|
||||
field work_inverse = work_variable.invert();
|
||||
bool valid = true;
|
||||
for (size_t j = 0; j < count; ++j) {
|
||||
field subgroup_check = (work_inverse * result[j]).pow(subgroup_size);
|
||||
if (subgroup_check == field(1)) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (valid) {
|
||||
result[count] = (work_variable);
|
||||
++count;
|
||||
}
|
||||
work_variable += field(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::multiplicative_generator() noexcept
|
||||
{
|
||||
field target(1);
|
||||
uint256_t p_minus_one_over_two = (modulus - 1) >> 1;
|
||||
bool found = false;
|
||||
while (!found) {
|
||||
target += field(1);
|
||||
found = (target.pow(p_minus_one_over_two) == -field(1));
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
} // namespace bb
|
||||
|
||||
// clang-format off
|
||||
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)
|
||||
// clang-format on
|
||||
@@ -0,0 +1,945 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "./field_impl.hpp"
|
||||
#include "../../common/op_count.hpp"
|
||||
|
||||
namespace bb {
|
||||
using namespace numeric;
|
||||
// NOLINTBEGIN(readability-implicit-bool-conversion)
|
||||
template <class T> constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(uint64_t a, uint64_t b) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = (static_cast<uint128_t>(a) * static_cast<uint128_t>(b));
|
||||
return { static_cast<uint64_t>(res), static_cast<uint64_t>(res >> 64) };
|
||||
#else
|
||||
const uint64_t product = a * b;
|
||||
return { product & 0xffffffffULL, product >> 32 };
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr uint64_t field<T>::mac(
|
||||
const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t carry_in, uint64_t& carry_out) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c)) +
|
||||
static_cast<uint128_t>(carry_in);
|
||||
carry_out = static_cast<uint64_t>(res >> 64);
|
||||
return static_cast<uint64_t>(res);
|
||||
#else
|
||||
const uint64_t product = b * c + a + carry_in;
|
||||
carry_out = product >> 32;
|
||||
return product & 0xffffffffULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr void field<T>::mac(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
const uint64_t carry_in,
|
||||
uint64_t& out,
|
||||
uint64_t& carry_out) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c)) +
|
||||
static_cast<uint128_t>(carry_in);
|
||||
out = static_cast<uint64_t>(res);
|
||||
carry_out = static_cast<uint64_t>(res >> 64);
|
||||
#else
|
||||
const uint64_t product = b * c + a + carry_in;
|
||||
carry_out = product >> 32;
|
||||
out = product & 0xffffffffULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr uint64_t field<T>::mac_mini(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
uint64_t& carry_out) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
|
||||
carry_out = static_cast<uint64_t>(res >> 64);
|
||||
return static_cast<uint64_t>(res);
|
||||
#else
|
||||
const uint64_t product = b * c + a;
|
||||
carry_out = product >> 32;
|
||||
return product & 0xffffffffULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr void field<T>::mac_mini(
|
||||
const uint64_t a, const uint64_t b, const uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
|
||||
out = static_cast<uint64_t>(res);
|
||||
carry_out = static_cast<uint64_t>(res >> 64);
|
||||
#else
|
||||
const uint64_t result = b * c + a;
|
||||
carry_out = result >> 32;
|
||||
out = result & 0xffffffffULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr uint64_t field<T>::mac_discard_lo(const uint64_t a, const uint64_t b, const uint64_t c) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
|
||||
return static_cast<uint64_t>(res >> 64);
|
||||
#else
|
||||
return (b * c + a) >> 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr uint64_t field<T>::addc(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t carry_in,
|
||||
uint64_t& carry_out) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK();
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
uint128_t res = static_cast<uint128_t>(a) + static_cast<uint128_t>(b) + static_cast<uint128_t>(carry_in);
|
||||
carry_out = static_cast<uint64_t>(res >> 64);
|
||||
return static_cast<uint64_t>(res);
|
||||
#else
|
||||
uint64_t r = a + b;
|
||||
const uint64_t carry_temp = r < a;
|
||||
r += carry_in;
|
||||
carry_out = carry_temp + (r < carry_in);
|
||||
return r;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr uint64_t field<T>::sbb(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t borrow_in,
|
||||
uint64_t& borrow_out) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
uint128_t res = static_cast<uint128_t>(a) - (static_cast<uint128_t>(b) + static_cast<uint128_t>(borrow_in >> 63));
|
||||
borrow_out = static_cast<uint64_t>(res >> 64);
|
||||
return static_cast<uint64_t>(res);
|
||||
#else
|
||||
uint64_t t_1 = a - (borrow_in >> 63ULL);
|
||||
uint64_t borrow_temp_1 = t_1 > a;
|
||||
uint64_t t_2 = t_1 - b;
|
||||
uint64_t borrow_temp_2 = t_2 > t_1;
|
||||
borrow_out = 0ULL - (borrow_temp_1 | borrow_temp_2);
|
||||
return t_2;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
const uint64_t carry_in_lo,
|
||||
const uint64_t carry_in_hi,
|
||||
uint64_t& carry_lo,
|
||||
uint64_t& carry_hi) noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const uint128_t product = static_cast<uint128_t>(b) * static_cast<uint128_t>(c);
|
||||
const auto r0 = static_cast<uint64_t>(product);
|
||||
const auto r1 = static_cast<uint64_t>(product >> 64);
|
||||
uint64_t out = r0 + r0;
|
||||
carry_lo = (out < r0);
|
||||
out += a;
|
||||
carry_lo += (out < a);
|
||||
out += carry_in_lo;
|
||||
carry_lo += (out < carry_in_lo);
|
||||
carry_lo += r1;
|
||||
carry_hi = (carry_lo < r1);
|
||||
carry_lo += r1;
|
||||
carry_hi += (carry_lo < r1);
|
||||
carry_lo += carry_in_hi;
|
||||
carry_hi += (carry_lo < carry_in_hi);
|
||||
return out;
|
||||
#else
|
||||
const auto product = b * c;
|
||||
const auto t0 = product + a + carry_in_lo;
|
||||
const auto t1 = product + t0;
|
||||
carry_hi = t1 < product;
|
||||
const auto t2 = t1 + (carry_in_hi << 32);
|
||||
carry_hi += t2 < t1;
|
||||
carry_lo = t2 >> 32;
|
||||
return t2 & 0xffffffffULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::reduce() const noexcept
|
||||
{
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
uint256_t val{ data[0], data[1], data[2], data[3] };
|
||||
if (val >= modulus) {
|
||||
val -= modulus;
|
||||
}
|
||||
return { val.data[0], val.data[1], val.data[2], val.data[3] };
|
||||
}
|
||||
uint64_t t0 = data[0] + not_modulus.data[0];
|
||||
uint64_t c = t0 < data[0];
|
||||
auto t1 = addc(data[1], not_modulus.data[1], c, c);
|
||||
auto t2 = addc(data[2], not_modulus.data[2], c, c);
|
||||
auto t3 = addc(data[3], not_modulus.data[3], c, c);
|
||||
const uint64_t selection_mask = 0ULL - c; // 0xffff... if we have overflowed.
|
||||
const uint64_t selection_mask_inverse = ~selection_mask;
|
||||
// if we overflow, we want to swap
|
||||
return {
|
||||
(data[0] & selection_mask_inverse) | (t0 & selection_mask),
|
||||
(data[1] & selection_mask_inverse) | (t1 & selection_mask),
|
||||
(data[2] & selection_mask_inverse) | (t2 & selection_mask),
|
||||
(data[3] & selection_mask_inverse) | (t3 & selection_mask),
|
||||
};
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::add(const field& other) const noexcept
|
||||
{
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
uint64_t r0 = data[0] + other.data[0];
|
||||
uint64_t c = r0 < data[0];
|
||||
auto r1 = addc(data[1], other.data[1], c, c);
|
||||
auto r2 = addc(data[2], other.data[2], c, c);
|
||||
auto r3 = addc(data[3], other.data[3], c, c);
|
||||
if (c) {
|
||||
uint64_t b = 0;
|
||||
r0 = sbb(r0, modulus.data[0], b, b);
|
||||
r1 = sbb(r1, modulus.data[1], b, b);
|
||||
r2 = sbb(r2, modulus.data[2], b, b);
|
||||
r3 = sbb(r3, modulus.data[3], b, b);
|
||||
// Since both values are in [0, 2**256), the result is in [0, 2**257-2]. Subtracting one p might not be
|
||||
// enough. We need to ensure that we've underflown the 0 and that might require subtracting an additional p
|
||||
if (!b) {
|
||||
b = 0;
|
||||
r0 = sbb(r0, modulus.data[0], b, b);
|
||||
r1 = sbb(r1, modulus.data[1], b, b);
|
||||
r2 = sbb(r2, modulus.data[2], b, b);
|
||||
r3 = sbb(r3, modulus.data[3], b, b);
|
||||
}
|
||||
}
|
||||
return { r0, r1, r2, r3 };
|
||||
} else {
|
||||
uint64_t r0 = data[0] + other.data[0];
|
||||
uint64_t c = r0 < data[0];
|
||||
auto r1 = addc(data[1], other.data[1], c, c);
|
||||
auto r2 = addc(data[2], other.data[2], c, c);
|
||||
uint64_t r3 = data[3] + other.data[3] + c;
|
||||
|
||||
uint64_t t0 = r0 + twice_not_modulus.data[0];
|
||||
c = t0 < twice_not_modulus.data[0];
|
||||
uint64_t t1 = addc(r1, twice_not_modulus.data[1], c, c);
|
||||
uint64_t t2 = addc(r2, twice_not_modulus.data[2], c, c);
|
||||
uint64_t t3 = addc(r3, twice_not_modulus.data[3], c, c);
|
||||
const uint64_t selection_mask = 0ULL - c;
|
||||
const uint64_t selection_mask_inverse = ~selection_mask;
|
||||
|
||||
return {
|
||||
(r0 & selection_mask_inverse) | (t0 & selection_mask),
|
||||
(r1 & selection_mask_inverse) | (t1 & selection_mask),
|
||||
(r2 & selection_mask_inverse) | (t2 & selection_mask),
|
||||
(r3 & selection_mask_inverse) | (t3 & selection_mask),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::subtract(const field& other) const noexcept
|
||||
{
|
||||
uint64_t borrow = 0;
|
||||
uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
|
||||
uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow);
|
||||
uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow);
|
||||
uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow);
|
||||
|
||||
r0 += (modulus.data[0] & borrow);
|
||||
uint64_t carry = r0 < (modulus.data[0] & borrow);
|
||||
r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
|
||||
r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
|
||||
r3 = addc(r3, (modulus.data[3] & borrow), carry, carry);
|
||||
// The value being subtracted is in [0, 2**256), if we subtract 0 - 2*255 and then add p, the value will stay
|
||||
// negative. If we are adding p, we need to check that we've overflown 2**256. If not, we should add p again
|
||||
if (!carry) {
|
||||
r0 += (modulus.data[0] & borrow);
|
||||
uint64_t carry = r0 < (modulus.data[0] & borrow);
|
||||
r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
|
||||
r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
|
||||
r3 = addc(r3, (modulus.data[3] & borrow), carry, carry);
|
||||
}
|
||||
return { r0, r1, r2, r3 };
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @tparam T
|
||||
* @param other
|
||||
* @return constexpr field<T>
|
||||
*/
|
||||
template <class T> constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
|
||||
{
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return subtract(other);
|
||||
}
|
||||
uint64_t borrow = 0;
|
||||
uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
|
||||
uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow);
|
||||
uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow);
|
||||
uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow);
|
||||
|
||||
r0 += (twice_modulus.data[0] & borrow);
|
||||
uint64_t carry = r0 < (twice_modulus.data[0] & borrow);
|
||||
r1 = addc(r1, twice_modulus.data[1] & borrow, carry, carry);
|
||||
r2 = addc(r2, twice_modulus.data[2] & borrow, carry, carry);
|
||||
r3 += (twice_modulus.data[3] & borrow) + carry;
|
||||
|
||||
return { r0, r1, r2, r3 };
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Mongtomery multiplication for moduli > 2²⁵⁴
|
||||
*
|
||||
* @details Explanation of Montgomery form can be found in \ref field_docs_montgomery_explainer and the difference
|
||||
* between WASM and generic versions is explained in \ref field_docs_architecture_details
|
||||
*/
|
||||
template <class T> constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
uint64_t c = 0;
|
||||
uint64_t t0 = 0;
|
||||
uint64_t t1 = 0;
|
||||
uint64_t t2 = 0;
|
||||
uint64_t t3 = 0;
|
||||
uint64_t t4 = 0;
|
||||
uint64_t t5 = 0;
|
||||
uint64_t k = 0;
|
||||
for (const auto& element : data) {
|
||||
c = 0;
|
||||
mac(t0, element, other.data[0], c, t0, c);
|
||||
mac(t1, element, other.data[1], c, t1, c);
|
||||
mac(t2, element, other.data[2], c, t2, c);
|
||||
mac(t3, element, other.data[3], c, t3, c);
|
||||
t4 = addc(t4, c, 0, t5);
|
||||
|
||||
c = 0;
|
||||
k = t0 * T::r_inv;
|
||||
c = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, k, modulus.data[1], c, t0, c);
|
||||
mac(t2, k, modulus.data[2], c, t1, c);
|
||||
mac(t3, k, modulus.data[3], c, t2, c);
|
||||
t3 = addc(c, t4, 0, c);
|
||||
t4 = t5 + c;
|
||||
}
|
||||
uint64_t borrow = 0;
|
||||
uint64_t r0 = sbb(t0, modulus.data[0], borrow, borrow);
|
||||
uint64_t r1 = sbb(t1, modulus.data[1], borrow, borrow);
|
||||
uint64_t r2 = sbb(t2, modulus.data[2], borrow, borrow);
|
||||
uint64_t r3 = sbb(t3, modulus.data[3], borrow, borrow);
|
||||
borrow = borrow ^ (0ULL - t4);
|
||||
r0 += (modulus.data[0] & borrow);
|
||||
uint64_t carry = r0 < (modulus.data[0] & borrow);
|
||||
r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
|
||||
r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
|
||||
r3 += (modulus.data[3] & borrow) + carry;
|
||||
return { r0, r1, r2, r3 };
|
||||
#else
|
||||
|
||||
// Convert 4 64-bit limbs to 9 29-bit limbs
|
||||
auto left = wasm_convert(data);
|
||||
auto right = wasm_convert(other.data);
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
uint64_t temp_0 = 0;
|
||||
uint64_t temp_1 = 0;
|
||||
uint64_t temp_2 = 0;
|
||||
uint64_t temp_3 = 0;
|
||||
uint64_t temp_4 = 0;
|
||||
uint64_t temp_5 = 0;
|
||||
uint64_t temp_6 = 0;
|
||||
uint64_t temp_7 = 0;
|
||||
uint64_t temp_8 = 0;
|
||||
uint64_t temp_9 = 0;
|
||||
uint64_t temp_10 = 0;
|
||||
uint64_t temp_11 = 0;
|
||||
uint64_t temp_12 = 0;
|
||||
uint64_t temp_13 = 0;
|
||||
uint64_t temp_14 = 0;
|
||||
uint64_t temp_15 = 0;
|
||||
uint64_t temp_16 = 0;
|
||||
uint64_t temp_17 = 0;
|
||||
|
||||
// Multiply-add 0th limb of the left argument by all 9 limbs of the right arguemnt
|
||||
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
// Instantly reduce
|
||||
wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
// Continue for other limbs
|
||||
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
|
||||
// After all multiplications and additions, convert relaxed form to strict (all limbs are 29 bits)
|
||||
temp_10 += temp_9 >> WASM_LIMB_BITS;
|
||||
temp_9 &= mask;
|
||||
temp_11 += temp_10 >> WASM_LIMB_BITS;
|
||||
temp_10 &= mask;
|
||||
temp_12 += temp_11 >> WASM_LIMB_BITS;
|
||||
temp_11 &= mask;
|
||||
temp_13 += temp_12 >> WASM_LIMB_BITS;
|
||||
temp_12 &= mask;
|
||||
temp_14 += temp_13 >> WASM_LIMB_BITS;
|
||||
temp_13 &= mask;
|
||||
temp_15 += temp_14 >> WASM_LIMB_BITS;
|
||||
temp_14 &= mask;
|
||||
temp_16 += temp_15 >> WASM_LIMB_BITS;
|
||||
temp_15 &= mask;
|
||||
temp_17 += temp_16 >> WASM_LIMB_BITS;
|
||||
temp_16 &= mask;
|
||||
|
||||
uint64_t r_temp_0;
|
||||
uint64_t r_temp_1;
|
||||
uint64_t r_temp_2;
|
||||
uint64_t r_temp_3;
|
||||
uint64_t r_temp_4;
|
||||
uint64_t r_temp_5;
|
||||
uint64_t r_temp_6;
|
||||
uint64_t r_temp_7;
|
||||
uint64_t r_temp_8;
|
||||
// Subtract modulus from result
|
||||
r_temp_0 = temp_9 - wasm_modulus[0];
|
||||
r_temp_1 = temp_10 - wasm_modulus[1] - ((r_temp_0) >> 63);
|
||||
r_temp_2 = temp_11 - wasm_modulus[2] - ((r_temp_1) >> 63);
|
||||
r_temp_3 = temp_12 - wasm_modulus[3] - ((r_temp_2) >> 63);
|
||||
r_temp_4 = temp_13 - wasm_modulus[4] - ((r_temp_3) >> 63);
|
||||
r_temp_5 = temp_14 - wasm_modulus[5] - ((r_temp_4) >> 63);
|
||||
r_temp_6 = temp_15 - wasm_modulus[6] - ((r_temp_5) >> 63);
|
||||
r_temp_7 = temp_16 - wasm_modulus[7] - ((r_temp_6) >> 63);
|
||||
r_temp_8 = temp_17 - wasm_modulus[8] - ((r_temp_7) >> 63);
|
||||
|
||||
// Depending on whether the subtraction underflowed, choose original value or the result of subtraction
|
||||
uint64_t new_mask = 0 - (r_temp_8 >> 63);
|
||||
uint64_t inverse_mask = (~new_mask) & mask;
|
||||
temp_9 = (temp_9 & new_mask) | (r_temp_0 & inverse_mask);
|
||||
temp_10 = (temp_10 & new_mask) | (r_temp_1 & inverse_mask);
|
||||
temp_11 = (temp_11 & new_mask) | (r_temp_2 & inverse_mask);
|
||||
temp_12 = (temp_12 & new_mask) | (r_temp_3 & inverse_mask);
|
||||
temp_13 = (temp_13 & new_mask) | (r_temp_4 & inverse_mask);
|
||||
temp_14 = (temp_14 & new_mask) | (r_temp_5 & inverse_mask);
|
||||
temp_15 = (temp_15 & new_mask) | (r_temp_6 & inverse_mask);
|
||||
temp_16 = (temp_16 & new_mask) | (r_temp_7 & inverse_mask);
|
||||
temp_17 = (temp_17 & new_mask) | (r_temp_8 & inverse_mask);
|
||||
|
||||
// Convert back to 4 64-bit limbs
|
||||
return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
|
||||
(temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
|
||||
(temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
|
||||
(temp_15 >> 18) | (temp_16 << 11) | (temp_17 << 40) };
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
|
||||
|
||||
/**
|
||||
* @brief Multiply left limb by a sequence of 9 limbs and put into result variables
|
||||
*
|
||||
*/
|
||||
template <class T>
|
||||
constexpr void field<T>::wasm_madd(uint64_t& left_limb,
|
||||
const std::array<uint64_t, WASM_NUM_LIMBS>& right_limbs,
|
||||
uint64_t& result_0,
|
||||
uint64_t& result_1,
|
||||
uint64_t& result_2,
|
||||
uint64_t& result_3,
|
||||
uint64_t& result_4,
|
||||
uint64_t& result_5,
|
||||
uint64_t& result_6,
|
||||
uint64_t& result_7,
|
||||
uint64_t& result_8)
|
||||
{
|
||||
result_0 += left_limb * right_limbs[0];
|
||||
result_1 += left_limb * right_limbs[1];
|
||||
result_2 += left_limb * right_limbs[2];
|
||||
result_3 += left_limb * right_limbs[3];
|
||||
result_4 += left_limb * right_limbs[4];
|
||||
result_5 += left_limb * right_limbs[5];
|
||||
result_6 += left_limb * right_limbs[6];
|
||||
result_7 += left_limb * right_limbs[7];
|
||||
result_8 += left_limb * right_limbs[8];
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Perform 29-bit montgomery reduction on 1 limb (result_0 should be zero modulo 2**29 after this)
|
||||
*
|
||||
*/
|
||||
template <class T>
|
||||
constexpr void field<T>::wasm_reduce(uint64_t& result_0,
|
||||
uint64_t& result_1,
|
||||
uint64_t& result_2,
|
||||
uint64_t& result_3,
|
||||
uint64_t& result_4,
|
||||
uint64_t& result_5,
|
||||
uint64_t& result_6,
|
||||
uint64_t& result_7,
|
||||
uint64_t& result_8)
|
||||
{
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
constexpr uint64_t r_inv = T::r_inv & mask;
|
||||
uint64_t k = (result_0 * r_inv) & mask;
|
||||
result_0 += k * wasm_modulus[0];
|
||||
result_1 += k * wasm_modulus[1] + (result_0 >> WASM_LIMB_BITS);
|
||||
result_2 += k * wasm_modulus[2];
|
||||
result_3 += k * wasm_modulus[3];
|
||||
result_4 += k * wasm_modulus[4];
|
||||
result_5 += k * wasm_modulus[5];
|
||||
result_6 += k * wasm_modulus[6];
|
||||
result_7 += k * wasm_modulus[7];
|
||||
result_8 += k * wasm_modulus[8];
|
||||
}
|
||||
/**
|
||||
* @brief Convert 4 64-bit limbs into 9 29-bit limbs
|
||||
*
|
||||
*/
|
||||
template <class T> constexpr std::array<uint64_t, WASM_NUM_LIMBS> field<T>::wasm_convert(const uint64_t* data)
|
||||
{
|
||||
return { data[0] & 0x1fffffff,
|
||||
(data[0] >> WASM_LIMB_BITS) & 0x1fffffff,
|
||||
((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6),
|
||||
(data[1] >> 23) & 0x1fffffff,
|
||||
((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12),
|
||||
(data[2] >> 17) & 0x1fffffff,
|
||||
((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18),
|
||||
(data[3] >> 11) & 0x1fffffff,
|
||||
(data[3] >> 40) & 0x1fffffff };
|
||||
}
|
||||
#endif
|
||||
template <class T> constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
|
||||
{
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return montgomery_mul_big(other);
|
||||
}
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
auto [t0, c] = mul_wide(data[0], other.data[0]);
|
||||
uint64_t k = t0 * T::r_inv;
|
||||
uint64_t a = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
|
||||
uint64_t t1 = mac_mini(a, data[0], other.data[1], a);
|
||||
mac(t1, k, modulus.data[1], c, t0, c);
|
||||
uint64_t t2 = mac_mini(a, data[0], other.data[2], a);
|
||||
mac(t2, k, modulus.data[2], c, t1, c);
|
||||
uint64_t t3 = mac_mini(a, data[0], other.data[3], a);
|
||||
mac(t3, k, modulus.data[3], c, t2, c);
|
||||
t3 = c + a;
|
||||
|
||||
mac_mini(t0, data[1], other.data[0], t0, a);
|
||||
k = t0 * T::r_inv;
|
||||
c = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, data[1], other.data[1], a, t1, a);
|
||||
mac(t1, k, modulus.data[1], c, t0, c);
|
||||
mac(t2, data[1], other.data[2], a, t2, a);
|
||||
mac(t2, k, modulus.data[2], c, t1, c);
|
||||
mac(t3, data[1], other.data[3], a, t3, a);
|
||||
mac(t3, k, modulus.data[3], c, t2, c);
|
||||
t3 = c + a;
|
||||
|
||||
mac_mini(t0, data[2], other.data[0], t0, a);
|
||||
k = t0 * T::r_inv;
|
||||
c = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, data[2], other.data[1], a, t1, a);
|
||||
mac(t1, k, modulus.data[1], c, t0, c);
|
||||
mac(t2, data[2], other.data[2], a, t2, a);
|
||||
mac(t2, k, modulus.data[2], c, t1, c);
|
||||
mac(t3, data[2], other.data[3], a, t3, a);
|
||||
mac(t3, k, modulus.data[3], c, t2, c);
|
||||
t3 = c + a;
|
||||
|
||||
mac_mini(t0, data[3], other.data[0], t0, a);
|
||||
k = t0 * T::r_inv;
|
||||
c = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, data[3], other.data[1], a, t1, a);
|
||||
mac(t1, k, modulus.data[1], c, t0, c);
|
||||
mac(t2, data[3], other.data[2], a, t2, a);
|
||||
mac(t2, k, modulus.data[2], c, t1, c);
|
||||
mac(t3, data[3], other.data[3], a, t3, a);
|
||||
mac(t3, k, modulus.data[3], c, t2, c);
|
||||
t3 = c + a;
|
||||
return { t0, t1, t2, t3 };
|
||||
#else
|
||||
|
||||
// Convert 4 64-bit limbs to 9 29-bit ones
|
||||
auto left = wasm_convert(data);
|
||||
auto right = wasm_convert(other.data);
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
uint64_t temp_0 = 0;
|
||||
uint64_t temp_1 = 0;
|
||||
uint64_t temp_2 = 0;
|
||||
uint64_t temp_3 = 0;
|
||||
uint64_t temp_4 = 0;
|
||||
uint64_t temp_5 = 0;
|
||||
uint64_t temp_6 = 0;
|
||||
uint64_t temp_7 = 0;
|
||||
uint64_t temp_8 = 0;
|
||||
uint64_t temp_9 = 0;
|
||||
uint64_t temp_10 = 0;
|
||||
uint64_t temp_11 = 0;
|
||||
uint64_t temp_12 = 0;
|
||||
uint64_t temp_13 = 0;
|
||||
uint64_t temp_14 = 0;
|
||||
uint64_t temp_15 = 0;
|
||||
uint64_t temp_16 = 0;
|
||||
|
||||
// Perform a series of multiplications and reductions (we multiply 1 limb of left argument by the whole right
|
||||
// argument and then reduce)
|
||||
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
|
||||
// Convert result to unrelaxed form (all limbs are 29 bits)
|
||||
temp_10 += temp_9 >> WASM_LIMB_BITS;
|
||||
temp_9 &= mask;
|
||||
temp_11 += temp_10 >> WASM_LIMB_BITS;
|
||||
temp_10 &= mask;
|
||||
temp_12 += temp_11 >> WASM_LIMB_BITS;
|
||||
temp_11 &= mask;
|
||||
temp_13 += temp_12 >> WASM_LIMB_BITS;
|
||||
temp_12 &= mask;
|
||||
temp_14 += temp_13 >> WASM_LIMB_BITS;
|
||||
temp_13 &= mask;
|
||||
temp_15 += temp_14 >> WASM_LIMB_BITS;
|
||||
temp_14 &= mask;
|
||||
temp_16 += temp_15 >> WASM_LIMB_BITS;
|
||||
temp_15 &= mask;
|
||||
|
||||
// Convert back to 4 64-bit limbs form
|
||||
return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
|
||||
(temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
|
||||
(temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
|
||||
(temp_15 >> 18) | (temp_16 << 11) };
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T> constexpr field<T> field<T>::montgomery_square() const noexcept
|
||||
{
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return montgomery_mul_big(*this);
|
||||
}
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
uint64_t carry_hi = 0;
|
||||
|
||||
auto [t0, carry_lo] = mul_wide(data[0], data[0]);
|
||||
uint64_t t1 = square_accumulate(0, data[1], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
|
||||
uint64_t t2 = square_accumulate(0, data[2], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
|
||||
uint64_t t3 = square_accumulate(0, data[3], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
|
||||
|
||||
uint64_t round_carry = carry_lo;
|
||||
uint64_t k = t0 * T::r_inv;
|
||||
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
|
||||
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
|
||||
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
|
||||
t3 = carry_lo + round_carry;
|
||||
|
||||
t1 = mac_mini(t1, data[1], data[1], carry_lo);
|
||||
carry_hi = 0;
|
||||
t2 = square_accumulate(t2, data[2], data[1], carry_lo, carry_hi, carry_lo, carry_hi);
|
||||
t3 = square_accumulate(t3, data[3], data[1], carry_lo, carry_hi, carry_lo, carry_hi);
|
||||
round_carry = carry_lo;
|
||||
k = t0 * T::r_inv;
|
||||
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
|
||||
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
|
||||
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
|
||||
t3 = carry_lo + round_carry;
|
||||
|
||||
t2 = mac_mini(t2, data[2], data[2], carry_lo);
|
||||
carry_hi = 0;
|
||||
t3 = square_accumulate(t3, data[3], data[2], carry_lo, carry_hi, carry_lo, carry_hi);
|
||||
round_carry = carry_lo;
|
||||
k = t0 * T::r_inv;
|
||||
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
|
||||
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
|
||||
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
|
||||
t3 = carry_lo + round_carry;
|
||||
|
||||
t3 = mac_mini(t3, data[3], data[3], carry_lo);
|
||||
k = t0 * T::r_inv;
|
||||
round_carry = carry_lo;
|
||||
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
|
||||
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
|
||||
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
|
||||
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
|
||||
t3 = carry_lo + round_carry;
|
||||
return { t0, t1, t2, t3 };
|
||||
#else
|
||||
// Convert from 4 64-bit limbs to 9 29-bit ones
|
||||
auto left = wasm_convert(data);
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
uint64_t temp_0 = 0;
|
||||
uint64_t temp_1 = 0;
|
||||
uint64_t temp_2 = 0;
|
||||
uint64_t temp_3 = 0;
|
||||
uint64_t temp_4 = 0;
|
||||
uint64_t temp_5 = 0;
|
||||
uint64_t temp_6 = 0;
|
||||
uint64_t temp_7 = 0;
|
||||
uint64_t temp_8 = 0;
|
||||
uint64_t temp_9 = 0;
|
||||
uint64_t temp_10 = 0;
|
||||
uint64_t temp_11 = 0;
|
||||
uint64_t temp_12 = 0;
|
||||
uint64_t temp_13 = 0;
|
||||
uint64_t temp_14 = 0;
|
||||
uint64_t temp_15 = 0;
|
||||
uint64_t temp_16 = 0;
|
||||
uint64_t acc;
|
||||
// Perform multiplications, but accumulated results for limb k=i+j so that we can double them at the same time
|
||||
temp_0 += left[0] * left[0];
|
||||
acc = 0;
|
||||
acc += left[0] * left[1];
|
||||
temp_1 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[2];
|
||||
temp_2 += left[1] * left[1];
|
||||
temp_2 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[3];
|
||||
acc += left[1] * left[2];
|
||||
temp_3 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[4];
|
||||
acc += left[1] * left[3];
|
||||
temp_4 += left[2] * left[2];
|
||||
temp_4 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[5];
|
||||
acc += left[1] * left[4];
|
||||
acc += left[2] * left[3];
|
||||
temp_5 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[6];
|
||||
acc += left[1] * left[5];
|
||||
acc += left[2] * left[4];
|
||||
temp_6 += left[3] * left[3];
|
||||
temp_6 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[7];
|
||||
acc += left[1] * left[6];
|
||||
acc += left[2] * left[5];
|
||||
acc += left[3] * left[4];
|
||||
temp_7 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[0] * left[8];
|
||||
acc += left[1] * left[7];
|
||||
acc += left[2] * left[6];
|
||||
acc += left[3] * left[5];
|
||||
temp_8 += left[4] * left[4];
|
||||
temp_8 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[1] * left[8];
|
||||
acc += left[2] * left[7];
|
||||
acc += left[3] * left[6];
|
||||
acc += left[4] * left[5];
|
||||
temp_9 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[2] * left[8];
|
||||
acc += left[3] * left[7];
|
||||
acc += left[4] * left[6];
|
||||
temp_10 += left[5] * left[5];
|
||||
temp_10 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[3] * left[8];
|
||||
acc += left[4] * left[7];
|
||||
acc += left[5] * left[6];
|
||||
temp_11 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[4] * left[8];
|
||||
acc += left[5] * left[7];
|
||||
temp_12 += left[6] * left[6];
|
||||
temp_12 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[5] * left[8];
|
||||
acc += left[6] * left[7];
|
||||
temp_13 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[6] * left[8];
|
||||
temp_14 += left[7] * left[7];
|
||||
temp_14 += (acc << 1);
|
||||
acc = 0;
|
||||
acc += left[7] * left[8];
|
||||
temp_15 += (acc << 1);
|
||||
temp_16 += left[8] * left[8];
|
||||
|
||||
// Perform reductions
|
||||
wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
|
||||
// Convert to unrelaxed 29-bit form
|
||||
temp_10 += temp_9 >> WASM_LIMB_BITS;
|
||||
temp_9 &= mask;
|
||||
temp_11 += temp_10 >> WASM_LIMB_BITS;
|
||||
temp_10 &= mask;
|
||||
temp_12 += temp_11 >> WASM_LIMB_BITS;
|
||||
temp_11 &= mask;
|
||||
temp_13 += temp_12 >> WASM_LIMB_BITS;
|
||||
temp_12 &= mask;
|
||||
temp_14 += temp_13 >> WASM_LIMB_BITS;
|
||||
temp_13 &= mask;
|
||||
temp_15 += temp_14 >> WASM_LIMB_BITS;
|
||||
temp_14 &= mask;
|
||||
temp_16 += temp_15 >> WASM_LIMB_BITS;
|
||||
temp_15 &= mask;
|
||||
// Convert to 4 64-bit form
|
||||
return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
|
||||
(temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
|
||||
(temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
|
||||
(temp_15 >> 18) | (temp_16 << 11) };
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T> constexpr struct field<T>::wide_array field<T>::mul_512(const field& other) const noexcept {
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
uint64_t carry_2 = 0;
|
||||
auto [r0, carry] = mul_wide(data[0], other.data[0]);
|
||||
uint64_t r1 = mac_mini(carry, data[0], other.data[1], carry);
|
||||
uint64_t r2 = mac_mini(carry, data[0], other.data[2], carry);
|
||||
uint64_t r3 = mac_mini(carry, data[0], other.data[3], carry_2);
|
||||
|
||||
r1 = mac_mini(r1, data[1], other.data[0], carry);
|
||||
r2 = mac(r2, data[1], other.data[1], carry, carry);
|
||||
r3 = mac(r3, data[1], other.data[2], carry, carry);
|
||||
uint64_t r4 = mac(carry_2, data[1], other.data[3], carry, carry_2);
|
||||
|
||||
r2 = mac_mini(r2, data[2], other.data[0], carry);
|
||||
r3 = mac(r3, data[2], other.data[1], carry, carry);
|
||||
r4 = mac(r4, data[2], other.data[2], carry, carry);
|
||||
uint64_t r5 = mac(carry_2, data[2], other.data[3], carry, carry_2);
|
||||
|
||||
r3 = mac_mini(r3, data[3], other.data[0], carry);
|
||||
r4 = mac(r4, data[3], other.data[1], carry, carry);
|
||||
r5 = mac(r5, data[3], other.data[2], carry, carry);
|
||||
uint64_t r6 = mac(carry_2, data[3], other.data[3], carry, carry_2);
|
||||
|
||||
return { r0, r1, r2, r3, r4, r5, r6, carry_2 };
|
||||
#else
|
||||
// Convert from 4 64-bit limbs to 9 29-bit limbs
|
||||
auto left = wasm_convert(data);
|
||||
auto right = wasm_convert(other.data);
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
uint64_t temp_0 = 0;
|
||||
uint64_t temp_1 = 0;
|
||||
uint64_t temp_2 = 0;
|
||||
uint64_t temp_3 = 0;
|
||||
uint64_t temp_4 = 0;
|
||||
uint64_t temp_5 = 0;
|
||||
uint64_t temp_6 = 0;
|
||||
uint64_t temp_7 = 0;
|
||||
uint64_t temp_8 = 0;
|
||||
uint64_t temp_9 = 0;
|
||||
uint64_t temp_10 = 0;
|
||||
uint64_t temp_11 = 0;
|
||||
uint64_t temp_12 = 0;
|
||||
uint64_t temp_13 = 0;
|
||||
uint64_t temp_14 = 0;
|
||||
uint64_t temp_15 = 0;
|
||||
uint64_t temp_16 = 0;
|
||||
|
||||
// Multiply-add all limbs
|
||||
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
|
||||
// Convert to unrelaxed 29-bit form
|
||||
temp_1 += temp_0 >> WASM_LIMB_BITS;
|
||||
temp_0 &= mask;
|
||||
temp_2 += temp_1 >> WASM_LIMB_BITS;
|
||||
temp_1 &= mask;
|
||||
temp_3 += temp_2 >> WASM_LIMB_BITS;
|
||||
temp_2 &= mask;
|
||||
temp_4 += temp_3 >> WASM_LIMB_BITS;
|
||||
temp_3 &= mask;
|
||||
temp_5 += temp_4 >> WASM_LIMB_BITS;
|
||||
temp_4 &= mask;
|
||||
temp_6 += temp_5 >> WASM_LIMB_BITS;
|
||||
temp_5 &= mask;
|
||||
temp_7 += temp_6 >> WASM_LIMB_BITS;
|
||||
temp_6 &= mask;
|
||||
temp_8 += temp_7 >> WASM_LIMB_BITS;
|
||||
temp_7 &= mask;
|
||||
temp_9 += temp_8 >> WASM_LIMB_BITS;
|
||||
temp_8 &= mask;
|
||||
temp_10 += temp_9 >> WASM_LIMB_BITS;
|
||||
temp_9 &= mask;
|
||||
temp_11 += temp_10 >> WASM_LIMB_BITS;
|
||||
temp_10 &= mask;
|
||||
temp_12 += temp_11 >> WASM_LIMB_BITS;
|
||||
temp_11 &= mask;
|
||||
temp_13 += temp_12 >> WASM_LIMB_BITS;
|
||||
temp_12 &= mask;
|
||||
temp_14 += temp_13 >> WASM_LIMB_BITS;
|
||||
temp_13 &= mask;
|
||||
temp_15 += temp_14 >> WASM_LIMB_BITS;
|
||||
temp_14 &= mask;
|
||||
temp_16 += temp_15 >> WASM_LIMB_BITS;
|
||||
temp_15 &= mask;
|
||||
|
||||
// Convert to 8 64-bit limbs
|
||||
return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
|
||||
(temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
|
||||
(temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
|
||||
(temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40),
|
||||
(temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63),
|
||||
(temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57),
|
||||
(temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51),
|
||||
(temp_15 >> 13) | (temp_16 << 16) };
|
||||
#endif
|
||||
}
|
||||
|
||||
// NOLINTEND(readability-implicit-bool-conversion)
|
||||
} // namespace bb
|
||||
@@ -0,0 +1,389 @@
|
||||
#pragma once
|
||||
|
||||
#if (BBERG_NO_ASM == 0)
|
||||
#include "./field_impl.hpp"
|
||||
#include "asm_macros.hpp"
|
||||
namespace bb {
|
||||
|
||||
template <class T> field<T> field<T>::asm_mul_with_coarse_reduction(const field& a, const field& b) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_mul_with_coarse_reduction");
|
||||
|
||||
field r;
|
||||
constexpr uint64_t r_inv = T::r_inv;
|
||||
constexpr uint64_t modulus_0 = modulus.data[0];
|
||||
constexpr uint64_t modulus_1 = modulus.data[1];
|
||||
constexpr uint64_t modulus_2 = modulus.data[2];
|
||||
constexpr uint64_t modulus_3 = modulus.data[3];
|
||||
constexpr uint64_t zero_ref = 0;
|
||||
|
||||
/**
|
||||
* Registers: rax:rdx = multiplication accumulator
|
||||
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
|
||||
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
|
||||
* %r10: zero register
|
||||
* %0: pointer to `a`
|
||||
* %1: pointer to `b`
|
||||
* %2: pointer to `r`
|
||||
**/
|
||||
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
|
||||
STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "%r"(&a),
|
||||
"%r"(&b),
|
||||
"r"(&r),
|
||||
[modulus_0] "m"(modulus_0),
|
||||
[modulus_1] "m"(modulus_1),
|
||||
[modulus_2] "m"(modulus_2),
|
||||
[modulus_3] "m"(modulus_3),
|
||||
[r_inv] "m"(r_inv),
|
||||
[zero_reference] "m"(zero_ref)
|
||||
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> void field<T>::asm_self_mul_with_coarse_reduction(const field& a, const field& b) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_self_mul_with_coarse_reduction");
|
||||
|
||||
constexpr uint64_t r_inv = T::r_inv;
|
||||
constexpr uint64_t modulus_0 = modulus.data[0];
|
||||
constexpr uint64_t modulus_1 = modulus.data[1];
|
||||
constexpr uint64_t modulus_2 = modulus.data[2];
|
||||
constexpr uint64_t modulus_3 = modulus.data[3];
|
||||
constexpr uint64_t zero_ref = 0;
|
||||
/**
|
||||
* Registers: rax:rdx = multiplication accumulator
|
||||
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
|
||||
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
|
||||
* %r10: zero register
|
||||
* %0: pointer to `a`
|
||||
* %1: pointer to `b`
|
||||
* %2: pointer to `r`
|
||||
**/
|
||||
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
|
||||
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&b),
|
||||
[modulus_0] "m"(modulus_0),
|
||||
[modulus_1] "m"(modulus_1),
|
||||
[modulus_2] "m"(modulus_2),
|
||||
[modulus_3] "m"(modulus_3),
|
||||
[r_inv] "m"(r_inv),
|
||||
[zero_reference] "m"(zero_ref)
|
||||
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
}
|
||||
|
||||
template <class T> field<T> field<T>::asm_sqr_with_coarse_reduction(const field& a) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_sqr_with_coarse_reduction");
|
||||
|
||||
field r;
|
||||
constexpr uint64_t r_inv = T::r_inv;
|
||||
constexpr uint64_t modulus_0 = modulus.data[0];
|
||||
constexpr uint64_t modulus_1 = modulus.data[1];
|
||||
constexpr uint64_t modulus_2 = modulus.data[2];
|
||||
constexpr uint64_t modulus_3 = modulus.data[3];
|
||||
constexpr uint64_t zero_ref = 0;
|
||||
|
||||
// Our SQR implementation with BMI2 but without ADX has a bug.
|
||||
// The case is extremely rare so fixing it is a bit of a waste of time.
|
||||
// We'll use MUL instead.
|
||||
#if !defined(__ADX__) || defined(DISABLE_ADX)
|
||||
/**
|
||||
* Registers: rax:rdx = multiplication accumulator
|
||||
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
|
||||
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
|
||||
* %r10: zero register
|
||||
* %0: pointer to `a`
|
||||
* %1: pointer to `b`
|
||||
* %2: pointer to `r`
|
||||
**/
|
||||
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
|
||||
STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "%r"(&a),
|
||||
"%r"(&a),
|
||||
"r"(&r),
|
||||
[modulus_0] "m"(modulus_0),
|
||||
[modulus_1] "m"(modulus_1),
|
||||
[modulus_2] "m"(modulus_2),
|
||||
[modulus_3] "m"(modulus_3),
|
||||
[r_inv] "m"(r_inv),
|
||||
[zero_reference] "m"(zero_ref)
|
||||
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
|
||||
#else
|
||||
|
||||
/**
|
||||
* Registers: rax:rdx = multiplication accumulator
|
||||
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
|
||||
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
|
||||
* %[zero_reference]: memory location of zero value
|
||||
* %0: pointer to `a`
|
||||
* %[r_ptr]: memory location of pointer to `r`
|
||||
**/
|
||||
__asm__(SQR("%0")
|
||||
// "movq %[r_ptr], %%rsi \n\t"
|
||||
STORE_FIELD_ELEMENT("%1", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&r),
|
||||
[zero_reference] "m"(zero_ref),
|
||||
[modulus_0] "m"(modulus_0),
|
||||
[modulus_1] "m"(modulus_1),
|
||||
[modulus_2] "m"(modulus_2),
|
||||
[modulus_3] "m"(modulus_3),
|
||||
[r_inv] "m"(r_inv)
|
||||
: "%rcx", "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
#endif
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> void field<T>::asm_self_sqr_with_coarse_reduction(const field& a) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_self_sqr_with_coarse_reduction");
|
||||
|
||||
constexpr uint64_t r_inv = T::r_inv;
|
||||
constexpr uint64_t modulus_0 = modulus.data[0];
|
||||
constexpr uint64_t modulus_1 = modulus.data[1];
|
||||
constexpr uint64_t modulus_2 = modulus.data[2];
|
||||
constexpr uint64_t modulus_3 = modulus.data[3];
|
||||
constexpr uint64_t zero_ref = 0;
|
||||
|
||||
// Our SQR implementation with BMI2 but without ADX has a bug.
|
||||
// The case is extremely rare so fixing it is a bit of a waste of time.
|
||||
// We'll use MUL instead.
|
||||
#if !defined(__ADX__) || defined(DISABLE_ADX)
|
||||
/**
|
||||
* Registers: rax:rdx = multiplication accumulator
|
||||
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
|
||||
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
|
||||
* %r10: zero register
|
||||
* %0: pointer to `a`
|
||||
* %1: pointer to `b`
|
||||
* %2: pointer to `r`
|
||||
**/
|
||||
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
|
||||
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&a),
|
||||
[modulus_0] "m"(modulus_0),
|
||||
[modulus_1] "m"(modulus_1),
|
||||
[modulus_2] "m"(modulus_2),
|
||||
[modulus_3] "m"(modulus_3),
|
||||
[r_inv] "m"(r_inv),
|
||||
[zero_reference] "m"(zero_ref)
|
||||
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
|
||||
#else
|
||||
/**
|
||||
* Registers: rax:rdx = multiplication accumulator
|
||||
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
|
||||
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
|
||||
* %[zero_reference]: memory location of zero value
|
||||
* %0: pointer to `a`
|
||||
* %[r_ptr]: memory location of pointer to `r`
|
||||
**/
|
||||
__asm__(SQR("%0")
|
||||
// "movq %[r_ptr], %%rsi \n\t"
|
||||
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
[zero_reference] "m"(zero_ref),
|
||||
[modulus_0] "m"(modulus_0),
|
||||
[modulus_1] "m"(modulus_1),
|
||||
[modulus_2] "m"(modulus_2),
|
||||
[modulus_3] "m"(modulus_3),
|
||||
[r_inv] "m"(r_inv)
|
||||
: "%rcx", "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T> field<T> field<T>::asm_add_with_coarse_reduction(const field& a, const field& b) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_add_with_coarse_reduction");
|
||||
|
||||
field r;
|
||||
|
||||
constexpr uint64_t twice_not_modulus_0 = twice_not_modulus.data[0];
|
||||
constexpr uint64_t twice_not_modulus_1 = twice_not_modulus.data[1];
|
||||
constexpr uint64_t twice_not_modulus_2 = twice_not_modulus.data[2];
|
||||
constexpr uint64_t twice_not_modulus_3 = twice_not_modulus.data[3];
|
||||
|
||||
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
ADD_REDUCE("%1",
|
||||
"%[twice_not_modulus_0]",
|
||||
"%[twice_not_modulus_1]",
|
||||
"%[twice_not_modulus_2]",
|
||||
"%[twice_not_modulus_3]") STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "%r"(&a),
|
||||
"%r"(&b),
|
||||
"r"(&r),
|
||||
[twice_not_modulus_0] "m"(twice_not_modulus_0),
|
||||
[twice_not_modulus_1] "m"(twice_not_modulus_1),
|
||||
[twice_not_modulus_2] "m"(twice_not_modulus_2),
|
||||
[twice_not_modulus_3] "m"(twice_not_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> void field<T>::asm_self_add_with_coarse_reduction(const field& a, const field& b) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_self_add_with_coarse_reduction");
|
||||
|
||||
constexpr uint64_t twice_not_modulus_0 = twice_not_modulus.data[0];
|
||||
constexpr uint64_t twice_not_modulus_1 = twice_not_modulus.data[1];
|
||||
constexpr uint64_t twice_not_modulus_2 = twice_not_modulus.data[2];
|
||||
constexpr uint64_t twice_not_modulus_3 = twice_not_modulus.data[3];
|
||||
|
||||
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
ADD_REDUCE("%1",
|
||||
"%[twice_not_modulus_0]",
|
||||
"%[twice_not_modulus_1]",
|
||||
"%[twice_not_modulus_2]",
|
||||
"%[twice_not_modulus_3]") STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&b),
|
||||
[twice_not_modulus_0] "m"(twice_not_modulus_0),
|
||||
[twice_not_modulus_1] "m"(twice_not_modulus_1),
|
||||
[twice_not_modulus_2] "m"(twice_not_modulus_2),
|
||||
[twice_not_modulus_3] "m"(twice_not_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
}
|
||||
|
||||
template <class T> field<T> field<T>::asm_sub_with_coarse_reduction(const field& a, const field& b) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_sub_with_coarse_reduction");
|
||||
|
||||
field r;
|
||||
|
||||
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
|
||||
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
|
||||
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
|
||||
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
|
||||
|
||||
__asm__(
|
||||
CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") SUB("%1")
|
||||
REDUCE_FIELD_ELEMENT("%[twice_modulus_0]", "%[twice_modulus_1]", "%[twice_modulus_2]", "%[twice_modulus_3]")
|
||||
STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&b),
|
||||
"r"(&r),
|
||||
[twice_modulus_0] "m"(twice_modulus_0),
|
||||
[twice_modulus_1] "m"(twice_modulus_1),
|
||||
[twice_modulus_2] "m"(twice_modulus_2),
|
||||
[twice_modulus_3] "m"(twice_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> void field<T>::asm_self_sub_with_coarse_reduction(const field& a, const field& b) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_self_sub_with_coarse_reduction");
|
||||
|
||||
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
|
||||
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
|
||||
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
|
||||
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
|
||||
|
||||
__asm__(
|
||||
CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") SUB("%1")
|
||||
REDUCE_FIELD_ELEMENT("%[twice_modulus_0]", "%[twice_modulus_1]", "%[twice_modulus_2]", "%[twice_modulus_3]")
|
||||
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&b),
|
||||
[twice_modulus_0] "m"(twice_modulus_0),
|
||||
[twice_modulus_1] "m"(twice_modulus_1),
|
||||
[twice_modulus_2] "m"(twice_modulus_2),
|
||||
[twice_modulus_3] "m"(twice_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
}
|
||||
|
||||
template <class T> void field<T>::asm_conditional_negate(field& r, const uint64_t predicate) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_conditional_negate");
|
||||
|
||||
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
|
||||
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
|
||||
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
|
||||
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
|
||||
|
||||
__asm__(CLEAR_FLAGS("%%r8") LOAD_FIELD_ELEMENT(
|
||||
"%1", "%%r8", "%%r9", "%%r10", "%%r11") "movq %[twice_modulus_0], %%r12 \n\t"
|
||||
"movq %[twice_modulus_1], %%r13 \n\t"
|
||||
"movq %[twice_modulus_2], %%r14 \n\t"
|
||||
"movq %[twice_modulus_3], %%r15 \n\t"
|
||||
"subq %%r8, %%r12 \n\t"
|
||||
"sbbq %%r9, %%r13 \n\t"
|
||||
"sbbq %%r10, %%r14 \n\t"
|
||||
"sbbq %%r11, %%r15 \n\t"
|
||||
"testq %0, %0 \n\t"
|
||||
"cmovnzq %%r12, %%r8 \n\t"
|
||||
"cmovnzq %%r13, %%r9 \n\t"
|
||||
"cmovnzq %%r14, %%r10 \n\t"
|
||||
"cmovnzq %%r15, %%r11 \n\t" STORE_FIELD_ELEMENT(
|
||||
"%1", "%%r8", "%%r9", "%%r10", "%%r11")
|
||||
:
|
||||
: "r"(predicate),
|
||||
"r"(&r),
|
||||
[twice_modulus_0] "i"(twice_modulus_0),
|
||||
[twice_modulus_1] "i"(twice_modulus_1),
|
||||
[twice_modulus_2] "i"(twice_modulus_2),
|
||||
[twice_modulus_3] "i"(twice_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
}
|
||||
|
||||
template <class T> field<T> field<T>::asm_reduce_once(const field& a) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_reduce_once");
|
||||
|
||||
field r;
|
||||
|
||||
constexpr uint64_t not_modulus_0 = not_modulus.data[0];
|
||||
constexpr uint64_t not_modulus_1 = not_modulus.data[1];
|
||||
constexpr uint64_t not_modulus_2 = not_modulus.data[2];
|
||||
constexpr uint64_t not_modulus_3 = not_modulus.data[3];
|
||||
|
||||
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
REDUCE_FIELD_ELEMENT("%[not_modulus_0]", "%[not_modulus_1]", "%[not_modulus_2]", "%[not_modulus_3]")
|
||||
STORE_FIELD_ELEMENT("%1", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
"r"(&r),
|
||||
[not_modulus_0] "m"(not_modulus_0),
|
||||
[not_modulus_1] "m"(not_modulus_1),
|
||||
[not_modulus_2] "m"(not_modulus_2),
|
||||
[not_modulus_3] "m"(not_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
return r;
|
||||
}
|
||||
|
||||
template <class T> void field<T>::asm_self_reduce_once(const field& a) noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::asm_self_reduce_once");
|
||||
|
||||
constexpr uint64_t not_modulus_0 = not_modulus.data[0];
|
||||
constexpr uint64_t not_modulus_1 = not_modulus.data[1];
|
||||
constexpr uint64_t not_modulus_2 = not_modulus.data[2];
|
||||
constexpr uint64_t not_modulus_3 = not_modulus.data[3];
|
||||
|
||||
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
REDUCE_FIELD_ELEMENT("%[not_modulus_0]", "%[not_modulus_1]", "%[not_modulus_2]", "%[not_modulus_3]")
|
||||
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
|
||||
:
|
||||
: "r"(&a),
|
||||
[not_modulus_0] "m"(not_modulus_0),
|
||||
[not_modulus_1] "m"(not_modulus_1),
|
||||
[not_modulus_2] "m"(not_modulus_2),
|
||||
[not_modulus_3] "m"(not_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
|
||||
}
|
||||
} // namespace bb
|
||||
#endif
|
||||
@@ -0,0 +1,192 @@
|
||||
#pragma once
|
||||
#include "../../common/serialize.hpp"
|
||||
#include "../../ecc/curves/bn254/fq2.hpp"
|
||||
#include "../../numeric/uint256/uint256.hpp"
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace bb::group_elements {
|
||||
template <typename T>
|
||||
concept SupportsHashToCurve = T::can_hash_to_curve;
|
||||
template <typename Fq_, typename Fr_, typename Params> class alignas(64) affine_element {
|
||||
public:
|
||||
using Fq = Fq_;
|
||||
using Fr = Fr_;
|
||||
|
||||
using in_buf = const uint8_t*;
|
||||
using vec_in_buf = const uint8_t*;
|
||||
using out_buf = uint8_t*;
|
||||
using vec_out_buf = uint8_t**;
|
||||
|
||||
affine_element() noexcept = default;
|
||||
~affine_element() noexcept = default;
|
||||
|
||||
constexpr affine_element(const Fq& x, const Fq& y) noexcept;
|
||||
|
||||
constexpr affine_element(const affine_element& other) noexcept = default;
|
||||
|
||||
constexpr affine_element(affine_element&& other) noexcept = default;
|
||||
|
||||
static constexpr affine_element one() noexcept { return { Params::one_x, Params::one_y }; };
|
||||
|
||||
/**
|
||||
* @brief Reconstruct a point in affine coordinates from compressed form.
|
||||
* @details #LARGE_MODULUS_AFFINE_POINT_COMPRESSION Point compression is only implemented for curves of a prime
|
||||
* field F_p with p using < 256 bits. One possiblity for extending to a 256-bit prime field:
|
||||
* https://patents.google.com/patent/US6252960B1/en.
|
||||
*
|
||||
* @param compressed compressed point
|
||||
* @return constexpr affine_element
|
||||
*/
|
||||
template <typename BaseField = Fq,
|
||||
typename CompileTimeEnabled = std::enable_if_t<(BaseField::modulus >> 255) == uint256_t(0), void>>
|
||||
static constexpr affine_element from_compressed(const uint256_t& compressed) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Reconstruct a point in affine coordinates from compressed form.
|
||||
* @details #LARGE_MODULUS_AFFINE_POINT_COMPRESSION Point compression is implemented for curves of a prime
|
||||
* field F_p with p being 256 bits.
|
||||
* TODO(Suyash): Check with kesha if this is correct.
|
||||
*
|
||||
* @param compressed compressed point
|
||||
* @return constexpr affine_element
|
||||
*/
|
||||
template <typename BaseField = Fq,
|
||||
typename CompileTimeEnabled = std::enable_if_t<(BaseField::modulus >> 255) == uint256_t(1), void>>
|
||||
static constexpr std::array<affine_element, 2> from_compressed_unsafe(const uint256_t& compressed) noexcept;
|
||||
|
||||
constexpr affine_element& operator=(const affine_element& other) noexcept = default;
|
||||
|
||||
constexpr affine_element& operator=(affine_element&& other) noexcept = default;
|
||||
|
||||
constexpr affine_element operator+(const affine_element& other) const noexcept;
|
||||
|
||||
template <typename BaseField = Fq,
|
||||
typename CompileTimeEnabled = std::enable_if_t<(BaseField::modulus >> 255) == uint256_t(0), void>>
|
||||
[[nodiscard]] constexpr uint256_t compress() const noexcept;
|
||||
|
||||
static affine_element infinity();
|
||||
constexpr affine_element set_infinity() const noexcept;
|
||||
constexpr void self_set_infinity() noexcept;
|
||||
|
||||
[[nodiscard]] constexpr bool is_point_at_infinity() const noexcept;
|
||||
|
||||
[[nodiscard]] constexpr bool on_curve() const noexcept;
|
||||
|
||||
static constexpr std::optional<affine_element> derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Samples a random point on the curve.
|
||||
*
|
||||
* @return A randomly chosen point on the curve
|
||||
*/
|
||||
static affine_element random_element(numeric::RNG* engine = nullptr) noexcept;
|
||||
static constexpr affine_element hash_to_curve(const std::vector<uint8_t>& seed, uint8_t attempt_count = 0) noexcept
|
||||
requires SupportsHashToCurve<Params>;
|
||||
|
||||
constexpr bool operator==(const affine_element& other) const noexcept;
|
||||
|
||||
constexpr affine_element operator-() const noexcept { return { x, -y }; }
|
||||
|
||||
constexpr bool operator>(const affine_element& other) const noexcept;
|
||||
constexpr bool operator<(const affine_element& other) const noexcept { return (other > *this); }
|
||||
|
||||
/**
|
||||
* @brief Serialize the point to the given buffer
|
||||
*
|
||||
* @details We support serializing the point at infinity for curves defined over a bb::field (i.e., a
|
||||
* native field of prime order) and for points of bb::g2.
|
||||
*
|
||||
* @warning This will need to be updated if we serialize points over composite-order fields other than fq2!
|
||||
*
|
||||
*/
|
||||
static void serialize_to_buffer(const affine_element& value, uint8_t* buffer, bool write_x_first = false)
|
||||
{
|
||||
using namespace serialize;
|
||||
if (value.is_point_at_infinity()) {
|
||||
// if we are infinity, just set all buffer bits to 1
|
||||
// we only need this case because the below gets mangled converting from montgomery for infinity points
|
||||
memset(buffer, 255, sizeof(Fq) * 2);
|
||||
} else {
|
||||
// Note: for historic reasons we will need to redo downstream hashes if we want this to always be written in
|
||||
// the same order in our various serialization flows
|
||||
write(buffer, write_x_first ? value.x : value.y);
|
||||
write(buffer, write_x_first ? value.y : value.x);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Restore point from a buffer
|
||||
*
|
||||
* @param buffer Buffer from which we deserialize the point
|
||||
*
|
||||
* @return Deserialized point
|
||||
*
|
||||
* @details We support serializing the point at infinity for curves defined over a bb::field (i.e., a
|
||||
* native field of prime order) and for points of bb::g2.
|
||||
*
|
||||
* @warning This will need to be updated if we serialize points over composite-order fields other than fq2!
|
||||
*/
|
||||
static affine_element serialize_from_buffer(const uint8_t* buffer, bool write_x_first = false)
|
||||
{
|
||||
using namespace serialize;
|
||||
// Does the buffer consist entirely of set bits? If so, we have a point at infinity
|
||||
// Note that if it isn't, this loop should end early.
|
||||
// We only need this case because the below gets mangled converting to montgomery for infinity points
|
||||
bool is_point_at_infinity =
|
||||
std::all_of(buffer, buffer + sizeof(Fq) * 2, [](uint8_t val) { return val == 255; });
|
||||
if (is_point_at_infinity) {
|
||||
return affine_element::infinity();
|
||||
}
|
||||
affine_element result;
|
||||
// Note: for historic reasons we will need to redo downstream hashes if we want this to always be read in the
|
||||
// same order in our various serialization flows
|
||||
read(buffer, write_x_first ? result.x : result.y);
|
||||
read(buffer, write_x_first ? result.y : result.x);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Serialize the point to a byte vector
|
||||
*
|
||||
* @return Vector with serialized representation of the point
|
||||
*/
|
||||
[[nodiscard]] inline std::vector<uint8_t> to_buffer() const
|
||||
{
|
||||
std::vector<uint8_t> buffer(sizeof(affine_element));
|
||||
affine_element::serialize_to_buffer(*this, &buffer[0]);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const affine_element& a)
|
||||
{
|
||||
os << "{ " << a.x << ", " << a.y << " }";
|
||||
return os;
|
||||
}
|
||||
Fq x;
|
||||
Fq y;
|
||||
};
|
||||
|
||||
template <typename B, typename Fq_, typename Fr_, typename Params>
|
||||
inline void read(B& it, group_elements::affine_element<Fq_, Fr_, Params>& element)
|
||||
{
|
||||
using namespace serialize;
|
||||
std::array<uint8_t, sizeof(element)> buffer;
|
||||
read(it, buffer);
|
||||
element = group_elements::affine_element<Fq_, Fr_, Params>::serialize_from_buffer(
|
||||
buffer.data(), /* use legacy field order */ true);
|
||||
}
|
||||
|
||||
template <typename B, typename Fq_, typename Fr_, typename Params>
|
||||
inline void write(B& it, group_elements::affine_element<Fq_, Fr_, Params> const& element)
|
||||
{
|
||||
using namespace serialize;
|
||||
std::array<uint8_t, sizeof(element)> buffer;
|
||||
group_elements::affine_element<Fq_, Fr_, Params>::serialize_to_buffer(
|
||||
element, buffer.data(), /* use legacy field order */ true);
|
||||
write(it, buffer);
|
||||
}
|
||||
} // namespace bb::group_elements
|
||||
|
||||
#include "./affine_element_impl.hpp"
|
||||
@@ -0,0 +1,290 @@
|
||||
#pragma once
|
||||
#include "./element.hpp"
|
||||
#include "../../crypto/blake3s/blake3s.hpp"
|
||||
#include "../../crypto/keccak/keccak.hpp"
|
||||
|
||||
namespace bb::group_elements {
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr affine_element<Fq, Fr, T>::affine_element(const Fq& x, const Fq& y) noexcept
|
||||
: x(x)
|
||||
, y(y)
|
||||
{}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
template <typename BaseField, typename CompileTimeEnabled>
|
||||
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::from_compressed(const uint256_t& compressed) noexcept
|
||||
{
|
||||
uint256_t x_coordinate = compressed;
|
||||
x_coordinate.data[3] = x_coordinate.data[3] & (~0x8000000000000000ULL);
|
||||
bool y_bit = compressed.get_bit(255);
|
||||
|
||||
Fq x = Fq(x_coordinate);
|
||||
Fq y2 = (x.sqr() * x + T::b);
|
||||
if constexpr (T::has_a) {
|
||||
y2 += (x * T::a);
|
||||
}
|
||||
auto [is_quadratic_remainder, y] = y2.sqrt();
|
||||
if (!is_quadratic_remainder) {
|
||||
return affine_element(Fq::zero(), Fq::zero());
|
||||
}
|
||||
if (uint256_t(y).get_bit(0) != y_bit) {
|
||||
y = -y;
|
||||
}
|
||||
|
||||
return affine_element<Fq, Fr, T>(x, y);
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
template <typename BaseField, typename CompileTimeEnabled>
|
||||
constexpr std::array<affine_element<Fq, Fr, T>, 2> affine_element<Fq, Fr, T>::from_compressed_unsafe(
|
||||
const uint256_t& compressed) noexcept
|
||||
{
|
||||
auto get_y_coordinate = [](const uint256_t& x_coordinate) {
|
||||
Fq x = Fq(x_coordinate);
|
||||
Fq y2 = (x.sqr() * x + T::b);
|
||||
if constexpr (T::has_a) {
|
||||
y2 += (x * T::a);
|
||||
}
|
||||
return y2.sqrt();
|
||||
};
|
||||
|
||||
uint256_t x_1 = compressed;
|
||||
uint256_t x_2 = compressed + Fr::modulus;
|
||||
auto [is_quadratic_remainder_1, y_1] = get_y_coordinate(x_1);
|
||||
auto [is_quadratic_remainder_2, y_2] = get_y_coordinate(x_2);
|
||||
|
||||
auto output_1 = is_quadratic_remainder_1 ? affine_element<Fq, Fr, T>(Fq(x_1), y_1)
|
||||
: affine_element<Fq, Fr, T>(Fq::zero(), Fq::zero());
|
||||
auto output_2 = is_quadratic_remainder_2 ? affine_element<Fq, Fr, T>(Fq(x_2), y_2)
|
||||
: affine_element<Fq, Fr, T>(Fq::zero(), Fq::zero());
|
||||
|
||||
return { output_1, output_2 };
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::operator+(
|
||||
const affine_element<Fq, Fr, T>& other) const noexcept
|
||||
{
|
||||
return affine_element(element<Fq, Fr, T>(*this) + element<Fq, Fr, T>(other));
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
template <typename BaseField, typename CompileTimeEnabled>
|
||||
|
||||
constexpr uint256_t affine_element<Fq, Fr, T>::compress() const noexcept
|
||||
{
|
||||
uint256_t out(x);
|
||||
if (uint256_t(y).get_bit(0)) {
|
||||
out.data[3] = out.data[3] | 0x8000000000000000ULL;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T> affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::infinity()
|
||||
{
|
||||
affine_element e;
|
||||
e.self_set_infinity();
|
||||
return e;
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::set_infinity() const noexcept
|
||||
{
|
||||
affine_element result(*this);
|
||||
result.self_set_infinity();
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T> constexpr void affine_element<Fq, Fr, T>::self_set_infinity() noexcept
|
||||
{
|
||||
if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
// We set the value of x equal to modulus to represent inifinty
|
||||
x.data[0] = Fq::modulus.data[0];
|
||||
x.data[1] = Fq::modulus.data[1];
|
||||
x.data[2] = Fq::modulus.data[2];
|
||||
x.data[3] = Fq::modulus.data[3];
|
||||
|
||||
} else {
|
||||
x.self_set_msb();
|
||||
}
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T> constexpr bool affine_element<Fq, Fr, T>::is_point_at_infinity() const noexcept
|
||||
{
|
||||
if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
// We check if the value of x is equal to modulus to represent inifinty
|
||||
return ((x.data[0] ^ Fq::modulus.data[0]) | (x.data[1] ^ Fq::modulus.data[1]) |
|
||||
(x.data[2] ^ Fq::modulus.data[2]) | (x.data[3] ^ Fq::modulus.data[3])) == 0;
|
||||
|
||||
} else {
|
||||
return (x.is_msb_set());
|
||||
}
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T> constexpr bool affine_element<Fq, Fr, T>::on_curve() const noexcept
|
||||
{
|
||||
if (is_point_at_infinity()) {
|
||||
return true;
|
||||
}
|
||||
Fq xxx = x.sqr() * x + T::b;
|
||||
Fq yy = y.sqr();
|
||||
if constexpr (T::has_a) {
|
||||
xxx += (x * T::a);
|
||||
}
|
||||
return (xxx == yy);
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr bool affine_element<Fq, Fr, T>::operator==(const affine_element& other) const noexcept
|
||||
{
|
||||
bool this_is_infinity = is_point_at_infinity();
|
||||
bool other_is_infinity = other.is_point_at_infinity();
|
||||
bool both_infinity = this_is_infinity && other_is_infinity;
|
||||
bool only_one_is_infinity = this_is_infinity != other_is_infinity;
|
||||
return !only_one_is_infinity && (both_infinity || ((x == other.x) && (y == other.y)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Comparison operators (for std::sort)
|
||||
*
|
||||
* @details CAUTION!! Don't use this operator. It has no meaning other than for use by std::sort.
|
||||
**/
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr bool affine_element<Fq, Fr, T>::operator>(const affine_element& other) const noexcept
|
||||
{
|
||||
// We are setting point at infinity to always be the lowest element
|
||||
if (is_point_at_infinity()) {
|
||||
return false;
|
||||
}
|
||||
if (other.is_point_at_infinity()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (x > other.x) {
|
||||
return true;
|
||||
}
|
||||
if (x == other.x && y > other.y) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr std::optional<affine_element<Fq, Fr, T>> affine_element<Fq, Fr, T>::derive_from_x_coordinate(
|
||||
const Fq& x, bool sign_bit) noexcept
|
||||
{
|
||||
auto yy = x.sqr() * x + T::b;
|
||||
if constexpr (T::has_a) {
|
||||
yy += (x * T::a);
|
||||
}
|
||||
auto [found_root, y] = yy.sqrt();
|
||||
|
||||
if (found_root) {
|
||||
if (uint256_t(y).get_bit(0) != sign_bit) {
|
||||
y = -y;
|
||||
}
|
||||
return affine_element(x, y);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Hash a seed buffer into a point
|
||||
*
|
||||
* @details ALGORITHM DESCRIPTION:
|
||||
* 1. Initialize unsigned integer `attempt_count = 0`
|
||||
* 2. Copy seed into a buffer whose size is 2 bytes greater than `seed` (initialized to 0)
|
||||
* 3. Interpret `attempt_count` as a byte and write into buffer at [buffer.size() - 2]
|
||||
* 4. Compute Blake3s hash of buffer
|
||||
* 5. Set the end byte of the buffer to `1`
|
||||
* 6. Compute Blake3s hash of buffer
|
||||
* 7. Interpret the two hash outputs as the high / low 256 bits of a 512-bit integer (big-endian)
|
||||
* 8. Derive x-coordinate of point by reducing the 512-bit integer modulo the curve's field modulus (Fq)
|
||||
* 9. Compute y^2 from the curve formula y^2 = x^3 + ax + b (a, b are curve params. for BN254, a = 0, b = 3)
|
||||
* 10. IF y^2 IS NOT A QUADRATIC RESIDUE
|
||||
* 10a. increment `attempt_count` by 1 and go to step 2
|
||||
* 11. IF y^2 IS A QUADRATIC RESIDUE
|
||||
* 11a. derive y coordinate via y = sqrt(y)
|
||||
* 11b. Interpret most significant bit of 512-bit integer as a 'parity' bit
|
||||
* 11c. If parity bit is set AND y's most significant bit is not set, invert y
|
||||
* 11d. If parity bit is not set AND y's most significant bit is set, invert y
|
||||
* N.B. last 2 steps are because the sqrt() algorithm can return 2 values,
|
||||
* we need to a way to canonically distinguish between these 2 values and select a "preferred" one
|
||||
* 11e. return (x, y)
|
||||
*
|
||||
* @note This algorihm is constexpr: we can hash-to-curve (and derive generators) at compile-time!
|
||||
* @tparam Fq
|
||||
* @tparam Fr
|
||||
* @tparam T
|
||||
* @param seed Bytes that uniquely define the point being generated
|
||||
* @param attempt_count
|
||||
* @return constexpr affine_element<Fq, Fr, T>
|
||||
*/
|
||||
template <class Fq, class Fr, class T>
|
||||
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::hash_to_curve(const std::vector<uint8_t>& seed,
|
||||
uint8_t attempt_count) noexcept
|
||||
requires SupportsHashToCurve<T>
|
||||
|
||||
{
|
||||
std::vector<uint8_t> target_seed(seed);
|
||||
// expand by 2 bytes to cover incremental hash attempts
|
||||
const size_t seed_size = seed.size();
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
target_seed.push_back(0);
|
||||
}
|
||||
target_seed[seed_size] = attempt_count;
|
||||
target_seed[seed_size + 1] = 0;
|
||||
const auto hash_hi = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
|
||||
target_seed[seed_size + 1] = 1;
|
||||
const auto hash_lo = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
|
||||
// custom serialize methods as common/serialize.hpp is not constexpr!
|
||||
const auto read_uint256 = [](const uint8_t* in) {
|
||||
const auto read_limb = [](const uint8_t* in, uint64_t& out) {
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
out += static_cast<uint64_t>(in[i]) << ((7 - i) * 8);
|
||||
}
|
||||
};
|
||||
uint256_t out = 0;
|
||||
read_limb(&in[0], out.data[3]);
|
||||
read_limb(&in[8], out.data[2]);
|
||||
read_limb(&in[16], out.data[1]);
|
||||
read_limb(&in[24], out.data[0]);
|
||||
return out;
|
||||
};
|
||||
// interpret 64 byte hash output as a uint512_t, reduce to Fq element
|
||||
//(512 bits of entropy ensures result is not biased as 512 >> Fq::modulus.get_msb())
|
||||
Fq x(uint512_t(read_uint256(&hash_lo[0]), read_uint256(&hash_hi[0])));
|
||||
bool sign_bit = hash_hi[0] > 127;
|
||||
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
|
||||
if (result.has_value()) {
|
||||
return result.value();
|
||||
}
|
||||
return hash_to_curve(seed, attempt_count + 1);
|
||||
}
|
||||
|
||||
template <typename Fq, typename Fr, typename T>
|
||||
affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::random_element(numeric::RNG* engine) noexcept
|
||||
{
|
||||
if (engine == nullptr) {
|
||||
engine = &numeric::get_randomness();
|
||||
}
|
||||
|
||||
Fq x;
|
||||
Fq y;
|
||||
while (true) {
|
||||
// Sample a random x-coordinate and check if it satisfies curve equation.
|
||||
x = Fq::random_element(engine);
|
||||
// Negate the y-coordinate based on a randomly sampled bit.
|
||||
bool sign_bit = (engine->get_random_uint8() & 1) != 0;
|
||||
|
||||
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
|
||||
|
||||
if (result.has_value()) {
|
||||
return result.value();
|
||||
}
|
||||
}
|
||||
throw_or_abort("affine_element::random_element error");
|
||||
return affine_element<Fq, Fr, T>(x, y);
|
||||
}
|
||||
|
||||
} // namespace bb::group_elements
|
||||
168
sumcheck/src/cuda/includes/barretenberg/ecc/groups/element.hpp
Normal file
168
sumcheck/src/cuda/includes/barretenberg/ecc/groups/element.hpp
Normal file
@@ -0,0 +1,168 @@
|
||||
#pragma once
|
||||
|
||||
#include "affine_element.hpp"
|
||||
#include "../../common/compiler_hints.hpp"
|
||||
#include "../../common/mem.hpp"
|
||||
#include "../../numeric/random/engine.hpp"
|
||||
#include "../../numeric/uint256/uint256.hpp"
|
||||
#include "wnaf.hpp"
|
||||
#include <array>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
namespace bb::group_elements {
|
||||
|
||||
/**
|
||||
* @brief element class. Implements ecc group arithmetic using Jacobian coordinates
|
||||
* See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l
|
||||
*
|
||||
* Note: Currently subgroup checks are NOT IMPLEMENTED
|
||||
* Our current Plonk implementation uses G1 points that have a cofactor of 1.
|
||||
* All G2 points are precomputed (generator [1]_2 and trusted setup point [x]_2).
|
||||
* Explicitly assume precomputed points are valid members of the prime-order subgroup for G2.
|
||||
* @tparam Fq prime field the curve is defined over
|
||||
* @tparam Fr prime field whose characteristic equals the size of the prime-order elliptic curve subgroup
|
||||
* @tparam Params curve parameters
|
||||
*/
|
||||
template <class Fq, class Fr, class Params> class alignas(32) element {
|
||||
public:
|
||||
static constexpr Fq curve_b = Params::b;
|
||||
|
||||
element() noexcept = default;
|
||||
|
||||
constexpr element(const Fq& a, const Fq& b, const Fq& c) noexcept;
|
||||
constexpr element(const element& other) noexcept;
|
||||
constexpr element(element&& other) noexcept;
|
||||
constexpr element(const affine_element<Fq, Fr, Params>& other) noexcept;
|
||||
constexpr ~element() noexcept = default;
|
||||
|
||||
static constexpr element one() noexcept { return { Params::one_x, Params::one_y, Fq::one() }; };
|
||||
static constexpr element zero() noexcept
|
||||
{
|
||||
element zero;
|
||||
zero.self_set_infinity();
|
||||
return zero;
|
||||
};
|
||||
|
||||
constexpr element& operator=(const element& other) noexcept;
|
||||
constexpr element& operator=(element&& other) noexcept;
|
||||
|
||||
constexpr operator affine_element<Fq, Fr, Params>() const noexcept;
|
||||
|
||||
static element random_element(numeric::RNG* engine = nullptr) noexcept;
|
||||
|
||||
constexpr element dbl() const noexcept;
|
||||
constexpr void self_dbl() noexcept;
|
||||
constexpr void self_mixed_add_or_sub(const affine_element<Fq, Fr, Params>& other, uint64_t predicate) noexcept;
|
||||
|
||||
constexpr element operator+(const element& other) const noexcept;
|
||||
constexpr element operator+(const affine_element<Fq, Fr, Params>& other) const noexcept;
|
||||
constexpr element operator+=(const element& other) noexcept;
|
||||
constexpr element operator+=(const affine_element<Fq, Fr, Params>& other) noexcept;
|
||||
|
||||
constexpr element operator-(const element& other) const noexcept;
|
||||
constexpr element operator-(const affine_element<Fq, Fr, Params>& other) const noexcept;
|
||||
constexpr element operator-() const noexcept;
|
||||
constexpr element operator-=(const element& other) noexcept;
|
||||
constexpr element operator-=(const affine_element<Fq, Fr, Params>& other) noexcept;
|
||||
|
||||
friend constexpr element operator+(const affine_element<Fq, Fr, Params>& left, const element& right) noexcept
|
||||
{
|
||||
return right + left;
|
||||
}
|
||||
friend constexpr element operator-(const affine_element<Fq, Fr, Params>& left, const element& right) noexcept
|
||||
{
|
||||
return -right + left;
|
||||
}
|
||||
|
||||
element operator*(const Fr& exponent) const noexcept;
|
||||
element operator*=(const Fr& exponent) noexcept;
|
||||
|
||||
// If you end up implementing this, congrats, you've solved the DL problem!
|
||||
// P.S. This is a joke, don't even attempt! 😂
|
||||
// constexpr Fr operator/(const element& other) noexcept {}
|
||||
|
||||
constexpr element normalize() const noexcept;
|
||||
static element infinity();
|
||||
BB_INLINE constexpr element set_infinity() const noexcept;
|
||||
BB_INLINE constexpr void self_set_infinity() noexcept;
|
||||
[[nodiscard]] BB_INLINE constexpr bool is_point_at_infinity() const noexcept;
|
||||
[[nodiscard]] BB_INLINE constexpr bool on_curve() const noexcept;
|
||||
BB_INLINE constexpr bool operator==(const element& other) const noexcept;
|
||||
|
||||
static void batch_normalize(element* elements, size_t num_elements) noexcept;
|
||||
static void batch_affine_add(const std::span<affine_element<Fq, Fr, Params>>& first_group,
|
||||
const std::span<affine_element<Fq, Fr, Params>>& second_group,
|
||||
const std::span<affine_element<Fq, Fr, Params>>& results) noexcept;
|
||||
static std::vector<affine_element<Fq, Fr, Params>> batch_mul_with_endomorphism(
|
||||
const std::span<affine_element<Fq, Fr, Params>>& points, const Fr& scalar) noexcept;
|
||||
|
||||
Fq x;
|
||||
Fq y;
|
||||
Fq z;
|
||||
|
||||
private:
|
||||
// For test access to mul_without_endomorphism
|
||||
friend class TestElementPrivate;
|
||||
element mul_without_endomorphism(const Fr& scalar) const noexcept;
|
||||
element mul_with_endomorphism(const Fr& scalar) const noexcept;
|
||||
|
||||
template <typename = typename std::enable_if<Params::can_hash_to_curve>>
|
||||
static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept;
|
||||
// {
|
||||
// bool found_one = false;
|
||||
// Fq yy;
|
||||
// Fq x;
|
||||
// Fq y;
|
||||
// Fq t0;
|
||||
// while (!found_one) {
|
||||
// x = Fq::random_element(engine);
|
||||
// yy = x.sqr() * x + Params::b;
|
||||
// if constexpr (Params::has_a) {
|
||||
// yy += (x * Params::a);
|
||||
// }
|
||||
// y = yy.sqrt();
|
||||
// t0 = y.sqr();
|
||||
// found_one = (yy == t0);
|
||||
// }
|
||||
// return { x, y, Fq::one() };
|
||||
// }
|
||||
// for serialization: update with new fields
|
||||
// TODO(https://github.com/AztecProtocol/barretenberg/issues/908) point at inifinty isn't handled
|
||||
|
||||
static void conditional_negate_affine(const affine_element<Fq, Fr, Params>& in,
|
||||
affine_element<Fq, Fr, Params>& out,
|
||||
uint64_t predicate) noexcept;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const element& a)
|
||||
{
|
||||
os << "{ " << a.x << ", " << a.y << ", " << a.z << " }";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
template <class Fq, class Fr, class Params> std::ostream& operator<<(std::ostream& os, element<Fq, Fr, Params> const& e)
|
||||
{
|
||||
return os << "x:" << e.x << " y:" << e.y << " z:" << e.z;
|
||||
}
|
||||
|
||||
// constexpr element<Fq, Fr, Params>::one = element<Fq, Fr, Params>{ Params::one_x, Params::one_y, Fq::one() };
|
||||
// constexpr element<Fq, Fr, Params>::point_at_infinity = one.set_infinity();
|
||||
// constexpr element<Fq, Fr, Params>::curve_b = Params::b;
|
||||
} // namespace bb::group_elements
|
||||
|
||||
#include "./element_impl.hpp"
|
||||
|
||||
template <class Fq, class Fr, class Params>
|
||||
bb::group_elements::affine_element<Fq, Fr, Params> operator*(
|
||||
const bb::group_elements::affine_element<Fq, Fr, Params>& base, const Fr& exponent) noexcept
|
||||
{
|
||||
return bb::group_elements::affine_element<Fq, Fr, Params>(bb::group_elements::element(base) * exponent);
|
||||
}
|
||||
|
||||
template <class Fq, class Fr, class Params>
|
||||
bb::group_elements::affine_element<Fq, Fr, Params> operator*(const bb::group_elements::element<Fq, Fr, Params>& base,
|
||||
const Fr& exponent) noexcept
|
||||
{
|
||||
return (bb::group_elements::element(base) * exponent);
|
||||
}
|
||||
1219
sumcheck/src/cuda/includes/barretenberg/ecc/groups/element_impl.hpp
Normal file
1219
sumcheck/src/cuda/includes/barretenberg/ecc/groups/element_impl.hpp
Normal file
File diff suppressed because it is too large
Load Diff
129
sumcheck/src/cuda/includes/barretenberg/ecc/groups/group.hpp
Normal file
129
sumcheck/src/cuda/includes/barretenberg/ecc/groups/group.hpp
Normal file
@@ -0,0 +1,129 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../common/assert.hpp"
|
||||
#include "./affine_element.hpp"
|
||||
#include "./element.hpp"
|
||||
#include "./wnaf.hpp"
|
||||
#include "../../common/constexpr_utils.hpp"
|
||||
#include "../../crypto/blake3s/blake3s.hpp"
|
||||
#include <array>
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
namespace bb {
|
||||
|
||||
/**
|
||||
* @brief group class. Represents an elliptic curve group element.
|
||||
* Group is parametrised by coordinate_field and subgroup_field
|
||||
*
|
||||
* Note: Currently subgroup checks are NOT IMPLEMENTED
|
||||
* Our current Plonk implementation uses G1 points that have a cofactor of 1.
|
||||
* All G2 points are precomputed (generator [1]_2 and trusted setup point [x]_2).
|
||||
* Explicitly assume precomputed points are valid members of the prime-order subgroup for G2.
|
||||
*
|
||||
* @tparam coordinate_field
|
||||
* @tparam subgroup_field
|
||||
* @tparam GroupParams
|
||||
*/
|
||||
template <typename _coordinate_field, typename _subgroup_field, typename GroupParams> class group {
|
||||
public:
|
||||
// hoist coordinate_field, subgroup_field into the public namespace
|
||||
using coordinate_field = _coordinate_field;
|
||||
using subgroup_field = _subgroup_field;
|
||||
using element = group_elements::element<coordinate_field, subgroup_field, GroupParams>;
|
||||
using affine_element = group_elements::affine_element<coordinate_field, subgroup_field, GroupParams>;
|
||||
using Fq = coordinate_field;
|
||||
using Fr = subgroup_field;
|
||||
static constexpr bool USE_ENDOMORPHISM = GroupParams::USE_ENDOMORPHISM;
|
||||
static constexpr bool has_a = GroupParams::has_a;
|
||||
|
||||
static constexpr element one{ GroupParams::one_x, GroupParams::one_y, coordinate_field::one() };
|
||||
static constexpr element point_at_infinity = one.set_infinity();
|
||||
static constexpr affine_element affine_one{ GroupParams::one_x, GroupParams::one_y };
|
||||
static constexpr affine_element affine_point_at_infinity = affine_one.set_infinity();
|
||||
static constexpr coordinate_field curve_a = GroupParams::a;
|
||||
static constexpr coordinate_field curve_b = GroupParams::b;
|
||||
|
||||
/**
|
||||
* @brief Derives generator points via hash-to-curve
|
||||
*
|
||||
* ALGORITHM DESCRIPTION:
|
||||
* 1. Each generator has an associated "generator index" described by its location in the vector
|
||||
* 2. a 64-byte preimage buffer is generated with the following structure:
|
||||
* bytes 0-31: BLAKE3 hash of domain_separator
|
||||
* bytes 32-63: generator index in big-endian form
|
||||
* 3. The hash-to-curve algorithm is used to hash the above into a group element:
|
||||
* a. iterate `count` upwards from `0`
|
||||
* b. append `count` to the preimage buffer as a 1-byte integer in big-endian form
|
||||
* c. compute BLAKE3 hash of concat(preimage buffer, 0)
|
||||
* d. compute BLAKE3 hash of concat(preimage buffer, 1)
|
||||
* e. interpret (c, d) as (hi, low) limbs of a 512-bit integer
|
||||
* f. reduce 512-bit integer modulo coordinate_field to produce x-coordinate
|
||||
* g. attempt to derive y-coordinate. If not successful go to step (a) and continue
|
||||
* h. if parity of y-coordinate's least significant bit does not match parity of most significant bit of
|
||||
* (d), invert y-coordinate.
|
||||
* j. return (x, y)
|
||||
*
|
||||
* NOTE: In step 3b it is sufficient to use 1 byte to store `count`.
|
||||
* Step 3 has a 50% chance of returning, the probability of `count` exceeding 256 is 1 in 2^256
|
||||
* NOTE: The domain separator is included to ensure that it is possible to derive independent sets of
|
||||
* index-addressable generators.
|
||||
* NOTE: we produce 64 bytes of BLAKE3 output when producing x-coordinate field
|
||||
* element, to ensure that x-coordinate is uniformly randomly distributed in the field. Using a 256-bit input adds
|
||||
* significant bias when reducing modulo a ~256-bit coordinate_field
|
||||
* NOTE: We ensure y-parity is linked to preimage
|
||||
* hash because there is no canonical deterministic square root algorithm (i.e. if a field element has a square
|
||||
* root, there are two of them and `field::sqrt` may return either one)
|
||||
* @param num_generators
|
||||
* @param domain_separator
|
||||
* @return std::vector<affine_element>
|
||||
*/
|
||||
inline static constexpr std::vector<affine_element> derive_generators(
|
||||
const std::vector<uint8_t>& domain_separator_bytes,
|
||||
const size_t num_generators,
|
||||
const size_t starting_index = 0)
|
||||
{
|
||||
std::vector<affine_element> result;
|
||||
const auto domain_hash = blake3::blake3s_constexpr(&domain_separator_bytes[0], domain_separator_bytes.size());
|
||||
std::vector<uint8_t> generator_preimage;
|
||||
generator_preimage.reserve(64);
|
||||
std::copy(domain_hash.begin(), domain_hash.end(), std::back_inserter(generator_preimage));
|
||||
for (size_t i = 0; i < 32; ++i) {
|
||||
generator_preimage.emplace_back(0);
|
||||
}
|
||||
for (size_t i = starting_index; i < starting_index + num_generators; ++i) {
|
||||
auto generator_index = static_cast<uint32_t>(i);
|
||||
uint32_t mask = 0xff;
|
||||
generator_preimage[32] = static_cast<uint8_t>(generator_index >> 24);
|
||||
generator_preimage[33] = static_cast<uint8_t>((generator_index >> 16) & mask);
|
||||
generator_preimage[34] = static_cast<uint8_t>((generator_index >> 8) & mask);
|
||||
generator_preimage[35] = static_cast<uint8_t>(generator_index & mask);
|
||||
result.push_back(affine_element::hash_to_curve(generator_preimage));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline static constexpr std::vector<affine_element> derive_generators(const std::string_view& domain_separator,
|
||||
const size_t num_generators,
|
||||
const size_t starting_index = 0)
|
||||
{
|
||||
std::vector<uint8_t> domain_bytes;
|
||||
for (char i : domain_separator) {
|
||||
domain_bytes.emplace_back(static_cast<unsigned char>(i));
|
||||
}
|
||||
return derive_generators(domain_bytes, num_generators, starting_index);
|
||||
}
|
||||
|
||||
BB_INLINE static void conditional_negate_affine(const affine_element* src,
|
||||
affine_element* dest,
|
||||
uint64_t predicate);
|
||||
};
|
||||
|
||||
} // namespace bb
|
||||
|
||||
#ifdef DISABLE_ASM
|
||||
#include "group_impl_int128.tcc"
|
||||
#else
|
||||
#include "group_impl_asm.tcc"
|
||||
#endif
|
||||
@@ -0,0 +1,162 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef DISABLE_ASM
|
||||
|
||||
#include "barretenberg/ecc/groups/group.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace bb {
|
||||
// copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
|
||||
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
|
||||
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const affine_element* src, affine_element*
|
||||
// dest)
|
||||
// {
|
||||
// if constexpr (GroupParams::small_elements) {
|
||||
// #if defined __AVX__ && defined USE_AVX
|
||||
// ASSERT((((uintptr_t)src & 0x1f) == 0));
|
||||
// ASSERT((((uintptr_t)dest & 0x1f) == 0));
|
||||
// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t"
|
||||
// "vmovdqa 32(%0), %%ymm1 \n\t"
|
||||
// "vmovdqa %%ymm0, 0(%1) \n\t"
|
||||
// "vmovdqa %%ymm1, 32(%1) \n\t"
|
||||
// :
|
||||
// : "r"(src), "r"(dest)
|
||||
// : "%ymm0", "%ymm1", "memory");
|
||||
// #else
|
||||
// *dest = *src;
|
||||
// #endif
|
||||
// } else {
|
||||
// *dest = *src;
|
||||
// }
|
||||
// }
|
||||
|
||||
// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
|
||||
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
|
||||
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const element* src, element* dest)
|
||||
// {
|
||||
// if constexpr (GroupParams::small_elements) {
|
||||
// #if defined __AVX__ && defined USE_AVX
|
||||
// ASSERT((((uintptr_t)src & 0x1f) == 0));
|
||||
// ASSERT((((uintptr_t)dest & 0x1f) == 0));
|
||||
// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t"
|
||||
// "vmovdqa 32(%0), %%ymm1 \n\t"
|
||||
// "vmovdqa 64(%0), %%ymm2 \n\t"
|
||||
// "vmovdqa %%ymm0, 0(%1) \n\t"
|
||||
// "vmovdqa %%ymm1, 32(%1) \n\t"
|
||||
// "vmovdqa %%ymm2, 64(%1) \n\t"
|
||||
// :
|
||||
// : "r"(src), "r"(dest)
|
||||
// : "%ymm0", "%ymm1", "%ymm2", "memory");
|
||||
// #else
|
||||
// *dest = *src;
|
||||
// #endif
|
||||
// } else {
|
||||
// *dest = src;
|
||||
// }
|
||||
// }
|
||||
|
||||
// copies src into dest, inverting y-coordinate if 'predicate' is true
|
||||
// n.b. requires src and dest to be aligned on 32 byte boundary
|
||||
template <typename coordinate_field, typename subgroup_field, typename GroupParams>
|
||||
inline void group<coordinate_field, subgroup_field, GroupParams>::conditional_negate_affine(const affine_element* src,
|
||||
affine_element* dest,
|
||||
uint64_t predicate)
|
||||
{
|
||||
constexpr uint256_t twice_modulus = coordinate_field::modulus + coordinate_field::modulus;
|
||||
|
||||
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
|
||||
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
|
||||
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
|
||||
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
|
||||
|
||||
if constexpr (GroupParams::small_elements) {
|
||||
#if defined __AVX__ && defined USE_AVX
|
||||
ASSERT((((uintptr_t)src & 0x1f) == 0));
|
||||
ASSERT((((uintptr_t)dest & 0x1f) == 0));
|
||||
__asm__ __volatile__("xorq %%r8, %%r8 \n\t"
|
||||
"movq 32(%0), %%r8 \n\t"
|
||||
"movq 40(%0), %%r9 \n\t"
|
||||
"movq 48(%0), %%r10 \n\t"
|
||||
"movq 56(%0), %%r11 \n\t"
|
||||
"movq %[modulus_0], %%r12 \n\t"
|
||||
"movq %[modulus_1], %%r13 \n\t"
|
||||
"movq %[modulus_2], %%r14 \n\t"
|
||||
"movq %[modulus_3], %%r15 \n\t"
|
||||
"subq %%r8, %%r12 \n\t"
|
||||
"sbbq %%r9, %%r13 \n\t"
|
||||
"sbbq %%r10, %%r14 \n\t"
|
||||
"sbbq %%r11, %%r15 \n\t"
|
||||
"testq %2, %2 \n\t"
|
||||
"cmovnzq %%r12, %%r8 \n\t"
|
||||
"cmovnzq %%r13, %%r9 \n\t"
|
||||
"cmovnzq %%r14, %%r10 \n\t"
|
||||
"cmovnzq %%r15, %%r11 \n\t"
|
||||
"vmovdqa 0(%0), %%ymm0 \n\t"
|
||||
"vmovdqa %%ymm0, 0(%1) \n\t"
|
||||
"movq %%r8, 32(%1) \n\t"
|
||||
"movq %%r9, 40(%1) \n\t"
|
||||
"movq %%r10, 48(%1) \n\t"
|
||||
"movq %%r11, 56(%1) \n\t"
|
||||
:
|
||||
: "r"(src),
|
||||
"r"(dest),
|
||||
"r"(predicate),
|
||||
[modulus_0] "i"(twice_modulus_0),
|
||||
[modulus_1] "i"(twice_modulus_1),
|
||||
[modulus_2] "i"(twice_modulus_2),
|
||||
[modulus_3] "i"(twice_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "%ymm0", "memory", "cc");
|
||||
#else
|
||||
__asm__ __volatile__("xorq %%r8, %%r8 \n\t"
|
||||
"movq 32(%0), %%r8 \n\t"
|
||||
"movq 40(%0), %%r9 \n\t"
|
||||
"movq 48(%0), %%r10 \n\t"
|
||||
"movq 56(%0), %%r11 \n\t"
|
||||
"movq %[modulus_0], %%r12 \n\t"
|
||||
"movq %[modulus_1], %%r13 \n\t"
|
||||
"movq %[modulus_2], %%r14 \n\t"
|
||||
"movq %[modulus_3], %%r15 \n\t"
|
||||
"subq %%r8, %%r12 \n\t"
|
||||
"sbbq %%r9, %%r13 \n\t"
|
||||
"sbbq %%r10, %%r14 \n\t"
|
||||
"sbbq %%r11, %%r15 \n\t"
|
||||
"testq %2, %2 \n\t"
|
||||
"cmovnzq %%r12, %%r8 \n\t"
|
||||
"cmovnzq %%r13, %%r9 \n\t"
|
||||
"cmovnzq %%r14, %%r10 \n\t"
|
||||
"cmovnzq %%r15, %%r11 \n\t"
|
||||
"movq 0(%0), %%r12 \n\t"
|
||||
"movq 8(%0), %%r13 \n\t"
|
||||
"movq 16(%0), %%r14 \n\t"
|
||||
"movq 24(%0), %%r15 \n\t"
|
||||
"movq %%r8, 32(%1) \n\t"
|
||||
"movq %%r9, 40(%1) \n\t"
|
||||
"movq %%r10, 48(%1) \n\t"
|
||||
"movq %%r11, 56(%1) \n\t"
|
||||
"movq %%r12, 0(%1) \n\t"
|
||||
"movq %%r13, 8(%1) \n\t"
|
||||
"movq %%r14, 16(%1) \n\t"
|
||||
"movq %%r15, 24(%1) \n\t"
|
||||
:
|
||||
: "r"(src),
|
||||
"r"(dest),
|
||||
"r"(predicate),
|
||||
[modulus_0] "i"(twice_modulus_0),
|
||||
[modulus_1] "i"(twice_modulus_1),
|
||||
[modulus_2] "i"(twice_modulus_2),
|
||||
[modulus_3] "i"(twice_modulus_3)
|
||||
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "memory", "cc");
|
||||
#endif
|
||||
} else {
|
||||
if (predicate) { // NOLINT
|
||||
coordinate_field::__copy(src->x, dest->x);
|
||||
dest->y = -src->y;
|
||||
} else {
|
||||
copy_affine(*src, *dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace bb
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef DISABLE_ASM
|
||||
|
||||
#include "barretenberg/ecc/groups/group.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace bb {
|
||||
|
||||
// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
|
||||
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
|
||||
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const affine_element* src, affine_element*
|
||||
// dest)
|
||||
// {
|
||||
// *dest = *src;
|
||||
// }
|
||||
|
||||
// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
|
||||
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
|
||||
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const element* src, element* dest)
|
||||
// {
|
||||
// *dest = *src;
|
||||
// }
|
||||
|
||||
template <typename coordinate_field, typename subgroup_field, typename GroupParams>
|
||||
inline void group<coordinate_field, subgroup_field, GroupParams>::conditional_negate_affine(const affine_element* src,
|
||||
affine_element* dest,
|
||||
uint64_t predicate)
|
||||
{
|
||||
*dest = predicate ? -(*src) : (*src);
|
||||
}
|
||||
} // namespace bb
|
||||
|
||||
#endif
|
||||
513
sumcheck/src/cuda/includes/barretenberg/ecc/groups/wnaf.hpp
Normal file
513
sumcheck/src/cuda/includes/barretenberg/ecc/groups/wnaf.hpp
Normal file
@@ -0,0 +1,513 @@
|
||||
#pragma once
|
||||
#include "../../numeric/bitop/get_msb.hpp"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
// NOLINTBEGIN(readability-implicit-bool-conversion)
|
||||
namespace bb::wnaf {
|
||||
constexpr size_t SCALAR_BITS = 127;
|
||||
|
||||
#define WNAF_SIZE(x) ((bb::wnaf::SCALAR_BITS + (x)-1) / (x)) // NOLINT(cppcoreguidelines-macro-usage)
|
||||
|
||||
constexpr size_t get_optimal_bucket_width(const size_t num_points)
|
||||
{
|
||||
if (num_points >= 14617149) {
|
||||
return 21;
|
||||
}
|
||||
if (num_points >= 1139094) {
|
||||
return 18;
|
||||
}
|
||||
// if (num_points >= 100000)
|
||||
if (num_points >= 155975) {
|
||||
return 15;
|
||||
}
|
||||
if (num_points >= 144834)
|
||||
// if (num_points >= 100000)
|
||||
{
|
||||
return 14;
|
||||
}
|
||||
if (num_points >= 25067) {
|
||||
return 12;
|
||||
}
|
||||
if (num_points >= 13926) {
|
||||
return 11;
|
||||
}
|
||||
if (num_points >= 7659) {
|
||||
return 10;
|
||||
}
|
||||
if (num_points >= 2436) {
|
||||
return 9;
|
||||
}
|
||||
if (num_points >= 376) {
|
||||
return 7;
|
||||
}
|
||||
if (num_points >= 231) {
|
||||
return 6;
|
||||
}
|
||||
if (num_points >= 97) {
|
||||
return 5;
|
||||
}
|
||||
if (num_points >= 35) {
|
||||
return 4;
|
||||
}
|
||||
if (num_points >= 10) {
|
||||
return 3;
|
||||
}
|
||||
if (num_points >= 2) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
constexpr size_t get_num_buckets(const size_t num_points)
|
||||
{
|
||||
const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2);
|
||||
return 1UL << bits_per_bucket;
|
||||
}
|
||||
|
||||
constexpr size_t get_num_rounds(const size_t num_points)
|
||||
{
|
||||
const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2);
|
||||
return WNAF_SIZE(bits_per_bucket + 1);
|
||||
}
|
||||
|
||||
template <size_t bits, size_t bit_position> inline uint64_t get_wnaf_bits_const(const uint64_t* scalar) noexcept
|
||||
{
|
||||
if constexpr (bits == 0) {
|
||||
return 0ULL;
|
||||
} else {
|
||||
/**
|
||||
* we want to take a 128 bit scalar and shift it down by (bit_position).
|
||||
* We then wish to mask out `bits` number of bits.
|
||||
* Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also
|
||||
* (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual
|
||||
* bit position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 -
|
||||
* (bit_position & 63)
|
||||
*
|
||||
* So, step 1:
|
||||
* get low limb and shift right by (bit_position & 63)
|
||||
* get high limb and shift left by (64 - (bit_position & 63))
|
||||
*
|
||||
*/
|
||||
constexpr size_t lo_limb_idx = bit_position / 64;
|
||||
constexpr size_t hi_limb_idx = (bit_position + bits - 1) / 64;
|
||||
constexpr uint64_t lo_shift = bit_position & 63UL;
|
||||
constexpr uint64_t bit_mask = (1UL << static_cast<uint64_t>(bits)) - 1UL;
|
||||
|
||||
uint64_t lo = (scalar[lo_limb_idx] >> lo_shift);
|
||||
if constexpr (lo_limb_idx == hi_limb_idx) {
|
||||
return lo & bit_mask;
|
||||
} else {
|
||||
constexpr uint64_t hi_shift = 64UL - (bit_position & 63UL);
|
||||
uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift)));
|
||||
return (lo | hi) & bit_mask;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const uint64_t bit_position) noexcept
|
||||
{
|
||||
/**
|
||||
* we want to take a 128 bit scalar and shift it down by (bit_position).
|
||||
* We then wish to mask out `bits` number of bits.
|
||||
* Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also
|
||||
* (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual bit
|
||||
* position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - (bit_position
|
||||
* & 63)
|
||||
*
|
||||
* So, step 1:
|
||||
* get low limb and shift right by (bit_position & 63)
|
||||
* get high limb and shift left by (64 - (bit_position & 63))
|
||||
*
|
||||
*/
|
||||
const auto lo_limb_idx = static_cast<size_t>(bit_position >> 6);
|
||||
const auto hi_limb_idx = static_cast<size_t>((bit_position + bits - 1) >> 6);
|
||||
const uint64_t lo_shift = bit_position & 63UL;
|
||||
const uint64_t bit_mask = (1UL << static_cast<uint64_t>(bits)) - 1UL;
|
||||
|
||||
const uint64_t lo = (scalar[lo_limb_idx] >> lo_shift);
|
||||
const uint64_t hi_shift = bit_position ? 64UL - (bit_position & 63UL) : 0;
|
||||
const uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift)));
|
||||
const uint64_t hi_mask = bit_mask & (0ULL - (lo_limb_idx != hi_limb_idx));
|
||||
|
||||
return (lo & bit_mask) | (hi & hi_mask);
|
||||
}
|
||||
|
||||
inline void fixed_wnaf_packed(
|
||||
const uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const uint64_t point_index, const size_t wnaf_bits) noexcept
|
||||
{
|
||||
skew_map = ((scalar[0] & 1) == 0);
|
||||
uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
|
||||
const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
|
||||
|
||||
for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
|
||||
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[(wnaf_entries - round_i)] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
previous = slice + predicate;
|
||||
}
|
||||
size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1));
|
||||
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
|
||||
wnaf[1] = ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication.
|
||||
*
|
||||
* WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn optimizes
|
||||
* the number of point doublings in scalar multiplication, in turn aiding efficiency.
|
||||
*
|
||||
* @param scalar Pointer to 128-bit scalar for which WNAF is to be computed.
|
||||
* @param wnaf Pointer to num_points+1 size array where the computed WNAF will be stored.
|
||||
* @param skew_map Reference to a boolean variable which will be set based on the least significant bit of the scalar.
|
||||
* @param point_index The index of the point being computed in the context of multiple point multiplication.
|
||||
* @param num_points The number of points being computed in parallel.
|
||||
* @param wnaf_bits The number of bits to use in each window of the WNAF representation.
|
||||
*/
|
||||
inline void fixed_wnaf(const uint64_t* scalar,
|
||||
uint64_t* wnaf,
|
||||
bool& skew_map,
|
||||
const uint64_t point_index,
|
||||
const uint64_t num_points,
|
||||
const size_t wnaf_bits) noexcept
|
||||
{
|
||||
skew_map = ((scalar[0] & 1) == 0);
|
||||
uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
|
||||
const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
|
||||
|
||||
for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
|
||||
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[(wnaf_entries - round_i) * num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
previous = slice + predicate;
|
||||
}
|
||||
size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1));
|
||||
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
|
||||
wnaf[num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Current flow...
|
||||
*
|
||||
* If a wnaf entry is even, we add +1 to it, and subtract 32 from the previous entry.
|
||||
* This works if the previous entry is odd. If we recursively apply this process, starting at the least significant
|
||||
*window, this will always be the case.
|
||||
*
|
||||
* However, we want to skip over windows that are 0, which poses a problem.
|
||||
*
|
||||
* Scenario 1: even window followed by 0 window followed by any window 'x'
|
||||
*
|
||||
* We can't add 1 to the even window and subtract 32 from the 0 window, as we don't have a bucket that maps to -32
|
||||
* This means that we have to identify whether we are going to borrow 32 from 'x', requiring us to look at least 2
|
||||
*steps ahead
|
||||
*
|
||||
* Scenario 2: <even> <0> <0> <x>
|
||||
*
|
||||
* This problem proceeds indefinitely - if we have adjacent 0 windows, we do not know whether we need to track a
|
||||
*borrow flag until we identify the next non-zero window
|
||||
*
|
||||
* Scenario 3: <odd> <0>
|
||||
*
|
||||
* This one works...
|
||||
*
|
||||
* Ok, so we should be a bit more limited with when we don't include window entries.
|
||||
* The goal here is to identify short scalars, so we want to identify the most significant non-zero window
|
||||
**/
|
||||
inline uint64_t get_num_scalar_bits(const uint64_t* scalar)
|
||||
{
|
||||
const uint64_t msb_1 = numeric::get_msb(scalar[1]);
|
||||
const uint64_t msb_0 = numeric::get_msb(scalar[0]);
|
||||
|
||||
const uint64_t scalar_1_mask = (0ULL - (scalar[1] > 0));
|
||||
const uint64_t scalar_0_mask = (0ULL - (scalar[0] > 0)) & ~scalar_1_mask;
|
||||
|
||||
const uint64_t msb = (scalar_1_mask & (msb_1 + 64)) | (scalar_0_mask & (msb_0));
|
||||
return msb;
|
||||
}
|
||||
|
||||
/**
|
||||
* How to compute an x-bit wnaf slice?
|
||||
*
|
||||
* Iterate over number of slices in scalar.
|
||||
* For each slice, if slice is even, ADD +1 to current slice and SUBTRACT 2^x from previous slice.
|
||||
* (for 1st slice we instead add +1 and set the scalar's 'skew' value to 'true' (i.e. need to subtract 1 from it at the
|
||||
* end of our scalar mul algo))
|
||||
*
|
||||
* In *wnaf we store the following:
|
||||
* 1. bits 0-30: ABSOLUTE value of wnaf (i.e. -3 goes to 3)
|
||||
* 2. bit 31: 'predicate' bool (i.e. does the wnaf value need to be negated?)
|
||||
* 3. bits 32-63: position in a point array that describes the elliptic curve point this wnaf slice is referencing
|
||||
*
|
||||
* N.B. IN OUR STDLIB ALGORITHMS THE SKEW VALUE REPRESENTS AN ADDITION NOT A SUBTRACTION (i.e. we add +1 at the end of
|
||||
* the scalar mul algo we don't sub 1) (this is to eliminate situations which could produce the point at infinity as an
|
||||
* output as our circuit logic cannot accommodate this edge case).
|
||||
*
|
||||
* Credits: Zac W.
|
||||
*
|
||||
* @param scalar Pointer to the 128-bit non-montgomery scalar that is supposed to be transformed into wnaf
|
||||
* @param wnaf Pointer to output array that needs to accommodate enough 64-bit WNAF entries
|
||||
* @param skew_map Reference to output skew value, which if true shows that the point should be added once at the end of
|
||||
* computation
|
||||
* @param wnaf_round_counts Pointer to output array specifying the number of points participating in each round
|
||||
* @param point_index The index of the point that should be multiplied by this scalar in the point array
|
||||
* @param num_points Total points in the MSM (2*num_initial_points)
|
||||
*
|
||||
*/
|
||||
inline void fixed_wnaf_with_counts(const uint64_t* scalar,
|
||||
uint64_t* wnaf,
|
||||
bool& skew_map,
|
||||
uint64_t* wnaf_round_counts,
|
||||
const uint64_t point_index,
|
||||
const uint64_t num_points,
|
||||
const size_t wnaf_bits) noexcept
|
||||
{
|
||||
const size_t max_wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
|
||||
if ((scalar[0] | scalar[1]) == 0ULL) {
|
||||
skew_map = false;
|
||||
for (size_t round_i = 0; round_i < max_wnaf_entries; ++round_i) {
|
||||
wnaf[(round_i)*num_points] = 0xffffffffffffffffULL;
|
||||
}
|
||||
return;
|
||||
}
|
||||
const auto current_scalar_bits = static_cast<size_t>(get_num_scalar_bits(scalar) + 1);
|
||||
skew_map = ((scalar[0] & 1) == 0);
|
||||
uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
|
||||
const auto wnaf_entries = static_cast<size_t>((current_scalar_bits + wnaf_bits - 1) / wnaf_bits);
|
||||
|
||||
if (wnaf_entries == 1) {
|
||||
wnaf[(max_wnaf_entries - 1) * num_points] = (previous >> 1UL) | (point_index);
|
||||
++wnaf_round_counts[max_wnaf_entries - 1];
|
||||
for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) {
|
||||
wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// If there are several windows
|
||||
for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
|
||||
|
||||
// Get a bit slice
|
||||
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
|
||||
|
||||
// Get the predicate (last bit is zero)
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
|
||||
// Update round count
|
||||
++wnaf_round_counts[max_wnaf_entries - round_i];
|
||||
|
||||
// Calculate entry value
|
||||
// If the last bit of current slice is 1, we simply put the previous value with the point index
|
||||
// If the last bit of the current slice is 0, we negate everything, so that we subtract from the WNAF form and
|
||||
// make it 0
|
||||
wnaf[(max_wnaf_entries - round_i) * num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
|
||||
// Update the previous value to the next windows
|
||||
previous = slice + predicate;
|
||||
}
|
||||
// The final iteration for top bits
|
||||
auto final_bits = static_cast<size_t>(current_scalar_bits - (wnaf_bits * (wnaf_entries - 1)));
|
||||
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
|
||||
++wnaf_round_counts[(max_wnaf_entries - wnaf_entries + 1)];
|
||||
wnaf[((max_wnaf_entries - wnaf_entries + 1) * num_points)] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
|
||||
// Saving top bits
|
||||
++wnaf_round_counts[max_wnaf_entries - wnaf_entries];
|
||||
wnaf[(max_wnaf_entries - wnaf_entries) * num_points] = ((slice + predicate) >> 1UL) | (point_index);
|
||||
|
||||
// Fill all unused slots with -1
|
||||
for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) {
|
||||
wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL;
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t num_points, size_t wnaf_bits, size_t round_i>
|
||||
inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept
|
||||
{
|
||||
constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
|
||||
constexpr auto log2_num_points = static_cast<size_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
|
||||
|
||||
if constexpr (round_i < wnaf_entries - 1) {
|
||||
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[(wnaf_entries - round_i) << log2_num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
wnaf_round<num_points, wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
|
||||
} else {
|
||||
constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits;
|
||||
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
|
||||
// uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t scalar_bits, size_t num_points, size_t wnaf_bits, size_t round_i>
|
||||
inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept
|
||||
{
|
||||
constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits;
|
||||
constexpr auto log2_num_points = static_cast<uint64_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
|
||||
|
||||
if constexpr (round_i < wnaf_entries - 1) {
|
||||
uint64_t slice = get_wnaf_bits_const<wnaf_bits, round_i * wnaf_bits>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[(wnaf_entries - round_i) << log2_num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
wnaf_round<scalar_bits, num_points, wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
|
||||
} else {
|
||||
constexpr size_t final_bits = ((scalar_bits / wnaf_bits) * wnaf_bits == scalar_bits)
|
||||
? wnaf_bits
|
||||
: scalar_bits - (scalar_bits / wnaf_bits) * wnaf_bits;
|
||||
uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t wnaf_bits, size_t round_i>
|
||||
inline void wnaf_round_packed(const uint64_t* scalar,
|
||||
uint64_t* wnaf,
|
||||
const uint64_t point_index,
|
||||
const uint64_t previous) noexcept
|
||||
{
|
||||
constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
|
||||
|
||||
if constexpr (round_i < wnaf_entries - 1) {
|
||||
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
|
||||
// uint64_t slice = get_wnaf_bits_const<wnaf_bits, round_i * wnaf_bits>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[(wnaf_entries - round_i)] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
wnaf_round_packed<wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
|
||||
} else {
|
||||
constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits;
|
||||
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
|
||||
// uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[1] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index);
|
||||
|
||||
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t num_points, size_t wnaf_bits>
|
||||
inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept
|
||||
{
|
||||
skew_map = ((scalar[0] & 1) == 0);
|
||||
uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + static_cast<uint64_t>(skew_map);
|
||||
wnaf_round<num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
|
||||
}
|
||||
|
||||
template <size_t num_bits, size_t num_points, size_t wnaf_bits>
|
||||
inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept
|
||||
{
|
||||
skew_map = ((scalar[0] & 1) == 0);
|
||||
uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + static_cast<uint64_t>(skew_map);
|
||||
wnaf_round<num_bits, num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
|
||||
}
|
||||
|
||||
template <size_t scalar_bits, size_t num_points, size_t wnaf_bits, size_t round_i>
|
||||
inline void wnaf_round_with_restricted_first_slice(uint64_t* scalar,
|
||||
uint64_t* wnaf,
|
||||
const uint64_t point_index,
|
||||
const uint64_t previous) noexcept
|
||||
{
|
||||
constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits;
|
||||
constexpr auto log2_num_points = static_cast<uint64_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
|
||||
constexpr size_t bits_in_first_slice = scalar_bits % wnaf_bits;
|
||||
if constexpr (round_i == 1) {
|
||||
uint64_t slice = get_wnaf_bits_const<wnaf_bits, (round_i - 1) * wnaf_bits + bits_in_first_slice>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
|
||||
wnaf[(wnaf_entries - round_i) << log2_num_points] =
|
||||
((((previous - (predicate << (bits_in_first_slice /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) |
|
||||
(predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
if (round_i == 1) {
|
||||
std::cerr << "writing value " << std::hex << wnaf[(wnaf_entries - round_i) << log2_num_points] << std::dec
|
||||
<< " at index " << ((wnaf_entries - round_i) << log2_num_points) << std::endl;
|
||||
}
|
||||
wnaf_round_with_restricted_first_slice<scalar_bits, num_points, wnaf_bits, round_i + 1>(
|
||||
scalar, wnaf, point_index, slice + predicate);
|
||||
|
||||
} else if constexpr (round_i < wnaf_entries - 1) {
|
||||
uint64_t slice = get_wnaf_bits_const<wnaf_bits, (round_i - 1) * wnaf_bits + bits_in_first_slice>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[(wnaf_entries - round_i) << log2_num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
wnaf_round_with_restricted_first_slice<scalar_bits, num_points, wnaf_bits, round_i + 1>(
|
||||
scalar, wnaf, point_index, slice + predicate);
|
||||
} else {
|
||||
uint64_t slice = get_wnaf_bits_const<wnaf_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
|
||||
uint64_t predicate = ((slice & 1UL) == 0UL);
|
||||
wnaf[num_points] =
|
||||
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
|
||||
(point_index << 32UL);
|
||||
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t num_bits, size_t num_points, size_t wnaf_bits>
|
||||
inline void fixed_wnaf_with_restricted_first_slice(uint64_t* scalar,
|
||||
uint64_t* wnaf,
|
||||
bool& skew_map,
|
||||
const size_t point_index) noexcept
|
||||
{
|
||||
constexpr size_t bits_in_first_slice = num_bits % wnaf_bits;
|
||||
std::cerr << "bits in first slice = " << bits_in_first_slice << std::endl;
|
||||
skew_map = ((scalar[0] & 1) == 0);
|
||||
uint64_t previous = get_wnaf_bits_const<bits_in_first_slice, 0>(scalar) + static_cast<uint64_t>(skew_map);
|
||||
std::cerr << "previous = " << previous << std::endl;
|
||||
wnaf_round_with_restricted_first_slice<num_bits, num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
|
||||
}
|
||||
|
||||
// template <size_t wnaf_bits>
|
||||
// inline void fixed_wnaf_packed(const uint64_t* scalar,
|
||||
// uint64_t* wnaf,
|
||||
// bool& skew_map,
|
||||
// const uint64_t point_index) noexcept
|
||||
// {
|
||||
// skew_map = ((scalar[0] & 1) == 0);
|
||||
// uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + (uint64_t)skew_map;
|
||||
// wnaf_round_packed<wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
|
||||
// }
|
||||
|
||||
// template <size_t wnaf_bits>
|
||||
// inline constexpr std::array<uint32_t, WNAF_SIZE(wnaf_bits)> fixed_wnaf(const uint64_t *scalar) const noexcept
|
||||
// {
|
||||
// bool skew_map = ((scalar[0] * 1) == 0);
|
||||
// uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + (uint64_t)skew_map;
|
||||
// std::array<uint32_t, WNAF_SIZE(wnaf_bits)> result;
|
||||
// }
|
||||
} // namespace bb::wnaf
|
||||
|
||||
// NOLINTEND(readability-implicit-bool-conversion)
|
||||
9
sumcheck/src/cuda/includes/barretenberg/env/logstr.cpp
vendored
Normal file
9
sumcheck/src/cuda/includes/barretenberg/env/logstr.cpp
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
void logstr(char const* str)
|
||||
{
|
||||
std::cerr << str << std::endl;
|
||||
}
|
||||
}
|
||||
1
sumcheck/src/cuda/includes/barretenberg/env/logstr.hpp
vendored
Normal file
1
sumcheck/src/cuda/includes/barretenberg/env/logstr.hpp
vendored
Normal file
@@ -0,0 +1 @@
|
||||
void logstr(char const*);
|
||||
@@ -0,0 +1,17 @@
|
||||
#include "count_leading_zeros.hpp"
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
using namespace benchmark;
|
||||
|
||||
void count_leading_zeros(State& state) noexcept
|
||||
{
|
||||
uint256_t input = 7;
|
||||
for (auto _ : state) {
|
||||
auto r = count_leading_zeros(input);
|
||||
DoNotOptimize(r);
|
||||
}
|
||||
}
|
||||
BENCHMARK(count_leading_zeros);
|
||||
|
||||
// NOLINTNEXTLINE macro invokation triggers style errors from googletest code
|
||||
BENCHMARK_MAIN();
|
||||
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
#include "../uint128/uint128.hpp"
|
||||
#include "../uint256/uint256.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
/**
|
||||
* Returns the number of leading 0 bits for a given integer type.
|
||||
* Implemented in terms of intrinsics which will use instructions such as `bsr` or `lzcnt` for best performance.
|
||||
* Undefined behavior when input is 0.
|
||||
*/
|
||||
template <typename T> constexpr inline size_t count_leading_zeros(T const& u);
|
||||
|
||||
template <> constexpr inline size_t count_leading_zeros<uint32_t>(uint32_t const& u)
|
||||
{
|
||||
return static_cast<size_t>(__builtin_clz(u));
|
||||
}
|
||||
|
||||
template <> constexpr inline size_t count_leading_zeros<uint64_t>(uint64_t const& u)
|
||||
{
|
||||
return static_cast<size_t>(__builtin_clzll(u));
|
||||
}
|
||||
|
||||
template <> constexpr inline size_t count_leading_zeros<uint128_t>(uint128_t const& u)
|
||||
{
|
||||
auto hi = static_cast<uint64_t>(u >> 64);
|
||||
if (hi != 0U) {
|
||||
return static_cast<size_t>(__builtin_clzll(hi));
|
||||
}
|
||||
auto lo = static_cast<uint64_t>(u);
|
||||
return static_cast<size_t>(__builtin_clzll(lo)) + 64;
|
||||
}
|
||||
|
||||
template <> constexpr inline size_t count_leading_zeros<uint256_t>(uint256_t const& u)
|
||||
{
|
||||
if (u.data[3] != 0U) {
|
||||
return count_leading_zeros(u.data[3]);
|
||||
}
|
||||
if (u.data[2] != 0U) {
|
||||
return count_leading_zeros(u.data[2]) + 64;
|
||||
}
|
||||
if (u.data[1] != 0U) {
|
||||
return count_leading_zeros(u.data[1]) + 128;
|
||||
}
|
||||
if (u.data[0] != 0U) {
|
||||
return count_leading_zeros(u.data[0]) + 192;
|
||||
}
|
||||
return 256;
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,36 @@
|
||||
#include "count_leading_zeros.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace bb;
|
||||
|
||||
TEST(bitop, ClzUint3231)
|
||||
{
|
||||
uint32_t a = 0b00000000000000000000000000000001;
|
||||
EXPECT_EQ(numeric::count_leading_zeros(a), 31U);
|
||||
}
|
||||
|
||||
TEST(bitop, ClzUint320)
|
||||
{
|
||||
uint32_t a = 0b10000000000000000000000000000001;
|
||||
EXPECT_EQ(numeric::count_leading_zeros(a), 0U);
|
||||
}
|
||||
|
||||
TEST(bitop, ClzUint640)
|
||||
{
|
||||
uint64_t a = 0b1000000000000000000000000000000100000000000000000000000000000000;
|
||||
EXPECT_EQ(numeric::count_leading_zeros(a), 0U);
|
||||
}
|
||||
|
||||
TEST(bitop, ClzUint256255)
|
||||
{
|
||||
uint256_t a = 0x1;
|
||||
auto r = numeric::count_leading_zeros(a);
|
||||
EXPECT_EQ(r, 255U);
|
||||
}
|
||||
|
||||
TEST(bitop, ClzUint256248)
|
||||
{
|
||||
uint256_t a = 0x80;
|
||||
auto r = numeric::count_leading_zeros(a);
|
||||
EXPECT_EQ(r, 248U);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
namespace bb::numeric {
|
||||
|
||||
// from http://supertech.csail.mit.edu/papers/debruijn.pdf
|
||||
constexpr inline uint32_t get_msb32(const uint32_t in)
|
||||
{
|
||||
constexpr std::array<uint8_t, 32> MultiplyDeBruijnBitPosition{ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16,
|
||||
18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17,
|
||||
24, 7, 19, 27, 23, 6, 26, 5, 4, 31 };
|
||||
|
||||
uint32_t v = in | (in >> 1);
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
|
||||
return MultiplyDeBruijnBitPosition[static_cast<uint32_t>(v * static_cast<uint32_t>(0x07C4ACDD)) >>
|
||||
static_cast<uint32_t>(27)];
|
||||
}
|
||||
|
||||
constexpr inline uint64_t get_msb64(const uint64_t in)
|
||||
{
|
||||
constexpr std::array<uint8_t, 64> de_bruijn_sequence{ 0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28,
|
||||
16, 3, 61, 54, 58, 35, 52, 50, 42, 21, 44, 38, 32,
|
||||
29, 23, 17, 11, 4, 62, 46, 55, 26, 59, 40, 36, 15,
|
||||
53, 34, 51, 20, 43, 31, 22, 10, 45, 25, 39, 14, 33,
|
||||
19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63 };
|
||||
|
||||
uint64_t t = in | (in >> 1);
|
||||
t |= t >> 2;
|
||||
t |= t >> 4;
|
||||
t |= t >> 8;
|
||||
t |= t >> 16;
|
||||
t |= t >> 32;
|
||||
return static_cast<uint64_t>(de_bruijn_sequence[(t * 0x03F79D71B4CB0A89ULL) >> 58ULL]);
|
||||
};
|
||||
|
||||
template <typename T> constexpr inline T get_msb(const T in)
|
||||
{
|
||||
return (sizeof(T) <= 4) ? get_msb32(in) : get_msb64(in);
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,35 @@
|
||||
#include "get_msb.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace bb;
|
||||
|
||||
TEST(bitop, GetMsbUint640Value)
|
||||
{
|
||||
uint64_t a = 0b00000000000000000000000000000000;
|
||||
EXPECT_EQ(numeric::get_msb(a), 0U);
|
||||
}
|
||||
|
||||
TEST(bitop, GetMsbUint320)
|
||||
{
|
||||
uint32_t a = 0b00000000000000000000000000000001;
|
||||
EXPECT_EQ(numeric::get_msb(a), 0U);
|
||||
}
|
||||
|
||||
TEST(bitop, GetMsbUint3231)
|
||||
{
|
||||
uint32_t a = 0b10000000000000000000000000000001;
|
||||
EXPECT_EQ(numeric::get_msb(a), 31U);
|
||||
}
|
||||
|
||||
TEST(bitop, GetMsbUint6463)
|
||||
{
|
||||
uint64_t a = 0b1000000000000000000000000000000100000000000000000000000000000000;
|
||||
EXPECT_EQ(numeric::get_msb(a), 63U);
|
||||
}
|
||||
|
||||
TEST(bitop, GetMsbSizeT7)
|
||||
{
|
||||
size_t a = 0x80;
|
||||
auto r = numeric::get_msb(a);
|
||||
EXPECT_EQ(r, 7U);
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
template <typename T> inline T keep_n_lsb(T const& input, size_t num_bits)
|
||||
{
|
||||
return num_bits >= sizeof(T) * 8 ? input : input & ((T(1) << num_bits) - 1);
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "./get_msb.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace bb::numeric {
|
||||
constexpr uint64_t pow64(const uint64_t input, const uint64_t exponent)
|
||||
{
|
||||
if (input == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (exponent == 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint64_t accumulator = input;
|
||||
uint64_t to_mul = input;
|
||||
const uint64_t maximum_set_bit = get_msb64(exponent);
|
||||
|
||||
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
|
||||
accumulator *= accumulator;
|
||||
if (((exponent >> i) & 1) != 0U) {
|
||||
accumulator *= to_mul;
|
||||
}
|
||||
}
|
||||
return accumulator;
|
||||
}
|
||||
|
||||
constexpr bool is_power_of_two(uint64_t x)
|
||||
{
|
||||
return (x != 0U) && ((x & (x - 1)) == 0U);
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
constexpr inline uint64_t rotate64(const uint64_t value, const uint64_t rotation)
|
||||
{
|
||||
return rotation != 0U ? (value >> rotation) + (value << (64 - rotation)) : value;
|
||||
}
|
||||
|
||||
constexpr inline uint32_t rotate32(const uint32_t value, const uint32_t rotation)
|
||||
{
|
||||
return rotation != 0U ? (value >> rotation) + (value << (32 - rotation)) : value;
|
||||
}
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,157 @@
|
||||
#pragma once
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "../uint256/uint256.hpp"
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
inline std::vector<uint64_t> slice_input(const uint256_t& input, const uint64_t base, const size_t num_slices)
|
||||
{
|
||||
uint256_t target = input;
|
||||
std::vector<uint64_t> slices;
|
||||
if (num_slices > 0) {
|
||||
for (size_t i = 0; i < num_slices; ++i) {
|
||||
slices.push_back((target % base).data[0]);
|
||||
target /= base;
|
||||
}
|
||||
} else {
|
||||
while (target > 0) {
|
||||
slices.push_back((target % base).data[0]);
|
||||
target /= base;
|
||||
}
|
||||
}
|
||||
return slices;
|
||||
}
|
||||
|
||||
inline std::vector<uint64_t> slice_input_using_variable_bases(const uint256_t& input,
|
||||
const std::vector<uint64_t>& bases)
|
||||
{
|
||||
uint256_t target = input;
|
||||
std::vector<uint64_t> slices;
|
||||
for (size_t i = 0; i < bases.size(); ++i) {
|
||||
if (target >= bases[i] && i == bases.size() - 1) {
|
||||
throw_or_abort(format("Last key slice greater than ", bases[i]));
|
||||
}
|
||||
slices.push_back((target % bases[i]).data[0]);
|
||||
target /= bases[i];
|
||||
}
|
||||
return slices;
|
||||
}
|
||||
|
||||
template <uint64_t base, uint64_t num_slices> constexpr std::array<uint256_t, num_slices> get_base_powers()
|
||||
{
|
||||
std::array<uint256_t, num_slices> output{};
|
||||
output[0] = 1;
|
||||
for (size_t i = 1; i < num_slices; ++i) {
|
||||
output[i] = output[i - 1] * base;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
template <uint64_t base> constexpr uint256_t map_into_sparse_form(const uint64_t input)
|
||||
{
|
||||
uint256_t out = 0UL;
|
||||
auto converted = input;
|
||||
|
||||
constexpr auto base_powers = get_base_powers<base, 32>();
|
||||
for (size_t i = 0; i < 32; ++i) {
|
||||
uint64_t sparse_bit = ((converted >> i) & 1U);
|
||||
if (sparse_bit) {
|
||||
out += base_powers[i];
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <uint64_t base> constexpr uint64_t map_from_sparse_form(const uint256_t& input)
|
||||
{
|
||||
uint256_t target = input;
|
||||
uint64_t output = 0;
|
||||
|
||||
constexpr auto bases = get_base_powers<base, 32>();
|
||||
|
||||
for (uint64_t i = 0; i < 32; ++i) {
|
||||
const auto& base_power = bases[static_cast<size_t>(31 - i)];
|
||||
uint256_t prev_threshold = 0;
|
||||
for (uint64_t j = 1; j < base + 1; ++j) {
|
||||
const auto threshold = prev_threshold + base_power;
|
||||
if (target < threshold) {
|
||||
bool bit = ((j - 1) & 1);
|
||||
if (bit) {
|
||||
output += (1ULL << (31ULL - i));
|
||||
}
|
||||
if (j > 1) {
|
||||
target -= (prev_threshold);
|
||||
}
|
||||
break;
|
||||
}
|
||||
prev_threshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
template <uint64_t base, size_t num_bits> class sparse_int {
|
||||
public:
|
||||
sparse_int(const uint64_t input = 0)
|
||||
: value(input)
|
||||
{
|
||||
for (size_t i = 0; i < num_bits; ++i) {
|
||||
const uint64_t bit = (input >> i) & 1U;
|
||||
limbs[i] = bit;
|
||||
}
|
||||
}
|
||||
sparse_int(const sparse_int& other) noexcept = default;
|
||||
sparse_int(sparse_int&& other) noexcept = default;
|
||||
sparse_int& operator=(const sparse_int& other) noexcept = default;
|
||||
sparse_int& operator=(sparse_int&& other) noexcept = default;
|
||||
~sparse_int() noexcept = default;
|
||||
|
||||
sparse_int operator+(const sparse_int& other) const
|
||||
{
|
||||
sparse_int result(*this);
|
||||
for (size_t i = 0; i < num_bits - 1; ++i) {
|
||||
result.limbs[i] += other.limbs[i];
|
||||
if (result.limbs[i] >= base) {
|
||||
result.limbs[i] -= base;
|
||||
++result.limbs[i + 1];
|
||||
}
|
||||
}
|
||||
result.limbs[num_bits - 1] += other.limbs[num_bits - 1];
|
||||
result.limbs[num_bits - 1] %= base;
|
||||
result.value += other.value;
|
||||
return result;
|
||||
};
|
||||
|
||||
sparse_int operator+=(const sparse_int& other)
|
||||
{
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
[[nodiscard]] uint64_t get_value() const { return value; }
|
||||
|
||||
[[nodiscard]] uint64_t get_sparse_value() const
|
||||
{
|
||||
uint64_t result = 0;
|
||||
for (size_t i = num_bits - 1; i < num_bits; --i) {
|
||||
result *= base;
|
||||
result += limbs[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::array<uint64_t, num_bits>& get_limbs() const { return limbs; }
|
||||
|
||||
private:
|
||||
std::array<uint64_t, num_bits> limbs;
|
||||
uint64_t value;
|
||||
uint64_t sparse_value;
|
||||
};
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,139 @@
|
||||
#include "engine.hpp"
|
||||
#include "../../common/assert.hpp"
|
||||
#include <array>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
namespace {
|
||||
auto generate_random_data()
|
||||
{
|
||||
std::array<unsigned int, 32> random_data;
|
||||
std::random_device source;
|
||||
std::generate(std::begin(random_data), std::end(random_data), std::ref(source));
|
||||
return random_data;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class RandomEngine : public RNG {
|
||||
public:
|
||||
uint8_t get_random_uint8() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
uint32_t out = buf[0];
|
||||
return static_cast<uint8_t>(out);
|
||||
}
|
||||
|
||||
uint16_t get_random_uint16() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
uint32_t out = buf[0];
|
||||
return static_cast<uint16_t>(out);
|
||||
}
|
||||
|
||||
uint32_t get_random_uint32() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
uint32_t out = buf[0];
|
||||
return static_cast<uint32_t>(out);
|
||||
}
|
||||
|
||||
uint64_t get_random_uint64() override
|
||||
{
|
||||
auto buf = generate_random_data();
|
||||
auto lo = static_cast<uint64_t>(buf[0]);
|
||||
auto hi = static_cast<uint64_t>(buf[1]);
|
||||
return (lo + (hi << 32ULL));
|
||||
}
|
||||
|
||||
uint128_t get_random_uint128() override
|
||||
{
|
||||
auto big = get_random_uint256();
|
||||
auto lo = static_cast<uint128_t>(big.data[0]);
|
||||
auto hi = static_cast<uint128_t>(big.data[1]);
|
||||
return (lo + (hi << static_cast<uint128_t>(64ULL)));
|
||||
}
|
||||
|
||||
uint256_t get_random_uint256() override
|
||||
{
|
||||
const auto get64 = [](const std::array<uint32_t, 32>& buffer, const size_t offset) {
|
||||
auto lo = static_cast<uint64_t>(buffer[0 + offset]);
|
||||
auto hi = static_cast<uint64_t>(buffer[1 + offset]);
|
||||
return (lo + (hi << 32ULL));
|
||||
};
|
||||
auto buf = generate_random_data();
|
||||
uint64_t lolo = get64(buf, 0);
|
||||
uint64_t lohi = get64(buf, 2);
|
||||
uint64_t hilo = get64(buf, 4);
|
||||
uint64_t hihi = get64(buf, 6);
|
||||
return { lolo, lohi, hilo, hihi };
|
||||
}
|
||||
};
|
||||
|
||||
class DebugEngine : public RNG {
|
||||
public:
|
||||
DebugEngine()
|
||||
// disable linting for this line: we want the DEBUG engine to produce predictable pseudorandom numbers!
|
||||
// NOLINTNEXTLINE(cert-msc32-c, cert-msc51-cpp)
|
||||
: engine(std::mt19937_64(12345))
|
||||
{}
|
||||
|
||||
DebugEngine(std::uint_fast64_t seed)
|
||||
: engine(std::mt19937_64(seed))
|
||||
{}
|
||||
|
||||
uint8_t get_random_uint8() override { return static_cast<uint8_t>(dist(engine)); }
|
||||
|
||||
uint16_t get_random_uint16() override { return static_cast<uint16_t>(dist(engine)); }
|
||||
|
||||
uint32_t get_random_uint32() override { return static_cast<uint32_t>(dist(engine)); }
|
||||
|
||||
uint64_t get_random_uint64() override { return dist(engine); }
|
||||
|
||||
uint128_t get_random_uint128() override
|
||||
{
|
||||
uint128_t hi = dist(engine);
|
||||
uint128_t lo = dist(engine);
|
||||
return (hi << 64) | lo;
|
||||
}
|
||||
|
||||
uint256_t get_random_uint256() override
|
||||
{
|
||||
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
|
||||
auto a = dist(engine);
|
||||
auto b = dist(engine);
|
||||
auto c = dist(engine);
|
||||
auto d = dist(engine);
|
||||
return { a, b, c, d };
|
||||
}
|
||||
|
||||
private:
|
||||
std::mt19937_64 engine;
|
||||
std::uniform_int_distribution<uint64_t> dist = std::uniform_int_distribution<uint64_t>{ 0ULL, UINT64_MAX };
|
||||
};
|
||||
|
||||
/**
|
||||
* Used by tests to ensure consistent behavior.
|
||||
*/
|
||||
RNG& get_debug_randomness(bool reset, std::uint_fast64_t seed)
|
||||
{
|
||||
// static std::seed_seq seed({ 1, 2, 3, 4, 5 });
|
||||
static DebugEngine debug_engine = DebugEngine();
|
||||
if (reset) {
|
||||
debug_engine = DebugEngine(seed);
|
||||
}
|
||||
return debug_engine;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default engine. If wanting consistent proof construction, uncomment the line to return the debug engine.
|
||||
*/
|
||||
RNG& get_randomness()
|
||||
{
|
||||
// return get_debug_randomness();
|
||||
static RandomEngine engine;
|
||||
return engine;
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
#include "../uint128/uint128.hpp"
|
||||
#include "../uint256/uint256.hpp"
|
||||
#include "../uintx/uintx.hpp"
|
||||
#include "unistd.h"
|
||||
#include <cstdint>
|
||||
#include <random>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
class RNG {
|
||||
public:
|
||||
virtual uint8_t get_random_uint8() = 0;
|
||||
|
||||
virtual uint16_t get_random_uint16() = 0;
|
||||
|
||||
virtual uint32_t get_random_uint32() = 0;
|
||||
|
||||
virtual uint64_t get_random_uint64() = 0;
|
||||
|
||||
virtual uint128_t get_random_uint128() = 0;
|
||||
|
||||
virtual uint256_t get_random_uint256() = 0;
|
||||
|
||||
virtual ~RNG() = default;
|
||||
RNG() noexcept = default;
|
||||
RNG(const RNG& other) = default;
|
||||
RNG(RNG&& other) = default;
|
||||
RNG& operator=(const RNG& other) = default;
|
||||
RNG& operator=(RNG&& other) = default;
|
||||
|
||||
uint512_t get_random_uint512()
|
||||
{
|
||||
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
|
||||
auto lo = get_random_uint256();
|
||||
auto hi = get_random_uint256();
|
||||
return { lo, hi };
|
||||
}
|
||||
|
||||
uint1024_t get_random_uint1024()
|
||||
{
|
||||
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
|
||||
auto lo = get_random_uint512();
|
||||
auto hi = get_random_uint512();
|
||||
return { lo, hi };
|
||||
}
|
||||
};
|
||||
|
||||
RNG& get_debug_randomness(bool reset = false, std::uint_fast64_t seed = 12345);
|
||||
RNG& get_randomness();
|
||||
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,212 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
#include <ostream>
|
||||
|
||||
#ifdef __i386__
|
||||
#include "../../common/serialize.hpp"
|
||||
#include <concepts>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
class alignas(32) uint128_t {
|
||||
public:
|
||||
uint32_t data[4]; // NOLINT
|
||||
|
||||
constexpr uint128_t(const uint64_t a = 0)
|
||||
: data{ static_cast<uint32_t>(a), static_cast<uint32_t>(a >> 32), 0, 0 }
|
||||
{}
|
||||
|
||||
constexpr uint128_t(const uint32_t a, const uint32_t b, const uint32_t c, const uint32_t d)
|
||||
: data{ a, b, c, d }
|
||||
{}
|
||||
|
||||
constexpr uint128_t(const uint128_t& other)
|
||||
: data{ other.data[0], other.data[1], other.data[2], other.data[3] }
|
||||
{}
|
||||
constexpr uint128_t(uint128_t&& other) = default;
|
||||
|
||||
static constexpr uint128_t from_uint64(const uint64_t a)
|
||||
{
|
||||
return { static_cast<uint32_t>(a), static_cast<uint32_t>(a >> 32), 0, 0 };
|
||||
}
|
||||
|
||||
constexpr explicit operator uint64_t() { return (static_cast<uint64_t>(data[1]) << 32) + data[0]; }
|
||||
|
||||
constexpr uint128_t& operator=(const uint128_t& other) = default;
|
||||
constexpr uint128_t& operator=(uint128_t&& other) = default;
|
||||
constexpr ~uint128_t() = default;
|
||||
explicit constexpr operator bool() const { return static_cast<bool>(data[0]); };
|
||||
|
||||
template <std::integral T> explicit constexpr operator T() const { return static_cast<T>(data[0]); };
|
||||
|
||||
[[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const;
|
||||
[[nodiscard]] constexpr uint64_t get_msb() const;
|
||||
|
||||
[[nodiscard]] constexpr uint128_t slice(uint64_t start, uint64_t end) const;
|
||||
[[nodiscard]] constexpr uint128_t pow(const uint128_t& exponent) const;
|
||||
|
||||
constexpr uint128_t operator+(const uint128_t& other) const;
|
||||
constexpr uint128_t operator-(const uint128_t& other) const;
|
||||
constexpr uint128_t operator-() const;
|
||||
|
||||
constexpr uint128_t operator*(const uint128_t& other) const;
|
||||
constexpr uint128_t operator/(const uint128_t& other) const;
|
||||
constexpr uint128_t operator%(const uint128_t& other) const;
|
||||
|
||||
constexpr uint128_t operator>>(const uint128_t& other) const;
|
||||
constexpr uint128_t operator<<(const uint128_t& other) const;
|
||||
|
||||
constexpr uint128_t operator&(const uint128_t& other) const;
|
||||
constexpr uint128_t operator^(const uint128_t& other) const;
|
||||
constexpr uint128_t operator|(const uint128_t& other) const;
|
||||
constexpr uint128_t operator~() const;
|
||||
|
||||
constexpr bool operator==(const uint128_t& other) const;
|
||||
constexpr bool operator!=(const uint128_t& other) const;
|
||||
constexpr bool operator!() const;
|
||||
|
||||
constexpr bool operator>(const uint128_t& other) const;
|
||||
constexpr bool operator<(const uint128_t& other) const;
|
||||
constexpr bool operator>=(const uint128_t& other) const;
|
||||
constexpr bool operator<=(const uint128_t& other) const;
|
||||
|
||||
static constexpr size_t length() { return 128; }
|
||||
|
||||
constexpr uint128_t& operator+=(const uint128_t& other)
|
||||
{
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator-=(const uint128_t& other)
|
||||
{
|
||||
*this = *this - other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator*=(const uint128_t& other)
|
||||
{
|
||||
*this = *this * other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator/=(const uint128_t& other)
|
||||
{
|
||||
*this = *this / other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator%=(const uint128_t& other)
|
||||
{
|
||||
*this = *this % other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uint128_t& operator++()
|
||||
{
|
||||
*this += uint128_t(1);
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator--()
|
||||
{
|
||||
*this -= uint128_t(1);
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uint128_t& operator&=(const uint128_t& other)
|
||||
{
|
||||
*this = *this & other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator^=(const uint128_t& other)
|
||||
{
|
||||
*this = *this ^ other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator|=(const uint128_t& other)
|
||||
{
|
||||
*this = *this | other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uint128_t& operator>>=(const uint128_t& other)
|
||||
{
|
||||
*this = *this >> other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint128_t& operator<<=(const uint128_t& other)
|
||||
{
|
||||
*this = *this << other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
[[nodiscard]] constexpr std::pair<uint128_t, uint128_t> mul_extended(const uint128_t& other) const;
|
||||
|
||||
[[nodiscard]] constexpr std::pair<uint128_t, uint128_t> divmod(const uint128_t& b) const;
|
||||
|
||||
private:
|
||||
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> mul_wide(uint32_t a, uint32_t b);
|
||||
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> addc(uint32_t a, uint32_t b, uint32_t carry_in);
|
||||
[[nodiscard]] static constexpr uint32_t addc_discard_hi(uint32_t a, uint32_t b, uint32_t carry_in);
|
||||
[[nodiscard]] static constexpr uint32_t sbb_discard_hi(uint32_t a, uint32_t b, uint32_t borrow_in);
|
||||
|
||||
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> sbb(uint32_t a, uint32_t b, uint32_t borrow_in);
|
||||
[[nodiscard]] static constexpr uint32_t mac_discard_hi(uint32_t a, uint32_t b, uint32_t c, uint32_t carry_in);
|
||||
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> mac(uint32_t a,
|
||||
uint32_t b,
|
||||
uint32_t c,
|
||||
uint32_t carry_in);
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, uint128_t const& a)
|
||||
{
|
||||
std::ios_base::fmtflags f(os.flags());
|
||||
os << std::hex << "0x" << std::setfill('0') << std::setw(8) << a.data[3] << std::setw(8) << a.data[2]
|
||||
<< std::setw(8) << a.data[1] << std::setw(8) << a.data[0];
|
||||
os.flags(f);
|
||||
return os;
|
||||
}
|
||||
|
||||
template <typename B> inline void read(B& it, uint128_t& value)
|
||||
{
|
||||
using serialize::read;
|
||||
uint32_t a = 0;
|
||||
uint32_t b = 0;
|
||||
uint32_t c = 0;
|
||||
uint32_t d = 0;
|
||||
read(it, d);
|
||||
read(it, c);
|
||||
read(it, b);
|
||||
read(it, a);
|
||||
value = uint128_t(a, b, c, d);
|
||||
}
|
||||
|
||||
template <typename B> inline void write(B& it, uint128_t const& value)
|
||||
{
|
||||
using serialize::write;
|
||||
write(it, value.data[3]);
|
||||
write(it, value.data[2]);
|
||||
write(it, value.data[1]);
|
||||
write(it, value.data[0]);
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
|
||||
#include "./uint128_impl.hpp"
|
||||
|
||||
// disable linter errors; we want to expose a global uint128_t type to mimic uint64_t, uint32_t etc
|
||||
// NOLINTNEXTLINE(tidymisc-unused-using-decls, google-global-names-in-headers, misc-unused-using-decls)
|
||||
using numeric::uint128_t;
|
||||
#else
|
||||
__extension__ using uint128_t = unsigned __int128;
|
||||
|
||||
namespace std {
|
||||
// can ignore linter error for streaming operations, we need to add to std namespace to support printing this type!
|
||||
// NOLINTNEXTLINE(cert-dcl58-cpp)
|
||||
inline std::ostream& operator<<(std::ostream& os, uint128_t const& a)
|
||||
{
|
||||
std::ios_base::fmtflags f(os.flags());
|
||||
os << std::hex << "0x" << std::setfill('0') << std::setw(16) << static_cast<uint64_t>(a >> 64) << std::setw(16)
|
||||
<< static_cast<uint64_t>(a);
|
||||
os.flags(f);
|
||||
return os;
|
||||
}
|
||||
} // namespace std
|
||||
#endif
|
||||
@@ -0,0 +1,414 @@
|
||||
#ifdef __i386__
|
||||
#pragma once
|
||||
#include "../bitop/get_msb.hpp"
|
||||
#include "./uint128.hpp"
|
||||
#include "../../common/assert.hpp"
|
||||
namespace bb::numeric {
|
||||
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
|
||||
{
|
||||
const uint32_t a_lo = a & 0xffffULL;
|
||||
const uint32_t a_hi = a >> 16ULL;
|
||||
const uint32_t b_lo = b & 0xffffULL;
|
||||
const uint32_t b_hi = b >> 16ULL;
|
||||
|
||||
const uint32_t lo_lo = a_lo * b_lo;
|
||||
const uint32_t hi_lo = a_hi * b_lo;
|
||||
const uint32_t lo_hi = a_lo * b_hi;
|
||||
const uint32_t hi_hi = a_hi * b_hi;
|
||||
|
||||
const uint32_t cross = (lo_lo >> 16) + (hi_lo & 0xffffULL) + lo_hi;
|
||||
|
||||
return { (cross << 16ULL) | (lo_lo & 0xffffULL), (hi_lo >> 16ULL) + (cross >> 16ULL) + hi_hi };
|
||||
}
|
||||
|
||||
// compute a + b + carry, returning the carry
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
|
||||
{
|
||||
const uint32_t sum = a + b;
|
||||
const auto carry_temp = static_cast<uint32_t>(sum < a);
|
||||
const uint32_t r = sum + carry_in;
|
||||
const uint32_t carry_out = carry_temp + static_cast<unsigned int>(r < carry_in);
|
||||
return { r, carry_out };
|
||||
}
|
||||
|
||||
constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
|
||||
{
|
||||
return a + b + carry_in;
|
||||
}
|
||||
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
|
||||
{
|
||||
const uint32_t t_1 = a - (borrow_in >> 31ULL);
|
||||
const auto borrow_temp_1 = static_cast<uint32_t>(t_1 > a);
|
||||
const uint32_t t_2 = t_1 - b;
|
||||
const auto borrow_temp_2 = static_cast<uint32_t>(t_2 > t_1);
|
||||
|
||||
return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
|
||||
}
|
||||
|
||||
constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
|
||||
{
|
||||
return a - b - (borrow_in >> 31ULL);
|
||||
}
|
||||
|
||||
// {r, carry_out} = a + carry_in + b * c
|
||||
constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
|
||||
const uint32_t b,
|
||||
const uint32_t c,
|
||||
const uint32_t carry_in)
|
||||
{
|
||||
std::pair<uint32_t, uint32_t> result = mul_wide(b, c);
|
||||
result.first += a;
|
||||
const auto overflow_c = static_cast<uint32_t>(result.first < a);
|
||||
result.first += carry_in;
|
||||
const auto overflow_carry = static_cast<uint32_t>(result.first < carry_in);
|
||||
result.second += (overflow_c + overflow_carry);
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
|
||||
const uint32_t b,
|
||||
const uint32_t c,
|
||||
const uint32_t carry_in)
|
||||
{
|
||||
return (b * c + a + carry_in);
|
||||
}
|
||||
|
||||
constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
|
||||
{
|
||||
if (*this == 0 || b == 0) {
|
||||
return { 0, 0 };
|
||||
}
|
||||
if (b == 1) {
|
||||
return { *this, 0 };
|
||||
}
|
||||
if (*this == b) {
|
||||
return { 1, 0 };
|
||||
}
|
||||
if (b > *this) {
|
||||
return { 0, *this };
|
||||
}
|
||||
|
||||
uint128_t quotient = 0;
|
||||
uint128_t remainder = *this;
|
||||
|
||||
uint64_t bit_difference = get_msb() - b.get_msb();
|
||||
|
||||
uint128_t divisor = b << bit_difference;
|
||||
uint128_t accumulator = uint128_t(1) << bit_difference;
|
||||
|
||||
// if the divisor is bigger than the remainder, a and b have the same bit length
|
||||
if (divisor > remainder) {
|
||||
divisor >>= 1;
|
||||
accumulator >>= 1;
|
||||
}
|
||||
|
||||
// while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
|
||||
// and add to the quotient
|
||||
while (remainder >= b) {
|
||||
|
||||
// we've shunted 'divisor' up to have the same bit length as our remainder.
|
||||
// If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
|
||||
if (remainder >= divisor) {
|
||||
remainder -= divisor;
|
||||
// we can use OR here instead of +, as
|
||||
// accumulator is always a nice power of two
|
||||
quotient |= accumulator;
|
||||
}
|
||||
divisor >>= 1;
|
||||
accumulator >>= 1;
|
||||
}
|
||||
|
||||
return { quotient, remainder };
|
||||
}
|
||||
|
||||
constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
|
||||
{
|
||||
const auto [r0, t0] = mul_wide(data[0], other.data[0]);
|
||||
const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
|
||||
const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
|
||||
const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
|
||||
|
||||
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
|
||||
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
|
||||
const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
|
||||
const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
|
||||
|
||||
const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
|
||||
const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
|
||||
const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
|
||||
const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
|
||||
|
||||
const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
|
||||
const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
|
||||
const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
|
||||
const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
|
||||
|
||||
uint128_t lo(r0, r1, r2, r3);
|
||||
uint128_t hi(r4, r5, r6, r7);
|
||||
return { lo, hi };
|
||||
}
|
||||
|
||||
/**
|
||||
* Viewing `this` uint128_t as a bit string, and counting bits from 0, slices a substring.
|
||||
* @returns the uint128_t equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the
|
||||
* `end`-th bit of `this`.
|
||||
*/
|
||||
constexpr uint128_t uint128_t::slice(const uint64_t start, const uint64_t end) const
|
||||
{
|
||||
const uint64_t range = end - start;
|
||||
const uint128_t mask = (range == 128) ? -uint128_t(1) : (uint128_t(1) << range) - 1;
|
||||
return ((*this) >> start) & mask;
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::pow(const uint128_t& exponent) const
|
||||
{
|
||||
uint128_t accumulator{ data[0], data[1], data[2], data[3] };
|
||||
uint128_t to_mul{ data[0], data[1], data[2], data[3] };
|
||||
const uint64_t maximum_set_bit = exponent.get_msb();
|
||||
|
||||
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
|
||||
accumulator *= accumulator;
|
||||
if (exponent.get_bit(static_cast<uint64_t>(i))) {
|
||||
accumulator *= to_mul;
|
||||
}
|
||||
}
|
||||
if (exponent == uint128_t(0)) {
|
||||
accumulator = uint128_t(1);
|
||||
} else if (*this == uint128_t(0)) {
|
||||
accumulator = uint128_t(0);
|
||||
}
|
||||
return accumulator;
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::get_bit(const uint64_t bit_index) const
|
||||
{
|
||||
ASSERT(bit_index < 128);
|
||||
if (bit_index > 127) {
|
||||
return false;
|
||||
}
|
||||
const auto idx = static_cast<size_t>(bit_index >> 5);
|
||||
const size_t shift = bit_index & 31;
|
||||
return static_cast<bool>((data[idx] >> shift) & 1);
|
||||
}
|
||||
|
||||
constexpr uint64_t uint128_t::get_msb() const
|
||||
{
|
||||
uint64_t idx = numeric::get_msb64(data[3]);
|
||||
idx = (idx == 0 && data[3] == 0) ? numeric::get_msb64(data[2]) : idx + 32;
|
||||
idx = (idx == 0 && data[2] == 0) ? numeric::get_msb64(data[1]) : idx + 32;
|
||||
idx = (idx == 0 && data[1] == 0) ? numeric::get_msb64(data[0]) : idx + 32;
|
||||
return idx;
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator+(const uint128_t& other) const
|
||||
{
|
||||
const auto [r0, t0] = addc(data[0], other.data[0], 0);
|
||||
const auto [r1, t1] = addc(data[1], other.data[1], t0);
|
||||
const auto [r2, t2] = addc(data[2], other.data[2], t1);
|
||||
const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
|
||||
return { r0, r1, r2, r3 };
|
||||
};
|
||||
|
||||
constexpr uint128_t uint128_t::operator-(const uint128_t& other) const
|
||||
{
|
||||
|
||||
const auto [r0, t0] = sbb(data[0], other.data[0], 0);
|
||||
const auto [r1, t1] = sbb(data[1], other.data[1], t0);
|
||||
const auto [r2, t2] = sbb(data[2], other.data[2], t1);
|
||||
const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
|
||||
return { r0, r1, r2, r3 };
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator-() const
|
||||
{
|
||||
return uint128_t(0) - *this;
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator*(const uint128_t& other) const
|
||||
{
|
||||
const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
|
||||
const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
|
||||
const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
|
||||
const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
|
||||
|
||||
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
|
||||
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
|
||||
const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
|
||||
|
||||
const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
|
||||
const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
|
||||
|
||||
const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
|
||||
|
||||
return { r0, r1, r2, r3 };
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator/(const uint128_t& other) const
|
||||
{
|
||||
return divmod(other).first;
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator%(const uint128_t& other) const
|
||||
{
|
||||
return divmod(other).second;
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator&(const uint128_t& other) const
|
||||
{
|
||||
return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator^(const uint128_t& other) const
|
||||
{
|
||||
return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator|(const uint128_t& other) const
|
||||
{
|
||||
return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator~() const
|
||||
{
|
||||
return { ~data[0], ~data[1], ~data[2], ~data[3] };
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator==(const uint128_t& other) const
|
||||
{
|
||||
return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator!=(const uint128_t& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator!() const
|
||||
{
|
||||
return *this == uint128_t(0ULL);
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator>(const uint128_t& other) const
|
||||
{
|
||||
bool t0 = data[3] > other.data[3];
|
||||
bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
|
||||
bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
|
||||
bool t3 =
|
||||
data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
|
||||
return t0 || t1 || t2 || t3;
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator>=(const uint128_t& other) const
|
||||
{
|
||||
return (*this > other) || (*this == other);
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator<(const uint128_t& other) const
|
||||
{
|
||||
return other > *this;
|
||||
}
|
||||
|
||||
constexpr bool uint128_t::operator<=(const uint128_t& other) const
|
||||
{
|
||||
return (*this < other) || (*this == other);
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator>>(const uint128_t& other) const
|
||||
{
|
||||
uint32_t total_shift = other.data[0];
|
||||
|
||||
if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (total_shift == 0) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint32_t num_shifted_limbs = total_shift >> 5ULL;
|
||||
uint32_t limb_shift = total_shift & 31ULL;
|
||||
|
||||
std::array<uint32_t, 4> shifted_limbs = { 0, 0, 0, 0 };
|
||||
|
||||
if (limb_shift == 0) {
|
||||
shifted_limbs[0] = data[0];
|
||||
shifted_limbs[1] = data[1];
|
||||
shifted_limbs[2] = data[2];
|
||||
shifted_limbs[3] = data[3];
|
||||
} else {
|
||||
uint32_t remainder_shift = 32ULL - limb_shift;
|
||||
|
||||
shifted_limbs[3] = data[3] >> limb_shift;
|
||||
|
||||
uint32_t remainder = (data[3]) << remainder_shift;
|
||||
|
||||
shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
|
||||
|
||||
remainder = (data[2]) << remainder_shift;
|
||||
|
||||
shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
|
||||
|
||||
remainder = (data[1]) << remainder_shift;
|
||||
|
||||
shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
|
||||
}
|
||||
uint128_t result(0);
|
||||
|
||||
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
|
||||
result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr uint128_t uint128_t::operator<<(const uint128_t& other) const
|
||||
{
|
||||
uint32_t total_shift = other.data[0];
|
||||
|
||||
if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (total_shift == 0) {
|
||||
return *this;
|
||||
}
|
||||
uint32_t num_shifted_limbs = total_shift >> 5ULL;
|
||||
uint32_t limb_shift = total_shift & 31ULL;
|
||||
|
||||
std::array<uint32_t, 4> shifted_limbs{ 0, 0, 0, 0 };
|
||||
|
||||
if (limb_shift == 0) {
|
||||
shifted_limbs[0] = data[0];
|
||||
shifted_limbs[1] = data[1];
|
||||
shifted_limbs[2] = data[2];
|
||||
shifted_limbs[3] = data[3];
|
||||
} else {
|
||||
uint32_t remainder_shift = 32ULL - limb_shift;
|
||||
|
||||
shifted_limbs[0] = data[0] << limb_shift;
|
||||
|
||||
uint32_t remainder = data[0] >> remainder_shift;
|
||||
|
||||
shifted_limbs[1] = (data[1] << limb_shift) + remainder;
|
||||
|
||||
remainder = data[1] >> remainder_shift;
|
||||
|
||||
shifted_limbs[2] = (data[2] << limb_shift) + remainder;
|
||||
|
||||
remainder = data[2] >> remainder_shift;
|
||||
|
||||
shifted_limbs[3] = (data[3] << limb_shift) + remainder;
|
||||
}
|
||||
uint128_t result(0);
|
||||
|
||||
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
|
||||
result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
#endif
|
||||
@@ -0,0 +1,239 @@
|
||||
/**
|
||||
* uint256_t
|
||||
* Copyright Aztec 2020
|
||||
*
|
||||
* An unsigned 256 bit integer type.
|
||||
*
|
||||
* Constructor and all methods are constexpr.
|
||||
* Ideally, uint256_t should be able to be treated like any other literal type.
|
||||
*
|
||||
* Not optimized for performance, this code doesn't touch any of our hot paths when constructing PLONK proofs.
|
||||
**/
|
||||
#pragma once
|
||||
|
||||
#include "../uint128/uint128.hpp"
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include <concepts>
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
class alignas(32) uint256_t {
|
||||
|
||||
public:
|
||||
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
|
||||
#define WASM_NUM_LIMBS 9
|
||||
#define WASM_LIMB_BITS 29
|
||||
#endif
|
||||
constexpr uint256_t(const uint64_t a = 0) noexcept
|
||||
: data{ a, 0, 0, 0 }
|
||||
{}
|
||||
|
||||
constexpr uint256_t(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
|
||||
: data{ a, b, c, d }
|
||||
{}
|
||||
|
||||
constexpr uint256_t(const uint256_t& other) noexcept
|
||||
: data{ other.data[0], other.data[1], other.data[2], other.data[3] }
|
||||
{}
|
||||
constexpr uint256_t(uint256_t&& other) noexcept = default;
|
||||
|
||||
explicit constexpr uint256_t(std::string input) noexcept
|
||||
{
|
||||
/* Quick and dirty conversion from a single character to its hex equivelent */
|
||||
constexpr auto HexCharToInt = [](uint8_t Input) {
|
||||
bool valid =
|
||||
(Input >= 'a' && Input <= 'f') || (Input >= 'A' && Input <= 'F') || (Input >= '0' && Input <= '9');
|
||||
if (!valid) {
|
||||
throw_or_abort("Error, uint256 constructed from string_view with invalid hex parameter");
|
||||
}
|
||||
uint8_t res =
|
||||
((Input >= 'a') && (Input <= 'f')) ? (Input - (static_cast<uint8_t>('a') - static_cast<uint8_t>(10)))
|
||||
: ((Input >= 'A') && (Input <= 'F')) ? (Input - (static_cast<uint8_t>('A') - static_cast<uint8_t>(10)))
|
||||
: ((Input >= '0') && (Input <= '9')) ? (Input - static_cast<uint8_t>('0'))
|
||||
: 0;
|
||||
return res;
|
||||
};
|
||||
|
||||
std::array<uint64_t, 4> limbs{ 0, 0, 0, 0 };
|
||||
size_t start_index = 0;
|
||||
if (input.size() == 66 && input[0] == '0' && input[1] == 'x') {
|
||||
start_index = 2;
|
||||
} else if (input.size() != 64) {
|
||||
throw_or_abort("Error, uint256 constructed from string_view with invalid length");
|
||||
}
|
||||
for (size_t j = 0; j < 4; ++j) {
|
||||
|
||||
const size_t limb_index = start_index + j * 16;
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
const size_t byte_index = limb_index + (i * 2);
|
||||
uint8_t nibble_hi = HexCharToInt(static_cast<uint8_t>(input[byte_index]));
|
||||
uint8_t nibble_lo = HexCharToInt(static_cast<uint8_t>(input[byte_index + 1]));
|
||||
uint8_t byte = static_cast<uint8_t>((nibble_hi * 16) + nibble_lo);
|
||||
limbs[j] <<= 8;
|
||||
limbs[j] += byte;
|
||||
}
|
||||
}
|
||||
data[0] = limbs[3];
|
||||
data[1] = limbs[2];
|
||||
data[2] = limbs[1];
|
||||
data[3] = limbs[0];
|
||||
}
|
||||
|
||||
static constexpr uint256_t from_uint128(const uint128_t a) noexcept
|
||||
{
|
||||
return { static_cast<uint64_t>(a), static_cast<uint64_t>(a >> 64), 0, 0 };
|
||||
}
|
||||
|
||||
constexpr explicit operator uint128_t() { return (static_cast<uint128_t>(data[1]) << 64) + data[0]; }
|
||||
|
||||
constexpr uint256_t& operator=(const uint256_t& other) noexcept = default;
|
||||
constexpr uint256_t& operator=(uint256_t&& other) noexcept = default;
|
||||
constexpr ~uint256_t() noexcept = default;
|
||||
|
||||
explicit constexpr operator bool() const { return static_cast<bool>(data[0]); };
|
||||
|
||||
template <std::integral T> explicit constexpr operator T() const { return static_cast<T>(data[0]); };
|
||||
|
||||
[[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const;
|
||||
[[nodiscard]] constexpr uint64_t get_msb() const;
|
||||
|
||||
[[nodiscard]] constexpr uint256_t slice(uint64_t start, uint64_t end) const;
|
||||
[[nodiscard]] constexpr uint256_t pow(const uint256_t& exponent) const;
|
||||
|
||||
constexpr uint256_t operator+(const uint256_t& other) const;
|
||||
constexpr uint256_t operator-(const uint256_t& other) const;
|
||||
constexpr uint256_t operator-() const;
|
||||
|
||||
constexpr uint256_t operator*(const uint256_t& other) const;
|
||||
constexpr uint256_t operator/(const uint256_t& other) const;
|
||||
constexpr uint256_t operator%(const uint256_t& other) const;
|
||||
|
||||
constexpr uint256_t operator>>(const uint256_t& other) const;
|
||||
constexpr uint256_t operator<<(const uint256_t& other) const;
|
||||
|
||||
constexpr uint256_t operator&(const uint256_t& other) const;
|
||||
constexpr uint256_t operator^(const uint256_t& other) const;
|
||||
constexpr uint256_t operator|(const uint256_t& other) const;
|
||||
constexpr uint256_t operator~() const;
|
||||
|
||||
constexpr bool operator==(const uint256_t& other) const;
|
||||
constexpr bool operator!=(const uint256_t& other) const;
|
||||
constexpr bool operator!() const;
|
||||
|
||||
constexpr bool operator>(const uint256_t& other) const;
|
||||
constexpr bool operator<(const uint256_t& other) const;
|
||||
constexpr bool operator>=(const uint256_t& other) const;
|
||||
constexpr bool operator<=(const uint256_t& other) const;
|
||||
|
||||
static constexpr size_t length() { return 256; }
|
||||
|
||||
constexpr uint256_t& operator+=(const uint256_t& other)
|
||||
{
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator-=(const uint256_t& other)
|
||||
{
|
||||
*this = *this - other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator*=(const uint256_t& other)
|
||||
{
|
||||
*this = *this * other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator/=(const uint256_t& other)
|
||||
{
|
||||
*this = *this / other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator%=(const uint256_t& other)
|
||||
{
|
||||
*this = *this % other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uint256_t& operator++()
|
||||
{
|
||||
*this += uint256_t(1);
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator--()
|
||||
{
|
||||
*this -= uint256_t(1);
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uint256_t& operator&=(const uint256_t& other)
|
||||
{
|
||||
*this = *this & other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator^=(const uint256_t& other)
|
||||
{
|
||||
*this = *this ^ other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator|=(const uint256_t& other)
|
||||
{
|
||||
*this = *this | other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uint256_t& operator>>=(const uint256_t& other)
|
||||
{
|
||||
*this = *this >> other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uint256_t& operator<<=(const uint256_t& other)
|
||||
{
|
||||
*this = *this << other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
[[nodiscard]] constexpr std::pair<uint256_t, uint256_t> mul_extended(const uint256_t& other) const;
|
||||
|
||||
uint64_t data[4]; // NOLINT
|
||||
|
||||
[[nodiscard]] constexpr std::pair<uint256_t, uint256_t> divmod(const uint256_t& b) const;
|
||||
|
||||
private:
|
||||
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b);
|
||||
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> addc(uint64_t a, uint64_t b, uint64_t carry_in);
|
||||
[[nodiscard]] static constexpr uint64_t addc_discard_hi(uint64_t a, uint64_t b, uint64_t carry_in);
|
||||
[[nodiscard]] static constexpr uint64_t sbb_discard_hi(uint64_t a, uint64_t b, uint64_t borrow_in);
|
||||
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> sbb(uint64_t a, uint64_t b, uint64_t borrow_in);
|
||||
[[nodiscard]] static constexpr uint64_t mac_discard_hi(uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in);
|
||||
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> mac(uint64_t a,
|
||||
uint64_t b,
|
||||
uint64_t c,
|
||||
uint64_t carry_in);
|
||||
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
|
||||
static constexpr void wasm_madd(const uint64_t& left_limb,
|
||||
const uint64_t* right_limbs,
|
||||
uint64_t& result_0,
|
||||
uint64_t& result_1,
|
||||
uint64_t& result_2,
|
||||
uint64_t& result_3,
|
||||
uint64_t& result_4,
|
||||
uint64_t& result_5,
|
||||
uint64_t& result_6,
|
||||
uint64_t& result_7,
|
||||
uint64_t& result_8);
|
||||
[[nodiscard]] static constexpr std::array<uint64_t, WASM_NUM_LIMBS> wasm_convert(const uint64_t* data);
|
||||
#endif
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, uint256_t const& a)
|
||||
{
|
||||
std::ios_base::fmtflags f(os.flags());
|
||||
os << std::hex << "0x" << std::setfill('0') << std::setw(16) << a.data[3] << std::setw(16) << a.data[2]
|
||||
<< std::setw(16) << a.data[1] << std::setw(16) << a.data[0];
|
||||
os.flags(f);
|
||||
return os;
|
||||
}
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,622 @@
|
||||
#pragma once
|
||||
#include "../bitop/get_msb.hpp"
|
||||
#include "./uint256.hpp"
|
||||
#include "../../common/assert.hpp"
|
||||
namespace bb::numeric {
|
||||
|
||||
constexpr std::pair<uint64_t, uint64_t> uint256_t::mul_wide(const uint64_t a, const uint64_t b)
|
||||
{
|
||||
const uint64_t a_lo = a & 0xffffffffULL;
|
||||
const uint64_t a_hi = a >> 32ULL;
|
||||
const uint64_t b_lo = b & 0xffffffffULL;
|
||||
const uint64_t b_hi = b >> 32ULL;
|
||||
|
||||
const uint64_t lo_lo = a_lo * b_lo;
|
||||
const uint64_t hi_lo = a_hi * b_lo;
|
||||
const uint64_t lo_hi = a_lo * b_hi;
|
||||
const uint64_t hi_hi = a_hi * b_hi;
|
||||
|
||||
const uint64_t cross = (lo_lo >> 32ULL) + (hi_lo & 0xffffffffULL) + lo_hi;
|
||||
|
||||
return { (cross << 32ULL) | (lo_lo & 0xffffffffULL), (hi_lo >> 32ULL) + (cross >> 32ULL) + hi_hi };
|
||||
}
|
||||
|
||||
// compute a + b + carry, returning the carry
|
||||
constexpr std::pair<uint64_t, uint64_t> uint256_t::addc(const uint64_t a, const uint64_t b, const uint64_t carry_in)
|
||||
{
|
||||
const uint64_t sum = a + b;
|
||||
const auto carry_temp = static_cast<uint64_t>(sum < a);
|
||||
const uint64_t r = sum + carry_in;
|
||||
const uint64_t carry_out = carry_temp + static_cast<uint64_t>(r < carry_in);
|
||||
return { r, carry_out };
|
||||
}
|
||||
|
||||
constexpr uint64_t uint256_t::addc_discard_hi(const uint64_t a, const uint64_t b, const uint64_t carry_in)
|
||||
{
|
||||
return a + b + carry_in;
|
||||
}
|
||||
|
||||
constexpr std::pair<uint64_t, uint64_t> uint256_t::sbb(const uint64_t a, const uint64_t b, const uint64_t borrow_in)
|
||||
{
|
||||
const uint64_t t_1 = a - (borrow_in >> 63ULL);
|
||||
const auto borrow_temp_1 = static_cast<uint64_t>(t_1 > a);
|
||||
const uint64_t t_2 = t_1 - b;
|
||||
const auto borrow_temp_2 = static_cast<uint64_t>(t_2 > t_1);
|
||||
|
||||
return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
|
||||
}
|
||||
|
||||
constexpr uint64_t uint256_t::sbb_discard_hi(const uint64_t a, const uint64_t b, const uint64_t borrow_in)
|
||||
{
|
||||
return a - b - (borrow_in >> 63ULL);
|
||||
}
|
||||
|
||||
// {r, carry_out} = a + carry_in + b * c
|
||||
constexpr std::pair<uint64_t, uint64_t> uint256_t::mac(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
const uint64_t carry_in)
|
||||
{
|
||||
std::pair<uint64_t, uint64_t> result = mul_wide(b, c);
|
||||
result.first += a;
|
||||
const auto overflow_c = static_cast<uint64_t>(result.first < a);
|
||||
result.first += carry_in;
|
||||
const auto overflow_carry = static_cast<uint64_t>(result.first < carry_in);
|
||||
result.second += (overflow_c + overflow_carry);
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr uint64_t uint256_t::mac_discard_hi(const uint64_t a,
|
||||
const uint64_t b,
|
||||
const uint64_t c,
|
||||
const uint64_t carry_in)
|
||||
{
|
||||
return (b * c + a + carry_in);
|
||||
}
|
||||
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
|
||||
|
||||
/**
|
||||
* @brief Multiply one limb by 9 limbs and add to resulting limbs
|
||||
*
|
||||
*/
|
||||
constexpr void uint256_t::wasm_madd(const uint64_t& left_limb,
|
||||
const uint64_t* right_limbs,
|
||||
uint64_t& result_0,
|
||||
uint64_t& result_1,
|
||||
uint64_t& result_2,
|
||||
uint64_t& result_3,
|
||||
uint64_t& result_4,
|
||||
uint64_t& result_5,
|
||||
uint64_t& result_6,
|
||||
uint64_t& result_7,
|
||||
uint64_t& result_8)
|
||||
{
|
||||
result_0 += left_limb * right_limbs[0];
|
||||
result_1 += left_limb * right_limbs[1];
|
||||
result_2 += left_limb * right_limbs[2];
|
||||
result_3 += left_limb * right_limbs[3];
|
||||
result_4 += left_limb * right_limbs[4];
|
||||
result_5 += left_limb * right_limbs[5];
|
||||
result_6 += left_limb * right_limbs[6];
|
||||
result_7 += left_limb * right_limbs[7];
|
||||
result_8 += left_limb * right_limbs[8];
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Convert from 4 64-bit limbs to 9 29-bit limbs
|
||||
*
|
||||
*/
|
||||
constexpr std::array<uint64_t, WASM_NUM_LIMBS> uint256_t::wasm_convert(const uint64_t* data)
|
||||
{
|
||||
return { data[0] & 0x1fffffff,
|
||||
(data[0] >> 29) & 0x1fffffff,
|
||||
((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6),
|
||||
(data[1] >> 23) & 0x1fffffff,
|
||||
((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12),
|
||||
(data[2] >> 17) & 0x1fffffff,
|
||||
((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18),
|
||||
(data[3] >> 11) & 0x1fffffff,
|
||||
(data[3] >> 40) & 0x1fffffff };
|
||||
}
|
||||
#endif
|
||||
constexpr std::pair<uint256_t, uint256_t> uint256_t::divmod(const uint256_t& b) const
|
||||
{
|
||||
if (*this == 0 || b == 0) {
|
||||
return { 0, 0 };
|
||||
}
|
||||
if (b == 1) {
|
||||
return { *this, 0 };
|
||||
}
|
||||
if (*this == b) {
|
||||
return { 1, 0 };
|
||||
}
|
||||
if (b > *this) {
|
||||
return { 0, *this };
|
||||
}
|
||||
|
||||
uint256_t quotient = 0;
|
||||
uint256_t remainder = *this;
|
||||
|
||||
uint64_t bit_difference = get_msb() - b.get_msb();
|
||||
|
||||
uint256_t divisor = b << bit_difference;
|
||||
uint256_t accumulator = uint256_t(1) << bit_difference;
|
||||
|
||||
// if the divisor is bigger than the remainder, a and b have the same bit length
|
||||
if (divisor > remainder) {
|
||||
divisor >>= 1;
|
||||
accumulator >>= 1;
|
||||
}
|
||||
|
||||
// while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
|
||||
// and add to the quotient
|
||||
while (remainder >= b) {
|
||||
|
||||
// we've shunted 'divisor' up to have the same bit length as our remainder.
|
||||
// If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
|
||||
if (remainder >= divisor) {
|
||||
remainder -= divisor;
|
||||
// we can use OR here instead of +, as
|
||||
// accumulator is always a nice power of two
|
||||
quotient |= accumulator;
|
||||
}
|
||||
divisor >>= 1;
|
||||
accumulator >>= 1;
|
||||
}
|
||||
|
||||
return { quotient, remainder };
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute the result of multiplication modulu 2**512
|
||||
*
|
||||
*/
|
||||
constexpr std::pair<uint256_t, uint256_t> uint256_t::mul_extended(const uint256_t& other) const
|
||||
{
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const auto [r0, t0] = mul_wide(data[0], other.data[0]);
|
||||
const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
|
||||
const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
|
||||
const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
|
||||
|
||||
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
|
||||
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
|
||||
const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
|
||||
const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
|
||||
|
||||
const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
|
||||
const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
|
||||
const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
|
||||
const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
|
||||
|
||||
const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
|
||||
const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
|
||||
const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
|
||||
const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
|
||||
|
||||
uint256_t lo(r0, r1, r2, r3);
|
||||
uint256_t hi(r4, r5, r6, r7);
|
||||
return { lo, hi };
|
||||
#else
|
||||
// Convert 4 64-bit limbs to 9 29-bit limbs
|
||||
const auto left = wasm_convert(data);
|
||||
const auto right = wasm_convert(other.data);
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
uint64_t temp_0 = 0;
|
||||
uint64_t temp_1 = 0;
|
||||
uint64_t temp_2 = 0;
|
||||
uint64_t temp_3 = 0;
|
||||
uint64_t temp_4 = 0;
|
||||
uint64_t temp_5 = 0;
|
||||
uint64_t temp_6 = 0;
|
||||
uint64_t temp_7 = 0;
|
||||
uint64_t temp_8 = 0;
|
||||
uint64_t temp_9 = 0;
|
||||
uint64_t temp_10 = 0;
|
||||
uint64_t temp_11 = 0;
|
||||
uint64_t temp_12 = 0;
|
||||
uint64_t temp_13 = 0;
|
||||
uint64_t temp_14 = 0;
|
||||
uint64_t temp_15 = 0;
|
||||
uint64_t temp_16 = 0;
|
||||
|
||||
// Multiply and addd all limbs
|
||||
wasm_madd(left[0], &right[0], temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
wasm_madd(left[1], &right[0], temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
|
||||
wasm_madd(left[2], &right[0], temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
|
||||
wasm_madd(left[3], &right[0], temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
|
||||
wasm_madd(left[4], &right[0], temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
|
||||
wasm_madd(left[5], &right[0], temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
|
||||
wasm_madd(left[6], &right[0], temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
|
||||
wasm_madd(left[7], &right[0], temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
|
||||
wasm_madd(left[8], &right[0], temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
|
||||
|
||||
// Convert from relaxed form into strict 29-bit form (except for temp_16)
|
||||
temp_1 += temp_0 >> WASM_LIMB_BITS;
|
||||
temp_0 &= mask;
|
||||
temp_2 += temp_1 >> WASM_LIMB_BITS;
|
||||
temp_1 &= mask;
|
||||
temp_3 += temp_2 >> WASM_LIMB_BITS;
|
||||
temp_2 &= mask;
|
||||
temp_4 += temp_3 >> WASM_LIMB_BITS;
|
||||
temp_3 &= mask;
|
||||
temp_5 += temp_4 >> WASM_LIMB_BITS;
|
||||
temp_4 &= mask;
|
||||
temp_6 += temp_5 >> WASM_LIMB_BITS;
|
||||
temp_5 &= mask;
|
||||
temp_7 += temp_6 >> WASM_LIMB_BITS;
|
||||
temp_6 &= mask;
|
||||
temp_8 += temp_7 >> WASM_LIMB_BITS;
|
||||
temp_7 &= mask;
|
||||
temp_9 += temp_8 >> WASM_LIMB_BITS;
|
||||
temp_8 &= mask;
|
||||
temp_10 += temp_9 >> WASM_LIMB_BITS;
|
||||
temp_9 &= mask;
|
||||
temp_11 += temp_10 >> WASM_LIMB_BITS;
|
||||
temp_10 &= mask;
|
||||
temp_12 += temp_11 >> WASM_LIMB_BITS;
|
||||
temp_11 &= mask;
|
||||
temp_13 += temp_12 >> WASM_LIMB_BITS;
|
||||
temp_12 &= mask;
|
||||
temp_14 += temp_13 >> WASM_LIMB_BITS;
|
||||
temp_13 &= mask;
|
||||
temp_15 += temp_14 >> WASM_LIMB_BITS;
|
||||
temp_14 &= mask;
|
||||
temp_16 += temp_15 >> WASM_LIMB_BITS;
|
||||
temp_15 &= mask;
|
||||
|
||||
// Convert to 2 4-64-bit limb uint256_t objects
|
||||
return { { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
|
||||
(temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
|
||||
(temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
|
||||
(temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40) },
|
||||
{ (temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63),
|
||||
(temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57),
|
||||
(temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51),
|
||||
(temp_15 >> 13) | (temp_16 << 16) } };
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Viewing `this` uint256_t as a bit string, and counting bits from 0, slices a substring.
|
||||
* @returns the uint256_t equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the
|
||||
* `end`-th bit of `this`.
|
||||
*/
|
||||
constexpr uint256_t uint256_t::slice(const uint64_t start, const uint64_t end) const
|
||||
{
|
||||
const uint64_t range = end - start;
|
||||
const uint256_t mask = (range == 256) ? -uint256_t(1) : (uint256_t(1) << range) - 1;
|
||||
return ((*this) >> start) & mask;
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::pow(const uint256_t& exponent) const
|
||||
{
|
||||
uint256_t accumulator{ data[0], data[1], data[2], data[3] };
|
||||
uint256_t to_mul{ data[0], data[1], data[2], data[3] };
|
||||
const uint64_t maximum_set_bit = exponent.get_msb();
|
||||
|
||||
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
|
||||
accumulator *= accumulator;
|
||||
if (exponent.get_bit(static_cast<uint64_t>(i))) {
|
||||
accumulator *= to_mul;
|
||||
}
|
||||
}
|
||||
if (exponent == uint256_t(0)) {
|
||||
accumulator = uint256_t(1);
|
||||
} else if (*this == uint256_t(0)) {
|
||||
accumulator = uint256_t(0);
|
||||
}
|
||||
return accumulator;
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::get_bit(const uint64_t bit_index) const
|
||||
{
|
||||
ASSERT(bit_index < 256);
|
||||
if (bit_index > 255) {
|
||||
return static_cast<bool>(0);
|
||||
}
|
||||
const auto idx = static_cast<size_t>(bit_index >> 6);
|
||||
const size_t shift = bit_index & 63;
|
||||
return static_cast<bool>((data[idx] >> shift) & 1);
|
||||
}
|
||||
|
||||
constexpr uint64_t uint256_t::get_msb() const
|
||||
{
|
||||
uint64_t idx = numeric::get_msb(data[3]);
|
||||
idx = (idx == 0 && data[3] == 0) ? numeric::get_msb(data[2]) : idx + 64;
|
||||
idx = (idx == 0 && data[2] == 0) ? numeric::get_msb(data[1]) : idx + 64;
|
||||
idx = (idx == 0 && data[1] == 0) ? numeric::get_msb(data[0]) : idx + 64;
|
||||
return idx;
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator+(const uint256_t& other) const
|
||||
{
|
||||
const auto [r0, t0] = addc(data[0], other.data[0], 0);
|
||||
const auto [r1, t1] = addc(data[1], other.data[1], t0);
|
||||
const auto [r2, t2] = addc(data[2], other.data[2], t1);
|
||||
const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
|
||||
return { r0, r1, r2, r3 };
|
||||
};
|
||||
|
||||
constexpr uint256_t uint256_t::operator-(const uint256_t& other) const
|
||||
{
|
||||
|
||||
const auto [r0, t0] = sbb(data[0], other.data[0], 0);
|
||||
const auto [r1, t1] = sbb(data[1], other.data[1], t0);
|
||||
const auto [r2, t2] = sbb(data[2], other.data[2], t1);
|
||||
const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
|
||||
return { r0, r1, r2, r3 };
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator-() const
|
||||
{
|
||||
return uint256_t(0) - *this;
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator*(const uint256_t& other) const
|
||||
{
|
||||
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
|
||||
const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
|
||||
const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
|
||||
const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
|
||||
|
||||
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
|
||||
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
|
||||
const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
|
||||
|
||||
const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
|
||||
const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
|
||||
|
||||
const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
|
||||
|
||||
return { r0, r1, r2, r3 };
|
||||
#else
|
||||
// Convert 4 64-bit limbs to 9 29-bit limbs
|
||||
const auto left = wasm_convert(data);
|
||||
const auto right = wasm_convert(other.data);
|
||||
uint64_t temp_0 = 0;
|
||||
uint64_t temp_1 = 0;
|
||||
uint64_t temp_2 = 0;
|
||||
uint64_t temp_3 = 0;
|
||||
uint64_t temp_4 = 0;
|
||||
uint64_t temp_5 = 0;
|
||||
uint64_t temp_6 = 0;
|
||||
uint64_t temp_7 = 0;
|
||||
uint64_t temp_8 = 0;
|
||||
|
||||
// Multiply and add the product of left limb 0 by all right limbs
|
||||
wasm_madd(left[0], &right[0], temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
|
||||
// Multiply left limb 1 by limbs 0-7 ((1,8) doesn't need to be computed, because it overflows)
|
||||
temp_1 += left[1] * right[0];
|
||||
temp_2 += left[1] * right[1];
|
||||
temp_3 += left[1] * right[2];
|
||||
temp_4 += left[1] * right[3];
|
||||
temp_5 += left[1] * right[4];
|
||||
temp_6 += left[1] * right[5];
|
||||
temp_7 += left[1] * right[6];
|
||||
temp_8 += left[1] * right[7];
|
||||
// Left limb 2 by right 0-6, etc
|
||||
temp_2 += left[2] * right[0];
|
||||
temp_3 += left[2] * right[1];
|
||||
temp_4 += left[2] * right[2];
|
||||
temp_5 += left[2] * right[3];
|
||||
temp_6 += left[2] * right[4];
|
||||
temp_7 += left[2] * right[5];
|
||||
temp_8 += left[2] * right[6];
|
||||
temp_3 += left[3] * right[0];
|
||||
temp_4 += left[3] * right[1];
|
||||
temp_5 += left[3] * right[2];
|
||||
temp_6 += left[3] * right[3];
|
||||
temp_7 += left[3] * right[4];
|
||||
temp_8 += left[3] * right[5];
|
||||
temp_4 += left[4] * right[0];
|
||||
temp_5 += left[4] * right[1];
|
||||
temp_6 += left[4] * right[2];
|
||||
temp_7 += left[4] * right[3];
|
||||
temp_8 += left[4] * right[4];
|
||||
temp_5 += left[5] * right[0];
|
||||
temp_6 += left[5] * right[1];
|
||||
temp_7 += left[5] * right[2];
|
||||
temp_8 += left[5] * right[3];
|
||||
temp_6 += left[6] * right[0];
|
||||
temp_7 += left[6] * right[1];
|
||||
temp_8 += left[6] * right[2];
|
||||
temp_7 += left[7] * right[0];
|
||||
temp_8 += left[7] * right[1];
|
||||
temp_8 += left[8] * right[0];
|
||||
|
||||
// Convert from relaxed form to strict 29-bit form
|
||||
constexpr uint64_t mask = 0x1fffffff;
|
||||
temp_1 += temp_0 >> WASM_LIMB_BITS;
|
||||
temp_0 &= mask;
|
||||
temp_2 += temp_1 >> WASM_LIMB_BITS;
|
||||
temp_1 &= mask;
|
||||
temp_3 += temp_2 >> WASM_LIMB_BITS;
|
||||
temp_2 &= mask;
|
||||
temp_4 += temp_3 >> WASM_LIMB_BITS;
|
||||
temp_3 &= mask;
|
||||
temp_5 += temp_4 >> WASM_LIMB_BITS;
|
||||
temp_4 &= mask;
|
||||
temp_6 += temp_5 >> WASM_LIMB_BITS;
|
||||
temp_5 &= mask;
|
||||
temp_7 += temp_6 >> WASM_LIMB_BITS;
|
||||
temp_6 &= mask;
|
||||
temp_8 += temp_7 >> WASM_LIMB_BITS;
|
||||
temp_7 &= mask;
|
||||
|
||||
// Convert back to 4 64-bit limbs
|
||||
return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
|
||||
(temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
|
||||
(temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
|
||||
(temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40) };
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator/(const uint256_t& other) const
|
||||
{
|
||||
return divmod(other).first;
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator%(const uint256_t& other) const
|
||||
{
|
||||
return divmod(other).second;
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator&(const uint256_t& other) const
|
||||
{
|
||||
return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator^(const uint256_t& other) const
|
||||
{
|
||||
return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator|(const uint256_t& other) const
|
||||
{
|
||||
return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator~() const
|
||||
{
|
||||
return { ~data[0], ~data[1], ~data[2], ~data[3] };
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator==(const uint256_t& other) const
|
||||
{
|
||||
return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator!=(const uint256_t& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator!() const
|
||||
{
|
||||
return *this == uint256_t(0ULL);
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator>(const uint256_t& other) const
|
||||
{
|
||||
bool t0 = data[3] > other.data[3];
|
||||
bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
|
||||
bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
|
||||
bool t3 =
|
||||
data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
|
||||
return t0 || t1 || t2 || t3;
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator>=(const uint256_t& other) const
|
||||
{
|
||||
return (*this > other) || (*this == other);
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator<(const uint256_t& other) const
|
||||
{
|
||||
return other > *this;
|
||||
}
|
||||
|
||||
constexpr bool uint256_t::operator<=(const uint256_t& other) const
|
||||
{
|
||||
return (*this < other) || (*this == other);
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator>>(const uint256_t& other) const
|
||||
{
|
||||
uint64_t total_shift = other.data[0];
|
||||
|
||||
if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (total_shift == 0) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t num_shifted_limbs = total_shift >> 6ULL;
|
||||
uint64_t limb_shift = total_shift & 63ULL;
|
||||
|
||||
std::array<uint64_t, 4> shifted_limbs = { 0, 0, 0, 0 };
|
||||
|
||||
if (limb_shift == 0) {
|
||||
shifted_limbs[0] = data[0];
|
||||
shifted_limbs[1] = data[1];
|
||||
shifted_limbs[2] = data[2];
|
||||
shifted_limbs[3] = data[3];
|
||||
} else {
|
||||
uint64_t remainder_shift = 64ULL - limb_shift;
|
||||
|
||||
shifted_limbs[3] = data[3] >> limb_shift;
|
||||
|
||||
uint64_t remainder = (data[3]) << remainder_shift;
|
||||
|
||||
shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
|
||||
|
||||
remainder = (data[2]) << remainder_shift;
|
||||
|
||||
shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
|
||||
|
||||
remainder = (data[1]) << remainder_shift;
|
||||
|
||||
shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
|
||||
}
|
||||
uint256_t result(0);
|
||||
|
||||
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
|
||||
result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr uint256_t uint256_t::operator<<(const uint256_t& other) const
|
||||
{
|
||||
uint64_t total_shift = other.data[0];
|
||||
|
||||
if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (total_shift == 0) {
|
||||
return *this;
|
||||
}
|
||||
uint64_t num_shifted_limbs = total_shift >> 6ULL;
|
||||
uint64_t limb_shift = total_shift & 63ULL;
|
||||
|
||||
std::array<uint64_t, 4> shifted_limbs = { 0, 0, 0, 0 };
|
||||
|
||||
if (limb_shift == 0) {
|
||||
shifted_limbs[0] = data[0];
|
||||
shifted_limbs[1] = data[1];
|
||||
shifted_limbs[2] = data[2];
|
||||
shifted_limbs[3] = data[3];
|
||||
} else {
|
||||
uint64_t remainder_shift = 64ULL - limb_shift;
|
||||
|
||||
shifted_limbs[0] = data[0] << limb_shift;
|
||||
|
||||
uint64_t remainder = data[0] >> remainder_shift;
|
||||
|
||||
shifted_limbs[1] = (data[1] << limb_shift) + remainder;
|
||||
|
||||
remainder = data[1] >> remainder_shift;
|
||||
|
||||
shifted_limbs[2] = (data[2] << limb_shift) + remainder;
|
||||
|
||||
remainder = data[2] >> remainder_shift;
|
||||
|
||||
shifted_limbs[3] = (data[3] << limb_shift) + remainder;
|
||||
}
|
||||
uint256_t result(0);
|
||||
|
||||
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
|
||||
result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace bb::numeric
|
||||
178
sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx.hpp
Normal file
178
sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx.hpp
Normal file
@@ -0,0 +1,178 @@
|
||||
/**
|
||||
* uintx
|
||||
* Copyright Aztec 2020
|
||||
*
|
||||
* An unsigned 512 bit integer type.
|
||||
*
|
||||
* Constructor and all methods are constexpr. Ideally, uintx should be able to be treated like any other literal
|
||||
*type.
|
||||
*
|
||||
* Not optimized for performance, this code doesn"t touch any of our hot paths when constructing PLONK proofs
|
||||
**/
|
||||
#pragma once
|
||||
|
||||
#include "../uint256/uint256.hpp"
|
||||
#include "../../common/assert.hpp"
|
||||
#include "../../common/throw_or_abort.hpp"
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
namespace bb::numeric {
|
||||
|
||||
template <class base_uint> class uintx {
|
||||
public:
|
||||
constexpr uintx(const uint64_t data = 0)
|
||||
: lo(data)
|
||||
, hi(base_uint(0))
|
||||
{}
|
||||
|
||||
constexpr uintx(const base_uint input_lo)
|
||||
: lo(input_lo)
|
||||
, hi(base_uint(0))
|
||||
{}
|
||||
|
||||
constexpr uintx(const base_uint input_lo, const base_uint input_hi)
|
||||
: lo(input_lo)
|
||||
, hi(input_hi)
|
||||
{}
|
||||
|
||||
constexpr uintx(const uintx& other)
|
||||
: lo(other.lo)
|
||||
, hi(other.hi)
|
||||
{}
|
||||
|
||||
constexpr uintx(uintx&& other) noexcept = default;
|
||||
|
||||
static constexpr size_t length() { return 2 * base_uint::length(); }
|
||||
constexpr uintx& operator=(const uintx& other) = default;
|
||||
constexpr uintx& operator=(uintx&& other) noexcept = default;
|
||||
|
||||
constexpr ~uintx() = default;
|
||||
explicit constexpr operator bool() const { return static_cast<bool>(lo.data[0]); };
|
||||
explicit constexpr operator uint8_t() const { return static_cast<uint8_t>(lo.data[0]); };
|
||||
explicit constexpr operator uint16_t() const { return static_cast<uint16_t>(lo.data[0]); };
|
||||
explicit constexpr operator uint32_t() const { return static_cast<uint32_t>(lo.data[0]); };
|
||||
explicit constexpr operator uint64_t() const { return static_cast<uint64_t>(lo.data[0]); };
|
||||
|
||||
explicit constexpr operator base_uint() const { return lo; }
|
||||
|
||||
[[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const;
|
||||
[[nodiscard]] constexpr uint64_t get_msb() const;
|
||||
constexpr uintx slice(uint64_t start, uint64_t end) const;
|
||||
|
||||
constexpr uintx operator+(const uintx& other) const;
|
||||
constexpr uintx operator-(const uintx& other) const;
|
||||
constexpr uintx operator-() const;
|
||||
|
||||
constexpr uintx operator*(const uintx& other) const;
|
||||
constexpr uintx operator/(const uintx& other) const;
|
||||
constexpr uintx operator%(const uintx& other) const;
|
||||
|
||||
constexpr std::pair<uintx, uintx> mul_extended(const uintx& other) const;
|
||||
|
||||
constexpr uintx operator>>(uint64_t other) const;
|
||||
constexpr uintx operator<<(uint64_t other) const;
|
||||
|
||||
constexpr uintx operator&(const uintx& other) const;
|
||||
constexpr uintx operator^(const uintx& other) const;
|
||||
constexpr uintx operator|(const uintx& other) const;
|
||||
constexpr uintx operator~() const;
|
||||
|
||||
constexpr bool operator==(const uintx& other) const;
|
||||
constexpr bool operator!=(const uintx& other) const;
|
||||
constexpr bool operator!() const;
|
||||
|
||||
constexpr bool operator>(const uintx& other) const;
|
||||
constexpr bool operator<(const uintx& other) const;
|
||||
constexpr bool operator>=(const uintx& other) const;
|
||||
constexpr bool operator<=(const uintx& other) const;
|
||||
|
||||
constexpr uintx& operator+=(const uintx& other)
|
||||
{
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator-=(const uintx& other)
|
||||
{
|
||||
*this = *this - other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator*=(const uintx& other)
|
||||
{
|
||||
*this = *this * other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator/=(const uintx& other)
|
||||
{
|
||||
*this = *this / other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator%=(const uintx& other)
|
||||
{
|
||||
*this = *this % other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uintx& operator++()
|
||||
{
|
||||
*this += uintx(1);
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator--()
|
||||
{
|
||||
*this -= uintx(1);
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uintx& operator&=(const uintx& other)
|
||||
{
|
||||
*this = *this & other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator^=(const uintx& other)
|
||||
{
|
||||
*this = *this ^ other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator|=(const uintx& other)
|
||||
{
|
||||
*this = *this | other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uintx& operator>>=(const uint64_t other)
|
||||
{
|
||||
*this = *this >> other;
|
||||
return *this;
|
||||
};
|
||||
constexpr uintx& operator<<=(const uint64_t other)
|
||||
{
|
||||
*this = *this << other;
|
||||
return *this;
|
||||
};
|
||||
|
||||
constexpr uintx invmod(const uintx& modulus) const;
|
||||
constexpr uintx unsafe_invmod(const uintx& modulus) const;
|
||||
|
||||
base_uint lo;
|
||||
base_uint hi;
|
||||
|
||||
constexpr std::pair<uintx, uintx> divmod(const uintx& b) const;
|
||||
};
|
||||
|
||||
template <class base_uint> inline std::ostream& operator<<(std::ostream& os, uintx<base_uint> const& a)
|
||||
{
|
||||
os << a.lo << ", " << a.hi << std::endl;
|
||||
return os;
|
||||
}
|
||||
|
||||
using uint512_t = uintx<numeric::uint256_t>;
|
||||
using uint1024_t = uintx<uint512_t>;
|
||||
|
||||
} // namespace bb::numeric
|
||||
|
||||
#include "./uintx_impl.hpp"
|
||||
|
||||
using bb::numeric::uint1024_t; // NOLINT
|
||||
using bb::numeric::uint512_t; // NOLINT
|
||||
@@ -0,0 +1,339 @@
|
||||
#pragma once
|
||||
#include "./uintx.hpp"
|
||||
#include "../../common/assert.hpp"
|
||||
|
||||
namespace bb::numeric {
|
||||
template <class base_uint>
|
||||
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
|
||||
{
|
||||
ASSERT(b != 0);
|
||||
if (*this == 0) {
|
||||
return { uintx(0), uintx(0) };
|
||||
}
|
||||
if (b == 1) {
|
||||
return { *this, uintx(0) };
|
||||
}
|
||||
if (*this == b) {
|
||||
return { uintx(1), uintx(0) };
|
||||
}
|
||||
if (b > *this) {
|
||||
return { uintx(0), *this };
|
||||
}
|
||||
|
||||
uintx quotient(0);
|
||||
uintx remainder = *this;
|
||||
|
||||
uint64_t bit_difference = get_msb() - b.get_msb();
|
||||
|
||||
uintx divisor = b << bit_difference;
|
||||
uintx accumulator = uintx(1) << bit_difference;
|
||||
|
||||
// if the divisor is bigger than the remainder, a and b have the same bit length
|
||||
if (divisor > remainder) {
|
||||
divisor >>= 1;
|
||||
accumulator >>= 1;
|
||||
}
|
||||
|
||||
// while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
|
||||
// and add to the quotient
|
||||
while (remainder >= b) {
|
||||
|
||||
// we've shunted 'divisor' up to have the same bit length as our remainder.
|
||||
// If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
|
||||
if (remainder >= divisor) {
|
||||
remainder -= divisor;
|
||||
// we can use OR here instead of +, as
|
||||
// accumulator is always a nice power of two
|
||||
quotient |= accumulator;
|
||||
}
|
||||
divisor >>= 1;
|
||||
accumulator >>= 1;
|
||||
}
|
||||
|
||||
return std::make_pair(quotient, remainder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes invmod. Only for internal usage within the class.
|
||||
* This is an insecure version of the algorithm that doesn't take into account the 0 case and cases when modulus is
|
||||
*close to the top margin.
|
||||
*
|
||||
* @param modulus The modulus of the ring
|
||||
*
|
||||
* @return The inverse of *this modulo modulus
|
||||
**/
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::unsafe_invmod(const uintx& modulus) const
|
||||
{
|
||||
|
||||
uintx t1 = 0;
|
||||
uintx t2 = 1;
|
||||
uintx r2 = (*this > modulus) ? *this % modulus : *this;
|
||||
uintx r1 = modulus;
|
||||
uintx q = 0;
|
||||
while (r2 != 0) {
|
||||
q = r1 / r2;
|
||||
uintx temp_t1 = t1;
|
||||
uintx temp_r1 = r1;
|
||||
t1 = t2;
|
||||
t2 = temp_t1 - q * t2;
|
||||
r1 = r2;
|
||||
r2 = temp_r1 - q * r2;
|
||||
}
|
||||
|
||||
if (t1 > modulus) {
|
||||
return modulus + t1;
|
||||
}
|
||||
return t1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the inverse of *this, modulo modulus, via the extended Euclidean algorithm.
|
||||
*
|
||||
* Delegates to appropriate unsafe_invmod (if the modulus is close to uintx top margin there is a need to expand)
|
||||
*
|
||||
* @param modulus The modulus
|
||||
* @return The inverse of *this modulo modulus
|
||||
**/
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::invmod(const uintx& modulus) const
|
||||
{
|
||||
ASSERT((*this) != 0);
|
||||
if (modulus == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (modulus.get_msb() >= (2 * base_uint::length() - 1)) {
|
||||
uintx<uintx<base_uint>> a_expanded(*this);
|
||||
uintx<uintx<base_uint>> modulus_expanded(modulus);
|
||||
return a_expanded.unsafe_invmod(modulus_expanded).lo;
|
||||
}
|
||||
return this->unsafe_invmod(modulus);
|
||||
}
|
||||
|
||||
/**
|
||||
* Viewing `this` as a bit string, and counting bits from 0, slices a substring.
|
||||
* @returns the uintx equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the
|
||||
* `end`-th bit of `this`.
|
||||
*/
|
||||
template <class base_uint>
|
||||
constexpr uintx<base_uint> uintx<base_uint>::slice(const uint64_t start, const uint64_t end) const
|
||||
{
|
||||
const uint64_t range = end - start;
|
||||
const uintx mask = range == base_uint::length() ? -uintx(1) : (uintx(1) << range) - 1;
|
||||
return ((*this) >> start) & mask;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::get_bit(const uint64_t bit_index) const
|
||||
{
|
||||
if (bit_index >= base_uint::length()) {
|
||||
return hi.get_bit(bit_index - base_uint::length());
|
||||
}
|
||||
return lo.get_bit(bit_index);
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uint64_t uintx<base_uint>::get_msb() const
|
||||
{
|
||||
uint64_t hi_idx = hi.get_msb();
|
||||
uint64_t lo_idx = lo.get_msb();
|
||||
return (hi_idx || (hi > base_uint(0))) ? (hi_idx + base_uint::length()) : lo_idx;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator+(const uintx& other) const
|
||||
{
|
||||
base_uint res_lo = lo + other.lo;
|
||||
bool carry = res_lo < lo;
|
||||
base_uint res_hi = hi + other.hi + ((carry) ? base_uint(1) : base_uint(0));
|
||||
return { res_lo, res_hi };
|
||||
};
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator-(const uintx& other) const
|
||||
{
|
||||
base_uint res_lo = lo - other.lo;
|
||||
bool borrow = res_lo > lo;
|
||||
base_uint res_hi = hi - other.hi - ((borrow) ? base_uint(1) : base_uint(0));
|
||||
return { res_lo, res_hi };
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator-() const
|
||||
{
|
||||
return uintx(0) - *this;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator*(const uintx& other) const
|
||||
{
|
||||
const auto lolo = lo.mul_extended(other.lo);
|
||||
const auto lohi = lo.mul_extended(other.hi);
|
||||
const auto hilo = hi.mul_extended(other.lo);
|
||||
|
||||
base_uint top = lolo.second + hilo.first + lohi.first;
|
||||
base_uint bottom = lolo.first;
|
||||
return { bottom, top };
|
||||
}
|
||||
|
||||
template <class base_uint>
|
||||
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::mul_extended(const uintx& other) const
|
||||
{
|
||||
const auto lolo = lo.mul_extended(other.lo);
|
||||
const auto lohi = lo.mul_extended(other.hi);
|
||||
const auto hilo = hi.mul_extended(other.lo);
|
||||
const auto hihi = hi.mul_extended(other.hi);
|
||||
|
||||
base_uint t0 = lolo.first;
|
||||
base_uint t1 = lolo.second;
|
||||
base_uint t2 = hilo.second;
|
||||
base_uint t3 = hihi.second;
|
||||
base_uint t2_carry(0);
|
||||
base_uint t3_carry(0);
|
||||
t1 += hilo.first;
|
||||
t2_carry += (t1 < hilo.first ? base_uint(1) : base_uint(0));
|
||||
t1 += lohi.first;
|
||||
t2_carry += (t1 < lohi.first ? base_uint(1) : base_uint(0));
|
||||
t2 += lohi.second;
|
||||
t3_carry += (t2 < lohi.second ? base_uint(1) : base_uint(0));
|
||||
t2 += hihi.first;
|
||||
t3_carry += (t2 < hihi.first ? base_uint(1) : base_uint(0));
|
||||
t2 += t2_carry;
|
||||
t3_carry += (t2 < t2_carry ? base_uint(1) : base_uint(0));
|
||||
t3 += t3_carry;
|
||||
return { uintx(t0, t1), uintx(t2, t3) };
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator/(const uintx& other) const
|
||||
{
|
||||
return divmod(other).first;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator%(const uintx& other) const
|
||||
{
|
||||
return divmod(other).second;
|
||||
}
|
||||
// 0x2af0296feca4188a80fd373ebe3c64da87a232934abb3a99f9c4cd59e6758a65
|
||||
// 0x1182c6cdb54193b51ca27c1932b95c82bebac691e3996e5ec5e1d4395f3023e3
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator&(const uintx& other) const
|
||||
{
|
||||
return { lo & other.lo, hi & other.hi };
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator^(const uintx& other) const
|
||||
{
|
||||
return { lo ^ other.lo, hi ^ other.hi };
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator|(const uintx& other) const
|
||||
{
|
||||
return { lo | other.lo, hi | other.hi };
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator~() const
|
||||
{
|
||||
return { ~lo, ~hi };
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator==(const uintx& other) const
|
||||
{
|
||||
return ((lo == other.lo) && (hi == other.hi));
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator!=(const uintx& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator!() const
|
||||
{
|
||||
return *this == uintx(0ULL);
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator>(const uintx& other) const
|
||||
{
|
||||
bool hi_gt = hi > other.hi;
|
||||
bool lo_gt = lo > other.lo;
|
||||
|
||||
bool gt = (hi_gt) || (lo_gt && (hi == other.hi));
|
||||
return gt;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator>=(const uintx& other) const
|
||||
{
|
||||
return (*this > other) || (*this == other);
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator<(const uintx& other) const
|
||||
{
|
||||
return other > *this;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr bool uintx<base_uint>::operator<=(const uintx& other) const
|
||||
{
|
||||
return (*this < other) || (*this == other);
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator>>(const uint64_t other) const
|
||||
{
|
||||
const uint64_t total_shift = other;
|
||||
if (total_shift >= length()) {
|
||||
return uintx(0);
|
||||
}
|
||||
if (total_shift == 0) {
|
||||
return *this;
|
||||
}
|
||||
const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb());
|
||||
|
||||
const uint64_t limb_shift = total_shift & static_cast<uint64_t>(base_uint::length() - 1);
|
||||
|
||||
std::array<base_uint, 2> shifted_limbs = { 0, 0 };
|
||||
if (limb_shift == 0) {
|
||||
shifted_limbs[0] = lo;
|
||||
shifted_limbs[1] = hi;
|
||||
} else {
|
||||
const uint64_t remainder_shift = static_cast<uint64_t>(base_uint::length()) - limb_shift;
|
||||
|
||||
shifted_limbs[1] = hi >> limb_shift;
|
||||
|
||||
base_uint remainder = (hi) << remainder_shift;
|
||||
|
||||
shifted_limbs[0] = (lo >> limb_shift) + remainder;
|
||||
}
|
||||
uintx result(0);
|
||||
if (num_shifted_limbs == 0) {
|
||||
result.hi = shifted_limbs[1];
|
||||
result.lo = shifted_limbs[0];
|
||||
} else {
|
||||
result.lo = shifted_limbs[1];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator<<(const uint64_t other) const
|
||||
{
|
||||
const uint64_t total_shift = other;
|
||||
if (total_shift >= length()) {
|
||||
return uintx(0);
|
||||
}
|
||||
if (total_shift == 0) {
|
||||
return *this;
|
||||
}
|
||||
const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb());
|
||||
const uint64_t limb_shift = total_shift & static_cast<uint64_t>(base_uint::length() - 1);
|
||||
|
||||
std::array<base_uint, 2> shifted_limbs = { 0, 0 };
|
||||
if (limb_shift == 0) {
|
||||
shifted_limbs[0] = lo;
|
||||
shifted_limbs[1] = hi;
|
||||
} else {
|
||||
const uint64_t remainder_shift = static_cast<uint64_t>(base_uint::length()) - limb_shift;
|
||||
|
||||
shifted_limbs[0] = lo << limb_shift;
|
||||
|
||||
base_uint remainder = lo >> remainder_shift;
|
||||
|
||||
shifted_limbs[1] = (hi << limb_shift) + remainder;
|
||||
}
|
||||
uintx result(0);
|
||||
if (num_shifted_limbs == 0) {
|
||||
result.hi = shifted_limbs[1];
|
||||
result.lo = shifted_limbs[0];
|
||||
} else {
|
||||
result.hi = shifted_limbs[0];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace bb::numeric
|
||||
@@ -0,0 +1,27 @@
|
||||
/**
|
||||
* @brief Defines particular circuit builder types expected to be used for circuit
|
||||
construction in stdlib and contains macros for explicit instantiation.
|
||||
*
|
||||
* @details This file is designed to be included in header files to instruct the compiler that these classes exist and
|
||||
* their instantiation will eventually take place. Given it has no dependencies, it causes no additional compilation or
|
||||
* propagation.
|
||||
*/
|
||||
#pragma once
|
||||
#include <concepts>
|
||||
|
||||
namespace bb {
|
||||
class StandardFlavor;
|
||||
class UltraFlavor;
|
||||
class Bn254FrParams;
|
||||
class Bn254FqParams;
|
||||
template <class Params> struct alignas(32) field;
|
||||
template <typename FF_> class UltraArith;
|
||||
template <class FF> class StandardCircuitBuilder_;
|
||||
using StandardCircuitBuilder = StandardCircuitBuilder_<field<Bn254FrParams>>;
|
||||
using StandardGrumpkinCircuitBuilder = StandardCircuitBuilder_<field<Bn254FqParams>>;
|
||||
template <class Arithmetization> class UltraCircuitBuilder_;
|
||||
using UltraCircuitBuilder = UltraCircuitBuilder_<UltraArith<field<Bn254FrParams>>>;
|
||||
template <class FF> class MegaCircuitBuilder_;
|
||||
using MegaCircuitBuilder = MegaCircuitBuilder_<field<Bn254FrParams>>;
|
||||
class CircuitSimulatorBN254;
|
||||
} // namespace bb
|
||||
Reference in New Issue
Block a user