Compare commits

..

2 Commits

Author SHA1 Message Date
github-actions[bot]
2a1645cfff ci: update version string in docs 2025-03-01 04:56:38 +00:00
dante
fcbb27677f fix: empty dim len can be 1 (#949) 2025-02-28 23:56:19 -05:00
9 changed files with 64 additions and 95 deletions

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '20.2.2'
version = release

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -18,12 +18,13 @@ use self::tensor::{create_constant_tensor, create_zero_tensor};
use super::{chip::BaseConfig, region::RegionCtx};
use crate::{
circuit::{ops::base::BaseOp, utils},
fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt},
tensor::{DataFormat, KernelFormat},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
tensor::{
Tensor, TensorError, ValType, create_unit_tensor, get_broadcasted_shape,
create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
},
tensor::{DataFormat, KernelFormat},
};
use super::*;
@@ -1750,15 +1751,15 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
// assert than res is less than the product of the dims
if region.witness_gen() {
assert!(
res.int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
res.int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
}
@@ -2210,12 +2211,12 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
axes: &[usize],
// generic layout op
op: impl Fn(
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
) -> Result<ValTensor<F>, CircuitError> {
// calculate value of output
@@ -4100,11 +4101,6 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
decomp: bool,
) -> Result<ValTensor<F>, CircuitError> {
let mut output = values[0].clone();
println!("output: {:?}", output);
println!("all prev assigned: {:?}", output.all_prev_assigned());
println!("decomp: {:?}", decomp);
if !output.all_prev_assigned() {
// checks they are in range
if decomp {
@@ -5599,7 +5595,6 @@ pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
Tensor::from([alpha.0; 1].into_iter()),
*input_scale,
&crate::graph::Visibility::Fixed,
false,
)?;
let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);
@@ -5622,12 +5617,12 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
axes: &[usize],
op: impl Fn(
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
) -> Result<ValTensor<F>, CircuitError> {
let mut input = values[0].clone();

View File

@@ -301,8 +301,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
self.quantized_values =
quantize_tensor(self.raw_values.clone(), new_scale, &visibility, true)?;
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
}
@@ -318,8 +317,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -14,14 +14,14 @@ use super::VarScales;
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::circuit::CircuitError;
use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -740,7 +740,7 @@ fn rescale_const_with_single_use(
if scale_max > &current_scale {
let raw_values = constant.raw_values.clone();
constant.quantized_values =
super::quantize_tensor(raw_values, *scale_max, param_visibility, true)?;
super::quantize_tensor(raw_values, *scale_max, param_visibility)?;
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,7 +22,6 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -68,33 +68,6 @@ pub fn quantize_float(
Ok(scaled)
}
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. And values that exceed the max are capped.
/// Arguments
///
/// * `elem` - the element to quantize.
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float_and_cap(
elem: &f64,
shift: f64,
scale: crate::Scale,
) -> Result<IntegerRep, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Ok(IntegerRep::MAX);
} else if *elem < -max_value {
return Ok(IntegerRep::MIN);
}
// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as IntegerRep;
Ok(scaled)
}
/// Dequantizes a field element to a f64 using a fixed point representation.
/// Arguments
/// * `felt` - the field element to dequantize.
@@ -406,7 +379,7 @@ pub fn new_op_from_onnx(
let range = (start..end).step_by(delta).collect::<Vec<_>>();
let raw_value = range.iter().map(|x| *x as f32).collect::<Tensor<_>>();
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed, false)?;
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -738,7 +711,6 @@ pub fn new_op_from_onnx(
raw_value.clone(),
constant_scale,
&run_args.param_visibility,
true,
)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -1578,20 +1550,13 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
const_value: Tensor<f32>,
scale: crate::Scale,
visibility: &Visibility,
cap: bool,
) -> Result<Tensor<F>, TensorError> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
if cap {
Ok::<F, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(
quantize_float_and_cap(&(x).into(), 0.0, scale)?,
))
} else {
Ok(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
}
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
})?;
value.set_scale(scale);
@@ -1679,7 +1644,7 @@ pub mod tests {
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility, false).unwrap();
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::Visibility,
};
@@ -415,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
));
))
}
}
}
@@ -1254,7 +1254,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
));
))
}
};
@@ -1279,7 +1279,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
));
))
}
};
@@ -1692,8 +1692,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| match T::zero() {
Some(zero) => {
.map(|(o, r)| {
match T::zero() { Some(zero) => {
if r != zero {
*o = o.clone() % r;
Ok(())
@@ -1702,10 +1702,11 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
"Cannot divide by zero in remainder".to_string(),
))
}
}
_ => Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
)),
} _ => {
Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
))
}}
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1767,6 +1768,7 @@ pub fn get_broadcasted_shape(
}
}
////////////////////////
///
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]

View File

@@ -1342,9 +1342,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Gets the total number of elements in the tensor
pub fn len(&self) -> usize {
match self {
ValTensor::Value { dims, .. } => {
ValTensor::Value { dims, inner, .. } => {
if !dims.is_empty() && (dims != &[0]) {
dims.iter().product::<usize>()
} else if dims.is_empty() {
inner.inner.len()
} else {
0
}