Add array support

This commit is contained in:
Rick Weber
2022-04-11 21:01:28 -07:00
parent 895559d551
commit ad930251dc
19 changed files with 361 additions and 65 deletions

18
.vscode/launch.json vendored
View File

@@ -151,6 +151,24 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug array integration tests in library 'sunscreen'",
"cargo": {
"args": [
"test",
"--no-run",
"--package=sunscreen",
],
"filter": {
"name": "array",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",

View File

@@ -1,17 +1,12 @@
use sunscreen::{
fhe_program,
types::{bfv::Rational, Cipher},
Ciphertext, CompiledFheProgram, Compiler, Params, PrivateKey,
Error,
PublicKey,
Runtime,
Ciphertext, CompiledFheProgram, Compiler, Error, Params, PrivateKey, PublicKey, Runtime,
};
#[fhe_program(scheme = "bfv")]
/// This program swaps NU tokens to receive ETH.
fn swap_nu(
nu_tokens_to_trade: Cipher<Rational>,
) -> Cipher<Rational> {
fn swap_nu(nu_tokens_to_trade: Cipher<Rational>) -> Cipher<Rational> {
let total_eth = 100.0;
let total_nu = 1_000.0;
@@ -45,7 +40,9 @@ impl Miner {
nu_tokens_to_trade: Ciphertext,
public_key: &PublicKey,
) -> Result<Ciphertext, Error> {
let results = self.runtime.run(&self.compiled_swap_nu, vec![nu_tokens_to_trade], public_key)?;
let results =
self.runtime
.run(&self.compiled_swap_nu, vec![nu_tokens_to_trade], public_key)?;
Ok(results[0].clone())
}
@@ -77,15 +74,13 @@ impl Alice {
}
pub fn create_transaction(&self, amount: f64) -> Result<Ciphertext, Error> {
Ok(self.runtime
.encrypt(Rational::try_from(amount)?, &self.public_key)?
)
Ok(self
.runtime
.encrypt(Rational::try_from(amount)?, &self.public_key)?)
}
pub fn check_received_eth(&self, received_eth: Ciphertext) -> Result<(), Error> {
let received_eth: Rational = self
.runtime
.decrypt(&received_eth, &self.private_key)?;
let received_eth: Rational = self.runtime.decrypt(&received_eth, &self.private_key)?;
let received_eth: f64 = received_eth.into();
@@ -105,8 +100,7 @@ fn main() -> Result<(), Error> {
let transaction = alice.create_transaction(20.0)?;
let encrypted_received_eth =
miner.run_contract(transaction, &alice.public_key)?;
let encrypted_received_eth = miner.run_contract(transaction, &alice.public_key)?;
alice.check_received_eth(encrypted_received_eth)?;

View File

@@ -17,7 +17,7 @@ use sunscreen::{
bfv::{Batched, Signed},
Cipher, FheType, TypeName,
},
Compiler, FheProgramFn, FheProgramInput, PlainModulusConstraint, Runtime, Error,
Compiler, Error, FheProgramFn, FheProgramInput, PlainModulusConstraint, Runtime,
};
use std::marker::PhantomData;

View File

@@ -158,8 +158,7 @@ fn main() -> Result<(), sunscreen::Error> {
// modulus.
let fhe_program = Compiler::with_fhe_program(dot_product)
.plain_modulus_constraint(PlainModulusConstraint::BatchingMinimum(24))
.compile()
?;
.compile()?;
let end = start.elapsed();
println!("Compiled in {}s", end.as_secs_f64());

View File

@@ -10,7 +10,7 @@ use sunscreen::{
* the result.
*
* `Signed` is Sunscreen's integer type compatible with FHE programs.
*
*
* The `Cipher<T>` type indicates that type `T` is encrypted, thus `Cipher<Signed>`
* is an encrypted [`Signed`] value.
*
@@ -24,26 +24,25 @@ fn simple_multiply(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
fn main() -> Result<(), Error> {
/*
* Here we compile the FHE program we previously declared. In the first step,
* we create our compiler and use the default settings.
*
* Afterwards, we simply compile. The `?` operator is Rust's standard
* error handling mechanism; it returns from the current function (`main`)
* when an error occurs (shouldn't happen) or emits our compiled
* program on success.
*
* The compiler transforms our program and chooses encryption scheme parameters.
* These parameters are a tradeoff between correctness and performance;
* if parameters are too small data corruption occurs, but if they're too large,
* your program runs more slowly than necessary.
*
* Sunscreen allows experts to explicitly set the scheme parameters,
* but the default behavior has Sunscreen pick parameters for you, yielding
* good performance maintaining correctness for nearly all applications.
*
*/
let fhe_program = Compiler::with_fhe_program(simple_multiply)
.compile()?;
* Here we compile the FHE program we previously declared. In the first step,
* we create our compiler and use the default settings.
*
* Afterwards, we simply compile. The `?` operator is Rust's standard
* error handling mechanism; it returns from the current function (`main`)
* when an error occurs (shouldn't happen) or emits our compiled
* program on success.
*
* The compiler transforms our program and chooses encryption scheme parameters.
* These parameters are a tradeoff between correctness and performance;
* if parameters are too small data corruption occurs, but if they're too large,
* your program runs more slowly than necessary.
*
* Sunscreen allows experts to explicitly set the scheme parameters,
* but the default behavior has Sunscreen pick parameters for you, yielding
* good performance maintaining correctness for nearly all applications.
*
*/
let fhe_program = Compiler::with_fhe_program(simple_multiply).compile()?;
/*
* Next, we construct a runtime, which provides the APIs for encryption,
@@ -52,7 +51,7 @@ fn main() -> Result<(), Error> {
let runtime = Runtime::new(&fhe_program.metadata.params)?;
/*
* Here, we generate a public and private key pair. Normally, Alice does this,
* Here, we generate a public and private key pair. Normally, Alice does this,
* sending the public key to bob, who then runs a computation.
*/
let (public_key, private_key) = runtime.generate_keys()?;

View File

@@ -62,7 +62,10 @@ use petgraph::{
Graph,
};
use serde::{Deserialize, Serialize};
use core::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use sunscreen_backend::compile_inplace;
use sunscreen_fhe_program::{
@@ -82,6 +85,8 @@ pub use sunscreen_runtime::{
PrivateKey, PublicKey, RequiredKeys, Runtime, WithContext,
};
use crate::types::{intern::FheProgramNode, NumCiphertexts};
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
/**
* Represents a literal node's data.
@@ -245,6 +250,11 @@ pub struct Context {
* FheProgramNode to impl Copy.
*/
pub indicies_store: Vec<NodeIndex>,
/**
* A cache of FheProgramNodes used by [`get_fhe_program_node()`](Self::get_fhe_program_node).
*/
pub fhe_program_nodes: HashMap<&'static [NodeIndex], std::rc::Rc<dyn Any>>,
}
impl PartialEq for FrontendCompilation {
@@ -300,6 +310,7 @@ impl Context {
},
params: params.clone(),
indicies_store: vec![],
fhe_program_nodes: HashMap::new(),
}
}
@@ -451,6 +462,42 @@ impl Context {
pub fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::Output, i)
}
/**
* Creates an FheProgramNode from the given nodeIds.
*
* # Undefined behavior
* Using the returned reference after this context is dropped will result
* in use-after-free.
*
* # Panics
* Calling this method multiple times with the same ids but a different
* type T will panic.
*/
pub unsafe fn get_fhe_program_node<T>(
&mut self,
ids: &'static [NodeIndex],
) -> &'static FheProgramNode<T>
where
T: NumCiphertexts + Sync + Send + 'static,
{
let node_ref = match self.fhe_program_nodes.get(ids) {
// Panic if T doesn't match the first time we got the same node ids.
Some(n) => n.downcast_ref::<FheProgramNode<T>>().unwrap(),
None => {
let new_node = FheProgramNode::<T>::new(ids);
self.fhe_program_nodes
.insert(ids, std::rc::Rc::new(new_node));
self.fhe_program_nodes[ids]
.downcast_ref::<FheProgramNode<T>>()
.unwrap()
}
};
// Extend the lifetime to 'static.
std::mem::transmute(node_ref)
}
}
impl FrontendCompilation {

View File

@@ -174,6 +174,12 @@ impl<const INT_BITS: usize> NumCiphertexts for Fractional<INT_BITS> {
impl<const INT_BITS: usize> FheProgramInputTrait for Fractional<INT_BITS> {}
impl<const INT_BITS: usize> Default for Fractional<INT_BITS> {
fn default() -> Self {
Self::from(0.0)
}
}
impl<const INT_BITS: usize> TypeName for Fractional<INT_BITS> {
fn type_name() -> Type {
Type {

View File

@@ -30,6 +30,12 @@ impl PartialEq for Rational {
}
}
impl Default for Rational {
fn default() -> Self {
Self::try_from(0.0).unwrap()
}
}
impl NumCiphertexts for Rational {
const NUM_CIPHERTEXTS: usize = Signed::NUM_CIPHERTEXTS + Signed::NUM_CIPHERTEXTS;
}

View File

@@ -41,6 +41,12 @@ impl std::fmt::Display for Signed {
}
}
impl Default for Signed {
fn default() -> Self {
Self::from(0)
}
}
fn significant_bits(val: u64) -> usize {
let bits = std::mem::size_of::<u64>() * 8;

View File

@@ -4,7 +4,7 @@ use crate::{
};
use petgraph::stable_graph::NodeIndex;
use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub};
use std::ops::{Add, Div, Index, Mul, Neg, Shl, Shr, Sub};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
/**
@@ -508,3 +508,29 @@ where
T::graph_cipher_rotate_right(self, x)
}
}
impl<T, const N: usize> Index<usize> for FheProgramNode<[T; N]>
where
T: NumCiphertexts + Sync + Send + 'static,
{
type Output = FheProgramNode<T>;
fn index<'a>(&'a self, idx: usize) -> &'a Self::Output {
// Indexing is nasty because Rust requires we return a reference,
// but we don't have an FheProgramNode sitting around to borrow.
// We have to conjure one out of the ether on the context and return
// a reference to that.
with_ctx(|ctx| {
let stride = self.ids.len() / N;
let start = idx * stride;
let end = (idx + 1) * stride;
let slice = &self.ids[start..end];
// This [] operator is only used during program compilation,
// which always has a context that lives until compilation
// completes. Thus, it is sound to call this method.
unsafe { ctx.get_fhe_program_node(slice) }
})
}
}

View File

@@ -1,5 +1,5 @@
pub use crate::{
types::{intern::FheProgramNode, Cipher, FheType},
types::{intern::FheProgramNode, Cipher, FheType, NumCiphertexts, TypeName},
with_ctx,
};
@@ -20,30 +20,19 @@ pub trait Input {
fn input() -> Self;
}
impl<T> Input for FheProgramNode<Cipher<T>>
where
T: FheType,
{
fn input() -> Self {
let mut ids = Vec::with_capacity(T::NUM_CIPHERTEXTS);
for _ in 0..T::NUM_CIPHERTEXTS {
ids.push(with_ctx(|ctx| ctx.add_ciphertext_input()));
}
FheProgramNode::new(&ids)
}
}
impl<T> Input for FheProgramNode<T>
where
T: FheType,
T: NumCiphertexts + TypeName,
{
fn input() -> Self {
let mut ids = Vec::with_capacity(T::NUM_CIPHERTEXTS);
for _ in 0..T::NUM_CIPHERTEXTS {
ids.push(with_ctx(|ctx| ctx.add_plaintext_input()));
if T::type_name().is_encrypted {
ids.push(with_ctx(|ctx| ctx.add_ciphertext_input()));
} else {
ids.push(with_ctx(|ctx| ctx.add_plaintext_input()));
}
}
FheProgramNode::new(&ids)

View File

@@ -101,7 +101,7 @@ pub trait LaneCount {
fn lane_count() -> usize;
}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
/**
* Declares a type T as being encrypted in an [`fhe_program`](crate::fhe_program).
*/
@@ -130,7 +130,3 @@ where
}
}
}
trait Foo {}
impl<T> Foo for T where T: FheType {}

133
sunscreen/tests/array.rs Normal file
View File

@@ -0,0 +1,133 @@
use sunscreen::{
fhe_program,
types::{bfv::Signed, Cipher},
Compiler, FheProgramInput, PlainModulusConstraint, Runtime,
};
#[test]
fn can_add_array_elements() {
fn add_impl<T, U>(x: T) -> U
where
T: std::ops::Index<usize, Output = U>,
U: std::ops::Add<Output = U> + Copy,
{
x[0] + x[1]
}
#[fhe_program(scheme = "bfv")]
fn add(x: [Cipher<Signed>; 2]) -> Cipher<Signed> {
add_impl(x)
}
let fhe_program = Compiler::with_fhe_program(add)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&fhe_program.metadata.params).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let a = Signed::try_from(2).unwrap();
let b = Signed::try_from(4).unwrap();
let a_c = runtime.encrypt([a, b], &public_key).unwrap();
let result = runtime.run(&fhe_program, vec![a_c], &public_key).unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, add_impl([a, b]));
assert_eq!(c, a + b);
}
#[test]
fn multidimensional_arrays() {
fn determinant_impl<T, U, V>(x: T) -> V
where
T: std::ops::Index<usize, Output = U>,
U: std::ops::Index<usize, Output = V>,
V: std::ops::Add<Output = V> + std::ops::Mul<Output = V> + std::ops::Sub<Output = V> + Copy,
{
x[0][0] * (x[1][1] * x[2][2] - x[1][2] * x[2][1])
- x[0][1] * (x[1][0] * x[2][2] - x[1][2] * x[2][0])
+ x[0][2] * (x[1][0] * x[2][1] - x[1][1] * x[2][0])
}
#[fhe_program(scheme = "bfv")]
fn determinant(x: [[Cipher<Signed>; 3]; 3]) -> Cipher<Signed> {
determinant_impl(x)
}
let fhe_program = Compiler::with_fhe_program(determinant)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&fhe_program.metadata.params).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut matrix = <[[Signed; 3]; 3]>::default();
for i in 0..3 {
for j in 0..3 {
let value: i64 = (3 * i + j) as i64;
matrix[i][j] = Signed::from(value);
}
}
matrix[0][0] = Signed::from(1);
let a_c = runtime.encrypt(matrix, &public_key).unwrap();
let result = runtime.run(&fhe_program, vec![a_c], &public_key).unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, Signed::from(-3));
assert_eq!(c, determinant_impl(matrix));
}
#[test]
fn cipher_plain_arrays() {
#[fhe_program(scheme = "bfv")]
fn dot(a: [Cipher<Signed>; 3], b: [Signed; 3]) -> Cipher<Signed> {
let mut sum = a[0] * b[0];
for i in 1..3 {
sum = sum + a[i] * b[i];
}
sum
}
let fhe_program = Compiler::with_fhe_program(dot)
.additional_noise_budget(5)
.plain_modulus_constraint(PlainModulusConstraint::Raw(500))
.compile()
.unwrap();
let runtime = Runtime::new(&fhe_program.metadata.params).unwrap();
let (public_key, private_key) = runtime.generate_keys().unwrap();
let mut data = <[Signed; 3]>::default();
let mut select = <[Signed; 3]>::default();
for i in 1..4 {
data[i - 1] = Signed::from(i as i64);
select[i - 1] = Signed::from((2 * i) as i64);
}
let select_c = runtime.encrypt(select, &public_key).unwrap();
let args: Vec<FheProgramInput> = vec![select_c.into(), data.into()];
let result = runtime.run(&fhe_program, args, &public_key).unwrap();
let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap();
assert_eq!(c, Signed::from(28));
}

View File

@@ -406,3 +406,8 @@ fn can_negate() {
test_div(1e-23);
test_div(4294967295.);
}
#[test]
fn can_create_default() {
assert_eq!(Into::<f64>::into(Fractional::<64>::default()), 0.0f64);
}

View File

@@ -698,3 +698,8 @@ fn can_neg_cipher() {
assert_eq!(c, neg_impl(a));
}
#[test]
fn can_create_default() {
assert_eq!(Into::<f64>::into(Rational::default()), 0.0f64);
}

View File

@@ -467,3 +467,8 @@ fn can_mul_literal_cipher() {
assert_eq!(c, mul_fn(-4, a));
}
#[test]
fn can_create_default() {
assert_eq!(Into::<i64>::into(Signed::default()), 0);
}

View File

@@ -38,6 +38,7 @@ pub fn fhe_program_impl(
}
FnArg::Typed(t) => match (&*t.ty, &*t.pat) {
(Type::Path(_), Pat::Ident(i)) => (t, &i.ident),
(Type::Array(_), Pat::Ident(i)) => (t, &i.ident),
_ => {
return proc_macro::TokenStream::from(quote! {
compile_error!("fhe_program arguments' name must be a simple identifier and type must be a plain path.");

View File

@@ -0,0 +1,60 @@
use crate::{
FheProgramInputTrait, InnerPlaintext, NumCiphertexts, Params, Plaintext, Result,
TryIntoPlaintext, Type, TypeName, TypeNameInstance, WithContext,
};
use seal::Plaintext as SealPlaintext;
impl<T, const N: usize> TryIntoPlaintext for [T; N]
where
T: TryIntoPlaintext,
Self: TypeName,
{
fn try_into_plaintext(&self, params: &Params) -> Result<Plaintext> {
let element_plaintexts = self
.iter()
.map(|v| v.try_into_plaintext(params))
.collect::<Result<Vec<Plaintext>>>()?
.drain(0..)
.flat_map(|p| match p.inner {
InnerPlaintext::Seal(v) => v,
})
.collect::<Vec<WithContext<SealPlaintext>>>();
Ok(Plaintext {
inner: InnerPlaintext::Seal(element_plaintexts),
data_type: Self::type_name(),
})
}
}
impl<T, const N: usize> TypeName for [T; N]
where
T: TypeName,
{
fn type_name() -> Type {
let inner_type = T::type_name();
Type {
name: format!("[{};{}]", inner_type.name, N),
..inner_type
}
}
}
impl<T, const N: usize> TypeNameInstance for [T; N]
where
T: TypeName,
{
fn type_name_instance(&self) -> Type {
Self::type_name()
}
}
impl<T, const N: usize> FheProgramInputTrait for [T; N] where T: TypeName + TryIntoPlaintext {}
impl<T, const N: usize> NumCiphertexts for [T; N]
where
T: NumCiphertexts,
{
const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS * N;
}

View File

@@ -4,6 +4,7 @@
//! This crate contains the types and functions for executing a Sunscreen FHE program
//! (i.e. an [`FheProgram`](sunscreen_fhe_program::FheProgram)).
mod array;
mod error;
mod keys;
mod metadata;