Simplify reduce

This commit is contained in:
Eduard S
2024-01-12 13:02:34 +00:00
parent bea09199be
commit 8f5ceea044

View File

@@ -7,17 +7,7 @@ use crate::{
use super::{ConstrDecomp, SignalFactory};
/// Reduces the degree of an PI by decomposing it in many PI with a maximum degree.
pub fn reduce_degre<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
ctx: &mut ConstrDecomp<F, V>,
constr: Expr<F, V>,
max_degree: usize,
signal_factory: &mut SF,
) -> Expr<F, V> {
reduce_degree_recursive(ctx, constr, max_degree, max_degree, signal_factory)
}
/// Actual recursive implementation of `reduce_degre`. Key here to understand the difference
/// Actual recursive implementation of `reduce_degree`. Key here to understand the difference
/// between: + total_max_degree: maximum degree of the PI the input expression is decomposed of.
/// + partial_max_degree: maximum degree of the root PI, that can substitute the orginal
/// expression.
@@ -32,18 +22,13 @@ pub fn reduce_degre<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sign
/// total_max_degree
/// };
/// ```
fn reduce_degree_recursive<
F: Field,
V: Clone + Eq + PartialEq + Hash + Debug,
SF: SignalFactory<V>,
>(
ctx: &mut ConstrDecomp<F, V>,
fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
decomp: &mut ConstrDecomp<F, V>,
constr: Expr<F, V>,
total_max_degree: usize,
partial_max_degree: usize,
max_degree: usize,
signal_factory: &mut SF,
) -> Expr<F, V> {
if constr.degree() <= partial_max_degree {
if constr.degree() <= max_degree {
return constr;
}
@@ -51,37 +36,21 @@ fn reduce_degree_recursive<
Expr::Const(_) => constr,
Expr::Sum(ses) => Expr::Sum(
ses.into_iter()
.map(|se| {
reduce_degree_recursive(
ctx,
se,
total_max_degree,
partial_max_degree,
signal_factory,
)
})
.map(|se| reduce_degree(decomp, se, max_degree, signal_factory))
.collect(),
),
Expr::Mul(ses) => reduce_degree_mul(
ctx,
ses,
total_max_degree,
partial_max_degree,
signal_factory,
),
Expr::Neg(se) => Expr::Neg(Box::new(reduce_degree_recursive(
ctx,
Expr::Mul(ses) => reduce_degree_mul(decomp, ses, max_degree, signal_factory),
Expr::Neg(se) => Expr::Neg(Box::new(reduce_degree(
decomp,
*se,
total_max_degree,
partial_max_degree,
max_degree,
signal_factory,
))),
// TODO: decompose in Pow expressions instead of Mul
Expr::Pow(se, exp) => reduce_degree_mul(
ctx,
decomp,
std::vec::from_elem(*se, exp as usize),
total_max_degree,
partial_max_degree,
max_degree,
signal_factory,
),
Expr::Query(_) => constr,
@@ -89,84 +58,43 @@ fn reduce_degree_recursive<
Expr::MI(_) => unimplemented!(),
}
}
fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
ctx: &mut ConstrDecomp<F, V>,
ses: Vec<Expr<F, V>>,
total_max_degree: usize,
partial_max_degree: usize,
decomp: &mut ConstrDecomp<F, V>,
mut ses: Vec<Expr<F, V>>,
max_degree: usize,
signal_factory: &mut SF,
) -> Expr<F, V> {
// base case, if partial_max_degree == 1, the root expresion can only be a variable
if partial_max_degree == 1 {
let reduction =
reduce_degree_mul(ctx, ses, total_max_degree, total_max_degree, signal_factory);
let signal = signal_factory.create("virtual signal");
ctx.auto_eq(signal.clone(), reduction);
return Expr::Query(signal);
assert!(max_degree > 1);
let mut tail = Vec::new();
// Remove multiplicands until ses is degree `max_degree-1`.
while ses.iter().map(|se| se.degree()).sum::<usize>() > max_degree - 1 {
tail.push(ses.pop().expect("ses.len() > 0"));
}
let ses = simplify_mul(ses);
// to reduce the problem for recursion, at least one expression should have lower degree than
// total_max_degree
let mut first = true;
let ses_reduced: Vec<Expr<F, V>> = ses
.into_iter()
.map(|se| {
let partial_max_degree = if first {
total_max_degree - 1
} else {
total_max_degree
};
let reduction = reduce_degree_recursive(
ctx,
se,
total_max_degree,
partial_max_degree,
signal_factory,
);
first = false;
reduction
})
.collect();
// for_root will be multipliers that will be included in the root expression
let mut for_root = Vec::new();
// to_simplify will be multipliers that will be recursively decomposed and subsituted by a
// virtual signal in the root expression
let mut to_simplify = Vec::new();
let mut current_degree = 0;
for se in ses_reduced {
if se.degree() + current_degree < partial_max_degree {
current_degree += se.degree();
for_root.push(se);
} else {
to_simplify.push(se);
if tail.len() == 0 {
// Input expression is below max_degree
Expr::Mul(ses)
} else if tail.len() == 1 && ses.len() == 0 {
// Input expression contains a single multiplicand with degree > 1, unwrap it and recurse.
reduce_degree(
decomp,
tail.pop().expect("tail.len() == 1"),
max_degree,
signal_factory,
)
} else {
// Only one multiplicand in the tail and it's already degree 1, so no reduction needed.
if tail.len() == 1 && tail[0].degree() == 1 {
ses.push(tail.pop().expect("tail.len() == 1"));
return Expr::Mul(ses);
}
// Reverse the tail to keep the original expression order
tail.reverse();
let reduction = reduce_degree_mul(decomp, tail, max_degree, signal_factory);
let signal = signal_factory.create("virtual signal");
decomp.auto_eq(signal.clone(), reduction);
ses.push(Expr::Query(signal));
Expr::Mul(ses)
}
assert!(!for_root.is_empty());
assert!(!to_simplify.is_empty());
let rest_signal = signal_factory.create("rest_expr");
for_root.push(Expr::Query(rest_signal.clone()));
let root_expr = Expr::Mul(for_root);
// recursion, for the part that exceeds the degree and will be substituted by a virtual signal
let simplified = reduce_degree_recursive(
ctx,
Expr::Mul(to_simplify),
total_max_degree,
total_max_degree,
signal_factory,
);
ctx.auto_eq(rest_signal, simplified);
root_expr
}
#[cfg(test)]
@@ -178,7 +106,7 @@ mod test {
sbpir::{query::Queriable, InternalSignal},
};
use super::{reduce_degre, reduce_degree_mul, SignalFactory};
use super::{reduce_degree, reduce_degree_mul, SignalFactory};
#[derive(Default)]
struct TestSignalFactory {
@@ -199,138 +127,138 @@ mod test {
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
let mut ctx = ConstrDecomp::new();
let mut decomp = ConstrDecomp::new();
let result = reduce_degree_mul(
&mut ctx,
&mut decomp,
vec![a.expr(), b.expr(), c.expr()],
2,
2,
// 2,
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "(a * v1)");
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((b * c) + (-v1))");
assert_eq!(ctx.constrs.len(), 1);
assert!(ctx
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((b * c) + (-v1))");
assert_eq!(decomp.constrs.len(), 1);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * c)"));
assert_eq!(ctx.auto_signals.len(), 1);
assert_eq!(decomp.auto_signals.len(), 1);
let mut ctx = ConstrDecomp::new();
let mut decomp = ConstrDecomp::new();
let result = reduce_degree_mul(
&mut ctx,
&mut decomp,
vec![(a + b), (b + c), (a + c)],
2,
2,
// 2,
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "((a + b) * v1)");
assert_eq!(
format!("{:#?}", ctx.constrs[0]),
format!("{:#?}", decomp.constrs[0]),
"(((b + c) * (a + c)) + (-v1))"
);
assert_eq!(ctx.constrs.len(), 1);
assert!(ctx
assert_eq!(decomp.constrs.len(), 1);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: ((b + c) * (a + c))"));
assert_eq!(ctx.auto_signals.len(), 1);
assert_eq!(decomp.auto_signals.len(), 1);
}
#[test]
fn test_reduce_degree() {
fn test_reduce_degree_all() {
let a: Queriable<Fr> = Queriable::Internal(InternalSignal::new("a"));
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
let d: Queriable<Fr> = Queriable::Internal(InternalSignal::new("d"));
let e: Queriable<Fr> = Queriable::Internal(InternalSignal::new("e"));
let mut ctx = ConstrDecomp::new();
let result = reduce_degre(
&mut ctx,
let mut decomp = ConstrDecomp::new();
let result = reduce_degree(
&mut decomp,
a * b * c * d * e,
2,
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "(a * v1)");
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", ctx.constrs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", ctx.constrs[2]), "((b * v2) + (-v1))");
assert_eq!(ctx.constrs.len(), 3);
assert!(ctx
assert_eq!(format!("{:#?}", result), "(a * v3)");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
assert_eq!(decomp.constrs.len(), 3);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
assert_eq!(ctx.auto_signals.len(), 3);
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
assert_eq!(decomp.auto_signals.len(), 3);
let mut ctx = ConstrDecomp::new();
let result = reduce_degre(
&mut ctx,
let mut decomp = ConstrDecomp::new();
let result = reduce_degree(
&mut decomp,
1.expr() - (a * b * c * d * e),
2,
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v1)))");
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", ctx.constrs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", ctx.constrs[2]), "((b * v2) + (-v1))");
assert_eq!(ctx.constrs.len(), 3);
assert!(ctx
assert_eq!(format!("{:#?}", result), "(0x1 + (-(a * v3)))");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((d * e) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((c * v1) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((b * v2) + (-v3))");
assert_eq!(decomp.constrs.len(), 3);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (d * e)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v1)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
assert_eq!(ctx.auto_signals.len(), 3);
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (b * v2)"));
assert_eq!(decomp.auto_signals.len(), 3);
let mut ctx = ConstrDecomp::new();
let result = reduce_degre(
&mut ctx,
let mut decomp = ConstrDecomp::new();
let result = reduce_degree(
&mut decomp,
Pow(Box::new(a.expr()), 4) - (b * c * d * e),
2,
&mut TestSignalFactory::default(),
);
assert_eq!(format!("{:#?}", result), "((a * v1) + (-(b * v3)))");
assert_eq!(format!("{:#?}", ctx.constrs[0]), "((a * a) + (-v2))");
assert_eq!(format!("{:#?}", ctx.constrs[1]), "((a * v2) + (-v1))");
assert_eq!(format!("{:#?}", ctx.constrs[2]), "((d * e) + (-v4))");
assert_eq!(format!("{:#?}", ctx.constrs[3]), "((c * v4) + (-v3))");
assert_eq!(ctx.constrs.len(), 4);
assert!(ctx
assert_eq!(format!("{:#?}", result), "((a * v2) + (-(b * v4)))");
assert_eq!(format!("{:#?}", decomp.constrs[0]), "((a * a) + (-v1))");
assert_eq!(format!("{:#?}", decomp.constrs[1]), "((a * v1) + (-v2))");
assert_eq!(format!("{:#?}", decomp.constrs[2]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", decomp.constrs[3]), "((c * v3) + (-v4))");
assert_eq!(decomp.constrs.len(), 4);
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * a)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * a)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * v2)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * v1)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (d * e)"));
assert!(ctx
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
assert!(decomp
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (c * v4)"));
assert_eq!(ctx.auto_signals.len(), 4);
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (c * v3)"));
assert_eq!(decomp.auto_signals.len(), 4);
}
}