From 393316ffe1d02efe70e26310ff04318b2e185e87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tancr=C3=A8de=20Lepoint?= Date: Tue, 4 Apr 2023 04:28:01 -0400 Subject: [PATCH] Change tabs into space, optimize ntt operator constructor (#170) --- crates/fhe-math/benches/ntt.rs | 56 +- crates/fhe-math/benches/rns.rs | 80 +- crates/fhe-math/benches/rq.rs | 468 +-- crates/fhe-math/benches/zq.rs | 72 +- crates/fhe-math/src/errors.rs | 96 +- crates/fhe-math/src/rns/mod.rs | 352 +- crates/fhe-math/src/rns/scaler.rs | 854 ++--- crates/fhe-math/src/rq/context.rs | 382 +- crates/fhe-math/src/rq/convert.rs | 1112 +++--- crates/fhe-math/src/rq/mod.rs | 2030 +++++----- crates/fhe-math/src/rq/ops.rs | 1186 +++--- crates/fhe-math/src/rq/scaler.rs | 345 +- crates/fhe-math/src/rq/serialize.rs | 80 +- crates/fhe-math/src/rq/switcher.rs | 22 +- crates/fhe-math/src/rq/traits.rs | 28 +- crates/fhe-math/src/zq/mod.rs | 2008 +++++----- crates/fhe-math/src/zq/ntt.rs | 804 ++-- crates/fhe-math/src/zq/primes.rs | 178 +- crates/fhe-traits/src/lib.rs | 146 +- crates/fhe-util/src/lib.rs | 3342 ++++++++--------- crates/fhe-util/src/u256.rs | 332 +- crates/fhe/benches/bfv.rs | 480 +-- crates/fhe/benches/bfv_optimized_ops.rs | 110 +- crates/fhe/benches/bfv_rgsw.rs | 40 +- crates/fhe/examples/mulpir.rs | 444 +-- crates/fhe/examples/sealpir.rs | 584 +-- crates/fhe/examples/util.rs | 168 +- crates/fhe/src/bfv/ciphertext.rs | 456 +-- crates/fhe/src/bfv/encoding.rs | 90 +- crates/fhe/src/bfv/keys/evaluation_key.rs | 1471 ++++---- crates/fhe/src/bfv/keys/galois_key.rs | 302 +- crates/fhe/src/bfv/keys/key_switching_key.rs | 580 +-- crates/fhe/src/bfv/keys/public_key.rs | 304 +- .../fhe/src/bfv/keys/relinearization_key.rs | 484 +-- crates/fhe/src/bfv/keys/secret_key.rs | 392 +- crates/fhe/src/bfv/ops/dot_product.rs | 322 +- crates/fhe/src/bfv/ops/mod.rs | 978 ++--- crates/fhe/src/bfv/ops/mul.rs | 686 ++-- crates/fhe/src/bfv/parameters.rs | 1068 +++--- crates/fhe/src/bfv/plaintext.rs | 588 +-- crates/fhe/src/bfv/plaintext_vec.rs | 232 +- crates/fhe/src/bfv/rgsw_ciphertext.rs | 278 +- crates/fhe/src/bfv/traits.rs | 6 +- crates/fhe/src/errors.rs | 220 +- rustfmt.toml | 1 - 45 files changed, 12127 insertions(+), 12130 deletions(-) diff --git a/crates/fhe-math/benches/ntt.rs b/crates/fhe-math/benches/ntt.rs index 2ff98d6..9e4d5d7 100644 --- a/crates/fhe-math/benches/ntt.rs +++ b/crates/fhe-math/benches/ntt.rs @@ -4,40 +4,40 @@ use rand::thread_rng; use std::sync::Arc; pub fn ntt_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("ntt"); - group.sample_size(50); - let mut rng = thread_rng(); + let mut group = c.benchmark_group("ntt"); + group.sample_size(50); + let mut rng = thread_rng(); - for vector_size in [1024usize, 4096].iter() { - for p in [4611686018326724609u64, 40961u64] { - let p_nbits = 64 - p.leading_zeros(); - let q = Modulus::new(p).unwrap(); - let mut a = q.random_vec(*vector_size, &mut rng); - let op = NttOperator::new(&Arc::new(q), *vector_size).unwrap(); + for vector_size in [1024usize, 4096].iter() { + for p in [4611686018326724609u64, 40961u64] { + let p_nbits = 64 - p.leading_zeros(); + let q = Modulus::new(p).unwrap(); + let mut a = q.random_vec(*vector_size, &mut rng); + let op = NttOperator::new(&Arc::new(q), *vector_size).unwrap(); - group.bench_function( - BenchmarkId::new("forward", format!("{vector_size}/{p_nbits}")), - |b| b.iter(|| op.forward(&mut a)), - ); + group.bench_function( + BenchmarkId::new("forward", format!("{vector_size}/{p_nbits}")), + |b| b.iter(|| op.forward(&mut a)), + ); - group.bench_function( - BenchmarkId::new("forward_vt", format!("{vector_size}/{p_nbits}")), - |b| b.iter(|| unsafe { op.forward_vt(a.as_mut_ptr()) }), - ); + group.bench_function( + BenchmarkId::new("forward_vt", format!("{vector_size}/{p_nbits}")), + |b| b.iter(|| unsafe { op.forward_vt(a.as_mut_ptr()) }), + ); - group.bench_function( - BenchmarkId::new("backward", format!("{vector_size}/{p_nbits}")), - |b| b.iter(|| op.backward(&mut a)), - ); + group.bench_function( + BenchmarkId::new("backward", format!("{vector_size}/{p_nbits}")), + |b| b.iter(|| op.backward(&mut a)), + ); - group.bench_function( - BenchmarkId::new("backward_vt", format!("{vector_size}/{p_nbits}")), - |b| b.iter(|| unsafe { op.backward_vt(a.as_mut_ptr()) }), - ); - } - } + group.bench_function( + BenchmarkId::new("backward_vt", format!("{vector_size}/{p_nbits}")), + |b| b.iter(|| unsafe { op.backward_vt(a.as_mut_ptr()) }), + ); + } + } - group.finish(); + group.finish(); } criterion_group!(ntt, ntt_benchmark); diff --git a/crates/fhe-math/benches/rns.rs b/crates/fhe-math/benches/rns.rs index cfbd5a5..96ee92f 100644 --- a/crates/fhe-math/benches/rns.rs +++ b/crates/fhe-math/benches/rns.rs @@ -5,53 +5,53 @@ use rand::{thread_rng, RngCore}; use std::sync::Arc; pub fn rns_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("rns"); - group.sample_size(50); + let mut group = c.benchmark_group("rns"); + group.sample_size(50); - let q = [ - 4611686018326724609u64, - 4611686018309947393, - 4611686018282684417, - ]; - let p = [ - 4611686018257518593u64, - 4611686018232352769, - 4611686018171535361, - 4611686018106523649, - ]; + let q = [ + 4611686018326724609u64, + 4611686018309947393, + 4611686018282684417, + ]; + let p = [ + 4611686018257518593u64, + 4611686018232352769, + 4611686018171535361, + 4611686018106523649, + ]; - let mut rng = thread_rng(); - let mut x = vec![]; - for qi in &q { - x.push(rng.next_u64() % *qi); - } + let mut rng = thread_rng(); + let mut x = vec![]; + for qi in &q { + x.push(rng.next_u64() % *qi); + } - let rns_q = Arc::new(RnsContext::new(&q).unwrap()); - let rns_p = Arc::new(RnsContext::new(&p).unwrap()); - let scaler = RnsScaler::new( - &rns_q, - &rns_p, - ScalingFactor::new(&BigUint::from(1u64), &BigUint::from(46116860181065u64)), - ); - let scaler_as_converter = RnsScaler::new(&rns_q, &rns_p, ScalingFactor::one()); + let rns_q = Arc::new(RnsContext::new(&q).unwrap()); + let rns_p = Arc::new(RnsContext::new(&p).unwrap()); + let scaler = RnsScaler::new( + &rns_q, + &rns_p, + ScalingFactor::new(&BigUint::from(1u64), &BigUint::from(46116860181065u64)), + ); + let scaler_as_converter = RnsScaler::new(&rns_q, &rns_p, ScalingFactor::one()); - let mut y = vec![0; p.len()]; + let mut y = vec![0; p.len()]; - group.bench_function( - BenchmarkId::new("scaler", format!("{}->{}", q.len(), p.len())), - |b| { - b.iter(|| scaler.scale((&x).into(), (&mut y).into(), 0)); - }, - ); + group.bench_function( + BenchmarkId::new("scaler", format!("{}->{}", q.len(), p.len())), + |b| { + b.iter(|| scaler.scale((&x).into(), (&mut y).into(), 0)); + }, + ); - group.bench_function( - BenchmarkId::new("scaler_as_converter", format!("{}->{}", q.len(), p.len())), - |b| { - b.iter(|| scaler_as_converter.scale((&x).into(), (&mut y).into(), 0)); - }, - ); + group.bench_function( + BenchmarkId::new("scaler_as_converter", format!("{}->{}", q.len(), p.len())), + |b| { + b.iter(|| scaler_as_converter.scale((&x).into(), (&mut y).into(), 0)); + }, + ); - group.finish(); + group.finish(); } criterion_group!(rns, rns_benchmark); diff --git a/crates/fhe-math/benches/rq.rs b/crates/fhe-math/benches/rq.rs index fd1eb98..324a2e1 100644 --- a/crates/fhe-math/benches/rq.rs +++ b/crates/fhe-math/benches/rq.rs @@ -4,287 +4,287 @@ use fhe_math::rq::*; use itertools::{izip, Itertools}; use rand::thread_rng; use std::{ - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, - sync::Arc, - time::Duration, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + sync::Arc, + time::Duration, }; static MODULI: &[u64; 4] = &[ - 562949954093057, - 4611686018326724609, - 4611686018309947393, - 4611686018282684417, + 562949954093057, + 4611686018326724609, + 4611686018309947393, + 4611686018282684417, ]; static DEGREE: &[usize] = &[1024, 2048, 4096, 8192]; fn create_group(c: &mut Criterion, name: String) -> BenchmarkGroup { - let mut group = c.benchmark_group(name); - group.warm_up_time(Duration::from_millis(100)); - group.measurement_time(Duration::from_secs(1)); - group + let mut group = c.benchmark_group(name); + group.warm_up_time(Duration::from_millis(100)); + group.measurement_time(Duration::from_secs(1)); + group } macro_rules! bench_op { - ($c: expr, $name:expr, $op:expr, $vt:expr) => {{ - let name = if $vt { - format!("{}_vt", $name) - } else { - $name.to_string() - }; - let mut group = create_group($c, name); - let mut rng = thread_rng(); + ($c: expr, $name:expr, $op:expr, $vt:expr) => {{ + let name = if $vt { + format!("{}_vt", $name) + } else { + $name.to_string() + }; + let mut group = create_group($c, name); + let mut rng = thread_rng(); - for degree in DEGREE { - let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); - if $vt { - unsafe { q.allow_variable_time_computations() } - } + for degree in DEGREE { + let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + if $vt { + unsafe { q.allow_variable_time_computations() } + } - group.bench_function( - BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())), - |b| { - b.iter(|| $op(&p, &q)); - }, - ); - } - }}; + group.bench_function( + BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())), + |b| { + b.iter(|| $op(&p, &q)); + }, + ); + } + }}; } macro_rules! bench_op_unary { - ($c: expr, $name:expr, $op:expr, $vt:expr) => {{ - let name = if $vt { - format!("{}_vt", $name) - } else { - $name.to_string() - }; - let mut group = create_group($c, name); - let mut rng = thread_rng(); + ($c: expr, $name:expr, $op:expr, $vt:expr) => {{ + let name = if $vt { + format!("{}_vt", $name) + } else { + $name.to_string() + }; + let mut group = create_group($c, name); + let mut rng = thread_rng(); - for degree in DEGREE { - let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); - if $vt { - unsafe { q.allow_variable_time_computations() } - } + for degree in DEGREE { + let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + if $vt { + unsafe { q.allow_variable_time_computations() } + } - group.bench_function( - BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())), - |b| { - b.iter(|| $op(&p)); - }, - ); - } - }}; + group.bench_function( + BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())), + |b| { + b.iter(|| $op(&p)); + }, + ); + } + }}; } macro_rules! bench_op_assign { - ($c: expr, $name:expr, $op:expr, $vt:expr) => {{ - let name = if $vt { - format!("{}_vt", $name) - } else { - $name.to_string() - }; - let mut group = create_group($c, name); - let mut rng = thread_rng(); + ($c: expr, $name:expr, $op:expr, $vt:expr) => {{ + let name = if $vt { + format!("{}_vt", $name) + } else { + $name.to_string() + }; + let mut group = create_group($c, name); + let mut rng = thread_rng(); - for degree in DEGREE { - let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); - let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); - if $vt { - unsafe { q.allow_variable_time_computations() } - } + for degree in DEGREE { + let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap()); + let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + if $vt { + unsafe { q.allow_variable_time_computations() } + } - group.bench_function( - BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())), - |b| { - b.iter(|| $op(&mut p, &q)); - }, - ); - } - }}; + group.bench_function( + BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())), + |b| { + b.iter(|| $op(&mut p, &q)); + }, + ); + } + }}; } pub fn rq_op_benchmark(c: &mut Criterion) { - for vt in [false, true] { - bench_op!(c, "rq_add", <&Poly>::add, vt); - bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt); - bench_op!(c, "rq_sub", <&Poly>::sub, vt); - bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt); - bench_op!(c, "rq_mul", <&Poly>::mul, vt); - bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt); - bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt); - } + for vt in [false, true] { + bench_op!(c, "rq_add", <&Poly>::add, vt); + bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt); + bench_op!(c, "rq_sub", <&Poly>::sub, vt); + bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt); + bench_op!(c, "rq_mul", <&Poly>::mul, vt); + bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt); + bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt); + } } pub fn rq_dot_product(c: &mut Criterion) { - let mut group = create_group(c, "rq_dot_product".to_string()); - let mut rng = thread_rng(); - for degree in DEGREE { - for i in [1, 4] { - let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap()); - let p_vec = (0..256) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) - .collect_vec(); - let mut q_vec = (0..256) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) - .collect_vec(); - let mut out = Poly::zero(&ctx, Representation::Ntt); + let mut group = create_group(c, "rq_dot_product".to_string()); + let mut rng = thread_rng(); + for degree in DEGREE { + for i in [1, 4] { + let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap()); + let p_vec = (0..256) + .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .collect_vec(); + let mut q_vec = (0..256) + .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .collect_vec(); + let mut out = Poly::zero(&ctx, Representation::Ntt); - group.bench_function( - BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())), - |b| { - b.iter(|| { - izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi)) - }); - }, - ); + group.bench_function( + BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())), + |b| { + b.iter(|| { + izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi)) + }); + }, + ); - q_vec - .iter_mut() - .for_each(|qi| qi.change_representation(Representation::NttShoup)); - group.bench_function( - BenchmarkId::from_parameter(format!( - "naive_shoup/{}/{}", - degree, - ctx.modulus().bits() - )), - |b| { - b.iter(|| { - izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi)) - }); - }, - ); + q_vec + .iter_mut() + .for_each(|qi| qi.change_representation(Representation::NttShoup)); + group.bench_function( + BenchmarkId::from_parameter(format!( + "naive_shoup/{}/{}", + degree, + ctx.modulus().bits() + )), + |b| { + b.iter(|| { + izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi)) + }); + }, + ); - q_vec - .iter_mut() - .for_each(|qi| qi.change_representation(Representation::Ntt)); - group.bench_function( - BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())), - |b| { - b.iter(|| dot_product(p_vec.iter(), q_vec.iter())); - }, - ); - } - } + q_vec + .iter_mut() + .for_each(|qi| qi.change_representation(Representation::Ntt)); + group.bench_function( + BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())), + |b| { + b.iter(|| dot_product(p_vec.iter(), q_vec.iter())); + }, + ); + } + } } pub fn rq_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("rq"); - group.warm_up_time(Duration::from_millis(100)); - group.measurement_time(Duration::from_secs(1)); + let mut group = c.benchmark_group("rq"); + group.warm_up_time(Duration::from_millis(100)); + group.measurement_time(Duration::from_secs(1)); - let mut rng = thread_rng(); - for degree in DEGREE { - for nmoduli in 1..=MODULI.len() { - if !nmoduli.is_power_of_two() { - continue; - } - let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap()); - let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); - q.change_representation(Representation::NttShoup); + let mut rng = thread_rng(); + for degree in DEGREE { + for nmoduli in 1..=MODULI.len() { + if !nmoduli.is_power_of_two() { + continue; + } + let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap()); + let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + q.change_representation(Representation::NttShoup); - group.bench_function( - BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())), - |b| { - b.iter(|| p = &p * &q); - }, - ); + group.bench_function( + BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())), + |b| { + b.iter(|| p = &p * &q); + }, + ); - group.bench_function( - BenchmarkId::new( - "mul_shoup_assign", - format!("{}/{}", degree, ctx.modulus().bits()), - ), - |b| { - b.iter(|| p *= &q); - }, - ); + group.bench_function( + BenchmarkId::new( + "mul_shoup_assign", + format!("{}/{}", degree, ctx.modulus().bits()), + ), + |b| { + b.iter(|| p *= &q); + }, + ); - group.bench_function( - BenchmarkId::new( - "change_representation/PowerBasis_to_Ntt", - format!("{}/{}", degree, ctx.modulus().bits()), - ), - |b| { - b.iter(|| { - unsafe { - p.override_representation(Representation::PowerBasis); - } - p.change_representation(Representation::Ntt) - }); - }, - ); + group.bench_function( + BenchmarkId::new( + "change_representation/PowerBasis_to_Ntt", + format!("{}/{}", degree, ctx.modulus().bits()), + ), + |b| { + b.iter(|| { + unsafe { + p.override_representation(Representation::PowerBasis); + } + p.change_representation(Representation::Ntt) + }); + }, + ); - group.bench_function( - BenchmarkId::new( - "change_representation/Ntt_to_PowerBasis", - format!("{}/{}", degree, ctx.modulus().bits()), - ), - |b| { - b.iter(|| { - unsafe { - p.override_representation(Representation::Ntt); - } - p.change_representation(Representation::PowerBasis) - }); - }, - ); + group.bench_function( + BenchmarkId::new( + "change_representation/Ntt_to_PowerBasis", + format!("{}/{}", degree, ctx.modulus().bits()), + ), + |b| { + b.iter(|| { + unsafe { + p.override_representation(Representation::Ntt); + } + p.change_representation(Representation::PowerBasis) + }); + }, + ); - p.change_representation(Representation::Ntt); - q.change_representation(Representation::Ntt); + p.change_representation(Representation::Ntt); + q.change_representation(Representation::Ntt); - unsafe { - q.allow_variable_time_computations(); - q.change_representation(Representation::NttShoup); + unsafe { + q.allow_variable_time_computations(); + q.change_representation(Representation::NttShoup); - group.bench_function( - BenchmarkId::new( - "mul_shoup_vt", - format!("{}/{}", degree, ctx.modulus().bits()), - ), - |b| { - b.iter(|| p *= &q); - }, - ); + group.bench_function( + BenchmarkId::new( + "mul_shoup_vt", + format!("{}/{}", degree, ctx.modulus().bits()), + ), + |b| { + b.iter(|| p *= &q); + }, + ); - p.allow_variable_time_computations(); + p.allow_variable_time_computations(); - group.bench_function( - BenchmarkId::new( - "change_representation/PowerBasis_to_Ntt_vt", - format!("{}/{}", degree, ctx.modulus().bits()), - ), - |b| { - b.iter(|| { - p.override_representation(Representation::PowerBasis); - p.change_representation(Representation::Ntt) - }); - }, - ); + group.bench_function( + BenchmarkId::new( + "change_representation/PowerBasis_to_Ntt_vt", + format!("{}/{}", degree, ctx.modulus().bits()), + ), + |b| { + b.iter(|| { + p.override_representation(Representation::PowerBasis); + p.change_representation(Representation::Ntt) + }); + }, + ); - group.bench_function( - BenchmarkId::new( - "change_representation/Ntt_to_PowerBasis_vt", - format!("{}/{}", degree, ctx.modulus().bits()), - ), - |b| { - b.iter(|| { - p.override_representation(Representation::Ntt); - p.change_representation(Representation::PowerBasis) - }); - }, - ); - } - } - } + group.bench_function( + BenchmarkId::new( + "change_representation/Ntt_to_PowerBasis_vt", + format!("{}/{}", degree, ctx.modulus().bits()), + ), + |b| { + b.iter(|| { + p.override_representation(Representation::Ntt); + p.change_representation(Representation::PowerBasis) + }); + }, + ); + } + } + } - group.finish(); + group.finish(); } criterion_group!(rq, rq_op_benchmark, rq_dot_product, rq_benchmark); diff --git a/crates/fhe-math/benches/zq.rs b/crates/fhe-math/benches/zq.rs index 07ae72c..a7fb390 100644 --- a/crates/fhe-math/benches/zq.rs +++ b/crates/fhe-math/benches/zq.rs @@ -3,53 +3,53 @@ use fhe_math::zq::Modulus; use rand::thread_rng; pub fn zq_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("zq"); - group.sample_size(50); + let mut group = c.benchmark_group("zq"); + group.sample_size(50); - let p = 4611686018326724609; - let mut rng = thread_rng(); + let p = 4611686018326724609; + let mut rng = thread_rng(); - for vector_size in [1024usize, 4096].iter() { - let q = Modulus::new(p).unwrap(); - let mut a = q.random_vec(*vector_size, &mut rng); - let c = q.random_vec(*vector_size, &mut rng); - let c_shoup = q.shoup_vec(&c); - let scalar = c[0]; + for vector_size in [1024usize, 4096].iter() { + let q = Modulus::new(p).unwrap(); + let mut a = q.random_vec(*vector_size, &mut rng); + let c = q.random_vec(*vector_size, &mut rng); + let c_shoup = q.shoup_vec(&c); + let scalar = c[0]; - group.bench_function(BenchmarkId::new("add_vec", vector_size), |b| { - b.iter(|| q.add_vec(&mut a, &c)); - }); + group.bench_function(BenchmarkId::new("add_vec", vector_size), |b| { + b.iter(|| q.add_vec(&mut a, &c)); + }); - group.bench_function(BenchmarkId::new("add_vec_vt", vector_size), |b| unsafe { - b.iter(|| q.add_vec_vt(&mut a, &c)); - }); + group.bench_function(BenchmarkId::new("add_vec_vt", vector_size), |b| unsafe { + b.iter(|| q.add_vec_vt(&mut a, &c)); + }); - group.bench_function(BenchmarkId::new("sub_vec", vector_size), |b| { - b.iter(|| q.sub_vec(&mut a, &c)); - }); + group.bench_function(BenchmarkId::new("sub_vec", vector_size), |b| { + b.iter(|| q.sub_vec(&mut a, &c)); + }); - group.bench_function(BenchmarkId::new("neg_vec", vector_size), |b| { - b.iter(|| q.neg_vec(&mut a)); - }); + group.bench_function(BenchmarkId::new("neg_vec", vector_size), |b| { + b.iter(|| q.neg_vec(&mut a)); + }); - group.bench_function(BenchmarkId::new("mul_vec", vector_size), |b| { - b.iter(|| q.mul_vec(&mut a, &c)); - }); + group.bench_function(BenchmarkId::new("mul_vec", vector_size), |b| { + b.iter(|| q.mul_vec(&mut a, &c)); + }); - group.bench_function(BenchmarkId::new("mul_vec_vt", vector_size), |b| unsafe { - b.iter(|| q.mul_vec_vt(&mut a, &c)); - }); + group.bench_function(BenchmarkId::new("mul_vec_vt", vector_size), |b| unsafe { + b.iter(|| q.mul_vec_vt(&mut a, &c)); + }); - group.bench_function(BenchmarkId::new("mul_shoup_vec", vector_size), |b| { - b.iter(|| q.mul_shoup_vec(&mut a, &c, &c_shoup)); - }); + group.bench_function(BenchmarkId::new("mul_shoup_vec", vector_size), |b| { + b.iter(|| q.mul_shoup_vec(&mut a, &c, &c_shoup)); + }); - group.bench_function(BenchmarkId::new("scalar_mul_vec", vector_size), |b| { - b.iter(|| q.scalar_mul_vec(&mut a, scalar)); - }); - } + group.bench_function(BenchmarkId::new("scalar_mul_vec", vector_size), |b| { + b.iter(|| q.scalar_mul_vec(&mut a, scalar)); + }); + } - group.finish(); + group.finish(); } criterion_group!(zq, zq_benchmark); diff --git a/crates/fhe-math/src/errors.rs b/crates/fhe-math/src/errors.rs index 5cad785..c3b331d 100644 --- a/crates/fhe-math/src/errors.rs +++ b/crates/fhe-math/src/errors.rs @@ -8,63 +8,63 @@ pub type Result = std::result::Result; /// Enum encapsulation all the possible errors from this library. #[derive(Debug, Error, PartialEq, Eq)] pub enum Error { - /// Indicates an invalid modulus - #[error("Invalid modulus: modulus {0} should be between 2 and (1 << 62) - 1.")] - InvalidModulus(u64), + /// Indicates an invalid modulus + #[error("Invalid modulus: modulus {0} should be between 2 and (1 << 62) - 1.")] + InvalidModulus(u64), - /// Indicates an error in the serialization / deserialization. - #[error("{0}")] - Serialization(String), + /// Indicates an error in the serialization / deserialization. + #[error("{0}")] + Serialization(String), - /// Indicates that there is no more contexts to switch to. - #[error("This is the last context.")] - NoMoreContext, + /// Indicates that there is no more contexts to switch to. + #[error("This is the last context.")] + NoMoreContext, - /// Indicates that the provided context is invalid. - #[error("Invalid context provided.")] - InvalidContext, + /// Indicates that the provided context is invalid. + #[error("Invalid context provided.")] + InvalidContext, - /// Indicates an incorrect representation. - #[error("Incorrect representation: got {0:?}, expected {1:?}.")] - IncorrectRepresentation(Representation, Representation), + /// Indicates an incorrect representation. + #[error("Incorrect representation: got {0:?}, expected {1:?}.")] + IncorrectRepresentation(Representation, Representation), - /// Indicates that the seed size is incorrect. - #[error("Invalid seed: got {0} bytes, expected {1} bytes.")] - InvalidSeedSize(usize, usize), + /// Indicates that the seed size is incorrect. + #[error("Invalid seed: got {0} bytes, expected {1} bytes.")] + InvalidSeedSize(usize, usize), - /// Indicates a default error - /// TODO: To delete when transition is over - #[error("{0}")] - Default(String), + /// Indicates a default error + /// TODO: To delete when transition is over + #[error("{0}")] + Default(String), } #[cfg(test)] mod tests { - use crate::{rq::Representation, Error}; + use crate::{rq::Representation, Error}; - #[test] - fn error_strings() { - assert_eq!( - Error::InvalidModulus(0).to_string(), - "Invalid modulus: modulus 0 should be between 2 and (1 << 62) - 1." - ); - assert_eq!(Error::Serialization("test".to_string()).to_string(), "test"); - assert_eq!( - Error::NoMoreContext.to_string(), - "This is the last context." - ); - assert_eq!( - Error::InvalidContext.to_string(), - "Invalid context provided." - ); - assert_eq!( - Error::IncorrectRepresentation(Representation::Ntt, Representation::NttShoup) - .to_string(), - "Incorrect representation: got Ntt, expected NttShoup." - ); - assert_eq!( - Error::InvalidSeedSize(0, 1).to_string(), - "Invalid seed: got 0 bytes, expected 1 bytes." - ); - } + #[test] + fn error_strings() { + assert_eq!( + Error::InvalidModulus(0).to_string(), + "Invalid modulus: modulus 0 should be between 2 and (1 << 62) - 1." + ); + assert_eq!(Error::Serialization("test".to_string()).to_string(), "test"); + assert_eq!( + Error::NoMoreContext.to_string(), + "This is the last context." + ); + assert_eq!( + Error::InvalidContext.to_string(), + "Invalid context provided." + ); + assert_eq!( + Error::IncorrectRepresentation(Representation::Ntt, Representation::NttShoup) + .to_string(), + "Incorrect representation: got Ntt, expected NttShoup." + ); + assert_eq!( + Error::InvalidSeedSize(0, 1).to_string(), + "Invalid seed: got 0 bytes, expected 1 bytes." + ); + } } diff --git a/crates/fhe-math/src/rns/mod.rs b/crates/fhe-math/src/rns/mod.rs index 649c9b3..03e4169 100644 --- a/crates/fhe-math/src/rns/mod.rs +++ b/crates/fhe-math/src/rns/mod.rs @@ -17,218 +17,218 @@ pub use scaler::{RnsScaler, ScalingFactor}; /// Context for a Residue Number System. #[derive(Default, Clone, PartialEq, Eq)] pub struct RnsContext { - moduli_u64: Vec, - moduli: Vec, - q_tilde: Vec, - q_tilde_shoup: Vec, - q_star: Vec, - garner: Vec, - product: BigUint, + moduli_u64: Vec, + moduli: Vec, + q_tilde: Vec, + q_tilde_shoup: Vec, + q_star: Vec, + garner: Vec, + product: BigUint, } impl Debug for RnsContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RnsContext") - .field("moduli_u64", &self.moduli_u64) - // .field("moduli", &self.moduli) - // .field("q_tilde", &self.q_tilde) - // .field("q_tilde_shoup", &self.q_tilde_shoup) - // .field("q_star", &self.q_star) - // .field("garner", &self.garner) - .field("product", &self.product) - .finish() - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RnsContext") + .field("moduli_u64", &self.moduli_u64) + // .field("moduli", &self.moduli) + // .field("q_tilde", &self.q_tilde) + // .field("q_tilde_shoup", &self.q_tilde_shoup) + // .field("q_star", &self.q_star) + // .field("garner", &self.garner) + .field("product", &self.product) + .finish() + } } impl RnsContext { - /// Create a RNS context from a list of moduli. - /// - /// Returns an error if the list is empty, or if the moduli are no coprime. - pub fn new(moduli_u64: &[u64]) -> Result { - if moduli_u64.is_empty() { - Err(Error::Default("The list of moduli is empty".to_string())) - } else { - let mut moduli = Vec::with_capacity(moduli_u64.len()); - let mut q_tilde = Vec::with_capacity(moduli_u64.len()); - let mut q_tilde_shoup = Vec::with_capacity(moduli_u64.len()); - let mut q_star = Vec::with_capacity(moduli_u64.len()); - let mut garner = Vec::with_capacity(moduli_u64.len()); - let mut product = BigUint::one(); - let mut product_dig = BigUintDig::one(); + /// Create a RNS context from a list of moduli. + /// + /// Returns an error if the list is empty, or if the moduli are no coprime. + pub fn new(moduli_u64: &[u64]) -> Result { + if moduli_u64.is_empty() { + Err(Error::Default("The list of moduli is empty".to_string())) + } else { + let mut moduli = Vec::with_capacity(moduli_u64.len()); + let mut q_tilde = Vec::with_capacity(moduli_u64.len()); + let mut q_tilde_shoup = Vec::with_capacity(moduli_u64.len()); + let mut q_star = Vec::with_capacity(moduli_u64.len()); + let mut garner = Vec::with_capacity(moduli_u64.len()); + let mut product = BigUint::one(); + let mut product_dig = BigUintDig::one(); - for i in 0..moduli_u64.len() { - // Return None if the moduli are not coprime. - for j in 0..moduli_u64.len() { - if i != j { - let (d, _, _) = BigUintDig::from(moduli_u64[i]) - .extended_gcd(&BigUintDig::from(moduli_u64[j])); - if d.cmp(&BigIntDig::from(1)) != Ordering::Equal { - return Err(Error::Default("The moduli are not coprime".to_string())); - } - } - } + for i in 0..moduli_u64.len() { + // Return None if the moduli are not coprime. + for j in 0..moduli_u64.len() { + if i != j { + let (d, _, _) = BigUintDig::from(moduli_u64[i]) + .extended_gcd(&BigUintDig::from(moduli_u64[j])); + if d.cmp(&BigIntDig::from(1)) != Ordering::Equal { + return Err(Error::Default("The moduli are not coprime".to_string())); + } + } + } - product *= &BigUint::from(moduli_u64[i]); - product_dig *= &BigUintDig::from(moduli_u64[i]); - } + product *= &BigUint::from(moduli_u64[i]); + product_dig *= &BigUintDig::from(moduli_u64[i]); + } - for modulus in moduli_u64 { - moduli.push(Modulus::new(*modulus)?); - // q* = product / modulus - let q_star_i = &product / modulus; - // q~ = (product / modulus) ^ (-1) % modulus - let q_tilde_i = (&product_dig / modulus) - .mod_inverse(&BigUintDig::from(*modulus)) - .unwrap() - .to_u64() - .unwrap(); - // garner = (q*) * (q~) - let garner_i = &q_star_i * q_tilde_i; - q_tilde.push(q_tilde_i); - garner.push(garner_i); - q_star.push(q_star_i); - q_tilde_shoup.push( - Modulus::new(*modulus) - .unwrap() - .shoup(q_tilde_i.to_u64().unwrap()), - ); - } + for modulus in moduli_u64 { + moduli.push(Modulus::new(*modulus)?); + // q* = product / modulus + let q_star_i = &product / modulus; + // q~ = (product / modulus) ^ (-1) % modulus + let q_tilde_i = (&product_dig / modulus) + .mod_inverse(&BigUintDig::from(*modulus)) + .unwrap() + .to_u64() + .unwrap(); + // garner = (q*) * (q~) + let garner_i = &q_star_i * q_tilde_i; + q_tilde.push(q_tilde_i); + garner.push(garner_i); + q_star.push(q_star_i); + q_tilde_shoup.push( + Modulus::new(*modulus) + .unwrap() + .shoup(q_tilde_i.to_u64().unwrap()), + ); + } - Ok(Self { - moduli_u64: moduli_u64.to_owned(), - moduli, - q_tilde, - q_tilde_shoup, - q_star, - garner, - product, - }) - } - } + Ok(Self { + moduli_u64: moduli_u64.to_owned(), + moduli, + q_tilde, + q_tilde_shoup, + q_star, + garner, + product, + }) + } + } - /// Returns the product of the moduli used when creating the RNS context. - pub const fn modulus(&self) -> &BigUint { - &self.product - } + /// Returns the product of the moduli used when creating the RNS context. + pub const fn modulus(&self) -> &BigUint { + &self.product + } - /// Project a BigUint into its rests. - pub fn project(&self, a: &BigUint) -> Vec { - let mut rests = Vec::with_capacity(self.moduli_u64.len()); - for modulus in &self.moduli_u64 { - rests.push((a % modulus).to_u64().unwrap()) - } - rests - } + /// Project a BigUint into its rests. + pub fn project(&self, a: &BigUint) -> Vec { + let mut rests = Vec::with_capacity(self.moduli_u64.len()); + for modulus in &self.moduli_u64 { + rests.push((a % modulus).to_u64().unwrap()) + } + rests + } - /// Lift rests into a BigUint. - /// - /// Aborts if the number of rests is different than the number of moduli in - /// debug mode. - pub fn lift(&self, rests: ArrayView1) -> BigUint { - let mut result = BigUint::zero(); - izip!(rests.iter(), self.garner.iter()) - .for_each(|(r_i, garner_i)| result += garner_i * *r_i); - result % &self.product - } + /// Lift rests into a BigUint. + /// + /// Aborts if the number of rests is different than the number of moduli in + /// debug mode. + pub fn lift(&self, rests: ArrayView1) -> BigUint { + let mut result = BigUint::zero(); + izip!(rests.iter(), self.garner.iter()) + .for_each(|(r_i, garner_i)| result += garner_i * *r_i); + result % &self.product + } - /// Getter for the i-th garner coefficient. - pub fn get_garner(&self, i: usize) -> Option<&BigUint> { - self.garner.get(i) - } + /// Getter for the i-th garner coefficient. + pub fn get_garner(&self, i: usize) -> Option<&BigUint> { + self.garner.get(i) + } } #[cfg(test)] mod tests { - use std::error::Error; + use std::error::Error; - use super::RnsContext; - use ndarray::ArrayView1; - use num_bigint::BigUint; - use rand::RngCore; + use super::RnsContext; + use ndarray::ArrayView1; + use num_bigint::BigUint; + use rand::RngCore; - #[test] - fn constructor() { - assert!(RnsContext::new(&[2]).is_ok()); - assert!(RnsContext::new(&[2, 3]).is_ok()); - assert!(RnsContext::new(&[4, 15, 1153]).is_ok()); + #[test] + fn constructor() { + assert!(RnsContext::new(&[2]).is_ok()); + assert!(RnsContext::new(&[2, 3]).is_ok()); + assert!(RnsContext::new(&[4, 15, 1153]).is_ok()); - let e = RnsContext::new(&[]); - assert!(e.is_err()); - assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty"); - let e = RnsContext::new(&[2, 4]); - assert!(e.is_err()); - assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime"); - let e = RnsContext::new(&[2, 3, 5, 30]); - assert!(e.is_err()); - assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime"); - } + let e = RnsContext::new(&[]); + assert!(e.is_err()); + assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty"); + let e = RnsContext::new(&[2, 4]); + assert!(e.is_err()); + assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime"); + let e = RnsContext::new(&[2, 3, 5, 30]); + assert!(e.is_err()); + assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime"); + } - #[test] - fn garner() -> Result<(), Box> { - let rns = RnsContext::new(&[4, 15, 1153])?; + #[test] + fn garner() -> Result<(), Box> { + let rns = RnsContext::new(&[4, 15, 1153])?; - for i in 0..3 { - let gi = rns.get_garner(i); - assert!(gi.is_some()); - assert_eq!(gi.unwrap(), &rns.garner[i]); - } - assert!(rns.get_garner(3).is_none()); + for i in 0..3 { + let gi = rns.get_garner(i); + assert!(gi.is_some()); + assert_eq!(gi.unwrap(), &rns.garner[i]); + } + assert!(rns.get_garner(3).is_none()); - Ok(()) - } + Ok(()) + } - #[test] - fn modulus() -> Result<(), Box> { - let mut rns = RnsContext::new(&[2])?; - debug_assert_eq!(rns.modulus(), &BigUint::from(2u64)); + #[test] + fn modulus() -> Result<(), Box> { + let mut rns = RnsContext::new(&[2])?; + debug_assert_eq!(rns.modulus(), &BigUint::from(2u64)); - rns = RnsContext::new(&[2, 5])?; - debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5)); + rns = RnsContext::new(&[2, 5])?; + debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5)); - rns = RnsContext::new(&[4, 15, 1153])?; - debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153)); + rns = RnsContext::new(&[4, 15, 1153])?; + debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153)); - Ok(()) - } + Ok(()) + } - #[test] - fn project_lift() -> Result<(), Box> { - let ntests = 100; - let rns = RnsContext::new(&[4, 15, 1153])?; - let product = 4u64 * 15 * 1153; + #[test] + fn project_lift() -> Result<(), Box> { + let ntests = 100; + let rns = RnsContext::new(&[4, 15, 1153])?; + let product = 4u64 * 15 * 1153; - let mut rests = rns.project(&BigUint::from(0u64)); - assert_eq!(&rests, &[0u64, 0, 0]); - assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64)); + let mut rests = rns.project(&BigUint::from(0u64)); + assert_eq!(&rests, &[0u64, 0, 0]); + assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64)); - rests = rns.project(&BigUint::from(4u64)); - assert_eq!(&rests, &[0u64, 4, 4]); - assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64)); + rests = rns.project(&BigUint::from(4u64)); + assert_eq!(&rests, &[0u64, 4, 4]); + assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64)); - rests = rns.project(&BigUint::from(15u64)); - assert_eq!(&rests, &[3u64, 0, 15]); - assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64)); + rests = rns.project(&BigUint::from(15u64)); + assert_eq!(&rests, &[3u64, 0, 15]); + assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64)); - rests = rns.project(&BigUint::from(1153u64)); - assert_eq!(&rests, &[1u64, 13, 0]); - assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64)); + rests = rns.project(&BigUint::from(1153u64)); + assert_eq!(&rests, &[1u64, 13, 0]); + assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64)); - rests = rns.project(&BigUint::from(product - 1)); - assert_eq!(&rests, &[3u64, 14, 1152]); - assert_eq!( - rns.lift(ArrayView1::from(&rests)), - BigUint::from(product - 1) - ); + rests = rns.project(&BigUint::from(product - 1)); + assert_eq!(&rests, &[3u64, 14, 1152]); + assert_eq!( + rns.lift(ArrayView1::from(&rests)), + BigUint::from(product - 1) + ); - let mut rng = rand::thread_rng(); + let mut rng = rand::thread_rng(); - for _ in 0..ntests { - let b = BigUint::from(rng.next_u64() % product); - rests = rns.project(&b); - assert_eq!(rns.lift(ArrayView1::from(&rests)), b); - } + for _ in 0..ntests { + let b = BigUint::from(rng.next_u64() % product); + rests = rns.project(&b); + assert_eq!(rns.lift(ArrayView1::from(&rests)), b); + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe-math/src/rns/scaler.rs b/crates/fhe-math/src/rns/scaler.rs index 414696e..782a28e 100644 --- a/crates/fhe-math/src/rns/scaler.rs +++ b/crates/fhe-math/src/rns/scaler.rs @@ -14,481 +14,481 @@ use std::{cmp::min, sync::Arc}; /// Scaling factor when performing a RNS scaling. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct ScalingFactor { - numerator: BigUint, - denominator: BigUint, - pub(crate) is_one: bool, + numerator: BigUint, + denominator: BigUint, + pub(crate) is_one: bool, } impl ScalingFactor { - /// Create a new scaling factor. Aborts if the denominator is 0. - pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self { - assert_ne!(denominator, &BigUint::zero()); - Self { - numerator: numerator.clone(), - denominator: denominator.clone(), - is_one: numerator == denominator, - } - } + /// Create a new scaling factor. Aborts if the denominator is 0. + pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self { + assert_ne!(denominator, &BigUint::zero()); + Self { + numerator: numerator.clone(), + denominator: denominator.clone(), + is_one: numerator == denominator, + } + } - /// Returns the identity element of `Self`. - pub fn one() -> Self { - Self { - numerator: BigUint::one(), - denominator: BigUint::one(), - is_one: true, - } - } + /// Returns the identity element of `Self`. + pub fn one() -> Self { + Self { + numerator: BigUint::one(), + denominator: BigUint::one(), + is_one: true, + } + } } /// Scaler for a RNS context. /// This is a helper struct to perform RNS scaling. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct RnsScaler { - from: Arc, - to: Arc, - scaling_factor: ScalingFactor, + from: Arc, + to: Arc, + scaling_factor: ScalingFactor, - gamma: Box<[u64]>, - gamma_shoup: Box<[u64]>, - theta_gamma_lo: u64, - theta_gamma_hi: u64, - theta_gamma_sign: bool, + gamma: Box<[u64]>, + gamma_shoup: Box<[u64]>, + theta_gamma_lo: u64, + theta_gamma_hi: u64, + theta_gamma_sign: bool, - omega: Box<[Box<[u64]>]>, - omega_shoup: Box<[Box<[u64]>]>, - theta_omega_lo: Box<[u64]>, - theta_omega_hi: Box<[u64]>, - theta_omega_sign: Box<[bool]>, + omega: Box<[Box<[u64]>]>, + omega_shoup: Box<[Box<[u64]>]>, + theta_omega_lo: Box<[u64]>, + theta_omega_hi: Box<[u64]>, + theta_omega_sign: Box<[bool]>, - theta_garner_lo: Box<[u64]>, - theta_garner_hi: Box<[u64]>, - theta_garner_shift: usize, + theta_garner_lo: Box<[u64]>, + theta_garner_hi: Box<[u64]>, + theta_garner_shift: usize, } impl RnsScaler { - /// Create a RNS scaler by numerator / denominator. - /// - /// Aborts if denominator is equal to 0. - pub fn new( - from: &Arc, - to: &Arc, - scaling_factor: ScalingFactor, - ) -> Self { - // Let's define gamma = round(numerator * from.product / denominator) - let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) = - Self::extract_projection_and_theta( - to, - &from.product, - &scaling_factor.numerator, - &scaling_factor.denominator, - false, - ); - let gamma_shoup = izip!(&gamma, &to.moduli) - .map(|(wi, q)| q.shoup(*wi)) - .collect_vec(); + /// Create a RNS scaler by numerator / denominator. + /// + /// Aborts if denominator is equal to 0. + pub fn new( + from: &Arc, + to: &Arc, + scaling_factor: ScalingFactor, + ) -> Self { + // Let's define gamma = round(numerator * from.product / denominator) + let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) = + Self::extract_projection_and_theta( + to, + &from.product, + &scaling_factor.numerator, + &scaling_factor.denominator, + false, + ); + let gamma_shoup = izip!(&gamma, &to.moduli) + .map(|(wi, q)| q.shoup(*wi)) + .collect_vec(); - // Let's define omega_i = round(from.garner_i * numerator / denominator) - let mut omega = Vec::with_capacity(to.moduli.len()); - let mut omega_shoup = Vec::with_capacity(to.moduli.len()); - for _ in &to.moduli { - omega.push(vec![0u64; from.moduli.len()].into_boxed_slice()); - omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice()); - } - let mut theta_omega_lo = Vec::with_capacity(from.garner.len()); - let mut theta_omega_hi = Vec::with_capacity(from.garner.len()); - let mut theta_omega_sign = Vec::with_capacity(from.garner.len()); - for i in 0..from.garner.len() { - let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) = - Self::extract_projection_and_theta( - to, - &from.garner[i], - &scaling_factor.numerator, - &scaling_factor.denominator, - true, - ); - for j in 0..to.moduli.len() { - let qj = &to.moduli[j]; - omega[j][i] = qj.reduce(omega_i[j]); - omega_shoup[j][i] = qj.shoup(omega[j][i]); - } - theta_omega_lo.push(theta_omega_i_lo); - theta_omega_hi.push(theta_omega_i_hi); - theta_omega_sign.push(theta_omega_i_sign); - } + // Let's define omega_i = round(from.garner_i * numerator / denominator) + let mut omega = Vec::with_capacity(to.moduli.len()); + let mut omega_shoup = Vec::with_capacity(to.moduli.len()); + for _ in &to.moduli { + omega.push(vec![0u64; from.moduli.len()].into_boxed_slice()); + omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice()); + } + let mut theta_omega_lo = Vec::with_capacity(from.garner.len()); + let mut theta_omega_hi = Vec::with_capacity(from.garner.len()); + let mut theta_omega_sign = Vec::with_capacity(from.garner.len()); + for i in 0..from.garner.len() { + let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) = + Self::extract_projection_and_theta( + to, + &from.garner[i], + &scaling_factor.numerator, + &scaling_factor.denominator, + true, + ); + for j in 0..to.moduli.len() { + let qj = &to.moduli[j]; + omega[j][i] = qj.reduce(omega_i[j]); + omega_shoup[j][i] = qj.shoup(omega[j][i]); + } + theta_omega_lo.push(theta_omega_i_lo); + theta_omega_hi.push(theta_omega_i_hi); + theta_omega_sign.push(theta_omega_i_sign); + } - // Determine the shift so that the sum of the scaled theta_garner fit on an U192 - // (shift + 1) + log(q * n) <= 192 - let theta_garner_shift = min( - from.moduli_u64 - .iter() - .map(|qi| { - 192 - 1 - - ilog2( - ((*qi as u128) * (from.moduli_u64.len() as u128)).next_power_of_two(), - ) - }) - .min() - .unwrap(), - 127, - ); - // Finally, define theta_garner_i = from.garner_i / product, also scaled by - // 2^127. - let mut theta_garner_lo = Vec::with_capacity(from.garner.len()); - let mut theta_garner_hi = Vec::with_capacity(from.garner.len()); - for garner_i in &from.garner { - let mut theta: BigUint = - ((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product; - let theta_hi: BigUint = &theta >> 64; - theta -= &theta_hi << 64; - theta_garner_lo.push(theta.to_u64().unwrap()); - theta_garner_hi.push(theta_hi.to_u64().unwrap()); - } + // Determine the shift so that the sum of the scaled theta_garner fit on an U192 + // (shift + 1) + log(q * n) <= 192 + let theta_garner_shift = min( + from.moduli_u64 + .iter() + .map(|qi| { + 192 - 1 + - ilog2( + ((*qi as u128) * (from.moduli_u64.len() as u128)).next_power_of_two(), + ) + }) + .min() + .unwrap(), + 127, + ); + // Finally, define theta_garner_i = from.garner_i / product, also scaled by + // 2^127. + let mut theta_garner_lo = Vec::with_capacity(from.garner.len()); + let mut theta_garner_hi = Vec::with_capacity(from.garner.len()); + for garner_i in &from.garner { + let mut theta: BigUint = + ((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product; + let theta_hi: BigUint = &theta >> 64; + theta -= &theta_hi << 64; + theta_garner_lo.push(theta.to_u64().unwrap()); + theta_garner_hi.push(theta_hi.to_u64().unwrap()); + } - Self { - from: from.clone(), - to: to.clone(), - scaling_factor, - gamma: gamma.into_boxed_slice(), - gamma_shoup: gamma_shoup.into_boxed_slice(), - theta_gamma_lo, - theta_gamma_hi, - theta_gamma_sign, - omega: omega.into_boxed_slice(), - omega_shoup: omega_shoup.into_boxed_slice(), - theta_omega_lo: theta_omega_lo.into_boxed_slice(), - theta_omega_hi: theta_omega_hi.into_boxed_slice(), - theta_omega_sign: theta_omega_sign.into_boxed_slice(), - theta_garner_lo: theta_garner_lo.into_boxed_slice(), - theta_garner_hi: theta_garner_hi.into_boxed_slice(), - theta_garner_shift, - } - } + Self { + from: from.clone(), + to: to.clone(), + scaling_factor, + gamma: gamma.into_boxed_slice(), + gamma_shoup: gamma_shoup.into_boxed_slice(), + theta_gamma_lo, + theta_gamma_hi, + theta_gamma_sign, + omega: omega.into_boxed_slice(), + omega_shoup: omega_shoup.into_boxed_slice(), + theta_omega_lo: theta_omega_lo.into_boxed_slice(), + theta_omega_hi: theta_omega_hi.into_boxed_slice(), + theta_omega_sign: theta_omega_sign.into_boxed_slice(), + theta_garner_lo: theta_garner_lo.into_boxed_slice(), + theta_garner_hi: theta_garner_hi.into_boxed_slice(), + theta_garner_shift, + } + } - // Let's define gamma = round(numerator * input / denominator) - // and theta_gamma such that theta_gamma = numerator * input / denominator - - // gamma. This function projects gamma in the RNS context, and scales - // theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the - // RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma = - // (-1)**theta_sign * (theta_lo + 2^64 * theta_hi). - fn extract_projection_and_theta( - ctx: &RnsContext, - input: &BigUint, - numerator: &BigUint, - denominator: &BigUint, - round_up: bool, - ) -> (Vec, u64, u64, bool) { - let gamma = (numerator * input + (denominator >> 1)) / denominator; - let projected = ctx.project(&gamma); + // Let's define gamma = round(numerator * input / denominator) + // and theta_gamma such that theta_gamma = numerator * input / denominator - + // gamma. This function projects gamma in the RNS context, and scales + // theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the + // RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma = + // (-1)**theta_sign * (theta_lo + 2^64 * theta_hi). + fn extract_projection_and_theta( + ctx: &RnsContext, + input: &BigUint, + numerator: &BigUint, + denominator: &BigUint, + round_up: bool, + ) -> (Vec, u64, u64, bool) { + let gamma = (numerator * input + (denominator >> 1)) / denominator; + let projected = ctx.project(&gamma); - let mut theta = (numerator * input) % denominator; - let mut theta_sign = false; - if denominator > &BigUint::one() { - // If denominator is odd, flip theta if theta > (denominator >> 1) - if denominator & BigUint::one() == BigUint::one() { - if theta > (denominator >> 1) { - theta_sign = true; - theta = denominator - theta; - } - } else { - // denominator is even, flip if theta >= (denominator >> 1) - if theta >= (denominator >> 1) { - theta_sign = true; - theta = denominator - theta; - } - } - } - // theta = ((theta << 127) + (denominator >> 1)) / denominator; - // We can now split theta into two u64 words. - if round_up { - if theta_sign { - theta = (theta << 127) / denominator; - } else { - theta = ((theta << 127) + denominator - BigUint::one()) / denominator; - } - } else if theta_sign { - theta = ((theta << 127) + denominator - BigUint::one()) / denominator; - } else { - theta = (theta << 127) / denominator; - } - let theta_hi_biguint: BigUint = &theta >> 64; - theta -= &theta_hi_biguint << 64; - let theta_lo = theta.to_u64().unwrap(); - let theta_hi = theta_hi_biguint.to_u64().unwrap(); + let mut theta = (numerator * input) % denominator; + let mut theta_sign = false; + if denominator > &BigUint::one() { + // If denominator is odd, flip theta if theta > (denominator >> 1) + if denominator & BigUint::one() == BigUint::one() { + if theta > (denominator >> 1) { + theta_sign = true; + theta = denominator - theta; + } + } else { + // denominator is even, flip if theta >= (denominator >> 1) + if theta >= (denominator >> 1) { + theta_sign = true; + theta = denominator - theta; + } + } + } + // theta = ((theta << 127) + (denominator >> 1)) / denominator; + // We can now split theta into two u64 words. + if round_up { + if theta_sign { + theta = (theta << 127) / denominator; + } else { + theta = ((theta << 127) + denominator - BigUint::one()) / denominator; + } + } else if theta_sign { + theta = ((theta << 127) + denominator - BigUint::one()) / denominator; + } else { + theta = (theta << 127) / denominator; + } + let theta_hi_biguint: BigUint = &theta >> 64; + theta -= &theta_hi_biguint << 64; + let theta_lo = theta.to_u64().unwrap(); + let theta_hi = theta_hi_biguint.to_u64().unwrap(); - (projected, theta_lo, theta_hi, theta_sign) - } + (projected, theta_lo, theta_hi, theta_sign) + } - /// Output the RNS representation of the rests scaled by numerator * - /// denominator, and either rounded or floored. - /// - /// Aborts if the number of rests is different than the number of moduli in - /// debug mode, or if the size is not in [1, ..., rests.len()]. - pub fn scale_new(&self, rests: ArrayView1, size: usize) -> Vec { - let mut out = vec![0; size]; - self.scale(rests, (&mut out).into(), 0); - out - } + /// Output the RNS representation of the rests scaled by numerator * + /// denominator, and either rounded or floored. + /// + /// Aborts if the number of rests is different than the number of moduli in + /// debug mode, or if the size is not in [1, ..., rests.len()]. + pub fn scale_new(&self, rests: ArrayView1, size: usize) -> Vec { + let mut out = vec![0; size]; + self.scale(rests, (&mut out).into(), 0); + out + } - /// Compute the RNS representation of the rests scaled by numerator * - /// denominator, and either rounded or floored, and store the result in - /// `out`. - /// - /// Aborts if the number of rests is different than the number of moduli in - /// debug mode, or if the size of out is not in [1, ..., rests.len()]. - pub fn scale( - &self, - rests: ArrayView1, - mut out: ArrayViewMut1, - starting_index: usize, - ) { - debug_assert_eq!(rests.len(), self.from.moduli_u64.len()); - debug_assert!(!out.is_empty()); - debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len()); + /// Compute the RNS representation of the rests scaled by numerator * + /// denominator, and either rounded or floored, and store the result in + /// `out`. + /// + /// Aborts if the number of rests is different than the number of moduli in + /// debug mode, or if the size of out is not in [1, ..., rests.len()]. + pub fn scale( + &self, + rests: ArrayView1, + mut out: ArrayViewMut1, + starting_index: usize, + ) { + debug_assert_eq!(rests.len(), self.from.moduli_u64.len()); + debug_assert!(!out.is_empty()); + debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len()); - // First, let's compute the inner product of the rests with theta_omega. - let mut sum_theta_garner = U192::ZERO; - for (thetag_lo, thetag_hi, ri) in izip!( - self.theta_garner_lo.iter(), - self.theta_garner_hi.iter(), - rests - ) { - let lo = (*ri as u128) * (*thetag_lo as u128); - let hi = (*ri as u128) * (*thetag_hi as u128) + (lo >> 64); - sum_theta_garner = sum_theta_garner.wrapping_add(&U192::from_words([ - lo as u64, - hi as u64, - (hi >> 64) as u64, - ])); - } - // Let's compute v = round(sum_theta_garner / 2^theta_garner_shift) - sum_theta_garner >>= self.theta_garner_shift - 1; - let v = <[u64; 3]>::from(sum_theta_garner); - let v = div_ceil((v[0] as u128) | ((v[1] as u128) << 64), 2); + // First, let's compute the inner product of the rests with theta_omega. + let mut sum_theta_garner = U192::ZERO; + for (thetag_lo, thetag_hi, ri) in izip!( + self.theta_garner_lo.iter(), + self.theta_garner_hi.iter(), + rests + ) { + let lo = (*ri as u128) * (*thetag_lo as u128); + let hi = (*ri as u128) * (*thetag_hi as u128) + (lo >> 64); + sum_theta_garner = sum_theta_garner.wrapping_add(&U192::from_words([ + lo as u64, + hi as u64, + (hi >> 64) as u64, + ])); + } + // Let's compute v = round(sum_theta_garner / 2^theta_garner_shift) + sum_theta_garner >>= self.theta_garner_shift - 1; + let v = <[u64; 3]>::from(sum_theta_garner); + let v = div_ceil((v[0] as u128) | ((v[1] as u128) << 64), 2); - // If the scaling factor is not 1, compute the inner product with the - // theta_omega - let mut w_sign = false; - let mut w = 0u128; - if !self.scaling_factor.is_one { - let mut sum_theta_omega = U256::zero(); - for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!( - self.theta_omega_lo.iter(), - self.theta_omega_hi.iter(), - self.theta_omega_sign.iter(), - rests - ) { - let lo = (*ri as u128) * (*thetao_lo as u128); - let hi = (*ri as u128) * (*thetao_hi as u128) + (lo >> 64); - if *thetao_sign { - sum_theta_omega.wrapping_sub_assign(U256::from([ - lo as u64, - hi as u64, - (hi >> 64) as u64, - 0, - ])); - } else { - sum_theta_omega.wrapping_add_assign(U256::from([ - lo as u64, - hi as u64, - (hi >> 64) as u64, - 0, - ])); - } - } + // If the scaling factor is not 1, compute the inner product with the + // theta_omega + let mut w_sign = false; + let mut w = 0u128; + if !self.scaling_factor.is_one { + let mut sum_theta_omega = U256::zero(); + for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!( + self.theta_omega_lo.iter(), + self.theta_omega_hi.iter(), + self.theta_omega_sign.iter(), + rests + ) { + let lo = (*ri as u128) * (*thetao_lo as u128); + let hi = (*ri as u128) * (*thetao_hi as u128) + (lo >> 64); + if *thetao_sign { + sum_theta_omega.wrapping_sub_assign(U256::from([ + lo as u64, + hi as u64, + (hi >> 64) as u64, + 0, + ])); + } else { + sum_theta_omega.wrapping_add_assign(U256::from([ + lo as u64, + hi as u64, + (hi >> 64) as u64, + 0, + ])); + } + } - // Let's subtract v * theta_gamma to sum_theta_omega. - let vt_lo_lo = ((v as u64) as u128) * (self.theta_gamma_lo as u128); - let vt_lo_hi = ((v as u64) as u128) * (self.theta_gamma_hi as u128); - let vt_hi_lo = (v >> 64) * (self.theta_gamma_lo as u128); - let vt_hi_hi = (v >> 64) * (self.theta_gamma_hi as u128); - let vt_mi = - (vt_lo_lo >> 64) + ((vt_lo_hi as u64) as u128) + ((vt_hi_lo as u64) as u128); - let vt_hi = (vt_lo_hi >> 64) + (vt_mi >> 64) + ((vt_hi_hi as u64) as u128); - if self.theta_gamma_sign { - sum_theta_omega.wrapping_add_assign(U256::from([ - vt_lo_lo as u64, - vt_mi as u64, - vt_hi as u64, - 0, - ])) - } else { - sum_theta_omega.wrapping_sub_assign(U256::from([ - vt_lo_lo as u64, - vt_mi as u64, - vt_hi as u64, - 0, - ])) - } + // Let's subtract v * theta_gamma to sum_theta_omega. + let vt_lo_lo = ((v as u64) as u128) * (self.theta_gamma_lo as u128); + let vt_lo_hi = ((v as u64) as u128) * (self.theta_gamma_hi as u128); + let vt_hi_lo = (v >> 64) * (self.theta_gamma_lo as u128); + let vt_hi_hi = (v >> 64) * (self.theta_gamma_hi as u128); + let vt_mi = + (vt_lo_lo >> 64) + ((vt_lo_hi as u64) as u128) + ((vt_hi_lo as u64) as u128); + let vt_hi = (vt_lo_hi >> 64) + (vt_mi >> 64) + ((vt_hi_hi as u64) as u128); + if self.theta_gamma_sign { + sum_theta_omega.wrapping_add_assign(U256::from([ + vt_lo_lo as u64, + vt_mi as u64, + vt_hi as u64, + 0, + ])) + } else { + sum_theta_omega.wrapping_sub_assign(U256::from([ + vt_lo_lo as u64, + vt_mi as u64, + vt_hi as u64, + 0, + ])) + } - // Let's compute w = round(sum_theta_omega / 2^127). - w_sign = sum_theta_omega.msb() > 0; + // Let's compute w = round(sum_theta_omega / 2^127). + w_sign = sum_theta_omega.msb() > 0; - if w_sign { - w = u128::from(&((!sum_theta_omega) >> 126)) + 1; - w /= 2; - } else { - w = u128::from(&(sum_theta_omega >> 126)); - w = div_ceil(w, 2) - } - } + if w_sign { + w = u128::from(&((!sum_theta_omega) >> 126)) + 1; + w /= 2; + } else { + w = u128::from(&(sum_theta_omega >> 126)); + w = div_ceil(w, 2) + } + } - unsafe { - for i in 0..out.len() { - debug_assert!(starting_index + i < self.to.moduli.len()); - debug_assert!(starting_index + i < self.omega.len()); - debug_assert!(starting_index + i < self.omega_shoup.len()); - debug_assert!(starting_index + i < self.gamma.len()); - debug_assert!(starting_index + i < self.gamma_shoup.len()); - let out_i = out.get_mut(i).unwrap(); - let qi = self.to.moduli.get_unchecked(starting_index + i); - let omega_i = self.omega.get_unchecked(starting_index + i); - let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i); - let gamma_i = self.gamma.get_unchecked(starting_index + i); - let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i); + unsafe { + for i in 0..out.len() { + debug_assert!(starting_index + i < self.to.moduli.len()); + debug_assert!(starting_index + i < self.omega.len()); + debug_assert!(starting_index + i < self.omega_shoup.len()); + debug_assert!(starting_index + i < self.gamma.len()); + debug_assert!(starting_index + i < self.gamma_shoup.len()); + let out_i = out.get_mut(i).unwrap(); + let qi = self.to.moduli.get_unchecked(starting_index + i); + let omega_i = self.omega.get_unchecked(starting_index + i); + let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i); + let gamma_i = self.gamma.get_unchecked(starting_index + i); + let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i); - let mut yi = (qi.modulus() * 2 - - qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i)) - as u128; + let mut yi = (qi.modulus() * 2 + - qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i)) + as u128; - if !self.scaling_factor.is_one { - let wi = qi.lazy_reduce_u128(w); - yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128; - } + if !self.scaling_factor.is_one { + let wi = qi.lazy_reduce_u128(w); + yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128; + } - debug_assert!(rests.len() <= omega_i.len()); - debug_assert!(rests.len() <= omega_shoup_i.len()); - for j in 0..rests.len() { - yi += qi.lazy_mul_shoup( - *rests.get(j).unwrap(), - *omega_i.get_unchecked(j), - *omega_shoup_i.get_unchecked(j), - ) as u128; - } + debug_assert!(rests.len() <= omega_i.len()); + debug_assert!(rests.len() <= omega_shoup_i.len()); + for j in 0..rests.len() { + yi += qi.lazy_mul_shoup( + *rests.get(j).unwrap(), + *omega_i.get_unchecked(j), + *omega_shoup_i.get_unchecked(j), + ) as u128; + } - *out_i = qi.reduce_u128(yi) - } - } - } + *out_i = qi.reduce_u128(yi) + } + } + } } #[cfg(test)] mod tests { - use std::{error::Error, sync::Arc}; + use std::{error::Error, sync::Arc}; - use super::RnsScaler; - use crate::rns::{scaler::ScalingFactor, RnsContext}; - use fhe_util::catch_unwind; - use ndarray::ArrayView1; - use num_bigint::BigUint; - use num_traits::{ToPrimitive, Zero}; - use rand::{thread_rng, RngCore}; + use super::RnsScaler; + use crate::rns::{scaler::ScalingFactor, RnsContext}; + use fhe_util::catch_unwind; + use ndarray::ArrayView1; + use num_bigint::BigUint; + use num_traits::{ToPrimitive, Zero}; + use rand::{thread_rng, RngCore}; - #[test] - fn constructor() -> Result<(), Box> { - let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?); + #[test] + fn constructor() -> Result<(), Box> { + let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?); - let scaler = RnsScaler::new(&q, &q, ScalingFactor::one()); - assert_eq!(scaler.from, q); + let scaler = RnsScaler::new(&q, &q, ScalingFactor::one()); + assert_eq!(scaler.from, q); - assert!( - catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err() - ); - Ok(()) - } + assert!( + catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err() + ); + Ok(()) + } - #[test] - fn scale_same_context() -> Result<(), Box> { - let ntests = 1000; - let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?); - let mut rng = thread_rng(); + #[test] + fn scale_same_context() -> Result<(), Box> { + let ntests = 1000; + let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?); + let mut rng = thread_rng(); - for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { - for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { - let n = BigUint::from(*numerator); - let d = BigUint::from(*denominator); - let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d)); + for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { + for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { + let n = BigUint::from(*numerator); + let d = BigUint::from(*denominator); + let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d)); - for _ in 0..ntests { - let x = vec![ - rng.next_u64() % q.moduli_u64[0], - rng.next_u64() % q.moduli_u64[1], - rng.next_u64() % q.moduli_u64[2], - ]; - let mut x_lift = q.lift(ArrayView1::from(&x)); - let x_sign = x_lift >= (q.modulus() >> 1); - if x_sign { - x_lift = q.modulus() - x_lift; - } + for _ in 0..ntests { + let x = vec![ + rng.next_u64() % q.moduli_u64[0], + rng.next_u64() % q.moduli_u64[1], + rng.next_u64() % q.moduli_u64[2], + ]; + let mut x_lift = q.lift(ArrayView1::from(&x)); + let x_sign = x_lift >= (q.modulus() >> 1); + if x_sign { + x_lift = q.modulus() - x_lift; + } - let z = scaler.scale_new((&x).into(), x.len()); - let x_scaled_round = if x_sign { - if d.to_u64().unwrap() % 2 == 0 { - q.modulus() - - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus() - } else { - q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus() - } - } else { - &(&x_lift * &n + (&d >> 1)) / &d - }; - assert_eq!(z, q.project(&x_scaled_round)); - } - } - } - Ok(()) - } + let z = scaler.scale_new((&x).into(), x.len()); + let x_scaled_round = if x_sign { + if d.to_u64().unwrap() % 2 == 0 { + q.modulus() + - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus() + } else { + q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus() + } + } else { + &(&x_lift * &n + (&d >> 1)) / &d + }; + assert_eq!(z, q.project(&x_scaled_round)); + } + } + } + Ok(()) + } - #[test] - fn scale_different_contexts() -> Result<(), Box> { - let ntests = 100; - let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?); - let r = Arc::new(RnsContext::new(&[ - 4u64, - 4611686018326724609, - 1153, - 4611686018309947393, - 4611686018282684417, - 4611686018257518593, - 4611686018232352769, - 4611686018171535361, - 4611686018106523649, - 4611686018058289153, - ])?); - let mut rng = thread_rng(); + #[test] + fn scale_different_contexts() -> Result<(), Box> { + let ntests = 100; + let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?); + let r = Arc::new(RnsContext::new(&[ + 4u64, + 4611686018326724609, + 1153, + 4611686018309947393, + 4611686018282684417, + 4611686018257518593, + 4611686018232352769, + 4611686018171535361, + 4611686018106523649, + 4611686018058289153, + ])?); + let mut rng = thread_rng(); - for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { - for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { - let n = BigUint::from(*numerator); - let d = BigUint::from(*denominator); - let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d)); - for _ in 0..ntests { - let x = vec![ - rng.next_u64() % q.moduli_u64[0], - rng.next_u64() % q.moduli_u64[1], - rng.next_u64() % q.moduli_u64[2], - ]; + for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { + for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { + let n = BigUint::from(*numerator); + let d = BigUint::from(*denominator); + let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d)); + for _ in 0..ntests { + let x = vec![ + rng.next_u64() % q.moduli_u64[0], + rng.next_u64() % q.moduli_u64[1], + rng.next_u64() % q.moduli_u64[2], + ]; - let mut x_lift = q.lift(ArrayView1::from(&x)); - let x_sign = x_lift >= (q.modulus() >> 1); - if x_sign { - x_lift = q.modulus() - x_lift; - } + let mut x_lift = q.lift(ArrayView1::from(&x)); + let x_sign = x_lift >= (q.modulus() >> 1); + if x_sign { + x_lift = q.modulus() - x_lift; + } - let y = scaler.scale_new((&x).into(), r.moduli.len()); - let x_scaled_round = if x_sign { - if d.to_u64().unwrap() % 2 == 0 { - r.modulus() - - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus() - } else { - r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus() - } - } else { - &(&x_lift * &n + (&d >> 1)) / &d - }; - assert_eq!(y, r.project(&x_scaled_round)); - } - } - } - Ok(()) - } + let y = scaler.scale_new((&x).into(), r.moduli.len()); + let x_scaled_round = if x_sign { + if d.to_u64().unwrap() % 2 == 0 { + r.modulus() + - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus() + } else { + r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus() + } + } else { + &(&x_lift * &n + (&d >> 1)) / &d + }; + assert_eq!(y, r.project(&x_scaled_round)); + } + } + } + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/context.rs b/crates/fhe-math/src/rq/context.rs index 5760810..04f6f2a 100644 --- a/crates/fhe-math/src/rq/context.rs +++ b/crates/fhe-math/src/rq/context.rs @@ -3,233 +3,233 @@ use num_bigint::BigUint; use std::{fmt::Debug, sync::Arc}; use crate::{ - rns::RnsContext, - zq::{ntt::NttOperator, Modulus}, - Error, Result, + rns::RnsContext, + zq::{ntt::NttOperator, Modulus}, + Error, Result, }; /// Struct that holds the context associated with elements in rq. #[derive(Default, Clone, PartialEq, Eq)] pub struct Context { - pub(crate) moduli: Box<[u64]>, - pub(crate) q: Box<[Modulus]>, - pub(crate) rns: Arc, - pub(crate) ops: Box<[NttOperator]>, - pub(crate) degree: usize, - pub(crate) bitrev: Box<[usize]>, - pub(crate) inv_last_qi_mod_qj: Box<[u64]>, - pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>, - pub(crate) next_context: Option>, + pub(crate) moduli: Box<[u64]>, + pub(crate) q: Box<[Modulus]>, + pub(crate) rns: Arc, + pub(crate) ops: Box<[NttOperator]>, + pub(crate) degree: usize, + pub(crate) bitrev: Box<[usize]>, + pub(crate) inv_last_qi_mod_qj: Box<[u64]>, + pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>, + pub(crate) next_context: Option>, } impl Debug for Context { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Context") - .field("moduli", &self.moduli) - // .field("q", &self.q) - // .field("rns", &self.rns) - // .field("ops", &self.ops) - // .field("degree", &self.degree) - // .field("bitrev", &self.bitrev) - // .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj) - // .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup) - .field("next_context", &self.next_context) - .finish() - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Context") + .field("moduli", &self.moduli) + // .field("q", &self.q) + // .field("rns", &self.rns) + // .field("ops", &self.ops) + // .field("degree", &self.degree) + // .field("bitrev", &self.bitrev) + // .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj) + // .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup) + .field("next_context", &self.next_context) + .finish() + } } impl Context { - /// Creates a context from a list of moduli and a polynomial degree. - /// - /// Returns an error if the moduli are not primes less than 62 bits which - /// supports the NTT of size `degree`. - pub fn new(moduli: &[u64], degree: usize) -> Result { - if !degree.is_power_of_two() || degree < 8 { - Err(Error::Default( - "The degree is not a power of two larger or equal to 8".to_string(), - )) - } else { - let mut q = Vec::with_capacity(moduli.len()); - let rns = Arc::new(RnsContext::new(moduli)?); - let mut ops = Vec::with_capacity(moduli.len()); - for modulus in moduli { - let qi = Modulus::new(*modulus)?; - if let Some(op) = NttOperator::new(&qi, degree) { - q.push(qi); - ops.push(op); - } else { - return Err(Error::Default( - "Impossible to construct a Ntt operator".to_string(), - )); - } - } - let bitrev = (0..degree) - .map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1)) - .collect_vec(); + /// Creates a context from a list of moduli and a polynomial degree. + /// + /// Returns an error if the moduli are not primes less than 62 bits which + /// supports the NTT of size `degree`. + pub fn new(moduli: &[u64], degree: usize) -> Result { + if !degree.is_power_of_two() || degree < 8 { + Err(Error::Default( + "The degree is not a power of two larger or equal to 8".to_string(), + )) + } else { + let mut q = Vec::with_capacity(moduli.len()); + let rns = Arc::new(RnsContext::new(moduli)?); + let mut ops = Vec::with_capacity(moduli.len()); + for modulus in moduli { + let qi = Modulus::new(*modulus)?; + if let Some(op) = NttOperator::new(&qi, degree) { + q.push(qi); + ops.push(op); + } else { + return Err(Error::Default( + "Impossible to construct a Ntt operator".to_string(), + )); + } + } + let bitrev = (0..degree) + .map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1)) + .collect_vec(); - let mut inv_last_qi_mod_qj = vec![]; - let mut inv_last_qi_mod_qj_shoup = vec![]; - let q_last = moduli.last().unwrap(); - for qi in &q[..q.len() - 1] { - let inv = qi.inv(qi.reduce(*q_last)).unwrap(); - inv_last_qi_mod_qj.push(inv); - inv_last_qi_mod_qj_shoup.push(qi.shoup(inv)); - } + let mut inv_last_qi_mod_qj = vec![]; + let mut inv_last_qi_mod_qj_shoup = vec![]; + let q_last = moduli.last().unwrap(); + for qi in &q[..q.len() - 1] { + let inv = qi.inv(qi.reduce(*q_last)).unwrap(); + inv_last_qi_mod_qj.push(inv); + inv_last_qi_mod_qj_shoup.push(qi.shoup(inv)); + } - let next_context = if moduli.len() >= 2 { - Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?)) - } else { - None - }; + let next_context = if moduli.len() >= 2 { + Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?)) + } else { + None + }; - Ok(Self { - moduli: moduli.to_owned().into_boxed_slice(), - q: q.into_boxed_slice(), - rns, - ops: ops.into_boxed_slice(), - degree, - bitrev: bitrev.into_boxed_slice(), - inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(), - inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(), - next_context, - }) - } - } + Ok(Self { + moduli: moduli.to_owned().into_boxed_slice(), + q: q.into_boxed_slice(), + rns, + ops: ops.into_boxed_slice(), + degree, + bitrev: bitrev.into_boxed_slice(), + inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(), + inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(), + next_context, + }) + } + } - /// Returns the modulus as a BigUint. - pub fn modulus(&self) -> &BigUint { - self.rns.modulus() - } + /// Returns the modulus as a BigUint. + pub fn modulus(&self) -> &BigUint { + self.rns.modulus() + } - /// Returns a reference to the moduli in this context. - pub fn moduli(&self) -> &[u64] { - &self.moduli - } + /// Returns a reference to the moduli in this context. + pub fn moduli(&self) -> &[u64] { + &self.moduli + } - /// Returns a reference to the moduli as Modulus in this context. - pub fn moduli_operators(&self) -> &[Modulus] { - &self.q - } + /// Returns a reference to the moduli as Modulus in this context. + pub fn moduli_operators(&self) -> &[Modulus] { + &self.q + } - /// Returns the number of iterations to switch to a children context. - /// Returns an error if the context provided is not a child context. - pub fn niterations_to(&self, context: &Arc) -> Result { - if context.as_ref() == self { - return Ok(0); - } + /// Returns the number of iterations to switch to a children context. + /// Returns an error if the context provided is not a child context. + pub fn niterations_to(&self, context: &Arc) -> Result { + if context.as_ref() == self { + return Ok(0); + } - let mut niterations = 0; - let mut found = false; - let mut current_ctx = Arc::new(self.clone()); - while current_ctx.next_context.is_some() { - niterations += 1; - current_ctx = current_ctx.next_context.as_ref().unwrap().clone(); - if ¤t_ctx == context { - found = true; - break; - } - } - if found { - Ok(niterations) - } else { - Err(Error::InvalidContext) - } - } + let mut niterations = 0; + let mut found = false; + let mut current_ctx = Arc::new(self.clone()); + while current_ctx.next_context.is_some() { + niterations += 1; + current_ctx = current_ctx.next_context.as_ref().unwrap().clone(); + if ¤t_ctx == context { + found = true; + break; + } + } + if found { + Ok(niterations) + } else { + Err(Error::InvalidContext) + } + } - /// Returns the context after `i` iterations. - pub fn context_at_level(&self, i: usize) -> Result> { - if i >= self.moduli.len() { - Err(Error::Default( - "No context at the specified level".to_string(), - )) - } else { - let mut current_ctx = Arc::new(self.clone()); - for _ in 0..i { - current_ctx = current_ctx.next_context.as_ref().unwrap().clone(); - } - Ok(current_ctx) - } - } + /// Returns the context after `i` iterations. + pub fn context_at_level(&self, i: usize) -> Result> { + if i >= self.moduli.len() { + Err(Error::Default( + "No context at the specified level".to_string(), + )) + } else { + let mut current_ctx = Arc::new(self.clone()); + for _ in 0..i { + current_ctx = current_ctx.next_context.as_ref().unwrap().clone(); + } + Ok(current_ctx) + } + } } #[cfg(test)] mod tests { - use std::{error::Error, sync::Arc}; + use std::{error::Error, sync::Arc}; - use crate::{rq::Context, zq::ntt::supports_ntt}; + use crate::{rq::Context, zq::ntt::supports_ntt}; - const MODULI: &[u64; 5] = &[ - 1153, - 4611686018326724609, - 4611686018309947393, - 4611686018232352769, - 4611686018171535361, - ]; + const MODULI: &[u64; 5] = &[ + 1153, + 4611686018326724609, + 4611686018309947393, + 4611686018232352769, + 4611686018171535361, + ]; - #[test] - fn context_constructor() { - for modulus in MODULI { - // modulus is = 1 modulo 2 * 8 - assert!(Context::new(&[*modulus], 8).is_ok()); + #[test] + fn context_constructor() { + for modulus in MODULI { + // modulus is = 1 modulo 2 * 8 + assert!(Context::new(&[*modulus], 8).is_ok()); - if supports_ntt(*modulus, 128) { - assert!(Context::new(&[*modulus], 128).is_ok()); - } else { - assert!(Context::new(&[*modulus], 128).is_err()); - } - } + if supports_ntt(*modulus, 128) { + assert!(Context::new(&[*modulus], 128).is_ok()); + } else { + assert!(Context::new(&[*modulus], 128).is_err()); + } + } - // All moduli in MODULI are = 1 modulo 2 * 8 - assert!(Context::new(MODULI, 8).is_ok()); + // All moduli in MODULI are = 1 modulo 2 * 8 + assert!(Context::new(MODULI, 8).is_ok()); - // This should fail since 1153 != 1 moduli 2 * 128 - assert!(Context::new(MODULI, 128).is_err()); - } + // This should fail since 1153 != 1 moduli 2 * 128 + assert!(Context::new(MODULI, 128).is_err()); + } - #[test] - fn next_context() -> Result<(), Box> { - // A context should have a children pointing to a context with one less modulus. - let context = Arc::new(Context::new(MODULI, 8)?); - assert_eq!( - context.next_context, - Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?)) - ); + #[test] + fn next_context() -> Result<(), Box> { + // A context should have a children pointing to a context with one less modulus. + let context = Arc::new(Context::new(MODULI, 8)?); + assert_eq!( + context.next_context, + Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?)) + ); - // We can go down the chain of the MODULI.len() - 1 context's. - let mut number_of_children = 0; - let mut current = context; - while current.next_context.is_some() { - number_of_children += 1; - current = current.next_context.as_ref().unwrap().clone(); - } - assert_eq!(number_of_children, MODULI.len() - 1); + // We can go down the chain of the MODULI.len() - 1 context's. + let mut number_of_children = 0; + let mut current = context; + while current.next_context.is_some() { + number_of_children += 1; + current = current.next_context.as_ref().unwrap().clone(); + } + assert_eq!(number_of_children, MODULI.len() - 1); - Ok(()) - } + Ok(()) + } - #[test] - fn niterations_to() -> Result<(), Box> { - // A context should have a children pointing to a context with one less modulus. - let context = Arc::new(Context::new(MODULI, 8)?); + #[test] + fn niterations_to() -> Result<(), Box> { + // A context should have a children pointing to a context with one less modulus. + let context = Arc::new(Context::new(MODULI, 8)?); - assert_eq!(context.niterations_to(&context).ok(), Some(0)); + assert_eq!(context.niterations_to(&context).ok(), Some(0)); - assert_eq!( - context - .niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?)) - .err(), - Some(crate::Error::InvalidContext) - ); + assert_eq!( + context + .niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?)) + .err(), + Some(crate::Error::InvalidContext) + ); - for i in 1..MODULI.len() { - assert_eq!( - context - .niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?)) - .ok(), - Some(i) - ); - } + for i in 1..MODULI.len() { + assert_eq!( + context + .niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?)) + .ok(), + Some(i) + ); + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/convert.rs b/crates/fhe-math/src/rq/convert.rs index 70d6c04..88eeded 100644 --- a/crates/fhe-math/src/rq/convert.rs +++ b/crates/fhe-math/src/rq/convert.rs @@ -2,8 +2,8 @@ use super::{traits::TryConvertFrom, Context, Poly, Representation}; use crate::{ - proto::rq::{rq, Rq}, - Error, Result, + proto::rq::{rq, Rq}, + Error, Result, }; use itertools::{izip, Itertools}; use ndarray::{Array2, ArrayView, Axis}; @@ -13,655 +13,655 @@ use std::sync::Arc; use zeroize::{Zeroize, Zeroizing}; impl From<&Poly> for Rq { - fn from(p: &Poly) -> Self { - assert!(!p.has_lazy_coefficients); + fn from(p: &Poly) -> Self { + assert!(!p.has_lazy_coefficients); - let mut proto = Rq::new(); - match p.representation { - Representation::PowerBasis => { - proto.representation = EnumOrUnknown::new(rq::Representation::POWERBASIS); - } - Representation::Ntt => { - proto.representation = EnumOrUnknown::new(rq::Representation::NTT); - } - Representation::NttShoup => { - proto.representation = EnumOrUnknown::new(rq::Representation::NTTSHOUP); - } - } - let mut serialization_length = 0; - p.ctx - .q - .iter() - .for_each(|qi| serialization_length += qi.serialization_length(p.ctx.degree)); - let mut serialization = Vec::with_capacity(serialization_length); + let mut proto = Rq::new(); + match p.representation { + Representation::PowerBasis => { + proto.representation = EnumOrUnknown::new(rq::Representation::POWERBASIS); + } + Representation::Ntt => { + proto.representation = EnumOrUnknown::new(rq::Representation::NTT); + } + Representation::NttShoup => { + proto.representation = EnumOrUnknown::new(rq::Representation::NTTSHOUP); + } + } + let mut serialization_length = 0; + p.ctx + .q + .iter() + .for_each(|qi| serialization_length += qi.serialization_length(p.ctx.degree)); + let mut serialization = Vec::with_capacity(serialization_length); - izip!(p.coefficients.outer_iter(), p.ctx.q.iter()) - .for_each(|(v, qi)| serialization.append(&mut qi.serialize_vec(v.as_slice().unwrap()))); - proto.coefficients = serialization; - proto.degree = p.ctx.degree as u32; - proto.allow_variable_time = p.allow_variable_time_computations; - proto - } + izip!(p.coefficients.outer_iter(), p.ctx.q.iter()) + .for_each(|(v, qi)| serialization.append(&mut qi.serialize_vec(v.as_slice().unwrap()))); + proto.coefficients = serialization; + proto.degree = p.ctx.degree as u32; + proto.allow_variable_time = p.allow_variable_time_computations; + proto + } } impl TryConvertFrom> for Poly { - fn try_convert_from( - mut v: Vec, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = representation.into(); - match repr { - Some(Representation::Ntt) => { - if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { - Ok(Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }) - } else { - Err(Error::Default( - "In Ntt representation, all coefficients must be specified".to_string(), - )) - } - } - Some(Representation::NttShoup) => { - if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { - let mut p = Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }; - p.compute_coefficients_shoup(); - Ok(p) - } else { - Err(Error::Default( - "In NttShoup representation, all coefficients must be specified" - .to_string(), - )) - } - } - Some(Representation::PowerBasis) => { - if v.len() == ctx.q.len() * ctx.degree { - let coefficients = - Array2::from_shape_vec((ctx.q.len(), ctx.degree), v).unwrap(); - Ok(Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }) - } else if v.len() <= ctx.degree { - let mut out = Self::zero(ctx, repr.unwrap()); - if variable_time { - unsafe { - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( - |(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - wi[..v.len()].copy_from_slice(&v); - qi.reduce_vec_vt(wi); - }, - ); - out.allow_variable_time_computations(); - } - } else { - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( - |(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - wi[..v.len()].copy_from_slice(&v); - qi.reduce_vec(wi); - }, - ); - v.zeroize(); - } - Ok(out) - } else { - Err(Error::Default("In PowerBasis representation, either all coefficients must be specified, or only coefficients up to the degree".to_string())) - } - } - None => Err(Error::Default( - "When converting from a vector, the representation needs to be specified" - .to_string(), - )), - } - } + fn try_convert_from( + mut v: Vec, + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + let repr = representation.into(); + match repr { + Some(Representation::Ntt) => { + if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { + Ok(Self { + ctx: ctx.clone(), + representation: repr.unwrap(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + }) + } else { + Err(Error::Default( + "In Ntt representation, all coefficients must be specified".to_string(), + )) + } + } + Some(Representation::NttShoup) => { + if let Ok(coefficients) = Array2::from_shape_vec((ctx.q.len(), ctx.degree), v) { + let mut p = Self { + ctx: ctx.clone(), + representation: repr.unwrap(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + }; + p.compute_coefficients_shoup(); + Ok(p) + } else { + Err(Error::Default( + "In NttShoup representation, all coefficients must be specified" + .to_string(), + )) + } + } + Some(Representation::PowerBasis) => { + if v.len() == ctx.q.len() * ctx.degree { + let coefficients = + Array2::from_shape_vec((ctx.q.len(), ctx.degree), v).unwrap(); + Ok(Self { + ctx: ctx.clone(), + representation: repr.unwrap(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + }) + } else if v.len() <= ctx.degree { + let mut out = Self::zero(ctx, repr.unwrap()); + if variable_time { + unsafe { + izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( + |(mut w, qi)| { + let wi = w.as_slice_mut().unwrap(); + wi[..v.len()].copy_from_slice(&v); + qi.reduce_vec_vt(wi); + }, + ); + out.allow_variable_time_computations(); + } + } else { + izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each( + |(mut w, qi)| { + let wi = w.as_slice_mut().unwrap(); + wi[..v.len()].copy_from_slice(&v); + qi.reduce_vec(wi); + }, + ); + v.zeroize(); + } + Ok(out) + } else { + Err(Error::Default("In PowerBasis representation, either all coefficients must be specified, or only coefficients up to the degree".to_string())) + } + } + None => Err(Error::Default( + "When converting from a vector, the representation needs to be specified" + .to_string(), + )), + } + } } impl TryConvertFrom<&Rq> for Poly { - fn try_convert_from( - value: &Rq, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = value.representation.enum_value_or_default(); - let representation_from_proto = match repr { - rq::Representation::POWERBASIS => Representation::PowerBasis, - rq::Representation::NTT => Representation::Ntt, - rq::Representation::NTTSHOUP => Representation::NttShoup, - _ => return Err(Error::Default("Unknown representation".to_string())), - }; + fn try_convert_from( + value: &Rq, + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + let repr = value.representation.enum_value_or_default(); + let representation_from_proto = match repr { + rq::Representation::POWERBASIS => Representation::PowerBasis, + rq::Representation::NTT => Representation::Ntt, + rq::Representation::NTTSHOUP => Representation::NttShoup, + _ => return Err(Error::Default("Unknown representation".to_string())), + }; - let variable_time = variable_time || value.allow_variable_time; + let variable_time = variable_time || value.allow_variable_time; - if let Some(r) = representation.into() as Option { - if r != representation_from_proto { - return Err(Error::Default("The representation asked for does not match the representation in the serialization".to_string())); - } - } + if let Some(r) = representation.into() as Option { + if r != representation_from_proto { + return Err(Error::Default("The representation asked for does not match the representation in the serialization".to_string())); + } + } - let degree = value.degree as usize; - if degree % 8 != 0 || degree < 8 { - return Err(Error::Default("Invalid degree".to_string())); - } + let degree = value.degree as usize; + if degree % 8 != 0 || degree < 8 { + return Err(Error::Default("Invalid degree".to_string())); + } - let mut expected_nbytes = 0; - ctx.q - .iter() - .for_each(|qi| expected_nbytes += qi.serialization_length(degree)); - if value.coefficients.len() != expected_nbytes { - return Err(Error::Default("Invalid coefficients".to_string())); - } + let mut expected_nbytes = 0; + ctx.q + .iter() + .for_each(|qi| expected_nbytes += qi.serialization_length(degree)); + if value.coefficients.len() != expected_nbytes { + return Err(Error::Default("Invalid coefficients".to_string())); + } - let mut coefficients = Vec::with_capacity(ctx.q.len() * ctx.degree); - let mut index = 0; - for i in 0..ctx.q.len() { - let qi = &ctx.q[i]; - let size = qi.serialization_length(degree); - let mut v = qi.deserialize_vec(&value.coefficients[index..index + size]); - coefficients.append(&mut v); - index += size; - } + let mut coefficients = Vec::with_capacity(ctx.q.len() * ctx.degree); + let mut index = 0; + for i in 0..ctx.q.len() { + let qi = &ctx.q[i]; + let size = qi.serialization_length(degree); + let mut v = qi.deserialize_vec(&value.coefficients[index..index + size]); + coefficients.append(&mut v); + index += size; + } - Poly::try_convert_from(coefficients, ctx, variable_time, representation_from_proto) - } + Poly::try_convert_from(coefficients, ctx, variable_time, representation_from_proto) + } } impl TryConvertFrom> for Poly { - fn try_convert_from( - a: Array2, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - if a.shape() != [ctx.q.len(), ctx.degree] { - Err(Error::Default( - "The array of coefficient does not have the correct shape".to_string(), - )) - } else if let Some(repr) = representation.into() { - let mut p = Self { - ctx: ctx.clone(), - representation: repr, - allow_variable_time_computations: variable_time, - coefficients: a, - coefficients_shoup: None, - has_lazy_coefficients: false, - }; - if p.representation == Representation::NttShoup { - p.compute_coefficients_shoup() - } - Ok(p) - } else { - Err(Error::Default("When converting from a 2-dimensional array, the representation needs to be specified".to_string())) - } - } + fn try_convert_from( + a: Array2, + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + if a.shape() != [ctx.q.len(), ctx.degree] { + Err(Error::Default( + "The array of coefficient does not have the correct shape".to_string(), + )) + } else if let Some(repr) = representation.into() { + let mut p = Self { + ctx: ctx.clone(), + representation: repr, + allow_variable_time_computations: variable_time, + coefficients: a, + coefficients_shoup: None, + has_lazy_coefficients: false, + }; + if p.representation == Representation::NttShoup { + p.compute_coefficients_shoup() + } + Ok(p) + } else { + Err(Error::Default("When converting from a 2-dimensional array, the representation needs to be specified".to_string())) + } + } } impl<'a> TryConvertFrom<&'a [u64]> for Poly { - fn try_convert_from( - v: &'a [u64], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.to_vec(), ctx, variable_time, representation) - } + fn try_convert_from( + v: &'a [u64], + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + Poly::try_convert_from(v.to_vec(), ctx, variable_time, representation) + } } impl<'a> TryConvertFrom<&'a [i64]> for Poly { - fn try_convert_from( - v: &'a [i64], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - if representation.into() != Some(Representation::PowerBasis) { - Err(Error::Default( - "Converting signed integer require to import in PowerBasis representation" - .to_string(), - )) - } else if v.len() <= ctx.degree { - let mut out = Self::zero(ctx, Representation::PowerBasis); - if variable_time { - unsafe { out.allow_variable_time_computations() } - } - izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut w, qi)| { - let wi = w.as_slice_mut().unwrap(); - if variable_time { - unsafe { wi[..v.len()].copy_from_slice(&qi.reduce_vec_i64_vt(v)) } - } else { - wi[..v.len()].copy_from_slice(Zeroizing::new(qi.reduce_vec_i64(v)).as_ref()); - } - }); - Ok(out) - } else { - Err(Error::Default("In PowerBasis representation with signed integers, only `degree` coefficients can be specified".to_string())) - } - } + fn try_convert_from( + v: &'a [i64], + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + if representation.into() != Some(Representation::PowerBasis) { + Err(Error::Default( + "Converting signed integer require to import in PowerBasis representation" + .to_string(), + )) + } else if v.len() <= ctx.degree { + let mut out = Self::zero(ctx, Representation::PowerBasis); + if variable_time { + unsafe { out.allow_variable_time_computations() } + } + izip!(out.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut w, qi)| { + let wi = w.as_slice_mut().unwrap(); + if variable_time { + unsafe { wi[..v.len()].copy_from_slice(&qi.reduce_vec_i64_vt(v)) } + } else { + wi[..v.len()].copy_from_slice(Zeroizing::new(qi.reduce_vec_i64(v)).as_ref()); + } + }); + Ok(out) + } else { + Err(Error::Default("In PowerBasis representation with signed integers, only `degree` coefficients can be specified".to_string())) + } + } } impl<'a> TryConvertFrom<&'a Vec> for Poly { - fn try_convert_from( - v: &'a Vec, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref() as &[i64], ctx, variable_time, representation) - } + fn try_convert_from( + v: &'a Vec, + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + Poly::try_convert_from(v.as_ref() as &[i64], ctx, variable_time, representation) + } } impl<'a> TryConvertFrom<&'a [BigUint]> for Poly { - fn try_convert_from( - v: &'a [BigUint], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - let repr = representation.into(); + fn try_convert_from( + v: &'a [BigUint], + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + let repr = representation.into(); - if v.len() > ctx.degree { - Err(Error::Default( - "The slice contains too many big integers compared to the polynomial degree" - .to_string(), - )) - } else if repr.is_some() { - let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); + if v.len() > ctx.degree { + Err(Error::Default( + "The slice contains too many big integers compared to the polynomial degree" + .to_string(), + )) + } else if repr.is_some() { + let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); - izip!(coefficients.axis_iter_mut(Axis(1)), v).for_each(|(mut c, vi)| { - c.assign(&ArrayView::from(&ctx.rns.project(vi))); - }); + izip!(coefficients.axis_iter_mut(Axis(1)), v).for_each(|(mut c, vi)| { + c.assign(&ArrayView::from(&ctx.rns.project(vi))); + }); - let mut p = Self { - ctx: ctx.clone(), - representation: repr.unwrap(), - allow_variable_time_computations: variable_time, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }; + let mut p = Self { + ctx: ctx.clone(), + representation: repr.unwrap(), + allow_variable_time_computations: variable_time, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + }; - match p.representation { - Representation::PowerBasis => Ok(p), - Representation::Ntt => Ok(p), - Representation::NttShoup => { - p.compute_coefficients_shoup(); - Ok(p) - } - } - } else { - Err(Error::Default( - "When converting from a vector, the representation needs to be specified" - .to_string(), - )) - } - } + match p.representation { + Representation::PowerBasis => Ok(p), + Representation::Ntt => Ok(p), + Representation::NttShoup => { + p.compute_coefficients_shoup(); + Ok(p) + } + } + } else { + Err(Error::Default( + "When converting from a vector, the representation needs to be specified" + .to_string(), + )) + } + } } impl<'a> TryConvertFrom<&'a Vec> for Poly { - fn try_convert_from( - v: &'a Vec, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.to_vec(), ctx, variable_time, representation) - } + fn try_convert_from( + v: &'a Vec, + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + Poly::try_convert_from(v.to_vec(), ctx, variable_time, representation) + } } impl<'a, const N: usize> TryConvertFrom<&'a [u64; N]> for Poly { - fn try_convert_from( - v: &'a [u64; N], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) - } + fn try_convert_from( + v: &'a [u64; N], + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) + } } impl<'a, const N: usize> TryConvertFrom<&'a [BigUint; N]> for Poly { - fn try_convert_from( - v: &'a [BigUint; N], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) - } + fn try_convert_from( + v: &'a [BigUint; N], + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) + } } impl<'a, const N: usize> TryConvertFrom<&'a [i64; N]> for Poly { - fn try_convert_from( - v: &'a [i64; N], - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>, - { - Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) - } + fn try_convert_from( + v: &'a [i64; N], + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>, + { + Poly::try_convert_from(v.as_ref(), ctx, variable_time, representation) + } } impl From<&Poly> for Vec { - fn from(p: &Poly) -> Self { - p.coefficients.as_slice().unwrap().to_vec() - } + fn from(p: &Poly) -> Self { + p.coefficients.as_slice().unwrap().to_vec() + } } impl From<&Poly> for Vec { - fn from(p: &Poly) -> Self { - izip!(p.coefficients.axis_iter(Axis(1))) - .map(|c| p.ctx.rns.lift(c)) - .collect_vec() - } + fn from(p: &Poly) -> Self { + izip!(p.coefficients.axis_iter(Axis(1))) + .map(|c| p.ctx.rns.lift(c)) + .collect_vec() + } } #[cfg(test)] mod tests { - use crate::{ - proto::rq::Rq, - rq::{traits::TryConvertFrom, Context, Poly, Representation}, - Error as CrateError, - }; - use num_bigint::BigUint; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use crate::{ + proto::rq::Rq, + rq::{traits::TryConvertFrom, Context, Poly, Representation}, + Error as CrateError, + }; + use num_bigint::BigUint; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - static MODULI: &[u64; 3] = &[1153, 4611686018326724609, 4611686018309947393]; + static MODULI: &[u64; 3] = &[1153, 4611686018326724609, 4611686018309947393]; - #[test] - fn proto() -> Result<(), Box> { - let mut rng = thread_rng(); - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let proto = Rq::from(&p); - assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis)?, - p - ); - assert_eq!( + #[test] + fn proto() -> Result<(), Box> { + let mut rng = thread_rng(); + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let proto = Rq::from(&p); + assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); + assert_eq!( + Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis)?, + p + ); + assert_eq!( Poly::try_convert_from(&proto, &ctx, false, Representation::Ntt) .expect_err("Should fail because of mismatched representations"), CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) ); - assert_eq!( + assert_eq!( Poly::try_convert_from(&proto, &ctx, false, Representation::NttShoup) .expect_err("Should fail because of mismatched representations"), CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) ); - } + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let proto = Rq::from(&p); - assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis)?, - p - ); - assert_eq!( + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let proto = Rq::from(&p); + assert_eq!(Poly::try_convert_from(&proto, &ctx, false, None)?, p); + assert_eq!( + Poly::try_convert_from(&proto, &ctx, false, Representation::PowerBasis)?, + p + ); + assert_eq!( Poly::try_convert_from(&proto, &ctx, false, Representation::Ntt) .expect_err("Should fail because of mismatched representations"), CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) ); - assert_eq!( + assert_eq!( Poly::try_convert_from(&proto, &ctx, false, Representation::NttShoup) .expect_err("Should fail because of mismatched representations"), CrateError::Default("The representation asked for does not match the representation in the serialization".to_string()) ); - let ctx = Arc::new(Context::new(&MODULI[0..1], 8)?); - assert_eq!( - Poly::try_convert_from(&proto, &ctx, false, None) - .expect_err("Should fail because of incorrect context"), - CrateError::Default("Invalid coefficients".to_string()) - ); + let ctx = Arc::new(Context::new(&MODULI[0..1], 8)?); + assert_eq!( + Poly::try_convert_from(&proto, &ctx, false, None) + .expect_err("Should fail because of incorrect context"), + CrateError::Default("Invalid coefficients".to_string()) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn try_convert_from_slice_zero() -> Result<(), Box> { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + #[test] + fn try_convert_from_slice_zero() -> Result<(), Box> { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); - // Power Basis - assert_eq!( - Poly::try_convert_from(&[0u64], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(&[0i64], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(&[0i64; 8], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from( - &[0u64; 9], // One too many - &ctx, - false, - Representation::PowerBasis, - ) - .is_err()); + // Power Basis + assert_eq!( + Poly::try_convert_from(&[0u64], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert_eq!( + Poly::try_convert_from(&[0i64], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert_eq!( + Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert_eq!( + Poly::try_convert_from(&[0i64; 8], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from( + &[0u64; 9], // One too many + &ctx, + false, + Representation::PowerBasis, + ) + .is_err()); - // Ntt - assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); - assert!(Poly::try_convert_from(&[0i64], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); - assert!(Poly::try_convert_from(&[0i64; 8], &ctx, false, Representation::Ntt).is_err()); - assert!(Poly::try_convert_from( - &[0u64; 9], // One too many - &ctx, - false, - Representation::Ntt, - ) - .is_err()); - } + // Ntt + assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from(&[0i64], &ctx, false, Representation::Ntt).is_err()); + assert_eq!( + Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::Ntt)?, + Poly::zero(&ctx, Representation::Ntt) + ); + assert!(Poly::try_convert_from(&[0i64; 8], &ctx, false, Representation::Ntt).is_err()); + assert!(Poly::try_convert_from( + &[0u64; 9], // One too many + &ctx, + false, + Representation::Ntt, + ) + .is_err()); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - assert_eq!( - Poly::try_convert_from( - Vec::::default(), - &ctx, - false, - Representation::PowerBasis, - )?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!( - Poly::try_convert_from(Vec::::default(), &ctx, false, Representation::Ntt) - .is_err() - ); + let ctx = Arc::new(Context::new(MODULI, 8)?); + assert_eq!( + Poly::try_convert_from( + Vec::::default(), + &ctx, + false, + Representation::PowerBasis, + )?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!( + Poly::try_convert_from(Vec::::default(), &ctx, false, Representation::Ntt) + .is_err() + ); - assert_eq!( - Poly::try_convert_from(&[0u64], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); + assert_eq!( + Poly::try_convert_from(&[0u64], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(&[0u64], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::Ntt).is_err()); + assert_eq!( + Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(&[0u64; 8], &ctx, false, Representation::Ntt).is_err()); - assert!( - Poly::try_convert_from(&[0u64; 9], &ctx, false, Representation::PowerBasis).is_err() - ); - assert!(Poly::try_convert_from(&[0u64; 9], &ctx, false, Representation::Ntt).is_err()); + assert!( + Poly::try_convert_from(&[0u64; 9], &ctx, false, Representation::PowerBasis).is_err() + ); + assert!(Poly::try_convert_from(&[0u64; 9], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(&[0u64; 24], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(&[0u64; 24], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); + assert_eq!( + Poly::try_convert_from(&[0u64; 24], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert_eq!( + Poly::try_convert_from(&[0u64; 24], &ctx, false, Representation::Ntt)?, + Poly::zero(&ctx, Representation::Ntt) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn try_convert_from_vec_zero() -> Result<(), Box> { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - assert_eq!( - Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![], &ctx, false, Representation::Ntt).is_err()); + #[test] + fn try_convert_from_vec_zero() -> Result<(), Box> { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + assert_eq!( + Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(vec![], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(vec![0], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); + assert_eq!( + Poly::try_convert_from(vec![0], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); + assert_eq!( + Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert_eq!( + Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::Ntt)?, + Poly::zero(&ctx, Representation::Ntt) + ); - assert!( - Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::PowerBasis) - .is_err() - ); - assert!(Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::Ntt).is_err()); - } + assert!( + Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::PowerBasis) + .is_err() + ); + assert!(Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::Ntt).is_err()); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - assert_eq!( - Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![], &ctx, false, Representation::Ntt).is_err()); + let ctx = Arc::new(Context::new(MODULI, 8)?); + assert_eq!( + Poly::try_convert_from(vec![], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(vec![], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(vec![0], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); + assert_eq!( + Poly::try_convert_from(vec![0], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(vec![0], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert!(Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::Ntt).is_err()); + assert_eq!( + Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert!(Poly::try_convert_from(vec![0; 8], &ctx, false, Representation::Ntt).is_err()); - assert!( - Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::PowerBasis).is_err() - ); - assert!(Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::Ntt).is_err()); + assert!( + Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::PowerBasis).is_err() + ); + assert!(Poly::try_convert_from(vec![0; 9], &ctx, false, Representation::Ntt).is_err()); - assert_eq!( - Poly::try_convert_from(vec![0; 24], &ctx, false, Representation::PowerBasis)?, - Poly::zero(&ctx, Representation::PowerBasis) - ); - assert_eq!( - Poly::try_convert_from(vec![0; 24], &ctx, false, Representation::Ntt)?, - Poly::zero(&ctx, Representation::Ntt) - ); + assert_eq!( + Poly::try_convert_from(vec![0; 24], &ctx, false, Representation::PowerBasis)?, + Poly::zero(&ctx, Representation::PowerBasis) + ); + assert_eq!( + Poly::try_convert_from(vec![0; 24], &ctx, false, Representation::Ntt)?, + Poly::zero(&ctx, Representation::Ntt) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn biguint() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let p_coeffs = Vec::::from(&p); - let q = Poly::try_convert_from( - p_coeffs.as_slice(), - &ctx, - false, - Representation::PowerBasis, - )?; - assert_eq!(p, q); - } + #[test] + fn biguint() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p_coeffs = Vec::::from(&p); + let q = Poly::try_convert_from( + p_coeffs.as_slice(), + &ctx, + false, + Representation::PowerBasis, + )?; + assert_eq!(p, q); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let p_coeffs = Vec::::from(&p); - assert_eq!(p_coeffs.len(), ctx.degree); - let q = Poly::try_convert_from( - p_coeffs.as_slice(), - &ctx, - false, - Representation::PowerBasis, - )?; - assert_eq!(p, q); - } - Ok(()) - } + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let p_coeffs = Vec::::from(&p); + assert_eq!(p_coeffs.len(), ctx.degree); + let q = Poly::try_convert_from( + p_coeffs.as_slice(), + &ctx, + false, + Representation::PowerBasis, + )?; + assert_eq!(p, q); + } + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/mod.rs b/crates/fhe-math/src/rq/mod.rs index c8705da..ea16fed 100644 --- a/crates/fhe-math/src/rq/mod.rs +++ b/crates/fhe-math/src/rq/mod.rs @@ -28,1083 +28,1083 @@ use zeroize::{Zeroize, Zeroizing}; /// Possible representations of the underlying polynomial. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub enum Representation { - /// This is the list of coefficients ci, such that the polynomial is c0 + c1 - /// * x + ... + c_(degree - 1) * x^(degree - 1) - #[default] - PowerBasis, - /// This is the NTT representation of the PowerBasis representation. - Ntt, - /// This is a "Shoup" representation of the Ntt representation used for - /// faster multiplication. - NttShoup, + /// This is the list of coefficients ci, such that the polynomial is c0 + c1 + /// * x + ... + c_(degree - 1) * x^(degree - 1) + #[default] + PowerBasis, + /// This is the NTT representation of the PowerBasis representation. + Ntt, + /// This is a "Shoup" representation of the Ntt representation used for + /// faster multiplication. + NttShoup, } /// An exponent for a substitution. #[derive(Debug, PartialEq, Eq)] pub struct SubstitutionExponent { - /// The value of the exponent. - pub exponent: usize, + /// The value of the exponent. + pub exponent: usize, - ctx: Arc, - power_bitrev: Vec, + ctx: Arc, + power_bitrev: Vec, } impl SubstitutionExponent { - /// Creates a substitution element from an exponent. - /// Returns an error if the exponent is even modulo 2 * degree. - pub fn new(ctx: &Arc, exponent: usize) -> Result { - let exponent = exponent % (2 * ctx.degree); - if exponent & 1 == 0 { - return Err(Error::Default( - "The exponent should be odd modulo 2 * degree".to_string(), - )); - } - let mut power = (exponent - 1) / 2; - let mask = ctx.degree - 1; - let power_bitrev = (0..ctx.degree) - .map(|_| { - let r = (power & mask).reverse_bits() >> (ctx.degree.leading_zeros() + 1); - power += exponent; - r - }) - .collect_vec(); - Ok(Self { - ctx: ctx.clone(), - exponent, - power_bitrev, - }) - } + /// Creates a substitution element from an exponent. + /// Returns an error if the exponent is even modulo 2 * degree. + pub fn new(ctx: &Arc, exponent: usize) -> Result { + let exponent = exponent % (2 * ctx.degree); + if exponent & 1 == 0 { + return Err(Error::Default( + "The exponent should be odd modulo 2 * degree".to_string(), + )); + } + let mut power = (exponent - 1) / 2; + let mask = ctx.degree - 1; + let power_bitrev = (0..ctx.degree) + .map(|_| { + let r = (power & mask).reverse_bits() >> (ctx.degree.leading_zeros() + 1); + power += exponent; + r + }) + .collect_vec(); + Ok(Self { + ctx: ctx.clone(), + exponent, + power_bitrev, + }) + } } /// Struct that holds a polynomial for a specific context. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct Poly { - ctx: Arc, - representation: Representation, - has_lazy_coefficients: bool, - allow_variable_time_computations: bool, - coefficients: Array2, - coefficients_shoup: Option>, + ctx: Arc, + representation: Representation, + has_lazy_coefficients: bool, + allow_variable_time_computations: bool, + coefficients: Array2, + coefficients_shoup: Option>, } impl AsRef for Poly { - fn as_ref(&self) -> &Poly { - self - } + fn as_ref(&self) -> &Poly { + self + } } impl AsMut for Poly { - fn as_mut(&mut self) -> &mut Poly { - self - } + fn as_mut(&mut self) -> &mut Poly { + self + } } impl Poly { - /// Creates a polynomial holding the constant 0. - pub fn zero(ctx: &Arc, representation: Representation) -> Self { - Self { - ctx: ctx.clone(), - representation: representation.clone(), - allow_variable_time_computations: false, - has_lazy_coefficients: false, - coefficients: Array2::zeros((ctx.q.len(), ctx.degree)), - coefficients_shoup: if representation == Representation::NttShoup { - Some(Array2::zeros((ctx.q.len(), ctx.degree))) - } else { - None - }, - } - } + /// Creates a polynomial holding the constant 0. + pub fn zero(ctx: &Arc, representation: Representation) -> Self { + Self { + ctx: ctx.clone(), + representation: representation.clone(), + allow_variable_time_computations: false, + has_lazy_coefficients: false, + coefficients: Array2::zeros((ctx.q.len(), ctx.degree)), + coefficients_shoup: if representation == Representation::NttShoup { + Some(Array2::zeros((ctx.q.len(), ctx.degree))) + } else { + None + }, + } + } - /// Enable variable time computations when this polynomial is involved. - /// - /// # Safety - /// - /// By default, this is marked as unsafe, but is usually safe when only - /// public data is processed. - pub unsafe fn allow_variable_time_computations(&mut self) { - self.allow_variable_time_computations = true - } + /// Enable variable time computations when this polynomial is involved. + /// + /// # Safety + /// + /// By default, this is marked as unsafe, but is usually safe when only + /// public data is processed. + pub unsafe fn allow_variable_time_computations(&mut self) { + self.allow_variable_time_computations = true + } - /// Disable variable time computations when this polynomial is involved. - pub fn disallow_variable_time_computations(&mut self) { - self.allow_variable_time_computations = false - } + /// Disable variable time computations when this polynomial is involved. + pub fn disallow_variable_time_computations(&mut self) { + self.allow_variable_time_computations = false + } - /// Current representation of the polynomial. - pub const fn representation(&self) -> &Representation { - &self.representation - } + /// Current representation of the polynomial. + pub const fn representation(&self) -> &Representation { + &self.representation + } - /// Change the representation of the underlying polynomial. - pub fn change_representation(&mut self, to: Representation) { - match self.representation { - Representation::PowerBasis => { - match to { - Representation::Ntt => self.ntt_forward(), - Representation::NttShoup => { - self.ntt_forward(); - self.compute_coefficients_shoup(); - } - Representation::PowerBasis => {} // no-op - } - } - Representation::Ntt => { - match to { - Representation::PowerBasis => self.ntt_backward(), - Representation::NttShoup => self.compute_coefficients_shoup(), - Representation::Ntt => {} // no-op - } - } - Representation::NttShoup => { - if to != Representation::NttShoup { - // We are not sure whether this polynomial was sensitive or not, - // so for security, we zeroize the Shoup coefficients. - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - self.coefficients_shoup = None - } - match to { - Representation::PowerBasis => self.ntt_backward(), - Representation::Ntt => {} // no-op - Representation::NttShoup => {} // no-op - } - } - } + /// Change the representation of the underlying polynomial. + pub fn change_representation(&mut self, to: Representation) { + match self.representation { + Representation::PowerBasis => { + match to { + Representation::Ntt => self.ntt_forward(), + Representation::NttShoup => { + self.ntt_forward(); + self.compute_coefficients_shoup(); + } + Representation::PowerBasis => {} // no-op + } + } + Representation::Ntt => { + match to { + Representation::PowerBasis => self.ntt_backward(), + Representation::NttShoup => self.compute_coefficients_shoup(), + Representation::Ntt => {} // no-op + } + } + Representation::NttShoup => { + if to != Representation::NttShoup { + // We are not sure whether this polynomial was sensitive or not, + // so for security, we zeroize the Shoup coefficients. + self.coefficients_shoup + .as_mut() + .unwrap() + .as_slice_mut() + .unwrap() + .zeroize(); + self.coefficients_shoup = None + } + match to { + Representation::PowerBasis => self.ntt_backward(), + Representation::Ntt => {} // no-op + Representation::NttShoup => {} // no-op + } + } + } - self.representation = to; - } + self.representation = to; + } - /// Compute the Shoup representation of the coefficients. - fn compute_coefficients_shoup(&mut self) { - let mut coefficients_shoup = Array2::zeros((self.ctx.q.len(), self.ctx.degree)); - izip!( - coefficients_shoup.outer_iter_mut(), - self.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v_shoup, v, qi)| { - v_shoup - .as_slice_mut() - .unwrap() - .copy_from_slice(&qi.shoup_vec(v.as_slice().unwrap())) - }); - self.coefficients_shoup = Some(coefficients_shoup) - } + /// Compute the Shoup representation of the coefficients. + fn compute_coefficients_shoup(&mut self) { + let mut coefficients_shoup = Array2::zeros((self.ctx.q.len(), self.ctx.degree)); + izip!( + coefficients_shoup.outer_iter_mut(), + self.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v_shoup, v, qi)| { + v_shoup + .as_slice_mut() + .unwrap() + .copy_from_slice(&qi.shoup_vec(v.as_slice().unwrap())) + }); + self.coefficients_shoup = Some(coefficients_shoup) + } - /// Override the internal representation to a given representation. - /// - /// # Safety - /// - /// Prefer the `change_representation` function to safely modify the - /// polynomial representation. If the `to` representation is NttShoup, the - /// coefficients are still computed correctly to avoid being in an unstable - /// state. Similarly, if we override a representation which was NttShoup, we - /// zeroize the existing Shoup coefficients. - pub unsafe fn override_representation(&mut self, to: Representation) { - if to == Representation::NttShoup { - self.compute_coefficients_shoup() - } else if self.coefficients_shoup.is_some() { - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - self.coefficients_shoup = None - } - self.representation = to; - } + /// Override the internal representation to a given representation. + /// + /// # Safety + /// + /// Prefer the `change_representation` function to safely modify the + /// polynomial representation. If the `to` representation is NttShoup, the + /// coefficients are still computed correctly to avoid being in an unstable + /// state. Similarly, if we override a representation which was NttShoup, we + /// zeroize the existing Shoup coefficients. + pub unsafe fn override_representation(&mut self, to: Representation) { + if to == Representation::NttShoup { + self.compute_coefficients_shoup() + } else if self.coefficients_shoup.is_some() { + self.coefficients_shoup + .as_mut() + .unwrap() + .as_slice_mut() + .unwrap() + .zeroize(); + self.coefficients_shoup = None + } + self.representation = to; + } - /// Generate a random polynomial. - pub fn random( - ctx: &Arc, - representation: Representation, - rng: &mut R, - ) -> Self { - let mut p = Poly::zero(ctx, representation); - izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| { - v.as_slice_mut() - .unwrap() - .copy_from_slice(&qi.random_vec(ctx.degree, rng)) - }); - if p.representation == Representation::NttShoup { - p.compute_coefficients_shoup() - } - p - } + /// Generate a random polynomial. + pub fn random( + ctx: &Arc, + representation: Representation, + rng: &mut R, + ) -> Self { + let mut p = Poly::zero(ctx, representation); + izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| { + v.as_slice_mut() + .unwrap() + .copy_from_slice(&qi.random_vec(ctx.degree, rng)) + }); + if p.representation == Representation::NttShoup { + p.compute_coefficients_shoup() + } + p + } - /// Generate a random polynomial deterministically from a seed. - pub fn random_from_seed( - ctx: &Arc, - representation: Representation, - seed: ::Seed, - ) -> Self { - // Let's hash the seed into a ChaCha8Rng seed. - let mut hasher = Sha256::new(); - hasher.update(seed); - let mut prng = - ChaCha8Rng::from_seed(::Seed::from(hasher.finalize())); - let mut p = Poly::zero(ctx, representation); - izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| { - v.as_slice_mut() - .unwrap() - .copy_from_slice(&qi.random_vec(ctx.degree, &mut prng)) - }); - if p.representation == Representation::NttShoup { - p.compute_coefficients_shoup() - } - p - } + /// Generate a random polynomial deterministically from a seed. + pub fn random_from_seed( + ctx: &Arc, + representation: Representation, + seed: ::Seed, + ) -> Self { + // Let's hash the seed into a ChaCha8Rng seed. + let mut hasher = Sha256::new(); + hasher.update(seed); + let mut prng = + ChaCha8Rng::from_seed(::Seed::from(hasher.finalize())); + let mut p = Poly::zero(ctx, representation); + izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| { + v.as_slice_mut() + .unwrap() + .copy_from_slice(&qi.random_vec(ctx.degree, &mut prng)) + }); + if p.representation == Representation::NttShoup { + p.compute_coefficients_shoup() + } + p + } - /// Generate a small polynomial and convert into the specified - /// representation. - /// - /// Returns an error if the variance does not belong to [1, ..., 16]. - pub fn small( - ctx: &Arc, - representation: Representation, - variance: usize, - rng: &mut T, - ) -> Result { - if !(1..=16).contains(&variance) { - Err(Error::Default( - "The variance should be an integer between 1 and 16".to_string(), - )) - } else { - let coeffs = Zeroizing::new( - sample_vec_cbd(ctx.degree, variance, rng) - .map_err(|e| Error::Default(e.to_string()))?, - ); - let mut p = Poly::try_convert_from( - coeffs.as_ref() as &[i64], - ctx, - false, - Representation::PowerBasis, - )?; - if representation != Representation::PowerBasis { - p.change_representation(representation); - } - Ok(p) - } - } + /// Generate a small polynomial and convert into the specified + /// representation. + /// + /// Returns an error if the variance does not belong to [1, ..., 16]. + pub fn small( + ctx: &Arc, + representation: Representation, + variance: usize, + rng: &mut T, + ) -> Result { + if !(1..=16).contains(&variance) { + Err(Error::Default( + "The variance should be an integer between 1 and 16".to_string(), + )) + } else { + let coeffs = Zeroizing::new( + sample_vec_cbd(ctx.degree, variance, rng) + .map_err(|e| Error::Default(e.to_string()))?, + ); + let mut p = Poly::try_convert_from( + coeffs.as_ref() as &[i64], + ctx, + false, + Representation::PowerBasis, + )?; + if representation != Representation::PowerBasis { + p.change_representation(representation); + } + Ok(p) + } + } - /// Access the polynomial coefficients in RNS representation. - pub fn coefficients(&self) -> ArrayView2 { - self.coefficients.view() - } + /// Access the polynomial coefficients in RNS representation. + pub fn coefficients(&self) -> ArrayView2 { + self.coefficients.view() + } - /// Computes the forward Ntt on the coefficients - fn ntt_forward(&mut self) { - if self.allow_variable_time_computations { - izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) - .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) }); - } else { - izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) - .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap())); - } - } + /// Computes the forward Ntt on the coefficients + fn ntt_forward(&mut self) { + if self.allow_variable_time_computations { + izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) + .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) }); + } else { + izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) + .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap())); + } + } - /// Computes the backward Ntt on the coefficients - fn ntt_backward(&mut self) { - if self.allow_variable_time_computations { - izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) - .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) }); - } else { - izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) - .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap())); - } - } + /// Computes the backward Ntt on the coefficients + fn ntt_backward(&mut self) { + if self.allow_variable_time_computations { + izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) + .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) }); + } else { + izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter()) + .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap())); + } + } - /// Substitute x by x^i in a polynomial. - /// In PowerBasis representation, i can be any integer that is not a - /// multiple of 2 * degree. In Ntt and NttShoup representation, i can be any - /// odd integer that is not a multiple of 2 * degree. - pub fn substitute(&self, i: &SubstitutionExponent) -> Result { - let mut q = Poly::zero(&self.ctx, self.representation.clone()); - if self.allow_variable_time_computations { - unsafe { q.allow_variable_time_computations() } - } - match self.representation { - Representation::Ntt => { - izip!( - q.coefficients.outer_iter_mut(), - self.coefficients.outer_iter() - ) - .for_each(|(mut q_row, p_row)| { - for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) { - q_row[*j] = p_row[*k] - } - }); - } - Representation::NttShoup => { - izip!( - q.coefficients.outer_iter_mut(), - self.coefficients.outer_iter() - ) - .for_each(|(mut q_row, p_row)| { - for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) { - q_row[*j] = p_row[*k] - } - }); - izip!( - q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(), - self.coefficients_shoup.as_ref().unwrap().outer_iter() - ) - .for_each(|(mut q_row, p_row)| { - for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) { - q_row[*j] = p_row[*k] - } - }); - } - Representation::PowerBasis => { - let mut power = 0usize; - let mask = self.ctx.degree - 1; - for j in 0..self.ctx.degree { - izip!( - self.ctx.q.iter(), - q.coefficients.slice_mut(s![.., power & mask]), - self.coefficients.slice(s![.., j]) - ) - .for_each(|(qi, qij, pij)| { - if power & self.ctx.degree != 0 { - *qij = qi.sub(*qij, *pij) - } else { - *qij = qi.add(*qij, *pij) - } - }); - power += i.exponent - } - } - } + /// Substitute x by x^i in a polynomial. + /// In PowerBasis representation, i can be any integer that is not a + /// multiple of 2 * degree. In Ntt and NttShoup representation, i can be any + /// odd integer that is not a multiple of 2 * degree. + pub fn substitute(&self, i: &SubstitutionExponent) -> Result { + let mut q = Poly::zero(&self.ctx, self.representation.clone()); + if self.allow_variable_time_computations { + unsafe { q.allow_variable_time_computations() } + } + match self.representation { + Representation::Ntt => { + izip!( + q.coefficients.outer_iter_mut(), + self.coefficients.outer_iter() + ) + .for_each(|(mut q_row, p_row)| { + for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) { + q_row[*j] = p_row[*k] + } + }); + } + Representation::NttShoup => { + izip!( + q.coefficients.outer_iter_mut(), + self.coefficients.outer_iter() + ) + .for_each(|(mut q_row, p_row)| { + for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) { + q_row[*j] = p_row[*k] + } + }); + izip!( + q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(), + self.coefficients_shoup.as_ref().unwrap().outer_iter() + ) + .for_each(|(mut q_row, p_row)| { + for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) { + q_row[*j] = p_row[*k] + } + }); + } + Representation::PowerBasis => { + let mut power = 0usize; + let mask = self.ctx.degree - 1; + for j in 0..self.ctx.degree { + izip!( + self.ctx.q.iter(), + q.coefficients.slice_mut(s![.., power & mask]), + self.coefficients.slice(s![.., j]) + ) + .for_each(|(qi, qij, pij)| { + if power & self.ctx.degree != 0 { + *qij = qi.sub(*qij, *pij) + } else { + *qij = qi.add(*qij, *pij) + } + }); + power += i.exponent + } + } + } - Ok(q) - } + Ok(q) + } - /// Create a polynomial which can only be multiplied by a polynomial in - /// NttShoup representation. All other operations may panic. - /// - /// # Safety - /// This operation also creates a polynomial that allows variable time - /// operations. - pub unsafe fn create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( - power_basis_coefficients: &[u64], - ctx: &Arc, - ) -> Self { - let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); - izip!(coefficients.outer_iter_mut(), ctx.q.iter(), ctx.ops.iter()).for_each( - |(mut p, qi, op)| { - p.as_slice_mut() - .unwrap() - .clone_from_slice(power_basis_coefficients); - qi.lazy_reduce_vec(p.as_slice_mut().unwrap()); - op.forward_vt_lazy(p.as_mut_ptr()); - }, - ); - Self { - ctx: ctx.clone(), - representation: Representation::Ntt, - allow_variable_time_computations: true, - coefficients, - coefficients_shoup: None, - has_lazy_coefficients: true, - } - } + /// Create a polynomial which can only be multiplied by a polynomial in + /// NttShoup representation. All other operations may panic. + /// + /// # Safety + /// This operation also creates a polynomial that allows variable time + /// operations. + pub unsafe fn create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + power_basis_coefficients: &[u64], + ctx: &Arc, + ) -> Self { + let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree)); + izip!(coefficients.outer_iter_mut(), ctx.q.iter(), ctx.ops.iter()).for_each( + |(mut p, qi, op)| { + p.as_slice_mut() + .unwrap() + .clone_from_slice(power_basis_coefficients); + qi.lazy_reduce_vec(p.as_slice_mut().unwrap()); + op.forward_vt_lazy(p.as_mut_ptr()); + }, + ); + Self { + ctx: ctx.clone(), + representation: Representation::Ntt, + allow_variable_time_computations: true, + coefficients, + coefficients_shoup: None, + has_lazy_coefficients: true, + } + } - /// Modulus switch down the polynomial by dividing and rounding each - /// coefficient by the last modulus in the chain, then drops the last - /// modulus, as described in Algorithm 2 of . - /// - /// Returns an error if there is no next context or if the representation - /// is not PowerBasis. - pub fn mod_switch_down_next(&mut self) -> Result<()> { - if self.ctx.next_context.is_none() { - return Err(Error::NoMoreContext); - } + /// Modulus switch down the polynomial by dividing and rounding each + /// coefficient by the last modulus in the chain, then drops the last + /// modulus, as described in Algorithm 2 of . + /// + /// Returns an error if there is no next context or if the representation + /// is not PowerBasis. + pub fn mod_switch_down_next(&mut self) -> Result<()> { + if self.ctx.next_context.is_none() { + return Err(Error::NoMoreContext); + } - if self.representation != Representation::PowerBasis { - return Err(Error::IncorrectRepresentation( - self.representation.clone(), - Representation::PowerBasis, - )); - } + if self.representation != Representation::PowerBasis { + return Err(Error::IncorrectRepresentation( + self.representation.clone(), + Representation::PowerBasis, + )); + } - // Unwrap the next_context. - let next_context = self.ctx.next_context.as_ref().unwrap(); + // Unwrap the next_context. + let next_context = self.ctx.next_context.as_ref().unwrap(); - let q_len = self.ctx.q.len(); - let q_last = self.ctx.q.last().unwrap(); - let q_last_div_2 = q_last.modulus() / 2; + let q_len = self.ctx.q.len(); + let q_last = self.ctx.q.last().unwrap(); + let q_last_div_2 = q_last.modulus() / 2; - // Add (q_last - 1) / 2 to change from flooring to rounding - let (mut q_new_polys, mut q_last_poly) = - self.coefficients.view_mut().split_at(Axis(0), q_len - 1); + // Add (q_last - 1) / 2 to change from flooring to rounding + let (mut q_new_polys, mut q_last_poly) = + self.coefficients.view_mut().split_at(Axis(0), q_len - 1); - if self.allow_variable_time_computations { - unsafe { - q_last_poly - .iter_mut() - .for_each(|coeff| *coeff = q_last.add_vt(*coeff, q_last_div_2)); - izip!( - q_new_polys.outer_iter_mut(), - self.ctx.q.iter(), - self.ctx.inv_last_qi_mod_qj.iter(), - self.ctx.inv_last_qi_mod_qj_shoup.iter(), - ) - .for_each(|(coeffs, qi, inv, inv_shoup)| { - let q_last_div_2_mod_qi = qi.modulus() - qi.reduce_vt(q_last_div_2); // Up to qi.modulus() - for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) { - // (x mod q_last - q_L/2) mod q_i - let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus() + if self.allow_variable_time_computations { + unsafe { + q_last_poly + .iter_mut() + .for_each(|coeff| *coeff = q_last.add_vt(*coeff, q_last_div_2)); + izip!( + q_new_polys.outer_iter_mut(), + self.ctx.q.iter(), + self.ctx.inv_last_qi_mod_qj.iter(), + self.ctx.inv_last_qi_mod_qj_shoup.iter(), + ) + .for_each(|(coeffs, qi, inv, inv_shoup)| { + let q_last_div_2_mod_qi = qi.modulus() - qi.reduce_vt(q_last_div_2); // Up to qi.modulus() + for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) { + // (x mod q_last - q_L/2) mod q_i + let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus() - // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i - // = (x - x mod q_last + q_L/2) mod q_i - *coeff += 3 * qi.modulus() - tmp; // Up to 4 * qi.modulus() + // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i + // = (x - x mod q_last + q_L/2) mod q_i + *coeff += 3 * qi.modulus() - tmp; // Up to 4 * qi.modulus() - // q_last^{-1} * (x - x mod q_last) mod q_i - *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup); - } - }); - } - } else { - q_last_poly - .iter_mut() - .for_each(|coeff| *coeff = q_last.add(*coeff, q_last_div_2)); - izip!( - q_new_polys.outer_iter_mut(), - self.ctx.q.iter(), - self.ctx.inv_last_qi_mod_qj.iter(), - self.ctx.inv_last_qi_mod_qj_shoup.iter(), - ) - .for_each(|(coeffs, qi, inv, inv_shoup)| { - let q_last_div_2_mod_qi = qi.modulus() - qi.reduce(q_last_div_2); // Up to qi.modulus() - for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) { - // (x mod q_last - q_L/2) mod q_i - let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus() + // q_last^{-1} * (x - x mod q_last) mod q_i + *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup); + } + }); + } + } else { + q_last_poly + .iter_mut() + .for_each(|coeff| *coeff = q_last.add(*coeff, q_last_div_2)); + izip!( + q_new_polys.outer_iter_mut(), + self.ctx.q.iter(), + self.ctx.inv_last_qi_mod_qj.iter(), + self.ctx.inv_last_qi_mod_qj_shoup.iter(), + ) + .for_each(|(coeffs, qi, inv, inv_shoup)| { + let q_last_div_2_mod_qi = qi.modulus() - qi.reduce(q_last_div_2); // Up to qi.modulus() + for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) { + // (x mod q_last - q_L/2) mod q_i + let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus() - // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i - // = (x - x mod q_last + q_L/2) mod q_i - *coeff += 3 * qi.modulus() - tmp; // Up to 4 * qi.modulus() + // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i + // = (x - x mod q_last + q_L/2) mod q_i + *coeff += 3 * qi.modulus() - tmp; // Up to 4 * qi.modulus() - // q_last^{-1} * (x - x mod q_last) mod q_i - *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup); - } - }); - } + // q_last^{-1} * (x - x mod q_last) mod q_i + *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup); + } + }); + } - // Remove the last row, and update the context. - if !self.allow_variable_time_computations { - q_last_poly.as_slice_mut().unwrap().zeroize(); - } - self.coefficients.remove_index(Axis(0), q_len - 1); - self.ctx = next_context.clone(); + // Remove the last row, and update the context. + if !self.allow_variable_time_computations { + q_last_poly.as_slice_mut().unwrap().zeroize(); + } + self.coefficients.remove_index(Axis(0), q_len - 1); + self.ctx = next_context.clone(); - Ok(()) - } + Ok(()) + } - /// Modulo switch down to a smaller context. - /// - /// Returns an error if there is the provided context is not a child of the - /// current context, or if the polynomial is not in PowerBasis - /// representation. - pub fn mod_switch_down_to(&mut self, context: &Arc) -> Result<()> { - let niterations = self.ctx.niterations_to(context)?; - for _ in 0..niterations { - self.mod_switch_down_next()?; - } - assert_eq!(&self.ctx, context); - Ok(()) - } + /// Modulo switch down to a smaller context. + /// + /// Returns an error if there is the provided context is not a child of the + /// current context, or if the polynomial is not in PowerBasis + /// representation. + pub fn mod_switch_down_to(&mut self, context: &Arc) -> Result<()> { + let niterations = self.ctx.niterations_to(context)?; + for _ in 0..niterations { + self.mod_switch_down_next()?; + } + assert_eq!(&self.ctx, context); + Ok(()) + } - /// Modulo switch to another context. The target context needs not to be - /// related to the current context. - pub fn mod_switch_to(&self, switcher: &Switcher) -> Result { - switcher.switch(self) - } + /// Modulo switch to another context. The target context needs not to be + /// related to the current context. + pub fn mod_switch_to(&self, switcher: &Switcher) -> Result { + switcher.switch(self) + } - /// Scale a polynomial using a scaler. - pub fn scale(&self, scaler: &Scaler) -> Result { - scaler.scale(self) - } + /// Scale a polynomial using a scaler. + pub fn scale(&self, scaler: &Scaler) -> Result { + scaler.scale(self) + } - /// Returns the context of the underlying polynomial - pub fn ctx(&self) -> &Arc { - &self.ctx - } + /// Returns the context of the underlying polynomial + pub fn ctx(&self) -> &Arc { + &self.ctx + } - /// Multiplies a polynomial in PowerBasis representation by x^(-power). - pub fn multiply_inverse_power_of_x(&mut self, power: usize) -> Result<()> { - if self.representation != Representation::PowerBasis { - return Err(Error::IncorrectRepresentation( - self.representation.clone(), - Representation::PowerBasis, - )); - } + /// Multiplies a polynomial in PowerBasis representation by x^(-power). + pub fn multiply_inverse_power_of_x(&mut self, power: usize) -> Result<()> { + if self.representation != Representation::PowerBasis { + return Err(Error::IncorrectRepresentation( + self.representation.clone(), + Representation::PowerBasis, + )); + } - let shift = ((self.ctx.degree << 1) - power) % (self.ctx.degree << 1); - let mask = self.ctx.degree - 1; - let original_coefficients = self.coefficients.clone(); - izip!( - self.coefficients.outer_iter_mut(), - original_coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut coeffs, orig_coeffs, qi)| { - for k in 0..self.ctx.degree { - let index = shift + k; - if index & self.ctx.degree == 0 { - coeffs[index & mask] = orig_coeffs[k]; - } else { - coeffs[index & mask] = qi.neg(orig_coeffs[k]); - } - } - }); - Ok(()) - } + let shift = ((self.ctx.degree << 1) - power) % (self.ctx.degree << 1); + let mask = self.ctx.degree - 1; + let original_coefficients = self.coefficients.clone(); + izip!( + self.coefficients.outer_iter_mut(), + original_coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut coeffs, orig_coeffs, qi)| { + for k in 0..self.ctx.degree { + let index = shift + k; + if index & self.ctx.degree == 0 { + coeffs[index & mask] = orig_coeffs[k]; + } else { + coeffs[index & mask] = qi.neg(orig_coeffs[k]); + } + } + }); + Ok(()) + } } impl Zeroize for Poly { - fn zeroize(&mut self) { - self.coefficients.as_slice_mut().unwrap().zeroize(); - if let Some(s) = self.coefficients_shoup.as_mut() { - s.as_slice_mut().unwrap().zeroize(); - } - } + fn zeroize(&mut self) { + self.coefficients.as_slice_mut().unwrap().zeroize(); + if let Some(s) = self.coefficients_shoup.as_mut() { + s.as_slice_mut().unwrap().zeroize(); + } + } } #[cfg(test)] mod tests { - use super::{switcher::Switcher, Context, Poly, Representation}; - use crate::{rq::SubstitutionExponent, zq::Modulus}; - use fhe_util::variance; - use itertools::Itertools; - use num_bigint::BigUint; - use num_traits::{One, Zero}; - use rand::{thread_rng, Rng, SeedableRng}; - use rand_chacha::ChaCha8Rng; - use std::{error::Error, sync::Arc}; - - // Moduli to be used in tests. - const MODULI: &[u64; 5] = &[ - 1153, - 4611686018326724609, - 4611686018309947393, - 4611686018232352769, - 4611686018171535361, - ]; - - #[test] - fn poly_zero() -> Result<(), Box> { - let reference = &[ - BigUint::zero(), - BigUint::zero(), - BigUint::zero(), - BigUint::zero(), - BigUint::zero(), - BigUint::zero(), - BigUint::zero(), - BigUint::zero(), - ]; - - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); - let q = Poly::zero(&ctx, Representation::Ntt); - assert_ne!(p, q); - assert_eq!(Vec::::from(&p), &[0; 8]); - assert_eq!(Vec::::from(&q), &[0; 8]); - } - - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); - let q = Poly::zero(&ctx, Representation::Ntt); - assert_ne!(p, q); - assert_eq!(Vec::::from(&p), [0; 8 * MODULI.len()]); - assert_eq!(Vec::::from(&q), [0; 8 * MODULI.len()]); - assert_eq!(Vec::::from(&p), reference); - assert_eq!(Vec::::from(&q), reference); - - Ok(()) - } - - #[test] - fn ctx() -> Result<(), Box> { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); - assert_eq!(p.ctx(), &ctx); - } - - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::zero(&ctx, Representation::PowerBasis); - assert_eq!(p.ctx(), &ctx); - - Ok(()) - } - - #[test] - fn random() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - let mut seed = ::Seed::default(); - thread_rng().fill(&mut seed); - - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - assert_eq!(p, q); - } - - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - assert_eq!(p, q); - - thread_rng().fill(&mut seed); - let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); - assert_ne!(p, q); - - let r = Poly::random(&ctx, Representation::Ntt, &mut rng); - assert_ne!(p, r); - assert_ne!(q, r); - } - Ok(()) - } - - #[test] - fn coefficients() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..50 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let p_coefficients = Vec::::from(&p); - assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) - } - - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let p_coefficients = Vec::::from(&p); - assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) - } - Ok(()) - } - - #[test] - fn modulus() -> Result<(), Box> { - for modulus in MODULI { - let modulus_biguint = BigUint::from(*modulus); - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - assert_eq!(ctx.modulus(), &modulus_biguint) - } - - let mut modulus_biguint = BigUint::one(); - MODULI.iter().for_each(|m| modulus_biguint *= *m); - let ctx = Arc::new(Context::new(MODULI, 8)?); - assert_eq!(ctx.modulus(), &modulus_biguint); - - Ok(()) - } - - #[test] - fn allow_variable_time_computations() -> Result<(), Box> { - let mut rng = thread_rng(); - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let mut p = Poly::random(&ctx, Representation::default(), &mut rng); - assert!(!p.allow_variable_time_computations); - - unsafe { p.allow_variable_time_computations() } - assert!(p.allow_variable_time_computations); - - let q = p.clone(); - assert!(q.allow_variable_time_computations); - - p.disallow_variable_time_computations(); - assert!(!p.allow_variable_time_computations); - } - - let ctx = Arc::new(Context::new(MODULI, 8)?); - let mut p = Poly::random(&ctx, Representation::default(), &mut rng); - assert!(!p.allow_variable_time_computations); - - unsafe { p.allow_variable_time_computations() } - assert!(p.allow_variable_time_computations); - - let q = p.clone(); - assert!(q.allow_variable_time_computations); - - // Allowing variable time propagates. - let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); - unsafe { p.allow_variable_time_computations() } - let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); - - assert!(!q.allow_variable_time_computations); - q *= &p; - assert!(q.allow_variable_time_computations); - - q.disallow_variable_time_computations(); - q += &p; - assert!(q.allow_variable_time_computations); - - q.disallow_variable_time_computations(); - q -= &p; - assert!(q.allow_variable_time_computations); - - q = -&p; - assert!(q.allow_variable_time_computations); - - Ok(()) - } - - #[test] - fn change_representation() -> Result<(), Box> { - let mut rng = thread_rng(); - let ctx = Arc::new(Context::new(MODULI, 8)?); - - let mut p = Poly::random(&ctx, Representation::default(), &mut rng); - assert_eq!(p.representation, Representation::default()); - assert_eq!(p.representation(), &Representation::default()); - - p.change_representation(Representation::PowerBasis); - assert_eq!(p.representation, Representation::PowerBasis); - assert_eq!(p.representation(), &Representation::PowerBasis); - assert!(p.coefficients_shoup.is_none()); - let q = p.clone(); - - p.change_representation(Representation::Ntt); - assert_eq!(p.representation, Representation::Ntt); - assert_eq!(p.representation(), &Representation::Ntt); - assert_ne!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_none()); - let q_ntt = p.clone(); - - p.change_representation(Representation::NttShoup); - assert_eq!(p.representation, Representation::NttShoup); - assert_eq!(p.representation(), &Representation::NttShoup); - assert_ne!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_some()); - let q_ntt_shoup = p.clone(); - - p.change_representation(Representation::PowerBasis); - assert_eq!(p, q); - - p.change_representation(Representation::NttShoup); - assert_eq!(p, q_ntt_shoup); - - p.change_representation(Representation::Ntt); - assert_eq!(p, q_ntt); - - p.change_representation(Representation::PowerBasis); - assert_eq!(p, q); - - Ok(()) - } - - #[test] - fn override_representation() -> Result<(), Box> { - let mut rng = thread_rng(); - let ctx = Arc::new(Context::new(MODULI, 8)?); - - let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - assert_eq!(p.representation(), &p.representation); - let q = p.clone(); - - unsafe { p.override_representation(Representation::Ntt) } - assert_eq!(p.representation, Representation::Ntt); - assert_eq!(p.representation(), &p.representation); - assert_eq!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_none()); - - unsafe { p.override_representation(Representation::NttShoup) } - assert_eq!(p.representation, Representation::NttShoup); - assert_eq!(p.representation(), &p.representation); - assert_eq!(p.coefficients, q.coefficients); - assert!(p.coefficients_shoup.is_some()); - - unsafe { p.override_representation(Representation::PowerBasis) } - assert_eq!(p, q); - - unsafe { p.override_representation(Representation::NttShoup) } - assert!(p.coefficients_shoup.is_some()); - - unsafe { p.override_representation(Representation::Ntt) } - assert!(p.coefficients_shoup.is_none()); - - Ok(()) - } - - #[test] - fn small() -> Result<(), Box> { - let mut rng = thread_rng(); - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let q = Modulus::new(*modulus).unwrap(); - - let e = Poly::small(&ctx, Representation::PowerBasis, 0, &mut rng); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err().to_string(), - "The variance should be an integer between 1 and 16" - ); - let e = Poly::small(&ctx, Representation::PowerBasis, 17, &mut rng); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err().to_string(), - "The variance should be an integer between 1 and 16" - ); - - for i in 1..=16 { - let p = Poly::small(&ctx, Representation::PowerBasis, i, &mut rng)?; - let coefficients = p.coefficients().to_slice().unwrap(); - let v = unsafe { q.center_vec_vt(coefficients) }; - - assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * i as i64); - } - } - - // Generate a very large polynomial to check the variance (here equal to 8). - let ctx = Arc::new(Context::new(&[4611686018326724609], 1 << 18)?); - let q = Modulus::new(4611686018326724609).unwrap(); - let p = Poly::small(&ctx, Representation::PowerBasis, 8, &mut thread_rng())?; - let coefficients = p.coefficients().to_slice().unwrap(); - let v = unsafe { q.center_vec_vt(coefficients) }; - assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 16); - assert_eq!(variance(&v).round(), 8.0); - - Ok(()) - } - - #[test] - fn substitute() -> Result<(), Box> { - let mut rng = thread_rng(); - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut p_ntt = p.clone(); - p_ntt.change_representation(Representation::Ntt); - let mut p_ntt_shoup = p.clone(); - p_ntt_shoup.change_representation(Representation::NttShoup); - let p_coeffs = Vec::::from(&p); - - // Substitution by a multiple of 2 * degree, or even numbers, should fail - assert!(SubstitutionExponent::new(&ctx, 0).is_err()); - assert!(SubstitutionExponent::new(&ctx, 2).is_err()); - assert!(SubstitutionExponent::new(&ctx, 16).is_err()); - - // Substitution by 1 leaves the polynomials unchanged - assert_eq!(p, p.substitute(&SubstitutionExponent::new(&ctx, 1)?)?); - assert_eq!( - p_ntt, - p_ntt.substitute(&SubstitutionExponent::new(&ctx, 1)?)? - ); - assert_eq!( - p_ntt_shoup, - p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 1)?)? - ); - - // Substitution by 3 - let mut q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; - let mut v = vec![0u64; 8]; - for i in 0..8 { - v[(3 * i) % 8] = if ((3 * i) / 8) & 1 == 1 && p_coeffs[i] > 0 { - *modulus - p_coeffs[i] - } else { - p_coeffs[i] - }; - } - assert_eq!(&Vec::::from(&q), &v); - - let q_ntt = p_ntt.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; - q.change_representation(Representation::Ntt); - assert_eq!(q, q_ntt); - - let q_ntt_shoup = p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; - q.change_representation(Representation::NttShoup); - assert_eq!(q, q_ntt_shoup); - - // 11 = 3^(-1) % 16 - assert_eq!( - p, - p.substitute(&SubstitutionExponent::new(&ctx, 3)?)? - .substitute(&SubstitutionExponent::new(&ctx, 11)?)? - ); - assert_eq!( - p_ntt, - p_ntt - .substitute(&SubstitutionExponent::new(&ctx, 3)?)? - .substitute(&SubstitutionExponent::new(&ctx, 11)?)? - ); - assert_eq!( - p_ntt_shoup, - p_ntt_shoup - .substitute(&SubstitutionExponent::new(&ctx, 3)?)? - .substitute(&SubstitutionExponent::new(&ctx, 11)?)? - ); - } - - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut p_ntt = p.clone(); - p_ntt.change_representation(Representation::Ntt); - let mut p_ntt_shoup = p.clone(); - p_ntt_shoup.change_representation(Representation::NttShoup); - - assert_eq!( - p, - p.substitute(&SubstitutionExponent::new(&ctx, 3)?)? - .substitute(&SubstitutionExponent::new(&ctx, 11)?)? - ); - assert_eq!( - p_ntt, - p_ntt - .substitute(&SubstitutionExponent::new(&ctx, 3)?)? - .substitute(&SubstitutionExponent::new(&ctx, 11)?)? - ); - assert_eq!( - p_ntt_shoup, - p_ntt_shoup - .substitute(&SubstitutionExponent::new(&ctx, 3)?)? - .substitute(&SubstitutionExponent::new(&ctx, 11)?)? - ); - - Ok(()) - } - - #[test] - fn mod_switch_down_next() -> Result<(), Box> { - let mut rng = thread_rng(); - let ntests = 100; - let ctx = Arc::new(Context::new(MODULI, 8)?); - - for _ in 0..ntests { - // If the polynomial has incorrect representation, an error is returned - let e = Poly::random(&ctx, Representation::Ntt, &mut rng).mod_switch_down_next(); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::IncorrectRepresentation( - Representation::Ntt, - Representation::PowerBasis - ) - ); - - // Otherwise, no error happens and the coefficients evolve as expected. - let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut reference = Vec::::from(&p); - let mut current_ctx = ctx.clone(); - assert_eq!(p.ctx, current_ctx); - while current_ctx.next_context.is_some() { - let denominator = current_ctx.modulus().clone(); - current_ctx = current_ctx.next_context.as_ref().unwrap().clone(); - let numerator = current_ctx.modulus().clone(); - assert!(p.mod_switch_down_next().is_ok()); - assert_eq!(p.ctx, current_ctx); - let p_biguint = Vec::::from(&p); - assert_eq!( - p_biguint, - reference - .iter() - .map( - |b| (((b * &numerator) + (&denominator >> 1)) / &denominator) - % current_ctx.modulus() - ) - .collect_vec() - ); - reference = p_biguint.clone(); - } - } - Ok(()) - } - - #[test] - fn mod_switch_down_to() -> Result<(), Box> { - let mut rng = thread_rng(); - let ntests = 100; - let ctx1 = Arc::new(Context::new(MODULI, 8)?); - let ctx2 = Arc::new(Context::new(&MODULI[..2], 8)?); - - for _ in 0..ntests { - let mut p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); - let reference = Vec::::from(&p); - - p.mod_switch_down_to(&ctx2)?; - - assert_eq!(p.ctx, ctx2); - assert_eq!( - Vec::::from(&p), - reference - .iter() - .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus()) - .collect_vec() - ); - } - - Ok(()) - } - - #[test] - fn mod_switch_to() -> Result<(), Box> { - let mut rng = thread_rng(); - let ntests = 100; - let ctx1 = Arc::new(Context::new(&MODULI[..2], 8)?); - let ctx2 = Arc::new(Context::new(&MODULI[3..], 8)?); - let switcher = Switcher::new(&ctx1, &ctx2)?; - for _ in 0..ntests { - let p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); - let reference = Vec::::from(&p); - - let q = p.mod_switch_to(&switcher)?; - - assert_eq!(q.ctx, ctx2); - assert_eq!( - Vec::::from(&q), - reference - .iter() - .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus()) - .collect_vec() - ); - } - Ok(()) - } - - #[test] - fn mul_x_power() -> Result<(), Box> { - let mut rng = thread_rng(); - let ctx = Arc::new(Context::new(MODULI, 8)?); - let e = Poly::random(&ctx, Representation::Ntt, &mut rng).multiply_inverse_power_of_x(1); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::IncorrectRepresentation(Representation::Ntt, Representation::PowerBasis) - ); - - let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = p.clone(); - - p.multiply_inverse_power_of_x(0)?; - assert_eq!(p, q); - - p.multiply_inverse_power_of_x(1)?; - assert_ne!(p, q); - - p.multiply_inverse_power_of_x(2 * ctx.degree - 1)?; - assert_eq!(p, q); - - p.multiply_inverse_power_of_x(ctx.degree)?; - assert_eq!( - Vec::::from(&p) - .iter() - .map(|c| ctx.modulus() - c) - .collect_vec(), - Vec::::from(&q) - ); - - Ok(()) - } + use super::{switcher::Switcher, Context, Poly, Representation}; + use crate::{rq::SubstitutionExponent, zq::Modulus}; + use fhe_util::variance; + use itertools::Itertools; + use num_bigint::BigUint; + use num_traits::{One, Zero}; + use rand::{thread_rng, Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + use std::{error::Error, sync::Arc}; + + // Moduli to be used in tests. + const MODULI: &[u64; 5] = &[ + 1153, + 4611686018326724609, + 4611686018309947393, + 4611686018232352769, + 4611686018171535361, + ]; + + #[test] + fn poly_zero() -> Result<(), Box> { + let reference = &[ + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + BigUint::zero(), + ]; + + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::zero(&ctx, Representation::PowerBasis); + let q = Poly::zero(&ctx, Representation::Ntt); + assert_ne!(p, q); + assert_eq!(Vec::::from(&p), &[0; 8]); + assert_eq!(Vec::::from(&q), &[0; 8]); + } + + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::zero(&ctx, Representation::PowerBasis); + let q = Poly::zero(&ctx, Representation::Ntt); + assert_ne!(p, q); + assert_eq!(Vec::::from(&p), [0; 8 * MODULI.len()]); + assert_eq!(Vec::::from(&q), [0; 8 * MODULI.len()]); + assert_eq!(Vec::::from(&p), reference); + assert_eq!(Vec::::from(&q), reference); + + Ok(()) + } + + #[test] + fn ctx() -> Result<(), Box> { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::zero(&ctx, Representation::PowerBasis); + assert_eq!(p.ctx(), &ctx); + } + + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::zero(&ctx, Representation::PowerBasis); + assert_eq!(p.ctx(), &ctx); + + Ok(()) + } + + #[test] + fn random() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + let mut seed = ::Seed::default(); + thread_rng().fill(&mut seed); + + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + assert_eq!(p, q); + } + + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + assert_eq!(p, q); + + thread_rng().fill(&mut seed); + let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed); + assert_ne!(p, q); + + let r = Poly::random(&ctx, Representation::Ntt, &mut rng); + assert_ne!(p, r); + assert_ne!(q, r); + } + Ok(()) + } + + #[test] + fn coefficients() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..50 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p_coefficients = Vec::::from(&p); + assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) + } + + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let p_coefficients = Vec::::from(&p); + assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap()) + } + Ok(()) + } + + #[test] + fn modulus() -> Result<(), Box> { + for modulus in MODULI { + let modulus_biguint = BigUint::from(*modulus); + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + assert_eq!(ctx.modulus(), &modulus_biguint) + } + + let mut modulus_biguint = BigUint::one(); + MODULI.iter().for_each(|m| modulus_biguint *= *m); + let ctx = Arc::new(Context::new(MODULI, 8)?); + assert_eq!(ctx.modulus(), &modulus_biguint); + + Ok(()) + } + + #[test] + fn allow_variable_time_computations() -> Result<(), Box> { + let mut rng = thread_rng(); + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let mut p = Poly::random(&ctx, Representation::default(), &mut rng); + assert!(!p.allow_variable_time_computations); + + unsafe { p.allow_variable_time_computations() } + assert!(p.allow_variable_time_computations); + + let q = p.clone(); + assert!(q.allow_variable_time_computations); + + p.disallow_variable_time_computations(); + assert!(!p.allow_variable_time_computations); + } + + let ctx = Arc::new(Context::new(MODULI, 8)?); + let mut p = Poly::random(&ctx, Representation::default(), &mut rng); + assert!(!p.allow_variable_time_computations); + + unsafe { p.allow_variable_time_computations() } + assert!(p.allow_variable_time_computations); + + let q = p.clone(); + assert!(q.allow_variable_time_computations); + + // Allowing variable time propagates. + let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng); + unsafe { p.allow_variable_time_computations() } + let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng); + + assert!(!q.allow_variable_time_computations); + q *= &p; + assert!(q.allow_variable_time_computations); + + q.disallow_variable_time_computations(); + q += &p; + assert!(q.allow_variable_time_computations); + + q.disallow_variable_time_computations(); + q -= &p; + assert!(q.allow_variable_time_computations); + + q = -&p; + assert!(q.allow_variable_time_computations); + + Ok(()) + } + + #[test] + fn change_representation() -> Result<(), Box> { + let mut rng = thread_rng(); + let ctx = Arc::new(Context::new(MODULI, 8)?); + + let mut p = Poly::random(&ctx, Representation::default(), &mut rng); + assert_eq!(p.representation, Representation::default()); + assert_eq!(p.representation(), &Representation::default()); + + p.change_representation(Representation::PowerBasis); + assert_eq!(p.representation, Representation::PowerBasis); + assert_eq!(p.representation(), &Representation::PowerBasis); + assert!(p.coefficients_shoup.is_none()); + let q = p.clone(); + + p.change_representation(Representation::Ntt); + assert_eq!(p.representation, Representation::Ntt); + assert_eq!(p.representation(), &Representation::Ntt); + assert_ne!(p.coefficients, q.coefficients); + assert!(p.coefficients_shoup.is_none()); + let q_ntt = p.clone(); + + p.change_representation(Representation::NttShoup); + assert_eq!(p.representation, Representation::NttShoup); + assert_eq!(p.representation(), &Representation::NttShoup); + assert_ne!(p.coefficients, q.coefficients); + assert!(p.coefficients_shoup.is_some()); + let q_ntt_shoup = p.clone(); + + p.change_representation(Representation::PowerBasis); + assert_eq!(p, q); + + p.change_representation(Representation::NttShoup); + assert_eq!(p, q_ntt_shoup); + + p.change_representation(Representation::Ntt); + assert_eq!(p, q_ntt); + + p.change_representation(Representation::PowerBasis); + assert_eq!(p, q); + + Ok(()) + } + + #[test] + fn override_representation() -> Result<(), Box> { + let mut rng = thread_rng(); + let ctx = Arc::new(Context::new(MODULI, 8)?); + + let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + assert_eq!(p.representation(), &p.representation); + let q = p.clone(); + + unsafe { p.override_representation(Representation::Ntt) } + assert_eq!(p.representation, Representation::Ntt); + assert_eq!(p.representation(), &p.representation); + assert_eq!(p.coefficients, q.coefficients); + assert!(p.coefficients_shoup.is_none()); + + unsafe { p.override_representation(Representation::NttShoup) } + assert_eq!(p.representation, Representation::NttShoup); + assert_eq!(p.representation(), &p.representation); + assert_eq!(p.coefficients, q.coefficients); + assert!(p.coefficients_shoup.is_some()); + + unsafe { p.override_representation(Representation::PowerBasis) } + assert_eq!(p, q); + + unsafe { p.override_representation(Representation::NttShoup) } + assert!(p.coefficients_shoup.is_some()); + + unsafe { p.override_representation(Representation::Ntt) } + assert!(p.coefficients_shoup.is_none()); + + Ok(()) + } + + #[test] + fn small() -> Result<(), Box> { + let mut rng = thread_rng(); + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let q = Modulus::new(*modulus).unwrap(); + + let e = Poly::small(&ctx, Representation::PowerBasis, 0, &mut rng); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err().to_string(), + "The variance should be an integer between 1 and 16" + ); + let e = Poly::small(&ctx, Representation::PowerBasis, 17, &mut rng); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err().to_string(), + "The variance should be an integer between 1 and 16" + ); + + for i in 1..=16 { + let p = Poly::small(&ctx, Representation::PowerBasis, i, &mut rng)?; + let coefficients = p.coefficients().to_slice().unwrap(); + let v = unsafe { q.center_vec_vt(coefficients) }; + + assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * i as i64); + } + } + + // Generate a very large polynomial to check the variance (here equal to 8). + let ctx = Arc::new(Context::new(&[4611686018326724609], 1 << 18)?); + let q = Modulus::new(4611686018326724609).unwrap(); + let p = Poly::small(&ctx, Representation::PowerBasis, 8, &mut thread_rng())?; + let coefficients = p.coefficients().to_slice().unwrap(); + let v = unsafe { q.center_vec_vt(coefficients) }; + assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 16); + assert_eq!(variance(&v).round(), 8.0); + + Ok(()) + } + + #[test] + fn substitute() -> Result<(), Box> { + let mut rng = thread_rng(); + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut p_ntt = p.clone(); + p_ntt.change_representation(Representation::Ntt); + let mut p_ntt_shoup = p.clone(); + p_ntt_shoup.change_representation(Representation::NttShoup); + let p_coeffs = Vec::::from(&p); + + // Substitution by a multiple of 2 * degree, or even numbers, should fail + assert!(SubstitutionExponent::new(&ctx, 0).is_err()); + assert!(SubstitutionExponent::new(&ctx, 2).is_err()); + assert!(SubstitutionExponent::new(&ctx, 16).is_err()); + + // Substitution by 1 leaves the polynomials unchanged + assert_eq!(p, p.substitute(&SubstitutionExponent::new(&ctx, 1)?)?); + assert_eq!( + p_ntt, + p_ntt.substitute(&SubstitutionExponent::new(&ctx, 1)?)? + ); + assert_eq!( + p_ntt_shoup, + p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 1)?)? + ); + + // Substitution by 3 + let mut q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; + let mut v = vec![0u64; 8]; + for i in 0..8 { + v[(3 * i) % 8] = if ((3 * i) / 8) & 1 == 1 && p_coeffs[i] > 0 { + *modulus - p_coeffs[i] + } else { + p_coeffs[i] + }; + } + assert_eq!(&Vec::::from(&q), &v); + + let q_ntt = p_ntt.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; + q.change_representation(Representation::Ntt); + assert_eq!(q, q_ntt); + + let q_ntt_shoup = p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 3)?)?; + q.change_representation(Representation::NttShoup); + assert_eq!(q, q_ntt_shoup); + + // 11 = 3^(-1) % 16 + assert_eq!( + p, + p.substitute(&SubstitutionExponent::new(&ctx, 3)?)? + .substitute(&SubstitutionExponent::new(&ctx, 11)?)? + ); + assert_eq!( + p_ntt, + p_ntt + .substitute(&SubstitutionExponent::new(&ctx, 3)?)? + .substitute(&SubstitutionExponent::new(&ctx, 11)?)? + ); + assert_eq!( + p_ntt_shoup, + p_ntt_shoup + .substitute(&SubstitutionExponent::new(&ctx, 3)?)? + .substitute(&SubstitutionExponent::new(&ctx, 11)?)? + ); + } + + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut p_ntt = p.clone(); + p_ntt.change_representation(Representation::Ntt); + let mut p_ntt_shoup = p.clone(); + p_ntt_shoup.change_representation(Representation::NttShoup); + + assert_eq!( + p, + p.substitute(&SubstitutionExponent::new(&ctx, 3)?)? + .substitute(&SubstitutionExponent::new(&ctx, 11)?)? + ); + assert_eq!( + p_ntt, + p_ntt + .substitute(&SubstitutionExponent::new(&ctx, 3)?)? + .substitute(&SubstitutionExponent::new(&ctx, 11)?)? + ); + assert_eq!( + p_ntt_shoup, + p_ntt_shoup + .substitute(&SubstitutionExponent::new(&ctx, 3)?)? + .substitute(&SubstitutionExponent::new(&ctx, 11)?)? + ); + + Ok(()) + } + + #[test] + fn mod_switch_down_next() -> Result<(), Box> { + let mut rng = thread_rng(); + let ntests = 100; + let ctx = Arc::new(Context::new(MODULI, 8)?); + + for _ in 0..ntests { + // If the polynomial has incorrect representation, an error is returned + let e = Poly::random(&ctx, Representation::Ntt, &mut rng).mod_switch_down_next(); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::IncorrectRepresentation( + Representation::Ntt, + Representation::PowerBasis + ) + ); + + // Otherwise, no error happens and the coefficients evolve as expected. + let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut reference = Vec::::from(&p); + let mut current_ctx = ctx.clone(); + assert_eq!(p.ctx, current_ctx); + while current_ctx.next_context.is_some() { + let denominator = current_ctx.modulus().clone(); + current_ctx = current_ctx.next_context.as_ref().unwrap().clone(); + let numerator = current_ctx.modulus().clone(); + assert!(p.mod_switch_down_next().is_ok()); + assert_eq!(p.ctx, current_ctx); + let p_biguint = Vec::::from(&p); + assert_eq!( + p_biguint, + reference + .iter() + .map( + |b| (((b * &numerator) + (&denominator >> 1)) / &denominator) + % current_ctx.modulus() + ) + .collect_vec() + ); + reference = p_biguint.clone(); + } + } + Ok(()) + } + + #[test] + fn mod_switch_down_to() -> Result<(), Box> { + let mut rng = thread_rng(); + let ntests = 100; + let ctx1 = Arc::new(Context::new(MODULI, 8)?); + let ctx2 = Arc::new(Context::new(&MODULI[..2], 8)?); + + for _ in 0..ntests { + let mut p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); + let reference = Vec::::from(&p); + + p.mod_switch_down_to(&ctx2)?; + + assert_eq!(p.ctx, ctx2); + assert_eq!( + Vec::::from(&p), + reference + .iter() + .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus()) + .collect_vec() + ); + } + + Ok(()) + } + + #[test] + fn mod_switch_to() -> Result<(), Box> { + let mut rng = thread_rng(); + let ntests = 100; + let ctx1 = Arc::new(Context::new(&MODULI[..2], 8)?); + let ctx2 = Arc::new(Context::new(&MODULI[3..], 8)?); + let switcher = Switcher::new(&ctx1, &ctx2)?; + for _ in 0..ntests { + let p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng); + let reference = Vec::::from(&p); + + let q = p.mod_switch_to(&switcher)?; + + assert_eq!(q.ctx, ctx2); + assert_eq!( + Vec::::from(&q), + reference + .iter() + .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus()) + .collect_vec() + ); + } + Ok(()) + } + + #[test] + fn mul_x_power() -> Result<(), Box> { + let mut rng = thread_rng(); + let ctx = Arc::new(Context::new(MODULI, 8)?); + let e = Poly::random(&ctx, Representation::Ntt, &mut rng).multiply_inverse_power_of_x(1); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::IncorrectRepresentation(Representation::Ntt, Representation::PowerBasis) + ); + + let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let q = p.clone(); + + p.multiply_inverse_power_of_x(0)?; + assert_eq!(p, q); + + p.multiply_inverse_power_of_x(1)?; + assert_ne!(p, q); + + p.multiply_inverse_power_of_x(2 * ctx.degree - 1)?; + assert_eq!(p, q); + + p.multiply_inverse_power_of_x(ctx.degree)?; + assert_eq!( + Vec::::from(&p) + .iter() + .map(|c| ctx.modulus() - c) + .collect_vec(), + Vec::::from(&q) + ); + + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/ops.rs b/crates/fhe-math/src/rq/ops.rs index a58758a..71994e1 100644 --- a/crates/fhe-math/src/rq/ops.rs +++ b/crates/fhe-math/src/rq/ops.rs @@ -6,343 +6,343 @@ use itertools::{izip, Itertools}; use ndarray::Array2; use num_bigint::BigUint; use std::{ - cmp::min, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + cmp::min, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; use zeroize::Zeroize; impl AddAssign<&Poly> for Poly { - fn add_assign(&mut self, p: &Poly) { - assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); - assert_ne!( - self.representation, - Representation::NttShoup, - "Cannot add to a polynomial in NttShoup representation" - ); - assert_eq!( - self.representation, p.representation, - "Incompatible representations" - ); - debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); - self.allow_variable_time_computations |= p.allow_variable_time_computations; - if self.allow_variable_time_computations { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| unsafe { - qi.add_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } else { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.add_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } - } + fn add_assign(&mut self, p: &Poly) { + assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); + assert_ne!( + self.representation, + Representation::NttShoup, + "Cannot add to a polynomial in NttShoup representation" + ); + assert_eq!( + self.representation, p.representation, + "Incompatible representations" + ); + debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); + self.allow_variable_time_computations |= p.allow_variable_time_computations; + if self.allow_variable_time_computations { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| unsafe { + qi.add_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.add_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } + } } impl Add<&Poly> for &Poly { - type Output = Poly; - fn add(self, p: &Poly) -> Poly { - let mut q = self.clone(); - q += p; - q - } + type Output = Poly; + fn add(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q += p; + q + } } impl Add for Poly { - type Output = Poly; - fn add(self, mut p: Poly) -> Poly { - p += &self; - p - } + type Output = Poly; + fn add(self, mut p: Poly) -> Poly { + p += &self; + p + } } impl SubAssign<&Poly> for Poly { - fn sub_assign(&mut self, p: &Poly) { - assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); - assert_ne!( - self.representation, - Representation::NttShoup, - "Cannot subtract from a polynomial in NttShoup representation" - ); - assert_eq!( - self.representation, p.representation, - "Incompatible representations" - ); - debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); - self.allow_variable_time_computations |= p.allow_variable_time_computations; - if self.allow_variable_time_computations { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| unsafe { - qi.sub_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } else { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.sub_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } - } + fn sub_assign(&mut self, p: &Poly) { + assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients); + assert_ne!( + self.representation, + Representation::NttShoup, + "Cannot subtract from a polynomial in NttShoup representation" + ); + assert_eq!( + self.representation, p.representation, + "Incompatible representations" + ); + debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); + self.allow_variable_time_computations |= p.allow_variable_time_computations; + if self.allow_variable_time_computations { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| unsafe { + qi.sub_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.sub_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } + } } impl Sub<&Poly> for &Poly { - type Output = Poly; - fn sub(self, p: &Poly) -> Poly { - let mut q = self.clone(); - q -= p; - q - } + type Output = Poly; + fn sub(self, p: &Poly) -> Poly { + let mut q = self.clone(); + q -= p; + q + } } impl MulAssign<&Poly> for Poly { - fn mul_assign(&mut self, p: &Poly) { - assert!(!p.has_lazy_coefficients); - assert_ne!( - self.representation, - Representation::NttShoup, - "Cannot multiply to a polynomial in NttShoup representation" - ); - if self.has_lazy_coefficients && self.representation == Representation::Ntt { - assert!( + fn mul_assign(&mut self, p: &Poly) { + assert!(!p.has_lazy_coefficients); + assert_ne!( + self.representation, + Representation::NttShoup, + "Cannot multiply to a polynomial in NttShoup representation" + ); + if self.has_lazy_coefficients && self.representation == Representation::Ntt { + assert!( p.representation == Representation::NttShoup, "Can only multiply a polynomial with lazy coefficients by an NttShoup representation." ); - } else { - assert_eq!( - self.representation, - Representation::Ntt, - "Multiplication requires an Ntt representation." - ); - } - debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); - self.allow_variable_time_computations |= p.allow_variable_time_computations; + } else { + assert_eq!( + self.representation, + Representation::Ntt, + "Multiplication requires an Ntt representation." + ); + } + debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts"); + self.allow_variable_time_computations |= p.allow_variable_time_computations; - match p.representation { - Representation::Ntt => { - if self.allow_variable_time_computations { - unsafe { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()); - }); - } - } else { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } - } - Representation::NttShoup => { - if self.allow_variable_time_computations { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - p.coefficients_shoup.as_ref().unwrap().outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, v2_shoup, qi)| unsafe { - qi.mul_shoup_vec_vt( - v1.as_slice_mut().unwrap(), - v2.as_slice().unwrap(), - v2_shoup.as_slice().unwrap(), - ) - }); - } else { - izip!( - self.coefficients.outer_iter_mut(), - p.coefficients.outer_iter(), - p.coefficients_shoup.as_ref().unwrap().outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, v2_shoup, qi)| { - qi.mul_shoup_vec( - v1.as_slice_mut().unwrap(), - v2.as_slice().unwrap(), - v2_shoup.as_slice().unwrap(), - ) - }); - } - self.has_lazy_coefficients = false - } - _ => { - panic!("Multiplication requires a multipliand in Ntt or NttShoup representation.") - } - } - } + match p.representation { + Representation::Ntt => { + if self.allow_variable_time_computations { + unsafe { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()); + }); + } + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } + } + Representation::NttShoup => { + if self.allow_variable_time_computations { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + p.coefficients_shoup.as_ref().unwrap().outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, v2_shoup, qi)| unsafe { + qi.mul_shoup_vec_vt( + v1.as_slice_mut().unwrap(), + v2.as_slice().unwrap(), + v2_shoup.as_slice().unwrap(), + ) + }); + } else { + izip!( + self.coefficients.outer_iter_mut(), + p.coefficients.outer_iter(), + p.coefficients_shoup.as_ref().unwrap().outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, v2_shoup, qi)| { + qi.mul_shoup_vec( + v1.as_slice_mut().unwrap(), + v2.as_slice().unwrap(), + v2_shoup.as_slice().unwrap(), + ) + }); + } + self.has_lazy_coefficients = false + } + _ => { + panic!("Multiplication requires a multipliand in Ntt or NttShoup representation.") + } + } + } } impl MulAssign<&BigUint> for Poly { - fn mul_assign(&mut self, p: &BigUint) { - let v: Vec = vec![p.clone()]; - let mut q = Poly::try_convert_from( - v.as_ref() as &[BigUint], - &self.ctx, - self.allow_variable_time_computations, - self.representation.clone(), - ) - .unwrap(); - q.change_representation(Representation::Ntt); - if self.allow_variable_time_computations { - unsafe { - izip!( - self.coefficients.outer_iter_mut(), - q.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } - } else { - izip!( - self.coefficients.outer_iter_mut(), - q.coefficients.outer_iter(), - self.ctx.q.iter() - ) - .for_each(|(mut v1, v2, qi)| { - qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) - }); - } - } + fn mul_assign(&mut self, p: &BigUint) { + let v: Vec = vec![p.clone()]; + let mut q = Poly::try_convert_from( + v.as_ref() as &[BigUint], + &self.ctx, + self.allow_variable_time_computations, + self.representation.clone(), + ) + .unwrap(); + q.change_representation(Representation::Ntt); + if self.allow_variable_time_computations { + unsafe { + izip!( + self.coefficients.outer_iter_mut(), + q.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } + } else { + izip!( + self.coefficients.outer_iter_mut(), + q.coefficients.outer_iter(), + self.ctx.q.iter() + ) + .for_each(|(mut v1, v2, qi)| { + qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap()) + }); + } + } } impl Mul<&Poly> for &Poly { - type Output = Poly; - fn mul(self, p: &Poly) -> Poly { - match self.representation { - Representation::NttShoup => { - // TODO: To test, and do the same thing for add, sub, and neg - let mut q = p.clone(); - if q.representation == Representation::NttShoup { - q.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - unsafe { q.override_representation(Representation::Ntt) } - } - q *= self; - q - } - _ => { - let mut q = self.clone(); - q *= p; - q - } - } - } + type Output = Poly; + fn mul(self, p: &Poly) -> Poly { + match self.representation { + Representation::NttShoup => { + // TODO: To test, and do the same thing for add, sub, and neg + let mut q = p.clone(); + if q.representation == Representation::NttShoup { + q.coefficients_shoup + .as_mut() + .unwrap() + .as_slice_mut() + .unwrap() + .zeroize(); + unsafe { q.override_representation(Representation::Ntt) } + } + q *= self; + q + } + _ => { + let mut q = self.clone(); + q *= p; + q + } + } + } } impl Mul<&BigUint> for &Poly { - type Output = Poly; - fn mul(self, p: &BigUint) -> Poly { - let mut q = self.clone(); - q *= p; - q - } + type Output = Poly; + fn mul(self, p: &BigUint) -> Poly { + let mut q = self.clone(); + q *= p; + q + } } impl Mul<&Poly> for &BigUint { - type Output = Poly; - fn mul(self, p: &Poly) -> Poly { - p * self - } + type Output = Poly; + fn mul(self, p: &Poly) -> Poly { + p * self + } } impl Neg for &Poly { - type Output = Poly; + type Output = Poly; - fn neg(self) -> Poly { - assert!(!self.has_lazy_coefficients); - let mut out = self.clone(); - if self.allow_variable_time_computations { - izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) - .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); - } else { - izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) - .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap())); - } - out - } + fn neg(self) -> Poly { + assert!(!self.has_lazy_coefficients); + let mut out = self.clone(); + if self.allow_variable_time_computations { + izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) + .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); + } else { + izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter()) + .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap())); + } + out + } } impl Neg for Poly { - type Output = Poly; + type Output = Poly; - fn neg(mut self) -> Poly { - assert!(!self.has_lazy_coefficients); - if self.allow_variable_time_computations { - izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) - .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); - } else { - izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) - .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap())); - } - self - } + fn neg(mut self) -> Poly { + assert!(!self.has_lazy_coefficients); + if self.allow_variable_time_computations { + izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) + .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) }); + } else { + izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter()) + .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap())); + } + self + } } /// Computes the Fused-Mul-Add operation `out[i] += x[i] * y[i]` unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) { - let n = out.len(); - assert_eq!(x.len(), n); - assert_eq!(y.len(), n); + let n = out.len(); + assert_eq!(x.len(), n); + assert_eq!(y.len(), n); - macro_rules! fma_at { - ($idx:expr) => { - *out.get_unchecked_mut($idx) += - (*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128); - }; - } + macro_rules! fma_at { + ($idx:expr) => { + *out.get_unchecked_mut($idx) += + (*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128); + }; + } - let r = n / 16; - for i in 0..r { - fma_at!(16 * i); - fma_at!(16 * i + 1); - fma_at!(16 * i + 2); - fma_at!(16 * i + 3); - fma_at!(16 * i + 4); - fma_at!(16 * i + 5); - fma_at!(16 * i + 6); - fma_at!(16 * i + 7); - fma_at!(16 * i + 8); - fma_at!(16 * i + 9); - fma_at!(16 * i + 10); - fma_at!(16 * i + 11); - fma_at!(16 * i + 12); - fma_at!(16 * i + 13); - fma_at!(16 * i + 14); - fma_at!(16 * i + 15); - } + let r = n / 16; + for i in 0..r { + fma_at!(16 * i); + fma_at!(16 * i + 1); + fma_at!(16 * i + 2); + fma_at!(16 * i + 3); + fma_at!(16 * i + 4); + fma_at!(16 * i + 5); + fma_at!(16 * i + 6); + fma_at!(16 * i + 7); + fma_at!(16 * i + 8); + fma_at!(16 * i + 9); + fma_at!(16 * i + 10); + fma_at!(16 * i + 11); + fma_at!(16 * i + 12); + fma_at!(16 * i + 13); + fma_at!(16 * i + 14); + fma_at!(16 * i + 15); + } - for i in 0..n % 16 { - fma_at!(16 * r + i); - } + for i in 0..n % 16 { + fma_at!(16 * r + i); + } } /// Compute the dot product between two iterators of polynomials. @@ -350,351 +350,351 @@ unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) { /// is not in Ntt or NttShoup representation. pub fn dot_product<'a, 'b, I, J>(p: I, q: J) -> Result where - I: Iterator + Clone, - J: Iterator + Clone, + I: Iterator + Clone, + J: Iterator + Clone, { - debug_assert!(!p - .clone() - .any(|pi| pi.representation == Representation::PowerBasis)); - debug_assert!(!q - .clone() - .any(|qi| qi.representation == Representation::PowerBasis)); + debug_assert!(!p + .clone() + .any(|pi| pi.representation == Representation::PowerBasis)); + debug_assert!(!q + .clone() + .any(|qi| qi.representation == Representation::PowerBasis)); - let count = min(p.clone().count(), q.clone().count()); - if count == 0 { - return Err(Error::Default("At least one iterator is empty".to_string())); - } + let count = min(p.clone().count(), q.clone().count()); + if count == 0 { + return Err(Error::Default("At least one iterator is empty".to_string())); + } - let p_first = p.clone().next().unwrap(); + let p_first = p.clone().next().unwrap(); - // Initialize the accumulator - let mut acc: Array2 = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree)); - let acc_ptr = acc.as_mut_ptr(); + // Initialize the accumulator + let mut acc: Array2 = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree)); + let acc_ptr = acc.as_mut_ptr(); - // Current number of products accumulated - let mut num_acc = vec![1u128; p_first.ctx.q.len()]; - let num_acc_ptr = num_acc.as_mut_ptr(); + // Current number of products accumulated + let mut num_acc = vec![1u128; p_first.ctx.q.len()]; + let num_acc_ptr = num_acc.as_mut_ptr(); - // Maximum number of products that can be accumulated - let max_acc = p_first - .ctx - .q - .iter() - .map(|qi| 1u128 << (2 * qi.modulus().leading_zeros())) - .collect_vec(); - let max_acc_ptr = max_acc.as_ptr(); + // Maximum number of products that can be accumulated + let max_acc = p_first + .ctx + .q + .iter() + .map(|qi| 1u128 << (2 * qi.modulus().leading_zeros())) + .collect_vec(); + let max_acc_ptr = max_acc.as_ptr(); - let q_ptr = p_first.ctx.q.as_ptr(); - let degree = p_first.ctx.degree as isize; + let q_ptr = p_first.ctx.q.as_ptr(); + let degree = p_first.ctx.degree as isize; - let min_of_max = max_acc.iter().min().unwrap(); + let min_of_max = max_acc.iter().min().unwrap(); - let out_slice = acc.as_slice_mut().unwrap(); - if count as u128 > *min_of_max { - for (pi, qi) in izip!(p, q) { - let pij = pi.coefficients(); - let qij = qi.coefficients(); - let pi_slice = pij.as_slice().unwrap(); - let qi_slice = qij.as_slice().unwrap(); - unsafe { - fma(out_slice, pi_slice, qi_slice); + let out_slice = acc.as_slice_mut().unwrap(); + if count as u128 > *min_of_max { + for (pi, qi) in izip!(p, q) { + let pij = pi.coefficients(); + let qij = qi.coefficients(); + let pi_slice = pij.as_slice().unwrap(); + let qi_slice = qij.as_slice().unwrap(); + unsafe { + fma(out_slice, pi_slice, qi_slice); - for j in 0..p_first.ctx.q.len() as isize { - let qj = &*q_ptr.offset(j); - *num_acc_ptr.offset(j) += 1; - if *num_acc_ptr.offset(j) == *max_acc_ptr.offset(j) { - if p_first.allow_variable_time_computations { - for i in j * degree..(j + 1) * degree { - *acc_ptr.offset(i) = qj.reduce_u128_vt(*acc_ptr.offset(i)) as u128; - } - } else { - for i in j * degree..(j + 1) * degree { - *acc_ptr.offset(i) = qj.reduce_u128(*acc_ptr.offset(i)) as u128; - } - } - *num_acc_ptr.offset(j) = 1; - } - } - } - } - } else { - // We don't need to check the condition on the max, it should shave off a few - // cycles. - for (pi, qi) in izip!(p, q) { - let pij = pi.coefficients(); - let qij = qi.coefficients(); - let pi_slice = pij.as_slice().unwrap(); - let qi_slice = qij.as_slice().unwrap(); - unsafe { fma(out_slice, pi_slice, qi_slice) } - } - } - // Last reduction to create the coefficients - let mut coeffs: Array2 = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree)); - izip!( - coeffs.outer_iter_mut(), - acc.outer_iter(), - p_first.ctx.q.iter() - ) - .for_each(|(mut coeffsj, accj, m)| { - if p_first.allow_variable_time_computations { - izip!(coeffsj.iter_mut(), accj.iter()) - .for_each(|(cj, accjk)| *cj = unsafe { m.reduce_u128_vt(*accjk) }); - } else { - izip!(coeffsj.iter_mut(), accj.iter()) - .for_each(|(cj, accjk)| *cj = m.reduce_u128(*accjk)); - } - }); + for j in 0..p_first.ctx.q.len() as isize { + let qj = &*q_ptr.offset(j); + *num_acc_ptr.offset(j) += 1; + if *num_acc_ptr.offset(j) == *max_acc_ptr.offset(j) { + if p_first.allow_variable_time_computations { + for i in j * degree..(j + 1) * degree { + *acc_ptr.offset(i) = qj.reduce_u128_vt(*acc_ptr.offset(i)) as u128; + } + } else { + for i in j * degree..(j + 1) * degree { + *acc_ptr.offset(i) = qj.reduce_u128(*acc_ptr.offset(i)) as u128; + } + } + *num_acc_ptr.offset(j) = 1; + } + } + } + } + } else { + // We don't need to check the condition on the max, it should shave off a few + // cycles. + for (pi, qi) in izip!(p, q) { + let pij = pi.coefficients(); + let qij = qi.coefficients(); + let pi_slice = pij.as_slice().unwrap(); + let qi_slice = qij.as_slice().unwrap(); + unsafe { fma(out_slice, pi_slice, qi_slice) } + } + } + // Last reduction to create the coefficients + let mut coeffs: Array2 = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree)); + izip!( + coeffs.outer_iter_mut(), + acc.outer_iter(), + p_first.ctx.q.iter() + ) + .for_each(|(mut coeffsj, accj, m)| { + if p_first.allow_variable_time_computations { + izip!(coeffsj.iter_mut(), accj.iter()) + .for_each(|(cj, accjk)| *cj = unsafe { m.reduce_u128_vt(*accjk) }); + } else { + izip!(coeffsj.iter_mut(), accj.iter()) + .for_each(|(cj, accjk)| *cj = m.reduce_u128(*accjk)); + } + }); - Ok(Poly { - ctx: p_first.ctx.clone(), - representation: Representation::Ntt, - allow_variable_time_computations: p_first.allow_variable_time_computations, - coefficients: coeffs, - coefficients_shoup: None, - has_lazy_coefficients: false, - }) + Ok(Poly { + ctx: p_first.ctx.clone(), + representation: Representation::Ntt, + allow_variable_time_computations: p_first.allow_variable_time_computations, + coefficients: coeffs, + coefficients_shoup: None, + has_lazy_coefficients: false, + }) } #[cfg(test)] mod tests { - use itertools::{izip, Itertools}; - use rand::thread_rng; + use itertools::{izip, Itertools}; + use rand::thread_rng; - use super::dot_product; - use crate::{ - rq::{Context, Poly, Representation}, - zq::Modulus, - }; - use std::{error::Error, sync::Arc}; + use super::dot_product; + use crate::{ + rq::{Context, Poly, Representation}, + zq::Modulus, + }; + use std::{error::Error, sync::Arc}; - static MODULI: &[u64; 3] = &[1153, 4611686018326724609, 4611686018309947393]; + static MODULI: &[u64; 3] = &[1153, 4611686018326724609, 4611686018309947393]; - #[test] - fn add() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let m = Modulus::new(*modulus).unwrap(); + #[test] + fn add() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let r = &p + &q; - assert_eq!(r.representation, Representation::PowerBasis); - let mut a = Vec::::from(&p); - m.add_vec(&mut a, &Vec::::from(&q)); - assert_eq!(Vec::::from(&r), a); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let r = &p + &q; + assert_eq!(r.representation, Representation::PowerBasis); + let mut a = Vec::::from(&p); + m.add_vec(&mut a, &Vec::::from(&q)); + assert_eq!(Vec::::from(&r), a); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); - let r = &p + &q; - assert_eq!(r.representation, Representation::Ntt); - let mut a = Vec::::from(&p); - m.add_vec(&mut a, &Vec::::from(&q)); - assert_eq!(Vec::::from(&r), a); - } + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let r = &p + &q; + assert_eq!(r.representation, Representation::Ntt); + let mut a = Vec::::from(&p); + m.add_vec(&mut a, &Vec::::from(&q)); + assert_eq!(Vec::::from(&r), a); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut a = Vec::::from(&p); - let b = Vec::::from(&q); - for i in 0..MODULI.len() { - let m = Modulus::new(MODULI[i]).unwrap(); - m.add_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) - } - let r = &p + &q; - assert_eq!(r.representation, Representation::PowerBasis); - assert_eq!(Vec::::from(&r), a); - } - Ok(()) - } + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut a = Vec::::from(&p); + let b = Vec::::from(&q); + for i in 0..MODULI.len() { + let m = Modulus::new(MODULI[i]).unwrap(); + m.add_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + } + let r = &p + &q; + assert_eq!(r.representation, Representation::PowerBasis); + assert_eq!(Vec::::from(&r), a); + } + Ok(()) + } - #[test] - fn sub() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let m = Modulus::new(*modulus).unwrap(); + #[test] + fn sub() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let r = &p - &q; - assert_eq!(r.representation, Representation::PowerBasis); - let mut a = Vec::::from(&p); - m.sub_vec(&mut a, &Vec::::from(&q)); - assert_eq!(Vec::::from(&r), a); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let r = &p - &q; + assert_eq!(r.representation, Representation::PowerBasis); + let mut a = Vec::::from(&p); + m.sub_vec(&mut a, &Vec::::from(&q)); + assert_eq!(Vec::::from(&r), a); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); - let r = &p - &q; - assert_eq!(r.representation, Representation::Ntt); - let mut a = Vec::::from(&p); - m.sub_vec(&mut a, &Vec::::from(&q)); - assert_eq!(Vec::::from(&r), a); - } + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let r = &p - &q; + assert_eq!(r.representation, Representation::Ntt); + let mut a = Vec::::from(&p); + m.sub_vec(&mut a, &Vec::::from(&q)); + assert_eq!(Vec::::from(&r), a); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut a = Vec::::from(&p); - let b = Vec::::from(&q); - for i in 0..MODULI.len() { - let m = Modulus::new(MODULI[i]).unwrap(); - m.sub_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) - } - let r = &p - &q; - assert_eq!(r.representation, Representation::PowerBasis); - assert_eq!(Vec::::from(&r), a); - } - Ok(()) - } + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut a = Vec::::from(&p); + let b = Vec::::from(&q); + for i in 0..MODULI.len() { + let m = Modulus::new(MODULI[i]).unwrap(); + m.sub_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + } + let r = &p - &q; + assert_eq!(r.representation, Representation::PowerBasis); + assert_eq!(Vec::::from(&r), a); + } + Ok(()) + } - #[test] - fn mul() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let m = Modulus::new(*modulus).unwrap(); + #[test] + fn mul() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); - let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); - let mut a = Vec::::from(&p); - m.mul_vec(&mut a, &Vec::::from(&q)); - assert_eq!(Vec::::from(&r), a); - } + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let r = &p * &q; + assert_eq!(r.representation, Representation::Ntt); + let mut a = Vec::::from(&p); + m.mul_vec(&mut a, &Vec::::from(&q)); + assert_eq!(Vec::::from(&r), a); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::Ntt, &mut rng); - let mut a = Vec::::from(&p); - let b = Vec::::from(&q); - for i in 0..MODULI.len() { - let m = Modulus::new(MODULI[i]).unwrap(); - m.mul_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) - } - let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); - assert_eq!(Vec::::from(&r), a); - } - Ok(()) - } + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let q = Poly::random(&ctx, Representation::Ntt, &mut rng); + let mut a = Vec::::from(&p); + let b = Vec::::from(&q); + for i in 0..MODULI.len() { + let m = Modulus::new(MODULI[i]).unwrap(); + m.mul_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + } + let r = &p * &q; + assert_eq!(r.representation, Representation::Ntt); + assert_eq!(Vec::::from(&r), a); + } + Ok(()) + } - #[test] - fn mul_shoup() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let m = Modulus::new(*modulus).unwrap(); + #[test] + fn mul_shoup() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); - let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); - let mut a = Vec::::from(&p); - m.mul_vec(&mut a, &Vec::::from(&q)); - assert_eq!(Vec::::from(&r), a); - } + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); + let r = &p * &q; + assert_eq!(r.representation, Representation::Ntt); + let mut a = Vec::::from(&p); + m.mul_vec(&mut a, &Vec::::from(&q)); + assert_eq!(Vec::::from(&r), a); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); - let mut a = Vec::::from(&p); - let b = Vec::::from(&q); - for i in 0..MODULI.len() { - let m = Modulus::new(MODULI[i]).unwrap(); - m.mul_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) - } - let r = &p * &q; - assert_eq!(r.representation, Representation::Ntt); - assert_eq!(Vec::::from(&r), a); - } - Ok(()) - } + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let q = Poly::random(&ctx, Representation::NttShoup, &mut rng); + let mut a = Vec::::from(&p); + let b = Vec::::from(&q); + for i in 0..MODULI.len() { + let m = Modulus::new(MODULI[i]).unwrap(); + m.mul_vec(&mut a[i * 8..(i + 1) * 8], &b[i * 8..(i + 1) * 8]) + } + let r = &p * &q; + assert_eq!(r.representation, Representation::Ntt); + assert_eq!(Vec::::from(&r), a); + } + Ok(()) + } - #[test] - fn neg() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..100 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); - let m = Modulus::new(*modulus).unwrap(); + #[test] + fn neg() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..100 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); + let m = Modulus::new(*modulus).unwrap(); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let r = -&p; - assert_eq!(r.representation, Representation::PowerBasis); - let mut a = Vec::::from(&p); - m.neg_vec(&mut a); - assert_eq!(Vec::::from(&r), a); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let r = -&p; + assert_eq!(r.representation, Representation::PowerBasis); + let mut a = Vec::::from(&p); + m.neg_vec(&mut a); + assert_eq!(Vec::::from(&r), a); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - let r = -&p; - assert_eq!(r.representation, Representation::Ntt); - let mut a = Vec::::from(&p); - m.neg_vec(&mut a); - assert_eq!(Vec::::from(&r), a); - } + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + let r = -&p; + assert_eq!(r.representation, Representation::Ntt); + let mut a = Vec::::from(&p); + m.neg_vec(&mut a); + assert_eq!(Vec::::from(&r), a); + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - let mut a = Vec::::from(&p); - for i in 0..MODULI.len() { - let m = Modulus::new(MODULI[i]).unwrap(); - m.neg_vec(&mut a[i * 8..(i + 1) * 8]) - } - let r = -&p; - assert_eq!(r.representation, Representation::PowerBasis); - assert_eq!(Vec::::from(&r), a); + let ctx = Arc::new(Context::new(MODULI, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + let mut a = Vec::::from(&p); + for i in 0..MODULI.len() { + let m = Modulus::new(MODULI[i]).unwrap(); + m.neg_vec(&mut a[i * 8..(i + 1) * 8]) + } + let r = -&p; + assert_eq!(r.representation, Representation::PowerBasis); + assert_eq!(Vec::::from(&r), a); - let r = -p; - assert_eq!(r.representation, Representation::PowerBasis); - assert_eq!(Vec::::from(&r), a); - } - Ok(()) - } + let r = -p; + assert_eq!(r.representation, Representation::PowerBasis); + assert_eq!(Vec::::from(&r), a); + } + Ok(()) + } - #[test] - fn test_dot_product() -> Result<(), Box> { - let mut rng = thread_rng(); - for _ in 0..20 { - for modulus in MODULI { - let ctx = Arc::new(Context::new(&[*modulus], 8)?); + #[test] + fn test_dot_product() -> Result<(), Box> { + let mut rng = thread_rng(); + for _ in 0..20 { + for modulus in MODULI { + let ctx = Arc::new(Context::new(&[*modulus], 8)?); - for len in 1..50 { - let p = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) - .collect_vec(); - let q = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) - .collect_vec(); - let r = dot_product(p.iter(), q.iter())?; + for len in 1..50 { + let p = (0..len) + .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .collect_vec(); + let q = (0..len) + .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .collect_vec(); + let r = dot_product(p.iter(), q.iter())?; - let mut expected = Poly::zero(&ctx, Representation::Ntt); - izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi)); - assert_eq!(r, expected); - } - } + let mut expected = Poly::zero(&ctx, Representation::Ntt); + izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi)); + assert_eq!(r, expected); + } + } - let ctx = Arc::new(Context::new(MODULI, 8)?); - for len in 1..50 { - let p = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) - .collect_vec(); - let q = (0..len) - .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) - .collect_vec(); - let r = dot_product(p.iter(), q.iter())?; + let ctx = Arc::new(Context::new(MODULI, 8)?); + for len in 1..50 { + let p = (0..len) + .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .collect_vec(); + let q = (0..len) + .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng)) + .collect_vec(); + let r = dot_product(p.iter(), q.iter())?; - let mut expected = Poly::zero(&ctx, Representation::Ntt); - izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi)); - assert_eq!(r, expected); - } - } - Ok(()) - } + let mut expected = Poly::zero(&ctx, Representation::Ntt); + izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi)); + assert_eq!(r, expected); + } + } + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/scaler.rs b/crates/fhe-math/src/rq/scaler.rs index a3f218e..3a11a93 100644 --- a/crates/fhe-math/src/rq/scaler.rs +++ b/crates/fhe-math/src/rq/scaler.rs @@ -4,8 +4,8 @@ use super::{Context, Poly, Representation}; use crate::{ - rns::{RnsScaler, ScalingFactor}, - Error, Result, + rns::{RnsScaler, ScalingFactor}, + Error, Result, }; use itertools::izip; use ndarray::{s, Array2, Axis}; @@ -14,199 +14,200 @@ use std::sync::Arc; /// Context extender. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct Scaler { - from: Arc, - to: Arc, - number_common_moduli: usize, - scaler: RnsScaler, + from: Arc, + to: Arc, + number_common_moduli: usize, + scaler: RnsScaler, } impl Scaler { - /// Create a scaler from a context `from` to a context `to`. - pub fn new(from: &Arc, to: &Arc, factor: ScalingFactor) -> Result { - if from.degree != to.degree { - return Err(Error::Default("Incompatible degrees".to_string())); - } + /// Create a scaler from a context `from` to a context `to`. + pub fn new(from: &Arc, to: &Arc, factor: ScalingFactor) -> Result { + if from.degree != to.degree { + return Err(Error::Default("Incompatible degrees".to_string())); + } - let mut number_common_moduli = 0; - if factor.is_one { - for (qi, pi) in izip!(from.q.iter(), to.q.iter()) { - if qi == pi { - number_common_moduli += 1 - } else { - break; - } - } - } + let mut number_common_moduli = 0; + if factor.is_one { + for (qi, pi) in izip!(from.q.iter(), to.q.iter()) { + if qi == pi { + number_common_moduli += 1 + } else { + break; + } + } + } - let scaler = RnsScaler::new(&from.rns, &to.rns, factor); + let scaler = RnsScaler::new(&from.rns, &to.rns, factor); - Ok(Self { - from: from.clone(), - to: to.clone(), - number_common_moduli, - scaler, - }) - } + Ok(Self { + from: from.clone(), + to: to.clone(), + number_common_moduli, + scaler, + }) + } - /// Scale a polynomial - pub(crate) fn scale(&self, p: &Poly) -> Result { - if p.ctx.as_ref() != self.from.as_ref() { - Err(Error::Default( - "The input polynomial does not have the correct context".to_string(), - )) - } else { - let mut representation = p.representation.clone(); - if representation == Representation::NttShoup { - representation = Representation::Ntt; - } + /// Scale a polynomial + pub(crate) fn scale(&self, p: &Poly) -> Result { + if p.ctx.as_ref() != self.from.as_ref() { + Err(Error::Default( + "The input polynomial does not have the correct context".to_string(), + )) + } else { + let mut representation = p.representation.clone(); + if representation == Representation::NttShoup { + representation = Representation::Ntt; + } - let mut new_coefficients = Array2::::zeros((self.to.q.len(), self.to.degree)); + let mut new_coefficients = Array2::::zeros((self.to.q.len(), self.to.degree)); - if self.number_common_moduli > 0 { - new_coefficients - .slice_mut(s![..self.number_common_moduli, ..]) - .assign(&p.coefficients.slice(s![..self.number_common_moduli, ..])); - } + if self.number_common_moduli > 0 { + new_coefficients + .slice_mut(s![..self.number_common_moduli, ..]) + .assign(&p.coefficients.slice(s![..self.number_common_moduli, ..])); + } - if self.number_common_moduli < self.to.q.len() { - if p.representation == Representation::PowerBasis { - izip!( - new_coefficients - .slice_mut(s![self.number_common_moduli.., ..]) - .axis_iter_mut(Axis(1)), - p.coefficients.axis_iter(Axis(1)) - ) - .for_each(|(new_column, column)| { - self.scaler - .scale(column, new_column, self.number_common_moduli) - }); - } else if self.number_common_moduli < self.to.q.len() { - let mut p_coefficients_powerbasis = p.coefficients.clone(); - // Backward NTT - if p.allow_variable_time_computations { - izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter()) - .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) }); - } else { - izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter()) - .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap())); - } - // Conversion - izip!( - new_coefficients - .slice_mut(s![self.number_common_moduli.., ..]) - .axis_iter_mut(Axis(1)), - p_coefficients_powerbasis.axis_iter(Axis(1)) - ) - .for_each(|(new_column, column)| { - self.scaler - .scale(column, new_column, self.number_common_moduli) - }); - // Forward NTT on the second half - if p.allow_variable_time_computations { - izip!( - new_coefficients - .slice_mut(s![self.number_common_moduli.., ..]) - .outer_iter_mut(), - &self.to.ops[self.number_common_moduli..] - ) - .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) }); - } else { - izip!( - new_coefficients - .slice_mut(s![self.number_common_moduli.., ..]) - .outer_iter_mut(), - &self.to.ops[self.number_common_moduli..] - ) - .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap())); - } - } - } + if self.number_common_moduli < self.to.q.len() { + if p.representation == Representation::PowerBasis { + izip!( + new_coefficients + .slice_mut(s![self.number_common_moduli.., ..]) + .axis_iter_mut(Axis(1)), + p.coefficients.axis_iter(Axis(1)) + ) + .for_each(|(new_column, column)| { + self.scaler + .scale(column, new_column, self.number_common_moduli) + }); + } else if self.number_common_moduli < self.to.q.len() { + let mut p_coefficients_powerbasis = p.coefficients.clone(); + // Backward NTT + if p.allow_variable_time_computations { + izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter()) + .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) }); + } else { + izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter()) + .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap())); + } + // Conversion + izip!( + new_coefficients + .slice_mut(s![self.number_common_moduli.., ..]) + .axis_iter_mut(Axis(1)), + p_coefficients_powerbasis.axis_iter(Axis(1)) + ) + .for_each(|(new_column, column)| { + self.scaler + .scale(column, new_column, self.number_common_moduli) + }); + // Forward NTT on the second half + if p.allow_variable_time_computations { + izip!( + new_coefficients + .slice_mut(s![self.number_common_moduli.., ..]) + .outer_iter_mut(), + &self.to.ops[self.number_common_moduli..] + ) + .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) }); + } else { + izip!( + new_coefficients + .slice_mut(s![self.number_common_moduli.., ..]) + .outer_iter_mut(), + &self.to.ops[self.number_common_moduli..] + ) + .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap())); + } + } + } - Ok(Poly { - ctx: self.to.clone(), - representation, - allow_variable_time_computations: p.allow_variable_time_computations, - coefficients: new_coefficients, - coefficients_shoup: None, - has_lazy_coefficients: false, - }) - } - } + Ok(Poly { + ctx: self.to.clone(), + representation, + allow_variable_time_computations: p.allow_variable_time_computations, + coefficients: new_coefficients, + coefficients_shoup: None, + has_lazy_coefficients: false, + }) + } + } } #[cfg(test)] mod tests { - use super::{Scaler, ScalingFactor}; - use crate::rq::{Context, Poly, Representation}; - use itertools::Itertools; - use num_bigint::BigUint; - use num_traits::{One, Zero}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use super::{Scaler, ScalingFactor}; + use crate::rq::{Context, Poly, Representation}; + use itertools::Itertools; + use num_bigint::BigUint; + use num_traits::{One, Zero}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - // Moduli to be used in tests. - static Q: &[u64; 3] = &[ - 4611686018282684417, - 4611686018326724609, - 4611686018309947393, - ]; + // Moduli to be used in tests. + static Q: &[u64; 3] = &[ + 4611686018282684417, + 4611686018326724609, + 4611686018309947393, + ]; - static P: &[u64; 3] = &[ - 4611686018282684417, - 4611686018309947393, - 4611686018257518593, - ]; + static P: &[u64; 3] = &[ + 4611686018282684417, + 4611686018309947393, + 4611686018257518593, + ]; - #[test] - fn scaler() -> Result<(), Box> { - let mut rng = thread_rng(); - let ntests = 100; - let from = Arc::new(Context::new(Q, 8)?); - let to = Arc::new(Context::new(P, 8)?); + #[test] + fn scaler() -> Result<(), Box> { + let mut rng = thread_rng(); + let ntests = 100; + let from = Arc::new(Context::new(Q, 8)?); + let to = Arc::new(Context::new(P, 8)?); - for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { - for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { - let n = BigUint::from(*numerator); - let d = BigUint::from(*denominator); + for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { + for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { + let n = BigUint::from(*numerator); + let d = BigUint::from(*denominator); - let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?; + let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?; - for _ in 0..ntests { - let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng); - let poly_biguint = Vec::::from(&poly); + for _ in 0..ntests { + let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng); + let poly_biguint = Vec::::from(&poly); - let scaled_poly = scaler.scale(&poly)?; - let scaled_biguint = Vec::::from(&scaled_poly); + let scaled_poly = scaler.scale(&poly)?; + let scaled_biguint = Vec::::from(&scaled_poly); - let expected = poly_biguint - .iter() - .map(|i| { - if i >= &(from.modulus() >> 1usize) { - if &d & BigUint::one() == BigUint::zero() { - to.modulus() - - (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64)) - / &d) % to.modulus() - } else { - to.modulus() - - (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d) - % to.modulus() - } - } else { - ((i * &n + (&d >> 1)) / &d) % to.modulus() - } - }) - .collect_vec(); - assert_eq!(expected, scaled_biguint); + let expected = poly_biguint + .iter() + .map(|i| { + if i >= &(from.modulus() >> 1usize) { + if &d & BigUint::one() == BigUint::zero() { + to.modulus() + - (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64)) + / &d) + % to.modulus() + } else { + to.modulus() + - (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d) + % to.modulus() + } + } else { + ((i * &n + (&d >> 1)) / &d) % to.modulus() + } + }) + .collect_vec(); + assert_eq!(expected, scaled_biguint); - poly.change_representation(Representation::Ntt); - let mut scaled_poly = scaler.scale(&poly)?; - scaled_poly.change_representation(Representation::PowerBasis); - let scaled_biguint = Vec::::from(&scaled_poly); - assert_eq!(expected, scaled_biguint); - } - } - } + poly.change_representation(Representation::Ntt); + let mut scaled_poly = scaler.scale(&poly)?; + scaled_poly.change_representation(Representation::PowerBasis); + let scaled_biguint = Vec::::from(&scaled_poly); + assert_eq!(expected, scaled_biguint); + } + } + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/serialize.rs b/crates/fhe-math/src/rq/serialize.rs index 1ce6ad0..3b51805 100644 --- a/crates/fhe-math/src/rq/serialize.rs +++ b/crates/fhe-math/src/rq/serialize.rs @@ -8,59 +8,59 @@ use fhe_traits::{DeserializeWithContext, Serialize}; use protobuf::Message; impl Serialize for Poly { - fn to_bytes(&self) -> Vec { - let rq = Rq::from(self); - rq.write_to_bytes().unwrap() - } + fn to_bytes(&self) -> Vec { + let rq = Rq::from(self); + rq.write_to_bytes().unwrap() + } } impl DeserializeWithContext for Poly { - type Error = Error; - type Context = Context; + type Error = Error; + type Context = Context; - fn from_bytes(bytes: &[u8], ctx: &Arc) -> Result { - let rq = Rq::parse_from_bytes(bytes).map_err(|e| Error::Serialization(e.to_string()))?; - Poly::try_convert_from(&rq, ctx, false, None) - } + fn from_bytes(bytes: &[u8], ctx: &Arc) -> Result { + let rq = Rq::parse_from_bytes(bytes).map_err(|e| Error::Serialization(e.to_string()))?; + Poly::try_convert_from(&rq, ctx, false, None) + } } #[cfg(test)] mod tests { - use std::{error::Error, sync::Arc}; + use std::{error::Error, sync::Arc}; - use fhe_traits::{DeserializeWithContext, Serialize}; - use rand::thread_rng; + use fhe_traits::{DeserializeWithContext, Serialize}; + use rand::thread_rng; - use crate::rq::{Context, Poly, Representation}; + use crate::rq::{Context, Poly, Representation}; - const Q: &[u64; 3] = &[ - 4611686018282684417, - 4611686018326724609, - 4611686018309947393, - ]; + const Q: &[u64; 3] = &[ + 4611686018282684417, + 4611686018326724609, + 4611686018309947393, + ]; - #[test] - fn serialize() -> Result<(), Box> { - let mut rng = thread_rng(); + #[test] + fn serialize() -> Result<(), Box> { + let mut rng = thread_rng(); - for qi in Q { - let ctx = Arc::new(Context::new(&[*qi], 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::NttShoup, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - } + for qi in Q { + let ctx = Arc::new(Context::new(&[*qi], 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::random(&ctx, Representation::NttShoup, &mut rng); + assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + } - let ctx = Arc::new(Context::new(Q, 8)?); - let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::Ntt, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - let p = Poly::random(&ctx, Representation::NttShoup, &mut rng); - assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let ctx = Arc::new(Context::new(Q, 8)?); + let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng); + assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::random(&ctx, Representation::Ntt, &mut rng); + assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); + let p = Poly::random(&ctx, Representation::NttShoup, &mut rng); + assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?); - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe-math/src/rq/switcher.rs b/crates/fhe-math/src/rq/switcher.rs index b822e01..4f7319d 100644 --- a/crates/fhe-math/src/rq/switcher.rs +++ b/crates/fhe-math/src/rq/switcher.rs @@ -9,19 +9,19 @@ use std::sync::Arc; /// Context switcher. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct Switcher { - pub(crate) scaler: Scaler, + pub(crate) scaler: Scaler, } impl Switcher { - /// Create a switcher from a context `from` to a context `to`. - pub fn new(from: &Arc, to: &Arc) -> Result { - Ok(Self { - scaler: Scaler::new(from, to, ScalingFactor::new(to.modulus(), from.modulus()))?, - }) - } + /// Create a switcher from a context `from` to a context `to`. + pub fn new(from: &Arc, to: &Arc) -> Result { + Ok(Self { + scaler: Scaler::new(from, to, ScalingFactor::new(to.modulus(), from.modulus()))?, + }) + } - /// Switch a polynomial. - pub(crate) fn switch(&self, p: &Poly) -> Result { - self.scaler.scale(p) - } + /// Switch a polynomial. + pub(crate) fn switch(&self, p: &Poly) -> Result { + self.scaler.scale(p) + } } diff --git a/crates/fhe-math/src/rq/traits.rs b/crates/fhe-math/src/rq/traits.rs index a83ef39..2ca35ca 100644 --- a/crates/fhe-math/src/rq/traits.rs +++ b/crates/fhe-math/src/rq/traits.rs @@ -14,19 +14,19 @@ use std::sync::Arc; /// blanket implementation . pub trait TryConvertFrom where - Self: Sized, + Self: Sized, { - /// Attempt to convert the `value` into a polynomial with a specific context - /// and under a specific representation. The representation may optional and - /// be specified as `None`; this is useful for example when converting from - /// a value that encodes the representation (e.g., serialization, protobuf, - /// etc.). - fn try_convert_from( - value: T, - ctx: &Arc, - variable_time: bool, - representation: R, - ) -> Result - where - R: Into>; + /// Attempt to convert the `value` into a polynomial with a specific context + /// and under a specific representation. The representation may optional and + /// be specified as `None`; this is useful for example when converting from + /// a value that encodes the representation (e.g., serialization, protobuf, + /// etc.). + fn try_convert_from( + value: T, + ctx: &Arc, + variable_time: bool, + representation: R, + ) -> Result + where + R: Into>; } diff --git a/crates/fhe-math/src/zq/mod.rs b/crates/fhe-math/src/zq/mod.rs index ec0f7f7..6bb511d 100644 --- a/crates/fhe-math/src/zq/mod.rs +++ b/crates/fhe-math/src/zq/mod.rs @@ -15,1074 +15,1074 @@ use rand::{distributions::Uniform, CryptoRng, Rng, RngCore}; /// Structure encapsulating an integer modulus up to 62 bits. #[derive(Debug, Clone, PartialEq)] pub struct Modulus { - p: u64, - nbits: usize, - barrett_hi: u64, - barrett_lo: u64, - leading_zeros: u32, - pub(crate) supports_opt: bool, - distribution: Uniform, + p: u64, + nbits: usize, + barrett_hi: u64, + barrett_lo: u64, + leading_zeros: u32, + pub(crate) supports_opt: bool, + distribution: Uniform, } // We need to declare Eq manually because of the `Uniform` member. impl Eq for Modulus {} impl Modulus { - /// Create a modulus from an integer of at most 62 bits. - pub fn new(p: u64) -> Result { - if p < 2 || (p >> 62) != 0 { - Err(Error::InvalidModulus(p)) - } else { - let barrett = ((BigUint::from(1u64) << 128usize) / p).to_u128().unwrap(); // 2^128 / p - Ok(Self { - p, - nbits: 64 - p.leading_zeros() as usize, - barrett_hi: (barrett >> 64) as u64, - barrett_lo: barrett as u64, - leading_zeros: p.leading_zeros(), - supports_opt: primes::supports_opt(p), - distribution: Uniform::from(0..p), - }) - } - } - - /// Returns the value of the modulus. - pub const fn modulus(&self) -> u64 { - self.p - } - - /// Performs the modular addition of a and b in constant time. - /// Aborts if a >= p or b >= p in debug mode. - pub const fn add(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.p && b < self.p); - Self::reduce1(a + b, self.p) - } - - /// Performs the modular addition of a and b in variable time. - /// Aborts if a >= p or b >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being added. - pub const unsafe fn add_vt(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.p && b < self.p); - Self::reduce1_vt(a + b, self.p) - } - - /// Performs the modular subtraction of a and b in constant time. - /// Aborts if a >= p or b >= p in debug mode. - pub const fn sub(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.p && b < self.p); - Self::reduce1(a + self.p - b, self.p) - } - - /// Performs the modular subtraction of a and b in constant time. - /// Aborts if a >= p or b >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being subtracted. - const unsafe fn sub_vt(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.p && b < self.p); - Self::reduce1_vt(a + self.p - b, self.p) - } - - /// Performs the modular multiplication of a and b in constant time. - /// Aborts if a >= p or b >= p in debug mode. - pub const fn mul(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.p && b < self.p); - self.reduce_u128((a as u128) * (b as u128)) - } - - /// Performs the modular multiplication of a and b in constant time. - /// Aborts if a >= p or b >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being multiplied. - const unsafe fn mul_vt(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.p && b < self.p); - Self::reduce1_vt(self.lazy_reduce_u128((a as u128) * (b as u128)), self.p) - } - - /// Optimized modular multiplication of a and b in constant time. - /// - /// Aborts if a >= p or b >= p in debug mode. - pub const fn mul_opt(&self, a: u64, b: u64) -> u64 { - debug_assert!(self.supports_opt); - debug_assert!(a < self.p && b < self.p); - - self.reduce_opt_u128((a as u128) * (b as u128)) - } - - /// Optimized modular multiplication of a and b in variable time. - /// Aborts if a >= p or b >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being multiplied. - const unsafe fn mul_opt_vt(&self, a: u64, b: u64) -> u64 { - debug_assert!(self.supports_opt); - debug_assert!(a < self.p && b < self.p); - - self.reduce_opt_u128_vt((a as u128) * (b as u128)) - } - - /// Modular negation in constant time. - /// - /// Aborts if a >= p in debug mode. - pub const fn neg(&self, a: u64) -> u64 { - debug_assert!(a < self.p); - Self::reduce1(self.p - a, self.p) - } - - /// Modular negation in variable time. - /// Aborts if a >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being negated. - const unsafe fn neg_vt(&self, a: u64) -> u64 { - debug_assert!(a < self.p); - Self::reduce1_vt(self.p - a, self.p) - } - - /// Compute the Shoup representation of a. - /// - /// Aborts if a >= p in debug mode. - pub const fn shoup(&self, a: u64) -> u64 { - debug_assert!(a < self.p); - - (((a as u128) << 64) / (self.p as u128)) as u64 - } - - /// Shoup multiplication of a and b in constant time. - /// - /// Aborts if b >= p or b_shoup != shoup(b) in debug mode. - pub const fn mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 { - Self::reduce1(self.lazy_mul_shoup(a, b, b_shoup), self.p) - } - - /// Shoup multiplication of a and b in variable time. - /// Aborts if b >= p or b_shoup != shoup(b) in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being multiplied. - const unsafe fn mul_shoup_vt(&self, a: u64, b: u64, b_shoup: u64) -> u64 { - Self::reduce1_vt(self.lazy_mul_shoup(a, b, b_shoup), self.p) - } - - /// Lazy Shoup multiplication of a and b in constant time. - /// The output is in the interval [0, 2 * p). - /// - /// Aborts if b >= p or b_shoup != shoup(b) in debug mode. - pub const fn lazy_mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 { - debug_assert!(b < self.p); - debug_assert!(b_shoup == self.shoup(b)); - - let q = ((a as u128) * (b_shoup as u128)) >> 64; - let r = ((a as u128) * (b as u128) - q * (self.p as u128)) as u64; - - debug_assert!(r < 2 * self.p); - - r - } - - /// Modular addition of vectors in place in constant time. - /// - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - pub fn add_vec(&self, a: &mut [u64], b: &[u64]) { - debug_assert_eq!(a.len(), b.len()); - - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi)); - } - - /// Modular addition of vectors in place in variable time. - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being added. - pub unsafe fn add_vec_vt(&self, a: &mut [u64], b: &[u64]) { - let n = a.len(); - debug_assert_eq!(n, b.len()); - - let p = self.p; - macro_rules! add_at { - ($idx:expr) => { - *a.get_unchecked_mut($idx) = - Self::reduce1_vt(*a.get_unchecked_mut($idx) + *b.get_unchecked($idx), p); - }; - } - - if n % 16 == 0 { - for i in 0..n / 16 { - add_at!(16 * i); - add_at!(16 * i + 1); - add_at!(16 * i + 2); - add_at!(16 * i + 3); - add_at!(16 * i + 4); - add_at!(16 * i + 5); - add_at!(16 * i + 6); - add_at!(16 * i + 7); - add_at!(16 * i + 8); - add_at!(16 * i + 9); - add_at!(16 * i + 10); - add_at!(16 * i + 11); - add_at!(16 * i + 12); - add_at!(16 * i + 13); - add_at!(16 * i + 14); - add_at!(16 * i + 15); - } - } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi)); - } - } - - /// Modular subtraction of vectors in place in constant time. - /// - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - pub fn sub_vec(&self, a: &mut [u64], b: &[u64]) { - debug_assert_eq!(a.len(), b.len()); - - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi)); - } - - /// Modular subtraction of vectors in place in variable time. - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being subtracted. - pub unsafe fn sub_vec_vt(&self, a: &mut [u64], b: &[u64]) { - let n = a.len(); - debug_assert_eq!(n, b.len()); - - let p = self.p; - macro_rules! sub_at { - ($idx:expr) => { - *a.get_unchecked_mut($idx) = - Self::reduce1_vt(p + *a.get_unchecked_mut($idx) - *b.get_unchecked($idx), p); - }; - } - - if n % 16 == 0 { - for i in 0..n / 16 { - sub_at!(16 * i); - sub_at!(16 * i + 1); - sub_at!(16 * i + 2); - sub_at!(16 * i + 3); - sub_at!(16 * i + 4); - sub_at!(16 * i + 5); - sub_at!(16 * i + 6); - sub_at!(16 * i + 7); - sub_at!(16 * i + 8); - sub_at!(16 * i + 9); - sub_at!(16 * i + 10); - sub_at!(16 * i + 11); - sub_at!(16 * i + 12); - sub_at!(16 * i + 13); - sub_at!(16 * i + 14); - sub_at!(16 * i + 15); - } - } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi)); - } - } - - /// Modular multiplication of vectors in place in constant time. - /// - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - pub fn mul_vec(&self, a: &mut [u64], b: &[u64]) { - debug_assert_eq!(a.len(), b.len()); - - if self.supports_opt { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi)); - } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi)); - } - } - - /// Modular scalar multiplication of vectors in place in constant time. - /// - /// Aborts if any of the values in a is >= p in debug mode. - pub fn scalar_mul_vec(&self, a: &mut [u64], b: u64) { - let b_shoup = self.shoup(b); - a.iter_mut() - .for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup)); - } - - /// Modular scalar multiplication of vectors in place in variable time. - /// Aborts if any of the values in a is >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being multiplied. - pub unsafe fn scalar_mul_vec_vt(&self, a: &mut [u64], b: u64) { - let b_shoup = self.shoup(b); - a.iter_mut() - .for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup)); - } - - /// Modular multiplication of vectors in place in variable time. - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being subtracted. - pub unsafe fn mul_vec_vt(&self, a: &mut [u64], b: &[u64]) { - debug_assert_eq!(a.len(), b.len()); - - if self.supports_opt { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi)); - } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi)); - } - } - - /// Compute the Shoup representation of a vector. - /// - /// Aborts if any of the values of the vector is >= p in debug mode. - pub fn shoup_vec(&self, a: &[u64]) -> Vec { - a.iter().map(|ai| self.shoup(*ai)).collect_vec() - } - - /// Shoup modular multiplication of vectors in place in constant time. - /// - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - pub fn mul_shoup_vec(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) { - debug_assert_eq!(a.len(), b.len()); - debug_assert_eq!(a.len(), b_shoup.len()); - debug_assert_eq!(&b_shoup, &self.shoup_vec(b)); - - izip!(a.iter_mut(), b.iter(), b_shoup.iter()) - .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup)); - } - - /// Shoup modular multiplication of vectors in place in variable time. - /// Aborts if a and b differ in size, and if any of their values is >= p in - /// debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being multiplied. - pub unsafe fn mul_shoup_vec_vt(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) { - debug_assert_eq!(a.len(), b.len()); - debug_assert_eq!(a.len(), b_shoup.len()); - debug_assert_eq!(&b_shoup, &self.shoup_vec(b)); - - izip!(a.iter_mut(), b.iter(), b_shoup.iter()) - .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup)); - } - - /// Reduce a vector in place in constant time. - pub fn reduce_vec(&self, a: &mut [u64]) { - a.iter_mut().for_each(|ai| *ai = self.reduce(*ai)); - } - - /// Center a value modulo p as i64 in variable time. - /// TODO: To test and to make constant time? - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being centered. - const unsafe fn center_vt(&self, a: u64) -> i64 { - debug_assert!(a < self.p); - - if a >= self.p >> 1 { - (a as i64) - (self.p as i64) - } else { - a as i64 - } - } - - /// Center a vector in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being centered. - pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec { - a.iter().map(|ai| self.center_vt(*ai)).collect_vec() - } - - /// Reduce a vector in place in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being reduced. - pub unsafe fn reduce_vec_vt(&self, a: &mut [u64]) { - a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai)); - } - - /// Modular reduction of a i64 in constant time. - const fn reduce_i64(&self, a: i64) -> u64 { - self.reduce_u128((((self.p as i128) << 64) + (a as i128)) as u128) - } - - /// Modular reduction of a i64 in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being reduced. - const unsafe fn reduce_i64_vt(&self, a: i64) -> u64 { - self.reduce_u128_vt((((self.p as i128) << 64) + (a as i128)) as u128) - } - - /// Reduce a vector in place in constant time. - pub fn reduce_vec_i64(&self, a: &[i64]) -> Vec { - a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec() - } - - /// Reduce a vector in place in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being reduced. - pub unsafe fn reduce_vec_i64_vt(&self, a: &[i64]) -> Vec { - a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect_vec() - } - - /// Reduce a vector in constant time. - pub fn reduce_vec_new(&self, a: &[u64]) -> Vec { - a.iter().map(|ai| self.reduce(*ai)).collect_vec() - } - - /// Reduce a vector in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being reduced. - pub unsafe fn reduce_vec_new_vt(&self, a: &[u64]) -> Vec { - a.iter().map(|bi| self.reduce_vt(*bi)).collect_vec() - } - - /// Modular negation of a vector in place in constant time. - /// - /// Aborts if any of the values in the vector is >= p in debug mode. - pub fn neg_vec(&self, a: &mut [u64]) { - izip!(a.iter_mut()).for_each(|ai| *ai = self.neg(*ai)); - } - - /// Modular negation of a vector in place in variable time. - /// Aborts if any of the values in the vector is >= p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being negated. - pub unsafe fn neg_vec_vt(&self, a: &mut [u64]) { - izip!(a.iter_mut()).for_each(|ai| *ai = self.neg_vt(*ai)); - } - - /// Modular exponentiation in variable time. - /// - /// Aborts if a >= p or n >= p in debug mode. - pub fn pow(&self, a: u64, n: u64) -> u64 { - debug_assert!(a < self.p && n < self.p); - - if n == 0 { - 1 - } else if n == 1 { - a - } else { - let mut r = a; - let mut i = (62 - n.leading_zeros()) as isize; - while i >= 0 { - r = self.mul(r, r); - if (n >> i) & 1 == 1 { - r = self.mul(r, a); - } - i -= 1; - } - r - } - } - - /// Modular inversion in variable time. - /// - /// Returns None if p is not prime or a = 0. - /// Aborts if a >= p in debug mode. - pub fn inv(&self, a: u64) -> std::option::Option { - if !is_prime(self.p) || a == 0 { - None - } else { - let r = self.pow(a, self.p - 2); - debug_assert_eq!(self.mul(a, r), 1); - Some(r) - } - } - - /// Modular reduction of a u128 in constant time. - pub const fn reduce_u128(&self, a: u128) -> u64 { - Self::reduce1(self.lazy_reduce_u128(a), self.p) - } - - /// Modular reduction of a u128 in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub const unsafe fn reduce_u128_vt(&self, a: u128) -> u64 { - Self::reduce1_vt(self.lazy_reduce_u128(a), self.p) - } - - /// Modular reduction of a u64 in constant time. - pub const fn reduce(&self, a: u64) -> u64 { - Self::reduce1(self.lazy_reduce(a), self.p) - } - - /// Modular reduction of a u64 in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub const unsafe fn reduce_vt(&self, a: u64) -> u64 { - Self::reduce1_vt(self.lazy_reduce(a), self.p) - } - - /// Optimized modular reduction of a u128 in constant time. - pub const fn reduce_opt_u128(&self, a: u128) -> u64 { - debug_assert!(self.supports_opt); - Self::reduce1(self.lazy_reduce_opt_u128(a), self.p) - } - - /// Optimized modular reduction of a u128 in constant time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub(crate) const unsafe fn reduce_opt_u128_vt(&self, a: u128) -> u64 { - debug_assert!(self.supports_opt); - Self::reduce1_vt(self.lazy_reduce_opt_u128(a), self.p) - } - - /// Optimized modular reduction of a u64 in constant time. - pub const fn reduce_opt(&self, a: u64) -> u64 { - Self::reduce1(self.lazy_reduce_opt(a), self.p) - } - - /// Optimized modular reduction of a u64 in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub const unsafe fn reduce_opt_vt(&self, a: u64) -> u64 { - Self::reduce1_vt(self.lazy_reduce_opt(a), self.p) - } - - /// Return x mod p in constant time. - /// Aborts if x >= 2 * p in debug mode. - const fn reduce1(x: u64, p: u64) -> u64 { - debug_assert!(p >> 63 == 0); - debug_assert!(x < 2 * p); - - let (y, _) = x.overflowing_sub(p); - let xp = x ^ p; - let yp = y ^ p; - let xy = xp ^ yp; - let xxy = x ^ xy; - let xxy = xxy >> 63; - let (c, _) = xxy.overflowing_sub(1); - let r = (c & y) | ((!c) & x); - - debug_assert!(r == x % p); - - r - } - - /// Return x mod p in variable time. - /// Aborts if x >= 2 * p in debug mode. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 { - debug_assert!(p >> 63 == 0); - debug_assert!(x < 2 * p); - - if x >= p { - x - p - } else { - x - } - } - - /// Lazy modular reduction of a in constant time. - /// The output is in the interval [0, 2 * p). - pub const fn lazy_reduce_u128(&self, a: u128) -> u64 { - let a_lo = a as u64; - let a_hi = (a >> 64) as u64; - let p_lo_lo = ((a_lo as u128) * (self.barrett_lo as u128)) >> 64; - let p_hi_lo = (a_hi as u128) * (self.barrett_lo as u128); - let p_lo_hi = (a_lo as u128) * (self.barrett_hi as u128); - - let q = ((p_lo_hi + p_hi_lo + p_lo_lo) >> 64) + (a_hi as u128) * (self.barrett_hi as u128); - let r = (a - q * (self.p as u128)) as u64; - - debug_assert!((r as u128) < 2 * (self.p as u128)); - debug_assert!(r % self.p == (a % (self.p as u128)) as u64); - - r - } - - /// Lazy modular reduction of a in constant time. - /// The output is in the interval [0, 2 * p). - pub const fn lazy_reduce(&self, a: u64) -> u64 { - let p_lo_lo = ((a as u128) * (self.barrett_lo as u128)) >> 64; - let p_lo_hi = (a as u128) * (self.barrett_hi as u128); - - let q = (p_lo_hi + p_lo_lo) >> 64; - let r = (a as u128 - q * (self.p as u128)) as u64; - - debug_assert!((r as u128) < 2 * (self.p as u128)); - debug_assert!(r % self.p == a % self.p); - - r - } - - /// Lazy optimized modular reduction of a in constant time. - /// The output is in the interval [0, 2 * p). - /// - /// Aborts if the input is >= p ^ 2 in debug mode. - pub const fn lazy_reduce_opt_u128(&self, a: u128) -> u64 { - debug_assert!(a < (self.p as u128) * (self.p as u128)); - - let q = (((self.barrett_lo as u128) * (a >> 64)) + (a << self.leading_zeros)) >> 64; - let r = (a - q * (self.p as u128)) as u64; - - debug_assert!((r as u128) < 2 * (self.p as u128)); - debug_assert!(r % self.p == (a % (self.p as u128)) as u64); - - r - } - - /// Lazy optimized modular reduction of a in constant time. - /// The output is in the interval [0, 2 * p). - const fn lazy_reduce_opt(&self, a: u64) -> u64 { - let q = a >> (64 - self.leading_zeros); - let r = ((a as u128) - (q as u128) * (self.p as u128)) as u64; - - debug_assert!((r as u128) < 2 * (self.p as u128)); - debug_assert!(r % self.p == a % self.p); - - r - } - - /// Lazy modular reduction of a vector in constant time. - /// The output coefficients are in the interval [0, 2 * p). - pub fn lazy_reduce_vec(&self, a: &mut [u64]) { - if self.supports_opt { - a.iter_mut().for_each(|ai| *ai = self.lazy_reduce_opt(*ai)) - } else { - a.iter_mut().for_each(|ai| *ai = self.lazy_reduce(*ai)) - } - } - - /// Returns a random vector. - pub fn random_vec(&self, size: usize, rng: &mut R) -> Vec { - rng.sample_iter(self.distribution).take(size).collect_vec() - } - - /// Length of the serialization of a vector of size `size`. - /// - /// Panics if the size is not a multiple of 8. - pub const fn serialization_length(&self, size: usize) -> usize { - assert!(size % 8 == 0); - let p_nbits = 64 - (self.p - 1).leading_zeros() as usize; - p_nbits * size / 8 - } - - /// Serialize a vector of elements of length a multiple of 8. - /// - /// Panics if the length of the vector is not a multiple of 8. - pub fn serialize_vec(&self, a: &[u64]) -> Vec { - let p_nbits = 64 - (self.p - 1).leading_zeros() as usize; - transcode_to_bytes(a, p_nbits) - } - - /// Deserialize a vector of bytes into a vector of elements mod p. - pub fn deserialize_vec(&self, b: &[u8]) -> Vec { - let p_nbits = 64 - (self.p - 1).leading_zeros() as usize; - transcode_from_bytes(b, p_nbits) - } + /// Create a modulus from an integer of at most 62 bits. + pub fn new(p: u64) -> Result { + if p < 2 || (p >> 62) != 0 { + Err(Error::InvalidModulus(p)) + } else { + let barrett = ((BigUint::from(1u64) << 128usize) / p).to_u128().unwrap(); // 2^128 / p + Ok(Self { + p, + nbits: 64 - p.leading_zeros() as usize, + barrett_hi: (barrett >> 64) as u64, + barrett_lo: barrett as u64, + leading_zeros: p.leading_zeros(), + supports_opt: primes::supports_opt(p), + distribution: Uniform::from(0..p), + }) + } + } + + /// Returns the value of the modulus. + pub const fn modulus(&self) -> u64 { + self.p + } + + /// Performs the modular addition of a and b in constant time. + /// Aborts if a >= p or b >= p in debug mode. + pub const fn add(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.p && b < self.p); + Self::reduce1(a + b, self.p) + } + + /// Performs the modular addition of a and b in variable time. + /// Aborts if a >= p or b >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being added. + pub const unsafe fn add_vt(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.p && b < self.p); + Self::reduce1_vt(a + b, self.p) + } + + /// Performs the modular subtraction of a and b in constant time. + /// Aborts if a >= p or b >= p in debug mode. + pub const fn sub(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.p && b < self.p); + Self::reduce1(a + self.p - b, self.p) + } + + /// Performs the modular subtraction of a and b in constant time. + /// Aborts if a >= p or b >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being subtracted. + const unsafe fn sub_vt(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.p && b < self.p); + Self::reduce1_vt(a + self.p - b, self.p) + } + + /// Performs the modular multiplication of a and b in constant time. + /// Aborts if a >= p or b >= p in debug mode. + pub const fn mul(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.p && b < self.p); + self.reduce_u128((a as u128) * (b as u128)) + } + + /// Performs the modular multiplication of a and b in constant time. + /// Aborts if a >= p or b >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being multiplied. + const unsafe fn mul_vt(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.p && b < self.p); + Self::reduce1_vt(self.lazy_reduce_u128((a as u128) * (b as u128)), self.p) + } + + /// Optimized modular multiplication of a and b in constant time. + /// + /// Aborts if a >= p or b >= p in debug mode. + pub const fn mul_opt(&self, a: u64, b: u64) -> u64 { + debug_assert!(self.supports_opt); + debug_assert!(a < self.p && b < self.p); + + self.reduce_opt_u128((a as u128) * (b as u128)) + } + + /// Optimized modular multiplication of a and b in variable time. + /// Aborts if a >= p or b >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being multiplied. + const unsafe fn mul_opt_vt(&self, a: u64, b: u64) -> u64 { + debug_assert!(self.supports_opt); + debug_assert!(a < self.p && b < self.p); + + self.reduce_opt_u128_vt((a as u128) * (b as u128)) + } + + /// Modular negation in constant time. + /// + /// Aborts if a >= p in debug mode. + pub const fn neg(&self, a: u64) -> u64 { + debug_assert!(a < self.p); + Self::reduce1(self.p - a, self.p) + } + + /// Modular negation in variable time. + /// Aborts if a >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being negated. + const unsafe fn neg_vt(&self, a: u64) -> u64 { + debug_assert!(a < self.p); + Self::reduce1_vt(self.p - a, self.p) + } + + /// Compute the Shoup representation of a. + /// + /// Aborts if a >= p in debug mode. + pub const fn shoup(&self, a: u64) -> u64 { + debug_assert!(a < self.p); + + (((a as u128) << 64) / (self.p as u128)) as u64 + } + + /// Shoup multiplication of a and b in constant time. + /// + /// Aborts if b >= p or b_shoup != shoup(b) in debug mode. + pub const fn mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 { + Self::reduce1(self.lazy_mul_shoup(a, b, b_shoup), self.p) + } + + /// Shoup multiplication of a and b in variable time. + /// Aborts if b >= p or b_shoup != shoup(b) in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being multiplied. + const unsafe fn mul_shoup_vt(&self, a: u64, b: u64, b_shoup: u64) -> u64 { + Self::reduce1_vt(self.lazy_mul_shoup(a, b, b_shoup), self.p) + } + + /// Lazy Shoup multiplication of a and b in constant time. + /// The output is in the interval [0, 2 * p). + /// + /// Aborts if b >= p or b_shoup != shoup(b) in debug mode. + pub const fn lazy_mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 { + debug_assert!(b < self.p); + debug_assert!(b_shoup == self.shoup(b)); + + let q = ((a as u128) * (b_shoup as u128)) >> 64; + let r = ((a as u128) * (b as u128) - q * (self.p as u128)) as u64; + + debug_assert!(r < 2 * self.p); + + r + } + + /// Modular addition of vectors in place in constant time. + /// + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + pub fn add_vec(&self, a: &mut [u64], b: &[u64]) { + debug_assert_eq!(a.len(), b.len()); + + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi)); + } + + /// Modular addition of vectors in place in variable time. + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being added. + pub unsafe fn add_vec_vt(&self, a: &mut [u64], b: &[u64]) { + let n = a.len(); + debug_assert_eq!(n, b.len()); + + let p = self.p; + macro_rules! add_at { + ($idx:expr) => { + *a.get_unchecked_mut($idx) = + Self::reduce1_vt(*a.get_unchecked_mut($idx) + *b.get_unchecked($idx), p); + }; + } + + if n % 16 == 0 { + for i in 0..n / 16 { + add_at!(16 * i); + add_at!(16 * i + 1); + add_at!(16 * i + 2); + add_at!(16 * i + 3); + add_at!(16 * i + 4); + add_at!(16 * i + 5); + add_at!(16 * i + 6); + add_at!(16 * i + 7); + add_at!(16 * i + 8); + add_at!(16 * i + 9); + add_at!(16 * i + 10); + add_at!(16 * i + 11); + add_at!(16 * i + 12); + add_at!(16 * i + 13); + add_at!(16 * i + 14); + add_at!(16 * i + 15); + } + } else { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi)); + } + } + + /// Modular subtraction of vectors in place in constant time. + /// + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + pub fn sub_vec(&self, a: &mut [u64], b: &[u64]) { + debug_assert_eq!(a.len(), b.len()); + + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi)); + } + + /// Modular subtraction of vectors in place in variable time. + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being subtracted. + pub unsafe fn sub_vec_vt(&self, a: &mut [u64], b: &[u64]) { + let n = a.len(); + debug_assert_eq!(n, b.len()); + + let p = self.p; + macro_rules! sub_at { + ($idx:expr) => { + *a.get_unchecked_mut($idx) = + Self::reduce1_vt(p + *a.get_unchecked_mut($idx) - *b.get_unchecked($idx), p); + }; + } + + if n % 16 == 0 { + for i in 0..n / 16 { + sub_at!(16 * i); + sub_at!(16 * i + 1); + sub_at!(16 * i + 2); + sub_at!(16 * i + 3); + sub_at!(16 * i + 4); + sub_at!(16 * i + 5); + sub_at!(16 * i + 6); + sub_at!(16 * i + 7); + sub_at!(16 * i + 8); + sub_at!(16 * i + 9); + sub_at!(16 * i + 10); + sub_at!(16 * i + 11); + sub_at!(16 * i + 12); + sub_at!(16 * i + 13); + sub_at!(16 * i + 14); + sub_at!(16 * i + 15); + } + } else { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi)); + } + } + + /// Modular multiplication of vectors in place in constant time. + /// + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + pub fn mul_vec(&self, a: &mut [u64], b: &[u64]) { + debug_assert_eq!(a.len(), b.len()); + + if self.supports_opt { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi)); + } else { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi)); + } + } + + /// Modular scalar multiplication of vectors in place in constant time. + /// + /// Aborts if any of the values in a is >= p in debug mode. + pub fn scalar_mul_vec(&self, a: &mut [u64], b: u64) { + let b_shoup = self.shoup(b); + a.iter_mut() + .for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup)); + } + + /// Modular scalar multiplication of vectors in place in variable time. + /// Aborts if any of the values in a is >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being multiplied. + pub unsafe fn scalar_mul_vec_vt(&self, a: &mut [u64], b: u64) { + let b_shoup = self.shoup(b); + a.iter_mut() + .for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup)); + } + + /// Modular multiplication of vectors in place in variable time. + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being subtracted. + pub unsafe fn mul_vec_vt(&self, a: &mut [u64], b: &[u64]) { + debug_assert_eq!(a.len(), b.len()); + + if self.supports_opt { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi)); + } else { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi)); + } + } + + /// Compute the Shoup representation of a vector. + /// + /// Aborts if any of the values of the vector is >= p in debug mode. + pub fn shoup_vec(&self, a: &[u64]) -> Vec { + a.iter().map(|ai| self.shoup(*ai)).collect_vec() + } + + /// Shoup modular multiplication of vectors in place in constant time. + /// + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + pub fn mul_shoup_vec(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) { + debug_assert_eq!(a.len(), b.len()); + debug_assert_eq!(a.len(), b_shoup.len()); + debug_assert_eq!(&b_shoup, &self.shoup_vec(b)); + + izip!(a.iter_mut(), b.iter(), b_shoup.iter()) + .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup)); + } + + /// Shoup modular multiplication of vectors in place in variable time. + /// Aborts if a and b differ in size, and if any of their values is >= p in + /// debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being multiplied. + pub unsafe fn mul_shoup_vec_vt(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) { + debug_assert_eq!(a.len(), b.len()); + debug_assert_eq!(a.len(), b_shoup.len()); + debug_assert_eq!(&b_shoup, &self.shoup_vec(b)); + + izip!(a.iter_mut(), b.iter(), b_shoup.iter()) + .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup)); + } + + /// Reduce a vector in place in constant time. + pub fn reduce_vec(&self, a: &mut [u64]) { + a.iter_mut().for_each(|ai| *ai = self.reduce(*ai)); + } + + /// Center a value modulo p as i64 in variable time. + /// TODO: To test and to make constant time? + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being centered. + const unsafe fn center_vt(&self, a: u64) -> i64 { + debug_assert!(a < self.p); + + if a >= self.p >> 1 { + (a as i64) - (self.p as i64) + } else { + a as i64 + } + } + + /// Center a vector in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being centered. + pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec { + a.iter().map(|ai| self.center_vt(*ai)).collect_vec() + } + + /// Reduce a vector in place in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being reduced. + pub unsafe fn reduce_vec_vt(&self, a: &mut [u64]) { + a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai)); + } + + /// Modular reduction of a i64 in constant time. + const fn reduce_i64(&self, a: i64) -> u64 { + self.reduce_u128((((self.p as i128) << 64) + (a as i128)) as u128) + } + + /// Modular reduction of a i64 in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being reduced. + const unsafe fn reduce_i64_vt(&self, a: i64) -> u64 { + self.reduce_u128_vt((((self.p as i128) << 64) + (a as i128)) as u128) + } + + /// Reduce a vector in place in constant time. + pub fn reduce_vec_i64(&self, a: &[i64]) -> Vec { + a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec() + } + + /// Reduce a vector in place in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being reduced. + pub unsafe fn reduce_vec_i64_vt(&self, a: &[i64]) -> Vec { + a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect_vec() + } + + /// Reduce a vector in constant time. + pub fn reduce_vec_new(&self, a: &[u64]) -> Vec { + a.iter().map(|ai| self.reduce(*ai)).collect_vec() + } + + /// Reduce a vector in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being reduced. + pub unsafe fn reduce_vec_new_vt(&self, a: &[u64]) -> Vec { + a.iter().map(|bi| self.reduce_vt(*bi)).collect_vec() + } + + /// Modular negation of a vector in place in constant time. + /// + /// Aborts if any of the values in the vector is >= p in debug mode. + pub fn neg_vec(&self, a: &mut [u64]) { + izip!(a.iter_mut()).for_each(|ai| *ai = self.neg(*ai)); + } + + /// Modular negation of a vector in place in variable time. + /// Aborts if any of the values in the vector is >= p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the values being negated. + pub unsafe fn neg_vec_vt(&self, a: &mut [u64]) { + izip!(a.iter_mut()).for_each(|ai| *ai = self.neg_vt(*ai)); + } + + /// Modular exponentiation in variable time. + /// + /// Aborts if a >= p or n >= p in debug mode. + pub fn pow(&self, a: u64, n: u64) -> u64 { + debug_assert!(a < self.p && n < self.p); + + if n == 0 { + 1 + } else if n == 1 { + a + } else { + let mut r = a; + let mut i = (62 - n.leading_zeros()) as isize; + while i >= 0 { + r = self.mul(r, r); + if (n >> i) & 1 == 1 { + r = self.mul(r, a); + } + i -= 1; + } + r + } + } + + /// Modular inversion in variable time. + /// + /// Returns None if p is not prime or a = 0. + /// Aborts if a >= p in debug mode. + pub fn inv(&self, a: u64) -> std::option::Option { + if !is_prime(self.p) || a == 0 { + None + } else { + let r = self.pow(a, self.p - 2); + debug_assert_eq!(self.mul(a, r), 1); + Some(r) + } + } + + /// Modular reduction of a u128 in constant time. + pub const fn reduce_u128(&self, a: u128) -> u64 { + Self::reduce1(self.lazy_reduce_u128(a), self.p) + } + + /// Modular reduction of a u128 in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub const unsafe fn reduce_u128_vt(&self, a: u128) -> u64 { + Self::reduce1_vt(self.lazy_reduce_u128(a), self.p) + } + + /// Modular reduction of a u64 in constant time. + pub const fn reduce(&self, a: u64) -> u64 { + Self::reduce1(self.lazy_reduce(a), self.p) + } + + /// Modular reduction of a u64 in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub const unsafe fn reduce_vt(&self, a: u64) -> u64 { + Self::reduce1_vt(self.lazy_reduce(a), self.p) + } + + /// Optimized modular reduction of a u128 in constant time. + pub const fn reduce_opt_u128(&self, a: u128) -> u64 { + debug_assert!(self.supports_opt); + Self::reduce1(self.lazy_reduce_opt_u128(a), self.p) + } + + /// Optimized modular reduction of a u128 in constant time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub(crate) const unsafe fn reduce_opt_u128_vt(&self, a: u128) -> u64 { + debug_assert!(self.supports_opt); + Self::reduce1_vt(self.lazy_reduce_opt_u128(a), self.p) + } + + /// Optimized modular reduction of a u64 in constant time. + pub const fn reduce_opt(&self, a: u64) -> u64 { + Self::reduce1(self.lazy_reduce_opt(a), self.p) + } + + /// Optimized modular reduction of a u64 in variable time. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub const unsafe fn reduce_opt_vt(&self, a: u64) -> u64 { + Self::reduce1_vt(self.lazy_reduce_opt(a), self.p) + } + + /// Return x mod p in constant time. + /// Aborts if x >= 2 * p in debug mode. + const fn reduce1(x: u64, p: u64) -> u64 { + debug_assert!(p >> 63 == 0); + debug_assert!(x < 2 * p); + + let (y, _) = x.overflowing_sub(p); + let xp = x ^ p; + let yp = y ^ p; + let xy = xp ^ yp; + let xxy = x ^ xy; + let xxy = xxy >> 63; + let (c, _) = xxy.overflowing_sub(1); + let r = (c & y) | ((!c) & x); + + debug_assert!(r == x % p); + + r + } + + /// Return x mod p in variable time. + /// Aborts if x >= 2 * p in debug mode. + /// + /// # Safety + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 { + debug_assert!(p >> 63 == 0); + debug_assert!(x < 2 * p); + + if x >= p { + x - p + } else { + x + } + } + + /// Lazy modular reduction of a in constant time. + /// The output is in the interval [0, 2 * p). + pub const fn lazy_reduce_u128(&self, a: u128) -> u64 { + let a_lo = a as u64; + let a_hi = (a >> 64) as u64; + let p_lo_lo = ((a_lo as u128) * (self.barrett_lo as u128)) >> 64; + let p_hi_lo = (a_hi as u128) * (self.barrett_lo as u128); + let p_lo_hi = (a_lo as u128) * (self.barrett_hi as u128); + + let q = ((p_lo_hi + p_hi_lo + p_lo_lo) >> 64) + (a_hi as u128) * (self.barrett_hi as u128); + let r = (a - q * (self.p as u128)) as u64; + + debug_assert!((r as u128) < 2 * (self.p as u128)); + debug_assert!(r % self.p == (a % (self.p as u128)) as u64); + + r + } + + /// Lazy modular reduction of a in constant time. + /// The output is in the interval [0, 2 * p). + pub const fn lazy_reduce(&self, a: u64) -> u64 { + let p_lo_lo = ((a as u128) * (self.barrett_lo as u128)) >> 64; + let p_lo_hi = (a as u128) * (self.barrett_hi as u128); + + let q = (p_lo_hi + p_lo_lo) >> 64; + let r = (a as u128 - q * (self.p as u128)) as u64; + + debug_assert!((r as u128) < 2 * (self.p as u128)); + debug_assert!(r % self.p == a % self.p); + + r + } + + /// Lazy optimized modular reduction of a in constant time. + /// The output is in the interval [0, 2 * p). + /// + /// Aborts if the input is >= p ^ 2 in debug mode. + pub const fn lazy_reduce_opt_u128(&self, a: u128) -> u64 { + debug_assert!(a < (self.p as u128) * (self.p as u128)); + + let q = (((self.barrett_lo as u128) * (a >> 64)) + (a << self.leading_zeros)) >> 64; + let r = (a - q * (self.p as u128)) as u64; + + debug_assert!((r as u128) < 2 * (self.p as u128)); + debug_assert!(r % self.p == (a % (self.p as u128)) as u64); + + r + } + + /// Lazy optimized modular reduction of a in constant time. + /// The output is in the interval [0, 2 * p). + const fn lazy_reduce_opt(&self, a: u64) -> u64 { + let q = a >> (64 - self.leading_zeros); + let r = ((a as u128) - (q as u128) * (self.p as u128)) as u64; + + debug_assert!((r as u128) < 2 * (self.p as u128)); + debug_assert!(r % self.p == a % self.p); + + r + } + + /// Lazy modular reduction of a vector in constant time. + /// The output coefficients are in the interval [0, 2 * p). + pub fn lazy_reduce_vec(&self, a: &mut [u64]) { + if self.supports_opt { + a.iter_mut().for_each(|ai| *ai = self.lazy_reduce_opt(*ai)) + } else { + a.iter_mut().for_each(|ai| *ai = self.lazy_reduce(*ai)) + } + } + + /// Returns a random vector. + pub fn random_vec(&self, size: usize, rng: &mut R) -> Vec { + rng.sample_iter(self.distribution).take(size).collect_vec() + } + + /// Length of the serialization of a vector of size `size`. + /// + /// Panics if the size is not a multiple of 8. + pub const fn serialization_length(&self, size: usize) -> usize { + assert!(size % 8 == 0); + let p_nbits = 64 - (self.p - 1).leading_zeros() as usize; + p_nbits * size / 8 + } + + /// Serialize a vector of elements of length a multiple of 8. + /// + /// Panics if the length of the vector is not a multiple of 8. + pub fn serialize_vec(&self, a: &[u64]) -> Vec { + let p_nbits = 64 - (self.p - 1).leading_zeros() as usize; + transcode_to_bytes(a, p_nbits) + } + + /// Deserialize a vector of bytes into a vector of elements mod p. + pub fn deserialize_vec(&self, b: &[u8]) -> Vec { + let p_nbits = 64 - (self.p - 1).leading_zeros() as usize; + transcode_from_bytes(b, p_nbits) + } } #[cfg(test)] mod tests { - use super::{primes, Modulus}; - use fhe_util::catch_unwind; - use itertools::{izip, Itertools}; - use proptest::collection::vec as prop_vec; - use proptest::prelude::{any, BoxedStrategy, Just, Strategy}; - use rand::{thread_rng, RngCore}; + use super::{primes, Modulus}; + use fhe_util::catch_unwind; + use itertools::{izip, Itertools}; + use proptest::collection::vec as prop_vec; + use proptest::prelude::{any, BoxedStrategy, Just, Strategy}; + use rand::{thread_rng, RngCore}; - // Utility functions for the proptests. + // Utility functions for the proptests. - fn valid_moduli() -> impl Strategy { - any::().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok()) - } + fn valid_moduli() -> impl Strategy { + any::().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok()) + } - fn vecs() -> BoxedStrategy<(Vec, Vec)> { - prop_vec(any::(), 1..100) - .prop_flat_map(|vec| { - let len = vec.len(); - (Just(vec), prop_vec(any::(), len)) - }) - .boxed() - } + fn vecs() -> BoxedStrategy<(Vec, Vec)> { + prop_vec(any::(), 1..100) + .prop_flat_map(|vec| { + let len = vec.len(); + (Just(vec), prop_vec(any::(), len)) + }) + .boxed() + } - proptest! { - #[test] - fn constructor(p: u64) { - // 63 and 64-bit integers do not work. - prop_assert!(Modulus::new(p | (1u64 << 62)).is_err()); - prop_assert!(Modulus::new(p | (1u64 << 63)).is_err()); + proptest! { + #[test] + fn constructor(p: u64) { + // 63 and 64-bit integers do not work. + prop_assert!(Modulus::new(p | (1u64 << 62)).is_err()); + prop_assert!(Modulus::new(p | (1u64 << 63)).is_err()); - // p = 0 & 1 do not work. - prop_assert!(Modulus::new(0u64).is_err()); - prop_assert!(Modulus::new(1u64).is_err()); + // p = 0 & 1 do not work. + prop_assert!(Modulus::new(0u64).is_err()); + prop_assert!(Modulus::new(1u64).is_err()); - // Otherwise, all moduli should work. - prop_assume!(p >> 2 >= 2); - let q = Modulus::new(p >> 2); - prop_assert!(q.is_ok()); - prop_assert_eq!(q.unwrap().modulus(), p >> 2); - } + // Otherwise, all moduli should work. + prop_assume!(p >> 2 >= 2); + let q = Modulus::new(p >> 2); + prop_assert!(q.is_ok()); + prop_assert_eq!(q.unwrap().modulus(), p >> 2); + } - #[test] - fn neg(p in valid_moduli(), mut a: u64, mut q: u64) { - a = p.reduce(a); - prop_assert_eq!(p.neg(a), (p.modulus() - a) % p.modulus()); - unsafe { prop_assert_eq!(p.neg_vt(a), (p.modulus() - a) % p.modulus()) } + #[test] + fn neg(p in valid_moduli(), mut a: u64, mut q: u64) { + a = p.reduce(a); + prop_assert_eq!(p.neg(a), (p.modulus() - a) % p.modulus()); + unsafe { prop_assert_eq!(p.neg_vt(a), (p.modulus() - a) % p.modulus()) } - q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p - prop_assert!(catch_unwind(|| p.neg(q)).is_err()); - } + q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p + prop_assert!(catch_unwind(|| p.neg(q)).is_err()); + } - #[test] - fn add(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { - a = p.reduce(a); - b = p.reduce(b); - prop_assert_eq!(p.add(a, b), (a + b) % p.modulus()); - unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % p.modulus()) } + #[test] + fn add(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { + a = p.reduce(a); + b = p.reduce(b); + prop_assert_eq!(p.add(a, b), (a + b) % p.modulus()); + unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % p.modulus()) } - q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p - prop_assert!(catch_unwind(|| p.add(q, a)).is_err()); - prop_assert!(catch_unwind(|| p.add(a, q)).is_err()); - } + q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p + prop_assert!(catch_unwind(|| p.add(q, a)).is_err()); + prop_assert!(catch_unwind(|| p.add(a, q)).is_err()); + } - #[test] - fn sub(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { - a = p.reduce(a); - b = p.reduce(b); - prop_assert_eq!(p.sub(a, b), (a + p.modulus() - b) % p.modulus()); - unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + p.modulus() - b) % p.modulus()) } + #[test] + fn sub(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { + a = p.reduce(a); + b = p.reduce(b); + prop_assert_eq!(p.sub(a, b), (a + p.modulus() - b) % p.modulus()); + unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + p.modulus() - b) % p.modulus()) } - q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p - prop_assert!(catch_unwind(|| p.sub(q, a)).is_err()); - prop_assert!(catch_unwind(|| p.sub(a, q)).is_err()); - } + q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p + prop_assert!(catch_unwind(|| p.sub(q, a)).is_err()); + prop_assert!(catch_unwind(|| p.sub(a, q)).is_err()); + } - #[test] - fn mul(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { - a = p.reduce(a); - b = p.reduce(b); - prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) } + #[test] + fn mul(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { + a = p.reduce(a); + b = p.reduce(b); + prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)); + unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) } - q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p - prop_assert!(catch_unwind(|| p.mul(q, a)).is_err()); - prop_assert!(catch_unwind(|| p.mul(a, q)).is_err()); - } + q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p + prop_assert!(catch_unwind(|| p.mul(q, a)).is_err()); + prop_assert!(catch_unwind(|| p.mul(a, q)).is_err()); + } - #[test] - fn mul_shoup(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { - a = p.reduce(a); - b = p.reduce(b); - q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p + #[test] + fn mul_shoup(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) { + a = p.reduce(a); + b = p.reduce(b); + q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); // q > p - // Compute shoup representation - let b_shoup = p.shoup(b); - prop_assert!(catch_unwind(|| p.shoup(q)).is_err()); + // Compute shoup representation + let b_shoup = p.shoup(b); + prop_assert!(catch_unwind(|| p.shoup(q)).is_err()); - // Check that the multiplication yields the expected result - prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) } + // Check that the multiplication yields the expected result + prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)); + unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) } - // Check that the multiplication with incorrect b_shoup panics in debug mode - prop_assert!(catch_unwind(|| p.mul_shoup(a, q, b_shoup)).is_err()); - prop_assume!(a != b); - prop_assert!(catch_unwind(|| p.mul_shoup(a, a, b_shoup)).is_err()); - } + // Check that the multiplication with incorrect b_shoup panics in debug mode + prop_assert!(catch_unwind(|| p.mul_shoup(a, q, b_shoup)).is_err()); + prop_assume!(a != b); + prop_assert!(catch_unwind(|| p.mul_shoup(a, a, b_shoup)).is_err()); + } - #[test] - fn reduce(p in valid_moduli(), a: u64) { - prop_assert_eq!(p.reduce(a), a % p.modulus()); - unsafe { prop_assert_eq!(p.reduce_vt(a), a % p.modulus()) } - if p.supports_opt { - prop_assert_eq!(p.reduce_opt(a), a % p.modulus()); - unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % p.modulus()) } - } - } + #[test] + fn reduce(p in valid_moduli(), a: u64) { + prop_assert_eq!(p.reduce(a), a % p.modulus()); + unsafe { prop_assert_eq!(p.reduce_vt(a), a % p.modulus()) } + if p.supports_opt { + prop_assert_eq!(p.reduce_opt(a), a % p.modulus()); + unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % p.modulus()) } + } + } - #[test] - fn lazy_reduce(p in valid_moduli(), a: u64) { - prop_assert!(p.lazy_reduce(a) < 2 * p.modulus()); - prop_assert_eq!(p.lazy_reduce(a) % p.modulus(), p.reduce(a)); - } + #[test] + fn lazy_reduce(p in valid_moduli(), a: u64) { + prop_assert!(p.lazy_reduce(a) < 2 * p.modulus()); + prop_assert_eq!(p.lazy_reduce(a) % p.modulus(), p.reduce(a)); + } - #[test] - fn reduce_i64(p in valid_moduli(), a: i64) { - let b = if a < 0 { p.neg(p.reduce(-a as u64)) } else { p.reduce(a as u64) }; - prop_assert_eq!(p.reduce_i64(a), b); - unsafe { prop_assert_eq!(p.reduce_i64_vt(a), b) } - } + #[test] + fn reduce_i64(p in valid_moduli(), a: i64) { + let b = if a < 0 { p.neg(p.reduce(-a as u64)) } else { p.reduce(a as u64) }; + prop_assert_eq!(p.reduce_i64(a), b); + unsafe { prop_assert_eq!(p.reduce_i64_vt(a), b) } + } - #[test] - fn reduce_u128(p in valid_moduli(), mut a: u128) { - prop_assert_eq!(p.reduce_u128(a) as u128, a % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (p.modulus() as u128)) } - if p.supports_opt { - let p_square = (p.modulus() as u128) * (p.modulus() as u128); - a %= p_square; - prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (p.modulus() as u128)) } - } - } + #[test] + fn reduce_u128(p in valid_moduli(), mut a: u128) { + prop_assert_eq!(p.reduce_u128(a) as u128, a % (p.modulus() as u128)); + unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (p.modulus() as u128)) } + if p.supports_opt { + let p_square = (p.modulus() as u128) * (p.modulus() as u128); + a %= p_square; + prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (p.modulus() as u128)); + unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (p.modulus() as u128)) } + } + } - #[test] - fn add_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { - p.reduce_vec(&mut a); - p.reduce_vec(&mut b); - let c = a.clone(); - p.add_vec(&mut a, &b); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); - a = c.clone(); - unsafe { p.add_vec_vt(&mut a, &b) } - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); - } + #[test] + fn add_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { + p.reduce_vec(&mut a); + p.reduce_vec(&mut b); + let c = a.clone(); + p.add_vec(&mut a, &b); + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); + a = c.clone(); + unsafe { p.add_vec_vt(&mut a, &b) } + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); + } - #[test] - fn sub_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { - p.reduce_vec(&mut a); - p.reduce_vec(&mut b); - let c = a.clone(); - p.sub_vec(&mut a, &b); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); - a = c.clone(); - unsafe { p.sub_vec_vt(&mut a, &b) } - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); - } + #[test] + fn sub_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { + p.reduce_vec(&mut a); + p.reduce_vec(&mut b); + let c = a.clone(); + p.sub_vec(&mut a, &b); + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); + a = c.clone(); + unsafe { p.sub_vec_vt(&mut a, &b) } + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); + } - #[test] - fn mul_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { - p.reduce_vec(&mut a); - p.reduce_vec(&mut b); - let c = a.clone(); - p.mul_vec(&mut a, &b); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); - a = c.clone(); - unsafe { p.mul_vec_vt(&mut a, &b); } - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); - } + #[test] + fn mul_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { + p.reduce_vec(&mut a); + p.reduce_vec(&mut b); + let c = a.clone(); + p.mul_vec(&mut a, &b); + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); + a = c.clone(); + unsafe { p.mul_vec_vt(&mut a, &b); } + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); + } - #[test] - fn scalar_mul_vec(p in valid_moduli(), mut a: Vec, mut b: u64) { - p.reduce_vec(&mut a); - b = p.reduce(b); - let c = a.clone(); + #[test] + fn scalar_mul_vec(p in valid_moduli(), mut a: Vec, mut b: u64) { + p.reduce_vec(&mut a); + b = p.reduce(b); + let c = a.clone(); - p.scalar_mul_vec(&mut a, b); - prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); + p.scalar_mul_vec(&mut a, b); + prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); - a = c.clone(); - unsafe { p.scalar_mul_vec_vt(&mut a, b) } - prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); - } + a = c.clone(); + unsafe { p.scalar_mul_vec_vt(&mut a, b) } + prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); + } - #[test] - fn mul_shoup_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { - p.reduce_vec(&mut a); - p.reduce_vec(&mut b); - let b_shoup = p.shoup_vec(&b); - let c = a.clone(); - p.mul_shoup_vec(&mut a, &b, &b_shoup); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); - a = c.clone(); - unsafe { p.mul_shoup_vec_vt(&mut a, &b, &b_shoup) } - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); - } + #[test] + fn mul_shoup_vec(p in valid_moduli(), (mut a, mut b) in vecs()) { + p.reduce_vec(&mut a); + p.reduce_vec(&mut b); + let b_shoup = p.shoup_vec(&b); + let c = a.clone(); + p.mul_shoup_vec(&mut a, &b, &b_shoup); + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); + a = c.clone(); + unsafe { p.mul_shoup_vec_vt(&mut a, &b, &b_shoup) } + prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); + } - #[test] - fn reduce_vec(p in valid_moduli(), a: Vec) { - let mut b = a.clone(); - p.reduce_vec(&mut b); - prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); + #[test] + fn reduce_vec(p in valid_moduli(), a: Vec) { + let mut b = a.clone(); + p.reduce_vec(&mut b); + prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); - b = a.clone(); - unsafe { p.reduce_vec_vt(&mut b) } - prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); - } + b = a.clone(); + unsafe { p.reduce_vec_vt(&mut b) } + prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); + } - #[test] - fn lazy_reduce_vec(p in valid_moduli(), a: Vec) { - let mut b = a.clone(); - p.lazy_reduce_vec(&mut b); - prop_assert!(b.iter().all(|bi| *bi < 2 * p.modulus())); - prop_assert!(izip!(a, b).all(|(ai, bi)| bi % p.modulus() == ai % p.modulus())); - } + #[test] + fn lazy_reduce_vec(p in valid_moduli(), a: Vec) { + let mut b = a.clone(); + p.lazy_reduce_vec(&mut b); + prop_assert!(b.iter().all(|bi| *bi < 2 * p.modulus())); + prop_assert!(izip!(a, b).all(|(ai, bi)| bi % p.modulus() == ai % p.modulus())); + } - #[test] - fn reduce_vec_new(p in valid_moduli(), a: Vec) { - let b = p.reduce_vec_new(&a); - prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); - prop_assert_eq!(p.reduce_vec_new(&a), unsafe { p.reduce_vec_new_vt(&a) }); - } + #[test] + fn reduce_vec_new(p in valid_moduli(), a: Vec) { + let b = p.reduce_vec_new(&a); + prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); + prop_assert_eq!(p.reduce_vec_new(&a), unsafe { p.reduce_vec_new_vt(&a) }); + } - #[test] - fn reduce_vec_i64(p in valid_moduli(), a: Vec) { - let b = p.reduce_vec_i64(&a); - prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec()); - let b = unsafe { p.reduce_vec_i64_vt(&a) }; - prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec()); - } + #[test] + fn reduce_vec_i64(p in valid_moduli(), a: Vec) { + let b = p.reduce_vec_i64(&a); + prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec()); + let b = unsafe { p.reduce_vec_i64_vt(&a) }; + prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec()); + } - #[test] - fn neg_vec(p in valid_moduli(), mut a: Vec) { - p.reduce_vec(&mut a); - let mut b = a.clone(); - p.neg_vec(&mut b); - prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec()); - b = a.clone(); - unsafe { p.neg_vec_vt(&mut b); } - prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec()); - } + #[test] + fn neg_vec(p in valid_moduli(), mut a: Vec) { + p.reduce_vec(&mut a); + let mut b = a.clone(); + p.neg_vec(&mut b); + prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec()); + b = a.clone(); + unsafe { p.neg_vec_vt(&mut b); } + prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec()); + } - #[test] - fn random_vec(p in valid_moduli(), size in 1..1000usize) { - let mut rng = thread_rng(); + #[test] + fn random_vec(p in valid_moduli(), size in 1..1000usize) { + let mut rng = thread_rng(); - let v = p.random_vec(size, &mut rng); - prop_assert_eq!(v.len(), size); + let v = p.random_vec(size, &mut rng); + prop_assert_eq!(v.len(), size); - let w = p.random_vec(size, &mut rng); - prop_assert_eq!(w.len(), size); + let w = p.random_vec(size, &mut rng); + prop_assert_eq!(w.len(), size); - if p.modulus().leading_zeros() <= 30 { - prop_assert_ne!(v, w); // This will hold with probability at least 2^(-30) - } - } + if p.modulus().leading_zeros() <= 30 { + prop_assert_ne!(v, w); // This will hold with probability at least 2^(-30) + } + } - #[test] - fn serialize(p in valid_moduli(), mut a in prop_vec(any::(), 8)) { - p.reduce_vec(&mut a); - let b = p.serialize_vec(&a); - let c = p.deserialize_vec(&b); - prop_assert_eq!(a, c); - } - } + #[test] + fn serialize(p in valid_moduli(), mut a in prop_vec(any::(), 8)) { + p.reduce_vec(&mut a); + let b = p.serialize_vec(&a); + let c = p.deserialize_vec(&b); + prop_assert_eq!(a, c); + } + } - // TODO: Make a proptest. - #[test] - fn mul_opt() { - let ntests = 100; - let mut rng = rand::thread_rng(); + // TODO: Make a proptest. + #[test] + fn mul_opt() { + let ntests = 100; + let mut rng = rand::thread_rng(); - #[allow(clippy::single_element_loop)] - for p in [4611686018326724609] { - let q = Modulus::new(p).unwrap(); - assert!(primes::supports_opt(p)); + #[allow(clippy::single_element_loop)] + for p in [4611686018326724609] { + let q = Modulus::new(p).unwrap(); + assert!(primes::supports_opt(p)); - assert_eq!(q.mul_opt(0, 1), 0); - assert_eq!(q.mul_opt(1, 1), 1); - assert_eq!(q.mul_opt(2 % p, 3 % p), 6 % p); - assert_eq!(q.mul_opt(p - 1, 1), p - 1); - assert_eq!(q.mul_opt(p - 1, 2 % p), p - 2); + assert_eq!(q.mul_opt(0, 1), 0); + assert_eq!(q.mul_opt(1, 1), 1); + assert_eq!(q.mul_opt(2 % p, 3 % p), 6 % p); + assert_eq!(q.mul_opt(p - 1, 1), p - 1); + assert_eq!(q.mul_opt(p - 1, 2 % p), p - 2); - assert!(catch_unwind(|| q.mul_opt(p, 1)).is_err()); - assert!(catch_unwind(|| q.mul_opt(p << 1, 1)).is_err()); - assert!(catch_unwind(|| q.mul_opt(0, p)).is_err()); - assert!(catch_unwind(|| q.mul_opt(0, p << 1)).is_err()); + assert!(catch_unwind(|| q.mul_opt(p, 1)).is_err()); + assert!(catch_unwind(|| q.mul_opt(p << 1, 1)).is_err()); + assert!(catch_unwind(|| q.mul_opt(0, p)).is_err()); + assert!(catch_unwind(|| q.mul_opt(0, p << 1)).is_err()); - for _ in 0..ntests { - let a = rng.next_u64() % p; - let b = rng.next_u64() % p; - assert_eq!( - q.mul_opt(a, b), - (((a as u128) * (b as u128)) % (p as u128)) as u64 - ); - } - } - } + for _ in 0..ntests { + let a = rng.next_u64() % p; + let b = rng.next_u64() % p; + assert_eq!( + q.mul_opt(a, b), + (((a as u128) * (b as u128)) % (p as u128)) as u64 + ); + } + } + } - // TODO: Make a proptest. - #[test] - fn pow() { - let ntests = 10; - let mut rng = rand::thread_rng(); + // TODO: Make a proptest. + #[test] + fn pow() { + let ntests = 10; + let mut rng = rand::thread_rng(); - for p in [2u64, 3, 17, 1987, 4611686018326724609] { - let q = Modulus::new(p).unwrap(); + for p in [2u64, 3, 17, 1987, 4611686018326724609] { + let q = Modulus::new(p).unwrap(); - assert_eq!(q.pow(p - 1, 0), 1); - assert_eq!(q.pow(p - 1, 1), p - 1); - assert_eq!(q.pow(p - 1, 2 % p), 1); - assert_eq!(q.pow(1, p - 2), 1); - assert_eq!(q.pow(1, p - 1), 1); + assert_eq!(q.pow(p - 1, 0), 1); + assert_eq!(q.pow(p - 1, 1), p - 1); + assert_eq!(q.pow(p - 1, 2 % p), 1); + assert_eq!(q.pow(1, p - 2), 1); + assert_eq!(q.pow(1, p - 1), 1); - assert!(catch_unwind(|| q.pow(p, 1)).is_err()); - assert!(catch_unwind(|| q.pow(p << 1, 1)).is_err()); - assert!(catch_unwind(|| q.pow(0, p)).is_err()); - assert!(catch_unwind(|| q.pow(0, p << 1)).is_err()); + assert!(catch_unwind(|| q.pow(p, 1)).is_err()); + assert!(catch_unwind(|| q.pow(p << 1, 1)).is_err()); + assert!(catch_unwind(|| q.pow(0, p)).is_err()); + assert!(catch_unwind(|| q.pow(0, p << 1)).is_err()); - for _ in 0..ntests { - let a = rng.next_u64() % p; - let b = (rng.next_u64() % p) % 1000; - let mut c = b; - let mut r = 1; - while c > 0 { - r = q.mul(r, a); - c -= 1; - } - assert_eq!(q.pow(a, b), r); - } - } - } + for _ in 0..ntests { + let a = rng.next_u64() % p; + let b = (rng.next_u64() % p) % 1000; + let mut c = b; + let mut r = 1; + while c > 0 { + r = q.mul(r, a); + c -= 1; + } + assert_eq!(q.pow(a, b), r); + } + } + } - // TODO: Make a proptest. - #[test] - fn inv() { - let ntests = 100; - let mut rng = rand::thread_rng(); + // TODO: Make a proptest. + #[test] + fn inv() { + let ntests = 100; + let mut rng = rand::thread_rng(); - for p in [2u64, 3, 17, 1987, 4611686018326724609] { - let q = Modulus::new(p).unwrap(); + for p in [2u64, 3, 17, 1987, 4611686018326724609] { + let q = Modulus::new(p).unwrap(); - assert!(q.inv(0).is_none()); - assert_eq!(q.inv(1).unwrap(), 1); - assert_eq!(q.inv(p - 1).unwrap(), p - 1); + assert!(q.inv(0).is_none()); + assert_eq!(q.inv(1).unwrap(), 1); + assert_eq!(q.inv(p - 1).unwrap(), p - 1); - assert!(catch_unwind(|| q.inv(p)).is_err()); - assert!(catch_unwind(|| q.inv(p << 1)).is_err()); + assert!(catch_unwind(|| q.inv(p)).is_err()); + assert!(catch_unwind(|| q.inv(p << 1)).is_err()); - for _ in 0..ntests { - let a = rng.next_u64() % p; - let b = q.inv(a); + for _ in 0..ntests { + let a = rng.next_u64() % p; + let b = q.inv(a); - if a == 0 { - assert!(b.is_none()) - } else { - assert!(b.is_some()); - assert_eq!(q.mul(a, b.unwrap()), 1) - } - } - } - } + if a == 0 { + assert!(b.is_none()) + } else { + assert!(b.is_some()); + assert_eq!(q.mul(a, b.unwrap()), 1) + } + } + } + } } diff --git a/crates/fhe-math/src/zq/ntt.rs b/crates/fhe-math/src/zq/ntt.rs index 8c01c0b..5befc87 100644 --- a/crates/fhe-math/src/zq/ntt.rs +++ b/crates/fhe-math/src/zq/ntt.rs @@ -2,489 +2,483 @@ use super::Modulus; use fhe_util::is_prime; +use itertools::Itertools; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use std::iter::successors; /// Returns whether a modulus p is prime and supports the Number Theoretic /// Transform of size n. /// /// Aborts if n is not a power of 2 that is >= 8. pub fn supports_ntt(p: u64, n: usize) -> bool { - assert!(n >= 8 && n.is_power_of_two()); + assert!(n >= 8 && n.is_power_of_two()); - p % ((n as u64) << 1) == 1 && is_prime(p) + p % ((n as u64) << 1) == 1 && is_prime(p) } /// Number-Theoretic Transform operator. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NttOperator { - p: Modulus, - p_twice: u64, - size: usize, - omegas: Box<[u64]>, - omegas_shoup: Box<[u64]>, - omegas_inv: Box<[u64]>, - zetas_inv: Box<[u64]>, - zetas_inv_shoup: Box<[u64]>, - size_inv: u64, - size_inv_shoup: u64, + p: Modulus, + p_twice: u64, + size: usize, + omegas: Box<[u64]>, + omegas_shoup: Box<[u64]>, + zetas_inv: Box<[u64]>, + zetas_inv_shoup: Box<[u64]>, + size_inv: u64, + size_inv_shoup: u64, } impl NttOperator { - /// Create an NTT operator given a modulus for a specific size. - /// - /// Aborts if the size is not a power of 2 that is >= 8 in debug mode. - /// Returns None if the modulus does not support the NTT for this specific - /// size. - pub fn new(p: &Modulus, size: usize) -> Option { - if !supports_ntt(p.p, size) { - None - } else { - let omega = Self::primitive_root(size, p); - let omega_inv = p.inv(omega).unwrap(); + /// Create an NTT operator given a modulus for a specific size. + /// + /// Aborts if the size is not a power of 2 that is >= 8 in debug mode. + /// Returns None if the modulus does not support the NTT for this specific + /// size. + pub fn new(p: &Modulus, size: usize) -> Option { + if !supports_ntt(p.p, size) { + None + } else { + let size_inv = p.inv(size as u64)?; - let mut exp = 1u64; - let mut exp_inv = 1u64; - let mut powers = Vec::with_capacity(size + 1); - let mut powers_inv = Vec::with_capacity(size + 1); - for _ in 0..size + 1 { - powers.push(exp); - powers_inv.push(exp_inv); - exp = p.mul(exp, omega); - exp_inv = p.mul(exp_inv, omega_inv); - } + let omega = Self::primitive_root(size, p); + let omega_inv = p.inv(omega)?; - let mut omegas = Vec::with_capacity(size); - let mut omegas_inv = Vec::with_capacity(size); - let mut zetas_inv = Vec::with_capacity(size); - for i in 0..size { - let j = i.reverse_bits() >> (size.leading_zeros() + 1); - omegas.push(powers[j]); - omegas_inv.push(powers_inv[j]); - zetas_inv.push(powers_inv[j + 1]); - } + let powers = successors(Some(1u64), |n| Some(p.mul(*n, omega))) + .take(size) + .collect_vec(); + let powers_inv = successors(Some(omega_inv), |n| Some(p.mul(*n, omega_inv))) + .take(size) + .collect_vec(); - let size_inv = p.inv(size as u64).unwrap(); + let mut omegas = Vec::with_capacity(size); + let mut zetas_inv = Vec::with_capacity(size); + for i in 0..size { + let j = i.reverse_bits() >> (size.leading_zeros() + 1); + omegas.push(powers[j]); + zetas_inv.push(powers_inv[j]); + } - let omegas_shoup = p.shoup_vec(&omegas); - let zetas_inv_shoup = p.shoup_vec(&zetas_inv); + let omegas_shoup = p.shoup_vec(&omegas); + let zetas_inv_shoup = p.shoup_vec(&zetas_inv); - Some(Self { - p: p.clone(), - p_twice: p.p * 2, - size, - omegas: omegas.into_boxed_slice(), - omegas_shoup: omegas_shoup.into_boxed_slice(), - omegas_inv: omegas_inv.into_boxed_slice(), - zetas_inv: zetas_inv.into_boxed_slice(), - zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(), - size_inv, - size_inv_shoup: p.shoup(size_inv), - }) - } - } + Some(Self { + p: p.clone(), + p_twice: p.p * 2, + size, + omegas: omegas.into_boxed_slice(), + omegas_shoup: omegas_shoup.into_boxed_slice(), + zetas_inv: zetas_inv.into_boxed_slice(), + zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(), + size_inv, + size_inv_shoup: p.shoup(size_inv), + }) + } + } - /// Compute the forward NTT in place. - /// Aborts if a is not of the size handled by the operator. - pub fn forward(&self, a: &mut [u64]) { - debug_assert_eq!(a.len(), self.size); + /// Compute the forward NTT in place. + /// Aborts if a is not of the size handled by the operator. + pub fn forward(&self, a: &mut [u64]) { + debug_assert_eq!(a.len(), self.size); - let n = self.size; - let a_ptr = a.as_mut_ptr(); + let n = self.size; + let a_ptr = a.as_mut_ptr(); - let mut l = n >> 1; - let mut m = 1; - let mut k = 1; - while l > 0 { - for i in 0..m { - unsafe { - let omega = *self.omegas.get_unchecked(k); - let omega_shoup = *self.omegas_shoup.get_unchecked(k); - k += 1; + let mut l = n >> 1; + let mut m = 1; + let mut k = 1; + while l > 0 { + for i in 0..m { + unsafe { + let omega = *self.omegas.get_unchecked(k); + let omega_shoup = *self.omegas_shoup.get_unchecked(k); + k += 1; - let s = 2 * i * l; - match l { - 1 => { - // The last level should reduce the output - let uj = &mut *a_ptr.add(s); - let ujl = &mut *a_ptr.add(s + l); - self.butterfly(uj, ujl, omega, omega_shoup); - *uj = self.reduce3(*uj); - *ujl = self.reduce3(*ujl); - } - _ => { - for j in s..(s + l) { - self.butterfly( - &mut *a_ptr.add(j), - &mut *a_ptr.add(j + l), - omega, - omega_shoup, - ); - } - } - } - } - } - l >>= 1; - m <<= 1; - } - } + let s = 2 * i * l; + match l { + 1 => { + // The last level should reduce the output + let uj = &mut *a_ptr.add(s); + let ujl = &mut *a_ptr.add(s + l); + self.butterfly(uj, ujl, omega, omega_shoup); + *uj = self.reduce3(*uj); + *ujl = self.reduce3(*ujl); + } + _ => { + for j in s..(s + l) { + self.butterfly( + &mut *a_ptr.add(j), + &mut *a_ptr.add(j + l), + omega, + omega_shoup, + ); + } + } + } + } + } + l >>= 1; + m <<= 1; + } + } - /// Compute the backward NTT in place. - /// Aborts if a is not of the size handled by the operator. - pub fn backward(&self, a: &mut [u64]) { - debug_assert_eq!(a.len(), self.size); + /// Compute the backward NTT in place. + /// Aborts if a is not of the size handled by the operator. + pub fn backward(&self, a: &mut [u64]) { + debug_assert_eq!(a.len(), self.size); - let a_ptr = a.as_mut_ptr(); + let a_ptr = a.as_mut_ptr(); - let mut k = 0; - let mut m = self.size >> 1; - let mut l = 1; - while m > 0 { - for i in 0..m { - let s = 2 * i * l; - unsafe { - let zeta_inv = *self.zetas_inv.get_unchecked(k); - let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k); - k += 1; - match l { - 1 => { - self.inv_butterfly( - &mut *a_ptr.add(s), - &mut *a_ptr.add(s + l), - zeta_inv, - zeta_inv_shoup, - ); - } - _ => { - for j in s..(s + l) { - self.inv_butterfly( - &mut *a_ptr.add(j), - &mut *a_ptr.add(j + l), - zeta_inv, - zeta_inv_shoup, - ); - } - } - } - } - } - l <<= 1; - m >>= 1; - } + let mut k = 0; + let mut m = self.size >> 1; + let mut l = 1; + while m > 0 { + for i in 0..m { + let s = 2 * i * l; + unsafe { + let zeta_inv = *self.zetas_inv.get_unchecked(k); + let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k); + k += 1; + match l { + 1 => { + self.inv_butterfly( + &mut *a_ptr.add(s), + &mut *a_ptr.add(s + l), + zeta_inv, + zeta_inv_shoup, + ); + } + _ => { + for j in s..(s + l) { + self.inv_butterfly( + &mut *a_ptr.add(j), + &mut *a_ptr.add(j + l), + zeta_inv, + zeta_inv_shoup, + ); + } + } + } + } + } + l <<= 1; + m >>= 1; + } - a.iter_mut() - .for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup)); - } + a.iter_mut() + .for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup)); + } - /// Compute the forward NTT in place in variable time in a lazily fashion. - /// This means that the output coefficients may be up to 4 times the - /// modulus. - /// - /// # Safety - /// This function assumes that a_ptr points to at least `size` elements. - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) { - let mut l = self.size >> 1; - let mut m = 1; - let mut k = 1; - while l > 0 { - for i in 0..m { - let omega = *self.omegas.get_unchecked(k); - let omega_shoup = *self.omegas_shoup.get_unchecked(k); - k += 1; + /// Compute the forward NTT in place in variable time in a lazily fashion. + /// This means that the output coefficients may be up to 4 times the + /// modulus. + /// + /// # Safety + /// This function assumes that a_ptr points to at least `size` elements. + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) { + let mut l = self.size >> 1; + let mut m = 1; + let mut k = 1; + while l > 0 { + for i in 0..m { + let omega = *self.omegas.get_unchecked(k); + let omega_shoup = *self.omegas_shoup.get_unchecked(k); + k += 1; - let s = 2 * i * l; - match l { - 1 => { - self.butterfly_vt( - &mut *a_ptr.add(s), - &mut *a_ptr.add(s + l), - omega, - omega_shoup, - ); - } - _ => { - for j in s..(s + l) { - self.butterfly_vt( - &mut *a_ptr.add(j), - &mut *a_ptr.add(j + l), - omega, - omega_shoup, - ); - } - } - } - } - l >>= 1; - m <<= 1; - } - } + let s = 2 * i * l; + match l { + 1 => { + self.butterfly_vt( + &mut *a_ptr.add(s), + &mut *a_ptr.add(s + l), + omega, + omega_shoup, + ); + } + _ => { + for j in s..(s + l) { + self.butterfly_vt( + &mut *a_ptr.add(j), + &mut *a_ptr.add(j + l), + omega, + omega_shoup, + ); + } + } + } + } + l >>= 1; + m <<= 1; + } + } - /// Compute the forward NTT in place in variable time. - /// - /// # Safety - /// This function assumes that a_ptr points to at least `size` elements. - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub unsafe fn forward_vt(&self, a_ptr: *mut u64) { - self.forward_vt_lazy(a_ptr); - for i in 0..self.size { - *a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i)) - } - } + /// Compute the forward NTT in place in variable time. + /// + /// # Safety + /// This function assumes that a_ptr points to at least `size` elements. + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub unsafe fn forward_vt(&self, a_ptr: *mut u64) { + self.forward_vt_lazy(a_ptr); + for i in 0..self.size { + *a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i)) + } + } - /// Compute the backward NTT in place in variable time. - /// - /// # Safety - /// This function assumes that a_ptr points to at least `size` elements. - /// This function is not constant time and its timing may reveal information - /// about the value being reduced. - pub unsafe fn backward_vt(&self, a_ptr: *mut u64) { - let mut k = 0; - let mut m = self.size >> 1; - let mut l = 1; - while m > 0 { - for i in 0..m { - let s = 2 * i * l; - let zeta_inv = *self.zetas_inv.get_unchecked(k); - let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k); - k += 1; - match l { - 1 => { - self.inv_butterfly_vt( - &mut *a_ptr.add(s), - &mut *a_ptr.add(s + l), - zeta_inv, - zeta_inv_shoup, - ); - } - _ => { - for j in s..(s + l) { - self.inv_butterfly_vt( - &mut *a_ptr.add(j), - &mut *a_ptr.add(j + l), - zeta_inv, - zeta_inv_shoup, - ); - } - } - } - } - l <<= 1; - m >>= 1; - } + /// Compute the backward NTT in place in variable time. + /// + /// # Safety + /// This function assumes that a_ptr points to at least `size` elements. + /// This function is not constant time and its timing may reveal information + /// about the value being reduced. + pub unsafe fn backward_vt(&self, a_ptr: *mut u64) { + let mut k = 0; + let mut m = self.size >> 1; + let mut l = 1; + while m > 0 { + for i in 0..m { + let s = 2 * i * l; + let zeta_inv = *self.zetas_inv.get_unchecked(k); + let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k); + k += 1; + match l { + 1 => { + self.inv_butterfly_vt( + &mut *a_ptr.add(s), + &mut *a_ptr.add(s + l), + zeta_inv, + zeta_inv_shoup, + ); + } + _ => { + for j in s..(s + l) { + self.inv_butterfly_vt( + &mut *a_ptr.add(j), + &mut *a_ptr.add(j + l), + zeta_inv, + zeta_inv_shoup, + ); + } + } + } + } + l <<= 1; + m >>= 1; + } - for i in 0..self.size as isize { - *a_ptr.offset(i) = - self.p - .mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup) - } - } + for i in 0..self.size as isize { + *a_ptr.offset(i) = + self.p + .mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup) + } + } - /// Reduce a modulo p. - /// - /// Aborts if a >= 4 * p. - const fn reduce3(&self, a: u64) -> u64 { - debug_assert!(a < 4 * self.p.p); + /// Reduce a modulo p. + /// + /// Aborts if a >= 4 * p. + const fn reduce3(&self, a: u64) -> u64 { + debug_assert!(a < 4 * self.p.p); - let y = Modulus::reduce1(a, 2 * self.p.p); - Modulus::reduce1(y, self.p.p) - } + let y = Modulus::reduce1(a, 2 * self.p.p); + Modulus::reduce1(y, self.p.p) + } - /// Reduce a modulo p in variable time. - /// - /// Aborts if a >= 4 * p. - const unsafe fn reduce3_vt(&self, a: u64) -> u64 { - debug_assert!(a < 4 * self.p.p); + /// Reduce a modulo p in variable time. + /// + /// Aborts if a >= 4 * p. + const unsafe fn reduce3_vt(&self, a: u64) -> u64 { + debug_assert!(a < 4 * self.p.p); - let y = Modulus::reduce1_vt(a, 2 * self.p.p); - Modulus::reduce1_vt(y, self.p.p) - } + let y = Modulus::reduce1_vt(a, 2 * self.p.p); + Modulus::reduce1_vt(y, self.p.p) + } - /// NTT Butterfly. - fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) { - debug_assert!(*x < 4 * self.p.p); - debug_assert!(*y < 4 * self.p.p); - debug_assert!(w < self.p.p); - debug_assert_eq!(self.p.shoup(w), w_shoup); + /// NTT Butterfly. + fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) { + debug_assert!(*x < 4 * self.p.p); + debug_assert!(*y < 4 * self.p.p); + debug_assert!(w < self.p.p); + debug_assert_eq!(self.p.shoup(w), w_shoup); - *x = Modulus::reduce1(*x, self.p_twice); - let t = self.p.lazy_mul_shoup(*y, w, w_shoup); - *y = *x + self.p_twice - t; - *x += t; + *x = Modulus::reduce1(*x, self.p_twice); + let t = self.p.lazy_mul_shoup(*y, w, w_shoup); + *y = *x + self.p_twice - t; + *x += t; - debug_assert!(*x < 4 * self.p.p); - debug_assert!(*y < 4 * self.p.p); - } + debug_assert!(*x < 4 * self.p.p); + debug_assert!(*y < 4 * self.p.p); + } - /// NTT Butterfly in variable time. - unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) { - debug_assert!(*x < 4 * self.p.p); - debug_assert!(*y < 4 * self.p.p); - debug_assert!(w < self.p.p); - debug_assert_eq!(self.p.shoup(w), w_shoup); + /// NTT Butterfly in variable time. + unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) { + debug_assert!(*x < 4 * self.p.p); + debug_assert!(*y < 4 * self.p.p); + debug_assert!(w < self.p.p); + debug_assert_eq!(self.p.shoup(w), w_shoup); - *x = Modulus::reduce1_vt(*x, self.p_twice); - let t = self.p.lazy_mul_shoup(*y, w, w_shoup); - *y = *x + self.p_twice - t; - *x += t; + *x = Modulus::reduce1_vt(*x, self.p_twice); + let t = self.p.lazy_mul_shoup(*y, w, w_shoup); + *y = *x + self.p_twice - t; + *x += t; - debug_assert!(*x < 4 * self.p.p); - debug_assert!(*y < 4 * self.p.p); - } + debug_assert!(*x < 4 * self.p.p); + debug_assert!(*y < 4 * self.p.p); + } - /// Inverse NTT butterfly. - fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) { - debug_assert!(*x < self.p_twice); - debug_assert!(*y < self.p_twice); - debug_assert!(z < self.p.p); - debug_assert_eq!(self.p.shoup(z), z_shoup); + /// Inverse NTT butterfly. + fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) { + debug_assert!(*x < self.p_twice); + debug_assert!(*y < self.p_twice); + debug_assert!(z < self.p.p); + debug_assert_eq!(self.p.shoup(z), z_shoup); - let t = *x; - *x = Modulus::reduce1(*y + t, self.p_twice); - *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup); + let t = *x; + *x = Modulus::reduce1(*y + t, self.p_twice); + *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup); - debug_assert!(*x < self.p_twice); - debug_assert!(*y < self.p_twice); - } + debug_assert!(*x < self.p_twice); + debug_assert!(*y < self.p_twice); + } - /// Inverse NTT butterfly in variable time - unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) { - debug_assert!(*x < self.p_twice); - debug_assert!(*y < self.p_twice); - debug_assert!(z < self.p.p); - debug_assert_eq!(self.p.shoup(z), z_shoup); + /// Inverse NTT butterfly in variable time + unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) { + debug_assert!(*x < self.p_twice); + debug_assert!(*y < self.p_twice); + debug_assert!(z < self.p.p); + debug_assert_eq!(self.p.shoup(z), z_shoup); - let t = *x; - *x = Modulus::reduce1_vt(*y + t, self.p_twice); - *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup); + let t = *x; + *x = Modulus::reduce1_vt(*y + t, self.p_twice); + *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup); - debug_assert!(*x < self.p_twice); - debug_assert!(*y < self.p_twice); - } + debug_assert!(*x < self.p_twice); + debug_assert!(*y < self.p_twice); + } - /// Returns a 2n-th primitive root modulo p. - /// - /// Aborts if p is not prime or n is not a power of 2 that is >= 8. - fn primitive_root(n: usize, p: &Modulus) -> u64 { - debug_assert!(supports_ntt(p.p, n)); + /// Returns a 2n-th primitive root modulo p. + /// + /// Aborts if p is not prime or n is not a power of 2 that is >= 8. + fn primitive_root(n: usize, p: &Modulus) -> u64 { + debug_assert!(supports_ntt(p.p, n)); - let lambda = (p.p - 1) / (2 * n as u64); + let lambda = (p.p - 1) / (2 * n as u64); - let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0); - for _ in 0..100 { - let mut root = rng.gen_range(0..p.p); - root = p.pow(root, lambda); - if Self::is_primitive_root(root, 2 * n, p) { - return root; - } - } + let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0); + for _ in 0..100 { + let mut root = rng.gen_range(0..p.p); + root = p.pow(root, lambda); + if Self::is_primitive_root(root, 2 * n, p) { + return root; + } + } - debug_assert!(false, "Couldn't find primitive root"); - 0 - } + debug_assert!(false, "Couldn't find primitive root"); + 0 + } - /// Returns whether a is a n-th primitive root of unity. - /// - /// Aborts if a >= p in debug mode. - fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool { - debug_assert!(a < p.p); - debug_assert!(supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here. + /// Returns whether a is a n-th primitive root of unity. + /// + /// Aborts if a >= p in debug mode. + fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool { + debug_assert!(a < p.p); + debug_assert!(supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here. - // A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p - // for all prime p dividing n. - (p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1) - } + // A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p + // for all prime p dividing n. + (p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1) + } } #[cfg(test)] mod tests { - use rand::thread_rng; + use rand::thread_rng; - use super::{supports_ntt, NttOperator}; - use crate::zq::Modulus; + use super::{supports_ntt, NttOperator}; + use crate::zq::Modulus; - #[test] - fn constructor() { - for size in [8, 1024] { - for p in [1153, 4611686018326724609] { - let q = Modulus::new(p).unwrap(); - let supports_ntt = supports_ntt(p, size); + #[test] + fn constructor() { + for size in [8, 1024] { + for p in [1153, 4611686018326724609] { + let q = Modulus::new(p).unwrap(); + let supports_ntt = supports_ntt(p, size); - let op = NttOperator::new(&q, size); + let op = NttOperator::new(&q, size); - if supports_ntt { - assert!(op.is_some()); - } else { - assert!(op.is_none()); - } - } - } - } + if supports_ntt { + assert!(op.is_some()); + } else { + assert!(op.is_none()); + } + } + } + } - #[test] - fn bijection() { - let ntests = 100; - let mut rng = thread_rng(); + #[test] + fn bijection() { + let ntests = 100; + let mut rng = thread_rng(); - for size in [8, 1024] { - for p in [1153, 4611686018326724609] { - let q = Modulus::new(p).unwrap(); + for size in [8, 1024] { + for p in [1153, 4611686018326724609] { + let q = Modulus::new(p).unwrap(); - if supports_ntt(p, size) { - let op = NttOperator::new(&q, size).unwrap(); + if supports_ntt(p, size) { + let op = NttOperator::new(&q, size).unwrap(); - for _ in 0..ntests { - let mut a = q.random_vec(size, &mut rng); - let a_clone = a.clone(); - let mut b = a.clone(); + for _ in 0..ntests { + let mut a = q.random_vec(size, &mut rng); + let a_clone = a.clone(); + let mut b = a.clone(); - op.forward(&mut a); - assert_ne!(a, a_clone); + op.forward(&mut a); + assert_ne!(a, a_clone); - unsafe { op.forward_vt(b.as_mut_ptr()) } - assert_eq!(a, b); + unsafe { op.forward_vt(b.as_mut_ptr()) } + assert_eq!(a, b); - op.backward(&mut a); - assert_eq!(a, a_clone); + op.backward(&mut a); + assert_eq!(a, a_clone); - unsafe { op.backward_vt(b.as_mut_ptr()) } - assert_eq!(a, b); - } - } - } - } - } + unsafe { op.backward_vt(b.as_mut_ptr()) } + assert_eq!(a, b); + } + } + } + } + } - #[test] - fn forward_lazy() { - let ntests = 100; - let mut rng = thread_rng(); + #[test] + fn forward_lazy() { + let ntests = 100; + let mut rng = thread_rng(); - for size in [8, 1024] { - for p in [1153, 4611686018326724609] { - let q = Modulus::new(p).unwrap(); + for size in [8, 1024] { + for p in [1153, 4611686018326724609] { + let q = Modulus::new(p).unwrap(); - if supports_ntt(p, size) { - let op = NttOperator::new(&q, size).unwrap(); + if supports_ntt(p, size) { + let op = NttOperator::new(&q, size).unwrap(); - for _ in 0..ntests { - let mut a = q.random_vec(size, &mut rng); - let mut a_lazy = a.clone(); + for _ in 0..ntests { + let mut a = q.random_vec(size, &mut rng); + let mut a_lazy = a.clone(); - op.forward(&mut a); + op.forward(&mut a); - unsafe { - op.forward_vt_lazy(a_lazy.as_mut_ptr()); - q.reduce_vec(&mut a_lazy); - } + unsafe { + op.forward_vt_lazy(a_lazy.as_mut_ptr()); + q.reduce_vec(&mut a_lazy); + } - assert_eq!(a, a_lazy); - } - } - } - } - } + assert_eq!(a, a_lazy); + } + } + } + } + } } diff --git a/crates/fhe-math/src/zq/primes.rs b/crates/fhe-math/src/zq/primes.rs index 87dc9de..7e81c3b 100644 --- a/crates/fhe-math/src/zq/primes.rs +++ b/crates/fhe-math/src/zq/primes.rs @@ -7,113 +7,113 @@ use num_bigint::BigUint; /// These optimized operations are possible when the modulus verifies /// Equation (1) of . pub fn supports_opt(p: u64) -> bool { - if p.leading_zeros() < 1 { - return false; - } + if p.leading_zeros() < 1 { + return false; + } - // Let's multiply the inequality by (2^s0+1)*2^(3s0): - // we want to output true when - // (2^(3s0)+1) * 2^64 < 2^(3s0) * (2^s0+1) * p - let mut middle = BigUint::from(1u64) << (3 * p.leading_zeros() as usize); - let left_side = (&middle + 1u64) << 64; - middle *= (1u64 << p.leading_zeros()) + 1; - middle *= p; + // Let's multiply the inequality by (2^s0+1)*2^(3s0): + // we want to output true when + // (2^(3s0)+1) * 2^64 < 2^(3s0) * (2^s0+1) * p + let mut middle = BigUint::from(1u64) << (3 * p.leading_zeros() as usize); + let left_side = (&middle + 1u64) << 64; + middle *= (1u64 << p.leading_zeros()) + 1; + middle *= p; - left_side < middle + left_side < middle } /// Generate a `num_bits`-bit prime, congruent to 1 mod `modulo`, strictly /// smaller than `upper_bound`. Note that `num_bits` must belong to (10..=62), /// and upper_bound must be <= 1 << num_bits. pub fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option { - if !(10..=62).contains(&num_bits) { - None - } else { - debug_assert!( - (1u64 << num_bits) >= upper_bound, - "upper_bound larger than number of bits" - ); + if !(10..=62).contains(&num_bits) { + None + } else { + debug_assert!( + (1u64 << num_bits) >= upper_bound, + "upper_bound larger than number of bits" + ); - let leading_zeros = (64 - num_bits) as u32; + let leading_zeros = (64 - num_bits) as u32; - let mut tentative_prime = upper_bound - 1; - while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros { - tentative_prime -= 1 - } + let mut tentative_prime = upper_bound - 1; + while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros { + tentative_prime -= 1 + } - while tentative_prime.leading_zeros() == leading_zeros - && !is_prime(tentative_prime) - && tentative_prime >= modulo - { - tentative_prime -= modulo - } + while tentative_prime.leading_zeros() == leading_zeros + && !is_prime(tentative_prime) + && tentative_prime >= modulo + { + tentative_prime -= modulo + } - if tentative_prime.leading_zeros() == leading_zeros && is_prime(tentative_prime) { - Some(tentative_prime) - } else { - None - } - } + if tentative_prime.leading_zeros() == leading_zeros && is_prime(tentative_prime) { + Some(tentative_prime) + } else { + None + } + } } #[cfg(test)] mod tests { - use super::generate_prime; - use fhe_util::catch_unwind; + use super::generate_prime; + use fhe_util::catch_unwind; - // Verifies that the same moduli as in the NFLlib library are generated. - // - #[test] - fn nfl_62bit_primes() { - let mut generated = vec![]; - let mut upper_bound = u64::MAX >> 2; - while generated.len() != 20 { - let p = generate_prime(62, 2 * 1048576, upper_bound); - assert!(p.is_some()); - upper_bound = p.unwrap(); - generated.push(upper_bound); - } - assert_eq!( - generated, - vec![ - 4611686018326724609, - 4611686018309947393, - 4611686018282684417, - 4611686018257518593, - 4611686018232352769, - 4611686018171535361, - 4611686018106523649, - 4611686018058289153, - 4611686018051997697, - 4611686017974403073, - 4611686017812922369, - 4611686017781465089, - 4611686017773076481, - 4611686017678704641, - 4611686017666121729, - 4611686017647247361, - 4611686017590624257, - 4611686017554972673, - 4611686017529806849, - 4611686017517223937 - ] - ) - } + // Verifies that the same moduli as in the NFLlib library are generated. + // + #[test] + fn nfl_62bit_primes() { + let mut generated = vec![]; + let mut upper_bound = u64::MAX >> 2; + while generated.len() != 20 { + let p = generate_prime(62, 2 * 1048576, upper_bound); + assert!(p.is_some()); + upper_bound = p.unwrap(); + generated.push(upper_bound); + } + assert_eq!( + generated, + vec![ + 4611686018326724609, + 4611686018309947393, + 4611686018282684417, + 4611686018257518593, + 4611686018232352769, + 4611686018171535361, + 4611686018106523649, + 4611686018058289153, + 4611686018051997697, + 4611686017974403073, + 4611686017812922369, + 4611686017781465089, + 4611686017773076481, + 4611686017678704641, + 4611686017666121729, + 4611686017647247361, + 4611686017590624257, + 4611686017554972673, + 4611686017529806849, + 4611686017517223937 + ] + ) + } - #[test] - fn upper_bound() { - debug_assert!(catch_unwind(|| generate_prime(62, 2 * 1048576, (1 << 62) + 1)).is_err()); - } + #[test] + fn upper_bound() { + debug_assert!(catch_unwind(|| generate_prime(62, 2 * 1048576, (1 << 62) + 1)).is_err()); + } - #[test] - fn modulo_too_large() { - assert!(generate_prime(10, 2048, 1 << 10).is_none()); - } + #[test] + fn modulo_too_large() { + assert!(generate_prime(10, 2048, 1 << 10).is_none()); + } - #[test] - fn not_found() { - // 1033 is the smallest 11-bit prime congruent to 1 modulo 16, so looking for a - // smaller one should fail. - assert!(generate_prime(11, 16, 1033).is_none()); - } + #[test] + fn not_found() { + // 1033 is the smallest 11-bit prime congruent to 1 modulo 16, so looking for a + // smaller one should fail. + assert!(generate_prime(11, 16, 1033).is_none()); + } } diff --git a/crates/fhe-traits/src/lib.rs b/crates/fhe-traits/src/lib.rs index f0d76a9..2825d08 100644 --- a/crates/fhe-traits/src/lib.rs +++ b/crates/fhe-traits/src/lib.rs @@ -13,21 +13,21 @@ pub trait FheParameters {} /// Indicates that an object is parametrized. pub trait FheParametrized { - /// The type of the FHE parameters. - type Parameters: FheParameters; + /// The type of the FHE parameters. + type Parameters: FheParameters; } /// Indicates that Self parameters can be switched. pub trait FheParametersSwitchable where - Self: FheParametrized, + Self: FheParametrized, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Attempt to switch the underlying parameters using the associated - /// switcher. - fn switch_parameters(&mut self, switcher: &S) -> Result<(), Self::Error>; + /// Attempt to switch the underlying parameters using the associated + /// switcher. + fn switch_parameters(&mut self, switcher: &S) -> Result<(), Self::Error>; } /// Encoding used when encoding a [`FhePlaintext`]. @@ -36,137 +36,137 @@ pub trait FhePlaintextEncoding {} /// A plaintext which will encode one (or more) value(s). pub trait FhePlaintext where - Self: Sized + FheParametrized, + Self: Sized + FheParametrized, { - /// The type of the encoding. - type Encoding: FhePlaintextEncoding; + /// The type of the encoding. + type Encoding: FhePlaintextEncoding; } /// Encode a value using a specified encoding. pub trait FheEncoder where - Self: FhePlaintext, + Self: FhePlaintext, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Attempt to encode a value using a specified encoding. - fn try_encode( - value: V, - encoding: Self::Encoding, - par: &Arc, - ) -> Result; + /// Attempt to encode a value using a specified encoding. + fn try_encode( + value: V, + encoding: Self::Encoding, + par: &Arc, + ) -> Result; } /// Encode a value using a specified encoding. pub trait FheEncoderVariableTime where - Self: FhePlaintext, + Self: FhePlaintext, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Attempt to encode a value using a specified encoding. - /// # Safety - /// This encoding runs in variable time and may leak information about the - /// value. - unsafe fn try_encode_vt( - value: V, - encoding: Self::Encoding, - par: &Arc, - ) -> Result; + /// Attempt to encode a value using a specified encoding. + /// # Safety + /// This encoding runs in variable time and may leak information about the + /// value. + unsafe fn try_encode_vt( + value: V, + encoding: Self::Encoding, + par: &Arc, + ) -> Result; } /// Decode the value in the plaintext with the specified (optional) encoding. pub trait FheDecoder where - Self: Sized, + Self: Sized, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Attempt to decode a [`FhePlaintext`] into a value, using an (optional) - /// encoding. - fn try_decode(pt: &P, encoding: O) -> Result - where - O: Into>; + /// Attempt to decode a [`FhePlaintext`] into a value, using an (optional) + /// encoding. + fn try_decode(pt: &P, encoding: O) -> Result + where + O: Into>; } /// A ciphertext which will encrypt a plaintext. pub trait FheCiphertext where - Self: Sized + Serialize + FheParametrized + DeserializeParametrized, + Self: Sized + Serialize + FheParametrized + DeserializeParametrized, { } /// Encrypt a plaintext into a ciphertext. pub trait FheEncrypter< - P: FhePlaintext, - C: FheCiphertext, + P: FhePlaintext, + C: FheCiphertext, >: FheParametrized { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Try to encrypt an [`FhePlaintext`] into an [`FheCiphertext`]. - fn try_encrypt(&self, pt: &P, rng: &mut R) -> Result; + /// Try to encrypt an [`FhePlaintext`] into an [`FheCiphertext`]. + fn try_encrypt(&self, pt: &P, rng: &mut R) -> Result; } /// Decrypt a ciphertext into a plaintext pub trait FheDecrypter< - P: FhePlaintext, - C: FheCiphertext, + P: FhePlaintext, + C: FheCiphertext, >: FheParametrized { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Try to decrypt an [`FheCiphertext`] into an [`FhePlaintext`]. - fn try_decrypt(&self, ct: &C) -> Result; + /// Try to decrypt an [`FheCiphertext`] into an [`FhePlaintext`]. + fn try_decrypt(&self, ct: &C) -> Result; } /// Serialization. pub trait Serialize { - /// Serialize `Self` into a vector of bytes. - fn to_bytes(&self) -> Vec; + /// Serialize `Self` into a vector of bytes. + fn to_bytes(&self) -> Vec; } /// Deserialization of a parametrized value. pub trait DeserializeParametrized where - Self: Sized, - Self: FheParametrized, + Self: Sized, + Self: FheParametrized, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Attempt to deserialize from a vector of bytes - fn from_bytes(bytes: &[u8], par: &Arc) -> Result; + /// Attempt to deserialize from a vector of bytes + fn from_bytes(bytes: &[u8], par: &Arc) -> Result; } /// Deserialization setting an explicit context. pub trait DeserializeWithContext where - Self: Sized, + Self: Sized, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// The type of context. - type Context; + /// The type of context. + type Context; - /// Attempt to deserialize from a vector of bytes - fn from_bytes(bytes: &[u8], ctx: &Arc) -> Result; + /// Attempt to deserialize from a vector of bytes + fn from_bytes(bytes: &[u8], ctx: &Arc) -> Result; } /// Deserialization without context. pub trait Deserialize where - Self: Sized, + Self: Sized, { - /// The type of error returned. - type Error; + /// The type of error returned. + type Error; - /// Attempt to deserialize from a vector of bytes - fn try_deserialize(bytes: &[u8]) -> Result; + /// Attempt to deserialize from a vector of bytes + fn try_deserialize(bytes: &[u8]) -> Result; } diff --git a/crates/fhe-util/src/lib.rs b/crates/fhe-util/src/lib.rs index f89b8a5..370ade9 100644 --- a/crates/fhe-util/src/lib.rs +++ b/crates/fhe-util/src/lib.rs @@ -19,1751 +19,1751 @@ use std::{mem::size_of, panic::UnwindSafe}; /// Define catch_unwind to silence the panic in unit tests. pub fn catch_unwind(f: F) -> std::thread::Result where - F: FnOnce() -> R + UnwindSafe, + F: FnOnce() -> R + UnwindSafe, { - let prev_hook = std::panic::take_hook(); - std::panic::set_hook(Box::new(|_| {})); - let r = std::panic::catch_unwind(f); - std::panic::set_hook(prev_hook); - r + let prev_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(|_| {})); + let r = std::panic::catch_unwind(f); + std::panic::set_hook(prev_hook); + r } /// Returns whether the modulus p is prime; this function is 100% accurate. pub fn is_prime(p: u64) -> bool { - probably_prime(&BigUint::from(p), 0) + probably_prime(&BigUint::from(p), 0) } /// Sample a vector of independent centered binomial distributions of a given /// variance. Returns an error if the variance is strictly larger than 16. pub fn sample_vec_cbd( - vector_size: usize, - variance: usize, - rng: &mut R, + vector_size: usize, + variance: usize, + rng: &mut R, ) -> Result, &'static str> { - if !(1..=16).contains(&variance) { - return Err("The variance should be between 1 and 16"); - } + if !(1..=16).contains(&variance) { + return Err("The variance should be between 1 and 16"); + } - let mut out = Vec::with_capacity(vector_size); + let mut out = Vec::with_capacity(vector_size); - let number_bits = 4 * variance; - let mask_add = ((u64::MAX >> (64 - number_bits)) >> (2 * variance)) as u128; - let mask_sub = mask_add << (2 * variance); + let number_bits = 4 * variance; + let mask_add = ((u64::MAX >> (64 - number_bits)) >> (2 * variance)) as u128; + let mask_sub = mask_add << (2 * variance); - let mut current_pool = 0u128; - let mut current_pool_nbits = 0; + let mut current_pool = 0u128; + let mut current_pool_nbits = 0; - for _ in 0..vector_size { - if current_pool_nbits < number_bits { - current_pool |= (rng.next_u64() as u128) << current_pool_nbits; - current_pool_nbits += 64; - } - debug_assert!(current_pool_nbits >= number_bits); - out.push( - ((current_pool & mask_add).count_ones() as i64) - - ((current_pool & mask_sub).count_ones() as i64), - ); - current_pool >>= number_bits; - current_pool_nbits -= number_bits; - } + for _ in 0..vector_size { + if current_pool_nbits < number_bits { + current_pool |= (rng.next_u64() as u128) << current_pool_nbits; + current_pool_nbits += 64; + } + debug_assert!(current_pool_nbits >= number_bits); + out.push( + ((current_pool & mask_add).count_ones() as i64) + - ((current_pool & mask_sub).count_ones() as i64), + ); + current_pool >>= number_bits; + current_pool_nbits -= number_bits; + } - Ok(out) + Ok(out) } /// Transcodes a vector of u64 of `nbits`-bit numbers into a vector of bytes. pub fn transcode_to_bytes(a: &[u64], nbits: usize) -> Vec { - assert!(nbits <= 64); - assert!(nbits > 0); + assert!(nbits <= 64); + assert!(nbits > 0); - let mask = (u64::MAX >> (64 - nbits)) as u128; - let nbytes = div_ceil(a.len() * nbits, 8); - let mut out = Vec::with_capacity(nbytes); + let mask = (u64::MAX >> (64 - nbits)) as u128; + let nbytes = div_ceil(a.len() * nbits, 8); + let mut out = Vec::with_capacity(nbytes); - let mut current_index = 0; - let mut current_value = 0u128; - let mut current_value_nbits = 0; - while current_index < a.len() { - if current_value_nbits < 8 { - debug_assert!(64 - a[current_index].leading_zeros() <= nbits as u32); - current_value |= ((a[current_index] as u128) & mask) << current_value_nbits; - current_value_nbits += nbits; - current_index += 1; - } - while current_value_nbits >= 8 { - out.push(current_value as u8); - current_value >>= 8; - current_value_nbits -= 8; - } - } - if current_value_nbits > 0 { - assert!(current_value_nbits < 8); - assert_eq!(out.len(), nbytes - 1); - out.push(current_value as u8) - } else { - assert_eq!(out.len(), nbytes); - assert_eq!(current_value, 0); - } - out + let mut current_index = 0; + let mut current_value = 0u128; + let mut current_value_nbits = 0; + while current_index < a.len() { + if current_value_nbits < 8 { + debug_assert!(64 - a[current_index].leading_zeros() <= nbits as u32); + current_value |= ((a[current_index] as u128) & mask) << current_value_nbits; + current_value_nbits += nbits; + current_index += 1; + } + while current_value_nbits >= 8 { + out.push(current_value as u8); + current_value >>= 8; + current_value_nbits -= 8; + } + } + if current_value_nbits > 0 { + assert!(current_value_nbits < 8); + assert_eq!(out.len(), nbytes - 1); + out.push(current_value as u8) + } else { + assert_eq!(out.len(), nbytes); + assert_eq!(current_value, 0); + } + out } /// Transcodes a vector of u8 into a vector of u64 of `nbits`-bit numbers. pub fn transcode_from_bytes(b: &[u8], nbits: usize) -> Vec { - assert!(nbits <= 64); - assert!(nbits > 0); - let mask = (u64::MAX >> (64 - nbits)) as u128; + assert!(nbits <= 64); + assert!(nbits > 0); + let mask = (u64::MAX >> (64 - nbits)) as u128; - let nelements = div_ceil(b.len() * 8, nbits); - let mut out = Vec::with_capacity(nelements); + let nelements = div_ceil(b.len() * 8, nbits); + let mut out = Vec::with_capacity(nelements); - let mut current_value = 0u128; - let mut current_value_nbits = 0; - let mut current_index = 0; - while current_index < b.len() { - if current_value_nbits < nbits { - current_value |= (b[current_index] as u128) << current_value_nbits; - current_value_nbits += 8; - current_index += 1; - } - while current_value_nbits >= nbits { - out.push((current_value & mask) as u64); - current_value >>= nbits; - current_value_nbits -= nbits; - } - } - if current_value_nbits > 0 { - assert_eq!(out.len(), nelements - 1); - out.push(current_value as u64); - } else { - assert_eq!(out.len(), nelements); - assert_eq!(current_value, 0); - } - out + let mut current_value = 0u128; + let mut current_value_nbits = 0; + let mut current_index = 0; + while current_index < b.len() { + if current_value_nbits < nbits { + current_value |= (b[current_index] as u128) << current_value_nbits; + current_value_nbits += 8; + current_index += 1; + } + while current_value_nbits >= nbits { + out.push((current_value & mask) as u64); + current_value >>= nbits; + current_value_nbits -= nbits; + } + } + if current_value_nbits > 0 { + assert_eq!(out.len(), nelements - 1); + out.push(current_value as u64); + } else { + assert_eq!(out.len(), nelements); + assert_eq!(current_value, 0); + } + out } /// Transcodes a vector of u64 of `input_nbits`-bit numbers into a vector of u64 /// of `output_nbits`-bit numbers. pub fn transcode_bidirectional(a: &[u64], input_nbits: usize, output_nbits: usize) -> Vec { - assert!(input_nbits <= 64); - assert!(output_nbits <= 64); - assert!(input_nbits > 0); - assert!(output_nbits > 0); - let input_mask = (u64::MAX >> (64 - input_nbits)) as u128; - let output_mask = (u64::MAX >> (64 - output_nbits)) as u128; - let output_size = div_ceil(a.len() * input_nbits, output_nbits); - let mut out = Vec::with_capacity(output_size); + assert!(input_nbits <= 64); + assert!(output_nbits <= 64); + assert!(input_nbits > 0); + assert!(output_nbits > 0); + let input_mask = (u64::MAX >> (64 - input_nbits)) as u128; + let output_mask = (u64::MAX >> (64 - output_nbits)) as u128; + let output_size = div_ceil(a.len() * input_nbits, output_nbits); + let mut out = Vec::with_capacity(output_size); - let mut current_index = 0; - let mut current_value = 0u128; - let mut current_value_nbits = 0; - while current_index < a.len() { - if current_value_nbits < output_nbits { - debug_assert!(64 - a[current_index].leading_zeros() <= input_nbits as u32); - current_value |= ((a[current_index] as u128) & input_mask) << current_value_nbits; - current_value_nbits += input_nbits; - current_index += 1; - } - while current_value_nbits >= output_nbits { - out.push((current_value & output_mask) as u64); - current_value >>= output_nbits; - current_value_nbits -= output_nbits; - } - } - if current_value_nbits > 0 { - assert!(current_value_nbits < output_nbits); - assert_eq!(out.len(), output_size - 1); - out.push(current_value as u64) - } else { - assert_eq!(out.len(), output_size); - assert_eq!(current_value, 0); - } - out + let mut current_index = 0; + let mut current_value = 0u128; + let mut current_value_nbits = 0; + while current_index < a.len() { + if current_value_nbits < output_nbits { + debug_assert!(64 - a[current_index].leading_zeros() <= input_nbits as u32); + current_value |= ((a[current_index] as u128) & input_mask) << current_value_nbits; + current_value_nbits += input_nbits; + current_index += 1; + } + while current_value_nbits >= output_nbits { + out.push((current_value & output_mask) as u64); + current_value >>= output_nbits; + current_value_nbits -= output_nbits; + } + } + if current_value_nbits > 0 { + assert!(current_value_nbits < output_nbits); + assert_eq!(out.len(), output_size - 1); + out.push(current_value as u64) + } else { + assert_eq!(out.len(), output_size); + assert_eq!(current_value, 0); + } + out } /// Computes the modular multiplicative inverse of `a` modulo `p`. Returns /// `None` if `a` is not invertible modulo `p`. pub fn inverse(a: u64, p: u64) -> Option { - let p = BigUint::from(p); - let a = BigUint::from(a); - a.mod_inverse(p)?.to_u64() + let p = BigUint::from(p); + let a = BigUint::from(a); + a.mod_inverse(p)?.to_u64() } /// Returns the number of bits b such that 2^b <= value /// to simulate the `.ilog2()` function from . /// Panics when `value` is 0. pub fn ilog2(value: T) -> usize { - assert!(value > T::zero()); - // For this, we compute sizeof(T) - 1 - value.leading_zeros(). Indeed, when 2^b - // <= value < 2^(b+1), then value.leading_zeros() = sizeof(T) - (b + 1). - size_of::() * 8 - 1 - value.leading_zeros() as usize + assert!(value > T::zero()); + // For this, we compute sizeof(T) - 1 - value.leading_zeros(). Indeed, when 2^b + // <= value < 2^(b+1), then value.leading_zeros() = sizeof(T) - (b + 1). + size_of::() * 8 - 1 - value.leading_zeros() as usize } /// Returns the ceil of a divided by b, to simulate the /// `.div_ceil()` function from . /// Panics when `b` is 0. pub fn div_ceil(a: T, b: T) -> T { - assert!(b > T::zero()); - (a + b - T::one()) / b + assert!(b > T::zero()); + (a + b - T::one()) / b } /// Compute the sample variance of a list of values. /// Panics if the length of value is < 2. pub fn variance(values: &[T]) -> f64 { - assert!(values.len() > 1); - let mean = values.iter().fold(0f64, |acc, i| acc + i.to_f64().unwrap()) / (values.len() as f64); - values.iter().fold(0f64, |acc, i| { - acc + (i.to_f64().unwrap() - mean) * (i.to_f64().unwrap() - mean) - }) / ((values.len() as f64) - 1.0) + assert!(values.len() > 1); + let mean = values.iter().fold(0f64, |acc, i| acc + i.to_f64().unwrap()) / (values.len() as f64); + values.iter().fold(0f64, |acc, i| { + acc + (i.to_f64().unwrap() - mean) * (i.to_f64().unwrap() - mean) + }) / ((values.len() as f64) - 1.0) } #[cfg(test)] mod tests { - use itertools::Itertools; - use rand::{thread_rng, RngCore}; + use itertools::Itertools; + use rand::{thread_rng, RngCore}; - use crate::{div_ceil, ilog2, variance}; + use crate::{div_ceil, ilog2, variance}; - use super::{ - inverse, is_prime, sample_vec_cbd, transcode_bidirectional, transcode_from_bytes, - transcode_to_bytes, - }; + use super::{ + inverse, is_prime, sample_vec_cbd, transcode_bidirectional, transcode_from_bytes, + transcode_to_bytes, + }; - #[test] - fn prime() { - assert!(is_prime(2)); - assert!(is_prime(3)); - assert!(is_prime(5)); - assert!(is_prime(7)); - assert!(is_prime(4611686018326724609)); + #[test] + fn prime() { + assert!(is_prime(2)); + assert!(is_prime(3)); + assert!(is_prime(5)); + assert!(is_prime(7)); + assert!(is_prime(4611686018326724609)); - assert!(!is_prime(0)); - assert!(!is_prime(1)); - assert!(!is_prime(4)); - assert!(!is_prime(6)); - assert!(!is_prime(8)); - assert!(!is_prime(9)); - assert!(!is_prime(4611686018326724607)); - } + assert!(!is_prime(0)); + assert!(!is_prime(1)); + assert!(!is_prime(4)); + assert!(!is_prime(6)); + assert!(!is_prime(8)); + assert!(!is_prime(9)); + assert!(!is_prime(4611686018326724607)); + } - #[test] - fn ilog2_is_correct() { - assert_eq!(ilog2(1), 0); - assert_eq!(ilog2(2), 1); - assert_eq!(ilog2(3), 1); - assert_eq!(ilog2(4), 2); - for i in 2..=110 { - assert_eq!(ilog2(1u128 << i), i); - assert_eq!(ilog2((1u128 << i) + 1), i); - assert_eq!(ilog2((1u128 << (i + 1)) - 1), i); - } - } + #[test] + fn ilog2_is_correct() { + assert_eq!(ilog2(1), 0); + assert_eq!(ilog2(2), 1); + assert_eq!(ilog2(3), 1); + assert_eq!(ilog2(4), 2); + for i in 2..=110 { + assert_eq!(ilog2(1u128 << i), i); + assert_eq!(ilog2((1u128 << i) + 1), i); + assert_eq!(ilog2((1u128 << (i + 1)) - 1), i); + } + } - #[test] - fn div_ceil_is_correct() { - for _ in 0..100 { - let a = (thread_rng().next_u32() >> 1) as usize; - assert_eq!(div_ceil(a, 1), a); - assert_eq!(div_ceil(a, 2), (a >> 1) + (a & 1)); - assert_eq!(div_ceil(a, 8), (a + 7) / 8); - } - } + #[test] + fn div_ceil_is_correct() { + for _ in 0..100 { + let a = (thread_rng().next_u32() >> 1) as usize; + assert_eq!(div_ceil(a, 1), a); + assert_eq!(div_ceil(a, 2), (a >> 1) + (a & 1)); + assert_eq!(div_ceil(a, 8), (a + 7) / 8); + } + } - #[test] - fn sample_cbd() { - assert!(sample_vec_cbd(10, 0, &mut thread_rng()).is_err()); - assert!(sample_vec_cbd(10, 17, &mut thread_rng()).is_err()); + #[test] + fn sample_cbd() { + assert!(sample_vec_cbd(10, 0, &mut thread_rng()).is_err()); + assert!(sample_vec_cbd(10, 17, &mut thread_rng()).is_err()); - for var in 1..=16 { - for size in 0..=100 { - let v = sample_vec_cbd(size, var, &mut thread_rng()).unwrap(); - assert_eq!(v.len(), size); - } + for var in 1..=16 { + for size in 0..=100 { + let v = sample_vec_cbd(size, var, &mut thread_rng()).unwrap(); + assert_eq!(v.len(), size); + } - // Verifies that the min, max are in absolute value smaller than 2 * var - let v = sample_vec_cbd(100000, var, &mut thread_rng()).unwrap(); - assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * var as i64); + // Verifies that the min, max are in absolute value smaller than 2 * var + let v = sample_vec_cbd(100000, var, &mut thread_rng()).unwrap(); + assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * var as i64); - // Verifies that the variance is correct. We could probably refine the bound - // but for now, we will just check that the rounded value is equal to the - // variance. - assert!(variance(&v).round() == (var as f64)); - } - } + // Verifies that the variance is correct. We could probably refine the bound + // but for now, we will just check that the rounded value is equal to the + // variance. + assert!(variance(&v).round() == (var as f64)); + } + } - #[test] - fn transcode_self_consistency() { - let mut rng = thread_rng(); + #[test] + fn transcode_self_consistency() { + let mut rng = thread_rng(); - for size in 1..=100 { - let input = (0..size).map(|_| rng.next_u64()).collect_vec(); - for input_nbits in 1..63 { - let masked_input = input - .iter() - .map(|i| (*i) & (u64::MAX >> (64 - input_nbits))) - .collect_vec(); - let bytes = transcode_to_bytes(&masked_input, input_nbits); - let bytes_as_u64 = transcode_bidirectional(&masked_input, input_nbits, 8); - assert_eq!(bytes, bytes_as_u64.iter().map(|e| *e as u8).collect_vec()); + for size in 1..=100 { + let input = (0..size).map(|_| rng.next_u64()).collect_vec(); + for input_nbits in 1..63 { + let masked_input = input + .iter() + .map(|i| (*i) & (u64::MAX >> (64 - input_nbits))) + .collect_vec(); + let bytes = transcode_to_bytes(&masked_input, input_nbits); + let bytes_as_u64 = transcode_bidirectional(&masked_input, input_nbits, 8); + assert_eq!(bytes, bytes_as_u64.iter().map(|e| *e as u8).collect_vec()); - let input_from_bytes = transcode_from_bytes(&bytes, input_nbits); - assert!(input_from_bytes.len() >= masked_input.len()); - assert_eq!(input_from_bytes[..masked_input.len()], masked_input); + let input_from_bytes = transcode_from_bytes(&bytes, input_nbits); + assert!(input_from_bytes.len() >= masked_input.len()); + assert_eq!(input_from_bytes[..masked_input.len()], masked_input); - let input_from_u64 = transcode_bidirectional(&bytes_as_u64, 8, input_nbits); - assert!(input_from_u64.len() >= masked_input.len()); - assert_eq!(input_from_u64[..masked_input.len()], masked_input); + let input_from_u64 = transcode_bidirectional(&bytes_as_u64, 8, input_nbits); + assert!(input_from_u64.len() >= masked_input.len()); + assert_eq!(input_from_u64[..masked_input.len()], masked_input); - for output_nbits in 1..63 { - let output = transcode_bidirectional(&masked_input, input_nbits, output_nbits); - let input_from_output = - transcode_bidirectional(&output, output_nbits, input_nbits); - assert!(input_from_output.len() >= masked_input.len()); - assert_eq!(input_from_output[..masked_input.len()], masked_input); - } - } - } - } + for output_nbits in 1..63 { + let output = transcode_bidirectional(&masked_input, input_nbits, output_nbits); + let input_from_output = + transcode_bidirectional(&output, output_nbits, input_nbits); + assert!(input_from_output.len() >= masked_input.len()); + assert_eq!(input_from_output[..masked_input.len()], masked_input); + } + } + } + } - #[test] - fn inv_kats() { - // KATs for inversion generated in Sage using the following code. - /* - sage: for p in range(2, 1000, 7): - ....: for a in range(1, 30, 3): - ....: if gcd(a, p) == 1: - ....: i = ZZ(a)^(-1) % p - ....: print("assert_eq!(inverse({}, {}), Some({}));".format(a, p, i)) - ....: else: - ....: print("assert!(inverse({}, {}).is_none());".format(a, p)) - */ - assert_eq!(inverse(1, 2), Some(1)); - assert!(inverse(4, 2).is_none()); - assert_eq!(inverse(7, 2), Some(1)); - assert!(inverse(10, 2).is_none()); - assert_eq!(inverse(13, 2), Some(1)); - assert!(inverse(16, 2).is_none()); - assert_eq!(inverse(19, 2), Some(1)); - assert!(inverse(22, 2).is_none()); - assert_eq!(inverse(25, 2), Some(1)); - assert!(inverse(28, 2).is_none()); - assert_eq!(inverse(1, 9), Some(1)); - assert_eq!(inverse(4, 9), Some(7)); - assert_eq!(inverse(7, 9), Some(4)); - assert_eq!(inverse(10, 9), Some(1)); - assert_eq!(inverse(13, 9), Some(7)); - assert_eq!(inverse(16, 9), Some(4)); - assert_eq!(inverse(19, 9), Some(1)); - assert_eq!(inverse(22, 9), Some(7)); - assert_eq!(inverse(25, 9), Some(4)); - assert_eq!(inverse(28, 9), Some(1)); - assert_eq!(inverse(1, 16), Some(1)); - assert!(inverse(4, 16).is_none()); - assert_eq!(inverse(7, 16), Some(7)); - assert!(inverse(10, 16).is_none()); - assert_eq!(inverse(13, 16), Some(5)); - assert!(inverse(16, 16).is_none()); - assert_eq!(inverse(19, 16), Some(11)); - assert!(inverse(22, 16).is_none()); - assert_eq!(inverse(25, 16), Some(9)); - assert!(inverse(28, 16).is_none()); - assert_eq!(inverse(1, 23), Some(1)); - assert_eq!(inverse(4, 23), Some(6)); - assert_eq!(inverse(7, 23), Some(10)); - assert_eq!(inverse(10, 23), Some(7)); - assert_eq!(inverse(13, 23), Some(16)); - assert_eq!(inverse(16, 23), Some(13)); - assert_eq!(inverse(19, 23), Some(17)); - assert_eq!(inverse(22, 23), Some(22)); - assert_eq!(inverse(25, 23), Some(12)); - assert_eq!(inverse(28, 23), Some(14)); - assert_eq!(inverse(1, 30), Some(1)); - assert!(inverse(4, 30).is_none()); - assert_eq!(inverse(7, 30), Some(13)); - assert!(inverse(10, 30).is_none()); - assert_eq!(inverse(13, 30), Some(7)); - assert!(inverse(16, 30).is_none()); - assert_eq!(inverse(19, 30), Some(19)); - assert!(inverse(22, 30).is_none()); - assert!(inverse(25, 30).is_none()); - assert!(inverse(28, 30).is_none()); - assert_eq!(inverse(1, 37), Some(1)); - assert_eq!(inverse(4, 37), Some(28)); - assert_eq!(inverse(7, 37), Some(16)); - assert_eq!(inverse(10, 37), Some(26)); - assert_eq!(inverse(13, 37), Some(20)); - assert_eq!(inverse(16, 37), Some(7)); - assert_eq!(inverse(19, 37), Some(2)); - assert_eq!(inverse(22, 37), Some(32)); - assert_eq!(inverse(25, 37), Some(3)); - assert_eq!(inverse(28, 37), Some(4)); - assert_eq!(inverse(1, 44), Some(1)); - assert!(inverse(4, 44).is_none()); - assert_eq!(inverse(7, 44), Some(19)); - assert!(inverse(10, 44).is_none()); - assert_eq!(inverse(13, 44), Some(17)); - assert!(inverse(16, 44).is_none()); - assert_eq!(inverse(19, 44), Some(7)); - assert!(inverse(22, 44).is_none()); - assert_eq!(inverse(25, 44), Some(37)); - assert!(inverse(28, 44).is_none()); - assert_eq!(inverse(1, 51), Some(1)); - assert_eq!(inverse(4, 51), Some(13)); - assert_eq!(inverse(7, 51), Some(22)); - assert_eq!(inverse(10, 51), Some(46)); - assert_eq!(inverse(13, 51), Some(4)); - assert_eq!(inverse(16, 51), Some(16)); - assert_eq!(inverse(19, 51), Some(43)); - assert_eq!(inverse(22, 51), Some(7)); - assert_eq!(inverse(25, 51), Some(49)); - assert_eq!(inverse(28, 51), Some(31)); - assert_eq!(inverse(1, 58), Some(1)); - assert!(inverse(4, 58).is_none()); - assert_eq!(inverse(7, 58), Some(25)); - assert!(inverse(10, 58).is_none()); - assert_eq!(inverse(13, 58), Some(9)); - assert!(inverse(16, 58).is_none()); - assert_eq!(inverse(19, 58), Some(55)); - assert!(inverse(22, 58).is_none()); - assert_eq!(inverse(25, 58), Some(7)); - assert!(inverse(28, 58).is_none()); - assert_eq!(inverse(1, 65), Some(1)); - assert_eq!(inverse(4, 65), Some(49)); - assert_eq!(inverse(7, 65), Some(28)); - assert!(inverse(10, 65).is_none()); - assert!(inverse(13, 65).is_none()); - assert_eq!(inverse(16, 65), Some(61)); - assert_eq!(inverse(19, 65), Some(24)); - assert_eq!(inverse(22, 65), Some(3)); - assert!(inverse(25, 65).is_none()); - assert_eq!(inverse(28, 65), Some(7)); - assert_eq!(inverse(1, 72), Some(1)); - assert!(inverse(4, 72).is_none()); - assert_eq!(inverse(7, 72), Some(31)); - assert!(inverse(10, 72).is_none()); - assert_eq!(inverse(13, 72), Some(61)); - assert!(inverse(16, 72).is_none()); - assert_eq!(inverse(19, 72), Some(19)); - assert!(inverse(22, 72).is_none()); - assert_eq!(inverse(25, 72), Some(49)); - assert!(inverse(28, 72).is_none()); - assert_eq!(inverse(1, 79), Some(1)); - assert_eq!(inverse(4, 79), Some(20)); - assert_eq!(inverse(7, 79), Some(34)); - assert_eq!(inverse(10, 79), Some(8)); - assert_eq!(inverse(13, 79), Some(73)); - assert_eq!(inverse(16, 79), Some(5)); - assert_eq!(inverse(19, 79), Some(25)); - assert_eq!(inverse(22, 79), Some(18)); - assert_eq!(inverse(25, 79), Some(19)); - assert_eq!(inverse(28, 79), Some(48)); - assert_eq!(inverse(1, 86), Some(1)); - assert!(inverse(4, 86).is_none()); - assert_eq!(inverse(7, 86), Some(37)); - assert!(inverse(10, 86).is_none()); - assert_eq!(inverse(13, 86), Some(53)); - assert!(inverse(16, 86).is_none()); - assert_eq!(inverse(19, 86), Some(77)); - assert!(inverse(22, 86).is_none()); - assert_eq!(inverse(25, 86), Some(31)); - assert!(inverse(28, 86).is_none()); - assert_eq!(inverse(1, 93), Some(1)); - assert_eq!(inverse(4, 93), Some(70)); - assert_eq!(inverse(7, 93), Some(40)); - assert_eq!(inverse(10, 93), Some(28)); - assert_eq!(inverse(13, 93), Some(43)); - assert_eq!(inverse(16, 93), Some(64)); - assert_eq!(inverse(19, 93), Some(49)); - assert_eq!(inverse(22, 93), Some(55)); - assert_eq!(inverse(25, 93), Some(67)); - assert_eq!(inverse(28, 93), Some(10)); - assert_eq!(inverse(1, 100), Some(1)); - assert!(inverse(4, 100).is_none()); - assert_eq!(inverse(7, 100), Some(43)); - assert!(inverse(10, 100).is_none()); - assert_eq!(inverse(13, 100), Some(77)); - assert!(inverse(16, 100).is_none()); - assert_eq!(inverse(19, 100), Some(79)); - assert!(inverse(22, 100).is_none()); - assert!(inverse(25, 100).is_none()); - assert!(inverse(28, 100).is_none()); - assert_eq!(inverse(1, 107), Some(1)); - assert_eq!(inverse(4, 107), Some(27)); - assert_eq!(inverse(7, 107), Some(46)); - assert_eq!(inverse(10, 107), Some(75)); - assert_eq!(inverse(13, 107), Some(33)); - assert_eq!(inverse(16, 107), Some(87)); - assert_eq!(inverse(19, 107), Some(62)); - assert_eq!(inverse(22, 107), Some(73)); - assert_eq!(inverse(25, 107), Some(30)); - assert_eq!(inverse(28, 107), Some(65)); - assert_eq!(inverse(1, 114), Some(1)); - assert!(inverse(4, 114).is_none()); - assert_eq!(inverse(7, 114), Some(49)); - assert!(inverse(10, 114).is_none()); - assert_eq!(inverse(13, 114), Some(79)); - assert!(inverse(16, 114).is_none()); - assert!(inverse(19, 114).is_none()); - assert!(inverse(22, 114).is_none()); - assert_eq!(inverse(25, 114), Some(73)); - assert!(inverse(28, 114).is_none()); - assert_eq!(inverse(1, 121), Some(1)); - assert_eq!(inverse(4, 121), Some(91)); - assert_eq!(inverse(7, 121), Some(52)); - assert_eq!(inverse(10, 121), Some(109)); - assert_eq!(inverse(13, 121), Some(28)); - assert_eq!(inverse(16, 121), Some(53)); - assert_eq!(inverse(19, 121), Some(51)); - assert!(inverse(22, 121).is_none()); - assert_eq!(inverse(25, 121), Some(92)); - assert_eq!(inverse(28, 121), Some(13)); - assert_eq!(inverse(1, 128), Some(1)); - assert!(inverse(4, 128).is_none()); - assert_eq!(inverse(7, 128), Some(55)); - assert!(inverse(10, 128).is_none()); - assert_eq!(inverse(13, 128), Some(69)); - assert!(inverse(16, 128).is_none()); - assert_eq!(inverse(19, 128), Some(27)); - assert!(inverse(22, 128).is_none()); - assert_eq!(inverse(25, 128), Some(41)); - assert!(inverse(28, 128).is_none()); - assert_eq!(inverse(1, 135), Some(1)); - assert_eq!(inverse(4, 135), Some(34)); - assert_eq!(inverse(7, 135), Some(58)); - assert!(inverse(10, 135).is_none()); - assert_eq!(inverse(13, 135), Some(52)); - assert_eq!(inverse(16, 135), Some(76)); - assert_eq!(inverse(19, 135), Some(64)); - assert_eq!(inverse(22, 135), Some(43)); - assert!(inverse(25, 135).is_none()); - assert_eq!(inverse(28, 135), Some(82)); - assert_eq!(inverse(1, 142), Some(1)); - assert!(inverse(4, 142).is_none()); - assert_eq!(inverse(7, 142), Some(61)); - assert!(inverse(10, 142).is_none()); - assert_eq!(inverse(13, 142), Some(11)); - assert!(inverse(16, 142).is_none()); - assert_eq!(inverse(19, 142), Some(15)); - assert!(inverse(22, 142).is_none()); - assert_eq!(inverse(25, 142), Some(125)); - assert!(inverse(28, 142).is_none()); - assert_eq!(inverse(1, 149), Some(1)); - assert_eq!(inverse(4, 149), Some(112)); - assert_eq!(inverse(7, 149), Some(64)); - assert_eq!(inverse(10, 149), Some(15)); - assert_eq!(inverse(13, 149), Some(23)); - assert_eq!(inverse(16, 149), Some(28)); - assert_eq!(inverse(19, 149), Some(102)); - assert_eq!(inverse(22, 149), Some(61)); - assert_eq!(inverse(25, 149), Some(6)); - assert_eq!(inverse(28, 149), Some(16)); - assert_eq!(inverse(1, 156), Some(1)); - assert!(inverse(4, 156).is_none()); - assert_eq!(inverse(7, 156), Some(67)); - assert!(inverse(10, 156).is_none()); - assert!(inverse(13, 156).is_none()); - assert!(inverse(16, 156).is_none()); - assert_eq!(inverse(19, 156), Some(115)); - assert!(inverse(22, 156).is_none()); - assert_eq!(inverse(25, 156), Some(25)); - assert!(inverse(28, 156).is_none()); - assert_eq!(inverse(1, 163), Some(1)); - assert_eq!(inverse(4, 163), Some(41)); - assert_eq!(inverse(7, 163), Some(70)); - assert_eq!(inverse(10, 163), Some(49)); - assert_eq!(inverse(13, 163), Some(138)); - assert_eq!(inverse(16, 163), Some(51)); - assert_eq!(inverse(19, 163), Some(103)); - assert_eq!(inverse(22, 163), Some(126)); - assert_eq!(inverse(25, 163), Some(150)); - assert_eq!(inverse(28, 163), Some(99)); - assert_eq!(inverse(1, 170), Some(1)); - assert!(inverse(4, 170).is_none()); - assert_eq!(inverse(7, 170), Some(73)); - assert!(inverse(10, 170).is_none()); - assert_eq!(inverse(13, 170), Some(157)); - assert!(inverse(16, 170).is_none()); - assert_eq!(inverse(19, 170), Some(9)); - assert!(inverse(22, 170).is_none()); - assert!(inverse(25, 170).is_none()); - assert!(inverse(28, 170).is_none()); - assert_eq!(inverse(1, 177), Some(1)); - assert_eq!(inverse(4, 177), Some(133)); - assert_eq!(inverse(7, 177), Some(76)); - assert_eq!(inverse(10, 177), Some(124)); - assert_eq!(inverse(13, 177), Some(109)); - assert_eq!(inverse(16, 177), Some(166)); - assert_eq!(inverse(19, 177), Some(28)); - assert_eq!(inverse(22, 177), Some(169)); - assert_eq!(inverse(25, 177), Some(85)); - assert_eq!(inverse(28, 177), Some(19)); - assert_eq!(inverse(1, 184), Some(1)); - assert!(inverse(4, 184).is_none()); - assert_eq!(inverse(7, 184), Some(79)); - assert!(inverse(10, 184).is_none()); - assert_eq!(inverse(13, 184), Some(85)); - assert!(inverse(16, 184).is_none()); - assert_eq!(inverse(19, 184), Some(155)); - assert!(inverse(22, 184).is_none()); - assert_eq!(inverse(25, 184), Some(81)); - assert!(inverse(28, 184).is_none()); - assert_eq!(inverse(1, 191), Some(1)); - assert_eq!(inverse(4, 191), Some(48)); - assert_eq!(inverse(7, 191), Some(82)); - assert_eq!(inverse(10, 191), Some(172)); - assert_eq!(inverse(13, 191), Some(147)); - assert_eq!(inverse(16, 191), Some(12)); - assert_eq!(inverse(19, 191), Some(181)); - assert_eq!(inverse(22, 191), Some(165)); - assert_eq!(inverse(25, 191), Some(107)); - assert_eq!(inverse(28, 191), Some(116)); - assert_eq!(inverse(1, 198), Some(1)); - assert!(inverse(4, 198).is_none()); - assert_eq!(inverse(7, 198), Some(85)); - assert!(inverse(10, 198).is_none()); - assert_eq!(inverse(13, 198), Some(61)); - assert!(inverse(16, 198).is_none()); - assert_eq!(inverse(19, 198), Some(73)); - assert!(inverse(22, 198).is_none()); - assert_eq!(inverse(25, 198), Some(103)); - assert!(inverse(28, 198).is_none()); - assert_eq!(inverse(1, 205), Some(1)); - assert_eq!(inverse(4, 205), Some(154)); - assert_eq!(inverse(7, 205), Some(88)); - assert!(inverse(10, 205).is_none()); - assert_eq!(inverse(13, 205), Some(142)); - assert_eq!(inverse(16, 205), Some(141)); - assert_eq!(inverse(19, 205), Some(54)); - assert_eq!(inverse(22, 205), Some(28)); - assert!(inverse(25, 205).is_none()); - assert_eq!(inverse(28, 205), Some(22)); - assert_eq!(inverse(1, 212), Some(1)); - assert!(inverse(4, 212).is_none()); - assert_eq!(inverse(7, 212), Some(91)); - assert!(inverse(10, 212).is_none()); - assert_eq!(inverse(13, 212), Some(49)); - assert!(inverse(16, 212).is_none()); - assert_eq!(inverse(19, 212), Some(67)); - assert!(inverse(22, 212).is_none()); - assert_eq!(inverse(25, 212), Some(17)); - assert!(inverse(28, 212).is_none()); - assert_eq!(inverse(1, 219), Some(1)); - assert_eq!(inverse(4, 219), Some(55)); - assert_eq!(inverse(7, 219), Some(94)); - assert_eq!(inverse(10, 219), Some(22)); - assert_eq!(inverse(13, 219), Some(118)); - assert_eq!(inverse(16, 219), Some(178)); - assert_eq!(inverse(19, 219), Some(196)); - assert_eq!(inverse(22, 219), Some(10)); - assert_eq!(inverse(25, 219), Some(184)); - assert_eq!(inverse(28, 219), Some(133)); - assert_eq!(inverse(1, 226), Some(1)); - assert!(inverse(4, 226).is_none()); - assert_eq!(inverse(7, 226), Some(97)); - assert!(inverse(10, 226).is_none()); - assert_eq!(inverse(13, 226), Some(87)); - assert!(inverse(16, 226).is_none()); - assert_eq!(inverse(19, 226), Some(119)); - assert!(inverse(22, 226).is_none()); - assert_eq!(inverse(25, 226), Some(217)); - assert!(inverse(28, 226).is_none()); - assert_eq!(inverse(1, 233), Some(1)); - assert_eq!(inverse(4, 233), Some(175)); - assert_eq!(inverse(7, 233), Some(100)); - assert_eq!(inverse(10, 233), Some(70)); - assert_eq!(inverse(13, 233), Some(18)); - assert_eq!(inverse(16, 233), Some(102)); - assert_eq!(inverse(19, 233), Some(184)); - assert_eq!(inverse(22, 233), Some(53)); - assert_eq!(inverse(25, 233), Some(28)); - assert_eq!(inverse(28, 233), Some(25)); - assert_eq!(inverse(1, 240), Some(1)); - assert!(inverse(4, 240).is_none()); - assert_eq!(inverse(7, 240), Some(103)); - assert!(inverse(10, 240).is_none()); - assert_eq!(inverse(13, 240), Some(37)); - assert!(inverse(16, 240).is_none()); - assert_eq!(inverse(19, 240), Some(139)); - assert!(inverse(22, 240).is_none()); - assert!(inverse(25, 240).is_none()); - assert!(inverse(28, 240).is_none()); - assert_eq!(inverse(1, 247), Some(1)); - assert_eq!(inverse(4, 247), Some(62)); - assert_eq!(inverse(7, 247), Some(106)); - assert_eq!(inverse(10, 247), Some(173)); - assert!(inverse(13, 247).is_none()); - assert_eq!(inverse(16, 247), Some(139)); - assert!(inverse(19, 247).is_none()); - assert_eq!(inverse(22, 247), Some(146)); - assert_eq!(inverse(25, 247), Some(168)); - assert_eq!(inverse(28, 247), Some(150)); - assert_eq!(inverse(1, 254), Some(1)); - assert!(inverse(4, 254).is_none()); - assert_eq!(inverse(7, 254), Some(109)); - assert!(inverse(10, 254).is_none()); - assert_eq!(inverse(13, 254), Some(215)); - assert!(inverse(16, 254).is_none()); - assert_eq!(inverse(19, 254), Some(107)); - assert!(inverse(22, 254).is_none()); - assert_eq!(inverse(25, 254), Some(61)); - assert!(inverse(28, 254).is_none()); - assert_eq!(inverse(1, 261), Some(1)); - assert_eq!(inverse(4, 261), Some(196)); - assert_eq!(inverse(7, 261), Some(112)); - assert_eq!(inverse(10, 261), Some(235)); - assert_eq!(inverse(13, 261), Some(241)); - assert_eq!(inverse(16, 261), Some(49)); - assert_eq!(inverse(19, 261), Some(55)); - assert_eq!(inverse(22, 261), Some(178)); - assert_eq!(inverse(25, 261), Some(94)); - assert_eq!(inverse(28, 261), Some(28)); - assert_eq!(inverse(1, 268), Some(1)); - assert!(inverse(4, 268).is_none()); - assert_eq!(inverse(7, 268), Some(115)); - assert!(inverse(10, 268).is_none()); - assert_eq!(inverse(13, 268), Some(165)); - assert!(inverse(16, 268).is_none()); - assert_eq!(inverse(19, 268), Some(127)); - assert!(inverse(22, 268).is_none()); - assert_eq!(inverse(25, 268), Some(193)); - assert!(inverse(28, 268).is_none()); - assert_eq!(inverse(1, 275), Some(1)); - assert_eq!(inverse(4, 275), Some(69)); - assert_eq!(inverse(7, 275), Some(118)); - assert!(inverse(10, 275).is_none()); - assert_eq!(inverse(13, 275), Some(127)); - assert_eq!(inverse(16, 275), Some(86)); - assert_eq!(inverse(19, 275), Some(29)); - assert!(inverse(22, 275).is_none()); - assert!(inverse(25, 275).is_none()); - assert_eq!(inverse(28, 275), Some(167)); - assert_eq!(inverse(1, 282), Some(1)); - assert!(inverse(4, 282).is_none()); - assert_eq!(inverse(7, 282), Some(121)); - assert!(inverse(10, 282).is_none()); - assert_eq!(inverse(13, 282), Some(217)); - assert!(inverse(16, 282).is_none()); - assert_eq!(inverse(19, 282), Some(193)); - assert!(inverse(22, 282).is_none()); - assert_eq!(inverse(25, 282), Some(79)); - assert!(inverse(28, 282).is_none()); - assert_eq!(inverse(1, 289), Some(1)); - assert_eq!(inverse(4, 289), Some(217)); - assert_eq!(inverse(7, 289), Some(124)); - assert_eq!(inverse(10, 289), Some(29)); - assert_eq!(inverse(13, 289), Some(89)); - assert_eq!(inverse(16, 289), Some(271)); - assert_eq!(inverse(19, 289), Some(213)); - assert_eq!(inverse(22, 289), Some(92)); - assert_eq!(inverse(25, 289), Some(185)); - assert_eq!(inverse(28, 289), Some(31)); - assert_eq!(inverse(1, 296), Some(1)); - assert!(inverse(4, 296).is_none()); - assert_eq!(inverse(7, 296), Some(127)); - assert!(inverse(10, 296).is_none()); - assert_eq!(inverse(13, 296), Some(205)); - assert!(inverse(16, 296).is_none()); - assert_eq!(inverse(19, 296), Some(187)); - assert!(inverse(22, 296).is_none()); - assert_eq!(inverse(25, 296), Some(225)); - assert!(inverse(28, 296).is_none()); - assert_eq!(inverse(1, 303), Some(1)); - assert_eq!(inverse(4, 303), Some(76)); - assert_eq!(inverse(7, 303), Some(130)); - assert_eq!(inverse(10, 303), Some(91)); - assert_eq!(inverse(13, 303), Some(70)); - assert_eq!(inverse(16, 303), Some(19)); - assert_eq!(inverse(19, 303), Some(16)); - assert_eq!(inverse(22, 303), Some(124)); - assert_eq!(inverse(25, 303), Some(97)); - assert_eq!(inverse(28, 303), Some(184)); - assert_eq!(inverse(1, 310), Some(1)); - assert!(inverse(4, 310).is_none()); - assert_eq!(inverse(7, 310), Some(133)); - assert!(inverse(10, 310).is_none()); - assert_eq!(inverse(13, 310), Some(167)); - assert!(inverse(16, 310).is_none()); - assert_eq!(inverse(19, 310), Some(49)); - assert!(inverse(22, 310).is_none()); - assert!(inverse(25, 310).is_none()); - assert!(inverse(28, 310).is_none()); - assert_eq!(inverse(1, 317), Some(1)); - assert_eq!(inverse(4, 317), Some(238)); - assert_eq!(inverse(7, 317), Some(136)); - assert_eq!(inverse(10, 317), Some(222)); - assert_eq!(inverse(13, 317), Some(122)); - assert_eq!(inverse(16, 317), Some(218)); - assert_eq!(inverse(19, 317), Some(267)); - assert_eq!(inverse(22, 317), Some(245)); - assert_eq!(inverse(25, 317), Some(279)); - assert_eq!(inverse(28, 317), Some(34)); - assert_eq!(inverse(1, 324), Some(1)); - assert!(inverse(4, 324).is_none()); - assert_eq!(inverse(7, 324), Some(139)); - assert!(inverse(10, 324).is_none()); - assert_eq!(inverse(13, 324), Some(25)); - assert!(inverse(16, 324).is_none()); - assert_eq!(inverse(19, 324), Some(307)); - assert!(inverse(22, 324).is_none()); - assert_eq!(inverse(25, 324), Some(13)); - assert!(inverse(28, 324).is_none()); - assert_eq!(inverse(1, 331), Some(1)); - assert_eq!(inverse(4, 331), Some(83)); - assert_eq!(inverse(7, 331), Some(142)); - assert_eq!(inverse(10, 331), Some(298)); - assert_eq!(inverse(13, 331), Some(51)); - assert_eq!(inverse(16, 331), Some(269)); - assert_eq!(inverse(19, 331), Some(122)); - assert_eq!(inverse(22, 331), Some(316)); - assert_eq!(inverse(25, 331), Some(53)); - assert_eq!(inverse(28, 331), Some(201)); - assert_eq!(inverse(1, 338), Some(1)); - assert!(inverse(4, 338).is_none()); - assert_eq!(inverse(7, 338), Some(145)); - assert!(inverse(10, 338).is_none()); - assert!(inverse(13, 338).is_none()); - assert!(inverse(16, 338).is_none()); - assert_eq!(inverse(19, 338), Some(89)); - assert!(inverse(22, 338).is_none()); - assert_eq!(inverse(25, 338), Some(311)); - assert!(inverse(28, 338).is_none()); - assert_eq!(inverse(1, 345), Some(1)); - assert_eq!(inverse(4, 345), Some(259)); - assert_eq!(inverse(7, 345), Some(148)); - assert!(inverse(10, 345).is_none()); - assert_eq!(inverse(13, 345), Some(292)); - assert_eq!(inverse(16, 345), Some(151)); - assert_eq!(inverse(19, 345), Some(109)); - assert_eq!(inverse(22, 345), Some(298)); - assert!(inverse(25, 345).is_none()); - assert_eq!(inverse(28, 345), Some(37)); - assert_eq!(inverse(1, 352), Some(1)); - assert!(inverse(4, 352).is_none()); - assert_eq!(inverse(7, 352), Some(151)); - assert!(inverse(10, 352).is_none()); - assert_eq!(inverse(13, 352), Some(325)); - assert!(inverse(16, 352).is_none()); - assert_eq!(inverse(19, 352), Some(315)); - assert!(inverse(22, 352).is_none()); - assert_eq!(inverse(25, 352), Some(169)); - assert!(inverse(28, 352).is_none()); - assert_eq!(inverse(1, 359), Some(1)); - assert_eq!(inverse(4, 359), Some(90)); - assert_eq!(inverse(7, 359), Some(154)); - assert_eq!(inverse(10, 359), Some(36)); - assert_eq!(inverse(13, 359), Some(221)); - assert_eq!(inverse(16, 359), Some(202)); - assert_eq!(inverse(19, 359), Some(189)); - assert_eq!(inverse(22, 359), Some(49)); - assert_eq!(inverse(25, 359), Some(158)); - assert_eq!(inverse(28, 359), Some(218)); - assert_eq!(inverse(1, 366), Some(1)); - assert!(inverse(4, 366).is_none()); - assert_eq!(inverse(7, 366), Some(157)); - assert!(inverse(10, 366).is_none()); - assert_eq!(inverse(13, 366), Some(169)); - assert!(inverse(16, 366).is_none()); - assert_eq!(inverse(19, 366), Some(289)); - assert!(inverse(22, 366).is_none()); - assert_eq!(inverse(25, 366), Some(205)); - assert!(inverse(28, 366).is_none()); - assert_eq!(inverse(1, 373), Some(1)); - assert_eq!(inverse(4, 373), Some(280)); - assert_eq!(inverse(7, 373), Some(160)); - assert_eq!(inverse(10, 373), Some(112)); - assert_eq!(inverse(13, 373), Some(287)); - assert_eq!(inverse(16, 373), Some(70)); - assert_eq!(inverse(19, 373), Some(216)); - assert_eq!(inverse(22, 373), Some(17)); - assert_eq!(inverse(25, 373), Some(194)); - assert_eq!(inverse(28, 373), Some(40)); - assert_eq!(inverse(1, 380), Some(1)); - assert!(inverse(4, 380).is_none()); - assert_eq!(inverse(7, 380), Some(163)); - assert!(inverse(10, 380).is_none()); - assert_eq!(inverse(13, 380), Some(117)); - assert!(inverse(16, 380).is_none()); - assert!(inverse(19, 380).is_none()); - assert!(inverse(22, 380).is_none()); - assert!(inverse(25, 380).is_none()); - assert!(inverse(28, 380).is_none()); - assert_eq!(inverse(1, 387), Some(1)); - assert_eq!(inverse(4, 387), Some(97)); - assert_eq!(inverse(7, 387), Some(166)); - assert_eq!(inverse(10, 387), Some(271)); - assert_eq!(inverse(13, 387), Some(268)); - assert_eq!(inverse(16, 387), Some(121)); - assert_eq!(inverse(19, 387), Some(163)); - assert_eq!(inverse(22, 387), Some(88)); - assert_eq!(inverse(25, 387), Some(31)); - assert_eq!(inverse(28, 387), Some(235)); - assert_eq!(inverse(1, 394), Some(1)); - assert!(inverse(4, 394).is_none()); - assert_eq!(inverse(7, 394), Some(169)); - assert!(inverse(10, 394).is_none()); - assert_eq!(inverse(13, 394), Some(91)); - assert!(inverse(16, 394).is_none()); - assert_eq!(inverse(19, 394), Some(83)); - assert!(inverse(22, 394).is_none()); - assert_eq!(inverse(25, 394), Some(331)); - assert!(inverse(28, 394).is_none()); - assert_eq!(inverse(1, 401), Some(1)); - assert_eq!(inverse(4, 401), Some(301)); - assert_eq!(inverse(7, 401), Some(172)); - assert_eq!(inverse(10, 401), Some(361)); - assert_eq!(inverse(13, 401), Some(216)); - assert_eq!(inverse(16, 401), Some(376)); - assert_eq!(inverse(19, 401), Some(190)); - assert_eq!(inverse(22, 401), Some(237)); - assert_eq!(inverse(25, 401), Some(385)); - assert_eq!(inverse(28, 401), Some(43)); - assert_eq!(inverse(1, 408), Some(1)); - assert!(inverse(4, 408).is_none()); - assert_eq!(inverse(7, 408), Some(175)); - assert!(inverse(10, 408).is_none()); - assert_eq!(inverse(13, 408), Some(157)); - assert!(inverse(16, 408).is_none()); - assert_eq!(inverse(19, 408), Some(43)); - assert!(inverse(22, 408).is_none()); - assert_eq!(inverse(25, 408), Some(49)); - assert!(inverse(28, 408).is_none()); - assert_eq!(inverse(1, 415), Some(1)); - assert_eq!(inverse(4, 415), Some(104)); - assert_eq!(inverse(7, 415), Some(178)); - assert!(inverse(10, 415).is_none()); - assert_eq!(inverse(13, 415), Some(32)); - assert_eq!(inverse(16, 415), Some(26)); - assert_eq!(inverse(19, 415), Some(284)); - assert_eq!(inverse(22, 415), Some(283)); - assert!(inverse(25, 415).is_none()); - assert_eq!(inverse(28, 415), Some(252)); - assert_eq!(inverse(1, 422), Some(1)); - assert!(inverse(4, 422).is_none()); - assert_eq!(inverse(7, 422), Some(181)); - assert!(inverse(10, 422).is_none()); - assert_eq!(inverse(13, 422), Some(65)); - assert!(inverse(16, 422).is_none()); - assert_eq!(inverse(19, 422), Some(311)); - assert!(inverse(22, 422).is_none()); - assert_eq!(inverse(25, 422), Some(287)); - assert!(inverse(28, 422).is_none()); - assert_eq!(inverse(1, 429), Some(1)); - assert_eq!(inverse(4, 429), Some(322)); - assert_eq!(inverse(7, 429), Some(184)); - assert_eq!(inverse(10, 429), Some(43)); - assert!(inverse(13, 429).is_none()); - assert_eq!(inverse(16, 429), Some(295)); - assert_eq!(inverse(19, 429), Some(271)); - assert!(inverse(22, 429).is_none()); - assert_eq!(inverse(25, 429), Some(103)); - assert_eq!(inverse(28, 429), Some(46)); - assert_eq!(inverse(1, 436), Some(1)); - assert!(inverse(4, 436).is_none()); - assert_eq!(inverse(7, 436), Some(187)); - assert!(inverse(10, 436).is_none()); - assert_eq!(inverse(13, 436), Some(369)); - assert!(inverse(16, 436).is_none()); - assert_eq!(inverse(19, 436), Some(23)); - assert!(inverse(22, 436).is_none()); - assert_eq!(inverse(25, 436), Some(157)); - assert!(inverse(28, 436).is_none()); - assert_eq!(inverse(1, 443), Some(1)); - assert_eq!(inverse(4, 443), Some(111)); - assert_eq!(inverse(7, 443), Some(190)); - assert_eq!(inverse(10, 443), Some(133)); - assert_eq!(inverse(13, 443), Some(409)); - assert_eq!(inverse(16, 443), Some(360)); - assert_eq!(inverse(19, 443), Some(70)); - assert_eq!(inverse(22, 443), Some(141)); - assert_eq!(inverse(25, 443), Some(319)); - assert_eq!(inverse(28, 443), Some(269)); - assert_eq!(inverse(1, 450), Some(1)); - assert!(inverse(4, 450).is_none()); - assert_eq!(inverse(7, 450), Some(193)); - assert!(inverse(10, 450).is_none()); - assert_eq!(inverse(13, 450), Some(277)); - assert!(inverse(16, 450).is_none()); - assert_eq!(inverse(19, 450), Some(379)); - assert!(inverse(22, 450).is_none()); - assert!(inverse(25, 450).is_none()); - assert!(inverse(28, 450).is_none()); - assert_eq!(inverse(1, 457), Some(1)); - assert_eq!(inverse(4, 457), Some(343)); - assert_eq!(inverse(7, 457), Some(196)); - assert_eq!(inverse(10, 457), Some(320)); - assert_eq!(inverse(13, 457), Some(211)); - assert_eq!(inverse(16, 457), Some(200)); - assert_eq!(inverse(19, 457), Some(433)); - assert_eq!(inverse(22, 457), Some(187)); - assert_eq!(inverse(25, 457), Some(128)); - assert_eq!(inverse(28, 457), Some(49)); - assert_eq!(inverse(1, 464), Some(1)); - assert!(inverse(4, 464).is_none()); - assert_eq!(inverse(7, 464), Some(199)); - assert!(inverse(10, 464).is_none()); - assert_eq!(inverse(13, 464), Some(357)); - assert!(inverse(16, 464).is_none()); - assert_eq!(inverse(19, 464), Some(171)); - assert!(inverse(22, 464).is_none()); - assert_eq!(inverse(25, 464), Some(297)); - assert!(inverse(28, 464).is_none()); - assert_eq!(inverse(1, 471), Some(1)); - assert_eq!(inverse(4, 471), Some(118)); - assert_eq!(inverse(7, 471), Some(202)); - assert_eq!(inverse(10, 471), Some(424)); - assert_eq!(inverse(13, 471), Some(145)); - assert_eq!(inverse(16, 471), Some(265)); - assert_eq!(inverse(19, 471), Some(124)); - assert_eq!(inverse(22, 471), Some(364)); - assert_eq!(inverse(25, 471), Some(358)); - assert_eq!(inverse(28, 471), Some(286)); - assert_eq!(inverse(1, 478), Some(1)); - assert!(inverse(4, 478).is_none()); - assert_eq!(inverse(7, 478), Some(205)); - assert!(inverse(10, 478).is_none()); - assert_eq!(inverse(13, 478), Some(331)); - assert!(inverse(16, 478).is_none()); - assert_eq!(inverse(19, 478), Some(151)); - assert!(inverse(22, 478).is_none()); - assert_eq!(inverse(25, 478), Some(153)); - assert!(inverse(28, 478).is_none()); - assert_eq!(inverse(1, 485), Some(1)); - assert_eq!(inverse(4, 485), Some(364)); - assert_eq!(inverse(7, 485), Some(208)); - assert!(inverse(10, 485).is_none()); - assert_eq!(inverse(13, 485), Some(112)); - assert_eq!(inverse(16, 485), Some(91)); - assert_eq!(inverse(19, 485), Some(434)); - assert_eq!(inverse(22, 485), Some(463)); - assert!(inverse(25, 485).is_none()); - assert_eq!(inverse(28, 485), Some(52)); - assert_eq!(inverse(1, 492), Some(1)); - assert!(inverse(4, 492).is_none()); - assert_eq!(inverse(7, 492), Some(211)); - assert!(inverse(10, 492).is_none()); - assert_eq!(inverse(13, 492), Some(265)); - assert!(inverse(16, 492).is_none()); - assert_eq!(inverse(19, 492), Some(259)); - assert!(inverse(22, 492).is_none()); - assert_eq!(inverse(25, 492), Some(433)); - assert!(inverse(28, 492).is_none()); - assert_eq!(inverse(1, 499), Some(1)); - assert_eq!(inverse(4, 499), Some(125)); - assert_eq!(inverse(7, 499), Some(214)); - assert_eq!(inverse(10, 499), Some(50)); - assert_eq!(inverse(13, 499), Some(192)); - assert_eq!(inverse(16, 499), Some(156)); - assert_eq!(inverse(19, 499), Some(394)); - assert_eq!(inverse(22, 499), Some(431)); - assert_eq!(inverse(25, 499), Some(20)); - assert_eq!(inverse(28, 499), Some(303)); - assert_eq!(inverse(1, 506), Some(1)); - assert!(inverse(4, 506).is_none()); - assert_eq!(inverse(7, 506), Some(217)); - assert!(inverse(10, 506).is_none()); - assert_eq!(inverse(13, 506), Some(39)); - assert!(inverse(16, 506).is_none()); - assert_eq!(inverse(19, 506), Some(293)); - assert!(inverse(22, 506).is_none()); - assert_eq!(inverse(25, 506), Some(81)); - assert!(inverse(28, 506).is_none()); - assert_eq!(inverse(1, 513), Some(1)); - assert_eq!(inverse(4, 513), Some(385)); - assert_eq!(inverse(7, 513), Some(220)); - assert_eq!(inverse(10, 513), Some(154)); - assert_eq!(inverse(13, 513), Some(79)); - assert_eq!(inverse(16, 513), Some(481)); - assert!(inverse(19, 513).is_none()); - assert_eq!(inverse(22, 513), Some(70)); - assert_eq!(inverse(25, 513), Some(472)); - assert_eq!(inverse(28, 513), Some(55)); - assert_eq!(inverse(1, 520), Some(1)); - assert!(inverse(4, 520).is_none()); - assert_eq!(inverse(7, 520), Some(223)); - assert!(inverse(10, 520).is_none()); - assert!(inverse(13, 520).is_none()); - assert!(inverse(16, 520).is_none()); - assert_eq!(inverse(19, 520), Some(219)); - assert!(inverse(22, 520).is_none()); - assert!(inverse(25, 520).is_none()); - assert!(inverse(28, 520).is_none()); - assert_eq!(inverse(1, 527), Some(1)); - assert_eq!(inverse(4, 527), Some(132)); - assert_eq!(inverse(7, 527), Some(226)); - assert_eq!(inverse(10, 527), Some(369)); - assert_eq!(inverse(13, 527), Some(446)); - assert_eq!(inverse(16, 527), Some(33)); - assert_eq!(inverse(19, 527), Some(111)); - assert_eq!(inverse(22, 527), Some(24)); - assert_eq!(inverse(25, 527), Some(253)); - assert_eq!(inverse(28, 527), Some(320)); - assert_eq!(inverse(1, 534), Some(1)); - assert!(inverse(4, 534).is_none()); - assert_eq!(inverse(7, 534), Some(229)); - assert!(inverse(10, 534).is_none()); - assert_eq!(inverse(13, 534), Some(493)); - assert!(inverse(16, 534).is_none()); - assert_eq!(inverse(19, 534), Some(253)); - assert!(inverse(22, 534).is_none()); - assert_eq!(inverse(25, 534), Some(235)); - assert!(inverse(28, 534).is_none()); - assert_eq!(inverse(1, 541), Some(1)); - assert_eq!(inverse(4, 541), Some(406)); - assert_eq!(inverse(7, 541), Some(232)); - assert_eq!(inverse(10, 541), Some(487)); - assert_eq!(inverse(13, 541), Some(333)); - assert_eq!(inverse(16, 541), Some(372)); - assert_eq!(inverse(19, 541), Some(57)); - assert_eq!(inverse(22, 541), Some(123)); - assert_eq!(inverse(25, 541), Some(303)); - assert_eq!(inverse(28, 541), Some(58)); - assert_eq!(inverse(1, 548), Some(1)); - assert!(inverse(4, 548).is_none()); - assert_eq!(inverse(7, 548), Some(235)); - assert!(inverse(10, 548).is_none()); - assert_eq!(inverse(13, 548), Some(253)); - assert!(inverse(16, 548).is_none()); - assert_eq!(inverse(19, 548), Some(375)); - assert!(inverse(22, 548).is_none()); - assert_eq!(inverse(25, 548), Some(285)); - assert!(inverse(28, 548).is_none()); - assert_eq!(inverse(1, 555), Some(1)); - assert_eq!(inverse(4, 555), Some(139)); - assert_eq!(inverse(7, 555), Some(238)); - assert!(inverse(10, 555).is_none()); - assert_eq!(inverse(13, 555), Some(427)); - assert_eq!(inverse(16, 555), Some(451)); - assert_eq!(inverse(19, 555), Some(409)); - assert_eq!(inverse(22, 555), Some(328)); - assert!(inverse(25, 555).is_none()); - assert_eq!(inverse(28, 555), Some(337)); - assert_eq!(inverse(1, 562), Some(1)); - assert!(inverse(4, 562).is_none()); - assert_eq!(inverse(7, 562), Some(241)); - assert!(inverse(10, 562).is_none()); - assert_eq!(inverse(13, 562), Some(173)); - assert!(inverse(16, 562).is_none()); - assert_eq!(inverse(19, 562), Some(355)); - assert!(inverse(22, 562).is_none()); - assert_eq!(inverse(25, 562), Some(45)); - assert!(inverse(28, 562).is_none()); - assert_eq!(inverse(1, 569), Some(1)); - assert_eq!(inverse(4, 569), Some(427)); - assert_eq!(inverse(7, 569), Some(244)); - assert_eq!(inverse(10, 569), Some(57)); - assert_eq!(inverse(13, 569), Some(394)); - assert_eq!(inverse(16, 569), Some(249)); - assert_eq!(inverse(19, 569), Some(30)); - assert_eq!(inverse(22, 569), Some(388)); - assert_eq!(inverse(25, 569), Some(478)); - assert_eq!(inverse(28, 569), Some(61)); - assert_eq!(inverse(1, 576), Some(1)); - assert!(inverse(4, 576).is_none()); - assert_eq!(inverse(7, 576), Some(247)); - assert!(inverse(10, 576).is_none()); - assert_eq!(inverse(13, 576), Some(133)); - assert!(inverse(16, 576).is_none()); - assert_eq!(inverse(19, 576), Some(91)); - assert!(inverse(22, 576).is_none()); - assert_eq!(inverse(25, 576), Some(553)); - assert!(inverse(28, 576).is_none()); - assert_eq!(inverse(1, 583), Some(1)); - assert_eq!(inverse(4, 583), Some(146)); - assert_eq!(inverse(7, 583), Some(250)); - assert_eq!(inverse(10, 583), Some(175)); - assert_eq!(inverse(13, 583), Some(314)); - assert_eq!(inverse(16, 583), Some(328)); - assert_eq!(inverse(19, 583), Some(491)); - assert!(inverse(22, 583).is_none()); - assert_eq!(inverse(25, 583), Some(70)); - assert_eq!(inverse(28, 583), Some(354)); - assert_eq!(inverse(1, 590), Some(1)); - assert!(inverse(4, 590).is_none()); - assert_eq!(inverse(7, 590), Some(253)); - assert!(inverse(10, 590).is_none()); - assert_eq!(inverse(13, 590), Some(227)); - assert!(inverse(16, 590).is_none()); - assert_eq!(inverse(19, 590), Some(559)); - assert!(inverse(22, 590).is_none()); - assert!(inverse(25, 590).is_none()); - assert!(inverse(28, 590).is_none()); - assert_eq!(inverse(1, 597), Some(1)); - assert_eq!(inverse(4, 597), Some(448)); - assert_eq!(inverse(7, 597), Some(256)); - assert_eq!(inverse(10, 597), Some(418)); - assert_eq!(inverse(13, 597), Some(46)); - assert_eq!(inverse(16, 597), Some(112)); - assert_eq!(inverse(19, 597), Some(220)); - assert_eq!(inverse(22, 597), Some(190)); - assert_eq!(inverse(25, 597), Some(406)); - assert_eq!(inverse(28, 597), Some(64)); - assert_eq!(inverse(1, 604), Some(1)); - assert!(inverse(4, 604).is_none()); - assert_eq!(inverse(7, 604), Some(259)); - assert!(inverse(10, 604).is_none()); - assert_eq!(inverse(13, 604), Some(93)); - assert!(inverse(16, 604).is_none()); - assert_eq!(inverse(19, 604), Some(159)); - assert!(inverse(22, 604).is_none()); - assert_eq!(inverse(25, 604), Some(145)); - assert!(inverse(28, 604).is_none()); - assert_eq!(inverse(1, 611), Some(1)); - assert_eq!(inverse(4, 611), Some(153)); - assert_eq!(inverse(7, 611), Some(262)); - assert_eq!(inverse(10, 611), Some(550)); - assert!(inverse(13, 611).is_none()); - assert_eq!(inverse(16, 611), Some(191)); - assert_eq!(inverse(19, 611), Some(193)); - assert_eq!(inverse(22, 611), Some(250)); - assert_eq!(inverse(25, 611), Some(220)); - assert_eq!(inverse(28, 611), Some(371)); - assert_eq!(inverse(1, 618), Some(1)); - assert!(inverse(4, 618).is_none()); - assert_eq!(inverse(7, 618), Some(265)); - assert!(inverse(10, 618).is_none()); - assert_eq!(inverse(13, 618), Some(523)); - assert!(inverse(16, 618).is_none()); - assert_eq!(inverse(19, 618), Some(553)); - assert!(inverse(22, 618).is_none()); - assert_eq!(inverse(25, 618), Some(445)); - assert!(inverse(28, 618).is_none()); - assert_eq!(inverse(1, 625), Some(1)); - assert_eq!(inverse(4, 625), Some(469)); - assert_eq!(inverse(7, 625), Some(268)); - assert!(inverse(10, 625).is_none()); - assert_eq!(inverse(13, 625), Some(577)); - assert_eq!(inverse(16, 625), Some(586)); - assert_eq!(inverse(19, 625), Some(329)); - assert_eq!(inverse(22, 625), Some(483)); - assert!(inverse(25, 625).is_none()); - assert_eq!(inverse(28, 625), Some(67)); - assert_eq!(inverse(1, 632), Some(1)); - assert!(inverse(4, 632).is_none()); - assert_eq!(inverse(7, 632), Some(271)); - assert!(inverse(10, 632).is_none()); - assert_eq!(inverse(13, 632), Some(389)); - assert!(inverse(16, 632).is_none()); - assert_eq!(inverse(19, 632), Some(499)); - assert!(inverse(22, 632).is_none()); - assert_eq!(inverse(25, 632), Some(177)); - assert!(inverse(28, 632).is_none()); - assert_eq!(inverse(1, 639), Some(1)); - assert_eq!(inverse(4, 639), Some(160)); - assert_eq!(inverse(7, 639), Some(274)); - assert_eq!(inverse(10, 639), Some(64)); - assert_eq!(inverse(13, 639), Some(295)); - assert_eq!(inverse(16, 639), Some(40)); - assert_eq!(inverse(19, 639), Some(370)); - assert_eq!(inverse(22, 639), Some(610)); - assert_eq!(inverse(25, 639), Some(409)); - assert_eq!(inverse(28, 639), Some(388)); - assert_eq!(inverse(1, 646), Some(1)); - assert!(inverse(4, 646).is_none()); - assert_eq!(inverse(7, 646), Some(277)); - assert!(inverse(10, 646).is_none()); - assert_eq!(inverse(13, 646), Some(497)); - assert!(inverse(16, 646).is_none()); - assert!(inverse(19, 646).is_none()); - assert!(inverse(22, 646).is_none()); - assert_eq!(inverse(25, 646), Some(491)); - assert!(inverse(28, 646).is_none()); - assert_eq!(inverse(1, 653), Some(1)); - assert_eq!(inverse(4, 653), Some(490)); - assert_eq!(inverse(7, 653), Some(280)); - assert_eq!(inverse(10, 653), Some(196)); - assert_eq!(inverse(13, 653), Some(201)); - assert_eq!(inverse(16, 653), Some(449)); - assert_eq!(inverse(19, 653), Some(275)); - assert_eq!(inverse(22, 653), Some(564)); - assert_eq!(inverse(25, 653), Some(209)); - assert_eq!(inverse(28, 653), Some(70)); - assert_eq!(inverse(1, 660), Some(1)); - assert!(inverse(4, 660).is_none()); - assert_eq!(inverse(7, 660), Some(283)); - assert!(inverse(10, 660).is_none()); - assert_eq!(inverse(13, 660), Some(457)); - assert!(inverse(16, 660).is_none()); - assert_eq!(inverse(19, 660), Some(139)); - assert!(inverse(22, 660).is_none()); - assert!(inverse(25, 660).is_none()); - assert!(inverse(28, 660).is_none()); - assert_eq!(inverse(1, 667), Some(1)); - assert_eq!(inverse(4, 667), Some(167)); - assert_eq!(inverse(7, 667), Some(286)); - assert_eq!(inverse(10, 667), Some(467)); - assert_eq!(inverse(13, 667), Some(154)); - assert_eq!(inverse(16, 667), Some(542)); - assert_eq!(inverse(19, 667), Some(316)); - assert_eq!(inverse(22, 667), Some(91)); - assert_eq!(inverse(25, 667), Some(587)); - assert_eq!(inverse(28, 667), Some(405)); - assert_eq!(inverse(1, 674), Some(1)); - assert!(inverse(4, 674).is_none()); - assert_eq!(inverse(7, 674), Some(289)); - assert!(inverse(10, 674).is_none()); - assert_eq!(inverse(13, 674), Some(363)); - assert!(inverse(16, 674).is_none()); - assert_eq!(inverse(19, 674), Some(71)); - assert!(inverse(22, 674).is_none()); - assert_eq!(inverse(25, 674), Some(27)); - assert!(inverse(28, 674).is_none()); - assert_eq!(inverse(1, 681), Some(1)); - assert_eq!(inverse(4, 681), Some(511)); - assert_eq!(inverse(7, 681), Some(292)); - assert_eq!(inverse(10, 681), Some(613)); - assert_eq!(inverse(13, 681), Some(262)); - assert_eq!(inverse(16, 681), Some(298)); - assert_eq!(inverse(19, 681), Some(466)); - assert_eq!(inverse(22, 681), Some(31)); - assert_eq!(inverse(25, 681), Some(109)); - assert_eq!(inverse(28, 681), Some(73)); - assert_eq!(inverse(1, 688), Some(1)); - assert!(inverse(4, 688).is_none()); - assert_eq!(inverse(7, 688), Some(295)); - assert!(inverse(10, 688).is_none()); - assert_eq!(inverse(13, 688), Some(53)); - assert!(inverse(16, 688).is_none()); - assert_eq!(inverse(19, 688), Some(507)); - assert!(inverse(22, 688).is_none()); - assert_eq!(inverse(25, 688), Some(633)); - assert!(inverse(28, 688).is_none()); - assert_eq!(inverse(1, 695), Some(1)); - assert_eq!(inverse(4, 695), Some(174)); - assert_eq!(inverse(7, 695), Some(298)); - assert!(inverse(10, 695).is_none()); - assert_eq!(inverse(13, 695), Some(107)); - assert_eq!(inverse(16, 695), Some(391)); - assert_eq!(inverse(19, 695), Some(439)); - assert_eq!(inverse(22, 695), Some(158)); - assert!(inverse(25, 695).is_none()); - assert_eq!(inverse(28, 695), Some(422)); - assert_eq!(inverse(1, 702), Some(1)); - assert!(inverse(4, 702).is_none()); - assert_eq!(inverse(7, 702), Some(301)); - assert!(inverse(10, 702).is_none()); - assert!(inverse(13, 702).is_none()); - assert!(inverse(16, 702).is_none()); - assert_eq!(inverse(19, 702), Some(37)); - assert!(inverse(22, 702).is_none()); - assert_eq!(inverse(25, 702), Some(337)); - assert!(inverse(28, 702).is_none()); - assert_eq!(inverse(1, 709), Some(1)); - assert_eq!(inverse(4, 709), Some(532)); - assert_eq!(inverse(7, 709), Some(304)); - assert_eq!(inverse(10, 709), Some(71)); - assert_eq!(inverse(13, 709), Some(600)); - assert_eq!(inverse(16, 709), Some(133)); - assert_eq!(inverse(19, 709), Some(112)); - assert_eq!(inverse(22, 709), Some(419)); - assert_eq!(inverse(25, 709), Some(312)); - assert_eq!(inverse(28, 709), Some(76)); - assert_eq!(inverse(1, 716), Some(1)); - assert!(inverse(4, 716).is_none()); - assert_eq!(inverse(7, 716), Some(307)); - assert!(inverse(10, 716).is_none()); - assert_eq!(inverse(13, 716), Some(661)); - assert!(inverse(16, 716).is_none()); - assert_eq!(inverse(19, 716), Some(603)); - assert!(inverse(22, 716).is_none()); - assert_eq!(inverse(25, 716), Some(401)); - assert!(inverse(28, 716).is_none()); - assert_eq!(inverse(1, 723), Some(1)); - assert_eq!(inverse(4, 723), Some(181)); - assert_eq!(inverse(7, 723), Some(310)); - assert_eq!(inverse(10, 723), Some(217)); - assert_eq!(inverse(13, 723), Some(445)); - assert_eq!(inverse(16, 723), Some(226)); - assert_eq!(inverse(19, 723), Some(685)); - assert_eq!(inverse(22, 723), Some(493)); - assert_eq!(inverse(25, 723), Some(376)); - assert_eq!(inverse(28, 723), Some(439)); - assert_eq!(inverse(1, 730), Some(1)); - assert!(inverse(4, 730).is_none()); - assert_eq!(inverse(7, 730), Some(313)); - assert!(inverse(10, 730).is_none()); - assert_eq!(inverse(13, 730), Some(337)); - assert!(inverse(16, 730).is_none()); - assert_eq!(inverse(19, 730), Some(269)); - assert!(inverse(22, 730).is_none()); - assert!(inverse(25, 730).is_none()); - assert!(inverse(28, 730).is_none()); - assert_eq!(inverse(1, 737), Some(1)); - assert_eq!(inverse(4, 737), Some(553)); - assert_eq!(inverse(7, 737), Some(316)); - assert_eq!(inverse(10, 737), Some(516)); - assert_eq!(inverse(13, 737), Some(567)); - assert_eq!(inverse(16, 737), Some(691)); - assert_eq!(inverse(19, 737), Some(194)); - assert!(inverse(22, 737).is_none()); - assert_eq!(inverse(25, 737), Some(59)); - assert_eq!(inverse(28, 737), Some(79)); - assert_eq!(inverse(1, 744), Some(1)); - assert!(inverse(4, 744).is_none()); - assert_eq!(inverse(7, 744), Some(319)); - assert!(inverse(10, 744).is_none()); - assert_eq!(inverse(13, 744), Some(229)); - assert!(inverse(16, 744).is_none()); - assert_eq!(inverse(19, 744), Some(235)); - assert!(inverse(22, 744).is_none()); - assert_eq!(inverse(25, 744), Some(625)); - assert!(inverse(28, 744).is_none()); - assert_eq!(inverse(1, 751), Some(1)); - assert_eq!(inverse(4, 751), Some(188)); - assert_eq!(inverse(7, 751), Some(322)); - assert_eq!(inverse(10, 751), Some(676)); - assert_eq!(inverse(13, 751), Some(520)); - assert_eq!(inverse(16, 751), Some(47)); - assert_eq!(inverse(19, 751), Some(672)); - assert_eq!(inverse(22, 751), Some(239)); - assert_eq!(inverse(25, 751), Some(721)); - assert_eq!(inverse(28, 751), Some(456)); - assert_eq!(inverse(1, 758), Some(1)); - assert!(inverse(4, 758).is_none()); - assert_eq!(inverse(7, 758), Some(325)); - assert!(inverse(10, 758).is_none()); - assert_eq!(inverse(13, 758), Some(175)); - assert!(inverse(16, 758).is_none()); - assert_eq!(inverse(19, 758), Some(399)); - assert!(inverse(22, 758).is_none()); - assert_eq!(inverse(25, 758), Some(91)); - assert!(inverse(28, 758).is_none()); - assert_eq!(inverse(1, 765), Some(1)); - assert_eq!(inverse(4, 765), Some(574)); - assert_eq!(inverse(7, 765), Some(328)); - assert!(inverse(10, 765).is_none()); - assert_eq!(inverse(13, 765), Some(412)); - assert_eq!(inverse(16, 765), Some(526)); - assert_eq!(inverse(19, 765), Some(604)); - assert_eq!(inverse(22, 765), Some(313)); - assert!(inverse(25, 765).is_none()); - assert_eq!(inverse(28, 765), Some(82)); - assert_eq!(inverse(1, 772), Some(1)); - assert!(inverse(4, 772).is_none()); - assert_eq!(inverse(7, 772), Some(331)); - assert!(inverse(10, 772).is_none()); - assert_eq!(inverse(13, 772), Some(297)); - assert!(inverse(16, 772).is_none()); - assert_eq!(inverse(19, 772), Some(447)); - assert!(inverse(22, 772).is_none()); - assert_eq!(inverse(25, 772), Some(525)); - assert!(inverse(28, 772).is_none()); - assert_eq!(inverse(1, 779), Some(1)); - assert_eq!(inverse(4, 779), Some(195)); - assert_eq!(inverse(7, 779), Some(334)); - assert_eq!(inverse(10, 779), Some(78)); - assert_eq!(inverse(13, 779), Some(60)); - assert_eq!(inverse(16, 779), Some(633)); - assert!(inverse(19, 779).is_none()); - assert_eq!(inverse(22, 779), Some(602)); - assert_eq!(inverse(25, 779), Some(187)); - assert_eq!(inverse(28, 779), Some(473)); - assert_eq!(inverse(1, 786), Some(1)); - assert!(inverse(4, 786).is_none()); - assert_eq!(inverse(7, 786), Some(337)); - assert!(inverse(10, 786).is_none()); - assert_eq!(inverse(13, 786), Some(121)); - assert!(inverse(16, 786).is_none()); - assert_eq!(inverse(19, 786), Some(331)); - assert!(inverse(22, 786).is_none()); - assert_eq!(inverse(25, 786), Some(283)); - assert!(inverse(28, 786).is_none()); - assert_eq!(inverse(1, 793), Some(1)); - assert_eq!(inverse(4, 793), Some(595)); - assert_eq!(inverse(7, 793), Some(340)); - assert_eq!(inverse(10, 793), Some(238)); - assert!(inverse(13, 793).is_none()); - assert_eq!(inverse(16, 793), Some(347)); - assert_eq!(inverse(19, 793), Some(167)); - assert_eq!(inverse(22, 793), Some(757)); - assert_eq!(inverse(25, 793), Some(571)); - assert_eq!(inverse(28, 793), Some(85)); - assert_eq!(inverse(1, 800), Some(1)); - assert!(inverse(4, 800).is_none()); - assert_eq!(inverse(7, 800), Some(343)); - assert!(inverse(10, 800).is_none()); - assert_eq!(inverse(13, 800), Some(677)); - assert!(inverse(16, 800).is_none()); - assert_eq!(inverse(19, 800), Some(379)); - assert!(inverse(22, 800).is_none()); - assert!(inverse(25, 800).is_none()); - assert!(inverse(28, 800).is_none()); - assert_eq!(inverse(1, 807), Some(1)); - assert_eq!(inverse(4, 807), Some(202)); - assert_eq!(inverse(7, 807), Some(346)); - assert_eq!(inverse(10, 807), Some(565)); - assert_eq!(inverse(13, 807), Some(745)); - assert_eq!(inverse(16, 807), Some(454)); - assert_eq!(inverse(19, 807), Some(85)); - assert_eq!(inverse(22, 807), Some(697)); - assert_eq!(inverse(25, 807), Some(226)); - assert_eq!(inverse(28, 807), Some(490)); - assert_eq!(inverse(1, 814), Some(1)); - assert!(inverse(4, 814).is_none()); - assert_eq!(inverse(7, 814), Some(349)); - assert!(inverse(10, 814).is_none()); - assert_eq!(inverse(13, 814), Some(501)); - assert!(inverse(16, 814).is_none()); - assert_eq!(inverse(19, 814), Some(557)); - assert!(inverse(22, 814).is_none()); - assert_eq!(inverse(25, 814), Some(521)); - assert!(inverse(28, 814).is_none()); - assert_eq!(inverse(1, 821), Some(1)); - assert_eq!(inverse(4, 821), Some(616)); - assert_eq!(inverse(7, 821), Some(352)); - assert_eq!(inverse(10, 821), Some(739)); - assert_eq!(inverse(13, 821), Some(379)); - assert_eq!(inverse(16, 821), Some(154)); - assert_eq!(inverse(19, 821), Some(605)); - assert_eq!(inverse(22, 821), Some(112)); - assert_eq!(inverse(25, 821), Some(624)); - assert_eq!(inverse(28, 821), Some(88)); - assert_eq!(inverse(1, 828), Some(1)); - assert!(inverse(4, 828).is_none()); - assert_eq!(inverse(7, 828), Some(355)); - assert!(inverse(10, 828).is_none()); - assert_eq!(inverse(13, 828), Some(637)); - assert!(inverse(16, 828).is_none()); - assert_eq!(inverse(19, 828), Some(523)); - assert!(inverse(22, 828).is_none()); - assert_eq!(inverse(25, 828), Some(265)); - assert!(inverse(28, 828).is_none()); - assert_eq!(inverse(1, 835), Some(1)); - assert_eq!(inverse(4, 835), Some(209)); - assert_eq!(inverse(7, 835), Some(358)); - assert!(inverse(10, 835).is_none()); - assert_eq!(inverse(13, 835), Some(257)); - assert_eq!(inverse(16, 835), Some(261)); - assert_eq!(inverse(19, 835), Some(44)); - assert_eq!(inverse(22, 835), Some(38)); - assert!(inverse(25, 835).is_none()); - assert_eq!(inverse(28, 835), Some(507)); - assert_eq!(inverse(1, 842), Some(1)); - assert!(inverse(4, 842).is_none()); - assert_eq!(inverse(7, 842), Some(361)); - assert!(inverse(10, 842).is_none()); - assert_eq!(inverse(13, 842), Some(583)); - assert!(inverse(16, 842).is_none()); - assert_eq!(inverse(19, 842), Some(133)); - assert!(inverse(22, 842).is_none()); - assert_eq!(inverse(25, 842), Some(741)); - assert!(inverse(28, 842).is_none()); - assert_eq!(inverse(1, 849), Some(1)); - assert_eq!(inverse(4, 849), Some(637)); - assert_eq!(inverse(7, 849), Some(364)); - assert_eq!(inverse(10, 849), Some(85)); - assert_eq!(inverse(13, 849), Some(196)); - assert_eq!(inverse(16, 849), Some(796)); - assert_eq!(inverse(19, 849), Some(715)); - assert_eq!(inverse(22, 849), Some(193)); - assert_eq!(inverse(25, 849), Some(34)); - assert_eq!(inverse(28, 849), Some(91)); - assert_eq!(inverse(1, 856), Some(1)); - assert!(inverse(4, 856).is_none()); - assert_eq!(inverse(7, 856), Some(367)); - assert!(inverse(10, 856).is_none()); - assert_eq!(inverse(13, 856), Some(461)); - assert!(inverse(16, 856).is_none()); - assert_eq!(inverse(19, 856), Some(811)); - assert!(inverse(22, 856).is_none()); - assert_eq!(inverse(25, 856), Some(137)); - assert!(inverse(28, 856).is_none()); - assert_eq!(inverse(1, 863), Some(1)); - assert_eq!(inverse(4, 863), Some(216)); - assert_eq!(inverse(7, 863), Some(370)); - assert_eq!(inverse(10, 863), Some(259)); - assert_eq!(inverse(13, 863), Some(332)); - assert_eq!(inverse(16, 863), Some(54)); - assert_eq!(inverse(19, 863), Some(318)); - assert_eq!(inverse(22, 863), Some(510)); - assert_eq!(inverse(25, 863), Some(794)); - assert_eq!(inverse(28, 863), Some(524)); - assert_eq!(inverse(1, 870), Some(1)); - assert!(inverse(4, 870).is_none()); - assert_eq!(inverse(7, 870), Some(373)); - assert!(inverse(10, 870).is_none()); - assert_eq!(inverse(13, 870), Some(67)); - assert!(inverse(16, 870).is_none()); - assert_eq!(inverse(19, 870), Some(229)); - assert!(inverse(22, 870).is_none()); - assert!(inverse(25, 870).is_none()); - assert!(inverse(28, 870).is_none()); - assert_eq!(inverse(1, 877), Some(1)); - assert_eq!(inverse(4, 877), Some(658)); - assert_eq!(inverse(7, 877), Some(376)); - assert_eq!(inverse(10, 877), Some(614)); - assert_eq!(inverse(13, 877), Some(135)); - assert_eq!(inverse(16, 877), Some(603)); - assert_eq!(inverse(19, 877), Some(277)); - assert_eq!(inverse(22, 877), Some(598)); - assert_eq!(inverse(25, 877), Some(421)); - assert_eq!(inverse(28, 877), Some(94)); - assert_eq!(inverse(1, 884), Some(1)); - assert!(inverse(4, 884).is_none()); - assert_eq!(inverse(7, 884), Some(379)); - assert!(inverse(10, 884).is_none()); - assert!(inverse(13, 884).is_none()); - assert!(inverse(16, 884).is_none()); - assert_eq!(inverse(19, 884), Some(791)); - assert!(inverse(22, 884).is_none()); - assert_eq!(inverse(25, 884), Some(389)); - assert!(inverse(28, 884).is_none()); - assert_eq!(inverse(1, 891), Some(1)); - assert_eq!(inverse(4, 891), Some(223)); - assert_eq!(inverse(7, 891), Some(382)); - assert_eq!(inverse(10, 891), Some(802)); - assert_eq!(inverse(13, 891), Some(754)); - assert_eq!(inverse(16, 891), Some(724)); - assert_eq!(inverse(19, 891), Some(469)); - assert!(inverse(22, 891).is_none()); - assert_eq!(inverse(25, 891), Some(499)); - assert_eq!(inverse(28, 891), Some(541)); - assert_eq!(inverse(1, 898), Some(1)); - assert!(inverse(4, 898).is_none()); - assert_eq!(inverse(7, 898), Some(385)); - assert!(inverse(10, 898).is_none()); - assert_eq!(inverse(13, 898), Some(829)); - assert!(inverse(16, 898).is_none()); - assert_eq!(inverse(19, 898), Some(709)); - assert!(inverse(22, 898).is_none()); - assert_eq!(inverse(25, 898), Some(467)); - assert!(inverse(28, 898).is_none()); - assert_eq!(inverse(1, 905), Some(1)); - assert_eq!(inverse(4, 905), Some(679)); - assert_eq!(inverse(7, 905), Some(388)); - assert!(inverse(10, 905).is_none()); - assert_eq!(inverse(13, 905), Some(557)); - assert_eq!(inverse(16, 905), Some(396)); - assert_eq!(inverse(19, 905), Some(524)); - assert_eq!(inverse(22, 905), Some(288)); - assert!(inverse(25, 905).is_none()); - assert_eq!(inverse(28, 905), Some(97)); - assert_eq!(inverse(1, 912), Some(1)); - assert!(inverse(4, 912).is_none()); - assert_eq!(inverse(7, 912), Some(391)); - assert!(inverse(10, 912).is_none()); - assert_eq!(inverse(13, 912), Some(421)); - assert!(inverse(16, 912).is_none()); - assert!(inverse(19, 912).is_none()); - assert!(inverse(22, 912).is_none()); - assert_eq!(inverse(25, 912), Some(73)); - assert!(inverse(28, 912).is_none()); - assert_eq!(inverse(1, 919), Some(1)); - assert_eq!(inverse(4, 919), Some(230)); - assert_eq!(inverse(7, 919), Some(394)); - assert_eq!(inverse(10, 919), Some(92)); - assert_eq!(inverse(13, 919), Some(707)); - assert_eq!(inverse(16, 919), Some(517)); - assert_eq!(inverse(19, 919), Some(387)); - assert_eq!(inverse(22, 919), Some(376)); - assert_eq!(inverse(25, 919), Some(772)); - assert_eq!(inverse(28, 919), Some(558)); - assert_eq!(inverse(1, 926), Some(1)); - assert!(inverse(4, 926).is_none()); - assert_eq!(inverse(7, 926), Some(397)); - assert!(inverse(10, 926).is_none()); - assert_eq!(inverse(13, 926), Some(285)); - assert!(inverse(16, 926).is_none()); - assert_eq!(inverse(19, 926), Some(195)); - assert!(inverse(22, 926).is_none()); - assert_eq!(inverse(25, 926), Some(889)); - assert!(inverse(28, 926).is_none()); - assert_eq!(inverse(1, 933), Some(1)); - assert_eq!(inverse(4, 933), Some(700)); - assert_eq!(inverse(7, 933), Some(400)); - assert_eq!(inverse(10, 933), Some(280)); - assert_eq!(inverse(13, 933), Some(646)); - assert_eq!(inverse(16, 933), Some(175)); - assert_eq!(inverse(19, 933), Some(442)); - assert_eq!(inverse(22, 933), Some(721)); - assert_eq!(inverse(25, 933), Some(112)); - assert_eq!(inverse(28, 933), Some(100)); - assert_eq!(inverse(1, 940), Some(1)); - assert!(inverse(4, 940).is_none()); - assert_eq!(inverse(7, 940), Some(403)); - assert!(inverse(10, 940).is_none()); - assert_eq!(inverse(13, 940), Some(217)); - assert!(inverse(16, 940).is_none()); - assert_eq!(inverse(19, 940), Some(99)); - assert!(inverse(22, 940).is_none()); - assert!(inverse(25, 940).is_none()); - assert!(inverse(28, 940).is_none()); - assert_eq!(inverse(1, 947), Some(1)); - assert_eq!(inverse(4, 947), Some(237)); - assert_eq!(inverse(7, 947), Some(406)); - assert_eq!(inverse(10, 947), Some(663)); - assert_eq!(inverse(13, 947), Some(510)); - assert_eq!(inverse(16, 947), Some(296)); - assert_eq!(inverse(19, 947), Some(648)); - assert_eq!(inverse(22, 947), Some(904)); - assert_eq!(inverse(25, 947), Some(644)); - assert_eq!(inverse(28, 947), Some(575)); - assert_eq!(inverse(1, 954), Some(1)); - assert!(inverse(4, 954).is_none()); - assert_eq!(inverse(7, 954), Some(409)); - assert!(inverse(10, 954).is_none()); - assert_eq!(inverse(13, 954), Some(367)); - assert!(inverse(16, 954).is_none()); - assert_eq!(inverse(19, 954), Some(703)); - assert!(inverse(22, 954).is_none()); - assert_eq!(inverse(25, 954), Some(229)); - assert!(inverse(28, 954).is_none()); - assert_eq!(inverse(1, 961), Some(1)); - assert_eq!(inverse(4, 961), Some(721)); - assert_eq!(inverse(7, 961), Some(412)); - assert_eq!(inverse(10, 961), Some(865)); - assert_eq!(inverse(13, 961), Some(74)); - assert_eq!(inverse(16, 961), Some(901)); - assert_eq!(inverse(19, 961), Some(607)); - assert_eq!(inverse(22, 961), Some(830)); - assert_eq!(inverse(25, 961), Some(346)); - assert_eq!(inverse(28, 961), Some(103)); - assert_eq!(inverse(1, 968), Some(1)); - assert!(inverse(4, 968).is_none()); - assert_eq!(inverse(7, 968), Some(415)); - assert!(inverse(10, 968).is_none()); - assert_eq!(inverse(13, 968), Some(149)); - assert!(inverse(16, 968).is_none()); - assert_eq!(inverse(19, 968), Some(51)); - assert!(inverse(22, 968).is_none()); - assert_eq!(inverse(25, 968), Some(697)); - assert!(inverse(28, 968).is_none()); - assert_eq!(inverse(1, 975), Some(1)); - assert_eq!(inverse(4, 975), Some(244)); - assert_eq!(inverse(7, 975), Some(418)); - assert!(inverse(10, 975).is_none()); - assert!(inverse(13, 975).is_none()); - assert_eq!(inverse(16, 975), Some(61)); - assert_eq!(inverse(19, 975), Some(154)); - assert_eq!(inverse(22, 975), Some(133)); - assert!(inverse(25, 975).is_none()); - assert_eq!(inverse(28, 975), Some(592)); - assert_eq!(inverse(1, 982), Some(1)); - assert!(inverse(4, 982).is_none()); - assert_eq!(inverse(7, 982), Some(421)); - assert!(inverse(10, 982).is_none()); - assert_eq!(inverse(13, 982), Some(831)); - assert!(inverse(16, 982).is_none()); - assert_eq!(inverse(19, 982), Some(827)); - assert!(inverse(22, 982).is_none()); - assert_eq!(inverse(25, 982), Some(275)); - assert!(inverse(28, 982).is_none()); - assert_eq!(inverse(1, 989), Some(1)); - assert_eq!(inverse(4, 989), Some(742)); - assert_eq!(inverse(7, 989), Some(424)); - assert_eq!(inverse(10, 989), Some(99)); - assert_eq!(inverse(13, 989), Some(913)); - assert_eq!(inverse(16, 989), Some(680)); - assert_eq!(inverse(19, 989), Some(937)); - assert_eq!(inverse(22, 989), Some(45)); - assert_eq!(inverse(25, 989), Some(633)); - assert_eq!(inverse(28, 989), Some(106)); - assert_eq!(inverse(1, 996), Some(1)); - assert!(inverse(4, 996).is_none()); - assert_eq!(inverse(7, 996), Some(427)); - assert!(inverse(10, 996).is_none()); - assert_eq!(inverse(13, 996), Some(613)); - assert!(inverse(16, 996).is_none()); - assert_eq!(inverse(19, 996), Some(367)); - assert!(inverse(22, 996).is_none()); - assert_eq!(inverse(25, 996), Some(757)); - assert!(inverse(28, 996).is_none()); - } + #[test] + fn inv_kats() { + // KATs for inversion generated in Sage using the following code. + /* + sage: for p in range(2, 1000, 7): + ....: for a in range(1, 30, 3): + ....: if gcd(a, p) == 1: + ....: i = ZZ(a)^(-1) % p + ....: print("assert_eq!(inverse({}, {}), Some({}));".format(a, p, i)) + ....: else: + ....: print("assert!(inverse({}, {}).is_none());".format(a, p)) + */ + assert_eq!(inverse(1, 2), Some(1)); + assert!(inverse(4, 2).is_none()); + assert_eq!(inverse(7, 2), Some(1)); + assert!(inverse(10, 2).is_none()); + assert_eq!(inverse(13, 2), Some(1)); + assert!(inverse(16, 2).is_none()); + assert_eq!(inverse(19, 2), Some(1)); + assert!(inverse(22, 2).is_none()); + assert_eq!(inverse(25, 2), Some(1)); + assert!(inverse(28, 2).is_none()); + assert_eq!(inverse(1, 9), Some(1)); + assert_eq!(inverse(4, 9), Some(7)); + assert_eq!(inverse(7, 9), Some(4)); + assert_eq!(inverse(10, 9), Some(1)); + assert_eq!(inverse(13, 9), Some(7)); + assert_eq!(inverse(16, 9), Some(4)); + assert_eq!(inverse(19, 9), Some(1)); + assert_eq!(inverse(22, 9), Some(7)); + assert_eq!(inverse(25, 9), Some(4)); + assert_eq!(inverse(28, 9), Some(1)); + assert_eq!(inverse(1, 16), Some(1)); + assert!(inverse(4, 16).is_none()); + assert_eq!(inverse(7, 16), Some(7)); + assert!(inverse(10, 16).is_none()); + assert_eq!(inverse(13, 16), Some(5)); + assert!(inverse(16, 16).is_none()); + assert_eq!(inverse(19, 16), Some(11)); + assert!(inverse(22, 16).is_none()); + assert_eq!(inverse(25, 16), Some(9)); + assert!(inverse(28, 16).is_none()); + assert_eq!(inverse(1, 23), Some(1)); + assert_eq!(inverse(4, 23), Some(6)); + assert_eq!(inverse(7, 23), Some(10)); + assert_eq!(inverse(10, 23), Some(7)); + assert_eq!(inverse(13, 23), Some(16)); + assert_eq!(inverse(16, 23), Some(13)); + assert_eq!(inverse(19, 23), Some(17)); + assert_eq!(inverse(22, 23), Some(22)); + assert_eq!(inverse(25, 23), Some(12)); + assert_eq!(inverse(28, 23), Some(14)); + assert_eq!(inverse(1, 30), Some(1)); + assert!(inverse(4, 30).is_none()); + assert_eq!(inverse(7, 30), Some(13)); + assert!(inverse(10, 30).is_none()); + assert_eq!(inverse(13, 30), Some(7)); + assert!(inverse(16, 30).is_none()); + assert_eq!(inverse(19, 30), Some(19)); + assert!(inverse(22, 30).is_none()); + assert!(inverse(25, 30).is_none()); + assert!(inverse(28, 30).is_none()); + assert_eq!(inverse(1, 37), Some(1)); + assert_eq!(inverse(4, 37), Some(28)); + assert_eq!(inverse(7, 37), Some(16)); + assert_eq!(inverse(10, 37), Some(26)); + assert_eq!(inverse(13, 37), Some(20)); + assert_eq!(inverse(16, 37), Some(7)); + assert_eq!(inverse(19, 37), Some(2)); + assert_eq!(inverse(22, 37), Some(32)); + assert_eq!(inverse(25, 37), Some(3)); + assert_eq!(inverse(28, 37), Some(4)); + assert_eq!(inverse(1, 44), Some(1)); + assert!(inverse(4, 44).is_none()); + assert_eq!(inverse(7, 44), Some(19)); + assert!(inverse(10, 44).is_none()); + assert_eq!(inverse(13, 44), Some(17)); + assert!(inverse(16, 44).is_none()); + assert_eq!(inverse(19, 44), Some(7)); + assert!(inverse(22, 44).is_none()); + assert_eq!(inverse(25, 44), Some(37)); + assert!(inverse(28, 44).is_none()); + assert_eq!(inverse(1, 51), Some(1)); + assert_eq!(inverse(4, 51), Some(13)); + assert_eq!(inverse(7, 51), Some(22)); + assert_eq!(inverse(10, 51), Some(46)); + assert_eq!(inverse(13, 51), Some(4)); + assert_eq!(inverse(16, 51), Some(16)); + assert_eq!(inverse(19, 51), Some(43)); + assert_eq!(inverse(22, 51), Some(7)); + assert_eq!(inverse(25, 51), Some(49)); + assert_eq!(inverse(28, 51), Some(31)); + assert_eq!(inverse(1, 58), Some(1)); + assert!(inverse(4, 58).is_none()); + assert_eq!(inverse(7, 58), Some(25)); + assert!(inverse(10, 58).is_none()); + assert_eq!(inverse(13, 58), Some(9)); + assert!(inverse(16, 58).is_none()); + assert_eq!(inverse(19, 58), Some(55)); + assert!(inverse(22, 58).is_none()); + assert_eq!(inverse(25, 58), Some(7)); + assert!(inverse(28, 58).is_none()); + assert_eq!(inverse(1, 65), Some(1)); + assert_eq!(inverse(4, 65), Some(49)); + assert_eq!(inverse(7, 65), Some(28)); + assert!(inverse(10, 65).is_none()); + assert!(inverse(13, 65).is_none()); + assert_eq!(inverse(16, 65), Some(61)); + assert_eq!(inverse(19, 65), Some(24)); + assert_eq!(inverse(22, 65), Some(3)); + assert!(inverse(25, 65).is_none()); + assert_eq!(inverse(28, 65), Some(7)); + assert_eq!(inverse(1, 72), Some(1)); + assert!(inverse(4, 72).is_none()); + assert_eq!(inverse(7, 72), Some(31)); + assert!(inverse(10, 72).is_none()); + assert_eq!(inverse(13, 72), Some(61)); + assert!(inverse(16, 72).is_none()); + assert_eq!(inverse(19, 72), Some(19)); + assert!(inverse(22, 72).is_none()); + assert_eq!(inverse(25, 72), Some(49)); + assert!(inverse(28, 72).is_none()); + assert_eq!(inverse(1, 79), Some(1)); + assert_eq!(inverse(4, 79), Some(20)); + assert_eq!(inverse(7, 79), Some(34)); + assert_eq!(inverse(10, 79), Some(8)); + assert_eq!(inverse(13, 79), Some(73)); + assert_eq!(inverse(16, 79), Some(5)); + assert_eq!(inverse(19, 79), Some(25)); + assert_eq!(inverse(22, 79), Some(18)); + assert_eq!(inverse(25, 79), Some(19)); + assert_eq!(inverse(28, 79), Some(48)); + assert_eq!(inverse(1, 86), Some(1)); + assert!(inverse(4, 86).is_none()); + assert_eq!(inverse(7, 86), Some(37)); + assert!(inverse(10, 86).is_none()); + assert_eq!(inverse(13, 86), Some(53)); + assert!(inverse(16, 86).is_none()); + assert_eq!(inverse(19, 86), Some(77)); + assert!(inverse(22, 86).is_none()); + assert_eq!(inverse(25, 86), Some(31)); + assert!(inverse(28, 86).is_none()); + assert_eq!(inverse(1, 93), Some(1)); + assert_eq!(inverse(4, 93), Some(70)); + assert_eq!(inverse(7, 93), Some(40)); + assert_eq!(inverse(10, 93), Some(28)); + assert_eq!(inverse(13, 93), Some(43)); + assert_eq!(inverse(16, 93), Some(64)); + assert_eq!(inverse(19, 93), Some(49)); + assert_eq!(inverse(22, 93), Some(55)); + assert_eq!(inverse(25, 93), Some(67)); + assert_eq!(inverse(28, 93), Some(10)); + assert_eq!(inverse(1, 100), Some(1)); + assert!(inverse(4, 100).is_none()); + assert_eq!(inverse(7, 100), Some(43)); + assert!(inverse(10, 100).is_none()); + assert_eq!(inverse(13, 100), Some(77)); + assert!(inverse(16, 100).is_none()); + assert_eq!(inverse(19, 100), Some(79)); + assert!(inverse(22, 100).is_none()); + assert!(inverse(25, 100).is_none()); + assert!(inverse(28, 100).is_none()); + assert_eq!(inverse(1, 107), Some(1)); + assert_eq!(inverse(4, 107), Some(27)); + assert_eq!(inverse(7, 107), Some(46)); + assert_eq!(inverse(10, 107), Some(75)); + assert_eq!(inverse(13, 107), Some(33)); + assert_eq!(inverse(16, 107), Some(87)); + assert_eq!(inverse(19, 107), Some(62)); + assert_eq!(inverse(22, 107), Some(73)); + assert_eq!(inverse(25, 107), Some(30)); + assert_eq!(inverse(28, 107), Some(65)); + assert_eq!(inverse(1, 114), Some(1)); + assert!(inverse(4, 114).is_none()); + assert_eq!(inverse(7, 114), Some(49)); + assert!(inverse(10, 114).is_none()); + assert_eq!(inverse(13, 114), Some(79)); + assert!(inverse(16, 114).is_none()); + assert!(inverse(19, 114).is_none()); + assert!(inverse(22, 114).is_none()); + assert_eq!(inverse(25, 114), Some(73)); + assert!(inverse(28, 114).is_none()); + assert_eq!(inverse(1, 121), Some(1)); + assert_eq!(inverse(4, 121), Some(91)); + assert_eq!(inverse(7, 121), Some(52)); + assert_eq!(inverse(10, 121), Some(109)); + assert_eq!(inverse(13, 121), Some(28)); + assert_eq!(inverse(16, 121), Some(53)); + assert_eq!(inverse(19, 121), Some(51)); + assert!(inverse(22, 121).is_none()); + assert_eq!(inverse(25, 121), Some(92)); + assert_eq!(inverse(28, 121), Some(13)); + assert_eq!(inverse(1, 128), Some(1)); + assert!(inverse(4, 128).is_none()); + assert_eq!(inverse(7, 128), Some(55)); + assert!(inverse(10, 128).is_none()); + assert_eq!(inverse(13, 128), Some(69)); + assert!(inverse(16, 128).is_none()); + assert_eq!(inverse(19, 128), Some(27)); + assert!(inverse(22, 128).is_none()); + assert_eq!(inverse(25, 128), Some(41)); + assert!(inverse(28, 128).is_none()); + assert_eq!(inverse(1, 135), Some(1)); + assert_eq!(inverse(4, 135), Some(34)); + assert_eq!(inverse(7, 135), Some(58)); + assert!(inverse(10, 135).is_none()); + assert_eq!(inverse(13, 135), Some(52)); + assert_eq!(inverse(16, 135), Some(76)); + assert_eq!(inverse(19, 135), Some(64)); + assert_eq!(inverse(22, 135), Some(43)); + assert!(inverse(25, 135).is_none()); + assert_eq!(inverse(28, 135), Some(82)); + assert_eq!(inverse(1, 142), Some(1)); + assert!(inverse(4, 142).is_none()); + assert_eq!(inverse(7, 142), Some(61)); + assert!(inverse(10, 142).is_none()); + assert_eq!(inverse(13, 142), Some(11)); + assert!(inverse(16, 142).is_none()); + assert_eq!(inverse(19, 142), Some(15)); + assert!(inverse(22, 142).is_none()); + assert_eq!(inverse(25, 142), Some(125)); + assert!(inverse(28, 142).is_none()); + assert_eq!(inverse(1, 149), Some(1)); + assert_eq!(inverse(4, 149), Some(112)); + assert_eq!(inverse(7, 149), Some(64)); + assert_eq!(inverse(10, 149), Some(15)); + assert_eq!(inverse(13, 149), Some(23)); + assert_eq!(inverse(16, 149), Some(28)); + assert_eq!(inverse(19, 149), Some(102)); + assert_eq!(inverse(22, 149), Some(61)); + assert_eq!(inverse(25, 149), Some(6)); + assert_eq!(inverse(28, 149), Some(16)); + assert_eq!(inverse(1, 156), Some(1)); + assert!(inverse(4, 156).is_none()); + assert_eq!(inverse(7, 156), Some(67)); + assert!(inverse(10, 156).is_none()); + assert!(inverse(13, 156).is_none()); + assert!(inverse(16, 156).is_none()); + assert_eq!(inverse(19, 156), Some(115)); + assert!(inverse(22, 156).is_none()); + assert_eq!(inverse(25, 156), Some(25)); + assert!(inverse(28, 156).is_none()); + assert_eq!(inverse(1, 163), Some(1)); + assert_eq!(inverse(4, 163), Some(41)); + assert_eq!(inverse(7, 163), Some(70)); + assert_eq!(inverse(10, 163), Some(49)); + assert_eq!(inverse(13, 163), Some(138)); + assert_eq!(inverse(16, 163), Some(51)); + assert_eq!(inverse(19, 163), Some(103)); + assert_eq!(inverse(22, 163), Some(126)); + assert_eq!(inverse(25, 163), Some(150)); + assert_eq!(inverse(28, 163), Some(99)); + assert_eq!(inverse(1, 170), Some(1)); + assert!(inverse(4, 170).is_none()); + assert_eq!(inverse(7, 170), Some(73)); + assert!(inverse(10, 170).is_none()); + assert_eq!(inverse(13, 170), Some(157)); + assert!(inverse(16, 170).is_none()); + assert_eq!(inverse(19, 170), Some(9)); + assert!(inverse(22, 170).is_none()); + assert!(inverse(25, 170).is_none()); + assert!(inverse(28, 170).is_none()); + assert_eq!(inverse(1, 177), Some(1)); + assert_eq!(inverse(4, 177), Some(133)); + assert_eq!(inverse(7, 177), Some(76)); + assert_eq!(inverse(10, 177), Some(124)); + assert_eq!(inverse(13, 177), Some(109)); + assert_eq!(inverse(16, 177), Some(166)); + assert_eq!(inverse(19, 177), Some(28)); + assert_eq!(inverse(22, 177), Some(169)); + assert_eq!(inverse(25, 177), Some(85)); + assert_eq!(inverse(28, 177), Some(19)); + assert_eq!(inverse(1, 184), Some(1)); + assert!(inverse(4, 184).is_none()); + assert_eq!(inverse(7, 184), Some(79)); + assert!(inverse(10, 184).is_none()); + assert_eq!(inverse(13, 184), Some(85)); + assert!(inverse(16, 184).is_none()); + assert_eq!(inverse(19, 184), Some(155)); + assert!(inverse(22, 184).is_none()); + assert_eq!(inverse(25, 184), Some(81)); + assert!(inverse(28, 184).is_none()); + assert_eq!(inverse(1, 191), Some(1)); + assert_eq!(inverse(4, 191), Some(48)); + assert_eq!(inverse(7, 191), Some(82)); + assert_eq!(inverse(10, 191), Some(172)); + assert_eq!(inverse(13, 191), Some(147)); + assert_eq!(inverse(16, 191), Some(12)); + assert_eq!(inverse(19, 191), Some(181)); + assert_eq!(inverse(22, 191), Some(165)); + assert_eq!(inverse(25, 191), Some(107)); + assert_eq!(inverse(28, 191), Some(116)); + assert_eq!(inverse(1, 198), Some(1)); + assert!(inverse(4, 198).is_none()); + assert_eq!(inverse(7, 198), Some(85)); + assert!(inverse(10, 198).is_none()); + assert_eq!(inverse(13, 198), Some(61)); + assert!(inverse(16, 198).is_none()); + assert_eq!(inverse(19, 198), Some(73)); + assert!(inverse(22, 198).is_none()); + assert_eq!(inverse(25, 198), Some(103)); + assert!(inverse(28, 198).is_none()); + assert_eq!(inverse(1, 205), Some(1)); + assert_eq!(inverse(4, 205), Some(154)); + assert_eq!(inverse(7, 205), Some(88)); + assert!(inverse(10, 205).is_none()); + assert_eq!(inverse(13, 205), Some(142)); + assert_eq!(inverse(16, 205), Some(141)); + assert_eq!(inverse(19, 205), Some(54)); + assert_eq!(inverse(22, 205), Some(28)); + assert!(inverse(25, 205).is_none()); + assert_eq!(inverse(28, 205), Some(22)); + assert_eq!(inverse(1, 212), Some(1)); + assert!(inverse(4, 212).is_none()); + assert_eq!(inverse(7, 212), Some(91)); + assert!(inverse(10, 212).is_none()); + assert_eq!(inverse(13, 212), Some(49)); + assert!(inverse(16, 212).is_none()); + assert_eq!(inverse(19, 212), Some(67)); + assert!(inverse(22, 212).is_none()); + assert_eq!(inverse(25, 212), Some(17)); + assert!(inverse(28, 212).is_none()); + assert_eq!(inverse(1, 219), Some(1)); + assert_eq!(inverse(4, 219), Some(55)); + assert_eq!(inverse(7, 219), Some(94)); + assert_eq!(inverse(10, 219), Some(22)); + assert_eq!(inverse(13, 219), Some(118)); + assert_eq!(inverse(16, 219), Some(178)); + assert_eq!(inverse(19, 219), Some(196)); + assert_eq!(inverse(22, 219), Some(10)); + assert_eq!(inverse(25, 219), Some(184)); + assert_eq!(inverse(28, 219), Some(133)); + assert_eq!(inverse(1, 226), Some(1)); + assert!(inverse(4, 226).is_none()); + assert_eq!(inverse(7, 226), Some(97)); + assert!(inverse(10, 226).is_none()); + assert_eq!(inverse(13, 226), Some(87)); + assert!(inverse(16, 226).is_none()); + assert_eq!(inverse(19, 226), Some(119)); + assert!(inverse(22, 226).is_none()); + assert_eq!(inverse(25, 226), Some(217)); + assert!(inverse(28, 226).is_none()); + assert_eq!(inverse(1, 233), Some(1)); + assert_eq!(inverse(4, 233), Some(175)); + assert_eq!(inverse(7, 233), Some(100)); + assert_eq!(inverse(10, 233), Some(70)); + assert_eq!(inverse(13, 233), Some(18)); + assert_eq!(inverse(16, 233), Some(102)); + assert_eq!(inverse(19, 233), Some(184)); + assert_eq!(inverse(22, 233), Some(53)); + assert_eq!(inverse(25, 233), Some(28)); + assert_eq!(inverse(28, 233), Some(25)); + assert_eq!(inverse(1, 240), Some(1)); + assert!(inverse(4, 240).is_none()); + assert_eq!(inverse(7, 240), Some(103)); + assert!(inverse(10, 240).is_none()); + assert_eq!(inverse(13, 240), Some(37)); + assert!(inverse(16, 240).is_none()); + assert_eq!(inverse(19, 240), Some(139)); + assert!(inverse(22, 240).is_none()); + assert!(inverse(25, 240).is_none()); + assert!(inverse(28, 240).is_none()); + assert_eq!(inverse(1, 247), Some(1)); + assert_eq!(inverse(4, 247), Some(62)); + assert_eq!(inverse(7, 247), Some(106)); + assert_eq!(inverse(10, 247), Some(173)); + assert!(inverse(13, 247).is_none()); + assert_eq!(inverse(16, 247), Some(139)); + assert!(inverse(19, 247).is_none()); + assert_eq!(inverse(22, 247), Some(146)); + assert_eq!(inverse(25, 247), Some(168)); + assert_eq!(inverse(28, 247), Some(150)); + assert_eq!(inverse(1, 254), Some(1)); + assert!(inverse(4, 254).is_none()); + assert_eq!(inverse(7, 254), Some(109)); + assert!(inverse(10, 254).is_none()); + assert_eq!(inverse(13, 254), Some(215)); + assert!(inverse(16, 254).is_none()); + assert_eq!(inverse(19, 254), Some(107)); + assert!(inverse(22, 254).is_none()); + assert_eq!(inverse(25, 254), Some(61)); + assert!(inverse(28, 254).is_none()); + assert_eq!(inverse(1, 261), Some(1)); + assert_eq!(inverse(4, 261), Some(196)); + assert_eq!(inverse(7, 261), Some(112)); + assert_eq!(inverse(10, 261), Some(235)); + assert_eq!(inverse(13, 261), Some(241)); + assert_eq!(inverse(16, 261), Some(49)); + assert_eq!(inverse(19, 261), Some(55)); + assert_eq!(inverse(22, 261), Some(178)); + assert_eq!(inverse(25, 261), Some(94)); + assert_eq!(inverse(28, 261), Some(28)); + assert_eq!(inverse(1, 268), Some(1)); + assert!(inverse(4, 268).is_none()); + assert_eq!(inverse(7, 268), Some(115)); + assert!(inverse(10, 268).is_none()); + assert_eq!(inverse(13, 268), Some(165)); + assert!(inverse(16, 268).is_none()); + assert_eq!(inverse(19, 268), Some(127)); + assert!(inverse(22, 268).is_none()); + assert_eq!(inverse(25, 268), Some(193)); + assert!(inverse(28, 268).is_none()); + assert_eq!(inverse(1, 275), Some(1)); + assert_eq!(inverse(4, 275), Some(69)); + assert_eq!(inverse(7, 275), Some(118)); + assert!(inverse(10, 275).is_none()); + assert_eq!(inverse(13, 275), Some(127)); + assert_eq!(inverse(16, 275), Some(86)); + assert_eq!(inverse(19, 275), Some(29)); + assert!(inverse(22, 275).is_none()); + assert!(inverse(25, 275).is_none()); + assert_eq!(inverse(28, 275), Some(167)); + assert_eq!(inverse(1, 282), Some(1)); + assert!(inverse(4, 282).is_none()); + assert_eq!(inverse(7, 282), Some(121)); + assert!(inverse(10, 282).is_none()); + assert_eq!(inverse(13, 282), Some(217)); + assert!(inverse(16, 282).is_none()); + assert_eq!(inverse(19, 282), Some(193)); + assert!(inverse(22, 282).is_none()); + assert_eq!(inverse(25, 282), Some(79)); + assert!(inverse(28, 282).is_none()); + assert_eq!(inverse(1, 289), Some(1)); + assert_eq!(inverse(4, 289), Some(217)); + assert_eq!(inverse(7, 289), Some(124)); + assert_eq!(inverse(10, 289), Some(29)); + assert_eq!(inverse(13, 289), Some(89)); + assert_eq!(inverse(16, 289), Some(271)); + assert_eq!(inverse(19, 289), Some(213)); + assert_eq!(inverse(22, 289), Some(92)); + assert_eq!(inverse(25, 289), Some(185)); + assert_eq!(inverse(28, 289), Some(31)); + assert_eq!(inverse(1, 296), Some(1)); + assert!(inverse(4, 296).is_none()); + assert_eq!(inverse(7, 296), Some(127)); + assert!(inverse(10, 296).is_none()); + assert_eq!(inverse(13, 296), Some(205)); + assert!(inverse(16, 296).is_none()); + assert_eq!(inverse(19, 296), Some(187)); + assert!(inverse(22, 296).is_none()); + assert_eq!(inverse(25, 296), Some(225)); + assert!(inverse(28, 296).is_none()); + assert_eq!(inverse(1, 303), Some(1)); + assert_eq!(inverse(4, 303), Some(76)); + assert_eq!(inverse(7, 303), Some(130)); + assert_eq!(inverse(10, 303), Some(91)); + assert_eq!(inverse(13, 303), Some(70)); + assert_eq!(inverse(16, 303), Some(19)); + assert_eq!(inverse(19, 303), Some(16)); + assert_eq!(inverse(22, 303), Some(124)); + assert_eq!(inverse(25, 303), Some(97)); + assert_eq!(inverse(28, 303), Some(184)); + assert_eq!(inverse(1, 310), Some(1)); + assert!(inverse(4, 310).is_none()); + assert_eq!(inverse(7, 310), Some(133)); + assert!(inverse(10, 310).is_none()); + assert_eq!(inverse(13, 310), Some(167)); + assert!(inverse(16, 310).is_none()); + assert_eq!(inverse(19, 310), Some(49)); + assert!(inverse(22, 310).is_none()); + assert!(inverse(25, 310).is_none()); + assert!(inverse(28, 310).is_none()); + assert_eq!(inverse(1, 317), Some(1)); + assert_eq!(inverse(4, 317), Some(238)); + assert_eq!(inverse(7, 317), Some(136)); + assert_eq!(inverse(10, 317), Some(222)); + assert_eq!(inverse(13, 317), Some(122)); + assert_eq!(inverse(16, 317), Some(218)); + assert_eq!(inverse(19, 317), Some(267)); + assert_eq!(inverse(22, 317), Some(245)); + assert_eq!(inverse(25, 317), Some(279)); + assert_eq!(inverse(28, 317), Some(34)); + assert_eq!(inverse(1, 324), Some(1)); + assert!(inverse(4, 324).is_none()); + assert_eq!(inverse(7, 324), Some(139)); + assert!(inverse(10, 324).is_none()); + assert_eq!(inverse(13, 324), Some(25)); + assert!(inverse(16, 324).is_none()); + assert_eq!(inverse(19, 324), Some(307)); + assert!(inverse(22, 324).is_none()); + assert_eq!(inverse(25, 324), Some(13)); + assert!(inverse(28, 324).is_none()); + assert_eq!(inverse(1, 331), Some(1)); + assert_eq!(inverse(4, 331), Some(83)); + assert_eq!(inverse(7, 331), Some(142)); + assert_eq!(inverse(10, 331), Some(298)); + assert_eq!(inverse(13, 331), Some(51)); + assert_eq!(inverse(16, 331), Some(269)); + assert_eq!(inverse(19, 331), Some(122)); + assert_eq!(inverse(22, 331), Some(316)); + assert_eq!(inverse(25, 331), Some(53)); + assert_eq!(inverse(28, 331), Some(201)); + assert_eq!(inverse(1, 338), Some(1)); + assert!(inverse(4, 338).is_none()); + assert_eq!(inverse(7, 338), Some(145)); + assert!(inverse(10, 338).is_none()); + assert!(inverse(13, 338).is_none()); + assert!(inverse(16, 338).is_none()); + assert_eq!(inverse(19, 338), Some(89)); + assert!(inverse(22, 338).is_none()); + assert_eq!(inverse(25, 338), Some(311)); + assert!(inverse(28, 338).is_none()); + assert_eq!(inverse(1, 345), Some(1)); + assert_eq!(inverse(4, 345), Some(259)); + assert_eq!(inverse(7, 345), Some(148)); + assert!(inverse(10, 345).is_none()); + assert_eq!(inverse(13, 345), Some(292)); + assert_eq!(inverse(16, 345), Some(151)); + assert_eq!(inverse(19, 345), Some(109)); + assert_eq!(inverse(22, 345), Some(298)); + assert!(inverse(25, 345).is_none()); + assert_eq!(inverse(28, 345), Some(37)); + assert_eq!(inverse(1, 352), Some(1)); + assert!(inverse(4, 352).is_none()); + assert_eq!(inverse(7, 352), Some(151)); + assert!(inverse(10, 352).is_none()); + assert_eq!(inverse(13, 352), Some(325)); + assert!(inverse(16, 352).is_none()); + assert_eq!(inverse(19, 352), Some(315)); + assert!(inverse(22, 352).is_none()); + assert_eq!(inverse(25, 352), Some(169)); + assert!(inverse(28, 352).is_none()); + assert_eq!(inverse(1, 359), Some(1)); + assert_eq!(inverse(4, 359), Some(90)); + assert_eq!(inverse(7, 359), Some(154)); + assert_eq!(inverse(10, 359), Some(36)); + assert_eq!(inverse(13, 359), Some(221)); + assert_eq!(inverse(16, 359), Some(202)); + assert_eq!(inverse(19, 359), Some(189)); + assert_eq!(inverse(22, 359), Some(49)); + assert_eq!(inverse(25, 359), Some(158)); + assert_eq!(inverse(28, 359), Some(218)); + assert_eq!(inverse(1, 366), Some(1)); + assert!(inverse(4, 366).is_none()); + assert_eq!(inverse(7, 366), Some(157)); + assert!(inverse(10, 366).is_none()); + assert_eq!(inverse(13, 366), Some(169)); + assert!(inverse(16, 366).is_none()); + assert_eq!(inverse(19, 366), Some(289)); + assert!(inverse(22, 366).is_none()); + assert_eq!(inverse(25, 366), Some(205)); + assert!(inverse(28, 366).is_none()); + assert_eq!(inverse(1, 373), Some(1)); + assert_eq!(inverse(4, 373), Some(280)); + assert_eq!(inverse(7, 373), Some(160)); + assert_eq!(inverse(10, 373), Some(112)); + assert_eq!(inverse(13, 373), Some(287)); + assert_eq!(inverse(16, 373), Some(70)); + assert_eq!(inverse(19, 373), Some(216)); + assert_eq!(inverse(22, 373), Some(17)); + assert_eq!(inverse(25, 373), Some(194)); + assert_eq!(inverse(28, 373), Some(40)); + assert_eq!(inverse(1, 380), Some(1)); + assert!(inverse(4, 380).is_none()); + assert_eq!(inverse(7, 380), Some(163)); + assert!(inverse(10, 380).is_none()); + assert_eq!(inverse(13, 380), Some(117)); + assert!(inverse(16, 380).is_none()); + assert!(inverse(19, 380).is_none()); + assert!(inverse(22, 380).is_none()); + assert!(inverse(25, 380).is_none()); + assert!(inverse(28, 380).is_none()); + assert_eq!(inverse(1, 387), Some(1)); + assert_eq!(inverse(4, 387), Some(97)); + assert_eq!(inverse(7, 387), Some(166)); + assert_eq!(inverse(10, 387), Some(271)); + assert_eq!(inverse(13, 387), Some(268)); + assert_eq!(inverse(16, 387), Some(121)); + assert_eq!(inverse(19, 387), Some(163)); + assert_eq!(inverse(22, 387), Some(88)); + assert_eq!(inverse(25, 387), Some(31)); + assert_eq!(inverse(28, 387), Some(235)); + assert_eq!(inverse(1, 394), Some(1)); + assert!(inverse(4, 394).is_none()); + assert_eq!(inverse(7, 394), Some(169)); + assert!(inverse(10, 394).is_none()); + assert_eq!(inverse(13, 394), Some(91)); + assert!(inverse(16, 394).is_none()); + assert_eq!(inverse(19, 394), Some(83)); + assert!(inverse(22, 394).is_none()); + assert_eq!(inverse(25, 394), Some(331)); + assert!(inverse(28, 394).is_none()); + assert_eq!(inverse(1, 401), Some(1)); + assert_eq!(inverse(4, 401), Some(301)); + assert_eq!(inverse(7, 401), Some(172)); + assert_eq!(inverse(10, 401), Some(361)); + assert_eq!(inverse(13, 401), Some(216)); + assert_eq!(inverse(16, 401), Some(376)); + assert_eq!(inverse(19, 401), Some(190)); + assert_eq!(inverse(22, 401), Some(237)); + assert_eq!(inverse(25, 401), Some(385)); + assert_eq!(inverse(28, 401), Some(43)); + assert_eq!(inverse(1, 408), Some(1)); + assert!(inverse(4, 408).is_none()); + assert_eq!(inverse(7, 408), Some(175)); + assert!(inverse(10, 408).is_none()); + assert_eq!(inverse(13, 408), Some(157)); + assert!(inverse(16, 408).is_none()); + assert_eq!(inverse(19, 408), Some(43)); + assert!(inverse(22, 408).is_none()); + assert_eq!(inverse(25, 408), Some(49)); + assert!(inverse(28, 408).is_none()); + assert_eq!(inverse(1, 415), Some(1)); + assert_eq!(inverse(4, 415), Some(104)); + assert_eq!(inverse(7, 415), Some(178)); + assert!(inverse(10, 415).is_none()); + assert_eq!(inverse(13, 415), Some(32)); + assert_eq!(inverse(16, 415), Some(26)); + assert_eq!(inverse(19, 415), Some(284)); + assert_eq!(inverse(22, 415), Some(283)); + assert!(inverse(25, 415).is_none()); + assert_eq!(inverse(28, 415), Some(252)); + assert_eq!(inverse(1, 422), Some(1)); + assert!(inverse(4, 422).is_none()); + assert_eq!(inverse(7, 422), Some(181)); + assert!(inverse(10, 422).is_none()); + assert_eq!(inverse(13, 422), Some(65)); + assert!(inverse(16, 422).is_none()); + assert_eq!(inverse(19, 422), Some(311)); + assert!(inverse(22, 422).is_none()); + assert_eq!(inverse(25, 422), Some(287)); + assert!(inverse(28, 422).is_none()); + assert_eq!(inverse(1, 429), Some(1)); + assert_eq!(inverse(4, 429), Some(322)); + assert_eq!(inverse(7, 429), Some(184)); + assert_eq!(inverse(10, 429), Some(43)); + assert!(inverse(13, 429).is_none()); + assert_eq!(inverse(16, 429), Some(295)); + assert_eq!(inverse(19, 429), Some(271)); + assert!(inverse(22, 429).is_none()); + assert_eq!(inverse(25, 429), Some(103)); + assert_eq!(inverse(28, 429), Some(46)); + assert_eq!(inverse(1, 436), Some(1)); + assert!(inverse(4, 436).is_none()); + assert_eq!(inverse(7, 436), Some(187)); + assert!(inverse(10, 436).is_none()); + assert_eq!(inverse(13, 436), Some(369)); + assert!(inverse(16, 436).is_none()); + assert_eq!(inverse(19, 436), Some(23)); + assert!(inverse(22, 436).is_none()); + assert_eq!(inverse(25, 436), Some(157)); + assert!(inverse(28, 436).is_none()); + assert_eq!(inverse(1, 443), Some(1)); + assert_eq!(inverse(4, 443), Some(111)); + assert_eq!(inverse(7, 443), Some(190)); + assert_eq!(inverse(10, 443), Some(133)); + assert_eq!(inverse(13, 443), Some(409)); + assert_eq!(inverse(16, 443), Some(360)); + assert_eq!(inverse(19, 443), Some(70)); + assert_eq!(inverse(22, 443), Some(141)); + assert_eq!(inverse(25, 443), Some(319)); + assert_eq!(inverse(28, 443), Some(269)); + assert_eq!(inverse(1, 450), Some(1)); + assert!(inverse(4, 450).is_none()); + assert_eq!(inverse(7, 450), Some(193)); + assert!(inverse(10, 450).is_none()); + assert_eq!(inverse(13, 450), Some(277)); + assert!(inverse(16, 450).is_none()); + assert_eq!(inverse(19, 450), Some(379)); + assert!(inverse(22, 450).is_none()); + assert!(inverse(25, 450).is_none()); + assert!(inverse(28, 450).is_none()); + assert_eq!(inverse(1, 457), Some(1)); + assert_eq!(inverse(4, 457), Some(343)); + assert_eq!(inverse(7, 457), Some(196)); + assert_eq!(inverse(10, 457), Some(320)); + assert_eq!(inverse(13, 457), Some(211)); + assert_eq!(inverse(16, 457), Some(200)); + assert_eq!(inverse(19, 457), Some(433)); + assert_eq!(inverse(22, 457), Some(187)); + assert_eq!(inverse(25, 457), Some(128)); + assert_eq!(inverse(28, 457), Some(49)); + assert_eq!(inverse(1, 464), Some(1)); + assert!(inverse(4, 464).is_none()); + assert_eq!(inverse(7, 464), Some(199)); + assert!(inverse(10, 464).is_none()); + assert_eq!(inverse(13, 464), Some(357)); + assert!(inverse(16, 464).is_none()); + assert_eq!(inverse(19, 464), Some(171)); + assert!(inverse(22, 464).is_none()); + assert_eq!(inverse(25, 464), Some(297)); + assert!(inverse(28, 464).is_none()); + assert_eq!(inverse(1, 471), Some(1)); + assert_eq!(inverse(4, 471), Some(118)); + assert_eq!(inverse(7, 471), Some(202)); + assert_eq!(inverse(10, 471), Some(424)); + assert_eq!(inverse(13, 471), Some(145)); + assert_eq!(inverse(16, 471), Some(265)); + assert_eq!(inverse(19, 471), Some(124)); + assert_eq!(inverse(22, 471), Some(364)); + assert_eq!(inverse(25, 471), Some(358)); + assert_eq!(inverse(28, 471), Some(286)); + assert_eq!(inverse(1, 478), Some(1)); + assert!(inverse(4, 478).is_none()); + assert_eq!(inverse(7, 478), Some(205)); + assert!(inverse(10, 478).is_none()); + assert_eq!(inverse(13, 478), Some(331)); + assert!(inverse(16, 478).is_none()); + assert_eq!(inverse(19, 478), Some(151)); + assert!(inverse(22, 478).is_none()); + assert_eq!(inverse(25, 478), Some(153)); + assert!(inverse(28, 478).is_none()); + assert_eq!(inverse(1, 485), Some(1)); + assert_eq!(inverse(4, 485), Some(364)); + assert_eq!(inverse(7, 485), Some(208)); + assert!(inverse(10, 485).is_none()); + assert_eq!(inverse(13, 485), Some(112)); + assert_eq!(inverse(16, 485), Some(91)); + assert_eq!(inverse(19, 485), Some(434)); + assert_eq!(inverse(22, 485), Some(463)); + assert!(inverse(25, 485).is_none()); + assert_eq!(inverse(28, 485), Some(52)); + assert_eq!(inverse(1, 492), Some(1)); + assert!(inverse(4, 492).is_none()); + assert_eq!(inverse(7, 492), Some(211)); + assert!(inverse(10, 492).is_none()); + assert_eq!(inverse(13, 492), Some(265)); + assert!(inverse(16, 492).is_none()); + assert_eq!(inverse(19, 492), Some(259)); + assert!(inverse(22, 492).is_none()); + assert_eq!(inverse(25, 492), Some(433)); + assert!(inverse(28, 492).is_none()); + assert_eq!(inverse(1, 499), Some(1)); + assert_eq!(inverse(4, 499), Some(125)); + assert_eq!(inverse(7, 499), Some(214)); + assert_eq!(inverse(10, 499), Some(50)); + assert_eq!(inverse(13, 499), Some(192)); + assert_eq!(inverse(16, 499), Some(156)); + assert_eq!(inverse(19, 499), Some(394)); + assert_eq!(inverse(22, 499), Some(431)); + assert_eq!(inverse(25, 499), Some(20)); + assert_eq!(inverse(28, 499), Some(303)); + assert_eq!(inverse(1, 506), Some(1)); + assert!(inverse(4, 506).is_none()); + assert_eq!(inverse(7, 506), Some(217)); + assert!(inverse(10, 506).is_none()); + assert_eq!(inverse(13, 506), Some(39)); + assert!(inverse(16, 506).is_none()); + assert_eq!(inverse(19, 506), Some(293)); + assert!(inverse(22, 506).is_none()); + assert_eq!(inverse(25, 506), Some(81)); + assert!(inverse(28, 506).is_none()); + assert_eq!(inverse(1, 513), Some(1)); + assert_eq!(inverse(4, 513), Some(385)); + assert_eq!(inverse(7, 513), Some(220)); + assert_eq!(inverse(10, 513), Some(154)); + assert_eq!(inverse(13, 513), Some(79)); + assert_eq!(inverse(16, 513), Some(481)); + assert!(inverse(19, 513).is_none()); + assert_eq!(inverse(22, 513), Some(70)); + assert_eq!(inverse(25, 513), Some(472)); + assert_eq!(inverse(28, 513), Some(55)); + assert_eq!(inverse(1, 520), Some(1)); + assert!(inverse(4, 520).is_none()); + assert_eq!(inverse(7, 520), Some(223)); + assert!(inverse(10, 520).is_none()); + assert!(inverse(13, 520).is_none()); + assert!(inverse(16, 520).is_none()); + assert_eq!(inverse(19, 520), Some(219)); + assert!(inverse(22, 520).is_none()); + assert!(inverse(25, 520).is_none()); + assert!(inverse(28, 520).is_none()); + assert_eq!(inverse(1, 527), Some(1)); + assert_eq!(inverse(4, 527), Some(132)); + assert_eq!(inverse(7, 527), Some(226)); + assert_eq!(inverse(10, 527), Some(369)); + assert_eq!(inverse(13, 527), Some(446)); + assert_eq!(inverse(16, 527), Some(33)); + assert_eq!(inverse(19, 527), Some(111)); + assert_eq!(inverse(22, 527), Some(24)); + assert_eq!(inverse(25, 527), Some(253)); + assert_eq!(inverse(28, 527), Some(320)); + assert_eq!(inverse(1, 534), Some(1)); + assert!(inverse(4, 534).is_none()); + assert_eq!(inverse(7, 534), Some(229)); + assert!(inverse(10, 534).is_none()); + assert_eq!(inverse(13, 534), Some(493)); + assert!(inverse(16, 534).is_none()); + assert_eq!(inverse(19, 534), Some(253)); + assert!(inverse(22, 534).is_none()); + assert_eq!(inverse(25, 534), Some(235)); + assert!(inverse(28, 534).is_none()); + assert_eq!(inverse(1, 541), Some(1)); + assert_eq!(inverse(4, 541), Some(406)); + assert_eq!(inverse(7, 541), Some(232)); + assert_eq!(inverse(10, 541), Some(487)); + assert_eq!(inverse(13, 541), Some(333)); + assert_eq!(inverse(16, 541), Some(372)); + assert_eq!(inverse(19, 541), Some(57)); + assert_eq!(inverse(22, 541), Some(123)); + assert_eq!(inverse(25, 541), Some(303)); + assert_eq!(inverse(28, 541), Some(58)); + assert_eq!(inverse(1, 548), Some(1)); + assert!(inverse(4, 548).is_none()); + assert_eq!(inverse(7, 548), Some(235)); + assert!(inverse(10, 548).is_none()); + assert_eq!(inverse(13, 548), Some(253)); + assert!(inverse(16, 548).is_none()); + assert_eq!(inverse(19, 548), Some(375)); + assert!(inverse(22, 548).is_none()); + assert_eq!(inverse(25, 548), Some(285)); + assert!(inverse(28, 548).is_none()); + assert_eq!(inverse(1, 555), Some(1)); + assert_eq!(inverse(4, 555), Some(139)); + assert_eq!(inverse(7, 555), Some(238)); + assert!(inverse(10, 555).is_none()); + assert_eq!(inverse(13, 555), Some(427)); + assert_eq!(inverse(16, 555), Some(451)); + assert_eq!(inverse(19, 555), Some(409)); + assert_eq!(inverse(22, 555), Some(328)); + assert!(inverse(25, 555).is_none()); + assert_eq!(inverse(28, 555), Some(337)); + assert_eq!(inverse(1, 562), Some(1)); + assert!(inverse(4, 562).is_none()); + assert_eq!(inverse(7, 562), Some(241)); + assert!(inverse(10, 562).is_none()); + assert_eq!(inverse(13, 562), Some(173)); + assert!(inverse(16, 562).is_none()); + assert_eq!(inverse(19, 562), Some(355)); + assert!(inverse(22, 562).is_none()); + assert_eq!(inverse(25, 562), Some(45)); + assert!(inverse(28, 562).is_none()); + assert_eq!(inverse(1, 569), Some(1)); + assert_eq!(inverse(4, 569), Some(427)); + assert_eq!(inverse(7, 569), Some(244)); + assert_eq!(inverse(10, 569), Some(57)); + assert_eq!(inverse(13, 569), Some(394)); + assert_eq!(inverse(16, 569), Some(249)); + assert_eq!(inverse(19, 569), Some(30)); + assert_eq!(inverse(22, 569), Some(388)); + assert_eq!(inverse(25, 569), Some(478)); + assert_eq!(inverse(28, 569), Some(61)); + assert_eq!(inverse(1, 576), Some(1)); + assert!(inverse(4, 576).is_none()); + assert_eq!(inverse(7, 576), Some(247)); + assert!(inverse(10, 576).is_none()); + assert_eq!(inverse(13, 576), Some(133)); + assert!(inverse(16, 576).is_none()); + assert_eq!(inverse(19, 576), Some(91)); + assert!(inverse(22, 576).is_none()); + assert_eq!(inverse(25, 576), Some(553)); + assert!(inverse(28, 576).is_none()); + assert_eq!(inverse(1, 583), Some(1)); + assert_eq!(inverse(4, 583), Some(146)); + assert_eq!(inverse(7, 583), Some(250)); + assert_eq!(inverse(10, 583), Some(175)); + assert_eq!(inverse(13, 583), Some(314)); + assert_eq!(inverse(16, 583), Some(328)); + assert_eq!(inverse(19, 583), Some(491)); + assert!(inverse(22, 583).is_none()); + assert_eq!(inverse(25, 583), Some(70)); + assert_eq!(inverse(28, 583), Some(354)); + assert_eq!(inverse(1, 590), Some(1)); + assert!(inverse(4, 590).is_none()); + assert_eq!(inverse(7, 590), Some(253)); + assert!(inverse(10, 590).is_none()); + assert_eq!(inverse(13, 590), Some(227)); + assert!(inverse(16, 590).is_none()); + assert_eq!(inverse(19, 590), Some(559)); + assert!(inverse(22, 590).is_none()); + assert!(inverse(25, 590).is_none()); + assert!(inverse(28, 590).is_none()); + assert_eq!(inverse(1, 597), Some(1)); + assert_eq!(inverse(4, 597), Some(448)); + assert_eq!(inverse(7, 597), Some(256)); + assert_eq!(inverse(10, 597), Some(418)); + assert_eq!(inverse(13, 597), Some(46)); + assert_eq!(inverse(16, 597), Some(112)); + assert_eq!(inverse(19, 597), Some(220)); + assert_eq!(inverse(22, 597), Some(190)); + assert_eq!(inverse(25, 597), Some(406)); + assert_eq!(inverse(28, 597), Some(64)); + assert_eq!(inverse(1, 604), Some(1)); + assert!(inverse(4, 604).is_none()); + assert_eq!(inverse(7, 604), Some(259)); + assert!(inverse(10, 604).is_none()); + assert_eq!(inverse(13, 604), Some(93)); + assert!(inverse(16, 604).is_none()); + assert_eq!(inverse(19, 604), Some(159)); + assert!(inverse(22, 604).is_none()); + assert_eq!(inverse(25, 604), Some(145)); + assert!(inverse(28, 604).is_none()); + assert_eq!(inverse(1, 611), Some(1)); + assert_eq!(inverse(4, 611), Some(153)); + assert_eq!(inverse(7, 611), Some(262)); + assert_eq!(inverse(10, 611), Some(550)); + assert!(inverse(13, 611).is_none()); + assert_eq!(inverse(16, 611), Some(191)); + assert_eq!(inverse(19, 611), Some(193)); + assert_eq!(inverse(22, 611), Some(250)); + assert_eq!(inverse(25, 611), Some(220)); + assert_eq!(inverse(28, 611), Some(371)); + assert_eq!(inverse(1, 618), Some(1)); + assert!(inverse(4, 618).is_none()); + assert_eq!(inverse(7, 618), Some(265)); + assert!(inverse(10, 618).is_none()); + assert_eq!(inverse(13, 618), Some(523)); + assert!(inverse(16, 618).is_none()); + assert_eq!(inverse(19, 618), Some(553)); + assert!(inverse(22, 618).is_none()); + assert_eq!(inverse(25, 618), Some(445)); + assert!(inverse(28, 618).is_none()); + assert_eq!(inverse(1, 625), Some(1)); + assert_eq!(inverse(4, 625), Some(469)); + assert_eq!(inverse(7, 625), Some(268)); + assert!(inverse(10, 625).is_none()); + assert_eq!(inverse(13, 625), Some(577)); + assert_eq!(inverse(16, 625), Some(586)); + assert_eq!(inverse(19, 625), Some(329)); + assert_eq!(inverse(22, 625), Some(483)); + assert!(inverse(25, 625).is_none()); + assert_eq!(inverse(28, 625), Some(67)); + assert_eq!(inverse(1, 632), Some(1)); + assert!(inverse(4, 632).is_none()); + assert_eq!(inverse(7, 632), Some(271)); + assert!(inverse(10, 632).is_none()); + assert_eq!(inverse(13, 632), Some(389)); + assert!(inverse(16, 632).is_none()); + assert_eq!(inverse(19, 632), Some(499)); + assert!(inverse(22, 632).is_none()); + assert_eq!(inverse(25, 632), Some(177)); + assert!(inverse(28, 632).is_none()); + assert_eq!(inverse(1, 639), Some(1)); + assert_eq!(inverse(4, 639), Some(160)); + assert_eq!(inverse(7, 639), Some(274)); + assert_eq!(inverse(10, 639), Some(64)); + assert_eq!(inverse(13, 639), Some(295)); + assert_eq!(inverse(16, 639), Some(40)); + assert_eq!(inverse(19, 639), Some(370)); + assert_eq!(inverse(22, 639), Some(610)); + assert_eq!(inverse(25, 639), Some(409)); + assert_eq!(inverse(28, 639), Some(388)); + assert_eq!(inverse(1, 646), Some(1)); + assert!(inverse(4, 646).is_none()); + assert_eq!(inverse(7, 646), Some(277)); + assert!(inverse(10, 646).is_none()); + assert_eq!(inverse(13, 646), Some(497)); + assert!(inverse(16, 646).is_none()); + assert!(inverse(19, 646).is_none()); + assert!(inverse(22, 646).is_none()); + assert_eq!(inverse(25, 646), Some(491)); + assert!(inverse(28, 646).is_none()); + assert_eq!(inverse(1, 653), Some(1)); + assert_eq!(inverse(4, 653), Some(490)); + assert_eq!(inverse(7, 653), Some(280)); + assert_eq!(inverse(10, 653), Some(196)); + assert_eq!(inverse(13, 653), Some(201)); + assert_eq!(inverse(16, 653), Some(449)); + assert_eq!(inverse(19, 653), Some(275)); + assert_eq!(inverse(22, 653), Some(564)); + assert_eq!(inverse(25, 653), Some(209)); + assert_eq!(inverse(28, 653), Some(70)); + assert_eq!(inverse(1, 660), Some(1)); + assert!(inverse(4, 660).is_none()); + assert_eq!(inverse(7, 660), Some(283)); + assert!(inverse(10, 660).is_none()); + assert_eq!(inverse(13, 660), Some(457)); + assert!(inverse(16, 660).is_none()); + assert_eq!(inverse(19, 660), Some(139)); + assert!(inverse(22, 660).is_none()); + assert!(inverse(25, 660).is_none()); + assert!(inverse(28, 660).is_none()); + assert_eq!(inverse(1, 667), Some(1)); + assert_eq!(inverse(4, 667), Some(167)); + assert_eq!(inverse(7, 667), Some(286)); + assert_eq!(inverse(10, 667), Some(467)); + assert_eq!(inverse(13, 667), Some(154)); + assert_eq!(inverse(16, 667), Some(542)); + assert_eq!(inverse(19, 667), Some(316)); + assert_eq!(inverse(22, 667), Some(91)); + assert_eq!(inverse(25, 667), Some(587)); + assert_eq!(inverse(28, 667), Some(405)); + assert_eq!(inverse(1, 674), Some(1)); + assert!(inverse(4, 674).is_none()); + assert_eq!(inverse(7, 674), Some(289)); + assert!(inverse(10, 674).is_none()); + assert_eq!(inverse(13, 674), Some(363)); + assert!(inverse(16, 674).is_none()); + assert_eq!(inverse(19, 674), Some(71)); + assert!(inverse(22, 674).is_none()); + assert_eq!(inverse(25, 674), Some(27)); + assert!(inverse(28, 674).is_none()); + assert_eq!(inverse(1, 681), Some(1)); + assert_eq!(inverse(4, 681), Some(511)); + assert_eq!(inverse(7, 681), Some(292)); + assert_eq!(inverse(10, 681), Some(613)); + assert_eq!(inverse(13, 681), Some(262)); + assert_eq!(inverse(16, 681), Some(298)); + assert_eq!(inverse(19, 681), Some(466)); + assert_eq!(inverse(22, 681), Some(31)); + assert_eq!(inverse(25, 681), Some(109)); + assert_eq!(inverse(28, 681), Some(73)); + assert_eq!(inverse(1, 688), Some(1)); + assert!(inverse(4, 688).is_none()); + assert_eq!(inverse(7, 688), Some(295)); + assert!(inverse(10, 688).is_none()); + assert_eq!(inverse(13, 688), Some(53)); + assert!(inverse(16, 688).is_none()); + assert_eq!(inverse(19, 688), Some(507)); + assert!(inverse(22, 688).is_none()); + assert_eq!(inverse(25, 688), Some(633)); + assert!(inverse(28, 688).is_none()); + assert_eq!(inverse(1, 695), Some(1)); + assert_eq!(inverse(4, 695), Some(174)); + assert_eq!(inverse(7, 695), Some(298)); + assert!(inverse(10, 695).is_none()); + assert_eq!(inverse(13, 695), Some(107)); + assert_eq!(inverse(16, 695), Some(391)); + assert_eq!(inverse(19, 695), Some(439)); + assert_eq!(inverse(22, 695), Some(158)); + assert!(inverse(25, 695).is_none()); + assert_eq!(inverse(28, 695), Some(422)); + assert_eq!(inverse(1, 702), Some(1)); + assert!(inverse(4, 702).is_none()); + assert_eq!(inverse(7, 702), Some(301)); + assert!(inverse(10, 702).is_none()); + assert!(inverse(13, 702).is_none()); + assert!(inverse(16, 702).is_none()); + assert_eq!(inverse(19, 702), Some(37)); + assert!(inverse(22, 702).is_none()); + assert_eq!(inverse(25, 702), Some(337)); + assert!(inverse(28, 702).is_none()); + assert_eq!(inverse(1, 709), Some(1)); + assert_eq!(inverse(4, 709), Some(532)); + assert_eq!(inverse(7, 709), Some(304)); + assert_eq!(inverse(10, 709), Some(71)); + assert_eq!(inverse(13, 709), Some(600)); + assert_eq!(inverse(16, 709), Some(133)); + assert_eq!(inverse(19, 709), Some(112)); + assert_eq!(inverse(22, 709), Some(419)); + assert_eq!(inverse(25, 709), Some(312)); + assert_eq!(inverse(28, 709), Some(76)); + assert_eq!(inverse(1, 716), Some(1)); + assert!(inverse(4, 716).is_none()); + assert_eq!(inverse(7, 716), Some(307)); + assert!(inverse(10, 716).is_none()); + assert_eq!(inverse(13, 716), Some(661)); + assert!(inverse(16, 716).is_none()); + assert_eq!(inverse(19, 716), Some(603)); + assert!(inverse(22, 716).is_none()); + assert_eq!(inverse(25, 716), Some(401)); + assert!(inverse(28, 716).is_none()); + assert_eq!(inverse(1, 723), Some(1)); + assert_eq!(inverse(4, 723), Some(181)); + assert_eq!(inverse(7, 723), Some(310)); + assert_eq!(inverse(10, 723), Some(217)); + assert_eq!(inverse(13, 723), Some(445)); + assert_eq!(inverse(16, 723), Some(226)); + assert_eq!(inverse(19, 723), Some(685)); + assert_eq!(inverse(22, 723), Some(493)); + assert_eq!(inverse(25, 723), Some(376)); + assert_eq!(inverse(28, 723), Some(439)); + assert_eq!(inverse(1, 730), Some(1)); + assert!(inverse(4, 730).is_none()); + assert_eq!(inverse(7, 730), Some(313)); + assert!(inverse(10, 730).is_none()); + assert_eq!(inverse(13, 730), Some(337)); + assert!(inverse(16, 730).is_none()); + assert_eq!(inverse(19, 730), Some(269)); + assert!(inverse(22, 730).is_none()); + assert!(inverse(25, 730).is_none()); + assert!(inverse(28, 730).is_none()); + assert_eq!(inverse(1, 737), Some(1)); + assert_eq!(inverse(4, 737), Some(553)); + assert_eq!(inverse(7, 737), Some(316)); + assert_eq!(inverse(10, 737), Some(516)); + assert_eq!(inverse(13, 737), Some(567)); + assert_eq!(inverse(16, 737), Some(691)); + assert_eq!(inverse(19, 737), Some(194)); + assert!(inverse(22, 737).is_none()); + assert_eq!(inverse(25, 737), Some(59)); + assert_eq!(inverse(28, 737), Some(79)); + assert_eq!(inverse(1, 744), Some(1)); + assert!(inverse(4, 744).is_none()); + assert_eq!(inverse(7, 744), Some(319)); + assert!(inverse(10, 744).is_none()); + assert_eq!(inverse(13, 744), Some(229)); + assert!(inverse(16, 744).is_none()); + assert_eq!(inverse(19, 744), Some(235)); + assert!(inverse(22, 744).is_none()); + assert_eq!(inverse(25, 744), Some(625)); + assert!(inverse(28, 744).is_none()); + assert_eq!(inverse(1, 751), Some(1)); + assert_eq!(inverse(4, 751), Some(188)); + assert_eq!(inverse(7, 751), Some(322)); + assert_eq!(inverse(10, 751), Some(676)); + assert_eq!(inverse(13, 751), Some(520)); + assert_eq!(inverse(16, 751), Some(47)); + assert_eq!(inverse(19, 751), Some(672)); + assert_eq!(inverse(22, 751), Some(239)); + assert_eq!(inverse(25, 751), Some(721)); + assert_eq!(inverse(28, 751), Some(456)); + assert_eq!(inverse(1, 758), Some(1)); + assert!(inverse(4, 758).is_none()); + assert_eq!(inverse(7, 758), Some(325)); + assert!(inverse(10, 758).is_none()); + assert_eq!(inverse(13, 758), Some(175)); + assert!(inverse(16, 758).is_none()); + assert_eq!(inverse(19, 758), Some(399)); + assert!(inverse(22, 758).is_none()); + assert_eq!(inverse(25, 758), Some(91)); + assert!(inverse(28, 758).is_none()); + assert_eq!(inverse(1, 765), Some(1)); + assert_eq!(inverse(4, 765), Some(574)); + assert_eq!(inverse(7, 765), Some(328)); + assert!(inverse(10, 765).is_none()); + assert_eq!(inverse(13, 765), Some(412)); + assert_eq!(inverse(16, 765), Some(526)); + assert_eq!(inverse(19, 765), Some(604)); + assert_eq!(inverse(22, 765), Some(313)); + assert!(inverse(25, 765).is_none()); + assert_eq!(inverse(28, 765), Some(82)); + assert_eq!(inverse(1, 772), Some(1)); + assert!(inverse(4, 772).is_none()); + assert_eq!(inverse(7, 772), Some(331)); + assert!(inverse(10, 772).is_none()); + assert_eq!(inverse(13, 772), Some(297)); + assert!(inverse(16, 772).is_none()); + assert_eq!(inverse(19, 772), Some(447)); + assert!(inverse(22, 772).is_none()); + assert_eq!(inverse(25, 772), Some(525)); + assert!(inverse(28, 772).is_none()); + assert_eq!(inverse(1, 779), Some(1)); + assert_eq!(inverse(4, 779), Some(195)); + assert_eq!(inverse(7, 779), Some(334)); + assert_eq!(inverse(10, 779), Some(78)); + assert_eq!(inverse(13, 779), Some(60)); + assert_eq!(inverse(16, 779), Some(633)); + assert!(inverse(19, 779).is_none()); + assert_eq!(inverse(22, 779), Some(602)); + assert_eq!(inverse(25, 779), Some(187)); + assert_eq!(inverse(28, 779), Some(473)); + assert_eq!(inverse(1, 786), Some(1)); + assert!(inverse(4, 786).is_none()); + assert_eq!(inverse(7, 786), Some(337)); + assert!(inverse(10, 786).is_none()); + assert_eq!(inverse(13, 786), Some(121)); + assert!(inverse(16, 786).is_none()); + assert_eq!(inverse(19, 786), Some(331)); + assert!(inverse(22, 786).is_none()); + assert_eq!(inverse(25, 786), Some(283)); + assert!(inverse(28, 786).is_none()); + assert_eq!(inverse(1, 793), Some(1)); + assert_eq!(inverse(4, 793), Some(595)); + assert_eq!(inverse(7, 793), Some(340)); + assert_eq!(inverse(10, 793), Some(238)); + assert!(inverse(13, 793).is_none()); + assert_eq!(inverse(16, 793), Some(347)); + assert_eq!(inverse(19, 793), Some(167)); + assert_eq!(inverse(22, 793), Some(757)); + assert_eq!(inverse(25, 793), Some(571)); + assert_eq!(inverse(28, 793), Some(85)); + assert_eq!(inverse(1, 800), Some(1)); + assert!(inverse(4, 800).is_none()); + assert_eq!(inverse(7, 800), Some(343)); + assert!(inverse(10, 800).is_none()); + assert_eq!(inverse(13, 800), Some(677)); + assert!(inverse(16, 800).is_none()); + assert_eq!(inverse(19, 800), Some(379)); + assert!(inverse(22, 800).is_none()); + assert!(inverse(25, 800).is_none()); + assert!(inverse(28, 800).is_none()); + assert_eq!(inverse(1, 807), Some(1)); + assert_eq!(inverse(4, 807), Some(202)); + assert_eq!(inverse(7, 807), Some(346)); + assert_eq!(inverse(10, 807), Some(565)); + assert_eq!(inverse(13, 807), Some(745)); + assert_eq!(inverse(16, 807), Some(454)); + assert_eq!(inverse(19, 807), Some(85)); + assert_eq!(inverse(22, 807), Some(697)); + assert_eq!(inverse(25, 807), Some(226)); + assert_eq!(inverse(28, 807), Some(490)); + assert_eq!(inverse(1, 814), Some(1)); + assert!(inverse(4, 814).is_none()); + assert_eq!(inverse(7, 814), Some(349)); + assert!(inverse(10, 814).is_none()); + assert_eq!(inverse(13, 814), Some(501)); + assert!(inverse(16, 814).is_none()); + assert_eq!(inverse(19, 814), Some(557)); + assert!(inverse(22, 814).is_none()); + assert_eq!(inverse(25, 814), Some(521)); + assert!(inverse(28, 814).is_none()); + assert_eq!(inverse(1, 821), Some(1)); + assert_eq!(inverse(4, 821), Some(616)); + assert_eq!(inverse(7, 821), Some(352)); + assert_eq!(inverse(10, 821), Some(739)); + assert_eq!(inverse(13, 821), Some(379)); + assert_eq!(inverse(16, 821), Some(154)); + assert_eq!(inverse(19, 821), Some(605)); + assert_eq!(inverse(22, 821), Some(112)); + assert_eq!(inverse(25, 821), Some(624)); + assert_eq!(inverse(28, 821), Some(88)); + assert_eq!(inverse(1, 828), Some(1)); + assert!(inverse(4, 828).is_none()); + assert_eq!(inverse(7, 828), Some(355)); + assert!(inverse(10, 828).is_none()); + assert_eq!(inverse(13, 828), Some(637)); + assert!(inverse(16, 828).is_none()); + assert_eq!(inverse(19, 828), Some(523)); + assert!(inverse(22, 828).is_none()); + assert_eq!(inverse(25, 828), Some(265)); + assert!(inverse(28, 828).is_none()); + assert_eq!(inverse(1, 835), Some(1)); + assert_eq!(inverse(4, 835), Some(209)); + assert_eq!(inverse(7, 835), Some(358)); + assert!(inverse(10, 835).is_none()); + assert_eq!(inverse(13, 835), Some(257)); + assert_eq!(inverse(16, 835), Some(261)); + assert_eq!(inverse(19, 835), Some(44)); + assert_eq!(inverse(22, 835), Some(38)); + assert!(inverse(25, 835).is_none()); + assert_eq!(inverse(28, 835), Some(507)); + assert_eq!(inverse(1, 842), Some(1)); + assert!(inverse(4, 842).is_none()); + assert_eq!(inverse(7, 842), Some(361)); + assert!(inverse(10, 842).is_none()); + assert_eq!(inverse(13, 842), Some(583)); + assert!(inverse(16, 842).is_none()); + assert_eq!(inverse(19, 842), Some(133)); + assert!(inverse(22, 842).is_none()); + assert_eq!(inverse(25, 842), Some(741)); + assert!(inverse(28, 842).is_none()); + assert_eq!(inverse(1, 849), Some(1)); + assert_eq!(inverse(4, 849), Some(637)); + assert_eq!(inverse(7, 849), Some(364)); + assert_eq!(inverse(10, 849), Some(85)); + assert_eq!(inverse(13, 849), Some(196)); + assert_eq!(inverse(16, 849), Some(796)); + assert_eq!(inverse(19, 849), Some(715)); + assert_eq!(inverse(22, 849), Some(193)); + assert_eq!(inverse(25, 849), Some(34)); + assert_eq!(inverse(28, 849), Some(91)); + assert_eq!(inverse(1, 856), Some(1)); + assert!(inverse(4, 856).is_none()); + assert_eq!(inverse(7, 856), Some(367)); + assert!(inverse(10, 856).is_none()); + assert_eq!(inverse(13, 856), Some(461)); + assert!(inverse(16, 856).is_none()); + assert_eq!(inverse(19, 856), Some(811)); + assert!(inverse(22, 856).is_none()); + assert_eq!(inverse(25, 856), Some(137)); + assert!(inverse(28, 856).is_none()); + assert_eq!(inverse(1, 863), Some(1)); + assert_eq!(inverse(4, 863), Some(216)); + assert_eq!(inverse(7, 863), Some(370)); + assert_eq!(inverse(10, 863), Some(259)); + assert_eq!(inverse(13, 863), Some(332)); + assert_eq!(inverse(16, 863), Some(54)); + assert_eq!(inverse(19, 863), Some(318)); + assert_eq!(inverse(22, 863), Some(510)); + assert_eq!(inverse(25, 863), Some(794)); + assert_eq!(inverse(28, 863), Some(524)); + assert_eq!(inverse(1, 870), Some(1)); + assert!(inverse(4, 870).is_none()); + assert_eq!(inverse(7, 870), Some(373)); + assert!(inverse(10, 870).is_none()); + assert_eq!(inverse(13, 870), Some(67)); + assert!(inverse(16, 870).is_none()); + assert_eq!(inverse(19, 870), Some(229)); + assert!(inverse(22, 870).is_none()); + assert!(inverse(25, 870).is_none()); + assert!(inverse(28, 870).is_none()); + assert_eq!(inverse(1, 877), Some(1)); + assert_eq!(inverse(4, 877), Some(658)); + assert_eq!(inverse(7, 877), Some(376)); + assert_eq!(inverse(10, 877), Some(614)); + assert_eq!(inverse(13, 877), Some(135)); + assert_eq!(inverse(16, 877), Some(603)); + assert_eq!(inverse(19, 877), Some(277)); + assert_eq!(inverse(22, 877), Some(598)); + assert_eq!(inverse(25, 877), Some(421)); + assert_eq!(inverse(28, 877), Some(94)); + assert_eq!(inverse(1, 884), Some(1)); + assert!(inverse(4, 884).is_none()); + assert_eq!(inverse(7, 884), Some(379)); + assert!(inverse(10, 884).is_none()); + assert!(inverse(13, 884).is_none()); + assert!(inverse(16, 884).is_none()); + assert_eq!(inverse(19, 884), Some(791)); + assert!(inverse(22, 884).is_none()); + assert_eq!(inverse(25, 884), Some(389)); + assert!(inverse(28, 884).is_none()); + assert_eq!(inverse(1, 891), Some(1)); + assert_eq!(inverse(4, 891), Some(223)); + assert_eq!(inverse(7, 891), Some(382)); + assert_eq!(inverse(10, 891), Some(802)); + assert_eq!(inverse(13, 891), Some(754)); + assert_eq!(inverse(16, 891), Some(724)); + assert_eq!(inverse(19, 891), Some(469)); + assert!(inverse(22, 891).is_none()); + assert_eq!(inverse(25, 891), Some(499)); + assert_eq!(inverse(28, 891), Some(541)); + assert_eq!(inverse(1, 898), Some(1)); + assert!(inverse(4, 898).is_none()); + assert_eq!(inverse(7, 898), Some(385)); + assert!(inverse(10, 898).is_none()); + assert_eq!(inverse(13, 898), Some(829)); + assert!(inverse(16, 898).is_none()); + assert_eq!(inverse(19, 898), Some(709)); + assert!(inverse(22, 898).is_none()); + assert_eq!(inverse(25, 898), Some(467)); + assert!(inverse(28, 898).is_none()); + assert_eq!(inverse(1, 905), Some(1)); + assert_eq!(inverse(4, 905), Some(679)); + assert_eq!(inverse(7, 905), Some(388)); + assert!(inverse(10, 905).is_none()); + assert_eq!(inverse(13, 905), Some(557)); + assert_eq!(inverse(16, 905), Some(396)); + assert_eq!(inverse(19, 905), Some(524)); + assert_eq!(inverse(22, 905), Some(288)); + assert!(inverse(25, 905).is_none()); + assert_eq!(inverse(28, 905), Some(97)); + assert_eq!(inverse(1, 912), Some(1)); + assert!(inverse(4, 912).is_none()); + assert_eq!(inverse(7, 912), Some(391)); + assert!(inverse(10, 912).is_none()); + assert_eq!(inverse(13, 912), Some(421)); + assert!(inverse(16, 912).is_none()); + assert!(inverse(19, 912).is_none()); + assert!(inverse(22, 912).is_none()); + assert_eq!(inverse(25, 912), Some(73)); + assert!(inverse(28, 912).is_none()); + assert_eq!(inverse(1, 919), Some(1)); + assert_eq!(inverse(4, 919), Some(230)); + assert_eq!(inverse(7, 919), Some(394)); + assert_eq!(inverse(10, 919), Some(92)); + assert_eq!(inverse(13, 919), Some(707)); + assert_eq!(inverse(16, 919), Some(517)); + assert_eq!(inverse(19, 919), Some(387)); + assert_eq!(inverse(22, 919), Some(376)); + assert_eq!(inverse(25, 919), Some(772)); + assert_eq!(inverse(28, 919), Some(558)); + assert_eq!(inverse(1, 926), Some(1)); + assert!(inverse(4, 926).is_none()); + assert_eq!(inverse(7, 926), Some(397)); + assert!(inverse(10, 926).is_none()); + assert_eq!(inverse(13, 926), Some(285)); + assert!(inverse(16, 926).is_none()); + assert_eq!(inverse(19, 926), Some(195)); + assert!(inverse(22, 926).is_none()); + assert_eq!(inverse(25, 926), Some(889)); + assert!(inverse(28, 926).is_none()); + assert_eq!(inverse(1, 933), Some(1)); + assert_eq!(inverse(4, 933), Some(700)); + assert_eq!(inverse(7, 933), Some(400)); + assert_eq!(inverse(10, 933), Some(280)); + assert_eq!(inverse(13, 933), Some(646)); + assert_eq!(inverse(16, 933), Some(175)); + assert_eq!(inverse(19, 933), Some(442)); + assert_eq!(inverse(22, 933), Some(721)); + assert_eq!(inverse(25, 933), Some(112)); + assert_eq!(inverse(28, 933), Some(100)); + assert_eq!(inverse(1, 940), Some(1)); + assert!(inverse(4, 940).is_none()); + assert_eq!(inverse(7, 940), Some(403)); + assert!(inverse(10, 940).is_none()); + assert_eq!(inverse(13, 940), Some(217)); + assert!(inverse(16, 940).is_none()); + assert_eq!(inverse(19, 940), Some(99)); + assert!(inverse(22, 940).is_none()); + assert!(inverse(25, 940).is_none()); + assert!(inverse(28, 940).is_none()); + assert_eq!(inverse(1, 947), Some(1)); + assert_eq!(inverse(4, 947), Some(237)); + assert_eq!(inverse(7, 947), Some(406)); + assert_eq!(inverse(10, 947), Some(663)); + assert_eq!(inverse(13, 947), Some(510)); + assert_eq!(inverse(16, 947), Some(296)); + assert_eq!(inverse(19, 947), Some(648)); + assert_eq!(inverse(22, 947), Some(904)); + assert_eq!(inverse(25, 947), Some(644)); + assert_eq!(inverse(28, 947), Some(575)); + assert_eq!(inverse(1, 954), Some(1)); + assert!(inverse(4, 954).is_none()); + assert_eq!(inverse(7, 954), Some(409)); + assert!(inverse(10, 954).is_none()); + assert_eq!(inverse(13, 954), Some(367)); + assert!(inverse(16, 954).is_none()); + assert_eq!(inverse(19, 954), Some(703)); + assert!(inverse(22, 954).is_none()); + assert_eq!(inverse(25, 954), Some(229)); + assert!(inverse(28, 954).is_none()); + assert_eq!(inverse(1, 961), Some(1)); + assert_eq!(inverse(4, 961), Some(721)); + assert_eq!(inverse(7, 961), Some(412)); + assert_eq!(inverse(10, 961), Some(865)); + assert_eq!(inverse(13, 961), Some(74)); + assert_eq!(inverse(16, 961), Some(901)); + assert_eq!(inverse(19, 961), Some(607)); + assert_eq!(inverse(22, 961), Some(830)); + assert_eq!(inverse(25, 961), Some(346)); + assert_eq!(inverse(28, 961), Some(103)); + assert_eq!(inverse(1, 968), Some(1)); + assert!(inverse(4, 968).is_none()); + assert_eq!(inverse(7, 968), Some(415)); + assert!(inverse(10, 968).is_none()); + assert_eq!(inverse(13, 968), Some(149)); + assert!(inverse(16, 968).is_none()); + assert_eq!(inverse(19, 968), Some(51)); + assert!(inverse(22, 968).is_none()); + assert_eq!(inverse(25, 968), Some(697)); + assert!(inverse(28, 968).is_none()); + assert_eq!(inverse(1, 975), Some(1)); + assert_eq!(inverse(4, 975), Some(244)); + assert_eq!(inverse(7, 975), Some(418)); + assert!(inverse(10, 975).is_none()); + assert!(inverse(13, 975).is_none()); + assert_eq!(inverse(16, 975), Some(61)); + assert_eq!(inverse(19, 975), Some(154)); + assert_eq!(inverse(22, 975), Some(133)); + assert!(inverse(25, 975).is_none()); + assert_eq!(inverse(28, 975), Some(592)); + assert_eq!(inverse(1, 982), Some(1)); + assert!(inverse(4, 982).is_none()); + assert_eq!(inverse(7, 982), Some(421)); + assert!(inverse(10, 982).is_none()); + assert_eq!(inverse(13, 982), Some(831)); + assert!(inverse(16, 982).is_none()); + assert_eq!(inverse(19, 982), Some(827)); + assert!(inverse(22, 982).is_none()); + assert_eq!(inverse(25, 982), Some(275)); + assert!(inverse(28, 982).is_none()); + assert_eq!(inverse(1, 989), Some(1)); + assert_eq!(inverse(4, 989), Some(742)); + assert_eq!(inverse(7, 989), Some(424)); + assert_eq!(inverse(10, 989), Some(99)); + assert_eq!(inverse(13, 989), Some(913)); + assert_eq!(inverse(16, 989), Some(680)); + assert_eq!(inverse(19, 989), Some(937)); + assert_eq!(inverse(22, 989), Some(45)); + assert_eq!(inverse(25, 989), Some(633)); + assert_eq!(inverse(28, 989), Some(106)); + assert_eq!(inverse(1, 996), Some(1)); + assert!(inverse(4, 996).is_none()); + assert_eq!(inverse(7, 996), Some(427)); + assert!(inverse(10, 996).is_none()); + assert_eq!(inverse(13, 996), Some(613)); + assert!(inverse(16, 996).is_none()); + assert_eq!(inverse(19, 996), Some(367)); + assert!(inverse(22, 996).is_none()); + assert_eq!(inverse(25, 996), Some(757)); + assert!(inverse(28, 996).is_none()); + } } diff --git a/crates/fhe-util/src/u256.rs b/crates/fhe-util/src/u256.rs index d043e81..ee679ca 100644 --- a/crates/fhe-util/src/u256.rs +++ b/crates/fhe-util/src/u256.rs @@ -8,217 +8,217 @@ use std::ops::{Not, Shr, ShrAssign}; pub struct U256(u64, u64, u64, u64); impl U256 { - /// Returns the additive identity element, 0. - pub const fn zero() -> Self { - Self(0, 0, 0, 0) - } + /// Returns the additive identity element, 0. + pub const fn zero() -> Self { + Self(0, 0, 0, 0) + } - /// Add an U256 to self, wrapping modulo 2^256. - pub fn wrapping_add_assign(&mut self, other: Self) { - let (a, c1) = self.0.overflowing_add(other.0); - let (b, c2) = self.1.overflowing_add(other.1); - let (c, c3) = b.overflowing_add(c1 as u64); - let (d, c4) = self.2.overflowing_add(other.2); - let (e, c5) = d.overflowing_add((c2 | c3) as u64); - let f = self.3.wrapping_add(other.3); - let g = f.wrapping_add((c4 | c5) as u64); - self.0 = a; - self.1 = c; - self.2 = e; - self.3 = g; - } + /// Add an U256 to self, wrapping modulo 2^256. + pub fn wrapping_add_assign(&mut self, other: Self) { + let (a, c1) = self.0.overflowing_add(other.0); + let (b, c2) = self.1.overflowing_add(other.1); + let (c, c3) = b.overflowing_add(c1 as u64); + let (d, c4) = self.2.overflowing_add(other.2); + let (e, c5) = d.overflowing_add((c2 | c3) as u64); + let f = self.3.wrapping_add(other.3); + let g = f.wrapping_add((c4 | c5) as u64); + self.0 = a; + self.1 = c; + self.2 = e; + self.3 = g; + } - /// Subtract an U256 to self, wrapping modulo 2^256. - pub fn wrapping_sub_assign(&mut self, other: Self) { - let (a, b1) = self.0.overflowing_sub(other.0); - let (b, b2) = self.1.overflowing_sub(other.1); - let (c, b3) = b.overflowing_sub(b1 as u64); - let (d, b4) = self.2.overflowing_sub(other.2); - let (e, b5) = d.overflowing_sub((b2 | b3) as u64); - let f = self.3.wrapping_sub(other.3); - let g = f.wrapping_sub((b4 | b5) as u64); - self.0 = a; - self.1 = c; - self.2 = e; - self.3 = g; - } + /// Subtract an U256 to self, wrapping modulo 2^256. + pub fn wrapping_sub_assign(&mut self, other: Self) { + let (a, b1) = self.0.overflowing_sub(other.0); + let (b, b2) = self.1.overflowing_sub(other.1); + let (c, b3) = b.overflowing_sub(b1 as u64); + let (d, b4) = self.2.overflowing_sub(other.2); + let (e, b5) = d.overflowing_sub((b2 | b3) as u64); + let f = self.3.wrapping_sub(other.3); + let g = f.wrapping_sub((b4 | b5) as u64); + self.0 = a; + self.1 = c; + self.2 = e; + self.3 = g; + } - /// Returns the most significant bit of the unsigned integer. - pub const fn msb(self) -> u64 { - self.3 >> 63 - } + /// Returns the most significant bit of the unsigned integer. + pub const fn msb(self) -> u64 { + self.3 >> 63 + } } impl From<[u64; 4]> for U256 { - fn from(a: [u64; 4]) -> Self { - Self(a[0], a[1], a[2], a[3]) - } + fn from(a: [u64; 4]) -> Self { + Self(a[0], a[1], a[2], a[3]) + } } impl From<[u128; 2]> for U256 { - fn from(a: [u128; 2]) -> Self { - Self( - a[0] as u64, - (a[0] >> 64) as u64, - a[1] as u64, - (a[1] >> 64) as u64, - ) - } + fn from(a: [u128; 2]) -> Self { + Self( + a[0] as u64, + (a[0] >> 64) as u64, + a[1] as u64, + (a[1] >> 64) as u64, + ) + } } impl From for [u64; 4] { - fn from(a: U256) -> [u64; 4] { - [a.0, a.1, a.2, a.3] - } + fn from(a: U256) -> [u64; 4] { + [a.0, a.1, a.2, a.3] + } } impl From for [u128; 2] { - fn from(a: U256) -> [u128; 2] { - [ - (a.0 as u128) + ((a.1 as u128) << 64), - (a.2 as u128) + ((a.3 as u128) << 64), - ] - } + fn from(a: U256) -> [u128; 2] { + [ + (a.0 as u128) + ((a.1 as u128) << 64), + (a.2 as u128) + ((a.3 as u128) << 64), + ] + } } impl From<&U256> for u128 { - fn from(v: &U256) -> Self { - debug_assert!(v.2 == 0 && v.3 == 0); - (v.0 as u128) + ((v.1 as u128) << 64) - } + fn from(v: &U256) -> Self { + debug_assert!(v.2 == 0 && v.3 == 0); + (v.0 as u128) + ((v.1 as u128) << 64) + } } impl Not for U256 { - type Output = Self; + type Output = Self; - fn not(self) -> Self { - Self(!self.0, !self.1, !self.2, !self.3) - } + fn not(self) -> Self { + Self(!self.0, !self.1, !self.2, !self.3) + } } impl Shr for U256 { - type Output = Self; + type Output = Self; - fn shr(self, rhs: usize) -> Self { - let mut r = self; - r >>= rhs; - r - } + fn shr(self, rhs: usize) -> Self { + let mut r = self; + r >>= rhs; + r + } } impl ShrAssign for U256 { - fn shr_assign(&mut self, rhs: usize) { - debug_assert!(rhs < 256); + fn shr_assign(&mut self, rhs: usize) { + debug_assert!(rhs < 256); - if rhs >= 192 { - self.0 = self.3 >> (rhs - 192); - self.1 = 0; - self.2 = 0; - self.3 = 0; - } else if rhs > 128 { - self.0 = (self.2 >> (rhs - 128)) | (self.3 << (192 - rhs)); - self.1 = self.3 >> (rhs - 128); - self.2 = 0; - self.3 = 0; - } else if rhs == 128 { - self.0 = self.2; - self.1 = self.3; - self.2 = 0; - self.3 = 0; - } else if rhs > 64 { - self.0 = (self.1 >> (rhs - 64)) | (self.2 << (128 - rhs)); - self.1 = (self.2 >> (rhs - 64)) | (self.3 << (128 - rhs)); - self.2 = self.3 >> (rhs - 64); - self.3 = 0; - } else if rhs == 64 { - self.0 = self.1; - self.1 = self.2; - self.2 = self.3; - self.3 = 0; - } else if rhs > 0 { - self.0 = (self.0 >> rhs) | (self.1 << (64 - rhs)); - self.1 = (self.1 >> rhs) | (self.2 << (64 - rhs)); - self.2 = (self.2 >> rhs) | (self.3 << (64 - rhs)); - self.3 >>= rhs; - } - } + if rhs >= 192 { + self.0 = self.3 >> (rhs - 192); + self.1 = 0; + self.2 = 0; + self.3 = 0; + } else if rhs > 128 { + self.0 = (self.2 >> (rhs - 128)) | (self.3 << (192 - rhs)); + self.1 = self.3 >> (rhs - 128); + self.2 = 0; + self.3 = 0; + } else if rhs == 128 { + self.0 = self.2; + self.1 = self.3; + self.2 = 0; + self.3 = 0; + } else if rhs > 64 { + self.0 = (self.1 >> (rhs - 64)) | (self.2 << (128 - rhs)); + self.1 = (self.2 >> (rhs - 64)) | (self.3 << (128 - rhs)); + self.2 = self.3 >> (rhs - 64); + self.3 = 0; + } else if rhs == 64 { + self.0 = self.1; + self.1 = self.2; + self.2 = self.3; + self.3 = 0; + } else if rhs > 0 { + self.0 = (self.0 >> rhs) | (self.1 << (64 - rhs)); + self.1 = (self.1 >> rhs) | (self.2 << (64 - rhs)); + self.2 = (self.2 >> rhs) | (self.3 << (64 - rhs)); + self.3 >>= rhs; + } + } } #[cfg(test)] mod tests { - use super::U256; + use super::U256; - #[test] - fn zero() { - assert_eq!(u128::from(&U256::zero()), 0u128); - } + #[test] + fn zero() { + assert_eq!(u128::from(&U256::zero()), 0u128); + } - proptest! { + proptest! { - #[test] - fn u128(a: u128) { - prop_assert_eq!(a, u128::from(&U256::from([a, 0]))); - } + #[test] + fn u128(a: u128) { + prop_assert_eq!(a, u128::from(&U256::from([a, 0]))); + } - #[test] - fn from_into_u64(a: u64, b: u64, c: u64, d:u64) { - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d])), [a, b, c, d]); - } + #[test] + fn from_into_u64(a: u64, b: u64, c: u64, d:u64) { + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d])), [a, b, c, d]); + } - #[test] - fn from_into_u128(a: u128, b: u128) { - prop_assert_eq!(<[u128; 2]>::from(U256::from([a, b])), [a, b]); - } + #[test] + fn from_into_u128(a: u128, b: u128) { + prop_assert_eq!(<[u128; 2]>::from(U256::from([a, b])), [a, b]); + } - #[test] - fn shift(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) { - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 0), [a, b, c, d]); - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 64), [b, c, d, 0]); - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 128), [c, d, 0, 0]); - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 192), [d, 0, 0, 0]); + #[test] + fn shift(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) { + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 0), [a, b, c, d]); + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 64), [b, c, d, 0]); + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 128), [c, d, 0, 0]); + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 192), [d, 0, 0, 0]); - prop_assume!(shift % 64 != 0); - if shift < 64 { - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]); - } else if shift < 128 { - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]); - } else if shift < 192 { - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]); - } else { - prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(d >> (shift - 192)), 0, 0, 0]); - } - } + prop_assume!(shift % 64 != 0); + if shift < 64 { + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]); + } else if shift < 128 { + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]); + } else if shift < 192 { + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]); + } else { + prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(d >> (shift - 192)), 0, 0, 0]); + } + } - #[test] - fn shift_assign(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) { - let mut u = U256::from([a, b, c, d]); + #[test] + fn shift_assign(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) { + let mut u = U256::from([a, b, c, d]); - u >>= 0; - prop_assert_eq!(<[u64; 4]>::from(u), [a, b, c, d]); + u >>= 0; + prop_assert_eq!(<[u64; 4]>::from(u), [a, b, c, d]); - u >>= 64; - prop_assert_eq!(<[u64; 4]>::from(u), [b, c, d, 0]); + u >>= 64; + prop_assert_eq!(<[u64; 4]>::from(u), [b, c, d, 0]); - u = U256::from([a, b, c, d]); - u >>= 128; - prop_assert_eq!(<[u64; 4]>::from(u), [c, d, 0, 0]); + u = U256::from([a, b, c, d]); + u >>= 128; + prop_assert_eq!(<[u64; 4]>::from(u), [c, d, 0, 0]); - u = U256::from([a, b, c, d]); - u >>= 192; - prop_assert_eq!(<[u64; 4]>::from(u), [d, 0, 0, 0]); + u = U256::from([a, b, c, d]); + u >>= 192; + prop_assert_eq!(<[u64; 4]>::from(u), [d, 0, 0, 0]); - prop_assume!(shift % 64 != 0); - u = U256::from([a, b, c, d]); - u >>= shift; - if shift < 64 { - prop_assert_eq!(<[u64; 4]>::from(u), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]); - } else if shift < 128 { - prop_assert_eq!(<[u64; 4]>::from(u), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]); - } else if shift < 192 { - prop_assert_eq!(<[u64; 4]>::from(u), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]); - } else { - prop_assert_eq!(<[u64; 4]>::from(u), [(d >> (shift - 192)), 0, 0, 0]); - } - } - } + prop_assume!(shift % 64 != 0); + u = U256::from([a, b, c, d]); + u >>= shift; + if shift < 64 { + prop_assert_eq!(<[u64; 4]>::from(u), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]); + } else if shift < 128 { + prop_assert_eq!(<[u64; 4]>::from(u), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]); + } else if shift < 192 { + prop_assert_eq!(<[u64; 4]>::from(u), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]); + } else { + prop_assert_eq!(<[u64; 4]>::from(u), [(d >> (shift - 192)), 0, 0, 0]); + } + } + } } diff --git a/crates/fhe/benches/bfv.rs b/crates/fhe/benches/bfv.rs index 3c4f7a8..7ddd466 100644 --- a/crates/fhe/benches/bfv.rs +++ b/crates/fhe/benches/bfv.rs @@ -1,7 +1,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use fhe::bfv::{ - BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Multiplicator, Plaintext, PublicKey, - RelinearizationKey, SecretKey, + BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Multiplicator, Plaintext, PublicKey, + RelinearizationKey, SecretKey, }; use fhe_math::rns::{RnsContext, ScalingFactor}; use fhe_math::zq::primes::generate_prime; @@ -13,275 +13,275 @@ use rand::{rngs::OsRng, thread_rng}; use std::time::Duration; pub fn bfv_benchmark(c: &mut Criterion) { - let mut rng = thread_rng(); - let mut group = c.benchmark_group("bfv"); - group.sample_size(10); - group.warm_up_time(Duration::from_millis(600)); - group.measurement_time(Duration::from_millis(1000)); + let mut rng = thread_rng(); + let mut group = c.benchmark_group("bfv"); + group.sample_size(10); + group.warm_up_time(Duration::from_millis(600)); + group.measurement_time(Duration::from_millis(1000)); - for par in BfvParameters::default_parameters_128(20) { - let sk = SecretKey::random(&par, &mut OsRng); - let ek = if par.moduli().len() > 1 { - Some( - EvaluationKeyBuilder::new(&sk) - .unwrap() - .enable_inner_sum() - .unwrap() - .enable_column_rotation(1) - .unwrap() - .enable_expansion(ilog2(par.degree() as u64)) - .unwrap() - .build(&mut rng) - .unwrap(), - ) - } else { - None - }; + for par in BfvParameters::default_parameters_128(20) { + let sk = SecretKey::random(&par, &mut OsRng); + let ek = if par.moduli().len() > 1 { + Some( + EvaluationKeyBuilder::new(&sk) + .unwrap() + .enable_inner_sum() + .unwrap() + .enable_column_rotation(1) + .unwrap() + .enable_expansion(ilog2(par.degree() as u64)) + .unwrap() + .build(&mut rng) + .unwrap(), + ) + } else { + None + }; - let rk = if par.moduli().len() > 1 { - Some(RelinearizationKey::new(&sk, &mut rng).unwrap()) - } else { - None - }; + let rk = if par.moduli().len() > 1 { + Some(RelinearizationKey::new(&sk, &mut rng).unwrap()) + } else { + None + }; - let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par).unwrap(); - let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), &par).unwrap(); - let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap(); - let c2: Ciphertext = sk.try_encrypt(&pt2, &mut rng).unwrap(); + let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par).unwrap(); + let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), &par).unwrap(); + let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap(); + let c2: Ciphertext = sk.try_encrypt(&pt2, &mut rng).unwrap(); - let q = par.moduli_sizes().iter().sum::(); + let q = par.moduli_sizes().iter().sum::(); - group.bench_function( - BenchmarkId::new("keygen_sk", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| SecretKey::random(&par, &mut OsRng)); - }, - ); + group.bench_function( + BenchmarkId::new("keygen_sk", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| SecretKey::random(&par, &mut OsRng)); + }, + ); - group.bench_function( - BenchmarkId::new("keygen_pk", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| PublicKey::new(&sk, &mut rng)); - }, - ); + group.bench_function( + BenchmarkId::new("keygen_pk", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| PublicKey::new(&sk, &mut rng)); + }, + ); - group.bench_function( - BenchmarkId::new("keygen_rk", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| RelinearizationKey::new(&sk, &mut rng)); - }, - ); + group.bench_function( + BenchmarkId::new("keygen_rk", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| RelinearizationKey::new(&sk, &mut rng)); + }, + ); - group.bench_function( - BenchmarkId::new("encode_poly", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), &par)); - }, - ); + group.bench_function( + BenchmarkId::new("encode_poly", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), &par)); + }, + ); - group.bench_function( - BenchmarkId::new("encode_simd", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par)); - }, - ); + group.bench_function( + BenchmarkId::new("encode_simd", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par)); + }, + ); - group.bench_function( - BenchmarkId::new("encrypt_sk", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| { - let _: fhe::Result = sk.try_encrypt(&pt1, &mut rng); - }); - }, - ); + group.bench_function( + BenchmarkId::new("encrypt_sk", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| { + let _: fhe::Result = sk.try_encrypt(&pt1, &mut rng); + }); + }, + ); - group.bench_function( - BenchmarkId::new("add_ct", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = &c1 + &c2); - }, - ); + group.bench_function( + BenchmarkId::new("add_ct", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = &c1 + &c2); + }, + ); - group.bench_function( - BenchmarkId::new("add_assign_ct", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 += &c2); - }, - ); + group.bench_function( + BenchmarkId::new("add_assign_ct", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 += &c2); + }, + ); - group.bench_function( - BenchmarkId::new("add_pt", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = &c1 + &pt2); - }, - ); + group.bench_function( + BenchmarkId::new("add_pt", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = &c1 + &pt2); + }, + ); - group.bench_function( - BenchmarkId::new("add_assign_pt", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 += &pt2); - }, - ); + group.bench_function( + BenchmarkId::new("add_assign_pt", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 += &pt2); + }, + ); - group.bench_function( - BenchmarkId::new("sub_ct", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = &c1 - &c2); - }, - ); + group.bench_function( + BenchmarkId::new("sub_ct", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = &c1 - &c2); + }, + ); - group.bench_function( - BenchmarkId::new("sub_assign_ct", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 -= &c2); - }, - ); + group.bench_function( + BenchmarkId::new("sub_assign_ct", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 -= &c2); + }, + ); - group.bench_function( - BenchmarkId::new("sub_pt", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = &c1 - &pt2); - }, - ); + group.bench_function( + BenchmarkId::new("sub_pt", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = &c1 - &pt2); + }, + ); - group.bench_function( - BenchmarkId::new("sub_assign_pt", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 -= &pt2); - }, - ); + group.bench_function( + BenchmarkId::new("sub_assign_pt", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 -= &pt2); + }, + ); - group.bench_function( - BenchmarkId::new("neg", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = -&c2); - }, - ); + group.bench_function( + BenchmarkId::new("neg", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = -&c2); + }, + ); - let mut c3 = &c1 * &c1; - let c3_clone = c3.clone(); - if let Some(rk) = rk.as_ref() { - group.bench_function( - BenchmarkId::new("relinearize", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| { - assert!(rk.relinearizes(&mut c3).is_ok()); - c3 = c3_clone.clone(); - }); - }, - ); - } + let mut c3 = &c1 * &c1; + let c3_clone = c3.clone(); + if let Some(rk) = rk.as_ref() { + group.bench_function( + BenchmarkId::new("relinearize", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| { + assert!(rk.relinearizes(&mut c3).is_ok()); + c3 = c3_clone.clone(); + }); + }, + ); + } - if let Some(ek) = ek { - group.bench_function( - BenchmarkId::new("rotate_rows", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = ek.rotates_rows(&c1).unwrap()); - }, - ); + if let Some(ek) = ek { + group.bench_function( + BenchmarkId::new("rotate_rows", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = ek.rotates_rows(&c1).unwrap()); + }, + ); - group.bench_function( - BenchmarkId::new("rotate_columns", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = ek.rotates_columns_by(&c1, 1).unwrap()); - }, - ); + group.bench_function( + BenchmarkId::new("rotate_columns", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = ek.rotates_columns_by(&c1, 1).unwrap()); + }, + ); - group.bench_function( - BenchmarkId::new("inner_sum", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| c1 = ek.computes_inner_sum(&c1).unwrap()); - }, - ); + group.bench_function( + BenchmarkId::new("inner_sum", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| c1 = ek.computes_inner_sum(&c1).unwrap()); + }, + ); - for i in 1..=ilog2(par.degree() as u64) { - if par.degree() > 2048 && i > 4 { - continue; // Skip slow benchmarks - } - group.bench_function( - BenchmarkId::new( - format!("expand_{i}"), - format!("n={}/log(q)={}", par.degree(), q), - ), - |b| { - b.iter(|| ek.expands(&c1, 1 << i).unwrap()); - }, - ); - } - } + for i in 1..=ilog2(par.degree() as u64) { + if par.degree() > 2048 && i > 4 { + continue; // Skip slow benchmarks + } + group.bench_function( + BenchmarkId::new( + format!("expand_{i}"), + format!("n={}/log(q)={}", par.degree(), q), + ), + |b| { + b.iter(|| ek.expands(&c1, 1 << i).unwrap()); + }, + ); + } + } - group.bench_function( - BenchmarkId::new("mul", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| &c1 * &c2); - }, - ); + group.bench_function( + BenchmarkId::new("mul", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| &c1 * &c2); + }, + ); - group.bench_function( - BenchmarkId::new("square", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| &c1 * &c1); - }, - ); + group.bench_function( + BenchmarkId::new("square", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| &c1 * &c1); + }, + ); - if let Some(rk) = rk.as_ref() { - group.bench_function( - BenchmarkId::new( - "mul_then_relinearize", - format!("n={}/log(q)={}", par.degree(), q), - ), - |b| { - b.iter(|| { - c3 = &c1 * &c2; - assert!(rk.relinearizes(&mut c3).is_ok()); - }); - }, - ); + if let Some(rk) = rk.as_ref() { + group.bench_function( + BenchmarkId::new( + "mul_then_relinearize", + format!("n={}/log(q)={}", par.degree(), q), + ), + |b| { + b.iter(|| { + c3 = &c1 * &c2; + assert!(rk.relinearizes(&mut c3).is_ok()); + }); + }, + ); - // Default multiplication method - let multiplicator = Multiplicator::default(rk).unwrap(); + // Default multiplication method + let multiplicator = Multiplicator::default(rk).unwrap(); - group.bench_function( - BenchmarkId::new("mul_and_relin", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok())); - }, - ); + group.bench_function( + BenchmarkId::new("mul_and_relin", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok())); + }, + ); - // Second multiplication option. - let nmoduli = div_ceil(q, 62); - let mut extended_basis = par.moduli().to_vec(); - let mut upper_bound = u64::MAX >> 2; - while extended_basis.len() != nmoduli + par.moduli().len() { - upper_bound = generate_prime(62, 2 * par.degree() as u64, upper_bound).unwrap(); - if !extended_basis.contains(&upper_bound) { - extended_basis.push(upper_bound) - } - } - let rns_q = RnsContext::new(&extended_basis[..par.moduli().len()]).unwrap(); - let rns_p = RnsContext::new(&extended_basis[par.moduli().len()..]).unwrap(); - let mut multiplicator = Multiplicator::new( - ScalingFactor::one(), - ScalingFactor::new(rns_p.modulus(), rns_q.modulus()), - &extended_basis, - ScalingFactor::new(&BigUint::from(par.plaintext()), rns_p.modulus()), - &par, - ) - .unwrap(); - assert!(multiplicator.enable_relinearization(rk).is_ok()); - group.bench_function( - BenchmarkId::new( - "mul_and_relin_2", - format!("n={}/log(q)={}", par.degree(), q), - ), - |b| { - b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok())); - }, - ); - } - } + // Second multiplication option. + let nmoduli = div_ceil(q, 62); + let mut extended_basis = par.moduli().to_vec(); + let mut upper_bound = u64::MAX >> 2; + while extended_basis.len() != nmoduli + par.moduli().len() { + upper_bound = generate_prime(62, 2 * par.degree() as u64, upper_bound).unwrap(); + if !extended_basis.contains(&upper_bound) { + extended_basis.push(upper_bound) + } + } + let rns_q = RnsContext::new(&extended_basis[..par.moduli().len()]).unwrap(); + let rns_p = RnsContext::new(&extended_basis[par.moduli().len()..]).unwrap(); + let mut multiplicator = Multiplicator::new( + ScalingFactor::one(), + ScalingFactor::new(rns_p.modulus(), rns_q.modulus()), + &extended_basis, + ScalingFactor::new(&BigUint::from(par.plaintext()), rns_p.modulus()), + &par, + ) + .unwrap(); + assert!(multiplicator.enable_relinearization(rk).is_ok()); + group.bench_function( + BenchmarkId::new( + "mul_and_relin_2", + format!("n={}/log(q)={}", par.degree(), q), + ), + |b| { + b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok())); + }, + ); + } + } - group.finish(); + group.finish(); } criterion_group!(bfv, bfv_benchmark); diff --git a/crates/fhe/benches/bfv_optimized_ops.rs b/crates/fhe/benches/bfv_optimized_ops.rs index aec2dd7..52c826f 100644 --- a/crates/fhe/benches/bfv_optimized_ops.rs +++ b/crates/fhe/benches/bfv_optimized_ops.rs @@ -6,66 +6,66 @@ use rand::{rngs::OsRng, thread_rng}; use std::time::Duration; pub fn bfv_benchmark(c: &mut Criterion) { - let mut rng = thread_rng(); - let mut group = c.benchmark_group("bfv_optimized_ops"); - group.sample_size(10); - group.warm_up_time(Duration::from_secs(1)); - group.measurement_time(Duration::from_secs(1)); + let mut rng = thread_rng(); + let mut group = c.benchmark_group("bfv_optimized_ops"); + group.sample_size(10); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(1)); - for par in &BfvParameters::default_parameters_128(20)[2..] { - for size in [10, 128, 1000] { - let sk = SecretKey::random(par, &mut OsRng); - let pt1 = - Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), par).unwrap(); - let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap(); + for par in &BfvParameters::default_parameters_128(20)[2..] { + for size in [10, 128, 1000] { + let sk = SecretKey::random(par, &mut OsRng); + let pt1 = + Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), par).unwrap(); + let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap(); - let ct_vec = (0..size) - .map(|i| { - let pt = - Plaintext::try_encode(&(i..16u64).collect_vec(), Encoding::poly(), par) - .unwrap(); - sk.try_encrypt(&pt, &mut rng).unwrap() - }) - .collect_vec(); - let pt_vec = (0..size) - .map(|i| { - Plaintext::try_encode(&(i..39u64).collect_vec(), Encoding::poly(), par).unwrap() - }) - .collect_vec(); + let ct_vec = (0..size) + .map(|i| { + let pt = + Plaintext::try_encode(&(i..16u64).collect_vec(), Encoding::poly(), par) + .unwrap(); + sk.try_encrypt(&pt, &mut rng).unwrap() + }) + .collect_vec(); + let pt_vec = (0..size) + .map(|i| { + Plaintext::try_encode(&(i..39u64).collect_vec(), Encoding::poly(), par).unwrap() + }) + .collect_vec(); - group.bench_function( - BenchmarkId::new( - "dot_product/naive", - format!( - "size={}/degree={}/logq={}", - size, - par.degree(), - par.moduli_sizes().iter().sum::() - ), - ), - |b| { - b.iter(|| izip!(&ct_vec, &pt_vec).for_each(|(cti, pti)| c1 += &(cti * pti))); - }, - ); + group.bench_function( + BenchmarkId::new( + "dot_product/naive", + format!( + "size={}/degree={}/logq={}", + size, + par.degree(), + par.moduli_sizes().iter().sum::() + ), + ), + |b| { + b.iter(|| izip!(&ct_vec, &pt_vec).for_each(|(cti, pti)| c1 += &(cti * pti))); + }, + ); - group.bench_function( - BenchmarkId::new( - "dot_product/opt", - format!( - "size={}/degree={}/logq={}", - size, - par.degree(), - par.moduli_sizes().iter().sum::() - ), - ), - |b| { - b.iter(|| dot_product_scalar(ct_vec.iter(), pt_vec.iter())); - }, - ); - } - } + group.bench_function( + BenchmarkId::new( + "dot_product/opt", + format!( + "size={}/degree={}/logq={}", + size, + par.degree(), + par.moduli_sizes().iter().sum::() + ), + ), + |b| { + b.iter(|| dot_product_scalar(ct_vec.iter(), pt_vec.iter())); + }, + ); + } + } - group.finish(); + group.finish(); } criterion_group!(bfv, bfv_benchmark); diff --git a/crates/fhe/benches/bfv_rgsw.rs b/crates/fhe/benches/bfv_rgsw.rs index c436948..94d2338 100644 --- a/crates/fhe/benches/bfv_rgsw.rs +++ b/crates/fhe/benches/bfv_rgsw.rs @@ -6,30 +6,30 @@ use rand::{rngs::OsRng, thread_rng}; use std::time::Duration; pub fn bfv_rgsw_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("bfv_rgsw"); - group.sample_size(10); - group.warm_up_time(Duration::from_secs(1)); - group.measurement_time(Duration::from_secs(1)); + let mut group = c.benchmark_group("bfv_rgsw"); + group.sample_size(10); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(1)); - for par in &BfvParameters::default_parameters_128(20)[2..] { - let mut rng = thread_rng(); - let sk = SecretKey::random(par, &mut OsRng); + for par in &BfvParameters::default_parameters_128(20)[2..] { + let mut rng = thread_rng(); + let sk = SecretKey::random(par, &mut OsRng); - let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), par).unwrap(); - let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), par).unwrap(); - let c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap(); - let c2: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng).unwrap(); - let q = par.moduli_sizes().iter().sum::(); + let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), par).unwrap(); + let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), par).unwrap(); + let c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap(); + let c2: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng).unwrap(); + let q = par.moduli_sizes().iter().sum::(); - group.bench_function( - BenchmarkId::new("external", format!("n={}/log(q)={}", par.degree(), q)), - |b| { - b.iter(|| &c1 * &c2); - }, - ); - } + group.bench_function( + BenchmarkId::new("external", format!("n={}/log(q)={}", par.degree(), q)), + |b| { + b.iter(|| &c1 * &c2); + }, + ); + } - group.finish(); + group.finish(); } criterion_group!(bfv_rgsw, bfv_rgsw_benchmark); diff --git a/crates/fhe/examples/mulpir.rs b/crates/fhe/examples/mulpir.rs index 03cf6fe..dcdfe6c 100644 --- a/crates/fhe/examples/mulpir.rs +++ b/crates/fhe/examples/mulpir.rs @@ -11,260 +11,260 @@ mod util; use console::style; use fhe::bfv; use fhe_traits::{ - DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize, + DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize, }; use fhe_util::{ilog2, inverse, transcode_to_bytes}; use indicatif::HumanBytes; use rand::{rngs::OsRng, thread_rng, RngCore}; use std::{env, error::Error, process::exit, sync::Arc}; use util::{ - encode_database, generate_database, number_elements_per_plaintext, - timeit::{timeit, timeit_n}, + encode_database, generate_database, number_elements_per_plaintext, + timeit::{timeit, timeit_n}, }; fn print_notice_and_exit(max_element_size: usize, error: Option) { - println!( - "{} MulPIR with fhe.rs", - style(" overview:").magenta().bold() - ); - println!( - "{} mulpir- [-h] [--help] [--database_size=] [--element_size=]", - style(" usage:").magenta().bold() - ); - println!( - "{} {} must be at least 1, and {} must be between 1 and {}", - style("constraints:").magenta().bold(), - style("database_size").blue(), - style("element_size").blue(), - max_element_size - ); - if let Some(error) = error { - println!("{} {}", style(" error:").red().bold(), error); - } - exit(0); + println!( + "{} MulPIR with fhe.rs", + style(" overview:").magenta().bold() + ); + println!( + "{} mulpir- [-h] [--help] [--database_size=] [--element_size=]", + style(" usage:").magenta().bold() + ); + println!( + "{} {} must be at least 1, and {} must be between 1 and {}", + style("constraints:").magenta().bold(), + style("database_size").blue(), + style("element_size").blue(), + max_element_size + ); + if let Some(error) = error { + println!("{} {}", style(" error:").red().bold(), error); + } + exit(0); } fn main() -> Result<(), Box> { - // We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf. - let degree = 8192; - let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1; - let moduli_sizes = [50, 55, 55]; + // We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf. + let degree = 8192; + let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1; + let moduli_sizes = [50, 55, 55]; - // Compute what is the maximum byte-length of an element to fit within one - // ciphertext. Each coefficient of the ciphertext polynomial can contain - // floor(log2(plaintext_modulus)) bits. - let max_element_size = (ilog2(plaintext_modulus) * degree) / 8; + // Compute what is the maximum byte-length of an element to fit within one + // ciphertext. Each coefficient of the ciphertext polynomial can contain + // floor(log2(plaintext_modulus)) bits. + let max_element_size = (ilog2(plaintext_modulus) * degree) / 8; - // This executable is a command line tool which enables to specify different - // database and element sizes. - let args: Vec = env::args().skip(1).collect(); + // This executable is a command line tool which enables to specify different + // database and element sizes. + let args: Vec = env::args().skip(1).collect(); - // Print the help if requested. - if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) { - print_notice_and_exit(max_element_size, None) - } + // Print the help if requested. + if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) { + print_notice_and_exit(max_element_size, None) + } - // Use the default values from . - let mut database_size = 1 << 20; - let mut elements_size = 288; + // Use the default values from . + let mut database_size = 1 << 20; + let mut elements_size = 288; - // Update the database size and/or element size depending on the arguments - // provided. - for arg in &args { - if arg.starts_with("--database_size") { - let a: Vec<&str> = arg.rsplit('=').collect(); - if a.len() != 2 || a[0].parse::().is_err() { - print_notice_and_exit( - max_element_size, - Some("Invalid `--database_size` command".to_string()), - ) - } else { - database_size = a[0].parse::().unwrap() - } - } else if arg.starts_with("--element_size") { - let a: Vec<&str> = arg.rsplit('=').collect(); - if a.len() != 2 || a[0].parse::().is_err() { - print_notice_and_exit( - max_element_size, - Some("Invalid `--element_size` command".to_string()), - ) - } else { - elements_size = a[0].parse::().unwrap() - } - } else { - print_notice_and_exit( - max_element_size, - Some(format!("Unrecognized command: {arg}")), - ) - } - } + // Update the database size and/or element size depending on the arguments + // provided. + for arg in &args { + if arg.starts_with("--database_size") { + let a: Vec<&str> = arg.rsplit('=').collect(); + if a.len() != 2 || a[0].parse::().is_err() { + print_notice_and_exit( + max_element_size, + Some("Invalid `--database_size` command".to_string()), + ) + } else { + database_size = a[0].parse::().unwrap() + } + } else if arg.starts_with("--element_size") { + let a: Vec<&str> = arg.rsplit('=').collect(); + if a.len() != 2 || a[0].parse::().is_err() { + print_notice_and_exit( + max_element_size, + Some("Invalid `--element_size` command".to_string()), + ) + } else { + elements_size = a[0].parse::().unwrap() + } + } else { + print_notice_and_exit( + max_element_size, + Some(format!("Unrecognized command: {arg}")), + ) + } + } - if elements_size > max_element_size || elements_size == 0 || database_size == 0 { - print_notice_and_exit( - max_element_size, - Some("Element or database sizes out of bound".to_string()), - ) - } + if elements_size > max_element_size || elements_size == 0 || database_size == 0 { + print_notice_and_exit( + max_element_size, + Some("Element or database sizes out of bound".to_string()), + ) + } - // The parameters are within bound, let's go! Let's first display some - // information about the database. - println!("# MulPIR with fhe.rs"); - println!( - "database of {}", - HumanBytes((database_size * elements_size) as u64) - ); - println!("\tdatabase_size = {database_size}"); - println!("\telements_size = {elements_size}"); + // The parameters are within bound, let's go! Let's first display some + // information about the database. + println!("# MulPIR with fhe.rs"); + println!( + "database of {}", + HumanBytes((database_size * elements_size) as u64) + ); + println!("\tdatabase_size = {database_size}"); + println!("\telements_size = {elements_size}"); - // Generation of a random database. - let database = timeit!("Database generation", { - generate_database(database_size, elements_size) - }); + // Generation of a random database. + let database = timeit!("Database generation", { + generate_database(database_size, elements_size) + }); - // Let's generate the BFV parameters structure. - let params = timeit!( - "Parameters generation", - Arc::new( - bfv::BfvParametersBuilder::new() - .set_degree(degree) - .set_plaintext_modulus(plaintext_modulus) - .set_moduli_sizes(&moduli_sizes) - .build() - .unwrap() - ) - ); + // Let's generate the BFV parameters structure. + let params = timeit!( + "Parameters generation", + Arc::new( + bfv::BfvParametersBuilder::new() + .set_degree(degree) + .set_plaintext_modulus(plaintext_modulus) + .set_moduli_sizes(&moduli_sizes) + .build() + .unwrap() + ) + ); - // Proprocess the database on the server side: the database will be reshaped - // so as to pack as many values as possible in every row so that it fits in one - // ciphertext, and each element will be encoded as a polynomial in Ntt - // representation. - let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", { - encode_database(&database, params.clone(), 1) - }); + // Proprocess the database on the server side: the database will be reshaped + // so as to pack as many values as possible in every row so that it fits in one + // ciphertext, and each element will be encoded as a polynomial in Ntt + // representation. + let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", { + encode_database(&database, params.clone(), 1) + }); - // Client setup: the client generates a secret key, an evaluation key for - // the server will which enable to obliviously expand a ciphertext up to (dim1 + - // dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a - // relinearization key. - let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", { - let sk = bfv::SecretKey::random(¶ms, &mut OsRng); - let level = ilog2((dim1 + dim2).next_power_of_two() as u64); - println!("level = {level}"); - let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)? - .enable_expansion(level)? - .build(&mut thread_rng())?; - let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut thread_rng())?; - let ek_expansion_serialized = ek_expansion.to_bytes(); - let rk_serialized = rk.to_bytes(); - (sk, ek_expansion_serialized, rk_serialized) - }); - println!( - "📄 Evaluation key (expansion): {}", - HumanBytes(ek_expansion_serialized.len() as u64) - ); - println!( - "📄 Relinearization key: {}", - HumanBytes(rk_serialized.len() as u64) - ); + // Client setup: the client generates a secret key, an evaluation key for + // the server will which enable to obliviously expand a ciphertext up to (dim1 + + // dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a + // relinearization key. + let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", { + let sk = bfv::SecretKey::random(¶ms, &mut OsRng); + let level = ilog2((dim1 + dim2).next_power_of_two() as u64); + println!("level = {level}"); + let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)? + .enable_expansion(level)? + .build(&mut thread_rng())?; + let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut thread_rng())?; + let ek_expansion_serialized = ek_expansion.to_bytes(); + let rk_serialized = rk.to_bytes(); + (sk, ek_expansion_serialized, rk_serialized) + }); + println!( + "📄 Evaluation key (expansion): {}", + HumanBytes(ek_expansion_serialized.len() as u64) + ); + println!( + "📄 Relinearization key: {}", + HumanBytes(rk_serialized.len() as u64) + ); - // Server setup: the server receives the evaluation and relinearization keys and - // deserializes them. - let (ek_expansion, rk) = timeit!("Server setup", { - ( - bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?, - bfv::RelinearizationKey::from_bytes(&rk_serialized, ¶ms)?, - ) - }); + // Server setup: the server receives the evaluation and relinearization keys and + // deserializes them. + let (ek_expansion, rk) = timeit!("Server setup", { + ( + bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?, + bfv::RelinearizationKey::from_bytes(&rk_serialized, ¶ms)?, + ) + }); - // Client query: when the client wants to retrieve the `index`-th row of the - // original database, it first computes to which row it corresponds in the - // original database, and then encrypt a selection vector with 0 everywhere, - // except at two indices i and (dim1 + j) such that `query_index = i * dim 2 + - // j` where it sets the value (2^level)^(-1) modulo the plaintext space. - // It then encodes this vector as a `polynomial` and encrypt the plaintext. - // The ciphertext is set at level `1`, which means that one of the three moduli - // has been dropped already; the reason is that the expansion will happen at - // level 0 (with all three moduli) and then one of the moduli will be dropped - // to reduce the noise. - let index = (thread_rng().next_u64() as usize) % database_size; - let query = timeit!("Client query", { - let level = ilog2((dim1 + dim2).next_power_of_two() as u64); - let query_index = index - / number_elements_per_plaintext( - params.degree(), - ilog2(plaintext_modulus), - elements_size, - ); - let mut pt = vec![0u64; dim1 + dim2]; - let inv = inverse(1 << level, plaintext_modulus).unwrap(); - pt[query_index / dim2] = inv; - pt[dim1 + (query_index % dim2)] = inv; - let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?; - let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?; - query.to_bytes() - }); - println!("📄 Query: {}", HumanBytes(query.len() as u64)); + // Client query: when the client wants to retrieve the `index`-th row of the + // original database, it first computes to which row it corresponds in the + // original database, and then encrypt a selection vector with 0 everywhere, + // except at two indices i and (dim1 + j) such that `query_index = i * dim 2 + + // j` where it sets the value (2^level)^(-1) modulo the plaintext space. + // It then encodes this vector as a `polynomial` and encrypt the plaintext. + // The ciphertext is set at level `1`, which means that one of the three moduli + // has been dropped already; the reason is that the expansion will happen at + // level 0 (with all three moduli) and then one of the moduli will be dropped + // to reduce the noise. + let index = (thread_rng().next_u64() as usize) % database_size; + let query = timeit!("Client query", { + let level = ilog2((dim1 + dim2).next_power_of_two() as u64); + let query_index = index + / number_elements_per_plaintext( + params.degree(), + ilog2(plaintext_modulus), + elements_size, + ); + let mut pt = vec![0u64; dim1 + dim2]; + let inv = inverse(1 << level, plaintext_modulus).unwrap(); + pt[query_index / dim2] = inv; + pt[dim1 + (query_index % dim2)] = inv; + let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?; + let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?; + query.to_bytes() + }); + println!("📄 Query: {}", HumanBytes(query.len() as u64)); - // Server response: The server receives the query, and after deserializing it, - // performs the following steps: - // 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts. - // If the client created the query correctly, the server will have obtained - // `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and - // `dim1 + j`th ones encrypting `1`. - // 2- It computes the inner product of the first `dim1` ciphertexts with the - // columns if the database viewed as a dim1 * dim2 matrix. - // 3- It then multiplies the column of ciphertexts with the next `dim2` - // ciphertexts obtained after expansion of the query, then relinearize and - // modulus switch to the latest modulus to optimize communication. - // The operation is done `5` times to compute an average response time. - let response = timeit_n!("Server response", 5, { - let start = std::time::Instant::now(); - let query = bfv::Ciphertext::from_bytes(&query, ¶ms)?; - let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?; - println!("Expand: {:?}", start.elapsed()); + // Server response: The server receives the query, and after deserializing it, + // performs the following steps: + // 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts. + // If the client created the query correctly, the server will have obtained + // `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and + // `dim1 + j`th ones encrypting `1`. + // 2- It computes the inner product of the first `dim1` ciphertexts with the + // columns if the database viewed as a dim1 * dim2 matrix. + // 3- It then multiplies the column of ciphertexts with the next `dim2` + // ciphertexts obtained after expansion of the query, then relinearize and + // modulus switch to the latest modulus to optimize communication. + // The operation is done `5` times to compute an average response time. + let response = timeit_n!("Server response", 5, { + let start = std::time::Instant::now(); + let query = bfv::Ciphertext::from_bytes(&query, ¶ms)?; + let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?; + println!("Expand: {:?}", start.elapsed()); - let query_vec = &expanded_query[..dim1]; - let dot_product_mod_switch = - move |i, database: &[bfv::Plaintext]| -> fhe::Result { - let column = database.iter().skip(i).step_by(dim2); - bfv::dot_product_scalar(query_vec.iter(), column) - }; + let query_vec = &expanded_query[..dim1]; + let dot_product_mod_switch = + move |i, database: &[bfv::Plaintext]| -> fhe::Result { + let column = database.iter().skip(i).step_by(dim2); + bfv::dot_product_scalar(query_vec.iter(), column) + }; - let mut out = bfv::Ciphertext::zero(¶ms); - for (i, ci) in expanded_query[dim1..].iter().enumerate() { - out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci) - } - rk.relinearizes(&mut out)?; - out.mod_switch_to_last_level(); - out.to_bytes() - }); - println!("📄 Response: {}", HumanBytes(response.len() as u64)); + let mut out = bfv::Ciphertext::zero(¶ms); + for (i, ci) in expanded_query[dim1..].iter().enumerate() { + out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci) + } + rk.relinearizes(&mut out)?; + out.mod_switch_to_last_level(); + out.to_bytes() + }); + println!("📄 Response: {}", HumanBytes(response.len() as u64)); - // Client processing: Upon reception of the response, the client decrypts. - // Finally, it outputs the plaintext bytes, offset by the correct value - // (remember the database was reshaped to maximize how many elements) were - // embedded in a single ciphertext. - let answer = timeit!("Client answer", { - let response = bfv::Ciphertext::from_bytes(&response, ¶ms).unwrap(); + // Client processing: Upon reception of the response, the client decrypts. + // Finally, it outputs the plaintext bytes, offset by the correct value + // (remember the database was reshaped to maximize how many elements) were + // embedded in a single ciphertext. + let answer = timeit!("Client answer", { + let response = bfv::Ciphertext::from_bytes(&response, ¶ms).unwrap(); - let pt = sk.try_decrypt(&response).unwrap(); - let pt = Vec::::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap(); - let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus)); - let offset = index - % number_elements_per_plaintext( - params.degree(), - ilog2(plaintext_modulus), - elements_size, - ); + let pt = sk.try_decrypt(&response).unwrap(); + let pt = Vec::::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap(); + let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus)); + let offset = index + % number_elements_per_plaintext( + params.degree(), + ilog2(plaintext_modulus), + elements_size, + ); - println!("Noise in response: {:?}", unsafe { - sk.measure_noise(&response) - }); + println!("Noise in response: {:?}", unsafe { + sk.measure_noise(&response) + }); - plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec() - }); + plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec() + }); - assert_eq!(&database[index], &answer); + assert_eq!(&database[index], &answer); - Ok(()) + Ok(()) } diff --git a/crates/fhe/examples/sealpir.rs b/crates/fhe/examples/sealpir.rs index 0d8e169..c97e1b8 100644 --- a/crates/fhe/examples/sealpir.rs +++ b/crates/fhe/examples/sealpir.rs @@ -12,8 +12,8 @@ use console::style; use fhe::bfv; use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation}; use fhe_traits::{ - DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime, - FheEncrypter, Serialize, + DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime, + FheEncrypter, Serialize, }; use fhe_util::{div_ceil, ilog2, inverse, transcode_bidirectional, transcode_to_bytes}; use indicatif::HumanBytes; @@ -21,323 +21,323 @@ use itertools::Itertools; use rand::{rngs::OsRng, thread_rng, RngCore}; use std::{env, error::Error, process::exit, sync::Arc}; use util::{ - encode_database, generate_database, number_elements_per_plaintext, - timeit::{timeit, timeit_n}, + encode_database, generate_database, number_elements_per_plaintext, + timeit::{timeit, timeit_n}, }; fn print_notice_and_exit(max_element_size: usize, error: Option) { - println!( - "{} SealPIR with fhe.rs", - style(" overview:").magenta().bold() - ); - println!( - "{} sealpir [-h] [--help] [--database_size=] [--element_size=]", - style(" usage:").magenta().bold() - ); - println!( - "{} {} must be at least 1, and {} must be between 1 and {}", - style("constraints:").magenta().bold(), - style("database_size").blue(), - style("element_size").blue(), - max_element_size - ); - if let Some(error) = error { - println!("{} {}", style(" error:").red().bold(), error); - } - exit(0); + println!( + "{} SealPIR with fhe.rs", + style(" overview:").magenta().bold() + ); + println!( + "{} sealpir [-h] [--help] [--database_size=] [--element_size=]", + style(" usage:").magenta().bold() + ); + println!( + "{} {} must be at least 1, and {} must be between 1 and {}", + style("constraints:").magenta().bold(), + style("database_size").blue(), + style("element_size").blue(), + max_element_size + ); + if let Some(error) = error { + println!("{} {}", style(" error:").red().bold(), error); + } + exit(0); } fn main() -> Result<(), Box> { - let degree = 4096usize; - let plaintext_modulus = 2056193; - let moduli_sizes = [36, 36, 37]; + let degree = 4096usize; + let plaintext_modulus = 2056193; + let moduli_sizes = [36, 36, 37]; - // Compute what is the maximum byte-length of an element to fit within one - // ciphertext. Each coefficient of the ciphertext polynomial can contain - // floor(log2(plaintext_modulus)) bits. - let max_element_size = (ilog2(plaintext_modulus) * degree) / 8; + // Compute what is the maximum byte-length of an element to fit within one + // ciphertext. Each coefficient of the ciphertext polynomial can contain + // floor(log2(plaintext_modulus)) bits. + let max_element_size = (ilog2(plaintext_modulus) * degree) / 8; - // This executable is a command line tool which enables to specify different - // database and element sizes. - let args: Vec = env::args().skip(1).collect(); + // This executable is a command line tool which enables to specify different + // database and element sizes. + let args: Vec = env::args().skip(1).collect(); - // Print the help if requested. - if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) { - print_notice_and_exit(max_element_size, None) - } + // Print the help if requested. + if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) { + print_notice_and_exit(max_element_size, None) + } - // Use the default values from . - let mut database_size = 1 << 16; - let mut elements_size = 1024; + // Use the default values from . + let mut database_size = 1 << 16; + let mut elements_size = 1024; - // Update the database size and/or element size depending on the arguments - // provided. - for arg in &args { - if arg.starts_with("--database_size") { - let a: Vec<&str> = arg.rsplit('=').collect(); - if a.len() != 2 || a[0].parse::().is_err() { - print_notice_and_exit( - max_element_size, - Some("Invalid `--database_size` command".to_string()), - ) - } else { - database_size = a[0].parse::().unwrap() - } - } else if arg.starts_with("--element_size") { - let a: Vec<&str> = arg.rsplit('=').collect(); - if a.len() != 2 || a[0].parse::().is_err() { - print_notice_and_exit( - max_element_size, - Some("Invalid `--element_size` command".to_string()), - ) - } else { - elements_size = a[0].parse::().unwrap() - } - } else { - print_notice_and_exit( - max_element_size, - Some(format!("Unrecognized command: {arg}")), - ) - } - } - if elements_size > max_element_size || elements_size == 0 || database_size == 0 { - print_notice_and_exit( - max_element_size, - Some("Element or database sizes out of bound".to_string()), - ) - } + // Update the database size and/or element size depending on the arguments + // provided. + for arg in &args { + if arg.starts_with("--database_size") { + let a: Vec<&str> = arg.rsplit('=').collect(); + if a.len() != 2 || a[0].parse::().is_err() { + print_notice_and_exit( + max_element_size, + Some("Invalid `--database_size` command".to_string()), + ) + } else { + database_size = a[0].parse::().unwrap() + } + } else if arg.starts_with("--element_size") { + let a: Vec<&str> = arg.rsplit('=').collect(); + if a.len() != 2 || a[0].parse::().is_err() { + print_notice_and_exit( + max_element_size, + Some("Invalid `--element_size` command".to_string()), + ) + } else { + elements_size = a[0].parse::().unwrap() + } + } else { + print_notice_and_exit( + max_element_size, + Some(format!("Unrecognized command: {arg}")), + ) + } + } + if elements_size > max_element_size || elements_size == 0 || database_size == 0 { + print_notice_and_exit( + max_element_size, + Some("Element or database sizes out of bound".to_string()), + ) + } - // The parameters are within bound, let's go! Let's first display some - // information about the database. - println!("# SealPIR with fhe.rs"); - println!( - "database of {}", - HumanBytes((database_size * elements_size) as u64) - ); - println!("\tdatabase_size = {database_size}"); - println!("\telements_size = {elements_size}"); + // The parameters are within bound, let's go! Let's first display some + // information about the database. + println!("# SealPIR with fhe.rs"); + println!( + "database of {}", + HumanBytes((database_size * elements_size) as u64) + ); + println!("\tdatabase_size = {database_size}"); + println!("\telements_size = {elements_size}"); - // Generation of a random database. - let database = timeit!("Database generation", { - generate_database(database_size, elements_size) - }); + // Generation of a random database. + let database = timeit!("Database generation", { + generate_database(database_size, elements_size) + }); - // Let's generate the BFV parameters structure. - let params = timeit!( - "Parameters generation", - Arc::new( - bfv::BfvParametersBuilder::new() - .set_degree(degree) - .set_plaintext_modulus(plaintext_modulus) - .set_moduli_sizes(&moduli_sizes) - .build() - .unwrap() - ) - ); + // Let's generate the BFV parameters structure. + let params = timeit!( + "Parameters generation", + Arc::new( + bfv::BfvParametersBuilder::new() + .set_degree(degree) + .set_plaintext_modulus(plaintext_modulus) + .set_moduli_sizes(&moduli_sizes) + .build() + .unwrap() + ) + ); - // Proprocess the database on the server side: the database will be reshaped - // so as to pack as many values as possible in every row so that it fits in one - // ciphertext, and each element will be encoded as a polynomial in Ntt - // representation. - let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", { - encode_database(&database, params.clone(), 1) - }); + // Proprocess the database on the server side: the database will be reshaped + // so as to pack as many values as possible in every row so that it fits in one + // ciphertext, and each element will be encoded as a polynomial in Ntt + // representation. + let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", { + encode_database(&database, params.clone(), 1) + }); - // Client setup: the client generates a secret key, and an evaluation key for - // the server will which enable to obliviously expand a ciphertext up to (dim1 + - // dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)). - let (sk, ek_expansion_serialized) = timeit!("Client setup", { - let sk = bfv::SecretKey::random(¶ms, &mut OsRng); - let level = ilog2((dim1 + dim2).next_power_of_two() as u64); - println!("expansion_level = {level}"); - let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)? - .enable_expansion(level)? - .build(&mut thread_rng())?; - let ek_expansion_serialized = ek_expansion.to_bytes(); - (sk, ek_expansion_serialized) - }); - println!( - "📄 Evaluation key: {}", - HumanBytes(ek_expansion_serialized.len() as u64) - ); + // Client setup: the client generates a secret key, and an evaluation key for + // the server will which enable to obliviously expand a ciphertext up to (dim1 + + // dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)). + let (sk, ek_expansion_serialized) = timeit!("Client setup", { + let sk = bfv::SecretKey::random(¶ms, &mut OsRng); + let level = ilog2((dim1 + dim2).next_power_of_two() as u64); + println!("expansion_level = {level}"); + let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)? + .enable_expansion(level)? + .build(&mut thread_rng())?; + let ek_expansion_serialized = ek_expansion.to_bytes(); + (sk, ek_expansion_serialized) + }); + println!( + "📄 Evaluation key: {}", + HumanBytes(ek_expansion_serialized.len() as u64) + ); - // Server setup: the server receives the evaluation key and deserializes it. - let ek_expansion = timeit!( - "Server setup", - bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)? - ); + // Server setup: the server receives the evaluation key and deserializes it. + let ek_expansion = timeit!( + "Server setup", + bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)? + ); - // Client query: when the client wants to retrieve the `index`-th row of the - // original database, it first computes to which row it corresponds in the - // original database, and then encrypt a selection vector with 0 everywhere, - // except at two indices i and (dim1 + j) such that `query_index = i * dim 2 + - // j` where it sets the value (2^level)^(-1) modulo the plaintext space. - // It then encodes this vector as a `polynomial` and encrypt the plaintext. - // The ciphertext is set at level `1`, which means that one of the three moduli - // has been dropped already; the reason is that the expansion will happen at - // level 0 (with all three moduli) and then one of the moduli will be dropped - // to reduce the noise. - let index = (thread_rng().next_u64() as usize) % database_size; - let query = timeit!("Client query", { - let level = ilog2((dim1 + dim2).next_power_of_two() as u64); - let query_index = index - / number_elements_per_plaintext( - params.degree(), - ilog2(plaintext_modulus), - elements_size, - ); - let mut pt = vec![0u64; dim1 + dim2]; - let inv = inverse(1 << level, plaintext_modulus).unwrap(); - pt[query_index / dim2] = inv; - pt[dim1 + (query_index % dim2)] = inv; - let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?; - let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?; - query.to_bytes() - }); - println!("📄 Query: {}", HumanBytes(query.len() as u64)); + // Client query: when the client wants to retrieve the `index`-th row of the + // original database, it first computes to which row it corresponds in the + // original database, and then encrypt a selection vector with 0 everywhere, + // except at two indices i and (dim1 + j) such that `query_index = i * dim 2 + + // j` where it sets the value (2^level)^(-1) modulo the plaintext space. + // It then encodes this vector as a `polynomial` and encrypt the plaintext. + // The ciphertext is set at level `1`, which means that one of the three moduli + // has been dropped already; the reason is that the expansion will happen at + // level 0 (with all three moduli) and then one of the moduli will be dropped + // to reduce the noise. + let index = (thread_rng().next_u64() as usize) % database_size; + let query = timeit!("Client query", { + let level = ilog2((dim1 + dim2).next_power_of_two() as u64); + let query_index = index + / number_elements_per_plaintext( + params.degree(), + ilog2(plaintext_modulus), + elements_size, + ); + let mut pt = vec![0u64; dim1 + dim2]; + let inv = inverse(1 << level, plaintext_modulus).unwrap(); + pt[query_index / dim2] = inv; + pt[dim1 + (query_index % dim2)] = inv; + let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?; + let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?; + query.to_bytes() + }); + println!("📄 Query: {}", HumanBytes(query.len() as u64)); - // Server response: The server receives the query, and after deserializing it, - // performs the following steps: - // 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts. - // If the client created the query correctly, the server will have obtained - // `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and - // `dim1 + j`th ones encrypting `1`. - // 2- It computes the inner product of the first `dim1` ciphertexts with the - // columns if the database viewed as a dim1 * dim2 matrix, and modulo-switch - // the ciphertext once. - // 3- It parses the resulting ciphertexts as vector of plaintexts, and compute - // the inner product of the last `dim2` ciphertexts from step 1 with the - // transposed of the plaintext obtained above. - // The operation is done `5` times to compute an average response time. - let responses: Vec> = timeit_n!("Server response", 5, { - let start = std::time::Instant::now(); - let query = bfv::Ciphertext::from_bytes(&query, ¶ms); - let query = query.unwrap(); - let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?; - println!("Expand: {}", DisplayDuration(start.elapsed())); + // Server response: The server receives the query, and after deserializing it, + // performs the following steps: + // 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts. + // If the client created the query correctly, the server will have obtained + // `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and + // `dim1 + j`th ones encrypting `1`. + // 2- It computes the inner product of the first `dim1` ciphertexts with the + // columns if the database viewed as a dim1 * dim2 matrix, and modulo-switch + // the ciphertext once. + // 3- It parses the resulting ciphertexts as vector of plaintexts, and compute + // the inner product of the last `dim2` ciphertexts from step 1 with the + // transposed of the plaintext obtained above. + // The operation is done `5` times to compute an average response time. + let responses: Vec> = timeit_n!("Server response", 5, { + let start = std::time::Instant::now(); + let query = bfv::Ciphertext::from_bytes(&query, ¶ms); + let query = query.unwrap(); + let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?; + println!("Expand: {}", DisplayDuration(start.elapsed())); - let query_vec = &expanded_query[..dim1]; - let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| { - let column = database.iter().skip(i).step_by(dim2); - let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?; - c.mod_switch_to_last_level(); - Ok(c) - }; + let query_vec = &expanded_query[..dim1]; + let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| { + let column = database.iter().skip(i).step_by(dim2); + let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?; + c.mod_switch_to_last_level(); + Ok(c) + }; - let dot_products = (0..dim2) - .map(|i| dot_product_mod_switch(i, &preprocessed_database)) - .collect::>>()?; + let dot_products = (0..dim2) + .map(|i| dot_product_mod_switch(i, &preprocessed_database)) + .collect::>>()?; - let fold = dot_products - .iter() - .map(|c| { - let mut pt_values = Vec::with_capacity(div_ceil( - 2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize)), - ilog2(plaintext_modulus), - )); - pt_values.append(&mut transcode_bidirectional( - c.get(0).unwrap().coefficients().as_slice().unwrap(), - 64 - params.moduli()[0].leading_zeros() as usize, - ilog2(plaintext_modulus), - )); - pt_values.append(&mut transcode_bidirectional( - c.get(1).unwrap().coefficients().as_slice().unwrap(), - 64 - params.moduli()[0].leading_zeros() as usize, - ilog2(plaintext_modulus), - )); - unsafe { - Ok(bfv::PlaintextVec::try_encode_vt( - &pt_values, - bfv::Encoding::poly_at_level(1), - ¶ms, - )? - .0) - } - }) - .collect::>>>()?; - (0..fold[0].len()) - .map(|i| { - let mut outi = bfv::dot_product_scalar( - expanded_query[dim1..].iter(), - fold.iter().map(|pts| pts.get(i).unwrap()), - )?; - outi.mod_switch_to_last_level(); - Ok(outi.to_bytes()) - }) - .collect::>>>()? - }); - println!( - "📄 Response: {}", - HumanBytes(responses.iter().map(|r| r.len()).sum::() as u64) - ); + let fold = dot_products + .iter() + .map(|c| { + let mut pt_values = Vec::with_capacity(div_ceil( + 2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize)), + ilog2(plaintext_modulus), + )); + pt_values.append(&mut transcode_bidirectional( + c.get(0).unwrap().coefficients().as_slice().unwrap(), + 64 - params.moduli()[0].leading_zeros() as usize, + ilog2(plaintext_modulus), + )); + pt_values.append(&mut transcode_bidirectional( + c.get(1).unwrap().coefficients().as_slice().unwrap(), + 64 - params.moduli()[0].leading_zeros() as usize, + ilog2(plaintext_modulus), + )); + unsafe { + Ok(bfv::PlaintextVec::try_encode_vt( + &pt_values, + bfv::Encoding::poly_at_level(1), + ¶ms, + )? + .0) + } + }) + .collect::>>>()?; + (0..fold[0].len()) + .map(|i| { + let mut outi = bfv::dot_product_scalar( + expanded_query[dim1..].iter(), + fold.iter().map(|pts| pts.get(i).unwrap()), + )?; + outi.mod_switch_to_last_level(); + Ok(outi.to_bytes()) + }) + .collect::>>>()? + }); + println!( + "📄 Response: {}", + HumanBytes(responses.iter().map(|r| r.len()).sum::() as u64) + ); - // Client processing: Upon reception of the response, the client decrypts - // the ciphertexts and recover the "ciphertexts" which were parsed as plaintext, - // which it decrypts too. Finally, it outputs the plaintext bytes, offset by the - // correct value (remember the database was reshaped to maximize how many - // elements) were embedded in a single ciphertext. - let answer = timeit!("Client answer", { - let responses = responses - .iter() - .map(|r| bfv::Ciphertext::from_bytes(r, ¶ms).unwrap()) - .collect_vec(); - let decrypted_pt = responses - .iter() - .map(|r| sk.try_decrypt(r).unwrap()) - .collect_vec(); - let decrypted_vec = decrypted_pt - .iter() - .flat_map(|pt| Vec::::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap()) - .collect_vec(); - let expect_ncoefficients = div_ceil( - params.degree() * (64 - params.moduli()[0].leading_zeros() as usize), - ilog2(plaintext_modulus), - ); - assert!(decrypted_vec.len() >= 2 * expect_ncoefficients); - let mut poly0 = transcode_bidirectional( - &decrypted_vec[..expect_ncoefficients], - ilog2(plaintext_modulus), - 64 - params.moduli()[0].leading_zeros() as usize, - ); - let mut poly1 = transcode_bidirectional( - &decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients], - ilog2(plaintext_modulus), - 64 - params.moduli()[0].leading_zeros() as usize, - ); - assert!(poly0.len() >= params.degree()); - assert!(poly1.len() >= params.degree()); - poly0.truncate(params.degree()); - poly1.truncate(params.degree()); + // Client processing: Upon reception of the response, the client decrypts + // the ciphertexts and recover the "ciphertexts" which were parsed as plaintext, + // which it decrypts too. Finally, it outputs the plaintext bytes, offset by the + // correct value (remember the database was reshaped to maximize how many + // elements) were embedded in a single ciphertext. + let answer = timeit!("Client answer", { + let responses = responses + .iter() + .map(|r| bfv::Ciphertext::from_bytes(r, ¶ms).unwrap()) + .collect_vec(); + let decrypted_pt = responses + .iter() + .map(|r| sk.try_decrypt(r).unwrap()) + .collect_vec(); + let decrypted_vec = decrypted_pt + .iter() + .flat_map(|pt| Vec::::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap()) + .collect_vec(); + let expect_ncoefficients = div_ceil( + params.degree() * (64 - params.moduli()[0].leading_zeros() as usize), + ilog2(plaintext_modulus), + ); + assert!(decrypted_vec.len() >= 2 * expect_ncoefficients); + let mut poly0 = transcode_bidirectional( + &decrypted_vec[..expect_ncoefficients], + ilog2(plaintext_modulus), + 64 - params.moduli()[0].leading_zeros() as usize, + ); + let mut poly1 = transcode_bidirectional( + &decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients], + ilog2(plaintext_modulus), + 64 - params.moduli()[0].leading_zeros() as usize, + ); + assert!(poly0.len() >= params.degree()); + assert!(poly1.len() >= params.degree()); + poly0.truncate(params.degree()); + poly1.truncate(params.degree()); - let ctx = Arc::new(Context::new(¶ms.moduli()[..1], params.degree())?); - let ct = bfv::Ciphertext::new( - vec![ - Poly::try_convert_from(poly0, &ctx, true, Representation::Ntt)?, - Poly::try_convert_from(poly1, &ctx, true, Representation::Ntt)?, - ], - ¶ms, - )?; + let ctx = Arc::new(Context::new(¶ms.moduli()[..1], params.degree())?); + let ct = bfv::Ciphertext::new( + vec![ + Poly::try_convert_from(poly0, &ctx, true, Representation::Ntt)?, + Poly::try_convert_from(poly1, &ctx, true, Representation::Ntt)?, + ], + ¶ms, + )?; - let pt = sk.try_decrypt(&ct).unwrap(); - let pt = Vec::::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap(); - let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus)); - let offset = index - % number_elements_per_plaintext( - params.degree(), - ilog2(plaintext_modulus), - elements_size, - ); + let pt = sk.try_decrypt(&ct).unwrap(); + let pt = Vec::::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap(); + let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus)); + let offset = index + % number_elements_per_plaintext( + params.degree(), + ilog2(plaintext_modulus), + elements_size, + ); - println!("Noise in response (ct): {:?}", unsafe { - sk.measure_noise(&ct) - }); + println!("Noise in response (ct): {:?}", unsafe { + sk.measure_noise(&ct) + }); - plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec() - }); + plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec() + }); - // Assert that the answer is indeed the `index`-th element of the initial - // database. - assert_eq!(&database[index], &answer); + // Assert that the answer is indeed the `index`-th element of the initial + // database. + assert_eq!(&database[index], &answer); - Ok(()) + Ok(()) } diff --git a/crates/fhe/examples/util.rs b/crates/fhe/examples/util.rs index 5541ed0..f01d418 100644 --- a/crates/fhe/examples/util.rs +++ b/crates/fhe/examples/util.rs @@ -7,37 +7,37 @@ use std::{cmp::min, fmt, sync::Arc, time::Duration}; /// Macros to time code and display a human-readable duration. pub mod timeit { - #[allow(unused_macros)] - macro_rules! timeit_n { - ($name:expr, $loops:expr, $code:expr) => {{ - use util::DisplayDuration; - let start = std::time::Instant::now(); + #[allow(unused_macros)] + macro_rules! timeit_n { + ($name:expr, $loops:expr, $code:expr) => {{ + use util::DisplayDuration; + let start = std::time::Instant::now(); - #[allow(clippy::reversed_empty_ranges)] - for _ in 1..$loops { - let _ = $code; - } - let r = $code; - println!( - "⏱ {}: {}", - $name, - DisplayDuration(start.elapsed() / $loops) - ); - r - }}; - } + #[allow(clippy::reversed_empty_ranges)] + for _ in 1..$loops { + let _ = $code; + } + let r = $code; + println!( + "⏱ {}: {}", + $name, + DisplayDuration(start.elapsed() / $loops) + ); + r + }}; + } - #[allow(unused_macros)] - macro_rules! timeit { - ($name:expr, $code:expr) => {{ - timeit_n!($name, 1, $code) - }}; - } + #[allow(unused_macros)] + macro_rules! timeit { + ($name:expr, $code:expr) => {{ + timeit_n!($name, 1, $code) + }}; + } - #[allow(unused_imports)] - pub(crate) use timeit; - #[allow(unused_imports)] - pub(crate) use timeit_n; + #[allow(unused_imports)] + pub(crate) use timeit; + #[allow(unused_imports)] + pub(crate) use timeit_n; } /// Utility struct for displaying human-readable duration of the form "10.5 ms", @@ -45,17 +45,17 @@ pub mod timeit { pub struct DisplayDuration(pub Duration); impl fmt::Display for DisplayDuration { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let duration_ns = self.0.as_nanos(); - if duration_ns < 1_000_u128 { - write!(f, "{duration_ns} ns") - } else if duration_ns < 1_000_000_u128 { - write!(f, "{} μs", (duration_ns + 500) / 1_000) - } else { - let duration_ms_times_10 = (duration_ns + 50_000) / (100_000); - write!(f, "{} ms", (duration_ms_times_10 as f64) / 10.0) - } - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let duration_ns = self.0.as_nanos(); + if duration_ns < 1_000_u128 { + write!(f, "{duration_ns} ns") + } else if duration_ns < 1_000_000_u128 { + write!(f, "{} μs", (duration_ns + 500) / 1_000) + } else { + let duration_ms_times_10 = (duration_ns + 50_000) / (100_000); + write!(f, "{} ms", (duration_ms_times_10 as f64) / 10.0) + } + } } // Utility functions for Private Information Retrieval. @@ -64,60 +64,60 @@ impl fmt::Display for DisplayDuration { /// little endian encoding of the index. When the element size is less than 4B, /// the encoding is truncated. pub fn generate_database(database_size: usize, elements_size: usize) -> Vec> { - assert!(elements_size > 0 && database_size > 0); - let mut database = vec![vec![0u8; elements_size]; database_size]; - for (i, element) in database.iter_mut().enumerate() { - element[..min(4, elements_size)] - .copy_from_slice(&(i as u32).to_le_bytes()[..min(4, elements_size)]); - } - database + assert!(elements_size > 0 && database_size > 0); + let mut database = vec![vec![0u8; elements_size]; database_size]; + for (i, element) in database.iter_mut().enumerate() { + element[..min(4, elements_size)] + .copy_from_slice(&(i as u32).to_le_bytes()[..min(4, elements_size)]); + } + database } pub fn number_elements_per_plaintext( - degree: usize, - plaintext_nbits: usize, - elements_size: usize, + degree: usize, + plaintext_nbits: usize, + elements_size: usize, ) -> usize { - (plaintext_nbits * degree) / (elements_size * 8) + (plaintext_nbits * degree) / (elements_size * 8) } pub fn encode_database( - database: &Vec>, - par: Arc, - level: usize, + database: &Vec>, + par: Arc, + level: usize, ) -> (Vec, (usize, usize)) { - assert!(!database.is_empty()); + assert!(!database.is_empty()); - let elements_size = database[0].len(); - let plaintext_nbits = ilog2(par.plaintext()); - let number_elements_per_plaintext = - number_elements_per_plaintext(par.degree(), plaintext_nbits, elements_size); - let number_rows = - (database.len() + number_elements_per_plaintext - 1) / number_elements_per_plaintext; - println!("number_rows = {number_rows}"); - println!("number_elements_per_plaintext = {number_elements_per_plaintext}"); - let dimension_1 = (number_rows as f64).sqrt().ceil() as usize; - let dimension_2 = (number_rows + dimension_1 - 1) / dimension_1; - println!("dimensions = {dimension_1} {dimension_2}"); - println!("dimension = {}", dimension_1 * dimension_2); - let mut preprocessed_database = - vec![ - bfv::Plaintext::zero(bfv::Encoding::poly_at_level(level), &par).unwrap(); - dimension_1 * dimension_2 - ]; - (0..number_rows).for_each(|i| { - let mut serialized_plaintext = vec![0u8; number_elements_per_plaintext * elements_size]; - for j in 0..number_elements_per_plaintext { - if let Some(pt) = database.get(j + i * number_elements_per_plaintext) { - serialized_plaintext[j * elements_size..(j + 1) * elements_size].copy_from_slice(pt) - } - } - let pt_values = transcode_from_bytes(&serialized_plaintext, plaintext_nbits); - preprocessed_database[i] = - bfv::Plaintext::try_encode(&pt_values, bfv::Encoding::poly_at_level(level), &par) - .unwrap(); - }); - (preprocessed_database, (dimension_1, dimension_2)) + let elements_size = database[0].len(); + let plaintext_nbits = ilog2(par.plaintext()); + let number_elements_per_plaintext = + number_elements_per_plaintext(par.degree(), plaintext_nbits, elements_size); + let number_rows = + (database.len() + number_elements_per_plaintext - 1) / number_elements_per_plaintext; + println!("number_rows = {number_rows}"); + println!("number_elements_per_plaintext = {number_elements_per_plaintext}"); + let dimension_1 = (number_rows as f64).sqrt().ceil() as usize; + let dimension_2 = (number_rows + dimension_1 - 1) / dimension_1; + println!("dimensions = {dimension_1} {dimension_2}"); + println!("dimension = {}", dimension_1 * dimension_2); + let mut preprocessed_database = + vec![ + bfv::Plaintext::zero(bfv::Encoding::poly_at_level(level), &par).unwrap(); + dimension_1 * dimension_2 + ]; + (0..number_rows).for_each(|i| { + let mut serialized_plaintext = vec![0u8; number_elements_per_plaintext * elements_size]; + for j in 0..number_elements_per_plaintext { + if let Some(pt) = database.get(j + i * number_elements_per_plaintext) { + serialized_plaintext[j * elements_size..(j + 1) * elements_size].copy_from_slice(pt) + } + } + let pt_values = transcode_from_bytes(&serialized_plaintext, plaintext_nbits); + preprocessed_database[i] = + bfv::Plaintext::try_encode(&pt_values, bfv::Encoding::poly_at_level(level), &par) + .unwrap(); + }); + (preprocessed_database, (dimension_1, dimension_2)) } #[allow(dead_code)] diff --git a/crates/fhe/src/bfv/ciphertext.rs b/crates/fhe/src/bfv/ciphertext.rs index 0cf3653..a0522c1 100644 --- a/crates/fhe/src/bfv/ciphertext.rs +++ b/crates/fhe/src/bfv/ciphertext.rs @@ -1,12 +1,12 @@ //! Ciphertext type in the BFV encryption scheme. use crate::bfv::{ - parameters::BfvParameters, proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, + parameters::BfvParameters, proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, }; use crate::{Error, Result}; use fhe_math::rq::{Poly, Representation}; use fhe_traits::{ - DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize, + DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize, }; use protobuf::Message; use rand::SeedableRng; @@ -16,287 +16,287 @@ use std::sync::Arc; /// A ciphertext encrypting a plaintext. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Ciphertext { - /// The parameters of the underlying BFV encryption scheme. - pub(crate) par: Arc, + /// The parameters of the underlying BFV encryption scheme. + pub(crate) par: Arc, - /// The seed that generated the polynomial c1 in a fresh ciphertext. - pub(crate) seed: Option<::Seed>, + /// The seed that generated the polynomial c1 in a fresh ciphertext. + pub(crate) seed: Option<::Seed>, - /// The ciphertext elements. - pub(crate) c: Vec, + /// The ciphertext elements. + pub(crate) c: Vec, - /// The ciphertext level - pub(crate) level: usize, + /// The ciphertext level + pub(crate) level: usize, } impl Ciphertext { - /// Modulo switch the ciphertext to the last level. - pub fn mod_switch_to_last_level(&mut self) { - self.level = self.par.max_level(); - let last_ctx = self.par.ctx_at_level(self.level).unwrap(); - self.seed = None; - self.c.iter_mut().for_each(|ci| { - if ci.ctx() != last_ctx { - ci.change_representation(Representation::PowerBasis); - assert!(ci.mod_switch_down_to(last_ctx).is_ok()); - ci.change_representation(Representation::Ntt); - } - }); - } + /// Modulo switch the ciphertext to the last level. + pub fn mod_switch_to_last_level(&mut self) { + self.level = self.par.max_level(); + let last_ctx = self.par.ctx_at_level(self.level).unwrap(); + self.seed = None; + self.c.iter_mut().for_each(|ci| { + if ci.ctx() != last_ctx { + ci.change_representation(Representation::PowerBasis); + assert!(ci.mod_switch_down_to(last_ctx).is_ok()); + ci.change_representation(Representation::Ntt); + } + }); + } - /// Modulo switch the ciphertext to the next level. - pub fn mod_switch_to_next_level(&mut self) { - if self.level < self.par.max_level() { - self.seed = None; - self.c.iter_mut().for_each(|ci| { - ci.change_representation(Representation::PowerBasis); - assert!(ci.mod_switch_down_next().is_ok()); - ci.change_representation(Representation::Ntt); - }); - self.level += 1 - } - } + /// Modulo switch the ciphertext to the next level. + pub fn mod_switch_to_next_level(&mut self) { + if self.level < self.par.max_level() { + self.seed = None; + self.c.iter_mut().for_each(|ci| { + ci.change_representation(Representation::PowerBasis); + assert!(ci.mod_switch_down_next().is_ok()); + ci.change_representation(Representation::Ntt); + }); + self.level += 1 + } + } - /// Create a ciphertext from a vector of polynomials. - /// A ciphertext must contain at least two polynomials, and all polynomials - /// must be in Ntt representation and with the same context. - pub fn new(c: Vec, par: &Arc) -> Result { - if c.len() < 2 { - return Err(Error::TooFewValues(c.len(), 2)); - } + /// Create a ciphertext from a vector of polynomials. + /// A ciphertext must contain at least two polynomials, and all polynomials + /// must be in Ntt representation and with the same context. + pub fn new(c: Vec, par: &Arc) -> Result { + if c.len() < 2 { + return Err(Error::TooFewValues(c.len(), 2)); + } - let ctx = c[0].ctx(); - let level = par.level_of_ctx(ctx)?; + let ctx = c[0].ctx(); + let level = par.level_of_ctx(ctx)?; - // Check that all polynomials have the expected representation and context. - for ci in c.iter() { - if ci.representation() != &Representation::Ntt { - return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation( - ci.representation().clone(), - Representation::Ntt, - ))); - } - if ci.ctx() != ctx { - return Err(Error::MathError(fhe_math::Error::InvalidContext)); - } - } + // Check that all polynomials have the expected representation and context. + for ci in c.iter() { + if ci.representation() != &Representation::Ntt { + return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation( + ci.representation().clone(), + Representation::Ntt, + ))); + } + if ci.ctx() != ctx { + return Err(Error::MathError(fhe_math::Error::InvalidContext)); + } + } - Ok(Self { - par: par.clone(), - seed: None, - c, - level, - }) - } + Ok(Self { + par: par.clone(), + seed: None, + c, + level, + }) + } - /// Get the i-th polynomial of the ciphertext. - pub fn get(&self, i: usize) -> Option<&Poly> { - self.c.get(i) - } + /// Get the i-th polynomial of the ciphertext. + pub fn get(&self, i: usize) -> Option<&Poly> { + self.c.get(i) + } } impl FheCiphertext for Ciphertext {} impl FheParametrized for Ciphertext { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl Serialize for Ciphertext { - fn to_bytes(&self) -> Vec { - CiphertextProto::from(self).write_to_bytes().unwrap() - } + fn to_bytes(&self) -> Vec { + CiphertextProto::from(self).write_to_bytes().unwrap() + } } impl DeserializeParametrized for Ciphertext { - fn from_bytes(bytes: &[u8], par: &Arc) -> Result { - if let Ok(ctp) = CiphertextProto::parse_from_bytes(bytes) { - Ciphertext::try_convert_from(&ctp, par) - } else { - Err(Error::SerializationError) - } - } + fn from_bytes(bytes: &[u8], par: &Arc) -> Result { + if let Ok(ctp) = CiphertextProto::parse_from_bytes(bytes) { + Ciphertext::try_convert_from(&ctp, par) + } else { + Err(Error::SerializationError) + } + } - type Error = Error; + type Error = Error; } impl Ciphertext { - /// Generate the zero ciphertext. - pub fn zero(par: &Arc) -> Self { - Self { - par: par.clone(), - seed: None, - c: Default::default(), - level: 0, - } - } + /// Generate the zero ciphertext. + pub fn zero(par: &Arc) -> Self { + Self { + par: par.clone(), + seed: None, + c: Default::default(), + level: 0, + } + } } /// Conversions from and to protobuf. impl From<&Ciphertext> for CiphertextProto { - fn from(ct: &Ciphertext) -> Self { - let mut proto = CiphertextProto::new(); - for i in 0..ct.c.len() - 1 { - proto.c.push(ct.c[i].to_bytes()) - } - if let Some(seed) = ct.seed { - proto.seed = seed.to_vec() - } else { - proto.c.push(ct.c[ct.c.len() - 1].to_bytes()) - } - proto.level = ct.level as u32; - proto - } + fn from(ct: &Ciphertext) -> Self { + let mut proto = CiphertextProto::new(); + for i in 0..ct.c.len() - 1 { + proto.c.push(ct.c[i].to_bytes()) + } + if let Some(seed) = ct.seed { + proto.seed = seed.to_vec() + } else { + proto.c.push(ct.c[ct.c.len() - 1].to_bytes()) + } + proto.level = ct.level as u32; + proto + } } impl TryConvertFrom<&CiphertextProto> for Ciphertext { - fn try_convert_from(value: &CiphertextProto, par: &Arc) -> Result { - if value.c.is_empty() || (value.c.len() == 1 && value.seed.is_empty()) { - return Err(Error::DefaultError("Not enough polynomials".to_string())); - } + fn try_convert_from(value: &CiphertextProto, par: &Arc) -> Result { + if value.c.is_empty() || (value.c.len() == 1 && value.seed.is_empty()) { + return Err(Error::DefaultError("Not enough polynomials".to_string())); + } - if value.level as usize > par.max_level() { - return Err(Error::DefaultError("Invalid level".to_string())); - } + if value.level as usize > par.max_level() { + return Err(Error::DefaultError("Invalid level".to_string())); + } - let ctx = par.ctx_at_level(value.level as usize)?; + let ctx = par.ctx_at_level(value.level as usize)?; - let mut seed = None; + let mut seed = None; - let mut c = Vec::with_capacity(value.c.len() + 1); - for cip in &value.c { - c.push(Poly::from_bytes(cip, ctx)?) - } + let mut c = Vec::with_capacity(value.c.len() + 1); + for cip in &value.c { + c.push(Poly::from_bytes(cip, ctx)?) + } - if !value.seed.is_empty() { - let try_seed = ::Seed::try_from(value.seed.clone()); - if try_seed.is_err() { - return Err(Error::MathError(fhe_math::Error::InvalidSeedSize( - value.seed.len(), - ::Seed::default().len(), - ))); - } - seed = try_seed.ok(); - let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, seed.unwrap()); - unsafe { c1.allow_variable_time_computations() } - c.push(c1) - } + if !value.seed.is_empty() { + let try_seed = ::Seed::try_from(value.seed.clone()); + if try_seed.is_err() { + return Err(Error::MathError(fhe_math::Error::InvalidSeedSize( + value.seed.len(), + ::Seed::default().len(), + ))); + } + seed = try_seed.ok(); + let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, seed.unwrap()); + unsafe { c1.allow_variable_time_computations() } + c.push(c1) + } - Ok(Ciphertext { - par: par.clone(), - seed, - c, - level: value.level as usize, - }) - } + Ok(Ciphertext { + par: par.clone(), + seed, + c, + level: value.level as usize, + }) + } } #[cfg(test)] mod tests { - use crate::bfv::{ - proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, BfvParameters, - Ciphertext, Encoding, Plaintext, SecretKey, - }; - use fhe_traits::FheDecrypter; - use fhe_traits::{DeserializeParametrized, FheEncoder, FheEncrypter, Serialize}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use crate::bfv::{ + proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, BfvParameters, + Ciphertext, Encoding, Plaintext, SecretKey, + }; + use fhe_traits::FheDecrypter; + use fhe_traits::{DeserializeParametrized, FheEncoder, FheEncrypter, Serialize}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn proto_conversion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; - let ct = sk.try_encrypt(&pt, &mut rng)?; - let ct_proto = CiphertextProto::from(&ct); - assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?); + #[test] + fn proto_conversion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; + let ct = sk.try_encrypt(&pt, &mut rng)?; + let ct_proto = CiphertextProto::from(&ct); + assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?); - let ct = &ct * &ct; - let ct_proto = CiphertextProto::from(&ct); - assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?) - } - Ok(()) - } + let ct = &ct * &ct; + let ct_proto = CiphertextProto::from(&ct); + assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?) + } + Ok(()) + } - #[test] - fn serialize() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; - let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; - let ct_bytes = ct.to_bytes(); - assert_eq!(ct, Ciphertext::from_bytes(&ct_bytes, ¶ms)?); - } - Ok(()) - } + #[test] + fn serialize() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; + let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + let ct_bytes = ct.to_bytes(); + assert_eq!(ct, Ciphertext::from_bytes(&ct_bytes, ¶ms)?); + } + Ok(()) + } - #[test] - fn new() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; - let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; - let mut ct3 = &ct * &ct; + #[test] + fn new() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; + let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + let mut ct3 = &ct * &ct; - let c0 = ct3.get(0).unwrap(); - let c1 = ct3.get(1).unwrap(); - let c2 = ct3.get(2).unwrap(); + let c0 = ct3.get(0).unwrap(); + let c1 = ct3.get(1).unwrap(); + let c2 = ct3.get(2).unwrap(); - assert_eq!( - ct3, - Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)? - ); - assert_eq!(ct3.level, 0); + assert_eq!( + ct3, + Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)? + ); + assert_eq!(ct3.level, 0); - ct3.mod_switch_to_last_level(); + ct3.mod_switch_to_last_level(); - let c0 = ct3.get(0).unwrap(); - let c1 = ct3.get(1).unwrap(); - let c2 = ct3.get(2).unwrap(); - assert_eq!( - ct3, - Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)? - ); - assert_eq!(ct3.level, params.max_level()); - } + let c0 = ct3.get(0).unwrap(); + let c1 = ct3.get(1).unwrap(); + let c2 = ct3.get(2).unwrap(); + assert_eq!( + ct3, + Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)? + ); + assert_eq!(ct3.level, params.max_level()); + } - Ok(()) - } + Ok(()) + } - #[test] - fn mod_switch_to_last_level() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; - let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + #[test] + fn mod_switch_to_last_level() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; + let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; - assert_eq!(ct.level, 0); - ct.mod_switch_to_last_level(); - assert_eq!(ct.level, params.max_level()); + assert_eq!(ct.level, 0); + ct.mod_switch_to_last_level(); + assert_eq!(ct.level, params.max_level()); - let decrypted = sk.try_decrypt(&ct)?; - assert_eq!(decrypted.value, pt.value); - } + let decrypted = sk.try_decrypt(&ct)?; + assert_eq!(decrypted.value, pt.value); + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/encoding.rs b/crates/fhe/src/bfv/encoding.rs index 83dc776..bafd338 100644 --- a/crates/fhe/src/bfv/encoding.rs +++ b/crates/fhe/src/bfv/encoding.rs @@ -6,71 +6,71 @@ use fhe_traits::FhePlaintextEncoding; #[derive(Debug, Clone, Eq, PartialEq)] pub(crate) enum EncodingEnum { - Poly, - Simd, + Poly, + Simd, } impl Display for EncodingEnum { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } } /// An encoding for the plaintext. #[derive(Debug, Clone, Eq, PartialEq)] pub struct Encoding { - pub(crate) encoding: EncodingEnum, - pub(crate) level: usize, + pub(crate) encoding: EncodingEnum, + pub(crate) level: usize, } impl Encoding { - /// A Poly encoding encodes a vector as coefficients of a polynomial; - /// homomorphic operations are therefore polynomial operations. - pub fn poly() -> Self { - Self { - encoding: EncodingEnum::Poly, - level: 0, - } - } + /// A Poly encoding encodes a vector as coefficients of a polynomial; + /// homomorphic operations are therefore polynomial operations. + pub fn poly() -> Self { + Self { + encoding: EncodingEnum::Poly, + level: 0, + } + } - /// A Simd encoding encodes a vector so that homomorphic operations are - /// component-wise operations on the coefficients of the underlying vectors. - /// The Simd encoding require that the plaintext modulus is congruent to 1 - /// modulo the degree of the underlying polynomial. - pub fn simd() -> Self { - Self { - encoding: EncodingEnum::Simd, - level: 0, - } - } + /// A Simd encoding encodes a vector so that homomorphic operations are + /// component-wise operations on the coefficients of the underlying vectors. + /// The Simd encoding require that the plaintext modulus is congruent to 1 + /// modulo the degree of the underlying polynomial. + pub fn simd() -> Self { + Self { + encoding: EncodingEnum::Simd, + level: 0, + } + } - /// A poly encoding at a given level. - pub fn poly_at_level(level: usize) -> Self { - Self { - encoding: EncodingEnum::Poly, - level, - } - } + /// A poly encoding at a given level. + pub fn poly_at_level(level: usize) -> Self { + Self { + encoding: EncodingEnum::Poly, + level, + } + } - /// A simd encoding at a given level. - pub fn simd_at_level(level: usize) -> Self { - Self { - encoding: EncodingEnum::Simd, - level, - } - } + /// A simd encoding at a given level. + pub fn simd_at_level(level: usize) -> Self { + Self { + encoding: EncodingEnum::Simd, + level, + } + } } impl From for String { - fn from(e: Encoding) -> Self { - String::from(&e) - } + fn from(e: Encoding) -> Self { + String::from(&e) + } } impl From<&Encoding> for String { - fn from(e: &Encoding) -> Self { - format!("{e:?}") - } + fn from(e: &Encoding) -> Self { + format!("{e:?}") + } } impl FhePlaintextEncoding for Encoding {} diff --git a/crates/fhe/src/bfv/keys/evaluation_key.rs b/crates/fhe/src/bfv/keys/evaluation_key.rs index b59fe55..b92da5e 100644 --- a/crates/fhe/src/bfv/keys/evaluation_key.rs +++ b/crates/fhe/src/bfv/keys/evaluation_key.rs @@ -1,10 +1,10 @@ //! Leveled evaluation keys for the BFV encryption scheme. use crate::bfv::{ - keys::GaloisKey, - proto::bfv::{EvaluationKey as EvaluationKeyProto, GaloisKey as GaloisKeyProto}, - traits::TryConvertFrom, - BfvParameters, Ciphertext, SecretKey, + keys::GaloisKey, + proto::bfv::{EvaluationKey as EvaluationKeyProto, GaloisKey as GaloisKeyProto}, + traits::TryConvertFrom, + BfvParameters, Ciphertext, SecretKey, }; use crate::{Error, Result}; use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}; @@ -26,843 +26,846 @@ use zeroize::{Zeroize, ZeroizeOnDrop}; /// - inner sum #[derive(Debug, PartialEq, Eq)] pub struct EvaluationKey { - par: Arc, + par: Arc, - ciphertext_level: usize, - evaluation_key_level: usize, + ciphertext_level: usize, + evaluation_key_level: usize, - /// Map from Galois keys exponents to Galois keys - gk: HashMap, + /// Map from Galois keys exponents to Galois keys + gk: HashMap, - /// Map from rotation index to Galois key exponent - rot_to_gk_exponent: HashMap, + /// Map from rotation index to Galois key exponent + rot_to_gk_exponent: HashMap, - /// Monomials used in expansion - monomials: Vec, + /// Monomials used in expansion + monomials: Vec, } impl EvaluationKey { - /// Reports whether the evaluation key enables to compute an homomorphic - /// inner sums. - pub fn supports_inner_sum(&self) -> bool { - if self.evaluation_key_level == self.par.moduli().len() { - false - } else { - let mut ret = self.gk.contains_key(&(self.par.degree() * 2 - 1)); - let mut i = 1; - while i < self.par.degree() / 2 { - ret &= self - .gk - .contains_key(self.rot_to_gk_exponent.get(&i).unwrap()); - i *= 2 - } - ret - } - } + /// Reports whether the evaluation key enables to compute an homomorphic + /// inner sums. + pub fn supports_inner_sum(&self) -> bool { + if self.evaluation_key_level == self.par.moduli().len() { + false + } else { + let mut ret = self.gk.contains_key(&(self.par.degree() * 2 - 1)); + let mut i = 1; + while i < self.par.degree() / 2 { + ret &= self + .gk + .contains_key(self.rot_to_gk_exponent.get(&i).unwrap()); + i *= 2 + } + ret + } + } - /// Computes the homomorphic inner sum. - pub fn computes_inner_sum(&self, ct: &Ciphertext) -> Result { - if !self.supports_inner_sum() { - Err(Error::DefaultError( - "This key does not support the inner sum functionality".to_string(), - )) - } else { - let mut out = ct.clone(); + /// Computes the homomorphic inner sum. + pub fn computes_inner_sum(&self, ct: &Ciphertext) -> Result { + if !self.supports_inner_sum() { + Err(Error::DefaultError( + "This key does not support the inner sum functionality".to_string(), + )) + } else { + let mut out = ct.clone(); - let mut i = 1; - while i < ct.par.degree() / 2 { - let gk = self - .gk - .get(self.rot_to_gk_exponent.get(&i).unwrap()) - .unwrap(); - out += &gk.relinearize(&out)?; - i *= 2 - } + let mut i = 1; + while i < ct.par.degree() / 2 { + let gk = self + .gk + .get(self.rot_to_gk_exponent.get(&i).unwrap()) + .unwrap(); + out += &gk.relinearize(&out)?; + i *= 2 + } - let gk = self.gk.get(&(self.par.degree() * 2 - 1)).unwrap(); - out += &gk.relinearize(&out)?; + let gk = self.gk.get(&(self.par.degree() * 2 - 1)).unwrap(); + out += &gk.relinearize(&out)?; - Ok(out) - } - } + Ok(out) + } + } - /// Reports whether the evaluation key enables to rotate the rows of the - /// plaintext. - pub fn supports_row_rotation(&self) -> bool { - if self.evaluation_key_level == self.par.moduli().len() { - false - } else { - self.gk.contains_key(&(self.par.degree() * 2 - 1)) - } - } + /// Reports whether the evaluation key enables to rotate the rows of the + /// plaintext. + pub fn supports_row_rotation(&self) -> bool { + if self.evaluation_key_level == self.par.moduli().len() { + false + } else { + self.gk.contains_key(&(self.par.degree() * 2 - 1)) + } + } - /// Homomorphically rotate the rows of the plaintext - pub fn rotates_rows(&self, ct: &Ciphertext) -> Result { - if !self.supports_row_rotation() { - Err(Error::DefaultError( - "This key does not support the row rotation functionality".to_string(), - )) - } else { - let gk = self.gk.get(&(self.par.degree() * 2 - 1)).unwrap(); - gk.relinearize(ct) - } - } + /// Homomorphically rotate the rows of the plaintext + pub fn rotates_rows(&self, ct: &Ciphertext) -> Result { + if !self.supports_row_rotation() { + Err(Error::DefaultError( + "This key does not support the row rotation functionality".to_string(), + )) + } else { + let gk = self.gk.get(&(self.par.degree() * 2 - 1)).unwrap(); + gk.relinearize(ct) + } + } - /// Reports whether the evaluation key enables to rotate the columns of the - /// plaintext. - pub fn supports_column_rotation_by(&self, i: usize) -> bool { - if self.evaluation_key_level == self.par.moduli().len() { - false - } else if let Some(exp) = self.rot_to_gk_exponent.get(&i) { - self.gk.contains_key(exp) - } else { - false - } - } + /// Reports whether the evaluation key enables to rotate the columns of the + /// plaintext. + pub fn supports_column_rotation_by(&self, i: usize) -> bool { + if self.evaluation_key_level == self.par.moduli().len() { + false + } else if let Some(exp) = self.rot_to_gk_exponent.get(&i) { + self.gk.contains_key(exp) + } else { + false + } + } - /// Homomorphically rotate the columns of the plaintext - pub fn rotates_columns_by(&self, ct: &Ciphertext, i: usize) -> Result { - if !self.supports_column_rotation_by(i) { - Err(Error::DefaultError( - "This key does not support rotating the columns by this index".to_string(), - )) - } else { - let gk = self - .gk - .get(self.rot_to_gk_exponent.get(&i).unwrap()) - .unwrap(); - gk.relinearize(ct) - } - } + /// Homomorphically rotate the columns of the plaintext + pub fn rotates_columns_by(&self, ct: &Ciphertext, i: usize) -> Result { + if !self.supports_column_rotation_by(i) { + Err(Error::DefaultError( + "This key does not support rotating the columns by this index".to_string(), + )) + } else { + let gk = self + .gk + .get(self.rot_to_gk_exponent.get(&i).unwrap()) + .unwrap(); + gk.relinearize(ct) + } + } - /// Reports whether the evaluation key supports oblivious expansion. - pub fn supports_expansion(&self, level: usize) -> bool { - if level == 0 { - true - } else if self.evaluation_key_level == self.par.moduli().len() { - false - } else { - let mut ret = level < self.par.degree().leading_zeros() as usize; - for l in 0..level { - ret &= self.gk.contains_key(&((self.par.degree() >> l) + 1)); - } - ret - } - } + /// Reports whether the evaluation key supports oblivious expansion. + pub fn supports_expansion(&self, level: usize) -> bool { + if level == 0 { + true + } else if self.evaluation_key_level == self.par.moduli().len() { + false + } else { + let mut ret = level < self.par.degree().leading_zeros() as usize; + for l in 0..level { + ret &= self.gk.contains_key(&((self.par.degree() >> l) + 1)); + } + ret + } + } - /// Obliviously expands the ciphertext. Returns an error if this evaluation - /// does not support expansion to level = ceil(log2(size)), or if the - /// ciphertext does not have size 2. The output is a vector of `size` - /// ciphertexts. - pub fn expands(&self, ct: &Ciphertext, size: usize) -> Result> { - let level = ilog2(size.next_power_of_two() as u64); - if ct.c.len() != 2 { - Err(Error::DefaultError( - "The ciphertext is not of size 2".to_string(), - )) - } else if level == 0 { - Ok(vec![ct.clone()]) - } else if self.supports_expansion(level) { - let mut out = vec![Ciphertext::zero(&ct.par); 1 << level]; - out[0] = ct.clone(); + /// Obliviously expands the ciphertext. Returns an error if this evaluation + /// does not support expansion to level = ceil(log2(size)), or if the + /// ciphertext does not have size 2. The output is a vector of `size` + /// ciphertexts. + pub fn expands(&self, ct: &Ciphertext, size: usize) -> Result> { + let level = ilog2(size.next_power_of_two() as u64); + if ct.c.len() != 2 { + Err(Error::DefaultError( + "The ciphertext is not of size 2".to_string(), + )) + } else if level == 0 { + Ok(vec![ct.clone()]) + } else if self.supports_expansion(level) { + let mut out = vec![Ciphertext::zero(&ct.par); 1 << level]; + out[0] = ct.clone(); - // We use the Oblivious expansion algorithm of - // https://eprint.iacr.org/2019/1483.pdf - for l in 0..level { - let monomial = &self.monomials[l]; - let gk = self.gk.get(&((self.par.degree() >> l) + 1)).unwrap(); - for i in 0..(1 << l) { - let sub = gk.relinearize(&out[i])?; - if (1 << l) | i < size { - out[(1 << l) | i] = &out[i] - ⊂ - out[(1 << l) | i].c[0] *= monomial; - out[(1 << l) | i].c[1] *= monomial; - } - out[i] += ⊂ - } - } - out.truncate(size); - Ok(out) - } else { - Err(Error::DefaultError( - "This key does not support expansion at this level".to_string(), - )) - } - } + // We use the Oblivious expansion algorithm of + // https://eprint.iacr.org/2019/1483.pdf + for l in 0..level { + let monomial = &self.monomials[l]; + let gk = self.gk.get(&((self.par.degree() >> l) + 1)).unwrap(); + for i in 0..(1 << l) { + let sub = gk.relinearize(&out[i])?; + if (1 << l) | i < size { + out[(1 << l) | i] = &out[i] - ⊂ + out[(1 << l) | i].c[0] *= monomial; + out[(1 << l) | i].c[1] *= monomial; + } + out[i] += ⊂ + } + } + out.truncate(size); + Ok(out) + } else { + Err(Error::DefaultError( + "This key does not support expansion at this level".to_string(), + )) + } + } - fn construct_rot_to_gk_exponent(par: &Arc) -> HashMap { - let mut m = HashMap::new(); - let q = Modulus::new(2 * par.degree() as u64).unwrap(); - for i in 1..par.degree() / 2 { - let exp = q.pow(3, i as u64) as usize; - m.insert(i, exp); - } - m - } + fn construct_rot_to_gk_exponent(par: &Arc) -> HashMap { + let mut m = HashMap::new(); + let q = Modulus::new(2 * par.degree() as u64).unwrap(); + for i in 1..par.degree() / 2 { + let exp = q.pow(3, i as u64) as usize; + m.insert(i, exp); + } + m + } } impl FheParametrized for EvaluationKey { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl Serialize for EvaluationKey { - fn to_bytes(&self) -> Vec { - let ekp = EvaluationKeyProto::from(self); - ekp.write_to_bytes().unwrap() - } + fn to_bytes(&self) -> Vec { + let ekp = EvaluationKeyProto::from(self); + ekp.write_to_bytes().unwrap() + } } impl DeserializeParametrized for EvaluationKey { - type Error = Error; + type Error = Error; - fn from_bytes(bytes: &[u8], par: &Arc) -> Result { - let gkp = EvaluationKeyProto::parse_from_bytes(bytes); - if let Ok(gkp) = gkp { - EvaluationKey::try_convert_from(&gkp, par) - } else { - Err(Error::DefaultError("Invalid serialization".to_string())) - } - } + fn from_bytes(bytes: &[u8], par: &Arc) -> Result { + let gkp = EvaluationKeyProto::parse_from_bytes(bytes); + if let Ok(gkp) = gkp { + EvaluationKey::try_convert_from(&gkp, par) + } else { + Err(Error::DefaultError("Invalid serialization".to_string())) + } + } } /// Builder for a leveled evaluation key from the secret key. #[derive(Debug)] pub struct EvaluationKeyBuilder { - sk: SecretKey, - ciphertext_level: usize, - evaluation_key_level: usize, - inner_sum: bool, - row_rotation: bool, - expansion_level: usize, - column_rotation: HashSet, - rot_to_gk_exponent: HashMap, + sk: SecretKey, + ciphertext_level: usize, + evaluation_key_level: usize, + inner_sum: bool, + row_rotation: bool, + expansion_level: usize, + column_rotation: HashSet, + rot_to_gk_exponent: HashMap, } impl Zeroize for EvaluationKeyBuilder { - fn zeroize(&mut self) { - self.sk.zeroize() - } + fn zeroize(&mut self) { + self.sk.zeroize() + } } impl ZeroizeOnDrop for EvaluationKeyBuilder {} impl EvaluationKeyBuilder { - /// Creates a new builder from the [`SecretKey`]. - pub fn new(sk: &SecretKey) -> Result { - Ok(Self { - sk: sk.clone(), - ciphertext_level: 0, - evaluation_key_level: 0, - inner_sum: false, - row_rotation: false, - expansion_level: 0, - column_rotation: HashSet::new(), - rot_to_gk_exponent: EvaluationKey::construct_rot_to_gk_exponent(&sk.par), - }) - } + /// Creates a new builder from the [`SecretKey`]. + pub fn new(sk: &SecretKey) -> Result { + Ok(Self { + sk: sk.clone(), + ciphertext_level: 0, + evaluation_key_level: 0, + inner_sum: false, + row_rotation: false, + expansion_level: 0, + column_rotation: HashSet::new(), + rot_to_gk_exponent: EvaluationKey::construct_rot_to_gk_exponent(&sk.par), + }) + } - /// Creates a new builder from the [`SecretKey`], for operations on - /// ciphertexts at level `ciphertext_level` using keys at level - /// `evaluation_key_level`. This raises an error if the key level is larger - /// than the ciphertext level, or if the ciphertext level is larger than the - /// maximum level supported by these parameters. - pub fn new_leveled( - sk: &SecretKey, - ciphertext_level: usize, - evaluation_key_level: usize, - ) -> Result { - if ciphertext_level < evaluation_key_level || ciphertext_level > sk.par.max_level() { - return Err(Error::DefaultError("Unexpected levels".to_string())); - } + /// Creates a new builder from the [`SecretKey`], for operations on + /// ciphertexts at level `ciphertext_level` using keys at level + /// `evaluation_key_level`. This raises an error if the key level is larger + /// than the ciphertext level, or if the ciphertext level is larger than the + /// maximum level supported by these parameters. + pub fn new_leveled( + sk: &SecretKey, + ciphertext_level: usize, + evaluation_key_level: usize, + ) -> Result { + if ciphertext_level < evaluation_key_level || ciphertext_level > sk.par.max_level() { + return Err(Error::DefaultError("Unexpected levels".to_string())); + } - Ok(Self { - sk: sk.clone(), - ciphertext_level, - evaluation_key_level, - inner_sum: false, - row_rotation: false, - expansion_level: 0, - column_rotation: HashSet::new(), - rot_to_gk_exponent: EvaluationKey::construct_rot_to_gk_exponent(&sk.par), - }) - } + Ok(Self { + sk: sk.clone(), + ciphertext_level, + evaluation_key_level, + inner_sum: false, + row_rotation: false, + expansion_level: 0, + column_rotation: HashSet::new(), + rot_to_gk_exponent: EvaluationKey::construct_rot_to_gk_exponent(&sk.par), + }) + } - /// Allow expansion by this evaluation key. - #[allow(unused_must_use)] - pub fn enable_expansion(&mut self, level: usize) -> Result<&mut Self> { - if self - .sk - .par - .ctx_at_level(self.evaluation_key_level)? - .moduli() - .len() == 1 - { - Err(Error::DefaultError( - "Not enough moduli to enable expansion".to_string(), - )) - } else if level >= 64 - self.sk.par.degree().leading_zeros() as usize { - Err(Error::DefaultError("Invalid level 2".to_string())) - } else { - self.expansion_level = level; - Ok(self) - } - } + /// Allow expansion by this evaluation key. + #[allow(unused_must_use)] + pub fn enable_expansion(&mut self, level: usize) -> Result<&mut Self> { + if self + .sk + .par + .ctx_at_level(self.evaluation_key_level)? + .moduli() + .len() + == 1 + { + Err(Error::DefaultError( + "Not enough moduli to enable expansion".to_string(), + )) + } else if level >= 64 - self.sk.par.degree().leading_zeros() as usize { + Err(Error::DefaultError("Invalid level 2".to_string())) + } else { + self.expansion_level = level; + Ok(self) + } + } - /// Allow this evaluation key to compute homomorphic inner sums. - #[allow(unused_must_use)] - pub fn enable_inner_sum(&mut self) -> Result<&mut Self> { - if self - .sk - .par - .ctx_at_level(self.evaluation_key_level)? - .moduli() - .len() == 1 - { - Err(Error::DefaultError( - "Not enough moduli to enable relinearization".to_string(), - )) - } else { - self.inner_sum = true; - Ok(self) - } - } + /// Allow this evaluation key to compute homomorphic inner sums. + #[allow(unused_must_use)] + pub fn enable_inner_sum(&mut self) -> Result<&mut Self> { + if self + .sk + .par + .ctx_at_level(self.evaluation_key_level)? + .moduli() + .len() + == 1 + { + Err(Error::DefaultError( + "Not enough moduli to enable relinearization".to_string(), + )) + } else { + self.inner_sum = true; + Ok(self) + } + } - /// Allow this evaluation key to homomorphically rotate the plaintext rows. - #[allow(unused_must_use)] - pub fn enable_row_rotation(&mut self) -> Result<&mut Self> { - if self - .sk - .par - .ctx_at_level(self.evaluation_key_level)? - .moduli() - .len() == 1 - { - Err(Error::DefaultError( - "Not enough moduli to enable relinearization".to_string(), - )) - } else { - self.row_rotation = true; - Ok(self) - } - } + /// Allow this evaluation key to homomorphically rotate the plaintext rows. + #[allow(unused_must_use)] + pub fn enable_row_rotation(&mut self) -> Result<&mut Self> { + if self + .sk + .par + .ctx_at_level(self.evaluation_key_level)? + .moduli() + .len() + == 1 + { + Err(Error::DefaultError( + "Not enough moduli to enable relinearization".to_string(), + )) + } else { + self.row_rotation = true; + Ok(self) + } + } - /// Allow this evaluation key to homomorphically rotate the plaintext - /// columns. - #[allow(unused_must_use)] - pub fn enable_column_rotation(&mut self, i: usize) -> Result<&mut Self> { - if let Some(exp) = self.rot_to_gk_exponent.get(&i) { - self.column_rotation.insert(*exp); - Ok(self) - } else { - Err(Error::DefaultError("Invalid column index".to_string())) - } - } + /// Allow this evaluation key to homomorphically rotate the plaintext + /// columns. + #[allow(unused_must_use)] + pub fn enable_column_rotation(&mut self, i: usize) -> Result<&mut Self> { + if let Some(exp) = self.rot_to_gk_exponent.get(&i) { + self.column_rotation.insert(*exp); + Ok(self) + } else { + Err(Error::DefaultError("Invalid column index".to_string())) + } + } - /// Build an [`EvaluationKey`] with the specified attributes. - pub fn build(&mut self, rng: &mut R) -> Result { - let mut ek = EvaluationKey { - gk: HashMap::default(), - par: self.sk.par.clone(), - rot_to_gk_exponent: self.rot_to_gk_exponent.clone(), - monomials: Vec::with_capacity(ilog2(self.sk.par.degree() as u64)), - ciphertext_level: self.ciphertext_level, - evaluation_key_level: self.evaluation_key_level, - }; + /// Build an [`EvaluationKey`] with the specified attributes. + pub fn build(&mut self, rng: &mut R) -> Result { + let mut ek = EvaluationKey { + gk: HashMap::default(), + par: self.sk.par.clone(), + rot_to_gk_exponent: self.rot_to_gk_exponent.clone(), + monomials: Vec::with_capacity(ilog2(self.sk.par.degree() as u64)), + ciphertext_level: self.ciphertext_level, + evaluation_key_level: self.evaluation_key_level, + }; - let mut indices = self.column_rotation.clone(); + let mut indices = self.column_rotation.clone(); - if self.row_rotation { - indices.insert(self.sk.par.degree() * 2 - 1); - } + if self.row_rotation { + indices.insert(self.sk.par.degree() * 2 - 1); + } - if self.inner_sum { - // Add the required indices to the set of indices - indices.insert(self.sk.par.degree() * 2 - 1); - let mut i = 1; - while i < self.sk.par.degree() / 2 { - indices.insert(*ek.rot_to_gk_exponent.get(&i).unwrap()); - i *= 2 - } - } + if self.inner_sum { + // Add the required indices to the set of indices + indices.insert(self.sk.par.degree() * 2 - 1); + let mut i = 1; + while i < self.sk.par.degree() / 2 { + indices.insert(*ek.rot_to_gk_exponent.get(&i).unwrap()); + i *= 2 + } + } - for l in 0..self.expansion_level { - indices.insert((self.sk.par.degree() >> l) + 1); - } + for l in 0..self.expansion_level { + indices.insert((self.sk.par.degree() >> l) + 1); + } - let ciphertext_ctx = self.sk.par.ctx_at_level(self.ciphertext_level)?; - for l in 0..ilog2(self.sk.par.degree() as u64) { - let mut monomial = vec![0i64; self.sk.par.degree()]; - monomial[self.sk.par.degree() - (1 << l)] = -1; - let mut monomial = Poly::try_convert_from( - &monomial, - ciphertext_ctx, - true, - Representation::PowerBasis, - )?; - unsafe { monomial.allow_variable_time_computations() } - monomial.change_representation(Representation::NttShoup); - ek.monomials.push(monomial); - } + let ciphertext_ctx = self.sk.par.ctx_at_level(self.ciphertext_level)?; + for l in 0..ilog2(self.sk.par.degree() as u64) { + let mut monomial = vec![0i64; self.sk.par.degree()]; + monomial[self.sk.par.degree() - (1 << l)] = -1; + let mut monomial = Poly::try_convert_from( + &monomial, + ciphertext_ctx, + true, + Representation::PowerBasis, + )?; + unsafe { monomial.allow_variable_time_computations() } + monomial.change_representation(Representation::NttShoup); + ek.monomials.push(monomial); + } - for index in indices { - ek.gk.insert( - index, - GaloisKey::new( - &self.sk, - index, - self.ciphertext_level, - self.evaluation_key_level, - rng, - )?, - ); - } + for index in indices { + ek.gk.insert( + index, + GaloisKey::new( + &self.sk, + index, + self.ciphertext_level, + self.evaluation_key_level, + rng, + )?, + ); + } - Ok(ek) - } + Ok(ek) + } } impl From<&EvaluationKey> for EvaluationKeyProto { - fn from(ek: &EvaluationKey) -> Self { - let mut proto = EvaluationKeyProto::new(); - for (_, gk) in ek.gk.iter() { - proto.gk.push(GaloisKeyProto::from(gk)) - } - proto.ciphertext_level = ek.ciphertext_level as u32; - proto.evaluation_key_level = ek.evaluation_key_level as u32; - proto - } + fn from(ek: &EvaluationKey) -> Self { + let mut proto = EvaluationKeyProto::new(); + for (_, gk) in ek.gk.iter() { + proto.gk.push(GaloisKeyProto::from(gk)) + } + proto.ciphertext_level = ek.ciphertext_level as u32; + proto.evaluation_key_level = ek.evaluation_key_level as u32; + proto + } } impl TryConvertFrom<&EvaluationKeyProto> for EvaluationKey { - fn try_convert_from(value: &EvaluationKeyProto, par: &Arc) -> Result { - let mut gk = HashMap::new(); - for gkp in &value.gk { - let key = GaloisKey::try_convert_from(gkp, par)?; - if key.ksk.ciphertext_level != value.ciphertext_level as usize { - return Err(Error::DefaultError( - "Galois key has incorrect ciphertext level".to_string(), - )); - } - if key.ksk.ksk_level != value.evaluation_key_level as usize { - return Err(Error::DefaultError( - "Galois key has incorrect evaluation key level".to_string(), - )); - } - gk.insert(key.element.exponent, key); - } + fn try_convert_from(value: &EvaluationKeyProto, par: &Arc) -> Result { + let mut gk = HashMap::new(); + for gkp in &value.gk { + let key = GaloisKey::try_convert_from(gkp, par)?; + if key.ksk.ciphertext_level != value.ciphertext_level as usize { + return Err(Error::DefaultError( + "Galois key has incorrect ciphertext level".to_string(), + )); + } + if key.ksk.ksk_level != value.evaluation_key_level as usize { + return Err(Error::DefaultError( + "Galois key has incorrect evaluation key level".to_string(), + )); + } + gk.insert(key.element.exponent, key); + } - let ciphertext_ctx = par.ctx_at_level(value.ciphertext_level as usize)?; - let mut monomials = Vec::with_capacity(ilog2(par.degree() as u64)); - for l in 0..ilog2(par.degree() as u64) { - let mut monomial = vec![0i64; par.degree()]; - monomial[par.degree() - (1 << l)] = -1; - let mut monomial = Poly::try_convert_from( - &monomial, - ciphertext_ctx, - true, - Representation::PowerBasis, - )?; - unsafe { monomial.allow_variable_time_computations() } - monomial.change_representation(Representation::NttShoup); - monomials.push(monomial); - } + let ciphertext_ctx = par.ctx_at_level(value.ciphertext_level as usize)?; + let mut monomials = Vec::with_capacity(ilog2(par.degree() as u64)); + for l in 0..ilog2(par.degree() as u64) { + let mut monomial = vec![0i64; par.degree()]; + monomial[par.degree() - (1 << l)] = -1; + let mut monomial = Poly::try_convert_from( + &monomial, + ciphertext_ctx, + true, + Representation::PowerBasis, + )?; + unsafe { monomial.allow_variable_time_computations() } + monomial.change_representation(Representation::NttShoup); + monomials.push(monomial); + } - Ok(EvaluationKey { - gk, - par: par.clone(), - rot_to_gk_exponent: EvaluationKey::construct_rot_to_gk_exponent(par), - monomials, - ciphertext_level: value.ciphertext_level as usize, - evaluation_key_level: value.evaluation_key_level as usize, - }) - } + Ok(EvaluationKey { + gk, + par: par.clone(), + rot_to_gk_exponent: EvaluationKey::construct_rot_to_gk_exponent(par), + monomials, + ciphertext_level: value.ciphertext_level as usize, + evaluation_key_level: value.evaluation_key_level as usize, + }) + } } #[cfg(test)] mod tests { - use super::{EvaluationKey, EvaluationKeyBuilder}; - use crate::bfv::{ - proto::bfv::EvaluationKey as LeveledEvaluationKeyProto, traits::TryConvertFrom, - BfvParameters, Encoding, Plaintext, SecretKey, - }; - use fhe_traits::{ - DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize, - }; - use fhe_util::ilog2; - use itertools::izip; - use rand::thread_rng; - use std::{cmp::min, error::Error, sync::Arc}; + use super::{EvaluationKey, EvaluationKeyBuilder}; + use crate::bfv::{ + proto::bfv::EvaluationKey as LeveledEvaluationKeyProto, traits::TryConvertFrom, + BfvParameters, Encoding, Plaintext, SecretKey, + }; + use fhe_traits::{ + DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize, + }; + use fhe_util::ilog2; + use itertools::izip; + use rand::thread_rng; + use std::{cmp::min, error::Error, sync::Arc}; - #[test] - fn builder() -> Result<(), Box> { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(6, 8)); - let sk = SecretKey::random(¶ms, &mut rng); + #[test] + fn builder() -> Result<(), Box> { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(6, 8)); + let sk = SecretKey::random(¶ms, &mut rng); - let max_level = params.max_level(); - for ciphertext_level in 0..=max_level { - for evaluation_key_level in 0..=min(max_level - 1, ciphertext_level) { - let mut builder = - EvaluationKeyBuilder::new_leveled(&sk, ciphertext_level, evaluation_key_level)?; + let max_level = params.max_level(); + for ciphertext_level in 0..=max_level { + for evaluation_key_level in 0..=min(max_level - 1, ciphertext_level) { + let mut builder = + EvaluationKeyBuilder::new_leveled(&sk, ciphertext_level, evaluation_key_level)?; - assert!(!builder.build(&mut rng)?.supports_row_rotation()); - assert!(!builder.build(&mut rng)?.supports_column_rotation_by(0)); - assert!(!builder.build(&mut rng)?.supports_column_rotation_by(1)); - assert!(!builder.build(&mut rng)?.supports_inner_sum()); - assert!(!builder.build(&mut rng)?.supports_expansion(1)); - assert!(builder.build(&mut rng)?.supports_expansion(0)); - assert!(builder.enable_column_rotation(0).is_err()); - assert!(builder - .enable_expansion(64 - params.degree().leading_zeros() as usize) - .is_err()); + assert!(!builder.build(&mut rng)?.supports_row_rotation()); + assert!(!builder.build(&mut rng)?.supports_column_rotation_by(0)); + assert!(!builder.build(&mut rng)?.supports_column_rotation_by(1)); + assert!(!builder.build(&mut rng)?.supports_inner_sum()); + assert!(!builder.build(&mut rng)?.supports_expansion(1)); + assert!(builder.build(&mut rng)?.supports_expansion(0)); + assert!(builder.enable_column_rotation(0).is_err()); + assert!(builder + .enable_expansion(64 - params.degree().leading_zeros() as usize) + .is_err()); - builder.enable_column_rotation(1)?; - assert!(builder.build(&mut rng)?.supports_column_rotation_by(1)); - assert!(!builder.build(&mut rng)?.supports_row_rotation()); - assert!(!builder.build(&mut rng)?.supports_inner_sum()); - assert!(!builder.build(&mut rng)?.supports_expansion(1)); + builder.enable_column_rotation(1)?; + assert!(builder.build(&mut rng)?.supports_column_rotation_by(1)); + assert!(!builder.build(&mut rng)?.supports_row_rotation()); + assert!(!builder.build(&mut rng)?.supports_inner_sum()); + assert!(!builder.build(&mut rng)?.supports_expansion(1)); - builder.enable_row_rotation()?; - assert!(builder.build(&mut rng)?.supports_row_rotation()); - assert!(!builder.build(&mut rng)?.supports_inner_sum()); - assert!(!builder.build(&mut rng)?.supports_expansion(1)); + builder.enable_row_rotation()?; + assert!(builder.build(&mut rng)?.supports_row_rotation()); + assert!(!builder.build(&mut rng)?.supports_inner_sum()); + assert!(!builder.build(&mut rng)?.supports_expansion(1)); - builder.enable_inner_sum()?; - assert!(builder.build(&mut rng)?.supports_inner_sum()); - assert!(builder.build(&mut rng)?.supports_expansion(1)); - assert!(!builder - .build(&mut rng)? - .supports_expansion(64 - 1 - params.degree().leading_zeros() as usize)); + builder.enable_inner_sum()?; + assert!(builder.build(&mut rng)?.supports_inner_sum()); + assert!(builder.build(&mut rng)?.supports_expansion(1)); + assert!(!builder + .build(&mut rng)? + .supports_expansion(64 - 1 - params.degree().leading_zeros() as usize)); - builder.enable_expansion(64 - 1 - params.degree().leading_zeros() as usize)?; - assert!(builder - .build(&mut rng)? - .supports_expansion(64 - 1 - params.degree().leading_zeros() as usize)); + builder.enable_expansion(64 - 1 - params.degree().leading_zeros() as usize)?; + assert!(builder + .build(&mut rng)? + .supports_expansion(64 - 1 - params.degree().leading_zeros() as usize)); - assert!(builder.build(&mut rng).is_ok()); + assert!(builder.build(&mut rng).is_ok()); - // Enabling inner sum enables row rotation and a few column rotations :) - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_inner_sum()? - .build(&mut rng)?; - assert!(ek.supports_inner_sum()); - assert!(ek.supports_row_rotation()); - let mut i = 1; - while i < params.degree() / 2 { - assert!(ek.supports_column_rotation_by(i)); - i *= 2 - } - assert!(!ek.supports_column_rotation_by(params.degree() / 2 - 1)); - } - } + // Enabling inner sum enables row rotation and a few column rotations :) + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_inner_sum()? + .build(&mut rng)?; + assert!(ek.supports_inner_sum()); + assert!(ek.supports_row_rotation()); + let mut i = 1; + while i < params.degree() / 2 { + assert!(ek.supports_column_rotation_by(i)); + i *= 2 + } + assert!(!ek.supports_column_rotation_by(params.degree() / 2 - 1)); + } + } - let mut builder = - EvaluationKeyBuilder::new_leveled(&sk, params.max_level(), params.max_level())?; - let e = builder.enable_inner_sum(); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::DefaultError("Not enough moduli to enable relinearization".to_string()) - ); + let mut builder = + EvaluationKeyBuilder::new_leveled(&sk, params.max_level(), params.max_level())?; + let e = builder.enable_inner_sum(); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::DefaultError("Not enough moduli to enable relinearization".to_string()) + ); - let e = EvaluationKeyBuilder::new_leveled(&sk, 0, 1); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::DefaultError("Unexpected levels".to_string()) - ); + let e = EvaluationKeyBuilder::new_leveled(&sk, 0, 1); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::DefaultError("Unexpected levels".to_string()) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn inner_sum() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(5, 8)), - ] { - for _ in 0..25 { - for ciphertext_level in 0..=params.max_level() { - for evaluation_key_level in 0..=min(params.max_level() - 1, ciphertext_level) { - let sk = SecretKey::random(¶ms, &mut rng); - let ek = EvaluationKeyBuilder::new_leveled( - &sk, - ciphertext_level, - evaluation_key_level, - )? - .enable_inner_sum()? - .build(&mut rng)?; + #[test] + fn inner_sum() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(5, 8)), + ] { + for _ in 0..25 { + for ciphertext_level in 0..=params.max_level() { + for evaluation_key_level in 0..=min(params.max_level() - 1, ciphertext_level) { + let sk = SecretKey::random(¶ms, &mut rng); + let ek = EvaluationKeyBuilder::new_leveled( + &sk, + ciphertext_level, + evaluation_key_level, + )? + .enable_inner_sum()? + .build(&mut rng)?; - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let expected = params - .plaintext - .reduce_u128(v.iter().map(|vi| *vi as u128).sum()); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let expected = params + .plaintext + .reduce_u128(v.iter().map(|vi| *vi as u128).sum()); - let pt = Plaintext::try_encode( - &v, - Encoding::simd_at_level(ciphertext_level), - ¶ms, - )?; - let ct = sk.try_encrypt(&pt, &mut rng)?; + let pt = Plaintext::try_encode( + &v, + Encoding::simd_at_level(ciphertext_level), + ¶ms, + )?; + let ct = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = ek.computes_inner_sum(&ct)?; - let pt = sk.try_decrypt(&ct2)?; - assert_eq!( - Vec::::try_decode(&pt, Encoding::simd_at_level(ciphertext_level))?, - vec![expected; params.degree()] - ) - } - } - } - } - Ok(()) - } + let ct2 = ek.computes_inner_sum(&ct)?; + let pt = sk.try_decrypt(&ct2)?; + assert_eq!( + Vec::::try_decode(&pt, Encoding::simd_at_level(ciphertext_level))?, + vec![expected; params.degree()] + ) + } + } + } + } + Ok(()) + } - #[test] - fn row_rotation() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(5, 8)), - ] { - for _ in 0..50 { - for ciphertext_level in 0..=params.max_level() { - for evaluation_key_level in 0..=min(params.max_level() - 1, ciphertext_level) { - let sk = SecretKey::random(¶ms, &mut rng); - let ek = EvaluationKeyBuilder::new_leveled( - &sk, - ciphertext_level, - evaluation_key_level, - )? - .enable_row_rotation()? - .build(&mut rng)?; + #[test] + fn row_rotation() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(5, 8)), + ] { + for _ in 0..50 { + for ciphertext_level in 0..=params.max_level() { + for evaluation_key_level in 0..=min(params.max_level() - 1, ciphertext_level) { + let sk = SecretKey::random(¶ms, &mut rng); + let ek = EvaluationKeyBuilder::new_leveled( + &sk, + ciphertext_level, + evaluation_key_level, + )? + .enable_row_rotation()? + .build(&mut rng)?; - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let row_size = params.degree() >> 1; - let mut expected = vec![0u64; params.degree()]; - expected[..row_size].copy_from_slice(&v[row_size..]); - expected[row_size..].copy_from_slice(&v[..row_size]); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let row_size = params.degree() >> 1; + let mut expected = vec![0u64; params.degree()]; + expected[..row_size].copy_from_slice(&v[row_size..]); + expected[row_size..].copy_from_slice(&v[..row_size]); - let pt = Plaintext::try_encode( - &v, - Encoding::simd_at_level(ciphertext_level), - ¶ms, - )?; - let ct = sk.try_encrypt(&pt, &mut rng)?; + let pt = Plaintext::try_encode( + &v, + Encoding::simd_at_level(ciphertext_level), + ¶ms, + )?; + let ct = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = ek.rotates_rows(&ct)?; - let pt = sk.try_decrypt(&ct2)?; - assert_eq!( - Vec::::try_decode(&pt, Encoding::simd_at_level(ciphertext_level))?, - expected - ) - } - } - } - } - Ok(()) - } + let ct2 = ek.rotates_rows(&ct)?; + let pt = sk.try_decrypt(&ct2)?; + assert_eq!( + Vec::::try_decode(&pt, Encoding::simd_at_level(ciphertext_level))?, + expected + ) + } + } + } + } + Ok(()) + } - #[test] - fn column_rotation() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(5, 8)), - ] { - let row_size = params.degree() >> 1; - for _ in 0..50 { - for i in 1..row_size { - for ciphertext_level in 0..=params.max_level() { - for evaluation_key_level in - 0..=min(params.max_level() - 1, ciphertext_level) - { - let sk = SecretKey::random(¶ms, &mut rng); - let ek = EvaluationKeyBuilder::new_leveled( - &sk, - ciphertext_level, - evaluation_key_level, - )? - .enable_column_rotation(i)? - .build(&mut rng)?; + #[test] + fn column_rotation() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(5, 8)), + ] { + let row_size = params.degree() >> 1; + for _ in 0..50 { + for i in 1..row_size { + for ciphertext_level in 0..=params.max_level() { + for evaluation_key_level in + 0..=min(params.max_level() - 1, ciphertext_level) + { + let sk = SecretKey::random(¶ms, &mut rng); + let ek = EvaluationKeyBuilder::new_leveled( + &sk, + ciphertext_level, + evaluation_key_level, + )? + .enable_column_rotation(i)? + .build(&mut rng)?; - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let row_size = params.degree() >> 1; - let mut expected = vec![0u64; params.degree()]; - expected[..row_size - i].copy_from_slice(&v[i..row_size]); - expected[row_size - i..row_size].copy_from_slice(&v[..i]); - expected[row_size..2 * row_size - i] - .copy_from_slice(&v[row_size + i..]); - expected[2 * row_size - i..] - .copy_from_slice(&v[row_size..row_size + i]); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let row_size = params.degree() >> 1; + let mut expected = vec![0u64; params.degree()]; + expected[..row_size - i].copy_from_slice(&v[i..row_size]); + expected[row_size - i..row_size].copy_from_slice(&v[..i]); + expected[row_size..2 * row_size - i] + .copy_from_slice(&v[row_size + i..]); + expected[2 * row_size - i..] + .copy_from_slice(&v[row_size..row_size + i]); - let pt = Plaintext::try_encode( - &v, - Encoding::simd_at_level(ciphertext_level), - ¶ms, - )?; - let ct = sk.try_encrypt(&pt, &mut rng)?; + let pt = Plaintext::try_encode( + &v, + Encoding::simd_at_level(ciphertext_level), + ¶ms, + )?; + let ct = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = ek.rotates_columns_by(&ct, i)?; - let pt = sk.try_decrypt(&ct2)?; - assert_eq!( - Vec::::try_decode( - &pt, - Encoding::simd_at_level(ciphertext_level) - )?, - expected - ) - } - } - } - } - } - Ok(()) - } + let ct2 = ek.rotates_columns_by(&ct, i)?; + let pt = sk.try_decrypt(&ct2)?; + assert_eq!( + Vec::::try_decode( + &pt, + Encoding::simd_at_level(ciphertext_level) + )?, + expected + ) + } + } + } + } + } + Ok(()) + } - #[test] - fn expansion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(5, 8)), - ] { - let log_degree = 64 - 1 - params.degree().leading_zeros(); - for _ in 0..15 { - for i in 1..1 + log_degree as usize { - for ciphertext_level in 0..=params.max_level() { - for evaluation_key_level in - 0..=min(params.max_level() - 1, ciphertext_level) - { - let sk = SecretKey::random(¶ms, &mut rng); - let ek = EvaluationKeyBuilder::new_leveled( - &sk, - ciphertext_level, - evaluation_key_level, - )? - .enable_expansion(i)? - .build(&mut rng)?; + #[test] + fn expansion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(5, 8)), + ] { + let log_degree = 64 - 1 - params.degree().leading_zeros(); + for _ in 0..15 { + for i in 1..1 + log_degree as usize { + for ciphertext_level in 0..=params.max_level() { + for evaluation_key_level in + 0..=min(params.max_level() - 1, ciphertext_level) + { + let sk = SecretKey::random(¶ms, &mut rng); + let ek = EvaluationKeyBuilder::new_leveled( + &sk, + ciphertext_level, + evaluation_key_level, + )? + .enable_expansion(i)? + .build(&mut rng)?; - assert!(ek.supports_expansion(i)); - assert!(!ek.supports_expansion(i + 1)); - let v = params.plaintext.random_vec(1 << i, &mut rng); - let pt = Plaintext::try_encode( - &v, - Encoding::poly_at_level(ciphertext_level), - ¶ms, - )?; - let ct = sk.try_encrypt(&pt, &mut rng)?; + assert!(ek.supports_expansion(i)); + assert!(!ek.supports_expansion(i + 1)); + let v = params.plaintext.random_vec(1 << i, &mut rng); + let pt = Plaintext::try_encode( + &v, + Encoding::poly_at_level(ciphertext_level), + ¶ms, + )?; + let ct = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = ek.expands(&ct, 1 << i)?; - assert_eq!(ct2.len(), 1 << i); - for (vi, ct2i) in izip!(&v, &ct2) { - let mut expected = vec![0u64; params.degree()]; - expected[0] = params.plaintext.mul(*vi, (1 << i) as u64); - let pt = sk.try_decrypt(ct2i)?; - assert_eq!( - expected, - Vec::::try_decode( - &pt, - Encoding::poly_at_level(ciphertext_level) - )? - ); - println!("Noise: {:?}", unsafe { sk.measure_noise(ct2i) }) - } - } - } - } - } - } - Ok(()) - } + let ct2 = ek.expands(&ct, 1 << i)?; + assert_eq!(ct2.len(), 1 << i); + for (vi, ct2i) in izip!(&v, &ct2) { + let mut expected = vec![0u64; params.degree()]; + expected[0] = params.plaintext.mul(*vi, (1 << i) as u64); + let pt = sk.try_decrypt(ct2i)?; + assert_eq!( + expected, + Vec::::try_decode( + &pt, + Encoding::poly_at_level(ciphertext_level) + )? + ); + println!("Noise: {:?}", unsafe { sk.measure_noise(ct2i) }) + } + } + } + } + } + } + Ok(()) + } - #[test] - fn proto_conversion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(5, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); + #[test] + fn proto_conversion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(5, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)?.build(&mut rng)?; + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)?.build(&mut rng)?; - let proto = LeveledEvaluationKeyProto::from(&ek); - assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); + let proto = LeveledEvaluationKeyProto::from(&ek); + assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); - if params.moduli.len() > 1 { - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_row_rotation()? - .build(&mut rng)?; + if params.moduli.len() > 1 { + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_row_rotation()? + .build(&mut rng)?; - let proto = LeveledEvaluationKeyProto::from(&ek); - assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); + let proto = LeveledEvaluationKeyProto::from(&ek); + assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_inner_sum()? - .build(&mut rng)?; - let proto = LeveledEvaluationKeyProto::from(&ek); - assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_inner_sum()? + .build(&mut rng)?; + let proto = LeveledEvaluationKeyProto::from(&ek); + assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_expansion(ilog2(params.degree() as u64))? - .build(&mut rng)?; - let proto = LeveledEvaluationKeyProto::from(&ek); - assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_expansion(ilog2(params.degree() as u64))? + .build(&mut rng)?; + let proto = LeveledEvaluationKeyProto::from(&ek); + assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_inner_sum()? - .enable_expansion(ilog2(params.degree() as u64))? - .build(&mut rng)?; - let proto = LeveledEvaluationKeyProto::from(&ek); - assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); - } - } - Ok(()) - } + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_inner_sum()? + .enable_expansion(ilog2(params.degree() as u64))? + .build(&mut rng)?; + let proto = LeveledEvaluationKeyProto::from(&ek); + assert_eq!(ek, EvaluationKey::try_convert_from(&proto, ¶ms)?); + } + } + Ok(()) + } - #[test] - fn serialize() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); + #[test] + fn serialize() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)?.build(&mut rng)?; - let bytes = ek.to_bytes(); - assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)?.build(&mut rng)?; + let bytes = ek.to_bytes(); + assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); - if params.moduli.len() > 1 { - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_row_rotation()? - .build(&mut rng)?; - let bytes = ek.to_bytes(); - assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); + if params.moduli.len() > 1 { + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_row_rotation()? + .build(&mut rng)?; + let bytes = ek.to_bytes(); + assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_inner_sum()? - .build(&mut rng)?; - let bytes = ek.to_bytes(); - assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_inner_sum()? + .build(&mut rng)?; + let bytes = ek.to_bytes(); + assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_expansion(ilog2(params.degree() as u64))? - .build(&mut rng)?; - let bytes = ek.to_bytes(); - assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_expansion(ilog2(params.degree() as u64))? + .build(&mut rng)?; + let bytes = ek.to_bytes(); + assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); - let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? - .enable_inner_sum()? - .enable_expansion(ilog2(params.degree() as u64))? - .build(&mut rng)?; - let bytes = ek.to_bytes(); - assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); - } - } - Ok(()) - } + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_inner_sum()? + .enable_expansion(ilog2(params.degree() as u64))? + .build(&mut rng)?; + let bytes = ek.to_bytes(); + assert_eq!(ek, EvaluationKey::from_bytes(&bytes, ¶ms)?); + } + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/keys/galois_key.rs b/crates/fhe/src/bfv/keys/galois_key.rs index 51fdfa9..4931f74 100644 --- a/crates/fhe/src/bfv/keys/galois_key.rs +++ b/crates/fhe/src/bfv/keys/galois_key.rs @@ -2,14 +2,14 @@ use super::key_switching_key::KeySwitchingKey; use crate::bfv::{ - proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto}, - traits::TryConvertFrom, - BfvParameters, Ciphertext, SecretKey, + proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto}, + traits::TryConvertFrom, + BfvParameters, Ciphertext, SecretKey, }; use crate::{Error, Result}; use fhe_math::rq::{ - switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation, - SubstitutionExponent, + switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation, + SubstitutionExponent, }; use protobuf::MessageField; use rand::{CryptoRng, RngCore}; @@ -21,181 +21,181 @@ use zeroize::Zeroizing; /// which switch from `s(x^i)` to `s(x)` where `s(x)` is the secret key. #[derive(Debug, PartialEq, Eq)] pub struct GaloisKey { - pub(crate) element: SubstitutionExponent, - pub(crate) ksk: KeySwitchingKey, + pub(crate) element: SubstitutionExponent, + pub(crate) ksk: KeySwitchingKey, } impl GaloisKey { - /// Generate a [`GaloisKey`] from a [`SecretKey`]. - pub fn new( - sk: &SecretKey, - exponent: usize, - ciphertext_level: usize, - galois_key_level: usize, - rng: &mut R, - ) -> Result { - let ctx_galois_key = sk.par.ctx_at_level(galois_key_level)?; - let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; + /// Generate a [`GaloisKey`] from a [`SecretKey`]. + pub fn new( + sk: &SecretKey, + exponent: usize, + ciphertext_level: usize, + galois_key_level: usize, + rng: &mut R, + ) -> Result { + let ctx_galois_key = sk.par.ctx_at_level(galois_key_level)?; + let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; - let ciphertext_exponent = - SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?; + let ciphertext_exponent = + SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?; - let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?; - let s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx_ciphertext, - false, - Representation::PowerBasis, - )?); - let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?); - let mut s_sub_switched_up = Zeroizing::new(s_sub.mod_switch_to(&switcher_up)?); - s_sub_switched_up.change_representation(Representation::PowerBasis); + let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?; + let s = Zeroizing::new(Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx_ciphertext, + false, + Representation::PowerBasis, + )?); + let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?); + let mut s_sub_switched_up = Zeroizing::new(s_sub.mod_switch_to(&switcher_up)?); + s_sub_switched_up.change_representation(Representation::PowerBasis); - let ksk = KeySwitchingKey::new( - sk, - &s_sub_switched_up, - ciphertext_level, - galois_key_level, - rng, - )?; + let ksk = KeySwitchingKey::new( + sk, + &s_sub_switched_up, + ciphertext_level, + galois_key_level, + rng, + )?; - Ok(Self { - element: ciphertext_exponent, - ksk, - }) - } + Ok(Self { + element: ciphertext_exponent, + ksk, + }) + } - /// Relinearize a [`Ciphertext`] using the [`GaloisKey`] - pub fn relinearize(&self, ct: &Ciphertext) -> Result { - // assert_eq!(ct.par, self.ksk.par); - assert_eq!(ct.c.len(), 2); + /// Relinearize a [`Ciphertext`] using the [`GaloisKey`] + pub fn relinearize(&self, ct: &Ciphertext) -> Result { + // assert_eq!(ct.par, self.ksk.par); + assert_eq!(ct.c.len(), 2); - let mut c2 = ct.c[1].substitute(&self.element)?; - c2.change_representation(Representation::PowerBasis); - let (mut c0, mut c1) = self.ksk.key_switch(&c2)?; + let mut c2 = ct.c[1].substitute(&self.element)?; + c2.change_representation(Representation::PowerBasis); + let (mut c0, mut c1) = self.ksk.key_switch(&c2)?; - if c0.ctx() != ct.c[0].ctx() { - c0.change_representation(Representation::PowerBasis); - c1.change_representation(Representation::PowerBasis); - c0.mod_switch_down_to(ct.c[0].ctx())?; - c1.mod_switch_down_to(ct.c[1].ctx())?; - c0.change_representation(Representation::Ntt); - c1.change_representation(Representation::Ntt); - } + if c0.ctx() != ct.c[0].ctx() { + c0.change_representation(Representation::PowerBasis); + c1.change_representation(Representation::PowerBasis); + c0.mod_switch_down_to(ct.c[0].ctx())?; + c1.mod_switch_down_to(ct.c[1].ctx())?; + c0.change_representation(Representation::Ntt); + c1.change_representation(Representation::Ntt); + } - c0 += &ct.c[0].substitute(&self.element)?; + c0 += &ct.c[0].substitute(&self.element)?; - Ok(Ciphertext { - par: ct.par.clone(), - seed: None, - c: vec![c0, c1], - level: self.ksk.ciphertext_level, - }) - } + Ok(Ciphertext { + par: ct.par.clone(), + seed: None, + c: vec![c0, c1], + level: self.ksk.ciphertext_level, + }) + } } impl From<&GaloisKey> for GaloisKeyProto { - fn from(value: &GaloisKey) -> Self { - let mut gk = GaloisKeyProto::new(); - gk.exponent = value.element.exponent as u32; - gk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk)); - gk - } + fn from(value: &GaloisKey) -> Self { + let mut gk = GaloisKeyProto::new(); + gk.exponent = value.element.exponent as u32; + gk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk)); + gk + } } impl TryConvertFrom<&GaloisKeyProto> for GaloisKey { - fn try_convert_from(value: &GaloisKeyProto, par: &Arc) -> Result { - if par.moduli.len() == 1 { - Err(Error::DefaultError( - "Invalid parameters for a relinearization key".to_string(), - )) - } else if value.ksk.is_some() { - let ksk = KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?; + fn try_convert_from(value: &GaloisKeyProto, par: &Arc) -> Result { + if par.moduli.len() == 1 { + Err(Error::DefaultError( + "Invalid parameters for a relinearization key".to_string(), + )) + } else if value.ksk.is_some() { + let ksk = KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?; - let ctx = par.ctx_at_level(ksk.ciphertext_level)?; - let element = SubstitutionExponent::new(ctx, value.exponent as usize) - .map_err(Error::MathError)?; + let ctx = par.ctx_at_level(ksk.ciphertext_level)?; + let element = SubstitutionExponent::new(ctx, value.exponent as usize) + .map_err(Error::MathError)?; - Ok(GaloisKey { element, ksk }) - } else { - Err(Error::DefaultError("Invalid serialization".to_string())) - } - } + Ok(GaloisKey { element, ksk }) + } else { + Err(Error::DefaultError("Invalid serialization".to_string())) + } + } } #[cfg(test)] mod tests { - use super::GaloisKey; - use crate::bfv::{ - proto::bfv::GaloisKey as GaloisKeyProto, traits::TryConvertFrom, BfvParameters, Encoding, - Plaintext, SecretKey, - }; - use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use super::GaloisKey; + use crate::bfv::{ + proto::bfv::GaloisKey as GaloisKeyProto, traits::TryConvertFrom, BfvParameters, Encoding, + Plaintext, SecretKey, + }; + use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn relinearization() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(3, 8)), - ] { - for _ in 0..30 { - let sk = SecretKey::random(¶ms, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let row_size = params.degree() >> 1; + #[test] + fn relinearization() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(3, 8)), + ] { + for _ in 0..30 { + let sk = SecretKey::random(¶ms, &mut rng); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let row_size = params.degree() >> 1; - let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; - let ct = sk.try_encrypt(&pt, &mut rng)?; + let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?; + let ct = sk.try_encrypt(&pt, &mut rng)?; - for i in 1..2 * params.degree() { - if i & 1 == 0 { - assert!(GaloisKey::new(&sk, i, 0, 0, &mut rng).is_err()) - } else { - let gk = GaloisKey::new(&sk, i, 0, 0, &mut rng)?; - let ct2 = gk.relinearize(&ct)?; - println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? }); + for i in 1..2 * params.degree() { + if i & 1 == 0 { + assert!(GaloisKey::new(&sk, i, 0, 0, &mut rng).is_err()) + } else { + let gk = GaloisKey::new(&sk, i, 0, 0, &mut rng)?; + let ct2 = gk.relinearize(&ct)?; + println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? }); - if i == 3 { - let pt = sk.try_decrypt(&ct2)?; + if i == 3 { + let pt = sk.try_decrypt(&ct2)?; - // The expected result is rotated one on the left - let mut expected = vec![0u64; params.degree()]; - expected[..row_size - 1].copy_from_slice(&v[1..row_size]); - expected[row_size - 1] = v[0]; - expected[row_size..2 * row_size - 1] - .copy_from_slice(&v[row_size + 1..]); - expected[2 * row_size - 1] = v[row_size]; - assert_eq!(&Vec::::try_decode(&pt, Encoding::simd())?, &expected) - } else if i == params.degree() * 2 - 1 { - let pt = sk.try_decrypt(&ct2)?; + // The expected result is rotated one on the left + let mut expected = vec![0u64; params.degree()]; + expected[..row_size - 1].copy_from_slice(&v[1..row_size]); + expected[row_size - 1] = v[0]; + expected[row_size..2 * row_size - 1] + .copy_from_slice(&v[row_size + 1..]); + expected[2 * row_size - 1] = v[row_size]; + assert_eq!(&Vec::::try_decode(&pt, Encoding::simd())?, &expected) + } else if i == params.degree() * 2 - 1 { + let pt = sk.try_decrypt(&ct2)?; - // The expected result has its rows swapped - let mut expected = vec![0u64; params.degree()]; - expected[..row_size].copy_from_slice(&v[row_size..]); - expected[row_size..].copy_from_slice(&v[..row_size]); - assert_eq!(&Vec::::try_decode(&pt, Encoding::simd())?, &expected) - } - } - } - } - } - Ok(()) - } + // The expected result has its rows swapped + let mut expected = vec![0u64; params.degree()]; + expected[..row_size].copy_from_slice(&v[row_size..]); + expected[row_size..].copy_from_slice(&v[..row_size]); + assert_eq!(&Vec::::try_decode(&pt, Encoding::simd())?, &expected) + } + } + } + } + } + Ok(()) + } - #[test] - fn proto_conversion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(4, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?; - let proto = GaloisKeyProto::from(&gk); - assert_eq!(gk, GaloisKey::try_convert_from(&proto, ¶ms)?); - } - Ok(()) - } + #[test] + fn proto_conversion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(4, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?; + let proto = GaloisKeyProto::from(&gk); + assert_eq!(gk, GaloisKey::try_convert_from(&proto, ¶ms)?); + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/keys/key_switching_key.rs b/crates/fhe/src/bfv/keys/key_switching_key.rs index 8e2aaf0..d57ae76 100644 --- a/crates/fhe/src/bfv/keys/key_switching_key.rs +++ b/crates/fhe/src/bfv/keys/key_switching_key.rs @@ -1,15 +1,15 @@ //! Key-switching keys for the BFV encryption scheme use crate::bfv::{ - proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, - traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey, + proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, + traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey, }; use crate::{Error, Result}; use fhe_math::rq::traits::TryConvertFrom; use fhe_math::rq::Context; use fhe_math::{ - rns::RnsContext, - rq::{Poly, Representation}, + rns::RnsContext, + rq::{Poly, Representation}, }; use fhe_traits::{DeserializeWithContext, Serialize}; use itertools::izip; @@ -21,335 +21,335 @@ use zeroize::Zeroizing; /// Key switching key for the BFV encryption scheme. #[derive(Debug, PartialEq, Eq, Clone)] pub struct KeySwitchingKey { - /// The parameters of the underlying BFV encryption scheme. - pub(crate) par: Arc, + /// The parameters of the underlying BFV encryption scheme. + pub(crate) par: Arc, - /// The (optional) seed that generated the polynomials c1. - pub(crate) seed: Option<::Seed>, + /// The (optional) seed that generated the polynomials c1. + pub(crate) seed: Option<::Seed>, - /// The key switching elements c0. - pub(crate) c0: Box<[Poly]>, + /// The key switching elements c0. + pub(crate) c0: Box<[Poly]>, - /// The key switching elements c1. - pub(crate) c1: Box<[Poly]>, + /// The key switching elements c1. + pub(crate) c1: Box<[Poly]>, - /// The level and context of the polynomials that will be key switched. - pub(crate) ciphertext_level: usize, - pub(crate) ctx_ciphertext: Arc, + /// The level and context of the polynomials that will be key switched. + pub(crate) ciphertext_level: usize, + pub(crate) ctx_ciphertext: Arc, - /// The level and context of the key switching key. - pub(crate) ksk_level: usize, - pub(crate) ctx_ksk: Arc, + /// The level and context of the key switching key. + pub(crate) ksk_level: usize, + pub(crate) ctx_ksk: Arc, } impl KeySwitchingKey { - /// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial - /// `from`. - pub fn new( - sk: &SecretKey, - from: &Poly, - ciphertext_level: usize, - ksk_level: usize, - rng: &mut R, - ) -> Result { - let ctx_ksk = sk.par.ctx_at_level(ksk_level)?; - let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; + /// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial + /// `from`. + pub fn new( + sk: &SecretKey, + from: &Poly, + ciphertext_level: usize, + ksk_level: usize, + rng: &mut R, + ) -> Result { + let ctx_ksk = sk.par.ctx_at_level(ksk_level)?; + let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; - if ctx_ksk.moduli().len() == 1 { - return Err(Error::DefaultError( - "These parameters do not support key switching".to_string(), - )); - } + if ctx_ksk.moduli().len() == 1 { + return Err(Error::DefaultError( + "These parameters do not support key switching".to_string(), + )); + } - if from.ctx() != ctx_ksk { - return Err(Error::DefaultError( - "Incorrect context for polynomial from".to_string(), - )); - } + if from.ctx() != ctx_ksk { + return Err(Error::DefaultError( + "Incorrect context for polynomial from".to_string(), + )); + } - let mut seed = ::Seed::default(); - rng.fill(&mut seed); - let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len()); - let c0 = Self::generate_c0(sk, from, &c1, rng)?; + let mut seed = ::Seed::default(); + rng.fill(&mut seed); + let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len()); + let c0 = Self::generate_c0(sk, from, &c1, rng)?; - Ok(Self { - par: sk.par.clone(), - seed: Some(seed), - c0: c0.into_boxed_slice(), - c1: c1.into_boxed_slice(), - ciphertext_level, - ctx_ciphertext: ctx_ciphertext.clone(), - ksk_level, - ctx_ksk: ctx_ksk.clone(), - }) - } + Ok(Self { + par: sk.par.clone(), + seed: Some(seed), + c0: c0.into_boxed_slice(), + c1: c1.into_boxed_slice(), + ciphertext_level, + ctx_ciphertext: ctx_ciphertext.clone(), + ksk_level, + ctx_ksk: ctx_ksk.clone(), + }) + } - /// Generate the c1's from the seed - fn generate_c1( - ctx: &Arc, - seed: ::Seed, - size: usize, - ) -> Vec { - let mut c1 = Vec::with_capacity(size); - let mut rng = ChaCha8Rng::from_seed(seed); - (0..size).for_each(|_| { - let mut seed_i = ::Seed::default(); - rng.fill(&mut seed_i); - let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i); - unsafe { a.allow_variable_time_computations() } - c1.push(a); - }); - c1 - } + /// Generate the c1's from the seed + fn generate_c1( + ctx: &Arc, + seed: ::Seed, + size: usize, + ) -> Vec { + let mut c1 = Vec::with_capacity(size); + let mut rng = ChaCha8Rng::from_seed(seed); + (0..size).for_each(|_| { + let mut seed_i = ::Seed::default(); + rng.fill(&mut seed_i); + let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i); + unsafe { a.allow_variable_time_computations() } + c1.push(a); + }); + c1 + } - /// Generate the c0's from the c1's and the secret key - fn generate_c0( - sk: &SecretKey, - from: &Poly, - c1: &[Poly], - rng: &mut R, - ) -> Result> { - if c1.is_empty() { - return Err(Error::DefaultError("Empty number of c1's".to_string())); - } - if from.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError( - "Unexpected representation for from".to_string(), - )); - } + /// Generate the c0's from the c1's and the secret key + fn generate_c0( + sk: &SecretKey, + from: &Poly, + c1: &[Poly], + rng: &mut R, + ) -> Result> { + if c1.is_empty() { + return Err(Error::DefaultError("Empty number of c1's".to_string())); + } + if from.representation() != &Representation::PowerBasis { + return Err(Error::DefaultError( + "Unexpected representation for from".to_string(), + )); + } - let size = c1.len(); + let size = c1.len(); - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - c1[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + let mut s = Zeroizing::new(Poly::try_convert_from( + sk.coeffs.as_ref(), + c1[0].ctx(), + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); - let rns = RnsContext::new(&sk.par.moduli[..size])?; - let c0 = c1 - .iter() - .enumerate() - .map(|(i, c1i)| { - let mut a_s = Zeroizing::new(c1i.clone()); - a_s.disallow_variable_time_computations(); - a_s.change_representation(Representation::Ntt); - *a_s.as_mut() *= s.as_ref(); - a_s.change_representation(Representation::PowerBasis); + let rns = RnsContext::new(&sk.par.moduli[..size])?; + let c0 = c1 + .iter() + .enumerate() + .map(|(i, c1i)| { + let mut a_s = Zeroizing::new(c1i.clone()); + a_s.disallow_variable_time_computations(); + a_s.change_representation(Representation::Ntt); + *a_s.as_mut() *= s.as_ref(); + a_s.change_representation(Representation::PowerBasis); - let mut b = - Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; - b -= &a_s; + let mut b = + Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?; + b -= &a_s; - let gi = rns.get_garner(i).unwrap(); - let g_i_from = Zeroizing::new(gi * from); - b += &g_i_from; + let gi = rns.get_garner(i).unwrap(); + let g_i_from = Zeroizing::new(gi * from); + b += &g_i_from; - // It is now safe to enable variable time computations. - unsafe { b.allow_variable_time_computations() } - b.change_representation(Representation::NttShoup); - Ok(b) - }) - .collect::>>()?; + // It is now safe to enable variable time computations. + unsafe { b.allow_variable_time_computations() } + b.change_representation(Representation::NttShoup); + Ok(b) + }) + .collect::>>()?; - Ok(c0) - } + Ok(c0) + } - /// Key switch a polynomial. - pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> { - if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { - return Err(Error::DefaultError( - "The input polynomial does not have the correct context.".to_string(), - )); - } - if p.representation() != &Representation::PowerBasis { - return Err(Error::DefaultError("Incorrect representation".to_string())); - } + /// Key switch a polynomial. + pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> { + if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() { + return Err(Error::DefaultError( + "The input polynomial does not have the correct context.".to_string(), + )); + } + if p.representation() != &Representation::PowerBasis { + return Err(Error::DefaultError("Incorrect representation".to_string())); + } - let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); - for (c2_i_coefficients, c0_i, c1_i) in izip!( - p.coefficients().outer_iter(), - self.c0.iter(), - self.c1.iter() - ) { - let mut c2_i = unsafe { - Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( - c2_i_coefficients.as_slice().unwrap(), - &self.ctx_ksk, - ) - }; - c0 += &(&c2_i * c0_i); - c2_i *= c1_i; - c1 += &c2_i; - } - Ok((c0, c1)) - } + let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt); + for (c2_i_coefficients, c0_i, c1_i) in izip!( + p.coefficients().outer_iter(), + self.c0.iter(), + self.c1.iter() + ) { + let mut c2_i = unsafe { + Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( + c2_i_coefficients.as_slice().unwrap(), + &self.ctx_ksk, + ) + }; + c0 += &(&c2_i * c0_i); + c2_i *= c1_i; + c1 += &c2_i; + } + Ok((c0, c1)) + } } impl From<&KeySwitchingKey> for KeySwitchingKeyProto { - fn from(value: &KeySwitchingKey) -> Self { - let mut ksk = KeySwitchingKeyProto::new(); - if let Some(seed) = value.seed.as_ref() { - ksk.seed = seed.to_vec(); - } else { - ksk.c1.reserve_exact(value.c1.len()); - for c1 in value.c1.iter() { - ksk.c1.push(c1.to_bytes()) - } - } - ksk.c0.reserve_exact(value.c0.len()); - for c0 in value.c0.iter() { - ksk.c0.push(c0.to_bytes()) - } - ksk.ciphertext_level = value.ciphertext_level as u32; - ksk.ksk_level = value.ksk_level as u32; - ksk - } + fn from(value: &KeySwitchingKey) -> Self { + let mut ksk = KeySwitchingKeyProto::new(); + if let Some(seed) = value.seed.as_ref() { + ksk.seed = seed.to_vec(); + } else { + ksk.c1.reserve_exact(value.c1.len()); + for c1 in value.c1.iter() { + ksk.c1.push(c1.to_bytes()) + } + } + ksk.c0.reserve_exact(value.c0.len()); + for c0 in value.c0.iter() { + ksk.c0.push(c0.to_bytes()) + } + ksk.ciphertext_level = value.ciphertext_level as u32; + ksk.ksk_level = value.ksk_level as u32; + ksk + } } impl BfvTryConvertFrom<&KeySwitchingKeyProto> for KeySwitchingKey { - fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc) -> Result { - let ciphertext_level = value.ciphertext_level as usize; - let ksk_level = value.ksk_level as usize; - let ctx_ksk = par.ctx_at_level(ksk_level)?; - let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?; + fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc) -> Result { + let ciphertext_level = value.ciphertext_level as usize; + let ksk_level = value.ksk_level as usize; + let ctx_ksk = par.ctx_at_level(ksk_level)?; + let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?; - if value.c0.len() != ctx_ciphertext.moduli().len() { - return Err(Error::DefaultError( - "Incorrect number of values in c0".to_string(), - )); - } + if value.c0.len() != ctx_ciphertext.moduli().len() { + return Err(Error::DefaultError( + "Incorrect number of values in c0".to_string(), + )); + } - let seed = if value.seed.is_empty() { - if value.c1.len() != ctx_ciphertext.moduli().len() { - return Err(Error::DefaultError( - "Incorrect number of values in c1".to_string(), - )); - } - None - } else { - let unwrapped = ::Seed::try_from(value.seed.clone()); - if unwrapped.is_err() { - return Err(Error::DefaultError("Invalid seed".to_string())); - } - Some(unwrapped.unwrap()) - }; + let seed = if value.seed.is_empty() { + if value.c1.len() != ctx_ciphertext.moduli().len() { + return Err(Error::DefaultError( + "Incorrect number of values in c1".to_string(), + )); + } + None + } else { + let unwrapped = ::Seed::try_from(value.seed.clone()); + if unwrapped.is_err() { + return Err(Error::DefaultError("Invalid seed".to_string())); + } + Some(unwrapped.unwrap()) + }; - let c1 = if let Some(seed) = seed { - Self::generate_c1(ctx_ksk, seed, value.c0.len()) - } else { - value - .c1 - .iter() - .map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError)) - .collect::>>()? - }; + let c1 = if let Some(seed) = seed { + Self::generate_c1(ctx_ksk, seed, value.c0.len()) + } else { + value + .c1 + .iter() + .map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError)) + .collect::>>()? + }; - let c0 = value - .c0 - .iter() - .map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError)) - .collect::>>()?; + let c0 = value + .c0 + .iter() + .map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError)) + .collect::>>()?; - Ok(Self { - par: par.clone(), - seed, - c0: c0.into_boxed_slice(), - c1: c1.into_boxed_slice(), - ciphertext_level, - ctx_ciphertext: ctx_ciphertext.clone(), - ksk_level, - ctx_ksk: ctx_ksk.clone(), - }) - } + Ok(Self { + par: par.clone(), + seed, + c0: c0.into_boxed_slice(), + c1: c1.into_boxed_slice(), + ciphertext_level, + ctx_ciphertext: ctx_ciphertext.clone(), + ksk_level, + ctx_ksk: ctx_ksk.clone(), + }) + } } #[cfg(test)] mod tests { - use crate::bfv::{ - keys::key_switching_key::KeySwitchingKey, - proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, traits::TryConvertFrom, BfvParameters, - SecretKey, - }; - use fhe_math::{ - rns::RnsContext, - rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}, - }; - use num_bigint::BigUint; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use crate::bfv::{ + keys::key_switching_key::KeySwitchingKey, + proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, traits::TryConvertFrom, BfvParameters, + SecretKey, + }; + use fhe_math::{ + rns::RnsContext, + rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}, + }; + use num_bigint::BigUint; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn constructor() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(3, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng); - assert!(ksk.is_ok()); - } - Ok(()) - } + #[test] + fn constructor() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(3, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(0)?; + let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng); + assert!(ksk.is_ok()); + } + Ok(()) + } - #[test] - fn key_switch() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [Arc::new(BfvParameters::default(6, 8))] { - for _ in 0..100 { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(0)?; - let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); + #[test] + fn key_switch() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [Arc::new(BfvParameters::default(6, 8))] { + for _ in 0..100 { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(0)?; + let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; + let mut s = Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx, + false, + Representation::PowerBasis, + ) + .map_err(crate::Error::MathError)?; + s.change_representation(Representation::Ntt); - let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); - let (c0, c1) = ksk.key_switch(&input)?; + let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng); + let (c0, c1) = ksk.key_switch(&input)?; - let mut c2 = &c0 + &(&c1 * &s); - c2.change_representation(Representation::PowerBasis); + let mut c2 = &c0 + &(&c1 * &s); + c2.change_representation(Representation::PowerBasis); - input.change_representation(Representation::Ntt); - p.change_representation(Representation::Ntt); - let mut c3 = &input * &p; - c3.change_representation(Representation::PowerBasis); + input.change_representation(Representation::Ntt); + p.change_representation(Representation::Ntt); + let mut c3 = &input * &p; + c3.change_representation(Representation::PowerBasis); - let rns = RnsContext::new(¶ms.moduli)?; - Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { - assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70) - }); - } - } - Ok(()) - } + let rns = RnsContext::new(¶ms.moduli)?; + Vec::::from(&(&c2 - &c3)).iter().for_each(|b| { + assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70) + }); + } + } + Ok(()) + } - #[test] - fn proto_conversion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(3, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let ctx = params.ctx_at_level(0)?; - let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; - let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; - let ksk_proto = KeySwitchingKeyProto::from(&ksk); - assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?); - } - Ok(()) - } + #[test] + fn proto_conversion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(3, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let ctx = params.ctx_at_level(0)?; + let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?; + let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?; + let ksk_proto = KeySwitchingKeyProto::from(&ksk); + assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?); + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/keys/public_key.rs b/crates/fhe/src/bfv/keys/public_key.rs index a057717..cabb077 100644 --- a/crates/fhe/src/bfv/keys/public_key.rs +++ b/crates/fhe/src/bfv/keys/public_key.rs @@ -2,8 +2,8 @@ use crate::bfv::traits::TryConvertFrom; use crate::bfv::{ - proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto}, - BfvParameters, Ciphertext, Encoding, Plaintext, + proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto}, + BfvParameters, Ciphertext, Encoding, Plaintext, }; use crate::{Error, Result}; use fhe_math::rq::{Poly, Representation}; @@ -18,188 +18,188 @@ use super::SecretKey; /// Public key for the BFV encryption scheme. #[derive(Debug, PartialEq, Eq, Clone)] pub struct PublicKey { - pub(crate) par: Arc, - pub(crate) c: Ciphertext, + pub(crate) par: Arc, + pub(crate) c: Ciphertext, } impl PublicKey { - /// Generate a new [`PublicKey`] from a [`SecretKey`]. - pub fn new(sk: &SecretKey, rng: &mut R) -> Self { - let zero = Plaintext::zero(Encoding::poly(), &sk.par).unwrap(); - let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap(); - // The polynomials of a public key should not allow for variable time - // computation. - c.c.iter_mut() - .for_each(|p| p.disallow_variable_time_computations()); - Self { - par: sk.par.clone(), - c, - } - } + /// Generate a new [`PublicKey`] from a [`SecretKey`]. + pub fn new(sk: &SecretKey, rng: &mut R) -> Self { + let zero = Plaintext::zero(Encoding::poly(), &sk.par).unwrap(); + let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap(); + // The polynomials of a public key should not allow for variable time + // computation. + c.c.iter_mut() + .for_each(|p| p.disallow_variable_time_computations()); + Self { + par: sk.par.clone(), + c, + } + } } impl FheParametrized for PublicKey { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl FheEncrypter for PublicKey { - type Error = Error; + type Error = Error; - fn try_encrypt( - &self, - pt: &Plaintext, - rng: &mut R, - ) -> Result { - let mut ct = self.c.clone(); - while ct.level != pt.level { - ct.mod_switch_to_next_level(); - } + fn try_encrypt( + &self, + pt: &Plaintext, + rng: &mut R, + ) -> Result { + let mut ct = self.c.clone(); + while ct.level != pt.level { + ct.mod_switch_to_next_level(); + } - let ctx = self.par.ctx_at_level(ct.level)?; - let u = Zeroizing::new(Poly::small( - ctx, - Representation::Ntt, - self.par.variance, - rng, - )?); - let e1 = Zeroizing::new(Poly::small( - ctx, - Representation::Ntt, - self.par.variance, - rng, - )?); - let e2 = Zeroizing::new(Poly::small( - ctx, - Representation::Ntt, - self.par.variance, - rng, - )?); + let ctx = self.par.ctx_at_level(ct.level)?; + let u = Zeroizing::new(Poly::small( + ctx, + Representation::Ntt, + self.par.variance, + rng, + )?); + let e1 = Zeroizing::new(Poly::small( + ctx, + Representation::Ntt, + self.par.variance, + rng, + )?); + let e2 = Zeroizing::new(Poly::small( + ctx, + Representation::Ntt, + self.par.variance, + rng, + )?); - let m = Zeroizing::new(pt.to_poly()); - let mut c0 = u.as_ref() * &ct.c[0]; - c0 += &e1; - c0 += &m; - let mut c1 = u.as_ref() * &ct.c[1]; - c1 += &e2; + let m = Zeroizing::new(pt.to_poly()); + let mut c0 = u.as_ref() * &ct.c[0]; + c0 += &e1; + c0 += &m; + let mut c1 = u.as_ref() * &ct.c[1]; + c1 += &e2; - // It is now safe to enable variable time computations. - unsafe { - c0.allow_variable_time_computations(); - c1.allow_variable_time_computations() - } + // It is now safe to enable variable time computations. + unsafe { + c0.allow_variable_time_computations(); + c1.allow_variable_time_computations() + } - Ok(Ciphertext { - par: self.par.clone(), - seed: None, - c: vec![c0, c1], - level: ct.level, - }) - } + Ok(Ciphertext { + par: self.par.clone(), + seed: None, + c: vec![c0, c1], + level: ct.level, + }) + } } impl From<&PublicKey> for PublicKeyProto { - fn from(pk: &PublicKey) -> Self { - let mut proto = PublicKeyProto::new(); - proto.c = MessageField::some(CiphertextProto::from(&pk.c)); - proto - } + fn from(pk: &PublicKey) -> Self { + let mut proto = PublicKeyProto::new(); + proto.c = MessageField::some(CiphertextProto::from(&pk.c)); + proto + } } impl Serialize for PublicKey { - fn to_bytes(&self) -> Vec { - PublicKeyProto::from(self).write_to_bytes().unwrap() - } + fn to_bytes(&self) -> Vec { + PublicKeyProto::from(self).write_to_bytes().unwrap() + } } impl DeserializeParametrized for PublicKey { - type Error = Error; + type Error = Error; - fn from_bytes(bytes: &[u8], par: &Arc) -> Result { - let proto = - PublicKeyProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?; - if proto.c.is_some() { - let mut c = Ciphertext::try_convert_from(&proto.c.unwrap(), par)?; - if c.level != 0 { - Err(Error::SerializationError) - } else { - // The polynomials of a public key should not allow for variable time - // computation. - c.c.iter_mut() - .for_each(|p| p.disallow_variable_time_computations()); - Ok(Self { - par: par.clone(), - c, - }) - } - } else { - Err(Error::SerializationError) - } - } + fn from_bytes(bytes: &[u8], par: &Arc) -> Result { + let proto = + PublicKeyProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?; + if proto.c.is_some() { + let mut c = Ciphertext::try_convert_from(&proto.c.unwrap(), par)?; + if c.level != 0 { + Err(Error::SerializationError) + } else { + // The polynomials of a public key should not allow for variable time + // computation. + c.c.iter_mut() + .for_each(|p| p.disallow_variable_time_computations()); + Ok(Self { + par: par.clone(), + c, + }) + } + } else { + Err(Error::SerializationError) + } + } } #[cfg(test)] mod tests { - use super::PublicKey; - use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext, SecretKey}; - use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use super::PublicKey; + use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext, SecretKey}; + use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn keygen() -> Result<(), Box> { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(1, 8)); - let sk = SecretKey::random(¶ms, &mut rng); - let pk = PublicKey::new(&sk, &mut rng); - assert_eq!(pk.par, params); - assert_eq!( - sk.try_decrypt(&pk.c)?, - Plaintext::zero(Encoding::poly(), ¶ms)? - ); - Ok(()) - } + #[test] + fn keygen() -> Result<(), Box> { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(1, 8)); + let sk = SecretKey::random(¶ms, &mut rng); + let pk = PublicKey::new(&sk, &mut rng); + assert_eq!(pk.par, params); + assert_eq!( + sk.try_decrypt(&pk.c)?, + Plaintext::zero(Encoding::poly(), ¶ms)? + ); + Ok(()) + } - #[test] - fn encrypt_decrypt() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - for level in 0..params.max_level() { - for _ in 0..20 { - let sk = SecretKey::random(¶ms, &mut rng); - let pk = PublicKey::new(&sk, &mut rng); + #[test] + fn encrypt_decrypt() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + for level in 0..params.max_level() { + for _ in 0..20 { + let sk = SecretKey::random(¶ms, &mut rng); + let pk = PublicKey::new(&sk, &mut rng); - let pt = Plaintext::try_encode( - ¶ms.plaintext.random_vec(params.degree(), &mut rng), - Encoding::poly_at_level(level), - ¶ms, - )?; - let ct = pk.try_encrypt(&pt, &mut rng)?; - let pt2 = sk.try_decrypt(&ct)?; + let pt = Plaintext::try_encode( + ¶ms.plaintext.random_vec(params.degree(), &mut rng), + Encoding::poly_at_level(level), + ¶ms, + )?; + let ct = pk.try_encrypt(&pt, &mut rng)?; + let pt2 = sk.try_decrypt(&ct)?; - println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); - assert_eq!(pt2, pt); - } - } - } + println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); + assert_eq!(pt2, pt); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn test_serialize() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let pk = PublicKey::new(&sk, &mut rng); - let bytes = pk.to_bytes(); - assert_eq!(pk, PublicKey::from_bytes(&bytes, ¶ms)?); - } - Ok(()) - } + #[test] + fn test_serialize() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let pk = PublicKey::new(&sk, &mut rng); + let bytes = pk.to_bytes(); + assert_eq!(pk, PublicKey::from_bytes(&bytes, ¶ms)?); + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/keys/relinearization_key.rs b/crates/fhe/src/bfv/keys/relinearization_key.rs index 888ea73..ac0fb73 100644 --- a/crates/fhe/src/bfv/keys/relinearization_key.rs +++ b/crates/fhe/src/bfv/keys/relinearization_key.rs @@ -4,15 +4,15 @@ use std::sync::Arc; use super::key_switching_key::KeySwitchingKey; use crate::bfv::{ - proto::bfv::{ - KeySwitchingKey as KeySwitchingKeyProto, RelinearizationKey as RelinearizationKeyProto, - }, - traits::TryConvertFrom, - BfvParameters, Ciphertext, SecretKey, + proto::bfv::{ + KeySwitchingKey as KeySwitchingKeyProto, RelinearizationKey as RelinearizationKeyProto, + }, + traits::TryConvertFrom, + BfvParameters, Ciphertext, SecretKey, }; use crate::{Error, Result}; use fhe_math::rq::{ - switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation, + switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation, }; use fhe_traits::{DeserializeParametrized, FheParametrized, Serialize}; use protobuf::{Message, MessageField}; @@ -24,284 +24,284 @@ use zeroize::Zeroizing; /// which switch from `s^2` to `s` where `s` is the secret key. #[derive(Debug, PartialEq, Eq, Clone)] pub struct RelinearizationKey { - pub(crate) ksk: KeySwitchingKey, + pub(crate) ksk: KeySwitchingKey, } impl RelinearizationKey { - /// Generate a [`RelinearizationKey`] from a [`SecretKey`]. - pub fn new(sk: &SecretKey, rng: &mut R) -> Result { - Self::new_leveled_internal(sk, 0, 0, rng) - } + /// Generate a [`RelinearizationKey`] from a [`SecretKey`]. + pub fn new(sk: &SecretKey, rng: &mut R) -> Result { + Self::new_leveled_internal(sk, 0, 0, rng) + } - /// Generate a [`RelinearizationKey`] from a [`SecretKey`]. - pub fn new_leveled( - sk: &SecretKey, - ciphertext_level: usize, - key_level: usize, - rng: &mut R, - ) -> Result { - Self::new_leveled_internal(sk, ciphertext_level, key_level, rng) - } + /// Generate a [`RelinearizationKey`] from a [`SecretKey`]. + pub fn new_leveled( + sk: &SecretKey, + ciphertext_level: usize, + key_level: usize, + rng: &mut R, + ) -> Result { + Self::new_leveled_internal(sk, ciphertext_level, key_level, rng) + } - fn new_leveled_internal( - sk: &SecretKey, - ciphertext_level: usize, - key_level: usize, - rng: &mut R, - ) -> Result { - let ctx_relin_key = sk.par.ctx_at_level(key_level)?; - let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; + fn new_leveled_internal( + sk: &SecretKey, + ciphertext_level: usize, + key_level: usize, + rng: &mut R, + ) -> Result { + let ctx_relin_key = sk.par.ctx_at_level(key_level)?; + let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?; - if ctx_relin_key.moduli().len() == 1 { - return Err(Error::DefaultError( - "These parameters do not support key switching".to_string(), - )); - } + if ctx_relin_key.moduli().len() == 1 { + return Err(Error::DefaultError( + "These parameters do not support key switching".to_string(), + )); + } - let mut s = Zeroizing::new(Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx_ciphertext, - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref()); - s2.change_representation(Representation::PowerBasis); - let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?; - let s2_switched_up = Zeroizing::new(s2.mod_switch_to(&switcher_up)?); - let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?; - Ok(Self { ksk }) - } + let mut s = Zeroizing::new(Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx_ciphertext, + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); + let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref()); + s2.change_representation(Representation::PowerBasis); + let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?; + let s2_switched_up = Zeroizing::new(s2.mod_switch_to(&switcher_up)?); + let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?; + Ok(Self { ksk }) + } - /// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`] - pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> { - if ct.c.len() != 3 { - Err(Error::DefaultError( - "Only supports relinearization of ciphertext with 3 parts".to_string(), - )) - } else if ct.level != self.ksk.ciphertext_level { - Err(Error::DefaultError( - "Ciphertext has incorrect level".to_string(), - )) - } else { - let mut c2 = ct.c[2].clone(); - c2.change_representation(Representation::PowerBasis); + /// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`] + pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> { + if ct.c.len() != 3 { + Err(Error::DefaultError( + "Only supports relinearization of ciphertext with 3 parts".to_string(), + )) + } else if ct.level != self.ksk.ciphertext_level { + Err(Error::DefaultError( + "Ciphertext has incorrect level".to_string(), + )) + } else { + let mut c2 = ct.c[2].clone(); + c2.change_representation(Representation::PowerBasis); - #[allow(unused_mut)] - let (mut c0, mut c1) = self.relinearizes_poly(&c2)?; + #[allow(unused_mut)] + let (mut c0, mut c1) = self.relinearizes_poly(&c2)?; - if c0.ctx() != ct.c[0].ctx() { - c0.change_representation(Representation::PowerBasis); - c1.change_representation(Representation::PowerBasis); - c0.mod_switch_down_to(ct.c[0].ctx())?; - c1.mod_switch_down_to(ct.c[1].ctx())?; - c0.change_representation(Representation::Ntt); - c1.change_representation(Representation::Ntt); - } + if c0.ctx() != ct.c[0].ctx() { + c0.change_representation(Representation::PowerBasis); + c1.change_representation(Representation::PowerBasis); + c0.mod_switch_down_to(ct.c[0].ctx())?; + c1.mod_switch_down_to(ct.c[1].ctx())?; + c0.change_representation(Representation::Ntt); + c1.change_representation(Representation::Ntt); + } - ct.c[0] += &c0; - ct.c[1] += &c1; - ct.c.truncate(2); - Ok(()) - } - } + ct.c[0] += &c0; + ct.c[1] += &c1; + ct.c.truncate(2); + Ok(()) + } + } - /// Relinearize using polynomials. - pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> { - self.ksk.key_switch(c2) - } + /// Relinearize using polynomials. + pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> { + self.ksk.key_switch(c2) + } } impl From<&RelinearizationKey> for RelinearizationKeyProto { - fn from(value: &RelinearizationKey) -> Self { - let mut rk = RelinearizationKeyProto::new(); - rk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk)); - rk - } + fn from(value: &RelinearizationKey) -> Self { + let mut rk = RelinearizationKeyProto::new(); + rk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk)); + rk + } } impl TryConvertFrom<&RelinearizationKeyProto> for RelinearizationKey { - fn try_convert_from(value: &RelinearizationKeyProto, par: &Arc) -> Result { - if par.moduli.len() == 1 { - Err(Error::DefaultError( - "Invalid parameters for a relinearization key".to_string(), - )) - } else if value.ksk.is_some() { - Ok(RelinearizationKey { - ksk: KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?, - }) - } else { - Err(Error::DefaultError("Invalid serialization".to_string())) - } - } + fn try_convert_from(value: &RelinearizationKeyProto, par: &Arc) -> Result { + if par.moduli.len() == 1 { + Err(Error::DefaultError( + "Invalid parameters for a relinearization key".to_string(), + )) + } else if value.ksk.is_some() { + Ok(RelinearizationKey { + ksk: KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?, + }) + } else { + Err(Error::DefaultError("Invalid serialization".to_string())) + } + } } impl Serialize for RelinearizationKey { - fn to_bytes(&self) -> Vec { - RelinearizationKeyProto::from(self) - .write_to_bytes() - .unwrap() - } + fn to_bytes(&self) -> Vec { + RelinearizationKeyProto::from(self) + .write_to_bytes() + .unwrap() + } } impl FheParametrized for RelinearizationKey { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl DeserializeParametrized for RelinearizationKey { - type Error = Error; + type Error = Error; - fn from_bytes(bytes: &[u8], par: &Arc) -> Result { - let rk = RelinearizationKeyProto::parse_from_bytes(bytes); - if let Ok(rk) = rk { - RelinearizationKey::try_convert_from(&rk, par) - } else { - Err(Error::DefaultError("Invalid serialization".to_string())) - } - } + fn from_bytes(bytes: &[u8], par: &Arc) -> Result { + let rk = RelinearizationKeyProto::parse_from_bytes(bytes); + if let Ok(rk) = rk { + RelinearizationKey::try_convert_from(&rk, par) + } else { + Err(Error::DefaultError("Invalid serialization".to_string())) + } + } } #[cfg(test)] mod tests { - use super::RelinearizationKey; - use crate::bfv::{ - proto::bfv::RelinearizationKey as RelinearizationKeyProto, traits::TryConvertFrom, - BfvParameters, Ciphertext, Encoding, SecretKey, - }; - use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}; - use fhe_traits::{FheDecoder, FheDecrypter}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use super::RelinearizationKey; + use crate::bfv::{ + proto::bfv::RelinearizationKey as RelinearizationKeyProto, traits::TryConvertFrom, + BfvParameters, Ciphertext, Encoding, SecretKey, + }; + use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}; + use fhe_traits::{FheDecoder, FheDecrypter}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn relinearization() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [Arc::new(BfvParameters::default(6, 8))] { - for _ in 0..100 { - let sk = SecretKey::random(¶ms, &mut rng); - let rk = RelinearizationKey::new(&sk, &mut rng)?; + #[test] + fn relinearization() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [Arc::new(BfvParameters::default(6, 8))] { + for _ in 0..100 { + let sk = SecretKey::random(¶ms, &mut rng); + let rk = RelinearizationKey::new(&sk, &mut rng)?; - let ctx = params.ctx_at_level(0)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); - let s2 = &s * &s; + let ctx = params.ctx_at_level(0)?; + let mut s = Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx, + false, + Representation::PowerBasis, + ) + .map_err(crate::Error::MathError)?; + s.change_representation(Representation::Ntt); + let s2 = &s * &s; - // Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2, - // c1, c2) encrypting 0. - let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng); - let c1 = Poly::random(ctx, Representation::Ntt, &mut rng); - let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?; - c0.change_representation(Representation::Ntt); - c0 -= &(&c1 * &s); - c0 -= &(&c2 * &s2); - let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?; + // Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2, + // c1, c2) encrypting 0. + let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng); + let c1 = Poly::random(ctx, Representation::Ntt, &mut rng); + let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?; + c0.change_representation(Representation::Ntt); + c0 -= &(&c1 * &s); + c0 -= &(&c2 * &s2); + let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?; - // Relinearize the extended ciphertext! - rk.relinearizes(&mut ct)?; - assert_eq!(ct.c.len(), 2); + // Relinearize the extended ciphertext! + rk.relinearizes(&mut ct)?; + assert_eq!(ct.c.len(), 2); - // Check that the relinearization by polynomials works the same way - c2.change_representation(Representation::PowerBasis); - let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?; - c0r.change_representation(Representation::PowerBasis); - c0r.mod_switch_down_to(c0.ctx())?; - c1r.change_representation(Representation::PowerBasis); - c1r.mod_switch_down_to(c1.ctx())?; - c0r.change_representation(Representation::Ntt); - c1r.change_representation(Representation::Ntt); - assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?); + // Check that the relinearization by polynomials works the same way + c2.change_representation(Representation::PowerBasis); + let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?; + c0r.change_representation(Representation::PowerBasis); + c0r.mod_switch_down_to(c0.ctx())?; + c1r.change_representation(Representation::PowerBasis); + c1r.mod_switch_down_to(c1.ctx())?; + c0r.change_representation(Representation::Ntt); + c1r.change_representation(Representation::Ntt); + assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?); - // Print the noise and decrypt - println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); - let pt = sk.try_decrypt(&ct)?; - let w = Vec::::try_decode(&pt, Encoding::poly())?; - assert_eq!(w, &[0u64; 8]); - } - } - Ok(()) - } + // Print the noise and decrypt + println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); + let pt = sk.try_decrypt(&ct)?; + let w = Vec::::try_decode(&pt, Encoding::poly())?; + assert_eq!(w, &[0u64; 8]); + } + } + Ok(()) + } - #[test] - fn relinearization_leveled() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [Arc::new(BfvParameters::default(5, 8))] { - for ciphertext_level in 0..3 { - for key_level in 0..ciphertext_level { - for _ in 0..10 { - let sk = SecretKey::random(¶ms, &mut rng); - let rk = RelinearizationKey::new_leveled( - &sk, - ciphertext_level, - key_level, - &mut rng, - )?; + #[test] + fn relinearization_leveled() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [Arc::new(BfvParameters::default(5, 8))] { + for ciphertext_level in 0..3 { + for key_level in 0..ciphertext_level { + for _ in 0..10 { + let sk = SecretKey::random(¶ms, &mut rng); + let rk = RelinearizationKey::new_leveled( + &sk, + ciphertext_level, + key_level, + &mut rng, + )?; - let ctx = params.ctx_at_level(ciphertext_level)?; - let mut s = Poly::try_convert_from( - sk.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - ) - .map_err(crate::Error::MathError)?; - s.change_representation(Representation::Ntt); - let s2 = &s * &s; - // Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * - // s^2, c1, c2) encrypting 0. - let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng); - let c1 = Poly::random(ctx, Representation::Ntt, &mut rng); - let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?; - c0.change_representation(Representation::Ntt); - c0 -= &(&c1 * &s); - c0 -= &(&c2 * &s2); - let mut ct = - Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?; + let ctx = params.ctx_at_level(ciphertext_level)?; + let mut s = Poly::try_convert_from( + sk.coeffs.as_ref(), + ctx, + false, + Representation::PowerBasis, + ) + .map_err(crate::Error::MathError)?; + s.change_representation(Representation::Ntt); + let s2 = &s * &s; + // Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * + // s^2, c1, c2) encrypting 0. + let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng); + let c1 = Poly::random(ctx, Representation::Ntt, &mut rng); + let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?; + c0.change_representation(Representation::Ntt); + c0 -= &(&c1 * &s); + c0 -= &(&c2 * &s2); + let mut ct = + Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?; - // Relinearize the extended ciphertext! - rk.relinearizes(&mut ct)?; - assert_eq!(ct.c.len(), 2); + // Relinearize the extended ciphertext! + rk.relinearizes(&mut ct)?; + assert_eq!(ct.c.len(), 2); - // Check that the relinearization by polynomials works the same way - c2.change_representation(Representation::PowerBasis); - let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?; - c0r.change_representation(Representation::PowerBasis); - c0r.mod_switch_down_to(c0.ctx())?; - c1r.change_representation(Representation::PowerBasis); - c1r.mod_switch_down_to(c1.ctx())?; - c0r.change_representation(Representation::Ntt); - c1r.change_representation(Representation::Ntt); - assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?); + // Check that the relinearization by polynomials works the same way + c2.change_representation(Representation::PowerBasis); + let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?; + c0r.change_representation(Representation::PowerBasis); + c0r.mod_switch_down_to(c0.ctx())?; + c1r.change_representation(Representation::PowerBasis); + c1r.mod_switch_down_to(c1.ctx())?; + c0r.change_representation(Representation::Ntt); + c1r.change_representation(Representation::Ntt); + assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?); - // Print the noise and decrypt - println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); - let pt = sk.try_decrypt(&ct)?; - let w = Vec::::try_decode(&pt, Encoding::poly())?; - assert_eq!(w, &[0u64; 8]); - } - } - } - } - Ok(()) - } + // Print the noise and decrypt + println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); + let pt = sk.try_decrypt(&ct)?; + let w = Vec::::try_decode(&pt, Encoding::poly())?; + assert_eq!(w, &[0u64; 8]); + } + } + } + } + Ok(()) + } - #[test] - fn proto_conversion() -> Result<(), Box> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(3, 8)), - ] { - let sk = SecretKey::random(¶ms, &mut rng); - let rk = RelinearizationKey::new(&sk, &mut rng)?; - let proto = RelinearizationKeyProto::from(&rk); - assert_eq!(rk, RelinearizationKey::try_convert_from(&proto, ¶ms)?); - } - Ok(()) - } + #[test] + fn proto_conversion() -> Result<(), Box> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(3, 8)), + ] { + let sk = SecretKey::random(¶ms, &mut rng); + let rk = RelinearizationKey::new(&sk, &mut rng)?; + let proto = RelinearizationKeyProto::from(&rk); + assert_eq!(rk, RelinearizationKey::try_convert_from(&proto, ¶ms)?); + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/keys/secret_key.rs b/crates/fhe/src/bfv/keys/secret_key.rs index 3efc0e6..786045c 100644 --- a/crates/fhe/src/bfv/keys/secret_key.rs +++ b/crates/fhe/src/bfv/keys/secret_key.rs @@ -3,8 +3,8 @@ use crate::bfv::{BfvParameters, Ciphertext, Plaintext}; use crate::{Error, Result}; use fhe_math::{ - rq::{traits::TryConvertFrom, Poly, Representation}, - zq::Modulus, + rq::{traits::TryConvertFrom, Poly, Representation}, + zq::Modulus, }; use fhe_traits::{FheDecrypter, FheEncrypter, FheParametrized}; use fhe_util::sample_vec_cbd; @@ -18,249 +18,249 @@ use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing}; /// Secret key for the BFV encryption scheme. #[derive(Debug, PartialEq, Eq, Clone)] pub struct SecretKey { - pub(crate) par: Arc, - pub(crate) coeffs: Box<[i64]>, + pub(crate) par: Arc, + pub(crate) coeffs: Box<[i64]>, } impl Zeroize for SecretKey { - fn zeroize(&mut self) { - self.coeffs.zeroize(); - } + fn zeroize(&mut self) { + self.coeffs.zeroize(); + } } impl ZeroizeOnDrop for SecretKey {} impl SecretKey { - /// Generate a random [`SecretKey`]. - pub fn random(par: &Arc, rng: &mut R) -> Self { - let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap(); - Self::new(s_coefficients, par) - } + /// Generate a random [`SecretKey`]. + pub fn random(par: &Arc, rng: &mut R) -> Self { + let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap(); + Self::new(s_coefficients, par) + } - /// Generate a [`SecretKey`] from its coefficients. - pub(crate) fn new(coeffs: Vec, par: &Arc) -> Self { - Self { - par: par.clone(), - coeffs: coeffs.into_boxed_slice(), - } - } + /// Generate a [`SecretKey`] from its coefficients. + pub(crate) fn new(coeffs: Vec, par: &Arc) -> Self { + Self { + par: par.clone(), + coeffs: coeffs.into_boxed_slice(), + } + } - /// Measure the noise in a [`Ciphertext`]. - /// - /// # Safety - /// - /// This operations may run in a variable time depending on the value of the - /// noise. - pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result { - let plaintext = Zeroizing::new(self.try_decrypt(ct)?); - let m = Zeroizing::new(plaintext.to_poly()); + /// Measure the noise in a [`Ciphertext`]. + /// + /// # Safety + /// + /// This operations may run in a variable time depending on the value of the + /// noise. + pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result { + let plaintext = Zeroizing::new(self.try_decrypt(ct)?); + let m = Zeroizing::new(plaintext.to_poly()); - // Let's create a secret key with the ciphertext context - let mut s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - ct.c[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - let mut si = s.clone(); + // Let's create a secret key with the ciphertext context + let mut s = Zeroizing::new(Poly::try_convert_from( + self.coeffs.as_ref(), + ct.c[0].ctx(), + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); + let mut si = s.clone(); - // Let's disable variable time computations - let mut c = Zeroizing::new(ct.c[0].clone()); - c.disallow_variable_time_computations(); + // Let's disable variable time computations + let mut c = Zeroizing::new(ct.c[0].clone()); + c.disallow_variable_time_computations(); - for i in 1..ct.c.len() { - let mut cis = Zeroizing::new(ct.c[i].clone()); - cis.disallow_variable_time_computations(); - *cis.as_mut() *= si.as_ref(); - *c.as_mut() += &cis; - *si.as_mut() *= s.as_ref(); - } - *c.as_mut() -= &m; - c.change_representation(Representation::PowerBasis); + for i in 1..ct.c.len() { + let mut cis = Zeroizing::new(ct.c[i].clone()); + cis.disallow_variable_time_computations(); + *cis.as_mut() *= si.as_ref(); + *c.as_mut() += &cis; + *si.as_mut() *= s.as_ref(); + } + *c.as_mut() -= &m; + c.change_representation(Representation::PowerBasis); - let ciphertext_modulus = ct.c[0].ctx().modulus(); - let mut noise = 0usize; - for coeff in Vec::::from(c.as_ref()) { - noise = std::cmp::max( - noise, - std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize, - ) - } + let ciphertext_modulus = ct.c[0].ctx().modulus(); + let mut noise = 0usize; + for coeff in Vec::::from(c.as_ref()) { + noise = std::cmp::max( + noise, + std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize, + ) + } - Ok(noise) - } + Ok(noise) + } - pub(crate) fn encrypt_poly( - &self, - p: &Poly, - rng: &mut R, - ) -> Result { - assert_eq!(p.representation(), &Representation::Ntt); + pub(crate) fn encrypt_poly( + &self, + p: &Poly, + rng: &mut R, + ) -> Result { + assert_eq!(p.representation(), &Representation::Ntt); - let level = self.par.level_of_ctx(p.ctx())?; + let level = self.par.level_of_ctx(p.ctx())?; - let mut seed = ::Seed::default(); - thread_rng().fill(&mut seed); + let mut seed = ::Seed::default(); + thread_rng().fill(&mut seed); - // Let's create a secret key with the ciphertext context - let mut s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - p.ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); + // Let's create a secret key with the ciphertext context + let mut s = Zeroizing::new(Poly::try_convert_from( + self.coeffs.as_ref(), + p.ctx(), + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); - let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed); - let a_s = Zeroizing::new(&a * s.as_ref()); + let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed); + let a_s = Zeroizing::new(&a * s.as_ref()); - let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng) - .map_err(Error::MathError)?; - b -= &a_s; - b += p; + let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng) + .map_err(Error::MathError)?; + b -= &a_s; + b += p; - // It is now safe to enable variable time computations. - unsafe { - a.allow_variable_time_computations(); - b.allow_variable_time_computations() - } + // It is now safe to enable variable time computations. + unsafe { + a.allow_variable_time_computations(); + b.allow_variable_time_computations() + } - Ok(Ciphertext { - par: self.par.clone(), - seed: Some(seed), - c: vec![b, a], - level, - }) - } + Ok(Ciphertext { + par: self.par.clone(), + seed: Some(seed), + c: vec![b, a], + level, + }) + } } impl FheParametrized for SecretKey { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl FheEncrypter for SecretKey { - type Error = Error; + type Error = Error; - fn try_encrypt( - &self, - pt: &Plaintext, - rng: &mut R, - ) -> Result { - assert_eq!(self.par, pt.par); - let m = Zeroizing::new(pt.to_poly()); - self.encrypt_poly(m.as_ref(), rng) - } + fn try_encrypt( + &self, + pt: &Plaintext, + rng: &mut R, + ) -> Result { + assert_eq!(self.par, pt.par); + let m = Zeroizing::new(pt.to_poly()); + self.encrypt_poly(m.as_ref(), rng) + } } impl FheDecrypter for SecretKey { - type Error = Error; + type Error = Error; - fn try_decrypt(&self, ct: &Ciphertext) -> Result { - if self.par != ct.par { - Err(Error::DefaultError( - "Incompatible BFV parameters".to_string(), - )) - } else { - // Let's create a secret key with the ciphertext context - let mut s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - ct.c[0].ctx(), - false, - Representation::PowerBasis, - )?); - s.change_representation(Representation::Ntt); - let mut si = s.clone(); + fn try_decrypt(&self, ct: &Ciphertext) -> Result<Plaintext> { + if self.par != ct.par { + Err(Error::DefaultError( + "Incompatible BFV parameters".to_string(), + )) + } else { + // Let's create a secret key with the ciphertext context + let mut s = Zeroizing::new(Poly::try_convert_from( + self.coeffs.as_ref(), + ct.c[0].ctx(), + false, + Representation::PowerBasis, + )?); + s.change_representation(Representation::Ntt); + let mut si = s.clone(); - let mut c = Zeroizing::new(ct.c[0].clone()); - c.disallow_variable_time_computations(); + let mut c = Zeroizing::new(ct.c[0].clone()); + c.disallow_variable_time_computations(); - for i in 1..ct.c.len() { - let mut cis = Zeroizing::new(ct.c[i].clone()); - cis.disallow_variable_time_computations(); - *cis.as_mut() *= si.as_ref(); - *c.as_mut() += &cis; - *si.as_mut() *= s.as_ref(); - } - c.change_representation(Representation::PowerBasis); + for i in 1..ct.c.len() { + let mut cis = Zeroizing::new(ct.c[i].clone()); + cis.disallow_variable_time_computations(); + *cis.as_mut() *= si.as_ref(); + *c.as_mut() += &cis; + *si.as_mut() *= s.as_ref(); + } + c.change_representation(Representation::PowerBasis); - let d = Zeroizing::new(c.scale(&self.par.scalers[ct.level])?); + let d = Zeroizing::new(c.scale(&self.par.scalers[ct.level])?); - // TODO: Can we handle plaintext moduli that are BigUint? - let v = Zeroizing::new( - Vec::<u64>::from(d.as_ref()) - .iter_mut() - .map(|vi| *vi + self.par.plaintext.modulus()) - .collect_vec(), - ); - let mut w = v[..self.par.degree()].to_vec(); - let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?; - q.reduce_vec(&mut w); - self.par.plaintext.reduce_vec(&mut w); + // TODO: Can we handle plaintext moduli that are BigUint? + let v = Zeroizing::new( + Vec::<u64>::from(d.as_ref()) + .iter_mut() + .map(|vi| *vi + self.par.plaintext.modulus()) + .collect_vec(), + ); + let mut w = v[..self.par.degree()].to_vec(); + let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?; + q.reduce_vec(&mut w); + self.par.plaintext.reduce_vec(&mut w); - let mut poly = - Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?; - poly.change_representation(Representation::Ntt); + let mut poly = + Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?; + poly.change_representation(Representation::Ntt); - let pt = Plaintext { - par: self.par.clone(), - value: w.into_boxed_slice(), - encoding: None, - poly_ntt: poly, - level: ct.level, - }; + let pt = Plaintext { + par: self.par.clone(), + value: w.into_boxed_slice(), + encoding: None, + poly_ntt: poly, + level: ct.level, + }; - Ok(pt) - } - } + Ok(pt) + } + } } #[cfg(test)] mod tests { - use super::SecretKey; - use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext}; - use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use super::SecretKey; + use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext}; + use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn keygen() { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(1, 8)); - let sk = SecretKey::random(&params, &mut rng); - assert_eq!(sk.par, params); + #[test] + fn keygen() { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(1, 8)); + let sk = SecretKey::random(&params, &mut rng); + assert_eq!(sk.par, params); - sk.coeffs.iter().for_each(|ci| { - // Check that this is a small polynomial - assert!((*ci).abs() <= 2 * sk.par.variance as i64) - }) - } + sk.coeffs.iter().for_each(|ci| { + // Check that this is a small polynomial + assert!((*ci).abs() <= 2 * sk.par.variance as i64) + }) + } - #[test] - fn encrypt_decrypt() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - for level in 0..params.max_level() { - for _ in 0..20 { - let sk = SecretKey::random(&params, &mut rng); + #[test] + fn encrypt_decrypt() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + for level in 0..params.max_level() { + for _ in 0..20 { + let sk = SecretKey::random(&params, &mut rng); - let pt = Plaintext::try_encode( - &params.plaintext.random_vec(params.degree(), &mut rng), - Encoding::poly_at_level(level), - &params, - )?; - let ct = sk.try_encrypt(&pt, &mut rng)?; - let pt2 = sk.try_decrypt(&ct)?; + let pt = Plaintext::try_encode( + &params.plaintext.random_vec(params.degree(), &mut rng), + Encoding::poly_at_level(level), + &params, + )?; + let ct = sk.try_encrypt(&pt, &mut rng)?; + let pt2 = sk.try_decrypt(&ct)?; - println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); - assert_eq!(pt2, pt); - } - } - } + println!("Noise: {}", unsafe { sk.measure_noise(&ct)? }); + assert_eq!(pt2, pt); + } + } + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index 1bb5d7b..9a63f38 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -5,46 +5,46 @@ use itertools::{izip, Itertools}; use ndarray::{Array, Array2}; use crate::{ - bfv::{Ciphertext, Plaintext}, - Error, Result, + bfv::{Ciphertext, Plaintext}, + Error, Result, }; /// Computes the Fused-Mul-Add operation `out[i] += x[i] * y[i]` unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) { - let n = out.len(); - assert_eq!(x.len(), n); - assert_eq!(y.len(), n); + let n = out.len(); + assert_eq!(x.len(), n); + assert_eq!(y.len(), n); - macro_rules! fma_at { - ($idx:expr) => { - *out.get_unchecked_mut($idx) += - (*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128); - }; - } + macro_rules! fma_at { + ($idx:expr) => { + *out.get_unchecked_mut($idx) += + (*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128); + }; + } - let r = n / 16; - for i in 0..r { - fma_at!(16 * i); - fma_at!(16 * i + 1); - fma_at!(16 * i + 2); - fma_at!(16 * i + 3); - fma_at!(16 * i + 4); - fma_at!(16 * i + 5); - fma_at!(16 * i + 6); - fma_at!(16 * i + 7); - fma_at!(16 * i + 8); - fma_at!(16 * i + 9); - fma_at!(16 * i + 10); - fma_at!(16 * i + 11); - fma_at!(16 * i + 12); - fma_at!(16 * i + 13); - fma_at!(16 * i + 14); - fma_at!(16 * i + 15); - } + let r = n / 16; + for i in 0..r { + fma_at!(16 * i); + fma_at!(16 * i + 1); + fma_at!(16 * i + 2); + fma_at!(16 * i + 3); + fma_at!(16 * i + 4); + fma_at!(16 * i + 5); + fma_at!(16 * i + 6); + fma_at!(16 * i + 7); + fma_at!(16 * i + 8); + fma_at!(16 * i + 9); + fma_at!(16 * i + 10); + fma_at!(16 * i + 11); + fma_at!(16 * i + 12); + fma_at!(16 * i + 13); + fma_at!(16 * i + 14); + fma_at!(16 * i + 15); + } - for i in 0..n % 16 { - fma_at!(16 * r + i); - } + for i in 0..n % 16 { + fma_at!(16 * r + i); + } } /// Compute the dot product between an iterator of [`Ciphertext`] and an @@ -53,146 +53,146 @@ unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) { /// number of parts. pub fn dot_product_scalar<'a, I, J>(ct: I, pt: J) -> Result<Ciphertext> where - I: Iterator<Item = &'a Ciphertext> + Clone, - J: Iterator<Item = &'a Plaintext> + Clone, + I: Iterator<Item = &'a Ciphertext> + Clone, + J: Iterator<Item = &'a Plaintext> + Clone, { - let count = min(ct.clone().count(), pt.clone().count()); - if count == 0 { - return Err(Error::DefaultError( - "At least one iterator is empty".to_string(), - )); - } - let ct_first = ct.clone().next().unwrap(); - let ctx = ct_first.c[0].ctx(); + let count = min(ct.clone().count(), pt.clone().count()); + if count == 0 { + return Err(Error::DefaultError( + "At least one iterator is empty".to_string(), + )); + } + let ct_first = ct.clone().next().unwrap(); + let ctx = ct_first.c[0].ctx(); - if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| { - cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len() - }) { - return Err(Error::DefaultError("Mismatched parameters".to_string())); - } - if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) { - return Err(Error::DefaultError( - "Mismatched number of parts in the ciphertexts".to_string(), - )); - } + if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| { + cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len() + }) { + return Err(Error::DefaultError("Mismatched parameters".to_string())); + } + if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) { + return Err(Error::DefaultError( + "Mismatched number of parts in the ciphertexts".to_string(), + )); + } - let max_acc = ctx - .moduli() - .iter() - .map(|qi| 1u128 << (2 * qi.leading_zeros())) - .collect_vec(); - let min_of_max = max_acc.iter().min().unwrap(); + let max_acc = ctx + .moduli() + .iter() + .map(|qi| 1u128 << (2 * qi.leading_zeros())) + .collect_vec(); + let min_of_max = max_acc.iter().min().unwrap(); - if count as u128 > *min_of_max { - // Too many ciphertexts for the optimized method, instead, we call - // `poly_dot_product`. - let c = (0..ct_first.c.len()) - .map(|i| { - poly_dot_product( - ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }), - pt.clone().map(|pti| &pti.poly_ntt), - ) - .map_err(Error::MathError) - }) - .collect::<Result<Vec<Poly>>>()?; + if count as u128 > *min_of_max { + // Too many ciphertexts for the optimized method, instead, we call + // `poly_dot_product`. + let c = (0..ct_first.c.len()) + .map(|i| { + poly_dot_product( + ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }), + pt.clone().map(|pti| &pti.poly_ntt), + ) + .map_err(Error::MathError) + }) + .collect::<Result<Vec<Poly>>>()?; - Ok(Ciphertext { - par: ct_first.par.clone(), - seed: None, - c, - level: ct_first.level, - }) - } else { - let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree())); - for (ciphertext, plaintext) in izip!(ct, pt) { - let pt_coefficients = plaintext.poly_ntt.coefficients(); - for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) { - let ci_coefficients = ci.coefficients(); - for (mut accij, cij, pij) in izip!( - acci.outer_iter_mut(), - ci_coefficients.outer_iter(), - pt_coefficients.outer_iter() - ) { - unsafe { - fma( - accij.as_slice_mut().unwrap(), - cij.as_slice().unwrap(), - pij.as_slice().unwrap(), - ) - } - } - } - } + Ok(Ciphertext { + par: ct_first.par.clone(), + seed: None, + c, + level: ct_first.level, + }) + } else { + let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree())); + for (ciphertext, plaintext) in izip!(ct, pt) { + let pt_coefficients = plaintext.poly_ntt.coefficients(); + for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) { + let ci_coefficients = ci.coefficients(); + for (mut accij, cij, pij) in izip!( + acci.outer_iter_mut(), + ci_coefficients.outer_iter(), + pt_coefficients.outer_iter() + ) { + unsafe { + fma( + accij.as_slice_mut().unwrap(), + cij.as_slice().unwrap(), + pij.as_slice().unwrap(), + ) + } + } + } + } - // Reduce - let mut c = Vec::with_capacity(ct_first.c.len()); - for acci in acc.outer_iter() { - let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree())); - for (mut outij, accij, q) in izip!( - coeffs.outer_iter_mut(), - acci.outer_iter(), - ctx.moduli_operators() - ) { - for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) { - unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) } - } - } - c.push(Poly::try_convert_from( - coeffs, - ctx, - true, - Representation::Ntt, - )?) - } + // Reduce + let mut c = Vec::with_capacity(ct_first.c.len()); + for acci in acc.outer_iter() { + let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree())); + for (mut outij, accij, q) in izip!( + coeffs.outer_iter_mut(), + acci.outer_iter(), + ctx.moduli_operators() + ) { + for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) { + unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) } + } + } + c.push(Poly::try_convert_from( + coeffs, + ctx, + true, + Representation::Ntt, + )?) + } - Ok(Ciphertext { - par: ct_first.par.clone(), - seed: None, - c, - level: ct_first.level, - }) - } + Ok(Ciphertext { + par: ct_first.par.clone(), + seed: None, + c, + level: ct_first.level, + }) + } } #[cfg(test)] mod tests { - use super::dot_product_scalar; - use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey}; - use fhe_traits::{FheEncoder, FheEncrypter}; - use itertools::{izip, Itertools}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use super::dot_product_scalar; + use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey}; + use fhe_traits::{FheEncoder, FheEncrypter}; + use itertools::{izip, Itertools}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn test_dot_product_scalar() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(2, 16)), - ] { - let sk = SecretKey::random(&params, &mut rng); - for size in 1..128 { - let ct = (0..size) - .map(|_| { - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap(); - sk.try_encrypt(&pt, &mut rng).unwrap() - }) - .collect_vec(); - let pt = (0..size) - .map(|_| { - let v = params.plaintext.random_vec(params.degree(), &mut rng); - Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap() - }) - .collect_vec(); + #[test] + fn test_dot_product_scalar() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(2, 16)), + ] { + let sk = SecretKey::random(&params, &mut rng); + for size in 1..128 { + let ct = (0..size) + .map(|_| { + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap(); + sk.try_encrypt(&pt, &mut rng).unwrap() + }) + .collect_vec(); + let pt = (0..size) + .map(|_| { + let v = params.plaintext.random_vec(params.degree(), &mut rng); + Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap() + }) + .collect_vec(); - let r = dot_product_scalar(ct.iter(), pt.iter())?; + let r = dot_product_scalar(ct.iter(), pt.iter())?; - let mut expected = Ciphertext::zero(&params); - izip!(&ct, &pt).for_each(|(cti, pti)| expected += &(cti * pti)); - assert_eq!(r, expected); - } - } - Ok(()) - } + let mut expected = Ciphertext::zero(&params); + izip!(&ct, &pt).for_each(|(cti, pti)| expected += &(cti * pti)); + assert_eq!(r, expected); + } + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/ops/mod.rs b/crates/fhe/src/bfv/ops/mod.rs index e5be7a5..7cbc8dd 100644 --- a/crates/fhe/src/bfv/ops/mod.rs +++ b/crates/fhe/src/bfv/ops/mod.rs @@ -13,611 +13,611 @@ use itertools::{izip, Itertools}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; impl Add<&Ciphertext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn add(self, rhs: &Ciphertext) -> Ciphertext { - let mut self_clone = self.clone(); - self_clone += rhs; - self_clone - } + fn add(self, rhs: &Ciphertext) -> Ciphertext { + let mut self_clone = self.clone(); + self_clone += rhs; + self_clone + } } impl AddAssign<&Ciphertext> for Ciphertext { - fn add_assign(&mut self, rhs: &Ciphertext) { - assert_eq!(self.par, rhs.par); + fn add_assign(&mut self, rhs: &Ciphertext) { + assert_eq!(self.par, rhs.par); - if self.c.is_empty() { - *self = rhs.clone() - } else if !rhs.c.is_empty() { - assert_eq!(self.level, rhs.level); - assert_eq!(self.c.len(), rhs.c.len()); - izip!(&mut self.c, &rhs.c).for_each(|(c1i, c2i)| *c1i += c2i); - self.seed = None - } - } + if self.c.is_empty() { + *self = rhs.clone() + } else if !rhs.c.is_empty() { + assert_eq!(self.level, rhs.level); + assert_eq!(self.c.len(), rhs.c.len()); + izip!(&mut self.c, &rhs.c).for_each(|(c1i, c2i)| *c1i += c2i); + self.seed = None + } + } } impl Add<&Plaintext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn add(self, rhs: &Plaintext) -> Ciphertext { - let mut self_clone = self.clone(); - self_clone += rhs; - self_clone - } + fn add(self, rhs: &Plaintext) -> Ciphertext { + let mut self_clone = self.clone(); + self_clone += rhs; + self_clone + } } impl Add<&Ciphertext> for &Plaintext { - type Output = Ciphertext; + type Output = Ciphertext; - fn add(self, rhs: &Ciphertext) -> Ciphertext { - rhs + self - } + fn add(self, rhs: &Ciphertext) -> Ciphertext { + rhs + self + } } impl AddAssign<&Plaintext> for Ciphertext { - fn add_assign(&mut self, rhs: &Plaintext) { - assert_eq!(self.par, rhs.par); - assert!(!self.c.is_empty()); - assert_eq!(self.level, rhs.level); + fn add_assign(&mut self, rhs: &Plaintext) { + assert_eq!(self.par, rhs.par); + assert!(!self.c.is_empty()); + assert_eq!(self.level, rhs.level); - let poly = rhs.to_poly(); - self.c[0] += &poly; - self.seed = None - } + let poly = rhs.to_poly(); + self.c[0] += &poly; + self.seed = None + } } impl Sub<&Ciphertext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn sub(self, rhs: &Ciphertext) -> Ciphertext { - let mut self_clone = self.clone(); - self_clone -= rhs; - self_clone - } + fn sub(self, rhs: &Ciphertext) -> Ciphertext { + let mut self_clone = self.clone(); + self_clone -= rhs; + self_clone + } } impl SubAssign<&Ciphertext> for Ciphertext { - fn sub_assign(&mut self, rhs: &Ciphertext) { - assert_eq!(self.par, rhs.par); + fn sub_assign(&mut self, rhs: &Ciphertext) { + assert_eq!(self.par, rhs.par); - if self.c.is_empty() { - *self = -rhs - } else if !rhs.c.is_empty() { - assert_eq!(self.level, rhs.level); - assert_eq!(self.c.len(), rhs.c.len()); - izip!(&mut self.c, &rhs.c).for_each(|(c1i, c2i)| *c1i -= c2i); - self.seed = None - } - } + if self.c.is_empty() { + *self = -rhs + } else if !rhs.c.is_empty() { + assert_eq!(self.level, rhs.level); + assert_eq!(self.c.len(), rhs.c.len()); + izip!(&mut self.c, &rhs.c).for_each(|(c1i, c2i)| *c1i -= c2i); + self.seed = None + } + } } impl Sub<&Plaintext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn sub(self, rhs: &Plaintext) -> Ciphertext { - let mut self_clone = self.clone(); - self_clone -= rhs; - self_clone - } + fn sub(self, rhs: &Plaintext) -> Ciphertext { + let mut self_clone = self.clone(); + self_clone -= rhs; + self_clone + } } impl Sub<&Ciphertext> for &Plaintext { - type Output = Ciphertext; + type Output = Ciphertext; - fn sub(self, rhs: &Ciphertext) -> Ciphertext { - -(rhs - self) - } + fn sub(self, rhs: &Ciphertext) -> Ciphertext { + -(rhs - self) + } } impl SubAssign<&Plaintext> for Ciphertext { - fn sub_assign(&mut self, rhs: &Plaintext) { - assert_eq!(self.par, rhs.par); - assert!(!self.c.is_empty()); - assert_eq!(self.level, rhs.level); + fn sub_assign(&mut self, rhs: &Plaintext) { + assert_eq!(self.par, rhs.par); + assert!(!self.c.is_empty()); + assert_eq!(self.level, rhs.level); - let poly = rhs.to_poly(); - self.c[0] -= &poly; - self.seed = None - } + let poly = rhs.to_poly(); + self.c[0] -= &poly; + self.seed = None + } } impl Neg for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn neg(self) -> Ciphertext { - let c = self.c.iter().map(|c1i| -c1i).collect_vec(); - Ciphertext { - par: self.par.clone(), - seed: None, - c, - level: self.level, - } - } + fn neg(self) -> Ciphertext { + let c = self.c.iter().map(|c1i| -c1i).collect_vec(); + Ciphertext { + par: self.par.clone(), + seed: None, + c, + level: self.level, + } + } } impl Neg for Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn neg(mut self) -> Ciphertext { - self.c.iter_mut().for_each(|c1i| *c1i = -&*c1i); - self.seed = None; - self - } + fn neg(mut self) -> Ciphertext { + self.c.iter_mut().for_each(|c1i| *c1i = -&*c1i); + self.seed = None; + self + } } impl MulAssign<&Plaintext> for Ciphertext { - fn mul_assign(&mut self, rhs: &Plaintext) { - assert_eq!(self.par, rhs.par); - if !self.c.is_empty() { - assert_eq!(self.level, rhs.level); - self.c.iter_mut().for_each(|ci| *ci *= &rhs.poly_ntt); - } - self.seed = None - } + fn mul_assign(&mut self, rhs: &Plaintext) { + assert_eq!(self.par, rhs.par); + if !self.c.is_empty() { + assert_eq!(self.level, rhs.level); + self.c.iter_mut().for_each(|ci| *ci *= &rhs.poly_ntt); + } + self.seed = None + } } impl Mul<&Plaintext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn mul(self, rhs: &Plaintext) -> Ciphertext { - let mut self_clone = self.clone(); - self_clone *= rhs; - self_clone - } + fn mul(self, rhs: &Plaintext) -> Ciphertext { + let mut self_clone = self.clone(); + self_clone *= rhs; + self_clone + } } impl Mul<&Ciphertext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn mul(self, rhs: &Ciphertext) -> Ciphertext { - if self.c.is_empty() { - return self.clone(); - } + fn mul(self, rhs: &Ciphertext) -> Ciphertext { + if self.c.is_empty() { + return self.clone(); + } - if rhs == self { - // Squaring operation - let mp = &self.par.mul_params[self.level]; + if rhs == self { + // Squaring operation + let mp = &self.par.mul_params[self.level]; - // Scale all ciphertexts - // let mut now = std::time::SystemTime::now(); - let self_c = self - .c - .iter() - .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) - .collect::<Result<Vec<Poly>>>() - .unwrap(); - // println!("Extend: {:?}", now.elapsed().unwrap()); + // Scale all ciphertexts + // let mut now = std::time::SystemTime::now(); + let self_c = self + .c + .iter() + .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) + .collect::<Result<Vec<Poly>>>() + .unwrap(); + // println!("Extend: {:?}", now.elapsed().unwrap()); - // Multiply - // now = std::time::SystemTime::now(); - let mut c = vec![Poly::zero(&mp.to, Representation::Ntt); 2 * self_c.len() - 1]; - for i in 0..self_c.len() { - for j in 0..self_c.len() { - c[i + j] += &(&self_c[i] * &self_c[j]) - } - } - // println!("Multiply: {:?}", now.elapsed().unwrap()); + // Multiply + // now = std::time::SystemTime::now(); + let mut c = vec![Poly::zero(&mp.to, Representation::Ntt); 2 * self_c.len() - 1]; + for i in 0..self_c.len() { + for j in 0..self_c.len() { + c[i + j] += &(&self_c[i] * &self_c[j]) + } + } + // println!("Multiply: {:?}", now.elapsed().unwrap()); - // Scale - // now = std::time::SystemTime::now(); - let c = c - .iter_mut() - .map(|ci| { - ci.change_representation(Representation::PowerBasis); - let mut ci = ci.scale(&mp.down_scaler).map_err(Error::MathError)?; - ci.change_representation(Representation::Ntt); - Ok(ci) - }) - .collect::<Result<Vec<Poly>>>() - .unwrap(); - // println!("Scale: {:?}", now.elapsed().unwrap()); + // Scale + // now = std::time::SystemTime::now(); + let c = c + .iter_mut() + .map(|ci| { + ci.change_representation(Representation::PowerBasis); + let mut ci = ci.scale(&mp.down_scaler).map_err(Error::MathError)?; + ci.change_representation(Representation::Ntt); + Ok(ci) + }) + .collect::<Result<Vec<Poly>>>() + .unwrap(); + // println!("Scale: {:?}", now.elapsed().unwrap()); - Ciphertext { - par: self.par.clone(), - seed: None, - c, - level: rhs.level, - } - } else { - assert_eq!(self.par, rhs.par); - assert_eq!(self.level, rhs.level); + Ciphertext { + par: self.par.clone(), + seed: None, + c, + level: rhs.level, + } + } else { + assert_eq!(self.par, rhs.par); + assert_eq!(self.level, rhs.level); - let mp = &self.par.mul_params[self.level]; + let mp = &self.par.mul_params[self.level]; - // Scale all ciphertexts - // let mut now = std::time::SystemTime::now(); - let self_c = self - .c - .iter() - .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) - .collect::<Result<Vec<Poly>>>() - .unwrap(); - let other_c = rhs - .c - .iter() - .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) - .collect::<Result<Vec<Poly>>>() - .unwrap(); - // println!("Extend: {:?}", now.elapsed().unwrap()); + // Scale all ciphertexts + // let mut now = std::time::SystemTime::now(); + let self_c = self + .c + .iter() + .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) + .collect::<Result<Vec<Poly>>>() + .unwrap(); + let other_c = rhs + .c + .iter() + .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) + .collect::<Result<Vec<Poly>>>() + .unwrap(); + // println!("Extend: {:?}", now.elapsed().unwrap()); - // Multiply - // now = std::time::SystemTime::now(); - let mut c = - vec![Poly::zero(&mp.to, Representation::Ntt); self_c.len() + other_c.len() - 1]; - for i in 0..self_c.len() { - for j in 0..other_c.len() { - c[i + j] += &(&self_c[i] * &other_c[j]) - } - } - // println!("Multiply: {:?}", now.elapsed().unwrap()); + // Multiply + // now = std::time::SystemTime::now(); + let mut c = + vec![Poly::zero(&mp.to, Representation::Ntt); self_c.len() + other_c.len() - 1]; + for i in 0..self_c.len() { + for j in 0..other_c.len() { + c[i + j] += &(&self_c[i] * &other_c[j]) + } + } + // println!("Multiply: {:?}", now.elapsed().unwrap()); - // Scale - // now = std::time::SystemTime::now(); - let c = c - .iter_mut() - .map(|ci| { - ci.change_representation(Representation::PowerBasis); - let mut ci = ci.scale(&mp.down_scaler).map_err(Error::MathError)?; - ci.change_representation(Representation::Ntt); - Ok(ci) - }) - .collect::<Result<Vec<Poly>>>() - .unwrap(); - // println!("Scale: {:?}", now.elapsed().unwrap()); + // Scale + // now = std::time::SystemTime::now(); + let c = c + .iter_mut() + .map(|ci| { + ci.change_representation(Representation::PowerBasis); + let mut ci = ci.scale(&mp.down_scaler).map_err(Error::MathError)?; + ci.change_representation(Representation::Ntt); + Ok(ci) + }) + .collect::<Result<Vec<Poly>>>() + .unwrap(); + // println!("Scale: {:?}", now.elapsed().unwrap()); - Ciphertext { - par: self.par.clone(), - seed: None, - c, - level: rhs.level, - } - } - } + Ciphertext { + par: self.par.clone(), + seed: None, + c, + level: rhs.level, + } + } + } } #[cfg(test)] mod tests { - use crate::bfv::{ - encoding::EncodingEnum, BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey, - }; - use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; - use rand::{rngs::OsRng, thread_rng}; - use std::{error::Error, sync::Arc}; + use crate::bfv::{ + encoding::EncodingEnum, BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey, + }; + use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; + use rand::{rngs::OsRng, thread_rng}; + use std::{error::Error, sync::Arc}; - #[test] - fn add() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); + #[test] + fn add() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let zero = Ciphertext::zero(&params); - for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let b = params.plaintext.random_vec(params.degree(), &mut rng); - let mut c = a.clone(); - params.plaintext.add_vec(&mut c, &b); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let zero = Ciphertext::zero(&params); + for _ in 0..50 { + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let b = params.plaintext.random_vec(params.degree(), &mut rng); + let mut c = a.clone(); + params.plaintext.add_vec(&mut c, &b); - let sk = SecretKey::random(&params, &mut rng); + let sk = SecretKey::random(&params, &mut rng); - for encoding in [Encoding::poly(), Encoding::simd()] { - let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; - let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; + for encoding in [Encoding::poly(), Encoding::simd()] { + let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; + let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; - let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; - assert_eq!(ct_a, &ct_a + &zero); - assert_eq!(ct_a, &zero + &ct_a); - let ct_b: Ciphertext = sk.try_encrypt(&pt_b, &mut rng)?; - let ct_c = &ct_a + &ct_b; - ct_a += &ct_b; + let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; + assert_eq!(ct_a, &ct_a + &zero); + assert_eq!(ct_a, &zero + &ct_a); + let ct_b: Ciphertext = sk.try_encrypt(&pt_b, &mut rng)?; + let ct_c = &ct_a + &ct_b; + ct_a += &ct_b; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - let pt_c = sk.try_decrypt(&ct_a)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - } - } - } + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + let pt_c = sk.try_decrypt(&ct_a)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn add_scalar() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); + #[test] + fn add_scalar() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let b = params.plaintext.random_vec(params.degree(), &mut rng); - let mut c = a.clone(); - params.plaintext.add_vec(&mut c, &b); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + for _ in 0..50 { + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let b = params.plaintext.random_vec(params.degree(), &mut rng); + let mut c = a.clone(); + params.plaintext.add_vec(&mut c, &b); - let sk = SecretKey::random(&params, &mut rng); + let sk = SecretKey::random(&params, &mut rng); - for encoding in [Encoding::poly(), Encoding::simd()] { - let zero = Plaintext::zero(encoding.clone(), &params)?; - let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; - let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; + for encoding in [Encoding::poly(), Encoding::simd()] { + let zero = Plaintext::zero(encoding.clone(), &params)?; + let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; + let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; - let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; - assert_eq!( - Vec::<u64>::try_decode( - &sk.try_decrypt(&(&ct_a + &zero))?, - encoding.clone() - )?, - a - ); - assert_eq!( - Vec::<u64>::try_decode( - &sk.try_decrypt(&(&zero + &ct_a))?, - encoding.clone() - )?, - a - ); - let ct_c = &ct_a + &pt_b; - ct_a += &pt_b; + let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; + assert_eq!( + Vec::<u64>::try_decode( + &sk.try_decrypt(&(&ct_a + &zero))?, + encoding.clone() + )?, + a + ); + assert_eq!( + Vec::<u64>::try_decode( + &sk.try_decrypt(&(&zero + &ct_a))?, + encoding.clone() + )?, + a + ); + let ct_c = &ct_a + &pt_b; + ct_a += &pt_b; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - let pt_c = sk.try_decrypt(&ct_a)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - } - } - } + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + let pt_c = sk.try_decrypt(&ct_a)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn sub() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - let zero = Ciphertext::zero(&params); - for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let mut a_neg = a.clone(); - params.plaintext.neg_vec(&mut a_neg); - let b = params.plaintext.random_vec(params.degree(), &mut rng); - let mut c = a.clone(); - params.plaintext.sub_vec(&mut c, &b); + #[test] + fn sub() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + let zero = Ciphertext::zero(&params); + for _ in 0..50 { + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let mut a_neg = a.clone(); + params.plaintext.neg_vec(&mut a_neg); + let b = params.plaintext.random_vec(params.degree(), &mut rng); + let mut c = a.clone(); + params.plaintext.sub_vec(&mut c, &b); - let sk = SecretKey::random(&params, &mut rng); + let sk = SecretKey::random(&params, &mut rng); - for encoding in [Encoding::poly(), Encoding::simd()] { - let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; - let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; + for encoding in [Encoding::poly(), Encoding::simd()] { + let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; + let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; - let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; - assert_eq!(ct_a, &ct_a - &zero); - assert_eq!( - Vec::<u64>::try_decode( - &sk.try_decrypt(&(&zero - &ct_a))?, - encoding.clone() - )?, - a_neg - ); - let ct_b: Ciphertext = sk.try_encrypt(&pt_b, &mut rng)?; - let ct_c = &ct_a - &ct_b; - ct_a -= &ct_b; + let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; + assert_eq!(ct_a, &ct_a - &zero); + assert_eq!( + Vec::<u64>::try_decode( + &sk.try_decrypt(&(&zero - &ct_a))?, + encoding.clone() + )?, + a_neg + ); + let ct_b: Ciphertext = sk.try_encrypt(&pt_b, &mut rng)?; + let ct_c = &ct_a - &ct_b; + ct_a -= &ct_b; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - let pt_c = sk.try_decrypt(&ct_a)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - } - } - } + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + let pt_c = sk.try_decrypt(&ct_a)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn sub_scalar() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let mut a_neg = a.clone(); - params.plaintext.neg_vec(&mut a_neg); - let b = params.plaintext.random_vec(params.degree(), &mut rng); - let mut c = a.clone(); - params.plaintext.sub_vec(&mut c, &b); + #[test] + fn sub_scalar() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + for _ in 0..50 { + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let mut a_neg = a.clone(); + params.plaintext.neg_vec(&mut a_neg); + let b = params.plaintext.random_vec(params.degree(), &mut rng); + let mut c = a.clone(); + params.plaintext.sub_vec(&mut c, &b); - let sk = SecretKey::random(&params, &mut rng); + let sk = SecretKey::random(&params, &mut rng); - for encoding in [Encoding::poly(), Encoding::simd()] { - let zero = Plaintext::zero(encoding.clone(), &params)?; - let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; - let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; + for encoding in [Encoding::poly(), Encoding::simd()] { + let zero = Plaintext::zero(encoding.clone(), &params)?; + let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; + let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; - let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; - assert_eq!( - Vec::<u64>::try_decode( - &sk.try_decrypt(&(&ct_a - &zero))?, - encoding.clone() - )?, - a - ); - assert_eq!( - Vec::<u64>::try_decode( - &sk.try_decrypt(&(&zero - &ct_a))?, - encoding.clone() - )?, - a_neg - ); - let ct_c = &ct_a - &pt_b; - ct_a -= &pt_b; + let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; + assert_eq!( + Vec::<u64>::try_decode( + &sk.try_decrypt(&(&ct_a - &zero))?, + encoding.clone() + )?, + a + ); + assert_eq!( + Vec::<u64>::try_decode( + &sk.try_decrypt(&(&zero - &ct_a))?, + encoding.clone() + )?, + a_neg + ); + let ct_c = &ct_a - &pt_b; + ct_a -= &pt_b; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - let pt_c = sk.try_decrypt(&ct_a)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - } - } - } + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + let pt_c = sk.try_decrypt(&ct_a)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn neg() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let mut c = a.clone(); - params.plaintext.neg_vec(&mut c); + #[test] + fn neg() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + for _ in 0..50 { + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let mut c = a.clone(); + params.plaintext.neg_vec(&mut c); - let sk = SecretKey::random(&params, &mut rng); - for encoding in [Encoding::poly(), Encoding::simd()] { - let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; + let sk = SecretKey::random(&params, &mut rng); + for encoding in [Encoding::poly(), Encoding::simd()] { + let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; - let ct_a: Ciphertext = sk.try_encrypt(&pt_a, &mut rng)?; + let ct_a: Ciphertext = sk.try_encrypt(&pt_a, &mut rng)?; - let ct_c = -&ct_a; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + let ct_c = -&ct_a; + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - let ct_c = -ct_a; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - } - } - } + let ct_c = -ct_a; + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn mul_scalar() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); + #[test] + fn mul_scalar() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(1, 8)), - Arc::new(BfvParameters::default(6, 8)), - ] { - for _ in 0..50 { - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let b = params.plaintext.random_vec(params.degree(), &mut rng); + for params in [ + Arc::new(BfvParameters::default(1, 8)), + Arc::new(BfvParameters::default(6, 8)), + ] { + for _ in 0..50 { + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let b = params.plaintext.random_vec(params.degree(), &mut rng); - let sk = SecretKey::random(&params, &mut rng); - for encoding in [Encoding::poly(), Encoding::simd()] { - let mut c = vec![0u64; params.degree()]; - match encoding.encoding { - EncodingEnum::Poly => { - for i in 0..params.degree() { - for j in 0..params.degree() { - if i + j >= params.degree() { - c[(i + j) % params.degree()] = params.plaintext.sub( - c[(i + j) % params.degree()], - params.plaintext.mul(a[i], b[j]), - ); - } else { - c[i + j] = params - .plaintext - .add(c[i + j], params.plaintext.mul(a[i], b[j])); - } - } - } - } - EncodingEnum::Simd => { - c = a.clone(); - params.plaintext.mul_vec(&mut c, &b); - } - } + let sk = SecretKey::random(&params, &mut rng); + for encoding in [Encoding::poly(), Encoding::simd()] { + let mut c = vec![0u64; params.degree()]; + match encoding.encoding { + EncodingEnum::Poly => { + for i in 0..params.degree() { + for j in 0..params.degree() { + if i + j >= params.degree() { + c[(i + j) % params.degree()] = params.plaintext.sub( + c[(i + j) % params.degree()], + params.plaintext.mul(a[i], b[j]), + ); + } else { + c[i + j] = params + .plaintext + .add(c[i + j], params.plaintext.mul(a[i], b[j])); + } + } + } + } + EncodingEnum::Simd => { + c = a.clone(); + params.plaintext.mul_vec(&mut c, &b); + } + } - let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; - let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; + let pt_a = Plaintext::try_encode(&a, encoding.clone(), &params)?; + let pt_b = Plaintext::try_encode(&b, encoding.clone(), &params)?; - let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; - let ct_c = &ct_a * &pt_b; - ct_a *= &pt_b; + let mut ct_a = sk.try_encrypt(&pt_a, &mut rng)?; + let ct_c = &ct_a * &pt_b; + ct_a *= &pt_b; - let pt_c = sk.try_decrypt(&ct_c)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - let pt_c = sk.try_decrypt(&ct_a)?; - assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); - } - } - } + let pt_c = sk.try_decrypt(&ct_c)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + let pt_c = sk.try_decrypt(&ct_a)?; + assert_eq!(Vec::<u64>::try_decode(&pt_c, encoding.clone())?, c); + } + } + } - Ok(()) - } + Ok(()) + } - #[test] - fn mul() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for par in [ - Arc::new(BfvParameters::default(2, 8)), - Arc::new(BfvParameters::default(8, 8)), - ] { - for _ in 0..1 { - // We will encode `values` in an Simd format, and check that the product is - // computed correctly. - let v1 = par.plaintext.random_vec(par.degree(), &mut rng); - let v2 = par.plaintext.random_vec(par.degree(), &mut rng); - let mut expected = v1.clone(); - par.plaintext.mul_vec(&mut expected, &v2); + #[test] + fn mul() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for par in [ + Arc::new(BfvParameters::default(2, 8)), + Arc::new(BfvParameters::default(8, 8)), + ] { + for _ in 0..1 { + // We will encode `values` in an Simd format, and check that the product is + // computed correctly. + let v1 = par.plaintext.random_vec(par.degree(), &mut rng); + let v2 = par.plaintext.random_vec(par.degree(), &mut rng); + let mut expected = v1.clone(); + par.plaintext.mul_vec(&mut expected, &v2); - let sk = SecretKey::random(&par, &mut OsRng); - let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &par)?; - let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &par)?; + let sk = SecretKey::random(&par, &mut OsRng); + let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &par)?; + let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &par)?; - let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; - let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; - let ct3 = &ct1 * &ct2; - let ct4 = &ct3 * &ct3; + let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; + let ct3 = &ct1 * &ct2; + let ct4 = &ct3 * &ct3; - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - let e = expected.clone(); - par.plaintext.mul_vec(&mut expected, &e); - println!("Noise: {}", unsafe { sk.measure_noise(&ct4)? }); - let pt = sk.try_decrypt(&ct4)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - } - } - Ok(()) - } + let e = expected.clone(); + par.plaintext.mul_vec(&mut expected, &e); + println!("Noise: {}", unsafe { sk.measure_noise(&ct4)? }); + let pt = sk.try_decrypt(&ct4)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + } + } + Ok(()) + } - #[test] - fn square() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let par = Arc::new(BfvParameters::default(6, 8)); - for _ in 0..20 { - // We will encode `values` in an Simd format, and check that the product is - // computed correctly. - let v = par.plaintext.random_vec(par.degree(), &mut rng); - let mut expected = v.clone(); - par.plaintext.mul_vec(&mut expected, &v); + #[test] + fn square() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let par = Arc::new(BfvParameters::default(6, 8)); + for _ in 0..20 { + // We will encode `values` in an Simd format, and check that the product is + // computed correctly. + let v = par.plaintext.random_vec(par.degree(), &mut rng); + let mut expected = v.clone(); + par.plaintext.mul_vec(&mut expected, &v); - let sk = SecretKey::random(&par, &mut OsRng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), &par)?; + let sk = SecretKey::random(&par, &mut OsRng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), &par)?; - let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = &ct1 * &ct1; + let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + let ct2 = &ct1 * &ct1; - println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? }); - let pt = sk.try_decrypt(&ct2)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - } - Ok(()) - } + println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? }); + let pt = sk.try_decrypt(&ct2)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/ops/mul.rs b/crates/fhe/src/bfv/ops/mul.rs index 7ce7811..d61920c 100644 --- a/crates/fhe/src/bfv/ops/mul.rs +++ b/crates/fhe/src/bfv/ops/mul.rs @@ -1,16 +1,16 @@ use std::sync::Arc; use fhe_math::{ - rns::ScalingFactor, - rq::{scaler::Scaler, Context, Representation}, - zq::primes::generate_prime, + rns::ScalingFactor, + rq::{scaler::Scaler, Context, Representation}, + zq::primes::generate_prime, }; use fhe_util::div_ceil; use num_bigint::BigUint; use crate::{ - bfv::{keys::RelinearizationKey, BfvParameters, Ciphertext}, - Error, Result, + bfv::{keys::RelinearizationKey, BfvParameters, Ciphertext}, + Error, Result, }; /// Multiplicator that implements a strategy for multiplying. In particular, the @@ -22,390 +22,390 @@ use crate::{ /// - Whether relinearization should be used. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Multiplicator { - par: Arc<BfvParameters>, - pub(crate) extender_lhs: Scaler, - pub(crate) extender_rhs: Scaler, - pub(crate) down_scaler: Scaler, - pub(crate) base_ctx: Arc<Context>, - pub(crate) mul_ctx: Arc<Context>, - rk: Option<RelinearizationKey>, - mod_switch: bool, - level: usize, + par: Arc<BfvParameters>, + pub(crate) extender_lhs: Scaler, + pub(crate) extender_rhs: Scaler, + pub(crate) down_scaler: Scaler, + pub(crate) base_ctx: Arc<Context>, + pub(crate) mul_ctx: Arc<Context>, + rk: Option<RelinearizationKey>, + mod_switch: bool, + level: usize, } impl Multiplicator { - /// Construct a multiplicator using custom scaling factors and extended - /// basis. - pub fn new( - lhs_scaling_factor: ScalingFactor, - rhs_scaling_factor: ScalingFactor, - extended_basis: &[u64], - post_mul_scaling_factor: ScalingFactor, - par: &Arc<BfvParameters>, - ) -> Result<Self> { - Self::new_leveled_internal( - lhs_scaling_factor, - rhs_scaling_factor, - extended_basis, - post_mul_scaling_factor, - 0, - par, - ) - } + /// Construct a multiplicator using custom scaling factors and extended + /// basis. + pub fn new( + lhs_scaling_factor: ScalingFactor, + rhs_scaling_factor: ScalingFactor, + extended_basis: &[u64], + post_mul_scaling_factor: ScalingFactor, + par: &Arc<BfvParameters>, + ) -> Result<Self> { + Self::new_leveled_internal( + lhs_scaling_factor, + rhs_scaling_factor, + extended_basis, + post_mul_scaling_factor, + 0, + par, + ) + } - /// Construct a multiplicator using custom scaling factors and extended - /// basis at a given level. - pub fn new_leveled( - lhs_scaling_factor: ScalingFactor, - rhs_scaling_factor: ScalingFactor, - extended_basis: &[u64], - post_mul_scaling_factor: ScalingFactor, - level: usize, - par: &Arc<BfvParameters>, - ) -> Result<Self> { - Self::new_leveled_internal( - lhs_scaling_factor, - rhs_scaling_factor, - extended_basis, - post_mul_scaling_factor, - level, - par, - ) - } + /// Construct a multiplicator using custom scaling factors and extended + /// basis at a given level. + pub fn new_leveled( + lhs_scaling_factor: ScalingFactor, + rhs_scaling_factor: ScalingFactor, + extended_basis: &[u64], + post_mul_scaling_factor: ScalingFactor, + level: usize, + par: &Arc<BfvParameters>, + ) -> Result<Self> { + Self::new_leveled_internal( + lhs_scaling_factor, + rhs_scaling_factor, + extended_basis, + post_mul_scaling_factor, + level, + par, + ) + } - fn new_leveled_internal( - lhs_scaling_factor: ScalingFactor, - rhs_scaling_factor: ScalingFactor, - extended_basis: &[u64], - post_mul_scaling_factor: ScalingFactor, - level: usize, - par: &Arc<BfvParameters>, - ) -> Result<Self> { - let base_ctx = par.ctx_at_level(level)?; - let mul_ctx = Arc::new(Context::new(extended_basis, par.degree())?); - let extender_lhs = Scaler::new(base_ctx, &mul_ctx, lhs_scaling_factor)?; - let extender_rhs = Scaler::new(base_ctx, &mul_ctx, rhs_scaling_factor)?; - let down_scaler = Scaler::new(&mul_ctx, base_ctx, post_mul_scaling_factor)?; - Ok(Self { - par: par.clone(), - extender_lhs, - extender_rhs, - down_scaler, - base_ctx: base_ctx.clone(), - mul_ctx, - rk: None, - mod_switch: false, - level, - }) - } + fn new_leveled_internal( + lhs_scaling_factor: ScalingFactor, + rhs_scaling_factor: ScalingFactor, + extended_basis: &[u64], + post_mul_scaling_factor: ScalingFactor, + level: usize, + par: &Arc<BfvParameters>, + ) -> Result<Self> { + let base_ctx = par.ctx_at_level(level)?; + let mul_ctx = Arc::new(Context::new(extended_basis, par.degree())?); + let extender_lhs = Scaler::new(base_ctx, &mul_ctx, lhs_scaling_factor)?; + let extender_rhs = Scaler::new(base_ctx, &mul_ctx, rhs_scaling_factor)?; + let down_scaler = Scaler::new(&mul_ctx, base_ctx, post_mul_scaling_factor)?; + Ok(Self { + par: par.clone(), + extender_lhs, + extender_rhs, + down_scaler, + base_ctx: base_ctx.clone(), + mul_ctx, + rk: None, + mod_switch: false, + level, + }) + } - /// Default multiplication strategy using relinearization. - pub fn default(rk: &RelinearizationKey) -> Result<Self> { - let ctx = rk.ksk.par.ctx_at_level(rk.ksk.ciphertext_level)?; + /// Default multiplication strategy using relinearization. + pub fn default(rk: &RelinearizationKey) -> Result<Self> { + let ctx = rk.ksk.par.ctx_at_level(rk.ksk.ciphertext_level)?; - let modulus_size = rk.ksk.par.moduli_sizes()[..ctx.moduli().len()] - .iter() - .sum::<usize>(); - let n_moduli = div_ceil(modulus_size + 60, 62); + let modulus_size = rk.ksk.par.moduli_sizes()[..ctx.moduli().len()] + .iter() + .sum::<usize>(); + let n_moduli = div_ceil(modulus_size + 60, 62); - let mut extended_basis = Vec::with_capacity(ctx.moduli().len() + n_moduli); - extended_basis.append(&mut ctx.moduli().to_vec()); - let mut upper_bound = 1 << 62; - while extended_basis.len() != ctx.moduli().len() + n_moduli { - upper_bound = generate_prime(62, 2 * rk.ksk.par.degree() as u64, upper_bound).unwrap(); - if !extended_basis.contains(&upper_bound) && !ctx.moduli().contains(&upper_bound) { - extended_basis.push(upper_bound) - } - } + let mut extended_basis = Vec::with_capacity(ctx.moduli().len() + n_moduli); + extended_basis.append(&mut ctx.moduli().to_vec()); + let mut upper_bound = 1 << 62; + while extended_basis.len() != ctx.moduli().len() + n_moduli { + upper_bound = generate_prime(62, 2 * rk.ksk.par.degree() as u64, upper_bound).unwrap(); + if !extended_basis.contains(&upper_bound) && !ctx.moduli().contains(&upper_bound) { + extended_basis.push(upper_bound) + } + } - let mut multiplicator = Self::new_leveled_internal( - ScalingFactor::one(), - ScalingFactor::one(), - &extended_basis, - ScalingFactor::new( - &BigUint::from(rk.ksk.par.plaintext.modulus()), - ctx.modulus(), - ), - rk.ksk.ciphertext_level, - &rk.ksk.par, - )?; + let mut multiplicator = Self::new_leveled_internal( + ScalingFactor::one(), + ScalingFactor::one(), + &extended_basis, + ScalingFactor::new( + &BigUint::from(rk.ksk.par.plaintext.modulus()), + ctx.modulus(), + ), + rk.ksk.ciphertext_level, + &rk.ksk.par, + )?; - multiplicator.enable_relinearization(rk)?; - Ok(multiplicator) - } + multiplicator.enable_relinearization(rk)?; + Ok(multiplicator) + } - /// Enable relinearization after multiplication. - pub fn enable_relinearization(&mut self, rk: &RelinearizationKey) -> Result<()> { - let rk_ctx = self.par.ctx_at_level(rk.ksk.ciphertext_level)?; - if rk_ctx != &self.base_ctx { - return Err(Error::DefaultError( - "Invalid relinearization key context".to_string(), - )); - } - self.rk = Some(rk.clone()); - Ok(()) - } + /// Enable relinearization after multiplication. + pub fn enable_relinearization(&mut self, rk: &RelinearizationKey) -> Result<()> { + let rk_ctx = self.par.ctx_at_level(rk.ksk.ciphertext_level)?; + if rk_ctx != &self.base_ctx { + return Err(Error::DefaultError( + "Invalid relinearization key context".to_string(), + )); + } + self.rk = Some(rk.clone()); + Ok(()) + } - /// Enable modulus switching after multiplication (and relinearization, if - /// applicable). - pub fn enable_mod_switching(&mut self) -> Result<()> { - if self.par.ctx_at_level(self.par.max_level())? == &self.base_ctx { - Err(Error::DefaultError( - "Cannot modulo switch as this is already the last level".to_string(), - )) - } else { - self.mod_switch = true; - Ok(()) - } - } + /// Enable modulus switching after multiplication (and relinearization, if + /// applicable). + pub fn enable_mod_switching(&mut self) -> Result<()> { + if self.par.ctx_at_level(self.par.max_level())? == &self.base_ctx { + Err(Error::DefaultError( + "Cannot modulo switch as this is already the last level".to_string(), + )) + } else { + self.mod_switch = true; + Ok(()) + } + } - /// Multiply two ciphertexts using the defined multiplication strategy. - pub fn multiply(&self, lhs: &Ciphertext, rhs: &Ciphertext) -> Result<Ciphertext> { - if lhs.par != self.par || rhs.par != self.par { - return Err(Error::DefaultError( - "Ciphertexts do not have the same parameters".to_string(), - )); - } - if lhs.level != self.level || rhs.level != self.level { - return Err(Error::DefaultError( - "Ciphertexts are not at expected level".to_string(), - )); - } - if lhs.c.len() != 2 || rhs.c.len() != 2 { - return Err(Error::DefaultError( - "Multiplication can only be performed on ciphertexts of size 2".to_string(), - )); - } + /// Multiply two ciphertexts using the defined multiplication strategy. + pub fn multiply(&self, lhs: &Ciphertext, rhs: &Ciphertext) -> Result<Ciphertext> { + if lhs.par != self.par || rhs.par != self.par { + return Err(Error::DefaultError( + "Ciphertexts do not have the same parameters".to_string(), + )); + } + if lhs.level != self.level || rhs.level != self.level { + return Err(Error::DefaultError( + "Ciphertexts are not at expected level".to_string(), + )); + } + if lhs.c.len() != 2 || rhs.c.len() != 2 { + return Err(Error::DefaultError( + "Multiplication can only be performed on ciphertexts of size 2".to_string(), + )); + } - // Extend - let c00 = lhs.c[0].scale(&self.extender_lhs)?; - let c01 = lhs.c[1].scale(&self.extender_lhs)?; - let c10 = rhs.c[0].scale(&self.extender_rhs)?; - let c11 = rhs.c[1].scale(&self.extender_rhs)?; + // Extend + let c00 = lhs.c[0].scale(&self.extender_lhs)?; + let c01 = lhs.c[1].scale(&self.extender_lhs)?; + let c10 = rhs.c[0].scale(&self.extender_rhs)?; + let c11 = rhs.c[1].scale(&self.extender_rhs)?; - // Multiply - let mut c0 = &c00 * &c10; - let mut c1 = &c00 * &c11; - c1 += &(&c01 * &c10); - let mut c2 = &c01 * &c11; - c0.change_representation(Representation::PowerBasis); - c1.change_representation(Representation::PowerBasis); - c2.change_representation(Representation::PowerBasis); + // Multiply + let mut c0 = &c00 * &c10; + let mut c1 = &c00 * &c11; + c1 += &(&c01 * &c10); + let mut c2 = &c01 * &c11; + c0.change_representation(Representation::PowerBasis); + c1.change_representation(Representation::PowerBasis); + c2.change_representation(Representation::PowerBasis); - // Scale - let c0 = c0.scale(&self.down_scaler)?; - let c1 = c1.scale(&self.down_scaler)?; - let c2 = c2.scale(&self.down_scaler)?; + // Scale + let c0 = c0.scale(&self.down_scaler)?; + let c1 = c1.scale(&self.down_scaler)?; + let c2 = c2.scale(&self.down_scaler)?; - let mut c = vec![c0, c1, c2]; + let mut c = vec![c0, c1, c2]; - // Relinearize - if let Some(rk) = self.rk.as_ref() { - #[allow(unused_mut)] - let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?; + // Relinearize + if let Some(rk) = self.rk.as_ref() { + #[allow(unused_mut)] + let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?; - if c0r.ctx() != c[0].ctx() { - c0r.change_representation(Representation::PowerBasis); - c1r.change_representation(Representation::PowerBasis); - c0r.mod_switch_down_to(c[0].ctx())?; - c1r.mod_switch_down_to(c[1].ctx())?; - } else { - c[0].change_representation(Representation::Ntt); - c[1].change_representation(Representation::Ntt); - } + if c0r.ctx() != c[0].ctx() { + c0r.change_representation(Representation::PowerBasis); + c1r.change_representation(Representation::PowerBasis); + c0r.mod_switch_down_to(c[0].ctx())?; + c1r.mod_switch_down_to(c[1].ctx())?; + } else { + c[0].change_representation(Representation::Ntt); + c[1].change_representation(Representation::Ntt); + } - c[0] += &c0r; - c[1] += &c1r; - c.truncate(2); - } + c[0] += &c0r; + c[1] += &c1r; + c.truncate(2); + } - // We construct a ciphertext, but it may not have the right representation for - // the polynomials yet. - let mut c = Ciphertext { - par: self.par.clone(), - seed: None, - c, - level: self.level, - }; + // We construct a ciphertext, but it may not have the right representation for + // the polynomials yet. + let mut c = Ciphertext { + par: self.par.clone(), + seed: None, + c, + level: self.level, + }; - if self.mod_switch { - c.mod_switch_to_next_level(); - } else { - c.c.iter_mut() - .for_each(|p| p.change_representation(Representation::Ntt)); - } + if self.mod_switch { + c.mod_switch_to_next_level(); + } else { + c.c.iter_mut() + .for_each(|p| p.change_representation(Representation::Ntt)); + } - Ok(c) - } + Ok(c) + } } #[cfg(test)] mod tests { - use crate::bfv::{ - BfvParameters, Ciphertext, Encoding, Plaintext, RelinearizationKey, SecretKey, - }; - use fhe_math::{ - rns::{RnsContext, ScalingFactor}, - zq::primes::generate_prime, - }; - use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; - use num_bigint::BigUint; - use rand::{rngs::OsRng, thread_rng}; - use std::{error::Error, sync::Arc}; + use crate::bfv::{ + BfvParameters, Ciphertext, Encoding, Plaintext, RelinearizationKey, SecretKey, + }; + use fhe_math::{ + rns::{RnsContext, ScalingFactor}, + zq::primes::generate_prime, + }; + use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; + use num_bigint::BigUint; + use rand::{rngs::OsRng, thread_rng}; + use std::{error::Error, sync::Arc}; - use super::Multiplicator; + use super::Multiplicator; - #[test] - fn mul() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let par = Arc::new(BfvParameters::default(3, 8)); - for _ in 0..30 { - // We will encode `values` in an Simd format, and check that the product is - // computed correctly. - let values = par.plaintext.random_vec(par.degree(), &mut rng); - let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + #[test] + fn mul() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let par = Arc::new(BfvParameters::default(3, 8)); + for _ in 0..30 { + // We will encode `values` in an Simd format, and check that the product is + // computed correctly. + let values = par.plaintext.random_vec(par.degree(), &mut rng); + let mut expected = values.clone(); + par.plaintext.mul_vec(&mut expected, &values); - let sk = SecretKey::random(&par, &mut OsRng); - let rk = RelinearizationKey::new(&sk, &mut rng)?; - let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; - let ct1 = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = sk.try_encrypt(&pt, &mut rng)?; + let sk = SecretKey::random(&par, &mut OsRng); + let rk = RelinearizationKey::new(&sk, &mut rng)?; + let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; + let ct1 = sk.try_encrypt(&pt, &mut rng)?; + let ct2 = sk.try_encrypt(&pt, &mut rng)?; - let mut multiplicator = Multiplicator::default(&rk)?; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + let mut multiplicator = Multiplicator::default(&rk)?; + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - multiplicator.enable_mod_switching()?; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - assert_eq!(ct3.level, 1); - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - } - Ok(()) - } + multiplicator.enable_mod_switching()?; + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + assert_eq!(ct3.level, 1); + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + } + Ok(()) + } - #[test] - fn mul_at_level() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let par = Arc::new(BfvParameters::default(3, 8)); - for _ in 0..15 { - for level in 0..2 { - let values = par.plaintext.random_vec(par.degree(), &mut rng); - let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + #[test] + fn mul_at_level() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let par = Arc::new(BfvParameters::default(3, 8)); + for _ in 0..15 { + for level in 0..2 { + let values = par.plaintext.random_vec(par.degree(), &mut rng); + let mut expected = values.clone(); + par.plaintext.mul_vec(&mut expected, &values); - let sk = SecretKey::random(&par, &mut OsRng); - let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?; - let pt = Plaintext::try_encode(&values, Encoding::simd_at_level(level), &par)?; - let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; - let ct2: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; - assert_eq!(ct1.level, level); - assert_eq!(ct2.level, level); + let sk = SecretKey::random(&par, &mut OsRng); + let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?; + let pt = Plaintext::try_encode(&values, Encoding::simd_at_level(level), &par)?; + let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; + assert_eq!(ct1.level, level); + assert_eq!(ct2.level, level); - let mut multiplicator = Multiplicator::default(&rk).unwrap(); - let ct3 = multiplicator.multiply(&ct1, &ct2).unwrap(); - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + let mut multiplicator = Multiplicator::default(&rk).unwrap(); + let ct3 = multiplicator.multiply(&ct1, &ct2).unwrap(); + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - multiplicator.enable_mod_switching()?; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - assert_eq!(ct3.level, level + 1); - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - } - } - Ok(()) - } + multiplicator.enable_mod_switching()?; + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + assert_eq!(ct3.level, level + 1); + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + } + } + Ok(()) + } - #[test] - fn mul_no_relin() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let par = Arc::new(BfvParameters::default(6, 8)); - for _ in 0..30 { - // We will encode `values` in an Simd format, and check that the product is - // computed correctly. - let values = par.plaintext.random_vec(par.degree(), &mut rng); - let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + #[test] + fn mul_no_relin() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let par = Arc::new(BfvParameters::default(6, 8)); + for _ in 0..30 { + // We will encode `values` in an Simd format, and check that the product is + // computed correctly. + let values = par.plaintext.random_vec(par.degree(), &mut rng); + let mut expected = values.clone(); + par.plaintext.mul_vec(&mut expected, &values); - let sk = SecretKey::random(&par, &mut OsRng); - let rk = RelinearizationKey::new(&sk, &mut rng)?; - let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; - let ct1 = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = sk.try_encrypt(&pt, &mut rng)?; + let sk = SecretKey::random(&par, &mut OsRng); + let rk = RelinearizationKey::new(&sk, &mut rng)?; + let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; + let ct1 = sk.try_encrypt(&pt, &mut rng)?; + let ct2 = sk.try_encrypt(&pt, &mut rng)?; - let mut multiplicator = Multiplicator::default(&rk)?; - // Remove the relinearization key. - multiplicator.rk = None; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + let mut multiplicator = Multiplicator::default(&rk)?; + // Remove the relinearization key. + multiplicator.rk = None; + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - multiplicator.enable_mod_switching()?; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - assert_eq!(ct3.level, 1); - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - } - Ok(()) - } + multiplicator.enable_mod_switching()?; + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + assert_eq!(ct3.level, 1); + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + } + Ok(()) + } - #[test] - fn different_mul_strategy() -> Result<(), Box<dyn Error>> { - // Implement the second multiplication strategy from <https://eprint.iacr.org/2021/204> + #[test] + fn different_mul_strategy() -> Result<(), Box<dyn Error>> { + // Implement the second multiplication strategy from <https://eprint.iacr.org/2021/204> - let mut rng = thread_rng(); - let par = Arc::new(BfvParameters::default(3, 8)); - let mut extended_basis = par.moduli().to_vec(); - extended_basis - .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap()); - extended_basis - .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[3]).unwrap()); - extended_basis - .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[4]).unwrap()); - let rns = RnsContext::new(&extended_basis[3..])?; + let mut rng = thread_rng(); + let par = Arc::new(BfvParameters::default(3, 8)); + let mut extended_basis = par.moduli().to_vec(); + extended_basis + .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap()); + extended_basis + .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[3]).unwrap()); + extended_basis + .push(generate_prime(62, 2 * par.degree() as u64, extended_basis[4]).unwrap()); + let rns = RnsContext::new(&extended_basis[3..])?; - for _ in 0..30 { - // We will encode `values` in an Simd format, and check that the product is - // computed correctly. - let values = par.plaintext.random_vec(par.degree(), &mut rng); - let mut expected = values.clone(); - par.plaintext.mul_vec(&mut expected, &values); + for _ in 0..30 { + // We will encode `values` in an Simd format, and check that the product is + // computed correctly. + let values = par.plaintext.random_vec(par.degree(), &mut rng); + let mut expected = values.clone(); + par.plaintext.mul_vec(&mut expected, &values); - let sk = SecretKey::random(&par, &mut OsRng); - let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; - let ct1 = sk.try_encrypt(&pt, &mut rng)?; - let ct2 = sk.try_encrypt(&pt, &mut rng)?; + let sk = SecretKey::random(&par, &mut OsRng); + let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?; + let ct1 = sk.try_encrypt(&pt, &mut rng)?; + let ct2 = sk.try_encrypt(&pt, &mut rng)?; - let mut multiplicator = Multiplicator::new( - ScalingFactor::one(), - ScalingFactor::new(rns.modulus(), par.ctx[0].modulus()), - &extended_basis, - ScalingFactor::new(&BigUint::from(par.plaintext()), rns.modulus()), - &par, - )?; + let mut multiplicator = Multiplicator::new( + ScalingFactor::one(), + ScalingFactor::new(rns.modulus(), par.ctx[0].modulus()), + &extended_basis, + ScalingFactor::new(&BigUint::from(par.plaintext()), rns.modulus()), + &par, + )?; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - multiplicator.enable_mod_switching()?; - let ct3 = multiplicator.multiply(&ct1, &ct2)?; - assert_eq!(ct3.level, 1); - println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); - let pt = sk.try_decrypt(&ct3)?; - assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); - } + multiplicator.enable_mod_switching()?; + let ct3 = multiplicator.multiply(&ct1, &ct2)?; + assert_eq!(ct3.level, 1); + println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? }); + let pt = sk.try_decrypt(&ct3)?; + assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected); + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index a176111..f20bd24 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -3,9 +3,9 @@ use crate::bfv::proto::bfv::Parameters; use crate::{Error, ParametersError, Result}; use fhe_math::{ - rns::{RnsContext, ScalingFactor}, - rq::{scaler::Scaler, traits::TryConvertFrom, Context, Poly, Representation}, - zq::{ntt::NttOperator, primes::generate_prime, Modulus}, + rns::{RnsContext, ScalingFactor}, + rq::{scaler::Scaler, traits::TryConvertFrom, Context, Poly, Representation}, + zq::{ntt::NttOperator, primes::generate_prime, Modulus}, }; use fhe_traits::{Deserialize, FheParameters, Serialize}; use fhe_util::div_ceil; @@ -20,67 +20,67 @@ use std::sync::Arc; /// Parameters for the BFV encryption scheme. #[derive(PartialEq, Eq)] pub struct BfvParameters { - /// Number of coefficients in a polynomial. - polynomial_degree: usize, + /// Number of coefficients in a polynomial. + polynomial_degree: usize, - /// Modulus of the plaintext. - plaintext_modulus: u64, + /// Modulus of the plaintext. + plaintext_modulus: u64, - /// Vector of coprime moduli q_i for the ciphertext. - /// One and only one of `ciphertext_moduli` or `ciphertext_moduli_sizes` - /// must be specified. - pub(crate) moduli: Box<[u64]>, + /// Vector of coprime moduli q_i for the ciphertext. + /// One and only one of `ciphertext_moduli` or `ciphertext_moduli_sizes` + /// must be specified. + pub(crate) moduli: Box<[u64]>, - /// Vector of the sized of the coprime moduli q_i for the ciphertext. - /// One and only one of `ciphertext_moduli` or `ciphertext_moduli_sizes` - /// must be specified. - moduli_sizes: Box<[usize]>, + /// Vector of the sized of the coprime moduli q_i for the ciphertext. + /// One and only one of `ciphertext_moduli` or `ciphertext_moduli_sizes` + /// must be specified. + moduli_sizes: Box<[usize]>, - /// Error variance - pub(crate) variance: usize, + /// Error variance + pub(crate) variance: usize, - /// Context for the underlying polynomials - pub(crate) ctx: Vec<Arc<Context>>, + /// Context for the underlying polynomials + pub(crate) ctx: Vec<Arc<Context>>, - /// Ntt operator for the SIMD plaintext, if possible. - pub(crate) op: Option<Arc<NttOperator>>, + /// Ntt operator for the SIMD plaintext, if possible. + pub(crate) op: Option<Arc<NttOperator>>, - /// Scaling polynomial for the plaintext - pub(crate) delta: Box<[Poly]>, + /// Scaling polynomial for the plaintext + pub(crate) delta: Box<[Poly]>, - /// Q modulo the plaintext modulus - pub(crate) q_mod_t: Box<[u64]>, + /// Q modulo the plaintext modulus + pub(crate) q_mod_t: Box<[u64]>, - /// Down scaler for the plaintext - pub(crate) scalers: Box<[Scaler]>, + /// Down scaler for the plaintext + pub(crate) scalers: Box<[Scaler]>, - /// Plaintext Modulus - pub(crate) plaintext: Modulus, + /// Plaintext Modulus + pub(crate) plaintext: Modulus, - // Parameters for the multiplications - pub(crate) mul_params: Box<[MultiplicationParameters]>, + // Parameters for the multiplications + pub(crate) mul_params: Box<[MultiplicationParameters]>, - pub(crate) matrix_reps_index_map: Box<[usize]>, + pub(crate) matrix_reps_index_map: Box<[usize]>, } impl Debug for BfvParameters { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("BfvParameters") - .field("polynomial_degree", &self.polynomial_degree) - .field("plaintext_modulus", &self.plaintext_modulus) - .field("moduli", &self.moduli) - // .field("moduli_sizes", &self.moduli_sizes) - // .field("variance", &self.variance) - // .field("ctx", &self.ctx) - // .field("op", &self.op) - // .field("delta", &self.delta) - // .field("q_mod_t", &self.q_mod_t) - // .field("scaler", &self.scaler) - // .field("plaintext", &self.plaintext) - // .field("mul_params", &self.mul_params) - // .field("matrix_reps_index_map", &self.matrix_reps_index_map) - .finish() - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BfvParameters") + .field("polynomial_degree", &self.polynomial_degree) + .field("plaintext_modulus", &self.plaintext_modulus) + .field("moduli", &self.moduli) + // .field("moduli_sizes", &self.moduli_sizes) + // .field("variance", &self.variance) + // .field("ctx", &self.ctx) + // .field("op", &self.op) + // .field("delta", &self.delta) + // .field("q_mod_t", &self.q_mod_t) + // .field("scaler", &self.scaler) + // .field("plaintext", &self.plaintext) + // .field("mul_params", &self.mul_params) + // .field("matrix_reps_index_map", &self.matrix_reps_index_map) + .finish() + } } impl FheParameters for BfvParameters {} @@ -88,570 +88,570 @@ impl FheParameters for BfvParameters {} unsafe impl Send for BfvParameters {} impl BfvParameters { - /// Returns the underlying polynomial degree - pub const fn degree(&self) -> usize { - self.polynomial_degree - } + /// Returns the underlying polynomial degree + pub const fn degree(&self) -> usize { + self.polynomial_degree + } - /// Returns a reference to the ciphertext moduli - pub fn moduli(&self) -> &[u64] { - &self.moduli - } + /// Returns a reference to the ciphertext moduli + pub fn moduli(&self) -> &[u64] { + &self.moduli + } - /// Returns a reference to the ciphertext moduli - pub fn moduli_sizes(&self) -> &[usize] { - &self.moduli_sizes - } + /// Returns a reference to the ciphertext moduli + pub fn moduli_sizes(&self) -> &[usize] { + &self.moduli_sizes + } - /// Returns the plaintext modulus - pub const fn plaintext(&self) -> u64 { - self.plaintext_modulus - } + /// Returns the plaintext modulus + pub const fn plaintext(&self) -> u64 { + self.plaintext_modulus + } - /// Returns the maximum level allowed by these parameters. - pub fn max_level(&self) -> usize { - self.moduli.len() - 1 - } + /// Returns the maximum level allowed by these parameters. + pub fn max_level(&self) -> usize { + self.moduli.len() - 1 + } - /// Returns the context corresponding to the level. - pub(crate) fn ctx_at_level(&self, level: usize) -> Result<&Arc<Context>> { - self.ctx - .get(level) - .ok_or_else(|| Error::DefaultError("No context".to_string())) - } + /// Returns the context corresponding to the level. + pub(crate) fn ctx_at_level(&self, level: usize) -> Result<&Arc<Context>> { + self.ctx + .get(level) + .ok_or_else(|| Error::DefaultError("No context".to_string())) + } - /// Returns the level of a given context - pub(crate) fn level_of_ctx(&self, ctx: &Arc<Context>) -> Result<usize> { - self.ctx[0].niterations_to(ctx).map_err(Error::MathError) - } + /// Returns the level of a given context + pub(crate) fn level_of_ctx(&self, ctx: &Arc<Context>) -> Result<usize> { + self.ctx[0].niterations_to(ctx).map_err(Error::MathError) + } - /// Vector of default parameters providing about 128 bits of security - /// according to the <https://homomorphicencryption.org> standard. - pub fn default_parameters_128(plaintext_nbits: usize) -> Vec<Arc<BfvParameters>> { - debug_assert!(plaintext_nbits < 64); + /// Vector of default parameters providing about 128 bits of security + /// according to the <https://homomorphicencryption.org> standard. + pub fn default_parameters_128(plaintext_nbits: usize) -> Vec<Arc<BfvParameters>> { + debug_assert!(plaintext_nbits < 64); - let mut n_and_qs = HashMap::new(); - n_and_qs.insert(1024, vec![0x7e00001]); - n_and_qs.insert(2048, vec![0x3fffffff000001]); - n_and_qs.insert(4096, vec![0xffffee001, 0xffffc4001, 0x1ffffe0001]); - n_and_qs.insert( - 8192, - vec![ - 0x7fffffd8001, - 0x7fffffc8001, - 0xfffffffc001, - 0xffffff6c001, - 0xfffffebc001, - ], - ); - n_and_qs.insert( - 16384, - vec![ - 0xfffffffd8001, - 0xfffffffa0001, - 0xfffffff00001, - 0x1fffffff68001, - 0x1fffffff50001, - 0x1ffffffee8001, - 0x1ffffffea0001, - 0x1ffffffe88001, - 0x1ffffffe48001, - ], - ); - n_and_qs.insert( - 32768, - vec![ - 0x7fffffffe90001, - 0x7fffffffbf0001, - 0x7fffffffbd0001, - 0x7fffffffba0001, - 0x7fffffffaa0001, - 0x7fffffffa50001, - 0x7fffffff9f0001, - 0x7fffffff7e0001, - 0x7fffffff770001, - 0x7fffffff380001, - 0x7fffffff330001, - 0x7fffffff2d0001, - 0x7fffffff170001, - 0x7fffffff150001, - 0x7ffffffef00001, - 0xfffffffff70001, - ], - ); + let mut n_and_qs = HashMap::new(); + n_and_qs.insert(1024, vec![0x7e00001]); + n_and_qs.insert(2048, vec![0x3fffffff000001]); + n_and_qs.insert(4096, vec![0xffffee001, 0xffffc4001, 0x1ffffe0001]); + n_and_qs.insert( + 8192, + vec![ + 0x7fffffd8001, + 0x7fffffc8001, + 0xfffffffc001, + 0xffffff6c001, + 0xfffffebc001, + ], + ); + n_and_qs.insert( + 16384, + vec![ + 0xfffffffd8001, + 0xfffffffa0001, + 0xfffffff00001, + 0x1fffffff68001, + 0x1fffffff50001, + 0x1ffffffee8001, + 0x1ffffffea0001, + 0x1ffffffe88001, + 0x1ffffffe48001, + ], + ); + n_and_qs.insert( + 32768, + vec![ + 0x7fffffffe90001, + 0x7fffffffbf0001, + 0x7fffffffbd0001, + 0x7fffffffba0001, + 0x7fffffffaa0001, + 0x7fffffffa50001, + 0x7fffffff9f0001, + 0x7fffffff7e0001, + 0x7fffffff770001, + 0x7fffffff380001, + 0x7fffffff330001, + 0x7fffffff2d0001, + 0x7fffffff170001, + 0x7fffffff150001, + 0x7ffffffef00001, + 0xfffffffff70001, + ], + ); - let mut params = vec![]; + let mut params = vec![]; - for n in n_and_qs.keys().sorted() { - let moduli = n_and_qs.get(n).unwrap(); - if let Some(plaintext_modulus) = generate_prime( - plaintext_nbits, - 2 * *n as u64, - u64::MAX >> (64 - plaintext_nbits), - ) { - params.push(Arc::new( - BfvParametersBuilder::new() - .set_degree(*n as usize) - .set_plaintext_modulus(plaintext_modulus) - .set_moduli(moduli) - .build() - .unwrap(), - )) - } - } + for n in n_and_qs.keys().sorted() { + let moduli = n_and_qs.get(n).unwrap(); + if let Some(plaintext_modulus) = generate_prime( + plaintext_nbits, + 2 * *n as u64, + u64::MAX >> (64 - plaintext_nbits), + ) { + params.push(Arc::new( + BfvParametersBuilder::new() + .set_degree(*n as usize) + .set_plaintext_modulus(plaintext_modulus) + .set_moduli(moduli) + .build() + .unwrap(), + )) + } + } - params - } + params + } - #[cfg(test)] - pub fn default(num_moduli: usize, degree: usize) -> Self { - if !degree.is_power_of_two() || degree < 8 { - panic!("Invalid degree"); - } - BfvParametersBuilder::new() - .set_degree(degree) - .set_plaintext_modulus(1153) - .set_moduli_sizes(&vec![62usize; num_moduli]) - .build() - .unwrap() - } + #[cfg(test)] + pub fn default(num_moduli: usize, degree: usize) -> Self { + if !degree.is_power_of_two() || degree < 8 { + panic!("Invalid degree"); + } + BfvParametersBuilder::new() + .set_degree(degree) + .set_plaintext_modulus(1153) + .set_moduli_sizes(&vec![62usize; num_moduli]) + .build() + .unwrap() + } } /// Builder for parameters for the Bfv encryption scheme. #[derive(Debug)] pub struct BfvParametersBuilder { - degree: usize, - plaintext: u64, - variance: usize, - ciphertext_moduli: Vec<u64>, - ciphertext_moduli_sizes: Vec<usize>, + degree: usize, + plaintext: u64, + variance: usize, + ciphertext_moduli: Vec<u64>, + ciphertext_moduli_sizes: Vec<usize>, } impl BfvParametersBuilder { - /// Creates a new instance of the builder - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - Self { - degree: Default::default(), - plaintext: Default::default(), - variance: 10, - ciphertext_moduli: Default::default(), - ciphertext_moduli_sizes: Default::default(), - } - } + /// Creates a new instance of the builder + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + degree: Default::default(), + plaintext: Default::default(), + variance: 10, + ciphertext_moduli: Default::default(), + ciphertext_moduli_sizes: Default::default(), + } + } - /// Sets the polynomial degree. Returns an error if the degree is not - /// a power of two larger or equal to 8. - pub fn set_degree(&mut self, degree: usize) -> &mut Self { - self.degree = degree; - self - } + /// Sets the polynomial degree. Returns an error if the degree is not + /// a power of two larger or equal to 8. + pub fn set_degree(&mut self, degree: usize) -> &mut Self { + self.degree = degree; + self + } - /// Sets the plaintext modulus. Returns an error if the plaintext is not - /// between 2 and 2^62 - 1. - pub fn set_plaintext_modulus(&mut self, plaintext: u64) -> &mut Self { - self.plaintext = plaintext; - self - } + /// Sets the plaintext modulus. Returns an error if the plaintext is not + /// between 2 and 2^62 - 1. + pub fn set_plaintext_modulus(&mut self, plaintext: u64) -> &mut Self { + self.plaintext = plaintext; + self + } - /// Sets the sizes of the ciphertext moduli. - /// Only one of `set_moduli_sizes` and `set_moduli` - /// can be specified. - pub fn set_moduli_sizes(&mut self, sizes: &[usize]) -> &mut Self { - self.ciphertext_moduli_sizes = sizes.to_owned(); - self - } + /// Sets the sizes of the ciphertext moduli. + /// Only one of `set_moduli_sizes` and `set_moduli` + /// can be specified. + pub fn set_moduli_sizes(&mut self, sizes: &[usize]) -> &mut Self { + self.ciphertext_moduli_sizes = sizes.to_owned(); + self + } - /// Sets the ciphertext moduli to use. - /// Only one of `set_moduli_sizes` and `set_moduli` - /// can be specified. - pub fn set_moduli(&mut self, moduli: &[u64]) -> &mut Self { - self.ciphertext_moduli = moduli.to_owned(); - self - } + /// Sets the ciphertext moduli to use. + /// Only one of `set_moduli_sizes` and `set_moduli` + /// can be specified. + pub fn set_moduli(&mut self, moduli: &[u64]) -> &mut Self { + self.ciphertext_moduli = moduli.to_owned(); + self + } - /// Sets the error variance. Returns an error if the variance is not between - /// one and sixteen. - pub fn set_variance(&mut self, variance: usize) -> &mut Self { - self.variance = variance; - self - } + /// Sets the error variance. Returns an error if the variance is not between + /// one and sixteen. + pub fn set_variance(&mut self, variance: usize) -> &mut Self { + self.variance = variance; + self + } - /// Generate ciphertext moduli with the specified sizes - fn generate_moduli(moduli_sizes: &[usize], degree: usize) -> Result<Vec<u64>> { - let mut moduli = vec![]; - for size in moduli_sizes { - if *size > 62 || *size < 10 { - return Err(Error::ParametersError(ParametersError::InvalidModulusSize( - *size, 10, 62, - ))); - } + /// Generate ciphertext moduli with the specified sizes + fn generate_moduli(moduli_sizes: &[usize], degree: usize) -> Result<Vec<u64>> { + let mut moduli = vec![]; + for size in moduli_sizes { + if *size > 62 || *size < 10 { + return Err(Error::ParametersError(ParametersError::InvalidModulusSize( + *size, 10, 62, + ))); + } - let mut upper_bound = 1 << size; - loop { - if let Some(prime) = generate_prime(*size, 2 * degree as u64, upper_bound) { - if !moduli.contains(&prime) { - moduli.push(prime); - break; - } else { - upper_bound = prime; - } - } else { - return Err(Error::ParametersError(ParametersError::NotEnoughPrimes( - *size, degree, - ))); - } - } - } + let mut upper_bound = 1 << size; + loop { + if let Some(prime) = generate_prime(*size, 2 * degree as u64, upper_bound) { + if !moduli.contains(&prime) { + moduli.push(prime); + break; + } else { + upper_bound = prime; + } + } else { + return Err(Error::ParametersError(ParametersError::NotEnoughPrimes( + *size, degree, + ))); + } + } + } - Ok(moduli) - } + Ok(moduli) + } - /// Build a new `BfvParameters`. - pub fn build(&self) -> Result<BfvParameters> { - // Check that the degree is a power of 2 (and large enough). - if self.degree < 8 || !self.degree.is_power_of_two() { - return Err(Error::ParametersError(ParametersError::InvalidDegree( - self.degree, - ))); - } + /// Build a new `BfvParameters`. + pub fn build(&self) -> Result<BfvParameters> { + // Check that the degree is a power of 2 (and large enough). + if self.degree < 8 || !self.degree.is_power_of_two() { + return Err(Error::ParametersError(ParametersError::InvalidDegree( + self.degree, + ))); + } - // This checks that the plaintext modulus is valid. - // TODO: Check bound on the plaintext modulus. - let plaintext_modulus = Modulus::new(self.plaintext).map_err(|e| { - Error::ParametersError(ParametersError::InvalidPlaintext(e.to_string())) - })?; + // This checks that the plaintext modulus is valid. + // TODO: Check bound on the plaintext modulus. + let plaintext_modulus = Modulus::new(self.plaintext).map_err(|e| { + Error::ParametersError(ParametersError::InvalidPlaintext(e.to_string())) + })?; - // Check that one of `ciphertext_moduli` and `ciphertext_moduli_sizes` is - // specified. - if !self.ciphertext_moduli.is_empty() && !self.ciphertext_moduli_sizes.is_empty() { - return Err(Error::ParametersError(ParametersError::TooManySpecified( - "Only one of `ciphertext_moduli` and `ciphertext_moduli_sizes` can be specified" - .to_string(), - ))); - } else if self.ciphertext_moduli.is_empty() && self.ciphertext_moduli_sizes.is_empty() { - return Err(Error::ParametersError(ParametersError::TooFewSpecified( - "One of `ciphertext_moduli` and `ciphertext_moduli_sizes` must be specified" - .to_string(), - ))); - } + // Check that one of `ciphertext_moduli` and `ciphertext_moduli_sizes` is + // specified. + if !self.ciphertext_moduli.is_empty() && !self.ciphertext_moduli_sizes.is_empty() { + return Err(Error::ParametersError(ParametersError::TooManySpecified( + "Only one of `ciphertext_moduli` and `ciphertext_moduli_sizes` can be specified" + .to_string(), + ))); + } else if self.ciphertext_moduli.is_empty() && self.ciphertext_moduli_sizes.is_empty() { + return Err(Error::ParametersError(ParametersError::TooFewSpecified( + "One of `ciphertext_moduli` and `ciphertext_moduli_sizes` must be specified" + .to_string(), + ))); + } - // Get or generate the moduli - let mut moduli = self.ciphertext_moduli.clone(); - if !self.ciphertext_moduli_sizes.is_empty() { - moduli = Self::generate_moduli(&self.ciphertext_moduli_sizes, self.degree)? - } + // Get or generate the moduli + let mut moduli = self.ciphertext_moduli.clone(); + if !self.ciphertext_moduli_sizes.is_empty() { + moduli = Self::generate_moduli(&self.ciphertext_moduli_sizes, self.degree)? + } - // Recomputes the moduli sizes - let moduli_sizes = moduli - .iter() - .map(|m| 64 - m.leading_zeros() as usize) - .collect_vec(); + // Recomputes the moduli sizes + let moduli_sizes = moduli + .iter() + .map(|m| 64 - m.leading_zeros() as usize) + .collect_vec(); - // Create n+1 moduli of 62 bits for multiplication. - let mut extended_basis = Vec::with_capacity(moduli.len() + 1); - let mut upper_bound = 1 << 62; - while extended_basis.len() != moduli.len() + 1 { - upper_bound = generate_prime(62, 2 * self.degree as u64, upper_bound).unwrap(); - if !extended_basis.contains(&upper_bound) && !moduli.contains(&upper_bound) { - extended_basis.push(upper_bound) - } - } + // Create n+1 moduli of 62 bits for multiplication. + let mut extended_basis = Vec::with_capacity(moduli.len() + 1); + let mut upper_bound = 1 << 62; + while extended_basis.len() != moduli.len() + 1 { + upper_bound = generate_prime(62, 2 * self.degree as u64, upper_bound).unwrap(); + if !extended_basis.contains(&upper_bound) && !moduli.contains(&upper_bound) { + extended_basis.push(upper_bound) + } + } - let op = NttOperator::new(&plaintext_modulus, self.degree); + let op = NttOperator::new(&plaintext_modulus, self.degree); - let plaintext_ctx = Arc::new(Context::new(&moduli[..1], self.degree)?); + let plaintext_ctx = Arc::new(Context::new(&moduli[..1], self.degree)?); - let mut delta_rests = vec![]; - for m in &moduli { - let q = Modulus::new(*m)?; - delta_rests.push(q.inv(q.neg(plaintext_modulus.modulus())).unwrap()) - } + let mut delta_rests = vec![]; + for m in &moduli { + let q = Modulus::new(*m)?; + delta_rests.push(q.inv(q.neg(plaintext_modulus.modulus())).unwrap()) + } - let mut ctx = Vec::with_capacity(moduli.len()); - let mut delta = Vec::with_capacity(moduli.len()); - let mut q_mod_t = Vec::with_capacity(moduli.len()); - let mut scalers = Vec::with_capacity(moduli.len()); - let mut mul_params = Vec::with_capacity(moduli.len()); - for i in 0..moduli.len() { - let rns = RnsContext::new(&moduli[..moduli.len() - i]).unwrap(); - let ctx_i = Arc::new(Context::new(&moduli[..moduli.len() - i], self.degree).unwrap()); - let mut p = Poly::try_convert_from( - &[rns.lift((&delta_rests).into())], - &ctx_i, - true, - Representation::PowerBasis, - )?; - p.change_representation(Representation::NttShoup); - delta.push(p); + let mut ctx = Vec::with_capacity(moduli.len()); + let mut delta = Vec::with_capacity(moduli.len()); + let mut q_mod_t = Vec::with_capacity(moduli.len()); + let mut scalers = Vec::with_capacity(moduli.len()); + let mut mul_params = Vec::with_capacity(moduli.len()); + for i in 0..moduli.len() { + let rns = RnsContext::new(&moduli[..moduli.len() - i]).unwrap(); + let ctx_i = Arc::new(Context::new(&moduli[..moduli.len() - i], self.degree).unwrap()); + let mut p = Poly::try_convert_from( + &[rns.lift((&delta_rests).into())], + &ctx_i, + true, + Representation::PowerBasis, + )?; + p.change_representation(Representation::NttShoup); + delta.push(p); - q_mod_t.push( - (rns.modulus() % plaintext_modulus.modulus()) - .to_u64() - .unwrap(), - ); + q_mod_t.push( + (rns.modulus() % plaintext_modulus.modulus()) + .to_u64() + .unwrap(), + ); - scalers.push(Scaler::new( - &ctx_i, - &plaintext_ctx, - ScalingFactor::new(&BigUint::from(plaintext_modulus.modulus()), rns.modulus()), - )?); + scalers.push(Scaler::new( + &ctx_i, + &plaintext_ctx, + ScalingFactor::new(&BigUint::from(plaintext_modulus.modulus()), rns.modulus()), + )?); - // For the first multiplication, we want to extend to a context that - // is ~60 bits larger. - let modulus_size = moduli_sizes[..moduli_sizes.len() - i].iter().sum::<usize>(); - let n_moduli = div_ceil(modulus_size + 60, 62); - let mut mul_1_moduli = vec![]; - mul_1_moduli.append(&mut moduli[..moduli_sizes.len() - i].to_vec()); - mul_1_moduli.append(&mut extended_basis[..n_moduli].to_vec()); - let mul_1_ctx = Arc::new(Context::new(&mul_1_moduli, self.degree)?); - mul_params.push(MultiplicationParameters::new( - &ctx_i, - &mul_1_ctx, - ScalingFactor::one(), - ScalingFactor::new(&BigUint::from(plaintext_modulus.modulus()), ctx_i.modulus()), - )?); + // For the first multiplication, we want to extend to a context that + // is ~60 bits larger. + let modulus_size = moduli_sizes[..moduli_sizes.len() - i].iter().sum::<usize>(); + let n_moduli = div_ceil(modulus_size + 60, 62); + let mut mul_1_moduli = vec![]; + mul_1_moduli.append(&mut moduli[..moduli_sizes.len() - i].to_vec()); + mul_1_moduli.append(&mut extended_basis[..n_moduli].to_vec()); + let mul_1_ctx = Arc::new(Context::new(&mul_1_moduli, self.degree)?); + mul_params.push(MultiplicationParameters::new( + &ctx_i, + &mul_1_ctx, + ScalingFactor::one(), + ScalingFactor::new(&BigUint::from(plaintext_modulus.modulus()), ctx_i.modulus()), + )?); - ctx.push(ctx_i); - } + ctx.push(ctx_i); + } - // We use the same code as SEAL - // https://github.com/microsoft/SEAL/blob/82b07db635132e297282649e2ab5908999089ad2/native/src/seal/batchencoder.cpp - let row_size = self.degree >> 1; - let m = self.degree << 1; - let gen = 3; - let mut pos = 1; - let mut matrix_reps_index_map = vec![0usize; self.degree]; - for i in 0..row_size { - let index1 = (pos - 1) >> 1; - let index2 = (m - pos - 1) >> 1; - matrix_reps_index_map[i] = index1.reverse_bits() >> (self.degree.leading_zeros() + 1); - matrix_reps_index_map[row_size | i] = - index2.reverse_bits() >> (self.degree.leading_zeros() + 1); - pos *= gen; - pos &= m - 1; - } + // We use the same code as SEAL + // https://github.com/microsoft/SEAL/blob/82b07db635132e297282649e2ab5908999089ad2/native/src/seal/batchencoder.cpp + let row_size = self.degree >> 1; + let m = self.degree << 1; + let gen = 3; + let mut pos = 1; + let mut matrix_reps_index_map = vec![0usize; self.degree]; + for i in 0..row_size { + let index1 = (pos - 1) >> 1; + let index2 = (m - pos - 1) >> 1; + matrix_reps_index_map[i] = index1.reverse_bits() >> (self.degree.leading_zeros() + 1); + matrix_reps_index_map[row_size | i] = + index2.reverse_bits() >> (self.degree.leading_zeros() + 1); + pos *= gen; + pos &= m - 1; + } - Ok(BfvParameters { - polynomial_degree: self.degree, - plaintext_modulus: self.plaintext, - moduli: moduli.into_boxed_slice(), - moduli_sizes: moduli_sizes.into_boxed_slice(), - variance: self.variance, - ctx, - op: op.map(Arc::new), - delta: delta.into_boxed_slice(), - q_mod_t: q_mod_t.into_boxed_slice(), - scalers: scalers.into_boxed_slice(), - plaintext: plaintext_modulus, - mul_params: mul_params.into_boxed_slice(), - matrix_reps_index_map: matrix_reps_index_map.into_boxed_slice(), - }) - } + Ok(BfvParameters { + polynomial_degree: self.degree, + plaintext_modulus: self.plaintext, + moduli: moduli.into_boxed_slice(), + moduli_sizes: moduli_sizes.into_boxed_slice(), + variance: self.variance, + ctx, + op: op.map(Arc::new), + delta: delta.into_boxed_slice(), + q_mod_t: q_mod_t.into_boxed_slice(), + scalers: scalers.into_boxed_slice(), + plaintext: plaintext_modulus, + mul_params: mul_params.into_boxed_slice(), + matrix_reps_index_map: matrix_reps_index_map.into_boxed_slice(), + }) + } } impl Serialize for BfvParameters { - fn to_bytes(&self) -> Vec<u8> { - let mut params = Parameters::new(); - params.degree = self.polynomial_degree as u32; - params.plaintext = self.plaintext_modulus; - params.moduli = self.moduli.to_vec(); - params.variance = self.variance as u32; - params.write_to_bytes().unwrap() - } + fn to_bytes(&self) -> Vec<u8> { + let mut params = Parameters::new(); + params.degree = self.polynomial_degree as u32; + params.plaintext = self.plaintext_modulus; + params.moduli = self.moduli.to_vec(); + params.variance = self.variance as u32; + params.write_to_bytes().unwrap() + } } impl Deserialize for BfvParameters { - fn try_deserialize(bytes: &[u8]) -> Result<Self> { - if let Ok(params) = Parameters::parse_from_bytes(bytes) { - BfvParametersBuilder::new() - .set_degree(params.degree as usize) - .set_plaintext_modulus(params.plaintext) - .set_moduli(&params.moduli) - .set_variance(params.variance as usize) - .build() - } else { - Err(Error::SerializationError) - } - } - type Error = Error; + fn try_deserialize(bytes: &[u8]) -> Result<Self> { + if let Ok(params) = Parameters::parse_from_bytes(bytes) { + BfvParametersBuilder::new() + .set_degree(params.degree as usize) + .set_plaintext_modulus(params.plaintext) + .set_moduli(&params.moduli) + .set_variance(params.variance as usize) + .build() + } else { + Err(Error::SerializationError) + } + } + type Error = Error; } /// Multiplication parameters #[derive(Debug, PartialEq, Eq, Default)] pub(crate) struct MultiplicationParameters { - pub(crate) extender: Scaler, - pub(crate) down_scaler: Scaler, - pub(crate) from: Arc<Context>, - pub(crate) to: Arc<Context>, + pub(crate) extender: Scaler, + pub(crate) down_scaler: Scaler, + pub(crate) from: Arc<Context>, + pub(crate) to: Arc<Context>, } impl MultiplicationParameters { - fn new( - from: &Arc<Context>, - to: &Arc<Context>, - up_self_factor: ScalingFactor, - down_factor: ScalingFactor, - ) -> Result<Self> { - Ok(Self { - extender: Scaler::new(from, to, up_self_factor)?, - down_scaler: Scaler::new(to, from, down_factor)?, - from: from.clone(), - to: to.clone(), - }) - } + fn new( + from: &Arc<Context>, + to: &Arc<Context>, + up_self_factor: ScalingFactor, + down_factor: ScalingFactor, + ) -> Result<Self> { + Ok(Self { + extender: Scaler::new(from, to, up_self_factor)?, + down_scaler: Scaler::new(to, from, down_factor)?, + from: from.clone(), + to: to.clone(), + }) + } } #[cfg(test)] mod tests { - use super::{BfvParameters, BfvParametersBuilder}; - use fhe_traits::{Deserialize, Serialize}; - use std::error::Error; + use super::{BfvParameters, BfvParametersBuilder}; + use fhe_traits::{Deserialize, Serialize}; + use std::error::Error; - // TODO: To fix when errors handling is fixed. - // #[test] - // fn builder() -> Result<(), Box<dyn Error>> { - // let params = BfvParametersBuilder::new().build(); - // assert!(params.is_err_and(|e| e.to_string() == "Unspecified degree")); + // TODO: To fix when errors handling is fixed. + // #[test] + // fn builder() -> Result<(), Box<dyn Error>> { + // let params = BfvParametersBuilder::new().build(); + // assert!(params.is_err_and(|e| e.to_string() == "Unspecified degree")); - // assert!(BfvParametersBuilder::new() - // .set_degree(7) - // .build() - // .is_err_and( - // |e| e.to_string() == "The degree should be a power of two larger or equal to - // 8" )); + // assert!(BfvParametersBuilder::new() + // .set_degree(7) + // .build() + // .is_err_and( + // |e| e.to_string() == "The degree should be a power of two larger or equal to + // 8" )); - // assert!(BfvParametersBuilder::new() - // .set_degree(1023) - // .build() - // .is_err_and( - // |e| e.to_string() == "The degree should be a power of two larger or equal to - // 8" )); + // assert!(BfvParametersBuilder::new() + // .set_degree(1023) + // .build() + // .is_err_and( + // |e| e.to_string() == "The degree should be a power of two larger or equal to + // 8" )); - // let params = BfvParametersBuilder::new().set_degree(1024).build(); - // assert!(params.is_err_and(|e| e.to_string() == "Unspecified plaintext - // modulus")); + // let params = BfvParametersBuilder::new().set_degree(1024).build(); + // assert!(params.is_err_and(|e| e.to_string() == "Unspecified plaintext + // modulus")); - // assert!(BfvParametersBuilder::new() - // .set_degree(1024) - // .set_plaintext_modulus(0) - // .build() - // .is_err_and(|e| e.to_string() == "modulus should be between 2 and 2^62-1")); + // assert!(BfvParametersBuilder::new() + // .set_degree(1024) + // .set_plaintext_modulus(0) + // .build() + // .is_err_and(|e| e.to_string() == "modulus should be between 2 and 2^62-1")); - // let params = BfvParametersBuilder::new() - // .set_degree(1024) - // .set_plaintext_modulus(2) - // .build(); - // assert!(params.is_err_and(|e| e.to_string() == "Unspecified ciphertext - // moduli")); + // let params = BfvParametersBuilder::new() + // .set_degree(1024) + // .set_plaintext_modulus(2) + // .build(); + // assert!(params.is_err_and(|e| e.to_string() == "Unspecified ciphertext + // moduli")); - // assert!(BfvParametersBuilder::new() - // .set_degree(1024) - // .set_plaintext_modulus(2) - // .set_moduli(&[]) - // .build() - // .is_err_and(|e| e.to_string() == "Unspecified ciphertext moduli")); + // assert!(BfvParametersBuilder::new() + // .set_degree(1024) + // .set_plaintext_modulus(2) + // .set_moduli(&[]) + // .build() + // .is_err_and(|e| e.to_string() == "Unspecified ciphertext moduli")); - // assert!(BfvParametersBuilder::new() - // .set_degree(1024) - // .set_plaintext_modulus(2) - // .set_moduli(&[1153]) - // .set_moduli_sizes(&[62]) - // .build() - // .is_err_and(|e| e.to_string() == "The set of ciphertext moduli is already - // specified")); + // assert!(BfvParametersBuilder::new() + // .set_degree(1024) + // .set_plaintext_modulus(2) + // .set_moduli(&[1153]) + // .set_moduli_sizes(&[62]) + // .build() + // .is_err_and(|e| e.to_string() == "The set of ciphertext moduli is already + // specified")); - // assert!(BfvParametersBuilder::new() - // .set_degree(8) - // .set_plaintext_modulus(2) - // .set_moduli(&[1]) - // .build() - // .is_err_and(|e| e.to_string() == "modulus should be between 2 and 2^62-1")); + // assert!(BfvParametersBuilder::new() + // .set_degree(8) + // .set_plaintext_modulus(2) + // .set_moduli(&[1]) + // .build() + // .is_err_and(|e| e.to_string() == "modulus should be between 2 and 2^62-1")); - // let params = BfvParametersBuilder::new() - // .set_degree(8) - // .set_plaintext_modulus(2) - // .set_moduli(&[2]) - // .build(); - // assert!(params.is_err_and(|e| e.to_string() == "Impossible to construct a Ntt - // operator")); + // let params = BfvParametersBuilder::new() + // .set_degree(8) + // .set_plaintext_modulus(2) + // .set_moduli(&[2]) + // .build(); + // assert!(params.is_err_and(|e| e.to_string() == "Impossible to construct a Ntt + // operator")); - // let params = BfvParametersBuilder::new() - // .set_degree(8) - // .set_plaintext_modulus(2) - // .set_moduli(&[1153]) - // .build(); - // assert!(params.is_ok()); + // let params = BfvParametersBuilder::new() + // .set_degree(8) + // .set_plaintext_modulus(2) + // .set_moduli(&[1153]) + // .build(); + // assert!(params.is_ok()); - // let params = params.unwrap(); - // assert_eq!(params.ciphertext_moduli, vec![1153]); - // assert_eq!(params.moduli(), vec![1153]); - // assert_eq!(params.plaintext_modulus, 2); - // assert_eq!(params.polynomial_degree, 8); - // assert_eq!(params.degree(), 8); - // assert_eq!(params.variance, 1); - // assert!(params.op.is_none()); + // let params = params.unwrap(); + // assert_eq!(params.ciphertext_moduli, vec![1153]); + // assert_eq!(params.moduli(), vec![1153]); + // assert_eq!(params.plaintext_modulus, 2); + // assert_eq!(params.polynomial_degree, 8); + // assert_eq!(params.degree(), 8); + // assert_eq!(params.variance, 1); + // assert!(params.op.is_none()); - // Ok(()) - // } + // Ok(()) + // } - #[test] - fn default() { - let params = BfvParameters::default(1, 8); - assert_eq!(params.moduli.len(), 1); - assert_eq!(params.degree(), 8); + #[test] + fn default() { + let params = BfvParameters::default(1, 8); + assert_eq!(params.moduli.len(), 1); + assert_eq!(params.degree(), 8); - let params = BfvParameters::default(2, 16); - assert_eq!(params.moduli.len(), 2); - assert_eq!(params.degree(), 16); - } + let params = BfvParameters::default(2, 16); + assert_eq!(params.moduli.len(), 2); + assert_eq!(params.degree(), 16); + } - #[test] - fn ciphertext_moduli() -> Result<(), Box<dyn Error>> { - let params = BfvParametersBuilder::new() - .set_degree(8) - .set_plaintext_modulus(2) - .set_moduli_sizes(&[62, 62, 62, 61, 60, 11]) - .build()?; - assert_eq!( - params.moduli.to_vec(), - &[ - 4611686018427387761, - 4611686018427387617, - 4611686018427387409, - 2305843009213693921, - 1152921504606846577, - 2017 - ] - ); + #[test] + fn ciphertext_moduli() -> Result<(), Box<dyn Error>> { + let params = BfvParametersBuilder::new() + .set_degree(8) + .set_plaintext_modulus(2) + .set_moduli_sizes(&[62, 62, 62, 61, 60, 11]) + .build()?; + assert_eq!( + params.moduli.to_vec(), + &[ + 4611686018427387761, + 4611686018427387617, + 4611686018427387409, + 2305843009213693921, + 1152921504606846577, + 2017 + ] + ); - let params = BfvParametersBuilder::new() - .set_degree(8) - .set_plaintext_modulus(2) - .set_moduli(&[ - 4611686018427387761, - 4611686018427387617, - 4611686018427387409, - 2305843009213693921, - 1152921504606846577, - 2017, - ]) - .build()?; - assert_eq!(params.moduli_sizes.to_vec(), &[62, 62, 62, 61, 60, 11]); + let params = BfvParametersBuilder::new() + .set_degree(8) + .set_plaintext_modulus(2) + .set_moduli(&[ + 4611686018427387761, + 4611686018427387617, + 4611686018427387409, + 2305843009213693921, + 1152921504606846577, + 2017, + ]) + .build()?; + assert_eq!(params.moduli_sizes.to_vec(), &[62, 62, 62, 61, 60, 11]); - Ok(()) - } + Ok(()) + } - #[test] - fn serialize() -> Result<(), Box<dyn Error>> { - let params = BfvParametersBuilder::new() - .set_degree(8) - .set_plaintext_modulus(2) - .set_moduli_sizes(&[62, 62, 62, 61, 60, 11]) - .set_variance(4) - .build()?; - let bytes = params.to_bytes(); - assert_eq!(BfvParameters::try_deserialize(&bytes)?, params); - Ok(()) - } + #[test] + fn serialize() -> Result<(), Box<dyn Error>> { + let params = BfvParametersBuilder::new() + .set_degree(8) + .set_plaintext_modulus(2) + .set_moduli_sizes(&[62, 62, 62, 61, 60, 11]) + .set_variance(4) + .build()?; + let bytes = params.to_bytes(); + assert_eq!(BfvParameters::try_deserialize(&bytes)?, params); + Ok(()) + } } diff --git a/crates/fhe/src/bfv/plaintext.rs b/crates/fhe/src/bfv/plaintext.rs index a0599da..c661d22 100644 --- a/crates/fhe/src/bfv/plaintext.rs +++ b/crates/fhe/src/bfv/plaintext.rs @@ -1,7 +1,7 @@ //! Plaintext type in the BFV encryption scheme. use crate::{ - bfv::{BfvParameters, Encoding, PlaintextVec}, - Error, Result, + bfv::{BfvParameters, Encoding, PlaintextVec}, + Error, Result, }; use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation}; use fhe_traits::{FheDecoder, FheEncoder, FheParametrized, FhePlaintext}; @@ -13,69 +13,69 @@ use super::encoding::EncodingEnum; /// A plaintext object, that encodes a vector according to a specific encoding. #[derive(Debug, Clone, Eq)] pub struct Plaintext { - /// The parameters of the underlying BFV encryption scheme. - pub(crate) par: Arc<BfvParameters>, - /// The value after encoding. - pub(crate) value: Box<[u64]>, - /// The encoding of the plaintext, if known - pub(crate) encoding: Option<Encoding>, - /// The plaintext as a polynomial. - pub(crate) poly_ntt: Poly, - /// The level of the plaintext - pub(crate) level: usize, + /// The parameters of the underlying BFV encryption scheme. + pub(crate) par: Arc<BfvParameters>, + /// The value after encoding. + pub(crate) value: Box<[u64]>, + /// The encoding of the plaintext, if known + pub(crate) encoding: Option<Encoding>, + /// The plaintext as a polynomial. + pub(crate) poly_ntt: Poly, + /// The level of the plaintext + pub(crate) level: usize, } impl FheParametrized for Plaintext { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl FhePlaintext for Plaintext { - type Encoding = Encoding; + type Encoding = Encoding; } // Zeroizing of plaintexts. impl ZeroizeOnDrop for Plaintext {} impl Zeroize for Plaintext { - fn zeroize(&mut self) { - self.value.zeroize(); - self.poly_ntt.zeroize(); - } + fn zeroize(&mut self) { + self.value.zeroize(); + self.poly_ntt.zeroize(); + } } impl Plaintext { - pub(crate) fn to_poly(&self) -> Poly { - let mut m_v = Zeroizing::new(self.value.clone()); - self.par - .plaintext - .scalar_mul_vec(&mut m_v, self.par.q_mod_t[self.level]); - let ctx = self.par.ctx_at_level(self.level).unwrap(); - let mut m = - Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap(); - m.change_representation(Representation::Ntt); - m *= &self.par.delta[self.level]; - m - } + pub(crate) fn to_poly(&self) -> Poly { + let mut m_v = Zeroizing::new(self.value.clone()); + self.par + .plaintext + .scalar_mul_vec(&mut m_v, self.par.q_mod_t[self.level]); + let ctx = self.par.ctx_at_level(self.level).unwrap(); + let mut m = + Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap(); + m.change_representation(Representation::Ntt); + m *= &self.par.delta[self.level]; + m + } - /// Generate a zero plaintext. - pub fn zero(encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { - let level = encoding.level; - let ctx = par.ctx_at_level(level)?; - let value = vec![0u64; par.degree()]; - let poly_ntt = Poly::zero(ctx, Representation::Ntt); - Ok(Self { - par: par.clone(), - value: value.into_boxed_slice(), - encoding: Some(encoding), - poly_ntt, - level, - }) - } + /// Generate a zero plaintext. + pub fn zero(encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + let level = encoding.level; + let ctx = par.ctx_at_level(level)?; + let value = vec![0u64; par.degree()]; + let poly_ntt = Poly::zero(ctx, Representation::Ntt); + Ok(Self { + par: par.clone(), + value: value.into_boxed_slice(), + encoding: Some(encoding), + poly_ntt, + level, + }) + } - /// Returns the level of this plaintext. - pub fn level(&self) -> usize { - self.par.level_of_ctx(self.poly_ntt.ctx()).unwrap() - } + /// Returns the level of this plaintext. + pub fn level(&self) -> usize { + self.par.level_of_ctx(self.poly_ntt.ctx()).unwrap() + } } unsafe impl Send for Plaintext {} @@ -83,320 +83,320 @@ unsafe impl Send for Plaintext {} // Implement the equality manually; we want to say that two plaintexts are equal // even if one of them doesn't store its encoding information. impl PartialEq for Plaintext { - fn eq(&self, other: &Self) -> bool { - let mut eq = self.par == other.par; - eq &= self.value == other.value; - if self.encoding.is_some() && other.encoding.is_some() { - eq &= self.encoding.as_ref().unwrap() == other.encoding.as_ref().unwrap() - } - eq - } + fn eq(&self, other: &Self) -> bool { + let mut eq = self.par == other.par; + eq &= self.value == other.value; + if self.encoding.is_some() && other.encoding.is_some() { + eq &= self.encoding.as_ref().unwrap() == other.encoding.as_ref().unwrap() + } + eq + } } // Conversions. impl TryConvertFrom<&Plaintext> for Poly { - fn try_convert_from<R>( - pt: &Plaintext, - ctx: &Arc<Context>, - variable_time: bool, - _: R, - ) -> fhe_math::Result<Self> - where - R: Into<Option<Representation>>, - { - if ctx - != pt - .par - .ctx_at_level(pt.level()) - .map_err(|e| fhe_math::Error::Default(e.to_string()))? - { - Err(fhe_math::Error::Default( - "Incompatible contexts".to_string(), - )) - } else { - Poly::try_convert_from( - pt.value.as_ref(), - ctx, - variable_time, - Representation::PowerBasis, - ) - } - } + fn try_convert_from<R>( + pt: &Plaintext, + ctx: &Arc<Context>, + variable_time: bool, + _: R, + ) -> fhe_math::Result<Self> + where + R: Into<Option<Representation>>, + { + if ctx + != pt + .par + .ctx_at_level(pt.level()) + .map_err(|e| fhe_math::Error::Default(e.to_string()))? + { + Err(fhe_math::Error::Default( + "Incompatible contexts".to_string(), + )) + } else { + Poly::try_convert_from( + pt.value.as_ref(), + ctx, + variable_time, + Representation::PowerBasis, + ) + } + } } // Encoding and decoding. impl<'a, const N: usize, T> FheEncoder<&'a [T; N]> for Plaintext where - Plaintext: FheEncoder<&'a [T], Error = Error>, + Plaintext: FheEncoder<&'a [T], Error = Error>, { - type Error = Error; - fn try_encode(value: &'a [T; N], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { - Plaintext::try_encode(value.as_ref(), encoding, par) - } + type Error = Error; + fn try_encode(value: &'a [T; N], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + Plaintext::try_encode(value.as_ref(), encoding, par) + } } impl<'a, T> FheEncoder<&'a Vec<T>> for Plaintext where - Plaintext: FheEncoder<&'a [T], Error = Error>, + Plaintext: FheEncoder<&'a [T], Error = Error>, { - type Error = Error; - fn try_encode(value: &'a Vec<T>, encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { - Plaintext::try_encode(value.as_ref(), encoding, par) - } + type Error = Error; + fn try_encode(value: &'a Vec<T>, encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + Plaintext::try_encode(value.as_ref(), encoding, par) + } } impl<'a> FheEncoder<&'a [u64]> for Plaintext { - type Error = Error; - fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { - if value.len() > par.degree() { - return Err(Error::TooManyValues(value.len(), par.degree())); - } - let v = PlaintextVec::try_encode(value, encoding, par)?; - Ok(v.0[0].clone()) - } + type Error = Error; + fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + if value.len() > par.degree() { + return Err(Error::TooManyValues(value.len(), par.degree())); + } + let v = PlaintextVec::try_encode(value, encoding, par)?; + Ok(v.0[0].clone()) + } } impl<'a> FheEncoder<&'a [i64]> for Plaintext { - type Error = Error; - fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { - let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value)); - Plaintext::try_encode(w.as_ref() as &[u64], encoding, par) - } + type Error = Error; + fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value)); + Plaintext::try_encode(w.as_ref() as &[u64], encoding, par) + } } impl FheDecoder<Plaintext> for Vec<u64> { - fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>> - where - O: Into<Option<Encoding>>, - { - let encoding = encoding.into(); - let enc: Encoding; - if pt.encoding.is_none() && encoding.is_none() { - return Err(Error::UnspecifiedInput("No encoding specified".to_string())); - } else if pt.encoding.is_some() { - enc = pt.encoding.as_ref().unwrap().clone(); - if let Some(arg_enc) = encoding { - if arg_enc != enc { - return Err(Error::EncodingMismatch(arg_enc.into(), enc.into())); - } - } - } else { - enc = encoding.unwrap(); - if let Some(pt_enc) = pt.encoding.as_ref() { - if pt_enc != &enc { - return Err(Error::EncodingMismatch(pt_enc.into(), enc.into())); - } - } - } + fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>> + where + O: Into<Option<Encoding>>, + { + let encoding = encoding.into(); + let enc: Encoding; + if pt.encoding.is_none() && encoding.is_none() { + return Err(Error::UnspecifiedInput("No encoding specified".to_string())); + } else if pt.encoding.is_some() { + enc = pt.encoding.as_ref().unwrap().clone(); + if let Some(arg_enc) = encoding { + if arg_enc != enc { + return Err(Error::EncodingMismatch(arg_enc.into(), enc.into())); + } + } + } else { + enc = encoding.unwrap(); + if let Some(pt_enc) = pt.encoding.as_ref() { + if pt_enc != &enc { + return Err(Error::EncodingMismatch(pt_enc.into(), enc.into())); + } + } + } - let mut w = pt.value.to_vec(); + let mut w = pt.value.to_vec(); - match enc.encoding { - EncodingEnum::Poly => Ok(w), - EncodingEnum::Simd => { - if let Some(op) = &pt.par.op { - op.forward(&mut w); - let mut w_reordered = w.clone(); - for i in 0..pt.par.degree() { - w_reordered[i] = w[pt.par.matrix_reps_index_map[i]] - } - w.zeroize(); - Ok(w_reordered) - } else { - Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string())) - } - } - } - } + match enc.encoding { + EncodingEnum::Poly => Ok(w), + EncodingEnum::Simd => { + if let Some(op) = &pt.par.op { + op.forward(&mut w); + let mut w_reordered = w.clone(); + for i in 0..pt.par.degree() { + w_reordered[i] = w[pt.par.matrix_reps_index_map[i]] + } + w.zeroize(); + Ok(w_reordered) + } else { + Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string())) + } + } + } + } - type Error = Error; + type Error = Error; } impl FheDecoder<Plaintext> for Vec<i64> { - fn try_decode<E>(pt: &Plaintext, encoding: E) -> Result<Vec<i64>> - where - E: Into<Option<Encoding>>, - { - let v = Vec::<u64>::try_decode(pt, encoding)?; - Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) }) - } + fn try_decode<E>(pt: &Plaintext, encoding: E) -> Result<Vec<i64>> + where + E: Into<Option<Encoding>>, + { + let v = Vec::<u64>::try_decode(pt, encoding)?; + Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) }) + } - type Error = Error; + type Error = Error; } #[cfg(test)] mod tests { - use super::{Encoding, Plaintext}; - use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder}; - use fhe_math::rq::{Poly, Representation}; - use fhe_traits::{FheDecoder, FheEncoder}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; - use zeroize::Zeroize; + use super::{Encoding, Plaintext}; + use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder}; + use fhe_math::rq::{Poly, Representation}; + use fhe_traits::{FheDecoder, FheEncoder}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; + use zeroize::Zeroize; - #[test] - fn try_encode() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - // The default test parameters support both Poly and Simd encodings - let params = Arc::new(BfvParameters::default(1, 8)); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + #[test] + fn try_encode() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + // The default test parameters support both Poly and Simd encodings + let params = Arc::new(BfvParameters::default(1, 8)); + let a = params.plaintext.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), &params); - assert!(plaintext.is_err()); + let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), &params); + assert!(plaintext.is_err()); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); - assert!(plaintext.is_ok()); + let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); + assert!(plaintext.is_ok()); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); - assert!(plaintext.is_ok()); + let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + assert!(plaintext.is_ok()); - let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), &params); - assert!(plaintext.is_ok()); + let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), &params); + assert!(plaintext.is_ok()); - // The following parameters do not allow for Simd encoding - let params = Arc::new( - BfvParametersBuilder::new() - .set_degree(8) - .set_plaintext_modulus(2) - .set_moduli(&[4611686018326724609]) - .build()?, - ); + // The following parameters do not allow for Simd encoding + let params = Arc::new( + BfvParametersBuilder::new() + .set_degree(8) + .set_plaintext_modulus(2) + .set_moduli(&[4611686018326724609]) + .build()?, + ); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + let a = params.plaintext.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); - assert!(plaintext.is_ok()); + let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); + assert!(plaintext.is_ok()); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); - assert!(plaintext.is_err()); + let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + assert!(plaintext.is_err()); - Ok(()) - } + Ok(()) + } - #[test] - fn encode_decode() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(1, 8)); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + #[test] + fn encode_decode() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(1, 8)); + let a = params.plaintext.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); - assert!(plaintext.is_ok()); - let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?; - assert_eq!(b, a); + let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + assert!(plaintext.is_ok()); + let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?; + assert_eq!(b, a); - let a = unsafe { params.plaintext.center_vec_vt(&a) }; - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); - assert!(plaintext.is_ok()); - let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?; - assert_eq!(b, a); + let a = unsafe { params.plaintext.center_vec_vt(&a) }; + let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); + assert!(plaintext.is_ok()); + let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?; + assert_eq!(b, a); - let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); - assert!(plaintext.is_ok()); - let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?; - assert_eq!(b, a); + let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params); + assert!(plaintext.is_ok()); + let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?; + assert_eq!(b, a); - Ok(()) - } + Ok(()) + } - #[test] - fn partial_eq() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(1, 8)); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + #[test] + fn partial_eq() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(1, 8)); + let a = params.plaintext.random_vec(params.degree(), &mut rng); - let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; - let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; - assert_eq!(plaintext, same_plaintext); + let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + assert_eq!(plaintext, same_plaintext); - // Equality also holds when there is no encoding specified. In this test, we use - // the fact that we can set it to None directly, but such a partial plaintext - // will be created during decryption since we do not specify the encoding at the - // time. - same_plaintext.encoding = None; - assert_eq!(plaintext, same_plaintext); + // Equality also holds when there is no encoding specified. In this test, we use + // the fact that we can set it to None directly, but such a partial plaintext + // will be created during decryption since we do not specify the encoding at the + // time. + same_plaintext.encoding = None; + assert_eq!(plaintext, same_plaintext); - Ok(()) - } + Ok(()) + } - #[test] - fn try_decode_errors() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(1, 8)); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + #[test] + fn try_decode_errors() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(1, 8)); + let a = params.plaintext.random_vec(params.degree(), &mut rng); - let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; - assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok()); - let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd()); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::EncodingMismatch(Encoding::simd().into(), Encoding::poly().into()) - ); - let e = Vec::<u64>::try_decode(&plaintext, Encoding::poly_at_level(1)); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::EncodingMismatch( - Encoding::poly_at_level(1).into(), - Encoding::poly().into() - ) - ); + assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok()); + let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd()); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::EncodingMismatch(Encoding::simd().into(), Encoding::poly().into()) + ); + let e = Vec::<u64>::try_decode(&plaintext, Encoding::poly_at_level(1)); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::EncodingMismatch( + Encoding::poly_at_level(1).into(), + Encoding::poly().into() + ) + ); - plaintext.encoding = None; - let e = Vec::<u64>::try_decode(&plaintext, None); - assert!(e.is_err()); - assert_eq!( - e.unwrap_err(), - crate::Error::UnspecifiedInput("No encoding specified".to_string()) - ); + plaintext.encoding = None; + let e = Vec::<u64>::try_decode(&plaintext, None); + assert!(e.is_err()); + assert_eq!( + e.unwrap_err(), + crate::Error::UnspecifiedInput("No encoding specified".to_string()) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn zero() -> Result<(), Box<dyn Error>> { - let params = Arc::new(BfvParameters::default(1, 8)); - let plaintext = Plaintext::zero(Encoding::poly(), &params)?; + #[test] + fn zero() -> Result<(), Box<dyn Error>> { + let params = Arc::new(BfvParameters::default(1, 8)); + let plaintext = Plaintext::zero(Encoding::poly(), &params)?; - assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8])); - assert_eq!( - plaintext.poly_ntt, - Poly::zero(&params.ctx[0], Representation::Ntt) - ); + assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8])); + assert_eq!( + plaintext.poly_ntt, + Poly::zero(&params.ctx[0], Representation::Ntt) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn zeroize() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - let params = Arc::new(BfvParameters::default(1, 8)); - let a = params.plaintext.random_vec(params.degree(), &mut rng); - let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; + #[test] + fn zeroize() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + let params = Arc::new(BfvParameters::default(1, 8)); + let a = params.plaintext.random_vec(params.degree(), &mut rng); + let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?; - plaintext.zeroize(); + plaintext.zeroize(); - assert_eq!(plaintext, Plaintext::zero(Encoding::poly(), &params)?); + assert_eq!(plaintext, Plaintext::zero(Encoding::poly(), &params)?); - Ok(()) - } + Ok(()) + } - #[test] - fn try_encode_level() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - // The default test parameters support both Poly and Simd encodings - let params = Arc::new(BfvParameters::default(10, 8)); - let a = params.plaintext.random_vec(params.degree(), &mut rng); + #[test] + fn try_encode_level() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + // The default test parameters support both Poly and Simd encodings + let params = Arc::new(BfvParameters::default(10, 8)); + let a = params.plaintext.random_vec(params.degree(), &mut rng); - for level in 0..10 { - let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), &params)?; - assert_eq!(plaintext.level(), level); - let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), &params)?; - assert_eq!(plaintext.level(), level); - } + for level in 0..10 { + let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), &params)?; + assert_eq!(plaintext.level(), level); + let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), &params)?; + assert_eq!(plaintext.level(), level); + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/plaintext_vec.rs b/crates/fhe/src/bfv/plaintext_vec.rs index 4e25fbf..bc07da4 100644 --- a/crates/fhe/src/bfv/plaintext_vec.rs +++ b/crates/fhe/src/bfv/plaintext_vec.rs @@ -6,8 +6,8 @@ use fhe_util::div_ceil; use zeroize::{Zeroize, ZeroizeOnDrop}; use crate::{ - bfv::{BfvParameters, Encoding, Plaintext}, - Error, Result, + bfv::{BfvParameters, Encoding, Plaintext}, + Error, Result, }; use super::encoding::EncodingEnum; @@ -17,148 +17,148 @@ use super::encoding::EncodingEnum; pub struct PlaintextVec(pub Vec<Plaintext>); impl FhePlaintext for PlaintextVec { - type Encoding = Encoding; + type Encoding = Encoding; } impl FheParametrized for PlaintextVec { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl Zeroize for PlaintextVec { - fn zeroize(&mut self) { - self.0.zeroize() - } + fn zeroize(&mut self) { + self.0.zeroize() + } } impl ZeroizeOnDrop for PlaintextVec {} impl FheEncoderVariableTime<&[u64]> for PlaintextVec { - type Error = Error; + type Error = Error; - unsafe fn try_encode_vt( - value: &[u64], - encoding: Encoding, - par: &Arc<BfvParameters>, - ) -> Result<Self> { - if value.is_empty() { - return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?])); - } - if encoding.encoding == EncodingEnum::Simd && par.op.is_none() { - return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string())); - } - let ctx = par.ctx_at_level(encoding.level)?; - let num_plaintexts = div_ceil(value.len(), par.degree()); + unsafe fn try_encode_vt( + value: &[u64], + encoding: Encoding, + par: &Arc<BfvParameters>, + ) -> Result<Self> { + if value.is_empty() { + return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?])); + } + if encoding.encoding == EncodingEnum::Simd && par.op.is_none() { + return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string())); + } + let ctx = par.ctx_at_level(encoding.level)?; + let num_plaintexts = div_ceil(value.len(), par.degree()); - Ok(PlaintextVec( - (0..num_plaintexts) - .map(|i| { - let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())]; - let mut v = vec![0u64; par.degree()]; - match encoding.encoding { - EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice), - EncodingEnum::Simd => { - for i in 0..slice.len() { - v[par.matrix_reps_index_map[i]] = slice[i]; - } - par.op.as_ref().unwrap().backward_vt(v.as_mut_ptr()); - } - }; + Ok(PlaintextVec( + (0..num_plaintexts) + .map(|i| { + let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())]; + let mut v = vec![0u64; par.degree()]; + match encoding.encoding { + EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice), + EncodingEnum::Simd => { + for i in 0..slice.len() { + v[par.matrix_reps_index_map[i]] = slice[i]; + } + par.op.as_ref().unwrap().backward_vt(v.as_mut_ptr()); + } + }; - let mut poly = - Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?; - poly.change_representation(Representation::Ntt); + let mut poly = + Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?; + poly.change_representation(Representation::Ntt); - Ok(Plaintext { - par: par.clone(), - value: v.into_boxed_slice(), - encoding: Some(encoding.clone()), - poly_ntt: poly, - level: encoding.level, - }) - }) - .collect::<Result<Vec<Plaintext>>>()?, - )) - } + Ok(Plaintext { + par: par.clone(), + value: v.into_boxed_slice(), + encoding: Some(encoding.clone()), + poly_ntt: poly, + level: encoding.level, + }) + }) + .collect::<Result<Vec<Plaintext>>>()?, + )) + } } impl FheEncoder<&[u64]> for PlaintextVec { - type Error = Error; - fn try_encode(value: &[u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { - if value.is_empty() { - return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?])); - } - if encoding.encoding == EncodingEnum::Simd && par.op.is_none() { - return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string())); - } - let ctx = par.ctx_at_level(encoding.level)?; - let num_plaintexts = div_ceil(value.len(), par.degree()); + type Error = Error; + fn try_encode(value: &[u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> { + if value.is_empty() { + return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?])); + } + if encoding.encoding == EncodingEnum::Simd && par.op.is_none() { + return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string())); + } + let ctx = par.ctx_at_level(encoding.level)?; + let num_plaintexts = div_ceil(value.len(), par.degree()); - Ok(PlaintextVec( - (0..num_plaintexts) - .map(|i| { - let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())]; - let mut v = vec![0u64; par.degree()]; - match encoding.encoding { - EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice), - EncodingEnum::Simd => { - for i in 0..slice.len() { - v[par.matrix_reps_index_map[i]] = slice[i]; - } - par.op.as_ref().unwrap().backward(&mut v); - } - }; + Ok(PlaintextVec( + (0..num_plaintexts) + .map(|i| { + let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())]; + let mut v = vec![0u64; par.degree()]; + match encoding.encoding { + EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice), + EncodingEnum::Simd => { + for i in 0..slice.len() { + v[par.matrix_reps_index_map[i]] = slice[i]; + } + par.op.as_ref().unwrap().backward(&mut v); + } + }; - let mut poly = - Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?; - poly.change_representation(Representation::Ntt); + let mut poly = + Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?; + poly.change_representation(Representation::Ntt); - Ok(Plaintext { - par: par.clone(), - value: v.into_boxed_slice(), - encoding: Some(encoding.clone()), - poly_ntt: poly, - level: encoding.level, - }) - }) - .collect::<Result<Vec<Plaintext>>>()?, - )) - } + Ok(Plaintext { + par: par.clone(), + value: v.into_boxed_slice(), + encoding: Some(encoding.clone()), + poly_ntt: poly, + level: encoding.level, + }) + }) + .collect::<Result<Vec<Plaintext>>>()?, + )) + } } #[cfg(test)] mod tests { - use crate::bfv::{BfvParameters, Encoding, PlaintextVec}; - use fhe_traits::{FheDecoder, FheEncoder, FheEncoderVariableTime}; - use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use crate::bfv::{BfvParameters, Encoding, PlaintextVec}; + use fhe_traits::{FheDecoder, FheEncoder, FheEncoderVariableTime}; + use rand::thread_rng; + use std::{error::Error, sync::Arc}; - #[test] - fn encode_decode() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for _ in 0..20 { - for i in 1..5 { - let params = Arc::new(BfvParameters::default(1, 8)); - let a = params.plaintext.random_vec(params.degree() * i, &mut rng); + #[test] + fn encode_decode() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for _ in 0..20 { + for i in 1..5 { + let params = Arc::new(BfvParameters::default(1, 8)); + let a = params.plaintext.random_vec(params.degree() * i, &mut rng); - let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), &params)?; - assert_eq!(plaintexts.0.len(), i); + let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), &params)?; + assert_eq!(plaintexts.0.len(), i); - for j in 0..i { - let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?; - assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); - } + for j in 0..i { + let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?; + assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); + } - let plaintexts = unsafe { - PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), &params)? - }; - assert_eq!(plaintexts.0.len(), i); + let plaintexts = unsafe { + PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), &params)? + }; + assert_eq!(plaintexts.0.len(), i); - for j in 0..i { - let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?; - assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); - } - } - } - Ok(()) - } + for j in 0..i { + let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?; + assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]); + } + } + } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/rgsw_ciphertext.rs b/crates/fhe/src/bfv/rgsw_ciphertext.rs index 3b70e2c..f1e7fe2 100644 --- a/crates/fhe/src/bfv/rgsw_ciphertext.rs +++ b/crates/fhe/src/bfv/rgsw_ciphertext.rs @@ -2,210 +2,210 @@ use std::ops::Mul; use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation}; use fhe_traits::{ - DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize, + DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize, }; use protobuf::{Message, MessageField}; use rand::{CryptoRng, RngCore}; use zeroize::Zeroizing; use crate::{ - bfv::proto::bfv::{ - KeySwitchingKey as KeySwitchingKeyProto, RGSWCiphertext as RGSWCiphertextProto, - }, - Error, Result, + bfv::proto::bfv::{ + KeySwitchingKey as KeySwitchingKeyProto, RGSWCiphertext as RGSWCiphertextProto, + }, + Error, Result, }; use super::{ - keys::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, Ciphertext, Plaintext, SecretKey, + keys::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, Ciphertext, Plaintext, SecretKey, }; /// A RGSW ciphertext encrypting a plaintext. #[derive(Debug, PartialEq, Eq)] pub struct RGSWCiphertext { - ksk0: KeySwitchingKey, - ksk1: KeySwitchingKey, + ksk0: KeySwitchingKey, + ksk1: KeySwitchingKey, } impl FheParametrized for RGSWCiphertext { - type Parameters = BfvParameters; + type Parameters = BfvParameters; } impl From<&RGSWCiphertext> for RGSWCiphertextProto { - fn from(ct: &RGSWCiphertext) -> Self { - let mut proto = RGSWCiphertextProto::new(); - proto.ksk0 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk0)); - proto.ksk1 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk1)); - proto - } + fn from(ct: &RGSWCiphertext) -> Self { + let mut proto = RGSWCiphertextProto::new(); + proto.ksk0 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk0)); + proto.ksk1 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk1)); + proto + } } impl TryConvertFrom<&RGSWCiphertextProto> for RGSWCiphertext { - fn try_convert_from( - value: &RGSWCiphertextProto, - par: &std::sync::Arc<BfvParameters>, - ) -> Result<Self> { - if value.ksk0.is_none() || value.ksk1.is_none() { - return Err(Error::SerializationError); - } + fn try_convert_from( + value: &RGSWCiphertextProto, + par: &std::sync::Arc<BfvParameters>, + ) -> Result<Self> { + if value.ksk0.is_none() || value.ksk1.is_none() { + return Err(Error::SerializationError); + } - let ksk0 = KeySwitchingKey::try_convert_from(value.ksk0.as_ref().unwrap(), par)?; - let ksk1 = KeySwitchingKey::try_convert_from(value.ksk1.as_ref().unwrap(), par)?; - if ksk0.ksk_level != ksk0.ciphertext_level - || ksk0.ciphertext_level != ksk1.ciphertext_level - || ksk1.ciphertext_level != ksk1.ksk_level - { - return Err(Error::SerializationError); - } + let ksk0 = KeySwitchingKey::try_convert_from(value.ksk0.as_ref().unwrap(), par)?; + let ksk1 = KeySwitchingKey::try_convert_from(value.ksk1.as_ref().unwrap(), par)?; + if ksk0.ksk_level != ksk0.ciphertext_level + || ksk0.ciphertext_level != ksk1.ciphertext_level + || ksk1.ciphertext_level != ksk1.ksk_level + { + return Err(Error::SerializationError); + } - Ok(Self { ksk0, ksk1 }) - } + Ok(Self { ksk0, ksk1 }) + } } impl DeserializeParametrized for RGSWCiphertext { - type Error = Error; + type Error = Error; - fn from_bytes(bytes: &[u8], par: &std::sync::Arc<Self::Parameters>) -> Result<Self> { - let proto = - RGSWCiphertextProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?; - RGSWCiphertext::try_convert_from(&proto, par) - } + fn from_bytes(bytes: &[u8], par: &std::sync::Arc<Self::Parameters>) -> Result<Self> { + let proto = + RGSWCiphertextProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?; + RGSWCiphertext::try_convert_from(&proto, par) + } } impl Serialize for RGSWCiphertext { - fn to_bytes(&self) -> Vec<u8> { - RGSWCiphertextProto::from(self).write_to_bytes().unwrap() - } + fn to_bytes(&self) -> Vec<u8> { + RGSWCiphertextProto::from(self).write_to_bytes().unwrap() + } } impl FheCiphertext for RGSWCiphertext {} impl FheEncrypter<Plaintext, RGSWCiphertext> for SecretKey { - type Error = Error; + type Error = Error; - fn try_encrypt<R: RngCore + CryptoRng>( - &self, - pt: &Plaintext, - rng: &mut R, - ) -> Result<RGSWCiphertext> { - let level = pt.level; - let ctx = self.par.ctx_at_level(level)?; + fn try_encrypt<R: RngCore + CryptoRng>( + &self, + pt: &Plaintext, + rng: &mut R, + ) -> Result<RGSWCiphertext> { + let level = pt.level; + let ctx = self.par.ctx_at_level(level)?; - let mut m = Zeroizing::new(pt.poly_ntt.clone()); - let mut m_s = Zeroizing::new(Poly::try_convert_from( - self.coeffs.as_ref(), - ctx, - false, - Representation::PowerBasis, - )?); - m_s.change_representation(Representation::Ntt); - *m_s.as_mut() *= m.as_ref(); - m_s.change_representation(Representation::PowerBasis); - m.change_representation(Representation::PowerBasis); + let mut m = Zeroizing::new(pt.poly_ntt.clone()); + let mut m_s = Zeroizing::new(Poly::try_convert_from( + self.coeffs.as_ref(), + ctx, + false, + Representation::PowerBasis, + )?); + m_s.change_representation(Representation::Ntt); + *m_s.as_mut() *= m.as_ref(); + m_s.change_representation(Representation::PowerBasis); + m.change_representation(Representation::PowerBasis); - let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?; - let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?; + let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?; + let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?; - Ok(RGSWCiphertext { ksk0, ksk1 }) - } + Ok(RGSWCiphertext { ksk0, ksk1 }) + } } impl Mul<&RGSWCiphertext> for &Ciphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn mul(self, rhs: &RGSWCiphertext) -> Self::Output { - assert_eq!( - self.par, rhs.ksk0.par, - "Ciphertext and RGSWCiphertext must have the same parameters" - ); - assert_eq!( - self.level, rhs.ksk0.ciphertext_level, - "Ciphertext and RGSWCiphertext must have the same level" - ); - assert_eq!(self.c.len(), 2, "Ciphertext must have two parts"); + fn mul(self, rhs: &RGSWCiphertext) -> Self::Output { + assert_eq!( + self.par, rhs.ksk0.par, + "Ciphertext and RGSWCiphertext must have the same parameters" + ); + assert_eq!( + self.level, rhs.ksk0.ciphertext_level, + "Ciphertext and RGSWCiphertext must have the same level" + ); + assert_eq!(self.c.len(), 2, "Ciphertext must have two parts"); - let mut ct0 = self.c[0].clone(); - let mut ct1 = self.c[1].clone(); - ct0.change_representation(Representation::PowerBasis); - ct1.change_representation(Representation::PowerBasis); + let mut ct0 = self.c[0].clone(); + let mut ct1 = self.c[1].clone(); + ct0.change_representation(Representation::PowerBasis); + ct1.change_representation(Representation::PowerBasis); - let (c0, c1) = rhs.ksk0.key_switch(&ct0).unwrap(); - let (c0p, c1p) = rhs.ksk1.key_switch(&ct1).unwrap(); + let (c0, c1) = rhs.ksk0.key_switch(&ct0).unwrap(); + let (c0p, c1p) = rhs.ksk1.key_switch(&ct1).unwrap(); - Ciphertext { - par: self.par.clone(), - seed: None, - c: vec![&c0 + &c0p, &c1 + &c1p], - level: self.level, - } - } + Ciphertext { + par: self.par.clone(), + seed: None, + c: vec![&c0 + &c0p, &c1 + &c1p], + level: self.level, + } + } } impl Mul<&Ciphertext> for &RGSWCiphertext { - type Output = Ciphertext; + type Output = Ciphertext; - fn mul(self, rhs: &Ciphertext) -> Self::Output { - rhs * self - } + fn mul(self, rhs: &Ciphertext) -> Self::Output { + rhs * self + } } #[cfg(test)] mod tests { - use std::{error::Error, sync::Arc}; + use std::{error::Error, sync::Arc}; - use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey}; - use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize}; - use rand::thread_rng; + use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey}; + use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize}; + use rand::thread_rng; - use super::RGSWCiphertext; + use super::RGSWCiphertext; - #[test] - fn external_product() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(2, 8)), - Arc::new(BfvParameters::default(8, 8)), - ] { - let sk = SecretKey::random(&params, &mut rng); - let v1 = params.plaintext.random_vec(params.degree(), &mut rng); - let v2 = params.plaintext.random_vec(params.degree(), &mut rng); + #[test] + fn external_product() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(2, 8)), + Arc::new(BfvParameters::default(8, 8)), + ] { + let sk = SecretKey::random(&params, &mut rng); + let v1 = params.plaintext.random_vec(params.degree(), &mut rng); + let v2 = params.plaintext.random_vec(params.degree(), &mut rng); - let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &params)?; - let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &params)?; + let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &params)?; + let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &params)?; - let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; - let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; - let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?; + let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; + let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?; - let product = &ct1 * &ct2; - let expected = sk.try_decrypt(&product)?; + let product = &ct1 * &ct2; + let expected = sk.try_decrypt(&product)?; - let ct3 = &ct1 * &ct2_rgsw; - let ct4 = &ct2_rgsw * &ct1; + let ct3 = &ct1 * &ct2_rgsw; + let ct4 = &ct2_rgsw * &ct1; - println!("Noise 1: {:?}", unsafe { sk.measure_noise(&ct3) }); - println!("Noise 2: {:?}", unsafe { sk.measure_noise(&ct4) }); - assert_eq!(expected, sk.try_decrypt(&ct3)?); - assert_eq!(expected, sk.try_decrypt(&ct4)?); - } - Ok(()) - } + println!("Noise 1: {:?}", unsafe { sk.measure_noise(&ct3) }); + println!("Noise 2: {:?}", unsafe { sk.measure_noise(&ct4) }); + assert_eq!(expected, sk.try_decrypt(&ct3)?); + assert_eq!(expected, sk.try_decrypt(&ct4)?); + } + Ok(()) + } - #[test] - fn serialize() -> Result<(), Box<dyn Error>> { - let mut rng = thread_rng(); - for params in [ - Arc::new(BfvParameters::default(6, 8)), - Arc::new(BfvParameters::default(5, 8)), - ] { - let sk = SecretKey::random(&params, &mut rng); - let v = params.plaintext.random_vec(params.degree(), &mut rng); - let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?; - let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?; + #[test] + fn serialize() -> Result<(), Box<dyn Error>> { + let mut rng = thread_rng(); + for params in [ + Arc::new(BfvParameters::default(6, 8)), + Arc::new(BfvParameters::default(5, 8)), + ] { + let sk = SecretKey::random(&params, &mut rng); + let v = params.plaintext.random_vec(params.degree(), &mut rng); + let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?; + let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?; - let bytes = ct.to_bytes(); - assert_eq!(RGSWCiphertext::from_bytes(&bytes, &params)?, ct); - } + let bytes = ct.to_bytes(); + assert_eq!(RGSWCiphertext::from_bytes(&bytes, &params)?, ct); + } - Ok(()) - } + Ok(()) + } } diff --git a/crates/fhe/src/bfv/traits.rs b/crates/fhe/src/bfv/traits.rs index 11bcbd3..06f3cf5 100644 --- a/crates/fhe/src/bfv/traits.rs +++ b/crates/fhe/src/bfv/traits.rs @@ -12,8 +12,8 @@ use std::sync::Arc; /// blanket implementation <https://github.com/rust-lang/rust/issues/50133#issuecomment-488512355>. pub trait TryConvertFrom<T> where - Self: Sized, + Self: Sized, { - /// Attempt to convert the `value` with a specific parameter. - fn try_convert_from(value: T, par: &Arc<BfvParameters>) -> Result<Self>; + /// Attempt to convert the `value` with a specific parameter. + fn try_convert_from(value: T, par: &Arc<BfvParameters>) -> Result<Self>; } diff --git a/crates/fhe/src/errors.rs b/crates/fhe/src/errors.rs index 47f3433..4642822 100644 --- a/crates/fhe/src/errors.rs +++ b/crates/fhe/src/errors.rs @@ -6,141 +6,141 @@ pub type Result<T> = std::result::Result<T, Error>; /// Enum encapsulating all the possible errors from this library. #[derive(Debug, Error, PartialEq, Eq)] pub enum Error { - /// Indicates that an error from the underlying mathematical library was - /// encountered. - #[error("{0}")] - MathError(fhe_math::Error), + /// Indicates that an error from the underlying mathematical library was + /// encountered. + #[error("{0}")] + MathError(fhe_math::Error), - /// Indicates a serialization error. - #[error("Serialization error")] - SerializationError, + /// Indicates a serialization error. + #[error("Serialization error")] + SerializationError, - /// Indicates that too many values were provided. - #[error("Too many values provided: {0} exceeds limit {1}")] - TooManyValues(usize, usize), + /// Indicates that too many values were provided. + #[error("Too many values provided: {0} exceeds limit {1}")] + TooManyValues(usize, usize), - /// Indicates that too few values were provided. - #[error("Too few values provided: {0} is below limit {1}")] - TooFewValues(usize, usize), + /// Indicates that too few values were provided. + #[error("Too few values provided: {0} is below limit {1}")] + TooFewValues(usize, usize), - /// Indicates that an input is invalid. - #[error("{0}")] - UnspecifiedInput(String), + /// Indicates that an input is invalid. + #[error("{0}")] + UnspecifiedInput(String), - /// Indicates a mismatch in the encodings. - #[error("Encoding mismatch: found {0}, expected {1}")] - EncodingMismatch(String, String), + /// Indicates a mismatch in the encodings. + #[error("Encoding mismatch: found {0}, expected {1}")] + EncodingMismatch(String, String), - /// Indicates that the encoding is not supported. - #[error("Does not support {0} encoding")] - EncodingNotSupported(String), + /// Indicates that the encoding is not supported. + #[error("Does not support {0} encoding")] + EncodingNotSupported(String), - /// Indicates a parameter error. - #[error("{0}")] - ParametersError(ParametersError), + /// Indicates a parameter error. + #[error("{0}")] + ParametersError(ParametersError), - /// Indicates a default error - /// TODO: To delete eventually - #[error("{0}")] - DefaultError(String), + /// Indicates a default error + /// TODO: To delete eventually + #[error("{0}")] + DefaultError(String), } impl From<fhe_math::Error> for Error { - fn from(e: fhe_math::Error) -> Self { - Error::MathError(e) - } + fn from(e: fhe_math::Error) -> Self { + Error::MathError(e) + } } /// Separate enum to indicate parameters-related errors. #[derive(Debug, Error, PartialEq, Eq)] pub enum ParametersError { - /// Indicates that the degree is invalid. - #[error("Invalid degree: {0} is not a power of 2 larger than 8")] - InvalidDegree(usize), + /// Indicates that the degree is invalid. + #[error("Invalid degree: {0} is not a power of 2 larger than 8")] + InvalidDegree(usize), - /// Indicates that the moduli sizes are invalid. - #[error("Invalid modulus size: {0}, expected an integer between {1} and {2}")] - InvalidModulusSize(usize, usize, usize), + /// Indicates that the moduli sizes are invalid. + #[error("Invalid modulus size: {0}, expected an integer between {1} and {2}")] + InvalidModulusSize(usize, usize, usize), - /// Indicates that there exists not enough primes of this size. - #[error("Not enough primes of size {0} for polynomials of degree {1}")] - NotEnoughPrimes(usize, usize), + /// Indicates that there exists not enough primes of this size. + #[error("Not enough primes of size {0} for polynomials of degree {1}")] + NotEnoughPrimes(usize, usize), - /// Indicates that the plaintext is invalid. - #[error("{0}")] - InvalidPlaintext(String), + /// Indicates that the plaintext is invalid. + #[error("{0}")] + InvalidPlaintext(String), - /// Indicates that too many parameters were specified. - #[error("{0}")] - TooManySpecified(String), + /// Indicates that too many parameters were specified. + #[error("{0}")] + TooManySpecified(String), - /// Indicates that too few parameters were specified. - #[error("{0}")] - TooFewSpecified(String), + /// Indicates that too few parameters were specified. + #[error("{0}")] + TooFewSpecified(String), } #[cfg(test)] mod tests { - use crate::{Error, ParametersError}; + use crate::{Error, ParametersError}; - #[test] - fn error_strings() { - assert_eq!( - Error::MathError(fhe_math::Error::InvalidContext).to_string(), - fhe_math::Error::InvalidContext.to_string() - ); - assert_eq!(Error::SerializationError.to_string(), "Serialization error"); - assert_eq!( - Error::TooManyValues(20, 17).to_string(), - "Too many values provided: 20 exceeds limit 17" - ); - assert_eq!( - Error::TooFewValues(10, 17).to_string(), - "Too few values provided: 10 is below limit 17" - ); - assert_eq!( - Error::UnspecifiedInput("test string".to_string()).to_string(), - "test string" - ); - assert_eq!( - Error::EncodingMismatch("enc1".to_string(), "enc2".to_string()).to_string(), - "Encoding mismatch: found enc1, expected enc2" - ); - assert_eq!( - Error::EncodingNotSupported("test".to_string()).to_string(), - "Does not support test encoding" - ); - assert_eq!( - Error::ParametersError(ParametersError::InvalidDegree(10)).to_string(), - ParametersError::InvalidDegree(10).to_string() - ); - } + #[test] + fn error_strings() { + assert_eq!( + Error::MathError(fhe_math::Error::InvalidContext).to_string(), + fhe_math::Error::InvalidContext.to_string() + ); + assert_eq!(Error::SerializationError.to_string(), "Serialization error"); + assert_eq!( + Error::TooManyValues(20, 17).to_string(), + "Too many values provided: 20 exceeds limit 17" + ); + assert_eq!( + Error::TooFewValues(10, 17).to_string(), + "Too few values provided: 10 is below limit 17" + ); + assert_eq!( + Error::UnspecifiedInput("test string".to_string()).to_string(), + "test string" + ); + assert_eq!( + Error::EncodingMismatch("enc1".to_string(), "enc2".to_string()).to_string(), + "Encoding mismatch: found enc1, expected enc2" + ); + assert_eq!( + Error::EncodingNotSupported("test".to_string()).to_string(), + "Does not support test encoding" + ); + assert_eq!( + Error::ParametersError(ParametersError::InvalidDegree(10)).to_string(), + ParametersError::InvalidDegree(10).to_string() + ); + } - #[test] - fn parameters_error_strings() { - assert_eq!( - ParametersError::InvalidDegree(10).to_string(), - "Invalid degree: 10 is not a power of 2 larger than 8" - ); - assert_eq!( - ParametersError::InvalidModulusSize(1, 2, 3).to_string(), - "Invalid modulus size: 1, expected an integer between 2 and 3" - ); - assert_eq!( - ParametersError::NotEnoughPrimes(1, 2).to_string(), - "Not enough primes of size 1 for polynomials of degree 2" - ); - assert_eq!( - ParametersError::InvalidPlaintext("test".to_string()).to_string(), - "test" - ); - assert_eq!( - ParametersError::TooManySpecified("test".to_string()).to_string(), - "test" - ); - assert_eq!( - ParametersError::TooFewSpecified("test".to_string()).to_string(), - "test" - ); - } + #[test] + fn parameters_error_strings() { + assert_eq!( + ParametersError::InvalidDegree(10).to_string(), + "Invalid degree: 10 is not a power of 2 larger than 8" + ); + assert_eq!( + ParametersError::InvalidModulusSize(1, 2, 3).to_string(), + "Invalid modulus size: 1, expected an integer between 2 and 3" + ); + assert_eq!( + ParametersError::NotEnoughPrimes(1, 2).to_string(), + "Not enough primes of size 1 for polynomials of degree 2" + ); + assert_eq!( + ParametersError::InvalidPlaintext("test".to_string()).to_string(), + "test" + ); + assert_eq!( + ParametersError::TooManySpecified("test".to_string()).to_string(), + "test" + ); + assert_eq!( + ParametersError::TooFewSpecified("test".to_string()).to_string(), + "test" + ); + } } diff --git a/rustfmt.toml b/rustfmt.toml index d63158b..b2715b2 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1 @@ wrap_comments = true -hard_tabs = true