Files
Sunscreen/sunscreen/tests/array.rs
Sam Tay 5faf981178 Hackathon; or, various compiler improvements (#272)
* Misc doc fixes

* Fix sunscreen zkp exports

* Fix broken api doc reference

* Add starter zkp example

* Use ZkpRuntime::new in sudoku example

* Use ? over unwrap in zkp examples

* Refactor pattern matching

No functionality changes

* Disallow `mut` args in fhe/zkp programs

* Play around with allowing cipher|plain values

* Allow user-declared plain|cipher values

NOTE: Not fully implemented. Will not work on Rational types until we
factor out literal->plaintext into a proper trait.

This allows, e.g.

```rust
fn simple_sum(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
    let mut sum = fhe_var(0);
    sum = sum + a;
    sum = sum + b;
    fhe_out(sum)
}
````

* Refactor array::output()

* More targeted compiler error messages on invalid return values

* Add option for var.into() rather than fhe_out(var)

* Fix incorrect macro invocation

* Add trait for inserting const as plaintext

* Impl all arithmetic operations for indeterminate nodes

* Offer an `fhe_var!` macro

* Offer a zkp_var! macro

* Offer a (safe) debug impl for zkp program nodes

* Fix tests

* Add test for fhe_var!

* Simplify tf out of sudoku

* Simplify fhe input() codegen

* Marginally better compiler error messages on invalid fhe program arg types

* Fix error for fhe program argument attributes

* Throw appropriate compiler error on generics

* Silence clippy warnings in generated code

These I think are typically ignored by default when consuming proc macros but might as well be explicit

* Fixup quote_spanned invocations

Unsure how important this is, but see here: https://docs.rs/quote/latest/quote/macro.quote_spanned.html#syntax

* Automatically call `.into()` on fhe prog return values

* Factor fhe_program_impl

* Further factor fhe_program_impl

So that token generation happens in helper methods, and the ultimate output() func is readable

* Fix doctests

* Fix clippy warnings

* Remove TODOs

* Add missing example runs to CI

* Oops: fix 232 > 64

* Allow arbitrary expressions in fhe_var!

* Use custom "into" to support impls on []

* Support explicit #[private] params

* Remove `backend = "bulletproofs"` attribute

* Address PR reveiw
2023-07-05 17:07:21 -05:00

289 lines
7.2 KiB
Rust

#![allow(clippy::needless_range_loop)]
use sunscreen::{
fhe_program,
types::{bfv::Signed, Cipher},
Compiler, FheProgramInput, PlainModulusConstraint, Runtime,
};
#[test]
fn can_add_array_elements() {
fn add_impl<T, U>(x: T) -> U
where
T: std::ops::Index<usize, Output = U>,
U: std::ops::Add<Output = U> + Copy,
{
x[0] + x[1]
}
#[fhe_program(scheme = "bfv")]
fn add(x: [Cipher<Signed>; 2]) -> Cipher<Signed> {
add_impl(x)
}
let app = Compiler::new()
.fhe_program(add)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new_fhe(app.params()).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let a = Signed::try_from(2).unwrap();
let b = Signed::try_from(4).unwrap();
let a_c = runtime.encrypt([a, b], &public_key).unwrap();
let result = runtime
.run(app.get_fhe_program(add).unwrap(), vec![a_c], &public_key)
.unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, add_impl([a, b]));
assert_eq!(c, a + b);
}
#[test]
fn multidimensional_arrays() {
fn determinant_impl<T, U, V>(x: T) -> V
where
T: std::ops::Index<usize, Output = U>,
U: std::ops::Index<usize, Output = V>,
V: std::ops::Add<Output = V> + std::ops::Mul<Output = V> + std::ops::Sub<Output = V> + Copy,
{
x[0][0] * (x[1][1] * x[2][2] - x[1][2] * x[2][1])
- x[0][1] * (x[1][0] * x[2][2] - x[1][2] * x[2][0])
+ x[0][2] * (x[1][0] * x[2][1] - x[1][1] * x[2][0])
}
#[fhe_program(scheme = "bfv")]
fn determinant(x: [[Cipher<Signed>; 3]; 3]) -> Cipher<Signed> {
determinant_impl(x)
}
let app = Compiler::new()
.fhe_program(determinant)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new_fhe(app.params()).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut matrix = <[[Signed; 3]; 3]>::default();
for i in 0..3 {
for j in 0..3 {
let value: i64 = (3 * i + j) as i64;
matrix[i][j] = Signed::from(value);
}
}
matrix[0][0] = Signed::from(1);
let a_c = runtime.encrypt(matrix, &public_key).unwrap();
let result = runtime
.run(
app.get_fhe_program(determinant).unwrap(),
vec![a_c],
&public_key,
)
.unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, Signed::from(-3));
assert_eq!(c, determinant_impl(matrix));
}
#[test]
fn multidimensional_is_row_major() {
#[fhe_program(scheme = "bfv")]
fn determinant(x: [[Cipher<Signed>; 3]; 3]) -> Cipher<Signed> {
x[1][2]
}
let app = Compiler::new()
.fhe_program(determinant)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new_fhe(app.params()).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut matrix = <[[Signed; 3]; 3]>::default();
for i in 0..3 {
for j in 0..3 {
let value: i64 = (3 * i + j) as i64;
matrix[i][j] = Signed::from(value);
}
}
matrix[0][0] = Signed::from(1);
let a_c = runtime.encrypt(matrix, &public_key).unwrap();
let result = runtime
.run(
app.get_fhe_program(determinant).unwrap(),
vec![a_c],
&public_key,
)
.unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, Signed::from(5));
assert_eq!(c, matrix[1][2]);
}
#[test]
fn cipher_plain_arrays() {
#[fhe_program(scheme = "bfv")]
fn dot(a: [Cipher<Signed>; 3], b: [Signed; 3]) -> Cipher<Signed> {
let mut sum = a[0] * b[0];
for i in 1..3 {
sum = sum + a[i] * b[i];
}
sum
}
let app = Compiler::new()
.fhe_program(dot)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new_fhe(app.params()).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut data = <[Signed; 3]>::default();
let mut select = <[Signed; 3]>::default();
for i in 1..4 {
data[i - 1] = Signed::from(i as i64);
select[i - 1] = Signed::from((2 * i) as i64);
}
let select_c = runtime.encrypt(select, &public_key).unwrap();
let args: Vec<FheProgramInput> = vec![select_c.into(), data.into()];
let result = runtime
.run(app.get_fhe_program(dot).unwrap(), args, &public_key)
.unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, Signed::from(28));
}
#[test]
fn can_mutate_array() {
#[fhe_program(scheme = "bfv")]
fn mult(a: [Cipher<Signed>; 6]) -> Cipher<Signed> {
let mut a = a;
for i in 0..a.len() {
a[i] = a[i] * 2
}
let mut sum = a[0];
for i in 1..a.len() {
sum = sum + a[i];
}
sum
}
let app = Compiler::new()
.fhe_program(mult)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new_fhe(app.params()).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut a = <[Signed; 6]>::default();
for i in 0..a.len() {
a[i] = Signed::from((i) as i64);
}
let a_enc = runtime.encrypt(a, &public_key).unwrap();
let result = runtime
.run(app.get_fhe_program(mult).unwrap(), vec![a_enc], &public_key)
.unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, Signed::from(30));
}
#[test]
fn can_return_array() {
#[fhe_program(scheme = "bfv")]
fn mult(a: [Cipher<Signed>; 6]) -> [Cipher<Signed>; 6] {
let mut a = a;
for i in 0..a.len() {
a[i] = a[i] * 2
}
a
}
let app = Compiler::new()
.fhe_program(mult)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new_fhe(app.params()).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut a = <[Signed; 6]>::default();
for i in 0..a.len() {
a[i] = Signed::from((i) as i64);
}
let a_enc = runtime.encrypt(a, &public_key).unwrap();
let result = runtime
.run(app.get_fhe_program(mult).unwrap(), vec![a_enc], &public_key)
.unwrap();
let c: [Signed; 6] = runtime.decrypt(&result[0], &private_key).unwrap();
let expected: [Signed; 6] = a
.iter()
.map(|x| *x * 2)
.collect::<Vec<Signed>>()
.try_into()
.unwrap();
assert_eq!(c, expected);
}