Fix circuit tests, create temp cargo patch for halo2 due to cross phase assignment issues

This commit is contained in:
DoHoonKim
2025-08-25 20:14:52 +09:00
parent 71e86ade32
commit 67b97f9ab8
10 changed files with 802 additions and 188 deletions

2
Cargo.lock generated
View File

@@ -2431,7 +2431,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
dependencies = [
"bincode",
"blake2b_simd",

View File

@@ -300,7 +300,6 @@ halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -3,6 +3,7 @@ use criterion::{
Throughput,
};
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
@@ -31,54 +32,13 @@ static mut K: usize = 15;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
einsum_params: SingleEinsumParams<F>,
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
type Params = SingleEinsumParams<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
@@ -101,8 +61,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
SingleEinsumParams::<Fr>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
@@ -157,7 +117,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
equation: self.einsum_params.equation.clone(),
}),
)
.unwrap();
@@ -189,11 +149,11 @@ fn runmatmul(c: &mut Criterion) {
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
einsum_params,
};
group.throughput(Throughput::Elements(len as u64));

View File

@@ -91,14 +91,18 @@ pub fn analyze_single_equation(
.map(|input| {
input
.chars()
.filter(|char| input_axes_to_dim.get(char).is_some())
.filter(|char| {
input_axes_to_dim.get(char).is_some() && *input_axes_to_dim.get(char).unwrap() > 1
})
.collect()
})
.collect();
let output = output_str
.chars()
.filter(|c| input_axes_to_dim.get(c).is_some())
.filter(|c| {
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
})
.collect();
[inputs.join(","), output].join("->")

View File

@@ -0,0 +1,54 @@
use std::{collections::HashMap, marker::PhantomData};
use halo2_proofs::circuit::Value;
use halo2curves::ff::PrimeField;
use crate::{
circuit::CircuitError,
tensor::{Tensor, TensorError, TensorType},
};
/// Circuit parameter for a single einsum equation
#[derive(Clone, Debug, Default)]
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
///
pub equation: String,
/// Map from input axes to dimensions
pub input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
///
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}

View File

@@ -3,7 +3,7 @@ use halo2curves::ff::PrimeField;
use log::{error, trace};
use crate::{
circuit::{base::BaseOp, region::RegionCtx, CircuitError},
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
@@ -11,11 +11,11 @@ use crate::{
},
};
use super::EinsumOpConfig;
use super::ContractionConfig;
/// Pairwise (elementwise) op layout
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
op: BaseOp,
@@ -26,7 +26,6 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
} else {
(values[1].clone(), values[0].clone())
};
let min_phase = std::cmp::min(phases[0], phases[1]);
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
@@ -42,9 +41,11 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.flush_einsum()?;
let vars = config.get_vartensors(phases.as_slice().into());
let inputs = [lhs, rhs]
.iter()
.zip(config.inputs.iter().skip(min_phase * 2))
.zip(vars)
.map(|(val, var)| {
let res = region.assign_einsum(var, val)?;
Ok(res.get_inner()?)
@@ -71,8 +72,12 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
if !region.is_dummy() {
(0..assigned_len)
.map(|i| {
let (x, y, z) = config.inputs[0].cartesian_coord(region.einsum_col_coord() + i);
let selector = config.selectors.get(&(min_phase, op.clone(), x, y));
let (x, y, z) = config.output.cartesian_coord(region.einsum_col_coord() + i);
let op_info = BaseOpInfo {
op_kind: op.clone(),
input_phases: phases.as_slice().into(),
};
let selector = config.selectors.get(&(op_info, x, y));
region.enable(selector, z)?;
@@ -88,10 +93,11 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
}
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() == 1 {
return Ok(values[0].clone());
@@ -109,8 +115,9 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
// value to advice column, how to workaround this issue?
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
let var = config.get_vartensors([phase].as_slice().into())[0];
let (res, len) = region
.assign_einsum_with_duplication_unconstrained(&config.inputs[phase * 2], &input)?;
.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
@@ -121,7 +128,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
&config.output,
&accumulated_sum.into(),
&crate::circuit::CheckMode::UNSAFE,
check_mode,
)?;
// enable the selectors
@@ -135,9 +142,17 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
continue;
}
let selector = if i == 0 {
config.selectors.get(&(phase, BaseOp::SumInit, x, 0))
let op_info = BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
config.selectors.get(&(phase, BaseOp::Sum, x, 0))
let op_info = BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
@@ -153,10 +168,11 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
}
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
@@ -168,8 +184,9 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
// value to advice column, how to workaround this issue?
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ONE)))?;
let var = config.get_vartensors([phase].as_slice().into())[0];
let (res, len) = region
.assign_einsum_with_duplication_unconstrained(&config.inputs[phase * 2], &input)?;
.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
@@ -180,7 +197,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
&config.output,
&accumulated_prod.into(),
&crate::circuit::CheckMode::UNSAFE,
check_mode,
)?;
// enable the selectors
@@ -195,9 +212,17 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
return Ok(());
}
let selector = if i == 0 {
config.selectors.get(&(phase, BaseOp::CumProdInit, x, 0))
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
config.selectors.get(&(phase, BaseOp::CumProd, x, 0))
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
@@ -215,10 +240,11 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
}
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
phases: &[usize; 2],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
@@ -233,16 +259,13 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
} else {
[values[1].clone(), values[0].clone()]
};
let min_phase = std::cmp::min(phases[0], phases[1]);
let vars = config.get_vartensors(phases.as_slice().into());
let mut inputs = vec![];
let block_width = config.output.num_inner_cols();
let mut assigned_len = 0;
for (val, var) in values
.iter_mut()
.zip(config.inputs.iter().skip(min_phase * 2))
{
for (val, var) in values.iter_mut().zip(vars) {
// FIXME : should pad with constant zero but currently this incurs an error
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
// value to advice column, how to workaround this issue?
@@ -261,8 +284,8 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
&config.output,
&accumulated_dot.into(),
&crate::circuit::CheckMode::UNSAFE,
)?;
check_mode,
).expect("failed to assign einsum with duplication constrained");
// enable the selectors
if !region.is_dummy() {
@@ -276,9 +299,17 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
return Ok(());
}
let selector = if i == 0 {
config.selectors.get(&(min_phase, BaseOp::DotInit, x, 0))
let op_info = BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
config.selectors.get(&(min_phase, BaseOp::Dot, x, 0))
let op_info = BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
@@ -299,10 +330,11 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Dot product of more than two tensors
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
if !values.iter().all(|value| value.len() == values[0].len()) {
@@ -313,7 +345,7 @@ pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
// do pairwise dot product between intermediate tensor and the next tensor
let (intermediate, _) = values
let (intermediate, output_phase) = values
.into_iter()
.zip(phases.iter().cloned())
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
@@ -331,10 +363,7 @@ pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
})
.unwrap();
// Sum the final tensor
// In current freivalds' algorithm, we ensure that there is no tensor contraction between phase 0 tensors,
// so the phase of the resulting tensor is set to 1
let accumulated_dot = sum(config, region, &[&intermediate], 1)?;
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
let last_elem = accumulated_dot.last()?;
let elapsed = global_start.elapsed();

View File

@@ -3,7 +3,7 @@ use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnal
use crate::circuit::einsum::layouts::{pairwise, sum};
use crate::circuit::einsum::reduction_planner::Reduction;
use crate::circuit::region::RegionCtx;
use crate::circuit::CircuitError;
use crate::circuit::{CheckMode, CircuitError};
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
@@ -17,6 +17,8 @@ use std::marker::PhantomData;
///
pub mod analysis;
///
pub mod circuit_params;
mod layouts;
mod reduction_planner;
@@ -24,7 +26,7 @@ mod reduction_planner;
#[derive(Clone, Debug, Default)]
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
/// custom gate to constrain tensor contractions
custom_gate: EinsumOpConfig<F>,
custom_gate: ContractionConfig<F>,
/// custom gate to constrain random linear combinations used by Freivalds' argument
rlc_gates: Vec<RLCConfig<F>>,
}
@@ -33,12 +35,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
///
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let dummy_custom_gate = EinsumOpConfig {
let dummy_custom_gate = ContractionConfig {
inputs: [
dummy_var.clone(),
dummy_var.clone(),
dummy_var.clone(),
dummy_var.clone(),
[dummy_var.clone(), dummy_var.clone()],
[dummy_var.clone(), dummy_var.clone()],
],
output: dummy_var.clone(),
selectors: BTreeMap::default(),
@@ -73,7 +73,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
];
let output = VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity);
let custom_gate = EinsumOpConfig::new(meta, &inputs, &output);
let custom_gate = ContractionConfig::new(
meta,
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
&output,
);
let mut rlc_gates = vec![];
for _ in 0..analysis.max_num_output_axes {
@@ -94,6 +98,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
input_tensors: &[&ValTensor<F>],
output_tensor: &ValTensor<F>,
equation: &str,
check_mode: &CheckMode,
) -> Result<(), CircuitError> {
region.set_num_einsum_inner_cols(self.custom_gate.output.num_inner_cols());
@@ -235,6 +240,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
dot_product_len,
&output_dims,
input_phases,
check_mode,
)?
}
None => {
@@ -258,7 +264,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
.collect_vec()
.into();
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1)?;
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1, check_mode)?;
region.constrain_equal(&squashed_input, &squashed_output)
}
@@ -300,7 +306,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
}
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
flattened_tensors: Vec<ValTensor<F>>,
input_phases: &[usize],
@@ -327,12 +333,13 @@ fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Has
}
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &EinsumOpConfig<F>,
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
flattened_tensors: Vec<ValTensor<F>>,
dot_product_len: usize,
output_shape: &[usize],
input_phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert_eq!(flattened_tensors.len(), input_phases.len());
let num_dot_products = output_shape.iter().product();
@@ -344,13 +351,14 @@ fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash:
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
.collect::<Result<Vec<_>, TensorError>>()?;
let result = if tensors.len() == 1 {
sum(config, region, &[&tensors[0]], input_phases[0])?
sum(config, region, &[&tensors[0]], input_phases[0], check_mode)?
} else if tensors.len() == 2 {
dot(
config,
region,
&[&tensors[0], &tensors[1]],
&[input_phases[0], input_phases[1]],
check_mode,
)?
} else {
multi_dot(
@@ -358,6 +366,7 @@ fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash:
region,
tensors.iter().collect_vec().as_slice(),
input_phases,
check_mode,
)?
};
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
@@ -367,107 +376,219 @@ fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash:
Ok(tensor)
}
/// `EinsumOpConfig` is the custom gate used for einsum contraction operations
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)]
enum InputPhases {
FirstPhase,
SecondPhase,
BothFirstPhase, // [0, 0]
Mixed, // [0, 1] or [1, 0]
BothSecondPhase, // [1, 1]
}
impl From<&[usize]> for InputPhases {
fn from(phases: &[usize]) -> Self {
match phases {
[0] => Self::FirstPhase,
[1] => Self::SecondPhase,
[0, 0] => Self::BothFirstPhase,
[0, 1] | [1, 0] => Self::Mixed,
[1, 1] => Self::BothSecondPhase,
_ => panic!("Invalid phase combination"),
}
}
}
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
struct BaseOpInfo {
pub op_kind: BaseOp,
pub input_phases: InputPhases,
}
/// `ContractionConfig` is the custom gate used for einsum contraction operations
#[derive(Clone, Debug, Default)]
struct EinsumOpConfig<F: PrimeField + TensorType + PartialOrd> {
// [phase 0, phase 0, phase 1, phase 1]
inputs: [VarTensor; 4],
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
// [[phase 0, phase 0], [phase 1, phase 1]]
inputs: [[VarTensor; 2]; 2],
// phase 1
output: VarTensor,
// (phase, BaseOp, block index, inner column index) -> selector
selectors: BTreeMap<(usize, BaseOp, usize, usize), Selector>,
// (BaseOpInfo, block index, inner column index) -> selector
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> EinsumOpConfig<F> {
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 4], output: &VarTensor) -> Self {
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
fn get_vartensors(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
match input_phases {
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
InputPhases::BothFirstPhase => vec![&self.inputs[0][0], &self.inputs[0][1]],
InputPhases::Mixed => vec![&self.inputs[0][0], &self.inputs[1][0]],
InputPhases::BothSecondPhase => vec![&self.inputs[1][0], &self.inputs[1][1]],
}
}
fn new(
meta: &mut ConstraintSystem<F>,
inputs: &[&[&VarTensor; 2]; 2],
output: &VarTensor,
) -> Self {
let mut selectors = BTreeMap::new();
for phase in [0, 1] {
for input_phases in [
InputPhases::BothFirstPhase,
InputPhases::Mixed,
InputPhases::BothSecondPhase,
] {
for i in 0..output.num_blocks() {
for j in 0..output.num_inner_cols() {
selectors.insert((phase, BaseOp::Mult, i, j), meta.selector());
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Mult,
input_phases,
},
i,
j,
),
meta.selector(),
);
}
for i in 0..output.num_blocks() {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases,
},
i,
0,
),
meta.selector(),
);
}
}
}
for phase in [0, 1] {
for input_phases in [
InputPhases::FirstPhase,
InputPhases::SecondPhase,
] {
for i in 0..output.num_blocks() {
selectors.insert((phase, BaseOp::DotInit, i, 0), meta.selector());
selectors.insert((phase, BaseOp::Dot, i, 0), meta.selector());
selectors.insert((phase, BaseOp::SumInit, i, 0), meta.selector());
selectors.insert((phase, BaseOp::Sum, i, 0), meta.selector());
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases,
},
i,
0,
),
meta.selector(),
);
}
}
selectors.insert(
(1, BaseOp::CumProdInit, output.num_blocks() - 1, 0),
meta.selector(),
);
selectors.insert(
(1, BaseOp::CumProd, output.num_blocks() - 1, 0),
meta.selector(),
);
for ((phase, base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
match base_op {
for ((base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
let inputs = match base_op.input_phases {
InputPhases::FirstPhase => vec![inputs[0][0]],
InputPhases::SecondPhase => vec![inputs[1][0]],
InputPhases::BothFirstPhase => vec![inputs[0][0], inputs[0][1]],
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
};
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
match base_op.op_kind {
BaseOp::Mult => {
meta.create_gate(base_op.as_str(), |meta| {
meta.create_gate(base_op.op_kind.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let zero = Expression::<F>::Constant(F::ZERO);
let mut qis = vec![zero; 4];
for (i, q_i) in qis
.iter_mut()
.enumerate()
.skip(*phase * 2)
.take(base_op.num_inputs())
{
*q_i = inputs[i]
let mut qis = vec![zero; 2];
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("einsum op config: input query failed")[0]
.clone()
}
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.query_offset_rng();
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("einsum op config: output query failed");
let res = base_op
.nonaccum_f((qis[2 * *phase].clone(), qis[2 * *phase + 1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
}
_ => {
meta.create_gate(base_op.as_str(), |meta| {
meta.create_gate(base_op.op_kind.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let mut qis = vec![vec![]; 4];
for (i, q_i) in qis
.iter_mut()
.enumerate()
.skip(*phase * 2)
.take(base_op.num_inputs())
{
*q_i = inputs[i]
let mut qis = vec![vec![]; 2];
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_whole_block(meta, *block_idx, 0, 1)
.expect("einsum op config: input query failed")
.into_iter()
.collect()
}
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.query_offset_rng();
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
.expect("einsum op config: output query failed");
let res = base_op.accum_f(
let res = base_op.op_kind.accum_f(
expected_output[0].clone(),
qis[2 * phase + 1].clone(),
qis[2 * *phase].clone(),
qis[1].clone(),
qis[0].clone(),
);
let constraints =
vec![expected_output[base_op.constraint_idx()].clone() - res];
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res];
Constraints::with_selector(selector, constraints)
});
@@ -475,8 +596,23 @@ impl<F: PrimeField + TensorType + PartialOrd> EinsumOpConfig<F> {
}
}
let first_phase_inputs: [VarTensor; 2] = inputs[0]
.iter()
.copied()
.cloned()
.collect_vec()
.try_into()
.unwrap();
let second_phase_inputs: [VarTensor; 2] = inputs[1]
.iter()
.copied()
.cloned()
.collect_vec()
.try_into()
.unwrap();
Self {
inputs: inputs.clone(),
inputs: [first_phase_inputs, second_phase_inputs],
output: output.clone(),
selectors,
_marker: PhantomData,

View File

@@ -1,4 +1,8 @@
use std::{collections::HashMap, f64::consts::E, ops::Range};
use std::{
collections::{HashMap, HashSet},
f64::consts::E,
ops::Range,
};
use halo2_proofs::circuit::Value;
use halo2curves::ff::PrimeField;
@@ -18,9 +22,8 @@ use crate::{
tensor::{
create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
DataFormat, KernelFormat, Tensor, TensorError, ValType,
},
tensor::{DataFormat, KernelFormat},
};
use super::*;
@@ -825,26 +828,244 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input_tensors: &[&ValTensor<F>],
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
// Track the einsum equation
region.add_used_einsum_equation(equation.to_string())?;
let inputs = input_tensors
// dispatch to freivalds' argument
if !config.einsums.challenges().is_empty() {
return freivalds(config, region, inputs, equation);
}
let mut equation = equation.split("->");
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut indices_to_size = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
e.insert(input.dims()[j]);
} else if indices_to_size[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
// maps unrepresented indices in the output to a trivial 1
for c in output_eq.chars() {
indices_to_size.entry(c).or_insert(1);
}
// Compute the output tensor shape
let mut output_shape: Vec<usize> = output_eq
.chars()
.map(|c| {
indices_to_size
.get(&c)
.ok_or(CircuitError::InvalidEinsum)
.copied()
})
.collect::<Result<Vec<_>, _>>()?;
if output_shape.is_empty() {
output_shape.push(1);
}
// Create a new output tensor with the computed shape
let mut output: Tensor<ValType<F>> = Tensor::new(None, &output_shape)?;
let mut seen = HashSet::new();
let mut common_indices_to_inputs = vec![];
for input in inputs_eq.iter().take(inputs.len()) {
for c in input.chars() {
if !seen.contains(&c) {
seen.insert(c);
} else {
common_indices_to_inputs.push(c);
}
}
}
let non_common_indices = indices_to_size
.keys()
.filter(|&x| !common_indices_to_inputs.contains(x))
.collect::<Vec<_>>();
let non_common_coord_size = non_common_indices
.iter()
.map(|d| {
// If the current index is in the output equation, then the slice should be the current coordinate
if output_eq.contains(**d) {
Ok(1)
// Otherwise, the slice should be the entire dimension of the input tensor
} else {
indices_to_size
.get(d)
.ok_or(CircuitError::InvalidEinsum)
.copied()
}
})
.collect::<Result<Vec<_>, _>>()?
.iter()
.product::<usize>();
let cartesian_coord = output_shape
.iter()
.map(|d| 0..*d)
.multi_cartesian_product()
.collect::<Vec<_>>();
// Get the indices common across input tensors
let mut common_coord = common_indices_to_inputs
.iter()
.map(|d| {
// If the current index is in the output equation, then the slice should be the current coordinate
if output_eq.contains(*d) {
Ok(0..1)
// Otherwise, the slice should be the entire dimension of the input tensor
} else {
Ok(0..*indices_to_size.get(d).ok_or(CircuitError::InvalidEinsum)?)
}
})
.collect::<Result<Vec<Range<_>>, CircuitError>>()?
.into_iter()
.multi_cartesian_product()
.collect::<Vec<_>>();
// If there are no common indices, then we need to add an empty slice to force one iteration of the loop
if common_coord.is_empty() {
common_coord.push(vec![]);
}
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
// Compute the slice of each input tensor given the current coordinate of the output tensor
let inputs = (0..inputs.len())
.map(|idx| {
let mut slice = vec![];
for (i, c) in inputs_eq[idx].chars().enumerate() {
// If the current index is in the output equation, then the slice should be the current coordinate
if let Some(idx) = output_eq.find(c) {
slice.push(coord[idx]..coord[idx] + 1);
// Otherwise, the slice should be the entire dimension of the input tensor
} else {
slice.push(0..inputs[idx].dims()[i]);
}
}
// Get the slice of the input tensor
inputs[idx].get_slice(&slice)
})
.collect::<Result<Vec<_>, _>>()?;
// in this case its just a dot product :)
if non_common_coord_size == 1 && inputs.len() == 2 {
Ok(dot(config, region, &[&inputs[0], &inputs[1]])?.get_inner_tensor()?[0].clone())
} else {
let mut prod_res = None;
// Compute the cartesian product of all common indices
for common_dim in &common_coord {
let inputs = (0..inputs.len())
.map(|idx| {
let mut slice = vec![];
// Iterate over all indices in the input equation
for (i, c) in inputs_eq[idx].chars().enumerate() {
// If the current index is common to multiple inputs, then the slice should be the current coordinate
if let Some(j) = common_indices_to_inputs.iter().position(|&r| r == c) {
slice.push(common_dim[j]..common_dim[j] + 1);
} else {
slice.push(0..inputs[idx].dims()[i]);
}
}
// Get the slice of the input tensor
inputs[idx].get_slice(&slice).map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})
})
.collect::<Result<Vec<_>, _>>()?;
let mut input_pairs = vec![];
for input in &inputs {
input_pairs.push(input.get_inner_tensor()?.iter());
}
let input_pairs = input_pairs
.into_iter()
.multi_cartesian_product()
.collect::<Vec<_>>();
// Compute the product of all input tensors
for pair in input_pairs {
let product_across_pair = prod(config, region, &[&pair.into()])?;
if let Some(product) = prod_res {
prod_res = Some(
pairwise(
config,
region,
&[&product, &product_across_pair],
BaseOp::Add,
)
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?,
);
} else {
prod_res = Some(product_across_pair);
}
}
}
Ok(prod_res
.ok_or(CircuitError::MissingEinsumProduct)?
.get_inner_tensor()?[0]
.clone())
}
};
region.flush()?;
region.apply_in_loop(&mut output, inner_loop_function)?;
let output: ValTensor<F> = output.into();
Ok(output)
}
///
pub fn freivalds<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
let input_values = inputs
.iter()
.map(|t| t.get_inner())
.collect::<Result<Vec<_>, TensorError>>()?;
// Compute expected output using existing einsum logic
// need to add this to ops
let (output_tensor, _) =
crate::tensor::ops::accumulated::einsum(equation, &inputs.iter().collect_vec())?;
crate::tensor::ops::accumulated::einsum(equation, &input_values.iter().collect_vec())?;
config.einsums.assign_einsum(
region,
input_tensors,
inputs,
&output_tensor.clone().into(),
equation,
&config.check_mode,
)?;
region.increment_einsum_index(1);

View File

@@ -1,9 +1,11 @@
use crate::circuit::einsum::analysis::analyze_einsum_usage;
use crate::circuit::einsum::circuit_params::SingleEinsumParams;
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{DataFormat, KernelFormat};
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
circuit::{floor_planner::V1, Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
@@ -17,6 +19,7 @@ use itertools::Itertools;
use ops::lookup::LookupOp;
use ops::region::RegionCtx;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Default)]
@@ -24,7 +27,6 @@ struct TestParams;
#[cfg(test)]
mod matmul {
use super::*;
const K: usize = 9;
@@ -33,18 +35,45 @@ mod matmul {
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
type Params = SingleEinsumParams<F>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(
cs: &mut ConstraintSystem<F>,
params: Self::Params,
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
SingleEinsumParams::<F>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN);
@@ -57,17 +86,31 @@ mod matmul {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let challenges = config
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new_with_challenges(
region,
0,
1,
128,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -89,8 +132,11 @@ mod matmul {
let mut w = Tensor::from((0..LEN + 1).map(|i| Value::known(F::from((i + 1) as u64))));
w.reshape(&[LEN + 1, 1]).unwrap();
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
let circuit = MatmulCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(w)],
einsum_params,
_marker: PhantomData,
};
@@ -110,18 +156,45 @@ mod matmul_col_overflow_double_col {
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
type Params = SingleEinsumParams<F>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(
cs: &mut ConstraintSystem<F>,
params: Self::Params,
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
SingleEinsumParams::<F>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
@@ -134,17 +207,31 @@ mod matmul_col_overflow_double_col {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let challenges = config
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
let mut region = RegionCtx::new_with_challenges(
region,
0,
NUM_INNER_COLS,
128,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -164,8 +251,11 @@ mod matmul_col_overflow_double_col {
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
w.reshape(&[LEN, 1]).unwrap();
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
let circuit = MatmulCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(w)],
einsum_params,
_marker: PhantomData,
};
@@ -184,13 +274,14 @@ mod matmul_col_overflow {
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
type FloorPlanner = V1;
type Params = SingleEinsumParams<F>;
fn without_witnesses(&self) -> Self {
self.clone()
@@ -203,22 +294,55 @@ mod matmul_col_overflow {
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn configure_with_params(
cs: &mut ConstraintSystem<F>,
params: Self::Params,
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
self.einsum_params.clone()
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let challenges = config
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new_with_challenges(
region,
0,
1,
128,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -238,8 +362,11 @@ mod matmul_col_overflow {
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
w.reshape(&[LEN, 1]).unwrap();
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
let circuit = MatmulCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(w)],
einsum_params,
_marker: PhantomData,
};
@@ -271,18 +398,37 @@ mod matmul_col_ultra_overflow_double_col {
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
type Params = SingleEinsumParams<F>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(
cs: &mut ConstraintSystem<F>,
params: Self::Params,
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
config
.configure_einsums(cs, &analysis, NUM_INNER_COLS, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
self.einsum_params.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
@@ -295,17 +441,31 @@ mod matmul_col_ultra_overflow_double_col {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let challenges = config
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
let mut region = RegionCtx::new_with_challenges(
region,
0,
NUM_INNER_COLS,
128,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -328,8 +488,11 @@ mod matmul_col_ultra_overflow_double_col {
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
w.reshape(&[LEN, 1]).unwrap();
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
let circuit = MatmulCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(w)],
einsum_params,
_marker: PhantomData,
};
@@ -376,10 +539,13 @@ mod matmul_col_ultra_overflow_double_col {
))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
use halo2_proofs::{
circuit::floor_planner::V1,
poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
},
};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
@@ -392,18 +558,38 @@ mod matmul_col_ultra_overflow {
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
type FloorPlanner = V1;
type Params = SingleEinsumParams<F>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(
cs: &mut ConstraintSystem<F>,
params: Self::Params,
) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert(params.equation, params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
self.einsum_params.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
@@ -416,17 +602,32 @@ mod matmul_col_ultra_overflow {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let challenges = config
.einsums
.challenges()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
println!("challenges: {:?}", challenges);
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new_with_challenges(
region,
0,
1,
128,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -449,8 +650,11 @@ mod matmul_col_ultra_overflow {
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
w.reshape(&[LEN, 1]).unwrap();
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
let circuit = MatmulCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(w)],
einsum_params,
_marker: PhantomData,
};

View File

@@ -2729,7 +2729,10 @@ pub mod accumulated {
// Output dims
let out_idx: Vec<char> = output_expr.chars().collect();
let out_dims: Vec<usize> = out_idx.iter().map(|c| *dim_of.get(c).unwrap_or(&1)).collect();
let out_dims: Vec<usize> = out_idx
.iter()
.map(|c| *dim_of.get(c).unwrap_or(&1))
.collect();
// Reduction indices
let all_idx: HashSet<char> = dim_of.keys().copied().collect();
@@ -2738,8 +2741,10 @@ pub mod accumulated {
let red_dims: Vec<usize> = red_idx.iter().map(|c| dim_of[c]).collect();
// Fast index->pos
let out_pos: HashMap<char, usize> = out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
let red_pos: HashMap<char, usize> = red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
let out_pos: HashMap<char, usize> =
out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
let red_pos: HashMap<char, usize> =
red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
// Precompute strides per input and contributions
struct Contrib {
@@ -2762,7 +2767,10 @@ pub mod accumulated {
red_stride[q] = s;
}
}
Contrib { out_stride, red_stride }
Contrib {
out_stride,
red_stride,
}
})
.collect();
@@ -2777,8 +2785,7 @@ pub mod accumulated {
let red_rank = red_dims.len();
// Materialize output elements one by one
out
.par_iter_mut()
out.par_iter_mut()
.enumerate()
.for_each(|(out_linear_coord, out)| {
let mut out_index = vec![0usize; out_rank];