implemented overflow detection for fhe programs

This commit is contained in:
Matthew Liu
2023-07-30 16:05:01 -07:00
parent c1ec13f787
commit e2f51008e3
6 changed files with 78 additions and 108 deletions

View File

@@ -21,7 +21,6 @@ use std::hash::{Hash, Hasher};
#[cfg(feature = "debugger")]
use std::collections::hash_map::DefaultHasher;
/**
* Stores debug information about groups and stack traces.
*/

View File

@@ -134,4 +134,4 @@ impl Default for GroupLookup {
fn default() -> Self {
Self::new()
}
}
}

View File

@@ -181,14 +181,13 @@ mod tests {
}
fn get_graph() -> CompilationResult<Operation> {
fn make_node(operation: Operation) -> NodeInfo<Operation> {
NodeInfo {
operation,
NodeInfo {
operation,
#[cfg(feature = "debugger")]
group_id: 0,
#[cfg(feature = "debugger")]
stack_id: 0
stack_id: 0,
}
}
let mut fe = CompilationResult::new();
@@ -197,17 +196,11 @@ mod tests {
#[cfg(feature = "debugger")]
{
let in_1 = fe.add_node(make_node(
Operation::PublicInput(NodeIndex::from(0))
));
let in_1 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(0))));
let in_2 = fe.add_node(make_node(
Operation::PublicInput(NodeIndex::from(1))
));
let in_2 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(1))));
let in_3 = fe.add_node(make_node(
Operation::PublicInput(NodeIndex::from(2))
));
let in_3 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(2))));
// Layer 2
// sub_2 gets eliminated.
@@ -305,12 +298,12 @@ mod tests {
fn get_expected() -> CompilationResult<Operation> {
fn make_node(operation: Operation) -> NodeInfo<Operation> {
NodeInfo {
operation,
NodeInfo {
operation,
#[cfg(feature = "debugger")]
group_id: 0,
#[cfg(feature = "debugger")]
stack_id: 0
stack_id: 0,
}
}
@@ -318,19 +311,12 @@ mod tests {
#[cfg(feature = "debugger")]
{
// Layer 1
let in_1 = fe.add_node(make_node(
Operation::PublicInput(NodeIndex::from(0))
));
let in_1 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(0))));
let in_2 = fe.add_node(make_node(
Operation::PublicInput(NodeIndex::from(1))
));
let in_2 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(1))));
let in_3 = fe.add_node(make_node(
Operation::PublicInput(NodeIndex::from(2))
));
let in_3 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(2))));
// Layer 2
// sub_2 gets eliminated.

View File

@@ -11,10 +11,9 @@ pub fn fhe_program_impl(
metadata: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input_fn = parse_macro_input!(input as ItemFn);
let raw_fn = input_fn.span().source_text().unwrap_or_default();
let fhe_program_name = &input_fn.sig.ident;
let vis = &input_fn.vis;
let body = &input_fn.block;

View File

@@ -2,15 +2,13 @@ use crate::Ciphertext;
use crate::InnerCiphertext;
use crate::InnerPlaintext;
use crate::PrivateKey;
use crate::SealCiphertext;
use crate::SealData;
use crate::SealPlaintext;
use crate::WithContext;
use petgraph::stable_graph::NodeIndex;
use petgraph::stable_graph::StableGraph;
use petgraph::Direction::Incoming;
use rayon::vec;
use crate::SealPlaintext;
use crate::SealCiphertext;
use crate::Plaintext;
use seal_fhe::BfvEncryptionParametersBuilder;
use seal_fhe::Context;
use seal_fhe::Decryptor;
@@ -18,11 +16,11 @@ use seal_fhe::Modulus;
use seal_fhe::SecretKey;
use semver::Version;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use sunscreen_compiler_common::GraphQuery;
use sunscreen_compiler_common::Operation as OperationTrait;
use sunscreen_compiler_common::Type;
use sunscreen_compiler_common::{EdgeInfo, NodeInfo};
use std::collections::{HashMap, VecDeque};
use sunscreen_fhe_program::Operation as FheOperation;
#[derive(Clone, Serialize, Deserialize)]
@@ -45,10 +43,7 @@ pub struct BfvNodeType {
/**
* Gets the multiplicative depth of a node in the compilation graph.
*/
pub fn get_mult_depth<O>(
graph: &StableGraph<NodeInfo<O>, EdgeInfo>,
start_node: NodeIndex,
) -> u64
pub fn get_mult_depth<O>(graph: &StableGraph<NodeInfo<O>, EdgeInfo>, start_node: NodeIndex) -> u64
where
O: OperationTrait,
{
@@ -62,7 +57,12 @@ where
while let Some((node, depth)) = queue.pop_front() {
visited.insert(node, true);
let curr_depth = depth + graph.node_weight(node).unwrap().operation.is_multiplication() as u64;
let curr_depth = depth
+ graph
.node_weight(node)
.unwrap()
.operation
.is_multiplication() as u64;
max_depth = max_depth.max(curr_depth);
@@ -86,8 +86,7 @@ pub fn overflow_occurred(
p: u64,
pk: &PrivateKey,
program_data: &[Option<SealData>],
) -> bool
{
) -> bool {
// Overflow only occurs at the output of an operation node
let mut parents = graph.neighbors_directed(node, Incoming);
if parents.clone().count() != 1 {
@@ -98,7 +97,10 @@ pub fn overflow_occurred(
// Extract operand data
let parent = parents.next().unwrap();
let (left_op, right_op) = query.get_binary_operands(parent).expect(&format!("Parent node of {:?} is not a binary operation", node.index()));
let (left_op, right_op) = query.get_binary_operands(parent).expect(&format!(
"Parent node of {:?} is not a binary operation",
node.index()
));
let operand_nodes = [left_op, right_op];
let mut op_coefficients: [Vec<Vec<u64>>; 2] = [Vec::new(), Vec::new()];
@@ -123,7 +125,7 @@ pub fn overflow_occurred(
SealData::Ciphertext(ct) => {
let ciphertext = create_ciphertext_from_seal_data(ct, pk);
op_coefficients[idx] = decrypt_inner_cipher(ciphertext, &pk.0.data);
},
}
SealData::Plaintext(pt) => {
let plaintext = create_plaintext_from_seal_data(pt, pk);
op_coefficients[idx] = decrypt_inner_plain(plaintext);
@@ -131,48 +133,18 @@ pub fn overflow_occurred(
};
}
// Extract current node's data
let node_data = program_data
.get(node.index())
.unwrap_or_else(|| {
panic!(
"Couldn't find Option<SealData> in index {:?} of program_data",
node.index()
)
})
.clone()
.unwrap_or_else(|| {
panic!(
"Option<SealData> in index {:?} was None",
node.index()
)
});
let result = match node_data {
SealData::Ciphertext(ct) => {
let ciphertext = create_ciphertext_from_seal_data(ct, pk);
decrypt_inner_cipher(ciphertext, &pk.0.data)
},
SealData::Plaintext(pt) => {
let plaintext = create_plaintext_from_seal_data(pt, pk);
decrypt_inner_plain(plaintext)
},
};
// Overflow only occurs on arithmetic operations involving at least 1 ciphertext
match graph.node_weight(node).unwrap().operation {
FheOperation::Multiply | FheOperation::MultiplyPlaintext => mul_overflow_occurred(op_coefficients, result, p),
FheOperation::Add | FheOperation::AddPlaintext => add_overflow_occurred(op_coefficients, result, p),
FheOperation::Sub | FheOperation::SubPlaintext => sub_overflow_occurred(op_coefficients, result, p),
_ => false
FheOperation::Multiply | FheOperation::MultiplyPlaintext => {
mul_overflow_occurred(op_coefficients, p)
}
FheOperation::Add | FheOperation::AddPlaintext => add_overflow_occurred(op_coefficients, p),
FheOperation::Sub | FheOperation::SubPlaintext => sub_overflow_occurred(op_coefficients, p),
_ => false,
}
}
pub fn add_overflow_occurred(
operands: [Vec<Vec<u64>>; 2],
result: Vec<Vec<u64>>,
p: u64
) -> bool
{
pub fn add_overflow_occurred(operands: [Vec<Vec<u64>>; 2], p: u64) -> bool {
for (c0, c1) in operands[0].iter().zip(operands[1].iter()) {
// Addition overflow
for i in 0..c0.len() {
@@ -184,25 +156,35 @@ pub fn add_overflow_occurred(
}
}
}
true
false
}
pub fn sub_overflow_occurred(
operands: [Vec<Vec<u64>>; 2],
result: Vec<Vec<u64>>,
p: u64
) -> bool
{
true
pub fn sub_overflow_occurred(operands: [Vec<Vec<u64>>; 2], p: u64) -> bool {
let negated_coeffs = operands[1]
.iter()
.map(|vec| vec.iter().map(|x| p - x).collect())
.collect();
let new_operands = [operands[0].clone(), negated_coeffs];
add_overflow_occurred(new_operands, p)
}
pub fn mul_overflow_occurred(
operands: [Vec<Vec<u64>>; 2],
result: Vec<Vec<u64>>,
p: u64
) -> bool
{
true
pub fn mul_overflow_occurred(operands: [Vec<Vec<u64>>; 2], p: u64) -> bool {
for (poly1, poly2) in operands[0].iter().zip(&operands[1]) {
let product = polynomial_mult(poly1, poly2);
let product_mod = polynomial_mult_mod(poly1, poly2, p);
for (coeff, coeff_mod) in product.iter().zip(&product_mod) {
let signed_coeff = if *coeff_mod > p / 2 {
*coeff_mod as i64 - p as i64
} else {
*coeff_mod as i64
};
if *coeff as i64 != signed_coeff {
return true;
}
}
}
false
}
fn polynomial_mult(a: &[u64], b: &[u64]) -> Vec<u64> {
@@ -274,12 +256,12 @@ pub fn decrypt_inner_plain(inner_plain: InnerPlaintext) -> Vec<Vec<u64>> {
for i in 0..inner_plain.as_seal_plaintext().unwrap().len() {
let inner = inner_plain.as_seal_plaintext().unwrap().get(i).unwrap();
let mut inner_coefficients = Vec::new();
for j in 0..inner.len() {
inner_coefficients.push(inner.get_coefficient(j));
}
for j in 0..inner.len() {
inner_coefficients.push(inner.get_coefficient(j));
}
coefficients.push(inner_coefficients);
}
coefficients
coefficients
}
fn create_ciphertext_from_seal_data(ct: SealCiphertext, pk: &PrivateKey) -> InnerCiphertext {

View File

@@ -188,7 +188,8 @@ pub async fn get_node_data(
&bfv_session.program_data.clone(),
);
let coefficients = decrypt_inner_cipher(sunscreen_ciphertext.inner, &pk.0.data);
let coefficients =
decrypt_inner_cipher(sunscreen_ciphertext.inner, &pk.0.data);
DebugNodeType::Bfv(BfvNodeType {
// WARNING: `value` and `data_type` are nonsense values
@@ -221,12 +222,15 @@ pub async fn get_node_data(
let multiplicative_depth = 0;
let mut coefficients: Vec<Vec<u64>> = Vec::new();
let mut inner_coefficients = Vec::new();
for i in 0..pt.len() {
inner_coefficients.push(pt.get_coefficient(i));
}
coefficients.push(inner_coefficients);
let overflowed = overflow_occurred(
stable_graph,
NodeIndex::new(nodeid),
pk.0.params.plain_modulus,
pk,
&bfv_session.program_data.clone(),
);
let coefficients = decrypt_inner_plain(sunscreen_plaintext.inner);
DebugNodeType::Bfv(BfvNodeType {
// WARNING: `value` and `data_type` contain nonsense
@@ -235,7 +239,7 @@ pub async fn get_node_data(
noise_budget: None,
coefficients,
multiplicative_depth,
overflowed: None,
overflowed: Some(overflowed),
noise_exceeded: None,
})
}