mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 20:37:55 -05:00
[WIP] Update eval kernel
This commit is contained in:
@@ -40,23 +40,38 @@ extern "C" __global__ void eval_by_coeff(fr* coeffs, fr* point, uint8_t num_vars
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void eval(fr* evals, fr* point, uint8_t num_vars, fr* buf) {
|
||||
__device__ fr merge(fr* evals, fr* point, uint8_t point_index, u_int32_t chunk_size) {
|
||||
const int start = chunk_size * (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
auto step = 2;
|
||||
while (chunk_size > 1) {
|
||||
for (int i = 0; i < chunk_size / 2; i++) {
|
||||
auto fst = start + step * i;
|
||||
auto snd = fst + step / 2;
|
||||
evals[fst] = point[point_index] * (evals[snd] - evals[fst]) + evals[fst];
|
||||
}
|
||||
chunk_size >>= 1;
|
||||
step <<= 1;
|
||||
point_index++;
|
||||
}
|
||||
return evals[start];
|
||||
}
|
||||
|
||||
extern "C" __global__ void eval(fr* evals, fr* buf, fr* point, u_int32_t size, u_int32_t chunk_size, uint8_t offset) {
|
||||
const int tid = threadIdx.x;
|
||||
auto i = 0;
|
||||
auto num_threads = 1 << (num_vars - 1);
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint8_t log2_chunk_size = log2(chunk_size);
|
||||
u_int32_t num_threads = ceil(size / chunk_size);
|
||||
auto i = offset;
|
||||
while (num_threads > 0) {
|
||||
if (tid < num_threads) {
|
||||
buf[tid] = point[i] * (evals[2 * tid + 1] - evals[2 * tid]) + evals[2 * tid];
|
||||
buf[idx] = merge(evals, point, i, chunk_size);
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
memcpy(evals, buf, num_threads * 32);
|
||||
memcpy(&evals[chunk_size * blockIdx.x * blockDim.x], &buf[blockIdx.x * blockDim.x], num_threads * 32);
|
||||
}
|
||||
i++;
|
||||
num_threads >>= 1;
|
||||
i += log2_chunk_size;
|
||||
num_threads >>= log2_chunk_size;
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
buf[0].self_from_montgomery_form();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,34 +68,55 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
.collect_vec();
|
||||
|
||||
// copy to GPU
|
||||
let evals = self.gpu.htod_copy(
|
||||
let mut evals = self.gpu.htod_copy(
|
||||
evals
|
||||
.into_par_iter()
|
||||
.map(|&eval| F::to_montgomery_form(eval))
|
||||
.collect(),
|
||||
)?;
|
||||
let gpu_eval_point = self.gpu.htod_copy(point)?;
|
||||
let buf = self
|
||||
.gpu
|
||||
.htod_copy(vec![FieldBinding::default(); 1 << (num_vars - 1)])?;
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
let eval = self.gpu.get_func("multilinear", "eval").unwrap();
|
||||
let mut num_vars = num_vars;
|
||||
let mut results = vec![];
|
||||
let mut offset = 0;
|
||||
while num_vars > 0 {
|
||||
let log2_chunk_size = 1;
|
||||
let chunk_size = 1 << log2_chunk_size;
|
||||
let (data_size_per_block, block_num) = if num_vars < 10 + log2_chunk_size {
|
||||
(1 << num_vars, 1)
|
||||
} else {
|
||||
(1 << (10 + log2_chunk_size), 1 << (num_vars - 10 - log2_chunk_size))
|
||||
};
|
||||
// each block produces single result and store to `buf`
|
||||
let buf = self.gpu.htod_copy(vec![
|
||||
FieldBinding::default();
|
||||
1 << (num_vars - log2_chunk_size)
|
||||
])?;
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
let eval = self.gpu.get_func("multilinear", "eval").unwrap();
|
||||
unsafe {
|
||||
eval.launch(
|
||||
LaunchConfig::for_num_elems((block_num << 10) as u32),
|
||||
(&evals, &buf, &gpu_eval_point, data_size_per_block, chunk_size, offset),
|
||||
)?;
|
||||
};
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
results = self.gpu.sync_reclaim(buf)?;
|
||||
println!("Time taken to synchronize: {:.2?}", now.elapsed());
|
||||
if block_num == 1 {
|
||||
break;
|
||||
} else {
|
||||
evals = self
|
||||
.gpu
|
||||
.htod_copy(results.iter().cloned().step_by(1 << 10).collect_vec())?;
|
||||
num_vars -= 10 + log2_chunk_size;
|
||||
offset += 10 + log2_chunk_size;
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
eval.launch(
|
||||
LaunchConfig::for_num_elems(1 << (num_vars - 1) as u32),
|
||||
(&evals, &gpu_eval_point, num_vars, &buf),
|
||||
)?;
|
||||
};
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
let buf = self.gpu.sync_reclaim(buf)?;
|
||||
println!("Time taken to synchronize: {:.2?}", now.elapsed());
|
||||
|
||||
Ok(F::from_canonical_form(buf[0]))
|
||||
Ok(F::from_montgomery_form(results[0]))
|
||||
}
|
||||
|
||||
pub fn eval_by_coeff(
|
||||
@@ -209,7 +230,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_eval() -> Result<(), DriverError> {
|
||||
let num_vars = 10;
|
||||
let num_vars = 18;
|
||||
let rng = OsRng::default();
|
||||
let evals = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
|
||||
Reference in New Issue
Block a user