Add barretenberg

Co-authored-by: jeong0982 <soowon1106@gmail.com>
This commit is contained in:
DoHoonKim8
2024-06-12 09:01:49 +00:00
committed by DoHoon Kim
parent 6e64ad017d
commit d5a15c396f
62 changed files with 11365 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
#pragma once
// NOLINTBEGIN
#if NDEBUG
// Compiler should optimize this out in release builds, without triggering an unused variable warning.
#define DONT_EVALUATE(expression) \
{ \
true ? static_cast<void>(0) : static_cast<void>((expression)); \
}
#define ASSERT(expression) DONT_EVALUATE((expression))
#else
// cassert in wasi-sdk takes one second to compile, only include if needed
#include <cassert>
#include <iostream>
#include <stdexcept>
#include <string>
#define ASSERT(expression) assert((expression))
#endif // NDEBUG
// NOLINTEND

View File

@@ -0,0 +1,26 @@
#pragma once
#ifdef _WIN32
#define BB_INLINE __forceinline inline
#else
#define BB_INLINE __attribute__((always_inline)) inline
#endif
// TODO(AD): Other instrumentation?
#ifdef XRAY
#define BB_PROFILE [[clang::xray_always_instrument]] [[clang::noinline]]
#define BB_NO_PROFILE [[clang::xray_never_instrument]]
#else
#define BB_PROFILE
#define BB_NO_PROFILE
#endif
// Optimization hints for clang - which outcome of an expression is expected for better
// branch-prediction optimization
#ifdef __clang__
#define BB_LIKELY(x) __builtin_expect(!!(x), 1)
#define BB_UNLIKELY(x) __builtin_expect(!!(x), 0)
#else
#define BB_LIKELY(x) x
#define BB_UNLIKELY(x) x
#endif

View File

@@ -0,0 +1,162 @@
#pragma once
#include <cstddef>
#include <tuple>
#include <utility>
/**
* @brief constexpr_utils defines some helper methods that perform some stl-equivalent operations
* but in a constexpr context over quantities known at compile-time
*
* Current methods are:
*
* constexpr_for : loop over a range , where the size_t iterator `i` is a constexpr variable
* constexpr_find : find if an element is in an array
*/
namespace bb {
/**
* @brief Implements a loop using a compile-time iterator. Requires c++20.
* Implementation (and description) from https://artificial-mind.net/blog/2020/10/31/constexpr-for
*
* @tparam Start the loop start value
* @tparam End the loop end value
* @tparam Inc how much the iterator increases by per iteration
* @tparam F a Lambda function that is executed once per loop
*
* @param f An rvalue reference to the lambda
* @details Implements a `for` loop where the iterator is a constexpr variable.
* Use this when you need to evaluate `if constexpr` statements on the iterator (or apply other constexpr expressions)
* Outside of this use-case avoid using this fn as it gives negligible performance increases vs regular loops.
*
* N.B. A side-effect of this method is that all loops will be unrolled
* (each loop iteration uses different iterator template parameters => unique constexpr_for implementation per
* iteration)
* Do not use this for large (~100+) loops!
*
* ##############################
* EXAMPLE USE OF `constexpr_for`
* ##############################
*
* constexpr_for<0, 10, 1>([&]<size_t i>(){
* if constexpr (i & 1 == 0)
* {
* foo[i] = even_container[i >> 1];
* }
* else
* {
* foo[i] = odd_container[i >> 1];
* }
* });
*
* In the above example we are iterating from i = 0 to i < 10.
* The provided lambda function has captured everything in its surrounding scope (via `[&]`),
* which is where `foo`, `even_container` and `odd_container` have come from.
*
* We do not need to explicitly define the `class F` parameter as the compiler derives it from our provided input
* argument `F&& f` (i.e. the lambda function)
*
* In the loop itself we're evaluating a constexpr if statement that defines which code path is taken.
*
* The above example benefits from `constexpr_for` because a run-time `if` statement has been reduced to a compile-time
* `if` statement. N.B. this would only give measurable improvements if the `constexpr_for` statement is itself in a hot
* loop that's iterated over many (>thousands) times
*/
template <size_t Start, size_t End, size_t Inc, class F> constexpr void constexpr_for(F&& f)
{
// Call function `f<Start>()` iff Start < End
if constexpr (Start < End) {
// F must be a template lambda with a single **typed** template parameter that represents the iterator
// (e.g. [&]<size_t i>(){ ... } is good)
// (and [&]<typename i>(){ ... } won't compile!)
/**
* Explaining f.template operator()<Start>()
*
* The following line must explicitly tell the compiler that <Start> is a template parameter by using the
* `template` keyword.
* (if we wrote f<Start>(), the compiler could legitimately interpret `<` as a less than symbol)
*
* The fragment `f.template` tells the compiler that we're calling a *templated* member of `f`.
* The "member" being called is the function operator, `operator()`, which must be explicitly provided
* (for any function X, `X(args)` is an alias for `X.operator()(args)`)
* The compiler has no alias `X.template <tparam>(args)` for `X.template operator()<tparam>(args)` so we must
* write it explicitly here
*
* To summarize what the next line tells the compiler...
* 1. I want to call a member of `f` that expects one or more template parameters
* 2. The member of `f` that I want to call is the function operator
* 3. The template parameter is `Start`
* 4. The function operator itself contains no arguments
*/
f.template operator()<Start>();
// Once we have executed `f`, we recursively call the `constexpr_for` function, increasing the value of `Start`
// by `Inc`
constexpr_for<Start + Inc, End, Inc>(f);
}
}
/**
* @brief returns true/false depending on whether `key` is in `container`
*
* @tparam container i.e. what are we looking in?
* @tparam key i.e. what are we looking for?
* @return true found!
* @return false not found!
*
* @details method is constexpr and can be used in static_asserts
*/
template <const auto& container, auto key> constexpr bool constexpr_find()
{
// using ElementType = typename std::remove_extent<ContainerType>::type;
bool found = false;
constexpr_for<0, container.size(), 1>([&]<size_t k>() {
if constexpr (std::get<k>(container) == key) {
found = true;
}
});
return found;
}
/**
* @brief Create a constexpr array object whose elements contain a default value
*
* @tparam T type contained in the array
* @tparam Is index sequence
* @param value the value each array element is being initialized to
* @return constexpr std::array<T, sizeof...(Is)>
*
* @details This method is used to create constexpr arrays whose encapsulated type:
*
* 1. HAS NO CONSTEXPR DEFAULT CONSTRUCTOR
* 2. HAS A CONSTEXPR COPY CONSTRUCTOR
*
* An example of this is bb::field_t
* (the default constructor does not default assign values to the field_t member variables for efficiency reasons, to
* reduce the time require to construct large arrays of field elements. This means the default constructor for field_t
* cannot be constexpr)
*/
template <typename T, std::size_t... Is>
constexpr std::array<T, sizeof...(Is)> create_array(T value, std::index_sequence<Is...> /*unused*/)
{
// cast Is to void to remove the warning: unused value
std::array<T, sizeof...(Is)> result = { { (static_cast<void>(Is), value)... } };
return result;
}
/**
* @brief Create a constexpr array object whose values all are 0
*
* @tparam T
* @tparam N
* @return constexpr std::array<T, N>
*
* @details Use in the same context as create_array, i.e. when encapsulated type has a default constructor that is not
* constexpr
*/
template <typename T, size_t N> constexpr std::array<T, N> create_empty_array()
{
return create_array(T(0), std::make_index_sequence<N>());
}
}; // namespace bb

View File

@@ -0,0 +1,129 @@
#pragma once
#include "../env/logstr.hpp"
#include "../stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp"
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#define BENCHMARK_INFO_PREFIX "##BENCHMARK_INFO_PREFIX##"
#define BENCHMARK_INFO_SEPARATOR "#"
#define BENCHMARK_INFO_SUFFIX "##BENCHMARK_INFO_SUFFIX##"
template <typename... Args> std::string format(Args... args)
{
std::ostringstream os;
((os << args), ...);
return os.str();
}
template <typename T> void benchmark_format_chain(std::ostream& os, T const& first)
{
// We will be saving these values to a CSV file, so we can't tolerate commas
std::stringstream current_argument;
current_argument << first;
std::string current_argument_string = current_argument.str();
std::replace(current_argument_string.begin(), current_argument_string.end(), ',', ';');
os << current_argument_string << BENCHMARK_INFO_SUFFIX;
}
template <typename T, typename... Args>
void benchmark_format_chain(std::ostream& os, T const& first, Args const&... args)
{
// We will be saving these values to a CSV file, so we can't tolerate commas
std::stringstream current_argument;
current_argument << first;
std::string current_argument_string = current_argument.str();
std::replace(current_argument_string.begin(), current_argument_string.end(), ',', ';');
os << current_argument_string << BENCHMARK_INFO_SEPARATOR;
benchmark_format_chain(os, args...);
}
template <typename... Args> std::string benchmark_format(Args... args)
{
std::ostringstream os;
os << BENCHMARK_INFO_PREFIX;
benchmark_format_chain(os, args...);
return os.str();
}
#if NDEBUG
template <typename... Args> inline void debug(Args... args)
{
logstr(format(args...).c_str());
}
#else
template <typename... Args> inline void debug(Args... /*unused*/) {}
#endif
template <typename... Args> inline void info(Args... args)
{
logstr(format(args...).c_str());
}
template <typename... Args> inline void important(Args... args)
{
logstr(format("important: ", args...).c_str());
}
/**
* @brief Info used to store circuit statistics during CI/CD with concrete structure. Writes straight to log
*
* @details Automatically appends the necessary prefix and suffix, as well as separators.
*
* @tparam Args
* @param args
*/
#ifdef CI
template <typename Arg1, typename Arg2, typename Arg3, typename Arg4, typename Arg5>
inline void benchmark_info(Arg1 composer, Arg2 class_name, Arg3 operation, Arg4 metric, Arg5 value)
{
logstr(benchmark_format(composer, class_name, operation, metric, value).c_str());
}
#else
template <typename... Args> inline void benchmark_info(Args... /*unused*/) {}
#endif
/**
* @brief A class for saving benchmarks and printing them all at once in the end of the function.
*
*/
class BenchmarkInfoCollator {
std::vector<std::string> saved_benchmarks;
public:
BenchmarkInfoCollator() = default;
BenchmarkInfoCollator(const BenchmarkInfoCollator& other) = default;
BenchmarkInfoCollator(BenchmarkInfoCollator&& other) = default;
BenchmarkInfoCollator& operator=(const BenchmarkInfoCollator& other) = default;
BenchmarkInfoCollator& operator=(BenchmarkInfoCollator&& other) = default;
/**
* @brief Info used to store circuit statistics during CI/CD with concrete structure. Stores string in vector for now
* (used to flush all benchmarks at the end of test).
*
* @details Automatically appends the necessary prefix and suffix, as well as separators.
*
* @tparam Args
* @param args
*/
#ifdef CI
template <typename Arg1, typename Arg2, typename Arg3, typename Arg4, typename Arg5>
inline void benchmark_info_deferred(Arg1 composer, Arg2 class_name, Arg3 operation, Arg4 metric, Arg5 value)
{
saved_benchmarks.push_back(benchmark_format(composer, class_name, operation, metric, value).c_str());
}
#else
explicit BenchmarkInfoCollator(std::vector<std::string> saved_benchmarks)
: saved_benchmarks(std::move(saved_benchmarks))
{}
template <typename... Args> inline void benchmark_info_deferred(Args... /*unused*/) {}
#endif
~BenchmarkInfoCollator()
{
for (auto& x : saved_benchmarks) {
logstr(x.c_str());
}
}
};

View File

@@ -0,0 +1,82 @@
#pragma once
#include "log.hpp"
#include "memory.h"
#include "wasm_export.hpp"
#include <cstdlib>
#include <memory>
// #include <malloc.h>
#define pad(size, alignment) (size - (size % alignment) + ((size % alignment) == 0 ? 0 : alignment))
#ifdef __APPLE__
inline void* aligned_alloc(size_t alignment, size_t size)
{
void* t = 0;
posix_memalign(&t, alignment, size);
if (t == 0) {
info("bad alloc of size: ", size);
std::abort();
}
return t;
}
inline void aligned_free(void* mem)
{
free(mem);
}
#endif
#if defined(__linux__) || defined(__wasm__)
inline void* protected_aligned_alloc(size_t alignment, size_t size)
{
size += (size % alignment);
void* t = nullptr;
// pad size to alignment
if (size % alignment != 0) {
size += alignment - (size % alignment);
}
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
t = aligned_alloc(alignment, size);
if (t == nullptr) {
info("bad alloc of size: ", size);
std::abort();
}
return t;
}
#define aligned_alloc protected_aligned_alloc
inline void aligned_free(void* mem)
{
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory, cppcoreguidelines-no-malloc)
free(mem);
}
#endif
#ifdef _WIN32
inline void* aligned_alloc(size_t alignment, size_t size)
{
return _aligned_malloc(size, alignment);
}
inline void aligned_free(void* mem)
{
_aligned_free(mem);
}
#endif
// inline void print_malloc_info()
// {
// struct mallinfo minfo = mallinfo();
// info("Total non-mmapped bytes (arena): ", minfo.arena);
// info("Number of free chunks (ordblks): ", minfo.ordblks);
// info("Number of fastbin blocks (smblks): ", minfo.smblks);
// info("Number of mmapped regions (hblks): ", minfo.hblks);
// info("Space allocated in mmapped regions (hblkhd): ", minfo.hblkhd);
// info("Maximum total allocated space (usmblks): ", minfo.usmblks);
// info("Space available in freed fastbin blocks (fsmblks): ", minfo.fsmblks);
// info("Total allocated space (uordblks): ", minfo.uordblks);
// info("Total free space (fordblks): ", minfo.fordblks);
// info("Top-most, releasable space (keepcost): ", minfo.keepcost);
// }

View File

@@ -0,0 +1,15 @@
#pragma once
#if defined(__linux__) || defined(__wasm__)
#include <arpa/inet.h>
#include <endian.h>
#define ntohll be64toh
#define htonll htobe64
#endif
inline bool is_little_endian()
{
constexpr int num = 42;
// NOLINTNEXTLINE Nope. nope nope nope nope nope.
return (*(char*)&num == 42);
}

View File

@@ -0,0 +1,104 @@
#include <cstddef>
#ifdef BB_USE_OP_COUNT
#include "op_count.hpp"
#include <iostream>
#include <sstream>
#include <thread>
namespace bb::detail {
GlobalOpCountContainer::~GlobalOpCountContainer()
{
// This is useful for printing counts at the end of non-benchmarks.
// See op_count_google_bench.hpp for benchmarks.
// print();
}
void GlobalOpCountContainer::add_entry(const char* key, const std::shared_ptr<OpStats>& count)
{
std::unique_lock<std::mutex> lock(mutex);
std::stringstream ss;
ss << std::this_thread::get_id();
counts.push_back({ key, ss.str(), count });
}
void GlobalOpCountContainer::print() const
{
std::cout << "print_op_counts() START" << std::endl;
for (const Entry& entry : counts) {
if (entry.count->count > 0) {
std::cout << entry.key << "\t" << entry.count->count << "\t[thread=" << entry.thread_id << "]" << std::endl;
}
if (entry.count->time > 0) {
std::cout << entry.key << "(t)\t" << static_cast<double>(entry.count->time) / 1000000.0
<< "ms\t[thread=" << entry.thread_id << "]" << std::endl;
}
if (entry.count->cycles > 0) {
std::cout << entry.key << "(c)\t" << entry.count->cycles << "\t[thread=" << entry.thread_id << "]"
<< std::endl;
}
}
std::cout << "print_op_counts() END" << std::endl;
}
std::map<std::string, std::size_t> GlobalOpCountContainer::get_aggregate_counts() const
{
std::map<std::string, std::size_t> aggregate_counts;
for (const Entry& entry : counts) {
if (entry.count->count > 0) {
aggregate_counts[entry.key] += entry.count->count;
}
if (entry.count->time > 0) {
aggregate_counts[entry.key + "(t)"] += entry.count->time;
}
if (entry.count->cycles > 0) {
aggregate_counts[entry.key + "(c)"] += entry.count->cycles;
}
}
return aggregate_counts;
}
void GlobalOpCountContainer::clear()
{
std::unique_lock<std::mutex> lock(mutex);
for (Entry& entry : counts) {
*entry.count = OpStats();
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
GlobalOpCountContainer GLOBAL_OP_COUNTS;
OpCountCycleReporter::OpCountCycleReporter(OpStats* stats)
: stats(stats)
{
#if __clang__ && (defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86))
// Don't support any other targets but x86 clang for now, this is a bit lazy but more than fits our needs
cycles = __builtin_ia32_rdtsc();
#endif
}
OpCountCycleReporter::~OpCountCycleReporter()
{
#if __clang__ && (defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86))
// Don't support any other targets but x86 clang for now, this is a bit lazy but more than fits our needs
stats->count += 1;
stats->cycles += __builtin_ia32_rdtsc() - cycles;
#endif
}
OpCountTimeReporter::OpCountTimeReporter(OpStats* stats)
: stats(stats)
{
auto now = std::chrono::high_resolution_clock::now();
auto now_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
time = static_cast<std::size_t>(now_ns.time_since_epoch().count());
}
OpCountTimeReporter::~OpCountTimeReporter()
{
auto now = std::chrono::high_resolution_clock::now();
auto now_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
stats->count += 1;
stats->time += static_cast<std::size_t>(now_ns.time_since_epoch().count()) - time;
}
} // namespace bb::detail
#endif

View File

@@ -0,0 +1,160 @@
#pragma once
#include <memory>
#ifndef BB_USE_OP_COUNT
// require a semicolon to appease formatters
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TRACK() (void)0
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TRACK_NAME(name) (void)0
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_CYCLES_NAME(name) (void)0
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TIME_NAME(name) (void)0
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_CYCLES() (void)0
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TIME() (void)0
#else
/**
* Provides an abstraction that counts operations based on function names.
* For efficiency, we spread out counts across threads.
*/
#include "./compiler_hints.hpp"
#include <algorithm>
#include <atomic>
#include <cstdlib>
#include <map>
#include <mutex>
#include <string>
#include <vector>
namespace bb::detail {
// Compile-time string
// See e.g. https://www.reddit.com/r/cpp_questions/comments/pumi9r/does_c20_not_support_string_literals_as_template/
template <std::size_t N> struct OperationLabel {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
constexpr OperationLabel(const char (&str)[N])
{
for (std::size_t i = 0; i < N; ++i) {
value[i] = str[i];
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
char value[N];
};
struct OpStats {
std::size_t count = 0;
std::size_t time = 0;
std::size_t cycles = 0;
};
// Contains all statically known op counts
struct GlobalOpCountContainer {
public:
struct Entry {
std::string key;
std::string thread_id;
std::shared_ptr<OpStats> count;
};
~GlobalOpCountContainer();
std::mutex mutex;
std::vector<Entry> counts;
void print() const;
// NOTE: Should be called when other threads aren't active
void clear();
void add_entry(const char* key, const std::shared_ptr<OpStats>& count);
std::map<std::string, std::size_t> get_aggregate_counts() const;
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
extern GlobalOpCountContainer GLOBAL_OP_COUNTS;
template <OperationLabel Op> struct GlobalOpCount {
public:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local std::shared_ptr<OpStats> stats;
static OpStats* ensure_stats()
{
if (BB_UNLIKELY(stats == nullptr)) {
stats = std::make_shared<OpStats>();
GLOBAL_OP_COUNTS.add_entry(Op.value, stats);
}
return stats.get();
}
static constexpr void increment_op_count()
{
#ifndef BB_USE_OP_COUNT_TIME_ONLY
if (std::is_constant_evaluated()) {
// We do nothing if the compiler tries to run this
return;
}
ensure_stats();
stats->count++;
#endif
}
static constexpr void add_cycle_time(std::size_t cycles)
{
#ifndef BB_USE_OP_COUNT_TRACK_ONLY
if (std::is_constant_evaluated()) {
// We do nothing if the compiler tries to run this
return;
}
ensure_stats();
stats->cycles += cycles;
#else
static_cast<void>(cycles);
#endif
}
static constexpr void add_clock_time(std::size_t time)
{
#ifndef BB_USE_OP_COUNT_TRACK_ONLY
if (std::is_constant_evaluated()) {
// We do nothing if the compiler tries to run this
return;
}
ensure_stats();
stats->time += time;
#else
static_cast<void>(time);
#endif
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
template <OperationLabel Op> thread_local std::shared_ptr<OpStats> GlobalOpCount<Op>::stats;
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
struct OpCountCycleReporter {
OpStats* stats;
std::size_t cycles;
OpCountCycleReporter(OpStats* stats);
~OpCountCycleReporter();
};
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
struct OpCountTimeReporter {
OpStats* stats;
std::size_t time;
OpCountTimeReporter(OpStats* stats);
~OpCountTimeReporter();
};
} // namespace bb::detail
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TRACK_NAME(name) bb::detail::GlobalOpCount<name>::increment_op_count()
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TRACK() BB_OP_COUNT_TRACK_NAME(__func__)
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_CYCLES_NAME(name) \
bb::detail::OpCountCycleReporter __bb_op_count_cyles(bb::detail::GlobalOpCount<name>::ensure_stats())
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_CYCLES() BB_OP_COUNT_CYCLES_NAME(__func__)
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TIME_NAME(name) \
bb::detail::OpCountTimeReporter __bb_op_count_time(bb::detail::GlobalOpCount<name>::ensure_stats())
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BB_OP_COUNT_TIME() BB_OP_COUNT_TIME_NAME(__func__)
#endif

View File

@@ -0,0 +1,243 @@
#include "slab_allocator.hpp"
#include <barretenberg/common/assert.hpp>
#include <barretenberg/common/log.hpp>
#include <barretenberg/common/mem.hpp>
#include <cstddef>
#include <numeric>
#include <unordered_map>
#define LOGGING 0
/**
* If we can guarantee that all slabs will be released before the allocator is destroyed, we wouldn't need this.
* However, there is (and maybe again) cases where a global is holding onto a slab. In such a case you will have
* issues if the runtime frees the allocator before the slab is released. The effect is subtle, so it's worth
* protecting against rather than just saying "don't do globals". But you know, don't do globals...
* (Irony of global slab allocator noted).
*/
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
bool allocator_destroyed = false;
// Slabs that are being manually managed by the user.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
std::unordered_map<void*, std::shared_ptr<void>> manual_slabs;
#ifndef NO_MULTITHREADING
// The manual slabs unordered map is not thread-safe, so we need to manage access to it when multithreaded.
std::mutex manual_slabs_mutex;
#endif
template <typename... Args> inline void dbg_info(Args... args)
{
#if LOGGING == 1
info(args...);
#else
// Suppress warning.
(void)(sizeof...(args));
#endif
}
/**
* Allows preallocating memory slabs sized to serve the fact that these slabs of memory follow certain sizing
* patterns and numbers based on prover system type and circuit size. Without the slab allocator, memory
* fragmentation prevents proof construction when approaching memory space limits (4GB in WASM).
*
* If no circuit_size_hint is given to the constructor, it behaves as a standard memory allocator.
*/
class SlabAllocator {
private:
size_t circuit_size_hint_;
std::map<size_t, std::list<void*>> memory_store;
#ifndef NO_MULTITHREADING
std::mutex memory_store_mutex;
#endif
public:
~SlabAllocator();
SlabAllocator() = default;
SlabAllocator(const SlabAllocator& other) = delete;
SlabAllocator(SlabAllocator&& other) = delete;
SlabAllocator& operator=(const SlabAllocator& other) = delete;
SlabAllocator& operator=(SlabAllocator&& other) = delete;
void init(size_t circuit_size_hint);
std::shared_ptr<void> get(size_t size);
size_t get_total_size();
private:
void release(void* ptr, size_t size);
};
SlabAllocator::~SlabAllocator()
{
allocator_destroyed = true;
for (auto& e : memory_store) {
for (auto& p : e.second) {
aligned_free(p);
}
}
}
void SlabAllocator::init(size_t circuit_size_hint)
{
if (circuit_size_hint <= circuit_size_hint_) {
return;
}
circuit_size_hint_ = circuit_size_hint;
// Free any existing slabs.
for (auto& e : memory_store) {
for (auto& p : e.second) {
aligned_free(p);
}
}
memory_store.clear();
dbg_info("slab allocator initing for size: ", circuit_size_hint);
if (circuit_size_hint == 0ULL) {
return;
}
// Over-allocate because we know there are requests for circuit_size + n. (somewhat arbitrary n = 512)
size_t overalloc = 512;
size_t base_size = circuit_size_hint + overalloc;
std::map<size_t, size_t> prealloc_num;
// Size comments below assume a base (circuit) size of 2^19, 524288 bytes.
// /* 0.5 MiB */ prealloc_num[base_size * 1] = 2; // Batch invert skipped temporary.
// /* 2 MiB */ prealloc_num[base_size * 4] = 4 + // Composer base wire vectors.
// 1; // Miscellaneous.
// /* 6 MiB */ prealloc_num[base_size * 12] = 2 + // next_var_index, prev_var_index
// 2; // real_variable_index, real_variable_tags
/* 16 MiB */ prealloc_num[base_size * 32] = 11; // Composer base selector vectors.
/* 32 MiB */ prealloc_num[base_size * 32 * 2] = 1; // Miscellaneous.
/* 50 MiB */ prealloc_num[base_size * 32 * 3] = 1; // Variables.
/* 64 MiB */ prealloc_num[base_size * 32 * 4] = 1 + // SRS monomial points.
4 + // Coset-fft wires.
15 + // Coset-fft constraint selectors.
8 + // Coset-fft perm selectors.
1 + // Coset-fft sorted poly.
1 + // Pippenger point_schedule.
4; // Miscellaneous.
/* 128 MiB */ prealloc_num[base_size * 32 * 8] = 1 + // Proving key evaluation domain roots.
2; // Pippenger point_pairs.
for (auto& e : prealloc_num) {
for (size_t i = 0; i < e.second; ++i) {
auto size = e.first;
memory_store[size].push_back(aligned_alloc(32, size));
dbg_info("Allocated memory slab of size: ", size, " total: ", get_total_size());
}
}
}
std::shared_ptr<void> SlabAllocator::get(size_t req_size)
{
#ifndef NO_MULTITHREADING
std::unique_lock<std::mutex> lock(memory_store_mutex);
#endif
auto it = memory_store.lower_bound(req_size);
// Can use a preallocated slab that is less than 2 times the requested size.
if (it != memory_store.end() && it->first < req_size * 2) {
size_t size = it->first;
auto* ptr = it->second.back();
it->second.pop_back();
if (it->second.empty()) {
memory_store.erase(it);
}
if (req_size >= circuit_size_hint_ && size > req_size + req_size / 10) {
dbg_info("WARNING: Using memory slab of size: ",
size,
" for requested ",
req_size,
" total: ",
get_total_size());
} else {
dbg_info("Reusing memory slab of size: ", size, " for requested ", req_size, " total: ", get_total_size());
}
return { ptr, [this, size](void* p) {
if (allocator_destroyed) {
aligned_free(p);
return;
}
this->release(p, size);
} };
}
if (req_size > static_cast<size_t>(1024 * 1024)) {
dbg_info("WARNING: Allocating unmanaged memory slab of size: ", req_size);
}
if (req_size % 32 == 0) {
return { aligned_alloc(32, req_size), aligned_free };
}
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
return { malloc(req_size), free };
}
size_t SlabAllocator::get_total_size()
{
return std::accumulate(memory_store.begin(), memory_store.end(), size_t{ 0 }, [](size_t acc, const auto& kv) {
return acc + kv.first * kv.second.size();
});
}
void SlabAllocator::release(void* ptr, size_t size)
{
#ifndef NO_MULTITHREADING
std::unique_lock<std::mutex> lock(memory_store_mutex);
#endif
memory_store[size].push_back(ptr);
// dbg_info("Pooled poly memory of size: ", size, " total: ", get_total_size());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
SlabAllocator allocator;
} // namespace
namespace bb {
void init_slab_allocator(size_t circuit_subgroup_size)
{
allocator.init(circuit_subgroup_size);
}
// auto init = ([]() {
// init_slab_allocator(524288);
// return 0;
// })();
std::shared_ptr<void> get_mem_slab(size_t size)
{
return allocator.get(size);
}
void* get_mem_slab_raw(size_t size)
{
auto slab = get_mem_slab(size);
#ifndef NO_MULTITHREADING
std::unique_lock<std::mutex> lock(manual_slabs_mutex);
#endif
manual_slabs[slab.get()] = slab;
return slab.get();
}
void free_mem_slab_raw(void* p)
{
if (allocator_destroyed) {
aligned_free(p);
return;
}
#ifndef NO_MULTITHREADING
std::unique_lock<std::mutex> lock(manual_slabs_mutex);
#endif
manual_slabs.erase(p);
}
} // namespace bb

View File

@@ -0,0 +1,78 @@
#pragma once
#include "./assert.hpp"
#include "./log.hpp"
#include <list>
#include <map>
#include <memory>
#include <unordered_map>
#ifndef NO_MULTITHREADING
#include <mutex>
#endif
namespace bb {
/**
* Allocates a bunch of memory slabs sized to serve an UltraPLONK proof construction.
* If you want normal memory allocator behavior, just don't call this init function.
*
* WARNING: If client code is still holding onto slabs from previous use, when those slabs
* are released they'll end up back in the allocator. That's probably not desired as presumably
* those slabs are now too small, so they're effectively leaked. But good client code should be releasing
* it's resources promptly anyway. It's not considered "proper use" to call init, take slab, and call init
* again, before releasing the slab.
*
* TODO: Take a composer type and allocate slabs according to those requirements?
* TODO: De-globalise. Init the allocator and pass around. Use a PolynomialFactory (PolynomialStore?).
* TODO: Consider removing, but once due-dilligence has been done that we no longer have memory limitations.
*/
void init_slab_allocator(size_t circuit_subgroup_size);
/**
* Returns a slab from the preallocated pool of slabs, or fallback to a new heap allocation (32 byte aligned).
* Ref counted result so no need to manually free.
*/
std::shared_ptr<void> get_mem_slab(size_t size);
/**
* Sometimes you want a raw pointer to a slab so you can manage when it's released manually (e.g. c_binds, containers).
* This still gets a slab with a shared_ptr, but holds the shared_ptr internally until free_mem_slab_raw is called.
*/
void* get_mem_slab_raw(size_t size);
void free_mem_slab_raw(void*);
/**
* Allocator for containers such as std::vector. Makes them leverage the underlying slab allocator where possible.
*/
template <typename T> class ContainerSlabAllocator {
public:
using value_type = T;
using pointer = T*;
using const_pointer = const T*;
using size_type = std::size_t;
template <typename U> struct rebind {
using other = ContainerSlabAllocator<U>;
};
pointer allocate(size_type n)
{
// info("ContainerSlabAllocator allocating: ", n * sizeof(T));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
return reinterpret_cast<pointer>(get_mem_slab_raw(n * sizeof(T)));
}
void deallocate(pointer p, size_type /*unused*/) { free_mem_slab_raw(p); }
friend bool operator==(const ContainerSlabAllocator<T>& /*unused*/, const ContainerSlabAllocator<T>& /*unused*/)
{
return true;
}
friend bool operator!=(const ContainerSlabAllocator<T>& /*unused*/, const ContainerSlabAllocator<T>& /*unused*/)
{
return false;
}
};
} // namespace bb

View File

@@ -0,0 +1,13 @@
#pragma once
#include "log.hpp"
#include <string>
inline void throw_or_abort [[noreturn]] (std::string const& err)
{
#ifndef __wasm__
throw std::runtime_error(err);
#else
info("abort: ", err);
std::abort();
#endif
}

View File

@@ -0,0 +1,16 @@
#pragma once
#ifdef __clang__
#define WASM_EXPORT extern "C" __attribute__((visibility("default"))) __attribute__((annotate("wasm_export")))
#define ASYNC_WASM_EXPORT \
extern "C" __attribute__((visibility("default"))) __attribute__((annotate("async_wasm_export")))
#else
#define WASM_EXPORT extern "C" __attribute__((visibility("default")))
#define ASYNC_WASM_EXPORT extern "C" __attribute__((visibility("default")))
#endif
#ifdef __wasm__
// Allow linker to not link this
#define WASM_IMPORT(name) extern "C" __attribute__((import_module("env"), import_name(name)))
#else
#define WASM_IMPORT(name) extern "C"
#endif

View File

@@ -0,0 +1 @@
barretenberg_module(crypto_blake3s)

View File

@@ -0,0 +1,80 @@
#pragma once
/*
BLAKE3 reference source code package - C implementations
Intellectual property:
The Rust code is copyright Jack O'Connor, 2019-2020.
The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020.
The assembly code is copyright Samuel Neves, 2019-2020.
This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache
License 2.0.
- CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0
- Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0
More information about the BLAKE3 hash function can be found at
https://github.com/BLAKE3-team/BLAKE3.
*/
#ifndef BLAKE3_IMPL_H
#define BLAKE3_IMPL_H
#include <cstddef>
#include <cstdint>
#include <cstring>
#include "blake3s.hpp"
namespace blake3 {
// Right rotates 32 bit inputs
constexpr uint32_t rotr32(uint32_t w, uint32_t c)
{
return (w >> c) | (w << (32 - c));
}
constexpr uint32_t load32(const uint8_t* src)
{
return (static_cast<uint32_t>(src[0]) << 0) | (static_cast<uint32_t>(src[1]) << 8) |
(static_cast<uint32_t>(src[2]) << 16) | (static_cast<uint32_t>(src[3]) << 24);
}
constexpr void load_key_words(const std::array<uint8_t, BLAKE3_KEY_LEN>& key, key_array& key_words)
{
key_words[0] = load32(&key[0]);
key_words[1] = load32(&key[4]);
key_words[2] = load32(&key[8]);
key_words[3] = load32(&key[12]);
key_words[4] = load32(&key[16]);
key_words[5] = load32(&key[20]);
key_words[6] = load32(&key[24]);
key_words[7] = load32(&key[28]);
}
constexpr void store32(uint8_t* dst, uint32_t w)
{
dst[0] = static_cast<uint8_t>(w >> 0);
dst[1] = static_cast<uint8_t>(w >> 8);
dst[2] = static_cast<uint8_t>(w >> 16);
dst[3] = static_cast<uint8_t>(w >> 24);
}
constexpr void store_cv_words(out_array& bytes_out, key_array& cv_words)
{
store32(&bytes_out[0], cv_words[0]);
store32(&bytes_out[4], cv_words[1]);
store32(&bytes_out[8], cv_words[2]);
store32(&bytes_out[12], cv_words[3]);
store32(&bytes_out[16], cv_words[4]);
store32(&bytes_out[20], cv_words[5]);
store32(&bytes_out[24], cv_words[6]);
store32(&bytes_out[28], cv_words[7]);
}
} // namespace blake3
#include "blake3s.tcc"
#endif

View File

@@ -0,0 +1,113 @@
/*
BLAKE3 reference source code package - C implementations
Intellectual property:
The Rust code is copyright Jack O'Connor, 2019-2020.
The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020.
The assembly code is copyright Samuel Neves, 2019-2020.
This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache
License 2.0.
- CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0
- Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0
More information about the BLAKE3 hash function can be found at
https://github.com/BLAKE3-team/BLAKE3.
NOTE: We have modified the original code from the BLAKE3 reference C implementation.
The following code works ONLY for inputs of size less than 1024 bytes. This kind of constraint
on the input size greatly simplifies the code and helps us get rid of the recursive merkle-tree
like operations on chunks (data of size 1024 bytes). This is because we would always be using BLAKE3
hashing for inputs of size 32 bytes (or lesser) in barretenberg. The full C++ version of BLAKE3
from the original authors is in the module `../crypto/blake3s_full`.
Also, the length of the output in this specific implementation is fixed at 32 bytes which is the only
version relevant to Barretenberg.
*/
#pragma once
#include <array>
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
namespace blake3 {
// internal flags
enum blake3_flags {
CHUNK_START = 1 << 0,
CHUNK_END = 1 << 1,
PARENT = 1 << 2,
ROOT = 1 << 3,
KEYED_HASH = 1 << 4,
DERIVE_KEY_CONTEXT = 1 << 5,
DERIVE_KEY_MATERIAL = 1 << 6,
};
// constants
enum blake3s_constant {
BLAKE3_KEY_LEN = 32,
BLAKE3_OUT_LEN = 32,
BLAKE3_BLOCK_LEN = 64,
BLAKE3_CHUNK_LEN = 1024,
BLAKE3_MAX_DEPTH = 54
};
using key_array = std::array<uint32_t, BLAKE3_KEY_LEN>;
using block_array = std::array<uint8_t, BLAKE3_BLOCK_LEN>;
using state_array = std::array<uint32_t, 16>;
using out_array = std::array<uint8_t, BLAKE3_OUT_LEN>;
static constexpr key_array IV = { 0x6A09E667UL, 0xBB67AE85UL, 0x3C6EF372UL, 0xA54FF53AUL,
0x510E527FUL, 0x9B05688CUL, 0x1F83D9ABUL, 0x5BE0CD19UL };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_0 = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_1 = { 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_2 = { 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_3 = { 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_4 = { 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_5 = { 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 };
static constexpr std::array<uint8_t, 16> MSG_SCHEDULE_6 = { 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 };
static constexpr std::array<std::array<uint8_t, 16>, 7> MSG_SCHEDULE = {
MSG_SCHEDULE_0, MSG_SCHEDULE_1, MSG_SCHEDULE_2, MSG_SCHEDULE_3, MSG_SCHEDULE_4, MSG_SCHEDULE_5, MSG_SCHEDULE_6,
};
struct blake3_hasher {
key_array key;
key_array cv;
block_array buf;
uint8_t buf_len = 0;
uint8_t blocks_compressed = 0;
uint8_t flags = 0;
};
inline const char* blake3_version()
{
static const std::string version = "0.3.7";
return version.c_str();
}
constexpr void blake3_hasher_init(blake3_hasher* self);
constexpr void blake3_hasher_update(blake3_hasher* self, const uint8_t* input, size_t input_len);
constexpr void blake3_hasher_finalize(const blake3_hasher* self, uint8_t* out);
constexpr void g(state_array& state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y);
constexpr void round_fn(state_array& state, const uint32_t* msg, size_t round);
constexpr void compress_pre(
state_array& state, const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags);
constexpr void blake3_compress_in_place(key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags);
constexpr void blake3_compress_xof(
const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags, uint8_t* out);
constexpr std::array<uint8_t, BLAKE3_OUT_LEN> blake3s_constexpr(const uint8_t* input, size_t input_size);
inline std::vector<uint8_t> blake3s(std::vector<uint8_t> const& input);
} // namespace blake3
#include "blake3-impl.hpp"

View File

@@ -0,0 +1,263 @@
#pragma once
/*
BLAKE3 reference source code package - C implementations
Intellectual property:
The Rust code is copyright Jack O'Connor, 2019-2020.
The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020.
The assembly code is copyright Samuel Neves, 2019-2020.
This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache
License 2.0.
- CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0
- Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0
More information about the BLAKE3 hash function can be found at
https://github.com/BLAKE3-team/BLAKE3.
NOTE: We have modified the original code from the BLAKE3 reference C implementation.
The following code works ONLY for inputs of size less than 1024 bytes. This kind of constraint
on the input size greatly simplifies the code and helps us get rid of the recursive merkle-tree
like operations on chunks (data of size 1024 bytes). This is because we would always be using BLAKE3
hashing for inputs of size 32 bytes (or lesser) in barretenberg. The full C++ version of BLAKE3
from the original authors is in the module `../crypto/blake3s_full`.
Also, the length of the output in this specific implementation is fixed at 32 bytes which is the only
version relevant to Barretenberg.
*/
#include <iostream>
#include <type_traits>
#include "blake3s.hpp"
namespace blake3 {
/*
* Core Blake3s functions. These are similar to that of Blake2s except for a few
* constant parameters and fewer rounds.
*
*/
constexpr void g(state_array& state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y)
{
state[a] = state[a] + state[b] + x;
state[d] = rotr32(state[d] ^ state[a], 16);
state[c] = state[c] + state[d];
state[b] = rotr32(state[b] ^ state[c], 12);
state[a] = state[a] + state[b] + y;
state[d] = rotr32(state[d] ^ state[a], 8);
state[c] = state[c] + state[d];
state[b] = rotr32(state[b] ^ state[c], 7);
}
constexpr void round_fn(state_array& state, const uint32_t* msg, size_t round)
{
// Select the message schedule based on the round.
const auto schedule = MSG_SCHEDULE[round];
// Mix the columns.
g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
// Mix the rows.
g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
}
constexpr void compress_pre(
state_array& state, const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags)
{
std::array<uint32_t, 16> block_words;
block_words[0] = load32(&block[0]);
block_words[1] = load32(&block[4]);
block_words[2] = load32(&block[8]);
block_words[3] = load32(&block[12]);
block_words[4] = load32(&block[16]);
block_words[5] = load32(&block[20]);
block_words[6] = load32(&block[24]);
block_words[7] = load32(&block[28]);
block_words[8] = load32(&block[32]);
block_words[9] = load32(&block[36]);
block_words[10] = load32(&block[40]);
block_words[11] = load32(&block[44]);
block_words[12] = load32(&block[48]);
block_words[13] = load32(&block[52]);
block_words[14] = load32(&block[56]);
block_words[15] = load32(&block[60]);
state[0] = cv[0];
state[1] = cv[1];
state[2] = cv[2];
state[3] = cv[3];
state[4] = cv[4];
state[5] = cv[5];
state[6] = cv[6];
state[7] = cv[7];
state[8] = IV[0];
state[9] = IV[1];
state[10] = IV[2];
state[11] = IV[3];
state[12] = 0;
state[13] = 0;
state[14] = static_cast<uint32_t>(block_len);
state[15] = static_cast<uint32_t>(flags);
round_fn(state, &block_words[0], 0);
round_fn(state, &block_words[0], 1);
round_fn(state, &block_words[0], 2);
round_fn(state, &block_words[0], 3);
round_fn(state, &block_words[0], 4);
round_fn(state, &block_words[0], 5);
round_fn(state, &block_words[0], 6);
}
constexpr void blake3_compress_in_place(key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags)
{
state_array state;
compress_pre(state, cv, block, block_len, flags);
cv[0] = state[0] ^ state[8];
cv[1] = state[1] ^ state[9];
cv[2] = state[2] ^ state[10];
cv[3] = state[3] ^ state[11];
cv[4] = state[4] ^ state[12];
cv[5] = state[5] ^ state[13];
cv[6] = state[6] ^ state[14];
cv[7] = state[7] ^ state[15];
}
constexpr void blake3_compress_xof(
const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags, uint8_t* out)
{
state_array state;
compress_pre(state, cv, block, block_len, flags);
store32(&out[0], state[0] ^ state[8]);
store32(&out[4], state[1] ^ state[9]);
store32(&out[8], state[2] ^ state[10]);
store32(&out[12], state[3] ^ state[11]);
store32(&out[16], state[4] ^ state[12]);
store32(&out[20], state[5] ^ state[13]);
store32(&out[24], state[6] ^ state[14]);
store32(&out[28], state[7] ^ state[15]);
store32(&out[32], state[8] ^ cv[0]);
store32(&out[36], state[9] ^ cv[1]);
store32(&out[40], state[10] ^ cv[2]);
store32(&out[44], state[11] ^ cv[3]);
store32(&out[48], state[12] ^ cv[4]);
store32(&out[52], state[13] ^ cv[5]);
store32(&out[56], state[14] ^ cv[6]);
store32(&out[60], state[15] ^ cv[7]);
}
constexpr uint8_t maybe_start_flag(const blake3_hasher* self)
{
if (self->blocks_compressed == 0) {
return CHUNK_START;
}
return 0;
}
struct output_t {
key_array input_cv = {};
block_array block = {};
uint8_t block_len = 0;
uint8_t flags = 0;
};
constexpr output_t make_output(const key_array& input_cv, const uint8_t* block, uint8_t block_len, uint8_t flags)
{
output_t ret;
for (size_t i = 0; i < (BLAKE3_OUT_LEN >> 2); ++i) {
ret.input_cv[i] = input_cv[i];
}
for (size_t i = 0; i < BLAKE3_BLOCK_LEN; i++) {
ret.block[i] = block[i];
}
ret.block_len = block_len;
ret.flags = flags;
return ret;
}
constexpr void blake3_hasher_init(blake3_hasher* self)
{
for (size_t i = 0; i < (BLAKE3_KEY_LEN >> 2); ++i) {
self->key[i] = IV[i];
self->cv[i] = IV[i];
}
for (size_t i = 0; i < BLAKE3_BLOCK_LEN; i++) {
self->buf[i] = 0;
}
self->buf_len = 0;
self->blocks_compressed = 0;
self->flags = 0;
}
constexpr void blake3_hasher_update(blake3_hasher* self, const uint8_t* input, size_t input_len)
{
if (input_len == 0) {
return;
}
while (input_len > BLAKE3_BLOCK_LEN) {
blake3_compress_in_place(self->cv, input, BLAKE3_BLOCK_LEN, self->flags | maybe_start_flag(self));
self->blocks_compressed = static_cast<uint8_t>(self->blocks_compressed + 1U);
input += BLAKE3_BLOCK_LEN;
input_len -= BLAKE3_BLOCK_LEN;
}
size_t take = BLAKE3_BLOCK_LEN - (static_cast<size_t>(self->buf_len));
if (take > input_len) {
take = input_len;
}
uint8_t* dest = &self->buf[0] + (static_cast<size_t>(self->buf_len));
for (size_t i = 0; i < take; i++) {
dest[i] = input[i];
}
self->buf_len = static_cast<uint8_t>(self->buf_len + static_cast<uint8_t>(take));
input_len -= take;
}
constexpr void blake3_hasher_finalize(const blake3_hasher* self, uint8_t* out)
{
uint8_t block_flags = self->flags | maybe_start_flag(self) | CHUNK_END;
output_t output = make_output(self->cv, &self->buf[0], self->buf_len, block_flags);
block_array wide_buf;
blake3_compress_xof(output.input_cv, &output.block[0], output.block_len, output.flags | ROOT, &wide_buf[0]);
for (size_t i = 0; i < BLAKE3_OUT_LEN; i++) {
out[i] = wide_buf[i];
}
}
std::vector<uint8_t> blake3s(std::vector<uint8_t> const& input)
{
blake3_hasher hasher;
blake3_hasher_init(&hasher);
blake3_hasher_update(&hasher, static_cast<const uint8_t*>(input.data()), input.size());
std::vector<uint8_t> output(BLAKE3_OUT_LEN);
blake3_hasher_finalize(&hasher, &output[0]);
return output;
}
constexpr std::array<uint8_t, BLAKE3_OUT_LEN> blake3s_constexpr(const uint8_t* input, const size_t input_size)
{
blake3_hasher hasher;
blake3_hasher_init(&hasher);
blake3_hasher_update(&hasher, input, input_size);
std::array<uint8_t, BLAKE3_OUT_LEN> output;
blake3_hasher_finalize(&hasher, &output[0]);
return output;
}
} // namespace blake3

View File

@@ -0,0 +1,11 @@
#include "../../common/wasm_export.hpp"
#include "../../ecc/curves/bn254/fr.hpp"
#include "blake3s.hpp"
WASM_EXPORT void blake3s_to_field(uint8_t const* data, size_t length, uint8_t* r)
{
std::vector<uint8_t> inputv(data, data + length);
std::vector<uint8_t> output = blake3::blake3s(inputv);
auto result = bb::fr::serialize_from_buffer(output.data());
bb::fr::serialize_to_buffer(result, r);
}

View File

@@ -0,0 +1 @@
barretenberg_module(crypto_keccak)

View File

@@ -0,0 +1,20 @@
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
* Copyright 2018-2019 Pawel Bylica.
* Licensed under the Apache License, Version 2.0.
*/
#pragma once
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
struct keccak256 {
uint64_t word64s[4];
};
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,133 @@
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
* Copyright 2018-2019 Pawel Bylica.
* Licensed under the Apache License, Version 2.0.
*/
#include "keccak.hpp"
#include "./hash_types.hpp"
#if _MSC_VER
#include <string.h>
#define __builtin_memcpy memcpy
#endif
#if _WIN32
/* On Windows assume little endian. */
#define __LITTLE_ENDIAN 1234
#define __BIG_ENDIAN 4321
#define __BYTE_ORDER __LITTLE_ENDIAN
#elif __APPLE__
#include <machine/endian.h>
#else
#include <endian.h>
#endif
#if __BYTE_ORDER == __LITTLE_ENDIAN
#define to_le64(X) X
#else
#define to_le64(X) __builtin_bswap64(X)
#endif
#if __BYTE_ORDER == __LITTLE_ENDIAN
#define to_be64(X) __builtin_bswap64(X)
#else
#define to_be64(X) X
#endif
/** Loads 64-bit integer from given memory location as little-endian number. */
static inline uint64_t load_le(const uint8_t* data)
{
/* memcpy is the best way of expressing the intention. Every compiler will
optimize is to single load instruction if the target architecture
supports unaligned memory access (GCC and clang even in O0).
This is great trick because we are violating C/C++ memory alignment
restrictions with no performance penalty. */
uint64_t word;
__builtin_memcpy(&word, data, sizeof(word));
return to_le64(word);
}
static inline void keccak(uint64_t* out, size_t bits, const uint8_t* data, size_t size)
{
static const size_t word_size = sizeof(uint64_t);
const size_t hash_size = bits / 8;
const size_t block_size = (1600 - bits * 2) / 8;
size_t i;
uint64_t* state_iter;
uint64_t last_word = 0;
uint8_t* last_word_iter = (uint8_t*)&last_word;
uint64_t state[25] = { 0 };
while (size >= block_size) {
for (i = 0; i < (block_size / word_size); ++i) {
state[i] ^= load_le(data);
data += word_size;
}
ethash_keccakf1600(state);
size -= block_size;
}
state_iter = state;
while (size >= word_size) {
*state_iter ^= load_le(data);
++state_iter;
data += word_size;
size -= word_size;
}
while (size > 0) {
*last_word_iter = *data;
++last_word_iter;
++data;
--size;
}
*last_word_iter = 0x01;
*state_iter ^= to_le64(last_word);
state[(block_size / word_size) - 1] ^= 0x8000000000000000;
ethash_keccakf1600(state);
for (i = 0; i < (hash_size / word_size); ++i)
out[i] = to_le64(state[i]);
}
struct keccak256 ethash_keccak256(const uint8_t* data, size_t size) NOEXCEPT
{
struct keccak256 hash;
keccak(hash.word64s, 256, data, size);
return hash;
}
struct keccak256 hash_field_elements(const uint64_t* limbs, size_t num_elements)
{
uint8_t input_buffer[num_elements * 32];
for (size_t i = 0; i < num_elements; ++i) {
for (size_t j = 0; j < 4; ++j) {
uint64_t word = (limbs[i * 4 + j]);
size_t idx = i * 32 + j * 8;
input_buffer[idx] = (uint8_t)((word >> 56) & 0xff);
input_buffer[idx + 1] = (uint8_t)((word >> 48) & 0xff);
input_buffer[idx + 2] = (uint8_t)((word >> 40) & 0xff);
input_buffer[idx + 3] = (uint8_t)((word >> 32) & 0xff);
input_buffer[idx + 4] = (uint8_t)((word >> 24) & 0xff);
input_buffer[idx + 5] = (uint8_t)((word >> 16) & 0xff);
input_buffer[idx + 6] = (uint8_t)((word >> 8) & 0xff);
input_buffer[idx + 7] = (uint8_t)(word & 0xff);
}
}
return ethash_keccak256(input_buffer, num_elements * 32);
}
struct keccak256 hash_field_element(const uint64_t* limb)
{
return hash_field_elements(limb, 1);
}

View File

@@ -0,0 +1,40 @@
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
* Copyright 2018-2019 Pawel Bylica.
* Licensed under the Apache License, Version 2.0.
*/
#pragma once
#include "./hash_types.hpp"
#include <stddef.h>
#ifdef __cplusplus
#define NOEXCEPT noexcept
#else
#define NOEXCEPT
#endif
#ifdef __cplusplus
extern "C" {
#endif
/**
* The Keccak-f[1600] function.
*
* The implementation of the Keccak-f function with 1600-bit width of the permutation (b).
* The size of the state is also 1600 bit what gives 25 64-bit words.
*
* @param state The state of 25 64-bit words on which the permutation is to be performed.
*/
void ethash_keccakf1600(uint64_t state[25]) NOEXCEPT;
struct keccak256 ethash_keccak256(const uint8_t* data, size_t size) NOEXCEPT;
struct keccak256 hash_field_elements(const uint64_t* limbs, size_t num_elements);
struct keccak256 hash_field_element(const uint64_t* limb);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,235 @@
/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm.
* Copyright 2018-2019 Pawel Bylica.
* Licensed under the Apache License, Version 2.0.
*/
#include "keccak.hpp"
#include <stdint.h>
static uint64_t rol(uint64_t x, unsigned s)
{
return (x << s) | (x >> (64 - s));
}
static const uint64_t round_constants[24] = {
0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, 0x000000000000808b,
0x0000000080000001, 0x8000000080008081, 0x8000000000008009, 0x000000000000008a, 0x0000000000000088,
0x0000000080008009, 0x000000008000000a, 0x000000008000808b, 0x800000000000008b, 0x8000000000008089,
0x8000000000008003, 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a,
0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008,
};
void ethash_keccakf1600(uint64_t state[25]) NOEXCEPT
{
/* The implementation based on the "simple" implementation by Ronny Van Keer. */
int round;
uint64_t Aba, Abe, Abi, Abo, Abu;
uint64_t Aga, Age, Agi, Ago, Agu;
uint64_t Aka, Ake, Aki, Ako, Aku;
uint64_t Ama, Ame, Ami, Amo, Amu;
uint64_t Asa, Ase, Asi, Aso, Asu;
uint64_t Eba, Ebe, Ebi, Ebo, Ebu;
uint64_t Ega, Ege, Egi, Ego, Egu;
uint64_t Eka, Eke, Eki, Eko, Eku;
uint64_t Ema, Eme, Emi, Emo, Emu;
uint64_t Esa, Ese, Esi, Eso, Esu;
uint64_t Ba, Be, Bi, Bo, Bu;
uint64_t Da, De, Di, Do, Du;
Aba = state[0];
Abe = state[1];
Abi = state[2];
Abo = state[3];
Abu = state[4];
Aga = state[5];
Age = state[6];
Agi = state[7];
Ago = state[8];
Agu = state[9];
Aka = state[10];
Ake = state[11];
Aki = state[12];
Ako = state[13];
Aku = state[14];
Ama = state[15];
Ame = state[16];
Ami = state[17];
Amo = state[18];
Amu = state[19];
Asa = state[20];
Ase = state[21];
Asi = state[22];
Aso = state[23];
Asu = state[24];
for (round = 0; round < 24; round += 2) {
/* Round (round + 0): Axx -> Exx */
Ba = Aba ^ Aga ^ Aka ^ Ama ^ Asa;
Be = Abe ^ Age ^ Ake ^ Ame ^ Ase;
Bi = Abi ^ Agi ^ Aki ^ Ami ^ Asi;
Bo = Abo ^ Ago ^ Ako ^ Amo ^ Aso;
Bu = Abu ^ Agu ^ Aku ^ Amu ^ Asu;
Da = Bu ^ rol(Be, 1);
De = Ba ^ rol(Bi, 1);
Di = Be ^ rol(Bo, 1);
Do = Bi ^ rol(Bu, 1);
Du = Bo ^ rol(Ba, 1);
Ba = Aba ^ Da;
Be = rol(Age ^ De, 44);
Bi = rol(Aki ^ Di, 43);
Bo = rol(Amo ^ Do, 21);
Bu = rol(Asu ^ Du, 14);
Eba = Ba ^ (~Be & Bi) ^ round_constants[round];
Ebe = Be ^ (~Bi & Bo);
Ebi = Bi ^ (~Bo & Bu);
Ebo = Bo ^ (~Bu & Ba);
Ebu = Bu ^ (~Ba & Be);
Ba = rol(Abo ^ Do, 28);
Be = rol(Agu ^ Du, 20);
Bi = rol(Aka ^ Da, 3);
Bo = rol(Ame ^ De, 45);
Bu = rol(Asi ^ Di, 61);
Ega = Ba ^ (~Be & Bi);
Ege = Be ^ (~Bi & Bo);
Egi = Bi ^ (~Bo & Bu);
Ego = Bo ^ (~Bu & Ba);
Egu = Bu ^ (~Ba & Be);
Ba = rol(Abe ^ De, 1);
Be = rol(Agi ^ Di, 6);
Bi = rol(Ako ^ Do, 25);
Bo = rol(Amu ^ Du, 8);
Bu = rol(Asa ^ Da, 18);
Eka = Ba ^ (~Be & Bi);
Eke = Be ^ (~Bi & Bo);
Eki = Bi ^ (~Bo & Bu);
Eko = Bo ^ (~Bu & Ba);
Eku = Bu ^ (~Ba & Be);
Ba = rol(Abu ^ Du, 27);
Be = rol(Aga ^ Da, 36);
Bi = rol(Ake ^ De, 10);
Bo = rol(Ami ^ Di, 15);
Bu = rol(Aso ^ Do, 56);
Ema = Ba ^ (~Be & Bi);
Eme = Be ^ (~Bi & Bo);
Emi = Bi ^ (~Bo & Bu);
Emo = Bo ^ (~Bu & Ba);
Emu = Bu ^ (~Ba & Be);
Ba = rol(Abi ^ Di, 62);
Be = rol(Ago ^ Do, 55);
Bi = rol(Aku ^ Du, 39);
Bo = rol(Ama ^ Da, 41);
Bu = rol(Ase ^ De, 2);
Esa = Ba ^ (~Be & Bi);
Ese = Be ^ (~Bi & Bo);
Esi = Bi ^ (~Bo & Bu);
Eso = Bo ^ (~Bu & Ba);
Esu = Bu ^ (~Ba & Be);
/* Round (round + 1): Exx -> Axx */
Ba = Eba ^ Ega ^ Eka ^ Ema ^ Esa;
Be = Ebe ^ Ege ^ Eke ^ Eme ^ Ese;
Bi = Ebi ^ Egi ^ Eki ^ Emi ^ Esi;
Bo = Ebo ^ Ego ^ Eko ^ Emo ^ Eso;
Bu = Ebu ^ Egu ^ Eku ^ Emu ^ Esu;
Da = Bu ^ rol(Be, 1);
De = Ba ^ rol(Bi, 1);
Di = Be ^ rol(Bo, 1);
Do = Bi ^ rol(Bu, 1);
Du = Bo ^ rol(Ba, 1);
Ba = Eba ^ Da;
Be = rol(Ege ^ De, 44);
Bi = rol(Eki ^ Di, 43);
Bo = rol(Emo ^ Do, 21);
Bu = rol(Esu ^ Du, 14);
Aba = Ba ^ (~Be & Bi) ^ round_constants[round + 1];
Abe = Be ^ (~Bi & Bo);
Abi = Bi ^ (~Bo & Bu);
Abo = Bo ^ (~Bu & Ba);
Abu = Bu ^ (~Ba & Be);
Ba = rol(Ebo ^ Do, 28);
Be = rol(Egu ^ Du, 20);
Bi = rol(Eka ^ Da, 3);
Bo = rol(Eme ^ De, 45);
Bu = rol(Esi ^ Di, 61);
Aga = Ba ^ (~Be & Bi);
Age = Be ^ (~Bi & Bo);
Agi = Bi ^ (~Bo & Bu);
Ago = Bo ^ (~Bu & Ba);
Agu = Bu ^ (~Ba & Be);
Ba = rol(Ebe ^ De, 1);
Be = rol(Egi ^ Di, 6);
Bi = rol(Eko ^ Do, 25);
Bo = rol(Emu ^ Du, 8);
Bu = rol(Esa ^ Da, 18);
Aka = Ba ^ (~Be & Bi);
Ake = Be ^ (~Bi & Bo);
Aki = Bi ^ (~Bo & Bu);
Ako = Bo ^ (~Bu & Ba);
Aku = Bu ^ (~Ba & Be);
Ba = rol(Ebu ^ Du, 27);
Be = rol(Ega ^ Da, 36);
Bi = rol(Eke ^ De, 10);
Bo = rol(Emi ^ Di, 15);
Bu = rol(Eso ^ Do, 56);
Ama = Ba ^ (~Be & Bi);
Ame = Be ^ (~Bi & Bo);
Ami = Bi ^ (~Bo & Bu);
Amo = Bo ^ (~Bu & Ba);
Amu = Bu ^ (~Ba & Be);
Ba = rol(Ebi ^ Di, 62);
Be = rol(Ego ^ Do, 55);
Bi = rol(Eku ^ Du, 39);
Bo = rol(Ema ^ Da, 41);
Bu = rol(Ese ^ De, 2);
Asa = Ba ^ (~Be & Bi);
Ase = Be ^ (~Bi & Bo);
Asi = Bi ^ (~Bo & Bu);
Aso = Bo ^ (~Bu & Ba);
Asu = Bu ^ (~Ba & Be);
}
state[0] = Aba;
state[1] = Abe;
state[2] = Abi;
state[3] = Abo;
state[4] = Abu;
state[5] = Aga;
state[6] = Age;
state[7] = Agi;
state[8] = Ago;
state[9] = Agu;
state[10] = Aka;
state[11] = Ake;
state[12] = Aki;
state[13] = Ako;
state[14] = Aku;
state[15] = Ama;
state[16] = Ame;
state[17] = Ami;
state[18] = Amo;
state[19] = Amu;
state[20] = Asa;
state[21] = Ase;
state[22] = Asi;
state[23] = Aso;
state[24] = Asu;
}

View File

@@ -0,0 +1,26 @@
#pragma once
#include "../bn254/fq.hpp"
#include "../bn254/fq12.hpp"
#include "../bn254/fq2.hpp"
#include "../bn254/fr.hpp"
#include "../bn254/g1.hpp"
#include "../bn254/g2.hpp"
namespace bb::curve {
class BN254 {
public:
using ScalarField = bb::fr;
using BaseField = bb::fq;
using Group = typename bb::g1;
using Element = typename Group::element;
using AffineElement = typename Group::affine_element;
using G2AffineElement = typename bb::g2::affine_element;
using G2BaseField = typename bb::fq2;
using TargetField = bb::fq12;
// TODO(#673): This flag is temporary. It is needed in the verifier classes (GeminiVerifier, etc.) while these
// classes are instantiated with "native" curve types. Eventually, the verifier classes will be instantiated only
// with stdlib types, and "native" verification will be acheived via a simulated builder.
static constexpr bool is_stdlib_type = false;
};
} // namespace bb::curve

View File

@@ -0,0 +1,115 @@
#pragma once
#include <cstdint>
#include <iomanip>
#include "../../fields/field.hpp"
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
namespace bb {
class Bn254FqParams {
public:
static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL;
static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL;
static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL;
static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL;
static constexpr uint64_t r_squared_0 = 0xF32CFC5B538AFA89UL;
static constexpr uint64_t r_squared_1 = 0xB5E71911D44501FBUL;
static constexpr uint64_t r_squared_2 = 0x47AB1EFF0A417FF6UL;
static constexpr uint64_t r_squared_3 = 0x06D89F71CAB8351FUL;
static constexpr uint64_t cube_root_0 = 0x71930c11d782e155UL;
static constexpr uint64_t cube_root_1 = 0xa6bb947cffbe3323UL;
static constexpr uint64_t cube_root_2 = 0xaa303344d4741444UL;
static constexpr uint64_t cube_root_3 = 0x2c3b3f0d26594943UL;
static constexpr uint64_t modulus_wasm_0 = 0x187cfd47;
static constexpr uint64_t modulus_wasm_1 = 0x10460b6;
static constexpr uint64_t modulus_wasm_2 = 0x1c72a34f;
static constexpr uint64_t modulus_wasm_3 = 0x2d522d0;
static constexpr uint64_t modulus_wasm_4 = 0x1585d978;
static constexpr uint64_t modulus_wasm_5 = 0x2db40c0;
static constexpr uint64_t modulus_wasm_6 = 0xa6e141;
static constexpr uint64_t modulus_wasm_7 = 0xe5c2634;
static constexpr uint64_t modulus_wasm_8 = 0x30644e;
static constexpr uint64_t r_squared_wasm_0 = 0xe1a2a074659bac10UL;
static constexpr uint64_t r_squared_wasm_1 = 0x639855865406005aUL;
static constexpr uint64_t r_squared_wasm_2 = 0xff54c5802d3e2632UL;
static constexpr uint64_t r_squared_wasm_3 = 0x2a11a68c34ea65a6UL;
static constexpr uint64_t cube_root_wasm_0 = 0x62b1a3a46a337995UL;
static constexpr uint64_t cube_root_wasm_1 = 0xadc97d2722e2726eUL;
static constexpr uint64_t cube_root_wasm_2 = 0x64ee82ede2db85faUL;
static constexpr uint64_t cube_root_wasm_3 = 0x0c0afea1488a03bbUL;
static constexpr uint64_t primitive_root_0 = 0UL;
static constexpr uint64_t primitive_root_1 = 0UL;
static constexpr uint64_t primitive_root_2 = 0UL;
static constexpr uint64_t primitive_root_3 = 0UL;
static constexpr uint64_t primitive_root_wasm_0 = 0x0000000000000000UL;
static constexpr uint64_t primitive_root_wasm_1 = 0x0000000000000000UL;
static constexpr uint64_t primitive_root_wasm_2 = 0x0000000000000000UL;
static constexpr uint64_t primitive_root_wasm_3 = 0x0000000000000000UL;
static constexpr uint64_t endo_g1_lo = 0x7a7bd9d4391eb18d;
static constexpr uint64_t endo_g1_mid = 0x4ccef014a773d2cfUL;
static constexpr uint64_t endo_g1_hi = 0x0000000000000002UL;
static constexpr uint64_t endo_g2_lo = 0xd91d232ec7e0b3d2UL;
static constexpr uint64_t endo_g2_mid = 0x0000000000000002UL;
static constexpr uint64_t endo_minus_b1_lo = 0x8211bbeb7d4f1129UL;
static constexpr uint64_t endo_minus_b1_mid = 0x6f4d8248eeb859fcUL;
static constexpr uint64_t endo_b2_lo = 0x89d3256894d213e2UL;
static constexpr uint64_t endo_b2_mid = 0UL;
static constexpr uint64_t r_inv = 0x87d20782e4866389UL;
static constexpr uint64_t coset_generators_0[8]{
0x7a17caa950ad28d7ULL, 0x4d750e37163c3674ULL, 0x20d251c4dbcb4411ULL, 0xf42f9552a15a51aeULL,
0x4f4bc0b2b5ef64bdULL, 0x22a904407b7e725aULL, 0xf60647ce410d7ff7ULL, 0xc9638b5c069c8d94ULL,
};
static constexpr uint64_t coset_generators_1[8]{
0x1f6ac17ae15521b9ULL, 0x29e3aca3d71c2cf7ULL, 0x345c97cccce33835ULL, 0x3ed582f5c2aa4372ULL,
0x1a4b98fbe78db996ULL, 0x24c48424dd54c4d4ULL, 0x2f3d6f4dd31bd011ULL, 0x39b65a76c8e2db4fULL,
};
static constexpr uint64_t coset_generators_2[8]{
0x334bea4e696bd284ULL, 0x99ba8dbde1e518b0ULL, 0x29312d5a5e5edcULL, 0x6697d49cd2d7a508ULL,
0x5c65ec9f484e3a79ULL, 0xc2d4900ec0c780a5ULL, 0x2943337e3940c6d1ULL, 0x8fb1d6edb1ba0cfdULL,
};
static constexpr uint64_t coset_generators_3[8]{
0x2a1f6744ce179d8eULL, 0x3829df06681f7cbdULL, 0x463456c802275bedULL, 0x543ece899c2f3b1cULL,
0x180a96573d3d9f8ULL, 0xf8b21270ddbb927ULL, 0x1d9598e8a7e39857ULL, 0x2ba010aa41eb7786ULL,
};
static constexpr uint64_t coset_generators_wasm_0[8] = { 0xeb8a8ec140766463ULL, 0xfded87957d76333dULL,
0x4c710c8092f2ff5eULL, 0x9af4916ba86fcb7fULL,
0xe9781656bdec97a0ULL, 0xfbdb0f2afaec667aULL,
0x4a5e94161069329bULL, 0x98e2190125e5febcULL };
static constexpr uint64_t coset_generators_wasm_1[8] = { 0xf2b1f20626a3da49ULL, 0x56c12d76cb13587fULL,
0x5251d378d7f4a143ULL, 0x4de2797ae4d5ea06ULL,
0x49731f7cf1b732c9ULL, 0xad825aed9626b0ffULL,
0xa91300efa307f9c3ULL, 0xa4a3a6f1afe94286ULL };
static constexpr uint64_t coset_generators_wasm_2[8] = { 0xf905ef8d84d5fea4ULL, 0x93b7a45b84f1507eULL,
0xe6b99ee0068dfab5ULL, 0x39bb9964882aa4ecULL,
0x8cbd93e909c74f23ULL, 0x276f48b709e2a0fcULL,
0x7a71433b8b7f4b33ULL, 0xcd733dc00d1bf56aULL };
static constexpr uint64_t coset_generators_wasm_3[8] = { 0x2958a27c02b7cd5fULL, 0x06bc8a3277c371abULL,
0x1484c05bce00b620ULL, 0x224cf685243dfa96ULL,
0x30152cae7a7b3f0bULL, 0x0d791464ef86e357ULL,
0x1b414a8e45c427ccULL, 0x290980b79c016c41ULL };
// used in msgpack schema serialization
static constexpr char schema_name[] = "fq";
static constexpr bool has_high_2adicity = false;
// The modulus is larger than BN254 scalar field modulus, so it maps to two BN254 scalars
static constexpr size_t NUM_BN254_SCALARS = 2;
};
using fq = field<Bn254FqParams>;
} // namespace bb
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)

View File

@@ -0,0 +1,62 @@
#pragma once
#include "../../fields/field2.hpp"
#include "./fq.hpp"
namespace bb {
struct Bn254Fq2Params {
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr fq twist_coeff_b_0{
0x3bf938e377b802a8UL, 0x020b1b273633535dUL, 0x26b7edf049755260UL, 0x2514c6324384a86dUL
};
static constexpr fq twist_coeff_b_1{
0x38e7ecccd1dcff67UL, 0x65f0b37d93ce0d3eUL, 0xd749d0dd22ac00aaUL, 0x0141b9ce4a688d4dUL
};
static constexpr fq twist_mul_by_q_x_0{
0xb5773b104563ab30UL, 0x347f91c8a9aa6454UL, 0x7a007127242e0991UL, 0x1956bcd8118214ecUL
};
static constexpr fq twist_mul_by_q_x_1{
0x6e849f1ea0aa4757UL, 0xaa1c7b6d89f89141UL, 0xb6e713cdfae0ca3aUL, 0x26694fbb4e82ebc3UL
};
static constexpr fq twist_mul_by_q_y_0{
0xe4bbdd0c2936b629UL, 0xbb30f162e133bacbUL, 0x31a9d1b6f9645366UL, 0x253570bea500f8ddUL
};
static constexpr fq twist_mul_by_q_y_1{
0xa1d77ce45ffe77c7UL, 0x07affd117826d1dbUL, 0x6d16bd27bb7edc6bUL, 0x2c87200285defeccUL
};
static constexpr fq twist_cube_root_0{
0x505ecc6f0dff1ac2UL, 0x2071416db35ec465UL, 0xf2b53469fa43ea78UL, 0x18545552044c99aaUL
};
static constexpr fq twist_cube_root_1{
0xad607f911cfe17a8UL, 0xb6bb78aa154154c4UL, 0xb53dd351736b20dbUL, 0x1d8ed57c5cc33d41UL
};
#else
static constexpr fq twist_coeff_b_0{
0xdc19fa4aab489658UL, 0xd416744fbbf6e69UL, 0x8f7734ed0a8a033aUL, 0x19316b8353ee09bbUL
};
static constexpr fq twist_coeff_b_1{
0x1cfd999a3b9fece0UL, 0xbe166fb279c1a7c7UL, 0xe93a1ba45580154cUL, 0x283739c94d11a9baUL
};
static constexpr fq twist_mul_by_q_x_0{
0xecdea09b24a59190UL, 0x17db8ffeae2fe1c2UL, 0xbb09c97c6dabac4dUL, 0x2492b3d41d289af3UL
};
static constexpr fq twist_mul_by_q_x_1{
0xf1663598f1142ef1UL, 0x77ec057e0bf56062UL, 0xdd0baaecb677a631UL, 0x135e4e31d284d463UL
};
static constexpr fq twist_mul_by_q_y_0{
0xf46e7f60db1f0678UL, 0x31fc2eba5bcc5c3eUL, 0xedb3adc3086a2411UL, 0x1d46bd0f837817bcUL
};
static constexpr fq twist_mul_by_q_y_1{
0x6b3fbdf579a647d5UL, 0xcc568fb62ff64974UL, 0xc1bfbf4ac4348ac6UL, 0x15871d4d3940b4d3UL
};
static constexpr fq twist_cube_root_0{
0x49d0cc74381383d0UL, 0x9611849fe4bbe3d6UL, 0xd1a231d73067c92aUL, 0x445c312767932c2UL
};
static constexpr fq twist_cube_root_1{
0x35a58c718e7c28bbUL, 0x98d42c77e7b8901aUL, 0xf9c53da2d0ca8c84UL, 0x1a68dd04e1b8c51dUL
};
#endif
};
using fq2 = field2<fq, Bn254Fq2Params>;
} // namespace bb

View File

@@ -0,0 +1,121 @@
#pragma once
#include <cstdint>
#include <iomanip>
#include <ostream>
#include "../../fields/field.hpp"
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
namespace bb {
class Bn254FrParams {
public:
// Note: limbs here are combined as concat(_3, _2, _1, _0)
// E.g. this modulus forms the value:
// 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001
// = 21888242871839275222246405745257275088548364400416034343698204186575808495617
static constexpr uint64_t modulus_0 = 0x43E1F593F0000001UL;
static constexpr uint64_t modulus_1 = 0x2833E84879B97091UL;
static constexpr uint64_t modulus_2 = 0xB85045B68181585DUL;
static constexpr uint64_t modulus_3 = 0x30644E72E131A029UL;
static constexpr uint64_t r_squared_0 = 0x1BB8E645AE216DA7UL;
static constexpr uint64_t r_squared_1 = 0x53FE3AB1E35C59E3UL;
static constexpr uint64_t r_squared_2 = 0x8C49833D53BB8085UL;
static constexpr uint64_t r_squared_3 = 0x216D0B17F4E44A5UL;
static constexpr uint64_t cube_root_0 = 0x93e7cede4a0329b3UL;
static constexpr uint64_t cube_root_1 = 0x7d4fdca77a96c167UL;
static constexpr uint64_t cube_root_2 = 0x8be4ba08b19a750aUL;
static constexpr uint64_t cube_root_3 = 0x1cbd5653a5661c25UL;
static constexpr uint64_t primitive_root_0 = 0x636e735580d13d9cUL;
static constexpr uint64_t primitive_root_1 = 0xa22bf3742445ffd6UL;
static constexpr uint64_t primitive_root_2 = 0x56452ac01eb203d8UL;
static constexpr uint64_t primitive_root_3 = 0x1860ef942963f9e7UL;
static constexpr uint64_t endo_g1_lo = 0x7a7bd9d4391eb18dUL;
static constexpr uint64_t endo_g1_mid = 0x4ccef014a773d2cfUL;
static constexpr uint64_t endo_g1_hi = 0x0000000000000002UL;
static constexpr uint64_t endo_g2_lo = 0xd91d232ec7e0b3d7UL;
static constexpr uint64_t endo_g2_mid = 0x0000000000000002UL;
static constexpr uint64_t endo_minus_b1_lo = 0x8211bbeb7d4f1128UL;
static constexpr uint64_t endo_minus_b1_mid = 0x6f4d8248eeb859fcUL;
static constexpr uint64_t endo_b2_lo = 0x89d3256894d213e3UL;
static constexpr uint64_t endo_b2_mid = 0UL;
static constexpr uint64_t r_inv = 0xc2e1f593efffffffUL;
static constexpr uint64_t coset_generators_0[8]{
0x5eef048d8fffffe7ULL, 0xb8538a9dfffffe2ULL, 0x3057819e4fffffdbULL, 0xdcedb5ba9fffffd6ULL,
0x8983e9d6efffffd1ULL, 0x361a1df33fffffccULL, 0xe2b0520f8fffffc7ULL, 0x8f46862bdfffffc2ULL,
};
static constexpr uint64_t coset_generators_1[8]{
0x12ee50ec1ce401d0ULL, 0x49eac781bc44cefaULL, 0x307f6d866832bb01ULL, 0x677be41c0793882aULL,
0x9e785ab1a6f45554ULL, 0xd574d1474655227eULL, 0xc7147dce5b5efa7ULL, 0x436dbe728516bcd1ULL,
};
static constexpr uint64_t coset_generators_2[8]{
0x29312d5a5e5ee7ULL, 0x6697d49cd2d7a515ULL, 0x5c65ec9f484e3a89ULL, 0xc2d4900ec0c780b7ULL,
0x2943337e3940c6e5ULL, 0x8fb1d6edb1ba0d13ULL, 0xf6207a5d2a335342ULL, 0x5c8f1dcca2ac9970ULL,
};
static constexpr uint64_t coset_generators_3[8]{
0x463456c802275bedULL, 0x543ece899c2f3b1cULL, 0x180a96573d3d9f8ULL, 0xf8b21270ddbb927ULL,
0x1d9598e8a7e39857ULL, 0x2ba010aa41eb7786ULL, 0x39aa886bdbf356b5ULL, 0x47b5002d75fb35e5ULL,
};
static constexpr uint64_t modulus_wasm_0 = 0x10000001;
static constexpr uint64_t modulus_wasm_1 = 0x1f0fac9f;
static constexpr uint64_t modulus_wasm_2 = 0xe5c2450;
static constexpr uint64_t modulus_wasm_3 = 0x7d090f3;
static constexpr uint64_t modulus_wasm_4 = 0x1585d283;
static constexpr uint64_t modulus_wasm_5 = 0x2db40c0;
static constexpr uint64_t modulus_wasm_6 = 0xa6e141;
static constexpr uint64_t modulus_wasm_7 = 0xe5c2634;
static constexpr uint64_t modulus_wasm_8 = 0x30644e;
static constexpr uint64_t r_squared_wasm_0 = 0x38c2e14b45b69bd4UL;
static constexpr uint64_t r_squared_wasm_1 = 0x0ffedb1885883377UL;
static constexpr uint64_t r_squared_wasm_2 = 0x7840f9f0abc6e54dUL;
static constexpr uint64_t r_squared_wasm_3 = 0x0a054a3e848b0f05UL;
static constexpr uint64_t cube_root_wasm_0 = 0x7334a1ce7065364dUL;
static constexpr uint64_t cube_root_wasm_1 = 0xae21578e4a14d22aUL;
static constexpr uint64_t cube_root_wasm_2 = 0xcea2148a96b51265UL;
static constexpr uint64_t cube_root_wasm_3 = 0x0038f7edf614a198UL;
static constexpr uint64_t primitive_root_wasm_0 = 0x2faf11711a27b370UL;
static constexpr uint64_t primitive_root_wasm_1 = 0xc23fe9fced28f1b8UL;
static constexpr uint64_t primitive_root_wasm_2 = 0x43a0fc9bbe2af541UL;
static constexpr uint64_t primitive_root_wasm_3 = 0x05d90b5719653a4fUL;
static constexpr uint64_t coset_generators_wasm_0[8] = { 0xab46711cdffffcb2ULL, 0xdb1b52736ffffc09ULL,
0x0af033c9fffffb60ULL, 0xf6e31f8c9ffffab6ULL,
0x26b800e32ffffa0dULL, 0x568ce239bffff964ULL,
0x427fcdfc5ffff8baULL, 0x7254af52effff811ULL };
static constexpr uint64_t coset_generators_wasm_1[8] = { 0x2476607dbd2dfff1ULL, 0x9a3208a561c2b00bULL,
0x0fedb0cd06576026ULL, 0x5d7570ac31329faeULL,
0xd33118d3d5c74fc9ULL, 0x48ecc0fb7a5bffe3ULL,
0x967480daa5373f6cULL, 0x0c30290249cbef86ULL };
static constexpr uint64_t coset_generators_wasm_2[8] = { 0xe6b99ee0068dfc25ULL, 0x39bb9964882aa6a5ULL,
0x8cbd93e909c75126ULL, 0x276f48b709e2a349ULL,
0x7a71433b8b7f4dc9ULL, 0xcd733dc00d1bf84aULL,
0x6824f28e0d374a6dULL, 0xbb26ed128ed3f4eeULL };
static constexpr uint64_t coset_generators_wasm_3[8] = { 0x1484c05bce00b620ULL, 0x224cf685243dfa96ULL,
0x30152cae7a7b3f0bULL, 0x0d791464ef86e357ULL,
0x1b414a8e45c427ccULL, 0x290980b79c016c41ULL,
0x066d686e110d108dULL, 0x14359e97674a5502ULL };
// used in msgpack schema serialization
static constexpr char schema_name[] = "fr";
static constexpr bool has_high_2adicity = true;
// This is a BN254 scalar, so it represents one BN254 scalar
static constexpr size_t NUM_BN254_SCALARS = 1;
};
using fr = field<Bn254FrParams>;
} // namespace bb
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)

View File

@@ -0,0 +1,30 @@
#pragma once
#include "../../groups/group.hpp"
#include "./fq.hpp"
#include "./fr.hpp"
namespace bb {
struct Bn254G1Params {
static constexpr bool USE_ENDOMORPHISM = true;
static constexpr bool can_hash_to_curve = true;
static constexpr bool small_elements = true;
static constexpr bool has_a = false;
static constexpr fq one_x = fq::one();
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr fq one_y{ 0xa6ba871b8b1e1b3aUL, 0x14f1d651eb8e167bUL, 0xccdd46def0f28c58UL, 0x1c14ef83340fbe5eUL };
#else
static constexpr fq one_y{ 0x9d0709d62af99842UL, 0xf7214c0419c29186UL, 0xa603f5090339546dUL, 0x1b906c52ac7a88eaUL };
#endif
static constexpr fq a{ 0UL, 0UL, 0UL, 0UL };
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr fq b{ 0x7a17caa950ad28d7UL, 0x1f6ac17ae15521b9UL, 0x334bea4e696bd284UL, 0x2a1f6744ce179d8eUL };
#else
static constexpr fq b{ 0xeb8a8ec140766463UL, 0xf2b1f20626a3da49UL, 0xf905ef8d84d5fea4UL, 0x2958a27c02b7cd5fUL };
#endif
};
using g1 = group<fq, fr, Bn254G1Params>;
} // namespace bb

View File

@@ -0,0 +1,34 @@
#pragma once
#include "../../groups/group.hpp"
#include "./fq2.hpp"
#include "./fr.hpp"
namespace bb {
struct Bn254G2Params {
static constexpr bool USE_ENDOMORPHISM = false;
static constexpr bool can_hash_to_curve = false;
static constexpr bool small_elements = false;
static constexpr bool has_a = false;
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr fq2 one_x{ { 0x8e83b5d102bc2026, 0xdceb1935497b0172, 0xfbb8264797811adf, 0x19573841af96503b },
{ 0xafb4737da84c6140, 0x6043dd5a5802d8c4, 0x09e950fc52a02f86, 0x14fef0833aea7b6b } };
static constexpr fq2 one_y{ { 0x619dfa9d886be9f6, 0xfe7fd297f59e9b78, 0xff9e1a62231b7dfe, 0x28fd7eebae9e4206 },
{ 0x64095b56c71856ee, 0xdc57f922327d3cbb, 0x55f935be33351076, 0x0da4a0e693fd6482 } };
#else
static constexpr fq2 one_x{
{ 0xe6df8b2cfb43050UL, 0x254c7d92a843857eUL, 0xf2006d8ad80dd622UL, 0x24a22107dfb004e3UL },
{ 0xe8e7528c0b334b65UL, 0x56e941e8b293cf69UL, 0xe1169545c074740bUL, 0x2ac61491edca4b42UL }
};
static constexpr fq2 one_y{
{ 0xdc508d48384e8843UL, 0xd55415a8afd31226UL, 0x834bf204bacb6e00UL, 0x51b9758138c5c79UL },
{ 0x64067e0b46a5f641UL, 0x37726529a3a77875UL, 0x4454445bd915f391UL, 0x10d5ac894edeed3UL }
};
#endif
static constexpr fq2 a = fq2::zero();
static constexpr fq2 b = fq2::twist_coeff_b();
};
using g2 = group<fq2, fr, Bn254G2Params>;
} // namespace bb

View File

@@ -0,0 +1,990 @@
#pragma once
// clang-format off
/*
* Clear all flags via xorq opcode
**/
#define CLEAR_FLAGS(empty_reg) \
"xorq " empty_reg ", " empty_reg " \n\t"
/**
* Load 4-limb field element, pointed to by a, into
* registers (lolo, lohi, hilo, hihi)
**/
#define LOAD_FIELD_ELEMENT(a, lolo, lohi, hilo, hihi) \
"movq 0(" a "), " lolo " \n\t" \
"movq 8(" a "), " lohi " \n\t" \
"movq 16(" a "), " hilo " \n\t" \
"movq 24(" a "), " hihi " \n\t"
/**
* Store 4-limb field element located in
* registers (lolo, lohi, hilo, hihi), into
* memory pointed to by r
**/
#define STORE_FIELD_ELEMENT(r, lolo, lohi, hilo, hihi) \
"movq " lolo ", 0(" r ") \n\t" \
"movq " lohi ", 8(" r ") \n\t" \
"movq " hilo ", 16(" r ") \n\t" \
"movq " hihi ", 24(" r ") \n\t"
#if !defined(__ADX__) || defined(DISABLE_ADX)
/**
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
* and add 4-limb field element pointed to by a
**/
#define ADD(b) \
"addq 0(" b "), %%r12 \n\t" \
"adcq 8(" b "), %%r13 \n\t" \
"adcq 16(" b "), %%r14 \n\t" \
"adcq 24(" b "), %%r15 \n\t"
/**
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
* and subtract 4-limb field element pointed to by b
**/
#define SUB(b) \
"subq 0(" b "), %%r12 \n\t" \
"sbbq 8(" b "), %%r13 \n\t" \
"sbbq 16(" b "), %%r14 \n\t" \
"sbbq 24(" b "), %%r15 \n\t"
/**
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
* add 4-limb field element pointed to by b, and reduce modulo p
**/
#define ADD_REDUCE(b, modulus_0, modulus_1, modulus_2, modulus_3) \
"addq 0(" b "), %%r12 \n\t" \
"adcq 8(" b "), %%r13 \n\t" \
"adcq 16(" b "), %%r14 \n\t" \
"adcq 24(" b "), %%r15 \n\t" \
"movq %%r12, %%r8 \n\t" \
"movq %%r13, %%r9 \n\t" \
"movq %%r14, %%r10 \n\t" \
"movq %%r15, %%r11 \n\t" \
"addq " modulus_0 ", %%r12 \n\t" \
"adcq " modulus_1 ", %%r13 \n\t" \
"adcq " modulus_2 ", %%r14 \n\t" \
"adcq " modulus_3 ", %%r15 \n\t" \
"cmovncq %%r8, %%r12 \n\t" \
"cmovncq %%r9, %%r13 \n\t" \
"cmovncq %%r10, %%r14 \n\t" \
"cmovncq %%r11, %%r15 \n\t"
/**
* Take a 4-limb integer, r, in (%r12, %r13, %r14, %r15)
* and conditionally subtract modulus, if r > p.
**/
#define REDUCE_FIELD_ELEMENT(neg_modulus_0, neg_modulus_1, neg_modulus_2, neg_modulus_3) \
/* Duplicate `r` */ \
"movq %%r12, %%r8 \n\t" \
"movq %%r13, %%r9 \n\t" \
"movq %%r14, %%r10 \n\t" \
"movq %%r15, %%r11 \n\t" \
"addq " neg_modulus_0 ", %%r12 \n\t" /* r'[0] -= modulus.data[0] */ \
"adcq " neg_modulus_1 ", %%r13 \n\t" /* r'[1] -= modulus.data[1] */ \
"adcq " neg_modulus_2 ", %%r14 \n\t" /* r'[2] -= modulus.data[2] */ \
"adcq " neg_modulus_3 ", %%r15 \n\t" /* r'[3] -= modulus.data[3] */ \
\
/* if r does not need to be reduced, overflow flag is 1 */ \
/* set r' = r if this flag is set */ \
"cmovncq %%r8, %%r12 \n\t" \
"cmovncq %%r9, %%r13 \n\t" \
"cmovncq %%r10, %%r14 \n\t" \
"cmovncq %%r11, %%r15 \n\t"
/**
* Compute Montgomery squaring of a
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define SQR(a) \
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
\
"xorq %%r8, %%r8 \n\t" /* clear flags */ \
/* compute a[0] *a[1], a[0]*a[2], a[0]*a[3], a[1]*a[2], a[1]*a[3], a[2]*a[3] */ \
"mulxq 8(" a "), %%r9, %%r10 \n\t" /* (r[1], r[2]) <- a[0] * a[1] */ \
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[1], t[2]) <- a[0] * a[2] */ \
"mulxq 24(" a "), %%r11, %%r12 \n\t" /* (r[3], r[4]) <- a[0] * a[3] */ \
\
\
/* accumulate products into result registers */ \
"addq %%r8, %%r10 \n\t" /* r[2] += t[1] */ \
"adcq %%r15, %%r11 \n\t" /* r[3] += t[2] */ \
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %r%dx */ \
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[5], t[6]) <- a[1] * a[2] */ \
"mulxq 24(" a "), %%rdi, %%rcx \n\t" /* (t[3], t[4]) <- a[1] * a[3] */ \
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %%rdx */ \
"mulxq 16(" a "), %%r13, %%r14 \n\t" /* (r[5], r[6]) <- a[3] * a[2] */ \
"adcq %%rdi, %%r12 \n\t" /* r[4] += t[3] */ \
"adcq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"addq %%r8, %%r11 \n\t" /* r[3] += t[5] */ \
"adcq %%r15, %%r12 \n\t" /* r[4] += t[6] */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
\
/* double result registers */ \
"addq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \
"adcq %%r10, %%r10 \n\t" /* r[2] = 2r[2] */ \
"adcq %%r11, %%r11 \n\t" /* r[3] = 2r[3] */ \
"adcq %%r12, %%r12 \n\t" /* r[4] = 2r[4] */ \
"adcq %%r13, %%r13 \n\t" /* r[5] = 2r[5] */ \
"adcq %%r14, %%r14 \n\t" /* r[6] = 2r[6] */ \
\
/* compute a[3]*a[3], a[2]*a[2], a[1]*a[1], a[0]*a[0] */ \
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
"mulxq %%rdx, %%r8, %%rcx \n\t" /* (r[0], t[4]) <- a[0] * a[0] */ \
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq %%rdx, %%rdx, %%rdi \n\t" /* (t[7], t[8]) <- a[2] * a[2] */ \
/* add squares into result registers */ \
"addq %%rdx, %%r12 \n\t" /* r[4] += t[7] */ \
"adcq %%rdi, %%r13 \n\t" /* r[5] += t[8] */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"addq %%rcx, %%r9 \n\t" /* r[1] += t[4] */ \
"movq 24(" a "), %%rdx \n\t" /* r[2] += flag_c */ \
"mulxq %%rdx, %%rcx, %%r15 \n\t" /* (t[5], r[7]) <- a[3] * a[3] */ \
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq %%rdx, %%rdi, %%rdx \n\t" /* (t[3], t[6]) <- a[1] * a[1] */ \
"adcq %%rdi, %%r10 \n\t" /* r[2] += t[3] */ \
"adcq %%rdx, %%r11 \n\t" /* r[3] += t[6] */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
"addq %%rcx, %%r14 \n\t" /* r[6] += t[5] */ \
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
\
/* perform modular reduction: r[0] */ \
"movq %%r8, %%rdx \n\t" /* move r8 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"addq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \
"adcq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_c */ \
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"adcq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_c */ \
"adcq $0, %%r11 \n\t" /* r[4] += flag_c */ \
/* Partial fix "adcq $0, %%r12 \n\t"*/ /* r[4] += flag_c */ \
"addq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"mulxq %[modulus_3], %%r8, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcq %%rdi, %%r10 \n\t" /* r[2] += t[0] + flag_c */ \
"adcq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
"adcq %%rdx, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
"addq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
\
/* perform modular reduction: r[1] */ \
"movq %%r9, %%rdx \n\t" /* move r9 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"addq %%rdi, %%r9 \n\t" /* r[1] += t[0] (%r8 now free) */ \
"adcq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"adcq %%rcx, %%r11 \n\t" /* r[3] += t[3] + flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
"addq %%rdi, %%r10 \n\t" /* r[2] += t[2] */ \
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"mulxq %[modulus_3], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcq %%rdi, %%r11 \n\t" /* r[3] += t[0] + flag_c */ \
"adcq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"addq %%r8, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
\
/* perform modular reduction: r[2] */ \
"movq %%r10, %%rdx \n\t" /* move r10 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"addq %%rdi, %%r10 \n\t" /* r[2] += t[0] (%r8 now free) */ \
"adcq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"mulxq %[modulus_3], %%r10, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcq %%rcx, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
"adcq %%r9, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
"adcq %%rdx, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
"addq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \
"adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
"adcq %%r10, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
\
/* perform modular reduction: r[3] */ \
"movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"mulxq %[modulus_1], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"addq %%rdi, %%r11 \n\t" /* r[3] += t[0] (%r11 now free) */ \
"adcq %%r8, %%r12 \n\t" /* r[4] += t[2] */ \
"adcq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"mulxq %[modulus_3], %%r10, %%r11 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcq %%r9, %%r14 \n\t" /* r[6] += t[1] + flag_c */ \
"adcq %%r11, %%r15 \n\t" /* r[7] += t[3] + flag_c */ \
"addq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcq %%r8, %%r13 \n\t" /* r[5] += t[0] + flag_c */ \
"adcq %%r10, %%r14 \n\t" /* r[6] += t[2] + flag_c */ \
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */
/**
* Compute Montgomery multiplication of a, b.
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define MUL(a1, a2, a3, a4, b) \
"movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \
"xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
"mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
/* zero flags */ \
\
/* start computing modular reduction */ \
"movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
\
/* start first addition chain */ \
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
"adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adcq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
\
/* reduce by r[0] * k */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"addq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \
"adcq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \
"adcq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adcq $0, %%r10 \n\t" /* r[3] += flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
"addq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_c */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"adcq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adcq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
"adcq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
"addq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_i */ \
\
/* modulus = 254 bits, so max(t[3]) = 62 bits */ \
/* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \
/* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \
/* (which is very convenient because we're out of registers!) */ \
/* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \
\
/* a[1] * b */ \
"movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
"adcq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adcq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
"addq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
\
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \
"adcq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
"addq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
\
/* reduce by r[1] * k */ \
"movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \
"adcq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adcq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
"adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
"addq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"adcq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
\
/* a[2] * b */ \
"movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
"addq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adcq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
"adcq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
"addq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \
"adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
"adcq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"addq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
\
/* reduce by r[2] * k */ \
"movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"addq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \
"adcq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"addq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
"adcq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
"adcq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
"addq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
\
/* a[3] * b */ \
"movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \
"addq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
\
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \
"adcq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
"adcq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
"adcq $0, %%r15 \n\t" /* r[7] += + flag_c */ \
"addq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
\
/* reduce by r[3] * k */ \
"movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \
"addq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \
"adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \
"addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
\
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \
"adcq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
"adcq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
"adcq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \
"addq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
"adcq $0, %%r15 \n\t" /* r[7] += flag_c */
/**
* Compute 256-bit multiplication of a, b.
* Result is stored, r. // in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define MUL_256(a, b, r) \
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
\
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
"mulxq 16(" b "), %%r15, %%rax \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
/* zero flags */ \
"xorq %%r10, %%r10 \n\t" /* clear r10 register, we use this when we need 0 */ \
\
\
/* start first addition chain */ \
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
"adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adcq %%r10, %%rax \n\t" /* r[3] += flag_c */ \
"addq %%rdi, %%rax \n\t" /* r[3] += t[2] + flag_c */ \
\
/* a[1] * b */ \
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
"addq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
"adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adcq %%rsi, %%rax \n\t" /* r[3] += t[1] + flag_c */ \
"addq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
\
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
"adcq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
\
/* a[2] * b */ \
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
"addq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adcq %%r9, %%rax \n\t" /* r[3] += t[1] + flag_c */ \
"addq %%rdi, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
\
\
/* a[3] * b */ \
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
"adcq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
"movq %%r13, 0(" r ") \n\t" \
"movq %%r14, 8(" r ") \n\t" \
"movq %%r15, 16(" r ") \n\t" \
"movq %%rax, 24(" r ") \n\t"
#else // 6047895us
/**
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
* and add 4-limb field element pointed to by a
**/
#define ADD(b) \
"adcxq 0(" b "), %%r12 \n\t" \
"adcxq 8(" b "), %%r13 \n\t" \
"adcxq 16(" b "), %%r14 \n\t" \
"adcxq 24(" b "), %%r15 \n\t"
/**
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
* and subtract 4-limb field element pointed to by b
**/
#define SUB(b) \
"subq 0(" b "), %%r12 \n\t" \
"sbbq 8(" b "), %%r13 \n\t" \
"sbbq 16(" b "), %%r14 \n\t" \
"sbbq 24(" b "), %%r15 \n\t"
/**
* Take a 4-limb field element, in (%r12, %r13, %r14, %r15),
* add 4-limb field element pointed to by b, and reduce modulo p
**/
#define ADD_REDUCE(b, modulus_0, modulus_1, modulus_2, modulus_3) \
"adcxq 0(" b "), %%r12 \n\t" \
"movq %%r12, %%r8 \n\t" \
"adoxq " modulus_0 ", %%r12 \n\t" \
"adcxq 8(" b "), %%r13 \n\t" \
"movq %%r13, %%r9 \n\t" \
"adoxq " modulus_1 ", %%r13 \n\t" \
"adcxq 16(" b "), %%r14 \n\t" \
"movq %%r14, %%r10 \n\t" \
"adoxq " modulus_2 ", %%r14 \n\t" \
"adcxq 24(" b "), %%r15 \n\t" \
"movq %%r15, %%r11 \n\t" \
"adoxq " modulus_3 ", %%r15 \n\t" \
"cmovnoq %%r8, %%r12 \n\t" \
"cmovnoq %%r9, %%r13 \n\t" \
"cmovnoq %%r10, %%r14 \n\t" \
"cmovnoq %%r11, %%r15 \n\t"
/**
* Take a 4-limb integer, r, in (%r12, %r13, %r14, %r15)
* and conditionally subtract modulus, if r > p.
**/
#define REDUCE_FIELD_ELEMENT(neg_modulus_0, neg_modulus_1, neg_modulus_2, neg_modulus_3) \
/* Duplicate `r` */ \
"movq %%r12, %%r8 \n\t" \
"movq %%r13, %%r9 \n\t" \
"movq %%r14, %%r10 \n\t" \
"movq %%r15, %%r11 \n\t" \
/* Add the negative representation of 'modulus' into `r`. We do this instead */ \
/* of subtracting, because we can use `adoxq`. */ \
/* This opcode only has a dependence on the overflow */ \
/* flag (sub/sbb changes both carry and overflow flags). */ \
/* We can process an `adcxq` and `acoxq` opcode simultaneously. */ \
"adoxq " neg_modulus_0 ", %%r12 \n\t" /* r'[0] -= modulus.data[0] */ \
"adoxq " neg_modulus_1 ", %%r13 \n\t" /* r'[1] -= modulus.data[1] */ \
"adoxq " neg_modulus_2 ", %%r14 \n\t" /* r'[2] -= modulus.data[2] */ \
"adoxq " neg_modulus_3 ", %%r15 \n\t" /* r'[3] -= modulus.data[3] */ \
\
/* if r does not need to be reduced, overflow flag is 1 */ \
/* set r' = r if this flag is set */ \
"cmovnoq %%r8, %%r12 \n\t" \
"cmovnoq %%r9, %%r13 \n\t" \
"cmovnoq %%r10, %%r14 \n\t" \
"cmovnoq %%r11, %%r15 \n\t"
/**
* Compute Montgomery squaring of a
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define SQR(a) \
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
\
"xorq %%r8, %%r8 \n\t" /* clear flags */ \
/* compute a[0] *a[1], a[0]*a[2], a[0]*a[3], a[1]*a[2], a[1]*a[3], a[2]*a[3] */ \
"mulxq 8(" a "), %%r9, %%r10 \n\t" /* (r[1], r[2]) <- a[0] * a[1] */ \
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[1], t[2]) <- a[0] * a[2] */ \
"mulxq 24(" a "), %%r11, %%r12 \n\t" /* (r[3], r[4]) <- a[0] * a[3] */ \
\
\
/* accumulate products into result registers */ \
"adoxq %%r8, %%r10 \n\t" /* r[2] += t[1] */ \
"adcxq %%r15, %%r11 \n\t" /* r[3] += t[2] */ \
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %r%dx */ \
"mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[5], t[6]) <- a[1] * a[2] */ \
"mulxq 24(" a "), %%rdi, %%rcx \n\t" /* (t[3], t[4]) <- a[1] * a[3] */ \
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %%rdx */ \
"mulxq 16(" a "), %%r13, %%r14 \n\t" /* (r[5], r[6]) <- a[3] * a[2] */ \
"adoxq %%r8, %%r11 \n\t" /* r[3] += t[5] */ \
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] */ \
"adoxq %%r15, %%r12 \n\t" /* r[4] += t[6] */ \
"adcxq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
\
/* double result registers */ \
"adoxq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \
"adcxq %%r12, %%r12 \n\t" /* r[4] = 2r[4] */ \
"adoxq %%r10, %%r10 \n\t" /* r[2] = 2r[2] */ \
"adcxq %%r13, %%r13 \n\t" /* r[5] = 2r[5] */ \
"adoxq %%r11, %%r11 \n\t" /* r[3] = 2r[3] */ \
"adcxq %%r14, %%r14 \n\t" /* r[6] = 2r[6] */ \
\
/* compute a[3]*a[3], a[2]*a[2], a[1]*a[1], a[0]*a[0] */ \
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
"mulxq %%rdx, %%r8, %%rcx \n\t" /* (r[0], t[4]) <- a[0] * a[0] */ \
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq %%rdx, %%rdx, %%rdi \n\t" /* (t[7], t[8]) <- a[2] * a[2] */ \
/* add squares into result registers */ \
"adcxq %%rcx, %%r9 \n\t" /* r[1] += t[4] */ \
"adoxq %%rdx, %%r12 \n\t" /* r[4] += t[7] */ \
"adoxq %%rdi, %%r13 \n\t" /* r[5] += t[8] */ \
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \
"mulxq %%rdx, %%rcx, %%r15 \n\t" /* (t[5], r[7]) <- a[3] * a[3] */ \
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq %%rdx, %%rdi, %%rdx \n\t" /* (t[3], t[6]) <- a[1] * a[1] */ \
"adcxq %%rdi, %%r10 \n\t" /* r[2] += t[3] */ \
"adcxq %%rdx, %%r11 \n\t" /* r[3] += t[6] */ \
"adoxq %%rcx, %%r14 \n\t" /* r[6] += t[5] */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \
\
/* perform modular reduction: r[0] */ \
"movq %%r8, %%rdx \n\t" /* move r8 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"adoxq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \
"mulxq %[modulus_3], %%r8, %%rdi \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
"adoxq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_o */ \
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
"adoxq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_o */ \
"adcxq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \
"adoxq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_o */ \
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"adcxq %%rdi, %%r10 \n\t" /* r[2] += t[0] + flag_c */ \
"adcxq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
\
/* perform modular reduction: r[1] */ \
"movq %%r9, %%rdx \n\t" /* move r9 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \
"mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"adoxq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"mulxq %[modulus_3], %%r8, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
"adoxq %%rcx, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
"mulxq %[modulus_0], %%r8, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"adcxq %%r8, %%r9 \n\t" /* r[1] += t[0] (%r9 now free) */ \
"adoxq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \
"mulxq %[modulus_1], %%r8, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"adcxq %%r8, %%r10 \n\t" /* r[2] += t[2] */ \
"adoxq %%rcx, %%r11 \n\t" /* r[3] += t[3] + flag_o */ \
"adcxq %%rdi, %%r11 \n\t" /* r[3] += t[0] + flag_c */ \
\
/* perform modular reduction: r[2] */ \
"movq %%r10, %%rdx \n\t" /* move r10 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
"mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"adoxq %%rcx, %%r12 \n\t" /* r[4] += t[3] + flag_o */ \
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \
"adoxq %%r9, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \
"mulxq %[modulus_3], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adcxq %%r8, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
"adoxq %%r9, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"adcxq %%r8, %%r10 \n\t" /* r[2] += t[0] (%r10 now free) */ \
"adoxq %%r9, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \
"adcxq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \
"adoxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_o */ \
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
\
/* perform modular reduction: r[3] */ \
"movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \
"mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \
"mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \
"mulxq %[modulus_1], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \
"adoxq %%rdi, %%r11 \n\t" /* r[3] += t[0] (%r11 now free) */ \
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[2] */ \
"adoxq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adcxq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \
"mulxq %[modulus_3], %%r10, %%r11 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[0] + flag_o */ \
"adcxq %%r10, %%r14 \n\t" /* r[6] += t[2] + flag_c */ \
"adoxq %%r9, %%r14 \n\t" /* r[6] += t[1] + flag_o */ \
"adcxq %%r11, %%r15 \n\t" /* r[7] += t[3] + flag_c */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */
/**
* Compute Montgomery multiplication of a, b.
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define MUL(a1, a2, a3, a4, b) \
"movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \
"xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
"mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
/* zero flags */ \
\
/* start computing modular reduction */ \
"movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
\
/* start first addition chain */ \
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_o */ \
"adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
\
/* reduce by r[0] * k */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_i */ \
"adoxq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \
"adcxq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_o */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"adoxq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \
"adcxq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_o */ \
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
\
/* modulus = 254 bits, so max(t[3]) = 62 bits */ \
/* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \
/* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \
/* (which is very convenient because we're out of registers!) */ \
/* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \
\
/* a[1] * b */ \
"movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
\
/* reduce by r[1] * k */ \
"movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"adcxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_o */ \
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adoxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"adoxq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
"adcxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
\
/* a[2] * b */ \
"movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adoxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
"adcxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_o */ \
"mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
"adoxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
"adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
\
/* reduce by r[2] * k */ \
"movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \
"adoxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"adcxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \
"adoxq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
\
/* a[3] * b */ \
"movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \
\
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_o */ \
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += + flag_o */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
\
/* reduce by r[3] * k */ \
"movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
\
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_o */ \
"adcxq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */
/**
* Compute Montgomery multiplication of a, b.
* Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define MUL_FOO(a1, a2, a3, a4, b) \
"movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \
"xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
"mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
/* zero flags */ \
\
/* start computing modular reduction */ \
"movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
\
/* start first addition chain */ \
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_o */ \
"adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
\
/* reduce by r[0] * k */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_i */ \
"adoxq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \
"adcxq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_o */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"adoxq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \
"adcxq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_o */ \
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
\
/* modulus = 254 bits, so max(t[3]) = 62 bits */ \
/* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \
/* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \
/* (which is very convenient because we're out of registers!) */ \
/* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \
\
/* a[1] * b */ \
"movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
"adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
\
/* reduce by r[1] * k */ \
"movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"adcxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_o */ \
"adoxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adcxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adoxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"adoxq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
"adcxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \
\
/* a[2] * b */ \
"movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \
"adoxq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adoxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \
"adcxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_o */ \
"mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
"adoxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
"adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \
"adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
\
/* reduce by r[2] * k */ \
"movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \
"adcxq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \
"adoxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \
"adoxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \
"mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"adcxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \
"adoxq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \
"adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \
"adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \
"adcxq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
\
/* a[3] * b */ \
"movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \
\
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \
"mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_o */ \
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \
"adcxq %[zero_reference], %%r15 \n\t" /* r[7] += + flag_o */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \
\
/* reduce by r[3] * k */ \
"movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \
"mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \
"mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \
"mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \
"adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \
"adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \
"adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \
"adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \
\
"mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \
"mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \
"adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \
"adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \
"adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_o */ \
"adcxq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \
"adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */
/**
* Compute 256-bit multiplication of a, b.
* Result is stored, r. // in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r"
**/
#define MUL_256(a, b, r) \
"movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \
\
/* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \
"mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \
"mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \
"mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \
"mulxq 16(" b "), %%r15, %%rax \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \
/* zero flags */ \
"xorq %%r10, %%r10 \n\t" /* clear r10 register, we use this when we need 0 */ \
\
\
/* start first addition chain */ \
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \
"adoxq %%rdi, %%rax \n\t" /* r[3] += t[2] + flag_o */ \
"adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \
"adcxq %%r10, %%rax \n\t" /* r[3] += flag_o */ \
\
/* a[1] * b */ \
"movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \
"adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \
"adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \
"adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%rsi, %%rax \n\t" /* r[3] += t[1] + flag_o */ \
\
"mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \
"adcxq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
\
/* a[2] * b */ \
"movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \
"mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \
"adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \
"adoxq %%r9, %%rax \n\t" /* r[3] += t[1] + flag_o */ \
"adcxq %%rdi, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
\
\
/* a[3] * b */ \
"movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \
"mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \
"adcxq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \
"movq %%r13, 0(" r ") \n\t" \
"movq %%r14, 8(" r ") \n\t" \
"movq %%r15, 16(" r ") \n\t" \
"movq %%rax, 24(" r ") \n\t"
#endif

View File

@@ -0,0 +1,10 @@
#pragma once
/**
* @brief Include order of header-only field class is structured to ensure linter/language server can resolve paths.
* Declarations are defined in "field_declarations.hpp", definitions in "field_impl.hpp" (which includes
* declarations header) Spectialized definitions are in "field_impl_generic.hpp" and "field_impl_x64.hpp"
* (which include "field_impl.hpp")
*/
#include "./field_impl_generic.hpp"
#include "./field_impl_x64.hpp"

View File

@@ -0,0 +1,682 @@
#pragma once
#include "../../common/assert.hpp"
#include "../../common/compiler_hints.hpp"
#include "../../numeric/random/engine.hpp"
#include "../../numeric/uint128/uint128.hpp"
#include "../../numeric/uint256/uint256.hpp"
#include <array>
#include <cstdint>
#include <iostream>
#include <random>
#include <span>
#ifndef DISABLE_ASM
#ifdef __BMI2__
#define BBERG_NO_ASM 0
#else
#define BBERG_NO_ASM 1
#endif
#else
#define BBERG_NO_ASM 1
#endif
namespace bb {
using namespace numeric;
/**
* @brief General class for prime fields see \ref field_docs["field documentation"] for general implementation reference
*
* @tparam Params_
*/
template <class Params_> struct alignas(32) field {
public:
using View = field;
using Params = Params_;
using in_buf = const uint8_t*;
using vec_in_buf = const uint8_t*;
using out_buf = uint8_t*;
using vec_out_buf = uint8_t**;
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
#define WASM_NUM_LIMBS 9
#define WASM_LIMB_BITS 29
#endif
// We don't initialize data in the default constructor since we'd lose a lot of time on huge array initializations.
// Other alternatives have been noted, such as casting to get around constructors where they matter,
// however it is felt that sanitizer tools (e.g. MSAN) can detect garbage well, whereas doing
// hacky casts where needed would require rework to critical algos like MSM, FFT, Sumcheck.
// Instead, the recommended solution is use an explicit {} where initialization is important:
// field f; // not initialized
// field f{}; // zero-initialized
// std::array<field, N> arr; // not initialized, good for huge N
// std::array<field, N> arr {}; // zero-initialized, preferable for moderate N
field() = default;
constexpr field(const numeric::uint256_t& input) noexcept
: data{ input.data[0], input.data[1], input.data[2], input.data[3] }
{
self_to_montgomery_form();
}
// NOLINTNEXTLINE (unsigned long is platform dependent, which we want in this case)
constexpr field(const unsigned long input) noexcept
: data{ input, 0, 0, 0 }
{
self_to_montgomery_form();
}
constexpr field(const unsigned int input) noexcept
: data{ input, 0, 0, 0 }
{
self_to_montgomery_form();
}
// NOLINTNEXTLINE (unsigned long long is platform dependent, which we want in this case)
constexpr field(const unsigned long long input) noexcept
: data{ input, 0, 0, 0 }
{
self_to_montgomery_form();
}
constexpr field(const int input) noexcept
: data{ 0, 0, 0, 0 }
{
if (input < 0) {
data[0] = static_cast<uint64_t>(-input);
data[1] = 0;
data[2] = 0;
data[3] = 0;
self_to_montgomery_form();
self_neg();
self_reduce_once();
} else {
data[0] = static_cast<uint64_t>(input);
data[1] = 0;
data[2] = 0;
data[3] = 0;
self_to_montgomery_form();
}
}
constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
: data{ a, b, c, d } {};
/**
* @brief Convert a 512-bit big integer into a field element.
*
* @details Used for deriving field elements from random values. 512-bits prevents biased output as 2^512>>modulus
*
*/
constexpr explicit field(const uint512_t& input) noexcept
{
uint256_t value = (input % modulus).lo;
data[0] = value.data[0];
data[1] = value.data[1];
data[2] = value.data[2];
data[3] = value.data[3];
self_to_montgomery_form();
}
constexpr explicit field(std::string input) noexcept
{
uint256_t value(input);
*this = field(value);
}
constexpr explicit operator bool() const
{
field out = from_montgomery_form();
ASSERT(out.data[0] == 0 || out.data[0] == 1);
return static_cast<bool>(out.data[0]);
}
constexpr explicit operator uint8_t() const
{
field out = from_montgomery_form();
return static_cast<uint8_t>(out.data[0]);
}
constexpr explicit operator uint16_t() const
{
field out = from_montgomery_form();
return static_cast<uint16_t>(out.data[0]);
}
constexpr explicit operator uint32_t() const
{
field out = from_montgomery_form();
return static_cast<uint32_t>(out.data[0]);
}
constexpr explicit operator uint64_t() const
{
field out = from_montgomery_form();
return out.data[0];
}
constexpr explicit operator uint128_t() const
{
field out = from_montgomery_form();
uint128_t lo = out.data[0];
uint128_t hi = out.data[1];
return (hi << 64) | lo;
}
constexpr operator uint256_t() const noexcept
{
field out = from_montgomery_form();
return uint256_t(out.data[0], out.data[1], out.data[2], out.data[3]);
}
[[nodiscard]] constexpr uint256_t uint256_t_no_montgomery_conversion() const noexcept
{
return { data[0], data[1], data[2], data[3] };
}
constexpr field(const field& other) noexcept = default;
constexpr field(field&& other) noexcept = default;
constexpr field& operator=(const field& other) noexcept = default;
constexpr field& operator=(field&& other) noexcept = default;
constexpr ~field() noexcept = default;
alignas(32) uint64_t data[4]; // NOLINT
static constexpr uint256_t modulus =
uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr uint256_t r_squared_uint{
Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3
};
#else
static constexpr uint256_t r_squared_uint{
Params_::r_squared_wasm_0, Params_::r_squared_wasm_1, Params_::r_squared_wasm_2, Params_::r_squared_wasm_3
};
static constexpr std::array<uint64_t, 9> wasm_modulus = { Params::modulus_wasm_0, Params::modulus_wasm_1,
Params::modulus_wasm_2, Params::modulus_wasm_3,
Params::modulus_wasm_4, Params::modulus_wasm_5,
Params::modulus_wasm_6, Params::modulus_wasm_7,
Params::modulus_wasm_8 };
#endif
static constexpr field cube_root_of_unity()
{
// endomorphism i.e. lambda * [P] = (beta * x, y)
if constexpr (Params::cube_root_0 != 0) {
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
constexpr field result{
Params::cube_root_0, Params::cube_root_1, Params::cube_root_2, Params::cube_root_3
};
#else
constexpr field result{
Params::cube_root_wasm_0, Params::cube_root_wasm_1, Params::cube_root_wasm_2, Params::cube_root_wasm_3
};
#endif
return result;
} else {
constexpr field two_inv = field(2).invert();
constexpr field numerator = (-field(3)).sqrt() - field(1);
constexpr field result = two_inv * numerator;
return result;
}
}
static constexpr field zero() { return field(0, 0, 0, 0); }
static constexpr field neg_one() { return -field(1); }
static constexpr field one() { return field(1); }
static constexpr field external_coset_generator()
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const field result{
Params::coset_generators_0[7],
Params::coset_generators_1[7],
Params::coset_generators_2[7],
Params::coset_generators_3[7],
};
#else
const field result{
Params::coset_generators_wasm_0[7],
Params::coset_generators_wasm_1[7],
Params::coset_generators_wasm_2[7],
Params::coset_generators_wasm_3[7],
};
#endif
return result;
}
static constexpr field tag_coset_generator()
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const field result{
Params::coset_generators_0[6],
Params::coset_generators_1[6],
Params::coset_generators_2[6],
Params::coset_generators_3[6],
};
#else
const field result{
Params::coset_generators_wasm_0[6],
Params::coset_generators_wasm_1[6],
Params::coset_generators_wasm_2[6],
Params::coset_generators_wasm_3[6],
};
#endif
return result;
}
static constexpr field coset_generator(const size_t idx)
{
ASSERT(idx < 7);
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const field result{
Params::coset_generators_0[idx],
Params::coset_generators_1[idx],
Params::coset_generators_2[idx],
Params::coset_generators_3[idx],
};
#else
const field result{
Params::coset_generators_wasm_0[idx],
Params::coset_generators_wasm_1[idx],
Params::coset_generators_wasm_2[idx],
Params::coset_generators_wasm_3[idx],
};
#endif
return result;
}
BB_INLINE constexpr field operator*(const field& other) const noexcept;
BB_INLINE constexpr field operator+(const field& other) const noexcept;
BB_INLINE constexpr field operator-(const field& other) const noexcept;
BB_INLINE constexpr field operator-() const noexcept;
constexpr field operator/(const field& other) const noexcept;
// prefix increment (++x)
BB_INLINE constexpr field operator++() noexcept;
// postfix increment (x++)
// NOLINTNEXTLINE
BB_INLINE constexpr field operator++(int) noexcept;
BB_INLINE constexpr field& operator*=(const field& other) noexcept;
BB_INLINE constexpr field& operator+=(const field& other) noexcept;
BB_INLINE constexpr field& operator-=(const field& other) noexcept;
constexpr field& operator/=(const field& other) noexcept;
// NOTE: comparison operators exist so that `field` is comparible with stl methods that require them.
// (e.g. std::sort)
// Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
BB_INLINE constexpr bool operator>(const field& other) const noexcept;
BB_INLINE constexpr bool operator<(const field& other) const noexcept;
BB_INLINE constexpr bool operator==(const field& other) const noexcept;
BB_INLINE constexpr bool operator!=(const field& other) const noexcept;
BB_INLINE constexpr field to_montgomery_form() const noexcept;
BB_INLINE constexpr field from_montgomery_form() const noexcept;
BB_INLINE constexpr field sqr() const noexcept;
BB_INLINE constexpr void self_sqr() noexcept;
BB_INLINE constexpr field pow(const uint256_t& exponent) const noexcept;
BB_INLINE constexpr field pow(uint64_t exponent) const noexcept;
static_assert(Params::modulus_0 != 1);
static constexpr uint256_t modulus_minus_two =
uint256_t(Params::modulus_0 - 2ULL, Params::modulus_1, Params::modulus_2, Params::modulus_3);
constexpr field invert() const noexcept;
static void batch_invert(std::span<field> coeffs) noexcept;
static void batch_invert(field* coeffs, size_t n) noexcept;
/**
* @brief Compute square root of the field element.
*
* @return <true, root> if the element is a quadratic remainder, <false, 0> if it's not
*/
constexpr std::pair<bool, field> sqrt() const noexcept;
BB_INLINE constexpr void self_neg() noexcept;
BB_INLINE constexpr void self_to_montgomery_form() noexcept;
BB_INLINE constexpr void self_from_montgomery_form() noexcept;
BB_INLINE constexpr void self_conditional_negate(uint64_t predicate) noexcept;
BB_INLINE constexpr field reduce_once() const noexcept;
BB_INLINE constexpr void self_reduce_once() noexcept;
BB_INLINE constexpr void self_set_msb() noexcept;
[[nodiscard]] BB_INLINE constexpr bool is_msb_set() const noexcept;
[[nodiscard]] BB_INLINE constexpr uint64_t is_msb_set_word() const noexcept;
[[nodiscard]] BB_INLINE constexpr bool is_zero() const noexcept;
static constexpr field get_root_of_unity(size_t subgroup_size) noexcept;
static void serialize_to_buffer(const field& value, uint8_t* buffer) { write(buffer, value); }
static field serialize_from_buffer(const uint8_t* buffer) { return from_buffer<field>(buffer); }
[[nodiscard]] BB_INLINE std::vector<uint8_t> to_buffer() const { return to_buffer(*this); }
struct wide_array {
uint64_t data[8]; // NOLINT
};
BB_INLINE constexpr wide_array mul_512(const field& other) const noexcept;
BB_INLINE constexpr wide_array sqr_512() const noexcept;
BB_INLINE constexpr field conditionally_subtract_from_double_modulus(const uint64_t predicate) const noexcept
{
if (predicate != 0) {
constexpr field p{
twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3]
};
return p - *this;
}
return *this;
}
/**
* For short Weierstrass curves y^2 = x^3 + b mod r, if there exists a cube root of unity mod r,
* we can take advantage of an enodmorphism to decompose a 254 bit scalar into 2 128 bit scalars.
* \beta = cube root of 1, mod q (q = order of fq)
* \lambda = cube root of 1, mod r (r = order of fr)
*
* For a point P1 = (X, Y), where Y^2 = X^3 + b, we know that
* the point P2 = (X * \beta, Y) is also a point on the curve
* We can represent P2 as a scalar multiplication of P1, where P2 = \lambda * P1
*
* For a generic multiplication of P1 by a 254 bit scalar k, we can decompose k
* into 2 127 bit scalars (k1, k2), such that k = k1 - (k2 * \lambda)
*
* We can now represent (k * P1) as (k1 * P1) - (k2 * P2), where P2 = (X * \beta, Y).
* As k1, k2 have half the bit length of k, we have reduced the number of loop iterations of our
* scalar multiplication algorithm in half
*
* To find k1, k2, We use the extended euclidean algorithm to find 4 short scalars [a1, a2], [b1, b2] such that
* modulus = (a1 * b2) - (b1 * a2)
* We then compute scalars c1 = round(b2 * k / r), c2 = round(b1 * k / r), where
* k1 = (c1 * a1) + (c2 * a2), k2 = -((c1 * b1) + (c2 * b2))
* We pre-compute scalars g1 = (2^256 * b1) / n, g2 = (2^256 * b2) / n, to avoid having to perform long division
* on 512-bit scalars
**/
static void split_into_endomorphism_scalars(const field& k, field& k1, field& k2)
{
// if the modulus is a >= 255-bit integer, we need to use a basis where g1, g2 have been shifted by 2^384
if constexpr (Params::modulus_3 >= 0x4000000000000000ULL) {
split_into_endomorphism_scalars_384(k, k1, k2);
} else {
std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> ret = split_into_endomorphism_scalars(k);
k1.data[0] = ret.first[0];
k1.data[1] = ret.first[1];
// TODO(https://github.com/AztecProtocol/barretenberg/issues/851): We should move away from this hack by
// returning pair of uint64_t[2] instead of a half-set field
#if !defined(__clang__) && defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
#endif
k2.data[0] = ret.second[0]; // NOLINT
k2.data[1] = ret.second[1];
#if !defined(__clang__) && defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
}
}
// NOTE: this form is only usable if the modulus is 254 bits or less, otherwise see
// split_into_endomorphism_scalars_384.
// TODO(https://github.com/AztecProtocol/barretenberg/issues/851): Unify these APIs.
static std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> split_into_endomorphism_scalars(const field& k)
{
static_assert(Params::modulus_3 < 0x4000000000000000ULL);
field input = k.reduce_once();
constexpr field endo_g1 = { Params::endo_g1_lo, Params::endo_g1_mid, Params::endo_g1_hi, 0 };
constexpr field endo_g2 = { Params::endo_g2_lo, Params::endo_g2_mid, 0, 0 };
constexpr field endo_minus_b1 = { Params::endo_minus_b1_lo, Params::endo_minus_b1_mid, 0, 0 };
constexpr field endo_b2 = { Params::endo_b2_lo, Params::endo_b2_mid, 0, 0 };
// compute c1 = (g2 * k) >> 256
wide_array c1 = endo_g2.mul_512(input);
// compute c2 = (g1 * k) >> 256
wide_array c2 = endo_g1.mul_512(input);
// (the bit shifts are implicit, as we only utilize the high limbs of c1, c2
field c1_hi = {
c1.data[4], c1.data[5], c1.data[6], c1.data[7]
}; // *(field*)((uintptr_t)(&c1) + (4 * sizeof(uint64_t)));
field c2_hi = {
c2.data[4], c2.data[5], c2.data[6], c2.data[7]
}; // *(field*)((uintptr_t)(&c2) + (4 * sizeof(uint64_t)));
// compute q1 = c1 * -b1
wide_array q1 = c1_hi.mul_512(endo_minus_b1);
// compute q2 = c2 * b2
wide_array q2 = c2_hi.mul_512(endo_b2);
// FIX: Avoid using 512-bit multiplication as its not necessary.
// c1_hi, c2_hi can be uint256_t's and the final result (without montgomery reduction)
// could be casted to a field.
field q1_lo{ q1.data[0], q1.data[1], q1.data[2], q1.data[3] };
field q2_lo{ q2.data[0], q2.data[1], q2.data[2], q2.data[3] };
field t1 = (q2_lo - q1_lo).reduce_once();
field beta = cube_root_of_unity();
field t2 = (t1 * beta + input).reduce_once();
return {
{ t2.data[0], t2.data[1] },
{ t1.data[0], t1.data[1] },
};
}
static void split_into_endomorphism_scalars_384(const field& input, field& k1_out, field& k2_out)
{
constexpr field minus_b1f{
Params::endo_minus_b1_lo,
Params::endo_minus_b1_mid,
0,
0,
};
constexpr field b2f{
Params::endo_b2_lo,
Params::endo_b2_mid,
0,
0,
};
constexpr uint256_t g1{
Params::endo_g1_lo,
Params::endo_g1_mid,
Params::endo_g1_hi,
Params::endo_g1_hihi,
};
constexpr uint256_t g2{
Params::endo_g2_lo,
Params::endo_g2_mid,
Params::endo_g2_hi,
Params::endo_g2_hihi,
};
field kf = input.reduce_once();
uint256_t k{ kf.data[0], kf.data[1], kf.data[2], kf.data[3] };
uint512_t c1 = (uint512_t(k) * static_cast<uint512_t>(g1)) >> 384;
uint512_t c2 = (uint512_t(k) * static_cast<uint512_t>(g2)) >> 384;
field c1f{ c1.lo.data[0], c1.lo.data[1], c1.lo.data[2], c1.lo.data[3] };
field c2f{ c2.lo.data[0], c2.lo.data[1], c2.lo.data[2], c2.lo.data[3] };
c1f.self_to_montgomery_form();
c2f.self_to_montgomery_form();
c1f = c1f * minus_b1f;
c2f = c2f * b2f;
field r2f = c1f - c2f;
field beta = cube_root_of_unity();
field r1f = input.reduce_once() - r2f * beta;
k1_out = r1f;
k2_out = -r2f;
}
// static constexpr auto coset_generators = compute_coset_generators();
// static constexpr std::array<field, 15> coset_generators = compute_coset_generators((1 << 30U));
friend std::ostream& operator<<(std::ostream& os, const field& a)
{
field out = a.from_montgomery_form();
std::ios_base::fmtflags f(os.flags());
os << std::hex << "0x" << std::setfill('0') << std::setw(16) << out.data[3] << std::setw(16) << out.data[2]
<< std::setw(16) << out.data[1] << std::setw(16) << out.data[0];
os.flags(f);
return os;
}
BB_INLINE static void __copy(const field& a, field& r) noexcept { r = a; } // NOLINT
BB_INLINE static void __swap(field& src, field& dest) noexcept // NOLINT
{
field T = dest;
dest = src;
src = T;
}
static field random_element(numeric::RNG* engine = nullptr) noexcept;
static constexpr field multiplicative_generator() noexcept;
static constexpr uint256_t twice_modulus = modulus + modulus;
static constexpr uint256_t not_modulus = -modulus;
static constexpr uint256_t twice_not_modulus = -twice_modulus;
struct wnaf_table {
uint8_t windows[64]; // NOLINT
constexpr wnaf_table(const uint256_t& target)
: windows{
static_cast<uint8_t>(target.data[0] & 15), static_cast<uint8_t>((target.data[0] >> 4) & 15),
static_cast<uint8_t>((target.data[0] >> 8) & 15), static_cast<uint8_t>((target.data[0] >> 12) & 15),
static_cast<uint8_t>((target.data[0] >> 16) & 15), static_cast<uint8_t>((target.data[0] >> 20) & 15),
static_cast<uint8_t>((target.data[0] >> 24) & 15), static_cast<uint8_t>((target.data[0] >> 28) & 15),
static_cast<uint8_t>((target.data[0] >> 32) & 15), static_cast<uint8_t>((target.data[0] >> 36) & 15),
static_cast<uint8_t>((target.data[0] >> 40) & 15), static_cast<uint8_t>((target.data[0] >> 44) & 15),
static_cast<uint8_t>((target.data[0] >> 48) & 15), static_cast<uint8_t>((target.data[0] >> 52) & 15),
static_cast<uint8_t>((target.data[0] >> 56) & 15), static_cast<uint8_t>((target.data[0] >> 60) & 15),
static_cast<uint8_t>(target.data[1] & 15), static_cast<uint8_t>((target.data[1] >> 4) & 15),
static_cast<uint8_t>((target.data[1] >> 8) & 15), static_cast<uint8_t>((target.data[1] >> 12) & 15),
static_cast<uint8_t>((target.data[1] >> 16) & 15), static_cast<uint8_t>((target.data[1] >> 20) & 15),
static_cast<uint8_t>((target.data[1] >> 24) & 15), static_cast<uint8_t>((target.data[1] >> 28) & 15),
static_cast<uint8_t>((target.data[1] >> 32) & 15), static_cast<uint8_t>((target.data[1] >> 36) & 15),
static_cast<uint8_t>((target.data[1] >> 40) & 15), static_cast<uint8_t>((target.data[1] >> 44) & 15),
static_cast<uint8_t>((target.data[1] >> 48) & 15), static_cast<uint8_t>((target.data[1] >> 52) & 15),
static_cast<uint8_t>((target.data[1] >> 56) & 15), static_cast<uint8_t>((target.data[1] >> 60) & 15),
static_cast<uint8_t>(target.data[2] & 15), static_cast<uint8_t>((target.data[2] >> 4) & 15),
static_cast<uint8_t>((target.data[2] >> 8) & 15), static_cast<uint8_t>((target.data[2] >> 12) & 15),
static_cast<uint8_t>((target.data[2] >> 16) & 15), static_cast<uint8_t>((target.data[2] >> 20) & 15),
static_cast<uint8_t>((target.data[2] >> 24) & 15), static_cast<uint8_t>((target.data[2] >> 28) & 15),
static_cast<uint8_t>((target.data[2] >> 32) & 15), static_cast<uint8_t>((target.data[2] >> 36) & 15),
static_cast<uint8_t>((target.data[2] >> 40) & 15), static_cast<uint8_t>((target.data[2] >> 44) & 15),
static_cast<uint8_t>((target.data[2] >> 48) & 15), static_cast<uint8_t>((target.data[2] >> 52) & 15),
static_cast<uint8_t>((target.data[2] >> 56) & 15), static_cast<uint8_t>((target.data[2] >> 60) & 15),
static_cast<uint8_t>(target.data[3] & 15), static_cast<uint8_t>((target.data[3] >> 4) & 15),
static_cast<uint8_t>((target.data[3] >> 8) & 15), static_cast<uint8_t>((target.data[3] >> 12) & 15),
static_cast<uint8_t>((target.data[3] >> 16) & 15), static_cast<uint8_t>((target.data[3] >> 20) & 15),
static_cast<uint8_t>((target.data[3] >> 24) & 15), static_cast<uint8_t>((target.data[3] >> 28) & 15),
static_cast<uint8_t>((target.data[3] >> 32) & 15), static_cast<uint8_t>((target.data[3] >> 36) & 15),
static_cast<uint8_t>((target.data[3] >> 40) & 15), static_cast<uint8_t>((target.data[3] >> 44) & 15),
static_cast<uint8_t>((target.data[3] >> 48) & 15), static_cast<uint8_t>((target.data[3] >> 52) & 15),
static_cast<uint8_t>((target.data[3] >> 56) & 15), static_cast<uint8_t>((target.data[3] >> 60) & 15)
}
{}
};
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
BB_INLINE static constexpr void wasm_madd(uint64_t& left_limb,
const std::array<uint64_t, WASM_NUM_LIMBS>& right_limbs,
uint64_t& result_0,
uint64_t& result_1,
uint64_t& result_2,
uint64_t& result_3,
uint64_t& result_4,
uint64_t& result_5,
uint64_t& result_6,
uint64_t& result_7,
uint64_t& result_8);
BB_INLINE static constexpr void wasm_reduce(uint64_t& result_0,
uint64_t& result_1,
uint64_t& result_2,
uint64_t& result_3,
uint64_t& result_4,
uint64_t& result_5,
uint64_t& result_6,
uint64_t& result_7,
uint64_t& result_8);
BB_INLINE static constexpr std::array<uint64_t, WASM_NUM_LIMBS> wasm_convert(const uint64_t* data);
#endif
BB_INLINE static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b) noexcept;
BB_INLINE static constexpr uint64_t mac(
uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& carry_out) noexcept;
BB_INLINE static constexpr void mac(
uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& out, uint64_t& carry_out) noexcept;
BB_INLINE static constexpr uint64_t mac_mini(uint64_t a, uint64_t b, uint64_t c, uint64_t& out) noexcept;
BB_INLINE static constexpr void mac_mini(
uint64_t a, uint64_t b, uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept;
BB_INLINE static constexpr uint64_t mac_discard_lo(uint64_t a, uint64_t b, uint64_t c) noexcept;
BB_INLINE static constexpr uint64_t addc(uint64_t a, uint64_t b, uint64_t carry_in, uint64_t& carry_out) noexcept;
BB_INLINE static constexpr uint64_t sbb(uint64_t a, uint64_t b, uint64_t borrow_in, uint64_t& borrow_out) noexcept;
BB_INLINE static constexpr uint64_t square_accumulate(uint64_t a,
uint64_t b,
uint64_t c,
uint64_t carry_in_lo,
uint64_t carry_in_hi,
uint64_t& carry_lo,
uint64_t& carry_hi) noexcept;
BB_INLINE constexpr field reduce() const noexcept;
BB_INLINE constexpr field add(const field& other) const noexcept;
BB_INLINE constexpr field subtract(const field& other) const noexcept;
BB_INLINE constexpr field subtract_coarse(const field& other) const noexcept;
BB_INLINE constexpr field montgomery_mul(const field& other) const noexcept;
BB_INLINE constexpr field montgomery_mul_big(const field& other) const noexcept;
BB_INLINE constexpr field montgomery_square() const noexcept;
#if (BBERG_NO_ASM == 0)
BB_INLINE static field asm_mul(const field& a, const field& b) noexcept;
BB_INLINE static field asm_sqr(const field& a) noexcept;
BB_INLINE static field asm_add(const field& a, const field& b) noexcept;
BB_INLINE static field asm_sub(const field& a, const field& b) noexcept;
BB_INLINE static field asm_mul_with_coarse_reduction(const field& a, const field& b) noexcept;
BB_INLINE static field asm_sqr_with_coarse_reduction(const field& a) noexcept;
BB_INLINE static field asm_add_with_coarse_reduction(const field& a, const field& b) noexcept;
BB_INLINE static field asm_sub_with_coarse_reduction(const field& a, const field& b) noexcept;
BB_INLINE static field asm_add_without_reduction(const field& a, const field& b) noexcept;
BB_INLINE static void asm_self_sqr(const field& a) noexcept;
BB_INLINE static void asm_self_add(const field& a, const field& b) noexcept;
BB_INLINE static void asm_self_sub(const field& a, const field& b) noexcept;
BB_INLINE static void asm_self_mul_with_coarse_reduction(const field& a, const field& b) noexcept;
BB_INLINE static void asm_self_sqr_with_coarse_reduction(const field& a) noexcept;
BB_INLINE static void asm_self_add_with_coarse_reduction(const field& a, const field& b) noexcept;
BB_INLINE static void asm_self_sub_with_coarse_reduction(const field& a, const field& b) noexcept;
BB_INLINE static void asm_self_add_without_reduction(const field& a, const field& b) noexcept;
BB_INLINE static void asm_conditional_negate(field& r, uint64_t predicate) noexcept;
BB_INLINE static field asm_reduce_once(const field& a) noexcept;
BB_INLINE static void asm_self_reduce_once(const field& a) noexcept;
static constexpr uint64_t zero_reference = 0x00ULL;
#endif
static constexpr size_t COSET_GENERATOR_SIZE = 15;
constexpr field tonelli_shanks_sqrt() const noexcept;
static constexpr size_t primitive_root_log_size() noexcept;
static constexpr std::array<field, COSET_GENERATOR_SIZE> compute_coset_generators() noexcept;
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr uint128_t lo_mask = 0xffffffffffffffffUL;
#endif
};
} // namespace bb

View File

@@ -0,0 +1,673 @@
#pragma once
#include "../../common/op_count.hpp"
#include "../../common/slab_allocator.hpp"
#include "../../common/throw_or_abort.hpp"
#include "../../numeric/bitop/get_msb.hpp"
#include "../../numeric/random/engine.hpp"
#include <memory>
#include <span>
#include <type_traits>
#include <vector>
#include "./field_declarations.hpp"
namespace bb {
using namespace numeric;
// clang-format off
// disable the following style guides:
// cppcoreguidelines-avoid-c-arrays : we make heavy use of c-style arrays here to prevent default-initialization of memory when constructing `field` objects.
// The intention is for field to act like a primitive numeric type with the performance/complexity trade-offs expected from this.
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
// clang-format on
/**
*
* Mutiplication
*
**/
template <class T> constexpr field<T> field<T>::operator*(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::mul");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
// >= 255-bits or <= 64-bits.
return montgomery_mul(other);
} else {
if (std::is_constant_evaluated()) {
return montgomery_mul(other);
}
return asm_mul_with_coarse_reduction(*this, other);
}
}
template <class T> constexpr field<T>& field<T>::operator*=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_mul");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
// >= 255-bits or <= 64-bits.
*this = operator*(other);
} else {
if (std::is_constant_evaluated()) {
*this = operator*(other);
} else {
asm_self_mul_with_coarse_reduction(*this, other); // asm_self_mul(*this, other);
}
}
return *this;
}
/**
*
* Squaring
*
**/
template <class T> constexpr field<T> field<T>::sqr() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::sqr");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
return montgomery_square();
} else {
if (std::is_constant_evaluated()) {
return montgomery_square();
}
return asm_sqr_with_coarse_reduction(*this); // asm_sqr(*this);
}
}
template <class T> constexpr void field<T>::self_sqr() noexcept
{
BB_OP_COUNT_TRACK_NAME("f::self_sqr");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
*this = montgomery_square();
} else {
if (std::is_constant_evaluated()) {
*this = montgomery_square();
} else {
asm_self_sqr_with_coarse_reduction(*this);
}
}
}
/**
*
* Addition
*
**/
template <class T> constexpr field<T> field<T>::operator+(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::add");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
return add(other);
} else {
if (std::is_constant_evaluated()) {
return add(other);
}
return asm_add_with_coarse_reduction(*this, other); // asm_add_without_reduction(*this, other);
}
}
template <class T> constexpr field<T>& field<T>::operator+=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_add");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
(*this) = operator+(other);
} else {
if (std::is_constant_evaluated()) {
(*this) = operator+(other);
} else {
asm_self_add_with_coarse_reduction(*this, other); // asm_self_add(*this, other);
}
}
return *this;
}
template <class T> constexpr field<T> field<T>::operator++() noexcept
{
BB_OP_COUNT_TRACK_NAME("++f");
return *this += 1;
}
// NOLINTNEXTLINE(cert-dcl21-cpp) circular linting errors. If const is added, linter suggests removing
template <class T> constexpr field<T> field<T>::operator++(int) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::increment");
field<T> value_before_incrementing = *this;
*this += 1;
return value_before_incrementing;
}
/**
*
* Subtraction
*
**/
template <class T> constexpr field<T> field<T>::operator-(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::sub");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
return subtract_coarse(other); // modulus - *this;
} else {
if (std::is_constant_evaluated()) {
return subtract_coarse(other); // subtract(other);
}
return asm_sub_with_coarse_reduction(*this, other); // asm_sub(*this, other);
}
}
template <class T> constexpr field<T> field<T>::operator-() const noexcept
{
BB_OP_COUNT_TRACK_NAME("-f");
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
return p - *this; // modulus - *this;
}
// TODO(@zac-williamson): there are 3 ways we can make this more efficient
// 1: we subtract `p` from `*this` instead of `2p`
// 2: instead of `p - *this`, we use an asm block that does `p - *this` without the assembly reduction step
// 3: we replace `(p - *this).reduce_once()` with an assembly block that is equivalent to `p - *this`,
// but we call `REDUCE_FIELD_ELEMENT` with `not_twice_modulus` instead of `twice_modulus`
// not sure which is faster and whether any of the above might break something!
//
// More context below:
// the operator-(a, b) method's asm implementation has a sneaky was to check underflow.
// if `a - b` underflows we need to add in `2p`. Instead of conditional branching which would cause pipeline
// flushes, we add `2p` into the result of `a - b`. If the result triggers the overflow flag, then we know we are
// correcting an *underflow* produced from computing `a - b`. Finally...we use the overflow flag to conditionally
// move data into registers such that we end up with either `a - b` or `2p + (a - b)` (this is branchless). OK! So
// what's the problem? Well we assume that every field element lies between 0 and 2p - 1. But we are computing `2p -
// *this`! If *this = 0 then we exceed this bound hence the need for the extra reduction step. HOWEVER, we also know
// that 2p - *this won't underflow so we could skip the underflow check present in the assembly code
constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] };
return (p - *this).reduce_once(); // modulus - *this;
}
template <class T> constexpr field<T>& field<T>::operator-=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_sub");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
*this = subtract_coarse(other); // subtract(other);
} else {
if (std::is_constant_evaluated()) {
*this = subtract_coarse(other); // subtract(other);
} else {
asm_self_sub_with_coarse_reduction(*this, other); // asm_self_sub(*this, other);
}
}
return *this;
}
template <class T> constexpr void field<T>::self_neg() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_neg");
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
*this = p - *this;
} else {
constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] };
*this = (p - *this).reduce_once();
}
}
template <class T> constexpr void field<T>::self_conditional_negate(const uint64_t predicate) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_conditional_negate");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
*this = predicate ? -(*this) : *this; // NOLINT
} else {
if (std::is_constant_evaluated()) {
*this = predicate ? -(*this) : *this; // NOLINT
} else {
asm_conditional_negate(*this, predicate);
}
}
}
/**
* @brief Greater-than operator
* @details comparison operators exist so that `field` is comparible with stl methods that require them.
* (e.g. std::sort)
* Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
*
* @tparam T
* @param other
* @return true
* @return false
*/
template <class T> constexpr bool field<T>::operator>(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::gt");
const field left = reduce_once();
const field right = other.reduce_once();
const bool t0 = left.data[3] > right.data[3];
const bool t1 = (left.data[3] == right.data[3]) && (left.data[2] > right.data[2]);
const bool t2 =
(left.data[3] == right.data[3]) && (left.data[2] == right.data[2]) && (left.data[1] > right.data[1]);
const bool t3 = (left.data[3] == right.data[3]) && (left.data[2] == right.data[2]) &&
(left.data[1] == right.data[1]) && (left.data[0] > right.data[0]);
return (t0 || t1 || t2 || t3);
}
/**
* @brief Less-than operator
* @details comparison operators exist so that `field` is comparible with stl methods that require them.
* (e.g. std::sort)
* Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms.
*
* @tparam T
* @param other
* @return true
* @return false
*/
template <class T> constexpr bool field<T>::operator<(const field& other) const noexcept
{
return (other > *this);
}
template <class T> constexpr bool field<T>::operator==(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::eqeq");
const field left = reduce_once();
const field right = other.reduce_once();
return (left.data[0] == right.data[0]) && (left.data[1] == right.data[1]) && (left.data[2] == right.data[2]) &&
(left.data[3] == right.data[3]);
}
template <class T> constexpr bool field<T>::operator!=(const field& other) const noexcept
{
return (!operator==(other));
}
template <class T> constexpr field<T> field<T>::to_montgomery_form() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::to_montgomery_form");
constexpr field r_squared =
field{ r_squared_uint.data[0], r_squared_uint.data[1], r_squared_uint.data[2], r_squared_uint.data[3] };
field result = *this;
// TODO(@zac-williamson): are these reductions needed?
// Rationale: We want to take any 256-bit input and be able to convert into montgomery form.
// A basic heuristic we use is that any input into the `*` operator must be between [0, 2p - 1]
// to prevent overflows in the asm algorithm.
// However... r_squared is already reduced so perhaps we can relax this requirement?
// (would be good to identify a failure case where not calling self_reduce triggers an error)
result.self_reduce_once();
result.self_reduce_once();
result.self_reduce_once();
return (result * r_squared).reduce_once();
}
template <class T> constexpr field<T> field<T>::from_montgomery_form() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::from_montgomery_form");
constexpr field one_raw{ 1, 0, 0, 0 };
return operator*(one_raw).reduce_once();
}
template <class T> constexpr void field<T>::self_to_montgomery_form() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_to_montgomery_form");
constexpr field r_squared =
field{ r_squared_uint.data[0], r_squared_uint.data[1], r_squared_uint.data[2], r_squared_uint.data[3] };
self_reduce_once();
self_reduce_once();
self_reduce_once();
*this *= r_squared;
self_reduce_once();
}
template <class T> constexpr void field<T>::self_from_montgomery_form() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_from_montgomery_form");
constexpr field one_raw{ 1, 0, 0, 0 };
*this *= one_raw;
self_reduce_once();
}
template <class T> constexpr field<T> field<T>::reduce_once() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::reduce_once");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
return reduce();
} else {
if (std::is_constant_evaluated()) {
return reduce();
}
return asm_reduce_once(*this);
}
}
template <class T> constexpr void field<T>::self_reduce_once() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_reduce_once");
if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
*this = reduce();
} else {
if (std::is_constant_evaluated()) {
*this = reduce();
} else {
asm_self_reduce_once(*this);
}
}
}
template <class T> constexpr field<T> field<T>::pow(const uint256_t& exponent) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::pow");
field accumulator{ data[0], data[1], data[2], data[3] };
field to_mul{ data[0], data[1], data[2], data[3] };
const uint64_t maximum_set_bit = exponent.get_msb();
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
accumulator.self_sqr();
if (exponent.get_bit(static_cast<uint64_t>(i))) {
accumulator *= to_mul;
}
}
if (exponent == uint256_t(0)) {
accumulator = one();
} else if (*this == zero()) {
accumulator = zero();
}
return accumulator;
}
template <class T> constexpr field<T> field<T>::pow(const uint64_t exponent) const noexcept
{
return pow({ exponent, 0, 0, 0 });
}
template <class T> constexpr field<T> field<T>::invert() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::invert");
if (*this == zero()) {
throw_or_abort("Trying to invert zero in the field");
}
return pow(modulus_minus_two);
}
template <class T> void field<T>::batch_invert(field* coeffs, const size_t n) noexcept
{
batch_invert(std::span{ coeffs, n });
}
template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::batch_invert");
const size_t n = coeffs.size();
auto temporaries_ptr = std::static_pointer_cast<field[]>(get_mem_slab(n * sizeof(field)));
auto skipped_ptr = std::static_pointer_cast<bool[]>(get_mem_slab(n));
auto temporaries = temporaries_ptr.get();
auto* skipped = skipped_ptr.get();
field accumulator = one();
for (size_t i = 0; i < n; ++i) {
temporaries[i] = accumulator;
if (coeffs[i].is_zero()) {
skipped[i] = true;
} else {
skipped[i] = false;
accumulator *= coeffs[i];
}
}
// std::vector<field> temporaries;
// std::vector<bool> skipped;
// temporaries.reserve(n);
// skipped.reserve(n);
// field accumulator = one();
// for (size_t i = 0; i < n; ++i) {
// temporaries.emplace_back(accumulator);
// if (coeffs[i].is_zero()) {
// skipped.emplace_back(true);
// } else {
// skipped.emplace_back(false);
// accumulator *= coeffs[i];
// }
// }
accumulator = accumulator.invert();
field T0;
for (size_t i = n - 1; i < n; --i) {
if (!skipped[i]) {
T0 = accumulator * temporaries[i];
accumulator *= coeffs[i];
coeffs[i] = T0;
}
}
}
template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt");
// Tonelli-shanks algorithm begins by finding a field element Q and integer S,
// such that (p - 1) = Q.2^{s}
// We can compute the square root of a, by considering a^{(Q + 1) / 2} = R
// Once we have found such an R, we have
// R^{2} = a^{Q + 1} = a^{Q}a
// If a^{Q} = 1, we have found our square root.
// Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity.
// This is because t^{2^{s-1}} = a^{Q.2^{s-1}}.
// We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2}
// From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1
// i.e. t^{2^{s-1}} = 1
// To proceed with computing our square root, we want to transform t into a smaller subgroup,
// specifically, the (s-2)'th roots of unity.
// We do this by finding some value b,such that
// (t.b^2)^{2^{s-2}} = 1 and R' = R.b
// Finding such a b is trivial, because from Euler's criterion, we know that,
// for any quadratic non-residue z, z^{(p - 1) / 2} = -1
// i.e. z^{Q.2^{s-1}} = -1
// => z^Q is a 2^{s-1}'th root of -1
// => z^{Q^2} is a 2^{s-2}'th root of -1
// Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1
// => t.z^{Q^2} is a 2^{s - 2}'th root of unity.
// We can iteratively transform t into ever smaller subgroups, until t = 1.
// At each iteration, we need to find a new value for b, which we can obtain
// by repeatedly squaring z^{Q}
constexpr uint256_t Q = (modulus - 1) >> static_cast<uint64_t>(primitive_root_log_size() - 1);
constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 2;
// __to_montgomery_form(Q_minus_one_over_two, Q_minus_one_over_two);
field z = coset_generator(0); // the generator is a non-residue
field b = pow(Q_minus_one_over_two);
field r = operator*(b); // r = a^{(Q + 1) / 2}
field t = r * b; // t = a^{(Q - 1) / 2 + (Q + 1) / 2} = a^{Q}
// check if t is a square with euler's criterion
// if not, we don't have a quadratic residue and a has no square root!
field check = t;
for (size_t i = 0; i < primitive_root_log_size() - 1; ++i) {
check.self_sqr();
}
if (check != one()) {
return zero();
}
field t1 = z.pow(Q_minus_one_over_two);
field t2 = t1 * z;
field c = t2 * t1; // z^Q
size_t m = primitive_root_log_size();
while (t != one()) {
size_t i = 0;
field t2m = t;
// find the smallest value of m, such that t^{2^m} = 1
while (t2m != one()) {
t2m.self_sqr();
i += 1;
}
size_t j = m - i - 1;
b = c;
while (j > 0) {
b.self_sqr();
--j;
} // b = z^2^(m-i-1)
c = b.sqr();
t = t * c;
r = r * b;
m = i;
}
return r;
}
template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::sqrt");
field root;
if constexpr ((T::modulus_0 & 0x3UL) == 0x3UL) {
constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2;
root = pow(sqrt_exponent);
} else {
root = tonelli_shanks_sqrt();
}
if ((root * root) == (*this)) {
return std::pair<bool, field>(true, root);
}
return std::pair<bool, field>(false, field::zero());
} // namespace bb;
template <class T> constexpr field<T> field<T>::operator/(const field& other) const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::div");
return operator*(other.invert());
}
template <class T> constexpr field<T>& field<T>::operator/=(const field& other) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_div");
*this = operator/(other);
return *this;
}
template <class T> constexpr void field<T>::self_set_msb() noexcept
{
data[3] = 0ULL | (1ULL << 63ULL);
}
template <class T> constexpr bool field<T>::is_msb_set() const noexcept
{
return (data[3] >> 63ULL) == 1ULL;
}
template <class T> constexpr uint64_t field<T>::is_msb_set_word() const noexcept
{
return (data[3] >> 63ULL);
}
template <class T> constexpr bool field<T>::is_zero() const noexcept
{
return ((data[0] | data[1] | data[2] | data[3]) == 0) ||
(data[0] == T::modulus_0 && data[1] == T::modulus_1 && data[2] == T::modulus_2 && data[3] == T::modulus_3);
}
template <class T> constexpr field<T> field<T>::get_root_of_unity(size_t subgroup_size) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
field r{ T::primitive_root_0, T::primitive_root_1, T::primitive_root_2, T::primitive_root_3 };
#else
field r{ T::primitive_root_wasm_0, T::primitive_root_wasm_1, T::primitive_root_wasm_2, T::primitive_root_wasm_3 };
#endif
for (size_t i = primitive_root_log_size(); i > subgroup_size; --i) {
r.self_sqr();
}
return r;
}
template <class T> field<T> field<T>::random_element(numeric::RNG* engine) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::random_element");
if (engine == nullptr) {
engine = &numeric::get_randomness();
}
uint512_t source = engine->get_random_uint512();
uint512_t q(modulus);
uint512_t reduced = source % q;
return field(reduced.lo);
}
template <class T> constexpr size_t field<T>::primitive_root_log_size() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::primitive_root_log_size");
uint256_t target = modulus - 1;
size_t result = 0;
while (!target.get_bit(result)) {
++result;
}
return result;
}
template <class T>
constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute_coset_generators() noexcept
{
constexpr size_t n = COSET_GENERATOR_SIZE;
constexpr uint64_t subgroup_size = 1 << 30;
std::array<field, COSET_GENERATOR_SIZE> result{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
if (n > 0) {
result[0] = (multiplicative_generator());
}
field work_variable = multiplicative_generator() + field(1);
size_t count = 1;
while (count < n) {
// work_variable contains a new field element, and we need to test that, for all previous vector elements,
// result[i] / work_variable is not a member of our subgroup
field work_inverse = work_variable.invert();
bool valid = true;
for (size_t j = 0; j < count; ++j) {
field subgroup_check = (work_inverse * result[j]).pow(subgroup_size);
if (subgroup_check == field(1)) {
valid = false;
break;
}
}
if (valid) {
result[count] = (work_variable);
++count;
}
work_variable += field(1);
}
return result;
}
template <class T> constexpr field<T> field<T>::multiplicative_generator() noexcept
{
field target(1);
uint256_t p_minus_one_over_two = (modulus - 1) >> 1;
bool found = false;
while (!found) {
target += field(1);
found = (target.pow(p_minus_one_over_two) == -field(1));
}
return target;
}
} // namespace bb
// clang-format off
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)
// clang-format on

View File

@@ -0,0 +1,945 @@
#pragma once
#include <array>
#include <cstdint>
#include "./field_impl.hpp"
#include "../../common/op_count.hpp"
namespace bb {
using namespace numeric;
// NOLINTBEGIN(readability-implicit-bool-conversion)
template <class T> constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide(uint64_t a, uint64_t b) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = (static_cast<uint128_t>(a) * static_cast<uint128_t>(b));
return { static_cast<uint64_t>(res), static_cast<uint64_t>(res >> 64) };
#else
const uint64_t product = a * b;
return { product & 0xffffffffULL, product >> 32 };
#endif
}
template <class T>
constexpr uint64_t field<T>::mac(
const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t carry_in, uint64_t& carry_out) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c)) +
static_cast<uint128_t>(carry_in);
carry_out = static_cast<uint64_t>(res >> 64);
return static_cast<uint64_t>(res);
#else
const uint64_t product = b * c + a + carry_in;
carry_out = product >> 32;
return product & 0xffffffffULL;
#endif
}
template <class T>
constexpr void field<T>::mac(const uint64_t a,
const uint64_t b,
const uint64_t c,
const uint64_t carry_in,
uint64_t& out,
uint64_t& carry_out) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c)) +
static_cast<uint128_t>(carry_in);
out = static_cast<uint64_t>(res);
carry_out = static_cast<uint64_t>(res >> 64);
#else
const uint64_t product = b * c + a + carry_in;
carry_out = product >> 32;
out = product & 0xffffffffULL;
#endif
}
template <class T>
constexpr uint64_t field<T>::mac_mini(const uint64_t a,
const uint64_t b,
const uint64_t c,
uint64_t& carry_out) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
carry_out = static_cast<uint64_t>(res >> 64);
return static_cast<uint64_t>(res);
#else
const uint64_t product = b * c + a;
carry_out = product >> 32;
return product & 0xffffffffULL;
#endif
}
template <class T>
constexpr void field<T>::mac_mini(
const uint64_t a, const uint64_t b, const uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
out = static_cast<uint64_t>(res);
carry_out = static_cast<uint64_t>(res >> 64);
#else
const uint64_t result = b * c + a;
carry_out = result >> 32;
out = result & 0xffffffffULL;
#endif
}
template <class T>
constexpr uint64_t field<T>::mac_discard_lo(const uint64_t a, const uint64_t b, const uint64_t c) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
return static_cast<uint64_t>(res >> 64);
#else
return (b * c + a) >> 32;
#endif
}
template <class T>
constexpr uint64_t field<T>::addc(const uint64_t a,
const uint64_t b,
const uint64_t carry_in,
uint64_t& carry_out) noexcept
{
BB_OP_COUNT_TRACK();
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
uint128_t res = static_cast<uint128_t>(a) + static_cast<uint128_t>(b) + static_cast<uint128_t>(carry_in);
carry_out = static_cast<uint64_t>(res >> 64);
return static_cast<uint64_t>(res);
#else
uint64_t r = a + b;
const uint64_t carry_temp = r < a;
r += carry_in;
carry_out = carry_temp + (r < carry_in);
return r;
#endif
}
template <class T>
constexpr uint64_t field<T>::sbb(const uint64_t a,
const uint64_t b,
const uint64_t borrow_in,
uint64_t& borrow_out) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
uint128_t res = static_cast<uint128_t>(a) - (static_cast<uint128_t>(b) + static_cast<uint128_t>(borrow_in >> 63));
borrow_out = static_cast<uint64_t>(res >> 64);
return static_cast<uint64_t>(res);
#else
uint64_t t_1 = a - (borrow_in >> 63ULL);
uint64_t borrow_temp_1 = t_1 > a;
uint64_t t_2 = t_1 - b;
uint64_t borrow_temp_2 = t_2 > t_1;
borrow_out = 0ULL - (borrow_temp_1 | borrow_temp_2);
return t_2;
#endif
}
template <class T>
constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
const uint64_t b,
const uint64_t c,
const uint64_t carry_in_lo,
const uint64_t carry_in_hi,
uint64_t& carry_lo,
uint64_t& carry_hi) noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const uint128_t product = static_cast<uint128_t>(b) * static_cast<uint128_t>(c);
const auto r0 = static_cast<uint64_t>(product);
const auto r1 = static_cast<uint64_t>(product >> 64);
uint64_t out = r0 + r0;
carry_lo = (out < r0);
out += a;
carry_lo += (out < a);
out += carry_in_lo;
carry_lo += (out < carry_in_lo);
carry_lo += r1;
carry_hi = (carry_lo < r1);
carry_lo += r1;
carry_hi += (carry_lo < r1);
carry_lo += carry_in_hi;
carry_hi += (carry_lo < carry_in_hi);
return out;
#else
const auto product = b * c;
const auto t0 = product + a + carry_in_lo;
const auto t1 = product + t0;
carry_hi = t1 < product;
const auto t2 = t1 + (carry_in_hi << 32);
carry_hi += t2 < t1;
carry_lo = t2 >> 32;
return t2 & 0xffffffffULL;
#endif
}
template <class T> constexpr field<T> field<T>::reduce() const noexcept
{
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
uint256_t val{ data[0], data[1], data[2], data[3] };
if (val >= modulus) {
val -= modulus;
}
return { val.data[0], val.data[1], val.data[2], val.data[3] };
}
uint64_t t0 = data[0] + not_modulus.data[0];
uint64_t c = t0 < data[0];
auto t1 = addc(data[1], not_modulus.data[1], c, c);
auto t2 = addc(data[2], not_modulus.data[2], c, c);
auto t3 = addc(data[3], not_modulus.data[3], c, c);
const uint64_t selection_mask = 0ULL - c; // 0xffff... if we have overflowed.
const uint64_t selection_mask_inverse = ~selection_mask;
// if we overflow, we want to swap
return {
(data[0] & selection_mask_inverse) | (t0 & selection_mask),
(data[1] & selection_mask_inverse) | (t1 & selection_mask),
(data[2] & selection_mask_inverse) | (t2 & selection_mask),
(data[3] & selection_mask_inverse) | (t3 & selection_mask),
};
}
template <class T> constexpr field<T> field<T>::add(const field& other) const noexcept
{
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
uint64_t r0 = data[0] + other.data[0];
uint64_t c = r0 < data[0];
auto r1 = addc(data[1], other.data[1], c, c);
auto r2 = addc(data[2], other.data[2], c, c);
auto r3 = addc(data[3], other.data[3], c, c);
if (c) {
uint64_t b = 0;
r0 = sbb(r0, modulus.data[0], b, b);
r1 = sbb(r1, modulus.data[1], b, b);
r2 = sbb(r2, modulus.data[2], b, b);
r3 = sbb(r3, modulus.data[3], b, b);
// Since both values are in [0, 2**256), the result is in [0, 2**257-2]. Subtracting one p might not be
// enough. We need to ensure that we've underflown the 0 and that might require subtracting an additional p
if (!b) {
b = 0;
r0 = sbb(r0, modulus.data[0], b, b);
r1 = sbb(r1, modulus.data[1], b, b);
r2 = sbb(r2, modulus.data[2], b, b);
r3 = sbb(r3, modulus.data[3], b, b);
}
}
return { r0, r1, r2, r3 };
} else {
uint64_t r0 = data[0] + other.data[0];
uint64_t c = r0 < data[0];
auto r1 = addc(data[1], other.data[1], c, c);
auto r2 = addc(data[2], other.data[2], c, c);
uint64_t r3 = data[3] + other.data[3] + c;
uint64_t t0 = r0 + twice_not_modulus.data[0];
c = t0 < twice_not_modulus.data[0];
uint64_t t1 = addc(r1, twice_not_modulus.data[1], c, c);
uint64_t t2 = addc(r2, twice_not_modulus.data[2], c, c);
uint64_t t3 = addc(r3, twice_not_modulus.data[3], c, c);
const uint64_t selection_mask = 0ULL - c;
const uint64_t selection_mask_inverse = ~selection_mask;
return {
(r0 & selection_mask_inverse) | (t0 & selection_mask),
(r1 & selection_mask_inverse) | (t1 & selection_mask),
(r2 & selection_mask_inverse) | (t2 & selection_mask),
(r3 & selection_mask_inverse) | (t3 & selection_mask),
};
}
}
template <class T> constexpr field<T> field<T>::subtract(const field& other) const noexcept
{
uint64_t borrow = 0;
uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow);
uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow);
uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow);
r0 += (modulus.data[0] & borrow);
uint64_t carry = r0 < (modulus.data[0] & borrow);
r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
r3 = addc(r3, (modulus.data[3] & borrow), carry, carry);
// The value being subtracted is in [0, 2**256), if we subtract 0 - 2*255 and then add p, the value will stay
// negative. If we are adding p, we need to check that we've overflown 2**256. If not, we should add p again
if (!carry) {
r0 += (modulus.data[0] & borrow);
uint64_t carry = r0 < (modulus.data[0] & borrow);
r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
r3 = addc(r3, (modulus.data[3] & borrow), carry, carry);
}
return { r0, r1, r2, r3 };
}
/**
* @brief
*
* @tparam T
* @param other
* @return constexpr field<T>
*/
template <class T> constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
{
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return subtract(other);
}
uint64_t borrow = 0;
uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow);
uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow);
uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow);
r0 += (twice_modulus.data[0] & borrow);
uint64_t carry = r0 < (twice_modulus.data[0] & borrow);
r1 = addc(r1, twice_modulus.data[1] & borrow, carry, carry);
r2 = addc(r2, twice_modulus.data[2] & borrow, carry, carry);
r3 += (twice_modulus.data[3] & borrow) + carry;
return { r0, r1, r2, r3 };
}
/**
* @brief Mongtomery multiplication for moduli > 2²⁵⁴
*
* @details Explanation of Montgomery form can be found in \ref field_docs_montgomery_explainer and the difference
* between WASM and generic versions is explained in \ref field_docs_architecture_details
*/
template <class T> constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
uint64_t c = 0;
uint64_t t0 = 0;
uint64_t t1 = 0;
uint64_t t2 = 0;
uint64_t t3 = 0;
uint64_t t4 = 0;
uint64_t t5 = 0;
uint64_t k = 0;
for (const auto& element : data) {
c = 0;
mac(t0, element, other.data[0], c, t0, c);
mac(t1, element, other.data[1], c, t1, c);
mac(t2, element, other.data[2], c, t2, c);
mac(t3, element, other.data[3], c, t3, c);
t4 = addc(t4, c, 0, t5);
c = 0;
k = t0 * T::r_inv;
c = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, k, modulus.data[1], c, t0, c);
mac(t2, k, modulus.data[2], c, t1, c);
mac(t3, k, modulus.data[3], c, t2, c);
t3 = addc(c, t4, 0, c);
t4 = t5 + c;
}
uint64_t borrow = 0;
uint64_t r0 = sbb(t0, modulus.data[0], borrow, borrow);
uint64_t r1 = sbb(t1, modulus.data[1], borrow, borrow);
uint64_t r2 = sbb(t2, modulus.data[2], borrow, borrow);
uint64_t r3 = sbb(t3, modulus.data[3], borrow, borrow);
borrow = borrow ^ (0ULL - t4);
r0 += (modulus.data[0] & borrow);
uint64_t carry = r0 < (modulus.data[0] & borrow);
r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
r3 += (modulus.data[3] & borrow) + carry;
return { r0, r1, r2, r3 };
#else
// Convert 4 64-bit limbs to 9 29-bit limbs
auto left = wasm_convert(data);
auto right = wasm_convert(other.data);
constexpr uint64_t mask = 0x1fffffff;
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
uint64_t temp_9 = 0;
uint64_t temp_10 = 0;
uint64_t temp_11 = 0;
uint64_t temp_12 = 0;
uint64_t temp_13 = 0;
uint64_t temp_14 = 0;
uint64_t temp_15 = 0;
uint64_t temp_16 = 0;
uint64_t temp_17 = 0;
// Multiply-add 0th limb of the left argument by all 9 limbs of the right arguemnt
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
// Instantly reduce
wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
// Continue for other limbs
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
// After all multiplications and additions, convert relaxed form to strict (all limbs are 29 bits)
temp_10 += temp_9 >> WASM_LIMB_BITS;
temp_9 &= mask;
temp_11 += temp_10 >> WASM_LIMB_BITS;
temp_10 &= mask;
temp_12 += temp_11 >> WASM_LIMB_BITS;
temp_11 &= mask;
temp_13 += temp_12 >> WASM_LIMB_BITS;
temp_12 &= mask;
temp_14 += temp_13 >> WASM_LIMB_BITS;
temp_13 &= mask;
temp_15 += temp_14 >> WASM_LIMB_BITS;
temp_14 &= mask;
temp_16 += temp_15 >> WASM_LIMB_BITS;
temp_15 &= mask;
temp_17 += temp_16 >> WASM_LIMB_BITS;
temp_16 &= mask;
uint64_t r_temp_0;
uint64_t r_temp_1;
uint64_t r_temp_2;
uint64_t r_temp_3;
uint64_t r_temp_4;
uint64_t r_temp_5;
uint64_t r_temp_6;
uint64_t r_temp_7;
uint64_t r_temp_8;
// Subtract modulus from result
r_temp_0 = temp_9 - wasm_modulus[0];
r_temp_1 = temp_10 - wasm_modulus[1] - ((r_temp_0) >> 63);
r_temp_2 = temp_11 - wasm_modulus[2] - ((r_temp_1) >> 63);
r_temp_3 = temp_12 - wasm_modulus[3] - ((r_temp_2) >> 63);
r_temp_4 = temp_13 - wasm_modulus[4] - ((r_temp_3) >> 63);
r_temp_5 = temp_14 - wasm_modulus[5] - ((r_temp_4) >> 63);
r_temp_6 = temp_15 - wasm_modulus[6] - ((r_temp_5) >> 63);
r_temp_7 = temp_16 - wasm_modulus[7] - ((r_temp_6) >> 63);
r_temp_8 = temp_17 - wasm_modulus[8] - ((r_temp_7) >> 63);
// Depending on whether the subtraction underflowed, choose original value or the result of subtraction
uint64_t new_mask = 0 - (r_temp_8 >> 63);
uint64_t inverse_mask = (~new_mask) & mask;
temp_9 = (temp_9 & new_mask) | (r_temp_0 & inverse_mask);
temp_10 = (temp_10 & new_mask) | (r_temp_1 & inverse_mask);
temp_11 = (temp_11 & new_mask) | (r_temp_2 & inverse_mask);
temp_12 = (temp_12 & new_mask) | (r_temp_3 & inverse_mask);
temp_13 = (temp_13 & new_mask) | (r_temp_4 & inverse_mask);
temp_14 = (temp_14 & new_mask) | (r_temp_5 & inverse_mask);
temp_15 = (temp_15 & new_mask) | (r_temp_6 & inverse_mask);
temp_16 = (temp_16 & new_mask) | (r_temp_7 & inverse_mask);
temp_17 = (temp_17 & new_mask) | (r_temp_8 & inverse_mask);
// Convert back to 4 64-bit limbs
return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
(temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
(temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
(temp_15 >> 18) | (temp_16 << 11) | (temp_17 << 40) };
#endif
}
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
/**
* @brief Multiply left limb by a sequence of 9 limbs and put into result variables
*
*/
template <class T>
constexpr void field<T>::wasm_madd(uint64_t& left_limb,
const std::array<uint64_t, WASM_NUM_LIMBS>& right_limbs,
uint64_t& result_0,
uint64_t& result_1,
uint64_t& result_2,
uint64_t& result_3,
uint64_t& result_4,
uint64_t& result_5,
uint64_t& result_6,
uint64_t& result_7,
uint64_t& result_8)
{
result_0 += left_limb * right_limbs[0];
result_1 += left_limb * right_limbs[1];
result_2 += left_limb * right_limbs[2];
result_3 += left_limb * right_limbs[3];
result_4 += left_limb * right_limbs[4];
result_5 += left_limb * right_limbs[5];
result_6 += left_limb * right_limbs[6];
result_7 += left_limb * right_limbs[7];
result_8 += left_limb * right_limbs[8];
}
/**
* @brief Perform 29-bit montgomery reduction on 1 limb (result_0 should be zero modulo 2**29 after this)
*
*/
template <class T>
constexpr void field<T>::wasm_reduce(uint64_t& result_0,
uint64_t& result_1,
uint64_t& result_2,
uint64_t& result_3,
uint64_t& result_4,
uint64_t& result_5,
uint64_t& result_6,
uint64_t& result_7,
uint64_t& result_8)
{
constexpr uint64_t mask = 0x1fffffff;
constexpr uint64_t r_inv = T::r_inv & mask;
uint64_t k = (result_0 * r_inv) & mask;
result_0 += k * wasm_modulus[0];
result_1 += k * wasm_modulus[1] + (result_0 >> WASM_LIMB_BITS);
result_2 += k * wasm_modulus[2];
result_3 += k * wasm_modulus[3];
result_4 += k * wasm_modulus[4];
result_5 += k * wasm_modulus[5];
result_6 += k * wasm_modulus[6];
result_7 += k * wasm_modulus[7];
result_8 += k * wasm_modulus[8];
}
/**
* @brief Convert 4 64-bit limbs into 9 29-bit limbs
*
*/
template <class T> constexpr std::array<uint64_t, WASM_NUM_LIMBS> field<T>::wasm_convert(const uint64_t* data)
{
return { data[0] & 0x1fffffff,
(data[0] >> WASM_LIMB_BITS) & 0x1fffffff,
((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6),
(data[1] >> 23) & 0x1fffffff,
((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12),
(data[2] >> 17) & 0x1fffffff,
((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18),
(data[3] >> 11) & 0x1fffffff,
(data[3] >> 40) & 0x1fffffff };
}
#endif
template <class T> constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
{
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return montgomery_mul_big(other);
}
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
auto [t0, c] = mul_wide(data[0], other.data[0]);
uint64_t k = t0 * T::r_inv;
uint64_t a = mac_discard_lo(t0, k, modulus.data[0]);
uint64_t t1 = mac_mini(a, data[0], other.data[1], a);
mac(t1, k, modulus.data[1], c, t0, c);
uint64_t t2 = mac_mini(a, data[0], other.data[2], a);
mac(t2, k, modulus.data[2], c, t1, c);
uint64_t t3 = mac_mini(a, data[0], other.data[3], a);
mac(t3, k, modulus.data[3], c, t2, c);
t3 = c + a;
mac_mini(t0, data[1], other.data[0], t0, a);
k = t0 * T::r_inv;
c = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, data[1], other.data[1], a, t1, a);
mac(t1, k, modulus.data[1], c, t0, c);
mac(t2, data[1], other.data[2], a, t2, a);
mac(t2, k, modulus.data[2], c, t1, c);
mac(t3, data[1], other.data[3], a, t3, a);
mac(t3, k, modulus.data[3], c, t2, c);
t3 = c + a;
mac_mini(t0, data[2], other.data[0], t0, a);
k = t0 * T::r_inv;
c = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, data[2], other.data[1], a, t1, a);
mac(t1, k, modulus.data[1], c, t0, c);
mac(t2, data[2], other.data[2], a, t2, a);
mac(t2, k, modulus.data[2], c, t1, c);
mac(t3, data[2], other.data[3], a, t3, a);
mac(t3, k, modulus.data[3], c, t2, c);
t3 = c + a;
mac_mini(t0, data[3], other.data[0], t0, a);
k = t0 * T::r_inv;
c = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, data[3], other.data[1], a, t1, a);
mac(t1, k, modulus.data[1], c, t0, c);
mac(t2, data[3], other.data[2], a, t2, a);
mac(t2, k, modulus.data[2], c, t1, c);
mac(t3, data[3], other.data[3], a, t3, a);
mac(t3, k, modulus.data[3], c, t2, c);
t3 = c + a;
return { t0, t1, t2, t3 };
#else
// Convert 4 64-bit limbs to 9 29-bit ones
auto left = wasm_convert(data);
auto right = wasm_convert(other.data);
constexpr uint64_t mask = 0x1fffffff;
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
uint64_t temp_9 = 0;
uint64_t temp_10 = 0;
uint64_t temp_11 = 0;
uint64_t temp_12 = 0;
uint64_t temp_13 = 0;
uint64_t temp_14 = 0;
uint64_t temp_15 = 0;
uint64_t temp_16 = 0;
// Perform a series of multiplications and reductions (we multiply 1 limb of left argument by the whole right
// argument and then reduce)
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
// Convert result to unrelaxed form (all limbs are 29 bits)
temp_10 += temp_9 >> WASM_LIMB_BITS;
temp_9 &= mask;
temp_11 += temp_10 >> WASM_LIMB_BITS;
temp_10 &= mask;
temp_12 += temp_11 >> WASM_LIMB_BITS;
temp_11 &= mask;
temp_13 += temp_12 >> WASM_LIMB_BITS;
temp_12 &= mask;
temp_14 += temp_13 >> WASM_LIMB_BITS;
temp_13 &= mask;
temp_15 += temp_14 >> WASM_LIMB_BITS;
temp_14 &= mask;
temp_16 += temp_15 >> WASM_LIMB_BITS;
temp_15 &= mask;
// Convert back to 4 64-bit limbs form
return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
(temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
(temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
(temp_15 >> 18) | (temp_16 << 11) };
#endif
}
template <class T> constexpr field<T> field<T>::montgomery_square() const noexcept
{
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return montgomery_mul_big(*this);
}
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
uint64_t carry_hi = 0;
auto [t0, carry_lo] = mul_wide(data[0], data[0]);
uint64_t t1 = square_accumulate(0, data[1], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
uint64_t t2 = square_accumulate(0, data[2], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
uint64_t t3 = square_accumulate(0, data[3], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
uint64_t round_carry = carry_lo;
uint64_t k = t0 * T::r_inv;
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
t3 = carry_lo + round_carry;
t1 = mac_mini(t1, data[1], data[1], carry_lo);
carry_hi = 0;
t2 = square_accumulate(t2, data[2], data[1], carry_lo, carry_hi, carry_lo, carry_hi);
t3 = square_accumulate(t3, data[3], data[1], carry_lo, carry_hi, carry_lo, carry_hi);
round_carry = carry_lo;
k = t0 * T::r_inv;
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
t3 = carry_lo + round_carry;
t2 = mac_mini(t2, data[2], data[2], carry_lo);
carry_hi = 0;
t3 = square_accumulate(t3, data[3], data[2], carry_lo, carry_hi, carry_lo, carry_hi);
round_carry = carry_lo;
k = t0 * T::r_inv;
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
t3 = carry_lo + round_carry;
t3 = mac_mini(t3, data[3], data[3], carry_lo);
k = t0 * T::r_inv;
round_carry = carry_lo;
carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
t3 = carry_lo + round_carry;
return { t0, t1, t2, t3 };
#else
// Convert from 4 64-bit limbs to 9 29-bit ones
auto left = wasm_convert(data);
constexpr uint64_t mask = 0x1fffffff;
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
uint64_t temp_9 = 0;
uint64_t temp_10 = 0;
uint64_t temp_11 = 0;
uint64_t temp_12 = 0;
uint64_t temp_13 = 0;
uint64_t temp_14 = 0;
uint64_t temp_15 = 0;
uint64_t temp_16 = 0;
uint64_t acc;
// Perform multiplications, but accumulated results for limb k=i+j so that we can double them at the same time
temp_0 += left[0] * left[0];
acc = 0;
acc += left[0] * left[1];
temp_1 += (acc << 1);
acc = 0;
acc += left[0] * left[2];
temp_2 += left[1] * left[1];
temp_2 += (acc << 1);
acc = 0;
acc += left[0] * left[3];
acc += left[1] * left[2];
temp_3 += (acc << 1);
acc = 0;
acc += left[0] * left[4];
acc += left[1] * left[3];
temp_4 += left[2] * left[2];
temp_4 += (acc << 1);
acc = 0;
acc += left[0] * left[5];
acc += left[1] * left[4];
acc += left[2] * left[3];
temp_5 += (acc << 1);
acc = 0;
acc += left[0] * left[6];
acc += left[1] * left[5];
acc += left[2] * left[4];
temp_6 += left[3] * left[3];
temp_6 += (acc << 1);
acc = 0;
acc += left[0] * left[7];
acc += left[1] * left[6];
acc += left[2] * left[5];
acc += left[3] * left[4];
temp_7 += (acc << 1);
acc = 0;
acc += left[0] * left[8];
acc += left[1] * left[7];
acc += left[2] * left[6];
acc += left[3] * left[5];
temp_8 += left[4] * left[4];
temp_8 += (acc << 1);
acc = 0;
acc += left[1] * left[8];
acc += left[2] * left[7];
acc += left[3] * left[6];
acc += left[4] * left[5];
temp_9 += (acc << 1);
acc = 0;
acc += left[2] * left[8];
acc += left[3] * left[7];
acc += left[4] * left[6];
temp_10 += left[5] * left[5];
temp_10 += (acc << 1);
acc = 0;
acc += left[3] * left[8];
acc += left[4] * left[7];
acc += left[5] * left[6];
temp_11 += (acc << 1);
acc = 0;
acc += left[4] * left[8];
acc += left[5] * left[7];
temp_12 += left[6] * left[6];
temp_12 += (acc << 1);
acc = 0;
acc += left[5] * left[8];
acc += left[6] * left[7];
temp_13 += (acc << 1);
acc = 0;
acc += left[6] * left[8];
temp_14 += left[7] * left[7];
temp_14 += (acc << 1);
acc = 0;
acc += left[7] * left[8];
temp_15 += (acc << 1);
temp_16 += left[8] * left[8];
// Perform reductions
wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
// Convert to unrelaxed 29-bit form
temp_10 += temp_9 >> WASM_LIMB_BITS;
temp_9 &= mask;
temp_11 += temp_10 >> WASM_LIMB_BITS;
temp_10 &= mask;
temp_12 += temp_11 >> WASM_LIMB_BITS;
temp_11 &= mask;
temp_13 += temp_12 >> WASM_LIMB_BITS;
temp_12 &= mask;
temp_14 += temp_13 >> WASM_LIMB_BITS;
temp_13 &= mask;
temp_15 += temp_14 >> WASM_LIMB_BITS;
temp_14 &= mask;
temp_16 += temp_15 >> WASM_LIMB_BITS;
temp_15 &= mask;
// Convert to 4 64-bit form
return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
(temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
(temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
(temp_15 >> 18) | (temp_16 << 11) };
#endif
}
template <class T> constexpr struct field<T>::wide_array field<T>::mul_512(const field& other) const noexcept {
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
uint64_t carry_2 = 0;
auto [r0, carry] = mul_wide(data[0], other.data[0]);
uint64_t r1 = mac_mini(carry, data[0], other.data[1], carry);
uint64_t r2 = mac_mini(carry, data[0], other.data[2], carry);
uint64_t r3 = mac_mini(carry, data[0], other.data[3], carry_2);
r1 = mac_mini(r1, data[1], other.data[0], carry);
r2 = mac(r2, data[1], other.data[1], carry, carry);
r3 = mac(r3, data[1], other.data[2], carry, carry);
uint64_t r4 = mac(carry_2, data[1], other.data[3], carry, carry_2);
r2 = mac_mini(r2, data[2], other.data[0], carry);
r3 = mac(r3, data[2], other.data[1], carry, carry);
r4 = mac(r4, data[2], other.data[2], carry, carry);
uint64_t r5 = mac(carry_2, data[2], other.data[3], carry, carry_2);
r3 = mac_mini(r3, data[3], other.data[0], carry);
r4 = mac(r4, data[3], other.data[1], carry, carry);
r5 = mac(r5, data[3], other.data[2], carry, carry);
uint64_t r6 = mac(carry_2, data[3], other.data[3], carry, carry_2);
return { r0, r1, r2, r3, r4, r5, r6, carry_2 };
#else
// Convert from 4 64-bit limbs to 9 29-bit limbs
auto left = wasm_convert(data);
auto right = wasm_convert(other.data);
constexpr uint64_t mask = 0x1fffffff;
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
uint64_t temp_9 = 0;
uint64_t temp_10 = 0;
uint64_t temp_11 = 0;
uint64_t temp_12 = 0;
uint64_t temp_13 = 0;
uint64_t temp_14 = 0;
uint64_t temp_15 = 0;
uint64_t temp_16 = 0;
// Multiply-add all limbs
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
// Convert to unrelaxed 29-bit form
temp_1 += temp_0 >> WASM_LIMB_BITS;
temp_0 &= mask;
temp_2 += temp_1 >> WASM_LIMB_BITS;
temp_1 &= mask;
temp_3 += temp_2 >> WASM_LIMB_BITS;
temp_2 &= mask;
temp_4 += temp_3 >> WASM_LIMB_BITS;
temp_3 &= mask;
temp_5 += temp_4 >> WASM_LIMB_BITS;
temp_4 &= mask;
temp_6 += temp_5 >> WASM_LIMB_BITS;
temp_5 &= mask;
temp_7 += temp_6 >> WASM_LIMB_BITS;
temp_6 &= mask;
temp_8 += temp_7 >> WASM_LIMB_BITS;
temp_7 &= mask;
temp_9 += temp_8 >> WASM_LIMB_BITS;
temp_8 &= mask;
temp_10 += temp_9 >> WASM_LIMB_BITS;
temp_9 &= mask;
temp_11 += temp_10 >> WASM_LIMB_BITS;
temp_10 &= mask;
temp_12 += temp_11 >> WASM_LIMB_BITS;
temp_11 &= mask;
temp_13 += temp_12 >> WASM_LIMB_BITS;
temp_12 &= mask;
temp_14 += temp_13 >> WASM_LIMB_BITS;
temp_13 &= mask;
temp_15 += temp_14 >> WASM_LIMB_BITS;
temp_14 &= mask;
temp_16 += temp_15 >> WASM_LIMB_BITS;
temp_15 &= mask;
// Convert to 8 64-bit limbs
return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
(temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
(temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
(temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40),
(temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63),
(temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57),
(temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51),
(temp_15 >> 13) | (temp_16 << 16) };
#endif
}
// NOLINTEND(readability-implicit-bool-conversion)
} // namespace bb

View File

@@ -0,0 +1,389 @@
#pragma once
#if (BBERG_NO_ASM == 0)
#include "./field_impl.hpp"
#include "asm_macros.hpp"
namespace bb {
template <class T> field<T> field<T>::asm_mul_with_coarse_reduction(const field& a, const field& b) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_mul_with_coarse_reduction");
field r;
constexpr uint64_t r_inv = T::r_inv;
constexpr uint64_t modulus_0 = modulus.data[0];
constexpr uint64_t modulus_1 = modulus.data[1];
constexpr uint64_t modulus_2 = modulus.data[2];
constexpr uint64_t modulus_3 = modulus.data[3];
constexpr uint64_t zero_ref = 0;
/**
* Registers: rax:rdx = multiplication accumulator
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
* %r10: zero register
* %0: pointer to `a`
* %1: pointer to `b`
* %2: pointer to `r`
**/
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "%r"(&a),
"%r"(&b),
"r"(&r),
[modulus_0] "m"(modulus_0),
[modulus_1] "m"(modulus_1),
[modulus_2] "m"(modulus_2),
[modulus_3] "m"(modulus_3),
[r_inv] "m"(r_inv),
[zero_reference] "m"(zero_ref)
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
return r;
}
template <class T> void field<T>::asm_self_mul_with_coarse_reduction(const field& a, const field& b) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_self_mul_with_coarse_reduction");
constexpr uint64_t r_inv = T::r_inv;
constexpr uint64_t modulus_0 = modulus.data[0];
constexpr uint64_t modulus_1 = modulus.data[1];
constexpr uint64_t modulus_2 = modulus.data[2];
constexpr uint64_t modulus_3 = modulus.data[3];
constexpr uint64_t zero_ref = 0;
/**
* Registers: rax:rdx = multiplication accumulator
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
* %r10: zero register
* %0: pointer to `a`
* %1: pointer to `b`
* %2: pointer to `r`
**/
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&b),
[modulus_0] "m"(modulus_0),
[modulus_1] "m"(modulus_1),
[modulus_2] "m"(modulus_2),
[modulus_3] "m"(modulus_3),
[r_inv] "m"(r_inv),
[zero_reference] "m"(zero_ref)
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
}
template <class T> field<T> field<T>::asm_sqr_with_coarse_reduction(const field& a) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_sqr_with_coarse_reduction");
field r;
constexpr uint64_t r_inv = T::r_inv;
constexpr uint64_t modulus_0 = modulus.data[0];
constexpr uint64_t modulus_1 = modulus.data[1];
constexpr uint64_t modulus_2 = modulus.data[2];
constexpr uint64_t modulus_3 = modulus.data[3];
constexpr uint64_t zero_ref = 0;
// Our SQR implementation with BMI2 but without ADX has a bug.
// The case is extremely rare so fixing it is a bit of a waste of time.
// We'll use MUL instead.
#if !defined(__ADX__) || defined(DISABLE_ADX)
/**
* Registers: rax:rdx = multiplication accumulator
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
* %r10: zero register
* %0: pointer to `a`
* %1: pointer to `b`
* %2: pointer to `r`
**/
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "%r"(&a),
"%r"(&a),
"r"(&r),
[modulus_0] "m"(modulus_0),
[modulus_1] "m"(modulus_1),
[modulus_2] "m"(modulus_2),
[modulus_3] "m"(modulus_3),
[r_inv] "m"(r_inv),
[zero_reference] "m"(zero_ref)
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
#else
/**
* Registers: rax:rdx = multiplication accumulator
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
* %[zero_reference]: memory location of zero value
* %0: pointer to `a`
* %[r_ptr]: memory location of pointer to `r`
**/
__asm__(SQR("%0")
// "movq %[r_ptr], %%rsi \n\t"
STORE_FIELD_ELEMENT("%1", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&r),
[zero_reference] "m"(zero_ref),
[modulus_0] "m"(modulus_0),
[modulus_1] "m"(modulus_1),
[modulus_2] "m"(modulus_2),
[modulus_3] "m"(modulus_3),
[r_inv] "m"(r_inv)
: "%rcx", "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
#endif
return r;
}
template <class T> void field<T>::asm_self_sqr_with_coarse_reduction(const field& a) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_self_sqr_with_coarse_reduction");
constexpr uint64_t r_inv = T::r_inv;
constexpr uint64_t modulus_0 = modulus.data[0];
constexpr uint64_t modulus_1 = modulus.data[1];
constexpr uint64_t modulus_2 = modulus.data[2];
constexpr uint64_t modulus_3 = modulus.data[3];
constexpr uint64_t zero_ref = 0;
// Our SQR implementation with BMI2 but without ADX has a bug.
// The case is extremely rare so fixing it is a bit of a waste of time.
// We'll use MUL instead.
#if !defined(__ADX__) || defined(DISABLE_ADX)
/**
* Registers: rax:rdx = multiplication accumulator
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
* %r10: zero register
* %0: pointer to `a`
* %1: pointer to `b`
* %2: pointer to `r`
**/
__asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1")
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&a),
[modulus_0] "m"(modulus_0),
[modulus_1] "m"(modulus_1),
[modulus_2] "m"(modulus_2),
[modulus_3] "m"(modulus_3),
[r_inv] "m"(r_inv),
[zero_reference] "m"(zero_ref)
: "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
#else
/**
* Registers: rax:rdx = multiplication accumulator
* %r12, %r13, %r14, %r15, %rax: work registers for `r`
* %r8, %r9, %rdi, %rsi: scratch registers for multiplication results
* %[zero_reference]: memory location of zero value
* %0: pointer to `a`
* %[r_ptr]: memory location of pointer to `r`
**/
__asm__(SQR("%0")
// "movq %[r_ptr], %%rsi \n\t"
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
[zero_reference] "m"(zero_ref),
[modulus_0] "m"(modulus_0),
[modulus_1] "m"(modulus_1),
[modulus_2] "m"(modulus_2),
[modulus_3] "m"(modulus_3),
[r_inv] "m"(r_inv)
: "%rcx", "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
#endif
}
template <class T> field<T> field<T>::asm_add_with_coarse_reduction(const field& a, const field& b) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_add_with_coarse_reduction");
field r;
constexpr uint64_t twice_not_modulus_0 = twice_not_modulus.data[0];
constexpr uint64_t twice_not_modulus_1 = twice_not_modulus.data[1];
constexpr uint64_t twice_not_modulus_2 = twice_not_modulus.data[2];
constexpr uint64_t twice_not_modulus_3 = twice_not_modulus.data[3];
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
ADD_REDUCE("%1",
"%[twice_not_modulus_0]",
"%[twice_not_modulus_1]",
"%[twice_not_modulus_2]",
"%[twice_not_modulus_3]") STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "%r"(&a),
"%r"(&b),
"r"(&r),
[twice_not_modulus_0] "m"(twice_not_modulus_0),
[twice_not_modulus_1] "m"(twice_not_modulus_1),
[twice_not_modulus_2] "m"(twice_not_modulus_2),
[twice_not_modulus_3] "m"(twice_not_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
return r;
}
template <class T> void field<T>::asm_self_add_with_coarse_reduction(const field& a, const field& b) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_self_add_with_coarse_reduction");
constexpr uint64_t twice_not_modulus_0 = twice_not_modulus.data[0];
constexpr uint64_t twice_not_modulus_1 = twice_not_modulus.data[1];
constexpr uint64_t twice_not_modulus_2 = twice_not_modulus.data[2];
constexpr uint64_t twice_not_modulus_3 = twice_not_modulus.data[3];
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
ADD_REDUCE("%1",
"%[twice_not_modulus_0]",
"%[twice_not_modulus_1]",
"%[twice_not_modulus_2]",
"%[twice_not_modulus_3]") STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&b),
[twice_not_modulus_0] "m"(twice_not_modulus_0),
[twice_not_modulus_1] "m"(twice_not_modulus_1),
[twice_not_modulus_2] "m"(twice_not_modulus_2),
[twice_not_modulus_3] "m"(twice_not_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
}
template <class T> field<T> field<T>::asm_sub_with_coarse_reduction(const field& a, const field& b) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_sub_with_coarse_reduction");
field r;
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
__asm__(
CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") SUB("%1")
REDUCE_FIELD_ELEMENT("%[twice_modulus_0]", "%[twice_modulus_1]", "%[twice_modulus_2]", "%[twice_modulus_3]")
STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&b),
"r"(&r),
[twice_modulus_0] "m"(twice_modulus_0),
[twice_modulus_1] "m"(twice_modulus_1),
[twice_modulus_2] "m"(twice_modulus_2),
[twice_modulus_3] "m"(twice_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
return r;
}
template <class T> void field<T>::asm_self_sub_with_coarse_reduction(const field& a, const field& b) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_self_sub_with_coarse_reduction");
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
__asm__(
CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") SUB("%1")
REDUCE_FIELD_ELEMENT("%[twice_modulus_0]", "%[twice_modulus_1]", "%[twice_modulus_2]", "%[twice_modulus_3]")
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&b),
[twice_modulus_0] "m"(twice_modulus_0),
[twice_modulus_1] "m"(twice_modulus_1),
[twice_modulus_2] "m"(twice_modulus_2),
[twice_modulus_3] "m"(twice_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
}
template <class T> void field<T>::asm_conditional_negate(field& r, const uint64_t predicate) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_conditional_negate");
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
__asm__(CLEAR_FLAGS("%%r8") LOAD_FIELD_ELEMENT(
"%1", "%%r8", "%%r9", "%%r10", "%%r11") "movq %[twice_modulus_0], %%r12 \n\t"
"movq %[twice_modulus_1], %%r13 \n\t"
"movq %[twice_modulus_2], %%r14 \n\t"
"movq %[twice_modulus_3], %%r15 \n\t"
"subq %%r8, %%r12 \n\t"
"sbbq %%r9, %%r13 \n\t"
"sbbq %%r10, %%r14 \n\t"
"sbbq %%r11, %%r15 \n\t"
"testq %0, %0 \n\t"
"cmovnzq %%r12, %%r8 \n\t"
"cmovnzq %%r13, %%r9 \n\t"
"cmovnzq %%r14, %%r10 \n\t"
"cmovnzq %%r15, %%r11 \n\t" STORE_FIELD_ELEMENT(
"%1", "%%r8", "%%r9", "%%r10", "%%r11")
:
: "r"(predicate),
"r"(&r),
[twice_modulus_0] "i"(twice_modulus_0),
[twice_modulus_1] "i"(twice_modulus_1),
[twice_modulus_2] "i"(twice_modulus_2),
[twice_modulus_3] "i"(twice_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
}
template <class T> field<T> field<T>::asm_reduce_once(const field& a) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_reduce_once");
field r;
constexpr uint64_t not_modulus_0 = not_modulus.data[0];
constexpr uint64_t not_modulus_1 = not_modulus.data[1];
constexpr uint64_t not_modulus_2 = not_modulus.data[2];
constexpr uint64_t not_modulus_3 = not_modulus.data[3];
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
REDUCE_FIELD_ELEMENT("%[not_modulus_0]", "%[not_modulus_1]", "%[not_modulus_2]", "%[not_modulus_3]")
STORE_FIELD_ELEMENT("%1", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
"r"(&r),
[not_modulus_0] "m"(not_modulus_0),
[not_modulus_1] "m"(not_modulus_1),
[not_modulus_2] "m"(not_modulus_2),
[not_modulus_3] "m"(not_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
return r;
}
template <class T> void field<T>::asm_self_reduce_once(const field& a) noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::asm_self_reduce_once");
constexpr uint64_t not_modulus_0 = not_modulus.data[0];
constexpr uint64_t not_modulus_1 = not_modulus.data[1];
constexpr uint64_t not_modulus_2 = not_modulus.data[2];
constexpr uint64_t not_modulus_3 = not_modulus.data[3];
__asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
REDUCE_FIELD_ELEMENT("%[not_modulus_0]", "%[not_modulus_1]", "%[not_modulus_2]", "%[not_modulus_3]")
STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15")
:
: "r"(&a),
[not_modulus_0] "m"(not_modulus_0),
[not_modulus_1] "m"(not_modulus_1),
[not_modulus_2] "m"(not_modulus_2),
[not_modulus_3] "m"(not_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory");
}
} // namespace bb
#endif

View File

@@ -0,0 +1,192 @@
#pragma once
#include "../../common/serialize.hpp"
#include "../../ecc/curves/bn254/fq2.hpp"
#include "../../numeric/uint256/uint256.hpp"
#include <cstring>
#include <type_traits>
#include <vector>
namespace bb::group_elements {
template <typename T>
concept SupportsHashToCurve = T::can_hash_to_curve;
template <typename Fq_, typename Fr_, typename Params> class alignas(64) affine_element {
public:
using Fq = Fq_;
using Fr = Fr_;
using in_buf = const uint8_t*;
using vec_in_buf = const uint8_t*;
using out_buf = uint8_t*;
using vec_out_buf = uint8_t**;
affine_element() noexcept = default;
~affine_element() noexcept = default;
constexpr affine_element(const Fq& x, const Fq& y) noexcept;
constexpr affine_element(const affine_element& other) noexcept = default;
constexpr affine_element(affine_element&& other) noexcept = default;
static constexpr affine_element one() noexcept { return { Params::one_x, Params::one_y }; };
/**
* @brief Reconstruct a point in affine coordinates from compressed form.
* @details #LARGE_MODULUS_AFFINE_POINT_COMPRESSION Point compression is only implemented for curves of a prime
* field F_p with p using < 256 bits. One possiblity for extending to a 256-bit prime field:
* https://patents.google.com/patent/US6252960B1/en.
*
* @param compressed compressed point
* @return constexpr affine_element
*/
template <typename BaseField = Fq,
typename CompileTimeEnabled = std::enable_if_t<(BaseField::modulus >> 255) == uint256_t(0), void>>
static constexpr affine_element from_compressed(const uint256_t& compressed) noexcept;
/**
* @brief Reconstruct a point in affine coordinates from compressed form.
* @details #LARGE_MODULUS_AFFINE_POINT_COMPRESSION Point compression is implemented for curves of a prime
* field F_p with p being 256 bits.
* TODO(Suyash): Check with kesha if this is correct.
*
* @param compressed compressed point
* @return constexpr affine_element
*/
template <typename BaseField = Fq,
typename CompileTimeEnabled = std::enable_if_t<(BaseField::modulus >> 255) == uint256_t(1), void>>
static constexpr std::array<affine_element, 2> from_compressed_unsafe(const uint256_t& compressed) noexcept;
constexpr affine_element& operator=(const affine_element& other) noexcept = default;
constexpr affine_element& operator=(affine_element&& other) noexcept = default;
constexpr affine_element operator+(const affine_element& other) const noexcept;
template <typename BaseField = Fq,
typename CompileTimeEnabled = std::enable_if_t<(BaseField::modulus >> 255) == uint256_t(0), void>>
[[nodiscard]] constexpr uint256_t compress() const noexcept;
static affine_element infinity();
constexpr affine_element set_infinity() const noexcept;
constexpr void self_set_infinity() noexcept;
[[nodiscard]] constexpr bool is_point_at_infinity() const noexcept;
[[nodiscard]] constexpr bool on_curve() const noexcept;
static constexpr std::optional<affine_element> derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept;
/**
* @brief Samples a random point on the curve.
*
* @return A randomly chosen point on the curve
*/
static affine_element random_element(numeric::RNG* engine = nullptr) noexcept;
static constexpr affine_element hash_to_curve(const std::vector<uint8_t>& seed, uint8_t attempt_count = 0) noexcept
requires SupportsHashToCurve<Params>;
constexpr bool operator==(const affine_element& other) const noexcept;
constexpr affine_element operator-() const noexcept { return { x, -y }; }
constexpr bool operator>(const affine_element& other) const noexcept;
constexpr bool operator<(const affine_element& other) const noexcept { return (other > *this); }
/**
* @brief Serialize the point to the given buffer
*
* @details We support serializing the point at infinity for curves defined over a bb::field (i.e., a
* native field of prime order) and for points of bb::g2.
*
* @warning This will need to be updated if we serialize points over composite-order fields other than fq2!
*
*/
static void serialize_to_buffer(const affine_element& value, uint8_t* buffer, bool write_x_first = false)
{
using namespace serialize;
if (value.is_point_at_infinity()) {
// if we are infinity, just set all buffer bits to 1
// we only need this case because the below gets mangled converting from montgomery for infinity points
memset(buffer, 255, sizeof(Fq) * 2);
} else {
// Note: for historic reasons we will need to redo downstream hashes if we want this to always be written in
// the same order in our various serialization flows
write(buffer, write_x_first ? value.x : value.y);
write(buffer, write_x_first ? value.y : value.x);
}
}
/**
* @brief Restore point from a buffer
*
* @param buffer Buffer from which we deserialize the point
*
* @return Deserialized point
*
* @details We support serializing the point at infinity for curves defined over a bb::field (i.e., a
* native field of prime order) and for points of bb::g2.
*
* @warning This will need to be updated if we serialize points over composite-order fields other than fq2!
*/
static affine_element serialize_from_buffer(const uint8_t* buffer, bool write_x_first = false)
{
using namespace serialize;
// Does the buffer consist entirely of set bits? If so, we have a point at infinity
// Note that if it isn't, this loop should end early.
// We only need this case because the below gets mangled converting to montgomery for infinity points
bool is_point_at_infinity =
std::all_of(buffer, buffer + sizeof(Fq) * 2, [](uint8_t val) { return val == 255; });
if (is_point_at_infinity) {
return affine_element::infinity();
}
affine_element result;
// Note: for historic reasons we will need to redo downstream hashes if we want this to always be read in the
// same order in our various serialization flows
read(buffer, write_x_first ? result.x : result.y);
read(buffer, write_x_first ? result.y : result.x);
return result;
}
/**
* @brief Serialize the point to a byte vector
*
* @return Vector with serialized representation of the point
*/
[[nodiscard]] inline std::vector<uint8_t> to_buffer() const
{
std::vector<uint8_t> buffer(sizeof(affine_element));
affine_element::serialize_to_buffer(*this, &buffer[0]);
return buffer;
}
friend std::ostream& operator<<(std::ostream& os, const affine_element& a)
{
os << "{ " << a.x << ", " << a.y << " }";
return os;
}
Fq x;
Fq y;
};
template <typename B, typename Fq_, typename Fr_, typename Params>
inline void read(B& it, group_elements::affine_element<Fq_, Fr_, Params>& element)
{
using namespace serialize;
std::array<uint8_t, sizeof(element)> buffer;
read(it, buffer);
element = group_elements::affine_element<Fq_, Fr_, Params>::serialize_from_buffer(
buffer.data(), /* use legacy field order */ true);
}
template <typename B, typename Fq_, typename Fr_, typename Params>
inline void write(B& it, group_elements::affine_element<Fq_, Fr_, Params> const& element)
{
using namespace serialize;
std::array<uint8_t, sizeof(element)> buffer;
group_elements::affine_element<Fq_, Fr_, Params>::serialize_to_buffer(
element, buffer.data(), /* use legacy field order */ true);
write(it, buffer);
}
} // namespace bb::group_elements
#include "./affine_element_impl.hpp"

View File

@@ -0,0 +1,290 @@
#pragma once
#include "./element.hpp"
#include "../../crypto/blake3s/blake3s.hpp"
#include "../../crypto/keccak/keccak.hpp"
namespace bb::group_elements {
template <class Fq, class Fr, class T>
constexpr affine_element<Fq, Fr, T>::affine_element(const Fq& x, const Fq& y) noexcept
: x(x)
, y(y)
{}
template <class Fq, class Fr, class T>
template <typename BaseField, typename CompileTimeEnabled>
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::from_compressed(const uint256_t& compressed) noexcept
{
uint256_t x_coordinate = compressed;
x_coordinate.data[3] = x_coordinate.data[3] & (~0x8000000000000000ULL);
bool y_bit = compressed.get_bit(255);
Fq x = Fq(x_coordinate);
Fq y2 = (x.sqr() * x + T::b);
if constexpr (T::has_a) {
y2 += (x * T::a);
}
auto [is_quadratic_remainder, y] = y2.sqrt();
if (!is_quadratic_remainder) {
return affine_element(Fq::zero(), Fq::zero());
}
if (uint256_t(y).get_bit(0) != y_bit) {
y = -y;
}
return affine_element<Fq, Fr, T>(x, y);
}
template <class Fq, class Fr, class T>
template <typename BaseField, typename CompileTimeEnabled>
constexpr std::array<affine_element<Fq, Fr, T>, 2> affine_element<Fq, Fr, T>::from_compressed_unsafe(
const uint256_t& compressed) noexcept
{
auto get_y_coordinate = [](const uint256_t& x_coordinate) {
Fq x = Fq(x_coordinate);
Fq y2 = (x.sqr() * x + T::b);
if constexpr (T::has_a) {
y2 += (x * T::a);
}
return y2.sqrt();
};
uint256_t x_1 = compressed;
uint256_t x_2 = compressed + Fr::modulus;
auto [is_quadratic_remainder_1, y_1] = get_y_coordinate(x_1);
auto [is_quadratic_remainder_2, y_2] = get_y_coordinate(x_2);
auto output_1 = is_quadratic_remainder_1 ? affine_element<Fq, Fr, T>(Fq(x_1), y_1)
: affine_element<Fq, Fr, T>(Fq::zero(), Fq::zero());
auto output_2 = is_quadratic_remainder_2 ? affine_element<Fq, Fr, T>(Fq(x_2), y_2)
: affine_element<Fq, Fr, T>(Fq::zero(), Fq::zero());
return { output_1, output_2 };
}
template <class Fq, class Fr, class T>
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::operator+(
const affine_element<Fq, Fr, T>& other) const noexcept
{
return affine_element(element<Fq, Fr, T>(*this) + element<Fq, Fr, T>(other));
}
template <class Fq, class Fr, class T>
template <typename BaseField, typename CompileTimeEnabled>
constexpr uint256_t affine_element<Fq, Fr, T>::compress() const noexcept
{
uint256_t out(x);
if (uint256_t(y).get_bit(0)) {
out.data[3] = out.data[3] | 0x8000000000000000ULL;
}
return out;
}
template <class Fq, class Fr, class T> affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::infinity()
{
affine_element e;
e.self_set_infinity();
return e;
}
template <class Fq, class Fr, class T>
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::set_infinity() const noexcept
{
affine_element result(*this);
result.self_set_infinity();
return result;
}
template <class Fq, class Fr, class T> constexpr void affine_element<Fq, Fr, T>::self_set_infinity() noexcept
{
if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) {
// We set the value of x equal to modulus to represent inifinty
x.data[0] = Fq::modulus.data[0];
x.data[1] = Fq::modulus.data[1];
x.data[2] = Fq::modulus.data[2];
x.data[3] = Fq::modulus.data[3];
} else {
x.self_set_msb();
}
}
template <class Fq, class Fr, class T> constexpr bool affine_element<Fq, Fr, T>::is_point_at_infinity() const noexcept
{
if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) {
// We check if the value of x is equal to modulus to represent inifinty
return ((x.data[0] ^ Fq::modulus.data[0]) | (x.data[1] ^ Fq::modulus.data[1]) |
(x.data[2] ^ Fq::modulus.data[2]) | (x.data[3] ^ Fq::modulus.data[3])) == 0;
} else {
return (x.is_msb_set());
}
}
template <class Fq, class Fr, class T> constexpr bool affine_element<Fq, Fr, T>::on_curve() const noexcept
{
if (is_point_at_infinity()) {
return true;
}
Fq xxx = x.sqr() * x + T::b;
Fq yy = y.sqr();
if constexpr (T::has_a) {
xxx += (x * T::a);
}
return (xxx == yy);
}
template <class Fq, class Fr, class T>
constexpr bool affine_element<Fq, Fr, T>::operator==(const affine_element& other) const noexcept
{
bool this_is_infinity = is_point_at_infinity();
bool other_is_infinity = other.is_point_at_infinity();
bool both_infinity = this_is_infinity && other_is_infinity;
bool only_one_is_infinity = this_is_infinity != other_is_infinity;
return !only_one_is_infinity && (both_infinity || ((x == other.x) && (y == other.y)));
}
/**
* Comparison operators (for std::sort)
*
* @details CAUTION!! Don't use this operator. It has no meaning other than for use by std::sort.
**/
template <class Fq, class Fr, class T>
constexpr bool affine_element<Fq, Fr, T>::operator>(const affine_element& other) const noexcept
{
// We are setting point at infinity to always be the lowest element
if (is_point_at_infinity()) {
return false;
}
if (other.is_point_at_infinity()) {
return true;
}
if (x > other.x) {
return true;
}
if (x == other.x && y > other.y) {
return true;
}
return false;
}
template <class Fq, class Fr, class T>
constexpr std::optional<affine_element<Fq, Fr, T>> affine_element<Fq, Fr, T>::derive_from_x_coordinate(
const Fq& x, bool sign_bit) noexcept
{
auto yy = x.sqr() * x + T::b;
if constexpr (T::has_a) {
yy += (x * T::a);
}
auto [found_root, y] = yy.sqrt();
if (found_root) {
if (uint256_t(y).get_bit(0) != sign_bit) {
y = -y;
}
return affine_element(x, y);
}
return std::nullopt;
}
/**
* @brief Hash a seed buffer into a point
*
* @details ALGORITHM DESCRIPTION:
* 1. Initialize unsigned integer `attempt_count = 0`
* 2. Copy seed into a buffer whose size is 2 bytes greater than `seed` (initialized to 0)
* 3. Interpret `attempt_count` as a byte and write into buffer at [buffer.size() - 2]
* 4. Compute Blake3s hash of buffer
* 5. Set the end byte of the buffer to `1`
* 6. Compute Blake3s hash of buffer
* 7. Interpret the two hash outputs as the high / low 256 bits of a 512-bit integer (big-endian)
* 8. Derive x-coordinate of point by reducing the 512-bit integer modulo the curve's field modulus (Fq)
* 9. Compute y^2 from the curve formula y^2 = x^3 + ax + b (a, b are curve params. for BN254, a = 0, b = 3)
* 10. IF y^2 IS NOT A QUADRATIC RESIDUE
* 10a. increment `attempt_count` by 1 and go to step 2
* 11. IF y^2 IS A QUADRATIC RESIDUE
* 11a. derive y coordinate via y = sqrt(y)
* 11b. Interpret most significant bit of 512-bit integer as a 'parity' bit
* 11c. If parity bit is set AND y's most significant bit is not set, invert y
* 11d. If parity bit is not set AND y's most significant bit is set, invert y
* N.B. last 2 steps are because the sqrt() algorithm can return 2 values,
* we need to a way to canonically distinguish between these 2 values and select a "preferred" one
* 11e. return (x, y)
*
* @note This algorihm is constexpr: we can hash-to-curve (and derive generators) at compile-time!
* @tparam Fq
* @tparam Fr
* @tparam T
* @param seed Bytes that uniquely define the point being generated
* @param attempt_count
* @return constexpr affine_element<Fq, Fr, T>
*/
template <class Fq, class Fr, class T>
constexpr affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::hash_to_curve(const std::vector<uint8_t>& seed,
uint8_t attempt_count) noexcept
requires SupportsHashToCurve<T>
{
std::vector<uint8_t> target_seed(seed);
// expand by 2 bytes to cover incremental hash attempts
const size_t seed_size = seed.size();
for (size_t i = 0; i < 2; ++i) {
target_seed.push_back(0);
}
target_seed[seed_size] = attempt_count;
target_seed[seed_size + 1] = 0;
const auto hash_hi = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
target_seed[seed_size + 1] = 1;
const auto hash_lo = blake3::blake3s_constexpr(&target_seed[0], target_seed.size());
// custom serialize methods as common/serialize.hpp is not constexpr!
const auto read_uint256 = [](const uint8_t* in) {
const auto read_limb = [](const uint8_t* in, uint64_t& out) {
for (size_t i = 0; i < 8; ++i) {
out += static_cast<uint64_t>(in[i]) << ((7 - i) * 8);
}
};
uint256_t out = 0;
read_limb(&in[0], out.data[3]);
read_limb(&in[8], out.data[2]);
read_limb(&in[16], out.data[1]);
read_limb(&in[24], out.data[0]);
return out;
};
// interpret 64 byte hash output as a uint512_t, reduce to Fq element
//(512 bits of entropy ensures result is not biased as 512 >> Fq::modulus.get_msb())
Fq x(uint512_t(read_uint256(&hash_lo[0]), read_uint256(&hash_hi[0])));
bool sign_bit = hash_hi[0] > 127;
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
if (result.has_value()) {
return result.value();
}
return hash_to_curve(seed, attempt_count + 1);
}
template <typename Fq, typename Fr, typename T>
affine_element<Fq, Fr, T> affine_element<Fq, Fr, T>::random_element(numeric::RNG* engine) noexcept
{
if (engine == nullptr) {
engine = &numeric::get_randomness();
}
Fq x;
Fq y;
while (true) {
// Sample a random x-coordinate and check if it satisfies curve equation.
x = Fq::random_element(engine);
// Negate the y-coordinate based on a randomly sampled bit.
bool sign_bit = (engine->get_random_uint8() & 1) != 0;
std::optional<affine_element> result = derive_from_x_coordinate(x, sign_bit);
if (result.has_value()) {
return result.value();
}
}
throw_or_abort("affine_element::random_element error");
return affine_element<Fq, Fr, T>(x, y);
}
} // namespace bb::group_elements

View File

@@ -0,0 +1,168 @@
#pragma once
#include "affine_element.hpp"
#include "../../common/compiler_hints.hpp"
#include "../../common/mem.hpp"
#include "../../numeric/random/engine.hpp"
#include "../../numeric/uint256/uint256.hpp"
#include "wnaf.hpp"
#include <array>
#include <random>
#include <vector>
namespace bb::group_elements {
/**
* @brief element class. Implements ecc group arithmetic using Jacobian coordinates
* See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l
*
* Note: Currently subgroup checks are NOT IMPLEMENTED
* Our current Plonk implementation uses G1 points that have a cofactor of 1.
* All G2 points are precomputed (generator [1]_2 and trusted setup point [x]_2).
* Explicitly assume precomputed points are valid members of the prime-order subgroup for G2.
* @tparam Fq prime field the curve is defined over
* @tparam Fr prime field whose characteristic equals the size of the prime-order elliptic curve subgroup
* @tparam Params curve parameters
*/
template <class Fq, class Fr, class Params> class alignas(32) element {
public:
static constexpr Fq curve_b = Params::b;
element() noexcept = default;
constexpr element(const Fq& a, const Fq& b, const Fq& c) noexcept;
constexpr element(const element& other) noexcept;
constexpr element(element&& other) noexcept;
constexpr element(const affine_element<Fq, Fr, Params>& other) noexcept;
constexpr ~element() noexcept = default;
static constexpr element one() noexcept { return { Params::one_x, Params::one_y, Fq::one() }; };
static constexpr element zero() noexcept
{
element zero;
zero.self_set_infinity();
return zero;
};
constexpr element& operator=(const element& other) noexcept;
constexpr element& operator=(element&& other) noexcept;
constexpr operator affine_element<Fq, Fr, Params>() const noexcept;
static element random_element(numeric::RNG* engine = nullptr) noexcept;
constexpr element dbl() const noexcept;
constexpr void self_dbl() noexcept;
constexpr void self_mixed_add_or_sub(const affine_element<Fq, Fr, Params>& other, uint64_t predicate) noexcept;
constexpr element operator+(const element& other) const noexcept;
constexpr element operator+(const affine_element<Fq, Fr, Params>& other) const noexcept;
constexpr element operator+=(const element& other) noexcept;
constexpr element operator+=(const affine_element<Fq, Fr, Params>& other) noexcept;
constexpr element operator-(const element& other) const noexcept;
constexpr element operator-(const affine_element<Fq, Fr, Params>& other) const noexcept;
constexpr element operator-() const noexcept;
constexpr element operator-=(const element& other) noexcept;
constexpr element operator-=(const affine_element<Fq, Fr, Params>& other) noexcept;
friend constexpr element operator+(const affine_element<Fq, Fr, Params>& left, const element& right) noexcept
{
return right + left;
}
friend constexpr element operator-(const affine_element<Fq, Fr, Params>& left, const element& right) noexcept
{
return -right + left;
}
element operator*(const Fr& exponent) const noexcept;
element operator*=(const Fr& exponent) noexcept;
// If you end up implementing this, congrats, you've solved the DL problem!
// P.S. This is a joke, don't even attempt! 😂
// constexpr Fr operator/(const element& other) noexcept {}
constexpr element normalize() const noexcept;
static element infinity();
BB_INLINE constexpr element set_infinity() const noexcept;
BB_INLINE constexpr void self_set_infinity() noexcept;
[[nodiscard]] BB_INLINE constexpr bool is_point_at_infinity() const noexcept;
[[nodiscard]] BB_INLINE constexpr bool on_curve() const noexcept;
BB_INLINE constexpr bool operator==(const element& other) const noexcept;
static void batch_normalize(element* elements, size_t num_elements) noexcept;
static void batch_affine_add(const std::span<affine_element<Fq, Fr, Params>>& first_group,
const std::span<affine_element<Fq, Fr, Params>>& second_group,
const std::span<affine_element<Fq, Fr, Params>>& results) noexcept;
static std::vector<affine_element<Fq, Fr, Params>> batch_mul_with_endomorphism(
const std::span<affine_element<Fq, Fr, Params>>& points, const Fr& scalar) noexcept;
Fq x;
Fq y;
Fq z;
private:
// For test access to mul_without_endomorphism
friend class TestElementPrivate;
element mul_without_endomorphism(const Fr& scalar) const noexcept;
element mul_with_endomorphism(const Fr& scalar) const noexcept;
template <typename = typename std::enable_if<Params::can_hash_to_curve>>
static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept;
// {
// bool found_one = false;
// Fq yy;
// Fq x;
// Fq y;
// Fq t0;
// while (!found_one) {
// x = Fq::random_element(engine);
// yy = x.sqr() * x + Params::b;
// if constexpr (Params::has_a) {
// yy += (x * Params::a);
// }
// y = yy.sqrt();
// t0 = y.sqr();
// found_one = (yy == t0);
// }
// return { x, y, Fq::one() };
// }
// for serialization: update with new fields
// TODO(https://github.com/AztecProtocol/barretenberg/issues/908) point at inifinty isn't handled
static void conditional_negate_affine(const affine_element<Fq, Fr, Params>& in,
affine_element<Fq, Fr, Params>& out,
uint64_t predicate) noexcept;
friend std::ostream& operator<<(std::ostream& os, const element& a)
{
os << "{ " << a.x << ", " << a.y << ", " << a.z << " }";
return os;
}
};
template <class Fq, class Fr, class Params> std::ostream& operator<<(std::ostream& os, element<Fq, Fr, Params> const& e)
{
return os << "x:" << e.x << " y:" << e.y << " z:" << e.z;
}
// constexpr element<Fq, Fr, Params>::one = element<Fq, Fr, Params>{ Params::one_x, Params::one_y, Fq::one() };
// constexpr element<Fq, Fr, Params>::point_at_infinity = one.set_infinity();
// constexpr element<Fq, Fr, Params>::curve_b = Params::b;
} // namespace bb::group_elements
#include "./element_impl.hpp"
template <class Fq, class Fr, class Params>
bb::group_elements::affine_element<Fq, Fr, Params> operator*(
const bb::group_elements::affine_element<Fq, Fr, Params>& base, const Fr& exponent) noexcept
{
return bb::group_elements::affine_element<Fq, Fr, Params>(bb::group_elements::element(base) * exponent);
}
template <class Fq, class Fr, class Params>
bb::group_elements::affine_element<Fq, Fr, Params> operator*(const bb::group_elements::element<Fq, Fr, Params>& base,
const Fr& exponent) noexcept
{
return (bb::group_elements::element(base) * exponent);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,129 @@
#pragma once
#include "../../common/assert.hpp"
#include "./affine_element.hpp"
#include "./element.hpp"
#include "./wnaf.hpp"
#include "../../common/constexpr_utils.hpp"
#include "../../crypto/blake3s/blake3s.hpp"
#include <array>
#include <cinttypes>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
namespace bb {
/**
* @brief group class. Represents an elliptic curve group element.
* Group is parametrised by coordinate_field and subgroup_field
*
* Note: Currently subgroup checks are NOT IMPLEMENTED
* Our current Plonk implementation uses G1 points that have a cofactor of 1.
* All G2 points are precomputed (generator [1]_2 and trusted setup point [x]_2).
* Explicitly assume precomputed points are valid members of the prime-order subgroup for G2.
*
* @tparam coordinate_field
* @tparam subgroup_field
* @tparam GroupParams
*/
template <typename _coordinate_field, typename _subgroup_field, typename GroupParams> class group {
public:
// hoist coordinate_field, subgroup_field into the public namespace
using coordinate_field = _coordinate_field;
using subgroup_field = _subgroup_field;
using element = group_elements::element<coordinate_field, subgroup_field, GroupParams>;
using affine_element = group_elements::affine_element<coordinate_field, subgroup_field, GroupParams>;
using Fq = coordinate_field;
using Fr = subgroup_field;
static constexpr bool USE_ENDOMORPHISM = GroupParams::USE_ENDOMORPHISM;
static constexpr bool has_a = GroupParams::has_a;
static constexpr element one{ GroupParams::one_x, GroupParams::one_y, coordinate_field::one() };
static constexpr element point_at_infinity = one.set_infinity();
static constexpr affine_element affine_one{ GroupParams::one_x, GroupParams::one_y };
static constexpr affine_element affine_point_at_infinity = affine_one.set_infinity();
static constexpr coordinate_field curve_a = GroupParams::a;
static constexpr coordinate_field curve_b = GroupParams::b;
/**
* @brief Derives generator points via hash-to-curve
*
* ALGORITHM DESCRIPTION:
* 1. Each generator has an associated "generator index" described by its location in the vector
* 2. a 64-byte preimage buffer is generated with the following structure:
* bytes 0-31: BLAKE3 hash of domain_separator
* bytes 32-63: generator index in big-endian form
* 3. The hash-to-curve algorithm is used to hash the above into a group element:
* a. iterate `count` upwards from `0`
* b. append `count` to the preimage buffer as a 1-byte integer in big-endian form
* c. compute BLAKE3 hash of concat(preimage buffer, 0)
* d. compute BLAKE3 hash of concat(preimage buffer, 1)
* e. interpret (c, d) as (hi, low) limbs of a 512-bit integer
* f. reduce 512-bit integer modulo coordinate_field to produce x-coordinate
* g. attempt to derive y-coordinate. If not successful go to step (a) and continue
* h. if parity of y-coordinate's least significant bit does not match parity of most significant bit of
* (d), invert y-coordinate.
* j. return (x, y)
*
* NOTE: In step 3b it is sufficient to use 1 byte to store `count`.
* Step 3 has a 50% chance of returning, the probability of `count` exceeding 256 is 1 in 2^256
* NOTE: The domain separator is included to ensure that it is possible to derive independent sets of
* index-addressable generators.
* NOTE: we produce 64 bytes of BLAKE3 output when producing x-coordinate field
* element, to ensure that x-coordinate is uniformly randomly distributed in the field. Using a 256-bit input adds
* significant bias when reducing modulo a ~256-bit coordinate_field
* NOTE: We ensure y-parity is linked to preimage
* hash because there is no canonical deterministic square root algorithm (i.e. if a field element has a square
* root, there are two of them and `field::sqrt` may return either one)
* @param num_generators
* @param domain_separator
* @return std::vector<affine_element>
*/
inline static constexpr std::vector<affine_element> derive_generators(
const std::vector<uint8_t>& domain_separator_bytes,
const size_t num_generators,
const size_t starting_index = 0)
{
std::vector<affine_element> result;
const auto domain_hash = blake3::blake3s_constexpr(&domain_separator_bytes[0], domain_separator_bytes.size());
std::vector<uint8_t> generator_preimage;
generator_preimage.reserve(64);
std::copy(domain_hash.begin(), domain_hash.end(), std::back_inserter(generator_preimage));
for (size_t i = 0; i < 32; ++i) {
generator_preimage.emplace_back(0);
}
for (size_t i = starting_index; i < starting_index + num_generators; ++i) {
auto generator_index = static_cast<uint32_t>(i);
uint32_t mask = 0xff;
generator_preimage[32] = static_cast<uint8_t>(generator_index >> 24);
generator_preimage[33] = static_cast<uint8_t>((generator_index >> 16) & mask);
generator_preimage[34] = static_cast<uint8_t>((generator_index >> 8) & mask);
generator_preimage[35] = static_cast<uint8_t>(generator_index & mask);
result.push_back(affine_element::hash_to_curve(generator_preimage));
}
return result;
}
inline static constexpr std::vector<affine_element> derive_generators(const std::string_view& domain_separator,
const size_t num_generators,
const size_t starting_index = 0)
{
std::vector<uint8_t> domain_bytes;
for (char i : domain_separator) {
domain_bytes.emplace_back(static_cast<unsigned char>(i));
}
return derive_generators(domain_bytes, num_generators, starting_index);
}
BB_INLINE static void conditional_negate_affine(const affine_element* src,
affine_element* dest,
uint64_t predicate);
};
} // namespace bb
#ifdef DISABLE_ASM
#include "group_impl_int128.tcc"
#else
#include "group_impl_asm.tcc"
#endif

View File

@@ -0,0 +1,162 @@
#pragma once
#ifndef DISABLE_ASM
#include "barretenberg/ecc/groups/group.hpp"
#include <cstdint>
namespace bb {
// copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const affine_element* src, affine_element*
// dest)
// {
// if constexpr (GroupParams::small_elements) {
// #if defined __AVX__ && defined USE_AVX
// ASSERT((((uintptr_t)src & 0x1f) == 0));
// ASSERT((((uintptr_t)dest & 0x1f) == 0));
// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t"
// "vmovdqa 32(%0), %%ymm1 \n\t"
// "vmovdqa %%ymm0, 0(%1) \n\t"
// "vmovdqa %%ymm1, 32(%1) \n\t"
// :
// : "r"(src), "r"(dest)
// : "%ymm0", "%ymm1", "memory");
// #else
// *dest = *src;
// #endif
// } else {
// *dest = *src;
// }
// }
// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const element* src, element* dest)
// {
// if constexpr (GroupParams::small_elements) {
// #if defined __AVX__ && defined USE_AVX
// ASSERT((((uintptr_t)src & 0x1f) == 0));
// ASSERT((((uintptr_t)dest & 0x1f) == 0));
// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t"
// "vmovdqa 32(%0), %%ymm1 \n\t"
// "vmovdqa 64(%0), %%ymm2 \n\t"
// "vmovdqa %%ymm0, 0(%1) \n\t"
// "vmovdqa %%ymm1, 32(%1) \n\t"
// "vmovdqa %%ymm2, 64(%1) \n\t"
// :
// : "r"(src), "r"(dest)
// : "%ymm0", "%ymm1", "%ymm2", "memory");
// #else
// *dest = *src;
// #endif
// } else {
// *dest = src;
// }
// }
// copies src into dest, inverting y-coordinate if 'predicate' is true
// n.b. requires src and dest to be aligned on 32 byte boundary
template <typename coordinate_field, typename subgroup_field, typename GroupParams>
inline void group<coordinate_field, subgroup_field, GroupParams>::conditional_negate_affine(const affine_element* src,
affine_element* dest,
uint64_t predicate)
{
constexpr uint256_t twice_modulus = coordinate_field::modulus + coordinate_field::modulus;
constexpr uint64_t twice_modulus_0 = twice_modulus.data[0];
constexpr uint64_t twice_modulus_1 = twice_modulus.data[1];
constexpr uint64_t twice_modulus_2 = twice_modulus.data[2];
constexpr uint64_t twice_modulus_3 = twice_modulus.data[3];
if constexpr (GroupParams::small_elements) {
#if defined __AVX__ && defined USE_AVX
ASSERT((((uintptr_t)src & 0x1f) == 0));
ASSERT((((uintptr_t)dest & 0x1f) == 0));
__asm__ __volatile__("xorq %%r8, %%r8 \n\t"
"movq 32(%0), %%r8 \n\t"
"movq 40(%0), %%r9 \n\t"
"movq 48(%0), %%r10 \n\t"
"movq 56(%0), %%r11 \n\t"
"movq %[modulus_0], %%r12 \n\t"
"movq %[modulus_1], %%r13 \n\t"
"movq %[modulus_2], %%r14 \n\t"
"movq %[modulus_3], %%r15 \n\t"
"subq %%r8, %%r12 \n\t"
"sbbq %%r9, %%r13 \n\t"
"sbbq %%r10, %%r14 \n\t"
"sbbq %%r11, %%r15 \n\t"
"testq %2, %2 \n\t"
"cmovnzq %%r12, %%r8 \n\t"
"cmovnzq %%r13, %%r9 \n\t"
"cmovnzq %%r14, %%r10 \n\t"
"cmovnzq %%r15, %%r11 \n\t"
"vmovdqa 0(%0), %%ymm0 \n\t"
"vmovdqa %%ymm0, 0(%1) \n\t"
"movq %%r8, 32(%1) \n\t"
"movq %%r9, 40(%1) \n\t"
"movq %%r10, 48(%1) \n\t"
"movq %%r11, 56(%1) \n\t"
:
: "r"(src),
"r"(dest),
"r"(predicate),
[modulus_0] "i"(twice_modulus_0),
[modulus_1] "i"(twice_modulus_1),
[modulus_2] "i"(twice_modulus_2),
[modulus_3] "i"(twice_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "%ymm0", "memory", "cc");
#else
__asm__ __volatile__("xorq %%r8, %%r8 \n\t"
"movq 32(%0), %%r8 \n\t"
"movq 40(%0), %%r9 \n\t"
"movq 48(%0), %%r10 \n\t"
"movq 56(%0), %%r11 \n\t"
"movq %[modulus_0], %%r12 \n\t"
"movq %[modulus_1], %%r13 \n\t"
"movq %[modulus_2], %%r14 \n\t"
"movq %[modulus_3], %%r15 \n\t"
"subq %%r8, %%r12 \n\t"
"sbbq %%r9, %%r13 \n\t"
"sbbq %%r10, %%r14 \n\t"
"sbbq %%r11, %%r15 \n\t"
"testq %2, %2 \n\t"
"cmovnzq %%r12, %%r8 \n\t"
"cmovnzq %%r13, %%r9 \n\t"
"cmovnzq %%r14, %%r10 \n\t"
"cmovnzq %%r15, %%r11 \n\t"
"movq 0(%0), %%r12 \n\t"
"movq 8(%0), %%r13 \n\t"
"movq 16(%0), %%r14 \n\t"
"movq 24(%0), %%r15 \n\t"
"movq %%r8, 32(%1) \n\t"
"movq %%r9, 40(%1) \n\t"
"movq %%r10, 48(%1) \n\t"
"movq %%r11, 56(%1) \n\t"
"movq %%r12, 0(%1) \n\t"
"movq %%r13, 8(%1) \n\t"
"movq %%r14, 16(%1) \n\t"
"movq %%r15, 24(%1) \n\t"
:
: "r"(src),
"r"(dest),
"r"(predicate),
[modulus_0] "i"(twice_modulus_0),
[modulus_1] "i"(twice_modulus_1),
[modulus_2] "i"(twice_modulus_2),
[modulus_3] "i"(twice_modulus_3)
: "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "memory", "cc");
#endif
} else {
if (predicate) { // NOLINT
coordinate_field::__copy(src->x, dest->x);
dest->y = -src->y;
} else {
copy_affine(*src, *dest);
}
}
}
} // namespace bb
#endif

View File

@@ -0,0 +1,34 @@
#pragma once
#ifdef DISABLE_ASM
#include "barretenberg/ecc/groups/group.hpp"
#include <cstdint>
namespace bb {
// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const affine_element* src, affine_element*
// dest)
// {
// *dest = *src;
// }
// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries
// template <typename coordinate_field, typename subgroup_field, typename GroupParams>
// inline void group<coordinate_field, subgroup_field, GroupParams>::copy(const element* src, element* dest)
// {
// *dest = *src;
// }
template <typename coordinate_field, typename subgroup_field, typename GroupParams>
inline void group<coordinate_field, subgroup_field, GroupParams>::conditional_negate_affine(const affine_element* src,
affine_element* dest,
uint64_t predicate)
{
*dest = predicate ? -(*src) : (*src);
}
} // namespace bb
#endif

View File

@@ -0,0 +1,513 @@
#pragma once
#include "../../numeric/bitop/get_msb.hpp"
#include <cstdint>
#include <iostream>
// NOLINTBEGIN(readability-implicit-bool-conversion)
namespace bb::wnaf {
constexpr size_t SCALAR_BITS = 127;
#define WNAF_SIZE(x) ((bb::wnaf::SCALAR_BITS + (x)-1) / (x)) // NOLINT(cppcoreguidelines-macro-usage)
constexpr size_t get_optimal_bucket_width(const size_t num_points)
{
if (num_points >= 14617149) {
return 21;
}
if (num_points >= 1139094) {
return 18;
}
// if (num_points >= 100000)
if (num_points >= 155975) {
return 15;
}
if (num_points >= 144834)
// if (num_points >= 100000)
{
return 14;
}
if (num_points >= 25067) {
return 12;
}
if (num_points >= 13926) {
return 11;
}
if (num_points >= 7659) {
return 10;
}
if (num_points >= 2436) {
return 9;
}
if (num_points >= 376) {
return 7;
}
if (num_points >= 231) {
return 6;
}
if (num_points >= 97) {
return 5;
}
if (num_points >= 35) {
return 4;
}
if (num_points >= 10) {
return 3;
}
if (num_points >= 2) {
return 2;
}
return 1;
}
constexpr size_t get_num_buckets(const size_t num_points)
{
const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2);
return 1UL << bits_per_bucket;
}
constexpr size_t get_num_rounds(const size_t num_points)
{
const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2);
return WNAF_SIZE(bits_per_bucket + 1);
}
template <size_t bits, size_t bit_position> inline uint64_t get_wnaf_bits_const(const uint64_t* scalar) noexcept
{
if constexpr (bits == 0) {
return 0ULL;
} else {
/**
* we want to take a 128 bit scalar and shift it down by (bit_position).
* We then wish to mask out `bits` number of bits.
* Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also
* (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual
* bit position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 -
* (bit_position & 63)
*
* So, step 1:
* get low limb and shift right by (bit_position & 63)
* get high limb and shift left by (64 - (bit_position & 63))
*
*/
constexpr size_t lo_limb_idx = bit_position / 64;
constexpr size_t hi_limb_idx = (bit_position + bits - 1) / 64;
constexpr uint64_t lo_shift = bit_position & 63UL;
constexpr uint64_t bit_mask = (1UL << static_cast<uint64_t>(bits)) - 1UL;
uint64_t lo = (scalar[lo_limb_idx] >> lo_shift);
if constexpr (lo_limb_idx == hi_limb_idx) {
return lo & bit_mask;
} else {
constexpr uint64_t hi_shift = 64UL - (bit_position & 63UL);
uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift)));
return (lo | hi) & bit_mask;
}
}
}
inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const uint64_t bit_position) noexcept
{
/**
* we want to take a 128 bit scalar and shift it down by (bit_position).
* We then wish to mask out `bits` number of bits.
* Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also
* (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual bit
* position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - (bit_position
* & 63)
*
* So, step 1:
* get low limb and shift right by (bit_position & 63)
* get high limb and shift left by (64 - (bit_position & 63))
*
*/
const auto lo_limb_idx = static_cast<size_t>(bit_position >> 6);
const auto hi_limb_idx = static_cast<size_t>((bit_position + bits - 1) >> 6);
const uint64_t lo_shift = bit_position & 63UL;
const uint64_t bit_mask = (1UL << static_cast<uint64_t>(bits)) - 1UL;
const uint64_t lo = (scalar[lo_limb_idx] >> lo_shift);
const uint64_t hi_shift = bit_position ? 64UL - (bit_position & 63UL) : 0;
const uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift)));
const uint64_t hi_mask = bit_mask & (0ULL - (lo_limb_idx != hi_limb_idx));
return (lo & bit_mask) | (hi & hi_mask);
}
inline void fixed_wnaf_packed(
const uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const uint64_t point_index, const size_t wnaf_bits) noexcept
{
skew_map = ((scalar[0] & 1) == 0);
uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i)] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
previous = slice + predicate;
}
size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1));
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[1] = ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
}
/**
* @brief Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication.
*
* WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn optimizes
* the number of point doublings in scalar multiplication, in turn aiding efficiency.
*
* @param scalar Pointer to 128-bit scalar for which WNAF is to be computed.
* @param wnaf Pointer to num_points+1 size array where the computed WNAF will be stored.
* @param skew_map Reference to a boolean variable which will be set based on the least significant bit of the scalar.
* @param point_index The index of the point being computed in the context of multiple point multiplication.
* @param num_points The number of points being computed in parallel.
* @param wnaf_bits The number of bits to use in each window of the WNAF representation.
*/
inline void fixed_wnaf(const uint64_t* scalar,
uint64_t* wnaf,
bool& skew_map,
const uint64_t point_index,
const uint64_t num_points,
const size_t wnaf_bits) noexcept
{
skew_map = ((scalar[0] & 1) == 0);
uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i) * num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
previous = slice + predicate;
}
size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1));
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
}
/**
* Current flow...
*
* If a wnaf entry is even, we add +1 to it, and subtract 32 from the previous entry.
* This works if the previous entry is odd. If we recursively apply this process, starting at the least significant
*window, this will always be the case.
*
* However, we want to skip over windows that are 0, which poses a problem.
*
* Scenario 1: even window followed by 0 window followed by any window 'x'
*
* We can't add 1 to the even window and subtract 32 from the 0 window, as we don't have a bucket that maps to -32
* This means that we have to identify whether we are going to borrow 32 from 'x', requiring us to look at least 2
*steps ahead
*
* Scenario 2: <even> <0> <0> <x>
*
* This problem proceeds indefinitely - if we have adjacent 0 windows, we do not know whether we need to track a
*borrow flag until we identify the next non-zero window
*
* Scenario 3: <odd> <0>
*
* This one works...
*
* Ok, so we should be a bit more limited with when we don't include window entries.
* The goal here is to identify short scalars, so we want to identify the most significant non-zero window
**/
inline uint64_t get_num_scalar_bits(const uint64_t* scalar)
{
const uint64_t msb_1 = numeric::get_msb(scalar[1]);
const uint64_t msb_0 = numeric::get_msb(scalar[0]);
const uint64_t scalar_1_mask = (0ULL - (scalar[1] > 0));
const uint64_t scalar_0_mask = (0ULL - (scalar[0] > 0)) & ~scalar_1_mask;
const uint64_t msb = (scalar_1_mask & (msb_1 + 64)) | (scalar_0_mask & (msb_0));
return msb;
}
/**
* How to compute an x-bit wnaf slice?
*
* Iterate over number of slices in scalar.
* For each slice, if slice is even, ADD +1 to current slice and SUBTRACT 2^x from previous slice.
* (for 1st slice we instead add +1 and set the scalar's 'skew' value to 'true' (i.e. need to subtract 1 from it at the
* end of our scalar mul algo))
*
* In *wnaf we store the following:
* 1. bits 0-30: ABSOLUTE value of wnaf (i.e. -3 goes to 3)
* 2. bit 31: 'predicate' bool (i.e. does the wnaf value need to be negated?)
* 3. bits 32-63: position in a point array that describes the elliptic curve point this wnaf slice is referencing
*
* N.B. IN OUR STDLIB ALGORITHMS THE SKEW VALUE REPRESENTS AN ADDITION NOT A SUBTRACTION (i.e. we add +1 at the end of
* the scalar mul algo we don't sub 1) (this is to eliminate situations which could produce the point at infinity as an
* output as our circuit logic cannot accommodate this edge case).
*
* Credits: Zac W.
*
* @param scalar Pointer to the 128-bit non-montgomery scalar that is supposed to be transformed into wnaf
* @param wnaf Pointer to output array that needs to accommodate enough 64-bit WNAF entries
* @param skew_map Reference to output skew value, which if true shows that the point should be added once at the end of
* computation
* @param wnaf_round_counts Pointer to output array specifying the number of points participating in each round
* @param point_index The index of the point that should be multiplied by this scalar in the point array
* @param num_points Total points in the MSM (2*num_initial_points)
*
*/
inline void fixed_wnaf_with_counts(const uint64_t* scalar,
uint64_t* wnaf,
bool& skew_map,
uint64_t* wnaf_round_counts,
const uint64_t point_index,
const uint64_t num_points,
const size_t wnaf_bits) noexcept
{
const size_t max_wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
if ((scalar[0] | scalar[1]) == 0ULL) {
skew_map = false;
for (size_t round_i = 0; round_i < max_wnaf_entries; ++round_i) {
wnaf[(round_i)*num_points] = 0xffffffffffffffffULL;
}
return;
}
const auto current_scalar_bits = static_cast<size_t>(get_num_scalar_bits(scalar) + 1);
skew_map = ((scalar[0] & 1) == 0);
uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
const auto wnaf_entries = static_cast<size_t>((current_scalar_bits + wnaf_bits - 1) / wnaf_bits);
if (wnaf_entries == 1) {
wnaf[(max_wnaf_entries - 1) * num_points] = (previous >> 1UL) | (point_index);
++wnaf_round_counts[max_wnaf_entries - 1];
for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) {
wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL;
}
return;
}
// If there are several windows
for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
// Get a bit slice
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
// Get the predicate (last bit is zero)
uint64_t predicate = ((slice & 1UL) == 0UL);
// Update round count
++wnaf_round_counts[max_wnaf_entries - round_i];
// Calculate entry value
// If the last bit of current slice is 1, we simply put the previous value with the point index
// If the last bit of the current slice is 0, we negate everything, so that we subtract from the WNAF form and
// make it 0
wnaf[(max_wnaf_entries - round_i) * num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
// Update the previous value to the next windows
previous = slice + predicate;
}
// The final iteration for top bits
auto final_bits = static_cast<size_t>(current_scalar_bits - (wnaf_bits * (wnaf_entries - 1)));
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
uint64_t predicate = ((slice & 1UL) == 0UL);
++wnaf_round_counts[(max_wnaf_entries - wnaf_entries + 1)];
wnaf[((max_wnaf_entries - wnaf_entries + 1) * num_points)] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
// Saving top bits
++wnaf_round_counts[max_wnaf_entries - wnaf_entries];
wnaf[(max_wnaf_entries - wnaf_entries) * num_points] = ((slice + predicate) >> 1UL) | (point_index);
// Fill all unused slots with -1
for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) {
wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL;
}
}
template <size_t num_points, size_t wnaf_bits, size_t round_i>
inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept
{
constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
constexpr auto log2_num_points = static_cast<size_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
if constexpr (round_i < wnaf_entries - 1) {
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i) << log2_num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index << 32UL);
wnaf_round<num_points, wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
} else {
constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits;
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
// uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index << 32UL);
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
}
}
template <size_t scalar_bits, size_t num_points, size_t wnaf_bits, size_t round_i>
inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept
{
constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits;
constexpr auto log2_num_points = static_cast<uint64_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
if constexpr (round_i < wnaf_entries - 1) {
uint64_t slice = get_wnaf_bits_const<wnaf_bits, round_i * wnaf_bits>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i) << log2_num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index << 32UL);
wnaf_round<scalar_bits, num_points, wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
} else {
constexpr size_t final_bits = ((scalar_bits / wnaf_bits) * wnaf_bits == scalar_bits)
? wnaf_bits
: scalar_bits - (scalar_bits / wnaf_bits) * wnaf_bits;
uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index << 32UL);
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
}
}
template <size_t wnaf_bits, size_t round_i>
inline void wnaf_round_packed(const uint64_t* scalar,
uint64_t* wnaf,
const uint64_t point_index,
const uint64_t previous) noexcept
{
constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
if constexpr (round_i < wnaf_entries - 1) {
uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
// uint64_t slice = get_wnaf_bits_const<wnaf_bits, round_i * wnaf_bits>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i)] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
wnaf_round_packed<wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
} else {
constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits;
uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
// uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[1] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index);
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
}
}
template <size_t num_points, size_t wnaf_bits>
inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept
{
skew_map = ((scalar[0] & 1) == 0);
uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + static_cast<uint64_t>(skew_map);
wnaf_round<num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
}
template <size_t num_bits, size_t num_points, size_t wnaf_bits>
inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept
{
skew_map = ((scalar[0] & 1) == 0);
uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + static_cast<uint64_t>(skew_map);
wnaf_round<num_bits, num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
}
template <size_t scalar_bits, size_t num_points, size_t wnaf_bits, size_t round_i>
inline void wnaf_round_with_restricted_first_slice(uint64_t* scalar,
uint64_t* wnaf,
const uint64_t point_index,
const uint64_t previous) noexcept
{
constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits;
constexpr auto log2_num_points = static_cast<uint64_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
constexpr size_t bits_in_first_slice = scalar_bits % wnaf_bits;
if constexpr (round_i == 1) {
uint64_t slice = get_wnaf_bits_const<wnaf_bits, (round_i - 1) * wnaf_bits + bits_in_first_slice>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i) << log2_num_points] =
((((previous - (predicate << (bits_in_first_slice /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) |
(predicate << 31UL)) |
(point_index << 32UL);
if (round_i == 1) {
std::cerr << "writing value " << std::hex << wnaf[(wnaf_entries - round_i) << log2_num_points] << std::dec
<< " at index " << ((wnaf_entries - round_i) << log2_num_points) << std::endl;
}
wnaf_round_with_restricted_first_slice<scalar_bits, num_points, wnaf_bits, round_i + 1>(
scalar, wnaf, point_index, slice + predicate);
} else if constexpr (round_i < wnaf_entries - 1) {
uint64_t slice = get_wnaf_bits_const<wnaf_bits, (round_i - 1) * wnaf_bits + bits_in_first_slice>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[(wnaf_entries - round_i) << log2_num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index << 32UL);
wnaf_round_with_restricted_first_slice<scalar_bits, num_points, wnaf_bits, round_i + 1>(
scalar, wnaf, point_index, slice + predicate);
} else {
uint64_t slice = get_wnaf_bits_const<wnaf_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
uint64_t predicate = ((slice & 1UL) == 0UL);
wnaf[num_points] =
((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
(point_index << 32UL);
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
}
}
template <size_t num_bits, size_t num_points, size_t wnaf_bits>
inline void fixed_wnaf_with_restricted_first_slice(uint64_t* scalar,
uint64_t* wnaf,
bool& skew_map,
const size_t point_index) noexcept
{
constexpr size_t bits_in_first_slice = num_bits % wnaf_bits;
std::cerr << "bits in first slice = " << bits_in_first_slice << std::endl;
skew_map = ((scalar[0] & 1) == 0);
uint64_t previous = get_wnaf_bits_const<bits_in_first_slice, 0>(scalar) + static_cast<uint64_t>(skew_map);
std::cerr << "previous = " << previous << std::endl;
wnaf_round_with_restricted_first_slice<num_bits, num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
}
// template <size_t wnaf_bits>
// inline void fixed_wnaf_packed(const uint64_t* scalar,
// uint64_t* wnaf,
// bool& skew_map,
// const uint64_t point_index) noexcept
// {
// skew_map = ((scalar[0] & 1) == 0);
// uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + (uint64_t)skew_map;
// wnaf_round_packed<wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
// }
// template <size_t wnaf_bits>
// inline constexpr std::array<uint32_t, WNAF_SIZE(wnaf_bits)> fixed_wnaf(const uint64_t *scalar) const noexcept
// {
// bool skew_map = ((scalar[0] * 1) == 0);
// uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + (uint64_t)skew_map;
// std::array<uint32_t, WNAF_SIZE(wnaf_bits)> result;
// }
} // namespace bb::wnaf
// NOLINTEND(readability-implicit-bool-conversion)

View File

@@ -0,0 +1,9 @@
#include <iostream>
extern "C" {
void logstr(char const* str)
{
std::cerr << str << std::endl;
}
}

View File

@@ -0,0 +1 @@
void logstr(char const*);

View File

@@ -0,0 +1,17 @@
#include "count_leading_zeros.hpp"
#include <benchmark/benchmark.h>
using namespace benchmark;
void count_leading_zeros(State& state) noexcept
{
uint256_t input = 7;
for (auto _ : state) {
auto r = count_leading_zeros(input);
DoNotOptimize(r);
}
}
BENCHMARK(count_leading_zeros);
// NOLINTNEXTLINE macro invokation triggers style errors from googletest code
BENCHMARK_MAIN();

View File

@@ -0,0 +1,52 @@
#pragma once
#include "../uint128/uint128.hpp"
#include "../uint256/uint256.hpp"
#include <cstdint>
namespace bb::numeric {
/**
* Returns the number of leading 0 bits for a given integer type.
* Implemented in terms of intrinsics which will use instructions such as `bsr` or `lzcnt` for best performance.
* Undefined behavior when input is 0.
*/
template <typename T> constexpr inline size_t count_leading_zeros(T const& u);
template <> constexpr inline size_t count_leading_zeros<uint32_t>(uint32_t const& u)
{
return static_cast<size_t>(__builtin_clz(u));
}
template <> constexpr inline size_t count_leading_zeros<uint64_t>(uint64_t const& u)
{
return static_cast<size_t>(__builtin_clzll(u));
}
template <> constexpr inline size_t count_leading_zeros<uint128_t>(uint128_t const& u)
{
auto hi = static_cast<uint64_t>(u >> 64);
if (hi != 0U) {
return static_cast<size_t>(__builtin_clzll(hi));
}
auto lo = static_cast<uint64_t>(u);
return static_cast<size_t>(__builtin_clzll(lo)) + 64;
}
template <> constexpr inline size_t count_leading_zeros<uint256_t>(uint256_t const& u)
{
if (u.data[3] != 0U) {
return count_leading_zeros(u.data[3]);
}
if (u.data[2] != 0U) {
return count_leading_zeros(u.data[2]) + 64;
}
if (u.data[1] != 0U) {
return count_leading_zeros(u.data[1]) + 128;
}
if (u.data[0] != 0U) {
return count_leading_zeros(u.data[0]) + 192;
}
return 256;
}
} // namespace bb::numeric

View File

@@ -0,0 +1,36 @@
#include "count_leading_zeros.hpp"
#include <gtest/gtest.h>
using namespace bb;
TEST(bitop, ClzUint3231)
{
uint32_t a = 0b00000000000000000000000000000001;
EXPECT_EQ(numeric::count_leading_zeros(a), 31U);
}
TEST(bitop, ClzUint320)
{
uint32_t a = 0b10000000000000000000000000000001;
EXPECT_EQ(numeric::count_leading_zeros(a), 0U);
}
TEST(bitop, ClzUint640)
{
uint64_t a = 0b1000000000000000000000000000000100000000000000000000000000000000;
EXPECT_EQ(numeric::count_leading_zeros(a), 0U);
}
TEST(bitop, ClzUint256255)
{
uint256_t a = 0x1;
auto r = numeric::count_leading_zeros(a);
EXPECT_EQ(r, 255U);
}
TEST(bitop, ClzUint256248)
{
uint256_t a = 0x80;
auto r = numeric::count_leading_zeros(a);
EXPECT_EQ(r, 248U);
}

View File

@@ -0,0 +1,46 @@
#pragma once
#include <array>
#include <cstddef>
#include <cstdint>
namespace bb::numeric {
// from http://supertech.csail.mit.edu/papers/debruijn.pdf
constexpr inline uint32_t get_msb32(const uint32_t in)
{
constexpr std::array<uint8_t, 32> MultiplyDeBruijnBitPosition{ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16,
18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17,
24, 7, 19, 27, 23, 6, 26, 5, 4, 31 };
uint32_t v = in | (in >> 1);
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
return MultiplyDeBruijnBitPosition[static_cast<uint32_t>(v * static_cast<uint32_t>(0x07C4ACDD)) >>
static_cast<uint32_t>(27)];
}
constexpr inline uint64_t get_msb64(const uint64_t in)
{
constexpr std::array<uint8_t, 64> de_bruijn_sequence{ 0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28,
16, 3, 61, 54, 58, 35, 52, 50, 42, 21, 44, 38, 32,
29, 23, 17, 11, 4, 62, 46, 55, 26, 59, 40, 36, 15,
53, 34, 51, 20, 43, 31, 22, 10, 45, 25, 39, 14, 33,
19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63 };
uint64_t t = in | (in >> 1);
t |= t >> 2;
t |= t >> 4;
t |= t >> 8;
t |= t >> 16;
t |= t >> 32;
return static_cast<uint64_t>(de_bruijn_sequence[(t * 0x03F79D71B4CB0A89ULL) >> 58ULL]);
};
template <typename T> constexpr inline T get_msb(const T in)
{
return (sizeof(T) <= 4) ? get_msb32(in) : get_msb64(in);
}
} // namespace bb::numeric

View File

@@ -0,0 +1,35 @@
#include "get_msb.hpp"
#include <gtest/gtest.h>
using namespace bb;
TEST(bitop, GetMsbUint640Value)
{
uint64_t a = 0b00000000000000000000000000000000;
EXPECT_EQ(numeric::get_msb(a), 0U);
}
TEST(bitop, GetMsbUint320)
{
uint32_t a = 0b00000000000000000000000000000001;
EXPECT_EQ(numeric::get_msb(a), 0U);
}
TEST(bitop, GetMsbUint3231)
{
uint32_t a = 0b10000000000000000000000000000001;
EXPECT_EQ(numeric::get_msb(a), 31U);
}
TEST(bitop, GetMsbUint6463)
{
uint64_t a = 0b1000000000000000000000000000000100000000000000000000000000000000;
EXPECT_EQ(numeric::get_msb(a), 63U);
}
TEST(bitop, GetMsbSizeT7)
{
size_t a = 0x80;
auto r = numeric::get_msb(a);
EXPECT_EQ(r, 7U);
}

View File

@@ -0,0 +1,11 @@
#pragma once
#include <cstddef>
namespace bb::numeric {
template <typename T> inline T keep_n_lsb(T const& input, size_t num_bits)
{
return num_bits >= sizeof(T) * 8 ? input : input & ((T(1) << num_bits) - 1);
}
} // namespace bb::numeric

View File

@@ -0,0 +1,34 @@
#pragma once
#include "./get_msb.hpp"
#include <cstdint>
namespace bb::numeric {
constexpr uint64_t pow64(const uint64_t input, const uint64_t exponent)
{
if (input == 0) {
return 0;
}
if (exponent == 0) {
return 1;
}
uint64_t accumulator = input;
uint64_t to_mul = input;
const uint64_t maximum_set_bit = get_msb64(exponent);
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
accumulator *= accumulator;
if (((exponent >> i) & 1) != 0U) {
accumulator *= to_mul;
}
}
return accumulator;
}
constexpr bool is_power_of_two(uint64_t x)
{
return (x != 0U) && ((x & (x - 1)) == 0U);
}
} // namespace bb::numeric

View File

@@ -0,0 +1,16 @@
#pragma once
#include <cstddef>
#include <cstdint>
namespace bb::numeric {
constexpr inline uint64_t rotate64(const uint64_t value, const uint64_t rotation)
{
return rotation != 0U ? (value >> rotation) + (value << (64 - rotation)) : value;
}
constexpr inline uint32_t rotate32(const uint32_t value, const uint32_t rotation)
{
return rotation != 0U ? (value >> rotation) + (value << (32 - rotation)) : value;
}
} // namespace bb::numeric

View File

@@ -0,0 +1,157 @@
#pragma once
#include "../../common/throw_or_abort.hpp"
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <vector>
#include "../uint256/uint256.hpp"
namespace bb::numeric {
inline std::vector<uint64_t> slice_input(const uint256_t& input, const uint64_t base, const size_t num_slices)
{
uint256_t target = input;
std::vector<uint64_t> slices;
if (num_slices > 0) {
for (size_t i = 0; i < num_slices; ++i) {
slices.push_back((target % base).data[0]);
target /= base;
}
} else {
while (target > 0) {
slices.push_back((target % base).data[0]);
target /= base;
}
}
return slices;
}
inline std::vector<uint64_t> slice_input_using_variable_bases(const uint256_t& input,
const std::vector<uint64_t>& bases)
{
uint256_t target = input;
std::vector<uint64_t> slices;
for (size_t i = 0; i < bases.size(); ++i) {
if (target >= bases[i] && i == bases.size() - 1) {
throw_or_abort(format("Last key slice greater than ", bases[i]));
}
slices.push_back((target % bases[i]).data[0]);
target /= bases[i];
}
return slices;
}
template <uint64_t base, uint64_t num_slices> constexpr std::array<uint256_t, num_slices> get_base_powers()
{
std::array<uint256_t, num_slices> output{};
output[0] = 1;
for (size_t i = 1; i < num_slices; ++i) {
output[i] = output[i - 1] * base;
}
return output;
}
template <uint64_t base> constexpr uint256_t map_into_sparse_form(const uint64_t input)
{
uint256_t out = 0UL;
auto converted = input;
constexpr auto base_powers = get_base_powers<base, 32>();
for (size_t i = 0; i < 32; ++i) {
uint64_t sparse_bit = ((converted >> i) & 1U);
if (sparse_bit) {
out += base_powers[i];
}
}
return out;
}
template <uint64_t base> constexpr uint64_t map_from_sparse_form(const uint256_t& input)
{
uint256_t target = input;
uint64_t output = 0;
constexpr auto bases = get_base_powers<base, 32>();
for (uint64_t i = 0; i < 32; ++i) {
const auto& base_power = bases[static_cast<size_t>(31 - i)];
uint256_t prev_threshold = 0;
for (uint64_t j = 1; j < base + 1; ++j) {
const auto threshold = prev_threshold + base_power;
if (target < threshold) {
bool bit = ((j - 1) & 1);
if (bit) {
output += (1ULL << (31ULL - i));
}
if (j > 1) {
target -= (prev_threshold);
}
break;
}
prev_threshold = threshold;
}
}
return output;
}
template <uint64_t base, size_t num_bits> class sparse_int {
public:
sparse_int(const uint64_t input = 0)
: value(input)
{
for (size_t i = 0; i < num_bits; ++i) {
const uint64_t bit = (input >> i) & 1U;
limbs[i] = bit;
}
}
sparse_int(const sparse_int& other) noexcept = default;
sparse_int(sparse_int&& other) noexcept = default;
sparse_int& operator=(const sparse_int& other) noexcept = default;
sparse_int& operator=(sparse_int&& other) noexcept = default;
~sparse_int() noexcept = default;
sparse_int operator+(const sparse_int& other) const
{
sparse_int result(*this);
for (size_t i = 0; i < num_bits - 1; ++i) {
result.limbs[i] += other.limbs[i];
if (result.limbs[i] >= base) {
result.limbs[i] -= base;
++result.limbs[i + 1];
}
}
result.limbs[num_bits - 1] += other.limbs[num_bits - 1];
result.limbs[num_bits - 1] %= base;
result.value += other.value;
return result;
};
sparse_int operator+=(const sparse_int& other)
{
*this = *this + other;
return *this;
}
[[nodiscard]] uint64_t get_value() const { return value; }
[[nodiscard]] uint64_t get_sparse_value() const
{
uint64_t result = 0;
for (size_t i = num_bits - 1; i < num_bits; --i) {
result *= base;
result += limbs[i];
}
return result;
}
const std::array<uint64_t, num_bits>& get_limbs() const { return limbs; }
private:
std::array<uint64_t, num_bits> limbs;
uint64_t value;
uint64_t sparse_value;
};
} // namespace bb::numeric

View File

@@ -0,0 +1,139 @@
#include "engine.hpp"
#include "../../common/assert.hpp"
#include <array>
#include <functional>
#include <random>
namespace bb::numeric {
namespace {
auto generate_random_data()
{
std::array<unsigned int, 32> random_data;
std::random_device source;
std::generate(std::begin(random_data), std::end(random_data), std::ref(source));
return random_data;
}
} // namespace
class RandomEngine : public RNG {
public:
uint8_t get_random_uint8() override
{
auto buf = generate_random_data();
uint32_t out = buf[0];
return static_cast<uint8_t>(out);
}
uint16_t get_random_uint16() override
{
auto buf = generate_random_data();
uint32_t out = buf[0];
return static_cast<uint16_t>(out);
}
uint32_t get_random_uint32() override
{
auto buf = generate_random_data();
uint32_t out = buf[0];
return static_cast<uint32_t>(out);
}
uint64_t get_random_uint64() override
{
auto buf = generate_random_data();
auto lo = static_cast<uint64_t>(buf[0]);
auto hi = static_cast<uint64_t>(buf[1]);
return (lo + (hi << 32ULL));
}
uint128_t get_random_uint128() override
{
auto big = get_random_uint256();
auto lo = static_cast<uint128_t>(big.data[0]);
auto hi = static_cast<uint128_t>(big.data[1]);
return (lo + (hi << static_cast<uint128_t>(64ULL)));
}
uint256_t get_random_uint256() override
{
const auto get64 = [](const std::array<uint32_t, 32>& buffer, const size_t offset) {
auto lo = static_cast<uint64_t>(buffer[0 + offset]);
auto hi = static_cast<uint64_t>(buffer[1 + offset]);
return (lo + (hi << 32ULL));
};
auto buf = generate_random_data();
uint64_t lolo = get64(buf, 0);
uint64_t lohi = get64(buf, 2);
uint64_t hilo = get64(buf, 4);
uint64_t hihi = get64(buf, 6);
return { lolo, lohi, hilo, hihi };
}
};
class DebugEngine : public RNG {
public:
DebugEngine()
// disable linting for this line: we want the DEBUG engine to produce predictable pseudorandom numbers!
// NOLINTNEXTLINE(cert-msc32-c, cert-msc51-cpp)
: engine(std::mt19937_64(12345))
{}
DebugEngine(std::uint_fast64_t seed)
: engine(std::mt19937_64(seed))
{}
uint8_t get_random_uint8() override { return static_cast<uint8_t>(dist(engine)); }
uint16_t get_random_uint16() override { return static_cast<uint16_t>(dist(engine)); }
uint32_t get_random_uint32() override { return static_cast<uint32_t>(dist(engine)); }
uint64_t get_random_uint64() override { return dist(engine); }
uint128_t get_random_uint128() override
{
uint128_t hi = dist(engine);
uint128_t lo = dist(engine);
return (hi << 64) | lo;
}
uint256_t get_random_uint256() override
{
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
auto a = dist(engine);
auto b = dist(engine);
auto c = dist(engine);
auto d = dist(engine);
return { a, b, c, d };
}
private:
std::mt19937_64 engine;
std::uniform_int_distribution<uint64_t> dist = std::uniform_int_distribution<uint64_t>{ 0ULL, UINT64_MAX };
};
/**
* Used by tests to ensure consistent behavior.
*/
RNG& get_debug_randomness(bool reset, std::uint_fast64_t seed)
{
// static std::seed_seq seed({ 1, 2, 3, 4, 5 });
static DebugEngine debug_engine = DebugEngine();
if (reset) {
debug_engine = DebugEngine(seed);
}
return debug_engine;
}
/**
* Default engine. If wanting consistent proof construction, uncomment the line to return the debug engine.
*/
RNG& get_randomness()
{
// return get_debug_randomness();
static RandomEngine engine;
return engine;
}
} // namespace bb::numeric

View File

@@ -0,0 +1,52 @@
#pragma once
#include "../uint128/uint128.hpp"
#include "../uint256/uint256.hpp"
#include "../uintx/uintx.hpp"
#include "unistd.h"
#include <cstdint>
#include <random>
namespace bb::numeric {
class RNG {
public:
virtual uint8_t get_random_uint8() = 0;
virtual uint16_t get_random_uint16() = 0;
virtual uint32_t get_random_uint32() = 0;
virtual uint64_t get_random_uint64() = 0;
virtual uint128_t get_random_uint128() = 0;
virtual uint256_t get_random_uint256() = 0;
virtual ~RNG() = default;
RNG() noexcept = default;
RNG(const RNG& other) = default;
RNG(RNG&& other) = default;
RNG& operator=(const RNG& other) = default;
RNG& operator=(RNG&& other) = default;
uint512_t get_random_uint512()
{
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
auto lo = get_random_uint256();
auto hi = get_random_uint256();
return { lo, hi };
}
uint1024_t get_random_uint1024()
{
// Do not inline in constructor call. Evaluation order is important for cross-compiler consistency.
auto lo = get_random_uint512();
auto hi = get_random_uint512();
return { lo, hi };
}
};
RNG& get_debug_randomness(bool reset = false, std::uint_fast64_t seed = 12345);
RNG& get_randomness();
} // namespace bb::numeric

View File

@@ -0,0 +1,212 @@
#pragma once
#include <cstdint>
#include <iomanip>
#include <ostream>
#ifdef __i386__
#include "../../common/serialize.hpp"
#include <concepts>
namespace bb::numeric {
class alignas(32) uint128_t {
public:
uint32_t data[4]; // NOLINT
constexpr uint128_t(const uint64_t a = 0)
: data{ static_cast<uint32_t>(a), static_cast<uint32_t>(a >> 32), 0, 0 }
{}
constexpr uint128_t(const uint32_t a, const uint32_t b, const uint32_t c, const uint32_t d)
: data{ a, b, c, d }
{}
constexpr uint128_t(const uint128_t& other)
: data{ other.data[0], other.data[1], other.data[2], other.data[3] }
{}
constexpr uint128_t(uint128_t&& other) = default;
static constexpr uint128_t from_uint64(const uint64_t a)
{
return { static_cast<uint32_t>(a), static_cast<uint32_t>(a >> 32), 0, 0 };
}
constexpr explicit operator uint64_t() { return (static_cast<uint64_t>(data[1]) << 32) + data[0]; }
constexpr uint128_t& operator=(const uint128_t& other) = default;
constexpr uint128_t& operator=(uint128_t&& other) = default;
constexpr ~uint128_t() = default;
explicit constexpr operator bool() const { return static_cast<bool>(data[0]); };
template <std::integral T> explicit constexpr operator T() const { return static_cast<T>(data[0]); };
[[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const;
[[nodiscard]] constexpr uint64_t get_msb() const;
[[nodiscard]] constexpr uint128_t slice(uint64_t start, uint64_t end) const;
[[nodiscard]] constexpr uint128_t pow(const uint128_t& exponent) const;
constexpr uint128_t operator+(const uint128_t& other) const;
constexpr uint128_t operator-(const uint128_t& other) const;
constexpr uint128_t operator-() const;
constexpr uint128_t operator*(const uint128_t& other) const;
constexpr uint128_t operator/(const uint128_t& other) const;
constexpr uint128_t operator%(const uint128_t& other) const;
constexpr uint128_t operator>>(const uint128_t& other) const;
constexpr uint128_t operator<<(const uint128_t& other) const;
constexpr uint128_t operator&(const uint128_t& other) const;
constexpr uint128_t operator^(const uint128_t& other) const;
constexpr uint128_t operator|(const uint128_t& other) const;
constexpr uint128_t operator~() const;
constexpr bool operator==(const uint128_t& other) const;
constexpr bool operator!=(const uint128_t& other) const;
constexpr bool operator!() const;
constexpr bool operator>(const uint128_t& other) const;
constexpr bool operator<(const uint128_t& other) const;
constexpr bool operator>=(const uint128_t& other) const;
constexpr bool operator<=(const uint128_t& other) const;
static constexpr size_t length() { return 128; }
constexpr uint128_t& operator+=(const uint128_t& other)
{
*this = *this + other;
return *this;
};
constexpr uint128_t& operator-=(const uint128_t& other)
{
*this = *this - other;
return *this;
};
constexpr uint128_t& operator*=(const uint128_t& other)
{
*this = *this * other;
return *this;
};
constexpr uint128_t& operator/=(const uint128_t& other)
{
*this = *this / other;
return *this;
};
constexpr uint128_t& operator%=(const uint128_t& other)
{
*this = *this % other;
return *this;
};
constexpr uint128_t& operator++()
{
*this += uint128_t(1);
return *this;
};
constexpr uint128_t& operator--()
{
*this -= uint128_t(1);
return *this;
};
constexpr uint128_t& operator&=(const uint128_t& other)
{
*this = *this & other;
return *this;
};
constexpr uint128_t& operator^=(const uint128_t& other)
{
*this = *this ^ other;
return *this;
};
constexpr uint128_t& operator|=(const uint128_t& other)
{
*this = *this | other;
return *this;
};
constexpr uint128_t& operator>>=(const uint128_t& other)
{
*this = *this >> other;
return *this;
};
constexpr uint128_t& operator<<=(const uint128_t& other)
{
*this = *this << other;
return *this;
};
[[nodiscard]] constexpr std::pair<uint128_t, uint128_t> mul_extended(const uint128_t& other) const;
[[nodiscard]] constexpr std::pair<uint128_t, uint128_t> divmod(const uint128_t& b) const;
private:
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> mul_wide(uint32_t a, uint32_t b);
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> addc(uint32_t a, uint32_t b, uint32_t carry_in);
[[nodiscard]] static constexpr uint32_t addc_discard_hi(uint32_t a, uint32_t b, uint32_t carry_in);
[[nodiscard]] static constexpr uint32_t sbb_discard_hi(uint32_t a, uint32_t b, uint32_t borrow_in);
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> sbb(uint32_t a, uint32_t b, uint32_t borrow_in);
[[nodiscard]] static constexpr uint32_t mac_discard_hi(uint32_t a, uint32_t b, uint32_t c, uint32_t carry_in);
[[nodiscard]] static constexpr std::pair<uint32_t, uint32_t> mac(uint32_t a,
uint32_t b,
uint32_t c,
uint32_t carry_in);
};
inline std::ostream& operator<<(std::ostream& os, uint128_t const& a)
{
std::ios_base::fmtflags f(os.flags());
os << std::hex << "0x" << std::setfill('0') << std::setw(8) << a.data[3] << std::setw(8) << a.data[2]
<< std::setw(8) << a.data[1] << std::setw(8) << a.data[0];
os.flags(f);
return os;
}
template <typename B> inline void read(B& it, uint128_t& value)
{
using serialize::read;
uint32_t a = 0;
uint32_t b = 0;
uint32_t c = 0;
uint32_t d = 0;
read(it, d);
read(it, c);
read(it, b);
read(it, a);
value = uint128_t(a, b, c, d);
}
template <typename B> inline void write(B& it, uint128_t const& value)
{
using serialize::write;
write(it, value.data[3]);
write(it, value.data[2]);
write(it, value.data[1]);
write(it, value.data[0]);
}
} // namespace bb::numeric
#include "./uint128_impl.hpp"
// disable linter errors; we want to expose a global uint128_t type to mimic uint64_t, uint32_t etc
// NOLINTNEXTLINE(tidymisc-unused-using-decls, google-global-names-in-headers, misc-unused-using-decls)
using numeric::uint128_t;
#else
__extension__ using uint128_t = unsigned __int128;
namespace std {
// can ignore linter error for streaming operations, we need to add to std namespace to support printing this type!
// NOLINTNEXTLINE(cert-dcl58-cpp)
inline std::ostream& operator<<(std::ostream& os, uint128_t const& a)
{
std::ios_base::fmtflags f(os.flags());
os << std::hex << "0x" << std::setfill('0') << std::setw(16) << static_cast<uint64_t>(a >> 64) << std::setw(16)
<< static_cast<uint64_t>(a);
os.flags(f);
return os;
}
} // namespace std
#endif

View File

@@ -0,0 +1,414 @@
#ifdef __i386__
#pragma once
#include "../bitop/get_msb.hpp"
#include "./uint128.hpp"
#include "../../common/assert.hpp"
namespace bb::numeric {
constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
{
const uint32_t a_lo = a & 0xffffULL;
const uint32_t a_hi = a >> 16ULL;
const uint32_t b_lo = b & 0xffffULL;
const uint32_t b_hi = b >> 16ULL;
const uint32_t lo_lo = a_lo * b_lo;
const uint32_t hi_lo = a_hi * b_lo;
const uint32_t lo_hi = a_lo * b_hi;
const uint32_t hi_hi = a_hi * b_hi;
const uint32_t cross = (lo_lo >> 16) + (hi_lo & 0xffffULL) + lo_hi;
return { (cross << 16ULL) | (lo_lo & 0xffffULL), (hi_lo >> 16ULL) + (cross >> 16ULL) + hi_hi };
}
// compute a + b + carry, returning the carry
constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
{
const uint32_t sum = a + b;
const auto carry_temp = static_cast<uint32_t>(sum < a);
const uint32_t r = sum + carry_in;
const uint32_t carry_out = carry_temp + static_cast<unsigned int>(r < carry_in);
return { r, carry_out };
}
constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
{
return a + b + carry_in;
}
constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
{
const uint32_t t_1 = a - (borrow_in >> 31ULL);
const auto borrow_temp_1 = static_cast<uint32_t>(t_1 > a);
const uint32_t t_2 = t_1 - b;
const auto borrow_temp_2 = static_cast<uint32_t>(t_2 > t_1);
return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
}
constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
{
return a - b - (borrow_in >> 31ULL);
}
// {r, carry_out} = a + carry_in + b * c
constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
const uint32_t b,
const uint32_t c,
const uint32_t carry_in)
{
std::pair<uint32_t, uint32_t> result = mul_wide(b, c);
result.first += a;
const auto overflow_c = static_cast<uint32_t>(result.first < a);
result.first += carry_in;
const auto overflow_carry = static_cast<uint32_t>(result.first < carry_in);
result.second += (overflow_c + overflow_carry);
return result;
}
constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
const uint32_t b,
const uint32_t c,
const uint32_t carry_in)
{
return (b * c + a + carry_in);
}
constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
{
if (*this == 0 || b == 0) {
return { 0, 0 };
}
if (b == 1) {
return { *this, 0 };
}
if (*this == b) {
return { 1, 0 };
}
if (b > *this) {
return { 0, *this };
}
uint128_t quotient = 0;
uint128_t remainder = *this;
uint64_t bit_difference = get_msb() - b.get_msb();
uint128_t divisor = b << bit_difference;
uint128_t accumulator = uint128_t(1) << bit_difference;
// if the divisor is bigger than the remainder, a and b have the same bit length
if (divisor > remainder) {
divisor >>= 1;
accumulator >>= 1;
}
// while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
// and add to the quotient
while (remainder >= b) {
// we've shunted 'divisor' up to have the same bit length as our remainder.
// If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
if (remainder >= divisor) {
remainder -= divisor;
// we can use OR here instead of +, as
// accumulator is always a nice power of two
quotient |= accumulator;
}
divisor >>= 1;
accumulator >>= 1;
}
return { quotient, remainder };
}
constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
{
const auto [r0, t0] = mul_wide(data[0], other.data[0]);
const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
uint128_t lo(r0, r1, r2, r3);
uint128_t hi(r4, r5, r6, r7);
return { lo, hi };
}
/**
* Viewing `this` uint128_t as a bit string, and counting bits from 0, slices a substring.
* @returns the uint128_t equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the
* `end`-th bit of `this`.
*/
constexpr uint128_t uint128_t::slice(const uint64_t start, const uint64_t end) const
{
const uint64_t range = end - start;
const uint128_t mask = (range == 128) ? -uint128_t(1) : (uint128_t(1) << range) - 1;
return ((*this) >> start) & mask;
}
constexpr uint128_t uint128_t::pow(const uint128_t& exponent) const
{
uint128_t accumulator{ data[0], data[1], data[2], data[3] };
uint128_t to_mul{ data[0], data[1], data[2], data[3] };
const uint64_t maximum_set_bit = exponent.get_msb();
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
accumulator *= accumulator;
if (exponent.get_bit(static_cast<uint64_t>(i))) {
accumulator *= to_mul;
}
}
if (exponent == uint128_t(0)) {
accumulator = uint128_t(1);
} else if (*this == uint128_t(0)) {
accumulator = uint128_t(0);
}
return accumulator;
}
constexpr bool uint128_t::get_bit(const uint64_t bit_index) const
{
ASSERT(bit_index < 128);
if (bit_index > 127) {
return false;
}
const auto idx = static_cast<size_t>(bit_index >> 5);
const size_t shift = bit_index & 31;
return static_cast<bool>((data[idx] >> shift) & 1);
}
constexpr uint64_t uint128_t::get_msb() const
{
uint64_t idx = numeric::get_msb64(data[3]);
idx = (idx == 0 && data[3] == 0) ? numeric::get_msb64(data[2]) : idx + 32;
idx = (idx == 0 && data[2] == 0) ? numeric::get_msb64(data[1]) : idx + 32;
idx = (idx == 0 && data[1] == 0) ? numeric::get_msb64(data[0]) : idx + 32;
return idx;
}
constexpr uint128_t uint128_t::operator+(const uint128_t& other) const
{
const auto [r0, t0] = addc(data[0], other.data[0], 0);
const auto [r1, t1] = addc(data[1], other.data[1], t0);
const auto [r2, t2] = addc(data[2], other.data[2], t1);
const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
return { r0, r1, r2, r3 };
};
constexpr uint128_t uint128_t::operator-(const uint128_t& other) const
{
const auto [r0, t0] = sbb(data[0], other.data[0], 0);
const auto [r1, t1] = sbb(data[1], other.data[1], t0);
const auto [r2, t2] = sbb(data[2], other.data[2], t1);
const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
return { r0, r1, r2, r3 };
}
constexpr uint128_t uint128_t::operator-() const
{
return uint128_t(0) - *this;
}
constexpr uint128_t uint128_t::operator*(const uint128_t& other) const
{
const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
return { r0, r1, r2, r3 };
}
constexpr uint128_t uint128_t::operator/(const uint128_t& other) const
{
return divmod(other).first;
}
constexpr uint128_t uint128_t::operator%(const uint128_t& other) const
{
return divmod(other).second;
}
constexpr uint128_t uint128_t::operator&(const uint128_t& other) const
{
return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
}
constexpr uint128_t uint128_t::operator^(const uint128_t& other) const
{
return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
}
constexpr uint128_t uint128_t::operator|(const uint128_t& other) const
{
return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
}
constexpr uint128_t uint128_t::operator~() const
{
return { ~data[0], ~data[1], ~data[2], ~data[3] };
}
constexpr bool uint128_t::operator==(const uint128_t& other) const
{
return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
}
constexpr bool uint128_t::operator!=(const uint128_t& other) const
{
return !(*this == other);
}
constexpr bool uint128_t::operator!() const
{
return *this == uint128_t(0ULL);
}
constexpr bool uint128_t::operator>(const uint128_t& other) const
{
bool t0 = data[3] > other.data[3];
bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
bool t3 =
data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
return t0 || t1 || t2 || t3;
}
constexpr bool uint128_t::operator>=(const uint128_t& other) const
{
return (*this > other) || (*this == other);
}
constexpr bool uint128_t::operator<(const uint128_t& other) const
{
return other > *this;
}
constexpr bool uint128_t::operator<=(const uint128_t& other) const
{
return (*this < other) || (*this == other);
}
constexpr uint128_t uint128_t::operator>>(const uint128_t& other) const
{
uint32_t total_shift = other.data[0];
if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
return 0;
}
if (total_shift == 0) {
return *this;
}
uint32_t num_shifted_limbs = total_shift >> 5ULL;
uint32_t limb_shift = total_shift & 31ULL;
std::array<uint32_t, 4> shifted_limbs = { 0, 0, 0, 0 };
if (limb_shift == 0) {
shifted_limbs[0] = data[0];
shifted_limbs[1] = data[1];
shifted_limbs[2] = data[2];
shifted_limbs[3] = data[3];
} else {
uint32_t remainder_shift = 32ULL - limb_shift;
shifted_limbs[3] = data[3] >> limb_shift;
uint32_t remainder = (data[3]) << remainder_shift;
shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
remainder = (data[2]) << remainder_shift;
shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
remainder = (data[1]) << remainder_shift;
shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
}
uint128_t result(0);
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
}
return result;
}
constexpr uint128_t uint128_t::operator<<(const uint128_t& other) const
{
uint32_t total_shift = other.data[0];
if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
return 0;
}
if (total_shift == 0) {
return *this;
}
uint32_t num_shifted_limbs = total_shift >> 5ULL;
uint32_t limb_shift = total_shift & 31ULL;
std::array<uint32_t, 4> shifted_limbs{ 0, 0, 0, 0 };
if (limb_shift == 0) {
shifted_limbs[0] = data[0];
shifted_limbs[1] = data[1];
shifted_limbs[2] = data[2];
shifted_limbs[3] = data[3];
} else {
uint32_t remainder_shift = 32ULL - limb_shift;
shifted_limbs[0] = data[0] << limb_shift;
uint32_t remainder = data[0] >> remainder_shift;
shifted_limbs[1] = (data[1] << limb_shift) + remainder;
remainder = data[1] >> remainder_shift;
shifted_limbs[2] = (data[2] << limb_shift) + remainder;
remainder = data[2] >> remainder_shift;
shifted_limbs[3] = (data[3] << limb_shift) + remainder;
}
uint128_t result(0);
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
}
return result;
}
} // namespace bb::numeric
#endif

View File

@@ -0,0 +1,239 @@
/**
* uint256_t
* Copyright Aztec 2020
*
* An unsigned 256 bit integer type.
*
* Constructor and all methods are constexpr.
* Ideally, uint256_t should be able to be treated like any other literal type.
*
* Not optimized for performance, this code doesn't touch any of our hot paths when constructing PLONK proofs.
**/
#pragma once
#include "../uint128/uint128.hpp"
#include "../../common/throw_or_abort.hpp"
#include <concepts>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <sstream>
namespace bb::numeric {
class alignas(32) uint256_t {
public:
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
#define WASM_NUM_LIMBS 9
#define WASM_LIMB_BITS 29
#endif
constexpr uint256_t(const uint64_t a = 0) noexcept
: data{ a, 0, 0, 0 }
{}
constexpr uint256_t(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept
: data{ a, b, c, d }
{}
constexpr uint256_t(const uint256_t& other) noexcept
: data{ other.data[0], other.data[1], other.data[2], other.data[3] }
{}
constexpr uint256_t(uint256_t&& other) noexcept = default;
explicit constexpr uint256_t(std::string input) noexcept
{
/* Quick and dirty conversion from a single character to its hex equivelent */
constexpr auto HexCharToInt = [](uint8_t Input) {
bool valid =
(Input >= 'a' && Input <= 'f') || (Input >= 'A' && Input <= 'F') || (Input >= '0' && Input <= '9');
if (!valid) {
throw_or_abort("Error, uint256 constructed from string_view with invalid hex parameter");
}
uint8_t res =
((Input >= 'a') && (Input <= 'f')) ? (Input - (static_cast<uint8_t>('a') - static_cast<uint8_t>(10)))
: ((Input >= 'A') && (Input <= 'F')) ? (Input - (static_cast<uint8_t>('A') - static_cast<uint8_t>(10)))
: ((Input >= '0') && (Input <= '9')) ? (Input - static_cast<uint8_t>('0'))
: 0;
return res;
};
std::array<uint64_t, 4> limbs{ 0, 0, 0, 0 };
size_t start_index = 0;
if (input.size() == 66 && input[0] == '0' && input[1] == 'x') {
start_index = 2;
} else if (input.size() != 64) {
throw_or_abort("Error, uint256 constructed from string_view with invalid length");
}
for (size_t j = 0; j < 4; ++j) {
const size_t limb_index = start_index + j * 16;
for (size_t i = 0; i < 8; ++i) {
const size_t byte_index = limb_index + (i * 2);
uint8_t nibble_hi = HexCharToInt(static_cast<uint8_t>(input[byte_index]));
uint8_t nibble_lo = HexCharToInt(static_cast<uint8_t>(input[byte_index + 1]));
uint8_t byte = static_cast<uint8_t>((nibble_hi * 16) + nibble_lo);
limbs[j] <<= 8;
limbs[j] += byte;
}
}
data[0] = limbs[3];
data[1] = limbs[2];
data[2] = limbs[1];
data[3] = limbs[0];
}
static constexpr uint256_t from_uint128(const uint128_t a) noexcept
{
return { static_cast<uint64_t>(a), static_cast<uint64_t>(a >> 64), 0, 0 };
}
constexpr explicit operator uint128_t() { return (static_cast<uint128_t>(data[1]) << 64) + data[0]; }
constexpr uint256_t& operator=(const uint256_t& other) noexcept = default;
constexpr uint256_t& operator=(uint256_t&& other) noexcept = default;
constexpr ~uint256_t() noexcept = default;
explicit constexpr operator bool() const { return static_cast<bool>(data[0]); };
template <std::integral T> explicit constexpr operator T() const { return static_cast<T>(data[0]); };
[[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const;
[[nodiscard]] constexpr uint64_t get_msb() const;
[[nodiscard]] constexpr uint256_t slice(uint64_t start, uint64_t end) const;
[[nodiscard]] constexpr uint256_t pow(const uint256_t& exponent) const;
constexpr uint256_t operator+(const uint256_t& other) const;
constexpr uint256_t operator-(const uint256_t& other) const;
constexpr uint256_t operator-() const;
constexpr uint256_t operator*(const uint256_t& other) const;
constexpr uint256_t operator/(const uint256_t& other) const;
constexpr uint256_t operator%(const uint256_t& other) const;
constexpr uint256_t operator>>(const uint256_t& other) const;
constexpr uint256_t operator<<(const uint256_t& other) const;
constexpr uint256_t operator&(const uint256_t& other) const;
constexpr uint256_t operator^(const uint256_t& other) const;
constexpr uint256_t operator|(const uint256_t& other) const;
constexpr uint256_t operator~() const;
constexpr bool operator==(const uint256_t& other) const;
constexpr bool operator!=(const uint256_t& other) const;
constexpr bool operator!() const;
constexpr bool operator>(const uint256_t& other) const;
constexpr bool operator<(const uint256_t& other) const;
constexpr bool operator>=(const uint256_t& other) const;
constexpr bool operator<=(const uint256_t& other) const;
static constexpr size_t length() { return 256; }
constexpr uint256_t& operator+=(const uint256_t& other)
{
*this = *this + other;
return *this;
};
constexpr uint256_t& operator-=(const uint256_t& other)
{
*this = *this - other;
return *this;
};
constexpr uint256_t& operator*=(const uint256_t& other)
{
*this = *this * other;
return *this;
};
constexpr uint256_t& operator/=(const uint256_t& other)
{
*this = *this / other;
return *this;
};
constexpr uint256_t& operator%=(const uint256_t& other)
{
*this = *this % other;
return *this;
};
constexpr uint256_t& operator++()
{
*this += uint256_t(1);
return *this;
};
constexpr uint256_t& operator--()
{
*this -= uint256_t(1);
return *this;
};
constexpr uint256_t& operator&=(const uint256_t& other)
{
*this = *this & other;
return *this;
};
constexpr uint256_t& operator^=(const uint256_t& other)
{
*this = *this ^ other;
return *this;
};
constexpr uint256_t& operator|=(const uint256_t& other)
{
*this = *this | other;
return *this;
};
constexpr uint256_t& operator>>=(const uint256_t& other)
{
*this = *this >> other;
return *this;
};
constexpr uint256_t& operator<<=(const uint256_t& other)
{
*this = *this << other;
return *this;
};
[[nodiscard]] constexpr std::pair<uint256_t, uint256_t> mul_extended(const uint256_t& other) const;
uint64_t data[4]; // NOLINT
[[nodiscard]] constexpr std::pair<uint256_t, uint256_t> divmod(const uint256_t& b) const;
private:
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> mul_wide(uint64_t a, uint64_t b);
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> addc(uint64_t a, uint64_t b, uint64_t carry_in);
[[nodiscard]] static constexpr uint64_t addc_discard_hi(uint64_t a, uint64_t b, uint64_t carry_in);
[[nodiscard]] static constexpr uint64_t sbb_discard_hi(uint64_t a, uint64_t b, uint64_t borrow_in);
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> sbb(uint64_t a, uint64_t b, uint64_t borrow_in);
[[nodiscard]] static constexpr uint64_t mac_discard_hi(uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in);
[[nodiscard]] static constexpr std::pair<uint64_t, uint64_t> mac(uint64_t a,
uint64_t b,
uint64_t c,
uint64_t carry_in);
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
static constexpr void wasm_madd(const uint64_t& left_limb,
const uint64_t* right_limbs,
uint64_t& result_0,
uint64_t& result_1,
uint64_t& result_2,
uint64_t& result_3,
uint64_t& result_4,
uint64_t& result_5,
uint64_t& result_6,
uint64_t& result_7,
uint64_t& result_8);
[[nodiscard]] static constexpr std::array<uint64_t, WASM_NUM_LIMBS> wasm_convert(const uint64_t* data);
#endif
};
inline std::ostream& operator<<(std::ostream& os, uint256_t const& a)
{
std::ios_base::fmtflags f(os.flags());
os << std::hex << "0x" << std::setfill('0') << std::setw(16) << a.data[3] << std::setw(16) << a.data[2]
<< std::setw(16) << a.data[1] << std::setw(16) << a.data[0];
os.flags(f);
return os;
}
} // namespace bb::numeric

View File

@@ -0,0 +1,622 @@
#pragma once
#include "../bitop/get_msb.hpp"
#include "./uint256.hpp"
#include "../../common/assert.hpp"
namespace bb::numeric {
constexpr std::pair<uint64_t, uint64_t> uint256_t::mul_wide(const uint64_t a, const uint64_t b)
{
const uint64_t a_lo = a & 0xffffffffULL;
const uint64_t a_hi = a >> 32ULL;
const uint64_t b_lo = b & 0xffffffffULL;
const uint64_t b_hi = b >> 32ULL;
const uint64_t lo_lo = a_lo * b_lo;
const uint64_t hi_lo = a_hi * b_lo;
const uint64_t lo_hi = a_lo * b_hi;
const uint64_t hi_hi = a_hi * b_hi;
const uint64_t cross = (lo_lo >> 32ULL) + (hi_lo & 0xffffffffULL) + lo_hi;
return { (cross << 32ULL) | (lo_lo & 0xffffffffULL), (hi_lo >> 32ULL) + (cross >> 32ULL) + hi_hi };
}
// compute a + b + carry, returning the carry
constexpr std::pair<uint64_t, uint64_t> uint256_t::addc(const uint64_t a, const uint64_t b, const uint64_t carry_in)
{
const uint64_t sum = a + b;
const auto carry_temp = static_cast<uint64_t>(sum < a);
const uint64_t r = sum + carry_in;
const uint64_t carry_out = carry_temp + static_cast<uint64_t>(r < carry_in);
return { r, carry_out };
}
constexpr uint64_t uint256_t::addc_discard_hi(const uint64_t a, const uint64_t b, const uint64_t carry_in)
{
return a + b + carry_in;
}
constexpr std::pair<uint64_t, uint64_t> uint256_t::sbb(const uint64_t a, const uint64_t b, const uint64_t borrow_in)
{
const uint64_t t_1 = a - (borrow_in >> 63ULL);
const auto borrow_temp_1 = static_cast<uint64_t>(t_1 > a);
const uint64_t t_2 = t_1 - b;
const auto borrow_temp_2 = static_cast<uint64_t>(t_2 > t_1);
return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
}
constexpr uint64_t uint256_t::sbb_discard_hi(const uint64_t a, const uint64_t b, const uint64_t borrow_in)
{
return a - b - (borrow_in >> 63ULL);
}
// {r, carry_out} = a + carry_in + b * c
constexpr std::pair<uint64_t, uint64_t> uint256_t::mac(const uint64_t a,
const uint64_t b,
const uint64_t c,
const uint64_t carry_in)
{
std::pair<uint64_t, uint64_t> result = mul_wide(b, c);
result.first += a;
const auto overflow_c = static_cast<uint64_t>(result.first < a);
result.first += carry_in;
const auto overflow_carry = static_cast<uint64_t>(result.first < carry_in);
result.second += (overflow_c + overflow_carry);
return result;
}
constexpr uint64_t uint256_t::mac_discard_hi(const uint64_t a,
const uint64_t b,
const uint64_t c,
const uint64_t carry_in)
{
return (b * c + a + carry_in);
}
#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
/**
* @brief Multiply one limb by 9 limbs and add to resulting limbs
*
*/
constexpr void uint256_t::wasm_madd(const uint64_t& left_limb,
const uint64_t* right_limbs,
uint64_t& result_0,
uint64_t& result_1,
uint64_t& result_2,
uint64_t& result_3,
uint64_t& result_4,
uint64_t& result_5,
uint64_t& result_6,
uint64_t& result_7,
uint64_t& result_8)
{
result_0 += left_limb * right_limbs[0];
result_1 += left_limb * right_limbs[1];
result_2 += left_limb * right_limbs[2];
result_3 += left_limb * right_limbs[3];
result_4 += left_limb * right_limbs[4];
result_5 += left_limb * right_limbs[5];
result_6 += left_limb * right_limbs[6];
result_7 += left_limb * right_limbs[7];
result_8 += left_limb * right_limbs[8];
}
/**
* @brief Convert from 4 64-bit limbs to 9 29-bit limbs
*
*/
constexpr std::array<uint64_t, WASM_NUM_LIMBS> uint256_t::wasm_convert(const uint64_t* data)
{
return { data[0] & 0x1fffffff,
(data[0] >> 29) & 0x1fffffff,
((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6),
(data[1] >> 23) & 0x1fffffff,
((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12),
(data[2] >> 17) & 0x1fffffff,
((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18),
(data[3] >> 11) & 0x1fffffff,
(data[3] >> 40) & 0x1fffffff };
}
#endif
constexpr std::pair<uint256_t, uint256_t> uint256_t::divmod(const uint256_t& b) const
{
if (*this == 0 || b == 0) {
return { 0, 0 };
}
if (b == 1) {
return { *this, 0 };
}
if (*this == b) {
return { 1, 0 };
}
if (b > *this) {
return { 0, *this };
}
uint256_t quotient = 0;
uint256_t remainder = *this;
uint64_t bit_difference = get_msb() - b.get_msb();
uint256_t divisor = b << bit_difference;
uint256_t accumulator = uint256_t(1) << bit_difference;
// if the divisor is bigger than the remainder, a and b have the same bit length
if (divisor > remainder) {
divisor >>= 1;
accumulator >>= 1;
}
// while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
// and add to the quotient
while (remainder >= b) {
// we've shunted 'divisor' up to have the same bit length as our remainder.
// If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
if (remainder >= divisor) {
remainder -= divisor;
// we can use OR here instead of +, as
// accumulator is always a nice power of two
quotient |= accumulator;
}
divisor >>= 1;
accumulator >>= 1;
}
return { quotient, remainder };
}
/**
* @brief Compute the result of multiplication modulu 2**512
*
*/
constexpr std::pair<uint256_t, uint256_t> uint256_t::mul_extended(const uint256_t& other) const
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const auto [r0, t0] = mul_wide(data[0], other.data[0]);
const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
uint256_t lo(r0, r1, r2, r3);
uint256_t hi(r4, r5, r6, r7);
return { lo, hi };
#else
// Convert 4 64-bit limbs to 9 29-bit limbs
const auto left = wasm_convert(data);
const auto right = wasm_convert(other.data);
constexpr uint64_t mask = 0x1fffffff;
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
uint64_t temp_9 = 0;
uint64_t temp_10 = 0;
uint64_t temp_11 = 0;
uint64_t temp_12 = 0;
uint64_t temp_13 = 0;
uint64_t temp_14 = 0;
uint64_t temp_15 = 0;
uint64_t temp_16 = 0;
// Multiply and addd all limbs
wasm_madd(left[0], &right[0], temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
wasm_madd(left[1], &right[0], temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_madd(left[2], &right[0], temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_madd(left[3], &right[0], temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_madd(left[4], &right[0], temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_madd(left[5], &right[0], temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_madd(left[6], &right[0], temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_madd(left[7], &right[0], temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_madd(left[8], &right[0], temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
// Convert from relaxed form into strict 29-bit form (except for temp_16)
temp_1 += temp_0 >> WASM_LIMB_BITS;
temp_0 &= mask;
temp_2 += temp_1 >> WASM_LIMB_BITS;
temp_1 &= mask;
temp_3 += temp_2 >> WASM_LIMB_BITS;
temp_2 &= mask;
temp_4 += temp_3 >> WASM_LIMB_BITS;
temp_3 &= mask;
temp_5 += temp_4 >> WASM_LIMB_BITS;
temp_4 &= mask;
temp_6 += temp_5 >> WASM_LIMB_BITS;
temp_5 &= mask;
temp_7 += temp_6 >> WASM_LIMB_BITS;
temp_6 &= mask;
temp_8 += temp_7 >> WASM_LIMB_BITS;
temp_7 &= mask;
temp_9 += temp_8 >> WASM_LIMB_BITS;
temp_8 &= mask;
temp_10 += temp_9 >> WASM_LIMB_BITS;
temp_9 &= mask;
temp_11 += temp_10 >> WASM_LIMB_BITS;
temp_10 &= mask;
temp_12 += temp_11 >> WASM_LIMB_BITS;
temp_11 &= mask;
temp_13 += temp_12 >> WASM_LIMB_BITS;
temp_12 &= mask;
temp_14 += temp_13 >> WASM_LIMB_BITS;
temp_13 &= mask;
temp_15 += temp_14 >> WASM_LIMB_BITS;
temp_14 &= mask;
temp_16 += temp_15 >> WASM_LIMB_BITS;
temp_15 &= mask;
// Convert to 2 4-64-bit limb uint256_t objects
return { { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
(temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
(temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
(temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40) },
{ (temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63),
(temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57),
(temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51),
(temp_15 >> 13) | (temp_16 << 16) } };
#endif
}
/**
* Viewing `this` uint256_t as a bit string, and counting bits from 0, slices a substring.
* @returns the uint256_t equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the
* `end`-th bit of `this`.
*/
constexpr uint256_t uint256_t::slice(const uint64_t start, const uint64_t end) const
{
const uint64_t range = end - start;
const uint256_t mask = (range == 256) ? -uint256_t(1) : (uint256_t(1) << range) - 1;
return ((*this) >> start) & mask;
}
constexpr uint256_t uint256_t::pow(const uint256_t& exponent) const
{
uint256_t accumulator{ data[0], data[1], data[2], data[3] };
uint256_t to_mul{ data[0], data[1], data[2], data[3] };
const uint64_t maximum_set_bit = exponent.get_msb();
for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
accumulator *= accumulator;
if (exponent.get_bit(static_cast<uint64_t>(i))) {
accumulator *= to_mul;
}
}
if (exponent == uint256_t(0)) {
accumulator = uint256_t(1);
} else if (*this == uint256_t(0)) {
accumulator = uint256_t(0);
}
return accumulator;
}
constexpr bool uint256_t::get_bit(const uint64_t bit_index) const
{
ASSERT(bit_index < 256);
if (bit_index > 255) {
return static_cast<bool>(0);
}
const auto idx = static_cast<size_t>(bit_index >> 6);
const size_t shift = bit_index & 63;
return static_cast<bool>((data[idx] >> shift) & 1);
}
constexpr uint64_t uint256_t::get_msb() const
{
uint64_t idx = numeric::get_msb(data[3]);
idx = (idx == 0 && data[3] == 0) ? numeric::get_msb(data[2]) : idx + 64;
idx = (idx == 0 && data[2] == 0) ? numeric::get_msb(data[1]) : idx + 64;
idx = (idx == 0 && data[1] == 0) ? numeric::get_msb(data[0]) : idx + 64;
return idx;
}
constexpr uint256_t uint256_t::operator+(const uint256_t& other) const
{
const auto [r0, t0] = addc(data[0], other.data[0], 0);
const auto [r1, t1] = addc(data[1], other.data[1], t0);
const auto [r2, t2] = addc(data[2], other.data[2], t1);
const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
return { r0, r1, r2, r3 };
};
constexpr uint256_t uint256_t::operator-(const uint256_t& other) const
{
const auto [r0, t0] = sbb(data[0], other.data[0], 0);
const auto [r1, t1] = sbb(data[1], other.data[1], t0);
const auto [r2, t2] = sbb(data[2], other.data[2], t1);
const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
return { r0, r1, r2, r3 };
}
constexpr uint256_t uint256_t::operator-() const
{
return uint256_t(0) - *this;
}
constexpr uint256_t uint256_t::operator*(const uint256_t& other) const
{
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
return { r0, r1, r2, r3 };
#else
// Convert 4 64-bit limbs to 9 29-bit limbs
const auto left = wasm_convert(data);
const auto right = wasm_convert(other.data);
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
// Multiply and add the product of left limb 0 by all right limbs
wasm_madd(left[0], &right[0], temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
// Multiply left limb 1 by limbs 0-7 ((1,8) doesn't need to be computed, because it overflows)
temp_1 += left[1] * right[0];
temp_2 += left[1] * right[1];
temp_3 += left[1] * right[2];
temp_4 += left[1] * right[3];
temp_5 += left[1] * right[4];
temp_6 += left[1] * right[5];
temp_7 += left[1] * right[6];
temp_8 += left[1] * right[7];
// Left limb 2 by right 0-6, etc
temp_2 += left[2] * right[0];
temp_3 += left[2] * right[1];
temp_4 += left[2] * right[2];
temp_5 += left[2] * right[3];
temp_6 += left[2] * right[4];
temp_7 += left[2] * right[5];
temp_8 += left[2] * right[6];
temp_3 += left[3] * right[0];
temp_4 += left[3] * right[1];
temp_5 += left[3] * right[2];
temp_6 += left[3] * right[3];
temp_7 += left[3] * right[4];
temp_8 += left[3] * right[5];
temp_4 += left[4] * right[0];
temp_5 += left[4] * right[1];
temp_6 += left[4] * right[2];
temp_7 += left[4] * right[3];
temp_8 += left[4] * right[4];
temp_5 += left[5] * right[0];
temp_6 += left[5] * right[1];
temp_7 += left[5] * right[2];
temp_8 += left[5] * right[3];
temp_6 += left[6] * right[0];
temp_7 += left[6] * right[1];
temp_8 += left[6] * right[2];
temp_7 += left[7] * right[0];
temp_8 += left[7] * right[1];
temp_8 += left[8] * right[0];
// Convert from relaxed form to strict 29-bit form
constexpr uint64_t mask = 0x1fffffff;
temp_1 += temp_0 >> WASM_LIMB_BITS;
temp_0 &= mask;
temp_2 += temp_1 >> WASM_LIMB_BITS;
temp_1 &= mask;
temp_3 += temp_2 >> WASM_LIMB_BITS;
temp_2 &= mask;
temp_4 += temp_3 >> WASM_LIMB_BITS;
temp_3 &= mask;
temp_5 += temp_4 >> WASM_LIMB_BITS;
temp_4 &= mask;
temp_6 += temp_5 >> WASM_LIMB_BITS;
temp_5 &= mask;
temp_7 += temp_6 >> WASM_LIMB_BITS;
temp_6 &= mask;
temp_8 += temp_7 >> WASM_LIMB_BITS;
temp_7 &= mask;
// Convert back to 4 64-bit limbs
return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
(temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
(temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
(temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40) };
#endif
}
constexpr uint256_t uint256_t::operator/(const uint256_t& other) const
{
return divmod(other).first;
}
constexpr uint256_t uint256_t::operator%(const uint256_t& other) const
{
return divmod(other).second;
}
constexpr uint256_t uint256_t::operator&(const uint256_t& other) const
{
return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
}
constexpr uint256_t uint256_t::operator^(const uint256_t& other) const
{
return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
}
constexpr uint256_t uint256_t::operator|(const uint256_t& other) const
{
return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
}
constexpr uint256_t uint256_t::operator~() const
{
return { ~data[0], ~data[1], ~data[2], ~data[3] };
}
constexpr bool uint256_t::operator==(const uint256_t& other) const
{
return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
}
constexpr bool uint256_t::operator!=(const uint256_t& other) const
{
return !(*this == other);
}
constexpr bool uint256_t::operator!() const
{
return *this == uint256_t(0ULL);
}
constexpr bool uint256_t::operator>(const uint256_t& other) const
{
bool t0 = data[3] > other.data[3];
bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
bool t3 =
data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
return t0 || t1 || t2 || t3;
}
constexpr bool uint256_t::operator>=(const uint256_t& other) const
{
return (*this > other) || (*this == other);
}
constexpr bool uint256_t::operator<(const uint256_t& other) const
{
return other > *this;
}
constexpr bool uint256_t::operator<=(const uint256_t& other) const
{
return (*this < other) || (*this == other);
}
constexpr uint256_t uint256_t::operator>>(const uint256_t& other) const
{
uint64_t total_shift = other.data[0];
if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
return 0;
}
if (total_shift == 0) {
return *this;
}
uint64_t num_shifted_limbs = total_shift >> 6ULL;
uint64_t limb_shift = total_shift & 63ULL;
std::array<uint64_t, 4> shifted_limbs = { 0, 0, 0, 0 };
if (limb_shift == 0) {
shifted_limbs[0] = data[0];
shifted_limbs[1] = data[1];
shifted_limbs[2] = data[2];
shifted_limbs[3] = data[3];
} else {
uint64_t remainder_shift = 64ULL - limb_shift;
shifted_limbs[3] = data[3] >> limb_shift;
uint64_t remainder = (data[3]) << remainder_shift;
shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
remainder = (data[2]) << remainder_shift;
shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
remainder = (data[1]) << remainder_shift;
shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
}
uint256_t result(0);
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
}
return result;
}
constexpr uint256_t uint256_t::operator<<(const uint256_t& other) const
{
uint64_t total_shift = other.data[0];
if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
return 0;
}
if (total_shift == 0) {
return *this;
}
uint64_t num_shifted_limbs = total_shift >> 6ULL;
uint64_t limb_shift = total_shift & 63ULL;
std::array<uint64_t, 4> shifted_limbs = { 0, 0, 0, 0 };
if (limb_shift == 0) {
shifted_limbs[0] = data[0];
shifted_limbs[1] = data[1];
shifted_limbs[2] = data[2];
shifted_limbs[3] = data[3];
} else {
uint64_t remainder_shift = 64ULL - limb_shift;
shifted_limbs[0] = data[0] << limb_shift;
uint64_t remainder = data[0] >> remainder_shift;
shifted_limbs[1] = (data[1] << limb_shift) + remainder;
remainder = data[1] >> remainder_shift;
shifted_limbs[2] = (data[2] << limb_shift) + remainder;
remainder = data[2] >> remainder_shift;
shifted_limbs[3] = (data[3] << limb_shift) + remainder;
}
uint256_t result(0);
for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
}
return result;
}
} // namespace bb::numeric

View File

@@ -0,0 +1,178 @@
/**
* uintx
* Copyright Aztec 2020
*
* An unsigned 512 bit integer type.
*
* Constructor and all methods are constexpr. Ideally, uintx should be able to be treated like any other literal
*type.
*
* Not optimized for performance, this code doesn"t touch any of our hot paths when constructing PLONK proofs
**/
#pragma once
#include "../uint256/uint256.hpp"
#include "../../common/assert.hpp"
#include "../../common/throw_or_abort.hpp"
#include <cstdint>
#include <iomanip>
#include <iostream>
namespace bb::numeric {
template <class base_uint> class uintx {
public:
constexpr uintx(const uint64_t data = 0)
: lo(data)
, hi(base_uint(0))
{}
constexpr uintx(const base_uint input_lo)
: lo(input_lo)
, hi(base_uint(0))
{}
constexpr uintx(const base_uint input_lo, const base_uint input_hi)
: lo(input_lo)
, hi(input_hi)
{}
constexpr uintx(const uintx& other)
: lo(other.lo)
, hi(other.hi)
{}
constexpr uintx(uintx&& other) noexcept = default;
static constexpr size_t length() { return 2 * base_uint::length(); }
constexpr uintx& operator=(const uintx& other) = default;
constexpr uintx& operator=(uintx&& other) noexcept = default;
constexpr ~uintx() = default;
explicit constexpr operator bool() const { return static_cast<bool>(lo.data[0]); };
explicit constexpr operator uint8_t() const { return static_cast<uint8_t>(lo.data[0]); };
explicit constexpr operator uint16_t() const { return static_cast<uint16_t>(lo.data[0]); };
explicit constexpr operator uint32_t() const { return static_cast<uint32_t>(lo.data[0]); };
explicit constexpr operator uint64_t() const { return static_cast<uint64_t>(lo.data[0]); };
explicit constexpr operator base_uint() const { return lo; }
[[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const;
[[nodiscard]] constexpr uint64_t get_msb() const;
constexpr uintx slice(uint64_t start, uint64_t end) const;
constexpr uintx operator+(const uintx& other) const;
constexpr uintx operator-(const uintx& other) const;
constexpr uintx operator-() const;
constexpr uintx operator*(const uintx& other) const;
constexpr uintx operator/(const uintx& other) const;
constexpr uintx operator%(const uintx& other) const;
constexpr std::pair<uintx, uintx> mul_extended(const uintx& other) const;
constexpr uintx operator>>(uint64_t other) const;
constexpr uintx operator<<(uint64_t other) const;
constexpr uintx operator&(const uintx& other) const;
constexpr uintx operator^(const uintx& other) const;
constexpr uintx operator|(const uintx& other) const;
constexpr uintx operator~() const;
constexpr bool operator==(const uintx& other) const;
constexpr bool operator!=(const uintx& other) const;
constexpr bool operator!() const;
constexpr bool operator>(const uintx& other) const;
constexpr bool operator<(const uintx& other) const;
constexpr bool operator>=(const uintx& other) const;
constexpr bool operator<=(const uintx& other) const;
constexpr uintx& operator+=(const uintx& other)
{
*this = *this + other;
return *this;
};
constexpr uintx& operator-=(const uintx& other)
{
*this = *this - other;
return *this;
};
constexpr uintx& operator*=(const uintx& other)
{
*this = *this * other;
return *this;
};
constexpr uintx& operator/=(const uintx& other)
{
*this = *this / other;
return *this;
};
constexpr uintx& operator%=(const uintx& other)
{
*this = *this % other;
return *this;
};
constexpr uintx& operator++()
{
*this += uintx(1);
return *this;
};
constexpr uintx& operator--()
{
*this -= uintx(1);
return *this;
};
constexpr uintx& operator&=(const uintx& other)
{
*this = *this & other;
return *this;
};
constexpr uintx& operator^=(const uintx& other)
{
*this = *this ^ other;
return *this;
};
constexpr uintx& operator|=(const uintx& other)
{
*this = *this | other;
return *this;
};
constexpr uintx& operator>>=(const uint64_t other)
{
*this = *this >> other;
return *this;
};
constexpr uintx& operator<<=(const uint64_t other)
{
*this = *this << other;
return *this;
};
constexpr uintx invmod(const uintx& modulus) const;
constexpr uintx unsafe_invmod(const uintx& modulus) const;
base_uint lo;
base_uint hi;
constexpr std::pair<uintx, uintx> divmod(const uintx& b) const;
};
template <class base_uint> inline std::ostream& operator<<(std::ostream& os, uintx<base_uint> const& a)
{
os << a.lo << ", " << a.hi << std::endl;
return os;
}
using uint512_t = uintx<numeric::uint256_t>;
using uint1024_t = uintx<uint512_t>;
} // namespace bb::numeric
#include "./uintx_impl.hpp"
using bb::numeric::uint1024_t; // NOLINT
using bb::numeric::uint512_t; // NOLINT

View File

@@ -0,0 +1,339 @@
#pragma once
#include "./uintx.hpp"
#include "../../common/assert.hpp"
namespace bb::numeric {
template <class base_uint>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
{
ASSERT(b != 0);
if (*this == 0) {
return { uintx(0), uintx(0) };
}
if (b == 1) {
return { *this, uintx(0) };
}
if (*this == b) {
return { uintx(1), uintx(0) };
}
if (b > *this) {
return { uintx(0), *this };
}
uintx quotient(0);
uintx remainder = *this;
uint64_t bit_difference = get_msb() - b.get_msb();
uintx divisor = b << bit_difference;
uintx accumulator = uintx(1) << bit_difference;
// if the divisor is bigger than the remainder, a and b have the same bit length
if (divisor > remainder) {
divisor >>= 1;
accumulator >>= 1;
}
// while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
// and add to the quotient
while (remainder >= b) {
// we've shunted 'divisor' up to have the same bit length as our remainder.
// If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
if (remainder >= divisor) {
remainder -= divisor;
// we can use OR here instead of +, as
// accumulator is always a nice power of two
quotient |= accumulator;
}
divisor >>= 1;
accumulator >>= 1;
}
return std::make_pair(quotient, remainder);
}
/**
* Computes invmod. Only for internal usage within the class.
* This is an insecure version of the algorithm that doesn't take into account the 0 case and cases when modulus is
*close to the top margin.
*
* @param modulus The modulus of the ring
*
* @return The inverse of *this modulo modulus
**/
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::unsafe_invmod(const uintx& modulus) const
{
uintx t1 = 0;
uintx t2 = 1;
uintx r2 = (*this > modulus) ? *this % modulus : *this;
uintx r1 = modulus;
uintx q = 0;
while (r2 != 0) {
q = r1 / r2;
uintx temp_t1 = t1;
uintx temp_r1 = r1;
t1 = t2;
t2 = temp_t1 - q * t2;
r1 = r2;
r2 = temp_r1 - q * r2;
}
if (t1 > modulus) {
return modulus + t1;
}
return t1;
}
/**
* Computes the inverse of *this, modulo modulus, via the extended Euclidean algorithm.
*
* Delegates to appropriate unsafe_invmod (if the modulus is close to uintx top margin there is a need to expand)
*
* @param modulus The modulus
* @return The inverse of *this modulo modulus
**/
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::invmod(const uintx& modulus) const
{
ASSERT((*this) != 0);
if (modulus == 0) {
return 0;
}
if (modulus.get_msb() >= (2 * base_uint::length() - 1)) {
uintx<uintx<base_uint>> a_expanded(*this);
uintx<uintx<base_uint>> modulus_expanded(modulus);
return a_expanded.unsafe_invmod(modulus_expanded).lo;
}
return this->unsafe_invmod(modulus);
}
/**
* Viewing `this` as a bit string, and counting bits from 0, slices a substring.
* @returns the uintx equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the
* `end`-th bit of `this`.
*/
template <class base_uint>
constexpr uintx<base_uint> uintx<base_uint>::slice(const uint64_t start, const uint64_t end) const
{
const uint64_t range = end - start;
const uintx mask = range == base_uint::length() ? -uintx(1) : (uintx(1) << range) - 1;
return ((*this) >> start) & mask;
}
template <class base_uint> constexpr bool uintx<base_uint>::get_bit(const uint64_t bit_index) const
{
if (bit_index >= base_uint::length()) {
return hi.get_bit(bit_index - base_uint::length());
}
return lo.get_bit(bit_index);
}
template <class base_uint> constexpr uint64_t uintx<base_uint>::get_msb() const
{
uint64_t hi_idx = hi.get_msb();
uint64_t lo_idx = lo.get_msb();
return (hi_idx || (hi > base_uint(0))) ? (hi_idx + base_uint::length()) : lo_idx;
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator+(const uintx& other) const
{
base_uint res_lo = lo + other.lo;
bool carry = res_lo < lo;
base_uint res_hi = hi + other.hi + ((carry) ? base_uint(1) : base_uint(0));
return { res_lo, res_hi };
};
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator-(const uintx& other) const
{
base_uint res_lo = lo - other.lo;
bool borrow = res_lo > lo;
base_uint res_hi = hi - other.hi - ((borrow) ? base_uint(1) : base_uint(0));
return { res_lo, res_hi };
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator-() const
{
return uintx(0) - *this;
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator*(const uintx& other) const
{
const auto lolo = lo.mul_extended(other.lo);
const auto lohi = lo.mul_extended(other.hi);
const auto hilo = hi.mul_extended(other.lo);
base_uint top = lolo.second + hilo.first + lohi.first;
base_uint bottom = lolo.first;
return { bottom, top };
}
template <class base_uint>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::mul_extended(const uintx& other) const
{
const auto lolo = lo.mul_extended(other.lo);
const auto lohi = lo.mul_extended(other.hi);
const auto hilo = hi.mul_extended(other.lo);
const auto hihi = hi.mul_extended(other.hi);
base_uint t0 = lolo.first;
base_uint t1 = lolo.second;
base_uint t2 = hilo.second;
base_uint t3 = hihi.second;
base_uint t2_carry(0);
base_uint t3_carry(0);
t1 += hilo.first;
t2_carry += (t1 < hilo.first ? base_uint(1) : base_uint(0));
t1 += lohi.first;
t2_carry += (t1 < lohi.first ? base_uint(1) : base_uint(0));
t2 += lohi.second;
t3_carry += (t2 < lohi.second ? base_uint(1) : base_uint(0));
t2 += hihi.first;
t3_carry += (t2 < hihi.first ? base_uint(1) : base_uint(0));
t2 += t2_carry;
t3_carry += (t2 < t2_carry ? base_uint(1) : base_uint(0));
t3 += t3_carry;
return { uintx(t0, t1), uintx(t2, t3) };
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator/(const uintx& other) const
{
return divmod(other).first;
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator%(const uintx& other) const
{
return divmod(other).second;
}
// 0x2af0296feca4188a80fd373ebe3c64da87a232934abb3a99f9c4cd59e6758a65
// 0x1182c6cdb54193b51ca27c1932b95c82bebac691e3996e5ec5e1d4395f3023e3
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator&(const uintx& other) const
{
return { lo & other.lo, hi & other.hi };
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator^(const uintx& other) const
{
return { lo ^ other.lo, hi ^ other.hi };
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator|(const uintx& other) const
{
return { lo | other.lo, hi | other.hi };
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator~() const
{
return { ~lo, ~hi };
}
template <class base_uint> constexpr bool uintx<base_uint>::operator==(const uintx& other) const
{
return ((lo == other.lo) && (hi == other.hi));
}
template <class base_uint> constexpr bool uintx<base_uint>::operator!=(const uintx& other) const
{
return !(*this == other);
}
template <class base_uint> constexpr bool uintx<base_uint>::operator!() const
{
return *this == uintx(0ULL);
}
template <class base_uint> constexpr bool uintx<base_uint>::operator>(const uintx& other) const
{
bool hi_gt = hi > other.hi;
bool lo_gt = lo > other.lo;
bool gt = (hi_gt) || (lo_gt && (hi == other.hi));
return gt;
}
template <class base_uint> constexpr bool uintx<base_uint>::operator>=(const uintx& other) const
{
return (*this > other) || (*this == other);
}
template <class base_uint> constexpr bool uintx<base_uint>::operator<(const uintx& other) const
{
return other > *this;
}
template <class base_uint> constexpr bool uintx<base_uint>::operator<=(const uintx& other) const
{
return (*this < other) || (*this == other);
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator>>(const uint64_t other) const
{
const uint64_t total_shift = other;
if (total_shift >= length()) {
return uintx(0);
}
if (total_shift == 0) {
return *this;
}
const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb());
const uint64_t limb_shift = total_shift & static_cast<uint64_t>(base_uint::length() - 1);
std::array<base_uint, 2> shifted_limbs = { 0, 0 };
if (limb_shift == 0) {
shifted_limbs[0] = lo;
shifted_limbs[1] = hi;
} else {
const uint64_t remainder_shift = static_cast<uint64_t>(base_uint::length()) - limb_shift;
shifted_limbs[1] = hi >> limb_shift;
base_uint remainder = (hi) << remainder_shift;
shifted_limbs[0] = (lo >> limb_shift) + remainder;
}
uintx result(0);
if (num_shifted_limbs == 0) {
result.hi = shifted_limbs[1];
result.lo = shifted_limbs[0];
} else {
result.lo = shifted_limbs[1];
}
return result;
}
template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator<<(const uint64_t other) const
{
const uint64_t total_shift = other;
if (total_shift >= length()) {
return uintx(0);
}
if (total_shift == 0) {
return *this;
}
const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb());
const uint64_t limb_shift = total_shift & static_cast<uint64_t>(base_uint::length() - 1);
std::array<base_uint, 2> shifted_limbs = { 0, 0 };
if (limb_shift == 0) {
shifted_limbs[0] = lo;
shifted_limbs[1] = hi;
} else {
const uint64_t remainder_shift = static_cast<uint64_t>(base_uint::length()) - limb_shift;
shifted_limbs[0] = lo << limb_shift;
base_uint remainder = lo >> remainder_shift;
shifted_limbs[1] = (hi << limb_shift) + remainder;
}
uintx result(0);
if (num_shifted_limbs == 0) {
result.hi = shifted_limbs[1];
result.lo = shifted_limbs[0];
} else {
result.hi = shifted_limbs[0];
}
return result;
}
} // namespace bb::numeric

View File

@@ -0,0 +1,27 @@
/**
* @brief Defines particular circuit builder types expected to be used for circuit
construction in stdlib and contains macros for explicit instantiation.
*
* @details This file is designed to be included in header files to instruct the compiler that these classes exist and
* their instantiation will eventually take place. Given it has no dependencies, it causes no additional compilation or
* propagation.
*/
#pragma once
#include <concepts>
namespace bb {
class StandardFlavor;
class UltraFlavor;
class Bn254FrParams;
class Bn254FqParams;
template <class Params> struct alignas(32) field;
template <typename FF_> class UltraArith;
template <class FF> class StandardCircuitBuilder_;
using StandardCircuitBuilder = StandardCircuitBuilder_<field<Bn254FrParams>>;
using StandardGrumpkinCircuitBuilder = StandardCircuitBuilder_<field<Bn254FqParams>>;
template <class Arithmetization> class UltraCircuitBuilder_;
using UltraCircuitBuilder = UltraCircuitBuilder_<UltraArith<field<Bn254FrParams>>>;
template <class FF> class MegaCircuitBuilder_;
using MegaCircuitBuilder = MegaCircuitBuilder_<field<Bn254FrParams>>;
class CircuitSimulatorBN254;
} // namespace bb