mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
1 Commits
v22.0.4
...
ac/cap-inf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c214838242 |
@@ -18,13 +18,12 @@ use self::tensor::{create_constant_tensor, create_zero_tensor};
|
||||
use super::{chip::BaseConfig, region::RegionCtx};
|
||||
use crate::{
|
||||
circuit::{ops::base::BaseOp, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
tensor::{
|
||||
create_unit_tensor, get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
Tensor, TensorError, ValType,
|
||||
},
|
||||
fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
tensor::{
|
||||
Tensor, TensorError, ValType, create_unit_tensor, get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -1751,15 +1750,15 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
|
||||
// assert than res is less than the product of the dims
|
||||
if region.witness_gen() {
|
||||
assert!(
|
||||
res.int_evals()?
|
||||
.iter()
|
||||
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
|
||||
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
|
||||
dims.iter().product::<usize>(),
|
||||
index_val.show(),
|
||||
index_dim_multiplier.show(),
|
||||
res.show()
|
||||
);
|
||||
res.int_evals()?
|
||||
.iter()
|
||||
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
|
||||
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
|
||||
dims.iter().product::<usize>(),
|
||||
index_val.show(),
|
||||
index_dim_multiplier.show(),
|
||||
res.show()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2211,12 +2210,12 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
axes: &[usize],
|
||||
// generic layout op
|
||||
op: impl Fn(
|
||||
&BaseConfig<F>,
|
||||
&mut RegionCtx<F>,
|
||||
&[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
&BaseConfig<F>,
|
||||
&mut RegionCtx<F>,
|
||||
&[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// calculate value of output
|
||||
|
||||
@@ -4101,6 +4100,11 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut output = values[0].clone();
|
||||
|
||||
println!("output: {:?}", output);
|
||||
println!("all prev assigned: {:?}", output.all_prev_assigned());
|
||||
println!("decomp: {:?}", decomp);
|
||||
|
||||
if !output.all_prev_assigned() {
|
||||
// checks they are in range
|
||||
if decomp {
|
||||
@@ -5595,6 +5599,7 @@ pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
Tensor::from([alpha.0; 1].into_iter()),
|
||||
*input_scale,
|
||||
&crate::graph::Visibility::Fixed,
|
||||
false,
|
||||
)?;
|
||||
|
||||
let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);
|
||||
@@ -5617,12 +5622,12 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
axes: &[usize],
|
||||
op: impl Fn(
|
||||
&BaseConfig<F>,
|
||||
&mut RegionCtx<F>,
|
||||
&[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
&BaseConfig<F>,
|
||||
&mut RegionCtx<F>,
|
||||
&[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut input = values[0].clone();
|
||||
|
||||
|
||||
@@ -301,7 +301,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
self.quantized_values =
|
||||
quantize_tensor(self.raw_values.clone(), new_scale, &visibility, true)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -317,13 +318,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -14,14 +14,14 @@ use super::VarScales;
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::CircuitError;
|
||||
use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
@@ -740,7 +740,7 @@ fn rescale_const_with_single_use(
|
||||
if scale_max > ¤t_scale {
|
||||
let raw_values = constant.raw_values.clone();
|
||||
constant.quantized_values =
|
||||
super::quantize_tensor(raw_values, *scale_max, param_visibility)?;
|
||||
super::quantize_tensor(raw_values, *scale_max, param_visibility, true)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
use super::errors::GraphError;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::Op;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,6 +22,7 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
Downsample,
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
|
||||
einsum::EinSum,
|
||||
element_wise::ElementWiseOp,
|
||||
nn::{LeakyRelu, Reduce, Softmax},
|
||||
Downsample,
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::{
|
||||
@@ -68,6 +68,33 @@ pub fn quantize_float(
|
||||
Ok(scaled)
|
||||
}
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. And values that exceed the max are capped.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float_and_cap(
|
||||
elem: &f64,
|
||||
shift: f64,
|
||||
scale: crate::Scale,
|
||||
) -> Result<IntegerRep, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Ok(IntegerRep::MAX);
|
||||
} else if *elem < -max_value {
|
||||
return Ok(IntegerRep::MIN);
|
||||
}
|
||||
|
||||
// we parallelize the quantization process as it seems to be quite slow at times
|
||||
let scaled = (mult * *elem + shift).round() as IntegerRep;
|
||||
|
||||
Ok(scaled)
|
||||
}
|
||||
|
||||
/// Dequantizes a field element to a f64 using a fixed point representation.
|
||||
/// Arguments
|
||||
/// * `felt` - the field element to dequantize.
|
||||
@@ -379,7 +406,7 @@ pub fn new_op_from_onnx(
|
||||
let range = (start..end).step_by(delta).collect::<Vec<_>>();
|
||||
let raw_value = range.iter().map(|x| *x as f32).collect::<Tensor<_>>();
|
||||
// Quantize the raw value (integers)
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed, false)?;
|
||||
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
@@ -711,6 +738,7 @@ pub fn new_op_from_onnx(
|
||||
raw_value.clone(),
|
||||
constant_scale,
|
||||
&run_args.param_visibility,
|
||||
true,
|
||||
)?;
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
@@ -1550,13 +1578,20 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
const_value: Tensor<f32>,
|
||||
scale: crate::Scale,
|
||||
visibility: &Visibility,
|
||||
cap: bool,
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
)?))
|
||||
if cap {
|
||||
Ok::<F, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(
|
||||
quantize_float_and_cap(&(x).into(), 0.0, scale)?,
|
||||
))
|
||||
} else {
|
||||
Ok(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
)?))
|
||||
}
|
||||
})?;
|
||||
|
||||
value.set_scale(scale);
|
||||
@@ -1644,7 +1679,7 @@ pub mod tests {
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility, false).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub use var::*;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
fieldutils::{IntegerRep, integer_rep_to_felt},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -415,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1254,7 +1254,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get last element of empty tensor".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1279,7 +1279,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get first element of empty tensor".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1692,8 +1692,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| {
|
||||
match T::zero() { Some(zero) => {
|
||||
.map(|(o, r)| match T::zero() {
|
||||
Some(zero) => {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
@@ -1702,11 +1702,10 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
} _ => {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}}
|
||||
}
|
||||
_ => Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
@@ -1768,7 +1767,6 @@ pub fn get_broadcasted_shape(
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
///
|
||||
|
||||
/// The shape of data for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
|
||||
Reference in New Issue
Block a user