mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
Can create SIMD types
This commit is contained in:
18
.vscode/launch.json
vendored
18
.vscode/launch.json
vendored
@@ -151,6 +151,24 @@
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug simd integration tests in library 'sunscreen_compiler'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--package=sunscreen_compiler",
|
||||
],
|
||||
"filter": {
|
||||
"name": "simd",
|
||||
"kind": "test"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
|
||||
@@ -438,6 +438,13 @@ impl Context {
|
||||
self.add_2_input(Operation::RotateRight, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a row swap.
|
||||
*/
|
||||
pub fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.add_1_input(Operation::SwapRows, x)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a node that captures the previous node as an output.
|
||||
*/
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use crate::{
|
||||
crate_version,
|
||||
types::{
|
||||
intern::{Cipher, CircuitNode},
|
||||
ops::*,
|
||||
BfvType, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName,
|
||||
TypeNameInstance, Version,
|
||||
},
|
||||
CircuitInputTrait, InnerPlaintext, Params, Plaintext, WithContext,
|
||||
with_ctx, CircuitInputTrait, InnerPlaintext, Literal, Params, Plaintext, WithContext,
|
||||
};
|
||||
use seal::{
|
||||
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
|
||||
@@ -203,6 +205,89 @@ impl<const LANES: usize> Into<[Vec<i64>; 2]> for Simd<LANES> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherAdd for Simd<LANES> {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_cipher_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherSub for Simd<LANES> {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_cipher_sub(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherMul for Simd<LANES> {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_cipher_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherSwapRows for Simd<LANES> {
|
||||
fn graph_cipher_swap_rows(x: CircuitNode<Cipher<Self>>) -> CircuitNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_swap_rows(x.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherRotateLeft for Simd<LANES> {
|
||||
fn graph_cipher_rotate_left(x: CircuitNode<Cipher<Self>>, y: u64) -> CircuitNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
let y = ctx.add_literal(Literal::U64(y));
|
||||
let n = ctx.add_rotate_left(x.ids[0], y);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> GraphCipherRotateRight for Simd<LANES> {
|
||||
fn graph_cipher_rotate_right(
|
||||
x: CircuitNode<Cipher<Self>>,
|
||||
y: u64,
|
||||
) -> CircuitNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
let y = ctx.add_literal(Literal::U64(y));
|
||||
let n = ctx.add_rotate_right(x.ids[0], y);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
use std::ops::{Add, Div, Mul, Neg, Sub};
|
||||
use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
/**
|
||||
@@ -458,3 +458,52 @@ where
|
||||
T::graph_cipher_neg(self)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait that allows data types to swap_rows. E.g. [`Simd`](crate::types::bfv::Simd)
|
||||
*/
|
||||
pub trait SwapRows {
|
||||
/**
|
||||
* The result type. Typically, this should just be `Self`.
|
||||
*/
|
||||
type Output;
|
||||
|
||||
/**
|
||||
* Performs a row swap.
|
||||
*/
|
||||
fn swap_rows(self) -> Self::Output;
|
||||
}
|
||||
|
||||
// ciphertext
|
||||
impl<T> SwapRows for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphCipherSwapRows,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn swap_rows(self) -> Self::Output {
|
||||
T::graph_cipher_swap_rows(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Shl<u64> for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphCipherRotateLeft,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn shl(self, x: u64) -> Self {
|
||||
T::graph_cipher_rotate_left(self, x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Shr<u64> for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphCipherRotateRight,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn shr(self, x: u64) -> Self {
|
||||
T::graph_cipher_rotate_right(self, x)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ pub mod bfv;
|
||||
*/
|
||||
pub mod intern;
|
||||
|
||||
/*
|
||||
/**
|
||||
* Contains the set of ops traits that dictate legal operations
|
||||
* for FHE data types.
|
||||
*/
|
||||
|
||||
@@ -2,10 +2,12 @@ mod add;
|
||||
mod div;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod rotate;
|
||||
mod sub;
|
||||
|
||||
pub use add::*;
|
||||
pub use div::*;
|
||||
pub use mul::*;
|
||||
pub use neg::*;
|
||||
pub use rotate::*;
|
||||
pub use sub::*;
|
||||
|
||||
@@ -12,5 +12,8 @@ pub trait GraphCipherNeg {
|
||||
*/
|
||||
type Val: FheType;
|
||||
|
||||
/**
|
||||
* Negates the given ciphertext (e.g. -x).
|
||||
*/
|
||||
fn graph_cipher_neg(a: CircuitNode<Cipher<Self::Val>>) -> CircuitNode<Cipher<Self::Val>>;
|
||||
}
|
||||
|
||||
34
sunscreen_compiler/src/types/ops/rotate.rs
Normal file
34
sunscreen_compiler/src/types/ops/rotate.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use crate::types::{intern::CircuitNode, Cipher, FheType};
|
||||
|
||||
/**
|
||||
* Swaps the rows of the given ciphertext.
|
||||
*/
|
||||
pub trait GraphCipherSwapRows
|
||||
where
|
||||
Self: FheType,
|
||||
{
|
||||
/**
|
||||
* Swap the rows in the given ciphertext.
|
||||
*/
|
||||
fn graph_cipher_swap_rows(x: CircuitNode<Cipher<Self>>) -> CircuitNode<Cipher<Self>>;
|
||||
}
|
||||
|
||||
pub trait GraphCipherRotateLeft
|
||||
where
|
||||
Self: FheType,
|
||||
{
|
||||
fn graph_cipher_rotate_left(
|
||||
x: CircuitNode<Cipher<Self>>,
|
||||
amount: u64,
|
||||
) -> CircuitNode<Cipher<Self>>;
|
||||
}
|
||||
|
||||
pub trait GraphCipherRotateRight
|
||||
where
|
||||
Self: FheType,
|
||||
{
|
||||
fn graph_cipher_rotate_right(
|
||||
x: CircuitNode<Cipher<Self>>,
|
||||
amount: u64,
|
||||
) -> CircuitNode<Cipher<Self>>;
|
||||
}
|
||||
218
sunscreen_compiler/tests/simd.rs
Normal file
218
sunscreen_compiler/tests/simd.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use sunscreen_compiler::{
|
||||
circuit,
|
||||
types::{bfv::Simd, Cipher},
|
||||
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_swap_rows_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a.swap_rows()
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![5, 6, 7, 8], vec![1, 2, 3, 4]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_rotate_left_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a << 1
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![2, 3, 4, 1], vec![6, 7, 8, 5]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_rotate_right_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a >> 1
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![4, 1, 2, 3], vec![8, 5, 6, 7]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_cipher_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a + b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
|
||||
.unwrap();
|
||||
let b = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![2, 4, 6, 8], vec![10, 12, 14, 16]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_sub_cipher_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a - b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
|
||||
.unwrap();
|
||||
let b = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![0, 0, 0, 0], vec![0, 0, 0, 0]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mul_cipher_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Simd<4>>, b: Cipher<Simd<4>>) -> Cipher<Simd<4>> {
|
||||
a * b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(0))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let data = [vec![1, 2, 3, 4], vec![5, 6, 7, 8]];
|
||||
|
||||
let a = runtime
|
||||
.encrypt(Simd::<4>::try_from(data.clone()).unwrap(), &public)
|
||||
.unwrap();
|
||||
let b = runtime
|
||||
.encrypt(Simd::<4>::try_from(data).unwrap(), &public)
|
||||
.unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Simd<4> = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
let expected = [vec![1, 4, 9, 16], vec![25, 36, 49, 64]];
|
||||
|
||||
assert_eq!(c, expected.try_into().unwrap());
|
||||
}
|
||||
@@ -100,7 +100,7 @@ pub fn circuit_impl(
|
||||
fn build(&self, params: &sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation> {
|
||||
use std::cell::RefCell;
|
||||
use std::mem::transmute;
|
||||
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{CircuitNode, Input}, NumCiphertexts, Type, TypeName, TypeNameInstance}};
|
||||
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{CircuitNode, Input, SwapRows}, NumCiphertexts, Type, TypeName, TypeNameInstance}};
|
||||
|
||||
if SchemeType::Bfv != params.scheme_type {
|
||||
return Err(Error::IncorrectScheme)
|
||||
|
||||
@@ -121,7 +121,7 @@ pub enum Error {
|
||||
BincodeError(String),
|
||||
|
||||
/**
|
||||
* Called [`inner_as_seal_plaintext`](crate::InnerPlaintext::inner_as_seal_plaintext)
|
||||
* Called [`inner_as_seal_plaintext`](crate::InnerPlaintext.inner_as_seal_plaintext)
|
||||
* on non-Seal plaintext.
|
||||
*/
|
||||
NotASealPlaintext,
|
||||
|
||||
@@ -237,7 +237,19 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
|
||||
|
||||
data[index.index()].store(Some(Cow::Owned(c.into())));
|
||||
}
|
||||
SwapRows => unimplemented!(),
|
||||
SwapRows => {
|
||||
let galois_keys = galois_keys
|
||||
.as_ref()
|
||||
.ok_or(CircuitRunFailure::MissingGaloisKeys)?;
|
||||
|
||||
let input = get_unary_operand(ir, index);
|
||||
|
||||
let x = get_ciphertext(&data, input.index())?;
|
||||
|
||||
let y = evaluator.rotate_columns(&x, galois_keys)?;
|
||||
|
||||
data[index.index()].store(Some(Cow::Owned(y.into())));
|
||||
}
|
||||
Relinearize => {
|
||||
let relin_keys = relin_keys
|
||||
.as_ref()
|
||||
|
||||
Reference in New Issue
Block a user