removed secretkey from debug_fhe_program, now uses privatekey instead

This commit is contained in:
Matthew Liu
2023-07-18 16:42:47 -07:00
parent bb59e579b1
commit 27f2153902
6 changed files with 144 additions and 49 deletions

View File

@@ -35,7 +35,7 @@ fn main() {
app.get_fhe_program("mad").unwrap(),
args1,
&public,
&private.0,
&private,
mad.source(),
)
.unwrap();
@@ -47,7 +47,7 @@ fn main() {
app.get_fhe_program("add_squares").unwrap(),
args2,
&public,
&private.0,
&private,
add_squares.source(),
)
.unwrap();

View File

@@ -293,7 +293,6 @@ where
pub group_stack: Vec<Group>,
}
// TODO: add modified support for `group_stack` with feature flag
impl<O, D> Context<O, D>
where
O: Operation,

View File

@@ -1,11 +1,7 @@
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use crate::{SealData};
use crate::{SealData, PrivateKey};
use sunscreen_compiler_common::CompilationResult;
use sunscreen_fhe_program::Operation;
@@ -65,7 +61,7 @@ pub struct BfvSession {
/**
* Used for decryption of ciphertexts for visualization.
*/
pub secret_key: SecretKey,
pub private_key: PrivateKey,
/**
* The source code of the FHE program.
@@ -78,7 +74,7 @@ impl BfvSession {
*/
pub fn new(
graph: &CompilationResult<Operation>,
secret_key: &SecretKey,
private_key: &PrivateKey,
source_code: &str,
) -> Self {
Self {
@@ -86,7 +82,7 @@ impl BfvSession {
// don't need a hashmap; if you don't encounter in the right order, it's all initialize das None so you
// can go back later and fill it in
program_data: vec![None; graph.node_count()],
secret_key: secret_key.clone(),
private_key: private_key.clone(),
source_code: source_code.to_owned(),
}
}

View File

@@ -50,7 +50,7 @@ pub struct PublicKey {
/**
* The private key used to decrypt ciphertexts.
*/
pub struct PrivateKey(pub WithContext<SealSecretKey>);
pub struct PrivateKey(pub(crate) WithContext<SealSecretKey>);
#[cfg(test)]
mod tests {

View File

@@ -1,11 +1,17 @@
use crate::{InnerPlaintext, SealData};
use crate::{InnerPlaintext, SealData, PrivateKey};
use static_assertions::const_assert;
use sunscreen_compiler_common::{GraphQuery, GraphQueryError};
use sunscreen_fhe_program::{FheProgram, FheProgramTrait, Literal, Operation::*};
#[cfg(feature = "debugger")]
use sunscreen_fhe_program::{SecurityLevel::TC128, SchemeType::Bfv};
#[cfg(feature = "debugger")]
use crate::debugger::sessions::{get_sessions, BfvSession};
#[cfg(feature = "debugger")]
use crate::serialization::WithContext;
use crossbeam::atomic::AtomicCell;
use petgraph::{stable_graph::NodeIndex, Direction};
#[cfg(test)]
@@ -92,9 +98,9 @@ impl From<SealError> for FheProgramRunFailure {
*/
pub struct DebugInfo<'a> {
/**
* The secret key associated with the debugger session. Used for decryption for visualization.
* The private key associated with the debugger session. Used for decryption for visualization.
*/
pub secret_key: &'a SecretKey,
pub private_key: &'a PrivateKey,
/**
* The name of the debugger session.
@@ -186,16 +192,14 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
let mut guard = get_sessions().lock().unwrap();
assert!(!guard.contains_key(&v.session_name));
let session = BfvSession::new(&ir.graph, v.secret_key, source_code);
let session = BfvSession::new(&ir.graph, v.private_key, source_code);
guard.insert(v.session_name.clone(), session.into());
}
None => {}
}
// #[cfg(feature = "debugger")]
// this function won't actually straight up decrypt stuff, it'll just insert ciphertexts correpsonding to nodeindex
// into sessions, and then we can decrypt those values on the backend before sending to the frontend
#[cfg(feature = "debugger")]
fn set_data(
data: &Vec<AtomicCell<Option<Arc<SealData>>>>,
node_index: NodeIndex,
@@ -214,27 +218,6 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
.unwrap_bfv_session_mut();
let node_val = get_data(data, node_index.index());
session.program_data[node_index.index()] = Arc::into_inner(node_val.unwrap().clone());
// pretty sure this code is not necessary: the point of this is that
// the program_infos are already created, so now it's just about
// storing the data into that struct
/*
// this should not happen
if lock.contains_key(&session) {
let program_info = lock.get(&session).unwrap();
// you don't need to match on results if you just want error propagation
let node_val = get_data(&data, node_index.index())?;
program_info.program_data.insert(node_index.index(), )
}
// Insert for new session
else {
let program_info = BfvSession::new(ir.graph, dbg_info.secret_key);
let node_val = get_data(&data, node_index.index())?;
program_info.program_data[node_index.index()] = node_val;
guard.insert(session, program_info);
}
*/
}
Ok(())
@@ -633,10 +616,22 @@ where
#[cfg(test)]
mod tests {
use crate::Params;
use super::*;
use seal_fhe::*;
use sunscreen_fhe_program::{FheProgramTrait, SchemeType};
fn setup_parameters(degree: u64) -> EncryptionParameters {
BfvEncryptionParametersBuilder::new()
.set_poly_modulus_degree(degree)
.set_plain_modulus(PlainModulus::batching(degree, 17).unwrap())
.set_coefficient_modulus(
CoefficientModulus::bfv_default(degree, SecurityLevel::default()).unwrap(),
)
.build()
.unwrap()
}
fn setup_scheme(
degree: u64,
) -> (
@@ -719,6 +714,21 @@ mod tests {
.unwrap()
};
#[cfg(feature = "debugger")]
let encryption_params = setup_parameters(degree);
#[cfg(feature = "debugger")]
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128,128,128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
#[cfg(feature = "debugger")]
let output = unsafe {
run_program_unchecked(
@@ -728,7 +738,7 @@ mod tests {
&None,
&None,
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "simple_add".to_owned(),
}),
"empty",
@@ -786,6 +796,21 @@ mod tests {
.unwrap()
};
#[cfg(feature = "debugger")]
let encryption_params = setup_parameters(degree);
#[cfg(feature = "debugger")]
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128,128,128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
#[cfg(feature = "debugger")]
let output = unsafe {
run_program_unchecked(
@@ -795,7 +820,7 @@ mod tests {
&Some(&relin_keys),
&None,
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "simple_mul".to_owned(),
}),
"empty"
@@ -854,6 +879,21 @@ mod tests {
.unwrap()
};
#[cfg(feature = "debugger")]
let encryption_params = setup_parameters(degree);
#[cfg(feature = "debugger")]
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128,128,128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
#[cfg(feature = "debugger")]
let output = unsafe {
run_program_unchecked(
@@ -863,7 +903,7 @@ mod tests {
&Some(&relin_keys),
&None,
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "can_mul_and_relinearize".to_owned(),
}),
"empty"
@@ -937,6 +977,21 @@ mod tests {
.unwrap()
};
#[cfg(feature = "debugger")]
let encryption_params = setup_parameters(degree);
#[cfg(feature = "debugger")]
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128,128,128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
#[cfg(feature = "debugger")]
let output = unsafe {
run_program_unchecked(
@@ -946,7 +1001,7 @@ mod tests {
&Some(&relin_keys),
&None,
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "add_reduction".to_owned(),
}),
"empty"
@@ -1002,6 +1057,22 @@ mod tests {
)
.unwrap()
};
#[cfg(feature = "debugger")]
let encryption_params = setup_parameters(degree);
#[cfg(feature = "debugger")]
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128,128,128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
#[cfg(feature = "debugger")]
let output = unsafe {
run_program_unchecked(
@@ -1011,7 +1082,7 @@ mod tests {
&None,
&Some(&galois_keys),
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "rotate_left".to_owned(),
}),
"empty"
@@ -1072,6 +1143,22 @@ mod tests {
)
.unwrap()
};
#[cfg(feature = "debugger")]
let encryption_params = setup_parameters(degree);
#[cfg(feature = "debugger")]
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128,128,128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
#[cfg(feature = "debugger")]
let output = unsafe {
run_program_unchecked(
@@ -1081,7 +1168,7 @@ mod tests {
&None,
&Some(&galois_keys),
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "rotate_right".to_owned(),
}),
"empty"
@@ -1131,6 +1218,19 @@ mod tests {
let ct_0 = encryptor.encrypt(&pt_0).unwrap();
let encryption_params = setup_parameters(degree);
let private_key = PrivateKey(WithContext {
params: Params {
lattice_dimension: encryption_params.get_poly_modulus_degree(),
coeff_modulus: /* encryption_params.get_coefficient_modulus() as Vec<u64>*/ vec![128, 128, 128],
plain_modulus: 1024,
scheme_type: Bfv,
security_level: TC128
},
data: private_key
});
let output = unsafe {
run_program_unchecked(
&ir,
@@ -1139,7 +1239,7 @@ mod tests {
&None,
&Some(&galois_keys),
Some(DebugInfo {
secret_key: &private_key,
private_key: &private_key,
session_name: "new_session".to_owned(),
}),
"empty"

View File

@@ -426,7 +426,7 @@ where
fhe_program: &CompiledFheProgram,
arguments: Vec<I>,
public_key: &PublicKey,
secret_key: &SecretKey,
private_key: &PrivateKey,
#[cfg(feature = "debugger")] source_code: &str,
) -> Result<()>
where
@@ -445,7 +445,7 @@ where
arguments,
public_key,
Some(DebugInfo {
secret_key,
private_key,
session_name,
}),
#[cfg(feature = "debugger")]