mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
@@ -2,7 +2,7 @@ use std::io::{self, Write};
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
use std::thread::{self, JoinHandle};
|
||||
use sunscreen_compiler::{
|
||||
circuit, types::Rational, Ciphertext, CompiledCircuit, Compiler, Params,
|
||||
circuit, types::{Cipher, Rational}, Ciphertext, CompiledCircuit, Compiler, Params,
|
||||
PlainModulusConstraint, PublicKey, Runtime, RuntimeError,
|
||||
};
|
||||
|
||||
@@ -192,22 +192,22 @@ fn compile_circuits() -> (
|
||||
CompiledCircuit,
|
||||
) {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational, b: Rational) -> Rational {
|
||||
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a + b
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn sub(a: Rational, b: Rational) -> Rational {
|
||||
fn sub(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a - b
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn mul(a: Rational, b: Rational) -> Rational {
|
||||
fn mul(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a * b
|
||||
}
|
||||
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn div(a: Rational, b: Rational) -> Rational {
|
||||
fn div(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a / b
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, PlainModulusConstraint};
|
||||
use sunscreen_compiler::{circuit, types::{Cipher, Unsigned}, Compiler, PlainModulusConstraint};
|
||||
use sunscreen_runtime::Runtime;
|
||||
|
||||
/**
|
||||
@@ -7,14 +7,17 @@ use sunscreen_runtime::Runtime;
|
||||
* the result. Circuits may take any number of parameters and return either a single result
|
||||
* or a tuple of results.
|
||||
*
|
||||
* The unsigned type refers to an unsigned integer modulo the plaintext
|
||||
* The [`Unsigned`] type refers to an unsigned integer modulo the plaintext
|
||||
* modulus (p). p is passed to the compiler via plain_modulus_constraint.
|
||||
*
|
||||
*
|
||||
* A `Cipher` type indicates the type is encrypted. Thus, a `Cipher<Unsigned>`
|
||||
* refers to an encrypted [`Unsigned`] value.
|
||||
*
|
||||
* One takes a circuit and passes them to the compiler, which transforms it into a form
|
||||
* suitable for execution.
|
||||
*/
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn simple_multiply(a: Unsigned, b: Unsigned) -> Unsigned {
|
||||
fn simple_multiply(a: Cipher<Unsigned>, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
|
||||
a * b
|
||||
}
|
||||
|
||||
|
||||
@@ -11,32 +11,50 @@ enum ParamsMode {
|
||||
Manual(Params),
|
||||
}
|
||||
|
||||
/**
|
||||
* The operations supported by a #[circuit] function.
|
||||
*/
|
||||
pub trait CircuitFn {
|
||||
/**
|
||||
* Get the call signature of the function
|
||||
*/
|
||||
fn signature(&self) -> CallSignature;
|
||||
|
||||
/**
|
||||
* Compile the circuit.
|
||||
*/
|
||||
fn build(&self, params: &Params) -> Result<FrontendCompilation>;
|
||||
|
||||
/**
|
||||
* Get the scheme type.
|
||||
*/
|
||||
fn scheme_type(&self) -> SchemeType;
|
||||
}
|
||||
|
||||
/**
|
||||
* A frontend circuit compiler for Sunscreen circuits.
|
||||
*/
|
||||
pub struct Compiler<F, G>
|
||||
pub struct Compiler<F>
|
||||
where
|
||||
G: Fn(&Params) -> Result<FrontendCompilation>,
|
||||
F: Fn() -> (SchemeType, G, CallSignature),
|
||||
F: CircuitFn,
|
||||
{
|
||||
circuit: F,
|
||||
circuit_fn: F,
|
||||
params_mode: ParamsMode,
|
||||
plain_modulus_constraint: Option<PlainModulusConstraint>,
|
||||
security_level: SecurityLevel,
|
||||
noise_margin: u32,
|
||||
}
|
||||
|
||||
impl<F, G> Compiler<F, G>
|
||||
impl<F> Compiler<F>
|
||||
where
|
||||
G: Fn(&Params) -> Result<FrontendCompilation>,
|
||||
F: Fn() -> (SchemeType, G, CallSignature),
|
||||
F: CircuitFn,
|
||||
{
|
||||
/**
|
||||
* Create a new compiler with the given circuit.
|
||||
*/
|
||||
pub fn with_circuit(circuit: F) -> Self {
|
||||
pub fn with_circuit(circuit_fn: F) -> Self {
|
||||
Self {
|
||||
circuit,
|
||||
circuit_fn,
|
||||
params_mode: ParamsMode::Search,
|
||||
plain_modulus_constraint: None,
|
||||
security_level: SecurityLevel::TC128,
|
||||
@@ -92,23 +110,25 @@ where
|
||||
* for running it.
|
||||
*/
|
||||
pub fn compile(self) -> Result<CompiledCircuit> {
|
||||
let (scheme, circuit_fn, signature) = (self.circuit)();
|
||||
let scheme = self.circuit_fn.scheme_type();
|
||||
let signature = self.circuit_fn.signature();
|
||||
|
||||
let (circuit, params) = match self.params_mode {
|
||||
ParamsMode::Manual(p) => (circuit_fn(&p), p.clone()),
|
||||
ParamsMode::Manual(p) => (self.circuit_fn.build(&p), p.clone()),
|
||||
ParamsMode::Search => {
|
||||
let constraint = self
|
||||
.plain_modulus_constraint
|
||||
.ok_or(Error::MissingPlainModulusConstraint)?;
|
||||
|
||||
let params = determine_params(
|
||||
&circuit_fn,
|
||||
let params = determine_params::<F>(
|
||||
&self.circuit_fn,
|
||||
constraint,
|
||||
self.security_level,
|
||||
self.noise_margin,
|
||||
scheme,
|
||||
)?;
|
||||
|
||||
(circuit_fn(¶ms), params.clone())
|
||||
(self.circuit_fn.build(¶ms), params.clone())
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
//! # Examples
|
||||
//! This example is further annotated in `examples/simple_multiply`.
|
||||
//! ```
|
||||
//! # use sunscreen_compiler::{circuit, Compiler, types::Unsigned, PlainModulusConstraint, Params, Runtime, Context};
|
||||
//! # use sunscreen_compiler::{circuit, Compiler, types::{Cipher, Unsigned}, PlainModulusConstraint, Params, Runtime, Context};
|
||||
//!
|
||||
//! #[circuit(scheme = "bfv")]
|
||||
//! fn simple_multiply(a: Unsigned, b: Unsigned) -> Unsigned {
|
||||
//! fn simple_multiply(a: Cipher<Unsigned>, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
|
||||
//! a * b
|
||||
//! }
|
||||
//!
|
||||
@@ -116,7 +116,7 @@ use sunscreen_circuit::{
|
||||
};
|
||||
|
||||
pub use clap::crate_version;
|
||||
pub use compiler::Compiler;
|
||||
pub use compiler::{CircuitFn, Compiler};
|
||||
pub use error::{Error, Result};
|
||||
pub use params::PlainModulusConstraint;
|
||||
pub use sunscreen_circuit::{SchemeType, SecurityLevel};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Error, FrontendCompilation, Result, SecurityLevel};
|
||||
use crate::{Error, Result, SecurityLevel, CircuitFn};
|
||||
|
||||
use log::{debug, trace};
|
||||
|
||||
@@ -47,7 +47,7 @@ pub fn determine_params<F>(
|
||||
scheme_type: SchemeType,
|
||||
) -> Result<Params>
|
||||
where
|
||||
F: Fn(&Params) -> Result<FrontendCompilation>,
|
||||
F: CircuitFn,
|
||||
{
|
||||
'order_loop: for (i, n) in LATTICE_DIMENSIONS.iter().enumerate() {
|
||||
let plaintext_modulus = match plaintext_constraint {
|
||||
@@ -144,7 +144,7 @@ where
|
||||
scheme_type: scheme_type,
|
||||
};
|
||||
|
||||
let ir = circuit_fn(¶ms)?.compile();
|
||||
let ir = circuit_fn.build(¶ms)?.compile();
|
||||
|
||||
let num_inputs = ir
|
||||
.graph
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use seal::Plaintext as SealPlaintext;
|
||||
|
||||
use crate::types::{GraphAdd, GraphMul, GraphSub};
|
||||
use crate::types::{GraphCipherAdd, GraphCipherPlainAdd, GraphCipherMul, GraphCipherSub, Cipher};
|
||||
use crate::{
|
||||
crate_version,
|
||||
types::{BfvType, CircuitNode, FheType, Type, Version},
|
||||
@@ -182,14 +182,14 @@ impl<const INT_BITS: usize> BfvType for Fractional<INT_BITS> {}
|
||||
|
||||
impl<const INT_BITS: usize> Fractional<INT_BITS> {}
|
||||
|
||||
impl<const INT_BITS: usize> GraphAdd for Fractional<INT_BITS> {
|
||||
impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
|
||||
type Left = Fractional<INT_BITS>;
|
||||
type Right = Fractional<INT_BITS>;
|
||||
|
||||
fn graph_add(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
@@ -198,14 +198,30 @@ impl<const INT_BITS: usize> GraphAdd for Fractional<INT_BITS> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const INT_BITS: usize> GraphSub for Fractional<INT_BITS> {
|
||||
impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
|
||||
type Left = Fractional<INT_BITS>;
|
||||
type Right = Fractional<INT_BITS>;
|
||||
|
||||
fn graph_sub(
|
||||
a: CircuitNode<Self::Left>,
|
||||
fn graph_cipher_plain_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
|
||||
type Left = Fractional<INT_BITS>;
|
||||
type Right = Fractional<INT_BITS>;
|
||||
|
||||
fn graph_cipher_sub(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
|
||||
|
||||
@@ -214,14 +230,14 @@ impl<const INT_BITS: usize> GraphSub for Fractional<INT_BITS> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const INT_BITS: usize> GraphMul for Fractional<INT_BITS> {
|
||||
impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
|
||||
type Left = Fractional<INT_BITS>;
|
||||
type Right = Fractional<INT_BITS>;
|
||||
|
||||
fn graph_mul(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use seal::Plaintext as SealPlaintext;
|
||||
|
||||
use crate::types::{GraphAdd, GraphMul};
|
||||
use crate::types::{GraphCipherAdd, GraphCipherMul, Cipher};
|
||||
use crate::{
|
||||
types::{BfvType, CircuitNode, FheType},
|
||||
with_ctx, Params, TypeName as DeriveTypeName,
|
||||
@@ -31,14 +31,14 @@ impl BfvType for Unsigned {}
|
||||
|
||||
impl Unsigned {}
|
||||
|
||||
impl GraphAdd for Unsigned {
|
||||
impl GraphCipherAdd for Unsigned {
|
||||
type Left = Unsigned;
|
||||
type Right = Unsigned;
|
||||
|
||||
fn graph_add(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
@@ -47,14 +47,14 @@ impl GraphAdd for Unsigned {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMul for Unsigned {
|
||||
impl GraphCipherMul for Unsigned {
|
||||
type Left = Unsigned;
|
||||
type Right = Unsigned;
|
||||
|
||||
fn graph_mul(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
@@ -232,14 +232,14 @@ impl Into<i64> for Signed {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphAdd for Signed {
|
||||
impl GraphCipherAdd for Signed {
|
||||
type Left = Signed;
|
||||
type Right = Signed;
|
||||
|
||||
fn graph_add(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
@@ -248,14 +248,14 @@ impl GraphAdd for Signed {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMul for Signed {
|
||||
impl GraphCipherMul for Signed {
|
||||
type Left = Signed;
|
||||
type Right = Signed;
|
||||
|
||||
fn graph_mul(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ impl U64LiteralRef {
|
||||
* Violating any of these condicitions may result in memory corruption or
|
||||
* use-after-free.
|
||||
*/
|
||||
pub struct CircuitNode<T: FheType> {
|
||||
pub struct CircuitNode<T: NumCiphertexts> {
|
||||
/**
|
||||
* The ids on this node. The 'static lifetime on this slice is a lie. The sunscreen
|
||||
* compiler must ensure that no CircuitNode exists after circuit construction.
|
||||
@@ -70,7 +70,7 @@ pub struct CircuitNode<T: FheType> {
|
||||
_phantom: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FheType> CircuitNode<T> {
|
||||
impl<T: NumCiphertexts> CircuitNode<T> {
|
||||
/**
|
||||
* Creates a new circuit node with the given node index.
|
||||
*
|
||||
@@ -153,10 +153,35 @@ impl<T: FheType> CircuitNode<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/**
|
||||
* Called when a circuit encounters a + operation.
|
||||
* Declares a type T as being encrypted in a circuit.
|
||||
*/
|
||||
pub trait GraphAdd {
|
||||
pub struct Cipher<T>
|
||||
where T: FheType
|
||||
{
|
||||
_val: T,
|
||||
}
|
||||
|
||||
impl <T> NumCiphertexts for Cipher<T>
|
||||
where T: FheType
|
||||
{
|
||||
const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS;
|
||||
}
|
||||
|
||||
impl <T> TypeName for Cipher<T>
|
||||
where T: FheType + TypeName
|
||||
{
|
||||
fn type_name() -> Type {
|
||||
T::type_name()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a circuit encounters a + operation on two encrypted
|
||||
* types.
|
||||
*/
|
||||
pub trait GraphCipherAdd {
|
||||
/**
|
||||
* The type of the left operand
|
||||
*/
|
||||
@@ -170,16 +195,17 @@ pub trait GraphAdd {
|
||||
/**
|
||||
* Process the + operation
|
||||
*/
|
||||
fn graph_add(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left>;
|
||||
fn graph_cipher_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a circuit encounters a + operation.
|
||||
* Called when a circuit encounters a + operation on one encrypted
|
||||
* and one unencrypted type.
|
||||
*/
|
||||
pub trait GraphSub {
|
||||
pub trait GraphCipherPlainAdd {
|
||||
/**
|
||||
* The type of the left operand
|
||||
*/
|
||||
@@ -193,16 +219,39 @@ pub trait GraphSub {
|
||||
/**
|
||||
* Process the + operation
|
||||
*/
|
||||
fn graph_sub(
|
||||
a: CircuitNode<Self::Left>,
|
||||
fn graph_cipher_plain_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left>;
|
||||
) -> CircuitNode<Cipher<Self::Left>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a circuit encounters a * operation.
|
||||
* Called when a circuit encounters a - operation on two encrypted types.
|
||||
*/
|
||||
pub trait GraphMul {
|
||||
pub trait GraphCipherSub {
|
||||
/**
|
||||
* The type of the left operand
|
||||
*/
|
||||
type Left: FheType;
|
||||
|
||||
/**
|
||||
* The type of the right operand
|
||||
*/
|
||||
type Right: FheType;
|
||||
|
||||
/**
|
||||
* Process the + operation
|
||||
*/
|
||||
fn graph_cipher_sub(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a circuit encounters a * operation on two encrypted types.
|
||||
*/
|
||||
pub trait GraphCipherMul {
|
||||
/**
|
||||
* The type of the left operand
|
||||
*/
|
||||
@@ -216,16 +265,16 @@ pub trait GraphMul {
|
||||
/**
|
||||
* Process the * operation
|
||||
*/
|
||||
fn graph_mul(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left>;
|
||||
fn graph_cipher_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a circuit encounters a / operation.
|
||||
* Called when a circuit encounters a / operation on two encrypted types.
|
||||
*/
|
||||
pub trait GraphDiv {
|
||||
pub trait GraphCipherDiv {
|
||||
/**
|
||||
* The type of the left operand
|
||||
*/
|
||||
@@ -239,53 +288,67 @@ pub trait GraphDiv {
|
||||
/**
|
||||
* Process the + operation
|
||||
*/
|
||||
fn graph_div(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left>;
|
||||
fn graph_cipher_div(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>>;
|
||||
}
|
||||
|
||||
impl<T> Add for CircuitNode<T>
|
||||
// cipher + cipher
|
||||
impl<T> Add for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphAdd<Left = T, Right = T>,
|
||||
T: FheType + GraphCipherAdd<Left = T, Right = T>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
T::graph_add(self, rhs)
|
||||
T::graph_cipher_add(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sub for CircuitNode<T>
|
||||
// cipher + plain
|
||||
impl<T> Add<CircuitNode<T>> for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphSub<Left = T, Right = T>,
|
||||
T: FheType + GraphCipherPlainAdd<Left = T, Right = T>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: CircuitNode<T>) -> Self::Output {
|
||||
T::graph_cipher_plain_add(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
// cipher - cipher
|
||||
impl<T> Sub for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphCipherSub<Left = T, Right = T>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
T::graph_sub(self, rhs)
|
||||
T::graph_cipher_sub(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Mul for CircuitNode<T>
|
||||
impl<T> Mul for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphMul<Left = T, Right = T>,
|
||||
T: FheType + GraphCipherMul<Left = T, Right = T>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
T::graph_mul(self, rhs)
|
||||
T::graph_cipher_mul(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Div for CircuitNode<T>
|
||||
impl<T> Div for CircuitNode<Cipher<T>>
|
||||
where
|
||||
T: FheType + GraphDiv<Left = T, Right = T>,
|
||||
T: FheType + GraphCipherDiv<Left = T, Right = T>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
T::graph_div(self, rhs)
|
||||
T::graph_cipher_div(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::types::{
|
||||
BfvType, CircuitNode, FheType, GraphAdd, GraphDiv, GraphMul, GraphSub, NumCiphertexts, Signed,
|
||||
BfvType, CircuitNode, Cipher, FheType, GraphCipherAdd, GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, Signed,
|
||||
TryFromPlaintext, TryIntoPlaintext,
|
||||
};
|
||||
use crate::{with_ctx, InnerPlaintext, Params, Plaintext, TypeName};
|
||||
@@ -96,14 +96,14 @@ impl Into<f64> for Rational {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphAdd for Rational {
|
||||
impl GraphCipherAdd for Rational {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_add(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
|
||||
@@ -119,14 +119,14 @@ impl GraphAdd for Rational {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphSub for Rational {
|
||||
impl GraphCipherSub for Rational {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_sub(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_sub(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
|
||||
@@ -142,14 +142,14 @@ impl GraphSub for Rational {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMul for Rational {
|
||||
impl GraphCipherMul for Rational {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_mul(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
@@ -162,14 +162,14 @@ impl GraphMul for Rational {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphDiv for Rational {
|
||||
impl GraphCipherDiv for Rational {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_div(
|
||||
a: CircuitNode<Self::Left>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Self::Left> {
|
||||
fn graph_cipher_div(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Cipher<Self::Right>>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[1]);
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use sunscreen_compiler::{
|
||||
circuit, types::Fractional, types::Rational, types::Signed, Compiler, PlainModulusConstraint,
|
||||
circuit, types::{Cipher, Fractional, Rational, Signed}, Compiler, PlainModulusConstraint,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
type CipherSigned = Cipher<Signed>;
|
||||
|
||||
#[test]
|
||||
fn can_encode_signed() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Signed) -> Signed {
|
||||
fn add(a: CipherSigned) -> CipherSigned {
|
||||
a
|
||||
}
|
||||
|
||||
@@ -32,7 +34,7 @@ fn can_encode_signed() {
|
||||
#[test]
|
||||
fn can_add_signed_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Signed, b: Signed) -> Signed {
|
||||
fn add(a: CipherSigned, b: CipherSigned) -> CipherSigned {
|
||||
a + b
|
||||
}
|
||||
|
||||
@@ -59,7 +61,7 @@ fn can_add_signed_numbers() {
|
||||
#[test]
|
||||
fn can_multiply_signed_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn mul(a: Signed, b: Signed) -> Signed {
|
||||
fn mul(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
|
||||
a * b
|
||||
}
|
||||
|
||||
@@ -86,7 +88,7 @@ fn can_multiply_signed_numbers() {
|
||||
#[test]
|
||||
fn can_encode_rational_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational) -> Rational {
|
||||
fn add(a: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a
|
||||
}
|
||||
|
||||
@@ -111,10 +113,12 @@ fn can_encode_rational_numbers() {
|
||||
assert_eq!(c, (-3.14).try_into().unwrap());
|
||||
}
|
||||
|
||||
type CipherRational = Cipher<Rational>;
|
||||
|
||||
#[test]
|
||||
fn can_add_rational_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational, b: Rational) -> Rational {
|
||||
fn add(a: CipherRational, b: CipherRational) -> CipherRational {
|
||||
a + b
|
||||
}
|
||||
|
||||
@@ -145,7 +149,7 @@ fn can_add_rational_numbers() {
|
||||
#[test]
|
||||
fn can_mul_rational_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational, b: Rational) -> Rational {
|
||||
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a * b
|
||||
}
|
||||
|
||||
@@ -176,7 +180,7 @@ fn can_mul_rational_numbers() {
|
||||
#[test]
|
||||
fn can_div_rational_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational, b: Rational) -> Rational {
|
||||
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a / b
|
||||
}
|
||||
|
||||
@@ -207,7 +211,7 @@ fn can_div_rational_numbers() {
|
||||
#[test]
|
||||
fn can_sub_rational_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational, b: Rational) -> Rational {
|
||||
fn add(a: Cipher<Rational>, b: Cipher<Rational>) -> Cipher<Rational> {
|
||||
a - b
|
||||
}
|
||||
|
||||
@@ -235,10 +239,12 @@ fn can_sub_rational_numbers() {
|
||||
assert_eq!(c, (-6.28).try_into().unwrap());
|
||||
}
|
||||
|
||||
type CipherFractional = Cipher<Fractional::<64>>;
|
||||
|
||||
#[test]
|
||||
fn can_add_fractional_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> {
|
||||
fn add(a: CipherFractional, b: CipherFractional) -> CipherFractional {
|
||||
a + b
|
||||
}
|
||||
|
||||
@@ -252,7 +258,7 @@ fn can_add_fractional_numbers() {
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let add = |a: f64, b: f64| {
|
||||
let do_add = |a: f64, b: f64| {
|
||||
let a_c = runtime
|
||||
.encrypt(Fractional::<64>::try_from(a).unwrap(), &public)
|
||||
.unwrap();
|
||||
@@ -267,27 +273,27 @@ fn can_add_fractional_numbers() {
|
||||
assert_eq!(c, (a + b).try_into().unwrap());
|
||||
};
|
||||
|
||||
add(3.14, 3.14);
|
||||
add(-3.14, 3.14);
|
||||
add(0., 0.);
|
||||
add(7., 3.);
|
||||
add(1e9, 1e9);
|
||||
add(1e-8, 1e-7);
|
||||
add(-3.14, -3.14);
|
||||
add(3.14, -3.14);
|
||||
add(-7., -3.);
|
||||
add(-1e9, -1e9);
|
||||
add(-1e-8, -1e-7);
|
||||
do_add(3.14, 3.14);
|
||||
do_add(-3.14, 3.14);
|
||||
do_add(0., 0.);
|
||||
do_add(7., 3.);
|
||||
do_add(1e9, 1e9);
|
||||
do_add(1e-8, 1e-7);
|
||||
do_add(-3.14, -3.14);
|
||||
do_add(3.14, -3.14);
|
||||
do_add(-7., -3.);
|
||||
do_add(-1e9, -1e9);
|
||||
do_add(-1e-8, -1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_sub_fractional_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> {
|
||||
fn sub(a: Cipher<Fractional<64>>, b: Cipher<Fractional<64>>) -> Cipher<Fractional<64>> {
|
||||
a - b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
let circuit = Compiler::with_circuit(sub)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
@@ -297,7 +303,7 @@ fn can_sub_fractional_numbers() {
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let add = |a: f64, b: f64| {
|
||||
let do_sub = |a: f64, b: f64| {
|
||||
let a_c = runtime
|
||||
.encrypt(Fractional::<64>::try_from(a).unwrap(), &public)
|
||||
.unwrap();
|
||||
@@ -312,23 +318,23 @@ fn can_sub_fractional_numbers() {
|
||||
assert_eq!(c, (a - b).try_into().unwrap());
|
||||
};
|
||||
|
||||
add(3.14, 3.14);
|
||||
add(-3.14, 3.14);
|
||||
add(0., 0.);
|
||||
add(7., 3.);
|
||||
add(1e9, 1e9);
|
||||
add(1e-8, 1e-7);
|
||||
add(-3.14, -3.14);
|
||||
add(3.14, -3.14);
|
||||
add(-7., -3.);
|
||||
add(-1e9, -1e9);
|
||||
add(-1e-8, -1e-7);
|
||||
do_sub(3.14, 3.14);
|
||||
do_sub(-3.14, 3.14);
|
||||
do_sub(0., 0.);
|
||||
do_sub(7., 3.);
|
||||
do_sub(1e9, 1e9);
|
||||
do_sub(1e-8, 1e-7);
|
||||
do_sub(-3.14, -3.14);
|
||||
do_sub(3.14, -3.14);
|
||||
do_sub(-7., -3.);
|
||||
do_sub(-1e9, -1e9);
|
||||
do_sub(-1e-8, -1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mul_fractional_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn mul(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> {
|
||||
fn mul(a: Cipher<Fractional<64>>, b: Cipher<Fractional<64>>) -> Cipher<Fractional<64>> {
|
||||
a * b
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ pub fn circuit_impl(
|
||||
let circuit_name = &input_fn.sig.ident;
|
||||
let vis = &input_fn.vis;
|
||||
let body = &input_fn.block;
|
||||
let attrs = &input_fn.attrs;
|
||||
let inputs = &input_fn.sig.inputs;
|
||||
let ret = &input_fn.sig.output;
|
||||
|
||||
@@ -25,7 +24,7 @@ pub fn circuit_impl(
|
||||
let scheme_type = match attr_params.scheme {
|
||||
Scheme::Bfv => {
|
||||
quote! {
|
||||
SchemeType::Bfv
|
||||
sunscreen_compiler::SchemeType::Bfv
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -90,18 +89,19 @@ pub fn circuit_impl(
|
||||
}
|
||||
});
|
||||
|
||||
proc_macro::TokenStream::from(quote! {
|
||||
#(#attrs)*
|
||||
#vis fn #circuit_name() -> (
|
||||
sunscreen_compiler::SchemeType,
|
||||
impl Fn(&sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation>,
|
||||
sunscreen_compiler::CallSignature
|
||||
) {
|
||||
use std::cell::RefCell;
|
||||
use std::mem::transmute;
|
||||
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{CircuitNode, NumCiphertexts, Type, TypeName, TypeNameInstance}};
|
||||
let circuit_struct_name = Ident::new(&format!("{}_struct", circuit_name), Span::call_site());
|
||||
|
||||
let circuit = proc_macro::TokenStream::from(quote! {
|
||||
#[allow(non_camel_case_types)]
|
||||
#vis struct #circuit_struct_name {
|
||||
}
|
||||
|
||||
impl sunscreen_compiler::CircuitFn for #circuit_struct_name {
|
||||
fn build(&self, params: &sunscreen_compiler::Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation> {
|
||||
use std::cell::RefCell;
|
||||
use std::mem::transmute;
|
||||
use sunscreen_compiler::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{CircuitNode, NumCiphertexts, Type, TypeName, TypeNameInstance}};
|
||||
|
||||
let circuit_builder = |params: &Params| {
|
||||
if SchemeType::Bfv != params.scheme_type {
|
||||
return Err(Error::IncorrectScheme)
|
||||
}
|
||||
@@ -110,9 +110,9 @@ pub fn circuit_impl(
|
||||
let mut context = Context::new(params);
|
||||
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
let internal = | #(#circuit_args)* | -> #circuit_returns {
|
||||
let internal = | #(#circuit_args)* | -> #circuit_returns
|
||||
#body
|
||||
};
|
||||
;
|
||||
|
||||
// Transmute away the lifetime to 'static. So long as we are careful with internal()
|
||||
// panicing, this is safe because we set the context back to none before the funtion
|
||||
@@ -131,7 +131,7 @@ pub fn circuit_impl(
|
||||
Ok(v) => { #catpured_outputs },
|
||||
Err(err) => {
|
||||
INDEX_ARENA.with(|allocator| {
|
||||
unsafe { allocator.borrow_mut().reset() }
|
||||
allocator.borrow_mut().reset()
|
||||
});
|
||||
ctx.swap(&RefCell::new(None));
|
||||
std::panic::resume_unwind(err)
|
||||
@@ -139,19 +139,31 @@ pub fn circuit_impl(
|
||||
};
|
||||
|
||||
INDEX_ARENA.with(|allocator| {
|
||||
unsafe { allocator.borrow_mut().reset() }
|
||||
allocator.borrow_mut().reset()
|
||||
});
|
||||
ctx.swap(&RefCell::new(None));
|
||||
});
|
||||
|
||||
Ok(context.compilation)
|
||||
};
|
||||
}
|
||||
|
||||
#signature;
|
||||
fn signature(&self) -> sunscreen_compiler::CallSignature {
|
||||
use sunscreen_compiler::types::NumCiphertexts;
|
||||
|
||||
(#scheme_type, circuit_builder, signature)
|
||||
#signature
|
||||
}
|
||||
|
||||
fn scheme_type(&self) -> sunscreen_compiler::SchemeType {
|
||||
#scheme_type
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
const #circuit_name: #circuit_struct_name = #circuit_struct_name { };
|
||||
});
|
||||
|
||||
//panic!("{}", circuit);
|
||||
circuit
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -265,13 +277,28 @@ fn capture_outputs(ret: &ReturnType) -> TokenStream {
|
||||
}
|
||||
|
||||
fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
|
||||
let arg_type_names = args.iter().map(|t| {
|
||||
// We have to type alias arguments and returns because they might
|
||||
// be generic and cause an error during invocation.
|
||||
// E.g. Foo<Bar> causes an error when doing Foo<Bar>::func()
|
||||
// because you need :: after Foo.
|
||||
// So we make type aliases and invoke the function on the alias.
|
||||
let arg_type_names = args.iter().enumerate().map(|(i, t)| {
|
||||
let alias = ident("T", i);
|
||||
|
||||
quote! {
|
||||
#t ::type_name(),
|
||||
type #alias = #t;
|
||||
}
|
||||
}).collect::<Vec<TokenStream>>();
|
||||
|
||||
let arg_get_types = arg_type_names.iter().enumerate().map(|(i, _)| {
|
||||
let alias = ident("T", i);
|
||||
|
||||
quote! {
|
||||
#alias::type_name(),
|
||||
}
|
||||
});
|
||||
|
||||
let (return_type_names, return_type_sizes) = match ret {
|
||||
let (return_type_aliases, return_type_names, return_type_sizes) = match ret {
|
||||
ReturnType::Type(_, t) => {
|
||||
let tuple_inners = match &**t {
|
||||
Type::Tuple(t) => t.elems.iter().map(|x| &*x).collect::<Vec<&Type>>(),
|
||||
@@ -288,19 +315,37 @@ fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
|
||||
}
|
||||
};
|
||||
|
||||
let return_type_sizes = tuple_inners.iter().map(|t| {
|
||||
let return_type_aliases = tuple_inners
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let alias = ident("R", i);
|
||||
|
||||
quote! {
|
||||
type #alias = #t;
|
||||
}
|
||||
});
|
||||
|
||||
let return_type_sizes = tuple_inners.iter().enumerate().map(|(i, _)| {
|
||||
let alias = ident("R", i);
|
||||
|
||||
quote! {
|
||||
#t ::NUM_CIPHERTEXTS,
|
||||
#alias ::NUM_CIPHERTEXTS,
|
||||
}
|
||||
});
|
||||
|
||||
let type_names = tuple_inners.iter().map(|t| {
|
||||
let type_names = tuple_inners.iter().enumerate().map(|(i, _)| {
|
||||
let alias = ident("R", i);
|
||||
|
||||
quote! {
|
||||
#t ::type_name(),
|
||||
#alias ::type_name(),
|
||||
}
|
||||
});
|
||||
|
||||
(
|
||||
quote! {
|
||||
#(#return_type_aliases)*
|
||||
},
|
||||
quote! {
|
||||
vec![
|
||||
#(#type_names)*
|
||||
@@ -313,14 +358,23 @@ fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
|
||||
},
|
||||
)
|
||||
}
|
||||
ReturnType::Default => (quote! { vec![] }, quote! { vec![] }),
|
||||
ReturnType::Default => (quote! { }, quote! { vec![] }, quote! { vec![] }),
|
||||
};
|
||||
|
||||
quote! {
|
||||
let signature = sunscreen_compiler::CallSignature {
|
||||
arguments: vec![#(#arg_type_names)*],
|
||||
use sunscreen_compiler::types::TypeName;
|
||||
|
||||
#(#arg_type_names)*
|
||||
#return_type_aliases
|
||||
|
||||
sunscreen_compiler::CallSignature {
|
||||
arguments: vec![#(#arg_get_types)*],
|
||||
returns: #return_type_names,
|
||||
num_ciphertexts: #return_type_sizes,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ident(prefix: &str, i: usize) -> Ident {
|
||||
Ident::new(&format!("{}{}", prefix, i), Span::call_site())
|
||||
}
|
||||
@@ -33,19 +33,27 @@ pub fn derive_typename(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
|
||||
*
|
||||
* # Examples
|
||||
* ```rust
|
||||
* # use sunscreen_compiler::{circuit, types::Unsigned, Params, Context};
|
||||
* # use sunscreen_compiler::{circuit, types::{Cipher, Unsigned}, Params, Context};
|
||||
*
|
||||
* #[circuit(scheme = "bfv")]
|
||||
* fn multiply_add(a: Unsigned, b: Unsigned, c: Unsigned) -> Unsigned {
|
||||
* fn multiply_add(
|
||||
* a: Cipher<Unsigned>,
|
||||
* b: Cipher<Unsigned>,
|
||||
* c: Cipher<Unsigned>
|
||||
* ) -> Cipher<Unsigned> {
|
||||
* a * b + c
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* ```rust
|
||||
* # use sunscreen_compiler::{circuit, types::Unsigned, Params, Context};
|
||||
* # use sunscreen_compiler::{circuit, types::{Cipher, Unsigned}, Params, Context};
|
||||
*
|
||||
* #[circuit(scheme = "bfv")]
|
||||
* fn multi_out(a: Unsigned, b: Unsigned, c: Unsigned) -> (Unsigned, Unsigned) {
|
||||
* fn multi_out(
|
||||
* a: Cipher<Unsigned>,
|
||||
* b: Cipher<Unsigned>,
|
||||
* c: Cipher<Unsigned>
|
||||
* ) -> (Cipher<Unsigned>, Cipher<Unsigned>) {
|
||||
* (a + b, b + c)
|
||||
* }
|
||||
* ```
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use sunscreen_compiler::{
|
||||
types::{TypeName, Unsigned},
|
||||
CallSignature, FrontendCompilation, Params, SchemeType, SecurityLevel, CURRENT_CTX,
|
||||
types::{TypeName, Cipher, Unsigned},
|
||||
CallSignature, FrontendCompilation, Params, SchemeType, SecurityLevel, CURRENT_CTX, CircuitFn
|
||||
};
|
||||
use sunscreen_compiler_macros::circuit;
|
||||
|
||||
@@ -16,6 +16,8 @@ fn get_params() -> Params {
|
||||
}
|
||||
}
|
||||
|
||||
type CipherUnsigned = Cipher<Unsigned>;
|
||||
|
||||
#[test]
|
||||
fn circuit_gets_called() {
|
||||
static mut FOO: u32 = 0;
|
||||
@@ -27,18 +29,16 @@ fn circuit_gets_called() {
|
||||
};
|
||||
}
|
||||
|
||||
let (scheme, compile_fn, signature) = simple_circuit();
|
||||
|
||||
let expected_signature = CallSignature {
|
||||
arguments: vec![],
|
||||
returns: vec![],
|
||||
num_ciphertexts: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(signature, expected_signature);
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(simple_circuit.signature(), expected_signature);
|
||||
assert_eq!(simple_circuit.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let _context = compile_fn(&get_params()).unwrap();
|
||||
let _context = simple_circuit.build(&get_params()).unwrap();
|
||||
|
||||
assert_eq!(unsafe { FOO }, 20);
|
||||
}
|
||||
@@ -58,18 +58,16 @@ fn panicing_circuit_clears_ctx() {
|
||||
}
|
||||
|
||||
let panic_result = std::panic::catch_unwind(|| {
|
||||
let (scheme, compile_fn, signature) = panic_circuit();
|
||||
|
||||
let expected_signature = CallSignature {
|
||||
arguments: vec![],
|
||||
returns: vec![],
|
||||
num_ciphertexts: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(signature, expected_signature);
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(panic_circuit.signature(), expected_signature);
|
||||
assert_eq!(panic_circuit.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let _context = compile_fn(&get_params()).unwrap();
|
||||
let _context = panic_circuit.build(&get_params()).unwrap();
|
||||
});
|
||||
|
||||
assert_eq!(panic_result.is_err(), true);
|
||||
@@ -93,9 +91,7 @@ fn capture_circuit_input_args() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn circuit_with_args(_a: Unsigned, _b: Unsigned, _c: Unsigned, _d: Unsigned) {}
|
||||
|
||||
let (scheme, compile_fn, signature) = circuit_with_args();
|
||||
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let type_name = Unsigned::type_name();
|
||||
|
||||
@@ -110,9 +106,9 @@ fn capture_circuit_input_args() {
|
||||
num_ciphertexts: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(expected_signature, signature);
|
||||
assert_eq!(expected_signature, circuit_with_args.signature());
|
||||
|
||||
let context = compile_fn(&get_params()).unwrap();
|
||||
let context = circuit_with_args.build(&get_params()).unwrap();
|
||||
|
||||
assert_eq!(context.graph.node_count(), 4);
|
||||
}
|
||||
@@ -120,12 +116,10 @@ fn capture_circuit_input_args() {
|
||||
#[test]
|
||||
fn can_add() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn circuit_with_args(a: Unsigned, b: Unsigned, c: Unsigned) {
|
||||
fn circuit_with_args(a: CipherUnsigned, b: CipherUnsigned, c: CipherUnsigned) {
|
||||
let _ = a + b + c;
|
||||
}
|
||||
|
||||
let (scheme, compile_fn, signature) = circuit_with_args();
|
||||
|
||||
let type_name = Unsigned::type_name();
|
||||
|
||||
let expected_signature = CallSignature {
|
||||
@@ -133,10 +127,10 @@ fn can_add() {
|
||||
returns: vec![],
|
||||
num_ciphertexts: vec![],
|
||||
};
|
||||
assert_eq!(signature, expected_signature);
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(circuit_with_args.signature(), expected_signature);
|
||||
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let context: FrontendCompilation = compile_fn(&get_params()).unwrap();
|
||||
let context: FrontendCompilation = circuit_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
|
||||
@@ -184,12 +178,10 @@ fn can_add() {
|
||||
#[test]
|
||||
fn can_mul() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn circuit_with_args(a: Unsigned, b: Unsigned, c: Unsigned) {
|
||||
fn circuit_with_args(a: CipherUnsigned, b: CipherUnsigned, c: CipherUnsigned) {
|
||||
let _ = a * b * c;
|
||||
}
|
||||
|
||||
let (scheme, compile_fn, signature) = circuit_with_args();
|
||||
|
||||
let type_name = Unsigned::type_name();
|
||||
|
||||
let expected_signature = CallSignature {
|
||||
@@ -197,10 +189,10 @@ fn can_mul() {
|
||||
returns: vec![],
|
||||
num_ciphertexts: vec![],
|
||||
};
|
||||
assert_eq!(signature, expected_signature);
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(circuit_with_args.signature(), expected_signature);
|
||||
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let context = compile_fn(&get_params()).unwrap();
|
||||
let context = circuit_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
@@ -247,12 +239,10 @@ fn can_mul() {
|
||||
#[test]
|
||||
fn can_collect_output() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn circuit_with_args(a: Unsigned, b: Unsigned) -> Unsigned {
|
||||
fn circuit_with_args(a: Cipher<Unsigned>, b: CipherUnsigned) -> CipherUnsigned {
|
||||
a + b * a
|
||||
}
|
||||
|
||||
let (scheme, compile_fn, signature) = circuit_with_args();
|
||||
|
||||
let type_name = Unsigned::type_name();
|
||||
|
||||
let expected_signature = CallSignature {
|
||||
@@ -260,10 +250,10 @@ fn can_collect_output() {
|
||||
returns: vec![type_name.clone()],
|
||||
num_ciphertexts: vec![1],
|
||||
};
|
||||
assert_eq!(signature, expected_signature);
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(circuit_with_args.signature(), expected_signature);
|
||||
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let context = compile_fn(&get_params()).unwrap();
|
||||
let context = circuit_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
@@ -315,12 +305,10 @@ fn can_collect_output() {
|
||||
#[test]
|
||||
fn can_collect_multiple_outputs() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn circuit_with_args(a: Unsigned, b: Unsigned) -> (Unsigned, Unsigned) {
|
||||
fn circuit_with_args(a: Cipher<Unsigned>, b: CipherUnsigned) -> (Cipher<Unsigned>, Cipher<Unsigned>) {
|
||||
(a + b * a, a)
|
||||
}
|
||||
|
||||
let (scheme, compile_fn, signature) = circuit_with_args();
|
||||
|
||||
let type_name = Unsigned::type_name();
|
||||
|
||||
let expected_signature = CallSignature {
|
||||
@@ -328,10 +316,10 @@ fn can_collect_multiple_outputs() {
|
||||
returns: vec![type_name.clone(), type_name.clone()],
|
||||
num_ciphertexts: vec![1, 1],
|
||||
};
|
||||
assert_eq!(signature, expected_signature);
|
||||
assert_eq!(scheme, SchemeType::Bfv);
|
||||
assert_eq!(circuit_with_args.signature(), expected_signature);
|
||||
assert_eq!(circuit_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let context = compile_fn(&get_params()).unwrap();
|
||||
let context = circuit_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use sunscreen_compiler::{types::*, *};
|
||||
|
||||
type CipherUnsigned = Cipher<Unsigned>;
|
||||
|
||||
#[test]
|
||||
fn can_encrypt_decrypt() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn foo(a: Unsigned, b: Unsigned) -> Unsigned {
|
||||
fn foo(a: CipherUnsigned, b: CipherUnsigned) -> CipherUnsigned {
|
||||
a + b
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user