Rweber/zkp (#185)

Start ZKP compiler and refactor common code.
This commit is contained in:
rickwebiii
2022-11-15 12:43:04 -08:00
committed by GitHub
parent 390a27b1cd
commit 28ea71118f
52 changed files with 2742 additions and 776 deletions

19
.vscode/launch.json vendored
View File

@@ -304,6 +304,25 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug integration test 'zkp_program_tests'",
"cargo": {
"args": [
"test",
"--no-run",
"--test=zkp_program_tests",
"--package=sunscreen"
],
"filter": {
"name": "zkp_program_tests",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",

274
Cargo.lock generated
View File

@@ -79,12 +79,54 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "block-buffer"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
dependencies = [
"block-padding",
"generic-array",
]
[[package]]
name = "block-padding"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae"
[[package]]
name = "bulletproofs"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40e698f1df446cc6246afd823afbe2d121134d089c9102c1dd26d1264991ba32"
dependencies = [
"byteorder",
"clear_on_drop",
"curve25519-dalek-ng",
"digest",
"merlin",
"rand",
"rand_core",
"serde",
"serde_derive",
"sha3",
"subtle-ng",
"thiserror",
]
[[package]]
name = "bumpalo"
version = "3.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "1.2.1"
@@ -182,6 +224,15 @@ dependencies = [
"os_str_bytes",
]
[[package]]
name = "clear_on_drop"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38508a63f4979f0048febc9966fadbd48e5dab31fd0ec6a3f151bbf4a74f7423"
dependencies = [
"cc",
]
[[package]]
name = "cmake"
version = "0.1.48"
@@ -207,6 +258,15 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
[[package]]
name = "cpufeatures"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320"
dependencies = [
"libc",
]
[[package]]
name = "crossbeam"
version = "0.8.2"
@@ -276,6 +336,29 @@ dependencies = [
"once_cell",
]
[[package]]
name = "curve25519-dalek-ng"
version = "4.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c359b7249347e46fb28804470d071c921156ad62b3eef5d34e2ba867533dec8"
dependencies = [
"byteorder",
"digest",
"rand_core",
"serde",
"subtle-ng",
"zeroize",
]
[[package]]
name = "digest"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
dependencies = [
"generic-array",
]
[[package]]
name = "dot_prod"
version = "0.1.0"
@@ -427,6 +510,27 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "glob"
version = "0.3.0"
@@ -594,6 +698,15 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "keccak"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768"
dependencies = [
"cpufeatures",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@@ -663,6 +776,18 @@ dependencies = [
"autocfg",
]
[[package]]
name = "merlin"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d"
dependencies = [
"byteorder",
"keccak",
"rand_core",
"zeroize",
]
[[package]]
name = "mime"
version = "0.3.16"
@@ -807,6 +932,12 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
[[package]]
name = "opaque-debug"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl"
version = "0.10.41"
@@ -907,6 +1038,12 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.47"
@@ -925,6 +1062,36 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rayon"
version = "1.5.3"
@@ -1149,6 +1316,18 @@ dependencies = [
"serde",
]
[[package]]
name = "sha3"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809"
dependencies = [
"block-buffer",
"digest",
"keccak",
"opaque-debug",
]
[[package]]
name = "shlex"
version = "1.1.0"
@@ -1187,6 +1366,12 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subtle-ng"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "734676eb262c623cec13c3155096e08d1f8f29adce39ba17948b18dad1e54142"
[[package]]
name = "sunscreen"
version = "0.7.0"
@@ -1200,6 +1385,7 @@ dependencies = [
"serde",
"serde_json",
"sunscreen_backend",
"sunscreen_compiler_common",
"sunscreen_compiler_macros",
"sunscreen_fhe_program",
"sunscreen_runtime",
@@ -1224,8 +1410,11 @@ name = "sunscreen_compiler_common"
version = "0.1.0"
dependencies = [
"petgraph",
"proc-macro2",
"quote",
"semver",
"serde",
"syn",
]
[[package]]
@@ -1235,6 +1424,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde_json",
"sunscreen_compiler_common",
"syn",
]
@@ -1268,17 +1458,38 @@ dependencies = [
"sunscreen_fhe_program",
]
[[package]]
name = "sunscreen_zkp_backend"
version = "0.1.0"
dependencies = [
"bulletproofs",
"curve25519-dalek-ng",
"merlin",
]
[[package]]
name = "syn"
version = "1.0.100"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52205623b1b0f064a4e71182c3b18ae902267282930c6d5462c91b859668426e"
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f"
dependencies = [
"proc-macro2",
"quote",
"syn",
"unicode-xid",
]
[[package]]
name = "tempfile"
version = "3.3.0"
@@ -1308,6 +1519,26 @@ version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16"
[[package]]
name = "thiserror"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
@@ -1397,6 +1628,12 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "typenum"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicode-bidi"
version = "0.3.8"
@@ -1418,6 +1655,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-xid"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "url"
version = "2.3.1"
@@ -1435,6 +1678,12 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "want"
version = "0.3.0"
@@ -1620,3 +1869,24 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]
[[package]]
name = "zeroize"
version = "1.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c394b5bd0c6f669e7275d9c20aa90ae064cb22e75a1cad54e1b34088034b149f"
dependencies = [
"zeroize_derive",
]
[[package]]
name = "zeroize_derive"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f8f187641dad4f680d25c4bfc4225b418165984179f26ca76ec4fb6441d3a17"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]

View File

@@ -17,6 +17,8 @@ members = [
"sunscreen_compiler_macros",
"sunscreen_fhe_program",
"sunscreen_runtime",
"sunscreen_compiler_common",
"sunscreen_zkp_backend",
]
exclude = [
"mdBook",

View File

@@ -42,7 +42,7 @@ impl PartialEq for Modulus {
* Microsoft SEAL when constructing a SEALContext object. Normal users should not
* have to specify the security level explicitly anywhere.
*/
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[repr(i32)]
pub enum SecurityLevel {
/// 128-bit security level according to HomomorphicEncryption.org standard.

View File

@@ -1,3 +1,4 @@
use core::hash::Hash;
use std::ffi::{c_void, CString};
use std::ptr::null_mut;
@@ -7,7 +8,7 @@ use crate::{bindgen, serialization::CompressionType, Context, FromBytes, ToBytes
use serde::ser::Error;
use serde::{Serialize, Serializer};
#[derive(Debug)]
#[derive(Debug, Eq)]
/**
* Class to store a plaintext element. The data for the plaintext is
* a polynomial with coefficients modulo the plaintext modulus. The degree
@@ -66,6 +67,15 @@ impl PartialEq for Plaintext {
}
}
impl Hash for Plaintext {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
for i in 0..self.len() {
let c = self.get_coefficient(i);
state.write_u64(c);
}
}
}
impl Serialize for Plaintext {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where

View File

@@ -24,14 +24,16 @@ bumpalo = "3.8.0"
log = "0.4.14"
num = "0.4.0"
petgraph = "0.6.0"
sunscreen_compiler_common = { path = "../sunscreen_compiler_common" }
sunscreen_compiler_macros = { version = "0.7", path = "../sunscreen_compiler_macros" }
sunscreen_backend = { version = "0.7", path = "../sunscreen_backend" }
sunscreen_fhe_program = { version = "0.7", path = "../sunscreen_fhe_program" }
sunscreen_runtime = { version = "0.7", path = "../sunscreen_runtime" }
seal_fhe = { version = "0.7", path = "../seal_fhe" }
serde = { version = "1.0.130", features = ["derive"] }
serde = { version = "1.0.147", features = ["derive"] }
[dev-dependencies]
sunscreen_compiler_common = { path = "../sunscreen_compiler_common" }
serde_json = "1.0.72"
float-cmp = "0.9.0"

394
sunscreen/src/fhe/mod.rs Normal file
View File

@@ -0,0 +1,394 @@
use petgraph::stable_graph::NodeIndex;
use serde::{Deserialize, Serialize};
use sunscreen_backend::compile_inplace;
use sunscreen_compiler_common::{
Context, EdgeInfo, FrontendCompilation, Operation as OperationTrait,
};
use sunscreen_fhe_program::{
EdgeInfo as FheProgramEdgeInfo, FheProgram, Literal as FheProgramLiteral, NodeInfo,
Operation as FheProgramOperation, SchemeType,
};
use sunscreen_runtime::{InnerPlaintext, Params};
use std::cell::RefCell;
#[derive(Clone, Debug, Deserialize, Hash, Serialize, PartialEq, Eq)]
/**
* Represents a literal node's data.
*/
pub enum Literal {
/**
* An unsigned 64-bit integer.
*/
U64(u64),
/**
* An encoded plaintext value.
*/
Plaintext(InnerPlaintext),
}
#[derive(Clone, Debug, Hash, Deserialize, Serialize, PartialEq, Eq)]
/**
* Represents an operation occurring in the frontend AST.
*/
pub enum FheOperation {
/**
* This node indicates loading a cipher text from an input.
*/
InputCiphertext,
/**
* This node indicates loading a plaintext from an input.
*/
InputPlaintext,
/**
* Addition.
*/
Add,
/**
* Add a ciphertext and plaintext value.
*/
AddPlaintext,
/**
* Subtraction.
*/
Sub,
/**
* Subtract a plaintext.
*/
SubPlaintext,
/**
* Unary negation (i.e. given x, compute -x)
*/
Negate,
/**
* Multiplication.
*/
Multiply,
/**
* Multiply a ciphertext by a plaintext.
*/
MultiplyPlaintext,
/**
* A literal that serves as an operand to other operations.
*/
Literal(Literal),
/**
* Rotate left.
*/
RotateLeft,
/**
* Rotate right.
*/
RotateRight,
/**
* In the BFV scheme, swap rows in the Batched vectors.
*/
SwapRows,
/**
* This node indicates the previous node's result should be a result of the [`fhe_program`](crate::fhe_program).
*/
Output,
}
impl OperationTrait for FheOperation {
fn is_binary(&self) -> bool {
matches!(
self,
FheOperation::Add
| FheOperation::Multiply
| FheOperation::Sub
| FheOperation::RotateLeft
| FheOperation::RotateRight
| FheOperation::SubPlaintext
| FheOperation::AddPlaintext
| FheOperation::MultiplyPlaintext
)
}
fn is_commutative(&self) -> bool {
matches!(
self,
FheOperation::Add
| FheOperation::Multiply
| FheOperation::AddPlaintext
| FheOperation::MultiplyPlaintext
)
}
fn is_unary(&self) -> bool {
matches!(self, FheOperation::Negate | FheOperation::SwapRows)
}
}
/**
* The context for constructing the [`fhe_program`](crate::fhe_program) graph during compilation.
*
* This is an implementation detail of the
* [`fhe_program`](crate::fhe_program) macro, and you shouldn't need
* to construct one.
*/
pub type FheContext = Context<FheOperation, Params>;
/**
*
*/
pub type FheFrontendCompilation = FrontendCompilation<FheOperation>;
thread_local! {
/**
* Contains the graph of a ZKP program during compilation. An
* implementation detail and not for public consumption.
*/
pub static CURRENT_FHE_CTX: RefCell<Option<&'static mut FheContext>> = RefCell::new(None);
}
/**
* Runs the specified closure, injecting the current
* [`fhe_program`](crate::fhe_program) context.
*/
pub fn with_fhe_ctx<F, R>(f: F) -> R
where
F: FnOnce(&mut FheContext) -> R,
{
CURRENT_FHE_CTX.with(|ctx| {
let mut option = ctx.borrow_mut();
let ctx = option
.as_mut()
.expect("Called Ciphertext::new() outside of a context.");
f(ctx)
})
}
/**
* Defines transformations to FHE program graphs.
*/
pub trait FheContextOps {
/**
* Add an encrypted input to this context.
*/
fn add_ciphertext_input(&mut self) -> NodeIndex;
/**
* Add a plaintext input to this context.
*/
fn add_plaintext_input(&mut self) -> NodeIndex;
/**
* Adds a plaintext literal to the
* [`fhe_program`](crate::fhe_program) graph.
*/
fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex;
/**
* Add a subtraction to this context.
*/
fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Add a subtraction to this context.
*/
fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Adds a negation to this context.
*/
fn add_negate(&mut self, x: NodeIndex) -> NodeIndex;
/**
* Add an addition to this context.
*/
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Adds an addition to a plaintext.
*/
fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Add a multiplication to this context.
*/
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Add a multiplication to this context.
*/
fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Adds a literal to this context.
*/
fn add_literal(&mut self, literal: Literal) -> NodeIndex;
/**
* Add a rotate left.
*/
fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Add a rotate right.
*/
fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
/**
* Adds a row swap.
*/
fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex;
/**
* Add a node that captures the previous node as an output.
*/
fn add_output(&mut self, i: NodeIndex) -> NodeIndex;
}
impl FheContextOps for FheContext {
fn add_ciphertext_input(&mut self) -> NodeIndex {
self.add_node(FheOperation::InputCiphertext)
}
fn add_plaintext_input(&mut self) -> NodeIndex {
self.add_node(FheOperation::InputPlaintext)
}
fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex {
self.add_node(FheOperation::Literal(Literal::Plaintext(plaintext)))
}
fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::Sub, left, right)
}
fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::SubPlaintext, left, right)
}
fn add_negate(&mut self, x: NodeIndex) -> NodeIndex {
self.add_unary_operation(FheOperation::Negate, x)
}
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::Add, left, right)
}
fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::AddPlaintext, left, right)
}
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::Multiply, left, right)
}
fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::MultiplyPlaintext, left, right)
}
fn add_literal(&mut self, literal: Literal) -> NodeIndex {
// See if we already have a node for the given literal. If so, just return it.
// If not, make a new one.
let existing_literal =
self.graph
.node_indices()
.find(|&i| match &self.graph[i].operation {
FheOperation::Literal(x) => *x == literal,
_ => false,
});
match existing_literal {
Some(x) => x,
None => self.add_node(FheOperation::Literal(literal)),
}
}
fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::RotateLeft, left, right)
}
fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::RotateRight, left, right)
}
fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
self.add_unary_operation(FheOperation::SwapRows, x)
}
fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
self.add_unary_operation(FheOperation::Output, i)
}
}
/**
* Extends FheFrontendCompilation to add a backend compilation method.
*/
pub trait FheCompile {
/**
* Performs frontend compilation of this intermediate representation into a backend [`FheProgram`],
* then perform backend compilation and return the result.
*/
fn compile(&self) -> FheProgram;
}
impl FheCompile for FheFrontendCompilation {
fn compile(&self) -> FheProgram {
let mut fhe_program = FheProgram::new(SchemeType::Bfv);
let mapped_graph = self.0.map(
|id, n| match &n.operation {
FheOperation::Add => NodeInfo::new(FheProgramOperation::Add),
FheOperation::InputCiphertext => {
// HACKHACK: Input nodes are always added first to the graph in the order
// they're specified as function arguments. We should not depend on this.
NodeInfo::new(FheProgramOperation::InputCiphertext(id.index()))
}
FheOperation::InputPlaintext => {
// HACKHACK: Input nodes are always added first to the graph in the order
// they're specified as function arguments. We should not depend on this.
NodeInfo::new(FheProgramOperation::InputPlaintext(id.index()))
}
FheOperation::Literal(Literal::U64(x)) => {
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::U64(*x)))
}
FheOperation::Literal(Literal::Plaintext(x)) => {
// It's okay to unwrap here because fhe_program compilation will
// catch the panic and return a compilation error.
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::Plaintext(
x.to_bytes().expect("Failed to serialize plaintext."),
)))
}
FheOperation::Sub => NodeInfo::new(FheProgramOperation::Sub),
FheOperation::SubPlaintext => NodeInfo::new(FheProgramOperation::SubPlaintext),
FheOperation::Negate => NodeInfo::new(FheProgramOperation::Negate),
FheOperation::Multiply => NodeInfo::new(FheProgramOperation::Multiply),
FheOperation::MultiplyPlaintext => {
NodeInfo::new(FheProgramOperation::MultiplyPlaintext)
}
FheOperation::Output => NodeInfo::new(FheProgramOperation::OutputCiphertext),
FheOperation::RotateLeft => NodeInfo::new(FheProgramOperation::ShiftLeft),
FheOperation::RotateRight => NodeInfo::new(FheProgramOperation::ShiftRight),
FheOperation::SwapRows => NodeInfo::new(FheProgramOperation::SwapRows),
FheOperation::AddPlaintext => NodeInfo::new(FheProgramOperation::AddPlaintext),
},
|_, e| match e {
EdgeInfo::Left => FheProgramEdgeInfo::LeftOperand,
EdgeInfo::Right => FheProgramEdgeInfo::RightOperand,
EdgeInfo::Unary => FheProgramEdgeInfo::UnaryOperand,
},
);
fhe_program.graph = mapped_graph;
compile_inplace(fhe_program)
}
}

View File

@@ -1,7 +1,8 @@
use crate::fhe::{FheCompile, FheFrontendCompilation};
use crate::params::{determine_params, PlainModulusConstraint};
use crate::{
Application, CallSignature, Error, FheProgramMetadata, FrontendCompilation, Params,
RequiredKeys, Result, SchemeType, SecurityLevel,
Application, CallSignature, Error, FheProgramMetadata, Params, RequiredKeys, Result,
SchemeType, SecurityLevel,
};
use std::collections::{HashMap, HashSet};
use sunscreen_runtime::CompiledFheProgram;
@@ -24,7 +25,7 @@ pub trait FheProgramFn {
/**
* Compile the `#[fhe_program]`.
*/
fn build(&self, params: &Params) -> Result<FrontendCompilation>;
fn build(&self, params: &Params) -> Result<FheFrontendCompilation>;
/**
* Get the scheme type.

View File

@@ -7,7 +7,7 @@
//! # Examples
//! This example is further annotated in `examples/simple_multiply`.
//! ```
//! # use sunscreen::{fhe_program, Compiler, types::{bfv::Signed, Cipher}, PlainModulusConstraint, Params, Runtime, Context};
//! # use sunscreen::{fhe_program, Compiler, types::{bfv::Signed, Cipher}, PlainModulusConstraint, Params, Runtime};
//!
//! #[fhe_program(scheme = "bfv")]
//! fn simple_multiply(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
@@ -39,8 +39,14 @@
//!
mod error;
/**
* This module contains types used internally when compiling
* [`fhe_program`]s.
*/
pub mod fhe;
mod fhe_compiler;
mod params;
mod zkp;
/**
* This module contains types used during [`fhe_program`] construction.
@@ -57,21 +63,13 @@ mod params;
*/
pub mod types;
use petgraph::{
algo::is_isomorphic_matching,
stable_graph::{NodeIndex, StableGraph},
Graph,
};
use fhe::{FheOperation, Literal};
use petgraph::stable_graph::StableGraph;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use sunscreen_backend::compile_inplace;
use sunscreen_fhe_program::{
EdgeInfo, FheProgram, Literal as FheProgramLiteral, NodeInfo, Operation as FheProgramOperation,
};
pub use error::{Error, Result};
pub use fhe_compiler::{Compiler, FheProgramFn};
pub use params::PlainModulusConstraint;
@@ -83,6 +81,8 @@ pub use sunscreen_runtime::{
FheProgramInputTrait, FheProgramMetadata, InnerCiphertext, InnerPlaintext, Params, Plaintext,
PrivateKey, PublicKey, RequiredKeys, Runtime, WithContext,
};
pub use zkp::ZkpProgramFn;
pub use zkp::{with_zkp_ctx, ZkpContext, ZkpFrontendCompilation, CURRENT_ZKP_CTX};
#[derive(Clone, Serialize, Deserialize)]
/**
@@ -141,98 +141,6 @@ impl Application {
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
/**
* Represents a literal node's data.
*/
pub enum Literal {
/**
* An unsigned 64-bit integer.
*/
U64(u64),
/**
* An encoded plaintext value.
*/
Plaintext(InnerPlaintext),
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
/**
* Represents an operation occurring in the frontend AST.
*/
pub enum Operation {
/**
* This node indicates loading a cipher text from an input.
*/
InputCiphertext,
/**
* This node indicates loading a plaintext from an input.
*/
InputPlaintext,
/**
* Addition.
*/
Add,
/**
* Add a ciphertext and plaintext value.
*/
AddPlaintext,
/**
* Subtraction.
*/
Sub,
/**
* Subtract a plaintext.
*/
SubPlaintext,
/**
* Unary negation (i.e. given x, compute -x)
*/
Negate,
/**
* Multiplication.
*/
Multiply,
/**
* Multiply a ciphertext by a plaintext.
*/
MultiplyPlaintext,
/**
* A literal that serves as an operand to other operations.
*/
Literal(Literal),
/**
* Rotate left.
*/
RotateLeft,
/**
* Rotate right.
*/
RotateRight,
/**
* In the BFV scheme, swap rows in the Batched vectors.
*/
SwapRows,
/**
* This node indicates the previous node's result should be a result of the [`fhe_program`].
*/
Output,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
/**
* Information about an edge in the frontend IR.
@@ -277,287 +185,13 @@ pub struct FrontendCompilation {
/**
* The dependency graph of the frontend's intermediate representation (IR) that backs an [`fhe_program`].
*/
pub graph: StableGraph<Operation, OperandInfo>,
}
#[derive(Clone, Debug)]
/**
* The context for constructing the [`fhe_program`] graph during compilation.
*
* This is an implementation detail of the [`fhe_program`] macro, and you shouldn't need
* to construct one.
*/
pub struct Context {
/**
* The frontend compilation result.
*/
pub compilation: FrontendCompilation,
/**
* The set of parameters for which we're currently constructing the graph.
*/
pub params: Params,
/**
* Stores indicies for graph nodes in a bump allocator. [`FheProgramNode`](crate::types::intern::FheProgramNode)
* can request allocations of these. This allows it to use slices instead of Vecs, which allows
* FheProgramNode to impl Copy.
*/
pub indicies_store: Vec<NodeIndex>,
}
impl PartialEq for FrontendCompilation {
fn eq(&self, b: &Self) -> bool {
is_isomorphic_matching(
&Graph::from(self.graph.clone()),
&Graph::from(b.graph.clone()),
|n1, n2| n1 == n2,
|e1, e2| e1 == e2,
)
}
pub graph: StableGraph<FheOperation, OperandInfo>,
}
thread_local! {
/**
* While constructing an [`fhe_program`], this refers to the current intermediate
* representation. An implementation detail of the [`fhe_program`] macro.
*/
pub static CURRENT_CTX: RefCell<Option<&'static mut Context>> = RefCell::new(None);
/**
* An arena containing slices of indicies. An implementation detail of the
* [`fhe_program`] macro.
*/
pub static INDEX_ARENA: RefCell<bumpalo::Bump> = RefCell::new(bumpalo::Bump::new());
}
/**
* Runs the specified closure, injecting the current [`fhe_program`] context.
*/
pub fn with_ctx<F, R>(f: F) -> R
where
F: FnOnce(&mut Context) -> R,
{
CURRENT_CTX.with(|ctx| {
let mut option = ctx.borrow_mut();
let ctx = option
.as_mut()
.expect("Called Ciphertext::new() outside of a context.");
f(ctx)
})
}
impl Context {
/**
* Creates a new empty frontend intermediate representation context with the given scheme.
*/
pub fn new(params: &Params) -> Self {
Self {
compilation: FrontendCompilation {
graph: StableGraph::new(),
},
params: params.clone(),
indicies_store: vec![],
}
}
fn add_2_input(&mut self, op: Operation, left: NodeIndex, right: NodeIndex) -> NodeIndex {
let new_id = self.compilation.graph.add_node(op);
self.compilation
.graph
.add_edge(left, new_id, OperandInfo::Left);
self.compilation
.graph
.add_edge(right, new_id, OperandInfo::Right);
new_id
}
fn add_1_input(&mut self, op: Operation, i: NodeIndex) -> NodeIndex {
let new_id = self.compilation.graph.add_node(op);
self.compilation
.graph
.add_edge(i, new_id, OperandInfo::Unary);
new_id
}
/**
* Add an input to this context.
*/
pub fn add_ciphertext_input(&mut self) -> NodeIndex {
self.compilation.graph.add_node(Operation::InputCiphertext)
}
/**
* Add an input to this context.
*/
pub fn add_plaintext_input(&mut self) -> NodeIndex {
self.compilation.graph.add_node(Operation::InputPlaintext)
}
/**
* Adds a plaintext literal to the [`fhe_program`] graph.
*/
pub fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex {
self.compilation
.graph
.add_node(Operation::Literal(Literal::Plaintext(plaintext)))
}
/**
* Add a subtraction to this context.
*/
pub fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::Sub, left, right)
}
/**
* Add a subtraction to this context.
*/
pub fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::SubPlaintext, left, right)
}
/**
* Adds a negation to this context.
*/
pub fn add_negate(&mut self, x: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::Negate, x)
}
/**
* Add an addition to this context.
*/
pub fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::Add, left, right)
}
/**
* Adds an addition to a plaintext.
*/
pub fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::AddPlaintext, left, right)
}
/**
* Add a multiplication to this context.
*/
pub fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::Multiply, left, right)
}
/**
* Add a multiplication to this context.
*/
pub fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::MultiplyPlaintext, left, right)
}
/**
* Adds a literal to this context.
*/
pub fn add_literal(&mut self, literal: Literal) -> NodeIndex {
// See if we already have a node for the given literal. If so, just return it.
// If not, make a new one.
let existing_literal =
self.compilation
.graph
.node_indices()
.find(|&i| match &self.compilation.graph[i] {
Operation::Literal(x) => *x == literal,
_ => false,
});
match existing_literal {
Some(x) => x,
None => self.compilation.graph.add_node(Operation::Literal(literal)),
}
}
/**
* Add a rotate left.
*/
pub fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::RotateLeft, left, right)
}
/**
* Add a rotate right.
*/
pub fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_2_input(Operation::RotateRight, left, right)
}
/**
* Adds a row swap.
*/
pub fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::SwapRows, x)
}
/**
* Add a node that captures the previous node as an output.
*/
pub fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
self.add_1_input(Operation::Output, i)
}
}
impl FrontendCompilation {
/**
* Performs frontend compilation of this intermediate representation into a backend [`FheProgram`],
* then perform backend compilation and return the result.
*/
pub fn compile(&self) -> FheProgram {
let mut fhe_program = FheProgram::new(SchemeType::Bfv);
let mapped_graph = self.graph.map(
|id, n| match n {
Operation::Add => NodeInfo::new(FheProgramOperation::Add),
Operation::InputCiphertext => {
// HACKHACK: Input nodes are always added first to the graph in the order
// they're specified as function arguments. We should not depend on this.
NodeInfo::new(FheProgramOperation::InputCiphertext(id.index()))
}
Operation::InputPlaintext => {
// HACKHACK: Input nodes are always added first to the graph in the order
// they're specified as function arguments. We should not depend on this.
NodeInfo::new(FheProgramOperation::InputPlaintext(id.index()))
}
Operation::Literal(Literal::U64(x)) => {
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::U64(*x)))
}
Operation::Literal(Literal::Plaintext(x)) => {
// It's okay to unwrap here because fhe_program compilation will
// catch the panic and return a compilation error.
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::Plaintext(
x.to_bytes().expect("Failed to serialize plaintext."),
)))
}
Operation::Sub => NodeInfo::new(FheProgramOperation::Sub),
Operation::SubPlaintext => NodeInfo::new(FheProgramOperation::SubPlaintext),
Operation::Negate => NodeInfo::new(FheProgramOperation::Negate),
Operation::Multiply => NodeInfo::new(FheProgramOperation::Multiply),
Operation::MultiplyPlaintext => {
NodeInfo::new(FheProgramOperation::MultiplyPlaintext)
}
Operation::Output => NodeInfo::new(FheProgramOperation::OutputCiphertext),
Operation::RotateLeft => NodeInfo::new(FheProgramOperation::ShiftLeft),
Operation::RotateRight => NodeInfo::new(FheProgramOperation::ShiftRight),
Operation::SwapRows => NodeInfo::new(FheProgramOperation::SwapRows),
Operation::AddPlaintext => NodeInfo::new(FheProgramOperation::AddPlaintext),
},
|_, e| match e {
OperandInfo::Left => EdgeInfo::LeftOperand,
OperandInfo::Right => EdgeInfo::RightOperand,
OperandInfo::Unary => EdgeInfo::UnaryOperand,
},
);
fhe_program.graph = mapped_graph;
compile_inplace(fhe_program)
}
}

View File

@@ -1,4 +1,4 @@
use crate::{Error, FheProgramFn, Result, SecurityLevel};
use crate::{fhe::FheCompile, Error, FheProgramFn, Result, SecurityLevel};
use log::{debug, trace};

View File

@@ -1,11 +1,12 @@
use crate::{
fhe::{with_fhe_ctx, FheContextOps, Literal},
types::{
intern::{Cipher, FheProgramNode},
ops::*,
BfvType, FheType, LaneCount, NumCiphertexts, SwapRows, TryFromPlaintext, TryIntoPlaintext,
Type, TypeName, TypeNameInstance, Version,
},
with_ctx, FheProgramInputTrait, InnerPlaintext, Literal, Params, Plaintext, WithContext,
FheProgramInputTrait, InnerPlaintext, Params, Plaintext, WithContext,
};
use seal_fhe::{
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
@@ -509,7 +510,7 @@ impl<const LANES: usize> GraphCipherAdd for Batched<LANES> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -525,7 +526,7 @@ impl<const LANES: usize> GraphCipherSub for Batched<LANES> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -541,7 +542,7 @@ impl<const LANES: usize> GraphCipherMul for Batched<LANES> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -557,8 +558,8 @@ impl<const LANES: usize> GraphCipherConstMul for Batched<LANES> {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let l = ctx.add_plaintext_literal(b.inner);
let n = ctx.add_multiplication_plaintext(a.ids[0], l);
@@ -569,7 +570,7 @@ impl<const LANES: usize> GraphCipherConstMul for Batched<LANES> {
impl<const LANES: usize> GraphCipherSwapRows for Batched<LANES> {
fn graph_cipher_swap_rows(x: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_swap_rows(x.ids[0]);
FheProgramNode::new(&[n])
@@ -582,7 +583,7 @@ impl<const LANES: usize> GraphCipherRotateLeft for Batched<LANES> {
x: FheProgramNode<Cipher<Self>>,
y: u64,
) -> FheProgramNode<Cipher<Self>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let y = ctx.add_literal(Literal::U64(y));
let n = ctx.add_rotate_left(x.ids[0], y);
@@ -596,7 +597,7 @@ impl<const LANES: usize> GraphCipherRotateRight for Batched<LANES> {
x: FheProgramNode<Cipher<Self>>,
y: u64,
) -> FheProgramNode<Cipher<Self>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let y = ctx.add_literal(Literal::U64(y));
let n = ctx.add_rotate_right(x.ids[0], y);
@@ -609,7 +610,7 @@ impl<const LANES: usize> GraphCipherNeg for Batched<LANES> {
type Val = Self;
fn graph_cipher_neg(x: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_negate(x.ids[0]);
FheProgramNode::new(&[n])

View File

@@ -1,17 +1,20 @@
use seal_fhe::Plaintext as SealPlaintext;
use crate::types::{
ops::{
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
GraphCipherConstSub, GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd,
GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub,
GraphPlainCipherSub,
use crate::{
fhe::{with_fhe_ctx, FheContextOps},
types::{
ops::{
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
GraphCipherConstSub, GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd,
GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub,
GraphPlainCipherSub,
},
Cipher,
},
Cipher,
};
use crate::{
types::{intern::FheProgramNode, BfvType, FheType, Type, Version},
with_ctx, FheProgramInputTrait, Params, WithContext,
FheProgramInputTrait, Params, WithContext,
};
use sunscreen_runtime::{
@@ -209,7 +212,7 @@ impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -225,7 +228,7 @@ impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -241,8 +244,8 @@ impl<const INT_BITS: usize> GraphCipherConstAdd for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let n = ctx.add_addition_plaintext(a.ids[0], lit);
@@ -260,7 +263,7 @@ impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -276,7 +279,7 @@ impl<const INT_BITS: usize> GraphCipherPlainSub for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -292,7 +295,7 @@ impl<const INT_BITS: usize> GraphPlainCipherSub for Fractional<INT_BITS> {
a: FheProgramNode<Self::Left>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
let n = ctx.add_negate(n);
@@ -309,8 +312,8 @@ impl<const INT_BITS: usize> GraphCipherConstSub for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let n = ctx.add_subtraction_plaintext(a.ids[0], lit);
@@ -328,8 +331,8 @@ impl<const INT_BITS: usize> GraphConstCipherSub for Fractional<INT_BITS> {
a: Self::Left,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Right>> {
with_ctx(|ctx| {
let a = Self::from(a).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let a = Self::from(a).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(a.inner);
let n = ctx.add_subtraction_plaintext(b.ids[0], lit);
@@ -348,7 +351,7 @@ impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -364,7 +367,7 @@ impl<const INT_BITS: usize> GraphCipherPlainMul for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -380,8 +383,8 @@ impl<const INT_BITS: usize> GraphCipherConstMul for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let n = ctx.add_multiplication_plaintext(a.ids[0], lit);
@@ -399,10 +402,10 @@ impl<const INT_BITS: usize> GraphCipherConstDiv for Fractional<INT_BITS> {
a: FheProgramNode<Cipher<Self::Left>>,
b: f64,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let b = Self::try_from(1. / b)
.unwrap()
.try_into_plaintext(&ctx.params)
.try_into_plaintext(&ctx.data)
.unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
@@ -418,7 +421,7 @@ impl<const INT_BITS: usize> GraphCipherNeg for Fractional<INT_BITS> {
type Val = Fractional<INT_BITS>;
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_negate(a.ids[0]);
FheProgramNode::new(&[n])

View File

@@ -1,9 +1,10 @@
use crate::fhe::{with_fhe_ctx, FheContextOps};
use crate::types::{
bfv::Signed, intern::FheProgramNode, ops::*, BfvType, Cipher, FheType, GraphCipherAdd,
GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, TryFromPlaintext,
TryIntoPlaintext, TypeName,
};
use crate::{with_ctx, FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
use crate::{FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
use std::cmp::Eq;
use std::ops::*;
use sunscreen_runtime::Error;
@@ -274,7 +275,7 @@ impl GraphCipherAdd for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
@@ -297,7 +298,7 @@ impl GraphCipherPlainAdd for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
@@ -320,14 +321,14 @@ impl GraphCipherConstAdd for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let b_num =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
let b_den =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b_den);
@@ -351,7 +352,7 @@ impl GraphCipherSub for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
@@ -374,7 +375,7 @@ impl GraphCipherPlainSub for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
@@ -397,7 +398,7 @@ impl GraphPlainCipherSub for Rational {
a: FheProgramNode<Self::Left>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
let num_b_2 = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
@@ -420,13 +421,13 @@ impl GraphCipherConstSub for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let b_num =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
let b_den =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
// Scale each numinator by the other's denominator.
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b_den);
@@ -450,13 +451,13 @@ impl GraphConstCipherSub for Rational {
a: Self::Left,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Right>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let a = Self::try_from(a).unwrap();
let a_num =
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.data).unwrap().inner);
let a_den =
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.data).unwrap().inner);
// Scale each numinator by the other's denominator.
let num_b_2 = ctx.add_multiplication_plaintext(b.ids[0], a_den);
@@ -480,7 +481,7 @@ impl GraphCipherMul for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
let mul_den = ctx.add_multiplication(a.ids[1], b.ids[1]);
@@ -500,7 +501,7 @@ impl GraphCipherPlainMul for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
@@ -520,13 +521,13 @@ impl GraphCipherConstMul for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let num_b =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
let den_b =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], num_b);
@@ -547,7 +548,7 @@ impl GraphCipherDiv for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[1]);
let mul_den = ctx.add_multiplication(a.ids[1], b.ids[0]);
@@ -567,7 +568,7 @@ impl GraphCipherPlainDiv for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
@@ -587,7 +588,7 @@ impl GraphPlainCipherDiv for Rational {
a: FheProgramNode<Self::Left>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
let mul_den = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
@@ -607,13 +608,13 @@ impl GraphCipherConstDiv for Rational {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let b = Self::try_from(b).unwrap();
let num_b =
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
let den_b =
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], den_b);
@@ -634,13 +635,13 @@ impl GraphConstCipherDiv for Rational {
a: Self::Left,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Right>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let a = Self::try_from(a).unwrap();
let num_a =
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.data).unwrap().inner);
let den_a =
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.data).unwrap().inner);
// Scale each numinator by the other's denominator.
let mul_num = ctx.add_multiplication_plaintext(b.ids[1], num_a);
@@ -657,7 +658,7 @@ impl GraphCipherNeg for Rational {
type Val = Self;
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self::Val>>) -> FheProgramNode<Cipher<Self::Val>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let neg = ctx.add_negate(a.ids[0]);
let ids = [neg, a.ids[1]];

View File

@@ -1,16 +1,19 @@
use seal_fhe::Plaintext as SealPlaintext;
use crate::types::{
ops::{
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstMul, GraphCipherConstSub,
GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd, GraphCipherPlainMul,
GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub, GraphPlainCipherSub,
use crate::{
fhe::{with_fhe_ctx, FheContextOps},
types::{
ops::{
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstMul, GraphCipherConstSub,
GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd, GraphCipherPlainMul,
GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub, GraphPlainCipherSub,
},
Cipher,
},
Cipher,
};
use crate::{
types::{intern::FheProgramNode, BfvType, FheType, TypeNameInstance},
with_ctx, FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext,
FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext,
};
use sunscreen_runtime::{
@@ -249,7 +252,7 @@ impl GraphCipherAdd for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -265,7 +268,7 @@ impl GraphCipherPlainAdd for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -281,8 +284,8 @@ impl GraphCipherConstAdd for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: i64,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let add = ctx.add_addition_plaintext(a.ids[0], lit);
@@ -300,7 +303,7 @@ impl GraphCipherSub for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -316,7 +319,7 @@ impl GraphCipherPlainSub for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -332,7 +335,7 @@ impl GraphPlainCipherSub for Signed {
a: FheProgramNode<Self::Left>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
let n = ctx.add_negate(n);
@@ -349,8 +352,8 @@ impl GraphCipherConstSub for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let n = ctx.add_subtraction_plaintext(a.ids[0], lit);
@@ -368,8 +371,8 @@ impl GraphConstCipherSub for Signed {
a: i64,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Right>> {
with_ctx(|ctx| {
let a = Self::from(a).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let a = Self::from(a).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(a.inner);
let n = ctx.add_subtraction_plaintext(b.ids[0], lit);
@@ -384,7 +387,7 @@ impl GraphCipherNeg for Signed {
type Val = Signed;
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_negate(a.ids[0]);
FheProgramNode::new(&[n])
@@ -400,7 +403,7 @@ impl GraphCipherMul for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
@@ -416,8 +419,8 @@ impl GraphCipherConstMul for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: i64,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
with_fhe_ctx(|ctx| {
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(b.inner);
let add = ctx.add_multiplication_plaintext(a.ids[0], lit);
@@ -435,7 +438,7 @@ impl GraphCipherPlainMul for Signed {
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_ctx(|ctx| {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])

View File

@@ -1,9 +1,10 @@
use crate::{
fhe::with_fhe_ctx,
types::{
intern::FheLiteral, ops::*, Cipher, FheType, LaneCount, NumCiphertexts, SwapRows, Type,
TypeName,
},
with_ctx, INDEX_ARENA,
INDEX_ARENA,
};
use petgraph::stable_graph::NodeIndex;
@@ -36,8 +37,8 @@ use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub};
* construction.
*
* # Undefined behavior
* These types must be constructed while a [`crate::CURRENT_CTX`] refers to a valid
* [`crate::Context`]. Furthermore, no [`FheProgramNode`] should outlive the said context.
* These types must be constructed while [`CURRENT_FHE_CTX`][crate::fhe::CURRENT_FHE_CTX] refers to a valid
* [`FheContext`](crate::fhe::FheContext). Furthermore, no [`FheProgramNode`] should outlive the said context.
* Violating any of these conditions may result in memory corruption or
* use-after-free.
*/
@@ -93,7 +94,7 @@ impl<T: NumCiphertexts> FheProgramNode<T> {
* Returns the plain modulus parameter for the given BFV scheme
*/
pub fn get_plain_modulus() -> u64 {
with_ctx(|ctx| ctx.params.plain_modulus)
with_fhe_ctx(|ctx| ctx.data.plain_modulus)
}
}

View File

@@ -1,7 +1,5 @@
pub use crate::{
types::{intern::FheProgramNode, Cipher, FheType, NumCiphertexts, TypeName},
with_ctx,
};
use crate::fhe::{with_fhe_ctx, FheContextOps};
pub use crate::types::{intern::FheProgramNode, Cipher, FheType, NumCiphertexts, TypeName};
/**
* Create an input node from an Fhe Program input argument.
@@ -18,7 +16,8 @@ pub trait Input {
* You should not call this, but rather allow the [`fhe_program`](crate::fhe_program) macro to do this on your behalf.
*
* # Undefined behavior
* This type references memory in a backing [`crate::Context`] and without carefully ensuring FheProgramNodes
* This type references memory in a backing
* [`FheContext`](crate::fhe::FheContext) and without carefully ensuring FheProgramNodes
* never outlive the backing context, use-after-free can occur.
*
*/
@@ -36,9 +35,9 @@ where
for _ in 0..T::NUM_CIPHERTEXTS {
if T::type_name().is_encrypted {
ids.push(with_ctx(|ctx| ctx.add_ciphertext_input()));
ids.push(with_fhe_ctx(|ctx| ctx.add_ciphertext_input()));
} else {
ids.push(with_ctx(|ctx| ctx.add_plaintext_input()));
ids.push(with_fhe_ctx(|ctx| ctx.add_plaintext_input()));
}
}
@@ -69,16 +68,17 @@ where
#[test]
fn can_create_inputs() {
use crate::{
fhe::{FheContext, FheOperation, CURRENT_FHE_CTX},
types::{bfv::Rational, intern::FheProgramNode},
Context, Operation, Params, SchemeType, SecurityLevel, CURRENT_CTX,
Params, SchemeType, SecurityLevel,
};
use std::cell::RefCell;
use std::mem::transmute;
use petgraph::stable_graph::NodeIndex;
CURRENT_CTX.with(|ctx| {
let mut context = Context::new(&Params {
CURRENT_FHE_CTX.with(|ctx| {
let mut context = FheContext::new(Params {
lattice_dimension: 0,
coeff_modulus: vec![],
plain_modulus: 0,
@@ -122,26 +122,26 @@ fn can_create_inputs() {
offset += 2 * 6 * 6;
assert_eq!(context.compilation.graph.node_count(), offset);
assert_eq!(context.graph.node_count(), offset);
for i in 0..2 {
assert_eq!(
context.compilation.graph[NodeIndex::from(i)],
Operation::InputPlaintext
context.graph[NodeIndex::from(i)].operation,
FheOperation::InputPlaintext
);
}
for i in 2..14 {
assert_eq!(
context.compilation.graph[NodeIndex::from(i)],
Operation::InputPlaintext
context.graph[NodeIndex::from(i)].operation,
FheOperation::InputPlaintext
);
}
for i in 14..context.compilation.graph.node_count() {
for i in 14..context.graph.node_count() {
assert_eq!(
context.compilation.graph[NodeIndex::from(i as u32)],
Operation::InputCiphertext
context.graph[NodeIndex::from(i as u32)].operation,
FheOperation::InputCiphertext
);
}
});

View File

@@ -1,6 +1,6 @@
use crate::{
fhe::{with_fhe_ctx, FheContextOps},
types::{intern::FheProgramNode, NumCiphertexts},
with_ctx,
};
/**
@@ -18,8 +18,10 @@ pub trait Output {
* You should not call this, but rather allow the [`fhe_program`](crate::fhe_program) macro to do this on your behalf.
*
* # Undefined behavior
* This type references memory in a backing [`crate::Context`] and without carefully ensuring FheProgramNodes
* never outlive the backing context, use-after-free can occur.
* This type references memory in a backing
* [`FheContext`](crate::fhe::FheContext) and without carefully
* ensuring FheProgramNodes never outlive the backing context,
* use-after-free can occur.
*/
fn output(&self) -> Self::Output;
}
@@ -34,7 +36,7 @@ where
let mut ids = Vec::with_capacity(self.ids.len());
for i in 0..self.ids.len() {
ids.push(with_ctx(|ctx| ctx.add_output(self.ids[i])));
ids.push(with_fhe_ctx(|ctx| ctx.add_output(self.ids[i])));
}
FheProgramNode::new(&ids)

View File

@@ -1,4 +1,7 @@
use crate::{with_ctx, Literal};
use crate::{
fhe::{with_fhe_ctx, FheContextOps},
Literal,
};
use petgraph::stable_graph::NodeIndex;
#[derive(Clone, Copy)]
@@ -13,6 +16,6 @@ impl U64LiteralRef {
* graph, a reference to the existing literal is returned.
*/
pub fn node(val: u64) -> NodeIndex {
with_ctx(|ctx| ctx.add_literal(Literal::U64(val)))
with_fhe_ctx(|ctx| ctx.add_literal(Literal::U64(val)))
}
}

View File

@@ -69,6 +69,11 @@ pub mod intern;
*/
mod ops;
/**
* Contains types used in creating zero-knowledge proof R1CS circuits.
*/
pub mod zkp;
use crate::types::ops::*;
pub use sunscreen_runtime::{

View File

@@ -0,0 +1,99 @@
mod native_field;
mod program_node;
pub use native_field::*;
pub use program_node::*;
/**
* A trait for adding two ZKP values together
*/
pub trait AddVar
where
Self: Sized + ZkpType,
{
/**
* Add the 2 values.
*/
fn add(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* A trait for multiplying two ZKP values together
*/
pub trait MulVar
where
Self: Sized + ZkpType,
{
/**
* Compute lhs * rhs.
*/
fn mul(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* A trait for dividing 2 zkp values.
*/
pub trait DivVar
where
Self: Sized + ZkpType,
{
/**
* Compute lhs / rhs.
*/
fn div(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* A trait for computing the Remainder of 2 zkp values.
*/
pub trait RemVar
where
Self: Sized + ZkpType,
{
/**
* Compute lhs % rhs;
*/
fn rem(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* A trait for subtracting 2 zkp values.
*/
pub trait SubVar
where
Self: Sized + ZkpType,
{
/**
* Compute lhs - rhs.
*/
fn sub(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* A trait for computing the additive inverse of a zkp value.
*/
pub trait NegVar
where
Self: Sized + ZkpType,
{
/**
* Compute -lhs.
*/
fn neg(lhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* The number of native field elements needed to represent a ZKP type.
*/
pub trait NumFieldElements {
/**
* The number of native field elements needed to represent this type.
*/
const NUM_NATIVE_FIELD_ELEMENTS: usize;
}
/**
* Encapsulates all the traits required for a type to be used in ZKP
* programs.
*/
pub trait ZkpType: NumFieldElements {}

View File

@@ -0,0 +1,54 @@
use sunscreen_compiler_macros::TypeName;
use crate::{
types::zkp::{AddVar, ProgramNode},
zkp::{with_zkp_ctx, ZkpContextOps},
};
use super::{MulVar, NegVar, NumFieldElements, ZkpType};
// Shouldn't need Clone + Copy, but there appears to be a bug in the Rust
// compiler that prevents ProgramNode from being Copy if we don't.
// https://github.com/rust-lang/rust/issues/104264
#[derive(Clone, Copy, TypeName)]
/**
* The native field type in the underlying backend proof system. For
* example, in Bulletproofs, this is [`Scalar`](https://docs.rs/curve25519-dalek-ng/4.1.1/curve25519_dalek_ng/scalar/struct.Scalar.html).
*/
pub struct NativeField {}
impl NumFieldElements for NativeField {
const NUM_NATIVE_FIELD_ELEMENTS: usize = 1;
}
impl ZkpType for NativeField {}
impl AddVar for NativeField {
fn add(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
with_zkp_ctx(|ctx| {
let o = ctx.add_addition(lhs.ids[0], rhs.ids[0]);
ProgramNode::new(&[o])
})
}
}
impl MulVar for NativeField {
fn mul(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
with_zkp_ctx(|ctx| {
let o = ctx.add_multiplication(lhs.ids[0], rhs.ids[0]);
ProgramNode::new(&[o])
})
}
}
impl NegVar for NativeField {
fn neg(lhs: ProgramNode<Self>) -> ProgramNode<Self> {
with_zkp_ctx(|ctx| {
let o = ctx.add_negate(lhs.ids[0]);
ProgramNode::new(&[o])
})
}
}

View File

@@ -0,0 +1,141 @@
use petgraph::stable_graph::NodeIndex;
use std::{
marker::PhantomData,
ops::{Add, Div, Mul, Neg, Rem, Sub},
};
use crate::{
types::zkp::{AddVar, DivVar, MulVar, NegVar, RemVar, SubVar, ZkpType},
zkp::{with_zkp_ctx, ZkpContextOps},
INDEX_ARENA,
};
#[derive(Clone, Copy)]
/**
* An implementation detail of the ZKP compiler. Each expression in a ZKP
* program is expressed in terms of `ProgramNode`, which proxy and compose
* the parse graph for a ZKP program.
*
* They proxy operations (+, -, /, etc) to their underlying type T to
* manipulate the program graph as appropriate.
*
* # Remarks
* For internal use only.
*/
pub struct ProgramNode<T>
where
T: ZkpType,
{
/**
* The indices in the graph that compose the type backing this
* `ProgramNode`.
*/
pub ids: &'static [NodeIndex],
_phantom: PhantomData<T>,
}
impl<T> ProgramNode<T>
where
T: ZkpType,
{
/**
* Create a new Program node from the given indicies in the
*/
pub fn new(ids: &[NodeIndex]) -> Self {
INDEX_ARENA.with(|allocator| {
let allocator = allocator.borrow();
let ids_dest = allocator.alloc_slice_copy(ids);
ids_dest.copy_from_slice(ids);
// The memory in the bump allocator is valid until we call reset, which
// we do after creating the FHE program. At this time, no FheProgramNodes should
// remain.
// We invoke the dark transmutation ritual to turn a finite lifetime into a 'static.
Self {
ids: unsafe { std::mem::transmute(ids_dest) },
_phantom: std::marker::PhantomData,
}
})
}
/**
* Creates a public program input of type T.
*/
pub fn input() -> Self {
let mut ids = Vec::with_capacity(T::NUM_NATIVE_FIELD_ELEMENTS);
for _ in 0..T::NUM_NATIVE_FIELD_ELEMENTS {
ids.push(with_zkp_ctx(|ctx| ctx.add_public_input()));
}
Self::new(&ids)
}
}
impl<T> Add<ProgramNode<T>> for ProgramNode<T>
where
T: AddVar + ZkpType,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
<T as AddVar>::add(self, rhs)
}
}
impl<T> Mul<ProgramNode<T>> for ProgramNode<T>
where
T: MulVar + ZkpType,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
<T as MulVar>::mul(self, rhs)
}
}
impl<T> Div<ProgramNode<T>> for ProgramNode<T>
where
T: DivVar + ZkpType,
{
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
<T as DivVar>::div(self, rhs)
}
}
impl<T> Rem<ProgramNode<T>> for ProgramNode<T>
where
T: RemVar + ZkpType,
{
type Output = Self;
fn rem(self, rhs: Self) -> Self::Output {
<T as RemVar>::rem(self, rhs)
}
}
impl<T> Sub<ProgramNode<T>> for ProgramNode<T>
where
T: SubVar + ZkpType,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
<T as SubVar>::sub(self, rhs)
}
}
impl<T> Neg for ProgramNode<T>
where
T: NegVar + ZkpType,
{
type Output = Self;
fn neg(self) -> Self::Output {
<T as NegVar>::neg(self)
}
}

167
sunscreen/src/zkp/mod.rs Normal file
View File

@@ -0,0 +1,167 @@
use sunscreen_runtime::CallSignature;
use crate::Result;
use std::cell::RefCell;
/**
* An internal representation of a ZKP program specification.
*/
pub trait ZkpProgramFn {
/**
* Create a circuit from this specification.
*/
fn build(&self) -> Result<ZkpFrontendCompilation>;
/**
* Gets the call signature for this program.
*/
fn signature(&self) -> CallSignature;
/**
* Gets the name of this program.
*/
fn name(&self) -> &str;
}
use std::fmt::Debug;
use petgraph::stable_graph::NodeIndex;
use sunscreen_compiler_common::{
Context, FrontendCompilation, Operation as OperationTrait, Render,
};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub enum Operation {
PrivateInput(NodeIndex),
PublicInput(NodeIndex),
HiddenInput(NodeIndex),
Add,
Sub,
Mul,
Neg,
}
impl OperationTrait for Operation {
fn is_binary(&self) -> bool {
matches!(self, Operation::Add | Operation::Mul | Operation::Sub)
}
fn is_commutative(&self) -> bool {
matches!(self, Operation::Add | Operation::Mul)
}
fn is_unary(&self) -> bool {
matches!(self, Operation::Neg)
}
}
impl Operation {
pub fn is_add(&self) -> bool {
matches!(self, Operation::Add)
}
pub fn is_sub(&self) -> bool {
matches!(self, Operation::Sub)
}
pub fn is_mul(&self) -> bool {
matches!(self, Operation::Mul)
}
pub fn is_neg(&self) -> bool {
matches!(self, Operation::Neg)
}
pub fn is_private_input(&self) -> bool {
matches!(self, Operation::PrivateInput(_))
}
pub fn is_public_input(&self) -> bool {
matches!(self, Operation::PublicInput(_))
}
pub fn is_hidden_input(&self) -> bool {
matches!(self, Operation::HiddenInput(_))
}
}
/**
* An implementation detail of a ZKP program. During compilation, it holds
* the graph of the program currently being constructed in an
* [`#[zkp_program]`](crate::zkp_program) function.
*
* # Remarks
* For internal use only.
*/
pub type ZkpContext = Context<Operation, u32>;
/**
* Contains the results of compiling a [`#[zkp_program]`](crate::zkp_program) function.
*
* # Remarks
* For internal use only.
*/
pub type ZkpFrontendCompilation = FrontendCompilation<Operation>;
pub trait ZkpContextOps {
fn add_public_input(&mut self) -> NodeIndex;
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_negate(&mut self, left: NodeIndex) -> NodeIndex;
}
impl ZkpContextOps for ZkpContext {
fn add_public_input(&mut self) -> NodeIndex {
let node = self.add_node(Operation::PublicInput(NodeIndex::from(self.data)));
self.data += 1;
node
}
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(Operation::Add, left, right)
}
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(Operation::Mul, left, right)
}
fn add_negate(&mut self, left: NodeIndex) -> NodeIndex {
self.add_unary_operation(Operation::Neg, left)
}
}
impl Render for Operation {
fn render(&self) -> String {
format!("{:?}", self)
}
}
thread_local! {
/**
* Contains the graph of a ZKP program during compilation. An
* implementation detail and not for public consumption.
*/
pub static CURRENT_ZKP_CTX: RefCell<Option<&'static mut ZkpContext>> = RefCell::new(None);
}
/**
* Runs the specified closure, injecting the current
* [`fhe_program`](crate::fhe_program) context.
*/
pub fn with_zkp_ctx<F, R>(f: F) -> R
where
F: FnOnce(&mut ZkpContext) -> R,
{
CURRENT_ZKP_CTX.with(|ctx| {
let mut option = ctx.borrow_mut();
let ctx = option
.as_mut()
.expect("Called with_zkp_ctx() outside of a context.");
f(ctx)
})
}

View File

@@ -1,8 +1,8 @@
use sunscreen::{
fhe::{FheFrontendCompilation, CURRENT_FHE_CTX},
fhe_program,
types::{bfv::Signed, Cipher, TypeName},
CallSignature, FheProgramFn, FrontendCompilation, Params, SchemeType, SecurityLevel,
CURRENT_CTX,
CallSignature, FheProgramFn, Params, SchemeType, SecurityLevel,
};
use serde_json::json;
@@ -46,7 +46,7 @@ fn fhe_program_gets_called() {
fn panicing_fhe_program_clears_ctx() {
#[fhe_program(scheme = "bfv")]
fn panic_fhe_program() {
CURRENT_CTX.with(|ctx| {
CURRENT_FHE_CTX.with(|ctx| {
let old = ctx.take();
assert!(old.is_some());
@@ -71,7 +71,7 @@ fn panicing_fhe_program_clears_ctx() {
assert!(panic_result.is_err());
CURRENT_CTX.with(|ctx| {
CURRENT_FHE_CTX.with(|ctx| {
let old = ctx.take();
assert!(old.is_none());
@@ -102,7 +102,7 @@ fn capture_fhe_program_input_args() {
let context = fhe_program_with_args.build(&get_params()).unwrap();
assert_eq!(context.graph.node_count(), 4);
assert_eq!(context.0.node_count(), 4);
}
#[test]
@@ -122,48 +122,45 @@ fn can_add() {
assert_eq!(fhe_program_with_args.signature(), expected_signature);
assert_eq!(fhe_program_with_args.scheme_type(), SchemeType::Bfv);
let context: FrontendCompilation = fhe_program_with_args.build(&get_params()).unwrap();
let context = fhe_program_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
"nodes": [
"InputCiphertext",
"InputCiphertext",
"InputCiphertext",
"Add",
"Add"
"nodes": [
{ "operation": "InputCiphertext" },
{ "operation": "InputCiphertext" },
{ "operation": "InputCiphertext" },
{ "operation": "Add" },
{ "operation": "Add" }
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
0,
3,
"Left"
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
0,
3,
"Left"
],
[
1,
3,
"Right"
],
[
3,
4,
"Left"
],
[
2,
4,
"Right"
]
[
1,
3,
"Right"
],
[
3,
4,
"Left"
],
[
2,
4,
"Right"
]
},
]
});
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
);
}
@@ -182,36 +179,33 @@ fn can_add_plaintext() {
assert_eq!(fhe_program_with_args.signature(), expected_signature);
assert_eq!(fhe_program_with_args.scheme_type(), SchemeType::Bfv);
let context: FrontendCompilation = fhe_program_with_args.build(&get_params()).unwrap();
let context = fhe_program_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
"nodes": [
"InputCiphertext",
"InputPlaintext",
"AddPlaintext",
"nodes": [
{ "operation": "InputCiphertext" },
{ "operation": "InputPlaintext" },
{ "operation": "AddPlaintext" },
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
0,
2,
"Left"
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
0,
2,
"Left"
],
[
1,
2,
"Right"
],
]
},
[
1,
2,
"Right"
],
]
});
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
);
}
@@ -235,44 +229,42 @@ fn can_mul() {
let context = fhe_program_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
"nodes": [
"InputCiphertext",
"InputCiphertext",
"InputCiphertext",
"Multiply",
"Multiply"
"nodes": [
{ "operation": "InputCiphertext" },
{ "operation": "InputCiphertext" },
{ "operation": "InputCiphertext" },
{ "operation": "Multiply" },
{ "operation": "Multiply" }
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
0,
3,
"Left"
],
"node_holes": [],
"edge_property": "directed",
"edges": [
[
0,
3,
"Left"
],
[
1,
3,
"Right"
],
[
3,
4,
"Left"
],
[
2,
4,
"Right"
]
[
1,
3,
"Right"
],
[
3,
4,
"Left"
],
[
2,
4,
"Right"
]
},
]
});
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
);
}
@@ -296,13 +288,12 @@ fn can_collect_output() {
let context = fhe_program_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
"nodes": [
"InputCiphertext",
"InputCiphertext",
"Multiply",
"Add",
"Output"
{ "operation": "InputCiphertext" },
{ "operation": "InputCiphertext" },
{ "operation": "Multiply" },
{ "operation": "Add" },
{ "operation": "Output" },
],
"node_holes": [],
"edge_property": "directed",
@@ -333,12 +324,13 @@ fn can_collect_output() {
"Unary"
]
]
},
});
dbg!(serde_json::to_string(&context).unwrap());
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
);
}
@@ -365,14 +357,13 @@ fn can_collect_multiple_outputs() {
let context = fhe_program_with_args.build(&get_params()).unwrap();
let expected = json!({
"graph": {
"nodes": [
"InputCiphertext",
"InputCiphertext",
"Multiply",
"Add",
"Output",
"Output"
{ "operation": "InputCiphertext" },
{ "operation": "InputCiphertext" },
{ "operation": "Multiply" },
{ "operation": "Add" },
{ "operation": "Output" },
{ "operation": "Output" },
],
"node_holes": [],
"edge_property": "directed",
@@ -408,11 +399,10 @@ fn can_collect_multiple_outputs() {
"Unary"
]
]
},
});
assert_eq!(
context,
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
);
}

View File

@@ -0,0 +1,11 @@
use sunscreen::{types::zkp::NativeField, zkp_program, ZkpProgramFn};
#[test]
fn can_add_and_mul_native_fields() {
#[zkp_program(backend = "bulletproofs")]
fn add_mul(a: NativeField, b: NativeField) {
a + b * a
}
add_mul.build().unwrap();
}

View File

@@ -7,5 +7,8 @@ edition = "2021"
[dependencies]
petgraph = "0.6.2"
proc-macro2 = "1.0.47"
quote = "1.0.21"
semver = "1.0.14"
serde = "1.0.147"
serde = { version = "1.0.147", features = ["derive"] }
syn = { version = "1.0.103", features = ["full"] }

View File

@@ -1,12 +1,15 @@
use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use petgraph::algo::is_isomorphic_matching;
use petgraph::stable_graph::{NodeIndex, StableGraph};
use petgraph::visit::{EdgeRef, IntoEdgeReferences, IntoNodeReferences};
use petgraph::Graph;
use serde::{Deserialize, Serialize};
use crate::{Operation, Render};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq, Eq)]
/**
* Information about a node in the compilation graph.
*/
@@ -29,7 +32,7 @@ where
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
/**
* Information about how one compiler graph node relates to another.
*/
@@ -79,7 +82,7 @@ impl Render for EdgeInfo {
}
}
#[derive(Clone)]
#[derive(Clone, Deserialize, Serialize)]
/**
* The result of a frontend compiler.
*/
@@ -98,6 +101,22 @@ where
}
}
impl<O> PartialEq for FrontendCompilation<O>
where
O: Operation,
{
/// FOR TESTING ONLY!!!
/// Graph isomorphism is an NP-Complete problem!
fn eq(&self, b: &Self) -> bool {
is_isomorphic_matching(
&Graph::from(self.0.clone()),
&Graph::from(b.0.clone()),
|n1, n2| n1 == n2,
|e1, e2| e1 == e2,
)
}
}
impl<O> Debug for FrontendCompilation<O>
where
O: Operation,
@@ -154,7 +173,7 @@ where
/**
* A compilation context. This stores the current parse graph.
*/
pub struct Context<O>
pub struct Context<O, D>
where
O: Operation,
{
@@ -165,22 +184,22 @@ where
#[allow(unused)]
/**
* Consumers can use this to uniquely number their inputs.
* Data given by the consumer.
*/
pub next_input_id: u32,
pub data: D,
}
impl<O> Context<O>
impl<O, D> Context<O, D>
where
O: Operation,
{
/**
* Create a new [`Context`].
*/
pub fn new() -> Self {
pub fn new(data: D) -> Self {
Self {
graph: FrontendCompilation::<O>::new(),
next_input_id: 0,
data,
}
}
@@ -221,12 +240,3 @@ where
node
}
}
impl<O> Default for Context<O>
where
O: Operation,
{
fn default() -> Self {
Self::new()
}
}

View File

@@ -62,8 +62,8 @@ impl<'a, N, E> GraphQuery<'a, N, E> {
*/
pub trait TransformList<N, E>
where
N: Copy,
E: Copy,
N: Clone,
E: Clone,
{
/**
* Apply the transformations.
@@ -87,8 +87,8 @@ where
*/
pub fn forward_traverse<N, E, F, T>(graph: &mut StableGraph<N, E>, callback: F)
where
N: Copy,
E: Copy,
N: Clone,
E: Clone,
T: TransformList<N, E>,
F: FnMut(GraphQuery<N, E>, NodeIndex) -> T,
{
@@ -111,8 +111,8 @@ where
*/
pub fn reverse_traverse<N, E, F, T>(graph: &mut StableGraph<N, E>, callback: F)
where
N: Copy,
E: Copy,
N: Clone,
E: Clone,
T: TransformList<N, E>,
F: FnMut(GraphQuery<N, E>, NodeIndex) -> T,
{
@@ -121,8 +121,8 @@ where
fn traverse<N, E, T, F>(graph: &mut StableGraph<N, E>, forward: bool, mut callback: F)
where
N: Copy,
E: Copy,
N: Clone,
E: Clone,
F: FnMut(GraphQuery<N, E>, NodeIndex) -> T,
T: TransformList<N, E>,
{

View File

@@ -6,6 +6,11 @@
mod context;
mod graph;
/**
* Helper methods for macros.
*/
pub mod macros;
/**
* A set of generic compiler transforms.
*/
@@ -41,7 +46,7 @@ pub trait Render {
*
* Also provides functions that describe properties of an operation.
*/
pub trait Operation: Clone + Copy + Debug + Hash + PartialEq + Eq {
pub trait Operation: Clone + Debug + Hash + PartialEq + Eq {
/**
* Whether or not this operation commutes.
*/

View File

@@ -0,0 +1,589 @@
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, quote, quote_spanned};
use syn::{
parse_quote, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, FnArg, Ident,
Index, Pat, ReturnType, Token, Type,
};
mod type_name;
pub use type_name::derive_typename_impl;
#[derive(Debug)]
/**
* A type error that occurs in a program specification.
*/
pub enum ProgramTypeError {
/**
* The given type is illegal.
*/
IllegalType(Span),
}
/**
* Given an input type T, returns
* * ProgramNode<T> when T is a Path
* * [map_input_type(T); N] when T is Array
*/
pub fn lift_type(arg_type: &Type) -> Result<Type, ProgramTypeError> {
let transformed_type = match arg_type {
Type::Path(ty) => parse_quote_spanned! {ty.span() => ProgramNode<#ty> },
Type::Array(a) => {
let inner_type = lift_type(&a.elem)?;
let len = &a.len;
parse_quote_spanned! {a.span() =>
[#inner_type; #len]
}
}
_ => {
return Err(ProgramTypeError::IllegalType(arg_type.span()));
}
};
Ok(transformed_type)
}
/**
* Emits code to make a program node for the given type T.
*/
pub fn create_program_node(var_name: &str, arg_type: &Type) -> TokenStream2 {
let mapped_type = match lift_type(arg_type) {
Ok(v) => v,
Err(ProgramTypeError::IllegalType(s)) => {
return quote_spanned! {
s => compile_error!("Unsupported program input type.")
};
}
};
let var_name = format_ident!("{}", var_name);
let type_annotation = match arg_type {
Type::Path(ty) => quote_spanned! { ty.span() => ProgramNode },
Type::Array(a) => quote_spanned! { a.span() =>
<#mapped_type>
},
_ => quote! {
compile_error!("fhe_program arguments' name must be a simple identifier and type must be a plain path.");
},
};
quote_spanned! {arg_type.span() =>
let #var_name: #mapped_type = #type_annotation::input();
}
}
#[derive(Debug)]
/**
* Errors that can occur when extracting the function signature of a
* program.
*/
pub enum ExtractFnArgumentsError {
/**
* The method contains a reference to `self` or `&self`.
*
* # Remarks
* FHE and ZKP programs must be pure functions.
*/
ContainsSelf(Span),
/**
* The given type is not allowed.
*/
IllegalType(Span),
/**
* The given type pattern is not of the a qualified path to a type.
*/
IllegalPat(Span),
}
/**
* Validate and parse the arguments of a function, returning a Vec of
* the types and identifiers.
*/
pub fn extract_fn_arguments(
args: &Punctuated<FnArg, Token!(,)>,
) -> Result<Vec<(&Type, &Ident)>, ExtractFnArgumentsError> {
let mut unwrapped_inputs = vec![];
for i in args {
let input_type = match i {
FnArg::Receiver(_) => {
return Err(ExtractFnArgumentsError::ContainsSelf(i.span()));
}
FnArg::Typed(t) => match (&*t.ty, &*t.pat) {
(Type::Path(_), Pat::Ident(i)) => (&*t.ty, &i.ident),
(Type::Array(_), Pat::Ident(i)) => (&*t.ty, &i.ident),
_ => {
match &*t.pat {
Pat::Ident(_) => {}
_ => {
return Err(ExtractFnArgumentsError::IllegalPat(t.span()));
}
};
return Err(ExtractFnArgumentsError::IllegalType(t.span()));
}
},
};
unwrapped_inputs.push(input_type);
}
Ok(unwrapped_inputs)
}
#[derive(Debug)]
/**
* Errors that can occur when extracting the return value of an FHE
* program.
*/
pub enum ExtractReturnTypesError {
/**
* The given return type is not allowed.
*
* # Remarks
* ZKP programs don't return values.
*
* FHE programs must return either
* * nothing (weird, but legal).
* * a single FHE type.
* * a tuple of FHE types.
*
*/
IllegalType(Span),
}
impl From<ProgramTypeError> for ExtractReturnTypesError {
fn from(e: ProgramTypeError) -> Self {
match e {
ProgramTypeError::IllegalType(s) => Self::IllegalType(s),
}
}
}
/**
* Unpacks the return types from a `ReturnType` and flattens them
* into a Vec.
* * Tuples will have more than one.
* * Path, Paren, and Arrays will have one.
* * Default has zero.
*/
pub fn extract_return_types(ret: &ReturnType) -> Result<Vec<Type>, ExtractReturnTypesError> {
let return_types = match ret {
ReturnType::Type(_, t) => match &**t {
Type::Tuple(t) => t.elems.iter().cloned().collect::<Vec<Type>>(),
Type::Paren(t) => {
vec![*t.elem.clone()]
}
Type::Path(_) => {
vec![*t.clone()]
}
Type::Array(_) => {
vec![*t.clone()]
}
_ => return Err(ExtractReturnTypesError::IllegalType(t.span())),
},
ReturnType::Default => {
vec![]
}
};
Ok(return_types)
}
/**
* Takes an array of return types and packages them into a tuple
* if needed.
*/
pub fn pack_return_type(return_types: &[Type]) -> Type {
match return_types.len() {
0 => parse_quote! { () },
1 => return_types[0].clone(),
_ => {
parse_quote_spanned! {return_types[0].span() => ( #(#return_types),* ) }
}
}
}
/**
* Emits code to create output nodes for each returned value in an
* FHE program.
*/
pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 {
match return_types.len() {
1 => quote_spanned! { return_types[0].span() => v.output(); },
_ => return_types
.iter()
.enumerate()
.map(|(i, t)| {
let index = Index::from(i);
quote_spanned! {t.span() =>
v.#index.output();
}
})
.collect(),
}
}
/**
* Emits the call signature of an FHE or ZKP program.
*/
pub fn emit_signature(args: &[Type], return_types: &[Type]) -> TokenStream2 {
let arg_type_names = args
.iter()
.enumerate()
.map(|(i, t)| {
let alias = format_ident!("T{}", i);
quote! {
type #alias = #t;
}
})
.collect::<Vec<TokenStream2>>();
let arg_get_types = arg_type_names.iter().enumerate().map(|(i, _)| {
let alias = format_ident!("T{}", i);
quote! {
#alias::type_name(),
}
});
let return_type_aliases = return_types.iter().enumerate().map(|(i, t)| {
let alias = format_ident!("R{}", i);
quote! {
type #alias = #t;
}
});
let return_type_names = return_types.iter().enumerate().map(|(i, _)| {
let alias = format_ident!("R{}", i);
quote! {
#alias ::type_name(),
}
});
let return_type_sizes = return_types.iter().enumerate().map(|(i, _)| {
let alias = format_ident!("R{}", i);
quote! {
#alias ::NUM_CIPHERTEXTS,
}
});
quote! {
use sunscreen::types::TypeName;
#(#arg_type_names)*
#(#return_type_aliases)*
sunscreen::CallSignature {
arguments: vec![#(#arg_get_types)*],
returns: vec![#(#return_type_names)*],
num_ciphertexts: vec![#(#return_type_sizes)*],
}
}
}
#[cfg(test)]
mod test {
use super::*;
use quote::ToTokens;
use syn::parse_quote;
fn assert_syn_eq<T, U>(a: &T, b: &U)
where
T: ToTokens,
U: ToTokens,
{
assert_eq!(
format!("{}", a.to_token_stream()),
format!("{}", b.to_token_stream())
);
}
fn assert_syn_slice_eq<T>(a: &[T], b: &[T])
where
T: ToTokens,
{
assert_eq!(a.len(), b.len());
for (l, r) in a.iter().zip(b) {
assert_syn_eq(l, r);
}
}
#[test]
fn transform_plain_scalar_type() {
let type_name = quote! {
Rational
};
let type_name: Type = parse_quote!(#type_name);
let actual = lift_type(&type_name).unwrap();
let expected: Type = parse_quote! {
ProgramNode<Rational>
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn transform_array_type() {
let type_name = quote! {
[Rational; 6]
};
let type_name: Type = parse_quote!(#type_name);
let actual = lift_type(&type_name).unwrap();
let expected: Type = parse_quote! {
[ProgramNode<Rational>; 6]
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn transform_multi_dimensional_array_type() {
let type_name = quote! {
[[Rational; 6]; 7]
};
let type_name: Type = parse_quote!(#type_name);
let actual = lift_type(&type_name).unwrap();
let expected: Type = parse_quote! {
[[ProgramNode<Rational>; 6]; 7]
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn transform_multi_dimensional_array_cipher_type() {
let type_name = quote! {
[[Cipher<Rational>; 6]; 7]
};
let type_name: Type = parse_quote!(#type_name);
let actual = lift_type(&type_name).unwrap();
let expected: Type = parse_quote! {
[[ProgramNode<Cipher<Rational> >; 6]; 7]
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn can_create_simple_fhe_program_node() {
let type_name = quote! {
Cipher<Rational>
};
let type_name: Type = parse_quote!(#type_name);
let actual = create_program_node("horse", &type_name);
let expected = quote! {
let horse: ProgramNode<Cipher<Rational> > = ProgramNode::input();
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn can_create_array_program_node() {
let type_name = quote! {
[Cipher<Rational>; 7]
};
let type_name: Type = parse_quote!(#type_name);
let actual = create_program_node("horse", &type_name);
let expected = quote! {
let horse: [ProgramNode<Cipher<Rational> >; 7] = <[ProgramNode<Cipher<Rational> >; 7]>::input();
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn can_create_multidimensional_array_program_node() {
let type_name = quote! {
[[Cipher<Rational>; 7]; 6]
};
let type_name: Type = parse_quote!(#type_name);
let actual = create_program_node("horse", &type_name);
let expected = quote! {
let horse: [[ProgramNode<Cipher<Rational> >; 7]; 6] = <[[ProgramNode<Cipher<Rational> >; 7]; 6]>::input();
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn can_extract_arguments() {
let type_name = quote! {
a: [[Cipher<Rational>; 7]; 6], b: Cipher<Rational>
};
let args: Punctuated<FnArg, Token!(,)> = parse_quote!(#type_name);
let extracted = extract_fn_arguments(&args).unwrap();
let expected_t0: Type = parse_quote! { [[Cipher<Rational>; 7]; 6] };
let expected_t1: Type = parse_quote! { Cipher<Rational> };
let expected_i0: Ident = parse_quote! { a };
let expected_i1: Ident = parse_quote! { b };
assert_eq!(extracted.len(), 2);
assert_syn_eq(extracted[0].0, &expected_t0);
assert_syn_eq(extracted[0].1, &expected_i0);
assert_syn_eq(extracted[1].0, &expected_t1);
assert_syn_eq(extracted[1].1, &expected_i1);
}
#[test]
fn disallows_self_arguments() {
let type_name = quote! {
&self, a: [[Cipher<Rational>; 7]; 6], b: Cipher<Rational>
};
let args: Punctuated<FnArg, Token!(,)> = parse_quote!(#type_name);
let extracted = extract_fn_arguments(&args);
match extracted {
Err(ExtractFnArgumentsError::ContainsSelf(_)) => {}
_ => {
panic!("Expected ExtractFnArgumentsError::ContainsSelf");
}
};
}
#[test]
fn can_extract_no_return_type() {
let return_type: Type = parse_quote! {
()
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
assert_syn_slice_eq(&extracted, &[]);
}
#[test]
fn can_extract_single_return() {
let return_type: Type = parse_quote! {
Cipher<Signed>
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
assert_syn_slice_eq(&extracted, &[parse_quote! { Cipher<Signed> }]);
}
#[test]
fn can_extract_single_paren_return() {
let return_type: Type = parse_quote! {
(Cipher<Signed>)
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
assert_syn_slice_eq(&extracted, &[parse_quote! { Cipher<Signed> }]);
}
#[test]
fn can_extract_single_array_return() {
let return_type: Type = parse_quote! {
[[Cipher<Signed>; 6]; 7]
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
assert_syn_slice_eq(&extracted, &[parse_quote! { [[Cipher<Signed>; 6]; 7] }]);
}
#[test]
fn can_extract_single_multiarray_return() {
let return_type: Type = parse_quote! {
([[Cipher<Signed>; 6]; 7], Cipher<Signed>)
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
assert_syn_slice_eq(
&extracted,
&[
parse_quote! { [[Cipher<Signed>; 6]; 7] },
parse_quote! { Cipher<Signed> },
],
);
}
#[test]
fn can_capture_single_output() {
let return_type: Type = parse_quote! {
(Cipher<Signed>)
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
let actual = emit_output_capture(&extracted);
let expected = quote! {
v.output();
};
assert_syn_eq(&actual, &expected);
}
#[test]
fn can_capture_multiple_outputs() {
let return_type: Type = parse_quote! {
(Cipher<Signed>, [[Cipher<Signed>; 6]; 7])
};
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
let extracted = extract_return_types(&return_type).unwrap();
let actual = emit_output_capture(&extracted);
let expected = quote! {
v.0.output();
v.1.output();
};
assert_syn_eq(&actual, &expected);
}
}

View File

@@ -0,0 +1,46 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{DeriveInput, Ident, LitStr};
/**
* The implementation for #[derive(TypeName)]
*/
pub fn derive_typename_impl(parse_stream: DeriveInput) -> TokenStream {
let name = &parse_stream.ident;
let name_contents = LitStr::new(&format!("{{}}::{}", name), name.span());
let crate_name = std::env::var("CARGO_CRATE_NAME").unwrap();
// If the sunscreen crate itself tries to derive types, then it needs to refer
// to itself in the first-person as "crate", not in the third-person as "sunscreen"
let sunscreen_path = if crate_name == "sunscreen" {
Ident::new("crate", Span::call_site())
} else {
Ident::new("sunscreen", Span::call_site())
};
quote! {
impl #sunscreen_path ::types::TypeName for #name {
fn type_name() -> #sunscreen_path ::types::Type {
let version = env!("CARGO_PKG_VERSION");
#sunscreen_path ::types::Type {
name: format!(#name_contents, module_path!()),
version: #sunscreen_path ::types::Version ::parse(version).expect("Crate version is not a valid semver"),
is_encrypted: false
}
}
}
impl #sunscreen_path ::types::TypeNameInstance for #name {
fn type_name_instance(&self) -> #sunscreen_path ::types::Type {
let version = env!("CARGO_PKG_VERSION");
#sunscreen_path ::types::Type {
name: format!(#name_contents, module_path!()),
version: #sunscreen_path ::types::Version ::parse(version).expect("Crate version is not a valid semver"),
is_encrypted: false,
}
}
}
}
}

View File

@@ -75,7 +75,7 @@ pub fn common_subexpression_elimination<O: Operation>(
// Key is left/unary+right operand and operation. Value is
// the node that matches such a key.
let mut visited_nodes = HashMap::<(NodeIndex, Option<NodeIndex>, O), NodeIndex>::new();
let mut visited_nodes = HashMap::<(NodeIndex, Option<NodeIndex>, &O), NodeIndex>::new();
// Look through out immediate children. If we find any of the
// type that share an edge with another node, consolidate them into
@@ -101,31 +101,31 @@ pub fn common_subexpression_elimination<O: Operation>(
)));
};
let child_op = child_node.operation;
let child_op = &child_node.operation;
if child_op.is_binary() {
let (left, right) = get_binary_operands(&query, e);
match visited_nodes.get(&(left, Some(right), child_node.operation)) {
match visited_nodes.get(&(left, Some(right), child_op)) {
Some(equiv_node) => {
move_edges(*equiv_node, e);
}
None => {
visited_nodes.insert((left, Some(right), child_node.operation), e);
visited_nodes.insert((left, Some(right), child_op), e);
if child_op.is_commutative() {
visited_nodes.insert((right, Some(left), child_node.operation), e);
visited_nodes.insert((right, Some(left), child_op), e);
}
}
};
} else if child_op.is_unary() {
// Unary
let equiv_node = visited_nodes.get(&(index, None, child_node.operation));
let equiv_node = visited_nodes.get(&(index, None, child_op));
match equiv_node {
Some(equiv_node) => move_edges(*equiv_node, e),
None => {
visited_nodes.insert((index, None, child_node.operation), e);
visited_nodes.insert((index, None, child_op), e);
}
}
}

View File

@@ -107,8 +107,8 @@ impl<N, E> GraphTransforms<N, E> {
impl<N, E> Default for GraphTransforms<N, E>
where
N: Copy,
E: Copy,
N: Clone,
E: Clone,
{
fn default() -> Self {
Self::new()
@@ -117,18 +117,18 @@ where
impl<N, E> TransformList<N, E> for GraphTransforms<N, E>
where
N: Copy,
E: Copy,
N: Clone,
E: Clone,
{
fn apply(&mut self, graph: &mut petgraph::stable_graph::StableGraph<N, E>) {
for t in &self.transforms {
let inserted_node = match t {
Transform::AddNode(n) => Some(graph.add_node(*n)),
Transform::AddNode(n) => Some(graph.add_node(n.clone())),
Transform::AddEdge(start, end, info) => {
let start = self.materialize_index(*start);
let end = self.materialize_index(*end);
graph.add_edge(start, end, *info);
graph.add_edge(start, end, info.clone());
None
}

View File

@@ -23,6 +23,7 @@ proc-macro = true
proc-macro2 = "1.0.32"
quote = "1.0.10"
syn = { version = "1.0.81", features = ["derive", "full", "fold"] }
sunscreen_compiler_common ={ path = "../sunscreen_compiler_common" }
[dev-dependencies]
serde_json = "1.0.72"

View File

@@ -0,0 +1,160 @@
use std::collections::HashMap;
use proc_macro2::Span;
use syn::{
parse::ParseStream, punctuated::Punctuated, spanned::Spanned, Error as SynError, Expr, Lit,
LitInt, LitStr, Result as SynResult, Token,
};
pub enum AttrValue {
/**
* The attribute value is a string.
*/
String(Span, String),
/**
* The attribute value is an integer.
*/
USize(Span, usize),
/**
* The key is present but has no value associated with it.
*/
Present(Span),
}
impl From<&LitStr> for AttrValue {
fn from(lit: &LitStr) -> Self {
Self::String(lit.span(), lit.value())
}
}
impl TryFrom<&LitInt> for AttrValue {
type Error = SynError;
fn try_from(lit: &LitInt) -> SynResult<Self> {
let val = lit.base10_parse::<usize>().map_err(|_| {
SynError::new_spanned(
lit,
format!("{} is not a valid integer literal.", lit.base10_digits()),
)
})?;
Ok(Self::USize(lit.span(), val))
}
}
impl AttrValue {
pub fn get_type(&self) -> &str {
match self {
Self::String(_s, _x) => "String",
Self::USize(_s, _x) => "usize",
Self::Present(_s) => "None",
}
}
pub fn span(&self) -> Span {
match self {
Self::String(s, _) => *s,
Self::USize(s, _) => *s,
Self::Present(s) => *s,
}
}
pub fn as_str(&self) -> SynResult<&str> {
match self {
Self::String(_, val) => Ok(val),
_ => Err(SynError::new(
self.span(),
format!("Expected string literal, got {}", self.get_type()),
)),
}
}
pub fn as_usize(&self) -> SynResult<usize> {
match self {
Self::USize(_, val) => Ok(*val),
_ => Err(SynError::new(
self.span(),
format!("Expected usize literal, got {}", self.get_type()),
)),
}
}
}
/**
* Attempts to parse a list of attributes contained in an attribute and
* returns them as a `HashMap<String, AttrValue>`. The list of items
* is a comma-delimited list of either `key = value` pairs where value is
* a string or numeric literal *or* merely a key.
*
* Parsing will fail and return an error on any syntax violation.
*
* # Example
* In the below example, this function parses the contents between the
* parentheses.
*
* ```no_test
* // key1 takes a string value, key2 takes a usize, key3's presence
* // indicates is a true boolean.
* #[my_attribute(key1 = "string", key2 = 42, key3)]
* fn my_function() {
* }
* ```
*/
pub fn try_parse_dict(input: ParseStream) -> SynResult<HashMap<String, AttrValue>> {
// parses a,b,c, or a,b,c where a,b and c are Indent
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
let mut attrs: HashMap<String, AttrValue> = HashMap::new();
for var in &vars {
match var {
Expr::Assign(a) => {
let key = match &*a.left {
Expr::Path(p) =>
p.path.get_ident().ok_or_else(||SynError::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
_ => { return Err(SynError::new_spanned(&a.left, "Key should be a plain identifier")) }
};
let value: AttrValue = match &*a.right {
Expr::Lit(l) => match &l.lit {
Lit::Str(s) => s.into(),
Lit::Int(x) => x.try_into()?,
_ => {
return Err(SynError::new_spanned(
l,
"Literal should be a string or integer",
))
}
},
_ => {
return Err(SynError::new_spanned(
&a.right,
"Value should be a string literal",
))
}
};
attrs.insert(key, value);
}
Expr::Path(p) => {
let key = p
.path
.get_ident()
.ok_or_else(|| SynError::new_spanned(p, "Unknown identifier"))?
.to_string();
attrs.insert(key, AttrValue::Present(p.span()));
}
_ => {
return Err(SynError::new_spanned(
var,
"Expected `key = \"value\"` or `key`",
))
}
}
}
Ok(attrs)
}

View File

@@ -1,13 +0,0 @@
#[derive(Debug)]
pub enum Error {
SynError(syn::Error),
UnknownScheme(String),
}
impl From<syn::Error> for Error {
fn from(err: syn::Error) -> Self {
Self::SynError(err)
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,6 +1,6 @@
use crate::{
fhe_program_transforms::*,
internals::{attr::Attrs, case::Scheme},
internals::attr::{FheProgramAttrs, Scheme},
};
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
@@ -18,7 +18,7 @@ pub fn fhe_program_impl(
let inputs = &input_fn.sig.inputs;
let ret = &input_fn.sig.output;
let attr_params = parse_macro_input!(metadata as Attrs);
let attr_params = parse_macro_input!(metadata as FheProgramAttrs);
let scheme_type = match attr_params.scheme {
Scheme::Bfv => {
@@ -119,19 +119,19 @@ pub fn fhe_program_impl(
}
impl sunscreen::FheProgramFn for #fhe_program_struct_name {
fn build(&self, params: &sunscreen::Params) -> sunscreen::Result<sunscreen::FrontendCompilation> {
fn build(&self, params: &sunscreen::Params) -> sunscreen::Result<sunscreen::fhe::FheFrontendCompilation> {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}};
use sunscreen::{fhe::{CURRENT_FHE_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}};
if SchemeType::Bfv != params.scheme_type {
return Err(Error::IncorrectScheme)
}
// TODO: Other schemes.
let mut context = Context::new(params);
let mut context = FheContext::new(params.clone());
CURRENT_CTX.with(|ctx| {
CURRENT_FHE_CTX.with(|ctx| {
#[allow(clippy::type_complexity)]
#[forbid(unused_variables)]
let internal = | #(#fhe_program_args)* | -> #fhe_program_return
@@ -168,7 +168,7 @@ pub fn fhe_program_impl(
ctx.swap(&RefCell::new(None));
});
Ok(context.compilation)
Ok(context.graph)
}
fn signature(&self) -> sunscreen::CallSignature {

View File

@@ -1,16 +1,14 @@
use super::case::Scheme;
use proc_macro2::Span;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
Error, Expr, Lit, LitInt, LitStr, Result, Token,
Error as SynError, Expr, Lit, LitInt, LitStr, Result as SynResult, Token,
};
use crate::internals::symbols::VALUE_KEYS;
use std::collections::HashMap;
#[derive(Debug)]
enum AttrValue {
/**
* The attribute value is a string.
@@ -35,11 +33,11 @@ impl From<&LitStr> for AttrValue {
}
impl TryFrom<&LitInt> for AttrValue {
type Error = Error;
type Error = SynError;
fn try_from(lit: &LitInt) -> Result<Self> {
fn try_from(lit: &LitInt) -> SynResult<Self> {
let val = lit.base10_parse::<usize>().map_err(|_| {
Error::new_spanned(
SynError::new_spanned(
lit,
format!("{} is not a valid integer literal.", lit.base10_digits()),
)
@@ -66,20 +64,20 @@ impl AttrValue {
}
}
pub fn as_str(&self) -> Result<&str> {
pub fn as_str(&self) -> SynResult<&str> {
match self {
Self::String(_, val) => Ok(val),
_ => Err(Error::new(
_ => Err(SynError::new(
self.span(),
format!("Expected String, got {}", self.get_type()),
)),
}
}
pub fn as_usize(&self) -> Result<usize> {
pub fn as_usize(&self) -> SynResult<usize> {
match self {
Self::USize(_, val) => Ok(*val),
_ => Err(Error::new(
_ => Err(SynError::new(
self.span(),
format!("Expected String, got {}", self.get_type()),
)),
@@ -87,78 +85,129 @@ impl AttrValue {
}
}
pub struct Attrs {
/**
* Attempts to parse a list of attributes contained in an attribute and
* returns them as a `HashMap<String, AttrValue>`. The list of items
* is a comma-delimited list of either `key = value` pairs where value is
* a string or numeric literal *or* merely a key.
*
* Parsing will fail and return an error on any syntax violation.
*
* # Example
* In the below example, this function parses the contents between the
* parentheses.
*
* ```no_test
* // key1 takes a string value, key2 takes a usize, key3's presence
* // indicates is a true boolean.
* #[my_attribute(key1 = "string", key2 = 42, key3)]
* fn my_function() {
* }
* ```
*/
fn try_parse_dict(input: ParseStream) -> SynResult<HashMap<String, AttrValue>> {
// parses a,b,c, or a,b,c where a,b and c are Indent
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
let mut attrs: HashMap<String, AttrValue> = HashMap::new();
for var in &vars {
match var {
Expr::Assign(a) => {
let key = match &*a.left {
Expr::Path(p) =>
p.path.get_ident().ok_or_else(||SynError::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
_ => { return Err(SynError::new_spanned(&a.left, "Key should be a plain identifier")) }
};
let value: AttrValue = match &*a.right {
Expr::Lit(l) => match &l.lit {
Lit::Str(s) => s.into(),
Lit::Int(x) => x.try_into()?,
_ => {
return Err(SynError::new_spanned(
l,
"Literal should be a string or integer",
))
}
},
_ => {
return Err(SynError::new_spanned(
&a.right,
"Value should be a string literal",
))
}
};
attrs.insert(key, value);
}
Expr::Path(p) => {
let key = p
.path
.get_ident()
.ok_or_else(|| SynError::new_spanned(p, "Unknown identifier"))?
.to_string();
attrs.insert(key, AttrValue::Present(p.span()));
}
_ => {
return Err(SynError::new_spanned(
var,
"Expected `key = \"value\"` or `key`",
))
}
}
}
Ok(attrs)
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum Scheme {
Bfv,
}
impl TryFrom<&AttrValue> for Scheme {
type Error = SynError;
fn try_from(value: &AttrValue) -> SynResult<Self> {
let as_str = value.as_str()?;
let scheme = match as_str {
"bfv" => Self::Bfv,
_ => {
return Err(SynError::new(
value.span(),
format!("Unknown scheme {}", as_str),
));
}
};
Ok(scheme)
}
}
pub struct FheProgramAttrs {
pub scheme: Scheme,
pub chain_count: usize,
}
impl Parse for Attrs {
fn parse(input: ParseStream) -> Result<Self> {
// parses a,b,c, or a,b,c where a,b and c are Indent
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
impl Parse for FheProgramAttrs {
fn parse(input: ParseStream) -> SynResult<Self> {
let attrs = try_parse_dict(input)?;
let mut attrs: HashMap<String, AttrValue> = HashMap::new();
const VALUE_KEYS: &[&str] = &["scheme", "chain_count"];
for var in &vars {
match var {
Expr::Assign(a) => {
let key = match &*a.left {
Expr::Path(p) =>
p.path.get_ident().ok_or_else(||Error::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
_ => { return Err(Error::new_spanned(&a.left, "Key should be a plain identifier")) }
};
let value: AttrValue = match &*a.right {
Expr::Lit(l) => match &l.lit {
Lit::Str(s) => s.into(),
Lit::Int(x) => x.try_into()?,
_ => {
return Err(Error::new_spanned(
l,
"Literal should be a string or integer",
))
}
},
_ => {
return Err(Error::new_spanned(
&a.right,
"Value should be a string literal",
))
}
};
if !VALUE_KEYS.iter().any(|x| *x == key) {
return Err(Error::new_spanned(a, "Unknown key".to_owned()));
}
attrs.insert(key, value);
}
Expr::Path(p) => {
let key = p
.path
.get_ident()
.ok_or_else(|| Error::new_spanned(p, "Unknown identifier"))?
.to_string();
if !VALUE_KEYS.iter().any(|x| *x == key) {
return Err(Error::new_spanned(p, "Unknown key"));
}
attrs.insert(key, AttrValue::Present(p.span()));
}
_ => {
return Err(Error::new_spanned(
var,
"Expected `key = \"value\"` or `key`",
))
}
for i in attrs.keys() {
if !VALUE_KEYS.iter().any(|x| x == i) {
return Err(SynError::new(input.span(), &format!("Unknown key '{}'", i)));
}
}
let scheme_type = attrs
let scheme: Scheme = attrs
.get("scheme")
.ok_or_else(|| Error::new_spanned(&vars, "required `scheme` is missing".to_owned()))?
.as_str()?;
.ok_or_else(|| SynError::new(input.span(), "required `scheme` is missing".to_owned()))?
.try_into()?;
let chain_count = attrs
.get("chain_count")
@@ -166,10 +215,54 @@ impl Parse for Attrs {
.unwrap_or(Ok(1))?;
Ok(Self {
scheme: Scheme::parse(scheme_type).map_err(|_e| {
Error::new_spanned(vars, format!("Unknown scheme '{}'", &scheme_type))
})?,
scheme,
chain_count,
})
}
}
pub enum BackendType {
Bulletproofs,
}
impl TryFrom<&AttrValue> for BackendType {
type Error = SynError;
fn try_from(value: &AttrValue) -> SynResult<Self> {
let as_str = value.as_str()?;
match as_str {
"bulletproofs" => Ok(BackendType::Bulletproofs),
_ => Err(SynError::new(
value.span(),
format!("Unknown backend `{}`", as_str.to_owned()),
)),
}
}
}
#[allow(unused)]
pub struct ZkpProgramAttrs {
backend_type: BackendType,
}
impl Parse for ZkpProgramAttrs {
fn parse(input: ParseStream) -> SynResult<Self> {
let attrs = try_parse_dict(input)?;
const VALUE_KEYS: &[&str] = &["backend"];
for i in attrs.keys() {
if !VALUE_KEYS.iter().any(|x| x == i) {
return Err(SynError::new(input.span(), &format!("Unknown key '{}'", i)));
}
}
let backend_type = attrs.get("backend").ok_or_else(|| {
SynError::new(input.span(), "required 'backend' is missing".to_owned())
})?;
let backend_type = BackendType::try_from(backend_type)?;
Ok(Self { backend_type })
}
}

View File

@@ -1,16 +1 @@
use self::Scheme::*;
use crate::error::*;
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum Scheme {
Bfv,
}
impl Scheme {
pub fn parse(s: &str) -> Result<Self> {
Ok(match s {
"bfv" => Bfv,
_ => Err(Error::UnknownScheme(s.to_owned()))?,
})
}
}

View File

@@ -1,4 +1,3 @@
// Following the pattern in serde (https://github.com/serde-rs)
pub mod attr;
pub mod case;
pub mod symbols;

View File

@@ -1 +0,0 @@
pub const VALUE_KEYS: &[&str] = &["scheme", "chain_count"];

View File

@@ -6,11 +6,11 @@
extern crate proc_macro;
mod error;
mod fhe_program;
mod fhe_program_transforms;
mod internals;
mod type_name;
mod zkp_program;
#[proc_macro_derive(TypeName)]
/**
@@ -65,3 +65,14 @@ pub fn fhe_program(
) -> proc_macro::TokenStream {
fhe_program::fhe_program_impl(metadata, input)
}
#[proc_macro_attribute]
/**
* Specifies a function to be a ZKP program. TODO: docs.
*/
pub fn zkp_program(
metadata: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
zkp_program::zkp_program_impl(metadata, input)
}

View File

@@ -0,0 +1,162 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned};
use sunscreen_compiler_common::macros::{
create_program_node, emit_signature, extract_fn_arguments, lift_type, ExtractFnArgumentsError,
};
use syn::{parse_macro_input, spanned::Spanned, ItemFn, ReturnType, Type};
use crate::internals::attr::ZkpProgramAttrs;
pub fn zkp_program_impl(
metadata: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let _attr_params = parse_macro_input!(metadata as ZkpProgramAttrs);
let input_fn = parse_macro_input!(input as ItemFn);
let zkp_program_name = &input_fn.sig.ident;
let vis = &input_fn.vis;
let body = &input_fn.block;
let inputs = &input_fn.sig.inputs;
let ret = &input_fn.sig.output;
match ret {
ReturnType::Default => {}
_ => {
return proc_macro::TokenStream::from(quote_spanned! {
ret.span() => compile_error!("ZKP programs may not return values.")
});
}
};
let unwrapped_inputs = match extract_fn_arguments(inputs) {
Ok(v) => v,
Err(e) => {
return proc_macro::TokenStream::from(match e {
ExtractFnArgumentsError::ContainsSelf(s) => {
quote_spanned! {s => compile_error!("FHE programs must not contain `self`") }
}
ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! {
s => compile_error! { "Expected Identifier" }
},
ExtractFnArgumentsError::IllegalType(s) => quote_spanned! {
s => compile_error! { "FHE program arguments must be an array or named struct type" }
},
});
}
};
let argument_types = unwrapped_inputs
.iter()
.map(|(t, _)| (**t).clone())
.collect::<Vec<Type>>();
let zkp_program_args = unwrapped_inputs
.iter()
.map(|i| {
let (ty, name) = i;
let ty = lift_type(ty).unwrap();
quote! {
#name: #ty,
}
})
.collect::<Vec<TokenStream>>();
let signature = emit_signature(&argument_types, &[]);
let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| {
let var_name = format!("c_{}", i);
create_program_node(&var_name, t.0)
});
let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| {
let id = Ident::new(&format!("c_{}", i), Span::call_site());
quote_spanned! {t.0.span() =>
#id
}
});
let zkp_program_struct_name =
Ident::new(&format!("{}_struct", zkp_program_name), Span::call_site());
let zkp_program_name_literal = format!("{}", zkp_program_name);
proc_macro::TokenStream::from(quote! {
#[allow(non_camel_case_types)]
#[derive(Clone)]
#vis struct #zkp_program_struct_name {
}
impl sunscreen::ZkpProgramFn for #zkp_program_struct_name {
fn build(&self) -> sunscreen::Result<sunscreen::ZkpFrontendCompilation> {
use std::cell::RefCell;
use std::mem::transmute;
use sunscreen::{CURRENT_ZKP_CTX, ZkpContext, Error, INDEX_ARENA, Result, types::{zkp::ProgramNode, TypeName}};
let mut context = ZkpContext::new(0);
CURRENT_ZKP_CTX.with(|ctx| {
#[allow(clippy::type_complexity)]
#[forbid(unused_variables)]
let internal = | #(#zkp_program_args)* |
#body
;
// Transmute away the lifetime to 'static. So long as we are careful with internal()
// panicing, this is safe because we set the context back to none before the function
// returns.
ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) })));
#(#var_decl)*
let panic_res = std::panic::catch_unwind(|| {
internal(#(#args),*)
});
// when panicing or not, we need to clear our indicies arena and
// unset the context reference.
match panic_res {
Ok(v) => { },
Err(err) => {
INDEX_ARENA.with(|allocator| {
allocator.borrow_mut().reset()
});
ctx.swap(&RefCell::new(None));
std::panic::resume_unwind(err)
}
};
INDEX_ARENA.with(|allocator| {
allocator.borrow_mut().reset()
});
ctx.swap(&RefCell::new(None));
});
Ok(context.graph)
}
fn signature(&self) -> sunscreen::CallSignature {
#signature
}
fn name(&self) -> &str {
#zkp_program_name_literal
}
}
impl AsRef<str> for #zkp_program_struct_name {
fn as_ref(&self) -> &str {
use sunscreen::FheProgramFn;
self.name()
}
}
#[allow(non_upper_case_globals)]
#vis const #zkp_program_name: #zkp_program_struct_name = #zkp_program_struct_name {
};
})
}

View File

@@ -18,7 +18,7 @@ readme = "crates-io.md"
[dependencies]
petgraph = { version = "0.6.0", features = ["serde-1"] }
serde = { version = "1.0.130", features = ["derive"] }
serde = { version = "1.0.147", features = ["derive"] }
seal_fhe = { version = "0.7", path = "../seal_fhe" }
[dev-dependencies]

View File

@@ -37,7 +37,7 @@ use TransformNodeIndex::*;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, Serialize, Hash, Deserialize, PartialEq, Eq)]
/**
* Sunscreen supports the BFV scheme.
*/

View File

@@ -27,7 +27,7 @@ petgraph = "0.6.0"
num_cpus = "1.13.0"
rayon = "1.5.1"
rlp = "0.5.1"
serde = "1.0.130"
serde = "1.0.147"
semver = "1.0.4"
[dev-dependencies]

View File

@@ -22,7 +22,7 @@ pub use serialization::WithContext;
use seal_fhe::{Ciphertext as SealCiphertext, Plaintext as SealPlaintext};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, Serialize, Deserialize, Eq)]
/**
* The underlying backend implementation of a plaintext (e.g. SEAL's [`Plaintext`](seal_fhe::Plaintext)).
*/

View File

@@ -63,7 +63,7 @@ pub enum RequiredKeys {
PublicKey,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Serialize, Hash, Deserialize, PartialEq, Eq)]
/**
* The parameter set required for a given FHE program to run efficiently and correctly.
*/

View File

@@ -1,3 +1,5 @@
use std::hash::Hash;
use crate::Params;
use seal_fhe::{BfvEncryptionParametersBuilder, Context, FromBytes, Modulus, ToBytes};
use serde::{
@@ -6,7 +8,7 @@ use serde::{
Deserialize, Serialize,
};
#[derive(Debug, PartialEq, Eq, Clone)]
#[derive(Debug, PartialEq, Hash, Eq, Clone)]
/**
* A data type that contains parameters for reconstructing a context
* during deserialization (needed by SEAL).

View File

@@ -0,0 +1,22 @@
[package]
name = "sunscreen_zkp_backend"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bulletproofs = { version = "4.0.0", features = ["yoloproofs"], optional = true }
merlin = { version = "3.0.0", optional = true}
[features]
bulletproofs = [
"dep:bulletproofs",
"dep:merlin"
]
[dependencies.curve25519-dalek]
version = "4"
features = ["u64_backend", "serde"]
default-features = false
package = "curve25519-dalek-ng"

View File

@@ -0,0 +1,77 @@
use bulletproofs::{
r1cs::{ConstraintSystem, Prover, R1CSProof, Variable, Verifier},
BulletproofGens, PedersenGens,
};
use curve25519_dalek::scalar::Scalar;
use merlin::Transcript;
use std::task::Context;
struct MulProof(R1CSProof);
impl MulProof {
fn make_gens() -> (Transcript, PedersenGens, BulletproofGens) {
let mut transcript = Transcript::new(b"Horse");
transcript.append_message(b"dom-sep", b"MulProof");
let pc_gens = PedersenGens::default();
let bp_gens = BulletproofGens::new(128, 1);
(transcript, pc_gens, bp_gens)
}
pub fn prove(x: Scalar, y: Scalar, o: Scalar) -> Self {
let (transcript, pc_gens, bp_gens) = Self::make_gens();
let mut prover = Prover::new(&pc_gens, transcript);
let inputs = vec![
prover.allocate(Some(x)).unwrap(),
prover.allocate(Some(y)).unwrap(),
];
let outputs = vec![prover.allocate(Some(o)).unwrap()];
Self::gadget(&mut prover, inputs, outputs);
Self(prover.prove(&bp_gens).unwrap())
}
pub fn verify(&self) -> bool {
let (transcript, pc_gens, bp_gens) = Self::make_gens();
let mut verifier = Verifier::new(transcript);
let inputs = vec![
verifier.allocate(None).unwrap(),
verifier.allocate(None).unwrap(),
];
let outputs = vec![verifier.allocate(None).unwrap()];
Self::gadget(&mut verifier, inputs, outputs);
verifier.verify(&self.0, &pc_gens, &bp_gens).is_ok()
}
fn gadget<CS: ConstraintSystem>(cs: &mut CS, inputs: Vec<Variable>, outputs: Vec<Variable>) {
let (_, _, o) = cs.multiply(
inputs[0] + Scalar::from(0u32),
inputs[1] + Scalar::from(0u32),
);
inputs[0];
cs.constrain(o - outputs[0]);
}
}
#[test]
fn can_use_bulletproofs_contstraints() {
let proof = MulProof::prove(Scalar::from(7u32), Scalar::from(9u32), Scalar::from(63u32));
assert!(proof.verify());
let proof = MulProof::prove(Scalar::from(7u32), Scalar::from(9u32), Scalar::from(64u32));
assert!(!proof.verify());
}

View File

@@ -0,0 +1,2 @@
#[cfg(feature = "bulletproofs")]
pub mod bulletproofs;