From c27f471999bd1962b3ca0435e0ce6b032d9ec864 Mon Sep 17 00:00:00 2001 From: Edward Chen Date: Wed, 10 Aug 2022 04:07:24 -0400 Subject: [PATCH] added ITE optimization --- examples/C/mpc/playground.c | 6 ++--- examples/circ.rs | 1 + scripts/build_mpc_c_test.zsh | 2 +- src/ir/opt/ite.rs | 48 ++++++++++++++++++++++++++++++++++++ src/ir/opt/mod.rs | 6 +++++ 5 files changed, 59 insertions(+), 4 deletions(-) create mode 100644 src/ir/opt/ite.rs diff --git a/examples/C/mpc/playground.c b/examples/C/mpc/playground.c index 9582786c..e75de495 100644 --- a/examples/C/mpc/playground.c +++ b/examples/C/mpc/playground.c @@ -1,7 +1,7 @@ int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b) { - int c[1]; if (a < b) { - c[0] = 1; + return 1; + } else { + return 2; } - return c[0]; } \ No newline at end of file diff --git a/examples/circ.rs b/examples/circ.rs index 8a8be75f..23f6523b 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -254,6 +254,7 @@ fn main() { // The linear scan pass produces more tuples, that must be eliminated Opt::Tuple, Opt::ConstantFold(Box::new(ignore.clone())), + Opt::Ite, // Inline Function Calls // Opt::Link, // Opt::Tuple, diff --git a/scripts/build_mpc_c_test.zsh b/scripts/build_mpc_c_test.zsh index bb272716..74c13c09 100755 --- a/scripts/build_mpc_c_test.zsh +++ b/scripts/build_mpc_c_test.zsh @@ -36,7 +36,7 @@ function mpc_test_bool { RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "b" } -# mpc_test 2 ./examples/C/mpc/playground.c +# mpc_test 2 ./examples/C/mpc/playground.c # build mpc arithmetic tests mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c diff --git a/src/ir/opt/ite.rs b/src/ir/opt/ite.rs new file mode 100644 index 00000000..05bc6d82 --- /dev/null +++ b/src/ir/opt/ite.rs @@ -0,0 +1,48 @@ +//! Rewrite ITE terms + +use crate::ir::opt::cfold::fold; +use crate::ir::opt::visit::RewritePass; +use crate::ir::term::*; + +/// ITE cache. +#[derive(Default)] +struct IteRewriter; + +impl RewritePass for IteRewriter { + fn visit Vec>( + &mut self, + _computation: &mut Computation, + orig: &Term, + rewritten_children: F, + ) -> Option { + let cs = rewritten_children(); + match &orig.op { + Op::Ite => { + let sel = cs[0].clone(); + let t = cs[1].clone(); + let f = cs[2].clone(); + match f.op { + Op::Ite => { + let child_sel = f.cs[0].clone(); + let child_t = f.cs[1].clone(); + if sel + == term![AND; term![Op::Not; child_sel.clone()], term![Op::Not; child_sel.clone()]] + { + Some(term![Op::Ite; child_sel, child_t, t]) + } else { + None + } + } + _ => None, + } + } + _ => None, + } + } +} + +/// Binarize (expand) n-ary terms. +pub fn rewrite_ites(c: &mut Computation) { + let mut pass = IteRewriter; + pass.traverse(c); +} diff --git a/src/ir/opt/mod.rs b/src/ir/opt/mod.rs index 23b5560b..455a43e1 100644 --- a/src/ir/opt/mod.rs +++ b/src/ir/opt/mod.rs @@ -3,6 +3,7 @@ pub mod binarize; pub mod cfold; pub mod flat; pub mod inline; +pub mod ite; pub mod link; pub mod mem; pub mod scalarize_vars; @@ -38,6 +39,8 @@ pub enum Opt { FlattenAssertions, /// Find outputs like `(= variable term)`, and substitute out `variable` Inline, + /// Ite peephole optimizations + Ite, /// Link function calls Link, /// Eliminate tuples @@ -108,6 +111,9 @@ pub fn opt>(mut fs: Functions, optimizations: I) -> Opt::Binarize => { binarize::binarize(comp); } + Opt::Ite => { + ite::rewrite_ites(comp); + } Opt::Inline => { let public_inputs = comp .metadata