mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(optimizer): dag partitionning based on p_cut
This commit is contained in:
@@ -1 +1,5 @@
|
||||
pub mod keys_spec;
|
||||
pub mod partitionning;
|
||||
pub(crate) mod partitions;
|
||||
pub(crate) mod precision_cut;
|
||||
pub(crate) mod union_find;
|
||||
|
||||
@@ -0,0 +1,591 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use crate::dag::operator::{Operator, OperatorIndex};
|
||||
use crate::dag::unparametrized;
|
||||
|
||||
use super::partitions::{InstructionPartition, PartitionIndex, Partitions, Transition};
|
||||
use super::precision_cut::PrecisionCut;
|
||||
use super::union_find::UnionFind;
|
||||
type Op = Operator;
|
||||
|
||||
// Blocks of instructions
|
||||
pub struct Blocks {
|
||||
// Set of instructions indexes for each block
|
||||
pub blocks: Vec<Vec<usize>>,
|
||||
// Block index of each instructions
|
||||
pub block_of: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Blocks {
|
||||
pub fn from(mut uf: UnionFind) -> Self {
|
||||
let mut block_of_canon: HashMap<usize, usize> = HashMap::new();
|
||||
let mut blocks: Vec<Vec<usize>> = vec![];
|
||||
let size = uf.parent.len();
|
||||
for op_i in 0..size {
|
||||
let canon = uf.find_canonical(op_i);
|
||||
// the canonic is always the smaller, so it's the first
|
||||
if canon == op_i {
|
||||
let block_i = blocks.len();
|
||||
_ = block_of_canon.insert(canon, block_i);
|
||||
blocks.push(vec![canon]);
|
||||
} else {
|
||||
let &block_i = block_of_canon.get(&canon).unwrap();
|
||||
blocks[block_i].push(op_i);
|
||||
}
|
||||
}
|
||||
let mut block_of = vec![0; size];
|
||||
for (i, block) in blocks.iter().enumerate() {
|
||||
for &a in block {
|
||||
block_of[a] = i;
|
||||
}
|
||||
}
|
||||
Self { blocks, block_of }
|
||||
}
|
||||
}
|
||||
|
||||
// Extract block of instructions connected by levelled ops.
|
||||
// This facilitates reasonning about conflicts on levelled ops.
|
||||
#[allow(clippy::match_same_arms)]
|
||||
fn extract_levelled_block(dag: &unparametrized::OperationDag) -> Blocks {
|
||||
let mut uf = UnionFind::new(dag.operators.len());
|
||||
for (op_i, op) in dag.operators.iter().enumerate() {
|
||||
match op {
|
||||
// Block entry point
|
||||
Operator::Input { .. } => (),
|
||||
// Block entry point and pre-exit point
|
||||
Op::Lut { .. } => (),
|
||||
// Connectors
|
||||
Op::UnsafeCast { input, .. } => uf.union(input.i, op_i),
|
||||
Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => {
|
||||
for input in inputs {
|
||||
uf.union(input.i, op_i);
|
||||
}
|
||||
}
|
||||
Op::Round { .. } => unreachable!("Round should have been expanded"),
|
||||
};
|
||||
}
|
||||
Blocks::from(uf)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct BlockConstraints {
|
||||
forced: HashSet<usize>, // hard constraints, need to be resolved, given by PartitionFromOp
|
||||
exit: HashSet<usize>, // soft constraints, to have less inter partition keyswitch in TLUs
|
||||
}
|
||||
|
||||
/* For each levelled block collect BlockConstraints */
|
||||
fn levelled_blocks_constraints(
|
||||
dag: &unparametrized::OperationDag,
|
||||
blocks: &Blocks,
|
||||
p_cut: &PrecisionCut,
|
||||
) -> Vec<BlockConstraints> {
|
||||
let mut constraints_by_block = vec![BlockConstraints::default(); blocks.blocks.len()];
|
||||
for (block_i, ops_i) in blocks.blocks.iter().enumerate() {
|
||||
for &op_i in ops_i {
|
||||
let op = &dag.operators[op_i];
|
||||
if let Some(partition) = p_cut.partition(dag, op) {
|
||||
_ = constraints_by_block[block_i].forced.insert(partition);
|
||||
if let Some(input) = op_tlu_inputs(op) {
|
||||
let input_group = blocks.block_of[input.i];
|
||||
constraints_by_block[input_group].exit.extend([partition]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
constraints_by_block
|
||||
}
|
||||
|
||||
fn op_tlu_inputs(op: &Operator) -> Option<OperatorIndex> {
|
||||
match op {
|
||||
Op::Lut { input, .. } => Some(*input),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_singleton_value<V: Copy>(hashset: &HashSet<V>) -> V {
|
||||
*hashset.iter().next().unwrap()
|
||||
}
|
||||
|
||||
fn only_1_partition(dag: &unparametrized::OperationDag) -> Partitions {
|
||||
let mut instrs_partition = vec![InstructionPartition::new(0); dag.operators.len()];
|
||||
for (op_i, op) in dag.operators.iter().enumerate() {
|
||||
match op {
|
||||
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
|
||||
instrs_partition[op_i].inputs_transition = vec![None; inputs.len()];
|
||||
}
|
||||
Op::Lut { .. } | Op::UnsafeCast { .. } => {
|
||||
instrs_partition[op_i].inputs_transition = vec![None];
|
||||
}
|
||||
Op::Input { .. } => (),
|
||||
Op::Round { .. } => unreachable!(),
|
||||
}
|
||||
}
|
||||
Partitions {
|
||||
nb_partitions: 1,
|
||||
instrs_partition,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_by_levelled_block(
|
||||
dag: &unparametrized::OperationDag,
|
||||
p_cut: &PrecisionCut,
|
||||
default_partition: PartitionIndex,
|
||||
) -> Partitions {
|
||||
let blocks = extract_levelled_block(dag);
|
||||
let constraints_by_blocks = levelled_blocks_constraints(dag, &blocks, p_cut);
|
||||
let present_partitions: HashSet<PartitionIndex> = constraints_by_blocks
|
||||
.iter()
|
||||
.flat_map(|c| &c.forced)
|
||||
.copied()
|
||||
.collect();
|
||||
let nb_partitions = present_partitions.len().max(1); // no tlu = no constraints
|
||||
if nb_partitions == 1 {
|
||||
return only_1_partition(dag);
|
||||
}
|
||||
let mut block_partition: Vec<PartitionIndex> = vec![];
|
||||
for constraints in constraints_by_blocks {
|
||||
let partition = match constraints.forced.len() {
|
||||
0 => {
|
||||
if constraints.exit.len() == 1 {
|
||||
get_singleton_value(&constraints.exit)
|
||||
} else {
|
||||
default_partition
|
||||
}
|
||||
}
|
||||
1 => get_singleton_value(&constraints.forced),
|
||||
_ => default_partition, // conflicts solved to default
|
||||
};
|
||||
// TODO1: Could choose based on the number of fast keyswitch added (case > 1)
|
||||
// TODO2: A conversion of an entry point could be deffered to the conflict until a conversion is needed
|
||||
// This is equivalent to refine levelled block
|
||||
// TODO3: This could make even make some exit value used in a different representation and go out unconverted
|
||||
// This can reduce the need to define extra parameters for internal ks
|
||||
block_partition.push(partition);
|
||||
}
|
||||
let mut instrs_p: Vec<InstructionPartition> =
|
||||
vec![InstructionPartition::new(default_partition); dag.operators.len()];
|
||||
let block_partition_of = |op_i| block_partition[blocks.block_of[op_i]];
|
||||
for (op_i, op) in dag.operators.iter().enumerate() {
|
||||
let group_partition = block_partition_of(op_i);
|
||||
match op {
|
||||
Op::Lut { input, .. } => {
|
||||
let instruction_partition = p_cut.partition(dag, op).unwrap();
|
||||
instrs_p[op_i].instruction_partition = instruction_partition;
|
||||
let input_partition = instrs_p[input.i].instruction_partition;
|
||||
instrs_p[op_i].inputs_transition = if input_partition == instruction_partition {
|
||||
vec![None]
|
||||
} else {
|
||||
vec![Some(Transition::Internal {
|
||||
src_partition: input_partition,
|
||||
})]
|
||||
};
|
||||
if group_partition != instruction_partition {
|
||||
instrs_p[op_i].alternative_output_representation =
|
||||
HashSet::from([group_partition]);
|
||||
}
|
||||
}
|
||||
Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => {
|
||||
instrs_p[op_i].instruction_partition = group_partition;
|
||||
instrs_p[op_i].inputs_transition = vec![None; inputs.len()];
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
let input_partition = instrs_p[input.i].instruction_partition;
|
||||
if group_partition != input_partition {
|
||||
instrs_p[op_i].inputs_transition[i] = Some(Transition::Additional {
|
||||
src_partition: input_partition,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::UnsafeCast { input, .. } => {
|
||||
instrs_p[op_i].instruction_partition = group_partition;
|
||||
let input_partition = instrs_p[input.i].instruction_partition;
|
||||
instrs_p[op_i].inputs_transition = if group_partition == input_partition {
|
||||
vec![None]
|
||||
} else {
|
||||
vec![Some(Transition::Additional {
|
||||
src_partition: input_partition,
|
||||
})]
|
||||
}
|
||||
}
|
||||
Operator::Input { .. } => instrs_p[op_i].instruction_partition = group_partition,
|
||||
Op::Round { .. } => unreachable!("Round should have been expanded"),
|
||||
}
|
||||
}
|
||||
Partitions {
|
||||
nb_partitions,
|
||||
instrs_partition: instrs_p,
|
||||
}
|
||||
// Now we can generate transitions
|
||||
// Input has no transtions
|
||||
// Tlu has internal transtions based on input partition
|
||||
// Tlu has immediate external transition if needed
|
||||
}
|
||||
|
||||
pub fn partitionning_with_preferred(
|
||||
dag: &unparametrized::OperationDag,
|
||||
p_cut: &PrecisionCut,
|
||||
default_partition: PartitionIndex,
|
||||
) -> Partitions {
|
||||
if p_cut.p_cut.is_empty() {
|
||||
only_1_partition(dag)
|
||||
} else {
|
||||
resolve_by_levelled_block(dag, p_cut, default_partition)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
|
||||
// 2 Partitions labels
|
||||
pub const LOW_PRECISION_PARTITION: PartitionIndex = 0;
|
||||
pub const HIGH_PRECISION_PARTITION: PartitionIndex = 1;
|
||||
|
||||
use super::*;
|
||||
use crate::dag::operator::{FunctionTable, Shape, Weights};
|
||||
use crate::dag::unparametrized;
|
||||
|
||||
fn default_p_cut() -> PrecisionCut {
|
||||
PrecisionCut { p_cut: vec![2] }
|
||||
}
|
||||
|
||||
fn partitionning_no_p_cut(dag: &unparametrized::OperationDag) -> Partitions {
|
||||
let p_cut = PrecisionCut { p_cut: vec![] };
|
||||
partitionning_with_preferred(dag, &p_cut, LOW_PRECISION_PARTITION)
|
||||
}
|
||||
|
||||
fn partitionning(dag: &unparametrized::OperationDag) -> Partitions {
|
||||
partitionning_with_preferred(dag, &default_p_cut(), LOW_PRECISION_PARTITION)
|
||||
}
|
||||
|
||||
fn partitionning_with_preferred(
|
||||
dag: &unparametrized::OperationDag,
|
||||
p_cut: &PrecisionCut,
|
||||
default_partition: usize,
|
||||
) -> Partitions {
|
||||
super::partitionning_with_preferred(dag, p_cut, default_partition)
|
||||
}
|
||||
|
||||
pub fn show_partitionning(
|
||||
dag: &unparametrized::OperationDag,
|
||||
partitions: &[InstructionPartition],
|
||||
) {
|
||||
println!("Dag:");
|
||||
for (i, op) in dag.operators.iter().enumerate() {
|
||||
let partition = partitions[i].instruction_partition;
|
||||
print!("P {partition}");
|
||||
if partitions[i].alternative_output_representation.is_empty() {
|
||||
print!(" _");
|
||||
} else {
|
||||
print!(" +FKS{:?}", partitions[i].alternative_output_representation);
|
||||
};
|
||||
// partition
|
||||
if !partitions[i].inputs_transition.is_empty() {
|
||||
print!(" <- ");
|
||||
// type in
|
||||
for arg in &partitions[i].inputs_transition {
|
||||
match arg {
|
||||
None => print!("_,"),
|
||||
Some(Transition::Internal { src_partition }) => {
|
||||
print!("{src_partition}&KS,");
|
||||
}
|
||||
Some(Transition::Additional { src_partition }) => {
|
||||
print!("{src_partition}+FKS,");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
println!();
|
||||
println!("%{i} <- {op}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_partition() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(16, Shape::number());
|
||||
_ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 4, 8);
|
||||
let instrs_partition = partitionning_no_p_cut(&dag).instrs_partition;
|
||||
for instr_partition in instrs_partition {
|
||||
assert!(instr_partition.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(instr_partition.no_transition());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_input_2_partitions() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
_ = dag.add_input(1, Shape::number());
|
||||
let partitions = partitionning(&dag);
|
||||
assert!(partitions.nb_partitions == 1);
|
||||
let instrs_partition = partitions.instrs_partition;
|
||||
assert!(instrs_partition[0].instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(partitions.nb_partitions == 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_2_lut_sequence() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let mut expected_partitions = vec![];
|
||||
let input1 = dag.add_input(8, Shape::number());
|
||||
expected_partitions.push(HIGH_PRECISION_PARTITION);
|
||||
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8);
|
||||
expected_partitions.push(HIGH_PRECISION_PARTITION);
|
||||
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 1);
|
||||
expected_partitions.push(HIGH_PRECISION_PARTITION);
|
||||
let lut3 = dag.add_lut(lut2, FunctionTable::UNKWOWN, 1);
|
||||
expected_partitions.push(LOW_PRECISION_PARTITION);
|
||||
let lut4 = dag.add_lut(lut3, FunctionTable::UNKWOWN, 8);
|
||||
expected_partitions.push(LOW_PRECISION_PARTITION);
|
||||
let lut5 = dag.add_lut(lut4, FunctionTable::UNKWOWN, 8);
|
||||
expected_partitions.push(HIGH_PRECISION_PARTITION);
|
||||
let partitions = partitionning(&dag);
|
||||
assert!(partitions.nb_partitions == 2);
|
||||
let instrs_partition = partitions.instrs_partition;
|
||||
let consider = |op_i: OperatorIndex| &instrs_partition[op_i.i];
|
||||
show_partitionning(&dag, &instrs_partition);
|
||||
assert!(consider(input1).instruction_partition == HIGH_PRECISION_PARTITION); // no constraint
|
||||
assert!(consider(lut1).instruction_partition == expected_partitions[1]);
|
||||
assert!(consider(lut2).instruction_partition == expected_partitions[2]);
|
||||
assert!(consider(lut3).instruction_partition == expected_partitions[3]);
|
||||
assert!(consider(lut4).instruction_partition == expected_partitions[4]);
|
||||
assert!(consider(lut5).instruction_partition == expected_partitions[5]);
|
||||
assert!(instrs_partition.len() == 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_dot_no_conflict_low() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(8, Shape::number());
|
||||
let input2 = dag.add_input(1, Shape::number());
|
||||
let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8);
|
||||
let _dot = dag.add_dot([input1, lut2], Weights::from([1, 1]));
|
||||
let partitions = partitionning(&dag);
|
||||
assert!(partitions.nb_partitions == 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_dot_no_conflict_high() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(8, Shape::number());
|
||||
let input2 = dag.add_input(1, Shape::number());
|
||||
let lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, 1);
|
||||
let _dot = dag.add_dot([input2, lut2], Weights::from([1, 1]));
|
||||
let partitions = partitionning(&dag);
|
||||
assert!(partitions.nb_partitions == 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_dot_conflict() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(8, Shape::number());
|
||||
let input2 = dag.add_input(1, Shape::number());
|
||||
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8);
|
||||
let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, 8);
|
||||
let dot = dag.add_dot([lut1, lut2], Weights::from([1, 1]));
|
||||
let partitions = partitionning(&dag);
|
||||
let consider = |op_i: OperatorIndex| &partitions.instrs_partition[op_i.i];
|
||||
// input1
|
||||
let p = consider(input1);
|
||||
{
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
assert!(p.no_transition());
|
||||
};
|
||||
// input2
|
||||
let p = consider(input2);
|
||||
{
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(p.no_transition());
|
||||
};
|
||||
// lut1 , used in low partition dot
|
||||
let p = consider(lut1);
|
||||
{
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
assert!(
|
||||
p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION])
|
||||
);
|
||||
assert!(p.inputs_transition == vec![None]);
|
||||
};
|
||||
// lut2
|
||||
let p = consider(lut2);
|
||||
{
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(p.no_transition());
|
||||
};
|
||||
// dot
|
||||
let p = consider(dot);
|
||||
{
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(p.alternative_output_representation.is_empty());
|
||||
assert!(
|
||||
p.inputs_transition
|
||||
== vec![
|
||||
Some(Transition::Additional {
|
||||
src_partition: HIGH_PRECISION_PARTITION
|
||||
}),
|
||||
None
|
||||
]
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_v3_first_layer_and_second_layer() {
|
||||
let acc_precision = 8;
|
||||
let precision = 6;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(acc_precision, Shape::number());
|
||||
let rounded1 = dag.add_expanded_round(input1, precision);
|
||||
let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let rounded2 = dag.add_expanded_round(lut1, precision);
|
||||
let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision);
|
||||
let partitions = partitionning(&dag);
|
||||
let consider = |op_i| &partitions.instrs_partition[op_i];
|
||||
// First layer is fully LOW_PRECISION_PARTITION
|
||||
for op_i in input1.i..lut1.i {
|
||||
let p = consider(op_i);
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(p.no_transition());
|
||||
}
|
||||
// First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION
|
||||
let p = consider(lut1.i);
|
||||
{
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
assert!(
|
||||
p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION])
|
||||
);
|
||||
assert!(
|
||||
p.inputs_transition
|
||||
== vec![Some(Transition::Internal {
|
||||
src_partition: LOW_PRECISION_PARTITION
|
||||
})]
|
||||
);
|
||||
};
|
||||
for op_i in (lut1.i + 1)..lut2.i {
|
||||
let p = consider(op_i);
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
}
|
||||
let p = consider(lut2.i);
|
||||
{
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
assert!(p.alternative_output_representation.is_empty());
|
||||
assert!(
|
||||
p.inputs_transition
|
||||
== vec![Some(Transition::Internal {
|
||||
src_partition: LOW_PRECISION_PARTITION
|
||||
})]
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_v3_classic_first_layer_second_layer() {
|
||||
let acc_precision = 8;
|
||||
let precision = 6;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let free_input1 = dag.add_input(precision, Shape::number());
|
||||
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let first_layer = free_input1.i..=input1.i;
|
||||
let rounded1 = dag.add_expanded_round(input1, precision);
|
||||
let rounded_layer: Vec<_> = ((input1.i + 1)..rounded1.i).collect();
|
||||
let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let partitions = partitionning(&dag);
|
||||
let consider = |op_i: usize| &partitions.instrs_partition[op_i];
|
||||
|
||||
// First layer is fully HIGH_PRECISION_PARTITION
|
||||
for op_i in first_layer {
|
||||
let p = consider(op_i);
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
}
|
||||
// input is converted with a fast keyswitch to LOW_PRECISION_PARTITION
|
||||
let p = consider(input1.i);
|
||||
assert!(p.alternative_output_representation == HashSet::from([LOW_PRECISION_PARTITION]));
|
||||
let read_converted = Some(Transition::Additional {
|
||||
src_partition: HIGH_PRECISION_PARTITION,
|
||||
});
|
||||
|
||||
// Second layer, rounded part is LOW_PRECISION_PARTITION
|
||||
for &op_i in &rounded_layer {
|
||||
let p = consider(op_i);
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
}
|
||||
// and use read the conversion result
|
||||
let mut first_bit_extract_verified = false;
|
||||
let mut first_bit_erase_verified = false;
|
||||
for &op_i in &rounded_layer {
|
||||
let p = consider(op_i);
|
||||
if let Op::Dot { weights, .. } = &dag.operators[op_i] {
|
||||
let first_bit_extract = weights.values == [256] && !first_bit_extract_verified;
|
||||
let first_bit_erase = weights.values == [1, -1] && !first_bit_erase_verified;
|
||||
if first_bit_extract || first_bit_erase {
|
||||
assert!(p.inputs_transition[0] == read_converted);
|
||||
}
|
||||
first_bit_extract_verified = first_bit_extract_verified || first_bit_extract;
|
||||
first_bit_erase_verified = first_bit_erase_verified || first_bit_erase;
|
||||
};
|
||||
}
|
||||
assert!(first_bit_extract_verified);
|
||||
assert!(first_bit_erase_verified);
|
||||
// Second layer, lut part is HIGH_PRECISION_PARTITION
|
||||
// and use an internal conversion
|
||||
let p = consider(lut1.i);
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
assert!(
|
||||
p.inputs_transition[0]
|
||||
== Some(Transition::Internal {
|
||||
src_partition: LOW_PRECISION_PARTITION
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_v1_classic_first_layer_second_layer() {
|
||||
let acc_precision = 8;
|
||||
let precision = 6;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let free_input1 = dag.add_input(precision, Shape::number());
|
||||
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let first_layer = free_input1.i..=input1.i;
|
||||
let rounded1 = dag.add_expanded_round(input1, precision);
|
||||
let rounded_layer = (input1.i + 1)..rounded1.i;
|
||||
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let partitions =
|
||||
partitionning_with_preferred(&dag, &default_p_cut(), HIGH_PRECISION_PARTITION);
|
||||
show_partitionning(&dag, &partitions.instrs_partition);
|
||||
let consider = |op_i: usize| &partitions.instrs_partition[op_i];
|
||||
|
||||
// First layer is fully HIGH_PRECISION_PARTITION
|
||||
for op_i in first_layer {
|
||||
assert!(consider(op_i).instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
}
|
||||
// input is converted with a fast keyswitch to LOW_PRECISION_PARTITION
|
||||
assert!(consider(input1.i)
|
||||
.alternative_output_representation
|
||||
.is_empty());
|
||||
let read_converted = Some(Transition::Additional {
|
||||
src_partition: LOW_PRECISION_PARTITION,
|
||||
});
|
||||
|
||||
// Second layer, rounded part is mostly HIGH_PRECISION_PARTITION
|
||||
// Only the Lut is post-converted
|
||||
for op_i in rounded_layer {
|
||||
let p = consider(op_i);
|
||||
match &dag.operators[op_i] {
|
||||
Op::Lut { .. } => {
|
||||
assert!(p.instruction_partition == LOW_PRECISION_PARTITION);
|
||||
assert!(
|
||||
p.alternative_output_representation
|
||||
== HashSet::from([HIGH_PRECISION_PARTITION])
|
||||
);
|
||||
}
|
||||
Op::Dot { weights, .. } => {
|
||||
assert!(p.instruction_partition == HIGH_PRECISION_PARTITION);
|
||||
assert!(p.inputs_transition[0].is_none());
|
||||
if weights.values.len() == 2 {
|
||||
assert!(p.inputs_transition[1] == read_converted);
|
||||
}
|
||||
}
|
||||
_ => assert!(p.instruction_partition == HIGH_PRECISION_PARTITION),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
pub type PartitionIndex = usize;
|
||||
pub type AdditionalRepresentations = HashSet<PartitionIndex>;
|
||||
|
||||
// How one input is made compatible with the instruction partition
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
|
||||
pub enum Transition {
|
||||
// The input rely on an already converted input, multi representation value
|
||||
Additional { src_partition: PartitionIndex },
|
||||
// The input can be converted directly by the internal instructions keyswitch
|
||||
Internal { src_partition: PartitionIndex },
|
||||
}
|
||||
|
||||
// One instruction partition is computed for each instruction.
|
||||
// It represents its partition and relations with other partitions.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct InstructionPartition {
|
||||
// The partition assigned to the instruction
|
||||
pub instruction_partition: PartitionIndex,
|
||||
// How the input are made compatible with the instruction partition
|
||||
pub inputs_transition: Vec<Option<Transition>>,
|
||||
// How the output are made compatible with levelled operation
|
||||
pub alternative_output_representation: AdditionalRepresentations,
|
||||
}
|
||||
|
||||
impl InstructionPartition {
|
||||
pub fn new(instruction_partition: PartitionIndex) -> Self {
|
||||
Self {
|
||||
instruction_partition,
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn no_transition(&self) -> bool {
|
||||
self.alternative_output_representation.is_empty()
|
||||
&& self.inputs_transition.iter().all(Option::is_none)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Partitions {
|
||||
pub nb_partitions: usize,
|
||||
pub instrs_partition: Vec<InstructionPartition>,
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use crate::dag::operator::{Operator, Precision};
|
||||
use crate::dag::unparametrized;
|
||||
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
|
||||
|
||||
pub struct PrecisionCut {
|
||||
// partition0 precision <= p_cut[0] < partition 1 precision <= p_cut[1] ...
|
||||
// precision are in the sens of Lut input precision and are sorted
|
||||
pub p_cut: Vec<Precision>,
|
||||
}
|
||||
|
||||
impl PrecisionCut {
|
||||
pub fn partition(
|
||||
&self,
|
||||
dag: &unparametrized::OperationDag,
|
||||
op: &Operator,
|
||||
) -> Option<PartitionIndex> {
|
||||
match op {
|
||||
Operator::Lut { input, .. } => {
|
||||
assert!(!self.p_cut.is_empty());
|
||||
for (partition, &p_cut) in self.p_cut.iter().enumerate() {
|
||||
if dag.out_precisions[input.i] <= p_cut {
|
||||
return Some(partition);
|
||||
}
|
||||
}
|
||||
Some(self.p_cut.len())
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PrecisionCut {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let mut prev_p_cut = 0;
|
||||
for (partition, &p_cut) in self.p_cut.iter().enumerate() {
|
||||
writeln!(
|
||||
f,
|
||||
"partition {partition}: {prev_p_cut} up through {p_cut} bits"
|
||||
)?;
|
||||
prev_p_cut = p_cut + 1;
|
||||
}
|
||||
writeln!(
|
||||
f,
|
||||
"partition {}: {prev_p_cut} bits and higher",
|
||||
self.p_cut.len()
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
pub struct UnionFind {
|
||||
pub parent: Vec<usize>,
|
||||
}
|
||||
|
||||
impl UnionFind {
|
||||
// Used to detect instructions connected in levelled block (in partionning.rs).
|
||||
pub fn new(size: usize) -> Self {
|
||||
Self {
|
||||
parent: (0..size).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_canonical(&mut self, a: usize) -> usize {
|
||||
let parent = self.parent[a];
|
||||
if a == parent {
|
||||
return a;
|
||||
}
|
||||
let canonical = self.find_canonical(parent);
|
||||
self.parent[a] = canonical;
|
||||
canonical
|
||||
}
|
||||
|
||||
pub fn union(&mut self, a: usize, b: usize) {
|
||||
_ = self.united_common_ancestor(a, b);
|
||||
}
|
||||
|
||||
// use slow path compression, immediate parent check and early recognition
|
||||
pub fn united_common_ancestor(&mut self, a: usize, b: usize) -> usize {
|
||||
let parent_a = self.parent[a];
|
||||
let parent_b = self.parent[b];
|
||||
if parent_a == parent_b {
|
||||
return parent_a; // common ancestor
|
||||
}
|
||||
let common_ancestor = if a == parent_a && parent_b < parent_a {
|
||||
// uniting class_a the smallest b ancestor
|
||||
parent_b
|
||||
} else if b == parent_b && parent_a < parent_b {
|
||||
// uniting class_b the smallest b ancestor
|
||||
parent_a
|
||||
} else {
|
||||
self.united_common_ancestor(parent_a, parent_b) // loop
|
||||
};
|
||||
// classic path compression
|
||||
self.parent[a] = common_ancestor;
|
||||
self.parent[b] = common_ancestor;
|
||||
common_ancestor
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::super::partitionning::Blocks;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_union_find() {
|
||||
let size = 10;
|
||||
let mut uf = UnionFind::new(size);
|
||||
for i in 0..size {
|
||||
assert!(uf.find_canonical(0) == 0);
|
||||
assert!(uf.find_canonical(i) == i);
|
||||
uf.union(i, 0);
|
||||
assert!(uf.find_canonical(i) == 0, "{} {:?}", i, &uf.parent[0..=i]);
|
||||
}
|
||||
eprintln!("{:?}", uf.parent);
|
||||
let expected_group: Vec<usize> = (0..10).collect();
|
||||
assert!(Blocks::from(uf).blocks == vec![expected_group]);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user