From 38646b7559d388b7eef3a7658f40a5a19a687c7e Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 24 Feb 2023 11:12:00 +0100 Subject: [PATCH] feat(optimizer): dag partitionning based on p_cut --- .../optimization/dag/multi_parameters/mod.rs | 4 + .../dag/multi_parameters/partitionning.rs | 591 ++++++++++++++++++ .../dag/multi_parameters/partitions.rs | 47 ++ .../dag/multi_parameters/precision_cut.rs | 49 ++ .../dag/multi_parameters/union_find.rs | 70 +++ 5 files changed, 761 insertions(+) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/union_find.rs diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 0772764fb..8885311e3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -1 +1,5 @@ pub mod keys_spec; +pub mod partitionning; +pub(crate) mod partitions; +pub(crate) mod precision_cut; +pub(crate) mod union_find; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs new file mode 100644 index 000000000..1fd8eeed0 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -0,0 +1,591 @@ +use std::collections::{HashMap, HashSet}; + +use crate::dag::operator::{Operator, OperatorIndex}; +use crate::dag::unparametrized; + +use super::partitions::{InstructionPartition, PartitionIndex, Partitions, Transition}; +use super::precision_cut::PrecisionCut; +use super::union_find::UnionFind; +type Op = Operator; + +// Blocks of instructions +pub struct Blocks { + // Set of instructions indexes for each block + pub blocks: Vec>, + // Block index of each instructions + pub block_of: Vec, +} + +impl Blocks { + pub fn from(mut uf: UnionFind) -> Self { + let mut block_of_canon: HashMap = HashMap::new(); + let mut blocks: Vec> = vec![]; + let size = uf.parent.len(); + for op_i in 0..size { + let canon = uf.find_canonical(op_i); + // the canonic is always the smaller, so it's the first + if canon == op_i { + let block_i = blocks.len(); + _ = block_of_canon.insert(canon, block_i); + blocks.push(vec![canon]); + } else { + let &block_i = block_of_canon.get(&canon).unwrap(); + blocks[block_i].push(op_i); + } + } + let mut block_of = vec![0; size]; + for (i, block) in blocks.iter().enumerate() { + for &a in block { + block_of[a] = i; + } + } + Self { blocks, block_of } + } +} + +// Extract block of instructions connected by levelled ops. +// This facilitates reasonning about conflicts on levelled ops. +#[allow(clippy::match_same_arms)] +fn extract_levelled_block(dag: &unparametrized::OperationDag) -> Blocks { + let mut uf = UnionFind::new(dag.operators.len()); + for (op_i, op) in dag.operators.iter().enumerate() { + match op { + // Block entry point + Operator::Input { .. } => (), + // Block entry point and pre-exit point + Op::Lut { .. } => (), + // Connectors + Op::UnsafeCast { input, .. } => uf.union(input.i, op_i), + Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { + for input in inputs { + uf.union(input.i, op_i); + } + } + Op::Round { .. } => unreachable!("Round should have been expanded"), + }; + } + Blocks::from(uf) +} + +#[derive(Clone, Debug, Default)] +struct BlockConstraints { + forced: HashSet, // hard constraints, need to be resolved, given by PartitionFromOp + exit: HashSet, // soft constraints, to have less inter partition keyswitch in TLUs +} + +/* For each levelled block collect BlockConstraints */ +fn levelled_blocks_constraints( + dag: &unparametrized::OperationDag, + blocks: &Blocks, + p_cut: &PrecisionCut, +) -> Vec { + let mut constraints_by_block = vec![BlockConstraints::default(); blocks.blocks.len()]; + for (block_i, ops_i) in blocks.blocks.iter().enumerate() { + for &op_i in ops_i { + let op = &dag.operators[op_i]; + if let Some(partition) = p_cut.partition(dag, op) { + _ = constraints_by_block[block_i].forced.insert(partition); + if let Some(input) = op_tlu_inputs(op) { + let input_group = blocks.block_of[input.i]; + constraints_by_block[input_group].exit.extend([partition]); + } + } + } + } + constraints_by_block +} + +fn op_tlu_inputs(op: &Operator) -> Option { + match op { + Op::Lut { input, .. } => Some(*input), + _ => None, + } +} + +fn get_singleton_value(hashset: &HashSet) -> V { + *hashset.iter().next().unwrap() +} + +fn only_1_partition(dag: &unparametrized::OperationDag) -> Partitions { + let mut instrs_partition = vec![InstructionPartition::new(0); dag.operators.len()]; + for (op_i, op) in dag.operators.iter().enumerate() { + match op { + Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { + instrs_partition[op_i].inputs_transition = vec![None; inputs.len()]; + } + Op::Lut { .. } | Op::UnsafeCast { .. } => { + instrs_partition[op_i].inputs_transition = vec![None]; + } + Op::Input { .. } => (), + Op::Round { .. } => unreachable!(), + } + } + Partitions { + nb_partitions: 1, + instrs_partition, + } +} + +fn resolve_by_levelled_block( + dag: &unparametrized::OperationDag, + p_cut: &PrecisionCut, + default_partition: PartitionIndex, +) -> Partitions { + let blocks = extract_levelled_block(dag); + let constraints_by_blocks = levelled_blocks_constraints(dag, &blocks, p_cut); + let present_partitions: HashSet = constraints_by_blocks + .iter() + .flat_map(|c| &c.forced) + .copied() + .collect(); + let nb_partitions = present_partitions.len().max(1); // no tlu = no constraints + if nb_partitions == 1 { + return only_1_partition(dag); + } + let mut block_partition: Vec = vec![]; + for constraints in constraints_by_blocks { + let partition = match constraints.forced.len() { + 0 => { + if constraints.exit.len() == 1 { + get_singleton_value(&constraints.exit) + } else { + default_partition + } + } + 1 => get_singleton_value(&constraints.forced), + _ => default_partition, // conflicts solved to default + }; + // TODO1: Could choose based on the number of fast keyswitch added (case > 1) + // TODO2: A conversion of an entry point could be deffered to the conflict until a conversion is needed + // This is equivalent to refine levelled block + // TODO3: This could make even make some exit value used in a different representation and go out unconverted + // This can reduce the need to define extra parameters for internal ks + block_partition.push(partition); + } + let mut instrs_p: Vec = + vec![InstructionPartition::new(default_partition); dag.operators.len()]; + let block_partition_of = |op_i| block_partition[blocks.block_of[op_i]]; + for (op_i, op) in dag.operators.iter().enumerate() { + let group_partition = block_partition_of(op_i); + match op { + Op::Lut { input, .. } => { + let instruction_partition = p_cut.partition(dag, op).unwrap(); + instrs_p[op_i].instruction_partition = instruction_partition; + let input_partition = instrs_p[input.i].instruction_partition; + instrs_p[op_i].inputs_transition = if input_partition == instruction_partition { + vec![None] + } else { + vec![Some(Transition::Internal { + src_partition: input_partition, + })] + }; + if group_partition != instruction_partition { + instrs_p[op_i].alternative_output_representation = + HashSet::from([group_partition]); + } + } + Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { + instrs_p[op_i].instruction_partition = group_partition; + instrs_p[op_i].inputs_transition = vec![None; inputs.len()]; + for (i, input) in inputs.iter().enumerate() { + let input_partition = instrs_p[input.i].instruction_partition; + if group_partition != input_partition { + instrs_p[op_i].inputs_transition[i] = Some(Transition::Additional { + src_partition: input_partition, + }); + } + } + } + Op::UnsafeCast { input, .. } => { + instrs_p[op_i].instruction_partition = group_partition; + let input_partition = instrs_p[input.i].instruction_partition; + instrs_p[op_i].inputs_transition = if group_partition == input_partition { + vec![None] + } else { + vec![Some(Transition::Additional { + src_partition: input_partition, + })] + } + } + Operator::Input { .. } => instrs_p[op_i].instruction_partition = group_partition, + Op::Round { .. } => unreachable!("Round should have been expanded"), + } + } + Partitions { + nb_partitions, + instrs_partition: instrs_p, + } + // Now we can generate transitions + // Input has no transtions + // Tlu has internal transtions based on input partition + // Tlu has immediate external transition if needed +} + +pub fn partitionning_with_preferred( + dag: &unparametrized::OperationDag, + p_cut: &PrecisionCut, + default_partition: PartitionIndex, +) -> Partitions { + if p_cut.p_cut.is_empty() { + only_1_partition(dag) + } else { + resolve_by_levelled_block(dag, p_cut, default_partition) + } +} + +#[cfg(test)] +pub mod tests { + + // 2 Partitions labels + pub const LOW_PRECISION_PARTITION: PartitionIndex = 0; + pub const HIGH_PRECISION_PARTITION: PartitionIndex = 1; + + use super::*; + use crate::dag::operator::{FunctionTable, Shape, Weights}; + use crate::dag::unparametrized; + + fn default_p_cut() -> PrecisionCut { + PrecisionCut { p_cut: vec![2] } + } + + fn partitionning_no_p_cut(dag: &unparametrized::OperationDag) -> Partitions { + let p_cut = PrecisionCut { p_cut: vec![] }; + partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION) + } + + fn partitionning(dag: &unparametrized::OperationDag) -> Partitions { + partitionning_with_preferred(dag, &default_p_cut(), LOW_PRECISION_PARTITION) + } + + fn partitionning_with_preferred( + dag: &unparametrized::OperationDag, + p_cut: &PrecisionCut, + default_partition: usize, + ) -> Partitions { + super::partitionning_with_preferred(dag, p_cut, default_partition) + } + + pub fn show_partitionning( + dag: &unparametrized::OperationDag, + partitions: &[InstructionPartition], + ) { + println!("Dag:"); + for (i, op) in dag.operators.iter().enumerate() { + let partition = partitions[i].instruction_partition; + print!("P {partition}"); + if partitions[i].alternative_output_representation.is_empty() { + print!(" _"); + } else { + print!(" +FKS{:?}", partitions[i].alternative_output_representation); + }; + // partition + if !partitions[i].inputs_transition.is_empty() { + print!(" <- "); + // type in + for arg in &partitions[i].inputs_transition { + match arg { + None => print!("_,"), + Some(Transition::Internal { src_partition }) => { + print!("{src_partition}&KS,"); + } + Some(Transition::Additional { src_partition }) => { + print!("{src_partition}+FKS,"); + } + }; + } + } + println!(); + println!("%{i} <- {op}"); + } + } + + #[test] + fn test_1_partition() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(16, Shape::number()); + _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8); + let instrs_partition = partitionning_no_p_cut(&dag).instrs_partition; + for instr_partition in instrs_partition { + assert!(instr_partition.instruction_partition == LOW_PRECISION_PARTITION); + assert!(instr_partition.no_transition()); + } + } + + #[test] + fn test_1_input_2_partitions() { + let mut dag = unparametrized::OperationDag::new(); + _ = dag.add_input(1, Shape::number()); + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 1); + let instrs_partition = partitions.instrs_partition; + assert!(instrs_partition[0].instruction_partition == LOW_PRECISION_PARTITION); + assert!(partitions.nb_partitions == 1); + } + + #[test] + fn test_2_lut_sequence() { + let mut dag = unparametrized::OperationDag::new(); + let mut expected_partitions = vec![]; + let input1 = dag.add_input(8, Shape::number()); + expected_partitions.push(HIGH_PRECISION_PARTITION); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); + expected_partitions.push(HIGH_PRECISION_PARTITION); + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 1); + expected_partitions.push(HIGH_PRECISION_PARTITION); + let lut3 = dag.add_lut(lut2, FunctionTable::UNKWOWN, 1); + expected_partitions.push(LOW_PRECISION_PARTITION); + let lut4 = dag.add_lut(lut3, FunctionTable::UNKWOWN, 8); + expected_partitions.push(LOW_PRECISION_PARTITION); + let lut5 = dag.add_lut(lut4, FunctionTable::UNKWOWN, 8); + expected_partitions.push(HIGH_PRECISION_PARTITION); + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 2); + let instrs_partition = partitions.instrs_partition; + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.i]; + show_partitionning(&dag, &instrs_partition); + assert!(consider(input1).instruction_partition == HIGH_PRECISION_PARTITION); // no constraint + assert!(consider(lut1).instruction_partition == expected_partitions[1]); + assert!(consider(lut2).instruction_partition == expected_partitions[2]); + assert!(consider(lut3).instruction_partition == expected_partitions[3]); + assert!(consider(lut4).instruction_partition == expected_partitions[4]); + assert!(consider(lut5).instruction_partition == expected_partitions[5]); + assert!(instrs_partition.len() == 6); + } + + #[test] + fn test_mixed_dot_no_conflict_low() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(8, Shape::number()); + let input2 = dag.add_input(1, Shape::number()); + let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); + let _dot = dag.add_dot([input1, lut2], Weights::from([1, 1])); + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 1); + } + + #[test] + fn test_mixed_dot_no_conflict_high() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(8, Shape::number()); + let input2 = dag.add_input(1, Shape::number()); + let lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, 1); + let _dot = dag.add_dot([input2, lut2], Weights::from([1, 1])); + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 1); + } + + #[test] + fn test_mixed_dot_conflict() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(8, Shape::number()); + let input2 = dag.add_input(1, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); + let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8); + let dot = dag.add_dot([lut1, lut2], Weights::from([1, 1])); + let partitions = partitionning(&dag); + let consider = |op_i: OperatorIndex| &partitions.instrs_partition[op_i.i]; + // input1 + let p = consider(input1); + { + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + assert!(p.no_transition()); + }; + // input2 + let p = consider(input2); + { + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + assert!(p.no_transition()); + }; + // lut1 , used in low partition dot + let p = consider(lut1); + { + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + assert!( + p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION]) + ); + assert!(p.inputs_transition == vec![None]); + }; + // lut2 + let p = consider(lut2); + { + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + assert!(p.no_transition()); + }; + // dot + let p = consider(dot); + { + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + assert!(p.alternative_output_representation.is_empty()); + assert!( + p.inputs_transition + == vec![ + Some(Transition::Additional { + src_partition: HIGH_PRECISION_PARTITION + }), + None + ] + ); + }; + } + + #[test] + fn test_rounded_v3_first_layer_and_second_layer() { + let acc_precision = 8; + let precision = 6; + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(acc_precision, Shape::number()); + let rounded1 = dag.add_expanded_round(input1, precision); + let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); + let rounded2 = dag.add_expanded_round(lut1, precision); + let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision); + let partitions = partitionning(&dag); + let consider = |op_i| &partitions.instrs_partition[op_i]; + // First layer is fully LOW_PRECISION_PARTITION + for op_i in input1.i..lut1.i { + let p = consider(op_i); + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + assert!(p.no_transition()); + } + // First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION + let p = consider(lut1.i); + { + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + assert!( + p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION]) + ); + assert!( + p.inputs_transition + == vec![Some(Transition::Internal { + src_partition: LOW_PRECISION_PARTITION + })] + ); + }; + for op_i in (lut1.i + 1)..lut2.i { + let p = consider(op_i); + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + } + let p = consider(lut2.i); + { + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + assert!(p.alternative_output_representation.is_empty()); + assert!( + p.inputs_transition + == vec![Some(Transition::Internal { + src_partition: LOW_PRECISION_PARTITION + })] + ); + }; + } + + #[test] + fn test_rounded_v3_classic_first_layer_second_layer() { + let acc_precision = 8; + let precision = 6; + let mut dag = unparametrized::OperationDag::new(); + let free_input1 = dag.add_input(precision, Shape::number()); + let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); + let first_layer = free_input1.i..=input1.i; + let rounded1 = dag.add_expanded_round(input1, precision); + let rounded_layer: Vec<_> = ((input1.i + 1)..rounded1.i).collect(); + let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); + let partitions = partitionning(&dag); + let consider = |op_i: usize| &partitions.instrs_partition[op_i]; + + // First layer is fully HIGH_PRECISION_PARTITION + for op_i in first_layer { + let p = consider(op_i); + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + } + // input is converted with a fast keyswitch to LOW_PRECISION_PARTITION + let p = consider(input1.i); + assert!(p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION])); + let read_converted = Some(Transition::Additional { + src_partition: HIGH_PRECISION_PARTITION, + }); + + // Second layer, rounded part is LOW_PRECISION_PARTITION + for &op_i in &rounded_layer { + let p = consider(op_i); + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + } + // and use read the conversion result + let mut first_bit_extract_verified = false; + let mut first_bit_erase_verified = false; + for &op_i in &rounded_layer { + let p = consider(op_i); + if let Op::Dot { weights, .. } = &dag.operators[op_i] { + let first_bit_extract = weights.values == [256] && !first_bit_extract_verified; + let first_bit_erase = weights.values == [1, -1] && !first_bit_erase_verified; + if first_bit_extract || first_bit_erase { + assert!(p.inputs_transition[0] == read_converted); + } + first_bit_extract_verified = first_bit_extract_verified || first_bit_extract; + first_bit_erase_verified = first_bit_erase_verified || first_bit_erase; + }; + } + assert!(first_bit_extract_verified); + assert!(first_bit_erase_verified); + // Second layer, lut part is HIGH_PRECISION_PARTITION + // and use an internal conversion + let p = consider(lut1.i); + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + assert!( + p.inputs_transition[0] + == Some(Transition::Internal { + src_partition: LOW_PRECISION_PARTITION + }) + ); + } + + #[test] + fn test_rounded_v1_classic_first_layer_second_layer() { + let acc_precision = 8; + let precision = 6; + let mut dag = unparametrized::OperationDag::new(); + let free_input1 = dag.add_input(precision, Shape::number()); + let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); + let first_layer = free_input1.i..=input1.i; + let rounded1 = dag.add_expanded_round(input1, precision); + let rounded_layer = (input1.i + 1)..rounded1.i; + let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); + let partitions = + partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION); + show_partitionning(&dag, &partitions.instrs_partition); + let consider = |op_i: usize| &partitions.instrs_partition[op_i]; + + // First layer is fully HIGH_PRECISION_PARTITION + for op_i in first_layer { + assert!(consider(op_i).instruction_partition == HIGH_PRECISION_PARTITION); + } + // input is converted with a fast keyswitch to LOW_PRECISION_PARTITION + assert!(consider(input1.i) + .alternative_output_representation + .is_empty()); + let read_converted = Some(Transition::Additional { + src_partition: LOW_PRECISION_PARTITION, + }); + + // Second layer, rounded part is mostly HIGH_PRECISION_PARTITION + // Only the Lut is post-converted + for op_i in rounded_layer { + let p = consider(op_i); + match &dag.operators[op_i] { + Op::Lut { .. } => { + assert!(p.instruction_partition == LOW_PRECISION_PARTITION); + assert!( + p.alternative_output_representation + == HashSet::from([HIGH_PRECISION_PARTITION]) + ); + } + Op::Dot { weights, .. } => { + assert!(p.instruction_partition == HIGH_PRECISION_PARTITION); + assert!(p.inputs_transition[0].is_none()); + if weights.values.len() == 2 { + assert!(p.inputs_transition[1] == read_converted); + } + } + _ => assert!(p.instruction_partition == HIGH_PRECISION_PARTITION), + } + } + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs new file mode 100644 index 000000000..376c24a3e --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs @@ -0,0 +1,47 @@ +use std::collections::HashSet; + +pub type PartitionIndex = usize; +pub type AdditionalRepresentations = HashSet; + +// How one input is made compatible with the instruction partition +#[derive(Clone, Debug, PartialEq, Eq)] + +pub enum Transition { + // The input rely on an already converted input, multi representation value + Additional { src_partition: PartitionIndex }, + // The input can be converted directly by the internal instructions keyswitch + Internal { src_partition: PartitionIndex }, +} + +// One instruction partition is computed for each instruction. +// It represents its partition and relations with other partitions. +#[derive(Clone, Debug, Default)] +pub struct InstructionPartition { + // The partition assigned to the instruction + pub instruction_partition: PartitionIndex, + // How the input are made compatible with the instruction partition + pub inputs_transition: Vec>, + // How the output are made compatible with levelled operation + pub alternative_output_representation: AdditionalRepresentations, +} + +impl InstructionPartition { + pub fn new(instruction_partition: PartitionIndex) -> Self { + Self { + instruction_partition, + ..Self::default() + } + } + + #[cfg(test)] + pub fn no_transition(&self) -> bool { + self.alternative_output_representation.is_empty() + && self.inputs_transition.iter().all(Option::is_none) + } +} + +#[derive(Clone, Debug)] +pub struct Partitions { + pub nb_partitions: usize, + pub instrs_partition: Vec, +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs new file mode 100644 index 000000000..75aebaeb1 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs @@ -0,0 +1,49 @@ +use crate::dag::operator::{Operator, Precision}; +use crate::dag::unparametrized; +use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; + +pub struct PrecisionCut { + // partition0 precision <= p_cut[0] < partition 1 precision <= p_cut[1] ... + // precision are in the sens of Lut input precision and are sorted + pub p_cut: Vec, +} + +impl PrecisionCut { + pub fn partition( + &self, + dag: &unparametrized::OperationDag, + op: &Operator, + ) -> Option { + match op { + Operator::Lut { input, .. } => { + assert!(!self.p_cut.is_empty()); + for (partition, &p_cut) in self.p_cut.iter().enumerate() { + if dag.out_precisions[input.i] <= p_cut { + return Some(partition); + } + } + Some(self.p_cut.len()) + } + _ => None, + } + } +} + +impl std::fmt::Display for PrecisionCut { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut prev_p_cut = 0; + for (partition, &p_cut) in self.p_cut.iter().enumerate() { + writeln!( + f, + "partition {partition}: {prev_p_cut} up through {p_cut} bits" + )?; + prev_p_cut = p_cut + 1; + } + writeln!( + f, + "partition {}: {prev_p_cut} bits and higher", + self.p_cut.len() + )?; + Ok(()) + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/union_find.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/union_find.rs new file mode 100644 index 000000000..101e2a35d --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/union_find.rs @@ -0,0 +1,70 @@ +pub struct UnionFind { + pub parent: Vec, +} + +impl UnionFind { + // Used to detect instructions connected in levelled block (in partionning.rs). + pub fn new(size: usize) -> Self { + Self { + parent: (0..size).collect(), + } + } + + pub fn find_canonical(&mut self, a: usize) -> usize { + let parent = self.parent[a]; + if a == parent { + return a; + } + let canonical = self.find_canonical(parent); + self.parent[a] = canonical; + canonical + } + + pub fn union(&mut self, a: usize, b: usize) { + _ = self.united_common_ancestor(a, b); + } + + // use slow path compression, immediate parent check and early recognition + pub fn united_common_ancestor(&mut self, a: usize, b: usize) -> usize { + let parent_a = self.parent[a]; + let parent_b = self.parent[b]; + if parent_a == parent_b { + return parent_a; // common ancestor + } + let common_ancestor = if a == parent_a && parent_b < parent_a { + // uniting class_a the smallest b ancestor + parent_b + } else if b == parent_b && parent_a < parent_b { + // uniting class_b the smallest b ancestor + parent_a + } else { + self.united_common_ancestor(parent_a, parent_b) // loop + }; + // classic path compression + self.parent[a] = common_ancestor; + self.parent[b] = common_ancestor; + common_ancestor + } +} + +#[cfg(test)] +mod tests { + + use super::super::partitionning::Blocks; + use super::*; + + #[test] + fn test_union_find() { + let size = 10; + let mut uf = UnionFind::new(size); + for i in 0..size { + assert!(uf.find_canonical(0) == 0); + assert!(uf.find_canonical(i) == i); + uf.union(i, 0); + assert!(uf.find_canonical(i) == 0, "{} {:?}", i, &uf.parent[0..=i]); + } + eprintln!("{:?}", uf.parent); + let expected_group: Vec = (0..10).collect(); + assert!(Blocks::from(uf).blocks == vec![expected_group]); + } +}