Compare commits

...

12 Commits

Author SHA1 Message Date
dante
c2d0bbb60f bump to 1d_conv 2025-02-09 14:16:34 -05:00
dante
27465a89d6 move to large test 2025-02-09 09:30:26 -05:00
dante
a8cccee4c8 Update binding_tests.py 2025-02-08 20:38:28 -05:00
dante
6a8358d471 bump tests 2025-02-08 20:08:19 -05:00
dante
07d192a630 Merge branch 'main' into ac/layoutytpe 2025-02-08 20:06:59 -05:00
dante
70b49538c1 Update main.rs 2025-02-08 19:29:49 -05:00
dante
6c038f2623 Update integration_tests.rs 2025-02-08 18:48:15 -05:00
dante
3cd9d2ad80 patch conv 2025-02-08 17:14:22 -05:00
dante
8f0c460f6a patch 2025-02-08 17:08:09 -05:00
dante
03d25cdbcb generalized layout and conv nd 2025-02-08 16:50:05 -05:00
dante
27ed57ad83 moar 2025-02-07 14:30:36 -05:00
dante
e320e07711 dataformat 2025-02-07 14:30:23 -05:00
17 changed files with 673 additions and 179 deletions

View File

@@ -276,6 +276,8 @@ jobs:
locked: true
# - name: The Worm Mock
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: Large 1D Conv Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
- name: MNIST Gan Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock

2
Cargo.lock generated
View File

@@ -1932,7 +1932,7 @@ dependencies = [
]
[[package]]
name = "ezkl-gpu"
name = "ezkl"
version = "0.0.0"
dependencies = [
"alloy",

View File

@@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();

View File

@@ -32,6 +32,7 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -208,6 +209,8 @@ where
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
};
let x = config
.layer_config

View File

@@ -0,0 +1,106 @@
{
"input_data": [
[
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127,
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127
]
]
}

Binary file not shown.

View File

@@ -3,7 +3,7 @@ use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -57,11 +57,13 @@ pub enum HybridOp {
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
data_format: DataFormat,
},
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
data_format: DataFormat,
},
ReduceMin {
axes: Vec<usize>,
@@ -154,10 +156,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
kernel_shape,
normalized,
normalized, data_format
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
padding, stride, kernel_shape, normalized
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
@@ -165,9 +167,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => format!(
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
padding, stride, pool_dims, data_format
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
@@ -239,6 +242,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
normalized,
data_format,
} => layouts::sumpool(
config,
region,
@@ -247,6 +251,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
*normalized,
*data_format,
)?,
HybridOp::Recip {
input_scale,
@@ -287,6 +292,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => layouts::max_pool(
config,
region,
@@ -294,6 +300,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
*data_format,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?

View File

@@ -24,6 +24,7 @@ use crate::{
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
},
tensor::{DataFormat, KernelFormat},
};
use super::*;
@@ -3225,6 +3226,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::circuit::region::RegionSettings;
/// use ezkl::circuit::BaseConfig;
/// use ezkl::tensor::ValTensor;
/// use ezkl::tensor::DataFormat;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
@@ -3234,12 +3236,12 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap());
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false).unwrap();
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false, DataFormat::default()).unwrap();
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled.int_evals().unwrap(), expected);
///
/// // This time with normalization
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true).unwrap();
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true, DataFormat::default()).unwrap();
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled.int_evals().unwrap(), expected);
/// ```
@@ -3251,9 +3253,19 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
stride: &[usize],
kernel_shape: &[usize],
normalized: bool,
data_format: DataFormat,
) -> Result<ValTensor<F>, CircuitError> {
let batch_size = values[0].dims()[0];
let image_channels = values[0].dims()[1];
let mut image = values[0].clone();
data_format.to_canonical(&mut image)?;
if data_format.has_no_batch() {
let mut dims = image.dims().to_vec();
dims.insert(0, 1);
image.reshape(&dims)?;
}
let batch_size = image.dims()[0];
let image_channels = image.dims()[1];
let kernel_len = kernel_shape.iter().product();
@@ -3278,7 +3290,16 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
.map(|coord| {
let (b, i) = (coord[0], coord[1]);
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride, 1)?;
let output = conv(
config,
region,
&[input, kernel.clone()],
padding,
stride,
1,
DataFormat::default(),
KernelFormat::default(),
)?;
res.push(output);
Ok(())
})
@@ -3293,6 +3314,9 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
if normalized {
last_elem = div(config, region, &[last_elem], F::from(kernel_len as u64))?;
}
data_format.from_canonical(&mut last_elem)?;
Ok(last_elem)
}
@@ -3302,6 +3326,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::max_pool;
/// use ezkl::tensor::DataFormat;
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
@@ -3316,7 +3341,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap());
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2]).unwrap();
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2], DataFormat::default()).unwrap();
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled.int_evals().unwrap(), expected);
///
@@ -3328,8 +3353,16 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
padding: &[(usize, usize)],
stride: &[usize],
pool_dims: &[usize],
data_format: DataFormat,
) -> Result<ValTensor<F>, CircuitError> {
let image = values[0].clone();
let mut image = values[0].clone();
data_format.to_canonical(&mut image)?;
if data_format.has_no_batch() {
let mut dims = image.dims().to_vec();
dims.insert(0, 1);
image.reshape(&dims)?;
}
let image_dims = image.dims();
@@ -3388,38 +3421,38 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.apply_in_loop(&mut output, inner_loop_function)?;
let res: ValTensor<F> = output.into();
let mut res: ValTensor<F> = output.into();
data_format.from_canonical(&mut res)?;
Ok(res)
}
/// Performs a deconvolution on the given input tensor.
/// # Examples
/// ```
// // expected outputs are taken from pytorch torch.nn.functional.conv_transpose2d
///
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::deconv;
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
/// use ezkl::circuit::BaseConfig;
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
///
/// // Original test case 1: Channel expansion
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 2: Basic deconvolution
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3428,11 +3461,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
///
/// // Original test case 3: With padding
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3441,11 +3474,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[17]), &[1, 1, 1, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
///
/// // Original test case 4: With stride
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3454,10 +3487,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 5: Zero padding with stride
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3466,10 +3500,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 6: Different kernel shape
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3478,10 +3513,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 7: Different kernel shape without padding
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3490,20 +3526,21 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
///
/// // Original test case 8: Channel expansion with stride
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 9: With bias
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[3, 8, 0, 8, 4, 9, 8, 1, 8]),
/// &[1, 1, 3, 3],
@@ -3516,11 +3553,89 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[1]),
/// &[1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 1: NHWC format with HWIO kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 2, 2, 1], // NHWC format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 5, 3]),
/// &[2, 2, 1, 1], // HWIO format
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[27]), &[1, 1, 1, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 2: 1D deconvolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[1, 1, 3], // NCH format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2]),
/// &[1, 1, 2], // OIH format
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![0], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 7, 6]), &[1, 1, 4]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 3: 3D deconvolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4]),
/// &[1, 1, 2, 2, 1], // NCDHW format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1]),
/// &[1, 1, 1, 1, 2], // OIDHW format
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![0; 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4]), &[1, 1, 2, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 4: Multi-channel with NHWC format and OHWI kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1, 3, 2, 1, 4]), // 2 channels, 2x2 spatial
/// &[1, 2, 2, 2], // NHWC format [batch, height, width, channels]
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[1, 2, 2, 2], // OHWI format [out_channels, height, width, in_channels]
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 24, 4, 41, 78, 27, 27, 66, 39]), &[1, 3, 3, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 5: CHW format (no batch dimension)
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 2, 2], // CHW format [channels, height, width]
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4]),
/// &[1, 1, 2, 2], // OIHW format [out_channels, in_channels, height, width]
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::CHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 1, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 6: HWC format with HWIO kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 3, 4, 1]),
/// &[2, 2, 1], // HWC format [height, width, channels]
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 1, 2]),
/// &[2, 2, 1, 1], // HWIO format [height, width, in_channels, out_channels]
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::HWC, KernelFormat::HWIO).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 3, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
///
pub fn deconv<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
@@ -3531,9 +3646,14 @@ pub fn deconv<
output_padding: &[usize],
stride: &[usize],
num_groups: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = inputs.len() == 3;
let (image, kernel) = (&inputs[0], &inputs[1]);
let (mut working_image, mut working_kernel) = (inputs[0].clone(), inputs[1].clone());
data_format.to_canonical(&mut working_image)?;
kernel_format.to_canonical(&mut working_kernel)?;
if stride.iter().any(|&s| s == 0) {
return Err(TensorError::DimMismatch(
@@ -3543,26 +3663,23 @@ pub fn deconv<
}
let null_val = ValType::Constant(F::ZERO);
let mut expanded_image = working_image.clone();
let mut expanded_image = image.clone();
// Expand image by inserting zeros according to stride
for (i, s) in stride.iter().enumerate() {
expanded_image.intercalate_values(null_val.clone(), *s, 2 + i)?;
}
// Pad to kernel size for each spatial dimension
expanded_image.pad(
kernel.dims()[2..]
working_kernel.dims()[2..]
.iter()
.map(|d| (d - 1, d - 1))
.collect::<Vec<_>>(),
2,
)?; // pad to the kernel size
// flip order
let channel_coord = (0..kernel.dims()[0])
.cartesian_product(0..kernel.dims()[1])
.collect::<Vec<_>>();
)?;
// Calculate slice coordinates considering padding and output padding
let slice_coord = expanded_image
.dims()
.iter()
@@ -3578,26 +3695,34 @@ pub fn deconv<
let sliced_expanded_image = expanded_image.get_slice(&slice_coord)?;
let mut inverted_kernels = vec![];
// Generate channel coordinates for kernel transformation
let (in_ch_dim, out_ch_dim) =
KernelFormat::default().get_channel_dims(working_kernel.dims().len());
let channel_coord = (0..working_kernel.dims()[out_ch_dim])
.cartesian_product(0..working_kernel.dims()[in_ch_dim])
.collect::<Vec<_>>();
// Invert kernels for deconvolution
let mut inverted_kernels = vec![];
for (i, j) in channel_coord {
let channel = kernel.get_slice(&[i..i + 1, j..j + 1])?;
let channel = working_kernel.get_slice(&[i..i + 1, j..j + 1])?;
let mut channel = Tensor::from(channel.get_inner_tensor()?.clone().into_iter().rev());
channel.reshape(&kernel.dims()[2..])?;
channel.reshape(&working_kernel.dims()[2..])?;
inverted_kernels.push(channel);
}
let mut deconv_kernel =
Tensor::new(Some(&inverted_kernels), &[inverted_kernels.len()])?.combine()?;
deconv_kernel.reshape(kernel.dims())?;
deconv_kernel.reshape(working_kernel.dims())?;
// tensorflow formatting patch
if kernel.dims()[0] == sliced_expanded_image.dims()[1] {
// Handle tensorflow-style input/output channel ordering
if working_kernel.dims()[0] == sliced_expanded_image.dims()[1] {
let mut dims = deconv_kernel.dims().to_vec();
dims.swap(0, 1);
deconv_kernel.reshape(&dims)?;
}
// Prepare inputs for convolution
let conv_input = if has_bias {
vec![
sliced_expanded_image,
@@ -3608,28 +3733,32 @@ pub fn deconv<
vec![sliced_expanded_image, deconv_kernel.clone().into()]
};
let conv_dim = kernel.dims()[2..].len();
let conv_dim = working_kernel.dims()[2..].len();
let output = conv(
// Perform convolution with canonical formats
let mut output = conv(
config,
region,
&conv_input,
&vec![(0, 0); conv_dim],
&vec![1; conv_dim],
num_groups,
data_format.canonical(), // Use canonical format
kernel_format.canonical(), // Use canonical format
)?;
// Convert output back to requested format
data_format.from_canonical(&mut output)?;
Ok(output)
}
/// Applies convolution over a ND tensor of shape C x H x D1...DN (and adds a bias).
/// ```
/// // expected outputs are taken from pytorch torch.nn.functional.conv2d
///
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::conv;
/// use ezkl::tensor::val::ValTensor;
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
@@ -3638,6 +3767,7 @@ pub fn deconv<
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
///
/// // Test case 1: Basic 2D convolution with NCHW format (default)
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
@@ -3650,44 +3780,64 @@ pub fn deconv<
/// Some(&[0]),
/// &[1],
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Now test single channel
/// // Test case 2: NHWC format with HWIO kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 2, 3, 3],
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3, 1], // NHWC format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1]),
/// &[2, 1, 2, 2],
/// Some(&[1, 1, 5, 1]),
/// &[2, 2, 1, 1], // HWIO format
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[11, 24, 20, 14]), &[1, 2, 2, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Test case 3: Multi-channel NHWC with OHWI kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3, 2], // NHWC format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 1, 1, 2, 5, 2, 1, 2]),
/// &[1, 2, 2, 2], // OHWI format
/// ).unwrap());
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1]),
/// &[2],
/// ).unwrap());
///
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[64, 66, 46, 58]), &[1, 2, 2, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Now test multi channel
/// // Test case 4: 1D convolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 2, 3, 3],
/// Some(&[1, 2, 3, 4, 5]),
/// &[1, 1, 5], // NCHW format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]),
/// &[4, 2, 2, 2],
/// ).unwrap());
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1]),
/// &[4],
/// Some(&[1, 2, 3]),
/// &[1, 1, 3], // OIHW format
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[1, 1, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap();
/// // Test case 5: 3D convolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[1, 1, 2, 2, 2], // NCDHW format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1]),
/// &[1, 1, 1, 1, 2], // OIDHW format
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 7, 11, 15]), &[1, 1, 2, 2, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
///
@@ -3700,9 +3850,14 @@ pub fn conv<
padding: &[(usize, usize)],
stride: &[usize],
num_groups: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = values.len() == 3;
let (mut image, mut kernel) = (values[0].clone(), values[1].clone());
let (mut working_image, mut working_kernel) = (values[0].clone(), values[1].clone());
data_format.to_canonical(&mut working_image)?;
kernel_format.to_canonical(&mut working_kernel)?;
if stride.iter().any(|&s| s == 0) {
return Err(TensorError::DimMismatch(
@@ -3711,47 +3866,40 @@ pub fn conv<
.into());
}
// we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them
// 1. assign the kernel
// Assign tensors
let mut assigned_len = vec![];
if !kernel.all_prev_assigned() {
kernel = region.assign(&config.custom_gates.inputs[0], &kernel)?;
assigned_len.push(kernel.len());
if !working_kernel.all_prev_assigned() {
working_kernel = region.assign(&config.custom_gates.inputs[0], &working_kernel)?;
assigned_len.push(working_kernel.len());
}
// 2. assign the image
if !image.all_prev_assigned() {
image = region.assign(&config.custom_gates.inputs[1], &image)?;
assigned_len.push(image.len());
if !working_image.all_prev_assigned() {
working_image = region.assign(&config.custom_gates.inputs[1], &working_image)?;
assigned_len.push(working_image.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
// if image is 3d add a dummy batch dimension
if image.dims().len() == kernel.dims().len() - 1 {
image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?;
if data_format.has_no_batch() {
let mut dim = working_image.dims().to_vec();
dim.insert(0, 1);
working_image.reshape(&dim)?;
}
let image_dims = image.dims();
let kernel_dims = kernel.dims();
let image_dims = working_image.dims();
let kernel_dims = working_kernel.dims();
let mut padded_image = image.clone();
// Apply padding
let mut padded_image = working_image.clone();
padded_image.pad(padding.to_vec(), 2)?;
// Extract dimensions
let batch_size = image_dims[0];
let input_channels = image_dims[1];
let output_channels = kernel_dims[0];
log::debug!(
"batch_size: {}, output_channels: {}, input_channels: {}",
batch_size,
output_channels,
input_channels
);
// Calculate slides for each spatial dimension
let slides = image_dims[2..]
.iter()
.enumerate()
@@ -3766,8 +3914,6 @@ pub fn conv<
})
.collect::<Result<Vec<_>, TensorError>>()?;
log::debug!("slides: {:?}", slides);
let input_channels_per_group = input_channels / num_groups;
let output_channels_per_group = output_channels / num_groups;
@@ -3775,24 +3921,15 @@ pub fn conv<
return Err(TensorError::DimMismatch(format!(
"Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}",
num_groups, input_channels, output_channels
))
.into());
)).into());
}
log::debug!(
"num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}",
num_groups,
input_channels_per_group,
output_channels_per_group
);
let num_outputs =
batch_size * num_groups * output_channels_per_group * slides.iter().product::<usize>();
log::debug!("num_outputs: {}", num_outputs);
let mut output: Tensor<ValType<F>> = Tensor::new(None, &[num_outputs])?;
// Create iteration space
let mut iterations = vec![0..batch_size, 0..num_groups, 0..output_channels_per_group];
for slide in slides.iter() {
iterations.push(0..*slide);
@@ -3804,6 +3941,13 @@ pub fn conv<
.multi_cartesian_product()
.collect::<Vec<_>>();
let batch_offset = if data_format.has_no_batch() {
2 // No batch dimension, start coordinates after channels
} else {
3 // Has batch dimension, start coordinates after batch and channels
};
// Main convolution loop
let inner_loop_function = |idx: usize, region: &mut RegionCtx<F>| {
let cartesian_coord_per_group = &cartesian_coord[idx];
let (batch, group, i) = (
@@ -3817,22 +3961,19 @@ pub fn conv<
let mut slices = vec![batch..batch + 1, start_channel..end_channel];
for (i, stride) in stride.iter().enumerate() {
let coord = cartesian_coord_per_group[3 + i] * stride;
let coord = cartesian_coord_per_group[batch_offset + i] * stride;
let kernel_dim = kernel_dims[2 + i];
slices.push(coord..(coord + kernel_dim));
}
let mut local_image = padded_image.get_slice(&slices)?;
local_image.flatten();
let start_kernel_index = group * output_channels_per_group + i;
let end_kernel_index = start_kernel_index + 1;
let mut local_kernel = kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
let mut local_kernel = working_kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
local_kernel.flatten();
// this is dot product notation in einsum format
let mut res = einsum(config, region, &[local_image, local_kernel], "i,i->")?;
if has_bias {
@@ -3853,21 +3994,16 @@ pub fn conv<
region.flush()?;
region.apply_in_loop(&mut output, inner_loop_function)?;
let reshape_output = |output: &mut Tensor<ValType<F>>| -> Result<(), TensorError> {
// remove dummy batch dimension if we added one
let mut dims = vec![batch_size, output_channels];
dims.extend(slides.iter().cloned());
output.reshape(&dims)?;
// Reshape output
let mut dims = vec![batch_size, output_channels];
dims.extend(slides.iter().cloned());
output.reshape(&dims)?;
Ok(())
};
// Convert output back to requested format
let mut final_output: ValTensor<F> = output.into();
data_format.from_canonical(&mut final_output)?;
// remove dummy batch dimension if we added one
reshape_output(&mut output)?;
let output: ValTensor<_> = output.into();
Ok(output)
Ok(final_output)
}
/// Power accumulated layout

View File

@@ -4,6 +4,7 @@ use crate::{
utils::{self, F32},
},
tensor::{self, Tensor, TensorError},
tensor::{DataFormat, KernelFormat},
};
use super::{base::BaseOp, *};
@@ -43,6 +44,8 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Downsample {
axis: usize,
@@ -54,6 +57,8 @@ pub enum PolyOp {
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Add,
Sub,
@@ -165,10 +170,12 @@ impl<
stride,
padding,
group,
data_format,
kernel_format,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, group, data_format, kernel_format
)
}
PolyOp::DeConv {
@@ -176,11 +183,12 @@ impl<
padding,
output_padding,
group,
data_format,
kernel_format,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
)
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, output_padding, group, data_format, kernel_format)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
@@ -242,6 +250,8 @@ impl<
padding,
stride,
group,
data_format,
kernel_format,
} => layouts::conv(
config,
region,
@@ -249,6 +259,8 @@ impl<
padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -309,6 +321,8 @@ impl<
output_padding,
stride,
group,
data_format,
kernel_format,
} => layouts::deconv(
config,
region,
@@ -317,6 +331,8 @@ impl<
output_padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,

View File

@@ -1,5 +1,6 @@
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{DataFormat, KernelFormat};
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -1065,6 +1066,8 @@ mod conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1220,6 +1223,8 @@ mod conv_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1377,6 +1382,8 @@ mod conv_relu_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -33,7 +33,7 @@ pub enum GraphError {
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
#[error("a node has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]

View File

@@ -609,8 +609,12 @@ impl GraphData {
if input.len() % input_size != 0 {
return Err(GraphError::InvalidDims(
0,
"calibration data length must be evenly divisible by the original input_size"
.to_string(),
format!(
"calibration data length (={}) must be evenly divisible by the original input_size(={})",
input.len(),
input_size
),
));
}

View File

@@ -39,9 +39,8 @@ use tract_onnx::tract_hir::{
ops::array::{Pad, PadMode, TypedConcat},
ops::cnn::PoolSpec,
ops::konst::Const,
ops::nn::DataFormat,
tract_core::ops::cast::Cast,
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
tract_core::ops::cnn::{MaxPool, SumPool},
};
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
@@ -1146,13 +1145,6 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
let kernel_shape = &pool_spec.kernel_shape;
@@ -1161,6 +1153,7 @@ pub fn new_op_from_onnx(
padding,
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
data_format: pool_spec.data_format.into(),
})
}
"Ceil" => {
@@ -1314,15 +1307,6 @@ pub fn new_op_from_onnx(
}
}
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &conv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1350,6 +1334,8 @@ pub fn new_op_from_onnx(
padding,
stride,
group,
data_format: conv_node.pool_spec.data_format.into(),
kernel_format: conv_node.kernel_fmt.into(),
})
}
"Not" => SupportedOp::Linear(PolyOp::Not),
@@ -1373,14 +1359,6 @@ pub fn new_op_from_onnx(
}
}
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|| (deconv_node.kernel_format != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &deconv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1406,6 +1384,8 @@ pub fn new_op_from_onnx(
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
data_format: deconv_node.pool_spec.data_format.into(),
kernel_format: deconv_node.kernel_format.into(),
})
}
"Downsample" => {
@@ -1489,13 +1469,6 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
@@ -1504,6 +1477,7 @@ pub fn new_op_from_onnx(
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
normalized: sumpool_node.normalize,
data_format: pool_spec.data_format.into(),
})
}
"Pad" => {

View File

@@ -1,6 +1,6 @@
use thiserror::Error;
use super::ops::DecompositionError;
use super::{ops::DecompositionError, DataFormat};
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
@@ -44,4 +44,7 @@ pub enum TensorError {
/// Index out of bounds
#[error("index {0} out of bounds for dimension {1}")]
IndexOutOfBounds(usize, usize),
/// Invalid data conversion
#[error("invalid data conversion from format {0} to {1}")]
InvalidDataConversion(DataFormat, DataFormat),
}

View File

@@ -9,6 +9,7 @@ pub mod var;
pub use errors::TensorError;
use core::hash::Hash;
use halo2curves::ff::PrimeField;
use maybe_rayon::{
prelude::{
@@ -1767,6 +1768,229 @@ pub fn get_broadcasted_shape(
}
}
////////////////////////
///
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum DataFormat {
/// NCHW
#[default]
NCHW,
/// NHWC
NHWC,
/// CHW
CHW,
/// HWC
HWC,
}
// as str
impl core::fmt::Display for DataFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DataFormat::NCHW => write!(f, "NCHW"),
DataFormat::NHWC => write!(f, "NHWC"),
DataFormat::CHW => write!(f, "CHW"),
DataFormat::HWC => write!(f, "HWC"),
}
}
}
impl DataFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
}
}
/// no batch dim
pub fn has_no_batch(&self) -> bool {
match self {
DataFormat::CHW | DataFormat::HWC => true,
_ => false,
}
}
/// Convert tensor to canonical format (NCHW or CHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// For ND: Move channels from last axis to position after batch
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(ndims - 1, 1)?;
}
}
DataFormat::HWC => {
// For ND: Move channels from last axis to first position
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(ndims - 1, 0)?;
}
}
_ => {} // NCHW/CHW are already in canonical format
}
Ok(())
}
/// Convert tensor from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// Move channels from position 1 to end
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(1, ndims - 1)?;
}
}
DataFormat::HWC => {
// Move channels from position 0 to end
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(0, ndims - 1)?;
}
}
_ => {} // NCHW/CHW don't need conversion
}
Ok(())
}
/// Get the position of the channel dimension
pub fn get_channel_dim(&self, ndims: usize) -> usize {
match self {
DataFormat::NCHW => 1,
DataFormat::NHWC => ndims - 1,
DataFormat::CHW => 0,
DataFormat::HWC => ndims - 1,
}
}
}
/// The shape of the kernel for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum KernelFormat {
/// HWIO
HWIO,
/// OIHW
#[default]
OIHW,
/// OHWI
OHWI,
}
impl core::fmt::Display for KernelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KernelFormat::HWIO => write!(f, "HWIO"),
KernelFormat::OIHW => write!(f, "OIHW"),
KernelFormat::OHWI => write!(f, "OHWI"),
}
}
}
impl KernelFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
}
}
/// Convert kernel to canonical format (OIHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move output channels from last to first
kernel.move_axis(kdims - 1, 0)?;
// Move input channels from new last to second position
kernel.move_axis(kdims - 1, 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from last to second position
kernel.move_axis(kdims - 1, 1)?;
}
_ => {} // OIHW is already canonical
}
Ok(())
}
/// Convert kernel from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
// Move output channels from first to last
kernel.move_axis(0, kdims - 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
}
_ => {} // OIHW doesn't need conversion
}
Ok(())
}
/// Get the position of input and output channel dimensions
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
// (input_ch, output_ch)
match self {
KernelFormat::OIHW => (1, 0),
KernelFormat::HWIO => (ndims - 2, ndims - 1),
KernelFormat::OHWI => (ndims - 1, 0),
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
match fmt {
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
match fmt {
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
KernelFormat::HWIO
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
KernelFormat::OIHW
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
KernelFormat::OHWI
}
}
}
}
#[cfg(test)]
mod tests {

View File

@@ -187,7 +187,7 @@ mod native_tests {
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
const LARGE_TESTS: [&str; 7] = [
const LARGE_TESTS: [&str; 8] = [
"self_attention",
"nanoGPT",
"multihead_attention",
@@ -195,6 +195,7 @@ mod native_tests {
"mnist_gan",
"smallworm",
"fr_age",
"1d_conv",
];
const ACCURACY_CAL_TESTS: [&str; 6] = [
@@ -965,7 +966,7 @@ mod native_tests {
});
seq!(N in 0..=6 {
seq!(N in 0..=7 {
#(#[test_case(LARGE_TESTS[N])])*
#[ignore]
@@ -1633,13 +1634,17 @@ mod native_tests {
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
format!("--variables=batch_size->{}", batch_size),
format!(
"--variables=batch_size->{},sequence_length->100,<Sym1>->1",
batch_size
),
format!("--input-visibility={}", input_visibility),
format!("--param-visibility={}", param_visibility),
format!("--output-visibility={}", output_visibility),
format!("--num-inner-cols={}", num_inner_columns),
format!("--tolerance={}", tolerance),
format!("--commitment={}", commitment),
format!("--logrows={}", 22),
];
// if output-visibility is fixed set --range-check-inputs-outputs to False
@@ -1725,7 +1730,6 @@ mod native_tests {
test_dir, example_name
),
])
.stdout(std::process::Stdio::null())
.status()
.expect("failed to execute process");
assert!(status.success());

View File

@@ -873,7 +873,8 @@ def get_examples():
'linear_regression',
"mnist_gan",
"smallworm",
"fr_age"
"fr_age",
"1d_conv",
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
@@ -900,7 +901,12 @@ async def test_all_examples(model_file, input_file):
proof_path = os.path.join(folder_path, 'proof.json')
print("Testing example: ", model_file)
res = ezkl.gen_settings(model_file, settings_path)
run_args = ezkl.PyRunArgs()
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
run_args.logrows = 22
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
assert res
res = await ezkl.calibrate_settings(