Compare commits

..

3 Commits

Author SHA1 Message Date
github-actions[bot]
f7b4067223 ci: update version string in docs 2025-02-09 21:15:59 +00:00
dante
c19fa5218a refactor: enforce max decomp base/legs in args (#936) 2025-02-09 16:15:40 -05:00
rebustron
eb205d0c73 chore: fix typos in comments and docs (#934) 2025-02-08 19:13:17 -05:00
5 changed files with 22 additions and 58 deletions

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '20.0.2'
release = '20.0.4'
version = release

View File

@@ -597,7 +597,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
@@ -656,7 +656,7 @@ fn kzg_commit(
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
@@ -1950,7 +1950,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// The address of the separate VK contract (if the verifier key is rendered as a separate contract)
/// Returns
/// -------
/// bool

View File

@@ -156,25 +156,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
claimed_output.reshape(input_dims)?;
// implicitly check if the prover provided output is within range
let claimed_output = identity(config, region, &[claimed_output], true)?;
// check if x is too large only if the decomp would support overflow in the previous op
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[claimed_output.clone()])?;
let abs_value = pairwise(
config,
region,
&[claimed_output.clone(), sign],
BaseOp::Mult,
)?;
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
// assert the result is 1
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
enforce_equality(config, region, &[abs_value, comparison_unit])?;
}
let product = pairwise(
config,
@@ -248,32 +229,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&[equal_zero_mask.clone(), equal_inverse_mask],
)?;
let masked_output = pairwise(
config,
region,
&[claimed_output.clone(), not_equal_zero_mask.clone()],
BaseOp::Mult,
)?;
// check if x is too large only if the decomp would support overflow in the previous op
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[masked_output.clone()])?;
let abs_value = pairwise(
config,
region,
&[claimed_output.clone(), sign],
BaseOp::Mult,
)?;
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
// assert the result is 1
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
enforce_equality(config, region, &[abs_value, comparison_unit])?;
}
let err_func = |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>|
@@ -349,7 +304,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
.into()
};
claimed_output.reshape(input_dims)?;
// force the output to be positive or zero, also implicitly checks that the ouput is in range
// force the output to be positive or zero, also implicitly checks that the output is in range
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
// rescaled input
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
@@ -1841,7 +1796,7 @@ pub(crate) fn get_missing_set_elements<
// get the difference between the two vectors
for eval in input_evals.iter() {
// delete first occurence of that value
// delete first occurrence of that value
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
fullset_evals.remove(pos);
}
@@ -1869,7 +1824,7 @@ pub(crate) fn get_missing_set_elements<
region.increment(claimed_output.len());
// input and claimed output should be the shuffles of fullset
// concatentate input and claimed output
// concatenate input and claimed output
let input_and_claimed_output = input.concat(claimed_output.clone())?;
// assert that this is a permutation/shuffle
@@ -3396,7 +3351,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Performs a deconvolution on the given input tensor.
/// # Examples
/// ```
// // expected ouputs are taken from pytorch torch.nn.functional.conv_transpose2d
// // expected outputs are taken from pytorch torch.nn.functional.conv_transpose2d
///
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
@@ -3624,7 +3579,7 @@ pub fn deconv<
/// Applies convolution over a ND tensor of shape C x H x D1...DN (and adds a bias).
/// ```
/// // expected ouputs are taken from pytorch torch.nn.functional.conv2d
/// // expected outputs are taken from pytorch torch.nn.functional.conv2d
///
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
@@ -3908,7 +3863,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
Ok(rescaled_inputs)
}
/// Dummy (no contraints) reshape layout
/// Dummy (no constraints) reshape layout
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
new_dims: &[usize],
@@ -3918,7 +3873,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
Ok(t)
}
/// Dummy (no contraints) move_axis layout
/// Dummy (no constraints) move_axis layout
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
source: usize,

View File

@@ -100,7 +100,6 @@ use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
@@ -399,6 +398,16 @@ impl RunArgs {
pub fn validate(&self) -> Result<(), String> {
let mut errors = Vec::new();
// check if the largest represented integer in the decomposed form overflows IntegerRep
// try it with the largest possible value
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
if max_decomp.is_none() {
errors.push(format!(
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
self.decomp_base, self.decomp_legs
));
}
// Visibility validations
if self.param_visibility == Visibility::Public {
errors.push(

View File

@@ -337,7 +337,7 @@ mod wasm32 {
// Run compiled circuit validation on onnx network (should fail)
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK.to_vec()));
assert!(circuit.is_err());
// Run compiled circuit validation on comiled network (should pass)
// Run compiled circuit validation on compiled network (should pass)
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()));
assert!(circuit.is_ok());
// Run input validation on witness (should fail)