mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
feat: accum sum (#173)
This commit is contained in:
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -89,6 +89,10 @@ jobs:
|
||||
run: cargo bench --verbose --bench add
|
||||
- name: Bench pairwise add
|
||||
run: cargo bench --verbose --bench pairwise_add
|
||||
- name: Bench sum
|
||||
run: cargo bench --verbose --bench sum
|
||||
- name: Bench accum sum
|
||||
run: cargo bench --verbose --bench accum_sum
|
||||
|
||||
docs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -66,6 +66,14 @@ harness = false
|
||||
name = "accum_dot"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "sum"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_sum"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "add"
|
||||
harness = false
|
||||
|
||||
107
benches/accum_sum.rs
Normal file
107
benches/accum_sum.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl_lib::circuit::accumulated::*;
|
||||
use ezkl_lib::commands::TranscriptType;
|
||||
use ezkl_lib::execute::create_proof_circuit_kzg;
|
||||
use ezkl_lib::pfsys::{create_keys, gen_srs};
|
||||
use ezkl_lib::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit {
|
||||
inputs: [ValTensor<Fr>; 1],
|
||||
_marker: PhantomData<Fr>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, len, vec![len], true, 1024);
|
||||
let b = VarTensor::new_advice(cs, K, len, vec![len], true, 1024);
|
||||
let output = VarTensor::new_advice(cs, K, len, vec![len + 1], true, 1024);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
config
|
||||
.layout(&mut layouter, &self.inputs, 0, Op::Sum)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runsum(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("accum_sum");
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [16, 512].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
};
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from((0..len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
circuit.clone(),
|
||||
¶ms,
|
||||
vec![],
|
||||
&pk,
|
||||
TranscriptType::Blake,
|
||||
SingleStrategy::new(¶ms),
|
||||
);
|
||||
prover.unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().with_plots();
|
||||
targets = runsum
|
||||
}
|
||||
criterion_main!(benches);
|
||||
111
benches/sum.rs
Normal file
111
benches/sum.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl_lib::circuit::polynomial::*;
|
||||
use ezkl_lib::commands::TranscriptType;
|
||||
use ezkl_lib::execute::create_proof_circuit_kzg;
|
||||
use ezkl_lib::pfsys::{create_keys, gen_srs};
|
||||
use ezkl_lib::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit {
|
||||
inputs: [ValTensor<Fr>; 2],
|
||||
_marker: PhantomData<Fr>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit {
|
||||
type Config = Config<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, len, vec![len], true, 512);
|
||||
let b = VarTensor::new_advice(cs, K, len, vec![len], true, 512);
|
||||
let output = VarTensor::new_advice(cs, K, len, vec![1], true, 512);
|
||||
let sum_node = Node {
|
||||
op: Op::Sum,
|
||||
input_order: vec![InputType::Input(0)],
|
||||
};
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, &[sum_node])
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout(&mut layouter, &self.inputs).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runsum(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("sum");
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [16].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
};
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from((0..len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
|
||||
let b = Tensor::from((0..len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
circuit.clone(),
|
||||
¶ms,
|
||||
vec![],
|
||||
&pk,
|
||||
TranscriptType::Blake,
|
||||
SingleStrategy::new(¶ms),
|
||||
);
|
||||
prover.unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().with_plots();
|
||||
targets = runsum
|
||||
}
|
||||
criterion_main!(benches);
|
||||
@@ -85,7 +85,7 @@ fn runsumpool(c: &mut Criterion) {
|
||||
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(K as u32);
|
||||
|
||||
for size in [1, 2].iter() {
|
||||
for size in [1].iter() {
|
||||
unsafe {
|
||||
IMAGE_HEIGHT = size * 4;
|
||||
IMAGE_WIDTH = size * 4;
|
||||
|
||||
@@ -9,7 +9,8 @@ use crate::{
|
||||
ops::{
|
||||
accumulated, add, affine as non_accum_affine, convolution as non_accum_conv,
|
||||
dot as non_accum_dot, matmul as non_accum_matmul, mult,
|
||||
scale_and_shift as ref_scale_and_shift, sub, sumpool as non_accum_sumpool,
|
||||
scale_and_shift as ref_scale_and_shift, sub, sum as non_accum_sum,
|
||||
sumpool as non_accum_sumpool,
|
||||
},
|
||||
Tensor, TensorError,
|
||||
},
|
||||
@@ -27,12 +28,6 @@ pub fn dot<F: FieldExt + TensorType>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
offset: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if values.len() != config.inputs.len() {
|
||||
return Err(Box::new(CircuitError::DimMismatch(
|
||||
"accum dot layout".to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
let t = match layouter.assign_region(
|
||||
|| "assign inputs",
|
||||
|mut region| {
|
||||
@@ -100,6 +95,79 @@ pub fn dot<F: FieldExt + TensorType>(
|
||||
Ok(ValTensor::from(t))
|
||||
}
|
||||
|
||||
/// Assigns variables to the regions created when calling `configure`.
|
||||
/// # Arguments
|
||||
/// * `values` - The explicit values to the operations.
|
||||
/// * `layouter` - A Halo2 Layouter.
|
||||
pub fn sum<F: FieldExt + TensorType>(
|
||||
config: &mut BaseConfig<F>,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
offset: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let t = match layouter.assign_region(
|
||||
|| "assign inputs",
|
||||
|mut region| {
|
||||
let input = utils::value_muxer(
|
||||
&config.inputs[0],
|
||||
&{
|
||||
let res = config.inputs[0].assign(&mut region, offset, &values[0])?;
|
||||
res.map(|e| e.value_field().evaluate())
|
||||
},
|
||||
&values[0],
|
||||
);
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_sum = accumulated::sum(&input)
|
||||
.expect("accum poly: sum op failed")
|
||||
.into();
|
||||
let output = config
|
||||
.output
|
||||
.assign(&mut region, offset, &accumulated_sum)?;
|
||||
|
||||
for i in 0..input.len() {
|
||||
let (_, y) = config.inputs[0].cartesian_coord(i);
|
||||
if y == 0 {
|
||||
config
|
||||
.selectors
|
||||
.get(&BaseOp::Identity)
|
||||
.unwrap()
|
||||
.enable(&mut region, offset + y)?;
|
||||
} else {
|
||||
config
|
||||
.selectors
|
||||
.get(&BaseOp::Sum)
|
||||
.unwrap()
|
||||
.enable(&mut region, offset + y)?;
|
||||
}
|
||||
}
|
||||
|
||||
let last_elem = output
|
||||
.get_slice(&[output.len() - 1..output.len()])
|
||||
.expect("accum poly: failed to fetch last elem");
|
||||
|
||||
if matches!(config.check_mode, CheckMode::SAFE) {
|
||||
let safe_dot =
|
||||
non_accum_sum(&input).map_err(|_| halo2_proofs::plonk::Error::Synthesis)?;
|
||||
|
||||
assert_eq!(
|
||||
Into::<Tensor<i32>>::into(last_elem.clone()),
|
||||
Into::<Tensor<i32>>::into(safe_dot),
|
||||
)
|
||||
}
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
},
|
||||
) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ValTensor::from(t))
|
||||
}
|
||||
|
||||
/// Assigns variables to the regions created when calling `configure`.
|
||||
/// # Arguments
|
||||
/// * `values` - The explicit values to the operations.
|
||||
|
||||
@@ -22,9 +22,11 @@ use std::{
|
||||
pub enum BaseOp {
|
||||
Dot,
|
||||
InitDot,
|
||||
Identity,
|
||||
Add,
|
||||
Mult,
|
||||
Sub,
|
||||
Sum,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
@@ -47,6 +49,8 @@ impl BaseOp {
|
||||
BaseOp::InitDot => a * b,
|
||||
BaseOp::Dot => a * b + m,
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Identity => a,
|
||||
BaseOp::Sum => a + m,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
}
|
||||
@@ -55,41 +59,52 @@ impl BaseOp {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
BaseOp::InitDot => "INITDOT",
|
||||
BaseOp::Identity => "IDENTITY",
|
||||
BaseOp::Dot => "DOT",
|
||||
BaseOp::Add => "ADD",
|
||||
BaseOp::Sub => "SUB",
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
}
|
||||
}
|
||||
fn query_offset_rng(&self) -> (i32, usize) {
|
||||
match self {
|
||||
BaseOp::InitDot => (0, 1),
|
||||
BaseOp::Identity => (0, 1),
|
||||
BaseOp::Dot => (-1, 2),
|
||||
BaseOp::Add => (0, 1),
|
||||
BaseOp::Sub => (0, 1),
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
}
|
||||
}
|
||||
fn num_inputs(&self) -> usize {
|
||||
match self {
|
||||
BaseOp::InitDot => 2,
|
||||
BaseOp::Identity => 1,
|
||||
BaseOp::Dot => 2,
|
||||
BaseOp::Add => 2,
|
||||
BaseOp::Sub => 2,
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
}
|
||||
}
|
||||
fn constraint_idx(&self) -> usize {
|
||||
match self {
|
||||
BaseOp::InitDot => 0,
|
||||
BaseOp::Identity => 0,
|
||||
BaseOp::Dot => 1,
|
||||
BaseOp::Add => 0,
|
||||
BaseOp::Sub => 0,
|
||||
BaseOp::Mult => 0,
|
||||
BaseOp::Sum => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BaseOp {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
BaseOp::InitDot => write!(f, "base accum init dot"),
|
||||
BaseOp::Dot => write!(f, "base accum dot"),
|
||||
BaseOp::Add => write!(f, "pairwise add"),
|
||||
BaseOp::Sub => write!(f, "pairwise sub"),
|
||||
BaseOp::Mult => write!(f, "pairwise mult"),
|
||||
}
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,6 +133,7 @@ pub enum Op {
|
||||
BatchNorm,
|
||||
ScaleAndShift,
|
||||
Pad(usize, usize),
|
||||
Sum,
|
||||
}
|
||||
|
||||
/// Configuration for an accumulated arg.
|
||||
@@ -154,8 +170,10 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
|
||||
selectors.insert(BaseOp::Add, meta.selector());
|
||||
selectors.insert(BaseOp::Sub, meta.selector());
|
||||
selectors.insert(BaseOp::Dot, meta.selector());
|
||||
selectors.insert(BaseOp::Sum, meta.selector());
|
||||
selectors.insert(BaseOp::Mult, meta.selector());
|
||||
selectors.insert(BaseOp::InitDot, meta.selector());
|
||||
selectors.insert(BaseOp::Identity, meta.selector());
|
||||
|
||||
let config = Self {
|
||||
selectors,
|
||||
@@ -169,16 +187,13 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
|
||||
meta.create_gate(base_op.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
|
||||
let qis = config
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
input
|
||||
.query_rng(meta, 0, 1)
|
||||
.expect("accum: input query failed")[0]
|
||||
.clone()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut qis = vec![Expression::<F>::zero().unwrap(); 2];
|
||||
for i in 0..base_op.num_inputs() {
|
||||
qis[i] = config.inputs[i]
|
||||
.query_rng(meta, 0, 1)
|
||||
.expect("accum: input query failed")[0]
|
||||
.clone()
|
||||
}
|
||||
|
||||
// Get output expressions for each input channel
|
||||
let (offset, rng) = base_op.query_offset_rng();
|
||||
@@ -213,6 +228,7 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
match op {
|
||||
Op::Dot => layouts::dot(self, layouter, values.try_into()?, offset),
|
||||
Op::Sum => layouts::sum(self, layouter, values.try_into()?, offset),
|
||||
Op::Matmul => layouts::matmul(self, layouter, values.try_into()?, offset),
|
||||
Op::Affine => layouts::affine(self, layouter, values.try_into()?, offset),
|
||||
Op::Conv { padding, stride } => {
|
||||
@@ -387,6 +403,71 @@ mod dottest {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod sumtest {
|
||||
use super::*;
|
||||
use halo2_proofs::{
|
||||
arithmetic::FieldExt,
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
dev::MockProver,
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
// use halo2curves::pasta::pallas;
|
||||
use halo2curves::pasta::Fp as F;
|
||||
// use rand::rngs::OsRng;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 4;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: FieldExt + TensorType> {
|
||||
inputs: [ValTensor<F>; 1],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: FieldExt + TensorType> Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512);
|
||||
let b = VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512);
|
||||
let output = VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let _ = config
|
||||
.layout(&mut layouter, &self.inputs.clone(), 0, Op::Sum)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sumcircuit() {
|
||||
// parameters
|
||||
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod batchnormtest {
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -924,6 +924,39 @@ pub mod accumulated {
|
||||
Ok(transcript)
|
||||
}
|
||||
|
||||
/// Sums a tensor.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl_lib::tensor::Tensor;
|
||||
/// use ezkl_lib::tensor::ops::accumulated::sum;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sum(&x).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(
|
||||
/// Some(&[2, 17, 19, 20, 21, 21]),
|
||||
/// &[6],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn sum<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
a: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let transcript: Tensor<T> = a
|
||||
.iter()
|
||||
.scan(T::zero().unwrap(), |acc, k| {
|
||||
*acc = acc.clone() + k.clone();
|
||||
Some(acc.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(transcript)
|
||||
}
|
||||
|
||||
/// Matrix multiplies two 2D tensors.
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user