mirror of
https://github.com/privacy-scaling-explorations/chiquito.git
synced 2026-01-10 06:28:06 -05:00
Simplify reduce
This commit is contained in:
@@ -14,12 +14,13 @@ pub fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
|
||||
signal_factory: &mut SF,
|
||||
) -> (Expr<F, V>, ConstrDecomp<F, V>) {
|
||||
let mut decomp = ConstrDecomp::default();
|
||||
let expr = reduce_degree_recursive(&mut decomp, constr, max_degree, max_degree, signal_factory);
|
||||
let expr = reduce_degree_recursive(&mut decomp, constr, max_degree, signal_factory);
|
||||
|
||||
(expr, decomp)
|
||||
}
|
||||
|
||||
/// Actual recursive implementation of `reduce_degre`. Key here to understand the difference
|
||||
// TODO: Update doc
|
||||
/// Actual recursive implementation of `reduce_degree`. Key here to understand the difference
|
||||
/// between: + total_max_degree: maximum degree of the PI the input expression is decomposed of.
|
||||
/// + partial_max_degree: maximum degree of the root PI, that can substitute the orginal
|
||||
/// expression.
|
||||
@@ -41,11 +42,10 @@ fn reduce_degree_recursive<
|
||||
>(
|
||||
decomp: &mut ConstrDecomp<F, V>,
|
||||
constr: Expr<F, V>,
|
||||
total_max_degree: usize,
|
||||
partial_max_degree: usize,
|
||||
max_degree: usize,
|
||||
signal_factory: &mut SF,
|
||||
) -> Expr<F, V> {
|
||||
if constr.degree() <= partial_max_degree {
|
||||
if constr.degree() <= max_degree {
|
||||
return constr;
|
||||
}
|
||||
|
||||
@@ -53,37 +53,21 @@ fn reduce_degree_recursive<
|
||||
Expr::Const(_) => constr,
|
||||
Expr::Sum(ses) => Expr::Sum(
|
||||
ses.into_iter()
|
||||
.map(|se| {
|
||||
reduce_degree_recursive(
|
||||
decomp,
|
||||
se,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
signal_factory,
|
||||
)
|
||||
})
|
||||
.map(|se| reduce_degree_recursive(decomp, se, max_degree, signal_factory))
|
||||
.collect(),
|
||||
),
|
||||
Expr::Mul(ses) => reduce_degree_mul(
|
||||
decomp,
|
||||
ses,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
signal_factory,
|
||||
),
|
||||
Expr::Mul(ses) => reduce_degree_mul(decomp, ses, max_degree, signal_factory),
|
||||
Expr::Neg(se) => Expr::Neg(Box::new(reduce_degree_recursive(
|
||||
decomp,
|
||||
*se,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
max_degree,
|
||||
signal_factory,
|
||||
))),
|
||||
// TODO: decompose in Pow expressions instead of Mul
|
||||
Expr::Pow(se, exp) => reduce_degree_mul(
|
||||
decomp,
|
||||
std::vec::from_elem(*se, exp as usize),
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
max_degree,
|
||||
signal_factory,
|
||||
),
|
||||
Expr::Query(_) => constr,
|
||||
@@ -91,89 +75,46 @@ fn reduce_degree_recursive<
|
||||
Expr::MI(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
|
||||
decomp: &mut ConstrDecomp<F, V>,
|
||||
ses: Vec<Expr<F, V>>,
|
||||
total_max_degree: usize,
|
||||
partial_max_degree: usize,
|
||||
max_degree: usize,
|
||||
signal_factory: &mut SF,
|
||||
) -> Expr<F, V> {
|
||||
// base case, if partial_max_degree == 1, the root expresion can only be a variable
|
||||
if partial_max_degree == 1 {
|
||||
let reduction = reduce_degree_mul(
|
||||
assert!(max_degree > 1);
|
||||
let mut ses = simplify_mul(ses);
|
||||
// Sort ses by degree so that we first pick higher degree multiplicands
|
||||
ses.sort_by(|a, b| a.degree().cmp(&b.degree()));
|
||||
let mut tail = Vec::new();
|
||||
// Remove multiplicands until ses is degree `max_degree-1`.
|
||||
while ses.iter().map(|se| se.degree()).sum::<usize>() > max_degree - 1 {
|
||||
tail.push(ses.pop().expect("ses.len() > 0"));
|
||||
}
|
||||
if tail.len() == 0 {
|
||||
// Input expression is below max_degree
|
||||
Expr::Mul(ses)
|
||||
} else if tail.len() == 1 && ses.len() == 0 {
|
||||
// Input expression contains a single multiplicand with degree > 1, unwrap it and recurse.
|
||||
reduce_degree_recursive(
|
||||
decomp,
|
||||
ses,
|
||||
total_max_degree,
|
||||
total_max_degree,
|
||||
tail.pop().expect("tail.len() == 1"),
|
||||
max_degree,
|
||||
signal_factory,
|
||||
);
|
||||
)
|
||||
} else {
|
||||
// Only one multiplicand in the tail and it's already degree 1, so no reduction needed.
|
||||
if tail.len() == 1 && tail[0].degree() == 1 {
|
||||
ses.push(tail.pop().expect("tail.len() == 1"));
|
||||
return Expr::Mul(ses);
|
||||
}
|
||||
// Reverse the tail to keep the original expression order
|
||||
tail.reverse();
|
||||
let reduction = reduce_degree_mul(decomp, tail, max_degree, signal_factory);
|
||||
let signal = signal_factory.create("virtual signal");
|
||||
decomp.auto_eq(signal.clone(), reduction);
|
||||
return Expr::Query(signal);
|
||||
ses.push(Expr::Query(signal));
|
||||
Expr::Mul(ses)
|
||||
}
|
||||
|
||||
let ses = simplify_mul(ses);
|
||||
|
||||
// to reduce the problem for recursion, at least one expression should have lower degree than
|
||||
// total_max_degree
|
||||
let mut first = true;
|
||||
let ses_reduced: Vec<Expr<F, V>> = ses
|
||||
.into_iter()
|
||||
.map(|se| {
|
||||
let partial_max_degree = if first {
|
||||
total_max_degree - 1
|
||||
} else {
|
||||
total_max_degree
|
||||
};
|
||||
let reduction = reduce_degree_recursive(
|
||||
decomp,
|
||||
se,
|
||||
total_max_degree,
|
||||
partial_max_degree,
|
||||
signal_factory,
|
||||
);
|
||||
first = false;
|
||||
|
||||
reduction
|
||||
})
|
||||
.collect();
|
||||
|
||||
// for_root will be multipliers that will be included in the root expression
|
||||
let mut for_root = Vec::default();
|
||||
// to_simplify will be multipliers that will be recursively decomposed and subsituted by a
|
||||
// virtual signal in the root expression
|
||||
let mut to_simplify = Vec::default();
|
||||
|
||||
let mut current_degree = 0;
|
||||
for se in ses_reduced {
|
||||
if se.degree() + current_degree < partial_max_degree {
|
||||
current_degree += se.degree();
|
||||
for_root.push(se);
|
||||
} else {
|
||||
to_simplify.push(se);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(!for_root.is_empty());
|
||||
assert!(!to_simplify.is_empty());
|
||||
|
||||
let rest_signal = signal_factory.create("rest_expr");
|
||||
for_root.push(Expr::Query(rest_signal.clone()));
|
||||
let root_expr = Expr::Mul(for_root);
|
||||
|
||||
// recursion, for the part that exceeds the degree and will be substituted by a virtual signal
|
||||
let simplified = reduce_degree_recursive(
|
||||
decomp,
|
||||
Expr::Mul(to_simplify),
|
||||
total_max_degree,
|
||||
total_max_degree,
|
||||
signal_factory,
|
||||
);
|
||||
|
||||
decomp.auto_eq(rest_signal, simplified);
|
||||
|
||||
root_expr
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -181,11 +122,11 @@ mod test {
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
use crate::{
|
||||
poly::{reduce::reduce_degree, ConstrDecomp, Expr::*, ToExpr},
|
||||
poly::{ConstrDecomp, Expr::*, ToExpr},
|
||||
sbpir::{query::Queriable, InternalSignal},
|
||||
};
|
||||
|
||||
use super::{reduce_degree_mul, SignalFactory};
|
||||
use super::{reduce_degree, reduce_degree_mul, SignalFactory};
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestSignalFactory {
|
||||
@@ -211,7 +152,7 @@ mod test {
|
||||
&mut decomp,
|
||||
vec![a.expr(), b.expr(), c.expr()],
|
||||
2,
|
||||
2,
|
||||
// 2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
@@ -229,7 +170,7 @@ mod test {
|
||||
&mut decomp,
|
||||
vec![(a + b), (b + c), (a + c)],
|
||||
2,
|
||||
2,
|
||||
// 2,
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
@@ -247,7 +188,7 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_degree() {
|
||||
fn test_reduce_degree_all() {
|
||||
let a: Queriable<Fr> = Queriable::Internal(InternalSignal::new("a"));
|
||||
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
|
||||
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
|
||||
@@ -257,23 +198,23 @@ mod test {
|
||||
let (result, decomp) =
|
||||
reduce_degree(a * b * c * d * e, 2, &mut TestSignalFactory::default());
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "(a * v1)");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v3) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", result), "(a * v3)");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
|
||||
assert_eq!(decomp.constrs.len(), 3);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
|
||||
assert_eq!(decomp.auto_signals.len(), 3);
|
||||
|
||||
let (result, decomp) = reduce_degree(
|
||||
@@ -282,23 +223,23 @@ mod test {
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v1)))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v3) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v3)))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
|
||||
assert_eq!(decomp.constrs.len(), 3);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
|
||||
assert_eq!(decomp.auto_signals.len(), 3);
|
||||
|
||||
let (result, decomp) = reduce_degree(
|
||||
@@ -307,28 +248,28 @@ mod test {
|
||||
&mut TestSignalFactory::default(),
|
||||
);
|
||||
|
||||
assert_eq!(format!("{:#?}", result), "((a * v1) + (-(b * v3)))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((a * a) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((a * v2) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((d * e) + (-v4))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[3]), "((c * v4) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", result), "((a * v2) + (-(b * v4)))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((a * a) + (-v1))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((a * v1) + (-v2))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((d * e) + (-v3))");
|
||||
assert_eq!(format!("{:#?}", decomp.constrs[3]), "((c * v3) + (-v4))");
|
||||
assert_eq!(decomp.constrs.len(), 4);
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * a)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * a)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * v2)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * v1)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (d * e)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
|
||||
assert!(decomp
|
||||
.auto_signals
|
||||
.iter()
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (c * v4)"));
|
||||
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (c * v3)"));
|
||||
assert_eq!(decomp.auto_signals.len(), 4);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user