mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-09 14:28:00 -05:00
Fix circuit tests, create temp cargo patch for halo2 due to cross phase assignment issues
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2431,7 +2431,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
|
||||
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
|
||||
@@ -300,7 +300,6 @@ halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
|
||||
[profile.release]
|
||||
# debug = true
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
|
||||
@@ -3,6 +3,7 @@ use criterion::{
|
||||
Throughput,
|
||||
};
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
@@ -31,54 +32,13 @@ static mut K: usize = 15;
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
type Params = SingleEinsumParams<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
@@ -101,8 +61,8 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
SingleEinsumParams::<Fr>::new(
|
||||
&self.einsum_params.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
@@ -157,7 +117,7 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -189,11 +149,11 @@ fn runmatmul(c: &mut Criterion) {
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
einsum_params,
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
|
||||
@@ -91,14 +91,18 @@ pub fn analyze_single_equation(
|
||||
.map(|input| {
|
||||
input
|
||||
.chars()
|
||||
.filter(|char| input_axes_to_dim.get(char).is_some())
|
||||
.filter(|char| {
|
||||
input_axes_to_dim.get(char).is_some() && *input_axes_to_dim.get(char).unwrap() > 1
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output = output_str
|
||||
.chars()
|
||||
.filter(|c| input_axes_to_dim.get(c).is_some())
|
||||
.filter(|c| {
|
||||
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
|
||||
})
|
||||
.collect();
|
||||
|
||||
[inputs.join(","), output].join("->")
|
||||
|
||||
54
src/circuit/ops/chip/einsum/circuit_params.rs
Normal file
54
src/circuit/ops/chip/einsum/circuit_params.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::{collections::HashMap, marker::PhantomData};
|
||||
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
tensor::{Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
/// Circuit parameter for a single einsum equation
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
|
||||
///
|
||||
pub equation: String,
|
||||
/// Map from input axes to dimensions
|
||||
pub input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
|
||||
///
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ use halo2curves::ff::PrimeField;
|
||||
use log::{error, trace};
|
||||
|
||||
use crate::{
|
||||
circuit::{base::BaseOp, region::RegionCtx, CircuitError},
|
||||
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
@@ -11,11 +11,11 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::EinsumOpConfig;
|
||||
use super::ContractionConfig;
|
||||
|
||||
/// Pairwise (elementwise) op layout
|
||||
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
op: BaseOp,
|
||||
@@ -26,7 +26,6 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
} else {
|
||||
(values[1].clone(), values[0].clone())
|
||||
};
|
||||
let min_phase = std::cmp::min(phases[0], phases[1]);
|
||||
|
||||
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
|
||||
|
||||
@@ -42,9 +41,11 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
region.flush_einsum()?;
|
||||
|
||||
let vars = config.get_vartensors(phases.as_slice().into());
|
||||
|
||||
let inputs = [lhs, rhs]
|
||||
.iter()
|
||||
.zip(config.inputs.iter().skip(min_phase * 2))
|
||||
.zip(vars)
|
||||
.map(|(val, var)| {
|
||||
let res = region.assign_einsum(var, val)?;
|
||||
Ok(res.get_inner()?)
|
||||
@@ -71,8 +72,12 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[0].cartesian_coord(region.einsum_col_coord() + i);
|
||||
let selector = config.selectors.get(&(min_phase, op.clone(), x, y));
|
||||
let (x, y, z) = config.output.cartesian_coord(region.einsum_col_coord() + i);
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: op.clone(),
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
let selector = config.selectors.get(&(op_info, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
|
||||
@@ -88,10 +93,11 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
}
|
||||
|
||||
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() == 1 {
|
||||
return Ok(values[0].clone());
|
||||
@@ -109,8 +115,9 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
let var = config.get_vartensors([phase].as_slice().into())[0];
|
||||
let (res, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&config.inputs[phase * 2], &input)?;
|
||||
.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
@@ -121,7 +128,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
&accumulated_sum.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
@@ -135,9 +142,17 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
continue;
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
config.selectors.get(&(phase, BaseOp::SumInit, x, 0))
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::SumInit,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
config.selectors.get(&(phase, BaseOp::Sum, x, 0))
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::Sum,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
@@ -153,10 +168,11 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
}
|
||||
|
||||
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phase == 0 || phase == 1);
|
||||
region.flush_einsum()?;
|
||||
@@ -168,8 +184,9 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ONE)))?;
|
||||
let var = config.get_vartensors([phase].as_slice().into())[0];
|
||||
let (res, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&config.inputs[phase * 2], &input)?;
|
||||
.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
@@ -180,7 +197,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
&accumulated_prod.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
@@ -195,9 +212,17 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
config.selectors.get(&(phase, BaseOp::CumProdInit, x, 0))
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::CumProdInit,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
config.selectors.get(&(phase, BaseOp::CumProd, x, 0))
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::CumProd,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
@@ -215,10 +240,11 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
}
|
||||
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
phases: &[usize; 2],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() != values[1].len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
@@ -233,16 +259,13 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
} else {
|
||||
[values[1].clone(), values[0].clone()]
|
||||
};
|
||||
let min_phase = std::cmp::min(phases[0], phases[1]);
|
||||
let vars = config.get_vartensors(phases.as_slice().into());
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.output.num_inner_cols();
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (val, var) in values
|
||||
.iter_mut()
|
||||
.zip(config.inputs.iter().skip(min_phase * 2))
|
||||
{
|
||||
for (val, var) in values.iter_mut().zip(vars) {
|
||||
// FIXME : should pad with constant zero but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
@@ -261,8 +284,8 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
&accumulated_dot.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
)?;
|
||||
check_mode,
|
||||
).expect("failed to assign einsum with duplication constrained");
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
@@ -276,9 +299,17 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
config.selectors.get(&(min_phase, BaseOp::DotInit, x, 0))
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::DotInit,
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
config.selectors.get(&(min_phase, BaseOp::Dot, x, 0))
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::Dot,
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
|
||||
@@ -299,10 +330,11 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
/// Dot product of more than two tensors
|
||||
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>],
|
||||
phases: &[usize],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
|
||||
if !values.iter().all(|value| value.len() == values[0].len()) {
|
||||
@@ -313,7 +345,7 @@ pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
|
||||
// do pairwise dot product between intermediate tensor and the next tensor
|
||||
let (intermediate, _) = values
|
||||
let (intermediate, output_phase) = values
|
||||
.into_iter()
|
||||
.zip(phases.iter().cloned())
|
||||
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
|
||||
@@ -331,10 +363,7 @@ pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Sum the final tensor
|
||||
// In current freivalds' algorithm, we ensure that there is no tensor contraction between phase 0 tensors,
|
||||
// so the phase of the resulting tensor is set to 1
|
||||
let accumulated_dot = sum(config, region, &[&intermediate], 1)?;
|
||||
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
|
||||
let last_elem = accumulated_dot.last()?;
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnal
|
||||
use crate::circuit::einsum::layouts::{pairwise, sum};
|
||||
use crate::circuit::einsum::reduction_planner::Reduction;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::CircuitError;
|
||||
use crate::circuit::{CheckMode, CircuitError};
|
||||
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
@@ -17,6 +17,8 @@ use std::marker::PhantomData;
|
||||
|
||||
///
|
||||
pub mod analysis;
|
||||
///
|
||||
pub mod circuit_params;
|
||||
mod layouts;
|
||||
mod reduction_planner;
|
||||
|
||||
@@ -24,7 +26,7 @@ mod reduction_planner;
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// custom gate to constrain tensor contractions
|
||||
custom_gate: EinsumOpConfig<F>,
|
||||
custom_gate: ContractionConfig<F>,
|
||||
/// custom gate to constrain random linear combinations used by Freivalds' argument
|
||||
rlc_gates: Vec<RLCConfig<F>>,
|
||||
}
|
||||
@@ -33,12 +35,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
///
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let dummy_custom_gate = EinsumOpConfig {
|
||||
let dummy_custom_gate = ContractionConfig {
|
||||
inputs: [
|
||||
dummy_var.clone(),
|
||||
dummy_var.clone(),
|
||||
dummy_var.clone(),
|
||||
dummy_var.clone(),
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
],
|
||||
output: dummy_var.clone(),
|
||||
selectors: BTreeMap::default(),
|
||||
@@ -73,7 +73,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let output = VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity);
|
||||
let custom_gate = EinsumOpConfig::new(meta, &inputs, &output);
|
||||
let custom_gate = ContractionConfig::new(
|
||||
meta,
|
||||
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
|
||||
&output,
|
||||
);
|
||||
|
||||
let mut rlc_gates = vec![];
|
||||
for _ in 0..analysis.max_num_output_axes {
|
||||
@@ -94,6 +98,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
input_tensors: &[&ValTensor<F>],
|
||||
output_tensor: &ValTensor<F>,
|
||||
equation: &str,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<(), CircuitError> {
|
||||
region.set_num_einsum_inner_cols(self.custom_gate.output.num_inner_cols());
|
||||
|
||||
@@ -235,6 +240,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
dot_product_len,
|
||||
&output_dims,
|
||||
input_phases,
|
||||
check_mode,
|
||||
)?
|
||||
}
|
||||
None => {
|
||||
@@ -258,7 +264,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
|
||||
.collect_vec()
|
||||
.into();
|
||||
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1)?;
|
||||
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1, check_mode)?;
|
||||
|
||||
region.constrain_equal(&squashed_input, &squashed_output)
|
||||
}
|
||||
@@ -300,7 +306,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
}
|
||||
|
||||
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
input_phases: &[usize],
|
||||
@@ -327,12 +333,13 @@ fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
}
|
||||
|
||||
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
dot_product_len: usize,
|
||||
output_shape: &[usize],
|
||||
input_phases: &[usize],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let num_dot_products = output_shape.iter().product();
|
||||
@@ -344,13 +351,14 @@ fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
let result = if tensors.len() == 1 {
|
||||
sum(config, region, &[&tensors[0]], input_phases[0])?
|
||||
sum(config, region, &[&tensors[0]], input_phases[0], check_mode)?
|
||||
} else if tensors.len() == 2 {
|
||||
dot(
|
||||
config,
|
||||
region,
|
||||
&[&tensors[0], &tensors[1]],
|
||||
&[input_phases[0], input_phases[1]],
|
||||
check_mode,
|
||||
)?
|
||||
} else {
|
||||
multi_dot(
|
||||
@@ -358,6 +366,7 @@ fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
region,
|
||||
tensors.iter().collect_vec().as_slice(),
|
||||
input_phases,
|
||||
check_mode,
|
||||
)?
|
||||
};
|
||||
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
|
||||
@@ -367,107 +376,219 @@ fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// `EinsumOpConfig` is the custom gate used for einsum contraction operations
|
||||
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||
enum InputPhases {
|
||||
FirstPhase,
|
||||
SecondPhase,
|
||||
BothFirstPhase, // [0, 0]
|
||||
Mixed, // [0, 1] or [1, 0]
|
||||
BothSecondPhase, // [1, 1]
|
||||
}
|
||||
|
||||
impl From<&[usize]> for InputPhases {
|
||||
fn from(phases: &[usize]) -> Self {
|
||||
match phases {
|
||||
[0] => Self::FirstPhase,
|
||||
[1] => Self::SecondPhase,
|
||||
[0, 0] => Self::BothFirstPhase,
|
||||
[0, 1] | [1, 0] => Self::Mixed,
|
||||
[1, 1] => Self::BothSecondPhase,
|
||||
_ => panic!("Invalid phase combination"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
struct BaseOpInfo {
|
||||
pub op_kind: BaseOp,
|
||||
pub input_phases: InputPhases,
|
||||
}
|
||||
|
||||
/// `ContractionConfig` is the custom gate used for einsum contraction operations
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct EinsumOpConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
// [phase 0, phase 0, phase 1, phase 1]
|
||||
inputs: [VarTensor; 4],
|
||||
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
// [[phase 0, phase 0], [phase 1, phase 1]]
|
||||
inputs: [[VarTensor; 2]; 2],
|
||||
// phase 1
|
||||
output: VarTensor,
|
||||
// (phase, BaseOp, block index, inner column index) -> selector
|
||||
selectors: BTreeMap<(usize, BaseOp, usize, usize), Selector>,
|
||||
// (BaseOpInfo, block index, inner column index) -> selector
|
||||
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> EinsumOpConfig<F> {
|
||||
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 4], output: &VarTensor) -> Self {
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
fn get_vartensors(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
|
||||
InputPhases::BothFirstPhase => vec![&self.inputs[0][0], &self.inputs[0][1]],
|
||||
InputPhases::Mixed => vec![&self.inputs[0][0], &self.inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![&self.inputs[1][0], &self.inputs[1][1]],
|
||||
}
|
||||
}
|
||||
|
||||
fn new(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
inputs: &[&[&VarTensor; 2]; 2],
|
||||
output: &VarTensor,
|
||||
) -> Self {
|
||||
let mut selectors = BTreeMap::new();
|
||||
for phase in [0, 1] {
|
||||
for input_phases in [
|
||||
InputPhases::BothFirstPhase,
|
||||
InputPhases::Mixed,
|
||||
InputPhases::BothSecondPhase,
|
||||
] {
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
selectors.insert((phase, BaseOp::Mult, i, j), meta.selector());
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Mult,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
j,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
for i in 0..output.num_blocks() {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::DotInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Dot,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for phase in [0, 1] {
|
||||
for input_phases in [
|
||||
InputPhases::FirstPhase,
|
||||
InputPhases::SecondPhase,
|
||||
] {
|
||||
for i in 0..output.num_blocks() {
|
||||
selectors.insert((phase, BaseOp::DotInit, i, 0), meta.selector());
|
||||
selectors.insert((phase, BaseOp::Dot, i, 0), meta.selector());
|
||||
selectors.insert((phase, BaseOp::SumInit, i, 0), meta.selector());
|
||||
selectors.insert((phase, BaseOp::Sum, i, 0), meta.selector());
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::SumInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Sum,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::CumProdInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::CumProd,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
}
|
||||
selectors.insert(
|
||||
(1, BaseOp::CumProdInit, output.num_blocks() - 1, 0),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(1, BaseOp::CumProd, output.num_blocks() - 1, 0),
|
||||
meta.selector(),
|
||||
);
|
||||
for ((phase, base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
|
||||
match base_op {
|
||||
for ((base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
|
||||
let inputs = match base_op.input_phases {
|
||||
InputPhases::FirstPhase => vec![inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![inputs[1][0]],
|
||||
InputPhases::BothFirstPhase => vec![inputs[0][0], inputs[0][1]],
|
||||
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
|
||||
};
|
||||
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
|
||||
match base_op.op_kind {
|
||||
BaseOp::Mult => {
|
||||
meta.create_gate(base_op.as_str(), |meta| {
|
||||
meta.create_gate(base_op.op_kind.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
|
||||
let zero = Expression::<F>::Constant(F::ZERO);
|
||||
let mut qis = vec![zero; 4];
|
||||
for (i, q_i) in qis
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.skip(*phase * 2)
|
||||
.take(base_op.num_inputs())
|
||||
{
|
||||
*q_i = inputs[i]
|
||||
let mut qis = vec![zero; 2];
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("einsum op config: input query failed")[0]
|
||||
.clone()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("einsum op config: output query failed");
|
||||
|
||||
let res = base_op
|
||||
.nonaccum_f((qis[2 * *phase].clone(), qis[2 * *phase + 1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
meta.create_gate(base_op.as_str(), |meta| {
|
||||
meta.create_gate(base_op.op_kind.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
let mut qis = vec![vec![]; 4];
|
||||
for (i, q_i) in qis
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.skip(*phase * 2)
|
||||
.take(base_op.num_inputs())
|
||||
{
|
||||
*q_i = inputs[i]
|
||||
let mut qis = vec![vec![]; 2];
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("einsum op config: input query failed")
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
|
||||
.expect("einsum op config: output query failed");
|
||||
|
||||
let res = base_op.accum_f(
|
||||
let res = base_op.op_kind.accum_f(
|
||||
expected_output[0].clone(),
|
||||
qis[2 * phase + 1].clone(),
|
||||
qis[2 * *phase].clone(),
|
||||
qis[1].clone(),
|
||||
qis[0].clone(),
|
||||
);
|
||||
let constraints =
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res];
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res];
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
@@ -475,8 +596,23 @@ impl<F: PrimeField + TensorType + PartialOrd> EinsumOpConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
let first_phase_inputs: [VarTensor; 2] = inputs[0]
|
||||
.iter()
|
||||
.copied()
|
||||
.cloned()
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let second_phase_inputs: [VarTensor; 2] = inputs[1]
|
||||
.iter()
|
||||
.copied()
|
||||
.cloned()
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
inputs: inputs.clone(),
|
||||
inputs: [first_phase_inputs, second_phase_inputs],
|
||||
output: output.clone(),
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use std::{collections::HashMap, f64::consts::E, ops::Range};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
f64::consts::E,
|
||||
ops::Range,
|
||||
};
|
||||
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -18,9 +22,8 @@ use crate::{
|
||||
tensor::{
|
||||
create_unit_tensor, get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
Tensor, TensorError, ValType,
|
||||
DataFormat, KernelFormat, Tensor, TensorError, ValType,
|
||||
},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -825,26 +828,244 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
input_tensors: &[&ValTensor<F>],
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// Track the einsum equation
|
||||
region.add_used_einsum_equation(equation.to_string())?;
|
||||
|
||||
let inputs = input_tensors
|
||||
// dispatch to freivalds' argument
|
||||
if !config.einsums.challenges().is_empty() {
|
||||
return freivalds(config, region, inputs, equation);
|
||||
}
|
||||
|
||||
let mut equation = equation.split("->");
|
||||
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut indices_to_size = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if indices_to_size[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// maps unrepresented indices in the output to a trivial 1
|
||||
for c in output_eq.chars() {
|
||||
indices_to_size.entry(c).or_insert(1);
|
||||
}
|
||||
|
||||
// Compute the output tensor shape
|
||||
let mut output_shape: Vec<usize> = output_eq
|
||||
.chars()
|
||||
.map(|c| {
|
||||
indices_to_size
|
||||
.get(&c)
|
||||
.ok_or(CircuitError::InvalidEinsum)
|
||||
.copied()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if output_shape.is_empty() {
|
||||
output_shape.push(1);
|
||||
}
|
||||
|
||||
// Create a new output tensor with the computed shape
|
||||
let mut output: Tensor<ValType<F>> = Tensor::new(None, &output_shape)?;
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
let mut common_indices_to_inputs = vec![];
|
||||
for input in inputs_eq.iter().take(inputs.len()) {
|
||||
for c in input.chars() {
|
||||
if !seen.contains(&c) {
|
||||
seen.insert(c);
|
||||
} else {
|
||||
common_indices_to_inputs.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let non_common_indices = indices_to_size
|
||||
.keys()
|
||||
.filter(|&x| !common_indices_to_inputs.contains(x))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let non_common_coord_size = non_common_indices
|
||||
.iter()
|
||||
.map(|d| {
|
||||
// If the current index is in the output equation, then the slice should be the current coordinate
|
||||
if output_eq.contains(**d) {
|
||||
Ok(1)
|
||||
// Otherwise, the slice should be the entire dimension of the input tensor
|
||||
} else {
|
||||
indices_to_size
|
||||
.get(d)
|
||||
.ok_or(CircuitError::InvalidEinsum)
|
||||
.copied()
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.iter()
|
||||
.product::<usize>();
|
||||
|
||||
let cartesian_coord = output_shape
|
||||
.iter()
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Get the indices common across input tensors
|
||||
let mut common_coord = common_indices_to_inputs
|
||||
.iter()
|
||||
.map(|d| {
|
||||
// If the current index is in the output equation, then the slice should be the current coordinate
|
||||
if output_eq.contains(*d) {
|
||||
Ok(0..1)
|
||||
// Otherwise, the slice should be the entire dimension of the input tensor
|
||||
} else {
|
||||
Ok(0..*indices_to_size.get(d).ok_or(CircuitError::InvalidEinsum)?)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Range<_>>, CircuitError>>()?
|
||||
.into_iter()
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// If there are no common indices, then we need to add an empty slice to force one iteration of the loop
|
||||
if common_coord.is_empty() {
|
||||
common_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
// Compute the slice of each input tensor given the current coordinate of the output tensor
|
||||
let inputs = (0..inputs.len())
|
||||
.map(|idx| {
|
||||
let mut slice = vec![];
|
||||
for (i, c) in inputs_eq[idx].chars().enumerate() {
|
||||
// If the current index is in the output equation, then the slice should be the current coordinate
|
||||
if let Some(idx) = output_eq.find(c) {
|
||||
slice.push(coord[idx]..coord[idx] + 1);
|
||||
// Otherwise, the slice should be the entire dimension of the input tensor
|
||||
} else {
|
||||
slice.push(0..inputs[idx].dims()[i]);
|
||||
}
|
||||
}
|
||||
// Get the slice of the input tensor
|
||||
inputs[idx].get_slice(&slice)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// in this case its just a dot product :)
|
||||
if non_common_coord_size == 1 && inputs.len() == 2 {
|
||||
Ok(dot(config, region, &[&inputs[0], &inputs[1]])?.get_inner_tensor()?[0].clone())
|
||||
} else {
|
||||
let mut prod_res = None;
|
||||
|
||||
// Compute the cartesian product of all common indices
|
||||
for common_dim in &common_coord {
|
||||
let inputs = (0..inputs.len())
|
||||
.map(|idx| {
|
||||
let mut slice = vec![];
|
||||
// Iterate over all indices in the input equation
|
||||
for (i, c) in inputs_eq[idx].chars().enumerate() {
|
||||
// If the current index is common to multiple inputs, then the slice should be the current coordinate
|
||||
if let Some(j) = common_indices_to_inputs.iter().position(|&r| r == c) {
|
||||
slice.push(common_dim[j]..common_dim[j] + 1);
|
||||
} else {
|
||||
slice.push(0..inputs[idx].dims()[i]);
|
||||
}
|
||||
}
|
||||
// Get the slice of the input tensor
|
||||
inputs[idx].get_slice(&slice).map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut input_pairs = vec![];
|
||||
|
||||
for input in &inputs {
|
||||
input_pairs.push(input.get_inner_tensor()?.iter());
|
||||
}
|
||||
|
||||
let input_pairs = input_pairs
|
||||
.into_iter()
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Compute the product of all input tensors
|
||||
for pair in input_pairs {
|
||||
let product_across_pair = prod(config, region, &[&pair.into()])?;
|
||||
|
||||
if let Some(product) = prod_res {
|
||||
prod_res = Some(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&product, &product_across_pair],
|
||||
BaseOp::Add,
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?,
|
||||
);
|
||||
} else {
|
||||
prod_res = Some(product_across_pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(prod_res
|
||||
.ok_or(CircuitError::MissingEinsumProduct)?
|
||||
.get_inner_tensor()?[0]
|
||||
.clone())
|
||||
}
|
||||
};
|
||||
|
||||
region.flush()?;
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let output: ValTensor<F> = output.into();
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn freivalds<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input_values = inputs
|
||||
.iter()
|
||||
.map(|t| t.get_inner())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
// Compute expected output using existing einsum logic
|
||||
// need to add this to ops
|
||||
let (output_tensor, _) =
|
||||
crate::tensor::ops::accumulated::einsum(equation, &inputs.iter().collect_vec())?;
|
||||
crate::tensor::ops::accumulated::einsum(equation, &input_values.iter().collect_vec())?;
|
||||
|
||||
config.einsums.assign_einsum(
|
||||
region,
|
||||
input_tensors,
|
||||
inputs,
|
||||
&output_tensor.clone().into(),
|
||||
equation,
|
||||
&config.check_mode,
|
||||
)?;
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use crate::circuit::einsum::circuit_params::SingleEinsumParams;
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::*;
|
||||
use crate::tensor::{DataFormat, KernelFormat};
|
||||
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
circuit::{floor_planner::V1, Layouter, SimpleFloorPlanner, Value},
|
||||
dev::MockProver,
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
@@ -17,6 +19,7 @@ use itertools::Itertools;
|
||||
use ops::lookup::LookupOp;
|
||||
use ops::region::RegionCtx;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -24,7 +27,6 @@ struct TestParams;
|
||||
|
||||
#[cfg(test)]
|
||||
mod matmul {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 9;
|
||||
@@ -33,18 +35,45 @@ mod matmul {
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MatmulCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
type Params = SingleEinsumParams<F>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
params: Self::Params,
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
SingleEinsumParams::<F>::new(
|
||||
&self.einsum_params.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN);
|
||||
@@ -57,17 +86,31 @@ mod matmul {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
let mut region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
128,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -89,8 +132,11 @@ mod matmul {
|
||||
let mut w = Tensor::from((0..LEN + 1).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
w.reshape(&[LEN + 1, 1]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
|
||||
|
||||
let circuit = MatmulCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(w)],
|
||||
einsum_params,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
@@ -110,18 +156,45 @@ mod matmul_col_overflow_double_col {
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MatmulCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
type Params = SingleEinsumParams<F>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
params: Self::Params,
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
SingleEinsumParams::<F>::new(
|
||||
&self.einsum_params.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
|
||||
@@ -134,17 +207,31 @@ mod matmul_col_overflow_double_col {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
let mut region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
NUM_INNER_COLS,
|
||||
128,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -164,8 +251,11 @@ mod matmul_col_overflow_double_col {
|
||||
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
w.reshape(&[LEN, 1]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
|
||||
|
||||
let circuit = MatmulCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(w)],
|
||||
einsum_params,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
@@ -184,13 +274,14 @@ mod matmul_col_overflow {
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MatmulCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
type FloorPlanner = V1;
|
||||
type Params = SingleEinsumParams<F>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
@@ -203,22 +294,55 @@ mod matmul_col_overflow {
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn configure_with_params(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
params: Self::Params,
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
self.einsum_params.clone()
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
let mut region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
128,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -238,8 +362,11 @@ mod matmul_col_overflow {
|
||||
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
w.reshape(&[LEN, 1]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
|
||||
|
||||
let circuit = MatmulCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(w)],
|
||||
einsum_params,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
@@ -271,18 +398,37 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MatmulCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
type Params = SingleEinsumParams<F>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
params: Self::Params,
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
config
|
||||
.configure_einsums(cs, &analysis, NUM_INNER_COLS, K)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
self.einsum_params.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN);
|
||||
@@ -295,17 +441,31 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
let mut region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
NUM_INNER_COLS,
|
||||
128,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -328,8 +488,11 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
w.reshape(&[LEN, 1]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
|
||||
|
||||
let circuit = MatmulCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(w)],
|
||||
einsum_params,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
@@ -376,10 +539,13 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
use halo2_proofs::{
|
||||
circuit::floor_planner::V1,
|
||||
poly::kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
@@ -392,18 +558,38 @@ mod matmul_col_ultra_overflow {
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MatmulCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
type FloorPlanner = V1;
|
||||
type Params = SingleEinsumParams<F>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
params: Self::Params,
|
||||
) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
self.einsum_params.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
@@ -416,17 +602,32 @@ mod matmul_col_ultra_overflow {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
println!("challenges: {:?}", challenges);
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
let mut region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
128,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -449,8 +650,11 @@ mod matmul_col_ultra_overflow {
|
||||
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
w.reshape(&[LEN, 1]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<F>::new("ij,jk->ik", &[&a, &w]).unwrap();
|
||||
|
||||
let circuit = MatmulCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(w)],
|
||||
einsum_params,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
|
||||
@@ -2729,7 +2729,10 @@ pub mod accumulated {
|
||||
|
||||
// Output dims
|
||||
let out_idx: Vec<char> = output_expr.chars().collect();
|
||||
let out_dims: Vec<usize> = out_idx.iter().map(|c| *dim_of.get(c).unwrap_or(&1)).collect();
|
||||
let out_dims: Vec<usize> = out_idx
|
||||
.iter()
|
||||
.map(|c| *dim_of.get(c).unwrap_or(&1))
|
||||
.collect();
|
||||
|
||||
// Reduction indices
|
||||
let all_idx: HashSet<char> = dim_of.keys().copied().collect();
|
||||
@@ -2738,8 +2741,10 @@ pub mod accumulated {
|
||||
let red_dims: Vec<usize> = red_idx.iter().map(|c| dim_of[c]).collect();
|
||||
|
||||
// Fast index->pos
|
||||
let out_pos: HashMap<char, usize> = out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
let red_pos: HashMap<char, usize> = red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
let out_pos: HashMap<char, usize> =
|
||||
out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
let red_pos: HashMap<char, usize> =
|
||||
red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
|
||||
// Precompute strides per input and contributions
|
||||
struct Contrib {
|
||||
@@ -2762,7 +2767,10 @@ pub mod accumulated {
|
||||
red_stride[q] = s;
|
||||
}
|
||||
}
|
||||
Contrib { out_stride, red_stride }
|
||||
Contrib {
|
||||
out_stride,
|
||||
red_stride,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -2777,8 +2785,7 @@ pub mod accumulated {
|
||||
let red_rank = red_dims.len();
|
||||
|
||||
// Materialize output elements one by one
|
||||
out
|
||||
.par_iter_mut()
|
||||
out.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(out_linear_coord, out)| {
|
||||
let mut out_index = vec![0usize; out_rank];
|
||||
|
||||
Reference in New Issue
Block a user