mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ec8d13082 | ||
|
|
12735aefd4 |
@@ -568,10 +568,10 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let sorted = if is_assigned {
|
||||
input
|
||||
.get_int_evals()?
|
||||
.iter()
|
||||
.sorted_by(|a, b| a.cmp(b))
|
||||
let mut int_evals = input.get_int_evals()?;
|
||||
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
|
||||
int_evals
|
||||
.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
} else {
|
||||
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
|
||||
let table_len = table_0.len();
|
||||
|
||||
trace!("assigning tables took: {:?}", start.elapsed());
|
||||
|
||||
// now create a vartensor of constants for the dynamic lookup index
|
||||
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
|
||||
let _table_index =
|
||||
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
|
||||
|
||||
trace!("assigning table index took: {:?}", start.elapsed());
|
||||
|
||||
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
|
||||
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
|
||||
let lookup_len = lookup_0.len();
|
||||
|
||||
trace!("assigning lookups took: {:?}", start.elapsed());
|
||||
|
||||
// now set the lookup index
|
||||
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
|
||||
|
||||
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
|
||||
|
||||
trace!("assigning lookup index took: {:?}", start.elapsed());
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..table_len)
|
||||
.map(|i| {
|
||||
@@ -3251,11 +3259,15 @@ pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// get the max then subtract it
|
||||
let max_val = max(config, region, values)?;
|
||||
// rebase the input to 0
|
||||
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
|
||||
// elementwise exponential
|
||||
let ex = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values,
|
||||
&[sub],
|
||||
&LookupOp::Exp { scale: input_scale },
|
||||
)?;
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
///
|
||||
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
|
||||
self.assigned_constants.extend(constants.into_iter());
|
||||
self.assigned_constants.extend(constants);
|
||||
}
|
||||
|
||||
///
|
||||
@@ -389,7 +389,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants.into_iter());
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
|
||||
res
|
||||
})
|
||||
@@ -574,8 +574,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
let values_map = values.create_constants_map();
|
||||
self.assigned_constants.extend(values_map);
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -599,8 +601,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
let values_map = values.create_constants_map();
|
||||
self.assigned_constants.extend(values_map);
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -630,9 +634,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
let mut values_map = values.create_constants_map();
|
||||
|
||||
let inner_tensor = values.get_inner_tensor().unwrap();
|
||||
let mut values_map = values.create_constants_map();
|
||||
|
||||
for o in ommissions {
|
||||
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
|
||||
|
||||
@@ -911,7 +911,7 @@ pub(crate) fn calibrate(
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path)?;
|
||||
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
|
||||
debug!("num of calibration batches: {}", chunks.len());
|
||||
info!("num calibration batches: {}", chunks.len());
|
||||
|
||||
debug!("running onnx predictions...");
|
||||
let original_predictions = Model::run_onnx_predictions(
|
||||
|
||||
@@ -448,25 +448,39 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Returns the number of constants in the [ValTensor].
|
||||
pub fn num_constants(&self) -> usize {
|
||||
pub fn create_constants_map_iterator(
|
||||
&self,
|
||||
) -> core::iter::FilterMap<
|
||||
core::slice::Iter<'_, ValType<F>>,
|
||||
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
|
||||
> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
|
||||
ValTensor::Instance { .. } => 0,
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
ValTensor::Instance { .. } => {
|
||||
unreachable!("Instance tensors do not have constants")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of constants in the [ValTensor].
|
||||
pub fn create_constants_map(&self) -> ConstantsMap<F> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => {
|
||||
let map = inner.iter().fold(ConstantsMap::new(), |mut acc, x| {
|
||||
if let ValType::Constant(c) = x {
|
||||
acc.insert(*c, x.clone());
|
||||
ValTensor::Value { inner, .. } => inner
|
||||
.par_iter()
|
||||
.filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
acc
|
||||
});
|
||||
map
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => ConstantsMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
87
src/wasm.rs
87
src/wasm.rs
@@ -8,12 +8,14 @@ use crate::graph::quantize_float;
|
||||
use crate::graph::scale_to_multiplier;
|
||||
use crate::graph::{GraphCircuit, GraphSettings};
|
||||
use crate::pfsys::create_proof_circuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::pfsys::verify_proof_circuit;
|
||||
use crate::pfsys::TranscriptType;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::CheckMode;
|
||||
use crate::Commitments;
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::plonk::*;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
|
||||
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
|
||||
@@ -33,11 +35,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
|
||||
use console_error_panic_hook;
|
||||
|
||||
#[cfg(feature = "web")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
|
||||
@@ -395,6 +396,88 @@ pub fn verify(
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
/// Verify aggregate proof in browser using wasm
|
||||
pub fn verifyAggr(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
logrows: u64,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
commitment: &str,
|
||||
) -> Result<bool, JsError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
(),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
let result = match commit {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows),
|
||||
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => Err(JsError::new(&format!("{}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prove in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
pub fn prove(
|
||||
|
||||
Reference in New Issue
Block a user