Compare commits

..

2 Commits

Author SHA1 Message Date
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
9 changed files with 64 additions and 45 deletions

View File

@@ -444,7 +444,7 @@ pub enum Commands {
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
Aggregate {
@@ -479,7 +479,7 @@ pub enum Commands {
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -726,7 +726,7 @@ pub enum Commands {
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl

View File

@@ -339,7 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
split_proofs,
disable_selector_compression,
commitment,
commitment.into(),
),
Commands::Aggregate {
proof_path,
@@ -360,7 +360,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
check_mode,
split_proofs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Verify {
@@ -384,7 +384,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
reduced_srs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
@@ -586,7 +586,7 @@ pub(crate) async fn get_srs_cmd(
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.commitment
settings.run_args.commitment.into()
} else {
return Err(err_string.into());
}
@@ -666,21 +666,17 @@ pub(crate) async fn gen_witness(
// if any of the settings have kzg visibility then we need to load the srs
let commitment: Commitments = settings.run_args.commitment.into();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(
settings.run_args.logrows,
srs_path.clone(),
settings.run_args.commitment,
)
.exists()
{
match settings.run_args.commitment {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
match Commitments::from(settings.run_args.commitment) {
Commitments::KZG => {
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<KZGCommitmentScheme<_>>(
&mut input,
@@ -694,7 +690,7 @@ pub(crate) async fn gen_witness(
load_params_prover::<IPACommitmentScheme<G1Affine>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<IPACommitmentScheme<_>>(
&mut input,
@@ -1303,17 +1299,19 @@ pub(crate) fn create_evm_verifier(
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1347,17 +1345,18 @@ pub(crate) fn create_evm_vk(
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1626,8 +1625,9 @@ pub(crate) fn setup(
}
let logrows = circuit.settings().run_args.logrows;
let commitment: Commitments = circuit.settings().run_args.commitment.into();
let pk = match circuit.settings().run_args.commitment {
let pk = match commitment {
Commitments::KZG => {
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
@@ -1736,7 +1736,8 @@ pub(crate) fn prove(
let transcript: TranscriptType = proof_type.into();
let proof_split_commits: Option<ProofSplitCommit> = data.into();
let commitment = circuit_settings.run_args.commitment;
let commitment = circuit_settings.run_args.commitment.into();
let logrows = circuit_settings.run_args.logrows;
// creates and verifies the proof
let mut snark = match commitment {
Commitments::KZG => {
@@ -1745,7 +1746,7 @@ pub(crate) fn prove(
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
logrows,
Commitments::KZG,
)?;
match strategy {
@@ -2187,8 +2188,9 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
match circuit_settings.run_args.commitment {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let params: ParamsKZG<Bn256> = if reduced_srs {

View File

@@ -114,6 +114,12 @@ pub enum Commitments {
IPA,
}
impl From<Option<Commitments>> for Commitments {
fn from(value: Option<Commitments>) -> Self {
value.unwrap_or(Commitments::KZG)
}
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -214,7 +220,7 @@ pub struct RunArgs {
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg")]
pub commitment: Commitments,
pub commitment: Option<Commitments>,
}
impl Default for RunArgs {
@@ -234,7 +240,7 @@ impl Default for RunArgs {
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: Commitments::KZG,
commitment: None,
}
}
}

View File

@@ -197,7 +197,7 @@ impl From<PyRunArgs> for RunArgs {
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: py_run_args.commitment.into(),
commitment: Some(py_run_args.commitment.into()),
}
}
}
@@ -234,6 +234,16 @@ pub enum PyCommitments {
IPA,
}
impl From<Option<Commitments>> for PyCommitments {
fn from(commitment: Option<Commitments>) -> Self {
match commitment {
Some(Commitments::KZG) => PyCommitments::KZG,
Some(Commitments::IPA) => PyCommitments::IPA,
None => PyCommitments::KZG,
}
}
}
impl From<PyCommitments> for Commitments {
fn from(py_commitments: PyCommitments) -> Self {
match py_commitments {

View File

@@ -1,3 +1,5 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use super::{
@@ -450,10 +452,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map_iterator(
&self,
) -> core::iter::FilterMap<
core::slice::Iter<'_, ValType<F>>,
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
> {
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {

View File

@@ -337,8 +337,10 @@ pub fn verify(
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match circuit_settings.run_args.commitment {
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
@@ -521,8 +523,9 @@ pub fn prove(
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
// creates and verifies the proof
let proof = match circuit.settings().run_args.commitment {
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)

View File

@@ -122,7 +122,7 @@ mod native_tests {
let settings: GraphSettings = serde_json::from_str(&settings).unwrap();
let logrows = settings.run_args.logrows;
download_srs(logrows, settings.run_args.commitment);
download_srs(logrows, settings.run_args.commitment.into());
}
fn mv_test_(test_dir: &str, test: &str) {
@@ -623,7 +623,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
if test != "gather_nd" && test != "lstm_large" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
@@ -1971,7 +1971,7 @@ mod native_tests {
.expect("failed to parse settings file");
// get_srs for the graph_settings_num_instances
download_srs(1, graph_settings.run_args.commitment);
download_srs(1, graph_settings.run_args.commitment.into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([

Binary file not shown.

View File

@@ -24,8 +24,7 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE",
"commitment": "KZG"
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,