mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
3 Commits
ac/cap-inf
...
ac/documen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d57285df3a | ||
|
|
61a246ee20 | ||
|
|
fcbb27677f |
1
examples/onnx/hierarchical_risk/input.json
Normal file
1
examples/onnx/hierarchical_risk/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/hierarchical_risk/network.onnx
Normal file
BIN
examples/onnx/hierarchical_risk/network.onnx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -301,8 +301,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
self.quantized_values =
|
||||
quantize_tensor(self.raw_values.clone(), new_scale, &visibility, true)?;
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -318,8 +317,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -14,14 +14,14 @@ use super::VarScales;
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::CircuitError;
|
||||
use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
@@ -740,7 +740,7 @@ fn rescale_const_with_single_use(
|
||||
if scale_max > ¤t_scale {
|
||||
let raw_values = constant.raw_values.clone();
|
||||
constant.quantized_values =
|
||||
super::quantize_tensor(raw_values, *scale_max, param_visibility, true)?;
|
||||
super::quantize_tensor(raw_values, *scale_max, param_visibility)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
use super::errors::GraphError;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::Op;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,7 +22,6 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
Downsample,
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
|
||||
einsum::EinSum,
|
||||
element_wise::ElementWiseOp,
|
||||
nn::{LeakyRelu, Reduce, Softmax},
|
||||
Downsample,
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::{
|
||||
@@ -68,33 +68,6 @@ pub fn quantize_float(
|
||||
Ok(scaled)
|
||||
}
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. And values that exceed the max are capped.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float_and_cap(
|
||||
elem: &f64,
|
||||
shift: f64,
|
||||
scale: crate::Scale,
|
||||
) -> Result<IntegerRep, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Ok(IntegerRep::MAX);
|
||||
} else if *elem < -max_value {
|
||||
return Ok(IntegerRep::MIN);
|
||||
}
|
||||
|
||||
// we parallelize the quantization process as it seems to be quite slow at times
|
||||
let scaled = (mult * *elem + shift).round() as IntegerRep;
|
||||
|
||||
Ok(scaled)
|
||||
}
|
||||
|
||||
/// Dequantizes a field element to a f64 using a fixed point representation.
|
||||
/// Arguments
|
||||
/// * `felt` - the field element to dequantize.
|
||||
@@ -406,7 +379,7 @@ pub fn new_op_from_onnx(
|
||||
let range = (start..end).step_by(delta).collect::<Vec<_>>();
|
||||
let raw_value = range.iter().map(|x| *x as f32).collect::<Tensor<_>>();
|
||||
// Quantize the raw value (integers)
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed, false)?;
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
|
||||
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
@@ -738,7 +711,6 @@ pub fn new_op_from_onnx(
|
||||
raw_value.clone(),
|
||||
constant_scale,
|
||||
&run_args.param_visibility,
|
||||
true,
|
||||
)?;
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
@@ -1578,20 +1550,13 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
const_value: Tensor<f32>,
|
||||
scale: crate::Scale,
|
||||
visibility: &Visibility,
|
||||
cap: bool,
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
if cap {
|
||||
Ok::<F, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(
|
||||
quantize_float_and_cap(&(x).into(), 0.0, scale)?,
|
||||
))
|
||||
} else {
|
||||
Ok(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
)?))
|
||||
}
|
||||
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
)?))
|
||||
})?;
|
||||
|
||||
value.set_scale(scale);
|
||||
@@ -1679,7 +1644,7 @@ pub mod tests {
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility, false).unwrap();
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub use var::*;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{IntegerRep, integer_rep_to_felt},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -415,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1254,7 +1254,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get last element of empty tensor".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1279,7 +1279,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get first element of empty tensor".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1692,8 +1692,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| match T::zero() {
|
||||
Some(zero) => {
|
||||
.map(|(o, r)| {
|
||||
match T::zero() { Some(zero) => {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
@@ -1702,10 +1702,11 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
_ => Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
)),
|
||||
} _ => {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
@@ -1767,6 +1768,7 @@ pub fn get_broadcasted_shape(
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
///
|
||||
|
||||
/// The shape of data for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
|
||||
@@ -1342,9 +1342,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Gets the total number of elements in the tensor
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
ValTensor::Value { dims, .. } => {
|
||||
ValTensor::Value { dims, inner, .. } => {
|
||||
if !dims.is_empty() && (dims != &[0]) {
|
||||
dims.iter().product::<usize>()
|
||||
} else if dims.is_empty() {
|
||||
inner.inner.len()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user