mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
add parametrized graphs
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
pub mod operator_index;
|
||||
pub mod operator;
|
||||
pub mod unparametrized;
|
||||
|
||||
18
src/graph/operator.rs
Normal file
18
src/graph/operator.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use crate::weight::Weight;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub(crate) enum Operator<InputExtraData, AtomicPatternExtraData> {
|
||||
Input {
|
||||
out_precision: u8,
|
||||
extra_data: InputExtraData,
|
||||
},
|
||||
AtomicPattern {
|
||||
in_precision: u8,
|
||||
out_precision: u8,
|
||||
multisum_inputs: Vec<(Weight, OperatorIndex)>,
|
||||
extra_data: AtomicPatternExtraData,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub struct OperatorIndex(pub usize);
|
||||
@@ -1,2 +0,0 @@
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub struct OperatorIndex(pub usize);
|
||||
@@ -1,22 +1,10 @@
|
||||
use super::operator_index::OperatorIndex;
|
||||
use super::operator::{Operator, OperatorIndex};
|
||||
use crate::weight::Weight;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub(crate) enum Operator {
|
||||
Input {
|
||||
out_precision: u8,
|
||||
},
|
||||
AtomicPattern {
|
||||
in_precision: u8,
|
||||
out_precision: u8,
|
||||
multisum_inputs: Vec<(Weight, OperatorIndex)>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[must_use]
|
||||
pub struct AtomicPatternDag {
|
||||
operators: Vec<Operator>,
|
||||
operators: Vec<Operator<(), ()>>,
|
||||
}
|
||||
|
||||
impl AtomicPatternDag {
|
||||
@@ -24,7 +12,7 @@ impl AtomicPatternDag {
|
||||
Self { operators: vec![] }
|
||||
}
|
||||
|
||||
fn add_operator(&mut self, operator: Operator) -> OperatorIndex {
|
||||
fn add_operator(&mut self, operator: Operator<(), ()>) -> OperatorIndex {
|
||||
let operator_index = self.operators.len();
|
||||
|
||||
self.operators.push(operator);
|
||||
@@ -33,7 +21,10 @@ impl AtomicPatternDag {
|
||||
}
|
||||
|
||||
pub fn add_input(&mut self, out_precision: u8) -> OperatorIndex {
|
||||
self.add_operator(Operator::Input { out_precision })
|
||||
self.add_operator(Operator::Input {
|
||||
out_precision,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_atomic_pattern(
|
||||
@@ -46,6 +37,7 @@ impl AtomicPatternDag {
|
||||
in_precision,
|
||||
out_precision,
|
||||
multisum_inputs,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -74,17 +66,25 @@ mod tests {
|
||||
assert_eq!(
|
||||
&graph.operators,
|
||||
&[
|
||||
Operator::Input { out_precision: 1 },
|
||||
Operator::Input { out_precision: 2 },
|
||||
Operator::Input {
|
||||
out_precision: 1,
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::Input {
|
||||
out_precision: 2,
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::AtomicPattern {
|
||||
in_precision: 3,
|
||||
out_precision: 3,
|
||||
multisum_inputs: vec![(Weight(1), input1), (Weight(2), input2)]
|
||||
multisum_inputs: vec![(Weight(1), input1), (Weight(2), input2)],
|
||||
extra_data: ()
|
||||
},
|
||||
Operator::AtomicPattern {
|
||||
in_precision: 4,
|
||||
out_precision: 4,
|
||||
multisum_inputs: vec![(Weight(1), atomic_pattern1), (Weight(2), input2)]
|
||||
multisum_inputs: vec![(Weight(1), atomic_pattern1), (Weight(2), input2)],
|
||||
extra_data: ()
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![warn(clippy::style)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod graph;
|
||||
pub mod parameters;
|
||||
pub mod weight;
|
||||
|
||||
70
src/parameters.rs
Normal file
70
src/parameters.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
pub use grouped::*;
|
||||
pub use individual::*;
|
||||
|
||||
mod individual {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct KsDecompositionParameters<LogBase, Level> {
|
||||
pub log2_base: LogBase,
|
||||
pub level: Level,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct PbsDecompositionParameters<LogBase, Level> {
|
||||
pub log2_base: LogBase,
|
||||
pub level: Level,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub struct GlweParameters<LogPolynomialSize, GlweDimension> {
|
||||
pub log2_polynomial_size: LogPolynomialSize,
|
||||
pub glwe_dimension: GlweDimension,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct LweDimension<LweDimension2> {
|
||||
pub lwe_dimension: LweDimension2,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct InputParameter<LweDimension> {
|
||||
pub lwe_dimension: LweDimension,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct AtomicPatternParameters<
|
||||
InputLweDimension,
|
||||
KsDecompositionParameter,
|
||||
InternalLweDimension,
|
||||
PbsDecompositionParameter,
|
||||
GlweDimensionAndPolynomialSize,
|
||||
> {
|
||||
pub input_lwe_dimension: InputLweDimension,
|
||||
pub ks_decomposition_parameter: KsDecompositionParameter,
|
||||
pub internal_lwe_dimension: InternalLweDimension,
|
||||
pub pbs_decomposition_parameter: PbsDecompositionParameter,
|
||||
pub output_glwe_params: GlweDimensionAndPolynomialSize,
|
||||
}
|
||||
}
|
||||
|
||||
mod grouped {
|
||||
use super::{
|
||||
GlweParameters, KsDecompositionParameters, LweDimension, PbsDecompositionParameters,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Parameters<
|
||||
LweDimension2,
|
||||
KsLogBase,
|
||||
KsLevel,
|
||||
PbsLogBase,
|
||||
PbsLevel,
|
||||
LogPolynomialSize,
|
||||
GlweDimension,
|
||||
> {
|
||||
pub lwe_dimension: Vec<LweDimension<LweDimension2>>,
|
||||
pub glwe_dimension_and_polynomial_size:
|
||||
Vec<GlweParameters<LogPolynomialSize, GlweDimension>>,
|
||||
pub pbs_decomposition_parameters: Vec<PbsDecompositionParameters<PbsLogBase, PbsLevel>>,
|
||||
pub ks_decomposition_parameters: Vec<KsDecompositionParameters<KsLogBase, KsLevel>>,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user