Merge pull request #5 from Sunscreen-tech/interface

Interface
This commit is contained in:
rickwebiii
2021-12-06 20:18:52 -08:00
committed by GitHub
25 changed files with 485 additions and 193 deletions

6
Cargo.lock generated
View File

@@ -507,7 +507,7 @@ name = "simple_multiply"
version = "0.1.0"
dependencies = [
"seal",
"sunscreen_frontend",
"sunscreen_compiler",
"sunscreen_runtime",
]
@@ -540,7 +540,7 @@ dependencies = [
]
[[package]]
name = "sunscreen_frontend"
name = "sunscreen_compiler"
version = "0.1.0"
dependencies = [
"env_logger 0.9.0",
@@ -558,7 +558,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde_json",
"sunscreen_frontend",
"sunscreen_compiler",
"sunscreen_frontend_types",
"syn",
"trybuild",

View File

@@ -4,7 +4,7 @@ members = [
"examples/simple_multiply",
"seal",
"sunscreen_backend",
"sunscreen_frontend",
"sunscreen_compiler",
"sunscreen_frontend_types",
"sunscreen_frontend_macros",
"sunscreen_circuit",

View File

@@ -7,5 +7,5 @@ edition = "2021"
[dependencies]
seal = { path = "../../seal" }
sunscreen_frontend = { path = "../../sunscreen_frontend" }
sunscreen_compiler = { path = "../../sunscreen_compiler" }
sunscreen_runtime = { path = "../../sunscreen_runtime" }

View File

@@ -1,5 +1,5 @@
use seal::BFVScalarEncoder;
use sunscreen_frontend::{circuit, types::Signed, Compiler, Params, PlainModulusConstraint};
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, Params, PlainModulusConstraint};
use sunscreen_runtime::RuntimeBuilder;
/**
@@ -8,15 +8,15 @@ use sunscreen_runtime::RuntimeBuilder;
* the result. Circuits may take any number of parameters and return either a single result
* or a tuple of results.
*
* Currently, the Signed type is the only legal type in circuit parameters and return values,
* Currently, the Unsigned type is the only legal type in circuit parameters and return values,
* which serves as a placeholder that allows the compiler to build up the circuit. Don't attach
* much meaning to it in its current form; this example in fact uses unsigned values!
*
* One takes a circuit and passes them to the compiler, which transforms it into a form
* suitable for execution.
*/
#[circuit]
fn simple_multiply(a: Signed, b: Signed) -> Signed {
#[circuit(scheme = "bfv")]
fn simple_multiply(a: Unsigned, b: Unsigned) -> Unsigned {
a * b
}

View File

@@ -197,9 +197,7 @@ impl CoefficientModulus {
Ok(coefficients
.iter()
.map(|handle| {
Modulus { handle: *handle }
})
.map(|handle| Modulus { handle: *handle })
.collect())
}

View File

@@ -30,7 +30,7 @@ use TransformNodeIndex::*;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
/**
* There 3 primary FHE schemes in use today: BFV, CKKS, and TFHE. BFV is generally the best choice for algorithms
* requiring exact arithmetic on integers easily expressed as addition and multiplication. CKKS is generally best

View File

@@ -1,5 +1,5 @@
[package]
name = "sunscreen_frontend"
name = "sunscreen_compiler"
version = "0.1.0"
edition = "2021"

View File

@@ -0,0 +1,5 @@
pub use sunscreen_frontend_macros::circuit;
pub use sunscreen_frontend_types::{
types, Compiler, Context, Error, FrontendCompilation, Params, PlainModulusConstraint, Result,
SchemeType, SecurityLevel, Value, CURRENT_CTX,
};

View File

@@ -1,4 +1,4 @@
use sunscreen_frontend::{circuit, types::*, Compiler, Params, PlainModulusConstraint};
use sunscreen_compiler::{circuit, types::*, Compiler, Params, PlainModulusConstraint};
use sunscreen_runtime::RuntimeBuilder;
use seal::BFVScalarEncoder;
@@ -7,8 +7,8 @@ use seal::BFVScalarEncoder;
fn can_compile_and_run_simple_add() {
let _ = env_logger::try_init();
#[circuit]
fn c(a: Signed, b: Signed) -> Signed {
#[circuit(scheme = "bfv")]
fn c(a: Unsigned, b: Unsigned) -> Unsigned {
a + b
}
@@ -47,8 +47,8 @@ fn can_compile_and_run_simple_add() {
fn can_compile_and_run_simple_mul() {
let _ = env_logger::try_init();
#[circuit]
fn c(a: Signed, b: Signed) -> Signed {
#[circuit(scheme = "bfv")]
fn c(a: Unsigned, b: Unsigned) -> Unsigned {
a * b
}
@@ -89,8 +89,8 @@ fn can_compile_and_run_simple_mul() {
fn can_compile_and_run_mul_reduction() {
let _ = env_logger::try_init();
#[circuit]
fn c(a: Signed, b: Signed, c: Signed, d: Signed) -> Signed {
#[circuit(scheme = "bfv")]
fn c(a: Unsigned, b: Unsigned, c: Unsigned, d: Unsigned) -> Unsigned {
a * b * c * d
}
@@ -137,8 +137,8 @@ fn can_compile_and_run_mul_reduction() {
fn can_compile_and_run_add_reduction() {
let _ = env_logger::try_init();
#[circuit]
fn c(a: Signed, b: Signed, c: Signed, d: Signed) -> Signed {
#[circuit(scheme = "bfv")]
fn c(a: Unsigned, b: Unsigned, c: Unsigned, d: Unsigned) -> Unsigned {
a + b + c + d
}

View File

@@ -1,5 +0,0 @@
pub use sunscreen_frontend_macros::circuit;
pub use sunscreen_frontend_types::{
types, Compiler, Context, Params, PlainModulusConstraint, SchemeType, SecurityLevel, Value,
CURRENT_CTX,
};

View File

@@ -17,4 +17,4 @@ sunscreen_frontend_types = { path = "../sunscreen_frontend_types" }
[dev-dependencies]
trybuild = "1.0.52"
serde_json = "1.0.72"
sunscreen_frontend = { path = "../sunscreen_frontend" }
sunscreen_compiler = { path = "../sunscreen_compiler" }

View File

@@ -0,0 +1,13 @@
#[derive(Debug)]
pub enum Error {
SynError(syn::Error),
UnknownScheme(String),
}
impl From<syn::Error> for Error {
fn from(err: syn::Error) -> Self {
Self::SynError(err)
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,91 @@
use super::case::Scheme;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Error, Expr, Lit, Result, Token,
};
use crate::internals::symbols::VALUE_KEYS;
use std::collections::HashMap;
pub struct Attrs {
pub scheme: Scheme,
}
impl Parse for Attrs {
fn parse(input: ParseStream) -> Result<Self> {
// parses a,b,c, or a,b,c where a,b and c are Indent
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
let mut attrs: HashMap<String, Option<String>> = HashMap::new();
for var in &vars {
match var {
Expr::Assign(a) => {
let key = match &*a.left {
Expr::Path(p) =>
p.path.get_ident().ok_or(Error::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
_ => { return Err(Error::new_spanned(&a.left, "Key should be a plain identifier")) }
};
let value = match &*a.right {
Expr::Lit(l) => match &l.lit {
Lit::Str(s) => s.value(),
_ => return Err(Error::new_spanned(l, "Literal should be a string")),
},
_ => {
return Err(Error::new_spanned(
&a.right,
"Value should be a string literal",
))
}
};
if !VALUE_KEYS.iter().any(|x| *x == key) {
return Err(Error::new_spanned(a, "Unknown key".to_owned()));
}
attrs.insert(key, Some(value));
}
Expr::Path(p) => {
let key = p
.path
.get_ident()
.ok_or(Error::new_spanned(p, "Unknown identifier"))?
.to_string();
if !VALUE_KEYS.iter().any(|x| *x == key) {
return Err(Error::new_spanned(p, "Unknown key"));
}
attrs.insert(key, None);
}
_ => {
return Err(Error::new_spanned(
var,
"Expected `key = \"value\"` or `key`",
))
}
}
}
let scheme_type = attrs
.get("scheme")
.ok_or(Error::new_spanned(
&vars,
"required `scheme` is missing".to_owned(),
))?
.as_ref()
.ok_or(Error::new_spanned(
&vars,
"`scheme` requires a value".to_owned(),
))?;
Ok(Self {
scheme: Scheme::parse(&scheme_type).map_err(|_e| {
Error::new_spanned(vars, format!("Unknown variant {}", &scheme_type))
})?,
})
}
}

View File

@@ -0,0 +1,16 @@
use self::Scheme::*;
use crate::error::*;
#[derive(Copy, Clone, PartialEq)]
pub enum Scheme {
Bfv,
}
impl Scheme {
pub fn parse(s: &str) -> Result<Self> {
Ok(match s {
"bfv" => Bfv,
_ => Err(Error::UnknownScheme(s.to_owned()))?,
})
}
}

View File

@@ -0,0 +1,4 @@
// Following the pattern in serde (https://github.com/serde-rs)
pub mod attr;
pub mod case;
pub mod symbols;

View File

@@ -0,0 +1 @@
pub const VALUE_KEYS: &[&str] = &["scheme"];

View File

@@ -6,6 +6,11 @@
extern crate proc_macro;
mod error;
mod internals;
use crate::internals::{attr::Attrs, case::Scheme};
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
use syn::{
@@ -29,7 +34,7 @@ pub fn derive_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let expanded = quote! {
// The generated impl.
impl #impl_generics sunscreen_frontend::Value for #name #ty_generics #where_clause {
impl #impl_generics sunscreen_compiler::Value for #name #ty_generics #where_clause {
fn new(id: usize) {
#new
}
@@ -45,7 +50,7 @@ fn add_trait_bounds(mut generics: Generics) -> Generics {
if let GenericParam::Type(ref mut type_param) = *param {
type_param
.bounds
.push(parse_quote!(sunscreen_frontend::Value));
.push(parse_quote!(sunscreen_compiler::Value));
}
}
generics
@@ -101,26 +106,30 @@ fn new_body(data: &Data) -> TokenStream {
* This function gets run by the compiler to build up the circuit you specify and does not
* directly or eagerly perform homomorphic operations.
*
* # Parameters
* * `scheme` (required): Designates the scheme this circuit uses. Today, this must be `"bfv"`.
*
* # Examples
* ```rust
* # use sunscreen_frontend_types::{types::Signed, Params, Context};
* # use sunscreen_frontend_macros::{circuit};
*
* #[circuit]
* fn multiply_add(a: Signed, b: Signed, c: Signed) -> Signed {
* # use sunscreen_compiler::{circuit, types::Unsigned, Params, Context};
*
* #[circuit(scheme = "bfv")]
* fn multiply_add(a: Unsigned, b: Unsigned, c: Unsigned) -> Unsigned {
* a * b + c
* }
* ```
*
* ```rust
* #[circuit]
* fn multi_out(a: Signed, b: Signed, c: Signed) -> (Signed, Signed) {
* # use sunscreen_compiler::{circuit, types::Unsigned, Params, Context};
*
* #[circuit(scheme = "bfv")]
* fn multi_out(a: Unsigned, b: Unsigned, c: Unsigned) -> (Unsigned, Unsigned) {
* (a + b, b + c)
* }
* ```
*/
pub fn circuit(
_metadata: proc_macro::TokenStream,
metadata: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input_fn = parse_macro_input!(input as ItemFn);
@@ -134,6 +143,16 @@ pub fn circuit(
let mut unwrapped_inputs = vec![];
let attr_params = parse_macro_input!(metadata as Attrs);
let scheme_type = match attr_params.scheme {
Scheme::Bfv => {
quote! {
SchemeType::Bfv
}
}
};
for i in inputs {
let input_type = match i {
FnArg::Receiver(_) => {
@@ -154,19 +173,31 @@ pub fn circuit(
unwrapped_inputs.push(input_type);
}
let circuit_args = unwrapped_inputs
.iter()
.map(|i| {
let name = &i.pat;
let ty = &i.ty;
quote! {
#name: CircuitNode<#ty>,
}
})
.collect::<Vec<TokenStream>>();
let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| {
let id = Ident::new(&format!("c_{}", i), Span::call_site());
let ty = &t.ty;
quote_spanned! {t.span() =>
let #id = #ty ::new();
let #id: CircuitNode<#ty> = CircuitNode::input();
}
});
let args = unwrapped_inputs.iter().enumerate().map(|(i, _)| {
let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| {
let id = Ident::new(&format!("c_{}", i), Span::call_site());
quote! {
quote_spanned! {t.span() =>
#id
}
});
@@ -213,43 +244,50 @@ pub fn circuit(
proc_macro::TokenStream::from(quote! {
#(#attrs)*
#vis fn #circuit_name(params: &Params) -> sunscreen_frontend::Context {
#vis fn #circuit_name() -> (sunscreen_compiler::SchemeType, impl Fn(&Params) -> sunscreen_compiler::Result<sunscreen_compiler::FrontendCompilation>) {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen_frontend::{CURRENT_CTX, Context, Params, SchemeType, Value};
use sunscreen_compiler::{CURRENT_CTX, Context, Error, Result, Params, SchemeType, Value, types::CircuitNode};
// TODO: Other schemes.
let mut context = Context::new(SchemeType::Bfv);
let mut cur_id = 0usize;
CURRENT_CTX.with(|ctx| {
fn internal(#inputs) #ret {
#body
let circuit_builder = |params: &Params| {
if SchemeType::Bfv != params.scheme_type {
return Err(Error::IncorrectScheme)
}
// Transmute away the lifetime to 'static. So long as we are careful with internal()
// panicing, this is safe because we set the context back to none before the funtion
// returns.
ctx.swap(&RefCell::new(Some(unsafe { transmute(&context) })));
// TODO: Other schemes.
let mut context = Context::new(params);
#(#var_decl)*
CURRENT_CTX.with(|ctx| {
let internal = | #(#circuit_args)* | {
#body
};
let panic_res = std::panic::catch_unwind(|| {
internal(#(#args),*)
// Transmute away the lifetime to 'static. So long as we are careful with internal()
// panicing, this is safe because we set the context back to none before the funtion
// returns.
ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) })));
#(#var_decl)*
let panic_res = std::panic::catch_unwind(|| {
internal(#(#args),*)
});
match panic_res {
Ok(v) => { #capture_outputs },
Err(err) => {
ctx.swap(&RefCell::new(None));
std::panic::resume_unwind(err)
}
};
ctx.swap(&RefCell::new(None));
});
match panic_res {
Ok(v) => { #capture_outputs },
Err(err) => {
ctx.swap(&RefCell::new(None));
std::panic::resume_unwind(err)
}
};
Ok(context.compilation)
};
ctx.swap(&RefCell::new(None));
});
context
(#scheme_type, circuit_builder)
}
})
}

View File

@@ -1,6 +1,6 @@
use sunscreen_frontend_macros::circuit;
use sunscreen_frontend_types::{
types::Signed, Params, SchemeType, SecurityLevel, CURRENT_CTX
types::Unsigned, FrontendCompilation, Params, SchemeType, SecurityLevel, CURRENT_CTX,
};
use serde_json::json;
@@ -19,21 +19,25 @@ fn get_params() -> Params {
fn circuit_gets_called() {
static mut FOO: u32 = 0;
#[circuit]
#[circuit(scheme = "bfv")]
fn simple_circuit() {
unsafe {
FOO = 20;
};
}
simple_circuit(&get_params());
let (scheme, compile_fn) = simple_circuit();
assert_eq!(scheme, SchemeType::Bfv);
let _context = compile_fn(&get_params()).unwrap();
assert_eq!(unsafe { FOO }, 20);
}
#[test]
fn panicing_circuit_clears_ctx() {
#[circuit]
#[circuit(scheme = "bfv")]
fn panic_circuit() {
CURRENT_CTX.with(|ctx| {
let old = ctx.take();
@@ -46,7 +50,11 @@ fn panicing_circuit_clears_ctx() {
}
let panic_result = std::panic::catch_unwind(|| {
panic_circuit(&get_params());
let (scheme, compile_fn) = panic_circuit();
assert_eq!(scheme, SchemeType::Bfv);
let _context = compile_fn(&get_params()).unwrap();
});
assert_eq!(panic_result.is_err(), true);
@@ -67,22 +75,30 @@ fn compile_failures() {
#[test]
fn capture_circuit_input_args() {
#[circuit]
fn circuit_with_args(_a: Signed, _b: Signed, _c: Signed, _d: Signed) {}
#[circuit(scheme = "bfv")]
fn circuit_with_args(_a: Unsigned, _b: Unsigned, _c: Unsigned, _d: Unsigned) {}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
assert_eq!(context.graph.node_count(), 4);
}
#[test]
fn can_add() {
#[circuit]
fn circuit_with_args(a: Signed, b: Signed, c: Signed) {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned, c: Unsigned) {
let _ = a + b + c;
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context: FrontendCompilation = compile_fn(&get_params()).unwrap();
let expected = json!({
@@ -119,20 +135,26 @@ fn can_add() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}
#[test]
fn can_mul() {
#[circuit]
fn circuit_with_args(a: Signed, b: Signed, c: Signed) {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned, c: Unsigned) {
let _ = a * b * c;
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -168,20 +190,26 @@ fn can_mul() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}
#[test]
fn can_rotate_left() {
#[circuit]
fn circuit_with_args(a: Signed) {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned) {
let _ = a << 4;
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -209,20 +237,26 @@ fn can_rotate_left() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}
#[test]
fn can_rotate_right() {
#[circuit]
fn circuit_with_args(a: Signed) {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned) {
let _ = a >> 4;
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context: FrontendCompilation = compile_fn(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -250,20 +284,26 @@ fn can_rotate_right() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}
#[test]
fn can_collect_output() {
#[circuit]
fn circuit_with_args(a: Signed, b: Signed) -> Signed {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned) -> Unsigned {
a + b * a
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -304,20 +344,26 @@ fn can_collect_output() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}
#[test]
fn can_collect_multiple_outputs() {
#[circuit]
fn circuit_with_args(a: Signed, b: Signed) -> (Signed, Signed) {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned, b: Unsigned) -> (Unsigned, Unsigned) {
(a + b * a, a)
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -364,22 +410,28 @@ fn can_collect_multiple_outputs() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}
#[test]
fn literals_consolidate() {
#[circuit]
fn circuit_with_args(a: Signed) {
#[circuit(scheme = "bfv")]
fn circuit_with_args(a: Unsigned) {
let _ = a << 4;
let _ = a << 4;
let _ = a << 3;
}
let context = circuit_with_args(&get_params());
let (scheme, compile_fn) = circuit_with_args();
assert_eq!(scheme, SchemeType::Bfv);
let context = compile_fn(&get_params()).unwrap();
let expected = json!({
"graph": {
@@ -434,8 +486,10 @@ fn literals_consolidate() {
]
]
},
"scheme": "Bfv"
});
assert_eq!(context, serde_json::from_value(expected).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
);
}

View File

@@ -3,7 +3,7 @@ use sunscreen_frontend_macros::{circuit};
struct Foo {}
impl Foo {
#[circuit]
#[circuit(scheme = "bfv")]
fn panic_circuit(&self) {
}
}

View File

@@ -1,7 +1,13 @@
error: circuits must not take a reference to self
--> tests/compile_failures/self_arg.rs:6:5
|
6 | #[circuit]
| ^^^^^^^^^^
|
= note: this error originates in the attribute macro `circuit` (in Nightly builds, run with -Z macro-backtrace for more info)
--> tests/compile_failures/self_arg.rs:6:5
|
6 | #[circuit(scheme = "bfv")]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ in this procedural macro expansion
|
::: src/lib.rs
|
| / pub fn circuit(
| | metadata: proc_macro::TokenStream,
| | input: proc_macro::TokenStream,
| | ) -> proc_macro::TokenStream {
| |____________________________- in this expansion of `#[circuit]`

View File

@@ -1,5 +1,5 @@
use crate::params::{determine_params, PlainModulusConstraint};
use crate::{Context, Error, Params, Result, SchemeType, SecurityLevel};
use crate::{Error, FrontendCompilation, Params, Result, SchemeType, SecurityLevel};
use sunscreen_circuit::Circuit;
#[derive(Debug, Clone)]
@@ -11,9 +11,10 @@ enum ParamsMode {
/**
* A frontend circuit compiler for Sunscreen circuits.
*/
pub struct Compiler<F>
pub struct Compiler<F, G>
where
F: Fn(&Params) -> Context,
G: Fn(&Params) -> Result<FrontendCompilation>,
F: Fn() -> (SchemeType, G),
{
circuit: F,
params_mode: ParamsMode,
@@ -22,9 +23,10 @@ where
noise_margin: u32,
}
impl<F> Compiler<F>
impl<F, G> Compiler<F, G>
where
F: Fn(&Params) -> Context,
G: Fn(&Params) -> Result<FrontendCompilation>,
F: Fn() -> (SchemeType, G),
{
/**
* Create a new compiler with the given circuit.
@@ -87,25 +89,29 @@ where
* for running it.
*/
pub fn compile(self) -> Result<(Circuit, Params)> {
let (scheme, circuit_fn) = (self.circuit)();
let (circuit, params) = match self.params_mode {
ParamsMode::Manual(p) => ((self.circuit)(&p), p.clone()),
ParamsMode::Manual(p) => {
(circuit_fn(&p), p.clone())
},
ParamsMode::Search => {
let constraint = self
.plain_modulus_constraint
.ok_or(Error::MissingPlainModulusConstraint)?;
let params = determine_params(
&self.circuit,
&circuit_fn,
constraint,
self.security_level,
self.noise_margin,
SchemeType::Bfv,
scheme
)?;
((self.circuit)(&params), params.clone())
(circuit_fn(&params), params.clone())
}
};
Ok((circuit.compile(), params))
Ok((circuit?.compile(), params))
}
}

View File

@@ -13,6 +13,11 @@ pub enum Error {
* No parameters were found that satisfy the given circuit.
*/
NoParams,
/**
* Attempted to compile the given circuit with the wrong scheme.
*/
IncorrectScheme,
}
/**

View File

@@ -128,21 +128,32 @@ pub trait Value {
#[derive(Clone, Debug, Deserialize, Serialize)]
/**
* The context under which std::ops:* on Value trait implementors should insert new nodes.
* Contains the frontend compilation graph.
*/
pub struct Context {
pub struct FrontendCompilation {
/**
* The dependency graph of the frontend's intermediate representation (IR) that backs a circuit.
*/
pub graph: StableGraph<Operation, OperandInfo>,
/**
* The type of scheme this circuit uses.
*/
pub scheme: SchemeType,
}
impl PartialEq for Context {
#[derive(Clone, Debug)]
/**
* The context under which std::ops:* on Value trait implementors should insert new nodes.
*/
pub struct Context {
/**
* The frontend compilation result.
*/
pub compilation: FrontendCompilation,
/**
* The set of parameters for which we're currently constructing the graph.
*/
pub params: Params,
}
impl PartialEq for FrontendCompilation {
fn eq(&self, b: &Self) -> bool {
is_isomorphic_matching(
&Graph::from(self.graph.clone()),
@@ -164,24 +175,32 @@ impl Context {
/**
* Creates a new empty frontend intermediate representation context with the given scheme.
*/
pub fn new(scheme: SchemeType) -> Self {
pub fn new(params: &Params) -> Self {
Self {
graph: StableGraph::new(),
scheme,
compilation: FrontendCompilation {
graph: StableGraph::new(),
},
params: params.clone(),
}
}
fn add_2_input(&mut self, op: Operation, left: NodeIndex, right: NodeIndex) -> NodeIndex {
let new_id = self.graph.add_node(op);
self.graph.add_edge(left, new_id, OperandInfo::Left);
self.graph.add_edge(right, new_id, OperandInfo::Right);
let new_id = self.compilation.graph.add_node(op);
self.compilation
.graph
.add_edge(left, new_id, OperandInfo::Left);
self.compilation
.graph
.add_edge(right, new_id, OperandInfo::Right);
new_id
}
fn add_1_input(&mut self, op: Operation, i: NodeIndex) -> NodeIndex {
let new_id = self.graph.add_node(op);
self.graph.add_edge(i, new_id, OperandInfo::Unary);
let new_id = self.compilation.graph.add_node(op);
self.compilation
.graph
.add_edge(i, new_id, OperandInfo::Unary);
new_id
}
@@ -190,7 +209,7 @@ impl Context {
* Add an input this context.
*/
pub fn add_input(&mut self) -> NodeIndex {
self.graph.add_node(Operation::InputCiphertext)
self.compilation.graph.add_node(Operation::InputCiphertext)
}
/**
@@ -214,9 +233,10 @@ impl Context {
// See if we already have a node for the given literal. If so, just return it.
// If not, make a new one.
let existing_literal = self
.compilation
.graph
.node_indices()
.filter_map(|i| match &self.graph[i] {
.filter_map(|i| match &self.compilation.graph[i] {
Operation::Literal(x) => {
if *x == literal {
Some(i)
@@ -230,7 +250,7 @@ impl Context {
match existing_literal {
Some(x) => x,
None => self.graph.add_node(Operation::Literal(literal)),
None => self.compilation.graph.add_node(Operation::Literal(literal)),
}
}
@@ -254,7 +274,9 @@ impl Context {
pub fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::Output, i)
}
}
impl FrontendCompilation {
/**
* Performs frontend compilation of this intermediate representation into a backend [`Circuit`],
* then perform backend compilation and return the result.

View File

@@ -1,4 +1,4 @@
use crate::{Context, Error, Result, SecurityLevel};
use crate::{Error, FrontendCompilation, Result, SecurityLevel};
use log::{debug, trace};
@@ -47,7 +47,7 @@ pub fn determine_params<F>(
scheme_type: SchemeType,
) -> Result<Params>
where
F: Fn(&Params) -> Context,
F: Fn(&Params) -> Result<FrontendCompilation>,
{
'order_loop: for (i, n) in LATTICE_DIMENSIONS.iter().enumerate() {
let plaintext_modulus = match plaintext_constraint {
@@ -144,7 +144,7 @@ where
scheme_type: scheme_type,
};
let ir = circuit_fn(&params).compile();
let ir = circuit_fn(&params)?.compile();
let num_inputs = ir
.graph

View File

@@ -3,90 +3,128 @@ use std::ops::{Add, Mul, Shl, Shr};
use petgraph::stable_graph::NodeIndex;
use serde::{Deserialize, Serialize};
use crate::{Context, Literal, Value, CURRENT_CTX};
use crate::{Context, Literal, CURRENT_CTX};
#[derive(Clone, Copy, Serialize, Deserialize)]
struct LiteralRef {
pub id: NodeIndex,
struct U64LiteralRef {}
impl FheType for U64LiteralRef {}
impl BfvType for U64LiteralRef {}
impl U64LiteralRef {
pub fn new(val: u64) -> NodeIndex {
with_ctx(|ctx| ctx.add_literal(Literal::U64(val)))
}
}
impl LiteralRef {
fn new(v: Literal) -> Self {
with_ctx(|ctx| Self {
id: ctx.add_literal(v),
})
/**
* Denotes the given rust type is an encoding in an FHE scheme
*/
pub trait FheType {}
/**
* Denotes the given type is valid under the [SchemeType::BFV](crate::SchemeType::Bfv).
*/
pub trait BfvType: FheType {}
impl CircuitNode<Unsigned> {
/**
* Returns the plain modulus parameter for the given BFV scheme
*/
pub fn get_plain_modulus() -> u64 {
with_ctx(|ctx| ctx.params.plain_modulus)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
/**
* A type that wraps an FheType during graph construction
*/
pub struct CircuitNode<T: FheType> {
/**
* The node's index
*/
pub id: NodeIndex,
_phantom: std::marker::PhantomData<T>,
}
impl<T: FheType> CircuitNode<T> {
/**
* Creates a new circuit node with the given node index.
*/
pub fn new(id: NodeIndex) -> Self {
Self {
id,
_phantom: std::marker::PhantomData,
}
}
/**
* Creates a new CircuitNode denoted as an input to a circuit graph.
*/
pub fn input() -> Self {
with_ctx(|ctx| Self::new(ctx.add_input()))
}
/**
* Denote this node as an output by appending an output circuit node.
*/
pub fn output(&self) -> Self {
with_ctx(|ctx| Self::new(ctx.add_output(self.id)))
}
}
#[derive(Clone, Copy)]
/**
* Represents a single signed integer encrypted as a ciphertext. Suitable for use
* Represents a single unsigned integer encrypted as a ciphertext. Suitable for use
* as an input or output for a Sunscreen circuit.
*/
pub struct Signed {
pub struct Unsigned {
/**
* The internal graph node id of this input or output.
*/
pub id: NodeIndex,
}
impl Value for Signed {
fn new() -> Self {
with_ctx(|ctx| Self {
id: ctx.add_input(),
})
}
impl FheType for Unsigned {}
impl BfvType for Unsigned {}
fn output(&self) -> Self {
with_ctx(|ctx| Self {
id: ctx.add_output(self.id),
})
}
}
impl Unsigned {}
impl Signed {}
impl Add for Signed {
impl Add for CircuitNode<Unsigned> {
type Output = Self;
fn add(self, other: Self) -> Self {
with_ctx(|ctx| Self {
id: ctx.add_addition(self.id, other.id),
})
with_ctx(|ctx| Self::new(ctx.add_addition(self.id, other.id)))
}
}
impl Mul for Signed {
impl Mul for CircuitNode<Unsigned> {
type Output = Self;
fn mul(self, other: Self) -> Self {
with_ctx(|ctx| Self {
id: ctx.add_multiplication(self.id, other.id),
})
with_ctx(|ctx| Self::new(ctx.add_multiplication(self.id, other.id)))
}
}
impl Shl<u64> for Signed {
impl Shl<u64> for CircuitNode<Unsigned> {
type Output = Self;
fn shl(self, n: u64) -> Self {
let l = LiteralRef::new(Literal::U64(n));
let l = U64LiteralRef::new(n);
with_ctx(|ctx| Self {
id: ctx.add_rotate_left(self.id, l.id),
})
with_ctx(|ctx| Self::new(ctx.add_rotate_left(self.id, l)))
}
}
impl Shr<u64> for Signed {
impl Shr<u64> for CircuitNode<Unsigned> {
type Output = Self;
fn shr(self, n: u64) -> Self {
let l = LiteralRef::new(Literal::U64(n));
let l = U64LiteralRef::new(n);
with_ctx(|ctx| Self {
id: ctx.add_rotate_right(self.id, l.id),
})
with_ctx(|ctx| Self::new(ctx.add_rotate_right(self.id, l)))
}
}