Compare commits

..

2 Commits

Author SHA1 Message Date
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
5 changed files with 139 additions and 27 deletions

View File

@@ -568,10 +568,10 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let is_assigned = !input.any_unknowns()?;
let sorted = if is_assigned {
input
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
let mut int_evals = input.get_int_evals()?;
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
int_evals
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
} else {
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
let table_len = table_0.len();
trace!("assigning tables took: {:?}", start.elapsed());
// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
trace!("assigning table index took: {:?}", start.elapsed());
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
trace!("assigning lookups took: {:?}", start.elapsed());
// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
trace!("assigning lookup index took: {:?}", start.elapsed());
if !region.is_dummy() {
(0..table_len)
.map(|i| {
@@ -3251,11 +3259,15 @@ pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// get the max then subtract it
let max_val = max(config, region, values)?;
// rebase the input to 0
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
// elementwise exponential
let ex = nonlinearity(
config,
region,
values,
&[sub],
&LookupOp::Exp { scale: input_scale },
)?;

View File

@@ -163,7 +163,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants.into_iter());
self.assigned_constants.extend(constants);
}
///
@@ -389,7 +389,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants.into_iter());
constants.extend(local_reg.assigned_constants);
res
})
@@ -574,8 +574,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -599,8 +601,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -630,9 +634,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let mut values_map = values.create_constants_map();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {

View File

@@ -911,7 +911,7 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
debug!("num of calibration batches: {}", chunks.len());
info!("num calibration batches: {}", chunks.len());
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(

View File

@@ -448,25 +448,39 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Returns the number of constants in the [ValTensor].
pub fn num_constants(&self) -> usize {
pub fn create_constants_map_iterator(
&self,
) -> core::iter::FilterMap<
core::slice::Iter<'_, ValType<F>>,
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
ValTensor::Instance { .. } => 0,
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
}),
ValTensor::Instance { .. } => {
unreachable!("Instance tensors do not have constants")
}
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => {
let map = inner.iter().fold(ConstantsMap::new(), |mut acc, x| {
if let ValType::Constant(c) = x {
acc.insert(*c, x.clone());
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
acc
});
map
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}

View File

@@ -8,12 +8,14 @@ use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
@@ -33,11 +35,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
@@ -395,6 +396,88 @@ pub fn verify(
}
}
#[wasm_bindgen]
#[allow(non_snake_case)]
/// Verify aggregate proof in browser using wasm
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, logrows),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, logrows)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, logrows),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, logrows)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(JsError::new(&format!("{}", e))),
}
}
/// Prove in browser using wasm
#[wasm_bindgen]
pub fn prove(