diff --git a/sumcheck/src/cuda/includes/barretenberg/common/assert.hpp b/sumcheck/src/cuda/includes/barretenberg/common/assert.hpp new file mode 100644 index 0000000..86767fa --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/assert.hpp @@ -0,0 +1,20 @@ +#pragma once + +// NOLINTBEGIN +#if NDEBUG +// Compiler should optimize this out in release builds, without triggering an unused variable warning. +#define DONT_EVALUATE(expression) \ + { \ + true ? static_cast(0) : static_cast((expression)); \ + } +#define ASSERT(expression) DONT_EVALUATE((expression)) +#else +// cassert in wasi-sdk takes one second to compile, only include if needed +#include +#include +#include +#include +#define ASSERT(expression) assert((expression)) +#endif // NDEBUG + +// NOLINTEND \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/compiler_hints.hpp b/sumcheck/src/cuda/includes/barretenberg/common/compiler_hints.hpp new file mode 100644 index 0000000..9492475 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/compiler_hints.hpp @@ -0,0 +1,26 @@ +#pragma once + +#ifdef _WIN32 +#define BB_INLINE __forceinline inline +#else +#define BB_INLINE __attribute__((always_inline)) inline +#endif + +// TODO(AD): Other instrumentation? +#ifdef XRAY +#define BB_PROFILE [[clang::xray_always_instrument]] [[clang::noinline]] +#define BB_NO_PROFILE [[clang::xray_never_instrument]] +#else +#define BB_PROFILE +#define BB_NO_PROFILE +#endif + +// Optimization hints for clang - which outcome of an expression is expected for better +// branch-prediction optimization +#ifdef __clang__ +#define BB_LIKELY(x) __builtin_expect(!!(x), 1) +#define BB_UNLIKELY(x) __builtin_expect(!!(x), 0) +#else +#define BB_LIKELY(x) x +#define BB_UNLIKELY(x) x +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/constexpr_utils.hpp b/sumcheck/src/cuda/includes/barretenberg/common/constexpr_utils.hpp new file mode 100644 index 0000000..9b5c7da --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/constexpr_utils.hpp @@ -0,0 +1,162 @@ +#pragma once + +#include +#include +#include + +/** + * @brief constexpr_utils defines some helper methods that perform some stl-equivalent operations + * but in a constexpr context over quantities known at compile-time + * + * Current methods are: + * + * constexpr_for : loop over a range , where the size_t iterator `i` is a constexpr variable + * constexpr_find : find if an element is in an array + */ +namespace bb { + +/** + * @brief Implements a loop using a compile-time iterator. Requires c++20. + * Implementation (and description) from https://artificial-mind.net/blog/2020/10/31/constexpr-for + * + * @tparam Start the loop start value + * @tparam End the loop end value + * @tparam Inc how much the iterator increases by per iteration + * @tparam F a Lambda function that is executed once per loop + * + * @param f An rvalue reference to the lambda + * @details Implements a `for` loop where the iterator is a constexpr variable. + * Use this when you need to evaluate `if constexpr` statements on the iterator (or apply other constexpr expressions) + * Outside of this use-case avoid using this fn as it gives negligible performance increases vs regular loops. + * + * N.B. A side-effect of this method is that all loops will be unrolled + * (each loop iteration uses different iterator template parameters => unique constexpr_for implementation per + * iteration) + * Do not use this for large (~100+) loops! + * + * ############################## + * EXAMPLE USE OF `constexpr_for` + * ############################## + * + * constexpr_for<0, 10, 1>([&](){ + * if constexpr (i & 1 == 0) + * { + * foo[i] = even_container[i >> 1]; + * } + * else + * { + * foo[i] = odd_container[i >> 1]; + * } + * }); + * + * In the above example we are iterating from i = 0 to i < 10. + * The provided lambda function has captured everything in its surrounding scope (via `[&]`), + * which is where `foo`, `even_container` and `odd_container` have come from. + * + * We do not need to explicitly define the `class F` parameter as the compiler derives it from our provided input + * argument `F&& f` (i.e. the lambda function) + * + * In the loop itself we're evaluating a constexpr if statement that defines which code path is taken. + * + * The above example benefits from `constexpr_for` because a run-time `if` statement has been reduced to a compile-time + * `if` statement. N.B. this would only give measurable improvements if the `constexpr_for` statement is itself in a hot + * loop that's iterated over many (>thousands) times + */ +template constexpr void constexpr_for(F&& f) +{ + // Call function `f()` iff Start < End + if constexpr (Start < End) { + // F must be a template lambda with a single **typed** template parameter that represents the iterator + // (e.g. [&](){ ... } is good) + // (and [&](){ ... } won't compile!) + + /** + * Explaining f.template operator()() + * + * The following line must explicitly tell the compiler that is a template parameter by using the + * `template` keyword. + * (if we wrote f(), the compiler could legitimately interpret `<` as a less than symbol) + * + * The fragment `f.template` tells the compiler that we're calling a *templated* member of `f`. + * The "member" being called is the function operator, `operator()`, which must be explicitly provided + * (for any function X, `X(args)` is an alias for `X.operator()(args)`) + * The compiler has no alias `X.template (args)` for `X.template operator()(args)` so we must + * write it explicitly here + * + * To summarize what the next line tells the compiler... + * 1. I want to call a member of `f` that expects one or more template parameters + * 2. The member of `f` that I want to call is the function operator + * 3. The template parameter is `Start` + * 4. The function operator itself contains no arguments + */ + f.template operator()(); + + // Once we have executed `f`, we recursively call the `constexpr_for` function, increasing the value of `Start` + // by `Inc` + constexpr_for(f); + } +} + +/** + * @brief returns true/false depending on whether `key` is in `container` + * + * @tparam container i.e. what are we looking in? + * @tparam key i.e. what are we looking for? + * @return true found! + * @return false not found! + * + * @details method is constexpr and can be used in static_asserts + */ +template constexpr bool constexpr_find() +{ + // using ElementType = typename std::remove_extent::type; + bool found = false; + constexpr_for<0, container.size(), 1>([&]() { + if constexpr (std::get(container) == key) { + found = true; + } + }); + return found; +} + +/** + * @brief Create a constexpr array object whose elements contain a default value + * + * @tparam T type contained in the array + * @tparam Is index sequence + * @param value the value each array element is being initialized to + * @return constexpr std::array + * + * @details This method is used to create constexpr arrays whose encapsulated type: + * + * 1. HAS NO CONSTEXPR DEFAULT CONSTRUCTOR + * 2. HAS A CONSTEXPR COPY CONSTRUCTOR + * + * An example of this is bb::field_t + * (the default constructor does not default assign values to the field_t member variables for efficiency reasons, to + * reduce the time require to construct large arrays of field elements. This means the default constructor for field_t + * cannot be constexpr) + */ +template +constexpr std::array create_array(T value, std::index_sequence /*unused*/) +{ + // cast Is to void to remove the warning: unused value + std::array result = { { (static_cast(Is), value)... } }; + return result; +} + +/** + * @brief Create a constexpr array object whose values all are 0 + * + * @tparam T + * @tparam N + * @return constexpr std::array + * + * @details Use in the same context as create_array, i.e. when encapsulated type has a default constructor that is not + * constexpr + */ +template constexpr std::array create_empty_array() +{ + return create_array(T(0), std::make_index_sequence()); +} +}; // namespace bb \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/log.hpp b/sumcheck/src/cuda/includes/barretenberg/common/log.hpp new file mode 100644 index 0000000..3e67679 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/log.hpp @@ -0,0 +1,129 @@ +#pragma once +#include "../env/logstr.hpp" +#include "../stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp" +#include +#include +#include +#include + +#define BENCHMARK_INFO_PREFIX "##BENCHMARK_INFO_PREFIX##" +#define BENCHMARK_INFO_SEPARATOR "#" +#define BENCHMARK_INFO_SUFFIX "##BENCHMARK_INFO_SUFFIX##" + +template std::string format(Args... args) +{ + std::ostringstream os; + ((os << args), ...); + return os.str(); +} + +template void benchmark_format_chain(std::ostream& os, T const& first) +{ + // We will be saving these values to a CSV file, so we can't tolerate commas + std::stringstream current_argument; + current_argument << first; + std::string current_argument_string = current_argument.str(); + std::replace(current_argument_string.begin(), current_argument_string.end(), ',', ';'); + os << current_argument_string << BENCHMARK_INFO_SUFFIX; +} + +template +void benchmark_format_chain(std::ostream& os, T const& first, Args const&... args) +{ + // We will be saving these values to a CSV file, so we can't tolerate commas + std::stringstream current_argument; + current_argument << first; + std::string current_argument_string = current_argument.str(); + std::replace(current_argument_string.begin(), current_argument_string.end(), ',', ';'); + os << current_argument_string << BENCHMARK_INFO_SEPARATOR; + benchmark_format_chain(os, args...); +} + +template std::string benchmark_format(Args... args) +{ + std::ostringstream os; + os << BENCHMARK_INFO_PREFIX; + benchmark_format_chain(os, args...); + return os.str(); +} + +#if NDEBUG +template inline void debug(Args... args) +{ + logstr(format(args...).c_str()); +} +#else +template inline void debug(Args... /*unused*/) {} +#endif + +template inline void info(Args... args) +{ + logstr(format(args...).c_str()); +} + +template inline void important(Args... args) +{ + logstr(format("important: ", args...).c_str()); +} + +/** + * @brief Info used to store circuit statistics during CI/CD with concrete structure. Writes straight to log + * + * @details Automatically appends the necessary prefix and suffix, as well as separators. + * + * @tparam Args + * @param args + */ +#ifdef CI +template +inline void benchmark_info(Arg1 composer, Arg2 class_name, Arg3 operation, Arg4 metric, Arg5 value) +{ + logstr(benchmark_format(composer, class_name, operation, metric, value).c_str()); +} +#else +template inline void benchmark_info(Args... /*unused*/) {} +#endif + +/** + * @brief A class for saving benchmarks and printing them all at once in the end of the function. + * + */ +class BenchmarkInfoCollator { + + std::vector saved_benchmarks; + + public: + BenchmarkInfoCollator() = default; + BenchmarkInfoCollator(const BenchmarkInfoCollator& other) = default; + BenchmarkInfoCollator(BenchmarkInfoCollator&& other) = default; + BenchmarkInfoCollator& operator=(const BenchmarkInfoCollator& other) = default; + BenchmarkInfoCollator& operator=(BenchmarkInfoCollator&& other) = default; + +/** + * @brief Info used to store circuit statistics during CI/CD with concrete structure. Stores string in vector for now + * (used to flush all benchmarks at the end of test). + * + * @details Automatically appends the necessary prefix and suffix, as well as separators. + * + * @tparam Args + * @param args + */ +#ifdef CI + template + inline void benchmark_info_deferred(Arg1 composer, Arg2 class_name, Arg3 operation, Arg4 metric, Arg5 value) + { + saved_benchmarks.push_back(benchmark_format(composer, class_name, operation, metric, value).c_str()); + } +#else + explicit BenchmarkInfoCollator(std::vector saved_benchmarks) + : saved_benchmarks(std::move(saved_benchmarks)) + {} + template inline void benchmark_info_deferred(Args... /*unused*/) {} +#endif + ~BenchmarkInfoCollator() + { + for (auto& x : saved_benchmarks) { + logstr(x.c_str()); + } + } +}; diff --git a/sumcheck/src/cuda/includes/barretenberg/common/mem.hpp b/sumcheck/src/cuda/includes/barretenberg/common/mem.hpp new file mode 100644 index 0000000..a26e615 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/mem.hpp @@ -0,0 +1,82 @@ +#pragma once +#include "log.hpp" +#include "memory.h" +#include "wasm_export.hpp" +#include +#include +// #include + +#define pad(size, alignment) (size - (size % alignment) + ((size % alignment) == 0 ? 0 : alignment)) + +#ifdef __APPLE__ +inline void* aligned_alloc(size_t alignment, size_t size) +{ + void* t = 0; + posix_memalign(&t, alignment, size); + if (t == 0) { + info("bad alloc of size: ", size); + std::abort(); + } + return t; +} + +inline void aligned_free(void* mem) +{ + free(mem); +} +#endif + +#if defined(__linux__) || defined(__wasm__) +inline void* protected_aligned_alloc(size_t alignment, size_t size) +{ + size += (size % alignment); + void* t = nullptr; + // pad size to alignment + if (size % alignment != 0) { + size += alignment - (size % alignment); + } + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + t = aligned_alloc(alignment, size); + if (t == nullptr) { + info("bad alloc of size: ", size); + std::abort(); + } + return t; +} + +#define aligned_alloc protected_aligned_alloc + +inline void aligned_free(void* mem) +{ + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory, cppcoreguidelines-no-malloc) + free(mem); +} +#endif + +#ifdef _WIN32 +inline void* aligned_alloc(size_t alignment, size_t size) +{ + return _aligned_malloc(size, alignment); +} + +inline void aligned_free(void* mem) +{ + _aligned_free(mem); +} +#endif + +// inline void print_malloc_info() +// { +// struct mallinfo minfo = mallinfo(); + +// info("Total non-mmapped bytes (arena): ", minfo.arena); +// info("Number of free chunks (ordblks): ", minfo.ordblks); +// info("Number of fastbin blocks (smblks): ", minfo.smblks); +// info("Number of mmapped regions (hblks): ", minfo.hblks); +// info("Space allocated in mmapped regions (hblkhd): ", minfo.hblkhd); +// info("Maximum total allocated space (usmblks): ", minfo.usmblks); +// info("Space available in freed fastbin blocks (fsmblks): ", minfo.fsmblks); +// info("Total allocated space (uordblks): ", minfo.uordblks); +// info("Total free space (fordblks): ", minfo.fordblks); +// info("Top-most, releasable space (keepcost): ", minfo.keepcost); +// } \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/net.hpp b/sumcheck/src/cuda/includes/barretenberg/common/net.hpp new file mode 100644 index 0000000..f0daa8f --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/net.hpp @@ -0,0 +1,15 @@ +#pragma once + +#if defined(__linux__) || defined(__wasm__) +#include +#include +#define ntohll be64toh +#define htonll htobe64 +#endif + +inline bool is_little_endian() +{ + constexpr int num = 42; + // NOLINTNEXTLINE Nope. nope nope nope nope nope. + return (*(char*)&num == 42); +} \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/op_count.cpp b/sumcheck/src/cuda/includes/barretenberg/common/op_count.cpp new file mode 100644 index 0000000..d95d46a --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/op_count.cpp @@ -0,0 +1,104 @@ + +#include +#ifdef BB_USE_OP_COUNT +#include "op_count.hpp" +#include +#include +#include + +namespace bb::detail { + +GlobalOpCountContainer::~GlobalOpCountContainer() +{ + // This is useful for printing counts at the end of non-benchmarks. + // See op_count_google_bench.hpp for benchmarks. + // print(); +} + +void GlobalOpCountContainer::add_entry(const char* key, const std::shared_ptr& count) +{ + std::unique_lock lock(mutex); + std::stringstream ss; + ss << std::this_thread::get_id(); + counts.push_back({ key, ss.str(), count }); +} + +void GlobalOpCountContainer::print() const +{ + std::cout << "print_op_counts() START" << std::endl; + for (const Entry& entry : counts) { + if (entry.count->count > 0) { + std::cout << entry.key << "\t" << entry.count->count << "\t[thread=" << entry.thread_id << "]" << std::endl; + } + if (entry.count->time > 0) { + std::cout << entry.key << "(t)\t" << static_cast(entry.count->time) / 1000000.0 + << "ms\t[thread=" << entry.thread_id << "]" << std::endl; + } + if (entry.count->cycles > 0) { + std::cout << entry.key << "(c)\t" << entry.count->cycles << "\t[thread=" << entry.thread_id << "]" + << std::endl; + } + } + std::cout << "print_op_counts() END" << std::endl; +} + +std::map GlobalOpCountContainer::get_aggregate_counts() const +{ + std::map aggregate_counts; + for (const Entry& entry : counts) { + if (entry.count->count > 0) { + aggregate_counts[entry.key] += entry.count->count; + } + if (entry.count->time > 0) { + aggregate_counts[entry.key + "(t)"] += entry.count->time; + } + if (entry.count->cycles > 0) { + aggregate_counts[entry.key + "(c)"] += entry.count->cycles; + } + } + return aggregate_counts; +} + +void GlobalOpCountContainer::clear() +{ + std::unique_lock lock(mutex); + for (Entry& entry : counts) { + *entry.count = OpStats(); + } +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +GlobalOpCountContainer GLOBAL_OP_COUNTS; + +OpCountCycleReporter::OpCountCycleReporter(OpStats* stats) + : stats(stats) +{ +#if __clang__ && (defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)) + // Don't support any other targets but x86 clang for now, this is a bit lazy but more than fits our needs + cycles = __builtin_ia32_rdtsc(); +#endif +} +OpCountCycleReporter::~OpCountCycleReporter() +{ +#if __clang__ && (defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)) + // Don't support any other targets but x86 clang for now, this is a bit lazy but more than fits our needs + stats->count += 1; + stats->cycles += __builtin_ia32_rdtsc() - cycles; +#endif +} +OpCountTimeReporter::OpCountTimeReporter(OpStats* stats) + : stats(stats) +{ + auto now = std::chrono::high_resolution_clock::now(); + auto now_ns = std::chrono::time_point_cast(now); + time = static_cast(now_ns.time_since_epoch().count()); +} +OpCountTimeReporter::~OpCountTimeReporter() +{ + auto now = std::chrono::high_resolution_clock::now(); + auto now_ns = std::chrono::time_point_cast(now); + stats->count += 1; + stats->time += static_cast(now_ns.time_since_epoch().count()) - time; +} +} // namespace bb::detail +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/common/op_count.hpp b/sumcheck/src/cuda/includes/barretenberg/common/op_count.hpp new file mode 100644 index 0000000..053bcfc --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/op_count.hpp @@ -0,0 +1,160 @@ + +#pragma once + +#include +#ifndef BB_USE_OP_COUNT +// require a semicolon to appease formatters +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TRACK() (void)0 +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TRACK_NAME(name) (void)0 +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_CYCLES_NAME(name) (void)0 +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TIME_NAME(name) (void)0 +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_CYCLES() (void)0 +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TIME() (void)0 +#else +/** + * Provides an abstraction that counts operations based on function names. + * For efficiency, we spread out counts across threads. + */ + +#include "./compiler_hints.hpp" +#include +#include +#include +#include +#include +#include +#include +namespace bb::detail { +// Compile-time string +// See e.g. https://www.reddit.com/r/cpp_questions/comments/pumi9r/does_c20_not_support_string_literals_as_template/ +template struct OperationLabel { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) + constexpr OperationLabel(const char (&str)[N]) + { + for (std::size_t i = 0; i < N; ++i) { + value[i] = str[i]; + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) + char value[N]; +}; + +struct OpStats { + std::size_t count = 0; + std::size_t time = 0; + std::size_t cycles = 0; +}; + +// Contains all statically known op counts +struct GlobalOpCountContainer { + public: + struct Entry { + std::string key; + std::string thread_id; + std::shared_ptr count; + }; + ~GlobalOpCountContainer(); + std::mutex mutex; + std::vector counts; + void print() const; + // NOTE: Should be called when other threads aren't active + void clear(); + void add_entry(const char* key, const std::shared_ptr& count); + std::map get_aggregate_counts() const; +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +extern GlobalOpCountContainer GLOBAL_OP_COUNTS; + +template struct GlobalOpCount { + public: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + static thread_local std::shared_ptr stats; + + static OpStats* ensure_stats() + { + if (BB_UNLIKELY(stats == nullptr)) { + stats = std::make_shared(); + GLOBAL_OP_COUNTS.add_entry(Op.value, stats); + } + return stats.get(); + } + static constexpr void increment_op_count() + { +#ifndef BB_USE_OP_COUNT_TIME_ONLY + if (std::is_constant_evaluated()) { + // We do nothing if the compiler tries to run this + return; + } + ensure_stats(); + stats->count++; +#endif + } + static constexpr void add_cycle_time(std::size_t cycles) + { +#ifndef BB_USE_OP_COUNT_TRACK_ONLY + if (std::is_constant_evaluated()) { + // We do nothing if the compiler tries to run this + return; + } + ensure_stats(); + stats->cycles += cycles; +#else + static_cast(cycles); +#endif + } + static constexpr void add_clock_time(std::size_t time) + { +#ifndef BB_USE_OP_COUNT_TRACK_ONLY + if (std::is_constant_evaluated()) { + // We do nothing if the compiler tries to run this + return; + } + ensure_stats(); + stats->time += time; +#else + static_cast(time); +#endif + } +}; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +template thread_local std::shared_ptr GlobalOpCount::stats; + +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +struct OpCountCycleReporter { + OpStats* stats; + std::size_t cycles; + OpCountCycleReporter(OpStats* stats); + ~OpCountCycleReporter(); +}; +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +struct OpCountTimeReporter { + OpStats* stats; + std::size_t time; + OpCountTimeReporter(OpStats* stats); + ~OpCountTimeReporter(); +}; +} // namespace bb::detail + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TRACK_NAME(name) bb::detail::GlobalOpCount::increment_op_count() +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TRACK() BB_OP_COUNT_TRACK_NAME(__func__) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_CYCLES_NAME(name) \ + bb::detail::OpCountCycleReporter __bb_op_count_cyles(bb::detail::GlobalOpCount::ensure_stats()) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_CYCLES() BB_OP_COUNT_CYCLES_NAME(__func__) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TIME_NAME(name) \ + bb::detail::OpCountTimeReporter __bb_op_count_time(bb::detail::GlobalOpCount::ensure_stats()) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BB_OP_COUNT_TIME() BB_OP_COUNT_TIME_NAME(__func__) +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/common/slab_allocator.cpp b/sumcheck/src/cuda/includes/barretenberg/common/slab_allocator.cpp new file mode 100644 index 0000000..fb787c9 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/slab_allocator.cpp @@ -0,0 +1,243 @@ +#include "slab_allocator.hpp" +#include +#include +#include +#include +#include +#include + +#define LOGGING 0 + +/** + * If we can guarantee that all slabs will be released before the allocator is destroyed, we wouldn't need this. + * However, there is (and maybe again) cases where a global is holding onto a slab. In such a case you will have + * issues if the runtime frees the allocator before the slab is released. The effect is subtle, so it's worth + * protecting against rather than just saying "don't do globals". But you know, don't do globals... + * (Irony of global slab allocator noted). + */ +namespace { +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +bool allocator_destroyed = false; + +// Slabs that are being manually managed by the user. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +std::unordered_map> manual_slabs; +#ifndef NO_MULTITHREADING +// The manual slabs unordered map is not thread-safe, so we need to manage access to it when multithreaded. +std::mutex manual_slabs_mutex; +#endif +template inline void dbg_info(Args... args) +{ +#if LOGGING == 1 + info(args...); +#else + // Suppress warning. + (void)(sizeof...(args)); +#endif +} + +/** + * Allows preallocating memory slabs sized to serve the fact that these slabs of memory follow certain sizing + * patterns and numbers based on prover system type and circuit size. Without the slab allocator, memory + * fragmentation prevents proof construction when approaching memory space limits (4GB in WASM). + * + * If no circuit_size_hint is given to the constructor, it behaves as a standard memory allocator. + */ +class SlabAllocator { + private: + size_t circuit_size_hint_; + std::map> memory_store; +#ifndef NO_MULTITHREADING + std::mutex memory_store_mutex; +#endif + + public: + ~SlabAllocator(); + SlabAllocator() = default; + SlabAllocator(const SlabAllocator& other) = delete; + SlabAllocator(SlabAllocator&& other) = delete; + SlabAllocator& operator=(const SlabAllocator& other) = delete; + SlabAllocator& operator=(SlabAllocator&& other) = delete; + + void init(size_t circuit_size_hint); + + std::shared_ptr get(size_t size); + + size_t get_total_size(); + + private: + void release(void* ptr, size_t size); +}; + +SlabAllocator::~SlabAllocator() +{ + allocator_destroyed = true; + for (auto& e : memory_store) { + for (auto& p : e.second) { + aligned_free(p); + } + } +} + +void SlabAllocator::init(size_t circuit_size_hint) +{ + if (circuit_size_hint <= circuit_size_hint_) { + return; + } + + circuit_size_hint_ = circuit_size_hint; + + // Free any existing slabs. + for (auto& e : memory_store) { + for (auto& p : e.second) { + aligned_free(p); + } + } + memory_store.clear(); + + dbg_info("slab allocator initing for size: ", circuit_size_hint); + + if (circuit_size_hint == 0ULL) { + return; + } + + // Over-allocate because we know there are requests for circuit_size + n. (somewhat arbitrary n = 512) + size_t overalloc = 512; + size_t base_size = circuit_size_hint + overalloc; + + std::map prealloc_num; + + // Size comments below assume a base (circuit) size of 2^19, 524288 bytes. + + // /* 0.5 MiB */ prealloc_num[base_size * 1] = 2; // Batch invert skipped temporary. + // /* 2 MiB */ prealloc_num[base_size * 4] = 4 + // Composer base wire vectors. + // 1; // Miscellaneous. + // /* 6 MiB */ prealloc_num[base_size * 12] = 2 + // next_var_index, prev_var_index + // 2; // real_variable_index, real_variable_tags + /* 16 MiB */ prealloc_num[base_size * 32] = 11; // Composer base selector vectors. + /* 32 MiB */ prealloc_num[base_size * 32 * 2] = 1; // Miscellaneous. + /* 50 MiB */ prealloc_num[base_size * 32 * 3] = 1; // Variables. + /* 64 MiB */ prealloc_num[base_size * 32 * 4] = 1 + // SRS monomial points. + 4 + // Coset-fft wires. + 15 + // Coset-fft constraint selectors. + 8 + // Coset-fft perm selectors. + 1 + // Coset-fft sorted poly. + 1 + // Pippenger point_schedule. + 4; // Miscellaneous. + /* 128 MiB */ prealloc_num[base_size * 32 * 8] = 1 + // Proving key evaluation domain roots. + 2; // Pippenger point_pairs. + + for (auto& e : prealloc_num) { + for (size_t i = 0; i < e.second; ++i) { + auto size = e.first; + memory_store[size].push_back(aligned_alloc(32, size)); + dbg_info("Allocated memory slab of size: ", size, " total: ", get_total_size()); + } + } +} + +std::shared_ptr SlabAllocator::get(size_t req_size) +{ +#ifndef NO_MULTITHREADING + std::unique_lock lock(memory_store_mutex); +#endif + + auto it = memory_store.lower_bound(req_size); + + // Can use a preallocated slab that is less than 2 times the requested size. + if (it != memory_store.end() && it->first < req_size * 2) { + size_t size = it->first; + auto* ptr = it->second.back(); + it->second.pop_back(); + + if (it->second.empty()) { + memory_store.erase(it); + } + + if (req_size >= circuit_size_hint_ && size > req_size + req_size / 10) { + dbg_info("WARNING: Using memory slab of size: ", + size, + " for requested ", + req_size, + " total: ", + get_total_size()); + } else { + dbg_info("Reusing memory slab of size: ", size, " for requested ", req_size, " total: ", get_total_size()); + } + + return { ptr, [this, size](void* p) { + if (allocator_destroyed) { + aligned_free(p); + return; + } + this->release(p, size); + } }; + } + + if (req_size > static_cast(1024 * 1024)) { + dbg_info("WARNING: Allocating unmanaged memory slab of size: ", req_size); + } + if (req_size % 32 == 0) { + return { aligned_alloc(32, req_size), aligned_free }; + } + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + return { malloc(req_size), free }; +} + +size_t SlabAllocator::get_total_size() +{ + return std::accumulate(memory_store.begin(), memory_store.end(), size_t{ 0 }, [](size_t acc, const auto& kv) { + return acc + kv.first * kv.second.size(); + }); +} + +void SlabAllocator::release(void* ptr, size_t size) +{ +#ifndef NO_MULTITHREADING + std::unique_lock lock(memory_store_mutex); +#endif + memory_store[size].push_back(ptr); + // dbg_info("Pooled poly memory of size: ", size, " total: ", get_total_size()); +} +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +SlabAllocator allocator; +} // namespace + +namespace bb { +void init_slab_allocator(size_t circuit_subgroup_size) +{ + allocator.init(circuit_subgroup_size); +} + +// auto init = ([]() { +// init_slab_allocator(524288); +// return 0; +// })(); + +std::shared_ptr get_mem_slab(size_t size) +{ + return allocator.get(size); +} + +void* get_mem_slab_raw(size_t size) +{ + auto slab = get_mem_slab(size); +#ifndef NO_MULTITHREADING + std::unique_lock lock(manual_slabs_mutex); +#endif + manual_slabs[slab.get()] = slab; + return slab.get(); +} + +void free_mem_slab_raw(void* p) +{ + if (allocator_destroyed) { + aligned_free(p); + return; + } +#ifndef NO_MULTITHREADING + std::unique_lock lock(manual_slabs_mutex); +#endif + manual_slabs.erase(p); +} +} // namespace bb diff --git a/sumcheck/src/cuda/includes/barretenberg/common/slab_allocator.hpp b/sumcheck/src/cuda/includes/barretenberg/common/slab_allocator.hpp new file mode 100644 index 0000000..1eb03b1 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/slab_allocator.hpp @@ -0,0 +1,78 @@ +#pragma once +#include "./assert.hpp" +#include "./log.hpp" +#include +#include +#include +#include +#ifndef NO_MULTITHREADING +#include +#endif + +namespace bb { + +/** + * Allocates a bunch of memory slabs sized to serve an UltraPLONK proof construction. + * If you want normal memory allocator behavior, just don't call this init function. + * + * WARNING: If client code is still holding onto slabs from previous use, when those slabs + * are released they'll end up back in the allocator. That's probably not desired as presumably + * those slabs are now too small, so they're effectively leaked. But good client code should be releasing + * it's resources promptly anyway. It's not considered "proper use" to call init, take slab, and call init + * again, before releasing the slab. + * + * TODO: Take a composer type and allocate slabs according to those requirements? + * TODO: De-globalise. Init the allocator and pass around. Use a PolynomialFactory (PolynomialStore?). + * TODO: Consider removing, but once due-dilligence has been done that we no longer have memory limitations. + */ +void init_slab_allocator(size_t circuit_subgroup_size); + +/** + * Returns a slab from the preallocated pool of slabs, or fallback to a new heap allocation (32 byte aligned). + * Ref counted result so no need to manually free. + */ +std::shared_ptr get_mem_slab(size_t size); + +/** + * Sometimes you want a raw pointer to a slab so you can manage when it's released manually (e.g. c_binds, containers). + * This still gets a slab with a shared_ptr, but holds the shared_ptr internally until free_mem_slab_raw is called. + */ +void* get_mem_slab_raw(size_t size); + +void free_mem_slab_raw(void*); + +/** + * Allocator for containers such as std::vector. Makes them leverage the underlying slab allocator where possible. + */ +template class ContainerSlabAllocator { + public: + using value_type = T; + using pointer = T*; + using const_pointer = const T*; + using size_type = std::size_t; + + template struct rebind { + using other = ContainerSlabAllocator; + }; + + pointer allocate(size_type n) + { + // info("ContainerSlabAllocator allocating: ", n * sizeof(T)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(get_mem_slab_raw(n * sizeof(T))); + } + + void deallocate(pointer p, size_type /*unused*/) { free_mem_slab_raw(p); } + + friend bool operator==(const ContainerSlabAllocator& /*unused*/, const ContainerSlabAllocator& /*unused*/) + { + return true; + } + + friend bool operator!=(const ContainerSlabAllocator& /*unused*/, const ContainerSlabAllocator& /*unused*/) + { + return false; + } +}; + +} // namespace bb \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/throw_or_abort.hpp b/sumcheck/src/cuda/includes/barretenberg/common/throw_or_abort.hpp new file mode 100644 index 0000000..bb140d8 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/throw_or_abort.hpp @@ -0,0 +1,13 @@ +#pragma once +#include "log.hpp" +#include + +inline void throw_or_abort [[noreturn]] (std::string const& err) +{ +#ifndef __wasm__ + throw std::runtime_error(err); +#else + info("abort: ", err); + std::abort(); +#endif +} \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/common/wasm_export.hpp b/sumcheck/src/cuda/includes/barretenberg/common/wasm_export.hpp new file mode 100644 index 0000000..b9ba4ef --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/common/wasm_export.hpp @@ -0,0 +1,16 @@ +#pragma once +#ifdef __clang__ +#define WASM_EXPORT extern "C" __attribute__((visibility("default"))) __attribute__((annotate("wasm_export"))) +#define ASYNC_WASM_EXPORT \ + extern "C" __attribute__((visibility("default"))) __attribute__((annotate("async_wasm_export"))) +#else +#define WASM_EXPORT extern "C" __attribute__((visibility("default"))) +#define ASYNC_WASM_EXPORT extern "C" __attribute__((visibility("default"))) +#endif + +#ifdef __wasm__ +// Allow linker to not link this +#define WASM_IMPORT(name) extern "C" __attribute__((import_module("env"), import_name(name))) +#else +#define WASM_IMPORT(name) extern "C" +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/CMakeLists.txt b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/CMakeLists.txt new file mode 100644 index 0000000..8fad42d --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/CMakeLists.txt @@ -0,0 +1 @@ +barretenberg_module(crypto_blake3s) \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3-impl.hpp b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3-impl.hpp new file mode 100644 index 0000000..5d6c3f9 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3-impl.hpp @@ -0,0 +1,80 @@ +#pragma once +/* + BLAKE3 reference source code package - C implementations + + Intellectual property: + + The Rust code is copyright Jack O'Connor, 2019-2020. + The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020. + The assembly code is copyright Samuel Neves, 2019-2020. + + This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache + License 2.0. + + - CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0 + - Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0 + + More information about the BLAKE3 hash function can be found at + https://github.com/BLAKE3-team/BLAKE3. +*/ + +#ifndef BLAKE3_IMPL_H +#define BLAKE3_IMPL_H + +#include +#include +#include + +#include "blake3s.hpp" + +namespace blake3 { + +// Right rotates 32 bit inputs +constexpr uint32_t rotr32(uint32_t w, uint32_t c) +{ + return (w >> c) | (w << (32 - c)); +} + +constexpr uint32_t load32(const uint8_t* src) +{ + return (static_cast(src[0]) << 0) | (static_cast(src[1]) << 8) | + (static_cast(src[2]) << 16) | (static_cast(src[3]) << 24); +} + +constexpr void load_key_words(const std::array& key, key_array& key_words) +{ + key_words[0] = load32(&key[0]); + key_words[1] = load32(&key[4]); + key_words[2] = load32(&key[8]); + key_words[3] = load32(&key[12]); + key_words[4] = load32(&key[16]); + key_words[5] = load32(&key[20]); + key_words[6] = load32(&key[24]); + key_words[7] = load32(&key[28]); +} + +constexpr void store32(uint8_t* dst, uint32_t w) +{ + dst[0] = static_cast(w >> 0); + dst[1] = static_cast(w >> 8); + dst[2] = static_cast(w >> 16); + dst[3] = static_cast(w >> 24); +} + +constexpr void store_cv_words(out_array& bytes_out, key_array& cv_words) +{ + store32(&bytes_out[0], cv_words[0]); + store32(&bytes_out[4], cv_words[1]); + store32(&bytes_out[8], cv_words[2]); + store32(&bytes_out[12], cv_words[3]); + store32(&bytes_out[16], cv_words[4]); + store32(&bytes_out[20], cv_words[5]); + store32(&bytes_out[24], cv_words[6]); + store32(&bytes_out[28], cv_words[7]); +} + +} // namespace blake3 + +#include "blake3s.tcc" + +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3s.hpp b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3s.hpp new file mode 100644 index 0000000..a9bb85d --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3s.hpp @@ -0,0 +1,113 @@ +/* + BLAKE3 reference source code package - C implementations + + Intellectual property: + + The Rust code is copyright Jack O'Connor, 2019-2020. + The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020. + The assembly code is copyright Samuel Neves, 2019-2020. + + This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache + License 2.0. + + - CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0 + - Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0 + + More information about the BLAKE3 hash function can be found at + https://github.com/BLAKE3-team/BLAKE3. + + + NOTE: We have modified the original code from the BLAKE3 reference C implementation. + The following code works ONLY for inputs of size less than 1024 bytes. This kind of constraint + on the input size greatly simplifies the code and helps us get rid of the recursive merkle-tree + like operations on chunks (data of size 1024 bytes). This is because we would always be using BLAKE3 + hashing for inputs of size 32 bytes (or lesser) in barretenberg. The full C++ version of BLAKE3 + from the original authors is in the module `../crypto/blake3s_full`. + + Also, the length of the output in this specific implementation is fixed at 32 bytes which is the only + version relevant to Barretenberg. +*/ +#pragma once +#include +#include +#include +#include +#include + +namespace blake3 { + +// internal flags +enum blake3_flags { + CHUNK_START = 1 << 0, + CHUNK_END = 1 << 1, + PARENT = 1 << 2, + ROOT = 1 << 3, + KEYED_HASH = 1 << 4, + DERIVE_KEY_CONTEXT = 1 << 5, + DERIVE_KEY_MATERIAL = 1 << 6, +}; + +// constants +enum blake3s_constant { + BLAKE3_KEY_LEN = 32, + BLAKE3_OUT_LEN = 32, + BLAKE3_BLOCK_LEN = 64, + BLAKE3_CHUNK_LEN = 1024, + BLAKE3_MAX_DEPTH = 54 +}; + +using key_array = std::array; +using block_array = std::array; +using state_array = std::array; +using out_array = std::array; + +static constexpr key_array IV = { 0x6A09E667UL, 0xBB67AE85UL, 0x3C6EF372UL, 0xA54FF53AUL, + 0x510E527FUL, 0x9B05688CUL, 0x1F83D9ABUL, 0x5BE0CD19UL }; + +static constexpr std::array MSG_SCHEDULE_0 = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }; +static constexpr std::array MSG_SCHEDULE_1 = { 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 }; +static constexpr std::array MSG_SCHEDULE_2 = { 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 }; +static constexpr std::array MSG_SCHEDULE_3 = { 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 }; +static constexpr std::array MSG_SCHEDULE_4 = { 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 }; +static constexpr std::array MSG_SCHEDULE_5 = { 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 }; +static constexpr std::array MSG_SCHEDULE_6 = { 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 }; +static constexpr std::array, 7> MSG_SCHEDULE = { + MSG_SCHEDULE_0, MSG_SCHEDULE_1, MSG_SCHEDULE_2, MSG_SCHEDULE_3, MSG_SCHEDULE_4, MSG_SCHEDULE_5, MSG_SCHEDULE_6, +}; + +struct blake3_hasher { + key_array key; + key_array cv; + block_array buf; + uint8_t buf_len = 0; + uint8_t blocks_compressed = 0; + uint8_t flags = 0; +}; + +inline const char* blake3_version() +{ + static const std::string version = "0.3.7"; + return version.c_str(); +} + +constexpr void blake3_hasher_init(blake3_hasher* self); +constexpr void blake3_hasher_update(blake3_hasher* self, const uint8_t* input, size_t input_len); +constexpr void blake3_hasher_finalize(const blake3_hasher* self, uint8_t* out); + +constexpr void g(state_array& state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y); +constexpr void round_fn(state_array& state, const uint32_t* msg, size_t round); + +constexpr void compress_pre( + state_array& state, const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags); + +constexpr void blake3_compress_in_place(key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags); + +constexpr void blake3_compress_xof( + const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags, uint8_t* out); + +constexpr std::array blake3s_constexpr(const uint8_t* input, size_t input_size); +inline std::vector blake3s(std::vector const& input); + +} // namespace blake3 + +#include "blake3-impl.hpp" diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3s.tcc b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3s.tcc new file mode 100644 index 0000000..9e2a5ef --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/blake3s.tcc @@ -0,0 +1,263 @@ +#pragma once +/* + BLAKE3 reference source code package - C implementations + + Intellectual property: + + The Rust code is copyright Jack O'Connor, 2019-2020. + The C code is copyright Samuel Neves and Jack O'Connor, 2019-2020. + The assembly code is copyright Samuel Neves, 2019-2020. + + This work is released into the public domain with CC0 1.0. Alternatively, it is licensed under the Apache + License 2.0. + + - CC0 1.0 Universal : http://creativecommons.org/publicdomain/zero/1.0 + - Apache 2.0 : http://www.apache.org/licenses/LICENSE-2.0 + + More information about the BLAKE3 hash function can be found at + https://github.com/BLAKE3-team/BLAKE3. + + + NOTE: We have modified the original code from the BLAKE3 reference C implementation. + The following code works ONLY for inputs of size less than 1024 bytes. This kind of constraint + on the input size greatly simplifies the code and helps us get rid of the recursive merkle-tree + like operations on chunks (data of size 1024 bytes). This is because we would always be using BLAKE3 + hashing for inputs of size 32 bytes (or lesser) in barretenberg. The full C++ version of BLAKE3 + from the original authors is in the module `../crypto/blake3s_full`. + + Also, the length of the output in this specific implementation is fixed at 32 bytes which is the only + version relevant to Barretenberg. +*/ + +#include +#include + +#include "blake3s.hpp" + +namespace blake3 { + +/* + * Core Blake3s functions. These are similar to that of Blake2s except for a few + * constant parameters and fewer rounds. + * + */ +constexpr void g(state_array& state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y) +{ + state[a] = state[a] + state[b] + x; + state[d] = rotr32(state[d] ^ state[a], 16); + state[c] = state[c] + state[d]; + state[b] = rotr32(state[b] ^ state[c], 12); + state[a] = state[a] + state[b] + y; + state[d] = rotr32(state[d] ^ state[a], 8); + state[c] = state[c] + state[d]; + state[b] = rotr32(state[b] ^ state[c], 7); +} + +constexpr void round_fn(state_array& state, const uint32_t* msg, size_t round) +{ + // Select the message schedule based on the round. + const auto schedule = MSG_SCHEDULE[round]; + + // Mix the columns. + g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); + g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); + g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); + g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); + + // Mix the rows. + g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); + g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); + g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); + g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); +} + +constexpr void compress_pre( + state_array& state, const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags) +{ + std::array block_words; + block_words[0] = load32(&block[0]); + block_words[1] = load32(&block[4]); + block_words[2] = load32(&block[8]); + block_words[3] = load32(&block[12]); + block_words[4] = load32(&block[16]); + block_words[5] = load32(&block[20]); + block_words[6] = load32(&block[24]); + block_words[7] = load32(&block[28]); + block_words[8] = load32(&block[32]); + block_words[9] = load32(&block[36]); + block_words[10] = load32(&block[40]); + block_words[11] = load32(&block[44]); + block_words[12] = load32(&block[48]); + block_words[13] = load32(&block[52]); + block_words[14] = load32(&block[56]); + block_words[15] = load32(&block[60]); + + state[0] = cv[0]; + state[1] = cv[1]; + state[2] = cv[2]; + state[3] = cv[3]; + state[4] = cv[4]; + state[5] = cv[5]; + state[6] = cv[6]; + state[7] = cv[7]; + state[8] = IV[0]; + state[9] = IV[1]; + state[10] = IV[2]; + state[11] = IV[3]; + state[12] = 0; + state[13] = 0; + state[14] = static_cast(block_len); + state[15] = static_cast(flags); + + round_fn(state, &block_words[0], 0); + round_fn(state, &block_words[0], 1); + round_fn(state, &block_words[0], 2); + round_fn(state, &block_words[0], 3); + round_fn(state, &block_words[0], 4); + round_fn(state, &block_words[0], 5); + round_fn(state, &block_words[0], 6); +} + +constexpr void blake3_compress_in_place(key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags) +{ + state_array state; + compress_pre(state, cv, block, block_len, flags); + cv[0] = state[0] ^ state[8]; + cv[1] = state[1] ^ state[9]; + cv[2] = state[2] ^ state[10]; + cv[3] = state[3] ^ state[11]; + cv[4] = state[4] ^ state[12]; + cv[5] = state[5] ^ state[13]; + cv[6] = state[6] ^ state[14]; + cv[7] = state[7] ^ state[15]; +} + +constexpr void blake3_compress_xof( + const key_array& cv, const uint8_t* block, uint8_t block_len, uint8_t flags, uint8_t* out) +{ + state_array state; + compress_pre(state, cv, block, block_len, flags); + + store32(&out[0], state[0] ^ state[8]); + store32(&out[4], state[1] ^ state[9]); + store32(&out[8], state[2] ^ state[10]); + store32(&out[12], state[3] ^ state[11]); + store32(&out[16], state[4] ^ state[12]); + store32(&out[20], state[5] ^ state[13]); + store32(&out[24], state[6] ^ state[14]); + store32(&out[28], state[7] ^ state[15]); + store32(&out[32], state[8] ^ cv[0]); + store32(&out[36], state[9] ^ cv[1]); + store32(&out[40], state[10] ^ cv[2]); + store32(&out[44], state[11] ^ cv[3]); + store32(&out[48], state[12] ^ cv[4]); + store32(&out[52], state[13] ^ cv[5]); + store32(&out[56], state[14] ^ cv[6]); + store32(&out[60], state[15] ^ cv[7]); +} + +constexpr uint8_t maybe_start_flag(const blake3_hasher* self) +{ + if (self->blocks_compressed == 0) { + return CHUNK_START; + } + return 0; +} + +struct output_t { + key_array input_cv = {}; + block_array block = {}; + uint8_t block_len = 0; + uint8_t flags = 0; +}; + +constexpr output_t make_output(const key_array& input_cv, const uint8_t* block, uint8_t block_len, uint8_t flags) +{ + output_t ret; + for (size_t i = 0; i < (BLAKE3_OUT_LEN >> 2); ++i) { + ret.input_cv[i] = input_cv[i]; + } + for (size_t i = 0; i < BLAKE3_BLOCK_LEN; i++) { + ret.block[i] = block[i]; + } + ret.block_len = block_len; + ret.flags = flags; + return ret; +} + +constexpr void blake3_hasher_init(blake3_hasher* self) +{ + for (size_t i = 0; i < (BLAKE3_KEY_LEN >> 2); ++i) { + self->key[i] = IV[i]; + self->cv[i] = IV[i]; + } + for (size_t i = 0; i < BLAKE3_BLOCK_LEN; i++) { + self->buf[i] = 0; + } + self->buf_len = 0; + self->blocks_compressed = 0; + self->flags = 0; +} + +constexpr void blake3_hasher_update(blake3_hasher* self, const uint8_t* input, size_t input_len) +{ + if (input_len == 0) { + return; + } + + while (input_len > BLAKE3_BLOCK_LEN) { + blake3_compress_in_place(self->cv, input, BLAKE3_BLOCK_LEN, self->flags | maybe_start_flag(self)); + + self->blocks_compressed = static_cast(self->blocks_compressed + 1U); + input += BLAKE3_BLOCK_LEN; + input_len -= BLAKE3_BLOCK_LEN; + } + + size_t take = BLAKE3_BLOCK_LEN - (static_cast(self->buf_len)); + if (take > input_len) { + take = input_len; + } + uint8_t* dest = &self->buf[0] + (static_cast(self->buf_len)); + for (size_t i = 0; i < take; i++) { + dest[i] = input[i]; + } + + self->buf_len = static_cast(self->buf_len + static_cast(take)); + input_len -= take; +} + +constexpr void blake3_hasher_finalize(const blake3_hasher* self, uint8_t* out) +{ + uint8_t block_flags = self->flags | maybe_start_flag(self) | CHUNK_END; + output_t output = make_output(self->cv, &self->buf[0], self->buf_len, block_flags); + + block_array wide_buf; + blake3_compress_xof(output.input_cv, &output.block[0], output.block_len, output.flags | ROOT, &wide_buf[0]); + for (size_t i = 0; i < BLAKE3_OUT_LEN; i++) { + out[i] = wide_buf[i]; + } +} + +std::vector blake3s(std::vector const& input) +{ + blake3_hasher hasher; + blake3_hasher_init(&hasher); + blake3_hasher_update(&hasher, static_cast(input.data()), input.size()); + + std::vector output(BLAKE3_OUT_LEN); + blake3_hasher_finalize(&hasher, &output[0]); + return output; +} + +constexpr std::array blake3s_constexpr(const uint8_t* input, const size_t input_size) +{ + blake3_hasher hasher; + blake3_hasher_init(&hasher); + blake3_hasher_update(&hasher, input, input_size); + + std::array output; + blake3_hasher_finalize(&hasher, &output[0]); + return output; +} + +} // namespace blake3 diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/c_bind.cpp b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/c_bind.cpp new file mode 100644 index 0000000..7c195ae --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/blake3s/c_bind.cpp @@ -0,0 +1,11 @@ +#include "../../common/wasm_export.hpp" +#include "../../ecc/curves/bn254/fr.hpp" +#include "blake3s.hpp" + +WASM_EXPORT void blake3s_to_field(uint8_t const* data, size_t length, uint8_t* r) +{ + std::vector inputv(data, data + length); + std::vector output = blake3::blake3s(inputv); + auto result = bb::fr::serialize_from_buffer(output.data()); + bb::fr::serialize_to_buffer(result, r); +} diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/CMakeLists.txt b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/CMakeLists.txt new file mode 100644 index 0000000..c447763 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/CMakeLists.txt @@ -0,0 +1 @@ +barretenberg_module(crypto_keccak) \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/hash_types.hpp b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/hash_types.hpp new file mode 100644 index 0000000..cac2a35 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/hash_types.hpp @@ -0,0 +1,20 @@ +/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm. + * Copyright 2018-2019 Pawel Bylica. + * Licensed under the Apache License, Version 2.0. + */ + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct keccak256 { + uint64_t word64s[4]; +}; + +#ifdef __cplusplus +} +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.cpp b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.cpp new file mode 100644 index 0000000..6548189 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.cpp @@ -0,0 +1,133 @@ +/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm. + * Copyright 2018-2019 Pawel Bylica. + * Licensed under the Apache License, Version 2.0. + */ + +#include "keccak.hpp" + +#include "./hash_types.hpp" + +#if _MSC_VER +#include +#define __builtin_memcpy memcpy +#endif + +#if _WIN32 +/* On Windows assume little endian. */ +#define __LITTLE_ENDIAN 1234 +#define __BIG_ENDIAN 4321 +#define __BYTE_ORDER __LITTLE_ENDIAN +#elif __APPLE__ +#include +#else +#include +#endif + +#if __BYTE_ORDER == __LITTLE_ENDIAN +#define to_le64(X) X +#else +#define to_le64(X) __builtin_bswap64(X) +#endif + +#if __BYTE_ORDER == __LITTLE_ENDIAN +#define to_be64(X) __builtin_bswap64(X) +#else +#define to_be64(X) X +#endif + +/** Loads 64-bit integer from given memory location as little-endian number. */ +static inline uint64_t load_le(const uint8_t* data) +{ + /* memcpy is the best way of expressing the intention. Every compiler will + optimize is to single load instruction if the target architecture + supports unaligned memory access (GCC and clang even in O0). + This is great trick because we are violating C/C++ memory alignment + restrictions with no performance penalty. */ + uint64_t word; + __builtin_memcpy(&word, data, sizeof(word)); + return to_le64(word); +} + +static inline void keccak(uint64_t* out, size_t bits, const uint8_t* data, size_t size) +{ + static const size_t word_size = sizeof(uint64_t); + const size_t hash_size = bits / 8; + const size_t block_size = (1600 - bits * 2) / 8; + + size_t i; + uint64_t* state_iter; + uint64_t last_word = 0; + uint8_t* last_word_iter = (uint8_t*)&last_word; + + uint64_t state[25] = { 0 }; + + while (size >= block_size) { + for (i = 0; i < (block_size / word_size); ++i) { + state[i] ^= load_le(data); + data += word_size; + } + + ethash_keccakf1600(state); + + size -= block_size; + } + + state_iter = state; + + while (size >= word_size) { + *state_iter ^= load_le(data); + ++state_iter; + data += word_size; + size -= word_size; + } + + while (size > 0) { + *last_word_iter = *data; + ++last_word_iter; + ++data; + --size; + } + *last_word_iter = 0x01; + *state_iter ^= to_le64(last_word); + + state[(block_size / word_size) - 1] ^= 0x8000000000000000; + + ethash_keccakf1600(state); + + for (i = 0; i < (hash_size / word_size); ++i) + out[i] = to_le64(state[i]); +} + +struct keccak256 ethash_keccak256(const uint8_t* data, size_t size) NOEXCEPT +{ + struct keccak256 hash; + keccak(hash.word64s, 256, data, size); + return hash; +} + +struct keccak256 hash_field_elements(const uint64_t* limbs, size_t num_elements) +{ + uint8_t input_buffer[num_elements * 32]; + + for (size_t i = 0; i < num_elements; ++i) { + for (size_t j = 0; j < 4; ++j) { + uint64_t word = (limbs[i * 4 + j]); + size_t idx = i * 32 + j * 8; + input_buffer[idx] = (uint8_t)((word >> 56) & 0xff); + input_buffer[idx + 1] = (uint8_t)((word >> 48) & 0xff); + input_buffer[idx + 2] = (uint8_t)((word >> 40) & 0xff); + input_buffer[idx + 3] = (uint8_t)((word >> 32) & 0xff); + input_buffer[idx + 4] = (uint8_t)((word >> 24) & 0xff); + input_buffer[idx + 5] = (uint8_t)((word >> 16) & 0xff); + input_buffer[idx + 6] = (uint8_t)((word >> 8) & 0xff); + input_buffer[idx + 7] = (uint8_t)(word & 0xff); + } + } + + return ethash_keccak256(input_buffer, num_elements * 32); +} + +struct keccak256 hash_field_element(const uint64_t* limb) +{ + return hash_field_elements(limb, 1); +} diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.hpp b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.hpp new file mode 100644 index 0000000..8ad5e76 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccak.hpp @@ -0,0 +1,40 @@ +/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm. + * Copyright 2018-2019 Pawel Bylica. + * Licensed under the Apache License, Version 2.0. + */ + +#pragma once + +#include "./hash_types.hpp" + +#include + +#ifdef __cplusplus +#define NOEXCEPT noexcept +#else +#define NOEXCEPT +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * The Keccak-f[1600] function. + * + * The implementation of the Keccak-f function with 1600-bit width of the permutation (b). + * The size of the state is also 1600 bit what gives 25 64-bit words. + * + * @param state The state of 25 64-bit words on which the permutation is to be performed. + */ +void ethash_keccakf1600(uint64_t state[25]) NOEXCEPT; + +struct keccak256 ethash_keccak256(const uint8_t* data, size_t size) NOEXCEPT; + +struct keccak256 hash_field_elements(const uint64_t* limbs, size_t num_elements); + +struct keccak256 hash_field_element(const uint64_t* limb); + +#ifdef __cplusplus +} +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccakf1600.cpp b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccakf1600.cpp new file mode 100644 index 0000000..919b49d --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/crypto/keccak/keccakf1600.cpp @@ -0,0 +1,235 @@ +/* ethash: C/C++ implementation of Ethash, the Ethereum Proof of Work algorithm. + * Copyright 2018-2019 Pawel Bylica. + * Licensed under the Apache License, Version 2.0. + */ + +#include "keccak.hpp" +#include + +static uint64_t rol(uint64_t x, unsigned s) +{ + return (x << s) | (x >> (64 - s)); +} + +static const uint64_t round_constants[24] = { + 0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, 0x000000000000808b, + 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, 0x000000000000008a, 0x0000000000000088, + 0x0000000080008009, 0x000000008000000a, 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, + 0x8000000000008003, 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a, + 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008, +}; + +void ethash_keccakf1600(uint64_t state[25]) NOEXCEPT +{ + /* The implementation based on the "simple" implementation by Ronny Van Keer. */ + + int round; + + uint64_t Aba, Abe, Abi, Abo, Abu; + uint64_t Aga, Age, Agi, Ago, Agu; + uint64_t Aka, Ake, Aki, Ako, Aku; + uint64_t Ama, Ame, Ami, Amo, Amu; + uint64_t Asa, Ase, Asi, Aso, Asu; + + uint64_t Eba, Ebe, Ebi, Ebo, Ebu; + uint64_t Ega, Ege, Egi, Ego, Egu; + uint64_t Eka, Eke, Eki, Eko, Eku; + uint64_t Ema, Eme, Emi, Emo, Emu; + uint64_t Esa, Ese, Esi, Eso, Esu; + + uint64_t Ba, Be, Bi, Bo, Bu; + + uint64_t Da, De, Di, Do, Du; + + Aba = state[0]; + Abe = state[1]; + Abi = state[2]; + Abo = state[3]; + Abu = state[4]; + Aga = state[5]; + Age = state[6]; + Agi = state[7]; + Ago = state[8]; + Agu = state[9]; + Aka = state[10]; + Ake = state[11]; + Aki = state[12]; + Ako = state[13]; + Aku = state[14]; + Ama = state[15]; + Ame = state[16]; + Ami = state[17]; + Amo = state[18]; + Amu = state[19]; + Asa = state[20]; + Ase = state[21]; + Asi = state[22]; + Aso = state[23]; + Asu = state[24]; + + for (round = 0; round < 24; round += 2) { + /* Round (round + 0): Axx -> Exx */ + + Ba = Aba ^ Aga ^ Aka ^ Ama ^ Asa; + Be = Abe ^ Age ^ Ake ^ Ame ^ Ase; + Bi = Abi ^ Agi ^ Aki ^ Ami ^ Asi; + Bo = Abo ^ Ago ^ Ako ^ Amo ^ Aso; + Bu = Abu ^ Agu ^ Aku ^ Amu ^ Asu; + + Da = Bu ^ rol(Be, 1); + De = Ba ^ rol(Bi, 1); + Di = Be ^ rol(Bo, 1); + Do = Bi ^ rol(Bu, 1); + Du = Bo ^ rol(Ba, 1); + + Ba = Aba ^ Da; + Be = rol(Age ^ De, 44); + Bi = rol(Aki ^ Di, 43); + Bo = rol(Amo ^ Do, 21); + Bu = rol(Asu ^ Du, 14); + Eba = Ba ^ (~Be & Bi) ^ round_constants[round]; + Ebe = Be ^ (~Bi & Bo); + Ebi = Bi ^ (~Bo & Bu); + Ebo = Bo ^ (~Bu & Ba); + Ebu = Bu ^ (~Ba & Be); + + Ba = rol(Abo ^ Do, 28); + Be = rol(Agu ^ Du, 20); + Bi = rol(Aka ^ Da, 3); + Bo = rol(Ame ^ De, 45); + Bu = rol(Asi ^ Di, 61); + Ega = Ba ^ (~Be & Bi); + Ege = Be ^ (~Bi & Bo); + Egi = Bi ^ (~Bo & Bu); + Ego = Bo ^ (~Bu & Ba); + Egu = Bu ^ (~Ba & Be); + + Ba = rol(Abe ^ De, 1); + Be = rol(Agi ^ Di, 6); + Bi = rol(Ako ^ Do, 25); + Bo = rol(Amu ^ Du, 8); + Bu = rol(Asa ^ Da, 18); + Eka = Ba ^ (~Be & Bi); + Eke = Be ^ (~Bi & Bo); + Eki = Bi ^ (~Bo & Bu); + Eko = Bo ^ (~Bu & Ba); + Eku = Bu ^ (~Ba & Be); + + Ba = rol(Abu ^ Du, 27); + Be = rol(Aga ^ Da, 36); + Bi = rol(Ake ^ De, 10); + Bo = rol(Ami ^ Di, 15); + Bu = rol(Aso ^ Do, 56); + Ema = Ba ^ (~Be & Bi); + Eme = Be ^ (~Bi & Bo); + Emi = Bi ^ (~Bo & Bu); + Emo = Bo ^ (~Bu & Ba); + Emu = Bu ^ (~Ba & Be); + + Ba = rol(Abi ^ Di, 62); + Be = rol(Ago ^ Do, 55); + Bi = rol(Aku ^ Du, 39); + Bo = rol(Ama ^ Da, 41); + Bu = rol(Ase ^ De, 2); + Esa = Ba ^ (~Be & Bi); + Ese = Be ^ (~Bi & Bo); + Esi = Bi ^ (~Bo & Bu); + Eso = Bo ^ (~Bu & Ba); + Esu = Bu ^ (~Ba & Be); + + /* Round (round + 1): Exx -> Axx */ + + Ba = Eba ^ Ega ^ Eka ^ Ema ^ Esa; + Be = Ebe ^ Ege ^ Eke ^ Eme ^ Ese; + Bi = Ebi ^ Egi ^ Eki ^ Emi ^ Esi; + Bo = Ebo ^ Ego ^ Eko ^ Emo ^ Eso; + Bu = Ebu ^ Egu ^ Eku ^ Emu ^ Esu; + + Da = Bu ^ rol(Be, 1); + De = Ba ^ rol(Bi, 1); + Di = Be ^ rol(Bo, 1); + Do = Bi ^ rol(Bu, 1); + Du = Bo ^ rol(Ba, 1); + + Ba = Eba ^ Da; + Be = rol(Ege ^ De, 44); + Bi = rol(Eki ^ Di, 43); + Bo = rol(Emo ^ Do, 21); + Bu = rol(Esu ^ Du, 14); + Aba = Ba ^ (~Be & Bi) ^ round_constants[round + 1]; + Abe = Be ^ (~Bi & Bo); + Abi = Bi ^ (~Bo & Bu); + Abo = Bo ^ (~Bu & Ba); + Abu = Bu ^ (~Ba & Be); + + Ba = rol(Ebo ^ Do, 28); + Be = rol(Egu ^ Du, 20); + Bi = rol(Eka ^ Da, 3); + Bo = rol(Eme ^ De, 45); + Bu = rol(Esi ^ Di, 61); + Aga = Ba ^ (~Be & Bi); + Age = Be ^ (~Bi & Bo); + Agi = Bi ^ (~Bo & Bu); + Ago = Bo ^ (~Bu & Ba); + Agu = Bu ^ (~Ba & Be); + + Ba = rol(Ebe ^ De, 1); + Be = rol(Egi ^ Di, 6); + Bi = rol(Eko ^ Do, 25); + Bo = rol(Emu ^ Du, 8); + Bu = rol(Esa ^ Da, 18); + Aka = Ba ^ (~Be & Bi); + Ake = Be ^ (~Bi & Bo); + Aki = Bi ^ (~Bo & Bu); + Ako = Bo ^ (~Bu & Ba); + Aku = Bu ^ (~Ba & Be); + + Ba = rol(Ebu ^ Du, 27); + Be = rol(Ega ^ Da, 36); + Bi = rol(Eke ^ De, 10); + Bo = rol(Emi ^ Di, 15); + Bu = rol(Eso ^ Do, 56); + Ama = Ba ^ (~Be & Bi); + Ame = Be ^ (~Bi & Bo); + Ami = Bi ^ (~Bo & Bu); + Amo = Bo ^ (~Bu & Ba); + Amu = Bu ^ (~Ba & Be); + + Ba = rol(Ebi ^ Di, 62); + Be = rol(Ego ^ Do, 55); + Bi = rol(Eku ^ Du, 39); + Bo = rol(Ema ^ Da, 41); + Bu = rol(Ese ^ De, 2); + Asa = Ba ^ (~Be & Bi); + Ase = Be ^ (~Bi & Bo); + Asi = Bi ^ (~Bo & Bu); + Aso = Bo ^ (~Bu & Ba); + Asu = Bu ^ (~Ba & Be); + } + + state[0] = Aba; + state[1] = Abe; + state[2] = Abi; + state[3] = Abo; + state[4] = Abu; + state[5] = Aga; + state[6] = Age; + state[7] = Agi; + state[8] = Ago; + state[9] = Agu; + state[10] = Aka; + state[11] = Ake; + state[12] = Aki; + state[13] = Ako; + state[14] = Aku; + state[15] = Ama; + state[16] = Ame; + state[17] = Ami; + state[18] = Amo; + state[19] = Amu; + state[20] = Asa; + state[21] = Ase; + state[22] = Asi; + state[23] = Aso; + state[24] = Asu; +} diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/bn254.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/bn254.hpp new file mode 100644 index 0000000..37d4212 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/bn254.hpp @@ -0,0 +1,26 @@ +#pragma once +#include "../bn254/fq.hpp" +#include "../bn254/fq12.hpp" +#include "../bn254/fq2.hpp" +#include "../bn254/fr.hpp" +#include "../bn254/g1.hpp" +#include "../bn254/g2.hpp" + +namespace bb::curve { +class BN254 { + public: + using ScalarField = bb::fr; + using BaseField = bb::fq; + using Group = typename bb::g1; + using Element = typename Group::element; + using AffineElement = typename Group::affine_element; + using G2AffineElement = typename bb::g2::affine_element; + using G2BaseField = typename bb::fq2; + using TargetField = bb::fq12; + + // TODO(#673): This flag is temporary. It is needed in the verifier classes (GeminiVerifier, etc.) while these + // classes are instantiated with "native" curve types. Eventually, the verifier classes will be instantiated only + // with stdlib types, and "native" verification will be acheived via a simulated builder. + static constexpr bool is_stdlib_type = false; +}; +} // namespace bb::curve \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq.hpp new file mode 100644 index 0000000..dce9f26 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include +#include + +#include "../../fields/field.hpp" + +// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays) +namespace bb { +class Bn254FqParams { + public: + static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL; + static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL; + static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL; + static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL; + + static constexpr uint64_t r_squared_0 = 0xF32CFC5B538AFA89UL; + static constexpr uint64_t r_squared_1 = 0xB5E71911D44501FBUL; + static constexpr uint64_t r_squared_2 = 0x47AB1EFF0A417FF6UL; + static constexpr uint64_t r_squared_3 = 0x06D89F71CAB8351FUL; + + static constexpr uint64_t cube_root_0 = 0x71930c11d782e155UL; + static constexpr uint64_t cube_root_1 = 0xa6bb947cffbe3323UL; + static constexpr uint64_t cube_root_2 = 0xaa303344d4741444UL; + static constexpr uint64_t cube_root_3 = 0x2c3b3f0d26594943UL; + + static constexpr uint64_t modulus_wasm_0 = 0x187cfd47; + static constexpr uint64_t modulus_wasm_1 = 0x10460b6; + static constexpr uint64_t modulus_wasm_2 = 0x1c72a34f; + static constexpr uint64_t modulus_wasm_3 = 0x2d522d0; + static constexpr uint64_t modulus_wasm_4 = 0x1585d978; + static constexpr uint64_t modulus_wasm_5 = 0x2db40c0; + static constexpr uint64_t modulus_wasm_6 = 0xa6e141; + static constexpr uint64_t modulus_wasm_7 = 0xe5c2634; + static constexpr uint64_t modulus_wasm_8 = 0x30644e; + + static constexpr uint64_t r_squared_wasm_0 = 0xe1a2a074659bac10UL; + static constexpr uint64_t r_squared_wasm_1 = 0x639855865406005aUL; + static constexpr uint64_t r_squared_wasm_2 = 0xff54c5802d3e2632UL; + static constexpr uint64_t r_squared_wasm_3 = 0x2a11a68c34ea65a6UL; + + static constexpr uint64_t cube_root_wasm_0 = 0x62b1a3a46a337995UL; + static constexpr uint64_t cube_root_wasm_1 = 0xadc97d2722e2726eUL; + static constexpr uint64_t cube_root_wasm_2 = 0x64ee82ede2db85faUL; + static constexpr uint64_t cube_root_wasm_3 = 0x0c0afea1488a03bbUL; + + static constexpr uint64_t primitive_root_0 = 0UL; + static constexpr uint64_t primitive_root_1 = 0UL; + static constexpr uint64_t primitive_root_2 = 0UL; + static constexpr uint64_t primitive_root_3 = 0UL; + + static constexpr uint64_t primitive_root_wasm_0 = 0x0000000000000000UL; + static constexpr uint64_t primitive_root_wasm_1 = 0x0000000000000000UL; + static constexpr uint64_t primitive_root_wasm_2 = 0x0000000000000000UL; + static constexpr uint64_t primitive_root_wasm_3 = 0x0000000000000000UL; + + static constexpr uint64_t endo_g1_lo = 0x7a7bd9d4391eb18d; + static constexpr uint64_t endo_g1_mid = 0x4ccef014a773d2cfUL; + static constexpr uint64_t endo_g1_hi = 0x0000000000000002UL; + static constexpr uint64_t endo_g2_lo = 0xd91d232ec7e0b3d2UL; + static constexpr uint64_t endo_g2_mid = 0x0000000000000002UL; + static constexpr uint64_t endo_minus_b1_lo = 0x8211bbeb7d4f1129UL; + static constexpr uint64_t endo_minus_b1_mid = 0x6f4d8248eeb859fcUL; + static constexpr uint64_t endo_b2_lo = 0x89d3256894d213e2UL; + static constexpr uint64_t endo_b2_mid = 0UL; + + static constexpr uint64_t r_inv = 0x87d20782e4866389UL; + + static constexpr uint64_t coset_generators_0[8]{ + 0x7a17caa950ad28d7ULL, 0x4d750e37163c3674ULL, 0x20d251c4dbcb4411ULL, 0xf42f9552a15a51aeULL, + 0x4f4bc0b2b5ef64bdULL, 0x22a904407b7e725aULL, 0xf60647ce410d7ff7ULL, 0xc9638b5c069c8d94ULL, + }; + static constexpr uint64_t coset_generators_1[8]{ + 0x1f6ac17ae15521b9ULL, 0x29e3aca3d71c2cf7ULL, 0x345c97cccce33835ULL, 0x3ed582f5c2aa4372ULL, + 0x1a4b98fbe78db996ULL, 0x24c48424dd54c4d4ULL, 0x2f3d6f4dd31bd011ULL, 0x39b65a76c8e2db4fULL, + }; + static constexpr uint64_t coset_generators_2[8]{ + 0x334bea4e696bd284ULL, 0x99ba8dbde1e518b0ULL, 0x29312d5a5e5edcULL, 0x6697d49cd2d7a508ULL, + 0x5c65ec9f484e3a79ULL, 0xc2d4900ec0c780a5ULL, 0x2943337e3940c6d1ULL, 0x8fb1d6edb1ba0cfdULL, + }; + static constexpr uint64_t coset_generators_3[8]{ + 0x2a1f6744ce179d8eULL, 0x3829df06681f7cbdULL, 0x463456c802275bedULL, 0x543ece899c2f3b1cULL, + 0x180a96573d3d9f8ULL, 0xf8b21270ddbb927ULL, 0x1d9598e8a7e39857ULL, 0x2ba010aa41eb7786ULL, + }; + + static constexpr uint64_t coset_generators_wasm_0[8] = { 0xeb8a8ec140766463ULL, 0xfded87957d76333dULL, + 0x4c710c8092f2ff5eULL, 0x9af4916ba86fcb7fULL, + 0xe9781656bdec97a0ULL, 0xfbdb0f2afaec667aULL, + 0x4a5e94161069329bULL, 0x98e2190125e5febcULL }; + static constexpr uint64_t coset_generators_wasm_1[8] = { 0xf2b1f20626a3da49ULL, 0x56c12d76cb13587fULL, + 0x5251d378d7f4a143ULL, 0x4de2797ae4d5ea06ULL, + 0x49731f7cf1b732c9ULL, 0xad825aed9626b0ffULL, + 0xa91300efa307f9c3ULL, 0xa4a3a6f1afe94286ULL }; + static constexpr uint64_t coset_generators_wasm_2[8] = { 0xf905ef8d84d5fea4ULL, 0x93b7a45b84f1507eULL, + 0xe6b99ee0068dfab5ULL, 0x39bb9964882aa4ecULL, + 0x8cbd93e909c74f23ULL, 0x276f48b709e2a0fcULL, + 0x7a71433b8b7f4b33ULL, 0xcd733dc00d1bf56aULL }; + static constexpr uint64_t coset_generators_wasm_3[8] = { 0x2958a27c02b7cd5fULL, 0x06bc8a3277c371abULL, + 0x1484c05bce00b620ULL, 0x224cf685243dfa96ULL, + 0x30152cae7a7b3f0bULL, 0x0d791464ef86e357ULL, + 0x1b414a8e45c427ccULL, 0x290980b79c016c41ULL }; + + // used in msgpack schema serialization + static constexpr char schema_name[] = "fq"; + static constexpr bool has_high_2adicity = false; + + // The modulus is larger than BN254 scalar field modulus, so it maps to two BN254 scalars + static constexpr size_t NUM_BN254_SCALARS = 2; +}; + +using fq = field; + +} // namespace bb + +// NOLINTEND(cppcoreguidelines-avoid-c-arrays) diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq2.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq2.hpp new file mode 100644 index 0000000..fce8cc5 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fq2.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include "../../fields/field2.hpp" +#include "./fq.hpp" + +namespace bb { +struct Bn254Fq2Params { +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + static constexpr fq twist_coeff_b_0{ + 0x3bf938e377b802a8UL, 0x020b1b273633535dUL, 0x26b7edf049755260UL, 0x2514c6324384a86dUL + }; + static constexpr fq twist_coeff_b_1{ + 0x38e7ecccd1dcff67UL, 0x65f0b37d93ce0d3eUL, 0xd749d0dd22ac00aaUL, 0x0141b9ce4a688d4dUL + }; + static constexpr fq twist_mul_by_q_x_0{ + 0xb5773b104563ab30UL, 0x347f91c8a9aa6454UL, 0x7a007127242e0991UL, 0x1956bcd8118214ecUL + }; + static constexpr fq twist_mul_by_q_x_1{ + 0x6e849f1ea0aa4757UL, 0xaa1c7b6d89f89141UL, 0xb6e713cdfae0ca3aUL, 0x26694fbb4e82ebc3UL + }; + static constexpr fq twist_mul_by_q_y_0{ + 0xe4bbdd0c2936b629UL, 0xbb30f162e133bacbUL, 0x31a9d1b6f9645366UL, 0x253570bea500f8ddUL + }; + static constexpr fq twist_mul_by_q_y_1{ + 0xa1d77ce45ffe77c7UL, 0x07affd117826d1dbUL, 0x6d16bd27bb7edc6bUL, 0x2c87200285defeccUL + }; + static constexpr fq twist_cube_root_0{ + 0x505ecc6f0dff1ac2UL, 0x2071416db35ec465UL, 0xf2b53469fa43ea78UL, 0x18545552044c99aaUL + }; + static constexpr fq twist_cube_root_1{ + 0xad607f911cfe17a8UL, 0xb6bb78aa154154c4UL, 0xb53dd351736b20dbUL, 0x1d8ed57c5cc33d41UL + }; +#else + static constexpr fq twist_coeff_b_0{ + 0xdc19fa4aab489658UL, 0xd416744fbbf6e69UL, 0x8f7734ed0a8a033aUL, 0x19316b8353ee09bbUL + }; + static constexpr fq twist_coeff_b_1{ + 0x1cfd999a3b9fece0UL, 0xbe166fb279c1a7c7UL, 0xe93a1ba45580154cUL, 0x283739c94d11a9baUL + }; + static constexpr fq twist_mul_by_q_x_0{ + 0xecdea09b24a59190UL, 0x17db8ffeae2fe1c2UL, 0xbb09c97c6dabac4dUL, 0x2492b3d41d289af3UL + }; + static constexpr fq twist_mul_by_q_x_1{ + 0xf1663598f1142ef1UL, 0x77ec057e0bf56062UL, 0xdd0baaecb677a631UL, 0x135e4e31d284d463UL + }; + static constexpr fq twist_mul_by_q_y_0{ + 0xf46e7f60db1f0678UL, 0x31fc2eba5bcc5c3eUL, 0xedb3adc3086a2411UL, 0x1d46bd0f837817bcUL + }; + static constexpr fq twist_mul_by_q_y_1{ + 0x6b3fbdf579a647d5UL, 0xcc568fb62ff64974UL, 0xc1bfbf4ac4348ac6UL, 0x15871d4d3940b4d3UL + }; + static constexpr fq twist_cube_root_0{ + 0x49d0cc74381383d0UL, 0x9611849fe4bbe3d6UL, 0xd1a231d73067c92aUL, 0x445c312767932c2UL + }; + static constexpr fq twist_cube_root_1{ + 0x35a58c718e7c28bbUL, 0x98d42c77e7b8901aUL, 0xf9c53da2d0ca8c84UL, 0x1a68dd04e1b8c51dUL + }; +#endif +}; + +using fq2 = field2; +} // namespace bb \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fr.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fr.hpp new file mode 100644 index 0000000..fcf2bc1 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/fr.hpp @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include + +#include "../../fields/field.hpp" + +// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays) + +namespace bb { +class Bn254FrParams { + public: + // Note: limbs here are combined as concat(_3, _2, _1, _0) + // E.g. this modulus forms the value: + // 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001 + // = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + static constexpr uint64_t modulus_0 = 0x43E1F593F0000001UL; + static constexpr uint64_t modulus_1 = 0x2833E84879B97091UL; + static constexpr uint64_t modulus_2 = 0xB85045B68181585DUL; + static constexpr uint64_t modulus_3 = 0x30644E72E131A029UL; + + static constexpr uint64_t r_squared_0 = 0x1BB8E645AE216DA7UL; + static constexpr uint64_t r_squared_1 = 0x53FE3AB1E35C59E3UL; + static constexpr uint64_t r_squared_2 = 0x8C49833D53BB8085UL; + static constexpr uint64_t r_squared_3 = 0x216D0B17F4E44A5UL; + + static constexpr uint64_t cube_root_0 = 0x93e7cede4a0329b3UL; + static constexpr uint64_t cube_root_1 = 0x7d4fdca77a96c167UL; + static constexpr uint64_t cube_root_2 = 0x8be4ba08b19a750aUL; + static constexpr uint64_t cube_root_3 = 0x1cbd5653a5661c25UL; + + static constexpr uint64_t primitive_root_0 = 0x636e735580d13d9cUL; + static constexpr uint64_t primitive_root_1 = 0xa22bf3742445ffd6UL; + static constexpr uint64_t primitive_root_2 = 0x56452ac01eb203d8UL; + static constexpr uint64_t primitive_root_3 = 0x1860ef942963f9e7UL; + + static constexpr uint64_t endo_g1_lo = 0x7a7bd9d4391eb18dUL; + static constexpr uint64_t endo_g1_mid = 0x4ccef014a773d2cfUL; + static constexpr uint64_t endo_g1_hi = 0x0000000000000002UL; + static constexpr uint64_t endo_g2_lo = 0xd91d232ec7e0b3d7UL; + static constexpr uint64_t endo_g2_mid = 0x0000000000000002UL; + static constexpr uint64_t endo_minus_b1_lo = 0x8211bbeb7d4f1128UL; + static constexpr uint64_t endo_minus_b1_mid = 0x6f4d8248eeb859fcUL; + static constexpr uint64_t endo_b2_lo = 0x89d3256894d213e3UL; + static constexpr uint64_t endo_b2_mid = 0UL; + + static constexpr uint64_t r_inv = 0xc2e1f593efffffffUL; + + static constexpr uint64_t coset_generators_0[8]{ + 0x5eef048d8fffffe7ULL, 0xb8538a9dfffffe2ULL, 0x3057819e4fffffdbULL, 0xdcedb5ba9fffffd6ULL, + 0x8983e9d6efffffd1ULL, 0x361a1df33fffffccULL, 0xe2b0520f8fffffc7ULL, 0x8f46862bdfffffc2ULL, + }; + static constexpr uint64_t coset_generators_1[8]{ + 0x12ee50ec1ce401d0ULL, 0x49eac781bc44cefaULL, 0x307f6d866832bb01ULL, 0x677be41c0793882aULL, + 0x9e785ab1a6f45554ULL, 0xd574d1474655227eULL, 0xc7147dce5b5efa7ULL, 0x436dbe728516bcd1ULL, + }; + static constexpr uint64_t coset_generators_2[8]{ + 0x29312d5a5e5ee7ULL, 0x6697d49cd2d7a515ULL, 0x5c65ec9f484e3a89ULL, 0xc2d4900ec0c780b7ULL, + 0x2943337e3940c6e5ULL, 0x8fb1d6edb1ba0d13ULL, 0xf6207a5d2a335342ULL, 0x5c8f1dcca2ac9970ULL, + }; + static constexpr uint64_t coset_generators_3[8]{ + 0x463456c802275bedULL, 0x543ece899c2f3b1cULL, 0x180a96573d3d9f8ULL, 0xf8b21270ddbb927ULL, + 0x1d9598e8a7e39857ULL, 0x2ba010aa41eb7786ULL, 0x39aa886bdbf356b5ULL, 0x47b5002d75fb35e5ULL, + }; + + static constexpr uint64_t modulus_wasm_0 = 0x10000001; + static constexpr uint64_t modulus_wasm_1 = 0x1f0fac9f; + static constexpr uint64_t modulus_wasm_2 = 0xe5c2450; + static constexpr uint64_t modulus_wasm_3 = 0x7d090f3; + static constexpr uint64_t modulus_wasm_4 = 0x1585d283; + static constexpr uint64_t modulus_wasm_5 = 0x2db40c0; + static constexpr uint64_t modulus_wasm_6 = 0xa6e141; + static constexpr uint64_t modulus_wasm_7 = 0xe5c2634; + static constexpr uint64_t modulus_wasm_8 = 0x30644e; + + static constexpr uint64_t r_squared_wasm_0 = 0x38c2e14b45b69bd4UL; + static constexpr uint64_t r_squared_wasm_1 = 0x0ffedb1885883377UL; + static constexpr uint64_t r_squared_wasm_2 = 0x7840f9f0abc6e54dUL; + static constexpr uint64_t r_squared_wasm_3 = 0x0a054a3e848b0f05UL; + + static constexpr uint64_t cube_root_wasm_0 = 0x7334a1ce7065364dUL; + static constexpr uint64_t cube_root_wasm_1 = 0xae21578e4a14d22aUL; + static constexpr uint64_t cube_root_wasm_2 = 0xcea2148a96b51265UL; + static constexpr uint64_t cube_root_wasm_3 = 0x0038f7edf614a198UL; + + static constexpr uint64_t primitive_root_wasm_0 = 0x2faf11711a27b370UL; + static constexpr uint64_t primitive_root_wasm_1 = 0xc23fe9fced28f1b8UL; + static constexpr uint64_t primitive_root_wasm_2 = 0x43a0fc9bbe2af541UL; + static constexpr uint64_t primitive_root_wasm_3 = 0x05d90b5719653a4fUL; + + static constexpr uint64_t coset_generators_wasm_0[8] = { 0xab46711cdffffcb2ULL, 0xdb1b52736ffffc09ULL, + 0x0af033c9fffffb60ULL, 0xf6e31f8c9ffffab6ULL, + 0x26b800e32ffffa0dULL, 0x568ce239bffff964ULL, + 0x427fcdfc5ffff8baULL, 0x7254af52effff811ULL }; + static constexpr uint64_t coset_generators_wasm_1[8] = { 0x2476607dbd2dfff1ULL, 0x9a3208a561c2b00bULL, + 0x0fedb0cd06576026ULL, 0x5d7570ac31329faeULL, + 0xd33118d3d5c74fc9ULL, 0x48ecc0fb7a5bffe3ULL, + 0x967480daa5373f6cULL, 0x0c30290249cbef86ULL }; + static constexpr uint64_t coset_generators_wasm_2[8] = { 0xe6b99ee0068dfc25ULL, 0x39bb9964882aa6a5ULL, + 0x8cbd93e909c75126ULL, 0x276f48b709e2a349ULL, + 0x7a71433b8b7f4dc9ULL, 0xcd733dc00d1bf84aULL, + 0x6824f28e0d374a6dULL, 0xbb26ed128ed3f4eeULL }; + static constexpr uint64_t coset_generators_wasm_3[8] = { 0x1484c05bce00b620ULL, 0x224cf685243dfa96ULL, + 0x30152cae7a7b3f0bULL, 0x0d791464ef86e357ULL, + 0x1b414a8e45c427ccULL, 0x290980b79c016c41ULL, + 0x066d686e110d108dULL, 0x14359e97674a5502ULL }; + + // used in msgpack schema serialization + static constexpr char schema_name[] = "fr"; + static constexpr bool has_high_2adicity = true; + + // This is a BN254 scalar, so it represents one BN254 scalar + static constexpr size_t NUM_BN254_SCALARS = 1; +}; + +using fr = field; + +} // namespace bb + +// NOLINTEND(cppcoreguidelines-avoid-c-arrays) diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/g1.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/g1.hpp new file mode 100644 index 0000000..3e4e8a1 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/g1.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "../../groups/group.hpp" +#include "./fq.hpp" +#include "./fr.hpp" + +namespace bb { +struct Bn254G1Params { + static constexpr bool USE_ENDOMORPHISM = true; + static constexpr bool can_hash_to_curve = true; + static constexpr bool small_elements = true; + static constexpr bool has_a = false; + static constexpr fq one_x = fq::one(); +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + static constexpr fq one_y{ 0xa6ba871b8b1e1b3aUL, 0x14f1d651eb8e167bUL, 0xccdd46def0f28c58UL, 0x1c14ef83340fbe5eUL }; +#else + static constexpr fq one_y{ 0x9d0709d62af99842UL, 0xf7214c0419c29186UL, 0xa603f5090339546dUL, 0x1b906c52ac7a88eaUL }; +#endif + static constexpr fq a{ 0UL, 0UL, 0UL, 0UL }; +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + static constexpr fq b{ 0x7a17caa950ad28d7UL, 0x1f6ac17ae15521b9UL, 0x334bea4e696bd284UL, 0x2a1f6744ce179d8eUL }; +#else + static constexpr fq b{ 0xeb8a8ec140766463UL, 0xf2b1f20626a3da49UL, 0xf905ef8d84d5fea4UL, 0x2958a27c02b7cd5fUL }; +#endif +}; + +using g1 = group; + +} // namespace bb + diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/g2.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/g2.hpp new file mode 100644 index 0000000..78ede0c --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/curves/bn254/g2.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "../../groups/group.hpp" +#include "./fq2.hpp" +#include "./fr.hpp" + +namespace bb { +struct Bn254G2Params { + static constexpr bool USE_ENDOMORPHISM = false; + static constexpr bool can_hash_to_curve = false; + static constexpr bool small_elements = false; + static constexpr bool has_a = false; + +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + static constexpr fq2 one_x{ { 0x8e83b5d102bc2026, 0xdceb1935497b0172, 0xfbb8264797811adf, 0x19573841af96503b }, + { 0xafb4737da84c6140, 0x6043dd5a5802d8c4, 0x09e950fc52a02f86, 0x14fef0833aea7b6b } }; + static constexpr fq2 one_y{ { 0x619dfa9d886be9f6, 0xfe7fd297f59e9b78, 0xff9e1a62231b7dfe, 0x28fd7eebae9e4206 }, + { 0x64095b56c71856ee, 0xdc57f922327d3cbb, 0x55f935be33351076, 0x0da4a0e693fd6482 } }; +#else + static constexpr fq2 one_x{ + { 0xe6df8b2cfb43050UL, 0x254c7d92a843857eUL, 0xf2006d8ad80dd622UL, 0x24a22107dfb004e3UL }, + { 0xe8e7528c0b334b65UL, 0x56e941e8b293cf69UL, 0xe1169545c074740bUL, 0x2ac61491edca4b42UL } + }; + static constexpr fq2 one_y{ + { 0xdc508d48384e8843UL, 0xd55415a8afd31226UL, 0x834bf204bacb6e00UL, 0x51b9758138c5c79UL }, + { 0x64067e0b46a5f641UL, 0x37726529a3a77875UL, 0x4454445bd915f391UL, 0x10d5ac894edeed3UL } + }; +#endif + static constexpr fq2 a = fq2::zero(); + static constexpr fq2 b = fq2::twist_coeff_b(); +}; + +using g2 = group; +} // namespace bb \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/fields/asm_macros.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/asm_macros.hpp new file mode 100644 index 0000000..bb657b0 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/asm_macros.hpp @@ -0,0 +1,990 @@ +#pragma once +// clang-format off + +/* + * Clear all flags via xorq opcode + **/ +#define CLEAR_FLAGS(empty_reg) \ + "xorq " empty_reg ", " empty_reg " \n\t" + +/** + * Load 4-limb field element, pointed to by a, into + * registers (lolo, lohi, hilo, hihi) + **/ +#define LOAD_FIELD_ELEMENT(a, lolo, lohi, hilo, hihi) \ + "movq 0(" a "), " lolo " \n\t" \ + "movq 8(" a "), " lohi " \n\t" \ + "movq 16(" a "), " hilo " \n\t" \ + "movq 24(" a "), " hihi " \n\t" + +/** + * Store 4-limb field element located in + * registers (lolo, lohi, hilo, hihi), into + * memory pointed to by r + **/ +#define STORE_FIELD_ELEMENT(r, lolo, lohi, hilo, hihi) \ + "movq " lolo ", 0(" r ") \n\t" \ + "movq " lohi ", 8(" r ") \n\t" \ + "movq " hilo ", 16(" r ") \n\t" \ + "movq " hihi ", 24(" r ") \n\t" + +#if !defined(__ADX__) || defined(DISABLE_ADX) +/** + * Take a 4-limb field element, in (%r12, %r13, %r14, %r15), + * and add 4-limb field element pointed to by a + **/ +#define ADD(b) \ + "addq 0(" b "), %%r12 \n\t" \ + "adcq 8(" b "), %%r13 \n\t" \ + "adcq 16(" b "), %%r14 \n\t" \ + "adcq 24(" b "), %%r15 \n\t" + +/** + * Take a 4-limb field element, in (%r12, %r13, %r14, %r15), + * and subtract 4-limb field element pointed to by b + **/ +#define SUB(b) \ + "subq 0(" b "), %%r12 \n\t" \ + "sbbq 8(" b "), %%r13 \n\t" \ + "sbbq 16(" b "), %%r14 \n\t" \ + "sbbq 24(" b "), %%r15 \n\t" + + +/** + * Take a 4-limb field element, in (%r12, %r13, %r14, %r15), + * add 4-limb field element pointed to by b, and reduce modulo p + **/ +#define ADD_REDUCE(b, modulus_0, modulus_1, modulus_2, modulus_3) \ + "addq 0(" b "), %%r12 \n\t" \ + "adcq 8(" b "), %%r13 \n\t" \ + "adcq 16(" b "), %%r14 \n\t" \ + "adcq 24(" b "), %%r15 \n\t" \ + "movq %%r12, %%r8 \n\t" \ + "movq %%r13, %%r9 \n\t" \ + "movq %%r14, %%r10 \n\t" \ + "movq %%r15, %%r11 \n\t" \ + "addq " modulus_0 ", %%r12 \n\t" \ + "adcq " modulus_1 ", %%r13 \n\t" \ + "adcq " modulus_2 ", %%r14 \n\t" \ + "adcq " modulus_3 ", %%r15 \n\t" \ + "cmovncq %%r8, %%r12 \n\t" \ + "cmovncq %%r9, %%r13 \n\t" \ + "cmovncq %%r10, %%r14 \n\t" \ + "cmovncq %%r11, %%r15 \n\t" + + + +/** + * Take a 4-limb integer, r, in (%r12, %r13, %r14, %r15) + * and conditionally subtract modulus, if r > p. + **/ +#define REDUCE_FIELD_ELEMENT(neg_modulus_0, neg_modulus_1, neg_modulus_2, neg_modulus_3) \ + /* Duplicate `r` */ \ + "movq %%r12, %%r8 \n\t" \ + "movq %%r13, %%r9 \n\t" \ + "movq %%r14, %%r10 \n\t" \ + "movq %%r15, %%r11 \n\t" \ + "addq " neg_modulus_0 ", %%r12 \n\t" /* r'[0] -= modulus.data[0] */ \ + "adcq " neg_modulus_1 ", %%r13 \n\t" /* r'[1] -= modulus.data[1] */ \ + "adcq " neg_modulus_2 ", %%r14 \n\t" /* r'[2] -= modulus.data[2] */ \ + "adcq " neg_modulus_3 ", %%r15 \n\t" /* r'[3] -= modulus.data[3] */ \ + \ + /* if r does not need to be reduced, overflow flag is 1 */ \ + /* set r' = r if this flag is set */ \ + "cmovncq %%r8, %%r12 \n\t" \ + "cmovncq %%r9, %%r13 \n\t" \ + "cmovncq %%r10, %%r14 \n\t" \ + "cmovncq %%r11, %%r15 \n\t" + +/** + * Compute Montgomery squaring of a + * Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define SQR(a) \ + "movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \ + \ + "xorq %%r8, %%r8 \n\t" /* clear flags */ \ + /* compute a[0] *a[1], a[0]*a[2], a[0]*a[3], a[1]*a[2], a[1]*a[3], a[2]*a[3] */ \ + "mulxq 8(" a "), %%r9, %%r10 \n\t" /* (r[1], r[2]) <- a[0] * a[1] */ \ + "mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[1], t[2]) <- a[0] * a[2] */ \ + "mulxq 24(" a "), %%r11, %%r12 \n\t" /* (r[3], r[4]) <- a[0] * a[3] */ \ + \ + \ + /* accumulate products into result registers */ \ + "addq %%r8, %%r10 \n\t" /* r[2] += t[1] */ \ + "adcq %%r15, %%r11 \n\t" /* r[3] += t[2] */ \ + "movq 8(" a "), %%rdx \n\t" /* load a[1] into %r%dx */ \ + "mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[5], t[6]) <- a[1] * a[2] */ \ + "mulxq 24(" a "), %%rdi, %%rcx \n\t" /* (t[3], t[4]) <- a[1] * a[3] */ \ + "movq 24(" a "), %%rdx \n\t" /* load a[3] into %%rdx */ \ + "mulxq 16(" a "), %%r13, %%r14 \n\t" /* (r[5], r[6]) <- a[3] * a[2] */ \ + "adcq %%rdi, %%r12 \n\t" /* r[4] += t[3] */ \ + "adcq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "addq %%r8, %%r11 \n\t" /* r[3] += t[5] */ \ + "adcq %%r15, %%r12 \n\t" /* r[4] += t[6] */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + \ + /* double result registers */ \ + "addq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \ + "adcq %%r10, %%r10 \n\t" /* r[2] = 2r[2] */ \ + "adcq %%r11, %%r11 \n\t" /* r[3] = 2r[3] */ \ + "adcq %%r12, %%r12 \n\t" /* r[4] = 2r[4] */ \ + "adcq %%r13, %%r13 \n\t" /* r[5] = 2r[5] */ \ + "adcq %%r14, %%r14 \n\t" /* r[6] = 2r[6] */ \ + \ + /* compute a[3]*a[3], a[2]*a[2], a[1]*a[1], a[0]*a[0] */ \ + "movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \ + "mulxq %%rdx, %%r8, %%rcx \n\t" /* (r[0], t[4]) <- a[0] * a[0] */ \ + "movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq %%rdx, %%rdx, %%rdi \n\t" /* (t[7], t[8]) <- a[2] * a[2] */ \ + /* add squares into result registers */ \ + "addq %%rdx, %%r12 \n\t" /* r[4] += t[7] */ \ + "adcq %%rdi, %%r13 \n\t" /* r[5] += t[8] */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "addq %%rcx, %%r9 \n\t" /* r[1] += t[4] */ \ + "movq 24(" a "), %%rdx \n\t" /* r[2] += flag_c */ \ + "mulxq %%rdx, %%rcx, %%r15 \n\t" /* (t[5], r[7]) <- a[3] * a[3] */ \ + "movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq %%rdx, %%rdi, %%rdx \n\t" /* (t[3], t[6]) <- a[1] * a[1] */ \ + "adcq %%rdi, %%r10 \n\t" /* r[2] += t[3] */ \ + "adcq %%rdx, %%r11 \n\t" /* r[3] += t[6] */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + "addq %%rcx, %%r14 \n\t" /* r[6] += t[5] */ \ + "adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \ + \ + /* perform modular reduction: r[0] */ \ + "movq %%r8, %%rdx \n\t" /* move r8 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \ + "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "addq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \ + "adcq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_c */ \ + "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "adcq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_c */ \ + "adcq $0, %%r11 \n\t" /* r[4] += flag_c */ \ + /* Partial fix "adcq $0, %%r12 \n\t"*/ /* r[4] += flag_c */ \ + "addq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \ + "mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "mulxq %[modulus_3], %%r8, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcq %%rdi, %%r10 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \ + "adcq %%rdx, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + "addq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + \ + /* perform modular reduction: r[1] */ \ + "movq %%r9, %%rdx \n\t" /* move r9 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \ + "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "addq %%rdi, %%r9 \n\t" /* r[1] += t[0] (%r8 now free) */ \ + "adcq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \ + "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "adcq %%rcx, %%r11 \n\t" /* r[3] += t[3] + flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + "addq %%rdi, %%r10 \n\t" /* r[2] += t[2] */ \ + "mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "mulxq %[modulus_3], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcq %%rdi, %%r11 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "addq %%r8, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + \ + /* perform modular reduction: r[2] */ \ + "movq %%r10, %%rdx \n\t" /* move r10 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \ + "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "addq %%rdi, %%r10 \n\t" /* r[2] += t[0] (%r8 now free) */ \ + "adcq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \ + "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "mulxq %[modulus_3], %%r10, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcq %%rcx, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ + "adcq %%r9, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \ + "adcq %%rdx, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \ + "adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \ + "addq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \ + "adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \ + "adcq %%r10, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + \ + /* perform modular reduction: r[3] */ \ + "movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \ + "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "mulxq %[modulus_1], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "addq %%rdi, %%r11 \n\t" /* r[3] += t[0] (%r11 now free) */ \ + "adcq %%r8, %%r12 \n\t" /* r[4] += t[2] */ \ + "adcq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "mulxq %[modulus_3], %%r10, %%r11 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcq %%r9, %%r14 \n\t" /* r[6] += t[1] + flag_c */ \ + "adcq %%r11, %%r15 \n\t" /* r[7] += t[3] + flag_c */ \ + "addq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcq %%r8, %%r13 \n\t" /* r[5] += t[0] + flag_c */ \ + "adcq %%r10, %%r14 \n\t" /* r[6] += t[2] + flag_c */ \ + "adcq $0, %%r15 \n\t" /* r[7] += flag_c */ + + +/** + * Compute Montgomery multiplication of a, b. + * Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define MUL(a1, a2, a3, a4, b) \ + "movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \ + "xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \ + /* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \ + "mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \ + "mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \ + "mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \ + "mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \ + /* zero flags */ \ + \ + /* start computing modular reduction */ \ + "movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + \ + /* start first addition chain */ \ + "addq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \ + "adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adcq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + \ + /* reduce by r[0] * k */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "addq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \ + "adcq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \ + "adcq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adcq $0, %%r10 \n\t" /* r[3] += flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + "addq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_c */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "adcq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \ + "adcq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ + "addq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_i */ \ + \ + /* modulus = 254 bits, so max(t[3]) = 62 bits */ \ + /* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \ + /* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \ + /* (which is very convenient because we're out of registers!) */ \ + /* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \ + \ + /* a[1] * b */ \ + "movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \ + "addq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \ + "adcq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + "addq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \ + "adcq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + "addq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + \ + /* reduce by r[1] * k */ \ + "movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "addq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \ + "adcq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + "adcq $0, %%r12 \n\t" /* r[4] += flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + "addq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "adcq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + \ + /* a[2] * b */ \ + "movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \ + "addq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + "adcq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + "addq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \ + "adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \ + "adcq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "addq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + \ + /* reduce by r[2] * k */ \ + "movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "addq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \ + "adcq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcq $0, %%r13 \n\t" /* r[5] += flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "addq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "adcq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \ + "adcq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ + "adcq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \ + "addq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + \ + /* a[3] * b */ \ + "movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \ + "addq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \ + "adcq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \ + "adcq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \ + "adcq $0, %%r15 \n\t" /* r[7] += + flag_c */ \ + "addq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \ + "adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \ + \ + /* reduce by r[3] * k */ \ + "movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \ + "addq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \ + "adcq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "adcq $0, %%r14 \n\t" /* r[6] += flag_c */ \ + "adcq $0, %%r15 \n\t" /* r[7] += flag_c */ \ + "addq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \ + "adcq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \ + "adcq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \ + "adcq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \ + "addq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \ + "adcq $0, %%r15 \n\t" /* r[7] += flag_c */ + + +/** + * Compute 256-bit multiplication of a, b. + * Result is stored, r. // in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define MUL_256(a, b, r) \ + "movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \ + \ + /* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \ + "mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \ + "mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \ + "mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \ + "mulxq 16(" b "), %%r15, %%rax \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \ + /* zero flags */ \ + "xorq %%r10, %%r10 \n\t" /* clear r10 register, we use this when we need 0 */ \ + \ + \ + /* start first addition chain */ \ + "addq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \ + "adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adcq %%r10, %%rax \n\t" /* r[3] += flag_c */ \ + "addq %%rdi, %%rax \n\t" /* r[3] += t[2] + flag_c */ \ + \ + /* a[1] * b */ \ + "movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \ + "addq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \ + "adcq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adcq %%rsi, %%rax \n\t" /* r[3] += t[1] + flag_c */ \ + "addq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \ + "adcq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \ + \ + /* a[2] * b */ \ + "movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \ + "addq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcq %%r9, %%rax \n\t" /* r[3] += t[1] + flag_c */ \ + "addq %%rdi, %%rax \n\t" /* r[3] += t[0] + flag_c */ \ + \ + \ + /* a[3] * b */ \ + "movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \ + "adcq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \ + "movq %%r13, 0(" r ") \n\t" \ + "movq %%r14, 8(" r ") \n\t" \ + "movq %%r15, 16(" r ") \n\t" \ + "movq %%rax, 24(" r ") \n\t" + + +#else // 6047895us +/** + * Take a 4-limb field element, in (%r12, %r13, %r14, %r15), + * and add 4-limb field element pointed to by a + **/ +#define ADD(b) \ + "adcxq 0(" b "), %%r12 \n\t" \ + "adcxq 8(" b "), %%r13 \n\t" \ + "adcxq 16(" b "), %%r14 \n\t" \ + "adcxq 24(" b "), %%r15 \n\t" + +/** + * Take a 4-limb field element, in (%r12, %r13, %r14, %r15), + * and subtract 4-limb field element pointed to by b + **/ +#define SUB(b) \ + "subq 0(" b "), %%r12 \n\t" \ + "sbbq 8(" b "), %%r13 \n\t" \ + "sbbq 16(" b "), %%r14 \n\t" \ + "sbbq 24(" b "), %%r15 \n\t" + +/** + * Take a 4-limb field element, in (%r12, %r13, %r14, %r15), + * add 4-limb field element pointed to by b, and reduce modulo p + **/ +#define ADD_REDUCE(b, modulus_0, modulus_1, modulus_2, modulus_3) \ + "adcxq 0(" b "), %%r12 \n\t" \ + "movq %%r12, %%r8 \n\t" \ + "adoxq " modulus_0 ", %%r12 \n\t" \ + "adcxq 8(" b "), %%r13 \n\t" \ + "movq %%r13, %%r9 \n\t" \ + "adoxq " modulus_1 ", %%r13 \n\t" \ + "adcxq 16(" b "), %%r14 \n\t" \ + "movq %%r14, %%r10 \n\t" \ + "adoxq " modulus_2 ", %%r14 \n\t" \ + "adcxq 24(" b "), %%r15 \n\t" \ + "movq %%r15, %%r11 \n\t" \ + "adoxq " modulus_3 ", %%r15 \n\t" \ + "cmovnoq %%r8, %%r12 \n\t" \ + "cmovnoq %%r9, %%r13 \n\t" \ + "cmovnoq %%r10, %%r14 \n\t" \ + "cmovnoq %%r11, %%r15 \n\t" + + +/** + * Take a 4-limb integer, r, in (%r12, %r13, %r14, %r15) + * and conditionally subtract modulus, if r > p. + **/ +#define REDUCE_FIELD_ELEMENT(neg_modulus_0, neg_modulus_1, neg_modulus_2, neg_modulus_3) \ + /* Duplicate `r` */ \ + "movq %%r12, %%r8 \n\t" \ + "movq %%r13, %%r9 \n\t" \ + "movq %%r14, %%r10 \n\t" \ + "movq %%r15, %%r11 \n\t" \ + /* Add the negative representation of 'modulus' into `r`. We do this instead */ \ + /* of subtracting, because we can use `adoxq`. */ \ + /* This opcode only has a dependence on the overflow */ \ + /* flag (sub/sbb changes both carry and overflow flags). */ \ + /* We can process an `adcxq` and `acoxq` opcode simultaneously. */ \ + "adoxq " neg_modulus_0 ", %%r12 \n\t" /* r'[0] -= modulus.data[0] */ \ + "adoxq " neg_modulus_1 ", %%r13 \n\t" /* r'[1] -= modulus.data[1] */ \ + "adoxq " neg_modulus_2 ", %%r14 \n\t" /* r'[2] -= modulus.data[2] */ \ + "adoxq " neg_modulus_3 ", %%r15 \n\t" /* r'[3] -= modulus.data[3] */ \ + \ + /* if r does not need to be reduced, overflow flag is 1 */ \ + /* set r' = r if this flag is set */ \ + "cmovnoq %%r8, %%r12 \n\t" \ + "cmovnoq %%r9, %%r13 \n\t" \ + "cmovnoq %%r10, %%r14 \n\t" \ + "cmovnoq %%r11, %%r15 \n\t" + + +/** + * Compute Montgomery squaring of a + * Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define SQR(a) \ + "movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \ + \ + "xorq %%r8, %%r8 \n\t" /* clear flags */ \ + /* compute a[0] *a[1], a[0]*a[2], a[0]*a[3], a[1]*a[2], a[1]*a[3], a[2]*a[3] */ \ + "mulxq 8(" a "), %%r9, %%r10 \n\t" /* (r[1], r[2]) <- a[0] * a[1] */ \ + "mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[1], t[2]) <- a[0] * a[2] */ \ + "mulxq 24(" a "), %%r11, %%r12 \n\t" /* (r[3], r[4]) <- a[0] * a[3] */ \ + \ + \ + /* accumulate products into result registers */ \ + "adoxq %%r8, %%r10 \n\t" /* r[2] += t[1] */ \ + "adcxq %%r15, %%r11 \n\t" /* r[3] += t[2] */ \ + "movq 8(" a "), %%rdx \n\t" /* load a[1] into %r%dx */ \ + "mulxq 16(" a "), %%r8, %%r15 \n\t" /* (t[5], t[6]) <- a[1] * a[2] */ \ + "mulxq 24(" a "), %%rdi, %%rcx \n\t" /* (t[3], t[4]) <- a[1] * a[3] */ \ + "movq 24(" a "), %%rdx \n\t" /* load a[3] into %%rdx */ \ + "mulxq 16(" a "), %%r13, %%r14 \n\t" /* (r[5], r[6]) <- a[3] * a[2] */ \ + "adoxq %%r8, %%r11 \n\t" /* r[3] += t[5] */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] */ \ + "adoxq %%r15, %%r12 \n\t" /* r[4] += t[6] */ \ + "adcxq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \ + "adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ + "adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + \ + /* double result registers */ \ + "adoxq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \ + "adcxq %%r12, %%r12 \n\t" /* r[4] = 2r[4] */ \ + "adoxq %%r10, %%r10 \n\t" /* r[2] = 2r[2] */ \ + "adcxq %%r13, %%r13 \n\t" /* r[5] = 2r[5] */ \ + "adoxq %%r11, %%r11 \n\t" /* r[3] = 2r[3] */ \ + "adcxq %%r14, %%r14 \n\t" /* r[6] = 2r[6] */ \ + \ + /* compute a[3]*a[3], a[2]*a[2], a[1]*a[1], a[0]*a[0] */ \ + "movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \ + "mulxq %%rdx, %%r8, %%rcx \n\t" /* (r[0], t[4]) <- a[0] * a[0] */ \ + "movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq %%rdx, %%rdx, %%rdi \n\t" /* (t[7], t[8]) <- a[2] * a[2] */ \ + /* add squares into result registers */ \ + "adcxq %%rcx, %%r9 \n\t" /* r[1] += t[4] */ \ + "adoxq %%rdx, %%r12 \n\t" /* r[4] += t[7] */ \ + "adoxq %%rdi, %%r13 \n\t" /* r[5] += t[8] */ \ + "movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \ + "mulxq %%rdx, %%rcx, %%r15 \n\t" /* (t[5], r[7]) <- a[3] * a[3] */ \ + "movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq %%rdx, %%rdi, %%rdx \n\t" /* (t[3], t[6]) <- a[1] * a[1] */ \ + "adcxq %%rdi, %%r10 \n\t" /* r[2] += t[3] */ \ + "adcxq %%rdx, %%r11 \n\t" /* r[3] += t[6] */ \ + "adoxq %%rcx, %%r14 \n\t" /* r[6] += t[5] */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \ + \ + /* perform modular reduction: r[0] */ \ + "movq %%r8, %%rdx \n\t" /* move r8 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \ + "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "adoxq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \ + "mulxq %[modulus_3], %%r8, %%rdi \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ + "adoxq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_o */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ + "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ + "adoxq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_o */ \ + "adcxq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \ + "adoxq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_o */ \ + "mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "adcxq %%rdi, %%r10 \n\t" /* r[2] += t[0] + flag_c */ \ + "adcxq %%rcx, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \ + \ + /* perform modular reduction: r[1] */ \ + "movq %%r9, %%rdx \n\t" /* move r9 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[9] * r_inv */ \ + "mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "adoxq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "mulxq %[modulus_3], %%r8, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcxq %%r8, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \ + "adoxq %%rcx, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \ + "adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \ + "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ + "mulxq %[modulus_0], %%r8, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "adcxq %%r8, %%r9 \n\t" /* r[1] += t[0] (%r9 now free) */ \ + "adoxq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \ + "mulxq %[modulus_1], %%r8, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "adcxq %%r8, %%r10 \n\t" /* r[2] += t[2] */ \ + "adoxq %%rcx, %%r11 \n\t" /* r[3] += t[3] + flag_o */ \ + "adcxq %%rdi, %%r11 \n\t" /* r[3] += t[0] + flag_c */ \ + \ + /* perform modular reduction: r[2] */ \ + "movq %%r10, %%rdx \n\t" /* move r10 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \ + "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "adoxq %%rcx, %%r12 \n\t" /* r[4] += t[3] + flag_o */ \ + "adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \ + "adoxq %%r9, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \ + "mulxq %[modulus_3], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adcxq %%r8, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ + "adoxq %%r9, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \ + "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "adcxq %%r8, %%r10 \n\t" /* r[2] += t[0] (%r10 now free) */ \ + "adoxq %%r9, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \ + "adcxq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \ + "adoxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_o */ \ + "adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + \ + /* perform modular reduction: r[3] */ \ + "movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%rdi \n\t" /* (%rdx, _) <- k = r[10] * r_inv */ \ + "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ + "mulxq %[modulus_1], %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ + "adoxq %%rdi, %%r11 \n\t" /* r[3] += t[0] (%r11 now free) */ \ + "adcxq %%r8, %%r12 \n\t" /* r[4] += t[2] */ \ + "adoxq %%rcx, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adcxq %%r9, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ + "mulxq %[modulus_3], %%r10, %%r11 \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ + "adoxq %%r8, %%r13 \n\t" /* r[5] += t[0] + flag_o */ \ + "adcxq %%r10, %%r14 \n\t" /* r[6] += t[2] + flag_c */ \ + "adoxq %%r9, %%r14 \n\t" /* r[6] += t[1] + flag_o */ \ + "adcxq %%r11, %%r15 \n\t" /* r[7] += t[3] + flag_c */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ + +/** + * Compute Montgomery multiplication of a, b. + * Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define MUL(a1, a2, a3, a4, b) \ + "movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \ + "xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \ + /* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \ + "mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \ + "mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \ + "mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \ + "mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \ + /* zero flags */ \ + \ + /* start computing modular reduction */ \ + "movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + \ + /* start first addition chain */ \ + "adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \ + "adoxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_o */ \ + "adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + \ + /* reduce by r[0] * k */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "adcxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \ + "adoxq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_i */ \ + "adoxq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \ + "adcxq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_o */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "adoxq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \ + "adcxq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_o */ \ + "adcxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + \ + /* modulus = 254 bits, so max(t[3]) = 62 bits */ \ + /* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \ + /* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \ + /* (which is very convenient because we're out of registers!) */ \ + /* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \ + \ + /* a[1] * b */ \ + "movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \ + "adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \ + "adoxq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + "adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \ + "adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \ + "adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \ + "adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + \ + /* reduce by r[1] * k */ \ + "movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "adcxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_o */ \ + "adoxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adoxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "adoxq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \ + "adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \ + "adcxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + \ + /* a[2] * b */ \ + "movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \ + "adoxq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adoxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \ + "adcxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_o */ \ + "mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \ + "adoxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + "adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ + "adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + \ + /* reduce by r[2] * k */ \ + "movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "adcxq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + "adoxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \ + "adoxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "adcxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \ + "adoxq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + "adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \ + "adcxq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + \ + /* a[3] * b */ \ + "movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \ + "adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \ + "adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \ + \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \ + "adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \ + "adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_o */ \ + "adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \ + "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += + flag_o */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ + \ + /* reduce by r[3] * k */ \ + "movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \ + "adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \ + "adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \ + "adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \ + "adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \ + "adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_o */ \ + "adcxq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ + +/** + * Compute Montgomery multiplication of a, b. + * Result is stored, in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define MUL_FOO(a1, a2, a3, a4, b) \ + "movq " a1 ", %%rdx \n\t" /* load a[0] into %rdx */ \ + "xorq %%r8, %%r8 \n\t" /* clear r10 register, we use this when we need 0 */ \ + /* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \ + "mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \ + "mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \ + "mulxq 16(" b "), %%r15, %%r10 \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \ + "mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \ + /* zero flags */ \ + \ + /* start computing modular reduction */ \ + "movq %%r13, %%rdx \n\t" /* move r[0] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r11 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + \ + /* start first addition chain */ \ + "adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \ + "adoxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_o */ \ + "adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + \ + /* reduce by r[0] * k */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "adcxq %%rdi, %%r10 \n\t" /* r[3] += t[2] + flag_c */ \ + "adoxq %%r11, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_i */ \ + "adoxq %%r8, %%r13 \n\t" /* r[0] += t[0] (%r13 now free) */ \ + "adcxq %%r9, %%r14 \n\t" /* r[1] += t[1] + flag_o */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "adoxq %%rdi, %%r14 \n\t" /* r[1] += t[0] */ \ + "adcxq %%r11, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_o */ \ + "adcxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + \ + /* modulus = 254 bits, so max(t[3]) = 62 bits */ \ + /* b also 254 bits, so (a[0] * b[3]) = 62 bits */ \ + /* i.e. carry flag here is always 0 if b is in mont form, no need to update r[5] */ \ + /* (which is very convenient because we're out of registers!) */ \ + /* N.B. the value of r[4] now has a max of 63 bits and can accept another 62 bit value before overflowing */ \ + \ + /* a[1] * b */ \ + "movq " a2 ", %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r13 \n\t" /* (t[6], r[5]) <- (a[1] * b[3]) */ \ + "adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \ + "adoxq %%r9, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + "adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \ + "adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \ + "adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \ + "adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + \ + /* reduce by r[1] * k */ \ + "movq %%r14, %%rdx \n\t" /* move r[1] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "adcxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_o */ \ + "adoxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adoxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "adoxq %%r8, %%r14 \n\t" /* r[1] += t[0] (%r14 now free) */ \ + "adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \ + "adcxq %%r11, %%r10 \n\t" /* r[3] += t[1] + flag_c */ \ + \ + /* a[2] * b */ \ + "movq " a3 ", %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[2]) */ \ + "adoxq %%rdi, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adoxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_c */ \ + "adcxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_o */ \ + "mulxq 24(" b "), %%rdi, %%r14 \n\t" /* (t[2], r[6]) <- (a[2] * b[3]) */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \ + "adoxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_c */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + "adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ + "adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r9, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + \ + /* reduce by r[2] * k */ \ + "movq %%r15, %%rdx \n\t" /* move r[2] into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[0], t[1]) <- (modulus.data[1] * k) */ \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[2] * k) */ \ + "adcxq %%rdi, %%r10 \n\t" /* r[3] += t[1] + flag_o */ \ + "adoxq %%r11, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcxq %%r8, %%r12 \n\t" /* r[4] += t[0] + flag_o */ \ + "adoxq %%r9, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ + "mulxq %[modulus_3], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[3] * k) */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "adcxq %%rdi, %%r13 \n\t" /* r[5] += t[1] + flag_o */ \ + "adoxq %%r11, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \ + "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ + "adoxq %%r8, %%r15 \n\t" /* r[2] += t[0] (%r15 now free) */ \ + "adcxq %%r9, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + \ + /* a[3] * b */ \ + "movq " a4 ", %%rdx \n\t" /* load a[3] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%r11 \n\t" /* (t[4], t[5]) <- (a[3] * b[1]) */ \ + "adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] + flag_c */ \ + "adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_o */ \ + "adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_c */ \ + "adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \ + \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[3] * b[2]) */ \ + "mulxq 24(" b "), %%rdi, %%r15 \n\t" /* (t[6], r[7]) <- (a[3] * b[3]) */ \ + "adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_c */ \ + "adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_o */ \ + "adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_c */ \ + "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += + flag_o */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ + \ + /* reduce by r[3] * k */ \ + "movq %%r10, %%rdx \n\t" /* move r_inv into %rdx */ \ + "mulxq %[r_inv], %%rdx, %%r8 \n\t" /* (%rdx, _) <- k = r[1] * r_inv */ \ + "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus.data[0] * k) */ \ + "mulxq %[modulus_1], %%rdi, %%r11 \n\t" /* (t[2], t[3]) <- (modulus.data[1] * k) */ \ + "adoxq %%r8, %%r10 \n\t" /* r[3] += t[0] (%rsi now free) */ \ + "adcxq %%r9, %%r12 \n\t" /* r[4] += t[2] + flag_c */ \ + "adoxq %%rdi, %%r12 \n\t" /* r[4] += t[1] + flag_o */ \ + "adcxq %%r11, %%r13 \n\t" /* r[5] += t[3] + flag_c */ \ + \ + "mulxq %[modulus_2], %%r8, %%r9 \n\t" /* (t[4], t[5]) <- (modulus.data[2] * k) */ \ + "mulxq %[modulus_3], %%rdi, %%rdx \n\t" /* (t[6], t[7]) <- (modulus.data[3] * k) */ \ + "adoxq %%r8, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \ + "adcxq %%r9, %%r14 \n\t" /* r[6] += t[6] + flag_c */ \ + "adoxq %%rdi, %%r14 \n\t" /* r[6] += t[5] + flag_o */ \ + "adcxq %%rdx, %%r15 \n\t" /* r[7] += t[7] + flag_c */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ + +/** + * Compute 256-bit multiplication of a, b. + * Result is stored, r. // in (%%r12, %%r13, %%r14, %%r15), in preparation for being stored in "r" + **/ +#define MUL_256(a, b, r) \ + "movq 0(" a "), %%rdx \n\t" /* load a[0] into %rdx */ \ + \ + /* front-load mul ops, can parallelize 4 of these but latency is 4 cycles */ \ + "mulxq 8(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- a[0] * b[1] */ \ + "mulxq 24(" b "), %%rdi, %%r12 \n\t" /* (t[2], r[4]) <- a[0] * b[3] (overwrite a[0]) */ \ + "mulxq 0(" b "), %%r13, %%r14 \n\t" /* (r[0], r[1]) <- a[0] * b[0] */ \ + "mulxq 16(" b "), %%r15, %%rax \n\t" /* (r[2] , r[3]) <- a[0] * b[2] */ \ + /* zero flags */ \ + "xorq %%r10, %%r10 \n\t" /* clear r10 register, we use this when we need 0 */ \ + \ + \ + /* start first addition chain */ \ + "adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] */ \ + "adoxq %%rdi, %%rax \n\t" /* r[3] += t[2] + flag_o */ \ + "adcxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_c */ \ + "adcxq %%r10, %%rax \n\t" /* r[3] += flag_o */ \ + \ + /* a[1] * b */ \ + "movq 8(" a "), %%rdx \n\t" /* load a[1] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[1] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[4], t[5]) <- (a[1] * b[1]) */ \ + "adcxq %%r8, %%r14 \n\t" /* r[1] += t[0] + flag_c */ \ + "adoxq %%r9, %%r15 \n\t" /* r[2] += t[1] + flag_o */ \ + "adcxq %%rdi, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%rsi, %%rax \n\t" /* r[3] += t[1] + flag_o */ \ + \ + "mulxq 16(" b "), %%r8, %%r9 \n\t" /* (t[2], t[3]) <- (a[1] * b[2]) */ \ + "adcxq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \ + \ + /* a[2] * b */ \ + "movq 16(" a "), %%rdx \n\t" /* load a[2] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[2] * b[0]) */ \ + "mulxq 8(" b "), %%rdi, %%rsi \n\t" /* (t[0], t[1]) <- (a[2] * b[1]) */ \ + "adcxq %%r8, %%r15 \n\t" /* r[2] += t[0] + flag_c */ \ + "adoxq %%r9, %%rax \n\t" /* r[3] += t[1] + flag_o */ \ + "adcxq %%rdi, %%rax \n\t" /* r[3] += t[0] + flag_c */ \ + \ + \ + /* a[3] * b */ \ + "movq 24(" a "), %%rdx \n\t" /* load a[3] into %rdx */ \ + "mulxq 0(" b "), %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (a[3] * b[0]) */ \ + "adcxq %%r8, %%rax \n\t" /* r[3] += t[0] + flag_c */ \ + "movq %%r13, 0(" r ") \n\t" \ + "movq %%r14, 8(" r ") \n\t" \ + "movq %%r15, 16(" r ") \n\t" \ + "movq %%rax, 24(" r ") \n\t" +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field.hpp new file mode 100644 index 0000000..ce2de34 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field.hpp @@ -0,0 +1,10 @@ +#pragma once + +/** + * @brief Include order of header-only field class is structured to ensure linter/language server can resolve paths. + * Declarations are defined in "field_declarations.hpp", definitions in "field_impl.hpp" (which includes + * declarations header) Spectialized definitions are in "field_impl_generic.hpp" and "field_impl_x64.hpp" + * (which include "field_impl.hpp") + */ +#include "./field_impl_generic.hpp" +#include "./field_impl_x64.hpp" diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_declarations.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_declarations.hpp new file mode 100644 index 0000000..69cba7d --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_declarations.hpp @@ -0,0 +1,682 @@ +#pragma once +#include "../../common/assert.hpp" +#include "../../common/compiler_hints.hpp" +#include "../../numeric/random/engine.hpp" +#include "../../numeric/uint128/uint128.hpp" +#include "../../numeric/uint256/uint256.hpp" +#include +#include +#include +#include +#include + +#ifndef DISABLE_ASM +#ifdef __BMI2__ +#define BBERG_NO_ASM 0 +#else +#define BBERG_NO_ASM 1 +#endif +#else +#define BBERG_NO_ASM 1 +#endif + +namespace bb { + using namespace numeric; +/** + * @brief General class for prime fields see \ref field_docs["field documentation"] for general implementation reference + * + * @tparam Params_ + */ +template struct alignas(32) field { + public: + using View = field; + using Params = Params_; + using in_buf = const uint8_t*; + using vec_in_buf = const uint8_t*; + using out_buf = uint8_t*; + using vec_out_buf = uint8_t**; + +#if defined(__wasm__) || !defined(__SIZEOF_INT128__) +#define WASM_NUM_LIMBS 9 +#define WASM_LIMB_BITS 29 +#endif + + // We don't initialize data in the default constructor since we'd lose a lot of time on huge array initializations. + // Other alternatives have been noted, such as casting to get around constructors where they matter, + // however it is felt that sanitizer tools (e.g. MSAN) can detect garbage well, whereas doing + // hacky casts where needed would require rework to critical algos like MSM, FFT, Sumcheck. + // Instead, the recommended solution is use an explicit {} where initialization is important: + // field f; // not initialized + // field f{}; // zero-initialized + // std::array arr; // not initialized, good for huge N + // std::array arr {}; // zero-initialized, preferable for moderate N + field() = default; + + constexpr field(const numeric::uint256_t& input) noexcept + : data{ input.data[0], input.data[1], input.data[2], input.data[3] } + { + self_to_montgomery_form(); + } + + // NOLINTNEXTLINE (unsigned long is platform dependent, which we want in this case) + constexpr field(const unsigned long input) noexcept + : data{ input, 0, 0, 0 } + { + self_to_montgomery_form(); + } + + constexpr field(const unsigned int input) noexcept + : data{ input, 0, 0, 0 } + { + self_to_montgomery_form(); + } + + // NOLINTNEXTLINE (unsigned long long is platform dependent, which we want in this case) + constexpr field(const unsigned long long input) noexcept + : data{ input, 0, 0, 0 } + { + self_to_montgomery_form(); + } + + constexpr field(const int input) noexcept + : data{ 0, 0, 0, 0 } + { + if (input < 0) { + data[0] = static_cast(-input); + data[1] = 0; + data[2] = 0; + data[3] = 0; + self_to_montgomery_form(); + self_neg(); + self_reduce_once(); + } else { + data[0] = static_cast(input); + data[1] = 0; + data[2] = 0; + data[3] = 0; + self_to_montgomery_form(); + } + } + + constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept + : data{ a, b, c, d } {}; + + /** + * @brief Convert a 512-bit big integer into a field element. + * + * @details Used for deriving field elements from random values. 512-bits prevents biased output as 2^512>>modulus + * + */ + constexpr explicit field(const uint512_t& input) noexcept + { + uint256_t value = (input % modulus).lo; + data[0] = value.data[0]; + data[1] = value.data[1]; + data[2] = value.data[2]; + data[3] = value.data[3]; + self_to_montgomery_form(); + } + + constexpr explicit field(std::string input) noexcept + { + uint256_t value(input); + *this = field(value); + } + + constexpr explicit operator bool() const + { + field out = from_montgomery_form(); + ASSERT(out.data[0] == 0 || out.data[0] == 1); + return static_cast(out.data[0]); + } + + constexpr explicit operator uint8_t() const + { + field out = from_montgomery_form(); + return static_cast(out.data[0]); + } + + constexpr explicit operator uint16_t() const + { + field out = from_montgomery_form(); + return static_cast(out.data[0]); + } + + constexpr explicit operator uint32_t() const + { + field out = from_montgomery_form(); + return static_cast(out.data[0]); + } + + constexpr explicit operator uint64_t() const + { + field out = from_montgomery_form(); + return out.data[0]; + } + + constexpr explicit operator uint128_t() const + { + field out = from_montgomery_form(); + uint128_t lo = out.data[0]; + uint128_t hi = out.data[1]; + return (hi << 64) | lo; + } + + constexpr operator uint256_t() const noexcept + { + field out = from_montgomery_form(); + return uint256_t(out.data[0], out.data[1], out.data[2], out.data[3]); + } + + [[nodiscard]] constexpr uint256_t uint256_t_no_montgomery_conversion() const noexcept + { + return { data[0], data[1], data[2], data[3] }; + } + + constexpr field(const field& other) noexcept = default; + constexpr field(field&& other) noexcept = default; + constexpr field& operator=(const field& other) noexcept = default; + constexpr field& operator=(field&& other) noexcept = default; + constexpr ~field() noexcept = default; + alignas(32) uint64_t data[4]; // NOLINT + + static constexpr uint256_t modulus = + uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 }; +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + static constexpr uint256_t r_squared_uint{ + Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3 + }; +#else + static constexpr uint256_t r_squared_uint{ + Params_::r_squared_wasm_0, Params_::r_squared_wasm_1, Params_::r_squared_wasm_2, Params_::r_squared_wasm_3 + }; + static constexpr std::array wasm_modulus = { Params::modulus_wasm_0, Params::modulus_wasm_1, + Params::modulus_wasm_2, Params::modulus_wasm_3, + Params::modulus_wasm_4, Params::modulus_wasm_5, + Params::modulus_wasm_6, Params::modulus_wasm_7, + Params::modulus_wasm_8 }; + +#endif + static constexpr field cube_root_of_unity() + { + // endomorphism i.e. lambda * [P] = (beta * x, y) + if constexpr (Params::cube_root_0 != 0) { +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + constexpr field result{ + Params::cube_root_0, Params::cube_root_1, Params::cube_root_2, Params::cube_root_3 + }; +#else + constexpr field result{ + Params::cube_root_wasm_0, Params::cube_root_wasm_1, Params::cube_root_wasm_2, Params::cube_root_wasm_3 + }; +#endif + return result; + } else { + constexpr field two_inv = field(2).invert(); + constexpr field numerator = (-field(3)).sqrt() - field(1); + constexpr field result = two_inv * numerator; + return result; + } + } + + static constexpr field zero() { return field(0, 0, 0, 0); } + static constexpr field neg_one() { return -field(1); } + static constexpr field one() { return field(1); } + + static constexpr field external_coset_generator() + { +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const field result{ + Params::coset_generators_0[7], + Params::coset_generators_1[7], + Params::coset_generators_2[7], + Params::coset_generators_3[7], + }; +#else + const field result{ + Params::coset_generators_wasm_0[7], + Params::coset_generators_wasm_1[7], + Params::coset_generators_wasm_2[7], + Params::coset_generators_wasm_3[7], + }; +#endif + + return result; + } + + static constexpr field tag_coset_generator() + { +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const field result{ + Params::coset_generators_0[6], + Params::coset_generators_1[6], + Params::coset_generators_2[6], + Params::coset_generators_3[6], + }; +#else + const field result{ + Params::coset_generators_wasm_0[6], + Params::coset_generators_wasm_1[6], + Params::coset_generators_wasm_2[6], + Params::coset_generators_wasm_3[6], + }; +#endif + + return result; + } + + static constexpr field coset_generator(const size_t idx) + { + ASSERT(idx < 7); +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const field result{ + Params::coset_generators_0[idx], + Params::coset_generators_1[idx], + Params::coset_generators_2[idx], + Params::coset_generators_3[idx], + }; +#else + const field result{ + Params::coset_generators_wasm_0[idx], + Params::coset_generators_wasm_1[idx], + Params::coset_generators_wasm_2[idx], + Params::coset_generators_wasm_3[idx], + }; +#endif + + return result; + } + + BB_INLINE constexpr field operator*(const field& other) const noexcept; + BB_INLINE constexpr field operator+(const field& other) const noexcept; + BB_INLINE constexpr field operator-(const field& other) const noexcept; + BB_INLINE constexpr field operator-() const noexcept; + constexpr field operator/(const field& other) const noexcept; + + // prefix increment (++x) + BB_INLINE constexpr field operator++() noexcept; + // postfix increment (x++) + // NOLINTNEXTLINE + BB_INLINE constexpr field operator++(int) noexcept; + + BB_INLINE constexpr field& operator*=(const field& other) noexcept; + BB_INLINE constexpr field& operator+=(const field& other) noexcept; + BB_INLINE constexpr field& operator-=(const field& other) noexcept; + constexpr field& operator/=(const field& other) noexcept; + + // NOTE: comparison operators exist so that `field` is comparible with stl methods that require them. + // (e.g. std::sort) + // Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms. + BB_INLINE constexpr bool operator>(const field& other) const noexcept; + BB_INLINE constexpr bool operator<(const field& other) const noexcept; + BB_INLINE constexpr bool operator==(const field& other) const noexcept; + BB_INLINE constexpr bool operator!=(const field& other) const noexcept; + + BB_INLINE constexpr field to_montgomery_form() const noexcept; + BB_INLINE constexpr field from_montgomery_form() const noexcept; + + BB_INLINE constexpr field sqr() const noexcept; + BB_INLINE constexpr void self_sqr() noexcept; + + BB_INLINE constexpr field pow(const uint256_t& exponent) const noexcept; + BB_INLINE constexpr field pow(uint64_t exponent) const noexcept; + static_assert(Params::modulus_0 != 1); + static constexpr uint256_t modulus_minus_two = + uint256_t(Params::modulus_0 - 2ULL, Params::modulus_1, Params::modulus_2, Params::modulus_3); + constexpr field invert() const noexcept; + static void batch_invert(std::span coeffs) noexcept; + static void batch_invert(field* coeffs, size_t n) noexcept; + /** + * @brief Compute square root of the field element. + * + * @return if the element is a quadratic remainder, if it's not + */ + constexpr std::pair sqrt() const noexcept; + + BB_INLINE constexpr void self_neg() noexcept; + + BB_INLINE constexpr void self_to_montgomery_form() noexcept; + BB_INLINE constexpr void self_from_montgomery_form() noexcept; + + BB_INLINE constexpr void self_conditional_negate(uint64_t predicate) noexcept; + + BB_INLINE constexpr field reduce_once() const noexcept; + BB_INLINE constexpr void self_reduce_once() noexcept; + + BB_INLINE constexpr void self_set_msb() noexcept; + [[nodiscard]] BB_INLINE constexpr bool is_msb_set() const noexcept; + [[nodiscard]] BB_INLINE constexpr uint64_t is_msb_set_word() const noexcept; + + [[nodiscard]] BB_INLINE constexpr bool is_zero() const noexcept; + + static constexpr field get_root_of_unity(size_t subgroup_size) noexcept; + + static void serialize_to_buffer(const field& value, uint8_t* buffer) { write(buffer, value); } + + static field serialize_from_buffer(const uint8_t* buffer) { return from_buffer(buffer); } + + [[nodiscard]] BB_INLINE std::vector to_buffer() const { return to_buffer(*this); } + + struct wide_array { + uint64_t data[8]; // NOLINT + }; + BB_INLINE constexpr wide_array mul_512(const field& other) const noexcept; + BB_INLINE constexpr wide_array sqr_512() const noexcept; + + BB_INLINE constexpr field conditionally_subtract_from_double_modulus(const uint64_t predicate) const noexcept + { + if (predicate != 0) { + constexpr field p{ + twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] + }; + return p - *this; + } + return *this; + } + + /** + * For short Weierstrass curves y^2 = x^3 + b mod r, if there exists a cube root of unity mod r, + * we can take advantage of an enodmorphism to decompose a 254 bit scalar into 2 128 bit scalars. + * \beta = cube root of 1, mod q (q = order of fq) + * \lambda = cube root of 1, mod r (r = order of fr) + * + * For a point P1 = (X, Y), where Y^2 = X^3 + b, we know that + * the point P2 = (X * \beta, Y) is also a point on the curve + * We can represent P2 as a scalar multiplication of P1, where P2 = \lambda * P1 + * + * For a generic multiplication of P1 by a 254 bit scalar k, we can decompose k + * into 2 127 bit scalars (k1, k2), such that k = k1 - (k2 * \lambda) + * + * We can now represent (k * P1) as (k1 * P1) - (k2 * P2), where P2 = (X * \beta, Y). + * As k1, k2 have half the bit length of k, we have reduced the number of loop iterations of our + * scalar multiplication algorithm in half + * + * To find k1, k2, We use the extended euclidean algorithm to find 4 short scalars [a1, a2], [b1, b2] such that + * modulus = (a1 * b2) - (b1 * a2) + * We then compute scalars c1 = round(b2 * k / r), c2 = round(b1 * k / r), where + * k1 = (c1 * a1) + (c2 * a2), k2 = -((c1 * b1) + (c2 * b2)) + * We pre-compute scalars g1 = (2^256 * b1) / n, g2 = (2^256 * b2) / n, to avoid having to perform long division + * on 512-bit scalars + **/ + static void split_into_endomorphism_scalars(const field& k, field& k1, field& k2) + { + // if the modulus is a >= 255-bit integer, we need to use a basis where g1, g2 have been shifted by 2^384 + if constexpr (Params::modulus_3 >= 0x4000000000000000ULL) { + split_into_endomorphism_scalars_384(k, k1, k2); + } else { + std::pair, std::array> ret = split_into_endomorphism_scalars(k); + k1.data[0] = ret.first[0]; + k1.data[1] = ret.first[1]; + + // TODO(https://github.com/AztecProtocol/barretenberg/issues/851): We should move away from this hack by + // returning pair of uint64_t[2] instead of a half-set field +#if !defined(__clang__) && defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif + k2.data[0] = ret.second[0]; // NOLINT + k2.data[1] = ret.second[1]; +#if !defined(__clang__) && defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + } + } + + // NOTE: this form is only usable if the modulus is 254 bits or less, otherwise see + // split_into_endomorphism_scalars_384. + // TODO(https://github.com/AztecProtocol/barretenberg/issues/851): Unify these APIs. + static std::pair, std::array> split_into_endomorphism_scalars(const field& k) + { + static_assert(Params::modulus_3 < 0x4000000000000000ULL); + field input = k.reduce_once(); + + constexpr field endo_g1 = { Params::endo_g1_lo, Params::endo_g1_mid, Params::endo_g1_hi, 0 }; + + constexpr field endo_g2 = { Params::endo_g2_lo, Params::endo_g2_mid, 0, 0 }; + + constexpr field endo_minus_b1 = { Params::endo_minus_b1_lo, Params::endo_minus_b1_mid, 0, 0 }; + + constexpr field endo_b2 = { Params::endo_b2_lo, Params::endo_b2_mid, 0, 0 }; + + // compute c1 = (g2 * k) >> 256 + wide_array c1 = endo_g2.mul_512(input); + // compute c2 = (g1 * k) >> 256 + wide_array c2 = endo_g1.mul_512(input); + + // (the bit shifts are implicit, as we only utilize the high limbs of c1, c2 + + field c1_hi = { + c1.data[4], c1.data[5], c1.data[6], c1.data[7] + }; // *(field*)((uintptr_t)(&c1) + (4 * sizeof(uint64_t))); + field c2_hi = { + c2.data[4], c2.data[5], c2.data[6], c2.data[7] + }; // *(field*)((uintptr_t)(&c2) + (4 * sizeof(uint64_t))); + + // compute q1 = c1 * -b1 + wide_array q1 = c1_hi.mul_512(endo_minus_b1); + // compute q2 = c2 * b2 + wide_array q2 = c2_hi.mul_512(endo_b2); + + // FIX: Avoid using 512-bit multiplication as its not necessary. + // c1_hi, c2_hi can be uint256_t's and the final result (without montgomery reduction) + // could be casted to a field. + field q1_lo{ q1.data[0], q1.data[1], q1.data[2], q1.data[3] }; + field q2_lo{ q2.data[0], q2.data[1], q2.data[2], q2.data[3] }; + + field t1 = (q2_lo - q1_lo).reduce_once(); + field beta = cube_root_of_unity(); + field t2 = (t1 * beta + input).reduce_once(); + return { + { t2.data[0], t2.data[1] }, + { t1.data[0], t1.data[1] }, + }; + } + + static void split_into_endomorphism_scalars_384(const field& input, field& k1_out, field& k2_out) + { + constexpr field minus_b1f{ + Params::endo_minus_b1_lo, + Params::endo_minus_b1_mid, + 0, + 0, + }; + constexpr field b2f{ + Params::endo_b2_lo, + Params::endo_b2_mid, + 0, + 0, + }; + constexpr uint256_t g1{ + Params::endo_g1_lo, + Params::endo_g1_mid, + Params::endo_g1_hi, + Params::endo_g1_hihi, + }; + constexpr uint256_t g2{ + Params::endo_g2_lo, + Params::endo_g2_mid, + Params::endo_g2_hi, + Params::endo_g2_hihi, + }; + + field kf = input.reduce_once(); + uint256_t k{ kf.data[0], kf.data[1], kf.data[2], kf.data[3] }; + + uint512_t c1 = (uint512_t(k) * static_cast(g1)) >> 384; + uint512_t c2 = (uint512_t(k) * static_cast(g2)) >> 384; + + field c1f{ c1.lo.data[0], c1.lo.data[1], c1.lo.data[2], c1.lo.data[3] }; + field c2f{ c2.lo.data[0], c2.lo.data[1], c2.lo.data[2], c2.lo.data[3] }; + + c1f.self_to_montgomery_form(); + c2f.self_to_montgomery_form(); + c1f = c1f * minus_b1f; + c2f = c2f * b2f; + field r2f = c1f - c2f; + field beta = cube_root_of_unity(); + field r1f = input.reduce_once() - r2f * beta; + k1_out = r1f; + k2_out = -r2f; + } + + // static constexpr auto coset_generators = compute_coset_generators(); + // static constexpr std::array coset_generators = compute_coset_generators((1 << 30U)); + + friend std::ostream& operator<<(std::ostream& os, const field& a) + { + field out = a.from_montgomery_form(); + std::ios_base::fmtflags f(os.flags()); + os << std::hex << "0x" << std::setfill('0') << std::setw(16) << out.data[3] << std::setw(16) << out.data[2] + << std::setw(16) << out.data[1] << std::setw(16) << out.data[0]; + os.flags(f); + return os; + } + + BB_INLINE static void __copy(const field& a, field& r) noexcept { r = a; } // NOLINT + BB_INLINE static void __swap(field& src, field& dest) noexcept // NOLINT + { + field T = dest; + dest = src; + src = T; + } + + static field random_element(numeric::RNG* engine = nullptr) noexcept; + + static constexpr field multiplicative_generator() noexcept; + + static constexpr uint256_t twice_modulus = modulus + modulus; + static constexpr uint256_t not_modulus = -modulus; + static constexpr uint256_t twice_not_modulus = -twice_modulus; + + struct wnaf_table { + uint8_t windows[64]; // NOLINT + + constexpr wnaf_table(const uint256_t& target) + : windows{ + static_cast(target.data[0] & 15), static_cast((target.data[0] >> 4) & 15), + static_cast((target.data[0] >> 8) & 15), static_cast((target.data[0] >> 12) & 15), + static_cast((target.data[0] >> 16) & 15), static_cast((target.data[0] >> 20) & 15), + static_cast((target.data[0] >> 24) & 15), static_cast((target.data[0] >> 28) & 15), + static_cast((target.data[0] >> 32) & 15), static_cast((target.data[0] >> 36) & 15), + static_cast((target.data[0] >> 40) & 15), static_cast((target.data[0] >> 44) & 15), + static_cast((target.data[0] >> 48) & 15), static_cast((target.data[0] >> 52) & 15), + static_cast((target.data[0] >> 56) & 15), static_cast((target.data[0] >> 60) & 15), + static_cast(target.data[1] & 15), static_cast((target.data[1] >> 4) & 15), + static_cast((target.data[1] >> 8) & 15), static_cast((target.data[1] >> 12) & 15), + static_cast((target.data[1] >> 16) & 15), static_cast((target.data[1] >> 20) & 15), + static_cast((target.data[1] >> 24) & 15), static_cast((target.data[1] >> 28) & 15), + static_cast((target.data[1] >> 32) & 15), static_cast((target.data[1] >> 36) & 15), + static_cast((target.data[1] >> 40) & 15), static_cast((target.data[1] >> 44) & 15), + static_cast((target.data[1] >> 48) & 15), static_cast((target.data[1] >> 52) & 15), + static_cast((target.data[1] >> 56) & 15), static_cast((target.data[1] >> 60) & 15), + static_cast(target.data[2] & 15), static_cast((target.data[2] >> 4) & 15), + static_cast((target.data[2] >> 8) & 15), static_cast((target.data[2] >> 12) & 15), + static_cast((target.data[2] >> 16) & 15), static_cast((target.data[2] >> 20) & 15), + static_cast((target.data[2] >> 24) & 15), static_cast((target.data[2] >> 28) & 15), + static_cast((target.data[2] >> 32) & 15), static_cast((target.data[2] >> 36) & 15), + static_cast((target.data[2] >> 40) & 15), static_cast((target.data[2] >> 44) & 15), + static_cast((target.data[2] >> 48) & 15), static_cast((target.data[2] >> 52) & 15), + static_cast((target.data[2] >> 56) & 15), static_cast((target.data[2] >> 60) & 15), + static_cast(target.data[3] & 15), static_cast((target.data[3] >> 4) & 15), + static_cast((target.data[3] >> 8) & 15), static_cast((target.data[3] >> 12) & 15), + static_cast((target.data[3] >> 16) & 15), static_cast((target.data[3] >> 20) & 15), + static_cast((target.data[3] >> 24) & 15), static_cast((target.data[3] >> 28) & 15), + static_cast((target.data[3] >> 32) & 15), static_cast((target.data[3] >> 36) & 15), + static_cast((target.data[3] >> 40) & 15), static_cast((target.data[3] >> 44) & 15), + static_cast((target.data[3] >> 48) & 15), static_cast((target.data[3] >> 52) & 15), + static_cast((target.data[3] >> 56) & 15), static_cast((target.data[3] >> 60) & 15) + } + {} + }; + +#if defined(__wasm__) || !defined(__SIZEOF_INT128__) + BB_INLINE static constexpr void wasm_madd(uint64_t& left_limb, + const std::array& right_limbs, + uint64_t& result_0, + uint64_t& result_1, + uint64_t& result_2, + uint64_t& result_3, + uint64_t& result_4, + uint64_t& result_5, + uint64_t& result_6, + uint64_t& result_7, + uint64_t& result_8); + BB_INLINE static constexpr void wasm_reduce(uint64_t& result_0, + uint64_t& result_1, + uint64_t& result_2, + uint64_t& result_3, + uint64_t& result_4, + uint64_t& result_5, + uint64_t& result_6, + uint64_t& result_7, + uint64_t& result_8); + BB_INLINE static constexpr std::array wasm_convert(const uint64_t* data); +#endif + BB_INLINE static constexpr std::pair mul_wide(uint64_t a, uint64_t b) noexcept; + + BB_INLINE static constexpr uint64_t mac( + uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& carry_out) noexcept; + + BB_INLINE static constexpr void mac( + uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t& out, uint64_t& carry_out) noexcept; + + BB_INLINE static constexpr uint64_t mac_mini(uint64_t a, uint64_t b, uint64_t c, uint64_t& out) noexcept; + + BB_INLINE static constexpr void mac_mini( + uint64_t a, uint64_t b, uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept; + + BB_INLINE static constexpr uint64_t mac_discard_lo(uint64_t a, uint64_t b, uint64_t c) noexcept; + + BB_INLINE static constexpr uint64_t addc(uint64_t a, uint64_t b, uint64_t carry_in, uint64_t& carry_out) noexcept; + + BB_INLINE static constexpr uint64_t sbb(uint64_t a, uint64_t b, uint64_t borrow_in, uint64_t& borrow_out) noexcept; + + BB_INLINE static constexpr uint64_t square_accumulate(uint64_t a, + uint64_t b, + uint64_t c, + uint64_t carry_in_lo, + uint64_t carry_in_hi, + uint64_t& carry_lo, + uint64_t& carry_hi) noexcept; + BB_INLINE constexpr field reduce() const noexcept; + BB_INLINE constexpr field add(const field& other) const noexcept; + BB_INLINE constexpr field subtract(const field& other) const noexcept; + BB_INLINE constexpr field subtract_coarse(const field& other) const noexcept; + BB_INLINE constexpr field montgomery_mul(const field& other) const noexcept; + BB_INLINE constexpr field montgomery_mul_big(const field& other) const noexcept; + BB_INLINE constexpr field montgomery_square() const noexcept; + +#if (BBERG_NO_ASM == 0) + BB_INLINE static field asm_mul(const field& a, const field& b) noexcept; + BB_INLINE static field asm_sqr(const field& a) noexcept; + BB_INLINE static field asm_add(const field& a, const field& b) noexcept; + BB_INLINE static field asm_sub(const field& a, const field& b) noexcept; + BB_INLINE static field asm_mul_with_coarse_reduction(const field& a, const field& b) noexcept; + BB_INLINE static field asm_sqr_with_coarse_reduction(const field& a) noexcept; + BB_INLINE static field asm_add_with_coarse_reduction(const field& a, const field& b) noexcept; + BB_INLINE static field asm_sub_with_coarse_reduction(const field& a, const field& b) noexcept; + BB_INLINE static field asm_add_without_reduction(const field& a, const field& b) noexcept; + BB_INLINE static void asm_self_sqr(const field& a) noexcept; + BB_INLINE static void asm_self_add(const field& a, const field& b) noexcept; + BB_INLINE static void asm_self_sub(const field& a, const field& b) noexcept; + BB_INLINE static void asm_self_mul_with_coarse_reduction(const field& a, const field& b) noexcept; + BB_INLINE static void asm_self_sqr_with_coarse_reduction(const field& a) noexcept; + BB_INLINE static void asm_self_add_with_coarse_reduction(const field& a, const field& b) noexcept; + BB_INLINE static void asm_self_sub_with_coarse_reduction(const field& a, const field& b) noexcept; + BB_INLINE static void asm_self_add_without_reduction(const field& a, const field& b) noexcept; + + BB_INLINE static void asm_conditional_negate(field& r, uint64_t predicate) noexcept; + BB_INLINE static field asm_reduce_once(const field& a) noexcept; + BB_INLINE static void asm_self_reduce_once(const field& a) noexcept; + static constexpr uint64_t zero_reference = 0x00ULL; +#endif + static constexpr size_t COSET_GENERATOR_SIZE = 15; + constexpr field tonelli_shanks_sqrt() const noexcept; + static constexpr size_t primitive_root_log_size() noexcept; + static constexpr std::array compute_coset_generators() noexcept; + +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + static constexpr uint128_t lo_mask = 0xffffffffffffffffUL; +#endif +}; +} // namespace bb diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl.hpp new file mode 100644 index 0000000..4b00a20 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl.hpp @@ -0,0 +1,673 @@ +#pragma once +#include "../../common/op_count.hpp" +#include "../../common/slab_allocator.hpp" +#include "../../common/throw_or_abort.hpp" +#include "../../numeric/bitop/get_msb.hpp" +#include "../../numeric/random/engine.hpp" +#include +#include +#include +#include + +#include "./field_declarations.hpp" + +namespace bb { +using namespace numeric; +// clang-format off +// disable the following style guides: +// cppcoreguidelines-avoid-c-arrays : we make heavy use of c-style arrays here to prevent default-initialization of memory when constructing `field` objects. +// The intention is for field to act like a primitive numeric type with the performance/complexity trade-offs expected from this. +// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays) +// clang-format on +/** + * + * Mutiplication + * + **/ +template constexpr field field::operator*(const field& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::mul"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + // >= 255-bits or <= 64-bits. + return montgomery_mul(other); + } else { + if (std::is_constant_evaluated()) { + return montgomery_mul(other); + } + return asm_mul_with_coarse_reduction(*this, other); + } +} + +template constexpr field& field::operator*=(const field& other) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_mul"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + // >= 255-bits or <= 64-bits. + *this = operator*(other); + } else { + if (std::is_constant_evaluated()) { + *this = operator*(other); + } else { + asm_self_mul_with_coarse_reduction(*this, other); // asm_self_mul(*this, other); + } + } + return *this; +} + +/** + * + * Squaring + * + **/ +template constexpr field field::sqr() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::sqr"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + return montgomery_square(); + } else { + if (std::is_constant_evaluated()) { + return montgomery_square(); + } + return asm_sqr_with_coarse_reduction(*this); // asm_sqr(*this); + } +} + +template constexpr void field::self_sqr() noexcept +{ + BB_OP_COUNT_TRACK_NAME("f::self_sqr"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + *this = montgomery_square(); + } else { + if (std::is_constant_evaluated()) { + *this = montgomery_square(); + } else { + asm_self_sqr_with_coarse_reduction(*this); + } + } +} + +/** + * + * Addition + * + **/ +template constexpr field field::operator+(const field& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::add"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + return add(other); + } else { + if (std::is_constant_evaluated()) { + return add(other); + } + return asm_add_with_coarse_reduction(*this, other); // asm_add_without_reduction(*this, other); + } +} + +template constexpr field& field::operator+=(const field& other) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_add"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + (*this) = operator+(other); + } else { + if (std::is_constant_evaluated()) { + (*this) = operator+(other); + } else { + asm_self_add_with_coarse_reduction(*this, other); // asm_self_add(*this, other); + } + } + return *this; +} + +template constexpr field field::operator++() noexcept +{ + BB_OP_COUNT_TRACK_NAME("++f"); + return *this += 1; +} + +// NOLINTNEXTLINE(cert-dcl21-cpp) circular linting errors. If const is added, linter suggests removing +template constexpr field field::operator++(int) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::increment"); + field value_before_incrementing = *this; + *this += 1; + return value_before_incrementing; +} + +/** + * + * Subtraction + * + **/ +template constexpr field field::operator-(const field& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::sub"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + return subtract_coarse(other); // modulus - *this; + } else { + if (std::is_constant_evaluated()) { + return subtract_coarse(other); // subtract(other); + } + return asm_sub_with_coarse_reduction(*this, other); // asm_sub(*this, other); + } +} + +template constexpr field field::operator-() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("-f"); + if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] }; + return p - *this; // modulus - *this; + } + + // TODO(@zac-williamson): there are 3 ways we can make this more efficient + // 1: we subtract `p` from `*this` instead of `2p` + // 2: instead of `p - *this`, we use an asm block that does `p - *this` without the assembly reduction step + // 3: we replace `(p - *this).reduce_once()` with an assembly block that is equivalent to `p - *this`, + // but we call `REDUCE_FIELD_ELEMENT` with `not_twice_modulus` instead of `twice_modulus` + // not sure which is faster and whether any of the above might break something! + // + // More context below: + // the operator-(a, b) method's asm implementation has a sneaky was to check underflow. + // if `a - b` underflows we need to add in `2p`. Instead of conditional branching which would cause pipeline + // flushes, we add `2p` into the result of `a - b`. If the result triggers the overflow flag, then we know we are + // correcting an *underflow* produced from computing `a - b`. Finally...we use the overflow flag to conditionally + // move data into registers such that we end up with either `a - b` or `2p + (a - b)` (this is branchless). OK! So + // what's the problem? Well we assume that every field element lies between 0 and 2p - 1. But we are computing `2p - + // *this`! If *this = 0 then we exceed this bound hence the need for the extra reduction step. HOWEVER, we also know + // that 2p - *this won't underflow so we could skip the underflow check present in the assembly code + constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] }; + return (p - *this).reduce_once(); // modulus - *this; +} + +template constexpr field& field::operator-=(const field& other) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_sub"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + *this = subtract_coarse(other); // subtract(other); + } else { + if (std::is_constant_evaluated()) { + *this = subtract_coarse(other); // subtract(other); + } else { + asm_self_sub_with_coarse_reduction(*this, other); // asm_self_sub(*this, other); + } + } + return *this; +} + +template constexpr void field::self_neg() noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_neg"); + if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] }; + *this = p - *this; + } else { + constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] }; + *this = (p - *this).reduce_once(); + } +} + +template constexpr void field::self_conditional_negate(const uint64_t predicate) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_conditional_negate"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + *this = predicate ? -(*this) : *this; // NOLINT + } else { + if (std::is_constant_evaluated()) { + *this = predicate ? -(*this) : *this; // NOLINT + } else { + asm_conditional_negate(*this, predicate); + } + } +} + +/** + * @brief Greater-than operator + * @details comparison operators exist so that `field` is comparible with stl methods that require them. + * (e.g. std::sort) + * Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms. + * + * @tparam T + * @param other + * @return true + * @return false + */ +template constexpr bool field::operator>(const field& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::gt"); + const field left = reduce_once(); + const field right = other.reduce_once(); + const bool t0 = left.data[3] > right.data[3]; + const bool t1 = (left.data[3] == right.data[3]) && (left.data[2] > right.data[2]); + const bool t2 = + (left.data[3] == right.data[3]) && (left.data[2] == right.data[2]) && (left.data[1] > right.data[1]); + const bool t3 = (left.data[3] == right.data[3]) && (left.data[2] == right.data[2]) && + (left.data[1] == right.data[1]) && (left.data[0] > right.data[0]); + return (t0 || t1 || t2 || t3); +} + +/** + * @brief Less-than operator + * @details comparison operators exist so that `field` is comparible with stl methods that require them. + * (e.g. std::sort) + * Finite fields do not have an explicit ordering, these should *NEVER* be used in algebraic algorithms. + * + * @tparam T + * @param other + * @return true + * @return false + */ +template constexpr bool field::operator<(const field& other) const noexcept +{ + return (other > *this); +} + +template constexpr bool field::operator==(const field& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::eqeq"); + const field left = reduce_once(); + const field right = other.reduce_once(); + return (left.data[0] == right.data[0]) && (left.data[1] == right.data[1]) && (left.data[2] == right.data[2]) && + (left.data[3] == right.data[3]); +} + +template constexpr bool field::operator!=(const field& other) const noexcept +{ + return (!operator==(other)); +} + +template constexpr field field::to_montgomery_form() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::to_montgomery_form"); + constexpr field r_squared = + field{ r_squared_uint.data[0], r_squared_uint.data[1], r_squared_uint.data[2], r_squared_uint.data[3] }; + + field result = *this; + // TODO(@zac-williamson): are these reductions needed? + // Rationale: We want to take any 256-bit input and be able to convert into montgomery form. + // A basic heuristic we use is that any input into the `*` operator must be between [0, 2p - 1] + // to prevent overflows in the asm algorithm. + // However... r_squared is already reduced so perhaps we can relax this requirement? + // (would be good to identify a failure case where not calling self_reduce triggers an error) + result.self_reduce_once(); + result.self_reduce_once(); + result.self_reduce_once(); + return (result * r_squared).reduce_once(); +} + +template constexpr field field::from_montgomery_form() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::from_montgomery_form"); + constexpr field one_raw{ 1, 0, 0, 0 }; + return operator*(one_raw).reduce_once(); +} + +template constexpr void field::self_to_montgomery_form() noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_to_montgomery_form"); + constexpr field r_squared = + field{ r_squared_uint.data[0], r_squared_uint.data[1], r_squared_uint.data[2], r_squared_uint.data[3] }; + + self_reduce_once(); + self_reduce_once(); + self_reduce_once(); + *this *= r_squared; + self_reduce_once(); +} + +template constexpr void field::self_from_montgomery_form() noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_from_montgomery_form"); + constexpr field one_raw{ 1, 0, 0, 0 }; + *this *= one_raw; + self_reduce_once(); +} + +template constexpr field field::reduce_once() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::reduce_once"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + return reduce(); + } else { + if (std::is_constant_evaluated()) { + return reduce(); + } + return asm_reduce_once(*this); + } +} + +template constexpr void field::self_reduce_once() noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_reduce_once"); + if constexpr (BBERG_NO_ASM || (T::modulus_3 >= 0x4000000000000000ULL) || + (T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) { + *this = reduce(); + } else { + if (std::is_constant_evaluated()) { + *this = reduce(); + } else { + asm_self_reduce_once(*this); + } + } +} + +template constexpr field field::pow(const uint256_t& exponent) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::pow"); + field accumulator{ data[0], data[1], data[2], data[3] }; + field to_mul{ data[0], data[1], data[2], data[3] }; + const uint64_t maximum_set_bit = exponent.get_msb(); + + for (int i = static_cast(maximum_set_bit) - 1; i >= 0; --i) { + accumulator.self_sqr(); + if (exponent.get_bit(static_cast(i))) { + accumulator *= to_mul; + } + } + if (exponent == uint256_t(0)) { + accumulator = one(); + } else if (*this == zero()) { + accumulator = zero(); + } + return accumulator; +} + +template constexpr field field::pow(const uint64_t exponent) const noexcept +{ + return pow({ exponent, 0, 0, 0 }); +} + +template constexpr field field::invert() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::invert"); + if (*this == zero()) { + throw_or_abort("Trying to invert zero in the field"); + } + return pow(modulus_minus_two); +} + +template void field::batch_invert(field* coeffs, const size_t n) noexcept +{ + batch_invert(std::span{ coeffs, n }); +} + +template void field::batch_invert(std::span coeffs) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::batch_invert"); + const size_t n = coeffs.size(); + + auto temporaries_ptr = std::static_pointer_cast(get_mem_slab(n * sizeof(field))); + auto skipped_ptr = std::static_pointer_cast(get_mem_slab(n)); + auto temporaries = temporaries_ptr.get(); + auto* skipped = skipped_ptr.get(); + + field accumulator = one(); + for (size_t i = 0; i < n; ++i) { + temporaries[i] = accumulator; + if (coeffs[i].is_zero()) { + skipped[i] = true; + } else { + skipped[i] = false; + accumulator *= coeffs[i]; + } + } + + // std::vector temporaries; + // std::vector skipped; + // temporaries.reserve(n); + // skipped.reserve(n); + + // field accumulator = one(); + // for (size_t i = 0; i < n; ++i) { + // temporaries.emplace_back(accumulator); + // if (coeffs[i].is_zero()) { + // skipped.emplace_back(true); + // } else { + // skipped.emplace_back(false); + // accumulator *= coeffs[i]; + // } + // } + + accumulator = accumulator.invert(); + + field T0; + for (size_t i = n - 1; i < n; --i) { + if (!skipped[i]) { + T0 = accumulator * temporaries[i]; + accumulator *= coeffs[i]; + coeffs[i] = T0; + } + } +} + +template constexpr field field::tonelli_shanks_sqrt() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt"); + // Tonelli-shanks algorithm begins by finding a field element Q and integer S, + // such that (p - 1) = Q.2^{s} + + // We can compute the square root of a, by considering a^{(Q + 1) / 2} = R + // Once we have found such an R, we have + // R^{2} = a^{Q + 1} = a^{Q}a + // If a^{Q} = 1, we have found our square root. + // Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity. + // This is because t^{2^{s-1}} = a^{Q.2^{s-1}}. + // We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2} + // From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1 + // i.e. t^{2^{s-1}} = 1 + + // To proceed with computing our square root, we want to transform t into a smaller subgroup, + // specifically, the (s-2)'th roots of unity. + // We do this by finding some value b,such that + // (t.b^2)^{2^{s-2}} = 1 and R' = R.b + // Finding such a b is trivial, because from Euler's criterion, we know that, + // for any quadratic non-residue z, z^{(p - 1) / 2} = -1 + // i.e. z^{Q.2^{s-1}} = -1 + // => z^Q is a 2^{s-1}'th root of -1 + // => z^{Q^2} is a 2^{s-2}'th root of -1 + // Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1 + // => t.z^{Q^2} is a 2^{s - 2}'th root of unity. + + // We can iteratively transform t into ever smaller subgroups, until t = 1. + // At each iteration, we need to find a new value for b, which we can obtain + // by repeatedly squaring z^{Q} + constexpr uint256_t Q = (modulus - 1) >> static_cast(primitive_root_log_size() - 1); + constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 2; + + // __to_montgomery_form(Q_minus_one_over_two, Q_minus_one_over_two); + field z = coset_generator(0); // the generator is a non-residue + field b = pow(Q_minus_one_over_two); + field r = operator*(b); // r = a^{(Q + 1) / 2} + field t = r * b; // t = a^{(Q - 1) / 2 + (Q + 1) / 2} = a^{Q} + + // check if t is a square with euler's criterion + // if not, we don't have a quadratic residue and a has no square root! + field check = t; + for (size_t i = 0; i < primitive_root_log_size() - 1; ++i) { + check.self_sqr(); + } + if (check != one()) { + return zero(); + } + field t1 = z.pow(Q_minus_one_over_two); + field t2 = t1 * z; + field c = t2 * t1; // z^Q + + size_t m = primitive_root_log_size(); + + while (t != one()) { + size_t i = 0; + field t2m = t; + + // find the smallest value of m, such that t^{2^m} = 1 + while (t2m != one()) { + t2m.self_sqr(); + i += 1; + } + + size_t j = m - i - 1; + b = c; + while (j > 0) { + b.self_sqr(); + --j; + } // b = z^2^(m-i-1) + + c = b.sqr(); + t = t * c; + r = r * b; + m = i; + } + return r; +} + +template constexpr std::pair> field::sqrt() const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::sqrt"); + field root; + if constexpr ((T::modulus_0 & 0x3UL) == 0x3UL) { + constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2; + root = pow(sqrt_exponent); + } else { + root = tonelli_shanks_sqrt(); + } + if ((root * root) == (*this)) { + return std::pair(true, root); + } + return std::pair(false, field::zero()); + +} // namespace bb; + +template constexpr field field::operator/(const field& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::div"); + return operator*(other.invert()); +} + +template constexpr field& field::operator/=(const field& other) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::self_div"); + *this = operator/(other); + return *this; +} + +template constexpr void field::self_set_msb() noexcept +{ + data[3] = 0ULL | (1ULL << 63ULL); +} + +template constexpr bool field::is_msb_set() const noexcept +{ + return (data[3] >> 63ULL) == 1ULL; +} + +template constexpr uint64_t field::is_msb_set_word() const noexcept +{ + return (data[3] >> 63ULL); +} + +template constexpr bool field::is_zero() const noexcept +{ + return ((data[0] | data[1] | data[2] | data[3]) == 0) || + (data[0] == T::modulus_0 && data[1] == T::modulus_1 && data[2] == T::modulus_2 && data[3] == T::modulus_3); +} + +template constexpr field field::get_root_of_unity(size_t subgroup_size) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + field r{ T::primitive_root_0, T::primitive_root_1, T::primitive_root_2, T::primitive_root_3 }; +#else + field r{ T::primitive_root_wasm_0, T::primitive_root_wasm_1, T::primitive_root_wasm_2, T::primitive_root_wasm_3 }; +#endif + for (size_t i = primitive_root_log_size(); i > subgroup_size; --i) { + r.self_sqr(); + } + return r; +} + +template field field::random_element(numeric::RNG* engine) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::random_element"); + if (engine == nullptr) { + engine = &numeric::get_randomness(); + } + + uint512_t source = engine->get_random_uint512(); + uint512_t q(modulus); + uint512_t reduced = source % q; + return field(reduced.lo); +} + +template constexpr size_t field::primitive_root_log_size() noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::primitive_root_log_size"); + uint256_t target = modulus - 1; + size_t result = 0; + while (!target.get_bit(result)) { + ++result; + } + return result; +} + +template +constexpr std::array, field::COSET_GENERATOR_SIZE> field::compute_coset_generators() noexcept +{ + constexpr size_t n = COSET_GENERATOR_SIZE; + constexpr uint64_t subgroup_size = 1 << 30; + + std::array result{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + if (n > 0) { + result[0] = (multiplicative_generator()); + } + field work_variable = multiplicative_generator() + field(1); + + size_t count = 1; + while (count < n) { + // work_variable contains a new field element, and we need to test that, for all previous vector elements, + // result[i] / work_variable is not a member of our subgroup + field work_inverse = work_variable.invert(); + bool valid = true; + for (size_t j = 0; j < count; ++j) { + field subgroup_check = (work_inverse * result[j]).pow(subgroup_size); + if (subgroup_check == field(1)) { + valid = false; + break; + } + } + if (valid) { + result[count] = (work_variable); + ++count; + } + work_variable += field(1); + } + return result; +} + +template constexpr field field::multiplicative_generator() noexcept +{ + field target(1); + uint256_t p_minus_one_over_two = (modulus - 1) >> 1; + bool found = false; + while (!found) { + target += field(1); + found = (target.pow(p_minus_one_over_two) == -field(1)); + } + return target; +} + +} // namespace bb + +// clang-format off +// NOLINTEND(cppcoreguidelines-avoid-c-arrays) +// clang-format on diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl_generic.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl_generic.hpp new file mode 100644 index 0000000..69c5f92 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl_generic.hpp @@ -0,0 +1,945 @@ +#pragma once + +#include +#include + +#include "./field_impl.hpp" +#include "../../common/op_count.hpp" + +namespace bb { + using namespace numeric; +// NOLINTBEGIN(readability-implicit-bool-conversion) +template constexpr std::pair field::mul_wide(uint64_t a, uint64_t b) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t res = (static_cast(a) * static_cast(b)); + return { static_cast(res), static_cast(res >> 64) }; +#else + const uint64_t product = a * b; + return { product & 0xffffffffULL, product >> 32 }; +#endif +} + +template +constexpr uint64_t field::mac( + const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t carry_in, uint64_t& carry_out) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t res = static_cast(a) + (static_cast(b) * static_cast(c)) + + static_cast(carry_in); + carry_out = static_cast(res >> 64); + return static_cast(res); +#else + const uint64_t product = b * c + a + carry_in; + carry_out = product >> 32; + return product & 0xffffffffULL; +#endif +} + +template +constexpr void field::mac(const uint64_t a, + const uint64_t b, + const uint64_t c, + const uint64_t carry_in, + uint64_t& out, + uint64_t& carry_out) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t res = static_cast(a) + (static_cast(b) * static_cast(c)) + + static_cast(carry_in); + out = static_cast(res); + carry_out = static_cast(res >> 64); +#else + const uint64_t product = b * c + a + carry_in; + carry_out = product >> 32; + out = product & 0xffffffffULL; +#endif +} + +template +constexpr uint64_t field::mac_mini(const uint64_t a, + const uint64_t b, + const uint64_t c, + uint64_t& carry_out) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t res = static_cast(a) + (static_cast(b) * static_cast(c)); + carry_out = static_cast(res >> 64); + return static_cast(res); +#else + const uint64_t product = b * c + a; + carry_out = product >> 32; + return product & 0xffffffffULL; +#endif +} + +template +constexpr void field::mac_mini( + const uint64_t a, const uint64_t b, const uint64_t c, uint64_t& out, uint64_t& carry_out) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t res = static_cast(a) + (static_cast(b) * static_cast(c)); + out = static_cast(res); + carry_out = static_cast(res >> 64); +#else + const uint64_t result = b * c + a; + carry_out = result >> 32; + out = result & 0xffffffffULL; +#endif +} + +template +constexpr uint64_t field::mac_discard_lo(const uint64_t a, const uint64_t b, const uint64_t c) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t res = static_cast(a) + (static_cast(b) * static_cast(c)); + return static_cast(res >> 64); +#else + return (b * c + a) >> 32; +#endif +} + +template +constexpr uint64_t field::addc(const uint64_t a, + const uint64_t b, + const uint64_t carry_in, + uint64_t& carry_out) noexcept +{ + BB_OP_COUNT_TRACK(); +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + uint128_t res = static_cast(a) + static_cast(b) + static_cast(carry_in); + carry_out = static_cast(res >> 64); + return static_cast(res); +#else + uint64_t r = a + b; + const uint64_t carry_temp = r < a; + r += carry_in; + carry_out = carry_temp + (r < carry_in); + return r; +#endif +} + +template +constexpr uint64_t field::sbb(const uint64_t a, + const uint64_t b, + const uint64_t borrow_in, + uint64_t& borrow_out) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + uint128_t res = static_cast(a) - (static_cast(b) + static_cast(borrow_in >> 63)); + borrow_out = static_cast(res >> 64); + return static_cast(res); +#else + uint64_t t_1 = a - (borrow_in >> 63ULL); + uint64_t borrow_temp_1 = t_1 > a; + uint64_t t_2 = t_1 - b; + uint64_t borrow_temp_2 = t_2 > t_1; + borrow_out = 0ULL - (borrow_temp_1 | borrow_temp_2); + return t_2; +#endif +} + +template +constexpr uint64_t field::square_accumulate(const uint64_t a, + const uint64_t b, + const uint64_t c, + const uint64_t carry_in_lo, + const uint64_t carry_in_hi, + uint64_t& carry_lo, + uint64_t& carry_hi) noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const uint128_t product = static_cast(b) * static_cast(c); + const auto r0 = static_cast(product); + const auto r1 = static_cast(product >> 64); + uint64_t out = r0 + r0; + carry_lo = (out < r0); + out += a; + carry_lo += (out < a); + out += carry_in_lo; + carry_lo += (out < carry_in_lo); + carry_lo += r1; + carry_hi = (carry_lo < r1); + carry_lo += r1; + carry_hi += (carry_lo < r1); + carry_lo += carry_in_hi; + carry_hi += (carry_lo < carry_in_hi); + return out; +#else + const auto product = b * c; + const auto t0 = product + a + carry_in_lo; + const auto t1 = product + t0; + carry_hi = t1 < product; + const auto t2 = t1 + (carry_in_hi << 32); + carry_hi += t2 < t1; + carry_lo = t2 >> 32; + return t2 & 0xffffffffULL; +#endif +} + +template constexpr field field::reduce() const noexcept +{ + if constexpr (modulus.data[3] >= 0x4000000000000000ULL) { + uint256_t val{ data[0], data[1], data[2], data[3] }; + if (val >= modulus) { + val -= modulus; + } + return { val.data[0], val.data[1], val.data[2], val.data[3] }; + } + uint64_t t0 = data[0] + not_modulus.data[0]; + uint64_t c = t0 < data[0]; + auto t1 = addc(data[1], not_modulus.data[1], c, c); + auto t2 = addc(data[2], not_modulus.data[2], c, c); + auto t3 = addc(data[3], not_modulus.data[3], c, c); + const uint64_t selection_mask = 0ULL - c; // 0xffff... if we have overflowed. + const uint64_t selection_mask_inverse = ~selection_mask; + // if we overflow, we want to swap + return { + (data[0] & selection_mask_inverse) | (t0 & selection_mask), + (data[1] & selection_mask_inverse) | (t1 & selection_mask), + (data[2] & selection_mask_inverse) | (t2 & selection_mask), + (data[3] & selection_mask_inverse) | (t3 & selection_mask), + }; +} + +template constexpr field field::add(const field& other) const noexcept +{ + if constexpr (modulus.data[3] >= 0x4000000000000000ULL) { + uint64_t r0 = data[0] + other.data[0]; + uint64_t c = r0 < data[0]; + auto r1 = addc(data[1], other.data[1], c, c); + auto r2 = addc(data[2], other.data[2], c, c); + auto r3 = addc(data[3], other.data[3], c, c); + if (c) { + uint64_t b = 0; + r0 = sbb(r0, modulus.data[0], b, b); + r1 = sbb(r1, modulus.data[1], b, b); + r2 = sbb(r2, modulus.data[2], b, b); + r3 = sbb(r3, modulus.data[3], b, b); + // Since both values are in [0, 2**256), the result is in [0, 2**257-2]. Subtracting one p might not be + // enough. We need to ensure that we've underflown the 0 and that might require subtracting an additional p + if (!b) { + b = 0; + r0 = sbb(r0, modulus.data[0], b, b); + r1 = sbb(r1, modulus.data[1], b, b); + r2 = sbb(r2, modulus.data[2], b, b); + r3 = sbb(r3, modulus.data[3], b, b); + } + } + return { r0, r1, r2, r3 }; + } else { + uint64_t r0 = data[0] + other.data[0]; + uint64_t c = r0 < data[0]; + auto r1 = addc(data[1], other.data[1], c, c); + auto r2 = addc(data[2], other.data[2], c, c); + uint64_t r3 = data[3] + other.data[3] + c; + + uint64_t t0 = r0 + twice_not_modulus.data[0]; + c = t0 < twice_not_modulus.data[0]; + uint64_t t1 = addc(r1, twice_not_modulus.data[1], c, c); + uint64_t t2 = addc(r2, twice_not_modulus.data[2], c, c); + uint64_t t3 = addc(r3, twice_not_modulus.data[3], c, c); + const uint64_t selection_mask = 0ULL - c; + const uint64_t selection_mask_inverse = ~selection_mask; + + return { + (r0 & selection_mask_inverse) | (t0 & selection_mask), + (r1 & selection_mask_inverse) | (t1 & selection_mask), + (r2 & selection_mask_inverse) | (t2 & selection_mask), + (r3 & selection_mask_inverse) | (t3 & selection_mask), + }; + } +} + +template constexpr field field::subtract(const field& other) const noexcept +{ + uint64_t borrow = 0; + uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow); + uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow); + uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow); + uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow); + + r0 += (modulus.data[0] & borrow); + uint64_t carry = r0 < (modulus.data[0] & borrow); + r1 = addc(r1, modulus.data[1] & borrow, carry, carry); + r2 = addc(r2, modulus.data[2] & borrow, carry, carry); + r3 = addc(r3, (modulus.data[3] & borrow), carry, carry); + // The value being subtracted is in [0, 2**256), if we subtract 0 - 2*255 and then add p, the value will stay + // negative. If we are adding p, we need to check that we've overflown 2**256. If not, we should add p again + if (!carry) { + r0 += (modulus.data[0] & borrow); + uint64_t carry = r0 < (modulus.data[0] & borrow); + r1 = addc(r1, modulus.data[1] & borrow, carry, carry); + r2 = addc(r2, modulus.data[2] & borrow, carry, carry); + r3 = addc(r3, (modulus.data[3] & borrow), carry, carry); + } + return { r0, r1, r2, r3 }; +} + +/** + * @brief + * + * @tparam T + * @param other + * @return constexpr field + */ +template constexpr field field::subtract_coarse(const field& other) const noexcept +{ + if constexpr (modulus.data[3] >= 0x4000000000000000ULL) { + return subtract(other); + } + uint64_t borrow = 0; + uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow); + uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow); + uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow); + uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow); + + r0 += (twice_modulus.data[0] & borrow); + uint64_t carry = r0 < (twice_modulus.data[0] & borrow); + r1 = addc(r1, twice_modulus.data[1] & borrow, carry, carry); + r2 = addc(r2, twice_modulus.data[2] & borrow, carry, carry); + r3 += (twice_modulus.data[3] & borrow) + carry; + + return { r0, r1, r2, r3 }; +} + +/** + * @brief Mongtomery multiplication for moduli > 2²⁵⁴ + * + * @details Explanation of Montgomery form can be found in \ref field_docs_montgomery_explainer and the difference + * between WASM and generic versions is explained in \ref field_docs_architecture_details + */ +template constexpr field field::montgomery_mul_big(const field& other) const noexcept +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + uint64_t c = 0; + uint64_t t0 = 0; + uint64_t t1 = 0; + uint64_t t2 = 0; + uint64_t t3 = 0; + uint64_t t4 = 0; + uint64_t t5 = 0; + uint64_t k = 0; + for (const auto& element : data) { + c = 0; + mac(t0, element, other.data[0], c, t0, c); + mac(t1, element, other.data[1], c, t1, c); + mac(t2, element, other.data[2], c, t2, c); + mac(t3, element, other.data[3], c, t3, c); + t4 = addc(t4, c, 0, t5); + + c = 0; + k = t0 * T::r_inv; + c = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, k, modulus.data[1], c, t0, c); + mac(t2, k, modulus.data[2], c, t1, c); + mac(t3, k, modulus.data[3], c, t2, c); + t3 = addc(c, t4, 0, c); + t4 = t5 + c; + } + uint64_t borrow = 0; + uint64_t r0 = sbb(t0, modulus.data[0], borrow, borrow); + uint64_t r1 = sbb(t1, modulus.data[1], borrow, borrow); + uint64_t r2 = sbb(t2, modulus.data[2], borrow, borrow); + uint64_t r3 = sbb(t3, modulus.data[3], borrow, borrow); + borrow = borrow ^ (0ULL - t4); + r0 += (modulus.data[0] & borrow); + uint64_t carry = r0 < (modulus.data[0] & borrow); + r1 = addc(r1, modulus.data[1] & borrow, carry, carry); + r2 = addc(r2, modulus.data[2] & borrow, carry, carry); + r3 += (modulus.data[3] & borrow) + carry; + return { r0, r1, r2, r3 }; +#else + + // Convert 4 64-bit limbs to 9 29-bit limbs + auto left = wasm_convert(data); + auto right = wasm_convert(other.data); + constexpr uint64_t mask = 0x1fffffff; + uint64_t temp_0 = 0; + uint64_t temp_1 = 0; + uint64_t temp_2 = 0; + uint64_t temp_3 = 0; + uint64_t temp_4 = 0; + uint64_t temp_5 = 0; + uint64_t temp_6 = 0; + uint64_t temp_7 = 0; + uint64_t temp_8 = 0; + uint64_t temp_9 = 0; + uint64_t temp_10 = 0; + uint64_t temp_11 = 0; + uint64_t temp_12 = 0; + uint64_t temp_13 = 0; + uint64_t temp_14 = 0; + uint64_t temp_15 = 0; + uint64_t temp_16 = 0; + uint64_t temp_17 = 0; + + // Multiply-add 0th limb of the left argument by all 9 limbs of the right arguemnt + wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + // Instantly reduce + wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + // Continue for other limbs + wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + + // After all multiplications and additions, convert relaxed form to strict (all limbs are 29 bits) + temp_10 += temp_9 >> WASM_LIMB_BITS; + temp_9 &= mask; + temp_11 += temp_10 >> WASM_LIMB_BITS; + temp_10 &= mask; + temp_12 += temp_11 >> WASM_LIMB_BITS; + temp_11 &= mask; + temp_13 += temp_12 >> WASM_LIMB_BITS; + temp_12 &= mask; + temp_14 += temp_13 >> WASM_LIMB_BITS; + temp_13 &= mask; + temp_15 += temp_14 >> WASM_LIMB_BITS; + temp_14 &= mask; + temp_16 += temp_15 >> WASM_LIMB_BITS; + temp_15 &= mask; + temp_17 += temp_16 >> WASM_LIMB_BITS; + temp_16 &= mask; + + uint64_t r_temp_0; + uint64_t r_temp_1; + uint64_t r_temp_2; + uint64_t r_temp_3; + uint64_t r_temp_4; + uint64_t r_temp_5; + uint64_t r_temp_6; + uint64_t r_temp_7; + uint64_t r_temp_8; + // Subtract modulus from result + r_temp_0 = temp_9 - wasm_modulus[0]; + r_temp_1 = temp_10 - wasm_modulus[1] - ((r_temp_0) >> 63); + r_temp_2 = temp_11 - wasm_modulus[2] - ((r_temp_1) >> 63); + r_temp_3 = temp_12 - wasm_modulus[3] - ((r_temp_2) >> 63); + r_temp_4 = temp_13 - wasm_modulus[4] - ((r_temp_3) >> 63); + r_temp_5 = temp_14 - wasm_modulus[5] - ((r_temp_4) >> 63); + r_temp_6 = temp_15 - wasm_modulus[6] - ((r_temp_5) >> 63); + r_temp_7 = temp_16 - wasm_modulus[7] - ((r_temp_6) >> 63); + r_temp_8 = temp_17 - wasm_modulus[8] - ((r_temp_7) >> 63); + + // Depending on whether the subtraction underflowed, choose original value or the result of subtraction + uint64_t new_mask = 0 - (r_temp_8 >> 63); + uint64_t inverse_mask = (~new_mask) & mask; + temp_9 = (temp_9 & new_mask) | (r_temp_0 & inverse_mask); + temp_10 = (temp_10 & new_mask) | (r_temp_1 & inverse_mask); + temp_11 = (temp_11 & new_mask) | (r_temp_2 & inverse_mask); + temp_12 = (temp_12 & new_mask) | (r_temp_3 & inverse_mask); + temp_13 = (temp_13 & new_mask) | (r_temp_4 & inverse_mask); + temp_14 = (temp_14 & new_mask) | (r_temp_5 & inverse_mask); + temp_15 = (temp_15 & new_mask) | (r_temp_6 & inverse_mask); + temp_16 = (temp_16 & new_mask) | (r_temp_7 & inverse_mask); + temp_17 = (temp_17 & new_mask) | (r_temp_8 & inverse_mask); + + // Convert back to 4 64-bit limbs + return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58), + (temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52), + (temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46), + (temp_15 >> 18) | (temp_16 << 11) | (temp_17 << 40) }; + +#endif +} + +#if defined(__wasm__) || !defined(__SIZEOF_INT128__) + +/** + * @brief Multiply left limb by a sequence of 9 limbs and put into result variables + * + */ +template +constexpr void field::wasm_madd(uint64_t& left_limb, + const std::array& right_limbs, + uint64_t& result_0, + uint64_t& result_1, + uint64_t& result_2, + uint64_t& result_3, + uint64_t& result_4, + uint64_t& result_5, + uint64_t& result_6, + uint64_t& result_7, + uint64_t& result_8) +{ + result_0 += left_limb * right_limbs[0]; + result_1 += left_limb * right_limbs[1]; + result_2 += left_limb * right_limbs[2]; + result_3 += left_limb * right_limbs[3]; + result_4 += left_limb * right_limbs[4]; + result_5 += left_limb * right_limbs[5]; + result_6 += left_limb * right_limbs[6]; + result_7 += left_limb * right_limbs[7]; + result_8 += left_limb * right_limbs[8]; +} + +/** + * @brief Perform 29-bit montgomery reduction on 1 limb (result_0 should be zero modulo 2**29 after this) + * + */ +template +constexpr void field::wasm_reduce(uint64_t& result_0, + uint64_t& result_1, + uint64_t& result_2, + uint64_t& result_3, + uint64_t& result_4, + uint64_t& result_5, + uint64_t& result_6, + uint64_t& result_7, + uint64_t& result_8) +{ + constexpr uint64_t mask = 0x1fffffff; + constexpr uint64_t r_inv = T::r_inv & mask; + uint64_t k = (result_0 * r_inv) & mask; + result_0 += k * wasm_modulus[0]; + result_1 += k * wasm_modulus[1] + (result_0 >> WASM_LIMB_BITS); + result_2 += k * wasm_modulus[2]; + result_3 += k * wasm_modulus[3]; + result_4 += k * wasm_modulus[4]; + result_5 += k * wasm_modulus[5]; + result_6 += k * wasm_modulus[6]; + result_7 += k * wasm_modulus[7]; + result_8 += k * wasm_modulus[8]; +} +/** + * @brief Convert 4 64-bit limbs into 9 29-bit limbs + * + */ +template constexpr std::array field::wasm_convert(const uint64_t* data) +{ + return { data[0] & 0x1fffffff, + (data[0] >> WASM_LIMB_BITS) & 0x1fffffff, + ((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6), + (data[1] >> 23) & 0x1fffffff, + ((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12), + (data[2] >> 17) & 0x1fffffff, + ((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18), + (data[3] >> 11) & 0x1fffffff, + (data[3] >> 40) & 0x1fffffff }; +} +#endif +template constexpr field field::montgomery_mul(const field& other) const noexcept +{ + if constexpr (modulus.data[3] >= 0x4000000000000000ULL) { + return montgomery_mul_big(other); + } +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + auto [t0, c] = mul_wide(data[0], other.data[0]); + uint64_t k = t0 * T::r_inv; + uint64_t a = mac_discard_lo(t0, k, modulus.data[0]); + + uint64_t t1 = mac_mini(a, data[0], other.data[1], a); + mac(t1, k, modulus.data[1], c, t0, c); + uint64_t t2 = mac_mini(a, data[0], other.data[2], a); + mac(t2, k, modulus.data[2], c, t1, c); + uint64_t t3 = mac_mini(a, data[0], other.data[3], a); + mac(t3, k, modulus.data[3], c, t2, c); + t3 = c + a; + + mac_mini(t0, data[1], other.data[0], t0, a); + k = t0 * T::r_inv; + c = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, data[1], other.data[1], a, t1, a); + mac(t1, k, modulus.data[1], c, t0, c); + mac(t2, data[1], other.data[2], a, t2, a); + mac(t2, k, modulus.data[2], c, t1, c); + mac(t3, data[1], other.data[3], a, t3, a); + mac(t3, k, modulus.data[3], c, t2, c); + t3 = c + a; + + mac_mini(t0, data[2], other.data[0], t0, a); + k = t0 * T::r_inv; + c = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, data[2], other.data[1], a, t1, a); + mac(t1, k, modulus.data[1], c, t0, c); + mac(t2, data[2], other.data[2], a, t2, a); + mac(t2, k, modulus.data[2], c, t1, c); + mac(t3, data[2], other.data[3], a, t3, a); + mac(t3, k, modulus.data[3], c, t2, c); + t3 = c + a; + + mac_mini(t0, data[3], other.data[0], t0, a); + k = t0 * T::r_inv; + c = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, data[3], other.data[1], a, t1, a); + mac(t1, k, modulus.data[1], c, t0, c); + mac(t2, data[3], other.data[2], a, t2, a); + mac(t2, k, modulus.data[2], c, t1, c); + mac(t3, data[3], other.data[3], a, t3, a); + mac(t3, k, modulus.data[3], c, t2, c); + t3 = c + a; + return { t0, t1, t2, t3 }; +#else + + // Convert 4 64-bit limbs to 9 29-bit ones + auto left = wasm_convert(data); + auto right = wasm_convert(other.data); + constexpr uint64_t mask = 0x1fffffff; + uint64_t temp_0 = 0; + uint64_t temp_1 = 0; + uint64_t temp_2 = 0; + uint64_t temp_3 = 0; + uint64_t temp_4 = 0; + uint64_t temp_5 = 0; + uint64_t temp_6 = 0; + uint64_t temp_7 = 0; + uint64_t temp_8 = 0; + uint64_t temp_9 = 0; + uint64_t temp_10 = 0; + uint64_t temp_11 = 0; + uint64_t temp_12 = 0; + uint64_t temp_13 = 0; + uint64_t temp_14 = 0; + uint64_t temp_15 = 0; + uint64_t temp_16 = 0; + + // Perform a series of multiplications and reductions (we multiply 1 limb of left argument by the whole right + // argument and then reduce) + wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + + // Convert result to unrelaxed form (all limbs are 29 bits) + temp_10 += temp_9 >> WASM_LIMB_BITS; + temp_9 &= mask; + temp_11 += temp_10 >> WASM_LIMB_BITS; + temp_10 &= mask; + temp_12 += temp_11 >> WASM_LIMB_BITS; + temp_11 &= mask; + temp_13 += temp_12 >> WASM_LIMB_BITS; + temp_12 &= mask; + temp_14 += temp_13 >> WASM_LIMB_BITS; + temp_13 &= mask; + temp_15 += temp_14 >> WASM_LIMB_BITS; + temp_14 &= mask; + temp_16 += temp_15 >> WASM_LIMB_BITS; + temp_15 &= mask; + + // Convert back to 4 64-bit limbs form + return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58), + (temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52), + (temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46), + (temp_15 >> 18) | (temp_16 << 11) }; +#endif +} + +template constexpr field field::montgomery_square() const noexcept +{ + if constexpr (modulus.data[3] >= 0x4000000000000000ULL) { + return montgomery_mul_big(*this); + } +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + uint64_t carry_hi = 0; + + auto [t0, carry_lo] = mul_wide(data[0], data[0]); + uint64_t t1 = square_accumulate(0, data[1], data[0], carry_lo, carry_hi, carry_lo, carry_hi); + uint64_t t2 = square_accumulate(0, data[2], data[0], carry_lo, carry_hi, carry_lo, carry_hi); + uint64_t t3 = square_accumulate(0, data[3], data[0], carry_lo, carry_hi, carry_lo, carry_hi); + + uint64_t round_carry = carry_lo; + uint64_t k = t0 * T::r_inv; + carry_lo = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo); + mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo); + mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo); + t3 = carry_lo + round_carry; + + t1 = mac_mini(t1, data[1], data[1], carry_lo); + carry_hi = 0; + t2 = square_accumulate(t2, data[2], data[1], carry_lo, carry_hi, carry_lo, carry_hi); + t3 = square_accumulate(t3, data[3], data[1], carry_lo, carry_hi, carry_lo, carry_hi); + round_carry = carry_lo; + k = t0 * T::r_inv; + carry_lo = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo); + mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo); + mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo); + t3 = carry_lo + round_carry; + + t2 = mac_mini(t2, data[2], data[2], carry_lo); + carry_hi = 0; + t3 = square_accumulate(t3, data[3], data[2], carry_lo, carry_hi, carry_lo, carry_hi); + round_carry = carry_lo; + k = t0 * T::r_inv; + carry_lo = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo); + mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo); + mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo); + t3 = carry_lo + round_carry; + + t3 = mac_mini(t3, data[3], data[3], carry_lo); + k = t0 * T::r_inv; + round_carry = carry_lo; + carry_lo = mac_discard_lo(t0, k, modulus.data[0]); + mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo); + mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo); + mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo); + t3 = carry_lo + round_carry; + return { t0, t1, t2, t3 }; +#else + // Convert from 4 64-bit limbs to 9 29-bit ones + auto left = wasm_convert(data); + constexpr uint64_t mask = 0x1fffffff; + uint64_t temp_0 = 0; + uint64_t temp_1 = 0; + uint64_t temp_2 = 0; + uint64_t temp_3 = 0; + uint64_t temp_4 = 0; + uint64_t temp_5 = 0; + uint64_t temp_6 = 0; + uint64_t temp_7 = 0; + uint64_t temp_8 = 0; + uint64_t temp_9 = 0; + uint64_t temp_10 = 0; + uint64_t temp_11 = 0; + uint64_t temp_12 = 0; + uint64_t temp_13 = 0; + uint64_t temp_14 = 0; + uint64_t temp_15 = 0; + uint64_t temp_16 = 0; + uint64_t acc; + // Perform multiplications, but accumulated results for limb k=i+j so that we can double them at the same time + temp_0 += left[0] * left[0]; + acc = 0; + acc += left[0] * left[1]; + temp_1 += (acc << 1); + acc = 0; + acc += left[0] * left[2]; + temp_2 += left[1] * left[1]; + temp_2 += (acc << 1); + acc = 0; + acc += left[0] * left[3]; + acc += left[1] * left[2]; + temp_3 += (acc << 1); + acc = 0; + acc += left[0] * left[4]; + acc += left[1] * left[3]; + temp_4 += left[2] * left[2]; + temp_4 += (acc << 1); + acc = 0; + acc += left[0] * left[5]; + acc += left[1] * left[4]; + acc += left[2] * left[3]; + temp_5 += (acc << 1); + acc = 0; + acc += left[0] * left[6]; + acc += left[1] * left[5]; + acc += left[2] * left[4]; + temp_6 += left[3] * left[3]; + temp_6 += (acc << 1); + acc = 0; + acc += left[0] * left[7]; + acc += left[1] * left[6]; + acc += left[2] * left[5]; + acc += left[3] * left[4]; + temp_7 += (acc << 1); + acc = 0; + acc += left[0] * left[8]; + acc += left[1] * left[7]; + acc += left[2] * left[6]; + acc += left[3] * left[5]; + temp_8 += left[4] * left[4]; + temp_8 += (acc << 1); + acc = 0; + acc += left[1] * left[8]; + acc += left[2] * left[7]; + acc += left[3] * left[6]; + acc += left[4] * left[5]; + temp_9 += (acc << 1); + acc = 0; + acc += left[2] * left[8]; + acc += left[3] * left[7]; + acc += left[4] * left[6]; + temp_10 += left[5] * left[5]; + temp_10 += (acc << 1); + acc = 0; + acc += left[3] * left[8]; + acc += left[4] * left[7]; + acc += left[5] * left[6]; + temp_11 += (acc << 1); + acc = 0; + acc += left[4] * left[8]; + acc += left[5] * left[7]; + temp_12 += left[6] * left[6]; + temp_12 += (acc << 1); + acc = 0; + acc += left[5] * left[8]; + acc += left[6] * left[7]; + temp_13 += (acc << 1); + acc = 0; + acc += left[6] * left[8]; + temp_14 += left[7] * left[7]; + temp_14 += (acc << 1); + acc = 0; + acc += left[7] * left[8]; + temp_15 += (acc << 1); + temp_16 += left[8] * left[8]; + + // Perform reductions + wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + + // Convert to unrelaxed 29-bit form + temp_10 += temp_9 >> WASM_LIMB_BITS; + temp_9 &= mask; + temp_11 += temp_10 >> WASM_LIMB_BITS; + temp_10 &= mask; + temp_12 += temp_11 >> WASM_LIMB_BITS; + temp_11 &= mask; + temp_13 += temp_12 >> WASM_LIMB_BITS; + temp_12 &= mask; + temp_14 += temp_13 >> WASM_LIMB_BITS; + temp_13 &= mask; + temp_15 += temp_14 >> WASM_LIMB_BITS; + temp_14 &= mask; + temp_16 += temp_15 >> WASM_LIMB_BITS; + temp_15 &= mask; + // Convert to 4 64-bit form + return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58), + (temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52), + (temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46), + (temp_15 >> 18) | (temp_16 << 11) }; +#endif +} + +template constexpr struct field::wide_array field::mul_512(const field& other) const noexcept { +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + uint64_t carry_2 = 0; + auto [r0, carry] = mul_wide(data[0], other.data[0]); + uint64_t r1 = mac_mini(carry, data[0], other.data[1], carry); + uint64_t r2 = mac_mini(carry, data[0], other.data[2], carry); + uint64_t r3 = mac_mini(carry, data[0], other.data[3], carry_2); + + r1 = mac_mini(r1, data[1], other.data[0], carry); + r2 = mac(r2, data[1], other.data[1], carry, carry); + r3 = mac(r3, data[1], other.data[2], carry, carry); + uint64_t r4 = mac(carry_2, data[1], other.data[3], carry, carry_2); + + r2 = mac_mini(r2, data[2], other.data[0], carry); + r3 = mac(r3, data[2], other.data[1], carry, carry); + r4 = mac(r4, data[2], other.data[2], carry, carry); + uint64_t r5 = mac(carry_2, data[2], other.data[3], carry, carry_2); + + r3 = mac_mini(r3, data[3], other.data[0], carry); + r4 = mac(r4, data[3], other.data[1], carry, carry); + r5 = mac(r5, data[3], other.data[2], carry, carry); + uint64_t r6 = mac(carry_2, data[3], other.data[3], carry, carry_2); + + return { r0, r1, r2, r3, r4, r5, r6, carry_2 }; +#else + // Convert from 4 64-bit limbs to 9 29-bit limbs + auto left = wasm_convert(data); + auto right = wasm_convert(other.data); + constexpr uint64_t mask = 0x1fffffff; + uint64_t temp_0 = 0; + uint64_t temp_1 = 0; + uint64_t temp_2 = 0; + uint64_t temp_3 = 0; + uint64_t temp_4 = 0; + uint64_t temp_5 = 0; + uint64_t temp_6 = 0; + uint64_t temp_7 = 0; + uint64_t temp_8 = 0; + uint64_t temp_9 = 0; + uint64_t temp_10 = 0; + uint64_t temp_11 = 0; + uint64_t temp_12 = 0; + uint64_t temp_13 = 0; + uint64_t temp_14 = 0; + uint64_t temp_15 = 0; + uint64_t temp_16 = 0; + + // Multiply-add all limbs + wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + + // Convert to unrelaxed 29-bit form + temp_1 += temp_0 >> WASM_LIMB_BITS; + temp_0 &= mask; + temp_2 += temp_1 >> WASM_LIMB_BITS; + temp_1 &= mask; + temp_3 += temp_2 >> WASM_LIMB_BITS; + temp_2 &= mask; + temp_4 += temp_3 >> WASM_LIMB_BITS; + temp_3 &= mask; + temp_5 += temp_4 >> WASM_LIMB_BITS; + temp_4 &= mask; + temp_6 += temp_5 >> WASM_LIMB_BITS; + temp_5 &= mask; + temp_7 += temp_6 >> WASM_LIMB_BITS; + temp_6 &= mask; + temp_8 += temp_7 >> WASM_LIMB_BITS; + temp_7 &= mask; + temp_9 += temp_8 >> WASM_LIMB_BITS; + temp_8 &= mask; + temp_10 += temp_9 >> WASM_LIMB_BITS; + temp_9 &= mask; + temp_11 += temp_10 >> WASM_LIMB_BITS; + temp_10 &= mask; + temp_12 += temp_11 >> WASM_LIMB_BITS; + temp_11 &= mask; + temp_13 += temp_12 >> WASM_LIMB_BITS; + temp_12 &= mask; + temp_14 += temp_13 >> WASM_LIMB_BITS; + temp_13 &= mask; + temp_15 += temp_14 >> WASM_LIMB_BITS; + temp_14 &= mask; + temp_16 += temp_15 >> WASM_LIMB_BITS; + temp_15 &= mask; + + // Convert to 8 64-bit limbs + return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58), + (temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52), + (temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46), + (temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40), + (temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63), + (temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57), + (temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51), + (temp_15 >> 13) | (temp_16 << 16) }; +#endif +} + +// NOLINTEND(readability-implicit-bool-conversion) +} // namespace bb \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl_x64.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl_x64.hpp new file mode 100644 index 0000000..5b661f8 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/fields/field_impl_x64.hpp @@ -0,0 +1,389 @@ +#pragma once + +#if (BBERG_NO_ASM == 0) +#include "./field_impl.hpp" +#include "asm_macros.hpp" +namespace bb { + +template field field::asm_mul_with_coarse_reduction(const field& a, const field& b) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_mul_with_coarse_reduction"); + + field r; + constexpr uint64_t r_inv = T::r_inv; + constexpr uint64_t modulus_0 = modulus.data[0]; + constexpr uint64_t modulus_1 = modulus.data[1]; + constexpr uint64_t modulus_2 = modulus.data[2]; + constexpr uint64_t modulus_3 = modulus.data[3]; + constexpr uint64_t zero_ref = 0; + + /** + * Registers: rax:rdx = multiplication accumulator + * %r12, %r13, %r14, %r15, %rax: work registers for `r` + * %r8, %r9, %rdi, %rsi: scratch registers for multiplication results + * %r10: zero register + * %0: pointer to `a` + * %1: pointer to `b` + * %2: pointer to `r` + **/ + __asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1") + STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "%r"(&a), + "%r"(&b), + "r"(&r), + [modulus_0] "m"(modulus_0), + [modulus_1] "m"(modulus_1), + [modulus_2] "m"(modulus_2), + [modulus_3] "m"(modulus_3), + [r_inv] "m"(r_inv), + [zero_reference] "m"(zero_ref) + : "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); + return r; +} + +template void field::asm_self_mul_with_coarse_reduction(const field& a, const field& b) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_self_mul_with_coarse_reduction"); + + constexpr uint64_t r_inv = T::r_inv; + constexpr uint64_t modulus_0 = modulus.data[0]; + constexpr uint64_t modulus_1 = modulus.data[1]; + constexpr uint64_t modulus_2 = modulus.data[2]; + constexpr uint64_t modulus_3 = modulus.data[3]; + constexpr uint64_t zero_ref = 0; + /** + * Registers: rax:rdx = multiplication accumulator + * %r12, %r13, %r14, %r15, %rax: work registers for `r` + * %r8, %r9, %rdi, %rsi: scratch registers for multiplication results + * %r10: zero register + * %0: pointer to `a` + * %1: pointer to `b` + * %2: pointer to `r` + **/ + __asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1") + STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&b), + [modulus_0] "m"(modulus_0), + [modulus_1] "m"(modulus_1), + [modulus_2] "m"(modulus_2), + [modulus_3] "m"(modulus_3), + [r_inv] "m"(r_inv), + [zero_reference] "m"(zero_ref) + : "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +} + +template field field::asm_sqr_with_coarse_reduction(const field& a) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_sqr_with_coarse_reduction"); + + field r; + constexpr uint64_t r_inv = T::r_inv; + constexpr uint64_t modulus_0 = modulus.data[0]; + constexpr uint64_t modulus_1 = modulus.data[1]; + constexpr uint64_t modulus_2 = modulus.data[2]; + constexpr uint64_t modulus_3 = modulus.data[3]; + constexpr uint64_t zero_ref = 0; + +// Our SQR implementation with BMI2 but without ADX has a bug. +// The case is extremely rare so fixing it is a bit of a waste of time. +// We'll use MUL instead. +#if !defined(__ADX__) || defined(DISABLE_ADX) + /** + * Registers: rax:rdx = multiplication accumulator + * %r12, %r13, %r14, %r15, %rax: work registers for `r` + * %r8, %r9, %rdi, %rsi: scratch registers for multiplication results + * %r10: zero register + * %0: pointer to `a` + * %1: pointer to `b` + * %2: pointer to `r` + **/ + __asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1") + STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "%r"(&a), + "%r"(&a), + "r"(&r), + [modulus_0] "m"(modulus_0), + [modulus_1] "m"(modulus_1), + [modulus_2] "m"(modulus_2), + [modulus_3] "m"(modulus_3), + [r_inv] "m"(r_inv), + [zero_reference] "m"(zero_ref) + : "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); + +#else + + /** + * Registers: rax:rdx = multiplication accumulator + * %r12, %r13, %r14, %r15, %rax: work registers for `r` + * %r8, %r9, %rdi, %rsi: scratch registers for multiplication results + * %[zero_reference]: memory location of zero value + * %0: pointer to `a` + * %[r_ptr]: memory location of pointer to `r` + **/ + __asm__(SQR("%0") + // "movq %[r_ptr], %%rsi \n\t" + STORE_FIELD_ELEMENT("%1", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&r), + [zero_reference] "m"(zero_ref), + [modulus_0] "m"(modulus_0), + [modulus_1] "m"(modulus_1), + [modulus_2] "m"(modulus_2), + [modulus_3] "m"(modulus_3), + [r_inv] "m"(r_inv) + : "%rcx", "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +#endif + return r; +} + +template void field::asm_self_sqr_with_coarse_reduction(const field& a) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_self_sqr_with_coarse_reduction"); + + constexpr uint64_t r_inv = T::r_inv; + constexpr uint64_t modulus_0 = modulus.data[0]; + constexpr uint64_t modulus_1 = modulus.data[1]; + constexpr uint64_t modulus_2 = modulus.data[2]; + constexpr uint64_t modulus_3 = modulus.data[3]; + constexpr uint64_t zero_ref = 0; + +// Our SQR implementation with BMI2 but without ADX has a bug. +// The case is extremely rare so fixing it is a bit of a waste of time. +// We'll use MUL instead. +#if !defined(__ADX__) || defined(DISABLE_ADX) + /** + * Registers: rax:rdx = multiplication accumulator + * %r12, %r13, %r14, %r15, %rax: work registers for `r` + * %r8, %r9, %rdi, %rsi: scratch registers for multiplication results + * %r10: zero register + * %0: pointer to `a` + * %1: pointer to `b` + * %2: pointer to `r` + **/ + __asm__(MUL("0(%0)", "8(%0)", "16(%0)", "24(%0)", "%1") + STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&a), + [modulus_0] "m"(modulus_0), + [modulus_1] "m"(modulus_1), + [modulus_2] "m"(modulus_2), + [modulus_3] "m"(modulus_3), + [r_inv] "m"(r_inv), + [zero_reference] "m"(zero_ref) + : "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); + +#else + /** + * Registers: rax:rdx = multiplication accumulator + * %r12, %r13, %r14, %r15, %rax: work registers for `r` + * %r8, %r9, %rdi, %rsi: scratch registers for multiplication results + * %[zero_reference]: memory location of zero value + * %0: pointer to `a` + * %[r_ptr]: memory location of pointer to `r` + **/ + __asm__(SQR("%0") + // "movq %[r_ptr], %%rsi \n\t" + STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + [zero_reference] "m"(zero_ref), + [modulus_0] "m"(modulus_0), + [modulus_1] "m"(modulus_1), + [modulus_2] "m"(modulus_2), + [modulus_3] "m"(modulus_3), + [r_inv] "m"(r_inv) + : "%rcx", "%rdx", "%rdi", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +#endif +} + +template field field::asm_add_with_coarse_reduction(const field& a, const field& b) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_add_with_coarse_reduction"); + + field r; + + constexpr uint64_t twice_not_modulus_0 = twice_not_modulus.data[0]; + constexpr uint64_t twice_not_modulus_1 = twice_not_modulus.data[1]; + constexpr uint64_t twice_not_modulus_2 = twice_not_modulus.data[2]; + constexpr uint64_t twice_not_modulus_3 = twice_not_modulus.data[3]; + + __asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + ADD_REDUCE("%1", + "%[twice_not_modulus_0]", + "%[twice_not_modulus_1]", + "%[twice_not_modulus_2]", + "%[twice_not_modulus_3]") STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "%r"(&a), + "%r"(&b), + "r"(&r), + [twice_not_modulus_0] "m"(twice_not_modulus_0), + [twice_not_modulus_1] "m"(twice_not_modulus_1), + [twice_not_modulus_2] "m"(twice_not_modulus_2), + [twice_not_modulus_3] "m"(twice_not_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); + return r; +} + +template void field::asm_self_add_with_coarse_reduction(const field& a, const field& b) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_self_add_with_coarse_reduction"); + + constexpr uint64_t twice_not_modulus_0 = twice_not_modulus.data[0]; + constexpr uint64_t twice_not_modulus_1 = twice_not_modulus.data[1]; + constexpr uint64_t twice_not_modulus_2 = twice_not_modulus.data[2]; + constexpr uint64_t twice_not_modulus_3 = twice_not_modulus.data[3]; + + __asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + ADD_REDUCE("%1", + "%[twice_not_modulus_0]", + "%[twice_not_modulus_1]", + "%[twice_not_modulus_2]", + "%[twice_not_modulus_3]") STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&b), + [twice_not_modulus_0] "m"(twice_not_modulus_0), + [twice_not_modulus_1] "m"(twice_not_modulus_1), + [twice_not_modulus_2] "m"(twice_not_modulus_2), + [twice_not_modulus_3] "m"(twice_not_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +} + +template field field::asm_sub_with_coarse_reduction(const field& a, const field& b) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_sub_with_coarse_reduction"); + + field r; + + constexpr uint64_t twice_modulus_0 = twice_modulus.data[0]; + constexpr uint64_t twice_modulus_1 = twice_modulus.data[1]; + constexpr uint64_t twice_modulus_2 = twice_modulus.data[2]; + constexpr uint64_t twice_modulus_3 = twice_modulus.data[3]; + + __asm__( + CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") SUB("%1") + REDUCE_FIELD_ELEMENT("%[twice_modulus_0]", "%[twice_modulus_1]", "%[twice_modulus_2]", "%[twice_modulus_3]") + STORE_FIELD_ELEMENT("%2", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&b), + "r"(&r), + [twice_modulus_0] "m"(twice_modulus_0), + [twice_modulus_1] "m"(twice_modulus_1), + [twice_modulus_2] "m"(twice_modulus_2), + [twice_modulus_3] "m"(twice_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); + return r; +} + +template void field::asm_self_sub_with_coarse_reduction(const field& a, const field& b) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_self_sub_with_coarse_reduction"); + + constexpr uint64_t twice_modulus_0 = twice_modulus.data[0]; + constexpr uint64_t twice_modulus_1 = twice_modulus.data[1]; + constexpr uint64_t twice_modulus_2 = twice_modulus.data[2]; + constexpr uint64_t twice_modulus_3 = twice_modulus.data[3]; + + __asm__( + CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") SUB("%1") + REDUCE_FIELD_ELEMENT("%[twice_modulus_0]", "%[twice_modulus_1]", "%[twice_modulus_2]", "%[twice_modulus_3]") + STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&b), + [twice_modulus_0] "m"(twice_modulus_0), + [twice_modulus_1] "m"(twice_modulus_1), + [twice_modulus_2] "m"(twice_modulus_2), + [twice_modulus_3] "m"(twice_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +} + +template void field::asm_conditional_negate(field& r, const uint64_t predicate) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_conditional_negate"); + + constexpr uint64_t twice_modulus_0 = twice_modulus.data[0]; + constexpr uint64_t twice_modulus_1 = twice_modulus.data[1]; + constexpr uint64_t twice_modulus_2 = twice_modulus.data[2]; + constexpr uint64_t twice_modulus_3 = twice_modulus.data[3]; + + __asm__(CLEAR_FLAGS("%%r8") LOAD_FIELD_ELEMENT( + "%1", "%%r8", "%%r9", "%%r10", "%%r11") "movq %[twice_modulus_0], %%r12 \n\t" + "movq %[twice_modulus_1], %%r13 \n\t" + "movq %[twice_modulus_2], %%r14 \n\t" + "movq %[twice_modulus_3], %%r15 \n\t" + "subq %%r8, %%r12 \n\t" + "sbbq %%r9, %%r13 \n\t" + "sbbq %%r10, %%r14 \n\t" + "sbbq %%r11, %%r15 \n\t" + "testq %0, %0 \n\t" + "cmovnzq %%r12, %%r8 \n\t" + "cmovnzq %%r13, %%r9 \n\t" + "cmovnzq %%r14, %%r10 \n\t" + "cmovnzq %%r15, %%r11 \n\t" STORE_FIELD_ELEMENT( + "%1", "%%r8", "%%r9", "%%r10", "%%r11") + : + : "r"(predicate), + "r"(&r), + [twice_modulus_0] "i"(twice_modulus_0), + [twice_modulus_1] "i"(twice_modulus_1), + [twice_modulus_2] "i"(twice_modulus_2), + [twice_modulus_3] "i"(twice_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +} + +template field field::asm_reduce_once(const field& a) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_reduce_once"); + + field r; + + constexpr uint64_t not_modulus_0 = not_modulus.data[0]; + constexpr uint64_t not_modulus_1 = not_modulus.data[1]; + constexpr uint64_t not_modulus_2 = not_modulus.data[2]; + constexpr uint64_t not_modulus_3 = not_modulus.data[3]; + + __asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + REDUCE_FIELD_ELEMENT("%[not_modulus_0]", "%[not_modulus_1]", "%[not_modulus_2]", "%[not_modulus_3]") + STORE_FIELD_ELEMENT("%1", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + "r"(&r), + [not_modulus_0] "m"(not_modulus_0), + [not_modulus_1] "m"(not_modulus_1), + [not_modulus_2] "m"(not_modulus_2), + [not_modulus_3] "m"(not_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); + return r; +} + +template void field::asm_self_reduce_once(const field& a) noexcept +{ + BB_OP_COUNT_TRACK_NAME("fr::asm_self_reduce_once"); + + constexpr uint64_t not_modulus_0 = not_modulus.data[0]; + constexpr uint64_t not_modulus_1 = not_modulus.data[1]; + constexpr uint64_t not_modulus_2 = not_modulus.data[2]; + constexpr uint64_t not_modulus_3 = not_modulus.data[3]; + + __asm__(CLEAR_FLAGS("%%r12") LOAD_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + REDUCE_FIELD_ELEMENT("%[not_modulus_0]", "%[not_modulus_1]", "%[not_modulus_2]", "%[not_modulus_3]") + STORE_FIELD_ELEMENT("%0", "%%r12", "%%r13", "%%r14", "%%r15") + : + : "r"(&a), + [not_modulus_0] "m"(not_modulus_0), + [not_modulus_1] "m"(not_modulus_1), + [not_modulus_2] "m"(not_modulus_2), + [not_modulus_3] "m"(not_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "cc", "memory"); +} +} // namespace bb +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/affine_element.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/affine_element.hpp new file mode 100644 index 0000000..500ab2a --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/affine_element.hpp @@ -0,0 +1,192 @@ +#pragma once +#include "../../common/serialize.hpp" +#include "../../ecc/curves/bn254/fq2.hpp" +#include "../../numeric/uint256/uint256.hpp" +#include +#include +#include + +namespace bb::group_elements { +template +concept SupportsHashToCurve = T::can_hash_to_curve; +template class alignas(64) affine_element { + public: + using Fq = Fq_; + using Fr = Fr_; + + using in_buf = const uint8_t*; + using vec_in_buf = const uint8_t*; + using out_buf = uint8_t*; + using vec_out_buf = uint8_t**; + + affine_element() noexcept = default; + ~affine_element() noexcept = default; + + constexpr affine_element(const Fq& x, const Fq& y) noexcept; + + constexpr affine_element(const affine_element& other) noexcept = default; + + constexpr affine_element(affine_element&& other) noexcept = default; + + static constexpr affine_element one() noexcept { return { Params::one_x, Params::one_y }; }; + + /** + * @brief Reconstruct a point in affine coordinates from compressed form. + * @details #LARGE_MODULUS_AFFINE_POINT_COMPRESSION Point compression is only implemented for curves of a prime + * field F_p with p using < 256 bits. One possiblity for extending to a 256-bit prime field: + * https://patents.google.com/patent/US6252960B1/en. + * + * @param compressed compressed point + * @return constexpr affine_element + */ + template > 255) == uint256_t(0), void>> + static constexpr affine_element from_compressed(const uint256_t& compressed) noexcept; + + /** + * @brief Reconstruct a point in affine coordinates from compressed form. + * @details #LARGE_MODULUS_AFFINE_POINT_COMPRESSION Point compression is implemented for curves of a prime + * field F_p with p being 256 bits. + * TODO(Suyash): Check with kesha if this is correct. + * + * @param compressed compressed point + * @return constexpr affine_element + */ + template > 255) == uint256_t(1), void>> + static constexpr std::array from_compressed_unsafe(const uint256_t& compressed) noexcept; + + constexpr affine_element& operator=(const affine_element& other) noexcept = default; + + constexpr affine_element& operator=(affine_element&& other) noexcept = default; + + constexpr affine_element operator+(const affine_element& other) const noexcept; + + template > 255) == uint256_t(0), void>> + [[nodiscard]] constexpr uint256_t compress() const noexcept; + + static affine_element infinity(); + constexpr affine_element set_infinity() const noexcept; + constexpr void self_set_infinity() noexcept; + + [[nodiscard]] constexpr bool is_point_at_infinity() const noexcept; + + [[nodiscard]] constexpr bool on_curve() const noexcept; + + static constexpr std::optional derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept; + + /** + * @brief Samples a random point on the curve. + * + * @return A randomly chosen point on the curve + */ + static affine_element random_element(numeric::RNG* engine = nullptr) noexcept; + static constexpr affine_element hash_to_curve(const std::vector& seed, uint8_t attempt_count = 0) noexcept + requires SupportsHashToCurve; + + constexpr bool operator==(const affine_element& other) const noexcept; + + constexpr affine_element operator-() const noexcept { return { x, -y }; } + + constexpr bool operator>(const affine_element& other) const noexcept; + constexpr bool operator<(const affine_element& other) const noexcept { return (other > *this); } + + /** + * @brief Serialize the point to the given buffer + * + * @details We support serializing the point at infinity for curves defined over a bb::field (i.e., a + * native field of prime order) and for points of bb::g2. + * + * @warning This will need to be updated if we serialize points over composite-order fields other than fq2! + * + */ + static void serialize_to_buffer(const affine_element& value, uint8_t* buffer, bool write_x_first = false) + { + using namespace serialize; + if (value.is_point_at_infinity()) { + // if we are infinity, just set all buffer bits to 1 + // we only need this case because the below gets mangled converting from montgomery for infinity points + memset(buffer, 255, sizeof(Fq) * 2); + } else { + // Note: for historic reasons we will need to redo downstream hashes if we want this to always be written in + // the same order in our various serialization flows + write(buffer, write_x_first ? value.x : value.y); + write(buffer, write_x_first ? value.y : value.x); + } + } + + /** + * @brief Restore point from a buffer + * + * @param buffer Buffer from which we deserialize the point + * + * @return Deserialized point + * + * @details We support serializing the point at infinity for curves defined over a bb::field (i.e., a + * native field of prime order) and for points of bb::g2. + * + * @warning This will need to be updated if we serialize points over composite-order fields other than fq2! + */ + static affine_element serialize_from_buffer(const uint8_t* buffer, bool write_x_first = false) + { + using namespace serialize; + // Does the buffer consist entirely of set bits? If so, we have a point at infinity + // Note that if it isn't, this loop should end early. + // We only need this case because the below gets mangled converting to montgomery for infinity points + bool is_point_at_infinity = + std::all_of(buffer, buffer + sizeof(Fq) * 2, [](uint8_t val) { return val == 255; }); + if (is_point_at_infinity) { + return affine_element::infinity(); + } + affine_element result; + // Note: for historic reasons we will need to redo downstream hashes if we want this to always be read in the + // same order in our various serialization flows + read(buffer, write_x_first ? result.x : result.y); + read(buffer, write_x_first ? result.y : result.x); + return result; + } + + /** + * @brief Serialize the point to a byte vector + * + * @return Vector with serialized representation of the point + */ + [[nodiscard]] inline std::vector to_buffer() const + { + std::vector buffer(sizeof(affine_element)); + affine_element::serialize_to_buffer(*this, &buffer[0]); + return buffer; + } + + friend std::ostream& operator<<(std::ostream& os, const affine_element& a) + { + os << "{ " << a.x << ", " << a.y << " }"; + return os; + } + Fq x; + Fq y; +}; + +template +inline void read(B& it, group_elements::affine_element& element) +{ + using namespace serialize; + std::array buffer; + read(it, buffer); + element = group_elements::affine_element::serialize_from_buffer( + buffer.data(), /* use legacy field order */ true); +} + +template +inline void write(B& it, group_elements::affine_element const& element) +{ + using namespace serialize; + std::array buffer; + group_elements::affine_element::serialize_to_buffer( + element, buffer.data(), /* use legacy field order */ true); + write(it, buffer); +} +} // namespace bb::group_elements + +#include "./affine_element_impl.hpp" diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/affine_element_impl.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/affine_element_impl.hpp new file mode 100644 index 0000000..4ed6fd4 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/affine_element_impl.hpp @@ -0,0 +1,290 @@ +#pragma once +#include "./element.hpp" +#include "../../crypto/blake3s/blake3s.hpp" +#include "../../crypto/keccak/keccak.hpp" + +namespace bb::group_elements { +template +constexpr affine_element::affine_element(const Fq& x, const Fq& y) noexcept + : x(x) + , y(y) +{} + +template +template +constexpr affine_element affine_element::from_compressed(const uint256_t& compressed) noexcept +{ + uint256_t x_coordinate = compressed; + x_coordinate.data[3] = x_coordinate.data[3] & (~0x8000000000000000ULL); + bool y_bit = compressed.get_bit(255); + + Fq x = Fq(x_coordinate); + Fq y2 = (x.sqr() * x + T::b); + if constexpr (T::has_a) { + y2 += (x * T::a); + } + auto [is_quadratic_remainder, y] = y2.sqrt(); + if (!is_quadratic_remainder) { + return affine_element(Fq::zero(), Fq::zero()); + } + if (uint256_t(y).get_bit(0) != y_bit) { + y = -y; + } + + return affine_element(x, y); +} + +template +template +constexpr std::array, 2> affine_element::from_compressed_unsafe( + const uint256_t& compressed) noexcept +{ + auto get_y_coordinate = [](const uint256_t& x_coordinate) { + Fq x = Fq(x_coordinate); + Fq y2 = (x.sqr() * x + T::b); + if constexpr (T::has_a) { + y2 += (x * T::a); + } + return y2.sqrt(); + }; + + uint256_t x_1 = compressed; + uint256_t x_2 = compressed + Fr::modulus; + auto [is_quadratic_remainder_1, y_1] = get_y_coordinate(x_1); + auto [is_quadratic_remainder_2, y_2] = get_y_coordinate(x_2); + + auto output_1 = is_quadratic_remainder_1 ? affine_element(Fq(x_1), y_1) + : affine_element(Fq::zero(), Fq::zero()); + auto output_2 = is_quadratic_remainder_2 ? affine_element(Fq(x_2), y_2) + : affine_element(Fq::zero(), Fq::zero()); + + return { output_1, output_2 }; +} + +template +constexpr affine_element affine_element::operator+( + const affine_element& other) const noexcept +{ + return affine_element(element(*this) + element(other)); +} + +template +template + +constexpr uint256_t affine_element::compress() const noexcept +{ + uint256_t out(x); + if (uint256_t(y).get_bit(0)) { + out.data[3] = out.data[3] | 0x8000000000000000ULL; + } + return out; +} + +template affine_element affine_element::infinity() +{ + affine_element e; + e.self_set_infinity(); + return e; +} + +template +constexpr affine_element affine_element::set_infinity() const noexcept +{ + affine_element result(*this); + result.self_set_infinity(); + return result; +} + +template constexpr void affine_element::self_set_infinity() noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + // We set the value of x equal to modulus to represent inifinty + x.data[0] = Fq::modulus.data[0]; + x.data[1] = Fq::modulus.data[1]; + x.data[2] = Fq::modulus.data[2]; + x.data[3] = Fq::modulus.data[3]; + + } else { + x.self_set_msb(); + } +} + +template constexpr bool affine_element::is_point_at_infinity() const noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + // We check if the value of x is equal to modulus to represent inifinty + return ((x.data[0] ^ Fq::modulus.data[0]) | (x.data[1] ^ Fq::modulus.data[1]) | + (x.data[2] ^ Fq::modulus.data[2]) | (x.data[3] ^ Fq::modulus.data[3])) == 0; + + } else { + return (x.is_msb_set()); + } +} + +template constexpr bool affine_element::on_curve() const noexcept +{ + if (is_point_at_infinity()) { + return true; + } + Fq xxx = x.sqr() * x + T::b; + Fq yy = y.sqr(); + if constexpr (T::has_a) { + xxx += (x * T::a); + } + return (xxx == yy); +} + +template +constexpr bool affine_element::operator==(const affine_element& other) const noexcept +{ + bool this_is_infinity = is_point_at_infinity(); + bool other_is_infinity = other.is_point_at_infinity(); + bool both_infinity = this_is_infinity && other_is_infinity; + bool only_one_is_infinity = this_is_infinity != other_is_infinity; + return !only_one_is_infinity && (both_infinity || ((x == other.x) && (y == other.y))); +} + +/** + * Comparison operators (for std::sort) + * + * @details CAUTION!! Don't use this operator. It has no meaning other than for use by std::sort. + **/ +template +constexpr bool affine_element::operator>(const affine_element& other) const noexcept +{ + // We are setting point at infinity to always be the lowest element + if (is_point_at_infinity()) { + return false; + } + if (other.is_point_at_infinity()) { + return true; + } + + if (x > other.x) { + return true; + } + if (x == other.x && y > other.y) { + return true; + } + return false; +} + +template +constexpr std::optional> affine_element::derive_from_x_coordinate( + const Fq& x, bool sign_bit) noexcept +{ + auto yy = x.sqr() * x + T::b; + if constexpr (T::has_a) { + yy += (x * T::a); + } + auto [found_root, y] = yy.sqrt(); + + if (found_root) { + if (uint256_t(y).get_bit(0) != sign_bit) { + y = -y; + } + return affine_element(x, y); + } + return std::nullopt; +} + +/** + * @brief Hash a seed buffer into a point + * + * @details ALGORITHM DESCRIPTION: + * 1. Initialize unsigned integer `attempt_count = 0` + * 2. Copy seed into a buffer whose size is 2 bytes greater than `seed` (initialized to 0) + * 3. Interpret `attempt_count` as a byte and write into buffer at [buffer.size() - 2] + * 4. Compute Blake3s hash of buffer + * 5. Set the end byte of the buffer to `1` + * 6. Compute Blake3s hash of buffer + * 7. Interpret the two hash outputs as the high / low 256 bits of a 512-bit integer (big-endian) + * 8. Derive x-coordinate of point by reducing the 512-bit integer modulo the curve's field modulus (Fq) + * 9. Compute y^2 from the curve formula y^2 = x^3 + ax + b (a, b are curve params. for BN254, a = 0, b = 3) + * 10. IF y^2 IS NOT A QUADRATIC RESIDUE + * 10a. increment `attempt_count` by 1 and go to step 2 + * 11. IF y^2 IS A QUADRATIC RESIDUE + * 11a. derive y coordinate via y = sqrt(y) + * 11b. Interpret most significant bit of 512-bit integer as a 'parity' bit + * 11c. If parity bit is set AND y's most significant bit is not set, invert y + * 11d. If parity bit is not set AND y's most significant bit is set, invert y + * N.B. last 2 steps are because the sqrt() algorithm can return 2 values, + * we need to a way to canonically distinguish between these 2 values and select a "preferred" one + * 11e. return (x, y) + * + * @note This algorihm is constexpr: we can hash-to-curve (and derive generators) at compile-time! + * @tparam Fq + * @tparam Fr + * @tparam T + * @param seed Bytes that uniquely define the point being generated + * @param attempt_count + * @return constexpr affine_element + */ +template +constexpr affine_element affine_element::hash_to_curve(const std::vector& seed, + uint8_t attempt_count) noexcept + requires SupportsHashToCurve + +{ + std::vector target_seed(seed); + // expand by 2 bytes to cover incremental hash attempts + const size_t seed_size = seed.size(); + for (size_t i = 0; i < 2; ++i) { + target_seed.push_back(0); + } + target_seed[seed_size] = attempt_count; + target_seed[seed_size + 1] = 0; + const auto hash_hi = blake3::blake3s_constexpr(&target_seed[0], target_seed.size()); + target_seed[seed_size + 1] = 1; + const auto hash_lo = blake3::blake3s_constexpr(&target_seed[0], target_seed.size()); + // custom serialize methods as common/serialize.hpp is not constexpr! + const auto read_uint256 = [](const uint8_t* in) { + const auto read_limb = [](const uint8_t* in, uint64_t& out) { + for (size_t i = 0; i < 8; ++i) { + out += static_cast(in[i]) << ((7 - i) * 8); + } + }; + uint256_t out = 0; + read_limb(&in[0], out.data[3]); + read_limb(&in[8], out.data[2]); + read_limb(&in[16], out.data[1]); + read_limb(&in[24], out.data[0]); + return out; + }; + // interpret 64 byte hash output as a uint512_t, reduce to Fq element + //(512 bits of entropy ensures result is not biased as 512 >> Fq::modulus.get_msb()) + Fq x(uint512_t(read_uint256(&hash_lo[0]), read_uint256(&hash_hi[0]))); + bool sign_bit = hash_hi[0] > 127; + std::optional result = derive_from_x_coordinate(x, sign_bit); + if (result.has_value()) { + return result.value(); + } + return hash_to_curve(seed, attempt_count + 1); +} + +template +affine_element affine_element::random_element(numeric::RNG* engine) noexcept +{ + if (engine == nullptr) { + engine = &numeric::get_randomness(); + } + + Fq x; + Fq y; + while (true) { + // Sample a random x-coordinate and check if it satisfies curve equation. + x = Fq::random_element(engine); + // Negate the y-coordinate based on a randomly sampled bit. + bool sign_bit = (engine->get_random_uint8() & 1) != 0; + + std::optional result = derive_from_x_coordinate(x, sign_bit); + + if (result.has_value()) { + return result.value(); + } + } + throw_or_abort("affine_element::random_element error"); + return affine_element(x, y); +} + +} // namespace bb::group_elements diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/element.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/element.hpp new file mode 100644 index 0000000..a60090a --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/element.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include "affine_element.hpp" +#include "../../common/compiler_hints.hpp" +#include "../../common/mem.hpp" +#include "../../numeric/random/engine.hpp" +#include "../../numeric/uint256/uint256.hpp" +#include "wnaf.hpp" +#include +#include +#include + +namespace bb::group_elements { + +/** + * @brief element class. Implements ecc group arithmetic using Jacobian coordinates + * See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l + * + * Note: Currently subgroup checks are NOT IMPLEMENTED + * Our current Plonk implementation uses G1 points that have a cofactor of 1. + * All G2 points are precomputed (generator [1]_2 and trusted setup point [x]_2). + * Explicitly assume precomputed points are valid members of the prime-order subgroup for G2. + * @tparam Fq prime field the curve is defined over + * @tparam Fr prime field whose characteristic equals the size of the prime-order elliptic curve subgroup + * @tparam Params curve parameters + */ +template class alignas(32) element { + public: + static constexpr Fq curve_b = Params::b; + + element() noexcept = default; + + constexpr element(const Fq& a, const Fq& b, const Fq& c) noexcept; + constexpr element(const element& other) noexcept; + constexpr element(element&& other) noexcept; + constexpr element(const affine_element& other) noexcept; + constexpr ~element() noexcept = default; + + static constexpr element one() noexcept { return { Params::one_x, Params::one_y, Fq::one() }; }; + static constexpr element zero() noexcept + { + element zero; + zero.self_set_infinity(); + return zero; + }; + + constexpr element& operator=(const element& other) noexcept; + constexpr element& operator=(element&& other) noexcept; + + constexpr operator affine_element() const noexcept; + + static element random_element(numeric::RNG* engine = nullptr) noexcept; + + constexpr element dbl() const noexcept; + constexpr void self_dbl() noexcept; + constexpr void self_mixed_add_or_sub(const affine_element& other, uint64_t predicate) noexcept; + + constexpr element operator+(const element& other) const noexcept; + constexpr element operator+(const affine_element& other) const noexcept; + constexpr element operator+=(const element& other) noexcept; + constexpr element operator+=(const affine_element& other) noexcept; + + constexpr element operator-(const element& other) const noexcept; + constexpr element operator-(const affine_element& other) const noexcept; + constexpr element operator-() const noexcept; + constexpr element operator-=(const element& other) noexcept; + constexpr element operator-=(const affine_element& other) noexcept; + + friend constexpr element operator+(const affine_element& left, const element& right) noexcept + { + return right + left; + } + friend constexpr element operator-(const affine_element& left, const element& right) noexcept + { + return -right + left; + } + + element operator*(const Fr& exponent) const noexcept; + element operator*=(const Fr& exponent) noexcept; + + // If you end up implementing this, congrats, you've solved the DL problem! + // P.S. This is a joke, don't even attempt! 😂 + // constexpr Fr operator/(const element& other) noexcept {} + + constexpr element normalize() const noexcept; + static element infinity(); + BB_INLINE constexpr element set_infinity() const noexcept; + BB_INLINE constexpr void self_set_infinity() noexcept; + [[nodiscard]] BB_INLINE constexpr bool is_point_at_infinity() const noexcept; + [[nodiscard]] BB_INLINE constexpr bool on_curve() const noexcept; + BB_INLINE constexpr bool operator==(const element& other) const noexcept; + + static void batch_normalize(element* elements, size_t num_elements) noexcept; + static void batch_affine_add(const std::span>& first_group, + const std::span>& second_group, + const std::span>& results) noexcept; + static std::vector> batch_mul_with_endomorphism( + const std::span>& points, const Fr& scalar) noexcept; + + Fq x; + Fq y; + Fq z; + + private: + // For test access to mul_without_endomorphism + friend class TestElementPrivate; + element mul_without_endomorphism(const Fr& scalar) const noexcept; + element mul_with_endomorphism(const Fr& scalar) const noexcept; + + template > + static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept; + // { + // bool found_one = false; + // Fq yy; + // Fq x; + // Fq y; + // Fq t0; + // while (!found_one) { + // x = Fq::random_element(engine); + // yy = x.sqr() * x + Params::b; + // if constexpr (Params::has_a) { + // yy += (x * Params::a); + // } + // y = yy.sqrt(); + // t0 = y.sqr(); + // found_one = (yy == t0); + // } + // return { x, y, Fq::one() }; + // } + // for serialization: update with new fields + // TODO(https://github.com/AztecProtocol/barretenberg/issues/908) point at inifinty isn't handled + + static void conditional_negate_affine(const affine_element& in, + affine_element& out, + uint64_t predicate) noexcept; + + friend std::ostream& operator<<(std::ostream& os, const element& a) + { + os << "{ " << a.x << ", " << a.y << ", " << a.z << " }"; + return os; + } +}; + +template std::ostream& operator<<(std::ostream& os, element const& e) +{ + return os << "x:" << e.x << " y:" << e.y << " z:" << e.z; +} + +// constexpr element::one = element{ Params::one_x, Params::one_y, Fq::one() }; +// constexpr element::point_at_infinity = one.set_infinity(); +// constexpr element::curve_b = Params::b; +} // namespace bb::group_elements + +#include "./element_impl.hpp" + +template +bb::group_elements::affine_element operator*( + const bb::group_elements::affine_element& base, const Fr& exponent) noexcept +{ + return bb::group_elements::affine_element(bb::group_elements::element(base) * exponent); +} + +template +bb::group_elements::affine_element operator*(const bb::group_elements::element& base, + const Fr& exponent) noexcept +{ + return (bb::group_elements::element(base) * exponent); +} diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/element_impl.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/element_impl.hpp new file mode 100644 index 0000000..d85d5ed --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/element_impl.hpp @@ -0,0 +1,1219 @@ +#pragma once +#include "../../common/op_count.hpp" +#include "../../common/thread.hpp" +#include "./element.hpp" +#include "element.hpp" +#include + +// NOLINTBEGIN(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays) +namespace bb::group_elements { +template +constexpr element::element(const Fq& a, const Fq& b, const Fq& c) noexcept + : x(a) + , y(b) + , z(c) +{} + +template +constexpr element::element(const element& other) noexcept + : x(other.x) + , y(other.y) + , z(other.z) +{} + +template +constexpr element::element(element&& other) noexcept + : x(other.x) + , y(other.y) + , z(other.z) +{} + +template +constexpr element::element(const affine_element& other) noexcept + : x(other.x) + , y(other.y) + , z(Fq::one()) +{} + +template +constexpr element& element::operator=(const element& other) noexcept +{ + if (this == &other) { + return *this; + } + x = other.x; + y = other.y; + z = other.z; + return *this; +} + +template +constexpr element& element::operator=(element&& other) noexcept +{ + x = other.x; + y = other.y; + z = other.z; + return *this; +} + +template constexpr element::operator affine_element() const noexcept +{ + if (is_point_at_infinity()) { + affine_element result; + result.x = Fq(0); + result.y = Fq(0); + result.self_set_infinity(); + return result; + } + Fq z_inv = z.invert(); + Fq zz_inv = z_inv.sqr(); + Fq zzz_inv = zz_inv * z_inv; + affine_element result(x * zz_inv, y * zzz_inv); + if (is_point_at_infinity()) { + result.self_set_infinity(); + } + return result; +} + +template constexpr void element::self_dbl() noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + if (is_point_at_infinity()) { + return; + } + } else { + if (x.is_msb_set_word()) { + return; + } + } + + // T0 = x*x + Fq T0 = x.sqr(); + + // T1 = y*y + Fq T1 = y.sqr(); + + // T2 = T2*T1 = y*y*y*y + Fq T2 = T1.sqr(); + + // T1 = T1 + x = x + y*y + T1 += x; + + // T1 = T1 * T1 + T1.self_sqr(); + + // T3 = T0 + T2 = xx + y*y*y*y + Fq T3 = T0 + T2; + + // T1 = T1 - T3 = x*x + y*y*y*y + 2*x*x*y*y*y*y - x*x - y*y*y*y = 2*x*x*y*y*y*y = 2*S + T1 -= T3; + + // T1 = 2T1 = 4*S + T1 += T1; + + // T3 = 3T0 + T3 = T0 + T0; + T3 += T0; + if constexpr (T::has_a) { + T3 += (T::a * z.sqr().sqr()); + } + + // z2 = 2*y*z + z += z; + z *= y; + + // T0 = 2T1 + T0 = T1 + T1; + + // x2 = T3*T3 + x = T3.sqr(); + + // x2 = x2 - 2T1 + x -= T0; + + // T2 = 8T2 + T2 += T2; + T2 += T2; + T2 += T2; + + // y2 = T1 - x2 + y = T1 - x; + + // y2 = y2 * T3 - T2 + y *= T3; + y -= T2; +} + +template constexpr element element::dbl() const noexcept +{ + element result(*this); + result.self_dbl(); + return result; +} + +template +constexpr void element::self_mixed_add_or_sub(const affine_element& other, + const uint64_t predicate) noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + if (is_point_at_infinity()) { + conditional_negate_affine(other, *(affine_element*)this, predicate); // NOLINT + z = Fq::one(); + return; + } + } else { + const bool edge_case_trigger = x.is_msb_set() || other.x.is_msb_set(); + if (edge_case_trigger) { + if (x.is_msb_set()) { + conditional_negate_affine(other, *(affine_element*)this, predicate); // NOLINT + z = Fq::one(); + } + return; + } + } + + // T0 = z1.z1 + Fq T0 = z.sqr(); + + // T1 = x2.t0 - x1 = x2.z1.z1 - x1 + Fq T1 = other.x * T0; + T1 -= x; + + // T2 = T0.z1 = z1.z1.z1 + // T2 = T2.y2 - y1 = y2.z1.z1.z1 - y1 + Fq T2 = z * T0; + T2 *= other.y; + T2.self_conditional_negate(predicate); + T2 -= y; + + if (__builtin_expect(T1.is_zero(), 0)) { + if (T2.is_zero()) { + // y2 equals y1, x2 equals x1, double x1 + self_dbl(); + return; + } + self_set_infinity(); + return; + } + + // T2 = 2T2 = 2(y2.z1.z1.z1 - y1) = R + // z3 = z1 + H + T2 += T2; + z += T1; + + // T3 = T1*T1 = HH + Fq T3 = T1.sqr(); + + // z3 = z3 - z1z1 - HH + T0 += T3; + + // z3 = (z1 + H)*(z1 + H) + z.self_sqr(); + z -= T0; + + // T3 = 4HH + T3 += T3; + T3 += T3; + + // T1 = T1*T3 = 4HHH + T1 *= T3; + + // T3 = T3 * x1 = 4HH*x1 + T3 *= x; + + // T0 = 2T3 + T0 = T3 + T3; + + // T0 = T0 + T1 = 2(4HH*x1) + 4HHH + T0 += T1; + x = T2.sqr(); + + // x3 = x3 - T0 = R*R - 8HH*x1 -4HHH + x -= T0; + + // T3 = T3 - x3 = 4HH*x1 - x3 + T3 -= x; + + T1 *= y; + T1 += T1; + + // T3 = T2 * T3 = R*(4HH*x1 - x3) + T3 *= T2; + + // y3 = T3 - T1 + y = T3 - T1; +} + +template +constexpr element element::operator+=(const affine_element& other) noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + if (is_point_at_infinity()) { + *this = { other.x, other.y, Fq::one() }; + return *this; + } + } else { + const bool edge_case_trigger = x.is_msb_set() || other.x.is_msb_set(); + if (edge_case_trigger) { + if (x.is_msb_set()) { + *this = { other.x, other.y, Fq::one() }; + } + return *this; + } + } + + // T0 = z1.z1 + Fq T0 = z.sqr(); + + // T1 = x2.t0 - x1 = x2.z1.z1 - x1 + Fq T1 = other.x * T0; + T1 -= x; + + // T2 = T0.z1 = z1.z1.z1 + // T2 = T2.y2 - y1 = y2.z1.z1.z1 - y1 + Fq T2 = z * T0; + T2 *= other.y; + T2 -= y; + + if (__builtin_expect(T1.is_zero(), 0)) { + if (T2.is_zero()) { + self_dbl(); + return *this; + } + self_set_infinity(); + return *this; + } + + // T2 = 2T2 = 2(y2.z1.z1.z1 - y1) = R + // z3 = z1 + H + T2 += T2; + z += T1; + + // T3 = T1*T1 = HH + Fq T3 = T1.sqr(); + + // z3 = z3 - z1z1 - HH + T0 += T3; + + // z3 = (z1 + H)*(z1 + H) + z.self_sqr(); + z -= T0; + + // T3 = 4HH + T3 += T3; + T3 += T3; + + // T1 = T1*T3 = 4HHH + T1 *= T3; + + // T3 = T3 * x1 = 4HH*x1 + T3 *= x; + + // T0 = 2T3 + T0 = T3 + T3; + + // T0 = T0 + T1 = 2(4HH*x1) + 4HHH + T0 += T1; + x = T2.sqr(); + + // x3 = x3 - T0 = R*R - 8HH*x1 -4HHH + x -= T0; + + // T3 = T3 - x3 = 4HH*x1 - x3 + T3 -= x; + + T1 *= y; + T1 += T1; + + // T3 = T2 * T3 = R*(4HH*x1 - x3) + T3 *= T2; + + // y3 = T3 - T1 + y = T3 - T1; + return *this; +} + +template +constexpr element element::operator+(const affine_element& other) const noexcept +{ + element result(*this); + return (result += other); +} + +template +constexpr element element::operator-=(const affine_element& other) noexcept +{ + const affine_element to_add{ other.x, -other.y }; + return operator+=(to_add); +} + +template +constexpr element element::operator-(const affine_element& other) const noexcept +{ + element result(*this); + return (result -= other); +} + +template +constexpr element element::operator+=(const element& other) noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + bool p1_zero = is_point_at_infinity(); + bool p2_zero = other.is_point_at_infinity(); + if (__builtin_expect((p1_zero || p2_zero), 0)) { + if (p1_zero && !p2_zero) { + *this = other; + return *this; + } + if (p2_zero && !p1_zero) { + return *this; + } + self_set_infinity(); + return *this; + } + } else { + bool p1_zero = x.is_msb_set(); + bool p2_zero = other.x.is_msb_set(); + if (__builtin_expect((p1_zero || p2_zero), 0)) { + if (p1_zero && !p2_zero) { + *this = other; + return *this; + } + if (p2_zero && !p1_zero) { + return *this; + } + self_set_infinity(); + return *this; + } + } + Fq Z1Z1(z.sqr()); + Fq Z2Z2(other.z.sqr()); + Fq S2(Z1Z1 * z); + Fq U2(Z1Z1 * other.x); + S2 *= other.y; + Fq U1(Z2Z2 * x); + Fq S1(Z2Z2 * other.z); + S1 *= y; + + Fq F(S2 - S1); + + Fq H(U2 - U1); + + if (__builtin_expect(H.is_zero(), 0)) { + if (F.is_zero()) { + self_dbl(); + return *this; + } + self_set_infinity(); + return *this; + } + + F += F; + + Fq I(H + H); + I.self_sqr(); + + Fq J(H * I); + + U1 *= I; + + U2 = U1 + U1; + U2 += J; + + x = F.sqr(); + + x -= U2; + + J *= S1; + J += J; + + y = U1 - x; + + y *= F; + + y -= J; + + z += other.z; + + Z1Z1 += Z2Z2; + + z.self_sqr(); + z -= Z1Z1; + z *= H; + return *this; +} + +template +constexpr element element::operator+(const element& other) const noexcept +{ + BB_OP_COUNT_TRACK_NAME("element::operator+"); + element result(*this); + return (result += other); +} + +template +constexpr element element::operator-=(const element& other) noexcept +{ + const element to_add{ other.x, -other.y, other.z }; + return operator+=(to_add); +} + +template +constexpr element element::operator-(const element& other) const noexcept +{ + BB_OP_COUNT_TRACK(); + element result(*this); + return (result -= other); +} + +template constexpr element element::operator-() const noexcept +{ + return { x, -y, z }; +} + +template +element element::operator*(const Fr& exponent) const noexcept +{ + if constexpr (T::USE_ENDOMORPHISM) { + return mul_with_endomorphism(exponent); + } + return mul_without_endomorphism(exponent); +} + +template element element::operator*=(const Fr& exponent) noexcept +{ + *this = operator*(exponent); + return *this; +} + +template constexpr element element::normalize() const noexcept +{ + const affine_element converted = *this; + return element(converted); +} + +template element element::infinity() +{ + element e{}; + e.self_set_infinity(); + return e; +} + +template constexpr element element::set_infinity() const noexcept +{ + element result(*this); + result.self_set_infinity(); + return result; +} + +template constexpr void element::self_set_infinity() noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + // We set the value of x equal to modulus to represent inifinty + x.data[0] = Fq::modulus.data[0]; + x.data[1] = Fq::modulus.data[1]; + x.data[2] = Fq::modulus.data[2]; + x.data[3] = Fq::modulus.data[3]; + } else { + x.self_set_msb(); + } +} + +template constexpr bool element::is_point_at_infinity() const noexcept +{ + if constexpr (Fq::modulus.data[3] >= 0x4000000000000000ULL) { + // We check if the value of x is equal to modulus to represent inifinty + return ((x.data[0] ^ Fq::modulus.data[0]) | (x.data[1] ^ Fq::modulus.data[1]) | + (x.data[2] ^ Fq::modulus.data[2]) | (x.data[3] ^ Fq::modulus.data[3])) == 0; + } else { + return (x.is_msb_set()); + } +} + +template constexpr bool element::on_curve() const noexcept +{ + if (is_point_at_infinity()) { + return true; + } + // We specify the point at inifinity not by (0 \lambda 0), so z should not be 0 + if (z.is_zero()) { + return false; + } + Fq zz = z.sqr(); + Fq zzzz = zz.sqr(); + Fq bz_6 = zzzz * zz * T::b; + if constexpr (T::has_a) { + bz_6 += (x * T::a) * zzzz; + } + Fq xxx = x.sqr() * x + bz_6; + Fq yy = y.sqr(); + return (xxx == yy); +} + +template +constexpr bool element::operator==(const element& other) const noexcept +{ + // If one of points is not on curve, we have no business comparing them. + if ((!on_curve()) || (!other.on_curve())) { + return false; + } + bool am_infinity = is_point_at_infinity(); + bool is_infinity = other.is_point_at_infinity(); + bool both_infinity = am_infinity && is_infinity; + // If just one is infinity, then they are obviously not equal. + if ((!both_infinity) && (am_infinity || is_infinity)) { + return false; + } + const Fq lhs_zz = z.sqr(); + const Fq lhs_zzz = lhs_zz * z; + const Fq rhs_zz = other.z.sqr(); + const Fq rhs_zzz = rhs_zz * other.z; + + const Fq lhs_x = x * rhs_zz; + const Fq lhs_y = y * rhs_zzz; + + const Fq rhs_x = other.x * lhs_zz; + const Fq rhs_y = other.y * lhs_zzz; + return both_infinity || ((lhs_x == rhs_x) && (lhs_y == rhs_y)); +} + +template +element element::random_element(numeric::RNG* engine) noexcept +{ + if constexpr (T::can_hash_to_curve) { + element result = random_coordinates_on_curve(engine); + result.z = Fq::random_element(engine); + Fq zz = result.z.sqr(); + Fq zzz = zz * result.z; + result.x *= zz; + result.y *= zzz; + return result; + } else { + Fr scalar = Fr::random_element(engine); + return (element{ T::one_x, T::one_y, Fq::one() } * scalar); + } +} + +template +element element::mul_without_endomorphism(const Fr& scalar) const noexcept +{ + const uint256_t converted_scalar(scalar); + + if (converted_scalar == 0) { + return element::infinity(); + } + + element accumulator(*this); + const uint64_t maximum_set_bit = converted_scalar.get_msb(); + // This is simpler and doublings of infinity should be fast. We should think if we want to defend against the + // timing leak here (if used with ECDSA it can sometimes lead to private key compromise) + for (uint64_t i = maximum_set_bit - 1; i < maximum_set_bit; --i) { + accumulator.self_dbl(); + if (converted_scalar.get_bit(i)) { + accumulator += *this; + } + } + return accumulator; +} + +namespace detail { +// Represents the result of +using EndoScalars = std::pair, std::array>; + +/** + * @brief Handles the WNAF computation for scalars that are split using an endomorphism, + * achieved through `split_into_endomorphism_scalars`. It facilitates efficient computation of elliptic curve + * point multiplication by optimizing the representation of these scalars. + * + * @tparam Element The data type of elements in the elliptic curve. + * @tparam NUM_ROUNDS The number of computation rounds for WNAF. + */ +template struct EndomorphismWnaf { + // NUM_WNAF_BITS: Number of bits per window in the WNAF representation. + static constexpr size_t NUM_WNAF_BITS = 4; + // table: Stores the WNAF representation of the scalars. + std::array table; + // skew and endo_skew: Indicate if our original scalar is even or odd. + bool skew = false; + bool endo_skew = false; + + /** + * @param scalars A pair of 128-bit scalars (as two uint64_t arrays), split using an endomorphism. + */ + EndomorphismWnaf(const EndoScalars& scalars) + { + wnaf::fixed_wnaf(&scalars.first[0], &table[0], skew, 0, 2, NUM_WNAF_BITS); + wnaf::fixed_wnaf(&scalars.second[0], &table[1], endo_skew, 0, 2, NUM_WNAF_BITS); + } +}; + +} // namespace detail + +template +element element::mul_with_endomorphism(const Fr& scalar) const noexcept +{ + // Consider the infinity flag, return infinity if set + if (is_point_at_infinity()) { + return element::infinity(); + } + constexpr size_t NUM_ROUNDS = 32; + const Fr converted_scalar = scalar.from_montgomery_form(); + + if (converted_scalar.is_zero()) { + return element::infinity(); + } + static constexpr size_t LOOKUP_SIZE = 8; + std::array lookup_table; + + element d2 = dbl(); + lookup_table[0] = element(*this); + for (size_t i = 1; i < LOOKUP_SIZE; ++i) { + lookup_table[i] = lookup_table[i - 1] + d2; + } + + detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar); + detail::EndomorphismWnaf wnaf{ endo_scalars }; + element accumulator{ T::one_x, T::one_y, Fq::one() }; + accumulator.self_set_infinity(); + Fq beta = Fq::cube_root_of_unity(); + + for (size_t i = 0; i < NUM_ROUNDS * 2; ++i) { + uint64_t wnaf_entry = wnaf.table[i]; + uint64_t index = wnaf_entry & 0x0fffffffU; + bool sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((i & 1) == 1); + auto to_add = lookup_table[static_cast(index)]; + to_add.y.self_conditional_negate(sign ^ is_odd); + if (is_odd) { + to_add.x *= beta; + } + accumulator += to_add; + + if (i != ((2 * NUM_ROUNDS) - 1) && is_odd) { + for (size_t j = 0; j < 4; ++j) { + accumulator.self_dbl(); + } + } + } + + if (wnaf.skew) { + accumulator += -lookup_table[0]; + } + if (wnaf.endo_skew) { + accumulator += element{ lookup_table[0].x * beta, lookup_table[0].y, lookup_table[0].z }; + } + + return accumulator; +} + +/** + * @brief Pairwise affine add points in first and second group + * + * @param first_group + * @param second_group + * @param results + */ +template +void element::batch_affine_add(const std::span>& first_group, + const std::span>& second_group, + const std::span>& results) noexcept +{ + typedef affine_element affine_element; + const size_t num_points = first_group.size(); + ASSERT(second_group.size() == first_group.size()); + + // Space for temporary values + std::vector scratch_space(num_points); + + run_loop_in_parallel_if_effective( + num_points, + [&results, &first_group](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + results[i] = first_group[i]; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/2); + + // TODO(#826): Same code as in batch mul + // we can mutate rhs but NOT lhs! + // output is stored in rhs + /** + * @brief Perform point addition rhs[i]=rhs[i]+lhs[i] with batch inversion + * + */ + const auto batch_affine_add_chunked = + [](const affine_element* lhs, affine_element* rhs, const size_t point_count, Fq* personal_scratch_space) { + Fq batch_inversion_accumulator = Fq::one(); + + for (size_t i = 0; i < point_count; i += 1) { + personal_scratch_space[i] = lhs[i].x + rhs[i].x; // x2 + x1 + rhs[i].x -= lhs[i].x; // x2 - x1 + rhs[i].y -= lhs[i].y; // y2 - y1 + rhs[i].y *= batch_inversion_accumulator; // (y2 - y1)*accumulator_old + batch_inversion_accumulator *= (rhs[i].x); + } + batch_inversion_accumulator = batch_inversion_accumulator.invert(); + + for (size_t i = (point_count)-1; i < point_count; i -= 1) { + rhs[i].y *= batch_inversion_accumulator; // update accumulator + batch_inversion_accumulator *= rhs[i].x; + rhs[i].x = rhs[i].y.sqr(); + rhs[i].x = rhs[i].x - (personal_scratch_space[i]); // x3 = lambda_squared - x2 + // - x1 + personal_scratch_space[i] = lhs[i].x - rhs[i].x; + personal_scratch_space[i] *= rhs[i].y; + rhs[i].y = personal_scratch_space[i] - lhs[i].y; + } + }; + + /** + * @brief Perform batch affine addition in parallel + * + */ + const auto batch_affine_add_internal = + [num_points, &scratch_space, &batch_affine_add_chunked](const affine_element* lhs, affine_element* rhs) { + run_loop_in_parallel_if_effective( + num_points, + [lhs, &rhs, &scratch_space, &batch_affine_add_chunked](size_t start, size_t end) { + batch_affine_add_chunked(lhs + start, rhs + start, end - start, &scratch_space[0] + start); + }, + /*finite_field_additions_per_iteration=*/6, + /*finite_field_multiplications_per_iteration=*/6); + }; + batch_affine_add_internal(&second_group[0], &results[0]); +} + +/** + * @brief Multiply each point by the same scalar + * + * @details We use the fact that all points are being multiplied by the same scalar to batch the operations (perform + * batch affine additions and doublings with batch inversion trick) + * + * @param points The span of individual points that need to be scaled + * @param scalar The scalar we multiply all the points by + * @return std::vector> Vector of new points where each point is exponent⋅points[i] + */ +template +std::vector> element::batch_mul_with_endomorphism( + const std::span>& points, const Fr& scalar) noexcept +{ + BB_OP_COUNT_TIME(); + typedef affine_element affine_element; + const size_t num_points = points.size(); + + // Space for temporary values + std::vector scratch_space(num_points); + + // TODO(#826): Same code as in batch add + // we can mutate rhs but NOT lhs! + // output is stored in rhs + /** + * @brief Perform point addition rhs[i]=rhs[i]+lhs[i] with batch inversion + * + */ + const auto batch_affine_add_chunked = + [](const affine_element* lhs, affine_element* rhs, const size_t point_count, Fq* personal_scratch_space) { + Fq batch_inversion_accumulator = Fq::one(); + + for (size_t i = 0; i < point_count; i += 1) { + personal_scratch_space[i] = lhs[i].x + rhs[i].x; // x2 + x1 + rhs[i].x -= lhs[i].x; // x2 - x1 + rhs[i].y -= lhs[i].y; // y2 - y1 + rhs[i].y *= batch_inversion_accumulator; // (y2 - y1)*accumulator_old + batch_inversion_accumulator *= (rhs[i].x); + } + batch_inversion_accumulator = batch_inversion_accumulator.invert(); + + for (size_t i = (point_count)-1; i < point_count; i -= 1) { + rhs[i].y *= batch_inversion_accumulator; // update accumulator + batch_inversion_accumulator *= rhs[i].x; + rhs[i].x = rhs[i].y.sqr(); + rhs[i].x = rhs[i].x - (personal_scratch_space[i]); // x3 = lambda_squared - x2 + // - x1 + personal_scratch_space[i] = lhs[i].x - rhs[i].x; + personal_scratch_space[i] *= rhs[i].y; + rhs[i].y = personal_scratch_space[i] - lhs[i].y; + } + }; + + /** + * @brief Perform batch affine addition in parallel + * + */ + const auto batch_affine_add_internal = + [num_points, &scratch_space, &batch_affine_add_chunked](const affine_element* lhs, affine_element* rhs) { + run_loop_in_parallel_if_effective( + num_points, + [lhs, &rhs, &scratch_space, &batch_affine_add_chunked](size_t start, size_t end) { + batch_affine_add_chunked(lhs + start, rhs + start, end - start, &scratch_space[0] + start); + }, + /*finite_field_additions_per_iteration=*/6, + /*finite_field_multiplications_per_iteration=*/6); + }; + + /** + * @brief Perform point doubling lhs[i]=lhs[i]+lhs[i] with batch inversion + * + */ + const auto batch_affine_double_chunked = + [](affine_element* lhs, const size_t point_count, Fq* personal_scratch_space) { + Fq batch_inversion_accumulator = Fq::one(); + + for (size_t i = 0; i < point_count; i += 1) { + + personal_scratch_space[i] = lhs[i].x.sqr(); + personal_scratch_space[i] = + personal_scratch_space[i] + personal_scratch_space[i] + personal_scratch_space[i]; + + personal_scratch_space[i] *= batch_inversion_accumulator; + + batch_inversion_accumulator *= (lhs[i].y + lhs[i].y); + } + batch_inversion_accumulator = batch_inversion_accumulator.invert(); + + Fq temp; + for (size_t i = (point_count)-1; i < point_count; i -= 1) { + + personal_scratch_space[i] *= batch_inversion_accumulator; + batch_inversion_accumulator *= (lhs[i].y + lhs[i].y); + + temp = lhs[i].x; + lhs[i].x = personal_scratch_space[i].sqr() - (lhs[i].x + lhs[i].x); + lhs[i].y = personal_scratch_space[i] * (temp - lhs[i].x) - lhs[i].y; + } + }; + /** + * @brief Perform point doubling in parallel + * + */ + const auto batch_affine_double = [num_points, &scratch_space, &batch_affine_double_chunked](affine_element* lhs) { + run_loop_in_parallel_if_effective( + num_points, + [&lhs, &scratch_space, &batch_affine_double_chunked](size_t start, size_t end) { + batch_affine_double_chunked(lhs + start, end - start, &scratch_space[0] + start); + }, + /*finite_field_additions_per_iteration=*/7, + /*finite_field_multiplications_per_iteration=*/6); + }; + + // We compute the resulting point through WNAF by evaluating (the (\sum_i (16ⁱ⋅ + // (a_i ∈ {-15,-13,-11,-9,-7,-5,-3,-1,1,3,5,7,9,11,13,15}))) - skew), where skew is 0 or 1. The result of the sum is + // always odd and skew is used to reconstruct an even scalar. This means that to construct scalar p-1, where p is + // the order of the scalar field, we first compute p through the sums and then subtract -1. Howver, since we are + // computing p⋅Point, we get a point at infinity, which is an edgecase, and we don't want to handle edgecases in the + // hot loop since the slow the computation down. So it's better to just handle it here. + if (scalar == -Fr::one()) { + + std::vector results(num_points); + run_loop_in_parallel_if_effective( + num_points, + [&results, &points](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + results[i] = -points[i]; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + return results; + } + // Compute wnaf for scalar + const Fr converted_scalar = scalar.from_montgomery_form(); + + // If the scalar is zero, just set results to the point at infinity + if (converted_scalar.is_zero()) { + affine_element result{ Fq::zero(), Fq::zero() }; + result.self_set_infinity(); + std::vector results(num_points); + run_loop_in_parallel_if_effective( + num_points, + [&results, result](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + results[i] = result; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + return results; + } + + constexpr size_t LOOKUP_SIZE = 8; + constexpr size_t NUM_ROUNDS = 32; + std::array, LOOKUP_SIZE> lookup_table; + for (auto& table : lookup_table) { + table.resize(num_points); + } + // Initialize first etnries in lookup table + std::vector temp_point_vector(num_points); + run_loop_in_parallel_if_effective( + num_points, + [&temp_point_vector, &lookup_table, &points](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + // If the point is at infinity we fix-up the result later + // To avoid 'trying to invert zero in the field' we set the point to 'one' here + temp_point_vector[i] = points[i].is_point_at_infinity() ? affine_element::one() : points[i]; + lookup_table[0][i] = points[i].is_point_at_infinity() ? affine_element::one() : points[i]; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/2); + + // Construct lookup table + batch_affine_double(&temp_point_vector[0]); + for (size_t j = 1; j < LOOKUP_SIZE; ++j) { + run_loop_in_parallel_if_effective( + num_points, + [j, &lookup_table](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + lookup_table[j][i] = lookup_table[j - 1][i]; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + batch_affine_add_internal(&temp_point_vector[0], &lookup_table[j][0]); + } + + detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar); + detail::EndomorphismWnaf wnaf{ endo_scalars }; + + std::vector work_elements(num_points); + + constexpr Fq beta = Fq::cube_root_of_unity(); + uint64_t wnaf_entry = 0; + uint64_t index = 0; + bool sign = 0; + // Prepare elements for the first batch addition + for (size_t j = 0; j < 2; ++j) { + wnaf_entry = wnaf.table[j]; + index = wnaf_entry & 0x0fffffffU; + sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((j & 1) == 1); + run_loop_in_parallel_if_effective( + num_points, + [j, index, is_odd, sign, beta, &lookup_table, &work_elements, &temp_point_vector](size_t start, + size_t end) { + for (size_t i = start; i < end; ++i) { + + auto to_add = lookup_table[static_cast(index)][i]; + to_add.y.self_conditional_negate(sign ^ is_odd); + if (is_odd) { + to_add.x *= beta; + } + if (j == 0) { + work_elements[i] = to_add; + } else { + temp_point_vector[i] = to_add; + } + } + }, + /*finite_field_additions_per_iteration=*/1, + /*finite_field_multiplications_per_iteration=*/is_odd ? 1 : 0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + } + // First cycle of addition + batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); + // Run through SM logic in wnaf form (excluding the skew) + for (size_t j = 2; j < NUM_ROUNDS * 2; ++j) { + wnaf_entry = wnaf.table[j]; + index = wnaf_entry & 0x0fffffffU; + sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((j & 1) == 1); + if (!is_odd) { + for (size_t k = 0; k < 4; ++k) { + batch_affine_double(&work_elements[0]); + } + } + run_loop_in_parallel_if_effective( + num_points, + [index, is_odd, sign, beta, &lookup_table, &temp_point_vector](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + + auto to_add = lookup_table[static_cast(index)][i]; + to_add.y.self_conditional_negate(sign ^ is_odd); + if (is_odd) { + to_add.x *= beta; + } + temp_point_vector[i] = to_add; + } + }, + /*finite_field_additions_per_iteration=*/1, + /*finite_field_multiplications_per_iteration=*/is_odd ? 1 : 0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); + } + + // Apply skew for the first endo scalar + if (wnaf.skew) { + run_loop_in_parallel_if_effective( + num_points, + [&lookup_table, &temp_point_vector](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + + temp_point_vector[i] = -lookup_table[0][i]; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/0, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); + } + // Apply skew for the second endo scalar + if (wnaf.endo_skew) { + run_loop_in_parallel_if_effective( + num_points, + [beta, &lookup_table, &temp_point_vector](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + temp_point_vector[i] = lookup_table[0][i]; + temp_point_vector[i].x *= beta; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/1, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); + } + // handle points at infinity explicitly + run_loop_in_parallel_if_effective( + num_points, + [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + work_elements[i] = + points[i].is_point_at_infinity() ? work_elements[i].set_infinity() : work_elements[i]; + } + }, + /*finite_field_additions_per_iteration=*/0, + /*finite_field_multiplications_per_iteration=*/1, + /*finite_field_inversions_per_iteration=*/0, + /*group_element_additions_per_iteration=*/0, + /*group_element_doublings_per_iteration=*/0, + /*scalar_multiplications_per_iteration=*/0, + /*sequential_copy_ops_per_iteration=*/1); + + return work_elements; +} + +template +void element::conditional_negate_affine(const affine_element& in, + affine_element& out, + const uint64_t predicate) noexcept +{ + out = { in.x, predicate ? -in.y : in.y }; +} + +template +void element::batch_normalize(element* elements, const size_t num_elements) noexcept +{ + std::vector temporaries; + temporaries.reserve(num_elements * 2); + Fq accumulator = Fq::one(); + + // Iterate over the points, computing the product of their z-coordinates. + // At each iteration, store the currently-accumulated z-coordinate in `temporaries` + for (size_t i = 0; i < num_elements; ++i) { + temporaries.emplace_back(accumulator); + if (!elements[i].is_point_at_infinity()) { + accumulator *= elements[i].z; + } + } + // For the rest of this method we refer to the product of all z-coordinates as the 'global' z-coordinate + // Invert the global z-coordinate and store in `accumulator` + accumulator = accumulator.invert(); + + /** + * We now proceed to iterate back down the array of points. + * At each iteration we update the accumulator to contain the z-coordinate of the currently worked-upon + *z-coordinate. We can then multiply this accumulator with `temporaries`, to get a scalar that is equal to the + *inverse of the z-coordinate of the point at the next iteration cycle e.g. Imagine we have 4 points, such that: + * + * accumulator = 1 / z.data[0]*z.data[1]*z.data[2]*z.data[3] + * temporaries[3] = z.data[0]*z.data[1]*z.data[2] + * temporaries[2] = z.data[0]*z.data[1] + * temporaries[1] = z.data[0] + * temporaries[0] = 1 + * + * At the first iteration, accumulator * temporaries[3] = z.data[0]*z.data[1]*z.data[2] / + *z.data[0]*z.data[1]*z.data[2]*z.data[3] = (1 / z.data[3]) We then update accumulator, such that: + * + * accumulator = accumulator * z.data[3] = 1 / z.data[0]*z.data[1]*z.data[2] + * + * At the second iteration, accumulator * temporaries[2] = z.data[0]*z.data[1] / z.data[0]*z.data[1]*z.data[2] = + *(1 z.data[2]) And so on, until we have computed every z-inverse! + * + * We can then convert out of Jacobian form (x = X / Z^2, y = Y / Z^3) with 4 muls and 1 square. + **/ + for (size_t i = num_elements - 1; i < num_elements; --i) { + if (!elements[i].is_point_at_infinity()) { + Fq z_inv = accumulator * temporaries[i]; + Fq zz_inv = z_inv.sqr(); + elements[i].x *= zz_inv; + elements[i].y *= (zz_inv * z_inv); + accumulator *= elements[i].z; + } + elements[i].z = Fq::one(); + } +} + +template +template +element element::random_coordinates_on_curve(numeric::RNG* engine) noexcept +{ + bool found_one = false; + Fq yy; + Fq x; + Fq y; + while (!found_one) { + x = Fq::random_element(engine); + yy = x.sqr() * x + T::b; + if constexpr (T::has_a) { + yy += (x * T::a); + } + auto [found_root, y1] = yy.sqrt(); + y = y1; + found_one = found_root; + } + return { x, y, Fq::one() }; +} + +} // namespace bb::group_elements +// NOLINTEND(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays) diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group.hpp new file mode 100644 index 0000000..b660fdc --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include "../../common/assert.hpp" +#include "./affine_element.hpp" +#include "./element.hpp" +#include "./wnaf.hpp" +#include "../../common/constexpr_utils.hpp" +#include "../../crypto/blake3s/blake3s.hpp" +#include +#include +#include +#include +#include +namespace bb { + +/** + * @brief group class. Represents an elliptic curve group element. + * Group is parametrised by coordinate_field and subgroup_field + * + * Note: Currently subgroup checks are NOT IMPLEMENTED + * Our current Plonk implementation uses G1 points that have a cofactor of 1. + * All G2 points are precomputed (generator [1]_2 and trusted setup point [x]_2). + * Explicitly assume precomputed points are valid members of the prime-order subgroup for G2. + * + * @tparam coordinate_field + * @tparam subgroup_field + * @tparam GroupParams + */ +template class group { + public: + // hoist coordinate_field, subgroup_field into the public namespace + using coordinate_field = _coordinate_field; + using subgroup_field = _subgroup_field; + using element = group_elements::element; + using affine_element = group_elements::affine_element; + using Fq = coordinate_field; + using Fr = subgroup_field; + static constexpr bool USE_ENDOMORPHISM = GroupParams::USE_ENDOMORPHISM; + static constexpr bool has_a = GroupParams::has_a; + + static constexpr element one{ GroupParams::one_x, GroupParams::one_y, coordinate_field::one() }; + static constexpr element point_at_infinity = one.set_infinity(); + static constexpr affine_element affine_one{ GroupParams::one_x, GroupParams::one_y }; + static constexpr affine_element affine_point_at_infinity = affine_one.set_infinity(); + static constexpr coordinate_field curve_a = GroupParams::a; + static constexpr coordinate_field curve_b = GroupParams::b; + + /** + * @brief Derives generator points via hash-to-curve + * + * ALGORITHM DESCRIPTION: + * 1. Each generator has an associated "generator index" described by its location in the vector + * 2. a 64-byte preimage buffer is generated with the following structure: + * bytes 0-31: BLAKE3 hash of domain_separator + * bytes 32-63: generator index in big-endian form + * 3. The hash-to-curve algorithm is used to hash the above into a group element: + * a. iterate `count` upwards from `0` + * b. append `count` to the preimage buffer as a 1-byte integer in big-endian form + * c. compute BLAKE3 hash of concat(preimage buffer, 0) + * d. compute BLAKE3 hash of concat(preimage buffer, 1) + * e. interpret (c, d) as (hi, low) limbs of a 512-bit integer + * f. reduce 512-bit integer modulo coordinate_field to produce x-coordinate + * g. attempt to derive y-coordinate. If not successful go to step (a) and continue + * h. if parity of y-coordinate's least significant bit does not match parity of most significant bit of + * (d), invert y-coordinate. + * j. return (x, y) + * + * NOTE: In step 3b it is sufficient to use 1 byte to store `count`. + * Step 3 has a 50% chance of returning, the probability of `count` exceeding 256 is 1 in 2^256 + * NOTE: The domain separator is included to ensure that it is possible to derive independent sets of + * index-addressable generators. + * NOTE: we produce 64 bytes of BLAKE3 output when producing x-coordinate field + * element, to ensure that x-coordinate is uniformly randomly distributed in the field. Using a 256-bit input adds + * significant bias when reducing modulo a ~256-bit coordinate_field + * NOTE: We ensure y-parity is linked to preimage + * hash because there is no canonical deterministic square root algorithm (i.e. if a field element has a square + * root, there are two of them and `field::sqrt` may return either one) + * @param num_generators + * @param domain_separator + * @return std::vector + */ + inline static constexpr std::vector derive_generators( + const std::vector& domain_separator_bytes, + const size_t num_generators, + const size_t starting_index = 0) + { + std::vector result; + const auto domain_hash = blake3::blake3s_constexpr(&domain_separator_bytes[0], domain_separator_bytes.size()); + std::vector generator_preimage; + generator_preimage.reserve(64); + std::copy(domain_hash.begin(), domain_hash.end(), std::back_inserter(generator_preimage)); + for (size_t i = 0; i < 32; ++i) { + generator_preimage.emplace_back(0); + } + for (size_t i = starting_index; i < starting_index + num_generators; ++i) { + auto generator_index = static_cast(i); + uint32_t mask = 0xff; + generator_preimage[32] = static_cast(generator_index >> 24); + generator_preimage[33] = static_cast((generator_index >> 16) & mask); + generator_preimage[34] = static_cast((generator_index >> 8) & mask); + generator_preimage[35] = static_cast(generator_index & mask); + result.push_back(affine_element::hash_to_curve(generator_preimage)); + } + return result; + } + + inline static constexpr std::vector derive_generators(const std::string_view& domain_separator, + const size_t num_generators, + const size_t starting_index = 0) + { + std::vector domain_bytes; + for (char i : domain_separator) { + domain_bytes.emplace_back(static_cast(i)); + } + return derive_generators(domain_bytes, num_generators, starting_index); + } + + BB_INLINE static void conditional_negate_affine(const affine_element* src, + affine_element* dest, + uint64_t predicate); +}; + +} // namespace bb + +#ifdef DISABLE_ASM +#include "group_impl_int128.tcc" +#else +#include "group_impl_asm.tcc" +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group_impl_asm.tcc b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group_impl_asm.tcc new file mode 100644 index 0000000..3ea790c --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group_impl_asm.tcc @@ -0,0 +1,162 @@ +#pragma once + +#ifndef DISABLE_ASM + +#include "barretenberg/ecc/groups/group.hpp" +#include + +namespace bb { +// copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries +// template +// inline void group::copy(const affine_element* src, affine_element* +// dest) +// { +// if constexpr (GroupParams::small_elements) { +// #if defined __AVX__ && defined USE_AVX +// ASSERT((((uintptr_t)src & 0x1f) == 0)); +// ASSERT((((uintptr_t)dest & 0x1f) == 0)); +// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t" +// "vmovdqa 32(%0), %%ymm1 \n\t" +// "vmovdqa %%ymm0, 0(%1) \n\t" +// "vmovdqa %%ymm1, 32(%1) \n\t" +// : +// : "r"(src), "r"(dest) +// : "%ymm0", "%ymm1", "memory"); +// #else +// *dest = *src; +// #endif +// } else { +// *dest = *src; +// } +// } + +// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries +// template +// inline void group::copy(const element* src, element* dest) +// { +// if constexpr (GroupParams::small_elements) { +// #if defined __AVX__ && defined USE_AVX +// ASSERT((((uintptr_t)src & 0x1f) == 0)); +// ASSERT((((uintptr_t)dest & 0x1f) == 0)); +// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t" +// "vmovdqa 32(%0), %%ymm1 \n\t" +// "vmovdqa 64(%0), %%ymm2 \n\t" +// "vmovdqa %%ymm0, 0(%1) \n\t" +// "vmovdqa %%ymm1, 32(%1) \n\t" +// "vmovdqa %%ymm2, 64(%1) \n\t" +// : +// : "r"(src), "r"(dest) +// : "%ymm0", "%ymm1", "%ymm2", "memory"); +// #else +// *dest = *src; +// #endif +// } else { +// *dest = src; +// } +// } + +// copies src into dest, inverting y-coordinate if 'predicate' is true +// n.b. requires src and dest to be aligned on 32 byte boundary +template +inline void group::conditional_negate_affine(const affine_element* src, + affine_element* dest, + uint64_t predicate) +{ + constexpr uint256_t twice_modulus = coordinate_field::modulus + coordinate_field::modulus; + + constexpr uint64_t twice_modulus_0 = twice_modulus.data[0]; + constexpr uint64_t twice_modulus_1 = twice_modulus.data[1]; + constexpr uint64_t twice_modulus_2 = twice_modulus.data[2]; + constexpr uint64_t twice_modulus_3 = twice_modulus.data[3]; + + if constexpr (GroupParams::small_elements) { +#if defined __AVX__ && defined USE_AVX + ASSERT((((uintptr_t)src & 0x1f) == 0)); + ASSERT((((uintptr_t)dest & 0x1f) == 0)); + __asm__ __volatile__("xorq %%r8, %%r8 \n\t" + "movq 32(%0), %%r8 \n\t" + "movq 40(%0), %%r9 \n\t" + "movq 48(%0), %%r10 \n\t" + "movq 56(%0), %%r11 \n\t" + "movq %[modulus_0], %%r12 \n\t" + "movq %[modulus_1], %%r13 \n\t" + "movq %[modulus_2], %%r14 \n\t" + "movq %[modulus_3], %%r15 \n\t" + "subq %%r8, %%r12 \n\t" + "sbbq %%r9, %%r13 \n\t" + "sbbq %%r10, %%r14 \n\t" + "sbbq %%r11, %%r15 \n\t" + "testq %2, %2 \n\t" + "cmovnzq %%r12, %%r8 \n\t" + "cmovnzq %%r13, %%r9 \n\t" + "cmovnzq %%r14, %%r10 \n\t" + "cmovnzq %%r15, %%r11 \n\t" + "vmovdqa 0(%0), %%ymm0 \n\t" + "vmovdqa %%ymm0, 0(%1) \n\t" + "movq %%r8, 32(%1) \n\t" + "movq %%r9, 40(%1) \n\t" + "movq %%r10, 48(%1) \n\t" + "movq %%r11, 56(%1) \n\t" + : + : "r"(src), + "r"(dest), + "r"(predicate), + [modulus_0] "i"(twice_modulus_0), + [modulus_1] "i"(twice_modulus_1), + [modulus_2] "i"(twice_modulus_2), + [modulus_3] "i"(twice_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "%ymm0", "memory", "cc"); +#else + __asm__ __volatile__("xorq %%r8, %%r8 \n\t" + "movq 32(%0), %%r8 \n\t" + "movq 40(%0), %%r9 \n\t" + "movq 48(%0), %%r10 \n\t" + "movq 56(%0), %%r11 \n\t" + "movq %[modulus_0], %%r12 \n\t" + "movq %[modulus_1], %%r13 \n\t" + "movq %[modulus_2], %%r14 \n\t" + "movq %[modulus_3], %%r15 \n\t" + "subq %%r8, %%r12 \n\t" + "sbbq %%r9, %%r13 \n\t" + "sbbq %%r10, %%r14 \n\t" + "sbbq %%r11, %%r15 \n\t" + "testq %2, %2 \n\t" + "cmovnzq %%r12, %%r8 \n\t" + "cmovnzq %%r13, %%r9 \n\t" + "cmovnzq %%r14, %%r10 \n\t" + "cmovnzq %%r15, %%r11 \n\t" + "movq 0(%0), %%r12 \n\t" + "movq 8(%0), %%r13 \n\t" + "movq 16(%0), %%r14 \n\t" + "movq 24(%0), %%r15 \n\t" + "movq %%r8, 32(%1) \n\t" + "movq %%r9, 40(%1) \n\t" + "movq %%r10, 48(%1) \n\t" + "movq %%r11, 56(%1) \n\t" + "movq %%r12, 0(%1) \n\t" + "movq %%r13, 8(%1) \n\t" + "movq %%r14, 16(%1) \n\t" + "movq %%r15, 24(%1) \n\t" + : + : "r"(src), + "r"(dest), + "r"(predicate), + [modulus_0] "i"(twice_modulus_0), + [modulus_1] "i"(twice_modulus_1), + [modulus_2] "i"(twice_modulus_2), + [modulus_3] "i"(twice_modulus_3) + : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "memory", "cc"); +#endif + } else { + if (predicate) { // NOLINT + coordinate_field::__copy(src->x, dest->x); + dest->y = -src->y; + } else { + copy_affine(*src, *dest); + } + } +} + +} // namespace bb + +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group_impl_int128.tcc b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group_impl_int128.tcc new file mode 100644 index 0000000..7f82c83 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/group_impl_int128.tcc @@ -0,0 +1,34 @@ +#pragma once + +#ifdef DISABLE_ASM + +#include "barretenberg/ecc/groups/group.hpp" +#include + +namespace bb { + +// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries +// template +// inline void group::copy(const affine_element* src, affine_element* +// dest) +// { +// *dest = *src; +// } + +// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries +// template +// inline void group::copy(const element* src, element* dest) +// { +// *dest = *src; +// } + +template +inline void group::conditional_negate_affine(const affine_element* src, + affine_element* dest, + uint64_t predicate) +{ + *dest = predicate ? -(*src) : (*src); +} +} // namespace bb + +#endif \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/ecc/groups/wnaf.hpp b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/wnaf.hpp new file mode 100644 index 0000000..462cc2b --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/ecc/groups/wnaf.hpp @@ -0,0 +1,513 @@ +#pragma once +#include "../../numeric/bitop/get_msb.hpp" +#include +#include + +// NOLINTBEGIN(readability-implicit-bool-conversion) +namespace bb::wnaf { +constexpr size_t SCALAR_BITS = 127; + +#define WNAF_SIZE(x) ((bb::wnaf::SCALAR_BITS + (x)-1) / (x)) // NOLINT(cppcoreguidelines-macro-usage) + +constexpr size_t get_optimal_bucket_width(const size_t num_points) +{ + if (num_points >= 14617149) { + return 21; + } + if (num_points >= 1139094) { + return 18; + } + // if (num_points >= 100000) + if (num_points >= 155975) { + return 15; + } + if (num_points >= 144834) + // if (num_points >= 100000) + { + return 14; + } + if (num_points >= 25067) { + return 12; + } + if (num_points >= 13926) { + return 11; + } + if (num_points >= 7659) { + return 10; + } + if (num_points >= 2436) { + return 9; + } + if (num_points >= 376) { + return 7; + } + if (num_points >= 231) { + return 6; + } + if (num_points >= 97) { + return 5; + } + if (num_points >= 35) { + return 4; + } + if (num_points >= 10) { + return 3; + } + if (num_points >= 2) { + return 2; + } + return 1; +} +constexpr size_t get_num_buckets(const size_t num_points) +{ + const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2); + return 1UL << bits_per_bucket; +} + +constexpr size_t get_num_rounds(const size_t num_points) +{ + const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2); + return WNAF_SIZE(bits_per_bucket + 1); +} + +template inline uint64_t get_wnaf_bits_const(const uint64_t* scalar) noexcept +{ + if constexpr (bits == 0) { + return 0ULL; + } else { + /** + * we want to take a 128 bit scalar and shift it down by (bit_position). + * We then wish to mask out `bits` number of bits. + * Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also + * (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual + * bit position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - + * (bit_position & 63) + * + * So, step 1: + * get low limb and shift right by (bit_position & 63) + * get high limb and shift left by (64 - (bit_position & 63)) + * + */ + constexpr size_t lo_limb_idx = bit_position / 64; + constexpr size_t hi_limb_idx = (bit_position + bits - 1) / 64; + constexpr uint64_t lo_shift = bit_position & 63UL; + constexpr uint64_t bit_mask = (1UL << static_cast(bits)) - 1UL; + + uint64_t lo = (scalar[lo_limb_idx] >> lo_shift); + if constexpr (lo_limb_idx == hi_limb_idx) { + return lo & bit_mask; + } else { + constexpr uint64_t hi_shift = 64UL - (bit_position & 63UL); + uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift))); + return (lo | hi) & bit_mask; + } + } +} + +inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const uint64_t bit_position) noexcept +{ + /** + * we want to take a 128 bit scalar and shift it down by (bit_position). + * We then wish to mask out `bits` number of bits. + * Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also + * (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual bit + * position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - (bit_position + * & 63) + * + * So, step 1: + * get low limb and shift right by (bit_position & 63) + * get high limb and shift left by (64 - (bit_position & 63)) + * + */ + const auto lo_limb_idx = static_cast(bit_position >> 6); + const auto hi_limb_idx = static_cast((bit_position + bits - 1) >> 6); + const uint64_t lo_shift = bit_position & 63UL; + const uint64_t bit_mask = (1UL << static_cast(bits)) - 1UL; + + const uint64_t lo = (scalar[lo_limb_idx] >> lo_shift); + const uint64_t hi_shift = bit_position ? 64UL - (bit_position & 63UL) : 0; + const uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift))); + const uint64_t hi_mask = bit_mask & (0ULL - (lo_limb_idx != hi_limb_idx)); + + return (lo & bit_mask) | (hi & hi_mask); +} + +inline void fixed_wnaf_packed( + const uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const uint64_t point_index, const size_t wnaf_bits) noexcept +{ + skew_map = ((scalar[0] & 1) == 0); + uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); + const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + + for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { + uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[(wnaf_entries - round_i)] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + previous = slice + predicate; + } + size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1)); + uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); + uint64_t predicate = ((slice & 1UL) == 0UL); + + wnaf[1] = ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); +} + +/** + * @brief Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication. + * + * WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn optimizes + * the number of point doublings in scalar multiplication, in turn aiding efficiency. + * + * @param scalar Pointer to 128-bit scalar for which WNAF is to be computed. + * @param wnaf Pointer to num_points+1 size array where the computed WNAF will be stored. + * @param skew_map Reference to a boolean variable which will be set based on the least significant bit of the scalar. + * @param point_index The index of the point being computed in the context of multiple point multiplication. + * @param num_points The number of points being computed in parallel. + * @param wnaf_bits The number of bits to use in each window of the WNAF representation. + */ +inline void fixed_wnaf(const uint64_t* scalar, + uint64_t* wnaf, + bool& skew_map, + const uint64_t point_index, + const uint64_t num_points, + const size_t wnaf_bits) noexcept +{ + skew_map = ((scalar[0] & 1) == 0); + uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); + const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + + for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { + uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[(wnaf_entries - round_i) * num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + previous = slice + predicate; + } + size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1)); + uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); + uint64_t predicate = ((slice & 1UL) == 0UL); + + wnaf[num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); +} + +/** + * Current flow... + * + * If a wnaf entry is even, we add +1 to it, and subtract 32 from the previous entry. + * This works if the previous entry is odd. If we recursively apply this process, starting at the least significant + *window, this will always be the case. + * + * However, we want to skip over windows that are 0, which poses a problem. + * + * Scenario 1: even window followed by 0 window followed by any window 'x' + * + * We can't add 1 to the even window and subtract 32 from the 0 window, as we don't have a bucket that maps to -32 + * This means that we have to identify whether we are going to borrow 32 from 'x', requiring us to look at least 2 + *steps ahead + * + * Scenario 2: <0> <0> + * + * This problem proceeds indefinitely - if we have adjacent 0 windows, we do not know whether we need to track a + *borrow flag until we identify the next non-zero window + * + * Scenario 3: <0> + * + * This one works... + * + * Ok, so we should be a bit more limited with when we don't include window entries. + * The goal here is to identify short scalars, so we want to identify the most significant non-zero window + **/ +inline uint64_t get_num_scalar_bits(const uint64_t* scalar) +{ + const uint64_t msb_1 = numeric::get_msb(scalar[1]); + const uint64_t msb_0 = numeric::get_msb(scalar[0]); + + const uint64_t scalar_1_mask = (0ULL - (scalar[1] > 0)); + const uint64_t scalar_0_mask = (0ULL - (scalar[0] > 0)) & ~scalar_1_mask; + + const uint64_t msb = (scalar_1_mask & (msb_1 + 64)) | (scalar_0_mask & (msb_0)); + return msb; +} + +/** + * How to compute an x-bit wnaf slice? + * + * Iterate over number of slices in scalar. + * For each slice, if slice is even, ADD +1 to current slice and SUBTRACT 2^x from previous slice. + * (for 1st slice we instead add +1 and set the scalar's 'skew' value to 'true' (i.e. need to subtract 1 from it at the + * end of our scalar mul algo)) + * + * In *wnaf we store the following: + * 1. bits 0-30: ABSOLUTE value of wnaf (i.e. -3 goes to 3) + * 2. bit 31: 'predicate' bool (i.e. does the wnaf value need to be negated?) + * 3. bits 32-63: position in a point array that describes the elliptic curve point this wnaf slice is referencing + * + * N.B. IN OUR STDLIB ALGORITHMS THE SKEW VALUE REPRESENTS AN ADDITION NOT A SUBTRACTION (i.e. we add +1 at the end of + * the scalar mul algo we don't sub 1) (this is to eliminate situations which could produce the point at infinity as an + * output as our circuit logic cannot accommodate this edge case). + * + * Credits: Zac W. + * + * @param scalar Pointer to the 128-bit non-montgomery scalar that is supposed to be transformed into wnaf + * @param wnaf Pointer to output array that needs to accommodate enough 64-bit WNAF entries + * @param skew_map Reference to output skew value, which if true shows that the point should be added once at the end of + * computation + * @param wnaf_round_counts Pointer to output array specifying the number of points participating in each round + * @param point_index The index of the point that should be multiplied by this scalar in the point array + * @param num_points Total points in the MSM (2*num_initial_points) + * + */ +inline void fixed_wnaf_with_counts(const uint64_t* scalar, + uint64_t* wnaf, + bool& skew_map, + uint64_t* wnaf_round_counts, + const uint64_t point_index, + const uint64_t num_points, + const size_t wnaf_bits) noexcept +{ + const size_t max_wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + if ((scalar[0] | scalar[1]) == 0ULL) { + skew_map = false; + for (size_t round_i = 0; round_i < max_wnaf_entries; ++round_i) { + wnaf[(round_i)*num_points] = 0xffffffffffffffffULL; + } + return; + } + const auto current_scalar_bits = static_cast(get_num_scalar_bits(scalar) + 1); + skew_map = ((scalar[0] & 1) == 0); + uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); + const auto wnaf_entries = static_cast((current_scalar_bits + wnaf_bits - 1) / wnaf_bits); + + if (wnaf_entries == 1) { + wnaf[(max_wnaf_entries - 1) * num_points] = (previous >> 1UL) | (point_index); + ++wnaf_round_counts[max_wnaf_entries - 1]; + for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) { + wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL; + } + return; + } + + // If there are several windows + for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { + + // Get a bit slice + uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + + // Get the predicate (last bit is zero) + uint64_t predicate = ((slice & 1UL) == 0UL); + + // Update round count + ++wnaf_round_counts[max_wnaf_entries - round_i]; + + // Calculate entry value + // If the last bit of current slice is 1, we simply put the previous value with the point index + // If the last bit of the current slice is 0, we negate everything, so that we subtract from the WNAF form and + // make it 0 + wnaf[(max_wnaf_entries - round_i) * num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + + // Update the previous value to the next windows + previous = slice + predicate; + } + // The final iteration for top bits + auto final_bits = static_cast(current_scalar_bits - (wnaf_bits * (wnaf_entries - 1))); + uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); + uint64_t predicate = ((slice & 1UL) == 0UL); + + ++wnaf_round_counts[(max_wnaf_entries - wnaf_entries + 1)]; + wnaf[((max_wnaf_entries - wnaf_entries + 1) * num_points)] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + + // Saving top bits + ++wnaf_round_counts[max_wnaf_entries - wnaf_entries]; + wnaf[(max_wnaf_entries - wnaf_entries) * num_points] = ((slice + predicate) >> 1UL) | (point_index); + + // Fill all unused slots with -1 + for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) { + wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL; + } +} + +template +inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept +{ + constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + constexpr auto log2_num_points = static_cast(numeric::get_msb(static_cast(num_points))); + + if constexpr (round_i < wnaf_entries - 1) { + uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[(wnaf_entries - round_i) << log2_num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index << 32UL); + wnaf_round(scalar, wnaf, point_index, slice + predicate); + } else { + constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits; + uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); + // uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index << 32UL); + wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); + } +} + +template +inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept +{ + constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits; + constexpr auto log2_num_points = static_cast(numeric::get_msb(static_cast(num_points))); + + if constexpr (round_i < wnaf_entries - 1) { + uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[(wnaf_entries - round_i) << log2_num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index << 32UL); + wnaf_round(scalar, wnaf, point_index, slice + predicate); + } else { + constexpr size_t final_bits = ((scalar_bits / wnaf_bits) * wnaf_bits == scalar_bits) + ? wnaf_bits + : scalar_bits - (scalar_bits / wnaf_bits) * wnaf_bits; + uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index << 32UL); + wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); + } +} + +template +inline void wnaf_round_packed(const uint64_t* scalar, + uint64_t* wnaf, + const uint64_t point_index, + const uint64_t previous) noexcept +{ + constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + + if constexpr (round_i < wnaf_entries - 1) { + uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + // uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[(wnaf_entries - round_i)] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + wnaf_round_packed(scalar, wnaf, point_index, slice + predicate); + } else { + constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits; + uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); + // uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[1] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index); + + wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); + } +} + +template +inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept +{ + skew_map = ((scalar[0] & 1) == 0); + uint64_t previous = get_wnaf_bits_const(scalar) + static_cast(skew_map); + wnaf_round(scalar, wnaf, point_index, previous); +} + +template +inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept +{ + skew_map = ((scalar[0] & 1) == 0); + uint64_t previous = get_wnaf_bits_const(scalar) + static_cast(skew_map); + wnaf_round(scalar, wnaf, point_index, previous); +} + +template +inline void wnaf_round_with_restricted_first_slice(uint64_t* scalar, + uint64_t* wnaf, + const uint64_t point_index, + const uint64_t previous) noexcept +{ + constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits; + constexpr auto log2_num_points = static_cast(numeric::get_msb(static_cast(num_points))); + constexpr size_t bits_in_first_slice = scalar_bits % wnaf_bits; + if constexpr (round_i == 1) { + uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + + wnaf[(wnaf_entries - round_i) << log2_num_points] = + ((((previous - (predicate << (bits_in_first_slice /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | + (predicate << 31UL)) | + (point_index << 32UL); + if (round_i == 1) { + std::cerr << "writing value " << std::hex << wnaf[(wnaf_entries - round_i) << log2_num_points] << std::dec + << " at index " << ((wnaf_entries - round_i) << log2_num_points) << std::endl; + } + wnaf_round_with_restricted_first_slice( + scalar, wnaf, point_index, slice + predicate); + + } else if constexpr (round_i < wnaf_entries - 1) { + uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[(wnaf_entries - round_i) << log2_num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index << 32UL); + wnaf_round_with_restricted_first_slice( + scalar, wnaf, point_index, slice + predicate); + } else { + uint64_t slice = get_wnaf_bits_const(scalar); + uint64_t predicate = ((slice & 1UL) == 0UL); + wnaf[num_points] = + ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + (point_index << 32UL); + wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); + } +} + +template +inline void fixed_wnaf_with_restricted_first_slice(uint64_t* scalar, + uint64_t* wnaf, + bool& skew_map, + const size_t point_index) noexcept +{ + constexpr size_t bits_in_first_slice = num_bits % wnaf_bits; + std::cerr << "bits in first slice = " << bits_in_first_slice << std::endl; + skew_map = ((scalar[0] & 1) == 0); + uint64_t previous = get_wnaf_bits_const(scalar) + static_cast(skew_map); + std::cerr << "previous = " << previous << std::endl; + wnaf_round_with_restricted_first_slice(scalar, wnaf, point_index, previous); +} + +// template +// inline void fixed_wnaf_packed(const uint64_t* scalar, +// uint64_t* wnaf, +// bool& skew_map, +// const uint64_t point_index) noexcept +// { +// skew_map = ((scalar[0] & 1) == 0); +// uint64_t previous = get_wnaf_bits_const(scalar) + (uint64_t)skew_map; +// wnaf_round_packed(scalar, wnaf, point_index, previous); +// } + +// template +// inline constexpr std::array fixed_wnaf(const uint64_t *scalar) const noexcept +// { +// bool skew_map = ((scalar[0] * 1) == 0); +// uint64_t previous = get_wnaf_bits_const(scalar) + (uint64_t)skew_map; +// std::array result; +// } +} // namespace bb::wnaf + +// NOLINTEND(readability-implicit-bool-conversion) diff --git a/sumcheck/src/cuda/includes/barretenberg/env/logstr.cpp b/sumcheck/src/cuda/includes/barretenberg/env/logstr.cpp new file mode 100644 index 0000000..ca836b2 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/env/logstr.cpp @@ -0,0 +1,9 @@ +#include + +extern "C" { + +void logstr(char const* str) +{ + std::cerr << str << std::endl; +} +} \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/env/logstr.hpp b/sumcheck/src/cuda/includes/barretenberg/env/logstr.hpp new file mode 100644 index 0000000..e093c3d --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/env/logstr.hpp @@ -0,0 +1 @@ +void logstr(char const*); diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/bitop.bench.cpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/bitop.bench.cpp new file mode 100644 index 0000000..aa2e6bc --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/bitop.bench.cpp @@ -0,0 +1,17 @@ +#include "count_leading_zeros.hpp" +#include + +using namespace benchmark; + +void count_leading_zeros(State& state) noexcept +{ + uint256_t input = 7; + for (auto _ : state) { + auto r = count_leading_zeros(input); + DoNotOptimize(r); + } +} +BENCHMARK(count_leading_zeros); + +// NOLINTNEXTLINE macro invokation triggers style errors from googletest code +BENCHMARK_MAIN(); diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/count_leading_zeros.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/count_leading_zeros.hpp new file mode 100644 index 0000000..acf8179 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/count_leading_zeros.hpp @@ -0,0 +1,52 @@ +#pragma once +#include "../uint128/uint128.hpp" +#include "../uint256/uint256.hpp" +#include + +namespace bb::numeric { + +/** + * Returns the number of leading 0 bits for a given integer type. + * Implemented in terms of intrinsics which will use instructions such as `bsr` or `lzcnt` for best performance. + * Undefined behavior when input is 0. + */ +template constexpr inline size_t count_leading_zeros(T const& u); + +template <> constexpr inline size_t count_leading_zeros(uint32_t const& u) +{ + return static_cast(__builtin_clz(u)); +} + +template <> constexpr inline size_t count_leading_zeros(uint64_t const& u) +{ + return static_cast(__builtin_clzll(u)); +} + +template <> constexpr inline size_t count_leading_zeros(uint128_t const& u) +{ + auto hi = static_cast(u >> 64); + if (hi != 0U) { + return static_cast(__builtin_clzll(hi)); + } + auto lo = static_cast(u); + return static_cast(__builtin_clzll(lo)) + 64; +} + +template <> constexpr inline size_t count_leading_zeros(uint256_t const& u) +{ + if (u.data[3] != 0U) { + return count_leading_zeros(u.data[3]); + } + if (u.data[2] != 0U) { + return count_leading_zeros(u.data[2]) + 64; + } + if (u.data[1] != 0U) { + return count_leading_zeros(u.data[1]) + 128; + } + if (u.data[0] != 0U) { + return count_leading_zeros(u.data[0]) + 192; + } + return 256; +} + +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/count_leading_zeros.test.cpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/count_leading_zeros.test.cpp new file mode 100644 index 0000000..0849a88 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/count_leading_zeros.test.cpp @@ -0,0 +1,36 @@ +#include "count_leading_zeros.hpp" +#include + +using namespace bb; + +TEST(bitop, ClzUint3231) +{ + uint32_t a = 0b00000000000000000000000000000001; + EXPECT_EQ(numeric::count_leading_zeros(a), 31U); +} + +TEST(bitop, ClzUint320) +{ + uint32_t a = 0b10000000000000000000000000000001; + EXPECT_EQ(numeric::count_leading_zeros(a), 0U); +} + +TEST(bitop, ClzUint640) +{ + uint64_t a = 0b1000000000000000000000000000000100000000000000000000000000000000; + EXPECT_EQ(numeric::count_leading_zeros(a), 0U); +} + +TEST(bitop, ClzUint256255) +{ + uint256_t a = 0x1; + auto r = numeric::count_leading_zeros(a); + EXPECT_EQ(r, 255U); +} + +TEST(bitop, ClzUint256248) +{ + uint256_t a = 0x80; + auto r = numeric::count_leading_zeros(a); + EXPECT_EQ(r, 248U); +} diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/get_msb.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/get_msb.hpp new file mode 100644 index 0000000..fc456a6 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/get_msb.hpp @@ -0,0 +1,46 @@ +#pragma once +#include +#include +#include +namespace bb::numeric { + +// from http://supertech.csail.mit.edu/papers/debruijn.pdf +constexpr inline uint32_t get_msb32(const uint32_t in) +{ + constexpr std::array MultiplyDeBruijnBitPosition{ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, + 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, + 24, 7, 19, 27, 23, 6, 26, 5, 4, 31 }; + + uint32_t v = in | (in >> 1); + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + + return MultiplyDeBruijnBitPosition[static_cast(v * static_cast(0x07C4ACDD)) >> + static_cast(27)]; +} + +constexpr inline uint64_t get_msb64(const uint64_t in) +{ + constexpr std::array de_bruijn_sequence{ 0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, + 16, 3, 61, 54, 58, 35, 52, 50, 42, 21, 44, 38, 32, + 29, 23, 17, 11, 4, 62, 46, 55, 26, 59, 40, 36, 15, + 53, 34, 51, 20, 43, 31, 22, 10, 45, 25, 39, 14, 33, + 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63 }; + + uint64_t t = in | (in >> 1); + t |= t >> 2; + t |= t >> 4; + t |= t >> 8; + t |= t >> 16; + t |= t >> 32; + return static_cast(de_bruijn_sequence[(t * 0x03F79D71B4CB0A89ULL) >> 58ULL]); +}; + +template constexpr inline T get_msb(const T in) +{ + return (sizeof(T) <= 4) ? get_msb32(in) : get_msb64(in); +} + +} // namespace bb::numeric \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/get_msb.test.cpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/get_msb.test.cpp new file mode 100644 index 0000000..3abdd73 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/get_msb.test.cpp @@ -0,0 +1,35 @@ +#include "get_msb.hpp" +#include + +using namespace bb; + +TEST(bitop, GetMsbUint640Value) +{ + uint64_t a = 0b00000000000000000000000000000000; + EXPECT_EQ(numeric::get_msb(a), 0U); +} + +TEST(bitop, GetMsbUint320) +{ + uint32_t a = 0b00000000000000000000000000000001; + EXPECT_EQ(numeric::get_msb(a), 0U); +} + +TEST(bitop, GetMsbUint3231) +{ + uint32_t a = 0b10000000000000000000000000000001; + EXPECT_EQ(numeric::get_msb(a), 31U); +} + +TEST(bitop, GetMsbUint6463) +{ + uint64_t a = 0b1000000000000000000000000000000100000000000000000000000000000000; + EXPECT_EQ(numeric::get_msb(a), 63U); +} + +TEST(bitop, GetMsbSizeT7) +{ + size_t a = 0x80; + auto r = numeric::get_msb(a); + EXPECT_EQ(r, 7U); +} diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/keep_n_lsb.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/keep_n_lsb.hpp new file mode 100644 index 0000000..c03a315 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/keep_n_lsb.hpp @@ -0,0 +1,11 @@ +#pragma once +#include + +namespace bb::numeric { + +template inline T keep_n_lsb(T const& input, size_t num_bits) +{ + return num_bits >= sizeof(T) * 8 ? input : input & ((T(1) << num_bits) - 1); +} + +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/pow.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/pow.hpp new file mode 100644 index 0000000..2e67aff --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/pow.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "./get_msb.hpp" +#include + +namespace bb::numeric { +constexpr uint64_t pow64(const uint64_t input, const uint64_t exponent) +{ + if (input == 0) { + return 0; + } + if (exponent == 0) { + return 1; + } + + uint64_t accumulator = input; + uint64_t to_mul = input; + const uint64_t maximum_set_bit = get_msb64(exponent); + + for (int i = static_cast(maximum_set_bit) - 1; i >= 0; --i) { + accumulator *= accumulator; + if (((exponent >> i) & 1) != 0U) { + accumulator *= to_mul; + } + } + return accumulator; +} + +constexpr bool is_power_of_two(uint64_t x) +{ + return (x != 0U) && ((x & (x - 1)) == 0U); +} + +} // namespace bb::numeric \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/rotate.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/rotate.hpp new file mode 100644 index 0000000..5be15b9 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/rotate.hpp @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace bb::numeric { + +constexpr inline uint64_t rotate64(const uint64_t value, const uint64_t rotation) +{ + return rotation != 0U ? (value >> rotation) + (value << (64 - rotation)) : value; +} + +constexpr inline uint32_t rotate32(const uint32_t value, const uint32_t rotation) +{ + return rotation != 0U ? (value >> rotation) + (value << (32 - rotation)) : value; +} +} // namespace bb::numeric \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/sparse_form.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/sparse_form.hpp new file mode 100644 index 0000000..db09c47 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/bitop/sparse_form.hpp @@ -0,0 +1,157 @@ +#pragma once +#include "../../common/throw_or_abort.hpp" +#include +#include +#include +#include + +#include "../uint256/uint256.hpp" + +namespace bb::numeric { + +inline std::vector slice_input(const uint256_t& input, const uint64_t base, const size_t num_slices) +{ + uint256_t target = input; + std::vector slices; + if (num_slices > 0) { + for (size_t i = 0; i < num_slices; ++i) { + slices.push_back((target % base).data[0]); + target /= base; + } + } else { + while (target > 0) { + slices.push_back((target % base).data[0]); + target /= base; + } + } + return slices; +} + +inline std::vector slice_input_using_variable_bases(const uint256_t& input, + const std::vector& bases) +{ + uint256_t target = input; + std::vector slices; + for (size_t i = 0; i < bases.size(); ++i) { + if (target >= bases[i] && i == bases.size() - 1) { + throw_or_abort(format("Last key slice greater than ", bases[i])); + } + slices.push_back((target % bases[i]).data[0]); + target /= bases[i]; + } + return slices; +} + +template constexpr std::array get_base_powers() +{ + std::array output{}; + output[0] = 1; + for (size_t i = 1; i < num_slices; ++i) { + output[i] = output[i - 1] * base; + } + return output; +} + +template constexpr uint256_t map_into_sparse_form(const uint64_t input) +{ + uint256_t out = 0UL; + auto converted = input; + + constexpr auto base_powers = get_base_powers(); + for (size_t i = 0; i < 32; ++i) { + uint64_t sparse_bit = ((converted >> i) & 1U); + if (sparse_bit) { + out += base_powers[i]; + } + } + return out; +} + +template constexpr uint64_t map_from_sparse_form(const uint256_t& input) +{ + uint256_t target = input; + uint64_t output = 0; + + constexpr auto bases = get_base_powers(); + + for (uint64_t i = 0; i < 32; ++i) { + const auto& base_power = bases[static_cast(31 - i)]; + uint256_t prev_threshold = 0; + for (uint64_t j = 1; j < base + 1; ++j) { + const auto threshold = prev_threshold + base_power; + if (target < threshold) { + bool bit = ((j - 1) & 1); + if (bit) { + output += (1ULL << (31ULL - i)); + } + if (j > 1) { + target -= (prev_threshold); + } + break; + } + prev_threshold = threshold; + } + } + + return output; +} + +template class sparse_int { + public: + sparse_int(const uint64_t input = 0) + : value(input) + { + for (size_t i = 0; i < num_bits; ++i) { + const uint64_t bit = (input >> i) & 1U; + limbs[i] = bit; + } + } + sparse_int(const sparse_int& other) noexcept = default; + sparse_int(sparse_int&& other) noexcept = default; + sparse_int& operator=(const sparse_int& other) noexcept = default; + sparse_int& operator=(sparse_int&& other) noexcept = default; + ~sparse_int() noexcept = default; + + sparse_int operator+(const sparse_int& other) const + { + sparse_int result(*this); + for (size_t i = 0; i < num_bits - 1; ++i) { + result.limbs[i] += other.limbs[i]; + if (result.limbs[i] >= base) { + result.limbs[i] -= base; + ++result.limbs[i + 1]; + } + } + result.limbs[num_bits - 1] += other.limbs[num_bits - 1]; + result.limbs[num_bits - 1] %= base; + result.value += other.value; + return result; + }; + + sparse_int operator+=(const sparse_int& other) + { + *this = *this + other; + return *this; + } + + [[nodiscard]] uint64_t get_value() const { return value; } + + [[nodiscard]] uint64_t get_sparse_value() const + { + uint64_t result = 0; + for (size_t i = num_bits - 1; i < num_bits; --i) { + result *= base; + result += limbs[i]; + } + return result; + } + + const std::array& get_limbs() const { return limbs; } + + private: + std::array limbs; + uint64_t value; + uint64_t sparse_value; +}; + +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/random/engine.cpp b/sumcheck/src/cuda/includes/barretenberg/numeric/random/engine.cpp new file mode 100644 index 0000000..4bbbb6f --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/random/engine.cpp @@ -0,0 +1,139 @@ +#include "engine.hpp" +#include "../../common/assert.hpp" +#include +#include +#include + +namespace bb::numeric { + +namespace { +auto generate_random_data() +{ + std::array random_data; + std::random_device source; + std::generate(std::begin(random_data), std::end(random_data), std::ref(source)); + return random_data; +} +} // namespace + +class RandomEngine : public RNG { + public: + uint8_t get_random_uint8() override + { + auto buf = generate_random_data(); + uint32_t out = buf[0]; + return static_cast(out); + } + + uint16_t get_random_uint16() override + { + auto buf = generate_random_data(); + uint32_t out = buf[0]; + return static_cast(out); + } + + uint32_t get_random_uint32() override + { + auto buf = generate_random_data(); + uint32_t out = buf[0]; + return static_cast(out); + } + + uint64_t get_random_uint64() override + { + auto buf = generate_random_data(); + auto lo = static_cast(buf[0]); + auto hi = static_cast(buf[1]); + return (lo + (hi << 32ULL)); + } + + uint128_t get_random_uint128() override + { + auto big = get_random_uint256(); + auto lo = static_cast(big.data[0]); + auto hi = static_cast(big.data[1]); + return (lo + (hi << static_cast(64ULL))); + } + + uint256_t get_random_uint256() override + { + const auto get64 = [](const std::array& buffer, const size_t offset) { + auto lo = static_cast(buffer[0 + offset]); + auto hi = static_cast(buffer[1 + offset]); + return (lo + (hi << 32ULL)); + }; + auto buf = generate_random_data(); + uint64_t lolo = get64(buf, 0); + uint64_t lohi = get64(buf, 2); + uint64_t hilo = get64(buf, 4); + uint64_t hihi = get64(buf, 6); + return { lolo, lohi, hilo, hihi }; + } +}; + +class DebugEngine : public RNG { + public: + DebugEngine() + // disable linting for this line: we want the DEBUG engine to produce predictable pseudorandom numbers! + // NOLINTNEXTLINE(cert-msc32-c, cert-msc51-cpp) + : engine(std::mt19937_64(12345)) + {} + + DebugEngine(std::uint_fast64_t seed) + : engine(std::mt19937_64(seed)) + {} + + uint8_t get_random_uint8() override { return static_cast(dist(engine)); } + + uint16_t get_random_uint16() override { return static_cast(dist(engine)); } + + uint32_t get_random_uint32() override { return static_cast(dist(engine)); } + + uint64_t get_random_uint64() override { return dist(engine); } + + uint128_t get_random_uint128() override + { + uint128_t hi = dist(engine); + uint128_t lo = dist(engine); + return (hi << 64) | lo; + } + + uint256_t get_random_uint256() override + { + // Do not inline in constructor call. Evaluation order is important for cross-compiler consistency. + auto a = dist(engine); + auto b = dist(engine); + auto c = dist(engine); + auto d = dist(engine); + return { a, b, c, d }; + } + + private: + std::mt19937_64 engine; + std::uniform_int_distribution dist = std::uniform_int_distribution{ 0ULL, UINT64_MAX }; +}; + +/** + * Used by tests to ensure consistent behavior. + */ +RNG& get_debug_randomness(bool reset, std::uint_fast64_t seed) +{ + // static std::seed_seq seed({ 1, 2, 3, 4, 5 }); + static DebugEngine debug_engine = DebugEngine(); + if (reset) { + debug_engine = DebugEngine(seed); + } + return debug_engine; +} + +/** + * Default engine. If wanting consistent proof construction, uncomment the line to return the debug engine. + */ +RNG& get_randomness() +{ + // return get_debug_randomness(); + static RandomEngine engine; + return engine; +} + +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/random/engine.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/random/engine.hpp new file mode 100644 index 0000000..0e54341 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/random/engine.hpp @@ -0,0 +1,52 @@ +#pragma once +#include "../uint128/uint128.hpp" +#include "../uint256/uint256.hpp" +#include "../uintx/uintx.hpp" +#include "unistd.h" +#include +#include + +namespace bb::numeric { + +class RNG { + public: + virtual uint8_t get_random_uint8() = 0; + + virtual uint16_t get_random_uint16() = 0; + + virtual uint32_t get_random_uint32() = 0; + + virtual uint64_t get_random_uint64() = 0; + + virtual uint128_t get_random_uint128() = 0; + + virtual uint256_t get_random_uint256() = 0; + + virtual ~RNG() = default; + RNG() noexcept = default; + RNG(const RNG& other) = default; + RNG(RNG&& other) = default; + RNG& operator=(const RNG& other) = default; + RNG& operator=(RNG&& other) = default; + + uint512_t get_random_uint512() + { + // Do not inline in constructor call. Evaluation order is important for cross-compiler consistency. + auto lo = get_random_uint256(); + auto hi = get_random_uint256(); + return { lo, hi }; + } + + uint1024_t get_random_uint1024() + { + // Do not inline in constructor call. Evaluation order is important for cross-compiler consistency. + auto lo = get_random_uint512(); + auto hi = get_random_uint512(); + return { lo, hi }; + } +}; + +RNG& get_debug_randomness(bool reset = false, std::uint_fast64_t seed = 12345); +RNG& get_randomness(); + +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/uint128/uint128.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/uint128/uint128.hpp new file mode 100644 index 0000000..b5cf4d3 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/uint128/uint128.hpp @@ -0,0 +1,212 @@ +#pragma once +#include +#include +#include + +#ifdef __i386__ +#include "../../common/serialize.hpp" +#include + +namespace bb::numeric { + +class alignas(32) uint128_t { + public: + uint32_t data[4]; // NOLINT + + constexpr uint128_t(const uint64_t a = 0) + : data{ static_cast(a), static_cast(a >> 32), 0, 0 } + {} + + constexpr uint128_t(const uint32_t a, const uint32_t b, const uint32_t c, const uint32_t d) + : data{ a, b, c, d } + {} + + constexpr uint128_t(const uint128_t& other) + : data{ other.data[0], other.data[1], other.data[2], other.data[3] } + {} + constexpr uint128_t(uint128_t&& other) = default; + + static constexpr uint128_t from_uint64(const uint64_t a) + { + return { static_cast(a), static_cast(a >> 32), 0, 0 }; + } + + constexpr explicit operator uint64_t() { return (static_cast(data[1]) << 32) + data[0]; } + + constexpr uint128_t& operator=(const uint128_t& other) = default; + constexpr uint128_t& operator=(uint128_t&& other) = default; + constexpr ~uint128_t() = default; + explicit constexpr operator bool() const { return static_cast(data[0]); }; + + template explicit constexpr operator T() const { return static_cast(data[0]); }; + + [[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const; + [[nodiscard]] constexpr uint64_t get_msb() const; + + [[nodiscard]] constexpr uint128_t slice(uint64_t start, uint64_t end) const; + [[nodiscard]] constexpr uint128_t pow(const uint128_t& exponent) const; + + constexpr uint128_t operator+(const uint128_t& other) const; + constexpr uint128_t operator-(const uint128_t& other) const; + constexpr uint128_t operator-() const; + + constexpr uint128_t operator*(const uint128_t& other) const; + constexpr uint128_t operator/(const uint128_t& other) const; + constexpr uint128_t operator%(const uint128_t& other) const; + + constexpr uint128_t operator>>(const uint128_t& other) const; + constexpr uint128_t operator<<(const uint128_t& other) const; + + constexpr uint128_t operator&(const uint128_t& other) const; + constexpr uint128_t operator^(const uint128_t& other) const; + constexpr uint128_t operator|(const uint128_t& other) const; + constexpr uint128_t operator~() const; + + constexpr bool operator==(const uint128_t& other) const; + constexpr bool operator!=(const uint128_t& other) const; + constexpr bool operator!() const; + + constexpr bool operator>(const uint128_t& other) const; + constexpr bool operator<(const uint128_t& other) const; + constexpr bool operator>=(const uint128_t& other) const; + constexpr bool operator<=(const uint128_t& other) const; + + static constexpr size_t length() { return 128; } + + constexpr uint128_t& operator+=(const uint128_t& other) + { + *this = *this + other; + return *this; + }; + constexpr uint128_t& operator-=(const uint128_t& other) + { + *this = *this - other; + return *this; + }; + constexpr uint128_t& operator*=(const uint128_t& other) + { + *this = *this * other; + return *this; + }; + constexpr uint128_t& operator/=(const uint128_t& other) + { + *this = *this / other; + return *this; + }; + constexpr uint128_t& operator%=(const uint128_t& other) + { + *this = *this % other; + return *this; + }; + + constexpr uint128_t& operator++() + { + *this += uint128_t(1); + return *this; + }; + constexpr uint128_t& operator--() + { + *this -= uint128_t(1); + return *this; + }; + + constexpr uint128_t& operator&=(const uint128_t& other) + { + *this = *this & other; + return *this; + }; + constexpr uint128_t& operator^=(const uint128_t& other) + { + *this = *this ^ other; + return *this; + }; + constexpr uint128_t& operator|=(const uint128_t& other) + { + *this = *this | other; + return *this; + }; + + constexpr uint128_t& operator>>=(const uint128_t& other) + { + *this = *this >> other; + return *this; + }; + constexpr uint128_t& operator<<=(const uint128_t& other) + { + *this = *this << other; + return *this; + }; + + [[nodiscard]] constexpr std::pair mul_extended(const uint128_t& other) const; + + [[nodiscard]] constexpr std::pair divmod(const uint128_t& b) const; + + private: + [[nodiscard]] static constexpr std::pair mul_wide(uint32_t a, uint32_t b); + [[nodiscard]] static constexpr std::pair addc(uint32_t a, uint32_t b, uint32_t carry_in); + [[nodiscard]] static constexpr uint32_t addc_discard_hi(uint32_t a, uint32_t b, uint32_t carry_in); + [[nodiscard]] static constexpr uint32_t sbb_discard_hi(uint32_t a, uint32_t b, uint32_t borrow_in); + + [[nodiscard]] static constexpr std::pair sbb(uint32_t a, uint32_t b, uint32_t borrow_in); + [[nodiscard]] static constexpr uint32_t mac_discard_hi(uint32_t a, uint32_t b, uint32_t c, uint32_t carry_in); + [[nodiscard]] static constexpr std::pair mac(uint32_t a, + uint32_t b, + uint32_t c, + uint32_t carry_in); +}; + +inline std::ostream& operator<<(std::ostream& os, uint128_t const& a) +{ + std::ios_base::fmtflags f(os.flags()); + os << std::hex << "0x" << std::setfill('0') << std::setw(8) << a.data[3] << std::setw(8) << a.data[2] + << std::setw(8) << a.data[1] << std::setw(8) << a.data[0]; + os.flags(f); + return os; +} + +template inline void read(B& it, uint128_t& value) +{ + using serialize::read; + uint32_t a = 0; + uint32_t b = 0; + uint32_t c = 0; + uint32_t d = 0; + read(it, d); + read(it, c); + read(it, b); + read(it, a); + value = uint128_t(a, b, c, d); +} + +template inline void write(B& it, uint128_t const& value) +{ + using serialize::write; + write(it, value.data[3]); + write(it, value.data[2]); + write(it, value.data[1]); + write(it, value.data[0]); +} + +} // namespace bb::numeric + +#include "./uint128_impl.hpp" + +// disable linter errors; we want to expose a global uint128_t type to mimic uint64_t, uint32_t etc +// NOLINTNEXTLINE(tidymisc-unused-using-decls, google-global-names-in-headers, misc-unused-using-decls) +using numeric::uint128_t; +#else +__extension__ using uint128_t = unsigned __int128; + +namespace std { +// can ignore linter error for streaming operations, we need to add to std namespace to support printing this type! +// NOLINTNEXTLINE(cert-dcl58-cpp) +inline std::ostream& operator<<(std::ostream& os, uint128_t const& a) +{ + std::ios_base::fmtflags f(os.flags()); + os << std::hex << "0x" << std::setfill('0') << std::setw(16) << static_cast(a >> 64) << std::setw(16) + << static_cast(a); + os.flags(f); + return os; +} +} // namespace std +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/uint128/uint128_impl.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/uint128/uint128_impl.hpp new file mode 100644 index 0000000..63d1b4c --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/uint128/uint128_impl.hpp @@ -0,0 +1,414 @@ +#ifdef __i386__ +#pragma once +#include "../bitop/get_msb.hpp" +#include "./uint128.hpp" +#include "../../common/assert.hpp" +namespace bb::numeric { + +constexpr std::pair uint128_t::mul_wide(const uint32_t a, const uint32_t b) +{ + const uint32_t a_lo = a & 0xffffULL; + const uint32_t a_hi = a >> 16ULL; + const uint32_t b_lo = b & 0xffffULL; + const uint32_t b_hi = b >> 16ULL; + + const uint32_t lo_lo = a_lo * b_lo; + const uint32_t hi_lo = a_hi * b_lo; + const uint32_t lo_hi = a_lo * b_hi; + const uint32_t hi_hi = a_hi * b_hi; + + const uint32_t cross = (lo_lo >> 16) + (hi_lo & 0xffffULL) + lo_hi; + + return { (cross << 16ULL) | (lo_lo & 0xffffULL), (hi_lo >> 16ULL) + (cross >> 16ULL) + hi_hi }; +} + +// compute a + b + carry, returning the carry +constexpr std::pair uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in) +{ + const uint32_t sum = a + b; + const auto carry_temp = static_cast(sum < a); + const uint32_t r = sum + carry_in; + const uint32_t carry_out = carry_temp + static_cast(r < carry_in); + return { r, carry_out }; +} + +constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in) +{ + return a + b + carry_in; +} + +constexpr std::pair uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in) +{ + const uint32_t t_1 = a - (borrow_in >> 31ULL); + const auto borrow_temp_1 = static_cast(t_1 > a); + const uint32_t t_2 = t_1 - b; + const auto borrow_temp_2 = static_cast(t_2 > t_1); + + return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) }; +} + +constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in) +{ + return a - b - (borrow_in >> 31ULL); +} + +// {r, carry_out} = a + carry_in + b * c +constexpr std::pair uint128_t::mac(const uint32_t a, + const uint32_t b, + const uint32_t c, + const uint32_t carry_in) +{ + std::pair result = mul_wide(b, c); + result.first += a; + const auto overflow_c = static_cast(result.first < a); + result.first += carry_in; + const auto overflow_carry = static_cast(result.first < carry_in); + result.second += (overflow_c + overflow_carry); + return result; +} + +constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a, + const uint32_t b, + const uint32_t c, + const uint32_t carry_in) +{ + return (b * c + a + carry_in); +} + +constexpr std::pair uint128_t::divmod(const uint128_t& b) const +{ + if (*this == 0 || b == 0) { + return { 0, 0 }; + } + if (b == 1) { + return { *this, 0 }; + } + if (*this == b) { + return { 1, 0 }; + } + if (b > *this) { + return { 0, *this }; + } + + uint128_t quotient = 0; + uint128_t remainder = *this; + + uint64_t bit_difference = get_msb() - b.get_msb(); + + uint128_t divisor = b << bit_difference; + uint128_t accumulator = uint128_t(1) << bit_difference; + + // if the divisor is bigger than the remainder, a and b have the same bit length + if (divisor > remainder) { + divisor >>= 1; + accumulator >>= 1; + } + + // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder, + // and add to the quotient + while (remainder >= b) { + + // we've shunted 'divisor' up to have the same bit length as our remainder. + // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b + if (remainder >= divisor) { + remainder -= divisor; + // we can use OR here instead of +, as + // accumulator is always a nice power of two + quotient |= accumulator; + } + divisor >>= 1; + accumulator >>= 1; + } + + return { quotient, remainder }; +} + +constexpr std::pair uint128_t::mul_extended(const uint128_t& other) const +{ + const auto [r0, t0] = mul_wide(data[0], other.data[0]); + const auto [q0, t1] = mac(t0, data[0], other.data[1], 0); + const auto [q1, t2] = mac(t1, data[0], other.data[2], 0); + const auto [q2, z0] = mac(t2, data[0], other.data[3], 0); + + const auto [r1, t3] = mac(q0, data[1], other.data[0], 0); + const auto [q3, t4] = mac(q1, data[1], other.data[1], t3); + const auto [q4, t5] = mac(q2, data[1], other.data[2], t4); + const auto [q5, z1] = mac(z0, data[1], other.data[3], t5); + + const auto [r2, t6] = mac(q3, data[2], other.data[0], 0); + const auto [q6, t7] = mac(q4, data[2], other.data[1], t6); + const auto [q7, t8] = mac(q5, data[2], other.data[2], t7); + const auto [q8, z2] = mac(z1, data[2], other.data[3], t8); + + const auto [r3, t9] = mac(q6, data[3], other.data[0], 0); + const auto [r4, t10] = mac(q7, data[3], other.data[1], t9); + const auto [r5, t11] = mac(q8, data[3], other.data[2], t10); + const auto [r6, r7] = mac(z2, data[3], other.data[3], t11); + + uint128_t lo(r0, r1, r2, r3); + uint128_t hi(r4, r5, r6, r7); + return { lo, hi }; +} + +/** + * Viewing `this` uint128_t as a bit string, and counting bits from 0, slices a substring. + * @returns the uint128_t equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the + * `end`-th bit of `this`. + */ +constexpr uint128_t uint128_t::slice(const uint64_t start, const uint64_t end) const +{ + const uint64_t range = end - start; + const uint128_t mask = (range == 128) ? -uint128_t(1) : (uint128_t(1) << range) - 1; + return ((*this) >> start) & mask; +} + +constexpr uint128_t uint128_t::pow(const uint128_t& exponent) const +{ + uint128_t accumulator{ data[0], data[1], data[2], data[3] }; + uint128_t to_mul{ data[0], data[1], data[2], data[3] }; + const uint64_t maximum_set_bit = exponent.get_msb(); + + for (int i = static_cast(maximum_set_bit) - 1; i >= 0; --i) { + accumulator *= accumulator; + if (exponent.get_bit(static_cast(i))) { + accumulator *= to_mul; + } + } + if (exponent == uint128_t(0)) { + accumulator = uint128_t(1); + } else if (*this == uint128_t(0)) { + accumulator = uint128_t(0); + } + return accumulator; +} + +constexpr bool uint128_t::get_bit(const uint64_t bit_index) const +{ + ASSERT(bit_index < 128); + if (bit_index > 127) { + return false; + } + const auto idx = static_cast(bit_index >> 5); + const size_t shift = bit_index & 31; + return static_cast((data[idx] >> shift) & 1); +} + +constexpr uint64_t uint128_t::get_msb() const +{ + uint64_t idx = numeric::get_msb64(data[3]); + idx = (idx == 0 && data[3] == 0) ? numeric::get_msb64(data[2]) : idx + 32; + idx = (idx == 0 && data[2] == 0) ? numeric::get_msb64(data[1]) : idx + 32; + idx = (idx == 0 && data[1] == 0) ? numeric::get_msb64(data[0]) : idx + 32; + return idx; +} + +constexpr uint128_t uint128_t::operator+(const uint128_t& other) const +{ + const auto [r0, t0] = addc(data[0], other.data[0], 0); + const auto [r1, t1] = addc(data[1], other.data[1], t0); + const auto [r2, t2] = addc(data[2], other.data[2], t1); + const auto r3 = addc_discard_hi(data[3], other.data[3], t2); + return { r0, r1, r2, r3 }; +}; + +constexpr uint128_t uint128_t::operator-(const uint128_t& other) const +{ + + const auto [r0, t0] = sbb(data[0], other.data[0], 0); + const auto [r1, t1] = sbb(data[1], other.data[1], t0); + const auto [r2, t2] = sbb(data[2], other.data[2], t1); + const auto r3 = sbb_discard_hi(data[3], other.data[3], t2); + return { r0, r1, r2, r3 }; +} + +constexpr uint128_t uint128_t::operator-() const +{ + return uint128_t(0) - *this; +} + +constexpr uint128_t uint128_t::operator*(const uint128_t& other) const +{ + const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL); + const auto [q0, t1] = mac(0, data[0], other.data[1], t0); + const auto [q1, t2] = mac(0, data[0], other.data[2], t1); + const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2); + + const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL); + const auto [q3, t4] = mac(q1, data[1], other.data[1], t3); + const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4); + + const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL); + const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5); + + const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL); + + return { r0, r1, r2, r3 }; +} + +constexpr uint128_t uint128_t::operator/(const uint128_t& other) const +{ + return divmod(other).first; +} + +constexpr uint128_t uint128_t::operator%(const uint128_t& other) const +{ + return divmod(other).second; +} + +constexpr uint128_t uint128_t::operator&(const uint128_t& other) const +{ + return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] }; +} + +constexpr uint128_t uint128_t::operator^(const uint128_t& other) const +{ + return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] }; +} + +constexpr uint128_t uint128_t::operator|(const uint128_t& other) const +{ + return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] }; +} + +constexpr uint128_t uint128_t::operator~() const +{ + return { ~data[0], ~data[1], ~data[2], ~data[3] }; +} + +constexpr bool uint128_t::operator==(const uint128_t& other) const +{ + return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3]; +} + +constexpr bool uint128_t::operator!=(const uint128_t& other) const +{ + return !(*this == other); +} + +constexpr bool uint128_t::operator!() const +{ + return *this == uint128_t(0ULL); +} + +constexpr bool uint128_t::operator>(const uint128_t& other) const +{ + bool t0 = data[3] > other.data[3]; + bool t1 = data[3] == other.data[3] && data[2] > other.data[2]; + bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1]; + bool t3 = + data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0]; + return t0 || t1 || t2 || t3; +} + +constexpr bool uint128_t::operator>=(const uint128_t& other) const +{ + return (*this > other) || (*this == other); +} + +constexpr bool uint128_t::operator<(const uint128_t& other) const +{ + return other > *this; +} + +constexpr bool uint128_t::operator<=(const uint128_t& other) const +{ + return (*this < other) || (*this == other); +} + +constexpr uint128_t uint128_t::operator>>(const uint128_t& other) const +{ + uint32_t total_shift = other.data[0]; + + if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) { + return 0; + } + + if (total_shift == 0) { + return *this; + } + + uint32_t num_shifted_limbs = total_shift >> 5ULL; + uint32_t limb_shift = total_shift & 31ULL; + + std::array shifted_limbs = { 0, 0, 0, 0 }; + + if (limb_shift == 0) { + shifted_limbs[0] = data[0]; + shifted_limbs[1] = data[1]; + shifted_limbs[2] = data[2]; + shifted_limbs[3] = data[3]; + } else { + uint32_t remainder_shift = 32ULL - limb_shift; + + shifted_limbs[3] = data[3] >> limb_shift; + + uint32_t remainder = (data[3]) << remainder_shift; + + shifted_limbs[2] = (data[2] >> limb_shift) + remainder; + + remainder = (data[2]) << remainder_shift; + + shifted_limbs[1] = (data[1] >> limb_shift) + remainder; + + remainder = (data[1]) << remainder_shift; + + shifted_limbs[0] = (data[0] >> limb_shift) + remainder; + } + uint128_t result(0); + + for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) { + result.data[i] = shifted_limbs[static_cast(i + num_shifted_limbs)]; + } + + return result; +} + +constexpr uint128_t uint128_t::operator<<(const uint128_t& other) const +{ + uint32_t total_shift = other.data[0]; + + if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) { + return 0; + } + + if (total_shift == 0) { + return *this; + } + uint32_t num_shifted_limbs = total_shift >> 5ULL; + uint32_t limb_shift = total_shift & 31ULL; + + std::array shifted_limbs{ 0, 0, 0, 0 }; + + if (limb_shift == 0) { + shifted_limbs[0] = data[0]; + shifted_limbs[1] = data[1]; + shifted_limbs[2] = data[2]; + shifted_limbs[3] = data[3]; + } else { + uint32_t remainder_shift = 32ULL - limb_shift; + + shifted_limbs[0] = data[0] << limb_shift; + + uint32_t remainder = data[0] >> remainder_shift; + + shifted_limbs[1] = (data[1] << limb_shift) + remainder; + + remainder = data[1] >> remainder_shift; + + shifted_limbs[2] = (data[2] << limb_shift) + remainder; + + remainder = data[2] >> remainder_shift; + + shifted_limbs[3] = (data[3] << limb_shift) + remainder; + } + uint128_t result(0); + + for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) { + result.data[static_cast(i + num_shifted_limbs)] = shifted_limbs[i]; + } + + return result; +} + +} // namespace bb::numeric +#endif diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/uint256/uint256.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/uint256/uint256.hpp new file mode 100644 index 0000000..ae5cb1a --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/uint256/uint256.hpp @@ -0,0 +1,239 @@ +/** + * uint256_t + * Copyright Aztec 2020 + * + * An unsigned 256 bit integer type. + * + * Constructor and all methods are constexpr. + * Ideally, uint256_t should be able to be treated like any other literal type. + * + * Not optimized for performance, this code doesn't touch any of our hot paths when constructing PLONK proofs. + **/ +#pragma once + +#include "../uint128/uint128.hpp" +#include "../../common/throw_or_abort.hpp" +#include +#include +#include +#include +#include + +namespace bb::numeric { + +class alignas(32) uint256_t { + + public: +#if defined(__wasm__) || !defined(__SIZEOF_INT128__) +#define WASM_NUM_LIMBS 9 +#define WASM_LIMB_BITS 29 +#endif + constexpr uint256_t(const uint64_t a = 0) noexcept + : data{ a, 0, 0, 0 } + {} + + constexpr uint256_t(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept + : data{ a, b, c, d } + {} + + constexpr uint256_t(const uint256_t& other) noexcept + : data{ other.data[0], other.data[1], other.data[2], other.data[3] } + {} + constexpr uint256_t(uint256_t&& other) noexcept = default; + + explicit constexpr uint256_t(std::string input) noexcept + { + /* Quick and dirty conversion from a single character to its hex equivelent */ + constexpr auto HexCharToInt = [](uint8_t Input) { + bool valid = + (Input >= 'a' && Input <= 'f') || (Input >= 'A' && Input <= 'F') || (Input >= '0' && Input <= '9'); + if (!valid) { + throw_or_abort("Error, uint256 constructed from string_view with invalid hex parameter"); + } + uint8_t res = + ((Input >= 'a') && (Input <= 'f')) ? (Input - (static_cast('a') - static_cast(10))) + : ((Input >= 'A') && (Input <= 'F')) ? (Input - (static_cast('A') - static_cast(10))) + : ((Input >= '0') && (Input <= '9')) ? (Input - static_cast('0')) + : 0; + return res; + }; + + std::array limbs{ 0, 0, 0, 0 }; + size_t start_index = 0; + if (input.size() == 66 && input[0] == '0' && input[1] == 'x') { + start_index = 2; + } else if (input.size() != 64) { + throw_or_abort("Error, uint256 constructed from string_view with invalid length"); + } + for (size_t j = 0; j < 4; ++j) { + + const size_t limb_index = start_index + j * 16; + for (size_t i = 0; i < 8; ++i) { + const size_t byte_index = limb_index + (i * 2); + uint8_t nibble_hi = HexCharToInt(static_cast(input[byte_index])); + uint8_t nibble_lo = HexCharToInt(static_cast(input[byte_index + 1])); + uint8_t byte = static_cast((nibble_hi * 16) + nibble_lo); + limbs[j] <<= 8; + limbs[j] += byte; + } + } + data[0] = limbs[3]; + data[1] = limbs[2]; + data[2] = limbs[1]; + data[3] = limbs[0]; + } + + static constexpr uint256_t from_uint128(const uint128_t a) noexcept + { + return { static_cast(a), static_cast(a >> 64), 0, 0 }; + } + + constexpr explicit operator uint128_t() { return (static_cast(data[1]) << 64) + data[0]; } + + constexpr uint256_t& operator=(const uint256_t& other) noexcept = default; + constexpr uint256_t& operator=(uint256_t&& other) noexcept = default; + constexpr ~uint256_t() noexcept = default; + + explicit constexpr operator bool() const { return static_cast(data[0]); }; + + template explicit constexpr operator T() const { return static_cast(data[0]); }; + + [[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const; + [[nodiscard]] constexpr uint64_t get_msb() const; + + [[nodiscard]] constexpr uint256_t slice(uint64_t start, uint64_t end) const; + [[nodiscard]] constexpr uint256_t pow(const uint256_t& exponent) const; + + constexpr uint256_t operator+(const uint256_t& other) const; + constexpr uint256_t operator-(const uint256_t& other) const; + constexpr uint256_t operator-() const; + + constexpr uint256_t operator*(const uint256_t& other) const; + constexpr uint256_t operator/(const uint256_t& other) const; + constexpr uint256_t operator%(const uint256_t& other) const; + + constexpr uint256_t operator>>(const uint256_t& other) const; + constexpr uint256_t operator<<(const uint256_t& other) const; + + constexpr uint256_t operator&(const uint256_t& other) const; + constexpr uint256_t operator^(const uint256_t& other) const; + constexpr uint256_t operator|(const uint256_t& other) const; + constexpr uint256_t operator~() const; + + constexpr bool operator==(const uint256_t& other) const; + constexpr bool operator!=(const uint256_t& other) const; + constexpr bool operator!() const; + + constexpr bool operator>(const uint256_t& other) const; + constexpr bool operator<(const uint256_t& other) const; + constexpr bool operator>=(const uint256_t& other) const; + constexpr bool operator<=(const uint256_t& other) const; + + static constexpr size_t length() { return 256; } + + constexpr uint256_t& operator+=(const uint256_t& other) + { + *this = *this + other; + return *this; + }; + constexpr uint256_t& operator-=(const uint256_t& other) + { + *this = *this - other; + return *this; + }; + constexpr uint256_t& operator*=(const uint256_t& other) + { + *this = *this * other; + return *this; + }; + constexpr uint256_t& operator/=(const uint256_t& other) + { + *this = *this / other; + return *this; + }; + constexpr uint256_t& operator%=(const uint256_t& other) + { + *this = *this % other; + return *this; + }; + + constexpr uint256_t& operator++() + { + *this += uint256_t(1); + return *this; + }; + constexpr uint256_t& operator--() + { + *this -= uint256_t(1); + return *this; + }; + + constexpr uint256_t& operator&=(const uint256_t& other) + { + *this = *this & other; + return *this; + }; + constexpr uint256_t& operator^=(const uint256_t& other) + { + *this = *this ^ other; + return *this; + }; + constexpr uint256_t& operator|=(const uint256_t& other) + { + *this = *this | other; + return *this; + }; + + constexpr uint256_t& operator>>=(const uint256_t& other) + { + *this = *this >> other; + return *this; + }; + constexpr uint256_t& operator<<=(const uint256_t& other) + { + *this = *this << other; + return *this; + }; + + [[nodiscard]] constexpr std::pair mul_extended(const uint256_t& other) const; + + uint64_t data[4]; // NOLINT + + [[nodiscard]] constexpr std::pair divmod(const uint256_t& b) const; + + private: + [[nodiscard]] static constexpr std::pair mul_wide(uint64_t a, uint64_t b); + [[nodiscard]] static constexpr std::pair addc(uint64_t a, uint64_t b, uint64_t carry_in); + [[nodiscard]] static constexpr uint64_t addc_discard_hi(uint64_t a, uint64_t b, uint64_t carry_in); + [[nodiscard]] static constexpr uint64_t sbb_discard_hi(uint64_t a, uint64_t b, uint64_t borrow_in); + [[nodiscard]] static constexpr std::pair sbb(uint64_t a, uint64_t b, uint64_t borrow_in); + [[nodiscard]] static constexpr uint64_t mac_discard_hi(uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in); + [[nodiscard]] static constexpr std::pair mac(uint64_t a, + uint64_t b, + uint64_t c, + uint64_t carry_in); +#if defined(__wasm__) || !defined(__SIZEOF_INT128__) + static constexpr void wasm_madd(const uint64_t& left_limb, + const uint64_t* right_limbs, + uint64_t& result_0, + uint64_t& result_1, + uint64_t& result_2, + uint64_t& result_3, + uint64_t& result_4, + uint64_t& result_5, + uint64_t& result_6, + uint64_t& result_7, + uint64_t& result_8); + [[nodiscard]] static constexpr std::array wasm_convert(const uint64_t* data); +#endif +}; + +inline std::ostream& operator<<(std::ostream& os, uint256_t const& a) +{ + std::ios_base::fmtflags f(os.flags()); + os << std::hex << "0x" << std::setfill('0') << std::setw(16) << a.data[3] << std::setw(16) << a.data[2] + << std::setw(16) << a.data[1] << std::setw(16) << a.data[0]; + os.flags(f); + return os; +} +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/uint256/uint256_impl.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/uint256/uint256_impl.hpp new file mode 100644 index 0000000..43cd810 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/uint256/uint256_impl.hpp @@ -0,0 +1,622 @@ +#pragma once +#include "../bitop/get_msb.hpp" +#include "./uint256.hpp" +#include "../../common/assert.hpp" +namespace bb::numeric { + +constexpr std::pair uint256_t::mul_wide(const uint64_t a, const uint64_t b) +{ + const uint64_t a_lo = a & 0xffffffffULL; + const uint64_t a_hi = a >> 32ULL; + const uint64_t b_lo = b & 0xffffffffULL; + const uint64_t b_hi = b >> 32ULL; + + const uint64_t lo_lo = a_lo * b_lo; + const uint64_t hi_lo = a_hi * b_lo; + const uint64_t lo_hi = a_lo * b_hi; + const uint64_t hi_hi = a_hi * b_hi; + + const uint64_t cross = (lo_lo >> 32ULL) + (hi_lo & 0xffffffffULL) + lo_hi; + + return { (cross << 32ULL) | (lo_lo & 0xffffffffULL), (hi_lo >> 32ULL) + (cross >> 32ULL) + hi_hi }; +} + +// compute a + b + carry, returning the carry +constexpr std::pair uint256_t::addc(const uint64_t a, const uint64_t b, const uint64_t carry_in) +{ + const uint64_t sum = a + b; + const auto carry_temp = static_cast(sum < a); + const uint64_t r = sum + carry_in; + const uint64_t carry_out = carry_temp + static_cast(r < carry_in); + return { r, carry_out }; +} + +constexpr uint64_t uint256_t::addc_discard_hi(const uint64_t a, const uint64_t b, const uint64_t carry_in) +{ + return a + b + carry_in; +} + +constexpr std::pair uint256_t::sbb(const uint64_t a, const uint64_t b, const uint64_t borrow_in) +{ + const uint64_t t_1 = a - (borrow_in >> 63ULL); + const auto borrow_temp_1 = static_cast(t_1 > a); + const uint64_t t_2 = t_1 - b; + const auto borrow_temp_2 = static_cast(t_2 > t_1); + + return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) }; +} + +constexpr uint64_t uint256_t::sbb_discard_hi(const uint64_t a, const uint64_t b, const uint64_t borrow_in) +{ + return a - b - (borrow_in >> 63ULL); +} + +// {r, carry_out} = a + carry_in + b * c +constexpr std::pair uint256_t::mac(const uint64_t a, + const uint64_t b, + const uint64_t c, + const uint64_t carry_in) +{ + std::pair result = mul_wide(b, c); + result.first += a; + const auto overflow_c = static_cast(result.first < a); + result.first += carry_in; + const auto overflow_carry = static_cast(result.first < carry_in); + result.second += (overflow_c + overflow_carry); + return result; +} + +constexpr uint64_t uint256_t::mac_discard_hi(const uint64_t a, + const uint64_t b, + const uint64_t c, + const uint64_t carry_in) +{ + return (b * c + a + carry_in); +} +#if defined(__wasm__) || !defined(__SIZEOF_INT128__) + +/** + * @brief Multiply one limb by 9 limbs and add to resulting limbs + * + */ +constexpr void uint256_t::wasm_madd(const uint64_t& left_limb, + const uint64_t* right_limbs, + uint64_t& result_0, + uint64_t& result_1, + uint64_t& result_2, + uint64_t& result_3, + uint64_t& result_4, + uint64_t& result_5, + uint64_t& result_6, + uint64_t& result_7, + uint64_t& result_8) +{ + result_0 += left_limb * right_limbs[0]; + result_1 += left_limb * right_limbs[1]; + result_2 += left_limb * right_limbs[2]; + result_3 += left_limb * right_limbs[3]; + result_4 += left_limb * right_limbs[4]; + result_5 += left_limb * right_limbs[5]; + result_6 += left_limb * right_limbs[6]; + result_7 += left_limb * right_limbs[7]; + result_8 += left_limb * right_limbs[8]; +} + +/** + * @brief Convert from 4 64-bit limbs to 9 29-bit limbs + * + */ +constexpr std::array uint256_t::wasm_convert(const uint64_t* data) +{ + return { data[0] & 0x1fffffff, + (data[0] >> 29) & 0x1fffffff, + ((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6), + (data[1] >> 23) & 0x1fffffff, + ((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12), + (data[2] >> 17) & 0x1fffffff, + ((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18), + (data[3] >> 11) & 0x1fffffff, + (data[3] >> 40) & 0x1fffffff }; +} +#endif +constexpr std::pair uint256_t::divmod(const uint256_t& b) const +{ + if (*this == 0 || b == 0) { + return { 0, 0 }; + } + if (b == 1) { + return { *this, 0 }; + } + if (*this == b) { + return { 1, 0 }; + } + if (b > *this) { + return { 0, *this }; + } + + uint256_t quotient = 0; + uint256_t remainder = *this; + + uint64_t bit_difference = get_msb() - b.get_msb(); + + uint256_t divisor = b << bit_difference; + uint256_t accumulator = uint256_t(1) << bit_difference; + + // if the divisor is bigger than the remainder, a and b have the same bit length + if (divisor > remainder) { + divisor >>= 1; + accumulator >>= 1; + } + + // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder, + // and add to the quotient + while (remainder >= b) { + + // we've shunted 'divisor' up to have the same bit length as our remainder. + // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b + if (remainder >= divisor) { + remainder -= divisor; + // we can use OR here instead of +, as + // accumulator is always a nice power of two + quotient |= accumulator; + } + divisor >>= 1; + accumulator >>= 1; + } + + return { quotient, remainder }; +} + +/** + * @brief Compute the result of multiplication modulu 2**512 + * + */ +constexpr std::pair uint256_t::mul_extended(const uint256_t& other) const +{ +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const auto [r0, t0] = mul_wide(data[0], other.data[0]); + const auto [q0, t1] = mac(t0, data[0], other.data[1], 0); + const auto [q1, t2] = mac(t1, data[0], other.data[2], 0); + const auto [q2, z0] = mac(t2, data[0], other.data[3], 0); + + const auto [r1, t3] = mac(q0, data[1], other.data[0], 0); + const auto [q3, t4] = mac(q1, data[1], other.data[1], t3); + const auto [q4, t5] = mac(q2, data[1], other.data[2], t4); + const auto [q5, z1] = mac(z0, data[1], other.data[3], t5); + + const auto [r2, t6] = mac(q3, data[2], other.data[0], 0); + const auto [q6, t7] = mac(q4, data[2], other.data[1], t6); + const auto [q7, t8] = mac(q5, data[2], other.data[2], t7); + const auto [q8, z2] = mac(z1, data[2], other.data[3], t8); + + const auto [r3, t9] = mac(q6, data[3], other.data[0], 0); + const auto [r4, t10] = mac(q7, data[3], other.data[1], t9); + const auto [r5, t11] = mac(q8, data[3], other.data[2], t10); + const auto [r6, r7] = mac(z2, data[3], other.data[3], t11); + + uint256_t lo(r0, r1, r2, r3); + uint256_t hi(r4, r5, r6, r7); + return { lo, hi }; +#else + // Convert 4 64-bit limbs to 9 29-bit limbs + const auto left = wasm_convert(data); + const auto right = wasm_convert(other.data); + constexpr uint64_t mask = 0x1fffffff; + uint64_t temp_0 = 0; + uint64_t temp_1 = 0; + uint64_t temp_2 = 0; + uint64_t temp_3 = 0; + uint64_t temp_4 = 0; + uint64_t temp_5 = 0; + uint64_t temp_6 = 0; + uint64_t temp_7 = 0; + uint64_t temp_8 = 0; + uint64_t temp_9 = 0; + uint64_t temp_10 = 0; + uint64_t temp_11 = 0; + uint64_t temp_12 = 0; + uint64_t temp_13 = 0; + uint64_t temp_14 = 0; + uint64_t temp_15 = 0; + uint64_t temp_16 = 0; + + // Multiply and addd all limbs + wasm_madd(left[0], &right[0], temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + wasm_madd(left[1], &right[0], temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9); + wasm_madd(left[2], &right[0], temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10); + wasm_madd(left[3], &right[0], temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11); + wasm_madd(left[4], &right[0], temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12); + wasm_madd(left[5], &right[0], temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13); + wasm_madd(left[6], &right[0], temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14); + wasm_madd(left[7], &right[0], temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15); + wasm_madd(left[8], &right[0], temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16); + + // Convert from relaxed form into strict 29-bit form (except for temp_16) + temp_1 += temp_0 >> WASM_LIMB_BITS; + temp_0 &= mask; + temp_2 += temp_1 >> WASM_LIMB_BITS; + temp_1 &= mask; + temp_3 += temp_2 >> WASM_LIMB_BITS; + temp_2 &= mask; + temp_4 += temp_3 >> WASM_LIMB_BITS; + temp_3 &= mask; + temp_5 += temp_4 >> WASM_LIMB_BITS; + temp_4 &= mask; + temp_6 += temp_5 >> WASM_LIMB_BITS; + temp_5 &= mask; + temp_7 += temp_6 >> WASM_LIMB_BITS; + temp_6 &= mask; + temp_8 += temp_7 >> WASM_LIMB_BITS; + temp_7 &= mask; + temp_9 += temp_8 >> WASM_LIMB_BITS; + temp_8 &= mask; + temp_10 += temp_9 >> WASM_LIMB_BITS; + temp_9 &= mask; + temp_11 += temp_10 >> WASM_LIMB_BITS; + temp_10 &= mask; + temp_12 += temp_11 >> WASM_LIMB_BITS; + temp_11 &= mask; + temp_13 += temp_12 >> WASM_LIMB_BITS; + temp_12 &= mask; + temp_14 += temp_13 >> WASM_LIMB_BITS; + temp_13 &= mask; + temp_15 += temp_14 >> WASM_LIMB_BITS; + temp_14 &= mask; + temp_16 += temp_15 >> WASM_LIMB_BITS; + temp_15 &= mask; + + // Convert to 2 4-64-bit limb uint256_t objects + return { { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58), + (temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52), + (temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46), + (temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40) }, + { (temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63), + (temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57), + (temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51), + (temp_15 >> 13) | (temp_16 << 16) } }; +#endif +} + +/** + * Viewing `this` uint256_t as a bit string, and counting bits from 0, slices a substring. + * @returns the uint256_t equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the + * `end`-th bit of `this`. + */ +constexpr uint256_t uint256_t::slice(const uint64_t start, const uint64_t end) const +{ + const uint64_t range = end - start; + const uint256_t mask = (range == 256) ? -uint256_t(1) : (uint256_t(1) << range) - 1; + return ((*this) >> start) & mask; +} + +constexpr uint256_t uint256_t::pow(const uint256_t& exponent) const +{ + uint256_t accumulator{ data[0], data[1], data[2], data[3] }; + uint256_t to_mul{ data[0], data[1], data[2], data[3] }; + const uint64_t maximum_set_bit = exponent.get_msb(); + + for (int i = static_cast(maximum_set_bit) - 1; i >= 0; --i) { + accumulator *= accumulator; + if (exponent.get_bit(static_cast(i))) { + accumulator *= to_mul; + } + } + if (exponent == uint256_t(0)) { + accumulator = uint256_t(1); + } else if (*this == uint256_t(0)) { + accumulator = uint256_t(0); + } + return accumulator; +} + +constexpr bool uint256_t::get_bit(const uint64_t bit_index) const +{ + ASSERT(bit_index < 256); + if (bit_index > 255) { + return static_cast(0); + } + const auto idx = static_cast(bit_index >> 6); + const size_t shift = bit_index & 63; + return static_cast((data[idx] >> shift) & 1); +} + +constexpr uint64_t uint256_t::get_msb() const +{ + uint64_t idx = numeric::get_msb(data[3]); + idx = (idx == 0 && data[3] == 0) ? numeric::get_msb(data[2]) : idx + 64; + idx = (idx == 0 && data[2] == 0) ? numeric::get_msb(data[1]) : idx + 64; + idx = (idx == 0 && data[1] == 0) ? numeric::get_msb(data[0]) : idx + 64; + return idx; +} + +constexpr uint256_t uint256_t::operator+(const uint256_t& other) const +{ + const auto [r0, t0] = addc(data[0], other.data[0], 0); + const auto [r1, t1] = addc(data[1], other.data[1], t0); + const auto [r2, t2] = addc(data[2], other.data[2], t1); + const auto r3 = addc_discard_hi(data[3], other.data[3], t2); + return { r0, r1, r2, r3 }; +}; + +constexpr uint256_t uint256_t::operator-(const uint256_t& other) const +{ + + const auto [r0, t0] = sbb(data[0], other.data[0], 0); + const auto [r1, t1] = sbb(data[1], other.data[1], t0); + const auto [r2, t2] = sbb(data[2], other.data[2], t1); + const auto r3 = sbb_discard_hi(data[3], other.data[3], t2); + return { r0, r1, r2, r3 }; +} + +constexpr uint256_t uint256_t::operator-() const +{ + return uint256_t(0) - *this; +} + +constexpr uint256_t uint256_t::operator*(const uint256_t& other) const +{ + +#if defined(__SIZEOF_INT128__) && !defined(__wasm__) + const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL); + const auto [q0, t1] = mac(0, data[0], other.data[1], t0); + const auto [q1, t2] = mac(0, data[0], other.data[2], t1); + const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2); + + const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL); + const auto [q3, t4] = mac(q1, data[1], other.data[1], t3); + const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4); + + const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL); + const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5); + + const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL); + + return { r0, r1, r2, r3 }; +#else + // Convert 4 64-bit limbs to 9 29-bit limbs + const auto left = wasm_convert(data); + const auto right = wasm_convert(other.data); + uint64_t temp_0 = 0; + uint64_t temp_1 = 0; + uint64_t temp_2 = 0; + uint64_t temp_3 = 0; + uint64_t temp_4 = 0; + uint64_t temp_5 = 0; + uint64_t temp_6 = 0; + uint64_t temp_7 = 0; + uint64_t temp_8 = 0; + + // Multiply and add the product of left limb 0 by all right limbs + wasm_madd(left[0], &right[0], temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8); + // Multiply left limb 1 by limbs 0-7 ((1,8) doesn't need to be computed, because it overflows) + temp_1 += left[1] * right[0]; + temp_2 += left[1] * right[1]; + temp_3 += left[1] * right[2]; + temp_4 += left[1] * right[3]; + temp_5 += left[1] * right[4]; + temp_6 += left[1] * right[5]; + temp_7 += left[1] * right[6]; + temp_8 += left[1] * right[7]; + // Left limb 2 by right 0-6, etc + temp_2 += left[2] * right[0]; + temp_3 += left[2] * right[1]; + temp_4 += left[2] * right[2]; + temp_5 += left[2] * right[3]; + temp_6 += left[2] * right[4]; + temp_7 += left[2] * right[5]; + temp_8 += left[2] * right[6]; + temp_3 += left[3] * right[0]; + temp_4 += left[3] * right[1]; + temp_5 += left[3] * right[2]; + temp_6 += left[3] * right[3]; + temp_7 += left[3] * right[4]; + temp_8 += left[3] * right[5]; + temp_4 += left[4] * right[0]; + temp_5 += left[4] * right[1]; + temp_6 += left[4] * right[2]; + temp_7 += left[4] * right[3]; + temp_8 += left[4] * right[4]; + temp_5 += left[5] * right[0]; + temp_6 += left[5] * right[1]; + temp_7 += left[5] * right[2]; + temp_8 += left[5] * right[3]; + temp_6 += left[6] * right[0]; + temp_7 += left[6] * right[1]; + temp_8 += left[6] * right[2]; + temp_7 += left[7] * right[0]; + temp_8 += left[7] * right[1]; + temp_8 += left[8] * right[0]; + + // Convert from relaxed form to strict 29-bit form + constexpr uint64_t mask = 0x1fffffff; + temp_1 += temp_0 >> WASM_LIMB_BITS; + temp_0 &= mask; + temp_2 += temp_1 >> WASM_LIMB_BITS; + temp_1 &= mask; + temp_3 += temp_2 >> WASM_LIMB_BITS; + temp_2 &= mask; + temp_4 += temp_3 >> WASM_LIMB_BITS; + temp_3 &= mask; + temp_5 += temp_4 >> WASM_LIMB_BITS; + temp_4 &= mask; + temp_6 += temp_5 >> WASM_LIMB_BITS; + temp_5 &= mask; + temp_7 += temp_6 >> WASM_LIMB_BITS; + temp_6 &= mask; + temp_8 += temp_7 >> WASM_LIMB_BITS; + temp_7 &= mask; + + // Convert back to 4 64-bit limbs + return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58), + (temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52), + (temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46), + (temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40) }; +#endif +} + +constexpr uint256_t uint256_t::operator/(const uint256_t& other) const +{ + return divmod(other).first; +} + +constexpr uint256_t uint256_t::operator%(const uint256_t& other) const +{ + return divmod(other).second; +} + +constexpr uint256_t uint256_t::operator&(const uint256_t& other) const +{ + return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] }; +} + +constexpr uint256_t uint256_t::operator^(const uint256_t& other) const +{ + return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] }; +} + +constexpr uint256_t uint256_t::operator|(const uint256_t& other) const +{ + return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] }; +} + +constexpr uint256_t uint256_t::operator~() const +{ + return { ~data[0], ~data[1], ~data[2], ~data[3] }; +} + +constexpr bool uint256_t::operator==(const uint256_t& other) const +{ + return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3]; +} + +constexpr bool uint256_t::operator!=(const uint256_t& other) const +{ + return !(*this == other); +} + +constexpr bool uint256_t::operator!() const +{ + return *this == uint256_t(0ULL); +} + +constexpr bool uint256_t::operator>(const uint256_t& other) const +{ + bool t0 = data[3] > other.data[3]; + bool t1 = data[3] == other.data[3] && data[2] > other.data[2]; + bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1]; + bool t3 = + data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0]; + return t0 || t1 || t2 || t3; +} + +constexpr bool uint256_t::operator>=(const uint256_t& other) const +{ + return (*this > other) || (*this == other); +} + +constexpr bool uint256_t::operator<(const uint256_t& other) const +{ + return other > *this; +} + +constexpr bool uint256_t::operator<=(const uint256_t& other) const +{ + return (*this < other) || (*this == other); +} + +constexpr uint256_t uint256_t::operator>>(const uint256_t& other) const +{ + uint64_t total_shift = other.data[0]; + + if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) { + return 0; + } + + if (total_shift == 0) { + return *this; + } + + uint64_t num_shifted_limbs = total_shift >> 6ULL; + uint64_t limb_shift = total_shift & 63ULL; + + std::array shifted_limbs = { 0, 0, 0, 0 }; + + if (limb_shift == 0) { + shifted_limbs[0] = data[0]; + shifted_limbs[1] = data[1]; + shifted_limbs[2] = data[2]; + shifted_limbs[3] = data[3]; + } else { + uint64_t remainder_shift = 64ULL - limb_shift; + + shifted_limbs[3] = data[3] >> limb_shift; + + uint64_t remainder = (data[3]) << remainder_shift; + + shifted_limbs[2] = (data[2] >> limb_shift) + remainder; + + remainder = (data[2]) << remainder_shift; + + shifted_limbs[1] = (data[1] >> limb_shift) + remainder; + + remainder = (data[1]) << remainder_shift; + + shifted_limbs[0] = (data[0] >> limb_shift) + remainder; + } + uint256_t result(0); + + for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) { + result.data[i] = shifted_limbs[static_cast(i + num_shifted_limbs)]; + } + + return result; +} + +constexpr uint256_t uint256_t::operator<<(const uint256_t& other) const +{ + uint64_t total_shift = other.data[0]; + + if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) { + return 0; + } + + if (total_shift == 0) { + return *this; + } + uint64_t num_shifted_limbs = total_shift >> 6ULL; + uint64_t limb_shift = total_shift & 63ULL; + + std::array shifted_limbs = { 0, 0, 0, 0 }; + + if (limb_shift == 0) { + shifted_limbs[0] = data[0]; + shifted_limbs[1] = data[1]; + shifted_limbs[2] = data[2]; + shifted_limbs[3] = data[3]; + } else { + uint64_t remainder_shift = 64ULL - limb_shift; + + shifted_limbs[0] = data[0] << limb_shift; + + uint64_t remainder = data[0] >> remainder_shift; + + shifted_limbs[1] = (data[1] << limb_shift) + remainder; + + remainder = data[1] >> remainder_shift; + + shifted_limbs[2] = (data[2] << limb_shift) + remainder; + + remainder = data[2] >> remainder_shift; + + shifted_limbs[3] = (data[3] << limb_shift) + remainder; + } + uint256_t result(0); + + for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) { + result.data[static_cast(i + num_shifted_limbs)] = shifted_limbs[i]; + } + + return result; +} + +} // namespace bb::numeric diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx.hpp new file mode 100644 index 0000000..4ffe48e --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx.hpp @@ -0,0 +1,178 @@ +/** + * uintx + * Copyright Aztec 2020 + * + * An unsigned 512 bit integer type. + * + * Constructor and all methods are constexpr. Ideally, uintx should be able to be treated like any other literal + *type. + * + * Not optimized for performance, this code doesn"t touch any of our hot paths when constructing PLONK proofs + **/ +#pragma once + +#include "../uint256/uint256.hpp" +#include "../../common/assert.hpp" +#include "../../common/throw_or_abort.hpp" +#include +#include +#include + +namespace bb::numeric { + +template class uintx { + public: + constexpr uintx(const uint64_t data = 0) + : lo(data) + , hi(base_uint(0)) + {} + + constexpr uintx(const base_uint input_lo) + : lo(input_lo) + , hi(base_uint(0)) + {} + + constexpr uintx(const base_uint input_lo, const base_uint input_hi) + : lo(input_lo) + , hi(input_hi) + {} + + constexpr uintx(const uintx& other) + : lo(other.lo) + , hi(other.hi) + {} + + constexpr uintx(uintx&& other) noexcept = default; + + static constexpr size_t length() { return 2 * base_uint::length(); } + constexpr uintx& operator=(const uintx& other) = default; + constexpr uintx& operator=(uintx&& other) noexcept = default; + + constexpr ~uintx() = default; + explicit constexpr operator bool() const { return static_cast(lo.data[0]); }; + explicit constexpr operator uint8_t() const { return static_cast(lo.data[0]); }; + explicit constexpr operator uint16_t() const { return static_cast(lo.data[0]); }; + explicit constexpr operator uint32_t() const { return static_cast(lo.data[0]); }; + explicit constexpr operator uint64_t() const { return static_cast(lo.data[0]); }; + + explicit constexpr operator base_uint() const { return lo; } + + [[nodiscard]] constexpr bool get_bit(uint64_t bit_index) const; + [[nodiscard]] constexpr uint64_t get_msb() const; + constexpr uintx slice(uint64_t start, uint64_t end) const; + + constexpr uintx operator+(const uintx& other) const; + constexpr uintx operator-(const uintx& other) const; + constexpr uintx operator-() const; + + constexpr uintx operator*(const uintx& other) const; + constexpr uintx operator/(const uintx& other) const; + constexpr uintx operator%(const uintx& other) const; + + constexpr std::pair mul_extended(const uintx& other) const; + + constexpr uintx operator>>(uint64_t other) const; + constexpr uintx operator<<(uint64_t other) const; + + constexpr uintx operator&(const uintx& other) const; + constexpr uintx operator^(const uintx& other) const; + constexpr uintx operator|(const uintx& other) const; + constexpr uintx operator~() const; + + constexpr bool operator==(const uintx& other) const; + constexpr bool operator!=(const uintx& other) const; + constexpr bool operator!() const; + + constexpr bool operator>(const uintx& other) const; + constexpr bool operator<(const uintx& other) const; + constexpr bool operator>=(const uintx& other) const; + constexpr bool operator<=(const uintx& other) const; + + constexpr uintx& operator+=(const uintx& other) + { + *this = *this + other; + return *this; + }; + constexpr uintx& operator-=(const uintx& other) + { + *this = *this - other; + return *this; + }; + constexpr uintx& operator*=(const uintx& other) + { + *this = *this * other; + return *this; + }; + constexpr uintx& operator/=(const uintx& other) + { + *this = *this / other; + return *this; + }; + constexpr uintx& operator%=(const uintx& other) + { + *this = *this % other; + return *this; + }; + + constexpr uintx& operator++() + { + *this += uintx(1); + return *this; + }; + constexpr uintx& operator--() + { + *this -= uintx(1); + return *this; + }; + + constexpr uintx& operator&=(const uintx& other) + { + *this = *this & other; + return *this; + }; + constexpr uintx& operator^=(const uintx& other) + { + *this = *this ^ other; + return *this; + }; + constexpr uintx& operator|=(const uintx& other) + { + *this = *this | other; + return *this; + }; + + constexpr uintx& operator>>=(const uint64_t other) + { + *this = *this >> other; + return *this; + }; + constexpr uintx& operator<<=(const uint64_t other) + { + *this = *this << other; + return *this; + }; + + constexpr uintx invmod(const uintx& modulus) const; + constexpr uintx unsafe_invmod(const uintx& modulus) const; + + base_uint lo; + base_uint hi; + + constexpr std::pair divmod(const uintx& b) const; +}; + +template inline std::ostream& operator<<(std::ostream& os, uintx const& a) +{ + os << a.lo << ", " << a.hi << std::endl; + return os; +} + +using uint512_t = uintx; +using uint1024_t = uintx; + +} // namespace bb::numeric + +#include "./uintx_impl.hpp" + +using bb::numeric::uint1024_t; // NOLINT +using bb::numeric::uint512_t; // NOLINT diff --git a/sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx_impl.hpp b/sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx_impl.hpp new file mode 100644 index 0000000..4efea35 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/numeric/uintx/uintx_impl.hpp @@ -0,0 +1,339 @@ +#pragma once +#include "./uintx.hpp" +#include "../../common/assert.hpp" + +namespace bb::numeric { +template +constexpr std::pair, uintx> uintx::divmod(const uintx& b) const +{ + ASSERT(b != 0); + if (*this == 0) { + return { uintx(0), uintx(0) }; + } + if (b == 1) { + return { *this, uintx(0) }; + } + if (*this == b) { + return { uintx(1), uintx(0) }; + } + if (b > *this) { + return { uintx(0), *this }; + } + + uintx quotient(0); + uintx remainder = *this; + + uint64_t bit_difference = get_msb() - b.get_msb(); + + uintx divisor = b << bit_difference; + uintx accumulator = uintx(1) << bit_difference; + + // if the divisor is bigger than the remainder, a and b have the same bit length + if (divisor > remainder) { + divisor >>= 1; + accumulator >>= 1; + } + + // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder, + // and add to the quotient + while (remainder >= b) { + + // we've shunted 'divisor' up to have the same bit length as our remainder. + // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b + if (remainder >= divisor) { + remainder -= divisor; + // we can use OR here instead of +, as + // accumulator is always a nice power of two + quotient |= accumulator; + } + divisor >>= 1; + accumulator >>= 1; + } + + return std::make_pair(quotient, remainder); +} + +/** + * Computes invmod. Only for internal usage within the class. + * This is an insecure version of the algorithm that doesn't take into account the 0 case and cases when modulus is + *close to the top margin. + * + * @param modulus The modulus of the ring + * + * @return The inverse of *this modulo modulus + **/ +template constexpr uintx uintx::unsafe_invmod(const uintx& modulus) const +{ + + uintx t1 = 0; + uintx t2 = 1; + uintx r2 = (*this > modulus) ? *this % modulus : *this; + uintx r1 = modulus; + uintx q = 0; + while (r2 != 0) { + q = r1 / r2; + uintx temp_t1 = t1; + uintx temp_r1 = r1; + t1 = t2; + t2 = temp_t1 - q * t2; + r1 = r2; + r2 = temp_r1 - q * r2; + } + + if (t1 > modulus) { + return modulus + t1; + } + return t1; +} + +/** + * Computes the inverse of *this, modulo modulus, via the extended Euclidean algorithm. + * + * Delegates to appropriate unsafe_invmod (if the modulus is close to uintx top margin there is a need to expand) + * + * @param modulus The modulus + * @return The inverse of *this modulo modulus + **/ +template constexpr uintx uintx::invmod(const uintx& modulus) const +{ + ASSERT((*this) != 0); + if (modulus == 0) { + return 0; + } + if (modulus.get_msb() >= (2 * base_uint::length() - 1)) { + uintx> a_expanded(*this); + uintx> modulus_expanded(modulus); + return a_expanded.unsafe_invmod(modulus_expanded).lo; + } + return this->unsafe_invmod(modulus); +} + +/** + * Viewing `this` as a bit string, and counting bits from 0, slices a substring. + * @returns the uintx equal to the substring of bits from (and including) the `start`-th bit, to (but excluding) the + * `end`-th bit of `this`. + */ +template +constexpr uintx uintx::slice(const uint64_t start, const uint64_t end) const +{ + const uint64_t range = end - start; + const uintx mask = range == base_uint::length() ? -uintx(1) : (uintx(1) << range) - 1; + return ((*this) >> start) & mask; +} + +template constexpr bool uintx::get_bit(const uint64_t bit_index) const +{ + if (bit_index >= base_uint::length()) { + return hi.get_bit(bit_index - base_uint::length()); + } + return lo.get_bit(bit_index); +} + +template constexpr uint64_t uintx::get_msb() const +{ + uint64_t hi_idx = hi.get_msb(); + uint64_t lo_idx = lo.get_msb(); + return (hi_idx || (hi > base_uint(0))) ? (hi_idx + base_uint::length()) : lo_idx; +} + +template constexpr uintx uintx::operator+(const uintx& other) const +{ + base_uint res_lo = lo + other.lo; + bool carry = res_lo < lo; + base_uint res_hi = hi + other.hi + ((carry) ? base_uint(1) : base_uint(0)); + return { res_lo, res_hi }; +}; + +template constexpr uintx uintx::operator-(const uintx& other) const +{ + base_uint res_lo = lo - other.lo; + bool borrow = res_lo > lo; + base_uint res_hi = hi - other.hi - ((borrow) ? base_uint(1) : base_uint(0)); + return { res_lo, res_hi }; +} + +template constexpr uintx uintx::operator-() const +{ + return uintx(0) - *this; +} + +template constexpr uintx uintx::operator*(const uintx& other) const +{ + const auto lolo = lo.mul_extended(other.lo); + const auto lohi = lo.mul_extended(other.hi); + const auto hilo = hi.mul_extended(other.lo); + + base_uint top = lolo.second + hilo.first + lohi.first; + base_uint bottom = lolo.first; + return { bottom, top }; +} + +template +constexpr std::pair, uintx> uintx::mul_extended(const uintx& other) const +{ + const auto lolo = lo.mul_extended(other.lo); + const auto lohi = lo.mul_extended(other.hi); + const auto hilo = hi.mul_extended(other.lo); + const auto hihi = hi.mul_extended(other.hi); + + base_uint t0 = lolo.first; + base_uint t1 = lolo.second; + base_uint t2 = hilo.second; + base_uint t3 = hihi.second; + base_uint t2_carry(0); + base_uint t3_carry(0); + t1 += hilo.first; + t2_carry += (t1 < hilo.first ? base_uint(1) : base_uint(0)); + t1 += lohi.first; + t2_carry += (t1 < lohi.first ? base_uint(1) : base_uint(0)); + t2 += lohi.second; + t3_carry += (t2 < lohi.second ? base_uint(1) : base_uint(0)); + t2 += hihi.first; + t3_carry += (t2 < hihi.first ? base_uint(1) : base_uint(0)); + t2 += t2_carry; + t3_carry += (t2 < t2_carry ? base_uint(1) : base_uint(0)); + t3 += t3_carry; + return { uintx(t0, t1), uintx(t2, t3) }; +} + +template constexpr uintx uintx::operator/(const uintx& other) const +{ + return divmod(other).first; +} + +template constexpr uintx uintx::operator%(const uintx& other) const +{ + return divmod(other).second; +} +// 0x2af0296feca4188a80fd373ebe3c64da87a232934abb3a99f9c4cd59e6758a65 +// 0x1182c6cdb54193b51ca27c1932b95c82bebac691e3996e5ec5e1d4395f3023e3 +template constexpr uintx uintx::operator&(const uintx& other) const +{ + return { lo & other.lo, hi & other.hi }; +} + +template constexpr uintx uintx::operator^(const uintx& other) const +{ + return { lo ^ other.lo, hi ^ other.hi }; +} + +template constexpr uintx uintx::operator|(const uintx& other) const +{ + return { lo | other.lo, hi | other.hi }; +} + +template constexpr uintx uintx::operator~() const +{ + return { ~lo, ~hi }; +} + +template constexpr bool uintx::operator==(const uintx& other) const +{ + return ((lo == other.lo) && (hi == other.hi)); +} + +template constexpr bool uintx::operator!=(const uintx& other) const +{ + return !(*this == other); +} + +template constexpr bool uintx::operator!() const +{ + return *this == uintx(0ULL); +} + +template constexpr bool uintx::operator>(const uintx& other) const +{ + bool hi_gt = hi > other.hi; + bool lo_gt = lo > other.lo; + + bool gt = (hi_gt) || (lo_gt && (hi == other.hi)); + return gt; +} + +template constexpr bool uintx::operator>=(const uintx& other) const +{ + return (*this > other) || (*this == other); +} + +template constexpr bool uintx::operator<(const uintx& other) const +{ + return other > *this; +} + +template constexpr bool uintx::operator<=(const uintx& other) const +{ + return (*this < other) || (*this == other); +} + +template constexpr uintx uintx::operator>>(const uint64_t other) const +{ + const uint64_t total_shift = other; + if (total_shift >= length()) { + return uintx(0); + } + if (total_shift == 0) { + return *this; + } + const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb()); + + const uint64_t limb_shift = total_shift & static_cast(base_uint::length() - 1); + + std::array shifted_limbs = { 0, 0 }; + if (limb_shift == 0) { + shifted_limbs[0] = lo; + shifted_limbs[1] = hi; + } else { + const uint64_t remainder_shift = static_cast(base_uint::length()) - limb_shift; + + shifted_limbs[1] = hi >> limb_shift; + + base_uint remainder = (hi) << remainder_shift; + + shifted_limbs[0] = (lo >> limb_shift) + remainder; + } + uintx result(0); + if (num_shifted_limbs == 0) { + result.hi = shifted_limbs[1]; + result.lo = shifted_limbs[0]; + } else { + result.lo = shifted_limbs[1]; + } + return result; +} + +template constexpr uintx uintx::operator<<(const uint64_t other) const +{ + const uint64_t total_shift = other; + if (total_shift >= length()) { + return uintx(0); + } + if (total_shift == 0) { + return *this; + } + const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb()); + const uint64_t limb_shift = total_shift & static_cast(base_uint::length() - 1); + + std::array shifted_limbs = { 0, 0 }; + if (limb_shift == 0) { + shifted_limbs[0] = lo; + shifted_limbs[1] = hi; + } else { + const uint64_t remainder_shift = static_cast(base_uint::length()) - limb_shift; + + shifted_limbs[0] = lo << limb_shift; + + base_uint remainder = lo >> remainder_shift; + + shifted_limbs[1] = (hi << limb_shift) + remainder; + } + uintx result(0); + if (num_shifted_limbs == 0) { + result.hi = shifted_limbs[1]; + result.lo = shifted_limbs[0]; + } else { + result.hi = shifted_limbs[0]; + } + return result; +} +} // namespace bb::numeric \ No newline at end of file diff --git a/sumcheck/src/cuda/includes/barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp b/sumcheck/src/cuda/includes/barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp new file mode 100644 index 0000000..344c048 --- /dev/null +++ b/sumcheck/src/cuda/includes/barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp @@ -0,0 +1,27 @@ +/** + * @brief Defines particular circuit builder types expected to be used for circuit +construction in stdlib and contains macros for explicit instantiation. + * + * @details This file is designed to be included in header files to instruct the compiler that these classes exist and + * their instantiation will eventually take place. Given it has no dependencies, it causes no additional compilation or + * propagation. + */ +#pragma once +#include + +namespace bb { +class StandardFlavor; +class UltraFlavor; +class Bn254FrParams; +class Bn254FqParams; +template struct alignas(32) field; +template class UltraArith; +template class StandardCircuitBuilder_; +using StandardCircuitBuilder = StandardCircuitBuilder_>; +using StandardGrumpkinCircuitBuilder = StandardCircuitBuilder_>; +template class UltraCircuitBuilder_; +using UltraCircuitBuilder = UltraCircuitBuilder_>>; +template class MegaCircuitBuilder_; +using MegaCircuitBuilder = MegaCircuitBuilder_>; +class CircuitSimulatorBN254; +} // namespace bb