Compare commits

...

1 Commits

Author SHA1 Message Date
dante
c214838242 cap inf 2025-02-28 14:16:19 -05:00
5 changed files with 95 additions and 61 deletions

View File

@@ -18,13 +18,12 @@ use self::tensor::{create_constant_tensor, create_zero_tensor};
use super::{chip::BaseConfig, region::RegionCtx};
use crate::{
circuit::{ops::base::BaseOp, utils},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
tensor::{
create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
},
fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt},
tensor::{DataFormat, KernelFormat},
tensor::{
Tensor, TensorError, ValType, create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
},
};
use super::*;
@@ -1751,15 +1750,15 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
// assert than res is less than the product of the dims
if region.witness_gen() {
assert!(
res.int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
res.int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
}
@@ -2211,12 +2210,12 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
axes: &[usize],
// generic layout op
op: impl Fn(
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
) -> Result<ValTensor<F>, CircuitError> {
// calculate value of output
@@ -4101,6 +4100,11 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
decomp: bool,
) -> Result<ValTensor<F>, CircuitError> {
let mut output = values[0].clone();
println!("output: {:?}", output);
println!("all prev assigned: {:?}", output.all_prev_assigned());
println!("decomp: {:?}", decomp);
if !output.all_prev_assigned() {
// checks they are in range
if decomp {
@@ -5595,6 +5599,7 @@ pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
Tensor::from([alpha.0; 1].into_iter()),
*input_scale,
&crate::graph::Visibility::Fixed,
false,
)?;
let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);
@@ -5617,12 +5622,12 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
axes: &[usize],
op: impl Fn(
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
) -> Result<ValTensor<F>, CircuitError> {
let mut input = values[0].clone();

View File

@@ -301,7 +301,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
self.quantized_values =
quantize_tensor(self.raw_values.clone(), new_scale, &visibility, true)?;
Ok(())
}
@@ -317,13 +318,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -14,14 +14,14 @@ use super::VarScales;
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::circuit::CircuitError;
use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -740,7 +740,7 @@ fn rescale_const_with_single_use(
if scale_max > &current_scale {
let raw_values = constant.raw_values.clone();
constant.quantized_values =
super::quantize_tensor(raw_values, *scale_max, param_visibility)?;
super::quantize_tensor(raw_values, *scale_max, param_visibility, true)?;
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,6 +22,7 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -68,6 +68,33 @@ pub fn quantize_float(
Ok(scaled)
}
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. And values that exceed the max are capped.
/// Arguments
///
/// * `elem` - the element to quantize.
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float_and_cap(
elem: &f64,
shift: f64,
scale: crate::Scale,
) -> Result<IntegerRep, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Ok(IntegerRep::MAX);
} else if *elem < -max_value {
return Ok(IntegerRep::MIN);
}
// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as IntegerRep;
Ok(scaled)
}
/// Dequantizes a field element to a f64 using a fixed point representation.
/// Arguments
/// * `felt` - the field element to dequantize.
@@ -379,7 +406,7 @@ pub fn new_op_from_onnx(
let range = (start..end).step_by(delta).collect::<Vec<_>>();
let raw_value = range.iter().map(|x| *x as f32).collect::<Tensor<_>>();
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed, false)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -711,6 +738,7 @@ pub fn new_op_from_onnx(
raw_value.clone(),
constant_scale,
&run_args.param_visibility,
true,
)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -1550,13 +1578,20 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
const_value: Tensor<f32>,
scale: crate::Scale,
visibility: &Visibility,
cap: bool,
) -> Result<Tensor<F>, TensorError> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
if cap {
Ok::<F, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(
quantize_float_and_cap(&(x).into(), 0.0, scale)?,
))
} else {
Ok(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
}
})?;
value.set_scale(scale);
@@ -1644,7 +1679,7 @@ pub mod tests {
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility, false).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::Visibility,
};
@@ -415,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
))
));
}
}
}
@@ -1254,7 +1254,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
))
));
}
};
@@ -1279,7 +1279,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
))
));
}
};
@@ -1692,8 +1692,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| {
match T::zero() { Some(zero) => {
.map(|(o, r)| match T::zero() {
Some(zero) => {
if r != zero {
*o = o.clone() % r;
Ok(())
@@ -1702,11 +1702,10 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
"Cannot divide by zero in remainder".to_string(),
))
}
} _ => {
Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
))
}}
}
_ => Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
)),
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1768,7 +1767,6 @@ pub fn get_broadcasted_shape(
}
}
////////////////////////
///
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]