This commit is contained in:
Rick Weber
2021-11-30 15:23:57 -08:00
parent feae81a051
commit 2875b150a7
11 changed files with 177 additions and 74 deletions

12
Cargo.lock generated
View File

@@ -525,6 +525,17 @@ dependencies = [
"serde_json",
]
[[package]]
name = "sunscreen_frontend"
version = "0.1.0"
dependencies = [
"seal",
"sunscreen_backend",
"sunscreen_frontend_macros",
"sunscreen_frontend_types",
"sunscreen_runtime",
]
[[package]]
name = "sunscreen_frontend_macros"
version = "0.1.0"
@@ -543,6 +554,7 @@ version = "0.1.0"
dependencies = [
"petgraph",
"serde",
"sunscreen_backend",
"sunscreen_circuit",
]

View File

@@ -2,6 +2,7 @@
members = [
"seal",
"sunscreen_backend",
"sunscreen_frontend",
"sunscreen_frontend_types",
"sunscreen_frontend_macros",
"sunscreen_circuit",

View File

@@ -194,7 +194,10 @@ impl ToString for NodeInfo {
}
impl NodeInfo {
fn new(operation: Operation) -> Self {
/**
* Creates a new NodeInfo from the given operation.
*/
pub fn new(operation: Operation) -> Self {
Self { operation }
}
}

View File

@@ -0,0 +1,15 @@
[package]
name = "sunscreen_frontend"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sunscreen_frontend_types = { path = "../sunscreen_frontend_types" }
sunscreen_frontend_macros = { path = "../sunscreen_frontend_macros" }
sunscreen_backend = {path = "../sunscreen_backend" }
[dev-dependencies]
sunscreen_runtime = { path = "../sunscreen_runtime" }
seal = { path = "../seal" }

View File

@@ -0,0 +1,3 @@
pub use sunscreen_backend::SchemeType;
pub use sunscreen_frontend_macros::circuit;
pub use sunscreen_frontend_types::{types, Context};

View File

@@ -0,0 +1,23 @@
use sunscreen_frontend::{Context, circuit, types::*};
use sunscreen_runtime::run_program_unchecked;
use seal::{Evaluator, Context as SealContext, BfvEncryptionParametersBuilder};
fn setup_seal(circuit: &Circuit) {
BfvEncryptionParametersBuilder::new()
.
}
#[test]
fn can_compile_and_run_simple_add() {
#[circuit]
fn add(a: Signed, b: Signed) -> Signed {
a + b
}
let circuit = add().compile();
unsafe {
run_program_unchecked(&circuit, inputs: &[Ciphertext], evaluator: &E, relin_keys: Option<RelinearizationKeys>, galois_keys: Option<GaloisKeys>)
}
}

View File

@@ -141,15 +141,13 @@ pub fn circuit(
let capture_outputs = match ret {
ReturnType::Type(_, t) => {
let tuple_inners = match &**t {
Type::Tuple(t) => {
t.elems.iter().map(|x| &*x).collect::<Vec<&Type>>()
},
Type::Tuple(t) => t.elems.iter().map(|x| &*x).collect::<Vec<&Type>>(),
Type::Paren(t) => {
vec![&*t.elem]
},
}
Type::Path(_) => {
vec![&**t]
},
}
_ => {
return proc_macro::TokenStream::from(quote! {
compile_error!("Circuits must return a single Cipthertext or a tuple of Ciphertexts");
@@ -162,17 +160,21 @@ pub fn circuit(
v.output();
}
} else {
tuple_inners.iter().enumerate().map(|(i, t)| {
let index = Index::from(i);
quote_spanned! {t.span() =>
v.#index.output();
}
}).collect()
tuple_inners
.iter()
.enumerate()
.map(|(i, t)| {
let index = Index::from(i);
quote_spanned! {t.span() =>
v.#index.output();
}
})
.collect()
}
},
ReturnType::Default => {
quote! { }
}
ReturnType::Default => {
quote! {}
}
};
@@ -204,7 +206,7 @@ pub fn circuit(
match panic_res {
Ok(v) => { #capture_outputs },
Err(err) => {
Err(err) => {
ctx.swap(&RefCell::new(None));
std::panic::resume_unwind(err)
}

View File

@@ -249,45 +249,45 @@ fn can_collect_output() {
let context = circuit_with_args();
let expected = json!({
"graph": {
"nodes": [
"InputCiphertext",
"InputCiphertext",
"Multiply",
"Add",
"Output"
"graph": {
"nodes": [
"InputCiphertext",
"InputCiphertext",
"Multiply",
"Add",
"Output"
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
1,
2,
"Left"
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
1,
2,
"Left"
],
[
0,
2,
"Right"
],
[
0,
3,
"Left"
],
[
2,
3,
"Right"
],
[
3,
4,
"Unary"
]
[
0,
2,
"Right"
],
[
0,
3,
"Left"
],
[
2,
3,
"Right"
],
[
3,
4,
"Unary"
]
}
});
]
}
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
}

View File

@@ -7,5 +7,6 @@ edition = "2021"
[dependencies]
petgraph = "0.6.0"
sunscreen_backend = { path = "../sunscreen_backend" }
sunscreen_circuit = { path = "../sunscreen_circuit" }
serde = { version = "1.0.130", features = ["derive"] }

View File

@@ -1,4 +1,4 @@
mod types;
pub mod types;
use std::cell::RefCell;
@@ -9,11 +9,17 @@ use petgraph::{
};
use serde::{Deserialize, Serialize};
pub use types::*;
use sunscreen_backend::compile_inplace;
use sunscreen_circuit::{
Circuit, EdgeInfo, Literal as CircuitLiteral, NodeInfo, Operation as CircuitOperation,
OuterLiteral as CircuitOuterLiteral,
};
pub use sunscreen_circuit::SchemeType;
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum Literal {
U64(u64)
U64(u64),
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
@@ -25,14 +31,14 @@ pub enum Operation {
RotateLeft,
RotateRight,
SwapRows,
Output
Output,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum OperandInfo {
Left,
Right,
Unary
Unary,
}
pub trait Value {
@@ -97,16 +103,24 @@ impl Context {
pub fn add_literal(&mut self, literal: Literal) -> NodeIndex {
// See if we already have a node for the given literal. If so, just return it.
// If not, make a new one.
let existing_literal = self.graph.node_indices().filter_map(|i| {
match &self.graph[i] {
Operation::Literal(x) => if *x == literal { Some(i) } else { None },
_ => None
}
}).nth(0);
let existing_literal = self
.graph
.node_indices()
.filter_map(|i| match &self.graph[i] {
Operation::Literal(x) => {
if *x == literal {
Some(i)
} else {
None
}
}
_ => None,
})
.nth(0);
match existing_literal {
Some(x) => x,
None => self.graph.add_node(Operation::Literal(literal))
None => self.graph.add_node(Operation::Literal(literal)),
}
}
@@ -121,4 +135,34 @@ impl Context {
pub fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::Output, i)
}
pub fn compile(&self) -> Circuit {
let mut circuit = Circuit::new(SchemeType::Bfv);
let mapped_graph = self.graph.map(
|id, n| match n {
Operation::Add => NodeInfo::new(CircuitOperation::Add),
Operation::InputCiphertext => {
NodeInfo::new(CircuitOperation::InputCiphertext(id.index()))
}
Operation::Literal(Literal::U64(x)) => NodeInfo::new(CircuitOperation::Literal(
CircuitOuterLiteral::Scalar(CircuitLiteral::U64(*x)),
)),
Operation::Multiply => NodeInfo::new(CircuitOperation::Multiply),
Operation::Output => NodeInfo::new(CircuitOperation::OutputCiphertext),
Operation::RotateLeft => NodeInfo::new(CircuitOperation::ShiftLeft),
Operation::RotateRight => NodeInfo::new(CircuitOperation::ShiftRight),
Operation::SwapRows => NodeInfo::new(CircuitOperation::SwapRows),
},
|_, e| match e {
OperandInfo::Left => EdgeInfo::LeftOperand,
OperandInfo::Right => EdgeInfo::RightOperand,
OperandInfo::Unary => EdgeInfo::UnaryOperand,
},
);
circuit.graph = StableGraph::from(mapped_graph);
compile_inplace(circuit)
}
}

View File

@@ -1,9 +1,9 @@
use std::ops::{Add, Mul, Shl, Shr};
use petgraph::stable_graph::NodeIndex;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use crate::{Value, Context, CURRENT_CTX, Literal};
use crate::{Context, Literal, Value, CURRENT_CTX};
#[derive(Clone, Copy, Serialize, Deserialize)]
struct LiteralRef {
@@ -12,10 +12,9 @@ struct LiteralRef {
impl LiteralRef {
fn new(v: Literal) -> Self {
with_ctx(|ctx|
Self {
with_ctx(|ctx| Self {
id: ctx.add_literal(v),
})
})
}
}
@@ -33,7 +32,7 @@ impl Value for Signed {
fn output(&self) -> Self {
with_ctx(|ctx| Self {
id: ctx.add_output(self.id)
id: ctx.add_output(self.id),
})
}
}
@@ -67,7 +66,7 @@ impl Shl<u64> for Signed {
let l = LiteralRef::new(Literal::U64(n));
with_ctx(|ctx| Self {
id: ctx.add_rotate_left(self.id, l.id)
id: ctx.add_rotate_left(self.id, l.id),
})
}
}
@@ -79,7 +78,7 @@ impl Shr<u64> for Signed {
let l = LiteralRef::new(Literal::U64(n));
with_ctx(|ctx| Self {
id: ctx.add_rotate_right(self.id, l.id)
id: ctx.add_rotate_right(self.id, l.id),
})
}
}