From c8c979de0dba64c2c88b5357173b511807188ce9 Mon Sep 17 00:00:00 2001 From: rickwebiii Date: Mon, 27 Feb 2023 15:47:52 -0800 Subject: [PATCH] mont reduce test --- sunscreen_math/src/webgpu_impl/scalarvec.rs | 57 +++++++- .../src/webgpu_impl/scalarvectest.rs | 127 ++++++++++++++++++ .../src/webgpu_impl/shaders/scalar.test.wgsl | 35 +++++ 3 files changed, 216 insertions(+), 3 deletions(-) diff --git a/sunscreen_math/src/webgpu_impl/scalarvec.rs b/sunscreen_math/src/webgpu_impl/scalarvec.rs index b507ddd74..b41907073 100644 --- a/sunscreen_math/src/webgpu_impl/scalarvec.rs +++ b/sunscreen_math/src/webgpu_impl/scalarvec.rs @@ -282,9 +282,12 @@ impl Mul<&GpuScalarVec> for &GpuScalarVec { mod tests { use rand::{thread_rng, RngCore}; - use crate::webgpu_impl::{ - scalarvectest::{mul_internal, Scalar29}, - GpuU32, Grid, + use crate::{ + webgpu_impl::{ + scalarvectest::{mul_internal, Scalar29}, + GpuU32, Grid, + }, + ScalarVec, }; use super::*; @@ -617,4 +620,52 @@ mod tests { } } } + + #[test] + fn can_montgomery_reduce() { + let a = (0..253) + .into_iter() + .map(|_| { + let vals = (0..17) + .into_iter() + .map(|_| thread_rng().next_u64()) + .collect::>(); + + let vals: [u64; 17] = vals.try_into().unwrap(); + vals + }) + .collect::>(); + let a_packed = a.iter().cloned().flatten().collect::>(); + + let runtime = Runtime::get(); + + let a_vec = runtime.alloc_from_slice(&a_packed); + let dummy = GpuU32::new(0); + let len = GpuU32::new(a.len() as u32); + + let c = (0..253) + .into_iter() + .map(|_| Scalar::random(&mut thread_rng())) + .collect::>(); + let c_gpu = ScalarVec::new(&c); + + let threadgroups = if a.len() % 128 == 0 { + a.len() / 128 + } else { + a.len() / 128 + 1 + }; + + runtime.run( + "test_montgomery_reduce", + &[&a_vec, &dummy.data, &c_gpu.data, &len.data], + &Grid::new(threadgroups as u32, 1, 1), + ); + + for (a, actual) in a.iter().zip(c_gpu.iter()) { + let expected = Scalar29::montgomery_reduce(&a).to_bytes(); + let expected = Scalar::from_bits(expected); + + assert_eq!(expected, actual); + } + } } diff --git a/sunscreen_math/src/webgpu_impl/scalarvectest.rs b/sunscreen_math/src/webgpu_impl/scalarvectest.rs index 8d2c0d65b..95887d708 100644 --- a/sunscreen_math/src/webgpu_impl/scalarvectest.rs +++ b/sunscreen_math/src/webgpu_impl/scalarvectest.rs @@ -15,12 +15,64 @@ impl IndexMut for Scalar29 { } } +const L: Scalar29 = Scalar29([ + 0x1cf5d3edu32, + 0x009318d2, + 0x1de73596, + 0x1df3bd45, + 0x0000014d, + 0x00000000, + 0x00000000, + 0x00000000, + 0x00100000, +]); + impl Scalar29 { /// Return the zero scalar. pub fn zero() -> Scalar29 { Scalar29([0, 0, 0, 0, 0, 0, 0, 0, 0]) } + /// Pack the limbs of this `Scalar29` into 32 bytes. + pub fn to_bytes(&self) -> [u8; 32] { + let mut s = [0u8; 32]; + + s[0] = (self.0[0] >> 0) as u8; + s[1] = (self.0[0] >> 8) as u8; + s[2] = (self.0[0] >> 16) as u8; + s[3] = ((self.0[0] >> 24) | (self.0[1] << 5)) as u8; + s[4] = (self.0[1] >> 3) as u8; + s[5] = (self.0[1] >> 11) as u8; + s[6] = (self.0[1] >> 19) as u8; + s[7] = ((self.0[1] >> 27) | (self.0[2] << 2)) as u8; + s[8] = (self.0[2] >> 6) as u8; + s[9] = (self.0[2] >> 14) as u8; + s[10] = ((self.0[2] >> 22) | (self.0[3] << 7)) as u8; + s[11] = (self.0[3] >> 1) as u8; + s[12] = (self.0[3] >> 9) as u8; + s[13] = (self.0[3] >> 17) as u8; + s[14] = ((self.0[3] >> 25) | (self.0[4] << 4)) as u8; + s[15] = (self.0[4] >> 4) as u8; + s[16] = (self.0[4] >> 12) as u8; + s[17] = (self.0[4] >> 20) as u8; + s[18] = ((self.0[4] >> 28) | (self.0[5] << 1)) as u8; + s[19] = (self.0[5] >> 7) as u8; + s[20] = (self.0[5] >> 15) as u8; + s[21] = ((self.0[5] >> 23) | (self.0[6] << 6)) as u8; + s[22] = (self.0[6] >> 2) as u8; + s[23] = (self.0[6] >> 10) as u8; + s[24] = (self.0[6] >> 18) as u8; + s[25] = ((self.0[6] >> 26) | (self.0[7] << 3)) as u8; + s[26] = (self.0[7] >> 5) as u8; + s[27] = (self.0[7] >> 13) as u8; + s[28] = (self.0[7] >> 21) as u8; + s[29] = (self.0[8] >> 0) as u8; + s[30] = (self.0[8] >> 8) as u8; + s[31] = (self.0[8] >> 16) as u8; + + s + } + /// Unpack a 32 byte / 256 bit scalar into 9 29-bit limbs. pub fn from_bytes(bytes: &[u8; 32]) -> Scalar29 { let mut words = [0u32; 8]; @@ -46,6 +98,81 @@ impl Scalar29 { s } + + pub(crate) fn montgomery_reduce(limbs: &[u64; 17]) -> Scalar29 { + const LFACTOR: u32 = 0x12547e1b; + + #[inline(always)] + fn part1(sum: u64) -> (u64, u32) { + let p = (sum as u32).wrapping_mul(LFACTOR) & ((1u32 << 29) - 1); + ((sum + m(p, L[0])) >> 29, p) + } + + #[inline(always)] + fn part2(sum: u64) -> (u64, u32) { + let w = (sum as u32) & ((1u32 << 29) - 1); + (sum >> 29, w) + } + + // note: l5,l6,l7 are zero, so their multiplies can be skipped + let l = &L; + + // the first half computes the Montgomery adjustment factor n, and begins adding n*l to make limbs divisible by R + let (carry, n0) = part1(limbs[0]); + let (carry, n1) = part1(carry + limbs[1] + m(n0, l[1])); + let (carry, n2) = part1(carry + limbs[2] + m(n0, l[2]) + m(n1, l[1])); + let (carry, n3) = part1(carry + limbs[3] + m(n0, l[3]) + m(n1, l[2]) + m(n2, l[1])); + let (carry, n4) = + part1(carry + limbs[4] + m(n0, l[4]) + m(n1, l[3]) + m(n2, l[2]) + m(n3, l[1])); + let (carry, n5) = + part1(carry + limbs[5] + m(n1, l[4]) + m(n2, l[3]) + m(n3, l[2]) + m(n4, l[1])); + let (carry, n6) = + part1(carry + limbs[6] + m(n2, l[4]) + m(n3, l[3]) + m(n4, l[2]) + m(n5, l[1])); + let (carry, n7) = + part1(carry + limbs[7] + m(n3, l[4]) + m(n4, l[3]) + m(n5, l[2]) + m(n6, l[1])); + let (carry, n8) = part1( + carry + limbs[8] + m(n0, l[8]) + m(n4, l[4]) + m(n5, l[3]) + m(n6, l[2]) + m(n7, l[1]), + ); + + // limbs is divisible by R now, so we can divide by R by simply storing the upper half as the result + let (carry, r0) = part2( + carry + limbs[9] + m(n1, l[8]) + m(n5, l[4]) + m(n6, l[3]) + m(n7, l[2]) + m(n8, l[1]), + ); + let (carry, r1) = + part2(carry + limbs[10] + m(n2, l[8]) + m(n6, l[4]) + m(n7, l[3]) + m(n8, l[2])); + let (carry, r2) = part2(carry + limbs[11] + m(n3, l[8]) + m(n7, l[4]) + m(n8, l[3])); + let (carry, r3) = part2(carry + limbs[12] + m(n4, l[8]) + m(n8, l[4])); + let (carry, r4) = part2(carry + limbs[13] + m(n5, l[8])); + let (carry, r5) = part2(carry + limbs[14] + m(n6, l[8])); + let (carry, r6) = part2(carry + limbs[15] + m(n7, l[8])); + let (carry, r7) = part2(carry + limbs[16] + m(n8, l[8])); + let r8 = carry as u32; + + // result may be >= l, so attempt to subtract l + Scalar29::sub(&Scalar29([r0, r1, r2, r3, r4, r5, r6, r7, r8]), l) + } + + pub fn sub(a: &Scalar29, b: &Scalar29) -> Scalar29 { + let mut difference = Scalar29::zero(); + let mask = (1u32 << 29) - 1; + + // a - b + let mut borrow: u32 = 0; + for i in 0..9 { + borrow = a[i].wrapping_sub(b[i] + (borrow >> 31)); + difference[i] = borrow & mask; + } + + // conditionally add l if the difference is negative + let underflow_mask = ((borrow >> 31) ^ 1).wrapping_sub(1); + let mut carry: u32 = 0; + for i in 0..9 { + carry = (carry >> 29) + difference[i] + (L[i] & underflow_mask); + difference[i] = carry & mask; + } + + difference + } } fn m(x: u32, y: u32) -> u64 { diff --git a/sunscreen_math/src/webgpu_impl/shaders/scalar.test.wgsl b/sunscreen_math/src/webgpu_impl/shaders/scalar.test.wgsl index f60ceda96..e83deac9a 100644 --- a/sunscreen_math/src/webgpu_impl/shaders/scalar.test.wgsl +++ b/sunscreen_math/src/webgpu_impl/shaders/scalar.test.wgsl @@ -64,6 +64,41 @@ fn test_scalar_montgomery_reduce_part2( g_c[gid.x + 2u * g_len] = b.n; } +@compute +@workgroup_size(128, 1, 1) +fn test_montgomery_reduce( + @builtin(global_invocation_id) gid: vec3, +) { + if gid.x >= g_len { + unused_b(); + return; + } + + let limbs = array( + u64(g_a[34u * gid.x + 0u], g_a[34u * gid.x + 1u]), + u64(g_a[34u * gid.x + 2u], g_a[34u * gid.x + 3u]), + u64(g_a[34u * gid.x + 4u], g_a[34u * gid.x + 5u]), + u64(g_a[34u * gid.x + 6u], g_a[34u * gid.x + 7u]), + u64(g_a[34u * gid.x + 8u], g_a[34u * gid.x + 9u]), + u64(g_a[34u * gid.x + 10u], g_a[34u * gid.x + 11u]), + u64(g_a[34u * gid.x + 12u], g_a[34u * gid.x + 13u]), + u64(g_a[34u * gid.x + 14u], g_a[34u * gid.x + 15u]), + u64(g_a[34u * gid.x + 16u], g_a[34u * gid.x + 17u]), + u64(g_a[34u * gid.x + 18u], g_a[34u * gid.x + 19u]), + u64(g_a[34u * gid.x + 20u], g_a[34u * gid.x + 21u]), + u64(g_a[34u * gid.x + 22u], g_a[34u * gid.x + 23u]), + u64(g_a[34u * gid.x + 24u], g_a[34u * gid.x + 25u]), + u64(g_a[34u * gid.x + 26u], g_a[34u * gid.x + 27u]), + u64(g_a[34u * gid.x + 28u], g_a[34u * gid.x + 29u]), + u64(g_a[34u * gid.x + 30u], g_a[34u * gid.x + 31u]), + u64(g_a[34u * gid.x + 32u], g_a[34u * gid.x + 33u]), + ); + + var c = scalar29_montgomery_reduce(limbs); + + scalar29_pack_c(&c, gid.x, g_len); +} + @compute @workgroup_size(128, 1, 1) fn test_scalar_mul_internal(