diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index 6d0cc9c8b..cdb8852a4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -1,3 +1,5 @@ +use std::fmt; + use crate::dag::operator::tensor::{ClearTensor, Shape}; pub type Weights = ClearTensor; @@ -103,3 +105,60 @@ pub enum Operator { pub struct OperatorIndex { pub i: usize, } + +impl fmt::Display for Operator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "")?; + match self { + Self::Input { + out_precision, + out_shape, + } => { + write!(f, "Input : u{out_precision} x {out_shape:?}")?; + } + Self::Dot { inputs, weights } => { + for (i, (input, weight)) in inputs.iter().zip(weights.values.iter()).enumerate() { + if i > 0 { + write!(f, " + ")?; + } + write!(f, "{weight} x %{}", input.i)?; + } + } + Self::UnsafeCast { + input, + out_precision, + } => { + write!(f, "%{} : u{out_precision}", input.i)?; + } + Self::Lut { + input, + out_precision, + .. + } => { + write!(f, "LUT[%{}] : u{out_precision}", input.i)?; + } + Self::LevelledOp { + inputs, + manp, + out_shape, + .. + } => { + write!(f, "LINEAR[")?; + for (i, input) in inputs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "%{}", input.i)?; + } + write!(f, "] : manp={manp} x {out_shape:?}")?; + } + Self::Round { + input, + out_precision, + } => { + write!(f, "ROUND[%{}] : u{out_precision}", input.i)?; + } + } + Ok(()) + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index c900e0000..b6d457d7f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::fmt::Write; use crate::dag::operator::{ @@ -17,6 +18,15 @@ pub struct OperationDag { pub(crate) out_precisions: Vec, } +impl fmt::Display for OperationDag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for (i, op) in self.operators.iter().enumerate() { + writeln!(f, "%{i} <- {op}")?; + } + Ok(()) + } +} + impl OperationDag { pub const fn new() -> Self { Self {