mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: new dag with Lut/Dot/LevelledOp
This commit is contained in:
@@ -6,6 +6,8 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
delegate = "0.6.2"
|
||||
derive_more = "0.99.17"
|
||||
concrete-commons = { git = "ssh://git@github.com/zama-ai/concrete_internal.git", branch = "fix/optimizer_compat" }
|
||||
concrete-npe = { git = "ssh://git@github.com/zama-ai/concrete_internal.git", branch = "fix/optimizer_compat" }
|
||||
statrs = "0.15.0"
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::graph::operator::{Operator, OperatorIndex};
|
||||
use crate::graph::parameter_indexed::{AtomicPatternParametersIndexed, InputParameterIndexed};
|
||||
use crate::graph::parameter_indexed::{
|
||||
InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed,
|
||||
};
|
||||
use crate::graph::unparametrized::UnparameterizedOperator;
|
||||
use crate::graph::{parameter_indexed, range_parametrized, unparametrized};
|
||||
use crate::parameters::{
|
||||
BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters,
|
||||
@@ -92,46 +95,65 @@ impl Range {
|
||||
|
||||
#[must_use]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub fn minimal_unify(_g: unparametrized::AtomicPatternDag) -> parameter_indexed::AtomicPatternDag {
|
||||
pub fn minimal_unify(_g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn convert_maximal(
|
||||
op: Operator<(), ()>,
|
||||
) -> Operator<InputParameterIndexed, AtomicPatternParametersIndexed> {
|
||||
fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed {
|
||||
let external_glwe_index = 0;
|
||||
let internal_lwe_index = 1;
|
||||
let br_decomposition_index = 0;
|
||||
let ks_decomposition_index = 0;
|
||||
match op {
|
||||
Operator::Input { out_precision, .. } => Operator::Input {
|
||||
Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
..
|
||||
} => Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
extra_data: InputParameterIndexed {
|
||||
lwe_dimension_index: external_glwe_index,
|
||||
},
|
||||
},
|
||||
Operator::AtomicPattern {
|
||||
in_precision,
|
||||
out_precision,
|
||||
multisum_inputs,
|
||||
..
|
||||
} => Operator::AtomicPattern {
|
||||
in_precision,
|
||||
out_precision,
|
||||
multisum_inputs,
|
||||
extra_data: AtomicPatternParametersIndexed {
|
||||
input_lwe_dimensionlwe_dimension_index: external_glwe_index,
|
||||
Operator::Lut { input, table, .. } => Operator::Lut {
|
||||
input,
|
||||
table,
|
||||
extra_data: LutParametersIndexed {
|
||||
input_lwe_dimension_index: external_glwe_index,
|
||||
ks_decomposition_parameter_index: ks_decomposition_index,
|
||||
internal_lwe_dimension_index: internal_lwe_index,
|
||||
br_decomposition_parameter_index: br_decomposition_index,
|
||||
output_glwe_params_index: external_glwe_index,
|
||||
},
|
||||
},
|
||||
Operator::Dot {
|
||||
inputs, weights, ..
|
||||
} => Operator::Dot {
|
||||
inputs,
|
||||
weights,
|
||||
extra_data: (),
|
||||
},
|
||||
Operator::LevelledOp {
|
||||
inputs,
|
||||
complexity,
|
||||
manp,
|
||||
out_shape,
|
||||
comment,
|
||||
..
|
||||
} => Operator::LevelledOp {
|
||||
inputs,
|
||||
complexity,
|
||||
manp,
|
||||
comment,
|
||||
out_shape,
|
||||
extra_data: (),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::AtomicPatternDag {
|
||||
pub fn maximal_unify(g: unparametrized::OperationDag) -> parameter_indexed::OperationDag {
|
||||
let operators: Vec<_> = g.operators.into_iter().map(convert_maximal).collect();
|
||||
|
||||
let parameters = ParameterCount {
|
||||
@@ -147,20 +169,25 @@ pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::
|
||||
};
|
||||
|
||||
for (i, op) in operators.iter().enumerate() {
|
||||
let index = OperatorIndex { i };
|
||||
match op {
|
||||
Operator::Input { .. } => {
|
||||
reverse_map.glwe[0].push(OperatorIndex(i));
|
||||
reverse_map.glwe[0].push(index);
|
||||
}
|
||||
Operator::AtomicPattern { .. } => {
|
||||
reverse_map.glwe[0].push(OperatorIndex(i));
|
||||
reverse_map.glwe[1].push(OperatorIndex(i));
|
||||
reverse_map.br_decomposition[0].push(OperatorIndex(i));
|
||||
reverse_map.ks_decomposition[0].push(OperatorIndex(i));
|
||||
Operator::Lut { .. } => {
|
||||
reverse_map.glwe[0].push(index);
|
||||
reverse_map.glwe[1].push(index);
|
||||
reverse_map.br_decomposition[0].push(index);
|
||||
reverse_map.ks_decomposition[0].push(index);
|
||||
}
|
||||
Operator::Dot { .. } | Operator::LevelledOp { .. } => {
|
||||
reverse_map.glwe[0].push(index);
|
||||
reverse_map.glwe[1].push(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parameter_indexed::AtomicPatternDag {
|
||||
parameter_indexed::OperationDag {
|
||||
operators,
|
||||
parameters_count: parameters,
|
||||
reverse_map,
|
||||
@@ -169,16 +196,16 @@ pub fn maximal_unify(g: unparametrized::AtomicPatternDag) -> parameter_indexed::
|
||||
|
||||
#[must_use]
|
||||
pub fn domains_to_ranges(
|
||||
parameter_indexed::AtomicPatternDag {
|
||||
parameter_indexed::OperationDag {
|
||||
operators,
|
||||
parameters_count,
|
||||
reverse_map,
|
||||
}: parameter_indexed::AtomicPatternDag,
|
||||
}: parameter_indexed::OperationDag,
|
||||
domains: ParameterDomains,
|
||||
) -> range_parametrized::AtomicPatternDag {
|
||||
) -> range_parametrized::OperationDag {
|
||||
let mut constrained_glwe_parameter_indexes = HashSet::new();
|
||||
for op in &operators {
|
||||
if let Operator::AtomicPattern { extra_data, .. } = op {
|
||||
if let Operator::Lut { extra_data, .. } = op {
|
||||
let _ = constrained_glwe_parameter_indexes.insert(extra_data.output_glwe_params_index);
|
||||
}
|
||||
}
|
||||
@@ -202,7 +229,7 @@ pub fn domains_to_ranges(
|
||||
ks_decomposition: vec![domains.ks_decomposition; parameters_count.ks_decomposition],
|
||||
};
|
||||
|
||||
range_parametrized::AtomicPatternDag {
|
||||
range_parametrized::OperationDag {
|
||||
operators,
|
||||
parameter_ranges,
|
||||
reverse_map,
|
||||
@@ -217,24 +244,27 @@ pub fn domains_to_ranges(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::weight::Weight;
|
||||
use crate::graph::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
|
||||
|
||||
#[test]
|
||||
fn test_maximal_unify() {
|
||||
let mut graph = unparametrized::AtomicPatternDag::new();
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
|
||||
let input1 = graph.add_input(1);
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
|
||||
let input2 = graph.add_input(2);
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let atomic_pattern1 =
|
||||
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
|
||||
let _atomic_pattern2 = graph.add_atomic_pattern(
|
||||
4,
|
||||
4,
|
||||
vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
|
||||
);
|
||||
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
|
||||
|
||||
let concat =
|
||||
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
|
||||
|
||||
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
|
||||
|
||||
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
|
||||
|
||||
let graph_params = maximal_unify(graph);
|
||||
|
||||
@@ -250,24 +280,19 @@ mod tests {
|
||||
assert_eq!(
|
||||
graph_params.reverse_map.glwe,
|
||||
vec![
|
||||
vec![
|
||||
OperatorIndex(0),
|
||||
OperatorIndex(1),
|
||||
OperatorIndex(2),
|
||||
OperatorIndex(3)
|
||||
],
|
||||
vec![OperatorIndex(2), OperatorIndex(3)]
|
||||
vec![input1, input2, sum1, lut1, concat, dot, lut2],
|
||||
vec![sum1, lut1, concat, dot, lut2]
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
graph_params.reverse_map.br_decomposition,
|
||||
vec![vec![OperatorIndex(2), OperatorIndex(3)]]
|
||||
vec![vec![lut1, lut2]]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
graph_params.reverse_map.ks_decomposition,
|
||||
vec![vec![OperatorIndex(2), OperatorIndex(3)]]
|
||||
vec![vec![lut1, lut2]]
|
||||
);
|
||||
// collectes l'ensemble des parametres
|
||||
// unify structurellement les parametres identiques
|
||||
@@ -278,19 +303,19 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_simple_lwe() {
|
||||
let mut graph = unparametrized::AtomicPatternDag::new();
|
||||
let input1 = graph.add_input(1);
|
||||
let _input2 = graph.add_input(2);
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let _input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let graph_params = maximal_unify(graph);
|
||||
|
||||
let range_parametrized::AtomicPatternDag {
|
||||
let range_parametrized::OperationDag {
|
||||
operators,
|
||||
parameter_ranges,
|
||||
reverse_map: _,
|
||||
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
|
||||
|
||||
let input_1_lwe_params = match &operators[input1.0] {
|
||||
let input_1_lwe_params = match &operators[input1.i] {
|
||||
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
@@ -305,22 +330,25 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_simple_lwe2() {
|
||||
let mut graph = unparametrized::AtomicPatternDag::new();
|
||||
let input1 = graph.add_input(1);
|
||||
let input2 = graph.add_input(2);
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let atomic_pattern1 =
|
||||
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let concat =
|
||||
graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
|
||||
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN);
|
||||
|
||||
let graph_params = maximal_unify(graph);
|
||||
|
||||
let range_parametrized::AtomicPatternDag {
|
||||
let range_parametrized::OperationDag {
|
||||
operators,
|
||||
parameter_ranges,
|
||||
reverse_map: _,
|
||||
} = domains_to_ranges(graph_params, DEFAUT_DOMAINS);
|
||||
|
||||
let input_1_lwe_params = match &operators[input1.0] {
|
||||
let input_1_lwe_params = match &operators[input1.i] {
|
||||
Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
@@ -329,22 +357,22 @@ mod tests {
|
||||
parameter_ranges.glwe[input_1_lwe_params]
|
||||
);
|
||||
|
||||
let atomic_pattern1_out_glwe_params = match &operators[atomic_pattern1.0] {
|
||||
Operator::AtomicPattern { extra_data, .. } => extra_data.output_glwe_params_index,
|
||||
let lut1_out_glwe_params = match &operators[lut1.i] {
|
||||
Operator::Lut { extra_data, .. } => extra_data.output_glwe_params_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert_eq!(
|
||||
DEFAUT_DOMAINS.glwe_pbs_constrained,
|
||||
parameter_ranges.glwe[atomic_pattern1_out_glwe_params]
|
||||
parameter_ranges.glwe[lut1_out_glwe_params]
|
||||
);
|
||||
|
||||
let atomic_pattern1_internal_glwe_params = match &operators[atomic_pattern1.0] {
|
||||
Operator::AtomicPattern { extra_data, .. } => extra_data.internal_lwe_dimension_index,
|
||||
let lut1_internal_glwe_params = match &operators[lut1.i] {
|
||||
Operator::Lut { extra_data, .. } => extra_data.internal_lwe_dimension_index,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
assert_eq!(
|
||||
DEFAUT_DOMAINS.free_glwe,
|
||||
parameter_ranges.glwe[atomic_pattern1_internal_glwe_params]
|
||||
parameter_ranges.glwe[lut1_internal_glwe_params]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
use crate::weight::Weight;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub(crate) enum Operator<InputExtraData, AtomicPatternExtraData> {
|
||||
Input {
|
||||
out_precision: u8,
|
||||
extra_data: InputExtraData,
|
||||
},
|
||||
AtomicPattern {
|
||||
in_precision: u8,
|
||||
out_precision: u8,
|
||||
multisum_inputs: Vec<(Weight, OperatorIndex)>,
|
||||
extra_data: AtomicPatternExtraData,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub struct OperatorIndex(pub usize);
|
||||
6
concrete-optimizer/src/graph/operator/mod.rs
Normal file
6
concrete-optimizer/src/graph/operator/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
#![allow(clippy::module_inception)]
|
||||
pub mod operator;
|
||||
pub mod tensor;
|
||||
|
||||
pub use self::operator::*;
|
||||
pub use self::tensor::*;
|
||||
85
concrete-optimizer/src/graph/operator/operator.rs
Normal file
85
concrete-optimizer/src/graph/operator/operator.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use crate::graph::operator::tensor::{ClearTensor, Shape};
|
||||
use derive_more::{Add, AddAssign};
|
||||
|
||||
pub type Weights = ClearTensor;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct FunctionTable {
|
||||
pub values: Vec<u64>,
|
||||
}
|
||||
|
||||
impl FunctionTable {
|
||||
pub const UNKWOWN: Self = Self { values: vec![] };
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Add, AddAssign, Debug)]
|
||||
pub struct LevelledComplexity {
|
||||
pub lwe_dim_cost_factor: f64,
|
||||
pub fixed_cost: f64,
|
||||
}
|
||||
|
||||
impl LevelledComplexity {
|
||||
pub const ZERO: Self = Self {
|
||||
lwe_dim_cost_factor: 0.0,
|
||||
fixed_cost: 0.0,
|
||||
};
|
||||
pub const ADDITION: Self = Self {
|
||||
lwe_dim_cost_factor: 1.0,
|
||||
fixed_cost: 0.0,
|
||||
};
|
||||
}
|
||||
|
||||
impl LevelledComplexity {
|
||||
pub fn cost(&self, lwe_dimension: u64) -> f64 {
|
||||
self.lwe_dim_cost_factor * (lwe_dimension as f64) + self.fixed_cost
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::AddAssign<&Self> for LevelledComplexity {
|
||||
fn add_assign(&mut self, rhs: &Self) {
|
||||
*self += *rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<u64> for LevelledComplexity {
|
||||
type Output = Self;
|
||||
fn mul(self, factor: u64) -> Self {
|
||||
Self {
|
||||
lwe_dim_cost_factor: self.lwe_dim_cost_factor * factor as f64,
|
||||
fixed_cost: self.fixed_cost * factor as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraData> {
|
||||
Input {
|
||||
out_precision: u8,
|
||||
out_shape: Shape,
|
||||
extra_data: InputExtraData,
|
||||
},
|
||||
Lut {
|
||||
input: OperatorIndex,
|
||||
table: FunctionTable,
|
||||
//reduced_precision: u64
|
||||
extra_data: LutExtraData,
|
||||
},
|
||||
Dot {
|
||||
inputs: Vec<OperatorIndex>,
|
||||
weights: Weights,
|
||||
extra_data: DotExtraData,
|
||||
},
|
||||
LevelledOp {
|
||||
inputs: Vec<OperatorIndex>,
|
||||
complexity: LevelledComplexity,
|
||||
manp: f64,
|
||||
out_shape: Shape,
|
||||
comment: String,
|
||||
extra_data: LevelledOpExtraData,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub struct OperatorIndex {
|
||||
pub i: usize,
|
||||
}
|
||||
86
concrete-optimizer/src/graph/operator/tensor.rs
Normal file
86
concrete-optimizer/src/graph/operator/tensor.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use delegate::delegate;
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub struct Shape {
|
||||
dimensions_size: Vec<u64>,
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.dimensions_size.len()
|
||||
}
|
||||
|
||||
pub fn flat_size(&self) -> u64 {
|
||||
let mut product = 1;
|
||||
for dim_size in &self.dimensions_size {
|
||||
product *= dim_size;
|
||||
}
|
||||
product
|
||||
}
|
||||
|
||||
pub fn number() -> Self {
|
||||
Self {
|
||||
dimensions_size: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_number(&self) -> bool {
|
||||
self.rank() == 0
|
||||
}
|
||||
|
||||
pub fn vector(size: u64) -> Self {
|
||||
Self {
|
||||
dimensions_size: vec![size],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_vector(&self) -> bool {
|
||||
self.rank() == 1
|
||||
}
|
||||
|
||||
pub fn duplicated(out_dim_size: u64, other: &Self) -> Self {
|
||||
let mut dimensions_size = Vec::with_capacity(other.rank() + 1);
|
||||
dimensions_size.push(out_dim_size as u64);
|
||||
dimensions_size.extend_from_slice(&other.dimensions_size);
|
||||
Self { dimensions_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub struct ClearTensor {
|
||||
pub shape: Shape,
|
||||
pub values: Vec<u64>,
|
||||
}
|
||||
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||
fn square(v: &u64) -> u64 {
|
||||
v * v
|
||||
}
|
||||
|
||||
impl ClearTensor {
|
||||
pub fn number(value: u64) -> Self {
|
||||
Self {
|
||||
shape: Shape::number(),
|
||||
values: vec![value],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector(values: &[u64]) -> Self {
|
||||
Self {
|
||||
shape: Shape::vector(values.len() as u64),
|
||||
values: values.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
delegate! {
|
||||
to self.shape {
|
||||
pub fn is_number(&self) -> bool;
|
||||
pub fn is_vector(&self) -> bool;
|
||||
pub fn flat_size(&self) -> u64;
|
||||
pub fn rank(&self) -> usize;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn square_norm2(&self) -> u64 {
|
||||
self.values.iter().map(square).sum()
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,8 @@ pub struct InputParameterIndexed {
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct AtomicPatternParametersIndexed {
|
||||
pub input_lwe_dimensionlwe_dimension_index: usize,
|
||||
pub struct LutParametersIndexed {
|
||||
pub input_lwe_dimension_index: usize,
|
||||
pub ks_decomposition_parameter_index: usize,
|
||||
pub internal_lwe_dimension_index: usize,
|
||||
pub br_decomposition_parameter_index: usize,
|
||||
@@ -16,9 +16,9 @@ pub struct AtomicPatternParametersIndexed {
|
||||
}
|
||||
|
||||
pub(crate) type OperatorParameterIndexed =
|
||||
Operator<InputParameterIndexed, AtomicPatternParametersIndexed>;
|
||||
Operator<InputParameterIndexed, LutParametersIndexed, (), ()>;
|
||||
|
||||
pub struct AtomicPatternDag {
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<OperatorParameterIndexed>,
|
||||
pub(crate) parameters_count: ParameterCount,
|
||||
pub(crate) reverse_map: ParameterToOperation,
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::global_parameters::{ParameterRanges, ParameterToOperation};
|
||||
use crate::graph::parameter_indexed::OperatorParameterIndexed;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct AtomicPatternDag {
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<OperatorParameterIndexed>,
|
||||
pub(crate) parameter_ranges: ParameterRanges,
|
||||
pub(crate) reverse_map: ParameterToOperation,
|
||||
|
||||
@@ -1,46 +1,71 @@
|
||||
use super::operator::{Operator, OperatorIndex};
|
||||
use crate::weight::Weight;
|
||||
use crate::graph::operator::{
|
||||
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Shape, Weights,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>;
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
#[must_use]
|
||||
pub struct AtomicPatternDag {
|
||||
pub(crate) operators: Vec<Operator<(), ()>>,
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<UnparameterizedOperator>,
|
||||
}
|
||||
|
||||
impl AtomicPatternDag {
|
||||
impl OperationDag {
|
||||
pub const fn new() -> Self {
|
||||
Self { operators: vec![] }
|
||||
}
|
||||
|
||||
fn add_operator(&mut self, operator: Operator<(), ()>) -> OperatorIndex {
|
||||
let operator_index = self.operators.len();
|
||||
|
||||
fn add_operator(&mut self, operator: UnparameterizedOperator) -> OperatorIndex {
|
||||
let i = self.operators.len();
|
||||
self.operators.push(operator);
|
||||
|
||||
OperatorIndex(operator_index)
|
||||
OperatorIndex { i }
|
||||
}
|
||||
|
||||
pub fn add_input(&mut self, out_precision: u8) -> OperatorIndex {
|
||||
pub fn add_input(&mut self, out_precision: u8, out_shape: Shape) -> OperatorIndex {
|
||||
self.add_operator(Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_atomic_pattern(
|
||||
&mut self,
|
||||
in_precision: u8,
|
||||
out_precision: u8,
|
||||
multisum_inputs: Vec<(Weight, OperatorIndex)>,
|
||||
) -> OperatorIndex {
|
||||
self.add_operator(Operator::AtomicPattern {
|
||||
in_precision,
|
||||
out_precision,
|
||||
multisum_inputs,
|
||||
pub fn add_lut(&mut self, input: OperatorIndex, table: FunctionTable) -> OperatorIndex {
|
||||
self.add_operator(Operator::Lut {
|
||||
input,
|
||||
table,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_dot(&mut self, inputs: &[OperatorIndex], weights: &Weights) -> OperatorIndex {
|
||||
self.add_operator(Operator::Dot {
|
||||
inputs: inputs.to_vec(),
|
||||
weights: weights.clone(),
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_levelled_op(
|
||||
&mut self,
|
||||
inputs: &[OperatorIndex],
|
||||
complexity: LevelledComplexity,
|
||||
manp: f64,
|
||||
out_shape: Shape,
|
||||
comment: &str,
|
||||
) -> OperatorIndex {
|
||||
let inputs = inputs.to_vec();
|
||||
let comment = comment.to_string();
|
||||
let op = Operator::LevelledOp {
|
||||
inputs,
|
||||
complexity,
|
||||
manp,
|
||||
out_shape,
|
||||
comment,
|
||||
extra_data: (),
|
||||
};
|
||||
self.add_operator(op)
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
@@ -50,47 +75,80 @@ impl AtomicPatternDag {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::graph::operator::Shape;
|
||||
|
||||
#[test]
|
||||
fn graph_creation() {
|
||||
let mut graph = AtomicPatternDag::new();
|
||||
let mut graph = OperationDag::new();
|
||||
|
||||
let input1 = graph.add_input(1);
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
|
||||
let input2 = graph.add_input(2);
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let atomic_pattern1 =
|
||||
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
|
||||
let _atomic_pattern2 = graph.add_atomic_pattern(
|
||||
4,
|
||||
4,
|
||||
vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
|
||||
);
|
||||
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
|
||||
|
||||
let concat =
|
||||
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
|
||||
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
|
||||
|
||||
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
|
||||
|
||||
let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2];
|
||||
for (expected_i, op_index) in ops_index.iter().enumerate() {
|
||||
assert_eq!(expected_i, op_index.i);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&graph.operators,
|
||||
&[
|
||||
graph.operators,
|
||||
vec![
|
||||
Operator::Input {
|
||||
out_precision: 1,
|
||||
out_shape: Shape::number(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Input {
|
||||
out_precision: 2,
|
||||
out_shape: Shape::number(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::AtomicPattern {
|
||||
in_precision: 3,
|
||||
out_precision: 3,
|
||||
multisum_inputs: vec![(Weight(1), input1), (Weight(2), input2)],
|
||||
Operator::LevelledOp {
|
||||
inputs: vec![input1, input2],
|
||||
complexity: cpx_add,
|
||||
manp: 1.0,
|
||||
out_shape: Shape::number(),
|
||||
comment: "sum".to_string(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::AtomicPattern {
|
||||
in_precision: 4,
|
||||
out_precision: 4,
|
||||
multisum_inputs: vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
|
||||
Operator::Lut {
|
||||
input: sum1,
|
||||
table: FunctionTable::UNKWOWN,
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::LevelledOp {
|
||||
inputs: vec![input1, lut1],
|
||||
complexity: cpx_add,
|
||||
manp: 1.0,
|
||||
out_shape: Shape::vector(2),
|
||||
comment: "concat".to_string(),
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Dot {
|
||||
inputs: vec![concat],
|
||||
weights: Weights {
|
||||
shape: Shape::vector(2),
|
||||
values: vec![1, 2]
|
||||
},
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Lut {
|
||||
input: dot,
|
||||
table: FunctionTable::UNKWOWN,
|
||||
extra_data: ()
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::global_parameters::{ParameterToOperation, ParameterValues};
|
||||
use crate::graph::parameter_indexed::OperatorParameterIndexed;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct AtomicPatternDag {
|
||||
pub struct OperationDag {
|
||||
pub(crate) operators: Vec<OperatorParameterIndexed>,
|
||||
pub(crate) parameter_ranges: ParameterValues,
|
||||
pub(crate) reverse_map: ParameterToOperation,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#![warn(clippy::style)]
|
||||
#![allow(clippy::cast_precision_loss)] // u64 to f64
|
||||
#![allow(clippy::cast_possible_truncation)] // u64 to usize
|
||||
#![allow(clippy::inline_always)] // needed by delegate
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
#![allow(clippy::missing_const_for_fn)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
Reference in New Issue
Block a user