mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7b4067223 | ||
|
|
c19fa5218a | ||
|
|
eb205d0c73 |
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '20.0.2'
|
||||
release = '20.0.4'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -597,7 +597,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
@@ -656,7 +656,7 @@ fn kzg_commit(
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
@@ -1950,7 +1950,7 @@ fn deploy_da_evm(
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// The address of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
|
||||
@@ -156,25 +156,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
@@ -248,32 +229,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let masked_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), not_equal_zero_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
@@ -349,7 +304,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// force the output to be positive or zero, also implicitly checks that the ouput is in range
|
||||
// force the output to be positive or zero, also implicitly checks that the output is in range
|
||||
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
|
||||
// rescaled input
|
||||
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
|
||||
@@ -1841,7 +1796,7 @@ pub(crate) fn get_missing_set_elements<
|
||||
|
||||
// get the difference between the two vectors
|
||||
for eval in input_evals.iter() {
|
||||
// delete first occurence of that value
|
||||
// delete first occurrence of that value
|
||||
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
|
||||
fullset_evals.remove(pos);
|
||||
}
|
||||
@@ -1869,7 +1824,7 @@ pub(crate) fn get_missing_set_elements<
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// input and claimed output should be the shuffles of fullset
|
||||
// concatentate input and claimed output
|
||||
// concatenate input and claimed output
|
||||
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
||||
|
||||
// assert that this is a permutation/shuffle
|
||||
@@ -3396,7 +3351,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Performs a deconvolution on the given input tensor.
|
||||
/// # Examples
|
||||
/// ```
|
||||
// // expected ouputs are taken from pytorch torch.nn.functional.conv_transpose2d
|
||||
// // expected outputs are taken from pytorch torch.nn.functional.conv_transpose2d
|
||||
///
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -3624,7 +3579,7 @@ pub fn deconv<
|
||||
|
||||
/// Applies convolution over a ND tensor of shape C x H x D1...DN (and adds a bias).
|
||||
/// ```
|
||||
/// // expected ouputs are taken from pytorch torch.nn.functional.conv2d
|
||||
/// // expected outputs are taken from pytorch torch.nn.functional.conv2d
|
||||
///
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -3908,7 +3863,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
Ok(rescaled_inputs)
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) reshape layout
|
||||
/// Dummy (no constraints) reshape layout
|
||||
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
new_dims: &[usize],
|
||||
@@ -3918,7 +3873,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) move_axis layout
|
||||
/// Dummy (no constraints) move_axis layout
|
||||
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
source: usize,
|
||||
|
||||
11
src/lib.rs
11
src/lib.rs
@@ -100,7 +100,6 @@ use std::str::FromStr;
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use halo2_proofs::poly::{
|
||||
@@ -399,6 +398,16 @@ impl RunArgs {
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// check if the largest represented integer in the decomposed form overflows IntegerRep
|
||||
// try it with the largest possible value
|
||||
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
|
||||
if max_decomp.is_none() {
|
||||
errors.push(format!(
|
||||
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
|
||||
self.decomp_base, self.decomp_legs
|
||||
));
|
||||
}
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
errors.push(
|
||||
|
||||
@@ -337,7 +337,7 @@ mod wasm32 {
|
||||
// Run compiled circuit validation on onnx network (should fail)
|
||||
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK.to_vec()));
|
||||
assert!(circuit.is_err());
|
||||
// Run compiled circuit validation on comiled network (should pass)
|
||||
// Run compiled circuit validation on compiled network (should pass)
|
||||
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()));
|
||||
assert!(circuit.is_ok());
|
||||
// Run input validation on witness (should fail)
|
||||
|
||||
Reference in New Issue
Block a user