Optimizing polynomial evaluation kernel

This commit is contained in:
DoHoonKim8
2024-06-24 09:50:36 +00:00
committed by DoHoon Kim
parent 6b14910ba4
commit cdc22ba4c0
5 changed files with 69 additions and 14 deletions

View File

@@ -166,6 +166,19 @@ template <class Params_> struct alignas(32) field {
static constexpr __device__ uint256_t get_modulus() {
return uint256_t{ Params::modulus_0, Params::modulus_1, Params::modulus_2, Params::modulus_3 };
}
static constexpr __device__ uint256_t get_not_modulus() {
constexpr uint256_t modulus = get_modulus();
return -modulus;
}
static constexpr __device__ uint256_t get_twice_modulus() {
constexpr uint256_t modulus = get_modulus();
return modulus + modulus;
}
static constexpr __device__ uint256_t get_twice_not_modulus() {
constexpr uint256_t twice_modulus = get_twice_modulus();
return -twice_modulus;
}
#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
static constexpr uint256_t r_squared_uint{
Params_::r_squared_0, Params_::r_squared_1, Params_::r_squared_2, Params_::r_squared_3
@@ -479,10 +492,6 @@ template <class Params_> struct alignas(32) field {
static constexpr field multiplicative_generator() noexcept;
static constexpr uint256_t twice_modulus = get_modulus() + get_modulus();
static constexpr uint256_t not_modulus = -get_modulus();
static constexpr uint256_t twice_not_modulus = -twice_modulus;
struct wnaf_table {
uint8_t windows[64]; // NOLINT

View File

@@ -207,6 +207,7 @@ template <class T> __device__ constexpr void field<T>::self_neg() noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::self_neg");
constexpr uint256_t modulus = get_modulus();
constexpr uint256_t twice_modulus = get_twice_modulus();
if constexpr ((T::modulus_3 >= 0x4000000000000000ULL) ||
(T::modulus_1 == 0 && T::modulus_2 == 0 && T::modulus_3 == 0)) {
constexpr field p{ modulus.data[0], modulus.data[1], modulus.data[2], modulus.data[3] };

View File

@@ -180,6 +180,7 @@ constexpr uint64_t field<T>::square_accumulate(const uint64_t a,
template <class T> __device__ constexpr field<T> field<T>::reduce() const noexcept
{
constexpr uint256_t modulus = get_modulus();
constexpr uint256_t not_modulus = -modulus;
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
uint256_t val{ data[0], data[1], data[2], data[3] };
if (val >= modulus) {
@@ -205,6 +206,8 @@ template <class T> __device__ constexpr field<T> field<T>::reduce() const noexce
template <class T> __device__ constexpr field<T> field<T>::add(const field& other) const noexcept
{
constexpr uint256_t modulus = get_modulus();
constexpr uint256_t twice_not_modulus = get_twice_not_modulus();
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
uint64_t r0 = data[0] + other.data[0];
uint64_t c = r0 < data[0];
@@ -287,6 +290,7 @@ template <class T> __device__ constexpr field<T> field<T>::subtract(const field&
template <class T> __device__ constexpr field<T> field<T>::subtract_coarse(const field& other) const noexcept
{
constexpr uint256_t modulus = get_modulus();
constexpr uint256_t twice_modulus = get_twice_modulus();
if constexpr (modulus.data[3] >= 0x4000000000000000ULL) {
return subtract(other);
}

View File

@@ -1,4 +1,5 @@
#include "../includes/barretenberg/ecc/curves/bn254/fr.cuh"
#include <stdio.h>
using namespace bb;
@@ -17,3 +18,40 @@ extern "C" __global__ void evaluate(fr* coeffs, fr* point, uint8_t num_vars, fr*
}
return;
}
extern "C" __global__ void evaluate_optimized(fr* coeffs, fr* point, uint8_t num_vars, fr* monomial_evals, fr* result, int* mutex) {
const int tid = threadIdx.x;
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
auto step_size = 1;
int number_of_threads = blockDim.x >> 1;
bool evaluated = false;
while (number_of_threads > 0)
{
if (!evaluated) {
fr coeff = coeffs[idx].to_montgomery_form();
monomial_evals[idx] = coeff;
for (int i = 0; i < num_vars; i++) {
monomial_evals[idx] *= (((idx >> i) & 1) ? point[i] : fr::one());
}
evaluated = true;
__syncthreads();
continue;
}
if (tid < number_of_threads) // still alive?
{
const auto fst = blockIdx.x * blockDim.x + tid * step_size * 2;
const auto snd = fst + step_size;
monomial_evals[fst] += monomial_evals[snd];
}
step_size <<= 1;
number_of_threads >>= 1;
__syncthreads();
}
if (tid == 0) {
monomial_evals[idx].self_from_montgomery_form();
}
}

View File

@@ -4,6 +4,7 @@ use ff::{Field, PrimeField};
use field::{FromFieldBinding, ToFieldBinding};
use halo2curves::bn256::Fr;
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::marker::PhantomData;
use std::process::Output;
use std::time::Instant;
@@ -44,7 +45,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
let now = Instant::now();
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
gpu.load_ptx(ptx, "multilinear", &["evaluate"])?;
gpu.load_ptx(ptx, "multilinear", &["evaluate", "evaluate_optimized"])?;
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
@@ -56,36 +57,38 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
// copy to GPU
let gpu_coeffs = gpu.htod_copy(
poly_coeffs
.into_iter()
.into_par_iter()
.map(|&coeff| F::to_canonical_form(coeff))
.collect_vec(),
.collect()
)?;
let gpu_eval_point = gpu.htod_copy(point_montgomery)?;
let monomial_evals = gpu.htod_copy(vec![FieldBinding::default(); 1 << num_vars])?;
let mutex = unsafe { gpu.alloc_zeros::<u32>(1)? };
let result = gpu.htod_copy(vec![FieldBinding::default(); 1])?;
println!("Time taken to initialise data: {:.2?}", now.elapsed());
let now = Instant::now();
let f = gpu.get_func("multilinear", "evaluate").unwrap();
let evaluate_optimized = gpu.get_func("multilinear", "evaluate_optimized").unwrap();
unsafe {
f.launch(
evaluate_optimized.launch(
LaunchConfig::for_num_elems(1 << num_vars as u32),
(&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals),
)
}?;
(&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals, &result, &mutex),
)?;
};
println!("Time taken to call kernel: {:.2?}", now.elapsed());
let now = Instant::now();
// TODO : Calculate the sum in GPU side rather than copying the monomial evaluation results
let monomial_evals = gpu.sync_reclaim(monomial_evals)?;
println!("Time taken to synchronize: {:.2?}", now.elapsed());
let now = Instant::now();
let result = monomial_evals
.into_iter()
.step_by(1024)
.map(|eval| F::from_montgomery_form(eval))
.sum::<F>();
println!("Time taken to calculate sum: {:.2?}", now.elapsed());
@@ -130,7 +133,7 @@ mod tests {
#[test]
fn test_evaluate_poly() -> Result<(), DriverError> {
let num_vars = 20;
let num_vars = 16;
let rng = OsRng::default();
let poly_coeffs = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();