mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
mont reduce test
This commit is contained in:
@@ -282,9 +282,12 @@ impl Mul<&GpuScalarVec> for &GpuScalarVec {
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::webgpu_impl::{
|
||||
scalarvectest::{mul_internal, Scalar29},
|
||||
GpuU32, Grid,
|
||||
use crate::{
|
||||
webgpu_impl::{
|
||||
scalarvectest::{mul_internal, Scalar29},
|
||||
GpuU32, Grid,
|
||||
},
|
||||
ScalarVec,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -617,4 +620,52 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_montgomery_reduce() {
|
||||
let a = (0..253)
|
||||
.into_iter()
|
||||
.map(|_| {
|
||||
let vals = (0..17)
|
||||
.into_iter()
|
||||
.map(|_| thread_rng().next_u64())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let vals: [u64; 17] = vals.try_into().unwrap();
|
||||
vals
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let a_packed = a.iter().cloned().flatten().collect::<Vec<_>>();
|
||||
|
||||
let runtime = Runtime::get();
|
||||
|
||||
let a_vec = runtime.alloc_from_slice(&a_packed);
|
||||
let dummy = GpuU32::new(0);
|
||||
let len = GpuU32::new(a.len() as u32);
|
||||
|
||||
let c = (0..253)
|
||||
.into_iter()
|
||||
.map(|_| Scalar::random(&mut thread_rng()))
|
||||
.collect::<Vec<_>>();
|
||||
let c_gpu = ScalarVec::new(&c);
|
||||
|
||||
let threadgroups = if a.len() % 128 == 0 {
|
||||
a.len() / 128
|
||||
} else {
|
||||
a.len() / 128 + 1
|
||||
};
|
||||
|
||||
runtime.run(
|
||||
"test_montgomery_reduce",
|
||||
&[&a_vec, &dummy.data, &c_gpu.data, &len.data],
|
||||
&Grid::new(threadgroups as u32, 1, 1),
|
||||
);
|
||||
|
||||
for (a, actual) in a.iter().zip(c_gpu.iter()) {
|
||||
let expected = Scalar29::montgomery_reduce(&a).to_bytes();
|
||||
let expected = Scalar::from_bits(expected);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,12 +15,64 @@ impl IndexMut<usize> for Scalar29 {
|
||||
}
|
||||
}
|
||||
|
||||
const L: Scalar29 = Scalar29([
|
||||
0x1cf5d3edu32,
|
||||
0x009318d2,
|
||||
0x1de73596,
|
||||
0x1df3bd45,
|
||||
0x0000014d,
|
||||
0x00000000,
|
||||
0x00000000,
|
||||
0x00000000,
|
||||
0x00100000,
|
||||
]);
|
||||
|
||||
impl Scalar29 {
|
||||
/// Return the zero scalar.
|
||||
pub fn zero() -> Scalar29 {
|
||||
Scalar29([0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
}
|
||||
|
||||
/// Pack the limbs of this `Scalar29` into 32 bytes.
|
||||
pub fn to_bytes(&self) -> [u8; 32] {
|
||||
let mut s = [0u8; 32];
|
||||
|
||||
s[0] = (self.0[0] >> 0) as u8;
|
||||
s[1] = (self.0[0] >> 8) as u8;
|
||||
s[2] = (self.0[0] >> 16) as u8;
|
||||
s[3] = ((self.0[0] >> 24) | (self.0[1] << 5)) as u8;
|
||||
s[4] = (self.0[1] >> 3) as u8;
|
||||
s[5] = (self.0[1] >> 11) as u8;
|
||||
s[6] = (self.0[1] >> 19) as u8;
|
||||
s[7] = ((self.0[1] >> 27) | (self.0[2] << 2)) as u8;
|
||||
s[8] = (self.0[2] >> 6) as u8;
|
||||
s[9] = (self.0[2] >> 14) as u8;
|
||||
s[10] = ((self.0[2] >> 22) | (self.0[3] << 7)) as u8;
|
||||
s[11] = (self.0[3] >> 1) as u8;
|
||||
s[12] = (self.0[3] >> 9) as u8;
|
||||
s[13] = (self.0[3] >> 17) as u8;
|
||||
s[14] = ((self.0[3] >> 25) | (self.0[4] << 4)) as u8;
|
||||
s[15] = (self.0[4] >> 4) as u8;
|
||||
s[16] = (self.0[4] >> 12) as u8;
|
||||
s[17] = (self.0[4] >> 20) as u8;
|
||||
s[18] = ((self.0[4] >> 28) | (self.0[5] << 1)) as u8;
|
||||
s[19] = (self.0[5] >> 7) as u8;
|
||||
s[20] = (self.0[5] >> 15) as u8;
|
||||
s[21] = ((self.0[5] >> 23) | (self.0[6] << 6)) as u8;
|
||||
s[22] = (self.0[6] >> 2) as u8;
|
||||
s[23] = (self.0[6] >> 10) as u8;
|
||||
s[24] = (self.0[6] >> 18) as u8;
|
||||
s[25] = ((self.0[6] >> 26) | (self.0[7] << 3)) as u8;
|
||||
s[26] = (self.0[7] >> 5) as u8;
|
||||
s[27] = (self.0[7] >> 13) as u8;
|
||||
s[28] = (self.0[7] >> 21) as u8;
|
||||
s[29] = (self.0[8] >> 0) as u8;
|
||||
s[30] = (self.0[8] >> 8) as u8;
|
||||
s[31] = (self.0[8] >> 16) as u8;
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
/// Unpack a 32 byte / 256 bit scalar into 9 29-bit limbs.
|
||||
pub fn from_bytes(bytes: &[u8; 32]) -> Scalar29 {
|
||||
let mut words = [0u32; 8];
|
||||
@@ -46,6 +98,81 @@ impl Scalar29 {
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
pub(crate) fn montgomery_reduce(limbs: &[u64; 17]) -> Scalar29 {
|
||||
const LFACTOR: u32 = 0x12547e1b;
|
||||
|
||||
#[inline(always)]
|
||||
fn part1(sum: u64) -> (u64, u32) {
|
||||
let p = (sum as u32).wrapping_mul(LFACTOR) & ((1u32 << 29) - 1);
|
||||
((sum + m(p, L[0])) >> 29, p)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn part2(sum: u64) -> (u64, u32) {
|
||||
let w = (sum as u32) & ((1u32 << 29) - 1);
|
||||
(sum >> 29, w)
|
||||
}
|
||||
|
||||
// note: l5,l6,l7 are zero, so their multiplies can be skipped
|
||||
let l = &L;
|
||||
|
||||
// the first half computes the Montgomery adjustment factor n, and begins adding n*l to make limbs divisible by R
|
||||
let (carry, n0) = part1(limbs[0]);
|
||||
let (carry, n1) = part1(carry + limbs[1] + m(n0, l[1]));
|
||||
let (carry, n2) = part1(carry + limbs[2] + m(n0, l[2]) + m(n1, l[1]));
|
||||
let (carry, n3) = part1(carry + limbs[3] + m(n0, l[3]) + m(n1, l[2]) + m(n2, l[1]));
|
||||
let (carry, n4) =
|
||||
part1(carry + limbs[4] + m(n0, l[4]) + m(n1, l[3]) + m(n2, l[2]) + m(n3, l[1]));
|
||||
let (carry, n5) =
|
||||
part1(carry + limbs[5] + m(n1, l[4]) + m(n2, l[3]) + m(n3, l[2]) + m(n4, l[1]));
|
||||
let (carry, n6) =
|
||||
part1(carry + limbs[6] + m(n2, l[4]) + m(n3, l[3]) + m(n4, l[2]) + m(n5, l[1]));
|
||||
let (carry, n7) =
|
||||
part1(carry + limbs[7] + m(n3, l[4]) + m(n4, l[3]) + m(n5, l[2]) + m(n6, l[1]));
|
||||
let (carry, n8) = part1(
|
||||
carry + limbs[8] + m(n0, l[8]) + m(n4, l[4]) + m(n5, l[3]) + m(n6, l[2]) + m(n7, l[1]),
|
||||
);
|
||||
|
||||
// limbs is divisible by R now, so we can divide by R by simply storing the upper half as the result
|
||||
let (carry, r0) = part2(
|
||||
carry + limbs[9] + m(n1, l[8]) + m(n5, l[4]) + m(n6, l[3]) + m(n7, l[2]) + m(n8, l[1]),
|
||||
);
|
||||
let (carry, r1) =
|
||||
part2(carry + limbs[10] + m(n2, l[8]) + m(n6, l[4]) + m(n7, l[3]) + m(n8, l[2]));
|
||||
let (carry, r2) = part2(carry + limbs[11] + m(n3, l[8]) + m(n7, l[4]) + m(n8, l[3]));
|
||||
let (carry, r3) = part2(carry + limbs[12] + m(n4, l[8]) + m(n8, l[4]));
|
||||
let (carry, r4) = part2(carry + limbs[13] + m(n5, l[8]));
|
||||
let (carry, r5) = part2(carry + limbs[14] + m(n6, l[8]));
|
||||
let (carry, r6) = part2(carry + limbs[15] + m(n7, l[8]));
|
||||
let (carry, r7) = part2(carry + limbs[16] + m(n8, l[8]));
|
||||
let r8 = carry as u32;
|
||||
|
||||
// result may be >= l, so attempt to subtract l
|
||||
Scalar29::sub(&Scalar29([r0, r1, r2, r3, r4, r5, r6, r7, r8]), l)
|
||||
}
|
||||
|
||||
pub fn sub(a: &Scalar29, b: &Scalar29) -> Scalar29 {
|
||||
let mut difference = Scalar29::zero();
|
||||
let mask = (1u32 << 29) - 1;
|
||||
|
||||
// a - b
|
||||
let mut borrow: u32 = 0;
|
||||
for i in 0..9 {
|
||||
borrow = a[i].wrapping_sub(b[i] + (borrow >> 31));
|
||||
difference[i] = borrow & mask;
|
||||
}
|
||||
|
||||
// conditionally add l if the difference is negative
|
||||
let underflow_mask = ((borrow >> 31) ^ 1).wrapping_sub(1);
|
||||
let mut carry: u32 = 0;
|
||||
for i in 0..9 {
|
||||
carry = (carry >> 29) + difference[i] + (L[i] & underflow_mask);
|
||||
difference[i] = carry & mask;
|
||||
}
|
||||
|
||||
difference
|
||||
}
|
||||
}
|
||||
|
||||
fn m(x: u32, y: u32) -> u64 {
|
||||
|
||||
@@ -64,6 +64,41 @@ fn test_scalar_montgomery_reduce_part2(
|
||||
g_c[gid.x + 2u * g_len] = b.n;
|
||||
}
|
||||
|
||||
@compute
|
||||
@workgroup_size(128, 1, 1)
|
||||
fn test_montgomery_reduce(
|
||||
@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
) {
|
||||
if gid.x >= g_len {
|
||||
unused_b();
|
||||
return;
|
||||
}
|
||||
|
||||
let limbs = array<u64, 17>(
|
||||
u64(g_a[34u * gid.x + 0u], g_a[34u * gid.x + 1u]),
|
||||
u64(g_a[34u * gid.x + 2u], g_a[34u * gid.x + 3u]),
|
||||
u64(g_a[34u * gid.x + 4u], g_a[34u * gid.x + 5u]),
|
||||
u64(g_a[34u * gid.x + 6u], g_a[34u * gid.x + 7u]),
|
||||
u64(g_a[34u * gid.x + 8u], g_a[34u * gid.x + 9u]),
|
||||
u64(g_a[34u * gid.x + 10u], g_a[34u * gid.x + 11u]),
|
||||
u64(g_a[34u * gid.x + 12u], g_a[34u * gid.x + 13u]),
|
||||
u64(g_a[34u * gid.x + 14u], g_a[34u * gid.x + 15u]),
|
||||
u64(g_a[34u * gid.x + 16u], g_a[34u * gid.x + 17u]),
|
||||
u64(g_a[34u * gid.x + 18u], g_a[34u * gid.x + 19u]),
|
||||
u64(g_a[34u * gid.x + 20u], g_a[34u * gid.x + 21u]),
|
||||
u64(g_a[34u * gid.x + 22u], g_a[34u * gid.x + 23u]),
|
||||
u64(g_a[34u * gid.x + 24u], g_a[34u * gid.x + 25u]),
|
||||
u64(g_a[34u * gid.x + 26u], g_a[34u * gid.x + 27u]),
|
||||
u64(g_a[34u * gid.x + 28u], g_a[34u * gid.x + 29u]),
|
||||
u64(g_a[34u * gid.x + 30u], g_a[34u * gid.x + 31u]),
|
||||
u64(g_a[34u * gid.x + 32u], g_a[34u * gid.x + 33u]),
|
||||
);
|
||||
|
||||
var c = scalar29_montgomery_reduce(limbs);
|
||||
|
||||
scalar29_pack_c(&c, gid.x, g_len);
|
||||
}
|
||||
|
||||
@compute
|
||||
@workgroup_size(128, 1, 1)
|
||||
fn test_scalar_mul_internal(
|
||||
|
||||
Reference in New Issue
Block a user