diff --git a/benches/accum_einsum_matmul.rs b/benches/accum_einsum_matmul.rs index 5682749e..0abe0fbb 100644 --- a/benches/accum_einsum_matmul.rs +++ b/benches/accum_einsum_matmul.rs @@ -6,13 +6,14 @@ use ezkl::circuit::einsum::analysis::analyze_einsum_usage; use ezkl::circuit::einsum::circuit_params::SingleEinsumParams; use ezkl::circuit::poly::PolyOp; use ezkl::circuit::*; -use ezkl::pfsys::create_keys; +use ezkl::pfsys::{create_keys, create_proof_circuit, TranscriptType}; use ezkl::pfsys::srs::gen_srs; use ezkl::tensor::*; use halo2_proofs::circuit::floor_planner::V1; use halo2_proofs::plonk::create_proof; use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; -use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK; +use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; +use halo2_proofs::poly::kzg::strategy::SingleStrategy; use halo2_proofs::transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}; use halo2_proofs::{ arithmetic::Field, @@ -23,6 +24,7 @@ use halo2curves::bn256::{Bn256, Fr}; use halo2curves::ff::PrimeField; use itertools::Itertools; use rand::rngs::OsRng; +use snark_verifier::system::halo2::transcript::evm::EvmTranscript; use std::collections::HashMap; use std::marker::PhantomData; @@ -48,13 +50,14 @@ impl Circuit for MyCircuit { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 2; unsafe { config .configure_einsums(cs, &analysis, num_einsum_inner_cols, K) .unwrap(); + let _constant = VarTensor::constant_cols(cs, K, 2, false); } config @@ -72,21 +75,7 @@ impl Circuit for MyCircuit { } fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let mut config = Self::Config::default(); - - let default_params = Self::Params::default(); - - let mut equations = HashMap::new(); - equations.insert(default_params.equation, default_params.input_axes_to_dims); - let analysis = analyze_einsum_usage(&equations).unwrap(); - let num_einsum_inner_cols = 1; - unsafe { - config - .configure_einsums(cs, &analysis, num_einsum_inner_cols, K) - .unwrap(); - } - - config + unimplemented!("call configure_with_params instead") } fn synthesize( @@ -133,11 +122,11 @@ fn runmatmul(c: &mut Criterion) { group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear)); group.sampling_mode(criterion::SamplingMode::Flat); group.sample_size(10); - let len = 512; + let len = 128; unsafe { LEN = len; } - for k in 19..20 { + for k in 15..16 { let params = unsafe { K = k; gen_srs::>(K as u32) @@ -170,24 +159,36 @@ fn runmatmul(c: &mut Criterion) { group.throughput(Throughput::Elements(len as u64)); group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| { b.iter(|| { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - - create_proof::, ProverSHPLONK<_>, _, _, _, _>( + let prover = create_proof_circuit::< + KZGCommitmentScheme<_>, + MyCircuit, + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + SingleStrategy<_>, + _, + EvmTranscript<_, _, _, _>, + EvmTranscript<_, _, _, _>, + >( + circuit.clone(), + vec![], ¶ms, &pk, - &[circuit.clone()], - &[&[]], - OsRng, - &mut transcript, - ) - .expect("proof generation should not fail"); - - transcript.finalize(); + CheckMode::UNSAFE, + ezkl::Commitments::KZG, + TranscriptType::EVM, + None, + None, + ); + prover.unwrap(); }); }); } group.finish(); } -criterion_group!(benches, runmatmul); +criterion_group! { + name = benches; + config = Criterion::default().with_plots(); + targets = runmatmul +} criterion_main!(benches); diff --git a/examples/accum_einsum_matmul.rs b/examples/accum_einsum_matmul.rs index b7fbe73e..961c2f3a 100644 --- a/examples/accum_einsum_matmul.rs +++ b/examples/accum_einsum_matmul.rs @@ -78,7 +78,7 @@ impl Circuit for MyCircuit { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 1; config @@ -99,20 +99,8 @@ impl Circuit for MyCircuit { .unwrap() } - fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let mut config = Self::Config::default(); - - let default_params = Self::Params::default(); - - let mut equations = HashMap::new(); - equations.insert(default_params.equation, default_params.input_axes_to_dims); - let analysis = analyze_einsum_usage(&equations).unwrap(); - let num_einsum_inner_cols = 1; - config - .configure_einsums(cs, &analysis, num_einsum_inner_cols, K) - .unwrap(); - - config + fn configure(_cs: &mut ConstraintSystem) -> Self::Config { + unimplemented!("call configure_with_params instead") } fn synthesize( diff --git a/examples/batch_mat_mul.rs b/examples/batch_mat_mul.rs index 17d09cf9..826ea90d 100644 --- a/examples/batch_mat_mul.rs +++ b/examples/batch_mat_mul.rs @@ -85,7 +85,7 @@ impl Circuit for MyCircuit { let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 1; config @@ -106,20 +106,8 @@ impl Circuit for MyCircuit { .unwrap() } - fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let mut config = Self::Config::default(); - - let default_params = Self::Params::default(); - - let mut equations = HashMap::new(); - equations.insert(default_params.equation, default_params.input_axes_to_dims); - let analysis = analyze_einsum_usage(&equations).unwrap(); - let num_einsum_inner_cols = 1; - config - .configure_einsums(cs, &analysis, num_einsum_inner_cols, K) - .unwrap(); - - config + fn configure(_cs: &mut ConstraintSystem) -> Self::Config { + unimplemented!("call configure_with_params instead") } fn synthesize( diff --git a/examples/tensor_contraction.rs b/examples/tensor_contraction.rs index a8ec031d..fa606a60 100644 --- a/examples/tensor_contraction.rs +++ b/examples/tensor_contraction.rs @@ -85,12 +85,13 @@ impl Circuit for MyCircuit { let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 2; config .configure_einsums(cs, &analysis, num_einsum_inner_cols, K) .unwrap(); + let _constant = VarTensor::constant_cols(cs, K, 2, false); config } @@ -106,20 +107,8 @@ impl Circuit for MyCircuit { .unwrap() } - fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let mut config = Self::Config::default(); - - let default_params = Self::Params::default(); - - let mut equations = HashMap::new(); - equations.insert(default_params.equation, default_params.input_axes_to_dims); - let analysis = analyze_einsum_usage(&equations).unwrap(); - let num_einsum_inner_cols = 1; - config - .configure_einsums(cs, &analysis, num_einsum_inner_cols, K) - .unwrap(); - - config + fn configure(_cs: &mut ConstraintSystem) -> Self::Config { + unimplemented!("call configure_with_params instead") } fn synthesize( diff --git a/src/circuit/ops/chip/einsum/analysis.rs b/src/circuit/ops/chip/einsum/analysis.rs index 1a268372..cfde03c2 100644 --- a/src/circuit/ops/chip/einsum/analysis.rs +++ b/src/circuit/ops/chip/einsum/analysis.rs @@ -47,7 +47,7 @@ pub struct SingleEquationAnalysis { /// pub fn analyze_einsum_usage( - equations: &HashMap>, + equations: &HashMap<(usize, String), HashMap>, ) -> Result { let mut max_num_inputs = 0; let mut max_input_size = 0; @@ -56,7 +56,7 @@ pub fn analyze_einsum_usage( let mut longest_challenge_vector = 0; let mut reduction_length = 0; - for (equation, input_axes_to_dim) in equations.iter() { + for ((_, equation), input_axes_to_dim) in equations.iter() { let analysis = analyze_single_equation(equation, input_axes_to_dim)?; max_input_size = max_input_size.max(analysis.max_input_size); longest_challenge_vector = longest_challenge_vector.max(analysis.longest_challenge_vector); @@ -126,7 +126,7 @@ pub fn analyze_single_equation( .iter() .map(|c| input_axes_to_dim.get(&c).unwrap()); let output_size = output_dims.clone().product(); - let longest_challenge_vector = *output_dims.clone().max().unwrap(); + let longest_challenge_vector = *output_dims.clone().max().unwrap_or(&0); let output_reduction_length = { let mut output_dims = output_dims.rev().cloned().collect_vec(); diff --git a/src/circuit/ops/chip/einsum/layouts.rs b/src/circuit/ops/chip/einsum/layouts.rs index 2f316bd9..9e56e301 100644 --- a/src/circuit/ops/chip/einsum/layouts.rs +++ b/src/circuit/ops/chip/einsum/layouts.rs @@ -1,4 +1,3 @@ -use halo2_proofs::circuit::Value; use halo2curves::ff::PrimeField; use log::{error, trace}; @@ -41,11 +40,12 @@ pub fn pairwise( region.flush_einsum()?; - let vars = config.get_vartensors(phases.as_slice().into()); + let input_vars = config.get_input_vars(phases.as_slice().into()); + let output_var = config.get_output_var(phases.as_slice().into()); let inputs = [lhs, rhs] .iter() - .zip(vars) + .zip(input_vars) .map(|(val, var)| { let res = region.assign_einsum(var, val)?; Ok(res.get_inner()?) @@ -66,13 +66,13 @@ pub fn pairwise( })?; let assigned_len = op_result.len(); - let mut output = region.assign_einsum(&config.output, &op_result.into())?; + let mut output = region.assign_einsum(output_var, &op_result.into())?; // Enable the selectors if !region.is_dummy() { (0..assigned_len) .map(|i| { - let (x, y, z) = config.output.cartesian_coord(region.einsum_col_coord() + i); + let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i); let op_info = BaseOpInfo { op_kind: op.clone(), input_phases: phases.as_slice().into(), @@ -107,15 +107,12 @@ pub fn sum( region.flush_einsum()?; let mut input = values[0].clone(); - let block_width = config.output.num_inner_cols(); + let block_width = config.block_width(); let assigned_len: usize; let input = { - // FIXME : should pad with constant zero but currently this incurs an error - // `NotEnoughColumnsForConstants` in halo2 because trying to assign constant - // value to advice column, how to workaround this issue? - input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?; - let var = config.get_vartensors([phase].as_slice().into())[0]; + input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?; + let var = config.get_input_vars([phase].as_slice().into())[0]; let (res, len) = region .assign_einsum_with_duplication_unconstrained(var, &input)?; assigned_len = len; @@ -125,8 +122,9 @@ pub fn sum( // Now we can assign the dot product let accumulated_sum = accumulated::sum(&input, block_width)?; + let output_var = config.get_output_var([phase].as_slice().into()); let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained( - &config.output, + output_var, &accumulated_sum.into(), check_mode, )?; @@ -134,9 +132,7 @@ pub fn sum( // enable the selectors if !region.is_dummy() { for i in 0..output_assigned_len { - let (x, _, z) = config - .output - .cartesian_coord(region.einsum_col_coord() + i * block_width); + let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width); // skip over duplicates at start of column if z == 0 && i > 0 { continue; @@ -176,15 +172,12 @@ pub fn prod( ) -> Result, CircuitError> { assert!(phase == 0 || phase == 1); region.flush_einsum()?; - let block_width = config.output.num_inner_cols(); + let block_width = config.block_width(); let assigned_len: usize; let input = { let mut input = values[0].clone(); - // FIXME : should pad with constant one but currently this incurs an error - // `NotEnoughColumnsForConstants` in halo2 because trying to assign constant - // value to advice column, how to workaround this issue? - input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ONE)))?; - let var = config.get_vartensors([phase].as_slice().into())[0]; + input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?; + let var = config.get_input_vars([phase].as_slice().into())[0]; let (res, len) = region .assign_einsum_with_duplication_unconstrained(var, &input)?; assigned_len = len; @@ -194,8 +187,9 @@ pub fn prod( // Now we can assign the dot product let accumulated_prod = accumulated::prod(&input, block_width)?; + let output_var = config.get_output_var([phase].as_slice().into()); let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained( - &config.output, + output_var, &accumulated_prod.into(), check_mode, )?; @@ -204,9 +198,7 @@ pub fn prod( if !region.is_dummy() { (0..output_assigned_len) .map(|i| { - let (x, _, z) = config - .output - .cartesian_coord(region.einsum_col_coord() + i * block_width); + let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width); // skip over duplicates at start of column if z == 0 && i > 0 { return Ok(()); @@ -259,17 +251,14 @@ pub fn dot( } else { [values[1].clone(), values[0].clone()] }; - let vars = config.get_vartensors(phases.as_slice().into()); + let vars = config.get_input_vars(phases.as_slice().into()); let mut inputs = vec![]; - let block_width = config.output.num_inner_cols(); + let block_width = config.block_width(); let mut assigned_len = 0; for (val, var) in values.iter_mut().zip(vars) { - // FIXME : should pad with constant zero but currently this incurs an error - // `NotEnoughColumnsForConstants` in halo2 because trying to assign constant - // value to advice column, how to workaround this issue? - val.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?; + val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?; let inp = { let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?; assigned_len = len; @@ -281,8 +270,9 @@ pub fn dot( // Now we can assign the dot product // time this step let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?; + let output_var = config.get_output_var(phases.as_slice().into()); let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained( - &config.output, + output_var, &accumulated_dot.into(), check_mode, ).expect("failed to assign einsum with duplication constrained"); @@ -291,9 +281,7 @@ pub fn dot( if !region.is_dummy() { (0..output_assigned_len) .map(|i| { - let (x, _, z) = config - .output - .cartesian_coord(region.einsum_col_coord() + i * block_width); + let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width); // hop over duplicates at start of column if z == 0 && i > 0 { return Ok(()); diff --git a/src/circuit/ops/chip/einsum/mod.rs b/src/circuit/ops/chip/einsum/mod.rs index 23992e30..67a25364 100644 --- a/src/circuit/ops/chip/einsum/mod.rs +++ b/src/circuit/ops/chip/einsum/mod.rs @@ -26,7 +26,7 @@ mod reduction_planner; #[derive(Clone, Debug, Default)] pub struct Einsums { /// custom gate to constrain tensor contractions - custom_gate: ContractionConfig, + contraction_gate: ContractionConfig, /// custom gate to constrain random linear combinations used by Freivalds' argument rlc_gates: Vec>, } @@ -35,17 +35,17 @@ impl Einsums { /// pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self { let dummy_var = VarTensor::dummy(col_size, num_inner_cols); - let dummy_custom_gate = ContractionConfig { + let dummy_contraction_gate = ContractionConfig { inputs: [ [dummy_var.clone(), dummy_var.clone()], [dummy_var.clone(), dummy_var.clone()], ], - output: dummy_var.clone(), + outputs: [dummy_var.clone(), dummy_var.clone()], selectors: BTreeMap::default(), _marker: PhantomData, }; Self { - custom_gate: dummy_custom_gate, + contraction_gate: dummy_contraction_gate, rlc_gates: vec![], } } @@ -72,21 +72,24 @@ impl Einsums { VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity), VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity), ]; - let output = VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity); - let custom_gate = ContractionConfig::new( + let outputs = [ + VarTensor::new_advice(meta, logrows, num_inner_cols, capacity), + VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity) + ]; + let contraction_gate = ContractionConfig::new( meta, &[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]], - &output, + &[&outputs[0], &outputs[1]], ); let mut rlc_gates = vec![]; for _ in 0..analysis.max_num_output_axes { - let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &output); + let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]); rlc_gates.push(rlc_gate); } Self { - custom_gate, + contraction_gate, rlc_gates, } } @@ -100,7 +103,7 @@ impl Einsums { equation: &str, check_mode: &CheckMode, ) -> Result<(), CircuitError> { - region.set_num_einsum_inner_cols(self.custom_gate.output.num_inner_cols()); + region.set_num_einsum_inner_cols(self.contraction_gate.block_width()); let (input_exprs, _) = equation.split_once("->").unwrap(); let input_exprs = input_exprs.split(",").collect_vec(); @@ -109,22 +112,14 @@ impl Einsums { let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec(); let mut output_tensor = output_tensor.clone(); - // Remove trivial axes from tensors - input_tensors - .iter_mut() - .map(|tensor| tensor.remove_trivial_axes()) - .collect::, TensorError>>()?; - output_tensor.remove_trivial_axes()?; - let mut input_axes_to_dim: HashMap = HashMap::new(); input_exprs .iter() .zip(input_tensors.iter()) .for_each(|(indices, tensor)| { - let tensor_dim = tensor.dims(); indices .chars() - .zip(tensor_dim.iter()) + .zip(tensor.dims()) .for_each(|(index, dim)| { if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dim.entry(index) @@ -134,6 +129,13 @@ impl Einsums { }); }); + // Remove trivial axes from tensors + input_tensors + .iter_mut() + .map(|tensor| tensor.remove_trivial_axes()) + .collect::, TensorError>>()?; + output_tensor.remove_trivial_axes()?; + let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?; let equation = equation_analysis.equation; @@ -234,7 +236,7 @@ impl Einsums { Some(axis) => { let dot_product_len = input_axes_to_dim[axis]; assign_input_contraction( - &self.custom_gate, + &self.contraction_gate, region, flattened_input_tensors, dot_product_len, @@ -245,7 +247,7 @@ impl Einsums { } None => { let mut result = assign_pairwise_mult( - &self.custom_gate, + &self.contraction_gate, region, flattened_input_tensors, input_phases, @@ -264,7 +266,7 @@ impl Einsums { .map(|t| t.get_inner_tensor().unwrap().get_scalar()) .collect_vec() .into(); - let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1, check_mode)?; + let squashed_input = prod(&self.contraction_gate, region, &[&scalars], 1, check_mode)?; region.constrain_equal(&squashed_input, &squashed_output) } @@ -404,20 +406,20 @@ struct BaseOpInfo { pub input_phases: InputPhases, } -/// `ContractionConfig` is the custom gate used for einsum contraction operations +/// `ContractionConfig` is the custom gate to constrain tensor contractions #[derive(Clone, Debug, Default)] struct ContractionConfig { - // [[phase 0, phase 0], [phase 1, phase 1]] + // [[first phase, first phase], [second phase, second phase]] inputs: [[VarTensor; 2]; 2], - // phase 1 - output: VarTensor, + // [first phase, second phase] + outputs: [VarTensor; 2], // (BaseOpInfo, block index, inner column index) -> selector selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>, _marker: PhantomData, } impl ContractionConfig { - fn get_vartensors(&self, input_phases: InputPhases) -> Vec<&VarTensor> { + fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> { match input_phases { InputPhases::FirstPhase => vec![&self.inputs[0][0]], InputPhases::SecondPhase => vec![&self.inputs[1][0]], @@ -427,19 +429,35 @@ impl ContractionConfig { } } + fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor { + match input_phases { + InputPhases::FirstPhase => &self.outputs[0], + InputPhases::SecondPhase => &self.outputs[1], + InputPhases::BothFirstPhase => &self.outputs[0], + InputPhases::Mixed => &self.outputs[1], + InputPhases::BothSecondPhase => &self.outputs[1], + } + } + + fn block_width(&self) -> usize { + self.outputs[0].num_inner_cols() + } + fn new( meta: &mut ConstraintSystem, inputs: &[&[&VarTensor; 2]; 2], - output: &VarTensor, + outputs: &[&VarTensor; 2], ) -> Self { let mut selectors = BTreeMap::new(); + let num_blocks = outputs[0].num_blocks(); + let block_width = outputs[0].num_inner_cols(); for input_phases in [ InputPhases::BothFirstPhase, InputPhases::Mixed, InputPhases::BothSecondPhase, ] { - for i in 0..output.num_blocks() { - for j in 0..output.num_inner_cols() { + for i in 0..num_blocks { + for j in 0..block_width { selectors.insert( ( BaseOpInfo { @@ -452,7 +470,7 @@ impl ContractionConfig { meta.selector(), ); } - for i in 0..output.num_blocks() { + for i in 0..num_blocks { selectors.insert( ( BaseOpInfo { @@ -483,7 +501,7 @@ impl ContractionConfig { InputPhases::FirstPhase, InputPhases::SecondPhase, ] { - for i in 0..output.num_blocks() { + for i in 0..num_blocks { selectors.insert( ( BaseOpInfo { @@ -538,6 +556,13 @@ impl ContractionConfig { InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]], InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]], }; + let output = match base_op.input_phases { + InputPhases::FirstPhase => outputs[0], + InputPhases::SecondPhase => outputs[1], + InputPhases::BothFirstPhase => outputs[0], + InputPhases::Mixed => outputs[1], + InputPhases::BothSecondPhase => outputs[1], + }; assert_eq!(inputs.len(), base_op.op_kind.num_inputs()); match base_op.op_kind { BaseOp::Mult => { @@ -549,7 +574,7 @@ impl ContractionConfig { for (q_i, input) in qis.iter_mut().zip(inputs) { *q_i = input .query_rng(meta, *block_idx, *inner_col_idx, 0, 1) - .expect("einsum op config: input query failed")[0] + .expect("contraction config: input query failed")[0] .clone() } // Get output expressions for each input channel @@ -557,7 +582,7 @@ impl ContractionConfig { let constraints = { let expected_output: Tensor> = output .query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng) - .expect("einsum op config: output query failed"); + .expect("contraction config: output query failed"); let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone())); vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res] @@ -572,7 +597,7 @@ impl ContractionConfig { for (q_i, input) in qis.iter_mut().zip(inputs) { *q_i = input .query_whole_block(meta, *block_idx, 0, 1) - .expect("einsum op config: input query failed") + .expect("contraction config: input query failed") .into_iter() .collect() } @@ -580,7 +605,7 @@ impl ContractionConfig { let (rotation_offset, rng) = base_op.op_kind.query_offset_rng(); let expected_output: Tensor> = output .query_rng(meta, *block_idx, 0, rotation_offset, rng) - .expect("einsum op config: output query failed"); + .expect("contraction config: output query failed"); let res = base_op.op_kind.accum_f( expected_output[0].clone(), @@ -613,7 +638,7 @@ impl ContractionConfig { Self { inputs: [first_phase_inputs, second_phase_inputs], - output: output.clone(), + outputs: [outputs[0].clone(), outputs[1].clone()], selectors, _marker: PhantomData, } @@ -624,7 +649,7 @@ impl ContractionConfig { #[derive(Clone, Debug)] struct RLCConfig { pub challenge: Challenge, - /// [phase 0, phase 1] + /// [first phase, second phase] pub inputs: [VarTensor; 2], pub output: VarTensor, /// (phase of input, block index) -> (init selector, acc selector) @@ -731,7 +756,9 @@ impl RLCConfig { .zip(powers_of_challenge.iter().rev()) .map(|(v, c_power)| { c_power.and_then(|c_power| { - Value::known(c_power * v.get_felt_eval().unwrap()) + v.get_felt_eval().and_then(|v| { + Some(Value::known(c_power * v)) + }).unwrap_or(Value::unknown()) }) }) .reduce(|acc, v| acc + v) @@ -743,7 +770,7 @@ impl RLCConfig { let assigned_len = { let mut input: ValTensor = tensor.iter().collect_vec().into(); - input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?; + input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?; let (_, len) = region .assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?; len diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 5c472611..b844bb9d 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -831,17 +831,9 @@ pub fn einsum( inputs: &[&ValTensor], equation: &str, ) -> Result, CircuitError> { - // Track the einsum equation - region.add_used_einsum_equation(equation.to_string())?; - - // dispatch to freivalds' argument - if !config.einsums.challenges().is_empty() { - return freivalds(config, region, inputs, equation); - } - - let mut equation = equation.split("->"); - let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?; - let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?; + let mut eq = equation.split("->"); + let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?; + let output_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?; let inputs_eq = inputs_eq.split(',').collect::>(); // Check that the number of inputs matches the number of inputs in the equation @@ -864,6 +856,14 @@ pub fn einsum( } } + // Track the einsum equation + region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?; + + // dispatch to freivalds' argument + if !config.einsums.challenges().is_empty() { + return freivalds(config, region, inputs, equation); + } + // maps unrepresented indices in the output to a trivial 1 for c in output_eq.chars() { indices_to_size.entry(c).or_insert(1); @@ -1043,6 +1043,8 @@ pub fn einsum( let output: ValTensor = output.into(); + region.increment_einsum_index(1); + Ok(output) } @@ -1068,10 +1070,10 @@ pub fn freivalds( &config.check_mode, )?; - region.increment_einsum_index(1); - let output: ValTensor = output_tensor.into(); + region.increment_einsum_index(1); + Ok(output) } diff --git a/src/circuit/ops/region.rs b/src/circuit/ops/region.rs index f32acc2f..fb6fe8df 100644 --- a/src/circuit/ops/region.rs +++ b/src/circuit/ops/region.rs @@ -90,7 +90,8 @@ impl ShuffleIndex { pub struct EinsumIndex { index: usize, col_coord: usize, - equations: HashSet, + // (einsum index, einsum equation) -> (input axes to dimensions map) + equations: HashMap<(usize, String), HashMap>, num_inner_cols: usize, } @@ -100,7 +101,7 @@ impl EinsumIndex { EinsumIndex { index, col_coord, - equations: HashSet::new(), + equations: HashMap::new(), num_inner_cols, } } @@ -115,11 +116,6 @@ impl EinsumIndex { self.col_coord } - /// Get the equations - pub fn equations(&self) -> &HashSet { - &self.equations - } - /// update with another einsum index pub fn update(&mut self, other: &EinsumIndex) { self.index += other.index; @@ -605,8 +601,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a } /// add used einsum equation - pub fn add_used_einsum_equation(&mut self, equation: String) -> Result<(), CircuitError> { - self.einsum_index.equations.insert(equation); + pub fn add_used_einsum_equation( + &mut self, + equation: String, + input_axes_to_dims: &HashMap, + ) -> Result<(), CircuitError> { + self.einsum_index.equations.insert((self.einsum_index(), equation), input_axes_to_dims.clone()); Ok(()) } @@ -656,7 +656,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a } /// get used einsum equations - pub fn used_einsum_equations(&self) -> HashSet { + pub fn used_einsum_equations(&self) -> HashMap<(usize, String), HashMap> { self.einsum_index.equations.clone() } diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 59ceb3ee..9407314e 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -54,7 +54,7 @@ mod matmul { ) -> Self::Config { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 1; config @@ -175,7 +175,7 @@ mod matmul_col_overflow_double_col { ) -> Self::Config { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 1; config @@ -300,7 +300,7 @@ mod matmul_col_overflow { ) -> Self::Config { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 1; config @@ -417,11 +417,12 @@ mod matmul_col_ultra_overflow_double_col { ) -> Self::Config { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); config .configure_einsums(cs, &analysis, NUM_INNER_COLS, K) .unwrap(); + let _constant = VarTensor::constant_cols(cs, K, 2, false); config } @@ -577,7 +578,7 @@ mod matmul_col_ultra_overflow { ) -> Self::Config { let mut config = Self::Config::default(); let mut equations = HashMap::new(); - equations.insert(params.equation, params.input_axes_to_dims); + equations.insert((0, params.equation), params.input_axes_to_dims); let analysis = analyze_einsum_usage(&equations).unwrap(); let num_einsum_inner_cols = 1; config @@ -608,7 +609,6 @@ mod matmul_col_ultra_overflow { .iter() .map(|c| layouter.get_challenge(*c)) .collect_vec(); - println!("challenges: {:?}", challenges); layouter .assign_region( @@ -2537,7 +2537,7 @@ mod matmul_relu { #[derive(Clone)] struct MyCircuit { inputs: [ValTensor; 2], - _marker: PhantomData, + einsum_params: SingleEinsumParams, } // A columnar ReLu MLP @@ -2548,13 +2548,54 @@ mod matmul_relu { impl Circuit for MyCircuit { type Config = MyConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = TestParams; + type FloorPlanner = V1; + type Params = SingleEinsumParams; fn without_witnesses(&self) -> Self { self.clone() } + fn params(&self) -> Self::Params { + SingleEinsumParams::::new( + &self.einsum_params.equation, + &[ + &self.inputs[0].get_inner().unwrap(), + &self.inputs[1].get_inner().unwrap(), + ], + ) + .unwrap() + } + + fn configure_with_params( + cs: &mut ConstraintSystem, + params: Self::Params, + ) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + let mut base_config = + BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE); + + base_config + .configure_range_check(cs, &a, &b, (-1, 1), K) + .unwrap(); + + base_config + .configure_range_check(cs, &a, &b, (0, 1023), K) + .unwrap(); + + let mut equations = HashMap::new(); + equations.insert((0, params.equation), params.input_axes_to_dims); + let analysis = analyze_einsum_usage(&equations).unwrap(); + base_config + .configure_einsums(cs, &analysis, 1, K) + .unwrap(); + let _constant = VarTensor::constant_cols(cs, K, 8, false); + + MyConfig { base_config } + } + fn configure(cs: &mut ConstraintSystem) -> Self::Config { let a = VarTensor::new_advice(cs, K, 1, LEN); let b = VarTensor::new_advice(cs, K, 1, LEN); @@ -2585,12 +2626,26 @@ mod matmul_relu { .base_config .layout_range_checks(&mut layouter) .unwrap(); + let challenges = config + .base_config + .einsums + .challenges() + .iter() + .map(|c| layouter.get_challenge(*c)) + .collect_vec(); layouter.assign_region( || "", |region| { - let mut region = RegionCtx::new(region, 0, 1, 1024, 2); + let mut region = RegionCtx::new_with_challenges( + region, + 0, + 1, + 1024, + 2, + challenges.clone(), + ); let op = PolyOp::Einsum { - equation: "ij,jk->ik".to_string(), + equation: self.einsum_params.equation.clone(), }; let output = config .base_config @@ -2625,9 +2680,11 @@ mod matmul_relu { let mut b = Tensor::from((0..LEN).map(|_| Value::known(F::from(1)))); b.reshape(&[LEN, 1]).unwrap(); + let einsum_params = SingleEinsumParams::::new("ij,jk->ik", &[&a, &b]).unwrap(); + let circuit = MyCircuit { inputs: [ValTensor::from(a), ValTensor::from(b)], - _marker: PhantomData, + einsum_params, }; let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 4029a2fe..850209fb 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -64,6 +64,7 @@ use pyo3::types::PyDictMethods; use pyo3::IntoPyObject; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::ops::Deref; pub use utilities::*; pub use vars::*; @@ -442,7 +443,7 @@ pub struct ShuffleParams { /// Parameters for einsum operations pub struct EinsumParams { /// einsum equations - pub equations: Vec, + pub equations: Vec<(String, HashMap)>, /// total einsum column size pub total_einsum_col_size: usize, } diff --git a/src/graph/model.rs b/src/graph/model.rs index 937b0506..92a8f5c9 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -3,6 +3,7 @@ use super::extract_const_quantized_values; use super::node::*; use super::vars::*; use super::GraphSettings; +use crate::circuit::einsum::analysis::analyze_einsum_usage; use crate::circuit::hybrid::HybridOp; use crate::circuit::region::ConstantsMap; use crate::circuit::region::RegionCtx; @@ -37,7 +38,6 @@ use log::{debug, info, trace}; use serde::Deserialize; use serde::Serialize; use std::collections::BTreeMap; -#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use std::collections::HashMap; use std::collections::HashSet; use std::fs; @@ -1050,6 +1050,7 @@ impl Model { let lookup_range = settings.run_args.lookup_range; let logrows = settings.run_args.logrows as usize; + let num_inner_cols = settings.run_args.num_inner_cols; let required_lookups = settings.required_lookups.clone(); let required_range_checks = settings.required_range_checks.clone(); @@ -1098,6 +1099,25 @@ impl Model { )?; } + // Configures the circuit to use Freivalds' argument + // If some models get slow down, conditionally configure to use Freivalds' argument + let used_einsums: HashMap<(usize, String), HashMap> = settings + .einsum_params + .equations + .iter() + .enumerate() + .map(|(idx, (equation, indices_to_dims))| { + ((idx, equation.clone()), indices_to_dims.clone()) + }) + .collect(); + let analysis = analyze_einsum_usage(&used_einsums)?; + base_gate.configure_einsums( + meta, + &analysis, + num_inner_cols, + logrows + )?; + Ok(base_gate) } @@ -1150,17 +1170,26 @@ impl Model { let original_constants = constants.clone(); + let challenges = config + .base + .einsums + .challenges() + .iter() + .map(|c| layouter.get_challenge(*c)) + .collect_vec(); + let outputs = layouter.assign_region( || "model", |region| { - let mut thread_safe_region = RegionCtx::new_with_constants( + let mut thread_safe_region = RegionCtx::new_with_challenges( region, 0, run_args.num_inner_cols, run_args.decomp_base, run_args.decomp_legs, - original_constants.clone(), + challenges.clone(), ); + thread_safe_region.update_constants(original_constants.clone()); // we need to do this as this loop is called multiple times vars.set_instance_idx(instance_idx); @@ -1533,7 +1562,11 @@ impl Model { total_shuffle_col_size: region.shuffle_col_coord(), }, einsum_params: crate::graph::EinsumParams { - equations: region.used_einsum_equations().into_iter().collect(), + equations: region + .used_einsum_equations() + .into_iter() + .map(|((_, equation), axes_to_dims)| (equation, axes_to_dims)) + .collect(), total_einsum_col_size: region.einsum_col_coord(), }, total_const_size: region.total_constants(), diff --git a/src/lib.rs b/src/lib.rs index daf9e92c..2ac80347 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -171,7 +171,6 @@ pub mod pfsys; pub mod srs_sha; /// An implementation of multi-dimensional tensors. pub mod tensor; - #[cfg(feature = "ios-bindings")] uniffi::setup_scaffolding!(); #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]