mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
fmt
This commit is contained in:
12
Cargo.lock
generated
12
Cargo.lock
generated
@@ -525,6 +525,17 @@ dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sunscreen_frontend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"seal",
|
||||
"sunscreen_backend",
|
||||
"sunscreen_frontend_macros",
|
||||
"sunscreen_frontend_types",
|
||||
"sunscreen_runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sunscreen_frontend_macros"
|
||||
version = "0.1.0"
|
||||
@@ -543,6 +554,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"petgraph",
|
||||
"serde",
|
||||
"sunscreen_backend",
|
||||
"sunscreen_circuit",
|
||||
]
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
members = [
|
||||
"seal",
|
||||
"sunscreen_backend",
|
||||
"sunscreen_frontend",
|
||||
"sunscreen_frontend_types",
|
||||
"sunscreen_frontend_macros",
|
||||
"sunscreen_circuit",
|
||||
|
||||
@@ -194,7 +194,10 @@ impl ToString for NodeInfo {
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
fn new(operation: Operation) -> Self {
|
||||
/**
|
||||
* Creates a new NodeInfo from the given operation.
|
||||
*/
|
||||
pub fn new(operation: Operation) -> Self {
|
||||
Self { operation }
|
||||
}
|
||||
}
|
||||
|
||||
15
sunscreen_frontend/Cargo.toml
Normal file
15
sunscreen_frontend/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "sunscreen_frontend"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
sunscreen_frontend_types = { path = "../sunscreen_frontend_types" }
|
||||
sunscreen_frontend_macros = { path = "../sunscreen_frontend_macros" }
|
||||
sunscreen_backend = {path = "../sunscreen_backend" }
|
||||
|
||||
[dev-dependencies]
|
||||
sunscreen_runtime = { path = "../sunscreen_runtime" }
|
||||
seal = { path = "../seal" }
|
||||
3
sunscreen_frontend/src/lib.rs
Normal file
3
sunscreen_frontend/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub use sunscreen_backend::SchemeType;
|
||||
pub use sunscreen_frontend_macros::circuit;
|
||||
pub use sunscreen_frontend_types::{types, Context};
|
||||
23
sunscreen_frontend/tests/circuit_compilation.rs
Normal file
23
sunscreen_frontend/tests/circuit_compilation.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use sunscreen_frontend::{Context, circuit, types::*};
|
||||
use sunscreen_runtime::run_program_unchecked;
|
||||
use seal::{Evaluator, Context as SealContext, BfvEncryptionParametersBuilder};
|
||||
|
||||
fn setup_seal(circuit: &Circuit) {
|
||||
BfvEncryptionParametersBuilder::new()
|
||||
.
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_compile_and_run_simple_add() {
|
||||
#[circuit]
|
||||
fn add(a: Signed, b: Signed) -> Signed {
|
||||
a + b
|
||||
}
|
||||
|
||||
let circuit = add().compile();
|
||||
|
||||
unsafe {
|
||||
run_program_unchecked(&circuit, inputs: &[Ciphertext], evaluator: &E, relin_keys: Option<RelinearizationKeys>, galois_keys: Option<GaloisKeys>)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -141,15 +141,13 @@ pub fn circuit(
|
||||
let capture_outputs = match ret {
|
||||
ReturnType::Type(_, t) => {
|
||||
let tuple_inners = match &**t {
|
||||
Type::Tuple(t) => {
|
||||
t.elems.iter().map(|x| &*x).collect::<Vec<&Type>>()
|
||||
},
|
||||
Type::Tuple(t) => t.elems.iter().map(|x| &*x).collect::<Vec<&Type>>(),
|
||||
Type::Paren(t) => {
|
||||
vec![&*t.elem]
|
||||
},
|
||||
}
|
||||
Type::Path(_) => {
|
||||
vec![&**t]
|
||||
},
|
||||
}
|
||||
_ => {
|
||||
return proc_macro::TokenStream::from(quote! {
|
||||
compile_error!("Circuits must return a single Cipthertext or a tuple of Ciphertexts");
|
||||
@@ -162,17 +160,21 @@ pub fn circuit(
|
||||
v.output();
|
||||
}
|
||||
} else {
|
||||
tuple_inners.iter().enumerate().map(|(i, t)| {
|
||||
let index = Index::from(i);
|
||||
|
||||
quote_spanned! {t.span() =>
|
||||
v.#index.output();
|
||||
}
|
||||
}).collect()
|
||||
tuple_inners
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let index = Index::from(i);
|
||||
|
||||
quote_spanned! {t.span() =>
|
||||
v.#index.output();
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
},
|
||||
ReturnType::Default => {
|
||||
quote! { }
|
||||
}
|
||||
ReturnType::Default => {
|
||||
quote! {}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -204,7 +206,7 @@ pub fn circuit(
|
||||
|
||||
match panic_res {
|
||||
Ok(v) => { #capture_outputs },
|
||||
Err(err) => {
|
||||
Err(err) => {
|
||||
ctx.swap(&RefCell::new(None));
|
||||
std::panic::resume_unwind(err)
|
||||
}
|
||||
|
||||
@@ -249,45 +249,45 @@ fn can_collect_output() {
|
||||
let context = circuit_with_args();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"Multiply",
|
||||
"Add",
|
||||
"Output"
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"Multiply",
|
||||
"Add",
|
||||
"Output"
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
1,
|
||||
2,
|
||||
"Left"
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
1,
|
||||
2,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
0,
|
||||
2,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
0,
|
||||
3,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
2,
|
||||
3,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
3,
|
||||
4,
|
||||
"Unary"
|
||||
]
|
||||
[
|
||||
0,
|
||||
2,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
0,
|
||||
3,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
2,
|
||||
3,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
3,
|
||||
4,
|
||||
"Unary"
|
||||
]
|
||||
}
|
||||
});
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(context, serde_json::from_value(expected).unwrap());
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
petgraph = "0.6.0"
|
||||
sunscreen_backend = { path = "../sunscreen_backend" }
|
||||
sunscreen_circuit = { path = "../sunscreen_circuit" }
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
@@ -1,4 +1,4 @@
|
||||
mod types;
|
||||
pub mod types;
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
@@ -9,11 +9,17 @@ use petgraph::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use types::*;
|
||||
use sunscreen_backend::compile_inplace;
|
||||
use sunscreen_circuit::{
|
||||
Circuit, EdgeInfo, Literal as CircuitLiteral, NodeInfo, Operation as CircuitOperation,
|
||||
OuterLiteral as CircuitOuterLiteral,
|
||||
};
|
||||
|
||||
pub use sunscreen_circuit::SchemeType;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
pub enum Literal {
|
||||
U64(u64)
|
||||
U64(u64),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
@@ -25,14 +31,14 @@ pub enum Operation {
|
||||
RotateLeft,
|
||||
RotateRight,
|
||||
SwapRows,
|
||||
Output
|
||||
Output,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
pub enum OperandInfo {
|
||||
Left,
|
||||
Right,
|
||||
Unary
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub trait Value {
|
||||
@@ -97,16 +103,24 @@ impl Context {
|
||||
pub fn add_literal(&mut self, literal: Literal) -> NodeIndex {
|
||||
// See if we already have a node for the given literal. If so, just return it.
|
||||
// If not, make a new one.
|
||||
let existing_literal = self.graph.node_indices().filter_map(|i| {
|
||||
match &self.graph[i] {
|
||||
Operation::Literal(x) => if *x == literal { Some(i) } else { None },
|
||||
_ => None
|
||||
}
|
||||
}).nth(0);
|
||||
let existing_literal = self
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|i| match &self.graph[i] {
|
||||
Operation::Literal(x) => {
|
||||
if *x == literal {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.nth(0);
|
||||
|
||||
match existing_literal {
|
||||
Some(x) => x,
|
||||
None => self.graph.add_node(Operation::Literal(literal))
|
||||
None => self.graph.add_node(Operation::Literal(literal)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,4 +135,34 @@ impl Context {
|
||||
pub fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
|
||||
self.add_1_input(Operation::Output, i)
|
||||
}
|
||||
|
||||
pub fn compile(&self) -> Circuit {
|
||||
let mut circuit = Circuit::new(SchemeType::Bfv);
|
||||
|
||||
let mapped_graph = self.graph.map(
|
||||
|id, n| match n {
|
||||
Operation::Add => NodeInfo::new(CircuitOperation::Add),
|
||||
Operation::InputCiphertext => {
|
||||
NodeInfo::new(CircuitOperation::InputCiphertext(id.index()))
|
||||
}
|
||||
Operation::Literal(Literal::U64(x)) => NodeInfo::new(CircuitOperation::Literal(
|
||||
CircuitOuterLiteral::Scalar(CircuitLiteral::U64(*x)),
|
||||
)),
|
||||
Operation::Multiply => NodeInfo::new(CircuitOperation::Multiply),
|
||||
Operation::Output => NodeInfo::new(CircuitOperation::OutputCiphertext),
|
||||
Operation::RotateLeft => NodeInfo::new(CircuitOperation::ShiftLeft),
|
||||
Operation::RotateRight => NodeInfo::new(CircuitOperation::ShiftRight),
|
||||
Operation::SwapRows => NodeInfo::new(CircuitOperation::SwapRows),
|
||||
},
|
||||
|_, e| match e {
|
||||
OperandInfo::Left => EdgeInfo::LeftOperand,
|
||||
OperandInfo::Right => EdgeInfo::RightOperand,
|
||||
OperandInfo::Unary => EdgeInfo::UnaryOperand,
|
||||
},
|
||||
);
|
||||
|
||||
circuit.graph = StableGraph::from(mapped_graph);
|
||||
|
||||
compile_inplace(circuit)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::ops::{Add, Mul, Shl, Shr};
|
||||
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Value, Context, CURRENT_CTX, Literal};
|
||||
use crate::{Context, Literal, Value, CURRENT_CTX};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize)]
|
||||
struct LiteralRef {
|
||||
@@ -12,10 +12,9 @@ struct LiteralRef {
|
||||
|
||||
impl LiteralRef {
|
||||
fn new(v: Literal) -> Self {
|
||||
with_ctx(|ctx|
|
||||
Self {
|
||||
with_ctx(|ctx| Self {
|
||||
id: ctx.add_literal(v),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +32,7 @@ impl Value for Signed {
|
||||
|
||||
fn output(&self) -> Self {
|
||||
with_ctx(|ctx| Self {
|
||||
id: ctx.add_output(self.id)
|
||||
id: ctx.add_output(self.id),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -67,7 +66,7 @@ impl Shl<u64> for Signed {
|
||||
let l = LiteralRef::new(Literal::U64(n));
|
||||
|
||||
with_ctx(|ctx| Self {
|
||||
id: ctx.add_rotate_left(self.id, l.id)
|
||||
id: ctx.add_rotate_left(self.id, l.id),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -79,7 +78,7 @@ impl Shr<u64> for Signed {
|
||||
let l = LiteralRef::new(Literal::U64(n));
|
||||
|
||||
with_ctx(|ctx| Self {
|
||||
id: ctx.add_rotate_right(self.id, l.id)
|
||||
id: ctx.add_rotate_right(self.id, l.id),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user