chore: error bubbling (#93)

Co-authored-by: jason <jason.morton@gmail.com>
This commit is contained in:
dante
2023-01-15 09:02:39 +00:00
committed by GitHub
parent cc2cb51a88
commit 16f746b0d2
27 changed files with 1050 additions and 946 deletions

View File

@@ -6,12 +6,10 @@ on:
pull_request:
branches: [ "main" ]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
@@ -60,7 +58,7 @@ jobs:
library-tests:
runs-on: ubuntu-latest-16-cores
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@@ -108,9 +106,9 @@ jobs:
components: rustfmt, clippy
- name: IPA full-prove tests
run: cargo test --release --verbose tests::ipa_fullprove_ -- --test-threads 1
run: cargo test --release --verbose tests::ipa_fullprove_ -- --test-threads 4
- name: KZG full-prove tests
run: cargo test --release --verbose tests::kzg_fullprove_ -- --test-threads 1
run: cargo test --release --verbose tests::kzg_fullprove_ -- --test-threads 4
full-proving-evm-tests:
@@ -129,7 +127,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.17 && solc --version
- name: KZG full-prove tests (EVM)
run: cargo test --release --verbose tests_evm::kzg_evm_fullprove_ -- --test-threads 1
run: cargo test --release --verbose tests_evm::kzg_evm_fullprove_ -- --test-threads 3
prove-and-verify-tests:
@@ -145,9 +143,9 @@ jobs:
components: rustfmt, clippy
- name: IPA prove and verify tests
run: cargo test --release --verbose tests::ipa_prove_and_verify_ -- --test-threads 1
run: cargo test --release --verbose tests::ipa_prove_and_verify_ -- --test-threads 4
- name: KZG prove and verify tests
run: cargo test --release --verbose tests::kzg_prove_and_verify_ -- --test-threads 1
run: cargo test --release --verbose tests::kzg_prove_and_verify_ -- --test-threads 4
examples:

1
Cargo.lock generated
View File

@@ -1839,6 +1839,7 @@ dependencies = [
"tabled",
"tensorflow",
"test-case",
"thiserror",
"tract-onnx",
]

View File

@@ -26,6 +26,7 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", package = "e
plonk_verifier = { git = "https://github.com/zkonduit/plonk-verifier", branch = "main"}
colog = { version = "1.1.0", optional = true }
eq-float = "0.1.0"
thiserror = "1.0.38"
[dev-dependencies]
criterion = {version = "0.3", features = ["html_reports"]}

View File

@@ -55,14 +55,16 @@ impl<F: FieldExt + TensorType> Circuit<F> for MyCircuit<F> {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout(
&mut layouter,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
);
config
.layout(
&mut layouter,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
)
.unwrap();
Ok(())
}
}

View File

@@ -44,11 +44,13 @@ impl<F: FieldExt + TensorType> Circuit<F> for MyCircuit<F> {
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout(
layouter.namespace(|| "Assign value"),
self.input.clone(),
self.output.clone(),
);
config
.layout(
layouter.namespace(|| "Assign value"),
self.input.clone(),
self.output.clone(),
)
.unwrap();
Ok(())
}

View File

@@ -44,7 +44,7 @@ impl<F: FieldExt + TensorType> Circuit<F> for NLCircuit<F> {
config: Self::Config,
mut layouter: impl Layouter<F>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout(&mut layouter, &self.input);
config.layout(&mut layouter, &self.input).unwrap();
Ok(())
}

View File

@@ -239,28 +239,30 @@ where
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let x = config.l0.layout(
&mut layouter,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
);
let mut x = config.l1.layout(&mut layouter, &x);
let x = config
.l0
.layout(
&mut layouter,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
)
.unwrap();
let mut x = config.l1.layout(&mut layouter, &x).unwrap();
x.flatten();
let l2out = config.l2.layout(
&mut layouter,
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
);
let l2out = config
.l2
.layout(
&mut layouter,
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
)
.unwrap();
match l2out {
ValTensor::PrevAssigned { inner: v, dims: _ } => v
.enum_map(|i, x| {
layouter
.constrain_instance(x.cell(), config.public_output, i)
.unwrap()
})
.enum_map(|i, x| layouter.constrain_instance(x.cell(), config.public_output, i))
.unwrap(),
_ => panic!("Should be assigned"),
};
@@ -310,10 +312,11 @@ pub fn runconv() {
let mut input: ValTensor<F> = train_data
.get_slice(&[0..1, 0..28, 0..28])
.unwrap()
.map(Value::known)
.into();
input.reshape(&[1, 28, 28]);
input.reshape(&[1, 28, 28]).unwrap();
let myparams = params::Params::new();
let mut l0_kernels: ValTensor<F> = Tensor::<Value<F>>::from(
@@ -334,7 +337,9 @@ pub fn runconv() {
)
.into();
l0_kernels.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH]);
l0_kernels
.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH])
.unwrap();
let l0_bias: ValTensor<F> = Tensor::<Value<F>>::from(
(0..OUT_CHANNELS).map(|_| Value::known(fieldutils::i32_to_felt(0))),
@@ -360,7 +365,7 @@ pub fn runconv() {
}))
.into();
l2_weights.reshape(&[CLASSES, LEN]);
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
let circuit = MyCircuit::<
F,

View File

@@ -83,7 +83,8 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
&output,
BITS,
&[LookupOp::ReLU { scale: 1 }],
);
)
.unwrap();
// sets up a new Divide by table
let l4 =
@@ -107,28 +108,30 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let x = config.l0.layout(
&mut layouter,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
);
let x = config.l1.layout(&mut layouter, &x);
let x = config.l2.layout(
&mut layouter,
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
);
let x = config.l3.layout(&mut layouter, &x);
let x = config.l4.layout(&mut layouter, &x);
let x = config
.l0
.layout(
&mut layouter,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
)
.unwrap();
let x = config.l1.layout(&mut layouter, &x).unwrap();
let x = config
.l2
.layout(
&mut layouter,
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
)
.unwrap();
let x = config.l3.layout(&mut layouter, &x).unwrap();
let x = config.l4.layout(&mut layouter, &x).unwrap();
match x {
ValTensor::PrevAssigned { inner: v, dims: _ } => v
.enum_map(|i, x| {
layouter
.constrain_instance(x.cell(), config.public_output, i)
.unwrap()
})
.enum_map(|i, x| layouter.constrain_instance(x.cell(), config.public_output, i))
.unwrap(),
_ => panic!("Should be assigned"),
};

View File

@@ -1,14 +1,20 @@
use ezkl::commands::Cli;
use ezkl::execute::run;
use log::info;
use log::{error, info};
use rand::seq::SliceRandom;
use std::error::Error;
pub fn main() {
pub fn main() -> Result<(), Box<dyn Error>> {
let args = Cli::create();
colog::init();
banner();
info!("{}", &args.as_json());
run(args)
info!("{}", &args.as_json()?);
let res = run(args);
match &res {
Ok(_) => info!("verify succeeded"),
Err(e) => error!("verify failed: {}", e),
};
res
}
fn banner() {

View File

@@ -1,13 +1,13 @@
use super::*;
use crate::tensor::ops::activations::*;
use crate::{abort, fieldutils::felt_to_i32, fieldutils::i32_to_felt};
use crate::{fieldutils::felt_to_i32, fieldutils::i32_to_felt};
use halo2_proofs::{
arithmetic::{Field, FieldExt},
circuit::{Layouter, Value},
plonk::{ConstraintSystem, Expression, Selector, TableColumn},
poly::Rotation,
};
use log::error;
use std::error::Error;
use std::fmt;
use std::{cell::RefCell, marker::PhantomData, rc::Rc};
@@ -98,8 +98,11 @@ impl<F: FieldExt> Table<F> {
}
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) {
assert!(!self.is_assigned);
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
}
let base = 2i32;
let smallest = -base.pow(self.bits as u32 - 1);
let largest = base.pow(self.bits as u32 - 1);
@@ -108,36 +111,35 @@ impl<F: FieldExt> Table<F> {
for nl in self.nonlinearities.clone() {
evals = nl.f(inputs.clone());
}
self.is_assigned = true;
layouter
.assign_table(
|| "nl table",
|mut table| {
inputs
.enum_map(|row_offset, input| {
table
.assign_cell(
|| format!("nl_i_col row {}", row_offset),
self.table_input,
row_offset,
|| Value::known(i32_to_felt::<F>(input)),
)
.expect("failed to assign table input cell");
let _ = inputs
.iter()
.enumerate()
.map(|(row_offset, input)| {
table.assign_cell(
|| format!("nl_i_col row {}", row_offset),
self.table_input,
row_offset,
|| Value::known(i32_to_felt::<F>(*input)),
)?;
table
.assign_cell(
|| format!("nl_o_col row {}", row_offset),
self.table_output,
row_offset,
|| Value::known(i32_to_felt::<F>(evals[row_offset])),
)
.expect("failed to assign table output cell");
table.assign_cell(
|| format!("nl_o_col row {}", row_offset),
self.table_output,
row_offset,
|| Value::known(i32_to_felt::<F>(evals[row_offset])),
)?;
Ok(())
})
.expect("failed to assign table");
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
},
)
.expect("failed to layout table");
self.is_assigned = true;
.map_err(Box::<dyn Error>::from)
}
}
@@ -163,24 +165,24 @@ impl<F: FieldExt + TensorType> Config<F> {
output: &VarTensor,
bits: usize,
nonlinearitities: &[Op],
) -> [Self; NUM] {
) -> Result<[Self; NUM], Box<dyn Error>> {
let mut table: Option<Rc<RefCell<Table<F>>>> = None;
let configs = (0..NUM)
.map(|_| {
let l = match &table {
None => Self::configure(cs, input, output, bits, &nonlinearitities),
Some(t) => Self::configure_with_table(cs, input, output, t.clone()),
};
table = Some(l.table.clone());
l
})
.collect::<Vec<Config<F>>>()
.try_into();
match configs {
Ok(x) => x,
Err(_) => panic!("failed to initialize layers"),
let mut configs: Vec<Config<F>> = vec![];
for _ in 0..NUM {
let l = match &table {
None => Self::configure(cs, input, output, bits, nonlinearitities),
Some(t) => Self::configure_with_table(cs, input, output, t.clone()),
};
table = Some(l.table.clone());
configs.push(l);
}
let res: [Self; NUM] = match configs.try_into() {
Ok(a) => a,
Err(_) => {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
}
};
Ok(res)
}
/// Configures and creates an elementwise operation within a circuit using a supplied lookup table.
@@ -260,16 +262,20 @@ impl<F: FieldExt + TensorType> Config<F> {
let table = Rc::new(RefCell::new(Table::<F>::configure(
cs,
bits,
&nonlinearitities,
nonlinearitities,
)));
Self::configure_with_table(cs, input, output, table)
}
/// Assigns values to the variables created when calling `configure`.
/// Values are supplied as a 1-element array of `[input]` VarTensors.
pub fn layout(&self, layouter: &mut impl Layouter<F>, values: &ValTensor<F>) -> ValTensor<F> {
pub fn layout(
&self,
layouter: &mut impl Layouter<F>,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if !self.table.borrow().is_assigned {
self.table.borrow_mut().layout(layouter)
self.table.borrow_mut().layout(layouter)?
}
let mut t = ValTensor::from(
match layouter.assign_region(
@@ -277,7 +283,7 @@ impl<F: FieldExt + TensorType> Config<F> {
|mut region| {
self.qlookup.enable(&mut region, 0)?;
let w = self.input.assign(&mut region, 0, &values).unwrap();
let w = self.input.assign(&mut region, 0, values)?;
let mut res: Vec<i32> = vec![];
let _ = Tensor::from(w.iter().map(|acaf| (*acaf).value_field()).map(|vaf| {
@@ -298,20 +304,17 @@ impl<F: FieldExt + TensorType> Config<F> {
}
};
Ok(self
.output
.assign(&mut region, 0, &ValTensor::from(output))
.unwrap())
self.output.assign(&mut region, 0, &ValTensor::from(output))
},
) {
Ok(a) => a,
Err(e) => {
abort!("failed to assign elt-wise region {:?}", e);
return Err(Box::new(e));
}
},
);
t.reshape(values.dims());
t
t.reshape(values.dims())?;
Ok(t)
}
}

View File

@@ -7,3 +7,21 @@ pub mod polynomial;
pub mod range;
/// Utility functions for building gates.
pub mod utils;
use thiserror::Error;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum CircuitError {
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
}

View File

@@ -1,5 +1,4 @@
use super::*;
use crate::abort;
use crate::tensor::ops::*;
use crate::tensor::{Tensor, TensorType};
use halo2_proofs::{
@@ -8,7 +7,7 @@ use halo2_proofs::{
plonk::{ConstraintSystem, Constraints, Expression, Selector},
};
use itertools::Itertools;
use log::error;
use std::error::Error;
use std::fmt;
use std::marker::PhantomData;
@@ -93,18 +92,18 @@ impl Op {
pub fn f<T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T>>(
&self,
mut inputs: Vec<Tensor<T>>,
) -> Tensor<T> {
) -> Result<Tensor<T>, TensorError> {
match &self {
Op::Identity => inputs[0].clone(),
Op::Identity => Ok(inputs[0].clone()),
Op::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims);
t
Ok(t)
}
Op::Flatten(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims);
t
Ok(t)
}
Op::Add => add(&inputs),
Op::Sub => sub(&inputs),
@@ -124,24 +123,27 @@ impl Op {
} => sumpool(&inputs[0], *padding, *stride, *kernel_shape),
Op::GlobalSumPool => unreachable!(),
Op::Pow(u) => {
assert_eq!(inputs.len(), 1);
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
}
pow(&inputs[0], *u)
}
Op::Sum => {
assert_eq!(inputs.len(), 1);
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("sum inputs".to_string()));
}
sum(&inputs[0])
}
Op::Rescaled { inner, scale } => {
assert_eq!(scale.len(), inputs.len());
if scale.len() != inputs.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
}
inner.f(inputs
.iter_mut()
.enumerate()
.map(|(i, ri)| {
assert_eq!(scale[i].0, i);
rescale(ri, scale[i].1)
})
.collect_vec())
let mut rescaled_inputs = vec![];
for (i, ri) in inputs.iter_mut().enumerate() {
rescaled_inputs.push(rescale(ri, scale[i].1)?);
}
Ok(inner.f(rescaled_inputs)?)
}
}
}
@@ -204,31 +206,24 @@ impl<F: FieldExt + TensorType> Config<F> {
let qis = config
.inputs
.iter()
.map(|input| match input.query(meta, 0) {
Ok(q) => q,
Err(e) => {
abort!("failed to query input {:?}", e);
}
})
.map(|input| input.query(meta, 0).expect("poly: input query failed"))
.collect::<Vec<_>>();
let mut config_outputs = vec![];
for node in config.nodes.iter_mut() {
Self::apply_op(node, &qis, &mut config_outputs);
Self::apply_op(node, &qis, &mut config_outputs).expect("poly: apply op failed");
}
let witnessed_output = &config_outputs[config.nodes.len() - 1];
// Get output expressions for each input channel
let expected_output: Tensor<Expression<F>> = match config.output.query(meta, 0) {
Ok(res) => res,
Err(e) => {
abort!("failed to query output during fused layer layout {:?}", e);
}
};
let expected_output: Tensor<Expression<F>> = config
.output
.query(meta, 0)
.expect("poly: output query failed");
let constraints = witnessed_output
.enum_map(|i, o| o - expected_output[i].clone())
.unwrap();
.enum_map::<_, _, CircuitError>(|i, o| Ok(o - expected_output[i].clone()))
.expect("poly: failed to create constraints");
Constraints::with_selector(selector, constraints)
});
@@ -244,8 +239,12 @@ impl<F: FieldExt + TensorType> Config<F> {
&mut self,
layouter: &mut impl Layouter<F>,
values: &[ValTensor<F>],
) -> ValTensor<F> {
assert_eq!(values.len(), self.inputs.len());
) -> Result<ValTensor<F>, Box<dyn Error>> {
if values.len() != self.inputs.len() {
return Err(Box::new(CircuitError::DimMismatch(
"polynomial layout".to_string(),
)));
}
let t = match layouter.assign_region(
|| "assign inputs",
@@ -258,15 +257,8 @@ impl<F: FieldExt + TensorType> Config<F> {
let inp = utils::value_muxer(
&self.inputs[i],
&{
match self.inputs[i].assign(&mut region, offset, input) {
Ok(res) => res.map(|e| e.value_field().evaluate()),
Err(e) => {
abort!(
"failed to assign inputs during fused layer layout {:?}",
e
);
}
}
let res = self.inputs[i].assign(&mut region, offset, input)?;
res.map(|e| e.value_field().evaluate())
},
input,
);
@@ -276,30 +268,27 @@ impl<F: FieldExt + TensorType> Config<F> {
let mut layout_outputs = vec![];
for node in self.nodes.iter_mut() {
Self::apply_op(node, &inputs, &mut layout_outputs);
Self::apply_op(node, &inputs, &mut layout_outputs)
.expect("poly: apply op failed");
}
let output: ValTensor<F> = match layout_outputs.last() {
Some(a) => a.clone().into(),
None => {
panic!("fused layer has empty outputs");
panic!("poly: empty outputs");
}
};
match self.output.assign(&mut region, offset, &output) {
Ok(a) => Ok(a),
Err(e) => {
abort!("failed to assign fused layer output {:?}", e);
}
}
let output = self.output.assign(&mut region, offset, &output)?;
Ok(output)
},
) {
Ok(a) => a,
Err(e) => {
abort!("failed to assign fused layer region {:?}", e);
return Err(Box::new(e));
}
};
ValTensor::from(t)
Ok(ValTensor::from(t))
}
/// Applies an operation represented by a [Op] to the set of inputs (both explicit and intermediate results) it indexes over.
@@ -307,7 +296,7 @@ impl<F: FieldExt + TensorType> Config<F> {
node: &mut Node,
inputs: &[Tensor<T>],
outputs: &mut Vec<Tensor<T>>,
) {
) -> Result<(), Box<dyn Error>> {
let op_inputs = node
.input_order
.iter()
@@ -316,7 +305,8 @@ impl<F: FieldExt + TensorType> Config<F> {
InputType::Inter(u) => outputs[*u].clone(),
})
.collect_vec();
outputs.push(node.op.f(op_inputs));
outputs.push(node.op.f(op_inputs)?);
Ok(())
}
}

View File

@@ -1,4 +1,4 @@
use crate::abort;
use super::CircuitError;
use crate::fieldutils::i32_to_felt;
use crate::tensor::{TensorType, ValTensor, VarTensor};
use halo2_proofs::{
@@ -6,7 +6,6 @@ use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector},
};
use log::error;
use std::marker::PhantomData;
/// Configuration for a range check on the difference between `input` and `expected`.
@@ -45,20 +44,12 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
// v | 1
let q = cs.query_selector(config.selector);
let witnessed = match input.query(cs, 0) {
Ok(q) => q,
Err(e) => {
abort!("failed to query input {:?}", e);
}
};
let witnessed = input.query(cs, 0).expect("range: failed to query input");
// Get output expressions for each input channel
let expected = match expected.query(cs, 0) {
Ok(q) => q,
Err(e) => {
abort!("failed to query input {:?}", e);
}
};
let expected = expected
.query(cs, 0)
.expect("range: failed to query expected value");
// Given a range R and a value v, returns the expression
// (v) * (1 - v) * (2 - v) * ... * (R - 1 - v)
@@ -69,8 +60,10 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
};
let constraints = witnessed
.enum_map(|i, o| range_check(tol as i32, o - expected[i].clone()))
.unwrap();
.enum_map::<_, _, CircuitError>(|i, o| {
Ok(range_check(tol as i32, o - expected[i].clone()))
})
.expect("range: failed to create constraints");
Constraints::with_selector(q, constraints)
});
@@ -86,7 +79,7 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
mut layouter: impl Layouter<F>,
input: ValTensor<F>,
output: ValTensor<F>,
) {
) -> Result<(), halo2_proofs::plonk::Error> {
match layouter.assign_region(
|| "range check layout",
|mut region| {
@@ -96,37 +89,22 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
self.selector.enable(&mut region, offset)?;
// assigns the instance to the advice.
match self.input.assign(&mut region, offset, &input) {
Ok(_) => {}
Err(e) => {
abort!("failed to assign inputs during range layer layout {:?}", e);
}
};
self.input.assign(&mut region, offset, &input)?;
match self.expected.assign(&mut region, offset, &output) {
Ok(_) => {}
Err(e) => {
abort!(
"failed to assign expected output during range layer layout {:?}",
e
);
}
};
self.expected.assign(&mut region, offset, &output)?;
Ok(())
},
) {
Ok(a) => a,
Err(e) => {
abort!("failed to assign fused layer region {:?}", e);
}
};
// ValTensor::from(t);
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use crate::tensor::Tensor;
use halo2_proofs::{
arithmetic::FieldExt,
@@ -169,17 +147,20 @@ mod tests {
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout(
layouter.namespace(|| "assign value"),
self.input.clone(),
self.output.clone(),
);
config
.layout(
layouter.namespace(|| "assign value"),
self.input.clone(),
self.output.clone(),
)
.unwrap();
Ok(())
}
}
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_range_check() {
let k = 4;

View File

@@ -1,9 +1,9 @@
//use crate::onnx::OnnxModel;
use crate::abort;
use clap::{Parser, Subcommand, ValueEnum};
use log::{error, info};
use log::info;
use serde::{Deserialize, Serialize};
use std::env;
use std::error::Error;
use std::io::{stdin, stdout, Write};
use std::path::PathBuf;
@@ -44,14 +44,14 @@ pub struct Cli {
impl Cli {
/// Export the ezkl configuration as json
pub fn as_json(&self) -> String {
pub fn as_json(&self) -> Result<String, Box<dyn Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
abort!("failed to convert Cli to string {:?}", e);
return Err(Box::new(e));
}
};
serialized
Ok(serialized)
}
/// Parse an ezkl configuration from a json
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {

View File

@@ -1,4 +1,3 @@
use crate::abort;
use crate::commands::{Cli, Commands, ProofSystem};
use crate::fieldutils::i32_to_felt;
use crate::graph::Model;
@@ -9,9 +8,10 @@ use crate::pfsys::aggregation::{
};
use crate::pfsys::{create_keys, load_params, load_vk, Proof};
use crate::pfsys::{
create_proof_model, parse_prover_errors, prepare_circuit_and_public_input, prepare_data,
save_params, save_vk, verify_proof_model,
create_proof_model, prepare_circuit_and_public_input, prepare_data, save_params, save_vk,
verify_proof_model,
};
use halo2_proofs::dev::VerifyFailure;
#[cfg(feature = "evm")]
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
@@ -34,46 +34,43 @@ use halo2curves::bn256::G1Affine;
use halo2curves::bn256::{Bn256, Fr};
use halo2curves::pasta::vesta;
use halo2curves::pasta::Fp;
use log::{error, info, trace};
use log::{info, trace};
#[cfg(feature = "evm")]
use plonk_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::error::Error;
#[cfg(feature = "evm")]
use std::time::Instant;
use tabled::Table;
use thiserror::Error;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
pub enum ExecutionError {
/// Shape mismatch in a operation
#[error("verification failed")]
VerifyError(Vec<VerifyFailure>),
}
/// Run an ezkl command with given args
pub fn run(args: Cli) {
pub fn run(args: Cli) -> Result<(), Box<dyn Error>> {
match args.command {
Commands::Table { model: _ } => {
let om = Model::from_ezkl_conf(args);
let om = Model::from_ezkl_conf(args)?;
println!("{}", Table::new(om.nodes.flatten()));
}
Commands::Mock { ref data, model: _ } => {
let data = prepare_data(data.to_string());
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args);
let data = prepare_data(data.to_string())?;
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args)?;
info!("Mock proof");
let pi: Vec<Vec<Fp>> = public_inputs
.into_iter()
.map(|i| i.into_iter().map(i32_to_felt::<Fp>).collect())
.collect();
let prover = match MockProver::run(args.logrows, &circuit, pi) {
Ok(p) => p,
Err(e) => {
abort!("mock prover failed to run {:?}", e);
}
};
match prover.verify() {
Ok(_) => {
info!("verify succeeded")
}
Err(v) => {
for e in v.iter() {
parse_prover_errors(e)
}
panic!()
}
}
let prover =
MockProver::run(args.logrows, &circuit, pi).map_err(Box::<dyn Error>::from)?;
prover
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
}
Commands::Fullprove {
@@ -83,16 +80,17 @@ pub fn run(args: Cli) {
} => {
// A direct proof
let data = prepare_data(data.to_string());
let data = prepare_data(data.to_string())?;
match pfsys {
ProofSystem::IPA => {
let (circuit, public_inputs) =
prepare_circuit_and_public_input::<Fp>(&data, &args);
prepare_circuit_and_public_input::<Fp>(&data, &args)?;
info!("full proof with {}", pfsys);
let params: ParamsIPA<vesta::Affine> = ParamsIPA::new(args.logrows);
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, &params);
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
let strategy = IPASingleStrategy::new(&params);
trace!("params computed");
@@ -102,17 +100,19 @@ pub fn run(args: Cli) {
ProverIPA<_>,
>(
&circuit, &public_inputs, &params, &pk
);
)
.map_err(Box::<dyn Error>::from)?;
assert!(verify_proof_model(proof, &params, pk.get_vk(), strategy));
verify_proof_model(proof, &params, pk.get_vk(), strategy)?;
}
#[cfg(not(feature = "evm"))]
ProofSystem::KZG => {
// A direct proof
let (circuit, public_inputs) =
prepare_circuit_and_public_input::<Fr>(&data, &args);
prepare_circuit_and_public_input::<Fr>(&data, &args)?;
let params: ParamsKZG<Bn256> = ParamsKZG::new(args.logrows);
let pk = create_keys::<KZGCommitmentScheme<_>, Fr>(&circuit, &params);
let pk = create_keys::<KZGCommitmentScheme<_>, Fr>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
let strategy = KZGSingleStrategy::new(&params);
trace!("params computed");
@@ -122,14 +122,15 @@ pub fn run(args: Cli) {
ProverGWC<_>,
>(
&circuit, &public_inputs, &params, &pk
);
)
.map_err(Box::<dyn Error>::from)?;
assert!(verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>(
verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>(
proof,
&params,
pk.get_vk(),
strategy
));
strategy,
)?;
}
#[cfg(feature = "evm")]
ProofSystem::KZG => {
@@ -144,16 +145,16 @@ pub fn run(args: Cli) {
params
};
let now = Instant::now();
let snarks = [(); 1].map(|_| gen_application_snark(&params_app, &data, &args));
let snarks = [gen_application_snark(&params_app, &data, &args)?];
info!("Application proof took {}", now.elapsed().as_secs());
let agg_circuit = AggregationCircuit::new(&params, snarks);
let pk = gen_pk(&params, &agg_circuit);
let agg_circuit = AggregationCircuit::new(&params, snarks)?;
let pk = gen_pk(&params, &agg_circuit)?;
let deployment_code = gen_aggregation_evm_verifier(
&params,
pk.get_vk(),
AggregationCircuit::num_instance(),
AggregationCircuit::accumulator_indices(),
);
)?;
let now = Instant::now();
let proof = gen_kzg_proof::<
_,
@@ -162,10 +163,10 @@ pub fn run(args: Cli) {
EvmTranscript<G1Affine, _, _, _>,
>(
&params, &pk, agg_circuit.clone(), agg_circuit.instances()
);
)?;
info!("Aggregation proof took {}", now.elapsed().as_secs());
let now = Instant::now();
evm_verify(deployment_code, agg_circuit.instances(), proof);
evm_verify(deployment_code, agg_circuit.instances(), proof)?;
info!("verify took {}", now.elapsed().as_secs());
}
}
@@ -178,33 +179,37 @@ pub fn run(args: Cli) {
ref params_path,
pfsys,
} => {
let data = prepare_data(data.to_string());
let data = prepare_data(data.to_string())?;
match pfsys {
ProofSystem::IPA => {
info!("proof with {}", pfsys);
let (circuit, public_inputs) =
prepare_circuit_and_public_input::<Fp>(&data, &args);
prepare_circuit_and_public_input::<Fp>(&data, &args)?;
let params: ParamsIPA<vesta::Affine> = ParamsIPA::new(args.logrows);
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, &params);
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
trace!("params computed");
let (proof, _) = create_proof_model::<IPACommitmentScheme<_>, Fp, ProverIPA<_>>(
&circuit,
&public_inputs,
&params,
&pk,
);
let (proof, _) =
create_proof_model::<IPACommitmentScheme<_>, Fp, ProverIPA<_>>(
&circuit,
&public_inputs,
&params,
&pk,
)
.map_err(Box::<dyn Error>::from)?;
proof.save(proof_path);
save_params::<IPACommitmentScheme<_>>(params_path, &params);
save_vk::<IPACommitmentScheme<_>>(vk_path, pk.get_vk());
proof.save(proof_path)?;
save_params::<IPACommitmentScheme<_>>(params_path, &params)?;
save_vk::<IPACommitmentScheme<_>>(vk_path, pk.get_vk())?;
}
ProofSystem::KZG => {
info!("proof with {}", pfsys);
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args);
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args)?;
let params: ParamsKZG<Bn256> = ParamsKZG::new(args.logrows);
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr>(&circuit, &params);
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
trace!("params computed");
let (proof, _input_dims) = create_proof_model::<
@@ -213,11 +218,12 @@ pub fn run(args: Cli) {
ProverGWC<'_, Bn256>,
>(
&circuit, &public_inputs, &params, &pk
);
)
.map_err(Box::<dyn Error>::from)?;
proof.save(proof_path);
save_params::<KZGCommitmentScheme<Bn256>>(params_path, &params);
save_vk::<KZGCommitmentScheme<Bn256>>(vk_path, pk.get_vk());
proof.save(proof_path)?;
save_params::<KZGCommitmentScheme<Bn256>>(params_path, &params)?;
save_vk::<KZGCommitmentScheme<Bn256>>(vk_path, pk.get_vk())?;
}
};
}
@@ -228,29 +234,31 @@ pub fn run(args: Cli) {
params_path,
pfsys,
} => {
let proof = Proof::load(&proof_path);
let proof = Proof::load(&proof_path)?;
match pfsys {
ProofSystem::IPA => {
let params: ParamsIPA<vesta::Affine> =
load_params::<IPACommitmentScheme<_>>(params_path);
load_params::<IPACommitmentScheme<_>>(params_path)?;
let strategy = IPASingleStrategy::new(&params);
let vk = load_vk::<IPACommitmentScheme<_>, Fp>(vk_path, &params);
let result = verify_proof_model(proof, &params, &vk, strategy);
let vk = load_vk::<IPACommitmentScheme<_>, Fp>(vk_path, &params)?;
let result = verify_proof_model(proof, &params, &vk, strategy).is_ok();
info!("verified: {}", result);
assert!(result);
}
ProofSystem::KZG => {
let params: ParamsKZG<Bn256> =
load_params::<KZGCommitmentScheme<Bn256>>(params_path);
load_params::<KZGCommitmentScheme<Bn256>>(params_path)?;
let strategy = KZGSingleStrategy::new(&params);
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr>(vk_path, &params);
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr>(vk_path, &params)?;
let result = verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>(
proof, &params, &vk, strategy,
);
)
.is_ok();
info!("verified: {}", result);
assert!(result);
}
}
}
}
Ok(())
}

View File

@@ -14,15 +14,57 @@ use anyhow::Result;
use halo2_proofs::{
arithmetic::FieldExt,
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use log::{info, trace};
pub use model::*;
pub use node::*;
use std::cmp::max;
use std::marker::PhantomData;
use thiserror::Error;
pub use vars::*;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, OpKind),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, OpKind),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, OpKind),
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(OpKind),
/// Error when attempting to load a model
#[error("failed to load model")]
ModelLoad,
}
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
#[derive(Clone, Debug)]
pub struct ModelCircuit<F: FieldExt> {
@@ -41,7 +83,7 @@ impl<F: FieldExt + TensorType> Circuit<F> for ModelCircuit<F> {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let model = Model::from_arg();
let model = Model::from_arg().expect("model should load from args");
let mut num_fixed = 0;
let row_cap = model.max_node_size();
@@ -94,7 +136,7 @@ impl<F: FieldExt + TensorType> Circuit<F> for ModelCircuit<F> {
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
) -> Result<(), PlonkError> {
trace!("Setting input in synthesize");
let inputs = self
.inputs

View File

@@ -1,5 +1,6 @@
use super::node::*;
use super::vars::*;
use super::GraphError;
use crate::circuit::lookup::Config as LookupConfig;
use crate::circuit::lookup::Op as LookupOp;
use crate::circuit::lookup::Table as LookupTable;
@@ -14,8 +15,8 @@ use crate::circuit::range::*;
use crate::commands::{Cli, Commands};
use crate::tensor::TensorType;
use crate::tensor::{Tensor, ValTensor, VarTensor};
use anyhow::{Context, Result};
//use clap::Parser;
use anyhow::{Context, Error as AnyError};
use halo2_proofs::{
arithmetic::FieldExt,
circuit::{Layouter, Value},
@@ -26,13 +27,13 @@ use log::{debug, info, trace};
use std::cell::RefCell;
use std::cmp::max;
use std::collections::{BTreeMap, HashSet};
use std::error::Error;
use std::path::Path;
use std::rc::Rc;
use tabled::Table;
use tract_onnx;
use tract_onnx::prelude::{Framework, Graph, InferenceFact, Node as OnnxNode, OutletId};
use tract_onnx::tract_hir::internal::InferenceOp;
/// Mode we're using the model in.
#[derive(Clone, Debug)]
pub enum Mode {
@@ -96,6 +97,7 @@ impl Model {
/// * `tolerance` - How much each quantized output is allowed to be off by
/// * `mode` - The [Mode] we're using the model in.
/// * `visibility` - Which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility].
#[allow(clippy::too_many_arguments)]
pub fn new(
path: impl AsRef<Path>,
scale: i32,
@@ -105,26 +107,22 @@ impl Model {
tolerance: usize,
mode: Mode,
visibility: VarVisibility,
) -> Self {
let model = tract_onnx::onnx().model_for_path(path).unwrap();
) -> Result<Self, Box<dyn Error>> {
let model = tract_onnx::onnx()
.model_for_path(path)
.map_err(|_| GraphError::ModelLoad)?;
info!("visibility: {}", visibility);
let mut nodes = BTreeMap::<usize, Node>::new();
let _ = model
.nodes()
.iter()
.enumerate()
.map(|(i, n)| {
let n = Node::new(n.clone(), &mut nodes, scale, i);
nodes.insert(i, n);
})
.collect_vec();
for (i, n) in model.nodes.iter().enumerate() {
let n = Node::new(n.clone(), &mut nodes, scale, i)?;
nodes.insert(i, n);
}
let om = Model {
model: model.clone(),
scale,
tolerance,
nodes: Self::assign_execution_buckets(nodes)
.expect("failed to assign execution buckets"),
nodes: Self::assign_execution_buckets(nodes)?,
bits,
logrows,
max_rotations,
@@ -134,12 +132,12 @@ impl Model {
debug!("{}", Table::new(om.nodes.flatten()).to_string());
om
Ok(om)
}
/// Creates a `Model` from parsed CLI arguments
pub fn from_ezkl_conf(args: Cli) -> Self {
let visibility = VarVisibility::from_args(args.clone());
pub fn from_ezkl_conf(args: Cli) -> Result<Self, Box<dyn Error>> {
let visibility = VarVisibility::from_args(args.clone())?;
match args.command {
Commands::Table { model } => Model::new(
model,
@@ -195,7 +193,7 @@ impl Model {
}
/// Creates a `Model` based on CLI arguments
pub fn from_arg() -> Self {
pub fn from_arg() -> Result<Self, Box<dyn Error>> {
let args = Cli::create();
Self::from_ezkl_conf(args)
}
@@ -211,7 +209,7 @@ impl Model {
&self,
meta: &mut ConstraintSystem<F>,
vars: &mut ModelVars<F>,
) -> Result<ModelConfig<F>> {
) -> Result<ModelConfig<F>, Box<dyn Error>> {
info!("configuring model");
let mut results = BTreeMap::new();
let mut tables = BTreeMap::new();
@@ -224,7 +222,7 @@ impl Model {
.collect();
if !non_op_nodes.is_empty() {
for (i, node) in non_op_nodes {
let config = self.conf_non_op_node(&node);
let config = self.conf_non_op_node(node)?;
results.insert(*i, config);
}
}
@@ -236,7 +234,7 @@ impl Model {
if !lookup_ops.is_empty() {
for (i, node) in lookup_ops {
let config = self.conf_table(node, meta, vars, &mut tables);
let config = self.conf_table(node, meta, vars, &mut tables)?;
results.insert(*i, config);
}
}
@@ -248,7 +246,7 @@ impl Model {
.collect();
// preserves ordering
if !poly_ops.is_empty() {
let config = self.conf_poly_ops(&poly_ops, meta, vars);
let config = self.conf_poly_ops(&poly_ops, meta, vars)?;
results.insert(**poly_ops.keys().max().unwrap(), config);
let mut display: String = "Poly nodes: ".to_string();
@@ -301,24 +299,25 @@ impl Model {
configs
}
/// Configures non op related nodes (eg. representing an input or const value)
pub fn conf_non_op_node<F: FieldExt + TensorType>(&self, node: &Node) -> NodeConfig<F> {
pub fn conf_non_op_node<F: FieldExt + TensorType>(
&self,
node: &Node,
) -> Result<NodeConfig<F>, Box<dyn Error>> {
match &node.opkind {
OpKind::Const => {
// Typically parameters for one or more layers.
// Currently this is handled in the consuming node(s), but will be moved here.
NodeConfig::Const
Ok(NodeConfig::Const)
}
OpKind::Input => {
// This is the input to the model (e.g. the image).
// Currently this is handled in the consuming node(s), but will be moved here.
NodeConfig::Input
Ok(NodeConfig::Input)
}
OpKind::Unknown(_c) => {
unimplemented!()
}
c => {
panic!("wrong method called for {}", c)
}
c => Err(Box::new(GraphError::WrongMethod(node.idx, c.clone()))),
}
}
@@ -335,25 +334,27 @@ impl Model {
nodes: &BTreeMap<&usize, &Node>,
meta: &mut ConstraintSystem<F>,
vars: &mut ModelVars<F>,
) -> NodeConfig<F> {
let input_nodes: BTreeMap<(&usize, &PolyOp), Vec<Node>> = nodes
.iter()
.map(|(i, e)| {
(
(
*i,
match &e.opkind {
OpKind::Poly(f) => f,
_ => panic!(),
},
),
e.inputs
.iter()
.map(|i| self.nodes.filter(i.node))
.collect_vec(),
)
})
.collect();
) -> Result<NodeConfig<F>, Box<dyn Error>> {
let mut input_nodes: BTreeMap<(&usize, &PolyOp), Vec<Node>> = BTreeMap::new();
for (i, e) in nodes.iter() {
let key = (
*i,
match &e.opkind {
OpKind::Poly(f) => f,
_ => {
return Err(Box::new(GraphError::WrongMethod(e.idx, e.opkind.clone())));
}
},
);
let value = e
.inputs
.iter()
.map(|i| self.nodes.filter(i.node))
.collect_vec();
input_nodes.insert(key, value);
}
// This works because retain only keeps items for which the predicate returns true, and
// insert only returns true if the item was not previously present in the set.
// Since the vector is traversed in order, we end up keeping just the first occurrence of each item.
@@ -412,7 +413,7 @@ impl Model {
let inputs = inputs_to_layer.iter();
NodeConfig::Poly(
let config = NodeConfig::Poly(
PolyConfig::configure(
meta,
&inputs.clone().map(|x| x.1.clone()).collect_vec(),
@@ -420,7 +421,8 @@ impl Model {
&fused_nodes,
),
inputs.map(|x| x.0).collect_vec(),
)
);
Ok(config)
}
/// Configures a lookup table based operation. These correspond to operations that are represented in
@@ -436,7 +438,7 @@ impl Model {
meta: &mut ConstraintSystem<F>,
vars: &mut ModelVars<F>,
tables: &mut BTreeMap<Vec<LookupOp>, Rc<RefCell<LookupTable<F>>>>,
) -> NodeConfig<F> {
) -> Result<NodeConfig<F>, Box<dyn Error>> {
let input_len = node.in_dims[0].iter().product();
let input = &vars.advices[0].reshape(&[input_len]);
let output = &vars.advices[1].reshape(&[input_len]);
@@ -444,20 +446,24 @@ impl Model {
let op = match &node.opkind {
OpKind::Lookup(l) => l,
c => panic!("wrong method called for {}", c),
c => {
return Err(Box::new(GraphError::WrongMethod(node.idx, c.clone())));
}
};
if tables.contains_key(&vec![op.clone()]) {
let table = tables.get(&vec![op.clone()]).unwrap();
let conf: LookupConfig<F> =
LookupConfig::configure_with_table(meta, input, output, table.clone());
NodeConfig::Lookup(conf, node_inputs)
} else {
let conf: LookupConfig<F> =
LookupConfig::configure(meta, input, output, self.bits, &[op.clone()]);
tables.insert(vec![op.clone()], conf.table.clone());
NodeConfig::Lookup(conf, node_inputs)
}
let config =
if let std::collections::btree_map::Entry::Vacant(e) = tables.entry(vec![op.clone()]) {
let conf: LookupConfig<F> =
LookupConfig::configure(meta, input, output, self.bits, &[op.clone()]);
e.insert(conf.table.clone());
NodeConfig::Lookup(conf, node_inputs)
} else {
let table = tables.get(&vec![op.clone()]).unwrap();
let conf: LookupConfig<F> =
LookupConfig::configure_with_table(meta, input, output, table.clone());
NodeConfig::Lookup(conf, node_inputs)
};
Ok(config)
}
/// Assigns values to the regions created when calling `configure`.
@@ -472,7 +478,7 @@ impl Model {
layouter: &mut impl Layouter<F>,
inputs: &[ValTensor<F>],
vars: &ModelVars<F>,
) -> Result<()> {
) -> Result<(), Box<dyn Error>> {
info!("model layout");
let mut results = BTreeMap::<usize, ValTensor<F>>::new();
for i in inputs.iter().enumerate() {
@@ -533,7 +539,7 @@ impl Model {
layouter: &mut impl Layouter<F>,
inputs: &mut BTreeMap<usize, ValTensor<F>>,
config: &NodeConfig<F>,
) -> Result<Option<ValTensor<F>>> {
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
// The node kind and the config should be the same.
let res = match config.clone() {
NodeConfig::Poly(mut ac, idx) => {
@@ -555,17 +561,19 @@ impl Model {
})
.collect_vec();
Some(ac.layout(layouter, &values))
Some(ac.layout(layouter, &values)?)
}
NodeConfig::Lookup(rc, idx) => {
assert_eq!(idx.len(), 1);
if idx.len() != 1 {
return Err(Box::new(GraphError::InvalidLookupInputs));
}
// For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output.
Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap()))
Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap())?)
}
NodeConfig::Input => None,
NodeConfig::Const => None,
c => {
panic!("Not a configurable op {:?}", c)
_ => {
return Err(Box::new(GraphError::UnsupportedOp));
}
};
Ok(res)
@@ -580,28 +588,37 @@ impl Model {
/// # Arguments
///
/// * `nodes` - `BTreeMap` of (node index, [Node]) pairs.
pub fn assign_execution_buckets(mut nodes: BTreeMap<usize, Node>) -> Result<NodeGraph> {
pub fn assign_execution_buckets(
mut nodes: BTreeMap<usize, Node>,
) -> Result<NodeGraph, GraphError> {
info!("assigning configuration buckets to operations");
let mut bucketed_nodes = NodeGraph(BTreeMap::<Option<usize>, BTreeMap<usize, Node>>::new());
for (_, node) in nodes.iter_mut() {
let prev_bucket: Option<usize> = node
let mut prev_buckets = vec![];
for n in node
.inputs
.iter()
.filter(|n| !bucketed_nodes.filter(n.node).opkind.is_const())
.map(|n| match bucketed_nodes.filter(n.node).bucket {
Some(b) => b,
None => panic!(),
})
.max();
{
match bucketed_nodes.filter(n.node).bucket {
Some(b) => prev_buckets.push(b),
None => {
return Err(GraphError::MissingNode(n.node));
}
}
}
let prev_bucket: Option<&usize> = prev_buckets.iter().max();
match &node.opkind {
OpKind::Input => node.bucket = Some(0),
OpKind::Const => node.bucket = None,
OpKind::Poly(_) => node.bucket = Some(prev_bucket.unwrap()),
OpKind::Poly(_) => node.bucket = Some(*prev_bucket.unwrap()),
OpKind::Lookup(_) => node.bucket = Some(prev_bucket.unwrap() + 1),
_ => unimplemented!(),
op => {
return Err(GraphError::WrongMethod(node.idx, op.clone()));
}
}
bucketed_nodes.insert(node.bucket, node.idx, node.clone());
}
@@ -613,7 +630,7 @@ impl Model {
/// Note that this order is not stable over multiple reloads of the model. For example, it will freely
/// interchange the order of evaluation of fixed parameters. For example weight could have id 1 on one load,
/// and bias id 2, and vice versa on the next load of the same file. The ids are also not stable.
pub fn eval_order(&self) -> Result<Vec<usize>> {
pub fn eval_order(&self) -> Result<Vec<usize>, AnyError> {
self.model.eval_order()
}
@@ -623,12 +640,12 @@ impl Model {
}
/// Returns the ID of the computational graph's inputs
pub fn input_outlets(&self) -> Result<Vec<OutletId>> {
pub fn input_outlets(&self) -> Result<Vec<OutletId>, Box<dyn Error>> {
Ok(self.model.input_outlets()?.to_vec())
}
/// Returns the ID of the computational graph's outputs
pub fn output_outlets(&self) -> Result<Vec<OutletId>> {
pub fn output_outlets(&self) -> Result<Vec<OutletId>, Box<dyn Error>> {
Ok(self.model.output_outlets()?.to_vec())
}

View File

@@ -1,17 +1,18 @@
use super::utilities::{node_output_shapes, scale_to_multiplier, vector_to_quantized};
use crate::abort;
use crate::circuit::lookup::Config as LookupConfig;
use crate::circuit::lookup::Op as LookupOp;
use crate::circuit::polynomial::Config as PolyConfig;
use crate::circuit::polynomial::Op as PolyOp;
use crate::graph::GraphError;
use crate::tensor::ops::{add, const_mult, div, mult};
use crate::tensor::Tensor;
use crate::tensor::TensorType;
use anyhow::Result;
use halo2_proofs::arithmetic::FieldExt;
use itertools::Itertools;
use log::{error, info, trace, warn};
use log::{info, trace, warn};
use std::collections::{btree_map::Entry, BTreeMap};
use std::error::Error;
use std::fmt;
use std::ops::Deref;
use tabled::Tabled;
@@ -294,7 +295,7 @@ impl Node {
other_nodes: &mut BTreeMap<usize, Node>,
scale: i32,
idx: usize,
) -> Self {
) -> Result<Self, Box<dyn Error>> {
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
let output_shapes = match node_output_shapes(&node) {
@@ -302,20 +303,13 @@ impl Node {
_ => None,
};
let mut inputs: Vec<Node> = node
.inputs
.iter_mut()
// this shouldn't fail
.map(|i| {
match other_nodes.get(&i.node) {
Some(n) => n,
None => {
abort!("input {} has not been initialized", i.node);
}
}
.clone()
})
.collect();
let mut inputs = vec![];
for i in node.inputs.iter_mut() {
match other_nodes.get(&i.node) {
Some(n) => inputs.push(n.clone()),
None => return Err(Box::new(GraphError::MissingNode(i.node))),
}
}
let mut opkind = OpKind::new(node.op().name().as_ref()); // parses the op name
@@ -386,11 +380,11 @@ impl Node {
Some(b) => match (*b).as_any().downcast_ref() {
Some(b) => b,
None => {
panic!("not a leaky relu!");
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
}
},
None => {
panic!("op is not a Tract Expansion!");
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
}
};
@@ -445,7 +439,7 @@ impl Node {
opkind = OpKind::Lookup(LookupOp::PReLU {
scale: layer_scale,
slopes: slopes.clone(),
slopes,
}); // now the input will be scaled down to match
Node {
@@ -462,7 +456,7 @@ impl Node {
}
LookupOp::Div { .. } => {
if inputs[1].out_dims.clone() != [1] {
abort!("ezkl currently only supports division by a constant");
return Err(Box::new(GraphError::NonConstantDiv));
}
let mult = scale_to_multiplier(scale);
let div = inputs[1].output_max / mult;
@@ -516,27 +510,37 @@ impl Node {
Some(b) => match (*b).as_any().downcast_ref() {
Some(b) => b,
None => {
panic!("not a conv!");
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
}
},
None => {
panic!("op is not a Tract Expansion!");
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
}
};
// only support pytorch type formatting for now
assert_eq!(conv_node.data_format, DataFormat::NCHW);
assert_eq!(conv_node.kernel_fmt, KernelFormat::OIHW);
if (conv_node.data_format != DataFormat::NCHW)
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
return Err(Box::new(GraphError::MissingParams(
"data or kernel in wrong format".to_string(),
)));
}
let stride = match conv_node.strides.clone() {
Some(s) => s,
None => {
abort!("strides for node {} has not been initialized", idx);
return Err(Box::new(GraphError::MissingParams(
"strides".to_string(),
)));
}
};
let padding = match &conv_node.padding {
PaddingSpec::Explicit(p, _, _) => p,
_ => panic!("padding is not explicitly specified"),
_ => {
return Err(Box::new(GraphError::MissingParams(
"padding".to_string(),
)));
}
};
if inputs.len() == 3 {
@@ -544,11 +548,11 @@ impl Node {
let scale_diff =
weight_node.out_scale + input_node.out_scale - bias_node.out_scale;
let mut bias_node = other_nodes.get_mut(&node.inputs[2].node).unwrap();
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff);
assert_eq!(
input_node.out_scale + weight_node.out_scale,
bias_node.out_scale
);
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?;
if (input_node.out_scale + weight_node.out_scale) != bias_node.out_scale
{
return Err(Box::new(GraphError::RescalingError(opkind)));
}
}
let oihw = weight_node.out_dims.clone();
@@ -590,18 +594,28 @@ impl Node {
let op = Box::new(node.op());
let sumpool_node: &SumPool = match op.downcast_ref() {
Some(b) => b,
None => panic!("op isn't a SumPool!"),
None => {
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
}
};
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
assert_eq!(pool_spec.data_format, DataFormat::NCHW);
if pool_spec.data_format != DataFormat::NCHW {
return Err(Box::new(GraphError::MissingParams(
"data in wrong format".to_string(),
)));
}
let stride = pool_spec.strides.clone().unwrap();
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(p, _, _) => p,
_ => panic!("padding is not explicitly specified"),
_ => {
return Err(Box::new(GraphError::MissingParams(
"padding".to_string(),
)));
}
};
let kernel_shape = &pool_spec.kernel_shape;
@@ -697,12 +711,10 @@ impl Node {
let scale_diff =
weight_node.out_scale + input_node.out_scale - bias_node.out_scale;
let mut bias_node = other_nodes.get_mut(&node.inputs[2].node).unwrap();
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff);
assert_eq!(
input_node.out_scale + weight_node.out_scale,
bias_node.out_scale
);
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?;
if (input_node.out_scale + weight_node.out_scale) != bias_node.out_scale {
return Err(Box::new(GraphError::RescalingError(opkind)));
}
let in_dim = weight_node.out_dims.clone()[1];
let out_dim = weight_node.out_dims.clone()[0];
@@ -727,34 +739,26 @@ impl Node {
PolyOp::BatchNorm => {
//Compute scale and shift from the four inputs,
// then replace the first two, and change this node to a ScaleAndShift
// let (input_node, mut gamma_node, mut beta_node, mean_node, var_node) = (
// &mut inputs[0],
// &mut inputs[1],
// &mut inputs[2],
// &mut inputs[3],
// &mut inputs[4],
// );
let gamma = inputs[1].raw_const_value.as_ref().unwrap();
let beta = inputs[2].raw_const_value.as_ref().unwrap();
let mu = inputs[3].raw_const_value.as_ref().unwrap();
let sigma = inputs[4].raw_const_value.as_ref().unwrap();
let num_entries = gamma.len();
let a = div(gamma.clone(), sigma.clone());
let amu: Tensor<f32> = mult(&vec![a.clone(), mu.clone()]);
let amupb: Tensor<f32> = add(&vec![amu, beta.clone()]);
let b = const_mult(&amupb, -1f32);
let a = div(gamma.clone(), sigma.clone())?;
let amu: Tensor<f32> = mult(&vec![a.clone(), mu.clone()])?;
let amupb: Tensor<f32> = add(&vec![amu, beta.clone()])?;
let b = const_mult(&amupb, -1f32)?;
let in_scale = inputs[0].out_scale;
let out_scale = 2 * inputs[0].out_scale;
// gamma node becomes the scale (weigh) in scale and shift
inputs[1].raw_const_value = Some(a);
inputs[1].quantize_const_to_scale(in_scale);
inputs[1].quantize_const_to_scale(in_scale)?;
// beta node becomes the shift (bias)
inputs[2].raw_const_value = Some(b);
inputs[2].quantize_const_to_scale(out_scale);
inputs[2].quantize_const_to_scale(out_scale)?;
Node {
idx,
@@ -772,7 +776,7 @@ impl Node {
}
PolyOp::Add => {
opkind = Self::homogenize_input_scales(opkind, inputs.clone());
opkind = Self::homogenize_input_scales(opkind, inputs.clone())?;
let output_max =
if let OpKind::Poly(PolyOp::Rescaled { scale, .. }) = &opkind {
(inputs
@@ -785,8 +789,7 @@ impl Node {
.unwrap() as f32)
* (inputs.len() as f32)
} else {
error!("failed to homogenize input scalings for node {}", idx);
panic!()
return Err(Box::new(GraphError::RescalingError(opkind)));
};
Node {
@@ -802,7 +805,9 @@ impl Node {
}
}
PolyOp::Sum => {
assert!(inputs.len() == 1);
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, opkind)));
};
Node {
idx,
@@ -818,7 +823,7 @@ impl Node {
}
}
PolyOp::Sub => {
opkind = Self::homogenize_input_scales(opkind, inputs.clone());
opkind = Self::homogenize_input_scales(opkind, inputs.clone())?;
let output_max =
if let OpKind::Poly(PolyOp::Rescaled { inner: _, scale }) = &opkind {
(inputs
@@ -831,8 +836,7 @@ impl Node {
.unwrap() as f32)
* (inputs.len() as f32)
} else {
error!("failed to homogenize input scalings for node {}", idx);
panic!()
return Err(Box::new(GraphError::RescalingError(opkind)));
};
Node {
@@ -874,10 +878,9 @@ impl Node {
let pow = inputs[1].clone().raw_const_value.unwrap()[0];
node.inputs.pop();
if inputs[1].out_dims != [1] {
error!(
"ezkl currently only supports raising to the power by a constant"
);
unimplemented!()
{
return Err(Box::new(GraphError::NonConstantPower));
}
}
Node {
@@ -900,8 +903,7 @@ impl Node {
}
}
PolyOp::Rescaled { .. } => {
error!("operations should not already be rescaled at this stage");
panic!()
return Err(Box::new(GraphError::RescalingError(opkind)));
}
PolyOp::Identity => {
let input_node = &inputs[0];
@@ -939,33 +941,42 @@ impl Node {
let shape_const = match shape_const_node.const_value.as_ref() {
Some(sc) => sc,
None => {
abort!("missing shape constant");
return Err(Box::new(GraphError::MissingParams(
"shape constant".to_string(),
)));
}
};
let shapes = &shape_const[0..];
let new_dims: Vec<usize> = if shapes.iter().all(|x| x > &0) {
shapes
.iter()
.map(|x| {
assert!(x > &0);
*x as usize
})
.collect()
} else {
let num_entries: usize = input_node.out_dims.iter().product();
let explicit_prod: i32 = shapes.iter().filter(|x| *x > &0).product();
assert!(explicit_prod > 0);
let inferred = num_entries / (explicit_prod as usize);
let mut new_dims: Vec<usize> = Vec::new();
for i in shapes {
match i {
-1 => new_dims.push(inferred),
0 => continue,
x => new_dims.push(*x as usize),
let new_dims: Result<Vec<usize>, Box<dyn Error>> =
if shapes.iter().all(|x| x > &0) {
let mut res = vec![];
for x in shapes.iter() {
if x <= &0 {
return Err(Box::new(GraphError::InvalidDims(idx, opkind)));
}
res.push(*x as usize);
}
}
new_dims
};
Ok(res)
} else {
let num_entries: usize = input_node.out_dims.iter().product();
let explicit_prod: i32 =
shapes.iter().filter(|x| *x > &0).product();
if explicit_prod <= 0 {
return Err(Box::new(GraphError::InvalidDims(idx, opkind)));
}
let inferred = num_entries / (explicit_prod as usize);
let mut new_dims: Vec<usize> = Vec::new();
for i in shapes {
match i {
-1 => new_dims.push(inferred),
0 => continue,
x => new_dims.push(*x as usize),
}
}
Ok(new_dims)
};
let new_dims = new_dims?;
Node {
idx,
@@ -986,7 +997,7 @@ impl Node {
let const_node: &Const = match op.as_any().downcast_ref() {
Some(b) => b,
None => {
abort!("op is not a const!");
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
}
};
let dt = const_node.0.datum_type();
@@ -1081,16 +1092,18 @@ impl Node {
warn!("{:?} is unknown", opkind);
Node::default()
}
o => {
error!("unsupported op {:?}", o);
panic!()
_ => {
return Err(Box::new(GraphError::UnsupportedOp));
}
};
mn
Ok(mn)
}
/// Ensures all inputs to a node have the same fixed point denominator.
fn homogenize_input_scales(opkind: OpKind, inputs: Vec<Node>) -> OpKind {
fn homogenize_input_scales(
opkind: OpKind,
inputs: Vec<Node>,
) -> Result<OpKind, Box<dyn Error>> {
let mut multipliers = vec![1; inputs.len()];
let out_scales = inputs.windows(1).map(|w| w[0].out_scale).collect_vec();
if !out_scales.windows(2).all(|w| w[0] == w[1]) {
@@ -1114,32 +1127,42 @@ impl Node {
.collect_vec();
}
if let OpKind::Poly(c) = &opkind {
OpKind::Poly(PolyOp::Rescaled {
Ok(OpKind::Poly(PolyOp::Rescaled {
inner: Box::new(c.clone()),
scale: (0..inputs.len()).zip(multipliers).collect_vec(),
})
}))
} else {
error!("should not homegenize input scales for non fused ops.");
panic!()
Err(Box::new(GraphError::RescalingError(opkind)))
}
}
fn quantize_const_to_scale(&mut self, scale: i32) {
assert!(matches!(self.opkind, OpKind::Const));
fn quantize_const_to_scale(&mut self, scale: i32) -> Result<(), Box<dyn Error>> {
if !self.opkind.is_const() {
return Err(Box::new(GraphError::WrongMethod(
self.idx,
self.opkind.clone(),
)));
};
let raw = self.raw_const_value.as_ref().unwrap();
self.out_scale = scale;
let t = vector_to_quantized(raw, raw.dims(), 0f32, self.out_scale).unwrap();
self.output_max = 0f32; //t.iter().map(|x| x.abs()).max().unwrap() as f32;
self.const_value = Some(t);
Ok(())
}
/// Re-quantizes a constant value node to a new scale.
fn scale_up_const_node(node: &mut Node, scale: i32) -> &mut Node {
assert!(matches!(node.opkind, OpKind::Const));
fn scale_up_const_node(node: &mut Node, scale: i32) -> Result<&mut Node, Box<dyn Error>> {
if !node.opkind.is_const() {
return Err(Box::new(GraphError::WrongMethod(
node.idx,
node.opkind.clone(),
)));
};
if scale > 0 {
if let Some(raw) = &node.raw_const_value {
if let Some(val) = &node.raw_const_value {
let mult = scale_to_multiplier(scale);
let t = vector_to_quantized(&raw, raw.dims(), 0f32, scale).unwrap();
let t = vector_to_quantized(val, val.dims(), 0f32, scale)?;
node.const_value = Some(t);
info!(
"------ scaled const node {:?}: {:?} -> {:?}",
@@ -1149,6 +1172,6 @@ impl Node {
node.out_scale = scale;
}
}
node
Ok(node)
}
}

View File

@@ -1,12 +1,14 @@
use crate::abort;
use std::error::Error;
use crate::commands::Cli;
use crate::tensor::TensorType;
use crate::tensor::{ValTensor, VarTensor};
use halo2_proofs::{arithmetic::FieldExt, plonk::ConstraintSystem};
use itertools::Itertools;
use log::error;
use serde::Deserialize;
use super::GraphError;
/// Label Enum to track whether model input, model parameters, and model output are public or private
#[derive(Clone, Debug, Deserialize)]
pub enum Visibility {
@@ -53,7 +55,7 @@ impl std::fmt::Display for VarVisibility {
impl VarVisibility {
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
/// Place in [VarVisibility] struct.
pub fn from_args(args: Cli) -> Self {
pub fn from_args(args: Cli) -> Result<Self, Box<dyn Error>> {
let input_vis = if args.public_inputs {
Visibility::Public
} else {
@@ -70,13 +72,13 @@ impl VarVisibility {
Visibility::Private
};
if !output_vis.is_public() & !params_vis.is_public() & !input_vis.is_public() {
abort!("at least one set of variables should be public");
return Err(Box::new(GraphError::Visibility));
}
Self {
Ok(Self {
input: input_vis,
params: params_vis,
output: output_vis,
}
})
}
}

View File

@@ -51,12 +51,3 @@ pub mod graph;
pub mod pfsys;
/// An implementation of multi-dimensional tensors.
pub mod tensor;
/// A macro to abort concisely.
#[macro_export]
macro_rules! abort {
($msg:literal $(, $ex:expr)*) => {
error!($msg, $($ex,)*);
panic!();
};
}

View File

@@ -54,8 +54,10 @@ use plonk_verifier::{
Protocol,
};
use rand::rngs::OsRng;
use std::error::Error;
use std::io::Cursor;
use std::{iter, rc::Rc};
use thiserror::Error;
const LIMBS: usize = 4;
const BITS: usize = 68;
@@ -77,6 +79,26 @@ type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip
pub type PoseidonTranscript<L, S> =
system::halo2::transcript::halo2::PoseidonTranscript<G1Affine, L, S, T, RATE, R_F, R_P>;
#[derive(Error, Debug)]
/// Errors related to proof aggregation
pub enum AggregationError {
/// A KZG proof could not be verified
#[error("failed to verify KZG proof")]
KZGProofVerification,
/// EVM execution errors
#[error("EVM execution of raw code failed")]
EVMRawExecution,
/// proof read errors
#[error("Failed to read proof")]
ProofRead,
/// proof verification errors
#[error("Failed to verify proof")]
ProofVerify,
/// proof creation errors
#[error("Failed to create proof")]
ProofCreate,
}
/// An application snark with proof and instance variables ready for aggregation (raw field element)
#[derive(Debug)]
pub struct Snark {
@@ -142,7 +164,7 @@ pub fn aggregate<'a>(
loader: &Rc<Halo2Loader<'a>>,
snarks: &[SnarkWitness],
as_proof: Value<&'_ [u8]>,
) -> KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>> {
) -> Result<KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>, plonk::Error> {
let assign_instances = |instances: &[Vec<Value<Fr>>]| {
instances
.iter()
@@ -155,22 +177,24 @@ pub fn aggregate<'a>(
.collect_vec()
};
let accumulators = snarks
.iter()
.flat_map(|snark| {
let protocol = snark.protocol.loaded(loader);
let instances = assign_instances(&snark.instances);
let mut transcript =
PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap();
Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap()
})
.collect_vec();
let mut accumulators = vec![];
for snark in snarks.iter() {
let protocol = snark.protocol.loaded(loader);
let instances = assign_instances(&snark.instances);
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript)
.map_err(|_| plonk::Error::Synthesis)?;
let mut accum = Plonk::succinct_verify(svk, &protocol, &instances, &proof)
.map_err(|_| plonk::Error::Synthesis)?;
accumulators.append(&mut accum);
}
let accumulator = {
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, as_proof);
let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap();
As::verify(&Default::default(), &accumulators, &proof).unwrap()
let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript)
.map_err(|_| plonk::Error::Synthesis)?;
As::verify(&Default::default(), &accumulators, &proof).map_err(|_| plonk::Error::Synthesis)
};
accumulator
@@ -229,28 +253,31 @@ pub struct AggregationCircuit {
impl AggregationCircuit {
/// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof.
pub fn new(params: &ParamsKZG<Bn256>, snarks: impl IntoIterator<Item = Snark>) -> Self {
pub fn new(
params: &ParamsKZG<Bn256>,
snarks: impl IntoIterator<Item = Snark>,
) -> Result<Self, AggregationError> {
let svk = params.get_g()[0].into();
let snarks = snarks.into_iter().collect_vec();
let accumulators = snarks
.iter()
.flat_map(|snark| {
trace!("Aggregating with snark instances {:?}", snark.instances);
let mut transcript =
PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
let proof =
Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript)
.unwrap();
Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap()
})
.collect_vec();
let mut accumulators = vec![];
for snark in snarks.iter() {
trace!("Aggregating with snark instances {:?}", snark.instances);
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
let proof = Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript)
.map_err(|_| AggregationError::ProofRead)?;
let mut accum = Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof)
.map_err(|_| AggregationError::ProofVerify)?;
accumulators.append(&mut accum);
}
trace!("Accumulator");
let (accumulator, as_proof) = {
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(Vec::new());
let accumulator =
As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng)
.unwrap();
.map_err(|_| AggregationError::ProofCreate)?;
(accumulator, transcript.finalize())
};
@@ -260,12 +287,12 @@ impl AggregationCircuit {
.map(fe_to_limbs::<_, _, LIMBS, BITS>)
.concat();
Self {
Ok(Self {
svk,
snarks: snarks.into_iter().map_into().collect(),
instances,
as_proof: Value::known(as_proof),
}
})
}
/// Accumulator indices used in generating verifier.
@@ -331,7 +358,7 @@ impl Circuit<Fr> for AggregationCircuit {
let ecc_chip = config.ecc_chip();
let loader = Halo2Loader::new(ecc_chip, ctx);
let KzgAccumulator { lhs, rhs } =
aggregate(&self.svk, &loader, &self.snarks, self.as_proof());
aggregate(&self.svk, &loader, &self.snarks, self.as_proof())?;
let lhs = lhs.assigned().clone();
let rhs = rhs.assigned().clone();
@@ -355,10 +382,14 @@ impl Circuit<Fr> for AggregationCircuit {
}
/// Create proof and instance variables for the application snark
pub fn gen_application_snark(params: &ParamsKZG<Bn256>, data: &ModelInput, args: &Cli) -> Snark {
let (circuit, public_inputs) = prepare_circuit_and_public_input::<Fr>(data, &args);
pub fn gen_application_snark(
params: &ParamsKZG<Bn256>,
data: &ModelInput,
args: &Cli,
) -> Result<Snark, Box<dyn Error>> {
let (circuit, public_inputs) = prepare_circuit_and_public_input::<Fr>(data, args)?;
let pk = gen_pk(params, &circuit);
let pk = gen_pk(params, &circuit)?;
let number_instance = public_inputs[0].len();
trace!("number_instance {:?}", number_instance);
let protocol = compile(
@@ -377,8 +408,8 @@ pub fn gen_application_snark(params: &ParamsKZG<Bn256>, data: &ModelInput, args:
_,
PoseidonTranscript<NativeLoader, _>,
PoseidonTranscript<NativeLoader, _>,
>(params, &pk, circuit, pi_inner.clone());
Snark::new(protocol, pi_inner, proof)
>(params, &pk, circuit, pi_inner.clone())?;
Ok(Snark::new(protocol, pi_inner, proof))
}
/// Create aggregation EVM verifier bytecode
@@ -387,7 +418,7 @@ pub fn gen_aggregation_evm_verifier(
vk: &VerifyingKey<G1Affine>,
num_instance: Vec<usize>,
accumulator_indices: Vec<(usize, usize)>,
) -> Vec<u8> {
) -> Result<Vec<u8>, AggregationError> {
let svk = params.get_g()[0].into();
let dk = (params.g2(), params.s_g2()).into();
let protocol = compile(
@@ -400,37 +431,40 @@ pub fn gen_aggregation_evm_verifier(
let loader = EvmLoader::new::<Fq, Fr>();
let protocol = protocol.loaded(&loader);
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader.clone());
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader);
let instances = transcript.load_instances(num_instance);
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap();
Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap();
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript)
.map_err(|_| AggregationError::ProofRead)?;
Plonk::verify(&svk, &dk, &protocol, &instances, &proof)
.map_err(|_| AggregationError::ProofVerify)?;
evm::compile_yul(&loader.yul_code())
Ok(evm::compile_yul(&loader.yul_code()))
}
/// Verify by executing bytecode with instance variables and proof as input
pub fn evm_verify(deployment_code: Vec<u8>, instances: Vec<Vec<Fr>>, proof: Vec<u8>) {
pub fn evm_verify(
deployment_code: Vec<u8>,
instances: Vec<Vec<Fr>>,
proof: Vec<u8>,
) -> Result<bool, Box<dyn Error>> {
let calldata = encode_calldata(&instances, &proof);
let success = {
let mut evm = ExecutorBuilder::default()
.with_gas_limit(u64::MAX.into())
.build(Backend::new(MultiFork::new().0, None));
let mut evm = ExecutorBuilder::default()
.with_gas_limit(u64::MAX.into())
.build(Backend::new(MultiFork::new().0, None));
let caller = Address::from_low_u64_be(0xfe);
let verifier = evm
.deploy(caller, deployment_code.into(), 0.into(), None)
.unwrap()
.address;
let result = evm
.call_raw(caller, verifier, calldata.into(), 0.into())
.unwrap();
let caller = Address::from_low_u64_be(0xfe);
let verifier = evm
.deploy(caller, deployment_code.into(), 0.into(), None)
.map_err(Box::new)?
.address;
let result = evm
.call_raw(caller, verifier, calldata.into(), 0.into())
.map_err(|_| Box::new(AggregationError::EVMRawExecution))?;
dbg!(result.gas_used);
dbg!(result.gas_used);
!result.reverted
};
assert!(success);
Ok(!result.reverted)
}
/// Generate a structured reference string for testing. Not secure, do not use in production.
@@ -439,9 +473,12 @@ pub fn gen_srs(k: u32) -> ParamsKZG<Bn256> {
}
/// Generate the proving key
pub fn gen_pk<C: Circuit<Fr>>(params: &ParamsKZG<Bn256>, circuit: &C) -> ProvingKey<G1Affine> {
let vk = keygen_vk(params, circuit).unwrap();
keygen_pk(params, vk, circuit).unwrap()
pub fn gen_pk<C: Circuit<Fr>>(
params: &ParamsKZG<Bn256>,
circuit: &C,
) -> Result<ProvingKey<G1Affine>, plonk::Error> {
let vk = keygen_vk(params, circuit)?;
keygen_pk(params, vk, circuit)
}
/// Generates proof for either application circuit (model) or aggregation circuit.
@@ -455,43 +492,41 @@ pub fn gen_kzg_proof<
pk: &ProvingKey<G1Affine>,
circuit: C,
instances: Vec<Vec<Fr>>,
) -> Vec<u8> {
) -> Result<Vec<u8>, Box<dyn Error>> {
MockProver::run(params.k(), &circuit, instances.clone())
.unwrap()
.map_err(Box::new)?
.assert_satisfied();
let instances = instances
.iter()
.map(|instances| instances.as_slice())
.collect_vec();
let proof = {
let mut transcript = TW::init(Vec::new());
create_proof::<KZGCommitmentScheme<Bn256>, ProverGWC<_>, _, _, TW, _>(
params,
pk,
&[circuit],
&[instances.as_slice()],
OsRng,
&mut transcript,
)
.unwrap();
transcript.finalize()
};
let mut proof = TW::init(Vec::new());
create_proof::<KZGCommitmentScheme<Bn256>, ProverGWC<_>, _, _, TW, _>(
params,
pk,
&[circuit],
&[instances.as_slice()],
OsRng,
&mut proof,
)
.map_err(Box::new)?;
let proof = proof.finalize();
let accept = {
let mut transcript = TR::init(Cursor::new(proof.clone()));
VerificationStrategy::<_, VerifierGWC<_>>::finalize(
verify_proof::<_, VerifierGWC<_>, _, TR, _>(
params.verifier_params(),
pk.get_vk(),
AccumulatorStrategy::new(params.verifier_params()),
&[instances.as_slice()],
&mut transcript,
)
.unwrap(),
)
};
assert!(accept);
let mut transcript = TR::init(Cursor::new(proof.clone()));
let verify = verify_proof::<_, VerifierGWC<_>, _, TR, _>(
params.verifier_params(),
pk.get_vk(),
AccumulatorStrategy::new(params.verifier_params()),
&[instances.as_slice()],
&mut transcript,
)
.map_err(Box::new)?;
proof
let accept = VerificationStrategy::<_, VerifierGWC<_>>::finalize(verify);
if !accept {
return Err(Box::new(AggregationError::KZGProofVerification));
}
Ok(proof)
}

View File

@@ -2,11 +2,11 @@
#[cfg(feature = "evm")]
pub mod aggregation;
use crate::abort;
use crate::commands::{data_path, Cli};
use crate::fieldutils::i32_to_felt;
use crate::graph::{utilities::vector_to_quantized, Model, ModelCircuit};
use crate::tensor::{Tensor, TensorType};
use halo2_proofs::arithmetic::FieldExt;
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
@@ -15,12 +15,12 @@ use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{
Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer,
};
use halo2_proofs::{arithmetic::FieldExt, dev::VerifyFailure};
use log::{error, info, trace};
use log::{info, trace};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::marker::PhantomData;
use std::ops::Deref;
use std::path::PathBuf;
@@ -49,180 +49,95 @@ pub struct Proof {
impl Proof {
/// Saves the Proof to a specified `proof_path`.
pub fn save(&self, proof_path: &PathBuf) {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
abort!("failed to convert proof json to string {:?}", e);
}
};
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
let serialized = serde_json::to_string(&self).map_err(Box::<dyn Error>::from)?;
let mut file = std::fs::File::create(proof_path).expect("create failed");
file.write_all(serialized.as_bytes()).expect("write failed");
let mut file = std::fs::File::create(proof_path).map_err(Box::<dyn Error>::from)?;
file.write_all(serialized.as_bytes())
.map_err(Box::<dyn Error>::from)
}
/// Load a json serialized proof from the provided path.
pub fn load(proof_path: &PathBuf) -> Self {
let mut file = match File::open(proof_path) {
Ok(f) => f,
Err(e) => {
abort!("failed to open proof file {:?}", e);
}
};
pub fn load(proof_path: &PathBuf) -> Result<Self, Box<dyn Error>> {
let mut file = File::open(proof_path).map_err(Box::<dyn Error>::from)?;
let mut data = String::new();
match file.read_to_string(&mut data) {
Ok(_) => {}
Err(e) => {
abort!("failed to read file {:?}", e);
}
};
serde_json::from_str(&data).expect("JSON was not well-formatted")
file.read_to_string(&mut data)
.map_err(Box::<dyn Error>::from)?;
serde_json::from_str(&data).map_err(Box::<dyn Error>::from)
}
}
/// Helper function to print helpful error messages after verification has failed.
pub fn parse_prover_errors(f: &VerifyFailure) {
match f {
VerifyFailure::Lookup {
name,
location,
lookup_index,
} => {
error!("lookup {:?} is out of range, try increasing 'bits' or reducing 'scale' ({} and lookup index {}).",
name, location, lookup_index);
}
VerifyFailure::ConstraintNotSatisfied {
constraint,
location,
cell_values: _,
} => {
error!("{} was not satisfied {}).", constraint, location);
}
VerifyFailure::ConstraintPoisoned { constraint } => {
error!("constraint {:?} was poisoned", constraint);
}
VerifyFailure::Permutation { column, location } => {
error!(
"permutation did not preserve column cell value (try increasing 'scale') ({} {}).",
column, location
);
}
VerifyFailure::CellNotAssigned {
gate,
region,
gate_offset,
column,
offset,
} => {
error!(
"Unnassigned value in {} ({}) and {} ({:?}, {})",
gate, region, gate_offset, column, offset
);
}
}
}
type CircuitInputs<F> = (ModelCircuit<F>, Vec<Tensor<i32>>);
/// Initialize the model circuit and quantize the provided float inputs from the provided `ModelInput`.
pub fn prepare_circuit_and_public_input<F: FieldExt>(
data: &ModelInput,
args: &Cli,
) -> (ModelCircuit<F>, Vec<Tensor<i32>>) {
let model = Model::from_ezkl_conf(args.clone());
) -> Result<CircuitInputs<F>, Box<dyn Error>> {
let model = Model::from_ezkl_conf(args.clone())?;
let out_scales = model.get_output_scales();
let circuit = prepare_circuit(data, args);
let circuit = prepare_circuit(data, args)?;
// quantize the supplied data using the provided scale.
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
let mut public_inputs = vec![];
if model.visibility.input.is_public() {
let mut res = data
.input_data
.iter()
.map(
|v| match vector_to_quantized(v, &Vec::from([v.len()]), 0.0, model.scale) {
Ok(q) => q,
Err(e) => {
abort!("failed to quantize vector {:?}", e);
}
},
)
.collect();
public_inputs.append(&mut res);
for v in data.input_data.iter() {
let t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, model.scale)?;
public_inputs.push(t);
}
}
if model.visibility.output.is_public() {
let mut res = data
.output_data
.iter()
.enumerate()
.map(|(idx, v)| {
match vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx]) {
Ok(q) => q,
Err(e) => {
abort!("failed to quantize vector {:?}", e);
}
}
})
.collect();
public_inputs.append(&mut res);
for (idx, v) in data.output_data.iter().enumerate() {
let t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx])?;
public_inputs.push(t);
}
}
info!(
"public inputs lengths: {:?}",
"public inputs lengths: {:?}",
public_inputs
.iter()
.map(|i| i.len())
.collect::<Vec<usize>>()
);
trace!("{:?}", public_inputs);
(circuit, public_inputs)
Ok((circuit, public_inputs))
}
/// Initialize the model circuit
pub fn prepare_circuit<F: FieldExt>(data: &ModelInput, args: &Cli) -> ModelCircuit<F> {
pub fn prepare_circuit<F: FieldExt>(
data: &ModelInput,
args: &Cli,
) -> Result<ModelCircuit<F>, Box<dyn Error>> {
// quantize the supplied data using the provided scale.
let inputs = data
.input_data
.iter()
.zip(data.input_shapes.clone())
.map(|(i, s)| match vector_to_quantized(i, &s, 0.0, args.scale) {
Ok(q) => q,
Err(e) => {
abort!("failed to quantize vector {:?}", e);
}
})
.collect();
let mut inputs: Vec<Tensor<i32>> = vec![];
for (input, shape) in data.input_data.iter().zip(data.input_shapes.clone()) {
let t = vector_to_quantized(input, &shape, 0.0, args.scale)?;
inputs.push(t);
}
ModelCircuit::<F> {
Ok(ModelCircuit::<F> {
inputs,
_marker: PhantomData,
}
})
}
/// Deserializes the required inputs to a model at path `datapath` to a [ModelInput] struct.
pub fn prepare_data(datapath: String) -> ModelInput {
let mut file = match File::open(data_path(datapath)) {
Ok(t) => t,
Err(e) => {
abort!("failed to open data file {:?}", e);
}
};
pub fn prepare_data(datapath: String) -> Result<ModelInput, Box<dyn Error>> {
let mut file = File::open(data_path(datapath)).map_err(Box::<dyn Error>::from)?;
let mut data = String::new();
match file.read_to_string(&mut data) {
Ok(_) => {}
Err(e) => {
abort!("failed to read file {:?}", e);
}
};
let data: ModelInput = serde_json::from_str(&data).expect("JSON was not well-formatted");
data
file.read_to_string(&mut data)
.map_err(Box::<dyn Error>::from)?;
serde_json::from_str(&data).map_err(Box::<dyn Error>::from)
}
/// Creates a [VerifyingKey] and [ProvingKey] for a [ModelCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`).
pub fn create_keys<Scheme: CommitmentScheme, F: FieldExt + TensorType>(
circuit: &ModelCircuit<F>,
params: &'_ Scheme::ParamsProver,
) -> ProvingKey<Scheme::Curve>
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
ModelCircuit<F>: Circuit<Scheme::Scalar>,
{
@@ -232,12 +147,12 @@ where
// Initialize the proving key
let now = Instant::now();
trace!("preparing VK");
let vk = keygen_vk(params, &empty_circuit).expect("keygen_vk should not fail");
let vk = keygen_vk(params, &empty_circuit)?;
info!("VK took {}", now.elapsed().as_secs());
let now = Instant::now();
let pk = keygen_pk(params, vk, &empty_circuit).expect("keygen_pk should not fail");
let pk = keygen_pk(params, vk, &empty_circuit)?;
info!("PK took {}", now.elapsed().as_secs());
pk
Ok(pk)
}
/// a wrapper around halo2's create_proof
@@ -251,7 +166,7 @@ pub fn create_proof_model<
public_inputs: &[Tensor<i32>],
params: &'params Scheme::ParamsProver,
pk: &ProvingKey<Scheme::Curve>,
) -> (Proof, Vec<Vec<usize>>)
) -> Result<(Proof, Vec<Vec<usize>>), halo2_proofs::plonk::Error>
where
ModelCircuit<F>: Circuit<Scheme::Scalar>,
{
@@ -282,8 +197,7 @@ where
instances,
&mut rng,
&mut transcript,
)
.expect("proof generation should not fail");
)?;
let proof = transcript.finalize();
info!("Proof took {}", now.elapsed().as_secs());
@@ -295,7 +209,7 @@ where
proof,
};
(checkable_pf, dims)
Ok((checkable_pf, dims))
}
/// A wrapper around halo2's verify_proof
@@ -310,7 +224,7 @@ pub fn verify_proof_model<
params: &'params Scheme::ParamsVerifier,
vk: &VerifyingKey<Scheme::Curve>,
strategy: Strategy,
) -> bool
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
where
ModelCircuit<F>: Circuit<Scheme::Scalar>,
{
@@ -332,60 +246,57 @@ where
let now = Instant::now();
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof.proof[..]);
let result =
verify_proof::<Scheme, V, _, _, _>(params, vk, strategy, instances, &mut transcript)
.is_ok();
info!("verify took {}", now.elapsed().as_secs());
result
verify_proof::<Scheme, V, _, _, _>(params, vk, strategy, instances, &mut transcript)
}
/// Loads a [VerifyingKey] at `path`.
pub fn load_vk<Scheme: CommitmentScheme, F: FieldExt + TensorType>(
path: PathBuf,
params: &'_ Scheme::ParamsVerifier,
) -> VerifyingKey<Scheme::Curve>
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
where
ModelCircuit<F>: Circuit<Scheme::Scalar>,
{
info!("loading verification key from {:?}", path);
let f = match File::open(path) {
Ok(f) => f,
Err(e) => {
abort!("failed to load vk {}", e);
}
};
let f = File::open(path).map_err(Box::<dyn Error>::from)?;
let mut reader = BufReader::new(f);
VerifyingKey::<Scheme::Curve>::read::<_, ModelCircuit<F>>(&mut reader, params).unwrap()
VerifyingKey::<Scheme::Curve>::read::<_, ModelCircuit<F>>(&mut reader, params)
.map_err(Box::<dyn Error>::from)
}
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
pub fn load_params<Scheme: CommitmentScheme>(path: PathBuf) -> Scheme::ParamsVerifier {
pub fn load_params<Scheme: CommitmentScheme>(
path: PathBuf,
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
info!("loading params from {:?}", path);
let f = match File::open(path) {
Ok(f) => f,
Err(e) => {
abort!("failed to load params {}", e);
}
};
let f = File::open(path).map_err(Box::<dyn Error>::from)?;
let mut reader = BufReader::new(f);
Params::<'_, Scheme::Curve>::read(&mut reader).unwrap()
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
}
/// Saves a [VerifyingKey] to `path`.
pub fn save_vk<Scheme: CommitmentScheme>(path: &PathBuf, vk: &VerifyingKey<Scheme::Curve>) {
pub fn save_vk<Scheme: CommitmentScheme>(
path: &PathBuf,
vk: &VerifyingKey<Scheme::Curve>,
) -> Result<(), io::Error> {
info!("saving verification key 💾");
let f = File::create(path).unwrap();
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
vk.write(&mut writer).unwrap();
writer.flush().unwrap();
vk.write(&mut writer)?;
writer.flush()?;
Ok(())
}
/// Saves [CommitmentScheme] parameters to `path`.
pub fn save_params<Scheme: CommitmentScheme>(path: &PathBuf, params: &'_ Scheme::ParamsVerifier) {
pub fn save_params<Scheme: CommitmentScheme>(
path: &PathBuf,
params: &'_ Scheme::ParamsVerifier,
) -> Result<(), io::Error> {
info!("saving parameters 💾");
let f = File::create(path).unwrap();
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
params.write(&mut writer).unwrap();
writer.flush().unwrap();
params.write(&mut writer)?;
writer.flush()?;
Ok(())
}

View File

@@ -18,11 +18,26 @@ use halo2_proofs::{
};
use itertools::Itertools;
use std::cmp::max;
use std::error::Error;
use std::fmt::Debug;
use std::iter::Iterator;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::Range;
use thiserror::Error;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
pub enum TensorError {
/// Shape mismatch in a operation
#[error("dimension mismatch in tensor op: {0}")]
DimMismatch(String),
/// Shape when instantiating
#[error("dimensionality error when manipulating a tensor")]
DimError,
/// wrong method was called on a tensor-like struct
#[error("wrong method called")]
WrongMethod,
}
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
@@ -173,10 +188,6 @@ impl TensorType for halo2curves::bn256::Fr {
}
}
/// A wrapper for tensor related errors.
#[derive(Debug)]
pub struct TensorError(String);
/// A generic multi-dimensional array representation of a Tensor.
/// The `inner` attribute contains a vector of values whereas `dims` corresponds to the dimensionality of the array
/// and as such determines how we index, query for values, or slice a Tensor.
@@ -217,13 +228,24 @@ impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
}
}
impl<I: Iterator, T: Clone + TensorType + From<I::Item>> From<I> for Tensor<T>
impl<I: Iterator> From<I> for Tensor<I::Item>
where
I::Item: Clone + TensorType,
Vec<T>: FromIterator<I::Item>,
I::Item: TensorType + Clone,
Vec<I::Item>: FromIterator<I::Item>,
{
fn from(value: I) -> Tensor<T> {
let data: Vec<T> = value.collect::<Vec<T>>();
fn from(value: I) -> Tensor<I::Item> {
let data: Vec<I::Item> = value.collect::<Vec<I::Item>>();
Tensor::new(Some(&data), &[data.len()]).unwrap()
}
}
impl<T> FromIterator<T> for Tensor<T>
where
T: TensorType + Clone,
Vec<T>: FromIterator<T>,
{
fn from_iter<I: IntoIterator<Item = T>>(value: I) -> Tensor<T> {
let data: Vec<I::Item> = value.into_iter().collect::<Vec<I::Item>>();
Tensor::new(Some(&data), &[data.len()]).unwrap()
}
}
@@ -292,9 +314,7 @@ impl<T: Clone + TensorType> Tensor<T> {
match values {
Some(v) => {
if total_dims != v.len() {
return Err(TensorError(
"length of values array is not equal to tensor total elements".to_string(),
));
return Err(TensorError::DimError);
}
Ok(Tensor {
inner: Vec::from(v),
@@ -355,10 +375,12 @@ impl<T: Clone + TensorType> Tensor<T> {
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<i32>::new(Some(&[1, 2]), &[2]).unwrap();
///
/// assert_eq!(a.get_slice(&[0..2]), b);
/// assert_eq!(a.get_slice(&[0..2]).unwrap(), b);
/// ```
pub fn get_slice(&self, indices: &[Range<usize>]) -> Tensor<T> {
assert!(self.dims.len() >= indices.len());
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<Tensor<T>, TensorError> {
if self.dims.len() < indices.len() {
return Err(TensorError::DimError);
}
let mut res = Vec::new();
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
@@ -377,7 +399,7 @@ impl<T: Clone + TensorType> Tensor<T> {
}
}
Tensor::new(Some(&res), &dims).unwrap()
Tensor::new(Some(&res), &dims)
}
/// Get the array index from rows / columns indices.
@@ -446,16 +468,22 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Maps a function to tensors and enumerates
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::{Tensor, TensorError};
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.enum_map(|i, x| i32::pow(x + i as i32, 2)).unwrap();
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn enum_map<F: FnMut(usize, T) -> G, G: TensorType>(
pub fn enum_map<F: FnMut(usize, T) -> Result<G, E>, G: TensorType, E: Error>(
&self,
mut f: F,
) -> Result<Tensor<G>, TensorError> {
let mut t = Tensor::from(self.inner.iter().enumerate().map(|(i, e)| f(i, e.clone())));
) -> Result<Tensor<G>, E> {
let vec: Result<Vec<G>, E> = self
.inner
.iter()
.enumerate()
.map(|(i, e)| f(i, e.clone()))
.collect();
let mut t: Tensor<G> = Tensor::from(vec?.iter().cloned());
t.reshape(self.dims());
Ok(t)
}
@@ -540,6 +568,6 @@ mod tests {
fn tensor_slice() {
let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]), b);
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
}
}

View File

@@ -1,3 +1,4 @@
use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
pub use std::ops::{Add, Div, Mul, Sub};
@@ -23,17 +24,20 @@ pub use std::ops::{Add, Div, Mul, Sub};
/// Some(&[0, 0]),
/// &[2],
/// ).unwrap();
/// let result = affine(&vec![x, k, b]);
/// let result = affine(&vec![x, k, b]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[26, 7, 11, 3, 15, 3, 7, 2]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn affine<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &Vec<Tensor<T>>,
) -> Tensor<T> {
assert_eq!(inputs.len(), 3);
) -> Result<Tensor<T>, TensorError> {
let (mut input, kernel, bias) = (inputs[0].clone(), inputs[1].clone(), inputs[2].clone());
assert_eq!(bias.dims()[0], kernel.dims()[0]);
assert_eq!(input.dims()[0], kernel.dims()[1]);
if (inputs.len() != 3)
|| (bias.dims()[0] != kernel.dims()[0])
|| (input.dims()[0] != kernel.dims()[1])
{
return Err(TensorError::DimMismatch("affine".to_string()));
}
// does matrix to vector multiplication
if input.dims().len() == 1 {
@@ -49,9 +53,9 @@ pub fn affine<T: TensorType + Mul<Output = T> + Add<Output = T>>(
for i in 0..kernel_dims[0] {
for j in 0..input_dims[1] {
let prod = dot(&vec![
&kernel.get_slice(&[i..i + 1]),
&input.get_slice(&[0..input_dims[0], j..j + 1]),
]);
&kernel.get_slice(&[i..i + 1])?,
&input.get_slice(&[0..input_dims[0], j..j + 1])?,
])?;
output.set(&[i, j], prod[0].clone() + bias[i].clone());
}
}
@@ -59,23 +63,50 @@ pub fn affine<T: TensorType + Mul<Output = T> + Add<Output = T>>(
if output.dims()[1] == 1 {
output.flatten();
}
output
Ok(output)
}
/// Scales and shifts a tensor.
/// Given inputs (x,k,b) computes k*x + b elementwise
/// # Arguments
///
/// * `inputs` - Vector of tensors of length 2
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scale_and_shift;
///
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let b = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = scale_and_shift(&vec![x, k, b]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[6, 2, 6, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn scale_and_shift<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &Vec<Tensor<T>>,
) -> Tensor<T> {
assert_eq!(inputs.len(), 3);
) -> Result<Tensor<T>, TensorError> {
if (inputs.len() != 3)
|| (inputs[1].dims() != inputs[2].dims())
|| (inputs[0].dims() != inputs[1].dims())
{
return Err(TensorError::DimMismatch("scale and shift".to_string()));
}
let (input, kernel, bias) = (inputs[0].clone(), inputs[1].clone(), inputs[2].clone());
assert_eq!(input.dims(), kernel.dims());
assert_eq!(bias.dims(), kernel.dims());
let mut output: Tensor<T> = input;
for (i, bias_i) in bias.iter().enumerate() {
output[i] = kernel[i].clone() * output[i].clone() + bias_i.clone()
}
output
Ok(output)
}
/// Matrix multiplies two 2D tensors.
@@ -95,20 +126,20 @@ pub fn scale_and_shift<T: TensorType + Mul<Output = T> + Add<Output = T>>(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = matmul(&vec![k, x]);
/// let result = matmul(&vec![k, x]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[26, 7, 11, 3, 15, 3, 7, 2]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn matmul<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &Vec<Tensor<T>>,
) -> Tensor<T> {
assert_eq!(inputs.len(), 2);
) -> Result<Tensor<T>, TensorError> {
let (a, b) = (inputs[0].clone(), inputs[1].clone());
assert_eq!(a.dims()[a.dims().len() - 1], b.dims()[a.dims().len() - 2]);
assert_eq!(
a.dims()[0..a.dims().len() - 2],
b.dims()[0..a.dims().len() - 2]
);
if (inputs.len() != 2)
|| (a.dims()[a.dims().len() - 1] != b.dims()[a.dims().len() - 2])
|| (a.dims()[0..a.dims().len() - 2] != b.dims()[0..a.dims().len() - 2])
{
return Err(TensorError::DimMismatch("matmul".to_string()));
}
let mut dims = Vec::from(&a.dims()[0..a.dims().len() - 2]);
dims.push(a.dims()[a.dims().len() - 2]);
@@ -128,11 +159,11 @@ pub fn matmul<T: TensorType + Mul<Output = T> + Add<Output = T>>(
.map(|&d| d..(d + 1))
.collect::<Vec<_>>();
col[coord.len() - 2] = 0..b.dims()[coord.len() - 2];
let prod = dot(&vec![&a.get_slice(&row[0..]), &b.get_slice(&col[0..])]);
let prod = dot(&vec![&a.get_slice(&row[0..])?, &b.get_slice(&col[0..])?])?;
output.set(&coord, prod[0].clone());
}
output
Ok(output)
}
/// Adds multiple tensors.
@@ -151,17 +182,19 @@ pub fn matmul<T: TensorType + Mul<Output = T> + Add<Output = T>>(
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = add(&vec![x, k]);
/// let result = add(&vec![x, k]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Result<Tensor<T>, TensorError> {
// determines if we're multiplying by a 1D const
if t.len() == 2 && t[1].dims().len() == 1 && t[1].dims()[0] == 1 {
return const_add(&t[0], t[1][0].clone());
}
for e in t.iter() {
assert_eq!(t[0].dims(), e.dims());
if t[0].dims() != e.dims() {
return Err(TensorError::DimMismatch("add".to_string()));
}
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -172,7 +205,7 @@ pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
}
}
output
Ok(output)
}
/// Elementwise adds a tensor with a const element.
@@ -189,11 +222,14 @@ pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
/// &[2, 3],
/// ).unwrap();
/// let k = 2;
/// let result = const_add(&x, k);
/// let result = const_add(&x, k).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn const_add<T: TensorType + Add<Output = T>>(a: &Tensor<T>, b: T) -> Tensor<T> {
pub fn const_add<T: TensorType + Add<Output = T>>(
a: &Tensor<T>,
b: T,
) -> Result<Tensor<T>, TensorError> {
// calculate value of output
let mut output: Tensor<T> = a.clone();
@@ -201,7 +237,7 @@ pub fn const_add<T: TensorType + Add<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
output[i] = output[i].clone() + b.clone();
}
output
Ok(output)
}
/// Subtracts multiple tensors.
@@ -221,18 +257,20 @@ pub fn const_add<T: TensorType + Add<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = sub(&vec![x, k]);
/// let result = sub(&vec![x, k]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Result<Tensor<T>, TensorError> {
// determines if we're multiplying by a 1D const
if t.len() == 2 && t[1].dims().len() == 1 && t[1].dims()[0] == 1 {
return const_sub(&t[0], t[1][0].clone());
}
for e in t.iter() {
assert_eq!(t[0].dims(), e.dims());
if t[0].dims() != e.dims() {
return Err(TensorError::DimMismatch("sub".to_string()));
}
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -243,7 +281,7 @@ pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
}
}
output
Ok(output)
}
/// Elementwise subtracts a tensor with a const element.
@@ -260,11 +298,14 @@ pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
/// &[2, 3],
/// ).unwrap();
/// let k = 2;
/// let result = const_sub(&x, k);
/// let result = const_sub(&x, k).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn const_sub<T: TensorType + Sub<Output = T>>(a: &Tensor<T>, b: T) -> Tensor<T> {
pub fn const_sub<T: TensorType + Sub<Output = T>>(
a: &Tensor<T>,
b: T,
) -> Result<Tensor<T>, TensorError> {
// calculate value of output
let mut output: Tensor<T> = a.clone();
@@ -272,7 +313,7 @@ pub fn const_sub<T: TensorType + Sub<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
output[i] = output[i].clone() - b.clone();
}
output
Ok(output)
}
/// Elementwise multiplies two tensors.
@@ -292,18 +333,20 @@ pub fn const_sub<T: TensorType + Sub<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = mult(&vec![x, k]);
/// let result = mult(&vec![x, k]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Result<Tensor<T>, TensorError> {
// determines if we're multiplying by a 1D const
if t.len() == 2 && t[1].dims().len() == 1 && t[1].dims()[0] == 1 {
return const_mult(&t[0], t[1][0].clone());
}
for e in t.iter() {
assert_eq!(t[0].dims(), e.dims());
if t[0].dims() != e.dims() {
return Err(TensorError::DimMismatch("mult".to_string()));
}
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -314,7 +357,7 @@ pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
}
}
output
Ok(output)
}
/// Elementwise divide a tensor with another tensor.
@@ -334,19 +377,24 @@ pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = div(x, y);
/// let result = div(x, y).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn div<T: TensorType + Div<Output = T>>(t: Tensor<T>, d: Tensor<T>) -> Tensor<T> {
assert_eq!(t.dims(), d.dims());
pub fn div<T: TensorType + Div<Output = T>>(
t: Tensor<T>,
d: Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
if t.dims() != d.dims() {
return Err(TensorError::DimMismatch("div".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t;
for (i, d_i) in d.iter().enumerate() {
output[i] = output[i].clone() / d_i.clone()
}
output
Ok(output)
}
/// Elementwise multiplies a tensor with a const element.
@@ -363,11 +411,14 @@ pub fn div<T: TensorType + Div<Output = T>>(t: Tensor<T>, d: Tensor<T>) -> Tenso
/// &[2, 3],
/// ).unwrap();
/// let k = 2;
/// let result = const_mult(&x, k);
/// let result = const_mult(&x, k).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn const_mult<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, b: T) -> Tensor<T> {
pub fn const_mult<T: TensorType + Mul<Output = T>>(
a: &Tensor<T>,
b: T,
) -> Result<Tensor<T>, TensorError> {
// calculate value of output
let mut output: Tensor<T> = a.clone();
@@ -375,7 +426,7 @@ pub fn const_mult<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, b: T) -> Tenso
output[i] = output[i].clone() * b.clone();
}
output
Ok(output)
}
/// Rescale a tensor with a const integer (similar to const_mult).
@@ -392,11 +443,14 @@ pub fn const_mult<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, b: T) -> Tenso
/// &[2, 3],
/// ).unwrap();
/// let k = 2;
/// let result = rescale(&x, k);
/// let result = rescale(&x, k).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn rescale<T: TensorType + Add<Output = T>>(a: &Tensor<T>, mult: usize) -> Tensor<T> {
pub fn rescale<T: TensorType + Add<Output = T>>(
a: &Tensor<T>,
mult: usize,
) -> Result<Tensor<T>, TensorError> {
// calculate value of output
let mut output: Tensor<T> = a.clone();
for (i, a_i) in a.iter().enumerate() {
@@ -404,7 +458,7 @@ pub fn rescale<T: TensorType + Add<Output = T>>(a: &Tensor<T>, mult: usize) -> T
output[i] = output[i].clone() + a_i.clone();
}
}
output
Ok(output)
}
/// Elementwise raise a tensor to the nth power.
@@ -420,11 +474,14 @@ pub fn rescale<T: TensorType + Add<Output = T>>(a: &Tensor<T>, mult: usize) -> T
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = pow(&x, 3);
/// let result = pow(&x, 3).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn pow<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, pow: usize) -> Tensor<T> {
pub fn pow<T: TensorType + Mul<Output = T>>(
a: &Tensor<T>,
pow: usize,
) -> Result<Tensor<T>, TensorError> {
// calculate value of output
let mut output: Tensor<T> = a.clone();
for (i, a_i) in a.iter().enumerate() {
@@ -432,7 +489,7 @@ pub fn pow<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, pow: usize) -> Tensor
output[i] = output[i].clone() * a_i.clone();
}
}
output
Ok(output)
}
/// Sums a tensor.
@@ -448,15 +505,15 @@ pub fn pow<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, pow: usize) -> Tensor
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = sum(&x);
/// let result = sum(&x).unwrap();
/// let expected = 21;
/// assert_eq!(result[0], expected);
/// ```
pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Tensor<T> {
pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
// calculate value of output
let mut res = T::zero().unwrap();
let _ = a.map(|a_i| res = res.clone() + a_i);
Tensor::new(Some(&[res]), &[1]).unwrap()
Tensor::new(Some(&[res]), &[1])
}
/// Applies convolution over a 3D tensor of shape C x H x W (and adds a bias).
@@ -482,7 +539,7 @@ pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Tensor<T> {
/// Some(&[0]),
/// &[1],
/// ).unwrap();
/// let result = convolution::<i32>(&vec![x, k, b], (0, 0), (1, 1));
/// let result = convolution::<i32>(&vec![x, k, b], (0, 0), (1, 1)).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[31, 16, 8, 26]), &[1, 2, 2]).unwrap();
/// assert_eq!(result, expected);
/// ```
@@ -490,17 +547,22 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &Vec<Tensor<T>>,
padding: (usize, usize),
stride: (usize, usize),
) -> Tensor<T> {
) -> Result<Tensor<T>, TensorError> {
let has_bias = inputs.len() == 3;
let (image, kernel) = (inputs[0].clone(), inputs[1].clone());
assert_eq!(image.dims().len(), 3);
assert_eq!(kernel.dims().len(), 4);
assert_eq!(image.dims()[0], kernel.dims()[1]);
if (image.dims().len() != 3)
|| (kernel.dims().len() != 4)
|| (image.dims()[0] != kernel.dims()[1])
{
return Err(TensorError::DimMismatch("conv".to_string()));
}
if has_bias {
let bias = inputs[2].clone();
assert_eq!(bias.dims().len(), 1);
assert_eq!(bias.dims()[0], kernel.dims()[0]);
if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) {
return Err(TensorError::DimMismatch("conv bias".to_string()));
}
}
let image_dims = image.dims();
@@ -515,7 +577,7 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
let (image_height, image_width) = (image_dims[1], image_dims[2]);
let padded_image = pad::<T>(image.clone(), padding);
let padded_image = pad::<T>(image.clone(), padding)?;
let vert_slides = (image_height + 2 * padding.0 - kernel_height) / stride.0 + 1;
let horz_slides = (image_width + 2 * padding.1 - kernel_width) / stride.1 + 1;
@@ -530,13 +592,13 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
for k in 0..horz_slides {
let cs = k * stride.1;
let mut res = dot(&vec![
&kernel.get_slice(&[i..i + 1]).clone(),
&kernel.get_slice(&[i..i + 1])?.clone(),
&padded_image.get_slice(&[
0..input_channels,
rs..(rs + kernel_height),
cs..(cs + kernel_width),
]),
]);
])?,
])?;
if has_bias {
// increment result by the bias
@@ -547,7 +609,7 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
}
}
}
output
Ok(output)
}
/// Applies 2D sum pooling over a 3D tensor of shape C x H x W.
@@ -569,7 +631,7 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3],
/// ).unwrap();
/// let pooled = sumpool::<i32>(&x, (0, 0), (1, 1), (2, 2));
/// let pooled = sumpool::<i32>(&x, (0, 0), (1, 1), (2, 2)).unwrap();
/// let expected: Tensor<i32> = Tensor::<i32>::new(Some(&[11, 8, 8, 10]), &[1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
/// ```
@@ -578,8 +640,10 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
padding: (usize, usize),
stride: (usize, usize),
kernel_shape: (usize, usize),
) -> Tensor<T> {
assert_eq!(image.dims().len(), 3);
) -> Result<Tensor<T>, TensorError> {
if image.dims().len() != 3 {
return Err(TensorError::DimMismatch("sumpool".to_string()));
}
let image_dims = image.dims();
let (image_channels, image_height, image_width) = (image_dims[0], image_dims[1], image_dims[2]);
@@ -587,7 +651,7 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
let (output_channels, kernel_height, kernel_width) =
(image_channels, kernel_shape.0, kernel_shape.1);
let padded_image = pad::<T>(image.clone(), padding);
let padded_image = pad::<T>(image.clone(), padding)?;
let vert_slides = (image_height + 2 * padding.0 - kernel_height) / stride.0 + 1;
let horz_slides = (image_width + 2 * padding.1 - kernel_width) / stride.1 + 1;
@@ -605,12 +669,12 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
i..i + 1,
rs..(rs + kernel_height),
cs..(cs + kernel_width),
]));
])?)?;
output.set(&[i, j, k], thesum[0].clone());
}
}
}
output
Ok(output)
}
/// Applies 2D max pooling over a 3D tensor of shape C x H x W.
@@ -632,7 +696,7 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3],
/// ).unwrap();
/// let pooled = max_pool2d::<i32>(&x, (0, 0), (1, 1), (2, 2));
/// let pooled = max_pool2d::<i32>(&x, (0, 0), (1, 1), (2, 2)).unwrap();
/// let expected: Tensor<i32> = Tensor::<i32>::new(Some(&[5, 4, 4, 6]), &[1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
/// ```
@@ -641,14 +705,16 @@ pub fn max_pool2d<T: TensorType>(
padding: (usize, usize),
stride: (usize, usize),
pool_dims: (usize, usize),
) -> Tensor<T> {
) -> Result<Tensor<T>, TensorError> {
if image.dims().len() != 3 {
return Err(TensorError::DimMismatch("max_pool2d".to_string()));
}
let image_dims = image.dims();
assert_eq!(image_dims.len(), 3);
let input_channels = image_dims[0];
let (image_height, image_width) = (image_dims[1], image_dims[2]);
let padded_image = pad::<T>(image.clone(), padding);
let padded_image = pad::<T>(image.clone(), padding)?;
let horz_slides = (image_height + 2 * padding.0 - pool_dims.0) / stride.0 + 1;
let vert_slides = (image_width + 2 * padding.1 - pool_dims.1) / stride.1 + 1;
@@ -671,7 +737,7 @@ pub fn max_pool2d<T: TensorType>(
output.set(
&[i, j, k],
padded_image
.get_slice(&[i..(i + 1), rs..(rs + pool_dims.0), cs..(cs + pool_dims.1)])
.get_slice(&[i..(i + 1), rs..(rs + pool_dims.0), cs..(cs + pool_dims.1)])?
.into_iter()
.fold(None, fmax)
.unwrap(),
@@ -679,7 +745,7 @@ pub fn max_pool2d<T: TensorType>(
}
}
}
output
Ok(output)
}
/// Dot product of two tensors.
@@ -699,19 +765,20 @@ pub fn max_pool2d<T: TensorType>(
/// Some(&[5, 5, 10, -4, 2, -1, 2, 0, 1]),
/// &[1, 3, 3],
/// ).unwrap();
/// assert_eq!(dot(&vec![&x, &y])[0], 86);
/// assert_eq!(dot(&vec![&x, &y]).unwrap()[0], 86);
/// ```
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &Vec<&Tensor<T>>,
) -> Tensor<T> {
assert_eq!(inputs.len(), 2);
assert_eq!(inputs[0].clone().len(), inputs[1].clone().len());
) -> Result<Tensor<T>, TensorError> {
if (inputs.len() != 2) || (inputs[0].clone().len() != inputs[1].clone().len()) {
return Err(TensorError::DimMismatch("dot".to_string()));
}
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
let res = a
.iter()
.zip(b)
.fold(T::zero().unwrap(), |acc, (k, i)| acc + k.clone() * i);
Tensor::new(Some(&[res]), &[1]).unwrap()
Tensor::new(Some(&[res]), &[1])
}
/// Pads a 3D tensor of shape `C x H x W` to a tensor of shape `C x (H + 2xPADDING) x (W + 2xPADDING)` using 0 values.
@@ -728,15 +795,20 @@ pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3],
/// ).unwrap();
/// let result = pad::<i32>(x, (1, 1));
/// let result = pad::<i32>(x, (1, 1)).unwrap();
/// let expected = Tensor::<i32>::new(
/// Some(&[0, 0, 0, 0, 0, 0, 5, 2, 3, 0, 0, 0, 4, -1, 0, 0, 3, 1, 6, 0, 0, 0, 0, 0, 0]),
/// &[1, 5, 5],
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn pad<T: TensorType>(image: Tensor<T>, padding: (usize, usize)) -> Tensor<T> {
assert_eq!(image.dims().len(), 3);
pub fn pad<T: TensorType>(
image: Tensor<T>,
padding: (usize, usize),
) -> Result<Tensor<T>, TensorError> {
if image.dims().len() != 3 {
return Err(TensorError::DimMismatch("pad".to_string()));
}
let (channels, height, width) = (image.dims()[0], image.dims()[1], image.dims()[2]);
let padded_height = height + 2 * padding.0;
let padded_width = width + 2 * padding.1;
@@ -755,7 +827,7 @@ pub fn pad<T: TensorType>(image: Tensor<T>, padding: (usize, usize)) -> Tensor<T
}
output.reshape(&[channels, padded_height, padded_width]);
output
Ok(output)
}
// ---------------------------------------------------------------------------------------------------------

View File

@@ -73,35 +73,36 @@ impl<F: FieldExt + TensorType> ValTensor<F> {
}
/// Calls `get_slice` on the inner tensor.
pub fn get_slice(&self, indices: &[Range<usize>]) -> ValTensor<F> {
match self {
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, Box<dyn Error>> {
let slice = match self {
ValTensor::Value { inner: v, dims: _ } => {
let slice = v.get_slice(indices);
let slice = v.get_slice(indices)?;
ValTensor::Value {
inner: slice.clone(),
dims: slice.dims().to_vec(),
}
}
ValTensor::AssignedValue { inner: v, dims: _ } => {
let slice = v.get_slice(indices);
let slice = v.get_slice(indices)?;
ValTensor::AssignedValue {
inner: slice.clone(),
dims: slice.dims().to_vec(),
}
}
ValTensor::PrevAssigned { inner: v, dims: _ } => {
let slice = v.get_slice(indices);
let slice = v.get_slice(indices)?;
ValTensor::PrevAssigned {
inner: slice.clone(),
dims: slice.dims().to_vec(),
}
}
_ => unimplemented!(),
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(slice)
}
/// Sets the [ValTensor]'s shape.
pub fn reshape(&mut self, new_dims: &[usize]) {
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value { inner: v, dims: d } => {
v.reshape(new_dims);
@@ -116,13 +117,13 @@ impl<F: FieldExt + TensorType> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { dims: d, .. } => {
assert_eq!(
d.iter().product::<usize>(),
new_dims.iter().product::<usize>()
);
if d.iter().product::<usize>() != new_dims.iter().product::<usize>() {
return Err(Box::new(TensorError::DimError));
}
*d = new_dims.to_vec();
}
}
};
Ok(())
}
/// Calls `flatten` on the inner [Tensor].

View File

@@ -1,6 +1,4 @@
use super::*;
use crate::abort;
use log::error;
use std::cmp::min;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
/// The wrapper allows for `VarTensor`'s dimensions to differ from that of the inner (wrapped) columns.
@@ -184,7 +182,7 @@ impl VarTensor {
&self,
meta: &mut VirtualCells<'_, F>,
offset: usize,
) -> Result<Tensor<Expression<F>>, TensorError> {
) -> Result<Tensor<Expression<F>>, halo2_proofs::plonk::Error> {
match &self {
VarTensor::Fixed {
inner: fixed, dims, ..
@@ -224,87 +222,57 @@ impl VarTensor {
region: &mut Region<'_, F>,
offset: usize,
values: &ValTensor<F>,
) -> Result<Tensor<AssignedCell<F, F>>, TensorError> {
) -> Result<Tensor<AssignedCell<F, F>>, halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance {
inner: instance, ..
} => match &self {
VarTensor::Advice { inner: v, dims, .. } => {
let t = Tensor::new(None, dims).unwrap();
t.enum_map(|coord, _: usize| {
// this should never ever fail
let t: Tensor<i32> = Tensor::new(None, dims).unwrap();
t.enum_map(|coord, _| {
let (x, y) = self.cartesian_coord(offset + coord);
match region.assign_advice_from_instance(
region.assign_advice_from_instance(
|| "pub input anchor",
*instance,
coord,
v[x],
y,
) {
Ok(v) => v,
Err(e) => {
abort!("failed to assign advice from instance {:?}", e);
}
}
)
})
}
_ => {
abort!("should be an advice");
}
_ => Err(halo2_proofs::plonk::Error::Synthesis),
},
ValTensor::Value { inner: v, dims: _ } => v.enum_map(|coord, k| match &self {
ValTensor::Value { inner: v, .. } => v.enum_map(|coord, k| match &self {
VarTensor::Fixed { inner: fixed, .. } => {
let (x, y) = self.cartesian_coord(offset + coord);
match region.assign_fixed(|| "k", fixed[x], y, || k) {
Ok(a) => a,
Err(e) => {
abort!("failed to assign ValTensor to VarTensor {:?}", e);
}
}
region.assign_fixed(|| "k", fixed[x], y, || k)
}
VarTensor::Advice { inner: advices, .. } => {
let (x, y) = self.cartesian_coord(offset + coord);
match region.assign_advice(|| "k", advices[x], y, || k) {
Ok(a) => a,
Err(e) => {
abort!("failed to assign ValTensor to VarTensor {:?}", e);
}
}
region.assign_advice(|| "k", advices[x], y, || k)
}
}),
ValTensor::PrevAssigned { inner: v, dims: _ } => {
v.enum_map(|coord, xcell| match &self {
VarTensor::Advice { inner: advices, .. } => {
let (x, y) = self.cartesian_coord(offset + coord);
match xcell.copy_advice(|| "k", region, advices[x], y) {
Ok(a) => a,
Err(e) => {
abort!("failed to copy ValTensor to VarTensor {:?}", e);
}
}
}
_ => {
unimplemented!()
}
})
}
ValTensor::AssignedValue { inner: v, dims: _ } => v.enum_map(|coord, k| match &self {
ValTensor::PrevAssigned { inner: v, .. } => v.enum_map(|coord, xcell| match &self {
VarTensor::Advice { inner: advices, .. } => {
let (x, y) = self.cartesian_coord(offset + coord);
xcell.copy_advice(|| "k", region, advices[x], y)
}
_ => Err(halo2_proofs::plonk::Error::Synthesis),
}),
ValTensor::AssignedValue { inner: v, .. } => v.enum_map(|coord, k| match &self {
VarTensor::Fixed { inner: fixed, .. } => {
let (x, y) = self.cartesian_coord(offset + coord);
match region.assign_fixed(|| "k", fixed[x], y, || k) {
Ok(a) => a.evaluate(),
Err(e) => {
abort!("failed to assign ValTensor to VarTensor {:?}", e);
}
Ok(a) => Ok(a.evaluate()),
Err(e) => Err(e),
}
}
VarTensor::Advice { inner: advices, .. } => {
let (x, y) = self.cartesian_coord(offset + coord);
match region.assign_advice(|| "k", advices[x], y, || k) {
Ok(a) => a.evaluate(),
Err(e) => {
abort!("failed to assign ValTensor to VarTensor {:?}", e);
}
Ok(a) => Ok(a.evaluate()),
Err(e) => Err(e),
}
}
}),

View File

@@ -3,7 +3,8 @@ use std::env::var;
use std::process::Command;
lazy_static! {
static ref CARGO_TARGET_DIR: String = var("CARGO_TARGET_DIR").unwrap_or("./target".to_string());
static ref CARGO_TARGET_DIR: String =
var("CARGO_TARGET_DIR").unwrap_or_else(|_| "./target".to_string());
}
#[cfg(test)]
@@ -183,12 +184,7 @@ fn neg_mock(example_name: String, counter_example: String) {
// Mock prove (fast, but does not cover some potential issues)
fn run_example(example_name: String) {
let status = Command::new("cargo")
.args([
"run",
"--release",
"--example",
format!("{}", example_name).as_str(),
])
.args(["run", "--release", "--example", example_name.as_str()])
.status()
.expect("failed to execute process");
assert!(status.success());