mont reduce test

This commit is contained in:
rickwebiii
2023-02-27 15:47:52 -08:00
parent ab22e25df2
commit c8c979de0d
3 changed files with 216 additions and 3 deletions

View File

@@ -282,9 +282,12 @@ impl Mul<&GpuScalarVec> for &GpuScalarVec {
mod tests {
use rand::{thread_rng, RngCore};
use crate::webgpu_impl::{
scalarvectest::{mul_internal, Scalar29},
GpuU32, Grid,
use crate::{
webgpu_impl::{
scalarvectest::{mul_internal, Scalar29},
GpuU32, Grid,
},
ScalarVec,
};
use super::*;
@@ -617,4 +620,52 @@ mod tests {
}
}
}
#[test]
fn can_montgomery_reduce() {
let a = (0..253)
.into_iter()
.map(|_| {
let vals = (0..17)
.into_iter()
.map(|_| thread_rng().next_u64())
.collect::<Vec<_>>();
let vals: [u64; 17] = vals.try_into().unwrap();
vals
})
.collect::<Vec<_>>();
let a_packed = a.iter().cloned().flatten().collect::<Vec<_>>();
let runtime = Runtime::get();
let a_vec = runtime.alloc_from_slice(&a_packed);
let dummy = GpuU32::new(0);
let len = GpuU32::new(a.len() as u32);
let c = (0..253)
.into_iter()
.map(|_| Scalar::random(&mut thread_rng()))
.collect::<Vec<_>>();
let c_gpu = ScalarVec::new(&c);
let threadgroups = if a.len() % 128 == 0 {
a.len() / 128
} else {
a.len() / 128 + 1
};
runtime.run(
"test_montgomery_reduce",
&[&a_vec, &dummy.data, &c_gpu.data, &len.data],
&Grid::new(threadgroups as u32, 1, 1),
);
for (a, actual) in a.iter().zip(c_gpu.iter()) {
let expected = Scalar29::montgomery_reduce(&a).to_bytes();
let expected = Scalar::from_bits(expected);
assert_eq!(expected, actual);
}
}
}

View File

@@ -15,12 +15,64 @@ impl IndexMut<usize> for Scalar29 {
}
}
const L: Scalar29 = Scalar29([
0x1cf5d3edu32,
0x009318d2,
0x1de73596,
0x1df3bd45,
0x0000014d,
0x00000000,
0x00000000,
0x00000000,
0x00100000,
]);
impl Scalar29 {
/// Return the zero scalar.
pub fn zero() -> Scalar29 {
Scalar29([0, 0, 0, 0, 0, 0, 0, 0, 0])
}
/// Pack the limbs of this `Scalar29` into 32 bytes.
pub fn to_bytes(&self) -> [u8; 32] {
let mut s = [0u8; 32];
s[0] = (self.0[0] >> 0) as u8;
s[1] = (self.0[0] >> 8) as u8;
s[2] = (self.0[0] >> 16) as u8;
s[3] = ((self.0[0] >> 24) | (self.0[1] << 5)) as u8;
s[4] = (self.0[1] >> 3) as u8;
s[5] = (self.0[1] >> 11) as u8;
s[6] = (self.0[1] >> 19) as u8;
s[7] = ((self.0[1] >> 27) | (self.0[2] << 2)) as u8;
s[8] = (self.0[2] >> 6) as u8;
s[9] = (self.0[2] >> 14) as u8;
s[10] = ((self.0[2] >> 22) | (self.0[3] << 7)) as u8;
s[11] = (self.0[3] >> 1) as u8;
s[12] = (self.0[3] >> 9) as u8;
s[13] = (self.0[3] >> 17) as u8;
s[14] = ((self.0[3] >> 25) | (self.0[4] << 4)) as u8;
s[15] = (self.0[4] >> 4) as u8;
s[16] = (self.0[4] >> 12) as u8;
s[17] = (self.0[4] >> 20) as u8;
s[18] = ((self.0[4] >> 28) | (self.0[5] << 1)) as u8;
s[19] = (self.0[5] >> 7) as u8;
s[20] = (self.0[5] >> 15) as u8;
s[21] = ((self.0[5] >> 23) | (self.0[6] << 6)) as u8;
s[22] = (self.0[6] >> 2) as u8;
s[23] = (self.0[6] >> 10) as u8;
s[24] = (self.0[6] >> 18) as u8;
s[25] = ((self.0[6] >> 26) | (self.0[7] << 3)) as u8;
s[26] = (self.0[7] >> 5) as u8;
s[27] = (self.0[7] >> 13) as u8;
s[28] = (self.0[7] >> 21) as u8;
s[29] = (self.0[8] >> 0) as u8;
s[30] = (self.0[8] >> 8) as u8;
s[31] = (self.0[8] >> 16) as u8;
s
}
/// Unpack a 32 byte / 256 bit scalar into 9 29-bit limbs.
pub fn from_bytes(bytes: &[u8; 32]) -> Scalar29 {
let mut words = [0u32; 8];
@@ -46,6 +98,81 @@ impl Scalar29 {
s
}
pub(crate) fn montgomery_reduce(limbs: &[u64; 17]) -> Scalar29 {
const LFACTOR: u32 = 0x12547e1b;
#[inline(always)]
fn part1(sum: u64) -> (u64, u32) {
let p = (sum as u32).wrapping_mul(LFACTOR) & ((1u32 << 29) - 1);
((sum + m(p, L[0])) >> 29, p)
}
#[inline(always)]
fn part2(sum: u64) -> (u64, u32) {
let w = (sum as u32) & ((1u32 << 29) - 1);
(sum >> 29, w)
}
// note: l5,l6,l7 are zero, so their multiplies can be skipped
let l = &L;
// the first half computes the Montgomery adjustment factor n, and begins adding n*l to make limbs divisible by R
let (carry, n0) = part1(limbs[0]);
let (carry, n1) = part1(carry + limbs[1] + m(n0, l[1]));
let (carry, n2) = part1(carry + limbs[2] + m(n0, l[2]) + m(n1, l[1]));
let (carry, n3) = part1(carry + limbs[3] + m(n0, l[3]) + m(n1, l[2]) + m(n2, l[1]));
let (carry, n4) =
part1(carry + limbs[4] + m(n0, l[4]) + m(n1, l[3]) + m(n2, l[2]) + m(n3, l[1]));
let (carry, n5) =
part1(carry + limbs[5] + m(n1, l[4]) + m(n2, l[3]) + m(n3, l[2]) + m(n4, l[1]));
let (carry, n6) =
part1(carry + limbs[6] + m(n2, l[4]) + m(n3, l[3]) + m(n4, l[2]) + m(n5, l[1]));
let (carry, n7) =
part1(carry + limbs[7] + m(n3, l[4]) + m(n4, l[3]) + m(n5, l[2]) + m(n6, l[1]));
let (carry, n8) = part1(
carry + limbs[8] + m(n0, l[8]) + m(n4, l[4]) + m(n5, l[3]) + m(n6, l[2]) + m(n7, l[1]),
);
// limbs is divisible by R now, so we can divide by R by simply storing the upper half as the result
let (carry, r0) = part2(
carry + limbs[9] + m(n1, l[8]) + m(n5, l[4]) + m(n6, l[3]) + m(n7, l[2]) + m(n8, l[1]),
);
let (carry, r1) =
part2(carry + limbs[10] + m(n2, l[8]) + m(n6, l[4]) + m(n7, l[3]) + m(n8, l[2]));
let (carry, r2) = part2(carry + limbs[11] + m(n3, l[8]) + m(n7, l[4]) + m(n8, l[3]));
let (carry, r3) = part2(carry + limbs[12] + m(n4, l[8]) + m(n8, l[4]));
let (carry, r4) = part2(carry + limbs[13] + m(n5, l[8]));
let (carry, r5) = part2(carry + limbs[14] + m(n6, l[8]));
let (carry, r6) = part2(carry + limbs[15] + m(n7, l[8]));
let (carry, r7) = part2(carry + limbs[16] + m(n8, l[8]));
let r8 = carry as u32;
// result may be >= l, so attempt to subtract l
Scalar29::sub(&Scalar29([r0, r1, r2, r3, r4, r5, r6, r7, r8]), l)
}
pub fn sub(a: &Scalar29, b: &Scalar29) -> Scalar29 {
let mut difference = Scalar29::zero();
let mask = (1u32 << 29) - 1;
// a - b
let mut borrow: u32 = 0;
for i in 0..9 {
borrow = a[i].wrapping_sub(b[i] + (borrow >> 31));
difference[i] = borrow & mask;
}
// conditionally add l if the difference is negative
let underflow_mask = ((borrow >> 31) ^ 1).wrapping_sub(1);
let mut carry: u32 = 0;
for i in 0..9 {
carry = (carry >> 29) + difference[i] + (L[i] & underflow_mask);
difference[i] = carry & mask;
}
difference
}
}
fn m(x: u32, y: u32) -> u64 {

View File

@@ -64,6 +64,41 @@ fn test_scalar_montgomery_reduce_part2(
g_c[gid.x + 2u * g_len] = b.n;
}
@compute
@workgroup_size(128, 1, 1)
fn test_montgomery_reduce(
@builtin(global_invocation_id) gid: vec3<u32>,
) {
if gid.x >= g_len {
unused_b();
return;
}
let limbs = array<u64, 17>(
u64(g_a[34u * gid.x + 0u], g_a[34u * gid.x + 1u]),
u64(g_a[34u * gid.x + 2u], g_a[34u * gid.x + 3u]),
u64(g_a[34u * gid.x + 4u], g_a[34u * gid.x + 5u]),
u64(g_a[34u * gid.x + 6u], g_a[34u * gid.x + 7u]),
u64(g_a[34u * gid.x + 8u], g_a[34u * gid.x + 9u]),
u64(g_a[34u * gid.x + 10u], g_a[34u * gid.x + 11u]),
u64(g_a[34u * gid.x + 12u], g_a[34u * gid.x + 13u]),
u64(g_a[34u * gid.x + 14u], g_a[34u * gid.x + 15u]),
u64(g_a[34u * gid.x + 16u], g_a[34u * gid.x + 17u]),
u64(g_a[34u * gid.x + 18u], g_a[34u * gid.x + 19u]),
u64(g_a[34u * gid.x + 20u], g_a[34u * gid.x + 21u]),
u64(g_a[34u * gid.x + 22u], g_a[34u * gid.x + 23u]),
u64(g_a[34u * gid.x + 24u], g_a[34u * gid.x + 25u]),
u64(g_a[34u * gid.x + 26u], g_a[34u * gid.x + 27u]),
u64(g_a[34u * gid.x + 28u], g_a[34u * gid.x + 29u]),
u64(g_a[34u * gid.x + 30u], g_a[34u * gid.x + 31u]),
u64(g_a[34u * gid.x + 32u], g_a[34u * gid.x + 33u]),
);
var c = scalar29_montgomery_reduce(limbs);
scalar29_pack_c(&c, gid.x, g_len);
}
@compute
@workgroup_size(128, 1, 1)
fn test_scalar_mul_internal(