Compare commits

...

4 Commits

Author SHA1 Message Date
dante
1656846d1a fix: transcript should serialize as lc flag (#726) 2024-02-26 22:02:47 +00:00
dante
88098b8190 fix!: cleanup felt serialization language in python and wasm (#724)
BREAKING CHANGE: python and wasm felt utilities have new names
2024-02-25 14:06:48 +00:00
dante
6c0c17c9be fix: include tol check in fwd pass (#723) 2024-02-23 01:28:59 +00:00
dante
bf69b16fc1 fix: rm optional bool flags (#722) 2024-02-21 12:45:42 +00:00
27 changed files with 645 additions and 459 deletions

View File

@@ -236,6 +236,8 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
- name: kzg inputs
@@ -498,7 +500,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests, python-tests]
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -511,11 +513,11 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 8 -- --include-ignored
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests, python-tests]
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1

View File

@@ -309,7 +309,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
"print(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]))"
]
},
{
@@ -325,7 +325,7 @@
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider, utils\n",
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
@@ -338,7 +338,7 @@
"\n",
"def test_on_chain_data(res):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = [int(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
"\n",
" # Step 1: Prepare the data\n",
" # Step 2: Prepare and compile the contract.\n",
@@ -648,7 +648,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4
},

View File

@@ -695,7 +695,7 @@
"formatted_output = \"[\"\n",
"for i, value in enumerate(proof[\"instances\"]):\n",
" for j, field_element in enumerate(value):\n",
" onchain_input_array.append(ezkl.string_to_felt(field_element))\n",
" onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n",
" formatted_output += str(onchain_input_array[-1])\n",
" if j != len(value) - 1:\n",
" formatted_output += \", \"\n",
@@ -705,7 +705,7 @@
"# copy them over to remix and see if they verify\n",
"# What happens when you change a value?\n",
"print(\"pubInputs: \", formatted_output)\n",
"print(\"proof: \", \"0x\" + proof[\"proof\"])"
"print(\"proof: \", proof[\"proof\"])"
]
},
{

View File

@@ -122,8 +122,8 @@
"# Loop through each element in the y tensor\n",
"for e in y_input:\n",
" # Apply the custom function and append the result to the list\n",
" print(ezkl.float_to_string(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 7)])[0])\n",
" print(ezkl.float_to_felt(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 7)])[0])\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",

View File

@@ -126,7 +126,7 @@
"# Loop through each element in the y tensor\n",
"for e in user_preimages:\n",
" # Apply the custom function and append the result to the list\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 0)])[0])\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 0)])[0])\n",
"\n",
"users_t = torch.tensor(user_preimages)\n",
"users_t = users_t.reshape(1, 6)\n",
@@ -303,7 +303,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_felt(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))"
]
},
@@ -417,7 +417,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_felt(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))\n"
]
},
@@ -510,7 +510,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -503,11 +503,11 @@
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
"\n",
"arrow_x = ezkl.string_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.string_to_float(witness['outputs'][0][1], out_scale)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][1], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)\n",
"arrow_x = ezkl.string_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.string_to_float(witness['outputs'][0][3], out_scale)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][3], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)"
]
}

View File

@@ -19,8 +19,8 @@ use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use crate::{
circuit::ops::base::BaseOp,
circuit::{
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
utils,
},
@@ -540,7 +540,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
&mut self,
cs: &mut ConstraintSystem<F>,
input: &VarTensor,
index: &VarTensor,
range: Range,
logrows: usize,
) -> Result<(), Box<dyn Error>>
where
F: Field,
@@ -556,7 +558,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
let range_check =
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
@@ -565,32 +567,60 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
let single_col_sel = cs.complex_selector();
let len = range_check.selector_constructor.degree;
let multi_col_selector = cs.complex_selector();
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(single_col_sel);
for (col_idx, input_col) in range_check.inputs.iter().enumerate() {
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(multi_col_selector);
let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let default_x = range_check.get_first_element();
let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};
let not_sel = Expression::Constant(F::ONE) - sel.clone();
let default_x = range_check.get_first_element(col_idx);
res.extend([(
sel.clone() * input_query.clone()
+ not_sel.clone() * Expression::Constant(default_x),
range_check.input,
)]);
let col_expr = sel.clone()
* range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
res
});
selectors.insert((range, x, y), single_col_sel);
let multiplier = range_check
.selector_constructor
.get_selector_val_at_idx(col_idx);
let not_expr = Expression::Constant(multiplier) - col_expr.clone();
res.extend([(
col_expr.clone() * input_query.clone()
+ not_expr.clone() * Expression::Constant(default_x),
*input_col,
)]);
log::trace!("---------------- col {:?} ------------------", col_idx,);
log::trace!("expr: {:?}", col_expr,);
log::trace!("multiplier: {:?}", multiplier);
log::trace!("not_expr: {:?}", not_expr);
log::trace!("default x: {:?}", default_x);
res
});
}
selectors.insert((range, x, y), multi_col_selector);
}
}
self.range_check_selectors.extend(selectors);
@@ -600,6 +630,11 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
self.lookup_input = input.clone();
}
if let VarTensor::Empty = self.lookup_index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
}
Ok(())
}

View File

@@ -20,7 +20,7 @@ use super::{
use crate::{
circuit::{
ops::base::BaseOp,
utils::{self},
utils::{self, F32},
},
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{
@@ -105,6 +105,8 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
BaseOp::Sub,
)?;
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
range_check(
config,
region,
@@ -2482,6 +2484,28 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
let is_dummy = region.is_dummy();
let table_index: ValTensor<F> = w
.get_inner_tensor()?
.par_enum_map(|_, e| {
Ok::<ValType<F>, TensorError>(if let Some(f) = e.get_felt_eval() {
let col_idx = if !is_dummy {
let table = config
.range_checks
.get(range)
.ok_or(TensorError::TableLookupError)?;
table.get_col_index(f)
} else {
F::ZERO
};
Value::known(col_idx).into()
} else {
Value::<F>::unknown().into()
})
})?
.into();
region.assign(&config.lookup_index, &table_index)?;
if !is_dummy {
(0..assigned_len)
.map(|i| {
@@ -2953,8 +2977,17 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
return enforce_equality(config, region, values);
}
let mut values = [values[0].clone(), values[1].clone()];
values[0] = region.assign(&config.inputs[0], &values[0])?;
values[1] = region.assign(&config.inputs[1], &values[1])?;
let total_assigned_0 = values[0].len();
let total_assigned_1 = values[1].len();
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
region.increment(total_assigned);
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(
@@ -2963,44 +2996,22 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
&[values[0].clone()],
&LookupOp::Recip {
input_scale: scale,
output_scale: scale,
// multiply by 100 to get the percent error
output_scale: F32(scale.0 * 100.0),
},
)?;
// Multiply the difference by the recip
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
let rebased_product = div(config, region, &[product], F::from(scale.0 as u64))?;
let scale_squared = scale.0 * scale.0;
let scaled_tol = (tol * scale.0) as i128;
// Use the greater than look up table to check if the percent error is within the tolerance for upper bound
let tol = tol / 100.0;
let upper_bound = nonlinearity(
// check that it is within the tolerance range
range_check(
config,
region,
&[product.clone()],
&LookupOp::GreaterThan {
a: utils::F32(tol * scale_squared),
},
)?;
// Negate the product
let neg_product = neg(config, region, &[product])?;
// Use the greater than look up table to check if the percent error is within the tolerance for lower bound
let lower_bound = nonlinearity(
config,
region,
&[neg_product],
&LookupOp::GreaterThan {
a: utils::F32(tol * scale_squared),
},
)?;
// Add the lower_bound and upper_bound
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
// Constrain the sum to be all zeros
is_zero_identity(config, region, &[sum.clone()], false)?;
Ok(sum)
&[rebased_product],
&(-scaled_tol, scaled_tol),
)
}

View File

@@ -243,10 +243,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
LookupOp::LessThan { .. } => "LESS_THAN".into(),
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,

View File

@@ -70,6 +70,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
min_range_check: i128,
max_range_check: i128,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -93,6 +95,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
}
}
/// Create a new region context from a wrapped region
@@ -112,6 +116,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
}
}
@@ -130,6 +136,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
}
}
@@ -153,6 +161,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
used_range_checks,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
}
}
@@ -300,8 +310,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
return Err("update_max_min_lookup_range: invalid range".into());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
self.max_range_check = self.max_range_check.max(range.1);
self.min_range_check = self.min_range_check.min(range.0);
Ok(())
}
@@ -361,6 +371,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.min_lookup_inputs
}
/// min range check
pub fn min_range_check(&self) -> i128 {
self.min_range_check
}
/// max range check
pub fn max_range_check(&self) -> i128 {
self.max_range_check
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;

View File

@@ -130,14 +130,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
///
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
///
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
@@ -152,7 +152,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = Self::num_cols_required(range, col_size);
let num_cols = num_cols_required(range, col_size);
log::debug!("table range: {:?}", range);
@@ -265,7 +265,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
#[derive(Clone, Debug)]
pub struct RangeCheck<F: PrimeField> {
/// Input to table.
pub input: TableColumn,
pub inputs: Vec<TableColumn>,
/// col size
pub col_size: usize,
/// selector cn
pub selector_constructor: SelectorConstructor<F>,
/// Flags if table has been previously assigned to.
@@ -277,8 +279,10 @@ pub struct RangeCheck<F: PrimeField> {
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self) -> F {
i128_to_felt(self.range.0)
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
// we index from 1 to prevent soundness issues
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
}
///
@@ -290,24 +294,58 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
i128_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range) -> RangeCheck<F> {
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
let inputs = cs.lookup_table_column();
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required(range, col_size);
let inputs = {
let mut cols = vec![];
for _ in 0..num_cols {
cols.push(cs.lookup_table_column());
}
cols
};
let num_cols = inputs.len();
if num_cols > 1 {
warn!("Using {} columns for range-check.", num_cols);
}
RangeCheck {
input: inputs,
inputs,
col_size,
is_assigned: false,
selector_constructor: SelectorConstructor::new(2),
selector_constructor: SelectorConstructor::new(num_cols),
range,
_marker: PhantomData,
}
}
/// Take a linear coordinate and output the (column, row) position in the storage block.
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize) {
let x = linear_coord / self.col_size;
let y = linear_coord % self.col_size;
(x, y)
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
if self.is_assigned {
@@ -318,28 +356,43 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
layouter.assign_table(
|| "range check table",
|mut table| {
let _ = inputs
.iter()
.enumerate()
.map(|(row_offset, input)| {
table.assign_cell(
|| format!("rc_i_col row {}", row_offset),
self.input,
row_offset,
|| Value::known(*input),
)?;
let col_multipliers: Vec<F> = (0..chunked_inputs.len())
.map(|x| self.selector_constructor.get_selector_val_at_idx(x))
.collect();
let _ = chunked_inputs
.enumerate()
.map(|(chunk_idx, inputs)| {
layouter.assign_table(
|| "range check table",
|mut table| {
let _ = inputs
.iter()
.enumerate()
.map(|(mut row_offset, input)| {
let col_multiplier = col_multipliers[chunk_idx];
row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);
table.assign_cell(
|| format!("rc_i_col row {}", row_offset),
self.inputs[x],
y,
|| Value::known(*input * col_multiplier),
)?;
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
},
)?;
},
)
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
}
}

View File

@@ -2140,148 +2140,6 @@ mod matmul_relu {
}
}
#[cfg(test)]
mod rangecheckpercent {
use crate::circuit::Tolerance;
use crate::{circuit, tensor::Tensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const RANGE: f32 = 1.0; // 1 percent error tolerance
const K: usize = 18;
const LEN: usize = 1;
const SCALE: usize = i128::pow(2, 7) as usize;
use super::*;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
input: ValTensor<F>,
output: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let scale = utils::F32(SCALE as f32);
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// set up a new GreaterThan and Recip tables
let nl = &LookupOp::GreaterThan {
a: circuit::utils::F32((RANGE * SCALE.pow(2) as f32) / 100.0),
};
config
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl)
.unwrap();
config
.configure_lookup(
cs,
&b,
&output,
&a,
(-32768, 32768),
K,
&LookupOp::Recip {
input_scale: scale,
output_scale: scale,
},
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&[self.output.clone(), self.input.clone()],
Box::new(HybridOp::RangeCheck(Tolerance {
val: RANGE,
scale: SCALE.into(),
})),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_range_check_percent() {
// Successful cases
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(101_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(200_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(199_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
// Unsuccessful case
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(102_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
match prover.verify() {
Ok(_) => {
assert!(false)
}
Err(_) => {
assert!(true)
}
}
}
}
}
#[cfg(test)]
mod relu {
use super::*;

View File

@@ -88,6 +88,8 @@ pub const DEFAULT_VK_ABI: &str = "vk.abi";
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
@@ -371,9 +373,9 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
#[arg(long)]
div_rebasing: Option<bool>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
only_range_check_rebase: bool,
},
/// Generates a dummy SRS
@@ -469,7 +471,7 @@ pub enum Commands {
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
@@ -524,13 +526,13 @@ pub enum Commands {
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
@@ -732,7 +734,7 @@ pub enum Commands {
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: Option<bool>,
reduced_srs: bool,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {

View File

@@ -178,7 +178,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
scales,
scale_rebase_multiplier,
max_logrows,
div_rebasing,
only_range_check_rebase,
} => calibrate(
model,
data,
@@ -187,7 +187,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
only_range_check_rebase,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -488,8 +488,14 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let hash = sha256::digest(std::fs::read(path.clone())?);
let file = std::fs::File::open(path.clone())?;
let mut buffer = vec![];
let bytes_read = std::io::BufReader::new(file).read_to_end(&mut buffer)?;
debug!("read {} bytes from SRS file", bytes_read);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
@@ -789,7 +795,7 @@ pub(crate) fn calibrate(
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
div_rebasing: Option<bool>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
@@ -830,8 +836,8 @@ pub(crate) fn calibrate(
(10..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
vec![div_rebasing]
let div_rebasing = if only_range_check_rebase {
vec![false]
} else {
vec![true, false]
};
@@ -963,18 +969,34 @@ pub(crate) fn calibrate(
.max()
.unwrap_or(0);
let min_range_check = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.min_range_check)
.min()
.unwrap_or(0);
let max_range_check = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_check)
.max()
.unwrap_or(0);
let res = circuit.calibrate_from_min_max(
min_lookup_range,
max_lookup_range,
(min_lookup_range, max_lookup_range),
(min_range_check, max_range_check),
max_logrows,
lookup_safety_margin,
);
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
#[cfg(unix)]
std::mem::drop(_q);
// // drop the gag
// #[cfg(unix)]
// std::mem::drop(_r);
// #[cfg(unix)]
// std::mem::drop(_q);
if res.is_ok() {
let new_settings = circuit.settings().clone();
@@ -2044,16 +2066,12 @@ pub(crate) fn verify(
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
reduced_srs: Option<bool>,
reduced_srs: bool,
) -> Result<bool, Box<dyn Error>> {
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = if let Some(reduced_srs) = reduced_srs {
if reduced_srs {
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
} else {
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
}
let params = if reduced_srs {
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
} else {
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
};

View File

@@ -624,13 +624,13 @@ impl ToPyObject for DataSource {
}
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_string_montgomery;
use crate::pfsys::field_to_string;
#[cfg(feature = "python-bindings")]
impl ToPyObject for FileSourceInner {
fn to_object(&self, py: Python) -> PyObject {
match self {
FileSourceInner::Field(data) => field_to_string_montgomery(data).to_object(py),
FileSourceInner::Field(data) => field_to_string(data).to_object(py),
FileSourceInner::Bool(data) => data.to_object(py),
FileSourceInner::Float(data) => data.to_object(py),
}

View File

@@ -24,7 +24,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::table::{Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
@@ -56,7 +56,7 @@ pub use utilities::*;
pub use vars::*;
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_string_montgomery;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
@@ -171,6 +171,10 @@ pub struct GraphWitness {
pub max_lookup_inputs: i128,
/// max lookup input
pub min_lookup_inputs: i128,
/// max range check input
pub max_range_check: i128,
/// max range check input
pub min_range_check: i128,
}
impl GraphWitness {
@@ -198,6 +202,8 @@ impl GraphWitness {
processed_outputs: None,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_check: 0,
min_range_check: 0,
}
}
@@ -355,19 +361,25 @@ impl ToPyObject for GraphWitness {
let inputs: Vec<Vec<String>> = self
.inputs
.iter()
.map(|x| x.iter().map(field_to_string_montgomery).collect())
.map(|x| x.iter().map(field_to_string).collect())
.collect();
let outputs: Vec<Vec<String>> = self
.outputs
.iter()
.map(|x| x.iter().map(field_to_string_montgomery).collect())
.map(|x| x.iter().map(field_to_string).collect())
.collect();
dict.set_item("inputs", inputs).unwrap();
dict.set_item("outputs", outputs).unwrap();
dict.set_item("max_lookup_inputs", self.max_lookup_inputs)
.unwrap();
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
.unwrap();
dict.set_item("max_range_check", self.max_range_check)
.unwrap();
dict.set_item("min_range_check", self.min_range_check)
.unwrap();
if let Some(processed_inputs) = &self.processed_inputs {
//poseidon_hash
@@ -409,10 +421,7 @@ impl ToPyObject for GraphWitness {
#[cfg(feature = "python-bindings")]
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
let poseidon_hash: Vec<String> = poseidon_hash
.iter()
.map(field_to_string_montgomery)
.collect();
let poseidon_hash: Vec<String> = poseidon_hash.iter().map(field_to_string).collect();
pydict.set_item("poseidon_hash", poseidon_hash)?;
Ok(())
@@ -1000,14 +1009,10 @@ impl GraphCircuit {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}
fn calc_safe_lookup_range(
min_lookup_inputs: i128,
max_lookup_inputs: i128,
lookup_safety_margin: i128,
) -> Range {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let mut margin = (
lookup_safety_margin * min_lookup_inputs,
lookup_safety_margin * max_lookup_inputs,
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
if lookup_safety_margin == 1 {
margin.0 += 4;
@@ -1022,13 +1027,13 @@ impl GraphCircuit {
max_logrows as usize,
Self::reserved_blinding_rows() as usize,
);
Table::<Fp>::num_cols_required(safe_range, max_col_size)
num_cols_required(safe_range, max_col_size)
}
fn calc_min_logrows(
&mut self,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
min_max_lookup: Range,
min_max_range_checks: Range,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -1040,18 +1045,32 @@ impl GraphCircuit {
let reserved_blinding_rows = Self::reserved_blinding_rows();
// check if has overflowed max lookup input
if max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
if min_max_lookup.1.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_max_lookup.0.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
{
let err_string = format!("max lookup input ({}) is too large", max_lookup_inputs);
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
}
let safe_range = Self::calc_safe_lookup_range(
min_lookup_inputs,
max_lookup_inputs,
lookup_safety_margin,
);
if min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|| min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
{
let err_string = format!(
"max range check input {:?} is too large",
min_max_range_checks
);
return Err(err_string.into());
}
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// pick the range with the largest absolute size between safe_lookup_range and min_max_range_checks
let safe_range = if (safe_lookup_range.1 - safe_lookup_range.0)
> (min_max_range_checks.1 - min_max_range_checks.0)
{
safe_lookup_range
} else {
min_max_range_checks
};
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
@@ -1135,7 +1154,7 @@ impl GraphCircuit {
let model = self.model().clone();
let settings_mut = self.settings_mut();
settings_mut.run_args.lookup_range = safe_range;
settings_mut.run_args.lookup_range = safe_lookup_range;
settings_mut.run_args.logrows = logrows as u32;
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
@@ -1188,14 +1207,14 @@ impl GraphCircuit {
/// Calibrate the circuit to the supplied data.
pub fn calibrate_from_min_max(
&mut self,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
min_max_lookup: Range,
min_max_range_checks: Range,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_lookup_inputs,
max_lookup_inputs,
min_max_lookup,
min_max_range_checks,
max_logrows,
lookup_safety_margin,
)?;
@@ -1248,7 +1267,7 @@ impl GraphCircuit {
}
}
let mut model_results = self.model().forward(inputs)?;
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1291,6 +1310,8 @@ impl GraphCircuit {
processed_outputs,
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_check: model_results.max_range_check,
min_range_check: model_results.min_range_check,
};
witness.generate_rescaled_elements(

View File

@@ -67,6 +67,10 @@ pub struct ForwardResult {
pub max_lookup_inputs: i128,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
/// The max range check value
pub max_range_check: i128,
/// The min range check value
pub min_range_check: i128,
}
impl From<DummyPassRes> for ForwardResult {
@@ -75,6 +79,8 @@ impl From<DummyPassRes> for ForwardResult {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
min_range_check: res.min_range_check,
max_range_check: res.max_range_check,
}
}
}
@@ -108,6 +114,10 @@ pub struct DummyPassRes {
pub max_lookup_inputs: i128,
/// min lookup inputs
pub min_lookup_inputs: i128,
/// min range check
pub min_range_check: i128,
/// max range check
pub max_range_check: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -556,12 +566,16 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
pub fn forward(
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
let res = self.dummy_layout(run_args, &valtensor_inputs)?;
Ok(res.into())
}
@@ -1021,7 +1035,7 @@ impl Model {
}
for range in required_range_checks {
base_gate.configure_range_check(meta, input, range)?;
base_gate.configure_range_check(meta, input, index, range, logrows)?;
}
Ok(base_gate)
@@ -1371,27 +1385,26 @@ impl Model {
ValType::Constant(Fp::ONE)
};
let comparator = outputs
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.map(|x| {
let mut v: ValTensor<Fp> =
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
v.reshape(x.dims())?;
Ok(v)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
let _ = outputs
.iter()
.zip(comparator)
.map(|(o, c)| {
dummy_config.layout(
&mut region,
&[o.clone(), c],
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
)
})
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<_>, _>>();
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
@@ -1428,6 +1441,8 @@ impl Model {
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
min_range_check: region.min_range_check(),
max_range_check: region.max_range_check(),
outputs,
};

View File

@@ -197,7 +197,11 @@ impl std::fmt::Display for TranscriptType {
}
}
impl ToFlags for TranscriptType {}
impl ToFlags for TranscriptType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for TranscriptType {
@@ -212,8 +216,8 @@ impl ToPyObject for TranscriptType {
#[cfg(feature = "python-bindings")]
///
pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
let g1affine_x = field_to_string_montgomery(&g1affine.x);
let g1affine_y = field_to_string_montgomery(&g1affine.y);
let g1affine_x = field_to_string(&g1affine.x);
let g1affine_y = field_to_string(&g1affine.y);
g1affine_dict.set_item("x", g1affine_x).unwrap();
g1affine_dict.set_item("y", g1affine_y).unwrap();
}
@@ -223,23 +227,23 @@ use halo2curves::bn256::G1;
#[cfg(feature = "python-bindings")]
///
pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) {
let g1_x = field_to_string_montgomery(&g1.x);
let g1_y = field_to_string_montgomery(&g1.y);
let g1_z = field_to_string_montgomery(&g1.z);
let g1_x = field_to_string(&g1.x);
let g1_y = field_to_string(&g1.y);
let g1_z = field_to_string(&g1.z);
g1_dict.set_item("x", g1_x).unwrap();
g1_dict.set_item("y", g1_y).unwrap();
g1_dict.set_item("z", g1_z).unwrap();
}
/// converts fp into `Vec<u64>` in Montgomery form
pub fn field_to_string_montgomery<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> String {
/// converts fp into a little endian Hex string
pub fn field_to_string<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> String {
let repr = serde_json::to_string(&fp).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
b
}
/// converts `Vec<u64>` in Montgomery form into fp
pub fn string_to_field_montgomery<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
/// converts a little endian Hex string into a field element
pub fn string_to_field<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
b: &String,
) -> F {
let repr = serde_json::to_string(&b).unwrap();
@@ -304,7 +308,7 @@ where
let field_elems: Vec<Vec<String>> = self
.instances
.iter()
.map(|x| x.iter().map(|fp| field_to_string_montgomery(fp)).collect())
.map(|x| x.iter().map(|fp| field_to_string(fp)).collect())
.collect::<Vec<_>>();
dict.set_item("instances", field_elems).unwrap();
let hex_proof = hex::encode(&self.proof);
@@ -728,12 +732,13 @@ where
let f =
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
VerifyingKey::<Scheme::Curve>::read::<_, C>(
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading verification key ✅");
Ok(vk)
}
/// Loads a [ProvingKey] at `path`.
@@ -750,12 +755,13 @@ where
let f =
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
ProvingKey::<Scheme::Curve>::read::<_, C>(
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading proving key ✅");
Ok(pk)
}
/// Saves a [ProvingKey] to `path`.
@@ -772,6 +778,7 @@ where
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving proving key ✅");
Ok(())
}
@@ -789,6 +796,7 @@ where
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving verification key ✅");
Ok(())
}

View File

@@ -63,9 +63,9 @@ struct PyG1 {
impl From<G1> for PyG1 {
fn from(g1: G1) -> Self {
PyG1 {
x: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.x),
y: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.y),
z: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.z),
x: crate::pfsys::field_to_string::<Fq>(&g1.x),
y: crate::pfsys::field_to_string::<Fq>(&g1.y),
z: crate::pfsys::field_to_string::<Fq>(&g1.z),
}
}
}
@@ -73,9 +73,9 @@ impl From<G1> for PyG1 {
impl From<PyG1> for G1 {
fn from(val: PyG1) -> Self {
G1 {
x: crate::pfsys::string_to_field_montgomery::<Fq>(&val.x),
y: crate::pfsys::string_to_field_montgomery::<Fq>(&val.y),
z: crate::pfsys::string_to_field_montgomery::<Fq>(&val.z),
x: crate::pfsys::string_to_field::<Fq>(&val.x),
y: crate::pfsys::string_to_field::<Fq>(&val.y),
z: crate::pfsys::string_to_field::<Fq>(&val.z),
}
}
}
@@ -106,8 +106,8 @@ pub struct PyG1Affine {
impl From<G1Affine> for PyG1Affine {
fn from(g1: G1Affine) -> Self {
PyG1Affine {
x: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.x),
y: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.y),
x: crate::pfsys::field_to_string::<Fq>(&g1.x),
y: crate::pfsys::field_to_string::<Fq>(&g1.y),
}
}
}
@@ -115,8 +115,8 @@ impl From<G1Affine> for PyG1Affine {
impl From<PyG1Affine> for G1Affine {
fn from(val: PyG1Affine) -> Self {
G1Affine {
x: crate::pfsys::string_to_field_montgomery::<Fq>(&val.x),
y: crate::pfsys::string_to_field_montgomery::<Fq>(&val.y),
x: crate::pfsys::string_to_field::<Fq>(&val.x),
y: crate::pfsys::string_to_field::<Fq>(&val.y),
}
}
}
@@ -217,53 +217,51 @@ impl Into<PyRunArgs> for RunArgs {
}
}
/// Converts 4 u64s to a field element
/// Converts a felt to big endian
#[pyfunction(signature = (
array,
felt,
))]
fn string_to_felt(array: PyFelt) -> PyResult<String> {
Ok(format!(
"{:?}",
crate::pfsys::string_to_field_montgomery::<Fr>(&array)
))
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
Ok(format!("{:?}", felt))
}
/// Converts 4 u64s representing a field element directly to an integer
/// Converts a field element hex string to an integer
#[pyfunction(signature = (
array,
))]
fn string_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field_montgomery::<Fr>(&array);
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
Ok(int_rep)
}
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
/// Converts a field eleement hex string to a floating point number
#[pyfunction(signature = (
array,
scale
))]
fn string_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field_montgomery::<Fr>(&array);
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
Ok(float_rep)
}
/// Converts a floating point element to 4 u64s representing a fixed point field element
/// Converts a floating point element to a field element hex string
#[pyfunction(signature = (
input,
scale
))]
fn float_to_string(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = i128_to_felt(int_rep);
Ok(crate::pfsys::field_to_string_montgomery::<Fr>(&felt))
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
/// Converts a buffer to vector of field elements
#[pyfunction(signature = (
buffer
))]
@@ -316,7 +314,10 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
.collect();
let field_elements: Vec<String> = field_elements.iter().map(|x| format!("{:?}", x)).collect();
let field_elements: Vec<String> = field_elements
.iter()
.map(|x| crate::pfsys::field_to_string::<Fr>(x))
.collect();
Ok(field_elements)
}
@@ -328,7 +329,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field_montgomery::<Fr>)
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let output =
@@ -339,7 +340,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let hash = output[0]
.iter()
.map(crate::pfsys::field_to_string_montgomery::<Fr>)
.map(crate::pfsys::field_to_string::<Fr>)
.collect::<Vec<_>>();
Ok(hash)
}
@@ -359,7 +360,7 @@ fn kzg_commit(
) -> PyResult<Vec<PyG1Affine>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field_montgomery::<Fr>)
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let settings = GraphSettings::load(&settings_path)
@@ -523,7 +524,7 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
div_rebasing = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
data: PathBuf,
@@ -534,7 +535,7 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
div_rebasing: Option<bool>,
only_range_check_rebase: bool,
) -> Result<bool, PyErr> {
crate::execute::calibrate(
model,
@@ -544,7 +545,7 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
only_range_check_rebase,
max_logrows,
)
.map_err(|e| {
@@ -687,14 +688,14 @@ fn prove(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_path=PathBuf::from(DEFAULT_VK),
srs_path=None,
non_reduced_srs=Some(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap()),
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
non_reduced_srs: Option<bool>,
non_reduced_srs: bool,
) -> Result<bool, PyErr> {
crate::execute::verify(
proof_path,
@@ -1103,13 +1104,13 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_function(wrap_pyfunction!(string_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(string_to_int, m)?)?;
m.add_function(wrap_pyfunction!(string_to_float, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
m.add_function(wrap_pyfunction!(float_to_string, m)?)?;
m.add_function(wrap_pyfunction!(float_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?;

View File

@@ -672,7 +672,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -690,7 +690,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -709,7 +709,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
if indices.is_empty() {
return Ok(());
} else {
return Err(TensorError::WrongMethod);
}
}
}
Ok(())

View File

@@ -69,19 +69,19 @@ pub fn encodeVerifierCalldata(
Ok(encoded)
}
/// Converts 4 u64s to a field element
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn stringToFelt(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(format!("{:?}", felt))
}
/// Converts 4 u64s representing a field element directly to an integer
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn stringToInt(
pub fn feltToInt(
array: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
@@ -92,10 +92,10 @@ pub fn stringToInt(
))
}
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
/// Converts felts to a floating point element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn stringToFloat(
pub fn feltToFloat(
array: wasm_bindgen::Clamped<Vec<u8>>,
scale: crate::Scale,
) -> Result<f64, JsError> {
@@ -106,26 +106,26 @@ pub fn stringToFloat(
Ok(int_rep as f64 / multiplier)
}
/// Converts a floating point element to 4 u64s representing a fixed point field element
/// Converts a floating point number to a hex string representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatTostring(
pub fn floatToFelt(
input: f64,
scale: crate::Scale,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = i128_to_felt(int_rep);
let vec = crate::pfsys::field_to_string_montgomery::<halo2curves::bn256::Fr>(&felt);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize string_montgomery{}", e)),
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn bufferToVecOfstring(
pub fn bufferToVecOfFelt(
buffer: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
// Convert the buffer to a slice

View File

@@ -2,6 +2,7 @@
#[cfg(test)]
mod native_tests {
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
@@ -477,6 +478,7 @@ mod native_tests {
use crate::native_tests::kzg_fuzz;
use crate::native_tests::render_circuit;
use crate::native_tests::model_serialization_different_binaries;
use rand::Rng;
use tempdir::TempDir;
#[test]
@@ -496,7 +498,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
test_dir.close().unwrap();
}
});
@@ -569,7 +571,18 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_tolerance_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// gen random number between 0.0 and 1.0
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
test_dir.close().unwrap();
}
@@ -580,7 +593,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -589,7 +602,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -598,7 +611,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -607,7 +620,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -616,7 +629,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -625,7 +638,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -634,7 +647,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None);
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -644,7 +657,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -654,7 +667,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -663,7 +676,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -673,7 +686,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None);
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -682,7 +695,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -692,7 +705,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None);
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -702,7 +715,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None);
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -712,7 +725,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -722,7 +735,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None);
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -732,7 +745,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None);
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -876,7 +889,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
});
@@ -1273,6 +1286,7 @@ mod native_tests {
batch_size: usize,
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
) {
gen_circuit_settings_and_witness(
test_dir,
@@ -1285,19 +1299,137 @@ mod native_tests {
scales_to_use,
2,
false,
tolerance,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if tolerance > 0.0 && !any_output_scales_smol {
// load witness and shift the output by a small amount that is less than tolerance percent
let witness = GraphWitness::from_path(
format!("{}/{}/witness.json", test_dir, example_name).into(),
)
.unwrap();
let witness = witness.clone();
let outputs = witness.outputs.clone();
// get values as i128
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
sv.iter()
.map(|v| {
// randomly perturb by a small amount less than tolerance
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::zero()
} else {
i128_to_felt(
(felt_to_i128(*v) as f32
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
as i128,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// get values as i128
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
sv.iter()
.map(|v| {
// randomly perturb by a small amount less than tolerance
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::from(2)
} else {
i128_to_felt(
(felt_to_i128(*v) as f32
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
as i128,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let good_witness = GraphWitness {
outputs: output_perturbed_safe,
..witness.clone()
};
// save
good_witness
.save(format!("{}/{}/witness_ok.json", test_dir, example_name).into())
.unwrap();
let bad_witness = GraphWitness {
outputs: output_perturbed_bad,
..witness.clone()
};
// save
bad_witness
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
.unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"mock",
"-W",
format!("{}/{}/witness_ok.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"mock",
"-W",
format!("{}/{}/witness_bad.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(!status.success());
} else {
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
}
#[allow(clippy::too_many_arguments)]
@@ -1312,6 +1444,7 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
div_rebasing: bool,
tolerance: f32,
) {
let mut args = vec![
"gen-settings".to_string(),
@@ -1326,6 +1459,7 @@ mod native_tests {
format!("--param-visibility={}", param_visibility),
format!("--output-visibility={}", output_visibility),
format!("--num-inner-cols={}", num_inner_columns),
format!("--tolerance={}", tolerance),
];
if div_rebasing {
@@ -1425,6 +1559,7 @@ mod native_tests {
None,
2,
div_rebasing,
0.0,
);
println!(
@@ -1684,6 +1819,7 @@ mod native_tests {
scales_to_use,
num_inner_columns,
false,
0.0,
);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
@@ -1765,7 +1901,7 @@ mod native_tests {
&format!("{}/{}/proof.pf", test_dir, example_name),
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
"--reduced-srs=true",
"--reduced-srs",
])
.status()
.expect("failed to execute process");
@@ -1785,6 +1921,7 @@ mod native_tests {
None,
2,
false,
0.0,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
@@ -2061,6 +2198,7 @@ mod native_tests {
Some(vec![4]),
1,
false,
0.0,
);
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);

View File

@@ -12,7 +12,7 @@ def get_ezkl_output(witness_file, settings_file):
outputs = witness_output['outputs']
with open(settings_file) as f:
settings = json.load(f)
ezkl_outputs = [[ezkl.string_to_float(
ezkl_outputs = [[ezkl.felt_to_float(
outputs[i][j], settings['model_output_scales'][i]) for j in range(len(outputs[i]))] for i in range(len(outputs))]
return ezkl_outputs

View File

@@ -56,9 +56,9 @@ def test_poseidon_hash():
Test for poseidon_hash
"""
message = [1.0, 2.0, 3.0, 4.0]
message = [ezkl.float_to_string(x, 7) for x in message]
message = [ezkl.float_to_felt(x, 7) for x in message]
res = ezkl.poseidon_hash(message)
assert ezkl.string_to_felt(
assert ezkl.felt_to_big_endian(
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
@@ -70,14 +70,14 @@ def test_field_serialization():
input = 890
scale = 7
felt = ezkl.float_to_string(input, scale)
roundtrip_input = ezkl.string_to_float(felt, scale)
felt = ezkl.float_to_felt(input, scale)
roundtrip_input = ezkl.felt_to_float(felt, scale)
assert input == roundtrip_input
input = -700
scale = 7
felt = ezkl.float_to_string(input, scale)
roundtrip_input = ezkl.string_to_float(felt, scale)
felt = ezkl.float_to_felt(input, scale)
roundtrip_input = ezkl.felt_to_float(felt, scale)
assert input == roundtrip_input
@@ -88,12 +88,12 @@ def test_buffer_to_felts():
buffer = bytearray("a sample string!", 'utf-8')
felts = ezkl.buffer_to_felts(buffer)
ref_felt_1 = "0x0000000000000000000000000000000021676e6972747320656c706d61732061"
assert felts == [ref_felt_1]
assert ezkl.felt_to_big_endian(felts[0]) == ref_felt_1
buffer = bytearray("a sample string!"+"high", 'utf-8')
felts = ezkl.buffer_to_felts(buffer)
ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968"
assert felts == [ref_felt_1, ref_felt_2]
assert [ezkl.felt_to_big_endian(felts[0]), ezkl.felt_to_big_endian(felts[1])] == [ref_felt_1, ref_felt_2]
def test_gen_srs():

View File

@@ -8,9 +8,9 @@ mod wasm32 {
use ezkl::graph::GraphWitness;
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
genWitness, inputValidation, pkValidation, poseidonHash, proofValidation, prove,
settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, genPk, genVk, genWitness, inputValidation, pkValidation,
poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
@@ -76,21 +76,21 @@ mod wasm32 {
for i in 0..32 {
let field_element = Fr::from(i);
let serialized = serde_json::to_vec(&field_element).unwrap();
let clamped = wasm_bindgen::Clamped(serialized);
let scale = 2;
let floating_point = stringToFloat(clamped.clone(), scale)
let floating_point = feltToFloat(clamped.clone(), scale)
.map_err(|_| "failed")
.unwrap();
assert_eq!(floating_point, (i as f64) / 4.0);
let integer: i128 = serde_json::from_slice(
&stringToInt(clamped.clone()).map_err(|_| "failed").unwrap(),
)
.unwrap();
let integer: i128 =
serde_json::from_slice(&feltToInt(clamped.clone()).map_err(|_| "failed").unwrap())
.unwrap();
assert_eq!(integer, i as i128);
let hex_string = format!("{:?}", field_element);
let returned_string = stringToFelt(clamped).map_err(|_| "failed").unwrap();
let returned_string: String = feltToBigEndian(clamped).map_err(|_| "failed").unwrap();
assert_eq!(hex_string, returned_string);
}
}
@@ -101,7 +101,7 @@ mod wasm32 {
let mut buffer = string_high.clone().into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -118,7 +118,7 @@ mod wasm32 {
let buffer = string_sample.clone().into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -133,7 +133,7 @@ mod wasm32 {
let buffer = string_concat.into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();

Binary file not shown.

View File

@@ -1 +1 @@
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_check":0,"min_range_check":0}