mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
4 Commits
release-v2
...
ac/panic-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be5d241b42 | ||
|
|
ae076aef09 | ||
|
|
a7544f4060 | ||
|
|
c19fa5218a |
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -276,6 +276,8 @@ jobs:
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: Large 1D Conv Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
|
||||
- name: MNIST Gan Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
|
||||
- name: NanoGPT Mock
|
||||
@@ -292,8 +294,6 @@ jobs:
|
||||
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1932,7 +1932,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ezkl-gpu"
|
||||
name = "ezkl"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"alloy",
|
||||
|
||||
@@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::NCHW,
|
||||
kernel_format: KernelFormat::OIHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -69,6 +69,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
normalized: false,
|
||||
data_format: DataFormat::NCHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -32,6 +32,7 @@ use mnist::*;
|
||||
use rand::rngs::OsRng;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
||||
mod params;
|
||||
|
||||
const K: usize = 20;
|
||||
@@ -208,6 +209,8 @@ where
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::NCHW,
|
||||
kernel_format: KernelFormat::OIHW,
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
|
||||
106
examples/onnx/1d_conv/input.json
Normal file
106
examples/onnx/1d_conv/input.json
Normal file
@@ -0,0 +1,106 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
8761,
|
||||
7654,
|
||||
8501,
|
||||
2404,
|
||||
6929,
|
||||
8858,
|
||||
5946,
|
||||
3673,
|
||||
4131,
|
||||
3854,
|
||||
8137,
|
||||
8239,
|
||||
9038,
|
||||
6299,
|
||||
1118,
|
||||
9737,
|
||||
208,
|
||||
7954,
|
||||
3691,
|
||||
610,
|
||||
3468,
|
||||
3314,
|
||||
8658,
|
||||
8366,
|
||||
2850,
|
||||
477,
|
||||
6114,
|
||||
232,
|
||||
4601,
|
||||
7420,
|
||||
5713,
|
||||
2936,
|
||||
6061,
|
||||
2870,
|
||||
8421,
|
||||
177,
|
||||
7107,
|
||||
7382,
|
||||
6115,
|
||||
5487,
|
||||
8502,
|
||||
2559,
|
||||
1875,
|
||||
129,
|
||||
8533,
|
||||
8201,
|
||||
8414,
|
||||
4775,
|
||||
9817,
|
||||
3127,
|
||||
8761,
|
||||
7654,
|
||||
8501,
|
||||
2404,
|
||||
6929,
|
||||
8858,
|
||||
5946,
|
||||
3673,
|
||||
4131,
|
||||
3854,
|
||||
8137,
|
||||
8239,
|
||||
9038,
|
||||
6299,
|
||||
1118,
|
||||
9737,
|
||||
208,
|
||||
7954,
|
||||
3691,
|
||||
610,
|
||||
3468,
|
||||
3314,
|
||||
8658,
|
||||
8366,
|
||||
2850,
|
||||
477,
|
||||
6114,
|
||||
232,
|
||||
4601,
|
||||
7420,
|
||||
5713,
|
||||
2936,
|
||||
6061,
|
||||
2870,
|
||||
8421,
|
||||
177,
|
||||
7107,
|
||||
7382,
|
||||
6115,
|
||||
5487,
|
||||
8502,
|
||||
2559,
|
||||
1875,
|
||||
129,
|
||||
8533,
|
||||
8201,
|
||||
8414,
|
||||
4775,
|
||||
9817,
|
||||
3127
|
||||
]
|
||||
]
|
||||
}
|
||||
BIN
examples/onnx/1d_conv/network.onnx
Normal file
BIN
examples/onnx/1d_conv/network.onnx
Normal file
Binary file not shown.
@@ -4,8 +4,8 @@ use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::TestDataSource;
|
||||
@@ -155,9 +155,6 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
#[derive(Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
pub tolerance: f32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing inputs
|
||||
pub input_scale: crate::Scale,
|
||||
@@ -225,7 +222,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
num_inner_cols: py_run_args.num_inner_cols,
|
||||
@@ -250,7 +246,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
|
||||
@@ -20,7 +20,6 @@ use crate::{
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
@@ -85,55 +84,6 @@ impl CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
pub struct Tolerance {
|
||||
pub val: f32,
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if let Ok(val) = s.parse::<f32>() {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(1.0),
|
||||
})
|
||||
} else {
|
||||
Err(
|
||||
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Tolerance {
|
||||
fn from(value: f32) -> Self {
|
||||
Tolerance {
|
||||
val: value,
|
||||
scale: utils::F32(1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CheckMode {
|
||||
@@ -158,29 +108,6 @@ impl<'source> FromPyObject<'source> for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Tolerance {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
(self.val, self.scale.0).to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Tolerance {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(scale),
|
||||
})
|
||||
} else {
|
||||
Err(PyValueError::new_err("Invalid tolerance value provided. "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
circuit::{layouts, utils},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -57,11 +57,13 @@ pub enum HybridOp {
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
normalized: bool,
|
||||
data_format: DataFormat,
|
||||
},
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
data_format: DataFormat,
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -77,7 +79,6 @@ pub enum HybridOp {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Output {
|
||||
tol: Tolerance,
|
||||
decomp: bool,
|
||||
},
|
||||
Greater,
|
||||
@@ -154,10 +155,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
normalized, data_format
|
||||
} => format!(
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
|
||||
padding, stride, kernel_shape, normalized
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
|
||||
padding, stride, kernel_shape, normalized, data_format
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
@@ -165,9 +166,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
data_format,
|
||||
} => format!(
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
|
||||
padding, stride, pool_dims, data_format
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
|
||||
@@ -181,8 +183,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::Output { tol, decomp } => {
|
||||
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
|
||||
HybridOp::Output { decomp } => {
|
||||
format!("OUTPUT (decomp={})", decomp)
|
||||
}
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
@@ -239,6 +241,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
data_format,
|
||||
} => layouts::sumpool(
|
||||
config,
|
||||
region,
|
||||
@@ -247,6 +250,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
stride,
|
||||
kernel_shape,
|
||||
*normalized,
|
||||
*data_format,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
@@ -287,6 +291,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
data_format,
|
||||
} => layouts::max_pool(
|
||||
config,
|
||||
region,
|
||||
@@ -294,6 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
*data_format,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
@@ -319,14 +325,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::Output { tol, decomp } => layouts::output(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
tol.scale,
|
||||
tol.val,
|
||||
*decomp,
|
||||
)?,
|
||||
HybridOp::Output { decomp } => {
|
||||
layouts::output(config, region, values[..].try_into()?, *decomp)?
|
||||
}
|
||||
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
|
||||
HybridOp::GreaterEqual => {
|
||||
layouts::greater_equal(config, region, values[..].try_into()?)?
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::{
|
||||
ops::{accumulated, add, mult, sub},
|
||||
Tensor, TensorError, ValType,
|
||||
},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -156,25 +157,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
@@ -248,32 +230,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let masked_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), not_equal_zero_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
@@ -3225,6 +3181,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// use ezkl::tensor::DataFormat;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
@@ -3234,12 +3191,12 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap());
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false).unwrap();
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false, DataFormat::default()).unwrap();
|
||||
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // This time with normalization
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true).unwrap();
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true, DataFormat::default()).unwrap();
|
||||
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
@@ -3251,9 +3208,19 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
stride: &[usize],
|
||||
kernel_shape: &[usize],
|
||||
normalized: bool,
|
||||
data_format: DataFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let batch_size = values[0].dims()[0];
|
||||
let image_channels = values[0].dims()[1];
|
||||
let mut image = values[0].clone();
|
||||
data_format.to_canonical(&mut image)?;
|
||||
|
||||
if data_format.has_no_batch() {
|
||||
let mut dims = image.dims().to_vec();
|
||||
dims.insert(0, 1);
|
||||
image.reshape(&dims)?;
|
||||
}
|
||||
|
||||
let batch_size = image.dims()[0];
|
||||
let image_channels = image.dims()[1];
|
||||
|
||||
let kernel_len = kernel_shape.iter().product();
|
||||
|
||||
@@ -3278,7 +3245,16 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map(|coord| {
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
|
||||
let output = conv(config, region, &[input, kernel.clone()], padding, stride, 1)?;
|
||||
let output = conv(
|
||||
config,
|
||||
region,
|
||||
&[input, kernel.clone()],
|
||||
padding,
|
||||
stride,
|
||||
1,
|
||||
DataFormat::default(),
|
||||
KernelFormat::default(),
|
||||
)?;
|
||||
res.push(output);
|
||||
Ok(())
|
||||
})
|
||||
@@ -3293,6 +3269,9 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
if normalized {
|
||||
last_elem = div(config, region, &[last_elem], F::from(kernel_len as u64))?;
|
||||
}
|
||||
|
||||
data_format.from_canonical(&mut last_elem)?;
|
||||
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
@@ -3302,6 +3281,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::max_pool;
|
||||
/// use ezkl::tensor::DataFormat;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
@@ -3316,7 +3296,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap());
|
||||
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2]).unwrap();
|
||||
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2], DataFormat::default()).unwrap();
|
||||
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled.int_evals().unwrap(), expected);
|
||||
///
|
||||
@@ -3328,8 +3308,16 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
padding: &[(usize, usize)],
|
||||
stride: &[usize],
|
||||
pool_dims: &[usize],
|
||||
data_format: DataFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let image = values[0].clone();
|
||||
let mut image = values[0].clone();
|
||||
data_format.to_canonical(&mut image)?;
|
||||
|
||||
if data_format.has_no_batch() {
|
||||
let mut dims = image.dims().to_vec();
|
||||
dims.insert(0, 1);
|
||||
image.reshape(&dims)?;
|
||||
}
|
||||
|
||||
let image_dims = image.dims();
|
||||
|
||||
@@ -3388,38 +3376,38 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let res: ValTensor<F> = output.into();
|
||||
let mut res: ValTensor<F> = output.into();
|
||||
|
||||
data_format.from_canonical(&mut res)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Performs a deconvolution on the given input tensor.
|
||||
/// # Examples
|
||||
/// ```
|
||||
// // expected outputs are taken from pytorch torch.nn.functional.conv_transpose2d
|
||||
///
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::deconv;
|
||||
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// // Original test case 1: Channel expansion
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 2: Basic deconvolution
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3428,11 +3416,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// // Original test case 3: With padding
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3441,11 +3429,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[17]), &[1, 1, 1, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// // Original test case 4: With stride
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3454,10 +3442,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 5: Zero padding with stride
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3466,10 +3455,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 6: Different kernel shape
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3478,10 +3468,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 2]),
|
||||
/// &[1, 1, 2, 1],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 7: Different kernel shape without padding
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3490,20 +3481,21 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 2]),
|
||||
/// &[1, 1, 2, 1],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// // Original test case 8: Channel expansion with stride
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 9: With bias
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 8, 0, 8, 4, 9, 8, 1, 8]),
|
||||
/// &[1, 1, 3, 3],
|
||||
@@ -3516,11 +3508,89 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[1]),
|
||||
/// &[1],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 1: NHWC format with HWIO kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 2, 2, 1], // NHWC format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 5, 3]),
|
||||
/// &[2, 2, 1, 1], // HWIO format
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[27]), &[1, 1, 1, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 2: 1D deconvolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[1, 1, 3], // NCH format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2]),
|
||||
/// &[1, 1, 2], // OIH format
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![0], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 7, 6]), &[1, 1, 4]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 3: 3D deconvolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4]),
|
||||
/// &[1, 1, 2, 2, 1], // NCDHW format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1]),
|
||||
/// &[1, 1, 1, 1, 2], // OIDHW format
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![0; 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4]), &[1, 1, 2, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 4: Multi-channel with NHWC format and OHWI kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1, 3, 2, 1, 4]), // 2 channels, 2x2 spatial
|
||||
/// &[1, 2, 2, 2], // NHWC format [batch, height, width, channels]
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[1, 2, 2, 2], // OHWI format [out_channels, height, width, in_channels]
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 24, 4, 41, 78, 27, 27, 66, 39]), &[1, 3, 3, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 5: CHW format (no batch dimension)
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 2, 2], // CHW format [channels, height, width]
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4]),
|
||||
/// &[1, 1, 2, 2], // OIHW format [out_channels, in_channels, height, width]
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::CHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 1, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 6: HWC format with HWIO kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 4, 1]),
|
||||
/// &[2, 2, 1], // HWC format [height, width, channels]
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 1, 2]),
|
||||
/// &[2, 2, 1, 1], // HWIO format [height, width, in_channels, out_channels]
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::HWC, KernelFormat::HWIO).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 3, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn deconv<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
@@ -3531,9 +3601,14 @@ pub fn deconv<
|
||||
output_padding: &[usize],
|
||||
stride: &[usize],
|
||||
num_groups: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let has_bias = inputs.len() == 3;
|
||||
let (image, kernel) = (&inputs[0], &inputs[1]);
|
||||
let (mut working_image, mut working_kernel) = (inputs[0].clone(), inputs[1].clone());
|
||||
|
||||
data_format.to_canonical(&mut working_image)?;
|
||||
kernel_format.to_canonical(&mut working_kernel)?;
|
||||
|
||||
if stride.iter().any(|&s| s == 0) {
|
||||
return Err(TensorError::DimMismatch(
|
||||
@@ -3543,26 +3618,23 @@ pub fn deconv<
|
||||
}
|
||||
|
||||
let null_val = ValType::Constant(F::ZERO);
|
||||
let mut expanded_image = working_image.clone();
|
||||
|
||||
let mut expanded_image = image.clone();
|
||||
|
||||
// Expand image by inserting zeros according to stride
|
||||
for (i, s) in stride.iter().enumerate() {
|
||||
expanded_image.intercalate_values(null_val.clone(), *s, 2 + i)?;
|
||||
}
|
||||
|
||||
// Pad to kernel size for each spatial dimension
|
||||
expanded_image.pad(
|
||||
kernel.dims()[2..]
|
||||
working_kernel.dims()[2..]
|
||||
.iter()
|
||||
.map(|d| (d - 1, d - 1))
|
||||
.collect::<Vec<_>>(),
|
||||
2,
|
||||
)?; // pad to the kernel size
|
||||
|
||||
// flip order
|
||||
let channel_coord = (0..kernel.dims()[0])
|
||||
.cartesian_product(0..kernel.dims()[1])
|
||||
.collect::<Vec<_>>();
|
||||
)?;
|
||||
|
||||
// Calculate slice coordinates considering padding and output padding
|
||||
let slice_coord = expanded_image
|
||||
.dims()
|
||||
.iter()
|
||||
@@ -3578,26 +3650,34 @@ pub fn deconv<
|
||||
|
||||
let sliced_expanded_image = expanded_image.get_slice(&slice_coord)?;
|
||||
|
||||
let mut inverted_kernels = vec![];
|
||||
// Generate channel coordinates for kernel transformation
|
||||
let (in_ch_dim, out_ch_dim) =
|
||||
KernelFormat::default().get_channel_dims(working_kernel.dims().len());
|
||||
let channel_coord = (0..working_kernel.dims()[out_ch_dim])
|
||||
.cartesian_product(0..working_kernel.dims()[in_ch_dim])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Invert kernels for deconvolution
|
||||
let mut inverted_kernels = vec![];
|
||||
for (i, j) in channel_coord {
|
||||
let channel = kernel.get_slice(&[i..i + 1, j..j + 1])?;
|
||||
let channel = working_kernel.get_slice(&[i..i + 1, j..j + 1])?;
|
||||
let mut channel = Tensor::from(channel.get_inner_tensor()?.clone().into_iter().rev());
|
||||
channel.reshape(&kernel.dims()[2..])?;
|
||||
channel.reshape(&working_kernel.dims()[2..])?;
|
||||
inverted_kernels.push(channel);
|
||||
}
|
||||
|
||||
let mut deconv_kernel =
|
||||
Tensor::new(Some(&inverted_kernels), &[inverted_kernels.len()])?.combine()?;
|
||||
deconv_kernel.reshape(kernel.dims())?;
|
||||
deconv_kernel.reshape(working_kernel.dims())?;
|
||||
|
||||
// tensorflow formatting patch
|
||||
if kernel.dims()[0] == sliced_expanded_image.dims()[1] {
|
||||
// Handle tensorflow-style input/output channel ordering
|
||||
if working_kernel.dims()[0] == sliced_expanded_image.dims()[1] {
|
||||
let mut dims = deconv_kernel.dims().to_vec();
|
||||
dims.swap(0, 1);
|
||||
deconv_kernel.reshape(&dims)?;
|
||||
}
|
||||
|
||||
// Prepare inputs for convolution
|
||||
let conv_input = if has_bias {
|
||||
vec![
|
||||
sliced_expanded_image,
|
||||
@@ -3608,28 +3688,32 @@ pub fn deconv<
|
||||
vec![sliced_expanded_image, deconv_kernel.clone().into()]
|
||||
};
|
||||
|
||||
let conv_dim = kernel.dims()[2..].len();
|
||||
let conv_dim = working_kernel.dims()[2..].len();
|
||||
|
||||
let output = conv(
|
||||
// Perform convolution with canonical formats
|
||||
let mut output = conv(
|
||||
config,
|
||||
region,
|
||||
&conv_input,
|
||||
&vec![(0, 0); conv_dim],
|
||||
&vec![1; conv_dim],
|
||||
num_groups,
|
||||
data_format.canonical(), // Use canonical format
|
||||
kernel_format.canonical(), // Use canonical format
|
||||
)?;
|
||||
|
||||
// Convert output back to requested format
|
||||
data_format.from_canonical(&mut output)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Applies convolution over a ND tensor of shape C x H x D1...DN (and adds a bias).
|
||||
/// ```
|
||||
/// // expected outputs are taken from pytorch torch.nn.functional.conv2d
|
||||
///
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::conv;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
@@ -3638,6 +3722,7 @@ pub fn deconv<
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// // Test case 1: Basic 2D convolution with NCHW format (default)
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
@@ -3650,44 +3735,64 @@ pub fn deconv<
|
||||
/// Some(&[0]),
|
||||
/// &[1],
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Now test single channel
|
||||
/// // Test case 2: NHWC format with HWIO kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 2, 3, 3],
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3, 1], // NHWC format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1]),
|
||||
/// &[2, 1, 2, 2],
|
||||
/// Some(&[1, 1, 5, 1]),
|
||||
/// &[2, 2, 1, 1], // HWIO format
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[11, 24, 20, 14]), &[1, 2, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Test case 3: Multi-channel NHWC with OHWI kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3, 2], // NHWC format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 2, 5, 2, 1, 2]),
|
||||
/// &[1, 2, 2, 2], // OHWI format
|
||||
/// ).unwrap());
|
||||
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1]),
|
||||
/// &[2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap();
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[64, 66, 46, 58]), &[1, 2, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Now test multi channel
|
||||
/// // Test case 4: 1D convolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 2, 3, 3],
|
||||
/// Some(&[1, 2, 3, 4, 5]),
|
||||
/// &[1, 1, 5], // NCHW format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]),
|
||||
/// &[4, 2, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1]),
|
||||
/// &[4],
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[1, 1, 3], // OIHW format
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[1, 1, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap();
|
||||
/// // Test case 5: 3D convolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[1, 1, 2, 2, 2], // NCDHW format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1]),
|
||||
/// &[1, 1, 1, 1, 2], // OIDHW format
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 7, 11, 15]), &[1, 1, 2, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
@@ -3700,9 +3805,14 @@ pub fn conv<
|
||||
padding: &[(usize, usize)],
|
||||
stride: &[usize],
|
||||
num_groups: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let has_bias = values.len() == 3;
|
||||
let (mut image, mut kernel) = (values[0].clone(), values[1].clone());
|
||||
let (mut working_image, mut working_kernel) = (values[0].clone(), values[1].clone());
|
||||
|
||||
data_format.to_canonical(&mut working_image)?;
|
||||
kernel_format.to_canonical(&mut working_kernel)?;
|
||||
|
||||
if stride.iter().any(|&s| s == 0) {
|
||||
return Err(TensorError::DimMismatch(
|
||||
@@ -3711,47 +3821,40 @@ pub fn conv<
|
||||
.into());
|
||||
}
|
||||
|
||||
// we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them
|
||||
// 1. assign the kernel
|
||||
// Assign tensors
|
||||
let mut assigned_len = vec![];
|
||||
|
||||
if !kernel.all_prev_assigned() {
|
||||
kernel = region.assign(&config.custom_gates.inputs[0], &kernel)?;
|
||||
assigned_len.push(kernel.len());
|
||||
if !working_kernel.all_prev_assigned() {
|
||||
working_kernel = region.assign(&config.custom_gates.inputs[0], &working_kernel)?;
|
||||
assigned_len.push(working_kernel.len());
|
||||
}
|
||||
// 2. assign the image
|
||||
if !image.all_prev_assigned() {
|
||||
image = region.assign(&config.custom_gates.inputs[1], &image)?;
|
||||
assigned_len.push(image.len());
|
||||
if !working_image.all_prev_assigned() {
|
||||
working_image = region.assign(&config.custom_gates.inputs[1], &working_image)?;
|
||||
assigned_len.push(working_image.len());
|
||||
}
|
||||
|
||||
if !assigned_len.is_empty() {
|
||||
// safe to unwrap since we've just checked it has at least one element
|
||||
region.increment(*assigned_len.iter().max().unwrap());
|
||||
}
|
||||
|
||||
// if image is 3d add a dummy batch dimension
|
||||
if image.dims().len() == kernel.dims().len() - 1 {
|
||||
image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?;
|
||||
if data_format.has_no_batch() {
|
||||
let mut dim = working_image.dims().to_vec();
|
||||
dim.insert(0, 1);
|
||||
working_image.reshape(&dim)?;
|
||||
}
|
||||
|
||||
let image_dims = image.dims();
|
||||
let kernel_dims = kernel.dims();
|
||||
let image_dims = working_image.dims();
|
||||
let kernel_dims = working_kernel.dims();
|
||||
|
||||
let mut padded_image = image.clone();
|
||||
// Apply padding
|
||||
let mut padded_image = working_image.clone();
|
||||
padded_image.pad(padding.to_vec(), 2)?;
|
||||
|
||||
// Extract dimensions
|
||||
let batch_size = image_dims[0];
|
||||
let input_channels = image_dims[1];
|
||||
let output_channels = kernel_dims[0];
|
||||
|
||||
log::debug!(
|
||||
"batch_size: {}, output_channels: {}, input_channels: {}",
|
||||
batch_size,
|
||||
output_channels,
|
||||
input_channels
|
||||
);
|
||||
|
||||
// Calculate slides for each spatial dimension
|
||||
let slides = image_dims[2..]
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -3766,8 +3869,6 @@ pub fn conv<
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
log::debug!("slides: {:?}", slides);
|
||||
|
||||
let input_channels_per_group = input_channels / num_groups;
|
||||
let output_channels_per_group = output_channels / num_groups;
|
||||
|
||||
@@ -3775,24 +3876,15 @@ pub fn conv<
|
||||
return Err(TensorError::DimMismatch(format!(
|
||||
"Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}",
|
||||
num_groups, input_channels, output_channels
|
||||
))
|
||||
.into());
|
||||
)).into());
|
||||
}
|
||||
|
||||
log::debug!(
|
||||
"num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}",
|
||||
num_groups,
|
||||
input_channels_per_group,
|
||||
output_channels_per_group
|
||||
);
|
||||
|
||||
let num_outputs =
|
||||
batch_size * num_groups * output_channels_per_group * slides.iter().product::<usize>();
|
||||
|
||||
log::debug!("num_outputs: {}", num_outputs);
|
||||
|
||||
let mut output: Tensor<ValType<F>> = Tensor::new(None, &[num_outputs])?;
|
||||
|
||||
// Create iteration space
|
||||
let mut iterations = vec![0..batch_size, 0..num_groups, 0..output_channels_per_group];
|
||||
for slide in slides.iter() {
|
||||
iterations.push(0..*slide);
|
||||
@@ -3804,6 +3896,13 @@ pub fn conv<
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let batch_offset = if data_format.has_no_batch() {
|
||||
2 // No batch dimension, start coordinates after channels
|
||||
} else {
|
||||
3 // Has batch dimension, start coordinates after batch and channels
|
||||
};
|
||||
|
||||
// Main convolution loop
|
||||
let inner_loop_function = |idx: usize, region: &mut RegionCtx<F>| {
|
||||
let cartesian_coord_per_group = &cartesian_coord[idx];
|
||||
let (batch, group, i) = (
|
||||
@@ -3817,22 +3916,19 @@ pub fn conv<
|
||||
|
||||
let mut slices = vec![batch..batch + 1, start_channel..end_channel];
|
||||
for (i, stride) in stride.iter().enumerate() {
|
||||
let coord = cartesian_coord_per_group[3 + i] * stride;
|
||||
let coord = cartesian_coord_per_group[batch_offset + i] * stride;
|
||||
let kernel_dim = kernel_dims[2 + i];
|
||||
slices.push(coord..(coord + kernel_dim));
|
||||
}
|
||||
|
||||
let mut local_image = padded_image.get_slice(&slices)?;
|
||||
|
||||
local_image.flatten();
|
||||
|
||||
let start_kernel_index = group * output_channels_per_group + i;
|
||||
let end_kernel_index = start_kernel_index + 1;
|
||||
let mut local_kernel = kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
|
||||
|
||||
let mut local_kernel = working_kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
|
||||
local_kernel.flatten();
|
||||
|
||||
// this is dot product notation in einsum format
|
||||
let mut res = einsum(config, region, &[local_image, local_kernel], "i,i->")?;
|
||||
|
||||
if has_bias {
|
||||
@@ -3853,21 +3949,16 @@ pub fn conv<
|
||||
region.flush()?;
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let reshape_output = |output: &mut Tensor<ValType<F>>| -> Result<(), TensorError> {
|
||||
// remove dummy batch dimension if we added one
|
||||
let mut dims = vec![batch_size, output_channels];
|
||||
dims.extend(slides.iter().cloned());
|
||||
output.reshape(&dims)?;
|
||||
// Reshape output
|
||||
let mut dims = vec![batch_size, output_channels];
|
||||
dims.extend(slides.iter().cloned());
|
||||
output.reshape(&dims)?;
|
||||
|
||||
Ok(())
|
||||
};
|
||||
// Convert output back to requested format
|
||||
let mut final_output: ValTensor<F> = output.into();
|
||||
data_format.from_canonical(&mut final_output)?;
|
||||
|
||||
// remove dummy batch dimension if we added one
|
||||
reshape_output(&mut output)?;
|
||||
|
||||
let output: ValTensor<_> = output.into();
|
||||
|
||||
Ok(output)
|
||||
Ok(final_output)
|
||||
}
|
||||
|
||||
/// Power accumulated layout
|
||||
@@ -5747,14 +5838,12 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[101, 201, 302, 403, 503, 603]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0, false).unwrap();
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], false).unwrap();
|
||||
/// ```
|
||||
pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
scale: utils::F32,
|
||||
tol: f32,
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
@@ -5769,43 +5858,6 @@ pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values[1] = layouts::identity(config, region, &[values[1].clone()], decomp)?;
|
||||
}
|
||||
|
||||
if tol == 0.0 {
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
}
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
// integer scale
|
||||
let int_scale = scale.0 as IntegerRep;
|
||||
// felt scale
|
||||
let felt_scale = integer_rep_to_felt(int_scale);
|
||||
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
|
||||
let input_scale_ratio = (scale.0 * tol) as IntegerRep / 2 * 2;
|
||||
|
||||
let recip = recip(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone()],
|
||||
felt_scale,
|
||||
felt_scale * F::from(100),
|
||||
)?;
|
||||
|
||||
log::debug!("recip: {}", recip.show());
|
||||
|
||||
// Multiply the difference by the recip
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
|
||||
log::debug!("product: {}", product.show());
|
||||
let rebased_product = div(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
integer_rep_to_felt(input_scale_ratio),
|
||||
)?;
|
||||
log::debug!("rebased_product: {}", rebased_product.show());
|
||||
|
||||
// check that it is within the tolerance range
|
||||
range_check(config, region, &[rebased_product], &(-int_scale, int_scale))
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::{
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::{base::BaseOp, *};
|
||||
@@ -43,6 +44,8 @@ pub enum PolyOp {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -54,6 +57,8 @@ pub enum PolyOp {
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -165,10 +170,12 @@ impl<
|
||||
stride,
|
||||
padding,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => {
|
||||
format!(
|
||||
"CONV (stride={:?}, padding={:?}, group={})",
|
||||
stride, padding, group
|
||||
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, group, data_format, kernel_format
|
||||
)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
@@ -176,11 +183,12 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
|
||||
stride, padding, output_padding, group
|
||||
)
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, output_padding, group, data_format, kernel_format)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
@@ -242,6 +250,8 @@ impl<
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => layouts::conv(
|
||||
config,
|
||||
region,
|
||||
@@ -249,6 +259,8 @@ impl<
|
||||
padding,
|
||||
stride,
|
||||
*group,
|
||||
*data_format,
|
||||
*kernel_format,
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -309,6 +321,8 @@ impl<
|
||||
output_padding,
|
||||
stride,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => layouts::deconv(
|
||||
config,
|
||||
region,
|
||||
@@ -317,6 +331,8 @@ impl<
|
||||
output_padding,
|
||||
stride,
|
||||
*group,
|
||||
*data_format,
|
||||
*kernel_format,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::*;
|
||||
use crate::tensor::{DataFormat, KernelFormat};
|
||||
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -1065,6 +1066,8 @@ mod conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1220,6 +1223,8 @@ mod conv_col_ultra_overflow {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1377,6 +1382,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -19,6 +19,11 @@ pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN as f64;
|
||||
} else if x < -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
panic!("Felt value out of range for conversion to integer rep");
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -31,11 +36,13 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i64.
|
||||
/// Converts a PrimeField element to an integer rep.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
} else if x < -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
panic!("Felt value out of range for conversion to integer rep");
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
@@ -70,6 +77,13 @@ mod test {
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn felttointegerrep_overflow() {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(IntegerRep::MIN) - F::ONE;
|
||||
let _xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
|
||||
@@ -33,7 +33,7 @@ pub enum GraphError {
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
#[error("a node has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
|
||||
@@ -609,8 +609,12 @@ impl GraphData {
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"calibration data length must be evenly divisible by the original input_size"
|
||||
.to_string(),
|
||||
format!(
|
||||
"calibration data length (={}) must be evenly divisible by the original input_size(={})",
|
||||
input.len(),
|
||||
input_size
|
||||
),
|
||||
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use super::errors::GraphError;
|
||||
use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::scale_to_multiplier;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
@@ -1173,17 +1172,10 @@ impl Model {
|
||||
})?;
|
||||
|
||||
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales().map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars
|
||||
.instance
|
||||
@@ -1206,7 +1198,6 @@ impl Model {
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
@@ -1468,11 +1459,9 @@ impl Model {
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
.map(|output| {
|
||||
let mut comparator: ValTensor<Fp> = (0..output.len())
|
||||
.map(|_| {
|
||||
if !self.visibility.output.is_fixed() {
|
||||
@@ -1485,14 +1474,10 @@ impl Model {
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tol = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -39,9 +39,8 @@ use tract_onnx::tract_hir::{
|
||||
ops::array::{Pad, PadMode, TypedConcat},
|
||||
ops::cnn::PoolSpec,
|
||||
ops::konst::Const,
|
||||
ops::nn::DataFormat,
|
||||
tract_core::ops::cast::Cast,
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
tract_core::ops::cnn::{MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
@@ -1146,13 +1145,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
@@ -1161,6 +1153,7 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
data_format: pool_spec.data_format.into(),
|
||||
})
|
||||
}
|
||||
"Ceil" => {
|
||||
@@ -1314,15 +1307,6 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|
||||
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool_spec = &conv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
@@ -1350,6 +1334,8 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
data_format: conv_node.pool_spec.data_format.into(),
|
||||
kernel_format: conv_node.kernel_fmt.into(),
|
||||
})
|
||||
}
|
||||
"Not" => SupportedOp::Linear(PolyOp::Not),
|
||||
@@ -1373,14 +1359,6 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
|| (deconv_node.kernel_format != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool_spec = &deconv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
@@ -1406,6 +1384,8 @@ pub fn new_op_from_onnx(
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
group: deconv_node.group,
|
||||
data_format: deconv_node.pool_spec.data_format.into(),
|
||||
kernel_format: deconv_node.kernel_format.into(),
|
||||
})
|
||||
}
|
||||
"Downsample" => {
|
||||
@@ -1489,13 +1469,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
|
||||
@@ -1504,6 +1477,7 @@ pub fn new_op_from_onnx(
|
||||
stride: stride.to_vec(),
|
||||
kernel_shape: pool_spec.kernel_shape.to_vec(),
|
||||
normalized: sumpool_node.normalize,
|
||||
data_format: pool_spec.data_format.into(),
|
||||
})
|
||||
}
|
||||
"Pad" => {
|
||||
|
||||
44
src/lib.rs
44
src/lib.rs
@@ -97,10 +97,9 @@ impl From<String> for EZKLError {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use circuit::{table::Range, CheckMode};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use halo2_proofs::poly::{
|
||||
@@ -275,10 +274,6 @@ impl From<String> for Commitments {
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
@@ -365,7 +360,6 @@ impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
scale_rebase_multiplier: 1,
|
||||
@@ -399,6 +393,16 @@ impl RunArgs {
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// check if the largest represented integer in the decomposed form overflows IntegerRep
|
||||
// try it with the largest possible value
|
||||
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
|
||||
if max_decomp.is_none() {
|
||||
errors.push(format!(
|
||||
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
|
||||
self.decomp_base, self.decomp_legs
|
||||
));
|
||||
}
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
errors.push(
|
||||
@@ -407,10 +411,6 @@ impl RunArgs {
|
||||
);
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
@@ -459,11 +459,6 @@ impl RunArgs {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
@@ -610,23 +605,6 @@ mod tests {
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ops::DecompositionError;
|
||||
use super::{ops::DecompositionError, DataFormat};
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -44,4 +44,7 @@ pub enum TensorError {
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
/// Invalid data conversion
|
||||
#[error("invalid data conversion from format {0} to {1}")]
|
||||
InvalidDataConversion(DataFormat, DataFormat),
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use core::hash::Hash;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
@@ -1767,6 +1768,229 @@ pub fn get_broadcasted_shape(
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
///
|
||||
|
||||
/// The shape of data for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
pub enum DataFormat {
|
||||
/// NCHW
|
||||
#[default]
|
||||
NCHW,
|
||||
/// NHWC
|
||||
NHWC,
|
||||
/// CHW
|
||||
CHW,
|
||||
/// HWC
|
||||
HWC,
|
||||
}
|
||||
|
||||
// as str
|
||||
impl core::fmt::Display for DataFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DataFormat::NCHW => write!(f, "NCHW"),
|
||||
DataFormat::NHWC => write!(f, "NHWC"),
|
||||
DataFormat::CHW => write!(f, "CHW"),
|
||||
DataFormat::HWC => write!(f, "HWC"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DataFormat {
|
||||
/// Get the format's canonical form
|
||||
pub fn canonical(&self) -> DataFormat {
|
||||
match self {
|
||||
DataFormat::NHWC => DataFormat::NCHW,
|
||||
DataFormat::HWC => DataFormat::CHW,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// no batch dim
|
||||
pub fn has_no_batch(&self) -> bool {
|
||||
match self {
|
||||
DataFormat::CHW | DataFormat::HWC => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tensor to canonical format (NCHW or CHW)
|
||||
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
tensor: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
DataFormat::NHWC => {
|
||||
// For ND: Move channels from last axis to position after batch
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 2 {
|
||||
tensor.move_axis(ndims - 1, 1)?;
|
||||
}
|
||||
}
|
||||
DataFormat::HWC => {
|
||||
// For ND: Move channels from last axis to first position
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 1 {
|
||||
tensor.move_axis(ndims - 1, 0)?;
|
||||
}
|
||||
}
|
||||
_ => {} // NCHW/CHW are already in canonical format
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert tensor from canonical format to target format
|
||||
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
tensor: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
DataFormat::NHWC => {
|
||||
// Move channels from position 1 to end
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 2 {
|
||||
tensor.move_axis(1, ndims - 1)?;
|
||||
}
|
||||
}
|
||||
DataFormat::HWC => {
|
||||
// Move channels from position 0 to end
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 1 {
|
||||
tensor.move_axis(0, ndims - 1)?;
|
||||
}
|
||||
}
|
||||
_ => {} // NCHW/CHW don't need conversion
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the position of the channel dimension
|
||||
pub fn get_channel_dim(&self, ndims: usize) -> usize {
|
||||
match self {
|
||||
DataFormat::NCHW => 1,
|
||||
DataFormat::NHWC => ndims - 1,
|
||||
DataFormat::CHW => 0,
|
||||
DataFormat::HWC => ndims - 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The shape of the kernel for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
pub enum KernelFormat {
|
||||
/// HWIO
|
||||
HWIO,
|
||||
/// OIHW
|
||||
#[default]
|
||||
OIHW,
|
||||
/// OHWI
|
||||
OHWI,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for KernelFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
KernelFormat::HWIO => write!(f, "HWIO"),
|
||||
KernelFormat::OIHW => write!(f, "OIHW"),
|
||||
KernelFormat::OHWI => write!(f, "OHWI"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelFormat {
|
||||
/// Get the format's canonical form
|
||||
pub fn canonical(&self) -> KernelFormat {
|
||||
match self {
|
||||
KernelFormat::HWIO => KernelFormat::OIHW,
|
||||
KernelFormat::OHWI => KernelFormat::OIHW,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert kernel to canonical format (OIHW)
|
||||
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
kernel: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
KernelFormat::HWIO => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move output channels from last to first
|
||||
kernel.move_axis(kdims - 1, 0)?;
|
||||
// Move input channels from new last to second position
|
||||
kernel.move_axis(kdims - 1, 1)?;
|
||||
}
|
||||
KernelFormat::OHWI => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from last to second position
|
||||
kernel.move_axis(kdims - 1, 1)?;
|
||||
}
|
||||
_ => {} // OIHW is already canonical
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert kernel from canonical format to target format
|
||||
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
kernel: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
KernelFormat::HWIO => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from second position to last
|
||||
kernel.move_axis(1, kdims - 1)?;
|
||||
// Move output channels from first to last
|
||||
kernel.move_axis(0, kdims - 1)?;
|
||||
}
|
||||
KernelFormat::OHWI => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from second position to last
|
||||
kernel.move_axis(1, kdims - 1)?;
|
||||
}
|
||||
_ => {} // OIHW doesn't need conversion
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the position of input and output channel dimensions
|
||||
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
|
||||
// (input_ch, output_ch)
|
||||
match self {
|
||||
KernelFormat::OIHW => (1, 0),
|
||||
KernelFormat::HWIO => (ndims - 2, ndims - 1),
|
||||
KernelFormat::OHWI => (ndims - 1, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
|
||||
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
|
||||
match fmt {
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
|
||||
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
|
||||
match fmt {
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
|
||||
KernelFormat::HWIO
|
||||
}
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
|
||||
KernelFormat::OIHW
|
||||
}
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
|
||||
KernelFormat::OHWI
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
Binary file not shown.
@@ -1,8 +1,7 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -187,7 +186,7 @@ mod native_tests {
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 7] = [
|
||||
const LARGE_TESTS: [&str; 8] = [
|
||||
"self_attention",
|
||||
"nanoGPT",
|
||||
"multihead_attention",
|
||||
@@ -195,6 +194,7 @@ mod native_tests {
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age",
|
||||
"1d_conv",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
@@ -522,7 +522,7 @@ mod native_tests {
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
use rand::Rng;
|
||||
|
||||
use tempdir::TempDir;
|
||||
use ezkl::Commitments;
|
||||
|
||||
@@ -543,7 +543,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -608,7 +608,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -618,22 +618,10 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, true, Some(8194), Some(4));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_tolerance_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
@@ -644,7 +632,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -654,7 +642,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -663,7 +651,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -672,7 +660,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -681,7 +669,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -690,7 +678,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -699,7 +687,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -708,7 +696,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -718,7 +706,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -728,7 +716,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -737,7 +725,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -747,7 +735,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -756,7 +744,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -766,7 +754,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -776,7 +764,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -786,7 +774,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -796,7 +784,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -806,7 +794,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -965,7 +953,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=6 {
|
||||
seq!(N in 0..=7 {
|
||||
|
||||
#(#[test_case(LARGE_TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -983,7 +971,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, false, None, Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1459,12 +1447,10 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
@@ -1475,7 +1461,6 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
&mut tolerance,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
@@ -1483,128 +1468,17 @@ mod native_tests {
|
||||
decomp_legs,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
// load witness and shift the output by a small amount that is less than tolerance percent
|
||||
let witness = GraphWitness::from_path(
|
||||
format!("{}/{}/witness.json", test_dir, example_name).into(),
|
||||
)
|
||||
.unwrap();
|
||||
let witness = witness.clone();
|
||||
let outputs = witness.outputs.clone();
|
||||
|
||||
// get values as i64
|
||||
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
sv.iter()
|
||||
.map(|v| {
|
||||
// randomly perturb by a small amount less than tolerance
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// get values as i64
|
||||
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
sv.iter()
|
||||
.map(|v| {
|
||||
// randomly perturb by a small amount less than tolerance
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let good_witness = GraphWitness {
|
||||
outputs: output_perturbed_safe,
|
||||
..witness.clone()
|
||||
};
|
||||
|
||||
// save
|
||||
good_witness
|
||||
.save(format!("{}/{}/witness_ok.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let bad_witness = GraphWitness {
|
||||
outputs: output_perturbed_bad,
|
||||
..witness.clone()
|
||||
};
|
||||
|
||||
// save
|
||||
bad_witness
|
||||
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness_ok.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness_bad.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
} else {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -1618,7 +1492,6 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
tolerance: &mut f32,
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
@@ -1633,13 +1506,16 @@ mod native_tests {
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size->{}", batch_size),
|
||||
format!(
|
||||
"--variables=batch_size->{},sequence_length->100,<Sym1>->1",
|
||||
batch_size
|
||||
),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
format!("--tolerance={}", tolerance),
|
||||
format!("--commitment={}", commitment),
|
||||
format!("--logrows={}", 22),
|
||||
];
|
||||
|
||||
// if output-visibility is fixed set --range-check-inputs-outputs to False
|
||||
@@ -1695,24 +1571,6 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let mut settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if any_output_scales_smol {
|
||||
// set the tolerance to 0.0
|
||||
settings.run_args.tolerance = Tolerance {
|
||||
val: 0.0,
|
||||
scale: 0.0.into(),
|
||||
};
|
||||
settings
|
||||
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
@@ -1725,7 +1583,6 @@ mod native_tests {
|
||||
test_dir, example_name
|
||||
),
|
||||
])
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
@@ -1768,7 +1625,6 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
@@ -2054,7 +1910,6 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
&mut 0.0,
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
@@ -2489,7 +2344,6 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![4]),
|
||||
1,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
|
||||
@@ -48,7 +48,6 @@ def test_py_run_args():
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.input_visibility = "hashed"
|
||||
run_args.output_visibility = "hashed"
|
||||
run_args.tolerance = 1.5
|
||||
|
||||
|
||||
def test_poseidon_hash():
|
||||
@@ -873,7 +872,8 @@ def get_examples():
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age"
|
||||
"fr_age",
|
||||
"1d_conv",
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
@@ -900,7 +900,12 @@ async def test_all_examples(model_file, input_file):
|
||||
proof_path = os.path.join(folder_path, 'proof.json')
|
||||
|
||||
print("Testing example: ", model_file)
|
||||
res = ezkl.gen_settings(model_file, settings_path)
|
||||
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
|
||||
run_args.logrows = 22
|
||||
|
||||
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
|
||||
assert res
|
||||
|
||||
res = await ezkl.calibrate_settings(
|
||||
|
||||
Reference in New Issue
Block a user