This commit is contained in:
rickwebiii
2023-06-06 14:06:09 -07:00
parent 4623ec97a6
commit 9d8332b6b8
2 changed files with 37 additions and 39 deletions

View File

@@ -363,6 +363,7 @@ pub fn radix_sort_2(
#[cfg(test)]
mod tests {
use curve25519_dalek::traits::Identity;
use rand::thread_rng;
use crate::{GpuRistrettoPointVec, GpuVec, RistrettoPointVec};
@@ -497,17 +498,18 @@ mod tests {
#[test]
fn can_prefix_sum_blocks_ristretto() {
let cols = 128u32;
let cols = 2u32;
let rows = 1;
let data = (0..cols)
.map(|x| RistrettoPoint::identity())
.map(|x| {
//RistrettoPoint::identity()
RistrettoPoint::random(&mut thread_rng())
})
.collect::<Vec<_>>();
//let data = [data.clone(), data.clone(), data].concat();
let runtime = Runtime::get();
let data_gpu = RistrettoPointVec::new(&data);
let (mut prefix_sums, mut block_totals, actual_num_blocks) =
@@ -555,23 +557,24 @@ mod tests {
assert_eq!(sums_row.len(), data_row.len());
for (c_id, (res_chunk, data_chunk)) in sums_row
.chunks(THREADS_PER_GROUP)
.zip(data_row.chunks(THREADS_PER_GROUP))
for (c_id, (actual_chunk, expected_chunk)) in sums_row
.chunks(2 * RistrettoPoint::LOCAL_THREADS)
.zip(data_row.chunks(2 * RistrettoPoint::LOCAL_THREADS))
.enumerate()
{
// Check that the block totals match
let expected_sum = data_chunk
let expected_sum = expected_chunk
.iter()
.fold(RistrettoPoint::identity(), |s, x| s + x);
let actual = block_totals[row as usize * expected_num_blocks + c_id];
let actual_sum = block_totals[row as usize * expected_num_blocks + c_id];
assert_eq!(actual, expected_sum);
dbg!(actual_sum);
assert_eq!(actual_sum.compress(), expected_sum.compress());
// Serially compute the chunk's prefix sum and check that the
// prefix sum matches.
let mut data_chunk = data_chunk.to_owned();
let mut data_chunk = expected_chunk.to_owned();
let mut sum = RistrettoPoint::identity();
for i in data_chunk.iter_mut() {
@@ -581,7 +584,7 @@ mod tests {
sum += val;
}
for (a, e) in data_chunk.iter().zip(res_chunk.iter()) {
for (a, e) in data_chunk.iter().zip(actual_chunk.iter()) {
assert_eq!(a.compress(), e.compress());
}
}

View File

@@ -10,14 +10,14 @@
/// the rows of a matrix in shared memory. This is potentially future work to
/// improve performance.
RistrettoPoint local_prefix_sum_ristretto(
local u32* restrict data,
local RistrettoPoint* restrict data,
u32 log_len
) {
u32 local_id = get_local_id(0);
u32 len = 0x1 << log_len;
u32 len = 0x1 << (log_len + 1);
// Up sweep
for (u32 i = 0; i < log_len; i++) {
for (u32 i = 0; i < 1; i++) {//log_len; i++) {
u32 two_n = 0x1 << i;
u32 two_n_plus_1 = 0x1 << (i + 1);
@@ -26,29 +26,28 @@ RistrettoPoint local_prefix_sum_ristretto(
u32 idx_2 = k + two_n_plus_1 - 1;
if (idx_1 < len && idx_2 < len) {
RistrettoPoint a = RistrettoPoint_unpack_local(data, idx_1, len);
RistrettoPoint b = RistrettoPoint_unpack_local(data, idx_2, len);
RistrettoPoint a = data[idx_1];
RistrettoPoint b = data[idx_2];
RistrettoPoint c = RistrettoPoint_add(&a, &b);
RistrettoPoint_pack_local(&c, data, idx_2, len);
}
//data[local_id].X.limbs[0] = idx_2;
//data[local_id].X.limbs[1] = 0;
barrier(CLK_LOCAL_MEM_FENCE);
}
barrier(CLK_LOCAL_MEM_FENCE);
// The last element after up sweeping contains the sum of
// all inputs. Write this to the block_totals.
//RistrettoPoint sum = RistrettoPoint_unpack_local(data, len - 1, len);
RistrettoPoint sum;
RistrettoPoint sum = data[len - 1];
/*
// Down sweep
if (local_id == 0) {
RistrettoPoint identity = RistrettoPoint_IDENTITY;
RistrettoPoint_pack_local(&identity, data, len - 1, len);
data[len - 1] = identity;
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -64,18 +63,18 @@ RistrettoPoint local_prefix_sum_ristretto(
u32 idx_2 = k + two_d_plus_1 - 1;
if (idx_1 < len && idx_2 < len) {
RistrettoPoint t = RistrettoPoint_unpack_local(data, idx_1, len);
RistrettoPoint a = RistrettoPoint_unpack_local(data, idx_2, len);
RistrettoPoint t = data[idx_1];
RistrettoPoint a = data[idx_2];
RistrettoPoint_pack_local(&a, data, idx_1, len);
data[idx_1] = a;
a = RistrettoPoint_add(&t, &a);
RistrettoPoint_pack_local(&a, data, idx_2, len);
data[idx_2] = a;
}
barrier(CLK_LOCAL_MEM_FENCE);
}
}*/
return sum;
}
@@ -102,7 +101,7 @@ kernel void prefix_sum_blocks_ristretto(
u32 row_id = get_global_id(1);
// TODO: Prevent bank conflicts
local u32 values_local[LOCAL_WORDS];
local RistrettoPoint values_local[0x1 << (LOG_THREADS_PER_GROUP + 1)];
if (global_id < len) {
RistrettoPoint val = RistrettoPoint_unpack(
@@ -111,18 +110,18 @@ kernel void prefix_sum_blocks_ristretto(
len
);
RistrettoPoint_pack_local(&val, values_local, local_id, LOCAL_LEN);
values_local[local_id] = val;
} else {
RistrettoPoint identity = RistrettoPoint_IDENTITY;
RistrettoPoint_pack_local(&identity, values_local, local_id, LOCAL_LEN);
values_local[local_id] = identity;
}
barrier(CLK_LOCAL_MEM_FENCE);
RistrettoPoint sum = local_prefix_sum_ristretto(
values_local,
LOG_THREADS_PER_GROUP + 1
LOG_THREADS_PER_GROUP
);
// TIL, multiple GPU threads writing to the same memory address is
@@ -142,11 +141,7 @@ kernel void prefix_sum_blocks_ristretto(
}
if (global_id < len) {
RistrettoPoint val = RistrettoPoint_unpack_local(
values_local,
local_id,
LOCAL_LEN
);
RistrettoPoint val = values_local[local_id];
RistrettoPoint_pack(
&val,