Use Freivalds' as default when configuring graph circuit

This commit is contained in:
DoHoonKim
2025-08-28 15:52:21 +09:00
parent fa548efb7f
commit e3355dbf69
13 changed files with 269 additions and 196 deletions

View File

@@ -6,13 +6,14 @@ use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::{create_keys, create_proof_circuit, TranscriptType};
use ezkl::pfsys::srs::gen_srs;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::plonk::create_proof;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer};
use halo2_proofs::{
arithmetic::Field,
@@ -23,6 +24,7 @@ use halo2curves::bn256::{Bn256, Fr};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::collections::HashMap;
use std::marker::PhantomData;
@@ -48,13 +50,14 @@ impl Circuit<Fr> for MyCircuit<Fr> {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 2;
unsafe {
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
}
config
@@ -72,21 +75,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let mut config = Self::Config::default();
let default_params = Self::Params::default();
let mut equations = HashMap::new();
equations.insert(default_params.equation, default_params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
unsafe {
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
}
config
unimplemented!("call configure_with_params instead")
}
fn synthesize(
@@ -133,11 +122,11 @@ fn runmatmul(c: &mut Criterion) {
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
group.sampling_mode(criterion::SamplingMode::Flat);
group.sample_size(10);
let len = 512;
let len = 128;
unsafe {
LEN = len;
}
for k in 19..20 {
for k in 15..16 {
let params = unsafe {
K = k;
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
@@ -170,24 +159,36 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
b.iter(|| {
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
create_proof::<KZGCommitmentScheme<_>, ProverSHPLONK<_>, _, _, _, _>(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit<Fr>,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
&pk,
&[circuit.clone()],
&[&[]],
OsRng,
&mut transcript,
)
.expect("proof generation should not fail");
transcript.finalize();
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();
});
});
}
group.finish();
}
criterion_group!(benches, runmatmul);
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runmatmul
}
criterion_main!(benches);

View File

@@ -78,7 +78,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
@@ -99,20 +99,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
.unwrap()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let mut config = Self::Config::default();
let default_params = Self::Params::default();
let mut equations = HashMap::new();
equations.insert(default_params.equation, default_params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(

View File

@@ -85,7 +85,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
@@ -106,20 +106,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
.unwrap()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let mut config = Self::Config::default();
let default_params = Self::Params::default();
let mut equations = HashMap::new();
equations.insert(default_params.equation, default_params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(

View File

@@ -85,12 +85,13 @@ impl Circuit<Fr> for MyCircuit<Fr> {
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 2;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
config
}
@@ -106,20 +107,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
.unwrap()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let mut config = Self::Config::default();
let default_params = Self::Params::default();
let mut equations = HashMap::new();
equations.insert(default_params.equation, default_params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(

View File

@@ -47,7 +47,7 @@ pub struct SingleEquationAnalysis {
///
pub fn analyze_einsum_usage(
equations: &HashMap<String, HashMap<char, usize>>,
equations: &HashMap<(usize, String), HashMap<char, usize>>,
) -> Result<EinsumAnalysis, CircuitError> {
let mut max_num_inputs = 0;
let mut max_input_size = 0;
@@ -56,7 +56,7 @@ pub fn analyze_einsum_usage(
let mut longest_challenge_vector = 0;
let mut reduction_length = 0;
for (equation, input_axes_to_dim) in equations.iter() {
for ((_, equation), input_axes_to_dim) in equations.iter() {
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
max_input_size = max_input_size.max(analysis.max_input_size);
longest_challenge_vector = longest_challenge_vector.max(analysis.longest_challenge_vector);
@@ -126,7 +126,7 @@ pub fn analyze_single_equation(
.iter()
.map(|c| input_axes_to_dim.get(&c).unwrap());
let output_size = output_dims.clone().product();
let longest_challenge_vector = *output_dims.clone().max().unwrap();
let longest_challenge_vector = *output_dims.clone().max().unwrap_or(&0);
let output_reduction_length = {
let mut output_dims = output_dims.rev().cloned().collect_vec();

View File

@@ -1,4 +1,3 @@
use halo2_proofs::circuit::Value;
use halo2curves::ff::PrimeField;
use log::{error, trace};
@@ -41,11 +40,12 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.flush_einsum()?;
let vars = config.get_vartensors(phases.as_slice().into());
let input_vars = config.get_input_vars(phases.as_slice().into());
let output_var = config.get_output_var(phases.as_slice().into());
let inputs = [lhs, rhs]
.iter()
.zip(vars)
.zip(input_vars)
.map(|(val, var)| {
let res = region.assign_einsum(var, val)?;
Ok(res.get_inner()?)
@@ -66,13 +66,13 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
})?;
let assigned_len = op_result.len();
let mut output = region.assign_einsum(&config.output, &op_result.into())?;
let mut output = region.assign_einsum(output_var, &op_result.into())?;
// Enable the selectors
if !region.is_dummy() {
(0..assigned_len)
.map(|i| {
let (x, y, z) = config.output.cartesian_coord(region.einsum_col_coord() + i);
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
let op_info = BaseOpInfo {
op_kind: op.clone(),
input_phases: phases.as_slice().into(),
@@ -107,15 +107,12 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.flush_einsum()?;
let mut input = values[0].clone();
let block_width = config.output.num_inner_cols();
let block_width = config.block_width();
let assigned_len: usize;
let input = {
// FIXME : should pad with constant zero but currently this incurs an error
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
// value to advice column, how to workaround this issue?
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
let var = config.get_vartensors([phase].as_slice().into())[0];
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region
.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
@@ -125,8 +122,9 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
&config.output,
output_var,
&accumulated_sum.into(),
check_mode,
)?;
@@ -134,9 +132,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// enable the selectors
if !region.is_dummy() {
for i in 0..output_assigned_len {
let (x, _, z) = config
.output
.cartesian_coord(region.einsum_col_coord() + i * block_width);
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
continue;
@@ -176,15 +172,12 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
) -> Result<ValTensor<F>, CircuitError> {
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let block_width = config.output.num_inner_cols();
let block_width = config.block_width();
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
// FIXME : should pad with constant one but currently this incurs an error
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
// value to advice column, how to workaround this issue?
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ONE)))?;
let var = config.get_vartensors([phase].as_slice().into())[0];
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region
.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
@@ -194,8 +187,9 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// Now we can assign the dot product
let accumulated_prod = accumulated::prod(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
&config.output,
output_var,
&accumulated_prod.into(),
check_mode,
)?;
@@ -204,9 +198,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) = config
.output
.cartesian_coord(region.einsum_col_coord() + i * block_width);
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
@@ -259,17 +251,14 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
} else {
[values[1].clone(), values[0].clone()]
};
let vars = config.get_vartensors(phases.as_slice().into());
let vars = config.get_input_vars(phases.as_slice().into());
let mut inputs = vec![];
let block_width = config.output.num_inner_cols();
let block_width = config.block_width();
let mut assigned_len = 0;
for (val, var) in values.iter_mut().zip(vars) {
// FIXME : should pad with constant zero but currently this incurs an error
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
// value to advice column, how to workaround this issue?
val.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
assigned_len = len;
@@ -281,8 +270,9 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// Now we can assign the dot product
// time this step
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
let output_var = config.get_output_var(phases.as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
&config.output,
output_var,
&accumulated_dot.into(),
check_mode,
).expect("failed to assign einsum with duplication constrained");
@@ -291,9 +281,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) = config
.output
.cartesian_coord(region.einsum_col_coord() + i * block_width);
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// hop over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());

View File

@@ -26,7 +26,7 @@ mod reduction_planner;
#[derive(Clone, Debug, Default)]
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
/// custom gate to constrain tensor contractions
custom_gate: ContractionConfig<F>,
contraction_gate: ContractionConfig<F>,
/// custom gate to constrain random linear combinations used by Freivalds' argument
rlc_gates: Vec<RLCConfig<F>>,
}
@@ -35,17 +35,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
///
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let dummy_custom_gate = ContractionConfig {
let dummy_contraction_gate = ContractionConfig {
inputs: [
[dummy_var.clone(), dummy_var.clone()],
[dummy_var.clone(), dummy_var.clone()],
],
output: dummy_var.clone(),
outputs: [dummy_var.clone(), dummy_var.clone()],
selectors: BTreeMap::default(),
_marker: PhantomData,
};
Self {
custom_gate: dummy_custom_gate,
contraction_gate: dummy_contraction_gate,
rlc_gates: vec![],
}
}
@@ -72,21 +72,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
];
let output = VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity);
let custom_gate = ContractionConfig::new(
let outputs = [
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity)
];
let contraction_gate = ContractionConfig::new(
meta,
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
&output,
&[&outputs[0], &outputs[1]],
);
let mut rlc_gates = vec![];
for _ in 0..analysis.max_num_output_axes {
let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &output);
let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]);
rlc_gates.push(rlc_gate);
}
Self {
custom_gate,
contraction_gate,
rlc_gates,
}
}
@@ -100,7 +103,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
equation: &str,
check_mode: &CheckMode,
) -> Result<(), CircuitError> {
region.set_num_einsum_inner_cols(self.custom_gate.output.num_inner_cols());
region.set_num_einsum_inner_cols(self.contraction_gate.block_width());
let (input_exprs, _) = equation.split_once("->").unwrap();
let input_exprs = input_exprs.split(",").collect_vec();
@@ -109,22 +112,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
let mut output_tensor = output_tensor.clone();
// Remove trivial axes from tensors
input_tensors
.iter_mut()
.map(|tensor| tensor.remove_trivial_axes())
.collect::<Result<Vec<_>, TensorError>>()?;
output_tensor.remove_trivial_axes()?;
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
input_exprs
.iter()
.zip(input_tensors.iter())
.for_each(|(indices, tensor)| {
let tensor_dim = tensor.dims();
indices
.chars()
.zip(tensor_dim.iter())
.zip(tensor.dims())
.for_each(|(index, dim)| {
if let std::collections::hash_map::Entry::Vacant(e) =
input_axes_to_dim.entry(index)
@@ -134,6 +129,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
});
});
// Remove trivial axes from tensors
input_tensors
.iter_mut()
.map(|tensor| tensor.remove_trivial_axes())
.collect::<Result<Vec<_>, TensorError>>()?;
output_tensor.remove_trivial_axes()?;
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
let equation = equation_analysis.equation;
@@ -234,7 +236,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
Some(axis) => {
let dot_product_len = input_axes_to_dim[axis];
assign_input_contraction(
&self.custom_gate,
&self.contraction_gate,
region,
flattened_input_tensors,
dot_product_len,
@@ -245,7 +247,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
}
None => {
let mut result = assign_pairwise_mult(
&self.custom_gate,
&self.contraction_gate,
region,
flattened_input_tensors,
input_phases,
@@ -264,7 +266,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
.collect_vec()
.into();
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1, check_mode)?;
let squashed_input = prod(&self.contraction_gate, region, &[&scalars], 1, check_mode)?;
region.constrain_equal(&squashed_input, &squashed_output)
}
@@ -404,20 +406,20 @@ struct BaseOpInfo {
pub input_phases: InputPhases,
}
/// `ContractionConfig` is the custom gate used for einsum contraction operations
/// `ContractionConfig` is the custom gate to constrain tensor contractions
#[derive(Clone, Debug, Default)]
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
// [[phase 0, phase 0], [phase 1, phase 1]]
// [[first phase, first phase], [second phase, second phase]]
inputs: [[VarTensor; 2]; 2],
// phase 1
output: VarTensor,
// [first phase, second phase]
outputs: [VarTensor; 2],
// (BaseOpInfo, block index, inner column index) -> selector
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
fn get_vartensors(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
match input_phases {
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
@@ -427,19 +429,35 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
}
}
fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor {
match input_phases {
InputPhases::FirstPhase => &self.outputs[0],
InputPhases::SecondPhase => &self.outputs[1],
InputPhases::BothFirstPhase => &self.outputs[0],
InputPhases::Mixed => &self.outputs[1],
InputPhases::BothSecondPhase => &self.outputs[1],
}
}
fn block_width(&self) -> usize {
self.outputs[0].num_inner_cols()
}
fn new(
meta: &mut ConstraintSystem<F>,
inputs: &[&[&VarTensor; 2]; 2],
output: &VarTensor,
outputs: &[&VarTensor; 2],
) -> Self {
let mut selectors = BTreeMap::new();
let num_blocks = outputs[0].num_blocks();
let block_width = outputs[0].num_inner_cols();
for input_phases in [
InputPhases::BothFirstPhase,
InputPhases::Mixed,
InputPhases::BothSecondPhase,
] {
for i in 0..output.num_blocks() {
for j in 0..output.num_inner_cols() {
for i in 0..num_blocks {
for j in 0..block_width {
selectors.insert(
(
BaseOpInfo {
@@ -452,7 +470,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
meta.selector(),
);
}
for i in 0..output.num_blocks() {
for i in 0..num_blocks {
selectors.insert(
(
BaseOpInfo {
@@ -483,7 +501,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
InputPhases::FirstPhase,
InputPhases::SecondPhase,
] {
for i in 0..output.num_blocks() {
for i in 0..num_blocks {
selectors.insert(
(
BaseOpInfo {
@@ -538,6 +556,13 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
};
let output = match base_op.input_phases {
InputPhases::FirstPhase => outputs[0],
InputPhases::SecondPhase => outputs[1],
InputPhases::BothFirstPhase => outputs[0],
InputPhases::Mixed => outputs[1],
InputPhases::BothSecondPhase => outputs[1],
};
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
match base_op.op_kind {
BaseOp::Mult => {
@@ -549,7 +574,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("einsum op config: input query failed")[0]
.expect("contraction config: input query failed")[0]
.clone()
}
// Get output expressions for each input channel
@@ -557,7 +582,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("einsum op config: output query failed");
.expect("contraction config: output query failed");
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
@@ -572,7 +597,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_whole_block(meta, *block_idx, 0, 1)
.expect("einsum op config: input query failed")
.expect("contraction config: input query failed")
.into_iter()
.collect()
}
@@ -580,7 +605,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
.expect("einsum op config: output query failed");
.expect("contraction config: output query failed");
let res = base_op.op_kind.accum_f(
expected_output[0].clone(),
@@ -613,7 +638,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
Self {
inputs: [first_phase_inputs, second_phase_inputs],
output: output.clone(),
outputs: [outputs[0].clone(), outputs[1].clone()],
selectors,
_marker: PhantomData,
}
@@ -624,7 +649,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
#[derive(Clone, Debug)]
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
pub challenge: Challenge,
/// [phase 0, phase 1]
/// [first phase, second phase]
pub inputs: [VarTensor; 2],
pub output: VarTensor,
/// (phase of input, block index) -> (init selector, acc selector)
@@ -731,7 +756,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
.zip(powers_of_challenge.iter().rev())
.map(|(v, c_power)| {
c_power.and_then(|c_power| {
Value::known(c_power * v.get_felt_eval().unwrap())
v.get_felt_eval().and_then(|v| {
Some(Value::known(c_power * v))
}).unwrap_or(Value::unknown())
})
})
.reduce(|acc, v| acc + v)
@@ -743,7 +770,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
let assigned_len = {
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (_, len) = region
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
len

View File

@@ -831,17 +831,9 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
// Track the einsum equation
region.add_used_einsum_equation(equation.to_string())?;
// dispatch to freivalds' argument
if !config.einsums.challenges().is_empty() {
return freivalds(config, region, inputs, equation);
}
let mut equation = equation.split("->");
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let output_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
@@ -864,6 +856,14 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
}
}
// Track the einsum equation
region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?;
// dispatch to freivalds' argument
if !config.einsums.challenges().is_empty() {
return freivalds(config, region, inputs, equation);
}
// maps unrepresented indices in the output to a trivial 1
for c in output_eq.chars() {
indices_to_size.entry(c).or_insert(1);
@@ -1043,6 +1043,8 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let output: ValTensor<F> = output.into();
region.increment_einsum_index(1);
Ok(output)
}
@@ -1068,10 +1070,10 @@ pub fn freivalds<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&config.check_mode,
)?;
region.increment_einsum_index(1);
let output: ValTensor<F> = output_tensor.into();
region.increment_einsum_index(1);
Ok(output)
}

View File

@@ -90,7 +90,8 @@ impl ShuffleIndex {
pub struct EinsumIndex {
index: usize,
col_coord: usize,
equations: HashSet<String>,
// (einsum index, einsum equation) -> (input axes to dimensions map)
equations: HashMap<(usize, String), HashMap<char, usize>>,
num_inner_cols: usize,
}
@@ -100,7 +101,7 @@ impl EinsumIndex {
EinsumIndex {
index,
col_coord,
equations: HashSet::new(),
equations: HashMap::new(),
num_inner_cols,
}
}
@@ -115,11 +116,6 @@ impl EinsumIndex {
self.col_coord
}
/// Get the equations
pub fn equations(&self) -> &HashSet<String> {
&self.equations
}
/// update with another einsum index
pub fn update(&mut self, other: &EinsumIndex) {
self.index += other.index;
@@ -605,8 +601,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// add used einsum equation
pub fn add_used_einsum_equation(&mut self, equation: String) -> Result<(), CircuitError> {
self.einsum_index.equations.insert(equation);
pub fn add_used_einsum_equation(
&mut self,
equation: String,
input_axes_to_dims: &HashMap<char, usize>,
) -> Result<(), CircuitError> {
self.einsum_index.equations.insert((self.einsum_index(), equation), input_axes_to_dims.clone());
Ok(())
}
@@ -656,7 +656,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// get used einsum equations
pub fn used_einsum_equations(&self) -> HashSet<String> {
pub fn used_einsum_equations(&self) -> HashMap<(usize, String), HashMap<char, usize>> {
self.einsum_index.equations.clone()
}

View File

@@ -54,7 +54,7 @@ mod matmul {
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
@@ -175,7 +175,7 @@ mod matmul_col_overflow_double_col {
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
@@ -300,7 +300,7 @@ mod matmul_col_overflow {
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
@@ -417,11 +417,12 @@ mod matmul_col_ultra_overflow_double_col {
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
config
.configure_einsums(cs, &analysis, NUM_INNER_COLS, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
config
}
@@ -577,7 +578,7 @@ mod matmul_col_ultra_overflow {
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
@@ -608,7 +609,6 @@ mod matmul_col_ultra_overflow {
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
println!("challenges: {:?}", challenges);
layouter
.assign_region(
@@ -2537,7 +2537,7 @@ mod matmul_relu {
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
_marker: PhantomData<F>,
einsum_params: SingleEinsumParams<F>,
}
// A columnar ReLu MLP
@@ -2548,13 +2548,54 @@ mod matmul_relu {
impl Circuit<F> for MyCircuit<F> {
type Config = MyConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
type FloorPlanner = V1;
type Params = SingleEinsumParams<F>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn params(&self) -> Self::Params {
SingleEinsumParams::<F>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure_with_params(
cs: &mut ConstraintSystem<F>,
params: Self::Params,
) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut base_config =
BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
base_config
.configure_range_check(cs, &a, &b, (-1, 1), K)
.unwrap();
base_config
.configure_range_check(cs, &a, &b, (0, 1023), K)
.unwrap();
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
base_config
.configure_einsums(cs, &analysis, 1, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 8, false);
MyConfig { base_config }
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
@@ -2585,12 +2626,26 @@ mod matmul_relu {
.base_config
.layout_range_checks(&mut layouter)
.unwrap();
let challenges = config
.base_config
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
equation: self.einsum_params.equation.clone(),
};
let output = config
.base_config
@@ -2625,9 +2680,11 @@ mod matmul_relu {
let mut b = Tensor::from((0..LEN).map(|_| Value::known(F::from(1))));
b.reshape(&[LEN, 1]).unwrap();
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
_marker: PhantomData,
einsum_params,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();

View File

@@ -64,6 +64,7 @@ use pyo3::types::PyDictMethods;
use pyo3::IntoPyObject;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Deref;
pub use utilities::*;
pub use vars::*;
@@ -442,7 +443,7 @@ pub struct ShuffleParams {
/// Parameters for einsum operations
pub struct EinsumParams {
/// einsum equations
pub equations: Vec<String>,
pub equations: Vec<(String, HashMap<char, usize>)>,
/// total einsum column size
pub total_einsum_col_size: usize,
}

View File

@@ -3,6 +3,7 @@ use super::extract_const_quantized_values;
use super::node::*;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::einsum::analysis::analyze_einsum_usage;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
@@ -37,7 +38,6 @@ use log::{debug, info, trace};
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
@@ -1050,6 +1050,7 @@ impl Model {
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let num_inner_cols = settings.run_args.num_inner_cols;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
@@ -1098,6 +1099,25 @@ impl Model {
)?;
}
// Configures the circuit to use Freivalds' argument
// If some models get slow down, conditionally configure to use Freivalds' argument
let used_einsums: HashMap<(usize, String), HashMap<char, usize>> = settings
.einsum_params
.equations
.iter()
.enumerate()
.map(|(idx, (equation, indices_to_dims))| {
((idx, equation.clone()), indices_to_dims.clone())
})
.collect();
let analysis = analyze_einsum_usage(&used_einsums)?;
base_gate.configure_einsums(
meta,
&analysis,
num_inner_cols,
logrows
)?;
Ok(base_gate)
}
@@ -1150,17 +1170,26 @@ impl Model {
let original_constants = constants.clone();
let challenges = config
.base
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new_with_constants(
let mut thread_safe_region = RegionCtx::new_with_challenges(
region,
0,
run_args.num_inner_cols,
run_args.decomp_base,
run_args.decomp_legs,
original_constants.clone(),
challenges.clone(),
);
thread_safe_region.update_constants(original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1533,7 +1562,11 @@ impl Model {
total_shuffle_col_size: region.shuffle_col_coord(),
},
einsum_params: crate::graph::EinsumParams {
equations: region.used_einsum_equations().into_iter().collect(),
equations: region
.used_einsum_equations()
.into_iter()
.map(|((_, equation), axes_to_dims)| (equation, axes_to_dims))
.collect(),
total_einsum_col_size: region.einsum_col_coord(),
},
total_const_size: region.total_constants(),

View File

@@ -171,7 +171,6 @@ pub mod pfsys;
pub mod srs_sha;
/// An implementation of multi-dimensional tensors.
pub mod tensor;
#[cfg(feature = "ios-bindings")]
uniffi::setup_scaffolding!();
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]