mirror of
https://github.com/Sunscreen-tech/fhe.rs.git
synced 2026-01-08 20:18:03 -05:00
Change tabs into space, optimize ntt operator constructor (#170)
This commit is contained in:
@@ -4,40 +4,40 @@ use rand::thread_rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn ntt_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("ntt");
|
||||
group.sample_size(50);
|
||||
let mut rng = thread_rng();
|
||||
let mut group = c.benchmark_group("ntt");
|
||||
group.sample_size(50);
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for vector_size in [1024usize, 4096].iter() {
|
||||
for p in [4611686018326724609u64, 40961u64] {
|
||||
let p_nbits = 64 - p.leading_zeros();
|
||||
let q = Modulus::new(p).unwrap();
|
||||
let mut a = q.random_vec(*vector_size, &mut rng);
|
||||
let op = NttOperator::new(&Arc::new(q), *vector_size).unwrap();
|
||||
for vector_size in [1024usize, 4096].iter() {
|
||||
for p in [4611686018326724609u64, 40961u64] {
|
||||
let p_nbits = 64 - p.leading_zeros();
|
||||
let q = Modulus::new(p).unwrap();
|
||||
let mut a = q.random_vec(*vector_size, &mut rng);
|
||||
let op = NttOperator::new(&Arc::new(q), *vector_size).unwrap();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| op.forward(&mut a)),
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| op.forward(&mut a)),
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward_vt", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| unsafe { op.forward_vt(a.as_mut_ptr()) }),
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward_vt", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| unsafe { op.forward_vt(a.as_mut_ptr()) }),
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("backward", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| op.backward(&mut a)),
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("backward", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| op.backward(&mut a)),
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("backward_vt", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| unsafe { op.backward_vt(a.as_mut_ptr()) }),
|
||||
);
|
||||
}
|
||||
}
|
||||
group.bench_function(
|
||||
BenchmarkId::new("backward_vt", format!("{vector_size}/{p_nbits}")),
|
||||
|b| b.iter(|| unsafe { op.backward_vt(a.as_mut_ptr()) }),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(ntt, ntt_benchmark);
|
||||
|
||||
@@ -5,53 +5,53 @@ use rand::{thread_rng, RngCore};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn rns_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("rns");
|
||||
group.sample_size(50);
|
||||
let mut group = c.benchmark_group("rns");
|
||||
group.sample_size(50);
|
||||
|
||||
let q = [
|
||||
4611686018326724609u64,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
];
|
||||
let p = [
|
||||
4611686018257518593u64,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
4611686018106523649,
|
||||
];
|
||||
let q = [
|
||||
4611686018326724609u64,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
];
|
||||
let p = [
|
||||
4611686018257518593u64,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
4611686018106523649,
|
||||
];
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let mut x = vec![];
|
||||
for qi in &q {
|
||||
x.push(rng.next_u64() % *qi);
|
||||
}
|
||||
let mut rng = thread_rng();
|
||||
let mut x = vec![];
|
||||
for qi in &q {
|
||||
x.push(rng.next_u64() % *qi);
|
||||
}
|
||||
|
||||
let rns_q = Arc::new(RnsContext::new(&q).unwrap());
|
||||
let rns_p = Arc::new(RnsContext::new(&p).unwrap());
|
||||
let scaler = RnsScaler::new(
|
||||
&rns_q,
|
||||
&rns_p,
|
||||
ScalingFactor::new(&BigUint::from(1u64), &BigUint::from(46116860181065u64)),
|
||||
);
|
||||
let scaler_as_converter = RnsScaler::new(&rns_q, &rns_p, ScalingFactor::one());
|
||||
let rns_q = Arc::new(RnsContext::new(&q).unwrap());
|
||||
let rns_p = Arc::new(RnsContext::new(&p).unwrap());
|
||||
let scaler = RnsScaler::new(
|
||||
&rns_q,
|
||||
&rns_p,
|
||||
ScalingFactor::new(&BigUint::from(1u64), &BigUint::from(46116860181065u64)),
|
||||
);
|
||||
let scaler_as_converter = RnsScaler::new(&rns_q, &rns_p, ScalingFactor::one());
|
||||
|
||||
let mut y = vec![0; p.len()];
|
||||
let mut y = vec![0; p.len()];
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("scaler", format!("{}->{}", q.len(), p.len())),
|
||||
|b| {
|
||||
b.iter(|| scaler.scale((&x).into(), (&mut y).into(), 0));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("scaler", format!("{}->{}", q.len(), p.len())),
|
||||
|b| {
|
||||
b.iter(|| scaler.scale((&x).into(), (&mut y).into(), 0));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("scaler_as_converter", format!("{}->{}", q.len(), p.len())),
|
||||
|b| {
|
||||
b.iter(|| scaler_as_converter.scale((&x).into(), (&mut y).into(), 0));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("scaler_as_converter", format!("{}->{}", q.len(), p.len())),
|
||||
|b| {
|
||||
b.iter(|| scaler_as_converter.scale((&x).into(), (&mut y).into(), 0));
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(rns, rns_benchmark);
|
||||
|
||||
@@ -4,287 +4,287 @@ use fhe_math::rq::*;
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::thread_rng;
|
||||
use std::{
|
||||
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
static MODULI: &[u64; 4] = &[
|
||||
562949954093057,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
562949954093057,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
];
|
||||
|
||||
static DEGREE: &[usize] = &[1024, 2048, 4096, 8192];
|
||||
|
||||
fn create_group(c: &mut Criterion, name: String) -> BenchmarkGroup<WallTime> {
|
||||
let mut group = c.benchmark_group(name);
|
||||
group.warm_up_time(Duration::from_millis(100));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
group
|
||||
let mut group = c.benchmark_group(name);
|
||||
group.warm_up_time(Duration::from_millis(100));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
group
|
||||
}
|
||||
|
||||
macro_rules! bench_op {
|
||||
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
|
||||
let name = if $vt {
|
||||
format!("{}_vt", $name)
|
||||
} else {
|
||||
$name.to_string()
|
||||
};
|
||||
let mut group = create_group($c, name);
|
||||
let mut rng = thread_rng();
|
||||
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
|
||||
let name = if $vt {
|
||||
format!("{}_vt", $name)
|
||||
} else {
|
||||
$name.to_string()
|
||||
};
|
||||
let mut group = create_group($c, name);
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for degree in DEGREE {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
if $vt {
|
||||
unsafe { q.allow_variable_time_computations() }
|
||||
}
|
||||
for degree in DEGREE {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
if $vt {
|
||||
unsafe { q.allow_variable_time_computations() }
|
||||
}
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| $op(&p, &q));
|
||||
},
|
||||
);
|
||||
}
|
||||
}};
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| $op(&p, &q));
|
||||
},
|
||||
);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! bench_op_unary {
|
||||
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
|
||||
let name = if $vt {
|
||||
format!("{}_vt", $name)
|
||||
} else {
|
||||
$name.to_string()
|
||||
};
|
||||
let mut group = create_group($c, name);
|
||||
let mut rng = thread_rng();
|
||||
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
|
||||
let name = if $vt {
|
||||
format!("{}_vt", $name)
|
||||
} else {
|
||||
$name.to_string()
|
||||
};
|
||||
let mut group = create_group($c, name);
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for degree in DEGREE {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
if $vt {
|
||||
unsafe { q.allow_variable_time_computations() }
|
||||
}
|
||||
for degree in DEGREE {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
if $vt {
|
||||
unsafe { q.allow_variable_time_computations() }
|
||||
}
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| $op(&p));
|
||||
},
|
||||
);
|
||||
}
|
||||
}};
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| $op(&p));
|
||||
},
|
||||
);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! bench_op_assign {
|
||||
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
|
||||
let name = if $vt {
|
||||
format!("{}_vt", $name)
|
||||
} else {
|
||||
$name.to_string()
|
||||
};
|
||||
let mut group = create_group($c, name);
|
||||
let mut rng = thread_rng();
|
||||
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
|
||||
let name = if $vt {
|
||||
format!("{}_vt", $name)
|
||||
} else {
|
||||
$name.to_string()
|
||||
};
|
||||
let mut group = create_group($c, name);
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for degree in DEGREE {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
|
||||
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
if $vt {
|
||||
unsafe { q.allow_variable_time_computations() }
|
||||
}
|
||||
for degree in DEGREE {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
|
||||
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
if $vt {
|
||||
unsafe { q.allow_variable_time_computations() }
|
||||
}
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| $op(&mut p, &q));
|
||||
},
|
||||
);
|
||||
}
|
||||
}};
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| $op(&mut p, &q));
|
||||
},
|
||||
);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
pub fn rq_op_benchmark(c: &mut Criterion) {
|
||||
for vt in [false, true] {
|
||||
bench_op!(c, "rq_add", <&Poly>::add, vt);
|
||||
bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt);
|
||||
bench_op!(c, "rq_sub", <&Poly>::sub, vt);
|
||||
bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt);
|
||||
bench_op!(c, "rq_mul", <&Poly>::mul, vt);
|
||||
bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt);
|
||||
bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt);
|
||||
}
|
||||
for vt in [false, true] {
|
||||
bench_op!(c, "rq_add", <&Poly>::add, vt);
|
||||
bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt);
|
||||
bench_op!(c, "rq_sub", <&Poly>::sub, vt);
|
||||
bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt);
|
||||
bench_op!(c, "rq_mul", <&Poly>::mul, vt);
|
||||
bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt);
|
||||
bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rq_dot_product(c: &mut Criterion) {
|
||||
let mut group = create_group(c, "rq_dot_product".to_string());
|
||||
let mut rng = thread_rng();
|
||||
for degree in DEGREE {
|
||||
for i in [1, 4] {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap());
|
||||
let p_vec = (0..256)
|
||||
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
|
||||
.collect_vec();
|
||||
let mut q_vec = (0..256)
|
||||
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
|
||||
.collect_vec();
|
||||
let mut out = Poly::zero(&ctx, Representation::Ntt);
|
||||
let mut group = create_group(c, "rq_dot_product".to_string());
|
||||
let mut rng = thread_rng();
|
||||
for degree in DEGREE {
|
||||
for i in [1, 4] {
|
||||
let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap());
|
||||
let p_vec = (0..256)
|
||||
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
|
||||
.collect_vec();
|
||||
let mut q_vec = (0..256)
|
||||
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
|
||||
.collect_vec();
|
||||
let mut out = Poly::zero(&ctx, Representation::Ntt);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
|
||||
});
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
q_vec
|
||||
.iter_mut()
|
||||
.for_each(|qi| qi.change_representation(Representation::NttShoup));
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"naive_shoup/{}/{}",
|
||||
degree,
|
||||
ctx.modulus().bits()
|
||||
)),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
|
||||
});
|
||||
},
|
||||
);
|
||||
q_vec
|
||||
.iter_mut()
|
||||
.for_each(|qi| qi.change_representation(Representation::NttShoup));
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"naive_shoup/{}/{}",
|
||||
degree,
|
||||
ctx.modulus().bits()
|
||||
)),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
q_vec
|
||||
.iter_mut()
|
||||
.for_each(|qi| qi.change_representation(Representation::Ntt));
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| dot_product(p_vec.iter(), q_vec.iter()));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
q_vec
|
||||
.iter_mut()
|
||||
.for_each(|qi| qi.change_representation(Representation::Ntt));
|
||||
group.bench_function(
|
||||
BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| dot_product(p_vec.iter(), q_vec.iter()));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rq_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("rq");
|
||||
group.warm_up_time(Duration::from_millis(100));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
let mut group = c.benchmark_group("rq");
|
||||
group.warm_up_time(Duration::from_millis(100));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
|
||||
let mut rng = thread_rng();
|
||||
for degree in DEGREE {
|
||||
for nmoduli in 1..=MODULI.len() {
|
||||
if !nmoduli.is_power_of_two() {
|
||||
continue;
|
||||
}
|
||||
let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap());
|
||||
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
q.change_representation(Representation::NttShoup);
|
||||
let mut rng = thread_rng();
|
||||
for degree in DEGREE {
|
||||
for nmoduli in 1..=MODULI.len() {
|
||||
if !nmoduli.is_power_of_two() {
|
||||
continue;
|
||||
}
|
||||
let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap());
|
||||
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
q.change_representation(Representation::NttShoup);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| p = &p * &q);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())),
|
||||
|b| {
|
||||
b.iter(|| p = &p * &q);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_shoup_assign",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| p *= &q);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_shoup_assign",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| p *= &q);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/PowerBasis_to_Ntt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
p.override_representation(Representation::PowerBasis);
|
||||
}
|
||||
p.change_representation(Representation::Ntt)
|
||||
});
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/PowerBasis_to_Ntt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
p.override_representation(Representation::PowerBasis);
|
||||
}
|
||||
p.change_representation(Representation::Ntt)
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/Ntt_to_PowerBasis",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
p.override_representation(Representation::Ntt);
|
||||
}
|
||||
p.change_representation(Representation::PowerBasis)
|
||||
});
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/Ntt_to_PowerBasis",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
p.override_representation(Representation::Ntt);
|
||||
}
|
||||
p.change_representation(Representation::PowerBasis)
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
p.change_representation(Representation::Ntt);
|
||||
q.change_representation(Representation::Ntt);
|
||||
p.change_representation(Representation::Ntt);
|
||||
q.change_representation(Representation::Ntt);
|
||||
|
||||
unsafe {
|
||||
q.allow_variable_time_computations();
|
||||
q.change_representation(Representation::NttShoup);
|
||||
unsafe {
|
||||
q.allow_variable_time_computations();
|
||||
q.change_representation(Representation::NttShoup);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_shoup_vt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| p *= &q);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_shoup_vt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| p *= &q);
|
||||
},
|
||||
);
|
||||
|
||||
p.allow_variable_time_computations();
|
||||
p.allow_variable_time_computations();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/PowerBasis_to_Ntt_vt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
p.override_representation(Representation::PowerBasis);
|
||||
p.change_representation(Representation::Ntt)
|
||||
});
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/PowerBasis_to_Ntt_vt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
p.override_representation(Representation::PowerBasis);
|
||||
p.change_representation(Representation::Ntt)
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/Ntt_to_PowerBasis_vt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
p.override_representation(Representation::Ntt);
|
||||
p.change_representation(Representation::PowerBasis)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"change_representation/Ntt_to_PowerBasis_vt",
|
||||
format!("{}/{}", degree, ctx.modulus().bits()),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
p.override_representation(Representation::Ntt);
|
||||
p.change_representation(Representation::PowerBasis)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(rq, rq_op_benchmark, rq_dot_product, rq_benchmark);
|
||||
|
||||
@@ -3,53 +3,53 @@ use fhe_math::zq::Modulus;
|
||||
use rand::thread_rng;
|
||||
|
||||
pub fn zq_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("zq");
|
||||
group.sample_size(50);
|
||||
let mut group = c.benchmark_group("zq");
|
||||
group.sample_size(50);
|
||||
|
||||
let p = 4611686018326724609;
|
||||
let mut rng = thread_rng();
|
||||
let p = 4611686018326724609;
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for vector_size in [1024usize, 4096].iter() {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
let mut a = q.random_vec(*vector_size, &mut rng);
|
||||
let c = q.random_vec(*vector_size, &mut rng);
|
||||
let c_shoup = q.shoup_vec(&c);
|
||||
let scalar = c[0];
|
||||
for vector_size in [1024usize, 4096].iter() {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
let mut a = q.random_vec(*vector_size, &mut rng);
|
||||
let c = q.random_vec(*vector_size, &mut rng);
|
||||
let c_shoup = q.shoup_vec(&c);
|
||||
let scalar = c[0];
|
||||
|
||||
group.bench_function(BenchmarkId::new("add_vec", vector_size), |b| {
|
||||
b.iter(|| q.add_vec(&mut a, &c));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("add_vec", vector_size), |b| {
|
||||
b.iter(|| q.add_vec(&mut a, &c));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("add_vec_vt", vector_size), |b| unsafe {
|
||||
b.iter(|| q.add_vec_vt(&mut a, &c));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("add_vec_vt", vector_size), |b| unsafe {
|
||||
b.iter(|| q.add_vec_vt(&mut a, &c));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("sub_vec", vector_size), |b| {
|
||||
b.iter(|| q.sub_vec(&mut a, &c));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("sub_vec", vector_size), |b| {
|
||||
b.iter(|| q.sub_vec(&mut a, &c));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("neg_vec", vector_size), |b| {
|
||||
b.iter(|| q.neg_vec(&mut a));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("neg_vec", vector_size), |b| {
|
||||
b.iter(|| q.neg_vec(&mut a));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("mul_vec", vector_size), |b| {
|
||||
b.iter(|| q.mul_vec(&mut a, &c));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("mul_vec", vector_size), |b| {
|
||||
b.iter(|| q.mul_vec(&mut a, &c));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("mul_vec_vt", vector_size), |b| unsafe {
|
||||
b.iter(|| q.mul_vec_vt(&mut a, &c));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("mul_vec_vt", vector_size), |b| unsafe {
|
||||
b.iter(|| q.mul_vec_vt(&mut a, &c));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("mul_shoup_vec", vector_size), |b| {
|
||||
b.iter(|| q.mul_shoup_vec(&mut a, &c, &c_shoup));
|
||||
});
|
||||
group.bench_function(BenchmarkId::new("mul_shoup_vec", vector_size), |b| {
|
||||
b.iter(|| q.mul_shoup_vec(&mut a, &c, &c_shoup));
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("scalar_mul_vec", vector_size), |b| {
|
||||
b.iter(|| q.scalar_mul_vec(&mut a, scalar));
|
||||
});
|
||||
}
|
||||
group.bench_function(BenchmarkId::new("scalar_mul_vec", vector_size), |b| {
|
||||
b.iter(|| q.scalar_mul_vec(&mut a, scalar));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(zq, zq_benchmark);
|
||||
|
||||
@@ -8,63 +8,63 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
/// Enum encapsulation all the possible errors from this library.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// Indicates an invalid modulus
|
||||
#[error("Invalid modulus: modulus {0} should be between 2 and (1 << 62) - 1.")]
|
||||
InvalidModulus(u64),
|
||||
/// Indicates an invalid modulus
|
||||
#[error("Invalid modulus: modulus {0} should be between 2 and (1 << 62) - 1.")]
|
||||
InvalidModulus(u64),
|
||||
|
||||
/// Indicates an error in the serialization / deserialization.
|
||||
#[error("{0}")]
|
||||
Serialization(String),
|
||||
/// Indicates an error in the serialization / deserialization.
|
||||
#[error("{0}")]
|
||||
Serialization(String),
|
||||
|
||||
/// Indicates that there is no more contexts to switch to.
|
||||
#[error("This is the last context.")]
|
||||
NoMoreContext,
|
||||
/// Indicates that there is no more contexts to switch to.
|
||||
#[error("This is the last context.")]
|
||||
NoMoreContext,
|
||||
|
||||
/// Indicates that the provided context is invalid.
|
||||
#[error("Invalid context provided.")]
|
||||
InvalidContext,
|
||||
/// Indicates that the provided context is invalid.
|
||||
#[error("Invalid context provided.")]
|
||||
InvalidContext,
|
||||
|
||||
/// Indicates an incorrect representation.
|
||||
#[error("Incorrect representation: got {0:?}, expected {1:?}.")]
|
||||
IncorrectRepresentation(Representation, Representation),
|
||||
/// Indicates an incorrect representation.
|
||||
#[error("Incorrect representation: got {0:?}, expected {1:?}.")]
|
||||
IncorrectRepresentation(Representation, Representation),
|
||||
|
||||
/// Indicates that the seed size is incorrect.
|
||||
#[error("Invalid seed: got {0} bytes, expected {1} bytes.")]
|
||||
InvalidSeedSize(usize, usize),
|
||||
/// Indicates that the seed size is incorrect.
|
||||
#[error("Invalid seed: got {0} bytes, expected {1} bytes.")]
|
||||
InvalidSeedSize(usize, usize),
|
||||
|
||||
/// Indicates a default error
|
||||
/// TODO: To delete when transition is over
|
||||
#[error("{0}")]
|
||||
Default(String),
|
||||
/// Indicates a default error
|
||||
/// TODO: To delete when transition is over
|
||||
#[error("{0}")]
|
||||
Default(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{rq::Representation, Error};
|
||||
use crate::{rq::Representation, Error};
|
||||
|
||||
#[test]
|
||||
fn error_strings() {
|
||||
assert_eq!(
|
||||
Error::InvalidModulus(0).to_string(),
|
||||
"Invalid modulus: modulus 0 should be between 2 and (1 << 62) - 1."
|
||||
);
|
||||
assert_eq!(Error::Serialization("test".to_string()).to_string(), "test");
|
||||
assert_eq!(
|
||||
Error::NoMoreContext.to_string(),
|
||||
"This is the last context."
|
||||
);
|
||||
assert_eq!(
|
||||
Error::InvalidContext.to_string(),
|
||||
"Invalid context provided."
|
||||
);
|
||||
assert_eq!(
|
||||
Error::IncorrectRepresentation(Representation::Ntt, Representation::NttShoup)
|
||||
.to_string(),
|
||||
"Incorrect representation: got Ntt, expected NttShoup."
|
||||
);
|
||||
assert_eq!(
|
||||
Error::InvalidSeedSize(0, 1).to_string(),
|
||||
"Invalid seed: got 0 bytes, expected 1 bytes."
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn error_strings() {
|
||||
assert_eq!(
|
||||
Error::InvalidModulus(0).to_string(),
|
||||
"Invalid modulus: modulus 0 should be between 2 and (1 << 62) - 1."
|
||||
);
|
||||
assert_eq!(Error::Serialization("test".to_string()).to_string(), "test");
|
||||
assert_eq!(
|
||||
Error::NoMoreContext.to_string(),
|
||||
"This is the last context."
|
||||
);
|
||||
assert_eq!(
|
||||
Error::InvalidContext.to_string(),
|
||||
"Invalid context provided."
|
||||
);
|
||||
assert_eq!(
|
||||
Error::IncorrectRepresentation(Representation::Ntt, Representation::NttShoup)
|
||||
.to_string(),
|
||||
"Incorrect representation: got Ntt, expected NttShoup."
|
||||
);
|
||||
assert_eq!(
|
||||
Error::InvalidSeedSize(0, 1).to_string(),
|
||||
"Invalid seed: got 0 bytes, expected 1 bytes."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,218 +17,218 @@ pub use scaler::{RnsScaler, ScalingFactor};
|
||||
/// Context for a Residue Number System.
|
||||
#[derive(Default, Clone, PartialEq, Eq)]
|
||||
pub struct RnsContext {
|
||||
moduli_u64: Vec<u64>,
|
||||
moduli: Vec<Modulus>,
|
||||
q_tilde: Vec<u64>,
|
||||
q_tilde_shoup: Vec<u64>,
|
||||
q_star: Vec<BigUint>,
|
||||
garner: Vec<BigUint>,
|
||||
product: BigUint,
|
||||
moduli_u64: Vec<u64>,
|
||||
moduli: Vec<Modulus>,
|
||||
q_tilde: Vec<u64>,
|
||||
q_tilde_shoup: Vec<u64>,
|
||||
q_star: Vec<BigUint>,
|
||||
garner: Vec<BigUint>,
|
||||
product: BigUint,
|
||||
}
|
||||
|
||||
impl Debug for RnsContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RnsContext")
|
||||
.field("moduli_u64", &self.moduli_u64)
|
||||
// .field("moduli", &self.moduli)
|
||||
// .field("q_tilde", &self.q_tilde)
|
||||
// .field("q_tilde_shoup", &self.q_tilde_shoup)
|
||||
// .field("q_star", &self.q_star)
|
||||
// .field("garner", &self.garner)
|
||||
.field("product", &self.product)
|
||||
.finish()
|
||||
}
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RnsContext")
|
||||
.field("moduli_u64", &self.moduli_u64)
|
||||
// .field("moduli", &self.moduli)
|
||||
// .field("q_tilde", &self.q_tilde)
|
||||
// .field("q_tilde_shoup", &self.q_tilde_shoup)
|
||||
// .field("q_star", &self.q_star)
|
||||
// .field("garner", &self.garner)
|
||||
.field("product", &self.product)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl RnsContext {
|
||||
/// Create a RNS context from a list of moduli.
|
||||
///
|
||||
/// Returns an error if the list is empty, or if the moduli are no coprime.
|
||||
pub fn new(moduli_u64: &[u64]) -> Result<Self> {
|
||||
if moduli_u64.is_empty() {
|
||||
Err(Error::Default("The list of moduli is empty".to_string()))
|
||||
} else {
|
||||
let mut moduli = Vec::with_capacity(moduli_u64.len());
|
||||
let mut q_tilde = Vec::with_capacity(moduli_u64.len());
|
||||
let mut q_tilde_shoup = Vec::with_capacity(moduli_u64.len());
|
||||
let mut q_star = Vec::with_capacity(moduli_u64.len());
|
||||
let mut garner = Vec::with_capacity(moduli_u64.len());
|
||||
let mut product = BigUint::one();
|
||||
let mut product_dig = BigUintDig::one();
|
||||
/// Create a RNS context from a list of moduli.
|
||||
///
|
||||
/// Returns an error if the list is empty, or if the moduli are no coprime.
|
||||
pub fn new(moduli_u64: &[u64]) -> Result<Self> {
|
||||
if moduli_u64.is_empty() {
|
||||
Err(Error::Default("The list of moduli is empty".to_string()))
|
||||
} else {
|
||||
let mut moduli = Vec::with_capacity(moduli_u64.len());
|
||||
let mut q_tilde = Vec::with_capacity(moduli_u64.len());
|
||||
let mut q_tilde_shoup = Vec::with_capacity(moduli_u64.len());
|
||||
let mut q_star = Vec::with_capacity(moduli_u64.len());
|
||||
let mut garner = Vec::with_capacity(moduli_u64.len());
|
||||
let mut product = BigUint::one();
|
||||
let mut product_dig = BigUintDig::one();
|
||||
|
||||
for i in 0..moduli_u64.len() {
|
||||
// Return None if the moduli are not coprime.
|
||||
for j in 0..moduli_u64.len() {
|
||||
if i != j {
|
||||
let (d, _, _) = BigUintDig::from(moduli_u64[i])
|
||||
.extended_gcd(&BigUintDig::from(moduli_u64[j]));
|
||||
if d.cmp(&BigIntDig::from(1)) != Ordering::Equal {
|
||||
return Err(Error::Default("The moduli are not coprime".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
for i in 0..moduli_u64.len() {
|
||||
// Return None if the moduli are not coprime.
|
||||
for j in 0..moduli_u64.len() {
|
||||
if i != j {
|
||||
let (d, _, _) = BigUintDig::from(moduli_u64[i])
|
||||
.extended_gcd(&BigUintDig::from(moduli_u64[j]));
|
||||
if d.cmp(&BigIntDig::from(1)) != Ordering::Equal {
|
||||
return Err(Error::Default("The moduli are not coprime".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
product *= &BigUint::from(moduli_u64[i]);
|
||||
product_dig *= &BigUintDig::from(moduli_u64[i]);
|
||||
}
|
||||
product *= &BigUint::from(moduli_u64[i]);
|
||||
product_dig *= &BigUintDig::from(moduli_u64[i]);
|
||||
}
|
||||
|
||||
for modulus in moduli_u64 {
|
||||
moduli.push(Modulus::new(*modulus)?);
|
||||
// q* = product / modulus
|
||||
let q_star_i = &product / modulus;
|
||||
// q~ = (product / modulus) ^ (-1) % modulus
|
||||
let q_tilde_i = (&product_dig / modulus)
|
||||
.mod_inverse(&BigUintDig::from(*modulus))
|
||||
.unwrap()
|
||||
.to_u64()
|
||||
.unwrap();
|
||||
// garner = (q*) * (q~)
|
||||
let garner_i = &q_star_i * q_tilde_i;
|
||||
q_tilde.push(q_tilde_i);
|
||||
garner.push(garner_i);
|
||||
q_star.push(q_star_i);
|
||||
q_tilde_shoup.push(
|
||||
Modulus::new(*modulus)
|
||||
.unwrap()
|
||||
.shoup(q_tilde_i.to_u64().unwrap()),
|
||||
);
|
||||
}
|
||||
for modulus in moduli_u64 {
|
||||
moduli.push(Modulus::new(*modulus)?);
|
||||
// q* = product / modulus
|
||||
let q_star_i = &product / modulus;
|
||||
// q~ = (product / modulus) ^ (-1) % modulus
|
||||
let q_tilde_i = (&product_dig / modulus)
|
||||
.mod_inverse(&BigUintDig::from(*modulus))
|
||||
.unwrap()
|
||||
.to_u64()
|
||||
.unwrap();
|
||||
// garner = (q*) * (q~)
|
||||
let garner_i = &q_star_i * q_tilde_i;
|
||||
q_tilde.push(q_tilde_i);
|
||||
garner.push(garner_i);
|
||||
q_star.push(q_star_i);
|
||||
q_tilde_shoup.push(
|
||||
Modulus::new(*modulus)
|
||||
.unwrap()
|
||||
.shoup(q_tilde_i.to_u64().unwrap()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
moduli_u64: moduli_u64.to_owned(),
|
||||
moduli,
|
||||
q_tilde,
|
||||
q_tilde_shoup,
|
||||
q_star,
|
||||
garner,
|
||||
product,
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
moduli_u64: moduli_u64.to_owned(),
|
||||
moduli,
|
||||
q_tilde,
|
||||
q_tilde_shoup,
|
||||
q_star,
|
||||
garner,
|
||||
product,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the product of the moduli used when creating the RNS context.
|
||||
pub const fn modulus(&self) -> &BigUint {
|
||||
&self.product
|
||||
}
|
||||
/// Returns the product of the moduli used when creating the RNS context.
|
||||
pub const fn modulus(&self) -> &BigUint {
|
||||
&self.product
|
||||
}
|
||||
|
||||
/// Project a BigUint into its rests.
|
||||
pub fn project(&self, a: &BigUint) -> Vec<u64> {
|
||||
let mut rests = Vec::with_capacity(self.moduli_u64.len());
|
||||
for modulus in &self.moduli_u64 {
|
||||
rests.push((a % modulus).to_u64().unwrap())
|
||||
}
|
||||
rests
|
||||
}
|
||||
/// Project a BigUint into its rests.
|
||||
pub fn project(&self, a: &BigUint) -> Vec<u64> {
|
||||
let mut rests = Vec::with_capacity(self.moduli_u64.len());
|
||||
for modulus in &self.moduli_u64 {
|
||||
rests.push((a % modulus).to_u64().unwrap())
|
||||
}
|
||||
rests
|
||||
}
|
||||
|
||||
/// Lift rests into a BigUint.
|
||||
///
|
||||
/// Aborts if the number of rests is different than the number of moduli in
|
||||
/// debug mode.
|
||||
pub fn lift(&self, rests: ArrayView1<u64>) -> BigUint {
|
||||
let mut result = BigUint::zero();
|
||||
izip!(rests.iter(), self.garner.iter())
|
||||
.for_each(|(r_i, garner_i)| result += garner_i * *r_i);
|
||||
result % &self.product
|
||||
}
|
||||
/// Lift rests into a BigUint.
|
||||
///
|
||||
/// Aborts if the number of rests is different than the number of moduli in
|
||||
/// debug mode.
|
||||
pub fn lift(&self, rests: ArrayView1<u64>) -> BigUint {
|
||||
let mut result = BigUint::zero();
|
||||
izip!(rests.iter(), self.garner.iter())
|
||||
.for_each(|(r_i, garner_i)| result += garner_i * *r_i);
|
||||
result % &self.product
|
||||
}
|
||||
|
||||
/// Getter for the i-th garner coefficient.
|
||||
pub fn get_garner(&self, i: usize) -> Option<&BigUint> {
|
||||
self.garner.get(i)
|
||||
}
|
||||
/// Getter for the i-th garner coefficient.
|
||||
pub fn get_garner(&self, i: usize) -> Option<&BigUint> {
|
||||
self.garner.get(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::error::Error;
|
||||
use std::error::Error;
|
||||
|
||||
use super::RnsContext;
|
||||
use ndarray::ArrayView1;
|
||||
use num_bigint::BigUint;
|
||||
use rand::RngCore;
|
||||
use super::RnsContext;
|
||||
use ndarray::ArrayView1;
|
||||
use num_bigint::BigUint;
|
||||
use rand::RngCore;
|
||||
|
||||
#[test]
|
||||
fn constructor() {
|
||||
assert!(RnsContext::new(&[2]).is_ok());
|
||||
assert!(RnsContext::new(&[2, 3]).is_ok());
|
||||
assert!(RnsContext::new(&[4, 15, 1153]).is_ok());
|
||||
#[test]
|
||||
fn constructor() {
|
||||
assert!(RnsContext::new(&[2]).is_ok());
|
||||
assert!(RnsContext::new(&[2, 3]).is_ok());
|
||||
assert!(RnsContext::new(&[4, 15, 1153]).is_ok());
|
||||
|
||||
let e = RnsContext::new(&[]);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty");
|
||||
let e = RnsContext::new(&[2, 4]);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
|
||||
let e = RnsContext::new(&[2, 3, 5, 30]);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
|
||||
}
|
||||
let e = RnsContext::new(&[]);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty");
|
||||
let e = RnsContext::new(&[2, 4]);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
|
||||
let e = RnsContext::new(&[2, 3, 5, 30]);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn garner() -> Result<(), Box<dyn Error>> {
|
||||
let rns = RnsContext::new(&[4, 15, 1153])?;
|
||||
#[test]
|
||||
fn garner() -> Result<(), Box<dyn Error>> {
|
||||
let rns = RnsContext::new(&[4, 15, 1153])?;
|
||||
|
||||
for i in 0..3 {
|
||||
let gi = rns.get_garner(i);
|
||||
assert!(gi.is_some());
|
||||
assert_eq!(gi.unwrap(), &rns.garner[i]);
|
||||
}
|
||||
assert!(rns.get_garner(3).is_none());
|
||||
for i in 0..3 {
|
||||
let gi = rns.get_garner(i);
|
||||
assert!(gi.is_some());
|
||||
assert_eq!(gi.unwrap(), &rns.garner[i]);
|
||||
}
|
||||
assert!(rns.get_garner(3).is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn modulus() -> Result<(), Box<dyn Error>> {
|
||||
let mut rns = RnsContext::new(&[2])?;
|
||||
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64));
|
||||
#[test]
|
||||
fn modulus() -> Result<(), Box<dyn Error>> {
|
||||
let mut rns = RnsContext::new(&[2])?;
|
||||
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64));
|
||||
|
||||
rns = RnsContext::new(&[2, 5])?;
|
||||
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5));
|
||||
rns = RnsContext::new(&[2, 5])?;
|
||||
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5));
|
||||
|
||||
rns = RnsContext::new(&[4, 15, 1153])?;
|
||||
debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153));
|
||||
rns = RnsContext::new(&[4, 15, 1153])?;
|
||||
debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn project_lift() -> Result<(), Box<dyn Error>> {
|
||||
let ntests = 100;
|
||||
let rns = RnsContext::new(&[4, 15, 1153])?;
|
||||
let product = 4u64 * 15 * 1153;
|
||||
#[test]
|
||||
fn project_lift() -> Result<(), Box<dyn Error>> {
|
||||
let ntests = 100;
|
||||
let rns = RnsContext::new(&[4, 15, 1153])?;
|
||||
let product = 4u64 * 15 * 1153;
|
||||
|
||||
let mut rests = rns.project(&BigUint::from(0u64));
|
||||
assert_eq!(&rests, &[0u64, 0, 0]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64));
|
||||
let mut rests = rns.project(&BigUint::from(0u64));
|
||||
assert_eq!(&rests, &[0u64, 0, 0]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64));
|
||||
|
||||
rests = rns.project(&BigUint::from(4u64));
|
||||
assert_eq!(&rests, &[0u64, 4, 4]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64));
|
||||
rests = rns.project(&BigUint::from(4u64));
|
||||
assert_eq!(&rests, &[0u64, 4, 4]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64));
|
||||
|
||||
rests = rns.project(&BigUint::from(15u64));
|
||||
assert_eq!(&rests, &[3u64, 0, 15]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64));
|
||||
rests = rns.project(&BigUint::from(15u64));
|
||||
assert_eq!(&rests, &[3u64, 0, 15]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64));
|
||||
|
||||
rests = rns.project(&BigUint::from(1153u64));
|
||||
assert_eq!(&rests, &[1u64, 13, 0]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64));
|
||||
rests = rns.project(&BigUint::from(1153u64));
|
||||
assert_eq!(&rests, &[1u64, 13, 0]);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64));
|
||||
|
||||
rests = rns.project(&BigUint::from(product - 1));
|
||||
assert_eq!(&rests, &[3u64, 14, 1152]);
|
||||
assert_eq!(
|
||||
rns.lift(ArrayView1::from(&rests)),
|
||||
BigUint::from(product - 1)
|
||||
);
|
||||
rests = rns.project(&BigUint::from(product - 1));
|
||||
assert_eq!(&rests, &[3u64, 14, 1152]);
|
||||
assert_eq!(
|
||||
rns.lift(ArrayView1::from(&rests)),
|
||||
BigUint::from(product - 1)
|
||||
);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for _ in 0..ntests {
|
||||
let b = BigUint::from(rng.next_u64() % product);
|
||||
rests = rns.project(&b);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), b);
|
||||
}
|
||||
for _ in 0..ntests {
|
||||
let b = BigUint::from(rng.next_u64() % product);
|
||||
rests = rns.project(&b);
|
||||
assert_eq!(rns.lift(ArrayView1::from(&rests)), b);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,481 +14,481 @@ use std::{cmp::min, sync::Arc};
|
||||
/// Scaling factor when performing a RNS scaling.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ScalingFactor {
|
||||
numerator: BigUint,
|
||||
denominator: BigUint,
|
||||
pub(crate) is_one: bool,
|
||||
numerator: BigUint,
|
||||
denominator: BigUint,
|
||||
pub(crate) is_one: bool,
|
||||
}
|
||||
|
||||
impl ScalingFactor {
|
||||
/// Create a new scaling factor. Aborts if the denominator is 0.
|
||||
pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self {
|
||||
assert_ne!(denominator, &BigUint::zero());
|
||||
Self {
|
||||
numerator: numerator.clone(),
|
||||
denominator: denominator.clone(),
|
||||
is_one: numerator == denominator,
|
||||
}
|
||||
}
|
||||
/// Create a new scaling factor. Aborts if the denominator is 0.
|
||||
pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self {
|
||||
assert_ne!(denominator, &BigUint::zero());
|
||||
Self {
|
||||
numerator: numerator.clone(),
|
||||
denominator: denominator.clone(),
|
||||
is_one: numerator == denominator,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the identity element of `Self`.
|
||||
pub fn one() -> Self {
|
||||
Self {
|
||||
numerator: BigUint::one(),
|
||||
denominator: BigUint::one(),
|
||||
is_one: true,
|
||||
}
|
||||
}
|
||||
/// Returns the identity element of `Self`.
|
||||
pub fn one() -> Self {
|
||||
Self {
|
||||
numerator: BigUint::one(),
|
||||
denominator: BigUint::one(),
|
||||
is_one: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scaler for a RNS context.
|
||||
/// This is a helper struct to perform RNS scaling.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RnsScaler {
|
||||
from: Arc<RnsContext>,
|
||||
to: Arc<RnsContext>,
|
||||
scaling_factor: ScalingFactor,
|
||||
from: Arc<RnsContext>,
|
||||
to: Arc<RnsContext>,
|
||||
scaling_factor: ScalingFactor,
|
||||
|
||||
gamma: Box<[u64]>,
|
||||
gamma_shoup: Box<[u64]>,
|
||||
theta_gamma_lo: u64,
|
||||
theta_gamma_hi: u64,
|
||||
theta_gamma_sign: bool,
|
||||
gamma: Box<[u64]>,
|
||||
gamma_shoup: Box<[u64]>,
|
||||
theta_gamma_lo: u64,
|
||||
theta_gamma_hi: u64,
|
||||
theta_gamma_sign: bool,
|
||||
|
||||
omega: Box<[Box<[u64]>]>,
|
||||
omega_shoup: Box<[Box<[u64]>]>,
|
||||
theta_omega_lo: Box<[u64]>,
|
||||
theta_omega_hi: Box<[u64]>,
|
||||
theta_omega_sign: Box<[bool]>,
|
||||
omega: Box<[Box<[u64]>]>,
|
||||
omega_shoup: Box<[Box<[u64]>]>,
|
||||
theta_omega_lo: Box<[u64]>,
|
||||
theta_omega_hi: Box<[u64]>,
|
||||
theta_omega_sign: Box<[bool]>,
|
||||
|
||||
theta_garner_lo: Box<[u64]>,
|
||||
theta_garner_hi: Box<[u64]>,
|
||||
theta_garner_shift: usize,
|
||||
theta_garner_lo: Box<[u64]>,
|
||||
theta_garner_hi: Box<[u64]>,
|
||||
theta_garner_shift: usize,
|
||||
}
|
||||
|
||||
impl RnsScaler {
|
||||
/// Create a RNS scaler by numerator / denominator.
|
||||
///
|
||||
/// Aborts if denominator is equal to 0.
|
||||
pub fn new(
|
||||
from: &Arc<RnsContext>,
|
||||
to: &Arc<RnsContext>,
|
||||
scaling_factor: ScalingFactor,
|
||||
) -> Self {
|
||||
// Let's define gamma = round(numerator * from.product / denominator)
|
||||
let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
|
||||
Self::extract_projection_and_theta(
|
||||
to,
|
||||
&from.product,
|
||||
&scaling_factor.numerator,
|
||||
&scaling_factor.denominator,
|
||||
false,
|
||||
);
|
||||
let gamma_shoup = izip!(&gamma, &to.moduli)
|
||||
.map(|(wi, q)| q.shoup(*wi))
|
||||
.collect_vec();
|
||||
/// Create a RNS scaler by numerator / denominator.
|
||||
///
|
||||
/// Aborts if denominator is equal to 0.
|
||||
pub fn new(
|
||||
from: &Arc<RnsContext>,
|
||||
to: &Arc<RnsContext>,
|
||||
scaling_factor: ScalingFactor,
|
||||
) -> Self {
|
||||
// Let's define gamma = round(numerator * from.product / denominator)
|
||||
let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
|
||||
Self::extract_projection_and_theta(
|
||||
to,
|
||||
&from.product,
|
||||
&scaling_factor.numerator,
|
||||
&scaling_factor.denominator,
|
||||
false,
|
||||
);
|
||||
let gamma_shoup = izip!(&gamma, &to.moduli)
|
||||
.map(|(wi, q)| q.shoup(*wi))
|
||||
.collect_vec();
|
||||
|
||||
// Let's define omega_i = round(from.garner_i * numerator / denominator)
|
||||
let mut omega = Vec::with_capacity(to.moduli.len());
|
||||
let mut omega_shoup = Vec::with_capacity(to.moduli.len());
|
||||
for _ in &to.moduli {
|
||||
omega.push(vec![0u64; from.moduli.len()].into_boxed_slice());
|
||||
omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice());
|
||||
}
|
||||
let mut theta_omega_lo = Vec::with_capacity(from.garner.len());
|
||||
let mut theta_omega_hi = Vec::with_capacity(from.garner.len());
|
||||
let mut theta_omega_sign = Vec::with_capacity(from.garner.len());
|
||||
for i in 0..from.garner.len() {
|
||||
let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) =
|
||||
Self::extract_projection_and_theta(
|
||||
to,
|
||||
&from.garner[i],
|
||||
&scaling_factor.numerator,
|
||||
&scaling_factor.denominator,
|
||||
true,
|
||||
);
|
||||
for j in 0..to.moduli.len() {
|
||||
let qj = &to.moduli[j];
|
||||
omega[j][i] = qj.reduce(omega_i[j]);
|
||||
omega_shoup[j][i] = qj.shoup(omega[j][i]);
|
||||
}
|
||||
theta_omega_lo.push(theta_omega_i_lo);
|
||||
theta_omega_hi.push(theta_omega_i_hi);
|
||||
theta_omega_sign.push(theta_omega_i_sign);
|
||||
}
|
||||
// Let's define omega_i = round(from.garner_i * numerator / denominator)
|
||||
let mut omega = Vec::with_capacity(to.moduli.len());
|
||||
let mut omega_shoup = Vec::with_capacity(to.moduli.len());
|
||||
for _ in &to.moduli {
|
||||
omega.push(vec![0u64; from.moduli.len()].into_boxed_slice());
|
||||
omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice());
|
||||
}
|
||||
let mut theta_omega_lo = Vec::with_capacity(from.garner.len());
|
||||
let mut theta_omega_hi = Vec::with_capacity(from.garner.len());
|
||||
let mut theta_omega_sign = Vec::with_capacity(from.garner.len());
|
||||
for i in 0..from.garner.len() {
|
||||
let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) =
|
||||
Self::extract_projection_and_theta(
|
||||
to,
|
||||
&from.garner[i],
|
||||
&scaling_factor.numerator,
|
||||
&scaling_factor.denominator,
|
||||
true,
|
||||
);
|
||||
for j in 0..to.moduli.len() {
|
||||
let qj = &to.moduli[j];
|
||||
omega[j][i] = qj.reduce(omega_i[j]);
|
||||
omega_shoup[j][i] = qj.shoup(omega[j][i]);
|
||||
}
|
||||
theta_omega_lo.push(theta_omega_i_lo);
|
||||
theta_omega_hi.push(theta_omega_i_hi);
|
||||
theta_omega_sign.push(theta_omega_i_sign);
|
||||
}
|
||||
|
||||
// Determine the shift so that the sum of the scaled theta_garner fit on an U192
|
||||
// (shift + 1) + log(q * n) <= 192
|
||||
let theta_garner_shift = min(
|
||||
from.moduli_u64
|
||||
.iter()
|
||||
.map(|qi| {
|
||||
192 - 1
|
||||
- ilog2(
|
||||
((*qi as u128) * (from.moduli_u64.len() as u128)).next_power_of_two(),
|
||||
)
|
||||
})
|
||||
.min()
|
||||
.unwrap(),
|
||||
127,
|
||||
);
|
||||
// Finally, define theta_garner_i = from.garner_i / product, also scaled by
|
||||
// 2^127.
|
||||
let mut theta_garner_lo = Vec::with_capacity(from.garner.len());
|
||||
let mut theta_garner_hi = Vec::with_capacity(from.garner.len());
|
||||
for garner_i in &from.garner {
|
||||
let mut theta: BigUint =
|
||||
((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
|
||||
let theta_hi: BigUint = &theta >> 64;
|
||||
theta -= &theta_hi << 64;
|
||||
theta_garner_lo.push(theta.to_u64().unwrap());
|
||||
theta_garner_hi.push(theta_hi.to_u64().unwrap());
|
||||
}
|
||||
// Determine the shift so that the sum of the scaled theta_garner fit on an U192
|
||||
// (shift + 1) + log(q * n) <= 192
|
||||
let theta_garner_shift = min(
|
||||
from.moduli_u64
|
||||
.iter()
|
||||
.map(|qi| {
|
||||
192 - 1
|
||||
- ilog2(
|
||||
((*qi as u128) * (from.moduli_u64.len() as u128)).next_power_of_two(),
|
||||
)
|
||||
})
|
||||
.min()
|
||||
.unwrap(),
|
||||
127,
|
||||
);
|
||||
// Finally, define theta_garner_i = from.garner_i / product, also scaled by
|
||||
// 2^127.
|
||||
let mut theta_garner_lo = Vec::with_capacity(from.garner.len());
|
||||
let mut theta_garner_hi = Vec::with_capacity(from.garner.len());
|
||||
for garner_i in &from.garner {
|
||||
let mut theta: BigUint =
|
||||
((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
|
||||
let theta_hi: BigUint = &theta >> 64;
|
||||
theta -= &theta_hi << 64;
|
||||
theta_garner_lo.push(theta.to_u64().unwrap());
|
||||
theta_garner_hi.push(theta_hi.to_u64().unwrap());
|
||||
}
|
||||
|
||||
Self {
|
||||
from: from.clone(),
|
||||
to: to.clone(),
|
||||
scaling_factor,
|
||||
gamma: gamma.into_boxed_slice(),
|
||||
gamma_shoup: gamma_shoup.into_boxed_slice(),
|
||||
theta_gamma_lo,
|
||||
theta_gamma_hi,
|
||||
theta_gamma_sign,
|
||||
omega: omega.into_boxed_slice(),
|
||||
omega_shoup: omega_shoup.into_boxed_slice(),
|
||||
theta_omega_lo: theta_omega_lo.into_boxed_slice(),
|
||||
theta_omega_hi: theta_omega_hi.into_boxed_slice(),
|
||||
theta_omega_sign: theta_omega_sign.into_boxed_slice(),
|
||||
theta_garner_lo: theta_garner_lo.into_boxed_slice(),
|
||||
theta_garner_hi: theta_garner_hi.into_boxed_slice(),
|
||||
theta_garner_shift,
|
||||
}
|
||||
}
|
||||
Self {
|
||||
from: from.clone(),
|
||||
to: to.clone(),
|
||||
scaling_factor,
|
||||
gamma: gamma.into_boxed_slice(),
|
||||
gamma_shoup: gamma_shoup.into_boxed_slice(),
|
||||
theta_gamma_lo,
|
||||
theta_gamma_hi,
|
||||
theta_gamma_sign,
|
||||
omega: omega.into_boxed_slice(),
|
||||
omega_shoup: omega_shoup.into_boxed_slice(),
|
||||
theta_omega_lo: theta_omega_lo.into_boxed_slice(),
|
||||
theta_omega_hi: theta_omega_hi.into_boxed_slice(),
|
||||
theta_omega_sign: theta_omega_sign.into_boxed_slice(),
|
||||
theta_garner_lo: theta_garner_lo.into_boxed_slice(),
|
||||
theta_garner_hi: theta_garner_hi.into_boxed_slice(),
|
||||
theta_garner_shift,
|
||||
}
|
||||
}
|
||||
|
||||
// Let's define gamma = round(numerator * input / denominator)
|
||||
// and theta_gamma such that theta_gamma = numerator * input / denominator -
|
||||
// gamma. This function projects gamma in the RNS context, and scales
|
||||
// theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the
|
||||
// RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma =
|
||||
// (-1)**theta_sign * (theta_lo + 2^64 * theta_hi).
|
||||
fn extract_projection_and_theta(
|
||||
ctx: &RnsContext,
|
||||
input: &BigUint,
|
||||
numerator: &BigUint,
|
||||
denominator: &BigUint,
|
||||
round_up: bool,
|
||||
) -> (Vec<u64>, u64, u64, bool) {
|
||||
let gamma = (numerator * input + (denominator >> 1)) / denominator;
|
||||
let projected = ctx.project(&gamma);
|
||||
// Let's define gamma = round(numerator * input / denominator)
|
||||
// and theta_gamma such that theta_gamma = numerator * input / denominator -
|
||||
// gamma. This function projects gamma in the RNS context, and scales
|
||||
// theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the
|
||||
// RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma =
|
||||
// (-1)**theta_sign * (theta_lo + 2^64 * theta_hi).
|
||||
fn extract_projection_and_theta(
|
||||
ctx: &RnsContext,
|
||||
input: &BigUint,
|
||||
numerator: &BigUint,
|
||||
denominator: &BigUint,
|
||||
round_up: bool,
|
||||
) -> (Vec<u64>, u64, u64, bool) {
|
||||
let gamma = (numerator * input + (denominator >> 1)) / denominator;
|
||||
let projected = ctx.project(&gamma);
|
||||
|
||||
let mut theta = (numerator * input) % denominator;
|
||||
let mut theta_sign = false;
|
||||
if denominator > &BigUint::one() {
|
||||
// If denominator is odd, flip theta if theta > (denominator >> 1)
|
||||
if denominator & BigUint::one() == BigUint::one() {
|
||||
if theta > (denominator >> 1) {
|
||||
theta_sign = true;
|
||||
theta = denominator - theta;
|
||||
}
|
||||
} else {
|
||||
// denominator is even, flip if theta >= (denominator >> 1)
|
||||
if theta >= (denominator >> 1) {
|
||||
theta_sign = true;
|
||||
theta = denominator - theta;
|
||||
}
|
||||
}
|
||||
}
|
||||
// theta = ((theta << 127) + (denominator >> 1)) / denominator;
|
||||
// We can now split theta into two u64 words.
|
||||
if round_up {
|
||||
if theta_sign {
|
||||
theta = (theta << 127) / denominator;
|
||||
} else {
|
||||
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
|
||||
}
|
||||
} else if theta_sign {
|
||||
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
|
||||
} else {
|
||||
theta = (theta << 127) / denominator;
|
||||
}
|
||||
let theta_hi_biguint: BigUint = &theta >> 64;
|
||||
theta -= &theta_hi_biguint << 64;
|
||||
let theta_lo = theta.to_u64().unwrap();
|
||||
let theta_hi = theta_hi_biguint.to_u64().unwrap();
|
||||
let mut theta = (numerator * input) % denominator;
|
||||
let mut theta_sign = false;
|
||||
if denominator > &BigUint::one() {
|
||||
// If denominator is odd, flip theta if theta > (denominator >> 1)
|
||||
if denominator & BigUint::one() == BigUint::one() {
|
||||
if theta > (denominator >> 1) {
|
||||
theta_sign = true;
|
||||
theta = denominator - theta;
|
||||
}
|
||||
} else {
|
||||
// denominator is even, flip if theta >= (denominator >> 1)
|
||||
if theta >= (denominator >> 1) {
|
||||
theta_sign = true;
|
||||
theta = denominator - theta;
|
||||
}
|
||||
}
|
||||
}
|
||||
// theta = ((theta << 127) + (denominator >> 1)) / denominator;
|
||||
// We can now split theta into two u64 words.
|
||||
if round_up {
|
||||
if theta_sign {
|
||||
theta = (theta << 127) / denominator;
|
||||
} else {
|
||||
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
|
||||
}
|
||||
} else if theta_sign {
|
||||
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
|
||||
} else {
|
||||
theta = (theta << 127) / denominator;
|
||||
}
|
||||
let theta_hi_biguint: BigUint = &theta >> 64;
|
||||
theta -= &theta_hi_biguint << 64;
|
||||
let theta_lo = theta.to_u64().unwrap();
|
||||
let theta_hi = theta_hi_biguint.to_u64().unwrap();
|
||||
|
||||
(projected, theta_lo, theta_hi, theta_sign)
|
||||
}
|
||||
(projected, theta_lo, theta_hi, theta_sign)
|
||||
}
|
||||
|
||||
/// Output the RNS representation of the rests scaled by numerator *
|
||||
/// denominator, and either rounded or floored.
|
||||
///
|
||||
/// Aborts if the number of rests is different than the number of moduli in
|
||||
/// debug mode, or if the size is not in [1, ..., rests.len()].
|
||||
pub fn scale_new(&self, rests: ArrayView1<u64>, size: usize) -> Vec<u64> {
|
||||
let mut out = vec![0; size];
|
||||
self.scale(rests, (&mut out).into(), 0);
|
||||
out
|
||||
}
|
||||
/// Output the RNS representation of the rests scaled by numerator *
|
||||
/// denominator, and either rounded or floored.
|
||||
///
|
||||
/// Aborts if the number of rests is different than the number of moduli in
|
||||
/// debug mode, or if the size is not in [1, ..., rests.len()].
|
||||
pub fn scale_new(&self, rests: ArrayView1<u64>, size: usize) -> Vec<u64> {
|
||||
let mut out = vec![0; size];
|
||||
self.scale(rests, (&mut out).into(), 0);
|
||||
out
|
||||
}
|
||||
|
||||
/// Compute the RNS representation of the rests scaled by numerator *
|
||||
/// denominator, and either rounded or floored, and store the result in
|
||||
/// `out`.
|
||||
///
|
||||
/// Aborts if the number of rests is different than the number of moduli in
|
||||
/// debug mode, or if the size of out is not in [1, ..., rests.len()].
|
||||
pub fn scale(
|
||||
&self,
|
||||
rests: ArrayView1<u64>,
|
||||
mut out: ArrayViewMut1<u64>,
|
||||
starting_index: usize,
|
||||
) {
|
||||
debug_assert_eq!(rests.len(), self.from.moduli_u64.len());
|
||||
debug_assert!(!out.is_empty());
|
||||
debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());
|
||||
/// Compute the RNS representation of the rests scaled by numerator *
|
||||
/// denominator, and either rounded or floored, and store the result in
|
||||
/// `out`.
|
||||
///
|
||||
/// Aborts if the number of rests is different than the number of moduli in
|
||||
/// debug mode, or if the size of out is not in [1, ..., rests.len()].
|
||||
pub fn scale(
|
||||
&self,
|
||||
rests: ArrayView1<u64>,
|
||||
mut out: ArrayViewMut1<u64>,
|
||||
starting_index: usize,
|
||||
) {
|
||||
debug_assert_eq!(rests.len(), self.from.moduli_u64.len());
|
||||
debug_assert!(!out.is_empty());
|
||||
debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());
|
||||
|
||||
// First, let's compute the inner product of the rests with theta_omega.
|
||||
let mut sum_theta_garner = U192::ZERO;
|
||||
for (thetag_lo, thetag_hi, ri) in izip!(
|
||||
self.theta_garner_lo.iter(),
|
||||
self.theta_garner_hi.iter(),
|
||||
rests
|
||||
) {
|
||||
let lo = (*ri as u128) * (*thetag_lo as u128);
|
||||
let hi = (*ri as u128) * (*thetag_hi as u128) + (lo >> 64);
|
||||
sum_theta_garner = sum_theta_garner.wrapping_add(&U192::from_words([
|
||||
lo as u64,
|
||||
hi as u64,
|
||||
(hi >> 64) as u64,
|
||||
]));
|
||||
}
|
||||
// Let's compute v = round(sum_theta_garner / 2^theta_garner_shift)
|
||||
sum_theta_garner >>= self.theta_garner_shift - 1;
|
||||
let v = <[u64; 3]>::from(sum_theta_garner);
|
||||
let v = div_ceil((v[0] as u128) | ((v[1] as u128) << 64), 2);
|
||||
// First, let's compute the inner product of the rests with theta_omega.
|
||||
let mut sum_theta_garner = U192::ZERO;
|
||||
for (thetag_lo, thetag_hi, ri) in izip!(
|
||||
self.theta_garner_lo.iter(),
|
||||
self.theta_garner_hi.iter(),
|
||||
rests
|
||||
) {
|
||||
let lo = (*ri as u128) * (*thetag_lo as u128);
|
||||
let hi = (*ri as u128) * (*thetag_hi as u128) + (lo >> 64);
|
||||
sum_theta_garner = sum_theta_garner.wrapping_add(&U192::from_words([
|
||||
lo as u64,
|
||||
hi as u64,
|
||||
(hi >> 64) as u64,
|
||||
]));
|
||||
}
|
||||
// Let's compute v = round(sum_theta_garner / 2^theta_garner_shift)
|
||||
sum_theta_garner >>= self.theta_garner_shift - 1;
|
||||
let v = <[u64; 3]>::from(sum_theta_garner);
|
||||
let v = div_ceil((v[0] as u128) | ((v[1] as u128) << 64), 2);
|
||||
|
||||
// If the scaling factor is not 1, compute the inner product with the
|
||||
// theta_omega
|
||||
let mut w_sign = false;
|
||||
let mut w = 0u128;
|
||||
if !self.scaling_factor.is_one {
|
||||
let mut sum_theta_omega = U256::zero();
|
||||
for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
|
||||
self.theta_omega_lo.iter(),
|
||||
self.theta_omega_hi.iter(),
|
||||
self.theta_omega_sign.iter(),
|
||||
rests
|
||||
) {
|
||||
let lo = (*ri as u128) * (*thetao_lo as u128);
|
||||
let hi = (*ri as u128) * (*thetao_hi as u128) + (lo >> 64);
|
||||
if *thetao_sign {
|
||||
sum_theta_omega.wrapping_sub_assign(U256::from([
|
||||
lo as u64,
|
||||
hi as u64,
|
||||
(hi >> 64) as u64,
|
||||
0,
|
||||
]));
|
||||
} else {
|
||||
sum_theta_omega.wrapping_add_assign(U256::from([
|
||||
lo as u64,
|
||||
hi as u64,
|
||||
(hi >> 64) as u64,
|
||||
0,
|
||||
]));
|
||||
}
|
||||
}
|
||||
// If the scaling factor is not 1, compute the inner product with the
|
||||
// theta_omega
|
||||
let mut w_sign = false;
|
||||
let mut w = 0u128;
|
||||
if !self.scaling_factor.is_one {
|
||||
let mut sum_theta_omega = U256::zero();
|
||||
for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
|
||||
self.theta_omega_lo.iter(),
|
||||
self.theta_omega_hi.iter(),
|
||||
self.theta_omega_sign.iter(),
|
||||
rests
|
||||
) {
|
||||
let lo = (*ri as u128) * (*thetao_lo as u128);
|
||||
let hi = (*ri as u128) * (*thetao_hi as u128) + (lo >> 64);
|
||||
if *thetao_sign {
|
||||
sum_theta_omega.wrapping_sub_assign(U256::from([
|
||||
lo as u64,
|
||||
hi as u64,
|
||||
(hi >> 64) as u64,
|
||||
0,
|
||||
]));
|
||||
} else {
|
||||
sum_theta_omega.wrapping_add_assign(U256::from([
|
||||
lo as u64,
|
||||
hi as u64,
|
||||
(hi >> 64) as u64,
|
||||
0,
|
||||
]));
|
||||
}
|
||||
}
|
||||
|
||||
// Let's subtract v * theta_gamma to sum_theta_omega.
|
||||
let vt_lo_lo = ((v as u64) as u128) * (self.theta_gamma_lo as u128);
|
||||
let vt_lo_hi = ((v as u64) as u128) * (self.theta_gamma_hi as u128);
|
||||
let vt_hi_lo = (v >> 64) * (self.theta_gamma_lo as u128);
|
||||
let vt_hi_hi = (v >> 64) * (self.theta_gamma_hi as u128);
|
||||
let vt_mi =
|
||||
(vt_lo_lo >> 64) + ((vt_lo_hi as u64) as u128) + ((vt_hi_lo as u64) as u128);
|
||||
let vt_hi = (vt_lo_hi >> 64) + (vt_mi >> 64) + ((vt_hi_hi as u64) as u128);
|
||||
if self.theta_gamma_sign {
|
||||
sum_theta_omega.wrapping_add_assign(U256::from([
|
||||
vt_lo_lo as u64,
|
||||
vt_mi as u64,
|
||||
vt_hi as u64,
|
||||
0,
|
||||
]))
|
||||
} else {
|
||||
sum_theta_omega.wrapping_sub_assign(U256::from([
|
||||
vt_lo_lo as u64,
|
||||
vt_mi as u64,
|
||||
vt_hi as u64,
|
||||
0,
|
||||
]))
|
||||
}
|
||||
// Let's subtract v * theta_gamma to sum_theta_omega.
|
||||
let vt_lo_lo = ((v as u64) as u128) * (self.theta_gamma_lo as u128);
|
||||
let vt_lo_hi = ((v as u64) as u128) * (self.theta_gamma_hi as u128);
|
||||
let vt_hi_lo = (v >> 64) * (self.theta_gamma_lo as u128);
|
||||
let vt_hi_hi = (v >> 64) * (self.theta_gamma_hi as u128);
|
||||
let vt_mi =
|
||||
(vt_lo_lo >> 64) + ((vt_lo_hi as u64) as u128) + ((vt_hi_lo as u64) as u128);
|
||||
let vt_hi = (vt_lo_hi >> 64) + (vt_mi >> 64) + ((vt_hi_hi as u64) as u128);
|
||||
if self.theta_gamma_sign {
|
||||
sum_theta_omega.wrapping_add_assign(U256::from([
|
||||
vt_lo_lo as u64,
|
||||
vt_mi as u64,
|
||||
vt_hi as u64,
|
||||
0,
|
||||
]))
|
||||
} else {
|
||||
sum_theta_omega.wrapping_sub_assign(U256::from([
|
||||
vt_lo_lo as u64,
|
||||
vt_mi as u64,
|
||||
vt_hi as u64,
|
||||
0,
|
||||
]))
|
||||
}
|
||||
|
||||
// Let's compute w = round(sum_theta_omega / 2^127).
|
||||
w_sign = sum_theta_omega.msb() > 0;
|
||||
// Let's compute w = round(sum_theta_omega / 2^127).
|
||||
w_sign = sum_theta_omega.msb() > 0;
|
||||
|
||||
if w_sign {
|
||||
w = u128::from(&((!sum_theta_omega) >> 126)) + 1;
|
||||
w /= 2;
|
||||
} else {
|
||||
w = u128::from(&(sum_theta_omega >> 126));
|
||||
w = div_ceil(w, 2)
|
||||
}
|
||||
}
|
||||
if w_sign {
|
||||
w = u128::from(&((!sum_theta_omega) >> 126)) + 1;
|
||||
w /= 2;
|
||||
} else {
|
||||
w = u128::from(&(sum_theta_omega >> 126));
|
||||
w = div_ceil(w, 2)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
for i in 0..out.len() {
|
||||
debug_assert!(starting_index + i < self.to.moduli.len());
|
||||
debug_assert!(starting_index + i < self.omega.len());
|
||||
debug_assert!(starting_index + i < self.omega_shoup.len());
|
||||
debug_assert!(starting_index + i < self.gamma.len());
|
||||
debug_assert!(starting_index + i < self.gamma_shoup.len());
|
||||
let out_i = out.get_mut(i).unwrap();
|
||||
let qi = self.to.moduli.get_unchecked(starting_index + i);
|
||||
let omega_i = self.omega.get_unchecked(starting_index + i);
|
||||
let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i);
|
||||
let gamma_i = self.gamma.get_unchecked(starting_index + i);
|
||||
let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i);
|
||||
unsafe {
|
||||
for i in 0..out.len() {
|
||||
debug_assert!(starting_index + i < self.to.moduli.len());
|
||||
debug_assert!(starting_index + i < self.omega.len());
|
||||
debug_assert!(starting_index + i < self.omega_shoup.len());
|
||||
debug_assert!(starting_index + i < self.gamma.len());
|
||||
debug_assert!(starting_index + i < self.gamma_shoup.len());
|
||||
let out_i = out.get_mut(i).unwrap();
|
||||
let qi = self.to.moduli.get_unchecked(starting_index + i);
|
||||
let omega_i = self.omega.get_unchecked(starting_index + i);
|
||||
let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i);
|
||||
let gamma_i = self.gamma.get_unchecked(starting_index + i);
|
||||
let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i);
|
||||
|
||||
let mut yi = (qi.modulus() * 2
|
||||
- qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i))
|
||||
as u128;
|
||||
let mut yi = (qi.modulus() * 2
|
||||
- qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i))
|
||||
as u128;
|
||||
|
||||
if !self.scaling_factor.is_one {
|
||||
let wi = qi.lazy_reduce_u128(w);
|
||||
yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128;
|
||||
}
|
||||
if !self.scaling_factor.is_one {
|
||||
let wi = qi.lazy_reduce_u128(w);
|
||||
yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128;
|
||||
}
|
||||
|
||||
debug_assert!(rests.len() <= omega_i.len());
|
||||
debug_assert!(rests.len() <= omega_shoup_i.len());
|
||||
for j in 0..rests.len() {
|
||||
yi += qi.lazy_mul_shoup(
|
||||
*rests.get(j).unwrap(),
|
||||
*omega_i.get_unchecked(j),
|
||||
*omega_shoup_i.get_unchecked(j),
|
||||
) as u128;
|
||||
}
|
||||
debug_assert!(rests.len() <= omega_i.len());
|
||||
debug_assert!(rests.len() <= omega_shoup_i.len());
|
||||
for j in 0..rests.len() {
|
||||
yi += qi.lazy_mul_shoup(
|
||||
*rests.get(j).unwrap(),
|
||||
*omega_i.get_unchecked(j),
|
||||
*omega_shoup_i.get_unchecked(j),
|
||||
) as u128;
|
||||
}
|
||||
|
||||
*out_i = qi.reduce_u128(yi)
|
||||
}
|
||||
}
|
||||
}
|
||||
*out_i = qi.reduce_u128(yi)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{error::Error, sync::Arc};
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
use super::RnsScaler;
|
||||
use crate::rns::{scaler::ScalingFactor, RnsContext};
|
||||
use fhe_util::catch_unwind;
|
||||
use ndarray::ArrayView1;
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::{ToPrimitive, Zero};
|
||||
use rand::{thread_rng, RngCore};
|
||||
use super::RnsScaler;
|
||||
use crate::rns::{scaler::ScalingFactor, RnsContext};
|
||||
use fhe_util::catch_unwind;
|
||||
use ndarray::ArrayView1;
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::{ToPrimitive, Zero};
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
#[test]
|
||||
fn constructor() -> Result<(), Box<dyn Error>> {
|
||||
let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?);
|
||||
#[test]
|
||||
fn constructor() -> Result<(), Box<dyn Error>> {
|
||||
let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?);
|
||||
|
||||
let scaler = RnsScaler::new(&q, &q, ScalingFactor::one());
|
||||
assert_eq!(scaler.from, q);
|
||||
let scaler = RnsScaler::new(&q, &q, ScalingFactor::one());
|
||||
assert_eq!(scaler.from, q);
|
||||
|
||||
assert!(
|
||||
catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
assert!(
|
||||
catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_same_context() -> Result<(), Box<dyn Error>> {
|
||||
let ntests = 1000;
|
||||
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
|
||||
let mut rng = thread_rng();
|
||||
#[test]
|
||||
fn scale_same_context() -> Result<(), Box<dyn Error>> {
|
||||
let ntests = 1000;
|
||||
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
|
||||
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
|
||||
let n = BigUint::from(*numerator);
|
||||
let d = BigUint::from(*denominator);
|
||||
let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d));
|
||||
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
|
||||
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
|
||||
let n = BigUint::from(*numerator);
|
||||
let d = BigUint::from(*denominator);
|
||||
let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d));
|
||||
|
||||
for _ in 0..ntests {
|
||||
let x = vec![
|
||||
rng.next_u64() % q.moduli_u64[0],
|
||||
rng.next_u64() % q.moduli_u64[1],
|
||||
rng.next_u64() % q.moduli_u64[2],
|
||||
];
|
||||
let mut x_lift = q.lift(ArrayView1::from(&x));
|
||||
let x_sign = x_lift >= (q.modulus() >> 1);
|
||||
if x_sign {
|
||||
x_lift = q.modulus() - x_lift;
|
||||
}
|
||||
for _ in 0..ntests {
|
||||
let x = vec![
|
||||
rng.next_u64() % q.moduli_u64[0],
|
||||
rng.next_u64() % q.moduli_u64[1],
|
||||
rng.next_u64() % q.moduli_u64[2],
|
||||
];
|
||||
let mut x_lift = q.lift(ArrayView1::from(&x));
|
||||
let x_sign = x_lift >= (q.modulus() >> 1);
|
||||
if x_sign {
|
||||
x_lift = q.modulus() - x_lift;
|
||||
}
|
||||
|
||||
let z = scaler.scale_new((&x).into(), x.len());
|
||||
let x_scaled_round = if x_sign {
|
||||
if d.to_u64().unwrap() % 2 == 0 {
|
||||
q.modulus()
|
||||
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus()
|
||||
} else {
|
||||
q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus()
|
||||
}
|
||||
} else {
|
||||
&(&x_lift * &n + (&d >> 1)) / &d
|
||||
};
|
||||
assert_eq!(z, q.project(&x_scaled_round));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
let z = scaler.scale_new((&x).into(), x.len());
|
||||
let x_scaled_round = if x_sign {
|
||||
if d.to_u64().unwrap() % 2 == 0 {
|
||||
q.modulus()
|
||||
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus()
|
||||
} else {
|
||||
q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus()
|
||||
}
|
||||
} else {
|
||||
&(&x_lift * &n + (&d >> 1)) / &d
|
||||
};
|
||||
assert_eq!(z, q.project(&x_scaled_round));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_different_contexts() -> Result<(), Box<dyn Error>> {
|
||||
let ntests = 100;
|
||||
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
|
||||
let r = Arc::new(RnsContext::new(&[
|
||||
4u64,
|
||||
4611686018326724609,
|
||||
1153,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
4611686018257518593,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
4611686018106523649,
|
||||
4611686018058289153,
|
||||
])?);
|
||||
let mut rng = thread_rng();
|
||||
#[test]
|
||||
fn scale_different_contexts() -> Result<(), Box<dyn Error>> {
|
||||
let ntests = 100;
|
||||
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
|
||||
let r = Arc::new(RnsContext::new(&[
|
||||
4u64,
|
||||
4611686018326724609,
|
||||
1153,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
4611686018257518593,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
4611686018106523649,
|
||||
4611686018058289153,
|
||||
])?);
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
|
||||
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
|
||||
let n = BigUint::from(*numerator);
|
||||
let d = BigUint::from(*denominator);
|
||||
let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d));
|
||||
for _ in 0..ntests {
|
||||
let x = vec![
|
||||
rng.next_u64() % q.moduli_u64[0],
|
||||
rng.next_u64() % q.moduli_u64[1],
|
||||
rng.next_u64() % q.moduli_u64[2],
|
||||
];
|
||||
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
|
||||
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
|
||||
let n = BigUint::from(*numerator);
|
||||
let d = BigUint::from(*denominator);
|
||||
let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d));
|
||||
for _ in 0..ntests {
|
||||
let x = vec![
|
||||
rng.next_u64() % q.moduli_u64[0],
|
||||
rng.next_u64() % q.moduli_u64[1],
|
||||
rng.next_u64() % q.moduli_u64[2],
|
||||
];
|
||||
|
||||
let mut x_lift = q.lift(ArrayView1::from(&x));
|
||||
let x_sign = x_lift >= (q.modulus() >> 1);
|
||||
if x_sign {
|
||||
x_lift = q.modulus() - x_lift;
|
||||
}
|
||||
let mut x_lift = q.lift(ArrayView1::from(&x));
|
||||
let x_sign = x_lift >= (q.modulus() >> 1);
|
||||
if x_sign {
|
||||
x_lift = q.modulus() - x_lift;
|
||||
}
|
||||
|
||||
let y = scaler.scale_new((&x).into(), r.moduli.len());
|
||||
let x_scaled_round = if x_sign {
|
||||
if d.to_u64().unwrap() % 2 == 0 {
|
||||
r.modulus()
|
||||
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus()
|
||||
} else {
|
||||
r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus()
|
||||
}
|
||||
} else {
|
||||
&(&x_lift * &n + (&d >> 1)) / &d
|
||||
};
|
||||
assert_eq!(y, r.project(&x_scaled_round));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
let y = scaler.scale_new((&x).into(), r.moduli.len());
|
||||
let x_scaled_round = if x_sign {
|
||||
if d.to_u64().unwrap() % 2 == 0 {
|
||||
r.modulus()
|
||||
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus()
|
||||
} else {
|
||||
r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus()
|
||||
}
|
||||
} else {
|
||||
&(&x_lift * &n + (&d >> 1)) / &d
|
||||
};
|
||||
assert_eq!(y, r.project(&x_scaled_round));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,233 +3,233 @@ use num_bigint::BigUint;
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
rns::RnsContext,
|
||||
zq::{ntt::NttOperator, Modulus},
|
||||
Error, Result,
|
||||
rns::RnsContext,
|
||||
zq::{ntt::NttOperator, Modulus},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
/// Struct that holds the context associated with elements in rq.
|
||||
#[derive(Default, Clone, PartialEq, Eq)]
|
||||
pub struct Context {
|
||||
pub(crate) moduli: Box<[u64]>,
|
||||
pub(crate) q: Box<[Modulus]>,
|
||||
pub(crate) rns: Arc<RnsContext>,
|
||||
pub(crate) ops: Box<[NttOperator]>,
|
||||
pub(crate) degree: usize,
|
||||
pub(crate) bitrev: Box<[usize]>,
|
||||
pub(crate) inv_last_qi_mod_qj: Box<[u64]>,
|
||||
pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>,
|
||||
pub(crate) next_context: Option<Arc<Context>>,
|
||||
pub(crate) moduli: Box<[u64]>,
|
||||
pub(crate) q: Box<[Modulus]>,
|
||||
pub(crate) rns: Arc<RnsContext>,
|
||||
pub(crate) ops: Box<[NttOperator]>,
|
||||
pub(crate) degree: usize,
|
||||
pub(crate) bitrev: Box<[usize]>,
|
||||
pub(crate) inv_last_qi_mod_qj: Box<[u64]>,
|
||||
pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>,
|
||||
pub(crate) next_context: Option<Arc<Context>>,
|
||||
}
|
||||
|
||||
impl Debug for Context {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Context")
|
||||
.field("moduli", &self.moduli)
|
||||
// .field("q", &self.q)
|
||||
// .field("rns", &self.rns)
|
||||
// .field("ops", &self.ops)
|
||||
// .field("degree", &self.degree)
|
||||
// .field("bitrev", &self.bitrev)
|
||||
// .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj)
|
||||
// .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup)
|
||||
.field("next_context", &self.next_context)
|
||||
.finish()
|
||||
}
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Context")
|
||||
.field("moduli", &self.moduli)
|
||||
// .field("q", &self.q)
|
||||
// .field("rns", &self.rns)
|
||||
// .field("ops", &self.ops)
|
||||
// .field("degree", &self.degree)
|
||||
// .field("bitrev", &self.bitrev)
|
||||
// .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj)
|
||||
// .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup)
|
||||
.field("next_context", &self.next_context)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Creates a context from a list of moduli and a polynomial degree.
|
||||
///
|
||||
/// Returns an error if the moduli are not primes less than 62 bits which
|
||||
/// supports the NTT of size `degree`.
|
||||
pub fn new(moduli: &[u64], degree: usize) -> Result<Self> {
|
||||
if !degree.is_power_of_two() || degree < 8 {
|
||||
Err(Error::Default(
|
||||
"The degree is not a power of two larger or equal to 8".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut q = Vec::with_capacity(moduli.len());
|
||||
let rns = Arc::new(RnsContext::new(moduli)?);
|
||||
let mut ops = Vec::with_capacity(moduli.len());
|
||||
for modulus in moduli {
|
||||
let qi = Modulus::new(*modulus)?;
|
||||
if let Some(op) = NttOperator::new(&qi, degree) {
|
||||
q.push(qi);
|
||||
ops.push(op);
|
||||
} else {
|
||||
return Err(Error::Default(
|
||||
"Impossible to construct a Ntt operator".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let bitrev = (0..degree)
|
||||
.map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1))
|
||||
.collect_vec();
|
||||
/// Creates a context from a list of moduli and a polynomial degree.
|
||||
///
|
||||
/// Returns an error if the moduli are not primes less than 62 bits which
|
||||
/// supports the NTT of size `degree`.
|
||||
pub fn new(moduli: &[u64], degree: usize) -> Result<Self> {
|
||||
if !degree.is_power_of_two() || degree < 8 {
|
||||
Err(Error::Default(
|
||||
"The degree is not a power of two larger or equal to 8".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut q = Vec::with_capacity(moduli.len());
|
||||
let rns = Arc::new(RnsContext::new(moduli)?);
|
||||
let mut ops = Vec::with_capacity(moduli.len());
|
||||
for modulus in moduli {
|
||||
let qi = Modulus::new(*modulus)?;
|
||||
if let Some(op) = NttOperator::new(&qi, degree) {
|
||||
q.push(qi);
|
||||
ops.push(op);
|
||||
} else {
|
||||
return Err(Error::Default(
|
||||
"Impossible to construct a Ntt operator".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let bitrev = (0..degree)
|
||||
.map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1))
|
||||
.collect_vec();
|
||||
|
||||
let mut inv_last_qi_mod_qj = vec![];
|
||||
let mut inv_last_qi_mod_qj_shoup = vec![];
|
||||
let q_last = moduli.last().unwrap();
|
||||
for qi in &q[..q.len() - 1] {
|
||||
let inv = qi.inv(qi.reduce(*q_last)).unwrap();
|
||||
inv_last_qi_mod_qj.push(inv);
|
||||
inv_last_qi_mod_qj_shoup.push(qi.shoup(inv));
|
||||
}
|
||||
let mut inv_last_qi_mod_qj = vec![];
|
||||
let mut inv_last_qi_mod_qj_shoup = vec![];
|
||||
let q_last = moduli.last().unwrap();
|
||||
for qi in &q[..q.len() - 1] {
|
||||
let inv = qi.inv(qi.reduce(*q_last)).unwrap();
|
||||
inv_last_qi_mod_qj.push(inv);
|
||||
inv_last_qi_mod_qj_shoup.push(qi.shoup(inv));
|
||||
}
|
||||
|
||||
let next_context = if moduli.len() >= 2 {
|
||||
Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let next_context = if moduli.len() >= 2 {
|
||||
Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
moduli: moduli.to_owned().into_boxed_slice(),
|
||||
q: q.into_boxed_slice(),
|
||||
rns,
|
||||
ops: ops.into_boxed_slice(),
|
||||
degree,
|
||||
bitrev: bitrev.into_boxed_slice(),
|
||||
inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(),
|
||||
inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(),
|
||||
next_context,
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
moduli: moduli.to_owned().into_boxed_slice(),
|
||||
q: q.into_boxed_slice(),
|
||||
rns,
|
||||
ops: ops.into_boxed_slice(),
|
||||
degree,
|
||||
bitrev: bitrev.into_boxed_slice(),
|
||||
inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(),
|
||||
inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(),
|
||||
next_context,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the modulus as a BigUint.
|
||||
pub fn modulus(&self) -> &BigUint {
|
||||
self.rns.modulus()
|
||||
}
|
||||
/// Returns the modulus as a BigUint.
|
||||
pub fn modulus(&self) -> &BigUint {
|
||||
self.rns.modulus()
|
||||
}
|
||||
|
||||
/// Returns a reference to the moduli in this context.
|
||||
pub fn moduli(&self) -> &[u64] {
|
||||
&self.moduli
|
||||
}
|
||||
/// Returns a reference to the moduli in this context.
|
||||
pub fn moduli(&self) -> &[u64] {
|
||||
&self.moduli
|
||||
}
|
||||
|
||||
/// Returns a reference to the moduli as Modulus in this context.
|
||||
pub fn moduli_operators(&self) -> &[Modulus] {
|
||||
&self.q
|
||||
}
|
||||
/// Returns a reference to the moduli as Modulus in this context.
|
||||
pub fn moduli_operators(&self) -> &[Modulus] {
|
||||
&self.q
|
||||
}
|
||||
|
||||
/// Returns the number of iterations to switch to a children context.
|
||||
/// Returns an error if the context provided is not a child context.
|
||||
pub fn niterations_to(&self, context: &Arc<Context>) -> Result<usize> {
|
||||
if context.as_ref() == self {
|
||||
return Ok(0);
|
||||
}
|
||||
/// Returns the number of iterations to switch to a children context.
|
||||
/// Returns an error if the context provided is not a child context.
|
||||
pub fn niterations_to(&self, context: &Arc<Context>) -> Result<usize> {
|
||||
if context.as_ref() == self {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut niterations = 0;
|
||||
let mut found = false;
|
||||
let mut current_ctx = Arc::new(self.clone());
|
||||
while current_ctx.next_context.is_some() {
|
||||
niterations += 1;
|
||||
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
|
||||
if ¤t_ctx == context {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if found {
|
||||
Ok(niterations)
|
||||
} else {
|
||||
Err(Error::InvalidContext)
|
||||
}
|
||||
}
|
||||
let mut niterations = 0;
|
||||
let mut found = false;
|
||||
let mut current_ctx = Arc::new(self.clone());
|
||||
while current_ctx.next_context.is_some() {
|
||||
niterations += 1;
|
||||
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
|
||||
if ¤t_ctx == context {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if found {
|
||||
Ok(niterations)
|
||||
} else {
|
||||
Err(Error::InvalidContext)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the context after `i` iterations.
|
||||
pub fn context_at_level(&self, i: usize) -> Result<Arc<Self>> {
|
||||
if i >= self.moduli.len() {
|
||||
Err(Error::Default(
|
||||
"No context at the specified level".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut current_ctx = Arc::new(self.clone());
|
||||
for _ in 0..i {
|
||||
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
|
||||
}
|
||||
Ok(current_ctx)
|
||||
}
|
||||
}
|
||||
/// Returns the context after `i` iterations.
|
||||
pub fn context_at_level(&self, i: usize) -> Result<Arc<Self>> {
|
||||
if i >= self.moduli.len() {
|
||||
Err(Error::Default(
|
||||
"No context at the specified level".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut current_ctx = Arc::new(self.clone());
|
||||
for _ in 0..i {
|
||||
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
|
||||
}
|
||||
Ok(current_ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{error::Error, sync::Arc};
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
use crate::{rq::Context, zq::ntt::supports_ntt};
|
||||
use crate::{rq::Context, zq::ntt::supports_ntt};
|
||||
|
||||
const MODULI: &[u64; 5] = &[
|
||||
1153,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
];
|
||||
const MODULI: &[u64; 5] = &[
|
||||
1153,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn context_constructor() {
|
||||
for modulus in MODULI {
|
||||
// modulus is = 1 modulo 2 * 8
|
||||
assert!(Context::new(&[*modulus], 8).is_ok());
|
||||
#[test]
|
||||
fn context_constructor() {
|
||||
for modulus in MODULI {
|
||||
// modulus is = 1 modulo 2 * 8
|
||||
assert!(Context::new(&[*modulus], 8).is_ok());
|
||||
|
||||
if supports_ntt(*modulus, 128) {
|
||||
assert!(Context::new(&[*modulus], 128).is_ok());
|
||||
} else {
|
||||
assert!(Context::new(&[*modulus], 128).is_err());
|
||||
}
|
||||
}
|
||||
if supports_ntt(*modulus, 128) {
|
||||
assert!(Context::new(&[*modulus], 128).is_ok());
|
||||
} else {
|
||||
assert!(Context::new(&[*modulus], 128).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
// All moduli in MODULI are = 1 modulo 2 * 8
|
||||
assert!(Context::new(MODULI, 8).is_ok());
|
||||
// All moduli in MODULI are = 1 modulo 2 * 8
|
||||
assert!(Context::new(MODULI, 8).is_ok());
|
||||
|
||||
// This should fail since 1153 != 1 moduli 2 * 128
|
||||
assert!(Context::new(MODULI, 128).is_err());
|
||||
}
|
||||
// This should fail since 1153 != 1 moduli 2 * 128
|
||||
assert!(Context::new(MODULI, 128).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_context() -> Result<(), Box<dyn Error>> {
|
||||
// A context should have a children pointing to a context with one less modulus.
|
||||
let context = Arc::new(Context::new(MODULI, 8)?);
|
||||
assert_eq!(
|
||||
context.next_context,
|
||||
Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?))
|
||||
);
|
||||
#[test]
|
||||
fn next_context() -> Result<(), Box<dyn Error>> {
|
||||
// A context should have a children pointing to a context with one less modulus.
|
||||
let context = Arc::new(Context::new(MODULI, 8)?);
|
||||
assert_eq!(
|
||||
context.next_context,
|
||||
Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?))
|
||||
);
|
||||
|
||||
// We can go down the chain of the MODULI.len() - 1 context's.
|
||||
let mut number_of_children = 0;
|
||||
let mut current = context;
|
||||
while current.next_context.is_some() {
|
||||
number_of_children += 1;
|
||||
current = current.next_context.as_ref().unwrap().clone();
|
||||
}
|
||||
assert_eq!(number_of_children, MODULI.len() - 1);
|
||||
// We can go down the chain of the MODULI.len() - 1 context's.
|
||||
let mut number_of_children = 0;
|
||||
let mut current = context;
|
||||
while current.next_context.is_some() {
|
||||
number_of_children += 1;
|
||||
current = current.next_context.as_ref().unwrap().clone();
|
||||
}
|
||||
assert_eq!(number_of_children, MODULI.len() - 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn niterations_to() -> Result<(), Box<dyn Error>> {
|
||||
// A context should have a children pointing to a context with one less modulus.
|
||||
let context = Arc::new(Context::new(MODULI, 8)?);
|
||||
#[test]
|
||||
fn niterations_to() -> Result<(), Box<dyn Error>> {
|
||||
// A context should have a children pointing to a context with one less modulus.
|
||||
let context = Arc::new(Context::new(MODULI, 8)?);
|
||||
|
||||
assert_eq!(context.niterations_to(&context).ok(), Some(0));
|
||||
assert_eq!(context.niterations_to(&context).ok(), Some(0));
|
||||
|
||||
assert_eq!(
|
||||
context
|
||||
.niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?))
|
||||
.err(),
|
||||
Some(crate::Error::InvalidContext)
|
||||
);
|
||||
assert_eq!(
|
||||
context
|
||||
.niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?))
|
||||
.err(),
|
||||
Some(crate::Error::InvalidContext)
|
||||
);
|
||||
|
||||
for i in 1..MODULI.len() {
|
||||
assert_eq!(
|
||||
context
|
||||
.niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?))
|
||||
.ok(),
|
||||
Some(i)
|
||||
);
|
||||
}
|
||||
for i in 1..MODULI.len() {
|
||||
assert_eq!(
|
||||
context
|
||||
.niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?))
|
||||
.ok(),
|
||||
Some(i)
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,8 +4,8 @@
|
||||
|
||||
use super::{Context, Poly, Representation};
|
||||
use crate::{
|
||||
rns::{RnsScaler, ScalingFactor},
|
||||
Error, Result,
|
||||
rns::{RnsScaler, ScalingFactor},
|
||||
Error, Result,
|
||||
};
|
||||
use itertools::izip;
|
||||
use ndarray::{s, Array2, Axis};
|
||||
@@ -14,199 +14,200 @@ use std::sync::Arc;
|
||||
/// Context extender.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Scaler {
|
||||
from: Arc<Context>,
|
||||
to: Arc<Context>,
|
||||
number_common_moduli: usize,
|
||||
scaler: RnsScaler,
|
||||
from: Arc<Context>,
|
||||
to: Arc<Context>,
|
||||
number_common_moduli: usize,
|
||||
scaler: RnsScaler,
|
||||
}
|
||||
|
||||
impl Scaler {
|
||||
/// Create a scaler from a context `from` to a context `to`.
|
||||
pub fn new(from: &Arc<Context>, to: &Arc<Context>, factor: ScalingFactor) -> Result<Self> {
|
||||
if from.degree != to.degree {
|
||||
return Err(Error::Default("Incompatible degrees".to_string()));
|
||||
}
|
||||
/// Create a scaler from a context `from` to a context `to`.
|
||||
pub fn new(from: &Arc<Context>, to: &Arc<Context>, factor: ScalingFactor) -> Result<Self> {
|
||||
if from.degree != to.degree {
|
||||
return Err(Error::Default("Incompatible degrees".to_string()));
|
||||
}
|
||||
|
||||
let mut number_common_moduli = 0;
|
||||
if factor.is_one {
|
||||
for (qi, pi) in izip!(from.q.iter(), to.q.iter()) {
|
||||
if qi == pi {
|
||||
number_common_moduli += 1
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut number_common_moduli = 0;
|
||||
if factor.is_one {
|
||||
for (qi, pi) in izip!(from.q.iter(), to.q.iter()) {
|
||||
if qi == pi {
|
||||
number_common_moduli += 1
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let scaler = RnsScaler::new(&from.rns, &to.rns, factor);
|
||||
let scaler = RnsScaler::new(&from.rns, &to.rns, factor);
|
||||
|
||||
Ok(Self {
|
||||
from: from.clone(),
|
||||
to: to.clone(),
|
||||
number_common_moduli,
|
||||
scaler,
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
from: from.clone(),
|
||||
to: to.clone(),
|
||||
number_common_moduli,
|
||||
scaler,
|
||||
})
|
||||
}
|
||||
|
||||
/// Scale a polynomial
|
||||
pub(crate) fn scale(&self, p: &Poly) -> Result<Poly> {
|
||||
if p.ctx.as_ref() != self.from.as_ref() {
|
||||
Err(Error::Default(
|
||||
"The input polynomial does not have the correct context".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut representation = p.representation.clone();
|
||||
if representation == Representation::NttShoup {
|
||||
representation = Representation::Ntt;
|
||||
}
|
||||
/// Scale a polynomial
|
||||
pub(crate) fn scale(&self, p: &Poly) -> Result<Poly> {
|
||||
if p.ctx.as_ref() != self.from.as_ref() {
|
||||
Err(Error::Default(
|
||||
"The input polynomial does not have the correct context".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut representation = p.representation.clone();
|
||||
if representation == Representation::NttShoup {
|
||||
representation = Representation::Ntt;
|
||||
}
|
||||
|
||||
let mut new_coefficients = Array2::<u64>::zeros((self.to.q.len(), self.to.degree));
|
||||
let mut new_coefficients = Array2::<u64>::zeros((self.to.q.len(), self.to.degree));
|
||||
|
||||
if self.number_common_moduli > 0 {
|
||||
new_coefficients
|
||||
.slice_mut(s![..self.number_common_moduli, ..])
|
||||
.assign(&p.coefficients.slice(s![..self.number_common_moduli, ..]));
|
||||
}
|
||||
if self.number_common_moduli > 0 {
|
||||
new_coefficients
|
||||
.slice_mut(s![..self.number_common_moduli, ..])
|
||||
.assign(&p.coefficients.slice(s![..self.number_common_moduli, ..]));
|
||||
}
|
||||
|
||||
if self.number_common_moduli < self.to.q.len() {
|
||||
if p.representation == Representation::PowerBasis {
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.axis_iter_mut(Axis(1)),
|
||||
p.coefficients.axis_iter(Axis(1))
|
||||
)
|
||||
.for_each(|(new_column, column)| {
|
||||
self.scaler
|
||||
.scale(column, new_column, self.number_common_moduli)
|
||||
});
|
||||
} else if self.number_common_moduli < self.to.q.len() {
|
||||
let mut p_coefficients_powerbasis = p.coefficients.clone();
|
||||
// Backward NTT
|
||||
if p.allow_variable_time_computations {
|
||||
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
|
||||
.for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
|
||||
} else {
|
||||
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
|
||||
.for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
|
||||
}
|
||||
// Conversion
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.axis_iter_mut(Axis(1)),
|
||||
p_coefficients_powerbasis.axis_iter(Axis(1))
|
||||
)
|
||||
.for_each(|(new_column, column)| {
|
||||
self.scaler
|
||||
.scale(column, new_column, self.number_common_moduli)
|
||||
});
|
||||
// Forward NTT on the second half
|
||||
if p.allow_variable_time_computations {
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.outer_iter_mut(),
|
||||
&self.to.ops[self.number_common_moduli..]
|
||||
)
|
||||
.for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
|
||||
} else {
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.outer_iter_mut(),
|
||||
&self.to.ops[self.number_common_moduli..]
|
||||
)
|
||||
.for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.number_common_moduli < self.to.q.len() {
|
||||
if p.representation == Representation::PowerBasis {
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.axis_iter_mut(Axis(1)),
|
||||
p.coefficients.axis_iter(Axis(1))
|
||||
)
|
||||
.for_each(|(new_column, column)| {
|
||||
self.scaler
|
||||
.scale(column, new_column, self.number_common_moduli)
|
||||
});
|
||||
} else if self.number_common_moduli < self.to.q.len() {
|
||||
let mut p_coefficients_powerbasis = p.coefficients.clone();
|
||||
// Backward NTT
|
||||
if p.allow_variable_time_computations {
|
||||
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
|
||||
.for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
|
||||
} else {
|
||||
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
|
||||
.for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
|
||||
}
|
||||
// Conversion
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.axis_iter_mut(Axis(1)),
|
||||
p_coefficients_powerbasis.axis_iter(Axis(1))
|
||||
)
|
||||
.for_each(|(new_column, column)| {
|
||||
self.scaler
|
||||
.scale(column, new_column, self.number_common_moduli)
|
||||
});
|
||||
// Forward NTT on the second half
|
||||
if p.allow_variable_time_computations {
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.outer_iter_mut(),
|
||||
&self.to.ops[self.number_common_moduli..]
|
||||
)
|
||||
.for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
|
||||
} else {
|
||||
izip!(
|
||||
new_coefficients
|
||||
.slice_mut(s![self.number_common_moduli.., ..])
|
||||
.outer_iter_mut(),
|
||||
&self.to.ops[self.number_common_moduli..]
|
||||
)
|
||||
.for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Poly {
|
||||
ctx: self.to.clone(),
|
||||
representation,
|
||||
allow_variable_time_computations: p.allow_variable_time_computations,
|
||||
coefficients: new_coefficients,
|
||||
coefficients_shoup: None,
|
||||
has_lazy_coefficients: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Poly {
|
||||
ctx: self.to.clone(),
|
||||
representation,
|
||||
allow_variable_time_computations: p.allow_variable_time_computations,
|
||||
coefficients: new_coefficients,
|
||||
coefficients_shoup: None,
|
||||
has_lazy_coefficients: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Scaler, ScalingFactor};
|
||||
use crate::rq::{Context, Poly, Representation};
|
||||
use itertools::Itertools;
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::{One, Zero};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use super::{Scaler, ScalingFactor};
|
||||
use crate::rq::{Context, Poly, Representation};
|
||||
use itertools::Itertools;
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::{One, Zero};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
// Moduli to be used in tests.
|
||||
static Q: &[u64; 3] = &[
|
||||
4611686018282684417,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
];
|
||||
// Moduli to be used in tests.
|
||||
static Q: &[u64; 3] = &[
|
||||
4611686018282684417,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
];
|
||||
|
||||
static P: &[u64; 3] = &[
|
||||
4611686018282684417,
|
||||
4611686018309947393,
|
||||
4611686018257518593,
|
||||
];
|
||||
static P: &[u64; 3] = &[
|
||||
4611686018282684417,
|
||||
4611686018309947393,
|
||||
4611686018257518593,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn scaler() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let ntests = 100;
|
||||
let from = Arc::new(Context::new(Q, 8)?);
|
||||
let to = Arc::new(Context::new(P, 8)?);
|
||||
#[test]
|
||||
fn scaler() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let ntests = 100;
|
||||
let from = Arc::new(Context::new(Q, 8)?);
|
||||
let to = Arc::new(Context::new(P, 8)?);
|
||||
|
||||
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
|
||||
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
|
||||
let n = BigUint::from(*numerator);
|
||||
let d = BigUint::from(*denominator);
|
||||
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
|
||||
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
|
||||
let n = BigUint::from(*numerator);
|
||||
let d = BigUint::from(*denominator);
|
||||
|
||||
let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?;
|
||||
let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?;
|
||||
|
||||
for _ in 0..ntests {
|
||||
let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng);
|
||||
let poly_biguint = Vec::<BigUint>::from(&poly);
|
||||
for _ in 0..ntests {
|
||||
let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng);
|
||||
let poly_biguint = Vec::<BigUint>::from(&poly);
|
||||
|
||||
let scaled_poly = scaler.scale(&poly)?;
|
||||
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
|
||||
let scaled_poly = scaler.scale(&poly)?;
|
||||
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
|
||||
|
||||
let expected = poly_biguint
|
||||
.iter()
|
||||
.map(|i| {
|
||||
if i >= &(from.modulus() >> 1usize) {
|
||||
if &d & BigUint::one() == BigUint::zero() {
|
||||
to.modulus()
|
||||
- (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64))
|
||||
/ &d) % to.modulus()
|
||||
} else {
|
||||
to.modulus()
|
||||
- (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d)
|
||||
% to.modulus()
|
||||
}
|
||||
} else {
|
||||
((i * &n + (&d >> 1)) / &d) % to.modulus()
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
assert_eq!(expected, scaled_biguint);
|
||||
let expected = poly_biguint
|
||||
.iter()
|
||||
.map(|i| {
|
||||
if i >= &(from.modulus() >> 1usize) {
|
||||
if &d & BigUint::one() == BigUint::zero() {
|
||||
to.modulus()
|
||||
- (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64))
|
||||
/ &d)
|
||||
% to.modulus()
|
||||
} else {
|
||||
to.modulus()
|
||||
- (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d)
|
||||
% to.modulus()
|
||||
}
|
||||
} else {
|
||||
((i * &n + (&d >> 1)) / &d) % to.modulus()
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
assert_eq!(expected, scaled_biguint);
|
||||
|
||||
poly.change_representation(Representation::Ntt);
|
||||
let mut scaled_poly = scaler.scale(&poly)?;
|
||||
scaled_poly.change_representation(Representation::PowerBasis);
|
||||
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
|
||||
assert_eq!(expected, scaled_biguint);
|
||||
}
|
||||
}
|
||||
}
|
||||
poly.change_representation(Representation::Ntt);
|
||||
let mut scaled_poly = scaler.scale(&poly)?;
|
||||
scaled_poly.change_representation(Representation::PowerBasis);
|
||||
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
|
||||
assert_eq!(expected, scaled_biguint);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,59 +8,59 @@ use fhe_traits::{DeserializeWithContext, Serialize};
|
||||
use protobuf::Message;
|
||||
|
||||
impl Serialize for Poly {
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
let rq = Rq::from(self);
|
||||
rq.write_to_bytes().unwrap()
|
||||
}
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
let rq = Rq::from(self);
|
||||
rq.write_to_bytes().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl DeserializeWithContext for Poly {
|
||||
type Error = Error;
|
||||
type Context = Context;
|
||||
type Error = Error;
|
||||
type Context = Context;
|
||||
|
||||
fn from_bytes(bytes: &[u8], ctx: &Arc<Context>) -> Result<Self, Self::Error> {
|
||||
let rq = Rq::parse_from_bytes(bytes).map_err(|e| Error::Serialization(e.to_string()))?;
|
||||
Poly::try_convert_from(&rq, ctx, false, None)
|
||||
}
|
||||
fn from_bytes(bytes: &[u8], ctx: &Arc<Context>) -> Result<Self, Self::Error> {
|
||||
let rq = Rq::parse_from_bytes(bytes).map_err(|e| Error::Serialization(e.to_string()))?;
|
||||
Poly::try_convert_from(&rq, ctx, false, None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{error::Error, sync::Arc};
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
use fhe_traits::{DeserializeWithContext, Serialize};
|
||||
use rand::thread_rng;
|
||||
use fhe_traits::{DeserializeWithContext, Serialize};
|
||||
use rand::thread_rng;
|
||||
|
||||
use crate::rq::{Context, Poly, Representation};
|
||||
use crate::rq::{Context, Poly, Representation};
|
||||
|
||||
const Q: &[u64; 3] = &[
|
||||
4611686018282684417,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
];
|
||||
const Q: &[u64; 3] = &[
|
||||
4611686018282684417,
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
#[test]
|
||||
fn serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for qi in Q {
|
||||
let ctx = Arc::new(Context::new(&[*qi], 8)?);
|
||||
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
}
|
||||
for qi in Q {
|
||||
let ctx = Arc::new(Context::new(&[*qi], 8)?);
|
||||
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
}
|
||||
|
||||
let ctx = Arc::new(Context::new(Q, 8)?);
|
||||
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let ctx = Arc::new(Context::new(Q, 8)?);
|
||||
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
|
||||
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,19 +9,19 @@ use std::sync::Arc;
|
||||
/// Context switcher.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Switcher {
|
||||
pub(crate) scaler: Scaler,
|
||||
pub(crate) scaler: Scaler,
|
||||
}
|
||||
|
||||
impl Switcher {
|
||||
/// Create a switcher from a context `from` to a context `to`.
|
||||
pub fn new(from: &Arc<Context>, to: &Arc<Context>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
scaler: Scaler::new(from, to, ScalingFactor::new(to.modulus(), from.modulus()))?,
|
||||
})
|
||||
}
|
||||
/// Create a switcher from a context `from` to a context `to`.
|
||||
pub fn new(from: &Arc<Context>, to: &Arc<Context>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
scaler: Scaler::new(from, to, ScalingFactor::new(to.modulus(), from.modulus()))?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Switch a polynomial.
|
||||
pub(crate) fn switch(&self, p: &Poly) -> Result<Poly> {
|
||||
self.scaler.scale(p)
|
||||
}
|
||||
/// Switch a polynomial.
|
||||
pub(crate) fn switch(&self, p: &Poly) -> Result<Poly> {
|
||||
self.scaler.scale(p)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,19 +14,19 @@ use std::sync::Arc;
|
||||
/// blanket implementation <https://github.com/rust-lang/rust/issues/50133#issuecomment-488512355>.
|
||||
pub trait TryConvertFrom<T>
|
||||
where
|
||||
Self: Sized,
|
||||
Self: Sized,
|
||||
{
|
||||
/// Attempt to convert the `value` into a polynomial with a specific context
|
||||
/// and under a specific representation. The representation may optional and
|
||||
/// be specified as `None`; this is useful for example when converting from
|
||||
/// a value that encodes the representation (e.g., serialization, protobuf,
|
||||
/// etc.).
|
||||
fn try_convert_from<R>(
|
||||
value: T,
|
||||
ctx: &Arc<Context>,
|
||||
variable_time: bool,
|
||||
representation: R,
|
||||
) -> Result<Self>
|
||||
where
|
||||
R: Into<Option<Representation>>;
|
||||
/// Attempt to convert the `value` into a polynomial with a specific context
|
||||
/// and under a specific representation. The representation may optional and
|
||||
/// be specified as `None`; this is useful for example when converting from
|
||||
/// a value that encodes the representation (e.g., serialization, protobuf,
|
||||
/// etc.).
|
||||
fn try_convert_from<R>(
|
||||
value: T,
|
||||
ctx: &Arc<Context>,
|
||||
variable_time: bool,
|
||||
representation: R,
|
||||
) -> Result<Self>
|
||||
where
|
||||
R: Into<Option<Representation>>;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,489 +2,483 @@
|
||||
|
||||
use super::Modulus;
|
||||
use fhe_util::is_prime;
|
||||
use itertools::Itertools;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
use std::iter::successors;
|
||||
|
||||
/// Returns whether a modulus p is prime and supports the Number Theoretic
|
||||
/// Transform of size n.
|
||||
///
|
||||
/// Aborts if n is not a power of 2 that is >= 8.
|
||||
pub fn supports_ntt(p: u64, n: usize) -> bool {
|
||||
assert!(n >= 8 && n.is_power_of_two());
|
||||
assert!(n >= 8 && n.is_power_of_two());
|
||||
|
||||
p % ((n as u64) << 1) == 1 && is_prime(p)
|
||||
p % ((n as u64) << 1) == 1 && is_prime(p)
|
||||
}
|
||||
|
||||
/// Number-Theoretic Transform operator.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct NttOperator {
|
||||
p: Modulus,
|
||||
p_twice: u64,
|
||||
size: usize,
|
||||
omegas: Box<[u64]>,
|
||||
omegas_shoup: Box<[u64]>,
|
||||
omegas_inv: Box<[u64]>,
|
||||
zetas_inv: Box<[u64]>,
|
||||
zetas_inv_shoup: Box<[u64]>,
|
||||
size_inv: u64,
|
||||
size_inv_shoup: u64,
|
||||
p: Modulus,
|
||||
p_twice: u64,
|
||||
size: usize,
|
||||
omegas: Box<[u64]>,
|
||||
omegas_shoup: Box<[u64]>,
|
||||
zetas_inv: Box<[u64]>,
|
||||
zetas_inv_shoup: Box<[u64]>,
|
||||
size_inv: u64,
|
||||
size_inv_shoup: u64,
|
||||
}
|
||||
|
||||
impl NttOperator {
|
||||
/// Create an NTT operator given a modulus for a specific size.
|
||||
///
|
||||
/// Aborts if the size is not a power of 2 that is >= 8 in debug mode.
|
||||
/// Returns None if the modulus does not support the NTT for this specific
|
||||
/// size.
|
||||
pub fn new(p: &Modulus, size: usize) -> Option<Self> {
|
||||
if !supports_ntt(p.p, size) {
|
||||
None
|
||||
} else {
|
||||
let omega = Self::primitive_root(size, p);
|
||||
let omega_inv = p.inv(omega).unwrap();
|
||||
/// Create an NTT operator given a modulus for a specific size.
|
||||
///
|
||||
/// Aborts if the size is not a power of 2 that is >= 8 in debug mode.
|
||||
/// Returns None if the modulus does not support the NTT for this specific
|
||||
/// size.
|
||||
pub fn new(p: &Modulus, size: usize) -> Option<Self> {
|
||||
if !supports_ntt(p.p, size) {
|
||||
None
|
||||
} else {
|
||||
let size_inv = p.inv(size as u64)?;
|
||||
|
||||
let mut exp = 1u64;
|
||||
let mut exp_inv = 1u64;
|
||||
let mut powers = Vec::with_capacity(size + 1);
|
||||
let mut powers_inv = Vec::with_capacity(size + 1);
|
||||
for _ in 0..size + 1 {
|
||||
powers.push(exp);
|
||||
powers_inv.push(exp_inv);
|
||||
exp = p.mul(exp, omega);
|
||||
exp_inv = p.mul(exp_inv, omega_inv);
|
||||
}
|
||||
let omega = Self::primitive_root(size, p);
|
||||
let omega_inv = p.inv(omega)?;
|
||||
|
||||
let mut omegas = Vec::with_capacity(size);
|
||||
let mut omegas_inv = Vec::with_capacity(size);
|
||||
let mut zetas_inv = Vec::with_capacity(size);
|
||||
for i in 0..size {
|
||||
let j = i.reverse_bits() >> (size.leading_zeros() + 1);
|
||||
omegas.push(powers[j]);
|
||||
omegas_inv.push(powers_inv[j]);
|
||||
zetas_inv.push(powers_inv[j + 1]);
|
||||
}
|
||||
let powers = successors(Some(1u64), |n| Some(p.mul(*n, omega)))
|
||||
.take(size)
|
||||
.collect_vec();
|
||||
let powers_inv = successors(Some(omega_inv), |n| Some(p.mul(*n, omega_inv)))
|
||||
.take(size)
|
||||
.collect_vec();
|
||||
|
||||
let size_inv = p.inv(size as u64).unwrap();
|
||||
let mut omegas = Vec::with_capacity(size);
|
||||
let mut zetas_inv = Vec::with_capacity(size);
|
||||
for i in 0..size {
|
||||
let j = i.reverse_bits() >> (size.leading_zeros() + 1);
|
||||
omegas.push(powers[j]);
|
||||
zetas_inv.push(powers_inv[j]);
|
||||
}
|
||||
|
||||
let omegas_shoup = p.shoup_vec(&omegas);
|
||||
let zetas_inv_shoup = p.shoup_vec(&zetas_inv);
|
||||
let omegas_shoup = p.shoup_vec(&omegas);
|
||||
let zetas_inv_shoup = p.shoup_vec(&zetas_inv);
|
||||
|
||||
Some(Self {
|
||||
p: p.clone(),
|
||||
p_twice: p.p * 2,
|
||||
size,
|
||||
omegas: omegas.into_boxed_slice(),
|
||||
omegas_shoup: omegas_shoup.into_boxed_slice(),
|
||||
omegas_inv: omegas_inv.into_boxed_slice(),
|
||||
zetas_inv: zetas_inv.into_boxed_slice(),
|
||||
zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(),
|
||||
size_inv,
|
||||
size_inv_shoup: p.shoup(size_inv),
|
||||
})
|
||||
}
|
||||
}
|
||||
Some(Self {
|
||||
p: p.clone(),
|
||||
p_twice: p.p * 2,
|
||||
size,
|
||||
omegas: omegas.into_boxed_slice(),
|
||||
omegas_shoup: omegas_shoup.into_boxed_slice(),
|
||||
zetas_inv: zetas_inv.into_boxed_slice(),
|
||||
zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(),
|
||||
size_inv,
|
||||
size_inv_shoup: p.shoup(size_inv),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the forward NTT in place.
|
||||
/// Aborts if a is not of the size handled by the operator.
|
||||
pub fn forward(&self, a: &mut [u64]) {
|
||||
debug_assert_eq!(a.len(), self.size);
|
||||
/// Compute the forward NTT in place.
|
||||
/// Aborts if a is not of the size handled by the operator.
|
||||
pub fn forward(&self, a: &mut [u64]) {
|
||||
debug_assert_eq!(a.len(), self.size);
|
||||
|
||||
let n = self.size;
|
||||
let a_ptr = a.as_mut_ptr();
|
||||
let n = self.size;
|
||||
let a_ptr = a.as_mut_ptr();
|
||||
|
||||
let mut l = n >> 1;
|
||||
let mut m = 1;
|
||||
let mut k = 1;
|
||||
while l > 0 {
|
||||
for i in 0..m {
|
||||
unsafe {
|
||||
let omega = *self.omegas.get_unchecked(k);
|
||||
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
let mut l = n >> 1;
|
||||
let mut m = 1;
|
||||
let mut k = 1;
|
||||
while l > 0 {
|
||||
for i in 0..m {
|
||||
unsafe {
|
||||
let omega = *self.omegas.get_unchecked(k);
|
||||
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
|
||||
let s = 2 * i * l;
|
||||
match l {
|
||||
1 => {
|
||||
// The last level should reduce the output
|
||||
let uj = &mut *a_ptr.add(s);
|
||||
let ujl = &mut *a_ptr.add(s + l);
|
||||
self.butterfly(uj, ujl, omega, omega_shoup);
|
||||
*uj = self.reduce3(*uj);
|
||||
*ujl = self.reduce3(*ujl);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.butterfly(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
omega,
|
||||
omega_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l >>= 1;
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
let s = 2 * i * l;
|
||||
match l {
|
||||
1 => {
|
||||
// The last level should reduce the output
|
||||
let uj = &mut *a_ptr.add(s);
|
||||
let ujl = &mut *a_ptr.add(s + l);
|
||||
self.butterfly(uj, ujl, omega, omega_shoup);
|
||||
*uj = self.reduce3(*uj);
|
||||
*ujl = self.reduce3(*ujl);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.butterfly(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
omega,
|
||||
omega_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l >>= 1;
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the backward NTT in place.
|
||||
/// Aborts if a is not of the size handled by the operator.
|
||||
pub fn backward(&self, a: &mut [u64]) {
|
||||
debug_assert_eq!(a.len(), self.size);
|
||||
/// Compute the backward NTT in place.
|
||||
/// Aborts if a is not of the size handled by the operator.
|
||||
pub fn backward(&self, a: &mut [u64]) {
|
||||
debug_assert_eq!(a.len(), self.size);
|
||||
|
||||
let a_ptr = a.as_mut_ptr();
|
||||
let a_ptr = a.as_mut_ptr();
|
||||
|
||||
let mut k = 0;
|
||||
let mut m = self.size >> 1;
|
||||
let mut l = 1;
|
||||
while m > 0 {
|
||||
for i in 0..m {
|
||||
let s = 2 * i * l;
|
||||
unsafe {
|
||||
let zeta_inv = *self.zetas_inv.get_unchecked(k);
|
||||
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
match l {
|
||||
1 => {
|
||||
self.inv_butterfly(
|
||||
&mut *a_ptr.add(s),
|
||||
&mut *a_ptr.add(s + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.inv_butterfly(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l <<= 1;
|
||||
m >>= 1;
|
||||
}
|
||||
let mut k = 0;
|
||||
let mut m = self.size >> 1;
|
||||
let mut l = 1;
|
||||
while m > 0 {
|
||||
for i in 0..m {
|
||||
let s = 2 * i * l;
|
||||
unsafe {
|
||||
let zeta_inv = *self.zetas_inv.get_unchecked(k);
|
||||
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
match l {
|
||||
1 => {
|
||||
self.inv_butterfly(
|
||||
&mut *a_ptr.add(s),
|
||||
&mut *a_ptr.add(s + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.inv_butterfly(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l <<= 1;
|
||||
m >>= 1;
|
||||
}
|
||||
|
||||
a.iter_mut()
|
||||
.for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup));
|
||||
}
|
||||
a.iter_mut()
|
||||
.for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup));
|
||||
}
|
||||
|
||||
/// Compute the forward NTT in place in variable time in a lazily fashion.
|
||||
/// This means that the output coefficients may be up to 4 times the
|
||||
/// modulus.
|
||||
///
|
||||
/// # Safety
|
||||
/// This function assumes that a_ptr points to at least `size` elements.
|
||||
/// This function is not constant time and its timing may reveal information
|
||||
/// about the value being reduced.
|
||||
pub unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) {
|
||||
let mut l = self.size >> 1;
|
||||
let mut m = 1;
|
||||
let mut k = 1;
|
||||
while l > 0 {
|
||||
for i in 0..m {
|
||||
let omega = *self.omegas.get_unchecked(k);
|
||||
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
/// Compute the forward NTT in place in variable time in a lazily fashion.
|
||||
/// This means that the output coefficients may be up to 4 times the
|
||||
/// modulus.
|
||||
///
|
||||
/// # Safety
|
||||
/// This function assumes that a_ptr points to at least `size` elements.
|
||||
/// This function is not constant time and its timing may reveal information
|
||||
/// about the value being reduced.
|
||||
pub unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) {
|
||||
let mut l = self.size >> 1;
|
||||
let mut m = 1;
|
||||
let mut k = 1;
|
||||
while l > 0 {
|
||||
for i in 0..m {
|
||||
let omega = *self.omegas.get_unchecked(k);
|
||||
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
|
||||
let s = 2 * i * l;
|
||||
match l {
|
||||
1 => {
|
||||
self.butterfly_vt(
|
||||
&mut *a_ptr.add(s),
|
||||
&mut *a_ptr.add(s + l),
|
||||
omega,
|
||||
omega_shoup,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.butterfly_vt(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
omega,
|
||||
omega_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l >>= 1;
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
let s = 2 * i * l;
|
||||
match l {
|
||||
1 => {
|
||||
self.butterfly_vt(
|
||||
&mut *a_ptr.add(s),
|
||||
&mut *a_ptr.add(s + l),
|
||||
omega,
|
||||
omega_shoup,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.butterfly_vt(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
omega,
|
||||
omega_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l >>= 1;
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the forward NTT in place in variable time.
|
||||
///
|
||||
/// # Safety
|
||||
/// This function assumes that a_ptr points to at least `size` elements.
|
||||
/// This function is not constant time and its timing may reveal information
|
||||
/// about the value being reduced.
|
||||
pub unsafe fn forward_vt(&self, a_ptr: *mut u64) {
|
||||
self.forward_vt_lazy(a_ptr);
|
||||
for i in 0..self.size {
|
||||
*a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i))
|
||||
}
|
||||
}
|
||||
/// Compute the forward NTT in place in variable time.
|
||||
///
|
||||
/// # Safety
|
||||
/// This function assumes that a_ptr points to at least `size` elements.
|
||||
/// This function is not constant time and its timing may reveal information
|
||||
/// about the value being reduced.
|
||||
pub unsafe fn forward_vt(&self, a_ptr: *mut u64) {
|
||||
self.forward_vt_lazy(a_ptr);
|
||||
for i in 0..self.size {
|
||||
*a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i))
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the backward NTT in place in variable time.
|
||||
///
|
||||
/// # Safety
|
||||
/// This function assumes that a_ptr points to at least `size` elements.
|
||||
/// This function is not constant time and its timing may reveal information
|
||||
/// about the value being reduced.
|
||||
pub unsafe fn backward_vt(&self, a_ptr: *mut u64) {
|
||||
let mut k = 0;
|
||||
let mut m = self.size >> 1;
|
||||
let mut l = 1;
|
||||
while m > 0 {
|
||||
for i in 0..m {
|
||||
let s = 2 * i * l;
|
||||
let zeta_inv = *self.zetas_inv.get_unchecked(k);
|
||||
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
match l {
|
||||
1 => {
|
||||
self.inv_butterfly_vt(
|
||||
&mut *a_ptr.add(s),
|
||||
&mut *a_ptr.add(s + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.inv_butterfly_vt(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l <<= 1;
|
||||
m >>= 1;
|
||||
}
|
||||
/// Compute the backward NTT in place in variable time.
|
||||
///
|
||||
/// # Safety
|
||||
/// This function assumes that a_ptr points to at least `size` elements.
|
||||
/// This function is not constant time and its timing may reveal information
|
||||
/// about the value being reduced.
|
||||
pub unsafe fn backward_vt(&self, a_ptr: *mut u64) {
|
||||
let mut k = 0;
|
||||
let mut m = self.size >> 1;
|
||||
let mut l = 1;
|
||||
while m > 0 {
|
||||
for i in 0..m {
|
||||
let s = 2 * i * l;
|
||||
let zeta_inv = *self.zetas_inv.get_unchecked(k);
|
||||
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
|
||||
k += 1;
|
||||
match l {
|
||||
1 => {
|
||||
self.inv_butterfly_vt(
|
||||
&mut *a_ptr.add(s),
|
||||
&mut *a_ptr.add(s + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
for j in s..(s + l) {
|
||||
self.inv_butterfly_vt(
|
||||
&mut *a_ptr.add(j),
|
||||
&mut *a_ptr.add(j + l),
|
||||
zeta_inv,
|
||||
zeta_inv_shoup,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
l <<= 1;
|
||||
m >>= 1;
|
||||
}
|
||||
|
||||
for i in 0..self.size as isize {
|
||||
*a_ptr.offset(i) =
|
||||
self.p
|
||||
.mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup)
|
||||
}
|
||||
}
|
||||
for i in 0..self.size as isize {
|
||||
*a_ptr.offset(i) =
|
||||
self.p
|
||||
.mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup)
|
||||
}
|
||||
}
|
||||
|
||||
/// Reduce a modulo p.
|
||||
///
|
||||
/// Aborts if a >= 4 * p.
|
||||
const fn reduce3(&self, a: u64) -> u64 {
|
||||
debug_assert!(a < 4 * self.p.p);
|
||||
/// Reduce a modulo p.
|
||||
///
|
||||
/// Aborts if a >= 4 * p.
|
||||
const fn reduce3(&self, a: u64) -> u64 {
|
||||
debug_assert!(a < 4 * self.p.p);
|
||||
|
||||
let y = Modulus::reduce1(a, 2 * self.p.p);
|
||||
Modulus::reduce1(y, self.p.p)
|
||||
}
|
||||
let y = Modulus::reduce1(a, 2 * self.p.p);
|
||||
Modulus::reduce1(y, self.p.p)
|
||||
}
|
||||
|
||||
/// Reduce a modulo p in variable time.
|
||||
///
|
||||
/// Aborts if a >= 4 * p.
|
||||
const unsafe fn reduce3_vt(&self, a: u64) -> u64 {
|
||||
debug_assert!(a < 4 * self.p.p);
|
||||
/// Reduce a modulo p in variable time.
|
||||
///
|
||||
/// Aborts if a >= 4 * p.
|
||||
const unsafe fn reduce3_vt(&self, a: u64) -> u64 {
|
||||
debug_assert!(a < 4 * self.p.p);
|
||||
|
||||
let y = Modulus::reduce1_vt(a, 2 * self.p.p);
|
||||
Modulus::reduce1_vt(y, self.p.p)
|
||||
}
|
||||
let y = Modulus::reduce1_vt(a, 2 * self.p.p);
|
||||
Modulus::reduce1_vt(y, self.p.p)
|
||||
}
|
||||
|
||||
/// NTT Butterfly.
|
||||
fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
debug_assert!(w < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(w), w_shoup);
|
||||
/// NTT Butterfly.
|
||||
fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
debug_assert!(w < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(w), w_shoup);
|
||||
|
||||
*x = Modulus::reduce1(*x, self.p_twice);
|
||||
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
|
||||
*y = *x + self.p_twice - t;
|
||||
*x += t;
|
||||
*x = Modulus::reduce1(*x, self.p_twice);
|
||||
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
|
||||
*y = *x + self.p_twice - t;
|
||||
*x += t;
|
||||
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
}
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
}
|
||||
|
||||
/// NTT Butterfly in variable time.
|
||||
unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
debug_assert!(w < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(w), w_shoup);
|
||||
/// NTT Butterfly in variable time.
|
||||
unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
debug_assert!(w < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(w), w_shoup);
|
||||
|
||||
*x = Modulus::reduce1_vt(*x, self.p_twice);
|
||||
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
|
||||
*y = *x + self.p_twice - t;
|
||||
*x += t;
|
||||
*x = Modulus::reduce1_vt(*x, self.p_twice);
|
||||
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
|
||||
*y = *x + self.p_twice - t;
|
||||
*x += t;
|
||||
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
}
|
||||
debug_assert!(*x < 4 * self.p.p);
|
||||
debug_assert!(*y < 4 * self.p.p);
|
||||
}
|
||||
|
||||
/// Inverse NTT butterfly.
|
||||
fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
debug_assert!(z < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(z), z_shoup);
|
||||
/// Inverse NTT butterfly.
|
||||
fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
debug_assert!(z < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(z), z_shoup);
|
||||
|
||||
let t = *x;
|
||||
*x = Modulus::reduce1(*y + t, self.p_twice);
|
||||
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
|
||||
let t = *x;
|
||||
*x = Modulus::reduce1(*y + t, self.p_twice);
|
||||
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
|
||||
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
}
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
}
|
||||
|
||||
/// Inverse NTT butterfly in variable time
|
||||
unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
debug_assert!(z < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(z), z_shoup);
|
||||
/// Inverse NTT butterfly in variable time
|
||||
unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
debug_assert!(z < self.p.p);
|
||||
debug_assert_eq!(self.p.shoup(z), z_shoup);
|
||||
|
||||
let t = *x;
|
||||
*x = Modulus::reduce1_vt(*y + t, self.p_twice);
|
||||
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
|
||||
let t = *x;
|
||||
*x = Modulus::reduce1_vt(*y + t, self.p_twice);
|
||||
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
|
||||
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
}
|
||||
debug_assert!(*x < self.p_twice);
|
||||
debug_assert!(*y < self.p_twice);
|
||||
}
|
||||
|
||||
/// Returns a 2n-th primitive root modulo p.
|
||||
///
|
||||
/// Aborts if p is not prime or n is not a power of 2 that is >= 8.
|
||||
fn primitive_root(n: usize, p: &Modulus) -> u64 {
|
||||
debug_assert!(supports_ntt(p.p, n));
|
||||
/// Returns a 2n-th primitive root modulo p.
|
||||
///
|
||||
/// Aborts if p is not prime or n is not a power of 2 that is >= 8.
|
||||
fn primitive_root(n: usize, p: &Modulus) -> u64 {
|
||||
debug_assert!(supports_ntt(p.p, n));
|
||||
|
||||
let lambda = (p.p - 1) / (2 * n as u64);
|
||||
let lambda = (p.p - 1) / (2 * n as u64);
|
||||
|
||||
let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0);
|
||||
for _ in 0..100 {
|
||||
let mut root = rng.gen_range(0..p.p);
|
||||
root = p.pow(root, lambda);
|
||||
if Self::is_primitive_root(root, 2 * n, p) {
|
||||
return root;
|
||||
}
|
||||
}
|
||||
let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0);
|
||||
for _ in 0..100 {
|
||||
let mut root = rng.gen_range(0..p.p);
|
||||
root = p.pow(root, lambda);
|
||||
if Self::is_primitive_root(root, 2 * n, p) {
|
||||
return root;
|
||||
}
|
||||
}
|
||||
|
||||
debug_assert!(false, "Couldn't find primitive root");
|
||||
0
|
||||
}
|
||||
debug_assert!(false, "Couldn't find primitive root");
|
||||
0
|
||||
}
|
||||
|
||||
/// Returns whether a is a n-th primitive root of unity.
|
||||
///
|
||||
/// Aborts if a >= p in debug mode.
|
||||
fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool {
|
||||
debug_assert!(a < p.p);
|
||||
debug_assert!(supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here.
|
||||
/// Returns whether a is a n-th primitive root of unity.
|
||||
///
|
||||
/// Aborts if a >= p in debug mode.
|
||||
fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool {
|
||||
debug_assert!(a < p.p);
|
||||
debug_assert!(supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here.
|
||||
|
||||
// A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p
|
||||
// for all prime p dividing n.
|
||||
(p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1)
|
||||
}
|
||||
// A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p
|
||||
// for all prime p dividing n.
|
||||
(p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::thread_rng;
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::{supports_ntt, NttOperator};
|
||||
use crate::zq::Modulus;
|
||||
use super::{supports_ntt, NttOperator};
|
||||
use crate::zq::Modulus;
|
||||
|
||||
#[test]
|
||||
fn constructor() {
|
||||
for size in [8, 1024] {
|
||||
for p in [1153, 4611686018326724609] {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
let supports_ntt = supports_ntt(p, size);
|
||||
#[test]
|
||||
fn constructor() {
|
||||
for size in [8, 1024] {
|
||||
for p in [1153, 4611686018326724609] {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
let supports_ntt = supports_ntt(p, size);
|
||||
|
||||
let op = NttOperator::new(&q, size);
|
||||
let op = NttOperator::new(&q, size);
|
||||
|
||||
if supports_ntt {
|
||||
assert!(op.is_some());
|
||||
} else {
|
||||
assert!(op.is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if supports_ntt {
|
||||
assert!(op.is_some());
|
||||
} else {
|
||||
assert!(op.is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bijection() {
|
||||
let ntests = 100;
|
||||
let mut rng = thread_rng();
|
||||
#[test]
|
||||
fn bijection() {
|
||||
let ntests = 100;
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for size in [8, 1024] {
|
||||
for p in [1153, 4611686018326724609] {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
for size in [8, 1024] {
|
||||
for p in [1153, 4611686018326724609] {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
|
||||
if supports_ntt(p, size) {
|
||||
let op = NttOperator::new(&q, size).unwrap();
|
||||
if supports_ntt(p, size) {
|
||||
let op = NttOperator::new(&q, size).unwrap();
|
||||
|
||||
for _ in 0..ntests {
|
||||
let mut a = q.random_vec(size, &mut rng);
|
||||
let a_clone = a.clone();
|
||||
let mut b = a.clone();
|
||||
for _ in 0..ntests {
|
||||
let mut a = q.random_vec(size, &mut rng);
|
||||
let a_clone = a.clone();
|
||||
let mut b = a.clone();
|
||||
|
||||
op.forward(&mut a);
|
||||
assert_ne!(a, a_clone);
|
||||
op.forward(&mut a);
|
||||
assert_ne!(a, a_clone);
|
||||
|
||||
unsafe { op.forward_vt(b.as_mut_ptr()) }
|
||||
assert_eq!(a, b);
|
||||
unsafe { op.forward_vt(b.as_mut_ptr()) }
|
||||
assert_eq!(a, b);
|
||||
|
||||
op.backward(&mut a);
|
||||
assert_eq!(a, a_clone);
|
||||
op.backward(&mut a);
|
||||
assert_eq!(a, a_clone);
|
||||
|
||||
unsafe { op.backward_vt(b.as_mut_ptr()) }
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
unsafe { op.backward_vt(b.as_mut_ptr()) }
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_lazy() {
|
||||
let ntests = 100;
|
||||
let mut rng = thread_rng();
|
||||
#[test]
|
||||
fn forward_lazy() {
|
||||
let ntests = 100;
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for size in [8, 1024] {
|
||||
for p in [1153, 4611686018326724609] {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
for size in [8, 1024] {
|
||||
for p in [1153, 4611686018326724609] {
|
||||
let q = Modulus::new(p).unwrap();
|
||||
|
||||
if supports_ntt(p, size) {
|
||||
let op = NttOperator::new(&q, size).unwrap();
|
||||
if supports_ntt(p, size) {
|
||||
let op = NttOperator::new(&q, size).unwrap();
|
||||
|
||||
for _ in 0..ntests {
|
||||
let mut a = q.random_vec(size, &mut rng);
|
||||
let mut a_lazy = a.clone();
|
||||
for _ in 0..ntests {
|
||||
let mut a = q.random_vec(size, &mut rng);
|
||||
let mut a_lazy = a.clone();
|
||||
|
||||
op.forward(&mut a);
|
||||
op.forward(&mut a);
|
||||
|
||||
unsafe {
|
||||
op.forward_vt_lazy(a_lazy.as_mut_ptr());
|
||||
q.reduce_vec(&mut a_lazy);
|
||||
}
|
||||
unsafe {
|
||||
op.forward_vt_lazy(a_lazy.as_mut_ptr());
|
||||
q.reduce_vec(&mut a_lazy);
|
||||
}
|
||||
|
||||
assert_eq!(a, a_lazy);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_eq!(a, a_lazy);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,113 +7,113 @@ use num_bigint::BigUint;
|
||||
/// These optimized operations are possible when the modulus verifies
|
||||
/// Equation (1) of <https://hal.archives-ouvertes.fr/hal-01242273/document>.
|
||||
pub fn supports_opt(p: u64) -> bool {
|
||||
if p.leading_zeros() < 1 {
|
||||
return false;
|
||||
}
|
||||
if p.leading_zeros() < 1 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Let's multiply the inequality by (2^s0+1)*2^(3s0):
|
||||
// we want to output true when
|
||||
// (2^(3s0)+1) * 2^64 < 2^(3s0) * (2^s0+1) * p
|
||||
let mut middle = BigUint::from(1u64) << (3 * p.leading_zeros() as usize);
|
||||
let left_side = (&middle + 1u64) << 64;
|
||||
middle *= (1u64 << p.leading_zeros()) + 1;
|
||||
middle *= p;
|
||||
// Let's multiply the inequality by (2^s0+1)*2^(3s0):
|
||||
// we want to output true when
|
||||
// (2^(3s0)+1) * 2^64 < 2^(3s0) * (2^s0+1) * p
|
||||
let mut middle = BigUint::from(1u64) << (3 * p.leading_zeros() as usize);
|
||||
let left_side = (&middle + 1u64) << 64;
|
||||
middle *= (1u64 << p.leading_zeros()) + 1;
|
||||
middle *= p;
|
||||
|
||||
left_side < middle
|
||||
left_side < middle
|
||||
}
|
||||
|
||||
/// Generate a `num_bits`-bit prime, congruent to 1 mod `modulo`, strictly
|
||||
/// smaller than `upper_bound`. Note that `num_bits` must belong to (10..=62),
|
||||
/// and upper_bound must be <= 1 << num_bits.
|
||||
pub fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option<u64> {
|
||||
if !(10..=62).contains(&num_bits) {
|
||||
None
|
||||
} else {
|
||||
debug_assert!(
|
||||
(1u64 << num_bits) >= upper_bound,
|
||||
"upper_bound larger than number of bits"
|
||||
);
|
||||
if !(10..=62).contains(&num_bits) {
|
||||
None
|
||||
} else {
|
||||
debug_assert!(
|
||||
(1u64 << num_bits) >= upper_bound,
|
||||
"upper_bound larger than number of bits"
|
||||
);
|
||||
|
||||
let leading_zeros = (64 - num_bits) as u32;
|
||||
let leading_zeros = (64 - num_bits) as u32;
|
||||
|
||||
let mut tentative_prime = upper_bound - 1;
|
||||
while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros {
|
||||
tentative_prime -= 1
|
||||
}
|
||||
let mut tentative_prime = upper_bound - 1;
|
||||
while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros {
|
||||
tentative_prime -= 1
|
||||
}
|
||||
|
||||
while tentative_prime.leading_zeros() == leading_zeros
|
||||
&& !is_prime(tentative_prime)
|
||||
&& tentative_prime >= modulo
|
||||
{
|
||||
tentative_prime -= modulo
|
||||
}
|
||||
while tentative_prime.leading_zeros() == leading_zeros
|
||||
&& !is_prime(tentative_prime)
|
||||
&& tentative_prime >= modulo
|
||||
{
|
||||
tentative_prime -= modulo
|
||||
}
|
||||
|
||||
if tentative_prime.leading_zeros() == leading_zeros && is_prime(tentative_prime) {
|
||||
Some(tentative_prime)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
if tentative_prime.leading_zeros() == leading_zeros && is_prime(tentative_prime) {
|
||||
Some(tentative_prime)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::generate_prime;
|
||||
use fhe_util::catch_unwind;
|
||||
use super::generate_prime;
|
||||
use fhe_util::catch_unwind;
|
||||
|
||||
// Verifies that the same moduli as in the NFLlib library are generated.
|
||||
// <https://github.com/quarkslab/NFLlib/blob/master/include/nfl/params.hpp>
|
||||
#[test]
|
||||
fn nfl_62bit_primes() {
|
||||
let mut generated = vec![];
|
||||
let mut upper_bound = u64::MAX >> 2;
|
||||
while generated.len() != 20 {
|
||||
let p = generate_prime(62, 2 * 1048576, upper_bound);
|
||||
assert!(p.is_some());
|
||||
upper_bound = p.unwrap();
|
||||
generated.push(upper_bound);
|
||||
}
|
||||
assert_eq!(
|
||||
generated,
|
||||
vec![
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
4611686018257518593,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
4611686018106523649,
|
||||
4611686018058289153,
|
||||
4611686018051997697,
|
||||
4611686017974403073,
|
||||
4611686017812922369,
|
||||
4611686017781465089,
|
||||
4611686017773076481,
|
||||
4611686017678704641,
|
||||
4611686017666121729,
|
||||
4611686017647247361,
|
||||
4611686017590624257,
|
||||
4611686017554972673,
|
||||
4611686017529806849,
|
||||
4611686017517223937
|
||||
]
|
||||
)
|
||||
}
|
||||
// Verifies that the same moduli as in the NFLlib library are generated.
|
||||
// <https://github.com/quarkslab/NFLlib/blob/master/include/nfl/params.hpp>
|
||||
#[test]
|
||||
fn nfl_62bit_primes() {
|
||||
let mut generated = vec![];
|
||||
let mut upper_bound = u64::MAX >> 2;
|
||||
while generated.len() != 20 {
|
||||
let p = generate_prime(62, 2 * 1048576, upper_bound);
|
||||
assert!(p.is_some());
|
||||
upper_bound = p.unwrap();
|
||||
generated.push(upper_bound);
|
||||
}
|
||||
assert_eq!(
|
||||
generated,
|
||||
vec![
|
||||
4611686018326724609,
|
||||
4611686018309947393,
|
||||
4611686018282684417,
|
||||
4611686018257518593,
|
||||
4611686018232352769,
|
||||
4611686018171535361,
|
||||
4611686018106523649,
|
||||
4611686018058289153,
|
||||
4611686018051997697,
|
||||
4611686017974403073,
|
||||
4611686017812922369,
|
||||
4611686017781465089,
|
||||
4611686017773076481,
|
||||
4611686017678704641,
|
||||
4611686017666121729,
|
||||
4611686017647247361,
|
||||
4611686017590624257,
|
||||
4611686017554972673,
|
||||
4611686017529806849,
|
||||
4611686017517223937
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upper_bound() {
|
||||
debug_assert!(catch_unwind(|| generate_prime(62, 2 * 1048576, (1 << 62) + 1)).is_err());
|
||||
}
|
||||
#[test]
|
||||
fn upper_bound() {
|
||||
debug_assert!(catch_unwind(|| generate_prime(62, 2 * 1048576, (1 << 62) + 1)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn modulo_too_large() {
|
||||
assert!(generate_prime(10, 2048, 1 << 10).is_none());
|
||||
}
|
||||
#[test]
|
||||
fn modulo_too_large() {
|
||||
assert!(generate_prime(10, 2048, 1 << 10).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_found() {
|
||||
// 1033 is the smallest 11-bit prime congruent to 1 modulo 16, so looking for a
|
||||
// smaller one should fail.
|
||||
assert!(generate_prime(11, 16, 1033).is_none());
|
||||
}
|
||||
#[test]
|
||||
fn not_found() {
|
||||
// 1033 is the smallest 11-bit prime congruent to 1 modulo 16, so looking for a
|
||||
// smaller one should fail.
|
||||
assert!(generate_prime(11, 16, 1033).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,21 +13,21 @@ pub trait FheParameters {}
|
||||
|
||||
/// Indicates that an object is parametrized.
|
||||
pub trait FheParametrized {
|
||||
/// The type of the FHE parameters.
|
||||
type Parameters: FheParameters;
|
||||
/// The type of the FHE parameters.
|
||||
type Parameters: FheParameters;
|
||||
}
|
||||
|
||||
/// Indicates that Self parameters can be switched.
|
||||
pub trait FheParametersSwitchable<S: FheParametrized>
|
||||
where
|
||||
Self: FheParametrized,
|
||||
Self: FheParametrized,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Attempt to switch the underlying parameters using the associated
|
||||
/// switcher.
|
||||
fn switch_parameters(&mut self, switcher: &S) -> Result<(), Self::Error>;
|
||||
/// Attempt to switch the underlying parameters using the associated
|
||||
/// switcher.
|
||||
fn switch_parameters(&mut self, switcher: &S) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
/// Encoding used when encoding a [`FhePlaintext`].
|
||||
@@ -36,137 +36,137 @@ pub trait FhePlaintextEncoding {}
|
||||
/// A plaintext which will encode one (or more) value(s).
|
||||
pub trait FhePlaintext
|
||||
where
|
||||
Self: Sized + FheParametrized,
|
||||
Self: Sized + FheParametrized,
|
||||
{
|
||||
/// The type of the encoding.
|
||||
type Encoding: FhePlaintextEncoding;
|
||||
/// The type of the encoding.
|
||||
type Encoding: FhePlaintextEncoding;
|
||||
}
|
||||
|
||||
/// Encode a value using a specified encoding.
|
||||
pub trait FheEncoder<V>
|
||||
where
|
||||
Self: FhePlaintext,
|
||||
Self: FhePlaintext,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Attempt to encode a value using a specified encoding.
|
||||
fn try_encode(
|
||||
value: V,
|
||||
encoding: Self::Encoding,
|
||||
par: &Arc<Self::Parameters>,
|
||||
) -> Result<Self, Self::Error>;
|
||||
/// Attempt to encode a value using a specified encoding.
|
||||
fn try_encode(
|
||||
value: V,
|
||||
encoding: Self::Encoding,
|
||||
par: &Arc<Self::Parameters>,
|
||||
) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
/// Encode a value using a specified encoding.
|
||||
pub trait FheEncoderVariableTime<V>
|
||||
where
|
||||
Self: FhePlaintext,
|
||||
Self: FhePlaintext,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Attempt to encode a value using a specified encoding.
|
||||
/// # Safety
|
||||
/// This encoding runs in variable time and may leak information about the
|
||||
/// value.
|
||||
unsafe fn try_encode_vt(
|
||||
value: V,
|
||||
encoding: Self::Encoding,
|
||||
par: &Arc<Self::Parameters>,
|
||||
) -> Result<Self, Self::Error>;
|
||||
/// Attempt to encode a value using a specified encoding.
|
||||
/// # Safety
|
||||
/// This encoding runs in variable time and may leak information about the
|
||||
/// value.
|
||||
unsafe fn try_encode_vt(
|
||||
value: V,
|
||||
encoding: Self::Encoding,
|
||||
par: &Arc<Self::Parameters>,
|
||||
) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
/// Decode the value in the plaintext with the specified (optional) encoding.
|
||||
pub trait FheDecoder<P: FhePlaintext>
|
||||
where
|
||||
Self: Sized,
|
||||
Self: Sized,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Attempt to decode a [`FhePlaintext`] into a value, using an (optional)
|
||||
/// encoding.
|
||||
fn try_decode<O>(pt: &P, encoding: O) -> Result<Self, Self::Error>
|
||||
where
|
||||
O: Into<Option<P::Encoding>>;
|
||||
/// Attempt to decode a [`FhePlaintext`] into a value, using an (optional)
|
||||
/// encoding.
|
||||
fn try_decode<O>(pt: &P, encoding: O) -> Result<Self, Self::Error>
|
||||
where
|
||||
O: Into<Option<P::Encoding>>;
|
||||
}
|
||||
|
||||
/// A ciphertext which will encrypt a plaintext.
|
||||
pub trait FheCiphertext
|
||||
where
|
||||
Self: Sized + Serialize + FheParametrized + DeserializeParametrized,
|
||||
Self: Sized + Serialize + FheParametrized + DeserializeParametrized,
|
||||
{
|
||||
}
|
||||
|
||||
/// Encrypt a plaintext into a ciphertext.
|
||||
pub trait FheEncrypter<
|
||||
P: FhePlaintext<Parameters = Self::Parameters>,
|
||||
C: FheCiphertext<Parameters = Self::Parameters>,
|
||||
P: FhePlaintext<Parameters = Self::Parameters>,
|
||||
C: FheCiphertext<Parameters = Self::Parameters>,
|
||||
>: FheParametrized
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Try to encrypt an [`FhePlaintext`] into an [`FheCiphertext`].
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(&self, pt: &P, rng: &mut R) -> Result<C, Self::Error>;
|
||||
/// Try to encrypt an [`FhePlaintext`] into an [`FheCiphertext`].
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(&self, pt: &P, rng: &mut R) -> Result<C, Self::Error>;
|
||||
}
|
||||
|
||||
/// Decrypt a ciphertext into a plaintext
|
||||
pub trait FheDecrypter<
|
||||
P: FhePlaintext<Parameters = Self::Parameters>,
|
||||
C: FheCiphertext<Parameters = Self::Parameters>,
|
||||
P: FhePlaintext<Parameters = Self::Parameters>,
|
||||
C: FheCiphertext<Parameters = Self::Parameters>,
|
||||
>: FheParametrized
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Try to decrypt an [`FheCiphertext`] into an [`FhePlaintext`].
|
||||
fn try_decrypt(&self, ct: &C) -> Result<P, Self::Error>;
|
||||
/// Try to decrypt an [`FheCiphertext`] into an [`FhePlaintext`].
|
||||
fn try_decrypt(&self, ct: &C) -> Result<P, Self::Error>;
|
||||
}
|
||||
|
||||
/// Serialization.
|
||||
pub trait Serialize {
|
||||
/// Serialize `Self` into a vector of bytes.
|
||||
fn to_bytes(&self) -> Vec<u8>;
|
||||
/// Serialize `Self` into a vector of bytes.
|
||||
fn to_bytes(&self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Deserialization of a parametrized value.
|
||||
pub trait DeserializeParametrized
|
||||
where
|
||||
Self: Sized,
|
||||
Self: FheParametrized,
|
||||
Self: Sized,
|
||||
Self: FheParametrized,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Attempt to deserialize from a vector of bytes
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self, Self::Error>;
|
||||
/// Attempt to deserialize from a vector of bytes
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
/// Deserialization setting an explicit context.
|
||||
pub trait DeserializeWithContext
|
||||
where
|
||||
Self: Sized,
|
||||
Self: Sized,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// The type of context.
|
||||
type Context;
|
||||
/// The type of context.
|
||||
type Context;
|
||||
|
||||
/// Attempt to deserialize from a vector of bytes
|
||||
fn from_bytes(bytes: &[u8], ctx: &Arc<Self::Context>) -> Result<Self, Self::Error>;
|
||||
/// Attempt to deserialize from a vector of bytes
|
||||
fn from_bytes(bytes: &[u8], ctx: &Arc<Self::Context>) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
/// Deserialization without context.
|
||||
pub trait Deserialize
|
||||
where
|
||||
Self: Sized,
|
||||
Self: Sized,
|
||||
{
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
/// The type of error returned.
|
||||
type Error;
|
||||
|
||||
/// Attempt to deserialize from a vector of bytes
|
||||
fn try_deserialize(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
/// Attempt to deserialize from a vector of bytes
|
||||
fn try_deserialize(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,217 +8,217 @@ use std::ops::{Not, Shr, ShrAssign};
|
||||
pub struct U256(u64, u64, u64, u64);
|
||||
|
||||
impl U256 {
|
||||
/// Returns the additive identity element, 0.
|
||||
pub const fn zero() -> Self {
|
||||
Self(0, 0, 0, 0)
|
||||
}
|
||||
/// Returns the additive identity element, 0.
|
||||
pub const fn zero() -> Self {
|
||||
Self(0, 0, 0, 0)
|
||||
}
|
||||
|
||||
/// Add an U256 to self, wrapping modulo 2^256.
|
||||
pub fn wrapping_add_assign(&mut self, other: Self) {
|
||||
let (a, c1) = self.0.overflowing_add(other.0);
|
||||
let (b, c2) = self.1.overflowing_add(other.1);
|
||||
let (c, c3) = b.overflowing_add(c1 as u64);
|
||||
let (d, c4) = self.2.overflowing_add(other.2);
|
||||
let (e, c5) = d.overflowing_add((c2 | c3) as u64);
|
||||
let f = self.3.wrapping_add(other.3);
|
||||
let g = f.wrapping_add((c4 | c5) as u64);
|
||||
self.0 = a;
|
||||
self.1 = c;
|
||||
self.2 = e;
|
||||
self.3 = g;
|
||||
}
|
||||
/// Add an U256 to self, wrapping modulo 2^256.
|
||||
pub fn wrapping_add_assign(&mut self, other: Self) {
|
||||
let (a, c1) = self.0.overflowing_add(other.0);
|
||||
let (b, c2) = self.1.overflowing_add(other.1);
|
||||
let (c, c3) = b.overflowing_add(c1 as u64);
|
||||
let (d, c4) = self.2.overflowing_add(other.2);
|
||||
let (e, c5) = d.overflowing_add((c2 | c3) as u64);
|
||||
let f = self.3.wrapping_add(other.3);
|
||||
let g = f.wrapping_add((c4 | c5) as u64);
|
||||
self.0 = a;
|
||||
self.1 = c;
|
||||
self.2 = e;
|
||||
self.3 = g;
|
||||
}
|
||||
|
||||
/// Subtract an U256 to self, wrapping modulo 2^256.
|
||||
pub fn wrapping_sub_assign(&mut self, other: Self) {
|
||||
let (a, b1) = self.0.overflowing_sub(other.0);
|
||||
let (b, b2) = self.1.overflowing_sub(other.1);
|
||||
let (c, b3) = b.overflowing_sub(b1 as u64);
|
||||
let (d, b4) = self.2.overflowing_sub(other.2);
|
||||
let (e, b5) = d.overflowing_sub((b2 | b3) as u64);
|
||||
let f = self.3.wrapping_sub(other.3);
|
||||
let g = f.wrapping_sub((b4 | b5) as u64);
|
||||
self.0 = a;
|
||||
self.1 = c;
|
||||
self.2 = e;
|
||||
self.3 = g;
|
||||
}
|
||||
/// Subtract an U256 to self, wrapping modulo 2^256.
|
||||
pub fn wrapping_sub_assign(&mut self, other: Self) {
|
||||
let (a, b1) = self.0.overflowing_sub(other.0);
|
||||
let (b, b2) = self.1.overflowing_sub(other.1);
|
||||
let (c, b3) = b.overflowing_sub(b1 as u64);
|
||||
let (d, b4) = self.2.overflowing_sub(other.2);
|
||||
let (e, b5) = d.overflowing_sub((b2 | b3) as u64);
|
||||
let f = self.3.wrapping_sub(other.3);
|
||||
let g = f.wrapping_sub((b4 | b5) as u64);
|
||||
self.0 = a;
|
||||
self.1 = c;
|
||||
self.2 = e;
|
||||
self.3 = g;
|
||||
}
|
||||
|
||||
/// Returns the most significant bit of the unsigned integer.
|
||||
pub const fn msb(self) -> u64 {
|
||||
self.3 >> 63
|
||||
}
|
||||
/// Returns the most significant bit of the unsigned integer.
|
||||
pub const fn msb(self) -> u64 {
|
||||
self.3 >> 63
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u64; 4]> for U256 {
|
||||
fn from(a: [u64; 4]) -> Self {
|
||||
Self(a[0], a[1], a[2], a[3])
|
||||
}
|
||||
fn from(a: [u64; 4]) -> Self {
|
||||
Self(a[0], a[1], a[2], a[3])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u128; 2]> for U256 {
|
||||
fn from(a: [u128; 2]) -> Self {
|
||||
Self(
|
||||
a[0] as u64,
|
||||
(a[0] >> 64) as u64,
|
||||
a[1] as u64,
|
||||
(a[1] >> 64) as u64,
|
||||
)
|
||||
}
|
||||
fn from(a: [u128; 2]) -> Self {
|
||||
Self(
|
||||
a[0] as u64,
|
||||
(a[0] >> 64) as u64,
|
||||
a[1] as u64,
|
||||
(a[1] >> 64) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<U256> for [u64; 4] {
|
||||
fn from(a: U256) -> [u64; 4] {
|
||||
[a.0, a.1, a.2, a.3]
|
||||
}
|
||||
fn from(a: U256) -> [u64; 4] {
|
||||
[a.0, a.1, a.2, a.3]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<U256> for [u128; 2] {
|
||||
fn from(a: U256) -> [u128; 2] {
|
||||
[
|
||||
(a.0 as u128) + ((a.1 as u128) << 64),
|
||||
(a.2 as u128) + ((a.3 as u128) << 64),
|
||||
]
|
||||
}
|
||||
fn from(a: U256) -> [u128; 2] {
|
||||
[
|
||||
(a.0 as u128) + ((a.1 as u128) << 64),
|
||||
(a.2 as u128) + ((a.3 as u128) << 64),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&U256> for u128 {
|
||||
fn from(v: &U256) -> Self {
|
||||
debug_assert!(v.2 == 0 && v.3 == 0);
|
||||
(v.0 as u128) + ((v.1 as u128) << 64)
|
||||
}
|
||||
fn from(v: &U256) -> Self {
|
||||
debug_assert!(v.2 == 0 && v.3 == 0);
|
||||
(v.0 as u128) + ((v.1 as u128) << 64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Not for U256 {
|
||||
type Output = Self;
|
||||
type Output = Self;
|
||||
|
||||
fn not(self) -> Self {
|
||||
Self(!self.0, !self.1, !self.2, !self.3)
|
||||
}
|
||||
fn not(self) -> Self {
|
||||
Self(!self.0, !self.1, !self.2, !self.3)
|
||||
}
|
||||
}
|
||||
|
||||
impl Shr<usize> for U256 {
|
||||
type Output = Self;
|
||||
type Output = Self;
|
||||
|
||||
fn shr(self, rhs: usize) -> Self {
|
||||
let mut r = self;
|
||||
r >>= rhs;
|
||||
r
|
||||
}
|
||||
fn shr(self, rhs: usize) -> Self {
|
||||
let mut r = self;
|
||||
r >>= rhs;
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
impl ShrAssign<usize> for U256 {
|
||||
fn shr_assign(&mut self, rhs: usize) {
|
||||
debug_assert!(rhs < 256);
|
||||
fn shr_assign(&mut self, rhs: usize) {
|
||||
debug_assert!(rhs < 256);
|
||||
|
||||
if rhs >= 192 {
|
||||
self.0 = self.3 >> (rhs - 192);
|
||||
self.1 = 0;
|
||||
self.2 = 0;
|
||||
self.3 = 0;
|
||||
} else if rhs > 128 {
|
||||
self.0 = (self.2 >> (rhs - 128)) | (self.3 << (192 - rhs));
|
||||
self.1 = self.3 >> (rhs - 128);
|
||||
self.2 = 0;
|
||||
self.3 = 0;
|
||||
} else if rhs == 128 {
|
||||
self.0 = self.2;
|
||||
self.1 = self.3;
|
||||
self.2 = 0;
|
||||
self.3 = 0;
|
||||
} else if rhs > 64 {
|
||||
self.0 = (self.1 >> (rhs - 64)) | (self.2 << (128 - rhs));
|
||||
self.1 = (self.2 >> (rhs - 64)) | (self.3 << (128 - rhs));
|
||||
self.2 = self.3 >> (rhs - 64);
|
||||
self.3 = 0;
|
||||
} else if rhs == 64 {
|
||||
self.0 = self.1;
|
||||
self.1 = self.2;
|
||||
self.2 = self.3;
|
||||
self.3 = 0;
|
||||
} else if rhs > 0 {
|
||||
self.0 = (self.0 >> rhs) | (self.1 << (64 - rhs));
|
||||
self.1 = (self.1 >> rhs) | (self.2 << (64 - rhs));
|
||||
self.2 = (self.2 >> rhs) | (self.3 << (64 - rhs));
|
||||
self.3 >>= rhs;
|
||||
}
|
||||
}
|
||||
if rhs >= 192 {
|
||||
self.0 = self.3 >> (rhs - 192);
|
||||
self.1 = 0;
|
||||
self.2 = 0;
|
||||
self.3 = 0;
|
||||
} else if rhs > 128 {
|
||||
self.0 = (self.2 >> (rhs - 128)) | (self.3 << (192 - rhs));
|
||||
self.1 = self.3 >> (rhs - 128);
|
||||
self.2 = 0;
|
||||
self.3 = 0;
|
||||
} else if rhs == 128 {
|
||||
self.0 = self.2;
|
||||
self.1 = self.3;
|
||||
self.2 = 0;
|
||||
self.3 = 0;
|
||||
} else if rhs > 64 {
|
||||
self.0 = (self.1 >> (rhs - 64)) | (self.2 << (128 - rhs));
|
||||
self.1 = (self.2 >> (rhs - 64)) | (self.3 << (128 - rhs));
|
||||
self.2 = self.3 >> (rhs - 64);
|
||||
self.3 = 0;
|
||||
} else if rhs == 64 {
|
||||
self.0 = self.1;
|
||||
self.1 = self.2;
|
||||
self.2 = self.3;
|
||||
self.3 = 0;
|
||||
} else if rhs > 0 {
|
||||
self.0 = (self.0 >> rhs) | (self.1 << (64 - rhs));
|
||||
self.1 = (self.1 >> rhs) | (self.2 << (64 - rhs));
|
||||
self.2 = (self.2 >> rhs) | (self.3 << (64 - rhs));
|
||||
self.3 >>= rhs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::U256;
|
||||
use super::U256;
|
||||
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert_eq!(u128::from(&U256::zero()), 0u128);
|
||||
}
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert_eq!(u128::from(&U256::zero()), 0u128);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
proptest! {
|
||||
|
||||
#[test]
|
||||
fn u128(a: u128) {
|
||||
prop_assert_eq!(a, u128::from(&U256::from([a, 0])));
|
||||
}
|
||||
#[test]
|
||||
fn u128(a: u128) {
|
||||
prop_assert_eq!(a, u128::from(&U256::from([a, 0])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_into_u64(a: u64, b: u64, c: u64, d:u64) {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d])), [a, b, c, d]);
|
||||
}
|
||||
#[test]
|
||||
fn from_into_u64(a: u64, b: u64, c: u64, d:u64) {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d])), [a, b, c, d]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_into_u128(a: u128, b: u128) {
|
||||
prop_assert_eq!(<[u128; 2]>::from(U256::from([a, b])), [a, b]);
|
||||
}
|
||||
#[test]
|
||||
fn from_into_u128(a: u128, b: u128) {
|
||||
prop_assert_eq!(<[u128; 2]>::from(U256::from([a, b])), [a, b]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shift(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 0), [a, b, c, d]);
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 64), [b, c, d, 0]);
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 128), [c, d, 0, 0]);
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 192), [d, 0, 0, 0]);
|
||||
#[test]
|
||||
fn shift(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 0), [a, b, c, d]);
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 64), [b, c, d, 0]);
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 128), [c, d, 0, 0]);
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 192), [d, 0, 0, 0]);
|
||||
|
||||
prop_assume!(shift % 64 != 0);
|
||||
if shift < 64 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
|
||||
} else if shift < 128 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
|
||||
} else if shift < 192 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
|
||||
} else {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(d >> (shift - 192)), 0, 0, 0]);
|
||||
}
|
||||
}
|
||||
prop_assume!(shift % 64 != 0);
|
||||
if shift < 64 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
|
||||
} else if shift < 128 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
|
||||
} else if shift < 192 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
|
||||
} else {
|
||||
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(d >> (shift - 192)), 0, 0, 0]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shift_assign(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
|
||||
let mut u = U256::from([a, b, c, d]);
|
||||
#[test]
|
||||
fn shift_assign(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
|
||||
let mut u = U256::from([a, b, c, d]);
|
||||
|
||||
u >>= 0;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [a, b, c, d]);
|
||||
u >>= 0;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [a, b, c, d]);
|
||||
|
||||
u >>= 64;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [b, c, d, 0]);
|
||||
u >>= 64;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [b, c, d, 0]);
|
||||
|
||||
u = U256::from([a, b, c, d]);
|
||||
u >>= 128;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [c, d, 0, 0]);
|
||||
u = U256::from([a, b, c, d]);
|
||||
u >>= 128;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [c, d, 0, 0]);
|
||||
|
||||
u = U256::from([a, b, c, d]);
|
||||
u >>= 192;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [d, 0, 0, 0]);
|
||||
u = U256::from([a, b, c, d]);
|
||||
u >>= 192;
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [d, 0, 0, 0]);
|
||||
|
||||
prop_assume!(shift % 64 != 0);
|
||||
u = U256::from([a, b, c, d]);
|
||||
u >>= shift;
|
||||
if shift < 64 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
|
||||
} else if shift < 128 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
|
||||
} else if shift < 192 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
|
||||
} else {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(d >> (shift - 192)), 0, 0, 0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
prop_assume!(shift % 64 != 0);
|
||||
u = U256::from([a, b, c, d]);
|
||||
u >>= shift;
|
||||
if shift < 64 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
|
||||
} else if shift < 128 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
|
||||
} else if shift < 192 {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
|
||||
} else {
|
||||
prop_assert_eq!(<[u64; 4]>::from(u), [(d >> (shift - 192)), 0, 0, 0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use fhe::bfv::{
|
||||
BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Multiplicator, Plaintext, PublicKey,
|
||||
RelinearizationKey, SecretKey,
|
||||
BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Multiplicator, Plaintext, PublicKey,
|
||||
RelinearizationKey, SecretKey,
|
||||
};
|
||||
use fhe_math::rns::{RnsContext, ScalingFactor};
|
||||
use fhe_math::zq::primes::generate_prime;
|
||||
@@ -13,275 +13,275 @@ use rand::{rngs::OsRng, thread_rng};
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn bfv_benchmark(c: &mut Criterion) {
|
||||
let mut rng = thread_rng();
|
||||
let mut group = c.benchmark_group("bfv");
|
||||
group.sample_size(10);
|
||||
group.warm_up_time(Duration::from_millis(600));
|
||||
group.measurement_time(Duration::from_millis(1000));
|
||||
let mut rng = thread_rng();
|
||||
let mut group = c.benchmark_group("bfv");
|
||||
group.sample_size(10);
|
||||
group.warm_up_time(Duration::from_millis(600));
|
||||
group.measurement_time(Duration::from_millis(1000));
|
||||
|
||||
for par in BfvParameters::default_parameters_128(20) {
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let ek = if par.moduli().len() > 1 {
|
||||
Some(
|
||||
EvaluationKeyBuilder::new(&sk)
|
||||
.unwrap()
|
||||
.enable_inner_sum()
|
||||
.unwrap()
|
||||
.enable_column_rotation(1)
|
||||
.unwrap()
|
||||
.enable_expansion(ilog2(par.degree() as u64))
|
||||
.unwrap()
|
||||
.build(&mut rng)
|
||||
.unwrap(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
for par in BfvParameters::default_parameters_128(20) {
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let ek = if par.moduli().len() > 1 {
|
||||
Some(
|
||||
EvaluationKeyBuilder::new(&sk)
|
||||
.unwrap()
|
||||
.enable_inner_sum()
|
||||
.unwrap()
|
||||
.enable_column_rotation(1)
|
||||
.unwrap()
|
||||
.enable_expansion(ilog2(par.degree() as u64))
|
||||
.unwrap()
|
||||
.build(&mut rng)
|
||||
.unwrap(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let rk = if par.moduli().len() > 1 {
|
||||
Some(RelinearizationKey::new(&sk, &mut rng).unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let rk = if par.moduli().len() > 1 {
|
||||
Some(RelinearizationKey::new(&sk, &mut rng).unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par).unwrap();
|
||||
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), &par).unwrap();
|
||||
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
|
||||
let c2: Ciphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
|
||||
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par).unwrap();
|
||||
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), &par).unwrap();
|
||||
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
|
||||
let c2: Ciphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
|
||||
|
||||
let q = par.moduli_sizes().iter().sum::<usize>();
|
||||
let q = par.moduli_sizes().iter().sum::<usize>();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("keygen_sk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| SecretKey::random(&par, &mut OsRng));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("keygen_sk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| SecretKey::random(&par, &mut OsRng));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("keygen_pk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| PublicKey::new(&sk, &mut rng));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("keygen_pk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| PublicKey::new(&sk, &mut rng));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("keygen_rk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| RelinearizationKey::new(&sk, &mut rng));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("keygen_rk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| RelinearizationKey::new(&sk, &mut rng));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("encode_poly", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), &par));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("encode_poly", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), &par));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("encode_simd", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("encode_simd", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("encrypt_sk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let _: fhe::Result<Ciphertext> = sk.try_encrypt(&pt1, &mut rng);
|
||||
});
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("encrypt_sk", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let _: fhe::Result<Ciphertext> = sk.try_encrypt(&pt1, &mut rng);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 + &c2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 + &c2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 += &c2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 += &c2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 + &pt2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 + &pt2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 += &pt2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("add_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 += &pt2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 - &c2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 - &c2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 -= &c2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 -= &c2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 - &pt2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = &c1 - &pt2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 -= &pt2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("sub_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 -= &pt2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("neg", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = -&c2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("neg", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = -&c2);
|
||||
},
|
||||
);
|
||||
|
||||
let mut c3 = &c1 * &c1;
|
||||
let c3_clone = c3.clone();
|
||||
if let Some(rk) = rk.as_ref() {
|
||||
group.bench_function(
|
||||
BenchmarkId::new("relinearize", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
assert!(rk.relinearizes(&mut c3).is_ok());
|
||||
c3 = c3_clone.clone();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
let mut c3 = &c1 * &c1;
|
||||
let c3_clone = c3.clone();
|
||||
if let Some(rk) = rk.as_ref() {
|
||||
group.bench_function(
|
||||
BenchmarkId::new("relinearize", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
assert!(rk.relinearizes(&mut c3).is_ok());
|
||||
c3 = c3_clone.clone();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(ek) = ek {
|
||||
group.bench_function(
|
||||
BenchmarkId::new("rotate_rows", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = ek.rotates_rows(&c1).unwrap());
|
||||
},
|
||||
);
|
||||
if let Some(ek) = ek {
|
||||
group.bench_function(
|
||||
BenchmarkId::new("rotate_rows", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = ek.rotates_rows(&c1).unwrap());
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("rotate_columns", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = ek.rotates_columns_by(&c1, 1).unwrap());
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("rotate_columns", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = ek.rotates_columns_by(&c1, 1).unwrap());
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("inner_sum", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = ek.computes_inner_sum(&c1).unwrap());
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("inner_sum", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| c1 = ek.computes_inner_sum(&c1).unwrap());
|
||||
},
|
||||
);
|
||||
|
||||
for i in 1..=ilog2(par.degree() as u64) {
|
||||
if par.degree() > 2048 && i > 4 {
|
||||
continue; // Skip slow benchmarks
|
||||
}
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
format!("expand_{i}"),
|
||||
format!("n={}/log(q)={}", par.degree(), q),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| ek.expands(&c1, 1 << i).unwrap());
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
for i in 1..=ilog2(par.degree() as u64) {
|
||||
if par.degree() > 2048 && i > 4 {
|
||||
continue; // Skip slow benchmarks
|
||||
}
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
format!("expand_{i}"),
|
||||
format!("n={}/log(q)={}", par.degree(), q),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| ek.expands(&c1, 1 << i).unwrap());
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("mul", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| &c1 * &c2);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("mul", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| &c1 * &c2);
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("square", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| &c1 * &c1);
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("square", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| &c1 * &c1);
|
||||
},
|
||||
);
|
||||
|
||||
if let Some(rk) = rk.as_ref() {
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_then_relinearize",
|
||||
format!("n={}/log(q)={}", par.degree(), q),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
c3 = &c1 * &c2;
|
||||
assert!(rk.relinearizes(&mut c3).is_ok());
|
||||
});
|
||||
},
|
||||
);
|
||||
if let Some(rk) = rk.as_ref() {
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_then_relinearize",
|
||||
format!("n={}/log(q)={}", par.degree(), q),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
c3 = &c1 * &c2;
|
||||
assert!(rk.relinearizes(&mut c3).is_ok());
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
// Default multiplication method
|
||||
let multiplicator = Multiplicator::default(rk).unwrap();
|
||||
// Default multiplication method
|
||||
let multiplicator = Multiplicator::default(rk).unwrap();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("mul_and_relin", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("mul_and_relin", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
|
||||
},
|
||||
);
|
||||
|
||||
// Second multiplication option.
|
||||
let nmoduli = div_ceil(q, 62);
|
||||
let mut extended_basis = par.moduli().to_vec();
|
||||
let mut upper_bound = u64::MAX >> 2;
|
||||
while extended_basis.len() != nmoduli + par.moduli().len() {
|
||||
upper_bound = generate_prime(62, 2 * par.degree() as u64, upper_bound).unwrap();
|
||||
if !extended_basis.contains(&upper_bound) {
|
||||
extended_basis.push(upper_bound)
|
||||
}
|
||||
}
|
||||
let rns_q = RnsContext::new(&extended_basis[..par.moduli().len()]).unwrap();
|
||||
let rns_p = RnsContext::new(&extended_basis[par.moduli().len()..]).unwrap();
|
||||
let mut multiplicator = Multiplicator::new(
|
||||
ScalingFactor::one(),
|
||||
ScalingFactor::new(rns_p.modulus(), rns_q.modulus()),
|
||||
&extended_basis,
|
||||
ScalingFactor::new(&BigUint::from(par.plaintext()), rns_p.modulus()),
|
||||
&par,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(multiplicator.enable_relinearization(rk).is_ok());
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_and_relin_2",
|
||||
format!("n={}/log(q)={}", par.degree(), q),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
// Second multiplication option.
|
||||
let nmoduli = div_ceil(q, 62);
|
||||
let mut extended_basis = par.moduli().to_vec();
|
||||
let mut upper_bound = u64::MAX >> 2;
|
||||
while extended_basis.len() != nmoduli + par.moduli().len() {
|
||||
upper_bound = generate_prime(62, 2 * par.degree() as u64, upper_bound).unwrap();
|
||||
if !extended_basis.contains(&upper_bound) {
|
||||
extended_basis.push(upper_bound)
|
||||
}
|
||||
}
|
||||
let rns_q = RnsContext::new(&extended_basis[..par.moduli().len()]).unwrap();
|
||||
let rns_p = RnsContext::new(&extended_basis[par.moduli().len()..]).unwrap();
|
||||
let mut multiplicator = Multiplicator::new(
|
||||
ScalingFactor::one(),
|
||||
ScalingFactor::new(rns_p.modulus(), rns_q.modulus()),
|
||||
&extended_basis,
|
||||
ScalingFactor::new(&BigUint::from(par.plaintext()), rns_p.modulus()),
|
||||
&par,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(multiplicator.enable_relinearization(rk).is_ok());
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"mul_and_relin_2",
|
||||
format!("n={}/log(q)={}", par.degree(), q),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(bfv, bfv_benchmark);
|
||||
|
||||
@@ -6,66 +6,66 @@ use rand::{rngs::OsRng, thread_rng};
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn bfv_benchmark(c: &mut Criterion) {
|
||||
let mut rng = thread_rng();
|
||||
let mut group = c.benchmark_group("bfv_optimized_ops");
|
||||
group.sample_size(10);
|
||||
group.warm_up_time(Duration::from_secs(1));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
let mut rng = thread_rng();
|
||||
let mut group = c.benchmark_group("bfv_optimized_ops");
|
||||
group.sample_size(10);
|
||||
group.warm_up_time(Duration::from_secs(1));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
|
||||
for par in &BfvParameters::default_parameters_128(20)[2..] {
|
||||
for size in [10, 128, 1000] {
|
||||
let sk = SecretKey::random(par, &mut OsRng);
|
||||
let pt1 =
|
||||
Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), par).unwrap();
|
||||
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
|
||||
for par in &BfvParameters::default_parameters_128(20)[2..] {
|
||||
for size in [10, 128, 1000] {
|
||||
let sk = SecretKey::random(par, &mut OsRng);
|
||||
let pt1 =
|
||||
Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), par).unwrap();
|
||||
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
|
||||
|
||||
let ct_vec = (0..size)
|
||||
.map(|i| {
|
||||
let pt =
|
||||
Plaintext::try_encode(&(i..16u64).collect_vec(), Encoding::poly(), par)
|
||||
.unwrap();
|
||||
sk.try_encrypt(&pt, &mut rng).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let pt_vec = (0..size)
|
||||
.map(|i| {
|
||||
Plaintext::try_encode(&(i..39u64).collect_vec(), Encoding::poly(), par).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let ct_vec = (0..size)
|
||||
.map(|i| {
|
||||
let pt =
|
||||
Plaintext::try_encode(&(i..16u64).collect_vec(), Encoding::poly(), par)
|
||||
.unwrap();
|
||||
sk.try_encrypt(&pt, &mut rng).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let pt_vec = (0..size)
|
||||
.map(|i| {
|
||||
Plaintext::try_encode(&(i..39u64).collect_vec(), Encoding::poly(), par).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"dot_product/naive",
|
||||
format!(
|
||||
"size={}/degree={}/logq={}",
|
||||
size,
|
||||
par.degree(),
|
||||
par.moduli_sizes().iter().sum::<usize>()
|
||||
),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| izip!(&ct_vec, &pt_vec).for_each(|(cti, pti)| c1 += &(cti * pti)));
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"dot_product/naive",
|
||||
format!(
|
||||
"size={}/degree={}/logq={}",
|
||||
size,
|
||||
par.degree(),
|
||||
par.moduli_sizes().iter().sum::<usize>()
|
||||
),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| izip!(&ct_vec, &pt_vec).for_each(|(cti, pti)| c1 += &(cti * pti)));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"dot_product/opt",
|
||||
format!(
|
||||
"size={}/degree={}/logq={}",
|
||||
size,
|
||||
par.degree(),
|
||||
par.moduli_sizes().iter().sum::<usize>()
|
||||
),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| dot_product_scalar(ct_vec.iter(), pt_vec.iter()));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"dot_product/opt",
|
||||
format!(
|
||||
"size={}/degree={}/logq={}",
|
||||
size,
|
||||
par.degree(),
|
||||
par.moduli_sizes().iter().sum::<usize>()
|
||||
),
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| dot_product_scalar(ct_vec.iter(), pt_vec.iter()));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(bfv, bfv_benchmark);
|
||||
|
||||
@@ -6,30 +6,30 @@ use rand::{rngs::OsRng, thread_rng};
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn bfv_rgsw_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("bfv_rgsw");
|
||||
group.sample_size(10);
|
||||
group.warm_up_time(Duration::from_secs(1));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
let mut group = c.benchmark_group("bfv_rgsw");
|
||||
group.sample_size(10);
|
||||
group.warm_up_time(Duration::from_secs(1));
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
|
||||
for par in &BfvParameters::default_parameters_128(20)[2..] {
|
||||
let mut rng = thread_rng();
|
||||
let sk = SecretKey::random(par, &mut OsRng);
|
||||
for par in &BfvParameters::default_parameters_128(20)[2..] {
|
||||
let mut rng = thread_rng();
|
||||
let sk = SecretKey::random(par, &mut OsRng);
|
||||
|
||||
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), par).unwrap();
|
||||
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), par).unwrap();
|
||||
let c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
|
||||
let c2: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
|
||||
let q = par.moduli_sizes().iter().sum::<usize>();
|
||||
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), par).unwrap();
|
||||
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), par).unwrap();
|
||||
let c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
|
||||
let c2: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
|
||||
let q = par.moduli_sizes().iter().sum::<usize>();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("external", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| &c1 * &c2);
|
||||
},
|
||||
);
|
||||
}
|
||||
group.bench_function(
|
||||
BenchmarkId::new("external", format!("n={}/log(q)={}", par.degree(), q)),
|
||||
|b| {
|
||||
b.iter(|| &c1 * &c2);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(bfv_rgsw, bfv_rgsw_benchmark);
|
||||
|
||||
@@ -11,260 +11,260 @@ mod util;
|
||||
use console::style;
|
||||
use fhe::bfv;
|
||||
use fhe_traits::{
|
||||
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize,
|
||||
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize,
|
||||
};
|
||||
use fhe_util::{ilog2, inverse, transcode_to_bytes};
|
||||
use indicatif::HumanBytes;
|
||||
use rand::{rngs::OsRng, thread_rng, RngCore};
|
||||
use std::{env, error::Error, process::exit, sync::Arc};
|
||||
use util::{
|
||||
encode_database, generate_database, number_elements_per_plaintext,
|
||||
timeit::{timeit, timeit_n},
|
||||
encode_database, generate_database, number_elements_per_plaintext,
|
||||
timeit::{timeit, timeit_n},
|
||||
};
|
||||
|
||||
fn print_notice_and_exit(max_element_size: usize, error: Option<String>) {
|
||||
println!(
|
||||
"{} MulPIR with fhe.rs",
|
||||
style(" overview:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} mulpir- [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
|
||||
style(" usage:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} {} must be at least 1, and {} must be between 1 and {}",
|
||||
style("constraints:").magenta().bold(),
|
||||
style("database_size").blue(),
|
||||
style("element_size").blue(),
|
||||
max_element_size
|
||||
);
|
||||
if let Some(error) = error {
|
||||
println!("{} {}", style(" error:").red().bold(), error);
|
||||
}
|
||||
exit(0);
|
||||
println!(
|
||||
"{} MulPIR with fhe.rs",
|
||||
style(" overview:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} mulpir- [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
|
||||
style(" usage:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} {} must be at least 1, and {} must be between 1 and {}",
|
||||
style("constraints:").magenta().bold(),
|
||||
style("database_size").blue(),
|
||||
style("element_size").blue(),
|
||||
max_element_size
|
||||
);
|
||||
if let Some(error) = error {
|
||||
println!("{} {}", style(" error:").red().bold(), error);
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
// We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf.
|
||||
let degree = 8192;
|
||||
let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1;
|
||||
let moduli_sizes = [50, 55, 55];
|
||||
// We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf.
|
||||
let degree = 8192;
|
||||
let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1;
|
||||
let moduli_sizes = [50, 55, 55];
|
||||
|
||||
// Compute what is the maximum byte-length of an element to fit within one
|
||||
// ciphertext. Each coefficient of the ciphertext polynomial can contain
|
||||
// floor(log2(plaintext_modulus)) bits.
|
||||
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
|
||||
// Compute what is the maximum byte-length of an element to fit within one
|
||||
// ciphertext. Each coefficient of the ciphertext polynomial can contain
|
||||
// floor(log2(plaintext_modulus)) bits.
|
||||
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
|
||||
|
||||
// This executable is a command line tool which enables to specify different
|
||||
// database and element sizes.
|
||||
let args: Vec<String> = env::args().skip(1).collect();
|
||||
// This executable is a command line tool which enables to specify different
|
||||
// database and element sizes.
|
||||
let args: Vec<String> = env::args().skip(1).collect();
|
||||
|
||||
// Print the help if requested.
|
||||
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
|
||||
print_notice_and_exit(max_element_size, None)
|
||||
}
|
||||
// Print the help if requested.
|
||||
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
|
||||
print_notice_and_exit(max_element_size, None)
|
||||
}
|
||||
|
||||
// Use the default values from <https://eprint.iacr.org/2019/1483.pdf>.
|
||||
let mut database_size = 1 << 20;
|
||||
let mut elements_size = 288;
|
||||
// Use the default values from <https://eprint.iacr.org/2019/1483.pdf>.
|
||||
let mut database_size = 1 << 20;
|
||||
let mut elements_size = 288;
|
||||
|
||||
// Update the database size and/or element size depending on the arguments
|
||||
// provided.
|
||||
for arg in &args {
|
||||
if arg.starts_with("--database_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--database_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
database_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else if arg.starts_with("--element_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--element_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
elements_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some(format!("Unrecognized command: {arg}")),
|
||||
)
|
||||
}
|
||||
}
|
||||
// Update the database size and/or element size depending on the arguments
|
||||
// provided.
|
||||
for arg in &args {
|
||||
if arg.starts_with("--database_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--database_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
database_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else if arg.starts_with("--element_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--element_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
elements_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some(format!("Unrecognized command: {arg}")),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Element or database sizes out of bound".to_string()),
|
||||
)
|
||||
}
|
||||
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Element or database sizes out of bound".to_string()),
|
||||
)
|
||||
}
|
||||
|
||||
// The parameters are within bound, let's go! Let's first display some
|
||||
// information about the database.
|
||||
println!("# MulPIR with fhe.rs");
|
||||
println!(
|
||||
"database of {}",
|
||||
HumanBytes((database_size * elements_size) as u64)
|
||||
);
|
||||
println!("\tdatabase_size = {database_size}");
|
||||
println!("\telements_size = {elements_size}");
|
||||
// The parameters are within bound, let's go! Let's first display some
|
||||
// information about the database.
|
||||
println!("# MulPIR with fhe.rs");
|
||||
println!(
|
||||
"database of {}",
|
||||
HumanBytes((database_size * elements_size) as u64)
|
||||
);
|
||||
println!("\tdatabase_size = {database_size}");
|
||||
println!("\telements_size = {elements_size}");
|
||||
|
||||
// Generation of a random database.
|
||||
let database = timeit!("Database generation", {
|
||||
generate_database(database_size, elements_size)
|
||||
});
|
||||
// Generation of a random database.
|
||||
let database = timeit!("Database generation", {
|
||||
generate_database(database_size, elements_size)
|
||||
});
|
||||
|
||||
// Let's generate the BFV parameters structure.
|
||||
let params = timeit!(
|
||||
"Parameters generation",
|
||||
Arc::new(
|
||||
bfv::BfvParametersBuilder::new()
|
||||
.set_degree(degree)
|
||||
.set_plaintext_modulus(plaintext_modulus)
|
||||
.set_moduli_sizes(&moduli_sizes)
|
||||
.build()
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
// Let's generate the BFV parameters structure.
|
||||
let params = timeit!(
|
||||
"Parameters generation",
|
||||
Arc::new(
|
||||
bfv::BfvParametersBuilder::new()
|
||||
.set_degree(degree)
|
||||
.set_plaintext_modulus(plaintext_modulus)
|
||||
.set_moduli_sizes(&moduli_sizes)
|
||||
.build()
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
|
||||
// Proprocess the database on the server side: the database will be reshaped
|
||||
// so as to pack as many values as possible in every row so that it fits in one
|
||||
// ciphertext, and each element will be encoded as a polynomial in Ntt
|
||||
// representation.
|
||||
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
|
||||
encode_database(&database, params.clone(), 1)
|
||||
});
|
||||
// Proprocess the database on the server side: the database will be reshaped
|
||||
// so as to pack as many values as possible in every row so that it fits in one
|
||||
// ciphertext, and each element will be encoded as a polynomial in Ntt
|
||||
// representation.
|
||||
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
|
||||
encode_database(&database, params.clone(), 1)
|
||||
});
|
||||
|
||||
// Client setup: the client generates a secret key, an evaluation key for
|
||||
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
|
||||
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a
|
||||
// relinearization key.
|
||||
let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", {
|
||||
let sk = bfv::SecretKey::random(¶ms, &mut OsRng);
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
println!("level = {level}");
|
||||
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
|
||||
.enable_expansion(level)?
|
||||
.build(&mut thread_rng())?;
|
||||
let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut thread_rng())?;
|
||||
let ek_expansion_serialized = ek_expansion.to_bytes();
|
||||
let rk_serialized = rk.to_bytes();
|
||||
(sk, ek_expansion_serialized, rk_serialized)
|
||||
});
|
||||
println!(
|
||||
"📄 Evaluation key (expansion): {}",
|
||||
HumanBytes(ek_expansion_serialized.len() as u64)
|
||||
);
|
||||
println!(
|
||||
"📄 Relinearization key: {}",
|
||||
HumanBytes(rk_serialized.len() as u64)
|
||||
);
|
||||
// Client setup: the client generates a secret key, an evaluation key for
|
||||
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
|
||||
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a
|
||||
// relinearization key.
|
||||
let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", {
|
||||
let sk = bfv::SecretKey::random(¶ms, &mut OsRng);
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
println!("level = {level}");
|
||||
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
|
||||
.enable_expansion(level)?
|
||||
.build(&mut thread_rng())?;
|
||||
let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut thread_rng())?;
|
||||
let ek_expansion_serialized = ek_expansion.to_bytes();
|
||||
let rk_serialized = rk.to_bytes();
|
||||
(sk, ek_expansion_serialized, rk_serialized)
|
||||
});
|
||||
println!(
|
||||
"📄 Evaluation key (expansion): {}",
|
||||
HumanBytes(ek_expansion_serialized.len() as u64)
|
||||
);
|
||||
println!(
|
||||
"📄 Relinearization key: {}",
|
||||
HumanBytes(rk_serialized.len() as u64)
|
||||
);
|
||||
|
||||
// Server setup: the server receives the evaluation and relinearization keys and
|
||||
// deserializes them.
|
||||
let (ek_expansion, rk) = timeit!("Server setup", {
|
||||
(
|
||||
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?,
|
||||
bfv::RelinearizationKey::from_bytes(&rk_serialized, ¶ms)?,
|
||||
)
|
||||
});
|
||||
// Server setup: the server receives the evaluation and relinearization keys and
|
||||
// deserializes them.
|
||||
let (ek_expansion, rk) = timeit!("Server setup", {
|
||||
(
|
||||
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?,
|
||||
bfv::RelinearizationKey::from_bytes(&rk_serialized, ¶ms)?,
|
||||
)
|
||||
});
|
||||
|
||||
// Client query: when the client wants to retrieve the `index`-th row of the
|
||||
// original database, it first computes to which row it corresponds in the
|
||||
// original database, and then encrypt a selection vector with 0 everywhere,
|
||||
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
|
||||
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
|
||||
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
|
||||
// The ciphertext is set at level `1`, which means that one of the three moduli
|
||||
// has been dropped already; the reason is that the expansion will happen at
|
||||
// level 0 (with all three moduli) and then one of the moduli will be dropped
|
||||
// to reduce the noise.
|
||||
let index = (thread_rng().next_u64() as usize) % database_size;
|
||||
let query = timeit!("Client query", {
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
let query_index = index
|
||||
/ number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
let mut pt = vec![0u64; dim1 + dim2];
|
||||
let inv = inverse(1 << level, plaintext_modulus).unwrap();
|
||||
pt[query_index / dim2] = inv;
|
||||
pt[dim1 + (query_index % dim2)] = inv;
|
||||
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?;
|
||||
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
|
||||
query.to_bytes()
|
||||
});
|
||||
println!("📄 Query: {}", HumanBytes(query.len() as u64));
|
||||
// Client query: when the client wants to retrieve the `index`-th row of the
|
||||
// original database, it first computes to which row it corresponds in the
|
||||
// original database, and then encrypt a selection vector with 0 everywhere,
|
||||
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
|
||||
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
|
||||
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
|
||||
// The ciphertext is set at level `1`, which means that one of the three moduli
|
||||
// has been dropped already; the reason is that the expansion will happen at
|
||||
// level 0 (with all three moduli) and then one of the moduli will be dropped
|
||||
// to reduce the noise.
|
||||
let index = (thread_rng().next_u64() as usize) % database_size;
|
||||
let query = timeit!("Client query", {
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
let query_index = index
|
||||
/ number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
let mut pt = vec![0u64; dim1 + dim2];
|
||||
let inv = inverse(1 << level, plaintext_modulus).unwrap();
|
||||
pt[query_index / dim2] = inv;
|
||||
pt[dim1 + (query_index % dim2)] = inv;
|
||||
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?;
|
||||
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
|
||||
query.to_bytes()
|
||||
});
|
||||
println!("📄 Query: {}", HumanBytes(query.len() as u64));
|
||||
|
||||
// Server response: The server receives the query, and after deserializing it,
|
||||
// performs the following steps:
|
||||
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
|
||||
// If the client created the query correctly, the server will have obtained
|
||||
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
|
||||
// `dim1 + j`th ones encrypting `1`.
|
||||
// 2- It computes the inner product of the first `dim1` ciphertexts with the
|
||||
// columns if the database viewed as a dim1 * dim2 matrix.
|
||||
// 3- It then multiplies the column of ciphertexts with the next `dim2`
|
||||
// ciphertexts obtained after expansion of the query, then relinearize and
|
||||
// modulus switch to the latest modulus to optimize communication.
|
||||
// The operation is done `5` times to compute an average response time.
|
||||
let response = timeit_n!("Server response", 5, {
|
||||
let start = std::time::Instant::now();
|
||||
let query = bfv::Ciphertext::from_bytes(&query, ¶ms)?;
|
||||
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
|
||||
println!("Expand: {:?}", start.elapsed());
|
||||
// Server response: The server receives the query, and after deserializing it,
|
||||
// performs the following steps:
|
||||
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
|
||||
// If the client created the query correctly, the server will have obtained
|
||||
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
|
||||
// `dim1 + j`th ones encrypting `1`.
|
||||
// 2- It computes the inner product of the first `dim1` ciphertexts with the
|
||||
// columns if the database viewed as a dim1 * dim2 matrix.
|
||||
// 3- It then multiplies the column of ciphertexts with the next `dim2`
|
||||
// ciphertexts obtained after expansion of the query, then relinearize and
|
||||
// modulus switch to the latest modulus to optimize communication.
|
||||
// The operation is done `5` times to compute an average response time.
|
||||
let response = timeit_n!("Server response", 5, {
|
||||
let start = std::time::Instant::now();
|
||||
let query = bfv::Ciphertext::from_bytes(&query, ¶ms)?;
|
||||
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
|
||||
println!("Expand: {:?}", start.elapsed());
|
||||
|
||||
let query_vec = &expanded_query[..dim1];
|
||||
let dot_product_mod_switch =
|
||||
move |i, database: &[bfv::Plaintext]| -> fhe::Result<bfv::Ciphertext> {
|
||||
let column = database.iter().skip(i).step_by(dim2);
|
||||
bfv::dot_product_scalar(query_vec.iter(), column)
|
||||
};
|
||||
let query_vec = &expanded_query[..dim1];
|
||||
let dot_product_mod_switch =
|
||||
move |i, database: &[bfv::Plaintext]| -> fhe::Result<bfv::Ciphertext> {
|
||||
let column = database.iter().skip(i).step_by(dim2);
|
||||
bfv::dot_product_scalar(query_vec.iter(), column)
|
||||
};
|
||||
|
||||
let mut out = bfv::Ciphertext::zero(¶ms);
|
||||
for (i, ci) in expanded_query[dim1..].iter().enumerate() {
|
||||
out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci)
|
||||
}
|
||||
rk.relinearizes(&mut out)?;
|
||||
out.mod_switch_to_last_level();
|
||||
out.to_bytes()
|
||||
});
|
||||
println!("📄 Response: {}", HumanBytes(response.len() as u64));
|
||||
let mut out = bfv::Ciphertext::zero(¶ms);
|
||||
for (i, ci) in expanded_query[dim1..].iter().enumerate() {
|
||||
out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci)
|
||||
}
|
||||
rk.relinearizes(&mut out)?;
|
||||
out.mod_switch_to_last_level();
|
||||
out.to_bytes()
|
||||
});
|
||||
println!("📄 Response: {}", HumanBytes(response.len() as u64));
|
||||
|
||||
// Client processing: Upon reception of the response, the client decrypts.
|
||||
// Finally, it outputs the plaintext bytes, offset by the correct value
|
||||
// (remember the database was reshaped to maximize how many elements) were
|
||||
// embedded in a single ciphertext.
|
||||
let answer = timeit!("Client answer", {
|
||||
let response = bfv::Ciphertext::from_bytes(&response, ¶ms).unwrap();
|
||||
// Client processing: Upon reception of the response, the client decrypts.
|
||||
// Finally, it outputs the plaintext bytes, offset by the correct value
|
||||
// (remember the database was reshaped to maximize how many elements) were
|
||||
// embedded in a single ciphertext.
|
||||
let answer = timeit!("Client answer", {
|
||||
let response = bfv::Ciphertext::from_bytes(&response, ¶ms).unwrap();
|
||||
|
||||
let pt = sk.try_decrypt(&response).unwrap();
|
||||
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
|
||||
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
|
||||
let offset = index
|
||||
% number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
let pt = sk.try_decrypt(&response).unwrap();
|
||||
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
|
||||
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
|
||||
let offset = index
|
||||
% number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
|
||||
println!("Noise in response: {:?}", unsafe {
|
||||
sk.measure_noise(&response)
|
||||
});
|
||||
println!("Noise in response: {:?}", unsafe {
|
||||
sk.measure_noise(&response)
|
||||
});
|
||||
|
||||
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
|
||||
});
|
||||
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
|
||||
});
|
||||
|
||||
assert_eq!(&database[index], &answer);
|
||||
assert_eq!(&database[index], &answer);
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ use console::style;
|
||||
use fhe::bfv;
|
||||
use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation};
|
||||
use fhe_traits::{
|
||||
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime,
|
||||
FheEncrypter, Serialize,
|
||||
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime,
|
||||
FheEncrypter, Serialize,
|
||||
};
|
||||
use fhe_util::{div_ceil, ilog2, inverse, transcode_bidirectional, transcode_to_bytes};
|
||||
use indicatif::HumanBytes;
|
||||
@@ -21,323 +21,323 @@ use itertools::Itertools;
|
||||
use rand::{rngs::OsRng, thread_rng, RngCore};
|
||||
use std::{env, error::Error, process::exit, sync::Arc};
|
||||
use util::{
|
||||
encode_database, generate_database, number_elements_per_plaintext,
|
||||
timeit::{timeit, timeit_n},
|
||||
encode_database, generate_database, number_elements_per_plaintext,
|
||||
timeit::{timeit, timeit_n},
|
||||
};
|
||||
|
||||
fn print_notice_and_exit(max_element_size: usize, error: Option<String>) {
|
||||
println!(
|
||||
"{} SealPIR with fhe.rs",
|
||||
style(" overview:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} sealpir [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
|
||||
style(" usage:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} {} must be at least 1, and {} must be between 1 and {}",
|
||||
style("constraints:").magenta().bold(),
|
||||
style("database_size").blue(),
|
||||
style("element_size").blue(),
|
||||
max_element_size
|
||||
);
|
||||
if let Some(error) = error {
|
||||
println!("{} {}", style(" error:").red().bold(), error);
|
||||
}
|
||||
exit(0);
|
||||
println!(
|
||||
"{} SealPIR with fhe.rs",
|
||||
style(" overview:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} sealpir [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
|
||||
style(" usage:").magenta().bold()
|
||||
);
|
||||
println!(
|
||||
"{} {} must be at least 1, and {} must be between 1 and {}",
|
||||
style("constraints:").magenta().bold(),
|
||||
style("database_size").blue(),
|
||||
style("element_size").blue(),
|
||||
max_element_size
|
||||
);
|
||||
if let Some(error) = error {
|
||||
println!("{} {}", style(" error:").red().bold(), error);
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let degree = 4096usize;
|
||||
let plaintext_modulus = 2056193;
|
||||
let moduli_sizes = [36, 36, 37];
|
||||
let degree = 4096usize;
|
||||
let plaintext_modulus = 2056193;
|
||||
let moduli_sizes = [36, 36, 37];
|
||||
|
||||
// Compute what is the maximum byte-length of an element to fit within one
|
||||
// ciphertext. Each coefficient of the ciphertext polynomial can contain
|
||||
// floor(log2(plaintext_modulus)) bits.
|
||||
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
|
||||
// Compute what is the maximum byte-length of an element to fit within one
|
||||
// ciphertext. Each coefficient of the ciphertext polynomial can contain
|
||||
// floor(log2(plaintext_modulus)) bits.
|
||||
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
|
||||
|
||||
// This executable is a command line tool which enables to specify different
|
||||
// database and element sizes.
|
||||
let args: Vec<String> = env::args().skip(1).collect();
|
||||
// This executable is a command line tool which enables to specify different
|
||||
// database and element sizes.
|
||||
let args: Vec<String> = env::args().skip(1).collect();
|
||||
|
||||
// Print the help if requested.
|
||||
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
|
||||
print_notice_and_exit(max_element_size, None)
|
||||
}
|
||||
// Print the help if requested.
|
||||
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
|
||||
print_notice_and_exit(max_element_size, None)
|
||||
}
|
||||
|
||||
// Use the default values from <https://github.com/microsoft/SealPIR>.
|
||||
let mut database_size = 1 << 16;
|
||||
let mut elements_size = 1024;
|
||||
// Use the default values from <https://github.com/microsoft/SealPIR>.
|
||||
let mut database_size = 1 << 16;
|
||||
let mut elements_size = 1024;
|
||||
|
||||
// Update the database size and/or element size depending on the arguments
|
||||
// provided.
|
||||
for arg in &args {
|
||||
if arg.starts_with("--database_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--database_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
database_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else if arg.starts_with("--element_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--element_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
elements_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some(format!("Unrecognized command: {arg}")),
|
||||
)
|
||||
}
|
||||
}
|
||||
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Element or database sizes out of bound".to_string()),
|
||||
)
|
||||
}
|
||||
// Update the database size and/or element size depending on the arguments
|
||||
// provided.
|
||||
for arg in &args {
|
||||
if arg.starts_with("--database_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--database_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
database_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else if arg.starts_with("--element_size") {
|
||||
let a: Vec<&str> = arg.rsplit('=').collect();
|
||||
if a.len() != 2 || a[0].parse::<usize>().is_err() {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Invalid `--element_size` command".to_string()),
|
||||
)
|
||||
} else {
|
||||
elements_size = a[0].parse::<usize>().unwrap()
|
||||
}
|
||||
} else {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some(format!("Unrecognized command: {arg}")),
|
||||
)
|
||||
}
|
||||
}
|
||||
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
|
||||
print_notice_and_exit(
|
||||
max_element_size,
|
||||
Some("Element or database sizes out of bound".to_string()),
|
||||
)
|
||||
}
|
||||
|
||||
// The parameters are within bound, let's go! Let's first display some
|
||||
// information about the database.
|
||||
println!("# SealPIR with fhe.rs");
|
||||
println!(
|
||||
"database of {}",
|
||||
HumanBytes((database_size * elements_size) as u64)
|
||||
);
|
||||
println!("\tdatabase_size = {database_size}");
|
||||
println!("\telements_size = {elements_size}");
|
||||
// The parameters are within bound, let's go! Let's first display some
|
||||
// information about the database.
|
||||
println!("# SealPIR with fhe.rs");
|
||||
println!(
|
||||
"database of {}",
|
||||
HumanBytes((database_size * elements_size) as u64)
|
||||
);
|
||||
println!("\tdatabase_size = {database_size}");
|
||||
println!("\telements_size = {elements_size}");
|
||||
|
||||
// Generation of a random database.
|
||||
let database = timeit!("Database generation", {
|
||||
generate_database(database_size, elements_size)
|
||||
});
|
||||
// Generation of a random database.
|
||||
let database = timeit!("Database generation", {
|
||||
generate_database(database_size, elements_size)
|
||||
});
|
||||
|
||||
// Let's generate the BFV parameters structure.
|
||||
let params = timeit!(
|
||||
"Parameters generation",
|
||||
Arc::new(
|
||||
bfv::BfvParametersBuilder::new()
|
||||
.set_degree(degree)
|
||||
.set_plaintext_modulus(plaintext_modulus)
|
||||
.set_moduli_sizes(&moduli_sizes)
|
||||
.build()
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
// Let's generate the BFV parameters structure.
|
||||
let params = timeit!(
|
||||
"Parameters generation",
|
||||
Arc::new(
|
||||
bfv::BfvParametersBuilder::new()
|
||||
.set_degree(degree)
|
||||
.set_plaintext_modulus(plaintext_modulus)
|
||||
.set_moduli_sizes(&moduli_sizes)
|
||||
.build()
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
|
||||
// Proprocess the database on the server side: the database will be reshaped
|
||||
// so as to pack as many values as possible in every row so that it fits in one
|
||||
// ciphertext, and each element will be encoded as a polynomial in Ntt
|
||||
// representation.
|
||||
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
|
||||
encode_database(&database, params.clone(), 1)
|
||||
});
|
||||
// Proprocess the database on the server side: the database will be reshaped
|
||||
// so as to pack as many values as possible in every row so that it fits in one
|
||||
// ciphertext, and each element will be encoded as a polynomial in Ntt
|
||||
// representation.
|
||||
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
|
||||
encode_database(&database, params.clone(), 1)
|
||||
});
|
||||
|
||||
// Client setup: the client generates a secret key, and an evaluation key for
|
||||
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
|
||||
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)).
|
||||
let (sk, ek_expansion_serialized) = timeit!("Client setup", {
|
||||
let sk = bfv::SecretKey::random(¶ms, &mut OsRng);
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
println!("expansion_level = {level}");
|
||||
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
|
||||
.enable_expansion(level)?
|
||||
.build(&mut thread_rng())?;
|
||||
let ek_expansion_serialized = ek_expansion.to_bytes();
|
||||
(sk, ek_expansion_serialized)
|
||||
});
|
||||
println!(
|
||||
"📄 Evaluation key: {}",
|
||||
HumanBytes(ek_expansion_serialized.len() as u64)
|
||||
);
|
||||
// Client setup: the client generates a secret key, and an evaluation key for
|
||||
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
|
||||
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)).
|
||||
let (sk, ek_expansion_serialized) = timeit!("Client setup", {
|
||||
let sk = bfv::SecretKey::random(¶ms, &mut OsRng);
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
println!("expansion_level = {level}");
|
||||
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
|
||||
.enable_expansion(level)?
|
||||
.build(&mut thread_rng())?;
|
||||
let ek_expansion_serialized = ek_expansion.to_bytes();
|
||||
(sk, ek_expansion_serialized)
|
||||
});
|
||||
println!(
|
||||
"📄 Evaluation key: {}",
|
||||
HumanBytes(ek_expansion_serialized.len() as u64)
|
||||
);
|
||||
|
||||
// Server setup: the server receives the evaluation key and deserializes it.
|
||||
let ek_expansion = timeit!(
|
||||
"Server setup",
|
||||
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?
|
||||
);
|
||||
// Server setup: the server receives the evaluation key and deserializes it.
|
||||
let ek_expansion = timeit!(
|
||||
"Server setup",
|
||||
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?
|
||||
);
|
||||
|
||||
// Client query: when the client wants to retrieve the `index`-th row of the
|
||||
// original database, it first computes to which row it corresponds in the
|
||||
// original database, and then encrypt a selection vector with 0 everywhere,
|
||||
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
|
||||
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
|
||||
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
|
||||
// The ciphertext is set at level `1`, which means that one of the three moduli
|
||||
// has been dropped already; the reason is that the expansion will happen at
|
||||
// level 0 (with all three moduli) and then one of the moduli will be dropped
|
||||
// to reduce the noise.
|
||||
let index = (thread_rng().next_u64() as usize) % database_size;
|
||||
let query = timeit!("Client query", {
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
let query_index = index
|
||||
/ number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
let mut pt = vec![0u64; dim1 + dim2];
|
||||
let inv = inverse(1 << level, plaintext_modulus).unwrap();
|
||||
pt[query_index / dim2] = inv;
|
||||
pt[dim1 + (query_index % dim2)] = inv;
|
||||
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?;
|
||||
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
|
||||
query.to_bytes()
|
||||
});
|
||||
println!("📄 Query: {}", HumanBytes(query.len() as u64));
|
||||
// Client query: when the client wants to retrieve the `index`-th row of the
|
||||
// original database, it first computes to which row it corresponds in the
|
||||
// original database, and then encrypt a selection vector with 0 everywhere,
|
||||
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
|
||||
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
|
||||
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
|
||||
// The ciphertext is set at level `1`, which means that one of the three moduli
|
||||
// has been dropped already; the reason is that the expansion will happen at
|
||||
// level 0 (with all three moduli) and then one of the moduli will be dropped
|
||||
// to reduce the noise.
|
||||
let index = (thread_rng().next_u64() as usize) % database_size;
|
||||
let query = timeit!("Client query", {
|
||||
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
|
||||
let query_index = index
|
||||
/ number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
let mut pt = vec![0u64; dim1 + dim2];
|
||||
let inv = inverse(1 << level, plaintext_modulus).unwrap();
|
||||
pt[query_index / dim2] = inv;
|
||||
pt[dim1 + (query_index % dim2)] = inv;
|
||||
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?;
|
||||
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
|
||||
query.to_bytes()
|
||||
});
|
||||
println!("📄 Query: {}", HumanBytes(query.len() as u64));
|
||||
|
||||
// Server response: The server receives the query, and after deserializing it,
|
||||
// performs the following steps:
|
||||
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
|
||||
// If the client created the query correctly, the server will have obtained
|
||||
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
|
||||
// `dim1 + j`th ones encrypting `1`.
|
||||
// 2- It computes the inner product of the first `dim1` ciphertexts with the
|
||||
// columns if the database viewed as a dim1 * dim2 matrix, and modulo-switch
|
||||
// the ciphertext once.
|
||||
// 3- It parses the resulting ciphertexts as vector of plaintexts, and compute
|
||||
// the inner product of the last `dim2` ciphertexts from step 1 with the
|
||||
// transposed of the plaintext obtained above.
|
||||
// The operation is done `5` times to compute an average response time.
|
||||
let responses: Vec<Vec<u8>> = timeit_n!("Server response", 5, {
|
||||
let start = std::time::Instant::now();
|
||||
let query = bfv::Ciphertext::from_bytes(&query, ¶ms);
|
||||
let query = query.unwrap();
|
||||
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
|
||||
println!("Expand: {}", DisplayDuration(start.elapsed()));
|
||||
// Server response: The server receives the query, and after deserializing it,
|
||||
// performs the following steps:
|
||||
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
|
||||
// If the client created the query correctly, the server will have obtained
|
||||
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
|
||||
// `dim1 + j`th ones encrypting `1`.
|
||||
// 2- It computes the inner product of the first `dim1` ciphertexts with the
|
||||
// columns if the database viewed as a dim1 * dim2 matrix, and modulo-switch
|
||||
// the ciphertext once.
|
||||
// 3- It parses the resulting ciphertexts as vector of plaintexts, and compute
|
||||
// the inner product of the last `dim2` ciphertexts from step 1 with the
|
||||
// transposed of the plaintext obtained above.
|
||||
// The operation is done `5` times to compute an average response time.
|
||||
let responses: Vec<Vec<u8>> = timeit_n!("Server response", 5, {
|
||||
let start = std::time::Instant::now();
|
||||
let query = bfv::Ciphertext::from_bytes(&query, ¶ms);
|
||||
let query = query.unwrap();
|
||||
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
|
||||
println!("Expand: {}", DisplayDuration(start.elapsed()));
|
||||
|
||||
let query_vec = &expanded_query[..dim1];
|
||||
let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| {
|
||||
let column = database.iter().skip(i).step_by(dim2);
|
||||
let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?;
|
||||
c.mod_switch_to_last_level();
|
||||
Ok(c)
|
||||
};
|
||||
let query_vec = &expanded_query[..dim1];
|
||||
let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| {
|
||||
let column = database.iter().skip(i).step_by(dim2);
|
||||
let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?;
|
||||
c.mod_switch_to_last_level();
|
||||
Ok(c)
|
||||
};
|
||||
|
||||
let dot_products = (0..dim2)
|
||||
.map(|i| dot_product_mod_switch(i, &preprocessed_database))
|
||||
.collect::<fhe::Result<Vec<bfv::Ciphertext>>>()?;
|
||||
let dot_products = (0..dim2)
|
||||
.map(|i| dot_product_mod_switch(i, &preprocessed_database))
|
||||
.collect::<fhe::Result<Vec<bfv::Ciphertext>>>()?;
|
||||
|
||||
let fold = dot_products
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let mut pt_values = Vec::with_capacity(div_ceil(
|
||||
2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize)),
|
||||
ilog2(plaintext_modulus),
|
||||
));
|
||||
pt_values.append(&mut transcode_bidirectional(
|
||||
c.get(0).unwrap().coefficients().as_slice().unwrap(),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
ilog2(plaintext_modulus),
|
||||
));
|
||||
pt_values.append(&mut transcode_bidirectional(
|
||||
c.get(1).unwrap().coefficients().as_slice().unwrap(),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
ilog2(plaintext_modulus),
|
||||
));
|
||||
unsafe {
|
||||
Ok(bfv::PlaintextVec::try_encode_vt(
|
||||
&pt_values,
|
||||
bfv::Encoding::poly_at_level(1),
|
||||
¶ms,
|
||||
)?
|
||||
.0)
|
||||
}
|
||||
})
|
||||
.collect::<fhe::Result<Vec<Vec<bfv::Plaintext>>>>()?;
|
||||
(0..fold[0].len())
|
||||
.map(|i| {
|
||||
let mut outi = bfv::dot_product_scalar(
|
||||
expanded_query[dim1..].iter(),
|
||||
fold.iter().map(|pts| pts.get(i).unwrap()),
|
||||
)?;
|
||||
outi.mod_switch_to_last_level();
|
||||
Ok(outi.to_bytes())
|
||||
})
|
||||
.collect::<fhe::Result<Vec<Vec<u8>>>>()?
|
||||
});
|
||||
println!(
|
||||
"📄 Response: {}",
|
||||
HumanBytes(responses.iter().map(|r| r.len()).sum::<usize>() as u64)
|
||||
);
|
||||
let fold = dot_products
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let mut pt_values = Vec::with_capacity(div_ceil(
|
||||
2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize)),
|
||||
ilog2(plaintext_modulus),
|
||||
));
|
||||
pt_values.append(&mut transcode_bidirectional(
|
||||
c.get(0).unwrap().coefficients().as_slice().unwrap(),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
ilog2(plaintext_modulus),
|
||||
));
|
||||
pt_values.append(&mut transcode_bidirectional(
|
||||
c.get(1).unwrap().coefficients().as_slice().unwrap(),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
ilog2(plaintext_modulus),
|
||||
));
|
||||
unsafe {
|
||||
Ok(bfv::PlaintextVec::try_encode_vt(
|
||||
&pt_values,
|
||||
bfv::Encoding::poly_at_level(1),
|
||||
¶ms,
|
||||
)?
|
||||
.0)
|
||||
}
|
||||
})
|
||||
.collect::<fhe::Result<Vec<Vec<bfv::Plaintext>>>>()?;
|
||||
(0..fold[0].len())
|
||||
.map(|i| {
|
||||
let mut outi = bfv::dot_product_scalar(
|
||||
expanded_query[dim1..].iter(),
|
||||
fold.iter().map(|pts| pts.get(i).unwrap()),
|
||||
)?;
|
||||
outi.mod_switch_to_last_level();
|
||||
Ok(outi.to_bytes())
|
||||
})
|
||||
.collect::<fhe::Result<Vec<Vec<u8>>>>()?
|
||||
});
|
||||
println!(
|
||||
"📄 Response: {}",
|
||||
HumanBytes(responses.iter().map(|r| r.len()).sum::<usize>() as u64)
|
||||
);
|
||||
|
||||
// Client processing: Upon reception of the response, the client decrypts
|
||||
// the ciphertexts and recover the "ciphertexts" which were parsed as plaintext,
|
||||
// which it decrypts too. Finally, it outputs the plaintext bytes, offset by the
|
||||
// correct value (remember the database was reshaped to maximize how many
|
||||
// elements) were embedded in a single ciphertext.
|
||||
let answer = timeit!("Client answer", {
|
||||
let responses = responses
|
||||
.iter()
|
||||
.map(|r| bfv::Ciphertext::from_bytes(r, ¶ms).unwrap())
|
||||
.collect_vec();
|
||||
let decrypted_pt = responses
|
||||
.iter()
|
||||
.map(|r| sk.try_decrypt(r).unwrap())
|
||||
.collect_vec();
|
||||
let decrypted_vec = decrypted_pt
|
||||
.iter()
|
||||
.flat_map(|pt| Vec::<u64>::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap())
|
||||
.collect_vec();
|
||||
let expect_ncoefficients = div_ceil(
|
||||
params.degree() * (64 - params.moduli()[0].leading_zeros() as usize),
|
||||
ilog2(plaintext_modulus),
|
||||
);
|
||||
assert!(decrypted_vec.len() >= 2 * expect_ncoefficients);
|
||||
let mut poly0 = transcode_bidirectional(
|
||||
&decrypted_vec[..expect_ncoefficients],
|
||||
ilog2(plaintext_modulus),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
);
|
||||
let mut poly1 = transcode_bidirectional(
|
||||
&decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients],
|
||||
ilog2(plaintext_modulus),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
);
|
||||
assert!(poly0.len() >= params.degree());
|
||||
assert!(poly1.len() >= params.degree());
|
||||
poly0.truncate(params.degree());
|
||||
poly1.truncate(params.degree());
|
||||
// Client processing: Upon reception of the response, the client decrypts
|
||||
// the ciphertexts and recover the "ciphertexts" which were parsed as plaintext,
|
||||
// which it decrypts too. Finally, it outputs the plaintext bytes, offset by the
|
||||
// correct value (remember the database was reshaped to maximize how many
|
||||
// elements) were embedded in a single ciphertext.
|
||||
let answer = timeit!("Client answer", {
|
||||
let responses = responses
|
||||
.iter()
|
||||
.map(|r| bfv::Ciphertext::from_bytes(r, ¶ms).unwrap())
|
||||
.collect_vec();
|
||||
let decrypted_pt = responses
|
||||
.iter()
|
||||
.map(|r| sk.try_decrypt(r).unwrap())
|
||||
.collect_vec();
|
||||
let decrypted_vec = decrypted_pt
|
||||
.iter()
|
||||
.flat_map(|pt| Vec::<u64>::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap())
|
||||
.collect_vec();
|
||||
let expect_ncoefficients = div_ceil(
|
||||
params.degree() * (64 - params.moduli()[0].leading_zeros() as usize),
|
||||
ilog2(plaintext_modulus),
|
||||
);
|
||||
assert!(decrypted_vec.len() >= 2 * expect_ncoefficients);
|
||||
let mut poly0 = transcode_bidirectional(
|
||||
&decrypted_vec[..expect_ncoefficients],
|
||||
ilog2(plaintext_modulus),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
);
|
||||
let mut poly1 = transcode_bidirectional(
|
||||
&decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients],
|
||||
ilog2(plaintext_modulus),
|
||||
64 - params.moduli()[0].leading_zeros() as usize,
|
||||
);
|
||||
assert!(poly0.len() >= params.degree());
|
||||
assert!(poly1.len() >= params.degree());
|
||||
poly0.truncate(params.degree());
|
||||
poly1.truncate(params.degree());
|
||||
|
||||
let ctx = Arc::new(Context::new(¶ms.moduli()[..1], params.degree())?);
|
||||
let ct = bfv::Ciphertext::new(
|
||||
vec![
|
||||
Poly::try_convert_from(poly0, &ctx, true, Representation::Ntt)?,
|
||||
Poly::try_convert_from(poly1, &ctx, true, Representation::Ntt)?,
|
||||
],
|
||||
¶ms,
|
||||
)?;
|
||||
let ctx = Arc::new(Context::new(¶ms.moduli()[..1], params.degree())?);
|
||||
let ct = bfv::Ciphertext::new(
|
||||
vec![
|
||||
Poly::try_convert_from(poly0, &ctx, true, Representation::Ntt)?,
|
||||
Poly::try_convert_from(poly1, &ctx, true, Representation::Ntt)?,
|
||||
],
|
||||
¶ms,
|
||||
)?;
|
||||
|
||||
let pt = sk.try_decrypt(&ct).unwrap();
|
||||
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
|
||||
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
|
||||
let offset = index
|
||||
% number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
let pt = sk.try_decrypt(&ct).unwrap();
|
||||
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
|
||||
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
|
||||
let offset = index
|
||||
% number_elements_per_plaintext(
|
||||
params.degree(),
|
||||
ilog2(plaintext_modulus),
|
||||
elements_size,
|
||||
);
|
||||
|
||||
println!("Noise in response (ct): {:?}", unsafe {
|
||||
sk.measure_noise(&ct)
|
||||
});
|
||||
println!("Noise in response (ct): {:?}", unsafe {
|
||||
sk.measure_noise(&ct)
|
||||
});
|
||||
|
||||
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
|
||||
});
|
||||
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
|
||||
});
|
||||
|
||||
// Assert that the answer is indeed the `index`-th element of the initial
|
||||
// database.
|
||||
assert_eq!(&database[index], &answer);
|
||||
// Assert that the answer is indeed the `index`-th element of the initial
|
||||
// database.
|
||||
assert_eq!(&database[index], &answer);
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -7,37 +7,37 @@ use std::{cmp::min, fmt, sync::Arc, time::Duration};
|
||||
|
||||
/// Macros to time code and display a human-readable duration.
|
||||
pub mod timeit {
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! timeit_n {
|
||||
($name:expr, $loops:expr, $code:expr) => {{
|
||||
use util::DisplayDuration;
|
||||
let start = std::time::Instant::now();
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! timeit_n {
|
||||
($name:expr, $loops:expr, $code:expr) => {{
|
||||
use util::DisplayDuration;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
for _ in 1..$loops {
|
||||
let _ = $code;
|
||||
}
|
||||
let r = $code;
|
||||
println!(
|
||||
"⏱ {}: {}",
|
||||
$name,
|
||||
DisplayDuration(start.elapsed() / $loops)
|
||||
);
|
||||
r
|
||||
}};
|
||||
}
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
for _ in 1..$loops {
|
||||
let _ = $code;
|
||||
}
|
||||
let r = $code;
|
||||
println!(
|
||||
"⏱ {}: {}",
|
||||
$name,
|
||||
DisplayDuration(start.elapsed() / $loops)
|
||||
);
|
||||
r
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! timeit {
|
||||
($name:expr, $code:expr) => {{
|
||||
timeit_n!($name, 1, $code)
|
||||
}};
|
||||
}
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! timeit {
|
||||
($name:expr, $code:expr) => {{
|
||||
timeit_n!($name, 1, $code)
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub(crate) use timeit;
|
||||
#[allow(unused_imports)]
|
||||
pub(crate) use timeit_n;
|
||||
#[allow(unused_imports)]
|
||||
pub(crate) use timeit;
|
||||
#[allow(unused_imports)]
|
||||
pub(crate) use timeit_n;
|
||||
}
|
||||
|
||||
/// Utility struct for displaying human-readable duration of the form "10.5 ms",
|
||||
@@ -45,17 +45,17 @@ pub mod timeit {
|
||||
pub struct DisplayDuration(pub Duration);
|
||||
|
||||
impl fmt::Display for DisplayDuration {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let duration_ns = self.0.as_nanos();
|
||||
if duration_ns < 1_000_u128 {
|
||||
write!(f, "{duration_ns} ns")
|
||||
} else if duration_ns < 1_000_000_u128 {
|
||||
write!(f, "{} μs", (duration_ns + 500) / 1_000)
|
||||
} else {
|
||||
let duration_ms_times_10 = (duration_ns + 50_000) / (100_000);
|
||||
write!(f, "{} ms", (duration_ms_times_10 as f64) / 10.0)
|
||||
}
|
||||
}
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let duration_ns = self.0.as_nanos();
|
||||
if duration_ns < 1_000_u128 {
|
||||
write!(f, "{duration_ns} ns")
|
||||
} else if duration_ns < 1_000_000_u128 {
|
||||
write!(f, "{} μs", (duration_ns + 500) / 1_000)
|
||||
} else {
|
||||
let duration_ms_times_10 = (duration_ns + 50_000) / (100_000);
|
||||
write!(f, "{} ms", (duration_ms_times_10 as f64) / 10.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions for Private Information Retrieval.
|
||||
@@ -64,60 +64,60 @@ impl fmt::Display for DisplayDuration {
|
||||
/// little endian encoding of the index. When the element size is less than 4B,
|
||||
/// the encoding is truncated.
|
||||
pub fn generate_database(database_size: usize, elements_size: usize) -> Vec<Vec<u8>> {
|
||||
assert!(elements_size > 0 && database_size > 0);
|
||||
let mut database = vec![vec![0u8; elements_size]; database_size];
|
||||
for (i, element) in database.iter_mut().enumerate() {
|
||||
element[..min(4, elements_size)]
|
||||
.copy_from_slice(&(i as u32).to_le_bytes()[..min(4, elements_size)]);
|
||||
}
|
||||
database
|
||||
assert!(elements_size > 0 && database_size > 0);
|
||||
let mut database = vec![vec![0u8; elements_size]; database_size];
|
||||
for (i, element) in database.iter_mut().enumerate() {
|
||||
element[..min(4, elements_size)]
|
||||
.copy_from_slice(&(i as u32).to_le_bytes()[..min(4, elements_size)]);
|
||||
}
|
||||
database
|
||||
}
|
||||
|
||||
pub fn number_elements_per_plaintext(
|
||||
degree: usize,
|
||||
plaintext_nbits: usize,
|
||||
elements_size: usize,
|
||||
degree: usize,
|
||||
plaintext_nbits: usize,
|
||||
elements_size: usize,
|
||||
) -> usize {
|
||||
(plaintext_nbits * degree) / (elements_size * 8)
|
||||
(plaintext_nbits * degree) / (elements_size * 8)
|
||||
}
|
||||
|
||||
pub fn encode_database(
|
||||
database: &Vec<Vec<u8>>,
|
||||
par: Arc<bfv::BfvParameters>,
|
||||
level: usize,
|
||||
database: &Vec<Vec<u8>>,
|
||||
par: Arc<bfv::BfvParameters>,
|
||||
level: usize,
|
||||
) -> (Vec<bfv::Plaintext>, (usize, usize)) {
|
||||
assert!(!database.is_empty());
|
||||
assert!(!database.is_empty());
|
||||
|
||||
let elements_size = database[0].len();
|
||||
let plaintext_nbits = ilog2(par.plaintext());
|
||||
let number_elements_per_plaintext =
|
||||
number_elements_per_plaintext(par.degree(), plaintext_nbits, elements_size);
|
||||
let number_rows =
|
||||
(database.len() + number_elements_per_plaintext - 1) / number_elements_per_plaintext;
|
||||
println!("number_rows = {number_rows}");
|
||||
println!("number_elements_per_plaintext = {number_elements_per_plaintext}");
|
||||
let dimension_1 = (number_rows as f64).sqrt().ceil() as usize;
|
||||
let dimension_2 = (number_rows + dimension_1 - 1) / dimension_1;
|
||||
println!("dimensions = {dimension_1} {dimension_2}");
|
||||
println!("dimension = {}", dimension_1 * dimension_2);
|
||||
let mut preprocessed_database =
|
||||
vec![
|
||||
bfv::Plaintext::zero(bfv::Encoding::poly_at_level(level), &par).unwrap();
|
||||
dimension_1 * dimension_2
|
||||
];
|
||||
(0..number_rows).for_each(|i| {
|
||||
let mut serialized_plaintext = vec![0u8; number_elements_per_plaintext * elements_size];
|
||||
for j in 0..number_elements_per_plaintext {
|
||||
if let Some(pt) = database.get(j + i * number_elements_per_plaintext) {
|
||||
serialized_plaintext[j * elements_size..(j + 1) * elements_size].copy_from_slice(pt)
|
||||
}
|
||||
}
|
||||
let pt_values = transcode_from_bytes(&serialized_plaintext, plaintext_nbits);
|
||||
preprocessed_database[i] =
|
||||
bfv::Plaintext::try_encode(&pt_values, bfv::Encoding::poly_at_level(level), &par)
|
||||
.unwrap();
|
||||
});
|
||||
(preprocessed_database, (dimension_1, dimension_2))
|
||||
let elements_size = database[0].len();
|
||||
let plaintext_nbits = ilog2(par.plaintext());
|
||||
let number_elements_per_plaintext =
|
||||
number_elements_per_plaintext(par.degree(), plaintext_nbits, elements_size);
|
||||
let number_rows =
|
||||
(database.len() + number_elements_per_plaintext - 1) / number_elements_per_plaintext;
|
||||
println!("number_rows = {number_rows}");
|
||||
println!("number_elements_per_plaintext = {number_elements_per_plaintext}");
|
||||
let dimension_1 = (number_rows as f64).sqrt().ceil() as usize;
|
||||
let dimension_2 = (number_rows + dimension_1 - 1) / dimension_1;
|
||||
println!("dimensions = {dimension_1} {dimension_2}");
|
||||
println!("dimension = {}", dimension_1 * dimension_2);
|
||||
let mut preprocessed_database =
|
||||
vec![
|
||||
bfv::Plaintext::zero(bfv::Encoding::poly_at_level(level), &par).unwrap();
|
||||
dimension_1 * dimension_2
|
||||
];
|
||||
(0..number_rows).for_each(|i| {
|
||||
let mut serialized_plaintext = vec![0u8; number_elements_per_plaintext * elements_size];
|
||||
for j in 0..number_elements_per_plaintext {
|
||||
if let Some(pt) = database.get(j + i * number_elements_per_plaintext) {
|
||||
serialized_plaintext[j * elements_size..(j + 1) * elements_size].copy_from_slice(pt)
|
||||
}
|
||||
}
|
||||
let pt_values = transcode_from_bytes(&serialized_plaintext, plaintext_nbits);
|
||||
preprocessed_database[i] =
|
||||
bfv::Plaintext::try_encode(&pt_values, bfv::Encoding::poly_at_level(level), &par)
|
||||
.unwrap();
|
||||
});
|
||||
(preprocessed_database, (dimension_1, dimension_2))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Ciphertext type in the BFV encryption scheme.
|
||||
|
||||
use crate::bfv::{
|
||||
parameters::BfvParameters, proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom,
|
||||
parameters::BfvParameters, proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use fhe_math::rq::{Poly, Representation};
|
||||
use fhe_traits::{
|
||||
DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize,
|
||||
DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize,
|
||||
};
|
||||
use protobuf::Message;
|
||||
use rand::SeedableRng;
|
||||
@@ -16,287 +16,287 @@ use std::sync::Arc;
|
||||
/// A ciphertext encrypting a plaintext.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Ciphertext {
|
||||
/// The parameters of the underlying BFV encryption scheme.
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
/// The parameters of the underlying BFV encryption scheme.
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
|
||||
/// The seed that generated the polynomial c1 in a fresh ciphertext.
|
||||
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
|
||||
/// The seed that generated the polynomial c1 in a fresh ciphertext.
|
||||
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
|
||||
|
||||
/// The ciphertext elements.
|
||||
pub(crate) c: Vec<Poly>,
|
||||
/// The ciphertext elements.
|
||||
pub(crate) c: Vec<Poly>,
|
||||
|
||||
/// The ciphertext level
|
||||
pub(crate) level: usize,
|
||||
/// The ciphertext level
|
||||
pub(crate) level: usize,
|
||||
}
|
||||
|
||||
impl Ciphertext {
|
||||
/// Modulo switch the ciphertext to the last level.
|
||||
pub fn mod_switch_to_last_level(&mut self) {
|
||||
self.level = self.par.max_level();
|
||||
let last_ctx = self.par.ctx_at_level(self.level).unwrap();
|
||||
self.seed = None;
|
||||
self.c.iter_mut().for_each(|ci| {
|
||||
if ci.ctx() != last_ctx {
|
||||
ci.change_representation(Representation::PowerBasis);
|
||||
assert!(ci.mod_switch_down_to(last_ctx).is_ok());
|
||||
ci.change_representation(Representation::Ntt);
|
||||
}
|
||||
});
|
||||
}
|
||||
/// Modulo switch the ciphertext to the last level.
|
||||
pub fn mod_switch_to_last_level(&mut self) {
|
||||
self.level = self.par.max_level();
|
||||
let last_ctx = self.par.ctx_at_level(self.level).unwrap();
|
||||
self.seed = None;
|
||||
self.c.iter_mut().for_each(|ci| {
|
||||
if ci.ctx() != last_ctx {
|
||||
ci.change_representation(Representation::PowerBasis);
|
||||
assert!(ci.mod_switch_down_to(last_ctx).is_ok());
|
||||
ci.change_representation(Representation::Ntt);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Modulo switch the ciphertext to the next level.
|
||||
pub fn mod_switch_to_next_level(&mut self) {
|
||||
if self.level < self.par.max_level() {
|
||||
self.seed = None;
|
||||
self.c.iter_mut().for_each(|ci| {
|
||||
ci.change_representation(Representation::PowerBasis);
|
||||
assert!(ci.mod_switch_down_next().is_ok());
|
||||
ci.change_representation(Representation::Ntt);
|
||||
});
|
||||
self.level += 1
|
||||
}
|
||||
}
|
||||
/// Modulo switch the ciphertext to the next level.
|
||||
pub fn mod_switch_to_next_level(&mut self) {
|
||||
if self.level < self.par.max_level() {
|
||||
self.seed = None;
|
||||
self.c.iter_mut().for_each(|ci| {
|
||||
ci.change_representation(Representation::PowerBasis);
|
||||
assert!(ci.mod_switch_down_next().is_ok());
|
||||
ci.change_representation(Representation::Ntt);
|
||||
});
|
||||
self.level += 1
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a ciphertext from a vector of polynomials.
|
||||
/// A ciphertext must contain at least two polynomials, and all polynomials
|
||||
/// must be in Ntt representation and with the same context.
|
||||
pub fn new(c: Vec<Poly>, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if c.len() < 2 {
|
||||
return Err(Error::TooFewValues(c.len(), 2));
|
||||
}
|
||||
/// Create a ciphertext from a vector of polynomials.
|
||||
/// A ciphertext must contain at least two polynomials, and all polynomials
|
||||
/// must be in Ntt representation and with the same context.
|
||||
pub fn new(c: Vec<Poly>, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if c.len() < 2 {
|
||||
return Err(Error::TooFewValues(c.len(), 2));
|
||||
}
|
||||
|
||||
let ctx = c[0].ctx();
|
||||
let level = par.level_of_ctx(ctx)?;
|
||||
let ctx = c[0].ctx();
|
||||
let level = par.level_of_ctx(ctx)?;
|
||||
|
||||
// Check that all polynomials have the expected representation and context.
|
||||
for ci in c.iter() {
|
||||
if ci.representation() != &Representation::Ntt {
|
||||
return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation(
|
||||
ci.representation().clone(),
|
||||
Representation::Ntt,
|
||||
)));
|
||||
}
|
||||
if ci.ctx() != ctx {
|
||||
return Err(Error::MathError(fhe_math::Error::InvalidContext));
|
||||
}
|
||||
}
|
||||
// Check that all polynomials have the expected representation and context.
|
||||
for ci in c.iter() {
|
||||
if ci.representation() != &Representation::Ntt {
|
||||
return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation(
|
||||
ci.representation().clone(),
|
||||
Representation::Ntt,
|
||||
)));
|
||||
}
|
||||
if ci.ctx() != ctx {
|
||||
return Err(Error::MathError(fhe_math::Error::InvalidContext));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level,
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the i-th polynomial of the ciphertext.
|
||||
pub fn get(&self, i: usize) -> Option<&Poly> {
|
||||
self.c.get(i)
|
||||
}
|
||||
/// Get the i-th polynomial of the ciphertext.
|
||||
pub fn get(&self, i: usize) -> Option<&Poly> {
|
||||
self.c.get(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl FheCiphertext for Ciphertext {}
|
||||
|
||||
impl FheParametrized for Ciphertext {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl Serialize for Ciphertext {
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
CiphertextProto::from(self).write_to_bytes().unwrap()
|
||||
}
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
CiphertextProto::from(self).write_to_bytes().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl DeserializeParametrized for Ciphertext {
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if let Ok(ctp) = CiphertextProto::parse_from_bytes(bytes) {
|
||||
Ciphertext::try_convert_from(&ctp, par)
|
||||
} else {
|
||||
Err(Error::SerializationError)
|
||||
}
|
||||
}
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if let Ok(ctp) = CiphertextProto::parse_from_bytes(bytes) {
|
||||
Ciphertext::try_convert_from(&ctp, par)
|
||||
} else {
|
||||
Err(Error::SerializationError)
|
||||
}
|
||||
}
|
||||
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
}
|
||||
|
||||
impl Ciphertext {
|
||||
/// Generate the zero ciphertext.
|
||||
pub fn zero(par: &Arc<BfvParameters>) -> Self {
|
||||
Self {
|
||||
par: par.clone(),
|
||||
seed: None,
|
||||
c: Default::default(),
|
||||
level: 0,
|
||||
}
|
||||
}
|
||||
/// Generate the zero ciphertext.
|
||||
pub fn zero(par: &Arc<BfvParameters>) -> Self {
|
||||
Self {
|
||||
par: par.clone(),
|
||||
seed: None,
|
||||
c: Default::default(),
|
||||
level: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversions from and to protobuf.
|
||||
impl From<&Ciphertext> for CiphertextProto {
|
||||
fn from(ct: &Ciphertext) -> Self {
|
||||
let mut proto = CiphertextProto::new();
|
||||
for i in 0..ct.c.len() - 1 {
|
||||
proto.c.push(ct.c[i].to_bytes())
|
||||
}
|
||||
if let Some(seed) = ct.seed {
|
||||
proto.seed = seed.to_vec()
|
||||
} else {
|
||||
proto.c.push(ct.c[ct.c.len() - 1].to_bytes())
|
||||
}
|
||||
proto.level = ct.level as u32;
|
||||
proto
|
||||
}
|
||||
fn from(ct: &Ciphertext) -> Self {
|
||||
let mut proto = CiphertextProto::new();
|
||||
for i in 0..ct.c.len() - 1 {
|
||||
proto.c.push(ct.c[i].to_bytes())
|
||||
}
|
||||
if let Some(seed) = ct.seed {
|
||||
proto.seed = seed.to_vec()
|
||||
} else {
|
||||
proto.c.push(ct.c[ct.c.len() - 1].to_bytes())
|
||||
}
|
||||
proto.level = ct.level as u32;
|
||||
proto
|
||||
}
|
||||
}
|
||||
|
||||
impl TryConvertFrom<&CiphertextProto> for Ciphertext {
|
||||
fn try_convert_from(value: &CiphertextProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if value.c.is_empty() || (value.c.len() == 1 && value.seed.is_empty()) {
|
||||
return Err(Error::DefaultError("Not enough polynomials".to_string()));
|
||||
}
|
||||
fn try_convert_from(value: &CiphertextProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if value.c.is_empty() || (value.c.len() == 1 && value.seed.is_empty()) {
|
||||
return Err(Error::DefaultError("Not enough polynomials".to_string()));
|
||||
}
|
||||
|
||||
if value.level as usize > par.max_level() {
|
||||
return Err(Error::DefaultError("Invalid level".to_string()));
|
||||
}
|
||||
if value.level as usize > par.max_level() {
|
||||
return Err(Error::DefaultError("Invalid level".to_string()));
|
||||
}
|
||||
|
||||
let ctx = par.ctx_at_level(value.level as usize)?;
|
||||
let ctx = par.ctx_at_level(value.level as usize)?;
|
||||
|
||||
let mut seed = None;
|
||||
let mut seed = None;
|
||||
|
||||
let mut c = Vec::with_capacity(value.c.len() + 1);
|
||||
for cip in &value.c {
|
||||
c.push(Poly::from_bytes(cip, ctx)?)
|
||||
}
|
||||
let mut c = Vec::with_capacity(value.c.len() + 1);
|
||||
for cip in &value.c {
|
||||
c.push(Poly::from_bytes(cip, ctx)?)
|
||||
}
|
||||
|
||||
if !value.seed.is_empty() {
|
||||
let try_seed = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
|
||||
if try_seed.is_err() {
|
||||
return Err(Error::MathError(fhe_math::Error::InvalidSeedSize(
|
||||
value.seed.len(),
|
||||
<ChaCha8Rng as SeedableRng>::Seed::default().len(),
|
||||
)));
|
||||
}
|
||||
seed = try_seed.ok();
|
||||
let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, seed.unwrap());
|
||||
unsafe { c1.allow_variable_time_computations() }
|
||||
c.push(c1)
|
||||
}
|
||||
if !value.seed.is_empty() {
|
||||
let try_seed = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
|
||||
if try_seed.is_err() {
|
||||
return Err(Error::MathError(fhe_math::Error::InvalidSeedSize(
|
||||
value.seed.len(),
|
||||
<ChaCha8Rng as SeedableRng>::Seed::default().len(),
|
||||
)));
|
||||
}
|
||||
seed = try_seed.ok();
|
||||
let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, seed.unwrap());
|
||||
unsafe { c1.allow_variable_time_computations() }
|
||||
c.push(c1)
|
||||
}
|
||||
|
||||
Ok(Ciphertext {
|
||||
par: par.clone(),
|
||||
seed,
|
||||
c,
|
||||
level: value.level as usize,
|
||||
})
|
||||
}
|
||||
Ok(Ciphertext {
|
||||
par: par.clone(),
|
||||
seed,
|
||||
c,
|
||||
level: value.level as usize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bfv::{
|
||||
proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, BfvParameters,
|
||||
Ciphertext, Encoding, Plaintext, SecretKey,
|
||||
};
|
||||
use fhe_traits::FheDecrypter;
|
||||
use fhe_traits::{DeserializeParametrized, FheEncoder, FheEncrypter, Serialize};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use crate::bfv::{
|
||||
proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, BfvParameters,
|
||||
Ciphertext, Encoding, Plaintext, SecretKey,
|
||||
};
|
||||
use fhe_traits::FheDecrypter;
|
||||
use fhe_traits::{DeserializeParametrized, FheEncoder, FheEncrypter, Serialize};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct_proto = CiphertextProto::from(&ct);
|
||||
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?);
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct_proto = CiphertextProto::from(&ct);
|
||||
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?);
|
||||
|
||||
let ct = &ct * &ct;
|
||||
let ct_proto = CiphertextProto::from(&ct);
|
||||
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
let ct = &ct * &ct;
|
||||
let ct_proto = CiphertextProto::from(&ct);
|
||||
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, ¶ms)?)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct_bytes = ct.to_bytes();
|
||||
assert_eq!(ct, Ciphertext::from_bytes(&ct_bytes, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct_bytes = ct.to_bytes();
|
||||
assert_eq!(ct, Ciphertext::from_bytes(&ct_bytes, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let mut ct3 = &ct * &ct;
|
||||
#[test]
|
||||
fn new() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let mut ct3 = &ct * &ct;
|
||||
|
||||
let c0 = ct3.get(0).unwrap();
|
||||
let c1 = ct3.get(1).unwrap();
|
||||
let c2 = ct3.get(2).unwrap();
|
||||
let c0 = ct3.get(0).unwrap();
|
||||
let c1 = ct3.get(1).unwrap();
|
||||
let c2 = ct3.get(2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
ct3,
|
||||
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?
|
||||
);
|
||||
assert_eq!(ct3.level, 0);
|
||||
assert_eq!(
|
||||
ct3,
|
||||
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?
|
||||
);
|
||||
assert_eq!(ct3.level, 0);
|
||||
|
||||
ct3.mod_switch_to_last_level();
|
||||
ct3.mod_switch_to_last_level();
|
||||
|
||||
let c0 = ct3.get(0).unwrap();
|
||||
let c1 = ct3.get(1).unwrap();
|
||||
let c2 = ct3.get(2).unwrap();
|
||||
assert_eq!(
|
||||
ct3,
|
||||
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?
|
||||
);
|
||||
assert_eq!(ct3.level, params.max_level());
|
||||
}
|
||||
let c0 = ct3.get(0).unwrap();
|
||||
let c1 = ct3.get(1).unwrap();
|
||||
let c2 = ct3.get(2).unwrap();
|
||||
assert_eq!(
|
||||
ct3,
|
||||
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?
|
||||
);
|
||||
assert_eq!(ct3.level, params.max_level());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mod_switch_to_last_level() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
#[test]
|
||||
fn mod_switch_to_last_level() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
|
||||
assert_eq!(ct.level, 0);
|
||||
ct.mod_switch_to_last_level();
|
||||
assert_eq!(ct.level, params.max_level());
|
||||
assert_eq!(ct.level, 0);
|
||||
ct.mod_switch_to_last_level();
|
||||
assert_eq!(ct.level, params.max_level());
|
||||
|
||||
let decrypted = sk.try_decrypt(&ct)?;
|
||||
assert_eq!(decrypted.value, pt.value);
|
||||
}
|
||||
let decrypted = sk.try_decrypt(&ct)?;
|
||||
assert_eq!(decrypted.value, pt.value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,71 +6,71 @@ use fhe_traits::FhePlaintextEncoding;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub(crate) enum EncodingEnum {
|
||||
Poly,
|
||||
Simd,
|
||||
Poly,
|
||||
Simd,
|
||||
}
|
||||
|
||||
impl Display for EncodingEnum {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
/// An encoding for the plaintext.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct Encoding {
|
||||
pub(crate) encoding: EncodingEnum,
|
||||
pub(crate) level: usize,
|
||||
pub(crate) encoding: EncodingEnum,
|
||||
pub(crate) level: usize,
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
/// A Poly encoding encodes a vector as coefficients of a polynomial;
|
||||
/// homomorphic operations are therefore polynomial operations.
|
||||
pub fn poly() -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Poly,
|
||||
level: 0,
|
||||
}
|
||||
}
|
||||
/// A Poly encoding encodes a vector as coefficients of a polynomial;
|
||||
/// homomorphic operations are therefore polynomial operations.
|
||||
pub fn poly() -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Poly,
|
||||
level: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// A Simd encoding encodes a vector so that homomorphic operations are
|
||||
/// component-wise operations on the coefficients of the underlying vectors.
|
||||
/// The Simd encoding require that the plaintext modulus is congruent to 1
|
||||
/// modulo the degree of the underlying polynomial.
|
||||
pub fn simd() -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Simd,
|
||||
level: 0,
|
||||
}
|
||||
}
|
||||
/// A Simd encoding encodes a vector so that homomorphic operations are
|
||||
/// component-wise operations on the coefficients of the underlying vectors.
|
||||
/// The Simd encoding require that the plaintext modulus is congruent to 1
|
||||
/// modulo the degree of the underlying polynomial.
|
||||
pub fn simd() -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Simd,
|
||||
level: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// A poly encoding at a given level.
|
||||
pub fn poly_at_level(level: usize) -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Poly,
|
||||
level,
|
||||
}
|
||||
}
|
||||
/// A poly encoding at a given level.
|
||||
pub fn poly_at_level(level: usize) -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Poly,
|
||||
level,
|
||||
}
|
||||
}
|
||||
|
||||
/// A simd encoding at a given level.
|
||||
pub fn simd_at_level(level: usize) -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Simd,
|
||||
level,
|
||||
}
|
||||
}
|
||||
/// A simd encoding at a given level.
|
||||
pub fn simd_at_level(level: usize) -> Self {
|
||||
Self {
|
||||
encoding: EncodingEnum::Simd,
|
||||
level,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Encoding> for String {
|
||||
fn from(e: Encoding) -> Self {
|
||||
String::from(&e)
|
||||
}
|
||||
fn from(e: Encoding) -> Self {
|
||||
String::from(&e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Encoding> for String {
|
||||
fn from(e: &Encoding) -> Self {
|
||||
format!("{e:?}")
|
||||
}
|
||||
fn from(e: &Encoding) -> Self {
|
||||
format!("{e:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl FhePlaintextEncoding for Encoding {}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,14 +2,14 @@
|
||||
|
||||
use super::key_switching_key::KeySwitchingKey;
|
||||
use crate::bfv::{
|
||||
proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto},
|
||||
traits::TryConvertFrom,
|
||||
BfvParameters, Ciphertext, SecretKey,
|
||||
proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto},
|
||||
traits::TryConvertFrom,
|
||||
BfvParameters, Ciphertext, SecretKey,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use fhe_math::rq::{
|
||||
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
|
||||
SubstitutionExponent,
|
||||
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
|
||||
SubstitutionExponent,
|
||||
};
|
||||
use protobuf::MessageField;
|
||||
use rand::{CryptoRng, RngCore};
|
||||
@@ -21,181 +21,181 @@ use zeroize::Zeroizing;
|
||||
/// which switch from `s(x^i)` to `s(x)` where `s(x)` is the secret key.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct GaloisKey {
|
||||
pub(crate) element: SubstitutionExponent,
|
||||
pub(crate) ksk: KeySwitchingKey,
|
||||
pub(crate) element: SubstitutionExponent,
|
||||
pub(crate) ksk: KeySwitchingKey,
|
||||
}
|
||||
|
||||
impl GaloisKey {
|
||||
/// Generate a [`GaloisKey`] from a [`SecretKey`].
|
||||
pub fn new<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
exponent: usize,
|
||||
ciphertext_level: usize,
|
||||
galois_key_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
let ctx_galois_key = sk.par.ctx_at_level(galois_key_level)?;
|
||||
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
|
||||
/// Generate a [`GaloisKey`] from a [`SecretKey`].
|
||||
pub fn new<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
exponent: usize,
|
||||
ciphertext_level: usize,
|
||||
galois_key_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
let ctx_galois_key = sk.par.ctx_at_level(galois_key_level)?;
|
||||
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
|
||||
|
||||
let ciphertext_exponent =
|
||||
SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?;
|
||||
let ciphertext_exponent =
|
||||
SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?;
|
||||
|
||||
let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?;
|
||||
let s = Zeroizing::new(Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx_ciphertext,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?);
|
||||
let mut s_sub_switched_up = Zeroizing::new(s_sub.mod_switch_to(&switcher_up)?);
|
||||
s_sub_switched_up.change_representation(Representation::PowerBasis);
|
||||
let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?;
|
||||
let s = Zeroizing::new(Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx_ciphertext,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?);
|
||||
let mut s_sub_switched_up = Zeroizing::new(s_sub.mod_switch_to(&switcher_up)?);
|
||||
s_sub_switched_up.change_representation(Representation::PowerBasis);
|
||||
|
||||
let ksk = KeySwitchingKey::new(
|
||||
sk,
|
||||
&s_sub_switched_up,
|
||||
ciphertext_level,
|
||||
galois_key_level,
|
||||
rng,
|
||||
)?;
|
||||
let ksk = KeySwitchingKey::new(
|
||||
sk,
|
||||
&s_sub_switched_up,
|
||||
ciphertext_level,
|
||||
galois_key_level,
|
||||
rng,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
element: ciphertext_exponent,
|
||||
ksk,
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
element: ciphertext_exponent,
|
||||
ksk,
|
||||
})
|
||||
}
|
||||
|
||||
/// Relinearize a [`Ciphertext`] using the [`GaloisKey`]
|
||||
pub fn relinearize(&self, ct: &Ciphertext) -> Result<Ciphertext> {
|
||||
// assert_eq!(ct.par, self.ksk.par);
|
||||
assert_eq!(ct.c.len(), 2);
|
||||
/// Relinearize a [`Ciphertext`] using the [`GaloisKey`]
|
||||
pub fn relinearize(&self, ct: &Ciphertext) -> Result<Ciphertext> {
|
||||
// assert_eq!(ct.par, self.ksk.par);
|
||||
assert_eq!(ct.c.len(), 2);
|
||||
|
||||
let mut c2 = ct.c[1].substitute(&self.element)?;
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let (mut c0, mut c1) = self.ksk.key_switch(&c2)?;
|
||||
let mut c2 = ct.c[1].substitute(&self.element)?;
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let (mut c0, mut c1) = self.ksk.key_switch(&c2)?;
|
||||
|
||||
if c0.ctx() != ct.c[0].ctx() {
|
||||
c0.change_representation(Representation::PowerBasis);
|
||||
c1.change_representation(Representation::PowerBasis);
|
||||
c0.mod_switch_down_to(ct.c[0].ctx())?;
|
||||
c1.mod_switch_down_to(ct.c[1].ctx())?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c1.change_representation(Representation::Ntt);
|
||||
}
|
||||
if c0.ctx() != ct.c[0].ctx() {
|
||||
c0.change_representation(Representation::PowerBasis);
|
||||
c1.change_representation(Representation::PowerBasis);
|
||||
c0.mod_switch_down_to(ct.c[0].ctx())?;
|
||||
c1.mod_switch_down_to(ct.c[1].ctx())?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c1.change_representation(Representation::Ntt);
|
||||
}
|
||||
|
||||
c0 += &ct.c[0].substitute(&self.element)?;
|
||||
c0 += &ct.c[0].substitute(&self.element)?;
|
||||
|
||||
Ok(Ciphertext {
|
||||
par: ct.par.clone(),
|
||||
seed: None,
|
||||
c: vec![c0, c1],
|
||||
level: self.ksk.ciphertext_level,
|
||||
})
|
||||
}
|
||||
Ok(Ciphertext {
|
||||
par: ct.par.clone(),
|
||||
seed: None,
|
||||
c: vec![c0, c1],
|
||||
level: self.ksk.ciphertext_level,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&GaloisKey> for GaloisKeyProto {
|
||||
fn from(value: &GaloisKey) -> Self {
|
||||
let mut gk = GaloisKeyProto::new();
|
||||
gk.exponent = value.element.exponent as u32;
|
||||
gk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
|
||||
gk
|
||||
}
|
||||
fn from(value: &GaloisKey) -> Self {
|
||||
let mut gk = GaloisKeyProto::new();
|
||||
gk.exponent = value.element.exponent as u32;
|
||||
gk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
|
||||
gk
|
||||
}
|
||||
}
|
||||
|
||||
impl TryConvertFrom<&GaloisKeyProto> for GaloisKey {
|
||||
fn try_convert_from(value: &GaloisKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if par.moduli.len() == 1 {
|
||||
Err(Error::DefaultError(
|
||||
"Invalid parameters for a relinearization key".to_string(),
|
||||
))
|
||||
} else if value.ksk.is_some() {
|
||||
let ksk = KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?;
|
||||
fn try_convert_from(value: &GaloisKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if par.moduli.len() == 1 {
|
||||
Err(Error::DefaultError(
|
||||
"Invalid parameters for a relinearization key".to_string(),
|
||||
))
|
||||
} else if value.ksk.is_some() {
|
||||
let ksk = KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?;
|
||||
|
||||
let ctx = par.ctx_at_level(ksk.ciphertext_level)?;
|
||||
let element = SubstitutionExponent::new(ctx, value.exponent as usize)
|
||||
.map_err(Error::MathError)?;
|
||||
let ctx = par.ctx_at_level(ksk.ciphertext_level)?;
|
||||
let element = SubstitutionExponent::new(ctx, value.exponent as usize)
|
||||
.map_err(Error::MathError)?;
|
||||
|
||||
Ok(GaloisKey { element, ksk })
|
||||
} else {
|
||||
Err(Error::DefaultError("Invalid serialization".to_string()))
|
||||
}
|
||||
}
|
||||
Ok(GaloisKey { element, ksk })
|
||||
} else {
|
||||
Err(Error::DefaultError("Invalid serialization".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::GaloisKey;
|
||||
use crate::bfv::{
|
||||
proto::bfv::GaloisKey as GaloisKeyProto, traits::TryConvertFrom, BfvParameters, Encoding,
|
||||
Plaintext, SecretKey,
|
||||
};
|
||||
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use super::GaloisKey;
|
||||
use crate::bfv::{
|
||||
proto::bfv::GaloisKey as GaloisKeyProto, traits::TryConvertFrom, BfvParameters, Encoding,
|
||||
Plaintext, SecretKey,
|
||||
};
|
||||
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn relinearization() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
for _ in 0..30 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let row_size = params.degree() >> 1;
|
||||
#[test]
|
||||
fn relinearization() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
for _ in 0..30 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let row_size = params.degree() >> 1;
|
||||
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct = sk.try_encrypt(&pt, &mut rng)?;
|
||||
|
||||
for i in 1..2 * params.degree() {
|
||||
if i & 1 == 0 {
|
||||
assert!(GaloisKey::new(&sk, i, 0, 0, &mut rng).is_err())
|
||||
} else {
|
||||
let gk = GaloisKey::new(&sk, i, 0, 0, &mut rng)?;
|
||||
let ct2 = gk.relinearize(&ct)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? });
|
||||
for i in 1..2 * params.degree() {
|
||||
if i & 1 == 0 {
|
||||
assert!(GaloisKey::new(&sk, i, 0, 0, &mut rng).is_err())
|
||||
} else {
|
||||
let gk = GaloisKey::new(&sk, i, 0, 0, &mut rng)?;
|
||||
let ct2 = gk.relinearize(&ct)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? });
|
||||
|
||||
if i == 3 {
|
||||
let pt = sk.try_decrypt(&ct2)?;
|
||||
if i == 3 {
|
||||
let pt = sk.try_decrypt(&ct2)?;
|
||||
|
||||
// The expected result is rotated one on the left
|
||||
let mut expected = vec![0u64; params.degree()];
|
||||
expected[..row_size - 1].copy_from_slice(&v[1..row_size]);
|
||||
expected[row_size - 1] = v[0];
|
||||
expected[row_size..2 * row_size - 1]
|
||||
.copy_from_slice(&v[row_size + 1..]);
|
||||
expected[2 * row_size - 1] = v[row_size];
|
||||
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
|
||||
} else if i == params.degree() * 2 - 1 {
|
||||
let pt = sk.try_decrypt(&ct2)?;
|
||||
// The expected result is rotated one on the left
|
||||
let mut expected = vec![0u64; params.degree()];
|
||||
expected[..row_size - 1].copy_from_slice(&v[1..row_size]);
|
||||
expected[row_size - 1] = v[0];
|
||||
expected[row_size..2 * row_size - 1]
|
||||
.copy_from_slice(&v[row_size + 1..]);
|
||||
expected[2 * row_size - 1] = v[row_size];
|
||||
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
|
||||
} else if i == params.degree() * 2 - 1 {
|
||||
let pt = sk.try_decrypt(&ct2)?;
|
||||
|
||||
// The expected result has its rows swapped
|
||||
let mut expected = vec![0u64; params.degree()];
|
||||
expected[..row_size].copy_from_slice(&v[row_size..]);
|
||||
expected[row_size..].copy_from_slice(&v[..row_size]);
|
||||
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
// The expected result has its rows swapped
|
||||
let mut expected = vec![0u64; params.degree()];
|
||||
expected[..row_size].copy_from_slice(&v[row_size..]);
|
||||
expected[row_size..].copy_from_slice(&v[..row_size]);
|
||||
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(4, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?;
|
||||
let proto = GaloisKeyProto::from(&gk);
|
||||
assert_eq!(gk, GaloisKey::try_convert_from(&proto, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(4, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?;
|
||||
let proto = GaloisKeyProto::from(&gk);
|
||||
assert_eq!(gk, GaloisKey::try_convert_from(&proto, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
//! Key-switching keys for the BFV encryption scheme
|
||||
|
||||
use crate::bfv::{
|
||||
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto,
|
||||
traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey,
|
||||
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto,
|
||||
traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use fhe_math::rq::traits::TryConvertFrom;
|
||||
use fhe_math::rq::Context;
|
||||
use fhe_math::{
|
||||
rns::RnsContext,
|
||||
rq::{Poly, Representation},
|
||||
rns::RnsContext,
|
||||
rq::{Poly, Representation},
|
||||
};
|
||||
use fhe_traits::{DeserializeWithContext, Serialize};
|
||||
use itertools::izip;
|
||||
@@ -21,335 +21,335 @@ use zeroize::Zeroizing;
|
||||
/// Key switching key for the BFV encryption scheme.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct KeySwitchingKey {
|
||||
/// The parameters of the underlying BFV encryption scheme.
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
/// The parameters of the underlying BFV encryption scheme.
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
|
||||
/// The (optional) seed that generated the polynomials c1.
|
||||
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
|
||||
/// The (optional) seed that generated the polynomials c1.
|
||||
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
|
||||
|
||||
/// The key switching elements c0.
|
||||
pub(crate) c0: Box<[Poly]>,
|
||||
/// The key switching elements c0.
|
||||
pub(crate) c0: Box<[Poly]>,
|
||||
|
||||
/// The key switching elements c1.
|
||||
pub(crate) c1: Box<[Poly]>,
|
||||
/// The key switching elements c1.
|
||||
pub(crate) c1: Box<[Poly]>,
|
||||
|
||||
/// The level and context of the polynomials that will be key switched.
|
||||
pub(crate) ciphertext_level: usize,
|
||||
pub(crate) ctx_ciphertext: Arc<Context>,
|
||||
/// The level and context of the polynomials that will be key switched.
|
||||
pub(crate) ciphertext_level: usize,
|
||||
pub(crate) ctx_ciphertext: Arc<Context>,
|
||||
|
||||
/// The level and context of the key switching key.
|
||||
pub(crate) ksk_level: usize,
|
||||
pub(crate) ctx_ksk: Arc<Context>,
|
||||
/// The level and context of the key switching key.
|
||||
pub(crate) ksk_level: usize,
|
||||
pub(crate) ctx_ksk: Arc<Context>,
|
||||
}
|
||||
|
||||
impl KeySwitchingKey {
|
||||
/// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial
|
||||
/// `from`.
|
||||
pub fn new<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
from: &Poly,
|
||||
ciphertext_level: usize,
|
||||
ksk_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
let ctx_ksk = sk.par.ctx_at_level(ksk_level)?;
|
||||
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
|
||||
/// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial
|
||||
/// `from`.
|
||||
pub fn new<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
from: &Poly,
|
||||
ciphertext_level: usize,
|
||||
ksk_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
let ctx_ksk = sk.par.ctx_at_level(ksk_level)?;
|
||||
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
|
||||
|
||||
if ctx_ksk.moduli().len() == 1 {
|
||||
return Err(Error::DefaultError(
|
||||
"These parameters do not support key switching".to_string(),
|
||||
));
|
||||
}
|
||||
if ctx_ksk.moduli().len() == 1 {
|
||||
return Err(Error::DefaultError(
|
||||
"These parameters do not support key switching".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if from.ctx() != ctx_ksk {
|
||||
return Err(Error::DefaultError(
|
||||
"Incorrect context for polynomial from".to_string(),
|
||||
));
|
||||
}
|
||||
if from.ctx() != ctx_ksk {
|
||||
return Err(Error::DefaultError(
|
||||
"Incorrect context for polynomial from".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
|
||||
rng.fill(&mut seed);
|
||||
let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len());
|
||||
let c0 = Self::generate_c0(sk, from, &c1, rng)?;
|
||||
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
|
||||
rng.fill(&mut seed);
|
||||
let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len());
|
||||
let c0 = Self::generate_c0(sk, from, &c1, rng)?;
|
||||
|
||||
Ok(Self {
|
||||
par: sk.par.clone(),
|
||||
seed: Some(seed),
|
||||
c0: c0.into_boxed_slice(),
|
||||
c1: c1.into_boxed_slice(),
|
||||
ciphertext_level,
|
||||
ctx_ciphertext: ctx_ciphertext.clone(),
|
||||
ksk_level,
|
||||
ctx_ksk: ctx_ksk.clone(),
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
par: sk.par.clone(),
|
||||
seed: Some(seed),
|
||||
c0: c0.into_boxed_slice(),
|
||||
c1: c1.into_boxed_slice(),
|
||||
ciphertext_level,
|
||||
ctx_ciphertext: ctx_ciphertext.clone(),
|
||||
ksk_level,
|
||||
ctx_ksk: ctx_ksk.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate the c1's from the seed
|
||||
fn generate_c1(
|
||||
ctx: &Arc<Context>,
|
||||
seed: <ChaCha8Rng as SeedableRng>::Seed,
|
||||
size: usize,
|
||||
) -> Vec<Poly> {
|
||||
let mut c1 = Vec::with_capacity(size);
|
||||
let mut rng = ChaCha8Rng::from_seed(seed);
|
||||
(0..size).for_each(|_| {
|
||||
let mut seed_i = <ChaCha8Rng as SeedableRng>::Seed::default();
|
||||
rng.fill(&mut seed_i);
|
||||
let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i);
|
||||
unsafe { a.allow_variable_time_computations() }
|
||||
c1.push(a);
|
||||
});
|
||||
c1
|
||||
}
|
||||
/// Generate the c1's from the seed
|
||||
fn generate_c1(
|
||||
ctx: &Arc<Context>,
|
||||
seed: <ChaCha8Rng as SeedableRng>::Seed,
|
||||
size: usize,
|
||||
) -> Vec<Poly> {
|
||||
let mut c1 = Vec::with_capacity(size);
|
||||
let mut rng = ChaCha8Rng::from_seed(seed);
|
||||
(0..size).for_each(|_| {
|
||||
let mut seed_i = <ChaCha8Rng as SeedableRng>::Seed::default();
|
||||
rng.fill(&mut seed_i);
|
||||
let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i);
|
||||
unsafe { a.allow_variable_time_computations() }
|
||||
c1.push(a);
|
||||
});
|
||||
c1
|
||||
}
|
||||
|
||||
/// Generate the c0's from the c1's and the secret key
|
||||
fn generate_c0<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
from: &Poly,
|
||||
c1: &[Poly],
|
||||
rng: &mut R,
|
||||
) -> Result<Vec<Poly>> {
|
||||
if c1.is_empty() {
|
||||
return Err(Error::DefaultError("Empty number of c1's".to_string()));
|
||||
}
|
||||
if from.representation() != &Representation::PowerBasis {
|
||||
return Err(Error::DefaultError(
|
||||
"Unexpected representation for from".to_string(),
|
||||
));
|
||||
}
|
||||
/// Generate the c0's from the c1's and the secret key
|
||||
fn generate_c0<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
from: &Poly,
|
||||
c1: &[Poly],
|
||||
rng: &mut R,
|
||||
) -> Result<Vec<Poly>> {
|
||||
if c1.is_empty() {
|
||||
return Err(Error::DefaultError("Empty number of c1's".to_string()));
|
||||
}
|
||||
if from.representation() != &Representation::PowerBasis {
|
||||
return Err(Error::DefaultError(
|
||||
"Unexpected representation for from".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let size = c1.len();
|
||||
let size = c1.len();
|
||||
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
c1[0].ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
c1[0].ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
|
||||
let rns = RnsContext::new(&sk.par.moduli[..size])?;
|
||||
let c0 = c1
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c1i)| {
|
||||
let mut a_s = Zeroizing::new(c1i.clone());
|
||||
a_s.disallow_variable_time_computations();
|
||||
a_s.change_representation(Representation::Ntt);
|
||||
*a_s.as_mut() *= s.as_ref();
|
||||
a_s.change_representation(Representation::PowerBasis);
|
||||
let rns = RnsContext::new(&sk.par.moduli[..size])?;
|
||||
let c0 = c1
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c1i)| {
|
||||
let mut a_s = Zeroizing::new(c1i.clone());
|
||||
a_s.disallow_variable_time_computations();
|
||||
a_s.change_representation(Representation::Ntt);
|
||||
*a_s.as_mut() *= s.as_ref();
|
||||
a_s.change_representation(Representation::PowerBasis);
|
||||
|
||||
let mut b =
|
||||
Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?;
|
||||
b -= &a_s;
|
||||
let mut b =
|
||||
Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?;
|
||||
b -= &a_s;
|
||||
|
||||
let gi = rns.get_garner(i).unwrap();
|
||||
let g_i_from = Zeroizing::new(gi * from);
|
||||
b += &g_i_from;
|
||||
let gi = rns.get_garner(i).unwrap();
|
||||
let g_i_from = Zeroizing::new(gi * from);
|
||||
b += &g_i_from;
|
||||
|
||||
// It is now safe to enable variable time computations.
|
||||
unsafe { b.allow_variable_time_computations() }
|
||||
b.change_representation(Representation::NttShoup);
|
||||
Ok(b)
|
||||
})
|
||||
.collect::<Result<Vec<Poly>>>()?;
|
||||
// It is now safe to enable variable time computations.
|
||||
unsafe { b.allow_variable_time_computations() }
|
||||
b.change_representation(Representation::NttShoup);
|
||||
Ok(b)
|
||||
})
|
||||
.collect::<Result<Vec<Poly>>>()?;
|
||||
|
||||
Ok(c0)
|
||||
}
|
||||
Ok(c0)
|
||||
}
|
||||
|
||||
/// Key switch a polynomial.
|
||||
pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> {
|
||||
if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() {
|
||||
return Err(Error::DefaultError(
|
||||
"The input polynomial does not have the correct context.".to_string(),
|
||||
));
|
||||
}
|
||||
if p.representation() != &Representation::PowerBasis {
|
||||
return Err(Error::DefaultError("Incorrect representation".to_string()));
|
||||
}
|
||||
/// Key switch a polynomial.
|
||||
pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> {
|
||||
if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() {
|
||||
return Err(Error::DefaultError(
|
||||
"The input polynomial does not have the correct context.".to_string(),
|
||||
));
|
||||
}
|
||||
if p.representation() != &Representation::PowerBasis {
|
||||
return Err(Error::DefaultError("Incorrect representation".to_string()));
|
||||
}
|
||||
|
||||
let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
|
||||
let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
|
||||
for (c2_i_coefficients, c0_i, c1_i) in izip!(
|
||||
p.coefficients().outer_iter(),
|
||||
self.c0.iter(),
|
||||
self.c1.iter()
|
||||
) {
|
||||
let mut c2_i = unsafe {
|
||||
Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time(
|
||||
c2_i_coefficients.as_slice().unwrap(),
|
||||
&self.ctx_ksk,
|
||||
)
|
||||
};
|
||||
c0 += &(&c2_i * c0_i);
|
||||
c2_i *= c1_i;
|
||||
c1 += &c2_i;
|
||||
}
|
||||
Ok((c0, c1))
|
||||
}
|
||||
let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
|
||||
let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
|
||||
for (c2_i_coefficients, c0_i, c1_i) in izip!(
|
||||
p.coefficients().outer_iter(),
|
||||
self.c0.iter(),
|
||||
self.c1.iter()
|
||||
) {
|
||||
let mut c2_i = unsafe {
|
||||
Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time(
|
||||
c2_i_coefficients.as_slice().unwrap(),
|
||||
&self.ctx_ksk,
|
||||
)
|
||||
};
|
||||
c0 += &(&c2_i * c0_i);
|
||||
c2_i *= c1_i;
|
||||
c1 += &c2_i;
|
||||
}
|
||||
Ok((c0, c1))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&KeySwitchingKey> for KeySwitchingKeyProto {
|
||||
fn from(value: &KeySwitchingKey) -> Self {
|
||||
let mut ksk = KeySwitchingKeyProto::new();
|
||||
if let Some(seed) = value.seed.as_ref() {
|
||||
ksk.seed = seed.to_vec();
|
||||
} else {
|
||||
ksk.c1.reserve_exact(value.c1.len());
|
||||
for c1 in value.c1.iter() {
|
||||
ksk.c1.push(c1.to_bytes())
|
||||
}
|
||||
}
|
||||
ksk.c0.reserve_exact(value.c0.len());
|
||||
for c0 in value.c0.iter() {
|
||||
ksk.c0.push(c0.to_bytes())
|
||||
}
|
||||
ksk.ciphertext_level = value.ciphertext_level as u32;
|
||||
ksk.ksk_level = value.ksk_level as u32;
|
||||
ksk
|
||||
}
|
||||
fn from(value: &KeySwitchingKey) -> Self {
|
||||
let mut ksk = KeySwitchingKeyProto::new();
|
||||
if let Some(seed) = value.seed.as_ref() {
|
||||
ksk.seed = seed.to_vec();
|
||||
} else {
|
||||
ksk.c1.reserve_exact(value.c1.len());
|
||||
for c1 in value.c1.iter() {
|
||||
ksk.c1.push(c1.to_bytes())
|
||||
}
|
||||
}
|
||||
ksk.c0.reserve_exact(value.c0.len());
|
||||
for c0 in value.c0.iter() {
|
||||
ksk.c0.push(c0.to_bytes())
|
||||
}
|
||||
ksk.ciphertext_level = value.ciphertext_level as u32;
|
||||
ksk.ksk_level = value.ksk_level as u32;
|
||||
ksk
|
||||
}
|
||||
}
|
||||
|
||||
impl BfvTryConvertFrom<&KeySwitchingKeyProto> for KeySwitchingKey {
|
||||
fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
let ciphertext_level = value.ciphertext_level as usize;
|
||||
let ksk_level = value.ksk_level as usize;
|
||||
let ctx_ksk = par.ctx_at_level(ksk_level)?;
|
||||
let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?;
|
||||
fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
let ciphertext_level = value.ciphertext_level as usize;
|
||||
let ksk_level = value.ksk_level as usize;
|
||||
let ctx_ksk = par.ctx_at_level(ksk_level)?;
|
||||
let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?;
|
||||
|
||||
if value.c0.len() != ctx_ciphertext.moduli().len() {
|
||||
return Err(Error::DefaultError(
|
||||
"Incorrect number of values in c0".to_string(),
|
||||
));
|
||||
}
|
||||
if value.c0.len() != ctx_ciphertext.moduli().len() {
|
||||
return Err(Error::DefaultError(
|
||||
"Incorrect number of values in c0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let seed = if value.seed.is_empty() {
|
||||
if value.c1.len() != ctx_ciphertext.moduli().len() {
|
||||
return Err(Error::DefaultError(
|
||||
"Incorrect number of values in c1".to_string(),
|
||||
));
|
||||
}
|
||||
None
|
||||
} else {
|
||||
let unwrapped = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
|
||||
if unwrapped.is_err() {
|
||||
return Err(Error::DefaultError("Invalid seed".to_string()));
|
||||
}
|
||||
Some(unwrapped.unwrap())
|
||||
};
|
||||
let seed = if value.seed.is_empty() {
|
||||
if value.c1.len() != ctx_ciphertext.moduli().len() {
|
||||
return Err(Error::DefaultError(
|
||||
"Incorrect number of values in c1".to_string(),
|
||||
));
|
||||
}
|
||||
None
|
||||
} else {
|
||||
let unwrapped = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
|
||||
if unwrapped.is_err() {
|
||||
return Err(Error::DefaultError("Invalid seed".to_string()));
|
||||
}
|
||||
Some(unwrapped.unwrap())
|
||||
};
|
||||
|
||||
let c1 = if let Some(seed) = seed {
|
||||
Self::generate_c1(ctx_ksk, seed, value.c0.len())
|
||||
} else {
|
||||
value
|
||||
.c1
|
||||
.iter()
|
||||
.map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError))
|
||||
.collect::<Result<Vec<Poly>>>()?
|
||||
};
|
||||
let c1 = if let Some(seed) = seed {
|
||||
Self::generate_c1(ctx_ksk, seed, value.c0.len())
|
||||
} else {
|
||||
value
|
||||
.c1
|
||||
.iter()
|
||||
.map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError))
|
||||
.collect::<Result<Vec<Poly>>>()?
|
||||
};
|
||||
|
||||
let c0 = value
|
||||
.c0
|
||||
.iter()
|
||||
.map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError))
|
||||
.collect::<Result<Vec<Poly>>>()?;
|
||||
let c0 = value
|
||||
.c0
|
||||
.iter()
|
||||
.map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError))
|
||||
.collect::<Result<Vec<Poly>>>()?;
|
||||
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
seed,
|
||||
c0: c0.into_boxed_slice(),
|
||||
c1: c1.into_boxed_slice(),
|
||||
ciphertext_level,
|
||||
ctx_ciphertext: ctx_ciphertext.clone(),
|
||||
ksk_level,
|
||||
ctx_ksk: ctx_ksk.clone(),
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
seed,
|
||||
c0: c0.into_boxed_slice(),
|
||||
c1: c1.into_boxed_slice(),
|
||||
ciphertext_level,
|
||||
ctx_ciphertext: ctx_ciphertext.clone(),
|
||||
ksk_level,
|
||||
ctx_ksk: ctx_ksk.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bfv::{
|
||||
keys::key_switching_key::KeySwitchingKey,
|
||||
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, traits::TryConvertFrom, BfvParameters,
|
||||
SecretKey,
|
||||
};
|
||||
use fhe_math::{
|
||||
rns::RnsContext,
|
||||
rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation},
|
||||
};
|
||||
use num_bigint::BigUint;
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use crate::bfv::{
|
||||
keys::key_switching_key::KeySwitchingKey,
|
||||
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, traits::TryConvertFrom, BfvParameters,
|
||||
SecretKey,
|
||||
};
|
||||
use fhe_math::{
|
||||
rns::RnsContext,
|
||||
rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation},
|
||||
};
|
||||
use num_bigint::BigUint;
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn constructor() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
|
||||
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng);
|
||||
assert!(ksk.is_ok());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn constructor() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
|
||||
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng);
|
||||
assert!(ksk.is_ok());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_switch() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [Arc::new(BfvParameters::default(6, 8))] {
|
||||
for _ in 0..100 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
|
||||
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
|
||||
let mut s = Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
.map_err(crate::Error::MathError)?;
|
||||
s.change_representation(Representation::Ntt);
|
||||
#[test]
|
||||
fn key_switch() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [Arc::new(BfvParameters::default(6, 8))] {
|
||||
for _ in 0..100 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
|
||||
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
|
||||
let mut s = Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
.map_err(crate::Error::MathError)?;
|
||||
s.change_representation(Representation::Ntt);
|
||||
|
||||
let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng);
|
||||
let (c0, c1) = ksk.key_switch(&input)?;
|
||||
let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng);
|
||||
let (c0, c1) = ksk.key_switch(&input)?;
|
||||
|
||||
let mut c2 = &c0 + &(&c1 * &s);
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let mut c2 = &c0 + &(&c1 * &s);
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
|
||||
input.change_representation(Representation::Ntt);
|
||||
p.change_representation(Representation::Ntt);
|
||||
let mut c3 = &input * &p;
|
||||
c3.change_representation(Representation::PowerBasis);
|
||||
input.change_representation(Representation::Ntt);
|
||||
p.change_representation(Representation::Ntt);
|
||||
let mut c3 = &input * &p;
|
||||
c3.change_representation(Representation::PowerBasis);
|
||||
|
||||
let rns = RnsContext::new(¶ms.moduli)?;
|
||||
Vec::<BigUint>::from(&(&c2 - &c3)).iter().for_each(|b| {
|
||||
assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70)
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
let rns = RnsContext::new(¶ms.moduli)?;
|
||||
Vec::<BigUint>::from(&(&c2 - &c3)).iter().for_each(|b| {
|
||||
assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70)
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
|
||||
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
|
||||
let ksk_proto = KeySwitchingKeyProto::from(&ksk);
|
||||
assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
|
||||
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
|
||||
let ksk_proto = KeySwitchingKeyProto::from(&ksk);
|
||||
assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
use crate::bfv::traits::TryConvertFrom;
|
||||
use crate::bfv::{
|
||||
proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto},
|
||||
BfvParameters, Ciphertext, Encoding, Plaintext,
|
||||
proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto},
|
||||
BfvParameters, Ciphertext, Encoding, Plaintext,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use fhe_math::rq::{Poly, Representation};
|
||||
@@ -18,188 +18,188 @@ use super::SecretKey;
|
||||
/// Public key for the BFV encryption scheme.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct PublicKey {
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
pub(crate) c: Ciphertext,
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
pub(crate) c: Ciphertext,
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Generate a new [`PublicKey`] from a [`SecretKey`].
|
||||
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Self {
|
||||
let zero = Plaintext::zero(Encoding::poly(), &sk.par).unwrap();
|
||||
let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap();
|
||||
// The polynomials of a public key should not allow for variable time
|
||||
// computation.
|
||||
c.c.iter_mut()
|
||||
.for_each(|p| p.disallow_variable_time_computations());
|
||||
Self {
|
||||
par: sk.par.clone(),
|
||||
c,
|
||||
}
|
||||
}
|
||||
/// Generate a new [`PublicKey`] from a [`SecretKey`].
|
||||
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Self {
|
||||
let zero = Plaintext::zero(Encoding::poly(), &sk.par).unwrap();
|
||||
let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap();
|
||||
// The polynomials of a public key should not allow for variable time
|
||||
// computation.
|
||||
c.c.iter_mut()
|
||||
.for_each(|p| p.disallow_variable_time_computations());
|
||||
Self {
|
||||
par: sk.par.clone(),
|
||||
c,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FheParametrized for PublicKey {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl FheEncrypter<Plaintext, Ciphertext> for PublicKey {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
pt: &Plaintext,
|
||||
rng: &mut R,
|
||||
) -> Result<Ciphertext> {
|
||||
let mut ct = self.c.clone();
|
||||
while ct.level != pt.level {
|
||||
ct.mod_switch_to_next_level();
|
||||
}
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
pt: &Plaintext,
|
||||
rng: &mut R,
|
||||
) -> Result<Ciphertext> {
|
||||
let mut ct = self.c.clone();
|
||||
while ct.level != pt.level {
|
||||
ct.mod_switch_to_next_level();
|
||||
}
|
||||
|
||||
let ctx = self.par.ctx_at_level(ct.level)?;
|
||||
let u = Zeroizing::new(Poly::small(
|
||||
ctx,
|
||||
Representation::Ntt,
|
||||
self.par.variance,
|
||||
rng,
|
||||
)?);
|
||||
let e1 = Zeroizing::new(Poly::small(
|
||||
ctx,
|
||||
Representation::Ntt,
|
||||
self.par.variance,
|
||||
rng,
|
||||
)?);
|
||||
let e2 = Zeroizing::new(Poly::small(
|
||||
ctx,
|
||||
Representation::Ntt,
|
||||
self.par.variance,
|
||||
rng,
|
||||
)?);
|
||||
let ctx = self.par.ctx_at_level(ct.level)?;
|
||||
let u = Zeroizing::new(Poly::small(
|
||||
ctx,
|
||||
Representation::Ntt,
|
||||
self.par.variance,
|
||||
rng,
|
||||
)?);
|
||||
let e1 = Zeroizing::new(Poly::small(
|
||||
ctx,
|
||||
Representation::Ntt,
|
||||
self.par.variance,
|
||||
rng,
|
||||
)?);
|
||||
let e2 = Zeroizing::new(Poly::small(
|
||||
ctx,
|
||||
Representation::Ntt,
|
||||
self.par.variance,
|
||||
rng,
|
||||
)?);
|
||||
|
||||
let m = Zeroizing::new(pt.to_poly());
|
||||
let mut c0 = u.as_ref() * &ct.c[0];
|
||||
c0 += &e1;
|
||||
c0 += &m;
|
||||
let mut c1 = u.as_ref() * &ct.c[1];
|
||||
c1 += &e2;
|
||||
let m = Zeroizing::new(pt.to_poly());
|
||||
let mut c0 = u.as_ref() * &ct.c[0];
|
||||
c0 += &e1;
|
||||
c0 += &m;
|
||||
let mut c1 = u.as_ref() * &ct.c[1];
|
||||
c1 += &e2;
|
||||
|
||||
// It is now safe to enable variable time computations.
|
||||
unsafe {
|
||||
c0.allow_variable_time_computations();
|
||||
c1.allow_variable_time_computations()
|
||||
}
|
||||
// It is now safe to enable variable time computations.
|
||||
unsafe {
|
||||
c0.allow_variable_time_computations();
|
||||
c1.allow_variable_time_computations()
|
||||
}
|
||||
|
||||
Ok(Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: None,
|
||||
c: vec![c0, c1],
|
||||
level: ct.level,
|
||||
})
|
||||
}
|
||||
Ok(Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: None,
|
||||
c: vec![c0, c1],
|
||||
level: ct.level,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&PublicKey> for PublicKeyProto {
|
||||
fn from(pk: &PublicKey) -> Self {
|
||||
let mut proto = PublicKeyProto::new();
|
||||
proto.c = MessageField::some(CiphertextProto::from(&pk.c));
|
||||
proto
|
||||
}
|
||||
fn from(pk: &PublicKey) -> Self {
|
||||
let mut proto = PublicKeyProto::new();
|
||||
proto.c = MessageField::some(CiphertextProto::from(&pk.c));
|
||||
proto
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PublicKey {
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
PublicKeyProto::from(self).write_to_bytes().unwrap()
|
||||
}
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
PublicKeyProto::from(self).write_to_bytes().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl DeserializeParametrized for PublicKey {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
|
||||
let proto =
|
||||
PublicKeyProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
|
||||
if proto.c.is_some() {
|
||||
let mut c = Ciphertext::try_convert_from(&proto.c.unwrap(), par)?;
|
||||
if c.level != 0 {
|
||||
Err(Error::SerializationError)
|
||||
} else {
|
||||
// The polynomials of a public key should not allow for variable time
|
||||
// computation.
|
||||
c.c.iter_mut()
|
||||
.for_each(|p| p.disallow_variable_time_computations());
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
c,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
Err(Error::SerializationError)
|
||||
}
|
||||
}
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
|
||||
let proto =
|
||||
PublicKeyProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
|
||||
if proto.c.is_some() {
|
||||
let mut c = Ciphertext::try_convert_from(&proto.c.unwrap(), par)?;
|
||||
if c.level != 0 {
|
||||
Err(Error::SerializationError)
|
||||
} else {
|
||||
// The polynomials of a public key should not allow for variable time
|
||||
// computation.
|
||||
c.c.iter_mut()
|
||||
.for_each(|p| p.disallow_variable_time_computations());
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
c,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
Err(Error::SerializationError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::PublicKey;
|
||||
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext, SecretKey};
|
||||
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use super::PublicKey;
|
||||
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext, SecretKey};
|
||||
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn keygen() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let pk = PublicKey::new(&sk, &mut rng);
|
||||
assert_eq!(pk.par, params);
|
||||
assert_eq!(
|
||||
sk.try_decrypt(&pk.c)?,
|
||||
Plaintext::zero(Encoding::poly(), ¶ms)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn keygen() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let pk = PublicKey::new(&sk, &mut rng);
|
||||
assert_eq!(pk.par, params);
|
||||
assert_eq!(
|
||||
sk.try_decrypt(&pk.c)?,
|
||||
Plaintext::zero(Encoding::poly(), ¶ms)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
for level in 0..params.max_level() {
|
||||
for _ in 0..20 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let pk = PublicKey::new(&sk, &mut rng);
|
||||
#[test]
|
||||
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
for level in 0..params.max_level() {
|
||||
for _ in 0..20 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let pk = PublicKey::new(&sk, &mut rng);
|
||||
|
||||
let pt = Plaintext::try_encode(
|
||||
¶ms.plaintext.random_vec(params.degree(), &mut rng),
|
||||
Encoding::poly_at_level(level),
|
||||
¶ms,
|
||||
)?;
|
||||
let ct = pk.try_encrypt(&pt, &mut rng)?;
|
||||
let pt2 = sk.try_decrypt(&ct)?;
|
||||
let pt = Plaintext::try_encode(
|
||||
¶ms.plaintext.random_vec(params.degree(), &mut rng),
|
||||
Encoding::poly_at_level(level),
|
||||
¶ms,
|
||||
)?;
|
||||
let ct = pk.try_encrypt(&pt, &mut rng)?;
|
||||
let pt2 = sk.try_decrypt(&ct)?;
|
||||
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
assert_eq!(pt2, pt);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
assert_eq!(pt2, pt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let pk = PublicKey::new(&sk, &mut rng);
|
||||
let bytes = pk.to_bytes();
|
||||
assert_eq!(pk, PublicKey::from_bytes(&bytes, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn test_serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let pk = PublicKey::new(&sk, &mut rng);
|
||||
let bytes = pk.to_bytes();
|
||||
assert_eq!(pk, PublicKey::from_bytes(&bytes, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,15 +4,15 @@ use std::sync::Arc;
|
||||
|
||||
use super::key_switching_key::KeySwitchingKey;
|
||||
use crate::bfv::{
|
||||
proto::bfv::{
|
||||
KeySwitchingKey as KeySwitchingKeyProto, RelinearizationKey as RelinearizationKeyProto,
|
||||
},
|
||||
traits::TryConvertFrom,
|
||||
BfvParameters, Ciphertext, SecretKey,
|
||||
proto::bfv::{
|
||||
KeySwitchingKey as KeySwitchingKeyProto, RelinearizationKey as RelinearizationKeyProto,
|
||||
},
|
||||
traits::TryConvertFrom,
|
||||
BfvParameters, Ciphertext, SecretKey,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use fhe_math::rq::{
|
||||
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
|
||||
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
|
||||
};
|
||||
use fhe_traits::{DeserializeParametrized, FheParametrized, Serialize};
|
||||
use protobuf::{Message, MessageField};
|
||||
@@ -24,284 +24,284 @@ use zeroize::Zeroizing;
|
||||
/// which switch from `s^2` to `s` where `s` is the secret key.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct RelinearizationKey {
|
||||
pub(crate) ksk: KeySwitchingKey,
|
||||
pub(crate) ksk: KeySwitchingKey,
|
||||
}
|
||||
|
||||
impl RelinearizationKey {
|
||||
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
|
||||
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Result<Self> {
|
||||
Self::new_leveled_internal(sk, 0, 0, rng)
|
||||
}
|
||||
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
|
||||
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Result<Self> {
|
||||
Self::new_leveled_internal(sk, 0, 0, rng)
|
||||
}
|
||||
|
||||
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
|
||||
pub fn new_leveled<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
ciphertext_level: usize,
|
||||
key_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
Self::new_leveled_internal(sk, ciphertext_level, key_level, rng)
|
||||
}
|
||||
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
|
||||
pub fn new_leveled<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
ciphertext_level: usize,
|
||||
key_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
Self::new_leveled_internal(sk, ciphertext_level, key_level, rng)
|
||||
}
|
||||
|
||||
fn new_leveled_internal<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
ciphertext_level: usize,
|
||||
key_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
let ctx_relin_key = sk.par.ctx_at_level(key_level)?;
|
||||
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
|
||||
fn new_leveled_internal<R: RngCore + CryptoRng>(
|
||||
sk: &SecretKey,
|
||||
ciphertext_level: usize,
|
||||
key_level: usize,
|
||||
rng: &mut R,
|
||||
) -> Result<Self> {
|
||||
let ctx_relin_key = sk.par.ctx_at_level(key_level)?;
|
||||
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
|
||||
|
||||
if ctx_relin_key.moduli().len() == 1 {
|
||||
return Err(Error::DefaultError(
|
||||
"These parameters do not support key switching".to_string(),
|
||||
));
|
||||
}
|
||||
if ctx_relin_key.moduli().len() == 1 {
|
||||
return Err(Error::DefaultError(
|
||||
"These parameters do not support key switching".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx_ciphertext,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref());
|
||||
s2.change_representation(Representation::PowerBasis);
|
||||
let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?;
|
||||
let s2_switched_up = Zeroizing::new(s2.mod_switch_to(&switcher_up)?);
|
||||
let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?;
|
||||
Ok(Self { ksk })
|
||||
}
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx_ciphertext,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref());
|
||||
s2.change_representation(Representation::PowerBasis);
|
||||
let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?;
|
||||
let s2_switched_up = Zeroizing::new(s2.mod_switch_to(&switcher_up)?);
|
||||
let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?;
|
||||
Ok(Self { ksk })
|
||||
}
|
||||
|
||||
/// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`]
|
||||
pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> {
|
||||
if ct.c.len() != 3 {
|
||||
Err(Error::DefaultError(
|
||||
"Only supports relinearization of ciphertext with 3 parts".to_string(),
|
||||
))
|
||||
} else if ct.level != self.ksk.ciphertext_level {
|
||||
Err(Error::DefaultError(
|
||||
"Ciphertext has incorrect level".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut c2 = ct.c[2].clone();
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
/// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`]
|
||||
pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> {
|
||||
if ct.c.len() != 3 {
|
||||
Err(Error::DefaultError(
|
||||
"Only supports relinearization of ciphertext with 3 parts".to_string(),
|
||||
))
|
||||
} else if ct.level != self.ksk.ciphertext_level {
|
||||
Err(Error::DefaultError(
|
||||
"Ciphertext has incorrect level".to_string(),
|
||||
))
|
||||
} else {
|
||||
let mut c2 = ct.c[2].clone();
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
|
||||
#[allow(unused_mut)]
|
||||
let (mut c0, mut c1) = self.relinearizes_poly(&c2)?;
|
||||
#[allow(unused_mut)]
|
||||
let (mut c0, mut c1) = self.relinearizes_poly(&c2)?;
|
||||
|
||||
if c0.ctx() != ct.c[0].ctx() {
|
||||
c0.change_representation(Representation::PowerBasis);
|
||||
c1.change_representation(Representation::PowerBasis);
|
||||
c0.mod_switch_down_to(ct.c[0].ctx())?;
|
||||
c1.mod_switch_down_to(ct.c[1].ctx())?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c1.change_representation(Representation::Ntt);
|
||||
}
|
||||
if c0.ctx() != ct.c[0].ctx() {
|
||||
c0.change_representation(Representation::PowerBasis);
|
||||
c1.change_representation(Representation::PowerBasis);
|
||||
c0.mod_switch_down_to(ct.c[0].ctx())?;
|
||||
c1.mod_switch_down_to(ct.c[1].ctx())?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c1.change_representation(Representation::Ntt);
|
||||
}
|
||||
|
||||
ct.c[0] += &c0;
|
||||
ct.c[1] += &c1;
|
||||
ct.c.truncate(2);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
ct.c[0] += &c0;
|
||||
ct.c[1] += &c1;
|
||||
ct.c.truncate(2);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Relinearize using polynomials.
|
||||
pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> {
|
||||
self.ksk.key_switch(c2)
|
||||
}
|
||||
/// Relinearize using polynomials.
|
||||
pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> {
|
||||
self.ksk.key_switch(c2)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RelinearizationKey> for RelinearizationKeyProto {
|
||||
fn from(value: &RelinearizationKey) -> Self {
|
||||
let mut rk = RelinearizationKeyProto::new();
|
||||
rk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
|
||||
rk
|
||||
}
|
||||
fn from(value: &RelinearizationKey) -> Self {
|
||||
let mut rk = RelinearizationKeyProto::new();
|
||||
rk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
|
||||
rk
|
||||
}
|
||||
}
|
||||
|
||||
impl TryConvertFrom<&RelinearizationKeyProto> for RelinearizationKey {
|
||||
fn try_convert_from(value: &RelinearizationKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if par.moduli.len() == 1 {
|
||||
Err(Error::DefaultError(
|
||||
"Invalid parameters for a relinearization key".to_string(),
|
||||
))
|
||||
} else if value.ksk.is_some() {
|
||||
Ok(RelinearizationKey {
|
||||
ksk: KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?,
|
||||
})
|
||||
} else {
|
||||
Err(Error::DefaultError("Invalid serialization".to_string()))
|
||||
}
|
||||
}
|
||||
fn try_convert_from(value: &RelinearizationKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if par.moduli.len() == 1 {
|
||||
Err(Error::DefaultError(
|
||||
"Invalid parameters for a relinearization key".to_string(),
|
||||
))
|
||||
} else if value.ksk.is_some() {
|
||||
Ok(RelinearizationKey {
|
||||
ksk: KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?,
|
||||
})
|
||||
} else {
|
||||
Err(Error::DefaultError("Invalid serialization".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for RelinearizationKey {
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
RelinearizationKeyProto::from(self)
|
||||
.write_to_bytes()
|
||||
.unwrap()
|
||||
}
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
RelinearizationKeyProto::from(self)
|
||||
.write_to_bytes()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl FheParametrized for RelinearizationKey {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl DeserializeParametrized for RelinearizationKey {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
|
||||
let rk = RelinearizationKeyProto::parse_from_bytes(bytes);
|
||||
if let Ok(rk) = rk {
|
||||
RelinearizationKey::try_convert_from(&rk, par)
|
||||
} else {
|
||||
Err(Error::DefaultError("Invalid serialization".to_string()))
|
||||
}
|
||||
}
|
||||
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
|
||||
let rk = RelinearizationKeyProto::parse_from_bytes(bytes);
|
||||
if let Ok(rk) = rk {
|
||||
RelinearizationKey::try_convert_from(&rk, par)
|
||||
} else {
|
||||
Err(Error::DefaultError("Invalid serialization".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::RelinearizationKey;
|
||||
use crate::bfv::{
|
||||
proto::bfv::RelinearizationKey as RelinearizationKeyProto, traits::TryConvertFrom,
|
||||
BfvParameters, Ciphertext, Encoding, SecretKey,
|
||||
};
|
||||
use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation};
|
||||
use fhe_traits::{FheDecoder, FheDecrypter};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use super::RelinearizationKey;
|
||||
use crate::bfv::{
|
||||
proto::bfv::RelinearizationKey as RelinearizationKeyProto, traits::TryConvertFrom,
|
||||
BfvParameters, Ciphertext, Encoding, SecretKey,
|
||||
};
|
||||
use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation};
|
||||
use fhe_traits::{FheDecoder, FheDecrypter};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn relinearization() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [Arc::new(BfvParameters::default(6, 8))] {
|
||||
for _ in 0..100 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
#[test]
|
||||
fn relinearization() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [Arc::new(BfvParameters::default(6, 8))] {
|
||||
for _ in 0..100 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let mut s = Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
.map_err(crate::Error::MathError)?;
|
||||
s.change_representation(Representation::Ntt);
|
||||
let s2 = &s * &s;
|
||||
let ctx = params.ctx_at_level(0)?;
|
||||
let mut s = Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
.map_err(crate::Error::MathError)?;
|
||||
s.change_representation(Representation::Ntt);
|
||||
let s2 = &s * &s;
|
||||
|
||||
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2,
|
||||
// c1, c2) encrypting 0.
|
||||
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c0 -= &(&c1 * &s);
|
||||
c0 -= &(&c2 * &s2);
|
||||
let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?;
|
||||
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2,
|
||||
// c1, c2) encrypting 0.
|
||||
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c0 -= &(&c1 * &s);
|
||||
c0 -= &(&c2 * &s2);
|
||||
let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?;
|
||||
|
||||
// Relinearize the extended ciphertext!
|
||||
rk.relinearizes(&mut ct)?;
|
||||
assert_eq!(ct.c.len(), 2);
|
||||
// Relinearize the extended ciphertext!
|
||||
rk.relinearizes(&mut ct)?;
|
||||
assert_eq!(ct.c.len(), 2);
|
||||
|
||||
// Check that the relinearization by polynomials works the same way
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
|
||||
c0r.change_representation(Representation::PowerBasis);
|
||||
c0r.mod_switch_down_to(c0.ctx())?;
|
||||
c1r.change_representation(Representation::PowerBasis);
|
||||
c1r.mod_switch_down_to(c1.ctx())?;
|
||||
c0r.change_representation(Representation::Ntt);
|
||||
c1r.change_representation(Representation::Ntt);
|
||||
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?);
|
||||
// Check that the relinearization by polynomials works the same way
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
|
||||
c0r.change_representation(Representation::PowerBasis);
|
||||
c0r.mod_switch_down_to(c0.ctx())?;
|
||||
c1r.change_representation(Representation::PowerBasis);
|
||||
c1r.mod_switch_down_to(c1.ctx())?;
|
||||
c0r.change_representation(Representation::Ntt);
|
||||
c1r.change_representation(Representation::Ntt);
|
||||
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?);
|
||||
|
||||
// Print the noise and decrypt
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
let pt = sk.try_decrypt(&ct)?;
|
||||
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
|
||||
assert_eq!(w, &[0u64; 8]);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
// Print the noise and decrypt
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
let pt = sk.try_decrypt(&ct)?;
|
||||
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
|
||||
assert_eq!(w, &[0u64; 8]);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relinearization_leveled() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [Arc::new(BfvParameters::default(5, 8))] {
|
||||
for ciphertext_level in 0..3 {
|
||||
for key_level in 0..ciphertext_level {
|
||||
for _ in 0..10 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let rk = RelinearizationKey::new_leveled(
|
||||
&sk,
|
||||
ciphertext_level,
|
||||
key_level,
|
||||
&mut rng,
|
||||
)?;
|
||||
#[test]
|
||||
fn relinearization_leveled() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [Arc::new(BfvParameters::default(5, 8))] {
|
||||
for ciphertext_level in 0..3 {
|
||||
for key_level in 0..ciphertext_level {
|
||||
for _ in 0..10 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let rk = RelinearizationKey::new_leveled(
|
||||
&sk,
|
||||
ciphertext_level,
|
||||
key_level,
|
||||
&mut rng,
|
||||
)?;
|
||||
|
||||
let ctx = params.ctx_at_level(ciphertext_level)?;
|
||||
let mut s = Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
.map_err(crate::Error::MathError)?;
|
||||
s.change_representation(Representation::Ntt);
|
||||
let s2 = &s * &s;
|
||||
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 *
|
||||
// s^2, c1, c2) encrypting 0.
|
||||
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c0 -= &(&c1 * &s);
|
||||
c0 -= &(&c2 * &s2);
|
||||
let mut ct =
|
||||
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?;
|
||||
let ctx = params.ctx_at_level(ciphertext_level)?;
|
||||
let mut s = Poly::try_convert_from(
|
||||
sk.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
.map_err(crate::Error::MathError)?;
|
||||
s.change_representation(Representation::Ntt);
|
||||
let s2 = &s * &s;
|
||||
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 *
|
||||
// s^2, c1, c2) encrypting 0.
|
||||
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
|
||||
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
|
||||
c0.change_representation(Representation::Ntt);
|
||||
c0 -= &(&c1 * &s);
|
||||
c0 -= &(&c2 * &s2);
|
||||
let mut ct =
|
||||
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], ¶ms)?;
|
||||
|
||||
// Relinearize the extended ciphertext!
|
||||
rk.relinearizes(&mut ct)?;
|
||||
assert_eq!(ct.c.len(), 2);
|
||||
// Relinearize the extended ciphertext!
|
||||
rk.relinearizes(&mut ct)?;
|
||||
assert_eq!(ct.c.len(), 2);
|
||||
|
||||
// Check that the relinearization by polynomials works the same way
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
|
||||
c0r.change_representation(Representation::PowerBasis);
|
||||
c0r.mod_switch_down_to(c0.ctx())?;
|
||||
c1r.change_representation(Representation::PowerBasis);
|
||||
c1r.mod_switch_down_to(c1.ctx())?;
|
||||
c0r.change_representation(Representation::Ntt);
|
||||
c1r.change_representation(Representation::Ntt);
|
||||
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?);
|
||||
// Check that the relinearization by polynomials works the same way
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
|
||||
c0r.change_representation(Representation::PowerBasis);
|
||||
c0r.mod_switch_down_to(c0.ctx())?;
|
||||
c1r.change_representation(Representation::PowerBasis);
|
||||
c1r.mod_switch_down_to(c1.ctx())?;
|
||||
c0r.change_representation(Representation::Ntt);
|
||||
c1r.change_representation(Representation::Ntt);
|
||||
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], ¶ms)?);
|
||||
|
||||
// Print the noise and decrypt
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
let pt = sk.try_decrypt(&ct)?;
|
||||
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
|
||||
assert_eq!(w, &[0u64; 8]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
// Print the noise and decrypt
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
let pt = sk.try_decrypt(&ct)?;
|
||||
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
|
||||
assert_eq!(w, &[0u64; 8]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
let proto = RelinearizationKeyProto::from(&rk);
|
||||
assert_eq!(rk, RelinearizationKey::try_convert_from(&proto, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn proto_conversion() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(3, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
let proto = RelinearizationKeyProto::from(&rk);
|
||||
assert_eq!(rk, RelinearizationKey::try_convert_from(&proto, ¶ms)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
use crate::bfv::{BfvParameters, Ciphertext, Plaintext};
|
||||
use crate::{Error, Result};
|
||||
use fhe_math::{
|
||||
rq::{traits::TryConvertFrom, Poly, Representation},
|
||||
zq::Modulus,
|
||||
rq::{traits::TryConvertFrom, Poly, Representation},
|
||||
zq::Modulus,
|
||||
};
|
||||
use fhe_traits::{FheDecrypter, FheEncrypter, FheParametrized};
|
||||
use fhe_util::sample_vec_cbd;
|
||||
@@ -18,249 +18,249 @@ use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
|
||||
/// Secret key for the BFV encryption scheme.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct SecretKey {
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
pub(crate) coeffs: Box<[i64]>,
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
pub(crate) coeffs: Box<[i64]>,
|
||||
}
|
||||
|
||||
impl Zeroize for SecretKey {
|
||||
fn zeroize(&mut self) {
|
||||
self.coeffs.zeroize();
|
||||
}
|
||||
fn zeroize(&mut self) {
|
||||
self.coeffs.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl ZeroizeOnDrop for SecretKey {}
|
||||
|
||||
impl SecretKey {
|
||||
/// Generate a random [`SecretKey`].
|
||||
pub fn random<R: RngCore + CryptoRng>(par: &Arc<BfvParameters>, rng: &mut R) -> Self {
|
||||
let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap();
|
||||
Self::new(s_coefficients, par)
|
||||
}
|
||||
/// Generate a random [`SecretKey`].
|
||||
pub fn random<R: RngCore + CryptoRng>(par: &Arc<BfvParameters>, rng: &mut R) -> Self {
|
||||
let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap();
|
||||
Self::new(s_coefficients, par)
|
||||
}
|
||||
|
||||
/// Generate a [`SecretKey`] from its coefficients.
|
||||
pub(crate) fn new(coeffs: Vec<i64>, par: &Arc<BfvParameters>) -> Self {
|
||||
Self {
|
||||
par: par.clone(),
|
||||
coeffs: coeffs.into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
/// Generate a [`SecretKey`] from its coefficients.
|
||||
pub(crate) fn new(coeffs: Vec<i64>, par: &Arc<BfvParameters>) -> Self {
|
||||
Self {
|
||||
par: par.clone(),
|
||||
coeffs: coeffs.into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Measure the noise in a [`Ciphertext`].
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This operations may run in a variable time depending on the value of the
|
||||
/// noise.
|
||||
pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result<usize> {
|
||||
let plaintext = Zeroizing::new(self.try_decrypt(ct)?);
|
||||
let m = Zeroizing::new(plaintext.to_poly());
|
||||
/// Measure the noise in a [`Ciphertext`].
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This operations may run in a variable time depending on the value of the
|
||||
/// noise.
|
||||
pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result<usize> {
|
||||
let plaintext = Zeroizing::new(self.try_decrypt(ct)?);
|
||||
let m = Zeroizing::new(plaintext.to_poly());
|
||||
|
||||
// Let's create a secret key with the ciphertext context
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
ct.c[0].ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut si = s.clone();
|
||||
// Let's create a secret key with the ciphertext context
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
ct.c[0].ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut si = s.clone();
|
||||
|
||||
// Let's disable variable time computations
|
||||
let mut c = Zeroizing::new(ct.c[0].clone());
|
||||
c.disallow_variable_time_computations();
|
||||
// Let's disable variable time computations
|
||||
let mut c = Zeroizing::new(ct.c[0].clone());
|
||||
c.disallow_variable_time_computations();
|
||||
|
||||
for i in 1..ct.c.len() {
|
||||
let mut cis = Zeroizing::new(ct.c[i].clone());
|
||||
cis.disallow_variable_time_computations();
|
||||
*cis.as_mut() *= si.as_ref();
|
||||
*c.as_mut() += &cis;
|
||||
*si.as_mut() *= s.as_ref();
|
||||
}
|
||||
*c.as_mut() -= &m;
|
||||
c.change_representation(Representation::PowerBasis);
|
||||
for i in 1..ct.c.len() {
|
||||
let mut cis = Zeroizing::new(ct.c[i].clone());
|
||||
cis.disallow_variable_time_computations();
|
||||
*cis.as_mut() *= si.as_ref();
|
||||
*c.as_mut() += &cis;
|
||||
*si.as_mut() *= s.as_ref();
|
||||
}
|
||||
*c.as_mut() -= &m;
|
||||
c.change_representation(Representation::PowerBasis);
|
||||
|
||||
let ciphertext_modulus = ct.c[0].ctx().modulus();
|
||||
let mut noise = 0usize;
|
||||
for coeff in Vec::<BigUint>::from(c.as_ref()) {
|
||||
noise = std::cmp::max(
|
||||
noise,
|
||||
std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize,
|
||||
)
|
||||
}
|
||||
let ciphertext_modulus = ct.c[0].ctx().modulus();
|
||||
let mut noise = 0usize;
|
||||
for coeff in Vec::<BigUint>::from(c.as_ref()) {
|
||||
noise = std::cmp::max(
|
||||
noise,
|
||||
std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize,
|
||||
)
|
||||
}
|
||||
|
||||
Ok(noise)
|
||||
}
|
||||
Ok(noise)
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_poly<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
p: &Poly,
|
||||
rng: &mut R,
|
||||
) -> Result<Ciphertext> {
|
||||
assert_eq!(p.representation(), &Representation::Ntt);
|
||||
pub(crate) fn encrypt_poly<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
p: &Poly,
|
||||
rng: &mut R,
|
||||
) -> Result<Ciphertext> {
|
||||
assert_eq!(p.representation(), &Representation::Ntt);
|
||||
|
||||
let level = self.par.level_of_ctx(p.ctx())?;
|
||||
let level = self.par.level_of_ctx(p.ctx())?;
|
||||
|
||||
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
|
||||
thread_rng().fill(&mut seed);
|
||||
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
|
||||
thread_rng().fill(&mut seed);
|
||||
|
||||
// Let's create a secret key with the ciphertext context
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
p.ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
// Let's create a secret key with the ciphertext context
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
p.ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
|
||||
let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed);
|
||||
let a_s = Zeroizing::new(&a * s.as_ref());
|
||||
let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed);
|
||||
let a_s = Zeroizing::new(&a * s.as_ref());
|
||||
|
||||
let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng)
|
||||
.map_err(Error::MathError)?;
|
||||
b -= &a_s;
|
||||
b += p;
|
||||
let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng)
|
||||
.map_err(Error::MathError)?;
|
||||
b -= &a_s;
|
||||
b += p;
|
||||
|
||||
// It is now safe to enable variable time computations.
|
||||
unsafe {
|
||||
a.allow_variable_time_computations();
|
||||
b.allow_variable_time_computations()
|
||||
}
|
||||
// It is now safe to enable variable time computations.
|
||||
unsafe {
|
||||
a.allow_variable_time_computations();
|
||||
b.allow_variable_time_computations()
|
||||
}
|
||||
|
||||
Ok(Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: Some(seed),
|
||||
c: vec![b, a],
|
||||
level,
|
||||
})
|
||||
}
|
||||
Ok(Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: Some(seed),
|
||||
c: vec![b, a],
|
||||
level,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FheParametrized for SecretKey {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl FheEncrypter<Plaintext, Ciphertext> for SecretKey {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
pt: &Plaintext,
|
||||
rng: &mut R,
|
||||
) -> Result<Ciphertext> {
|
||||
assert_eq!(self.par, pt.par);
|
||||
let m = Zeroizing::new(pt.to_poly());
|
||||
self.encrypt_poly(m.as_ref(), rng)
|
||||
}
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
pt: &Plaintext,
|
||||
rng: &mut R,
|
||||
) -> Result<Ciphertext> {
|
||||
assert_eq!(self.par, pt.par);
|
||||
let m = Zeroizing::new(pt.to_poly());
|
||||
self.encrypt_poly(m.as_ref(), rng)
|
||||
}
|
||||
}
|
||||
|
||||
impl FheDecrypter<Plaintext, Ciphertext> for SecretKey {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn try_decrypt(&self, ct: &Ciphertext) -> Result<Plaintext> {
|
||||
if self.par != ct.par {
|
||||
Err(Error::DefaultError(
|
||||
"Incompatible BFV parameters".to_string(),
|
||||
))
|
||||
} else {
|
||||
// Let's create a secret key with the ciphertext context
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
ct.c[0].ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut si = s.clone();
|
||||
fn try_decrypt(&self, ct: &Ciphertext) -> Result<Plaintext> {
|
||||
if self.par != ct.par {
|
||||
Err(Error::DefaultError(
|
||||
"Incompatible BFV parameters".to_string(),
|
||||
))
|
||||
} else {
|
||||
// Let's create a secret key with the ciphertext context
|
||||
let mut s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
ct.c[0].ctx(),
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
s.change_representation(Representation::Ntt);
|
||||
let mut si = s.clone();
|
||||
|
||||
let mut c = Zeroizing::new(ct.c[0].clone());
|
||||
c.disallow_variable_time_computations();
|
||||
let mut c = Zeroizing::new(ct.c[0].clone());
|
||||
c.disallow_variable_time_computations();
|
||||
|
||||
for i in 1..ct.c.len() {
|
||||
let mut cis = Zeroizing::new(ct.c[i].clone());
|
||||
cis.disallow_variable_time_computations();
|
||||
*cis.as_mut() *= si.as_ref();
|
||||
*c.as_mut() += &cis;
|
||||
*si.as_mut() *= s.as_ref();
|
||||
}
|
||||
c.change_representation(Representation::PowerBasis);
|
||||
for i in 1..ct.c.len() {
|
||||
let mut cis = Zeroizing::new(ct.c[i].clone());
|
||||
cis.disallow_variable_time_computations();
|
||||
*cis.as_mut() *= si.as_ref();
|
||||
*c.as_mut() += &cis;
|
||||
*si.as_mut() *= s.as_ref();
|
||||
}
|
||||
c.change_representation(Representation::PowerBasis);
|
||||
|
||||
let d = Zeroizing::new(c.scale(&self.par.scalers[ct.level])?);
|
||||
let d = Zeroizing::new(c.scale(&self.par.scalers[ct.level])?);
|
||||
|
||||
// TODO: Can we handle plaintext moduli that are BigUint?
|
||||
let v = Zeroizing::new(
|
||||
Vec::<u64>::from(d.as_ref())
|
||||
.iter_mut()
|
||||
.map(|vi| *vi + self.par.plaintext.modulus())
|
||||
.collect_vec(),
|
||||
);
|
||||
let mut w = v[..self.par.degree()].to_vec();
|
||||
let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?;
|
||||
q.reduce_vec(&mut w);
|
||||
self.par.plaintext.reduce_vec(&mut w);
|
||||
// TODO: Can we handle plaintext moduli that are BigUint?
|
||||
let v = Zeroizing::new(
|
||||
Vec::<u64>::from(d.as_ref())
|
||||
.iter_mut()
|
||||
.map(|vi| *vi + self.par.plaintext.modulus())
|
||||
.collect_vec(),
|
||||
);
|
||||
let mut w = v[..self.par.degree()].to_vec();
|
||||
let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?;
|
||||
q.reduce_vec(&mut w);
|
||||
self.par.plaintext.reduce_vec(&mut w);
|
||||
|
||||
let mut poly =
|
||||
Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?;
|
||||
poly.change_representation(Representation::Ntt);
|
||||
let mut poly =
|
||||
Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?;
|
||||
poly.change_representation(Representation::Ntt);
|
||||
|
||||
let pt = Plaintext {
|
||||
par: self.par.clone(),
|
||||
value: w.into_boxed_slice(),
|
||||
encoding: None,
|
||||
poly_ntt: poly,
|
||||
level: ct.level,
|
||||
};
|
||||
let pt = Plaintext {
|
||||
par: self.par.clone(),
|
||||
value: w.into_boxed_slice(),
|
||||
encoding: None,
|
||||
poly_ntt: poly,
|
||||
level: ct.level,
|
||||
};
|
||||
|
||||
Ok(pt)
|
||||
}
|
||||
}
|
||||
Ok(pt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SecretKey;
|
||||
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext};
|
||||
use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use super::SecretKey;
|
||||
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext};
|
||||
use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn keygen() {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
assert_eq!(sk.par, params);
|
||||
#[test]
|
||||
fn keygen() {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
assert_eq!(sk.par, params);
|
||||
|
||||
sk.coeffs.iter().for_each(|ci| {
|
||||
// Check that this is a small polynomial
|
||||
assert!((*ci).abs() <= 2 * sk.par.variance as i64)
|
||||
})
|
||||
}
|
||||
sk.coeffs.iter().for_each(|ci| {
|
||||
// Check that this is a small polynomial
|
||||
assert!((*ci).abs() <= 2 * sk.par.variance as i64)
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
for level in 0..params.max_level() {
|
||||
for _ in 0..20 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
#[test]
|
||||
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
] {
|
||||
for level in 0..params.max_level() {
|
||||
for _ in 0..20 {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
|
||||
let pt = Plaintext::try_encode(
|
||||
¶ms.plaintext.random_vec(params.degree(), &mut rng),
|
||||
Encoding::poly_at_level(level),
|
||||
¶ms,
|
||||
)?;
|
||||
let ct = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let pt2 = sk.try_decrypt(&ct)?;
|
||||
let pt = Plaintext::try_encode(
|
||||
¶ms.plaintext.random_vec(params.degree(), &mut rng),
|
||||
Encoding::poly_at_level(level),
|
||||
¶ms,
|
||||
)?;
|
||||
let ct = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let pt2 = sk.try_decrypt(&ct)?;
|
||||
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
assert_eq!(pt2, pt);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
|
||||
assert_eq!(pt2, pt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,46 +5,46 @@ use itertools::{izip, Itertools};
|
||||
use ndarray::{Array, Array2};
|
||||
|
||||
use crate::{
|
||||
bfv::{Ciphertext, Plaintext},
|
||||
Error, Result,
|
||||
bfv::{Ciphertext, Plaintext},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
/// Computes the Fused-Mul-Add operation `out[i] += x[i] * y[i]`
|
||||
unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) {
|
||||
let n = out.len();
|
||||
assert_eq!(x.len(), n);
|
||||
assert_eq!(y.len(), n);
|
||||
let n = out.len();
|
||||
assert_eq!(x.len(), n);
|
||||
assert_eq!(y.len(), n);
|
||||
|
||||
macro_rules! fma_at {
|
||||
($idx:expr) => {
|
||||
*out.get_unchecked_mut($idx) +=
|
||||
(*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128);
|
||||
};
|
||||
}
|
||||
macro_rules! fma_at {
|
||||
($idx:expr) => {
|
||||
*out.get_unchecked_mut($idx) +=
|
||||
(*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128);
|
||||
};
|
||||
}
|
||||
|
||||
let r = n / 16;
|
||||
for i in 0..r {
|
||||
fma_at!(16 * i);
|
||||
fma_at!(16 * i + 1);
|
||||
fma_at!(16 * i + 2);
|
||||
fma_at!(16 * i + 3);
|
||||
fma_at!(16 * i + 4);
|
||||
fma_at!(16 * i + 5);
|
||||
fma_at!(16 * i + 6);
|
||||
fma_at!(16 * i + 7);
|
||||
fma_at!(16 * i + 8);
|
||||
fma_at!(16 * i + 9);
|
||||
fma_at!(16 * i + 10);
|
||||
fma_at!(16 * i + 11);
|
||||
fma_at!(16 * i + 12);
|
||||
fma_at!(16 * i + 13);
|
||||
fma_at!(16 * i + 14);
|
||||
fma_at!(16 * i + 15);
|
||||
}
|
||||
let r = n / 16;
|
||||
for i in 0..r {
|
||||
fma_at!(16 * i);
|
||||
fma_at!(16 * i + 1);
|
||||
fma_at!(16 * i + 2);
|
||||
fma_at!(16 * i + 3);
|
||||
fma_at!(16 * i + 4);
|
||||
fma_at!(16 * i + 5);
|
||||
fma_at!(16 * i + 6);
|
||||
fma_at!(16 * i + 7);
|
||||
fma_at!(16 * i + 8);
|
||||
fma_at!(16 * i + 9);
|
||||
fma_at!(16 * i + 10);
|
||||
fma_at!(16 * i + 11);
|
||||
fma_at!(16 * i + 12);
|
||||
fma_at!(16 * i + 13);
|
||||
fma_at!(16 * i + 14);
|
||||
fma_at!(16 * i + 15);
|
||||
}
|
||||
|
||||
for i in 0..n % 16 {
|
||||
fma_at!(16 * r + i);
|
||||
}
|
||||
for i in 0..n % 16 {
|
||||
fma_at!(16 * r + i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the dot product between an iterator of [`Ciphertext`] and an
|
||||
@@ -53,146 +53,146 @@ unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) {
|
||||
/// number of parts.
|
||||
pub fn dot_product_scalar<'a, I, J>(ct: I, pt: J) -> Result<Ciphertext>
|
||||
where
|
||||
I: Iterator<Item = &'a Ciphertext> + Clone,
|
||||
J: Iterator<Item = &'a Plaintext> + Clone,
|
||||
I: Iterator<Item = &'a Ciphertext> + Clone,
|
||||
J: Iterator<Item = &'a Plaintext> + Clone,
|
||||
{
|
||||
let count = min(ct.clone().count(), pt.clone().count());
|
||||
if count == 0 {
|
||||
return Err(Error::DefaultError(
|
||||
"At least one iterator is empty".to_string(),
|
||||
));
|
||||
}
|
||||
let ct_first = ct.clone().next().unwrap();
|
||||
let ctx = ct_first.c[0].ctx();
|
||||
let count = min(ct.clone().count(), pt.clone().count());
|
||||
if count == 0 {
|
||||
return Err(Error::DefaultError(
|
||||
"At least one iterator is empty".to_string(),
|
||||
));
|
||||
}
|
||||
let ct_first = ct.clone().next().unwrap();
|
||||
let ctx = ct_first.c[0].ctx();
|
||||
|
||||
if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| {
|
||||
cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len()
|
||||
}) {
|
||||
return Err(Error::DefaultError("Mismatched parameters".to_string()));
|
||||
}
|
||||
if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) {
|
||||
return Err(Error::DefaultError(
|
||||
"Mismatched number of parts in the ciphertexts".to_string(),
|
||||
));
|
||||
}
|
||||
if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| {
|
||||
cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len()
|
||||
}) {
|
||||
return Err(Error::DefaultError("Mismatched parameters".to_string()));
|
||||
}
|
||||
if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) {
|
||||
return Err(Error::DefaultError(
|
||||
"Mismatched number of parts in the ciphertexts".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let max_acc = ctx
|
||||
.moduli()
|
||||
.iter()
|
||||
.map(|qi| 1u128 << (2 * qi.leading_zeros()))
|
||||
.collect_vec();
|
||||
let min_of_max = max_acc.iter().min().unwrap();
|
||||
let max_acc = ctx
|
||||
.moduli()
|
||||
.iter()
|
||||
.map(|qi| 1u128 << (2 * qi.leading_zeros()))
|
||||
.collect_vec();
|
||||
let min_of_max = max_acc.iter().min().unwrap();
|
||||
|
||||
if count as u128 > *min_of_max {
|
||||
// Too many ciphertexts for the optimized method, instead, we call
|
||||
// `poly_dot_product`.
|
||||
let c = (0..ct_first.c.len())
|
||||
.map(|i| {
|
||||
poly_dot_product(
|
||||
ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }),
|
||||
pt.clone().map(|pti| &pti.poly_ntt),
|
||||
)
|
||||
.map_err(Error::MathError)
|
||||
})
|
||||
.collect::<Result<Vec<Poly>>>()?;
|
||||
if count as u128 > *min_of_max {
|
||||
// Too many ciphertexts for the optimized method, instead, we call
|
||||
// `poly_dot_product`.
|
||||
let c = (0..ct_first.c.len())
|
||||
.map(|i| {
|
||||
poly_dot_product(
|
||||
ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }),
|
||||
pt.clone().map(|pti| &pti.poly_ntt),
|
||||
)
|
||||
.map_err(Error::MathError)
|
||||
})
|
||||
.collect::<Result<Vec<Poly>>>()?;
|
||||
|
||||
Ok(Ciphertext {
|
||||
par: ct_first.par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level: ct_first.level,
|
||||
})
|
||||
} else {
|
||||
let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree()));
|
||||
for (ciphertext, plaintext) in izip!(ct, pt) {
|
||||
let pt_coefficients = plaintext.poly_ntt.coefficients();
|
||||
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) {
|
||||
let ci_coefficients = ci.coefficients();
|
||||
for (mut accij, cij, pij) in izip!(
|
||||
acci.outer_iter_mut(),
|
||||
ci_coefficients.outer_iter(),
|
||||
pt_coefficients.outer_iter()
|
||||
) {
|
||||
unsafe {
|
||||
fma(
|
||||
accij.as_slice_mut().unwrap(),
|
||||
cij.as_slice().unwrap(),
|
||||
pij.as_slice().unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Ciphertext {
|
||||
par: ct_first.par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level: ct_first.level,
|
||||
})
|
||||
} else {
|
||||
let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree()));
|
||||
for (ciphertext, plaintext) in izip!(ct, pt) {
|
||||
let pt_coefficients = plaintext.poly_ntt.coefficients();
|
||||
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) {
|
||||
let ci_coefficients = ci.coefficients();
|
||||
for (mut accij, cij, pij) in izip!(
|
||||
acci.outer_iter_mut(),
|
||||
ci_coefficients.outer_iter(),
|
||||
pt_coefficients.outer_iter()
|
||||
) {
|
||||
unsafe {
|
||||
fma(
|
||||
accij.as_slice_mut().unwrap(),
|
||||
cij.as_slice().unwrap(),
|
||||
pij.as_slice().unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce
|
||||
let mut c = Vec::with_capacity(ct_first.c.len());
|
||||
for acci in acc.outer_iter() {
|
||||
let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree()));
|
||||
for (mut outij, accij, q) in izip!(
|
||||
coeffs.outer_iter_mut(),
|
||||
acci.outer_iter(),
|
||||
ctx.moduli_operators()
|
||||
) {
|
||||
for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) {
|
||||
unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) }
|
||||
}
|
||||
}
|
||||
c.push(Poly::try_convert_from(
|
||||
coeffs,
|
||||
ctx,
|
||||
true,
|
||||
Representation::Ntt,
|
||||
)?)
|
||||
}
|
||||
// Reduce
|
||||
let mut c = Vec::with_capacity(ct_first.c.len());
|
||||
for acci in acc.outer_iter() {
|
||||
let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree()));
|
||||
for (mut outij, accij, q) in izip!(
|
||||
coeffs.outer_iter_mut(),
|
||||
acci.outer_iter(),
|
||||
ctx.moduli_operators()
|
||||
) {
|
||||
for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) {
|
||||
unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) }
|
||||
}
|
||||
}
|
||||
c.push(Poly::try_convert_from(
|
||||
coeffs,
|
||||
ctx,
|
||||
true,
|
||||
Representation::Ntt,
|
||||
)?)
|
||||
}
|
||||
|
||||
Ok(Ciphertext {
|
||||
par: ct_first.par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level: ct_first.level,
|
||||
})
|
||||
}
|
||||
Ok(Ciphertext {
|
||||
par: ct_first.par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level: ct_first.level,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::dot_product_scalar;
|
||||
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
|
||||
use fhe_traits::{FheEncoder, FheEncrypter};
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use super::dot_product_scalar;
|
||||
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
|
||||
use fhe_traits::{FheEncoder, FheEncrypter};
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_scalar() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(2, 16)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
for size in 1..128 {
|
||||
let ct = (0..size)
|
||||
.map(|_| {
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap();
|
||||
sk.try_encrypt(&pt, &mut rng).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let pt = (0..size)
|
||||
.map(|_| {
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
#[test]
|
||||
fn test_dot_product_scalar() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(1, 8)),
|
||||
Arc::new(BfvParameters::default(2, 16)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
for size in 1..128 {
|
||||
let ct = (0..size)
|
||||
.map(|_| {
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap();
|
||||
sk.try_encrypt(&pt, &mut rng).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let pt = (0..size)
|
||||
.map(|_| {
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
Plaintext::try_encode(&v, Encoding::simd(), ¶ms).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let r = dot_product_scalar(ct.iter(), pt.iter())?;
|
||||
let r = dot_product_scalar(ct.iter(), pt.iter())?;
|
||||
|
||||
let mut expected = Ciphertext::zero(¶ms);
|
||||
izip!(&ct, &pt).for_each(|(cti, pti)| expected += &(cti * pti));
|
||||
assert_eq!(r, expected);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
let mut expected = Ciphertext::zero(¶ms);
|
||||
izip!(&ct, &pt).for_each(|(cti, pti)| expected += &(cti * pti));
|
||||
assert_eq!(r, expected);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use fhe_math::{
|
||||
rns::ScalingFactor,
|
||||
rq::{scaler::Scaler, Context, Representation},
|
||||
zq::primes::generate_prime,
|
||||
rns::ScalingFactor,
|
||||
rq::{scaler::Scaler, Context, Representation},
|
||||
zq::primes::generate_prime,
|
||||
};
|
||||
use fhe_util::div_ceil;
|
||||
use num_bigint::BigUint;
|
||||
|
||||
use crate::{
|
||||
bfv::{keys::RelinearizationKey, BfvParameters, Ciphertext},
|
||||
Error, Result,
|
||||
bfv::{keys::RelinearizationKey, BfvParameters, Ciphertext},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
/// Multiplicator that implements a strategy for multiplying. In particular, the
|
||||
@@ -22,390 +22,390 @@ use crate::{
|
||||
/// - Whether relinearization should be used.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Multiplicator {
|
||||
par: Arc<BfvParameters>,
|
||||
pub(crate) extender_lhs: Scaler,
|
||||
pub(crate) extender_rhs: Scaler,
|
||||
pub(crate) down_scaler: Scaler,
|
||||
pub(crate) base_ctx: Arc<Context>,
|
||||
pub(crate) mul_ctx: Arc<Context>,
|
||||
rk: Option<RelinearizationKey>,
|
||||
mod_switch: bool,
|
||||
level: usize,
|
||||
par: Arc<BfvParameters>,
|
||||
pub(crate) extender_lhs: Scaler,
|
||||
pub(crate) extender_rhs: Scaler,
|
||||
pub(crate) down_scaler: Scaler,
|
||||
pub(crate) base_ctx: Arc<Context>,
|
||||
pub(crate) mul_ctx: Arc<Context>,
|
||||
rk: Option<RelinearizationKey>,
|
||||
mod_switch: bool,
|
||||
level: usize,
|
||||
}
|
||||
|
||||
impl Multiplicator {
|
||||
/// Construct a multiplicator using custom scaling factors and extended
|
||||
/// basis.
|
||||
pub fn new(
|
||||
lhs_scaling_factor: ScalingFactor,
|
||||
rhs_scaling_factor: ScalingFactor,
|
||||
extended_basis: &[u64],
|
||||
post_mul_scaling_factor: ScalingFactor,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
Self::new_leveled_internal(
|
||||
lhs_scaling_factor,
|
||||
rhs_scaling_factor,
|
||||
extended_basis,
|
||||
post_mul_scaling_factor,
|
||||
0,
|
||||
par,
|
||||
)
|
||||
}
|
||||
/// Construct a multiplicator using custom scaling factors and extended
|
||||
/// basis.
|
||||
pub fn new(
|
||||
lhs_scaling_factor: ScalingFactor,
|
||||
rhs_scaling_factor: ScalingFactor,
|
||||
extended_basis: &[u64],
|
||||
post_mul_scaling_factor: ScalingFactor,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
Self::new_leveled_internal(
|
||||
lhs_scaling_factor,
|
||||
rhs_scaling_factor,
|
||||
extended_basis,
|
||||
post_mul_scaling_factor,
|
||||
0,
|
||||
par,
|
||||
)
|
||||
}
|
||||
|
||||
/// Construct a multiplicator using custom scaling factors and extended
|
||||
/// basis at a given level.
|
||||
pub fn new_leveled(
|
||||
lhs_scaling_factor: ScalingFactor,
|
||||
rhs_scaling_factor: ScalingFactor,
|
||||
extended_basis: &[u64],
|
||||
post_mul_scaling_factor: ScalingFactor,
|
||||
level: usize,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
Self::new_leveled_internal(
|
||||
lhs_scaling_factor,
|
||||
rhs_scaling_factor,
|
||||
extended_basis,
|
||||
post_mul_scaling_factor,
|
||||
level,
|
||||
par,
|
||||
)
|
||||
}
|
||||
/// Construct a multiplicator using custom scaling factors and extended
|
||||
/// basis at a given level.
|
||||
pub fn new_leveled(
|
||||
lhs_scaling_factor: ScalingFactor,
|
||||
rhs_scaling_factor: ScalingFactor,
|
||||
extended_basis: &[u64],
|
||||
post_mul_scaling_factor: ScalingFactor,
|
||||
level: usize,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
Self::new_leveled_internal(
|
||||
lhs_scaling_factor,
|
||||
rhs_scaling_factor,
|
||||
extended_basis,
|
||||
post_mul_scaling_factor,
|
||||
level,
|
||||
par,
|
||||
)
|
||||
}
|
||||
|
||||
fn new_leveled_internal(
|
||||
lhs_scaling_factor: ScalingFactor,
|
||||
rhs_scaling_factor: ScalingFactor,
|
||||
extended_basis: &[u64],
|
||||
post_mul_scaling_factor: ScalingFactor,
|
||||
level: usize,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
let base_ctx = par.ctx_at_level(level)?;
|
||||
let mul_ctx = Arc::new(Context::new(extended_basis, par.degree())?);
|
||||
let extender_lhs = Scaler::new(base_ctx, &mul_ctx, lhs_scaling_factor)?;
|
||||
let extender_rhs = Scaler::new(base_ctx, &mul_ctx, rhs_scaling_factor)?;
|
||||
let down_scaler = Scaler::new(&mul_ctx, base_ctx, post_mul_scaling_factor)?;
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
extender_lhs,
|
||||
extender_rhs,
|
||||
down_scaler,
|
||||
base_ctx: base_ctx.clone(),
|
||||
mul_ctx,
|
||||
rk: None,
|
||||
mod_switch: false,
|
||||
level,
|
||||
})
|
||||
}
|
||||
fn new_leveled_internal(
|
||||
lhs_scaling_factor: ScalingFactor,
|
||||
rhs_scaling_factor: ScalingFactor,
|
||||
extended_basis: &[u64],
|
||||
post_mul_scaling_factor: ScalingFactor,
|
||||
level: usize,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
let base_ctx = par.ctx_at_level(level)?;
|
||||
let mul_ctx = Arc::new(Context::new(extended_basis, par.degree())?);
|
||||
let extender_lhs = Scaler::new(base_ctx, &mul_ctx, lhs_scaling_factor)?;
|
||||
let extender_rhs = Scaler::new(base_ctx, &mul_ctx, rhs_scaling_factor)?;
|
||||
let down_scaler = Scaler::new(&mul_ctx, base_ctx, post_mul_scaling_factor)?;
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
extender_lhs,
|
||||
extender_rhs,
|
||||
down_scaler,
|
||||
base_ctx: base_ctx.clone(),
|
||||
mul_ctx,
|
||||
rk: None,
|
||||
mod_switch: false,
|
||||
level,
|
||||
})
|
||||
}
|
||||
|
||||
/// Default multiplication strategy using relinearization.
|
||||
pub fn default(rk: &RelinearizationKey) -> Result<Self> {
|
||||
let ctx = rk.ksk.par.ctx_at_level(rk.ksk.ciphertext_level)?;
|
||||
/// Default multiplication strategy using relinearization.
|
||||
pub fn default(rk: &RelinearizationKey) -> Result<Self> {
|
||||
let ctx = rk.ksk.par.ctx_at_level(rk.ksk.ciphertext_level)?;
|
||||
|
||||
let modulus_size = rk.ksk.par.moduli_sizes()[..ctx.moduli().len()]
|
||||
.iter()
|
||||
.sum::<usize>();
|
||||
let n_moduli = div_ceil(modulus_size + 60, 62);
|
||||
let modulus_size = rk.ksk.par.moduli_sizes()[..ctx.moduli().len()]
|
||||
.iter()
|
||||
.sum::<usize>();
|
||||
let n_moduli = div_ceil(modulus_size + 60, 62);
|
||||
|
||||
let mut extended_basis = Vec::with_capacity(ctx.moduli().len() + n_moduli);
|
||||
extended_basis.append(&mut ctx.moduli().to_vec());
|
||||
let mut upper_bound = 1 << 62;
|
||||
while extended_basis.len() != ctx.moduli().len() + n_moduli {
|
||||
upper_bound = generate_prime(62, 2 * rk.ksk.par.degree() as u64, upper_bound).unwrap();
|
||||
if !extended_basis.contains(&upper_bound) && !ctx.moduli().contains(&upper_bound) {
|
||||
extended_basis.push(upper_bound)
|
||||
}
|
||||
}
|
||||
let mut extended_basis = Vec::with_capacity(ctx.moduli().len() + n_moduli);
|
||||
extended_basis.append(&mut ctx.moduli().to_vec());
|
||||
let mut upper_bound = 1 << 62;
|
||||
while extended_basis.len() != ctx.moduli().len() + n_moduli {
|
||||
upper_bound = generate_prime(62, 2 * rk.ksk.par.degree() as u64, upper_bound).unwrap();
|
||||
if !extended_basis.contains(&upper_bound) && !ctx.moduli().contains(&upper_bound) {
|
||||
extended_basis.push(upper_bound)
|
||||
}
|
||||
}
|
||||
|
||||
let mut multiplicator = Self::new_leveled_internal(
|
||||
ScalingFactor::one(),
|
||||
ScalingFactor::one(),
|
||||
&extended_basis,
|
||||
ScalingFactor::new(
|
||||
&BigUint::from(rk.ksk.par.plaintext.modulus()),
|
||||
ctx.modulus(),
|
||||
),
|
||||
rk.ksk.ciphertext_level,
|
||||
&rk.ksk.par,
|
||||
)?;
|
||||
let mut multiplicator = Self::new_leveled_internal(
|
||||
ScalingFactor::one(),
|
||||
ScalingFactor::one(),
|
||||
&extended_basis,
|
||||
ScalingFactor::new(
|
||||
&BigUint::from(rk.ksk.par.plaintext.modulus()),
|
||||
ctx.modulus(),
|
||||
),
|
||||
rk.ksk.ciphertext_level,
|
||||
&rk.ksk.par,
|
||||
)?;
|
||||
|
||||
multiplicator.enable_relinearization(rk)?;
|
||||
Ok(multiplicator)
|
||||
}
|
||||
multiplicator.enable_relinearization(rk)?;
|
||||
Ok(multiplicator)
|
||||
}
|
||||
|
||||
/// Enable relinearization after multiplication.
|
||||
pub fn enable_relinearization(&mut self, rk: &RelinearizationKey) -> Result<()> {
|
||||
let rk_ctx = self.par.ctx_at_level(rk.ksk.ciphertext_level)?;
|
||||
if rk_ctx != &self.base_ctx {
|
||||
return Err(Error::DefaultError(
|
||||
"Invalid relinearization key context".to_string(),
|
||||
));
|
||||
}
|
||||
self.rk = Some(rk.clone());
|
||||
Ok(())
|
||||
}
|
||||
/// Enable relinearization after multiplication.
|
||||
pub fn enable_relinearization(&mut self, rk: &RelinearizationKey) -> Result<()> {
|
||||
let rk_ctx = self.par.ctx_at_level(rk.ksk.ciphertext_level)?;
|
||||
if rk_ctx != &self.base_ctx {
|
||||
return Err(Error::DefaultError(
|
||||
"Invalid relinearization key context".to_string(),
|
||||
));
|
||||
}
|
||||
self.rk = Some(rk.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enable modulus switching after multiplication (and relinearization, if
|
||||
/// applicable).
|
||||
pub fn enable_mod_switching(&mut self) -> Result<()> {
|
||||
if self.par.ctx_at_level(self.par.max_level())? == &self.base_ctx {
|
||||
Err(Error::DefaultError(
|
||||
"Cannot modulo switch as this is already the last level".to_string(),
|
||||
))
|
||||
} else {
|
||||
self.mod_switch = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// Enable modulus switching after multiplication (and relinearization, if
|
||||
/// applicable).
|
||||
pub fn enable_mod_switching(&mut self) -> Result<()> {
|
||||
if self.par.ctx_at_level(self.par.max_level())? == &self.base_ctx {
|
||||
Err(Error::DefaultError(
|
||||
"Cannot modulo switch as this is already the last level".to_string(),
|
||||
))
|
||||
} else {
|
||||
self.mod_switch = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply two ciphertexts using the defined multiplication strategy.
|
||||
pub fn multiply(&self, lhs: &Ciphertext, rhs: &Ciphertext) -> Result<Ciphertext> {
|
||||
if lhs.par != self.par || rhs.par != self.par {
|
||||
return Err(Error::DefaultError(
|
||||
"Ciphertexts do not have the same parameters".to_string(),
|
||||
));
|
||||
}
|
||||
if lhs.level != self.level || rhs.level != self.level {
|
||||
return Err(Error::DefaultError(
|
||||
"Ciphertexts are not at expected level".to_string(),
|
||||
));
|
||||
}
|
||||
if lhs.c.len() != 2 || rhs.c.len() != 2 {
|
||||
return Err(Error::DefaultError(
|
||||
"Multiplication can only be performed on ciphertexts of size 2".to_string(),
|
||||
));
|
||||
}
|
||||
/// Multiply two ciphertexts using the defined multiplication strategy.
|
||||
pub fn multiply(&self, lhs: &Ciphertext, rhs: &Ciphertext) -> Result<Ciphertext> {
|
||||
if lhs.par != self.par || rhs.par != self.par {
|
||||
return Err(Error::DefaultError(
|
||||
"Ciphertexts do not have the same parameters".to_string(),
|
||||
));
|
||||
}
|
||||
if lhs.level != self.level || rhs.level != self.level {
|
||||
return Err(Error::DefaultError(
|
||||
"Ciphertexts are not at expected level".to_string(),
|
||||
));
|
||||
}
|
||||
if lhs.c.len() != 2 || rhs.c.len() != 2 {
|
||||
return Err(Error::DefaultError(
|
||||
"Multiplication can only be performed on ciphertexts of size 2".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Extend
|
||||
let c00 = lhs.c[0].scale(&self.extender_lhs)?;
|
||||
let c01 = lhs.c[1].scale(&self.extender_lhs)?;
|
||||
let c10 = rhs.c[0].scale(&self.extender_rhs)?;
|
||||
let c11 = rhs.c[1].scale(&self.extender_rhs)?;
|
||||
// Extend
|
||||
let c00 = lhs.c[0].scale(&self.extender_lhs)?;
|
||||
let c01 = lhs.c[1].scale(&self.extender_lhs)?;
|
||||
let c10 = rhs.c[0].scale(&self.extender_rhs)?;
|
||||
let c11 = rhs.c[1].scale(&self.extender_rhs)?;
|
||||
|
||||
// Multiply
|
||||
let mut c0 = &c00 * &c10;
|
||||
let mut c1 = &c00 * &c11;
|
||||
c1 += &(&c01 * &c10);
|
||||
let mut c2 = &c01 * &c11;
|
||||
c0.change_representation(Representation::PowerBasis);
|
||||
c1.change_representation(Representation::PowerBasis);
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
// Multiply
|
||||
let mut c0 = &c00 * &c10;
|
||||
let mut c1 = &c00 * &c11;
|
||||
c1 += &(&c01 * &c10);
|
||||
let mut c2 = &c01 * &c11;
|
||||
c0.change_representation(Representation::PowerBasis);
|
||||
c1.change_representation(Representation::PowerBasis);
|
||||
c2.change_representation(Representation::PowerBasis);
|
||||
|
||||
// Scale
|
||||
let c0 = c0.scale(&self.down_scaler)?;
|
||||
let c1 = c1.scale(&self.down_scaler)?;
|
||||
let c2 = c2.scale(&self.down_scaler)?;
|
||||
// Scale
|
||||
let c0 = c0.scale(&self.down_scaler)?;
|
||||
let c1 = c1.scale(&self.down_scaler)?;
|
||||
let c2 = c2.scale(&self.down_scaler)?;
|
||||
|
||||
let mut c = vec![c0, c1, c2];
|
||||
let mut c = vec![c0, c1, c2];
|
||||
|
||||
// Relinearize
|
||||
if let Some(rk) = self.rk.as_ref() {
|
||||
#[allow(unused_mut)]
|
||||
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?;
|
||||
// Relinearize
|
||||
if let Some(rk) = self.rk.as_ref() {
|
||||
#[allow(unused_mut)]
|
||||
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?;
|
||||
|
||||
if c0r.ctx() != c[0].ctx() {
|
||||
c0r.change_representation(Representation::PowerBasis);
|
||||
c1r.change_representation(Representation::PowerBasis);
|
||||
c0r.mod_switch_down_to(c[0].ctx())?;
|
||||
c1r.mod_switch_down_to(c[1].ctx())?;
|
||||
} else {
|
||||
c[0].change_representation(Representation::Ntt);
|
||||
c[1].change_representation(Representation::Ntt);
|
||||
}
|
||||
if c0r.ctx() != c[0].ctx() {
|
||||
c0r.change_representation(Representation::PowerBasis);
|
||||
c1r.change_representation(Representation::PowerBasis);
|
||||
c0r.mod_switch_down_to(c[0].ctx())?;
|
||||
c1r.mod_switch_down_to(c[1].ctx())?;
|
||||
} else {
|
||||
c[0].change_representation(Representation::Ntt);
|
||||
c[1].change_representation(Representation::Ntt);
|
||||
}
|
||||
|
||||
c[0] += &c0r;
|
||||
c[1] += &c1r;
|
||||
c.truncate(2);
|
||||
}
|
||||
c[0] += &c0r;
|
||||
c[1] += &c1r;
|
||||
c.truncate(2);
|
||||
}
|
||||
|
||||
// We construct a ciphertext, but it may not have the right representation for
|
||||
// the polynomials yet.
|
||||
let mut c = Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level: self.level,
|
||||
};
|
||||
// We construct a ciphertext, but it may not have the right representation for
|
||||
// the polynomials yet.
|
||||
let mut c = Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: None,
|
||||
c,
|
||||
level: self.level,
|
||||
};
|
||||
|
||||
if self.mod_switch {
|
||||
c.mod_switch_to_next_level();
|
||||
} else {
|
||||
c.c.iter_mut()
|
||||
.for_each(|p| p.change_representation(Representation::Ntt));
|
||||
}
|
||||
if self.mod_switch {
|
||||
c.mod_switch_to_next_level();
|
||||
} else {
|
||||
c.c.iter_mut()
|
||||
.for_each(|p| p.change_representation(Representation::Ntt));
|
||||
}
|
||||
|
||||
Ok(c)
|
||||
}
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bfv::{
|
||||
BfvParameters, Ciphertext, Encoding, Plaintext, RelinearizationKey, SecretKey,
|
||||
};
|
||||
use fhe_math::{
|
||||
rns::{RnsContext, ScalingFactor},
|
||||
zq::primes::generate_prime,
|
||||
};
|
||||
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
|
||||
use num_bigint::BigUint;
|
||||
use rand::{rngs::OsRng, thread_rng};
|
||||
use std::{error::Error, sync::Arc};
|
||||
use crate::bfv::{
|
||||
BfvParameters, Ciphertext, Encoding, Plaintext, RelinearizationKey, SecretKey,
|
||||
};
|
||||
use fhe_math::{
|
||||
rns::{RnsContext, ScalingFactor},
|
||||
zq::primes::generate_prime,
|
||||
};
|
||||
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
|
||||
use num_bigint::BigUint;
|
||||
use rand::{rngs::OsRng, thread_rng};
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
use super::Multiplicator;
|
||||
use super::Multiplicator;
|
||||
|
||||
#[test]
|
||||
fn mul() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(3, 8));
|
||||
for _ in 0..30 {
|
||||
// We will encode `values` in an Simd format, and check that the product is
|
||||
// computed correctly.
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
#[test]
|
||||
fn mul() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(3, 8));
|
||||
for _ in 0..30 {
|
||||
// We will encode `values` in an Simd format, and check that the product is
|
||||
// computed correctly.
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
|
||||
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
|
||||
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
|
||||
let mut multiplicator = Multiplicator::default(&rk)?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
let mut multiplicator = Multiplicator::default(&rk)?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_at_level() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(3, 8));
|
||||
for _ in 0..15 {
|
||||
for level in 0..2 {
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
#[test]
|
||||
fn mul_at_level() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(3, 8));
|
||||
for _ in 0..15 {
|
||||
for level in 0..2 {
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd_at_level(level), &par)?;
|
||||
let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
assert_eq!(ct1.level, level);
|
||||
assert_eq!(ct2.level, level);
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd_at_level(level), &par)?;
|
||||
let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
assert_eq!(ct1.level, level);
|
||||
assert_eq!(ct2.level, level);
|
||||
|
||||
let mut multiplicator = Multiplicator::default(&rk).unwrap();
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2).unwrap();
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
let mut multiplicator = Multiplicator::default(&rk).unwrap();
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2).unwrap();
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, level + 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, level + 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_no_relin() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(6, 8));
|
||||
for _ in 0..30 {
|
||||
// We will encode `values` in an Simd format, and check that the product is
|
||||
// computed correctly.
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
#[test]
|
||||
fn mul_no_relin() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(6, 8));
|
||||
for _ in 0..30 {
|
||||
// We will encode `values` in an Simd format, and check that the product is
|
||||
// computed correctly.
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
|
||||
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let rk = RelinearizationKey::new(&sk, &mut rng)?;
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
|
||||
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
|
||||
let mut multiplicator = Multiplicator::default(&rk)?;
|
||||
// Remove the relinearization key.
|
||||
multiplicator.rk = None;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
let mut multiplicator = Multiplicator::default(&rk)?;
|
||||
// Remove the relinearization key.
|
||||
multiplicator.rk = None;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_mul_strategy() -> Result<(), Box<dyn Error>> {
|
||||
// Implement the second multiplication strategy from <https://eprint.iacr.org/2021/204>
|
||||
#[test]
|
||||
fn different_mul_strategy() -> Result<(), Box<dyn Error>> {
|
||||
// Implement the second multiplication strategy from <https://eprint.iacr.org/2021/204>
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(3, 8));
|
||||
let mut extended_basis = par.moduli().to_vec();
|
||||
extended_basis
|
||||
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap());
|
||||
extended_basis
|
||||
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[3]).unwrap());
|
||||
extended_basis
|
||||
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[4]).unwrap());
|
||||
let rns = RnsContext::new(&extended_basis[3..])?;
|
||||
let mut rng = thread_rng();
|
||||
let par = Arc::new(BfvParameters::default(3, 8));
|
||||
let mut extended_basis = par.moduli().to_vec();
|
||||
extended_basis
|
||||
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap());
|
||||
extended_basis
|
||||
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[3]).unwrap());
|
||||
extended_basis
|
||||
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[4]).unwrap());
|
||||
let rns = RnsContext::new(&extended_basis[3..])?;
|
||||
|
||||
for _ in 0..30 {
|
||||
// We will encode `values` in an Simd format, and check that the product is
|
||||
// computed correctly.
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
for _ in 0..30 {
|
||||
// We will encode `values` in an Simd format, and check that the product is
|
||||
// computed correctly.
|
||||
let values = par.plaintext.random_vec(par.degree(), &mut rng);
|
||||
let mut expected = values.clone();
|
||||
par.plaintext.mul_vec(&mut expected, &values);
|
||||
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
|
||||
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let sk = SecretKey::random(&par, &mut OsRng);
|
||||
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
|
||||
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
|
||||
|
||||
let mut multiplicator = Multiplicator::new(
|
||||
ScalingFactor::one(),
|
||||
ScalingFactor::new(rns.modulus(), par.ctx[0].modulus()),
|
||||
&extended_basis,
|
||||
ScalingFactor::new(&BigUint::from(par.plaintext()), rns.modulus()),
|
||||
&par,
|
||||
)?;
|
||||
let mut multiplicator = Multiplicator::new(
|
||||
ScalingFactor::one(),
|
||||
ScalingFactor::new(rns.modulus(), par.ctx[0].modulus()),
|
||||
&extended_basis,
|
||||
ScalingFactor::new(&BigUint::from(par.plaintext()), rns.modulus()),
|
||||
&par,
|
||||
)?;
|
||||
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
multiplicator.enable_mod_switching()?;
|
||||
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
|
||||
assert_eq!(ct3.level, 1);
|
||||
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
|
||||
let pt = sk.try_decrypt(&ct3)?;
|
||||
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
//! Plaintext type in the BFV encryption scheme.
|
||||
use crate::{
|
||||
bfv::{BfvParameters, Encoding, PlaintextVec},
|
||||
Error, Result,
|
||||
bfv::{BfvParameters, Encoding, PlaintextVec},
|
||||
Error, Result,
|
||||
};
|
||||
use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation};
|
||||
use fhe_traits::{FheDecoder, FheEncoder, FheParametrized, FhePlaintext};
|
||||
@@ -13,69 +13,69 @@ use super::encoding::EncodingEnum;
|
||||
/// A plaintext object, that encodes a vector according to a specific encoding.
|
||||
#[derive(Debug, Clone, Eq)]
|
||||
pub struct Plaintext {
|
||||
/// The parameters of the underlying BFV encryption scheme.
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
/// The value after encoding.
|
||||
pub(crate) value: Box<[u64]>,
|
||||
/// The encoding of the plaintext, if known
|
||||
pub(crate) encoding: Option<Encoding>,
|
||||
/// The plaintext as a polynomial.
|
||||
pub(crate) poly_ntt: Poly,
|
||||
/// The level of the plaintext
|
||||
pub(crate) level: usize,
|
||||
/// The parameters of the underlying BFV encryption scheme.
|
||||
pub(crate) par: Arc<BfvParameters>,
|
||||
/// The value after encoding.
|
||||
pub(crate) value: Box<[u64]>,
|
||||
/// The encoding of the plaintext, if known
|
||||
pub(crate) encoding: Option<Encoding>,
|
||||
/// The plaintext as a polynomial.
|
||||
pub(crate) poly_ntt: Poly,
|
||||
/// The level of the plaintext
|
||||
pub(crate) level: usize,
|
||||
}
|
||||
|
||||
impl FheParametrized for Plaintext {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl FhePlaintext for Plaintext {
|
||||
type Encoding = Encoding;
|
||||
type Encoding = Encoding;
|
||||
}
|
||||
|
||||
// Zeroizing of plaintexts.
|
||||
impl ZeroizeOnDrop for Plaintext {}
|
||||
|
||||
impl Zeroize for Plaintext {
|
||||
fn zeroize(&mut self) {
|
||||
self.value.zeroize();
|
||||
self.poly_ntt.zeroize();
|
||||
}
|
||||
fn zeroize(&mut self) {
|
||||
self.value.zeroize();
|
||||
self.poly_ntt.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl Plaintext {
|
||||
pub(crate) fn to_poly(&self) -> Poly {
|
||||
let mut m_v = Zeroizing::new(self.value.clone());
|
||||
self.par
|
||||
.plaintext
|
||||
.scalar_mul_vec(&mut m_v, self.par.q_mod_t[self.level]);
|
||||
let ctx = self.par.ctx_at_level(self.level).unwrap();
|
||||
let mut m =
|
||||
Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap();
|
||||
m.change_representation(Representation::Ntt);
|
||||
m *= &self.par.delta[self.level];
|
||||
m
|
||||
}
|
||||
pub(crate) fn to_poly(&self) -> Poly {
|
||||
let mut m_v = Zeroizing::new(self.value.clone());
|
||||
self.par
|
||||
.plaintext
|
||||
.scalar_mul_vec(&mut m_v, self.par.q_mod_t[self.level]);
|
||||
let ctx = self.par.ctx_at_level(self.level).unwrap();
|
||||
let mut m =
|
||||
Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap();
|
||||
m.change_representation(Representation::Ntt);
|
||||
m *= &self.par.delta[self.level];
|
||||
m
|
||||
}
|
||||
|
||||
/// Generate a zero plaintext.
|
||||
pub fn zero(encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
let level = encoding.level;
|
||||
let ctx = par.ctx_at_level(level)?;
|
||||
let value = vec![0u64; par.degree()];
|
||||
let poly_ntt = Poly::zero(ctx, Representation::Ntt);
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
value: value.into_boxed_slice(),
|
||||
encoding: Some(encoding),
|
||||
poly_ntt,
|
||||
level,
|
||||
})
|
||||
}
|
||||
/// Generate a zero plaintext.
|
||||
pub fn zero(encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
let level = encoding.level;
|
||||
let ctx = par.ctx_at_level(level)?;
|
||||
let value = vec![0u64; par.degree()];
|
||||
let poly_ntt = Poly::zero(ctx, Representation::Ntt);
|
||||
Ok(Self {
|
||||
par: par.clone(),
|
||||
value: value.into_boxed_slice(),
|
||||
encoding: Some(encoding),
|
||||
poly_ntt,
|
||||
level,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the level of this plaintext.
|
||||
pub fn level(&self) -> usize {
|
||||
self.par.level_of_ctx(self.poly_ntt.ctx()).unwrap()
|
||||
}
|
||||
/// Returns the level of this plaintext.
|
||||
pub fn level(&self) -> usize {
|
||||
self.par.level_of_ctx(self.poly_ntt.ctx()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for Plaintext {}
|
||||
@@ -83,320 +83,320 @@ unsafe impl Send for Plaintext {}
|
||||
// Implement the equality manually; we want to say that two plaintexts are equal
|
||||
// even if one of them doesn't store its encoding information.
|
||||
impl PartialEq for Plaintext {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let mut eq = self.par == other.par;
|
||||
eq &= self.value == other.value;
|
||||
if self.encoding.is_some() && other.encoding.is_some() {
|
||||
eq &= self.encoding.as_ref().unwrap() == other.encoding.as_ref().unwrap()
|
||||
}
|
||||
eq
|
||||
}
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let mut eq = self.par == other.par;
|
||||
eq &= self.value == other.value;
|
||||
if self.encoding.is_some() && other.encoding.is_some() {
|
||||
eq &= self.encoding.as_ref().unwrap() == other.encoding.as_ref().unwrap()
|
||||
}
|
||||
eq
|
||||
}
|
||||
}
|
||||
|
||||
// Conversions.
|
||||
impl TryConvertFrom<&Plaintext> for Poly {
|
||||
fn try_convert_from<R>(
|
||||
pt: &Plaintext,
|
||||
ctx: &Arc<Context>,
|
||||
variable_time: bool,
|
||||
_: R,
|
||||
) -> fhe_math::Result<Self>
|
||||
where
|
||||
R: Into<Option<Representation>>,
|
||||
{
|
||||
if ctx
|
||||
!= pt
|
||||
.par
|
||||
.ctx_at_level(pt.level())
|
||||
.map_err(|e| fhe_math::Error::Default(e.to_string()))?
|
||||
{
|
||||
Err(fhe_math::Error::Default(
|
||||
"Incompatible contexts".to_string(),
|
||||
))
|
||||
} else {
|
||||
Poly::try_convert_from(
|
||||
pt.value.as_ref(),
|
||||
ctx,
|
||||
variable_time,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
}
|
||||
}
|
||||
fn try_convert_from<R>(
|
||||
pt: &Plaintext,
|
||||
ctx: &Arc<Context>,
|
||||
variable_time: bool,
|
||||
_: R,
|
||||
) -> fhe_math::Result<Self>
|
||||
where
|
||||
R: Into<Option<Representation>>,
|
||||
{
|
||||
if ctx
|
||||
!= pt
|
||||
.par
|
||||
.ctx_at_level(pt.level())
|
||||
.map_err(|e| fhe_math::Error::Default(e.to_string()))?
|
||||
{
|
||||
Err(fhe_math::Error::Default(
|
||||
"Incompatible contexts".to_string(),
|
||||
))
|
||||
} else {
|
||||
Poly::try_convert_from(
|
||||
pt.value.as_ref(),
|
||||
ctx,
|
||||
variable_time,
|
||||
Representation::PowerBasis,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encoding and decoding.
|
||||
|
||||
impl<'a, const N: usize, T> FheEncoder<&'a [T; N]> for Plaintext
|
||||
where
|
||||
Plaintext: FheEncoder<&'a [T], Error = Error>,
|
||||
Plaintext: FheEncoder<&'a [T], Error = Error>,
|
||||
{
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a [T; N], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
Plaintext::try_encode(value.as_ref(), encoding, par)
|
||||
}
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a [T; N], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
Plaintext::try_encode(value.as_ref(), encoding, par)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> FheEncoder<&'a Vec<T>> for Plaintext
|
||||
where
|
||||
Plaintext: FheEncoder<&'a [T], Error = Error>,
|
||||
Plaintext: FheEncoder<&'a [T], Error = Error>,
|
||||
{
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a Vec<T>, encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
Plaintext::try_encode(value.as_ref(), encoding, par)
|
||||
}
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a Vec<T>, encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
Plaintext::try_encode(value.as_ref(), encoding, par)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FheEncoder<&'a [u64]> for Plaintext {
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if value.len() > par.degree() {
|
||||
return Err(Error::TooManyValues(value.len(), par.degree()));
|
||||
}
|
||||
let v = PlaintextVec::try_encode(value, encoding, par)?;
|
||||
Ok(v.0[0].clone())
|
||||
}
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if value.len() > par.degree() {
|
||||
return Err(Error::TooManyValues(value.len(), par.degree()));
|
||||
}
|
||||
let v = PlaintextVec::try_encode(value, encoding, par)?;
|
||||
Ok(v.0[0].clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FheEncoder<&'a [i64]> for Plaintext {
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value));
|
||||
Plaintext::try_encode(w.as_ref() as &[u64], encoding, par)
|
||||
}
|
||||
type Error = Error;
|
||||
fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value));
|
||||
Plaintext::try_encode(w.as_ref() as &[u64], encoding, par)
|
||||
}
|
||||
}
|
||||
|
||||
impl FheDecoder<Plaintext> for Vec<u64> {
|
||||
fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>>
|
||||
where
|
||||
O: Into<Option<Encoding>>,
|
||||
{
|
||||
let encoding = encoding.into();
|
||||
let enc: Encoding;
|
||||
if pt.encoding.is_none() && encoding.is_none() {
|
||||
return Err(Error::UnspecifiedInput("No encoding specified".to_string()));
|
||||
} else if pt.encoding.is_some() {
|
||||
enc = pt.encoding.as_ref().unwrap().clone();
|
||||
if let Some(arg_enc) = encoding {
|
||||
if arg_enc != enc {
|
||||
return Err(Error::EncodingMismatch(arg_enc.into(), enc.into()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
enc = encoding.unwrap();
|
||||
if let Some(pt_enc) = pt.encoding.as_ref() {
|
||||
if pt_enc != &enc {
|
||||
return Err(Error::EncodingMismatch(pt_enc.into(), enc.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>>
|
||||
where
|
||||
O: Into<Option<Encoding>>,
|
||||
{
|
||||
let encoding = encoding.into();
|
||||
let enc: Encoding;
|
||||
if pt.encoding.is_none() && encoding.is_none() {
|
||||
return Err(Error::UnspecifiedInput("No encoding specified".to_string()));
|
||||
} else if pt.encoding.is_some() {
|
||||
enc = pt.encoding.as_ref().unwrap().clone();
|
||||
if let Some(arg_enc) = encoding {
|
||||
if arg_enc != enc {
|
||||
return Err(Error::EncodingMismatch(arg_enc.into(), enc.into()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
enc = encoding.unwrap();
|
||||
if let Some(pt_enc) = pt.encoding.as_ref() {
|
||||
if pt_enc != &enc {
|
||||
return Err(Error::EncodingMismatch(pt_enc.into(), enc.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut w = pt.value.to_vec();
|
||||
let mut w = pt.value.to_vec();
|
||||
|
||||
match enc.encoding {
|
||||
EncodingEnum::Poly => Ok(w),
|
||||
EncodingEnum::Simd => {
|
||||
if let Some(op) = &pt.par.op {
|
||||
op.forward(&mut w);
|
||||
let mut w_reordered = w.clone();
|
||||
for i in 0..pt.par.degree() {
|
||||
w_reordered[i] = w[pt.par.matrix_reps_index_map[i]]
|
||||
}
|
||||
w.zeroize();
|
||||
Ok(w_reordered)
|
||||
} else {
|
||||
Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
match enc.encoding {
|
||||
EncodingEnum::Poly => Ok(w),
|
||||
EncodingEnum::Simd => {
|
||||
if let Some(op) = &pt.par.op {
|
||||
op.forward(&mut w);
|
||||
let mut w_reordered = w.clone();
|
||||
for i in 0..pt.par.degree() {
|
||||
w_reordered[i] = w[pt.par.matrix_reps_index_map[i]]
|
||||
}
|
||||
w.zeroize();
|
||||
Ok(w_reordered)
|
||||
} else {
|
||||
Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
}
|
||||
|
||||
impl FheDecoder<Plaintext> for Vec<i64> {
|
||||
fn try_decode<E>(pt: &Plaintext, encoding: E) -> Result<Vec<i64>>
|
||||
where
|
||||
E: Into<Option<Encoding>>,
|
||||
{
|
||||
let v = Vec::<u64>::try_decode(pt, encoding)?;
|
||||
Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) })
|
||||
}
|
||||
fn try_decode<E>(pt: &Plaintext, encoding: E) -> Result<Vec<i64>>
|
||||
where
|
||||
E: Into<Option<Encoding>>,
|
||||
{
|
||||
let v = Vec::<u64>::try_decode(pt, encoding)?;
|
||||
Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) })
|
||||
}
|
||||
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Encoding, Plaintext};
|
||||
use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder};
|
||||
use fhe_math::rq::{Poly, Representation};
|
||||
use fhe_traits::{FheDecoder, FheEncoder};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use zeroize::Zeroize;
|
||||
use super::{Encoding, Plaintext};
|
||||
use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder};
|
||||
use fhe_math::rq::{Poly, Representation};
|
||||
use fhe_traits::{FheDecoder, FheEncoder};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use zeroize::Zeroize;
|
||||
|
||||
#[test]
|
||||
fn try_encode() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
// The default test parameters support both Poly and Simd encodings
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
#[test]
|
||||
fn try_encode() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
// The default test parameters support both Poly and Simd encodings
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_err());
|
||||
let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_err());
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
|
||||
let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
|
||||
// The following parameters do not allow for Simd encoding
|
||||
let params = Arc::new(
|
||||
BfvParametersBuilder::new()
|
||||
.set_degree(8)
|
||||
.set_plaintext_modulus(2)
|
||||
.set_moduli(&[4611686018326724609])
|
||||
.build()?,
|
||||
);
|
||||
// The following parameters do not allow for Simd encoding
|
||||
let params = Arc::new(
|
||||
BfvParametersBuilder::new()
|
||||
.set_degree(8)
|
||||
.set_plaintext_modulus(2)
|
||||
.set_moduli(&[4611686018326724609])
|
||||
.build()?,
|
||||
);
|
||||
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_err());
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_decode() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
#[test]
|
||||
fn encode_decode() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
|
||||
assert_eq!(b, a);
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
|
||||
assert_eq!(b, a);
|
||||
|
||||
let a = unsafe { params.plaintext.center_vec_vt(&a) };
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?;
|
||||
assert_eq!(b, a);
|
||||
let a = unsafe { params.plaintext.center_vec_vt(&a) };
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?;
|
||||
assert_eq!(b, a);
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
|
||||
assert_eq!(b, a);
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), ¶ms);
|
||||
assert!(plaintext.is_ok());
|
||||
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
|
||||
assert_eq!(b, a);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_eq() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
#[test]
|
||||
fn partial_eq() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
assert_eq!(plaintext, same_plaintext);
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
assert_eq!(plaintext, same_plaintext);
|
||||
|
||||
// Equality also holds when there is no encoding specified. In this test, we use
|
||||
// the fact that we can set it to None directly, but such a partial plaintext
|
||||
// will be created during decryption since we do not specify the encoding at the
|
||||
// time.
|
||||
same_plaintext.encoding = None;
|
||||
assert_eq!(plaintext, same_plaintext);
|
||||
// Equality also holds when there is no encoding specified. In this test, we use
|
||||
// the fact that we can set it to None directly, but such a partial plaintext
|
||||
// will be created during decryption since we do not specify the encoding at the
|
||||
// time.
|
||||
same_plaintext.encoding = None;
|
||||
assert_eq!(plaintext, same_plaintext);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_decode_errors() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
#[test]
|
||||
fn try_decode_errors() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
|
||||
assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok());
|
||||
let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd());
|
||||
assert!(e.is_err());
|
||||
assert_eq!(
|
||||
e.unwrap_err(),
|
||||
crate::Error::EncodingMismatch(Encoding::simd().into(), Encoding::poly().into())
|
||||
);
|
||||
let e = Vec::<u64>::try_decode(&plaintext, Encoding::poly_at_level(1));
|
||||
assert!(e.is_err());
|
||||
assert_eq!(
|
||||
e.unwrap_err(),
|
||||
crate::Error::EncodingMismatch(
|
||||
Encoding::poly_at_level(1).into(),
|
||||
Encoding::poly().into()
|
||||
)
|
||||
);
|
||||
assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok());
|
||||
let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd());
|
||||
assert!(e.is_err());
|
||||
assert_eq!(
|
||||
e.unwrap_err(),
|
||||
crate::Error::EncodingMismatch(Encoding::simd().into(), Encoding::poly().into())
|
||||
);
|
||||
let e = Vec::<u64>::try_decode(&plaintext, Encoding::poly_at_level(1));
|
||||
assert!(e.is_err());
|
||||
assert_eq!(
|
||||
e.unwrap_err(),
|
||||
crate::Error::EncodingMismatch(
|
||||
Encoding::poly_at_level(1).into(),
|
||||
Encoding::poly().into()
|
||||
)
|
||||
);
|
||||
|
||||
plaintext.encoding = None;
|
||||
let e = Vec::<u64>::try_decode(&plaintext, None);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(
|
||||
e.unwrap_err(),
|
||||
crate::Error::UnspecifiedInput("No encoding specified".to_string())
|
||||
);
|
||||
plaintext.encoding = None;
|
||||
let e = Vec::<u64>::try_decode(&plaintext, None);
|
||||
assert!(e.is_err());
|
||||
assert_eq!(
|
||||
e.unwrap_err(),
|
||||
crate::Error::UnspecifiedInput("No encoding specified".to_string())
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero() -> Result<(), Box<dyn Error>> {
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let plaintext = Plaintext::zero(Encoding::poly(), ¶ms)?;
|
||||
#[test]
|
||||
fn zero() -> Result<(), Box<dyn Error>> {
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let plaintext = Plaintext::zero(Encoding::poly(), ¶ms)?;
|
||||
|
||||
assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8]));
|
||||
assert_eq!(
|
||||
plaintext.poly_ntt,
|
||||
Poly::zero(¶ms.ctx[0], Representation::Ntt)
|
||||
);
|
||||
assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8]));
|
||||
assert_eq!(
|
||||
plaintext.poly_ntt,
|
||||
Poly::zero(¶ms.ctx[0], Representation::Ntt)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeroize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
#[test]
|
||||
fn zeroize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), ¶ms)?;
|
||||
|
||||
plaintext.zeroize();
|
||||
plaintext.zeroize();
|
||||
|
||||
assert_eq!(plaintext, Plaintext::zero(Encoding::poly(), ¶ms)?);
|
||||
assert_eq!(plaintext, Plaintext::zero(Encoding::poly(), ¶ms)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_encode_level() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
// The default test parameters support both Poly and Simd encodings
|
||||
let params = Arc::new(BfvParameters::default(10, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
#[test]
|
||||
fn try_encode_level() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
// The default test parameters support both Poly and Simd encodings
|
||||
let params = Arc::new(BfvParameters::default(10, 8));
|
||||
let a = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
for level in 0..10 {
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), ¶ms)?;
|
||||
assert_eq!(plaintext.level(), level);
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), ¶ms)?;
|
||||
assert_eq!(plaintext.level(), level);
|
||||
}
|
||||
for level in 0..10 {
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), ¶ms)?;
|
||||
assert_eq!(plaintext.level(), level);
|
||||
let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), ¶ms)?;
|
||||
assert_eq!(plaintext.level(), level);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ use fhe_util::div_ceil;
|
||||
use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||
|
||||
use crate::{
|
||||
bfv::{BfvParameters, Encoding, Plaintext},
|
||||
Error, Result,
|
||||
bfv::{BfvParameters, Encoding, Plaintext},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
use super::encoding::EncodingEnum;
|
||||
@@ -17,148 +17,148 @@ use super::encoding::EncodingEnum;
|
||||
pub struct PlaintextVec(pub Vec<Plaintext>);
|
||||
|
||||
impl FhePlaintext for PlaintextVec {
|
||||
type Encoding = Encoding;
|
||||
type Encoding = Encoding;
|
||||
}
|
||||
|
||||
impl FheParametrized for PlaintextVec {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl Zeroize for PlaintextVec {
|
||||
fn zeroize(&mut self) {
|
||||
self.0.zeroize()
|
||||
}
|
||||
fn zeroize(&mut self) {
|
||||
self.0.zeroize()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZeroizeOnDrop for PlaintextVec {}
|
||||
|
||||
impl FheEncoderVariableTime<&[u64]> for PlaintextVec {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
unsafe fn try_encode_vt(
|
||||
value: &[u64],
|
||||
encoding: Encoding,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
if value.is_empty() {
|
||||
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
|
||||
}
|
||||
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
|
||||
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
|
||||
}
|
||||
let ctx = par.ctx_at_level(encoding.level)?;
|
||||
let num_plaintexts = div_ceil(value.len(), par.degree());
|
||||
unsafe fn try_encode_vt(
|
||||
value: &[u64],
|
||||
encoding: Encoding,
|
||||
par: &Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
if value.is_empty() {
|
||||
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
|
||||
}
|
||||
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
|
||||
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
|
||||
}
|
||||
let ctx = par.ctx_at_level(encoding.level)?;
|
||||
let num_plaintexts = div_ceil(value.len(), par.degree());
|
||||
|
||||
Ok(PlaintextVec(
|
||||
(0..num_plaintexts)
|
||||
.map(|i| {
|
||||
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
|
||||
let mut v = vec![0u64; par.degree()];
|
||||
match encoding.encoding {
|
||||
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
|
||||
EncodingEnum::Simd => {
|
||||
for i in 0..slice.len() {
|
||||
v[par.matrix_reps_index_map[i]] = slice[i];
|
||||
}
|
||||
par.op.as_ref().unwrap().backward_vt(v.as_mut_ptr());
|
||||
}
|
||||
};
|
||||
Ok(PlaintextVec(
|
||||
(0..num_plaintexts)
|
||||
.map(|i| {
|
||||
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
|
||||
let mut v = vec![0u64; par.degree()];
|
||||
match encoding.encoding {
|
||||
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
|
||||
EncodingEnum::Simd => {
|
||||
for i in 0..slice.len() {
|
||||
v[par.matrix_reps_index_map[i]] = slice[i];
|
||||
}
|
||||
par.op.as_ref().unwrap().backward_vt(v.as_mut_ptr());
|
||||
}
|
||||
};
|
||||
|
||||
let mut poly =
|
||||
Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?;
|
||||
poly.change_representation(Representation::Ntt);
|
||||
let mut poly =
|
||||
Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?;
|
||||
poly.change_representation(Representation::Ntt);
|
||||
|
||||
Ok(Plaintext {
|
||||
par: par.clone(),
|
||||
value: v.into_boxed_slice(),
|
||||
encoding: Some(encoding.clone()),
|
||||
poly_ntt: poly,
|
||||
level: encoding.level,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<Plaintext>>>()?,
|
||||
))
|
||||
}
|
||||
Ok(Plaintext {
|
||||
par: par.clone(),
|
||||
value: v.into_boxed_slice(),
|
||||
encoding: Some(encoding.clone()),
|
||||
poly_ntt: poly,
|
||||
level: encoding.level,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<Plaintext>>>()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl FheEncoder<&[u64]> for PlaintextVec {
|
||||
type Error = Error;
|
||||
fn try_encode(value: &[u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if value.is_empty() {
|
||||
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
|
||||
}
|
||||
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
|
||||
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
|
||||
}
|
||||
let ctx = par.ctx_at_level(encoding.level)?;
|
||||
let num_plaintexts = div_ceil(value.len(), par.degree());
|
||||
type Error = Error;
|
||||
fn try_encode(value: &[u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
|
||||
if value.is_empty() {
|
||||
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
|
||||
}
|
||||
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
|
||||
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
|
||||
}
|
||||
let ctx = par.ctx_at_level(encoding.level)?;
|
||||
let num_plaintexts = div_ceil(value.len(), par.degree());
|
||||
|
||||
Ok(PlaintextVec(
|
||||
(0..num_plaintexts)
|
||||
.map(|i| {
|
||||
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
|
||||
let mut v = vec![0u64; par.degree()];
|
||||
match encoding.encoding {
|
||||
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
|
||||
EncodingEnum::Simd => {
|
||||
for i in 0..slice.len() {
|
||||
v[par.matrix_reps_index_map[i]] = slice[i];
|
||||
}
|
||||
par.op.as_ref().unwrap().backward(&mut v);
|
||||
}
|
||||
};
|
||||
Ok(PlaintextVec(
|
||||
(0..num_plaintexts)
|
||||
.map(|i| {
|
||||
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
|
||||
let mut v = vec![0u64; par.degree()];
|
||||
match encoding.encoding {
|
||||
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
|
||||
EncodingEnum::Simd => {
|
||||
for i in 0..slice.len() {
|
||||
v[par.matrix_reps_index_map[i]] = slice[i];
|
||||
}
|
||||
par.op.as_ref().unwrap().backward(&mut v);
|
||||
}
|
||||
};
|
||||
|
||||
let mut poly =
|
||||
Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?;
|
||||
poly.change_representation(Representation::Ntt);
|
||||
let mut poly =
|
||||
Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?;
|
||||
poly.change_representation(Representation::Ntt);
|
||||
|
||||
Ok(Plaintext {
|
||||
par: par.clone(),
|
||||
value: v.into_boxed_slice(),
|
||||
encoding: Some(encoding.clone()),
|
||||
poly_ntt: poly,
|
||||
level: encoding.level,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<Plaintext>>>()?,
|
||||
))
|
||||
}
|
||||
Ok(Plaintext {
|
||||
par: par.clone(),
|
||||
value: v.into_boxed_slice(),
|
||||
encoding: Some(encoding.clone()),
|
||||
poly_ntt: poly,
|
||||
level: encoding.level,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<Plaintext>>>()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bfv::{BfvParameters, Encoding, PlaintextVec};
|
||||
use fhe_traits::{FheDecoder, FheEncoder, FheEncoderVariableTime};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
use crate::bfv::{BfvParameters, Encoding, PlaintextVec};
|
||||
use fhe_traits::{FheDecoder, FheEncoder, FheEncoderVariableTime};
|
||||
use rand::thread_rng;
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
#[test]
|
||||
fn encode_decode() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for _ in 0..20 {
|
||||
for i in 1..5 {
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree() * i, &mut rng);
|
||||
#[test]
|
||||
fn encode_decode() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for _ in 0..20 {
|
||||
for i in 1..5 {
|
||||
let params = Arc::new(BfvParameters::default(1, 8));
|
||||
let a = params.plaintext.random_vec(params.degree() * i, &mut rng);
|
||||
|
||||
let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), ¶ms)?;
|
||||
assert_eq!(plaintexts.0.len(), i);
|
||||
let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), ¶ms)?;
|
||||
assert_eq!(plaintexts.0.len(), i);
|
||||
|
||||
for j in 0..i {
|
||||
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
|
||||
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
|
||||
}
|
||||
for j in 0..i {
|
||||
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
|
||||
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
|
||||
}
|
||||
|
||||
let plaintexts = unsafe {
|
||||
PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), ¶ms)?
|
||||
};
|
||||
assert_eq!(plaintexts.0.len(), i);
|
||||
let plaintexts = unsafe {
|
||||
PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), ¶ms)?
|
||||
};
|
||||
assert_eq!(plaintexts.0.len(), i);
|
||||
|
||||
for j in 0..i {
|
||||
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
|
||||
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
for j in 0..i {
|
||||
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
|
||||
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,210 +2,210 @@ use std::ops::Mul;
|
||||
|
||||
use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation};
|
||||
use fhe_traits::{
|
||||
DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize,
|
||||
DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize,
|
||||
};
|
||||
use protobuf::{Message, MessageField};
|
||||
use rand::{CryptoRng, RngCore};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use crate::{
|
||||
bfv::proto::bfv::{
|
||||
KeySwitchingKey as KeySwitchingKeyProto, RGSWCiphertext as RGSWCiphertextProto,
|
||||
},
|
||||
Error, Result,
|
||||
bfv::proto::bfv::{
|
||||
KeySwitchingKey as KeySwitchingKeyProto, RGSWCiphertext as RGSWCiphertextProto,
|
||||
},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
use super::{
|
||||
keys::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, Ciphertext, Plaintext, SecretKey,
|
||||
keys::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, Ciphertext, Plaintext, SecretKey,
|
||||
};
|
||||
|
||||
/// A RGSW ciphertext encrypting a plaintext.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct RGSWCiphertext {
|
||||
ksk0: KeySwitchingKey,
|
||||
ksk1: KeySwitchingKey,
|
||||
ksk0: KeySwitchingKey,
|
||||
ksk1: KeySwitchingKey,
|
||||
}
|
||||
|
||||
impl FheParametrized for RGSWCiphertext {
|
||||
type Parameters = BfvParameters;
|
||||
type Parameters = BfvParameters;
|
||||
}
|
||||
|
||||
impl From<&RGSWCiphertext> for RGSWCiphertextProto {
|
||||
fn from(ct: &RGSWCiphertext) -> Self {
|
||||
let mut proto = RGSWCiphertextProto::new();
|
||||
proto.ksk0 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk0));
|
||||
proto.ksk1 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk1));
|
||||
proto
|
||||
}
|
||||
fn from(ct: &RGSWCiphertext) -> Self {
|
||||
let mut proto = RGSWCiphertextProto::new();
|
||||
proto.ksk0 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk0));
|
||||
proto.ksk1 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk1));
|
||||
proto
|
||||
}
|
||||
}
|
||||
|
||||
impl TryConvertFrom<&RGSWCiphertextProto> for RGSWCiphertext {
|
||||
fn try_convert_from(
|
||||
value: &RGSWCiphertextProto,
|
||||
par: &std::sync::Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
if value.ksk0.is_none() || value.ksk1.is_none() {
|
||||
return Err(Error::SerializationError);
|
||||
}
|
||||
fn try_convert_from(
|
||||
value: &RGSWCiphertextProto,
|
||||
par: &std::sync::Arc<BfvParameters>,
|
||||
) -> Result<Self> {
|
||||
if value.ksk0.is_none() || value.ksk1.is_none() {
|
||||
return Err(Error::SerializationError);
|
||||
}
|
||||
|
||||
let ksk0 = KeySwitchingKey::try_convert_from(value.ksk0.as_ref().unwrap(), par)?;
|
||||
let ksk1 = KeySwitchingKey::try_convert_from(value.ksk1.as_ref().unwrap(), par)?;
|
||||
if ksk0.ksk_level != ksk0.ciphertext_level
|
||||
|| ksk0.ciphertext_level != ksk1.ciphertext_level
|
||||
|| ksk1.ciphertext_level != ksk1.ksk_level
|
||||
{
|
||||
return Err(Error::SerializationError);
|
||||
}
|
||||
let ksk0 = KeySwitchingKey::try_convert_from(value.ksk0.as_ref().unwrap(), par)?;
|
||||
let ksk1 = KeySwitchingKey::try_convert_from(value.ksk1.as_ref().unwrap(), par)?;
|
||||
if ksk0.ksk_level != ksk0.ciphertext_level
|
||||
|| ksk0.ciphertext_level != ksk1.ciphertext_level
|
||||
|| ksk1.ciphertext_level != ksk1.ksk_level
|
||||
{
|
||||
return Err(Error::SerializationError);
|
||||
}
|
||||
|
||||
Ok(Self { ksk0, ksk1 })
|
||||
}
|
||||
Ok(Self { ksk0, ksk1 })
|
||||
}
|
||||
}
|
||||
|
||||
impl DeserializeParametrized for RGSWCiphertext {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn from_bytes(bytes: &[u8], par: &std::sync::Arc<Self::Parameters>) -> Result<Self> {
|
||||
let proto =
|
||||
RGSWCiphertextProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
|
||||
RGSWCiphertext::try_convert_from(&proto, par)
|
||||
}
|
||||
fn from_bytes(bytes: &[u8], par: &std::sync::Arc<Self::Parameters>) -> Result<Self> {
|
||||
let proto =
|
||||
RGSWCiphertextProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
|
||||
RGSWCiphertext::try_convert_from(&proto, par)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for RGSWCiphertext {
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
RGSWCiphertextProto::from(self).write_to_bytes().unwrap()
|
||||
}
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
RGSWCiphertextProto::from(self).write_to_bytes().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl FheCiphertext for RGSWCiphertext {}
|
||||
|
||||
impl FheEncrypter<Plaintext, RGSWCiphertext> for SecretKey {
|
||||
type Error = Error;
|
||||
type Error = Error;
|
||||
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
pt: &Plaintext,
|
||||
rng: &mut R,
|
||||
) -> Result<RGSWCiphertext> {
|
||||
let level = pt.level;
|
||||
let ctx = self.par.ctx_at_level(level)?;
|
||||
fn try_encrypt<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
pt: &Plaintext,
|
||||
rng: &mut R,
|
||||
) -> Result<RGSWCiphertext> {
|
||||
let level = pt.level;
|
||||
let ctx = self.par.ctx_at_level(level)?;
|
||||
|
||||
let mut m = Zeroizing::new(pt.poly_ntt.clone());
|
||||
let mut m_s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
m_s.change_representation(Representation::Ntt);
|
||||
*m_s.as_mut() *= m.as_ref();
|
||||
m_s.change_representation(Representation::PowerBasis);
|
||||
m.change_representation(Representation::PowerBasis);
|
||||
let mut m = Zeroizing::new(pt.poly_ntt.clone());
|
||||
let mut m_s = Zeroizing::new(Poly::try_convert_from(
|
||||
self.coeffs.as_ref(),
|
||||
ctx,
|
||||
false,
|
||||
Representation::PowerBasis,
|
||||
)?);
|
||||
m_s.change_representation(Representation::Ntt);
|
||||
*m_s.as_mut() *= m.as_ref();
|
||||
m_s.change_representation(Representation::PowerBasis);
|
||||
m.change_representation(Representation::PowerBasis);
|
||||
|
||||
let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?;
|
||||
let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?;
|
||||
let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?;
|
||||
let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?;
|
||||
|
||||
Ok(RGSWCiphertext { ksk0, ksk1 })
|
||||
}
|
||||
Ok(RGSWCiphertext { ksk0, ksk1 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<&RGSWCiphertext> for &Ciphertext {
|
||||
type Output = Ciphertext;
|
||||
type Output = Ciphertext;
|
||||
|
||||
fn mul(self, rhs: &RGSWCiphertext) -> Self::Output {
|
||||
assert_eq!(
|
||||
self.par, rhs.ksk0.par,
|
||||
"Ciphertext and RGSWCiphertext must have the same parameters"
|
||||
);
|
||||
assert_eq!(
|
||||
self.level, rhs.ksk0.ciphertext_level,
|
||||
"Ciphertext and RGSWCiphertext must have the same level"
|
||||
);
|
||||
assert_eq!(self.c.len(), 2, "Ciphertext must have two parts");
|
||||
fn mul(self, rhs: &RGSWCiphertext) -> Self::Output {
|
||||
assert_eq!(
|
||||
self.par, rhs.ksk0.par,
|
||||
"Ciphertext and RGSWCiphertext must have the same parameters"
|
||||
);
|
||||
assert_eq!(
|
||||
self.level, rhs.ksk0.ciphertext_level,
|
||||
"Ciphertext and RGSWCiphertext must have the same level"
|
||||
);
|
||||
assert_eq!(self.c.len(), 2, "Ciphertext must have two parts");
|
||||
|
||||
let mut ct0 = self.c[0].clone();
|
||||
let mut ct1 = self.c[1].clone();
|
||||
ct0.change_representation(Representation::PowerBasis);
|
||||
ct1.change_representation(Representation::PowerBasis);
|
||||
let mut ct0 = self.c[0].clone();
|
||||
let mut ct1 = self.c[1].clone();
|
||||
ct0.change_representation(Representation::PowerBasis);
|
||||
ct1.change_representation(Representation::PowerBasis);
|
||||
|
||||
let (c0, c1) = rhs.ksk0.key_switch(&ct0).unwrap();
|
||||
let (c0p, c1p) = rhs.ksk1.key_switch(&ct1).unwrap();
|
||||
let (c0, c1) = rhs.ksk0.key_switch(&ct0).unwrap();
|
||||
let (c0p, c1p) = rhs.ksk1.key_switch(&ct1).unwrap();
|
||||
|
||||
Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: None,
|
||||
c: vec![&c0 + &c0p, &c1 + &c1p],
|
||||
level: self.level,
|
||||
}
|
||||
}
|
||||
Ciphertext {
|
||||
par: self.par.clone(),
|
||||
seed: None,
|
||||
c: vec![&c0 + &c0p, &c1 + &c1p],
|
||||
level: self.level,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<&Ciphertext> for &RGSWCiphertext {
|
||||
type Output = Ciphertext;
|
||||
type Output = Ciphertext;
|
||||
|
||||
fn mul(self, rhs: &Ciphertext) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
fn mul(self, rhs: &Ciphertext) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{error::Error, sync::Arc};
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
|
||||
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
|
||||
use rand::thread_rng;
|
||||
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
|
||||
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::RGSWCiphertext;
|
||||
use super::RGSWCiphertext;
|
||||
|
||||
#[test]
|
||||
fn external_product() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(2, 8)),
|
||||
Arc::new(BfvParameters::default(8, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v1 = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let v2 = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
#[test]
|
||||
fn external_product() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(2, 8)),
|
||||
Arc::new(BfvParameters::default(8, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v1 = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let v2 = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
|
||||
let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), ¶ms)?;
|
||||
let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), ¶ms)?;
|
||||
let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), ¶ms)?;
|
||||
let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), ¶ms)?;
|
||||
|
||||
let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?;
|
||||
let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?;
|
||||
let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?;
|
||||
let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?;
|
||||
let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?;
|
||||
let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?;
|
||||
|
||||
let product = &ct1 * &ct2;
|
||||
let expected = sk.try_decrypt(&product)?;
|
||||
let product = &ct1 * &ct2;
|
||||
let expected = sk.try_decrypt(&product)?;
|
||||
|
||||
let ct3 = &ct1 * &ct2_rgsw;
|
||||
let ct4 = &ct2_rgsw * &ct1;
|
||||
let ct3 = &ct1 * &ct2_rgsw;
|
||||
let ct4 = &ct2_rgsw * &ct1;
|
||||
|
||||
println!("Noise 1: {:?}", unsafe { sk.measure_noise(&ct3) });
|
||||
println!("Noise 2: {:?}", unsafe { sk.measure_noise(&ct4) });
|
||||
assert_eq!(expected, sk.try_decrypt(&ct3)?);
|
||||
assert_eq!(expected, sk.try_decrypt(&ct4)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
println!("Noise 1: {:?}", unsafe { sk.measure_noise(&ct3) });
|
||||
println!("Noise 2: {:?}", unsafe { sk.measure_noise(&ct4) });
|
||||
assert_eq!(expected, sk.try_decrypt(&ct3)?);
|
||||
assert_eq!(expected, sk.try_decrypt(&ct4)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(5, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
#[test]
|
||||
fn serialize() -> Result<(), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
for params in [
|
||||
Arc::new(BfvParameters::default(6, 8)),
|
||||
Arc::new(BfvParameters::default(5, 8)),
|
||||
] {
|
||||
let sk = SecretKey::random(¶ms, &mut rng);
|
||||
let v = params.plaintext.random_vec(params.degree(), &mut rng);
|
||||
let pt = Plaintext::try_encode(&v, Encoding::simd(), ¶ms)?;
|
||||
let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?;
|
||||
|
||||
let bytes = ct.to_bytes();
|
||||
assert_eq!(RGSWCiphertext::from_bytes(&bytes, ¶ms)?, ct);
|
||||
}
|
||||
let bytes = ct.to_bytes();
|
||||
assert_eq!(RGSWCiphertext::from_bytes(&bytes, ¶ms)?, ct);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ use std::sync::Arc;
|
||||
/// blanket implementation <https://github.com/rust-lang/rust/issues/50133#issuecomment-488512355>.
|
||||
pub trait TryConvertFrom<T>
|
||||
where
|
||||
Self: Sized,
|
||||
Self: Sized,
|
||||
{
|
||||
/// Attempt to convert the `value` with a specific parameter.
|
||||
fn try_convert_from(value: T, par: &Arc<BfvParameters>) -> Result<Self>;
|
||||
/// Attempt to convert the `value` with a specific parameter.
|
||||
fn try_convert_from(value: T, par: &Arc<BfvParameters>) -> Result<Self>;
|
||||
}
|
||||
|
||||
@@ -6,141 +6,141 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
/// Enum encapsulating all the possible errors from this library.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// Indicates that an error from the underlying mathematical library was
|
||||
/// encountered.
|
||||
#[error("{0}")]
|
||||
MathError(fhe_math::Error),
|
||||
/// Indicates that an error from the underlying mathematical library was
|
||||
/// encountered.
|
||||
#[error("{0}")]
|
||||
MathError(fhe_math::Error),
|
||||
|
||||
/// Indicates a serialization error.
|
||||
#[error("Serialization error")]
|
||||
SerializationError,
|
||||
/// Indicates a serialization error.
|
||||
#[error("Serialization error")]
|
||||
SerializationError,
|
||||
|
||||
/// Indicates that too many values were provided.
|
||||
#[error("Too many values provided: {0} exceeds limit {1}")]
|
||||
TooManyValues(usize, usize),
|
||||
/// Indicates that too many values were provided.
|
||||
#[error("Too many values provided: {0} exceeds limit {1}")]
|
||||
TooManyValues(usize, usize),
|
||||
|
||||
/// Indicates that too few values were provided.
|
||||
#[error("Too few values provided: {0} is below limit {1}")]
|
||||
TooFewValues(usize, usize),
|
||||
/// Indicates that too few values were provided.
|
||||
#[error("Too few values provided: {0} is below limit {1}")]
|
||||
TooFewValues(usize, usize),
|
||||
|
||||
/// Indicates that an input is invalid.
|
||||
#[error("{0}")]
|
||||
UnspecifiedInput(String),
|
||||
/// Indicates that an input is invalid.
|
||||
#[error("{0}")]
|
||||
UnspecifiedInput(String),
|
||||
|
||||
/// Indicates a mismatch in the encodings.
|
||||
#[error("Encoding mismatch: found {0}, expected {1}")]
|
||||
EncodingMismatch(String, String),
|
||||
/// Indicates a mismatch in the encodings.
|
||||
#[error("Encoding mismatch: found {0}, expected {1}")]
|
||||
EncodingMismatch(String, String),
|
||||
|
||||
/// Indicates that the encoding is not supported.
|
||||
#[error("Does not support {0} encoding")]
|
||||
EncodingNotSupported(String),
|
||||
/// Indicates that the encoding is not supported.
|
||||
#[error("Does not support {0} encoding")]
|
||||
EncodingNotSupported(String),
|
||||
|
||||
/// Indicates a parameter error.
|
||||
#[error("{0}")]
|
||||
ParametersError(ParametersError),
|
||||
/// Indicates a parameter error.
|
||||
#[error("{0}")]
|
||||
ParametersError(ParametersError),
|
||||
|
||||
/// Indicates a default error
|
||||
/// TODO: To delete eventually
|
||||
#[error("{0}")]
|
||||
DefaultError(String),
|
||||
/// Indicates a default error
|
||||
/// TODO: To delete eventually
|
||||
#[error("{0}")]
|
||||
DefaultError(String),
|
||||
}
|
||||
|
||||
impl From<fhe_math::Error> for Error {
|
||||
fn from(e: fhe_math::Error) -> Self {
|
||||
Error::MathError(e)
|
||||
}
|
||||
fn from(e: fhe_math::Error) -> Self {
|
||||
Error::MathError(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Separate enum to indicate parameters-related errors.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
pub enum ParametersError {
|
||||
/// Indicates that the degree is invalid.
|
||||
#[error("Invalid degree: {0} is not a power of 2 larger than 8")]
|
||||
InvalidDegree(usize),
|
||||
/// Indicates that the degree is invalid.
|
||||
#[error("Invalid degree: {0} is not a power of 2 larger than 8")]
|
||||
InvalidDegree(usize),
|
||||
|
||||
/// Indicates that the moduli sizes are invalid.
|
||||
#[error("Invalid modulus size: {0}, expected an integer between {1} and {2}")]
|
||||
InvalidModulusSize(usize, usize, usize),
|
||||
/// Indicates that the moduli sizes are invalid.
|
||||
#[error("Invalid modulus size: {0}, expected an integer between {1} and {2}")]
|
||||
InvalidModulusSize(usize, usize, usize),
|
||||
|
||||
/// Indicates that there exists not enough primes of this size.
|
||||
#[error("Not enough primes of size {0} for polynomials of degree {1}")]
|
||||
NotEnoughPrimes(usize, usize),
|
||||
/// Indicates that there exists not enough primes of this size.
|
||||
#[error("Not enough primes of size {0} for polynomials of degree {1}")]
|
||||
NotEnoughPrimes(usize, usize),
|
||||
|
||||
/// Indicates that the plaintext is invalid.
|
||||
#[error("{0}")]
|
||||
InvalidPlaintext(String),
|
||||
/// Indicates that the plaintext is invalid.
|
||||
#[error("{0}")]
|
||||
InvalidPlaintext(String),
|
||||
|
||||
/// Indicates that too many parameters were specified.
|
||||
#[error("{0}")]
|
||||
TooManySpecified(String),
|
||||
/// Indicates that too many parameters were specified.
|
||||
#[error("{0}")]
|
||||
TooManySpecified(String),
|
||||
|
||||
/// Indicates that too few parameters were specified.
|
||||
#[error("{0}")]
|
||||
TooFewSpecified(String),
|
||||
/// Indicates that too few parameters were specified.
|
||||
#[error("{0}")]
|
||||
TooFewSpecified(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{Error, ParametersError};
|
||||
use crate::{Error, ParametersError};
|
||||
|
||||
#[test]
|
||||
fn error_strings() {
|
||||
assert_eq!(
|
||||
Error::MathError(fhe_math::Error::InvalidContext).to_string(),
|
||||
fhe_math::Error::InvalidContext.to_string()
|
||||
);
|
||||
assert_eq!(Error::SerializationError.to_string(), "Serialization error");
|
||||
assert_eq!(
|
||||
Error::TooManyValues(20, 17).to_string(),
|
||||
"Too many values provided: 20 exceeds limit 17"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::TooFewValues(10, 17).to_string(),
|
||||
"Too few values provided: 10 is below limit 17"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::UnspecifiedInput("test string".to_string()).to_string(),
|
||||
"test string"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::EncodingMismatch("enc1".to_string(), "enc2".to_string()).to_string(),
|
||||
"Encoding mismatch: found enc1, expected enc2"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::EncodingNotSupported("test".to_string()).to_string(),
|
||||
"Does not support test encoding"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::ParametersError(ParametersError::InvalidDegree(10)).to_string(),
|
||||
ParametersError::InvalidDegree(10).to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn error_strings() {
|
||||
assert_eq!(
|
||||
Error::MathError(fhe_math::Error::InvalidContext).to_string(),
|
||||
fhe_math::Error::InvalidContext.to_string()
|
||||
);
|
||||
assert_eq!(Error::SerializationError.to_string(), "Serialization error");
|
||||
assert_eq!(
|
||||
Error::TooManyValues(20, 17).to_string(),
|
||||
"Too many values provided: 20 exceeds limit 17"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::TooFewValues(10, 17).to_string(),
|
||||
"Too few values provided: 10 is below limit 17"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::UnspecifiedInput("test string".to_string()).to_string(),
|
||||
"test string"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::EncodingMismatch("enc1".to_string(), "enc2".to_string()).to_string(),
|
||||
"Encoding mismatch: found enc1, expected enc2"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::EncodingNotSupported("test".to_string()).to_string(),
|
||||
"Does not support test encoding"
|
||||
);
|
||||
assert_eq!(
|
||||
Error::ParametersError(ParametersError::InvalidDegree(10)).to_string(),
|
||||
ParametersError::InvalidDegree(10).to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_error_strings() {
|
||||
assert_eq!(
|
||||
ParametersError::InvalidDegree(10).to_string(),
|
||||
"Invalid degree: 10 is not a power of 2 larger than 8"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::InvalidModulusSize(1, 2, 3).to_string(),
|
||||
"Invalid modulus size: 1, expected an integer between 2 and 3"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::NotEnoughPrimes(1, 2).to_string(),
|
||||
"Not enough primes of size 1 for polynomials of degree 2"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::InvalidPlaintext("test".to_string()).to_string(),
|
||||
"test"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::TooManySpecified("test".to_string()).to_string(),
|
||||
"test"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::TooFewSpecified("test".to_string()).to_string(),
|
||||
"test"
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn parameters_error_strings() {
|
||||
assert_eq!(
|
||||
ParametersError::InvalidDegree(10).to_string(),
|
||||
"Invalid degree: 10 is not a power of 2 larger than 8"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::InvalidModulusSize(1, 2, 3).to_string(),
|
||||
"Invalid modulus size: 1, expected an integer between 2 and 3"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::NotEnoughPrimes(1, 2).to_string(),
|
||||
"Not enough primes of size 1 for polynomials of degree 2"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::InvalidPlaintext("test".to_string()).to_string(),
|
||||
"test"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::TooManySpecified("test".to_string()).to_string(),
|
||||
"test"
|
||||
);
|
||||
assert_eq!(
|
||||
ParametersError::TooFewSpecified("test".to_string()).to_string(),
|
||||
"test"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
wrap_comments = true
|
||||
hard_tabs = true
|
||||
|
||||
Reference in New Issue
Block a user