feat: accum sum (#173)

This commit is contained in:
dante
2023-03-22 08:07:27 +00:00
committed by GitHub
parent bf9311d2f4
commit f579967ad1
8 changed files with 437 additions and 25 deletions

View File

@@ -89,6 +89,10 @@ jobs:
run: cargo bench --verbose --bench add
- name: Bench pairwise add
run: cargo bench --verbose --bench pairwise_add
- name: Bench sum
run: cargo bench --verbose --bench sum
- name: Bench accum sum
run: cargo bench --verbose --bench accum_sum
docs:
runs-on: ubuntu-latest

View File

@@ -66,6 +66,14 @@ harness = false
name = "accum_dot"
harness = false
[[bench]]
name = "sum"
harness = false
[[bench]]
name = "accum_sum"
harness = false
[[bench]]
name = "add"
harness = false

107
benches/accum_sum.rs Normal file
View File

@@ -0,0 +1,107 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl_lib::circuit::accumulated::*;
use ezkl_lib::commands::TranscriptType;
use ezkl_lib::execute::create_proof_circuit_kzg;
use ezkl_lib::pfsys::{create_keys, gen_srs};
use ezkl_lib::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 16;
#[derive(Clone)]
struct MyCircuit {
inputs: [ValTensor<Fr>; 1],
_marker: PhantomData<Fr>,
}
impl Circuit<Fr> for MyCircuit {
type Config = BaseConfig<Fr>;
type FloorPlanner = SimpleFloorPlanner;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, len, vec![len], true, 1024);
let b = VarTensor::new_advice(cs, K, len, vec![len], true, 1024);
let output = VarTensor::new_advice(cs, K, len, vec![len + 1], true, 1024);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
config
.layout(&mut layouter, &self.inputs, 0, Op::Sum)
.unwrap();
Ok(())
}
}
fn runsum(c: &mut Criterion) {
let mut group = c.benchmark_group("accum_sum");
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [16, 512].iter() {
unsafe {
LEN = len;
};
// parameters
let a = Tensor::from((0..len).map(|_| Value::known(Fr::random(OsRng))));
let circuit = MyCircuit {
inputs: [ValTensor::from(a)],
_marker: PhantomData,
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
circuit.clone(),
&params,
vec![],
&pk,
TranscriptType::Blake,
SingleStrategy::new(&params),
);
prover.unwrap();
});
});
}
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runsum
}
criterion_main!(benches);

111
benches/sum.rs Normal file
View File

@@ -0,0 +1,111 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl_lib::circuit::polynomial::*;
use ezkl_lib::commands::TranscriptType;
use ezkl_lib::execute::create_proof_circuit_kzg;
use ezkl_lib::pfsys::{create_keys, gen_srs};
use ezkl_lib::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 16;
#[derive(Clone)]
struct MyCircuit {
inputs: [ValTensor<Fr>; 2],
_marker: PhantomData<Fr>,
}
impl Circuit<Fr> for MyCircuit {
type Config = Config<Fr>;
type FloorPlanner = SimpleFloorPlanner;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, len, vec![len], true, 512);
let b = VarTensor::new_advice(cs, K, len, vec![len], true, 512);
let output = VarTensor::new_advice(cs, K, len, vec![1], true, 512);
let sum_node = Node {
op: Op::Sum,
input_order: vec![InputType::Input(0)],
};
Self::Config::configure(cs, &[a, b], &output, &[sum_node])
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
config.layout(&mut layouter, &self.inputs).unwrap();
Ok(())
}
}
fn runsum(c: &mut Criterion) {
let mut group = c.benchmark_group("sum");
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [16].iter() {
unsafe {
LEN = len;
};
// parameters
let a = Tensor::from((0..len).map(|_| Value::known(Fr::random(OsRng))));
let b = Tensor::from((0..len).map(|_| Value::known(Fr::random(OsRng))));
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
_marker: PhantomData,
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
circuit.clone(),
&params,
vec![],
&pk,
TranscriptType::Blake,
SingleStrategy::new(&params),
);
prover.unwrap();
});
});
}
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runsum
}
criterion_main!(benches);

View File

@@ -85,7 +85,7 @@ fn runsumpool(c: &mut Criterion) {
let params = gen_srs::<KZGCommitmentScheme<_>>(K as u32);
for size in [1, 2].iter() {
for size in [1].iter() {
unsafe {
IMAGE_HEIGHT = size * 4;
IMAGE_WIDTH = size * 4;

View File

@@ -9,7 +9,8 @@ use crate::{
ops::{
accumulated, add, affine as non_accum_affine, convolution as non_accum_conv,
dot as non_accum_dot, matmul as non_accum_matmul, mult,
scale_and_shift as ref_scale_and_shift, sub, sumpool as non_accum_sumpool,
scale_and_shift as ref_scale_and_shift, sub, sum as non_accum_sum,
sumpool as non_accum_sumpool,
},
Tensor, TensorError,
},
@@ -27,12 +28,6 @@ pub fn dot<F: FieldExt + TensorType>(
values: &[ValTensor<F>; 2],
offset: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if values.len() != config.inputs.len() {
return Err(Box::new(CircuitError::DimMismatch(
"accum dot layout".to_string(),
)));
}
let t = match layouter.assign_region(
|| "assign inputs",
|mut region| {
@@ -100,6 +95,79 @@ pub fn dot<F: FieldExt + TensorType>(
Ok(ValTensor::from(t))
}
/// Assigns variables to the regions created when calling `configure`.
/// # Arguments
/// * `values` - The explicit values to the operations.
/// * `layouter` - A Halo2 Layouter.
pub fn sum<F: FieldExt + TensorType>(
config: &mut BaseConfig<F>,
layouter: &mut impl Layouter<F>,
values: &[ValTensor<F>; 1],
offset: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let t = match layouter.assign_region(
|| "assign inputs",
|mut region| {
let input = utils::value_muxer(
&config.inputs[0],
&{
let res = config.inputs[0].assign(&mut region, offset, &values[0])?;
res.map(|e| e.value_field().evaluate())
},
&values[0],
);
// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input)
.expect("accum poly: sum op failed")
.into();
let output = config
.output
.assign(&mut region, offset, &accumulated_sum)?;
for i in 0..input.len() {
let (_, y) = config.inputs[0].cartesian_coord(i);
if y == 0 {
config
.selectors
.get(&BaseOp::Identity)
.unwrap()
.enable(&mut region, offset + y)?;
} else {
config
.selectors
.get(&BaseOp::Sum)
.unwrap()
.enable(&mut region, offset + y)?;
}
}
let last_elem = output
.get_slice(&[output.len() - 1..output.len()])
.expect("accum poly: failed to fetch last elem");
if matches!(config.check_mode, CheckMode::SAFE) {
let safe_dot =
non_accum_sum(&input).map_err(|_| halo2_proofs::plonk::Error::Synthesis)?;
assert_eq!(
Into::<Tensor<i32>>::into(last_elem.clone()),
Into::<Tensor<i32>>::into(safe_dot),
)
}
// last element is the result
Ok(last_elem)
},
) {
Ok(a) => a,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(ValTensor::from(t))
}
/// Assigns variables to the regions created when calling `configure`.
/// # Arguments
/// * `values` - The explicit values to the operations.

View File

@@ -22,9 +22,11 @@ use std::{
pub enum BaseOp {
Dot,
InitDot,
Identity,
Add,
Mult,
Sub,
Sum,
}
#[allow(missing_docs)]
@@ -47,6 +49,8 @@ impl BaseOp {
BaseOp::InitDot => a * b,
BaseOp::Dot => a * b + m,
BaseOp::Add => a + b,
BaseOp::Identity => a,
BaseOp::Sum => a + m,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
}
@@ -55,41 +59,52 @@ impl BaseOp {
fn as_str(&self) -> &'static str {
match self {
BaseOp::InitDot => "INITDOT",
BaseOp::Identity => "IDENTITY",
BaseOp::Dot => "DOT",
BaseOp::Add => "ADD",
BaseOp::Sub => "SUB",
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
}
}
fn query_offset_rng(&self) -> (i32, usize) {
match self {
BaseOp::InitDot => (0, 1),
BaseOp::Identity => (0, 1),
BaseOp::Dot => (-1, 2),
BaseOp::Add => (0, 1),
BaseOp::Sub => (0, 1),
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
}
}
fn num_inputs(&self) -> usize {
match self {
BaseOp::InitDot => 2,
BaseOp::Identity => 1,
BaseOp::Dot => 2,
BaseOp::Add => 2,
BaseOp::Sub => 2,
BaseOp::Mult => 2,
BaseOp::Sum => 1,
}
}
fn constraint_idx(&self) -> usize {
match self {
BaseOp::InitDot => 0,
BaseOp::Identity => 0,
BaseOp::Dot => 1,
BaseOp::Add => 0,
BaseOp::Sub => 0,
BaseOp::Mult => 0,
BaseOp::Sum => 1,
}
}
}
impl fmt::Display for BaseOp {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
BaseOp::InitDot => write!(f, "base accum init dot"),
BaseOp::Dot => write!(f, "base accum dot"),
BaseOp::Add => write!(f, "pairwise add"),
BaseOp::Sub => write!(f, "pairwise sub"),
BaseOp::Mult => write!(f, "pairwise mult"),
}
write!(f, "{}", self.as_str())
}
}
@@ -118,6 +133,7 @@ pub enum Op {
BatchNorm,
ScaleAndShift,
Pad(usize, usize),
Sum,
}
/// Configuration for an accumulated arg.
@@ -154,8 +170,10 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
selectors.insert(BaseOp::Add, meta.selector());
selectors.insert(BaseOp::Sub, meta.selector());
selectors.insert(BaseOp::Dot, meta.selector());
selectors.insert(BaseOp::Sum, meta.selector());
selectors.insert(BaseOp::Mult, meta.selector());
selectors.insert(BaseOp::InitDot, meta.selector());
selectors.insert(BaseOp::Identity, meta.selector());
let config = Self {
selectors,
@@ -169,16 +187,13 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
meta.create_gate(base_op.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let qis = config
.inputs
.iter()
.map(|input| {
input
.query_rng(meta, 0, 1)
.expect("accum: input query failed")[0]
.clone()
})
.collect::<Vec<_>>();
let mut qis = vec![Expression::<F>::zero().unwrap(); 2];
for i in 0..base_op.num_inputs() {
qis[i] = config.inputs[i]
.query_rng(meta, 0, 1)
.expect("accum: input query failed")[0]
.clone()
}
// Get output expressions for each input channel
let (offset, rng) = base_op.query_offset_rng();
@@ -213,6 +228,7 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
) -> Result<ValTensor<F>, Box<dyn Error>> {
match op {
Op::Dot => layouts::dot(self, layouter, values.try_into()?, offset),
Op::Sum => layouts::sum(self, layouter, values.try_into()?, offset),
Op::Matmul => layouts::matmul(self, layouter, values.try_into()?, offset),
Op::Affine => layouts::affine(self, layouter, values.try_into()?, offset),
Op::Conv { padding, stride } => {
@@ -387,6 +403,71 @@ mod dottest {
}
}
#[cfg(test)]
mod sumtest {
use super::*;
use halo2_proofs::{
arithmetic::FieldExt,
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
// use halo2curves::pasta::pallas;
use halo2curves::pasta::Fp as F;
// use rand::rngs::OsRng;
const K: usize = 4;
const LEN: usize = 4;
#[derive(Clone)]
struct MyCircuit<F: FieldExt + TensorType> {
inputs: [ValTensor<F>; 1],
_marker: PhantomData<F>,
}
impl<F: FieldExt + TensorType> Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512);
let b = VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512);
let output = VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let _ = config
.layout(&mut layouter, &self.inputs.clone(), 0, Op::Sum)
.unwrap();
Ok(())
}
}
#[test]
fn sumcircuit() {
// parameters
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = MyCircuit::<F> {
inputs: [ValTensor::from(a)],
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
#[cfg(test)]
mod batchnormtest {
use std::marker::PhantomData;

View File

@@ -924,6 +924,39 @@ pub mod accumulated {
Ok(transcript)
}
/// Sums a tensor.
/// # Arguments
///
/// * `a` - Tensor
/// # Examples
/// ```
/// use ezkl_lib::tensor::Tensor;
/// use ezkl_lib::tensor::ops::accumulated::sum;
/// let x = Tensor::<i128>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = sum(&x).unwrap();
/// let expected = Tensor::<i128>::new(
/// Some(&[2, 17, 19, 20, 21, 21]),
/// &[6],
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn sum<T: TensorType + Mul<Output = T> + Add<Output = T>>(
a: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
let transcript: Tensor<T> = a
.iter()
.scan(T::zero().unwrap(), |acc, k| {
*acc = acc.clone() + k.clone();
Some(acc.clone())
})
.collect();
Ok(transcript)
}
/// Matrix multiplies two 2D tensors.
/// # Arguments
///