mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-04-19 03:00:06 -04:00
Refactor operand type specification.
This commit is contained in:
@@ -1,17 +1,22 @@
|
||||
use sunscreen_ir::{IRTransform::*, IntermediateRepresentation, Operation::*, TransformList};
|
||||
|
||||
use petgraph::Direction;
|
||||
use petgraph::{
|
||||
Direction,
|
||||
visit::EdgeRef
|
||||
};
|
||||
|
||||
pub fn apply_insert_relinearizations(ir: &mut IntermediateRepresentation) {
|
||||
ir.forward_traverse(|query, id| match query.get_node(id).operation {
|
||||
Multiply(_a, _b) => {
|
||||
Multiply => {
|
||||
let mut transforms = TransformList::new();
|
||||
|
||||
let relin_node = transforms.push(AppendRelinearize(id.into()));
|
||||
|
||||
for e in query.get_neighbors(id, Direction::Outgoing) {
|
||||
transforms.push(RemoveEdge(id.into(), e.into()));
|
||||
transforms.push(AddEdge(relin_node.into(), e.into()));
|
||||
for e in query.edges_directed(id, Direction::Outgoing) {
|
||||
let operand_type = e.weight();
|
||||
|
||||
transforms.push(RemoveEdge(id.into(), e.target().into()));
|
||||
transforms.push(AddEdge(relin_node.into(), e.target().into(), *operand_type));
|
||||
}
|
||||
|
||||
transforms
|
||||
@@ -58,7 +63,7 @@ mod tests {
|
||||
.node_indices()
|
||||
.filter(|i| {
|
||||
match query.get_node(*i).operation {
|
||||
Operation::Relinearize(_) => true,
|
||||
Operation::Relinearize => true,
|
||||
_ => false
|
||||
}
|
||||
})
|
||||
@@ -71,7 +76,7 @@ mod tests {
|
||||
assert_eq!(
|
||||
relin_nodes
|
||||
.iter()
|
||||
.all(|id| { query.get_neighbors(*id, Direction::Incoming).count() == 1 }),
|
||||
.all(|id| { query.neighbors_directed(*id, Direction::Incoming).count() == 1 }),
|
||||
true
|
||||
);
|
||||
|
||||
@@ -79,10 +84,10 @@ mod tests {
|
||||
assert_eq!(
|
||||
relin_nodes.iter().all(|id| {
|
||||
query
|
||||
.get_neighbors(*id, Direction::Incoming)
|
||||
.neighbors_directed(*id, Direction::Incoming)
|
||||
.map(|id| query.get_node(id))
|
||||
.all(|node| match node.operation {
|
||||
Operation::Multiply(_a, _b) => true,
|
||||
Operation::Multiply => true,
|
||||
_ => false
|
||||
})
|
||||
}),
|
||||
@@ -92,7 +97,7 @@ mod tests {
|
||||
// The first relin node should point to add_2
|
||||
assert_eq!(
|
||||
query
|
||||
.get_neighbors(relin_nodes[0], Direction::Outgoing)
|
||||
.neighbors_directed(relin_nodes[0], Direction::Outgoing)
|
||||
.count(),
|
||||
1
|
||||
);
|
||||
@@ -100,7 +105,7 @@ mod tests {
|
||||
// The second relin node should point to nothing.
|
||||
assert_eq!(
|
||||
query
|
||||
.get_neighbors(relin_nodes[1], Direction::Outgoing)
|
||||
.neighbors_directed(relin_nodes[1], Direction::Outgoing)
|
||||
.count(),
|
||||
0
|
||||
);
|
||||
@@ -108,10 +113,10 @@ mod tests {
|
||||
// The first relin node should point to add_2
|
||||
assert_eq!(
|
||||
query
|
||||
.get_neighbors(relin_nodes[0], Direction::Outgoing)
|
||||
.neighbors_directed(relin_nodes[0], Direction::Outgoing)
|
||||
.all(|i| {
|
||||
match query.get_node(i).operation {
|
||||
Operation::Add(_a, _b) => true,
|
||||
Operation::Add => true,
|
||||
_ => false
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -12,8 +12,9 @@ use petgraph::{
|
||||
algo::is_isomorphic_matching,
|
||||
algo::toposort,
|
||||
algo::tred::*,
|
||||
Directed,
|
||||
graph::{Graph, NodeIndex},
|
||||
stable_graph::{Neighbors, StableGraph},
|
||||
stable_graph::{Edges, Neighbors, StableGraph},
|
||||
visit::{IntoNeighbors, IntoNodeIdentifiers},
|
||||
Direction,
|
||||
};
|
||||
@@ -44,16 +45,25 @@ impl NodeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
/**
|
||||
* Contains information about an edge between nodes in the circuit graph.
|
||||
*/
|
||||
pub struct EdgeInfo;
|
||||
pub enum EdgeInfo {
|
||||
/**
|
||||
* The source node is the left input to a binary operation.
|
||||
*/
|
||||
LeftOperand,
|
||||
|
||||
impl EdgeInfo {
|
||||
fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
/**
|
||||
* The source node is the right input to fa binary operation.
|
||||
*/
|
||||
RightOperand,
|
||||
|
||||
/**
|
||||
* The source node is the single input to a unary operation.
|
||||
*/
|
||||
UnaryOperand,
|
||||
}
|
||||
|
||||
type IRGraph = StableGraph<NodeInfo, EdgeInfo>;
|
||||
@@ -105,8 +115,8 @@ impl IntermediateRepresentation {
|
||||
) -> NodeIndex {
|
||||
let new_node = self.graph.add_node(NodeInfo::new(operation));
|
||||
|
||||
self.graph.update_edge(x, new_node, EdgeInfo::new());
|
||||
self.graph.update_edge(y, new_node, EdgeInfo::new());
|
||||
self.graph.update_edge(x, new_node, EdgeInfo::LeftOperand);
|
||||
self.graph.update_edge(y, new_node, EdgeInfo::RightOperand);
|
||||
|
||||
new_node
|
||||
}
|
||||
@@ -114,7 +124,7 @@ impl IntermediateRepresentation {
|
||||
fn append_1_input_node(&mut self, operation: Operation, x: NodeIndex) -> NodeIndex {
|
||||
let new_node = self.graph.add_node(NodeInfo::new(operation));
|
||||
|
||||
self.graph.update_edge(x, new_node, EdgeInfo::new());
|
||||
self.graph.update_edge(x, new_node, EdgeInfo::UnaryOperand);
|
||||
|
||||
new_node
|
||||
}
|
||||
@@ -136,14 +146,14 @@ impl IntermediateRepresentation {
|
||||
* Appends a multiply operation that depends on the operands `x` and `y`.
|
||||
*/
|
||||
pub fn append_multiply(&mut self, x: NodeIndex, y: NodeIndex) -> NodeIndex {
|
||||
self.append_2_input_node(Operation::Multiply(x, y), x, y)
|
||||
self.append_2_input_node(Operation::Multiply, x, y)
|
||||
}
|
||||
|
||||
/**
|
||||
* Appends an add operation that depends on the operands `x` and `y`.
|
||||
*/
|
||||
pub fn append_add(&mut self, x: NodeIndex, y: NodeIndex) -> NodeIndex {
|
||||
self.append_2_input_node(Operation::Add(x, y), x, y)
|
||||
self.append_2_input_node(Operation::Add, x, y)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -173,14 +183,14 @@ impl IntermediateRepresentation {
|
||||
* Sppends a node designating `x` as an output of the circuit.
|
||||
*/
|
||||
pub fn append_output_ciphertext(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.append_1_input_node(Operation::OutputCiphertext(x), x)
|
||||
self.append_1_input_node(Operation::OutputCiphertext, x)
|
||||
}
|
||||
|
||||
/**
|
||||
* Appends an operation that relinearizes `x`.
|
||||
*/
|
||||
pub fn append_relinearize(&mut self, x: NodeIndex) -> NodeIndex {
|
||||
self.append_1_input_node(Operation::Relinearize(x), x)
|
||||
self.append_1_input_node(Operation::Relinearize, x)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -325,7 +335,7 @@ impl IntermediateRepresentation {
|
||||
.node_indices()
|
||||
.filter(|g| {
|
||||
match self.graph[*g].operation {
|
||||
Operation::OutputCiphertext(_) => true,
|
||||
Operation::OutputCiphertext => true,
|
||||
_ => false
|
||||
}
|
||||
})
|
||||
@@ -426,9 +436,13 @@ impl<'a> GraphQuery<'a> {
|
||||
* Typically, you want children writing forward traversal compiler passes and
|
||||
* parents when writing reverse traversal compiler passes.
|
||||
*/
|
||||
pub fn get_neighbors(&self, x: NodeIndex, direction: Direction) -> Neighbors<EdgeInfo> {
|
||||
pub fn neighbors_directed(&self, x: NodeIndex, direction: Direction) -> Neighbors<EdgeInfo> {
|
||||
self.0.graph.neighbors_directed(x, direction)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edges_directed(&self, x: NodeIndex, direction: Direction) -> Edges<EdgeInfo, Directed> {
|
||||
self.0.graph.edges_directed(x, direction)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -491,7 +505,7 @@ pub enum IRTransform {
|
||||
/**
|
||||
* Add a graph edge between two nodes.
|
||||
*/
|
||||
AddEdge(TransformNodeIndex, TransformNodeIndex),
|
||||
AddEdge(TransformNodeIndex, TransformNodeIndex, EdgeInfo),
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -613,11 +627,11 @@ impl TransformList {
|
||||
|
||||
None
|
||||
}
|
||||
AddEdge(x, y) => {
|
||||
AddEdge(x, y, edge_info) => {
|
||||
let x = self.materialize_index(*x);
|
||||
let y = self.materialize_index(*y);
|
||||
|
||||
ir.graph.update_edge(x, y, EdgeInfo::new());
|
||||
ir.graph.update_edge(x, y, *edge_info);
|
||||
|
||||
None
|
||||
}
|
||||
@@ -699,12 +713,12 @@ mod tests {
|
||||
nodes[1].1.operation,
|
||||
Operation::Literal(OuterLiteral::from(7i64))
|
||||
);
|
||||
assert_eq!(nodes[2].1.operation, Operation::Add(NodeIndex::from(0), NodeIndex::from(1)));
|
||||
assert_eq!(nodes[2].1.operation, Operation::Add);
|
||||
assert_eq!(
|
||||
nodes[3].1.operation,
|
||||
Operation::Literal(OuterLiteral::from(5u64))
|
||||
);
|
||||
assert_eq!(nodes[4].1.operation, Operation::Multiply(NodeIndex::from(2), NodeIndex::from(3)));
|
||||
assert_eq!(nodes[4].1.operation, Operation::Multiply);
|
||||
|
||||
assert_eq!(
|
||||
ir.graph
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::OuterLiteral;
|
||||
@@ -33,17 +32,17 @@ pub enum Operation {
|
||||
* a multiplication operation by reducing the resultant 3xN ciphertext down to
|
||||
* the cannonical 2xN.
|
||||
*/
|
||||
Relinearize(NodeIndex),
|
||||
Relinearize,
|
||||
|
||||
/**
|
||||
* Multiply two values. Either operand may be a literal or a ciphertext.
|
||||
*/
|
||||
Multiply(NodeIndex, NodeIndex),
|
||||
Multiply,
|
||||
|
||||
/**
|
||||
* Add two values. Either operand may be a literal or a ciphertext.
|
||||
*/
|
||||
Add(NodeIndex, NodeIndex),
|
||||
Add,
|
||||
|
||||
/**
|
||||
* Computes the additive inverse of a plaintext or ciphertext.
|
||||
@@ -69,5 +68,5 @@ pub enum Operation {
|
||||
/**
|
||||
* Represents a ciphertext output for the circuit.
|
||||
*/
|
||||
OutputCiphertext(NodeIndex),
|
||||
OutputCiphertext,
|
||||
}
|
||||
|
||||
@@ -4,15 +4,61 @@
|
||||
//! This crate contains the types and functions for executing a Sunscreen circuit
|
||||
//! (i.e. an [`IntermediateRepresentation`](sunscreen_ir::IntermediateRepresentation)).
|
||||
|
||||
use sunscreen_ir::{IntermediateRepresentation, Operation::*};
|
||||
use sunscreen_ir::{EdgeInfo, IntermediateRepresentation, Operation::*};
|
||||
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use petgraph::{stable_graph::NodeIndex, Direction};
|
||||
use petgraph::{stable_graph::{NodeIndex}, Direction, visit::{EdgeRef}};
|
||||
use seal::{Ciphertext, Evaluator, RelinearizationKeys};
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/**
|
||||
* Gets the two input operands and returns a tuple of left, right. For some operations
|
||||
* (i.e. subtraction), order matters. While it's erroneous for a binary operations to have
|
||||
* anything other than a single left and single right operand, having more operands will result
|
||||
* in one being selected arbitrarily. Validating the [`IntermediateRepresentation`] will
|
||||
* reveal having the wrong number of operands.
|
||||
*
|
||||
* # Panics
|
||||
* Panics if the given node doesn't have at least one left and one right operand. Calling
|
||||
* [`validate()`](sunscreen_ir::IntermediateRepresentation::validate()) should reveal this
|
||||
* issue.
|
||||
*/
|
||||
pub fn get_left_right_operands(ir: &IntermediateRepresentation, index: NodeIndex) -> (NodeIndex, NodeIndex) {
|
||||
let left = ir.graph.edges_directed(index, Direction::Incoming)
|
||||
.filter(|e| *e.weight() == EdgeInfo::LeftOperand)
|
||||
.map(|e| e.source())
|
||||
.nth(0)
|
||||
.unwrap();
|
||||
|
||||
let right = ir.graph.edges_directed(index, Direction::Incoming)
|
||||
.filter(|e| *e.weight() == EdgeInfo::RightOperand)
|
||||
.map(|e| e.source())
|
||||
.nth(0)
|
||||
.unwrap();
|
||||
|
||||
(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the single unary input operand for the given node. If the [`IntermediateRepresentation`]
|
||||
* is malformed and the node has more than one UnaryOperand, one will be selected arbitrarily.
|
||||
* As such, one should validate the [`IntermediateRepresentation`] before calling this method.
|
||||
*
|
||||
* # Panics
|
||||
* Panics if the given node doesn't have at least one unary operant. Calling
|
||||
* [`validate()`](sunscreen_ir::IntermediateRepresentation::validate()) should reveal this
|
||||
* issue.
|
||||
*/
|
||||
pub fn get_unary_operand(ir: &IntermediateRepresentation, index: NodeIndex) -> NodeIndex {
|
||||
ir.graph.edges_directed(index, Direction::Incoming)
|
||||
.filter(|e| *e.weight() == EdgeInfo::UnaryOperand)
|
||||
.map(|e| e.source())
|
||||
.nth(0)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the given [`IntermediateRepresentation`] to completion with the given inputs. This
|
||||
* method performs no validation. You must verify the program is first valid. Programs produced
|
||||
@@ -72,29 +118,35 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
|
||||
}
|
||||
ShiftLeft => unimplemented!(),
|
||||
ShiftRight => unimplemented!(),
|
||||
Add(a_id, b_id) => {
|
||||
let a = get_ciphertext(&data, a_id.index());
|
||||
let b = get_ciphertext(&data, b_id.index());
|
||||
Add => {
|
||||
let (left, right) = get_left_right_operands(ir, index);
|
||||
|
||||
let a = get_ciphertext(&data, left.index());
|
||||
let b = get_ciphertext(&data, right.index());
|
||||
|
||||
let c = evaluator.add(&a, &b).unwrap();
|
||||
|
||||
data[index.index()].store(Some(Cow::Owned(c)));
|
||||
}
|
||||
Multiply(a_id, b_id) => {
|
||||
let a = get_ciphertext(&data, a_id.index());
|
||||
let b = get_ciphertext(&data, b_id.index());
|
||||
Multiply => {
|
||||
let (left, right) = get_left_right_operands(ir, index);
|
||||
|
||||
let a = get_ciphertext(&data, left.index());
|
||||
let b = get_ciphertext(&data, right.index());
|
||||
|
||||
let c = evaluator.multiply(&a, &b).unwrap();
|
||||
|
||||
data[index.index()].store(Some(Cow::Owned(c)));
|
||||
}
|
||||
SwapRows => unimplemented!(),
|
||||
Relinearize(a_id) => {
|
||||
Relinearize => {
|
||||
let relin_keys = relin_keys.as_ref().expect(
|
||||
"Fatal error: attempted to relinearize without relinearization keys.",
|
||||
);
|
||||
|
||||
let a = get_ciphertext(&data, a_id.index());
|
||||
let input = get_unary_operand(ir, index);
|
||||
|
||||
let a = get_ciphertext(&data, input.index());
|
||||
|
||||
let c = evaluator.relinearize(&a, relin_keys).unwrap();
|
||||
|
||||
@@ -103,8 +155,10 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
|
||||
Negate => unimplemented!(),
|
||||
Sub => unimplemented!(),
|
||||
Literal(_x) => unimplemented!(),
|
||||
OutputCiphertext(a_id) => {
|
||||
let a = get_ciphertext(&data, a_id.index());
|
||||
OutputCiphertext => {
|
||||
let input = get_unary_operand(ir, index);
|
||||
|
||||
let a = get_ciphertext(&data, input.index());
|
||||
|
||||
data[index.index()].store(Some(Cow::Borrowed(&a)));
|
||||
}
|
||||
@@ -117,8 +171,8 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
|
||||
ir.graph
|
||||
.node_indices()
|
||||
.filter_map(|id| match ir.graph[id].operation {
|
||||
OutputCiphertext(o_id) => {
|
||||
Some(get_ciphertext(&data, o_id.index()).clone().into_owned())
|
||||
OutputCiphertext => {
|
||||
Some(get_ciphertext(&data, id.index()).clone().into_owned())
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user