Merge pull request #45 from Sunscreen-tech/rweber/simd

Rweber/simd
This commit is contained in:
rickwebiii
2022-01-26 15:43:30 -08:00
committed by GitHub
21 changed files with 745 additions and 76 deletions

36
.vscode/launch.json vendored
View File

@@ -115,24 +115,6 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unsigned integration tests in library 'sunscreen_compiler'",
"cargo": {
"args": [
"test",
"--no-run",
"--package=sunscreen_compiler",
],
"filter": {
"name": "unsigned",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
@@ -169,6 +151,24 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug simd integration tests in library 'sunscreen_compiler'",
"cargo": {
"args": [
"test",
"--no-run",
"--package=sunscreen_compiler",
],
"filter": {
"name": "simd",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",

View File

@@ -205,11 +205,7 @@ fn alice(
})
}
fn compile_circuits() -> (
CompiledCircuit,
CompiledCircuit,
CompiledCircuit,
) {
fn compile_circuits() -> (CompiledCircuit, CompiledCircuit, CompiledCircuit) {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Fractional<64>>, b: Cipher<Fractional<64>>) -> Cipher<Fractional<64>> {
a + b

View File

@@ -55,17 +55,17 @@ pub(crate) fn validate_nodes(ir: &Circuit) -> Vec<IRError> {
);
}
SubPlaintext => {
errors.append(
&mut validate_binary_op_has_correct_operands(
ir,
i,
OutputType::Ciphertext,
OutputType::Plaintext,
)
.iter()
.map(|e| IRError::NodeError(i, node_info.to_string(), *e))
.collect(),
);
errors.append(
&mut validate_binary_op_has_correct_operands(
ir,
i,
OutputType::Ciphertext,
OutputType::Plaintext,
)
.iter()
.map(|e| IRError::NodeError(i, node_info.to_string(), *e))
.collect(),
);
}
Multiply => {
errors.append(

View File

@@ -438,6 +438,13 @@ impl Context {
self.add_2_input(Operation::RotateRight, left, right)
}
/**
* Adds a row swap.
*/
pub fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::SwapRows, x)
}
/**
* Add a node that captures the previous node as an output.
*/

View File

@@ -3,7 +3,8 @@ use seal::Plaintext as SealPlaintext;
use crate::types::{
ops::{
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
GraphCipherMul, GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherSub, GraphCipherPlainSub, GraphPlainCipherSub, GraphCipherConstSub, GraphConstCipherSub
GraphCipherConstSub, GraphCipherMul, GraphCipherPlainAdd, GraphCipherPlainMul,
GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub, GraphPlainCipherSub,
},
Cipher,
};

View File

@@ -6,4 +6,4 @@ mod simd;
pub use fractional::*;
pub use rational::*;
pub use signed::*;
pub use simd::*;
pub use simd::*;

View File

@@ -1,6 +1,7 @@
use crate::types::{
bfv::Signed, intern::CircuitNode, BfvType, Cipher, FheType, GraphCipherAdd, GraphCipherDiv,
GraphCipherMul, GraphCipherSub, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, TypeName, ops::*,
bfv::Signed, intern::CircuitNode, ops::*, BfvType, Cipher, FheType, GraphCipherAdd,
GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, TryFromPlaintext,
TryIntoPlaintext, TypeName,
};
use crate::{with_ctx, CircuitInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
use std::cmp::Eq;
@@ -157,9 +158,11 @@ impl GraphCipherConstAdd for Rational {
with_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let b_num = ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let b_num =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let b_den = ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
let b_den =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b_den);
@@ -250,13 +253,15 @@ impl GraphCipherConstSub for Rational {
fn graph_cipher_const_sub(
a: CircuitNode<Cipher<Self::Left>>,
b: Self::Right
b: Self::Right,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let b_num = ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let b_den = ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
let b_num =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let b_den =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b_den);
@@ -283,8 +288,10 @@ impl GraphConstCipherSub for Rational {
with_ctx(|ctx| {
let a = Self::try_from(a).unwrap();
let a_num = ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
let a_den = ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
let a_num =
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
let a_den =
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
// Scale each numinator by the other's denominator.
let num_b_2 = ctx.add_multiplication_plaintext(b.ids[0], a_den);
@@ -351,8 +358,10 @@ impl GraphCipherConstMul for Rational {
with_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let num_b = ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let den_b = ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
let num_b =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let den_b =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], num_b);
@@ -436,8 +445,10 @@ impl GraphCipherConstDiv for Rational {
with_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let num_b = ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let den_b = ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
let num_b =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
let den_b =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], den_b);
@@ -461,8 +472,10 @@ impl GraphConstCipherDiv for Rational {
with_ctx(|ctx| {
let a = Self::try_from(a).unwrap();
let num_a = ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
let den_a = ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
let num_a =
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
let den_a =
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(b.ids[0], den_a);
@@ -473,4 +486,4 @@ impl GraphConstCipherDiv for Rational {
CircuitNode::new(&ids)
})
}
}
}

View File

@@ -1,6 +1,320 @@
/**
* A vectorized
*/
pub struct Simd {
use crate::{
crate_version,
types::{
intern::{Cipher, CircuitNode},
ops::*,
BfvType, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName,
TypeNameInstance, Version,
},
with_ctx, CircuitInputTrait, InnerPlaintext, Literal, Params, Plaintext, WithContext,
};
use seal::{
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
Result as SealResult,
};
use sunscreen_runtime::{Error as RuntimeError, Result as RuntimeResult};
}
/**
* A SIMD vector of signed integers. The vector has 2 rows of `LANES`
* columns. The `LANES` value must be a power of 2 up to 16384.
*
* # Remarks
* Plaintexts in the BFV scheme are polynomials. When the plaintext
* modulus is an appropriate prime number, one can decompose the
* cyclotomic field into ideals using the Chinese remainder theorem.
* Each ideal is a value independent of the other and forms a SIMD lane.
*
* In the BFV scheme using a vector encoding, plaintexts encode as a
* `2xN/2` matrix, where N is the scheme's polynomial degree.
* Homomorphic addition, subtraction, and multiplication
* operate element-wise, thus making the scheme similar to CPU SIMD
* instructions (e.g. Intel AVX or ARM Neon) with the minor distinction
* that BFV vector types have 2 rows of values.
*
* Unlike CPU vector instructions, which typically feature 4-16 lanes,
* BFV Simd vectors have thousands of lanes. The LANES values
* effectively demarks a constraint to the compiler that the polynomial
* degree must be at least 2*LANES. Should the compiler choose a larger
* degree for unrelated reasons (e.g. noise budget), the Simd type will
* automatically repeat the lanes so that rotation operations behave
* as if you only have `LANES` elements. For example, if `LANES` is
* 4 (not actually a legal value, but illustrative only!)
*
* To combine values across multiple lanes, one can use rotation
* operations. Unlike a shift, rotation operations cause elements to
* wrap around rather than truncate. The Simd type exposes these as the
* `<<`, `>>`, and `swap_rows` operators:
* * `x << n`, where n is a u64 rotates each row n places to the left.
* For example, `[0, 1, 2, 3; 4, 5, 6, 7] << 3` yields
* `[3, 0, 1, 2; 7, 4, 5, 6]` (note that real vectors have many more
* columns).
* * `x << n`, where n is a u64 rotates each lane n places to the left.
* For example, `[0, 1, 2, 3; 4, 5, 6, 7] >> 1` yields `[3, 0, 1, 2; 7, 4, 5, 6]`.
* * `x.swap_rows()` swaps the rows. For example, `[0, 1, 2, 3; 4, 5, 6, 7].swap_rows()` yields `[4, 5, 6, 7; 0, 1, 2, 3]`.
*
* # Performance
* The BFV scheme is parameterized by a number of values. Generally,
* the polynomial degree has primacy in determining execution time.
* A smaller polynomial degree results in a smaller noise budget, but
* each operation is faster. Additionally, a smaller polynomial degree
* results in fewer SIMD lanes in a plaintext.
*
* To maximally utilize circuit throughput, one should choose a `LANES`
* value equal to half the polynomial degree needed to accomodate the
* circuit's noise budget constraint.
*/
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Simd<const LANES: usize> {
data: [Vec<i64>; 2],
}
impl<const LANES: usize> NumCiphertexts for Simd<LANES> {
const NUM_CIPHERTEXTS: usize = 1;
}
impl<const LANES: usize> TypeName for Simd<LANES> {
fn type_name() -> Type {
Type {
name: format!("sunscreen_compiler::types::Simd<{}>", LANES),
version: Version::parse(crate_version!()).expect("Crate version is not a valid semver"),
is_encrypted: false,
}
}
}
impl<const LANES: usize> TypeNameInstance for Simd<LANES> {
fn type_name_instance(&self) -> Type {
Self::type_name()
}
}
impl<const LANES: usize> CircuitInputTrait for Simd<LANES> {}
impl<const LANES: usize> FheType for Simd<LANES> {}
impl<const LANES: usize> BfvType for Simd<LANES> {}
impl<const LANES: usize> TryIntoPlaintext for Simd<LANES> {
fn try_into_plaintext(
&self,
params: &Params,
) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
if (params.lattice_dimension / 2) as usize % LANES != 0 {
return Err(RuntimeError::FheTypeError(
"LANES must be a power two".to_owned(),
));
}
if 2 * LANES > params.lattice_dimension as usize {
return Err(RuntimeError::FheTypeError(
"LANES must be <= polynomial degree / 2".to_owned(),
));
}
let encryption_params = BfvEncryptionParametersBuilder::new()
.set_poly_modulus_degree(params.lattice_dimension)
.set_plain_modulus(Modulus::new(params.plain_modulus)?)
.set_coefficient_modulus(
params
.coeff_modulus
.iter()
.map(|x| Modulus::new(*x))
.collect::<SealResult<Vec<Modulus>>>()?,
)
.build()?;
let context = SealContext::new(&encryption_params, false, params.security_level)?;
let encoder = BFVEncoder::new(&context)?;
let reps = params.lattice_dimension as usize / (2 * LANES);
let data = [self.data[0].repeat(reps), self.data[1].repeat(reps)].concat();
let plaintext = encoder.encode_signed(&data)?;
Ok(Plaintext {
data_type: Self::type_name(),
inner: InnerPlaintext::Seal(vec![WithContext {
params: params.clone(),
data: plaintext,
}]),
})
}
}
impl<const LANES: usize> TryFromPlaintext for Simd<LANES> {
fn try_from_plaintext(
plaintext: &Plaintext,
params: &Params,
) -> std::result::Result<Self, sunscreen_runtime::Error> {
let plaintext = plaintext.inner_as_seal_plaintext()?;
if plaintext.len() != 1 {
return Err(sunscreen_runtime::Error::FheTypeError(
"Expected 1 plaintext".to_owned(),
));
}
if plaintext[0].params != *params {
return Err(sunscreen_runtime::Error::ParameterMismatch);
}
let encryption_params = BfvEncryptionParametersBuilder::new()
.set_poly_modulus_degree(params.lattice_dimension)
.set_plain_modulus(Modulus::new(params.plain_modulus)?)
.set_coefficient_modulus(
params
.coeff_modulus
.iter()
.map(|x| Modulus::new(*x))
.collect::<SealResult<Vec<Modulus>>>()?,
)
.build()?;
let context = SealContext::new(&encryption_params, false, params.security_level)?;
let encoder = BFVEncoder::new(&context)?;
let data = encoder.decode_signed(&plaintext[0].data)?;
let (row_0, row_1) = data.split_at(params.lattice_dimension as usize / 2);
Ok(Self {
data: [
row_0.iter().take(LANES).map(|x| *x).collect(),
row_1.iter().take(LANES).map(|x| *x).collect(),
],
})
}
}
impl<const LANES: usize> TryFrom<[Vec<i64>; 2]> for Simd<LANES> {
type Error = RuntimeError;
fn try_from(data: [Vec<i64>; 2]) -> RuntimeResult<Self> {
if data[0].len() != data[1].len() || data[0].len() != LANES {
return Err(RuntimeError::FheTypeError(
format!("Invalid SIMD shape. Expected a 2x{} matrix", LANES).to_owned(),
));
}
Ok(Self { data })
}
}
impl<const LANES: usize> Into<[Vec<i64>; 2]> for Simd<LANES> {
fn into(self) -> [Vec<i64>; 2] {
self.data
}
}
impl<const LANES: usize> GraphCipherAdd for Simd<LANES> {
type Left = Self;
type Right = Self;
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherSub for Simd<LANES> {
type Left = Self;
type Right = Self;
fn graph_cipher_sub(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherMul for Simd<LANES> {
type Left = Self;
type Right = Self;
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherSwapRows for Simd<LANES> {
fn graph_cipher_swap_rows(x: CircuitNode<Cipher<Self>>) -> CircuitNode<Cipher<Self>> {
with_ctx(|ctx| {
let n = ctx.add_swap_rows(x.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherRotateLeft for Simd<LANES> {
fn graph_cipher_rotate_left(x: CircuitNode<Cipher<Self>>, y: u64) -> CircuitNode<Cipher<Self>> {
with_ctx(|ctx| {
let y = ctx.add_literal(Literal::U64(y));
let n = ctx.add_rotate_left(x.ids[0], y);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherRotateRight for Simd<LANES> {
fn graph_cipher_rotate_right(
x: CircuitNode<Cipher<Self>>,
y: u64,
) -> CircuitNode<Cipher<Self>> {
with_ctx(|ctx| {
let y = ctx.add_literal(Literal::U64(y));
let n = ctx.add_rotate_right(x.ids[0], y);
CircuitNode::new(&[n])
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SchemeType;
use seal::{CoefficientModulus, PlainModulus, SecurityLevel};
#[test]
fn can_roundtrip_encode_simd() {
let data = [vec![0, 1, 2, 3], vec![4, 5, 6, 7]];
let params = Params {
lattice_dimension: 4096,
plain_modulus: PlainModulus::batching(4096, 16).unwrap().value(),
coeff_modulus: CoefficientModulus::bfv_default(4096, SecurityLevel::TC128)
.unwrap()
.iter()
.map(|x| x.value())
.collect::<Vec<u64>>(),
scheme_type: SchemeType::Bfv,
security_level: SecurityLevel::TC128,
};
let x = Simd::<4>::try_from(data.clone()).unwrap();
let plaintext = x.try_into_plaintext(&params).unwrap();
let y = Simd::<4>::try_from_plaintext(&plaintext, &params).unwrap();
assert_eq!(x, y);
}
}

View File

@@ -4,7 +4,7 @@ use crate::{
};
use petgraph::stable_graph::NodeIndex;
use std::ops::{Add, Div, Mul, Neg, Sub};
use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
/**
@@ -411,7 +411,6 @@ where
}
}
// literal / cipher
impl<T> Div<CircuitNode<Cipher<T>>> for f64
where
@@ -459,3 +458,52 @@ where
T::graph_cipher_neg(self)
}
}
/**
* A trait that allows data types to swap_rows. E.g. [`Simd`](crate::types::bfv::Simd)
*/
pub trait SwapRows {
/**
* The result type. Typically, this should just be `Self`.
*/
type Output;
/**
* Performs a row swap.
*/
fn swap_rows(self) -> Self::Output;
}
// ciphertext
impl<T> SwapRows for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherSwapRows,
{
type Output = Self;
fn swap_rows(self) -> Self::Output {
T::graph_cipher_swap_rows(self)
}
}
impl<T> Shl<u64> for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherRotateLeft,
{
type Output = Self;
fn shl(self, x: u64) -> Self {
T::graph_cipher_rotate_left(self, x)
}
}
impl<T> Shr<u64> for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherRotateRight,
{
type Output = Self;
fn shr(self, x: u64) -> Self {
T::graph_cipher_rotate_right(self, x)
}
}

View File

@@ -59,7 +59,7 @@ pub mod bfv;
*/
pub mod intern;
/*
/**
* Contains the set of ops traits that dictate legal operations
* for FHE data types.
*/

View File

@@ -51,7 +51,7 @@ pub trait GraphCipherPlainDiv {
}
/**
* Called when a circuit encounters a / operation with a
* Called when a circuit encounters a / operation with a
* plaintext numerator and ciphertext denominator.
*/
pub trait GraphPlainCipherDiv {
@@ -98,7 +98,7 @@ pub trait GraphCipherConstDiv {
}
/**
* Called when a circuit encounters a / operation on a
* Called when a circuit encounters a / operation on a
* literal numerator and encrypted denominator.
*/
pub trait GraphConstCipherDiv {

View File

@@ -2,10 +2,12 @@ mod add;
mod div;
mod mul;
mod neg;
mod rotate;
mod sub;
pub use add::*;
pub use div::*;
pub use mul::*;
pub use neg::*;
pub use rotate::*;
pub use sub::*;

View File

@@ -12,5 +12,8 @@ pub trait GraphCipherNeg {
*/
type Val: FheType;
/**
* Negates the given ciphertext (e.g. -x).
*/
fn graph_cipher_neg(a: CircuitNode<Cipher<Self::Val>>) -> CircuitNode<Cipher<Self::Val>>;
}

View File

@@ -0,0 +1,34 @@
use crate::types::{intern::CircuitNode, Cipher, FheType};
/**
* Swaps the rows of the given ciphertext.
*/
pub trait GraphCipherSwapRows
where
Self: FheType,
{
/**
* Swap the rows in the given ciphertext.
*/
fn graph_cipher_swap_rows(x: CircuitNode<Cipher<Self>>) -> CircuitNode<Cipher<Self>>;
}
pub trait GraphCipherRotateLeft
where
Self: FheType,
{
fn graph_cipher_rotate_left(
x: CircuitNode<Cipher<Self>>,
amount: u64,
) -> CircuitNode<Cipher<Self>>;
}
pub trait GraphCipherRotateRight
where
Self: FheType,
{
fn graph_cipher_rotate_right(
x: CircuitNode<Cipher<Self>>,
amount: u64,
) -> CircuitNode<Cipher<Self>>;
}

View File

@@ -1,7 +1,7 @@
use sunscreen_compiler::{
circuit,
types::{bfv::Rational, Cipher},
Compiler, PlainModulusConstraint, Runtime, CircuitInput
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
};
#[test]
@@ -157,7 +157,6 @@ fn can_add_cipher_literal() {
assert_eq!(c, (-3.14).try_into().unwrap());
}
#[test]
fn can_add_literal_cipher() {
#[circuit(scheme = "bfv")]
@@ -188,7 +187,6 @@ fn can_add_literal_cipher() {
assert_eq!(c, (-3.14).try_into().unwrap());
}
#[test]
fn can_mul_cipher_cipher() {
#[circuit(scheme = "bfv")]
@@ -585,4 +583,4 @@ fn can_sub_literal_cipher() {
let c: Rational = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, (1.64).try_into().unwrap());
}
}

View File

@@ -0,0 +1,218 @@
use sunscreen_compiler::{
circuit,
types::{bfv::Simd, Cipher},
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
};
#[test]
fn can_swap_rows_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a.swap_rows()
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![5, 6, 7, 8], vec![1, 2, 3, 4]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_rotate_left_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a << 1
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![2, 3, 4, 1], vec![6, 7, 8, 5]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_rotate_right_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a >> 1
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![4, 1, 2, 3], vec![8, 5, 6, 7]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_add_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a + b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![2, 4, 6, 8], vec![10, 12, 14, 16]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_sub_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a - b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![0, 0, 0, 0], vec![0, 0, 0, 0]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_mul_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a * b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![1, 4, 9, 16], vec![25, 36, 49, 64]];
assert_eq!(c, expected.try_into().unwrap());
}

View File

@@ -100,7 +100,7 @@ pub fn circuit_impl(
fn build(&self, params: &sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation> {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{CircuitNode, Input}, NumCiphertexts, Type, TypeName, TypeNameInstance}};
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{CircuitNode, Input, SwapRows}, NumCiphertexts, Type, TypeName, TypeNameInstance}};
if SchemeType::Bfv != params.scheme_type {
return Err(Error::IncorrectScheme)

View File

@@ -351,10 +351,7 @@ fn can_collect_output() {
#[test]
fn can_collect_multiple_outputs() {
#[circuit(scheme = "bfv")]
fn circuit_with_args(
a: Cipher<Signed>,
b: Cipher<Signed>,
) -> (Cipher<Signed>, Cipher<Signed>) {
fn circuit_with_args(a: Cipher<Signed>, b: Cipher<Signed>) -> (Cipher<Signed>, Cipher<Signed>) {
(a + b * a, a)
}

View File

@@ -119,6 +119,12 @@ pub enum Error {
* An error occurred when serializing/deserializing with bincode.
*/
BincodeError(String),
/**
* Called [`inner_as_seal_plaintext`](crate::InnerPlaintext.inner_as_seal_plaintext)
* on non-Seal plaintext.
*/
NotASealPlaintext,
}
impl From<bincode::Error> for Error {

View File

@@ -72,6 +72,16 @@ impl InnerPlaintext {
pub fn from_bytes(data: &[u8]) -> Result<Self> {
Ok(bincode::deserialize(data)?)
}
/**
* Unwraps the enum and returns the underlying seal plaintexts, or
* returns an error if this plaintext isn't a Seal plaintext.
*/
pub fn as_seal_plaintext(&self) -> Result<&[WithContext<SealPlaintext>]> {
match self {
Self::Seal(d) => Ok(&d),
}
}
}
#[derive(Clone)]
@@ -118,6 +128,16 @@ pub struct Plaintext {
pub inner: InnerPlaintext,
}
impl Plaintext {
/**
* Unwraps the inner plaintext as a Seal plaintext variant. Returns an
* error if the inner plaintext is not a Seal plaintext.
*/
pub fn inner_as_seal_plaintext(&self) -> Result<&[WithContext<SealPlaintext>]> {
Ok(self.inner.as_seal_plaintext()?)
}
}
#[derive(Clone, Deserialize, Serialize)]
/**
* The underlying backend implementation of a ciphertext (e.g SEAL's [`Ciphertext`](seal::Ciphertext)).

View File

@@ -226,7 +226,7 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
let c = evaluator.multiply(&a, &b)?;
data[index.index()].store(Some(Cow::Owned(c.into())));
},
}
MultiplyPlaintext => {
let (left, right) = get_left_right_operands(ir, index);
@@ -236,8 +236,20 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
let c = evaluator.multiply_plain(&a, &b)?;
data[index.index()].store(Some(Cow::Owned(c.into())));
},
SwapRows => unimplemented!(),
}
SwapRows => {
let galois_keys = galois_keys
.as_ref()
.ok_or(CircuitRunFailure::MissingGaloisKeys)?;
let input = get_unary_operand(ir, index);
let x = get_ciphertext(&data, input.index())?;
let y = evaluator.rotate_columns(&x, galois_keys)?;
data[index.index()].store(Some(Cow::Owned(y.into())));
}
Relinearize => {
let relin_keys = relin_keys
.as_ref()
@@ -259,7 +271,7 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
let y = evaluator.negate(&x)?;
data[index.index()].store(Some(Cow::Owned(y.into())));
},
}
Sub => {
let (left, right) = get_left_right_operands(ir, index);
@@ -269,7 +281,7 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
let c = evaluator.sub(&a, &b)?;
data[index.index()].store(Some(Cow::Owned(c.into())));
},
}
SubPlaintext => {
let (left, right) = get_left_right_operands(ir, index);