diff --git a/sunscreen/src/fhe/mod.rs b/sunscreen/src/fhe/mod.rs index 5c28806a8..2fd5310c2 100644 --- a/sunscreen/src/fhe/mod.rs +++ b/sunscreen/src/fhe/mod.rs @@ -179,9 +179,10 @@ where let ctx = option .as_mut() .expect("Called Ciphertext::new() outside of a context.") - .unwrap_fhe_mut(); + .unwrap_fhe_mut() + .unwrap(); - f(ctx.unwrap()) + f(ctx) }) } diff --git a/sunscreen/src/lib.rs b/sunscreen/src/lib.rs index d198be1f5..85922f96e 100644 --- a/sunscreen/src/lib.rs +++ b/sunscreen/src/lib.rs @@ -65,6 +65,7 @@ pub mod types; use fhe::{FheOperation, Literal, FheContext}; use petgraph::stable_graph::StableGraph; +use seal_fhe::Context; use serde::{Deserialize, Serialize}; use sunscreen_runtime::{marker, Fhe, FheZkp, Zkp}; use sunscreen_zkp_backend::CompiledZkpProgram; @@ -88,8 +89,7 @@ pub use sunscreen_runtime::{ pub use sunscreen_zkp_backend::{BackendField, Error as ZkpError, Result as ZkpResult, ZkpBackend}; pub use zkp::ZkpProgramFn; pub use zkp::{ - invoke_gadget, with_zkp_ctx, ZkpContext, ZkpContextOps, ZkpData, ZkpFrontendCompilation, - CURRENT_ZKP_CTX, + invoke_gadget, with_zkp_ctx, ZkpContext, ZkpContextOps, ZkpData, ZkpFrontendCompilation }; pub use fhe::{ with_fhe_ctx, CURRENT_PROGRAM_CTX diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index b0da3721e..e62b55761 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -37,7 +37,7 @@ use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub}; * construction. * * # Undefined behavior - * These types must be constructed while [`CURRENT_FHE_CTX`][crate::fhe::CURRENT_FHE_CTX] refers to a valid + * These types must be constructed while [`CURRENT_PROGRAM_CTX`][crate::fhe::CURRENT_PROGRAM_CTX] refers to a valid * [`FheContext`](crate::fhe::FheContext). Furthermore, no [`FheProgramNode`] should outlive the said context. * Violating any of these conditions may result in memory corruption or * use-after-free. diff --git a/sunscreen/src/types/intern/input.rs b/sunscreen/src/types/intern/input.rs index 88eb3e5fa..659d760da 100644 --- a/sunscreen/src/types/intern/input.rs +++ b/sunscreen/src/types/intern/input.rs @@ -1,4 +1,4 @@ -use crate::fhe::{with_fhe_ctx, FheContextOps}; +use crate::{fhe::{with_fhe_ctx, FheContextOps}, ContextEnum}; pub use crate::types::{intern::FheProgramNode, Cipher, FheType, NumCiphertexts, TypeName}; /** @@ -86,7 +86,7 @@ fn can_create_inputs() { security_level: SecurityLevel::TC128, }); - ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) }))); + ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut ContextEnum::Fhe(context.clone())) }))); let scalar_node: FheProgramNode = FheProgramNode::input(); let mut offset = 0; diff --git a/sunscreen/src/zkp/mod.rs b/sunscreen/src/zkp/mod.rs index e976d5051..c1103b237 100644 --- a/sunscreen/src/zkp/mod.rs +++ b/sunscreen/src/zkp/mod.rs @@ -4,7 +4,7 @@ use sunscreen_compiler_common::DebugData; use sunscreen_runtime::CallSignature; use sunscreen_zkp_backend::{BackendField, BigInt, Gadget, Operation as JitOperation}; -use crate::Result; +use crate::{Result, CURRENT_PROGRAM_CTX}; use std::collections::HashMap; use std::hash::Hash; @@ -187,6 +187,7 @@ impl Operation { * An implementation detail of a ZKP program. During compilation, it * tracks how many public and private inputs have been added. */ +#[derive(Clone)] pub struct ZkpData { next_public_input: usize, next_private_input: usize, @@ -367,14 +368,6 @@ impl Render for Operation { } } -thread_local! { - /** - * Contains the graph of a ZKP program during compilation. An - * implementation detail and not for public consumption. - */ - pub static CURRENT_ZKP_CTX: RefCell> = RefCell::new(None); -} - /** * Runs the specified closure, injecting the current * [`fhe_program`](crate::fhe_program) context. @@ -383,11 +376,13 @@ pub fn with_zkp_ctx(f: F) -> R where F: FnOnce(&mut ZkpContext) -> R, { - CURRENT_ZKP_CTX.with(|ctx| { + CURRENT_PROGRAM_CTX.with(|ctx| { let mut option = ctx.borrow_mut(); let ctx = option .as_mut() - .expect("Called with_zkp_ctx() outside of a context."); + .expect("Called with_zkp_ctx() outside of a context.") + .unwrap_zkp_mut() + .unwrap(); f(ctx) }) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 74d7456c4..4afa105ad 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -134,7 +134,7 @@ pub fn fhe_program_impl( fn build(&self, params: &sunscreen::Params) -> sunscreen::Result { use std::cell::RefCell; use std::mem::transmute; - use sunscreen::{fhe::{CURRENT_PROGRAM_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}}; + use sunscreen::{fhe::{CURRENT_PROGRAM_CTX, FheContext}, ContextEnum, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}}; if SchemeType::Bfv != params.scheme_type { return Err(Error::IncorrectScheme) @@ -153,7 +153,7 @@ pub fn fhe_program_impl( // Transmute away the lifetime to 'static. So long as we are careful with internal() // panicing, this is safe because we set the context back to none before the funtion // returns. - ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) }))); + ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut ContextEnum::Fhe(context.clone())) }))); #(#var_decl)* diff --git a/sunscreen_compiler_macros/src/zkp_program.rs b/sunscreen_compiler_macros/src/zkp_program.rs index 446ccba55..3a6885a84 100644 --- a/sunscreen_compiler_macros/src/zkp_program.rs +++ b/sunscreen_compiler_macros/src/zkp_program.rs @@ -171,7 +171,7 @@ fn parse_inner(_attr_params: ZkpProgramAttrs, input_fn: ItemFn) -> Result sunscreen::Result { use std::cell::RefCell; use std::mem::transmute; - use sunscreen::{CURRENT_PROGRAM_CTX, ZkpContext, ZkpData, Error, INDEX_ARENA, Result, types::{zkp::{ProgramNode, CreateZkpProgramInput, ConstrainEq, IntoProgramNode}, TypeName}}; + use sunscreen::{CURRENT_PROGRAM_CTX, ZkpContext, ContextEnum, ZkpData, Error, INDEX_ARENA, Result, types::{zkp::{ProgramNode, CreateZkpProgramInput, ConstrainEq, IntoProgramNode}, TypeName}}; let mut context = ZkpContext::new(ZkpData::new()); @@ -179,7 +179,7 @@ fn parse_inner(_attr_params: ZkpProgramAttrs, input_fn: ItemFn) -> Result