mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
19
.vscode/launch.json
vendored
19
.vscode/launch.json
vendored
@@ -304,6 +304,25 @@
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug integration test 'zkp_program_tests'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--test=zkp_program_tests",
|
||||
"--package=sunscreen"
|
||||
],
|
||||
"filter": {
|
||||
"name": "zkp_program_tests",
|
||||
"kind": "test"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
|
||||
274
Cargo.lock
generated
274
Cargo.lock
generated
@@ -79,12 +79,54 @@ version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
|
||||
dependencies = [
|
||||
"block-padding",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-padding"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae"
|
||||
|
||||
[[package]]
|
||||
name = "bulletproofs"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40e698f1df446cc6246afd823afbe2d121134d089c9102c1dd26d1264991ba32"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"clear_on_drop",
|
||||
"curve25519-dalek-ng",
|
||||
"digest",
|
||||
"merlin",
|
||||
"rand",
|
||||
"rand_core",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"sha3",
|
||||
"subtle-ng",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.2.1"
|
||||
@@ -182,6 +224,15 @@ dependencies = [
|
||||
"os_str_bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clear_on_drop"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38508a63f4979f0048febc9966fadbd48e5dab31fd0ec6a3f151bbf4a74f7423"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.48"
|
||||
@@ -207,6 +258,15 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.8.2"
|
||||
@@ -276,6 +336,29 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curve25519-dalek-ng"
|
||||
version = "4.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c359b7249347e46fb28804470d071c921156ad62b3eef5d34e2ba867533dec8"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"digest",
|
||||
"rand_core",
|
||||
"serde",
|
||||
"subtle-ng",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dot_prod"
|
||||
version = "0.1.0"
|
||||
@@ -427,6 +510,27 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.0"
|
||||
@@ -594,6 +698,15 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "keccak"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768"
|
||||
dependencies = [
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
@@ -663,6 +776,18 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "merlin"
|
||||
version = "3.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"keccak",
|
||||
"rand_core",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.16"
|
||||
@@ -807,6 +932,12 @@ version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.41"
|
||||
@@ -907,6 +1038,12 @@ version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.47"
|
||||
@@ -925,6 +1062,36 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.5.3"
|
||||
@@ -1149,6 +1316,18 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha3"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"digest",
|
||||
"keccak",
|
||||
"opaque-debug",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.1.0"
|
||||
@@ -1187,6 +1366,12 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "subtle-ng"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "734676eb262c623cec13c3155096e08d1f8f29adce39ba17948b18dad1e54142"
|
||||
|
||||
[[package]]
|
||||
name = "sunscreen"
|
||||
version = "0.7.0"
|
||||
@@ -1200,6 +1385,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sunscreen_backend",
|
||||
"sunscreen_compiler_common",
|
||||
"sunscreen_compiler_macros",
|
||||
"sunscreen_fhe_program",
|
||||
"sunscreen_runtime",
|
||||
@@ -1224,8 +1410,11 @@ name = "sunscreen_compiler_common"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"petgraph",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"semver",
|
||||
"serde",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1235,6 +1424,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde_json",
|
||||
"sunscreen_compiler_common",
|
||||
"syn",
|
||||
]
|
||||
|
||||
@@ -1268,17 +1458,38 @@ dependencies = [
|
||||
"sunscreen_fhe_program",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sunscreen_zkp_backend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bulletproofs",
|
||||
"curve25519-dalek-ng",
|
||||
"merlin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.100"
|
||||
version = "1.0.103"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52205623b1b0f064a4e71182c3b18ae902267282930c6d5462c91b859668426e"
|
||||
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.12.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.3.0"
|
||||
@@ -1308,6 +1519,26 @@ version = "0.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.6.0"
|
||||
@@ -1397,6 +1628,12 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.8"
|
||||
@@ -1418,6 +1655,12 @@ dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.3.1"
|
||||
@@ -1435,6 +1678,12 @@ version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.0"
|
||||
@@ -1620,3 +1869,24 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c394b5bd0c6f669e7275d9c20aa90ae064cb22e75a1cad54e1b34088034b149f"
|
||||
dependencies = [
|
||||
"zeroize_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize_derive"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f8f187641dad4f680d25c4bfc4225b418165984179f26ca76ec4fb6441d3a17"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
@@ -17,6 +17,8 @@ members = [
|
||||
"sunscreen_compiler_macros",
|
||||
"sunscreen_fhe_program",
|
||||
"sunscreen_runtime",
|
||||
"sunscreen_compiler_common",
|
||||
"sunscreen_zkp_backend",
|
||||
]
|
||||
exclude = [
|
||||
"mdBook",
|
||||
|
||||
@@ -42,7 +42,7 @@ impl PartialEq for Modulus {
|
||||
* Microsoft SEAL when constructing a SEALContext object. Normal users should not
|
||||
* have to specify the security level explicitly anywhere.
|
||||
*/
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[repr(i32)]
|
||||
pub enum SecurityLevel {
|
||||
/// 128-bit security level according to HomomorphicEncryption.org standard.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use core::hash::Hash;
|
||||
use std::ffi::{c_void, CString};
|
||||
use std::ptr::null_mut;
|
||||
|
||||
@@ -7,7 +8,7 @@ use crate::{bindgen, serialization::CompressionType, Context, FromBytes, ToBytes
|
||||
use serde::ser::Error;
|
||||
use serde::{Serialize, Serializer};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Eq)]
|
||||
/**
|
||||
* Class to store a plaintext element. The data for the plaintext is
|
||||
* a polynomial with coefficients modulo the plaintext modulus. The degree
|
||||
@@ -66,6 +67,15 @@ impl PartialEq for Plaintext {
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for Plaintext {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
for i in 0..self.len() {
|
||||
let c = self.get_coefficient(i);
|
||||
state.write_u64(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Plaintext {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
|
||||
@@ -24,14 +24,16 @@ bumpalo = "3.8.0"
|
||||
log = "0.4.14"
|
||||
num = "0.4.0"
|
||||
petgraph = "0.6.0"
|
||||
sunscreen_compiler_common = { path = "../sunscreen_compiler_common" }
|
||||
sunscreen_compiler_macros = { version = "0.7", path = "../sunscreen_compiler_macros" }
|
||||
sunscreen_backend = { version = "0.7", path = "../sunscreen_backend" }
|
||||
sunscreen_fhe_program = { version = "0.7", path = "../sunscreen_fhe_program" }
|
||||
sunscreen_runtime = { version = "0.7", path = "../sunscreen_runtime" }
|
||||
seal_fhe = { version = "0.7", path = "../seal_fhe" }
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
serde = { version = "1.0.147", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
sunscreen_compiler_common = { path = "../sunscreen_compiler_common" }
|
||||
serde_json = "1.0.72"
|
||||
float-cmp = "0.9.0"
|
||||
|
||||
|
||||
394
sunscreen/src/fhe/mod.rs
Normal file
394
sunscreen/src/fhe/mod.rs
Normal file
@@ -0,0 +1,394 @@
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_backend::compile_inplace;
|
||||
use sunscreen_compiler_common::{
|
||||
Context, EdgeInfo, FrontendCompilation, Operation as OperationTrait,
|
||||
};
|
||||
use sunscreen_fhe_program::{
|
||||
EdgeInfo as FheProgramEdgeInfo, FheProgram, Literal as FheProgramLiteral, NodeInfo,
|
||||
Operation as FheProgramOperation, SchemeType,
|
||||
};
|
||||
use sunscreen_runtime::{InnerPlaintext, Params};
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Hash, Serialize, PartialEq, Eq)]
|
||||
/**
|
||||
* Represents a literal node's data.
|
||||
*/
|
||||
pub enum Literal {
|
||||
/**
|
||||
* An unsigned 64-bit integer.
|
||||
*/
|
||||
U64(u64),
|
||||
|
||||
/**
|
||||
* An encoded plaintext value.
|
||||
*/
|
||||
Plaintext(InnerPlaintext),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Deserialize, Serialize, PartialEq, Eq)]
|
||||
/**
|
||||
* Represents an operation occurring in the frontend AST.
|
||||
*/
|
||||
pub enum FheOperation {
|
||||
/**
|
||||
* This node indicates loading a cipher text from an input.
|
||||
*/
|
||||
InputCiphertext,
|
||||
|
||||
/**
|
||||
* This node indicates loading a plaintext from an input.
|
||||
*/
|
||||
InputPlaintext,
|
||||
|
||||
/**
|
||||
* Addition.
|
||||
*/
|
||||
Add,
|
||||
|
||||
/**
|
||||
* Add a ciphertext and plaintext value.
|
||||
*/
|
||||
AddPlaintext,
|
||||
|
||||
/**
|
||||
* Subtraction.
|
||||
*/
|
||||
Sub,
|
||||
|
||||
/**
|
||||
* Subtract a plaintext.
|
||||
*/
|
||||
SubPlaintext,
|
||||
|
||||
/**
|
||||
* Unary negation (i.e. given x, compute -x)
|
||||
*/
|
||||
Negate,
|
||||
|
||||
/**
|
||||
* Multiplication.
|
||||
*/
|
||||
Multiply,
|
||||
|
||||
/**
|
||||
* Multiply a ciphertext by a plaintext.
|
||||
*/
|
||||
MultiplyPlaintext,
|
||||
|
||||
/**
|
||||
* A literal that serves as an operand to other operations.
|
||||
*/
|
||||
Literal(Literal),
|
||||
|
||||
/**
|
||||
* Rotate left.
|
||||
*/
|
||||
RotateLeft,
|
||||
|
||||
/**
|
||||
* Rotate right.
|
||||
*/
|
||||
RotateRight,
|
||||
|
||||
/**
|
||||
* In the BFV scheme, swap rows in the Batched vectors.
|
||||
*/
|
||||
SwapRows,
|
||||
|
||||
/**
|
||||
* This node indicates the previous node's result should be a result of the [`fhe_program`](crate::fhe_program).
|
||||
*/
|
||||
Output,
|
||||
}
|
||||
|
||||
impl OperationTrait for FheOperation {
|
||||
fn is_binary(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
FheOperation::Add
|
||||
| FheOperation::Multiply
|
||||
| FheOperation::Sub
|
||||
| FheOperation::RotateLeft
|
||||
| FheOperation::RotateRight
|
||||
| FheOperation::SubPlaintext
|
||||
| FheOperation::AddPlaintext
|
||||
| FheOperation::MultiplyPlaintext
|
||||
)
|
||||
}
|
||||
|
||||
fn is_commutative(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
FheOperation::Add
|
||||
| FheOperation::Multiply
|
||||
| FheOperation::AddPlaintext
|
||||
| FheOperation::MultiplyPlaintext
|
||||
)
|
||||
}
|
||||
|
||||
fn is_unary(&self) -> bool {
|
||||
matches!(self, FheOperation::Negate | FheOperation::SwapRows)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The context for constructing the [`fhe_program`](crate::fhe_program) graph during compilation.
|
||||
*
|
||||
* This is an implementation detail of the
|
||||
* [`fhe_program`](crate::fhe_program) macro, and you shouldn't need
|
||||
* to construct one.
|
||||
*/
|
||||
pub type FheContext = Context<FheOperation, Params>;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
pub type FheFrontendCompilation = FrontendCompilation<FheOperation>;
|
||||
|
||||
thread_local! {
|
||||
/**
|
||||
* Contains the graph of a ZKP program during compilation. An
|
||||
* implementation detail and not for public consumption.
|
||||
*/
|
||||
pub static CURRENT_FHE_CTX: RefCell<Option<&'static mut FheContext>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the specified closure, injecting the current
|
||||
* [`fhe_program`](crate::fhe_program) context.
|
||||
*/
|
||||
pub fn with_fhe_ctx<F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut FheContext) -> R,
|
||||
{
|
||||
CURRENT_FHE_CTX.with(|ctx| {
|
||||
let mut option = ctx.borrow_mut();
|
||||
let ctx = option
|
||||
.as_mut()
|
||||
.expect("Called Ciphertext::new() outside of a context.");
|
||||
|
||||
f(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines transformations to FHE program graphs.
|
||||
*/
|
||||
pub trait FheContextOps {
|
||||
/**
|
||||
* Add an encrypted input to this context.
|
||||
*/
|
||||
fn add_ciphertext_input(&mut self) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a plaintext input to this context.
|
||||
*/
|
||||
fn add_plaintext_input(&mut self) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Adds a plaintext literal to the
|
||||
* [`fhe_program`](crate::fhe_program) graph.
|
||||
*/
|
||||
fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a subtraction to this context.
|
||||
*/
|
||||
fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a subtraction to this context.
|
||||
*/
|
||||
fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Adds a negation to this context.
|
||||
*/
|
||||
fn add_negate(&mut self, x: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add an addition to this context.
|
||||
*/
|
||||
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Adds an addition to a plaintext.
|
||||
*/
|
||||
fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a multiplication to this context.
|
||||
*/
|
||||
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a multiplication to this context.
|
||||
*/
|
||||
fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Adds a literal to this context.
|
||||
*/
|
||||
fn add_literal(&mut self, literal: Literal) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a rotate left.
|
||||
*/
|
||||
fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a rotate right.
|
||||
*/
|
||||
fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
/**
|
||||
* Adds a row swap.
|
||||
*/
|
||||
fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex;
|
||||
|
||||
/**
|
||||
* Add a node that captures the previous node as an output.
|
||||
*/
|
||||
fn add_output(&mut self, i: NodeIndex) -> NodeIndex;
|
||||
}
|
||||
|
||||
impl FheContextOps for FheContext {
|
||||
fn add_ciphertext_input(&mut self) -> NodeIndex {
|
||||
self.add_node(FheOperation::InputCiphertext)
|
||||
}
|
||||
|
||||
fn add_plaintext_input(&mut self) -> NodeIndex {
|
||||
self.add_node(FheOperation::InputPlaintext)
|
||||
}
|
||||
|
||||
fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex {
|
||||
self.add_node(FheOperation::Literal(Literal::Plaintext(plaintext)))
|
||||
}
|
||||
|
||||
fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::Sub, left, right)
|
||||
}
|
||||
|
||||
fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::SubPlaintext, left, right)
|
||||
}
|
||||
|
||||
fn add_negate(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.add_unary_operation(FheOperation::Negate, x)
|
||||
}
|
||||
|
||||
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::Add, left, right)
|
||||
}
|
||||
|
||||
fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::AddPlaintext, left, right)
|
||||
}
|
||||
|
||||
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::Multiply, left, right)
|
||||
}
|
||||
|
||||
fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::MultiplyPlaintext, left, right)
|
||||
}
|
||||
|
||||
fn add_literal(&mut self, literal: Literal) -> NodeIndex {
|
||||
// See if we already have a node for the given literal. If so, just return it.
|
||||
// If not, make a new one.
|
||||
let existing_literal =
|
||||
self.graph
|
||||
.node_indices()
|
||||
.find(|&i| match &self.graph[i].operation {
|
||||
FheOperation::Literal(x) => *x == literal,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
match existing_literal {
|
||||
Some(x) => x,
|
||||
None => self.add_node(FheOperation::Literal(literal)),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::RotateLeft, left, right)
|
||||
}
|
||||
|
||||
fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(FheOperation::RotateRight, left, right)
|
||||
}
|
||||
|
||||
fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.add_unary_operation(FheOperation::SwapRows, x)
|
||||
}
|
||||
|
||||
fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
|
||||
self.add_unary_operation(FheOperation::Output, i)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends FheFrontendCompilation to add a backend compilation method.
|
||||
*/
|
||||
pub trait FheCompile {
|
||||
/**
|
||||
* Performs frontend compilation of this intermediate representation into a backend [`FheProgram`],
|
||||
* then perform backend compilation and return the result.
|
||||
*/
|
||||
fn compile(&self) -> FheProgram;
|
||||
}
|
||||
|
||||
impl FheCompile for FheFrontendCompilation {
|
||||
fn compile(&self) -> FheProgram {
|
||||
let mut fhe_program = FheProgram::new(SchemeType::Bfv);
|
||||
|
||||
let mapped_graph = self.0.map(
|
||||
|id, n| match &n.operation {
|
||||
FheOperation::Add => NodeInfo::new(FheProgramOperation::Add),
|
||||
FheOperation::InputCiphertext => {
|
||||
// HACKHACK: Input nodes are always added first to the graph in the order
|
||||
// they're specified as function arguments. We should not depend on this.
|
||||
NodeInfo::new(FheProgramOperation::InputCiphertext(id.index()))
|
||||
}
|
||||
FheOperation::InputPlaintext => {
|
||||
// HACKHACK: Input nodes are always added first to the graph in the order
|
||||
// they're specified as function arguments. We should not depend on this.
|
||||
NodeInfo::new(FheProgramOperation::InputPlaintext(id.index()))
|
||||
}
|
||||
FheOperation::Literal(Literal::U64(x)) => {
|
||||
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::U64(*x)))
|
||||
}
|
||||
FheOperation::Literal(Literal::Plaintext(x)) => {
|
||||
// It's okay to unwrap here because fhe_program compilation will
|
||||
// catch the panic and return a compilation error.
|
||||
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::Plaintext(
|
||||
x.to_bytes().expect("Failed to serialize plaintext."),
|
||||
)))
|
||||
}
|
||||
FheOperation::Sub => NodeInfo::new(FheProgramOperation::Sub),
|
||||
FheOperation::SubPlaintext => NodeInfo::new(FheProgramOperation::SubPlaintext),
|
||||
FheOperation::Negate => NodeInfo::new(FheProgramOperation::Negate),
|
||||
FheOperation::Multiply => NodeInfo::new(FheProgramOperation::Multiply),
|
||||
FheOperation::MultiplyPlaintext => {
|
||||
NodeInfo::new(FheProgramOperation::MultiplyPlaintext)
|
||||
}
|
||||
FheOperation::Output => NodeInfo::new(FheProgramOperation::OutputCiphertext),
|
||||
FheOperation::RotateLeft => NodeInfo::new(FheProgramOperation::ShiftLeft),
|
||||
FheOperation::RotateRight => NodeInfo::new(FheProgramOperation::ShiftRight),
|
||||
FheOperation::SwapRows => NodeInfo::new(FheProgramOperation::SwapRows),
|
||||
FheOperation::AddPlaintext => NodeInfo::new(FheProgramOperation::AddPlaintext),
|
||||
},
|
||||
|_, e| match e {
|
||||
EdgeInfo::Left => FheProgramEdgeInfo::LeftOperand,
|
||||
EdgeInfo::Right => FheProgramEdgeInfo::RightOperand,
|
||||
EdgeInfo::Unary => FheProgramEdgeInfo::UnaryOperand,
|
||||
},
|
||||
);
|
||||
|
||||
fhe_program.graph = mapped_graph;
|
||||
|
||||
compile_inplace(fhe_program)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::fhe::{FheCompile, FheFrontendCompilation};
|
||||
use crate::params::{determine_params, PlainModulusConstraint};
|
||||
use crate::{
|
||||
Application, CallSignature, Error, FheProgramMetadata, FrontendCompilation, Params,
|
||||
RequiredKeys, Result, SchemeType, SecurityLevel,
|
||||
Application, CallSignature, Error, FheProgramMetadata, Params, RequiredKeys, Result,
|
||||
SchemeType, SecurityLevel,
|
||||
};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use sunscreen_runtime::CompiledFheProgram;
|
||||
@@ -24,7 +25,7 @@ pub trait FheProgramFn {
|
||||
/**
|
||||
* Compile the `#[fhe_program]`.
|
||||
*/
|
||||
fn build(&self, params: &Params) -> Result<FrontendCompilation>;
|
||||
fn build(&self, params: &Params) -> Result<FheFrontendCompilation>;
|
||||
|
||||
/**
|
||||
* Get the scheme type.
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
//! # Examples
|
||||
//! This example is further annotated in `examples/simple_multiply`.
|
||||
//! ```
|
||||
//! # use sunscreen::{fhe_program, Compiler, types::{bfv::Signed, Cipher}, PlainModulusConstraint, Params, Runtime, Context};
|
||||
//! # use sunscreen::{fhe_program, Compiler, types::{bfv::Signed, Cipher}, PlainModulusConstraint, Params, Runtime};
|
||||
//!
|
||||
//! #[fhe_program(scheme = "bfv")]
|
||||
//! fn simple_multiply(a: Cipher<Signed>, b: Cipher<Signed>) -> Cipher<Signed> {
|
||||
@@ -39,8 +39,14 @@
|
||||
//!
|
||||
|
||||
mod error;
|
||||
/**
|
||||
* This module contains types used internally when compiling
|
||||
* [`fhe_program`]s.
|
||||
*/
|
||||
pub mod fhe;
|
||||
mod fhe_compiler;
|
||||
mod params;
|
||||
mod zkp;
|
||||
|
||||
/**
|
||||
* This module contains types used during [`fhe_program`] construction.
|
||||
@@ -57,21 +63,13 @@ mod params;
|
||||
*/
|
||||
pub mod types;
|
||||
|
||||
use petgraph::{
|
||||
algo::is_isomorphic_matching,
|
||||
stable_graph::{NodeIndex, StableGraph},
|
||||
Graph,
|
||||
};
|
||||
use fhe::{FheOperation, Literal};
|
||||
use petgraph::stable_graph::StableGraph;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use sunscreen_backend::compile_inplace;
|
||||
use sunscreen_fhe_program::{
|
||||
EdgeInfo, FheProgram, Literal as FheProgramLiteral, NodeInfo, Operation as FheProgramOperation,
|
||||
};
|
||||
|
||||
pub use error::{Error, Result};
|
||||
pub use fhe_compiler::{Compiler, FheProgramFn};
|
||||
pub use params::PlainModulusConstraint;
|
||||
@@ -83,6 +81,8 @@ pub use sunscreen_runtime::{
|
||||
FheProgramInputTrait, FheProgramMetadata, InnerCiphertext, InnerPlaintext, Params, Plaintext,
|
||||
PrivateKey, PublicKey, RequiredKeys, Runtime, WithContext,
|
||||
};
|
||||
pub use zkp::ZkpProgramFn;
|
||||
pub use zkp::{with_zkp_ctx, ZkpContext, ZkpFrontendCompilation, CURRENT_ZKP_CTX};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
/**
|
||||
@@ -141,98 +141,6 @@ impl Application {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
/**
|
||||
* Represents a literal node's data.
|
||||
*/
|
||||
pub enum Literal {
|
||||
/**
|
||||
* An unsigned 64-bit integer.
|
||||
*/
|
||||
U64(u64),
|
||||
|
||||
/**
|
||||
* An encoded plaintext value.
|
||||
*/
|
||||
Plaintext(InnerPlaintext),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
/**
|
||||
* Represents an operation occurring in the frontend AST.
|
||||
*/
|
||||
pub enum Operation {
|
||||
/**
|
||||
* This node indicates loading a cipher text from an input.
|
||||
*/
|
||||
InputCiphertext,
|
||||
|
||||
/**
|
||||
* This node indicates loading a plaintext from an input.
|
||||
*/
|
||||
InputPlaintext,
|
||||
|
||||
/**
|
||||
* Addition.
|
||||
*/
|
||||
Add,
|
||||
|
||||
/**
|
||||
* Add a ciphertext and plaintext value.
|
||||
*/
|
||||
AddPlaintext,
|
||||
|
||||
/**
|
||||
* Subtraction.
|
||||
*/
|
||||
Sub,
|
||||
|
||||
/**
|
||||
* Subtract a plaintext.
|
||||
*/
|
||||
SubPlaintext,
|
||||
|
||||
/**
|
||||
* Unary negation (i.e. given x, compute -x)
|
||||
*/
|
||||
Negate,
|
||||
|
||||
/**
|
||||
* Multiplication.
|
||||
*/
|
||||
Multiply,
|
||||
|
||||
/**
|
||||
* Multiply a ciphertext by a plaintext.
|
||||
*/
|
||||
MultiplyPlaintext,
|
||||
|
||||
/**
|
||||
* A literal that serves as an operand to other operations.
|
||||
*/
|
||||
Literal(Literal),
|
||||
|
||||
/**
|
||||
* Rotate left.
|
||||
*/
|
||||
RotateLeft,
|
||||
|
||||
/**
|
||||
* Rotate right.
|
||||
*/
|
||||
RotateRight,
|
||||
|
||||
/**
|
||||
* In the BFV scheme, swap rows in the Batched vectors.
|
||||
*/
|
||||
SwapRows,
|
||||
|
||||
/**
|
||||
* This node indicates the previous node's result should be a result of the [`fhe_program`].
|
||||
*/
|
||||
Output,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
/**
|
||||
* Information about an edge in the frontend IR.
|
||||
@@ -277,287 +185,13 @@ pub struct FrontendCompilation {
|
||||
/**
|
||||
* The dependency graph of the frontend's intermediate representation (IR) that backs an [`fhe_program`].
|
||||
*/
|
||||
pub graph: StableGraph<Operation, OperandInfo>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/**
|
||||
* The context for constructing the [`fhe_program`] graph during compilation.
|
||||
*
|
||||
* This is an implementation detail of the [`fhe_program`] macro, and you shouldn't need
|
||||
* to construct one.
|
||||
*/
|
||||
pub struct Context {
|
||||
/**
|
||||
* The frontend compilation result.
|
||||
*/
|
||||
pub compilation: FrontendCompilation,
|
||||
|
||||
/**
|
||||
* The set of parameters for which we're currently constructing the graph.
|
||||
*/
|
||||
pub params: Params,
|
||||
|
||||
/**
|
||||
* Stores indicies for graph nodes in a bump allocator. [`FheProgramNode`](crate::types::intern::FheProgramNode)
|
||||
* can request allocations of these. This allows it to use slices instead of Vecs, which allows
|
||||
* FheProgramNode to impl Copy.
|
||||
*/
|
||||
pub indicies_store: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
impl PartialEq for FrontendCompilation {
|
||||
fn eq(&self, b: &Self) -> bool {
|
||||
is_isomorphic_matching(
|
||||
&Graph::from(self.graph.clone()),
|
||||
&Graph::from(b.graph.clone()),
|
||||
|n1, n2| n1 == n2,
|
||||
|e1, e2| e1 == e2,
|
||||
)
|
||||
}
|
||||
pub graph: StableGraph<FheOperation, OperandInfo>,
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
/**
|
||||
* While constructing an [`fhe_program`], this refers to the current intermediate
|
||||
* representation. An implementation detail of the [`fhe_program`] macro.
|
||||
*/
|
||||
pub static CURRENT_CTX: RefCell<Option<&'static mut Context>> = RefCell::new(None);
|
||||
|
||||
/**
|
||||
* An arena containing slices of indicies. An implementation detail of the
|
||||
* [`fhe_program`] macro.
|
||||
*/
|
||||
pub static INDEX_ARENA: RefCell<bumpalo::Bump> = RefCell::new(bumpalo::Bump::new());
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the specified closure, injecting the current [`fhe_program`] context.
|
||||
*/
|
||||
pub fn with_ctx<F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut Context) -> R,
|
||||
{
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
let mut option = ctx.borrow_mut();
|
||||
let ctx = option
|
||||
.as_mut()
|
||||
.expect("Called Ciphertext::new() outside of a context.");
|
||||
|
||||
f(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/**
|
||||
* Creates a new empty frontend intermediate representation context with the given scheme.
|
||||
*/
|
||||
pub fn new(params: &Params) -> Self {
|
||||
Self {
|
||||
compilation: FrontendCompilation {
|
||||
graph: StableGraph::new(),
|
||||
},
|
||||
params: params.clone(),
|
||||
indicies_store: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_2_input(&mut self, op: Operation, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
let new_id = self.compilation.graph.add_node(op);
|
||||
self.compilation
|
||||
.graph
|
||||
.add_edge(left, new_id, OperandInfo::Left);
|
||||
self.compilation
|
||||
.graph
|
||||
.add_edge(right, new_id, OperandInfo::Right);
|
||||
|
||||
new_id
|
||||
}
|
||||
|
||||
fn add_1_input(&mut self, op: Operation, i: NodeIndex) -> NodeIndex {
|
||||
let new_id = self.compilation.graph.add_node(op);
|
||||
self.compilation
|
||||
.graph
|
||||
.add_edge(i, new_id, OperandInfo::Unary);
|
||||
|
||||
new_id
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an input to this context.
|
||||
*/
|
||||
pub fn add_ciphertext_input(&mut self) -> NodeIndex {
|
||||
self.compilation.graph.add_node(Operation::InputCiphertext)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an input to this context.
|
||||
*/
|
||||
pub fn add_plaintext_input(&mut self) -> NodeIndex {
|
||||
self.compilation.graph.add_node(Operation::InputPlaintext)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a plaintext literal to the [`fhe_program`] graph.
|
||||
*/
|
||||
pub fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex {
|
||||
self.compilation
|
||||
.graph
|
||||
.add_node(Operation::Literal(Literal::Plaintext(plaintext)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a subtraction to this context.
|
||||
*/
|
||||
pub fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::Sub, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a subtraction to this context.
|
||||
*/
|
||||
pub fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::SubPlaintext, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a negation to this context.
|
||||
*/
|
||||
pub fn add_negate(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.add_1_input(Operation::Negate, x)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an addition to this context.
|
||||
*/
|
||||
pub fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::Add, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an addition to a plaintext.
|
||||
*/
|
||||
pub fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::AddPlaintext, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a multiplication to this context.
|
||||
*/
|
||||
pub fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::Multiply, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a multiplication to this context.
|
||||
*/
|
||||
pub fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::MultiplyPlaintext, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a literal to this context.
|
||||
*/
|
||||
pub fn add_literal(&mut self, literal: Literal) -> NodeIndex {
|
||||
// See if we already have a node for the given literal. If so, just return it.
|
||||
// If not, make a new one.
|
||||
let existing_literal =
|
||||
self.compilation
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|&i| match &self.compilation.graph[i] {
|
||||
Operation::Literal(x) => *x == literal,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
match existing_literal {
|
||||
Some(x) => x,
|
||||
None => self.compilation.graph.add_node(Operation::Literal(literal)),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a rotate left.
|
||||
*/
|
||||
pub fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::RotateLeft, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a rotate right.
|
||||
*/
|
||||
pub fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_2_input(Operation::RotateRight, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a row swap.
|
||||
*/
|
||||
pub fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.add_1_input(Operation::SwapRows, x)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a node that captures the previous node as an output.
|
||||
*/
|
||||
pub fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
|
||||
self.add_1_input(Operation::Output, i)
|
||||
}
|
||||
}
|
||||
|
||||
impl FrontendCompilation {
|
||||
/**
|
||||
* Performs frontend compilation of this intermediate representation into a backend [`FheProgram`],
|
||||
* then perform backend compilation and return the result.
|
||||
*/
|
||||
pub fn compile(&self) -> FheProgram {
|
||||
let mut fhe_program = FheProgram::new(SchemeType::Bfv);
|
||||
|
||||
let mapped_graph = self.graph.map(
|
||||
|id, n| match n {
|
||||
Operation::Add => NodeInfo::new(FheProgramOperation::Add),
|
||||
Operation::InputCiphertext => {
|
||||
// HACKHACK: Input nodes are always added first to the graph in the order
|
||||
// they're specified as function arguments. We should not depend on this.
|
||||
NodeInfo::new(FheProgramOperation::InputCiphertext(id.index()))
|
||||
}
|
||||
Operation::InputPlaintext => {
|
||||
// HACKHACK: Input nodes are always added first to the graph in the order
|
||||
// they're specified as function arguments. We should not depend on this.
|
||||
NodeInfo::new(FheProgramOperation::InputPlaintext(id.index()))
|
||||
}
|
||||
Operation::Literal(Literal::U64(x)) => {
|
||||
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::U64(*x)))
|
||||
}
|
||||
Operation::Literal(Literal::Plaintext(x)) => {
|
||||
// It's okay to unwrap here because fhe_program compilation will
|
||||
// catch the panic and return a compilation error.
|
||||
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::Plaintext(
|
||||
x.to_bytes().expect("Failed to serialize plaintext."),
|
||||
)))
|
||||
}
|
||||
Operation::Sub => NodeInfo::new(FheProgramOperation::Sub),
|
||||
Operation::SubPlaintext => NodeInfo::new(FheProgramOperation::SubPlaintext),
|
||||
Operation::Negate => NodeInfo::new(FheProgramOperation::Negate),
|
||||
Operation::Multiply => NodeInfo::new(FheProgramOperation::Multiply),
|
||||
Operation::MultiplyPlaintext => {
|
||||
NodeInfo::new(FheProgramOperation::MultiplyPlaintext)
|
||||
}
|
||||
Operation::Output => NodeInfo::new(FheProgramOperation::OutputCiphertext),
|
||||
Operation::RotateLeft => NodeInfo::new(FheProgramOperation::ShiftLeft),
|
||||
Operation::RotateRight => NodeInfo::new(FheProgramOperation::ShiftRight),
|
||||
Operation::SwapRows => NodeInfo::new(FheProgramOperation::SwapRows),
|
||||
Operation::AddPlaintext => NodeInfo::new(FheProgramOperation::AddPlaintext),
|
||||
},
|
||||
|_, e| match e {
|
||||
OperandInfo::Left => EdgeInfo::LeftOperand,
|
||||
OperandInfo::Right => EdgeInfo::RightOperand,
|
||||
OperandInfo::Unary => EdgeInfo::UnaryOperand,
|
||||
},
|
||||
);
|
||||
|
||||
fhe_program.graph = mapped_graph;
|
||||
|
||||
compile_inplace(fhe_program)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Error, FheProgramFn, Result, SecurityLevel};
|
||||
use crate::{fhe::FheCompile, Error, FheProgramFn, Result, SecurityLevel};
|
||||
|
||||
use log::{debug, trace};
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use crate::{
|
||||
fhe::{with_fhe_ctx, FheContextOps, Literal},
|
||||
types::{
|
||||
intern::{Cipher, FheProgramNode},
|
||||
ops::*,
|
||||
BfvType, FheType, LaneCount, NumCiphertexts, SwapRows, TryFromPlaintext, TryIntoPlaintext,
|
||||
Type, TypeName, TypeNameInstance, Version,
|
||||
},
|
||||
with_ctx, FheProgramInputTrait, InnerPlaintext, Literal, Params, Plaintext, WithContext,
|
||||
FheProgramInputTrait, InnerPlaintext, Params, Plaintext, WithContext,
|
||||
};
|
||||
use seal_fhe::{
|
||||
BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus,
|
||||
@@ -509,7 +510,7 @@ impl<const LANES: usize> GraphCipherAdd for Batched<LANES> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -525,7 +526,7 @@ impl<const LANES: usize> GraphCipherSub for Batched<LANES> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -541,7 +542,7 @@ impl<const LANES: usize> GraphCipherMul for Batched<LANES> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -557,8 +558,8 @@ impl<const LANES: usize> GraphCipherConstMul for Batched<LANES> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
let l = ctx.add_plaintext_literal(b.inner);
|
||||
let n = ctx.add_multiplication_plaintext(a.ids[0], l);
|
||||
|
||||
@@ -569,7 +570,7 @@ impl<const LANES: usize> GraphCipherConstMul for Batched<LANES> {
|
||||
|
||||
impl<const LANES: usize> GraphCipherSwapRows for Batched<LANES> {
|
||||
fn graph_cipher_swap_rows(x: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_swap_rows(x.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -582,7 +583,7 @@ impl<const LANES: usize> GraphCipherRotateLeft for Batched<LANES> {
|
||||
x: FheProgramNode<Cipher<Self>>,
|
||||
y: u64,
|
||||
) -> FheProgramNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let y = ctx.add_literal(Literal::U64(y));
|
||||
let n = ctx.add_rotate_left(x.ids[0], y);
|
||||
|
||||
@@ -596,7 +597,7 @@ impl<const LANES: usize> GraphCipherRotateRight for Batched<LANES> {
|
||||
x: FheProgramNode<Cipher<Self>>,
|
||||
y: u64,
|
||||
) -> FheProgramNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let y = ctx.add_literal(Literal::U64(y));
|
||||
let n = ctx.add_rotate_right(x.ids[0], y);
|
||||
|
||||
@@ -609,7 +610,7 @@ impl<const LANES: usize> GraphCipherNeg for Batched<LANES> {
|
||||
type Val = Self;
|
||||
|
||||
fn graph_cipher_neg(x: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_negate(x.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
use seal_fhe::Plaintext as SealPlaintext;
|
||||
|
||||
use crate::types::{
|
||||
ops::{
|
||||
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
|
||||
GraphCipherConstSub, GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd,
|
||||
GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub,
|
||||
GraphPlainCipherSub,
|
||||
use crate::{
|
||||
fhe::{with_fhe_ctx, FheContextOps},
|
||||
types::{
|
||||
ops::{
|
||||
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
|
||||
GraphCipherConstSub, GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd,
|
||||
GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub,
|
||||
GraphPlainCipherSub,
|
||||
},
|
||||
Cipher,
|
||||
},
|
||||
Cipher,
|
||||
};
|
||||
use crate::{
|
||||
types::{intern::FheProgramNode, BfvType, FheType, Type, Version},
|
||||
with_ctx, FheProgramInputTrait, Params, WithContext,
|
||||
FheProgramInputTrait, Params, WithContext,
|
||||
};
|
||||
|
||||
use sunscreen_runtime::{
|
||||
@@ -209,7 +212,7 @@ impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -225,7 +228,7 @@ impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -241,8 +244,8 @@ impl<const INT_BITS: usize> GraphCipherConstAdd for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let n = ctx.add_addition_plaintext(a.ids[0], lit);
|
||||
@@ -260,7 +263,7 @@ impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -276,7 +279,7 @@ impl<const INT_BITS: usize> GraphCipherPlainSub for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -292,7 +295,7 @@ impl<const INT_BITS: usize> GraphPlainCipherSub for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Self::Left>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
|
||||
let n = ctx.add_negate(n);
|
||||
|
||||
@@ -309,8 +312,8 @@ impl<const INT_BITS: usize> GraphCipherConstSub for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let n = ctx.add_subtraction_plaintext(a.ids[0], lit);
|
||||
@@ -328,8 +331,8 @@ impl<const INT_BITS: usize> GraphConstCipherSub for Fractional<INT_BITS> {
|
||||
a: Self::Left,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Right>> {
|
||||
with_ctx(|ctx| {
|
||||
let a = Self::from(a).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let a = Self::from(a).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(a.inner);
|
||||
let n = ctx.add_subtraction_plaintext(b.ids[0], lit);
|
||||
@@ -348,7 +351,7 @@ impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -364,7 +367,7 @@ impl<const INT_BITS: usize> GraphCipherPlainMul for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -380,8 +383,8 @@ impl<const INT_BITS: usize> GraphCipherConstMul for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
|
||||
let n = ctx.add_multiplication_plaintext(a.ids[0], lit);
|
||||
@@ -399,10 +402,10 @@ impl<const INT_BITS: usize> GraphCipherConstDiv for Fractional<INT_BITS> {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: f64,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::try_from(1. / b)
|
||||
.unwrap()
|
||||
.try_into_plaintext(&ctx.params)
|
||||
.try_into_plaintext(&ctx.data)
|
||||
.unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
@@ -418,7 +421,7 @@ impl<const INT_BITS: usize> GraphCipherNeg for Fractional<INT_BITS> {
|
||||
type Val = Fractional<INT_BITS>;
|
||||
|
||||
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_negate(a.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::fhe::{with_fhe_ctx, FheContextOps};
|
||||
use crate::types::{
|
||||
bfv::Signed, intern::FheProgramNode, ops::*, BfvType, Cipher, FheType, GraphCipherAdd,
|
||||
GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, TryFromPlaintext,
|
||||
TryIntoPlaintext, TypeName,
|
||||
};
|
||||
use crate::{with_ctx, FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
|
||||
use crate::{FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
|
||||
use std::cmp::Eq;
|
||||
use std::ops::*;
|
||||
use sunscreen_runtime::Error;
|
||||
@@ -274,7 +275,7 @@ impl GraphCipherAdd for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
|
||||
let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
|
||||
@@ -297,7 +298,7 @@ impl GraphCipherPlainAdd for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
|
||||
let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
|
||||
@@ -320,14 +321,14 @@ impl GraphCipherConstAdd for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::try_from(b).unwrap();
|
||||
|
||||
let b_num =
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
let b_den =
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b_den);
|
||||
@@ -351,7 +352,7 @@ impl GraphCipherSub for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
|
||||
let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
|
||||
@@ -374,7 +375,7 @@ impl GraphCipherPlainSub for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
|
||||
let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
|
||||
@@ -397,7 +398,7 @@ impl GraphPlainCipherSub for Rational {
|
||||
a: FheProgramNode<Self::Left>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
|
||||
let num_b_2 = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
|
||||
@@ -420,13 +421,13 @@ impl GraphCipherConstSub for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::try_from(b).unwrap();
|
||||
|
||||
let b_num =
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
let b_den =
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b_den);
|
||||
@@ -450,13 +451,13 @@ impl GraphConstCipherSub for Rational {
|
||||
a: Self::Left,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Right>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let a = Self::try_from(a).unwrap();
|
||||
|
||||
let a_num =
|
||||
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
let a_den =
|
||||
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
// Scale each numinator by the other's denominator.
|
||||
let num_b_2 = ctx.add_multiplication_plaintext(b.ids[0], a_den);
|
||||
@@ -480,7 +481,7 @@ impl GraphCipherMul for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
let mul_den = ctx.add_multiplication(a.ids[1], b.ids[1]);
|
||||
@@ -500,7 +501,7 @@ impl GraphCipherPlainMul for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
|
||||
let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
|
||||
@@ -520,13 +521,13 @@ impl GraphCipherConstMul for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::try_from(b).unwrap();
|
||||
|
||||
let num_b =
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
let den_b =
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], num_b);
|
||||
@@ -547,7 +548,7 @@ impl GraphCipherDiv for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication(a.ids[0], b.ids[1]);
|
||||
let mul_den = ctx.add_multiplication(a.ids[1], b.ids[0]);
|
||||
@@ -567,7 +568,7 @@ impl GraphCipherPlainDiv for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
|
||||
let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
|
||||
@@ -587,7 +588,7 @@ impl GraphPlainCipherDiv for Rational {
|
||||
a: FheProgramNode<Self::Left>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
|
||||
let mul_den = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
|
||||
@@ -607,13 +608,13 @@ impl GraphCipherConstDiv for Rational {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::try_from(b).unwrap();
|
||||
|
||||
let num_b =
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.num.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
let den_b =
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(b.den.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication_plaintext(a.ids[0], den_b);
|
||||
@@ -634,13 +635,13 @@ impl GraphConstCipherDiv for Rational {
|
||||
a: Self::Left,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Right>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let a = Self::try_from(a).unwrap();
|
||||
|
||||
let num_a =
|
||||
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(a.num.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
let den_a =
|
||||
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.params).unwrap().inner);
|
||||
ctx.add_plaintext_literal(a.den.try_into_plaintext(&ctx.data).unwrap().inner);
|
||||
|
||||
// Scale each numinator by the other's denominator.
|
||||
let mul_num = ctx.add_multiplication_plaintext(b.ids[1], num_a);
|
||||
@@ -657,7 +658,7 @@ impl GraphCipherNeg for Rational {
|
||||
type Val = Self;
|
||||
|
||||
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self::Val>>) -> FheProgramNode<Cipher<Self::Val>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let neg = ctx.add_negate(a.ids[0]);
|
||||
let ids = [neg, a.ids[1]];
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
use seal_fhe::Plaintext as SealPlaintext;
|
||||
|
||||
use crate::types::{
|
||||
ops::{
|
||||
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstMul, GraphCipherConstSub,
|
||||
GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd, GraphCipherPlainMul,
|
||||
GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub, GraphPlainCipherSub,
|
||||
use crate::{
|
||||
fhe::{with_fhe_ctx, FheContextOps},
|
||||
types::{
|
||||
ops::{
|
||||
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstMul, GraphCipherConstSub,
|
||||
GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd, GraphCipherPlainMul,
|
||||
GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub, GraphPlainCipherSub,
|
||||
},
|
||||
Cipher,
|
||||
},
|
||||
Cipher,
|
||||
};
|
||||
use crate::{
|
||||
types::{intern::FheProgramNode, BfvType, FheType, TypeNameInstance},
|
||||
with_ctx, FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext,
|
||||
FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext,
|
||||
};
|
||||
|
||||
use sunscreen_runtime::{
|
||||
@@ -249,7 +252,7 @@ impl GraphCipherAdd for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -265,7 +268,7 @@ impl GraphCipherPlainAdd for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -281,8 +284,8 @@ impl GraphCipherConstAdd for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: i64,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let add = ctx.add_addition_plaintext(a.ids[0], lit);
|
||||
@@ -300,7 +303,7 @@ impl GraphCipherSub for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -316,7 +319,7 @@ impl GraphCipherPlainSub for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -332,7 +335,7 @@ impl GraphPlainCipherSub for Signed {
|
||||
a: FheProgramNode<Self::Left>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
|
||||
let n = ctx.add_negate(n);
|
||||
|
||||
@@ -349,8 +352,8 @@ impl GraphCipherConstSub for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: Self::Right,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let n = ctx.add_subtraction_plaintext(a.ids[0], lit);
|
||||
@@ -368,8 +371,8 @@ impl GraphConstCipherSub for Signed {
|
||||
a: i64,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Right>> {
|
||||
with_ctx(|ctx| {
|
||||
let a = Self::from(a).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let a = Self::from(a).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(a.inner);
|
||||
let n = ctx.add_subtraction_plaintext(b.ids[0], lit);
|
||||
@@ -384,7 +387,7 @@ impl GraphCipherNeg for Signed {
|
||||
type Val = Signed;
|
||||
|
||||
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_negate(a.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -400,7 +403,7 @@ impl GraphCipherMul for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Cipher<Self::Right>>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
@@ -416,8 +419,8 @@ impl GraphCipherConstMul for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: i64,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.params).unwrap();
|
||||
with_fhe_ctx(|ctx| {
|
||||
let b = Self::from(b).try_into_plaintext(&ctx.data).unwrap();
|
||||
|
||||
let lit = ctx.add_plaintext_literal(b.inner);
|
||||
let add = ctx.add_multiplication_plaintext(a.ids[0], lit);
|
||||
@@ -435,7 +438,7 @@ impl GraphCipherPlainMul for Signed {
|
||||
a: FheProgramNode<Cipher<Self::Left>>,
|
||||
b: FheProgramNode<Self::Right>,
|
||||
) -> FheProgramNode<Cipher<Self::Left>> {
|
||||
with_ctx(|ctx| {
|
||||
with_fhe_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
|
||||
|
||||
FheProgramNode::new(&[n])
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::{
|
||||
fhe::with_fhe_ctx,
|
||||
types::{
|
||||
intern::FheLiteral, ops::*, Cipher, FheType, LaneCount, NumCiphertexts, SwapRows, Type,
|
||||
TypeName,
|
||||
},
|
||||
with_ctx, INDEX_ARENA,
|
||||
INDEX_ARENA,
|
||||
};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
@@ -36,8 +37,8 @@ use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
* construction.
|
||||
*
|
||||
* # Undefined behavior
|
||||
* These types must be constructed while a [`crate::CURRENT_CTX`] refers to a valid
|
||||
* [`crate::Context`]. Furthermore, no [`FheProgramNode`] should outlive the said context.
|
||||
* These types must be constructed while [`CURRENT_FHE_CTX`][crate::fhe::CURRENT_FHE_CTX] refers to a valid
|
||||
* [`FheContext`](crate::fhe::FheContext). Furthermore, no [`FheProgramNode`] should outlive the said context.
|
||||
* Violating any of these conditions may result in memory corruption or
|
||||
* use-after-free.
|
||||
*/
|
||||
@@ -93,7 +94,7 @@ impl<T: NumCiphertexts> FheProgramNode<T> {
|
||||
* Returns the plain modulus parameter for the given BFV scheme
|
||||
*/
|
||||
pub fn get_plain_modulus() -> u64 {
|
||||
with_ctx(|ctx| ctx.params.plain_modulus)
|
||||
with_fhe_ctx(|ctx| ctx.data.plain_modulus)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
pub use crate::{
|
||||
types::{intern::FheProgramNode, Cipher, FheType, NumCiphertexts, TypeName},
|
||||
with_ctx,
|
||||
};
|
||||
use crate::fhe::{with_fhe_ctx, FheContextOps};
|
||||
pub use crate::types::{intern::FheProgramNode, Cipher, FheType, NumCiphertexts, TypeName};
|
||||
|
||||
/**
|
||||
* Create an input node from an Fhe Program input argument.
|
||||
@@ -18,7 +16,8 @@ pub trait Input {
|
||||
* You should not call this, but rather allow the [`fhe_program`](crate::fhe_program) macro to do this on your behalf.
|
||||
*
|
||||
* # Undefined behavior
|
||||
* This type references memory in a backing [`crate::Context`] and without carefully ensuring FheProgramNodes
|
||||
* This type references memory in a backing
|
||||
* [`FheContext`](crate::fhe::FheContext) and without carefully ensuring FheProgramNodes
|
||||
* never outlive the backing context, use-after-free can occur.
|
||||
*
|
||||
*/
|
||||
@@ -36,9 +35,9 @@ where
|
||||
|
||||
for _ in 0..T::NUM_CIPHERTEXTS {
|
||||
if T::type_name().is_encrypted {
|
||||
ids.push(with_ctx(|ctx| ctx.add_ciphertext_input()));
|
||||
ids.push(with_fhe_ctx(|ctx| ctx.add_ciphertext_input()));
|
||||
} else {
|
||||
ids.push(with_ctx(|ctx| ctx.add_plaintext_input()));
|
||||
ids.push(with_fhe_ctx(|ctx| ctx.add_plaintext_input()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,16 +68,17 @@ where
|
||||
#[test]
|
||||
fn can_create_inputs() {
|
||||
use crate::{
|
||||
fhe::{FheContext, FheOperation, CURRENT_FHE_CTX},
|
||||
types::{bfv::Rational, intern::FheProgramNode},
|
||||
Context, Operation, Params, SchemeType, SecurityLevel, CURRENT_CTX,
|
||||
Params, SchemeType, SecurityLevel,
|
||||
};
|
||||
use std::cell::RefCell;
|
||||
use std::mem::transmute;
|
||||
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
let mut context = Context::new(&Params {
|
||||
CURRENT_FHE_CTX.with(|ctx| {
|
||||
let mut context = FheContext::new(Params {
|
||||
lattice_dimension: 0,
|
||||
coeff_modulus: vec![],
|
||||
plain_modulus: 0,
|
||||
@@ -122,26 +122,26 @@ fn can_create_inputs() {
|
||||
|
||||
offset += 2 * 6 * 6;
|
||||
|
||||
assert_eq!(context.compilation.graph.node_count(), offset);
|
||||
assert_eq!(context.graph.node_count(), offset);
|
||||
|
||||
for i in 0..2 {
|
||||
assert_eq!(
|
||||
context.compilation.graph[NodeIndex::from(i)],
|
||||
Operation::InputPlaintext
|
||||
context.graph[NodeIndex::from(i)].operation,
|
||||
FheOperation::InputPlaintext
|
||||
);
|
||||
}
|
||||
|
||||
for i in 2..14 {
|
||||
assert_eq!(
|
||||
context.compilation.graph[NodeIndex::from(i)],
|
||||
Operation::InputPlaintext
|
||||
context.graph[NodeIndex::from(i)].operation,
|
||||
FheOperation::InputPlaintext
|
||||
);
|
||||
}
|
||||
|
||||
for i in 14..context.compilation.graph.node_count() {
|
||||
for i in 14..context.graph.node_count() {
|
||||
assert_eq!(
|
||||
context.compilation.graph[NodeIndex::from(i as u32)],
|
||||
Operation::InputCiphertext
|
||||
context.graph[NodeIndex::from(i as u32)].operation,
|
||||
FheOperation::InputCiphertext
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
fhe::{with_fhe_ctx, FheContextOps},
|
||||
types::{intern::FheProgramNode, NumCiphertexts},
|
||||
with_ctx,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -18,8 +18,10 @@ pub trait Output {
|
||||
* You should not call this, but rather allow the [`fhe_program`](crate::fhe_program) macro to do this on your behalf.
|
||||
*
|
||||
* # Undefined behavior
|
||||
* This type references memory in a backing [`crate::Context`] and without carefully ensuring FheProgramNodes
|
||||
* never outlive the backing context, use-after-free can occur.
|
||||
* This type references memory in a backing
|
||||
* [`FheContext`](crate::fhe::FheContext) and without carefully
|
||||
* ensuring FheProgramNodes never outlive the backing context,
|
||||
* use-after-free can occur.
|
||||
*/
|
||||
fn output(&self) -> Self::Output;
|
||||
}
|
||||
@@ -34,7 +36,7 @@ where
|
||||
let mut ids = Vec::with_capacity(self.ids.len());
|
||||
|
||||
for i in 0..self.ids.len() {
|
||||
ids.push(with_ctx(|ctx| ctx.add_output(self.ids[i])));
|
||||
ids.push(with_fhe_ctx(|ctx| ctx.add_output(self.ids[i])));
|
||||
}
|
||||
|
||||
FheProgramNode::new(&ids)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::{with_ctx, Literal};
|
||||
use crate::{
|
||||
fhe::{with_fhe_ctx, FheContextOps},
|
||||
Literal,
|
||||
};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -13,6 +16,6 @@ impl U64LiteralRef {
|
||||
* graph, a reference to the existing literal is returned.
|
||||
*/
|
||||
pub fn node(val: u64) -> NodeIndex {
|
||||
with_ctx(|ctx| ctx.add_literal(Literal::U64(val)))
|
||||
with_fhe_ctx(|ctx| ctx.add_literal(Literal::U64(val)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +69,11 @@ pub mod intern;
|
||||
*/
|
||||
mod ops;
|
||||
|
||||
/**
|
||||
* Contains types used in creating zero-knowledge proof R1CS circuits.
|
||||
*/
|
||||
pub mod zkp;
|
||||
|
||||
use crate::types::ops::*;
|
||||
|
||||
pub use sunscreen_runtime::{
|
||||
|
||||
99
sunscreen/src/types/zkp/mod.rs
Normal file
99
sunscreen/src/types/zkp/mod.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
mod native_field;
|
||||
mod program_node;
|
||||
|
||||
pub use native_field::*;
|
||||
pub use program_node::*;
|
||||
|
||||
/**
|
||||
* A trait for adding two ZKP values together
|
||||
*/
|
||||
pub trait AddVar
|
||||
where
|
||||
Self: Sized + ZkpType,
|
||||
{
|
||||
/**
|
||||
* Add the 2 values.
|
||||
*/
|
||||
fn add(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait for multiplying two ZKP values together
|
||||
*/
|
||||
pub trait MulVar
|
||||
where
|
||||
Self: Sized + ZkpType,
|
||||
{
|
||||
/**
|
||||
* Compute lhs * rhs.
|
||||
*/
|
||||
fn mul(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait for dividing 2 zkp values.
|
||||
*/
|
||||
pub trait DivVar
|
||||
where
|
||||
Self: Sized + ZkpType,
|
||||
{
|
||||
/**
|
||||
* Compute lhs / rhs.
|
||||
*/
|
||||
fn div(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait for computing the Remainder of 2 zkp values.
|
||||
*/
|
||||
pub trait RemVar
|
||||
where
|
||||
Self: Sized + ZkpType,
|
||||
{
|
||||
/**
|
||||
* Compute lhs % rhs;
|
||||
*/
|
||||
fn rem(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait for subtracting 2 zkp values.
|
||||
*/
|
||||
pub trait SubVar
|
||||
where
|
||||
Self: Sized + ZkpType,
|
||||
{
|
||||
/**
|
||||
* Compute lhs - rhs.
|
||||
*/
|
||||
fn sub(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait for computing the additive inverse of a zkp value.
|
||||
*/
|
||||
pub trait NegVar
|
||||
where
|
||||
Self: Sized + ZkpType,
|
||||
{
|
||||
/**
|
||||
* Compute -lhs.
|
||||
*/
|
||||
fn neg(lhs: ProgramNode<Self>) -> ProgramNode<Self>;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of native field elements needed to represent a ZKP type.
|
||||
*/
|
||||
pub trait NumFieldElements {
|
||||
/**
|
||||
* The number of native field elements needed to represent this type.
|
||||
*/
|
||||
const NUM_NATIVE_FIELD_ELEMENTS: usize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Encapsulates all the traits required for a type to be used in ZKP
|
||||
* programs.
|
||||
*/
|
||||
pub trait ZkpType: NumFieldElements {}
|
||||
54
sunscreen/src/types/zkp/native_field.rs
Normal file
54
sunscreen/src/types/zkp/native_field.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use sunscreen_compiler_macros::TypeName;
|
||||
|
||||
use crate::{
|
||||
types::zkp::{AddVar, ProgramNode},
|
||||
zkp::{with_zkp_ctx, ZkpContextOps},
|
||||
};
|
||||
|
||||
use super::{MulVar, NegVar, NumFieldElements, ZkpType};
|
||||
|
||||
// Shouldn't need Clone + Copy, but there appears to be a bug in the Rust
|
||||
// compiler that prevents ProgramNode from being Copy if we don't.
|
||||
// https://github.com/rust-lang/rust/issues/104264
|
||||
#[derive(Clone, Copy, TypeName)]
|
||||
/**
|
||||
* The native field type in the underlying backend proof system. For
|
||||
* example, in Bulletproofs, this is [`Scalar`](https://docs.rs/curve25519-dalek-ng/4.1.1/curve25519_dalek_ng/scalar/struct.Scalar.html).
|
||||
*/
|
||||
pub struct NativeField {}
|
||||
|
||||
impl NumFieldElements for NativeField {
|
||||
const NUM_NATIVE_FIELD_ELEMENTS: usize = 1;
|
||||
}
|
||||
|
||||
impl ZkpType for NativeField {}
|
||||
|
||||
impl AddVar for NativeField {
|
||||
fn add(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
|
||||
with_zkp_ctx(|ctx| {
|
||||
let o = ctx.add_addition(lhs.ids[0], rhs.ids[0]);
|
||||
|
||||
ProgramNode::new(&[o])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MulVar for NativeField {
|
||||
fn mul(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
|
||||
with_zkp_ctx(|ctx| {
|
||||
let o = ctx.add_multiplication(lhs.ids[0], rhs.ids[0]);
|
||||
|
||||
ProgramNode::new(&[o])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl NegVar for NativeField {
|
||||
fn neg(lhs: ProgramNode<Self>) -> ProgramNode<Self> {
|
||||
with_zkp_ctx(|ctx| {
|
||||
let o = ctx.add_negate(lhs.ids[0]);
|
||||
|
||||
ProgramNode::new(&[o])
|
||||
})
|
||||
}
|
||||
}
|
||||
141
sunscreen/src/types/zkp/program_node.rs
Normal file
141
sunscreen/src/types/zkp/program_node.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
ops::{Add, Div, Mul, Neg, Rem, Sub},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
types::zkp::{AddVar, DivVar, MulVar, NegVar, RemVar, SubVar, ZkpType},
|
||||
zkp::{with_zkp_ctx, ZkpContextOps},
|
||||
INDEX_ARENA,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
/**
|
||||
* An implementation detail of the ZKP compiler. Each expression in a ZKP
|
||||
* program is expressed in terms of `ProgramNode`, which proxy and compose
|
||||
* the parse graph for a ZKP program.
|
||||
*
|
||||
* They proxy operations (+, -, /, etc) to their underlying type T to
|
||||
* manipulate the program graph as appropriate.
|
||||
*
|
||||
* # Remarks
|
||||
* For internal use only.
|
||||
*/
|
||||
pub struct ProgramNode<T>
|
||||
where
|
||||
T: ZkpType,
|
||||
{
|
||||
/**
|
||||
* The indices in the graph that compose the type backing this
|
||||
* `ProgramNode`.
|
||||
*/
|
||||
pub ids: &'static [NodeIndex],
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> ProgramNode<T>
|
||||
where
|
||||
T: ZkpType,
|
||||
{
|
||||
/**
|
||||
* Create a new Program node from the given indicies in the
|
||||
*/
|
||||
pub fn new(ids: &[NodeIndex]) -> Self {
|
||||
INDEX_ARENA.with(|allocator| {
|
||||
let allocator = allocator.borrow();
|
||||
let ids_dest = allocator.alloc_slice_copy(ids);
|
||||
|
||||
ids_dest.copy_from_slice(ids);
|
||||
|
||||
// The memory in the bump allocator is valid until we call reset, which
|
||||
// we do after creating the FHE program. At this time, no FheProgramNodes should
|
||||
// remain.
|
||||
// We invoke the dark transmutation ritual to turn a finite lifetime into a 'static.
|
||||
Self {
|
||||
ids: unsafe { std::mem::transmute(ids_dest) },
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a public program input of type T.
|
||||
*/
|
||||
pub fn input() -> Self {
|
||||
let mut ids = Vec::with_capacity(T::NUM_NATIVE_FIELD_ELEMENTS);
|
||||
|
||||
for _ in 0..T::NUM_NATIVE_FIELD_ELEMENTS {
|
||||
ids.push(with_zkp_ctx(|ctx| ctx.add_public_input()));
|
||||
}
|
||||
|
||||
Self::new(&ids)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Add<ProgramNode<T>> for ProgramNode<T>
|
||||
where
|
||||
T: AddVar + ZkpType,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
<T as AddVar>::add(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Mul<ProgramNode<T>> for ProgramNode<T>
|
||||
where
|
||||
T: MulVar + ZkpType,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
<T as MulVar>::mul(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Div<ProgramNode<T>> for ProgramNode<T>
|
||||
where
|
||||
T: DivVar + ZkpType,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
<T as DivVar>::div(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Rem<ProgramNode<T>> for ProgramNode<T>
|
||||
where
|
||||
T: RemVar + ZkpType,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
<T as RemVar>::rem(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sub<ProgramNode<T>> for ProgramNode<T>
|
||||
where
|
||||
T: SubVar + ZkpType,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
<T as SubVar>::sub(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Neg for ProgramNode<T>
|
||||
where
|
||||
T: NegVar + ZkpType,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
<T as NegVar>::neg(self)
|
||||
}
|
||||
}
|
||||
167
sunscreen/src/zkp/mod.rs
Normal file
167
sunscreen/src/zkp/mod.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use sunscreen_runtime::CallSignature;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
/**
|
||||
* An internal representation of a ZKP program specification.
|
||||
*/
|
||||
pub trait ZkpProgramFn {
|
||||
/**
|
||||
* Create a circuit from this specification.
|
||||
*/
|
||||
fn build(&self) -> Result<ZkpFrontendCompilation>;
|
||||
|
||||
/**
|
||||
* Gets the call signature for this program.
|
||||
*/
|
||||
fn signature(&self) -> CallSignature;
|
||||
|
||||
/**
|
||||
* Gets the name of this program.
|
||||
*/
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use sunscreen_compiler_common::{
|
||||
Context, FrontendCompilation, Operation as OperationTrait, Render,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
|
||||
pub enum Operation {
|
||||
PrivateInput(NodeIndex),
|
||||
PublicInput(NodeIndex),
|
||||
HiddenInput(NodeIndex),
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Neg,
|
||||
}
|
||||
|
||||
impl OperationTrait for Operation {
|
||||
fn is_binary(&self) -> bool {
|
||||
matches!(self, Operation::Add | Operation::Mul | Operation::Sub)
|
||||
}
|
||||
|
||||
fn is_commutative(&self) -> bool {
|
||||
matches!(self, Operation::Add | Operation::Mul)
|
||||
}
|
||||
|
||||
fn is_unary(&self) -> bool {
|
||||
matches!(self, Operation::Neg)
|
||||
}
|
||||
}
|
||||
|
||||
impl Operation {
|
||||
pub fn is_add(&self) -> bool {
|
||||
matches!(self, Operation::Add)
|
||||
}
|
||||
|
||||
pub fn is_sub(&self) -> bool {
|
||||
matches!(self, Operation::Sub)
|
||||
}
|
||||
|
||||
pub fn is_mul(&self) -> bool {
|
||||
matches!(self, Operation::Mul)
|
||||
}
|
||||
|
||||
pub fn is_neg(&self) -> bool {
|
||||
matches!(self, Operation::Neg)
|
||||
}
|
||||
|
||||
pub fn is_private_input(&self) -> bool {
|
||||
matches!(self, Operation::PrivateInput(_))
|
||||
}
|
||||
|
||||
pub fn is_public_input(&self) -> bool {
|
||||
matches!(self, Operation::PublicInput(_))
|
||||
}
|
||||
|
||||
pub fn is_hidden_input(&self) -> bool {
|
||||
matches!(self, Operation::HiddenInput(_))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An implementation detail of a ZKP program. During compilation, it holds
|
||||
* the graph of the program currently being constructed in an
|
||||
* [`#[zkp_program]`](crate::zkp_program) function.
|
||||
*
|
||||
* # Remarks
|
||||
* For internal use only.
|
||||
*/
|
||||
pub type ZkpContext = Context<Operation, u32>;
|
||||
/**
|
||||
* Contains the results of compiling a [`#[zkp_program]`](crate::zkp_program) function.
|
||||
*
|
||||
* # Remarks
|
||||
* For internal use only.
|
||||
*/
|
||||
pub type ZkpFrontendCompilation = FrontendCompilation<Operation>;
|
||||
|
||||
pub trait ZkpContextOps {
|
||||
fn add_public_input(&mut self) -> NodeIndex;
|
||||
|
||||
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
|
||||
|
||||
fn add_negate(&mut self, left: NodeIndex) -> NodeIndex;
|
||||
}
|
||||
|
||||
impl ZkpContextOps for ZkpContext {
|
||||
fn add_public_input(&mut self) -> NodeIndex {
|
||||
let node = self.add_node(Operation::PublicInput(NodeIndex::from(self.data)));
|
||||
self.data += 1;
|
||||
|
||||
node
|
||||
}
|
||||
|
||||
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(Operation::Add, left, right)
|
||||
}
|
||||
|
||||
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
self.add_binary_operation(Operation::Mul, left, right)
|
||||
}
|
||||
|
||||
fn add_negate(&mut self, left: NodeIndex) -> NodeIndex {
|
||||
self.add_unary_operation(Operation::Neg, left)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for Operation {
|
||||
fn render(&self) -> String {
|
||||
format!("{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
/**
|
||||
* Contains the graph of a ZKP program during compilation. An
|
||||
* implementation detail and not for public consumption.
|
||||
*/
|
||||
pub static CURRENT_ZKP_CTX: RefCell<Option<&'static mut ZkpContext>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the specified closure, injecting the current
|
||||
* [`fhe_program`](crate::fhe_program) context.
|
||||
*/
|
||||
pub fn with_zkp_ctx<F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut ZkpContext) -> R,
|
||||
{
|
||||
CURRENT_ZKP_CTX.with(|ctx| {
|
||||
let mut option = ctx.borrow_mut();
|
||||
let ctx = option
|
||||
.as_mut()
|
||||
.expect("Called with_zkp_ctx() outside of a context.");
|
||||
|
||||
f(ctx)
|
||||
})
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
use sunscreen::{
|
||||
fhe::{FheFrontendCompilation, CURRENT_FHE_CTX},
|
||||
fhe_program,
|
||||
types::{bfv::Signed, Cipher, TypeName},
|
||||
CallSignature, FheProgramFn, FrontendCompilation, Params, SchemeType, SecurityLevel,
|
||||
CURRENT_CTX,
|
||||
CallSignature, FheProgramFn, Params, SchemeType, SecurityLevel,
|
||||
};
|
||||
|
||||
use serde_json::json;
|
||||
@@ -46,7 +46,7 @@ fn fhe_program_gets_called() {
|
||||
fn panicing_fhe_program_clears_ctx() {
|
||||
#[fhe_program(scheme = "bfv")]
|
||||
fn panic_fhe_program() {
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
CURRENT_FHE_CTX.with(|ctx| {
|
||||
let old = ctx.take();
|
||||
|
||||
assert!(old.is_some());
|
||||
@@ -71,7 +71,7 @@ fn panicing_fhe_program_clears_ctx() {
|
||||
|
||||
assert!(panic_result.is_err());
|
||||
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
CURRENT_FHE_CTX.with(|ctx| {
|
||||
let old = ctx.take();
|
||||
|
||||
assert!(old.is_none());
|
||||
@@ -102,7 +102,7 @@ fn capture_fhe_program_input_args() {
|
||||
|
||||
let context = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
|
||||
assert_eq!(context.graph.node_count(), 4);
|
||||
assert_eq!(context.0.node_count(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -122,48 +122,45 @@ fn can_add() {
|
||||
assert_eq!(fhe_program_with_args.signature(), expected_signature);
|
||||
assert_eq!(fhe_program_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let context: FrontendCompilation = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
let context = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"Add",
|
||||
"Add"
|
||||
"nodes": [
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "Add" },
|
||||
{ "operation": "Add" }
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
0,
|
||||
3,
|
||||
"Left"
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
0,
|
||||
3,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
3,
|
||||
4,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
"Right"
|
||||
]
|
||||
[
|
||||
1,
|
||||
3,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
3,
|
||||
4,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
"Right"
|
||||
]
|
||||
},
|
||||
]
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
context,
|
||||
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
|
||||
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -182,36 +179,33 @@ fn can_add_plaintext() {
|
||||
assert_eq!(fhe_program_with_args.signature(), expected_signature);
|
||||
assert_eq!(fhe_program_with_args.scheme_type(), SchemeType::Bfv);
|
||||
|
||||
let context: FrontendCompilation = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
let context = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputPlaintext",
|
||||
"AddPlaintext",
|
||||
"nodes": [
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputPlaintext" },
|
||||
{ "operation": "AddPlaintext" },
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
0,
|
||||
2,
|
||||
"Left"
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
0,
|
||||
2,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
1,
|
||||
2,
|
||||
"Right"
|
||||
],
|
||||
]
|
||||
},
|
||||
[
|
||||
1,
|
||||
2,
|
||||
"Right"
|
||||
],
|
||||
]
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
context,
|
||||
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
|
||||
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -235,44 +229,42 @@ fn can_mul() {
|
||||
let context = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"Multiply",
|
||||
"Multiply"
|
||||
"nodes": [
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "Multiply" },
|
||||
{ "operation": "Multiply" }
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
0,
|
||||
3,
|
||||
"Left"
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
"edges": [
|
||||
[
|
||||
0,
|
||||
3,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
3,
|
||||
4,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
"Right"
|
||||
]
|
||||
[
|
||||
1,
|
||||
3,
|
||||
"Right"
|
||||
],
|
||||
[
|
||||
3,
|
||||
4,
|
||||
"Left"
|
||||
],
|
||||
[
|
||||
2,
|
||||
4,
|
||||
"Right"
|
||||
]
|
||||
},
|
||||
]
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
context,
|
||||
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
|
||||
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -296,13 +288,12 @@ fn can_collect_output() {
|
||||
let context = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"Multiply",
|
||||
"Add",
|
||||
"Output"
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "Multiply" },
|
||||
{ "operation": "Add" },
|
||||
{ "operation": "Output" },
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
@@ -333,12 +324,13 @@ fn can_collect_output() {
|
||||
"Unary"
|
||||
]
|
||||
]
|
||||
},
|
||||
});
|
||||
|
||||
dbg!(serde_json::to_string(&context).unwrap());
|
||||
|
||||
assert_eq!(
|
||||
context,
|
||||
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
|
||||
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -365,14 +357,13 @@ fn can_collect_multiple_outputs() {
|
||||
let context = fhe_program_with_args.build(&get_params()).unwrap();
|
||||
|
||||
let expected = json!({
|
||||
"graph": {
|
||||
"nodes": [
|
||||
"InputCiphertext",
|
||||
"InputCiphertext",
|
||||
"Multiply",
|
||||
"Add",
|
||||
"Output",
|
||||
"Output"
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "InputCiphertext" },
|
||||
{ "operation": "Multiply" },
|
||||
{ "operation": "Add" },
|
||||
{ "operation": "Output" },
|
||||
{ "operation": "Output" },
|
||||
],
|
||||
"node_holes": [],
|
||||
"edge_property": "directed",
|
||||
@@ -408,11 +399,10 @@ fn can_collect_multiple_outputs() {
|
||||
"Unary"
|
||||
]
|
||||
]
|
||||
},
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
context,
|
||||
serde_json::from_value::<FrontendCompilation>(expected).unwrap()
|
||||
serde_json::from_value::<FheFrontendCompilation>(expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
11
sunscreen/tests/zkp_program_tests.rs
Normal file
11
sunscreen/tests/zkp_program_tests.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use sunscreen::{types::zkp::NativeField, zkp_program, ZkpProgramFn};
|
||||
|
||||
#[test]
|
||||
fn can_add_and_mul_native_fields() {
|
||||
#[zkp_program(backend = "bulletproofs")]
|
||||
fn add_mul(a: NativeField, b: NativeField) {
|
||||
a + b * a
|
||||
}
|
||||
|
||||
add_mul.build().unwrap();
|
||||
}
|
||||
@@ -7,5 +7,8 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
petgraph = "0.6.2"
|
||||
proc-macro2 = "1.0.47"
|
||||
quote = "1.0.21"
|
||||
semver = "1.0.14"
|
||||
serde = "1.0.147"
|
||||
serde = { version = "1.0.147", features = ["derive"] }
|
||||
syn = { version = "1.0.103", features = ["full"] }
|
||||
@@ -1,12 +1,15 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use petgraph::algo::is_isomorphic_matching;
|
||||
use petgraph::stable_graph::{NodeIndex, StableGraph};
|
||||
use petgraph::visit::{EdgeRef, IntoEdgeReferences, IntoNodeReferences};
|
||||
use petgraph::Graph;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Operation, Render};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq, Eq)]
|
||||
/**
|
||||
* Information about a node in the compilation graph.
|
||||
*/
|
||||
@@ -29,7 +32,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
/**
|
||||
* Information about how one compiler graph node relates to another.
|
||||
*/
|
||||
@@ -79,7 +82,7 @@ impl Render for EdgeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
/**
|
||||
* The result of a frontend compiler.
|
||||
*/
|
||||
@@ -98,6 +101,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> PartialEq for FrontendCompilation<O>
|
||||
where
|
||||
O: Operation,
|
||||
{
|
||||
/// FOR TESTING ONLY!!!
|
||||
/// Graph isomorphism is an NP-Complete problem!
|
||||
fn eq(&self, b: &Self) -> bool {
|
||||
is_isomorphic_matching(
|
||||
&Graph::from(self.0.clone()),
|
||||
&Graph::from(b.0.clone()),
|
||||
|n1, n2| n1 == n2,
|
||||
|e1, e2| e1 == e2,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> Debug for FrontendCompilation<O>
|
||||
where
|
||||
O: Operation,
|
||||
@@ -154,7 +173,7 @@ where
|
||||
/**
|
||||
* A compilation context. This stores the current parse graph.
|
||||
*/
|
||||
pub struct Context<O>
|
||||
pub struct Context<O, D>
|
||||
where
|
||||
O: Operation,
|
||||
{
|
||||
@@ -165,22 +184,22 @@ where
|
||||
|
||||
#[allow(unused)]
|
||||
/**
|
||||
* Consumers can use this to uniquely number their inputs.
|
||||
* Data given by the consumer.
|
||||
*/
|
||||
pub next_input_id: u32,
|
||||
pub data: D,
|
||||
}
|
||||
|
||||
impl<O> Context<O>
|
||||
impl<O, D> Context<O, D>
|
||||
where
|
||||
O: Operation,
|
||||
{
|
||||
/**
|
||||
* Create a new [`Context`].
|
||||
*/
|
||||
pub fn new() -> Self {
|
||||
pub fn new(data: D) -> Self {
|
||||
Self {
|
||||
graph: FrontendCompilation::<O>::new(),
|
||||
next_input_id: 0,
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,12 +240,3 @@ where
|
||||
node
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> Default for Context<O>
|
||||
where
|
||||
O: Operation,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,8 +62,8 @@ impl<'a, N, E> GraphQuery<'a, N, E> {
|
||||
*/
|
||||
pub trait TransformList<N, E>
|
||||
where
|
||||
N: Copy,
|
||||
E: Copy,
|
||||
N: Clone,
|
||||
E: Clone,
|
||||
{
|
||||
/**
|
||||
* Apply the transformations.
|
||||
@@ -87,8 +87,8 @@ where
|
||||
*/
|
||||
pub fn forward_traverse<N, E, F, T>(graph: &mut StableGraph<N, E>, callback: F)
|
||||
where
|
||||
N: Copy,
|
||||
E: Copy,
|
||||
N: Clone,
|
||||
E: Clone,
|
||||
T: TransformList<N, E>,
|
||||
F: FnMut(GraphQuery<N, E>, NodeIndex) -> T,
|
||||
{
|
||||
@@ -111,8 +111,8 @@ where
|
||||
*/
|
||||
pub fn reverse_traverse<N, E, F, T>(graph: &mut StableGraph<N, E>, callback: F)
|
||||
where
|
||||
N: Copy,
|
||||
E: Copy,
|
||||
N: Clone,
|
||||
E: Clone,
|
||||
T: TransformList<N, E>,
|
||||
F: FnMut(GraphQuery<N, E>, NodeIndex) -> T,
|
||||
{
|
||||
@@ -121,8 +121,8 @@ where
|
||||
|
||||
fn traverse<N, E, T, F>(graph: &mut StableGraph<N, E>, forward: bool, mut callback: F)
|
||||
where
|
||||
N: Copy,
|
||||
E: Copy,
|
||||
N: Clone,
|
||||
E: Clone,
|
||||
F: FnMut(GraphQuery<N, E>, NodeIndex) -> T,
|
||||
T: TransformList<N, E>,
|
||||
{
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
|
||||
mod context;
|
||||
mod graph;
|
||||
/**
|
||||
* Helper methods for macros.
|
||||
*/
|
||||
pub mod macros;
|
||||
|
||||
/**
|
||||
* A set of generic compiler transforms.
|
||||
*/
|
||||
@@ -41,7 +46,7 @@ pub trait Render {
|
||||
*
|
||||
* Also provides functions that describe properties of an operation.
|
||||
*/
|
||||
pub trait Operation: Clone + Copy + Debug + Hash + PartialEq + Eq {
|
||||
pub trait Operation: Clone + Debug + Hash + PartialEq + Eq {
|
||||
/**
|
||||
* Whether or not this operation commutes.
|
||||
*/
|
||||
|
||||
589
sunscreen_compiler_common/src/macros/mod.rs
Normal file
589
sunscreen_compiler_common/src/macros/mod.rs
Normal file
@@ -0,0 +1,589 @@
|
||||
use proc_macro2::{Span, TokenStream as TokenStream2};
|
||||
use quote::{format_ident, quote, quote_spanned};
|
||||
use syn::{
|
||||
parse_quote, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, FnArg, Ident,
|
||||
Index, Pat, ReturnType, Token, Type,
|
||||
};
|
||||
|
||||
mod type_name;
|
||||
|
||||
pub use type_name::derive_typename_impl;
|
||||
|
||||
#[derive(Debug)]
|
||||
/**
|
||||
* A type error that occurs in a program specification.
|
||||
*/
|
||||
pub enum ProgramTypeError {
|
||||
/**
|
||||
* The given type is illegal.
|
||||
*/
|
||||
IllegalType(Span),
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an input type T, returns
|
||||
* * ProgramNode<T> when T is a Path
|
||||
* * [map_input_type(T); N] when T is Array
|
||||
*/
|
||||
pub fn lift_type(arg_type: &Type) -> Result<Type, ProgramTypeError> {
|
||||
let transformed_type = match arg_type {
|
||||
Type::Path(ty) => parse_quote_spanned! {ty.span() => ProgramNode<#ty> },
|
||||
Type::Array(a) => {
|
||||
let inner_type = lift_type(&a.elem)?;
|
||||
let len = &a.len;
|
||||
|
||||
parse_quote_spanned! {a.span() =>
|
||||
[#inner_type; #len]
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(ProgramTypeError::IllegalType(arg_type.span()));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(transformed_type)
|
||||
}
|
||||
|
||||
/**
|
||||
* Emits code to make a program node for the given type T.
|
||||
*/
|
||||
pub fn create_program_node(var_name: &str, arg_type: &Type) -> TokenStream2 {
|
||||
let mapped_type = match lift_type(arg_type) {
|
||||
Ok(v) => v,
|
||||
Err(ProgramTypeError::IllegalType(s)) => {
|
||||
return quote_spanned! {
|
||||
s => compile_error!("Unsupported program input type.")
|
||||
};
|
||||
}
|
||||
};
|
||||
let var_name = format_ident!("{}", var_name);
|
||||
|
||||
let type_annotation = match arg_type {
|
||||
Type::Path(ty) => quote_spanned! { ty.span() => ProgramNode },
|
||||
Type::Array(a) => quote_spanned! { a.span() =>
|
||||
<#mapped_type>
|
||||
},
|
||||
_ => quote! {
|
||||
compile_error!("fhe_program arguments' name must be a simple identifier and type must be a plain path.");
|
||||
},
|
||||
};
|
||||
|
||||
quote_spanned! {arg_type.span() =>
|
||||
let #var_name: #mapped_type = #type_annotation::input();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/**
|
||||
* Errors that can occur when extracting the function signature of a
|
||||
* program.
|
||||
*/
|
||||
pub enum ExtractFnArgumentsError {
|
||||
/**
|
||||
* The method contains a reference to `self` or `&self`.
|
||||
*
|
||||
* # Remarks
|
||||
* FHE and ZKP programs must be pure functions.
|
||||
*/
|
||||
ContainsSelf(Span),
|
||||
|
||||
/**
|
||||
* The given type is not allowed.
|
||||
*/
|
||||
IllegalType(Span),
|
||||
|
||||
/**
|
||||
* The given type pattern is not of the a qualified path to a type.
|
||||
*/
|
||||
IllegalPat(Span),
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and parse the arguments of a function, returning a Vec of
|
||||
* the types and identifiers.
|
||||
*/
|
||||
pub fn extract_fn_arguments(
|
||||
args: &Punctuated<FnArg, Token!(,)>,
|
||||
) -> Result<Vec<(&Type, &Ident)>, ExtractFnArgumentsError> {
|
||||
let mut unwrapped_inputs = vec![];
|
||||
|
||||
for i in args {
|
||||
let input_type = match i {
|
||||
FnArg::Receiver(_) => {
|
||||
return Err(ExtractFnArgumentsError::ContainsSelf(i.span()));
|
||||
}
|
||||
FnArg::Typed(t) => match (&*t.ty, &*t.pat) {
|
||||
(Type::Path(_), Pat::Ident(i)) => (&*t.ty, &i.ident),
|
||||
(Type::Array(_), Pat::Ident(i)) => (&*t.ty, &i.ident),
|
||||
_ => {
|
||||
match &*t.pat {
|
||||
Pat::Ident(_) => {}
|
||||
_ => {
|
||||
return Err(ExtractFnArgumentsError::IllegalPat(t.span()));
|
||||
}
|
||||
};
|
||||
|
||||
return Err(ExtractFnArgumentsError::IllegalType(t.span()));
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
unwrapped_inputs.push(input_type);
|
||||
}
|
||||
|
||||
Ok(unwrapped_inputs)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/**
|
||||
* Errors that can occur when extracting the return value of an FHE
|
||||
* program.
|
||||
*/
|
||||
pub enum ExtractReturnTypesError {
|
||||
/**
|
||||
* The given return type is not allowed.
|
||||
*
|
||||
* # Remarks
|
||||
* ZKP programs don't return values.
|
||||
*
|
||||
* FHE programs must return either
|
||||
* * nothing (weird, but legal).
|
||||
* * a single FHE type.
|
||||
* * a tuple of FHE types.
|
||||
*
|
||||
*/
|
||||
IllegalType(Span),
|
||||
}
|
||||
|
||||
impl From<ProgramTypeError> for ExtractReturnTypesError {
|
||||
fn from(e: ProgramTypeError) -> Self {
|
||||
match e {
|
||||
ProgramTypeError::IllegalType(s) => Self::IllegalType(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unpacks the return types from a `ReturnType` and flattens them
|
||||
* into a Vec.
|
||||
* * Tuples will have more than one.
|
||||
* * Path, Paren, and Arrays will have one.
|
||||
* * Default has zero.
|
||||
*/
|
||||
pub fn extract_return_types(ret: &ReturnType) -> Result<Vec<Type>, ExtractReturnTypesError> {
|
||||
let return_types = match ret {
|
||||
ReturnType::Type(_, t) => match &**t {
|
||||
Type::Tuple(t) => t.elems.iter().cloned().collect::<Vec<Type>>(),
|
||||
Type::Paren(t) => {
|
||||
vec![*t.elem.clone()]
|
||||
}
|
||||
Type::Path(_) => {
|
||||
vec![*t.clone()]
|
||||
}
|
||||
Type::Array(_) => {
|
||||
vec![*t.clone()]
|
||||
}
|
||||
_ => return Err(ExtractReturnTypesError::IllegalType(t.span())),
|
||||
},
|
||||
ReturnType::Default => {
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
|
||||
Ok(return_types)
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes an array of return types and packages them into a tuple
|
||||
* if needed.
|
||||
*/
|
||||
pub fn pack_return_type(return_types: &[Type]) -> Type {
|
||||
match return_types.len() {
|
||||
0 => parse_quote! { () },
|
||||
1 => return_types[0].clone(),
|
||||
_ => {
|
||||
parse_quote_spanned! {return_types[0].span() => ( #(#return_types),* ) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emits code to create output nodes for each returned value in an
|
||||
* FHE program.
|
||||
*/
|
||||
pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 {
|
||||
match return_types.len() {
|
||||
1 => quote_spanned! { return_types[0].span() => v.output(); },
|
||||
_ => return_types
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let index = Index::from(i);
|
||||
|
||||
quote_spanned! {t.span() =>
|
||||
v.#index.output();
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emits the call signature of an FHE or ZKP program.
|
||||
*/
|
||||
pub fn emit_signature(args: &[Type], return_types: &[Type]) -> TokenStream2 {
|
||||
let arg_type_names = args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let alias = format_ident!("T{}", i);
|
||||
|
||||
quote! {
|
||||
type #alias = #t;
|
||||
}
|
||||
})
|
||||
.collect::<Vec<TokenStream2>>();
|
||||
|
||||
let arg_get_types = arg_type_names.iter().enumerate().map(|(i, _)| {
|
||||
let alias = format_ident!("T{}", i);
|
||||
|
||||
quote! {
|
||||
#alias::type_name(),
|
||||
}
|
||||
});
|
||||
|
||||
let return_type_aliases = return_types.iter().enumerate().map(|(i, t)| {
|
||||
let alias = format_ident!("R{}", i);
|
||||
|
||||
quote! {
|
||||
type #alias = #t;
|
||||
}
|
||||
});
|
||||
|
||||
let return_type_names = return_types.iter().enumerate().map(|(i, _)| {
|
||||
let alias = format_ident!("R{}", i);
|
||||
|
||||
quote! {
|
||||
#alias ::type_name(),
|
||||
}
|
||||
});
|
||||
|
||||
let return_type_sizes = return_types.iter().enumerate().map(|(i, _)| {
|
||||
let alias = format_ident!("R{}", i);
|
||||
|
||||
quote! {
|
||||
#alias ::NUM_CIPHERTEXTS,
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
use sunscreen::types::TypeName;
|
||||
|
||||
#(#arg_type_names)*
|
||||
#(#return_type_aliases)*
|
||||
|
||||
sunscreen::CallSignature {
|
||||
arguments: vec![#(#arg_get_types)*],
|
||||
returns: vec![#(#return_type_names)*],
|
||||
num_ciphertexts: vec![#(#return_type_sizes)*],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use quote::ToTokens;
|
||||
use syn::parse_quote;
|
||||
|
||||
fn assert_syn_eq<T, U>(a: &T, b: &U)
|
||||
where
|
||||
T: ToTokens,
|
||||
U: ToTokens,
|
||||
{
|
||||
assert_eq!(
|
||||
format!("{}", a.to_token_stream()),
|
||||
format!("{}", b.to_token_stream())
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_syn_slice_eq<T>(a: &[T], b: &[T])
|
||||
where
|
||||
T: ToTokens,
|
||||
{
|
||||
assert_eq!(a.len(), b.len());
|
||||
|
||||
for (l, r) in a.iter().zip(b) {
|
||||
assert_syn_eq(l, r);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_plain_scalar_type() {
|
||||
let type_name = quote! {
|
||||
Rational
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = lift_type(&type_name).unwrap();
|
||||
|
||||
let expected: Type = parse_quote! {
|
||||
ProgramNode<Rational>
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_array_type() {
|
||||
let type_name = quote! {
|
||||
[Rational; 6]
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = lift_type(&type_name).unwrap();
|
||||
|
||||
let expected: Type = parse_quote! {
|
||||
[ProgramNode<Rational>; 6]
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_multi_dimensional_array_type() {
|
||||
let type_name = quote! {
|
||||
[[Rational; 6]; 7]
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = lift_type(&type_name).unwrap();
|
||||
|
||||
let expected: Type = parse_quote! {
|
||||
[[ProgramNode<Rational>; 6]; 7]
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_multi_dimensional_array_cipher_type() {
|
||||
let type_name = quote! {
|
||||
[[Cipher<Rational>; 6]; 7]
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = lift_type(&type_name).unwrap();
|
||||
|
||||
let expected: Type = parse_quote! {
|
||||
[[ProgramNode<Cipher<Rational> >; 6]; 7]
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_create_simple_fhe_program_node() {
|
||||
let type_name = quote! {
|
||||
Cipher<Rational>
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = create_program_node("horse", &type_name);
|
||||
|
||||
let expected = quote! {
|
||||
let horse: ProgramNode<Cipher<Rational> > = ProgramNode::input();
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_create_array_program_node() {
|
||||
let type_name = quote! {
|
||||
[Cipher<Rational>; 7]
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = create_program_node("horse", &type_name);
|
||||
|
||||
let expected = quote! {
|
||||
let horse: [ProgramNode<Cipher<Rational> >; 7] = <[ProgramNode<Cipher<Rational> >; 7]>::input();
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_create_multidimensional_array_program_node() {
|
||||
let type_name = quote! {
|
||||
[[Cipher<Rational>; 7]; 6]
|
||||
};
|
||||
|
||||
let type_name: Type = parse_quote!(#type_name);
|
||||
|
||||
let actual = create_program_node("horse", &type_name);
|
||||
|
||||
let expected = quote! {
|
||||
let horse: [[ProgramNode<Cipher<Rational> >; 7]; 6] = <[[ProgramNode<Cipher<Rational> >; 7]; 6]>::input();
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_arguments() {
|
||||
let type_name = quote! {
|
||||
a: [[Cipher<Rational>; 7]; 6], b: Cipher<Rational>
|
||||
};
|
||||
|
||||
let args: Punctuated<FnArg, Token!(,)> = parse_quote!(#type_name);
|
||||
|
||||
let extracted = extract_fn_arguments(&args).unwrap();
|
||||
|
||||
let expected_t0: Type = parse_quote! { [[Cipher<Rational>; 7]; 6] };
|
||||
let expected_t1: Type = parse_quote! { Cipher<Rational> };
|
||||
let expected_i0: Ident = parse_quote! { a };
|
||||
let expected_i1: Ident = parse_quote! { b };
|
||||
|
||||
assert_eq!(extracted.len(), 2);
|
||||
assert_syn_eq(extracted[0].0, &expected_t0);
|
||||
assert_syn_eq(extracted[0].1, &expected_i0);
|
||||
assert_syn_eq(extracted[1].0, &expected_t1);
|
||||
assert_syn_eq(extracted[1].1, &expected_i1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disallows_self_arguments() {
|
||||
let type_name = quote! {
|
||||
&self, a: [[Cipher<Rational>; 7]; 6], b: Cipher<Rational>
|
||||
};
|
||||
|
||||
let args: Punctuated<FnArg, Token!(,)> = parse_quote!(#type_name);
|
||||
|
||||
let extracted = extract_fn_arguments(&args);
|
||||
|
||||
match extracted {
|
||||
Err(ExtractFnArgumentsError::ContainsSelf(_)) => {}
|
||||
_ => {
|
||||
panic!("Expected ExtractFnArgumentsError::ContainsSelf");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_no_return_type() {
|
||||
let return_type: Type = parse_quote! {
|
||||
()
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
assert_syn_slice_eq(&extracted, &[]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_single_return() {
|
||||
let return_type: Type = parse_quote! {
|
||||
Cipher<Signed>
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
assert_syn_slice_eq(&extracted, &[parse_quote! { Cipher<Signed> }]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_single_paren_return() {
|
||||
let return_type: Type = parse_quote! {
|
||||
(Cipher<Signed>)
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
assert_syn_slice_eq(&extracted, &[parse_quote! { Cipher<Signed> }]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_single_array_return() {
|
||||
let return_type: Type = parse_quote! {
|
||||
[[Cipher<Signed>; 6]; 7]
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
assert_syn_slice_eq(&extracted, &[parse_quote! { [[Cipher<Signed>; 6]; 7] }]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_single_multiarray_return() {
|
||||
let return_type: Type = parse_quote! {
|
||||
([[Cipher<Signed>; 6]; 7], Cipher<Signed>)
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
assert_syn_slice_eq(
|
||||
&extracted,
|
||||
&[
|
||||
parse_quote! { [[Cipher<Signed>; 6]; 7] },
|
||||
parse_quote! { Cipher<Signed> },
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_capture_single_output() {
|
||||
let return_type: Type = parse_quote! {
|
||||
(Cipher<Signed>)
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
let actual = emit_output_capture(&extracted);
|
||||
|
||||
let expected = quote! {
|
||||
v.output();
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_capture_multiple_outputs() {
|
||||
let return_type: Type = parse_quote! {
|
||||
(Cipher<Signed>, [[Cipher<Signed>; 6]; 7])
|
||||
};
|
||||
|
||||
let return_type = ReturnType::Type(syn::token::RArrow::default(), Box::new(return_type));
|
||||
|
||||
let extracted = extract_return_types(&return_type).unwrap();
|
||||
|
||||
let actual = emit_output_capture(&extracted);
|
||||
|
||||
let expected = quote! {
|
||||
v.0.output();
|
||||
v.1.output();
|
||||
};
|
||||
|
||||
assert_syn_eq(&actual, &expected);
|
||||
}
|
||||
}
|
||||
46
sunscreen_compiler_common/src/macros/type_name.rs
Normal file
46
sunscreen_compiler_common/src/macros/type_name.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{DeriveInput, Ident, LitStr};
|
||||
|
||||
/**
|
||||
* The implementation for #[derive(TypeName)]
|
||||
*/
|
||||
pub fn derive_typename_impl(parse_stream: DeriveInput) -> TokenStream {
|
||||
let name = &parse_stream.ident;
|
||||
let name_contents = LitStr::new(&format!("{{}}::{}", name), name.span());
|
||||
let crate_name = std::env::var("CARGO_CRATE_NAME").unwrap();
|
||||
|
||||
// If the sunscreen crate itself tries to derive types, then it needs to refer
|
||||
// to itself in the first-person as "crate", not in the third-person as "sunscreen"
|
||||
let sunscreen_path = if crate_name == "sunscreen" {
|
||||
Ident::new("crate", Span::call_site())
|
||||
} else {
|
||||
Ident::new("sunscreen", Span::call_site())
|
||||
};
|
||||
|
||||
quote! {
|
||||
impl #sunscreen_path ::types::TypeName for #name {
|
||||
fn type_name() -> #sunscreen_path ::types::Type {
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
|
||||
#sunscreen_path ::types::Type {
|
||||
name: format!(#name_contents, module_path!()),
|
||||
version: #sunscreen_path ::types::Version ::parse(version).expect("Crate version is not a valid semver"),
|
||||
is_encrypted: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl #sunscreen_path ::types::TypeNameInstance for #name {
|
||||
fn type_name_instance(&self) -> #sunscreen_path ::types::Type {
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
|
||||
#sunscreen_path ::types::Type {
|
||||
name: format!(#name_contents, module_path!()),
|
||||
version: #sunscreen_path ::types::Version ::parse(version).expect("Crate version is not a valid semver"),
|
||||
is_encrypted: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -75,7 +75,7 @@ pub fn common_subexpression_elimination<O: Operation>(
|
||||
|
||||
// Key is left/unary+right operand and operation. Value is
|
||||
// the node that matches such a key.
|
||||
let mut visited_nodes = HashMap::<(NodeIndex, Option<NodeIndex>, O), NodeIndex>::new();
|
||||
let mut visited_nodes = HashMap::<(NodeIndex, Option<NodeIndex>, &O), NodeIndex>::new();
|
||||
|
||||
// Look through out immediate children. If we find any of the
|
||||
// type that share an edge with another node, consolidate them into
|
||||
@@ -101,31 +101,31 @@ pub fn common_subexpression_elimination<O: Operation>(
|
||||
)));
|
||||
};
|
||||
|
||||
let child_op = child_node.operation;
|
||||
let child_op = &child_node.operation;
|
||||
|
||||
if child_op.is_binary() {
|
||||
let (left, right) = get_binary_operands(&query, e);
|
||||
|
||||
match visited_nodes.get(&(left, Some(right), child_node.operation)) {
|
||||
match visited_nodes.get(&(left, Some(right), child_op)) {
|
||||
Some(equiv_node) => {
|
||||
move_edges(*equiv_node, e);
|
||||
}
|
||||
None => {
|
||||
visited_nodes.insert((left, Some(right), child_node.operation), e);
|
||||
visited_nodes.insert((left, Some(right), child_op), e);
|
||||
|
||||
if child_op.is_commutative() {
|
||||
visited_nodes.insert((right, Some(left), child_node.operation), e);
|
||||
visited_nodes.insert((right, Some(left), child_op), e);
|
||||
}
|
||||
}
|
||||
};
|
||||
} else if child_op.is_unary() {
|
||||
// Unary
|
||||
let equiv_node = visited_nodes.get(&(index, None, child_node.operation));
|
||||
let equiv_node = visited_nodes.get(&(index, None, child_op));
|
||||
|
||||
match equiv_node {
|
||||
Some(equiv_node) => move_edges(*equiv_node, e),
|
||||
None => {
|
||||
visited_nodes.insert((index, None, child_node.operation), e);
|
||||
visited_nodes.insert((index, None, child_op), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,8 +107,8 @@ impl<N, E> GraphTransforms<N, E> {
|
||||
|
||||
impl<N, E> Default for GraphTransforms<N, E>
|
||||
where
|
||||
N: Copy,
|
||||
E: Copy,
|
||||
N: Clone,
|
||||
E: Clone,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
@@ -117,18 +117,18 @@ where
|
||||
|
||||
impl<N, E> TransformList<N, E> for GraphTransforms<N, E>
|
||||
where
|
||||
N: Copy,
|
||||
E: Copy,
|
||||
N: Clone,
|
||||
E: Clone,
|
||||
{
|
||||
fn apply(&mut self, graph: &mut petgraph::stable_graph::StableGraph<N, E>) {
|
||||
for t in &self.transforms {
|
||||
let inserted_node = match t {
|
||||
Transform::AddNode(n) => Some(graph.add_node(*n)),
|
||||
Transform::AddNode(n) => Some(graph.add_node(n.clone())),
|
||||
Transform::AddEdge(start, end, info) => {
|
||||
let start = self.materialize_index(*start);
|
||||
let end = self.materialize_index(*end);
|
||||
|
||||
graph.add_edge(start, end, *info);
|
||||
graph.add_edge(start, end, info.clone());
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ proc-macro = true
|
||||
proc-macro2 = "1.0.32"
|
||||
quote = "1.0.10"
|
||||
syn = { version = "1.0.81", features = ["derive", "full", "fold"] }
|
||||
sunscreen_compiler_common ={ path = "../sunscreen_compiler_common" }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = "1.0.72"
|
||||
|
||||
160
sunscreen_compiler_macros/src/attr_parsing.rs
Normal file
160
sunscreen_compiler_macros/src/attr_parsing.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use proc_macro2::Span;
|
||||
use syn::{
|
||||
parse::ParseStream, punctuated::Punctuated, spanned::Spanned, Error as SynError, Expr, Lit,
|
||||
LitInt, LitStr, Result as SynResult, Token,
|
||||
};
|
||||
|
||||
pub enum AttrValue {
|
||||
/**
|
||||
* The attribute value is a string.
|
||||
*/
|
||||
String(Span, String),
|
||||
|
||||
/**
|
||||
* The attribute value is an integer.
|
||||
*/
|
||||
USize(Span, usize),
|
||||
|
||||
/**
|
||||
* The key is present but has no value associated with it.
|
||||
*/
|
||||
Present(Span),
|
||||
}
|
||||
|
||||
impl From<&LitStr> for AttrValue {
|
||||
fn from(lit: &LitStr) -> Self {
|
||||
Self::String(lit.span(), lit.value())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&LitInt> for AttrValue {
|
||||
type Error = SynError;
|
||||
|
||||
fn try_from(lit: &LitInt) -> SynResult<Self> {
|
||||
let val = lit.base10_parse::<usize>().map_err(|_| {
|
||||
SynError::new_spanned(
|
||||
lit,
|
||||
format!("{} is not a valid integer literal.", lit.base10_digits()),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Self::USize(lit.span(), val))
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrValue {
|
||||
pub fn get_type(&self) -> &str {
|
||||
match self {
|
||||
Self::String(_s, _x) => "String",
|
||||
Self::USize(_s, _x) => "usize",
|
||||
Self::Present(_s) => "None",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn span(&self) -> Span {
|
||||
match self {
|
||||
Self::String(s, _) => *s,
|
||||
Self::USize(s, _) => *s,
|
||||
Self::Present(s) => *s,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> SynResult<&str> {
|
||||
match self {
|
||||
Self::String(_, val) => Ok(val),
|
||||
_ => Err(SynError::new(
|
||||
self.span(),
|
||||
format!("Expected string literal, got {}", self.get_type()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_usize(&self) -> SynResult<usize> {
|
||||
match self {
|
||||
Self::USize(_, val) => Ok(*val),
|
||||
_ => Err(SynError::new(
|
||||
self.span(),
|
||||
format!("Expected usize literal, got {}", self.get_type()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to parse a list of attributes contained in an attribute and
|
||||
* returns them as a `HashMap<String, AttrValue>`. The list of items
|
||||
* is a comma-delimited list of either `key = value` pairs where value is
|
||||
* a string or numeric literal *or* merely a key.
|
||||
*
|
||||
* Parsing will fail and return an error on any syntax violation.
|
||||
*
|
||||
* # Example
|
||||
* In the below example, this function parses the contents between the
|
||||
* parentheses.
|
||||
*
|
||||
* ```no_test
|
||||
* // key1 takes a string value, key2 takes a usize, key3's presence
|
||||
* // indicates is a true boolean.
|
||||
* #[my_attribute(key1 = "string", key2 = 42, key3)]
|
||||
* fn my_function() {
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
pub fn try_parse_dict(input: ParseStream) -> SynResult<HashMap<String, AttrValue>> {
|
||||
// parses a,b,c, or a,b,c where a,b and c are Indent
|
||||
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
|
||||
|
||||
let mut attrs: HashMap<String, AttrValue> = HashMap::new();
|
||||
|
||||
for var in &vars {
|
||||
match var {
|
||||
Expr::Assign(a) => {
|
||||
let key = match &*a.left {
|
||||
Expr::Path(p) =>
|
||||
p.path.get_ident().ok_or_else(||SynError::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
|
||||
_ => { return Err(SynError::new_spanned(&a.left, "Key should be a plain identifier")) }
|
||||
};
|
||||
|
||||
let value: AttrValue = match &*a.right {
|
||||
Expr::Lit(l) => match &l.lit {
|
||||
Lit::Str(s) => s.into(),
|
||||
Lit::Int(x) => x.try_into()?,
|
||||
_ => {
|
||||
return Err(SynError::new_spanned(
|
||||
l,
|
||||
"Literal should be a string or integer",
|
||||
))
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
return Err(SynError::new_spanned(
|
||||
&a.right,
|
||||
"Value should be a string literal",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
attrs.insert(key, value);
|
||||
}
|
||||
Expr::Path(p) => {
|
||||
let key = p
|
||||
.path
|
||||
.get_ident()
|
||||
.ok_or_else(|| SynError::new_spanned(p, "Unknown identifier"))?
|
||||
.to_string();
|
||||
|
||||
attrs.insert(key, AttrValue::Present(p.span()));
|
||||
}
|
||||
_ => {
|
||||
return Err(SynError::new_spanned(
|
||||
var,
|
||||
"Expected `key = \"value\"` or `key`",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(attrs)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
SynError(syn::Error),
|
||||
UnknownScheme(String),
|
||||
}
|
||||
|
||||
impl From<syn::Error> for Error {
|
||||
fn from(err: syn::Error) -> Self {
|
||||
Self::SynError(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
fhe_program_transforms::*,
|
||||
internals::{attr::Attrs, case::Scheme},
|
||||
internals::attr::{FheProgramAttrs, Scheme},
|
||||
};
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::{quote, quote_spanned};
|
||||
@@ -18,7 +18,7 @@ pub fn fhe_program_impl(
|
||||
let inputs = &input_fn.sig.inputs;
|
||||
let ret = &input_fn.sig.output;
|
||||
|
||||
let attr_params = parse_macro_input!(metadata as Attrs);
|
||||
let attr_params = parse_macro_input!(metadata as FheProgramAttrs);
|
||||
|
||||
let scheme_type = match attr_params.scheme {
|
||||
Scheme::Bfv => {
|
||||
@@ -119,19 +119,19 @@ pub fn fhe_program_impl(
|
||||
}
|
||||
|
||||
impl sunscreen::FheProgramFn for #fhe_program_struct_name {
|
||||
fn build(&self, params: &sunscreen::Params) -> sunscreen::Result<sunscreen::FrontendCompilation> {
|
||||
fn build(&self, params: &sunscreen::Params) -> sunscreen::Result<sunscreen::fhe::FheFrontendCompilation> {
|
||||
use std::cell::RefCell;
|
||||
use std::mem::transmute;
|
||||
use sunscreen::{CURRENT_CTX, Context, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}};
|
||||
use sunscreen::{fhe::{CURRENT_FHE_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}};
|
||||
|
||||
if SchemeType::Bfv != params.scheme_type {
|
||||
return Err(Error::IncorrectScheme)
|
||||
}
|
||||
|
||||
// TODO: Other schemes.
|
||||
let mut context = Context::new(params);
|
||||
let mut context = FheContext::new(params.clone());
|
||||
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
CURRENT_FHE_CTX.with(|ctx| {
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[forbid(unused_variables)]
|
||||
let internal = | #(#fhe_program_args)* | -> #fhe_program_return
|
||||
@@ -168,7 +168,7 @@ pub fn fhe_program_impl(
|
||||
ctx.swap(&RefCell::new(None));
|
||||
});
|
||||
|
||||
Ok(context.compilation)
|
||||
Ok(context.graph)
|
||||
}
|
||||
|
||||
fn signature(&self) -> sunscreen::CallSignature {
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
use super::case::Scheme;
|
||||
use proc_macro2::Span;
|
||||
use syn::{
|
||||
parse::{Parse, ParseStream},
|
||||
punctuated::Punctuated,
|
||||
spanned::Spanned,
|
||||
Error, Expr, Lit, LitInt, LitStr, Result, Token,
|
||||
Error as SynError, Expr, Lit, LitInt, LitStr, Result as SynResult, Token,
|
||||
};
|
||||
|
||||
use crate::internals::symbols::VALUE_KEYS;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AttrValue {
|
||||
/**
|
||||
* The attribute value is a string.
|
||||
@@ -35,11 +33,11 @@ impl From<&LitStr> for AttrValue {
|
||||
}
|
||||
|
||||
impl TryFrom<&LitInt> for AttrValue {
|
||||
type Error = Error;
|
||||
type Error = SynError;
|
||||
|
||||
fn try_from(lit: &LitInt) -> Result<Self> {
|
||||
fn try_from(lit: &LitInt) -> SynResult<Self> {
|
||||
let val = lit.base10_parse::<usize>().map_err(|_| {
|
||||
Error::new_spanned(
|
||||
SynError::new_spanned(
|
||||
lit,
|
||||
format!("{} is not a valid integer literal.", lit.base10_digits()),
|
||||
)
|
||||
@@ -66,20 +64,20 @@ impl AttrValue {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> Result<&str> {
|
||||
pub fn as_str(&self) -> SynResult<&str> {
|
||||
match self {
|
||||
Self::String(_, val) => Ok(val),
|
||||
_ => Err(Error::new(
|
||||
_ => Err(SynError::new(
|
||||
self.span(),
|
||||
format!("Expected String, got {}", self.get_type()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_usize(&self) -> Result<usize> {
|
||||
pub fn as_usize(&self) -> SynResult<usize> {
|
||||
match self {
|
||||
Self::USize(_, val) => Ok(*val),
|
||||
_ => Err(Error::new(
|
||||
_ => Err(SynError::new(
|
||||
self.span(),
|
||||
format!("Expected String, got {}", self.get_type()),
|
||||
)),
|
||||
@@ -87,78 +85,129 @@ impl AttrValue {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Attrs {
|
||||
/**
|
||||
* Attempts to parse a list of attributes contained in an attribute and
|
||||
* returns them as a `HashMap<String, AttrValue>`. The list of items
|
||||
* is a comma-delimited list of either `key = value` pairs where value is
|
||||
* a string or numeric literal *or* merely a key.
|
||||
*
|
||||
* Parsing will fail and return an error on any syntax violation.
|
||||
*
|
||||
* # Example
|
||||
* In the below example, this function parses the contents between the
|
||||
* parentheses.
|
||||
*
|
||||
* ```no_test
|
||||
* // key1 takes a string value, key2 takes a usize, key3's presence
|
||||
* // indicates is a true boolean.
|
||||
* #[my_attribute(key1 = "string", key2 = 42, key3)]
|
||||
* fn my_function() {
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
fn try_parse_dict(input: ParseStream) -> SynResult<HashMap<String, AttrValue>> {
|
||||
// parses a,b,c, or a,b,c where a,b and c are Indent
|
||||
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
|
||||
|
||||
let mut attrs: HashMap<String, AttrValue> = HashMap::new();
|
||||
|
||||
for var in &vars {
|
||||
match var {
|
||||
Expr::Assign(a) => {
|
||||
let key = match &*a.left {
|
||||
Expr::Path(p) =>
|
||||
p.path.get_ident().ok_or_else(||SynError::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
|
||||
_ => { return Err(SynError::new_spanned(&a.left, "Key should be a plain identifier")) }
|
||||
};
|
||||
|
||||
let value: AttrValue = match &*a.right {
|
||||
Expr::Lit(l) => match &l.lit {
|
||||
Lit::Str(s) => s.into(),
|
||||
Lit::Int(x) => x.try_into()?,
|
||||
_ => {
|
||||
return Err(SynError::new_spanned(
|
||||
l,
|
||||
"Literal should be a string or integer",
|
||||
))
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
return Err(SynError::new_spanned(
|
||||
&a.right,
|
||||
"Value should be a string literal",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
attrs.insert(key, value);
|
||||
}
|
||||
Expr::Path(p) => {
|
||||
let key = p
|
||||
.path
|
||||
.get_ident()
|
||||
.ok_or_else(|| SynError::new_spanned(p, "Unknown identifier"))?
|
||||
.to_string();
|
||||
|
||||
attrs.insert(key, AttrValue::Present(p.span()));
|
||||
}
|
||||
_ => {
|
||||
return Err(SynError::new_spanned(
|
||||
var,
|
||||
"Expected `key = \"value\"` or `key`",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(attrs)
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Scheme {
|
||||
Bfv,
|
||||
}
|
||||
|
||||
impl TryFrom<&AttrValue> for Scheme {
|
||||
type Error = SynError;
|
||||
|
||||
fn try_from(value: &AttrValue) -> SynResult<Self> {
|
||||
let as_str = value.as_str()?;
|
||||
|
||||
let scheme = match as_str {
|
||||
"bfv" => Self::Bfv,
|
||||
_ => {
|
||||
return Err(SynError::new(
|
||||
value.span(),
|
||||
format!("Unknown scheme {}", as_str),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(scheme)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FheProgramAttrs {
|
||||
pub scheme: Scheme,
|
||||
pub chain_count: usize,
|
||||
}
|
||||
|
||||
impl Parse for Attrs {
|
||||
fn parse(input: ParseStream) -> Result<Self> {
|
||||
// parses a,b,c, or a,b,c where a,b and c are Indent
|
||||
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
|
||||
impl Parse for FheProgramAttrs {
|
||||
fn parse(input: ParseStream) -> SynResult<Self> {
|
||||
let attrs = try_parse_dict(input)?;
|
||||
|
||||
let mut attrs: HashMap<String, AttrValue> = HashMap::new();
|
||||
const VALUE_KEYS: &[&str] = &["scheme", "chain_count"];
|
||||
|
||||
for var in &vars {
|
||||
match var {
|
||||
Expr::Assign(a) => {
|
||||
let key = match &*a.left {
|
||||
Expr::Path(p) =>
|
||||
p.path.get_ident().ok_or_else(||Error::new_spanned(p, "Key should contain only a single path element (e.g, foo, not foo::bar)".to_owned()))?.to_string(),
|
||||
_ => { return Err(Error::new_spanned(&a.left, "Key should be a plain identifier")) }
|
||||
};
|
||||
|
||||
let value: AttrValue = match &*a.right {
|
||||
Expr::Lit(l) => match &l.lit {
|
||||
Lit::Str(s) => s.into(),
|
||||
Lit::Int(x) => x.try_into()?,
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
l,
|
||||
"Literal should be a string or integer",
|
||||
))
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
&a.right,
|
||||
"Value should be a string literal",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if !VALUE_KEYS.iter().any(|x| *x == key) {
|
||||
return Err(Error::new_spanned(a, "Unknown key".to_owned()));
|
||||
}
|
||||
|
||||
attrs.insert(key, value);
|
||||
}
|
||||
Expr::Path(p) => {
|
||||
let key = p
|
||||
.path
|
||||
.get_ident()
|
||||
.ok_or_else(|| Error::new_spanned(p, "Unknown identifier"))?
|
||||
.to_string();
|
||||
|
||||
if !VALUE_KEYS.iter().any(|x| *x == key) {
|
||||
return Err(Error::new_spanned(p, "Unknown key"));
|
||||
}
|
||||
|
||||
attrs.insert(key, AttrValue::Present(p.span()));
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
var,
|
||||
"Expected `key = \"value\"` or `key`",
|
||||
))
|
||||
}
|
||||
for i in attrs.keys() {
|
||||
if !VALUE_KEYS.iter().any(|x| x == i) {
|
||||
return Err(SynError::new(input.span(), &format!("Unknown key '{}'", i)));
|
||||
}
|
||||
}
|
||||
|
||||
let scheme_type = attrs
|
||||
let scheme: Scheme = attrs
|
||||
.get("scheme")
|
||||
.ok_or_else(|| Error::new_spanned(&vars, "required `scheme` is missing".to_owned()))?
|
||||
.as_str()?;
|
||||
.ok_or_else(|| SynError::new(input.span(), "required `scheme` is missing".to_owned()))?
|
||||
.try_into()?;
|
||||
|
||||
let chain_count = attrs
|
||||
.get("chain_count")
|
||||
@@ -166,10 +215,54 @@ impl Parse for Attrs {
|
||||
.unwrap_or(Ok(1))?;
|
||||
|
||||
Ok(Self {
|
||||
scheme: Scheme::parse(scheme_type).map_err(|_e| {
|
||||
Error::new_spanned(vars, format!("Unknown scheme '{}'", &scheme_type))
|
||||
})?,
|
||||
scheme,
|
||||
chain_count,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub enum BackendType {
|
||||
Bulletproofs,
|
||||
}
|
||||
|
||||
impl TryFrom<&AttrValue> for BackendType {
|
||||
type Error = SynError;
|
||||
|
||||
fn try_from(value: &AttrValue) -> SynResult<Self> {
|
||||
let as_str = value.as_str()?;
|
||||
|
||||
match as_str {
|
||||
"bulletproofs" => Ok(BackendType::Bulletproofs),
|
||||
_ => Err(SynError::new(
|
||||
value.span(),
|
||||
format!("Unknown backend `{}`", as_str.to_owned()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub struct ZkpProgramAttrs {
|
||||
backend_type: BackendType,
|
||||
}
|
||||
|
||||
impl Parse for ZkpProgramAttrs {
|
||||
fn parse(input: ParseStream) -> SynResult<Self> {
|
||||
let attrs = try_parse_dict(input)?;
|
||||
|
||||
const VALUE_KEYS: &[&str] = &["backend"];
|
||||
|
||||
for i in attrs.keys() {
|
||||
if !VALUE_KEYS.iter().any(|x| x == i) {
|
||||
return Err(SynError::new(input.span(), &format!("Unknown key '{}'", i)));
|
||||
}
|
||||
}
|
||||
|
||||
let backend_type = attrs.get("backend").ok_or_else(|| {
|
||||
SynError::new(input.span(), "required 'backend' is missing".to_owned())
|
||||
})?;
|
||||
let backend_type = BackendType::try_from(backend_type)?;
|
||||
|
||||
Ok(Self { backend_type })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1 @@
|
||||
use self::Scheme::*;
|
||||
use crate::error::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Scheme {
|
||||
Bfv,
|
||||
}
|
||||
|
||||
impl Scheme {
|
||||
pub fn parse(s: &str) -> Result<Self> {
|
||||
Ok(match s {
|
||||
"bfv" => Bfv,
|
||||
_ => Err(Error::UnknownScheme(s.to_owned()))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Following the pattern in serde (https://github.com/serde-rs)
|
||||
pub mod attr;
|
||||
pub mod case;
|
||||
pub mod symbols;
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub const VALUE_KEYS: &[&str] = &["scheme", "chain_count"];
|
||||
@@ -6,11 +6,11 @@
|
||||
|
||||
extern crate proc_macro;
|
||||
|
||||
mod error;
|
||||
mod fhe_program;
|
||||
mod fhe_program_transforms;
|
||||
mod internals;
|
||||
mod type_name;
|
||||
mod zkp_program;
|
||||
|
||||
#[proc_macro_derive(TypeName)]
|
||||
/**
|
||||
@@ -65,3 +65,14 @@ pub fn fhe_program(
|
||||
) -> proc_macro::TokenStream {
|
||||
fhe_program::fhe_program_impl(metadata, input)
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
/**
|
||||
* Specifies a function to be a ZKP program. TODO: docs.
|
||||
*/
|
||||
pub fn zkp_program(
|
||||
metadata: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
zkp_program::zkp_program_impl(metadata, input)
|
||||
}
|
||||
|
||||
162
sunscreen_compiler_macros/src/zkp_program.rs
Normal file
162
sunscreen_compiler_macros/src/zkp_program.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::{quote, quote_spanned};
|
||||
use sunscreen_compiler_common::macros::{
|
||||
create_program_node, emit_signature, extract_fn_arguments, lift_type, ExtractFnArgumentsError,
|
||||
};
|
||||
use syn::{parse_macro_input, spanned::Spanned, ItemFn, ReturnType, Type};
|
||||
|
||||
use crate::internals::attr::ZkpProgramAttrs;
|
||||
|
||||
pub fn zkp_program_impl(
|
||||
metadata: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
let _attr_params = parse_macro_input!(metadata as ZkpProgramAttrs);
|
||||
let input_fn = parse_macro_input!(input as ItemFn);
|
||||
|
||||
let zkp_program_name = &input_fn.sig.ident;
|
||||
let vis = &input_fn.vis;
|
||||
let body = &input_fn.block;
|
||||
let inputs = &input_fn.sig.inputs;
|
||||
let ret = &input_fn.sig.output;
|
||||
|
||||
match ret {
|
||||
ReturnType::Default => {}
|
||||
_ => {
|
||||
return proc_macro::TokenStream::from(quote_spanned! {
|
||||
ret.span() => compile_error!("ZKP programs may not return values.")
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let unwrapped_inputs = match extract_fn_arguments(inputs) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return proc_macro::TokenStream::from(match e {
|
||||
ExtractFnArgumentsError::ContainsSelf(s) => {
|
||||
quote_spanned! {s => compile_error!("FHE programs must not contain `self`") }
|
||||
}
|
||||
ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! {
|
||||
s => compile_error! { "Expected Identifier" }
|
||||
},
|
||||
ExtractFnArgumentsError::IllegalType(s) => quote_spanned! {
|
||||
s => compile_error! { "FHE program arguments must be an array or named struct type" }
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let argument_types = unwrapped_inputs
|
||||
.iter()
|
||||
.map(|(t, _)| (**t).clone())
|
||||
.collect::<Vec<Type>>();
|
||||
|
||||
let zkp_program_args = unwrapped_inputs
|
||||
.iter()
|
||||
.map(|i| {
|
||||
let (ty, name) = i;
|
||||
let ty = lift_type(ty).unwrap();
|
||||
|
||||
quote! {
|
||||
#name: #ty,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<TokenStream>>();
|
||||
|
||||
let signature = emit_signature(&argument_types, &[]);
|
||||
|
||||
let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| {
|
||||
let var_name = format!("c_{}", i);
|
||||
|
||||
create_program_node(&var_name, t.0)
|
||||
});
|
||||
|
||||
let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| {
|
||||
let id = Ident::new(&format!("c_{}", i), Span::call_site());
|
||||
|
||||
quote_spanned! {t.0.span() =>
|
||||
#id
|
||||
}
|
||||
});
|
||||
|
||||
let zkp_program_struct_name =
|
||||
Ident::new(&format!("{}_struct", zkp_program_name), Span::call_site());
|
||||
|
||||
let zkp_program_name_literal = format!("{}", zkp_program_name);
|
||||
|
||||
proc_macro::TokenStream::from(quote! {
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Clone)]
|
||||
#vis struct #zkp_program_struct_name {
|
||||
}
|
||||
|
||||
impl sunscreen::ZkpProgramFn for #zkp_program_struct_name {
|
||||
fn build(&self) -> sunscreen::Result<sunscreen::ZkpFrontendCompilation> {
|
||||
use std::cell::RefCell;
|
||||
use std::mem::transmute;
|
||||
use sunscreen::{CURRENT_ZKP_CTX, ZkpContext, Error, INDEX_ARENA, Result, types::{zkp::ProgramNode, TypeName}};
|
||||
|
||||
let mut context = ZkpContext::new(0);
|
||||
|
||||
CURRENT_ZKP_CTX.with(|ctx| {
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[forbid(unused_variables)]
|
||||
let internal = | #(#zkp_program_args)* |
|
||||
#body
|
||||
;
|
||||
|
||||
// Transmute away the lifetime to 'static. So long as we are careful with internal()
|
||||
// panicing, this is safe because we set the context back to none before the function
|
||||
// returns.
|
||||
ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) })));
|
||||
|
||||
#(#var_decl)*
|
||||
|
||||
let panic_res = std::panic::catch_unwind(|| {
|
||||
internal(#(#args),*)
|
||||
});
|
||||
|
||||
// when panicing or not, we need to clear our indicies arena and
|
||||
// unset the context reference.
|
||||
match panic_res {
|
||||
Ok(v) => { },
|
||||
Err(err) => {
|
||||
INDEX_ARENA.with(|allocator| {
|
||||
allocator.borrow_mut().reset()
|
||||
});
|
||||
ctx.swap(&RefCell::new(None));
|
||||
std::panic::resume_unwind(err)
|
||||
}
|
||||
};
|
||||
|
||||
INDEX_ARENA.with(|allocator| {
|
||||
allocator.borrow_mut().reset()
|
||||
});
|
||||
ctx.swap(&RefCell::new(None));
|
||||
});
|
||||
|
||||
Ok(context.graph)
|
||||
}
|
||||
|
||||
fn signature(&self) -> sunscreen::CallSignature {
|
||||
#signature
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
#zkp_program_name_literal
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for #zkp_program_struct_name {
|
||||
fn as_ref(&self) -> &str {
|
||||
use sunscreen::FheProgramFn;
|
||||
|
||||
self.name()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
#vis const #zkp_program_name: #zkp_program_struct_name = #zkp_program_struct_name {
|
||||
};
|
||||
})
|
||||
}
|
||||
@@ -18,7 +18,7 @@ readme = "crates-io.md"
|
||||
|
||||
[dependencies]
|
||||
petgraph = { version = "0.6.0", features = ["serde-1"] }
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
serde = { version = "1.0.147", features = ["derive"] }
|
||||
seal_fhe = { version = "0.7", path = "../seal_fhe" }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -37,7 +37,7 @@ use TransformNodeIndex::*;
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Hash, Deserialize, PartialEq, Eq)]
|
||||
/**
|
||||
* Sunscreen supports the BFV scheme.
|
||||
*/
|
||||
|
||||
@@ -27,7 +27,7 @@ petgraph = "0.6.0"
|
||||
num_cpus = "1.13.0"
|
||||
rayon = "1.5.1"
|
||||
rlp = "0.5.1"
|
||||
serde = "1.0.130"
|
||||
serde = "1.0.147"
|
||||
semver = "1.0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -22,7 +22,7 @@ pub use serialization::WithContext;
|
||||
use seal_fhe::{Ciphertext as SealCiphertext, Plaintext as SealPlaintext};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, Serialize, Deserialize, Eq)]
|
||||
/**
|
||||
* The underlying backend implementation of a plaintext (e.g. SEAL's [`Plaintext`](seal_fhe::Plaintext)).
|
||||
*/
|
||||
|
||||
@@ -63,7 +63,7 @@ pub enum RequiredKeys {
|
||||
PublicKey,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Serialize, Hash, Deserialize, PartialEq, Eq)]
|
||||
/**
|
||||
* The parameter set required for a given FHE program to run efficiently and correctly.
|
||||
*/
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::hash::Hash;
|
||||
|
||||
use crate::Params;
|
||||
use seal_fhe::{BfvEncryptionParametersBuilder, Context, FromBytes, Modulus, ToBytes};
|
||||
use serde::{
|
||||
@@ -6,7 +8,7 @@ use serde::{
|
||||
Deserialize, Serialize,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
#[derive(Debug, PartialEq, Hash, Eq, Clone)]
|
||||
/**
|
||||
* A data type that contains parameters for reconstructing a context
|
||||
* during deserialization (needed by SEAL).
|
||||
|
||||
22
sunscreen_zkp_backend/Cargo.toml
Normal file
22
sunscreen_zkp_backend/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "sunscreen_zkp_backend"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bulletproofs = { version = "4.0.0", features = ["yoloproofs"], optional = true }
|
||||
merlin = { version = "3.0.0", optional = true}
|
||||
|
||||
[features]
|
||||
bulletproofs = [
|
||||
"dep:bulletproofs",
|
||||
"dep:merlin"
|
||||
]
|
||||
|
||||
[dependencies.curve25519-dalek]
|
||||
version = "4"
|
||||
features = ["u64_backend", "serde"]
|
||||
default-features = false
|
||||
package = "curve25519-dalek-ng"
|
||||
77
sunscreen_zkp_backend/src/bulletproofs.rs
Normal file
77
sunscreen_zkp_backend/src/bulletproofs.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use bulletproofs::{
|
||||
r1cs::{ConstraintSystem, Prover, R1CSProof, Variable, Verifier},
|
||||
BulletproofGens, PedersenGens,
|
||||
};
|
||||
use curve25519_dalek::scalar::Scalar;
|
||||
use merlin::Transcript;
|
||||
use std::task::Context;
|
||||
|
||||
struct MulProof(R1CSProof);
|
||||
|
||||
impl MulProof {
|
||||
fn make_gens() -> (Transcript, PedersenGens, BulletproofGens) {
|
||||
let mut transcript = Transcript::new(b"Horse");
|
||||
transcript.append_message(b"dom-sep", b"MulProof");
|
||||
|
||||
let pc_gens = PedersenGens::default();
|
||||
let bp_gens = BulletproofGens::new(128, 1);
|
||||
|
||||
(transcript, pc_gens, bp_gens)
|
||||
}
|
||||
|
||||
pub fn prove(x: Scalar, y: Scalar, o: Scalar) -> Self {
|
||||
let (transcript, pc_gens, bp_gens) = Self::make_gens();
|
||||
|
||||
let mut prover = Prover::new(&pc_gens, transcript);
|
||||
|
||||
let inputs = vec![
|
||||
prover.allocate(Some(x)).unwrap(),
|
||||
prover.allocate(Some(y)).unwrap(),
|
||||
];
|
||||
|
||||
let outputs = vec![prover.allocate(Some(o)).unwrap()];
|
||||
|
||||
Self::gadget(&mut prover, inputs, outputs);
|
||||
|
||||
Self(prover.prove(&bp_gens).unwrap())
|
||||
}
|
||||
|
||||
pub fn verify(&self) -> bool {
|
||||
let (transcript, pc_gens, bp_gens) = Self::make_gens();
|
||||
|
||||
let mut verifier = Verifier::new(transcript);
|
||||
|
||||
let inputs = vec![
|
||||
verifier.allocate(None).unwrap(),
|
||||
verifier.allocate(None).unwrap(),
|
||||
];
|
||||
|
||||
let outputs = vec![verifier.allocate(None).unwrap()];
|
||||
|
||||
Self::gadget(&mut verifier, inputs, outputs);
|
||||
|
||||
verifier.verify(&self.0, &pc_gens, &bp_gens).is_ok()
|
||||
}
|
||||
|
||||
fn gadget<CS: ConstraintSystem>(cs: &mut CS, inputs: Vec<Variable>, outputs: Vec<Variable>) {
|
||||
let (_, _, o) = cs.multiply(
|
||||
inputs[0] + Scalar::from(0u32),
|
||||
inputs[1] + Scalar::from(0u32),
|
||||
);
|
||||
|
||||
inputs[0];
|
||||
|
||||
cs.constrain(o - outputs[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_use_bulletproofs_contstraints() {
|
||||
let proof = MulProof::prove(Scalar::from(7u32), Scalar::from(9u32), Scalar::from(63u32));
|
||||
|
||||
assert!(proof.verify());
|
||||
|
||||
let proof = MulProof::prove(Scalar::from(7u32), Scalar::from(9u32), Scalar::from(64u32));
|
||||
|
||||
assert!(!proof.verify());
|
||||
}
|
||||
2
sunscreen_zkp_backend/src/lib.rs
Normal file
2
sunscreen_zkp_backend/src/lib.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[cfg(feature = "bulletproofs")]
|
||||
pub mod bulletproofs;
|
||||
Reference in New Issue
Block a user