diff --git a/sunscreen_compiler/src/error.rs b/sunscreen_compiler/src/error.rs index cd11fcb46..9fd104c0c 100644 --- a/sunscreen_compiler/src/error.rs +++ b/sunscreen_compiler/src/error.rs @@ -23,6 +23,11 @@ pub enum Error { * An internal error occurred in the SEAL library. */ SealError(seal::Error), + + /** + * An Error occurred in the Sunscreen runtime. + */ + RuntimeError(crate::RuntimeError) } impl From for Error { @@ -31,6 +36,12 @@ impl From for Error { } } +impl From for Error { + fn from(err: crate::RuntimeError) -> Self { + Self::RuntimeError(err) + } +} + /** * Wrapper around [`Result`](std::result::Result) with this crate's error type. */ diff --git a/sunscreen_compiler/src/lib.rs b/sunscreen_compiler/src/lib.rs index b3b527fa2..dbc9a65b6 100644 --- a/sunscreen_compiler/src/lib.rs +++ b/sunscreen_compiler/src/lib.rs @@ -33,7 +33,7 @@ pub use error::{Error, Result}; pub use params::PlainModulusConstraint; pub use sunscreen_circuit::{SchemeType, SecurityLevel}; pub use sunscreen_compiler_macros::*; -pub use sunscreen_runtime::{Arguments, CallSignature, CircuitMetadata, Params, RequiredKeys}; +pub use sunscreen_runtime::{Arguments, CallSignature, CircuitMetadata, Error as RuntimeError, Params, RequiredKeys, RuntimeBuilder, Runtime}; #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] /** diff --git a/sunscreen_compiler_macros/src/decrypt.rs b/sunscreen_compiler_macros/src/decrypt.rs new file mode 100644 index 000000000..8f5b9d5db --- /dev/null +++ b/sunscreen_compiler_macros/src/decrypt.rs @@ -0,0 +1,116 @@ +use proc_macro2::{TokenStream}; +use quote::{quote}; +use syn::{punctuated::Punctuated, Ident, parse::{Parse, ParseStream}, Token, Expr, parse_macro_input, ExprPath, Index, Error, Result}; + +pub struct DecryptArgs { + pub return_types: Vec, + pub runtime_ident: Ident, + pub return_bundle_ident: Ident, +} + +impl Parse for DecryptArgs { + fn parse(input: ParseStream) -> Result { + let vars = Punctuated::::parse_terminated(input)?; + + let mut runtime_ident: Option = None; + let mut return_bundle_ident: Option = None; + let mut return_types = vec![]; + + if vars.len() < 2 { + return Err(Error::new_spanned( + vars, + "Usage: decrypt_impl!(runtime, return_val, T1, T2, ...)" + )); + }; + + for (i, var) in vars.iter().enumerate() { + match var { + Expr::Path(p) => { + if i == 0 { + runtime_ident = Some(p.path.get_ident().ok_or(Error::new_spanned(p, "Not a variable"))?.clone()); + } else if i == 1 { + return_bundle_ident = Some(p.path.get_ident().ok_or(Error::new_spanned(p, "Not a variable"))?.clone()); + } else { + return_types.push(p.clone()) + } + }, + _ => { + return Err(Error::new_spanned( + var, + "Usage: decrypt_impl!(runtime, return_val, T1, T2, ...)" + )); + } + }; + } + + Ok(Self { + return_bundle_ident: return_bundle_ident.unwrap(), + runtime_ident: runtime_ident.unwrap(), + return_types: return_types, + }) + } +} + +pub fn decrypt_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let parsed = parse_macro_input!(input as DecryptArgs); + + let tok = decrypt_internal(&parsed).into(); + + //panic!("{}", tok); + + tok +} + +fn decrypt_internal(input: &DecryptArgs) -> TokenStream { + let validate = validate_types(input); + let return_types = &input.return_types; + + TokenStream::from(quote! { + (|| -> sunscreen_compiler::Result<(#(#return_types,)*)> { + #validate + + Ok(()) + })() + }) +} + +fn validate_types(args: &DecryptArgs) -> TokenStream { + let runtime = &args.runtime_ident; + let return_types = &args.return_types; + + let validate = args.return_types.iter().enumerate().map(|(i, t)| { + let id = Index::from(i); + + quote! { + if #t::type_name() != #runtime.get_metadata().signature.returns[#id] { + return Err(Error::ReturnMismatch( + RuntimeError::ReturnMismatch { + expected: #runtime.get_metadata().signature.returns.clone(), + actual: vec![#(#return_types ::type_name(),)*], + } + )); + } + } + }); + + let len = Index::from(args.return_types.len()); + + quote!{ + (|| -> sunscreen_compiler::Result<()> { + use sunscreen_compiler::*; + + if #runtime.get_metadata().signature.returns.len() != #len { + return Err(Error::RuntimeError( + RuntimeError::ReturnMismatch { + expected: #runtime.get_metadata().signature.returns.clone(), + actual: vec![#(#return_types ::type_name(),)*], + } + )); + } + + #(#validate)* + + Ok(()) + })()?; + } +} \ No newline at end of file diff --git a/sunscreen_compiler_macros/src/lib.rs b/sunscreen_compiler_macros/src/lib.rs index e5cb35674..eee23ad33 100644 --- a/sunscreen_compiler_macros/src/lib.rs +++ b/sunscreen_compiler_macros/src/lib.rs @@ -7,6 +7,7 @@ extern crate proc_macro; mod circuit; +mod decrypt; mod error; mod internals; mod type_name; @@ -56,3 +57,29 @@ pub fn circuit( ) -> proc_macro::TokenStream { circuit::circuit_impl(metadata, input) } + +#[proc_macro] +/** + * Decrypts an output parameter set using the given runtime. The first argument + * to this macro is an identifier to a runtime. The second argument is the identifier + * of the return bundle to decrypt. 3rd-Nth arguments are the expected return types + * from the circuit, in order. The macro returns a `Result`. + * + * # Remarks + * This macro validates the given types against the circuit's return interface + * for correctness, then decrypts each item. If successful, this macro returns + * an Ok(T) where T is: + * * The unit type `()` if the circuit returned nothing. + * * The single argument matching the lone type parameter + * if the circuit returns one argument. + * * A tuple of composed of the types passed to the macro if the circuit returns + * more than one argument. + * + * The types passed in arguments 3-N must exactly match those in the return interface + * of the circuit. Circuits that return nothing, while useless, are legal. In this case, + * you should only pass the first two arguments. In the event of failure, this function + * returns the underlying issue. + */ +pub fn decrypt(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + decrypt::decrypt_impl(input) +} \ No newline at end of file diff --git a/sunscreen_compiler_macros/tests/decrypt.rs b/sunscreen_compiler_macros/tests/decrypt.rs new file mode 100644 index 000000000..49fb89337 --- /dev/null +++ b/sunscreen_compiler_macros/tests/decrypt.rs @@ -0,0 +1,31 @@ +use sunscreen_compiler_macros::{decrypt}; +use sunscreen_compiler::{*, types::*}; + +#[test] +fn error_on_no_args() { + #[circuit(scheme = "bfv")] + fn foo(a: Unsigned, b: Unsigned) -> Unsigned { + a + b + } + + let (circuit, metadata) = Compiler::with_circuit(foo) + .noise_margin_bits(5) + .plain_modulus_constraint(PlainModulusConstraint::Raw(500)) + .compile() + .unwrap(); + + let runtime = RuntimeBuilder::new(&metadata).build().unwrap(); + + let (public, secret) = runtime.generate_keys().unwrap(); + + let args = runtime.encrypt_args( + &Arguments::new() + .arg(Unsigned::from(5)) + .arg(Unsigned::from(15)), + &public + ).unwrap(); + + let result = runtime.run(&circuit, args).unwrap(); + + decrypt!(runtime, result).unwrap(); +} \ No newline at end of file diff --git a/sunscreen_runtime/src/args.rs b/sunscreen_runtime/src/args.rs index 846332cf3..cb1026f9d 100644 --- a/sunscreen_runtime/src/args.rs +++ b/sunscreen_runtime/src/args.rs @@ -38,4 +38,9 @@ pub struct InputBundle { pub(crate) galois_keys: Option, pub(crate) relin_keys: Option, pub(crate) public_keys: Option, -} \ No newline at end of file +} + +/** + * The encrypted result of running a circuit. + */ +pub struct OutputBundle(pub(crate) Vec); \ No newline at end of file diff --git a/sunscreen_runtime/src/error.rs b/sunscreen_runtime/src/error.rs index cedd0873f..7b4e12efe 100644 --- a/sunscreen_runtime/src/error.rs +++ b/sunscreen_runtime/src/error.rs @@ -1,6 +1,6 @@ use crate::Type; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] /** * Represents an error that can occur in this crate. */ @@ -48,6 +48,21 @@ pub enum Error { * The given arguments. */ actual: Vec + }, + + /** + * The given return types do not match the circuit interface. + */ + ReturnMismatch { + /** + * The return types in the call signature of the circuit. + */ + expected: Vec, + + /** + * The given return types. + */ + actual: Vec } } diff --git a/sunscreen_runtime/src/runtime.rs b/sunscreen_runtime/src/runtime.rs index 08f3ad582..04ad95914 100644 --- a/sunscreen_runtime/src/runtime.rs +++ b/sunscreen_runtime/src/runtime.rs @@ -46,6 +46,13 @@ impl Runtime { Ok(keys) } + /** + * Returns the metadata for this runtime's associated circuit. + */ + pub fn get_metadata(&self) -> &CircuitMetadata { + &self.metadata + } + /** * Generates Galois keys needed for SIMD rotations. */ @@ -84,7 +91,7 @@ impl Runtime { &self, ir: &Circuit, input_bundle: InputBundle, - ) -> Result> { + ) -> Result { ir.validate()?; // Aside from circuit correctness, check that the required keys are given. @@ -104,9 +111,9 @@ impl Runtime { Context::Seal(context) => { let evaluator = BFVEvaluator::new(&context)?; - Ok(unsafe { + Ok(OutputBundle(unsafe { run_program_unchecked(ir, &input_bundle.ciphertexts, &evaluator, input_bundle.relin_keys, input_bundle.galois_keys) - }) + })) }, } }