Refactor operand type specification.

This commit is contained in:
Rick Weber
2021-11-21 15:17:53 -08:00
parent e6ca659f7d
commit 799bbcad2c
4 changed files with 126 additions and 54 deletions

View File

@@ -1,17 +1,22 @@
use sunscreen_ir::{IRTransform::*, IntermediateRepresentation, Operation::*, TransformList};
use petgraph::Direction;
use petgraph::{
Direction,
visit::EdgeRef
};
pub fn apply_insert_relinearizations(ir: &mut IntermediateRepresentation) {
ir.forward_traverse(|query, id| match query.get_node(id).operation {
Multiply(_a, _b) => {
Multiply => {
let mut transforms = TransformList::new();
let relin_node = transforms.push(AppendRelinearize(id.into()));
for e in query.get_neighbors(id, Direction::Outgoing) {
transforms.push(RemoveEdge(id.into(), e.into()));
transforms.push(AddEdge(relin_node.into(), e.into()));
for e in query.edges_directed(id, Direction::Outgoing) {
let operand_type = e.weight();
transforms.push(RemoveEdge(id.into(), e.target().into()));
transforms.push(AddEdge(relin_node.into(), e.target().into(), *operand_type));
}
transforms
@@ -58,7 +63,7 @@ mod tests {
.node_indices()
.filter(|i| {
match query.get_node(*i).operation {
Operation::Relinearize(_) => true,
Operation::Relinearize => true,
_ => false
}
})
@@ -71,7 +76,7 @@ mod tests {
assert_eq!(
relin_nodes
.iter()
.all(|id| { query.get_neighbors(*id, Direction::Incoming).count() == 1 }),
.all(|id| { query.neighbors_directed(*id, Direction::Incoming).count() == 1 }),
true
);
@@ -79,10 +84,10 @@ mod tests {
assert_eq!(
relin_nodes.iter().all(|id| {
query
.get_neighbors(*id, Direction::Incoming)
.neighbors_directed(*id, Direction::Incoming)
.map(|id| query.get_node(id))
.all(|node| match node.operation {
Operation::Multiply(_a, _b) => true,
Operation::Multiply => true,
_ => false
})
}),
@@ -92,7 +97,7 @@ mod tests {
// The first relin node should point to add_2
assert_eq!(
query
.get_neighbors(relin_nodes[0], Direction::Outgoing)
.neighbors_directed(relin_nodes[0], Direction::Outgoing)
.count(),
1
);
@@ -100,7 +105,7 @@ mod tests {
// The second relin node should point to nothing.
assert_eq!(
query
.get_neighbors(relin_nodes[1], Direction::Outgoing)
.neighbors_directed(relin_nodes[1], Direction::Outgoing)
.count(),
0
);
@@ -108,10 +113,10 @@ mod tests {
// The first relin node should point to add_2
assert_eq!(
query
.get_neighbors(relin_nodes[0], Direction::Outgoing)
.neighbors_directed(relin_nodes[0], Direction::Outgoing)
.all(|i| {
match query.get_node(i).operation {
Operation::Add(_a, _b) => true,
Operation::Add => true,
_ => false
}
}),

View File

@@ -12,8 +12,9 @@ use petgraph::{
algo::is_isomorphic_matching,
algo::toposort,
algo::tred::*,
Directed,
graph::{Graph, NodeIndex},
stable_graph::{Neighbors, StableGraph},
stable_graph::{Edges, Neighbors, StableGraph},
visit::{IntoNeighbors, IntoNodeIdentifiers},
Direction,
};
@@ -44,16 +45,25 @@ impl NodeInfo {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
/**
* Contains information about an edge between nodes in the circuit graph.
*/
pub struct EdgeInfo;
pub enum EdgeInfo {
/**
* The source node is the left input to a binary operation.
*/
LeftOperand,
impl EdgeInfo {
fn new() -> Self {
Self
}
/**
* The source node is the right input to fa binary operation.
*/
RightOperand,
/**
* The source node is the single input to a unary operation.
*/
UnaryOperand,
}
type IRGraph = StableGraph<NodeInfo, EdgeInfo>;
@@ -105,8 +115,8 @@ impl IntermediateRepresentation {
) -> NodeIndex {
let new_node = self.graph.add_node(NodeInfo::new(operation));
self.graph.update_edge(x, new_node, EdgeInfo::new());
self.graph.update_edge(y, new_node, EdgeInfo::new());
self.graph.update_edge(x, new_node, EdgeInfo::LeftOperand);
self.graph.update_edge(y, new_node, EdgeInfo::RightOperand);
new_node
}
@@ -114,7 +124,7 @@ impl IntermediateRepresentation {
fn append_1_input_node(&mut self, operation: Operation, x: NodeIndex) -> NodeIndex {
let new_node = self.graph.add_node(NodeInfo::new(operation));
self.graph.update_edge(x, new_node, EdgeInfo::new());
self.graph.update_edge(x, new_node, EdgeInfo::UnaryOperand);
new_node
}
@@ -136,14 +146,14 @@ impl IntermediateRepresentation {
* Appends a multiply operation that depends on the operands `x` and `y`.
*/
pub fn append_multiply(&mut self, x: NodeIndex, y: NodeIndex) -> NodeIndex {
self.append_2_input_node(Operation::Multiply(x, y), x, y)
self.append_2_input_node(Operation::Multiply, x, y)
}
/**
* Appends an add operation that depends on the operands `x` and `y`.
*/
pub fn append_add(&mut self, x: NodeIndex, y: NodeIndex) -> NodeIndex {
self.append_2_input_node(Operation::Add(x, y), x, y)
self.append_2_input_node(Operation::Add, x, y)
}
/**
@@ -173,14 +183,14 @@ impl IntermediateRepresentation {
* Sppends a node designating `x` as an output of the circuit.
*/
pub fn append_output_ciphertext(&mut self, x: NodeIndex) -> NodeIndex {
self.append_1_input_node(Operation::OutputCiphertext(x), x)
self.append_1_input_node(Operation::OutputCiphertext, x)
}
/**
* Appends an operation that relinearizes `x`.
*/
pub fn append_relinearize(&mut self, x: NodeIndex) -> NodeIndex {
self.append_1_input_node(Operation::Relinearize(x), x)
self.append_1_input_node(Operation::Relinearize, x)
}
/**
@@ -325,7 +335,7 @@ impl IntermediateRepresentation {
.node_indices()
.filter(|g| {
match self.graph[*g].operation {
Operation::OutputCiphertext(_) => true,
Operation::OutputCiphertext => true,
_ => false
}
})
@@ -426,9 +436,13 @@ impl<'a> GraphQuery<'a> {
* Typically, you want children writing forward traversal compiler passes and
* parents when writing reverse traversal compiler passes.
*/
pub fn get_neighbors(&self, x: NodeIndex, direction: Direction) -> Neighbors<EdgeInfo> {
pub fn neighbors_directed(&self, x: NodeIndex, direction: Direction) -> Neighbors<EdgeInfo> {
self.0.graph.neighbors_directed(x, direction)
}
}
pub fn edges_directed(&self, x: NodeIndex, direction: Direction) -> Edges<EdgeInfo, Directed> {
self.0.graph.edges_directed(x, direction)
}
}
#[derive(Debug, Clone)]
@@ -491,7 +505,7 @@ pub enum IRTransform {
/**
* Add a graph edge between two nodes.
*/
AddEdge(TransformNodeIndex, TransformNodeIndex),
AddEdge(TransformNodeIndex, TransformNodeIndex, EdgeInfo),
}
/**
@@ -613,11 +627,11 @@ impl TransformList {
None
}
AddEdge(x, y) => {
AddEdge(x, y, edge_info) => {
let x = self.materialize_index(*x);
let y = self.materialize_index(*y);
ir.graph.update_edge(x, y, EdgeInfo::new());
ir.graph.update_edge(x, y, *edge_info);
None
}
@@ -699,12 +713,12 @@ mod tests {
nodes[1].1.operation,
Operation::Literal(OuterLiteral::from(7i64))
);
assert_eq!(nodes[2].1.operation, Operation::Add(NodeIndex::from(0), NodeIndex::from(1)));
assert_eq!(nodes[2].1.operation, Operation::Add);
assert_eq!(
nodes[3].1.operation,
Operation::Literal(OuterLiteral::from(5u64))
);
assert_eq!(nodes[4].1.operation, Operation::Multiply(NodeIndex::from(2), NodeIndex::from(3)));
assert_eq!(nodes[4].1.operation, Operation::Multiply);
assert_eq!(
ir.graph

View File

@@ -1,4 +1,3 @@
use petgraph::stable_graph::NodeIndex;
use serde::{Deserialize, Serialize};
use crate::OuterLiteral;
@@ -33,17 +32,17 @@ pub enum Operation {
* a multiplication operation by reducing the resultant 3xN ciphertext down to
* the cannonical 2xN.
*/
Relinearize(NodeIndex),
Relinearize,
/**
* Multiply two values. Either operand may be a literal or a ciphertext.
*/
Multiply(NodeIndex, NodeIndex),
Multiply,
/**
* Add two values. Either operand may be a literal or a ciphertext.
*/
Add(NodeIndex, NodeIndex),
Add,
/**
* Computes the additive inverse of a plaintext or ciphertext.
@@ -69,5 +68,5 @@ pub enum Operation {
/**
* Represents a ciphertext output for the circuit.
*/
OutputCiphertext(NodeIndex),
OutputCiphertext,
}

View File

@@ -4,15 +4,61 @@
//! This crate contains the types and functions for executing a Sunscreen circuit
//! (i.e. an [`IntermediateRepresentation`](sunscreen_ir::IntermediateRepresentation)).
use sunscreen_ir::{IntermediateRepresentation, Operation::*};
use sunscreen_ir::{EdgeInfo, IntermediateRepresentation, Operation::*};
use crossbeam::atomic::AtomicCell;
use petgraph::{stable_graph::NodeIndex, Direction};
use petgraph::{stable_graph::{NodeIndex}, Direction, visit::{EdgeRef}};
use seal::{Ciphertext, Evaluator, RelinearizationKeys};
use std::borrow::Cow;
use std::sync::atomic::{AtomicUsize, Ordering};
/**
* Gets the two input operands and returns a tuple of left, right. For some operations
* (i.e. subtraction), order matters. While it's erroneous for a binary operations to have
* anything other than a single left and single right operand, having more operands will result
* in one being selected arbitrarily. Validating the [`IntermediateRepresentation`] will
* reveal having the wrong number of operands.
*
* # Panics
* Panics if the given node doesn't have at least one left and one right operand. Calling
* [`validate()`](sunscreen_ir::IntermediateRepresentation::validate()) should reveal this
* issue.
*/
pub fn get_left_right_operands(ir: &IntermediateRepresentation, index: NodeIndex) -> (NodeIndex, NodeIndex) {
let left = ir.graph.edges_directed(index, Direction::Incoming)
.filter(|e| *e.weight() == EdgeInfo::LeftOperand)
.map(|e| e.source())
.nth(0)
.unwrap();
let right = ir.graph.edges_directed(index, Direction::Incoming)
.filter(|e| *e.weight() == EdgeInfo::RightOperand)
.map(|e| e.source())
.nth(0)
.unwrap();
(left, right)
}
/**
* Gets the single unary input operand for the given node. If the [`IntermediateRepresentation`]
* is malformed and the node has more than one UnaryOperand, one will be selected arbitrarily.
* As such, one should validate the [`IntermediateRepresentation`] before calling this method.
*
* # Panics
* Panics if the given node doesn't have at least one unary operant. Calling
* [`validate()`](sunscreen_ir::IntermediateRepresentation::validate()) should reveal this
* issue.
*/
pub fn get_unary_operand(ir: &IntermediateRepresentation, index: NodeIndex) -> NodeIndex {
ir.graph.edges_directed(index, Direction::Incoming)
.filter(|e| *e.weight() == EdgeInfo::UnaryOperand)
.map(|e| e.source())
.nth(0)
.unwrap()
}
/**
* Run the given [`IntermediateRepresentation`] to completion with the given inputs. This
* method performs no validation. You must verify the program is first valid. Programs produced
@@ -72,29 +118,35 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
}
ShiftLeft => unimplemented!(),
ShiftRight => unimplemented!(),
Add(a_id, b_id) => {
let a = get_ciphertext(&data, a_id.index());
let b = get_ciphertext(&data, b_id.index());
Add => {
let (left, right) = get_left_right_operands(ir, index);
let a = get_ciphertext(&data, left.index());
let b = get_ciphertext(&data, right.index());
let c = evaluator.add(&a, &b).unwrap();
data[index.index()].store(Some(Cow::Owned(c)));
}
Multiply(a_id, b_id) => {
let a = get_ciphertext(&data, a_id.index());
let b = get_ciphertext(&data, b_id.index());
Multiply => {
let (left, right) = get_left_right_operands(ir, index);
let a = get_ciphertext(&data, left.index());
let b = get_ciphertext(&data, right.index());
let c = evaluator.multiply(&a, &b).unwrap();
data[index.index()].store(Some(Cow::Owned(c)));
}
SwapRows => unimplemented!(),
Relinearize(a_id) => {
Relinearize => {
let relin_keys = relin_keys.as_ref().expect(
"Fatal error: attempted to relinearize without relinearization keys.",
);
let a = get_ciphertext(&data, a_id.index());
let input = get_unary_operand(ir, index);
let a = get_ciphertext(&data, input.index());
let c = evaluator.relinearize(&a, relin_keys).unwrap();
@@ -103,8 +155,10 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
Negate => unimplemented!(),
Sub => unimplemented!(),
Literal(_x) => unimplemented!(),
OutputCiphertext(a_id) => {
let a = get_ciphertext(&data, a_id.index());
OutputCiphertext => {
let input = get_unary_operand(ir, index);
let a = get_ciphertext(&data, input.index());
data[index.index()].store(Some(Cow::Borrowed(&a)));
}
@@ -117,8 +171,8 @@ pub unsafe fn run_program_unchecked<E: Evaluator + Sync + Send>(
ir.graph
.node_indices()
.filter_map(|id| match ir.graph[id].operation {
OutputCiphertext(o_id) => {
Some(get_ciphertext(&data, o_id.index()).clone().into_owned())
OutputCiphertext => {
Some(get_ciphertext(&data, id.index()).clone().into_owned())
}
_ => None,
})