mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
1 Commits
v21.0.1
...
ac/negativ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca354f9261 |
@@ -4914,7 +4914,7 @@ pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
axis: &usize,
|
||||
stride: &usize,
|
||||
stride: &isize,
|
||||
modulo: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = region.assign(&config.custom_gates.inputs[0], &values[0])?;
|
||||
|
||||
@@ -49,7 +49,7 @@ pub enum PolyOp {
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
stride: usize,
|
||||
stride: isize,
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
@@ -108,13 +108,8 @@ pub enum PolyOp {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -188,7 +183,8 @@ impl<
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, output_padding, group, data_format, kernel_format)
|
||||
stride, padding, output_padding, group, data_format, kernel_format
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
use super::errors::GraphError;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::Op;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,6 +22,7 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
Downsample,
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
|
||||
einsum::EinSum,
|
||||
element_wise::ElementWiseOp,
|
||||
nn::{LeakyRelu, Reduce, Softmax},
|
||||
Downsample,
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::{
|
||||
@@ -1398,7 +1398,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::Downsample {
|
||||
axis: downsample_node.axis,
|
||||
stride: downsample_node.stride as usize,
|
||||
stride: downsample_node.stride,
|
||||
modulo: downsample_node.modulo,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -535,30 +535,101 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
|
||||
/// let result = downsample(&x, 1, 2, 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 6]), &[2, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// // Test case 1: Negative stride along dimension 0
|
||||
/// // This should flip the order along dimension 0
|
||||
/// let result = downsample(&x, 0, -1, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 6, 1, 2, 3]), // Flipped order of rows
|
||||
/// &[2, 3]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 2: Negative stride along dimension 1
|
||||
/// // This should flip the order along dimension 1
|
||||
/// let result = downsample(&x, 1, -1, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 1, 6, 5, 4]), // Flipped order of columns
|
||||
/// &[2, 3]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 3: Negative stride with stride magnitude > 1
|
||||
/// // This should both skip and flip
|
||||
/// let result = downsample(&x, 1, -2, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 1, 6, 4]), // Take every 2nd element in reverse
|
||||
/// &[2, 2]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 4: Negative stride with non-zero modulo
|
||||
/// // This should start at (size - 1 - modulo) and reverse
|
||||
/// let result = downsample(&x, 1, -2, 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 5]), // Start at second element from end, take every 2nd in reverse
|
||||
/// &[2, 1]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Create a larger test case for more complex downsampling
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
|
||||
/// &[3, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// // Test case 5: Negative stride with modulo on larger tensor
|
||||
/// let result = downsample(&y, 1, -2, 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 1, 7, 5, 11, 9]), // Start at one after reverse, take every 2nd
|
||||
/// &[3, 2]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn downsample<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
stride: isize, // Changed from usize to isize to support negative strides
|
||||
modulo: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mut output_shape = input.dims().to_vec();
|
||||
// now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero)
|
||||
let remainder = (input.dims()[dim] - modulo) % stride;
|
||||
let div = (input.dims()[dim] - modulo) / stride;
|
||||
output_shape[dim] = div + (remainder > 0) as usize;
|
||||
let mut output = Tensor::<T>::new(None, &output_shape)?;
|
||||
// Handle negative stride case
|
||||
if stride == 0 {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"downsample stride cannot be zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if modulo > input.dims()[dim] {
|
||||
let stride_abs = stride.unsigned_abs();
|
||||
let mut output_shape = input.dims().to_vec();
|
||||
|
||||
if modulo >= input.dims()[dim] {
|
||||
return Err(TensorError::DimMismatch("downsample".to_string()));
|
||||
}
|
||||
|
||||
// now downsample along axis dim offset by modulo
|
||||
// Calculate output shape based on the absolute value of stride
|
||||
let remainder = (input.dims()[dim] - modulo) % stride_abs;
|
||||
let div = (input.dims()[dim] - modulo) / stride_abs;
|
||||
output_shape[dim] = div + (remainder > 0) as usize;
|
||||
|
||||
let mut output = Tensor::<T>::new(None, &output_shape)?;
|
||||
|
||||
// Calculate indices based on stride direction
|
||||
let indices = (0..output_shape.len())
|
||||
.map(|i| {
|
||||
if i == dim {
|
||||
let mut index = vec![0; output_shape[i]];
|
||||
for (i, idx) in index.iter_mut().enumerate() {
|
||||
*idx = i * stride + modulo;
|
||||
for (j, idx) in index.iter_mut().enumerate() {
|
||||
if stride > 0 {
|
||||
// Positive stride: move forward from modulo
|
||||
*idx = j * stride_abs + modulo;
|
||||
} else {
|
||||
// Negative stride: move backward from (size - 1 - modulo)
|
||||
*idx = (input.dims()[dim] - 1 - modulo) - j * stride_abs;
|
||||
}
|
||||
}
|
||||
index
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user