From 758cc51f61bbbf2f07d22c6cd74dfd6379aba210 Mon Sep 17 00:00:00 2001 From: Rick Weber Date: Thu, 16 Dec 2021 11:45:08 -0800 Subject: [PATCH] Can multiply rationals --- sunscreen_compiler/src/types/rational.rs | 22 ++++++++++++++++++- sunscreen_compiler/tests/types.rs | 27 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/sunscreen_compiler/src/types/rational.rs b/sunscreen_compiler/src/types/rational.rs index 849b71cce..d90a63e1a 100644 --- a/sunscreen_compiler/src/types/rational.rs +++ b/sunscreen_compiler/src/types/rational.rs @@ -1,5 +1,5 @@ use crate::{TypeName, Params, InnerPlaintext, Plaintext, with_ctx}; -use crate::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, GraphAdd, CircuitNode}; +use crate::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, GraphAdd, GraphMul, CircuitNode}; use sunscreen_runtime::{Error}; use std::cmp::Eq; @@ -107,6 +107,26 @@ impl GraphAdd for Rational { den_2 ]; + CircuitNode::new(&ids) + }) + } +} + +impl GraphMul for Rational { + type Left = Self; + type Right = Self; + + fn graph_mul(a: CircuitNode, b: CircuitNode) -> CircuitNode { + with_ctx(|ctx| { + // Scale each numinator by the other's denominator. + let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]); + let mul_den = ctx.add_multiplication(a.ids[1], b.ids[1]); + + let ids = [ + mul_num, + mul_den + ]; + CircuitNode::new(&ids) }) } diff --git a/sunscreen_compiler/tests/types.rs b/sunscreen_compiler/tests/types.rs index a3af70ec9..c03509e23 100644 --- a/sunscreen_compiler/tests/types.rs +++ b/sunscreen_compiler/tests/types.rs @@ -131,4 +131,31 @@ fn can_add_rational_numbers() { let c: Rational = runtime.decrypt(&result[0], &secret).unwrap(); assert_eq!(c, (3.14).try_into().unwrap()); +} + +#[test] +fn can_mul_rational_numbers() { + #[circuit(scheme = "bfv")] + fn add(a: Rational, b: Rational) -> Rational { + a * b + } + + let circuit = Compiler::with_circuit(add) + .noise_margin_bits(5) + .plain_modulus_constraint(PlainModulusConstraint::Raw(500)) + .compile() + .unwrap(); + + let runtime = Runtime::new(&circuit.metadata.params).unwrap(); + + let (public, secret) = runtime.generate_keys().unwrap(); + + let a = runtime.encrypt(Rational::try_from(-3.14).unwrap(), &public).unwrap(); + let b = runtime.encrypt(Rational::try_from(3.14).unwrap(), &public).unwrap(); + + let result = runtime.run(&circuit, vec![a, b], &public).unwrap(); + + let c: Rational = runtime.decrypt(&result[0], &secret).unwrap(); + + assert_eq!(c, (-3.14 * 3.14).try_into().unwrap()); } \ No newline at end of file