mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
implemented overflow detection for fhe programs
This commit is contained in:
@@ -21,7 +21,6 @@ use std::hash::{Hash, Hasher};
|
||||
#[cfg(feature = "debugger")]
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
|
||||
|
||||
/**
|
||||
* Stores debug information about groups and stack traces.
|
||||
*/
|
||||
|
||||
@@ -134,4 +134,4 @@ impl Default for GroupLookup {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,14 +181,13 @@ mod tests {
|
||||
}
|
||||
|
||||
fn get_graph() -> CompilationResult<Operation> {
|
||||
|
||||
fn make_node(operation: Operation) -> NodeInfo<Operation> {
|
||||
NodeInfo {
|
||||
operation,
|
||||
NodeInfo {
|
||||
operation,
|
||||
#[cfg(feature = "debugger")]
|
||||
group_id: 0,
|
||||
#[cfg(feature = "debugger")]
|
||||
stack_id: 0
|
||||
stack_id: 0,
|
||||
}
|
||||
}
|
||||
let mut fe = CompilationResult::new();
|
||||
@@ -197,17 +196,11 @@ mod tests {
|
||||
|
||||
#[cfg(feature = "debugger")]
|
||||
{
|
||||
let in_1 = fe.add_node(make_node(
|
||||
Operation::PublicInput(NodeIndex::from(0))
|
||||
));
|
||||
let in_1 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(0))));
|
||||
|
||||
let in_2 = fe.add_node(make_node(
|
||||
Operation::PublicInput(NodeIndex::from(1))
|
||||
));
|
||||
let in_2 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(1))));
|
||||
|
||||
let in_3 = fe.add_node(make_node(
|
||||
Operation::PublicInput(NodeIndex::from(2))
|
||||
));
|
||||
let in_3 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(2))));
|
||||
|
||||
// Layer 2
|
||||
// sub_2 gets eliminated.
|
||||
@@ -305,12 +298,12 @@ mod tests {
|
||||
|
||||
fn get_expected() -> CompilationResult<Operation> {
|
||||
fn make_node(operation: Operation) -> NodeInfo<Operation> {
|
||||
NodeInfo {
|
||||
operation,
|
||||
NodeInfo {
|
||||
operation,
|
||||
#[cfg(feature = "debugger")]
|
||||
group_id: 0,
|
||||
#[cfg(feature = "debugger")]
|
||||
stack_id: 0
|
||||
stack_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,19 +311,12 @@ mod tests {
|
||||
|
||||
#[cfg(feature = "debugger")]
|
||||
{
|
||||
|
||||
// Layer 1
|
||||
let in_1 = fe.add_node(make_node(
|
||||
Operation::PublicInput(NodeIndex::from(0))
|
||||
));
|
||||
let in_1 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(0))));
|
||||
|
||||
let in_2 = fe.add_node(make_node(
|
||||
Operation::PublicInput(NodeIndex::from(1))
|
||||
));
|
||||
let in_2 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(1))));
|
||||
|
||||
let in_3 = fe.add_node(make_node(
|
||||
Operation::PublicInput(NodeIndex::from(2))
|
||||
));
|
||||
let in_3 = fe.add_node(make_node(Operation::PublicInput(NodeIndex::from(2))));
|
||||
|
||||
// Layer 2
|
||||
// sub_2 gets eliminated.
|
||||
|
||||
@@ -11,10 +11,9 @@ pub fn fhe_program_impl(
|
||||
metadata: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
|
||||
let input_fn = parse_macro_input!(input as ItemFn);
|
||||
let raw_fn = input_fn.span().source_text().unwrap_or_default();
|
||||
|
||||
|
||||
let fhe_program_name = &input_fn.sig.ident;
|
||||
let vis = &input_fn.vis;
|
||||
let body = &input_fn.block;
|
||||
|
||||
@@ -2,15 +2,13 @@ use crate::Ciphertext;
|
||||
use crate::InnerCiphertext;
|
||||
use crate::InnerPlaintext;
|
||||
use crate::PrivateKey;
|
||||
use crate::SealCiphertext;
|
||||
use crate::SealData;
|
||||
use crate::SealPlaintext;
|
||||
use crate::WithContext;
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use petgraph::stable_graph::StableGraph;
|
||||
use petgraph::Direction::Incoming;
|
||||
use rayon::vec;
|
||||
use crate::SealPlaintext;
|
||||
use crate::SealCiphertext;
|
||||
use crate::Plaintext;
|
||||
use seal_fhe::BfvEncryptionParametersBuilder;
|
||||
use seal_fhe::Context;
|
||||
use seal_fhe::Decryptor;
|
||||
@@ -18,11 +16,11 @@ use seal_fhe::Modulus;
|
||||
use seal_fhe::SecretKey;
|
||||
use semver::Version;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use sunscreen_compiler_common::GraphQuery;
|
||||
use sunscreen_compiler_common::Operation as OperationTrait;
|
||||
use sunscreen_compiler_common::Type;
|
||||
use sunscreen_compiler_common::{EdgeInfo, NodeInfo};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use sunscreen_fhe_program::Operation as FheOperation;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -45,10 +43,7 @@ pub struct BfvNodeType {
|
||||
/**
|
||||
* Gets the multiplicative depth of a node in the compilation graph.
|
||||
*/
|
||||
pub fn get_mult_depth<O>(
|
||||
graph: &StableGraph<NodeInfo<O>, EdgeInfo>,
|
||||
start_node: NodeIndex,
|
||||
) -> u64
|
||||
pub fn get_mult_depth<O>(graph: &StableGraph<NodeInfo<O>, EdgeInfo>, start_node: NodeIndex) -> u64
|
||||
where
|
||||
O: OperationTrait,
|
||||
{
|
||||
@@ -62,7 +57,12 @@ where
|
||||
while let Some((node, depth)) = queue.pop_front() {
|
||||
visited.insert(node, true);
|
||||
|
||||
let curr_depth = depth + graph.node_weight(node).unwrap().operation.is_multiplication() as u64;
|
||||
let curr_depth = depth
|
||||
+ graph
|
||||
.node_weight(node)
|
||||
.unwrap()
|
||||
.operation
|
||||
.is_multiplication() as u64;
|
||||
|
||||
max_depth = max_depth.max(curr_depth);
|
||||
|
||||
@@ -86,8 +86,7 @@ pub fn overflow_occurred(
|
||||
p: u64,
|
||||
pk: &PrivateKey,
|
||||
program_data: &[Option<SealData>],
|
||||
) -> bool
|
||||
{
|
||||
) -> bool {
|
||||
// Overflow only occurs at the output of an operation node
|
||||
let mut parents = graph.neighbors_directed(node, Incoming);
|
||||
if parents.clone().count() != 1 {
|
||||
@@ -98,7 +97,10 @@ pub fn overflow_occurred(
|
||||
|
||||
// Extract operand data
|
||||
let parent = parents.next().unwrap();
|
||||
let (left_op, right_op) = query.get_binary_operands(parent).expect(&format!("Parent node of {:?} is not a binary operation", node.index()));
|
||||
let (left_op, right_op) = query.get_binary_operands(parent).expect(&format!(
|
||||
"Parent node of {:?} is not a binary operation",
|
||||
node.index()
|
||||
));
|
||||
|
||||
let operand_nodes = [left_op, right_op];
|
||||
let mut op_coefficients: [Vec<Vec<u64>>; 2] = [Vec::new(), Vec::new()];
|
||||
@@ -123,7 +125,7 @@ pub fn overflow_occurred(
|
||||
SealData::Ciphertext(ct) => {
|
||||
let ciphertext = create_ciphertext_from_seal_data(ct, pk);
|
||||
op_coefficients[idx] = decrypt_inner_cipher(ciphertext, &pk.0.data);
|
||||
},
|
||||
}
|
||||
SealData::Plaintext(pt) => {
|
||||
let plaintext = create_plaintext_from_seal_data(pt, pk);
|
||||
op_coefficients[idx] = decrypt_inner_plain(plaintext);
|
||||
@@ -131,48 +133,18 @@ pub fn overflow_occurred(
|
||||
};
|
||||
}
|
||||
|
||||
// Extract current node's data
|
||||
let node_data = program_data
|
||||
.get(node.index())
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Couldn't find Option<SealData> in index {:?} of program_data",
|
||||
node.index()
|
||||
)
|
||||
})
|
||||
.clone()
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Option<SealData> in index {:?} was None",
|
||||
node.index()
|
||||
)
|
||||
});
|
||||
let result = match node_data {
|
||||
SealData::Ciphertext(ct) => {
|
||||
let ciphertext = create_ciphertext_from_seal_data(ct, pk);
|
||||
decrypt_inner_cipher(ciphertext, &pk.0.data)
|
||||
},
|
||||
SealData::Plaintext(pt) => {
|
||||
let plaintext = create_plaintext_from_seal_data(pt, pk);
|
||||
decrypt_inner_plain(plaintext)
|
||||
},
|
||||
};
|
||||
|
||||
// Overflow only occurs on arithmetic operations involving at least 1 ciphertext
|
||||
match graph.node_weight(node).unwrap().operation {
|
||||
FheOperation::Multiply | FheOperation::MultiplyPlaintext => mul_overflow_occurred(op_coefficients, result, p),
|
||||
FheOperation::Add | FheOperation::AddPlaintext => add_overflow_occurred(op_coefficients, result, p),
|
||||
FheOperation::Sub | FheOperation::SubPlaintext => sub_overflow_occurred(op_coefficients, result, p),
|
||||
_ => false
|
||||
FheOperation::Multiply | FheOperation::MultiplyPlaintext => {
|
||||
mul_overflow_occurred(op_coefficients, p)
|
||||
}
|
||||
FheOperation::Add | FheOperation::AddPlaintext => add_overflow_occurred(op_coefficients, p),
|
||||
FheOperation::Sub | FheOperation::SubPlaintext => sub_overflow_occurred(op_coefficients, p),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_overflow_occurred(
|
||||
operands: [Vec<Vec<u64>>; 2],
|
||||
result: Vec<Vec<u64>>,
|
||||
p: u64
|
||||
) -> bool
|
||||
{
|
||||
pub fn add_overflow_occurred(operands: [Vec<Vec<u64>>; 2], p: u64) -> bool {
|
||||
for (c0, c1) in operands[0].iter().zip(operands[1].iter()) {
|
||||
// Addition overflow
|
||||
for i in 0..c0.len() {
|
||||
@@ -184,25 +156,35 @@ pub fn add_overflow_occurred(
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
false
|
||||
}
|
||||
|
||||
pub fn sub_overflow_occurred(
|
||||
operands: [Vec<Vec<u64>>; 2],
|
||||
result: Vec<Vec<u64>>,
|
||||
p: u64
|
||||
) -> bool
|
||||
{
|
||||
true
|
||||
pub fn sub_overflow_occurred(operands: [Vec<Vec<u64>>; 2], p: u64) -> bool {
|
||||
let negated_coeffs = operands[1]
|
||||
.iter()
|
||||
.map(|vec| vec.iter().map(|x| p - x).collect())
|
||||
.collect();
|
||||
let new_operands = [operands[0].clone(), negated_coeffs];
|
||||
add_overflow_occurred(new_operands, p)
|
||||
}
|
||||
|
||||
pub fn mul_overflow_occurred(
|
||||
operands: [Vec<Vec<u64>>; 2],
|
||||
result: Vec<Vec<u64>>,
|
||||
p: u64
|
||||
) -> bool
|
||||
{
|
||||
true
|
||||
pub fn mul_overflow_occurred(operands: [Vec<Vec<u64>>; 2], p: u64) -> bool {
|
||||
for (poly1, poly2) in operands[0].iter().zip(&operands[1]) {
|
||||
let product = polynomial_mult(poly1, poly2);
|
||||
let product_mod = polynomial_mult_mod(poly1, poly2, p);
|
||||
|
||||
for (coeff, coeff_mod) in product.iter().zip(&product_mod) {
|
||||
let signed_coeff = if *coeff_mod > p / 2 {
|
||||
*coeff_mod as i64 - p as i64
|
||||
} else {
|
||||
*coeff_mod as i64
|
||||
};
|
||||
if *coeff as i64 != signed_coeff {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn polynomial_mult(a: &[u64], b: &[u64]) -> Vec<u64> {
|
||||
@@ -274,12 +256,12 @@ pub fn decrypt_inner_plain(inner_plain: InnerPlaintext) -> Vec<Vec<u64>> {
|
||||
for i in 0..inner_plain.as_seal_plaintext().unwrap().len() {
|
||||
let inner = inner_plain.as_seal_plaintext().unwrap().get(i).unwrap();
|
||||
let mut inner_coefficients = Vec::new();
|
||||
for j in 0..inner.len() {
|
||||
inner_coefficients.push(inner.get_coefficient(j));
|
||||
}
|
||||
for j in 0..inner.len() {
|
||||
inner_coefficients.push(inner.get_coefficient(j));
|
||||
}
|
||||
coefficients.push(inner_coefficients);
|
||||
}
|
||||
coefficients
|
||||
coefficients
|
||||
}
|
||||
|
||||
fn create_ciphertext_from_seal_data(ct: SealCiphertext, pk: &PrivateKey) -> InnerCiphertext {
|
||||
|
||||
@@ -188,7 +188,8 @@ pub async fn get_node_data(
|
||||
&bfv_session.program_data.clone(),
|
||||
);
|
||||
|
||||
let coefficients = decrypt_inner_cipher(sunscreen_ciphertext.inner, &pk.0.data);
|
||||
let coefficients =
|
||||
decrypt_inner_cipher(sunscreen_ciphertext.inner, &pk.0.data);
|
||||
|
||||
DebugNodeType::Bfv(BfvNodeType {
|
||||
// WARNING: `value` and `data_type` are nonsense values
|
||||
@@ -221,12 +222,15 @@ pub async fn get_node_data(
|
||||
|
||||
let multiplicative_depth = 0;
|
||||
|
||||
let mut coefficients: Vec<Vec<u64>> = Vec::new();
|
||||
let mut inner_coefficients = Vec::new();
|
||||
for i in 0..pt.len() {
|
||||
inner_coefficients.push(pt.get_coefficient(i));
|
||||
}
|
||||
coefficients.push(inner_coefficients);
|
||||
let overflowed = overflow_occurred(
|
||||
stable_graph,
|
||||
NodeIndex::new(nodeid),
|
||||
pk.0.params.plain_modulus,
|
||||
pk,
|
||||
&bfv_session.program_data.clone(),
|
||||
);
|
||||
|
||||
let coefficients = decrypt_inner_plain(sunscreen_plaintext.inner);
|
||||
|
||||
DebugNodeType::Bfv(BfvNodeType {
|
||||
// WARNING: `value` and `data_type` contain nonsense
|
||||
@@ -235,7 +239,7 @@ pub async fn get_node_data(
|
||||
noise_budget: None,
|
||||
coefficients,
|
||||
multiplicative_depth,
|
||||
overflowed: None,
|
||||
overflowed: Some(overflowed),
|
||||
noise_exceeded: None,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user