Can create SIMD types

This commit is contained in:
Rick Weber
2022-01-26 15:43:01 -08:00
parent c19edf33fa
commit 628c0036a7
12 changed files with 434 additions and 6 deletions

18
.vscode/launch.json vendored
View File

@@ -151,6 +151,24 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug simd integration tests in library 'sunscreen_compiler'",
"cargo": {
"args": [
"test",
"--no-run",
"--package=sunscreen_compiler",
],
"filter": {
"name": "simd",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",

View File

@@ -438,6 +438,13 @@ impl Context {
self.add_2_input(Operation::RotateRight, left, right)
}
/**
* Adds a row swap.
*/
pub fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::SwapRows, x)
}
/**
* Add a node that captures the previous node as an output.
*/

View File

@@ -1,10 +1,12 @@
use crate::{
crate_version,
types::{
intern::{Cipher, CircuitNode},
ops::*,
BfvType, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName,
TypeNameInstance, Version,
},
CircuitInputTrait, InnerPlaintext, Params, Plaintext, WithContext,
with_ctx, CircuitInputTrait, InnerPlaintext, Literal, Params, Plaintext, WithContext,
};
use seal::{
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
@@ -203,6 +205,89 @@ impl<const LANES: usize> Into<[Vec<i64>; 2]> for Simd<LANES> {
}
}
impl<const LANES: usize> GraphCipherAdd for Simd<LANES> {
type Left = Self;
type Right = Self;
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherSub for Simd<LANES> {
type Left = Self;
type Right = Self;
fn graph_cipher_sub(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherMul for Simd<LANES> {
type Left = Self;
type Right = Self;
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherSwapRows for Simd<LANES> {
fn graph_cipher_swap_rows(x: CircuitNode<Cipher<Self>>) -> CircuitNode<Cipher<Self>> {
with_ctx(|ctx| {
let n = ctx.add_swap_rows(x.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherRotateLeft for Simd<LANES> {
fn graph_cipher_rotate_left(x: CircuitNode<Cipher<Self>>, y: u64) -> CircuitNode<Cipher<Self>> {
with_ctx(|ctx| {
let y = ctx.add_literal(Literal::U64(y));
let n = ctx.add_rotate_left(x.ids[0], y);
CircuitNode::new(&[n])
})
}
}
impl<const LANES: usize> GraphCipherRotateRight for Simd<LANES> {
fn graph_cipher_rotate_right(
x: CircuitNode<Cipher<Self>>,
y: u64,
) -> CircuitNode<Cipher<Self>> {
with_ctx(|ctx| {
let y = ctx.add_literal(Literal::U64(y));
let n = ctx.add_rotate_right(x.ids[0], y);
CircuitNode::new(&[n])
})
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -4,7 +4,7 @@ use crate::{
};
use petgraph::stable_graph::NodeIndex;
use std::ops::{Add, Div, Mul, Neg, Sub};
use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
/**
@@ -458,3 +458,52 @@ where
T::graph_cipher_neg(self)
}
}
/**
* A trait that allows data types to swap_rows. E.g. [`Simd`](crate::types::bfv::Simd)
*/
pub trait SwapRows {
/**
* The result type. Typically, this should just be `Self`.
*/
type Output;
/**
* Performs a row swap.
*/
fn swap_rows(self) -> Self::Output;
}
// ciphertext
impl<T> SwapRows for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherSwapRows,
{
type Output = Self;
fn swap_rows(self) -> Self::Output {
T::graph_cipher_swap_rows(self)
}
}
impl<T> Shl<u64> for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherRotateLeft,
{
type Output = Self;
fn shl(self, x: u64) -> Self {
T::graph_cipher_rotate_left(self, x)
}
}
impl<T> Shr<u64> for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherRotateRight,
{
type Output = Self;
fn shr(self, x: u64) -> Self {
T::graph_cipher_rotate_right(self, x)
}
}

View File

@@ -59,7 +59,7 @@ pub mod bfv;
*/
pub mod intern;
/*
/**
* Contains the set of ops traits that dictate legal operations
* for FHE data types.
*/

View File

@@ -2,10 +2,12 @@ mod add;
mod div;
mod mul;
mod neg;
mod rotate;
mod sub;
pub use add::*;
pub use div::*;
pub use mul::*;
pub use neg::*;
pub use rotate::*;
pub use sub::*;

View File

@@ -12,5 +12,8 @@ pub trait GraphCipherNeg {
*/
type Val: FheType;
/**
* Negates the given ciphertext (e.g. -x).
*/
fn graph_cipher_neg(a: CircuitNode<Cipher<Self::Val>>) -> CircuitNode<Cipher<Self::Val>>;
}

View File

@@ -0,0 +1,34 @@
use crate::types::{intern::CircuitNode, Cipher, FheType};
/**
* Swaps the rows of the given ciphertext.
*/
pub trait GraphCipherSwapRows
where
Self: FheType,
{
/**
* Swap the rows in the given ciphertext.
*/
fn graph_cipher_swap_rows(x: CircuitNode<Cipher<Self>>) -> CircuitNode<Cipher<Self>>;
}
pub trait GraphCipherRotateLeft
where
Self: FheType,
{
fn graph_cipher_rotate_left(
x: CircuitNode<Cipher<Self>>,
amount: u64,
) -> CircuitNode<Cipher<Self>>;
}
pub trait GraphCipherRotateRight
where
Self: FheType,
{
fn graph_cipher_rotate_right(
x: CircuitNode<Cipher<Self>>,
amount: u64,
) -> CircuitNode<Cipher<Self>>;
}

View File

@@ -0,0 +1,218 @@
use sunscreen_compiler::{
circuit,
types::{bfv::Simd, Cipher},
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
};
#[test]
fn can_swap_rows_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a.swap_rows()
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![5, 6, 7, 8], vec![1, 2, 3, 4]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_rotate_left_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a << 1
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![2, 3, 4, 1], vec![6, 7, 8, 5]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_rotate_right_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a >> 1
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![4, 1, 2, 3], vec![8, 5, 6, 7]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_add_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a + b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![2, 4, 6, 8], vec![10, 12, 14, 16]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_sub_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a - b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![0, 0, 0, 0], vec![0, 0, 0, 0]];
assert_eq!(c, expected.try_into().unwrap());
}
#[test]
fn can_mul_cipher_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
a * b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
let a = runtime
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
.unwrap();
let b = runtime
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
.unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
let expected = [vec![1, 4, 9, 16], vec![25, 36, 49, 64]];
assert_eq!(c, expected.try_into().unwrap());
}

View File

@@ -100,7 +100,7 @@ pub fn circuit_impl(
fn build(&self, params: &sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation> {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{CircuitNode, Input}, NumCiphertexts, Type, TypeName, TypeNameInstance}};
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{CircuitNode, Input, SwapRows}, NumCiphertexts, Type, TypeName, TypeNameInstance}};
if SchemeType::Bfv != params.scheme_type {
return Err(Error::IncorrectScheme)

View File

@@ -121,7 +121,7 @@ pub enum Error {
BincodeError(String),
/**
* Called [`inner_as_seal_plaintext`](crate::InnerPlaintext::inner_as_seal_plaintext)
* Called [`inner_as_seal_plaintext`](crate::InnerPlaintext.inner_as_seal_plaintext)
* on non-Seal plaintext.
*/
NotASealPlaintext,

View File

@@ -237,7 +237,19 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
data[index.index()].store(Some(Cow::Owned(c.into())));
}
SwapRows => unimplemented!(),
SwapRows => {
let galois_keys = galois_keys
.as_ref()
.ok_or(CircuitRunFailure::MissingGaloisKeys)?;
let input = get_unary_operand(ir, index);
let x = get_ciphertext(&data, input.index())?;
let y = evaluator.rotate_columns(&x, galois_keys)?;
data[index.index()].store(Some(Cow::Owned(y.into())));
}
Relinearize => {
let relin_keys = relin_keys
.as_ref()