mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
WIP
This commit is contained in:
@@ -363,6 +363,7 @@ pub fn radix_sort_2(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use curve25519_dalek::traits::Identity;
|
||||
use rand::thread_rng;
|
||||
|
||||
use crate::{GpuRistrettoPointVec, GpuVec, RistrettoPointVec};
|
||||
|
||||
@@ -497,17 +498,18 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn can_prefix_sum_blocks_ristretto() {
|
||||
let cols = 128u32;
|
||||
let cols = 2u32;
|
||||
let rows = 1;
|
||||
|
||||
let data = (0..cols)
|
||||
.map(|x| RistrettoPoint::identity())
|
||||
.map(|x| {
|
||||
//RistrettoPoint::identity()
|
||||
RistrettoPoint::random(&mut thread_rng())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
//let data = [data.clone(), data.clone(), data].concat();
|
||||
|
||||
let runtime = Runtime::get();
|
||||
|
||||
let data_gpu = RistrettoPointVec::new(&data);
|
||||
|
||||
let (mut prefix_sums, mut block_totals, actual_num_blocks) =
|
||||
@@ -555,23 +557,24 @@ mod tests {
|
||||
|
||||
assert_eq!(sums_row.len(), data_row.len());
|
||||
|
||||
for (c_id, (res_chunk, data_chunk)) in sums_row
|
||||
.chunks(THREADS_PER_GROUP)
|
||||
.zip(data_row.chunks(THREADS_PER_GROUP))
|
||||
for (c_id, (actual_chunk, expected_chunk)) in sums_row
|
||||
.chunks(2 * RistrettoPoint::LOCAL_THREADS)
|
||||
.zip(data_row.chunks(2 * RistrettoPoint::LOCAL_THREADS))
|
||||
.enumerate()
|
||||
{
|
||||
// Check that the block totals match
|
||||
let expected_sum = data_chunk
|
||||
let expected_sum = expected_chunk
|
||||
.iter()
|
||||
.fold(RistrettoPoint::identity(), |s, x| s + x);
|
||||
|
||||
let actual = block_totals[row as usize * expected_num_blocks + c_id];
|
||||
let actual_sum = block_totals[row as usize * expected_num_blocks + c_id];
|
||||
|
||||
assert_eq!(actual, expected_sum);
|
||||
dbg!(actual_sum);
|
||||
assert_eq!(actual_sum.compress(), expected_sum.compress());
|
||||
|
||||
// Serially compute the chunk's prefix sum and check that the
|
||||
// prefix sum matches.
|
||||
let mut data_chunk = data_chunk.to_owned();
|
||||
let mut data_chunk = expected_chunk.to_owned();
|
||||
let mut sum = RistrettoPoint::identity();
|
||||
|
||||
for i in data_chunk.iter_mut() {
|
||||
@@ -581,7 +584,7 @@ mod tests {
|
||||
sum += val;
|
||||
}
|
||||
|
||||
for (a, e) in data_chunk.iter().zip(res_chunk.iter()) {
|
||||
for (a, e) in data_chunk.iter().zip(actual_chunk.iter()) {
|
||||
assert_eq!(a.compress(), e.compress());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,14 +10,14 @@
|
||||
/// the rows of a matrix in shared memory. This is potentially future work to
|
||||
/// improve performance.
|
||||
RistrettoPoint local_prefix_sum_ristretto(
|
||||
local u32* restrict data,
|
||||
local RistrettoPoint* restrict data,
|
||||
u32 log_len
|
||||
) {
|
||||
u32 local_id = get_local_id(0);
|
||||
u32 len = 0x1 << log_len;
|
||||
u32 len = 0x1 << (log_len + 1);
|
||||
|
||||
// Up sweep
|
||||
for (u32 i = 0; i < log_len; i++) {
|
||||
for (u32 i = 0; i < 1; i++) {//log_len; i++) {
|
||||
u32 two_n = 0x1 << i;
|
||||
u32 two_n_plus_1 = 0x1 << (i + 1);
|
||||
|
||||
@@ -26,29 +26,28 @@ RistrettoPoint local_prefix_sum_ristretto(
|
||||
u32 idx_2 = k + two_n_plus_1 - 1;
|
||||
|
||||
if (idx_1 < len && idx_2 < len) {
|
||||
RistrettoPoint a = RistrettoPoint_unpack_local(data, idx_1, len);
|
||||
RistrettoPoint b = RistrettoPoint_unpack_local(data, idx_2, len);
|
||||
RistrettoPoint a = data[idx_1];
|
||||
RistrettoPoint b = data[idx_2];
|
||||
|
||||
RistrettoPoint c = RistrettoPoint_add(&a, &b);
|
||||
|
||||
RistrettoPoint_pack_local(&c, data, idx_2, len);
|
||||
}
|
||||
|
||||
//data[local_id].X.limbs[0] = idx_2;
|
||||
//data[local_id].X.limbs[1] = 0;
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// The last element after up sweeping contains the sum of
|
||||
// all inputs. Write this to the block_totals.
|
||||
//RistrettoPoint sum = RistrettoPoint_unpack_local(data, len - 1, len);
|
||||
RistrettoPoint sum;
|
||||
|
||||
RistrettoPoint sum = data[len - 1];
|
||||
/*
|
||||
// Down sweep
|
||||
if (local_id == 0) {
|
||||
RistrettoPoint identity = RistrettoPoint_IDENTITY;
|
||||
|
||||
RistrettoPoint_pack_local(&identity, data, len - 1, len);
|
||||
data[len - 1] = identity;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -64,18 +63,18 @@ RistrettoPoint local_prefix_sum_ristretto(
|
||||
u32 idx_2 = k + two_d_plus_1 - 1;
|
||||
|
||||
if (idx_1 < len && idx_2 < len) {
|
||||
RistrettoPoint t = RistrettoPoint_unpack_local(data, idx_1, len);
|
||||
RistrettoPoint a = RistrettoPoint_unpack_local(data, idx_2, len);
|
||||
RistrettoPoint t = data[idx_1];
|
||||
RistrettoPoint a = data[idx_2];
|
||||
|
||||
RistrettoPoint_pack_local(&a, data, idx_1, len);
|
||||
data[idx_1] = a;
|
||||
|
||||
a = RistrettoPoint_add(&t, &a);
|
||||
|
||||
RistrettoPoint_pack_local(&a, data, idx_2, len);
|
||||
data[idx_2] = a;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
}*/
|
||||
|
||||
return sum;
|
||||
}
|
||||
@@ -102,7 +101,7 @@ kernel void prefix_sum_blocks_ristretto(
|
||||
u32 row_id = get_global_id(1);
|
||||
|
||||
// TODO: Prevent bank conflicts
|
||||
local u32 values_local[LOCAL_WORDS];
|
||||
local RistrettoPoint values_local[0x1 << (LOG_THREADS_PER_GROUP + 1)];
|
||||
|
||||
if (global_id < len) {
|
||||
RistrettoPoint val = RistrettoPoint_unpack(
|
||||
@@ -111,18 +110,18 @@ kernel void prefix_sum_blocks_ristretto(
|
||||
len
|
||||
);
|
||||
|
||||
RistrettoPoint_pack_local(&val, values_local, local_id, LOCAL_LEN);
|
||||
values_local[local_id] = val;
|
||||
} else {
|
||||
RistrettoPoint identity = RistrettoPoint_IDENTITY;
|
||||
|
||||
RistrettoPoint_pack_local(&identity, values_local, local_id, LOCAL_LEN);
|
||||
values_local[local_id] = identity;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
RistrettoPoint sum = local_prefix_sum_ristretto(
|
||||
values_local,
|
||||
LOG_THREADS_PER_GROUP + 1
|
||||
LOG_THREADS_PER_GROUP
|
||||
);
|
||||
|
||||
// TIL, multiple GPU threads writing to the same memory address is
|
||||
@@ -142,11 +141,7 @@ kernel void prefix_sum_blocks_ristretto(
|
||||
}
|
||||
|
||||
if (global_id < len) {
|
||||
RistrettoPoint val = RistrettoPoint_unpack_local(
|
||||
values_local,
|
||||
local_id,
|
||||
LOCAL_LEN
|
||||
);
|
||||
RistrettoPoint val = values_local[local_id];
|
||||
|
||||
RistrettoPoint_pack(
|
||||
&val,
|
||||
|
||||
Reference in New Issue
Block a user