Compare commits

...

3 Commits

Author SHA1 Message Date
dante
14786acb95 feat: dynamic lookups (#727) 2024-03-01 01:44:45 +00:00
dante
80a3c44cb4 feat: lookup-less recip by default (#725) 2024-02-28 16:35:20 +00:00
dante
1656846d1a fix: transcript should serialize as lc flag (#726) 2024-02-26 22:02:47 +00:00
23 changed files with 915 additions and 463 deletions

4
Cargo.lock generated
View File

@@ -2263,7 +2263,7 @@ dependencies = [
[[package]]
name = "halo2_gadgets"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
dependencies = [
"arrayvec 0.7.4",
"bitvec 1.0.1",
@@ -2280,7 +2280,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
dependencies = [
"blake2b_simd",
"env_logger",

View File

@@ -633,7 +633,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -664,7 +664,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
")"
]
},

View File

@@ -198,6 +198,9 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub lookup_input: VarTensor,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// The VarTensor reserved for dynamic lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub dynamic_lookup_tables: Vec<VarTensor>,
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_output: VarTensor,
@@ -207,6 +210,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub dynamic_lookup_selectors: BTreeMap<(usize, usize), Vec<Selector>>,
///
pub dynamic_table_selectors: Vec<Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
@@ -228,9 +235,12 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
lookup_input: dummy_var.clone(),
output: dummy_var.clone(),
lookup_output: dummy_var.clone(),
dynamic_lookup_tables: vec![VarTensor::dummy(col_size, 2), dummy_var.clone()],
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
dynamic_lookup_selectors: BTreeMap::new(),
dynamic_table_selectors: vec![],
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
@@ -376,10 +386,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
dynamic_lookup_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
dynamic_table_selectors: vec![],
dynamic_lookup_tables: vec![],
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
@@ -403,8 +416,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
}
@@ -514,10 +525,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
selectors.insert((nl.clone(), x, y), multi_col_selector);
self.lookup_selectors
.insert((nl.clone(), x, y), multi_col_selector);
}
}
self.lookup_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
@@ -534,6 +545,85 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_dynamic_lookup(
&mut self,
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 2],
tables: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_ltable = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.dynamic_lookup_selectors
.entry((x, y))
.or_default()
.push(s_lookup);
}
}
self.dynamic_table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookup_tables.is_empty() {
debug!("assigning dynamic lookup table");
self.dynamic_lookup_tables = tables.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_range_check(
@@ -547,8 +637,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
}
@@ -620,10 +708,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
selectors.insert((range, x, y), multi_col_selector);
self.range_check_selectors
.insert((range, x, y), multi_col_selector);
}
}
self.range_check_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");

View File

@@ -277,7 +277,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::div(
layouts::loop_div(
config,
region,
values[..].try_into()?,

View File

@@ -18,10 +18,7 @@ use super::{
region::RegionCtx,
};
use crate::{
circuit::{
ops::base::BaseOp,
utils::{self, F32},
},
circuit::{ops::base::BaseOp, utils},
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{
get_broadcasted_shape,
@@ -54,6 +51,41 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
total_len
}
/// Same as div but splits the division into N parts
pub fn loop_div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
divisor: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if divisor == F::ONE {
return Ok(value[0].clone());
}
// if integer val is divisible by 2, we can use a faster method and div > F::S
let mut divisor = divisor;
let mut num_parts = 1;
while felt_to_i128(divisor) % 2 == 0 && felt_to_i128(divisor) > (2_i128.pow(F::S - 4)) {
divisor = i128_to_felt(felt_to_i128(divisor) / 2);
num_parts += 1;
}
let output = div(config, region, value, divisor)?;
if num_parts == 1 {
return Ok(output);
}
let divisor_int = 2_i128.pow(num_parts - 1);
let divisor_felt = i128_to_felt(divisor_int);
if divisor_int <= 2_i128.pow(F::S - 3) {
div(config, region, &[output], divisor_felt)
} else {
// keep splitting the divisor until it satisfies the condition
loop_div(config, region, &[output], divisor_felt)
}
}
/// Div accumulated layout
pub fn div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -61,6 +93,10 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
value: &[ValTensor<F>; 1],
div: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if div == F::ONE {
return Ok(value[0].clone());
}
let input = value[0].clone();
let input_dims = input.dims();
@@ -88,6 +124,8 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
.into()
};
claimed_output.reshape(input_dims)?;
region.assign(&config.output, &claimed_output)?;
region.increment(claimed_output.len());
let product = pairwise(
config,
@@ -96,8 +134,6 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Mult,
)?;
log::debug!("product: {:?}", product.get_int_evals()?);
let diff_with_input = pairwise(
config,
region,
@@ -105,8 +141,6 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Sub,
)?;
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
range_check(
config,
region,
@@ -117,6 +151,46 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
Ok(claimed_output)
}
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assert is boolean
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
// get values where input is 0
let zero_mask = equals_zero(config, region, input)?;
let one_minus_zero_mask = pairwise(
config,
region,
&[
zero_mask.clone(),
ValTensor::from(Tensor::from([ValType::Constant(F::ONE)].into_iter())),
],
BaseOp::Sub,
)?;
let zero_inverse_val = pairwise(
config,
region,
&[
zero_mask,
ValTensor::from(Tensor::from(
[ValType::Constant(i128_to_felt(zero_inverse_val))].into_iter(),
)),
],
BaseOp::Mult,
)?;
pairwise(
config,
region,
&[one_minus_zero_mask, zero_inverse_val],
BaseOp::Add,
)
}
/// recip accumulated layout
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -125,10 +199,23 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
input_scale: F,
output_scale: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if output_scale == F::ONE || output_scale == F::ZERO {
return recip_int(config, region, value);
}
let input = value[0].clone();
let input_dims = input.dims();
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
let integer_input_scale = felt_to_i128(input_scale);
let integer_output_scale = felt_to_i128(output_scale);
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
let input_scale_ratio =
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
let range_check_bracket = range_check_len / 2;
let is_assigned = !input.any_unknowns()?;
@@ -151,6 +238,8 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
.into()
};
claimed_output.reshape(input_dims)?;
let claimed_output = region.assign(&config.output, &claimed_output)?;
region.increment(claimed_output.len());
// this is now of scale 2 * scale
let product = pairwise(
@@ -160,15 +249,46 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Mult,
)?;
log::debug!("product: {:?}", product.get_int_evals()?);
// divide by input_scale
let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?;
log::debug!("range_check_bracket: {:?}", range_check_bracket);
let zero_inverse_val =
tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0];
let zero_inverse =
Tensor::from([ValType::Constant(i128_to_felt::<F>(zero_inverse_val))].into_iter());
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
let equal_inverse_mask = equals(
config,
region,
&[claimed_output.clone(), zero_inverse.into()],
)?;
// assert the two masks are equal
enforce_equality(
config,
region,
&[equal_zero_mask.clone(), equal_inverse_mask],
)?;
let unit_scale = Tensor::from([ValType::Constant(i128_to_felt(range_check_len))].into_iter());
let unit_mask = pairwise(
config,
region,
&[equal_zero_mask, unit_scale.into()],
BaseOp::Mult,
)?;
// now add the unit mask to the rebased_div
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
// at most the error should be in the original unit scale's range
range_check(
config,
region,
&[product],
&[rebased_offset_div],
&(range_check_bracket, 3 * range_check_bracket),
)?;
@@ -795,6 +915,78 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
Ok(assigned_output)
}
/// Dynamic lookup
pub fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err("lookups must be same length".into());
}
// if not all inputs same length err
if tables[0].len() != tables[1].len() {
return Err("inputs must be same length".into());
}
// now assert the inputs of a smaller length than the lookups
if tables[0].len() > tables[0].len() {
return Err("inputs must be smaller length than dynamic lookups".into());
}
let (lookup_0, lookup_1) = (lookups[0].clone(), lookups[1].clone());
let (table_0, table_1) = (tables[0].clone(), tables[1].clone());
let table_0 = region.assign_dynamic_lookup(&config.dynamic_lookup_tables[0], &table_0)?;
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookup_tables[1], &table_1)?;
let table_len = table_0.len();
let lookup_0 = region.assign(&config.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
if !region.is_dummy() {
(0..table_len)
.map(|i| {
let dynamic_lookup_index = region.dynamic_lookup_index();
let table_selector = config.dynamic_table_selectors[dynamic_lookup_index];
let (_, _, z) = config.dynamic_lookup_tables[0]
.cartesian_coord(region.dynamic_lookup_col_coord() + i);
region.enable(Some(&table_selector), z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if !region.is_dummy() {
// Enable the selectors
(0..lookup_len)
.map(|i| {
let (x, y, z) = config.inputs[0].cartesian_coord(region.linear_coord() + i);
let dynamic_lookup_index = region.dynamic_lookup_index();
let lookup_selector = config
.dynamic_lookup_selectors
.get(&(x, y))
.ok_or("missing selectors")?[dynamic_lookup_index];
region.enable(Some(&lookup_selector), z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment_dynamic_lookup_col_coord(table_len);
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);
Ok((lookup_0, lookup_1))
}
/// One hot accumulated layout
pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -1677,9 +1869,23 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let diff_inverse = diff.inverse()?;
let product_diff_and_invert =
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
equals_zero(config, region, &[diff])
}
/// Equality boolean operation
pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let values = values[0].clone();
let values_inverse = values.inverse()?;
let product_values_and_invert = pairwise(
config,
region,
&[values.clone(), values_inverse],
BaseOp::Mult,
)?;
// constant of 1
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
@@ -1689,12 +1895,12 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
let output = pairwise(
config,
region,
&[ones.into(), product_diff_and_invert],
&[ones.into(), product_values_and_invert],
BaseOp::Sub,
)?;
// take the product of diff and output
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
is_zero_identity(config, region, &[prod_check], false)?;
@@ -1860,7 +2066,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
if normalized {
last_elem = div(
last_elem = loop_div(
config,
region,
&[last_elem],
@@ -2519,6 +2725,17 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if region.throw_range_check_error() {
// assert is within range
let int_values = w.get_int_evals()?;
for v in int_values {
if v < range.0 || v > range.1 {
log::debug!("Value ({:?}) out of range: {:?}", v, range);
return Err(Box::new(TensorError::TableLookupError));
}
}
}
region.increment(assigned_len);
let elapsed = timer.elapsed();
@@ -2945,16 +3162,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
let denom = sum(config, region, &[ex.clone()])?;
// get the inverse
let inv_denom = nonlinearity(
config,
region,
&[denom],
// we set to input scale + output_scale so the output scale is output)scale
&LookupOp::Recip {
input_scale: scale,
output_scale: scale,
},
)?;
let felt_scale = F::from(scale.0 as u64);
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
// product of num * (1 / denom) = 2*output_scale
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
@@ -2989,29 +3198,44 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(
// integer scale
let int_scale = scale.0 as i128;
// felt scale
let felt_scale = i128_to_felt(int_scale);
// range check len capped at 2^(S-3) and make it divisible 2
let range_check_bracket = std::cmp::min(
utils::F32(scale.0),
utils::F32(2_f32.powf((F::S - 5) as f32)),
)
.0;
let range_check_bracket_int = range_check_bracket as i128;
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
let input_scale_ratio = ((scale.0.powf(2.0) / range_check_bracket) * tol) as i128 / 2 * 2;
let recip = recip(
config,
region,
&[values[0].clone()],
&LookupOp::Recip {
input_scale: scale,
// multiply by 100 to get the percent error
output_scale: F32(scale.0 * 100.0),
},
felt_scale,
felt_scale * F::from(100),
)?;
log::debug!("recip: {}", recip.show());
// Multiply the difference by the recip
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
let rebased_product = div(config, region, &[product], F::from(scale.0 as u64))?;
let scaled_tol = (tol * scale.0) as i128;
log::debug!("product: {}", product.show());
let rebased_product = loop_div(config, region, &[product], i128_to_felt(input_scale_ratio))?;
log::debug!("rebased_product: {}", rebased_product.show());
// check that it is within the tolerance range
range_check(
config,
region,
&[rebased_product],
&(-scaled_tol, scaled_tol),
&(-range_check_bracket_int, range_check_bracket_int),
)
}

View File

@@ -20,6 +20,33 @@ use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
lookup_index: usize,
col_coord: usize,
}
impl DynamicLookupIndex {
/// Create a new dynamic lookup index
pub fn new(lookup_index: usize, col_coord: usize) -> DynamicLookupIndex {
DynamicLookupIndex {
lookup_index,
col_coord,
}
}
/// Get the lookup index
pub fn lookup_index(&self) -> usize {
self.lookup_index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
@@ -66,12 +93,13 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
min_range_check: i128,
max_range_check: i128,
max_range_size: i128,
throw_range_check_error: bool,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -80,6 +108,21 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants += n;
}
///
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
self.dynamic_lookup_index.lookup_index += n;
}
///
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
self.dynamic_lookup_index.col_coord += n;
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
@@ -91,12 +134,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context from a wrapped region
@@ -104,6 +148,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
@@ -112,17 +157,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context
pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -132,12 +182,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -147,8 +198,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -157,12 +210,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants,
dynamic_lookup_index,
used_lookups,
used_range_checks,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -217,6 +271,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
*output = output
.par_enum_map(|idx, _| {
@@ -232,8 +287,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
starting_linear_coord,
starting_constants,
self.num_inner_cols,
DynamicLookupIndex::default(),
HashSet::new(),
HashSet::new(),
self.throw_range_check_error,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -254,6 +311,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
lookups.extend(local_reg.used_lookups());
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.lookup_index += local_reg.dynamic_lookup_index.lookup_index;
dynamic_lookup_index.col_coord += local_reg.dynamic_lookup_index.col_coord;
res
})
.map_err(|e| {
@@ -282,6 +343,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
Ok(())
}
@@ -310,8 +385,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
return Err("update_max_min_lookup_range: invalid range".into());
}
self.max_range_check = self.max_range_check.max(range.1);
self.min_range_check = self.min_range_check.min(range.0);
let range_size = (range.1 - range.0).abs();
self.max_range_size = self.max_range_size.max(range_size);
Ok(())
}
@@ -351,6 +427,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants
}
/// Get the dynamic lookup index
pub fn dynamic_lookup_index(&self) -> usize {
self.dynamic_lookup_index.lookup_index
}
/// Get the dynamic lookup column coordinate
pub fn dynamic_lookup_col_coord(&self) -> usize {
self.dynamic_lookup_index.col_coord
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
@@ -371,14 +457,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.min_lookup_inputs
}
/// min range check
pub fn min_range_check(&self) -> i128 {
self.min_range_check
}
/// max range check
pub fn max_range_check(&self) -> i128 {
self.max_range_check
pub fn max_range_size(&self) -> i128 {
self.max_range_size
}
/// Assign a constant value
@@ -405,6 +486,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
}
/// Assign a valtensor to a vartensor
pub fn assign_dynamic_lookup(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.dynamic_lookup_col_coord(),
values,
)
} else {
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,

View File

@@ -133,9 +133,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
}
///
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
@@ -152,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required(range, col_size);
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
log::debug!("table range: {:?}", range);
@@ -313,7 +311,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required(range, col_size);
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
let inputs = {
let mut cols = vec![];

View File

@@ -1,4 +1,3 @@
use crate::circuit::ops::hybrid::HybridOp;
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
@@ -1575,6 +1574,142 @@ mod add {
}
}
#[cfg(test)]
mod dynamic_lookup {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
tables: [[ValTensor<F>; 2]; NUM_LOOP],
lookups: [[ValTensor<F>; 2]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
for _ in 0..NUM_LOOP {
config
.configure_dynamic_lookup(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
}
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.dynamic_lookup_col_coord(),
NUM_LOOP * self.tables[0][0].len()
);
assert_eq!(region.dynamic_lookup_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn dynamiclookupcircuit() {
// parameters
let tables = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.clone().try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..2).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from((0..2).map(|_| Value::known(F::from(10000))))),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
#[cfg(test)]
mod add_with_overflow {
use super::*;
@@ -2338,113 +2473,3 @@ mod lookup_ultra_overflow {
println!("done.");
}
}
#[cfg(test)]
mod softmax {
use super::*;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const K: usize = 18;
const LEN: usize = 3;
const SCALE: f32 = 128.0;
#[derive(Clone)]
struct SoftmaxCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for SoftmaxCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE);
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Exp {
scale: SCALE.into(),
},
)
.unwrap();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Recip {
input_scale: SCALE.into(),
output_scale: SCALE.into(),
},
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let _output = config
.layout(
&mut region,
&[self.input.clone()],
Box::new(HybridOp::Softmax {
scale: SCALE.into(),
axes: vec![0],
}),
)
.unwrap();
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn softmax_circuit() {
let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = SoftmaxCircuit::<F> {
input: ValTensor::from(input),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}

View File

@@ -471,7 +471,7 @@ pub enum Commands {
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
@@ -526,13 +526,13 @@ pub enum Commands {
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,

View File

@@ -618,7 +618,7 @@ pub(crate) async fn gen_witness(
let start_time = Instant::now();
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?;
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref(), false)?;
// print each variable tuple (symbol, value) as symbol=value
trace!(
@@ -808,16 +808,7 @@ pub(crate) fn calibrate(
// we load the model to get the input and output shapes
// check if gag already exists
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
let model = Model::from_run_args(&settings.run_args, &model_path)?;
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
@@ -833,7 +824,7 @@ pub(crate) fn calibrate(
let range = if let Some(scales) = scales {
scales
} else {
(10..14).collect::<Vec<crate::Scale>>()
(11..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if only_range_check_rebase {
@@ -896,16 +887,6 @@ pub(crate) fn calibrate(
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
#[cfg(unix)]
let _q = match Gag::stderr() {
Ok(r) => Some(r),
Err(_) => None,
};
let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);
@@ -920,17 +901,12 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
#[cfg(unix)]
std::mem::drop(_q);
debug!("circuit creation from run args failed: {:?}", e);
continue;
}
};
chunks
let forward_res = chunks
.iter()
.map(|chunk| {
let chunk = chunk.clone();
@@ -940,7 +916,7 @@ pub(crate) fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
.forward(&mut data.clone(), None, None)
.forward(&mut data.clone(), None, None, true)
.map_err(|e| format!("failed to forward: {}", e))?;
// push result to the hashmap
@@ -951,7 +927,16 @@ pub(crate) fn calibrate(
Ok(()) as Result<(), String>
})
.collect::<Result<Vec<()>, String>>()?;
.collect::<Result<Vec<()>, String>>();
match forward_res {
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
continue;
}
}
let min_lookup_range = forward_pass_res
.get(&key)
@@ -969,35 +954,21 @@ pub(crate) fn calibrate(
.max()
.unwrap_or(0);
let min_range_check = forward_pass_res
let max_range_size = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.min_range_check)
.min()
.unwrap_or(0);
let max_range_check = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_check)
.map(|x| x.max_range_size)
.max()
.unwrap_or(0);
let res = circuit.calibrate_from_min_max(
let res = circuit.calc_min_logrows(
(min_lookup_range, max_lookup_range),
(min_range_check, max_range_check),
max_range_size,
max_logrows,
lookup_safety_margin,
);
// // drop the gag
// #[cfg(unix)]
// std::mem::drop(_r);
// #[cfg(unix)]
// std::mem::drop(_q);
if res.is_ok() {
let new_settings = circuit.settings().clone();

View File

@@ -61,8 +61,11 @@ use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -134,15 +137,16 @@ pub enum GraphError {
MissingResults,
}
const ASSUMED_BLINDING_FACTORS: usize = 5;
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
pub const MIN_LOGROWS: u32 = 6;
/// 26
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
/// Lookup deg
pub const LOOKUP_DEG: usize = 5;
///
pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD;
use std::cell::RefCell;
@@ -171,10 +175,8 @@ pub struct GraphWitness {
pub max_lookup_inputs: i128,
/// max lookup input
pub min_lookup_inputs: i128,
/// max range check input
pub max_range_check: i128,
/// max range check input
pub min_range_check: i128,
/// max range check size
pub max_range_size: i128,
}
impl GraphWitness {
@@ -202,8 +204,7 @@ impl GraphWitness {
processed_outputs: None,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
max_range_size: 0,
}
}
@@ -376,9 +377,7 @@ impl ToPyObject for GraphWitness {
.unwrap();
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
.unwrap();
dict.set_item("max_range_check", self.max_range_check)
.unwrap();
dict.set_item("min_range_check", self.min_range_check)
dict.set_item("max_range_size", self.max_range_size)
.unwrap();
if let Some(processed_inputs) = &self.processed_inputs {
@@ -450,6 +449,10 @@ pub struct GraphSettings {
pub total_assignments: usize,
/// total const size
pub total_const_size: usize,
/// total dynamic column size
pub total_dynamic_col_size: usize,
/// number of dynamic lookups
pub num_dynamic_lookups: usize,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -473,6 +476,24 @@ pub struct GraphSettings {
}
impl GraphSettings {
fn model_constraint_logrows(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_logrows(&self) -> u32 {
(self.total_dynamic_col_size as f64).log2().ceil() as u32
}
fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64).log2().ceil() as u32
}
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self
@@ -557,6 +578,11 @@ impl GraphSettings {
|| self.run_args.param_visibility.is_hashed()
}
/// requires dynamic lookup
pub fn requires_dynamic_lookup(&self) -> bool {
self.num_dynamic_lookups > 0
}
/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
self.run_args.input_visibility.is_kzgcommit()
@@ -1005,10 +1031,6 @@ impl GraphCircuit {
Ok(data)
}
fn reserved_blinding_rows() -> f64 {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let mut margin = (
lookup_safety_margin * min_max_lookup.0,
@@ -1022,18 +1044,34 @@ impl GraphCircuit {
margin
}
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(
max_logrows as usize,
Self::reserved_blinding_rows() as usize,
);
num_cols_required(safe_range, max_col_size)
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
fn calc_min_logrows(
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
) -> Result<u32, Box<dyn std::error::Error>> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
max_range_size,
);
let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.)
.log2()
.ceil() as u32;
Ok(min_bits)
}
/// calculate the minimum logrows required for the circuit
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
min_max_range_checks: Range,
max_range_size: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -1043,68 +1081,60 @@ impl GraphCircuit {
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
let mut min_logrows = MIN_LOGROWS;
let reserved_blinding_rows = Self::reserved_blinding_rows();
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if has overflowed max lookup input
if min_max_lookup.1.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_max_lookup.0.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
{
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
}
if min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|| min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
{
let err_string = format!(
"max range check input {:?} is too large",
min_max_range_checks
);
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
}
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// pick the range with the largest absolute size between safe_lookup_range and min_max_range_checks
let safe_range = if (safe_lookup_range.1 - safe_lookup_range.0)
> (min_max_range_checks.1 - min_max_range_checks.0)
{
safe_lookup_range
} else {
min_max_range_checks
};
// These are hard lower limits, we can't overflow instances or modules constraints
let instance_logrows = self.settings().log2_total_instances();
let module_constraint_logrows = self.settings().module_constraint_logrows();
let dynamic_lookup_logrows = self.settings().dynamic_lookup_logrows();
min_logrows = std::cmp::max(
min_logrows,
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
*[
instance_logrows,
module_constraint_logrows,
dynamic_lookup_logrows,
]
.iter()
.max()
.unwrap(),
);
// These are upper limits, going above these is wasteful, but they are not hard limits
let model_constraint_logrows = self.settings().model_constraint_logrows();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
*[model_constraint_logrows, min_bits, constants_logrows]
.iter()
.max()
.unwrap(),
);
// we now have a min and max logrows
max_logrows = std::cmp::max(min_logrows, max_logrows);
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
min_logrows,
Self::calc_num_cols(safe_range, min_logrows),
)
{
min_logrows += 1;
}
if !self
.extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows))
{
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
min_logrows
);
debug!("{}", err_string);
return Err(err_string.into());
}
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
max_logrows,
Self::calc_num_cols(safe_range, max_logrows),
)
&& !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size)
{
max_logrows -= 1;
}
if !self
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
{
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
@@ -1113,67 +1143,27 @@ impl GraphCircuit {
return Err(err_string.into());
}
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
.log2()
.ceil() as usize;
let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as usize;
let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints);
// if public input then public inputs col will have public inputs len
if self.settings().run_args.input_visibility.is_public()
|| self.settings().run_args.output_visibility.is_public()
{
let mut max_instance_len = self
.model()
.instance_shapes()?
.iter()
.fold(0, |acc, x| std::cmp::max(acc, x.iter().product::<usize>()))
as f64
+ reserved_blinding_rows;
// if there are modules then we need to add the max module size
if self.settings().uses_modules() {
max_instance_len += self
.settings()
.module_sizes
.num_instances()
.iter()
.sum::<usize>() as f64;
}
let instance_len_logrows = (max_instance_len).log2().ceil() as usize;
logrows = std::cmp::max(logrows, instance_len_logrows);
// this is for fixed const columns
}
// ensure logrows is at least 4
logrows = std::cmp::max(logrows, min_logrows as usize);
logrows = std::cmp::min(logrows, max_logrows as usize);
let logrows = max_logrows;
let model = self.model().clone();
let settings_mut = self.settings_mut();
settings_mut.run_args.lookup_range = safe_lookup_range;
settings_mut.run_args.logrows = logrows as u32;
settings_mut.run_args.logrows = logrows;
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
.settings()
.clone();
// recalculate the total const size give nthe new logrows
let total_const_len = settings_mut.total_const_size;
let const_len_logrows = (total_const_len as f64).log2().ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, const_len_logrows);
// recalculate the total number of constraints given the new logrows
let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
// recalculate the logrows if there has been overflow on the constants
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.constants_logrows(),
);
// recalculate the logrows if there has been overflow for the model constraints
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.model_constraint_logrows(),
);
debug!(
"setting lookup_range to: {:?}, setting logrows to: {}",
@@ -1184,12 +1174,29 @@ impl GraphCircuit {
Ok(())
}
fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool {
let max_degree = self.settings().run_args.num_inner_cols + 2;
let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector
fn extended_k_is_small_enough(
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS {
return false;
} else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS {
return false;
}
let max_degree = std::cmp::max(max_degree, max_lookup_degree);
let mut settings = self.settings().clone();
settings.run_args.lookup_range = safe_lookup_range;
settings.run_args.logrows = k;
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
Self::configure_with_params(&mut cs, settings);
#[cfg(feature = "mv-lookup")]
let cs = cs.chunk_lookups();
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
let max_degree = cs.degree();
let quotient_poly_degree = (max_degree - 1) as u64;
// n = 2^k
let n = 1u64 << k;
@@ -1204,29 +1211,13 @@ impl GraphCircuit {
true
}
/// Calibrate the circuit to the supplied data.
pub fn calibrate_from_min_max(
&mut self,
min_max_lookup: Range,
min_max_range_checks: Range,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_max_lookup,
min_max_range_checks,
max_logrows,
lookup_safety_margin,
)?;
Ok(())
}
/// Runs the forward pass of the model / graph of computations and any associated hashing.
pub fn forward(
&self,
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
throw_range_check_error: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1267,7 +1258,9 @@ impl GraphCircuit {
}
}
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1310,8 +1303,7 @@ impl GraphCircuit {
processed_outputs,
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_check: model_results.max_range_check,
min_range_check: model_results.min_range_check,
max_range_size: model_results.max_range_size,
};
witness.generate_rescaled_elements(
@@ -1518,34 +1510,18 @@ impl Circuit<Fp> for GraphCircuit {
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(
cs,
params.run_args.logrows as usize,
params.total_assignments,
params.run_args.num_inner_cols,
params.total_const_size,
params.module_requires_fixed(),
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
vars.instantiate_instance(
cs,
params.model_instance_shapes,
params.model_instance_shapes.clone(),
params.run_args.input_scale,
module_configs.instance,
);
let base = Model::configure(
cs,
&vars,
params.run_args.lookup_range,
params.run_args.logrows as usize,
params.required_lookups,
params.required_range_checks,
params.check_mode,
)
.unwrap();
let base = Model::configure(cs, &vars, &params).unwrap();
let model_config = ModelConfig { base, vars };

View File

@@ -67,10 +67,8 @@ pub struct ForwardResult {
pub max_lookup_inputs: i128,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
/// The max range check value
pub max_range_check: i128,
/// The min range check value
pub min_range_check: i128,
/// The max range check size
pub max_range_size: i128,
}
impl From<DummyPassRes> for ForwardResult {
@@ -79,8 +77,7 @@ impl From<DummyPassRes> for ForwardResult {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
min_range_check: res.min_range_check,
max_range_check: res.max_range_check,
max_range_size: res.max_range_size,
}
}
}
@@ -102,6 +99,10 @@ pub type NodeGraph = BTreeMap<usize, NodeType>;
pub struct DummyPassRes {
/// number of rows use
pub num_rows: usize,
/// num dynamic lookups
pub num_dynamic_lookups: usize,
/// dynamic lookup col size
pub dynamic_lookup_col_coord: usize,
/// linear coordinate
pub linear_coord: usize,
/// total const size
@@ -115,9 +116,7 @@ pub struct DummyPassRes {
/// min lookup inputs
pub min_lookup_inputs: i128,
/// min range check
pub min_range_check: i128,
/// max range check
pub max_range_check: i128,
pub max_range_size: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -531,7 +530,7 @@ impl Model {
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs)?;
let res = self.dummy_layout(run_args, &inputs, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -545,6 +544,8 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -570,12 +571,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
Ok(res.into())
}
@@ -1007,24 +1009,25 @@ impl Model {
/// # Arguments
/// * `meta` - The constraint system.
/// * `vars` - The variables for the circuit.
/// * `run_args` - [RunArgs]
/// * `required_lookups` - The required lookup operations for the circuit.
/// * `settings` - [GraphSettings]
pub fn configure(
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
lookup_range: Range,
logrows: usize,
required_lookups: Vec<LookupOp>,
required_range_checks: Vec<Range>,
check_mode: CheckMode,
settings: &GraphSettings,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
info!("configuring model");
debug!("configuring model");
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let num_dynamic_lookups = settings.num_dynamic_lookups;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
&vars.advices[2],
check_mode,
settings.check_mode,
);
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
let input = &vars.advices[0];
@@ -1038,6 +1041,14 @@ impl Model {
base_gate.configure_range_check(meta, input, index, range, logrows)?;
}
for _ in 0..num_dynamic_lookups {
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
)?;
}
Ok(base_gate)
}
@@ -1356,6 +1367,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1374,7 +1386,7 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
@@ -1441,8 +1453,9 @@ impl Model {
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
min_range_check: region.min_range_check(),
max_range_check: region.max_range_check(),
max_range_size: region.max_range_size(),
num_dynamic_lookups: region.dynamic_lookup_index(),
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
outputs,
};

View File

@@ -734,7 +734,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
use_range_check_for_int: false,
use_range_check_for_int: true,
})
}

View File

@@ -420,20 +420,31 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
}
/// Allocate all columns that will be assigned to by a model.
pub fn new(
cs: &mut ConstraintSystem<F>,
logrows: usize,
var_len: usize,
num_inner_cols: usize,
num_constants: usize,
module_requires_fixed: bool,
) -> Self {
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
let advices = (0..3)
let logrows = params.run_args.logrows as usize;
let var_len = params.total_assignments;
let num_inner_cols = params.run_args.num_inner_cols;
let num_constants = params.total_const_size;
let module_requires_fixed = params.module_requires_fixed();
let requires_dynamic_lookup = params.requires_dynamic_lookup();
let dynamic_lookup_size = params.total_dynamic_col_size;
let mut advices = (0..3)
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
.collect_vec();
if requires_dynamic_lookup {
for _ in 0..2 {
let dynamic_lookup = VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_size);
if dynamic_lookup.num_blocks() > 1 {
panic!("dynamic lookup should only have one block");
};
advices.push(dynamic_lookup);
}
}
debug!(
"model uses {} advice blocks (size={})",
advices.iter().map(|v| v.num_blocks()).sum::<usize>(),

View File

@@ -180,6 +180,9 @@ impl RunArgs {
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
Ok(())
}

View File

@@ -197,7 +197,11 @@ impl std::fmt::Display for TranscriptType {
}
}
impl ToFlags for TranscriptType {}
impl ToFlags for TranscriptType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for TranscriptType {

View File

@@ -3773,6 +3773,30 @@ pub mod nonlinearities {
.unwrap()
}
/// Elementwise inverse.
/// # Arguments
/// * `out_scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let expected = Tensor::<i128>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<i128> {
let a = Tensor::<i128>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as i128)
})
.unwrap()
}
/// Elementwise greater than
/// # Arguments
///

View File

@@ -211,7 +211,7 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward(&mut input, None, None)
.forward(&mut input, None, None, false)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)

View File

@@ -2,6 +2,7 @@
#[cfg(test)]
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
@@ -276,7 +277,7 @@ mod native_tests {
"bitshift",
];
const WASM_TESTS: [&str; 48] = [
const WASM_TESTS: [&str; 46] = [
"1l_mlp",
"1l_slice",
"1l_concat",
@@ -325,8 +326,6 @@ mod native_tests {
"1l_where",
"boolean",
"boolean_identity",
"decision_tree", // "variable_cnn",
"random_forest",
"gradient_boosted_trees",
"1l_topk",
// "xgboost",
@@ -586,6 +585,8 @@ mod native_tests {
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
@@ -841,7 +842,7 @@ mod native_tests {
});
seq!(N in 0..=47 {
seq!(N in 0..=45 {
#(#[test_case(WASM_TESTS[N])])*
fn kzg_prove_and_verify_with_overflow_(test: &str) {
@@ -1288,6 +1289,7 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
) {
let mut tolerance = tolerance;
gen_circuit_settings_and_witness(
test_dir,
example_name.clone(),
@@ -1299,16 +1301,10 @@ mod native_tests {
scales_to_use,
2,
false,
tolerance,
&mut tolerance,
);
let settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if tolerance > 0.0 && !any_output_scales_smol {
if tolerance > 0.0 {
// load witness and shift the output by a small amount that is less than tolerance percent
let witness = GraphWitness::from_path(
format!("{}/{}/witness.json", test_dir, example_name).into(),
@@ -1333,7 +1329,7 @@ mod native_tests {
as i128,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
@@ -1444,7 +1440,7 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
div_rebasing: bool,
tolerance: f32,
tolerance: &mut f32,
) {
let mut args = vec![
"gen-settings".to_string(),
@@ -1502,6 +1498,24 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let mut settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if any_output_scales_smol {
// set the tolerance to 0.0
settings.run_args.tolerance = Tolerance {
val: 0.0,
scale: 0.0.into(),
};
settings
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
*tolerance = 0.0;
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"compile-circuit",
@@ -1559,7 +1573,7 @@ mod native_tests {
None,
2,
div_rebasing,
0.0,
&mut 0.0,
);
println!(
@@ -1819,7 +1833,7 @@ mod native_tests {
scales_to_use,
num_inner_columns,
false,
0.0,
&mut 0.0,
);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
@@ -1921,7 +1935,7 @@ mod native_tests {
None,
2,
false,
0.0,
&mut 0.0,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
@@ -2198,7 +2212,7 @@ mod native_tests {
Some(vec![4]),
1,
false,
0.0,
&mut 0.0,
);
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);

View File

@@ -91,9 +91,7 @@ def compare_outputs(zk_output, onnx_output):
print("------- zk_output: ", list1_i)
print("------- onnx_output: ", list2_i)
return np.mean(np.abs(res))
return res
if __name__ == '__main__':
@@ -113,6 +111,9 @@ if __name__ == '__main__':
onnx_output = get_onnx_output(model_file, input_file)
# compare the outputs
percentage_difference = compare_outputs(ezkl_output, onnx_output)
mean_percentage_difference = np.mean(np.abs(percentage_difference))
max_percentage_difference = np.max(np.abs(percentage_difference))
# print the percentage difference
print("mean percent diff: ", percentage_difference)
assert percentage_difference < target, "Percentage difference is too high"
print("mean percent diff: ", mean_percentage_difference)
print("max percent diff: ", max_percentage_difference)
assert mean_percentage_difference < target, "Percentage difference is too high"

Binary file not shown.

View File

@@ -27,6 +27,8 @@
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,
"num_dynamic_lookups": 0,
"total_assignments": 32,
"total_const_size": 8,
"model_instance_shapes": [

View File

@@ -1 +1 @@
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_check":0,"min_range_check":0}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}