Merge pull request #47 from Sunscreen-tech/rweber/nonFhe

Rweber/non fhe
This commit is contained in:
rickwebiii
2022-01-27 19:00:11 -08:00
committed by GitHub
7 changed files with 503 additions and 187 deletions

View File

@@ -5,8 +5,8 @@ use crate::types::{
};
use crate::{with_ctx, CircuitInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
use std::cmp::Eq;
use sunscreen_runtime::Error;
use std::ops::*;
use sunscreen_runtime::Error;
use num::Rational64;
@@ -108,7 +108,7 @@ impl Add for Rational {
fn add(self, rhs: Self) -> Self::Output {
Self::Output {
num: self.num * rhs.den + rhs.num * self.den,
den: self.den * rhs.den
den: self.den * rhs.den,
}
}
}
@@ -121,7 +121,7 @@ impl Add<f64> for Rational {
Self::Output {
num: self.num * rhs.den + rhs.num * self.den,
den: self.den * rhs.den
den: self.den * rhs.den,
}
}
}
@@ -134,7 +134,7 @@ impl Add<Rational> for f64 {
Self::Output {
num: lhs.num * rhs.den + rhs.num * lhs.den,
den: lhs.den * rhs.den
den: lhs.den * rhs.den,
}
}
}
@@ -145,7 +145,7 @@ impl Mul for Rational {
fn mul(self, rhs: Self) -> Self::Output {
Self::Output {
num: self.num * rhs.num,
den: self.den * rhs.den
den: self.den * rhs.den,
}
}
}
@@ -182,7 +182,7 @@ impl Sub for Rational {
fn sub(self, rhs: Self) -> Self::Output {
Self::Output {
num: self.num * rhs.den - rhs.num * self.den,
den: self.den * rhs.den
den: self.den * rhs.den,
}
}
}
@@ -195,7 +195,7 @@ impl Sub<f64> for Rational {
Self::Output {
num: self.num * rhs.den - rhs.num * self.den,
den: self.den * rhs.den
den: self.den * rhs.den,
}
}
}
@@ -208,7 +208,7 @@ impl Sub<Rational> for f64 {
Self::Output {
num: lhs.num * rhs.den - rhs.num * lhs.den,
den: lhs.den * rhs.den
den: lhs.den * rhs.den,
}
}
}
@@ -219,7 +219,7 @@ impl Div for Rational {
fn div(self, rhs: Self) -> Self::Output {
Self::Output {
num: self.num * rhs.den,
den: self.den * rhs.num
den: self.den * rhs.num,
}
}
}
@@ -232,7 +232,7 @@ impl Div<f64> for Rational {
Self::Output {
num: self.num * rhs.den,
den: self.den * rhs.num
den: self.den * rhs.num,
}
}
}
@@ -245,7 +245,7 @@ impl Div<Rational> for f64 {
Self::Output {
num: lhs.num * rhs.den,
den: lhs.den * rhs.num
den: lhs.den * rhs.num,
}
}
}
@@ -256,7 +256,7 @@ impl Neg for Rational {
fn neg(self) -> Self::Output {
Self::Output {
num: -self.num,
den: self.den
den: self.den,
}
}
}
@@ -651,9 +651,7 @@ impl GraphConstCipherDiv for Rational {
impl GraphCipherNeg for Rational {
type Val = Self;
fn graph_cipher_neg(
a: CircuitNode<Cipher<Self::Val>>,
) -> CircuitNode<Cipher<Self::Val>> {
fn graph_cipher_neg(a: CircuitNode<Cipher<Self::Val>>) -> CircuitNode<Cipher<Self::Val>> {
with_ctx(|ctx| {
let neg = ctx.add_negate(a.ids[0]);
let ids = [neg, a.ids[1]];
@@ -713,4 +711,4 @@ mod tests {
assert_eq!(-a, (-5.).try_into().unwrap());
}
}
}

View File

@@ -136,7 +136,7 @@ impl Add for Signed {
fn add(self, rhs: Self) -> Self::Output {
Self::Output {
val: self.val + rhs.val
val: self.val + rhs.val,
}
}
}
@@ -146,7 +146,7 @@ impl Add<i64> for Signed {
fn add(self, rhs: i64) -> Self::Output {
Self {
val: self.val + rhs
val: self.val + rhs,
}
}
}
@@ -156,7 +156,7 @@ impl Add<Signed> for i64 {
fn add(self, rhs: Signed) -> Self::Output {
Self::Output {
val: self + rhs.val
val: self + rhs.val,
}
}
}
@@ -166,7 +166,7 @@ impl Mul for Signed {
fn mul(self, rhs: Self) -> Self::Output {
Self::Output {
val: self.val * rhs.val
val: self.val * rhs.val,
}
}
}
@@ -176,7 +176,7 @@ impl Mul<i64> for Signed {
fn mul(self, rhs: i64) -> Self::Output {
Self {
val: self.val * rhs
val: self.val * rhs,
}
}
}
@@ -186,7 +186,7 @@ impl Mul<Signed> for i64 {
fn mul(self, rhs: Signed) -> Self::Output {
Self::Output {
val: self * rhs.val
val: self * rhs.val,
}
}
}
@@ -196,7 +196,7 @@ impl Sub for Signed {
fn sub(self, rhs: Self) -> Self::Output {
Self::Output {
val: self.val - rhs.val
val: self.val - rhs.val,
}
}
}
@@ -206,7 +206,7 @@ impl Sub<i64> for Signed {
fn sub(self, rhs: i64) -> Self::Output {
Self {
val: self.val - rhs
val: self.val - rhs,
}
}
}
@@ -216,7 +216,7 @@ impl Sub<Signed> for i64 {
fn sub(self, rhs: Signed) -> Self::Output {
Self::Output {
val: self - rhs.val
val: self - rhs.val,
}
}
}
@@ -225,9 +225,7 @@ impl Neg for Signed {
type Output = Self;
fn neg(self) -> Self::Output {
Self::Output {
val: -self.val
}
Self::Output { val: -self.val }
}
}
@@ -473,4 +471,4 @@ mod tests {
assert_eq!(-a, (-5).into());
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::{
crate_version,
types::{
intern::{Cipher, CircuitNode},
intern::{Cipher, CircuitNode, SwapRows},
ops::*,
BfvType, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName,
TypeNameInstance, Version,
@@ -12,6 +12,7 @@ use seal::{
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
Result as SealResult,
};
use std::ops::*;
use sunscreen_runtime::{Error as RuntimeError, Result as RuntimeResult};
/**
@@ -63,9 +64,9 @@ use sunscreen_runtime::{Error as RuntimeError, Result as RuntimeResult};
* value equal to half the polynomial degree needed to accomodate the
* circuit's noise budget constraint.
*/
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Simd<const LANES: usize> {
data: [Vec<i64>; 2],
data: [[i64; LANES]; 2],
}
impl<const LANES: usize> NumCiphertexts for Simd<LANES> {
@@ -178,8 +179,30 @@ impl<const LANES: usize> TryFromPlaintext for Simd<LANES> {
Ok(Self {
data: [
row_0.iter().take(LANES).map(|x| *x).collect(),
row_1.iter().take(LANES).map(|x| *x).collect(),
row_0
.iter()
.take(LANES)
.map(|x| *x)
.collect::<Vec<i64>>()
.try_into()
.map_err(|_| {
RuntimeError::FheTypeError(format!(
"Failed to convert Vec to [i64;{}]",
LANES
))
})?,
row_1
.iter()
.take(LANES)
.map(|x| *x)
.collect::<Vec<i64>>()
.try_into()
.map_err(|_| {
RuntimeError::FheTypeError(format!(
"Failed to convert Vec to [i64;{}]",
LANES
))
})?,
],
})
}
@@ -189,19 +212,217 @@ impl<const LANES: usize> TryFrom<[Vec<i64>; 2]> for Simd<LANES> {
type Error = RuntimeError;
fn try_from(data: [Vec<i64>; 2]) -> RuntimeResult<Self> {
if data[0].len() != data[1].len() || data[0].len() != LANES {
return Err(RuntimeError::FheTypeError(
format!("Invalid SIMD shape. Expected a 2x{} matrix", LANES).to_owned(),
));
}
Ok(Self { data })
Ok(Self {
data: [
data[0].clone().try_into().map_err(|_| {
RuntimeError::FheTypeError(format!("Failed to convert Vec to [i64;{}]", LANES))
})?,
data[1].clone().try_into().map_err(|_| {
RuntimeError::FheTypeError(format!("Failed to convert Vec to [i64;{}]", LANES))
})?,
],
})
}
}
impl<const LANES: usize> Into<[Vec<i64>; 2]> for Simd<LANES> {
fn into(self) -> [Vec<i64>; 2] {
self.data
[self.data[0].into(), self.data[1].into()]
}
}
impl<const LANES: usize> From<[[i64; LANES]; 2]> for Simd<LANES> {
fn from(data: [[i64; LANES]; 2]) -> Self {
Self { data }
}
}
impl<const LANES: usize> Into<[[i64; LANES]; 2]> for Simd<LANES> {
fn into(self) -> [[i64; LANES]; 2] {
[self.data[0], self.data[1]]
}
}
impl<const LANES: usize> Add for Simd<LANES> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let r_0: [i64; LANES] = self.data[0]
.iter()
.zip(rhs.data[0].iter())
.map(|(x, y)| x + y)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
let r_1: [i64; LANES] = self.data[1]
.iter()
.zip(rhs.data[1].iter())
.map(|(x, y)| x + y)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
Self { data: [r_0, r_1] }
}
}
impl<const LANES: usize> Sub for Simd<LANES> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
let r_0: [i64; LANES] = self.data[0]
.iter()
.zip(rhs.data[0].iter())
.map(|(x, y)| x - y)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
let r_1: [i64; LANES] = self.data[1]
.iter()
.zip(rhs.data[1].iter())
.map(|(x, y)| x - y)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
Self { data: [r_0, r_1] }
}
}
impl<const LANES: usize> Mul for Simd<LANES> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
let r_0: [i64; LANES] = self.data[0]
.iter()
.zip(rhs.data[0].iter())
.map(|(x, y)| x * y)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
let r_1: [i64; LANES] = self.data[1]
.iter()
.zip(rhs.data[1].iter())
.map(|(x, y)| x * y)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
Self { data: [r_0, r_1] }
}
}
impl<const LANES: usize> Neg for Simd<LANES> {
type Output = Self;
fn neg(self) -> Self::Output {
let r_0: [i64; LANES] = self.data[0]
.iter()
.map(|x| -x)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
let r_1: [i64; LANES] = self.data[1]
.iter()
.map(|x| -x)
.collect::<Vec<i64>>()
.try_into()
.unwrap();
Self { data: [r_0, r_1] }
}
}
impl<const LANES: usize> Shl<u64> for Simd<LANES> {
type Output = Self;
fn shl(self, x: u64) -> Self::Output {
let r_0: [i64; LANES] = [
self.data[0]
.iter()
.skip(x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
self.data[0]
.iter()
.take(x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
]
.concat()
.try_into()
.unwrap();
let r_1: [i64; LANES] = [
self.data[1]
.iter()
.skip(x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
self.data[1]
.iter()
.take(x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
]
.concat()
.try_into()
.unwrap();
Self { data: [r_0, r_1] }
}
}
impl<const LANES: usize> Shr<u64> for Simd<LANES> {
type Output = Self;
fn shr(self, x: u64) -> Self::Output {
let r_0: [i64; LANES] = [
self.data[0]
.iter()
.skip(LANES - x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
self.data[0]
.iter()
.take(LANES - x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
]
.concat()
.try_into()
.unwrap();
let r_1: [i64; LANES] = [
self.data[1]
.iter()
.skip(LANES - x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
self.data[1]
.iter()
.take(LANES -x as usize)
.map(|x| *x)
.collect::<Vec<i64>>(),
]
.concat()
.try_into()
.unwrap();
Self { data: [r_0, r_1] }
}
}
impl<const LANES: usize> SwapRows for Simd<LANES> {
type Output = Self;
fn swap_rows(self) -> Self::Output {
Self {
data: [
self.data[1],
self.data[0]
]
}
}
}
@@ -288,6 +509,19 @@ impl<const LANES: usize> GraphCipherRotateRight for Simd<LANES> {
}
}
impl<const LANES: usize> GraphCipherNeg for Simd<LANES> {
type Val = Self;
fn graph_cipher_neg(
x: CircuitNode<Cipher<Self>>,
) -> CircuitNode<Cipher<Self::Val>> {
with_ctx(|ctx| {
let n = ctx.add_negate(x.ids[0]);
CircuitNode::new(&[n])
})
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -317,4 +551,59 @@ mod tests {
assert_eq!(x, y);
}
const A_VEC: [[i64; 4]; 2] = [[1, 2, 3, 4], [5, 6, 7, 8]];
const B_VEC: [[i64; 4]; 2] = [[5, 6, 7, 8], [1, 2, 3, 4]];
#[test]
fn can_add_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
let b = Simd::<4>::try_from(B_VEC).unwrap();
assert_eq!(a + b, [[6, 8, 10, 12], [6, 8, 10, 12]].into());
}
#[test]
fn can_mul_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
let b = Simd::<4>::try_from(B_VEC).unwrap();
assert_eq!(a * b, [[5, 12, 21, 32], [5, 12, 21, 32]].into());
}
#[test]
fn can_sub_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
let b = Simd::<4>::try_from(B_VEC).unwrap();
assert_eq!(a - b, [[-4, -4, -4, -4], [4, 4, 4, 4]].into());
}
#[test]
fn can_neg_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
assert_eq!(-a, [[-1, -2, -3, -4], [-5, -6, -7, -8]].into());
}
#[test]
fn can_shl_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
assert_eq!(a << 3, [[4, 1, 2, 3], [8, 5, 6, 7]].into());
}
#[test]
fn can_shr_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
assert_eq!(a >> 3, [[2, 3, 4, 1], [6, 7, 8, 5]].into());
}
#[test]
fn can_swap_rows_non_fhe() {
let a = Simd::<4>::try_from(A_VEC).unwrap();
assert_eq!(a.swap_rows(), [[5, 6, 7, 8], [1, 2, 3, 4]].into());
}
}

View File

@@ -104,5 +104,4 @@ where
trait Foo {}
impl<T> Foo for T
where T: FheType { }
impl<T> Foo for T where T: FheType {}

View File

@@ -36,9 +36,9 @@ fn can_encode_rational_numbers() {
type CipherRational = Cipher<Rational>;
fn add_impl<T, U, R>(a: T, b: U) -> R
fn add_impl<T, U, R>(a: T, b: U) -> R
where
T: Add<U, Output = R>
T: Add<U, Output = R>,
{
a + b
}
@@ -61,13 +61,9 @@ fn can_add_cipher_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
@@ -94,9 +90,7 @@ fn can_add_cipher_plain() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
@@ -126,9 +120,7 @@ fn can_add_plain_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
@@ -158,9 +150,7 @@ fn can_add_cipher_literal() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -189,9 +179,7 @@ fn can_add_literal_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -202,9 +190,9 @@ fn can_add_literal_cipher() {
assert_eq!(c, add_impl(3.14, a));
}
fn sub_impl<T, U, R>(a: T, b: U) -> R
fn sub_impl<T, U, R>(a: T, b: U) -> R
where
T: Sub<U, Output = R>
T: Sub<U, Output = R>,
{
a - b
}
@@ -227,13 +215,9 @@ fn can_sub_cipher_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
@@ -260,9 +244,7 @@ fn can_sub_cipher_plain() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
@@ -292,11 +274,9 @@ fn can_sub_plain_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b_c.into()];
@@ -325,9 +305,7 @@ fn can_sub_cipher_literal() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -356,9 +334,7 @@ fn can_sub_literal_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -369,9 +345,9 @@ fn can_sub_literal_cipher() {
assert_eq!(c, sub_impl(3.14, a));
}
fn mul_impl<T, U, R>(a: T, b: U) -> R
fn mul_impl<T, U, R>(a: T, b: U) -> R
where
T: Mul<U, Output = R>
T: Mul<U, Output = R>,
{
a * b
}
@@ -394,13 +370,9 @@ fn can_mul_cipher_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
@@ -427,9 +399,7 @@ fn can_mul_cipher_plain() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
@@ -459,11 +429,9 @@ fn can_mul_plain_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b_c.into()];
@@ -492,9 +460,7 @@ fn can_mul_cipher_literal() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -523,9 +489,7 @@ fn can_mul_literal_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -536,9 +500,9 @@ fn can_mul_literal_cipher() {
assert_eq!(c, mul_impl(3.14, a));
}
fn div_impl<T, U, R>(a: T, b: U) -> R
fn div_impl<T, U, R>(a: T, b: U) -> R
where
T: Mul<U, Output = R>
T: Mul<U, Output = R>,
{
a * b
}
@@ -561,13 +525,9 @@ fn can_div_cipher_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap();
@@ -594,9 +554,7 @@ fn can_div_cipher_plain() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let b = Rational::try_from(6.28).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into(), b.into()];
@@ -626,11 +584,9 @@ fn can_div_plain_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-3.14).unwrap();
let b = Rational::try_from(6.28).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let b_c = runtime.encrypt(b, &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b_c.into()];
@@ -659,9 +615,7 @@ fn can_div_cipher_literal() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -690,9 +644,7 @@ fn can_div_literal_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -705,9 +657,9 @@ fn can_div_literal_cipher() {
#[test]
fn can_neg_cipher() {
fn neg_impl<T>(x: T) -> T
fn neg_impl<T>(x: T) -> T
where
T: Neg<Output = T>
T: Neg<Output = T>,
{
-x
}
@@ -728,9 +680,7 @@ fn can_neg_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = Rational::try_from(-6.28).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let a_c = runtime.encrypt(a, &public).unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
@@ -739,4 +689,4 @@ fn can_neg_cipher() {
let c: Rational = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, neg_impl(a));
}
}

View File

@@ -6,7 +6,8 @@ use sunscreen_compiler::{
use std::ops::*;
fn add_fn<T, U, R>(a: T, b: U) -> R where
fn add_fn<T, U, R>(a: T, b: U) -> R
where
T: Add<U, Output = R>,
{
a + b
@@ -128,7 +129,8 @@ fn can_add_literal_cipher() {
assert_eq!(c, add_fn(-4, a));
}
fn sub_fn<T, U, R>(a: T, b: U) -> R where
fn sub_fn<T, U, R>(a: T, b: U) -> R
where
T: Sub<U, Output = R>,
{
a - b
@@ -250,7 +252,8 @@ fn can_sub_literal_cipher() {
assert_eq!(c, sub_fn(-4, a));
}
fn mul_fn<T, U, R>(a: T, b: U) -> R where
fn mul_fn<T, U, R>(a: T, b: U) -> R
where
T: Mul<U, Output = R>,
{
a * b
@@ -370,4 +373,4 @@ fn can_mul_literal_cipher() {
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, mul_fn(-4, a));
}
}

View File

@@ -1,17 +1,26 @@
use sunscreen_compiler::{
circuit,
types::{bfv::Simd, Cipher},
types::{bfv::Simd, Cipher, intern::SwapRows,
},
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
};
use std::ops::*;
#[test]
fn can_swap_rows_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
fn swap_impl<T>(a: T) -> T
where T: SwapRows<Output = T>
{
a.swap_rows()
}
let circuit = Compiler::with_circuit(add)
#[circuit(scheme = "bfv")]
fn swap_rows(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
swap_impl(a)
}
let circuit = Compiler::with_circuit(swap_rows)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
@@ -23,11 +32,12 @@ fn can_swap_rows_cipher() {
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
let a = Simd::<4>::try_from(data).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let args: Vec<CircuitInput> = vec![a_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
@@ -35,14 +45,21 @@ fn can_swap_rows_cipher() {
let expected = [vec![5, 6, 7, 8], vec![1, 2, 3, 4]];
assert_eq!(c, swap_impl(a));
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_rotate_left_cipher() {
fn shl_impl<T>(x: T, y: u64) -> T
where T: Shl<u64, Output = T>
{
x << y
}
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a << 1
shl_impl(a, 1)
}
let circuit = Compiler::with_circuit(add)
@@ -57,26 +74,31 @@ fn can_rotate_left_cipher() {
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
let a = Simd::<4>::try_from(data).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let args: Vec<CircuitInput> = vec![a_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![2, 3, 4, 1], vec![6, 7, 8, 5]];
assert_eq!(c, expected.try_into().unwrap());
assert_eq!(c, shl_impl(a, 1));
}
#[test]
fn can_rotate_right_cipher() {
fn shr_impl<T>(x: T, y: u64) -> T
where T: Shr<u64, Output = T>
{
x >> y
}
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a >> 1
shr_impl(a, 1)
}
let circuit = Compiler::with_circuit(add)
@@ -91,26 +113,31 @@ fn can_rotate_right_cipher() {
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
let a = Simd::<4>::try_from(data).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let args: Vec<CircuitInput> = vec![a_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![4, 1, 2, 3], vec![8, 5, 6, 7]];
assert_eq!(c, expected.try_into().unwrap());
assert_eq!(c, shr_impl(a, 1));
}
#[test]
fn can_add_cipher_cipher() {
fn add_impl<T>(a: T, b: T) -> T
where T: Add<T, Output = T>
{
a + b
}
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a + b
add_impl(a, b)
}
let circuit = Compiler::with_circuit(add)
@@ -125,32 +152,39 @@ fn can_add_cipher_cipher() {
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
let a = Simd::<4>::try_from(data.clone()).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
let b = Simd::<4>::try_from(data).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let args: Vec<CircuitInput> = vec![a_c.into(), b_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![2, 4, 6, 8], vec![10, 12, 14, 16]];
assert_eq!(c, expected.try_into().unwrap());
assert_eq!(c, add_impl(a, b));
}
#[test]
fn can_sub_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
fn sub_impl<T>(a: T, b: T) -> T
where
T: Sub<T, Output = T>
{
a - b
}
let circuit = Compiler::with_circuit(add)
#[circuit(scheme = "bfv")]
fn sub(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
sub_impl(a, b)
}
let circuit = Compiler::with_circuit(sub)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
@@ -162,29 +196,35 @@ fn can_sub_cipher_cipher() {
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
let a = Simd::<4>::try_from(data.clone()).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
let b = Simd::<4>::try_from(data.clone()).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let args: Vec<CircuitInput> = vec![a_c.into(), b_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![0, 0, 0, 0], vec![0, 0, 0, 0]];
assert_eq!(c, expected.try_into().unwrap());
assert_eq!(c, sub_impl(a, b));
}
#[test]
fn can_mul_cipher_cipher() {
fn mul_impl<T>(a: T, b: T) -> T
where T: Mul<T, Output = T>
{
a * b
}
#[circuit(scheme = "bfv")]
fn mul(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a * b
mul_impl(a, b)
}
let circuit = Compiler::with_circuit(mul)
@@ -199,20 +239,59 @@ fn can_mul_cipher_cipher() {
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
let a = Simd::<4>::try_from(data.clone()).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
let b = Simd::<4>::try_from(data).unwrap();
let b_c = runtime
.encrypt(b, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let args: Vec<CircuitInput> = vec![a_c.into(), b_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![1, 4, 9, 16], vec![25, 36, 49, 64]];
assert_eq!(c, expected.try_into().unwrap());
assert_eq!(c, mul_impl(a, b));
}
#[test]
fn can_neg_cipher_cipher() {
fn neg_impl<T>(a: T) -> T
where T: Neg<Output = T>
{
-a
}
#[circuit(scheme = "bfv")]
fn mul(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
neg_impl(a)
}
let circuit = Compiler::with_circuit(mul)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = Simd::<4>::try_from(data.clone()).unwrap();
let a_c = runtime
.encrypt(a, &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a_c.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, neg_impl(a));
}