Compare commits

..

1 Commits

Author SHA1 Message Date
dante
c214838242 cap inf 2025-02-28 14:16:19 -05:00
8 changed files with 165 additions and 1118 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -301,7 +301,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
self.quantized_values =
quantize_tensor(self.raw_values.clone(), new_scale, &visibility, true)?;
Ok(())
}
@@ -317,13 +318,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -14,14 +14,14 @@ use super::VarScales;
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::circuit::CircuitError;
use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -740,7 +740,7 @@ fn rescale_const_with_single_use(
if scale_max > &current_scale {
let raw_values = constant.raw_values.clone();
constant.quantized_values =
super::quantize_tensor(raw_values, *scale_max, param_visibility)?;
super::quantize_tensor(raw_values, *scale_max, param_visibility, true)?;
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,6 +22,7 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -68,6 +68,33 @@ pub fn quantize_float(
Ok(scaled)
}
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. And values that exceed the max are capped.
/// Arguments
///
/// * `elem` - the element to quantize.
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float_and_cap(
elem: &f64,
shift: f64,
scale: crate::Scale,
) -> Result<IntegerRep, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Ok(IntegerRep::MAX);
} else if *elem < -max_value {
return Ok(IntegerRep::MIN);
}
// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as IntegerRep;
Ok(scaled)
}
/// Dequantizes a field element to a f64 using a fixed point representation.
/// Arguments
/// * `felt` - the field element to dequantize.
@@ -379,7 +406,7 @@ pub fn new_op_from_onnx(
let range = (start..end).step_by(delta).collect::<Vec<_>>();
let raw_value = range.iter().map(|x| *x as f32).collect::<Tensor<_>>();
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed, false)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -711,6 +738,7 @@ pub fn new_op_from_onnx(
raw_value.clone(),
constant_scale,
&run_args.param_visibility,
true,
)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -1550,13 +1578,20 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
const_value: Tensor<f32>,
scale: crate::Scale,
visibility: &Visibility,
cap: bool,
) -> Result<Tensor<F>, TensorError> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
if cap {
Ok::<F, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(
quantize_float_and_cap(&(x).into(), 0.0, scale)?,
))
} else {
Ok(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
}
})?;
value.set_scale(scale);
@@ -1644,7 +1679,7 @@ pub mod tests {
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility, false).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::Visibility,
};
@@ -415,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
))
));
}
}
}
@@ -926,9 +926,6 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
new_dims.iter().product::<usize>()
@@ -1107,10 +1104,6 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
} else if self.dims() == &[0] && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}
if self.dims().len() > shape.len() {
@@ -1261,7 +1254,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
))
));
}
};
@@ -1286,7 +1279,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
))
));
}
};
@@ -1699,8 +1692,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| {
match T::zero() { Some(zero) => {
.map(|(o, r)| match T::zero() {
Some(zero) => {
if r != zero {
*o = o.clone() % r;
Ok(())
@@ -1709,11 +1702,10 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
"Cannot divide by zero in remainder".to_string(),
))
}
} _ => {
Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
))
}}
}
_ => Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
)),
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1775,7 +1767,6 @@ pub fn get_broadcasted_shape(
}
}
////////////////////////
///
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]

View File

@@ -1342,11 +1342,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Gets the total number of elements in the tensor
pub fn len(&self) -> usize {
match self {
ValTensor::Value { dims, inner, .. } => {
ValTensor::Value { dims, .. } => {
if !dims.is_empty() && (dims != &[0]) {
dims.iter().product::<usize>()
} else if dims.is_empty() {
inner.inner.len()
} else {
0
}