Simplify reduce

This commit is contained in:
Eduard S
2024-01-12 13:02:34 +00:00
parent ac6b6fcebb
commit 2ac7d77dd2

View File

@@ -14,12 +14,13 @@ pub fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
signal_factory: &mut SF,
) -> (Expr<F, V>, ConstrDecomp<F, V>) {
let mut decomp = ConstrDecomp::default();
let expr = reduce_degree_recursive(&mut decomp, constr, max_degree, max_degree, signal_factory);
let expr = reduce_degree_recursive(&mut decomp, constr, max_degree, signal_factory);
(expr, decomp)
}
/// Actual recursive implementation of `reduce_degre`. Key here to understand the difference
// TODO: Update doc
/// Actual recursive implementation of `reduce_degree`. Key here to understand the difference
/// between: + total_max_degree: maximum degree of the PI the input expression is decomposed of.
/// + partial_max_degree: maximum degree of the root PI, that can substitute the orginal
/// expression.
@@ -41,11 +42,10 @@ fn reduce_degree_recursive<
>(
decomp: &mut ConstrDecomp<F, V>,
constr: Expr<F, V>,
total_max_degree: usize,
partial_max_degree: usize,
max_degree: usize,
signal_factory: &mut SF,
) -> Expr<F, V> {
if constr.degree() <= partial_max_degree {
if constr.degree() <= max_degree {
return constr;
}
@@ -53,37 +53,21 @@ fn reduce_degree_recursive<
Expr::Const(_) => constr,
Expr::Sum(ses) => Expr::Sum(
ses.into_iter()
.map(|se| {
reduce_degree_recursive(
decomp,
se,
total_max_degree,
partial_max_degree,
signal_factory,
)
})
.map(|se| reduce_degree_recursive(decomp, se, max_degree, signal_factory))
.collect(),
),
Expr::Mul(ses) => reduce_degree_mul(
decomp,
ses,
total_max_degree,
partial_max_degree,
signal_factory,
),
Expr::Mul(ses) => reduce_degree_mul(decomp, ses, max_degree, signal_factory),
Expr::Neg(se) => Expr::Neg(Box::new(reduce_degree_recursive(
decomp,
*se,
total_max_degree,
partial_max_degree,
max_degree,
signal_factory,
))),
// TODO: decompose in Pow expressions instead of Mul
Expr::Pow(se, exp) => reduce_degree_mul(
decomp,
std::vec::from_elem(*se, exp as usize),
total_max_degree,
partial_max_degree,
max_degree,
signal_factory,
),
Expr::Query(_) => constr,
@@ -91,89 +75,46 @@ fn reduce_degree_recursive<
Expr::MI(_) => unimplemented!(),
}
}
fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
decomp: &mut ConstrDecomp<F, V>,
ses: Vec<Expr<F, V>>,
total_max_degree: usize,
partial_max_degree: usize,
max_degree: usize,
signal_factory: &mut SF,
) -> Expr<F, V> {
// base case, if partial_max_degree == 1, the root expresion can only be a variable
if partial_max_degree == 1 {
let reduction = reduce_degree_mul(
assert!(max_degree > 1);
let mut ses = simplify_mul(ses);
// Sort ses by degree so that we first pick higher degree multiplicands
ses.sort_by(|a, b| a.degree().cmp(&b.degree()));
let mut tail = Vec::new();
// Remove multiplicands until ses is degree `max_degree-1`.
while ses.iter().map(|se| se.degree()).sum::<usize>() > max_degree - 1 {
tail.push(ses.pop().expect("ses.len() > 0"));
}
if tail.len() == 0 {
// Input expression is below max_degree
Expr::Mul(ses)
} else if tail.len() == 1 && ses.len() == 0 {
// Input expression contains a single multiplicand with degree > 1, unwrap it and recurse.
reduce_degree_recursive(
decomp,
ses,
total_max_degree,
total_max_degree,
tail.pop().expect("tail.len() == 1"),
max_degree,
signal_factory,
);
)
} else {
// Only one multiplicand in the tail and it's already degree 1, so no reduction needed.
if tail.len() == 1 && tail[0].degree() == 1 {
ses.push(tail.pop().expect("tail.len() == 1"));
return Expr::Mul(ses);
}
// Reverse the tail to keep the original expression order
tail.reverse();
let reduction = reduce_degree_mul(decomp, tail, max_degree, signal_factory);
let signal = signal_factory.create("virtual signal");
decomp.auto_eq(signal.clone(), reduction);
return Expr::Query(signal);
ses.push(Expr::Query(signal));
Expr::Mul(ses)
}
let ses = simplify_mul(ses);
// to reduce the problem for recursion, at least one expression should have lower degree than
// total_max_degree
let mut first = true;
let ses_reduced: Vec<Expr<F, V>> = ses
.into_iter()
.map(|se| {
let partial_max_degree = if first {
total_max_degree - 1
} else {
total_max_degree
};
let reduction = reduce_degree_recursive(
decomp,
se,
total_max_degree,
partial_max_degree,
signal_factory,
);
first = false;
reduction
})
.collect();
// for_root will be multipliers that will be included in the root expression
let mut for_root = Vec::default();
// to_simplify will be multipliers that will be recursively decomposed and subsituted by a
// virtual signal in the root expression
let mut to_simplify = Vec::default();
let mut current_degree = 0;
for se in ses_reduced {
if se.degree() + current_degree < partial_max_degree {
current_degree += se.degree();
for_root.push(se);
} else {
to_simplify.push(se);
}
}
assert!(!for_root.is_empty());
assert!(!to_simplify.is_empty());
let rest_signal = signal_factory.create("rest_expr");
for_root.push(Expr::Query(rest_signal.clone()));
let root_expr = Expr::Mul(for_root);
// recursion, for the part that exceeds the degree and will be substituted by a virtual signal
let simplified = reduce_degree_recursive(
decomp,
Expr::Mul(to_simplify),
total_max_degree,
total_max_degree,
signal_factory,
);
decomp.auto_eq(rest_signal, simplified);
root_expr
}
#[cfg(test)]
@@ -181,11 +122,11 @@ mod test {
use halo2curves::bn256::Fr;
use crate::{
poly::{reduce::reduce_degree, ConstrDecomp, Expr::*, ToExpr},
poly::{ConstrDecomp, Expr::*, ToExpr},
sbpir::{query::Queriable, InternalSignal},
};
use super::{reduce_degree_mul, SignalFactory};
use super::{reduce_degree, reduce_degree_mul, SignalFactory};
#[derive(Default)]
struct TestSignalFactory {
@@ -211,7 +152,7 @@ mod test {
&mut decomp,
vec![a.expr(), b.expr(), c.expr()],
2,
2,
// 2,
&mut TestSignalFactory::default(),
);
@@ -229,7 +170,7 @@ mod test {
&mut decomp,
vec![(a + b), (b + c), (a + c)],
2,
2,
// 2,
&mut TestSignalFactory::default(),
);
@@ -247,7 +188,7 @@ mod test {
}
#[test]
fn test_reduce_degree() {
fn test_reduce_degree_all() {
let a: Queriable<Fr> = Queriable::Internal(InternalSignal::new("a"));
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
@@ -257,23 +198,23 @@ mod test {
let (result, decomp) =
reduce_degree(a * b * c * d * e, 2, &mut TestSignalFactory::default());
assert_eq!(format!("{:#?}", result), "(a * v1)");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v1))");
assert_eq!(format!("{:#?}", result), "(a * v3)");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
assert_eq!(decomp.constrs.len(), 3);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
assert_eq!(decomp.auto_signals.len(), 3);
let (result, decomp) = reduce_degree(
@@ -282,23 +223,23 @@ mod test {
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v1)))");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v1))");
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v3)))");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
assert_eq!(decomp.constrs.len(), 3);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
assert_eq!(decomp.auto_signals.len(), 3);
let (result, decomp) = reduce_degree(
@@ -307,28 +248,28 @@ mod test {
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "((a * v1) + (-(b * v3)))");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((a * a) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((a * v2) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((d * e) + (-v4))");
assert_eq!(format!("{:#?}", decomp.constrs[3]), "((c * v4) + (-v3))");
assert_eq!(format!("{:#?}", result), "((a * v2) + (-(b * v4)))");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((a * a) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((a * v1) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", decomp.constrs[3]), "((c * v3) + (-v4))");
assert_eq!(decomp.constrs.len(), 4);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * a)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * a)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * v2)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * v1)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (d * e)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (c * v4)"));
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (c * v3)"));
assert_eq!(decomp.auto_signals.len(), 4);
}
}