mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14786acb95 | ||
|
|
80a3c44cb4 |
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -2263,7 +2263,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec 1.0.1",
|
||||
@@ -2280,7 +2280,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
|
||||
@@ -633,7 +633,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -664,7 +664,6 @@
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -198,6 +198,9 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub lookup_input: VarTensor,
|
||||
/// the (currently singular) output of the accumulated operations.
|
||||
pub output: VarTensor,
|
||||
/// The VarTensor reserved for dynamic lookup operations (could be an element of inputs or the same as output)
|
||||
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
|
||||
pub dynamic_lookup_tables: Vec<VarTensor>,
|
||||
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
|
||||
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
|
||||
pub lookup_output: VarTensor,
|
||||
@@ -207,6 +210,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
|
||||
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub dynamic_lookup_selectors: BTreeMap<(usize, usize), Vec<Selector>>,
|
||||
///
|
||||
pub dynamic_table_selectors: Vec<Selector>,
|
||||
///
|
||||
pub tables: BTreeMap<LookupOp, Table<F>>,
|
||||
///
|
||||
@@ -228,9 +235,12 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
lookup_input: dummy_var.clone(),
|
||||
output: dummy_var.clone(),
|
||||
lookup_output: dummy_var.clone(),
|
||||
dynamic_lookup_tables: vec![VarTensor::dummy(col_size, 2), dummy_var.clone()],
|
||||
lookup_index: dummy_var,
|
||||
selectors: BTreeMap::new(),
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
dynamic_lookup_selectors: BTreeMap::new(),
|
||||
dynamic_table_selectors: vec![],
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
@@ -376,10 +386,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
selectors,
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
dynamic_lookup_selectors: BTreeMap::new(),
|
||||
inputs: inputs.to_vec(),
|
||||
lookup_input: VarTensor::Empty,
|
||||
lookup_output: VarTensor::Empty,
|
||||
lookup_index: VarTensor::Empty,
|
||||
dynamic_table_selectors: vec![],
|
||||
dynamic_lookup_tables: vec![],
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
output: output.clone(),
|
||||
@@ -403,8 +416,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !index.is_advice() {
|
||||
return Err("wrong input type for lookup index".into());
|
||||
}
|
||||
@@ -514,10 +525,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
selectors.insert((nl.clone(), x, y), multi_col_selector);
|
||||
self.lookup_selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
}
|
||||
}
|
||||
self.lookup_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
debug!("assigning lookup input");
|
||||
@@ -534,6 +545,85 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_dynamic_lookup(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
lookups: &[VarTensor; 2],
|
||||
tables: &[VarTensor; 2],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in lookups.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
}
|
||||
}
|
||||
|
||||
for t in tables.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
}
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_ltable = cs.complex_selector();
|
||||
|
||||
for x in 0..lookups[0].num_blocks() {
|
||||
for y in 0..lookups[0].num_inner_cols() {
|
||||
let s_lookup = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_lookupq = cs.query_selector(s_lookup);
|
||||
let mut expression = vec![];
|
||||
let s_ltableq = cs.query_selector(s_ltable);
|
||||
let mut lookup_queries = vec![one.clone()];
|
||||
|
||||
for lookup in lookups {
|
||||
lookup_queries.push(match lookup {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut table_queries = vec![one.clone()];
|
||||
for table in tables {
|
||||
table_queries.push(match table {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
|
||||
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.dynamic_lookup_selectors
|
||||
.entry((x, y))
|
||||
.or_default()
|
||||
.push(s_lookup);
|
||||
}
|
||||
}
|
||||
self.dynamic_table_selectors.push(s_ltable);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.dynamic_lookup_tables.is_empty() {
|
||||
debug!("assigning dynamic lookup table");
|
||||
self.dynamic_lookup_tables = tables.to_vec();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_range_check(
|
||||
@@ -547,8 +637,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
}
|
||||
@@ -620,10 +708,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
selectors.insert((range, x, y), multi_col_selector);
|
||||
self.range_check_selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
}
|
||||
}
|
||||
self.range_check_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
debug!("assigning lookup input");
|
||||
|
||||
@@ -277,7 +277,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::div(
|
||||
layouts::loop_div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
|
||||
@@ -18,10 +18,7 @@ use super::{
|
||||
region::RegionCtx,
|
||||
};
|
||||
use crate::{
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
utils::{self, F32},
|
||||
},
|
||||
circuit::{ops::base::BaseOp, utils},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
@@ -54,6 +51,41 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
|
||||
total_len
|
||||
}
|
||||
|
||||
/// Same as div but splits the division into N parts
|
||||
pub fn loop_div<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
divisor: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if divisor == F::ONE {
|
||||
return Ok(value[0].clone());
|
||||
}
|
||||
|
||||
// if integer val is divisible by 2, we can use a faster method and div > F::S
|
||||
let mut divisor = divisor;
|
||||
let mut num_parts = 1;
|
||||
|
||||
while felt_to_i128(divisor) % 2 == 0 && felt_to_i128(divisor) > (2_i128.pow(F::S - 4)) {
|
||||
divisor = i128_to_felt(felt_to_i128(divisor) / 2);
|
||||
num_parts += 1;
|
||||
}
|
||||
|
||||
let output = div(config, region, value, divisor)?;
|
||||
if num_parts == 1 {
|
||||
return Ok(output);
|
||||
}
|
||||
|
||||
let divisor_int = 2_i128.pow(num_parts - 1);
|
||||
let divisor_felt = i128_to_felt(divisor_int);
|
||||
if divisor_int <= 2_i128.pow(F::S - 3) {
|
||||
div(config, region, &[output], divisor_felt)
|
||||
} else {
|
||||
// keep splitting the divisor until it satisfies the condition
|
||||
loop_div(config, region, &[output], divisor_felt)
|
||||
}
|
||||
}
|
||||
|
||||
/// Div accumulated layout
|
||||
pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -61,6 +93,10 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
value: &[ValTensor<F>; 1],
|
||||
div: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if div == F::ONE {
|
||||
return Ok(value[0].clone());
|
||||
}
|
||||
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
@@ -88,6 +124,8 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
region.assign(&config.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
@@ -96,8 +134,6 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -105,8 +141,6 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
|
||||
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
@@ -117,6 +151,46 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
input: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// assert is boolean
|
||||
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
|
||||
// get values where input is 0
|
||||
let zero_mask = equals_zero(config, region, input)?;
|
||||
|
||||
let one_minus_zero_mask = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
zero_mask.clone(),
|
||||
ValTensor::from(Tensor::from([ValType::Constant(F::ONE)].into_iter())),
|
||||
],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let zero_inverse_val = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
zero_mask,
|
||||
ValTensor::from(Tensor::from(
|
||||
[ValType::Constant(i128_to_felt(zero_inverse_val))].into_iter(),
|
||||
)),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[one_minus_zero_mask, zero_inverse_val],
|
||||
BaseOp::Add,
|
||||
)
|
||||
}
|
||||
|
||||
/// recip accumulated layout
|
||||
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -125,10 +199,23 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
input_scale: F,
|
||||
output_scale: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if output_scale == F::ONE || output_scale == F::ZERO {
|
||||
return recip_int(config, region, value);
|
||||
}
|
||||
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
|
||||
let integer_input_scale = felt_to_i128(input_scale);
|
||||
let integer_output_scale = felt_to_i128(output_scale);
|
||||
|
||||
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
|
||||
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
|
||||
|
||||
let input_scale_ratio =
|
||||
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
|
||||
|
||||
let range_check_bracket = range_check_len / 2;
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
@@ -151,6 +238,8 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// this is now of scale 2 * scale
|
||||
let product = pairwise(
|
||||
@@ -160,15 +249,46 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
// divide by input_scale
|
||||
let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?;
|
||||
|
||||
log::debug!("range_check_bracket: {:?}", range_check_bracket);
|
||||
let zero_inverse_val =
|
||||
tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0];
|
||||
let zero_inverse =
|
||||
Tensor::from([ValType::Constant(i128_to_felt::<F>(zero_inverse_val))].into_iter());
|
||||
|
||||
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
|
||||
|
||||
let equal_inverse_mask = equals(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), zero_inverse.into()],
|
||||
)?;
|
||||
|
||||
// assert the two masks are equal
|
||||
enforce_equality(
|
||||
config,
|
||||
region,
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let unit_scale = Tensor::from([ValType::Constant(i128_to_felt(range_check_len))].into_iter());
|
||||
|
||||
let unit_mask = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[equal_zero_mask, unit_scale.into()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// now add the unit mask to the rebased_div
|
||||
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
&[rebased_offset_div],
|
||||
&(range_check_bracket, 3 * range_check_bracket),
|
||||
)?;
|
||||
|
||||
@@ -795,6 +915,78 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(assigned_output)
|
||||
}
|
||||
|
||||
/// Dynamic lookup
|
||||
pub fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
lookups: &[ValTensor<F>; 2],
|
||||
tables: &[ValTensor<F>; 2],
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
|
||||
// if not all lookups same length err
|
||||
if lookups[0].len() != lookups[1].len() {
|
||||
return Err("lookups must be same length".into());
|
||||
}
|
||||
|
||||
// if not all inputs same length err
|
||||
if tables[0].len() != tables[1].len() {
|
||||
return Err("inputs must be same length".into());
|
||||
}
|
||||
|
||||
// now assert the inputs of a smaller length than the lookups
|
||||
if tables[0].len() > tables[0].len() {
|
||||
return Err("inputs must be smaller length than dynamic lookups".into());
|
||||
}
|
||||
|
||||
let (lookup_0, lookup_1) = (lookups[0].clone(), lookups[1].clone());
|
||||
let (table_0, table_1) = (tables[0].clone(), tables[1].clone());
|
||||
|
||||
let table_0 = region.assign_dynamic_lookup(&config.dynamic_lookup_tables[0], &table_0)?;
|
||||
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookup_tables[1], &table_1)?;
|
||||
let table_len = table_0.len();
|
||||
|
||||
let lookup_0 = region.assign(&config.inputs[0], &lookup_0)?;
|
||||
let lookup_1 = region.assign(&config.inputs[1], &lookup_1)?;
|
||||
|
||||
let lookup_len = lookup_0.len();
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..table_len)
|
||||
.map(|i| {
|
||||
let dynamic_lookup_index = region.dynamic_lookup_index();
|
||||
let table_selector = config.dynamic_table_selectors[dynamic_lookup_index];
|
||||
let (_, _, z) = config.dynamic_lookup_tables[0]
|
||||
.cartesian_coord(region.dynamic_lookup_col_coord() + i);
|
||||
region.enable(Some(&table_selector), z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
if !region.is_dummy() {
|
||||
// Enable the selectors
|
||||
(0..lookup_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[0].cartesian_coord(region.linear_coord() + i);
|
||||
let dynamic_lookup_index = region.dynamic_lookup_index();
|
||||
let lookup_selector = config
|
||||
.dynamic_lookup_selectors
|
||||
.get(&(x, y))
|
||||
.ok_or("missing selectors")?[dynamic_lookup_index];
|
||||
|
||||
region.enable(Some(&lookup_selector), z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment_dynamic_lookup_col_coord(table_len);
|
||||
region.increment_dynamic_lookup_index(1);
|
||||
region.increment(lookup_len);
|
||||
|
||||
Ok((lookup_0, lookup_1))
|
||||
}
|
||||
|
||||
/// One hot accumulated layout
|
||||
pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -1677,9 +1869,23 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let diff_inverse = diff.inverse()?;
|
||||
let product_diff_and_invert =
|
||||
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
|
||||
equals_zero(config, region, &[diff])
|
||||
}
|
||||
|
||||
/// Equality boolean operation
|
||||
pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let values = values[0].clone();
|
||||
let values_inverse = values.inverse()?;
|
||||
let product_values_and_invert = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values.clone(), values_inverse],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// constant of 1
|
||||
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
|
||||
@@ -1689,12 +1895,12 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
let output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[ones.into(), product_diff_and_invert],
|
||||
&[ones.into(), product_values_and_invert],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// take the product of diff and output
|
||||
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
|
||||
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
|
||||
|
||||
is_zero_identity(config, region, &[prod_check], false)?;
|
||||
|
||||
@@ -1860,7 +2066,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
last_elem = div(
|
||||
last_elem = loop_div(
|
||||
config,
|
||||
region,
|
||||
&[last_elem],
|
||||
@@ -2519,6 +2725,17 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
if region.throw_range_check_error() {
|
||||
// assert is within range
|
||||
let int_values = w.get_int_evals()?;
|
||||
for v in int_values {
|
||||
if v < range.0 || v > range.1 {
|
||||
log::debug!("Value ({:?}) out of range: {:?}", v, range);
|
||||
return Err(Box::new(TensorError::TableLookupError));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
let elapsed = timer.elapsed();
|
||||
@@ -2945,16 +3162,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
let denom = sum(config, region, &[ex.clone()])?;
|
||||
// get the inverse
|
||||
|
||||
let inv_denom = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[denom],
|
||||
// we set to input scale + output_scale so the output scale is output)scale
|
||||
&LookupOp::Recip {
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)?;
|
||||
let felt_scale = F::from(scale.0 as u64);
|
||||
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
|
||||
|
||||
// product of num * (1 / denom) = 2*output_scale
|
||||
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
|
||||
@@ -2989,29 +3198,44 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
|
||||
let recip = nonlinearity(
|
||||
// integer scale
|
||||
let int_scale = scale.0 as i128;
|
||||
// felt scale
|
||||
let felt_scale = i128_to_felt(int_scale);
|
||||
// range check len capped at 2^(S-3) and make it divisible 2
|
||||
let range_check_bracket = std::cmp::min(
|
||||
utils::F32(scale.0),
|
||||
utils::F32(2_f32.powf((F::S - 5) as f32)),
|
||||
)
|
||||
.0;
|
||||
|
||||
let range_check_bracket_int = range_check_bracket as i128;
|
||||
|
||||
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
|
||||
let input_scale_ratio = ((scale.0.powf(2.0) / range_check_bracket) * tol) as i128 / 2 * 2;
|
||||
|
||||
let recip = recip(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone()],
|
||||
&LookupOp::Recip {
|
||||
input_scale: scale,
|
||||
// multiply by 100 to get the percent error
|
||||
output_scale: F32(scale.0 * 100.0),
|
||||
},
|
||||
felt_scale,
|
||||
felt_scale * F::from(100),
|
||||
)?;
|
||||
|
||||
log::debug!("recip: {}", recip.show());
|
||||
|
||||
// Multiply the difference by the recip
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
let rebased_product = div(config, region, &[product], F::from(scale.0 as u64))?;
|
||||
|
||||
let scaled_tol = (tol * scale.0) as i128;
|
||||
log::debug!("product: {}", product.show());
|
||||
let rebased_product = loop_div(config, region, &[product], i128_to_felt(input_scale_ratio))?;
|
||||
log::debug!("rebased_product: {}", rebased_product.show());
|
||||
|
||||
// check that it is within the tolerance range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[rebased_product],
|
||||
&(-scaled_tol, scaled_tol),
|
||||
&(-range_check_bracket_int, range_check_bracket_int),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,33 @@ use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Dynamic lookup index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookupIndex {
|
||||
lookup_index: usize,
|
||||
col_coord: usize,
|
||||
}
|
||||
|
||||
impl DynamicLookupIndex {
|
||||
/// Create a new dynamic lookup index
|
||||
pub fn new(lookup_index: usize, col_coord: usize) -> DynamicLookupIndex {
|
||||
DynamicLookupIndex {
|
||||
lookup_index,
|
||||
col_coord,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the lookup index
|
||||
pub fn lookup_index(&self) -> usize {
|
||||
self.lookup_index
|
||||
}
|
||||
|
||||
/// Get the column coord
|
||||
pub fn col_coord(&self) -> usize {
|
||||
self.col_coord
|
||||
}
|
||||
}
|
||||
|
||||
/// Region error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegionError {
|
||||
@@ -66,12 +93,13 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
total_constants: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
min_range_check: i128,
|
||||
max_range_check: i128,
|
||||
max_range_size: i128,
|
||||
throw_range_check_error: bool,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
@@ -80,6 +108,21 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
|
||||
self.dynamic_lookup_index.lookup_index += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
|
||||
self.dynamic_lookup_index.col_coord += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn throw_range_check_error(&self) -> bool {
|
||||
self.throw_range_check_error
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
let region = Some(RefCell::new(region));
|
||||
@@ -91,12 +134,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
row,
|
||||
linear_coord,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
}
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
@@ -104,6 +148,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let linear_coord = row * num_inner_cols;
|
||||
RegionCtx {
|
||||
@@ -112,17 +157,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
throw_range_check_error: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -132,12 +182,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,8 +198,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord: usize,
|
||||
total_constants: usize,
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
throw_range_check_error: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -157,12 +210,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants,
|
||||
dynamic_lookup_index,
|
||||
used_lookups,
|
||||
used_range_checks,
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +271,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
@@ -232,8 +287,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
starting_linear_coord,
|
||||
starting_constants,
|
||||
self.num_inner_cols,
|
||||
DynamicLookupIndex::default(),
|
||||
HashSet::new(),
|
||||
HashSet::new(),
|
||||
self.throw_range_check_error,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -254,6 +311,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.lookup_index += local_reg.dynamic_lookup_index.lookup_index;
|
||||
dynamic_lookup_index.col_coord += local_reg.dynamic_lookup_index.col_coord;
|
||||
|
||||
res
|
||||
})
|
||||
.map_err(|e| {
|
||||
@@ -282,6 +343,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -310,8 +385,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
}
|
||||
|
||||
self.max_range_check = self.max_range_check.max(range.1);
|
||||
self.min_range_check = self.min_range_check.min(range.0);
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
|
||||
self.max_range_size = self.max_range_size.max(range_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -351,6 +427,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants
|
||||
}
|
||||
|
||||
/// Get the dynamic lookup index
|
||||
pub fn dynamic_lookup_index(&self) -> usize {
|
||||
self.dynamic_lookup_index.lookup_index
|
||||
}
|
||||
|
||||
/// Get the dynamic lookup column coordinate
|
||||
pub fn dynamic_lookup_col_coord(&self) -> usize {
|
||||
self.dynamic_lookup_index.col_coord
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.used_lookups.clone()
|
||||
@@ -371,14 +457,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// min range check
|
||||
pub fn min_range_check(&self) -> i128 {
|
||||
self.min_range_check
|
||||
}
|
||||
|
||||
/// max range check
|
||||
pub fn max_range_check(&self) -> i128 {
|
||||
self.max_range_check
|
||||
pub fn max_range_size(&self) -> i128 {
|
||||
self.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
@@ -405,6 +486,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign_dynamic_lookup(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.total_constants += values.num_constants();
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.dynamic_lookup_col_coord(),
|
||||
values,
|
||||
)
|
||||
} else {
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign_with_omissions(
|
||||
&mut self,
|
||||
|
||||
@@ -133,9 +133,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
|
||||
// double it to be safe
|
||||
let range_len = range.1 - range.0;
|
||||
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as i128)) as usize + 1
|
||||
}
|
||||
@@ -152,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
// number of cols needed to store the range
|
||||
let num_cols = num_cols_required(range, col_size);
|
||||
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
|
||||
|
||||
log::debug!("table range: {:?}", range);
|
||||
|
||||
@@ -313,7 +311,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
// number of cols needed to store the range
|
||||
let num_cols = num_cols_required(range, col_size);
|
||||
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
|
||||
|
||||
let inputs = {
|
||||
let mut cols = vec![];
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::circuit::ops::hybrid::HybridOp;
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::*;
|
||||
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
|
||||
@@ -1575,6 +1574,142 @@ mod add {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dynamic_lookup {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 4;
|
||||
const NUM_LOOP: usize = 5;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
tables: [[ValTensor<F>; 2]; NUM_LOOP],
|
||||
lookups: [[ValTensor<F>; 2]; NUM_LOOP],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
for _ in 0..NUM_LOOP {
|
||||
config
|
||||
.configure_dynamic_lookup(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::dynamic_lookup(
|
||||
&config,
|
||||
&mut region,
|
||||
&self.lookups[i],
|
||||
&self.tables[i],
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
assert_eq!(
|
||||
region.dynamic_lookup_col_coord(),
|
||||
NUM_LOOP * self.tables[0][0].len()
|
||||
);
|
||||
assert_eq!(region.dynamic_lookup_index(), NUM_LOOP);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamiclookupcircuit() {
|
||||
// parameters
|
||||
let tables = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
)),
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
|
||||
)),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lookups = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..3).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
)),
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..3).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
|
||||
)),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
tables: tables.clone().try_into().unwrap(),
|
||||
lookups: lookups.try_into().unwrap(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
|
||||
let lookups = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..2).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
)),
|
||||
ValTensor::from(Tensor::from((0..2).map(|_| Value::known(F::from(10000))))),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
tables: tables.try_into().unwrap(),
|
||||
lookups: lookups.try_into().unwrap(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
assert!(prover.verify().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod add_with_overflow {
|
||||
use super::*;
|
||||
@@ -2338,113 +2473,3 @@ mod lookup_ultra_overflow {
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod softmax {
|
||||
|
||||
use super::*;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
dev::MockProver,
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
|
||||
const K: usize = 18;
|
||||
const LEN: usize = 3;
|
||||
const SCALE: f32 = 128.0;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SoftmaxCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for SoftmaxCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE);
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&advices[0],
|
||||
&advices[1],
|
||||
&advices[2],
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Exp {
|
||||
scale: SCALE.into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&advices[0],
|
||||
&advices[1],
|
||||
&advices[2],
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Recip {
|
||||
input_scale: SCALE.into(),
|
||||
output_scale: SCALE.into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let _output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(HybridOp::Softmax {
|
||||
scale: SCALE.into(),
|
||||
axes: vec![0],
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_circuit() {
|
||||
let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let circuit = SoftmaxCircuit::<F> {
|
||||
input: ValTensor::from(input),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,7 +618,7 @@ pub(crate) async fn gen_witness(
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?;
|
||||
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref(), false)?;
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
trace!(
|
||||
@@ -808,16 +808,7 @@ pub(crate) fn calibrate(
|
||||
// we load the model to get the input and output shapes
|
||||
// check if gag already exists
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path)?;
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
|
||||
info!("num of calibration batches: {}", chunks.len());
|
||||
@@ -833,7 +824,7 @@ pub(crate) fn calibrate(
|
||||
let range = if let Some(scales) = scales {
|
||||
scales
|
||||
} else {
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
(11..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
@@ -896,16 +887,6 @@ pub(crate) fn calibrate(
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
#[cfg(unix)]
|
||||
let _q = match Gag::stderr() {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
@@ -920,17 +901,12 @@ pub(crate) fn calibrate(
|
||||
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_q);
|
||||
debug!("circuit creation from run args failed: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
chunks
|
||||
let forward_res = chunks
|
||||
.iter()
|
||||
.map(|chunk| {
|
||||
let chunk = chunk.clone();
|
||||
@@ -940,7 +916,7 @@ pub(crate) fn calibrate(
|
||||
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
|
||||
|
||||
let forward_res = circuit
|
||||
.forward(&mut data.clone(), None, None)
|
||||
.forward(&mut data.clone(), None, None, true)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
// push result to the hashmap
|
||||
@@ -951,7 +927,16 @@ pub(crate) fn calibrate(
|
||||
|
||||
Ok(()) as Result<(), String>
|
||||
})
|
||||
.collect::<Result<Vec<()>, String>>()?;
|
||||
.collect::<Result<Vec<()>, String>>();
|
||||
|
||||
match forward_res {
|
||||
Ok(_) => (),
|
||||
// typically errors will be due to the circuit overflowing the i128 limit
|
||||
Err(e) => {
|
||||
debug!("forward pass failed: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let min_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
@@ -969,35 +954,21 @@ pub(crate) fn calibrate(
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let min_range_check = forward_pass_res
|
||||
let max_range_size = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.min_range_check)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_range_check = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.max_range_check)
|
||||
.map(|x| x.max_range_size)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let res = circuit.calibrate_from_min_max(
|
||||
let res = circuit.calc_min_logrows(
|
||||
(min_lookup_range, max_lookup_range),
|
||||
(min_range_check, max_range_check),
|
||||
max_range_size,
|
||||
max_logrows,
|
||||
lookup_safety_margin,
|
||||
);
|
||||
|
||||
// // drop the gag
|
||||
// #[cfg(unix)]
|
||||
// std::mem::drop(_r);
|
||||
// #[cfg(unix)]
|
||||
// std::mem::drop(_q);
|
||||
|
||||
if res.is_ok() {
|
||||
let new_settings = circuit.settings().clone();
|
||||
|
||||
|
||||
310
src/graph/mod.rs
310
src/graph/mod.rs
@@ -61,8 +61,11 @@ use crate::pfsys::field_to_string;
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
|
||||
/// The maximum number of columns in a lookup table.
|
||||
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
|
||||
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
@@ -134,15 +137,16 @@ pub enum GraphError {
|
||||
MissingResults,
|
||||
}
|
||||
|
||||
const ASSUMED_BLINDING_FACTORS: usize = 5;
|
||||
///
|
||||
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
|
||||
/// The minimum number of rows in the grid
|
||||
pub const MIN_LOGROWS: u32 = 6;
|
||||
|
||||
/// 26
|
||||
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
|
||||
|
||||
/// Lookup deg
|
||||
pub const LOOKUP_DEG: usize = 5;
|
||||
///
|
||||
pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD;
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
@@ -171,10 +175,8 @@ pub struct GraphWitness {
|
||||
pub max_lookup_inputs: i128,
|
||||
/// max lookup input
|
||||
pub min_lookup_inputs: i128,
|
||||
/// max range check input
|
||||
pub max_range_check: i128,
|
||||
/// max range check input
|
||||
pub min_range_check: i128,
|
||||
/// max range check size
|
||||
pub max_range_size: i128,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -202,8 +204,7 @@ impl GraphWitness {
|
||||
processed_outputs: None,
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,9 +377,7 @@ impl ToPyObject for GraphWitness {
|
||||
.unwrap();
|
||||
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
|
||||
.unwrap();
|
||||
dict.set_item("max_range_check", self.max_range_check)
|
||||
.unwrap();
|
||||
dict.set_item("min_range_check", self.min_range_check)
|
||||
dict.set_item("max_range_size", self.max_range_size)
|
||||
.unwrap();
|
||||
|
||||
if let Some(processed_inputs) = &self.processed_inputs {
|
||||
@@ -450,6 +449,10 @@ pub struct GraphSettings {
|
||||
pub total_assignments: usize,
|
||||
/// total const size
|
||||
pub total_const_size: usize,
|
||||
/// total dynamic column size
|
||||
pub total_dynamic_col_size: usize,
|
||||
/// number of dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// the shape of public inputs to the model (in order of appearance)
|
||||
pub model_instance_shapes: Vec<Vec<usize>>,
|
||||
/// model output scales
|
||||
@@ -473,6 +476,24 @@ pub struct GraphSettings {
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
fn model_constraint_logrows(&self) -> u32 {
|
||||
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_logrows(&self) -> u32 {
|
||||
(self.total_dynamic_col_size as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
fn module_constraint_logrows(&self) -> u32 {
|
||||
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
fn constants_logrows(&self) -> u32 {
|
||||
(self.total_const_size as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the total number of instances
|
||||
pub fn total_instances(&self) -> Vec<usize> {
|
||||
let mut instances: Vec<usize> = self
|
||||
@@ -557,6 +578,11 @@ impl GraphSettings {
|
||||
|| self.run_args.param_visibility.is_hashed()
|
||||
}
|
||||
|
||||
/// requires dynamic lookup
|
||||
pub fn requires_dynamic_lookup(&self) -> bool {
|
||||
self.num_dynamic_lookups > 0
|
||||
}
|
||||
|
||||
/// any kzg visibility
|
||||
pub fn module_requires_kzg(&self) -> bool {
|
||||
self.run_args.input_visibility.is_kzgcommit()
|
||||
@@ -1005,10 +1031,6 @@ impl GraphCircuit {
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn reserved_blinding_rows() -> f64 {
|
||||
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
|
||||
let mut margin = (
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
@@ -1022,18 +1044,34 @@ impl GraphCircuit {
|
||||
margin
|
||||
}
|
||||
|
||||
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(
|
||||
max_logrows as usize,
|
||||
Self::reserved_blinding_rows() as usize,
|
||||
);
|
||||
num_cols_required(safe_range, max_col_size)
|
||||
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
|
||||
num_cols_required(range_len, max_col_size)
|
||||
}
|
||||
|
||||
fn calc_min_logrows(
|
||||
fn table_size_logrows(
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i128,
|
||||
) -> Result<u32, Box<dyn std::error::Error>> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
|
||||
max_range_size,
|
||||
);
|
||||
|
||||
let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.)
|
||||
.log2()
|
||||
.ceil() as u32;
|
||||
|
||||
Ok(min_bits)
|
||||
}
|
||||
|
||||
/// calculate the minimum logrows required for the circuit
|
||||
pub fn calc_min_logrows(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
min_max_range_checks: Range,
|
||||
max_range_size: i128,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@@ -1043,68 +1081,60 @@ impl GraphCircuit {
|
||||
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
|
||||
let mut min_logrows = MIN_LOGROWS;
|
||||
|
||||
let reserved_blinding_rows = Self::reserved_blinding_rows();
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
|
||||
// check if has overflowed max lookup input
|
||||
if min_max_lookup.1.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
|| min_max_lookup.0.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
{
|
||||
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
if min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|
||||
|| min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|
||||
{
|
||||
let err_string = format!(
|
||||
"max range check input {:?} is too large",
|
||||
min_max_range_checks
|
||||
);
|
||||
if max_range_size.abs() > MAX_LOOKUP_ABS {
|
||||
let err_string = format!("max range check size {:?} is too large", max_range_size);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
// pick the range with the largest absolute size between safe_lookup_range and min_max_range_checks
|
||||
let safe_range = if (safe_lookup_range.1 - safe_lookup_range.0)
|
||||
> (min_max_range_checks.1 - min_max_range_checks.0)
|
||||
{
|
||||
safe_lookup_range
|
||||
} else {
|
||||
min_max_range_checks
|
||||
};
|
||||
// These are hard lower limits, we can't overflow instances or modules constraints
|
||||
let instance_logrows = self.settings().log2_total_instances();
|
||||
let module_constraint_logrows = self.settings().module_constraint_logrows();
|
||||
let dynamic_lookup_logrows = self.settings().dynamic_lookup_logrows();
|
||||
min_logrows = std::cmp::max(
|
||||
min_logrows,
|
||||
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
|
||||
*[
|
||||
instance_logrows,
|
||||
module_constraint_logrows,
|
||||
dynamic_lookup_logrows,
|
||||
]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// These are upper limits, going above these is wasteful, but they are not hard limits
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows();
|
||||
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
|
||||
let constants_logrows = self.settings().constants_logrows();
|
||||
max_logrows = std::cmp::min(
|
||||
max_logrows,
|
||||
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
|
||||
*[model_constraint_logrows, min_bits, constants_logrows]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// we now have a min and max logrows
|
||||
max_logrows = std::cmp::max(min_logrows, max_logrows);
|
||||
|
||||
// degrade the max logrows until the extended k is small enough
|
||||
while min_logrows < max_logrows
|
||||
&& !self.extended_k_is_small_enough(
|
||||
min_logrows,
|
||||
Self::calc_num_cols(safe_range, min_logrows),
|
||||
)
|
||||
{
|
||||
min_logrows += 1;
|
||||
}
|
||||
|
||||
if !self
|
||||
.extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows))
|
||||
{
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
min_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
while min_logrows < max_logrows
|
||||
&& !self.extended_k_is_small_enough(
|
||||
max_logrows,
|
||||
Self::calc_num_cols(safe_range, max_logrows),
|
||||
)
|
||||
&& !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size)
|
||||
{
|
||||
max_logrows -= 1;
|
||||
}
|
||||
|
||||
if !self
|
||||
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
|
||||
{
|
||||
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
@@ -1113,67 +1143,27 @@ impl GraphCircuit {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
|
||||
.log2()
|
||||
.ceil() as usize;
|
||||
|
||||
let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows)
|
||||
.log2()
|
||||
.ceil() as usize;
|
||||
|
||||
let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints);
|
||||
|
||||
// if public input then public inputs col will have public inputs len
|
||||
if self.settings().run_args.input_visibility.is_public()
|
||||
|| self.settings().run_args.output_visibility.is_public()
|
||||
{
|
||||
let mut max_instance_len = self
|
||||
.model()
|
||||
.instance_shapes()?
|
||||
.iter()
|
||||
.fold(0, |acc, x| std::cmp::max(acc, x.iter().product::<usize>()))
|
||||
as f64
|
||||
+ reserved_blinding_rows;
|
||||
// if there are modules then we need to add the max module size
|
||||
if self.settings().uses_modules() {
|
||||
max_instance_len += self
|
||||
.settings()
|
||||
.module_sizes
|
||||
.num_instances()
|
||||
.iter()
|
||||
.sum::<usize>() as f64;
|
||||
}
|
||||
let instance_len_logrows = (max_instance_len).log2().ceil() as usize;
|
||||
logrows = std::cmp::max(logrows, instance_len_logrows);
|
||||
// this is for fixed const columns
|
||||
}
|
||||
|
||||
// ensure logrows is at least 4
|
||||
logrows = std::cmp::max(logrows, min_logrows as usize);
|
||||
logrows = std::cmp::min(logrows, max_logrows as usize);
|
||||
let logrows = max_logrows;
|
||||
|
||||
let model = self.model().clone();
|
||||
let settings_mut = self.settings_mut();
|
||||
settings_mut.run_args.lookup_range = safe_lookup_range;
|
||||
settings_mut.run_args.logrows = logrows as u32;
|
||||
settings_mut.run_args.logrows = logrows;
|
||||
|
||||
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
|
||||
.settings()
|
||||
.clone();
|
||||
|
||||
// recalculate the total const size give nthe new logrows
|
||||
let total_const_len = settings_mut.total_const_size;
|
||||
let const_len_logrows = (total_const_len as f64).log2().ceil() as u32;
|
||||
settings_mut.run_args.logrows =
|
||||
std::cmp::max(settings_mut.run_args.logrows, const_len_logrows);
|
||||
// recalculate the total number of constraints given the new logrows
|
||||
let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows)
|
||||
.log2()
|
||||
.ceil() as u32;
|
||||
settings_mut.run_args.logrows =
|
||||
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);
|
||||
|
||||
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
|
||||
// recalculate the logrows if there has been overflow on the constants
|
||||
settings_mut.run_args.logrows = std::cmp::max(
|
||||
settings_mut.run_args.logrows,
|
||||
settings_mut.constants_logrows(),
|
||||
);
|
||||
// recalculate the logrows if there has been overflow for the model constraints
|
||||
settings_mut.run_args.logrows = std::cmp::max(
|
||||
settings_mut.run_args.logrows,
|
||||
settings_mut.model_constraint_logrows(),
|
||||
);
|
||||
|
||||
debug!(
|
||||
"setting lookup_range to: {:?}, setting logrows to: {}",
|
||||
@@ -1184,12 +1174,29 @@ impl GraphCircuit {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool {
|
||||
let max_degree = self.settings().run_args.num_inner_cols + 2;
|
||||
let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector
|
||||
fn extended_k_is_small_enough(
|
||||
&self,
|
||||
k: u32,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i128,
|
||||
) -> bool {
|
||||
// if num cols is too large then the extended k is too large
|
||||
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS {
|
||||
return false;
|
||||
} else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS {
|
||||
return false;
|
||||
}
|
||||
|
||||
let max_degree = std::cmp::max(max_degree, max_lookup_degree);
|
||||
let mut settings = self.settings().clone();
|
||||
settings.run_args.lookup_range = safe_lookup_range;
|
||||
settings.run_args.logrows = k;
|
||||
settings.required_range_checks = vec![(0, max_range_size)];
|
||||
let mut cs = ConstraintSystem::default();
|
||||
Self::configure_with_params(&mut cs, settings);
|
||||
#[cfg(feature = "mv-lookup")]
|
||||
let cs = cs.chunk_lookups();
|
||||
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
|
||||
let max_degree = cs.degree();
|
||||
let quotient_poly_degree = (max_degree - 1) as u64;
|
||||
// n = 2^k
|
||||
let n = 1u64 << k;
|
||||
@@ -1204,29 +1211,13 @@ impl GraphCircuit {
|
||||
true
|
||||
}
|
||||
|
||||
/// Calibrate the circuit to the supplied data.
|
||||
pub fn calibrate_from_min_max(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
min_max_range_checks: Range,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.calc_min_logrows(
|
||||
min_max_lookup,
|
||||
min_max_range_checks,
|
||||
max_logrows,
|
||||
lookup_safety_margin,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the forward pass of the model / graph of computations and any associated hashing.
|
||||
pub fn forward(
|
||||
&self,
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&ParamsKZG<Bn256>>,
|
||||
throw_range_check_error: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1267,7 +1258,9 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
@@ -1310,8 +1303,7 @@ impl GraphCircuit {
|
||||
processed_outputs,
|
||||
max_lookup_inputs: model_results.max_lookup_inputs,
|
||||
min_lookup_inputs: model_results.min_lookup_inputs,
|
||||
max_range_check: model_results.max_range_check,
|
||||
min_range_check: model_results.min_range_check,
|
||||
max_range_size: model_results.max_range_size,
|
||||
};
|
||||
|
||||
witness.generate_rescaled_elements(
|
||||
@@ -1518,34 +1510,18 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
params.run_args.logrows as usize,
|
||||
);
|
||||
|
||||
let mut vars = ModelVars::new(
|
||||
cs,
|
||||
params.run_args.logrows as usize,
|
||||
params.total_assignments,
|
||||
params.run_args.num_inner_cols,
|
||||
params.total_const_size,
|
||||
params.module_requires_fixed(),
|
||||
);
|
||||
let mut vars = ModelVars::new(cs, ¶ms);
|
||||
|
||||
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
|
||||
|
||||
vars.instantiate_instance(
|
||||
cs,
|
||||
params.model_instance_shapes,
|
||||
params.model_instance_shapes.clone(),
|
||||
params.run_args.input_scale,
|
||||
module_configs.instance,
|
||||
);
|
||||
|
||||
let base = Model::configure(
|
||||
cs,
|
||||
&vars,
|
||||
params.run_args.lookup_range,
|
||||
params.run_args.logrows as usize,
|
||||
params.required_lookups,
|
||||
params.required_range_checks,
|
||||
params.check_mode,
|
||||
)
|
||||
.unwrap();
|
||||
let base = Model::configure(cs, &vars, ¶ms).unwrap();
|
||||
|
||||
let model_config = ModelConfig { base, vars };
|
||||
|
||||
|
||||
@@ -67,10 +67,8 @@ pub struct ForwardResult {
|
||||
pub max_lookup_inputs: i128,
|
||||
/// The minimum value of any input to a lookup operation.
|
||||
pub min_lookup_inputs: i128,
|
||||
/// The max range check value
|
||||
pub max_range_check: i128,
|
||||
/// The min range check value
|
||||
pub min_range_check: i128,
|
||||
/// The max range check size
|
||||
pub max_range_size: i128,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
@@ -79,8 +77,7 @@ impl From<DummyPassRes> for ForwardResult {
|
||||
outputs: res.outputs,
|
||||
max_lookup_inputs: res.max_lookup_inputs,
|
||||
min_lookup_inputs: res.min_lookup_inputs,
|
||||
min_range_check: res.min_range_check,
|
||||
max_range_check: res.max_range_check,
|
||||
max_range_size: res.max_range_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,6 +99,10 @@ pub type NodeGraph = BTreeMap<usize, NodeType>;
|
||||
pub struct DummyPassRes {
|
||||
/// number of rows use
|
||||
pub num_rows: usize,
|
||||
/// num dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// dynamic lookup col size
|
||||
pub dynamic_lookup_col_coord: usize,
|
||||
/// linear coordinate
|
||||
pub linear_coord: usize,
|
||||
/// total const size
|
||||
@@ -115,9 +116,7 @@ pub struct DummyPassRes {
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i128,
|
||||
/// min range check
|
||||
pub min_range_check: i128,
|
||||
/// max range check
|
||||
pub max_range_check: i128,
|
||||
pub max_range_size: i128,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
@@ -531,7 +530,7 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs)?;
|
||||
let res = self.dummy_layout(run_args, &inputs, false)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -545,6 +544,8 @@ impl Model {
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
num_dynamic_lookups: res.num_dynamic_lookups,
|
||||
total_dynamic_col_size: res.dynamic_lookup_col_coord,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -570,12 +571,13 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
throw_range_check_error: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -1007,24 +1009,25 @@ impl Model {
|
||||
/// # Arguments
|
||||
/// * `meta` - The constraint system.
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `run_args` - [RunArgs]
|
||||
/// * `required_lookups` - The required lookup operations for the circuit.
|
||||
/// * `settings` - [GraphSettings]
|
||||
pub fn configure(
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
required_lookups: Vec<LookupOp>,
|
||||
required_range_checks: Vec<Range>,
|
||||
check_mode: CheckMode,
|
||||
settings: &GraphSettings,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
debug!("configuring model");
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
let logrows = settings.run_args.logrows as usize;
|
||||
let num_dynamic_lookups = settings.num_dynamic_lookups;
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
&vars.advices[2],
|
||||
check_mode,
|
||||
settings.check_mode,
|
||||
);
|
||||
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
|
||||
let input = &vars.advices[0];
|
||||
@@ -1038,6 +1041,14 @@ impl Model {
|
||||
base_gate.configure_range_check(meta, input, index, range, logrows)?;
|
||||
}
|
||||
|
||||
for _ in 0..num_dynamic_lookups {
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
vars.advices[3..5].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1356,6 +1367,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
throw_range_check_error: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1374,7 +1386,7 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols);
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
@@ -1441,8 +1453,9 @@ impl Model {
|
||||
range_checks: region.used_range_checks(),
|
||||
max_lookup_inputs: region.max_lookup_inputs(),
|
||||
min_lookup_inputs: region.min_lookup_inputs(),
|
||||
min_range_check: region.min_range_check(),
|
||||
max_range_check: region.max_range_check(),
|
||||
max_range_size: region.max_range_size(),
|
||||
num_dynamic_lookups: region.dynamic_lookup_index(),
|
||||
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
|
||||
outputs,
|
||||
};
|
||||
|
||||
|
||||
@@ -734,7 +734,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: false,
|
||||
use_range_check_for_int: true,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -420,20 +420,31 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
|
||||
}
|
||||
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
pub fn new(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
var_len: usize,
|
||||
num_inner_cols: usize,
|
||||
num_constants: usize,
|
||||
module_requires_fixed: bool,
|
||||
) -> Self {
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
let advices = (0..3)
|
||||
let logrows = params.run_args.logrows as usize;
|
||||
let var_len = params.total_assignments;
|
||||
let num_inner_cols = params.run_args.num_inner_cols;
|
||||
let num_constants = params.total_const_size;
|
||||
let module_requires_fixed = params.module_requires_fixed();
|
||||
let requires_dynamic_lookup = params.requires_dynamic_lookup();
|
||||
let dynamic_lookup_size = params.total_dynamic_col_size;
|
||||
|
||||
let mut advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
|
||||
.collect_vec();
|
||||
|
||||
if requires_dynamic_lookup {
|
||||
for _ in 0..2 {
|
||||
let dynamic_lookup = VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_size);
|
||||
if dynamic_lookup.num_blocks() > 1 {
|
||||
panic!("dynamic lookup should only have one block");
|
||||
};
|
||||
advices.push(dynamic_lookup);
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"model uses {} advice blocks (size={})",
|
||||
advices.iter().map(|v| v.num_blocks()).sum::<usize>(),
|
||||
|
||||
@@ -180,6 +180,9 @@ impl RunArgs {
|
||||
if self.num_inner_cols < 1 {
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -3773,6 +3773,30 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise inverse.
|
||||
/// # Arguments
|
||||
/// * `out_scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
|
||||
/// let k = 2_f64;
|
||||
/// let result = zero_recip(1.0);
|
||||
/// let expected = Tensor::<i128>::new(Some(&[4503599627370496]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn zero_recip(out_scale: f64) -> Tensor<i128> {
|
||||
let a = Tensor::<i128>::new(Some(&[0]), &[1]).unwrap();
|
||||
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = a_i as f64;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as i128)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
@@ -211,7 +211,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward(&mut input, None, None)
|
||||
.forward(&mut input, None, None, false)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
@@ -276,7 +277,7 @@ mod native_tests {
|
||||
"bitshift",
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 48] = [
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
"1l_mlp",
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -325,8 +326,6 @@ mod native_tests {
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // "variable_cnn",
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
// "xgboost",
|
||||
@@ -586,6 +585,8 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -841,7 +842,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=47 {
|
||||
seq!(N in 0..=45 {
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
fn kzg_prove_and_verify_with_overflow_(test: &str) {
|
||||
@@ -1288,6 +1289,7 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
@@ -1299,16 +1301,10 @@ mod native_tests {
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
tolerance,
|
||||
&mut tolerance,
|
||||
);
|
||||
|
||||
let settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if tolerance > 0.0 && !any_output_scales_smol {
|
||||
if tolerance > 0.0 {
|
||||
// load witness and shift the output by a small amount that is less than tolerance percent
|
||||
let witness = GraphWitness::from_path(
|
||||
format!("{}/{}/witness.json", test_dir, example_name).into(),
|
||||
@@ -1333,7 +1329,7 @@ mod native_tests {
|
||||
as i128,
|
||||
)
|
||||
};
|
||||
|
||||
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
@@ -1444,7 +1440,7 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: f32,
|
||||
tolerance: &mut f32,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1502,6 +1498,24 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let mut settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if any_output_scales_smol {
|
||||
// set the tolerance to 0.0
|
||||
settings.run_args.tolerance = Tolerance {
|
||||
val: 0.0,
|
||||
scale: 0.0.into(),
|
||||
};
|
||||
settings
|
||||
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
@@ -1559,7 +1573,7 @@ mod native_tests {
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -1819,7 +1833,7 @@ mod native_tests {
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -1921,7 +1935,7 @@ mod native_tests {
|
||||
None,
|
||||
2,
|
||||
false,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -2198,7 +2212,7 @@ mod native_tests {
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -91,9 +91,7 @@ def compare_outputs(zk_output, onnx_output):
|
||||
print("------- zk_output: ", list1_i)
|
||||
print("------- onnx_output: ", list2_i)
|
||||
|
||||
|
||||
|
||||
return np.mean(np.abs(res))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -113,6 +111,9 @@ if __name__ == '__main__':
|
||||
onnx_output = get_onnx_output(model_file, input_file)
|
||||
# compare the outputs
|
||||
percentage_difference = compare_outputs(ezkl_output, onnx_output)
|
||||
mean_percentage_difference = np.mean(np.abs(percentage_difference))
|
||||
max_percentage_difference = np.max(np.abs(percentage_difference))
|
||||
# print the percentage difference
|
||||
print("mean percent diff: ", percentage_difference)
|
||||
assert percentage_difference < target, "Percentage difference is too high"
|
||||
print("mean percent diff: ", mean_percentage_difference)
|
||||
print("max percent diff: ", max_percentage_difference)
|
||||
assert mean_percentage_difference < target, "Percentage difference is too high"
|
||||
|
||||
Binary file not shown.
@@ -27,6 +27,8 @@
|
||||
"check_mode": "UNSAFE"
|
||||
},
|
||||
"num_rows": 16,
|
||||
"total_dynamic_col_size": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
"total_assignments": 32,
|
||||
"total_const_size": 8,
|
||||
"model_instance_shapes": [
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_check":0,"min_range_check":0}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}
|
||||
Reference in New Issue
Block a user