This commit is contained in:
Rick Weber
2021-12-15 22:53:04 -08:00
parent d9b947597a
commit 88b757fbbc
9 changed files with 97 additions and 77 deletions

20
.vscode/launch.json vendored
View File

@@ -100,7 +100,7 @@
{
"type": "lldb",
"request": "launch",
"name": "Debug tests in library 'sunscreen_frontend'",
"name": "Debug tests in library 'sunscreen_compiler'",
"cargo": {
"args": [
"test",
@@ -115,5 +115,23 @@
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug tests in library 'sunscreen_compiler_macros'",
"cargo": {
"args": [
"test",
"--no-run",
"--package=sunscreen_compiler_macros"
],
"filter": {
"name": "circuit_tests",
"kind": "test"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
]
}

View File

@@ -1,7 +1,7 @@
mod rational;
use std::thread::{self, JoinHandle};
use sunscreen_compiler::{circuit, Compiler};
use sunscreen_compiler::{circuit};
use rational::Rational;
fn alice() -> JoinHandle<()> {
@@ -24,6 +24,6 @@ fn main() {
let a = alice();
let b = bob();
a.join();
b.join();
a.join().unwrap();
b.join().unwrap();
}

View File

@@ -1,11 +1,9 @@
use sunscreen_compiler::{TypeName, Params, InnerPlaintext, Plaintext};
use sunscreen_compiler::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, CircuitNode};
use sunscreen_compiler::{TypeName, Params, InnerPlaintext, Plaintext, with_ctx};
use sunscreen_compiler::types::{BfvType, FheType, NumCiphertexts, TryIntoPlaintext, TryFromPlaintext, Signed, GraphAdd, CircuitNode};
use sunscreen_runtime::{Error};
use num::Rational64;
use std::ops::{Add, Mul, Sub, Div};
#[derive(Debug, Clone, Copy, TypeName, PartialEq, Eq)]
pub struct Rational {
num: Signed,
@@ -16,9 +14,7 @@ pub struct Rational {
// each type. It's spiritually similar to [`std::mem::sizeof`] except it returns the number
// of plaintexts this type needs rather than the number of bytes.
impl NumCiphertexts for Rational {
fn num_ciphertexts() -> usize {
Signed::num_ciphertexts() + Signed::num_ciphertexts()
}
const NUM_CIPHERTEXTS: usize = Signed::NUM_CIPHERTEXTS + Signed::NUM_CIPHERTEXTS;
}
// This trait takes a plaintext and turns it into a [`Rational`]. [`Plaintext`] is a type generally
@@ -30,12 +26,12 @@ impl NumCiphertexts for Rational {
// This trait is needed so Runtime knows how to package this type after decryption.
impl TryFromPlaintext for Rational {
fn try_from_plaintext(plaintext: &Plaintext, params: &Params) -> Result<Self, Error> {
let (num, den) = match plaintext.inner {
let (num, den) = match &plaintext.inner {
InnerPlaintext::Seal(p) => {
// We encode Rationals as 2 plaintexts. Wrap each plaintext and delegate
// to Signed::try_from_plaintext to compute our inner values.
let num = Plaintext {inner: InnerPlaintext::Seal(vec![p[0]]) };
let den = Plaintext {inner: InnerPlaintext::Seal(vec![p[1]]) };
let num = Plaintext {inner: InnerPlaintext::Seal(vec![p[0].clone()]) };
let den = Plaintext {inner: InnerPlaintext::Seal(vec![p[1].clone()]) };
(
Signed::try_from_plaintext(&num, params)?,
@@ -58,7 +54,7 @@ impl TryIntoPlaintext for Rational {
let (num, den) = match (num.inner, den.inner) {
(InnerPlaintext::Seal(n), InnerPlaintext::Seal(d)) => {
(n[0], d[0])
(n[0].clone(), d[0].clone())
}
};
@@ -88,6 +84,19 @@ impl TryFrom<f64> for Rational {
}
}
impl Add for CircuitNode<Rational> {
impl GraphAdd for Rational {
type Left = Self;
type Right = Self;
fn graph_add(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
with_ctx(|ctx| {
let ids = [
ctx.add_addition(a.ids[0], b.ids[0]),
ctx.add_addition(a.ids[1], b.ids[1]),
];
CircuitNode::new(&ids)
})
}
}

View File

@@ -1,4 +1,4 @@
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, Params, PlainModulusConstraint};
use sunscreen_compiler::{circuit, types::Unsigned, Compiler, PlainModulusConstraint};
use sunscreen_runtime::Runtime;
/**

View File

@@ -45,8 +45,7 @@ mod params;
*/
pub mod types;
use std::cell::{RefCell, RefMut};
use std::cell::{RefCell};
use petgraph::{
algo::is_isomorphic_matching,
stable_graph::{NodeIndex, StableGraph},
@@ -71,8 +70,6 @@ pub use sunscreen_runtime::{
Params, PublicKey, RequiredKeys, Runtime, InnerPlaintext, Plaintext
};
use types::{FheType, CircuitNode};
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
/**
* Represents a literal node's data.
@@ -223,15 +220,15 @@ thread_local! {
*/
pub fn with_ctx<F, R>(f: F) -> R
where
F: FnOnce(&'static mut Context) -> R,
F: FnOnce(&mut Context) -> R,
{
CURRENT_CTX.with(|ctx| {
let mut option: RefMut<'static, Option<&mut Context>> = ctx.borrow_mut();
let ctx = option
.as_mut()
.expect("Called Ciphertext::new() outside of a context.");
f(ctx)
let mut option = ctx.borrow_mut();
let ctx = option
.as_mut()
.expect("Called Ciphertext::new() outside of a context.");
f(ctx)
})
}
@@ -249,22 +246,14 @@ impl Context {
}
}
pub fn allocate_circuit_node<T: FheType>(&mut self, ids: &[NodeIndex]) -> CircuitNode<T> {
let indicies = self.allocate_indicies(T::num_ciphertexts());
indicies.copy_from_slice(ids);
CircuitNode::new(indicies)
}
fn allocate_indicies(&mut self, len: usize) -> &mut [NodeIndex] {
pub(crate) unsafe fn allocate_indicies(&mut self, len: usize) -> &'static mut [NodeIndex] {
let before_len = self.indicies_store.len();
self.indicies_store.resize(before_len + len, NodeIndex::new(0));
let (_, right) = self.indicies_store.split_at_mut(before_len);
right
std::mem::transmute(right)
}
fn add_2_input(&mut self, op: Operation, left: NodeIndex, right: NodeIndex) -> NodeIndex {

View File

@@ -2,7 +2,7 @@ use seal::Plaintext as SealPlaintext;
use crate::{
types::{BfvType, CircuitNode, FheType},
Context, TypeName as DeriveTypeName, Params, with_ctx
TypeName as DeriveTypeName, Params, with_ctx
};
use crate::types::{GraphAdd, GraphMul};
@@ -35,9 +35,11 @@ impl GraphAdd for Unsigned {
type Left = Unsigned;
type Right = Unsigned;
fn graph_add<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left> {
fn graph_add(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
with_ctx(|ctx| {
ctx.allocate_circuit_node(&[ctx.add_addition(a.ids[0], b.ids[0])])
let n = ctx.add_addition(a.ids[0], b.ids[0]);
CircuitNode::new(&[n])
})
}
}
@@ -46,13 +48,11 @@ impl GraphMul for Unsigned {
type Left = Unsigned;
type Right = Unsigned;
fn graph_mul<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left> {
Context::with_ctx(|ctx| unsafe {
let indicies = ctx.allocate_indicies(Self::num_ciphertexts());
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left> {
with_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
indicies[0] = ctx.add_multiplication(a.ids[0], b.ids[0]);
CircuitNode::new(indicies)
CircuitNode::new(&[n])
})
}
}
@@ -102,9 +102,7 @@ impl TryFromPlaintext for Unsigned {
}
impl NumCiphertexts for Unsigned {
fn num_ciphertexts() -> usize {
1
}
const NUM_CIPHERTEXTS: usize = 1;
}
impl From<u64> for Unsigned {
@@ -128,9 +126,7 @@ pub struct Signed {
}
impl NumCiphertexts for Signed {
fn num_ciphertexts() -> usize {
1
}
const NUM_CIPHERTEXTS: usize = 1;
}
impl FheType for Signed {}

View File

@@ -10,7 +10,7 @@ pub use sunscreen_runtime::{
};
pub use integer::{Signed, Unsigned};
use std::ops::{Add, Mul, Div, Sub};
use std::ops::{Add, Mul};
#[derive(Clone, Copy, Serialize, Deserialize)]
/**
@@ -38,14 +38,16 @@ impl U64LiteralRef {
* no [`CircuitNode`] should outlive the said context. Violating any of these condicitions may result
* in memory corruption or use-after-free.
*/
pub struct CircuitNode<'a, T: FheType> {
ids: &'a [NodeIndex],
pub struct CircuitNode<T: FheType> {
/**
* The ids on this node.
*/
pub ids: &'static [NodeIndex],
_phantom: std::marker::PhantomData<T>,
}
impl <'a, T: FheType> CircuitNode<'a, T> {
impl <T: FheType> CircuitNode<T> {
/**
* Creates a new circuit node with the given node index.
*
@@ -62,9 +64,15 @@ impl <'a, T: FheType> CircuitNode<'a, T> {
* never outlive the backing context, use-after-free can occur.
*
*/
pub unsafe fn new(ids: &'a [NodeIndex]) -> Self {
pub fn new(ids: &[NodeIndex]) -> Self {
let ids_dest = with_ctx(|ctx| {
unsafe { ctx.allocate_indicies(ids.len()) }
});
ids_dest.copy_from_slice(ids);
Self {
ids,
ids: ids_dest,
_phantom: std::marker::PhantomData,
}
}
@@ -79,14 +87,14 @@ impl <'a, T: FheType> CircuitNode<'a, T> {
* never outlive the backing context, use-after-free can occur.
*
*/
pub unsafe fn input() -> Self {
let mut ids = with_ctx(|ctx| ctx.allocate_indicies(T::num_ciphertexts()));
pub fn input() -> Self {
let mut ids = Vec::with_capacity(T::NUM_CIPHERTEXTS);
for i in 0..T::num_ciphertexts() {
ids[i] = with_ctx(|ctx| ctx.add_input());
for _ in 0..T::NUM_CIPHERTEXTS {
ids.push(with_ctx(|ctx| ctx.add_input()));
}
unsafe { Self::new(ids) }
CircuitNode::new(&ids)
}
/**
@@ -99,14 +107,14 @@ impl <'a, T: FheType> CircuitNode<'a, T> {
* never outlive the backing context, use-after-free can occur.
*
*/
pub unsafe fn output(&self) -> Self {
let mut ids = with_ctx(|ctx| ctx.allocate_indicies(T::num_ciphertexts()));
pub fn output(&self) -> Self {
let mut ids = Vec::with_capacity(self.ids.len());
for i in 0..self.ids.len() {
ids[i] = with_ctx(|ctx| ctx.add_output(self.ids[i]));
ids.push(with_ctx(|ctx| ctx.add_output(self.ids[i])));
}
unsafe { Self::new(ids) }
CircuitNode::new(&ids)
}
/**
@@ -134,7 +142,7 @@ pub trait GraphAdd {
/**
* Process the + operation
*/
fn graph_add<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left>;
fn graph_add(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left>;
}
/**
@@ -154,7 +162,7 @@ pub trait GraphMul {
/**
* Process the * operation
*/
fn graph_mul<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left>;
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left>;
}
/**
@@ -174,10 +182,10 @@ pub trait GraphDiv {
/**
* Process the + operation
*/
fn graph_mul<'a>(a: CircuitNode<'a, Self::Left>, b: CircuitNode<'a, Self::Right>) -> CircuitNode<'a, Self::Left>;
fn graph_mul(a: CircuitNode<Self::Left>, b: CircuitNode<Self::Right>) -> CircuitNode<Self::Left>;
}
impl <'a, T> Add for CircuitNode<'a, T>
impl <T> Add for CircuitNode<T>
where T: FheType + GraphAdd<Left=T, Right=T>
{
type Output = Self;
@@ -187,7 +195,7 @@ where T: FheType + GraphAdd<Left=T, Right=T>
}
}
impl <'a, T> Mul for CircuitNode<'a, T>
impl <T> Mul for CircuitNode<T>
where T: FheType + GraphMul<Left=T, Right=T>
{
type Output = Self;

View File

@@ -213,7 +213,7 @@ fn create_signature(args: &[&Type], ret: &ReturnType) -> TokenStream {
let return_type_sizes = tuple_inners.iter().map(|t| {
quote! {
#t ::num_ciphertexts(),
#t ::NUM_CIPHERTEXTS,
}
});

View File

@@ -98,7 +98,7 @@ pub trait NumCiphertexts {
/**
* Returns the number of ciphertexts this type decomposes into.
*/
fn num_ciphertexts() -> usize;
const NUM_CIPHERTEXTS: usize;
}
/**