Compare commits

...

2 Commits

Author SHA1 Message Date
dante
80a3c44cb4 feat: lookup-less recip by default (#725) 2024-02-28 16:35:20 +00:00
dante
1656846d1a fix: transcript should serialize as lc flag (#726) 2024-02-26 22:02:47 +00:00
20 changed files with 460 additions and 400 deletions

4
Cargo.lock generated
View File

@@ -2263,7 +2263,7 @@ dependencies = [
[[package]]
name = "halo2_gadgets"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
dependencies = [
"arrayvec 0.7.4",
"bitvec 1.0.1",
@@ -2280,7 +2280,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
dependencies = [
"blake2b_simd",
"env_logger",

View File

@@ -633,7 +633,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -664,7 +664,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
")"
]
},

View File

@@ -277,7 +277,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::div(
layouts::loop_div(
config,
region,
values[..].try_into()?,

View File

@@ -18,10 +18,7 @@ use super::{
region::RegionCtx,
};
use crate::{
circuit::{
ops::base::BaseOp,
utils::{self, F32},
},
circuit::{ops::base::BaseOp, utils},
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{
get_broadcasted_shape,
@@ -54,6 +51,41 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
total_len
}
/// Same as div but splits the division into N parts
pub fn loop_div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
divisor: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if divisor == F::ONE {
return Ok(value[0].clone());
}
// if integer val is divisible by 2, we can use a faster method and div > F::S
let mut divisor = divisor;
let mut num_parts = 1;
while felt_to_i128(divisor) % 2 == 0 && felt_to_i128(divisor) > (2_i128.pow(F::S - 4)) {
divisor = i128_to_felt(felt_to_i128(divisor) / 2);
num_parts += 1;
}
let output = div(config, region, value, divisor)?;
if num_parts == 1 {
return Ok(output);
}
let divisor_int = 2_i128.pow(num_parts - 1);
let divisor_felt = i128_to_felt(divisor_int);
if divisor_int <= 2_i128.pow(F::S - 3) {
div(config, region, &[output], divisor_felt)
} else {
// keep splitting the divisor until it satisfies the condition
loop_div(config, region, &[output], divisor_felt)
}
}
/// Div accumulated layout
pub fn div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -61,6 +93,10 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
value: &[ValTensor<F>; 1],
div: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if div == F::ONE {
return Ok(value[0].clone());
}
let input = value[0].clone();
let input_dims = input.dims();
@@ -88,6 +124,8 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
.into()
};
claimed_output.reshape(input_dims)?;
region.assign(&config.output, &claimed_output)?;
region.increment(claimed_output.len());
let product = pairwise(
config,
@@ -96,8 +134,6 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Mult,
)?;
log::debug!("product: {:?}", product.get_int_evals()?);
let diff_with_input = pairwise(
config,
region,
@@ -105,8 +141,6 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Sub,
)?;
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
range_check(
config,
region,
@@ -117,6 +151,46 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
Ok(claimed_output)
}
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assert is boolean
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
// get values where input is 0
let zero_mask = equals_zero(config, region, input)?;
let one_minus_zero_mask = pairwise(
config,
region,
&[
zero_mask.clone(),
ValTensor::from(Tensor::from([ValType::Constant(F::ONE)].into_iter())),
],
BaseOp::Sub,
)?;
let zero_inverse_val = pairwise(
config,
region,
&[
zero_mask,
ValTensor::from(Tensor::from(
[ValType::Constant(i128_to_felt(zero_inverse_val))].into_iter(),
)),
],
BaseOp::Mult,
)?;
pairwise(
config,
region,
&[one_minus_zero_mask, zero_inverse_val],
BaseOp::Add,
)
}
/// recip accumulated layout
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -125,10 +199,23 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
input_scale: F,
output_scale: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if output_scale == F::ONE || output_scale == F::ZERO {
return recip_int(config, region, value);
}
let input = value[0].clone();
let input_dims = input.dims();
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
let integer_input_scale = felt_to_i128(input_scale);
let integer_output_scale = felt_to_i128(output_scale);
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
let input_scale_ratio =
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
let range_check_bracket = range_check_len / 2;
let is_assigned = !input.any_unknowns()?;
@@ -151,6 +238,8 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
.into()
};
claimed_output.reshape(input_dims)?;
let claimed_output = region.assign(&config.output, &claimed_output)?;
region.increment(claimed_output.len());
// this is now of scale 2 * scale
let product = pairwise(
@@ -160,15 +249,46 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Mult,
)?;
log::debug!("product: {:?}", product.get_int_evals()?);
// divide by input_scale
let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?;
log::debug!("range_check_bracket: {:?}", range_check_bracket);
let zero_inverse_val =
tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0];
let zero_inverse =
Tensor::from([ValType::Constant(i128_to_felt::<F>(zero_inverse_val))].into_iter());
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
let equal_inverse_mask = equals(
config,
region,
&[claimed_output.clone(), zero_inverse.into()],
)?;
// assert the two masks are equal
enforce_equality(
config,
region,
&[equal_zero_mask.clone(), equal_inverse_mask],
)?;
let unit_scale = Tensor::from([ValType::Constant(i128_to_felt(range_check_len))].into_iter());
let unit_mask = pairwise(
config,
region,
&[equal_zero_mask, unit_scale.into()],
BaseOp::Mult,
)?;
// now add the unit mask to the rebased_div
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
// at most the error should be in the original unit scale's range
range_check(
config,
region,
&[product],
&[rebased_offset_div],
&(range_check_bracket, 3 * range_check_bracket),
)?;
@@ -1677,9 +1797,23 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let diff_inverse = diff.inverse()?;
let product_diff_and_invert =
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
equals_zero(config, region, &[diff])
}
/// Equality boolean operation
pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let values = values[0].clone();
let values_inverse = values.inverse()?;
let product_values_and_invert = pairwise(
config,
region,
&[values.clone(), values_inverse],
BaseOp::Mult,
)?;
// constant of 1
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
@@ -1689,12 +1823,12 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
let output = pairwise(
config,
region,
&[ones.into(), product_diff_and_invert],
&[ones.into(), product_values_and_invert],
BaseOp::Sub,
)?;
// take the product of diff and output
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
is_zero_identity(config, region, &[prod_check], false)?;
@@ -1860,7 +1994,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
if normalized {
last_elem = div(
last_elem = loop_div(
config,
region,
&[last_elem],
@@ -2519,6 +2653,17 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if region.throw_range_check_error() {
// assert is within range
let int_values = w.get_int_evals()?;
for v in int_values {
if v < range.0 || v > range.1 {
log::debug!("Value ({:?}) out of range: {:?}", v, range);
return Err(Box::new(TensorError::TableLookupError));
}
}
}
region.increment(assigned_len);
let elapsed = timer.elapsed();
@@ -2945,16 +3090,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
let denom = sum(config, region, &[ex.clone()])?;
// get the inverse
let inv_denom = nonlinearity(
config,
region,
&[denom],
// we set to input scale + output_scale so the output scale is output)scale
&LookupOp::Recip {
input_scale: scale,
output_scale: scale,
},
)?;
let felt_scale = F::from(scale.0 as u64);
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
// product of num * (1 / denom) = 2*output_scale
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
@@ -2989,29 +3126,44 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(
// integer scale
let int_scale = scale.0 as i128;
// felt scale
let felt_scale = i128_to_felt(int_scale);
// range check len capped at 2^(S-3) and make it divisible 2
let range_check_bracket = std::cmp::min(
utils::F32(scale.0),
utils::F32(2_f32.powf((F::S - 5) as f32)),
)
.0;
let range_check_bracket_int = range_check_bracket as i128;
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
let input_scale_ratio = ((scale.0.powf(2.0) / range_check_bracket) * tol) as i128 / 2 * 2;
let recip = recip(
config,
region,
&[values[0].clone()],
&LookupOp::Recip {
input_scale: scale,
// multiply by 100 to get the percent error
output_scale: F32(scale.0 * 100.0),
},
felt_scale,
felt_scale * F::from(100),
)?;
log::debug!("recip: {}", recip.show());
// Multiply the difference by the recip
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
let rebased_product = div(config, region, &[product], F::from(scale.0 as u64))?;
let scaled_tol = (tol * scale.0) as i128;
log::debug!("product: {}", product.show());
let rebased_product = loop_div(config, region, &[product], i128_to_felt(input_scale_ratio))?;
log::debug!("rebased_product: {}", rebased_product.show());
// check that it is within the tolerance range
range_check(
config,
region,
&[rebased_product],
&(-scaled_tol, scaled_tol),
&(-range_check_bracket_int, range_check_bracket_int),
)
}

View File

@@ -70,8 +70,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
min_range_check: i128,
max_range_check: i128,
max_range_size: i128,
throw_range_check_error: bool,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -80,6 +80,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants += n;
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
@@ -95,8 +100,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context from a wrapped region
@@ -116,13 +121,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context
pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -136,8 +145,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -149,6 +158,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -161,8 +171,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -234,6 +244,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.num_inner_cols,
HashSet::new(),
HashSet::new(),
self.throw_range_check_error,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -310,8 +321,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
return Err("update_max_min_lookup_range: invalid range".into());
}
self.max_range_check = self.max_range_check.max(range.1);
self.min_range_check = self.min_range_check.min(range.0);
let range_size = (range.1 - range.0).abs();
self.max_range_size = self.max_range_size.max(range_size);
Ok(())
}
@@ -371,14 +383,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.min_lookup_inputs
}
/// min range check
pub fn min_range_check(&self) -> i128 {
self.min_range_check
}
/// max range check
pub fn max_range_check(&self) -> i128 {
self.max_range_check
pub fn max_range_size(&self) -> i128 {
self.max_range_size
}
/// Assign a constant value

View File

@@ -133,9 +133,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
}
///
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
@@ -152,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required(range, col_size);
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
log::debug!("table range: {:?}", range);
@@ -313,7 +311,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required(range, col_size);
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
let inputs = {
let mut cols = vec![];

View File

@@ -1,4 +1,3 @@
use crate::circuit::ops::hybrid::HybridOp;
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
@@ -2338,113 +2337,3 @@ mod lookup_ultra_overflow {
println!("done.");
}
}
#[cfg(test)]
mod softmax {
use super::*;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const K: usize = 18;
const LEN: usize = 3;
const SCALE: f32 = 128.0;
#[derive(Clone)]
struct SoftmaxCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for SoftmaxCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE);
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Exp {
scale: SCALE.into(),
},
)
.unwrap();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Recip {
input_scale: SCALE.into(),
output_scale: SCALE.into(),
},
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let _output = config
.layout(
&mut region,
&[self.input.clone()],
Box::new(HybridOp::Softmax {
scale: SCALE.into(),
axes: vec![0],
}),
)
.unwrap();
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn softmax_circuit() {
let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = SoftmaxCircuit::<F> {
input: ValTensor::from(input),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}

View File

@@ -471,7 +471,7 @@ pub enum Commands {
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
@@ -526,13 +526,13 @@ pub enum Commands {
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,

View File

@@ -618,7 +618,7 @@ pub(crate) async fn gen_witness(
let start_time = Instant::now();
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?;
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref(), false)?;
// print each variable tuple (symbol, value) as symbol=value
trace!(
@@ -808,16 +808,7 @@ pub(crate) fn calibrate(
// we load the model to get the input and output shapes
// check if gag already exists
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
let model = Model::from_run_args(&settings.run_args, &model_path)?;
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
@@ -833,7 +824,7 @@ pub(crate) fn calibrate(
let range = if let Some(scales) = scales {
scales
} else {
(10..14).collect::<Vec<crate::Scale>>()
(11..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if only_range_check_rebase {
@@ -896,16 +887,6 @@ pub(crate) fn calibrate(
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
#[cfg(unix)]
let _q = match Gag::stderr() {
Ok(r) => Some(r),
Err(_) => None,
};
let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);
@@ -920,17 +901,12 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
#[cfg(unix)]
std::mem::drop(_q);
debug!("circuit creation from run args failed: {:?}", e);
continue;
}
};
chunks
let forward_res = chunks
.iter()
.map(|chunk| {
let chunk = chunk.clone();
@@ -940,7 +916,7 @@ pub(crate) fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
.forward(&mut data.clone(), None, None)
.forward(&mut data.clone(), None, None, true)
.map_err(|e| format!("failed to forward: {}", e))?;
// push result to the hashmap
@@ -951,7 +927,16 @@ pub(crate) fn calibrate(
Ok(()) as Result<(), String>
})
.collect::<Result<Vec<()>, String>>()?;
.collect::<Result<Vec<()>, String>>();
match forward_res {
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
continue;
}
}
let min_lookup_range = forward_pass_res
.get(&key)
@@ -969,35 +954,21 @@ pub(crate) fn calibrate(
.max()
.unwrap_or(0);
let min_range_check = forward_pass_res
let max_range_size = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.min_range_check)
.min()
.unwrap_or(0);
let max_range_check = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_check)
.map(|x| x.max_range_size)
.max()
.unwrap_or(0);
let res = circuit.calibrate_from_min_max(
(min_lookup_range, max_lookup_range),
(min_range_check, max_range_check),
max_range_size,
max_logrows,
lookup_safety_margin,
);
// // drop the gag
// #[cfg(unix)]
// std::mem::drop(_r);
// #[cfg(unix)]
// std::mem::drop(_q);
if res.is_ok() {
let new_settings = circuit.settings().clone();

View File

@@ -61,8 +61,11 @@ use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -134,15 +137,16 @@ pub enum GraphError {
MissingResults,
}
const ASSUMED_BLINDING_FACTORS: usize = 5;
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
pub const MIN_LOGROWS: u32 = 6;
/// 26
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
/// Lookup deg
pub const LOOKUP_DEG: usize = 5;
///
pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD;
use std::cell::RefCell;
@@ -171,10 +175,8 @@ pub struct GraphWitness {
pub max_lookup_inputs: i128,
/// max lookup input
pub min_lookup_inputs: i128,
/// max range check input
pub max_range_check: i128,
/// max range check input
pub min_range_check: i128,
/// max range check size
pub max_range_size: i128,
}
impl GraphWitness {
@@ -202,8 +204,7 @@ impl GraphWitness {
processed_outputs: None,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
}
}
@@ -376,9 +377,7 @@ impl ToPyObject for GraphWitness {
.unwrap();
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
.unwrap();
dict.set_item("max_range_check", self.max_range_check)
.unwrap();
dict.set_item("min_range_check", self.min_range_check)
dict.set_item("max_range_size", self.max_range_size)
.unwrap();
if let Some(processed_inputs) = &self.processed_inputs {
@@ -473,6 +472,20 @@ pub struct GraphSettings {
}
impl GraphSettings {
fn model_constraint_logrows(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64).log2().ceil() as u32
}
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self
@@ -1005,10 +1018,6 @@ impl GraphCircuit {
Ok(data)
}
fn reserved_blinding_rows() -> f64 {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let mut margin = (
lookup_safety_margin * min_max_lookup.0,
@@ -1022,18 +1031,33 @@ impl GraphCircuit {
margin
}
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(
max_logrows as usize,
Self::reserved_blinding_rows() as usize,
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
) -> Result<u32, Box<dyn std::error::Error>> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
max_range_size,
);
num_cols_required(safe_range, max_col_size)
let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.)
.log2()
.ceil() as u32;
Ok(min_bits)
}
fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
min_max_range_checks: Range,
max_range_size: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -1043,68 +1067,57 @@ impl GraphCircuit {
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
let mut min_logrows = MIN_LOGROWS;
let reserved_blinding_rows = Self::reserved_blinding_rows();
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if has overflowed max lookup input
if min_max_lookup.1.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_max_lookup.0.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
{
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
}
if min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|| min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
{
let err_string = format!(
"max range check input {:?} is too large",
min_max_range_checks
);
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
}
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// pick the range with the largest absolute size between safe_lookup_range and min_max_range_checks
let safe_range = if (safe_lookup_range.1 - safe_lookup_range.0)
> (min_max_range_checks.1 - min_max_range_checks.0)
{
safe_lookup_range
} else {
min_max_range_checks
};
// These are hard lower limits, we can't overflow instances or modules constraints
let instance_logrows = self.settings().log2_total_instances();
let module_constraint_logrows = self.settings().module_constraint_logrows();
min_logrows = std::cmp::max(
min_logrows,
// max of the instance logrows and the module constraint logrows is the lower limit
[instance_logrows, module_constraint_logrows]
.iter()
.max()
.unwrap()
.clone(),
);
// These are upper limits, going above these is wasteful, but they are not hard limits
let model_constraint_logrows = self.settings().model_constraint_logrows();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
[model_constraint_logrows, min_bits, constants_logrows]
.iter()
.max()
.unwrap()
.clone(),
);
// we now have a min and max logrows
max_logrows = std::cmp::max(min_logrows, max_logrows);
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
min_logrows,
Self::calc_num_cols(safe_range, min_logrows),
)
{
min_logrows += 1;
}
if !self
.extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows))
{
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
min_logrows
);
debug!("{}", err_string);
return Err(err_string.into());
}
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
max_logrows,
Self::calc_num_cols(safe_range, max_logrows),
)
&& !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size)
{
max_logrows -= 1;
}
if !self
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
{
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
@@ -1113,67 +1126,27 @@ impl GraphCircuit {
return Err(err_string.into());
}
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
.log2()
.ceil() as usize;
let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as usize;
let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints);
// if public input then public inputs col will have public inputs len
if self.settings().run_args.input_visibility.is_public()
|| self.settings().run_args.output_visibility.is_public()
{
let mut max_instance_len = self
.model()
.instance_shapes()?
.iter()
.fold(0, |acc, x| std::cmp::max(acc, x.iter().product::<usize>()))
as f64
+ reserved_blinding_rows;
// if there are modules then we need to add the max module size
if self.settings().uses_modules() {
max_instance_len += self
.settings()
.module_sizes
.num_instances()
.iter()
.sum::<usize>() as f64;
}
let instance_len_logrows = (max_instance_len).log2().ceil() as usize;
logrows = std::cmp::max(logrows, instance_len_logrows);
// this is for fixed const columns
}
// ensure logrows is at least 4
logrows = std::cmp::max(logrows, min_logrows as usize);
logrows = std::cmp::min(logrows, max_logrows as usize);
let logrows = max_logrows;
let model = self.model().clone();
let settings_mut = self.settings_mut();
settings_mut.run_args.lookup_range = safe_lookup_range;
settings_mut.run_args.logrows = logrows as u32;
settings_mut.run_args.logrows = logrows;
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
.settings()
.clone();
// recalculate the total const size give nthe new logrows
let total_const_len = settings_mut.total_const_size;
let const_len_logrows = (total_const_len as f64).log2().ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, const_len_logrows);
// recalculate the total number of constraints given the new logrows
let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
// recalculate the logrows if there has been overflow on the constants
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.constants_logrows(),
);
// recalculate the logrows if there has been overflow for the model constraints
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.model_constraint_logrows(),
);
debug!(
"setting lookup_range to: {:?}, setting logrows to: {}",
@@ -1184,12 +1157,37 @@ impl GraphCircuit {
Ok(())
}
fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool {
let max_degree = self.settings().run_args.num_inner_cols + 2;
let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector
fn extended_k_is_small_enough(
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS {
return false;
} else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS {
return false;
}
let max_degree = std::cmp::max(max_degree, max_lookup_degree);
let mut settings = self.settings().clone();
settings.run_args.lookup_range = safe_lookup_range;
settings.run_args.logrows = k;
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
// fetch gag
#[cfg(unix)]
let _r = match gag::Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
Self::configure_with_params(&mut cs, settings);
#[cfg(feature = "mv-lookup")]
let cs = cs.chunk_lookups();
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
let max_degree = cs.degree();
#[cfg(unix)]
std::mem::drop(_r);
let quotient_poly_degree = (max_degree - 1) as u64;
// n = 2^k
let n = 1u64 << k;
@@ -1208,13 +1206,13 @@ impl GraphCircuit {
pub fn calibrate_from_min_max(
&mut self,
min_max_lookup: Range,
min_max_range_checks: Range,
max_range_size: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_max_lookup,
min_max_range_checks,
max_range_size,
max_logrows,
lookup_safety_margin,
)?;
@@ -1227,6 +1225,7 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
throw_range_check_error: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1267,7 +1266,9 @@ impl GraphCircuit {
}
}
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1310,8 +1311,7 @@ impl GraphCircuit {
processed_outputs,
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_check: model_results.max_range_check,
min_range_check: model_results.min_range_check,
max_range_size: model_results.max_range_size,
};
witness.generate_rescaled_elements(

View File

@@ -67,10 +67,8 @@ pub struct ForwardResult {
pub max_lookup_inputs: i128,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
/// The max range check value
pub max_range_check: i128,
/// The min range check value
pub min_range_check: i128,
/// The max range check size
pub max_range_size: i128,
}
impl From<DummyPassRes> for ForwardResult {
@@ -79,8 +77,7 @@ impl From<DummyPassRes> for ForwardResult {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
min_range_check: res.min_range_check,
max_range_check: res.max_range_check,
max_range_size: res.max_range_size,
}
}
}
@@ -115,9 +112,7 @@ pub struct DummyPassRes {
/// min lookup inputs
pub min_lookup_inputs: i128,
/// min range check
pub min_range_check: i128,
/// max range check
pub max_range_check: i128,
pub max_range_size: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -531,7 +526,7 @@ impl Model {
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs)?;
let res = self.dummy_layout(run_args, &inputs, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -570,12 +565,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
Ok(res.into())
}
@@ -1356,6 +1352,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1374,7 +1371,7 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
@@ -1441,8 +1438,7 @@ impl Model {
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
min_range_check: region.min_range_check(),
max_range_check: region.max_range_check(),
max_range_size: region.max_range_size(),
outputs,
};

View File

@@ -734,7 +734,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
use_range_check_for_int: false,
use_range_check_for_int: true,
})
}

View File

@@ -180,6 +180,11 @@ impl RunArgs {
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
}
if self.tolerance.val > 0.0 {
if self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
}
Ok(())
}

View File

@@ -197,7 +197,11 @@ impl std::fmt::Display for TranscriptType {
}
}
impl ToFlags for TranscriptType {}
impl ToFlags for TranscriptType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for TranscriptType {

View File

@@ -3773,6 +3773,30 @@ pub mod nonlinearities {
.unwrap()
}
/// Elementwise inverse.
/// # Arguments
/// * `out_scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let expected = Tensor::<i128>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<i128> {
let a = Tensor::<i128>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as i128)
})
.unwrap()
}
/// Elementwise greater than
/// # Arguments
///

View File

@@ -211,7 +211,7 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward(&mut input, None, None)
.forward(&mut input, None, None, false)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)

View File

@@ -2,6 +2,7 @@
#[cfg(test)]
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
@@ -276,7 +277,7 @@ mod native_tests {
"bitshift",
];
const WASM_TESTS: [&str; 48] = [
const WASM_TESTS: [&str; 46] = [
"1l_mlp",
"1l_slice",
"1l_concat",
@@ -325,8 +326,6 @@ mod native_tests {
"1l_where",
"boolean",
"boolean_identity",
"decision_tree", // "variable_cnn",
"random_forest",
"gradient_boosted_trees",
"1l_topk",
// "xgboost",
@@ -586,6 +585,8 @@ mod native_tests {
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
@@ -841,7 +842,7 @@ mod native_tests {
});
seq!(N in 0..=47 {
seq!(N in 0..=45 {
#(#[test_case(WASM_TESTS[N])])*
fn kzg_prove_and_verify_with_overflow_(test: &str) {
@@ -1288,6 +1289,7 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
) {
let mut tolerance = tolerance;
gen_circuit_settings_and_witness(
test_dir,
example_name.clone(),
@@ -1299,16 +1301,10 @@ mod native_tests {
scales_to_use,
2,
false,
tolerance,
&mut tolerance,
);
let settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if tolerance > 0.0 && !any_output_scales_smol {
if tolerance > 0.0 {
// load witness and shift the output by a small amount that is less than tolerance percent
let witness = GraphWitness::from_path(
format!("{}/{}/witness.json", test_dir, example_name).into(),
@@ -1333,7 +1329,7 @@ mod native_tests {
as i128,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
@@ -1444,7 +1440,7 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
div_rebasing: bool,
tolerance: f32,
tolerance: &mut f32,
) {
let mut args = vec![
"gen-settings".to_string(),
@@ -1502,6 +1498,24 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let mut settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if any_output_scales_smol {
// set the tolerance to 0.0
settings.run_args.tolerance = Tolerance {
val: 0.0.into(),
scale: 0.0.into(),
};
settings
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
*tolerance = 0.0;
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"compile-circuit",
@@ -1559,7 +1573,7 @@ mod native_tests {
None,
2,
div_rebasing,
0.0,
&mut 0.0,
);
println!(
@@ -1819,7 +1833,7 @@ mod native_tests {
scales_to_use,
num_inner_columns,
false,
0.0,
&mut 0.0,
);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
@@ -1921,7 +1935,7 @@ mod native_tests {
None,
2,
false,
0.0,
&mut 0.0,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
@@ -2198,7 +2212,7 @@ mod native_tests {
Some(vec![4]),
1,
false,
0.0,
&mut 0.0,
);
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);

View File

@@ -91,9 +91,7 @@ def compare_outputs(zk_output, onnx_output):
print("------- zk_output: ", list1_i)
print("------- onnx_output: ", list2_i)
return np.mean(np.abs(res))
return res
if __name__ == '__main__':
@@ -113,6 +111,9 @@ if __name__ == '__main__':
onnx_output = get_onnx_output(model_file, input_file)
# compare the outputs
percentage_difference = compare_outputs(ezkl_output, onnx_output)
mean_percentage_difference = np.mean(np.abs(percentage_difference))
max_percentage_difference = np.max(np.abs(percentage_difference))
# print the percentage difference
print("mean percent diff: ", percentage_difference)
assert percentage_difference < target, "Percentage difference is too high"
print("mean percent diff: ", mean_percentage_difference)
print("max percent diff: ", max_percentage_difference)
assert mean_percentage_difference < target, "Percentage difference is too high"

Binary file not shown.

View File

@@ -1 +1 @@
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_check":0,"min_range_check":0}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}