Merge pull request #31 from Sunscreen-tech/rweber/cipher

Rweber/cipher
This commit is contained in:
rickwebiii
2022-01-19 15:47:04 -08:00
committed by GitHub
14 changed files with 398 additions and 238 deletions

View File

@@ -2,7 +2,7 @@ use std::io::{self, Write};
use std::sync::mpsc::{Receiver, Sender};
use std::thread::{self, JoinHandle};
use sunscreen_compiler::{
circuit, types::Rational, Ciphertext, CompiledCircuit, Compiler, Params,
circuit, types::{Cipher, Rational}, Ciphertext, CompiledCircuit, Compiler, Params,
PlainModulusConstraint, PublicKey, Runtime, RuntimeError,
};
@@ -192,22 +192,22 @@ fn compile_circuits() -> (
CompiledCircuit,
) {
#[circuit(scheme = "bfv")]
fn add(a: Rational, b: Rational) -> Rational {
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a + b
}
#[circuit(scheme = "bfv")]
fn sub(a: Rational, b: Rational) -> Rational {
fn sub(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a - b
}
#[circuit(scheme = "bfv")]
fn mul(a: Rational, b: Rational) -> Rational {
fn mul(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a * b
}
#[circuit(scheme = "bfv")]
fn div(a: Rational, b: Rational) -> Rational {
fn div(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a / b
}

View File

@@ -1,4 +1,4 @@
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, PlainModulusConstraint};
use sunscreen_compiler::{circuit, types::{Cipher, Unsigned}, Compiler, PlainModulusConstraint};
use sunscreen_runtime::Runtime;
/**
@@ -7,14 +7,17 @@ use sunscreen_runtime::Runtime;
* the result. Circuits may take any number of parameters and return either a single result
* or a tuple of results.
*
* The unsigned type refers to an unsigned integer modulo the plaintext
* The [`Unsigned`] type refers to an unsigned integer modulo the plaintext
* modulus (p). p is passed to the compiler via plain_modulus_constraint.
*
*
* A `Cipher` type indicates the type is encrypted. Thus, a `Cipher<Unsigned>`
* refers to an encrypted [`Unsigned`] value.
*
* One takes a circuit and passes them to the compiler, which transforms it into a form
* suitable for execution.
*/
#[circuit(scheme = "bfv")]
fn simple_multiply(a: Unsigned, b: Unsigned) -> Unsigned {
fn simple_multiply(a: Cipher<Unsigned>, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
a * b
}

View File

@@ -11,32 +11,50 @@ enum ParamsMode {
Manual(Params),
}
/**
* The operations supported by a #[circuit] function.
*/
pub trait CircuitFn {
/**
* Get the call signature of the function
*/
fn signature(&self) -> CallSignature;
/**
* Compile the circuit.
*/
fn build(&self, params: &Params) -> Result<FrontendCompilation>;
/**
* Get the scheme type.
*/
fn scheme_type(&self) -> SchemeType;
}
/**
* A frontend circuit compiler for Sunscreen circuits.
*/
pub struct Compiler<F, G>
pub struct Compiler<F>
where
G: Fn(&Params) -> Result<FrontendCompilation>,
F: Fn() -> (SchemeType, G, CallSignature),
F: CircuitFn,
{
circuit: F,
circuit_fn: F,
params_mode: ParamsMode,
plain_modulus_constraint: Option<PlainModulusConstraint>,
security_level: SecurityLevel,
noise_margin: u32,
}
impl<F, G> Compiler<F, G>
impl<F> Compiler<F>
where
G: Fn(&Params) -> Result<FrontendCompilation>,
F: Fn() -> (SchemeType, G, CallSignature),
F: CircuitFn,
{
/**
* Create a new compiler with the given circuit.
*/
pub fn with_circuit(circuit: F) -> Self {
pub fn with_circuit(circuit_fn: F) -> Self {
Self {
circuit,
circuit_fn,
params_mode: ParamsMode::Search,
plain_modulus_constraint: None,
security_level: SecurityLevel::TC128,
@@ -92,23 +110,25 @@ where
* for running it.
*/
pub fn compile(self) -> Result<CompiledCircuit> {
let (scheme, circuit_fn, signature) = (self.circuit)();
let scheme = self.circuit_fn.scheme_type();
let signature = self.circuit_fn.signature();
let (circuit, params) = match self.params_mode {
ParamsMode::Manual(p) => (circuit_fn(&p), p.clone()),
ParamsMode::Manual(p) => (self.circuit_fn.build(&p), p.clone()),
ParamsMode::Search => {
let constraint = self
.plain_modulus_constraint
.ok_or(Error::MissingPlainModulusConstraint)?;
let params = determine_params(
&circuit_fn,
let params = determine_params::<F>(
&self.circuit_fn,
constraint,
self.security_level,
self.noise_margin,
scheme,
)?;
(circuit_fn(&params), params.clone())
(self.circuit_fn.build(&params), params.clone())
}
};

View File

@@ -7,10 +7,10 @@
//! # Examples
//! This example is further annotated in `examples/simple_multiply`.
//! ```
//! # use sunscreen_compiler::{circuit, Compiler, types::Unsigned, PlainModulusConstraint, Params, Runtime, Context};
//! # use sunscreen_compiler::{circuit, Compiler, types::{Cipher, Unsigned}, PlainModulusConstraint, Params, Runtime, Context};
//!
//! #[circuit(scheme = "bfv")]
//! fn simple_multiply(a: Unsigned, b: Unsigned) -> Unsigned {
//! fn simple_multiply(a: Cipher<Unsigned>, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
//! a * b
//! }
//!
@@ -116,7 +116,7 @@ use sunscreen_circuit::{
};
pub use clap::crate_version;
pub use compiler::Compiler;
pub use compiler::{CircuitFn, Compiler};
pub use error::{Error, Result};
pub use params::PlainModulusConstraint;
pub use sunscreen_circuit::{SchemeType, SecurityLevel};

View File

@@ -1,4 +1,4 @@
use crate::{Error, FrontendCompilation, Result, SecurityLevel};
use crate::{Error, Result, SecurityLevel, CircuitFn};
use log::{debug, trace};
@@ -47,7 +47,7 @@ pub fn determine_params<F>(
scheme_type: SchemeType,
) -> Result<Params>
where
F: Fn(&Params) -> Result<FrontendCompilation>,
F: CircuitFn,
{
'order_loop: for (i, n) in LATTICE_DIMENSIONS.iter().enumerate() {
let plaintext_modulus = match plaintext_constraint {
@@ -144,7 +144,7 @@ where
scheme_type: scheme_type,
};
let ir = circuit_fn(&params)?.compile();
let ir = circuit_fn.build(&params)?.compile();
let num_inputs = ir
.graph

View File

@@ -1,6 +1,6 @@
use seal::Plaintext as SealPlaintext;
use crate::types::{GraphAdd, GraphMul, GraphSub};
use crate::types::{GraphCipherAdd, GraphCipherPlainAdd, GraphCipherMul, GraphCipherSub, Cipher};
use crate::{
crate_version,
types::{BfvType, CircuitNode, FheType, Type, Version},
@@ -182,14 +182,14 @@ impl<const INT_BITS: usize> BfvType for Fractional<INT_BITS> {}
impl<const INT_BITS: usize> Fractional<INT_BITS> {}
impl<const INT_BITS: usize> GraphAdd for Fractional<INT_BITS> {
impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_add(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
@@ -198,14 +198,30 @@ impl<const INT_BITS: usize> GraphAdd for Fractional<INT_BITS> {
}
}
impl<const INT_BITS: usize> GraphSub for Fractional<INT_BITS> {
impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_sub(
a: CircuitNode<Self::Left>,
fn graph_cipher_plain_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_sub(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
@@ -214,14 +230,14 @@ impl<const INT_BITS: usize> GraphSub for Fractional<INT_BITS> {
}
}
impl<const INT_BITS: usize> GraphMul for Fractional<INT_BITS> {
impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_mul(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);

View File

@@ -1,6 +1,6 @@
use seal::Plaintext as SealPlaintext;
use crate::types::{GraphAdd, GraphMul};
use crate::types::{GraphCipherAdd, GraphCipherMul, Cipher};
use crate::{
types::{BfvType, CircuitNode, FheType},
with_ctx, Params, TypeName as DeriveTypeName,
@@ -31,14 +31,14 @@ impl BfvType for Unsigned {}
impl Unsigned {}
impl GraphAdd for Unsigned {
impl GraphCipherAdd for Unsigned {
type Left = Unsigned;
type Right = Unsigned;
fn graph_add(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
@@ -47,14 +47,14 @@ impl GraphAdd for Unsigned {
}
}
impl GraphMul for Unsigned {
impl GraphCipherMul for Unsigned {
type Left = Unsigned;
type Right = Unsigned;
fn graph_mul(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
@@ -232,14 +232,14 @@ impl Into<i64> for Signed {
}
}
impl GraphAdd for Signed {
impl GraphCipherAdd for Signed {
type Left = Signed;
type Right = Signed;
fn graph_add(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
@@ -248,14 +248,14 @@ impl GraphAdd for Signed {
}
}
impl GraphMul for Signed {
impl GraphCipherMul for Signed {
type Left = Signed;
type Right = Signed;
fn graph_mul(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);

View File

@@ -60,7 +60,7 @@ impl U64LiteralRef {
* Violating any of these condicitions may result in memory corruption or
* use-after-free.
*/
pub struct CircuitNode<T: FheType> {
pub struct CircuitNode<T: NumCiphertexts> {
/**
* The ids on this node. The 'static lifetime on this slice is a lie. The sunscreen
* compiler must ensure that no CircuitNode exists after circuit construction.
@@ -70,7 +70,7 @@ pub struct CircuitNode<T: FheType> {
_phantom: std::marker::PhantomData<T>,
}
impl<T: FheType> CircuitNode<T> {
impl<T: NumCiphertexts> CircuitNode<T> {
/**
* Creates a new circuit node with the given node index.
*
@@ -153,10 +153,35 @@ impl<T: FheType> CircuitNode<T> {
}
}
#[derive(Copy, Clone)]
/**
* Called when a circuit encounters a + operation.
* Declares a type T as being encrypted in a circuit.
*/
pub trait GraphAdd {
pub struct Cipher<T>
where T: FheType
{
_val: T,
}
impl <T> NumCiphertexts for Cipher<T>
where T: FheType
{
const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS;
}
impl <T> TypeName for Cipher<T>
where T: FheType + TypeName
{
fn type_name() -> Type {
T::type_name()
}
}
/**
* Called when a circuit encounters a + operation on two encrypted
* types.
*/
pub trait GraphCipherAdd {
/**
* The type of the left operand
*/
@@ -170,16 +195,17 @@ pub trait GraphAdd {
/**
* Process the + operation
*/
fn graph_add(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left>;
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>>;
}
/**
* Called when a circuit encounters a + operation.
* Called when a circuit encounters a + operation on one encrypted
* and one unencrypted type.
*/
pub trait GraphSub {
pub trait GraphCipherPlainAdd {
/**
* The type of the left operand
*/
@@ -193,16 +219,39 @@ pub trait GraphSub {
/**
* Process the + operation
*/
fn graph_sub(
a: CircuitNode<Self::Left>,
fn graph_cipher_plain_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left>;
) -> CircuitNode<Cipher<Self::Left>>;
}
/**
* Called when a circuit encounters a * operation.
* Called when a circuit encounters a - operation on two encrypted types.
*/
pub trait GraphMul {
pub trait GraphCipherSub {
/**
* The type of the left operand
*/
type Left: FheType;
/**
* The type of the right operand
*/
type Right: FheType;
/**
* Process the + operation
*/
fn graph_cipher_sub(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>>;
}
/**
* Called when a circuit encounters a * operation on two encrypted types.
*/
pub trait GraphCipherMul {
/**
* The type of the left operand
*/
@@ -216,16 +265,16 @@ pub trait GraphMul {
/**
* Process the * operation
*/
fn graph_mul(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left>;
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>>;
}
/**
* Called when a circuit encounters a / operation.
* Called when a circuit encounters a / operation on two encrypted types.
*/
pub trait GraphDiv {
pub trait GraphCipherDiv {
/**
* The type of the left operand
*/
@@ -239,53 +288,67 @@ pub trait GraphDiv {
/**
* Process the + operation
*/
fn graph_div(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left>;
fn graph_cipher_div(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>>;
}
impl<T> Add for CircuitNode<T>
// cipher + cipher
impl<T> Add for CircuitNode<Cipher<T>>
where
T: FheType + GraphAdd<Left = T, Right = T>,
T: FheType + GraphCipherAdd<Left = T, Right = T>,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
T::graph_add(self, rhs)
T::graph_cipher_add(self, rhs)
}
}
impl<T> Sub for CircuitNode<T>
// cipher + plain
impl<T> Add<CircuitNode<T>> for CircuitNode<Cipher<T>>
where
T: FheType + GraphSub<Left = T, Right = T>,
T: FheType + GraphCipherPlainAdd<Left = T, Right = T>,
{
type Output = Self;
fn add(self, rhs: CircuitNode<T>) -> Self::Output {
T::graph_cipher_plain_add(self, rhs)
}
}
// cipher - cipher
impl<T> Sub for CircuitNode<Cipher<T>>
where
T: FheType + GraphCipherSub<Left = T, Right = T>,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
T::graph_sub(self, rhs)
T::graph_cipher_sub(self, rhs)
}
}
impl<T> Mul for CircuitNode<T>
impl<T> Mul for CircuitNode<Cipher<T>>
where
T: FheType + GraphMul<Left = T, Right = T>,
T: FheType + GraphCipherMul<Left = T, Right = T>,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
T::graph_mul(self, rhs)
T::graph_cipher_mul(self, rhs)
}
}
impl<T> Div for CircuitNode<T>
impl<T> Div for CircuitNode<Cipher<T>>
where
T: FheType + GraphDiv<Left = T, Right = T>,
T: FheType + GraphCipherDiv<Left = T, Right = T>,
{
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
T::graph_div(self, rhs)
T::graph_cipher_div(self, rhs)
}
}

View File

@@ -1,5 +1,5 @@
use crate::types::{
BfvType, CircuitNode, FheType, GraphAdd, GraphDiv, GraphMul, GraphSub, NumCiphertexts, Signed,
BfvType, CircuitNode, Cipher, FheType, GraphCipherAdd, GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, Signed,
TryFromPlaintext, TryIntoPlaintext,
};
use crate::{with_ctx, InnerPlaintext, Params, Plaintext, TypeName};
@@ -96,14 +96,14 @@ impl Into<f64> for Rational {
}
}
impl GraphAdd for Rational {
impl GraphCipherAdd for Rational {
type Left = Self;
type Right = Self;
fn graph_add(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
@@ -119,14 +119,14 @@ impl GraphAdd for Rational {
}
}
impl GraphSub for Rational {
impl GraphCipherSub for Rational {
type Left = Self;
type Right = Self;
fn graph_sub(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_sub(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
@@ -142,14 +142,14 @@ impl GraphSub for Rational {
}
}
impl GraphMul for Rational {
impl GraphCipherMul for Rational {
type Left = Self;
type Right = Self;
fn graph_mul(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
@@ -162,14 +162,14 @@ impl GraphMul for Rational {
}
}
impl GraphDiv for Rational {
impl GraphCipherDiv for Rational {
type Left = Self;
type Right = Self;
fn graph_div(
a: CircuitNode<Self::Left>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Self::Left> {
fn graph_cipher_div(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Cipher<Self::Right>>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[1]);

View File

@@ -1,12 +1,14 @@
use sunscreen_compiler::{
circuit, types::Fractional, types::Rational, types::Signed, Compiler, PlainModulusConstraint,
circuit, types::{Cipher, Fractional, Rational, Signed}, Compiler, PlainModulusConstraint,
Runtime,
};
type CipherSigned = Cipher<Signed>;
#[test]
fn can_encode_signed() {
#[circuit(scheme = "bfv")]
fn add(a: Signed) -> Signed {
fn add(a: CipherSigned) -> CipherSigned {
a
}
@@ -32,7 +34,7 @@ fn can_encode_signed() {
#[test]
fn can_add_signed_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Signed, b: Signed) -> Signed {
fn add(a: CipherSigned, b: CipherSigned) -> CipherSigned {
a + b
}
@@ -59,7 +61,7 @@ fn can_add_signed_numbers() {
#[test]
fn can_multiply_signed_numbers() {
#[circuit(scheme = "bfv")]
fn mul(a: Signed, b: Signed) -> Signed {
fn mul(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
a * b
}
@@ -86,7 +88,7 @@ fn can_multiply_signed_numbers() {
#[test]
fn can_encode_rational_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Rational) -> Rational {
fn add(a: Cipher<Rational>) -> Cipher<Rational> {
a
}
@@ -111,10 +113,12 @@ fn can_encode_rational_numbers() {
assert_eq!(c, (-3.14).try_into().unwrap());
}
type CipherRational = Cipher<Rational>;
#[test]
fn can_add_rational_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Rational, b: Rational) -> Rational {
fn add(a: CipherRational, b: CipherRational) -> CipherRational {
a + b
}
@@ -145,7 +149,7 @@ fn can_add_rational_numbers() {
#[test]
fn can_mul_rational_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Rational, b: Rational) -> Rational {
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a * b
}
@@ -176,7 +180,7 @@ fn can_mul_rational_numbers() {
#[test]
fn can_div_rational_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Rational, b: Rational) -> Rational {
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a / b
}
@@ -207,7 +211,7 @@ fn can_div_rational_numbers() {
#[test]
fn can_sub_rational_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Rational, b: Rational) -> Rational {
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
a - b
}
@@ -235,10 +239,12 @@ fn can_sub_rational_numbers() {
assert_eq!(c, (-6.28).try_into().unwrap());
}
type CipherFractional = Cipher<Fractional::<64>>;
#[test]
fn can_add_fractional_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> {
fn add(a: CipherFractional, b: CipherFractional) -> CipherFractional {
a + b
}
@@ -252,7 +258,7 @@ fn can_add_fractional_numbers() {
let (public, secret) = runtime.generate_keys().unwrap();
let add = |a: f64, b: f64| {
let do_add = |a: f64, b: f64| {
let a_c = runtime
.encrypt(Fractional::<64>::try_from(a).unwrap(), &public)
.unwrap();
@@ -267,27 +273,27 @@ fn can_add_fractional_numbers() {
assert_eq!(c, (a + b).try_into().unwrap());
};
add(3.14, 3.14);
add(-3.14, 3.14);
add(0., 0.);
add(7., 3.);
add(1e9, 1e9);
add(1e-8, 1e-7);
add(-3.14, -3.14);
add(3.14, -3.14);
add(-7., -3.);
add(-1e9, -1e9);
add(-1e-8, -1e-7);
do_add(3.14, 3.14);
do_add(-3.14, 3.14);
do_add(0., 0.);
do_add(7., 3.);
do_add(1e9, 1e9);
do_add(1e-8, 1e-7);
do_add(-3.14, -3.14);
do_add(3.14, -3.14);
do_add(-7., -3.);
do_add(-1e9, -1e9);
do_add(-1e-8, -1e-7);
}
#[test]
fn can_sub_fractional_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> {
fn sub(a: Cipher<Fractional<64>>, b: Cipher<Fractional<64>>) -> Cipher<Fractional<64>> {
a - b
}
let circuit = Compiler::with_circuit(add)
let circuit = Compiler::with_circuit(sub)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
@@ -297,7 +303,7 @@ fn can_sub_fractional_numbers() {
let (public, secret) = runtime.generate_keys().unwrap();
let add = |a: f64, b: f64| {
let do_sub = |a: f64, b: f64| {
let a_c = runtime
.encrypt(Fractional::<64>::try_from(a).unwrap(), &public)
.unwrap();
@@ -312,23 +318,23 @@ fn can_sub_fractional_numbers() {
assert_eq!(c, (a - b).try_into().unwrap());
};
add(3.14, 3.14);
add(-3.14, 3.14);
add(0., 0.);
add(7., 3.);
add(1e9, 1e9);
add(1e-8, 1e-7);
add(-3.14, -3.14);
add(3.14, -3.14);
add(-7., -3.);
add(-1e9, -1e9);
add(-1e-8, -1e-7);
do_sub(3.14, 3.14);
do_sub(-3.14, 3.14);
do_sub(0., 0.);
do_sub(7., 3.);
do_sub(1e9, 1e9);
do_sub(1e-8, 1e-7);
do_sub(-3.14, -3.14);
do_sub(3.14, -3.14);
do_sub(-7., -3.);
do_sub(-1e9, -1e9);
do_sub(-1e-8, -1e-7);
}
#[test]
fn can_mul_fractional_numbers() {
#[circuit(scheme = "bfv")]
fn mul(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> {
fn mul(a: Cipher<Fractional<64>>, b: Cipher<Fractional<64>>) -> Cipher<Fractional<64>> {
a * b
}

View File

@@ -14,7 +14,6 @@ pub fn circuit_impl(
let circuit_name = &input_fn.sig.ident;
let vis = &input_fn.vis;
let body = &input_fn.block;
let attrs = &input_fn.attrs;
let inputs = &input_fn.sig.inputs;
let ret = &input_fn.sig.output;
@@ -25,7 +24,7 @@ pub fn circuit_impl(
let scheme_type = match attr_params.scheme {
Scheme::Bfv => {
quote! {
SchemeType::Bfv
sunscreen_compiler::SchemeType::Bfv
}
}
};
@@ -90,18 +89,19 @@ pub fn circuit_impl(
}
});
proc_macro::TokenStream::from(quote! {
#(#attrs)*
#vis fn #circuit_name() -> (
sunscreen_compiler::SchemeType,
impl Fn(&sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation>,
sunscreen_compiler::CallSignature
) {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{CircuitNode, NumCiphertexts, Type, TypeName, TypeNameInstance}};
let circuit_struct_name = Ident::new(&format!("{}_struct", circuit_name), Span::call_site());
let circuit = proc_macro::TokenStream::from(quote! {
#[allow(non_camel_case_types)]
#vis struct #circuit_struct_name {
}
impl sunscreen_compiler::CircuitFn for #circuit_struct_name {
fn build(&self, params: &sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation> {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{CircuitNode, NumCiphertexts, Type, TypeName, TypeNameInstance}};
let circuit_builder = |params: &Params| {
if SchemeType::Bfv != params.scheme_type {
return Err(Error::IncorrectScheme)
}
@@ -110,9 +110,9 @@ pub fn circuit_impl(
let mut context = Context::new(params);
CURRENT_CTX.with(|ctx| {
let internal = | #(#circuit_args)* | -> #circuit_returns {
let internal = | #(#circuit_args)* | -> #circuit_returns
#body
};
;
// Transmute away the lifetime to 'static. So long as we are careful with internal()
// panicing, this is safe because we set the context back to none before the funtion
@@ -131,7 +131,7 @@ pub fn circuit_impl(
Ok(v) => { #catpured_outputs },
Err(err) => {
INDEX_ARENA.with(|allocator| {
unsafe { allocator.borrow_mut().reset() }
allocator.borrow_mut().reset()
});
ctx.swap(&RefCell::new(None));
std::panic::resume_unwind(err)
@@ -139,19 +139,31 @@ pub fn circuit_impl(
};
INDEX_ARENA.with(|allocator| {
unsafe { allocator.borrow_mut().reset() }
allocator.borrow_mut().reset()
});
ctx.swap(&RefCell::new(None));
});
Ok(context.compilation)
};
}
#signature;
fn signature(&self) -> sunscreen_compiler::CallSignature {
use sunscreen_compiler::types::NumCiphertexts;
(#scheme_type, circuit_builder, signature)
#signature
}
fn scheme_type(&self) -> sunscreen_compiler::SchemeType {
#scheme_type
}
}
})
#[allow(non_upper_case_globals)]
const #circuit_name: #circuit_struct_name = #circuit_struct_name { };
});
//panic!("{}", circuit);
circuit
}
/**
@@ -265,13 +277,28 @@ fn capture_outputs(ret: &ReturnType) -> TokenStream {
}
fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
let arg_type_names = args.iter().map(|t| {
// We have to type alias arguments and returns because they might
// be generic and cause an error during invocation.
// E.g. Foo<Bar> causes an error when doing Foo<Bar>::func()
// because you need :: after Foo.
// So we make type aliases and invoke the function on the alias.
let arg_type_names = args.iter().enumerate().map(|(i, t)| {
let alias = ident("T", i);
quote! {
#t ::type_name(),
type #alias = #t;
}
}).collect::<Vec<TokenStream>>();
let arg_get_types = arg_type_names.iter().enumerate().map(|(i, _)| {
let alias = ident("T", i);
quote! {
#alias::type_name(),
}
});
let (return_type_names, return_type_sizes) = match ret {
let (return_type_aliases, return_type_names, return_type_sizes) = match ret {
ReturnType::Type(_, t) => {
let tuple_inners = match &**t {
Type::Tuple(t) => t.elems.iter().map(|x| &*x).collect::<Vec<&Type>>(),
@@ -288,19 +315,37 @@ fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
}
};
let return_type_sizes = tuple_inners.iter().map(|t| {
let return_type_aliases = tuple_inners
.iter()
.enumerate()
.map(|(i, t)| {
let alias = ident("R", i);
quote! {
type #alias = #t;
}
});
let return_type_sizes = tuple_inners.iter().enumerate().map(|(i, _)| {
let alias = ident("R", i);
quote! {
#t ::NUM_CIPHERTEXTS,
#alias ::NUM_CIPHERTEXTS,
}
});
let type_names = tuple_inners.iter().map(|t| {
let type_names = tuple_inners.iter().enumerate().map(|(i, _)| {
let alias = ident("R", i);
quote! {
#t ::type_name(),
#alias ::type_name(),
}
});
(
quote! {
#(#return_type_aliases)*
},
quote! {
vec![
#(#type_names)*
@@ -313,14 +358,23 @@ fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
},
)
}
ReturnType::Default => (quote! { vec![] }, quote! { vec![] }),
ReturnType::Default => (quote! { }, quote! { vec![] }, quote! { vec![] }),
};
quote! {
let signature = sunscreen_compiler::CallSignature {
arguments: vec![#(#arg_type_names)*],
use sunscreen_compiler::types::TypeName;
#(#arg_type_names)*
#return_type_aliases
sunscreen_compiler::CallSignature {
arguments: vec![#(#arg_get_types)*],
returns: #return_type_names,
num_ciphertexts: #return_type_sizes,
};
}
}
}
fn ident(prefix: &str, i: usize) -> Ident {
Ident::new(&format!("{}{}", prefix, i), Span::call_site())
}

View File

@@ -33,19 +33,27 @@ pub fn derive_typename(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
*
* # Examples
* ```rust
* # use sunscreen_compiler::{circuit, types::Unsigned, Params, Context};
* # use sunscreen_compiler::{circuit, types::{Cipher, Unsigned}, Params, Context};
*
* #[circuit(scheme = "bfv")]
* fn multiply_add(a: Unsigned, b: Unsigned, c: Unsigned) -> Unsigned {
* fn multiply_add(
* a: Cipher<Unsigned>,
* b: Cipher<Unsigned>,
* c: Cipher<Unsigned>
* ) -> Cipher<Unsigned> {
* a * b + c
* }
* ```
*
* ```rust
* # use sunscreen_compiler::{circuit, types::Unsigned, Params, Context};
* # use sunscreen_compiler::{circuit, types::{Cipher, Unsigned}, Params, Context};
*
* #[circuit(scheme = "bfv")]
* fn multi_out(a: Unsigned, b: Unsigned, c: Unsigned) -> (Unsigned, Unsigned) {
* fn multi_out(
* a: Cipher<Unsigned>,
* b: Cipher<Unsigned>,
* c: Cipher<Unsigned>
* ) -> (Cipher<Unsigned>, Cipher<Unsigned>) {
* (a + b, b + c)
* }
* ```

View File

@@ -1,6 +1,6 @@
use sunscreen_compiler::{
types::{TypeName, Unsigned},
CallSignature, FrontendCompilation, Params, SchemeType, SecurityLevel, CURRENT_CTX,
types::{TypeName, Cipher, Unsigned},
CallSignature, FrontendCompilation, Params, SchemeType, SecurityLevel, CURRENT_CTX, CircuitFn
};
use sunscreen_compiler_macros::circuit;
@@ -16,6 +16,8 @@ fn get_params() -> Params {
}
}
type CipherUnsigned = Cipher<Unsigned>;
#[test]
fn circuit_gets_called() {
static mut FOO: u32 = 0;
@@ -27,18 +29,16 @@ fn circuit_gets_called() {
};
}
let (scheme, compile_fn, signature) = simple_circuit();
let expected_signature = CallSignature {
arguments: vec![],
returns: vec![],
num_ciphertexts: vec![],
};
assert_eq!(signature, expected_signature);
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(simple_circuit.signature(), expected_signature);
assert_eq!(simple_circuit.scheme_type(), SchemeType::Bfv);
let _context = compile_fn(&get_params()).unwrap();
let _context = simple_circuit.build(&get_params()).unwrap();
assert_eq!(unsafe { FOO }, 20);
}
@@ -58,18 +58,16 @@ fn panicing_circuit_clears_ctx() {
}
let panic_result = std::panic::catch_unwind(|| {
let (scheme, compile_fn, signature) = panic_circuit();
let expected_signature = CallSignature {
arguments: vec![],
returns: vec![],
num_ciphertexts: vec![],
};
assert_eq!(signature, expected_signature);
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(panic_circuit.signature(), expected_signature);
assert_eq!(panic_circuit.scheme_type(), SchemeType::Bfv);
let _context = compile_fn(&get_params()).unwrap();
let _context = panic_circuit.build(&get_params()).unwrap();
});
assert_eq!(panic_result.is_err(), true);
@@ -93,9 +91,7 @@ fn capture_circuit_input_args() {
#[circuit(scheme = "bfv")]
fn circuit_with_args(_a: Unsigned, _b: Unsigned, _c: Unsigned, _d: Unsigned) {}
let (scheme, compile_fn, signature) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
let type_name = Unsigned::type_name();
@@ -110,9 +106,9 @@ fn capture_circuit_input_args() {
num_ciphertexts: vec![],
};
assert_eq!(expected_signature, signature);
assert_eq!(expected_signature, circuit_with_args.signature());
let context = compile_fn(&get_params()).unwrap();
let context = circuit_with_args.build(&get_params()).unwrap();
assert_eq!(context.graph.node_count(), 4);
}
@@ -120,12 +116,10 @@ fn capture_circuit_input_args() {
#[test]
fn can_add() {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned, c: Unsigned) {
fn circuit_with_args(a: CipherUnsigned, b: CipherUnsigned, c: CipherUnsigned) {
let _ = a + b + c;
}
let (scheme, compile_fn, signature) = circuit_with_args();
let type_name = Unsigned::type_name();
let expected_signature = CallSignature {
@@ -133,10 +127,10 @@ fn can_add() {
returns: vec![],
num_ciphertexts: vec![],
};
assert_eq!(signature, expected_signature);
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(circuit_with_args.signature(), expected_signature);
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
let context: FrontendCompilation = compile_fn(&get_params()).unwrap();
let context: FrontendCompilation = circuit_with_args.build(&get_params()).unwrap();
let expected = json!({
@@ -184,12 +178,10 @@ fn can_add() {
#[test]
fn can_mul() {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned, c: Unsigned) {
fn circuit_with_args(a: CipherUnsigned, b: CipherUnsigned, c: CipherUnsigned) {
let _ = a * b * c;
}
let (scheme, compile_fn, signature) = circuit_with_args();
let type_name = Unsigned::type_name();
let expected_signature = CallSignature {
@@ -197,10 +189,10 @@ fn can_mul() {
returns: vec![],
num_ciphertexts: vec![],
};
assert_eq!(signature, expected_signature);
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(circuit_with_args.signature(), expected_signature);
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let context = circuit_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -247,12 +239,10 @@ fn can_mul() {
#[test]
fn can_collect_output() {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned) -> Unsigned {
fn circuit_with_args(a: Cipher<Unsigned>, b: CipherUnsigned) -> CipherUnsigned {
a + b * a
}
let (scheme, compile_fn, signature) = circuit_with_args();
let type_name = Unsigned::type_name();
let expected_signature = CallSignature {
@@ -260,10 +250,10 @@ fn can_collect_output() {
returns: vec![type_name.clone()],
num_ciphertexts: vec![1],
};
assert_eq!(signature, expected_signature);
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(circuit_with_args.signature(), expected_signature);
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let context = circuit_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -315,12 +305,10 @@ fn can_collect_output() {
#[test]
fn can_collect_multiple_outputs() {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned) -> (Unsigned, Unsigned) {
fn circuit_with_args(a: Cipher<Unsigned>, b: CipherUnsigned) -> (Cipher<Unsigned>, Cipher<Unsigned>) {
(a + b * a, a)
}
let (scheme, compile_fn, signature) = circuit_with_args();
let type_name = Unsigned::type_name();
let expected_signature = CallSignature {
@@ -328,10 +316,10 @@ fn can_collect_multiple_outputs() {
returns: vec![type_name.clone(), type_name.clone()],
num_ciphertexts: vec![1, 1],
};
assert_eq!(signature, expected_signature);
assert_eq!(scheme, SchemeType::Bfv);
assert_eq!(circuit_with_args.signature(), expected_signature);
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let context = circuit_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {

View File

@@ -1,9 +1,11 @@
use sunscreen_compiler::{types::*, *};
type CipherUnsigned = Cipher<Unsigned>;
#[test]
fn can_encrypt_decrypt() {
#[circuit(scheme = "bfv")]
fn foo(a: Unsigned, b: Unsigned) -> Unsigned {
fn foo(a: CipherUnsigned, b: CipherUnsigned) -> CipherUnsigned {
a + b
}