Change tabs into space, optimize ntt operator constructor (#170)

This commit is contained in:
Tancrède Lepoint
2023-04-04 04:28:01 -04:00
committed by GitHub
parent 34dd9f800c
commit 393316ffe1
45 changed files with 12127 additions and 12130 deletions

View File

@@ -4,40 +4,40 @@ use rand::thread_rng;
use std::sync::Arc;
pub fn ntt_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("ntt");
group.sample_size(50);
let mut rng = thread_rng();
let mut group = c.benchmark_group("ntt");
group.sample_size(50);
let mut rng = thread_rng();
for vector_size in [1024usize, 4096].iter() {
for p in [4611686018326724609u64, 40961u64] {
let p_nbits = 64 - p.leading_zeros();
let q = Modulus::new(p).unwrap();
let mut a = q.random_vec(*vector_size, &mut rng);
let op = NttOperator::new(&Arc::new(q), *vector_size).unwrap();
for vector_size in [1024usize, 4096].iter() {
for p in [4611686018326724609u64, 40961u64] {
let p_nbits = 64 - p.leading_zeros();
let q = Modulus::new(p).unwrap();
let mut a = q.random_vec(*vector_size, &mut rng);
let op = NttOperator::new(&Arc::new(q), *vector_size).unwrap();
group.bench_function(
BenchmarkId::new("forward", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| op.forward(&mut a)),
);
group.bench_function(
BenchmarkId::new("forward", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| op.forward(&mut a)),
);
group.bench_function(
BenchmarkId::new("forward_vt", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| unsafe { op.forward_vt(a.as_mut_ptr()) }),
);
group.bench_function(
BenchmarkId::new("forward_vt", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| unsafe { op.forward_vt(a.as_mut_ptr()) }),
);
group.bench_function(
BenchmarkId::new("backward", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| op.backward(&mut a)),
);
group.bench_function(
BenchmarkId::new("backward", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| op.backward(&mut a)),
);
group.bench_function(
BenchmarkId::new("backward_vt", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| unsafe { op.backward_vt(a.as_mut_ptr()) }),
);
}
}
group.bench_function(
BenchmarkId::new("backward_vt", format!("{vector_size}/{p_nbits}")),
|b| b.iter(|| unsafe { op.backward_vt(a.as_mut_ptr()) }),
);
}
}
group.finish();
group.finish();
}
criterion_group!(ntt, ntt_benchmark);

View File

@@ -5,53 +5,53 @@ use rand::{thread_rng, RngCore};
use std::sync::Arc;
pub fn rns_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("rns");
group.sample_size(50);
let mut group = c.benchmark_group("rns");
group.sample_size(50);
let q = [
4611686018326724609u64,
4611686018309947393,
4611686018282684417,
];
let p = [
4611686018257518593u64,
4611686018232352769,
4611686018171535361,
4611686018106523649,
];
let q = [
4611686018326724609u64,
4611686018309947393,
4611686018282684417,
];
let p = [
4611686018257518593u64,
4611686018232352769,
4611686018171535361,
4611686018106523649,
];
let mut rng = thread_rng();
let mut x = vec![];
for qi in &q {
x.push(rng.next_u64() % *qi);
}
let mut rng = thread_rng();
let mut x = vec![];
for qi in &q {
x.push(rng.next_u64() % *qi);
}
let rns_q = Arc::new(RnsContext::new(&q).unwrap());
let rns_p = Arc::new(RnsContext::new(&p).unwrap());
let scaler = RnsScaler::new(
&rns_q,
&rns_p,
ScalingFactor::new(&BigUint::from(1u64), &BigUint::from(46116860181065u64)),
);
let scaler_as_converter = RnsScaler::new(&rns_q, &rns_p, ScalingFactor::one());
let rns_q = Arc::new(RnsContext::new(&q).unwrap());
let rns_p = Arc::new(RnsContext::new(&p).unwrap());
let scaler = RnsScaler::new(
&rns_q,
&rns_p,
ScalingFactor::new(&BigUint::from(1u64), &BigUint::from(46116860181065u64)),
);
let scaler_as_converter = RnsScaler::new(&rns_q, &rns_p, ScalingFactor::one());
let mut y = vec![0; p.len()];
let mut y = vec![0; p.len()];
group.bench_function(
BenchmarkId::new("scaler", format!("{}->{}", q.len(), p.len())),
|b| {
b.iter(|| scaler.scale((&x).into(), (&mut y).into(), 0));
},
);
group.bench_function(
BenchmarkId::new("scaler", format!("{}->{}", q.len(), p.len())),
|b| {
b.iter(|| scaler.scale((&x).into(), (&mut y).into(), 0));
},
);
group.bench_function(
BenchmarkId::new("scaler_as_converter", format!("{}->{}", q.len(), p.len())),
|b| {
b.iter(|| scaler_as_converter.scale((&x).into(), (&mut y).into(), 0));
},
);
group.bench_function(
BenchmarkId::new("scaler_as_converter", format!("{}->{}", q.len(), p.len())),
|b| {
b.iter(|| scaler_as_converter.scale((&x).into(), (&mut y).into(), 0));
},
);
group.finish();
group.finish();
}
criterion_group!(rns, rns_benchmark);

View File

@@ -4,287 +4,287 @@ use fhe_math::rq::*;
use itertools::{izip, Itertools};
use rand::thread_rng;
use std::{
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
sync::Arc,
time::Duration,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
sync::Arc,
time::Duration,
};
static MODULI: &[u64; 4] = &[
562949954093057,
4611686018326724609,
4611686018309947393,
4611686018282684417,
562949954093057,
4611686018326724609,
4611686018309947393,
4611686018282684417,
];
static DEGREE: &[usize] = &[1024, 2048, 4096, 8192];
fn create_group(c: &mut Criterion, name: String) -> BenchmarkGroup<WallTime> {
let mut group = c.benchmark_group(name);
group.warm_up_time(Duration::from_millis(100));
group.measurement_time(Duration::from_secs(1));
group
let mut group = c.benchmark_group(name);
group.warm_up_time(Duration::from_millis(100));
group.measurement_time(Duration::from_secs(1));
group
}
macro_rules! bench_op {
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
let name = if $vt {
format!("{}_vt", $name)
} else {
$name.to_string()
};
let mut group = create_group($c, name);
let mut rng = thread_rng();
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
let name = if $vt {
format!("{}_vt", $name)
} else {
$name.to_string()
};
let mut group = create_group($c, name);
let mut rng = thread_rng();
for degree in DEGREE {
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
if $vt {
unsafe { q.allow_variable_time_computations() }
}
for degree in DEGREE {
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
if $vt {
unsafe { q.allow_variable_time_computations() }
}
group.bench_function(
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| $op(&p, &q));
},
);
}
}};
group.bench_function(
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| $op(&p, &q));
},
);
}
}};
}
macro_rules! bench_op_unary {
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
let name = if $vt {
format!("{}_vt", $name)
} else {
$name.to_string()
};
let mut group = create_group($c, name);
let mut rng = thread_rng();
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
let name = if $vt {
format!("{}_vt", $name)
} else {
$name.to_string()
};
let mut group = create_group($c, name);
let mut rng = thread_rng();
for degree in DEGREE {
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
if $vt {
unsafe { q.allow_variable_time_computations() }
}
for degree in DEGREE {
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
if $vt {
unsafe { q.allow_variable_time_computations() }
}
group.bench_function(
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| $op(&p));
},
);
}
}};
group.bench_function(
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| $op(&p));
},
);
}
}};
}
macro_rules! bench_op_assign {
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
let name = if $vt {
format!("{}_vt", $name)
} else {
$name.to_string()
};
let mut group = create_group($c, name);
let mut rng = thread_rng();
($c: expr, $name:expr, $op:expr, $vt:expr) => {{
let name = if $vt {
format!("{}_vt", $name)
} else {
$name.to_string()
};
let mut group = create_group($c, name);
let mut rng = thread_rng();
for degree in DEGREE {
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
if $vt {
unsafe { q.allow_variable_time_computations() }
}
for degree in DEGREE {
let ctx = Arc::new(Context::new(&MODULI[..1], *degree).unwrap());
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
if $vt {
unsafe { q.allow_variable_time_computations() }
}
group.bench_function(
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| $op(&mut p, &q));
},
);
}
}};
group.bench_function(
BenchmarkId::from_parameter(format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| $op(&mut p, &q));
},
);
}
}};
}
pub fn rq_op_benchmark(c: &mut Criterion) {
for vt in [false, true] {
bench_op!(c, "rq_add", <&Poly>::add, vt);
bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt);
bench_op!(c, "rq_sub", <&Poly>::sub, vt);
bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt);
bench_op!(c, "rq_mul", <&Poly>::mul, vt);
bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt);
bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt);
}
for vt in [false, true] {
bench_op!(c, "rq_add", <&Poly>::add, vt);
bench_op_assign!(c, "rq_add_assign", Poly::add_assign, vt);
bench_op!(c, "rq_sub", <&Poly>::sub, vt);
bench_op_assign!(c, "rq_sub_assign", Poly::sub_assign, vt);
bench_op!(c, "rq_mul", <&Poly>::mul, vt);
bench_op_assign!(c, "rq_mul_assign", Poly::mul_assign, vt);
bench_op_unary!(c, "rq_neg", <&Poly>::neg, vt);
}
}
pub fn rq_dot_product(c: &mut Criterion) {
let mut group = create_group(c, "rq_dot_product".to_string());
let mut rng = thread_rng();
for degree in DEGREE {
for i in [1, 4] {
let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap());
let p_vec = (0..256)
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
.collect_vec();
let mut q_vec = (0..256)
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
.collect_vec();
let mut out = Poly::zero(&ctx, Representation::Ntt);
let mut group = create_group(c, "rq_dot_product".to_string());
let mut rng = thread_rng();
for degree in DEGREE {
for i in [1, 4] {
let ctx = Arc::new(Context::new(&MODULI[..i], *degree).unwrap());
let p_vec = (0..256)
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
.collect_vec();
let mut q_vec = (0..256)
.map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
.collect_vec();
let mut out = Poly::zero(&ctx, Representation::Ntt);
group.bench_function(
BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| {
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
});
},
);
group.bench_function(
BenchmarkId::from_parameter(format!("naive/{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| {
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
});
},
);
q_vec
.iter_mut()
.for_each(|qi| qi.change_representation(Representation::NttShoup));
group.bench_function(
BenchmarkId::from_parameter(format!(
"naive_shoup/{}/{}",
degree,
ctx.modulus().bits()
)),
|b| {
b.iter(|| {
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
});
},
);
q_vec
.iter_mut()
.for_each(|qi| qi.change_representation(Representation::NttShoup));
group.bench_function(
BenchmarkId::from_parameter(format!(
"naive_shoup/{}/{}",
degree,
ctx.modulus().bits()
)),
|b| {
b.iter(|| {
izip!(p_vec.iter(), q_vec.iter()).for_each(|(pi, qi)| out += &(pi * qi))
});
},
);
q_vec
.iter_mut()
.for_each(|qi| qi.change_representation(Representation::Ntt));
group.bench_function(
BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| dot_product(p_vec.iter(), q_vec.iter()));
},
);
}
}
q_vec
.iter_mut()
.for_each(|qi| qi.change_representation(Representation::Ntt));
group.bench_function(
BenchmarkId::from_parameter(format!("opt/{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| dot_product(p_vec.iter(), q_vec.iter()));
},
);
}
}
}
pub fn rq_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("rq");
group.warm_up_time(Duration::from_millis(100));
group.measurement_time(Duration::from_secs(1));
let mut group = c.benchmark_group("rq");
group.warm_up_time(Duration::from_millis(100));
group.measurement_time(Duration::from_secs(1));
let mut rng = thread_rng();
for degree in DEGREE {
for nmoduli in 1..=MODULI.len() {
if !nmoduli.is_power_of_two() {
continue;
}
let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap());
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
q.change_representation(Representation::NttShoup);
let mut rng = thread_rng();
for degree in DEGREE {
for nmoduli in 1..=MODULI.len() {
if !nmoduli.is_power_of_two() {
continue;
}
let ctx = Arc::new(Context::new(&MODULI[..nmoduli], *degree).unwrap());
let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
q.change_representation(Representation::NttShoup);
group.bench_function(
BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| p = &p * &q);
},
);
group.bench_function(
BenchmarkId::new("mul_shoup", format!("{}/{}", degree, ctx.modulus().bits())),
|b| {
b.iter(|| p = &p * &q);
},
);
group.bench_function(
BenchmarkId::new(
"mul_shoup_assign",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| p *= &q);
},
);
group.bench_function(
BenchmarkId::new(
"mul_shoup_assign",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| p *= &q);
},
);
group.bench_function(
BenchmarkId::new(
"change_representation/PowerBasis_to_Ntt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
unsafe {
p.override_representation(Representation::PowerBasis);
}
p.change_representation(Representation::Ntt)
});
},
);
group.bench_function(
BenchmarkId::new(
"change_representation/PowerBasis_to_Ntt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
unsafe {
p.override_representation(Representation::PowerBasis);
}
p.change_representation(Representation::Ntt)
});
},
);
group.bench_function(
BenchmarkId::new(
"change_representation/Ntt_to_PowerBasis",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
unsafe {
p.override_representation(Representation::Ntt);
}
p.change_representation(Representation::PowerBasis)
});
},
);
group.bench_function(
BenchmarkId::new(
"change_representation/Ntt_to_PowerBasis",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
unsafe {
p.override_representation(Representation::Ntt);
}
p.change_representation(Representation::PowerBasis)
});
},
);
p.change_representation(Representation::Ntt);
q.change_representation(Representation::Ntt);
p.change_representation(Representation::Ntt);
q.change_representation(Representation::Ntt);
unsafe {
q.allow_variable_time_computations();
q.change_representation(Representation::NttShoup);
unsafe {
q.allow_variable_time_computations();
q.change_representation(Representation::NttShoup);
group.bench_function(
BenchmarkId::new(
"mul_shoup_vt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| p *= &q);
},
);
group.bench_function(
BenchmarkId::new(
"mul_shoup_vt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| p *= &q);
},
);
p.allow_variable_time_computations();
p.allow_variable_time_computations();
group.bench_function(
BenchmarkId::new(
"change_representation/PowerBasis_to_Ntt_vt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
p.override_representation(Representation::PowerBasis);
p.change_representation(Representation::Ntt)
});
},
);
group.bench_function(
BenchmarkId::new(
"change_representation/PowerBasis_to_Ntt_vt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
p.override_representation(Representation::PowerBasis);
p.change_representation(Representation::Ntt)
});
},
);
group.bench_function(
BenchmarkId::new(
"change_representation/Ntt_to_PowerBasis_vt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
p.override_representation(Representation::Ntt);
p.change_representation(Representation::PowerBasis)
});
},
);
}
}
}
group.bench_function(
BenchmarkId::new(
"change_representation/Ntt_to_PowerBasis_vt",
format!("{}/{}", degree, ctx.modulus().bits()),
),
|b| {
b.iter(|| {
p.override_representation(Representation::Ntt);
p.change_representation(Representation::PowerBasis)
});
},
);
}
}
}
group.finish();
group.finish();
}
criterion_group!(rq, rq_op_benchmark, rq_dot_product, rq_benchmark);

View File

@@ -3,53 +3,53 @@ use fhe_math::zq::Modulus;
use rand::thread_rng;
pub fn zq_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("zq");
group.sample_size(50);
let mut group = c.benchmark_group("zq");
group.sample_size(50);
let p = 4611686018326724609;
let mut rng = thread_rng();
let p = 4611686018326724609;
let mut rng = thread_rng();
for vector_size in [1024usize, 4096].iter() {
let q = Modulus::new(p).unwrap();
let mut a = q.random_vec(*vector_size, &mut rng);
let c = q.random_vec(*vector_size, &mut rng);
let c_shoup = q.shoup_vec(&c);
let scalar = c[0];
for vector_size in [1024usize, 4096].iter() {
let q = Modulus::new(p).unwrap();
let mut a = q.random_vec(*vector_size, &mut rng);
let c = q.random_vec(*vector_size, &mut rng);
let c_shoup = q.shoup_vec(&c);
let scalar = c[0];
group.bench_function(BenchmarkId::new("add_vec", vector_size), |b| {
b.iter(|| q.add_vec(&mut a, &c));
});
group.bench_function(BenchmarkId::new("add_vec", vector_size), |b| {
b.iter(|| q.add_vec(&mut a, &c));
});
group.bench_function(BenchmarkId::new("add_vec_vt", vector_size), |b| unsafe {
b.iter(|| q.add_vec_vt(&mut a, &c));
});
group.bench_function(BenchmarkId::new("add_vec_vt", vector_size), |b| unsafe {
b.iter(|| q.add_vec_vt(&mut a, &c));
});
group.bench_function(BenchmarkId::new("sub_vec", vector_size), |b| {
b.iter(|| q.sub_vec(&mut a, &c));
});
group.bench_function(BenchmarkId::new("sub_vec", vector_size), |b| {
b.iter(|| q.sub_vec(&mut a, &c));
});
group.bench_function(BenchmarkId::new("neg_vec", vector_size), |b| {
b.iter(|| q.neg_vec(&mut a));
});
group.bench_function(BenchmarkId::new("neg_vec", vector_size), |b| {
b.iter(|| q.neg_vec(&mut a));
});
group.bench_function(BenchmarkId::new("mul_vec", vector_size), |b| {
b.iter(|| q.mul_vec(&mut a, &c));
});
group.bench_function(BenchmarkId::new("mul_vec", vector_size), |b| {
b.iter(|| q.mul_vec(&mut a, &c));
});
group.bench_function(BenchmarkId::new("mul_vec_vt", vector_size), |b| unsafe {
b.iter(|| q.mul_vec_vt(&mut a, &c));
});
group.bench_function(BenchmarkId::new("mul_vec_vt", vector_size), |b| unsafe {
b.iter(|| q.mul_vec_vt(&mut a, &c));
});
group.bench_function(BenchmarkId::new("mul_shoup_vec", vector_size), |b| {
b.iter(|| q.mul_shoup_vec(&mut a, &c, &c_shoup));
});
group.bench_function(BenchmarkId::new("mul_shoup_vec", vector_size), |b| {
b.iter(|| q.mul_shoup_vec(&mut a, &c, &c_shoup));
});
group.bench_function(BenchmarkId::new("scalar_mul_vec", vector_size), |b| {
b.iter(|| q.scalar_mul_vec(&mut a, scalar));
});
}
group.bench_function(BenchmarkId::new("scalar_mul_vec", vector_size), |b| {
b.iter(|| q.scalar_mul_vec(&mut a, scalar));
});
}
group.finish();
group.finish();
}
criterion_group!(zq, zq_benchmark);

View File

@@ -8,63 +8,63 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Enum encapsulation all the possible errors from this library.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum Error {
/// Indicates an invalid modulus
#[error("Invalid modulus: modulus {0} should be between 2 and (1 << 62) - 1.")]
InvalidModulus(u64),
/// Indicates an invalid modulus
#[error("Invalid modulus: modulus {0} should be between 2 and (1 << 62) - 1.")]
InvalidModulus(u64),
/// Indicates an error in the serialization / deserialization.
#[error("{0}")]
Serialization(String),
/// Indicates an error in the serialization / deserialization.
#[error("{0}")]
Serialization(String),
/// Indicates that there is no more contexts to switch to.
#[error("This is the last context.")]
NoMoreContext,
/// Indicates that there is no more contexts to switch to.
#[error("This is the last context.")]
NoMoreContext,
/// Indicates that the provided context is invalid.
#[error("Invalid context provided.")]
InvalidContext,
/// Indicates that the provided context is invalid.
#[error("Invalid context provided.")]
InvalidContext,
/// Indicates an incorrect representation.
#[error("Incorrect representation: got {0:?}, expected {1:?}.")]
IncorrectRepresentation(Representation, Representation),
/// Indicates an incorrect representation.
#[error("Incorrect representation: got {0:?}, expected {1:?}.")]
IncorrectRepresentation(Representation, Representation),
/// Indicates that the seed size is incorrect.
#[error("Invalid seed: got {0} bytes, expected {1} bytes.")]
InvalidSeedSize(usize, usize),
/// Indicates that the seed size is incorrect.
#[error("Invalid seed: got {0} bytes, expected {1} bytes.")]
InvalidSeedSize(usize, usize),
/// Indicates a default error
/// TODO: To delete when transition is over
#[error("{0}")]
Default(String),
/// Indicates a default error
/// TODO: To delete when transition is over
#[error("{0}")]
Default(String),
}
#[cfg(test)]
mod tests {
use crate::{rq::Representation, Error};
use crate::{rq::Representation, Error};
#[test]
fn error_strings() {
assert_eq!(
Error::InvalidModulus(0).to_string(),
"Invalid modulus: modulus 0 should be between 2 and (1 << 62) - 1."
);
assert_eq!(Error::Serialization("test".to_string()).to_string(), "test");
assert_eq!(
Error::NoMoreContext.to_string(),
"This is the last context."
);
assert_eq!(
Error::InvalidContext.to_string(),
"Invalid context provided."
);
assert_eq!(
Error::IncorrectRepresentation(Representation::Ntt, Representation::NttShoup)
.to_string(),
"Incorrect representation: got Ntt, expected NttShoup."
);
assert_eq!(
Error::InvalidSeedSize(0, 1).to_string(),
"Invalid seed: got 0 bytes, expected 1 bytes."
);
}
#[test]
fn error_strings() {
assert_eq!(
Error::InvalidModulus(0).to_string(),
"Invalid modulus: modulus 0 should be between 2 and (1 << 62) - 1."
);
assert_eq!(Error::Serialization("test".to_string()).to_string(), "test");
assert_eq!(
Error::NoMoreContext.to_string(),
"This is the last context."
);
assert_eq!(
Error::InvalidContext.to_string(),
"Invalid context provided."
);
assert_eq!(
Error::IncorrectRepresentation(Representation::Ntt, Representation::NttShoup)
.to_string(),
"Incorrect representation: got Ntt, expected NttShoup."
);
assert_eq!(
Error::InvalidSeedSize(0, 1).to_string(),
"Invalid seed: got 0 bytes, expected 1 bytes."
);
}
}

View File

@@ -17,218 +17,218 @@ pub use scaler::{RnsScaler, ScalingFactor};
/// Context for a Residue Number System.
#[derive(Default, Clone, PartialEq, Eq)]
pub struct RnsContext {
moduli_u64: Vec<u64>,
moduli: Vec<Modulus>,
q_tilde: Vec<u64>,
q_tilde_shoup: Vec<u64>,
q_star: Vec<BigUint>,
garner: Vec<BigUint>,
product: BigUint,
moduli_u64: Vec<u64>,
moduli: Vec<Modulus>,
q_tilde: Vec<u64>,
q_tilde_shoup: Vec<u64>,
q_star: Vec<BigUint>,
garner: Vec<BigUint>,
product: BigUint,
}
impl Debug for RnsContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RnsContext")
.field("moduli_u64", &self.moduli_u64)
// .field("moduli", &self.moduli)
// .field("q_tilde", &self.q_tilde)
// .field("q_tilde_shoup", &self.q_tilde_shoup)
// .field("q_star", &self.q_star)
// .field("garner", &self.garner)
.field("product", &self.product)
.finish()
}
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RnsContext")
.field("moduli_u64", &self.moduli_u64)
// .field("moduli", &self.moduli)
// .field("q_tilde", &self.q_tilde)
// .field("q_tilde_shoup", &self.q_tilde_shoup)
// .field("q_star", &self.q_star)
// .field("garner", &self.garner)
.field("product", &self.product)
.finish()
}
}
impl RnsContext {
/// Create a RNS context from a list of moduli.
///
/// Returns an error if the list is empty, or if the moduli are no coprime.
pub fn new(moduli_u64: &[u64]) -> Result<Self> {
if moduli_u64.is_empty() {
Err(Error::Default("The list of moduli is empty".to_string()))
} else {
let mut moduli = Vec::with_capacity(moduli_u64.len());
let mut q_tilde = Vec::with_capacity(moduli_u64.len());
let mut q_tilde_shoup = Vec::with_capacity(moduli_u64.len());
let mut q_star = Vec::with_capacity(moduli_u64.len());
let mut garner = Vec::with_capacity(moduli_u64.len());
let mut product = BigUint::one();
let mut product_dig = BigUintDig::one();
/// Create a RNS context from a list of moduli.
///
/// Returns an error if the list is empty, or if the moduli are no coprime.
pub fn new(moduli_u64: &[u64]) -> Result<Self> {
if moduli_u64.is_empty() {
Err(Error::Default("The list of moduli is empty".to_string()))
} else {
let mut moduli = Vec::with_capacity(moduli_u64.len());
let mut q_tilde = Vec::with_capacity(moduli_u64.len());
let mut q_tilde_shoup = Vec::with_capacity(moduli_u64.len());
let mut q_star = Vec::with_capacity(moduli_u64.len());
let mut garner = Vec::with_capacity(moduli_u64.len());
let mut product = BigUint::one();
let mut product_dig = BigUintDig::one();
for i in 0..moduli_u64.len() {
// Return None if the moduli are not coprime.
for j in 0..moduli_u64.len() {
if i != j {
let (d, _, _) = BigUintDig::from(moduli_u64[i])
.extended_gcd(&BigUintDig::from(moduli_u64[j]));
if d.cmp(&BigIntDig::from(1)) != Ordering::Equal {
return Err(Error::Default("The moduli are not coprime".to_string()));
}
}
}
for i in 0..moduli_u64.len() {
// Return None if the moduli are not coprime.
for j in 0..moduli_u64.len() {
if i != j {
let (d, _, _) = BigUintDig::from(moduli_u64[i])
.extended_gcd(&BigUintDig::from(moduli_u64[j]));
if d.cmp(&BigIntDig::from(1)) != Ordering::Equal {
return Err(Error::Default("The moduli are not coprime".to_string()));
}
}
}
product *= &BigUint::from(moduli_u64[i]);
product_dig *= &BigUintDig::from(moduli_u64[i]);
}
product *= &BigUint::from(moduli_u64[i]);
product_dig *= &BigUintDig::from(moduli_u64[i]);
}
for modulus in moduli_u64 {
moduli.push(Modulus::new(*modulus)?);
// q* = product / modulus
let q_star_i = &product / modulus;
// q~ = (product / modulus) ^ (-1) % modulus
let q_tilde_i = (&product_dig / modulus)
.mod_inverse(&BigUintDig::from(*modulus))
.unwrap()
.to_u64()
.unwrap();
// garner = (q*) * (q~)
let garner_i = &q_star_i * q_tilde_i;
q_tilde.push(q_tilde_i);
garner.push(garner_i);
q_star.push(q_star_i);
q_tilde_shoup.push(
Modulus::new(*modulus)
.unwrap()
.shoup(q_tilde_i.to_u64().unwrap()),
);
}
for modulus in moduli_u64 {
moduli.push(Modulus::new(*modulus)?);
// q* = product / modulus
let q_star_i = &product / modulus;
// q~ = (product / modulus) ^ (-1) % modulus
let q_tilde_i = (&product_dig / modulus)
.mod_inverse(&BigUintDig::from(*modulus))
.unwrap()
.to_u64()
.unwrap();
// garner = (q*) * (q~)
let garner_i = &q_star_i * q_tilde_i;
q_tilde.push(q_tilde_i);
garner.push(garner_i);
q_star.push(q_star_i);
q_tilde_shoup.push(
Modulus::new(*modulus)
.unwrap()
.shoup(q_tilde_i.to_u64().unwrap()),
);
}
Ok(Self {
moduli_u64: moduli_u64.to_owned(),
moduli,
q_tilde,
q_tilde_shoup,
q_star,
garner,
product,
})
}
}
Ok(Self {
moduli_u64: moduli_u64.to_owned(),
moduli,
q_tilde,
q_tilde_shoup,
q_star,
garner,
product,
})
}
}
/// Returns the product of the moduli used when creating the RNS context.
pub const fn modulus(&self) -> &BigUint {
&self.product
}
/// Returns the product of the moduli used when creating the RNS context.
pub const fn modulus(&self) -> &BigUint {
&self.product
}
/// Project a BigUint into its rests.
pub fn project(&self, a: &BigUint) -> Vec<u64> {
let mut rests = Vec::with_capacity(self.moduli_u64.len());
for modulus in &self.moduli_u64 {
rests.push((a % modulus).to_u64().unwrap())
}
rests
}
/// Project a BigUint into its rests.
pub fn project(&self, a: &BigUint) -> Vec<u64> {
let mut rests = Vec::with_capacity(self.moduli_u64.len());
for modulus in &self.moduli_u64 {
rests.push((a % modulus).to_u64().unwrap())
}
rests
}
/// Lift rests into a BigUint.
///
/// Aborts if the number of rests is different than the number of moduli in
/// debug mode.
pub fn lift(&self, rests: ArrayView1<u64>) -> BigUint {
let mut result = BigUint::zero();
izip!(rests.iter(), self.garner.iter())
.for_each(|(r_i, garner_i)| result += garner_i * *r_i);
result % &self.product
}
/// Lift rests into a BigUint.
///
/// Aborts if the number of rests is different than the number of moduli in
/// debug mode.
pub fn lift(&self, rests: ArrayView1<u64>) -> BigUint {
let mut result = BigUint::zero();
izip!(rests.iter(), self.garner.iter())
.for_each(|(r_i, garner_i)| result += garner_i * *r_i);
result % &self.product
}
/// Getter for the i-th garner coefficient.
pub fn get_garner(&self, i: usize) -> Option<&BigUint> {
self.garner.get(i)
}
/// Getter for the i-th garner coefficient.
pub fn get_garner(&self, i: usize) -> Option<&BigUint> {
self.garner.get(i)
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use std::error::Error;
use super::RnsContext;
use ndarray::ArrayView1;
use num_bigint::BigUint;
use rand::RngCore;
use super::RnsContext;
use ndarray::ArrayView1;
use num_bigint::BigUint;
use rand::RngCore;
#[test]
fn constructor() {
assert!(RnsContext::new(&[2]).is_ok());
assert!(RnsContext::new(&[2, 3]).is_ok());
assert!(RnsContext::new(&[4, 15, 1153]).is_ok());
#[test]
fn constructor() {
assert!(RnsContext::new(&[2]).is_ok());
assert!(RnsContext::new(&[2, 3]).is_ok());
assert!(RnsContext::new(&[4, 15, 1153]).is_ok());
let e = RnsContext::new(&[]);
assert!(e.is_err());
assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty");
let e = RnsContext::new(&[2, 4]);
assert!(e.is_err());
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
let e = RnsContext::new(&[2, 3, 5, 30]);
assert!(e.is_err());
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
}
let e = RnsContext::new(&[]);
assert!(e.is_err());
assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty");
let e = RnsContext::new(&[2, 4]);
assert!(e.is_err());
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
let e = RnsContext::new(&[2, 3, 5, 30]);
assert!(e.is_err());
assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
}
#[test]
fn garner() -> Result<(), Box<dyn Error>> {
let rns = RnsContext::new(&[4, 15, 1153])?;
#[test]
fn garner() -> Result<(), Box<dyn Error>> {
let rns = RnsContext::new(&[4, 15, 1153])?;
for i in 0..3 {
let gi = rns.get_garner(i);
assert!(gi.is_some());
assert_eq!(gi.unwrap(), &rns.garner[i]);
}
assert!(rns.get_garner(3).is_none());
for i in 0..3 {
let gi = rns.get_garner(i);
assert!(gi.is_some());
assert_eq!(gi.unwrap(), &rns.garner[i]);
}
assert!(rns.get_garner(3).is_none());
Ok(())
}
Ok(())
}
#[test]
fn modulus() -> Result<(), Box<dyn Error>> {
let mut rns = RnsContext::new(&[2])?;
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64));
#[test]
fn modulus() -> Result<(), Box<dyn Error>> {
let mut rns = RnsContext::new(&[2])?;
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64));
rns = RnsContext::new(&[2, 5])?;
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5));
rns = RnsContext::new(&[2, 5])?;
debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5));
rns = RnsContext::new(&[4, 15, 1153])?;
debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153));
rns = RnsContext::new(&[4, 15, 1153])?;
debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153));
Ok(())
}
Ok(())
}
#[test]
fn project_lift() -> Result<(), Box<dyn Error>> {
let ntests = 100;
let rns = RnsContext::new(&[4, 15, 1153])?;
let product = 4u64 * 15 * 1153;
#[test]
fn project_lift() -> Result<(), Box<dyn Error>> {
let ntests = 100;
let rns = RnsContext::new(&[4, 15, 1153])?;
let product = 4u64 * 15 * 1153;
let mut rests = rns.project(&BigUint::from(0u64));
assert_eq!(&rests, &[0u64, 0, 0]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64));
let mut rests = rns.project(&BigUint::from(0u64));
assert_eq!(&rests, &[0u64, 0, 0]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64));
rests = rns.project(&BigUint::from(4u64));
assert_eq!(&rests, &[0u64, 4, 4]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64));
rests = rns.project(&BigUint::from(4u64));
assert_eq!(&rests, &[0u64, 4, 4]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64));
rests = rns.project(&BigUint::from(15u64));
assert_eq!(&rests, &[3u64, 0, 15]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64));
rests = rns.project(&BigUint::from(15u64));
assert_eq!(&rests, &[3u64, 0, 15]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64));
rests = rns.project(&BigUint::from(1153u64));
assert_eq!(&rests, &[1u64, 13, 0]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64));
rests = rns.project(&BigUint::from(1153u64));
assert_eq!(&rests, &[1u64, 13, 0]);
assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64));
rests = rns.project(&BigUint::from(product - 1));
assert_eq!(&rests, &[3u64, 14, 1152]);
assert_eq!(
rns.lift(ArrayView1::from(&rests)),
BigUint::from(product - 1)
);
rests = rns.project(&BigUint::from(product - 1));
assert_eq!(&rests, &[3u64, 14, 1152]);
assert_eq!(
rns.lift(ArrayView1::from(&rests)),
BigUint::from(product - 1)
);
let mut rng = rand::thread_rng();
let mut rng = rand::thread_rng();
for _ in 0..ntests {
let b = BigUint::from(rng.next_u64() % product);
rests = rns.project(&b);
assert_eq!(rns.lift(ArrayView1::from(&rests)), b);
}
for _ in 0..ntests {
let b = BigUint::from(rng.next_u64() % product);
rests = rns.project(&b);
assert_eq!(rns.lift(ArrayView1::from(&rests)), b);
}
Ok(())
}
Ok(())
}
}

View File

@@ -14,481 +14,481 @@ use std::{cmp::min, sync::Arc};
/// Scaling factor when performing a RNS scaling.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct ScalingFactor {
numerator: BigUint,
denominator: BigUint,
pub(crate) is_one: bool,
numerator: BigUint,
denominator: BigUint,
pub(crate) is_one: bool,
}
impl ScalingFactor {
/// Create a new scaling factor. Aborts if the denominator is 0.
pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self {
assert_ne!(denominator, &BigUint::zero());
Self {
numerator: numerator.clone(),
denominator: denominator.clone(),
is_one: numerator == denominator,
}
}
/// Create a new scaling factor. Aborts if the denominator is 0.
pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self {
assert_ne!(denominator, &BigUint::zero());
Self {
numerator: numerator.clone(),
denominator: denominator.clone(),
is_one: numerator == denominator,
}
}
/// Returns the identity element of `Self`.
pub fn one() -> Self {
Self {
numerator: BigUint::one(),
denominator: BigUint::one(),
is_one: true,
}
}
/// Returns the identity element of `Self`.
pub fn one() -> Self {
Self {
numerator: BigUint::one(),
denominator: BigUint::one(),
is_one: true,
}
}
}
/// Scaler for a RNS context.
/// This is a helper struct to perform RNS scaling.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct RnsScaler {
from: Arc<RnsContext>,
to: Arc<RnsContext>,
scaling_factor: ScalingFactor,
from: Arc<RnsContext>,
to: Arc<RnsContext>,
scaling_factor: ScalingFactor,
gamma: Box<[u64]>,
gamma_shoup: Box<[u64]>,
theta_gamma_lo: u64,
theta_gamma_hi: u64,
theta_gamma_sign: bool,
gamma: Box<[u64]>,
gamma_shoup: Box<[u64]>,
theta_gamma_lo: u64,
theta_gamma_hi: u64,
theta_gamma_sign: bool,
omega: Box<[Box<[u64]>]>,
omega_shoup: Box<[Box<[u64]>]>,
theta_omega_lo: Box<[u64]>,
theta_omega_hi: Box<[u64]>,
theta_omega_sign: Box<[bool]>,
omega: Box<[Box<[u64]>]>,
omega_shoup: Box<[Box<[u64]>]>,
theta_omega_lo: Box<[u64]>,
theta_omega_hi: Box<[u64]>,
theta_omega_sign: Box<[bool]>,
theta_garner_lo: Box<[u64]>,
theta_garner_hi: Box<[u64]>,
theta_garner_shift: usize,
theta_garner_lo: Box<[u64]>,
theta_garner_hi: Box<[u64]>,
theta_garner_shift: usize,
}
impl RnsScaler {
/// Create a RNS scaler by numerator / denominator.
///
/// Aborts if denominator is equal to 0.
pub fn new(
from: &Arc<RnsContext>,
to: &Arc<RnsContext>,
scaling_factor: ScalingFactor,
) -> Self {
// Let's define gamma = round(numerator * from.product / denominator)
let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
Self::extract_projection_and_theta(
to,
&from.product,
&scaling_factor.numerator,
&scaling_factor.denominator,
false,
);
let gamma_shoup = izip!(&gamma, &to.moduli)
.map(|(wi, q)| q.shoup(*wi))
.collect_vec();
/// Create a RNS scaler by numerator / denominator.
///
/// Aborts if denominator is equal to 0.
pub fn new(
from: &Arc<RnsContext>,
to: &Arc<RnsContext>,
scaling_factor: ScalingFactor,
) -> Self {
// Let's define gamma = round(numerator * from.product / denominator)
let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
Self::extract_projection_and_theta(
to,
&from.product,
&scaling_factor.numerator,
&scaling_factor.denominator,
false,
);
let gamma_shoup = izip!(&gamma, &to.moduli)
.map(|(wi, q)| q.shoup(*wi))
.collect_vec();
// Let's define omega_i = round(from.garner_i * numerator / denominator)
let mut omega = Vec::with_capacity(to.moduli.len());
let mut omega_shoup = Vec::with_capacity(to.moduli.len());
for _ in &to.moduli {
omega.push(vec![0u64; from.moduli.len()].into_boxed_slice());
omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice());
}
let mut theta_omega_lo = Vec::with_capacity(from.garner.len());
let mut theta_omega_hi = Vec::with_capacity(from.garner.len());
let mut theta_omega_sign = Vec::with_capacity(from.garner.len());
for i in 0..from.garner.len() {
let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) =
Self::extract_projection_and_theta(
to,
&from.garner[i],
&scaling_factor.numerator,
&scaling_factor.denominator,
true,
);
for j in 0..to.moduli.len() {
let qj = &to.moduli[j];
omega[j][i] = qj.reduce(omega_i[j]);
omega_shoup[j][i] = qj.shoup(omega[j][i]);
}
theta_omega_lo.push(theta_omega_i_lo);
theta_omega_hi.push(theta_omega_i_hi);
theta_omega_sign.push(theta_omega_i_sign);
}
// Let's define omega_i = round(from.garner_i * numerator / denominator)
let mut omega = Vec::with_capacity(to.moduli.len());
let mut omega_shoup = Vec::with_capacity(to.moduli.len());
for _ in &to.moduli {
omega.push(vec![0u64; from.moduli.len()].into_boxed_slice());
omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice());
}
let mut theta_omega_lo = Vec::with_capacity(from.garner.len());
let mut theta_omega_hi = Vec::with_capacity(from.garner.len());
let mut theta_omega_sign = Vec::with_capacity(from.garner.len());
for i in 0..from.garner.len() {
let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) =
Self::extract_projection_and_theta(
to,
&from.garner[i],
&scaling_factor.numerator,
&scaling_factor.denominator,
true,
);
for j in 0..to.moduli.len() {
let qj = &to.moduli[j];
omega[j][i] = qj.reduce(omega_i[j]);
omega_shoup[j][i] = qj.shoup(omega[j][i]);
}
theta_omega_lo.push(theta_omega_i_lo);
theta_omega_hi.push(theta_omega_i_hi);
theta_omega_sign.push(theta_omega_i_sign);
}
// Determine the shift so that the sum of the scaled theta_garner fit on an U192
// (shift + 1) + log(q * n) <= 192
let theta_garner_shift = min(
from.moduli_u64
.iter()
.map(|qi| {
192 - 1
- ilog2(
((*qi as u128) * (from.moduli_u64.len() as u128)).next_power_of_two(),
)
})
.min()
.unwrap(),
127,
);
// Finally, define theta_garner_i = from.garner_i / product, also scaled by
// 2^127.
let mut theta_garner_lo = Vec::with_capacity(from.garner.len());
let mut theta_garner_hi = Vec::with_capacity(from.garner.len());
for garner_i in &from.garner {
let mut theta: BigUint =
((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
let theta_hi: BigUint = &theta >> 64;
theta -= &theta_hi << 64;
theta_garner_lo.push(theta.to_u64().unwrap());
theta_garner_hi.push(theta_hi.to_u64().unwrap());
}
// Determine the shift so that the sum of the scaled theta_garner fit on an U192
// (shift + 1) + log(q * n) <= 192
let theta_garner_shift = min(
from.moduli_u64
.iter()
.map(|qi| {
192 - 1
- ilog2(
((*qi as u128) * (from.moduli_u64.len() as u128)).next_power_of_two(),
)
})
.min()
.unwrap(),
127,
);
// Finally, define theta_garner_i = from.garner_i / product, also scaled by
// 2^127.
let mut theta_garner_lo = Vec::with_capacity(from.garner.len());
let mut theta_garner_hi = Vec::with_capacity(from.garner.len());
for garner_i in &from.garner {
let mut theta: BigUint =
((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
let theta_hi: BigUint = &theta >> 64;
theta -= &theta_hi << 64;
theta_garner_lo.push(theta.to_u64().unwrap());
theta_garner_hi.push(theta_hi.to_u64().unwrap());
}
Self {
from: from.clone(),
to: to.clone(),
scaling_factor,
gamma: gamma.into_boxed_slice(),
gamma_shoup: gamma_shoup.into_boxed_slice(),
theta_gamma_lo,
theta_gamma_hi,
theta_gamma_sign,
omega: omega.into_boxed_slice(),
omega_shoup: omega_shoup.into_boxed_slice(),
theta_omega_lo: theta_omega_lo.into_boxed_slice(),
theta_omega_hi: theta_omega_hi.into_boxed_slice(),
theta_omega_sign: theta_omega_sign.into_boxed_slice(),
theta_garner_lo: theta_garner_lo.into_boxed_slice(),
theta_garner_hi: theta_garner_hi.into_boxed_slice(),
theta_garner_shift,
}
}
Self {
from: from.clone(),
to: to.clone(),
scaling_factor,
gamma: gamma.into_boxed_slice(),
gamma_shoup: gamma_shoup.into_boxed_slice(),
theta_gamma_lo,
theta_gamma_hi,
theta_gamma_sign,
omega: omega.into_boxed_slice(),
omega_shoup: omega_shoup.into_boxed_slice(),
theta_omega_lo: theta_omega_lo.into_boxed_slice(),
theta_omega_hi: theta_omega_hi.into_boxed_slice(),
theta_omega_sign: theta_omega_sign.into_boxed_slice(),
theta_garner_lo: theta_garner_lo.into_boxed_slice(),
theta_garner_hi: theta_garner_hi.into_boxed_slice(),
theta_garner_shift,
}
}
// Let's define gamma = round(numerator * input / denominator)
// and theta_gamma such that theta_gamma = numerator * input / denominator -
// gamma. This function projects gamma in the RNS context, and scales
// theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the
// RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma =
// (-1)**theta_sign * (theta_lo + 2^64 * theta_hi).
fn extract_projection_and_theta(
ctx: &RnsContext,
input: &BigUint,
numerator: &BigUint,
denominator: &BigUint,
round_up: bool,
) -> (Vec<u64>, u64, u64, bool) {
let gamma = (numerator * input + (denominator >> 1)) / denominator;
let projected = ctx.project(&gamma);
// Let's define gamma = round(numerator * input / denominator)
// and theta_gamma such that theta_gamma = numerator * input / denominator -
// gamma. This function projects gamma in the RNS context, and scales
// theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the
// RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma =
// (-1)**theta_sign * (theta_lo + 2^64 * theta_hi).
fn extract_projection_and_theta(
ctx: &RnsContext,
input: &BigUint,
numerator: &BigUint,
denominator: &BigUint,
round_up: bool,
) -> (Vec<u64>, u64, u64, bool) {
let gamma = (numerator * input + (denominator >> 1)) / denominator;
let projected = ctx.project(&gamma);
let mut theta = (numerator * input) % denominator;
let mut theta_sign = false;
if denominator > &BigUint::one() {
// If denominator is odd, flip theta if theta > (denominator >> 1)
if denominator & BigUint::one() == BigUint::one() {
if theta > (denominator >> 1) {
theta_sign = true;
theta = denominator - theta;
}
} else {
// denominator is even, flip if theta >= (denominator >> 1)
if theta >= (denominator >> 1) {
theta_sign = true;
theta = denominator - theta;
}
}
}
// theta = ((theta << 127) + (denominator >> 1)) / denominator;
// We can now split theta into two u64 words.
if round_up {
if theta_sign {
theta = (theta << 127) / denominator;
} else {
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
}
} else if theta_sign {
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
} else {
theta = (theta << 127) / denominator;
}
let theta_hi_biguint: BigUint = &theta >> 64;
theta -= &theta_hi_biguint << 64;
let theta_lo = theta.to_u64().unwrap();
let theta_hi = theta_hi_biguint.to_u64().unwrap();
let mut theta = (numerator * input) % denominator;
let mut theta_sign = false;
if denominator > &BigUint::one() {
// If denominator is odd, flip theta if theta > (denominator >> 1)
if denominator & BigUint::one() == BigUint::one() {
if theta > (denominator >> 1) {
theta_sign = true;
theta = denominator - theta;
}
} else {
// denominator is even, flip if theta >= (denominator >> 1)
if theta >= (denominator >> 1) {
theta_sign = true;
theta = denominator - theta;
}
}
}
// theta = ((theta << 127) + (denominator >> 1)) / denominator;
// We can now split theta into two u64 words.
if round_up {
if theta_sign {
theta = (theta << 127) / denominator;
} else {
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
}
} else if theta_sign {
theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
} else {
theta = (theta << 127) / denominator;
}
let theta_hi_biguint: BigUint = &theta >> 64;
theta -= &theta_hi_biguint << 64;
let theta_lo = theta.to_u64().unwrap();
let theta_hi = theta_hi_biguint.to_u64().unwrap();
(projected, theta_lo, theta_hi, theta_sign)
}
(projected, theta_lo, theta_hi, theta_sign)
}
/// Output the RNS representation of the rests scaled by numerator *
/// denominator, and either rounded or floored.
///
/// Aborts if the number of rests is different than the number of moduli in
/// debug mode, or if the size is not in [1, ..., rests.len()].
pub fn scale_new(&self, rests: ArrayView1<u64>, size: usize) -> Vec<u64> {
let mut out = vec![0; size];
self.scale(rests, (&mut out).into(), 0);
out
}
/// Output the RNS representation of the rests scaled by numerator *
/// denominator, and either rounded or floored.
///
/// Aborts if the number of rests is different than the number of moduli in
/// debug mode, or if the size is not in [1, ..., rests.len()].
pub fn scale_new(&self, rests: ArrayView1<u64>, size: usize) -> Vec<u64> {
let mut out = vec![0; size];
self.scale(rests, (&mut out).into(), 0);
out
}
/// Compute the RNS representation of the rests scaled by numerator *
/// denominator, and either rounded or floored, and store the result in
/// `out`.
///
/// Aborts if the number of rests is different than the number of moduli in
/// debug mode, or if the size of out is not in [1, ..., rests.len()].
pub fn scale(
&self,
rests: ArrayView1<u64>,
mut out: ArrayViewMut1<u64>,
starting_index: usize,
) {
debug_assert_eq!(rests.len(), self.from.moduli_u64.len());
debug_assert!(!out.is_empty());
debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());
/// Compute the RNS representation of the rests scaled by numerator *
/// denominator, and either rounded or floored, and store the result in
/// `out`.
///
/// Aborts if the number of rests is different than the number of moduli in
/// debug mode, or if the size of out is not in [1, ..., rests.len()].
pub fn scale(
&self,
rests: ArrayView1<u64>,
mut out: ArrayViewMut1<u64>,
starting_index: usize,
) {
debug_assert_eq!(rests.len(), self.from.moduli_u64.len());
debug_assert!(!out.is_empty());
debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());
// First, let's compute the inner product of the rests with theta_omega.
let mut sum_theta_garner = U192::ZERO;
for (thetag_lo, thetag_hi, ri) in izip!(
self.theta_garner_lo.iter(),
self.theta_garner_hi.iter(),
rests
) {
let lo = (*ri as u128) * (*thetag_lo as u128);
let hi = (*ri as u128) * (*thetag_hi as u128) + (lo >> 64);
sum_theta_garner = sum_theta_garner.wrapping_add(&U192::from_words([
lo as u64,
hi as u64,
(hi >> 64) as u64,
]));
}
// Let's compute v = round(sum_theta_garner / 2^theta_garner_shift)
sum_theta_garner >>= self.theta_garner_shift - 1;
let v = <[u64; 3]>::from(sum_theta_garner);
let v = div_ceil((v[0] as u128) | ((v[1] as u128) << 64), 2);
// First, let's compute the inner product of the rests with theta_omega.
let mut sum_theta_garner = U192::ZERO;
for (thetag_lo, thetag_hi, ri) in izip!(
self.theta_garner_lo.iter(),
self.theta_garner_hi.iter(),
rests
) {
let lo = (*ri as u128) * (*thetag_lo as u128);
let hi = (*ri as u128) * (*thetag_hi as u128) + (lo >> 64);
sum_theta_garner = sum_theta_garner.wrapping_add(&U192::from_words([
lo as u64,
hi as u64,
(hi >> 64) as u64,
]));
}
// Let's compute v = round(sum_theta_garner / 2^theta_garner_shift)
sum_theta_garner >>= self.theta_garner_shift - 1;
let v = <[u64; 3]>::from(sum_theta_garner);
let v = div_ceil((v[0] as u128) | ((v[1] as u128) << 64), 2);
// If the scaling factor is not 1, compute the inner product with the
// theta_omega
let mut w_sign = false;
let mut w = 0u128;
if !self.scaling_factor.is_one {
let mut sum_theta_omega = U256::zero();
for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
self.theta_omega_lo.iter(),
self.theta_omega_hi.iter(),
self.theta_omega_sign.iter(),
rests
) {
let lo = (*ri as u128) * (*thetao_lo as u128);
let hi = (*ri as u128) * (*thetao_hi as u128) + (lo >> 64);
if *thetao_sign {
sum_theta_omega.wrapping_sub_assign(U256::from([
lo as u64,
hi as u64,
(hi >> 64) as u64,
0,
]));
} else {
sum_theta_omega.wrapping_add_assign(U256::from([
lo as u64,
hi as u64,
(hi >> 64) as u64,
0,
]));
}
}
// If the scaling factor is not 1, compute the inner product with the
// theta_omega
let mut w_sign = false;
let mut w = 0u128;
if !self.scaling_factor.is_one {
let mut sum_theta_omega = U256::zero();
for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
self.theta_omega_lo.iter(),
self.theta_omega_hi.iter(),
self.theta_omega_sign.iter(),
rests
) {
let lo = (*ri as u128) * (*thetao_lo as u128);
let hi = (*ri as u128) * (*thetao_hi as u128) + (lo >> 64);
if *thetao_sign {
sum_theta_omega.wrapping_sub_assign(U256::from([
lo as u64,
hi as u64,
(hi >> 64) as u64,
0,
]));
} else {
sum_theta_omega.wrapping_add_assign(U256::from([
lo as u64,
hi as u64,
(hi >> 64) as u64,
0,
]));
}
}
// Let's subtract v * theta_gamma to sum_theta_omega.
let vt_lo_lo = ((v as u64) as u128) * (self.theta_gamma_lo as u128);
let vt_lo_hi = ((v as u64) as u128) * (self.theta_gamma_hi as u128);
let vt_hi_lo = (v >> 64) * (self.theta_gamma_lo as u128);
let vt_hi_hi = (v >> 64) * (self.theta_gamma_hi as u128);
let vt_mi =
(vt_lo_lo >> 64) + ((vt_lo_hi as u64) as u128) + ((vt_hi_lo as u64) as u128);
let vt_hi = (vt_lo_hi >> 64) + (vt_mi >> 64) + ((vt_hi_hi as u64) as u128);
if self.theta_gamma_sign {
sum_theta_omega.wrapping_add_assign(U256::from([
vt_lo_lo as u64,
vt_mi as u64,
vt_hi as u64,
0,
]))
} else {
sum_theta_omega.wrapping_sub_assign(U256::from([
vt_lo_lo as u64,
vt_mi as u64,
vt_hi as u64,
0,
]))
}
// Let's subtract v * theta_gamma to sum_theta_omega.
let vt_lo_lo = ((v as u64) as u128) * (self.theta_gamma_lo as u128);
let vt_lo_hi = ((v as u64) as u128) * (self.theta_gamma_hi as u128);
let vt_hi_lo = (v >> 64) * (self.theta_gamma_lo as u128);
let vt_hi_hi = (v >> 64) * (self.theta_gamma_hi as u128);
let vt_mi =
(vt_lo_lo >> 64) + ((vt_lo_hi as u64) as u128) + ((vt_hi_lo as u64) as u128);
let vt_hi = (vt_lo_hi >> 64) + (vt_mi >> 64) + ((vt_hi_hi as u64) as u128);
if self.theta_gamma_sign {
sum_theta_omega.wrapping_add_assign(U256::from([
vt_lo_lo as u64,
vt_mi as u64,
vt_hi as u64,
0,
]))
} else {
sum_theta_omega.wrapping_sub_assign(U256::from([
vt_lo_lo as u64,
vt_mi as u64,
vt_hi as u64,
0,
]))
}
// Let's compute w = round(sum_theta_omega / 2^127).
w_sign = sum_theta_omega.msb() > 0;
// Let's compute w = round(sum_theta_omega / 2^127).
w_sign = sum_theta_omega.msb() > 0;
if w_sign {
w = u128::from(&((!sum_theta_omega) >> 126)) + 1;
w /= 2;
} else {
w = u128::from(&(sum_theta_omega >> 126));
w = div_ceil(w, 2)
}
}
if w_sign {
w = u128::from(&((!sum_theta_omega) >> 126)) + 1;
w /= 2;
} else {
w = u128::from(&(sum_theta_omega >> 126));
w = div_ceil(w, 2)
}
}
unsafe {
for i in 0..out.len() {
debug_assert!(starting_index + i < self.to.moduli.len());
debug_assert!(starting_index + i < self.omega.len());
debug_assert!(starting_index + i < self.omega_shoup.len());
debug_assert!(starting_index + i < self.gamma.len());
debug_assert!(starting_index + i < self.gamma_shoup.len());
let out_i = out.get_mut(i).unwrap();
let qi = self.to.moduli.get_unchecked(starting_index + i);
let omega_i = self.omega.get_unchecked(starting_index + i);
let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i);
let gamma_i = self.gamma.get_unchecked(starting_index + i);
let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i);
unsafe {
for i in 0..out.len() {
debug_assert!(starting_index + i < self.to.moduli.len());
debug_assert!(starting_index + i < self.omega.len());
debug_assert!(starting_index + i < self.omega_shoup.len());
debug_assert!(starting_index + i < self.gamma.len());
debug_assert!(starting_index + i < self.gamma_shoup.len());
let out_i = out.get_mut(i).unwrap();
let qi = self.to.moduli.get_unchecked(starting_index + i);
let omega_i = self.omega.get_unchecked(starting_index + i);
let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i);
let gamma_i = self.gamma.get_unchecked(starting_index + i);
let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i);
let mut yi = (qi.modulus() * 2
- qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i))
as u128;
let mut yi = (qi.modulus() * 2
- qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i))
as u128;
if !self.scaling_factor.is_one {
let wi = qi.lazy_reduce_u128(w);
yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128;
}
if !self.scaling_factor.is_one {
let wi = qi.lazy_reduce_u128(w);
yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128;
}
debug_assert!(rests.len() <= omega_i.len());
debug_assert!(rests.len() <= omega_shoup_i.len());
for j in 0..rests.len() {
yi += qi.lazy_mul_shoup(
*rests.get(j).unwrap(),
*omega_i.get_unchecked(j),
*omega_shoup_i.get_unchecked(j),
) as u128;
}
debug_assert!(rests.len() <= omega_i.len());
debug_assert!(rests.len() <= omega_shoup_i.len());
for j in 0..rests.len() {
yi += qi.lazy_mul_shoup(
*rests.get(j).unwrap(),
*omega_i.get_unchecked(j),
*omega_shoup_i.get_unchecked(j),
) as u128;
}
*out_i = qi.reduce_u128(yi)
}
}
}
*out_i = qi.reduce_u128(yi)
}
}
}
}
#[cfg(test)]
mod tests {
use std::{error::Error, sync::Arc};
use std::{error::Error, sync::Arc};
use super::RnsScaler;
use crate::rns::{scaler::ScalingFactor, RnsContext};
use fhe_util::catch_unwind;
use ndarray::ArrayView1;
use num_bigint::BigUint;
use num_traits::{ToPrimitive, Zero};
use rand::{thread_rng, RngCore};
use super::RnsScaler;
use crate::rns::{scaler::ScalingFactor, RnsContext};
use fhe_util::catch_unwind;
use ndarray::ArrayView1;
use num_bigint::BigUint;
use num_traits::{ToPrimitive, Zero};
use rand::{thread_rng, RngCore};
#[test]
fn constructor() -> Result<(), Box<dyn Error>> {
let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?);
#[test]
fn constructor() -> Result<(), Box<dyn Error>> {
let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?);
let scaler = RnsScaler::new(&q, &q, ScalingFactor::one());
assert_eq!(scaler.from, q);
let scaler = RnsScaler::new(&q, &q, ScalingFactor::one());
assert_eq!(scaler.from, q);
assert!(
catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err()
);
Ok(())
}
assert!(
catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err()
);
Ok(())
}
#[test]
fn scale_same_context() -> Result<(), Box<dyn Error>> {
let ntests = 1000;
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
let mut rng = thread_rng();
#[test]
fn scale_same_context() -> Result<(), Box<dyn Error>> {
let ntests = 1000;
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
let mut rng = thread_rng();
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
let n = BigUint::from(*numerator);
let d = BigUint::from(*denominator);
let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d));
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
let n = BigUint::from(*numerator);
let d = BigUint::from(*denominator);
let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d));
for _ in 0..ntests {
let x = vec![
rng.next_u64() % q.moduli_u64[0],
rng.next_u64() % q.moduli_u64[1],
rng.next_u64() % q.moduli_u64[2],
];
let mut x_lift = q.lift(ArrayView1::from(&x));
let x_sign = x_lift >= (q.modulus() >> 1);
if x_sign {
x_lift = q.modulus() - x_lift;
}
for _ in 0..ntests {
let x = vec![
rng.next_u64() % q.moduli_u64[0],
rng.next_u64() % q.moduli_u64[1],
rng.next_u64() % q.moduli_u64[2],
];
let mut x_lift = q.lift(ArrayView1::from(&x));
let x_sign = x_lift >= (q.modulus() >> 1);
if x_sign {
x_lift = q.modulus() - x_lift;
}
let z = scaler.scale_new((&x).into(), x.len());
let x_scaled_round = if x_sign {
if d.to_u64().unwrap() % 2 == 0 {
q.modulus()
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus()
} else {
q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus()
}
} else {
&(&x_lift * &n + (&d >> 1)) / &d
};
assert_eq!(z, q.project(&x_scaled_round));
}
}
}
Ok(())
}
let z = scaler.scale_new((&x).into(), x.len());
let x_scaled_round = if x_sign {
if d.to_u64().unwrap() % 2 == 0 {
q.modulus()
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus()
} else {
q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus()
}
} else {
&(&x_lift * &n + (&d >> 1)) / &d
};
assert_eq!(z, q.project(&x_scaled_round));
}
}
}
Ok(())
}
#[test]
fn scale_different_contexts() -> Result<(), Box<dyn Error>> {
let ntests = 100;
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
let r = Arc::new(RnsContext::new(&[
4u64,
4611686018326724609,
1153,
4611686018309947393,
4611686018282684417,
4611686018257518593,
4611686018232352769,
4611686018171535361,
4611686018106523649,
4611686018058289153,
])?);
let mut rng = thread_rng();
#[test]
fn scale_different_contexts() -> Result<(), Box<dyn Error>> {
let ntests = 100;
let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
let r = Arc::new(RnsContext::new(&[
4u64,
4611686018326724609,
1153,
4611686018309947393,
4611686018282684417,
4611686018257518593,
4611686018232352769,
4611686018171535361,
4611686018106523649,
4611686018058289153,
])?);
let mut rng = thread_rng();
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
let n = BigUint::from(*numerator);
let d = BigUint::from(*denominator);
let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d));
for _ in 0..ntests {
let x = vec![
rng.next_u64() % q.moduli_u64[0],
rng.next_u64() % q.moduli_u64[1],
rng.next_u64() % q.moduli_u64[2],
];
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
let n = BigUint::from(*numerator);
let d = BigUint::from(*denominator);
let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d));
for _ in 0..ntests {
let x = vec![
rng.next_u64() % q.moduli_u64[0],
rng.next_u64() % q.moduli_u64[1],
rng.next_u64() % q.moduli_u64[2],
];
let mut x_lift = q.lift(ArrayView1::from(&x));
let x_sign = x_lift >= (q.modulus() >> 1);
if x_sign {
x_lift = q.modulus() - x_lift;
}
let mut x_lift = q.lift(ArrayView1::from(&x));
let x_sign = x_lift >= (q.modulus() >> 1);
if x_sign {
x_lift = q.modulus() - x_lift;
}
let y = scaler.scale_new((&x).into(), r.moduli.len());
let x_scaled_round = if x_sign {
if d.to_u64().unwrap() % 2 == 0 {
r.modulus()
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus()
} else {
r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus()
}
} else {
&(&x_lift * &n + (&d >> 1)) / &d
};
assert_eq!(y, r.project(&x_scaled_round));
}
}
}
Ok(())
}
let y = scaler.scale_new((&x).into(), r.moduli.len());
let x_scaled_round = if x_sign {
if d.to_u64().unwrap() % 2 == 0 {
r.modulus()
- (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus()
} else {
r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus()
}
} else {
&(&x_lift * &n + (&d >> 1)) / &d
};
assert_eq!(y, r.project(&x_scaled_round));
}
}
}
Ok(())
}
}

View File

@@ -3,233 +3,233 @@ use num_bigint::BigUint;
use std::{fmt::Debug, sync::Arc};
use crate::{
rns::RnsContext,
zq::{ntt::NttOperator, Modulus},
Error, Result,
rns::RnsContext,
zq::{ntt::NttOperator, Modulus},
Error, Result,
};
/// Struct that holds the context associated with elements in rq.
#[derive(Default, Clone, PartialEq, Eq)]
pub struct Context {
pub(crate) moduli: Box<[u64]>,
pub(crate) q: Box<[Modulus]>,
pub(crate) rns: Arc<RnsContext>,
pub(crate) ops: Box<[NttOperator]>,
pub(crate) degree: usize,
pub(crate) bitrev: Box<[usize]>,
pub(crate) inv_last_qi_mod_qj: Box<[u64]>,
pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>,
pub(crate) next_context: Option<Arc<Context>>,
pub(crate) moduli: Box<[u64]>,
pub(crate) q: Box<[Modulus]>,
pub(crate) rns: Arc<RnsContext>,
pub(crate) ops: Box<[NttOperator]>,
pub(crate) degree: usize,
pub(crate) bitrev: Box<[usize]>,
pub(crate) inv_last_qi_mod_qj: Box<[u64]>,
pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>,
pub(crate) next_context: Option<Arc<Context>>,
}
impl Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("moduli", &self.moduli)
// .field("q", &self.q)
// .field("rns", &self.rns)
// .field("ops", &self.ops)
// .field("degree", &self.degree)
// .field("bitrev", &self.bitrev)
// .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj)
// .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup)
.field("next_context", &self.next_context)
.finish()
}
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("moduli", &self.moduli)
// .field("q", &self.q)
// .field("rns", &self.rns)
// .field("ops", &self.ops)
// .field("degree", &self.degree)
// .field("bitrev", &self.bitrev)
// .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj)
// .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup)
.field("next_context", &self.next_context)
.finish()
}
}
impl Context {
/// Creates a context from a list of moduli and a polynomial degree.
///
/// Returns an error if the moduli are not primes less than 62 bits which
/// supports the NTT of size `degree`.
pub fn new(moduli: &[u64], degree: usize) -> Result<Self> {
if !degree.is_power_of_two() || degree < 8 {
Err(Error::Default(
"The degree is not a power of two larger or equal to 8".to_string(),
))
} else {
let mut q = Vec::with_capacity(moduli.len());
let rns = Arc::new(RnsContext::new(moduli)?);
let mut ops = Vec::with_capacity(moduli.len());
for modulus in moduli {
let qi = Modulus::new(*modulus)?;
if let Some(op) = NttOperator::new(&qi, degree) {
q.push(qi);
ops.push(op);
} else {
return Err(Error::Default(
"Impossible to construct a Ntt operator".to_string(),
));
}
}
let bitrev = (0..degree)
.map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1))
.collect_vec();
/// Creates a context from a list of moduli and a polynomial degree.
///
/// Returns an error if the moduli are not primes less than 62 bits which
/// supports the NTT of size `degree`.
pub fn new(moduli: &[u64], degree: usize) -> Result<Self> {
if !degree.is_power_of_two() || degree < 8 {
Err(Error::Default(
"The degree is not a power of two larger or equal to 8".to_string(),
))
} else {
let mut q = Vec::with_capacity(moduli.len());
let rns = Arc::new(RnsContext::new(moduli)?);
let mut ops = Vec::with_capacity(moduli.len());
for modulus in moduli {
let qi = Modulus::new(*modulus)?;
if let Some(op) = NttOperator::new(&qi, degree) {
q.push(qi);
ops.push(op);
} else {
return Err(Error::Default(
"Impossible to construct a Ntt operator".to_string(),
));
}
}
let bitrev = (0..degree)
.map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1))
.collect_vec();
let mut inv_last_qi_mod_qj = vec![];
let mut inv_last_qi_mod_qj_shoup = vec![];
let q_last = moduli.last().unwrap();
for qi in &q[..q.len() - 1] {
let inv = qi.inv(qi.reduce(*q_last)).unwrap();
inv_last_qi_mod_qj.push(inv);
inv_last_qi_mod_qj_shoup.push(qi.shoup(inv));
}
let mut inv_last_qi_mod_qj = vec![];
let mut inv_last_qi_mod_qj_shoup = vec![];
let q_last = moduli.last().unwrap();
for qi in &q[..q.len() - 1] {
let inv = qi.inv(qi.reduce(*q_last)).unwrap();
inv_last_qi_mod_qj.push(inv);
inv_last_qi_mod_qj_shoup.push(qi.shoup(inv));
}
let next_context = if moduli.len() >= 2 {
Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?))
} else {
None
};
let next_context = if moduli.len() >= 2 {
Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?))
} else {
None
};
Ok(Self {
moduli: moduli.to_owned().into_boxed_slice(),
q: q.into_boxed_slice(),
rns,
ops: ops.into_boxed_slice(),
degree,
bitrev: bitrev.into_boxed_slice(),
inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(),
inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(),
next_context,
})
}
}
Ok(Self {
moduli: moduli.to_owned().into_boxed_slice(),
q: q.into_boxed_slice(),
rns,
ops: ops.into_boxed_slice(),
degree,
bitrev: bitrev.into_boxed_slice(),
inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(),
inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(),
next_context,
})
}
}
/// Returns the modulus as a BigUint.
pub fn modulus(&self) -> &BigUint {
self.rns.modulus()
}
/// Returns the modulus as a BigUint.
pub fn modulus(&self) -> &BigUint {
self.rns.modulus()
}
/// Returns a reference to the moduli in this context.
pub fn moduli(&self) -> &[u64] {
&self.moduli
}
/// Returns a reference to the moduli in this context.
pub fn moduli(&self) -> &[u64] {
&self.moduli
}
/// Returns a reference to the moduli as Modulus in this context.
pub fn moduli_operators(&self) -> &[Modulus] {
&self.q
}
/// Returns a reference to the moduli as Modulus in this context.
pub fn moduli_operators(&self) -> &[Modulus] {
&self.q
}
/// Returns the number of iterations to switch to a children context.
/// Returns an error if the context provided is not a child context.
pub fn niterations_to(&self, context: &Arc<Context>) -> Result<usize> {
if context.as_ref() == self {
return Ok(0);
}
/// Returns the number of iterations to switch to a children context.
/// Returns an error if the context provided is not a child context.
pub fn niterations_to(&self, context: &Arc<Context>) -> Result<usize> {
if context.as_ref() == self {
return Ok(0);
}
let mut niterations = 0;
let mut found = false;
let mut current_ctx = Arc::new(self.clone());
while current_ctx.next_context.is_some() {
niterations += 1;
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
if &current_ctx == context {
found = true;
break;
}
}
if found {
Ok(niterations)
} else {
Err(Error::InvalidContext)
}
}
let mut niterations = 0;
let mut found = false;
let mut current_ctx = Arc::new(self.clone());
while current_ctx.next_context.is_some() {
niterations += 1;
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
if &current_ctx == context {
found = true;
break;
}
}
if found {
Ok(niterations)
} else {
Err(Error::InvalidContext)
}
}
/// Returns the context after `i` iterations.
pub fn context_at_level(&self, i: usize) -> Result<Arc<Self>> {
if i >= self.moduli.len() {
Err(Error::Default(
"No context at the specified level".to_string(),
))
} else {
let mut current_ctx = Arc::new(self.clone());
for _ in 0..i {
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
}
Ok(current_ctx)
}
}
/// Returns the context after `i` iterations.
pub fn context_at_level(&self, i: usize) -> Result<Arc<Self>> {
if i >= self.moduli.len() {
Err(Error::Default(
"No context at the specified level".to_string(),
))
} else {
let mut current_ctx = Arc::new(self.clone());
for _ in 0..i {
current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
}
Ok(current_ctx)
}
}
}
#[cfg(test)]
mod tests {
use std::{error::Error, sync::Arc};
use std::{error::Error, sync::Arc};
use crate::{rq::Context, zq::ntt::supports_ntt};
use crate::{rq::Context, zq::ntt::supports_ntt};
const MODULI: &[u64; 5] = &[
1153,
4611686018326724609,
4611686018309947393,
4611686018232352769,
4611686018171535361,
];
const MODULI: &[u64; 5] = &[
1153,
4611686018326724609,
4611686018309947393,
4611686018232352769,
4611686018171535361,
];
#[test]
fn context_constructor() {
for modulus in MODULI {
// modulus is = 1 modulo 2 * 8
assert!(Context::new(&[*modulus], 8).is_ok());
#[test]
fn context_constructor() {
for modulus in MODULI {
// modulus is = 1 modulo 2 * 8
assert!(Context::new(&[*modulus], 8).is_ok());
if supports_ntt(*modulus, 128) {
assert!(Context::new(&[*modulus], 128).is_ok());
} else {
assert!(Context::new(&[*modulus], 128).is_err());
}
}
if supports_ntt(*modulus, 128) {
assert!(Context::new(&[*modulus], 128).is_ok());
} else {
assert!(Context::new(&[*modulus], 128).is_err());
}
}
// All moduli in MODULI are = 1 modulo 2 * 8
assert!(Context::new(MODULI, 8).is_ok());
// All moduli in MODULI are = 1 modulo 2 * 8
assert!(Context::new(MODULI, 8).is_ok());
// This should fail since 1153 != 1 moduli 2 * 128
assert!(Context::new(MODULI, 128).is_err());
}
// This should fail since 1153 != 1 moduli 2 * 128
assert!(Context::new(MODULI, 128).is_err());
}
#[test]
fn next_context() -> Result<(), Box<dyn Error>> {
// A context should have a children pointing to a context with one less modulus.
let context = Arc::new(Context::new(MODULI, 8)?);
assert_eq!(
context.next_context,
Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?))
);
#[test]
fn next_context() -> Result<(), Box<dyn Error>> {
// A context should have a children pointing to a context with one less modulus.
let context = Arc::new(Context::new(MODULI, 8)?);
assert_eq!(
context.next_context,
Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 8)?))
);
// We can go down the chain of the MODULI.len() - 1 context's.
let mut number_of_children = 0;
let mut current = context;
while current.next_context.is_some() {
number_of_children += 1;
current = current.next_context.as_ref().unwrap().clone();
}
assert_eq!(number_of_children, MODULI.len() - 1);
// We can go down the chain of the MODULI.len() - 1 context's.
let mut number_of_children = 0;
let mut current = context;
while current.next_context.is_some() {
number_of_children += 1;
current = current.next_context.as_ref().unwrap().clone();
}
assert_eq!(number_of_children, MODULI.len() - 1);
Ok(())
}
Ok(())
}
#[test]
fn niterations_to() -> Result<(), Box<dyn Error>> {
// A context should have a children pointing to a context with one less modulus.
let context = Arc::new(Context::new(MODULI, 8)?);
#[test]
fn niterations_to() -> Result<(), Box<dyn Error>> {
// A context should have a children pointing to a context with one less modulus.
let context = Arc::new(Context::new(MODULI, 8)?);
assert_eq!(context.niterations_to(&context).ok(), Some(0));
assert_eq!(context.niterations_to(&context).ok(), Some(0));
assert_eq!(
context
.niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?))
.err(),
Some(crate::Error::InvalidContext)
);
assert_eq!(
context
.niterations_to(&Arc::new(Context::new(&MODULI[1..], 8)?))
.err(),
Some(crate::Error::InvalidContext)
);
for i in 1..MODULI.len() {
assert_eq!(
context
.niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?))
.ok(),
Some(i)
);
}
for i in 1..MODULI.len() {
assert_eq!(
context
.niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 8)?))
.ok(),
Some(i)
);
}
Ok(())
}
Ok(())
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -4,8 +4,8 @@
use super::{Context, Poly, Representation};
use crate::{
rns::{RnsScaler, ScalingFactor},
Error, Result,
rns::{RnsScaler, ScalingFactor},
Error, Result,
};
use itertools::izip;
use ndarray::{s, Array2, Axis};
@@ -14,199 +14,200 @@ use std::sync::Arc;
/// Context extender.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct Scaler {
from: Arc<Context>,
to: Arc<Context>,
number_common_moduli: usize,
scaler: RnsScaler,
from: Arc<Context>,
to: Arc<Context>,
number_common_moduli: usize,
scaler: RnsScaler,
}
impl Scaler {
/// Create a scaler from a context `from` to a context `to`.
pub fn new(from: &Arc<Context>, to: &Arc<Context>, factor: ScalingFactor) -> Result<Self> {
if from.degree != to.degree {
return Err(Error::Default("Incompatible degrees".to_string()));
}
/// Create a scaler from a context `from` to a context `to`.
pub fn new(from: &Arc<Context>, to: &Arc<Context>, factor: ScalingFactor) -> Result<Self> {
if from.degree != to.degree {
return Err(Error::Default("Incompatible degrees".to_string()));
}
let mut number_common_moduli = 0;
if factor.is_one {
for (qi, pi) in izip!(from.q.iter(), to.q.iter()) {
if qi == pi {
number_common_moduli += 1
} else {
break;
}
}
}
let mut number_common_moduli = 0;
if factor.is_one {
for (qi, pi) in izip!(from.q.iter(), to.q.iter()) {
if qi == pi {
number_common_moduli += 1
} else {
break;
}
}
}
let scaler = RnsScaler::new(&from.rns, &to.rns, factor);
let scaler = RnsScaler::new(&from.rns, &to.rns, factor);
Ok(Self {
from: from.clone(),
to: to.clone(),
number_common_moduli,
scaler,
})
}
Ok(Self {
from: from.clone(),
to: to.clone(),
number_common_moduli,
scaler,
})
}
/// Scale a polynomial
pub(crate) fn scale(&self, p: &Poly) -> Result<Poly> {
if p.ctx.as_ref() != self.from.as_ref() {
Err(Error::Default(
"The input polynomial does not have the correct context".to_string(),
))
} else {
let mut representation = p.representation.clone();
if representation == Representation::NttShoup {
representation = Representation::Ntt;
}
/// Scale a polynomial
pub(crate) fn scale(&self, p: &Poly) -> Result<Poly> {
if p.ctx.as_ref() != self.from.as_ref() {
Err(Error::Default(
"The input polynomial does not have the correct context".to_string(),
))
} else {
let mut representation = p.representation.clone();
if representation == Representation::NttShoup {
representation = Representation::Ntt;
}
let mut new_coefficients = Array2::<u64>::zeros((self.to.q.len(), self.to.degree));
let mut new_coefficients = Array2::<u64>::zeros((self.to.q.len(), self.to.degree));
if self.number_common_moduli > 0 {
new_coefficients
.slice_mut(s![..self.number_common_moduli, ..])
.assign(&p.coefficients.slice(s![..self.number_common_moduli, ..]));
}
if self.number_common_moduli > 0 {
new_coefficients
.slice_mut(s![..self.number_common_moduli, ..])
.assign(&p.coefficients.slice(s![..self.number_common_moduli, ..]));
}
if self.number_common_moduli < self.to.q.len() {
if p.representation == Representation::PowerBasis {
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.axis_iter_mut(Axis(1)),
p.coefficients.axis_iter(Axis(1))
)
.for_each(|(new_column, column)| {
self.scaler
.scale(column, new_column, self.number_common_moduli)
});
} else if self.number_common_moduli < self.to.q.len() {
let mut p_coefficients_powerbasis = p.coefficients.clone();
// Backward NTT
if p.allow_variable_time_computations {
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
.for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
} else {
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
.for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
}
// Conversion
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.axis_iter_mut(Axis(1)),
p_coefficients_powerbasis.axis_iter(Axis(1))
)
.for_each(|(new_column, column)| {
self.scaler
.scale(column, new_column, self.number_common_moduli)
});
// Forward NTT on the second half
if p.allow_variable_time_computations {
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.outer_iter_mut(),
&self.to.ops[self.number_common_moduli..]
)
.for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
} else {
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.outer_iter_mut(),
&self.to.ops[self.number_common_moduli..]
)
.for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
}
}
}
if self.number_common_moduli < self.to.q.len() {
if p.representation == Representation::PowerBasis {
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.axis_iter_mut(Axis(1)),
p.coefficients.axis_iter(Axis(1))
)
.for_each(|(new_column, column)| {
self.scaler
.scale(column, new_column, self.number_common_moduli)
});
} else if self.number_common_moduli < self.to.q.len() {
let mut p_coefficients_powerbasis = p.coefficients.clone();
// Backward NTT
if p.allow_variable_time_computations {
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
.for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
} else {
izip!(p_coefficients_powerbasis.outer_iter_mut(), p.ctx.ops.iter())
.for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
}
// Conversion
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.axis_iter_mut(Axis(1)),
p_coefficients_powerbasis.axis_iter(Axis(1))
)
.for_each(|(new_column, column)| {
self.scaler
.scale(column, new_column, self.number_common_moduli)
});
// Forward NTT on the second half
if p.allow_variable_time_computations {
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.outer_iter_mut(),
&self.to.ops[self.number_common_moduli..]
)
.for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
} else {
izip!(
new_coefficients
.slice_mut(s![self.number_common_moduli.., ..])
.outer_iter_mut(),
&self.to.ops[self.number_common_moduli..]
)
.for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
}
}
}
Ok(Poly {
ctx: self.to.clone(),
representation,
allow_variable_time_computations: p.allow_variable_time_computations,
coefficients: new_coefficients,
coefficients_shoup: None,
has_lazy_coefficients: false,
})
}
}
Ok(Poly {
ctx: self.to.clone(),
representation,
allow_variable_time_computations: p.allow_variable_time_computations,
coefficients: new_coefficients,
coefficients_shoup: None,
has_lazy_coefficients: false,
})
}
}
}
#[cfg(test)]
mod tests {
use super::{Scaler, ScalingFactor};
use crate::rq::{Context, Poly, Representation};
use itertools::Itertools;
use num_bigint::BigUint;
use num_traits::{One, Zero};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use super::{Scaler, ScalingFactor};
use crate::rq::{Context, Poly, Representation};
use itertools::Itertools;
use num_bigint::BigUint;
use num_traits::{One, Zero};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
// Moduli to be used in tests.
static Q: &[u64; 3] = &[
4611686018282684417,
4611686018326724609,
4611686018309947393,
];
// Moduli to be used in tests.
static Q: &[u64; 3] = &[
4611686018282684417,
4611686018326724609,
4611686018309947393,
];
static P: &[u64; 3] = &[
4611686018282684417,
4611686018309947393,
4611686018257518593,
];
static P: &[u64; 3] = &[
4611686018282684417,
4611686018309947393,
4611686018257518593,
];
#[test]
fn scaler() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let ntests = 100;
let from = Arc::new(Context::new(Q, 8)?);
let to = Arc::new(Context::new(P, 8)?);
#[test]
fn scaler() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let ntests = 100;
let from = Arc::new(Context::new(Q, 8)?);
let to = Arc::new(Context::new(P, 8)?);
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
let n = BigUint::from(*numerator);
let d = BigUint::from(*denominator);
for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
let n = BigUint::from(*numerator);
let d = BigUint::from(*denominator);
let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?;
let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?;
for _ in 0..ntests {
let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng);
let poly_biguint = Vec::<BigUint>::from(&poly);
for _ in 0..ntests {
let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng);
let poly_biguint = Vec::<BigUint>::from(&poly);
let scaled_poly = scaler.scale(&poly)?;
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
let scaled_poly = scaler.scale(&poly)?;
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
let expected = poly_biguint
.iter()
.map(|i| {
if i >= &(from.modulus() >> 1usize) {
if &d & BigUint::one() == BigUint::zero() {
to.modulus()
- (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64))
/ &d) % to.modulus()
} else {
to.modulus()
- (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d)
% to.modulus()
}
} else {
((i * &n + (&d >> 1)) / &d) % to.modulus()
}
})
.collect_vec();
assert_eq!(expected, scaled_biguint);
let expected = poly_biguint
.iter()
.map(|i| {
if i >= &(from.modulus() >> 1usize) {
if &d & BigUint::one() == BigUint::zero() {
to.modulus()
- (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64))
/ &d)
% to.modulus()
} else {
to.modulus()
- (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d)
% to.modulus()
}
} else {
((i * &n + (&d >> 1)) / &d) % to.modulus()
}
})
.collect_vec();
assert_eq!(expected, scaled_biguint);
poly.change_representation(Representation::Ntt);
let mut scaled_poly = scaler.scale(&poly)?;
scaled_poly.change_representation(Representation::PowerBasis);
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
assert_eq!(expected, scaled_biguint);
}
}
}
poly.change_representation(Representation::Ntt);
let mut scaled_poly = scaler.scale(&poly)?;
scaled_poly.change_representation(Representation::PowerBasis);
let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
assert_eq!(expected, scaled_biguint);
}
}
}
Ok(())
}
Ok(())
}
}

View File

@@ -8,59 +8,59 @@ use fhe_traits::{DeserializeWithContext, Serialize};
use protobuf::Message;
impl Serialize for Poly {
fn to_bytes(&self) -> Vec<u8> {
let rq = Rq::from(self);
rq.write_to_bytes().unwrap()
}
fn to_bytes(&self) -> Vec<u8> {
let rq = Rq::from(self);
rq.write_to_bytes().unwrap()
}
}
impl DeserializeWithContext for Poly {
type Error = Error;
type Context = Context;
type Error = Error;
type Context = Context;
fn from_bytes(bytes: &[u8], ctx: &Arc<Context>) -> Result<Self, Self::Error> {
let rq = Rq::parse_from_bytes(bytes).map_err(|e| Error::Serialization(e.to_string()))?;
Poly::try_convert_from(&rq, ctx, false, None)
}
fn from_bytes(bytes: &[u8], ctx: &Arc<Context>) -> Result<Self, Self::Error> {
let rq = Rq::parse_from_bytes(bytes).map_err(|e| Error::Serialization(e.to_string()))?;
Poly::try_convert_from(&rq, ctx, false, None)
}
}
#[cfg(test)]
mod tests {
use std::{error::Error, sync::Arc};
use std::{error::Error, sync::Arc};
use fhe_traits::{DeserializeWithContext, Serialize};
use rand::thread_rng;
use fhe_traits::{DeserializeWithContext, Serialize};
use rand::thread_rng;
use crate::rq::{Context, Poly, Representation};
use crate::rq::{Context, Poly, Representation};
const Q: &[u64; 3] = &[
4611686018282684417,
4611686018326724609,
4611686018309947393,
];
const Q: &[u64; 3] = &[
4611686018282684417,
4611686018326724609,
4611686018309947393,
];
#[test]
fn serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
#[test]
fn serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for qi in Q {
let ctx = Arc::new(Context::new(&[*qi], 8)?);
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
}
for qi in Q {
let ctx = Arc::new(Context::new(&[*qi], 8)?);
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
}
let ctx = Arc::new(Context::new(Q, 8)?);
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let ctx = Arc::new(Context::new(Q, 8)?);
let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
let p = Poly::random(&ctx, Representation::NttShoup, &mut rng);
assert_eq!(p, Poly::from_bytes(&p.to_bytes(), &ctx)?);
Ok(())
}
Ok(())
}
}

View File

@@ -9,19 +9,19 @@ use std::sync::Arc;
/// Context switcher.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct Switcher {
pub(crate) scaler: Scaler,
pub(crate) scaler: Scaler,
}
impl Switcher {
/// Create a switcher from a context `from` to a context `to`.
pub fn new(from: &Arc<Context>, to: &Arc<Context>) -> Result<Self> {
Ok(Self {
scaler: Scaler::new(from, to, ScalingFactor::new(to.modulus(), from.modulus()))?,
})
}
/// Create a switcher from a context `from` to a context `to`.
pub fn new(from: &Arc<Context>, to: &Arc<Context>) -> Result<Self> {
Ok(Self {
scaler: Scaler::new(from, to, ScalingFactor::new(to.modulus(), from.modulus()))?,
})
}
/// Switch a polynomial.
pub(crate) fn switch(&self, p: &Poly) -> Result<Poly> {
self.scaler.scale(p)
}
/// Switch a polynomial.
pub(crate) fn switch(&self, p: &Poly) -> Result<Poly> {
self.scaler.scale(p)
}
}

View File

@@ -14,19 +14,19 @@ use std::sync::Arc;
/// blanket implementation <https://github.com/rust-lang/rust/issues/50133#issuecomment-488512355>.
pub trait TryConvertFrom<T>
where
Self: Sized,
Self: Sized,
{
/// Attempt to convert the `value` into a polynomial with a specific context
/// and under a specific representation. The representation may optional and
/// be specified as `None`; this is useful for example when converting from
/// a value that encodes the representation (e.g., serialization, protobuf,
/// etc.).
fn try_convert_from<R>(
value: T,
ctx: &Arc<Context>,
variable_time: bool,
representation: R,
) -> Result<Self>
where
R: Into<Option<Representation>>;
/// Attempt to convert the `value` into a polynomial with a specific context
/// and under a specific representation. The representation may optional and
/// be specified as `None`; this is useful for example when converting from
/// a value that encodes the representation (e.g., serialization, protobuf,
/// etc.).
fn try_convert_from<R>(
value: T,
ctx: &Arc<Context>,
variable_time: bool,
representation: R,
) -> Result<Self>
where
R: Into<Option<Representation>>;
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,489 +2,483 @@
use super::Modulus;
use fhe_util::is_prime;
use itertools::Itertools;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::iter::successors;
/// Returns whether a modulus p is prime and supports the Number Theoretic
/// Transform of size n.
///
/// Aborts if n is not a power of 2 that is >= 8.
pub fn supports_ntt(p: u64, n: usize) -> bool {
assert!(n >= 8 && n.is_power_of_two());
assert!(n >= 8 && n.is_power_of_two());
p % ((n as u64) << 1) == 1 && is_prime(p)
p % ((n as u64) << 1) == 1 && is_prime(p)
}
/// Number-Theoretic Transform operator.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NttOperator {
p: Modulus,
p_twice: u64,
size: usize,
omegas: Box<[u64]>,
omegas_shoup: Box<[u64]>,
omegas_inv: Box<[u64]>,
zetas_inv: Box<[u64]>,
zetas_inv_shoup: Box<[u64]>,
size_inv: u64,
size_inv_shoup: u64,
p: Modulus,
p_twice: u64,
size: usize,
omegas: Box<[u64]>,
omegas_shoup: Box<[u64]>,
zetas_inv: Box<[u64]>,
zetas_inv_shoup: Box<[u64]>,
size_inv: u64,
size_inv_shoup: u64,
}
impl NttOperator {
/// Create an NTT operator given a modulus for a specific size.
///
/// Aborts if the size is not a power of 2 that is >= 8 in debug mode.
/// Returns None if the modulus does not support the NTT for this specific
/// size.
pub fn new(p: &Modulus, size: usize) -> Option<Self> {
if !supports_ntt(p.p, size) {
None
} else {
let omega = Self::primitive_root(size, p);
let omega_inv = p.inv(omega).unwrap();
/// Create an NTT operator given a modulus for a specific size.
///
/// Aborts if the size is not a power of 2 that is >= 8 in debug mode.
/// Returns None if the modulus does not support the NTT for this specific
/// size.
pub fn new(p: &Modulus, size: usize) -> Option<Self> {
if !supports_ntt(p.p, size) {
None
} else {
let size_inv = p.inv(size as u64)?;
let mut exp = 1u64;
let mut exp_inv = 1u64;
let mut powers = Vec::with_capacity(size + 1);
let mut powers_inv = Vec::with_capacity(size + 1);
for _ in 0..size + 1 {
powers.push(exp);
powers_inv.push(exp_inv);
exp = p.mul(exp, omega);
exp_inv = p.mul(exp_inv, omega_inv);
}
let omega = Self::primitive_root(size, p);
let omega_inv = p.inv(omega)?;
let mut omegas = Vec::with_capacity(size);
let mut omegas_inv = Vec::with_capacity(size);
let mut zetas_inv = Vec::with_capacity(size);
for i in 0..size {
let j = i.reverse_bits() >> (size.leading_zeros() + 1);
omegas.push(powers[j]);
omegas_inv.push(powers_inv[j]);
zetas_inv.push(powers_inv[j + 1]);
}
let powers = successors(Some(1u64), |n| Some(p.mul(*n, omega)))
.take(size)
.collect_vec();
let powers_inv = successors(Some(omega_inv), |n| Some(p.mul(*n, omega_inv)))
.take(size)
.collect_vec();
let size_inv = p.inv(size as u64).unwrap();
let mut omegas = Vec::with_capacity(size);
let mut zetas_inv = Vec::with_capacity(size);
for i in 0..size {
let j = i.reverse_bits() >> (size.leading_zeros() + 1);
omegas.push(powers[j]);
zetas_inv.push(powers_inv[j]);
}
let omegas_shoup = p.shoup_vec(&omegas);
let zetas_inv_shoup = p.shoup_vec(&zetas_inv);
let omegas_shoup = p.shoup_vec(&omegas);
let zetas_inv_shoup = p.shoup_vec(&zetas_inv);
Some(Self {
p: p.clone(),
p_twice: p.p * 2,
size,
omegas: omegas.into_boxed_slice(),
omegas_shoup: omegas_shoup.into_boxed_slice(),
omegas_inv: omegas_inv.into_boxed_slice(),
zetas_inv: zetas_inv.into_boxed_slice(),
zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(),
size_inv,
size_inv_shoup: p.shoup(size_inv),
})
}
}
Some(Self {
p: p.clone(),
p_twice: p.p * 2,
size,
omegas: omegas.into_boxed_slice(),
omegas_shoup: omegas_shoup.into_boxed_slice(),
zetas_inv: zetas_inv.into_boxed_slice(),
zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(),
size_inv,
size_inv_shoup: p.shoup(size_inv),
})
}
}
/// Compute the forward NTT in place.
/// Aborts if a is not of the size handled by the operator.
pub fn forward(&self, a: &mut [u64]) {
debug_assert_eq!(a.len(), self.size);
/// Compute the forward NTT in place.
/// Aborts if a is not of the size handled by the operator.
pub fn forward(&self, a: &mut [u64]) {
debug_assert_eq!(a.len(), self.size);
let n = self.size;
let a_ptr = a.as_mut_ptr();
let n = self.size;
let a_ptr = a.as_mut_ptr();
let mut l = n >> 1;
let mut m = 1;
let mut k = 1;
while l > 0 {
for i in 0..m {
unsafe {
let omega = *self.omegas.get_unchecked(k);
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
k += 1;
let mut l = n >> 1;
let mut m = 1;
let mut k = 1;
while l > 0 {
for i in 0..m {
unsafe {
let omega = *self.omegas.get_unchecked(k);
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
k += 1;
let s = 2 * i * l;
match l {
1 => {
// The last level should reduce the output
let uj = &mut *a_ptr.add(s);
let ujl = &mut *a_ptr.add(s + l);
self.butterfly(uj, ujl, omega, omega_shoup);
*uj = self.reduce3(*uj);
*ujl = self.reduce3(*ujl);
}
_ => {
for j in s..(s + l) {
self.butterfly(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
omega,
omega_shoup,
);
}
}
}
}
}
l >>= 1;
m <<= 1;
}
}
let s = 2 * i * l;
match l {
1 => {
// The last level should reduce the output
let uj = &mut *a_ptr.add(s);
let ujl = &mut *a_ptr.add(s + l);
self.butterfly(uj, ujl, omega, omega_shoup);
*uj = self.reduce3(*uj);
*ujl = self.reduce3(*ujl);
}
_ => {
for j in s..(s + l) {
self.butterfly(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
omega,
omega_shoup,
);
}
}
}
}
}
l >>= 1;
m <<= 1;
}
}
/// Compute the backward NTT in place.
/// Aborts if a is not of the size handled by the operator.
pub fn backward(&self, a: &mut [u64]) {
debug_assert_eq!(a.len(), self.size);
/// Compute the backward NTT in place.
/// Aborts if a is not of the size handled by the operator.
pub fn backward(&self, a: &mut [u64]) {
debug_assert_eq!(a.len(), self.size);
let a_ptr = a.as_mut_ptr();
let a_ptr = a.as_mut_ptr();
let mut k = 0;
let mut m = self.size >> 1;
let mut l = 1;
while m > 0 {
for i in 0..m {
let s = 2 * i * l;
unsafe {
let zeta_inv = *self.zetas_inv.get_unchecked(k);
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
k += 1;
match l {
1 => {
self.inv_butterfly(
&mut *a_ptr.add(s),
&mut *a_ptr.add(s + l),
zeta_inv,
zeta_inv_shoup,
);
}
_ => {
for j in s..(s + l) {
self.inv_butterfly(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
zeta_inv,
zeta_inv_shoup,
);
}
}
}
}
}
l <<= 1;
m >>= 1;
}
let mut k = 0;
let mut m = self.size >> 1;
let mut l = 1;
while m > 0 {
for i in 0..m {
let s = 2 * i * l;
unsafe {
let zeta_inv = *self.zetas_inv.get_unchecked(k);
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
k += 1;
match l {
1 => {
self.inv_butterfly(
&mut *a_ptr.add(s),
&mut *a_ptr.add(s + l),
zeta_inv,
zeta_inv_shoup,
);
}
_ => {
for j in s..(s + l) {
self.inv_butterfly(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
zeta_inv,
zeta_inv_shoup,
);
}
}
}
}
}
l <<= 1;
m >>= 1;
}
a.iter_mut()
.for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup));
}
a.iter_mut()
.for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup));
}
/// Compute the forward NTT in place in variable time in a lazily fashion.
/// This means that the output coefficients may be up to 4 times the
/// modulus.
///
/// # Safety
/// This function assumes that a_ptr points to at least `size` elements.
/// This function is not constant time and its timing may reveal information
/// about the value being reduced.
pub unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) {
let mut l = self.size >> 1;
let mut m = 1;
let mut k = 1;
while l > 0 {
for i in 0..m {
let omega = *self.omegas.get_unchecked(k);
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
k += 1;
/// Compute the forward NTT in place in variable time in a lazily fashion.
/// This means that the output coefficients may be up to 4 times the
/// modulus.
///
/// # Safety
/// This function assumes that a_ptr points to at least `size` elements.
/// This function is not constant time and its timing may reveal information
/// about the value being reduced.
pub unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) {
let mut l = self.size >> 1;
let mut m = 1;
let mut k = 1;
while l > 0 {
for i in 0..m {
let omega = *self.omegas.get_unchecked(k);
let omega_shoup = *self.omegas_shoup.get_unchecked(k);
k += 1;
let s = 2 * i * l;
match l {
1 => {
self.butterfly_vt(
&mut *a_ptr.add(s),
&mut *a_ptr.add(s + l),
omega,
omega_shoup,
);
}
_ => {
for j in s..(s + l) {
self.butterfly_vt(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
omega,
omega_shoup,
);
}
}
}
}
l >>= 1;
m <<= 1;
}
}
let s = 2 * i * l;
match l {
1 => {
self.butterfly_vt(
&mut *a_ptr.add(s),
&mut *a_ptr.add(s + l),
omega,
omega_shoup,
);
}
_ => {
for j in s..(s + l) {
self.butterfly_vt(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
omega,
omega_shoup,
);
}
}
}
}
l >>= 1;
m <<= 1;
}
}
/// Compute the forward NTT in place in variable time.
///
/// # Safety
/// This function assumes that a_ptr points to at least `size` elements.
/// This function is not constant time and its timing may reveal information
/// about the value being reduced.
pub unsafe fn forward_vt(&self, a_ptr: *mut u64) {
self.forward_vt_lazy(a_ptr);
for i in 0..self.size {
*a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i))
}
}
/// Compute the forward NTT in place in variable time.
///
/// # Safety
/// This function assumes that a_ptr points to at least `size` elements.
/// This function is not constant time and its timing may reveal information
/// about the value being reduced.
pub unsafe fn forward_vt(&self, a_ptr: *mut u64) {
self.forward_vt_lazy(a_ptr);
for i in 0..self.size {
*a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i))
}
}
/// Compute the backward NTT in place in variable time.
///
/// # Safety
/// This function assumes that a_ptr points to at least `size` elements.
/// This function is not constant time and its timing may reveal information
/// about the value being reduced.
pub unsafe fn backward_vt(&self, a_ptr: *mut u64) {
let mut k = 0;
let mut m = self.size >> 1;
let mut l = 1;
while m > 0 {
for i in 0..m {
let s = 2 * i * l;
let zeta_inv = *self.zetas_inv.get_unchecked(k);
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
k += 1;
match l {
1 => {
self.inv_butterfly_vt(
&mut *a_ptr.add(s),
&mut *a_ptr.add(s + l),
zeta_inv,
zeta_inv_shoup,
);
}
_ => {
for j in s..(s + l) {
self.inv_butterfly_vt(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
zeta_inv,
zeta_inv_shoup,
);
}
}
}
}
l <<= 1;
m >>= 1;
}
/// Compute the backward NTT in place in variable time.
///
/// # Safety
/// This function assumes that a_ptr points to at least `size` elements.
/// This function is not constant time and its timing may reveal information
/// about the value being reduced.
pub unsafe fn backward_vt(&self, a_ptr: *mut u64) {
let mut k = 0;
let mut m = self.size >> 1;
let mut l = 1;
while m > 0 {
for i in 0..m {
let s = 2 * i * l;
let zeta_inv = *self.zetas_inv.get_unchecked(k);
let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
k += 1;
match l {
1 => {
self.inv_butterfly_vt(
&mut *a_ptr.add(s),
&mut *a_ptr.add(s + l),
zeta_inv,
zeta_inv_shoup,
);
}
_ => {
for j in s..(s + l) {
self.inv_butterfly_vt(
&mut *a_ptr.add(j),
&mut *a_ptr.add(j + l),
zeta_inv,
zeta_inv_shoup,
);
}
}
}
}
l <<= 1;
m >>= 1;
}
for i in 0..self.size as isize {
*a_ptr.offset(i) =
self.p
.mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup)
}
}
for i in 0..self.size as isize {
*a_ptr.offset(i) =
self.p
.mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup)
}
}
/// Reduce a modulo p.
///
/// Aborts if a >= 4 * p.
const fn reduce3(&self, a: u64) -> u64 {
debug_assert!(a < 4 * self.p.p);
/// Reduce a modulo p.
///
/// Aborts if a >= 4 * p.
const fn reduce3(&self, a: u64) -> u64 {
debug_assert!(a < 4 * self.p.p);
let y = Modulus::reduce1(a, 2 * self.p.p);
Modulus::reduce1(y, self.p.p)
}
let y = Modulus::reduce1(a, 2 * self.p.p);
Modulus::reduce1(y, self.p.p)
}
/// Reduce a modulo p in variable time.
///
/// Aborts if a >= 4 * p.
const unsafe fn reduce3_vt(&self, a: u64) -> u64 {
debug_assert!(a < 4 * self.p.p);
/// Reduce a modulo p in variable time.
///
/// Aborts if a >= 4 * p.
const unsafe fn reduce3_vt(&self, a: u64) -> u64 {
debug_assert!(a < 4 * self.p.p);
let y = Modulus::reduce1_vt(a, 2 * self.p.p);
Modulus::reduce1_vt(y, self.p.p)
}
let y = Modulus::reduce1_vt(a, 2 * self.p.p);
Modulus::reduce1_vt(y, self.p.p)
}
/// NTT Butterfly.
fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
debug_assert!(w < self.p.p);
debug_assert_eq!(self.p.shoup(w), w_shoup);
/// NTT Butterfly.
fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
debug_assert!(w < self.p.p);
debug_assert_eq!(self.p.shoup(w), w_shoup);
*x = Modulus::reduce1(*x, self.p_twice);
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
*y = *x + self.p_twice - t;
*x += t;
*x = Modulus::reduce1(*x, self.p_twice);
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
*y = *x + self.p_twice - t;
*x += t;
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
}
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
}
/// NTT Butterfly in variable time.
unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
debug_assert!(w < self.p.p);
debug_assert_eq!(self.p.shoup(w), w_shoup);
/// NTT Butterfly in variable time.
unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
debug_assert!(w < self.p.p);
debug_assert_eq!(self.p.shoup(w), w_shoup);
*x = Modulus::reduce1_vt(*x, self.p_twice);
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
*y = *x + self.p_twice - t;
*x += t;
*x = Modulus::reduce1_vt(*x, self.p_twice);
let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
*y = *x + self.p_twice - t;
*x += t;
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
}
debug_assert!(*x < 4 * self.p.p);
debug_assert!(*y < 4 * self.p.p);
}
/// Inverse NTT butterfly.
fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
debug_assert!(z < self.p.p);
debug_assert_eq!(self.p.shoup(z), z_shoup);
/// Inverse NTT butterfly.
fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
debug_assert!(z < self.p.p);
debug_assert_eq!(self.p.shoup(z), z_shoup);
let t = *x;
*x = Modulus::reduce1(*y + t, self.p_twice);
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
let t = *x;
*x = Modulus::reduce1(*y + t, self.p_twice);
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
}
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
}
/// Inverse NTT butterfly in variable time
unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
debug_assert!(z < self.p.p);
debug_assert_eq!(self.p.shoup(z), z_shoup);
/// Inverse NTT butterfly in variable time
unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
debug_assert!(z < self.p.p);
debug_assert_eq!(self.p.shoup(z), z_shoup);
let t = *x;
*x = Modulus::reduce1_vt(*y + t, self.p_twice);
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
let t = *x;
*x = Modulus::reduce1_vt(*y + t, self.p_twice);
*y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
}
debug_assert!(*x < self.p_twice);
debug_assert!(*y < self.p_twice);
}
/// Returns a 2n-th primitive root modulo p.
///
/// Aborts if p is not prime or n is not a power of 2 that is >= 8.
fn primitive_root(n: usize, p: &Modulus) -> u64 {
debug_assert!(supports_ntt(p.p, n));
/// Returns a 2n-th primitive root modulo p.
///
/// Aborts if p is not prime or n is not a power of 2 that is >= 8.
fn primitive_root(n: usize, p: &Modulus) -> u64 {
debug_assert!(supports_ntt(p.p, n));
let lambda = (p.p - 1) / (2 * n as u64);
let lambda = (p.p - 1) / (2 * n as u64);
let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0);
for _ in 0..100 {
let mut root = rng.gen_range(0..p.p);
root = p.pow(root, lambda);
if Self::is_primitive_root(root, 2 * n, p) {
return root;
}
}
let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0);
for _ in 0..100 {
let mut root = rng.gen_range(0..p.p);
root = p.pow(root, lambda);
if Self::is_primitive_root(root, 2 * n, p) {
return root;
}
}
debug_assert!(false, "Couldn't find primitive root");
0
}
debug_assert!(false, "Couldn't find primitive root");
0
}
/// Returns whether a is a n-th primitive root of unity.
///
/// Aborts if a >= p in debug mode.
fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool {
debug_assert!(a < p.p);
debug_assert!(supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here.
/// Returns whether a is a n-th primitive root of unity.
///
/// Aborts if a >= p in debug mode.
fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool {
debug_assert!(a < p.p);
debug_assert!(supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here.
// A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p
// for all prime p dividing n.
(p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1)
}
// A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p
// for all prime p dividing n.
(p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1)
}
}
#[cfg(test)]
mod tests {
use rand::thread_rng;
use rand::thread_rng;
use super::{supports_ntt, NttOperator};
use crate::zq::Modulus;
use super::{supports_ntt, NttOperator};
use crate::zq::Modulus;
#[test]
fn constructor() {
for size in [8, 1024] {
for p in [1153, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
let supports_ntt = supports_ntt(p, size);
#[test]
fn constructor() {
for size in [8, 1024] {
for p in [1153, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
let supports_ntt = supports_ntt(p, size);
let op = NttOperator::new(&q, size);
let op = NttOperator::new(&q, size);
if supports_ntt {
assert!(op.is_some());
} else {
assert!(op.is_none());
}
}
}
}
if supports_ntt {
assert!(op.is_some());
} else {
assert!(op.is_none());
}
}
}
}
#[test]
fn bijection() {
let ntests = 100;
let mut rng = thread_rng();
#[test]
fn bijection() {
let ntests = 100;
let mut rng = thread_rng();
for size in [8, 1024] {
for p in [1153, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
for size in [8, 1024] {
for p in [1153, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
if supports_ntt(p, size) {
let op = NttOperator::new(&q, size).unwrap();
if supports_ntt(p, size) {
let op = NttOperator::new(&q, size).unwrap();
for _ in 0..ntests {
let mut a = q.random_vec(size, &mut rng);
let a_clone = a.clone();
let mut b = a.clone();
for _ in 0..ntests {
let mut a = q.random_vec(size, &mut rng);
let a_clone = a.clone();
let mut b = a.clone();
op.forward(&mut a);
assert_ne!(a, a_clone);
op.forward(&mut a);
assert_ne!(a, a_clone);
unsafe { op.forward_vt(b.as_mut_ptr()) }
assert_eq!(a, b);
unsafe { op.forward_vt(b.as_mut_ptr()) }
assert_eq!(a, b);
op.backward(&mut a);
assert_eq!(a, a_clone);
op.backward(&mut a);
assert_eq!(a, a_clone);
unsafe { op.backward_vt(b.as_mut_ptr()) }
assert_eq!(a, b);
}
}
}
}
}
unsafe { op.backward_vt(b.as_mut_ptr()) }
assert_eq!(a, b);
}
}
}
}
}
#[test]
fn forward_lazy() {
let ntests = 100;
let mut rng = thread_rng();
#[test]
fn forward_lazy() {
let ntests = 100;
let mut rng = thread_rng();
for size in [8, 1024] {
for p in [1153, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
for size in [8, 1024] {
for p in [1153, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
if supports_ntt(p, size) {
let op = NttOperator::new(&q, size).unwrap();
if supports_ntt(p, size) {
let op = NttOperator::new(&q, size).unwrap();
for _ in 0..ntests {
let mut a = q.random_vec(size, &mut rng);
let mut a_lazy = a.clone();
for _ in 0..ntests {
let mut a = q.random_vec(size, &mut rng);
let mut a_lazy = a.clone();
op.forward(&mut a);
op.forward(&mut a);
unsafe {
op.forward_vt_lazy(a_lazy.as_mut_ptr());
q.reduce_vec(&mut a_lazy);
}
unsafe {
op.forward_vt_lazy(a_lazy.as_mut_ptr());
q.reduce_vec(&mut a_lazy);
}
assert_eq!(a, a_lazy);
}
}
}
}
}
assert_eq!(a, a_lazy);
}
}
}
}
}
}

View File

@@ -7,113 +7,113 @@ use num_bigint::BigUint;
/// These optimized operations are possible when the modulus verifies
/// Equation (1) of <https://hal.archives-ouvertes.fr/hal-01242273/document>.
pub fn supports_opt(p: u64) -> bool {
if p.leading_zeros() < 1 {
return false;
}
if p.leading_zeros() < 1 {
return false;
}
// Let's multiply the inequality by (2^s0+1)*2^(3s0):
// we want to output true when
// (2^(3s0)+1) * 2^64 < 2^(3s0) * (2^s0+1) * p
let mut middle = BigUint::from(1u64) << (3 * p.leading_zeros() as usize);
let left_side = (&middle + 1u64) << 64;
middle *= (1u64 << p.leading_zeros()) + 1;
middle *= p;
// Let's multiply the inequality by (2^s0+1)*2^(3s0):
// we want to output true when
// (2^(3s0)+1) * 2^64 < 2^(3s0) * (2^s0+1) * p
let mut middle = BigUint::from(1u64) << (3 * p.leading_zeros() as usize);
let left_side = (&middle + 1u64) << 64;
middle *= (1u64 << p.leading_zeros()) + 1;
middle *= p;
left_side < middle
left_side < middle
}
/// Generate a `num_bits`-bit prime, congruent to 1 mod `modulo`, strictly
/// smaller than `upper_bound`. Note that `num_bits` must belong to (10..=62),
/// and upper_bound must be <= 1 << num_bits.
pub fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option<u64> {
if !(10..=62).contains(&num_bits) {
None
} else {
debug_assert!(
(1u64 << num_bits) >= upper_bound,
"upper_bound larger than number of bits"
);
if !(10..=62).contains(&num_bits) {
None
} else {
debug_assert!(
(1u64 << num_bits) >= upper_bound,
"upper_bound larger than number of bits"
);
let leading_zeros = (64 - num_bits) as u32;
let leading_zeros = (64 - num_bits) as u32;
let mut tentative_prime = upper_bound - 1;
while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros {
tentative_prime -= 1
}
let mut tentative_prime = upper_bound - 1;
while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros {
tentative_prime -= 1
}
while tentative_prime.leading_zeros() == leading_zeros
&& !is_prime(tentative_prime)
&& tentative_prime >= modulo
{
tentative_prime -= modulo
}
while tentative_prime.leading_zeros() == leading_zeros
&& !is_prime(tentative_prime)
&& tentative_prime >= modulo
{
tentative_prime -= modulo
}
if tentative_prime.leading_zeros() == leading_zeros && is_prime(tentative_prime) {
Some(tentative_prime)
} else {
None
}
}
if tentative_prime.leading_zeros() == leading_zeros && is_prime(tentative_prime) {
Some(tentative_prime)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::generate_prime;
use fhe_util::catch_unwind;
use super::generate_prime;
use fhe_util::catch_unwind;
// Verifies that the same moduli as in the NFLlib library are generated.
// <https://github.com/quarkslab/NFLlib/blob/master/include/nfl/params.hpp>
#[test]
fn nfl_62bit_primes() {
let mut generated = vec![];
let mut upper_bound = u64::MAX >> 2;
while generated.len() != 20 {
let p = generate_prime(62, 2 * 1048576, upper_bound);
assert!(p.is_some());
upper_bound = p.unwrap();
generated.push(upper_bound);
}
assert_eq!(
generated,
vec![
4611686018326724609,
4611686018309947393,
4611686018282684417,
4611686018257518593,
4611686018232352769,
4611686018171535361,
4611686018106523649,
4611686018058289153,
4611686018051997697,
4611686017974403073,
4611686017812922369,
4611686017781465089,
4611686017773076481,
4611686017678704641,
4611686017666121729,
4611686017647247361,
4611686017590624257,
4611686017554972673,
4611686017529806849,
4611686017517223937
]
)
}
// Verifies that the same moduli as in the NFLlib library are generated.
// <https://github.com/quarkslab/NFLlib/blob/master/include/nfl/params.hpp>
#[test]
fn nfl_62bit_primes() {
let mut generated = vec![];
let mut upper_bound = u64::MAX >> 2;
while generated.len() != 20 {
let p = generate_prime(62, 2 * 1048576, upper_bound);
assert!(p.is_some());
upper_bound = p.unwrap();
generated.push(upper_bound);
}
assert_eq!(
generated,
vec![
4611686018326724609,
4611686018309947393,
4611686018282684417,
4611686018257518593,
4611686018232352769,
4611686018171535361,
4611686018106523649,
4611686018058289153,
4611686018051997697,
4611686017974403073,
4611686017812922369,
4611686017781465089,
4611686017773076481,
4611686017678704641,
4611686017666121729,
4611686017647247361,
4611686017590624257,
4611686017554972673,
4611686017529806849,
4611686017517223937
]
)
}
#[test]
fn upper_bound() {
debug_assert!(catch_unwind(|| generate_prime(62, 2 * 1048576, (1 << 62) + 1)).is_err());
}
#[test]
fn upper_bound() {
debug_assert!(catch_unwind(|| generate_prime(62, 2 * 1048576, (1 << 62) + 1)).is_err());
}
#[test]
fn modulo_too_large() {
assert!(generate_prime(10, 2048, 1 << 10).is_none());
}
#[test]
fn modulo_too_large() {
assert!(generate_prime(10, 2048, 1 << 10).is_none());
}
#[test]
fn not_found() {
// 1033 is the smallest 11-bit prime congruent to 1 modulo 16, so looking for a
// smaller one should fail.
assert!(generate_prime(11, 16, 1033).is_none());
}
#[test]
fn not_found() {
// 1033 is the smallest 11-bit prime congruent to 1 modulo 16, so looking for a
// smaller one should fail.
assert!(generate_prime(11, 16, 1033).is_none());
}
}

View File

@@ -13,21 +13,21 @@ pub trait FheParameters {}
/// Indicates that an object is parametrized.
pub trait FheParametrized {
/// The type of the FHE parameters.
type Parameters: FheParameters;
/// The type of the FHE parameters.
type Parameters: FheParameters;
}
/// Indicates that Self parameters can be switched.
pub trait FheParametersSwitchable<S: FheParametrized>
where
Self: FheParametrized,
Self: FheParametrized,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Attempt to switch the underlying parameters using the associated
/// switcher.
fn switch_parameters(&mut self, switcher: &S) -> Result<(), Self::Error>;
/// Attempt to switch the underlying parameters using the associated
/// switcher.
fn switch_parameters(&mut self, switcher: &S) -> Result<(), Self::Error>;
}
/// Encoding used when encoding a [`FhePlaintext`].
@@ -36,137 +36,137 @@ pub trait FhePlaintextEncoding {}
/// A plaintext which will encode one (or more) value(s).
pub trait FhePlaintext
where
Self: Sized + FheParametrized,
Self: Sized + FheParametrized,
{
/// The type of the encoding.
type Encoding: FhePlaintextEncoding;
/// The type of the encoding.
type Encoding: FhePlaintextEncoding;
}
/// Encode a value using a specified encoding.
pub trait FheEncoder<V>
where
Self: FhePlaintext,
Self: FhePlaintext,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Attempt to encode a value using a specified encoding.
fn try_encode(
value: V,
encoding: Self::Encoding,
par: &Arc<Self::Parameters>,
) -> Result<Self, Self::Error>;
/// Attempt to encode a value using a specified encoding.
fn try_encode(
value: V,
encoding: Self::Encoding,
par: &Arc<Self::Parameters>,
) -> Result<Self, Self::Error>;
}
/// Encode a value using a specified encoding.
pub trait FheEncoderVariableTime<V>
where
Self: FhePlaintext,
Self: FhePlaintext,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Attempt to encode a value using a specified encoding.
/// # Safety
/// This encoding runs in variable time and may leak information about the
/// value.
unsafe fn try_encode_vt(
value: V,
encoding: Self::Encoding,
par: &Arc<Self::Parameters>,
) -> Result<Self, Self::Error>;
/// Attempt to encode a value using a specified encoding.
/// # Safety
/// This encoding runs in variable time and may leak information about the
/// value.
unsafe fn try_encode_vt(
value: V,
encoding: Self::Encoding,
par: &Arc<Self::Parameters>,
) -> Result<Self, Self::Error>;
}
/// Decode the value in the plaintext with the specified (optional) encoding.
pub trait FheDecoder<P: FhePlaintext>
where
Self: Sized,
Self: Sized,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Attempt to decode a [`FhePlaintext`] into a value, using an (optional)
/// encoding.
fn try_decode<O>(pt: &P, encoding: O) -> Result<Self, Self::Error>
where
O: Into<Option<P::Encoding>>;
/// Attempt to decode a [`FhePlaintext`] into a value, using an (optional)
/// encoding.
fn try_decode<O>(pt: &P, encoding: O) -> Result<Self, Self::Error>
where
O: Into<Option<P::Encoding>>;
}
/// A ciphertext which will encrypt a plaintext.
pub trait FheCiphertext
where
Self: Sized + Serialize + FheParametrized + DeserializeParametrized,
Self: Sized + Serialize + FheParametrized + DeserializeParametrized,
{
}
/// Encrypt a plaintext into a ciphertext.
pub trait FheEncrypter<
P: FhePlaintext<Parameters = Self::Parameters>,
C: FheCiphertext<Parameters = Self::Parameters>,
P: FhePlaintext<Parameters = Self::Parameters>,
C: FheCiphertext<Parameters = Self::Parameters>,
>: FheParametrized
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Try to encrypt an [`FhePlaintext`] into an [`FheCiphertext`].
fn try_encrypt<R: RngCore + CryptoRng>(&self, pt: &P, rng: &mut R) -> Result<C, Self::Error>;
/// Try to encrypt an [`FhePlaintext`] into an [`FheCiphertext`].
fn try_encrypt<R: RngCore + CryptoRng>(&self, pt: &P, rng: &mut R) -> Result<C, Self::Error>;
}
/// Decrypt a ciphertext into a plaintext
pub trait FheDecrypter<
P: FhePlaintext<Parameters = Self::Parameters>,
C: FheCiphertext<Parameters = Self::Parameters>,
P: FhePlaintext<Parameters = Self::Parameters>,
C: FheCiphertext<Parameters = Self::Parameters>,
>: FheParametrized
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Try to decrypt an [`FheCiphertext`] into an [`FhePlaintext`].
fn try_decrypt(&self, ct: &C) -> Result<P, Self::Error>;
/// Try to decrypt an [`FheCiphertext`] into an [`FhePlaintext`].
fn try_decrypt(&self, ct: &C) -> Result<P, Self::Error>;
}
/// Serialization.
pub trait Serialize {
/// Serialize `Self` into a vector of bytes.
fn to_bytes(&self) -> Vec<u8>;
/// Serialize `Self` into a vector of bytes.
fn to_bytes(&self) -> Vec<u8>;
}
/// Deserialization of a parametrized value.
pub trait DeserializeParametrized
where
Self: Sized,
Self: FheParametrized,
Self: Sized,
Self: FheParametrized,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Attempt to deserialize from a vector of bytes
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self, Self::Error>;
/// Attempt to deserialize from a vector of bytes
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self, Self::Error>;
}
/// Deserialization setting an explicit context.
pub trait DeserializeWithContext
where
Self: Sized,
Self: Sized,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// The type of context.
type Context;
/// The type of context.
type Context;
/// Attempt to deserialize from a vector of bytes
fn from_bytes(bytes: &[u8], ctx: &Arc<Self::Context>) -> Result<Self, Self::Error>;
/// Attempt to deserialize from a vector of bytes
fn from_bytes(bytes: &[u8], ctx: &Arc<Self::Context>) -> Result<Self, Self::Error>;
}
/// Deserialization without context.
pub trait Deserialize
where
Self: Sized,
Self: Sized,
{
/// The type of error returned.
type Error;
/// The type of error returned.
type Error;
/// Attempt to deserialize from a vector of bytes
fn try_deserialize(bytes: &[u8]) -> Result<Self, Self::Error>;
/// Attempt to deserialize from a vector of bytes
fn try_deserialize(bytes: &[u8]) -> Result<Self, Self::Error>;
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,217 +8,217 @@ use std::ops::{Not, Shr, ShrAssign};
pub struct U256(u64, u64, u64, u64);
impl U256 {
/// Returns the additive identity element, 0.
pub const fn zero() -> Self {
Self(0, 0, 0, 0)
}
/// Returns the additive identity element, 0.
pub const fn zero() -> Self {
Self(0, 0, 0, 0)
}
/// Add an U256 to self, wrapping modulo 2^256.
pub fn wrapping_add_assign(&mut self, other: Self) {
let (a, c1) = self.0.overflowing_add(other.0);
let (b, c2) = self.1.overflowing_add(other.1);
let (c, c3) = b.overflowing_add(c1 as u64);
let (d, c4) = self.2.overflowing_add(other.2);
let (e, c5) = d.overflowing_add((c2 | c3) as u64);
let f = self.3.wrapping_add(other.3);
let g = f.wrapping_add((c4 | c5) as u64);
self.0 = a;
self.1 = c;
self.2 = e;
self.3 = g;
}
/// Add an U256 to self, wrapping modulo 2^256.
pub fn wrapping_add_assign(&mut self, other: Self) {
let (a, c1) = self.0.overflowing_add(other.0);
let (b, c2) = self.1.overflowing_add(other.1);
let (c, c3) = b.overflowing_add(c1 as u64);
let (d, c4) = self.2.overflowing_add(other.2);
let (e, c5) = d.overflowing_add((c2 | c3) as u64);
let f = self.3.wrapping_add(other.3);
let g = f.wrapping_add((c4 | c5) as u64);
self.0 = a;
self.1 = c;
self.2 = e;
self.3 = g;
}
/// Subtract an U256 to self, wrapping modulo 2^256.
pub fn wrapping_sub_assign(&mut self, other: Self) {
let (a, b1) = self.0.overflowing_sub(other.0);
let (b, b2) = self.1.overflowing_sub(other.1);
let (c, b3) = b.overflowing_sub(b1 as u64);
let (d, b4) = self.2.overflowing_sub(other.2);
let (e, b5) = d.overflowing_sub((b2 | b3) as u64);
let f = self.3.wrapping_sub(other.3);
let g = f.wrapping_sub((b4 | b5) as u64);
self.0 = a;
self.1 = c;
self.2 = e;
self.3 = g;
}
/// Subtract an U256 to self, wrapping modulo 2^256.
pub fn wrapping_sub_assign(&mut self, other: Self) {
let (a, b1) = self.0.overflowing_sub(other.0);
let (b, b2) = self.1.overflowing_sub(other.1);
let (c, b3) = b.overflowing_sub(b1 as u64);
let (d, b4) = self.2.overflowing_sub(other.2);
let (e, b5) = d.overflowing_sub((b2 | b3) as u64);
let f = self.3.wrapping_sub(other.3);
let g = f.wrapping_sub((b4 | b5) as u64);
self.0 = a;
self.1 = c;
self.2 = e;
self.3 = g;
}
/// Returns the most significant bit of the unsigned integer.
pub const fn msb(self) -> u64 {
self.3 >> 63
}
/// Returns the most significant bit of the unsigned integer.
pub const fn msb(self) -> u64 {
self.3 >> 63
}
}
impl From<[u64; 4]> for U256 {
fn from(a: [u64; 4]) -> Self {
Self(a[0], a[1], a[2], a[3])
}
fn from(a: [u64; 4]) -> Self {
Self(a[0], a[1], a[2], a[3])
}
}
impl From<[u128; 2]> for U256 {
fn from(a: [u128; 2]) -> Self {
Self(
a[0] as u64,
(a[0] >> 64) as u64,
a[1] as u64,
(a[1] >> 64) as u64,
)
}
fn from(a: [u128; 2]) -> Self {
Self(
a[0] as u64,
(a[0] >> 64) as u64,
a[1] as u64,
(a[1] >> 64) as u64,
)
}
}
impl From<U256> for [u64; 4] {
fn from(a: U256) -> [u64; 4] {
[a.0, a.1, a.2, a.3]
}
fn from(a: U256) -> [u64; 4] {
[a.0, a.1, a.2, a.3]
}
}
impl From<U256> for [u128; 2] {
fn from(a: U256) -> [u128; 2] {
[
(a.0 as u128) + ((a.1 as u128) << 64),
(a.2 as u128) + ((a.3 as u128) << 64),
]
}
fn from(a: U256) -> [u128; 2] {
[
(a.0 as u128) + ((a.1 as u128) << 64),
(a.2 as u128) + ((a.3 as u128) << 64),
]
}
}
impl From<&U256> for u128 {
fn from(v: &U256) -> Self {
debug_assert!(v.2 == 0 && v.3 == 0);
(v.0 as u128) + ((v.1 as u128) << 64)
}
fn from(v: &U256) -> Self {
debug_assert!(v.2 == 0 && v.3 == 0);
(v.0 as u128) + ((v.1 as u128) << 64)
}
}
impl Not for U256 {
type Output = Self;
type Output = Self;
fn not(self) -> Self {
Self(!self.0, !self.1, !self.2, !self.3)
}
fn not(self) -> Self {
Self(!self.0, !self.1, !self.2, !self.3)
}
}
impl Shr<usize> for U256 {
type Output = Self;
type Output = Self;
fn shr(self, rhs: usize) -> Self {
let mut r = self;
r >>= rhs;
r
}
fn shr(self, rhs: usize) -> Self {
let mut r = self;
r >>= rhs;
r
}
}
impl ShrAssign<usize> for U256 {
fn shr_assign(&mut self, rhs: usize) {
debug_assert!(rhs < 256);
fn shr_assign(&mut self, rhs: usize) {
debug_assert!(rhs < 256);
if rhs >= 192 {
self.0 = self.3 >> (rhs - 192);
self.1 = 0;
self.2 = 0;
self.3 = 0;
} else if rhs > 128 {
self.0 = (self.2 >> (rhs - 128)) | (self.3 << (192 - rhs));
self.1 = self.3 >> (rhs - 128);
self.2 = 0;
self.3 = 0;
} else if rhs == 128 {
self.0 = self.2;
self.1 = self.3;
self.2 = 0;
self.3 = 0;
} else if rhs > 64 {
self.0 = (self.1 >> (rhs - 64)) | (self.2 << (128 - rhs));
self.1 = (self.2 >> (rhs - 64)) | (self.3 << (128 - rhs));
self.2 = self.3 >> (rhs - 64);
self.3 = 0;
} else if rhs == 64 {
self.0 = self.1;
self.1 = self.2;
self.2 = self.3;
self.3 = 0;
} else if rhs > 0 {
self.0 = (self.0 >> rhs) | (self.1 << (64 - rhs));
self.1 = (self.1 >> rhs) | (self.2 << (64 - rhs));
self.2 = (self.2 >> rhs) | (self.3 << (64 - rhs));
self.3 >>= rhs;
}
}
if rhs >= 192 {
self.0 = self.3 >> (rhs - 192);
self.1 = 0;
self.2 = 0;
self.3 = 0;
} else if rhs > 128 {
self.0 = (self.2 >> (rhs - 128)) | (self.3 << (192 - rhs));
self.1 = self.3 >> (rhs - 128);
self.2 = 0;
self.3 = 0;
} else if rhs == 128 {
self.0 = self.2;
self.1 = self.3;
self.2 = 0;
self.3 = 0;
} else if rhs > 64 {
self.0 = (self.1 >> (rhs - 64)) | (self.2 << (128 - rhs));
self.1 = (self.2 >> (rhs - 64)) | (self.3 << (128 - rhs));
self.2 = self.3 >> (rhs - 64);
self.3 = 0;
} else if rhs == 64 {
self.0 = self.1;
self.1 = self.2;
self.2 = self.3;
self.3 = 0;
} else if rhs > 0 {
self.0 = (self.0 >> rhs) | (self.1 << (64 - rhs));
self.1 = (self.1 >> rhs) | (self.2 << (64 - rhs));
self.2 = (self.2 >> rhs) | (self.3 << (64 - rhs));
self.3 >>= rhs;
}
}
}
#[cfg(test)]
mod tests {
use super::U256;
use super::U256;
#[test]
fn zero() {
assert_eq!(u128::from(&U256::zero()), 0u128);
}
#[test]
fn zero() {
assert_eq!(u128::from(&U256::zero()), 0u128);
}
proptest! {
proptest! {
#[test]
fn u128(a: u128) {
prop_assert_eq!(a, u128::from(&U256::from([a, 0])));
}
#[test]
fn u128(a: u128) {
prop_assert_eq!(a, u128::from(&U256::from([a, 0])));
}
#[test]
fn from_into_u64(a: u64, b: u64, c: u64, d:u64) {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d])), [a, b, c, d]);
}
#[test]
fn from_into_u64(a: u64, b: u64, c: u64, d:u64) {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d])), [a, b, c, d]);
}
#[test]
fn from_into_u128(a: u128, b: u128) {
prop_assert_eq!(<[u128; 2]>::from(U256::from([a, b])), [a, b]);
}
#[test]
fn from_into_u128(a: u128, b: u128) {
prop_assert_eq!(<[u128; 2]>::from(U256::from([a, b])), [a, b]);
}
#[test]
fn shift(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 0), [a, b, c, d]);
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 64), [b, c, d, 0]);
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 128), [c, d, 0, 0]);
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 192), [d, 0, 0, 0]);
#[test]
fn shift(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 0), [a, b, c, d]);
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 64), [b, c, d, 0]);
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 128), [c, d, 0, 0]);
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> 192), [d, 0, 0, 0]);
prop_assume!(shift % 64 != 0);
if shift < 64 {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
} else if shift < 128 {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
} else if shift < 192 {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
} else {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(d >> (shift - 192)), 0, 0, 0]);
}
}
prop_assume!(shift % 64 != 0);
if shift < 64 {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
} else if shift < 128 {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
} else if shift < 192 {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
} else {
prop_assert_eq!(<[u64; 4]>::from(U256::from([a, b, c, d]) >> shift), [(d >> (shift - 192)), 0, 0, 0]);
}
}
#[test]
fn shift_assign(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
let mut u = U256::from([a, b, c, d]);
#[test]
fn shift_assign(a: u64, b: u64, c: u64, d:u64, shift in 0..256usize) {
let mut u = U256::from([a, b, c, d]);
u >>= 0;
prop_assert_eq!(<[u64; 4]>::from(u), [a, b, c, d]);
u >>= 0;
prop_assert_eq!(<[u64; 4]>::from(u), [a, b, c, d]);
u >>= 64;
prop_assert_eq!(<[u64; 4]>::from(u), [b, c, d, 0]);
u >>= 64;
prop_assert_eq!(<[u64; 4]>::from(u), [b, c, d, 0]);
u = U256::from([a, b, c, d]);
u >>= 128;
prop_assert_eq!(<[u64; 4]>::from(u), [c, d, 0, 0]);
u = U256::from([a, b, c, d]);
u >>= 128;
prop_assert_eq!(<[u64; 4]>::from(u), [c, d, 0, 0]);
u = U256::from([a, b, c, d]);
u >>= 192;
prop_assert_eq!(<[u64; 4]>::from(u), [d, 0, 0, 0]);
u = U256::from([a, b, c, d]);
u >>= 192;
prop_assert_eq!(<[u64; 4]>::from(u), [d, 0, 0, 0]);
prop_assume!(shift % 64 != 0);
u = U256::from([a, b, c, d]);
u >>= shift;
if shift < 64 {
prop_assert_eq!(<[u64; 4]>::from(u), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
} else if shift < 128 {
prop_assert_eq!(<[u64; 4]>::from(u), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
} else if shift < 192 {
prop_assert_eq!(<[u64; 4]>::from(u), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
} else {
prop_assert_eq!(<[u64; 4]>::from(u), [(d >> (shift - 192)), 0, 0, 0]);
}
}
}
prop_assume!(shift % 64 != 0);
u = U256::from([a, b, c, d]);
u >>= shift;
if shift < 64 {
prop_assert_eq!(<[u64; 4]>::from(u), [(a >> shift) | (b << (64 - shift)), (b >> shift) | (c << (64 - shift)), (c >> shift) | (d << (64 - shift)), d >> shift]);
} else if shift < 128 {
prop_assert_eq!(<[u64; 4]>::from(u), [(b >> (shift - 64)) | (c << (128 - shift)), (c >> (shift - 64)) | (d << (128 - shift)), (d >> (shift - 64)), 0]);
} else if shift < 192 {
prop_assert_eq!(<[u64; 4]>::from(u), [(c >> (shift - 128)) | (d << (192 - shift)), (d >> (shift - 128)), 0, 0]);
} else {
prop_assert_eq!(<[u64; 4]>::from(u), [(d >> (shift - 192)), 0, 0, 0]);
}
}
}
}

View File

@@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use fhe::bfv::{
BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Multiplicator, Plaintext, PublicKey,
RelinearizationKey, SecretKey,
BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Multiplicator, Plaintext, PublicKey,
RelinearizationKey, SecretKey,
};
use fhe_math::rns::{RnsContext, ScalingFactor};
use fhe_math::zq::primes::generate_prime;
@@ -13,275 +13,275 @@ use rand::{rngs::OsRng, thread_rng};
use std::time::Duration;
pub fn bfv_benchmark(c: &mut Criterion) {
let mut rng = thread_rng();
let mut group = c.benchmark_group("bfv");
group.sample_size(10);
group.warm_up_time(Duration::from_millis(600));
group.measurement_time(Duration::from_millis(1000));
let mut rng = thread_rng();
let mut group = c.benchmark_group("bfv");
group.sample_size(10);
group.warm_up_time(Duration::from_millis(600));
group.measurement_time(Duration::from_millis(1000));
for par in BfvParameters::default_parameters_128(20) {
let sk = SecretKey::random(&par, &mut OsRng);
let ek = if par.moduli().len() > 1 {
Some(
EvaluationKeyBuilder::new(&sk)
.unwrap()
.enable_inner_sum()
.unwrap()
.enable_column_rotation(1)
.unwrap()
.enable_expansion(ilog2(par.degree() as u64))
.unwrap()
.build(&mut rng)
.unwrap(),
)
} else {
None
};
for par in BfvParameters::default_parameters_128(20) {
let sk = SecretKey::random(&par, &mut OsRng);
let ek = if par.moduli().len() > 1 {
Some(
EvaluationKeyBuilder::new(&sk)
.unwrap()
.enable_inner_sum()
.unwrap()
.enable_column_rotation(1)
.unwrap()
.enable_expansion(ilog2(par.degree() as u64))
.unwrap()
.build(&mut rng)
.unwrap(),
)
} else {
None
};
let rk = if par.moduli().len() > 1 {
Some(RelinearizationKey::new(&sk, &mut rng).unwrap())
} else {
None
};
let rk = if par.moduli().len() > 1 {
Some(RelinearizationKey::new(&sk, &mut rng).unwrap())
} else {
None
};
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par).unwrap();
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), &par).unwrap();
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
let c2: Ciphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par).unwrap();
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), &par).unwrap();
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
let c2: Ciphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
let q = par.moduli_sizes().iter().sum::<usize>();
let q = par.moduli_sizes().iter().sum::<usize>();
group.bench_function(
BenchmarkId::new("keygen_sk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| SecretKey::random(&par, &mut OsRng));
},
);
group.bench_function(
BenchmarkId::new("keygen_sk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| SecretKey::random(&par, &mut OsRng));
},
);
group.bench_function(
BenchmarkId::new("keygen_pk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| PublicKey::new(&sk, &mut rng));
},
);
group.bench_function(
BenchmarkId::new("keygen_pk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| PublicKey::new(&sk, &mut rng));
},
);
group.bench_function(
BenchmarkId::new("keygen_rk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| RelinearizationKey::new(&sk, &mut rng));
},
);
group.bench_function(
BenchmarkId::new("keygen_rk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| RelinearizationKey::new(&sk, &mut rng));
},
);
group.bench_function(
BenchmarkId::new("encode_poly", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), &par));
},
);
group.bench_function(
BenchmarkId::new("encode_poly", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), &par));
},
);
group.bench_function(
BenchmarkId::new("encode_simd", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par));
},
);
group.bench_function(
BenchmarkId::new("encode_simd", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), &par));
},
);
group.bench_function(
BenchmarkId::new("encrypt_sk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| {
let _: fhe::Result<Ciphertext> = sk.try_encrypt(&pt1, &mut rng);
});
},
);
group.bench_function(
BenchmarkId::new("encrypt_sk", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| {
let _: fhe::Result<Ciphertext> = sk.try_encrypt(&pt1, &mut rng);
});
},
);
group.bench_function(
BenchmarkId::new("add_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 + &c2);
},
);
group.bench_function(
BenchmarkId::new("add_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 + &c2);
},
);
group.bench_function(
BenchmarkId::new("add_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 += &c2);
},
);
group.bench_function(
BenchmarkId::new("add_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 += &c2);
},
);
group.bench_function(
BenchmarkId::new("add_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 + &pt2);
},
);
group.bench_function(
BenchmarkId::new("add_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 + &pt2);
},
);
group.bench_function(
BenchmarkId::new("add_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 += &pt2);
},
);
group.bench_function(
BenchmarkId::new("add_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 += &pt2);
},
);
group.bench_function(
BenchmarkId::new("sub_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 - &c2);
},
);
group.bench_function(
BenchmarkId::new("sub_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 - &c2);
},
);
group.bench_function(
BenchmarkId::new("sub_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 -= &c2);
},
);
group.bench_function(
BenchmarkId::new("sub_assign_ct", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 -= &c2);
},
);
group.bench_function(
BenchmarkId::new("sub_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 - &pt2);
},
);
group.bench_function(
BenchmarkId::new("sub_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = &c1 - &pt2);
},
);
group.bench_function(
BenchmarkId::new("sub_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 -= &pt2);
},
);
group.bench_function(
BenchmarkId::new("sub_assign_pt", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 -= &pt2);
},
);
group.bench_function(
BenchmarkId::new("neg", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = -&c2);
},
);
group.bench_function(
BenchmarkId::new("neg", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = -&c2);
},
);
let mut c3 = &c1 * &c1;
let c3_clone = c3.clone();
if let Some(rk) = rk.as_ref() {
group.bench_function(
BenchmarkId::new("relinearize", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| {
assert!(rk.relinearizes(&mut c3).is_ok());
c3 = c3_clone.clone();
});
},
);
}
let mut c3 = &c1 * &c1;
let c3_clone = c3.clone();
if let Some(rk) = rk.as_ref() {
group.bench_function(
BenchmarkId::new("relinearize", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| {
assert!(rk.relinearizes(&mut c3).is_ok());
c3 = c3_clone.clone();
});
},
);
}
if let Some(ek) = ek {
group.bench_function(
BenchmarkId::new("rotate_rows", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = ek.rotates_rows(&c1).unwrap());
},
);
if let Some(ek) = ek {
group.bench_function(
BenchmarkId::new("rotate_rows", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = ek.rotates_rows(&c1).unwrap());
},
);
group.bench_function(
BenchmarkId::new("rotate_columns", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = ek.rotates_columns_by(&c1, 1).unwrap());
},
);
group.bench_function(
BenchmarkId::new("rotate_columns", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = ek.rotates_columns_by(&c1, 1).unwrap());
},
);
group.bench_function(
BenchmarkId::new("inner_sum", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = ek.computes_inner_sum(&c1).unwrap());
},
);
group.bench_function(
BenchmarkId::new("inner_sum", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| c1 = ek.computes_inner_sum(&c1).unwrap());
},
);
for i in 1..=ilog2(par.degree() as u64) {
if par.degree() > 2048 && i > 4 {
continue; // Skip slow benchmarks
}
group.bench_function(
BenchmarkId::new(
format!("expand_{i}"),
format!("n={}/log(q)={}", par.degree(), q),
),
|b| {
b.iter(|| ek.expands(&c1, 1 << i).unwrap());
},
);
}
}
for i in 1..=ilog2(par.degree() as u64) {
if par.degree() > 2048 && i > 4 {
continue; // Skip slow benchmarks
}
group.bench_function(
BenchmarkId::new(
format!("expand_{i}"),
format!("n={}/log(q)={}", par.degree(), q),
),
|b| {
b.iter(|| ek.expands(&c1, 1 << i).unwrap());
},
);
}
}
group.bench_function(
BenchmarkId::new("mul", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| &c1 * &c2);
},
);
group.bench_function(
BenchmarkId::new("mul", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| &c1 * &c2);
},
);
group.bench_function(
BenchmarkId::new("square", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| &c1 * &c1);
},
);
group.bench_function(
BenchmarkId::new("square", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| &c1 * &c1);
},
);
if let Some(rk) = rk.as_ref() {
group.bench_function(
BenchmarkId::new(
"mul_then_relinearize",
format!("n={}/log(q)={}", par.degree(), q),
),
|b| {
b.iter(|| {
c3 = &c1 * &c2;
assert!(rk.relinearizes(&mut c3).is_ok());
});
},
);
if let Some(rk) = rk.as_ref() {
group.bench_function(
BenchmarkId::new(
"mul_then_relinearize",
format!("n={}/log(q)={}", par.degree(), q),
),
|b| {
b.iter(|| {
c3 = &c1 * &c2;
assert!(rk.relinearizes(&mut c3).is_ok());
});
},
);
// Default multiplication method
let multiplicator = Multiplicator::default(rk).unwrap();
// Default multiplication method
let multiplicator = Multiplicator::default(rk).unwrap();
group.bench_function(
BenchmarkId::new("mul_and_relin", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
},
);
group.bench_function(
BenchmarkId::new("mul_and_relin", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
},
);
// Second multiplication option.
let nmoduli = div_ceil(q, 62);
let mut extended_basis = par.moduli().to_vec();
let mut upper_bound = u64::MAX >> 2;
while extended_basis.len() != nmoduli + par.moduli().len() {
upper_bound = generate_prime(62, 2 * par.degree() as u64, upper_bound).unwrap();
if !extended_basis.contains(&upper_bound) {
extended_basis.push(upper_bound)
}
}
let rns_q = RnsContext::new(&extended_basis[..par.moduli().len()]).unwrap();
let rns_p = RnsContext::new(&extended_basis[par.moduli().len()..]).unwrap();
let mut multiplicator = Multiplicator::new(
ScalingFactor::one(),
ScalingFactor::new(rns_p.modulus(), rns_q.modulus()),
&extended_basis,
ScalingFactor::new(&BigUint::from(par.plaintext()), rns_p.modulus()),
&par,
)
.unwrap();
assert!(multiplicator.enable_relinearization(rk).is_ok());
group.bench_function(
BenchmarkId::new(
"mul_and_relin_2",
format!("n={}/log(q)={}", par.degree(), q),
),
|b| {
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
},
);
}
}
// Second multiplication option.
let nmoduli = div_ceil(q, 62);
let mut extended_basis = par.moduli().to_vec();
let mut upper_bound = u64::MAX >> 2;
while extended_basis.len() != nmoduli + par.moduli().len() {
upper_bound = generate_prime(62, 2 * par.degree() as u64, upper_bound).unwrap();
if !extended_basis.contains(&upper_bound) {
extended_basis.push(upper_bound)
}
}
let rns_q = RnsContext::new(&extended_basis[..par.moduli().len()]).unwrap();
let rns_p = RnsContext::new(&extended_basis[par.moduli().len()..]).unwrap();
let mut multiplicator = Multiplicator::new(
ScalingFactor::one(),
ScalingFactor::new(rns_p.modulus(), rns_q.modulus()),
&extended_basis,
ScalingFactor::new(&BigUint::from(par.plaintext()), rns_p.modulus()),
&par,
)
.unwrap();
assert!(multiplicator.enable_relinearization(rk).is_ok());
group.bench_function(
BenchmarkId::new(
"mul_and_relin_2",
format!("n={}/log(q)={}", par.degree(), q),
),
|b| {
b.iter(|| assert!(multiplicator.multiply(&c1, &c2).is_ok()));
},
);
}
}
group.finish();
group.finish();
}
criterion_group!(bfv, bfv_benchmark);

View File

@@ -6,66 +6,66 @@ use rand::{rngs::OsRng, thread_rng};
use std::time::Duration;
pub fn bfv_benchmark(c: &mut Criterion) {
let mut rng = thread_rng();
let mut group = c.benchmark_group("bfv_optimized_ops");
group.sample_size(10);
group.warm_up_time(Duration::from_secs(1));
group.measurement_time(Duration::from_secs(1));
let mut rng = thread_rng();
let mut group = c.benchmark_group("bfv_optimized_ops");
group.sample_size(10);
group.warm_up_time(Duration::from_secs(1));
group.measurement_time(Duration::from_secs(1));
for par in &BfvParameters::default_parameters_128(20)[2..] {
for size in [10, 128, 1000] {
let sk = SecretKey::random(par, &mut OsRng);
let pt1 =
Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), par).unwrap();
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
for par in &BfvParameters::default_parameters_128(20)[2..] {
for size in [10, 128, 1000] {
let sk = SecretKey::random(par, &mut OsRng);
let pt1 =
Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::poly(), par).unwrap();
let mut c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
let ct_vec = (0..size)
.map(|i| {
let pt =
Plaintext::try_encode(&(i..16u64).collect_vec(), Encoding::poly(), par)
.unwrap();
sk.try_encrypt(&pt, &mut rng).unwrap()
})
.collect_vec();
let pt_vec = (0..size)
.map(|i| {
Plaintext::try_encode(&(i..39u64).collect_vec(), Encoding::poly(), par).unwrap()
})
.collect_vec();
let ct_vec = (0..size)
.map(|i| {
let pt =
Plaintext::try_encode(&(i..16u64).collect_vec(), Encoding::poly(), par)
.unwrap();
sk.try_encrypt(&pt, &mut rng).unwrap()
})
.collect_vec();
let pt_vec = (0..size)
.map(|i| {
Plaintext::try_encode(&(i..39u64).collect_vec(), Encoding::poly(), par).unwrap()
})
.collect_vec();
group.bench_function(
BenchmarkId::new(
"dot_product/naive",
format!(
"size={}/degree={}/logq={}",
size,
par.degree(),
par.moduli_sizes().iter().sum::<usize>()
),
),
|b| {
b.iter(|| izip!(&ct_vec, &pt_vec).for_each(|(cti, pti)| c1 += &(cti * pti)));
},
);
group.bench_function(
BenchmarkId::new(
"dot_product/naive",
format!(
"size={}/degree={}/logq={}",
size,
par.degree(),
par.moduli_sizes().iter().sum::<usize>()
),
),
|b| {
b.iter(|| izip!(&ct_vec, &pt_vec).for_each(|(cti, pti)| c1 += &(cti * pti)));
},
);
group.bench_function(
BenchmarkId::new(
"dot_product/opt",
format!(
"size={}/degree={}/logq={}",
size,
par.degree(),
par.moduli_sizes().iter().sum::<usize>()
),
),
|b| {
b.iter(|| dot_product_scalar(ct_vec.iter(), pt_vec.iter()));
},
);
}
}
group.bench_function(
BenchmarkId::new(
"dot_product/opt",
format!(
"size={}/degree={}/logq={}",
size,
par.degree(),
par.moduli_sizes().iter().sum::<usize>()
),
),
|b| {
b.iter(|| dot_product_scalar(ct_vec.iter(), pt_vec.iter()));
},
);
}
}
group.finish();
group.finish();
}
criterion_group!(bfv, bfv_benchmark);

View File

@@ -6,30 +6,30 @@ use rand::{rngs::OsRng, thread_rng};
use std::time::Duration;
pub fn bfv_rgsw_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("bfv_rgsw");
group.sample_size(10);
group.warm_up_time(Duration::from_secs(1));
group.measurement_time(Duration::from_secs(1));
let mut group = c.benchmark_group("bfv_rgsw");
group.sample_size(10);
group.warm_up_time(Duration::from_secs(1));
group.measurement_time(Duration::from_secs(1));
for par in &BfvParameters::default_parameters_128(20)[2..] {
let mut rng = thread_rng();
let sk = SecretKey::random(par, &mut OsRng);
for par in &BfvParameters::default_parameters_128(20)[2..] {
let mut rng = thread_rng();
let sk = SecretKey::random(par, &mut OsRng);
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), par).unwrap();
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), par).unwrap();
let c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
let c2: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
let q = par.moduli_sizes().iter().sum::<usize>();
let pt1 = Plaintext::try_encode(&(1..16u64).collect_vec(), Encoding::simd(), par).unwrap();
let pt2 = Plaintext::try_encode(&(3..39u64).collect_vec(), Encoding::simd(), par).unwrap();
let c1: Ciphertext = sk.try_encrypt(&pt1, &mut rng).unwrap();
let c2: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng).unwrap();
let q = par.moduli_sizes().iter().sum::<usize>();
group.bench_function(
BenchmarkId::new("external", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| &c1 * &c2);
},
);
}
group.bench_function(
BenchmarkId::new("external", format!("n={}/log(q)={}", par.degree(), q)),
|b| {
b.iter(|| &c1 * &c2);
},
);
}
group.finish();
group.finish();
}
criterion_group!(bfv_rgsw, bfv_rgsw_benchmark);

View File

@@ -11,260 +11,260 @@ mod util;
use console::style;
use fhe::bfv;
use fhe_traits::{
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize,
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize,
};
use fhe_util::{ilog2, inverse, transcode_to_bytes};
use indicatif::HumanBytes;
use rand::{rngs::OsRng, thread_rng, RngCore};
use std::{env, error::Error, process::exit, sync::Arc};
use util::{
encode_database, generate_database, number_elements_per_plaintext,
timeit::{timeit, timeit_n},
encode_database, generate_database, number_elements_per_plaintext,
timeit::{timeit, timeit_n},
};
fn print_notice_and_exit(max_element_size: usize, error: Option<String>) {
println!(
"{} MulPIR with fhe.rs",
style(" overview:").magenta().bold()
);
println!(
"{} mulpir- [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
style(" usage:").magenta().bold()
);
println!(
"{} {} must be at least 1, and {} must be between 1 and {}",
style("constraints:").magenta().bold(),
style("database_size").blue(),
style("element_size").blue(),
max_element_size
);
if let Some(error) = error {
println!("{} {}", style(" error:").red().bold(), error);
}
exit(0);
println!(
"{} MulPIR with fhe.rs",
style(" overview:").magenta().bold()
);
println!(
"{} mulpir- [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
style(" usage:").magenta().bold()
);
println!(
"{} {} must be at least 1, and {} must be between 1 and {}",
style("constraints:").magenta().bold(),
style("database_size").blue(),
style("element_size").blue(),
max_element_size
);
if let Some(error) = error {
println!("{} {}", style(" error:").red().bold(), error);
}
exit(0);
}
fn main() -> Result<(), Box<dyn Error>> {
// We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf.
let degree = 8192;
let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1;
let moduli_sizes = [50, 55, 55];
// We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf.
let degree = 8192;
let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1;
let moduli_sizes = [50, 55, 55];
// Compute what is the maximum byte-length of an element to fit within one
// ciphertext. Each coefficient of the ciphertext polynomial can contain
// floor(log2(plaintext_modulus)) bits.
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
// Compute what is the maximum byte-length of an element to fit within one
// ciphertext. Each coefficient of the ciphertext polynomial can contain
// floor(log2(plaintext_modulus)) bits.
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
// This executable is a command line tool which enables to specify different
// database and element sizes.
let args: Vec<String> = env::args().skip(1).collect();
// This executable is a command line tool which enables to specify different
// database and element sizes.
let args: Vec<String> = env::args().skip(1).collect();
// Print the help if requested.
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
print_notice_and_exit(max_element_size, None)
}
// Print the help if requested.
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
print_notice_and_exit(max_element_size, None)
}
// Use the default values from <https://eprint.iacr.org/2019/1483.pdf>.
let mut database_size = 1 << 20;
let mut elements_size = 288;
// Use the default values from <https://eprint.iacr.org/2019/1483.pdf>.
let mut database_size = 1 << 20;
let mut elements_size = 288;
// Update the database size and/or element size depending on the arguments
// provided.
for arg in &args {
if arg.starts_with("--database_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--database_size` command".to_string()),
)
} else {
database_size = a[0].parse::<usize>().unwrap()
}
} else if arg.starts_with("--element_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--element_size` command".to_string()),
)
} else {
elements_size = a[0].parse::<usize>().unwrap()
}
} else {
print_notice_and_exit(
max_element_size,
Some(format!("Unrecognized command: {arg}")),
)
}
}
// Update the database size and/or element size depending on the arguments
// provided.
for arg in &args {
if arg.starts_with("--database_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--database_size` command".to_string()),
)
} else {
database_size = a[0].parse::<usize>().unwrap()
}
} else if arg.starts_with("--element_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--element_size` command".to_string()),
)
} else {
elements_size = a[0].parse::<usize>().unwrap()
}
} else {
print_notice_and_exit(
max_element_size,
Some(format!("Unrecognized command: {arg}")),
)
}
}
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
print_notice_and_exit(
max_element_size,
Some("Element or database sizes out of bound".to_string()),
)
}
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
print_notice_and_exit(
max_element_size,
Some("Element or database sizes out of bound".to_string()),
)
}
// The parameters are within bound, let's go! Let's first display some
// information about the database.
println!("# MulPIR with fhe.rs");
println!(
"database of {}",
HumanBytes((database_size * elements_size) as u64)
);
println!("\tdatabase_size = {database_size}");
println!("\telements_size = {elements_size}");
// The parameters are within bound, let's go! Let's first display some
// information about the database.
println!("# MulPIR with fhe.rs");
println!(
"database of {}",
HumanBytes((database_size * elements_size) as u64)
);
println!("\tdatabase_size = {database_size}");
println!("\telements_size = {elements_size}");
// Generation of a random database.
let database = timeit!("Database generation", {
generate_database(database_size, elements_size)
});
// Generation of a random database.
let database = timeit!("Database generation", {
generate_database(database_size, elements_size)
});
// Let's generate the BFV parameters structure.
let params = timeit!(
"Parameters generation",
Arc::new(
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build()
.unwrap()
)
);
// Let's generate the BFV parameters structure.
let params = timeit!(
"Parameters generation",
Arc::new(
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build()
.unwrap()
)
);
// Proprocess the database on the server side: the database will be reshaped
// so as to pack as many values as possible in every row so that it fits in one
// ciphertext, and each element will be encoded as a polynomial in Ntt
// representation.
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
encode_database(&database, params.clone(), 1)
});
// Proprocess the database on the server side: the database will be reshaped
// so as to pack as many values as possible in every row so that it fits in one
// ciphertext, and each element will be encoded as a polynomial in Ntt
// representation.
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
encode_database(&database, params.clone(), 1)
});
// Client setup: the client generates a secret key, an evaluation key for
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a
// relinearization key.
let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", {
let sk = bfv::SecretKey::random(&params, &mut OsRng);
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
println!("level = {level}");
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
.enable_expansion(level)?
.build(&mut thread_rng())?;
let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut thread_rng())?;
let ek_expansion_serialized = ek_expansion.to_bytes();
let rk_serialized = rk.to_bytes();
(sk, ek_expansion_serialized, rk_serialized)
});
println!(
"📄 Evaluation key (expansion): {}",
HumanBytes(ek_expansion_serialized.len() as u64)
);
println!(
"📄 Relinearization key: {}",
HumanBytes(rk_serialized.len() as u64)
);
// Client setup: the client generates a secret key, an evaluation key for
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a
// relinearization key.
let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", {
let sk = bfv::SecretKey::random(&params, &mut OsRng);
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
println!("level = {level}");
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
.enable_expansion(level)?
.build(&mut thread_rng())?;
let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut thread_rng())?;
let ek_expansion_serialized = ek_expansion.to_bytes();
let rk_serialized = rk.to_bytes();
(sk, ek_expansion_serialized, rk_serialized)
});
println!(
"📄 Evaluation key (expansion): {}",
HumanBytes(ek_expansion_serialized.len() as u64)
);
println!(
"📄 Relinearization key: {}",
HumanBytes(rk_serialized.len() as u64)
);
// Server setup: the server receives the evaluation and relinearization keys and
// deserializes them.
let (ek_expansion, rk) = timeit!("Server setup", {
(
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, &params)?,
bfv::RelinearizationKey::from_bytes(&rk_serialized, &params)?,
)
});
// Server setup: the server receives the evaluation and relinearization keys and
// deserializes them.
let (ek_expansion, rk) = timeit!("Server setup", {
(
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, &params)?,
bfv::RelinearizationKey::from_bytes(&rk_serialized, &params)?,
)
});
// Client query: when the client wants to retrieve the `index`-th row of the
// original database, it first computes to which row it corresponds in the
// original database, and then encrypt a selection vector with 0 everywhere,
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
// The ciphertext is set at level `1`, which means that one of the three moduli
// has been dropped already; the reason is that the expansion will happen at
// level 0 (with all three moduli) and then one of the moduli will be dropped
// to reduce the noise.
let index = (thread_rng().next_u64() as usize) % database_size;
let query = timeit!("Client query", {
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
let query_index = index
/ number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
let mut pt = vec![0u64; dim1 + dim2];
let inv = inverse(1 << level, plaintext_modulus).unwrap();
pt[query_index / dim2] = inv;
pt[dim1 + (query_index % dim2)] = inv;
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), &params)?;
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
query.to_bytes()
});
println!("📄 Query: {}", HumanBytes(query.len() as u64));
// Client query: when the client wants to retrieve the `index`-th row of the
// original database, it first computes to which row it corresponds in the
// original database, and then encrypt a selection vector with 0 everywhere,
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
// The ciphertext is set at level `1`, which means that one of the three moduli
// has been dropped already; the reason is that the expansion will happen at
// level 0 (with all three moduli) and then one of the moduli will be dropped
// to reduce the noise.
let index = (thread_rng().next_u64() as usize) % database_size;
let query = timeit!("Client query", {
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
let query_index = index
/ number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
let mut pt = vec![0u64; dim1 + dim2];
let inv = inverse(1 << level, plaintext_modulus).unwrap();
pt[query_index / dim2] = inv;
pt[dim1 + (query_index % dim2)] = inv;
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), &params)?;
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
query.to_bytes()
});
println!("📄 Query: {}", HumanBytes(query.len() as u64));
// Server response: The server receives the query, and after deserializing it,
// performs the following steps:
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
// If the client created the query correctly, the server will have obtained
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
// `dim1 + j`th ones encrypting `1`.
// 2- It computes the inner product of the first `dim1` ciphertexts with the
// columns if the database viewed as a dim1 * dim2 matrix.
// 3- It then multiplies the column of ciphertexts with the next `dim2`
// ciphertexts obtained after expansion of the query, then relinearize and
// modulus switch to the latest modulus to optimize communication.
// The operation is done `5` times to compute an average response time.
let response = timeit_n!("Server response", 5, {
let start = std::time::Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, &params)?;
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {:?}", start.elapsed());
// Server response: The server receives the query, and after deserializing it,
// performs the following steps:
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
// If the client created the query correctly, the server will have obtained
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
// `dim1 + j`th ones encrypting `1`.
// 2- It computes the inner product of the first `dim1` ciphertexts with the
// columns if the database viewed as a dim1 * dim2 matrix.
// 3- It then multiplies the column of ciphertexts with the next `dim2`
// ciphertexts obtained after expansion of the query, then relinearize and
// modulus switch to the latest modulus to optimize communication.
// The operation is done `5` times to compute an average response time.
let response = timeit_n!("Server response", 5, {
let start = std::time::Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, &params)?;
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {:?}", start.elapsed());
let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch =
move |i, database: &[bfv::Plaintext]| -> fhe::Result<bfv::Ciphertext> {
let column = database.iter().skip(i).step_by(dim2);
bfv::dot_product_scalar(query_vec.iter(), column)
};
let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch =
move |i, database: &[bfv::Plaintext]| -> fhe::Result<bfv::Ciphertext> {
let column = database.iter().skip(i).step_by(dim2);
bfv::dot_product_scalar(query_vec.iter(), column)
};
let mut out = bfv::Ciphertext::zero(&params);
for (i, ci) in expanded_query[dim1..].iter().enumerate() {
out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci)
}
rk.relinearizes(&mut out)?;
out.mod_switch_to_last_level();
out.to_bytes()
});
println!("📄 Response: {}", HumanBytes(response.len() as u64));
let mut out = bfv::Ciphertext::zero(&params);
for (i, ci) in expanded_query[dim1..].iter().enumerate() {
out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci)
}
rk.relinearizes(&mut out)?;
out.mod_switch_to_last_level();
out.to_bytes()
});
println!("📄 Response: {}", HumanBytes(response.len() as u64));
// Client processing: Upon reception of the response, the client decrypts.
// Finally, it outputs the plaintext bytes, offset by the correct value
// (remember the database was reshaped to maximize how many elements) were
// embedded in a single ciphertext.
let answer = timeit!("Client answer", {
let response = bfv::Ciphertext::from_bytes(&response, &params).unwrap();
// Client processing: Upon reception of the response, the client decrypts.
// Finally, it outputs the plaintext bytes, offset by the correct value
// (remember the database was reshaped to maximize how many elements) were
// embedded in a single ciphertext.
let answer = timeit!("Client answer", {
let response = bfv::Ciphertext::from_bytes(&response, &params).unwrap();
let pt = sk.try_decrypt(&response).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
let offset = index
% number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
let pt = sk.try_decrypt(&response).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
let offset = index
% number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
println!("Noise in response: {:?}", unsafe {
sk.measure_noise(&response)
});
println!("Noise in response: {:?}", unsafe {
sk.measure_noise(&response)
});
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
});
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
});
assert_eq!(&database[index], &answer);
assert_eq!(&database[index], &answer);
Ok(())
Ok(())
}

View File

@@ -12,8 +12,8 @@ use console::style;
use fhe::bfv;
use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation};
use fhe_traits::{
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime,
FheEncrypter, Serialize,
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime,
FheEncrypter, Serialize,
};
use fhe_util::{div_ceil, ilog2, inverse, transcode_bidirectional, transcode_to_bytes};
use indicatif::HumanBytes;
@@ -21,323 +21,323 @@ use itertools::Itertools;
use rand::{rngs::OsRng, thread_rng, RngCore};
use std::{env, error::Error, process::exit, sync::Arc};
use util::{
encode_database, generate_database, number_elements_per_plaintext,
timeit::{timeit, timeit_n},
encode_database, generate_database, number_elements_per_plaintext,
timeit::{timeit, timeit_n},
};
fn print_notice_and_exit(max_element_size: usize, error: Option<String>) {
println!(
"{} SealPIR with fhe.rs",
style(" overview:").magenta().bold()
);
println!(
"{} sealpir [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
style(" usage:").magenta().bold()
);
println!(
"{} {} must be at least 1, and {} must be between 1 and {}",
style("constraints:").magenta().bold(),
style("database_size").blue(),
style("element_size").blue(),
max_element_size
);
if let Some(error) = error {
println!("{} {}", style(" error:").red().bold(), error);
}
exit(0);
println!(
"{} SealPIR with fhe.rs",
style(" overview:").magenta().bold()
);
println!(
"{} sealpir [-h] [--help] [--database_size=<value>] [--element_size=<value>]",
style(" usage:").magenta().bold()
);
println!(
"{} {} must be at least 1, and {} must be between 1 and {}",
style("constraints:").magenta().bold(),
style("database_size").blue(),
style("element_size").blue(),
max_element_size
);
if let Some(error) = error {
println!("{} {}", style(" error:").red().bold(), error);
}
exit(0);
}
fn main() -> Result<(), Box<dyn Error>> {
let degree = 4096usize;
let plaintext_modulus = 2056193;
let moduli_sizes = [36, 36, 37];
let degree = 4096usize;
let plaintext_modulus = 2056193;
let moduli_sizes = [36, 36, 37];
// Compute what is the maximum byte-length of an element to fit within one
// ciphertext. Each coefficient of the ciphertext polynomial can contain
// floor(log2(plaintext_modulus)) bits.
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
// Compute what is the maximum byte-length of an element to fit within one
// ciphertext. Each coefficient of the ciphertext polynomial can contain
// floor(log2(plaintext_modulus)) bits.
let max_element_size = (ilog2(plaintext_modulus) * degree) / 8;
// This executable is a command line tool which enables to specify different
// database and element sizes.
let args: Vec<String> = env::args().skip(1).collect();
// This executable is a command line tool which enables to specify different
// database and element sizes.
let args: Vec<String> = env::args().skip(1).collect();
// Print the help if requested.
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
print_notice_and_exit(max_element_size, None)
}
// Print the help if requested.
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
print_notice_and_exit(max_element_size, None)
}
// Use the default values from <https://github.com/microsoft/SealPIR>.
let mut database_size = 1 << 16;
let mut elements_size = 1024;
// Use the default values from <https://github.com/microsoft/SealPIR>.
let mut database_size = 1 << 16;
let mut elements_size = 1024;
// Update the database size and/or element size depending on the arguments
// provided.
for arg in &args {
if arg.starts_with("--database_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--database_size` command".to_string()),
)
} else {
database_size = a[0].parse::<usize>().unwrap()
}
} else if arg.starts_with("--element_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--element_size` command".to_string()),
)
} else {
elements_size = a[0].parse::<usize>().unwrap()
}
} else {
print_notice_and_exit(
max_element_size,
Some(format!("Unrecognized command: {arg}")),
)
}
}
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
print_notice_and_exit(
max_element_size,
Some("Element or database sizes out of bound".to_string()),
)
}
// Update the database size and/or element size depending on the arguments
// provided.
for arg in &args {
if arg.starts_with("--database_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--database_size` command".to_string()),
)
} else {
database_size = a[0].parse::<usize>().unwrap()
}
} else if arg.starts_with("--element_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(
max_element_size,
Some("Invalid `--element_size` command".to_string()),
)
} else {
elements_size = a[0].parse::<usize>().unwrap()
}
} else {
print_notice_and_exit(
max_element_size,
Some(format!("Unrecognized command: {arg}")),
)
}
}
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
print_notice_and_exit(
max_element_size,
Some("Element or database sizes out of bound".to_string()),
)
}
// The parameters are within bound, let's go! Let's first display some
// information about the database.
println!("# SealPIR with fhe.rs");
println!(
"database of {}",
HumanBytes((database_size * elements_size) as u64)
);
println!("\tdatabase_size = {database_size}");
println!("\telements_size = {elements_size}");
// The parameters are within bound, let's go! Let's first display some
// information about the database.
println!("# SealPIR with fhe.rs");
println!(
"database of {}",
HumanBytes((database_size * elements_size) as u64)
);
println!("\tdatabase_size = {database_size}");
println!("\telements_size = {elements_size}");
// Generation of a random database.
let database = timeit!("Database generation", {
generate_database(database_size, elements_size)
});
// Generation of a random database.
let database = timeit!("Database generation", {
generate_database(database_size, elements_size)
});
// Let's generate the BFV parameters structure.
let params = timeit!(
"Parameters generation",
Arc::new(
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build()
.unwrap()
)
);
// Let's generate the BFV parameters structure.
let params = timeit!(
"Parameters generation",
Arc::new(
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build()
.unwrap()
)
);
// Proprocess the database on the server side: the database will be reshaped
// so as to pack as many values as possible in every row so that it fits in one
// ciphertext, and each element will be encoded as a polynomial in Ntt
// representation.
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
encode_database(&database, params.clone(), 1)
});
// Proprocess the database on the server side: the database will be reshaped
// so as to pack as many values as possible in every row so that it fits in one
// ciphertext, and each element will be encoded as a polynomial in Ntt
// representation.
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
encode_database(&database, params.clone(), 1)
});
// Client setup: the client generates a secret key, and an evaluation key for
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)).
let (sk, ek_expansion_serialized) = timeit!("Client setup", {
let sk = bfv::SecretKey::random(&params, &mut OsRng);
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
println!("expansion_level = {level}");
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
.enable_expansion(level)?
.build(&mut thread_rng())?;
let ek_expansion_serialized = ek_expansion.to_bytes();
(sk, ek_expansion_serialized)
});
println!(
"📄 Evaluation key: {}",
HumanBytes(ek_expansion_serialized.len() as u64)
);
// Client setup: the client generates a secret key, and an evaluation key for
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)).
let (sk, ek_expansion_serialized) = timeit!("Client setup", {
let sk = bfv::SecretKey::random(&params, &mut OsRng);
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
println!("expansion_level = {level}");
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
.enable_expansion(level)?
.build(&mut thread_rng())?;
let ek_expansion_serialized = ek_expansion.to_bytes();
(sk, ek_expansion_serialized)
});
println!(
"📄 Evaluation key: {}",
HumanBytes(ek_expansion_serialized.len() as u64)
);
// Server setup: the server receives the evaluation key and deserializes it.
let ek_expansion = timeit!(
"Server setup",
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, &params)?
);
// Server setup: the server receives the evaluation key and deserializes it.
let ek_expansion = timeit!(
"Server setup",
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, &params)?
);
// Client query: when the client wants to retrieve the `index`-th row of the
// original database, it first computes to which row it corresponds in the
// original database, and then encrypt a selection vector with 0 everywhere,
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
// The ciphertext is set at level `1`, which means that one of the three moduli
// has been dropped already; the reason is that the expansion will happen at
// level 0 (with all three moduli) and then one of the moduli will be dropped
// to reduce the noise.
let index = (thread_rng().next_u64() as usize) % database_size;
let query = timeit!("Client query", {
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
let query_index = index
/ number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
let mut pt = vec![0u64; dim1 + dim2];
let inv = inverse(1 << level, plaintext_modulus).unwrap();
pt[query_index / dim2] = inv;
pt[dim1 + (query_index % dim2)] = inv;
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), &params)?;
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
query.to_bytes()
});
println!("📄 Query: {}", HumanBytes(query.len() as u64));
// Client query: when the client wants to retrieve the `index`-th row of the
// original database, it first computes to which row it corresponds in the
// original database, and then encrypt a selection vector with 0 everywhere,
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
// The ciphertext is set at level `1`, which means that one of the three moduli
// has been dropped already; the reason is that the expansion will happen at
// level 0 (with all three moduli) and then one of the moduli will be dropped
// to reduce the noise.
let index = (thread_rng().next_u64() as usize) % database_size;
let query = timeit!("Client query", {
let level = ilog2((dim1 + dim2).next_power_of_two() as u64);
let query_index = index
/ number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
let mut pt = vec![0u64; dim1 + dim2];
let inv = inverse(1 << level, plaintext_modulus).unwrap();
pt[query_index / dim2] = inv;
pt[dim1 + (query_index % dim2)] = inv;
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), &params)?;
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut thread_rng())?;
query.to_bytes()
});
println!("📄 Query: {}", HumanBytes(query.len() as u64));
// Server response: The server receives the query, and after deserializing it,
// performs the following steps:
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
// If the client created the query correctly, the server will have obtained
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
// `dim1 + j`th ones encrypting `1`.
// 2- It computes the inner product of the first `dim1` ciphertexts with the
// columns if the database viewed as a dim1 * dim2 matrix, and modulo-switch
// the ciphertext once.
// 3- It parses the resulting ciphertexts as vector of plaintexts, and compute
// the inner product of the last `dim2` ciphertexts from step 1 with the
// transposed of the plaintext obtained above.
// The operation is done `5` times to compute an average response time.
let responses: Vec<Vec<u8>> = timeit_n!("Server response", 5, {
let start = std::time::Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, &params);
let query = query.unwrap();
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {}", DisplayDuration(start.elapsed()));
// Server response: The server receives the query, and after deserializing it,
// performs the following steps:
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
// If the client created the query correctly, the server will have obtained
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
// `dim1 + j`th ones encrypting `1`.
// 2- It computes the inner product of the first `dim1` ciphertexts with the
// columns if the database viewed as a dim1 * dim2 matrix, and modulo-switch
// the ciphertext once.
// 3- It parses the resulting ciphertexts as vector of plaintexts, and compute
// the inner product of the last `dim2` ciphertexts from step 1 with the
// transposed of the plaintext obtained above.
// The operation is done `5` times to compute an average response time.
let responses: Vec<Vec<u8>> = timeit_n!("Server response", 5, {
let start = std::time::Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, &params);
let query = query.unwrap();
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {}", DisplayDuration(start.elapsed()));
let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| {
let column = database.iter().skip(i).step_by(dim2);
let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?;
c.mod_switch_to_last_level();
Ok(c)
};
let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| {
let column = database.iter().skip(i).step_by(dim2);
let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?;
c.mod_switch_to_last_level();
Ok(c)
};
let dot_products = (0..dim2)
.map(|i| dot_product_mod_switch(i, &preprocessed_database))
.collect::<fhe::Result<Vec<bfv::Ciphertext>>>()?;
let dot_products = (0..dim2)
.map(|i| dot_product_mod_switch(i, &preprocessed_database))
.collect::<fhe::Result<Vec<bfv::Ciphertext>>>()?;
let fold = dot_products
.iter()
.map(|c| {
let mut pt_values = Vec::with_capacity(div_ceil(
2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize)),
ilog2(plaintext_modulus),
));
pt_values.append(&mut transcode_bidirectional(
c.get(0).unwrap().coefficients().as_slice().unwrap(),
64 - params.moduli()[0].leading_zeros() as usize,
ilog2(plaintext_modulus),
));
pt_values.append(&mut transcode_bidirectional(
c.get(1).unwrap().coefficients().as_slice().unwrap(),
64 - params.moduli()[0].leading_zeros() as usize,
ilog2(plaintext_modulus),
));
unsafe {
Ok(bfv::PlaintextVec::try_encode_vt(
&pt_values,
bfv::Encoding::poly_at_level(1),
&params,
)?
.0)
}
})
.collect::<fhe::Result<Vec<Vec<bfv::Plaintext>>>>()?;
(0..fold[0].len())
.map(|i| {
let mut outi = bfv::dot_product_scalar(
expanded_query[dim1..].iter(),
fold.iter().map(|pts| pts.get(i).unwrap()),
)?;
outi.mod_switch_to_last_level();
Ok(outi.to_bytes())
})
.collect::<fhe::Result<Vec<Vec<u8>>>>()?
});
println!(
"📄 Response: {}",
HumanBytes(responses.iter().map(|r| r.len()).sum::<usize>() as u64)
);
let fold = dot_products
.iter()
.map(|c| {
let mut pt_values = Vec::with_capacity(div_ceil(
2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize)),
ilog2(plaintext_modulus),
));
pt_values.append(&mut transcode_bidirectional(
c.get(0).unwrap().coefficients().as_slice().unwrap(),
64 - params.moduli()[0].leading_zeros() as usize,
ilog2(plaintext_modulus),
));
pt_values.append(&mut transcode_bidirectional(
c.get(1).unwrap().coefficients().as_slice().unwrap(),
64 - params.moduli()[0].leading_zeros() as usize,
ilog2(plaintext_modulus),
));
unsafe {
Ok(bfv::PlaintextVec::try_encode_vt(
&pt_values,
bfv::Encoding::poly_at_level(1),
&params,
)?
.0)
}
})
.collect::<fhe::Result<Vec<Vec<bfv::Plaintext>>>>()?;
(0..fold[0].len())
.map(|i| {
let mut outi = bfv::dot_product_scalar(
expanded_query[dim1..].iter(),
fold.iter().map(|pts| pts.get(i).unwrap()),
)?;
outi.mod_switch_to_last_level();
Ok(outi.to_bytes())
})
.collect::<fhe::Result<Vec<Vec<u8>>>>()?
});
println!(
"📄 Response: {}",
HumanBytes(responses.iter().map(|r| r.len()).sum::<usize>() as u64)
);
// Client processing: Upon reception of the response, the client decrypts
// the ciphertexts and recover the "ciphertexts" which were parsed as plaintext,
// which it decrypts too. Finally, it outputs the plaintext bytes, offset by the
// correct value (remember the database was reshaped to maximize how many
// elements) were embedded in a single ciphertext.
let answer = timeit!("Client answer", {
let responses = responses
.iter()
.map(|r| bfv::Ciphertext::from_bytes(r, &params).unwrap())
.collect_vec();
let decrypted_pt = responses
.iter()
.map(|r| sk.try_decrypt(r).unwrap())
.collect_vec();
let decrypted_vec = decrypted_pt
.iter()
.flat_map(|pt| Vec::<u64>::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap())
.collect_vec();
let expect_ncoefficients = div_ceil(
params.degree() * (64 - params.moduli()[0].leading_zeros() as usize),
ilog2(plaintext_modulus),
);
assert!(decrypted_vec.len() >= 2 * expect_ncoefficients);
let mut poly0 = transcode_bidirectional(
&decrypted_vec[..expect_ncoefficients],
ilog2(plaintext_modulus),
64 - params.moduli()[0].leading_zeros() as usize,
);
let mut poly1 = transcode_bidirectional(
&decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients],
ilog2(plaintext_modulus),
64 - params.moduli()[0].leading_zeros() as usize,
);
assert!(poly0.len() >= params.degree());
assert!(poly1.len() >= params.degree());
poly0.truncate(params.degree());
poly1.truncate(params.degree());
// Client processing: Upon reception of the response, the client decrypts
// the ciphertexts and recover the "ciphertexts" which were parsed as plaintext,
// which it decrypts too. Finally, it outputs the plaintext bytes, offset by the
// correct value (remember the database was reshaped to maximize how many
// elements) were embedded in a single ciphertext.
let answer = timeit!("Client answer", {
let responses = responses
.iter()
.map(|r| bfv::Ciphertext::from_bytes(r, &params).unwrap())
.collect_vec();
let decrypted_pt = responses
.iter()
.map(|r| sk.try_decrypt(r).unwrap())
.collect_vec();
let decrypted_vec = decrypted_pt
.iter()
.flat_map(|pt| Vec::<u64>::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap())
.collect_vec();
let expect_ncoefficients = div_ceil(
params.degree() * (64 - params.moduli()[0].leading_zeros() as usize),
ilog2(plaintext_modulus),
);
assert!(decrypted_vec.len() >= 2 * expect_ncoefficients);
let mut poly0 = transcode_bidirectional(
&decrypted_vec[..expect_ncoefficients],
ilog2(plaintext_modulus),
64 - params.moduli()[0].leading_zeros() as usize,
);
let mut poly1 = transcode_bidirectional(
&decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients],
ilog2(plaintext_modulus),
64 - params.moduli()[0].leading_zeros() as usize,
);
assert!(poly0.len() >= params.degree());
assert!(poly1.len() >= params.degree());
poly0.truncate(params.degree());
poly1.truncate(params.degree());
let ctx = Arc::new(Context::new(&params.moduli()[..1], params.degree())?);
let ct = bfv::Ciphertext::new(
vec![
Poly::try_convert_from(poly0, &ctx, true, Representation::Ntt)?,
Poly::try_convert_from(poly1, &ctx, true, Representation::Ntt)?,
],
&params,
)?;
let ctx = Arc::new(Context::new(&params.moduli()[..1], params.degree())?);
let ct = bfv::Ciphertext::new(
vec![
Poly::try_convert_from(poly0, &ctx, true, Representation::Ntt)?,
Poly::try_convert_from(poly1, &ctx, true, Representation::Ntt)?,
],
&params,
)?;
let pt = sk.try_decrypt(&ct).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
let offset = index
% number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
let pt = sk.try_decrypt(&ct).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
let offset = index
% number_elements_per_plaintext(
params.degree(),
ilog2(plaintext_modulus),
elements_size,
);
println!("Noise in response (ct): {:?}", unsafe {
sk.measure_noise(&ct)
});
println!("Noise in response (ct): {:?}", unsafe {
sk.measure_noise(&ct)
});
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
});
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
});
// Assert that the answer is indeed the `index`-th element of the initial
// database.
assert_eq!(&database[index], &answer);
// Assert that the answer is indeed the `index`-th element of the initial
// database.
assert_eq!(&database[index], &answer);
Ok(())
Ok(())
}

View File

@@ -7,37 +7,37 @@ use std::{cmp::min, fmt, sync::Arc, time::Duration};
/// Macros to time code and display a human-readable duration.
pub mod timeit {
#[allow(unused_macros)]
macro_rules! timeit_n {
($name:expr, $loops:expr, $code:expr) => {{
use util::DisplayDuration;
let start = std::time::Instant::now();
#[allow(unused_macros)]
macro_rules! timeit_n {
($name:expr, $loops:expr, $code:expr) => {{
use util::DisplayDuration;
let start = std::time::Instant::now();
#[allow(clippy::reversed_empty_ranges)]
for _ in 1..$loops {
let _ = $code;
}
let r = $code;
println!(
"{}: {}",
$name,
DisplayDuration(start.elapsed() / $loops)
);
r
}};
}
#[allow(clippy::reversed_empty_ranges)]
for _ in 1..$loops {
let _ = $code;
}
let r = $code;
println!(
"{}: {}",
$name,
DisplayDuration(start.elapsed() / $loops)
);
r
}};
}
#[allow(unused_macros)]
macro_rules! timeit {
($name:expr, $code:expr) => {{
timeit_n!($name, 1, $code)
}};
}
#[allow(unused_macros)]
macro_rules! timeit {
($name:expr, $code:expr) => {{
timeit_n!($name, 1, $code)
}};
}
#[allow(unused_imports)]
pub(crate) use timeit;
#[allow(unused_imports)]
pub(crate) use timeit_n;
#[allow(unused_imports)]
pub(crate) use timeit;
#[allow(unused_imports)]
pub(crate) use timeit_n;
}
/// Utility struct for displaying human-readable duration of the form "10.5 ms",
@@ -45,17 +45,17 @@ pub mod timeit {
pub struct DisplayDuration(pub Duration);
impl fmt::Display for DisplayDuration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let duration_ns = self.0.as_nanos();
if duration_ns < 1_000_u128 {
write!(f, "{duration_ns} ns")
} else if duration_ns < 1_000_000_u128 {
write!(f, "{} μs", (duration_ns + 500) / 1_000)
} else {
let duration_ms_times_10 = (duration_ns + 50_000) / (100_000);
write!(f, "{} ms", (duration_ms_times_10 as f64) / 10.0)
}
}
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let duration_ns = self.0.as_nanos();
if duration_ns < 1_000_u128 {
write!(f, "{duration_ns} ns")
} else if duration_ns < 1_000_000_u128 {
write!(f, "{} μs", (duration_ns + 500) / 1_000)
} else {
let duration_ms_times_10 = (duration_ns + 50_000) / (100_000);
write!(f, "{} ms", (duration_ms_times_10 as f64) / 10.0)
}
}
}
// Utility functions for Private Information Retrieval.
@@ -64,60 +64,60 @@ impl fmt::Display for DisplayDuration {
/// little endian encoding of the index. When the element size is less than 4B,
/// the encoding is truncated.
pub fn generate_database(database_size: usize, elements_size: usize) -> Vec<Vec<u8>> {
assert!(elements_size > 0 && database_size > 0);
let mut database = vec![vec![0u8; elements_size]; database_size];
for (i, element) in database.iter_mut().enumerate() {
element[..min(4, elements_size)]
.copy_from_slice(&(i as u32).to_le_bytes()[..min(4, elements_size)]);
}
database
assert!(elements_size > 0 && database_size > 0);
let mut database = vec![vec![0u8; elements_size]; database_size];
for (i, element) in database.iter_mut().enumerate() {
element[..min(4, elements_size)]
.copy_from_slice(&(i as u32).to_le_bytes()[..min(4, elements_size)]);
}
database
}
pub fn number_elements_per_plaintext(
degree: usize,
plaintext_nbits: usize,
elements_size: usize,
degree: usize,
plaintext_nbits: usize,
elements_size: usize,
) -> usize {
(plaintext_nbits * degree) / (elements_size * 8)
(plaintext_nbits * degree) / (elements_size * 8)
}
pub fn encode_database(
database: &Vec<Vec<u8>>,
par: Arc<bfv::BfvParameters>,
level: usize,
database: &Vec<Vec<u8>>,
par: Arc<bfv::BfvParameters>,
level: usize,
) -> (Vec<bfv::Plaintext>, (usize, usize)) {
assert!(!database.is_empty());
assert!(!database.is_empty());
let elements_size = database[0].len();
let plaintext_nbits = ilog2(par.plaintext());
let number_elements_per_plaintext =
number_elements_per_plaintext(par.degree(), plaintext_nbits, elements_size);
let number_rows =
(database.len() + number_elements_per_plaintext - 1) / number_elements_per_plaintext;
println!("number_rows = {number_rows}");
println!("number_elements_per_plaintext = {number_elements_per_plaintext}");
let dimension_1 = (number_rows as f64).sqrt().ceil() as usize;
let dimension_2 = (number_rows + dimension_1 - 1) / dimension_1;
println!("dimensions = {dimension_1} {dimension_2}");
println!("dimension = {}", dimension_1 * dimension_2);
let mut preprocessed_database =
vec![
bfv::Plaintext::zero(bfv::Encoding::poly_at_level(level), &par).unwrap();
dimension_1 * dimension_2
];
(0..number_rows).for_each(|i| {
let mut serialized_plaintext = vec![0u8; number_elements_per_plaintext * elements_size];
for j in 0..number_elements_per_plaintext {
if let Some(pt) = database.get(j + i * number_elements_per_plaintext) {
serialized_plaintext[j * elements_size..(j + 1) * elements_size].copy_from_slice(pt)
}
}
let pt_values = transcode_from_bytes(&serialized_plaintext, plaintext_nbits);
preprocessed_database[i] =
bfv::Plaintext::try_encode(&pt_values, bfv::Encoding::poly_at_level(level), &par)
.unwrap();
});
(preprocessed_database, (dimension_1, dimension_2))
let elements_size = database[0].len();
let plaintext_nbits = ilog2(par.plaintext());
let number_elements_per_plaintext =
number_elements_per_plaintext(par.degree(), plaintext_nbits, elements_size);
let number_rows =
(database.len() + number_elements_per_plaintext - 1) / number_elements_per_plaintext;
println!("number_rows = {number_rows}");
println!("number_elements_per_plaintext = {number_elements_per_plaintext}");
let dimension_1 = (number_rows as f64).sqrt().ceil() as usize;
let dimension_2 = (number_rows + dimension_1 - 1) / dimension_1;
println!("dimensions = {dimension_1} {dimension_2}");
println!("dimension = {}", dimension_1 * dimension_2);
let mut preprocessed_database =
vec![
bfv::Plaintext::zero(bfv::Encoding::poly_at_level(level), &par).unwrap();
dimension_1 * dimension_2
];
(0..number_rows).for_each(|i| {
let mut serialized_plaintext = vec![0u8; number_elements_per_plaintext * elements_size];
for j in 0..number_elements_per_plaintext {
if let Some(pt) = database.get(j + i * number_elements_per_plaintext) {
serialized_plaintext[j * elements_size..(j + 1) * elements_size].copy_from_slice(pt)
}
}
let pt_values = transcode_from_bytes(&serialized_plaintext, plaintext_nbits);
preprocessed_database[i] =
bfv::Plaintext::try_encode(&pt_values, bfv::Encoding::poly_at_level(level), &par)
.unwrap();
});
(preprocessed_database, (dimension_1, dimension_2))
}
#[allow(dead_code)]

View File

@@ -1,12 +1,12 @@
//! Ciphertext type in the BFV encryption scheme.
use crate::bfv::{
parameters::BfvParameters, proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom,
parameters::BfvParameters, proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom,
};
use crate::{Error, Result};
use fhe_math::rq::{Poly, Representation};
use fhe_traits::{
DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize,
DeserializeParametrized, DeserializeWithContext, FheCiphertext, FheParametrized, Serialize,
};
use protobuf::Message;
use rand::SeedableRng;
@@ -16,287 +16,287 @@ use std::sync::Arc;
/// A ciphertext encrypting a plaintext.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Ciphertext {
/// The parameters of the underlying BFV encryption scheme.
pub(crate) par: Arc<BfvParameters>,
/// The parameters of the underlying BFV encryption scheme.
pub(crate) par: Arc<BfvParameters>,
/// The seed that generated the polynomial c1 in a fresh ciphertext.
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
/// The seed that generated the polynomial c1 in a fresh ciphertext.
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
/// The ciphertext elements.
pub(crate) c: Vec<Poly>,
/// The ciphertext elements.
pub(crate) c: Vec<Poly>,
/// The ciphertext level
pub(crate) level: usize,
/// The ciphertext level
pub(crate) level: usize,
}
impl Ciphertext {
/// Modulo switch the ciphertext to the last level.
pub fn mod_switch_to_last_level(&mut self) {
self.level = self.par.max_level();
let last_ctx = self.par.ctx_at_level(self.level).unwrap();
self.seed = None;
self.c.iter_mut().for_each(|ci| {
if ci.ctx() != last_ctx {
ci.change_representation(Representation::PowerBasis);
assert!(ci.mod_switch_down_to(last_ctx).is_ok());
ci.change_representation(Representation::Ntt);
}
});
}
/// Modulo switch the ciphertext to the last level.
pub fn mod_switch_to_last_level(&mut self) {
self.level = self.par.max_level();
let last_ctx = self.par.ctx_at_level(self.level).unwrap();
self.seed = None;
self.c.iter_mut().for_each(|ci| {
if ci.ctx() != last_ctx {
ci.change_representation(Representation::PowerBasis);
assert!(ci.mod_switch_down_to(last_ctx).is_ok());
ci.change_representation(Representation::Ntt);
}
});
}
/// Modulo switch the ciphertext to the next level.
pub fn mod_switch_to_next_level(&mut self) {
if self.level < self.par.max_level() {
self.seed = None;
self.c.iter_mut().for_each(|ci| {
ci.change_representation(Representation::PowerBasis);
assert!(ci.mod_switch_down_next().is_ok());
ci.change_representation(Representation::Ntt);
});
self.level += 1
}
}
/// Modulo switch the ciphertext to the next level.
pub fn mod_switch_to_next_level(&mut self) {
if self.level < self.par.max_level() {
self.seed = None;
self.c.iter_mut().for_each(|ci| {
ci.change_representation(Representation::PowerBasis);
assert!(ci.mod_switch_down_next().is_ok());
ci.change_representation(Representation::Ntt);
});
self.level += 1
}
}
/// Create a ciphertext from a vector of polynomials.
/// A ciphertext must contain at least two polynomials, and all polynomials
/// must be in Ntt representation and with the same context.
pub fn new(c: Vec<Poly>, par: &Arc<BfvParameters>) -> Result<Self> {
if c.len() < 2 {
return Err(Error::TooFewValues(c.len(), 2));
}
/// Create a ciphertext from a vector of polynomials.
/// A ciphertext must contain at least two polynomials, and all polynomials
/// must be in Ntt representation and with the same context.
pub fn new(c: Vec<Poly>, par: &Arc<BfvParameters>) -> Result<Self> {
if c.len() < 2 {
return Err(Error::TooFewValues(c.len(), 2));
}
let ctx = c[0].ctx();
let level = par.level_of_ctx(ctx)?;
let ctx = c[0].ctx();
let level = par.level_of_ctx(ctx)?;
// Check that all polynomials have the expected representation and context.
for ci in c.iter() {
if ci.representation() != &Representation::Ntt {
return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation(
ci.representation().clone(),
Representation::Ntt,
)));
}
if ci.ctx() != ctx {
return Err(Error::MathError(fhe_math::Error::InvalidContext));
}
}
// Check that all polynomials have the expected representation and context.
for ci in c.iter() {
if ci.representation() != &Representation::Ntt {
return Err(Error::MathError(fhe_math::Error::IncorrectRepresentation(
ci.representation().clone(),
Representation::Ntt,
)));
}
if ci.ctx() != ctx {
return Err(Error::MathError(fhe_math::Error::InvalidContext));
}
}
Ok(Self {
par: par.clone(),
seed: None,
c,
level,
})
}
Ok(Self {
par: par.clone(),
seed: None,
c,
level,
})
}
/// Get the i-th polynomial of the ciphertext.
pub fn get(&self, i: usize) -> Option<&Poly> {
self.c.get(i)
}
/// Get the i-th polynomial of the ciphertext.
pub fn get(&self, i: usize) -> Option<&Poly> {
self.c.get(i)
}
}
impl FheCiphertext for Ciphertext {}
impl FheParametrized for Ciphertext {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl Serialize for Ciphertext {
fn to_bytes(&self) -> Vec<u8> {
CiphertextProto::from(self).write_to_bytes().unwrap()
}
fn to_bytes(&self) -> Vec<u8> {
CiphertextProto::from(self).write_to_bytes().unwrap()
}
}
impl DeserializeParametrized for Ciphertext {
fn from_bytes(bytes: &[u8], par: &Arc<BfvParameters>) -> Result<Self> {
if let Ok(ctp) = CiphertextProto::parse_from_bytes(bytes) {
Ciphertext::try_convert_from(&ctp, par)
} else {
Err(Error::SerializationError)
}
}
fn from_bytes(bytes: &[u8], par: &Arc<BfvParameters>) -> Result<Self> {
if let Ok(ctp) = CiphertextProto::parse_from_bytes(bytes) {
Ciphertext::try_convert_from(&ctp, par)
} else {
Err(Error::SerializationError)
}
}
type Error = Error;
type Error = Error;
}
impl Ciphertext {
/// Generate the zero ciphertext.
pub fn zero(par: &Arc<BfvParameters>) -> Self {
Self {
par: par.clone(),
seed: None,
c: Default::default(),
level: 0,
}
}
/// Generate the zero ciphertext.
pub fn zero(par: &Arc<BfvParameters>) -> Self {
Self {
par: par.clone(),
seed: None,
c: Default::default(),
level: 0,
}
}
}
/// Conversions from and to protobuf.
impl From<&Ciphertext> for CiphertextProto {
fn from(ct: &Ciphertext) -> Self {
let mut proto = CiphertextProto::new();
for i in 0..ct.c.len() - 1 {
proto.c.push(ct.c[i].to_bytes())
}
if let Some(seed) = ct.seed {
proto.seed = seed.to_vec()
} else {
proto.c.push(ct.c[ct.c.len() - 1].to_bytes())
}
proto.level = ct.level as u32;
proto
}
fn from(ct: &Ciphertext) -> Self {
let mut proto = CiphertextProto::new();
for i in 0..ct.c.len() - 1 {
proto.c.push(ct.c[i].to_bytes())
}
if let Some(seed) = ct.seed {
proto.seed = seed.to_vec()
} else {
proto.c.push(ct.c[ct.c.len() - 1].to_bytes())
}
proto.level = ct.level as u32;
proto
}
}
impl TryConvertFrom<&CiphertextProto> for Ciphertext {
fn try_convert_from(value: &CiphertextProto, par: &Arc<BfvParameters>) -> Result<Self> {
if value.c.is_empty() || (value.c.len() == 1 && value.seed.is_empty()) {
return Err(Error::DefaultError("Not enough polynomials".to_string()));
}
fn try_convert_from(value: &CiphertextProto, par: &Arc<BfvParameters>) -> Result<Self> {
if value.c.is_empty() || (value.c.len() == 1 && value.seed.is_empty()) {
return Err(Error::DefaultError("Not enough polynomials".to_string()));
}
if value.level as usize > par.max_level() {
return Err(Error::DefaultError("Invalid level".to_string()));
}
if value.level as usize > par.max_level() {
return Err(Error::DefaultError("Invalid level".to_string()));
}
let ctx = par.ctx_at_level(value.level as usize)?;
let ctx = par.ctx_at_level(value.level as usize)?;
let mut seed = None;
let mut seed = None;
let mut c = Vec::with_capacity(value.c.len() + 1);
for cip in &value.c {
c.push(Poly::from_bytes(cip, ctx)?)
}
let mut c = Vec::with_capacity(value.c.len() + 1);
for cip in &value.c {
c.push(Poly::from_bytes(cip, ctx)?)
}
if !value.seed.is_empty() {
let try_seed = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
if try_seed.is_err() {
return Err(Error::MathError(fhe_math::Error::InvalidSeedSize(
value.seed.len(),
<ChaCha8Rng as SeedableRng>::Seed::default().len(),
)));
}
seed = try_seed.ok();
let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, seed.unwrap());
unsafe { c1.allow_variable_time_computations() }
c.push(c1)
}
if !value.seed.is_empty() {
let try_seed = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
if try_seed.is_err() {
return Err(Error::MathError(fhe_math::Error::InvalidSeedSize(
value.seed.len(),
<ChaCha8Rng as SeedableRng>::Seed::default().len(),
)));
}
seed = try_seed.ok();
let mut c1 = Poly::random_from_seed(ctx, Representation::Ntt, seed.unwrap());
unsafe { c1.allow_variable_time_computations() }
c.push(c1)
}
Ok(Ciphertext {
par: par.clone(),
seed,
c,
level: value.level as usize,
})
}
Ok(Ciphertext {
par: par.clone(),
seed,
c,
level: value.level as usize,
})
}
}
#[cfg(test)]
mod tests {
use crate::bfv::{
proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, BfvParameters,
Ciphertext, Encoding, Plaintext, SecretKey,
};
use fhe_traits::FheDecrypter;
use fhe_traits::{DeserializeParametrized, FheEncoder, FheEncrypter, Serialize};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use crate::bfv::{
proto::bfv::Ciphertext as CiphertextProto, traits::TryConvertFrom, BfvParameters,
Ciphertext, Encoding, Plaintext, SecretKey,
};
use fhe_traits::FheDecrypter;
use fhe_traits::{DeserializeParametrized, FheEncoder, FheEncrypter, Serialize};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
let ct_proto = CiphertextProto::from(&ct);
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, &params)?);
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
let ct_proto = CiphertextProto::from(&ct);
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, &params)?);
let ct = &ct * &ct;
let ct_proto = CiphertextProto::from(&ct);
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, &params)?)
}
Ok(())
}
let ct = &ct * &ct;
let ct_proto = CiphertextProto::from(&ct);
assert_eq!(ct, Ciphertext::try_convert_from(&ct_proto, &params)?)
}
Ok(())
}
#[test]
fn serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
let ct_bytes = ct.to_bytes();
assert_eq!(ct, Ciphertext::from_bytes(&ct_bytes, &params)?);
}
Ok(())
}
#[test]
fn serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
let ct_bytes = ct.to_bytes();
assert_eq!(ct, Ciphertext::from_bytes(&ct_bytes, &params)?);
}
Ok(())
}
#[test]
fn new() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
let mut ct3 = &ct * &ct;
#[test]
fn new() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
let mut ct3 = &ct * &ct;
let c0 = ct3.get(0).unwrap();
let c1 = ct3.get(1).unwrap();
let c2 = ct3.get(2).unwrap();
let c0 = ct3.get(0).unwrap();
let c1 = ct3.get(1).unwrap();
let c2 = ct3.get(2).unwrap();
assert_eq!(
ct3,
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?
);
assert_eq!(ct3.level, 0);
assert_eq!(
ct3,
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?
);
assert_eq!(ct3.level, 0);
ct3.mod_switch_to_last_level();
ct3.mod_switch_to_last_level();
let c0 = ct3.get(0).unwrap();
let c1 = ct3.get(1).unwrap();
let c2 = ct3.get(2).unwrap();
assert_eq!(
ct3,
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?
);
assert_eq!(ct3.level, params.max_level());
}
let c0 = ct3.get(0).unwrap();
let c1 = ct3.get(1).unwrap();
let c2 = ct3.get(2).unwrap();
assert_eq!(
ct3,
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?
);
assert_eq!(ct3.level, params.max_level());
}
Ok(())
}
Ok(())
}
#[test]
fn mod_switch_to_last_level() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
#[test]
fn mod_switch_to_last_level() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
assert_eq!(ct.level, 0);
ct.mod_switch_to_last_level();
assert_eq!(ct.level, params.max_level());
assert_eq!(ct.level, 0);
ct.mod_switch_to_last_level();
assert_eq!(ct.level, params.max_level());
let decrypted = sk.try_decrypt(&ct)?;
assert_eq!(decrypted.value, pt.value);
}
let decrypted = sk.try_decrypt(&ct)?;
assert_eq!(decrypted.value, pt.value);
}
Ok(())
}
Ok(())
}
}

View File

@@ -6,71 +6,71 @@ use fhe_traits::FhePlaintextEncoding;
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) enum EncodingEnum {
Poly,
Simd,
Poly,
Simd,
}
impl Display for EncodingEnum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
/// An encoding for the plaintext.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Encoding {
pub(crate) encoding: EncodingEnum,
pub(crate) level: usize,
pub(crate) encoding: EncodingEnum,
pub(crate) level: usize,
}
impl Encoding {
/// A Poly encoding encodes a vector as coefficients of a polynomial;
/// homomorphic operations are therefore polynomial operations.
pub fn poly() -> Self {
Self {
encoding: EncodingEnum::Poly,
level: 0,
}
}
/// A Poly encoding encodes a vector as coefficients of a polynomial;
/// homomorphic operations are therefore polynomial operations.
pub fn poly() -> Self {
Self {
encoding: EncodingEnum::Poly,
level: 0,
}
}
/// A Simd encoding encodes a vector so that homomorphic operations are
/// component-wise operations on the coefficients of the underlying vectors.
/// The Simd encoding require that the plaintext modulus is congruent to 1
/// modulo the degree of the underlying polynomial.
pub fn simd() -> Self {
Self {
encoding: EncodingEnum::Simd,
level: 0,
}
}
/// A Simd encoding encodes a vector so that homomorphic operations are
/// component-wise operations on the coefficients of the underlying vectors.
/// The Simd encoding require that the plaintext modulus is congruent to 1
/// modulo the degree of the underlying polynomial.
pub fn simd() -> Self {
Self {
encoding: EncodingEnum::Simd,
level: 0,
}
}
/// A poly encoding at a given level.
pub fn poly_at_level(level: usize) -> Self {
Self {
encoding: EncodingEnum::Poly,
level,
}
}
/// A poly encoding at a given level.
pub fn poly_at_level(level: usize) -> Self {
Self {
encoding: EncodingEnum::Poly,
level,
}
}
/// A simd encoding at a given level.
pub fn simd_at_level(level: usize) -> Self {
Self {
encoding: EncodingEnum::Simd,
level,
}
}
/// A simd encoding at a given level.
pub fn simd_at_level(level: usize) -> Self {
Self {
encoding: EncodingEnum::Simd,
level,
}
}
}
impl From<Encoding> for String {
fn from(e: Encoding) -> Self {
String::from(&e)
}
fn from(e: Encoding) -> Self {
String::from(&e)
}
}
impl From<&Encoding> for String {
fn from(e: &Encoding) -> Self {
format!("{e:?}")
}
fn from(e: &Encoding) -> Self {
format!("{e:?}")
}
}
impl FhePlaintextEncoding for Encoding {}

File diff suppressed because it is too large Load Diff

View File

@@ -2,14 +2,14 @@
use super::key_switching_key::KeySwitchingKey;
use crate::bfv::{
proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto},
traits::TryConvertFrom,
BfvParameters, Ciphertext, SecretKey,
proto::bfv::{GaloisKey as GaloisKeyProto, KeySwitchingKey as KeySwitchingKeyProto},
traits::TryConvertFrom,
BfvParameters, Ciphertext, SecretKey,
};
use crate::{Error, Result};
use fhe_math::rq::{
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
SubstitutionExponent,
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
SubstitutionExponent,
};
use protobuf::MessageField;
use rand::{CryptoRng, RngCore};
@@ -21,181 +21,181 @@ use zeroize::Zeroizing;
/// which switch from `s(x^i)` to `s(x)` where `s(x)` is the secret key.
#[derive(Debug, PartialEq, Eq)]
pub struct GaloisKey {
pub(crate) element: SubstitutionExponent,
pub(crate) ksk: KeySwitchingKey,
pub(crate) element: SubstitutionExponent,
pub(crate) ksk: KeySwitchingKey,
}
impl GaloisKey {
/// Generate a [`GaloisKey`] from a [`SecretKey`].
pub fn new<R: RngCore + CryptoRng>(
sk: &SecretKey,
exponent: usize,
ciphertext_level: usize,
galois_key_level: usize,
rng: &mut R,
) -> Result<Self> {
let ctx_galois_key = sk.par.ctx_at_level(galois_key_level)?;
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
/// Generate a [`GaloisKey`] from a [`SecretKey`].
pub fn new<R: RngCore + CryptoRng>(
sk: &SecretKey,
exponent: usize,
ciphertext_level: usize,
galois_key_level: usize,
rng: &mut R,
) -> Result<Self> {
let ctx_galois_key = sk.par.ctx_at_level(galois_key_level)?;
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
let ciphertext_exponent =
SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?;
let ciphertext_exponent =
SubstitutionExponent::new(ctx_ciphertext, exponent).map_err(Error::MathError)?;
let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?;
let s = Zeroizing::new(Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx_ciphertext,
false,
Representation::PowerBasis,
)?);
let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?);
let mut s_sub_switched_up = Zeroizing::new(s_sub.mod_switch_to(&switcher_up)?);
s_sub_switched_up.change_representation(Representation::PowerBasis);
let switcher_up = Switcher::new(ctx_ciphertext, ctx_galois_key)?;
let s = Zeroizing::new(Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx_ciphertext,
false,
Representation::PowerBasis,
)?);
let s_sub = Zeroizing::new(s.substitute(&ciphertext_exponent)?);
let mut s_sub_switched_up = Zeroizing::new(s_sub.mod_switch_to(&switcher_up)?);
s_sub_switched_up.change_representation(Representation::PowerBasis);
let ksk = KeySwitchingKey::new(
sk,
&s_sub_switched_up,
ciphertext_level,
galois_key_level,
rng,
)?;
let ksk = KeySwitchingKey::new(
sk,
&s_sub_switched_up,
ciphertext_level,
galois_key_level,
rng,
)?;
Ok(Self {
element: ciphertext_exponent,
ksk,
})
}
Ok(Self {
element: ciphertext_exponent,
ksk,
})
}
/// Relinearize a [`Ciphertext`] using the [`GaloisKey`]
pub fn relinearize(&self, ct: &Ciphertext) -> Result<Ciphertext> {
// assert_eq!(ct.par, self.ksk.par);
assert_eq!(ct.c.len(), 2);
/// Relinearize a [`Ciphertext`] using the [`GaloisKey`]
pub fn relinearize(&self, ct: &Ciphertext) -> Result<Ciphertext> {
// assert_eq!(ct.par, self.ksk.par);
assert_eq!(ct.c.len(), 2);
let mut c2 = ct.c[1].substitute(&self.element)?;
c2.change_representation(Representation::PowerBasis);
let (mut c0, mut c1) = self.ksk.key_switch(&c2)?;
let mut c2 = ct.c[1].substitute(&self.element)?;
c2.change_representation(Representation::PowerBasis);
let (mut c0, mut c1) = self.ksk.key_switch(&c2)?;
if c0.ctx() != ct.c[0].ctx() {
c0.change_representation(Representation::PowerBasis);
c1.change_representation(Representation::PowerBasis);
c0.mod_switch_down_to(ct.c[0].ctx())?;
c1.mod_switch_down_to(ct.c[1].ctx())?;
c0.change_representation(Representation::Ntt);
c1.change_representation(Representation::Ntt);
}
if c0.ctx() != ct.c[0].ctx() {
c0.change_representation(Representation::PowerBasis);
c1.change_representation(Representation::PowerBasis);
c0.mod_switch_down_to(ct.c[0].ctx())?;
c1.mod_switch_down_to(ct.c[1].ctx())?;
c0.change_representation(Representation::Ntt);
c1.change_representation(Representation::Ntt);
}
c0 += &ct.c[0].substitute(&self.element)?;
c0 += &ct.c[0].substitute(&self.element)?;
Ok(Ciphertext {
par: ct.par.clone(),
seed: None,
c: vec![c0, c1],
level: self.ksk.ciphertext_level,
})
}
Ok(Ciphertext {
par: ct.par.clone(),
seed: None,
c: vec![c0, c1],
level: self.ksk.ciphertext_level,
})
}
}
impl From<&GaloisKey> for GaloisKeyProto {
fn from(value: &GaloisKey) -> Self {
let mut gk = GaloisKeyProto::new();
gk.exponent = value.element.exponent as u32;
gk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
gk
}
fn from(value: &GaloisKey) -> Self {
let mut gk = GaloisKeyProto::new();
gk.exponent = value.element.exponent as u32;
gk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
gk
}
}
impl TryConvertFrom<&GaloisKeyProto> for GaloisKey {
fn try_convert_from(value: &GaloisKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
if par.moduli.len() == 1 {
Err(Error::DefaultError(
"Invalid parameters for a relinearization key".to_string(),
))
} else if value.ksk.is_some() {
let ksk = KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?;
fn try_convert_from(value: &GaloisKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
if par.moduli.len() == 1 {
Err(Error::DefaultError(
"Invalid parameters for a relinearization key".to_string(),
))
} else if value.ksk.is_some() {
let ksk = KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?;
let ctx = par.ctx_at_level(ksk.ciphertext_level)?;
let element = SubstitutionExponent::new(ctx, value.exponent as usize)
.map_err(Error::MathError)?;
let ctx = par.ctx_at_level(ksk.ciphertext_level)?;
let element = SubstitutionExponent::new(ctx, value.exponent as usize)
.map_err(Error::MathError)?;
Ok(GaloisKey { element, ksk })
} else {
Err(Error::DefaultError("Invalid serialization".to_string()))
}
}
Ok(GaloisKey { element, ksk })
} else {
Err(Error::DefaultError("Invalid serialization".to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::GaloisKey;
use crate::bfv::{
proto::bfv::GaloisKey as GaloisKeyProto, traits::TryConvertFrom, BfvParameters, Encoding,
Plaintext, SecretKey,
};
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use super::GaloisKey;
use crate::bfv::{
proto::bfv::GaloisKey as GaloisKeyProto, traits::TryConvertFrom, BfvParameters, Encoding,
Plaintext, SecretKey,
};
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn relinearization() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
for _ in 0..30 {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let row_size = params.degree() >> 1;
#[test]
fn relinearization() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
for _ in 0..30 {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let row_size = params.degree() >> 1;
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
for i in 1..2 * params.degree() {
if i & 1 == 0 {
assert!(GaloisKey::new(&sk, i, 0, 0, &mut rng).is_err())
} else {
let gk = GaloisKey::new(&sk, i, 0, 0, &mut rng)?;
let ct2 = gk.relinearize(&ct)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? });
for i in 1..2 * params.degree() {
if i & 1 == 0 {
assert!(GaloisKey::new(&sk, i, 0, 0, &mut rng).is_err())
} else {
let gk = GaloisKey::new(&sk, i, 0, 0, &mut rng)?;
let ct2 = gk.relinearize(&ct)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct2)? });
if i == 3 {
let pt = sk.try_decrypt(&ct2)?;
if i == 3 {
let pt = sk.try_decrypt(&ct2)?;
// The expected result is rotated one on the left
let mut expected = vec![0u64; params.degree()];
expected[..row_size - 1].copy_from_slice(&v[1..row_size]);
expected[row_size - 1] = v[0];
expected[row_size..2 * row_size - 1]
.copy_from_slice(&v[row_size + 1..]);
expected[2 * row_size - 1] = v[row_size];
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
} else if i == params.degree() * 2 - 1 {
let pt = sk.try_decrypt(&ct2)?;
// The expected result is rotated one on the left
let mut expected = vec![0u64; params.degree()];
expected[..row_size - 1].copy_from_slice(&v[1..row_size]);
expected[row_size - 1] = v[0];
expected[row_size..2 * row_size - 1]
.copy_from_slice(&v[row_size + 1..]);
expected[2 * row_size - 1] = v[row_size];
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
} else if i == params.degree() * 2 - 1 {
let pt = sk.try_decrypt(&ct2)?;
// The expected result has its rows swapped
let mut expected = vec![0u64; params.degree()];
expected[..row_size].copy_from_slice(&v[row_size..]);
expected[row_size..].copy_from_slice(&v[..row_size]);
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
}
}
}
}
}
Ok(())
}
// The expected result has its rows swapped
let mut expected = vec![0u64; params.degree()];
expected[..row_size].copy_from_slice(&v[row_size..]);
expected[row_size..].copy_from_slice(&v[..row_size]);
assert_eq!(&Vec::<u64>::try_decode(&pt, Encoding::simd())?, &expected)
}
}
}
}
}
Ok(())
}
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(4, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?;
let proto = GaloisKeyProto::from(&gk);
assert_eq!(gk, GaloisKey::try_convert_from(&proto, &params)?);
}
Ok(())
}
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(4, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let gk = GaloisKey::new(&sk, 9, 0, 0, &mut rng)?;
let proto = GaloisKeyProto::from(&gk);
assert_eq!(gk, GaloisKey::try_convert_from(&proto, &params)?);
}
Ok(())
}
}

View File

@@ -1,15 +1,15 @@
//! Key-switching keys for the BFV encryption scheme
use crate::bfv::{
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto,
traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey,
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto,
traits::TryConvertFrom as BfvTryConvertFrom, BfvParameters, SecretKey,
};
use crate::{Error, Result};
use fhe_math::rq::traits::TryConvertFrom;
use fhe_math::rq::Context;
use fhe_math::{
rns::RnsContext,
rq::{Poly, Representation},
rns::RnsContext,
rq::{Poly, Representation},
};
use fhe_traits::{DeserializeWithContext, Serialize};
use itertools::izip;
@@ -21,335 +21,335 @@ use zeroize::Zeroizing;
/// Key switching key for the BFV encryption scheme.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct KeySwitchingKey {
/// The parameters of the underlying BFV encryption scheme.
pub(crate) par: Arc<BfvParameters>,
/// The parameters of the underlying BFV encryption scheme.
pub(crate) par: Arc<BfvParameters>,
/// The (optional) seed that generated the polynomials c1.
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
/// The (optional) seed that generated the polynomials c1.
pub(crate) seed: Option<<ChaCha8Rng as SeedableRng>::Seed>,
/// The key switching elements c0.
pub(crate) c0: Box<[Poly]>,
/// The key switching elements c0.
pub(crate) c0: Box<[Poly]>,
/// The key switching elements c1.
pub(crate) c1: Box<[Poly]>,
/// The key switching elements c1.
pub(crate) c1: Box<[Poly]>,
/// The level and context of the polynomials that will be key switched.
pub(crate) ciphertext_level: usize,
pub(crate) ctx_ciphertext: Arc<Context>,
/// The level and context of the polynomials that will be key switched.
pub(crate) ciphertext_level: usize,
pub(crate) ctx_ciphertext: Arc<Context>,
/// The level and context of the key switching key.
pub(crate) ksk_level: usize,
pub(crate) ctx_ksk: Arc<Context>,
/// The level and context of the key switching key.
pub(crate) ksk_level: usize,
pub(crate) ctx_ksk: Arc<Context>,
}
impl KeySwitchingKey {
/// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial
/// `from`.
pub fn new<R: RngCore + CryptoRng>(
sk: &SecretKey,
from: &Poly,
ciphertext_level: usize,
ksk_level: usize,
rng: &mut R,
) -> Result<Self> {
let ctx_ksk = sk.par.ctx_at_level(ksk_level)?;
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
/// Generate a [`KeySwitchingKey`] to this [`SecretKey`] from a polynomial
/// `from`.
pub fn new<R: RngCore + CryptoRng>(
sk: &SecretKey,
from: &Poly,
ciphertext_level: usize,
ksk_level: usize,
rng: &mut R,
) -> Result<Self> {
let ctx_ksk = sk.par.ctx_at_level(ksk_level)?;
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
if ctx_ksk.moduli().len() == 1 {
return Err(Error::DefaultError(
"These parameters do not support key switching".to_string(),
));
}
if ctx_ksk.moduli().len() == 1 {
return Err(Error::DefaultError(
"These parameters do not support key switching".to_string(),
));
}
if from.ctx() != ctx_ksk {
return Err(Error::DefaultError(
"Incorrect context for polynomial from".to_string(),
));
}
if from.ctx() != ctx_ksk {
return Err(Error::DefaultError(
"Incorrect context for polynomial from".to_string(),
));
}
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
rng.fill(&mut seed);
let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len());
let c0 = Self::generate_c0(sk, from, &c1, rng)?;
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
rng.fill(&mut seed);
let c1 = Self::generate_c1(ctx_ksk, seed, ctx_ciphertext.moduli().len());
let c0 = Self::generate_c0(sk, from, &c1, rng)?;
Ok(Self {
par: sk.par.clone(),
seed: Some(seed),
c0: c0.into_boxed_slice(),
c1: c1.into_boxed_slice(),
ciphertext_level,
ctx_ciphertext: ctx_ciphertext.clone(),
ksk_level,
ctx_ksk: ctx_ksk.clone(),
})
}
Ok(Self {
par: sk.par.clone(),
seed: Some(seed),
c0: c0.into_boxed_slice(),
c1: c1.into_boxed_slice(),
ciphertext_level,
ctx_ciphertext: ctx_ciphertext.clone(),
ksk_level,
ctx_ksk: ctx_ksk.clone(),
})
}
/// Generate the c1's from the seed
fn generate_c1(
ctx: &Arc<Context>,
seed: <ChaCha8Rng as SeedableRng>::Seed,
size: usize,
) -> Vec<Poly> {
let mut c1 = Vec::with_capacity(size);
let mut rng = ChaCha8Rng::from_seed(seed);
(0..size).for_each(|_| {
let mut seed_i = <ChaCha8Rng as SeedableRng>::Seed::default();
rng.fill(&mut seed_i);
let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i);
unsafe { a.allow_variable_time_computations() }
c1.push(a);
});
c1
}
/// Generate the c1's from the seed
fn generate_c1(
ctx: &Arc<Context>,
seed: <ChaCha8Rng as SeedableRng>::Seed,
size: usize,
) -> Vec<Poly> {
let mut c1 = Vec::with_capacity(size);
let mut rng = ChaCha8Rng::from_seed(seed);
(0..size).for_each(|_| {
let mut seed_i = <ChaCha8Rng as SeedableRng>::Seed::default();
rng.fill(&mut seed_i);
let mut a = Poly::random_from_seed(ctx, Representation::NttShoup, seed_i);
unsafe { a.allow_variable_time_computations() }
c1.push(a);
});
c1
}
/// Generate the c0's from the c1's and the secret key
fn generate_c0<R: RngCore + CryptoRng>(
sk: &SecretKey,
from: &Poly,
c1: &[Poly],
rng: &mut R,
) -> Result<Vec<Poly>> {
if c1.is_empty() {
return Err(Error::DefaultError("Empty number of c1's".to_string()));
}
if from.representation() != &Representation::PowerBasis {
return Err(Error::DefaultError(
"Unexpected representation for from".to_string(),
));
}
/// Generate the c0's from the c1's and the secret key
fn generate_c0<R: RngCore + CryptoRng>(
sk: &SecretKey,
from: &Poly,
c1: &[Poly],
rng: &mut R,
) -> Result<Vec<Poly>> {
if c1.is_empty() {
return Err(Error::DefaultError("Empty number of c1's".to_string()));
}
if from.representation() != &Representation::PowerBasis {
return Err(Error::DefaultError(
"Unexpected representation for from".to_string(),
));
}
let size = c1.len();
let size = c1.len();
let mut s = Zeroizing::new(Poly::try_convert_from(
sk.coeffs.as_ref(),
c1[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut s = Zeroizing::new(Poly::try_convert_from(
sk.coeffs.as_ref(),
c1[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let rns = RnsContext::new(&sk.par.moduli[..size])?;
let c0 = c1
.iter()
.enumerate()
.map(|(i, c1i)| {
let mut a_s = Zeroizing::new(c1i.clone());
a_s.disallow_variable_time_computations();
a_s.change_representation(Representation::Ntt);
*a_s.as_mut() *= s.as_ref();
a_s.change_representation(Representation::PowerBasis);
let rns = RnsContext::new(&sk.par.moduli[..size])?;
let c0 = c1
.iter()
.enumerate()
.map(|(i, c1i)| {
let mut a_s = Zeroizing::new(c1i.clone());
a_s.disallow_variable_time_computations();
a_s.change_representation(Representation::Ntt);
*a_s.as_mut() *= s.as_ref();
a_s.change_representation(Representation::PowerBasis);
let mut b =
Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?;
b -= &a_s;
let mut b =
Poly::small(a_s.ctx(), Representation::PowerBasis, sk.par.variance, rng)?;
b -= &a_s;
let gi = rns.get_garner(i).unwrap();
let g_i_from = Zeroizing::new(gi * from);
b += &g_i_from;
let gi = rns.get_garner(i).unwrap();
let g_i_from = Zeroizing::new(gi * from);
b += &g_i_from;
// It is now safe to enable variable time computations.
unsafe { b.allow_variable_time_computations() }
b.change_representation(Representation::NttShoup);
Ok(b)
})
.collect::<Result<Vec<Poly>>>()?;
// It is now safe to enable variable time computations.
unsafe { b.allow_variable_time_computations() }
b.change_representation(Representation::NttShoup);
Ok(b)
})
.collect::<Result<Vec<Poly>>>()?;
Ok(c0)
}
Ok(c0)
}
/// Key switch a polynomial.
pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> {
if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() {
return Err(Error::DefaultError(
"The input polynomial does not have the correct context.".to_string(),
));
}
if p.representation() != &Representation::PowerBasis {
return Err(Error::DefaultError("Incorrect representation".to_string()));
}
/// Key switch a polynomial.
pub fn key_switch(&self, p: &Poly) -> Result<(Poly, Poly)> {
if p.ctx().as_ref() != self.ctx_ciphertext.as_ref() {
return Err(Error::DefaultError(
"The input polynomial does not have the correct context.".to_string(),
));
}
if p.representation() != &Representation::PowerBasis {
return Err(Error::DefaultError("Incorrect representation".to_string()));
}
let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
for (c2_i_coefficients, c0_i, c1_i) in izip!(
p.coefficients().outer_iter(),
self.c0.iter(),
self.c1.iter()
) {
let mut c2_i = unsafe {
Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time(
c2_i_coefficients.as_slice().unwrap(),
&self.ctx_ksk,
)
};
c0 += &(&c2_i * c0_i);
c2_i *= c1_i;
c1 += &c2_i;
}
Ok((c0, c1))
}
let mut c0 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
let mut c1 = Poly::zero(&self.ctx_ksk, Representation::Ntt);
for (c2_i_coefficients, c0_i, c1_i) in izip!(
p.coefficients().outer_iter(),
self.c0.iter(),
self.c1.iter()
) {
let mut c2_i = unsafe {
Poly::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time(
c2_i_coefficients.as_slice().unwrap(),
&self.ctx_ksk,
)
};
c0 += &(&c2_i * c0_i);
c2_i *= c1_i;
c1 += &c2_i;
}
Ok((c0, c1))
}
}
impl From<&KeySwitchingKey> for KeySwitchingKeyProto {
fn from(value: &KeySwitchingKey) -> Self {
let mut ksk = KeySwitchingKeyProto::new();
if let Some(seed) = value.seed.as_ref() {
ksk.seed = seed.to_vec();
} else {
ksk.c1.reserve_exact(value.c1.len());
for c1 in value.c1.iter() {
ksk.c1.push(c1.to_bytes())
}
}
ksk.c0.reserve_exact(value.c0.len());
for c0 in value.c0.iter() {
ksk.c0.push(c0.to_bytes())
}
ksk.ciphertext_level = value.ciphertext_level as u32;
ksk.ksk_level = value.ksk_level as u32;
ksk
}
fn from(value: &KeySwitchingKey) -> Self {
let mut ksk = KeySwitchingKeyProto::new();
if let Some(seed) = value.seed.as_ref() {
ksk.seed = seed.to_vec();
} else {
ksk.c1.reserve_exact(value.c1.len());
for c1 in value.c1.iter() {
ksk.c1.push(c1.to_bytes())
}
}
ksk.c0.reserve_exact(value.c0.len());
for c0 in value.c0.iter() {
ksk.c0.push(c0.to_bytes())
}
ksk.ciphertext_level = value.ciphertext_level as u32;
ksk.ksk_level = value.ksk_level as u32;
ksk
}
}
impl BfvTryConvertFrom<&KeySwitchingKeyProto> for KeySwitchingKey {
fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
let ciphertext_level = value.ciphertext_level as usize;
let ksk_level = value.ksk_level as usize;
let ctx_ksk = par.ctx_at_level(ksk_level)?;
let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?;
fn try_convert_from(value: &KeySwitchingKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
let ciphertext_level = value.ciphertext_level as usize;
let ksk_level = value.ksk_level as usize;
let ctx_ksk = par.ctx_at_level(ksk_level)?;
let ctx_ciphertext = par.ctx_at_level(ciphertext_level)?;
if value.c0.len() != ctx_ciphertext.moduli().len() {
return Err(Error::DefaultError(
"Incorrect number of values in c0".to_string(),
));
}
if value.c0.len() != ctx_ciphertext.moduli().len() {
return Err(Error::DefaultError(
"Incorrect number of values in c0".to_string(),
));
}
let seed = if value.seed.is_empty() {
if value.c1.len() != ctx_ciphertext.moduli().len() {
return Err(Error::DefaultError(
"Incorrect number of values in c1".to_string(),
));
}
None
} else {
let unwrapped = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
if unwrapped.is_err() {
return Err(Error::DefaultError("Invalid seed".to_string()));
}
Some(unwrapped.unwrap())
};
let seed = if value.seed.is_empty() {
if value.c1.len() != ctx_ciphertext.moduli().len() {
return Err(Error::DefaultError(
"Incorrect number of values in c1".to_string(),
));
}
None
} else {
let unwrapped = <ChaCha8Rng as SeedableRng>::Seed::try_from(value.seed.clone());
if unwrapped.is_err() {
return Err(Error::DefaultError("Invalid seed".to_string()));
}
Some(unwrapped.unwrap())
};
let c1 = if let Some(seed) = seed {
Self::generate_c1(ctx_ksk, seed, value.c0.len())
} else {
value
.c1
.iter()
.map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError))
.collect::<Result<Vec<Poly>>>()?
};
let c1 = if let Some(seed) = seed {
Self::generate_c1(ctx_ksk, seed, value.c0.len())
} else {
value
.c1
.iter()
.map(|c1i| Poly::from_bytes(c1i, ctx_ksk).map_err(Error::MathError))
.collect::<Result<Vec<Poly>>>()?
};
let c0 = value
.c0
.iter()
.map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError))
.collect::<Result<Vec<Poly>>>()?;
let c0 = value
.c0
.iter()
.map(|c0i| Poly::from_bytes(c0i, ctx_ksk).map_err(Error::MathError))
.collect::<Result<Vec<Poly>>>()?;
Ok(Self {
par: par.clone(),
seed,
c0: c0.into_boxed_slice(),
c1: c1.into_boxed_slice(),
ciphertext_level,
ctx_ciphertext: ctx_ciphertext.clone(),
ksk_level,
ctx_ksk: ctx_ksk.clone(),
})
}
Ok(Self {
par: par.clone(),
seed,
c0: c0.into_boxed_slice(),
c1: c1.into_boxed_slice(),
ciphertext_level,
ctx_ciphertext: ctx_ciphertext.clone(),
ksk_level,
ctx_ksk: ctx_ksk.clone(),
})
}
}
#[cfg(test)]
mod tests {
use crate::bfv::{
keys::key_switching_key::KeySwitchingKey,
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, traits::TryConvertFrom, BfvParameters,
SecretKey,
};
use fhe_math::{
rns::RnsContext,
rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation},
};
use num_bigint::BigUint;
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use crate::bfv::{
keys::key_switching_key::KeySwitchingKey,
proto::bfv::KeySwitchingKey as KeySwitchingKeyProto, traits::TryConvertFrom, BfvParameters,
SecretKey,
};
use fhe_math::{
rns::RnsContext,
rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation},
};
use num_bigint::BigUint;
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn constructor() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let ctx = params.ctx_at_level(0)?;
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng);
assert!(ksk.is_ok());
}
Ok(())
}
#[test]
fn constructor() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let ctx = params.ctx_at_level(0)?;
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng);
assert!(ksk.is_ok());
}
Ok(())
}
#[test]
fn key_switch() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [Arc::new(BfvParameters::default(6, 8))] {
for _ in 0..100 {
let sk = SecretKey::random(&params, &mut rng);
let ctx = params.ctx_at_level(0)?;
let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
let mut s = Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)
.map_err(crate::Error::MathError)?;
s.change_representation(Representation::Ntt);
#[test]
fn key_switch() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [Arc::new(BfvParameters::default(6, 8))] {
for _ in 0..100 {
let sk = SecretKey::random(&params, &mut rng);
let ctx = params.ctx_at_level(0)?;
let mut p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
let mut s = Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)
.map_err(crate::Error::MathError)?;
s.change_representation(Representation::Ntt);
let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng);
let (c0, c1) = ksk.key_switch(&input)?;
let mut input = Poly::random(ctx, Representation::PowerBasis, &mut rng);
let (c0, c1) = ksk.key_switch(&input)?;
let mut c2 = &c0 + &(&c1 * &s);
c2.change_representation(Representation::PowerBasis);
let mut c2 = &c0 + &(&c1 * &s);
c2.change_representation(Representation::PowerBasis);
input.change_representation(Representation::Ntt);
p.change_representation(Representation::Ntt);
let mut c3 = &input * &p;
c3.change_representation(Representation::PowerBasis);
input.change_representation(Representation::Ntt);
p.change_representation(Representation::Ntt);
let mut c3 = &input * &p;
c3.change_representation(Representation::PowerBasis);
let rns = RnsContext::new(&params.moduli)?;
Vec::<BigUint>::from(&(&c2 - &c3)).iter().for_each(|b| {
assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70)
});
}
}
Ok(())
}
let rns = RnsContext::new(&params.moduli)?;
Vec::<BigUint>::from(&(&c2 - &c3)).iter().for_each(|b| {
assert!(std::cmp::min(b.bits(), (rns.modulus() - b).bits()) <= 70)
});
}
}
Ok(())
}
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let ctx = params.ctx_at_level(0)?;
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
let ksk_proto = KeySwitchingKeyProto::from(&ksk);
assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, &params)?);
}
Ok(())
}
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let ctx = params.ctx_at_level(0)?;
let p = Poly::small(ctx, Representation::PowerBasis, 10, &mut rng)?;
let ksk = KeySwitchingKey::new(&sk, &p, 0, 0, &mut rng)?;
let ksk_proto = KeySwitchingKeyProto::from(&ksk);
assert_eq!(ksk, KeySwitchingKey::try_convert_from(&ksk_proto, &params)?);
}
Ok(())
}
}

View File

@@ -2,8 +2,8 @@
use crate::bfv::traits::TryConvertFrom;
use crate::bfv::{
proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto},
BfvParameters, Ciphertext, Encoding, Plaintext,
proto::bfv::{Ciphertext as CiphertextProto, PublicKey as PublicKeyProto},
BfvParameters, Ciphertext, Encoding, Plaintext,
};
use crate::{Error, Result};
use fhe_math::rq::{Poly, Representation};
@@ -18,188 +18,188 @@ use super::SecretKey;
/// Public key for the BFV encryption scheme.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct PublicKey {
pub(crate) par: Arc<BfvParameters>,
pub(crate) c: Ciphertext,
pub(crate) par: Arc<BfvParameters>,
pub(crate) c: Ciphertext,
}
impl PublicKey {
/// Generate a new [`PublicKey`] from a [`SecretKey`].
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Self {
let zero = Plaintext::zero(Encoding::poly(), &sk.par).unwrap();
let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap();
// The polynomials of a public key should not allow for variable time
// computation.
c.c.iter_mut()
.for_each(|p| p.disallow_variable_time_computations());
Self {
par: sk.par.clone(),
c,
}
}
/// Generate a new [`PublicKey`] from a [`SecretKey`].
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Self {
let zero = Plaintext::zero(Encoding::poly(), &sk.par).unwrap();
let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap();
// The polynomials of a public key should not allow for variable time
// computation.
c.c.iter_mut()
.for_each(|p| p.disallow_variable_time_computations());
Self {
par: sk.par.clone(),
c,
}
}
}
impl FheParametrized for PublicKey {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl FheEncrypter<Plaintext, Ciphertext> for PublicKey {
type Error = Error;
type Error = Error;
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<Ciphertext> {
let mut ct = self.c.clone();
while ct.level != pt.level {
ct.mod_switch_to_next_level();
}
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<Ciphertext> {
let mut ct = self.c.clone();
while ct.level != pt.level {
ct.mod_switch_to_next_level();
}
let ctx = self.par.ctx_at_level(ct.level)?;
let u = Zeroizing::new(Poly::small(
ctx,
Representation::Ntt,
self.par.variance,
rng,
)?);
let e1 = Zeroizing::new(Poly::small(
ctx,
Representation::Ntt,
self.par.variance,
rng,
)?);
let e2 = Zeroizing::new(Poly::small(
ctx,
Representation::Ntt,
self.par.variance,
rng,
)?);
let ctx = self.par.ctx_at_level(ct.level)?;
let u = Zeroizing::new(Poly::small(
ctx,
Representation::Ntt,
self.par.variance,
rng,
)?);
let e1 = Zeroizing::new(Poly::small(
ctx,
Representation::Ntt,
self.par.variance,
rng,
)?);
let e2 = Zeroizing::new(Poly::small(
ctx,
Representation::Ntt,
self.par.variance,
rng,
)?);
let m = Zeroizing::new(pt.to_poly());
let mut c0 = u.as_ref() * &ct.c[0];
c0 += &e1;
c0 += &m;
let mut c1 = u.as_ref() * &ct.c[1];
c1 += &e2;
let m = Zeroizing::new(pt.to_poly());
let mut c0 = u.as_ref() * &ct.c[0];
c0 += &e1;
c0 += &m;
let mut c1 = u.as_ref() * &ct.c[1];
c1 += &e2;
// It is now safe to enable variable time computations.
unsafe {
c0.allow_variable_time_computations();
c1.allow_variable_time_computations()
}
// It is now safe to enable variable time computations.
unsafe {
c0.allow_variable_time_computations();
c1.allow_variable_time_computations()
}
Ok(Ciphertext {
par: self.par.clone(),
seed: None,
c: vec![c0, c1],
level: ct.level,
})
}
Ok(Ciphertext {
par: self.par.clone(),
seed: None,
c: vec![c0, c1],
level: ct.level,
})
}
}
impl From<&PublicKey> for PublicKeyProto {
fn from(pk: &PublicKey) -> Self {
let mut proto = PublicKeyProto::new();
proto.c = MessageField::some(CiphertextProto::from(&pk.c));
proto
}
fn from(pk: &PublicKey) -> Self {
let mut proto = PublicKeyProto::new();
proto.c = MessageField::some(CiphertextProto::from(&pk.c));
proto
}
}
impl Serialize for PublicKey {
fn to_bytes(&self) -> Vec<u8> {
PublicKeyProto::from(self).write_to_bytes().unwrap()
}
fn to_bytes(&self) -> Vec<u8> {
PublicKeyProto::from(self).write_to_bytes().unwrap()
}
}
impl DeserializeParametrized for PublicKey {
type Error = Error;
type Error = Error;
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
let proto =
PublicKeyProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
if proto.c.is_some() {
let mut c = Ciphertext::try_convert_from(&proto.c.unwrap(), par)?;
if c.level != 0 {
Err(Error::SerializationError)
} else {
// The polynomials of a public key should not allow for variable time
// computation.
c.c.iter_mut()
.for_each(|p| p.disallow_variable_time_computations());
Ok(Self {
par: par.clone(),
c,
})
}
} else {
Err(Error::SerializationError)
}
}
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
let proto =
PublicKeyProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
if proto.c.is_some() {
let mut c = Ciphertext::try_convert_from(&proto.c.unwrap(), par)?;
if c.level != 0 {
Err(Error::SerializationError)
} else {
// The polynomials of a public key should not allow for variable time
// computation.
c.c.iter_mut()
.for_each(|p| p.disallow_variable_time_computations());
Ok(Self {
par: par.clone(),
c,
})
}
} else {
Err(Error::SerializationError)
}
}
}
#[cfg(test)]
mod tests {
use super::PublicKey;
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext, SecretKey};
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use super::PublicKey;
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext, SecretKey};
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn keygen() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
assert_eq!(pk.par, params);
assert_eq!(
sk.try_decrypt(&pk.c)?,
Plaintext::zero(Encoding::poly(), &params)?
);
Ok(())
}
#[test]
fn keygen() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
assert_eq!(pk.par, params);
assert_eq!(
sk.try_decrypt(&pk.c)?,
Plaintext::zero(Encoding::poly(), &params)?
);
Ok(())
}
#[test]
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
for level in 0..params.max_level() {
for _ in 0..20 {
let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
#[test]
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
for level in 0..params.max_level() {
for _ in 0..20 {
let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
let pt = Plaintext::try_encode(
&params.plaintext.random_vec(params.degree(), &mut rng),
Encoding::poly_at_level(level),
&params,
)?;
let ct = pk.try_encrypt(&pt, &mut rng)?;
let pt2 = sk.try_decrypt(&ct)?;
let pt = Plaintext::try_encode(
&params.plaintext.random_vec(params.degree(), &mut rng),
Encoding::poly_at_level(level),
&params,
)?;
let ct = pk.try_encrypt(&pt, &mut rng)?;
let pt2 = sk.try_decrypt(&ct)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
assert_eq!(pt2, pt);
}
}
}
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
assert_eq!(pt2, pt);
}
}
}
Ok(())
}
Ok(())
}
#[test]
fn test_serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
let bytes = pk.to_bytes();
assert_eq!(pk, PublicKey::from_bytes(&bytes, &params)?);
}
Ok(())
}
#[test]
fn test_serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
let bytes = pk.to_bytes();
assert_eq!(pk, PublicKey::from_bytes(&bytes, &params)?);
}
Ok(())
}
}

View File

@@ -4,15 +4,15 @@ use std::sync::Arc;
use super::key_switching_key::KeySwitchingKey;
use crate::bfv::{
proto::bfv::{
KeySwitchingKey as KeySwitchingKeyProto, RelinearizationKey as RelinearizationKeyProto,
},
traits::TryConvertFrom,
BfvParameters, Ciphertext, SecretKey,
proto::bfv::{
KeySwitchingKey as KeySwitchingKeyProto, RelinearizationKey as RelinearizationKeyProto,
},
traits::TryConvertFrom,
BfvParameters, Ciphertext, SecretKey,
};
use crate::{Error, Result};
use fhe_math::rq::{
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
switcher::Switcher, traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation,
};
use fhe_traits::{DeserializeParametrized, FheParametrized, Serialize};
use protobuf::{Message, MessageField};
@@ -24,284 +24,284 @@ use zeroize::Zeroizing;
/// which switch from `s^2` to `s` where `s` is the secret key.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct RelinearizationKey {
pub(crate) ksk: KeySwitchingKey,
pub(crate) ksk: KeySwitchingKey,
}
impl RelinearizationKey {
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Result<Self> {
Self::new_leveled_internal(sk, 0, 0, rng)
}
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
pub fn new<R: RngCore + CryptoRng>(sk: &SecretKey, rng: &mut R) -> Result<Self> {
Self::new_leveled_internal(sk, 0, 0, rng)
}
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
pub fn new_leveled<R: RngCore + CryptoRng>(
sk: &SecretKey,
ciphertext_level: usize,
key_level: usize,
rng: &mut R,
) -> Result<Self> {
Self::new_leveled_internal(sk, ciphertext_level, key_level, rng)
}
/// Generate a [`RelinearizationKey`] from a [`SecretKey`].
pub fn new_leveled<R: RngCore + CryptoRng>(
sk: &SecretKey,
ciphertext_level: usize,
key_level: usize,
rng: &mut R,
) -> Result<Self> {
Self::new_leveled_internal(sk, ciphertext_level, key_level, rng)
}
fn new_leveled_internal<R: RngCore + CryptoRng>(
sk: &SecretKey,
ciphertext_level: usize,
key_level: usize,
rng: &mut R,
) -> Result<Self> {
let ctx_relin_key = sk.par.ctx_at_level(key_level)?;
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
fn new_leveled_internal<R: RngCore + CryptoRng>(
sk: &SecretKey,
ciphertext_level: usize,
key_level: usize,
rng: &mut R,
) -> Result<Self> {
let ctx_relin_key = sk.par.ctx_at_level(key_level)?;
let ctx_ciphertext = sk.par.ctx_at_level(ciphertext_level)?;
if ctx_relin_key.moduli().len() == 1 {
return Err(Error::DefaultError(
"These parameters do not support key switching".to_string(),
));
}
if ctx_relin_key.moduli().len() == 1 {
return Err(Error::DefaultError(
"These parameters do not support key switching".to_string(),
));
}
let mut s = Zeroizing::new(Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx_ciphertext,
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref());
s2.change_representation(Representation::PowerBasis);
let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?;
let s2_switched_up = Zeroizing::new(s2.mod_switch_to(&switcher_up)?);
let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?;
Ok(Self { ksk })
}
let mut s = Zeroizing::new(Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx_ciphertext,
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut s2 = Zeroizing::new(s.as_ref() * s.as_ref());
s2.change_representation(Representation::PowerBasis);
let switcher_up = Switcher::new(ctx_ciphertext, ctx_relin_key)?;
let s2_switched_up = Zeroizing::new(s2.mod_switch_to(&switcher_up)?);
let ksk = KeySwitchingKey::new(sk, &s2_switched_up, ciphertext_level, key_level, rng)?;
Ok(Self { ksk })
}
/// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`]
pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> {
if ct.c.len() != 3 {
Err(Error::DefaultError(
"Only supports relinearization of ciphertext with 3 parts".to_string(),
))
} else if ct.level != self.ksk.ciphertext_level {
Err(Error::DefaultError(
"Ciphertext has incorrect level".to_string(),
))
} else {
let mut c2 = ct.c[2].clone();
c2.change_representation(Representation::PowerBasis);
/// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`]
pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> {
if ct.c.len() != 3 {
Err(Error::DefaultError(
"Only supports relinearization of ciphertext with 3 parts".to_string(),
))
} else if ct.level != self.ksk.ciphertext_level {
Err(Error::DefaultError(
"Ciphertext has incorrect level".to_string(),
))
} else {
let mut c2 = ct.c[2].clone();
c2.change_representation(Representation::PowerBasis);
#[allow(unused_mut)]
let (mut c0, mut c1) = self.relinearizes_poly(&c2)?;
#[allow(unused_mut)]
let (mut c0, mut c1) = self.relinearizes_poly(&c2)?;
if c0.ctx() != ct.c[0].ctx() {
c0.change_representation(Representation::PowerBasis);
c1.change_representation(Representation::PowerBasis);
c0.mod_switch_down_to(ct.c[0].ctx())?;
c1.mod_switch_down_to(ct.c[1].ctx())?;
c0.change_representation(Representation::Ntt);
c1.change_representation(Representation::Ntt);
}
if c0.ctx() != ct.c[0].ctx() {
c0.change_representation(Representation::PowerBasis);
c1.change_representation(Representation::PowerBasis);
c0.mod_switch_down_to(ct.c[0].ctx())?;
c1.mod_switch_down_to(ct.c[1].ctx())?;
c0.change_representation(Representation::Ntt);
c1.change_representation(Representation::Ntt);
}
ct.c[0] += &c0;
ct.c[1] += &c1;
ct.c.truncate(2);
Ok(())
}
}
ct.c[0] += &c0;
ct.c[1] += &c1;
ct.c.truncate(2);
Ok(())
}
}
/// Relinearize using polynomials.
pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> {
self.ksk.key_switch(c2)
}
/// Relinearize using polynomials.
pub(crate) fn relinearizes_poly(&self, c2: &Poly) -> Result<(Poly, Poly)> {
self.ksk.key_switch(c2)
}
}
impl From<&RelinearizationKey> for RelinearizationKeyProto {
fn from(value: &RelinearizationKey) -> Self {
let mut rk = RelinearizationKeyProto::new();
rk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
rk
}
fn from(value: &RelinearizationKey) -> Self {
let mut rk = RelinearizationKeyProto::new();
rk.ksk = MessageField::some(KeySwitchingKeyProto::from(&value.ksk));
rk
}
}
impl TryConvertFrom<&RelinearizationKeyProto> for RelinearizationKey {
fn try_convert_from(value: &RelinearizationKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
if par.moduli.len() == 1 {
Err(Error::DefaultError(
"Invalid parameters for a relinearization key".to_string(),
))
} else if value.ksk.is_some() {
Ok(RelinearizationKey {
ksk: KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?,
})
} else {
Err(Error::DefaultError("Invalid serialization".to_string()))
}
}
fn try_convert_from(value: &RelinearizationKeyProto, par: &Arc<BfvParameters>) -> Result<Self> {
if par.moduli.len() == 1 {
Err(Error::DefaultError(
"Invalid parameters for a relinearization key".to_string(),
))
} else if value.ksk.is_some() {
Ok(RelinearizationKey {
ksk: KeySwitchingKey::try_convert_from(value.ksk.as_ref().unwrap(), par)?,
})
} else {
Err(Error::DefaultError("Invalid serialization".to_string()))
}
}
}
impl Serialize for RelinearizationKey {
fn to_bytes(&self) -> Vec<u8> {
RelinearizationKeyProto::from(self)
.write_to_bytes()
.unwrap()
}
fn to_bytes(&self) -> Vec<u8> {
RelinearizationKeyProto::from(self)
.write_to_bytes()
.unwrap()
}
}
impl FheParametrized for RelinearizationKey {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl DeserializeParametrized for RelinearizationKey {
type Error = Error;
type Error = Error;
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
let rk = RelinearizationKeyProto::parse_from_bytes(bytes);
if let Ok(rk) = rk {
RelinearizationKey::try_convert_from(&rk, par)
} else {
Err(Error::DefaultError("Invalid serialization".to_string()))
}
}
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
let rk = RelinearizationKeyProto::parse_from_bytes(bytes);
if let Ok(rk) = rk {
RelinearizationKey::try_convert_from(&rk, par)
} else {
Err(Error::DefaultError("Invalid serialization".to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::RelinearizationKey;
use crate::bfv::{
proto::bfv::RelinearizationKey as RelinearizationKeyProto, traits::TryConvertFrom,
BfvParameters, Ciphertext, Encoding, SecretKey,
};
use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation};
use fhe_traits::{FheDecoder, FheDecrypter};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use super::RelinearizationKey;
use crate::bfv::{
proto::bfv::RelinearizationKey as RelinearizationKeyProto, traits::TryConvertFrom,
BfvParameters, Ciphertext, Encoding, SecretKey,
};
use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation};
use fhe_traits::{FheDecoder, FheDecrypter};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn relinearization() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [Arc::new(BfvParameters::default(6, 8))] {
for _ in 0..100 {
let sk = SecretKey::random(&params, &mut rng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
#[test]
fn relinearization() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [Arc::new(BfvParameters::default(6, 8))] {
for _ in 0..100 {
let sk = SecretKey::random(&params, &mut rng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let ctx = params.ctx_at_level(0)?;
let mut s = Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)
.map_err(crate::Error::MathError)?;
s.change_representation(Representation::Ntt);
let s2 = &s * &s;
let ctx = params.ctx_at_level(0)?;
let mut s = Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)
.map_err(crate::Error::MathError)?;
s.change_representation(Representation::Ntt);
let s2 = &s * &s;
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2,
// c1, c2) encrypting 0.
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
c0.change_representation(Representation::Ntt);
c0 -= &(&c1 * &s);
c0 -= &(&c2 * &s2);
let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?;
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 * s^2,
// c1, c2) encrypting 0.
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
c0.change_representation(Representation::Ntt);
c0 -= &(&c1 * &s);
c0 -= &(&c2 * &s2);
let mut ct = Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?;
// Relinearize the extended ciphertext!
rk.relinearizes(&mut ct)?;
assert_eq!(ct.c.len(), 2);
// Relinearize the extended ciphertext!
rk.relinearizes(&mut ct)?;
assert_eq!(ct.c.len(), 2);
// Check that the relinearization by polynomials works the same way
c2.change_representation(Representation::PowerBasis);
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
c0r.change_representation(Representation::PowerBasis);
c0r.mod_switch_down_to(c0.ctx())?;
c1r.change_representation(Representation::PowerBasis);
c1r.mod_switch_down_to(c1.ctx())?;
c0r.change_representation(Representation::Ntt);
c1r.change_representation(Representation::Ntt);
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], &params)?);
// Check that the relinearization by polynomials works the same way
c2.change_representation(Representation::PowerBasis);
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
c0r.change_representation(Representation::PowerBasis);
c0r.mod_switch_down_to(c0.ctx())?;
c1r.change_representation(Representation::PowerBasis);
c1r.mod_switch_down_to(c1.ctx())?;
c0r.change_representation(Representation::Ntt);
c1r.change_representation(Representation::Ntt);
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], &params)?);
// Print the noise and decrypt
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
let pt = sk.try_decrypt(&ct)?;
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
assert_eq!(w, &[0u64; 8]);
}
}
Ok(())
}
// Print the noise and decrypt
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
let pt = sk.try_decrypt(&ct)?;
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
assert_eq!(w, &[0u64; 8]);
}
}
Ok(())
}
#[test]
fn relinearization_leveled() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [Arc::new(BfvParameters::default(5, 8))] {
for ciphertext_level in 0..3 {
for key_level in 0..ciphertext_level {
for _ in 0..10 {
let sk = SecretKey::random(&params, &mut rng);
let rk = RelinearizationKey::new_leveled(
&sk,
ciphertext_level,
key_level,
&mut rng,
)?;
#[test]
fn relinearization_leveled() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [Arc::new(BfvParameters::default(5, 8))] {
for ciphertext_level in 0..3 {
for key_level in 0..ciphertext_level {
for _ in 0..10 {
let sk = SecretKey::random(&params, &mut rng);
let rk = RelinearizationKey::new_leveled(
&sk,
ciphertext_level,
key_level,
&mut rng,
)?;
let ctx = params.ctx_at_level(ciphertext_level)?;
let mut s = Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)
.map_err(crate::Error::MathError)?;
s.change_representation(Representation::Ntt);
let s2 = &s * &s;
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 *
// s^2, c1, c2) encrypting 0.
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
c0.change_representation(Representation::Ntt);
c0 -= &(&c1 * &s);
c0 -= &(&c2 * &s2);
let mut ct =
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?;
let ctx = params.ctx_at_level(ciphertext_level)?;
let mut s = Poly::try_convert_from(
sk.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)
.map_err(crate::Error::MathError)?;
s.change_representation(Representation::Ntt);
let s2 = &s * &s;
// Let's generate manually an "extended" ciphertext (c0 = e - c1 * s - c2 *
// s^2, c1, c2) encrypting 0.
let mut c2 = Poly::random(ctx, Representation::Ntt, &mut rng);
let c1 = Poly::random(ctx, Representation::Ntt, &mut rng);
let mut c0 = Poly::small(ctx, Representation::PowerBasis, 16, &mut rng)?;
c0.change_representation(Representation::Ntt);
c0 -= &(&c1 * &s);
c0 -= &(&c2 * &s2);
let mut ct =
Ciphertext::new(vec![c0.clone(), c1.clone(), c2.clone()], &params)?;
// Relinearize the extended ciphertext!
rk.relinearizes(&mut ct)?;
assert_eq!(ct.c.len(), 2);
// Relinearize the extended ciphertext!
rk.relinearizes(&mut ct)?;
assert_eq!(ct.c.len(), 2);
// Check that the relinearization by polynomials works the same way
c2.change_representation(Representation::PowerBasis);
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
c0r.change_representation(Representation::PowerBasis);
c0r.mod_switch_down_to(c0.ctx())?;
c1r.change_representation(Representation::PowerBasis);
c1r.mod_switch_down_to(c1.ctx())?;
c0r.change_representation(Representation::Ntt);
c1r.change_representation(Representation::Ntt);
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], &params)?);
// Check that the relinearization by polynomials works the same way
c2.change_representation(Representation::PowerBasis);
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c2)?;
c0r.change_representation(Representation::PowerBasis);
c0r.mod_switch_down_to(c0.ctx())?;
c1r.change_representation(Representation::PowerBasis);
c1r.mod_switch_down_to(c1.ctx())?;
c0r.change_representation(Representation::Ntt);
c1r.change_representation(Representation::Ntt);
assert_eq!(ct, Ciphertext::new(vec![&c0 + &c0r, &c1 + &c1r], &params)?);
// Print the noise and decrypt
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
let pt = sk.try_decrypt(&ct)?;
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
assert_eq!(w, &[0u64; 8]);
}
}
}
}
Ok(())
}
// Print the noise and decrypt
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
let pt = sk.try_decrypt(&ct)?;
let w = Vec::<u64>::try_decode(&pt, Encoding::poly())?;
assert_eq!(w, &[0u64; 8]);
}
}
}
}
Ok(())
}
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let proto = RelinearizationKeyProto::from(&rk);
assert_eq!(rk, RelinearizationKey::try_convert_from(&proto, &params)?);
}
Ok(())
}
#[test]
fn proto_conversion() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(3, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let proto = RelinearizationKeyProto::from(&rk);
assert_eq!(rk, RelinearizationKey::try_convert_from(&proto, &params)?);
}
Ok(())
}
}

View File

@@ -3,8 +3,8 @@
use crate::bfv::{BfvParameters, Ciphertext, Plaintext};
use crate::{Error, Result};
use fhe_math::{
rq::{traits::TryConvertFrom, Poly, Representation},
zq::Modulus,
rq::{traits::TryConvertFrom, Poly, Representation},
zq::Modulus,
};
use fhe_traits::{FheDecrypter, FheEncrypter, FheParametrized};
use fhe_util::sample_vec_cbd;
@@ -18,249 +18,249 @@ use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
/// Secret key for the BFV encryption scheme.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SecretKey {
pub(crate) par: Arc<BfvParameters>,
pub(crate) coeffs: Box<[i64]>,
pub(crate) par: Arc<BfvParameters>,
pub(crate) coeffs: Box<[i64]>,
}
impl Zeroize for SecretKey {
fn zeroize(&mut self) {
self.coeffs.zeroize();
}
fn zeroize(&mut self) {
self.coeffs.zeroize();
}
}
impl ZeroizeOnDrop for SecretKey {}
impl SecretKey {
/// Generate a random [`SecretKey`].
pub fn random<R: RngCore + CryptoRng>(par: &Arc<BfvParameters>, rng: &mut R) -> Self {
let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap();
Self::new(s_coefficients, par)
}
/// Generate a random [`SecretKey`].
pub fn random<R: RngCore + CryptoRng>(par: &Arc<BfvParameters>, rng: &mut R) -> Self {
let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap();
Self::new(s_coefficients, par)
}
/// Generate a [`SecretKey`] from its coefficients.
pub(crate) fn new(coeffs: Vec<i64>, par: &Arc<BfvParameters>) -> Self {
Self {
par: par.clone(),
coeffs: coeffs.into_boxed_slice(),
}
}
/// Generate a [`SecretKey`] from its coefficients.
pub(crate) fn new(coeffs: Vec<i64>, par: &Arc<BfvParameters>) -> Self {
Self {
par: par.clone(),
coeffs: coeffs.into_boxed_slice(),
}
}
/// Measure the noise in a [`Ciphertext`].
///
/// # Safety
///
/// This operations may run in a variable time depending on the value of the
/// noise.
pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result<usize> {
let plaintext = Zeroizing::new(self.try_decrypt(ct)?);
let m = Zeroizing::new(plaintext.to_poly());
/// Measure the noise in a [`Ciphertext`].
///
/// # Safety
///
/// This operations may run in a variable time depending on the value of the
/// noise.
pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result<usize> {
let plaintext = Zeroizing::new(self.try_decrypt(ct)?);
let m = Zeroizing::new(plaintext.to_poly());
// Let's create a secret key with the ciphertext context
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ct.c[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut si = s.clone();
// Let's create a secret key with the ciphertext context
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ct.c[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut si = s.clone();
// Let's disable variable time computations
let mut c = Zeroizing::new(ct.c[0].clone());
c.disallow_variable_time_computations();
// Let's disable variable time computations
let mut c = Zeroizing::new(ct.c[0].clone());
c.disallow_variable_time_computations();
for i in 1..ct.c.len() {
let mut cis = Zeroizing::new(ct.c[i].clone());
cis.disallow_variable_time_computations();
*cis.as_mut() *= si.as_ref();
*c.as_mut() += &cis;
*si.as_mut() *= s.as_ref();
}
*c.as_mut() -= &m;
c.change_representation(Representation::PowerBasis);
for i in 1..ct.c.len() {
let mut cis = Zeroizing::new(ct.c[i].clone());
cis.disallow_variable_time_computations();
*cis.as_mut() *= si.as_ref();
*c.as_mut() += &cis;
*si.as_mut() *= s.as_ref();
}
*c.as_mut() -= &m;
c.change_representation(Representation::PowerBasis);
let ciphertext_modulus = ct.c[0].ctx().modulus();
let mut noise = 0usize;
for coeff in Vec::<BigUint>::from(c.as_ref()) {
noise = std::cmp::max(
noise,
std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize,
)
}
let ciphertext_modulus = ct.c[0].ctx().modulus();
let mut noise = 0usize;
for coeff in Vec::<BigUint>::from(c.as_ref()) {
noise = std::cmp::max(
noise,
std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize,
)
}
Ok(noise)
}
Ok(noise)
}
pub(crate) fn encrypt_poly<R: RngCore + CryptoRng>(
&self,
p: &Poly,
rng: &mut R,
) -> Result<Ciphertext> {
assert_eq!(p.representation(), &Representation::Ntt);
pub(crate) fn encrypt_poly<R: RngCore + CryptoRng>(
&self,
p: &Poly,
rng: &mut R,
) -> Result<Ciphertext> {
assert_eq!(p.representation(), &Representation::Ntt);
let level = self.par.level_of_ctx(p.ctx())?;
let level = self.par.level_of_ctx(p.ctx())?;
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
thread_rng().fill(&mut seed);
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
thread_rng().fill(&mut seed);
// Let's create a secret key with the ciphertext context
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
p.ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
// Let's create a secret key with the ciphertext context
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
p.ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed);
let a_s = Zeroizing::new(&a * s.as_ref());
let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed);
let a_s = Zeroizing::new(&a * s.as_ref());
let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng)
.map_err(Error::MathError)?;
b -= &a_s;
b += p;
let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng)
.map_err(Error::MathError)?;
b -= &a_s;
b += p;
// It is now safe to enable variable time computations.
unsafe {
a.allow_variable_time_computations();
b.allow_variable_time_computations()
}
// It is now safe to enable variable time computations.
unsafe {
a.allow_variable_time_computations();
b.allow_variable_time_computations()
}
Ok(Ciphertext {
par: self.par.clone(),
seed: Some(seed),
c: vec![b, a],
level,
})
}
Ok(Ciphertext {
par: self.par.clone(),
seed: Some(seed),
c: vec![b, a],
level,
})
}
}
impl FheParametrized for SecretKey {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl FheEncrypter<Plaintext, Ciphertext> for SecretKey {
type Error = Error;
type Error = Error;
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<Ciphertext> {
assert_eq!(self.par, pt.par);
let m = Zeroizing::new(pt.to_poly());
self.encrypt_poly(m.as_ref(), rng)
}
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<Ciphertext> {
assert_eq!(self.par, pt.par);
let m = Zeroizing::new(pt.to_poly());
self.encrypt_poly(m.as_ref(), rng)
}
}
impl FheDecrypter<Plaintext, Ciphertext> for SecretKey {
type Error = Error;
type Error = Error;
fn try_decrypt(&self, ct: &Ciphertext) -> Result<Plaintext> {
if self.par != ct.par {
Err(Error::DefaultError(
"Incompatible BFV parameters".to_string(),
))
} else {
// Let's create a secret key with the ciphertext context
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ct.c[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut si = s.clone();
fn try_decrypt(&self, ct: &Ciphertext) -> Result<Plaintext> {
if self.par != ct.par {
Err(Error::DefaultError(
"Incompatible BFV parameters".to_string(),
))
} else {
// Let's create a secret key with the ciphertext context
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ct.c[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut si = s.clone();
let mut c = Zeroizing::new(ct.c[0].clone());
c.disallow_variable_time_computations();
let mut c = Zeroizing::new(ct.c[0].clone());
c.disallow_variable_time_computations();
for i in 1..ct.c.len() {
let mut cis = Zeroizing::new(ct.c[i].clone());
cis.disallow_variable_time_computations();
*cis.as_mut() *= si.as_ref();
*c.as_mut() += &cis;
*si.as_mut() *= s.as_ref();
}
c.change_representation(Representation::PowerBasis);
for i in 1..ct.c.len() {
let mut cis = Zeroizing::new(ct.c[i].clone());
cis.disallow_variable_time_computations();
*cis.as_mut() *= si.as_ref();
*c.as_mut() += &cis;
*si.as_mut() *= s.as_ref();
}
c.change_representation(Representation::PowerBasis);
let d = Zeroizing::new(c.scale(&self.par.scalers[ct.level])?);
let d = Zeroizing::new(c.scale(&self.par.scalers[ct.level])?);
// TODO: Can we handle plaintext moduli that are BigUint?
let v = Zeroizing::new(
Vec::<u64>::from(d.as_ref())
.iter_mut()
.map(|vi| *vi + self.par.plaintext.modulus())
.collect_vec(),
);
let mut w = v[..self.par.degree()].to_vec();
let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?;
q.reduce_vec(&mut w);
self.par.plaintext.reduce_vec(&mut w);
// TODO: Can we handle plaintext moduli that are BigUint?
let v = Zeroizing::new(
Vec::<u64>::from(d.as_ref())
.iter_mut()
.map(|vi| *vi + self.par.plaintext.modulus())
.collect_vec(),
);
let mut w = v[..self.par.degree()].to_vec();
let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?;
q.reduce_vec(&mut w);
self.par.plaintext.reduce_vec(&mut w);
let mut poly =
Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
let mut poly =
Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
let pt = Plaintext {
par: self.par.clone(),
value: w.into_boxed_slice(),
encoding: None,
poly_ntt: poly,
level: ct.level,
};
let pt = Plaintext {
par: self.par.clone(),
value: w.into_boxed_slice(),
encoding: None,
poly_ntt: poly,
level: ct.level,
};
Ok(pt)
}
}
Ok(pt)
}
}
}
#[cfg(test)]
mod tests {
use super::SecretKey;
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext};
use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use super::SecretKey;
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext};
use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn keygen() {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let sk = SecretKey::random(&params, &mut rng);
assert_eq!(sk.par, params);
#[test]
fn keygen() {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let sk = SecretKey::random(&params, &mut rng);
assert_eq!(sk.par, params);
sk.coeffs.iter().for_each(|ci| {
// Check that this is a small polynomial
assert!((*ci).abs() <= 2 * sk.par.variance as i64)
})
}
sk.coeffs.iter().for_each(|ci| {
// Check that this is a small polynomial
assert!((*ci).abs() <= 2 * sk.par.variance as i64)
})
}
#[test]
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
for level in 0..params.max_level() {
for _ in 0..20 {
let sk = SecretKey::random(&params, &mut rng);
#[test]
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(6, 8)),
] {
for level in 0..params.max_level() {
for _ in 0..20 {
let sk = SecretKey::random(&params, &mut rng);
let pt = Plaintext::try_encode(
&params.plaintext.random_vec(params.degree(), &mut rng),
Encoding::poly_at_level(level),
&params,
)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
let pt2 = sk.try_decrypt(&ct)?;
let pt = Plaintext::try_encode(
&params.plaintext.random_vec(params.degree(), &mut rng),
Encoding::poly_at_level(level),
&params,
)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
let pt2 = sk.try_decrypt(&ct)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
assert_eq!(pt2, pt);
}
}
}
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
assert_eq!(pt2, pt);
}
}
}
Ok(())
}
Ok(())
}
}

View File

@@ -5,46 +5,46 @@ use itertools::{izip, Itertools};
use ndarray::{Array, Array2};
use crate::{
bfv::{Ciphertext, Plaintext},
Error, Result,
bfv::{Ciphertext, Plaintext},
Error, Result,
};
/// Computes the Fused-Mul-Add operation `out[i] += x[i] * y[i]`
unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) {
let n = out.len();
assert_eq!(x.len(), n);
assert_eq!(y.len(), n);
let n = out.len();
assert_eq!(x.len(), n);
assert_eq!(y.len(), n);
macro_rules! fma_at {
($idx:expr) => {
*out.get_unchecked_mut($idx) +=
(*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128);
};
}
macro_rules! fma_at {
($idx:expr) => {
*out.get_unchecked_mut($idx) +=
(*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128);
};
}
let r = n / 16;
for i in 0..r {
fma_at!(16 * i);
fma_at!(16 * i + 1);
fma_at!(16 * i + 2);
fma_at!(16 * i + 3);
fma_at!(16 * i + 4);
fma_at!(16 * i + 5);
fma_at!(16 * i + 6);
fma_at!(16 * i + 7);
fma_at!(16 * i + 8);
fma_at!(16 * i + 9);
fma_at!(16 * i + 10);
fma_at!(16 * i + 11);
fma_at!(16 * i + 12);
fma_at!(16 * i + 13);
fma_at!(16 * i + 14);
fma_at!(16 * i + 15);
}
let r = n / 16;
for i in 0..r {
fma_at!(16 * i);
fma_at!(16 * i + 1);
fma_at!(16 * i + 2);
fma_at!(16 * i + 3);
fma_at!(16 * i + 4);
fma_at!(16 * i + 5);
fma_at!(16 * i + 6);
fma_at!(16 * i + 7);
fma_at!(16 * i + 8);
fma_at!(16 * i + 9);
fma_at!(16 * i + 10);
fma_at!(16 * i + 11);
fma_at!(16 * i + 12);
fma_at!(16 * i + 13);
fma_at!(16 * i + 14);
fma_at!(16 * i + 15);
}
for i in 0..n % 16 {
fma_at!(16 * r + i);
}
for i in 0..n % 16 {
fma_at!(16 * r + i);
}
}
/// Compute the dot product between an iterator of [`Ciphertext`] and an
@@ -53,146 +53,146 @@ unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) {
/// number of parts.
pub fn dot_product_scalar<'a, I, J>(ct: I, pt: J) -> Result<Ciphertext>
where
I: Iterator<Item = &'a Ciphertext> + Clone,
J: Iterator<Item = &'a Plaintext> + Clone,
I: Iterator<Item = &'a Ciphertext> + Clone,
J: Iterator<Item = &'a Plaintext> + Clone,
{
let count = min(ct.clone().count(), pt.clone().count());
if count == 0 {
return Err(Error::DefaultError(
"At least one iterator is empty".to_string(),
));
}
let ct_first = ct.clone().next().unwrap();
let ctx = ct_first.c[0].ctx();
let count = min(ct.clone().count(), pt.clone().count());
if count == 0 {
return Err(Error::DefaultError(
"At least one iterator is empty".to_string(),
));
}
let ct_first = ct.clone().next().unwrap();
let ctx = ct_first.c[0].ctx();
if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| {
cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len()
}) {
return Err(Error::DefaultError("Mismatched parameters".to_string()));
}
if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) {
return Err(Error::DefaultError(
"Mismatched number of parts in the ciphertexts".to_string(),
));
}
if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| {
cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len()
}) {
return Err(Error::DefaultError("Mismatched parameters".to_string()));
}
if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) {
return Err(Error::DefaultError(
"Mismatched number of parts in the ciphertexts".to_string(),
));
}
let max_acc = ctx
.moduli()
.iter()
.map(|qi| 1u128 << (2 * qi.leading_zeros()))
.collect_vec();
let min_of_max = max_acc.iter().min().unwrap();
let max_acc = ctx
.moduli()
.iter()
.map(|qi| 1u128 << (2 * qi.leading_zeros()))
.collect_vec();
let min_of_max = max_acc.iter().min().unwrap();
if count as u128 > *min_of_max {
// Too many ciphertexts for the optimized method, instead, we call
// `poly_dot_product`.
let c = (0..ct_first.c.len())
.map(|i| {
poly_dot_product(
ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }),
pt.clone().map(|pti| &pti.poly_ntt),
)
.map_err(Error::MathError)
})
.collect::<Result<Vec<Poly>>>()?;
if count as u128 > *min_of_max {
// Too many ciphertexts for the optimized method, instead, we call
// `poly_dot_product`.
let c = (0..ct_first.c.len())
.map(|i| {
poly_dot_product(
ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }),
pt.clone().map(|pti| &pti.poly_ntt),
)
.map_err(Error::MathError)
})
.collect::<Result<Vec<Poly>>>()?;
Ok(Ciphertext {
par: ct_first.par.clone(),
seed: None,
c,
level: ct_first.level,
})
} else {
let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree()));
for (ciphertext, plaintext) in izip!(ct, pt) {
let pt_coefficients = plaintext.poly_ntt.coefficients();
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) {
let ci_coefficients = ci.coefficients();
for (mut accij, cij, pij) in izip!(
acci.outer_iter_mut(),
ci_coefficients.outer_iter(),
pt_coefficients.outer_iter()
) {
unsafe {
fma(
accij.as_slice_mut().unwrap(),
cij.as_slice().unwrap(),
pij.as_slice().unwrap(),
)
}
}
}
}
Ok(Ciphertext {
par: ct_first.par.clone(),
seed: None,
c,
level: ct_first.level,
})
} else {
let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree()));
for (ciphertext, plaintext) in izip!(ct, pt) {
let pt_coefficients = plaintext.poly_ntt.coefficients();
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) {
let ci_coefficients = ci.coefficients();
for (mut accij, cij, pij) in izip!(
acci.outer_iter_mut(),
ci_coefficients.outer_iter(),
pt_coefficients.outer_iter()
) {
unsafe {
fma(
accij.as_slice_mut().unwrap(),
cij.as_slice().unwrap(),
pij.as_slice().unwrap(),
)
}
}
}
}
// Reduce
let mut c = Vec::with_capacity(ct_first.c.len());
for acci in acc.outer_iter() {
let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree()));
for (mut outij, accij, q) in izip!(
coeffs.outer_iter_mut(),
acci.outer_iter(),
ctx.moduli_operators()
) {
for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) {
unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) }
}
}
c.push(Poly::try_convert_from(
coeffs,
ctx,
true,
Representation::Ntt,
)?)
}
// Reduce
let mut c = Vec::with_capacity(ct_first.c.len());
for acci in acc.outer_iter() {
let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree()));
for (mut outij, accij, q) in izip!(
coeffs.outer_iter_mut(),
acci.outer_iter(),
ctx.moduli_operators()
) {
for (outij_coeff, accij_coeff) in izip!(outij.iter_mut(), accij.iter()) {
unsafe { *outij_coeff = q.reduce_u128_vt(*accij_coeff) }
}
}
c.push(Poly::try_convert_from(
coeffs,
ctx,
true,
Representation::Ntt,
)?)
}
Ok(Ciphertext {
par: ct_first.par.clone(),
seed: None,
c,
level: ct_first.level,
})
}
Ok(Ciphertext {
par: ct_first.par.clone(),
seed: None,
c,
level: ct_first.level,
})
}
}
#[cfg(test)]
mod tests {
use super::dot_product_scalar;
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
use fhe_traits::{FheEncoder, FheEncrypter};
use itertools::{izip, Itertools};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use super::dot_product_scalar;
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
use fhe_traits::{FheEncoder, FheEncrypter};
use itertools::{izip, Itertools};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn test_dot_product_scalar() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(2, 16)),
] {
let sk = SecretKey::random(&params, &mut rng);
for size in 1..128 {
let ct = (0..size)
.map(|_| {
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap();
sk.try_encrypt(&pt, &mut rng).unwrap()
})
.collect_vec();
let pt = (0..size)
.map(|_| {
let v = params.plaintext.random_vec(params.degree(), &mut rng);
Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap()
})
.collect_vec();
#[test]
fn test_dot_product_scalar() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(1, 8)),
Arc::new(BfvParameters::default(2, 16)),
] {
let sk = SecretKey::random(&params, &mut rng);
for size in 1..128 {
let ct = (0..size)
.map(|_| {
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap();
sk.try_encrypt(&pt, &mut rng).unwrap()
})
.collect_vec();
let pt = (0..size)
.map(|_| {
let v = params.plaintext.random_vec(params.degree(), &mut rng);
Plaintext::try_encode(&v, Encoding::simd(), &params).unwrap()
})
.collect_vec();
let r = dot_product_scalar(ct.iter(), pt.iter())?;
let r = dot_product_scalar(ct.iter(), pt.iter())?;
let mut expected = Ciphertext::zero(&params);
izip!(&ct, &pt).for_each(|(cti, pti)| expected += &(cti * pti));
assert_eq!(r, expected);
}
}
Ok(())
}
let mut expected = Ciphertext::zero(&params);
izip!(&ct, &pt).for_each(|(cti, pti)| expected += &(cti * pti));
assert_eq!(r, expected);
}
}
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
use std::sync::Arc;
use fhe_math::{
rns::ScalingFactor,
rq::{scaler::Scaler, Context, Representation},
zq::primes::generate_prime,
rns::ScalingFactor,
rq::{scaler::Scaler, Context, Representation},
zq::primes::generate_prime,
};
use fhe_util::div_ceil;
use num_bigint::BigUint;
use crate::{
bfv::{keys::RelinearizationKey, BfvParameters, Ciphertext},
Error, Result,
bfv::{keys::RelinearizationKey, BfvParameters, Ciphertext},
Error, Result,
};
/// Multiplicator that implements a strategy for multiplying. In particular, the
@@ -22,390 +22,390 @@ use crate::{
/// - Whether relinearization should be used.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Multiplicator {
par: Arc<BfvParameters>,
pub(crate) extender_lhs: Scaler,
pub(crate) extender_rhs: Scaler,
pub(crate) down_scaler: Scaler,
pub(crate) base_ctx: Arc<Context>,
pub(crate) mul_ctx: Arc<Context>,
rk: Option<RelinearizationKey>,
mod_switch: bool,
level: usize,
par: Arc<BfvParameters>,
pub(crate) extender_lhs: Scaler,
pub(crate) extender_rhs: Scaler,
pub(crate) down_scaler: Scaler,
pub(crate) base_ctx: Arc<Context>,
pub(crate) mul_ctx: Arc<Context>,
rk: Option<RelinearizationKey>,
mod_switch: bool,
level: usize,
}
impl Multiplicator {
/// Construct a multiplicator using custom scaling factors and extended
/// basis.
pub fn new(
lhs_scaling_factor: ScalingFactor,
rhs_scaling_factor: ScalingFactor,
extended_basis: &[u64],
post_mul_scaling_factor: ScalingFactor,
par: &Arc<BfvParameters>,
) -> Result<Self> {
Self::new_leveled_internal(
lhs_scaling_factor,
rhs_scaling_factor,
extended_basis,
post_mul_scaling_factor,
0,
par,
)
}
/// Construct a multiplicator using custom scaling factors and extended
/// basis.
pub fn new(
lhs_scaling_factor: ScalingFactor,
rhs_scaling_factor: ScalingFactor,
extended_basis: &[u64],
post_mul_scaling_factor: ScalingFactor,
par: &Arc<BfvParameters>,
) -> Result<Self> {
Self::new_leveled_internal(
lhs_scaling_factor,
rhs_scaling_factor,
extended_basis,
post_mul_scaling_factor,
0,
par,
)
}
/// Construct a multiplicator using custom scaling factors and extended
/// basis at a given level.
pub fn new_leveled(
lhs_scaling_factor: ScalingFactor,
rhs_scaling_factor: ScalingFactor,
extended_basis: &[u64],
post_mul_scaling_factor: ScalingFactor,
level: usize,
par: &Arc<BfvParameters>,
) -> Result<Self> {
Self::new_leveled_internal(
lhs_scaling_factor,
rhs_scaling_factor,
extended_basis,
post_mul_scaling_factor,
level,
par,
)
}
/// Construct a multiplicator using custom scaling factors and extended
/// basis at a given level.
pub fn new_leveled(
lhs_scaling_factor: ScalingFactor,
rhs_scaling_factor: ScalingFactor,
extended_basis: &[u64],
post_mul_scaling_factor: ScalingFactor,
level: usize,
par: &Arc<BfvParameters>,
) -> Result<Self> {
Self::new_leveled_internal(
lhs_scaling_factor,
rhs_scaling_factor,
extended_basis,
post_mul_scaling_factor,
level,
par,
)
}
fn new_leveled_internal(
lhs_scaling_factor: ScalingFactor,
rhs_scaling_factor: ScalingFactor,
extended_basis: &[u64],
post_mul_scaling_factor: ScalingFactor,
level: usize,
par: &Arc<BfvParameters>,
) -> Result<Self> {
let base_ctx = par.ctx_at_level(level)?;
let mul_ctx = Arc::new(Context::new(extended_basis, par.degree())?);
let extender_lhs = Scaler::new(base_ctx, &mul_ctx, lhs_scaling_factor)?;
let extender_rhs = Scaler::new(base_ctx, &mul_ctx, rhs_scaling_factor)?;
let down_scaler = Scaler::new(&mul_ctx, base_ctx, post_mul_scaling_factor)?;
Ok(Self {
par: par.clone(),
extender_lhs,
extender_rhs,
down_scaler,
base_ctx: base_ctx.clone(),
mul_ctx,
rk: None,
mod_switch: false,
level,
})
}
fn new_leveled_internal(
lhs_scaling_factor: ScalingFactor,
rhs_scaling_factor: ScalingFactor,
extended_basis: &[u64],
post_mul_scaling_factor: ScalingFactor,
level: usize,
par: &Arc<BfvParameters>,
) -> Result<Self> {
let base_ctx = par.ctx_at_level(level)?;
let mul_ctx = Arc::new(Context::new(extended_basis, par.degree())?);
let extender_lhs = Scaler::new(base_ctx, &mul_ctx, lhs_scaling_factor)?;
let extender_rhs = Scaler::new(base_ctx, &mul_ctx, rhs_scaling_factor)?;
let down_scaler = Scaler::new(&mul_ctx, base_ctx, post_mul_scaling_factor)?;
Ok(Self {
par: par.clone(),
extender_lhs,
extender_rhs,
down_scaler,
base_ctx: base_ctx.clone(),
mul_ctx,
rk: None,
mod_switch: false,
level,
})
}
/// Default multiplication strategy using relinearization.
pub fn default(rk: &RelinearizationKey) -> Result<Self> {
let ctx = rk.ksk.par.ctx_at_level(rk.ksk.ciphertext_level)?;
/// Default multiplication strategy using relinearization.
pub fn default(rk: &RelinearizationKey) -> Result<Self> {
let ctx = rk.ksk.par.ctx_at_level(rk.ksk.ciphertext_level)?;
let modulus_size = rk.ksk.par.moduli_sizes()[..ctx.moduli().len()]
.iter()
.sum::<usize>();
let n_moduli = div_ceil(modulus_size + 60, 62);
let modulus_size = rk.ksk.par.moduli_sizes()[..ctx.moduli().len()]
.iter()
.sum::<usize>();
let n_moduli = div_ceil(modulus_size + 60, 62);
let mut extended_basis = Vec::with_capacity(ctx.moduli().len() + n_moduli);
extended_basis.append(&mut ctx.moduli().to_vec());
let mut upper_bound = 1 << 62;
while extended_basis.len() != ctx.moduli().len() + n_moduli {
upper_bound = generate_prime(62, 2 * rk.ksk.par.degree() as u64, upper_bound).unwrap();
if !extended_basis.contains(&upper_bound) && !ctx.moduli().contains(&upper_bound) {
extended_basis.push(upper_bound)
}
}
let mut extended_basis = Vec::with_capacity(ctx.moduli().len() + n_moduli);
extended_basis.append(&mut ctx.moduli().to_vec());
let mut upper_bound = 1 << 62;
while extended_basis.len() != ctx.moduli().len() + n_moduli {
upper_bound = generate_prime(62, 2 * rk.ksk.par.degree() as u64, upper_bound).unwrap();
if !extended_basis.contains(&upper_bound) && !ctx.moduli().contains(&upper_bound) {
extended_basis.push(upper_bound)
}
}
let mut multiplicator = Self::new_leveled_internal(
ScalingFactor::one(),
ScalingFactor::one(),
&extended_basis,
ScalingFactor::new(
&BigUint::from(rk.ksk.par.plaintext.modulus()),
ctx.modulus(),
),
rk.ksk.ciphertext_level,
&rk.ksk.par,
)?;
let mut multiplicator = Self::new_leveled_internal(
ScalingFactor::one(),
ScalingFactor::one(),
&extended_basis,
ScalingFactor::new(
&BigUint::from(rk.ksk.par.plaintext.modulus()),
ctx.modulus(),
),
rk.ksk.ciphertext_level,
&rk.ksk.par,
)?;
multiplicator.enable_relinearization(rk)?;
Ok(multiplicator)
}
multiplicator.enable_relinearization(rk)?;
Ok(multiplicator)
}
/// Enable relinearization after multiplication.
pub fn enable_relinearization(&mut self, rk: &RelinearizationKey) -> Result<()> {
let rk_ctx = self.par.ctx_at_level(rk.ksk.ciphertext_level)?;
if rk_ctx != &self.base_ctx {
return Err(Error::DefaultError(
"Invalid relinearization key context".to_string(),
));
}
self.rk = Some(rk.clone());
Ok(())
}
/// Enable relinearization after multiplication.
pub fn enable_relinearization(&mut self, rk: &RelinearizationKey) -> Result<()> {
let rk_ctx = self.par.ctx_at_level(rk.ksk.ciphertext_level)?;
if rk_ctx != &self.base_ctx {
return Err(Error::DefaultError(
"Invalid relinearization key context".to_string(),
));
}
self.rk = Some(rk.clone());
Ok(())
}
/// Enable modulus switching after multiplication (and relinearization, if
/// applicable).
pub fn enable_mod_switching(&mut self) -> Result<()> {
if self.par.ctx_at_level(self.par.max_level())? == &self.base_ctx {
Err(Error::DefaultError(
"Cannot modulo switch as this is already the last level".to_string(),
))
} else {
self.mod_switch = true;
Ok(())
}
}
/// Enable modulus switching after multiplication (and relinearization, if
/// applicable).
pub fn enable_mod_switching(&mut self) -> Result<()> {
if self.par.ctx_at_level(self.par.max_level())? == &self.base_ctx {
Err(Error::DefaultError(
"Cannot modulo switch as this is already the last level".to_string(),
))
} else {
self.mod_switch = true;
Ok(())
}
}
/// Multiply two ciphertexts using the defined multiplication strategy.
pub fn multiply(&self, lhs: &Ciphertext, rhs: &Ciphertext) -> Result<Ciphertext> {
if lhs.par != self.par || rhs.par != self.par {
return Err(Error::DefaultError(
"Ciphertexts do not have the same parameters".to_string(),
));
}
if lhs.level != self.level || rhs.level != self.level {
return Err(Error::DefaultError(
"Ciphertexts are not at expected level".to_string(),
));
}
if lhs.c.len() != 2 || rhs.c.len() != 2 {
return Err(Error::DefaultError(
"Multiplication can only be performed on ciphertexts of size 2".to_string(),
));
}
/// Multiply two ciphertexts using the defined multiplication strategy.
pub fn multiply(&self, lhs: &Ciphertext, rhs: &Ciphertext) -> Result<Ciphertext> {
if lhs.par != self.par || rhs.par != self.par {
return Err(Error::DefaultError(
"Ciphertexts do not have the same parameters".to_string(),
));
}
if lhs.level != self.level || rhs.level != self.level {
return Err(Error::DefaultError(
"Ciphertexts are not at expected level".to_string(),
));
}
if lhs.c.len() != 2 || rhs.c.len() != 2 {
return Err(Error::DefaultError(
"Multiplication can only be performed on ciphertexts of size 2".to_string(),
));
}
// Extend
let c00 = lhs.c[0].scale(&self.extender_lhs)?;
let c01 = lhs.c[1].scale(&self.extender_lhs)?;
let c10 = rhs.c[0].scale(&self.extender_rhs)?;
let c11 = rhs.c[1].scale(&self.extender_rhs)?;
// Extend
let c00 = lhs.c[0].scale(&self.extender_lhs)?;
let c01 = lhs.c[1].scale(&self.extender_lhs)?;
let c10 = rhs.c[0].scale(&self.extender_rhs)?;
let c11 = rhs.c[1].scale(&self.extender_rhs)?;
// Multiply
let mut c0 = &c00 * &c10;
let mut c1 = &c00 * &c11;
c1 += &(&c01 * &c10);
let mut c2 = &c01 * &c11;
c0.change_representation(Representation::PowerBasis);
c1.change_representation(Representation::PowerBasis);
c2.change_representation(Representation::PowerBasis);
// Multiply
let mut c0 = &c00 * &c10;
let mut c1 = &c00 * &c11;
c1 += &(&c01 * &c10);
let mut c2 = &c01 * &c11;
c0.change_representation(Representation::PowerBasis);
c1.change_representation(Representation::PowerBasis);
c2.change_representation(Representation::PowerBasis);
// Scale
let c0 = c0.scale(&self.down_scaler)?;
let c1 = c1.scale(&self.down_scaler)?;
let c2 = c2.scale(&self.down_scaler)?;
// Scale
let c0 = c0.scale(&self.down_scaler)?;
let c1 = c1.scale(&self.down_scaler)?;
let c2 = c2.scale(&self.down_scaler)?;
let mut c = vec![c0, c1, c2];
let mut c = vec![c0, c1, c2];
// Relinearize
if let Some(rk) = self.rk.as_ref() {
#[allow(unused_mut)]
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?;
// Relinearize
if let Some(rk) = self.rk.as_ref() {
#[allow(unused_mut)]
let (mut c0r, mut c1r) = rk.relinearizes_poly(&c[2])?;
if c0r.ctx() != c[0].ctx() {
c0r.change_representation(Representation::PowerBasis);
c1r.change_representation(Representation::PowerBasis);
c0r.mod_switch_down_to(c[0].ctx())?;
c1r.mod_switch_down_to(c[1].ctx())?;
} else {
c[0].change_representation(Representation::Ntt);
c[1].change_representation(Representation::Ntt);
}
if c0r.ctx() != c[0].ctx() {
c0r.change_representation(Representation::PowerBasis);
c1r.change_representation(Representation::PowerBasis);
c0r.mod_switch_down_to(c[0].ctx())?;
c1r.mod_switch_down_to(c[1].ctx())?;
} else {
c[0].change_representation(Representation::Ntt);
c[1].change_representation(Representation::Ntt);
}
c[0] += &c0r;
c[1] += &c1r;
c.truncate(2);
}
c[0] += &c0r;
c[1] += &c1r;
c.truncate(2);
}
// We construct a ciphertext, but it may not have the right representation for
// the polynomials yet.
let mut c = Ciphertext {
par: self.par.clone(),
seed: None,
c,
level: self.level,
};
// We construct a ciphertext, but it may not have the right representation for
// the polynomials yet.
let mut c = Ciphertext {
par: self.par.clone(),
seed: None,
c,
level: self.level,
};
if self.mod_switch {
c.mod_switch_to_next_level();
} else {
c.c.iter_mut()
.for_each(|p| p.change_representation(Representation::Ntt));
}
if self.mod_switch {
c.mod_switch_to_next_level();
} else {
c.c.iter_mut()
.for_each(|p| p.change_representation(Representation::Ntt));
}
Ok(c)
}
Ok(c)
}
}
#[cfg(test)]
mod tests {
use crate::bfv::{
BfvParameters, Ciphertext, Encoding, Plaintext, RelinearizationKey, SecretKey,
};
use fhe_math::{
rns::{RnsContext, ScalingFactor},
zq::primes::generate_prime,
};
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
use num_bigint::BigUint;
use rand::{rngs::OsRng, thread_rng};
use std::{error::Error, sync::Arc};
use crate::bfv::{
BfvParameters, Ciphertext, Encoding, Plaintext, RelinearizationKey, SecretKey,
};
use fhe_math::{
rns::{RnsContext, ScalingFactor},
zq::primes::generate_prime,
};
use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter};
use num_bigint::BigUint;
use rand::{rngs::OsRng, thread_rng};
use std::{error::Error, sync::Arc};
use super::Multiplicator;
use super::Multiplicator;
#[test]
fn mul() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(3, 8));
for _ in 0..30 {
// We will encode `values` in an Simd format, and check that the product is
// computed correctly.
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
#[test]
fn mul() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(3, 8));
for _ in 0..30 {
// We will encode `values` in an Simd format, and check that the product is
// computed correctly.
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
let sk = SecretKey::random(&par, &mut OsRng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
let sk = SecretKey::random(&par, &mut OsRng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
let mut multiplicator = Multiplicator::default(&rk)?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
let mut multiplicator = Multiplicator::default(&rk)?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
Ok(())
}
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
Ok(())
}
#[test]
fn mul_at_level() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(3, 8));
for _ in 0..15 {
for level in 0..2 {
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
#[test]
fn mul_at_level() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(3, 8));
for _ in 0..15 {
for level in 0..2 {
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
let sk = SecretKey::random(&par, &mut OsRng);
let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?;
let pt = Plaintext::try_encode(&values, Encoding::simd_at_level(level), &par)?;
let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
let ct2: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
assert_eq!(ct1.level, level);
assert_eq!(ct2.level, level);
let sk = SecretKey::random(&par, &mut OsRng);
let rk = RelinearizationKey::new_leveled(&sk, level, level, &mut rng)?;
let pt = Plaintext::try_encode(&values, Encoding::simd_at_level(level), &par)?;
let ct1: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
let ct2: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;
assert_eq!(ct1.level, level);
assert_eq!(ct2.level, level);
let mut multiplicator = Multiplicator::default(&rk).unwrap();
let ct3 = multiplicator.multiply(&ct1, &ct2).unwrap();
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
let mut multiplicator = Multiplicator::default(&rk).unwrap();
let ct3 = multiplicator.multiply(&ct1, &ct2).unwrap();
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, level + 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
}
Ok(())
}
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, level + 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
}
Ok(())
}
#[test]
fn mul_no_relin() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(6, 8));
for _ in 0..30 {
// We will encode `values` in an Simd format, and check that the product is
// computed correctly.
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
#[test]
fn mul_no_relin() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(6, 8));
for _ in 0..30 {
// We will encode `values` in an Simd format, and check that the product is
// computed correctly.
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
let sk = SecretKey::random(&par, &mut OsRng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
let sk = SecretKey::random(&par, &mut OsRng);
let rk = RelinearizationKey::new(&sk, &mut rng)?;
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
let mut multiplicator = Multiplicator::default(&rk)?;
// Remove the relinearization key.
multiplicator.rk = None;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
let mut multiplicator = Multiplicator::default(&rk)?;
// Remove the relinearization key.
multiplicator.rk = None;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
Ok(())
}
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
Ok(())
}
#[test]
fn different_mul_strategy() -> Result<(), Box<dyn Error>> {
// Implement the second multiplication strategy from <https://eprint.iacr.org/2021/204>
#[test]
fn different_mul_strategy() -> Result<(), Box<dyn Error>> {
// Implement the second multiplication strategy from <https://eprint.iacr.org/2021/204>
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(3, 8));
let mut extended_basis = par.moduli().to_vec();
extended_basis
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap());
extended_basis
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[3]).unwrap());
extended_basis
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[4]).unwrap());
let rns = RnsContext::new(&extended_basis[3..])?;
let mut rng = thread_rng();
let par = Arc::new(BfvParameters::default(3, 8));
let mut extended_basis = par.moduli().to_vec();
extended_basis
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[2]).unwrap());
extended_basis
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[3]).unwrap());
extended_basis
.push(generate_prime(62, 2 * par.degree() as u64, extended_basis[4]).unwrap());
let rns = RnsContext::new(&extended_basis[3..])?;
for _ in 0..30 {
// We will encode `values` in an Simd format, and check that the product is
// computed correctly.
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
for _ in 0..30 {
// We will encode `values` in an Simd format, and check that the product is
// computed correctly.
let values = par.plaintext.random_vec(par.degree(), &mut rng);
let mut expected = values.clone();
par.plaintext.mul_vec(&mut expected, &values);
let sk = SecretKey::random(&par, &mut OsRng);
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
let sk = SecretKey::random(&par, &mut OsRng);
let pt = Plaintext::try_encode(&values, Encoding::simd(), &par)?;
let ct1 = sk.try_encrypt(&pt, &mut rng)?;
let ct2 = sk.try_encrypt(&pt, &mut rng)?;
let mut multiplicator = Multiplicator::new(
ScalingFactor::one(),
ScalingFactor::new(rns.modulus(), par.ctx[0].modulus()),
&extended_basis,
ScalingFactor::new(&BigUint::from(par.plaintext()), rns.modulus()),
&par,
)?;
let mut multiplicator = Multiplicator::new(
ScalingFactor::one(),
ScalingFactor::new(rns.modulus(), par.ctx[0].modulus()),
&extended_basis,
ScalingFactor::new(&BigUint::from(par.plaintext()), rns.modulus()),
&par,
)?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
multiplicator.enable_mod_switching()?;
let ct3 = multiplicator.multiply(&ct1, &ct2)?;
assert_eq!(ct3.level, 1);
println!("Noise: {}", unsafe { sk.measure_noise(&ct3)? });
let pt = sk.try_decrypt(&ct3)?;
assert_eq!(Vec::<u64>::try_decode(&pt, Encoding::simd())?, expected);
}
Ok(())
}
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
//! Plaintext type in the BFV encryption scheme.
use crate::{
bfv::{BfvParameters, Encoding, PlaintextVec},
Error, Result,
bfv::{BfvParameters, Encoding, PlaintextVec},
Error, Result,
};
use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation};
use fhe_traits::{FheDecoder, FheEncoder, FheParametrized, FhePlaintext};
@@ -13,69 +13,69 @@ use super::encoding::EncodingEnum;
/// A plaintext object, that encodes a vector according to a specific encoding.
#[derive(Debug, Clone, Eq)]
pub struct Plaintext {
/// The parameters of the underlying BFV encryption scheme.
pub(crate) par: Arc<BfvParameters>,
/// The value after encoding.
pub(crate) value: Box<[u64]>,
/// The encoding of the plaintext, if known
pub(crate) encoding: Option<Encoding>,
/// The plaintext as a polynomial.
pub(crate) poly_ntt: Poly,
/// The level of the plaintext
pub(crate) level: usize,
/// The parameters of the underlying BFV encryption scheme.
pub(crate) par: Arc<BfvParameters>,
/// The value after encoding.
pub(crate) value: Box<[u64]>,
/// The encoding of the plaintext, if known
pub(crate) encoding: Option<Encoding>,
/// The plaintext as a polynomial.
pub(crate) poly_ntt: Poly,
/// The level of the plaintext
pub(crate) level: usize,
}
impl FheParametrized for Plaintext {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl FhePlaintext for Plaintext {
type Encoding = Encoding;
type Encoding = Encoding;
}
// Zeroizing of plaintexts.
impl ZeroizeOnDrop for Plaintext {}
impl Zeroize for Plaintext {
fn zeroize(&mut self) {
self.value.zeroize();
self.poly_ntt.zeroize();
}
fn zeroize(&mut self) {
self.value.zeroize();
self.poly_ntt.zeroize();
}
}
impl Plaintext {
pub(crate) fn to_poly(&self) -> Poly {
let mut m_v = Zeroizing::new(self.value.clone());
self.par
.plaintext
.scalar_mul_vec(&mut m_v, self.par.q_mod_t[self.level]);
let ctx = self.par.ctx_at_level(self.level).unwrap();
let mut m =
Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap();
m.change_representation(Representation::Ntt);
m *= &self.par.delta[self.level];
m
}
pub(crate) fn to_poly(&self) -> Poly {
let mut m_v = Zeroizing::new(self.value.clone());
self.par
.plaintext
.scalar_mul_vec(&mut m_v, self.par.q_mod_t[self.level]);
let ctx = self.par.ctx_at_level(self.level).unwrap();
let mut m =
Poly::try_convert_from(m_v.as_ref(), ctx, false, Representation::PowerBasis).unwrap();
m.change_representation(Representation::Ntt);
m *= &self.par.delta[self.level];
m
}
/// Generate a zero plaintext.
pub fn zero(encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
let level = encoding.level;
let ctx = par.ctx_at_level(level)?;
let value = vec![0u64; par.degree()];
let poly_ntt = Poly::zero(ctx, Representation::Ntt);
Ok(Self {
par: par.clone(),
value: value.into_boxed_slice(),
encoding: Some(encoding),
poly_ntt,
level,
})
}
/// Generate a zero plaintext.
pub fn zero(encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
let level = encoding.level;
let ctx = par.ctx_at_level(level)?;
let value = vec![0u64; par.degree()];
let poly_ntt = Poly::zero(ctx, Representation::Ntt);
Ok(Self {
par: par.clone(),
value: value.into_boxed_slice(),
encoding: Some(encoding),
poly_ntt,
level,
})
}
/// Returns the level of this plaintext.
pub fn level(&self) -> usize {
self.par.level_of_ctx(self.poly_ntt.ctx()).unwrap()
}
/// Returns the level of this plaintext.
pub fn level(&self) -> usize {
self.par.level_of_ctx(self.poly_ntt.ctx()).unwrap()
}
}
unsafe impl Send for Plaintext {}
@@ -83,320 +83,320 @@ unsafe impl Send for Plaintext {}
// Implement the equality manually; we want to say that two plaintexts are equal
// even if one of them doesn't store its encoding information.
impl PartialEq for Plaintext {
fn eq(&self, other: &Self) -> bool {
let mut eq = self.par == other.par;
eq &= self.value == other.value;
if self.encoding.is_some() && other.encoding.is_some() {
eq &= self.encoding.as_ref().unwrap() == other.encoding.as_ref().unwrap()
}
eq
}
fn eq(&self, other: &Self) -> bool {
let mut eq = self.par == other.par;
eq &= self.value == other.value;
if self.encoding.is_some() && other.encoding.is_some() {
eq &= self.encoding.as_ref().unwrap() == other.encoding.as_ref().unwrap()
}
eq
}
}
// Conversions.
impl TryConvertFrom<&Plaintext> for Poly {
fn try_convert_from<R>(
pt: &Plaintext,
ctx: &Arc<Context>,
variable_time: bool,
_: R,
) -> fhe_math::Result<Self>
where
R: Into<Option<Representation>>,
{
if ctx
!= pt
.par
.ctx_at_level(pt.level())
.map_err(|e| fhe_math::Error::Default(e.to_string()))?
{
Err(fhe_math::Error::Default(
"Incompatible contexts".to_string(),
))
} else {
Poly::try_convert_from(
pt.value.as_ref(),
ctx,
variable_time,
Representation::PowerBasis,
)
}
}
fn try_convert_from<R>(
pt: &Plaintext,
ctx: &Arc<Context>,
variable_time: bool,
_: R,
) -> fhe_math::Result<Self>
where
R: Into<Option<Representation>>,
{
if ctx
!= pt
.par
.ctx_at_level(pt.level())
.map_err(|e| fhe_math::Error::Default(e.to_string()))?
{
Err(fhe_math::Error::Default(
"Incompatible contexts".to_string(),
))
} else {
Poly::try_convert_from(
pt.value.as_ref(),
ctx,
variable_time,
Representation::PowerBasis,
)
}
}
}
// Encoding and decoding.
impl<'a, const N: usize, T> FheEncoder<&'a [T; N]> for Plaintext
where
Plaintext: FheEncoder<&'a [T], Error = Error>,
Plaintext: FheEncoder<&'a [T], Error = Error>,
{
type Error = Error;
fn try_encode(value: &'a [T; N], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
Plaintext::try_encode(value.as_ref(), encoding, par)
}
type Error = Error;
fn try_encode(value: &'a [T; N], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
Plaintext::try_encode(value.as_ref(), encoding, par)
}
}
impl<'a, T> FheEncoder<&'a Vec<T>> for Plaintext
where
Plaintext: FheEncoder<&'a [T], Error = Error>,
Plaintext: FheEncoder<&'a [T], Error = Error>,
{
type Error = Error;
fn try_encode(value: &'a Vec<T>, encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
Plaintext::try_encode(value.as_ref(), encoding, par)
}
type Error = Error;
fn try_encode(value: &'a Vec<T>, encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
Plaintext::try_encode(value.as_ref(), encoding, par)
}
}
impl<'a> FheEncoder<&'a [u64]> for Plaintext {
type Error = Error;
fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
if value.len() > par.degree() {
return Err(Error::TooManyValues(value.len(), par.degree()));
}
let v = PlaintextVec::try_encode(value, encoding, par)?;
Ok(v.0[0].clone())
}
type Error = Error;
fn try_encode(value: &'a [u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
if value.len() > par.degree() {
return Err(Error::TooManyValues(value.len(), par.degree()));
}
let v = PlaintextVec::try_encode(value, encoding, par)?;
Ok(v.0[0].clone())
}
}
impl<'a> FheEncoder<&'a [i64]> for Plaintext {
type Error = Error;
fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value));
Plaintext::try_encode(w.as_ref() as &[u64], encoding, par)
}
type Error = Error;
fn try_encode(value: &'a [i64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
let w = Zeroizing::new(par.plaintext.reduce_vec_i64(value));
Plaintext::try_encode(w.as_ref() as &[u64], encoding, par)
}
}
impl FheDecoder<Plaintext> for Vec<u64> {
fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>>
where
O: Into<Option<Encoding>>,
{
let encoding = encoding.into();
let enc: Encoding;
if pt.encoding.is_none() && encoding.is_none() {
return Err(Error::UnspecifiedInput("No encoding specified".to_string()));
} else if pt.encoding.is_some() {
enc = pt.encoding.as_ref().unwrap().clone();
if let Some(arg_enc) = encoding {
if arg_enc != enc {
return Err(Error::EncodingMismatch(arg_enc.into(), enc.into()));
}
}
} else {
enc = encoding.unwrap();
if let Some(pt_enc) = pt.encoding.as_ref() {
if pt_enc != &enc {
return Err(Error::EncodingMismatch(pt_enc.into(), enc.into()));
}
}
}
fn try_decode<O>(pt: &Plaintext, encoding: O) -> Result<Vec<u64>>
where
O: Into<Option<Encoding>>,
{
let encoding = encoding.into();
let enc: Encoding;
if pt.encoding.is_none() && encoding.is_none() {
return Err(Error::UnspecifiedInput("No encoding specified".to_string()));
} else if pt.encoding.is_some() {
enc = pt.encoding.as_ref().unwrap().clone();
if let Some(arg_enc) = encoding {
if arg_enc != enc {
return Err(Error::EncodingMismatch(arg_enc.into(), enc.into()));
}
}
} else {
enc = encoding.unwrap();
if let Some(pt_enc) = pt.encoding.as_ref() {
if pt_enc != &enc {
return Err(Error::EncodingMismatch(pt_enc.into(), enc.into()));
}
}
}
let mut w = pt.value.to_vec();
let mut w = pt.value.to_vec();
match enc.encoding {
EncodingEnum::Poly => Ok(w),
EncodingEnum::Simd => {
if let Some(op) = &pt.par.op {
op.forward(&mut w);
let mut w_reordered = w.clone();
for i in 0..pt.par.degree() {
w_reordered[i] = w[pt.par.matrix_reps_index_map[i]]
}
w.zeroize();
Ok(w_reordered)
} else {
Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()))
}
}
}
}
match enc.encoding {
EncodingEnum::Poly => Ok(w),
EncodingEnum::Simd => {
if let Some(op) = &pt.par.op {
op.forward(&mut w);
let mut w_reordered = w.clone();
for i in 0..pt.par.degree() {
w_reordered[i] = w[pt.par.matrix_reps_index_map[i]]
}
w.zeroize();
Ok(w_reordered)
} else {
Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()))
}
}
}
}
type Error = Error;
type Error = Error;
}
impl FheDecoder<Plaintext> for Vec<i64> {
fn try_decode<E>(pt: &Plaintext, encoding: E) -> Result<Vec<i64>>
where
E: Into<Option<Encoding>>,
{
let v = Vec::<u64>::try_decode(pt, encoding)?;
Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) })
}
fn try_decode<E>(pt: &Plaintext, encoding: E) -> Result<Vec<i64>>
where
E: Into<Option<Encoding>>,
{
let v = Vec::<u64>::try_decode(pt, encoding)?;
Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) })
}
type Error = Error;
type Error = Error;
}
#[cfg(test)]
mod tests {
use super::{Encoding, Plaintext};
use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder};
use fhe_math::rq::{Poly, Representation};
use fhe_traits::{FheDecoder, FheEncoder};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use zeroize::Zeroize;
use super::{Encoding, Plaintext};
use crate::bfv::parameters::{BfvParameters, BfvParametersBuilder};
use fhe_math::rq::{Poly, Representation};
use fhe_traits::{FheDecoder, FheEncoder};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use zeroize::Zeroize;
#[test]
fn try_encode() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
// The default test parameters support both Poly and Simd encodings
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
#[test]
fn try_encode() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
// The default test parameters support both Poly and Simd encodings
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), &params);
assert!(plaintext.is_err());
let plaintext = Plaintext::try_encode(&[0u64; 9], Encoding::poly(), &params);
assert!(plaintext.is_err());
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&[1u64], Encoding::poly(), &params);
assert!(plaintext.is_ok());
// The following parameters do not allow for Simd encoding
let params = Arc::new(
BfvParametersBuilder::new()
.set_degree(8)
.set_plaintext_modulus(2)
.set_moduli(&[4611686018326724609])
.build()?,
);
// The following parameters do not allow for Simd encoding
let params = Arc::new(
BfvParametersBuilder::new()
.set_degree(8)
.set_plaintext_modulus(2)
.set_moduli(&[4611686018326724609])
.build()?,
);
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_err());
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_err());
Ok(())
}
Ok(())
}
#[test]
fn encode_decode() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
#[test]
fn encode_decode() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
assert_eq!(b, a);
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
assert_eq!(b, a);
let a = unsafe { params.plaintext.center_vec_vt(&a) };
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?;
assert_eq!(b, a);
let a = unsafe { params.plaintext.center_vec_vt(&a) };
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?;
assert_eq!(b, a);
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
assert_eq!(b, a);
let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
assert_eq!(b, a);
Ok(())
}
Ok(())
}
#[test]
fn partial_eq() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
#[test]
fn partial_eq() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
assert_eq!(plaintext, same_plaintext);
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
let mut same_plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
assert_eq!(plaintext, same_plaintext);
// Equality also holds when there is no encoding specified. In this test, we use
// the fact that we can set it to None directly, but such a partial plaintext
// will be created during decryption since we do not specify the encoding at the
// time.
same_plaintext.encoding = None;
assert_eq!(plaintext, same_plaintext);
// Equality also holds when there is no encoding specified. In this test, we use
// the fact that we can set it to None directly, but such a partial plaintext
// will be created during decryption since we do not specify the encoding at the
// time.
same_plaintext.encoding = None;
assert_eq!(plaintext, same_plaintext);
Ok(())
}
Ok(())
}
#[test]
fn try_decode_errors() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
#[test]
fn try_decode_errors() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok());
let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd());
assert!(e.is_err());
assert_eq!(
e.unwrap_err(),
crate::Error::EncodingMismatch(Encoding::simd().into(), Encoding::poly().into())
);
let e = Vec::<u64>::try_decode(&plaintext, Encoding::poly_at_level(1));
assert!(e.is_err());
assert_eq!(
e.unwrap_err(),
crate::Error::EncodingMismatch(
Encoding::poly_at_level(1).into(),
Encoding::poly().into()
)
);
assert!(Vec::<u64>::try_decode(&plaintext, None).is_ok());
let e = Vec::<u64>::try_decode(&plaintext, Encoding::simd());
assert!(e.is_err());
assert_eq!(
e.unwrap_err(),
crate::Error::EncodingMismatch(Encoding::simd().into(), Encoding::poly().into())
);
let e = Vec::<u64>::try_decode(&plaintext, Encoding::poly_at_level(1));
assert!(e.is_err());
assert_eq!(
e.unwrap_err(),
crate::Error::EncodingMismatch(
Encoding::poly_at_level(1).into(),
Encoding::poly().into()
)
);
plaintext.encoding = None;
let e = Vec::<u64>::try_decode(&plaintext, None);
assert!(e.is_err());
assert_eq!(
e.unwrap_err(),
crate::Error::UnspecifiedInput("No encoding specified".to_string())
);
plaintext.encoding = None;
let e = Vec::<u64>::try_decode(&plaintext, None);
assert!(e.is_err());
assert_eq!(
e.unwrap_err(),
crate::Error::UnspecifiedInput("No encoding specified".to_string())
);
Ok(())
}
Ok(())
}
#[test]
fn zero() -> Result<(), Box<dyn Error>> {
let params = Arc::new(BfvParameters::default(1, 8));
let plaintext = Plaintext::zero(Encoding::poly(), &params)?;
#[test]
fn zero() -> Result<(), Box<dyn Error>> {
let params = Arc::new(BfvParameters::default(1, 8));
let plaintext = Plaintext::zero(Encoding::poly(), &params)?;
assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8]));
assert_eq!(
plaintext.poly_ntt,
Poly::zero(&params.ctx[0], Representation::Ntt)
);
assert_eq!(plaintext.value, Box::<[u64]>::from([0u64; 8]));
assert_eq!(
plaintext.poly_ntt,
Poly::zero(&params.ctx[0], Representation::Ntt)
);
Ok(())
}
Ok(())
}
#[test]
fn zeroize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
#[test]
fn zeroize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
let mut plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params)?;
plaintext.zeroize();
plaintext.zeroize();
assert_eq!(plaintext, Plaintext::zero(Encoding::poly(), &params)?);
assert_eq!(plaintext, Plaintext::zero(Encoding::poly(), &params)?);
Ok(())
}
Ok(())
}
#[test]
fn try_encode_level() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
// The default test parameters support both Poly and Simd encodings
let params = Arc::new(BfvParameters::default(10, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
#[test]
fn try_encode_level() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
// The default test parameters support both Poly and Simd encodings
let params = Arc::new(BfvParameters::default(10, 8));
let a = params.plaintext.random_vec(params.degree(), &mut rng);
for level in 0..10 {
let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), &params)?;
assert_eq!(plaintext.level(), level);
let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), &params)?;
assert_eq!(plaintext.level(), level);
}
for level in 0..10 {
let plaintext = Plaintext::try_encode(&a, Encoding::poly_at_level(level), &params)?;
assert_eq!(plaintext.level(), level);
let plaintext = Plaintext::try_encode(&a, Encoding::simd_at_level(level), &params)?;
assert_eq!(plaintext.level(), level);
}
Ok(())
}
Ok(())
}
}

View File

@@ -6,8 +6,8 @@ use fhe_util::div_ceil;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::{
bfv::{BfvParameters, Encoding, Plaintext},
Error, Result,
bfv::{BfvParameters, Encoding, Plaintext},
Error, Result,
};
use super::encoding::EncodingEnum;
@@ -17,148 +17,148 @@ use super::encoding::EncodingEnum;
pub struct PlaintextVec(pub Vec<Plaintext>);
impl FhePlaintext for PlaintextVec {
type Encoding = Encoding;
type Encoding = Encoding;
}
impl FheParametrized for PlaintextVec {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl Zeroize for PlaintextVec {
fn zeroize(&mut self) {
self.0.zeroize()
}
fn zeroize(&mut self) {
self.0.zeroize()
}
}
impl ZeroizeOnDrop for PlaintextVec {}
impl FheEncoderVariableTime<&[u64]> for PlaintextVec {
type Error = Error;
type Error = Error;
unsafe fn try_encode_vt(
value: &[u64],
encoding: Encoding,
par: &Arc<BfvParameters>,
) -> Result<Self> {
if value.is_empty() {
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
}
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
}
let ctx = par.ctx_at_level(encoding.level)?;
let num_plaintexts = div_ceil(value.len(), par.degree());
unsafe fn try_encode_vt(
value: &[u64],
encoding: Encoding,
par: &Arc<BfvParameters>,
) -> Result<Self> {
if value.is_empty() {
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
}
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
}
let ctx = par.ctx_at_level(encoding.level)?;
let num_plaintexts = div_ceil(value.len(), par.degree());
Ok(PlaintextVec(
(0..num_plaintexts)
.map(|i| {
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
let mut v = vec![0u64; par.degree()];
match encoding.encoding {
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
EncodingEnum::Simd => {
for i in 0..slice.len() {
v[par.matrix_reps_index_map[i]] = slice[i];
}
par.op.as_ref().unwrap().backward_vt(v.as_mut_ptr());
}
};
Ok(PlaintextVec(
(0..num_plaintexts)
.map(|i| {
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
let mut v = vec![0u64; par.degree()];
match encoding.encoding {
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
EncodingEnum::Simd => {
for i in 0..slice.len() {
v[par.matrix_reps_index_map[i]] = slice[i];
}
par.op.as_ref().unwrap().backward_vt(v.as_mut_ptr());
}
};
let mut poly =
Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
let mut poly =
Poly::try_convert_from(&v, ctx, true, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
Ok(Plaintext {
par: par.clone(),
value: v.into_boxed_slice(),
encoding: Some(encoding.clone()),
poly_ntt: poly,
level: encoding.level,
})
})
.collect::<Result<Vec<Plaintext>>>()?,
))
}
Ok(Plaintext {
par: par.clone(),
value: v.into_boxed_slice(),
encoding: Some(encoding.clone()),
poly_ntt: poly,
level: encoding.level,
})
})
.collect::<Result<Vec<Plaintext>>>()?,
))
}
}
impl FheEncoder<&[u64]> for PlaintextVec {
type Error = Error;
fn try_encode(value: &[u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
if value.is_empty() {
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
}
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
}
let ctx = par.ctx_at_level(encoding.level)?;
let num_plaintexts = div_ceil(value.len(), par.degree());
type Error = Error;
fn try_encode(value: &[u64], encoding: Encoding, par: &Arc<BfvParameters>) -> Result<Self> {
if value.is_empty() {
return Ok(PlaintextVec(vec![Plaintext::zero(encoding, par)?]));
}
if encoding.encoding == EncodingEnum::Simd && par.op.is_none() {
return Err(Error::EncodingNotSupported(EncodingEnum::Simd.to_string()));
}
let ctx = par.ctx_at_level(encoding.level)?;
let num_plaintexts = div_ceil(value.len(), par.degree());
Ok(PlaintextVec(
(0..num_plaintexts)
.map(|i| {
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
let mut v = vec![0u64; par.degree()];
match encoding.encoding {
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
EncodingEnum::Simd => {
for i in 0..slice.len() {
v[par.matrix_reps_index_map[i]] = slice[i];
}
par.op.as_ref().unwrap().backward(&mut v);
}
};
Ok(PlaintextVec(
(0..num_plaintexts)
.map(|i| {
let slice = &value[i * par.degree()..min(value.len(), (i + 1) * par.degree())];
let mut v = vec![0u64; par.degree()];
match encoding.encoding {
EncodingEnum::Poly => v[..slice.len()].copy_from_slice(slice),
EncodingEnum::Simd => {
for i in 0..slice.len() {
v[par.matrix_reps_index_map[i]] = slice[i];
}
par.op.as_ref().unwrap().backward(&mut v);
}
};
let mut poly =
Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
let mut poly =
Poly::try_convert_from(&v, ctx, false, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
Ok(Plaintext {
par: par.clone(),
value: v.into_boxed_slice(),
encoding: Some(encoding.clone()),
poly_ntt: poly,
level: encoding.level,
})
})
.collect::<Result<Vec<Plaintext>>>()?,
))
}
Ok(Plaintext {
par: par.clone(),
value: v.into_boxed_slice(),
encoding: Some(encoding.clone()),
poly_ntt: poly,
level: encoding.level,
})
})
.collect::<Result<Vec<Plaintext>>>()?,
))
}
}
#[cfg(test)]
mod tests {
use crate::bfv::{BfvParameters, Encoding, PlaintextVec};
use fhe_traits::{FheDecoder, FheEncoder, FheEncoderVariableTime};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
use crate::bfv::{BfvParameters, Encoding, PlaintextVec};
use fhe_traits::{FheDecoder, FheEncoder, FheEncoderVariableTime};
use rand::thread_rng;
use std::{error::Error, sync::Arc};
#[test]
fn encode_decode() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for _ in 0..20 {
for i in 1..5 {
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree() * i, &mut rng);
#[test]
fn encode_decode() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for _ in 0..20 {
for i in 1..5 {
let params = Arc::new(BfvParameters::default(1, 8));
let a = params.plaintext.random_vec(params.degree() * i, &mut rng);
let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), &params)?;
assert_eq!(plaintexts.0.len(), i);
let plaintexts = PlaintextVec::try_encode(&a, Encoding::poly_at_level(0), &params)?;
assert_eq!(plaintexts.0.len(), i);
for j in 0..i {
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
}
for j in 0..i {
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
}
let plaintexts = unsafe {
PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), &params)?
};
assert_eq!(plaintexts.0.len(), i);
let plaintexts = unsafe {
PlaintextVec::try_encode_vt(&a, Encoding::poly_at_level(0), &params)?
};
assert_eq!(plaintexts.0.len(), i);
for j in 0..i {
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
}
}
}
Ok(())
}
for j in 0..i {
let b = Vec::<u64>::try_decode(&plaintexts.0[j], Encoding::poly_at_level(0))?;
assert_eq!(b, &a[j * params.degree()..(j + 1) * params.degree()]);
}
}
}
Ok(())
}
}

View File

@@ -2,210 +2,210 @@ use std::ops::Mul;
use fhe_math::rq::{traits::TryConvertFrom as TryConvertFromPoly, Poly, Representation};
use fhe_traits::{
DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize,
DeserializeParametrized, FheCiphertext, FheEncrypter, FheParametrized, Serialize,
};
use protobuf::{Message, MessageField};
use rand::{CryptoRng, RngCore};
use zeroize::Zeroizing;
use crate::{
bfv::proto::bfv::{
KeySwitchingKey as KeySwitchingKeyProto, RGSWCiphertext as RGSWCiphertextProto,
},
Error, Result,
bfv::proto::bfv::{
KeySwitchingKey as KeySwitchingKeyProto, RGSWCiphertext as RGSWCiphertextProto,
},
Error, Result,
};
use super::{
keys::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, Ciphertext, Plaintext, SecretKey,
keys::KeySwitchingKey, traits::TryConvertFrom, BfvParameters, Ciphertext, Plaintext, SecretKey,
};
/// A RGSW ciphertext encrypting a plaintext.
#[derive(Debug, PartialEq, Eq)]
pub struct RGSWCiphertext {
ksk0: KeySwitchingKey,
ksk1: KeySwitchingKey,
ksk0: KeySwitchingKey,
ksk1: KeySwitchingKey,
}
impl FheParametrized for RGSWCiphertext {
type Parameters = BfvParameters;
type Parameters = BfvParameters;
}
impl From<&RGSWCiphertext> for RGSWCiphertextProto {
fn from(ct: &RGSWCiphertext) -> Self {
let mut proto = RGSWCiphertextProto::new();
proto.ksk0 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk0));
proto.ksk1 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk1));
proto
}
fn from(ct: &RGSWCiphertext) -> Self {
let mut proto = RGSWCiphertextProto::new();
proto.ksk0 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk0));
proto.ksk1 = MessageField::some(KeySwitchingKeyProto::from(&ct.ksk1));
proto
}
}
impl TryConvertFrom<&RGSWCiphertextProto> for RGSWCiphertext {
fn try_convert_from(
value: &RGSWCiphertextProto,
par: &std::sync::Arc<BfvParameters>,
) -> Result<Self> {
if value.ksk0.is_none() || value.ksk1.is_none() {
return Err(Error::SerializationError);
}
fn try_convert_from(
value: &RGSWCiphertextProto,
par: &std::sync::Arc<BfvParameters>,
) -> Result<Self> {
if value.ksk0.is_none() || value.ksk1.is_none() {
return Err(Error::SerializationError);
}
let ksk0 = KeySwitchingKey::try_convert_from(value.ksk0.as_ref().unwrap(), par)?;
let ksk1 = KeySwitchingKey::try_convert_from(value.ksk1.as_ref().unwrap(), par)?;
if ksk0.ksk_level != ksk0.ciphertext_level
|| ksk0.ciphertext_level != ksk1.ciphertext_level
|| ksk1.ciphertext_level != ksk1.ksk_level
{
return Err(Error::SerializationError);
}
let ksk0 = KeySwitchingKey::try_convert_from(value.ksk0.as_ref().unwrap(), par)?;
let ksk1 = KeySwitchingKey::try_convert_from(value.ksk1.as_ref().unwrap(), par)?;
if ksk0.ksk_level != ksk0.ciphertext_level
|| ksk0.ciphertext_level != ksk1.ciphertext_level
|| ksk1.ciphertext_level != ksk1.ksk_level
{
return Err(Error::SerializationError);
}
Ok(Self { ksk0, ksk1 })
}
Ok(Self { ksk0, ksk1 })
}
}
impl DeserializeParametrized for RGSWCiphertext {
type Error = Error;
type Error = Error;
fn from_bytes(bytes: &[u8], par: &std::sync::Arc<Self::Parameters>) -> Result<Self> {
let proto =
RGSWCiphertextProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
RGSWCiphertext::try_convert_from(&proto, par)
}
fn from_bytes(bytes: &[u8], par: &std::sync::Arc<Self::Parameters>) -> Result<Self> {
let proto =
RGSWCiphertextProto::parse_from_bytes(bytes).map_err(|_| Error::SerializationError)?;
RGSWCiphertext::try_convert_from(&proto, par)
}
}
impl Serialize for RGSWCiphertext {
fn to_bytes(&self) -> Vec<u8> {
RGSWCiphertextProto::from(self).write_to_bytes().unwrap()
}
fn to_bytes(&self) -> Vec<u8> {
RGSWCiphertextProto::from(self).write_to_bytes().unwrap()
}
}
impl FheCiphertext for RGSWCiphertext {}
impl FheEncrypter<Plaintext, RGSWCiphertext> for SecretKey {
type Error = Error;
type Error = Error;
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<RGSWCiphertext> {
let level = pt.level;
let ctx = self.par.ctx_at_level(level)?;
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<RGSWCiphertext> {
let level = pt.level;
let ctx = self.par.ctx_at_level(level)?;
let mut m = Zeroizing::new(pt.poly_ntt.clone());
let mut m_s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)?);
m_s.change_representation(Representation::Ntt);
*m_s.as_mut() *= m.as_ref();
m_s.change_representation(Representation::PowerBasis);
m.change_representation(Representation::PowerBasis);
let mut m = Zeroizing::new(pt.poly_ntt.clone());
let mut m_s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)?);
m_s.change_representation(Representation::Ntt);
*m_s.as_mut() *= m.as_ref();
m_s.change_representation(Representation::PowerBasis);
m.change_representation(Representation::PowerBasis);
let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?;
let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?;
let ksk0 = KeySwitchingKey::new(self, &m, pt.level, pt.level, rng)?;
let ksk1 = KeySwitchingKey::new(self, &m_s, pt.level, pt.level, rng)?;
Ok(RGSWCiphertext { ksk0, ksk1 })
}
Ok(RGSWCiphertext { ksk0, ksk1 })
}
}
impl Mul<&RGSWCiphertext> for &Ciphertext {
type Output = Ciphertext;
type Output = Ciphertext;
fn mul(self, rhs: &RGSWCiphertext) -> Self::Output {
assert_eq!(
self.par, rhs.ksk0.par,
"Ciphertext and RGSWCiphertext must have the same parameters"
);
assert_eq!(
self.level, rhs.ksk0.ciphertext_level,
"Ciphertext and RGSWCiphertext must have the same level"
);
assert_eq!(self.c.len(), 2, "Ciphertext must have two parts");
fn mul(self, rhs: &RGSWCiphertext) -> Self::Output {
assert_eq!(
self.par, rhs.ksk0.par,
"Ciphertext and RGSWCiphertext must have the same parameters"
);
assert_eq!(
self.level, rhs.ksk0.ciphertext_level,
"Ciphertext and RGSWCiphertext must have the same level"
);
assert_eq!(self.c.len(), 2, "Ciphertext must have two parts");
let mut ct0 = self.c[0].clone();
let mut ct1 = self.c[1].clone();
ct0.change_representation(Representation::PowerBasis);
ct1.change_representation(Representation::PowerBasis);
let mut ct0 = self.c[0].clone();
let mut ct1 = self.c[1].clone();
ct0.change_representation(Representation::PowerBasis);
ct1.change_representation(Representation::PowerBasis);
let (c0, c1) = rhs.ksk0.key_switch(&ct0).unwrap();
let (c0p, c1p) = rhs.ksk1.key_switch(&ct1).unwrap();
let (c0, c1) = rhs.ksk0.key_switch(&ct0).unwrap();
let (c0p, c1p) = rhs.ksk1.key_switch(&ct1).unwrap();
Ciphertext {
par: self.par.clone(),
seed: None,
c: vec![&c0 + &c0p, &c1 + &c1p],
level: self.level,
}
}
Ciphertext {
par: self.par.clone(),
seed: None,
c: vec![&c0 + &c0p, &c1 + &c1p],
level: self.level,
}
}
}
impl Mul<&Ciphertext> for &RGSWCiphertext {
type Output = Ciphertext;
type Output = Ciphertext;
fn mul(self, rhs: &Ciphertext) -> Self::Output {
rhs * self
}
fn mul(self, rhs: &Ciphertext) -> Self::Output {
rhs * self
}
}
#[cfg(test)]
mod tests {
use std::{error::Error, sync::Arc};
use std::{error::Error, sync::Arc};
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
use rand::thread_rng;
use crate::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, SecretKey};
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
use rand::thread_rng;
use super::RGSWCiphertext;
use super::RGSWCiphertext;
#[test]
fn external_product() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(2, 8)),
Arc::new(BfvParameters::default(8, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v1 = params.plaintext.random_vec(params.degree(), &mut rng);
let v2 = params.plaintext.random_vec(params.degree(), &mut rng);
#[test]
fn external_product() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(2, 8)),
Arc::new(BfvParameters::default(8, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v1 = params.plaintext.random_vec(params.degree(), &mut rng);
let v2 = params.plaintext.random_vec(params.degree(), &mut rng);
let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &params)?;
let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &params)?;
let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), &params)?;
let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), &params)?;
let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?;
let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?;
let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?;
let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?;
let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?;
let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?;
let product = &ct1 * &ct2;
let expected = sk.try_decrypt(&product)?;
let product = &ct1 * &ct2;
let expected = sk.try_decrypt(&product)?;
let ct3 = &ct1 * &ct2_rgsw;
let ct4 = &ct2_rgsw * &ct1;
let ct3 = &ct1 * &ct2_rgsw;
let ct4 = &ct2_rgsw * &ct1;
println!("Noise 1: {:?}", unsafe { sk.measure_noise(&ct3) });
println!("Noise 2: {:?}", unsafe { sk.measure_noise(&ct4) });
assert_eq!(expected, sk.try_decrypt(&ct3)?);
assert_eq!(expected, sk.try_decrypt(&ct4)?);
}
Ok(())
}
println!("Noise 1: {:?}", unsafe { sk.measure_noise(&ct3) });
println!("Noise 2: {:?}", unsafe { sk.measure_noise(&ct4) });
assert_eq!(expected, sk.try_decrypt(&ct3)?);
assert_eq!(expected, sk.try_decrypt(&ct4)?);
}
Ok(())
}
#[test]
fn serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(5, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?;
#[test]
fn serialize() -> Result<(), Box<dyn Error>> {
let mut rng = thread_rng();
for params in [
Arc::new(BfvParameters::default(6, 8)),
Arc::new(BfvParameters::default(5, 8)),
] {
let sk = SecretKey::random(&params, &mut rng);
let v = params.plaintext.random_vec(params.degree(), &mut rng);
let pt = Plaintext::try_encode(&v, Encoding::simd(), &params)?;
let ct: RGSWCiphertext = sk.try_encrypt(&pt, &mut rng)?;
let bytes = ct.to_bytes();
assert_eq!(RGSWCiphertext::from_bytes(&bytes, &params)?, ct);
}
let bytes = ct.to_bytes();
assert_eq!(RGSWCiphertext::from_bytes(&bytes, &params)?, ct);
}
Ok(())
}
Ok(())
}
}

View File

@@ -12,8 +12,8 @@ use std::sync::Arc;
/// blanket implementation <https://github.com/rust-lang/rust/issues/50133#issuecomment-488512355>.
pub trait TryConvertFrom<T>
where
Self: Sized,
Self: Sized,
{
/// Attempt to convert the `value` with a specific parameter.
fn try_convert_from(value: T, par: &Arc<BfvParameters>) -> Result<Self>;
/// Attempt to convert the `value` with a specific parameter.
fn try_convert_from(value: T, par: &Arc<BfvParameters>) -> Result<Self>;
}

View File

@@ -6,141 +6,141 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Enum encapsulating all the possible errors from this library.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum Error {
/// Indicates that an error from the underlying mathematical library was
/// encountered.
#[error("{0}")]
MathError(fhe_math::Error),
/// Indicates that an error from the underlying mathematical library was
/// encountered.
#[error("{0}")]
MathError(fhe_math::Error),
/// Indicates a serialization error.
#[error("Serialization error")]
SerializationError,
/// Indicates a serialization error.
#[error("Serialization error")]
SerializationError,
/// Indicates that too many values were provided.
#[error("Too many values provided: {0} exceeds limit {1}")]
TooManyValues(usize, usize),
/// Indicates that too many values were provided.
#[error("Too many values provided: {0} exceeds limit {1}")]
TooManyValues(usize, usize),
/// Indicates that too few values were provided.
#[error("Too few values provided: {0} is below limit {1}")]
TooFewValues(usize, usize),
/// Indicates that too few values were provided.
#[error("Too few values provided: {0} is below limit {1}")]
TooFewValues(usize, usize),
/// Indicates that an input is invalid.
#[error("{0}")]
UnspecifiedInput(String),
/// Indicates that an input is invalid.
#[error("{0}")]
UnspecifiedInput(String),
/// Indicates a mismatch in the encodings.
#[error("Encoding mismatch: found {0}, expected {1}")]
EncodingMismatch(String, String),
/// Indicates a mismatch in the encodings.
#[error("Encoding mismatch: found {0}, expected {1}")]
EncodingMismatch(String, String),
/// Indicates that the encoding is not supported.
#[error("Does not support {0} encoding")]
EncodingNotSupported(String),
/// Indicates that the encoding is not supported.
#[error("Does not support {0} encoding")]
EncodingNotSupported(String),
/// Indicates a parameter error.
#[error("{0}")]
ParametersError(ParametersError),
/// Indicates a parameter error.
#[error("{0}")]
ParametersError(ParametersError),
/// Indicates a default error
/// TODO: To delete eventually
#[error("{0}")]
DefaultError(String),
/// Indicates a default error
/// TODO: To delete eventually
#[error("{0}")]
DefaultError(String),
}
impl From<fhe_math::Error> for Error {
fn from(e: fhe_math::Error) -> Self {
Error::MathError(e)
}
fn from(e: fhe_math::Error) -> Self {
Error::MathError(e)
}
}
/// Separate enum to indicate parameters-related errors.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ParametersError {
/// Indicates that the degree is invalid.
#[error("Invalid degree: {0} is not a power of 2 larger than 8")]
InvalidDegree(usize),
/// Indicates that the degree is invalid.
#[error("Invalid degree: {0} is not a power of 2 larger than 8")]
InvalidDegree(usize),
/// Indicates that the moduli sizes are invalid.
#[error("Invalid modulus size: {0}, expected an integer between {1} and {2}")]
InvalidModulusSize(usize, usize, usize),
/// Indicates that the moduli sizes are invalid.
#[error("Invalid modulus size: {0}, expected an integer between {1} and {2}")]
InvalidModulusSize(usize, usize, usize),
/// Indicates that there exists not enough primes of this size.
#[error("Not enough primes of size {0} for polynomials of degree {1}")]
NotEnoughPrimes(usize, usize),
/// Indicates that there exists not enough primes of this size.
#[error("Not enough primes of size {0} for polynomials of degree {1}")]
NotEnoughPrimes(usize, usize),
/// Indicates that the plaintext is invalid.
#[error("{0}")]
InvalidPlaintext(String),
/// Indicates that the plaintext is invalid.
#[error("{0}")]
InvalidPlaintext(String),
/// Indicates that too many parameters were specified.
#[error("{0}")]
TooManySpecified(String),
/// Indicates that too many parameters were specified.
#[error("{0}")]
TooManySpecified(String),
/// Indicates that too few parameters were specified.
#[error("{0}")]
TooFewSpecified(String),
/// Indicates that too few parameters were specified.
#[error("{0}")]
TooFewSpecified(String),
}
#[cfg(test)]
mod tests {
use crate::{Error, ParametersError};
use crate::{Error, ParametersError};
#[test]
fn error_strings() {
assert_eq!(
Error::MathError(fhe_math::Error::InvalidContext).to_string(),
fhe_math::Error::InvalidContext.to_string()
);
assert_eq!(Error::SerializationError.to_string(), "Serialization error");
assert_eq!(
Error::TooManyValues(20, 17).to_string(),
"Too many values provided: 20 exceeds limit 17"
);
assert_eq!(
Error::TooFewValues(10, 17).to_string(),
"Too few values provided: 10 is below limit 17"
);
assert_eq!(
Error::UnspecifiedInput("test string".to_string()).to_string(),
"test string"
);
assert_eq!(
Error::EncodingMismatch("enc1".to_string(), "enc2".to_string()).to_string(),
"Encoding mismatch: found enc1, expected enc2"
);
assert_eq!(
Error::EncodingNotSupported("test".to_string()).to_string(),
"Does not support test encoding"
);
assert_eq!(
Error::ParametersError(ParametersError::InvalidDegree(10)).to_string(),
ParametersError::InvalidDegree(10).to_string()
);
}
#[test]
fn error_strings() {
assert_eq!(
Error::MathError(fhe_math::Error::InvalidContext).to_string(),
fhe_math::Error::InvalidContext.to_string()
);
assert_eq!(Error::SerializationError.to_string(), "Serialization error");
assert_eq!(
Error::TooManyValues(20, 17).to_string(),
"Too many values provided: 20 exceeds limit 17"
);
assert_eq!(
Error::TooFewValues(10, 17).to_string(),
"Too few values provided: 10 is below limit 17"
);
assert_eq!(
Error::UnspecifiedInput("test string".to_string()).to_string(),
"test string"
);
assert_eq!(
Error::EncodingMismatch("enc1".to_string(), "enc2".to_string()).to_string(),
"Encoding mismatch: found enc1, expected enc2"
);
assert_eq!(
Error::EncodingNotSupported("test".to_string()).to_string(),
"Does not support test encoding"
);
assert_eq!(
Error::ParametersError(ParametersError::InvalidDegree(10)).to_string(),
ParametersError::InvalidDegree(10).to_string()
);
}
#[test]
fn parameters_error_strings() {
assert_eq!(
ParametersError::InvalidDegree(10).to_string(),
"Invalid degree: 10 is not a power of 2 larger than 8"
);
assert_eq!(
ParametersError::InvalidModulusSize(1, 2, 3).to_string(),
"Invalid modulus size: 1, expected an integer between 2 and 3"
);
assert_eq!(
ParametersError::NotEnoughPrimes(1, 2).to_string(),
"Not enough primes of size 1 for polynomials of degree 2"
);
assert_eq!(
ParametersError::InvalidPlaintext("test".to_string()).to_string(),
"test"
);
assert_eq!(
ParametersError::TooManySpecified("test".to_string()).to_string(),
"test"
);
assert_eq!(
ParametersError::TooFewSpecified("test".to_string()).to_string(),
"test"
);
}
#[test]
fn parameters_error_strings() {
assert_eq!(
ParametersError::InvalidDegree(10).to_string(),
"Invalid degree: 10 is not a power of 2 larger than 8"
);
assert_eq!(
ParametersError::InvalidModulusSize(1, 2, 3).to_string(),
"Invalid modulus size: 1, expected an integer between 2 and 3"
);
assert_eq!(
ParametersError::NotEnoughPrimes(1, 2).to_string(),
"Not enough primes of size 1 for polynomials of degree 2"
);
assert_eq!(
ParametersError::InvalidPlaintext("test".to_string()).to_string(),
"test"
);
assert_eq!(
ParametersError::TooManySpecified("test".to_string()).to_string(),
"test"
);
assert_eq!(
ParametersError::TooFewSpecified("test".to_string()).to_string(),
"test"
);
}
}

View File

@@ -1,2 +1 @@
wrap_comments = true
hard_tabs = true