Add tests for signed values. Fix a bug in signed encoder

This commit is contained in:
Rick Weber
2022-01-24 10:21:47 -08:00
parent ee1d5b5f78
commit 3dc12031be
5 changed files with 322 additions and 48 deletions

19
.vscode/launch.json vendored
View File

@@ -124,7 +124,6 @@
"test",
"--no-run",
"--package=sunscreen_compiler",
"can_encode_signed"
],
"filter": {
"name": "unsigned",
@@ -134,6 +133,24 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug signed integration tests in library 'sunscreen_compiler'",
"cargo": {
"args": [
"test",
"--no-run",
"--package=sunscreen_compiler",
],
"filter": {
"name": "signed",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",

View File

@@ -1,6 +1,6 @@
use seal::Plaintext as SealPlaintext;
use crate::types::{ops::GraphCipherAdd, Cipher, GraphCipherMul};
use crate::types::{ops::{GraphCipherAdd, GraphCipherMul, GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherConstAdd, GraphCipherConstMul}, Cipher, };
use crate::{
types::{intern::CircuitNode, BfvType, FheType, TypeNameInstance},
with_ctx, CircuitInputTrait, Params, TypeName as DeriveTypeName, WithContext,
@@ -44,16 +44,16 @@ impl TryIntoPlaintext for Signed {
) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
let mut seal_plaintext = SealPlaintext::new()?;
let unsigned_val = if self.val < 0 { -self.val } else { self.val } as u64;
let signed_val = if self.val < 0 { -self.val } else { self.val } as u64;
let sig_bits = significant_bits(unsigned_val);
let sig_bits = significant_bits(signed_val);
seal_plaintext.resize(sig_bits);
for i in 0..sig_bits {
let bit_value = (unsigned_val & 0x1 << i) >> i;
let bit_value = (signed_val & 0x1 << i) >> i;
let coeff_value = if self.val < 0 {
params.plain_modulus as u64 - bit_value
bit_value * (params.plain_modulus as u64 - bit_value)
} else {
bit_value
};
@@ -137,6 +137,41 @@ impl GraphCipherAdd for Signed {
}
}
impl GraphCipherPlainAdd for Signed {
type Left = Signed;
type Right = Signed;
fn graph_cipher_plain_add(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
impl GraphCipherConstAdd for Signed {
type Left = Self;
type Right = i64;
fn graph_cipher_const_add(
a: CircuitNode<Cipher<Self::Left>>,
b: i64,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let add = ctx.add_addition_plaintext(a.ids[0], lit);
CircuitNode::new(&[add])
})
}
}
impl GraphCipherMul for Signed {
type Left = Signed;
type Right = Signed;
@@ -152,3 +187,38 @@ impl GraphCipherMul for Signed {
})
}
}
impl GraphCipherConstMul for Signed {
type Left = Self;
type Right = i64;
fn graph_cipher_const_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: i64,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let add = ctx.add_multiplication_plaintext(a.ids[0], lit);
CircuitNode::new(&[add])
})
}
}
impl GraphCipherPlainMul for Signed {
type Left = Signed;
type Right = Signed;
fn graph_cipher_plain_mul(
a: CircuitNode<Cipher<Self::Left>>,
b: CircuitNode<Self::Right>,
) -> CircuitNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}

View File

@@ -228,6 +228,18 @@ where
}
}
// plain * cipher
impl<T> Mul<CircuitNode<Cipher<T>>> for CircuitNode<T>
where
T: FheType + GraphCipherPlainMul<Left = T, Right = T>,
{
type Output = CircuitNode<Cipher<T>>;
fn mul(self, rhs: CircuitNode<Cipher<T>>) -> Self::Output {
T::graph_cipher_plain_mul(rhs, self)
}
}
// cipher * literal
impl<T, U> Mul<T> for CircuitNode<Cipher<U>>
where

View File

@@ -1,39 +1,13 @@
use sunscreen_compiler::{
circuit,
types::{bfv::Signed, Cipher},
Compiler, PlainModulusConstraint, Runtime,
CircuitInput, Compiler, PlainModulusConstraint, Runtime,
};
#[test]
fn can_encode_signed() {
fn can_add_cipher_plain() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
a
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(10), &public).unwrap();
let result = runtime.run(&circuit, vec![a], &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, 10.into());
}
#[test]
fn can_add_signed_numbers() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
fn add(a: Cipher<Signed>, b: Signed) -> Cipher<Signed> {
a + b
}
@@ -48,9 +22,11 @@ fn can_add_signed_numbers() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
let b = runtime.encrypt(Signed::from(-5), &public).unwrap();
let b = Signed::from(-5);
let result = runtime.run(&circuit, vec![a, b], &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
@@ -58,13 +34,13 @@ fn can_add_signed_numbers() {
}
#[test]
fn can_multiply_signed_numbers() {
fn can_add_plain_cipher() {
#[circuit(scheme = "bfv")]
fn mul(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
a * b
fn add(a: Signed, b: Cipher<Signed>) -> Cipher<Signed> {
b + a
}
let circuit = Compiler::with_circuit(mul)
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
@@ -74,12 +50,182 @@ fn can_multiply_signed_numbers() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(17), &public).unwrap();
let b = runtime.encrypt(Signed::from(-4), &public).unwrap();
let a = Signed::from(-5);
let b = runtime.encrypt(Signed::from(15), &public).unwrap();
let result = runtime.run(&circuit, vec![a, b], &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, (-68).into());
assert_eq!(c, 10.into());
}
#[test]
fn can_add_cipher_literal() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
a + -4
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, 11.into());
}
#[test]
fn can_add_literal_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
-4 + a
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, 11.into());
}
#[test]
fn can_multiply_cipher_plain() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>, b: Signed) -> Cipher<Signed> {
a * b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
let b = Signed::from(-3);
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, (-45).into());
}
#[test]
fn can_multiply_plain_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Signed, b: Cipher<Signed>) -> Cipher<Signed> {
a * b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = Signed::from(15);
let b = runtime.encrypt(Signed::from(-3), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, (-45).into());
}
#[test]
fn can_multiply_cipher_literal() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
a * -3
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, (-45).into());
}
#[test]
fn can_multiply_literal_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Signed>) -> Cipher<Signed> {
-3 * a
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Signed::from(15), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Signed = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, (-45).into());
}

View File

@@ -36,7 +36,7 @@ fn can_add_unsigned_cipher_plain() {
#[test]
fn can_add_unsigned_plain_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Cipher<Unsigned>, b: Unsigned) -> Cipher<Unsigned> {
fn add(a: Unsigned, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
b + a
}
@@ -50,8 +50,8 @@ fn can_add_unsigned_plain_cipher() {
let (public, secret) = runtime.generate_keys().unwrap();
let a = runtime.encrypt(Unsigned::from(15), &public).unwrap();
let b = Unsigned::from(5);
let a = Unsigned::from(5);
let b = runtime.encrypt(Unsigned::from(15), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
@@ -145,6 +145,35 @@ fn can_add_multiply_cipher_plain() {
assert_eq!(c, 45.into());
}
#[test]
fn can_add_multiply_plain_cipher() {
#[circuit(scheme = "bfv")]
fn add(a: Unsigned, b: Cipher<Unsigned>) -> Cipher<Unsigned> {
a * b
}
let circuit = Compiler::with_circuit(add)
.noise_margin_bits(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&circuit.metadata.params).unwrap();
let (public, secret) = runtime.generate_keys().unwrap();
let a = Unsigned::from(15);
let b = runtime.encrypt(Unsigned::from(3), &public).unwrap();
let args: Vec<CircuitInput> = vec![a.into(), b.into()];
let result = runtime.run(&circuit, args, &public).unwrap();
let c: Unsigned = runtime.decrypt(&result[0], &secret).unwrap();
assert_eq!(c, 45.into());
}
#[test]
fn can_add_multiply_cipher_literal() {
#[circuit(scheme = "bfv")]