mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 15:38:01 -05:00
Optimizing polynomial evaluation kernel
This commit is contained in:
@@ -166,6 +166,19 @@ template <class Params_> struct alignas(32) field {
|
||||
static constexpr __device__ uint256_t get_modulus() {
|
||||
return uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
|
||||
}
|
||||
static constexpr __device__ uint256_t get_not_modulus() {
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
return -modulus;
|
||||
}
|
||||
|
||||
static constexpr __device__ uint256_t get_twice_modulus() {
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
return modulus + modulus;
|
||||
}
|
||||
static constexpr __device__ uint256_t get_twice_not_modulus() {
|
||||
constexpr uint256_t twice_modulus = get_twice_modulus();
|
||||
return -twice_modulus;
|
||||
}
|
||||
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
|
||||
static constexpr uint256_t r_squared_uint{
|
||||
Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3
|
||||
@@ -479,10 +492,6 @@ template <class Params_> struct alignas(32) field {
|
||||
|
||||
static constexpr field multiplicative_generator() noexcept;
|
||||
|
||||
static constexpr uint256_t twice_modulus = get_modulus() + get_modulus();
|
||||
static constexpr uint256_t not_modulus = -get_modulus();
|
||||
static constexpr uint256_t twice_not_modulus = -twice_modulus;
|
||||
|
||||
struct wnaf_table {
|
||||
uint8_t windows[64]; // NOLINT
|
||||
|
||||
|
||||
@@ -207,6 +207,7 @@ template <class T> __device__ constexpr void field<T>::self_neg() noexcept
|
||||
{
|
||||
BB_OP_COUNT_TRACK_NAME("fr::self_neg");
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
constexpr uint256_t twice_modulus = get_twice_modulus();
|
||||
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
|
||||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
|
||||
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };
|
||||
|
||||
@@ -180,6 +180,7 @@ constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
|
||||
template <class T> __device__ constexpr field<T> field<T>::reduce() const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
constexpr uint256_t not_modulus = -modulus;
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
uint256_t val{ data[0], data[1], data[2], data[3] };
|
||||
if (val >= modulus) {
|
||||
@@ -205,6 +206,8 @@ template <class T> __device__ constexpr field<T> field<T>::reduce() const noexce
|
||||
|
||||
template <class T> __device__ constexpr field<T> field<T>::add(const field& other) const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
constexpr uint256_t twice_not_modulus = get_twice_not_modulus();
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
uint64_t r0 = data[0] + other.data[0];
|
||||
uint64_t c = r0 < data[0];
|
||||
@@ -287,6 +290,7 @@ template <class T> __device__ constexpr field<T> field<T>::subtract(const field&
|
||||
template <class T> __device__ constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
|
||||
{
|
||||
constexpr uint256_t modulus = get_modulus();
|
||||
constexpr uint256_t twice_modulus = get_twice_modulus();
|
||||
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
|
||||
return subtract(other);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "../includes/barretenberg/ecc/curves/bn254/fr.cuh"
|
||||
#include <stdio.h>
|
||||
|
||||
using namespace bb;
|
||||
|
||||
@@ -17,3 +18,40 @@ extern "C" __global__ void evaluate(fr* coeffs, fr* point, uint8_t num_vars, fr*
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
extern "C" __global__ void evaluate_optimized(fr* coeffs, fr* point, uint8_t num_vars, fr* monomial_evals, fr* result, int* mutex) {
|
||||
const int tid = threadIdx.x;
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
auto step_size = 1;
|
||||
int number_of_threads = blockDim.x >> 1;
|
||||
bool evaluated = false;
|
||||
|
||||
while (number_of_threads > 0)
|
||||
{
|
||||
if (!evaluated) {
|
||||
fr coeff = coeffs[idx].to_montgomery_form();
|
||||
monomial_evals[idx] = coeff;
|
||||
for (int i = 0; i < num_vars; i++) {
|
||||
monomial_evals[idx] *= (((idx >> i) & 1) ? point[i] : fr::one());
|
||||
}
|
||||
evaluated = true;
|
||||
__syncthreads();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tid < number_of_threads) // still alive?
|
||||
{
|
||||
const auto fst = blockIdx.x * blockDim.x + tid * step_size * 2;
|
||||
const auto snd = fst + step_size;
|
||||
monomial_evals[fst] += monomial_evals[snd];
|
||||
}
|
||||
|
||||
step_size <<= 1;
|
||||
number_of_threads >>= 1;
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
monomial_evals[idx].self_from_montgomery_form();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use ff::{Field, PrimeField};
|
||||
use field::{FromFieldBinding, ToFieldBinding};
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use std::marker::PhantomData;
|
||||
use std::process::Output;
|
||||
use std::time::Instant;
|
||||
@@ -44,7 +45,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
let now = Instant::now();
|
||||
|
||||
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
|
||||
gpu.load_ptx(ptx, "multilinear", &["evaluate"])?;
|
||||
gpu.load_ptx(ptx, "multilinear", &["evaluate", "evaluate_optimized"])?;
|
||||
|
||||
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
|
||||
|
||||
@@ -56,36 +57,38 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
// copy to GPU
|
||||
let gpu_coeffs = gpu.htod_copy(
|
||||
poly_coeffs
|
||||
.into_iter()
|
||||
.into_par_iter()
|
||||
.map(|&coeff| F::to_canonical_form(coeff))
|
||||
.collect_vec(),
|
||||
.collect()
|
||||
)?;
|
||||
let gpu_eval_point = gpu.htod_copy(point_montgomery)?;
|
||||
let monomial_evals = gpu.htod_copy(vec![FieldBinding::default(); 1 << num_vars])?;
|
||||
let mutex = unsafe { gpu.alloc_zeros::<u32>(1)? };
|
||||
let result = gpu.htod_copy(vec![FieldBinding::default(); 1])?;
|
||||
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let f = gpu.get_func("multilinear", "evaluate").unwrap();
|
||||
let evaluate_optimized = gpu.get_func("multilinear", "evaluate_optimized").unwrap();
|
||||
|
||||
unsafe {
|
||||
f.launch(
|
||||
evaluate_optimized.launch(
|
||||
LaunchConfig::for_num_elems(1 << num_vars as u32),
|
||||
(&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals),
|
||||
)
|
||||
}?;
|
||||
(&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals, &result, &mutex),
|
||||
)?;
|
||||
};
|
||||
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
// TODO : Calculate the sum in GPU side rather than copying the monomial evaluation results
|
||||
let monomial_evals = gpu.sync_reclaim(monomial_evals)?;
|
||||
println!("Time taken to synchronize: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
let result = monomial_evals
|
||||
.into_iter()
|
||||
.step_by(1024)
|
||||
.map(|eval| F::from_montgomery_form(eval))
|
||||
.sum::<F>();
|
||||
println!("Time taken to calculate sum: {:.2?}", now.elapsed());
|
||||
@@ -130,7 +133,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_poly() -> Result<(), DriverError> {
|
||||
let num_vars = 20;
|
||||
let num_vars = 16;
|
||||
let rng = OsRng::default();
|
||||
let poly_coeffs = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
|
||||
Reference in New Issue
Block a user