mirror of
https://github.com/privacy-scaling-explorations/chiquito.git
synced 2026-01-10 06:28:06 -05:00
Simplify reduce
This commit is contained in:
@@ -7,17 +7,7 @@ use crate::{
|
||||
|
||||
use super::{ConstrDecomp, SignalFactory};
|
||||
|
||||
/// Reduces the degree of an PI by decomposing it in many PI with a maximum degree.
|
||||
pub fn reduce_degre<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
|
||||
ctx: &mut ConstrDecomp<F, V>,
|
||||
constr: Expr<F, V>,
|
||||
max_degree: usize,
|
||||
signal_factory: &mut SF,
|
||||
) -> Expr<F, V> {
|
||||
reduce_degree_recursive(ctx, constr, max_degree, max_degree, signal_factory)
|
||||
}
|
||||
|
||||
/// Actual recursive implementation of `reduce_degre`. Key here to understand the difference
|
||||
/// Actual recursive implementation of `reduce_degree`. Key here to understand the difference
|
||||
/// between: + total_max_degree: maximum degree of the PI the input expression is decomposed of.
|
||||
/// + partial_max_degree: maximum degree of the root PI, that can substitute the orginal
|
||||
/// expression.
|
||||
@@ -32,18 +22,13 @@ pub fn reduce_degre<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sign
|
||||
/// total_max_degree
|
||||
/// };
|
||||
/// ```
|
||||
fn reduce_degree_recursive<
|
||||
F: Field,
|
||||
V: Clone + Eq + PartialEq + Hash + Debug,
|
||||
SF: SignalFactory<V>,
|
||||
>(
|
||||
ctx: &mut ConstrDecomp<F, V>,
|
||||
fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
|
||||
decomp: &mut ConstrDecomp<F, V>,
|
||||
constr: Expr<F, V>,
|
||||
total_max_degree: usize,
|
||||
partial_max_degree: usize,
|
||||
max_degree: usize,
|
||||
signal_factory: &mut SF,
|
||||
) -> Expr<F, V> {
|
||||
if constr.degree() <= partial_max_degree {
|
||||
if constr.degree() <= max_degree {
|
||||
return constr;
|
||||
}
|
||||
|
||||
@@ -51,37 +36,21 @@ fn reduce_degree_recursive<
|
||||
Expr::Const(_) => constr,
|
||||
Expr::Sum(ses) => Expr::Sum(
|
||||
ses.into_iter()
|
||||
.map(|se| {
|
||||
reduce_degree_recursive(
|
||||
ctx,
|
||||
se,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
signal_factory,
|
||||
)
|
||||
})
|
||||
.map(|se| reduce_degree(decomp, se, max_degree, signal_factory))
|
||||
.collect(),
|
||||
),
|
||||
Expr::Mul(ses) => reduce_degree_mul(
|
||||
ctx,
|
||||
ses,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
signal_factory,
|
||||
),
|
||||
Expr::Neg(se) => Expr::Neg(Box::new(reduce_degree_recursive(
|
||||
ctx,
|
||||
Expr::Mul(ses) => reduce_degree_mul(decomp, ses, max_degree, signal_factory),
|
||||
Expr::Neg(se) => Expr::Neg(Box::new(reduce_degree(
|
||||
decomp,
|
||||
*se,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
max_degree,
|
||||
signal_factory,
|
||||
))),
|
||||
// TODO: decompose in Pow expressions instead of Mul
|
||||
Expr::Pow(se, exp) => reduce_degree_mul(
|
||||
ctx,
|
||||
decomp,
|
||||
std::vec::from_elem(*se, exp as usize),
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
max_degree,
|
||||
signal_factory,
|
||||
),
|
||||
Expr::Query(_) => constr,
|
||||
@@ -89,84 +58,43 @@ fn reduce_degree_recursive<
|
||||
Expr::MI(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
|
||||
ctx: &mut ConstrDecomp<F, V>,
|
||||
ses: Vec<Expr<F, V>>,
|
||||
total_max_degree: usize,
|
||||
partial_max_degree: usize,
|
||||
decomp: &mut ConstrDecomp<F, V>,
|
||||
mut ses: Vec<Expr<F, V>>,
|
||||
max_degree: usize,
|
||||
signal_factory: &mut SF,
|
||||
) -> Expr<F, V> {
|
||||
// base case, if partial_max_degree == 1, the root expresion can only be a variable
|
||||
if partial_max_degree == 1 {
|
||||
let reduction =
|
||||
reduce_degree_mul(ctx, ses, total_max_degree, total_max_degree, signal_factory);
|
||||
let signal = signal_factory.create("virtual signal");
|
||||
ctx.auto_eq(signal.clone(), reduction);
|
||||
return Expr::Query(signal);
|
||||
assert!(max_degree > 1);
|
||||
let mut tail = Vec::new();
|
||||
// Remove multiplicands until ses is degree `max_degree-1`.
|
||||
while ses.iter().map(|se| se.degree()).sum::<usize>() > max_degree - 1 {
|
||||
tail.push(ses.pop().expect("ses.len() > 0"));
|
||||
}
|
||||
|
||||
let ses = simplify_mul(ses);
|
||||
|
||||
// to reduce the problem for recursion, at least one expression should have lower degree than
|
||||
// total_max_degree
|
||||
let mut first = true;
|
||||
let ses_reduced: Vec<Expr<F, V>> = ses
|
||||
.into_iter()
|
||||
.map(|se| {
|
||||
let partial_max_degree = if first {
|
||||
total_max_degree - 1
|
||||
} else {
|
||||
total_max_degree
|
||||
};
|
||||
let reduction = reduce_degree_recursive(
|
||||
ctx,
|
||||
se,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
signal_factory,
|
||||
);
|
||||
first = false;
|
||||
|
||||
reduction
|
||||
})
|
||||
.collect();
|
||||
|
||||
// for_root will be multipliers that will be included in the root expression
|
||||
let mut for_root = Vec::new();
|
||||
// to_simplify will be multipliers that will be recursively decomposed and subsituted by a
|
||||
// virtual signal in the root expression
|
||||
let mut to_simplify = Vec::new();
|
||||
|
||||
let mut current_degree = 0;
|
||||
for se in ses_reduced {
|
||||
if se.degree() + current_degree < partial_max_degree {
|
||||
current_degree += se.degree();
|
||||
for_root.push(se);
|
||||
} else {
|
||||
to_simplify.push(se);
|
||||
if tail.len() == 0 {
|
||||
// Input expression is below max_degree
|
||||
Expr::Mul(ses)
|
||||
} else if tail.len() == 1 && ses.len() == 0 {
|
||||
// Input expression contains a single multiplicand with degree > 1, unwrap it and recurse.
|
||||
reduce_degree(
|
||||
decomp,
|
||||
tail.pop().expect("tail.len() == 1"),
|
||||
max_degree,
|
||||
signal_factory,
|
||||
)
|
||||
} else {
|
||||
// Only one multiplicand in the tail and it's already degree 1, so no reduction needed.
|
||||
if tail.len() == 1 && tail[0].degree() == 1 {
|
||||
ses.push(tail.pop().expect("tail.len() == 1"));
|
||||
return Expr::Mul(ses);
|
||||
}
|
||||
// Reverse the tail to keep the original expression order
|
||||
tail.reverse();
|
||||
let reduction = reduce_degree_mul(decomp, tail, max_degree, signal_factory);
|
||||
let signal = signal_factory.create("virtual signal");
|
||||
decomp.auto_eq(signal.clone(), reduction);
|
||||
ses.push(Expr::Query(signal));
|
||||
Expr::Mul(ses)
|
||||
}
|
||||
|
||||
assert!(!for_root.is_empty());
|
||||
assert!(!to_simplify.is_empty());
|
||||
|
||||
let rest_signal = signal_factory.create("rest_expr");
|
||||
for_root.push(Expr::Query(rest_signal.clone()));
|
||||
let root_expr = Expr::Mul(for_root);
|
||||
|
||||
// recursion, for the part that exceeds the degree and will be substituted by a virtual signal
|
||||
let simplified = reduce_degree_recursive(
|
||||
ctx,
|
||||
Expr::Mul(to_simplify),
|
||||
total_max_degree,
|
||||
total_max_degree,
|
||||
signal_factory,
|
||||
);
|
||||
|
||||
ctx.auto_eq(rest_signal, simplified);
|
||||
|
||||
root_expr
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -178,7 +106,7 @@ mod test {
|
||||
sbpir::{query::Queriable, InternalSignal},
|
||||
};
|
||||
|
||||
use super::{reduce_degre, reduce_degree_mul, SignalFactory};
|
||||
use super::{reduce_degree, reduce_degree_mul, SignalFactory};
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestSignalFactory {
|
||||
@@ -199,138 +127,138 @@ mod test {
|
||||
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
|
||||
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
|
||||
|
||||
let mut ctx = ConstrDecomp::new();
|
||||
let mut decomp = ConstrDecomp::new();
|
||||
let result = reduce_degree_mul(
|
||||
&mut ctx,
|
||||
&mut decomp,
|
||||
vec![a.expr(), b.expr(), c.expr()],
|
||||
2,
|
||||
2,
|
||||
// 2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "(a * v1)");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((b * c) + (-v1))");
|
||||
assert_eq!(ctx.constrs.len(), 1);
|
||||
assert!(ctx
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((b * c) + (-v1))");
|
||||
assert_eq!(decomp.constrs.len(), 1);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * c)"));
|
||||
assert_eq!(ctx.auto_signals.len(), 1);
|
||||
assert_eq!(decomp.auto_signals.len(), 1);
|
||||
|
||||
let mut ctx = ConstrDecomp::new();
|
||||
let mut decomp = ConstrDecomp::new();
|
||||
let result = reduce_degree_mul(
|
||||
&mut ctx,
|
||||
&mut decomp,
|
||||
vec![(a + b), (b + c), (a + c)],
|
||||
2,
|
||||
2,
|
||||
// 2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "((a + b) * v1)");
|
||||
assert_eq!(
|
||||
format!("{:#?}", ctx.constrs[0]),
|
||||
format!("{:#?}", decomp.constrs[0]),
|
||||
"(((b + c) * (a + c)) + (-v1))"
|
||||
);
|
||||
assert_eq!(ctx.constrs.len(), 1);
|
||||
assert!(ctx
|
||||
assert_eq!(decomp.constrs.len(), 1);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: ((b + c) * (a + c))"));
|
||||
assert_eq!(ctx.auto_signals.len(), 1);
|
||||
assert_eq!(decomp.auto_signals.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_degree() {
|
||||
fn test_reduce_degree_all() {
|
||||
let a: Queriable<Fr> = Queriable::Internal(InternalSignal::new("a"));
|
||||
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
|
||||
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
|
||||
let d: Queriable<Fr> = Queriable::Internal(InternalSignal::new("d"));
|
||||
let e: Queriable<Fr> = Queriable::Internal(InternalSignal::new("e"));
|
||||
|
||||
let mut ctx = ConstrDecomp::new();
|
||||
let result = reduce_degre(
|
||||
&mut ctx,
|
||||
let mut decomp = ConstrDecomp::new();
|
||||
let result = reduce_degree(
|
||||
&mut decomp,
|
||||
a * b * c * d * e,
|
||||
2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "(a * v1)");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((d * e) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[1]), "((c * v3) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[2]), "((b * v2) + (-v1))");
|
||||
assert_eq!(ctx.constrs.len(), 3);
|
||||
assert!(ctx
|
||||
assert_eq!(format!("{:#?}", result), "(a * v3)");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
|
||||
assert_eq!(decomp.constrs.len(), 3);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
|
||||
assert_eq!(ctx.auto_signals.len(), 3);
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
|
||||
assert_eq!(decomp.auto_signals.len(), 3);
|
||||
|
||||
let mut ctx = ConstrDecomp::new();
|
||||
let result = reduce_degre(
|
||||
&mut ctx,
|
||||
let mut decomp = ConstrDecomp::new();
|
||||
let result = reduce_degree(
|
||||
&mut decomp,
|
||||
1.expr() - (a * b * c * d * e),
|
||||
2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v1)))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((d * e) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[1]), "((c * v3) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[2]), "((b * v2) + (-v1))");
|
||||
assert_eq!(ctx.constrs.len(), 3);
|
||||
assert!(ctx
|
||||
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v3)))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
|
||||
assert_eq!(decomp.constrs.len(), 3);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
|
||||
assert_eq!(ctx.auto_signals.len(), 3);
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
|
||||
assert_eq!(decomp.auto_signals.len(), 3);
|
||||
|
||||
let mut ctx = ConstrDecomp::new();
|
||||
let result = reduce_degre(
|
||||
&mut ctx,
|
||||
let mut decomp = ConstrDecomp::new();
|
||||
let result = reduce_degree(
|
||||
&mut decomp,
|
||||
Pow(Box::new(a.expr()), 4) - (b * c * d * e),
|
||||
2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "((a * v1) + (-(b * v3)))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((a * a) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[1]), "((a * v2) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[2]), "((d * e) + (-v4))");
|
||||
assert_eq!(format!("{:#?}", ctx.constrs[3]), "((c * v4) + (-v3))");
|
||||
assert_eq!(ctx.constrs.len(), 4);
|
||||
assert!(ctx
|
||||
assert_eq!(format!("{:#?}", result), "((a * v2) + (-(b * v4)))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((a * a) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((a * v1) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((d * e) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[3]), "((c * v3) + (-v4))");
|
||||
assert_eq!(decomp.constrs.len(), 4);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * a)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * a)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * v2)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * v1)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (d * e)"));
|
||||
assert!(ctx
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (c * v4)"));
|
||||
assert_eq!(ctx.auto_signals.len(), 4);
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (c * v3)"));
|
||||
assert_eq!(decomp.auto_signals.len(), 4);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user