mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
Merge pull request #47 from Sunscreen-tech/rweber/nonFhe
Rweber/non fhe
This commit is contained in:
@@ -5,8 +5,8 @@ use crate::types::{
|
||||
};
|
||||
use crate::{with_ctx, CircuitInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
|
||||
use std::cmp::Eq;
|
||||
use sunscreen_runtime::Error;
|
||||
use std::ops::*;
|
||||
use sunscreen_runtime::Error;
|
||||
|
||||
use num::Rational64;
|
||||
|
||||
@@ -108,7 +108,7 @@ impl Add for Rational {
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
num: self.num * rhs.den + rhs.num * self.den,
|
||||
den: self.den * rhs.den
|
||||
den: self.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -121,7 +121,7 @@ impl Add<f64> for Rational {
|
||||
|
||||
Self::Output {
|
||||
num: self.num * rhs.den + rhs.num * self.den,
|
||||
den: self.den * rhs.den
|
||||
den: self.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,7 +134,7 @@ impl Add<Rational> for f64 {
|
||||
|
||||
Self::Output {
|
||||
num: lhs.num * rhs.den + rhs.num * lhs.den,
|
||||
den: lhs.den * rhs.den
|
||||
den: lhs.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,7 +145,7 @@ impl Mul for Rational {
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
num: self.num * rhs.num,
|
||||
den: self.den * rhs.den
|
||||
den: self.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -182,7 +182,7 @@ impl Sub for Rational {
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
num: self.num * rhs.den - rhs.num * self.den,
|
||||
den: self.den * rhs.den
|
||||
den: self.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,7 +195,7 @@ impl Sub<f64> for Rational {
|
||||
|
||||
Self::Output {
|
||||
num: self.num * rhs.den - rhs.num * self.den,
|
||||
den: self.den * rhs.den
|
||||
den: self.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -208,7 +208,7 @@ impl Sub<Rational> for f64 {
|
||||
|
||||
Self::Output {
|
||||
num: lhs.num * rhs.den - rhs.num * lhs.den,
|
||||
den: lhs.den * rhs.den
|
||||
den: lhs.den * rhs.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -219,7 +219,7 @@ impl Div for Rational {
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
num: self.num * rhs.den,
|
||||
den: self.den * rhs.num
|
||||
den: self.den * rhs.num,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -232,7 +232,7 @@ impl Div<f64> for Rational {
|
||||
|
||||
Self::Output {
|
||||
num: self.num * rhs.den,
|
||||
den: self.den * rhs.num
|
||||
den: self.den * rhs.num,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,7 +245,7 @@ impl Div<Rational> for f64 {
|
||||
|
||||
Self::Output {
|
||||
num: lhs.num * rhs.den,
|
||||
den: lhs.den * rhs.num
|
||||
den: lhs.den * rhs.num,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,7 +256,7 @@ impl Neg for Rational {
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
num: -self.num,
|
||||
den: self.den
|
||||
den: self.den,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -651,9 +651,7 @@ impl GraphConstCipherDiv for Rational {
|
||||
impl GraphCipherNeg for Rational {
|
||||
type Val = Self;
|
||||
|
||||
fn graph_cipher_neg(
|
||||
a: CircuitNode<Cipher<Self::Val>>,
|
||||
) -> CircuitNode<Cipher<Self::Val>> {
|
||||
fn graph_cipher_neg(a: CircuitNode<Cipher<Self::Val>>) -> CircuitNode<Cipher<Self::Val>> {
|
||||
with_ctx(|ctx| {
|
||||
let neg = ctx.add_negate(a.ids[0]);
|
||||
let ids = [neg, a.ids[1]];
|
||||
@@ -713,4 +711,4 @@ mod tests {
|
||||
|
||||
assert_eq!(-a, (-5.).try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ impl Add for Signed {
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
val: self.val + rhs.val
|
||||
val: self.val + rhs.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,7 +146,7 @@ impl Add<i64> for Signed {
|
||||
|
||||
fn add(self, rhs: i64) -> Self::Output {
|
||||
Self {
|
||||
val: self.val + rhs
|
||||
val: self.val + rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -156,7 +156,7 @@ impl Add<Signed> for i64 {
|
||||
|
||||
fn add(self, rhs: Signed) -> Self::Output {
|
||||
Self::Output {
|
||||
val: self + rhs.val
|
||||
val: self + rhs.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,7 +166,7 @@ impl Mul for Signed {
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
val: self.val * rhs.val
|
||||
val: self.val * rhs.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,7 +176,7 @@ impl Mul<i64> for Signed {
|
||||
|
||||
fn mul(self, rhs: i64) -> Self::Output {
|
||||
Self {
|
||||
val: self.val * rhs
|
||||
val: self.val * rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -186,7 +186,7 @@ impl Mul<Signed> for i64 {
|
||||
|
||||
fn mul(self, rhs: Signed) -> Self::Output {
|
||||
Self::Output {
|
||||
val: self * rhs.val
|
||||
val: self * rhs.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,7 +196,7 @@ impl Sub for Signed {
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self::Output {
|
||||
val: self.val - rhs.val
|
||||
val: self.val - rhs.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -206,7 +206,7 @@ impl Sub<i64> for Signed {
|
||||
|
||||
fn sub(self, rhs: i64) -> Self::Output {
|
||||
Self {
|
||||
val: self.val - rhs
|
||||
val: self.val - rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,7 +216,7 @@ impl Sub<Signed> for i64 {
|
||||
|
||||
fn sub(self, rhs: Signed) -> Self::Output {
|
||||
Self::Output {
|
||||
val: self - rhs.val
|
||||
val: self - rhs.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,9 +225,7 @@ impl Neg for Signed {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
val: -self.val
|
||||
}
|
||||
Self::Output { val: -self.val }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,4 +471,4 @@ mod tests {
|
||||
|
||||
assert_eq!(-a, (-5).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
crate_version,
|
||||
types::{
|
||||
intern::{Cipher, CircuitNode},
|
||||
intern::{Cipher, CircuitNode, SwapRows},
|
||||
ops::*,
|
||||
BfvType, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName,
|
||||
TypeNameInstance, Version,
|
||||
@@ -12,6 +12,7 @@ use seal::{
|
||||
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
|
||||
Result as SealResult,
|
||||
};
|
||||
use std::ops::*;
|
||||
use sunscreen_runtime::{Error as RuntimeError, Result as RuntimeResult};
|
||||
|
||||
/**
|
||||
@@ -63,9 +64,9 @@ use sunscreen_runtime::{Error as RuntimeError, Result as RuntimeResult};
|
||||
* value equal to half the polynomial degree needed to accomodate the
|
||||
* circuit's noise budget constraint.
|
||||
*/
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Simd<const LANES: usize> {
|
||||
data: [Vec<i64>; 2],
|
||||
data: [[i64; LANES]; 2],
|
||||
}
|
||||
|
||||
impl<const LANES: usize> NumCiphertexts for Simd<LANES> {
|
||||
@@ -178,8 +179,30 @@ impl<const LANES: usize> TryFromPlaintext for Simd<LANES> {
|
||||
|
||||
Ok(Self {
|
||||
data: [
|
||||
row_0.iter().take(LANES).map(|x| *x).collect(),
|
||||
row_1.iter().take(LANES).map(|x| *x).collect(),
|
||||
row_0
|
||||
.iter()
|
||||
.take(LANES)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.map_err(|_| {
|
||||
RuntimeError::FheTypeError(format!(
|
||||
"Failed to convert Vec to [i64;{}]",
|
||||
LANES
|
||||
))
|
||||
})?,
|
||||
row_1
|
||||
.iter()
|
||||
.take(LANES)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.map_err(|_| {
|
||||
RuntimeError::FheTypeError(format!(
|
||||
"Failed to convert Vec to [i64;{}]",
|
||||
LANES
|
||||
))
|
||||
})?,
|
||||
],
|
||||
})
|
||||
}
|
||||
@@ -189,19 +212,217 @@ impl<const LANES: usize> TryFrom<[Vec<i64>; 2]> for Simd<LANES> {
|
||||
type Error = RuntimeError;
|
||||
|
||||
fn try_from(data: [Vec<i64>; 2]) -> RuntimeResult<Self> {
|
||||
if data[0].len() != data[1].len() || data[0].len() != LANES {
|
||||
return Err(RuntimeError::FheTypeError(
|
||||
format!("Invalid SIMD shape. Expected a 2x{} matrix", LANES).to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self { data })
|
||||
Ok(Self {
|
||||
data: [
|
||||
data[0].clone().try_into().map_err(|_| {
|
||||
RuntimeError::FheTypeError(format!("Failed to convert Vec to [i64;{}]", LANES))
|
||||
})?,
|
||||
data[1].clone().try_into().map_err(|_| {
|
||||
RuntimeError::FheTypeError(format!("Failed to convert Vec to [i64;{}]", LANES))
|
||||
})?,
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Into<[Vec<i64>; 2]> for Simd<LANES> {
|
||||
fn into(self) -> [Vec<i64>; 2] {
|
||||
self.data
|
||||
[self.data[0].into(), self.data[1].into()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> From<[[i64; LANES]; 2]> for Simd<LANES> {
|
||||
fn from(data: [[i64; LANES]; 2]) -> Self {
|
||||
Self { data }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Into<[[i64; LANES]; 2]> for Simd<LANES> {
|
||||
fn into(self) -> [[i64; LANES]; 2] {
|
||||
[self.data[0], self.data[1]]
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Add for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let r_0: [i64; LANES] = self.data[0]
|
||||
.iter()
|
||||
.zip(rhs.data[0].iter())
|
||||
.map(|(x, y)| x + y)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let r_1: [i64; LANES] = self.data[1]
|
||||
.iter()
|
||||
.zip(rhs.data[1].iter())
|
||||
.map(|(x, y)| x + y)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self { data: [r_0, r_1] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Sub for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
let r_0: [i64; LANES] = self.data[0]
|
||||
.iter()
|
||||
.zip(rhs.data[0].iter())
|
||||
.map(|(x, y)| x - y)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let r_1: [i64; LANES] = self.data[1]
|
||||
.iter()
|
||||
.zip(rhs.data[1].iter())
|
||||
.map(|(x, y)| x - y)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self { data: [r_0, r_1] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Mul for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
let r_0: [i64; LANES] = self.data[0]
|
||||
.iter()
|
||||
.zip(rhs.data[0].iter())
|
||||
.map(|(x, y)| x * y)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let r_1: [i64; LANES] = self.data[1]
|
||||
.iter()
|
||||
.zip(rhs.data[1].iter())
|
||||
.map(|(x, y)| x * y)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self { data: [r_0, r_1] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Neg for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let r_0: [i64; LANES] = self.data[0]
|
||||
.iter()
|
||||
.map(|x| -x)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let r_1: [i64; LANES] = self.data[1]
|
||||
.iter()
|
||||
.map(|x| -x)
|
||||
.collect::<Vec<i64>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self { data: [r_0, r_1] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Shl<u64> for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn shl(self, x: u64) -> Self::Output {
|
||||
let r_0: [i64; LANES] = [
|
||||
self.data[0]
|
||||
.iter()
|
||||
.skip(x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
self.data[0]
|
||||
.iter()
|
||||
.take(x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
]
|
||||
.concat()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let r_1: [i64; LANES] = [
|
||||
self.data[1]
|
||||
.iter()
|
||||
.skip(x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
self.data[1]
|
||||
.iter()
|
||||
.take(x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
]
|
||||
.concat()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self { data: [r_0, r_1] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> Shr<u64> for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn shr(self, x: u64) -> Self::Output {
|
||||
let r_0: [i64; LANES] = [
|
||||
self.data[0]
|
||||
.iter()
|
||||
.skip(LANES - x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
self.data[0]
|
||||
.iter()
|
||||
.take(LANES - x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
]
|
||||
.concat()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let r_1: [i64; LANES] = [
|
||||
self.data[1]
|
||||
.iter()
|
||||
.skip(LANES - x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
self.data[1]
|
||||
.iter()
|
||||
.take(LANES -x as usize)
|
||||
.map(|x| *x)
|
||||
.collect::<Vec<i64>>(),
|
||||
]
|
||||
.concat()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self { data: [r_0, r_1] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> SwapRows for Simd<LANES> {
|
||||
type Output = Self;
|
||||
|
||||
fn swap_rows(self) -> Self::Output {
|
||||
Self {
|
||||
data: [
|
||||
self.data[1],
|
||||
self.data[0]
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +509,19 @@ impl<const LANES: usize> GraphCipherRotateRight for Simd<LANES> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherNeg for Simd<LANES> {
|
||||
type Val = Self;
|
||||
|
||||
fn graph_cipher_neg(
|
||||
x: CircuitNode<Cipher<Self>>,
|
||||
) -> CircuitNode<Cipher<Self::Val>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_negate(x.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -317,4 +551,59 @@ mod tests {
|
||||
|
||||
assert_eq!(x, y);
|
||||
}
|
||||
|
||||
const A_VEC: [[i64; 4]; 2] = [[1, 2, 3, 4], [5, 6, 7, 8]];
|
||||
const B_VEC: [[i64; 4]; 2] = [[5, 6, 7, 8], [1, 2, 3, 4]];
|
||||
|
||||
#[test]
|
||||
fn can_add_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
let b = Simd::<4>::try_from(B_VEC).unwrap();
|
||||
|
||||
assert_eq!(a + b, [[6, 8, 10, 12], [6, 8, 10, 12]].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mul_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
let b = Simd::<4>::try_from(B_VEC).unwrap();
|
||||
|
||||
assert_eq!(a * b, [[5, 12, 21, 32], [5, 12, 21, 32]].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_sub_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
let b = Simd::<4>::try_from(B_VEC).unwrap();
|
||||
|
||||
assert_eq!(a - b, [[-4, -4, -4, -4], [4, 4, 4, 4]].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_neg_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
|
||||
assert_eq!(-a, [[-1, -2, -3, -4], [-5, -6, -7, -8]].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_shl_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
|
||||
assert_eq!(a << 3, [[4, 1, 2, 3], [8, 5, 6, 7]].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_shr_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
|
||||
assert_eq!(a >> 3, [[2, 3, 4, 1], [6, 7, 8, 5]].into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_swap_rows_non_fhe() {
|
||||
let a = Simd::<4>::try_from(A_VEC).unwrap();
|
||||
|
||||
assert_eq!(a.swap_rows(), [[5, 6, 7, 8], [1, 2, 3, 4]].into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,5 +104,4 @@ where
|
||||
|
||||
trait Foo {}
|
||||
|
||||
impl<T> Foo for T
|
||||
where T: FheType { }
|
||||
impl<T> Foo for T where T: FheType {}
|
||||
|
||||
@@ -36,9 +36,9 @@ fn can_encode_rational_numbers() {
|
||||
|
||||
type CipherRational = Cipher<Rational>;
|
||||
|
||||
fn add_impl<T, U, R>(a: T, b: U) -> R
|
||||
fn add_impl<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Add<U, Output = R>
|
||||
T: Add<U, Output = R>,
|
||||
{
|
||||
a + b
|
||||
}
|
||||
@@ -61,13 +61,9 @@ fn can_add_cipher_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
|
||||
|
||||
@@ -94,9 +90,7 @@ fn can_add_cipher_plain() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
|
||||
@@ -126,9 +120,7 @@ fn can_add_plain_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
|
||||
@@ -158,9 +150,7 @@ fn can_add_cipher_literal() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -189,9 +179,7 @@ fn can_add_literal_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -202,9 +190,9 @@ fn can_add_literal_cipher() {
|
||||
assert_eq!(c, add_impl(3.14, a));
|
||||
}
|
||||
|
||||
fn sub_impl<T, U, R>(a: T, b: U) -> R
|
||||
fn sub_impl<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Sub<U, Output = R>
|
||||
T: Sub<U, Output = R>,
|
||||
{
|
||||
a - b
|
||||
}
|
||||
@@ -227,13 +215,9 @@ fn can_sub_cipher_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
|
||||
|
||||
@@ -260,9 +244,7 @@ fn can_sub_cipher_plain() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
|
||||
@@ -292,11 +274,9 @@ fn can_sub_plain_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
|
||||
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b_c.into()];
|
||||
|
||||
@@ -325,9 +305,7 @@ fn can_sub_cipher_literal() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -356,9 +334,7 @@ fn can_sub_literal_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -369,9 +345,9 @@ fn can_sub_literal_cipher() {
|
||||
assert_eq!(c, sub_impl(3.14, a));
|
||||
}
|
||||
|
||||
fn mul_impl<T, U, R>(a: T, b: U) -> R
|
||||
fn mul_impl<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Mul<U, Output = R>
|
||||
T: Mul<U, Output = R>,
|
||||
{
|
||||
a * b
|
||||
}
|
||||
@@ -394,13 +370,9 @@ fn can_mul_cipher_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
|
||||
|
||||
@@ -427,9 +399,7 @@ fn can_mul_cipher_plain() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
|
||||
@@ -459,11 +429,9 @@ fn can_mul_plain_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
|
||||
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b_c.into()];
|
||||
|
||||
@@ -492,9 +460,7 @@ fn can_mul_cipher_literal() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -523,9 +489,7 @@ fn can_mul_literal_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -536,9 +500,9 @@ fn can_mul_literal_cipher() {
|
||||
assert_eq!(c, mul_impl(3.14, a));
|
||||
}
|
||||
|
||||
fn div_impl<T, U, R>(a: T, b: U) -> R
|
||||
fn div_impl<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Mul<U, Output = R>
|
||||
T: Mul<U, Output = R>,
|
||||
{
|
||||
a * b
|
||||
}
|
||||
@@ -561,13 +525,9 @@ fn can_div_cipher_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
|
||||
|
||||
@@ -594,9 +554,7 @@ fn can_div_cipher_plain() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
|
||||
@@ -626,11 +584,9 @@ fn can_div_plain_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-3.14).unwrap();
|
||||
|
||||
|
||||
let b = Rational::try_from(6.28).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
let b_c = runtime.encrypt(b, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b_c.into()];
|
||||
|
||||
@@ -659,9 +615,7 @@ fn can_div_cipher_literal() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -690,9 +644,7 @@ fn can_div_literal_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -705,9 +657,9 @@ fn can_div_literal_cipher() {
|
||||
|
||||
#[test]
|
||||
fn can_neg_cipher() {
|
||||
fn neg_impl<T>(x: T) -> T
|
||||
fn neg_impl<T>(x: T) -> T
|
||||
where
|
||||
T: Neg<Output = T>
|
||||
T: Neg<Output = T>,
|
||||
{
|
||||
-x
|
||||
}
|
||||
@@ -728,9 +680,7 @@ fn can_neg_cipher() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Rational::try_from(-6.28).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let a_c = runtime.encrypt(a, &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
@@ -739,4 +689,4 @@ fn can_neg_cipher() {
|
||||
let c: Rational = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, neg_impl(a));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ use sunscreen_compiler::{
|
||||
|
||||
use std::ops::*;
|
||||
|
||||
fn add_fn<T, U, R>(a: T, b: U) -> R where
|
||||
fn add_fn<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Add<U, Output = R>,
|
||||
{
|
||||
a + b
|
||||
@@ -128,7 +129,8 @@ fn can_add_literal_cipher() {
|
||||
assert_eq!(c, add_fn(-4, a));
|
||||
}
|
||||
|
||||
fn sub_fn<T, U, R>(a: T, b: U) -> R where
|
||||
fn sub_fn<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Sub<U, Output = R>,
|
||||
{
|
||||
a - b
|
||||
@@ -250,7 +252,8 @@ fn can_sub_literal_cipher() {
|
||||
assert_eq!(c, sub_fn(-4, a));
|
||||
}
|
||||
|
||||
fn mul_fn<T, U, R>(a: T, b: U) -> R where
|
||||
fn mul_fn<T, U, R>(a: T, b: U) -> R
|
||||
where
|
||||
T: Mul<U, Output = R>,
|
||||
{
|
||||
a * b
|
||||
@@ -370,4 +373,4 @@ fn can_mul_literal_cipher() {
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, mul_fn(-4, a));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
use sunscreen_compiler::{
|
||||
circuit,
|
||||
types::{bfv::Simd, Cipher},
|
||||
types::{bfv::Simd, Cipher, intern::SwapRows,
|
||||
},
|
||||
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
|
||||
};
|
||||
|
||||
use std::ops::*;
|
||||
|
||||
#[test]
|
||||
fn can_swap_rows_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
fn swap_impl<T>(a: T) -> T
|
||||
where T: SwapRows<Output = T>
|
||||
{
|
||||
a.swap_rows()
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn swap_rows(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
swap_impl(a)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(swap_rows)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
@@ -23,11 +32,12 @@ fn can_swap_rows_cipher() {
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
let a = Simd::<4>::try_from(data).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
@@ -35,14 +45,21 @@ fn can_swap_rows_cipher() {
|
||||
|
||||
let expected = [vec![5, 6, 7, 8], vec![1, 2, 3, 4]];
|
||||
|
||||
assert_eq!(c, swap_impl(a));
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_rotate_left_cipher() {
|
||||
fn shl_impl<T>(x: T, y: u64) -> T
|
||||
where T: Shl<u64, Output = T>
|
||||
{
|
||||
x << y
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a << 1
|
||||
shl_impl(a, 1)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
@@ -57,26 +74,31 @@ fn can_rotate_left_cipher() {
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
let a = Simd::<4>::try_from(data).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![2, 3, 4, 1], vec![6, 7, 8, 5]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
assert_eq!(c, shl_impl(a, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_rotate_right_cipher() {
|
||||
fn shr_impl<T>(x: T, y: u64) -> T
|
||||
where T: Shr<u64, Output = T>
|
||||
{
|
||||
x >> y
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a >> 1
|
||||
shr_impl(a, 1)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
@@ -91,26 +113,31 @@ fn can_rotate_right_cipher() {
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
let a = Simd::<4>::try_from(data).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![4, 1, 2, 3], vec![8, 5, 6, 7]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
assert_eq!(c, shr_impl(a, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_cipher_cipher() {
|
||||
fn add_impl<T>(a: T, b: T) -> T
|
||||
where T: Add<T, Output = T>
|
||||
{
|
||||
a + b
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a + b
|
||||
add_impl(a, b)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
@@ -125,32 +152,39 @@ fn can_add_cipher_cipher() {
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
|
||||
let a = Simd::<4>::try_from(data.clone()).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let b = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
let b = Simd::<4>::try_from(data).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![2, 4, 6, 8], vec![10, 12, 14, 16]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
assert_eq!(c, add_impl(a, b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_sub_cipher_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
fn sub_impl<T>(a: T, b: T) -> T
|
||||
where
|
||||
T: Sub<T, Output = T>
|
||||
{
|
||||
a - b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn sub(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
sub_impl(a, b)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(sub)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
@@ -162,29 +196,35 @@ fn can_sub_cipher_cipher() {
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
|
||||
let a = Simd::<4>::try_from(data.clone()).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let b = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
let b = Simd::<4>::try_from(data.clone()).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![0, 0, 0, 0], vec![0, 0, 0, 0]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
assert_eq!(c, sub_impl(a, b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mul_cipher_cipher() {
|
||||
fn mul_impl<T>(a: T, b: T) -> T
|
||||
where T: Mul<T, Output = T>
|
||||
{
|
||||
a * b
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn mul(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a * b
|
||||
mul_impl(a, b)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(mul)
|
||||
@@ -199,20 +239,59 @@ fn can_mul_cipher_cipher() {
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
|
||||
let a = Simd::<4>::try_from(data.clone()).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
let b = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
let b = Simd::<4>::try_from(data).unwrap();
|
||||
let b_c = runtime
|
||||
.encrypt(b, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
let args: Vec<CircuitInput> = vec![a_c.into(), b_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![1, 4, 9, 16], vec![25, 36, 49, 64]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
assert_eq!(c, mul_impl(a, b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_neg_cipher_cipher() {
|
||||
fn neg_impl<T>(a: T) -> T
|
||||
where T: Neg<Output = T>
|
||||
{
|
||||
-a
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn mul(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
neg_impl(a)
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(mul)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = Simd::<4>::try_from(data.clone()).unwrap();
|
||||
let a_c = runtime
|
||||
.encrypt(a, &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a_c.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, neg_impl(a));
|
||||
}
|
||||
Reference in New Issue
Block a user