mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
Add tests for signed values. Fix a bug in signed encoder
This commit is contained in:
19
.vscode/launch.json
vendored
19
.vscode/launch.json
vendored
@@ -124,7 +124,6 @@
|
||||
"test",
|
||||
"--no-run",
|
||||
"--package=sunscreen_compiler",
|
||||
"can_encode_signed"
|
||||
],
|
||||
"filter": {
|
||||
"name": "unsigned",
|
||||
@@ -134,6 +133,24 @@
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug signed integration tests in library 'sunscreen_compiler'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--package=sunscreen_compiler",
|
||||
],
|
||||
"filter": {
|
||||
"name": "signed",
|
||||
"kind": "test"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use seal::Plaintext as SealPlaintext;
|
||||
|
||||
use crate::types::{ops::GraphCipherAdd, Cipher, GraphCipherMul};
|
||||
use crate::types::{ops::{GraphCipherAdd, GraphCipherMul, GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherConstAdd, GraphCipherConstMul}, Cipher, };
|
||||
use crate::{
|
||||
types::{intern::CircuitNode, BfvType, FheType, TypeNameInstance},
|
||||
with_ctx, CircuitInputTrait, Params, TypeName as DeriveTypeName, WithContext,
|
||||
@@ -44,16 +44,16 @@ impl TryIntoPlaintext for Signed {
|
||||
) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
|
||||
let mut seal_plaintext = SealPlaintext::new()?;
|
||||
|
||||
let unsigned_val = if self.val < 0 { -self.val } else { self.val } as u64;
|
||||
let signed_val = if self.val < 0 { -self.val } else { self.val } as u64;
|
||||
|
||||
let sig_bits = significant_bits(unsigned_val);
|
||||
let sig_bits = significant_bits(signed_val);
|
||||
seal_plaintext.resize(sig_bits);
|
||||
|
||||
for i in 0..sig_bits {
|
||||
let bit_value = (unsigned_val & 0x1 << i) >> i;
|
||||
let bit_value = (signed_val & 0x1 << i) >> i;
|
||||
|
||||
let coeff_value = if self.val < 0 {
|
||||
params.plain_modulus as u64 - bit_value
|
||||
bit_value * (params.plain_modulus as u64 - bit_value)
|
||||
} else {
|
||||
bit_value
|
||||
};
|
||||
@@ -137,6 +137,41 @@ impl GraphCipherAdd for Signed {
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphCipherPlainAdd for Signed {
|
||||
type Left = Signed;
|
||||
type Right = Signed;
|
||||
|
||||
fn graph_cipher_plain_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphCipherConstAdd for Signed {
|
||||
type Left = Self;
|
||||
type Right = i64;
|
||||
|
||||
fn graph_cipher_const_add(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: i64,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let add = ctx.add_addition_plaintext(a.ids[0], lit);
|
||||
|
||||
CircuitNode::new(&[add])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphCipherMul for Signed {
|
||||
type Left = Signed;
|
||||
type Right = Signed;
|
||||
@@ -152,3 +187,38 @@ impl GraphCipherMul for Signed {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphCipherConstMul for Signed {
|
||||
type Left = Self;
|
||||
type Right = i64;
|
||||
|
||||
fn graph_cipher_const_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: i64,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let add = ctx.add_multiplication_plaintext(a.ids[0], lit);
|
||||
|
||||
CircuitNode::new(&[add])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphCipherPlainMul for Signed {
|
||||
type Left = Signed;
|
||||
type Right = Signed;
|
||||
|
||||
fn graph_cipher_plain_mul(
|
||||
a: CircuitNode<Cipher<Self::Left>>,
|
||||
b: CircuitNode<Self::Right>,
|
||||
) -> CircuitNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -228,6 +228,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// plain * cipher
|
||||
impl<T> Mul<CircuitNode<Cipher<T>>> for CircuitNode<T>
|
||||
where
|
||||
T: FheType + GraphCipherPlainMul<Left = T, Right = T>,
|
||||
{
|
||||
type Output = CircuitNode<Cipher<T>>;
|
||||
|
||||
fn mul(self, rhs: CircuitNode<Cipher<T>>) -> Self::Output {
|
||||
T::graph_cipher_plain_mul(rhs, self)
|
||||
}
|
||||
}
|
||||
|
||||
// cipher * literal
|
||||
impl<T, U> Mul<T> for CircuitNode<Cipher<U>>
|
||||
where
|
||||
|
||||
@@ -1,39 +1,13 @@
|
||||
use sunscreen_compiler::{
|
||||
circuit,
|
||||
types::{bfv::Signed, Cipher},
|
||||
Compiler, PlainModulusConstraint, Runtime,
|
||||
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_encode_signed() {
|
||||
fn can_add_cipher_plain() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
|
||||
a
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(10), &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a], &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, 10.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_signed_numbers() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
|
||||
fn add(a: Cipher<Signed>, b: Signed) -> Cipher<Signed> {
|
||||
a + b
|
||||
}
|
||||
|
||||
@@ -48,9 +22,11 @@ fn can_add_signed_numbers() {
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
let b = runtime.encrypt(Signed::from(-5), &public).unwrap();
|
||||
let b = Signed::from(-5);
|
||||
|
||||
let result = runtime.run(&circuit, vec![a, b], &public).unwrap();
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
@@ -58,13 +34,13 @@ fn can_add_signed_numbers() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_signed_numbers() {
|
||||
fn can_add_plain_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn mul(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
|
||||
a * b
|
||||
fn add(a: Signed, b: Cipher<Signed>) -> Cipher<Signed> {
|
||||
b + a
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(mul)
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
@@ -74,12 +50,182 @@ fn can_multiply_signed_numbers() {
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(17), &public).unwrap();
|
||||
let b = runtime.encrypt(Signed::from(-4), &public).unwrap();
|
||||
let a = Signed::from(-5);
|
||||
let b = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
|
||||
let result = runtime.run(&circuit, vec![a, b], &public).unwrap();
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (-68).into());
|
||||
assert_eq!(c, 10.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_cipher_literal() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
|
||||
a + -4
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, 11.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_literal_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
|
||||
-4 + a
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, 11.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_cipher_plain() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>, b: Signed) -> Cipher<Signed> {
|
||||
a * b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
let b = Signed::from(-3);
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (-45).into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_plain_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Signed, b: Cipher<Signed>) -> Cipher<Signed> {
|
||||
a * b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Signed::from(15);
|
||||
let b = runtime.encrypt(Signed::from(-3), &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (-45).into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_cipher_literal() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
|
||||
a * -3
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (-45).into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_literal_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
|
||||
-3 * a
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, (-45).into());
|
||||
}
|
||||
@@ -36,7 +36,7 @@ fn can_add_unsigned_cipher_plain() {
|
||||
#[test]
|
||||
fn can_add_unsigned_plain_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Cipher<Unsigned>, b: Unsigned) -> Cipher<Unsigned> {
|
||||
fn add(a: Unsigned, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
|
||||
b + a
|
||||
}
|
||||
|
||||
@@ -50,8 +50,8 @@ fn can_add_unsigned_plain_cipher() {
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = runtime.encrypt(Unsigned::from(15), &public).unwrap();
|
||||
let b = Unsigned::from(5);
|
||||
let a = Unsigned::from(5);
|
||||
let b = runtime.encrypt(Unsigned::from(15), &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
@@ -145,6 +145,35 @@ fn can_add_multiply_cipher_plain() {
|
||||
assert_eq!(c, 45.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_multiply_plain_cipher() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
fn add(a: Unsigned, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
|
||||
a * b
|
||||
}
|
||||
|
||||
let circuit = Compiler::with_circuit(add)
|
||||
.noise_margin_bits(5)
|
||||
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
|
||||
.compile()
|
||||
.unwrap();
|
||||
|
||||
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
|
||||
|
||||
let (public, secret) = runtime.generate_keys().unwrap();
|
||||
|
||||
let a = Unsigned::from(15);
|
||||
let b = runtime.encrypt(Unsigned::from(3), &public).unwrap();
|
||||
|
||||
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
|
||||
|
||||
let result = runtime.run(&circuit, args, &public).unwrap();
|
||||
|
||||
let c: Unsigned = runtime.decrypt(&result[0], &secret).unwrap();
|
||||
|
||||
assert_eq!(c, 45.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_multiply_cipher_literal() {
|
||||
#[circuit(scheme = "bfv")]
|
||||
|
||||
Reference in New Issue
Block a user