mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-01-13 23:58:11 -05:00
Can multiply rationals
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::{TypeName, Params, InnerPlaintext, Plaintext, with_ctx};
|
||||
use crate::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, GraphAdd, CircuitNode};
|
||||
use crate::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, GraphAdd, GraphMul, CircuitNode};
|
||||
use sunscreen_runtime::{Error};
|
||||
use std::cmp::Eq;
|
||||
|
||||
@@ -107,6 +107,26 @@ impl GraphAdd for Rational {
|
||||
den_2
|
||||
];
|
||||
|
||||
CircuitNode::new(&ids)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMul for Rational {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
|
||||
with_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
let mul_den = ctx.add_multiplication(a.ids[1], b.ids[1]);
|
||||
|
||||
let ids = [
|
||||
mul_num,
|
||||
mul_den
|
||||
];
|
||||
|
||||
CircuitNode::new(&ids)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -131,4 +131,31 @@ fn can_add_rational_numbers() {
|
||||
let c: Rational = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (3.14).try_into().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mul_rational_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Rational, b: Rational) -> Rational {
|
||||
a * b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Rational::try_from(-3.14).unwrap(), &public).unwrap();
|
||||
let b = runtime.encrypt(Rational::try_from(3.14).unwrap(), &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a, b], &public).unwrap();
|
||||
|
||||
let c: Rational = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (-3.14 * 3.14).try_into().unwrap());
|
||||
}
|
||||
Reference in New Issue
Block a user