mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
WIP
This commit is contained in:
20
.vscode/launch.json
vendored
20
.vscode/launch.json
vendored
@@ -100,7 +100,7 @@
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug tests in library 'sunscreen_frontend'",
|
||||
"name": "Debug tests in library 'sunscreen_compiler'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
@@ -115,5 +115,23 @@
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug tests in library 'sunscreen_compiler_macros'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--package=sunscreen_compiler_macros"
|
||||
],
|
||||
"filter": {
|
||||
"name": "circuit_tests",
|
||||
"kind": "test"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
]
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
mod rational;
|
||||
|
||||
use std::thread::{self, JoinHandle};
|
||||
use sunscreen_compiler::{circuit, Compiler};
|
||||
use sunscreen_compiler::{circuit};
|
||||
use rational::Rational;
|
||||
|
||||
fn alice() -> JoinHandle<()> {
|
||||
@@ -24,6 +24,6 @@ fn main() {
|
||||
let a = alice();
|
||||
let b = bob();
|
||||
|
||||
a.join();
|
||||
b.join();
|
||||
a.join().unwrap();
|
||||
b.join().unwrap();
|
||||
}
|
||||
@@ -1,11 +1,9 @@
|
||||
use sunscreen_compiler::{TypeName, Params, InnerPlaintext, Plaintext};
|
||||
use sunscreen_compiler::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, CircuitNode};
|
||||
use sunscreen_compiler::{TypeName, Params, InnerPlaintext, Plaintext, with_ctx};
|
||||
use sunscreen_compiler::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, GraphAdd, CircuitNode};
|
||||
use sunscreen_runtime::{Error};
|
||||
|
||||
use num::Rational64;
|
||||
|
||||
use std::ops::{Add, Mul, Sub, Div};
|
||||
|
||||
#[derive(Debug, Clone, Copy, TypeName, PartialEq, Eq)]
|
||||
pub struct Rational {
|
||||
num: Signed,
|
||||
@@ -16,9 +14,7 @@ pub struct Rational {
|
||||
// each type. It's spiritually similar to [`std::mem::sizeof`] except it returns the number
|
||||
// of plaintexts this type needs rather than the number of bytes.
|
||||
impl NumCiphertexts for Rational {
|
||||
fn num_ciphertexts() -> usize {
|
||||
Signed::num_ciphertexts() + Signed::num_ciphertexts()
|
||||
}
|
||||
const NUM_CIPHERTEXTS: usize = Signed::NUM_CIPHERTEXTS + Signed::NUM_CIPHERTEXTS;
|
||||
}
|
||||
|
||||
// This trait takes a plaintext and turns it into a [`Rational`]. [`Plaintext`] is a type generally
|
||||
@@ -30,12 +26,12 @@ impl NumCiphertexts for Rational {
|
||||
// This trait is needed so Runtime knows how to package this type after decryption.
|
||||
impl TryFromPlaintext for Rational {
|
||||
fn try_from_plaintext(plaintext: &Plaintext, params: &Params) -> Result<Self, Error> {
|
||||
let (num, den) = match plaintext.inner {
|
||||
let (num, den) = match &plaintext.inner {
|
||||
InnerPlaintext::Seal(p) => {
|
||||
// We encode Rationals as 2 plaintexts. Wrap each plaintext and delegate
|
||||
// to Signed::try_from_plaintext to compute our inner values.
|
||||
let num = Plaintext {inner: InnerPlaintext::Seal(vec![p[0]]) };
|
||||
let den = Plaintext {inner: InnerPlaintext::Seal(vec![p[1]]) };
|
||||
let num = Plaintext {inner: InnerPlaintext::Seal(vec![p[0].clone()]) };
|
||||
let den = Plaintext {inner: InnerPlaintext::Seal(vec![p[1].clone()]) };
|
||||
|
||||
(
|
||||
Signed::try_from_plaintext(&num, params)?,
|
||||
@@ -58,7 +54,7 @@ impl TryIntoPlaintext for Rational {
|
||||
|
||||
let (num, den) = match (num.inner, den.inner) {
|
||||
(InnerPlaintext::Seal(n), InnerPlaintext::Seal(d)) => {
|
||||
(n[0], d[0])
|
||||
(n[0].clone(), d[0].clone())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -88,6 +84,19 @@ impl TryFrom<f64> for Rational {
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for CircuitNode<Rational> {
|
||||
impl GraphAdd for Rational {
|
||||
type Left = Self;
|
||||
type Right = Self;
|
||||
|
||||
fn graph_add(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
|
||||
with_ctx(|ctx| {
|
||||
let ids = [
|
||||
ctx.add_addition(a.ids[0], b.ids[0]),
|
||||
ctx.add_addition(a.ids[1], b.ids[1]),
|
||||
];
|
||||
|
||||
CircuitNode::new(&ids)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, Params, PlainModulusConstraint};
|
||||
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, PlainModulusConstraint};
|
||||
use sunscreen_runtime::Runtime;
|
||||
|
||||
/**
|
||||
|
||||
@@ -45,8 +45,7 @@ mod params;
|
||||
*/
|
||||
pub mod types;
|
||||
|
||||
use std::cell::{RefCell, RefMut};
|
||||
|
||||
use std::cell::{RefCell};
|
||||
use petgraph::{
|
||||
algo::is_isomorphic_matching,
|
||||
stable_graph::{NodeIndex, StableGraph},
|
||||
@@ -71,8 +70,6 @@ pub use sunscreen_runtime::{
|
||||
Params, PublicKey, RequiredKeys, Runtime, InnerPlaintext, Plaintext
|
||||
};
|
||||
|
||||
use types::{FheType, CircuitNode};
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
/**
|
||||
* Represents a literal node's data.
|
||||
@@ -223,15 +220,15 @@ thread_local! {
|
||||
*/
|
||||
pub fn with_ctx<F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(&'static mut Context) -> R,
|
||||
F: FnOnce(&mut Context) -> R,
|
||||
{
|
||||
CURRENT_CTX.with(|ctx| {
|
||||
let mut option: RefMut<'static, Option<&mut Context>> = ctx.borrow_mut();
|
||||
let ctx = option
|
||||
.as_mut()
|
||||
.expect("Called Ciphertext::new() outside of a context.");
|
||||
|
||||
f(ctx)
|
||||
let mut option = ctx.borrow_mut();
|
||||
let ctx = option
|
||||
.as_mut()
|
||||
.expect("Called Ciphertext::new() outside of a context.");
|
||||
|
||||
f(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -249,22 +246,14 @@ impl Context {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allocate_circuit_node<T: FheType>(&mut self, ids: &[NodeIndex]) -> CircuitNode<T> {
|
||||
let indicies = self.allocate_indicies(T::num_ciphertexts());
|
||||
|
||||
indicies.copy_from_slice(ids);
|
||||
|
||||
CircuitNode::new(indicies)
|
||||
}
|
||||
|
||||
fn allocate_indicies(&mut self, len: usize) -> &mut [NodeIndex] {
|
||||
pub(crate) unsafe fn allocate_indicies(&mut self, len: usize) -> &'static mut [NodeIndex] {
|
||||
let before_len = self.indicies_store.len();
|
||||
|
||||
self.indicies_store.resize(before_len + len, NodeIndex::new(0));
|
||||
|
||||
let (_, right) = self.indicies_store.split_at_mut(before_len);
|
||||
|
||||
right
|
||||
std::mem::transmute(right)
|
||||
}
|
||||
|
||||
fn add_2_input(&mut self, op: Operation, left: NodeIndex, right: NodeIndex) -> NodeIndex {
|
||||
|
||||
@@ -2,7 +2,7 @@ use seal::Plaintext as SealPlaintext;
|
||||
|
||||
use crate::{
|
||||
types::{BfvType, CircuitNode, FheType},
|
||||
Context, TypeName as DeriveTypeName, Params, with_ctx
|
||||
TypeName as DeriveTypeName, Params, with_ctx
|
||||
};
|
||||
use crate::types::{GraphAdd, GraphMul};
|
||||
|
||||
@@ -35,9 +35,11 @@ impl GraphAdd for Unsigned {
|
||||
type Left = Unsigned;
|
||||
type Right = Unsigned;
|
||||
|
||||
fn graph_add<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left> {
|
||||
fn graph_add(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
|
||||
with_ctx(|ctx| {
|
||||
ctx.allocate_circuit_node(&[ctx.add_addition(a.ids[0], b.ids[0])])
|
||||
let n = ctx.add_addition(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -46,13 +48,11 @@ impl GraphMul for Unsigned {
|
||||
type Left = Unsigned;
|
||||
type Right = Unsigned;
|
||||
|
||||
fn graph_mul<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left> {
|
||||
Context::with_ctx(|ctx| unsafe {
|
||||
let indicies = ctx.allocate_indicies(Self::num_ciphertexts());
|
||||
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
|
||||
with_ctx(|ctx| {
|
||||
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
indicies[0] = ctx.add_multiplication(a.ids[0], b.ids[0]);
|
||||
|
||||
CircuitNode::new(indicies)
|
||||
CircuitNode::new(&[n])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -102,9 +102,7 @@ impl TryFromPlaintext for Unsigned {
|
||||
}
|
||||
|
||||
impl NumCiphertexts for Unsigned {
|
||||
fn num_ciphertexts() -> usize {
|
||||
1
|
||||
}
|
||||
const NUM_CIPHERTEXTS: usize = 1;
|
||||
}
|
||||
|
||||
impl From<u64> for Unsigned {
|
||||
@@ -128,9 +126,7 @@ pub struct Signed {
|
||||
}
|
||||
|
||||
impl NumCiphertexts for Signed {
|
||||
fn num_ciphertexts() -> usize {
|
||||
1
|
||||
}
|
||||
const NUM_CIPHERTEXTS: usize = 1;
|
||||
}
|
||||
|
||||
impl FheType for Signed {}
|
||||
|
||||
@@ -10,7 +10,7 @@ pub use sunscreen_runtime::{
|
||||
};
|
||||
|
||||
pub use integer::{Signed, Unsigned};
|
||||
use std::ops::{Add, Mul, Div, Sub};
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize)]
|
||||
/**
|
||||
@@ -38,14 +38,16 @@ impl U64LiteralRef {
|
||||
* no [`CircuitNode`] should outlive the said context. Violating any of these condicitions may result
|
||||
* in memory corruption or use-after-free.
|
||||
*/
|
||||
pub struct CircuitNode<'a, T: FheType> {
|
||||
|
||||
ids: &'a [NodeIndex],
|
||||
pub struct CircuitNode<T: FheType> {
|
||||
/**
|
||||
* The ids on this node.
|
||||
*/
|
||||
pub ids: &'static [NodeIndex],
|
||||
|
||||
_phantom: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl <'a, T: FheType> CircuitNode<'a, T> {
|
||||
impl <T: FheType> CircuitNode<T> {
|
||||
/**
|
||||
* Creates a new circuit node with the given node index.
|
||||
*
|
||||
@@ -62,9 +64,15 @@ impl <'a, T: FheType> CircuitNode<'a, T> {
|
||||
* never outlive the backing context, use-after-free can occur.
|
||||
*
|
||||
*/
|
||||
pub unsafe fn new(ids: &'a [NodeIndex]) -> Self {
|
||||
pub fn new(ids: &[NodeIndex]) -> Self {
|
||||
let ids_dest = with_ctx(|ctx| {
|
||||
unsafe { ctx.allocate_indicies(ids.len()) }
|
||||
});
|
||||
|
||||
ids_dest.copy_from_slice(ids);
|
||||
|
||||
Self {
|
||||
ids,
|
||||
ids: ids_dest,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -79,14 +87,14 @@ impl <'a, T: FheType> CircuitNode<'a, T> {
|
||||
* never outlive the backing context, use-after-free can occur.
|
||||
*
|
||||
*/
|
||||
pub unsafe fn input() -> Self {
|
||||
let mut ids = with_ctx(|ctx| ctx.allocate_indicies(T::num_ciphertexts()));
|
||||
pub fn input() -> Self {
|
||||
let mut ids = Vec::with_capacity(T::NUM_CIPHERTEXTS);
|
||||
|
||||
for i in 0..T::num_ciphertexts() {
|
||||
ids[i] = with_ctx(|ctx| ctx.add_input());
|
||||
for _ in 0..T::NUM_CIPHERTEXTS {
|
||||
ids.push(with_ctx(|ctx| ctx.add_input()));
|
||||
}
|
||||
|
||||
unsafe { Self::new(ids) }
|
||||
|
||||
CircuitNode::new(&ids)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -99,14 +107,14 @@ impl <'a, T: FheType> CircuitNode<'a, T> {
|
||||
* never outlive the backing context, use-after-free can occur.
|
||||
*
|
||||
*/
|
||||
pub unsafe fn output(&self) -> Self {
|
||||
let mut ids = with_ctx(|ctx| ctx.allocate_indicies(T::num_ciphertexts()));
|
||||
pub fn output(&self) -> Self {
|
||||
let mut ids = Vec::with_capacity(self.ids.len());
|
||||
|
||||
for i in 0..self.ids.len() {
|
||||
ids[i] = with_ctx(|ctx| ctx.add_output(self.ids[i]));
|
||||
ids.push(with_ctx(|ctx| ctx.add_output(self.ids[i])));
|
||||
}
|
||||
|
||||
unsafe { Self::new(ids) }
|
||||
CircuitNode::new(&ids)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -134,7 +142,7 @@ pub trait GraphAdd {
|
||||
/**
|
||||
* Process the + operation
|
||||
*/
|
||||
fn graph_add<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left>;
|
||||
fn graph_add(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -154,7 +162,7 @@ pub trait GraphMul {
|
||||
/**
|
||||
* Process the * operation
|
||||
*/
|
||||
fn graph_mul<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left>;
|
||||
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -174,10 +182,10 @@ pub trait GraphDiv {
|
||||
/**
|
||||
* Process the + operation
|
||||
*/
|
||||
fn graph_mul<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left>;
|
||||
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left>;
|
||||
}
|
||||
|
||||
impl <'a, T> Add for CircuitNode<'a, T>
|
||||
impl <T> Add for CircuitNode<T>
|
||||
where T: FheType + GraphAdd<Left=T, Right=T>
|
||||
{
|
||||
type Output = Self;
|
||||
@@ -187,7 +195,7 @@ where T: FheType + GraphAdd<Left=T, Right=T>
|
||||
}
|
||||
}
|
||||
|
||||
impl <'a, T> Mul for CircuitNode<'a, T>
|
||||
impl <T> Mul for CircuitNode<T>
|
||||
where T: FheType + GraphMul<Left=T, Right=T>
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
@@ -213,7 +213,7 @@ fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
|
||||
|
||||
let return_type_sizes = tuple_inners.iter().map(|t| {
|
||||
quote! {
|
||||
#t ::num_ciphertexts(),
|
||||
#t ::NUM_CIPHERTEXTS,
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ pub trait NumCiphertexts {
|
||||
/**
|
||||
* Returns the number of ciphertexts this type decomposes into.
|
||||
*/
|
||||
fn num_ciphertexts() -> usize;
|
||||
const NUM_CIPHERTEXTS: usize;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user