input graph definition

This commit is contained in:
Mayeul@Zama
2022-02-03 11:24:50 +01:00
committed by mayeul-zama
parent 75e4eb42c2
commit 53c74caa32
5 changed files with 105 additions and 8 deletions

2
src/graph.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod operator_index;
pub mod unparametrized;

View File

@@ -0,0 +1,2 @@
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct OperatorIndex(pub usize);

View File

@@ -0,0 +1,92 @@
use super::operator_index::OperatorIndex;
use crate::weight::Weight;
#[derive(Clone, PartialEq, Eq, Debug)]
pub(crate) enum Operator {
Input {
out_precision: u8,
},
AtomicPattern {
in_precision: u8,
out_precision: u8,
multisum_inputs: Vec<(Weight, OperatorIndex)>,
},
}
#[derive(Clone)]
#[must_use]
pub struct AtomicPatternDag {
operators: Vec<Operator>,
}
impl AtomicPatternDag {
pub const fn new() -> Self {
Self { operators: vec![] }
}
fn add_operator(&mut self, operator: Operator) -> OperatorIndex {
let operator_index = self.operators.len();
self.operators.push(operator);
OperatorIndex(operator_index)
}
pub fn add_input(&mut self, out_precision: u8) -> OperatorIndex {
self.add_operator(Operator::Input { out_precision })
}
pub fn add_atomic_pattern(
&mut self,
in_precision: u8,
out_precision: u8,
multisum_inputs: Vec<(Weight, OperatorIndex)>,
) -> OperatorIndex {
self.add_operator(Operator::AtomicPattern {
in_precision,
out_precision,
multisum_inputs,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn graph_creation() {
let mut graph = AtomicPatternDag::new();
let input1 = graph.add_input(1);
let input2 = graph.add_input(2);
let atomic_pattern1 =
graph.add_atomic_pattern(3, 3, vec![(Weight(1), input1), (Weight(2), input2)]);
let _atomic_pattern2 = graph.add_atomic_pattern(
4,
4,
vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
);
assert_eq!(
&graph.operators,
&[
Operator::Input { out_precision: 1 },
Operator::Input { out_precision: 2 },
Operator::AtomicPattern {
in_precision: 3,
out_precision: 3,
multisum_inputs: vec![(Weight(1), input1), (Weight(2), input2)]
},
Operator::AtomicPattern {
in_precision: 4,
out_precision: 4,
multisum_inputs: vec![(Weight(1), atomic_pattern1), (Weight(2), input2)]
},
]
);
}
}

View File

@@ -1,8 +1,7 @@
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}
#![warn(clippy::nursery)]
#![warn(clippy::pedantic)]
#![warn(clippy::style)]
#![allow(clippy::missing_panics_doc)]
pub mod graph;
pub mod weight;

2
src/weight.rs Normal file
View File

@@ -0,0 +1,2 @@
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Weight(pub(crate) u32);