diff --git a/Cargo.lock b/Cargo.lock index 1ea70278b..2c8712638 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,7 +100,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" [[package]] -name = "calculator" +name = "calculator_fractional" +version = "0.1.0" +dependencies = [ + "sunscreen_compiler", + "sunscreen_runtime", +] + +[[package]] +name = "calculator_rational" version = "0.1.0" dependencies = [ "sunscreen_compiler", diff --git a/Cargo.toml b/Cargo.toml index ac30cfc17..9c1eb8221 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,8 @@ members = [ "emsdk", "examples/simple_multiply", - "examples/calculator", + "examples/calculator_rational", + "examples/calculator_fractional", "seal", "seal_bench", "sunscreen_backend", diff --git a/examples/calculator_fractional/Cargo.toml b/examples/calculator_fractional/Cargo.toml new file mode 100644 index 000000000..1fc8307ac --- /dev/null +++ b/examples/calculator_fractional/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "calculator_fractional" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +sunscreen_compiler = { path = "../../sunscreen_compiler" } +sunscreen_runtime = { path = "../../sunscreen_runtime" } \ No newline at end of file diff --git a/examples/calculator_fractional/src/main.rs b/examples/calculator_fractional/src/main.rs new file mode 100644 index 000000000..3e72b4b81 --- /dev/null +++ b/examples/calculator_fractional/src/main.rs @@ -0,0 +1,336 @@ +use std::io::{self, Write}; +use std::sync::mpsc::{Receiver, Sender}; +use std::thread::{self, JoinHandle}; +use sunscreen_compiler::{ + circuit, + types::{bfv::Fractional, Cipher}, + Ciphertext, CompiledCircuit, Compiler, Params, PlainModulusConstraint, PublicKey, Runtime, + RuntimeError, +}; + +fn help() { + println!("This is a privacy preserving calculator. You can add, subtract, multiply, divide decimal values. The operation is sent to Bob in cleartext while the operands + are encrypted. Bob chooses a circuit corresponding to the selected operation and computes the result. Additionally, Bob saves the last computed value as `ans`, which you may use as either operand."); + println!("Since this example is to demo encryption, not parsing, you must insert exactly one space between the operand and values."); + println!("Type exit to quit."); + println!("Example:"); + println!(">> 3 + 6.5"); + println!("9.5"); + println!(">> ans / 5"); + println!("1.9"); + println!(""); +} + +enum Term { + Ans, + F64(f64), + Encrypted(Ciphertext), +} + +#[derive(PartialEq)] +enum Operand { + Add, + Sub, + Mul, + Div, +} + +struct Expression { + left: Term, + op: Operand, + right: Term, +} + +enum ParseResult { + Help, + Exit, + Expression(Expression), +} + +enum Error { + ParseError, + /** + * The operation is not supported: + * * division with ans as denominator. + */ + Unsupported, +} + +fn parse_input(line: &str) -> Result { + if line == "help" { + return Ok(ParseResult::Help); + } else if line == "exit" { + return Ok(ParseResult::Exit); + } + + let mut terms = line.split(" "); + + let left = terms.next().ok_or(Error::ParseError)?; + + let left_term = if left == "ans" { + Term::Ans + } else { + Term::F64(left.parse::().map_err(|_| Error::ParseError)?) + }; + + let operand = terms.next().ok_or(Error::ParseError)?; + + let operand = if operand == "+" { + Operand::Add + } else if operand == "-" { + Operand::Sub + } else if operand == "/" { + Operand::Div + } else if operand == "*" { + Operand::Mul + } else { + return Err(Error::ParseError); + }; + + let right = terms.next().ok_or(Error::ParseError)?; + + let right_term = if right == "ans" { + if operand == Operand::Div { + return Err(Error::Unsupported); + } + + Term::Ans + } else { + let right = right.parse::().map_err(|_| Error::ParseError)?; + + let right = if operand == Operand::Div { + 1. / right + } else { + right + }; + + Term::F64(right) + }; + + Ok(ParseResult::Expression(Expression { + left: left_term, + op: operand, + right: right_term, + })) +} + +fn encrypt_term(runtime: &Runtime, public_key: &PublicKey, input: Term) -> Term { + match input { + Term::Ans => Term::Ans, + Term::F64(v) => Term::Encrypted( + runtime + .encrypt(Fractional::<64>::try_from(v).unwrap(), &public_key) + .unwrap(), + ), + _ => { + panic!("This shouldn't happen."); + } + } +} + +fn alice( + send_pub: Sender, + send_calc: Sender, + recv_params: Receiver, + recv_res: Receiver, +) -> JoinHandle<()> { + thread::spawn(move || { + let stdin = io::stdin(); + let mut stdout = io::stdout(); + + println!("Bob's secret calculator. Type `help` for help."); + + // Bob needs to send us the scheme parameters compatible with his circuits. + let params = recv_params.recv().unwrap(); + + let runtime = Runtime::new(¶ms).unwrap(); + + let (public, secret) = runtime.generate_keys().unwrap(); + + // Send Bob a copy of our public keys. + send_pub.send(public.clone()).unwrap(); + + loop { + print!(">> "); + + stdout.flush().unwrap(); + + let mut line = String::new(); + stdin.read_line(&mut line).unwrap(); + let line = line.trim(); + + // Read the line and parse it into operands and an operator. + let parsed = parse_input(&line); + + let Expression { left, right, op } = match parsed { + Ok(ParseResult::Expression(val)) => val, + Ok(ParseResult::Exit) => std::process::exit(0), + Ok(ParseResult::Help) => { + help(); + continue; + } + Err(_) => { + println!("Parse error. Try again."); + continue; + } + }; + + // Encrypt the left and right terms. + let encrypt_left = encrypt_term(&runtime, &public, left); + let encrypt_right = encrypt_term(&runtime, &public, right); + + // Send Bob our encrypted operation. + send_calc + .send(Expression { + left: encrypt_left, + right: encrypt_right, + op: op, + }) + .unwrap(); + + // Get our result from Bob and print it. + let result: Ciphertext = recv_res.recv().unwrap(); + let result: Fractional<64> = match runtime.decrypt(&result, &secret) { + Ok(v) => v, + Err(RuntimeError::TooMuchNoise) => { + println!("Decryption failed: too much noise"); + continue; + } + Err(e) => panic!("{:#?}", e), + }; + let result: f64 = result.into(); + + println!("{}", result); + } + }) +} + +fn compile_circuits() -> ( + CompiledCircuit, + CompiledCircuit, + CompiledCircuit, +) { + #[circuit(scheme = "bfv")] + fn add(a: Cipher>, b: Cipher>) -> Cipher> { + a + b + } + + #[circuit(scheme = "bfv")] + fn sub(a: Cipher>, b: Cipher>) -> Cipher> { + a - b + } + + #[circuit(scheme = "bfv")] + fn mul(a: Cipher>, b: Cipher>) -> Cipher> { + a * b + } + + // In order for ciphertexts to be compatible between circuits, they must all use the same + // parameters. + // With rational numbers, each of these circuits produces roughly the same amount of noise. + // To be sure, we compile one of them with the default parameter search, and explicitly + // pass these parameters when compiling the other circuits so they are compatible. + let mul_circuit = Compiler::with_circuit(mul) + // We need to make the noise margin large enough so we can do a few repeated calculations. + .noise_margin_bits(32) + .plain_modulus_constraint(PlainModulusConstraint::Raw(1_000_000)) + .compile() + .unwrap(); + + let add_circuit = Compiler::with_circuit(add) + .with_params(&mul_circuit.metadata.params) + .compile() + .unwrap(); + + let sub_circuit = Compiler::with_circuit(sub) + .with_params(&mul_circuit.metadata.params) + .compile() + .unwrap(); + + (add_circuit, sub_circuit, mul_circuit) +} + +fn bob( + recv_pub: Receiver, + recv_calc: Receiver, + send_params: Sender, + send_res: Sender, +) -> JoinHandle<()> { + thread::spawn(move || { + let (add, sub, mul) = compile_circuits(); + + send_params.send(add.metadata.params.clone()).unwrap(); + + let public_key = recv_pub.recv().unwrap(); + + let runtime = Runtime::new(&add.metadata.params).unwrap(); + + let mut ans = runtime + .encrypt(Fractional::<64>::try_from(0f64).unwrap(), &public_key) + .unwrap(); + + loop { + let Expression { left, right, op } = recv_calc.recv().unwrap(); + + let left = match left { + Term::Ans => ans.clone(), + Term::Encrypted(c) => c, + _ => panic!("Alice sent us a plaintext!"), + }; + + let right = match right { + Term::Ans => ans.clone(), + Term::Encrypted(c) => c, + _ => panic!("Alice sent us a plaintext!"), + }; + + let mut c = match op { + Operand::Add => runtime.run(&add, vec![left, right], &public_key).unwrap(), + Operand::Sub => runtime.run(&sub, vec![left, right], &public_key).unwrap(), + Operand::Mul => runtime.run(&mul, vec![left, right], &public_key).unwrap(), + // To do division, Alice must send us 1 / b and we + // multiply. + Operand::Div => runtime.run(&mul, vec![left, right], &public_key).unwrap(), + }; + + // Our circuit produces a single value, so move the value out of the vector. + let c = c.drain(0..).next().unwrap(); + ans = c.clone(); + + send_res.send(c).unwrap(); + } + }) +} + +fn main() { + // A channel for Alice to send her public keys to Bob. + let (send_alice_pub, receive_alice_pub) = std::sync::mpsc::channel::(); + + // A channel for Alice to send calculation requests to Bob. + let (send_alice_calc, receive_alice_calc) = std::sync::mpsc::channel::(); + + // A channel for Bob to send scheme params to Alice + let (send_bob_params, receive_bob_params) = std::sync::mpsc::channel::(); + + // A channel for Bob to send calculation results to Alice. + let (send_bob_result, receive_bob_result) = std::sync::mpsc::channel::(); + + // We intentionally break Alice and Bob's roles into different functions to clearly + // show the separation of their roles. In a real application, they're usually on + // different machines communicating over a real protocol (e.g. TCP sockets). + let a = alice( + send_alice_pub, + send_alice_calc, + receive_bob_params, + receive_bob_result, + ); + let b = bob( + receive_alice_pub, + receive_alice_calc, + send_bob_params, + send_bob_result, + ); + + a.join().unwrap(); + b.join().unwrap(); +} diff --git a/examples/calculator/Cargo.toml b/examples/calculator_rational/Cargo.toml similarity index 90% rename from examples/calculator/Cargo.toml rename to examples/calculator_rational/Cargo.toml index 873ceecb6..fcc8dbedf 100644 --- a/examples/calculator/Cargo.toml +++ b/examples/calculator_rational/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "calculator" +name = "calculator_rational" version = "0.1.0" edition = "2021" diff --git a/examples/calculator/src/main.rs b/examples/calculator_rational/src/main.rs similarity index 100% rename from examples/calculator/src/main.rs rename to examples/calculator_rational/src/main.rs diff --git a/sunscreen_compiler/src/types/bfv/mod.rs b/sunscreen_compiler/src/types/bfv/mod.rs index 85283cea3..4c302b1f2 100644 --- a/sunscreen_compiler/src/types/bfv/mod.rs +++ b/sunscreen_compiler/src/types/bfv/mod.rs @@ -1,7 +1,9 @@ mod fractional; mod rational; mod signed; +mod simd; pub use fractional::*; pub use rational::*; pub use signed::*; +pub use simd::*; \ No newline at end of file diff --git a/sunscreen_compiler/src/types/bfv/simd.rs b/sunscreen_compiler/src/types/bfv/simd.rs new file mode 100644 index 000000000..d5bd18a33 --- /dev/null +++ b/sunscreen_compiler/src/types/bfv/simd.rs @@ -0,0 +1,6 @@ +/** + * A vectorized + */ +pub struct Simd { + +} \ No newline at end of file