From 9d8332b6b8620bb89564bbcdebab96a97a3fd4fc Mon Sep 17 00:00:00 2001 From: rickwebiii Date: Tue, 6 Jun 2023 14:06:09 -0700 Subject: [PATCH] WIP --- sunscreen_math/src/opencl_impl/radix_sort.rs | 27 +++++----- .../shaders/ristrettopoint_prefixsum.cl | 49 +++++++++---------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/sunscreen_math/src/opencl_impl/radix_sort.rs b/sunscreen_math/src/opencl_impl/radix_sort.rs index 7a54881b3..4c18ba7bd 100644 --- a/sunscreen_math/src/opencl_impl/radix_sort.rs +++ b/sunscreen_math/src/opencl_impl/radix_sort.rs @@ -363,6 +363,7 @@ pub fn radix_sort_2( #[cfg(test)] mod tests { use curve25519_dalek::traits::Identity; + use rand::thread_rng; use crate::{GpuRistrettoPointVec, GpuVec, RistrettoPointVec}; @@ -497,17 +498,18 @@ mod tests { #[test] fn can_prefix_sum_blocks_ristretto() { - let cols = 128u32; + let cols = 2u32; let rows = 1; let data = (0..cols) - .map(|x| RistrettoPoint::identity()) + .map(|x| { + //RistrettoPoint::identity() + RistrettoPoint::random(&mut thread_rng()) + }) .collect::>(); //let data = [data.clone(), data.clone(), data].concat(); - let runtime = Runtime::get(); - let data_gpu = RistrettoPointVec::new(&data); let (mut prefix_sums, mut block_totals, actual_num_blocks) = @@ -555,23 +557,24 @@ mod tests { assert_eq!(sums_row.len(), data_row.len()); - for (c_id, (res_chunk, data_chunk)) in sums_row - .chunks(THREADS_PER_GROUP) - .zip(data_row.chunks(THREADS_PER_GROUP)) + for (c_id, (actual_chunk, expected_chunk)) in sums_row + .chunks(2 * RistrettoPoint::LOCAL_THREADS) + .zip(data_row.chunks(2 * RistrettoPoint::LOCAL_THREADS)) .enumerate() { // Check that the block totals match - let expected_sum = data_chunk + let expected_sum = expected_chunk .iter() .fold(RistrettoPoint::identity(), |s, x| s + x); - let actual = block_totals[row as usize * expected_num_blocks + c_id]; + let actual_sum = block_totals[row as usize * expected_num_blocks + c_id]; - assert_eq!(actual, expected_sum); + dbg!(actual_sum); + assert_eq!(actual_sum.compress(), expected_sum.compress()); // Serially compute the chunk's prefix sum and check that the // prefix sum matches. - let mut data_chunk = data_chunk.to_owned(); + let mut data_chunk = expected_chunk.to_owned(); let mut sum = RistrettoPoint::identity(); for i in data_chunk.iter_mut() { @@ -581,7 +584,7 @@ mod tests { sum += val; } - for (a, e) in data_chunk.iter().zip(res_chunk.iter()) { + for (a, e) in data_chunk.iter().zip(actual_chunk.iter()) { assert_eq!(a.compress(), e.compress()); } } diff --git a/sunscreen_math/src/opencl_impl/shaders/ristrettopoint_prefixsum.cl b/sunscreen_math/src/opencl_impl/shaders/ristrettopoint_prefixsum.cl index 22a8db4c8..18c03f9da 100644 --- a/sunscreen_math/src/opencl_impl/shaders/ristrettopoint_prefixsum.cl +++ b/sunscreen_math/src/opencl_impl/shaders/ristrettopoint_prefixsum.cl @@ -10,14 +10,14 @@ /// the rows of a matrix in shared memory. This is potentially future work to /// improve performance. RistrettoPoint local_prefix_sum_ristretto( - local u32* restrict data, + local RistrettoPoint* restrict data, u32 log_len ) { u32 local_id = get_local_id(0); - u32 len = 0x1 << log_len; + u32 len = 0x1 << (log_len + 1); // Up sweep - for (u32 i = 0; i < log_len; i++) { + for (u32 i = 0; i < 1; i++) {//log_len; i++) { u32 two_n = 0x1 << i; u32 two_n_plus_1 = 0x1 << (i + 1); @@ -26,29 +26,28 @@ RistrettoPoint local_prefix_sum_ristretto( u32 idx_2 = k + two_n_plus_1 - 1; if (idx_1 < len && idx_2 < len) { - RistrettoPoint a = RistrettoPoint_unpack_local(data, idx_1, len); - RistrettoPoint b = RistrettoPoint_unpack_local(data, idx_2, len); + RistrettoPoint a = data[idx_1]; + RistrettoPoint b = data[idx_2]; RistrettoPoint c = RistrettoPoint_add(&a, &b); - - RistrettoPoint_pack_local(&c, data, idx_2, len); } + //data[local_id].X.limbs[0] = idx_2; + //data[local_id].X.limbs[1] = 0; + barrier(CLK_LOCAL_MEM_FENCE); } - + barrier(CLK_LOCAL_MEM_FENCE); // The last element after up sweeping contains the sum of // all inputs. Write this to the block_totals. - //RistrettoPoint sum = RistrettoPoint_unpack_local(data, len - 1, len); - RistrettoPoint sum; - + RistrettoPoint sum = data[len - 1]; +/* // Down sweep if (local_id == 0) { RistrettoPoint identity = RistrettoPoint_IDENTITY; - - RistrettoPoint_pack_local(&identity, data, len - 1, len); + data[len - 1] = identity; } barrier(CLK_LOCAL_MEM_FENCE); @@ -64,18 +63,18 @@ RistrettoPoint local_prefix_sum_ristretto( u32 idx_2 = k + two_d_plus_1 - 1; if (idx_1 < len && idx_2 < len) { - RistrettoPoint t = RistrettoPoint_unpack_local(data, idx_1, len); - RistrettoPoint a = RistrettoPoint_unpack_local(data, idx_2, len); + RistrettoPoint t = data[idx_1]; + RistrettoPoint a = data[idx_2]; - RistrettoPoint_pack_local(&a, data, idx_1, len); + data[idx_1] = a; a = RistrettoPoint_add(&t, &a); - RistrettoPoint_pack_local(&a, data, idx_2, len); + data[idx_2] = a; } barrier(CLK_LOCAL_MEM_FENCE); - } + }*/ return sum; } @@ -102,7 +101,7 @@ kernel void prefix_sum_blocks_ristretto( u32 row_id = get_global_id(1); // TODO: Prevent bank conflicts - local u32 values_local[LOCAL_WORDS]; + local RistrettoPoint values_local[0x1 << (LOG_THREADS_PER_GROUP + 1)]; if (global_id < len) { RistrettoPoint val = RistrettoPoint_unpack( @@ -111,18 +110,18 @@ kernel void prefix_sum_blocks_ristretto( len ); - RistrettoPoint_pack_local(&val, values_local, local_id, LOCAL_LEN); + values_local[local_id] = val; } else { RistrettoPoint identity = RistrettoPoint_IDENTITY; - RistrettoPoint_pack_local(&identity, values_local, local_id, LOCAL_LEN); + values_local[local_id] = identity; } barrier(CLK_LOCAL_MEM_FENCE); RistrettoPoint sum = local_prefix_sum_ristretto( values_local, - LOG_THREADS_PER_GROUP + 1 + LOG_THREADS_PER_GROUP ); // TIL, multiple GPU threads writing to the same memory address is @@ -142,11 +141,7 @@ kernel void prefix_sum_blocks_ristretto( } if (global_id < len) { - RistrettoPoint val = RistrettoPoint_unpack_local( - values_local, - local_id, - LOCAL_LEN - ); + RistrettoPoint val = values_local[local_id]; RistrettoPoint_pack( &val,