mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-09 14:28:00 -05:00
feat: implement generalized Freivalds' algorithm for arbitrary einsum expressions (#1006)
--------- Co-authored-by: therealyingtong <yingtong.lai@gmail.com>
This commit is contained in:
@@ -1,53 +1,132 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use criterion::{
|
||||
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
|
||||
Throughput,
|
||||
};
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::plonk::create_proof;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer};
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
static mut K: usize = 15;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit {
|
||||
inputs: [ValTensor<Fr>; 2],
|
||||
_marker: PhantomData<Fr>,
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit {
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ();
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 2;
|
||||
unsafe {
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len * len);
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let b = VarTensor::new_advice(cs, K, 1, len * len);
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
unsafe {
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -55,16 +134,30 @@ impl Circuit<Fr> for MyCircuit {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ab,bc->ac".to_string(),
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -77,68 +170,64 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|
||||
fn runmatmul(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("accum_einsum_matmul");
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [4, 32].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
|
||||
group.sampling_mode(criterion::SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
let len = 512;
|
||||
unsafe {
|
||||
LEN = len;
|
||||
}
|
||||
for k in 19..20 {
|
||||
let params = unsafe {
|
||||
K = k;
|
||||
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
|
||||
};
|
||||
|
||||
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
_marker: PhantomData,
|
||||
einsum,
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, ¶ms, false)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
|
||||
|
||||
create_proof::<KZGCommitmentScheme<_>, ProverSHPLONK<_>, _, _, _, _>(
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
&[circuit.clone()],
|
||||
&[&[]],
|
||||
OsRng,
|
||||
&mut transcript,
|
||||
)
|
||||
.expect("proof generation should not fail");
|
||||
|
||||
transcript.finalize();
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().with_plots();
|
||||
targets = runmatmul
|
||||
}
|
||||
criterion_group!(benches, runmatmul);
|
||||
criterion_main!(benches);
|
||||
|
||||
180
examples/accum_einsum_matmul.rs
Normal file
180
examples/accum_einsum_matmul.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const K: usize = 13;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runmatmul() {
|
||||
let len = 64;
|
||||
|
||||
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runmatmul()
|
||||
}
|
||||
188
examples/batch_mat_mul.rs
Normal file
188
examples/batch_mat_mul.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 11;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len);
|
||||
let b = VarTensor::new_advice(cs, K, 1, len);
|
||||
let output = VarTensor::new_advice(cs, K, 1, len);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runbatchmatmul() {
|
||||
let batch_size = 5;
|
||||
let len = 12;
|
||||
|
||||
let mut a = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[batch_size, len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[batch_size, len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ijk,ikl->ijl", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runbatchmatmul()
|
||||
}
|
||||
190
examples/tensor_contraction.rs
Normal file
190
examples/tensor_contraction.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 11;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len);
|
||||
let b = VarTensor::new_advice(cs, K, 1, len);
|
||||
let output = VarTensor::new_advice(cs, K, 1, len);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(params.equation, params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 2;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let default_params = Self::Params::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert(default_params.equation, default_params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.challenges()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runmatmul() {
|
||||
let i = 10;
|
||||
let n = 10;
|
||||
let j = 40;
|
||||
let k = 10;
|
||||
|
||||
let mut a = Tensor::from((0..i * n * j).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[i, n, j]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..j * k).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[j, k]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("inj,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runmatmul()
|
||||
}
|
||||
@@ -7,18 +7,14 @@ use halo2_proofs::{
|
||||
};
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::FromPyObject,
|
||||
exceptions::PyValueError,
|
||||
IntoPyObject,
|
||||
prelude::*,
|
||||
};
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*, IntoPyObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
circuit::{
|
||||
chip::einsum::analysis::EinsumAnalysis,
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
},
|
||||
@@ -29,6 +25,9 @@ use std::{collections::BTreeMap, marker::PhantomData};
|
||||
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
|
||||
///
|
||||
pub mod einsum;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
|
||||
#[derive(
|
||||
@@ -271,6 +270,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub range_checks: RangeChecks<F>,
|
||||
/// [Selector]s for the shuffles
|
||||
pub shuffles: Shuffles,
|
||||
/// Einsum-specific configuration
|
||||
pub einsums: einsum::Einsums<F>,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -285,6 +286,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
|
||||
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
|
||||
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
|
||||
einsums: einsum::Einsums::<F>::dummy(col_size, num_inner_cols),
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
@@ -419,6 +421,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
},
|
||||
static_lookups: StaticLookups::default(),
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
einsums: einsum::Einsums::<F>::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
@@ -693,6 +696,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates einsums
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_einsums(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
analysis: &EinsumAnalysis,
|
||||
num_inner_cols: usize,
|
||||
logrows: usize,
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
self.einsums = einsum::Einsums::configure_universal(cs, analysis, num_inner_cols, logrows);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_shuffles(
|
||||
|
||||
175
src/circuit/ops/chip/einsum/analysis.rs
Normal file
175
src/circuit/ops/chip/einsum/analysis.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::circuit::{
|
||||
einsum::reduction_planner::{self, Reduction},
|
||||
CircuitError,
|
||||
};
|
||||
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EinsumAnalysis {
|
||||
/// max size of input tensors
|
||||
pub max_input_size: usize,
|
||||
/// max size of output tensors
|
||||
pub max_output_size: usize,
|
||||
/// max number of input tensors
|
||||
pub max_num_inputs: usize,
|
||||
/// max number of output axes
|
||||
pub max_num_output_axes: usize,
|
||||
///
|
||||
pub longest_challenge_vector: usize,
|
||||
///
|
||||
pub reduction_length: usize,
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SingleEquationAnalysis {
|
||||
///
|
||||
pub equation: String,
|
||||
///
|
||||
pub num_inputs: usize,
|
||||
///
|
||||
pub max_input_size: usize,
|
||||
///
|
||||
pub output_size: usize,
|
||||
///
|
||||
pub num_output_axes: usize,
|
||||
///
|
||||
pub output_indices: Vec<char>,
|
||||
///
|
||||
pub longest_challenge_vector: usize,
|
||||
/// the length of dot product to compute all the reductions
|
||||
pub reduction_length: usize,
|
||||
}
|
||||
|
||||
///
|
||||
pub fn analyze_einsum_usage(
|
||||
equations: &HashMap<String, HashMap<char, usize>>,
|
||||
) -> Result<EinsumAnalysis, CircuitError> {
|
||||
let mut max_num_inputs = 0;
|
||||
let mut max_input_size = 0;
|
||||
let mut max_output_size = 0;
|
||||
let mut max_num_output_axes = 0;
|
||||
let mut longest_challenge_vector = 0;
|
||||
let mut reduction_length = 0;
|
||||
|
||||
for (equation, input_axes_to_dim) in equations.iter() {
|
||||
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
|
||||
max_input_size = max_input_size.max(analysis.max_input_size);
|
||||
longest_challenge_vector = longest_challenge_vector.max(analysis.longest_challenge_vector);
|
||||
max_output_size = max_output_size.max(analysis.output_size);
|
||||
max_num_inputs = max_num_inputs.max(analysis.num_inputs);
|
||||
max_num_output_axes = max_num_output_axes.max(analysis.num_output_axes);
|
||||
reduction_length += analysis.reduction_length;
|
||||
}
|
||||
|
||||
Ok(EinsumAnalysis {
|
||||
max_input_size,
|
||||
longest_challenge_vector,
|
||||
max_output_size,
|
||||
max_num_inputs,
|
||||
max_num_output_axes,
|
||||
reduction_length,
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
pub fn analyze_single_equation(
|
||||
equation: &str,
|
||||
input_axes_to_dim: &HashMap<char, usize>,
|
||||
) -> Result<SingleEquationAnalysis, CircuitError> {
|
||||
// Sanitise equation to remove trivial axes
|
||||
let equation = {
|
||||
let (inputs_str, output_str) = equation.split_once("->").unwrap();
|
||||
let input_equations: Vec<&str> = inputs_str.split(',').collect();
|
||||
|
||||
let inputs: Vec<String> = input_equations
|
||||
.iter()
|
||||
.map(|input| {
|
||||
input
|
||||
.chars()
|
||||
.filter(|char| input_axes_to_dim.get(char).is_some())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output = output_str
|
||||
.chars()
|
||||
.filter(|c| input_axes_to_dim.get(c).is_some())
|
||||
.collect();
|
||||
|
||||
[inputs.join(","), output].join("->")
|
||||
};
|
||||
|
||||
let (inputs_str, output_str) = equation.split_once("->").unwrap();
|
||||
let input_equations: Vec<&str> = inputs_str.split(',').collect();
|
||||
|
||||
let max_input_size = input_equations
|
||||
.iter()
|
||||
.map(|eqn| {
|
||||
eqn.chars()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap())
|
||||
.product()
|
||||
})
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
let output_indices: Vec<char> = output_str.chars().collect();
|
||||
let output_dims = output_indices
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap());
|
||||
let output_size = output_dims.clone().product();
|
||||
let longest_challenge_vector = *output_dims.clone().max().unwrap();
|
||||
|
||||
let output_reduction_length = {
|
||||
let mut output_dims = output_dims.rev().cloned().collect_vec();
|
||||
let mut total_length = 0;
|
||||
for _ in 0..output_dims.len() {
|
||||
let dot_product_len = output_dims.remove(0);
|
||||
let num_dot_products: usize = output_dims.iter().product();
|
||||
total_length += dot_product_len * num_dot_products;
|
||||
}
|
||||
total_length
|
||||
};
|
||||
|
||||
let input_reductions_length = {
|
||||
let input_reductions = reduction_planner::input_reductions(&equation)?;
|
||||
input_reductions
|
||||
.into_iter()
|
||||
.map(|reduction| {
|
||||
let (_, output_expr) = reduction.expression().split_once("->").unwrap();
|
||||
let num_inputs = reduction.input_indices().len();
|
||||
let dot_product_len = match reduction {
|
||||
Reduction::RLC { axis, .. } => *input_axes_to_dim.get(&axis).unwrap(),
|
||||
Reduction::Contraction { axis, .. } => *axis
|
||||
.and_then(|axis| input_axes_to_dim.get(&axis))
|
||||
.unwrap_or(&1),
|
||||
};
|
||||
let num_dot_products: usize = output_expr
|
||||
.chars()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap())
|
||||
.product();
|
||||
// since `multi_dot` does pairwise mult between input pairs and final summation
|
||||
if num_inputs <= 2 {
|
||||
num_dot_products * dot_product_len
|
||||
} else {
|
||||
num_dot_products * (dot_product_len * num_inputs)
|
||||
}
|
||||
})
|
||||
.sum::<usize>()
|
||||
};
|
||||
|
||||
Ok(SingleEquationAnalysis {
|
||||
output_size,
|
||||
longest_challenge_vector,
|
||||
max_input_size,
|
||||
equation: equation.to_string(),
|
||||
num_inputs: input_equations.len(),
|
||||
num_output_axes: output_indices.len(),
|
||||
output_indices,
|
||||
reduction_length: output_reduction_length + input_reductions_length,
|
||||
})
|
||||
}
|
||||
344
src/circuit/ops/chip/einsum/layouts.rs
Normal file
344
src/circuit/ops/chip/einsum/layouts.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use log::{error, trace};
|
||||
|
||||
use crate::{
|
||||
circuit::{base::BaseOp, region::RegionCtx, CircuitError},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
TensorError, TensorType, ValTensor, ValType,
|
||||
},
|
||||
};
|
||||
|
||||
use super::EinsumOpConfig;
|
||||
|
||||
/// Pairwise (elementwise) op layout
|
||||
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
op: BaseOp,
|
||||
phases: &[usize; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let (mut lhs, mut rhs) = if phases[0] <= phases[1] {
|
||||
(values[0].clone(), values[1].clone())
|
||||
} else {
|
||||
(values[1].clone(), values[0].clone())
|
||||
};
|
||||
let min_phase = std::cmp::min(phases[0], phases[1]);
|
||||
|
||||
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
|
||||
|
||||
lhs.expand(&broadcasted_shape)?;
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
if lhs.len() != rhs.len() {
|
||||
return Err(CircuitError::DimMismatch(format!(
|
||||
"pairwise {} layout",
|
||||
op.as_str()
|
||||
)));
|
||||
}
|
||||
|
||||
region.flush_einsum()?;
|
||||
|
||||
let inputs = [lhs, rhs]
|
||||
.iter()
|
||||
.zip(config.inputs.iter().skip(min_phase * 2))
|
||||
.map(|(val, var)| {
|
||||
let res = region.assign_einsum(var, val)?;
|
||||
Ok(res.get_inner()?)
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time the calc
|
||||
let op_result = match op {
|
||||
BaseOp::Add => add(&inputs),
|
||||
BaseOp::Sub => sub(&inputs),
|
||||
BaseOp::Mult => mult(&inputs),
|
||||
_ => return Err(CircuitError::UnsupportedOp),
|
||||
}
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
let assigned_len = op_result.len();
|
||||
let mut output = region.assign_einsum(&config.output, &op_result.into())?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[0].cartesian_coord(region.einsum_col_coord() + i);
|
||||
let selector = config.selectors.get(&(min_phase, op.clone(), x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
output.reshape(&broadcasted_shape)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() == 1 {
|
||||
return Ok(values[0].clone());
|
||||
}
|
||||
assert!(phase == 0 || phase == 1);
|
||||
|
||||
region.flush_einsum()?;
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let block_width = config.output.num_inner_cols();
|
||||
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
// FIXME : should pad with constant zero but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
let (res, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&config.inputs[phase * 2], &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_sum = accumulated::sum(&input, block_width)?;
|
||||
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
&accumulated_sum.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
for i in 0..output_assigned_len {
|
||||
let (x, _, z) = config
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
continue;
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
config.selectors.get(&(phase, BaseOp::SumInit, x, 0))
|
||||
} else {
|
||||
config.selectors.get(&(phase, BaseOp::Sum, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phase == 0 || phase == 1);
|
||||
region.flush_einsum()?;
|
||||
let block_width = config.output.num_inner_cols();
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
// FIXME : should pad with constant one but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ONE)))?;
|
||||
let (res, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&config.inputs[phase * 2], &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_prod = accumulated::prod(&input, block_width)?;
|
||||
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
&accumulated_prod.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) = config
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
config.selectors.get(&(phase, BaseOp::CumProdInit, x, 0))
|
||||
} else {
|
||||
config.selectors.get(&(phase, BaseOp::CumProd, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
phases: &[usize; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() != values[1].len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
|
||||
region.flush_einsum()?;
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
let mut values = if phases[0] <= phases[1] {
|
||||
[values[0].clone(), values[1].clone()]
|
||||
} else {
|
||||
[values[1].clone(), values[0].clone()]
|
||||
};
|
||||
let min_phase = std::cmp::min(phases[0], phases[1]);
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.output.num_inner_cols();
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (val, var) in values
|
||||
.iter_mut()
|
||||
.zip(config.inputs.iter().skip(min_phase * 2))
|
||||
{
|
||||
// FIXME : should pad with constant zero but currently this incurs an error
|
||||
// `NotEnoughColumnsForConstants` in halo2 because trying to assign constant
|
||||
// value to advice column, how to workaround this issue?
|
||||
val.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
inputs.push(inp);
|
||||
}
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time this step
|
||||
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
&config.output,
|
||||
&accumulated_dot.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) = config
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// hop over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
config.selectors.get(&(min_phase, BaseOp::DotInit, x, 0))
|
||||
} else {
|
||||
config.selectors.get(&(min_phase, BaseOp::Dot, x, 0))
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("dot layout took: {:?}, row {}", elapsed, region.row());
|
||||
trace!("----------------------------");
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
/// Dot product of more than two tensors
|
||||
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>],
|
||||
phases: &[usize],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
|
||||
if !values.iter().all(|value| value.len() == values[0].len()) {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
|
||||
// do pairwise dot product between intermediate tensor and the next tensor
|
||||
let (intermediate, _) = values
|
||||
.into_iter()
|
||||
.zip(phases.iter().cloned())
|
||||
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
|
||||
(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&intermediate, &input],
|
||||
BaseOp::Mult,
|
||||
&[intermediate_phase, phase],
|
||||
)
|
||||
.unwrap(),
|
||||
std::cmp::max(intermediate_phase, phase),
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Sum the final tensor
|
||||
// In current freivalds' algorithm, we ensure that there is no tensor contraction between phase 0 tensors,
|
||||
// so the phase of the resulting tensor is set to 1
|
||||
let accumulated_dot = sum(config, region, &[&intermediate], 1)?;
|
||||
let last_elem = accumulated_dot.last()?;
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("multi_dot layout took: {:?}, row {}", elapsed, region.row());
|
||||
trace!("----------------------------");
|
||||
Ok(last_elem)
|
||||
}
|
||||
649
src/circuit/ops/chip/einsum/mod.rs
Normal file
649
src/circuit/ops/chip/einsum/mod.rs
Normal file
@@ -0,0 +1,649 @@
|
||||
use crate::circuit::base::BaseOp;
|
||||
use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnalysis};
|
||||
use crate::circuit::einsum::layouts::{pairwise, sum};
|
||||
use crate::circuit::einsum::reduction_planner::Reduction;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::CircuitError;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
Challenge, ConstraintSystem, Constraints, Expression, FirstPhase, Selector,
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use layouts::{dot, multi_dot, prod};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
///
|
||||
pub mod analysis;
|
||||
mod layouts;
|
||||
mod reduction_planner;
|
||||
|
||||
/// A struct representing reductions for the einsums
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// custom gate to constrain tensor contractions
|
||||
custom_gate: EinsumOpConfig<F>,
|
||||
/// custom gate to constrain random linear combinations used by Freivalds' argument
|
||||
rlc_gates: Vec<RLCConfig<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
///
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let dummy_custom_gate = EinsumOpConfig {
|
||||
inputs: [
|
||||
dummy_var.clone(),
|
||||
dummy_var.clone(),
|
||||
dummy_var.clone(),
|
||||
dummy_var.clone(),
|
||||
],
|
||||
output: dummy_var.clone(),
|
||||
selectors: BTreeMap::default(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
Self {
|
||||
custom_gate: dummy_custom_gate,
|
||||
rlc_gates: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn challenges(&self) -> Vec<Challenge> {
|
||||
self.rlc_gates
|
||||
.iter()
|
||||
.map(|gate| gate.challenge)
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Configure the columns based on universal Einsum analysis
|
||||
pub fn configure_universal(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
analysis: &EinsumAnalysis,
|
||||
num_inner_cols: usize,
|
||||
logrows: usize,
|
||||
) -> Self {
|
||||
let capacity = analysis.reduction_length;
|
||||
let inputs: [VarTensor; 4] = [
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let output = VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity);
|
||||
let custom_gate = EinsumOpConfig::new(meta, &inputs, &output);
|
||||
|
||||
let mut rlc_gates = vec![];
|
||||
for _ in 0..analysis.max_num_output_axes {
|
||||
let rlc_gate = RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &output);
|
||||
rlc_gates.push(rlc_gate);
|
||||
}
|
||||
|
||||
Self {
|
||||
custom_gate,
|
||||
rlc_gates,
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn assign_einsum(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
input_tensors: &[&ValTensor<F>],
|
||||
output_tensor: &ValTensor<F>,
|
||||
equation: &str,
|
||||
) -> Result<(), CircuitError> {
|
||||
region.set_num_einsum_inner_cols(self.custom_gate.output.num_inner_cols());
|
||||
|
||||
let (input_exprs, _) = equation.split_once("->").unwrap();
|
||||
let input_exprs = input_exprs.split(",").collect_vec();
|
||||
assert_eq!(input_exprs.len(), input_tensors.len());
|
||||
|
||||
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
|
||||
let mut output_tensor = output_tensor.clone();
|
||||
|
||||
// Remove trivial axes from tensors
|
||||
input_tensors
|
||||
.iter_mut()
|
||||
.map(|tensor| tensor.remove_trivial_axes())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
output_tensor.remove_trivial_axes()?;
|
||||
|
||||
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
|
||||
input_exprs
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.for_each(|(indices, tensor)| {
|
||||
let tensor_dim = tensor.dims();
|
||||
indices
|
||||
.chars()
|
||||
.zip(tensor_dim.iter())
|
||||
.for_each(|(index, dim)| {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) =
|
||||
input_axes_to_dim.entry(index)
|
||||
{
|
||||
e.insert(*dim);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
|
||||
let equation = equation_analysis.equation;
|
||||
|
||||
let output_shape = equation_analysis
|
||||
.output_indices
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(c).copied().unwrap())
|
||||
.collect_vec();
|
||||
let squashed_output = self.assign_output(region, &output_tensor, output_shape)?;
|
||||
|
||||
// reorder the reduction of input tensors and reduce
|
||||
let reordered_input_reductions = reduction_planner::input_reductions(&equation).unwrap();
|
||||
let mut tensors = input_tensors;
|
||||
|
||||
for reduction in reordered_input_reductions.iter() {
|
||||
let (input_expr, output_expr) = reduction.expression().split_once("->").unwrap();
|
||||
let input_exprs = input_expr.split(",").collect_vec();
|
||||
|
||||
let remaining_axes = output_expr.chars().collect_vec();
|
||||
let mut remaining_axes_indices = remaining_axes
|
||||
.iter()
|
||||
.map(|c| 0..input_axes_to_dim[c])
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
|
||||
// Dummy value to ensure the for loop runs at least once
|
||||
if remaining_axes.is_empty() {
|
||||
remaining_axes_indices.push(vec![]);
|
||||
}
|
||||
|
||||
let input_tensors = reduction
|
||||
.input_indices()
|
||||
.iter()
|
||||
.map(|idx| tensors[*idx].clone())
|
||||
.collect_vec();
|
||||
|
||||
let mut flattened_input_tensors: Vec<Vec<ValTensor<F>>> =
|
||||
vec![vec![]; input_tensors.len()];
|
||||
for remaining_axes_indices in remaining_axes_indices {
|
||||
// corresponds to 1 running sum of input tensors
|
||||
for (i, (input_tensor, input_expr)) in
|
||||
input_tensors.iter().zip(input_exprs.iter()).enumerate()
|
||||
{
|
||||
let mut sliced_dim = vec![];
|
||||
input_expr.chars().for_each(|axis| {
|
||||
if let Some(pos) = remaining_axes.iter().position(|c| *c == axis) {
|
||||
sliced_dim
|
||||
.push(remaining_axes_indices[pos]..remaining_axes_indices[pos] + 1);
|
||||
} else {
|
||||
// common axis
|
||||
sliced_dim.push(0..input_axes_to_dim[&axis]);
|
||||
}
|
||||
});
|
||||
let mut sliced_input_tensor = input_tensor.get_slice(&sliced_dim)?;
|
||||
sliced_input_tensor.flatten();
|
||||
flattened_input_tensors[i].push(sliced_input_tensor);
|
||||
}
|
||||
}
|
||||
let flattened_input_tensors = flattened_input_tensors
|
||||
.into_iter()
|
||||
.map(|tensors| {
|
||||
ValTensor::from(
|
||||
tensors
|
||||
.into_iter()
|
||||
.flat_map(|t| t.get_inner_tensor().unwrap().clone().into_iter())
|
||||
.collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_dims = output_expr
|
||||
.chars()
|
||||
.map(|c| input_axes_to_dim[&c])
|
||||
.collect_vec();
|
||||
|
||||
let contracted_output = match reduction {
|
||||
Reduction::RLC {
|
||||
axis,
|
||||
input_phase,
|
||||
challenge_index,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(flattened_input_tensors.len(), 1);
|
||||
let rlc_len = input_axes_to_dim[axis];
|
||||
let mut result = self.rlc_gates[*challenge_index].assign_rlc(
|
||||
region,
|
||||
&flattened_input_tensors[0],
|
||||
region.challenges()[*challenge_index],
|
||||
rlc_len,
|
||||
*input_phase,
|
||||
)?;
|
||||
result.reshape(&output_dims)?;
|
||||
result
|
||||
}
|
||||
Reduction::Contraction {
|
||||
axis, input_phases, ..
|
||||
} => match axis {
|
||||
Some(axis) => {
|
||||
let dot_product_len = input_axes_to_dim[axis];
|
||||
assign_input_contraction(
|
||||
&self.custom_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
dot_product_len,
|
||||
&output_dims,
|
||||
input_phases,
|
||||
)?
|
||||
}
|
||||
None => {
|
||||
let mut result = assign_pairwise_mult(
|
||||
&self.custom_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
input_phases,
|
||||
)?;
|
||||
result.reshape(&output_dims)?;
|
||||
result
|
||||
}
|
||||
},
|
||||
};
|
||||
tensors.push(contracted_output);
|
||||
}
|
||||
tensors.retain(|tensor| tensor.is_singleton());
|
||||
|
||||
let scalars: ValTensor<F> = tensors
|
||||
.into_iter()
|
||||
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
|
||||
.collect_vec()
|
||||
.into();
|
||||
let squashed_input = prod(&self.custom_gate, region, &[&scalars], 1)?;
|
||||
|
||||
region.constrain_equal(&squashed_input, &squashed_output)
|
||||
}
|
||||
|
||||
fn assign_output(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
output: &ValTensor<F>,
|
||||
mut output_shape: Vec<usize>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut intermediate_values = output.clone();
|
||||
|
||||
let challenges = region
|
||||
.challenges()
|
||||
.iter()
|
||||
.take(output_shape.len())
|
||||
.copied()
|
||||
.collect_vec();
|
||||
// Intermediate values output from the previous reduction
|
||||
// Loop over the output axes
|
||||
for (idx, (rlc_config, challenge)) in self
|
||||
.rlc_gates
|
||||
.iter()
|
||||
.take(output_shape.len())
|
||||
.zip(challenges)
|
||||
.rev()
|
||||
.enumerate()
|
||||
{
|
||||
let rlc_len = output_shape.last().copied().unwrap();
|
||||
intermediate_values.flatten();
|
||||
let phase = if idx > 0 { 1 } else { 0 };
|
||||
intermediate_values =
|
||||
rlc_config.assign_rlc(region, &intermediate_values, challenge, rlc_len, phase)?;
|
||||
output_shape.pop();
|
||||
}
|
||||
|
||||
Ok(intermediate_values)
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
input_phases: &[usize],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let (result, _) = flattened_tensors
|
||||
.into_iter()
|
||||
.zip(input_phases.iter().cloned())
|
||||
.reduce(|(acc, acc_phase), (input, phase)| {
|
||||
(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&acc, &input],
|
||||
BaseOp::Mult,
|
||||
&[acc_phase, phase],
|
||||
)
|
||||
.unwrap(),
|
||||
std::cmp::max(acc_phase, phase),
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &EinsumOpConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
dot_product_len: usize,
|
||||
output_shape: &[usize],
|
||||
input_phases: &[usize],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let num_dot_products = output_shape.iter().product();
|
||||
let mut dot_product_results = vec![];
|
||||
for chunk_idx in 0..num_dot_products {
|
||||
let start = chunk_idx * dot_product_len;
|
||||
let tensors: Vec<_> = flattened_tensors
|
||||
.iter()
|
||||
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
let result = if tensors.len() == 1 {
|
||||
sum(config, region, &[&tensors[0]], input_phases[0])?
|
||||
} else if tensors.len() == 2 {
|
||||
dot(
|
||||
config,
|
||||
region,
|
||||
&[&tensors[0], &tensors[1]],
|
||||
&[input_phases[0], input_phases[1]],
|
||||
)?
|
||||
} else {
|
||||
multi_dot(
|
||||
config,
|
||||
region,
|
||||
tensors.iter().collect_vec().as_slice(),
|
||||
input_phases,
|
||||
)?
|
||||
};
|
||||
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
|
||||
}
|
||||
let mut tensor = ValTensor::from(dot_product_results);
|
||||
tensor.reshape(output_shape)?;
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// `EinsumOpConfig` is the custom gate used for einsum contraction operations
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct EinsumOpConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
// [phase 0, phase 0, phase 1, phase 1]
|
||||
inputs: [VarTensor; 4],
|
||||
// phase 1
|
||||
output: VarTensor,
|
||||
// (phase, BaseOp, block index, inner column index) -> selector
|
||||
selectors: BTreeMap<(usize, BaseOp, usize, usize), Selector>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> EinsumOpConfig<F> {
|
||||
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 4], output: &VarTensor) -> Self {
|
||||
let mut selectors = BTreeMap::new();
|
||||
for phase in [0, 1] {
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
selectors.insert((phase, BaseOp::Mult, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for phase in [0, 1] {
|
||||
for i in 0..output.num_blocks() {
|
||||
selectors.insert((phase, BaseOp::DotInit, i, 0), meta.selector());
|
||||
selectors.insert((phase, BaseOp::Dot, i, 0), meta.selector());
|
||||
selectors.insert((phase, BaseOp::SumInit, i, 0), meta.selector());
|
||||
selectors.insert((phase, BaseOp::Sum, i, 0), meta.selector());
|
||||
}
|
||||
}
|
||||
selectors.insert(
|
||||
(1, BaseOp::CumProdInit, output.num_blocks() - 1, 0),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(1, BaseOp::CumProd, output.num_blocks() - 1, 0),
|
||||
meta.selector(),
|
||||
);
|
||||
for ((phase, base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
|
||||
match base_op {
|
||||
BaseOp::Mult => {
|
||||
meta.create_gate(base_op.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
|
||||
let zero = Expression::<F>::Constant(F::ZERO);
|
||||
let mut qis = vec![zero; 4];
|
||||
for (i, q_i) in qis
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.skip(*phase * 2)
|
||||
.take(base_op.num_inputs())
|
||||
{
|
||||
*q_i = inputs[i]
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("einsum op config: input query failed")[0]
|
||||
.clone()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("einsum op config: output query failed");
|
||||
|
||||
let res = base_op
|
||||
.nonaccum_f((qis[2 * *phase].clone(), qis[2 * *phase + 1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
meta.create_gate(base_op.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
let mut qis = vec![vec![]; 4];
|
||||
for (i, q_i) in qis
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.skip(*phase * 2)
|
||||
.take(base_op.num_inputs())
|
||||
{
|
||||
*q_i = inputs[i]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("einsum op config: input query failed")
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
|
||||
.expect("einsum op config: output query failed");
|
||||
|
||||
let res = base_op.accum_f(
|
||||
expected_output[0].clone(),
|
||||
qis[2 * phase + 1].clone(),
|
||||
qis[2 * *phase].clone(),
|
||||
);
|
||||
let constraints =
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res];
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
inputs: inputs.clone(),
|
||||
output: output.clone(),
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `RLCConfig` is the custom gate used for random linear combination with the specific challenge
|
||||
#[derive(Clone, Debug)]
|
||||
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub challenge: Challenge,
|
||||
/// [phase 0, phase 1]
|
||||
pub inputs: [VarTensor; 2],
|
||||
pub output: VarTensor,
|
||||
/// (phase of input, block index) -> (init selector, acc selector)
|
||||
pub selectors: BTreeMap<(usize, usize), (Selector, Selector)>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
|
||||
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 2], output: &VarTensor) -> Self {
|
||||
let challenge = meta.challenge_usable_after(FirstPhase);
|
||||
|
||||
let mut selectors = BTreeMap::new();
|
||||
for (phase, input) in inputs.iter().enumerate() {
|
||||
for block_idx in 0..input.num_blocks() {
|
||||
let selector = (meta.selector(), meta.selector());
|
||||
selectors.insert((phase, block_idx), selector);
|
||||
}
|
||||
}
|
||||
let block_width = output.num_inner_cols();
|
||||
let powers_of_challenge = (0..block_width)
|
||||
.scan(Expression::Constant(F::ONE), |r_power, _| {
|
||||
*r_power = r_power.clone() * challenge.expr();
|
||||
Some(r_power.clone())
|
||||
})
|
||||
.collect_vec();
|
||||
for ((phase, block_idx), (init_selector, acc_selector)) in selectors.iter() {
|
||||
meta.create_gate("init", |meta| {
|
||||
let selector = meta.query_selector(*init_selector);
|
||||
let input_exprs = inputs[*phase]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("rlc config: input query failed")
|
||||
.into_iter()
|
||||
.collect();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, 0, 1)
|
||||
.expect("rlc config: output query failed");
|
||||
|
||||
let res = BaseOp::Dot.accum_f(
|
||||
Expression::Constant(F::ZERO),
|
||||
powers_of_challenge.iter().cloned().rev().collect_vec(),
|
||||
input_exprs,
|
||||
);
|
||||
vec![expected_output[0].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
meta.create_gate("acc", |meta| {
|
||||
let selector = meta.query_selector(*acc_selector);
|
||||
let input_exprs = inputs[*phase]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("rlc config: input query failed")
|
||||
.into_iter()
|
||||
.collect();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, -1, 2)
|
||||
.expect("rlc config: output query failed");
|
||||
|
||||
let res = BaseOp::Dot.accum_f(
|
||||
expected_output[0].clone() * powers_of_challenge.last().cloned().unwrap(),
|
||||
powers_of_challenge.iter().cloned().rev().collect_vec(),
|
||||
input_exprs,
|
||||
);
|
||||
vec![expected_output[1].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
Self {
|
||||
inputs: inputs.clone(),
|
||||
output: output.clone(),
|
||||
selectors,
|
||||
challenge,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_rlc(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_input: &ValTensor<F>,
|
||||
challenge: Value<F>,
|
||||
rlc_len: usize,
|
||||
phase: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
region.flush_einsum()?;
|
||||
let block_width = self.output.num_inner_cols();
|
||||
let powers_of_challenge = (0..block_width)
|
||||
.scan(Value::known(F::ONE), |challenge_power, _| {
|
||||
*challenge_power = challenge_power.clone() * challenge;
|
||||
Some(challenge_power.clone())
|
||||
})
|
||||
.collect_vec();
|
||||
let mut rlc_results: Vec<ValType<F>> = vec![];
|
||||
for tensor in flattened_input.get_inner_tensor()?.chunks_exact(rlc_len) {
|
||||
let running_sums = tensor
|
||||
.iter()
|
||||
.chunks(block_width)
|
||||
.into_iter()
|
||||
.scan(Value::known(F::ZERO), |state, val| {
|
||||
let curr_sum: Value<F> = val
|
||||
.into_iter()
|
||||
.zip(powers_of_challenge.iter().rev())
|
||||
.map(|(v, c_power)| {
|
||||
c_power.and_then(|c_power| {
|
||||
Value::known(c_power * v.get_felt_eval().unwrap())
|
||||
})
|
||||
})
|
||||
.reduce(|acc, v| acc + v)
|
||||
.unwrap();
|
||||
*state = *state * powers_of_challenge.last().unwrap() + curr_sum;
|
||||
Some(*state)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let assigned_len = {
|
||||
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
|
||||
input.pad_to_zero_rem(block_width, ValType::Value(Value::known(F::ZERO)))?;
|
||||
let (_, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
|
||||
len
|
||||
};
|
||||
let (assigned_output, assigned_output_len) = {
|
||||
let running_sums = running_sums.into_iter().map(ValType::from).collect_vec();
|
||||
region.assign_einsum_with_duplication_constrained(
|
||||
&self.output,
|
||||
&running_sums.into(),
|
||||
&crate::circuit::CheckMode::UNSAFE,
|
||||
)?
|
||||
};
|
||||
|
||||
(0..assigned_output_len)
|
||||
.map(|i| {
|
||||
let (block_idx, _, z) = self
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
self.selectors
|
||||
.get(&(phase, block_idx))
|
||||
.map(|(init, _)| init)
|
||||
} else {
|
||||
self.selectors.get(&(phase, block_idx)).map(|(_, acc)| acc)
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
rlc_results.push(assigned_output.last()?.get_inner_tensor()?.get_scalar());
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
}
|
||||
Ok(rlc_results.into())
|
||||
}
|
||||
}
|
||||
191
src/circuit/ops/chip/einsum/reduction_planner.rs
Normal file
191
src/circuit/ops/chip/einsum/reduction_planner.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use std::{collections::BTreeSet, ops::Index};
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
tensor::{TensorType, ValTensor},
|
||||
};
|
||||
|
||||
/// inj,jk->ik [inj,jk]
|
||||
/// inj,i->nj => RLC [jk,nj]
|
||||
/// jk,k->j => RLC [nj,j]
|
||||
/// nj,j->n => Contraction [n]
|
||||
/// n-> => Contraction []
|
||||
///
|
||||
/// bn,anm,bm->ba [bn,anm,bm]
|
||||
/// bn,bm->bnm => Contraction [anm,bnm]
|
||||
/// bnm,b->nm => RLC [anm,nm]
|
||||
/// anm,a->nm => RLC [nm,nm]
|
||||
/// nm,nm->m => Contraction [m]
|
||||
/// m-> => Contraction []
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Reduction {
|
||||
/// Random linear combination with powers of challenge along the axis
|
||||
RLC {
|
||||
expression: String,
|
||||
axis: char,
|
||||
/// Uniquely identifying index of input tensor to be reduced
|
||||
input_index: TensorIndex,
|
||||
/// phase of input tensor
|
||||
input_phase: usize,
|
||||
challenge_index: usize,
|
||||
},
|
||||
Contraction {
|
||||
expression: String,
|
||||
/// when axis is `None`, the contraction is pairwise multiplication
|
||||
axis: Option<char>,
|
||||
/// Uniquely identifying indices of input tensors to be contracted
|
||||
input_indices: Vec<TensorIndex>,
|
||||
/// phases of input tensors
|
||||
input_phases: Vec<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorIndex(usize);
|
||||
|
||||
impl<T: PrimeField + TensorType + PartialOrd> Index<TensorIndex> for Vec<ValTensor<T>> {
|
||||
type Output = ValTensor<T>;
|
||||
|
||||
fn index(&self, index: TensorIndex) -> &Self::Output {
|
||||
&self[index.0]
|
||||
}
|
||||
}
|
||||
|
||||
impl Reduction {
|
||||
pub fn expression(&self) -> &str {
|
||||
match self {
|
||||
Reduction::Contraction { expression, .. } => expression,
|
||||
Reduction::RLC { expression, .. } => &expression,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn input_indices(&self) -> Vec<TensorIndex> {
|
||||
match self {
|
||||
Reduction::Contraction { input_indices, .. } => input_indices.clone(),
|
||||
Reduction::RLC { input_index, .. } => vec![*input_index],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn input_reductions(expression: &str) -> Result<Vec<Reduction>, CircuitError> {
|
||||
let (input_exprs, output_expr) = expression.split_once("->").unwrap();
|
||||
let input_exprs: Vec<_> = input_exprs.split(",").map(|eq| eq.to_string()).collect();
|
||||
// (phase, expression)
|
||||
let input_exprs: Vec<(usize, String)> =
|
||||
input_exprs.into_iter().map(|expr| (0, expr)).collect_vec();
|
||||
|
||||
let mut input_tensor_counter = input_exprs.len();
|
||||
let mut input_exprs: Vec<((usize, String), TensorIndex)> = input_exprs
|
||||
.into_iter()
|
||||
.zip((0..input_tensor_counter).map(TensorIndex))
|
||||
.collect();
|
||||
let mut reductions: Vec<Reduction> = vec![];
|
||||
|
||||
// Reduce input_exprs along given axis
|
||||
let mut reduce = |input_exprs: Vec<((usize, String), TensorIndex)>,
|
||||
axis: char|
|
||||
-> (Reduction, Vec<((usize, String), TensorIndex)>) {
|
||||
let inputs = input_exprs
|
||||
.iter()
|
||||
.filter(|((_, eq), _)| eq.chars().contains(&axis))
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
let (inputs_axes, input_indices): (Vec<(usize, String)>, Vec<TensorIndex>) =
|
||||
inputs.iter().cloned().unzip();
|
||||
let (input_phases, inputs_axes): (Vec<usize>, Vec<String>) =
|
||||
inputs_axes.into_iter().unzip();
|
||||
|
||||
let is_output_axis = output_expr.chars().contains(&axis);
|
||||
let output: String = if is_output_axis == true && inputs.len() > 1 {
|
||||
let output: BTreeSet<char> =
|
||||
inputs_axes.iter().flat_map(|input| input.chars()).collect();
|
||||
output.iter().collect()
|
||||
} else {
|
||||
let output: BTreeSet<char> = inputs_axes
|
||||
.iter()
|
||||
.flat_map(|input| input.chars().filter(|&c| c != axis))
|
||||
.collect();
|
||||
output.iter().collect()
|
||||
};
|
||||
let mut output_phase = input_phases.iter().copied().max().unwrap();
|
||||
|
||||
let reduction = if is_output_axis == true && inputs.len() == 1 {
|
||||
output_phase = 1;
|
||||
let mut expression = inputs_axes.join(",");
|
||||
expression.push_str(format!(",{axis}").as_str());
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::RLC {
|
||||
expression,
|
||||
axis,
|
||||
input_index: input_indices[0],
|
||||
input_phase: input_phases[0],
|
||||
challenge_index: output_expr.chars().position(|c| c == axis).unwrap(),
|
||||
}
|
||||
} else if is_output_axis == true {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::Contraction {
|
||||
expression,
|
||||
axis: None,
|
||||
input_indices: input_indices,
|
||||
input_phases,
|
||||
}
|
||||
} else {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::Contraction {
|
||||
expression,
|
||||
axis: Some(axis),
|
||||
input_indices: input_indices,
|
||||
input_phases,
|
||||
}
|
||||
};
|
||||
|
||||
// Mutate input_exprs
|
||||
let mut input_exprs = input_exprs.clone();
|
||||
input_exprs.retain(|((_, input_eq), _)| !inputs_axes.contains(input_eq));
|
||||
input_exprs.push((
|
||||
(output_phase, output.clone()),
|
||||
TensorIndex(input_tensor_counter),
|
||||
));
|
||||
input_tensor_counter += 1;
|
||||
|
||||
(reduction, input_exprs)
|
||||
};
|
||||
|
||||
let mut output_axes = output_expr.chars().collect_vec();
|
||||
while let Some(axis) = output_axes.first().cloned() {
|
||||
let num_inputs = input_exprs
|
||||
.iter()
|
||||
.filter(|((_, eq), _)| eq.chars().contains(&axis))
|
||||
.count();
|
||||
if num_inputs == 0 {
|
||||
output_axes.remove(0);
|
||||
} else {
|
||||
let (reduction, new_input_exprs) = reduce(input_exprs, axis);
|
||||
reductions.push(reduction);
|
||||
input_exprs = new_input_exprs;
|
||||
}
|
||||
}
|
||||
|
||||
// These are not output axes and were not contracted with random vectors
|
||||
let remaining_axes: BTreeSet<_> = input_exprs
|
||||
.iter()
|
||||
.flat_map(|((_, eq), _)| eq.chars())
|
||||
.collect();
|
||||
|
||||
for axis in remaining_axes.iter() {
|
||||
let (reduction, new_input_exprs) = reduce(input_exprs, *axis);
|
||||
reductions.push(reduction);
|
||||
input_exprs = new_input_exprs;
|
||||
}
|
||||
|
||||
Ok(reductions)
|
||||
}
|
||||
@@ -64,6 +64,9 @@ pub enum CircuitError {
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
/// ???
|
||||
#[error("missing config in einsum")]
|
||||
MissingEinsumConfig,
|
||||
/// Mismatched lookup length
|
||||
#[error("mismatched lookup lengths: {0} and {1}")]
|
||||
MismatchedLookupLength(usize, usize),
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
f64::consts::E,
|
||||
ops::Range,
|
||||
};
|
||||
use std::{collections::HashMap, f64::consts::E, ops::Range};
|
||||
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -829,215 +825,31 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[&ValTensor<F>],
|
||||
input_tensors: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// Track the einsum equation
|
||||
region.add_used_einsum_equation(equation.to_string())?;
|
||||
|
||||
let mut equation = equation.split("->");
|
||||
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut indices_to_size = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if indices_to_size[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// maps unrepresented indices in the output to a trivial 1
|
||||
for c in output_eq.chars() {
|
||||
indices_to_size.entry(c).or_insert(1);
|
||||
}
|
||||
|
||||
// Compute the output tensor shape
|
||||
let mut output_shape: Vec<usize> = output_eq
|
||||
.chars()
|
||||
.map(|c| {
|
||||
indices_to_size
|
||||
.get(&c)
|
||||
.ok_or(CircuitError::InvalidEinsum)
|
||||
.copied()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if output_shape.is_empty() {
|
||||
output_shape.push(1);
|
||||
}
|
||||
|
||||
// Create a new output tensor with the computed shape
|
||||
let mut output: Tensor<ValType<F>> = Tensor::new(None, &output_shape)?;
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
let mut common_indices_to_inputs = vec![];
|
||||
for input in inputs_eq.iter().take(inputs.len()) {
|
||||
for c in input.chars() {
|
||||
if !seen.contains(&c) {
|
||||
seen.insert(c);
|
||||
} else {
|
||||
common_indices_to_inputs.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let non_common_indices = indices_to_size
|
||||
.keys()
|
||||
.filter(|&x| !common_indices_to_inputs.contains(x))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let non_common_coord_size = non_common_indices
|
||||
let inputs = input_tensors
|
||||
.iter()
|
||||
.map(|d| {
|
||||
// If the current index is in the output equation, then the slice should be the current coordinate
|
||||
if output_eq.contains(**d) {
|
||||
Ok(1)
|
||||
// Otherwise, the slice should be the entire dimension of the input tensor
|
||||
} else {
|
||||
indices_to_size
|
||||
.get(d)
|
||||
.ok_or(CircuitError::InvalidEinsum)
|
||||
.copied()
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.iter()
|
||||
.product::<usize>();
|
||||
.map(|t| t.get_inner())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
// Compute expected output using existing einsum logic
|
||||
// need to add this to ops
|
||||
let (output_tensor, _) =
|
||||
crate::tensor::ops::accumulated::einsum(equation, &inputs.iter().collect_vec())?;
|
||||
|
||||
let cartesian_coord = output_shape
|
||||
.iter()
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
config.einsums.assign_einsum(
|
||||
region,
|
||||
input_tensors,
|
||||
&output_tensor.clone().into(),
|
||||
equation,
|
||||
)?;
|
||||
|
||||
// Get the indices common across input tensors
|
||||
let mut common_coord = common_indices_to_inputs
|
||||
.iter()
|
||||
.map(|d| {
|
||||
// If the current index is in the output equation, then the slice should be the current coordinate
|
||||
if output_eq.contains(*d) {
|
||||
Ok(0..1)
|
||||
// Otherwise, the slice should be the entire dimension of the input tensor
|
||||
} else {
|
||||
Ok(0..*indices_to_size.get(d).ok_or(CircuitError::InvalidEinsum)?)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Range<_>>, CircuitError>>()?
|
||||
.into_iter()
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
// If there are no common indices, then we need to add an empty slice to force one iteration of the loop
|
||||
if common_coord.is_empty() {
|
||||
common_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
// Compute the slice of each input tensor given the current coordinate of the output tensor
|
||||
let inputs = (0..inputs.len())
|
||||
.map(|idx| {
|
||||
let mut slice = vec![];
|
||||
for (i, c) in inputs_eq[idx].chars().enumerate() {
|
||||
// If the current index is in the output equation, then the slice should be the current coordinate
|
||||
if let Some(idx) = output_eq.find(c) {
|
||||
slice.push(coord[idx]..coord[idx] + 1);
|
||||
// Otherwise, the slice should be the entire dimension of the input tensor
|
||||
} else {
|
||||
slice.push(0..inputs[idx].dims()[i]);
|
||||
}
|
||||
}
|
||||
// Get the slice of the input tensor
|
||||
inputs[idx].get_slice(&slice)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// in this case its just a dot product :)
|
||||
if non_common_coord_size == 1 && inputs.len() == 2 {
|
||||
Ok(dot(config, region, &[&inputs[0], &inputs[1]])?.get_inner_tensor()?[0].clone())
|
||||
} else {
|
||||
let mut prod_res = None;
|
||||
|
||||
// Compute the cartesian product of all common indices
|
||||
for common_dim in &common_coord {
|
||||
let inputs = (0..inputs.len())
|
||||
.map(|idx| {
|
||||
let mut slice = vec![];
|
||||
// Iterate over all indices in the input equation
|
||||
for (i, c) in inputs_eq[idx].chars().enumerate() {
|
||||
// If the current index is common to multiple inputs, then the slice should be the current coordinate
|
||||
if let Some(j) = common_indices_to_inputs.iter().position(|&r| r == c) {
|
||||
slice.push(common_dim[j]..common_dim[j] + 1);
|
||||
} else {
|
||||
slice.push(0..inputs[idx].dims()[i]);
|
||||
}
|
||||
}
|
||||
// Get the slice of the input tensor
|
||||
inputs[idx].get_slice(&slice).map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut input_pairs = vec![];
|
||||
|
||||
for input in &inputs {
|
||||
input_pairs.push(input.get_inner_tensor()?.iter());
|
||||
}
|
||||
|
||||
let input_pairs = input_pairs
|
||||
.into_iter()
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Compute the product of all input tensors
|
||||
for pair in input_pairs {
|
||||
let product_across_pair = prod(config, region, &[&pair.into()])?;
|
||||
|
||||
if let Some(product) = prod_res {
|
||||
prod_res = Some(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&product, &product_across_pair],
|
||||
BaseOp::Add,
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?,
|
||||
);
|
||||
} else {
|
||||
prod_res = Some(product_across_pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(prod_res
|
||||
.ok_or(CircuitError::MissingEinsumProduct)?
|
||||
.get_inner_tensor()?[0]
|
||||
.clone())
|
||||
}
|
||||
};
|
||||
|
||||
region.flush()?;
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let output: ValTensor<F> = output.into();
|
||||
let output: ValTensor<F> = output_tensor.into();
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
@@ -364,7 +364,15 @@ impl<
|
||||
};
|
||||
Ok(Some(if self.decomp {
|
||||
log::debug!("constraining constant to be decomp");
|
||||
super::layouts::decompose(config, region, &[&value], ®ion.base(), ®ion.legs(), false)?.1
|
||||
super::layouts::decompose(
|
||||
config,
|
||||
region,
|
||||
&[&value],
|
||||
®ion.base(),
|
||||
®ion.legs(),
|
||||
false,
|
||||
)?
|
||||
.1
|
||||
} else {
|
||||
log::debug!("constraining constant to be identity");
|
||||
super::layouts::identity(config, region, &[&value])?
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use colored::Colorize;
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
circuit::{Region, Value},
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -91,15 +91,17 @@ pub struct EinsumIndex {
|
||||
index: usize,
|
||||
col_coord: usize,
|
||||
equations: HashSet<String>,
|
||||
num_inner_cols: usize,
|
||||
}
|
||||
|
||||
impl EinsumIndex {
|
||||
/// Create a new einsum index
|
||||
pub fn new(index: usize, col_coord: usize) -> EinsumIndex {
|
||||
EinsumIndex {
|
||||
index,
|
||||
pub fn new(index: usize, col_coord: usize, num_inner_cols: usize) -> EinsumIndex {
|
||||
EinsumIndex {
|
||||
index,
|
||||
col_coord,
|
||||
equations: HashSet::new(),
|
||||
num_inner_cols,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +223,7 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
challenges: Vec<Value<F>>,
|
||||
max_dynamic_input_len: usize,
|
||||
}
|
||||
|
||||
@@ -317,6 +320,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&self.statistics
|
||||
}
|
||||
|
||||
///
|
||||
pub fn challenges(&self) -> &[Value<F>] {
|
||||
&self.challenges
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(
|
||||
region: Region<'a, F>,
|
||||
@@ -339,6 +347,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(decomp_base, decomp_legs),
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -357,6 +366,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
new_self
|
||||
}
|
||||
|
||||
/// Create a new region context with challenges
|
||||
pub fn new_with_challenges(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
challenges: Vec<Value<F>>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
|
||||
new_self.challenges = challenges;
|
||||
new_self
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
@@ -377,6 +400,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -400,6 +424,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -635,6 +660,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.einsum_index.equations.clone()
|
||||
}
|
||||
|
||||
/// set the number of inner columns used in einsum custom gate
|
||||
pub fn set_num_einsum_inner_cols(&mut self, num_inner_cols: usize) {
|
||||
self.einsum_index.num_inner_cols = num_inner_cols;
|
||||
}
|
||||
|
||||
/// number of inner columns used in einsum custom gate
|
||||
pub fn num_einsum_inner_cols(&self) -> usize {
|
||||
self.einsum_index.num_inner_cols
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.statistics.used_lookups.clone()
|
||||
@@ -724,6 +759,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor in einsum area
|
||||
pub fn assign_einsum(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
@@ -781,6 +838,63 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_einsum_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_unconstrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
false,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_einsum_with_duplication_constrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_constrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
check_mode,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
true,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable a selector
|
||||
pub fn enable(&mut self, selector: Option<&Selector>, offset: usize) -> Result<(), Error> {
|
||||
match &self.region {
|
||||
@@ -847,4 +961,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// flush row to the next row in einsum area
|
||||
pub fn flush_einsum(&mut self) -> Result<(), CircuitError> {
|
||||
// increment by the difference between the current linear coord and the next row
|
||||
let num_einsum_inner_cols = self.num_einsum_inner_cols();
|
||||
let remainder = self.einsum_col_coord() % num_einsum_inner_cols;
|
||||
if remainder != 0 {
|
||||
let diff = num_einsum_inner_cols - remainder;
|
||||
self.increment_einsum_col_coord(diff);
|
||||
}
|
||||
if self.einsum_col_coord() % num_einsum_inner_cols != 0 {
|
||||
return Err(CircuitError::FlushError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,6 +171,7 @@ pub mod pfsys;
|
||||
pub mod srs_sha;
|
||||
/// An implementation of multi-dimensional tensors.
|
||||
pub mod tensor;
|
||||
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
|
||||
@@ -480,6 +480,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
self[index].clone()
|
||||
}
|
||||
|
||||
/// Extracts a single value from the tensor
|
||||
pub fn get_scalar(&self) -> T {
|
||||
assert!(self.inner.len() == 1);
|
||||
assert!(self.dims.iter().all(|dim| *dim == 1));
|
||||
self.inner[0].clone()
|
||||
}
|
||||
|
||||
/// Get a mutable array index from rows / columns indices.
|
||||
///
|
||||
/// ```
|
||||
@@ -901,6 +908,22 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// remove axes that have dimensions 1
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap();
|
||||
/// let b = a.remove_trivial_axes().unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
pub fn remove_trivial_axes(&self) -> Result<Self, TensorError> {
|
||||
let mut result = self.clone();
|
||||
let new_dims: Vec<_> = self.dims.iter().copied().filter(|dim| *dim > 1).collect();
|
||||
result.reshape(&new_dims)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Move axis of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
pub use std::ops::{Add, Mul, Neg, Sub};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
|
||||
@@ -2396,6 +2397,8 @@ pub mod nonlinearities {
|
||||
|
||||
/// Ops that return the transcript i.e intermediate calcs of an op
|
||||
pub mod accumulated {
|
||||
use maybe_rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Dot product of two tensors.
|
||||
@@ -2523,4 +2526,320 @@ pub mod accumulated {
|
||||
|
||||
Ok(transcript)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn row_major_strides(dims: &[usize]) -> Vec<usize> {
|
||||
let mut s = vec![0; dims.len()];
|
||||
let mut acc = 1;
|
||||
for (i, &d) in dims.iter().enumerate().rev() {
|
||||
s[i] = acc;
|
||||
acc *= d;
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::accumulated::einsum;
|
||||
///
|
||||
/// // matmul case
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ij,jk->ik", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // element wise multiplication
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ij,ij->ij", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // dot product of A with the transpose of B.
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ik,jk->ij", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // dot product
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ik,ik->i", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // dot product
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("i,i->", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // wut ?
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("anm,bm->ba", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // wutttttt ?
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let z = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8, 9, 9]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("bn,anm,bm->ba", &[&z, &x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // contraction with a single common axis
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("abc,cd->", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[648]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // contraction with no common axes (outer product)
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("abc,ed->", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1296]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // trivial axes mapping
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,k->m", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,k->mn", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[0, 0, 0, 3]),
|
||||
/// &[1, 4],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[213, 227, 74, 77]),
|
||||
/// &[4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,k->ma", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[231]), &[1, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// // subtle difference
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,n->ma", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn einsum<T>(
|
||||
equation: &str,
|
||||
input_tensors: &[&Tensor<T>],
|
||||
) -> Result<(Tensor<T>, HashMap<char, usize>), TensorError>
|
||||
where
|
||||
T: Clone + TensorType + Mul<Output = T> + Add<Output = T> + Send + Sync,
|
||||
{
|
||||
let (input_exprs, output_expr) = equation.split_once("->").unwrap();
|
||||
let input_exprs: Vec<&str> = input_exprs.split(',').collect();
|
||||
assert_eq!(input_exprs.len(), input_tensors.len());
|
||||
|
||||
let mut dim_of: HashMap<char, usize> = HashMap::new();
|
||||
for (input_expr, t) in input_exprs.iter().zip(input_tensors.iter()) {
|
||||
for (c, &d) in input_expr.chars().zip(t.dims().iter()) {
|
||||
let e = dim_of.entry(c).or_insert(d);
|
||||
debug_assert!((*e == d) || (*e == 1) || (d == 1));
|
||||
*e = (*e).max(d);
|
||||
}
|
||||
}
|
||||
|
||||
// Output dims
|
||||
let out_idx: Vec<char> = output_expr.chars().collect();
|
||||
let out_dims: Vec<usize> = out_idx.iter().map(|c| *dim_of.get(c).unwrap_or(&1)).collect();
|
||||
|
||||
// Reduction indices
|
||||
let all_idx: HashSet<char> = dim_of.keys().copied().collect();
|
||||
let out_set: HashSet<char> = out_idx.iter().copied().collect();
|
||||
let red_idx: Vec<char> = all_idx.difference(&out_set).copied().collect();
|
||||
let red_dims: Vec<usize> = red_idx.iter().map(|c| dim_of[c]).collect();
|
||||
|
||||
// Fast index->pos
|
||||
let out_pos: HashMap<char, usize> = out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
let red_pos: HashMap<char, usize> = red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
|
||||
// Precompute strides per input and contributions
|
||||
struct Contrib {
|
||||
out_stride: Vec<usize>,
|
||||
red_stride: Vec<usize>,
|
||||
}
|
||||
let contribs: Vec<Contrib> = input_exprs
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.map(|(expr, t)| {
|
||||
let dims = t.dims().to_vec();
|
||||
let strides = row_major_strides(&dims);
|
||||
let mut out_stride = vec![0; out_idx.len()];
|
||||
let mut red_stride = vec![0; red_idx.len()];
|
||||
for (ax, (c, &d)) in expr.chars().zip(dims.iter()).enumerate() {
|
||||
let s = if d == 1 { 0 } else { strides[ax] };
|
||||
if let Some(&p) = out_pos.get(&c) {
|
||||
out_stride[p] = s;
|
||||
} else if let Some(&q) = red_pos.get(&c) {
|
||||
red_stride[q] = s;
|
||||
}
|
||||
}
|
||||
Contrib { out_stride, red_stride }
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Prepare output buffer
|
||||
let mut out = if out_dims.is_empty() {
|
||||
Tensor::<T>::new(None, &[1])?
|
||||
} else {
|
||||
Tensor::<T>::new(None, &out_dims)?
|
||||
};
|
||||
|
||||
let out_rank = out_dims.len();
|
||||
let red_rank = red_dims.len();
|
||||
|
||||
// Materialize output elements one by one
|
||||
out
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(out_linear_coord, out)| {
|
||||
let mut out_index = vec![0usize; out_rank];
|
||||
{
|
||||
let mut x = out_linear_coord;
|
||||
for i in (0..out_rank).rev() {
|
||||
let d = out_dims[i];
|
||||
out_index[i] = x % d;
|
||||
x /= d;
|
||||
}
|
||||
}
|
||||
|
||||
// Base offset per input from output coordinates
|
||||
let mut base_off = vec![0usize; input_tensors.len()];
|
||||
for (i, c) in contribs.iter().enumerate() {
|
||||
let mut off = 0usize;
|
||||
for p in 0..out_rank {
|
||||
off += out_index[p] * c.out_stride[p];
|
||||
}
|
||||
base_off[i] = off;
|
||||
}
|
||||
|
||||
let mut acc = T::zero().unwrap();
|
||||
|
||||
if red_rank == 0 {
|
||||
// No reduction -> just multiply corresponding elements
|
||||
let mut prod = T::one().unwrap();
|
||||
for (i, t) in input_tensors.iter().enumerate() {
|
||||
let val = t.get_flat_index(base_off[i]);
|
||||
prod = prod * val;
|
||||
}
|
||||
acc = acc + prod;
|
||||
} else {
|
||||
// Iterate over all reduction coords
|
||||
let red_size = red_dims.iter().product::<usize>();
|
||||
let mut red_index = vec![0usize; red_rank];
|
||||
for red_linear_coord in 0..red_size {
|
||||
{
|
||||
let mut x = red_linear_coord;
|
||||
for q in (0..red_rank).rev() {
|
||||
let d = red_dims[q];
|
||||
red_index[q] = x % d;
|
||||
x /= d;
|
||||
}
|
||||
}
|
||||
let mut prod = T::one().unwrap();
|
||||
for (i, (t, c)) in input_tensors.iter().zip(contribs.iter()).enumerate() {
|
||||
let mut off = base_off[i];
|
||||
for q in 0..red_rank {
|
||||
off += red_index[q] * c.red_stride[q];
|
||||
}
|
||||
let val = t.get_flat_index(off);
|
||||
prod = prod * val;
|
||||
}
|
||||
acc = acc + prod;
|
||||
}
|
||||
}
|
||||
|
||||
// write result
|
||||
*out = acc;
|
||||
});
|
||||
Ok((out, dim_of))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -940,6 +940,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// remove axes that have dimensions 1
|
||||
pub fn remove_trivial_axes(&mut self) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.remove_trivial_axes()?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Takes a slice of the tensor along a given axis
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use halo2_proofs::plonk::SecondPhase;
|
||||
use log::{debug, error, warn};
|
||||
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
@@ -152,6 +153,52 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice in SecondPhase with blinded columns enabled for equality constraints
|
||||
pub fn new_advice_in_second_phase<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
num_inner_cols: usize,
|
||||
capacity: usize,
|
||||
) -> Self {
|
||||
let max_rows = Self::max_rows(cs, logrows);
|
||||
let max_assignments = Self::max_rows(cs, logrows) * num_inner_cols;
|
||||
|
||||
let mut modulo = (capacity / max_assignments) + 1;
|
||||
// we add a buffer for duplicated rows (we get at most 1 duplicated row per column)
|
||||
modulo = ((capacity + modulo) / max_assignments) + 1;
|
||||
let mut advices = vec![];
|
||||
|
||||
if modulo > 1 {
|
||||
debug!("using column duplication for {} advice blocks", modulo - 1);
|
||||
}
|
||||
|
||||
for _ in 0..modulo {
|
||||
let mut inner = vec![];
|
||||
for _ in 0..num_inner_cols {
|
||||
let col = cs.advice_column_in(SecondPhase);
|
||||
cs.enable_equality(col);
|
||||
inner.push(col);
|
||||
}
|
||||
advices.push(inner);
|
||||
}
|
||||
|
||||
VarTensor::Advice {
|
||||
inner: advices,
|
||||
num_inner_cols,
|
||||
col_size: max_rows,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
@@ -651,7 +698,7 @@ impl VarTensor {
|
||||
>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
row: usize,
|
||||
_row: usize,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &CheckMode,
|
||||
@@ -669,7 +716,7 @@ impl VarTensor {
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
let duplication_freq = self.col_size();
|
||||
let num_repeats = 1;
|
||||
let duplication_offset = row;
|
||||
let (_, _, duplication_offset) = self.cartesian_coord(offset);
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v
|
||||
|
||||
Reference in New Issue
Block a user