mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
1 Commits
ac/panic-o
...
release-v2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5f01a5bc5 |
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -294,6 +294,8 @@ jobs:
|
||||
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
|
||||
@@ -69,7 +69,6 @@ impl Circuit<Fr> for MyCircuit {
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
normalized: false,
|
||||
data_format: DataFormat::NCHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '20.1.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::TestDataSource;
|
||||
@@ -155,6 +155,9 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
#[derive(Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
pub tolerance: f32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing inputs
|
||||
pub input_scale: crate::Scale,
|
||||
@@ -222,6 +225,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
num_inner_cols: py_run_args.num_inner_cols,
|
||||
@@ -246,6 +250,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
|
||||
@@ -20,6 +20,7 @@ use crate::{
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
@@ -84,6 +85,55 @@ impl CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
pub struct Tolerance {
|
||||
pub val: f32,
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if let Ok(val) = s.parse::<f32>() {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(1.0),
|
||||
})
|
||||
} else {
|
||||
Err(
|
||||
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Tolerance {
|
||||
fn from(value: f32) -> Self {
|
||||
Tolerance {
|
||||
val: value,
|
||||
scale: utils::F32(1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CheckMode {
|
||||
@@ -108,6 +158,29 @@ impl<'source> FromPyObject<'source> for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Tolerance {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
(self.val, self.scale.0).to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Tolerance {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(scale),
|
||||
})
|
||||
} else {
|
||||
Err(PyValueError::new_err("Invalid tolerance value provided. "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
|
||||
@@ -79,6 +79,7 @@ pub enum HybridOp {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Output {
|
||||
tol: Tolerance,
|
||||
decomp: bool,
|
||||
},
|
||||
Greater,
|
||||
@@ -183,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::Output { decomp } => {
|
||||
format!("OUTPUT (decomp={})", decomp)
|
||||
HybridOp::Output { tol, decomp } => {
|
||||
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
|
||||
}
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
@@ -325,9 +326,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::Output { decomp } => {
|
||||
layouts::output(config, region, values[..].try_into()?, *decomp)?
|
||||
}
|
||||
HybridOp::Output { tol, decomp } => layouts::output(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
tol.scale,
|
||||
tol.val,
|
||||
*decomp,
|
||||
)?,
|
||||
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
|
||||
HybridOp::GreaterEqual => {
|
||||
layouts::greater_equal(config, region, values[..].try_into()?)?
|
||||
|
||||
@@ -5838,12 +5838,14 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[101, 201, 302, 403, 503, 603]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], false).unwrap();
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0, false).unwrap();
|
||||
/// ```
|
||||
pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
scale: utils::F32,
|
||||
tol: f32,
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
@@ -5858,6 +5860,43 @@ pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values[1] = layouts::identity(config, region, &[values[1].clone()], decomp)?;
|
||||
}
|
||||
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
if tol == 0.0 {
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
}
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
// integer scale
|
||||
let int_scale = scale.0 as IntegerRep;
|
||||
// felt scale
|
||||
let felt_scale = integer_rep_to_felt(int_scale);
|
||||
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
|
||||
let input_scale_ratio = (scale.0 * tol) as IntegerRep / 2 * 2;
|
||||
|
||||
let recip = recip(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone()],
|
||||
felt_scale,
|
||||
felt_scale * F::from(100),
|
||||
)?;
|
||||
|
||||
log::debug!("recip: {}", recip.show());
|
||||
|
||||
// Multiply the difference by the recip
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
|
||||
log::debug!("product: {}", product.show());
|
||||
let rebased_product = div(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
integer_rep_to_felt(input_scale_ratio),
|
||||
)?;
|
||||
log::debug!("rebased_product: {}", rebased_product.show());
|
||||
|
||||
// check that it is within the tolerance range
|
||||
range_check(config, region, &[rebased_product], &(-int_scale, int_scale))
|
||||
}
|
||||
|
||||
@@ -19,11 +19,6 @@ pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN as f64;
|
||||
} else if x < -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
panic!("Felt value out of range for conversion to integer rep");
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -36,13 +31,11 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an integer rep.
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
} else if x < -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
panic!("Felt value out of range for conversion to integer rep");
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
@@ -77,13 +70,6 @@ mod test {
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn felttointegerrep_overflow() {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(IntegerRep::MIN) - F::ONE;
|
||||
let _xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::errors::GraphError;
|
||||
use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::scale_to_multiplier;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
@@ -1172,10 +1173,17 @@ impl Model {
|
||||
})?;
|
||||
|
||||
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales().map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars
|
||||
.instance
|
||||
@@ -1198,6 +1206,7 @@ impl Model {
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
@@ -1459,9 +1468,11 @@ impl Model {
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.map(|output| {
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut comparator: ValTensor<Fp> = (0..output.len())
|
||||
.map(|_| {
|
||||
if !self.visibility.output.is_fixed() {
|
||||
@@ -1474,10 +1485,14 @@ impl Model {
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tol = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
|
||||
33
src/lib.rs
33
src/lib.rs
@@ -97,7 +97,7 @@ impl From<String> for EZKLError {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode};
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
use fieldutils::IntegerRep;
|
||||
@@ -274,6 +274,10 @@ impl From<String> for Commitments {
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
@@ -360,6 +364,7 @@ impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
scale_rebase_multiplier: 1,
|
||||
@@ -411,6 +416,10 @@ impl RunArgs {
|
||||
);
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
@@ -459,6 +468,11 @@ impl RunArgs {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
@@ -605,6 +619,23 @@ mod tests {
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
|
||||
Binary file not shown.
@@ -1,7 +1,8 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -522,7 +523,7 @@ mod native_tests {
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
|
||||
use rand::Rng;
|
||||
use tempdir::TempDir;
|
||||
use ezkl::Commitments;
|
||||
|
||||
@@ -543,7 +544,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -608,7 +609,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -618,10 +619,22 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, true, Some(8194), Some(4));
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_tolerance_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
@@ -632,7 +645,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, false, None, None);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -642,7 +655,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -651,7 +664,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -660,7 +673,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -669,7 +682,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -678,7 +691,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -687,7 +700,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -696,7 +709,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -706,7 +719,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -716,7 +729,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -725,7 +738,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -735,7 +748,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -744,7 +757,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -754,7 +767,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -764,7 +777,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -774,7 +787,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -784,7 +797,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -794,7 +807,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, false, None, None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -971,7 +984,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, false, None, Some(5));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1447,10 +1460,12 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
@@ -1461,6 +1476,7 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
&mut tolerance,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
@@ -1468,17 +1484,128 @@ mod native_tests {
|
||||
decomp_legs,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
if tolerance > 0.0 {
|
||||
// load witness and shift the output by a small amount that is less than tolerance percent
|
||||
let witness = GraphWitness::from_path(
|
||||
format!("{}/{}/witness.json", test_dir, example_name).into(),
|
||||
)
|
||||
.unwrap();
|
||||
let witness = witness.clone();
|
||||
let outputs = witness.outputs.clone();
|
||||
|
||||
// get values as i64
|
||||
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
sv.iter()
|
||||
.map(|v| {
|
||||
// randomly perturb by a small amount less than tolerance
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// get values as i64
|
||||
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
sv.iter()
|
||||
.map(|v| {
|
||||
// randomly perturb by a small amount less than tolerance
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let good_witness = GraphWitness {
|
||||
outputs: output_perturbed_safe,
|
||||
..witness.clone()
|
||||
};
|
||||
|
||||
// save
|
||||
good_witness
|
||||
.save(format!("{}/{}/witness_ok.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let bad_witness = GraphWitness {
|
||||
outputs: output_perturbed_bad,
|
||||
..witness.clone()
|
||||
};
|
||||
|
||||
// save
|
||||
bad_witness
|
||||
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness_ok.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness_bad.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
} else {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -1492,6 +1619,7 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
tolerance: &mut f32,
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
@@ -1514,6 +1642,7 @@ mod native_tests {
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
format!("--tolerance={}", tolerance),
|
||||
format!("--commitment={}", commitment),
|
||||
format!("--logrows={}", 22),
|
||||
];
|
||||
@@ -1571,6 +1700,24 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let mut settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if any_output_scales_smol {
|
||||
// set the tolerance to 0.0
|
||||
settings.run_args.tolerance = Tolerance {
|
||||
val: 0.0,
|
||||
scale: 0.0.into(),
|
||||
};
|
||||
settings
|
||||
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
@@ -1625,6 +1772,7 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
@@ -1910,6 +2058,7 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
&mut 0.0,
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
@@ -2344,6 +2493,7 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![4]),
|
||||
1,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
|
||||
@@ -48,6 +48,7 @@ def test_py_run_args():
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.input_visibility = "hashed"
|
||||
run_args.output_visibility = "hashed"
|
||||
run_args.tolerance = 1.5
|
||||
|
||||
|
||||
def test_poseidon_hash():
|
||||
|
||||
Reference in New Issue
Block a user