mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
Use Freivalds' as default when configuring graph circuit
This commit is contained in:
@@ -6,13 +6,14 @@ use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::{create_keys, create_proof_circuit, TranscriptType};
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::plonk::create_proof;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer};
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -23,6 +24,7 @@ use halo2curves::bn256::{Bn256, Fr};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -48,13 +50,14 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 2;
|
||||
unsafe {
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
}
|
||||
|
||||
config
|
||||
@@ -72,21 +75,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
unsafe {
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
config
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -133,11 +122,11 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
|
||||
group.sampling_mode(criterion::SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
let len = 512;
|
||||
let len = 128;
|
||||
unsafe {
|
||||
LEN = len;
|
||||
}
|
||||
for k in 19..20 {
|
||||
for k in 15..16 {
|
||||
let params = unsafe {
|
||||
K = k;
|
||||
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
|
||||
@@ -170,24 +159,36 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
|
||||
|
||||
create_proof::<KZGCommitmentScheme<_>, ProverSHPLONK<_>, _, _, _, _>(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit<Fr>,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
&pk,
|
||||
&[circuit.clone()],
|
||||
&[&[]],
|
||||
OsRng,
|
||||
&mut transcript,
|
||||
)
|
||||
.expect("proof generation should not fail");
|
||||
|
||||
transcript.finalize();
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, runmatmul);
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().with_plots();
|
||||
targets = runmatmul
|
||||
}
|
||||
criterion_main!(benches);
|
||||
|
||||
@@ -78,7 +78,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
@@ -99,20 +99,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
|
||||
@@ -85,7 +85,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
@@ -106,20 +106,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
|
||||
@@ -85,12 +85,13 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 2;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
|
||||
config
|
||||
}
|
||||
@@ -106,20 +107,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
|
||||
@@ -47,7 +47,7 @@ pub struct SingleEquationAnalysis {
|
||||
|
||||
///
|
||||
pub fn analyze_einsum_usage(
|
||||
equations: &HashMap<String, HashMap<char, usize>>,
|
||||
equations: &HashMap<(usize, String), HashMap<char, usize>>,
|
||||
) -> Result<EinsumAnalysis, CircuitError> {
|
||||
let mut max_num_inputs = 0;
|
||||
let mut max_input_size = 0;
|
||||
@@ -56,7 +56,7 @@ pub fn analyze_einsum_usage(
|
||||
let mut longest_challenge_vector = 0;
|
||||
let mut reduction_length = 0;
|
||||
|
||||
for (equation, input_axes_to_dim) in equations.iter() {
|
||||
for ((_, equation), input_axes_to_dim) in equations.iter() {
|
||||
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
|
||||
max_input_size = max_input_size.max(analysis.max_input_size);
|
||||
longest_challenge_vector = longest_challenge_vector.max(analysis.longest_challenge_vector);
|
||||
@@ -126,7 +126,7 @@ pub fn analyze_single_equation(
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap());
|
||||
let output_size = output_dims.clone().product();
|
||||
let longest_challenge_vector = *output_dims.clone().max().unwrap();
|
||||
let longest_challenge_vector = *output_dims.clone().max().unwrap_or(&0);
|
||||
|
||||
let output_reduction_length = {
|
||||
let mut output_dims = output_dims.rev().cloned().collect_vec();
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use log::{error, trace};
|
||||
|
||||
@@ -41,11 +40,12 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
region.flush_einsum()?;
|
||||
|
||||
let vars = config.get_vartensors(phases.as_slice().into());
|
||||
let input_vars = config.get_input_vars(phases.as_slice().into());
|
||||
let output_var = config.get_output_var(phases.as_slice().into());
|
||||
|
||||
let inputs = [lhs, rhs]
|
||||
.iter()
|
||||
.zip(vars)
|
||||
.zip(input_vars)
|
||||
.map(|(val, var)| {
|
||||
let res = region.assign_einsum(var, val)?;
|
||||
Ok(res.get_inner()?)
|
||||
@@ -66,13 +66,13 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
})?;
|
||||
|
||||
let assigned_len = op_result.len();
|
||||
let mut output = region.assign_einsum(&config.output, &op_result.into())?;
|
||||
let mut output = region.assign_einsum(output_var, &op_result.into())?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.output.cartesian_coord(region.einsum_col_coord() + i);
|
||||
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: op.clone(),
|
||||
input_phases: phases.as_slice().into(),
|
||||
@@ -107,15 +107,12 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region.flush_einsum()?;
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let block_width = config.output.num_inner_cols();
|
||||
let block_width = config.block_width();
|
||||
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
// FIXME : should pad with constant zero but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
let var = config.get_vartensors([phase].as_slice().into())[0];
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let var = config.get_input_vars([phase].as_slice().into())[0];
|
||||
let (res, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
@@ -125,8 +122,9 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// Now we can assign the dot product
|
||||
let accumulated_sum = accumulated::sum(&input, block_width)?;
|
||||
|
||||
let output_var = config.get_output_var([phase].as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
output_var,
|
||||
&accumulated_sum.into(),
|
||||
check_mode,
|
||||
)?;
|
||||
@@ -134,9 +132,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
for i in 0..output_assigned_len {
|
||||
let (x, _, z) = config
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
continue;
|
||||
@@ -176,15 +172,12 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phase == 0 || phase == 1);
|
||||
region.flush_einsum()?;
|
||||
let block_width = config.output.num_inner_cols();
|
||||
let block_width = config.block_width();
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
// FIXME : should pad with constant one but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ONE)))?;
|
||||
let var = config.get_vartensors([phase].as_slice().into())[0];
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
|
||||
let var = config.get_input_vars([phase].as_slice().into())[0];
|
||||
let (res, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
@@ -194,8 +187,9 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// Now we can assign the dot product
|
||||
let accumulated_prod = accumulated::prod(&input, block_width)?;
|
||||
|
||||
let output_var = config.get_output_var([phase].as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
output_var,
|
||||
&accumulated_prod.into(),
|
||||
check_mode,
|
||||
)?;
|
||||
@@ -204,9 +198,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) = config
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
@@ -259,17 +251,14 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
} else {
|
||||
[values[1].clone(), values[0].clone()]
|
||||
};
|
||||
let vars = config.get_vartensors(phases.as_slice().into());
|
||||
let vars = config.get_input_vars(phases.as_slice().into());
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.output.num_inner_cols();
|
||||
let block_width = config.block_width();
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (val, var) in values.iter_mut().zip(vars) {
|
||||
// FIXME : should pad with constant zero but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
val.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
|
||||
assigned_len = len;
|
||||
@@ -281,8 +270,9 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// Now we can assign the dot product
|
||||
// time this step
|
||||
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
|
||||
let output_var = config.get_output_var(phases.as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
output_var,
|
||||
&accumulated_dot.into(),
|
||||
check_mode,
|
||||
).expect("failed to assign einsum with duplication constrained");
|
||||
@@ -291,9 +281,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) = config
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// hop over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
|
||||
@@ -26,7 +26,7 @@ mod reduction_planner;
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// custom gate to constrain tensor contractions
|
||||
custom_gate: ContractionConfig<F>,
|
||||
contraction_gate: ContractionConfig<F>,
|
||||
/// custom gate to constrain random linear combinations used by Freivalds' argument
|
||||
rlc_gates: Vec<RLCConfig<F>>,
|
||||
}
|
||||
@@ -35,17 +35,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
///
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let dummy_custom_gate = ContractionConfig {
|
||||
let dummy_contraction_gate = ContractionConfig {
|
||||
inputs: [
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
],
|
||||
output: dummy_var.clone(),
|
||||
outputs: [dummy_var.clone(), dummy_var.clone()],
|
||||
selectors: BTreeMap::default(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
Self {
|
||||
custom_gate: dummy_custom_gate,
|
||||
contraction_gate: dummy_contraction_gate,
|
||||
rlc_gates: vec![],
|
||||
}
|
||||
}
|
||||
@@ -72,21 +72,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let output = VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity);
|
||||
let custom_gate = ContractionConfig::new(
|
||||
let outputs = [
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity)
|
||||
];
|
||||
let contraction_gate = ContractionConfig::new(
|
||||
meta,
|
||||
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
|
||||
&output,
|
||||
&[&outputs[0], &outputs[1]],
|
||||
);
|
||||
|
||||
let mut rlc_gates = vec![];
|
||||
for _ in 0..analysis.max_num_output_axes {
|
||||
let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &output);
|
||||
let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]);
|
||||
rlc_gates.push(rlc_gate);
|
||||
}
|
||||
|
||||
Self {
|
||||
custom_gate,
|
||||
contraction_gate,
|
||||
rlc_gates,
|
||||
}
|
||||
}
|
||||
@@ -100,7 +103,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
equation: &str,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<(), CircuitError> {
|
||||
region.set_num_einsum_inner_cols(self.custom_gate.output.num_inner_cols());
|
||||
region.set_num_einsum_inner_cols(self.contraction_gate.block_width());
|
||||
|
||||
let (input_exprs, _) = equation.split_once("->").unwrap();
|
||||
let input_exprs = input_exprs.split(",").collect_vec();
|
||||
@@ -109,22 +112,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
|
||||
let mut output_tensor = output_tensor.clone();
|
||||
|
||||
// Remove trivial axes from tensors
|
||||
input_tensors
|
||||
.iter_mut()
|
||||
.map(|tensor| tensor.remove_trivial_axes())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
output_tensor.remove_trivial_axes()?;
|
||||
|
||||
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
|
||||
input_exprs
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.for_each(|(indices, tensor)| {
|
||||
let tensor_dim = tensor.dims();
|
||||
indices
|
||||
.chars()
|
||||
.zip(tensor_dim.iter())
|
||||
.zip(tensor.dims())
|
||||
.for_each(|(index, dim)| {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) =
|
||||
input_axes_to_dim.entry(index)
|
||||
@@ -134,6 +129,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
});
|
||||
});
|
||||
|
||||
// Remove trivial axes from tensors
|
||||
input_tensors
|
||||
.iter_mut()
|
||||
.map(|tensor| tensor.remove_trivial_axes())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
output_tensor.remove_trivial_axes()?;
|
||||
|
||||
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
|
||||
let equation = equation_analysis.equation;
|
||||
|
||||
@@ -234,7 +236,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
Some(axis) => {
|
||||
let dot_product_len = input_axes_to_dim[axis];
|
||||
assign_input_contraction(
|
||||
&self.custom_gate,
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
dot_product_len,
|
||||
@@ -245,7 +247,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
}
|
||||
None => {
|
||||
let mut result = assign_pairwise_mult(
|
||||
&self.custom_gate,
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
input_phases,
|
||||
@@ -264,7 +266,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
|
||||
.collect_vec()
|
||||
.into();
|
||||
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1, check_mode)?;
|
||||
let squashed_input = prod(&self.contraction_gate, region, &[&scalars], 1, check_mode)?;
|
||||
|
||||
region.constrain_equal(&squashed_input, &squashed_output)
|
||||
}
|
||||
@@ -404,20 +406,20 @@ struct BaseOpInfo {
|
||||
pub input_phases: InputPhases,
|
||||
}
|
||||
|
||||
/// `ContractionConfig` is the custom gate used for einsum contraction operations
|
||||
/// `ContractionConfig` is the custom gate to constrain tensor contractions
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
// [[phase 0, phase 0], [phase 1, phase 1]]
|
||||
// [[first phase, first phase], [second phase, second phase]]
|
||||
inputs: [[VarTensor; 2]; 2],
|
||||
// phase 1
|
||||
output: VarTensor,
|
||||
// [first phase, second phase]
|
||||
outputs: [VarTensor; 2],
|
||||
// (BaseOpInfo, block index, inner column index) -> selector
|
||||
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
fn get_vartensors(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
|
||||
fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
|
||||
@@ -427,19 +429,35 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => &self.outputs[0],
|
||||
InputPhases::SecondPhase => &self.outputs[1],
|
||||
InputPhases::BothFirstPhase => &self.outputs[0],
|
||||
InputPhases::Mixed => &self.outputs[1],
|
||||
InputPhases::BothSecondPhase => &self.outputs[1],
|
||||
}
|
||||
}
|
||||
|
||||
fn block_width(&self) -> usize {
|
||||
self.outputs[0].num_inner_cols()
|
||||
}
|
||||
|
||||
fn new(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
inputs: &[&[&VarTensor; 2]; 2],
|
||||
output: &VarTensor,
|
||||
outputs: &[&VarTensor; 2],
|
||||
) -> Self {
|
||||
let mut selectors = BTreeMap::new();
|
||||
let num_blocks = outputs[0].num_blocks();
|
||||
let block_width = outputs[0].num_inner_cols();
|
||||
for input_phases in [
|
||||
InputPhases::BothFirstPhase,
|
||||
InputPhases::Mixed,
|
||||
InputPhases::BothSecondPhase,
|
||||
] {
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
for i in 0..num_blocks {
|
||||
for j in 0..block_width {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
@@ -452,7 +470,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
for i in 0..output.num_blocks() {
|
||||
for i in 0..num_blocks {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
@@ -483,7 +501,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
InputPhases::FirstPhase,
|
||||
InputPhases::SecondPhase,
|
||||
] {
|
||||
for i in 0..output.num_blocks() {
|
||||
for i in 0..num_blocks {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
@@ -538,6 +556,13 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
|
||||
};
|
||||
let output = match base_op.input_phases {
|
||||
InputPhases::FirstPhase => outputs[0],
|
||||
InputPhases::SecondPhase => outputs[1],
|
||||
InputPhases::BothFirstPhase => outputs[0],
|
||||
InputPhases::Mixed => outputs[1],
|
||||
InputPhases::BothSecondPhase => outputs[1],
|
||||
};
|
||||
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
|
||||
match base_op.op_kind {
|
||||
BaseOp::Mult => {
|
||||
@@ -549,7 +574,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("einsum op config: input query failed")[0]
|
||||
.expect("contraction config: input query failed")[0]
|
||||
.clone()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
@@ -557,7 +582,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("einsum op config: output query failed");
|
||||
.expect("contraction config: output query failed");
|
||||
|
||||
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
|
||||
@@ -572,7 +597,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("einsum op config: input query failed")
|
||||
.expect("contraction config: input query failed")
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
@@ -580,7 +605,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
|
||||
.expect("einsum op config: output query failed");
|
||||
.expect("contraction config: output query failed");
|
||||
|
||||
let res = base_op.op_kind.accum_f(
|
||||
expected_output[0].clone(),
|
||||
@@ -613,7 +638,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
|
||||
Self {
|
||||
inputs: [first_phase_inputs, second_phase_inputs],
|
||||
output: output.clone(),
|
||||
outputs: [outputs[0].clone(), outputs[1].clone()],
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -624,7 +649,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
#[derive(Clone, Debug)]
|
||||
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub challenge: Challenge,
|
||||
/// [phase 0, phase 1]
|
||||
/// [first phase, second phase]
|
||||
pub inputs: [VarTensor; 2],
|
||||
pub output: VarTensor,
|
||||
/// (phase of input, block index) -> (init selector, acc selector)
|
||||
@@ -731,7 +756,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
|
||||
.zip(powers_of_challenge.iter().rev())
|
||||
.map(|(v, c_power)| {
|
||||
c_power.and_then(|c_power| {
|
||||
Value::known(c_power * v.get_felt_eval().unwrap())
|
||||
v.get_felt_eval().and_then(|v| {
|
||||
Some(Value::known(c_power * v))
|
||||
}).unwrap_or(Value::unknown())
|
||||
})
|
||||
})
|
||||
.reduce(|acc, v| acc + v)
|
||||
@@ -743,7 +770,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
|
||||
|
||||
let assigned_len = {
|
||||
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let (_, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
|
||||
len
|
||||
|
||||
@@ -831,17 +831,9 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// Track the einsum equation
|
||||
region.add_used_einsum_equation(equation.to_string())?;
|
||||
|
||||
// dispatch to freivalds' argument
|
||||
if !config.einsums.challenges().is_empty() {
|
||||
return freivalds(config, region, inputs, equation);
|
||||
}
|
||||
|
||||
let mut equation = equation.split("->");
|
||||
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let output_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
@@ -864,6 +856,14 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
}
|
||||
}
|
||||
|
||||
// Track the einsum equation
|
||||
region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?;
|
||||
|
||||
// dispatch to freivalds' argument
|
||||
if !config.einsums.challenges().is_empty() {
|
||||
return freivalds(config, region, inputs, equation);
|
||||
}
|
||||
|
||||
// maps unrepresented indices in the output to a trivial 1
|
||||
for c in output_eq.chars() {
|
||||
indices_to_size.entry(c).or_insert(1);
|
||||
@@ -1043,6 +1043,8 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let output: ValTensor<F> = output.into();
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
@@ -1068,10 +1070,10 @@ pub fn freivalds<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&config.check_mode,
|
||||
)?;
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
let output: ValTensor<F> = output_tensor.into();
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,8 @@ impl ShuffleIndex {
|
||||
pub struct EinsumIndex {
|
||||
index: usize,
|
||||
col_coord: usize,
|
||||
equations: HashSet<String>,
|
||||
// (einsum index, einsum equation) -> (input axes to dimensions map)
|
||||
equations: HashMap<(usize, String), HashMap<char, usize>>,
|
||||
num_inner_cols: usize,
|
||||
}
|
||||
|
||||
@@ -100,7 +101,7 @@ impl EinsumIndex {
|
||||
EinsumIndex {
|
||||
index,
|
||||
col_coord,
|
||||
equations: HashSet::new(),
|
||||
equations: HashMap::new(),
|
||||
num_inner_cols,
|
||||
}
|
||||
}
|
||||
@@ -115,11 +116,6 @@ impl EinsumIndex {
|
||||
self.col_coord
|
||||
}
|
||||
|
||||
/// Get the equations
|
||||
pub fn equations(&self) -> &HashSet<String> {
|
||||
&self.equations
|
||||
}
|
||||
|
||||
/// update with another einsum index
|
||||
pub fn update(&mut self, other: &EinsumIndex) {
|
||||
self.index += other.index;
|
||||
@@ -605,8 +601,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// add used einsum equation
|
||||
pub fn add_used_einsum_equation(&mut self, equation: String) -> Result<(), CircuitError> {
|
||||
self.einsum_index.equations.insert(equation);
|
||||
pub fn add_used_einsum_equation(
|
||||
&mut self,
|
||||
equation: String,
|
||||
input_axes_to_dims: &HashMap<char, usize>,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.einsum_index.equations.insert((self.einsum_index(), equation), input_axes_to_dims.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -656,7 +656,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// get used einsum equations
|
||||
pub fn used_einsum_equations(&self) -> HashSet<String> {
|
||||
pub fn used_einsum_equations(&self) -> HashMap<(usize, String), HashMap<char, usize>> {
|
||||
self.einsum_index.equations.clone()
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ mod matmul {
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
@@ -175,7 +175,7 @@ mod matmul_col_overflow_double_col {
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
@@ -300,7 +300,7 @@ mod matmul_col_overflow {
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
@@ -417,11 +417,12 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
config
|
||||
.configure_einsums(cs, &analysis, NUM_INNER_COLS, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
config
|
||||
}
|
||||
|
||||
@@ -577,7 +578,7 @@ mod matmul_col_ultra_overflow {
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
@@ -608,7 +609,6 @@ mod matmul_col_ultra_overflow {
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
println!("challenges: {:?}", challenges);
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
@@ -2537,7 +2537,7 @@ mod matmul_relu {
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
_marker: PhantomData<F>,
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
}
|
||||
|
||||
// A columnar ReLu MLP
|
||||
@@ -2548,13 +2548,54 @@ mod matmul_relu {
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = MyConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
type FloorPlanner = V1;
|
||||
type Params = SingleEinsumParams<F>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
SingleEinsumParams::<F>::new(
|
||||
&self.einsum_params.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure_with_params(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
params: Self::Params,
|
||||
) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let mut base_config =
|
||||
BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
|
||||
base_config
|
||||
.configure_range_check(cs, &a, &b, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
base_config
|
||||
.configure_range_check(cs, &a, &b, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
base_config
|
||||
.configure_einsums(cs, &analysis, 1, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
MyConfig { base_config }
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
@@ -2585,12 +2626,26 @@ mod matmul_relu {
|
||||
.base_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
let challenges = config
|
||||
.base_config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let mut region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
};
|
||||
let output = config
|
||||
.base_config
|
||||
@@ -2625,9 +2680,11 @@ mod matmul_relu {
|
||||
let mut b = Tensor::from((0..LEN).map(|_| Value::known(F::from(1))));
|
||||
b.reshape(&[LEN, 1]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
_marker: PhantomData,
|
||||
einsum_params,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
|
||||
@@ -64,6 +64,7 @@ use pyo3::types::PyDictMethods;
|
||||
use pyo3::IntoPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
pub use utilities::*;
|
||||
pub use vars::*;
|
||||
@@ -442,7 +443,7 @@ pub struct ShuffleParams {
|
||||
/// Parameters for einsum operations
|
||||
pub struct EinsumParams {
|
||||
/// einsum equations
|
||||
pub equations: Vec<String>,
|
||||
pub equations: Vec<(String, HashMap<char, usize>)>,
|
||||
/// total einsum column size
|
||||
pub total_einsum_col_size: usize,
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
@@ -37,7 +38,6 @@ use log::{debug, info, trace};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs;
|
||||
@@ -1050,6 +1050,7 @@ impl Model {
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
let logrows = settings.run_args.logrows as usize;
|
||||
let num_inner_cols = settings.run_args.num_inner_cols;
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
@@ -1098,6 +1099,25 @@ impl Model {
|
||||
)?;
|
||||
}
|
||||
|
||||
// Configures the circuit to use Freivalds' argument
|
||||
// If some models get slow down, conditionally configure to use Freivalds' argument
|
||||
let used_einsums: HashMap<(usize, String), HashMap<char, usize>> = settings
|
||||
.einsum_params
|
||||
.equations
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, (equation, indices_to_dims))| {
|
||||
((idx, equation.clone()), indices_to_dims.clone())
|
||||
})
|
||||
.collect();
|
||||
let analysis = analyze_einsum_usage(&used_einsums)?;
|
||||
base_gate.configure_einsums(
|
||||
meta,
|
||||
&analysis,
|
||||
num_inner_cols,
|
||||
logrows
|
||||
)?;
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1150,17 +1170,26 @@ impl Model {
|
||||
|
||||
let original_constants = constants.clone();
|
||||
|
||||
let challenges = config
|
||||
.base
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(
|
||||
let mut thread_safe_region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
run_args.num_inner_cols,
|
||||
run_args.decomp_base,
|
||||
run_args.decomp_legs,
|
||||
original_constants.clone(),
|
||||
challenges.clone(),
|
||||
);
|
||||
thread_safe_region.update_constants(original_constants.clone());
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1533,7 +1562,11 @@ impl Model {
|
||||
total_shuffle_col_size: region.shuffle_col_coord(),
|
||||
},
|
||||
einsum_params: crate::graph::EinsumParams {
|
||||
equations: region.used_einsum_equations().into_iter().collect(),
|
||||
equations: region
|
||||
.used_einsum_equations()
|
||||
.into_iter()
|
||||
.map(|((_, equation), axes_to_dims)| (equation, axes_to_dims))
|
||||
.collect(),
|
||||
total_einsum_col_size: region.einsum_col_coord(),
|
||||
},
|
||||
total_const_size: region.total_constants(),
|
||||
|
||||
@@ -171,7 +171,6 @@ pub mod pfsys;
|
||||
pub mod srs_sha;
|
||||
/// An implementation of multi-dimensional tensors.
|
||||
pub mod tensor;
|
||||
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
|
||||
Reference in New Issue
Block a user