mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-14 08:48:01 -05:00
Compare commits
2 Commits
ac/make-da
...
ac/lookupl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1099e4b036 | ||
|
|
c1f0f9314b |
@@ -45,6 +45,8 @@ pub enum HybridOp {
|
||||
ReduceArgMin {
|
||||
dim: usize,
|
||||
},
|
||||
Max,
|
||||
Min,
|
||||
Softmax {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -79,6 +81,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::Max
|
||||
| HybridOp::Min
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
vec![0, 1]
|
||||
}
|
||||
@@ -93,6 +97,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Max => format!("MAX"),
|
||||
HybridOp::Min => format!("MIN"),
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
@@ -162,6 +168,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
|
||||
@@ -4155,6 +4155,48 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(assigned_argmin)
|
||||
}
|
||||
|
||||
/// max layout
|
||||
pub(crate) fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let is_greater = greater(config, region, values)?;
|
||||
let is_less = not(config, region, &[is_greater.clone()])?;
|
||||
|
||||
let max_val_p1 = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), is_greater],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let max_val_p2 = pairwise(config, region, &[values[1].clone(), is_less], BaseOp::Mult)?;
|
||||
|
||||
pairwise(config, region, &[max_val_p1, max_val_p2], BaseOp::Add)
|
||||
}
|
||||
|
||||
/// min comp layout
|
||||
pub(crate) fn min_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let is_greater = greater(config, region, values)?;
|
||||
let is_less = not(config, region, &[is_greater.clone()])?;
|
||||
|
||||
let min_val_p1 = pairwise(config, region, &[values[0].clone(), is_less], BaseOp::Mult)?;
|
||||
|
||||
let min_val_p2 = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[1].clone(), is_greater],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
pairwise(config, region, &[min_val_p1, min_val_p2], BaseOp::Add)
|
||||
}
|
||||
|
||||
/// max layout
|
||||
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
|
||||
@@ -21,14 +21,6 @@ pub enum LookupOp {
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
@@ -129,8 +121,6 @@ impl LookupOp {
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
@@ -186,12 +176,6 @@ impl LookupOp {
|
||||
LookupOp::KroneckerDelta => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
|
||||
}
|
||||
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
@@ -289,8 +273,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::KroneckerDelta => "K_DELTA".into(),
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
|
||||
@@ -763,81 +763,38 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
if unit == 0. {
|
||||
SupportedOp::Linear(PolyOp::ReLU)
|
||||
if const_inputs.len() > 0 {
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
if unit == 0. {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::ReLU)
|
||||
} else {
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
SupportedOp::Nonlinear(LookupOp::Max {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
}
|
||||
"Min" => {
|
||||
// Extract the min value
|
||||
// first find the input that is a constant
|
||||
// and then extract the value
|
||||
let const_inputs = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.is_constant())
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Min {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Min)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
|
||||
@@ -124,41 +124,40 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 34] = [
|
||||
"ezkl_demo_batch.ipynb",
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
// "mnist_vae.ipynb",
|
||||
"keras_simple_demo.ipynb",
|
||||
"mnist_gan_proof_splitting.ipynb", // 4
|
||||
"hashed_vis.ipynb", // 5
|
||||
"simple_demo_all_public.ipynb",
|
||||
"data_attest.ipynb",
|
||||
"little_transformer.ipynb",
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
"xgboost.ipynb",
|
||||
"lightgbm.ipynb",
|
||||
"svm.ipynb",
|
||||
"simple_demo_public_input_output.ipynb",
|
||||
"simple_demo_public_network_output.ipynb", // 20
|
||||
"gcn.ipynb",
|
||||
"linear_regression.ipynb",
|
||||
"stacked_regression.ipynb",
|
||||
"data_attest_hashed.ipynb",
|
||||
"kzg_vis.ipynb", // 25
|
||||
"kmeans.ipynb",
|
||||
"solvency.ipynb",
|
||||
"sklearn_mlp.ipynb",
|
||||
"generalized_inverse.ipynb",
|
||||
"mnist_classifier.ipynb", // 30
|
||||
"world_rotation.ipynb",
|
||||
"logistic_regression.ipynb",
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
"simple_demo_all_public.ipynb", // 7
|
||||
"data_attest.ipynb", // 8
|
||||
"little_transformer.ipynb", // 9
|
||||
"simple_demo_aggregated_proofs.ipynb", // 10
|
||||
"ezkl_demo.ipynb", // 11
|
||||
"lstm.ipynb", // 12
|
||||
"set_membership.ipynb", // 13
|
||||
"decision_tree.ipynb", // 14
|
||||
"random_forest.ipynb", // 15
|
||||
"gradient_boosted_trees.ipynb", // 16
|
||||
"xgboost.ipynb", // 17
|
||||
"lightgbm.ipynb", // 18
|
||||
"svm.ipynb", // 19
|
||||
"simple_demo_public_input_output.ipynb", // 20
|
||||
"simple_demo_public_network_output.ipynb", // 21
|
||||
"gcn.ipynb", // 22
|
||||
"linear_regression.ipynb", // 23
|
||||
"stacked_regression.ipynb", // 24
|
||||
"data_attest_hashed.ipynb", // 25
|
||||
"kzg_vis.ipynb", // 26
|
||||
"kmeans.ipynb", // 27
|
||||
"solvency.ipynb", // 28
|
||||
"sklearn_mlp.ipynb", // 29
|
||||
"generalized_inverse.ipynb", // 30
|
||||
"mnist_classifier.ipynb", // 31
|
||||
"world_rotation.ipynb", // 32
|
||||
"logistic_regression.ipynb", // 33
|
||||
];
|
||||
|
||||
macro_rules! test_func {
|
||||
|
||||
Reference in New Issue
Block a user