mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 03:25:05 -05:00
chore(optimizer): fix formatting
This commit is contained in:
@@ -54,7 +54,8 @@ mod tests {
|
||||
&SHARED_CACHES,
|
||||
p_cut,
|
||||
default_partition,
|
||||
).map(|v| v.1)
|
||||
)
|
||||
.map(|v| v.1)
|
||||
}
|
||||
|
||||
fn optimize_single(dag: &unparametrized::OperationDag) -> Option<Parameters> {
|
||||
@@ -71,15 +72,12 @@ mod tests {
|
||||
if sol_multi.is_none() {
|
||||
return None;
|
||||
}
|
||||
let equiv = sol_mono.best_solution.unwrap().complexity
|
||||
== sol_multi.as_ref().unwrap().complexity;
|
||||
let equiv =
|
||||
sol_mono.best_solution.unwrap().complexity == sol_multi.as_ref().unwrap().complexity;
|
||||
if !equiv {
|
||||
eprintln!("Not same complexity");
|
||||
eprintln!("Single: {:?}", sol_mono.best_solution.unwrap());
|
||||
eprintln!(
|
||||
"Multi: {:?}",
|
||||
sol_multi.clone().unwrap().complexity
|
||||
);
|
||||
eprintln!("Multi: {:?}", sol_multi.clone().unwrap().complexity);
|
||||
eprintln!("Multi: {:?}", sol_multi.unwrap());
|
||||
}
|
||||
Some(equiv)
|
||||
@@ -136,9 +134,7 @@ mod tests {
|
||||
let feasible_2 = sol_single_2.best_solution.is_some();
|
||||
let feasible_multi = sol_multi.is_some();
|
||||
if (feasible_1 && feasible_2) != feasible_multi {
|
||||
eprintln!(
|
||||
"Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"
|
||||
);
|
||||
eprintln!("Not same feasibility {feasible_1} {feasible_2} {feasible_multi}");
|
||||
return Some(false);
|
||||
}
|
||||
if sol_multi.is_none() {
|
||||
@@ -150,15 +146,13 @@ mod tests {
|
||||
let cost_1 = sol_single_1.best_solution.unwrap().complexity;
|
||||
let cost_2 = sol_single_2.best_solution.unwrap().complexity;
|
||||
let cost_multi = sol_multi.complexity;
|
||||
let equiv =
|
||||
cost_1 + cost_2 == cost_multi
|
||||
let equiv = cost_1 + cost_2 == cost_multi
|
||||
&& cost_1 == sol_multi_1.complexity
|
||||
&& cost_2 == sol_multi_2.complexity
|
||||
&& sol_multi.micro_params.ks[0][0].unwrap().decomp ==
|
||||
sol_multi_1.micro_params.ks[0][0].unwrap().decomp
|
||||
&& sol_multi.micro_params.ks[1][1].unwrap().decomp ==
|
||||
sol_multi_2.micro_params.ks[0][0].unwrap().decomp
|
||||
;
|
||||
&& sol_multi.micro_params.ks[0][0].unwrap().decomp
|
||||
== sol_multi_1.micro_params.ks[0][0].unwrap().decomp
|
||||
&& sol_multi.micro_params.ks[1][1].unwrap().decomp
|
||||
== sol_multi_2.micro_params.ks[0][0].unwrap().decomp;
|
||||
if !equiv {
|
||||
eprintln!("Not same complexity");
|
||||
eprintln!("Multi: {cost_multi:?}");
|
||||
@@ -194,7 +188,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn dag_lut_sum_of_2_partitions_2_layer(precision1: u8, precision2: u8, final_lut: bool) -> unparametrized::OperationDag {
|
||||
fn dag_lut_sum_of_2_partitions_2_layer(
|
||||
precision1: u8,
|
||||
precision2: u8,
|
||||
final_lut: bool,
|
||||
) -> unparametrized::OperationDag {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision1, Shape::number());
|
||||
let input2 = dag.add_input(precision2, Shape::number());
|
||||
@@ -212,12 +210,12 @@ mod tests {
|
||||
#[test]
|
||||
fn optimize_multi_independant_2_partitions_finally_added() {
|
||||
let default_partition = 0;
|
||||
let single_precision_sol : Vec<_> = (0..11).map(
|
||||
|precision| {
|
||||
let single_precision_sol: Vec<_> = (0..11)
|
||||
.map(|precision| {
|
||||
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false);
|
||||
optimize_single(&dag)
|
||||
}
|
||||
).collect();
|
||||
})
|
||||
.collect();
|
||||
|
||||
for precision1 in 1..11 {
|
||||
for precision2 in (precision1 + 1)..11 {
|
||||
@@ -240,7 +238,10 @@ mod tests {
|
||||
assert!(sol_multi.complexity < sol_2.complexity);
|
||||
}
|
||||
eprintln!("{:?}", sol_multi.micro_params.fks);
|
||||
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
|
||||
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2]
|
||||
[default_partition]
|
||||
.unwrap()
|
||||
.complexity;
|
||||
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
|
||||
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
|
||||
if REAL_FAST_KS {
|
||||
@@ -267,7 +268,8 @@ mod tests {
|
||||
};
|
||||
assert!(
|
||||
sol_multi_without_fks / perfect_complexity < maximal_relative_degratdation,
|
||||
"{precision1} {precision2} {} < {maximal_relative_degratdation}", sol_multi_without_fks / perfect_complexity
|
||||
"{precision1} {precision2} {} < {maximal_relative_degratdation}",
|
||||
sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -276,12 +278,12 @@ mod tests {
|
||||
#[test]
|
||||
fn optimize_multi_independant_2_partitions_finally_added_and_luted() {
|
||||
let default_partition = 0;
|
||||
let single_precision_sol : Vec<_> = (0..11).map(
|
||||
|precision| {
|
||||
let single_precision_sol: Vec<_> = (0..11)
|
||||
.map(|precision| {
|
||||
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true);
|
||||
optimize_single(&dag)
|
||||
}
|
||||
).collect();
|
||||
})
|
||||
.collect();
|
||||
for precision1 in 1..11 {
|
||||
for precision2 in (precision1 + 1)..11 {
|
||||
let p_cut = Some(PrecisionCut {
|
||||
@@ -301,7 +303,10 @@ mod tests {
|
||||
// The smallest the precision the more fks noise dominate
|
||||
assert!(sol_1.complexity < sol_multi.complexity);
|
||||
assert!(sol_multi.complexity < sol_2.complexity);
|
||||
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
|
||||
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2]
|
||||
[default_partition]
|
||||
.unwrap()
|
||||
.complexity;
|
||||
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
|
||||
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
|
||||
let relative_degradation = sol_multi_without_fks / perfect_complexity;
|
||||
@@ -325,7 +330,8 @@ mod tests {
|
||||
};
|
||||
assert!(
|
||||
relative_degradation < maxim_relative_degradation,
|
||||
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
|
||||
"{precision1} {precision2} {}",
|
||||
sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -362,17 +368,24 @@ mod tests {
|
||||
dag
|
||||
}
|
||||
|
||||
fn test_optimize_v3_expanded_round(precision_acc: usize, precision_tlu: usize, minimal_speedup: f64) {
|
||||
fn test_optimize_v3_expanded_round(
|
||||
precision_acc: usize,
|
||||
precision_tlu: usize,
|
||||
minimal_speedup: f64,
|
||||
) {
|
||||
let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu);
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag)
|
||||
.best_solution
|
||||
.unwrap();
|
||||
let sol = optimize_rounded(&dag).unwrap();
|
||||
let speedup = sol_mono.complexity / sol.complexity;
|
||||
assert!(speedup >= minimal_speedup,
|
||||
assert!(
|
||||
speedup >= minimal_speedup,
|
||||
"Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}"
|
||||
);
|
||||
let expected_ks = [
|
||||
[true, true], // KS[0], KS[0->1]
|
||||
[false, true],// KS[1]
|
||||
[true, true], // KS[0], KS[0->1]
|
||||
[false, true], // KS[1]
|
||||
];
|
||||
let expected_fks = [
|
||||
[false, false],
|
||||
@@ -390,7 +403,6 @@ mod tests {
|
||||
test_optimize_v3_expanded_round(16, 8, 5.5);
|
||||
} else {
|
||||
test_optimize_v3_expanded_round(16, 8, 3.9);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,10 +421,13 @@ mod tests {
|
||||
let input1 = dag.add_input(16, Shape::number());
|
||||
_ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16);
|
||||
let sol = optimize_rounded(&dag).unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag)
|
||||
.best_solution
|
||||
.unwrap();
|
||||
let minimal_speedup = 8.6;
|
||||
let speedup = sol_mono.complexity / sol.complexity;
|
||||
assert!(speedup >= minimal_speedup,
|
||||
assert!(
|
||||
speedup >= minimal_speedup,
|
||||
"Speedup {speedup} smaller than {minimal_speedup}"
|
||||
);
|
||||
}
|
||||
@@ -436,14 +451,13 @@ mod tests {
|
||||
let rounded1 = dag.add_expanded_round(input1, 1);
|
||||
let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1);
|
||||
let sol = optimize_rounded(&dag).unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag)
|
||||
.best_solution
|
||||
.unwrap();
|
||||
let speedup = sol_mono.complexity / sol.complexity;
|
||||
let minimal_speedup = if REAL_FAST_KS {
|
||||
80.0
|
||||
} else {
|
||||
30.0
|
||||
};
|
||||
assert!(speedup >= minimal_speedup,
|
||||
let minimal_speedup = if REAL_FAST_KS { 80.0 } else { 30.0 };
|
||||
assert!(
|
||||
speedup >= minimal_speedup,
|
||||
"Speedup {speedup} smaller than {minimal_speedup}"
|
||||
);
|
||||
}
|
||||
@@ -455,7 +469,7 @@ mod tests {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let min_precision = 6;
|
||||
let max_precision = 8;
|
||||
let mut input_precisions : Vec<_> = (min_precision..=max_precision).collect();
|
||||
let mut input_precisions: Vec<_> = (min_precision..=max_precision).collect();
|
||||
if decreasing {
|
||||
input_precisions.reverse();
|
||||
}
|
||||
@@ -463,9 +477,13 @@ mod tests {
|
||||
for &out_precision in &input_precisions {
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
|
||||
}
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *input_precisions.last().unwrap());
|
||||
lut_input = dag.add_lut(
|
||||
lut_input,
|
||||
FunctionTable::UNKWOWN,
|
||||
*input_precisions.last().unwrap(),
|
||||
);
|
||||
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision);
|
||||
let mut p_cut = PrecisionCut { p_cut:vec![] };
|
||||
let mut p_cut = PrecisionCut { p_cut: vec![] };
|
||||
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
|
||||
assert!(sol.macro_params.len() == 1);
|
||||
let mut complexity = sol.complexity;
|
||||
@@ -479,19 +497,27 @@ mod tests {
|
||||
eprintln!("PCUT {p_cut}");
|
||||
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
|
||||
let nb_partitions = sol.macro_params.len();
|
||||
assert!(nb_partitions == (p_cut.p_cut.len() + 1),
|
||||
"bad nb partitions {} {p_cut}", sol.macro_params.len());
|
||||
assert!(sol.complexity < complexity,
|
||||
"{} < {complexity} {out_precision} / {max_precision}", sol.complexity);
|
||||
assert!(
|
||||
nb_partitions == (p_cut.p_cut.len() + 1),
|
||||
"bad nb partitions {} {p_cut}",
|
||||
sol.macro_params.len()
|
||||
);
|
||||
assert!(
|
||||
sol.complexity < complexity,
|
||||
"{} < {complexity} {out_precision} / {max_precision}",
|
||||
sol.complexity
|
||||
);
|
||||
for (src, dst) in cross_partition(nb_partitions) {
|
||||
let ks = sol.micro_params.ks[src][dst];
|
||||
eprintln!("{} {src} {dst}", ks.is_some());
|
||||
let expected_ks =
|
||||
(!decreasing || src == dst + 1)
|
||||
&& (decreasing || src + 1 == dst)
|
||||
|| (src == dst && (src == 0 || src == nb_partitions - 1))
|
||||
;
|
||||
assert!(ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks);
|
||||
let expected_ks = (!decreasing || src == dst + 1) && (decreasing || src + 1 == dst)
|
||||
|| (src == dst && (src == 0 || src == nb_partitions - 1));
|
||||
assert!(
|
||||
ks.is_some() == expected_ks,
|
||||
"{:?} {:?}",
|
||||
ks.is_some(),
|
||||
expected_ks
|
||||
);
|
||||
let fks = sol.micro_params.fks[src][dst];
|
||||
assert!(fks.is_none());
|
||||
}
|
||||
@@ -515,16 +541,16 @@ mod tests {
|
||||
// max v0 weight for each precision
|
||||
1_073_741_824,
|
||||
1_073_741_824, // 2**30, 1b
|
||||
536_870_912, // 2**29, 2b
|
||||
268_435_456, // 2**28, 3b
|
||||
67_108_864, // 2**26, 4b
|
||||
16_777_216, // 2**24, 5b
|
||||
4_194_304, // 2**22, 6b
|
||||
1_048_576, // 2**20, 7b
|
||||
262_144, // 2**18, 8b
|
||||
65_536, // 2**16, 9b
|
||||
16384, // 2**14, 10b
|
||||
2048, // 2**11, 11b
|
||||
536_870_912, // 2**29, 2b
|
||||
268_435_456, // 2**28, 3b
|
||||
67_108_864, // 2**26, 4b
|
||||
16_777_216, // 2**24, 5b
|
||||
4_194_304, // 2**22, 6b
|
||||
1_048_576, // 2**20, 7b
|
||||
262_144, // 2**18, 8b
|
||||
65_536, // 2**16, 9b
|
||||
16384, // 2**14, 10b
|
||||
2048, // 2**11, 11b
|
||||
];
|
||||
|
||||
#[test]
|
||||
@@ -539,7 +565,7 @@ mod tests {
|
||||
let mut optimal_complexity = sol_single.as_ref().unwrap().complexity;
|
||||
let mut optimal_p_error = sol_single.unwrap().p_error;
|
||||
for &out_precision in &precisions[1..] {
|
||||
let noise_factor = MAX_WEIGHT[out_precision] as f64;
|
||||
let noise_factor = MAX_WEIGHT[out_precision] as f64;
|
||||
add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor);
|
||||
let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor));
|
||||
optimal_complexity += sol_single.as_ref().unwrap().complexity;
|
||||
@@ -566,10 +592,20 @@ mod tests {
|
||||
let mut lut_input = dag.add_input(precisions[0], Shape::number());
|
||||
for out_precision in precisions {
|
||||
let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64;
|
||||
lut_input = dag.add_levelled_op([lut_input], LevelledComplexity::ZERO, noise_factor, Shape::number(), "");
|
||||
lut_input = dag.add_levelled_op(
|
||||
[lut_input],
|
||||
LevelledComplexity::ZERO,
|
||||
noise_factor,
|
||||
Shape::number(),
|
||||
"",
|
||||
);
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
|
||||
}
|
||||
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *precisions.last().unwrap());
|
||||
_ = dag.add_lut(
|
||||
lut_input,
|
||||
FunctionTable::UNKWOWN,
|
||||
*precisions.last().unwrap(),
|
||||
);
|
||||
let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution;
|
||||
assert!(sol_single.is_none());
|
||||
let sol = optimize(&dag, &None, 0);
|
||||
|
||||
Reference in New Issue
Block a user