Compare commits

...

2 Commits

Author SHA1 Message Date
github-actions[bot]
17ad9f76ce ci: update version string in docs 2025-03-25 19:32:33 +00:00
dante
70469e3bf9 chore: add min/max to gen-random-data (#960) 2025-03-25 19:32:15 +00:00
4 changed files with 49 additions and 22 deletions

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '21.0.3'
version = release

View File

@@ -1,34 +1,34 @@
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
@@ -962,6 +962,8 @@ fn gen_settings(
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
min=None,
max=None
))]
#[gen_stub_pyfunction]
fn gen_random_data(
@@ -969,8 +971,10 @@ fn gen_random_data(
output: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
min: Option<f32>,
max: Option<f32>,
) -> Result<bool, PyErr> {
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
crate::execute::gen_random_data(model, output, variables, seed, min, max).map_err(|e| {
let err_str = format!("Failed to generate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;

View File

@@ -443,6 +443,12 @@ pub enum Commands {
/// random seed for reproducibility (optional)
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
seed: u64,
/// min value for random data
#[arg(long, value_hint = clap::ValueHint::Other)]
min: Option<f32>,
/// max value for random data
#[arg(long, value_hint = clap::ValueHint::Other)]
max: Option<f32>,
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {

View File

@@ -45,6 +45,7 @@ use halo2curves::serde::SerdeObject;
use indicatif::{ProgressBar, ProgressStyle};
use instant::Instant;
use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::Serialize;
@@ -65,8 +66,6 @@ use thiserror::Error;
use tract_onnx::prelude::IntoTensor;
use tract_onnx::prelude::Tensor as TractTensor;
use lazy_static::lazy_static;
lazy_static! {
#[derive(Debug)]
/// The path to the ezkl related data.
@@ -138,11 +137,15 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
data,
variables,
seed,
min,
max,
} => gen_random_data(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
variables,
seed,
min,
max,
),
Commands::CalibrateSettings {
model,
@@ -840,6 +843,8 @@ pub(crate) fn gen_random_data(
data_path: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
min: Option<f32>,
max: Option<f32>,
) -> Result<String, EZKLError> {
let mut file = std::fs::File::open(&model_path).map_err(|e| {
crate::graph::errors::GraphError::ReadWriteFileError(
@@ -858,22 +863,32 @@ pub(crate) fn gen_random_data(
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
.map_err(|e| EZKLError::from(e.to_string()))?;
let min = min.unwrap_or(0.0);
let max = max.unwrap_or(1.0);
/// Generates a random tensor of a given size and type.
fn random(
sizes: &[usize],
datum_type: tract_onnx::prelude::DatumType,
seed: u64,
min: f32,
max: f32,
) -> TractTensor {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.r#gen());
slice.iter_mut().for_each(|x| *x = rng.gen_range(min..max));
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
fn tensor_for_fact(
fact: &tract_onnx::prelude::TypedFact,
seed: u64,
min: f32,
max: f32,
) -> TractTensor {
if let Some(value) = &fact.konst {
return value.clone().into_tensor();
}
@@ -884,12 +899,14 @@ pub(crate) fn gen_random_data(
.expect("Expected concrete shape, found: {fact:?}"),
fact.datum_type,
seed,
min,
max,
)
}
let generated = input_facts
.iter()
.map(|v| tensor_for_fact(v, seed))
.map(|v| tensor_for_fact(v, seed, min, max))
.collect_vec();
let data = GraphData::from_tract_data(&generated)?;