mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
chore: error bubbling (#93)
Co-authored-by: jason <jason.morton@gmail.com>
This commit is contained in:
14
.github/workflows/rust.yml
vendored
14
.github/workflows/rust.yml
vendored
@@ -6,12 +6,10 @@ on:
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
@@ -60,7 +58,7 @@ jobs:
|
||||
|
||||
library-tests:
|
||||
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -108,9 +106,9 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: IPA full-prove tests
|
||||
run: cargo test --release --verbose tests::ipa_fullprove_ -- --test-threads 1
|
||||
run: cargo test --release --verbose tests::ipa_fullprove_ -- --test-threads 4
|
||||
- name: KZG full-prove tests
|
||||
run: cargo test --release --verbose tests::kzg_fullprove_ -- --test-threads 1
|
||||
run: cargo test --release --verbose tests::kzg_fullprove_ -- --test-threads 4
|
||||
|
||||
full-proving-evm-tests:
|
||||
|
||||
@@ -129,7 +127,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.17 && solc --version
|
||||
- name: KZG full-prove tests (EVM)
|
||||
run: cargo test --release --verbose tests_evm::kzg_evm_fullprove_ -- --test-threads 1
|
||||
run: cargo test --release --verbose tests_evm::kzg_evm_fullprove_ -- --test-threads 3
|
||||
|
||||
prove-and-verify-tests:
|
||||
|
||||
@@ -145,9 +143,9 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo test --release --verbose tests::ipa_prove_and_verify_ -- --test-threads 1
|
||||
run: cargo test --release --verbose tests::ipa_prove_and_verify_ -- --test-threads 4
|
||||
- name: KZG prove and verify tests
|
||||
run: cargo test --release --verbose tests::kzg_prove_and_verify_ -- --test-threads 1
|
||||
run: cargo test --release --verbose tests::kzg_prove_and_verify_ -- --test-threads 4
|
||||
|
||||
examples:
|
||||
|
||||
|
||||
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1839,6 +1839,7 @@ dependencies = [
|
||||
"tabled",
|
||||
"tensorflow",
|
||||
"test-case",
|
||||
"thiserror",
|
||||
"tract-onnx",
|
||||
]
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", package = "e
|
||||
plonk_verifier = { git = "https://github.com/zkonduit/plonk-verifier", branch = "main"}
|
||||
colog = { version = "1.1.0", optional = true }
|
||||
eq-float = "0.1.0"
|
||||
thiserror = "1.0.38"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = {version = "0.3", features = ["html_reports"]}
|
||||
|
||||
@@ -55,14 +55,16 @@ impl<F: FieldExt + TensorType> Circuit<F> for MyCircuit<F> {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout(
|
||||
&mut layouter,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut layouter,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,11 +44,13 @@ impl<F: FieldExt + TensorType> Circuit<F> for MyCircuit<F> {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout(
|
||||
layouter.namespace(|| "Assign value"),
|
||||
self.input.clone(),
|
||||
self.output.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
layouter.namespace(|| "Assign value"),
|
||||
self.input.clone(),
|
||||
self.output.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ impl<F: FieldExt + TensorType> Circuit<F> for NLCircuit<F> {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>, // layouter is our 'write buffer' for the circuit
|
||||
) -> Result<(), Error> {
|
||||
config.layout(&mut layouter, &self.input);
|
||||
config.layout(&mut layouter, &self.input).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -239,28 +239,30 @@ where
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let x = config.l0.layout(
|
||||
&mut layouter,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
);
|
||||
let mut x = config.l1.layout(&mut layouter, &x);
|
||||
let x = config
|
||||
.l0
|
||||
.layout(
|
||||
&mut layouter,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let mut x = config.l1.layout(&mut layouter, &x).unwrap();
|
||||
x.flatten();
|
||||
let l2out = config.l2.layout(
|
||||
&mut layouter,
|
||||
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
|
||||
);
|
||||
let l2out = config
|
||||
.l2
|
||||
.layout(
|
||||
&mut layouter,
|
||||
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
match l2out {
|
||||
ValTensor::PrevAssigned { inner: v, dims: _ } => v
|
||||
.enum_map(|i, x| {
|
||||
layouter
|
||||
.constrain_instance(x.cell(), config.public_output, i)
|
||||
.unwrap()
|
||||
})
|
||||
.enum_map(|i, x| layouter.constrain_instance(x.cell(), config.public_output, i))
|
||||
.unwrap(),
|
||||
_ => panic!("Should be assigned"),
|
||||
};
|
||||
@@ -310,10 +312,11 @@ pub fn runconv() {
|
||||
|
||||
let mut input: ValTensor<F> = train_data
|
||||
.get_slice(&[0..1, 0..28, 0..28])
|
||||
.unwrap()
|
||||
.map(Value::known)
|
||||
.into();
|
||||
|
||||
input.reshape(&[1, 28, 28]);
|
||||
input.reshape(&[1, 28, 28]).unwrap();
|
||||
|
||||
let myparams = params::Params::new();
|
||||
let mut l0_kernels: ValTensor<F> = Tensor::<Value<F>>::from(
|
||||
@@ -334,7 +337,9 @@ pub fn runconv() {
|
||||
)
|
||||
.into();
|
||||
|
||||
l0_kernels.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH]);
|
||||
l0_kernels
|
||||
.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH])
|
||||
.unwrap();
|
||||
|
||||
let l0_bias: ValTensor<F> = Tensor::<Value<F>>::from(
|
||||
(0..OUT_CHANNELS).map(|_| Value::known(fieldutils::i32_to_felt(0))),
|
||||
@@ -360,7 +365,7 @@ pub fn runconv() {
|
||||
}))
|
||||
.into();
|
||||
|
||||
l2_weights.reshape(&[CLASSES, LEN]);
|
||||
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
|
||||
|
||||
let circuit = MyCircuit::<
|
||||
F,
|
||||
|
||||
@@ -83,7 +83,8 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
|
||||
&output,
|
||||
BITS,
|
||||
&[LookupOp::ReLU { scale: 1 }],
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// sets up a new Divide by table
|
||||
let l4 =
|
||||
@@ -107,28 +108,30 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
let x = config.l0.layout(
|
||||
&mut layouter,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
);
|
||||
let x = config.l1.layout(&mut layouter, &x);
|
||||
let x = config.l2.layout(
|
||||
&mut layouter,
|
||||
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
|
||||
);
|
||||
let x = config.l3.layout(&mut layouter, &x);
|
||||
let x = config.l4.layout(&mut layouter, &x);
|
||||
let x = config
|
||||
.l0
|
||||
.layout(
|
||||
&mut layouter,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let x = config.l1.layout(&mut layouter, &x).unwrap();
|
||||
let x = config
|
||||
.l2
|
||||
.layout(
|
||||
&mut layouter,
|
||||
&[x, self.l2_params[0].clone(), self.l2_params[1].clone()],
|
||||
)
|
||||
.unwrap();
|
||||
let x = config.l3.layout(&mut layouter, &x).unwrap();
|
||||
let x = config.l4.layout(&mut layouter, &x).unwrap();
|
||||
match x {
|
||||
ValTensor::PrevAssigned { inner: v, dims: _ } => v
|
||||
.enum_map(|i, x| {
|
||||
layouter
|
||||
.constrain_instance(x.cell(), config.public_output, i)
|
||||
.unwrap()
|
||||
})
|
||||
.enum_map(|i, x| layouter.constrain_instance(x.cell(), config.public_output, i))
|
||||
.unwrap(),
|
||||
_ => panic!("Should be assigned"),
|
||||
};
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use ezkl::commands::Cli;
|
||||
use ezkl::execute::run;
|
||||
use log::info;
|
||||
use log::{error, info};
|
||||
use rand::seq::SliceRandom;
|
||||
use std::error::Error;
|
||||
|
||||
pub fn main() {
|
||||
pub fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Cli::create();
|
||||
colog::init();
|
||||
banner();
|
||||
info!("{}", &args.as_json());
|
||||
run(args)
|
||||
info!("{}", &args.as_json()?);
|
||||
let res = run(args);
|
||||
match &res {
|
||||
Ok(_) => info!("verify succeeded"),
|
||||
Err(e) => error!("verify failed: {}", e),
|
||||
};
|
||||
res
|
||||
}
|
||||
|
||||
fn banner() {
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use super::*;
|
||||
use crate::tensor::ops::activations::*;
|
||||
use crate::{abort, fieldutils::felt_to_i32, fieldutils::i32_to_felt};
|
||||
use crate::{fieldutils::felt_to_i32, fieldutils::i32_to_felt};
|
||||
use halo2_proofs::{
|
||||
arithmetic::{Field, FieldExt},
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{ConstraintSystem, Expression, Selector, TableColumn},
|
||||
poly::Rotation,
|
||||
};
|
||||
use log::error;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::{cell::RefCell, marker::PhantomData, rc::Rc};
|
||||
|
||||
@@ -98,8 +98,11 @@ impl<F: FieldExt> Table<F> {
|
||||
}
|
||||
}
|
||||
/// Assigns values to the constraints generated when calling `configure`.
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) {
|
||||
assert!(!self.is_assigned);
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
}
|
||||
|
||||
let base = 2i32;
|
||||
let smallest = -base.pow(self.bits as u32 - 1);
|
||||
let largest = base.pow(self.bits as u32 - 1);
|
||||
@@ -108,36 +111,35 @@ impl<F: FieldExt> Table<F> {
|
||||
for nl in self.nonlinearities.clone() {
|
||||
evals = nl.f(inputs.clone());
|
||||
}
|
||||
self.is_assigned = true;
|
||||
layouter
|
||||
.assign_table(
|
||||
|| "nl table",
|
||||
|mut table| {
|
||||
inputs
|
||||
.enum_map(|row_offset, input| {
|
||||
table
|
||||
.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
self.table_input,
|
||||
row_offset,
|
||||
|| Value::known(i32_to_felt::<F>(input)),
|
||||
)
|
||||
.expect("failed to assign table input cell");
|
||||
let _ = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(row_offset, input)| {
|
||||
table.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
self.table_input,
|
||||
row_offset,
|
||||
|| Value::known(i32_to_felt::<F>(*input)),
|
||||
)?;
|
||||
|
||||
table
|
||||
.assign_cell(
|
||||
|| format!("nl_o_col row {}", row_offset),
|
||||
self.table_output,
|
||||
row_offset,
|
||||
|| Value::known(i32_to_felt::<F>(evals[row_offset])),
|
||||
)
|
||||
.expect("failed to assign table output cell");
|
||||
table.assign_cell(
|
||||
|| format!("nl_o_col row {}", row_offset),
|
||||
self.table_output,
|
||||
row_offset,
|
||||
|| Value::known(i32_to_felt::<F>(evals[row_offset])),
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.expect("failed to assign table");
|
||||
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.expect("failed to layout table");
|
||||
self.is_assigned = true;
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,24 +165,24 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
output: &VarTensor,
|
||||
bits: usize,
|
||||
nonlinearitities: &[Op],
|
||||
) -> [Self; NUM] {
|
||||
) -> Result<[Self; NUM], Box<dyn Error>> {
|
||||
let mut table: Option<Rc<RefCell<Table<F>>>> = None;
|
||||
let configs = (0..NUM)
|
||||
.map(|_| {
|
||||
let l = match &table {
|
||||
None => Self::configure(cs, input, output, bits, &nonlinearitities),
|
||||
Some(t) => Self::configure_with_table(cs, input, output, t.clone()),
|
||||
};
|
||||
table = Some(l.table.clone());
|
||||
l
|
||||
})
|
||||
.collect::<Vec<Config<F>>>()
|
||||
.try_into();
|
||||
|
||||
match configs {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!("failed to initialize layers"),
|
||||
let mut configs: Vec<Config<F>> = vec![];
|
||||
for _ in 0..NUM {
|
||||
let l = match &table {
|
||||
None => Self::configure(cs, input, output, bits, nonlinearitities),
|
||||
Some(t) => Self::configure_with_table(cs, input, output, t.clone()),
|
||||
};
|
||||
table = Some(l.table.clone());
|
||||
configs.push(l);
|
||||
}
|
||||
let res: [Self; NUM] = match configs.try_into() {
|
||||
Ok(a) => a,
|
||||
Err(_) => {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Configures and creates an elementwise operation within a circuit using a supplied lookup table.
|
||||
@@ -260,16 +262,20 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
let table = Rc::new(RefCell::new(Table::<F>::configure(
|
||||
cs,
|
||||
bits,
|
||||
&nonlinearitities,
|
||||
nonlinearitities,
|
||||
)));
|
||||
Self::configure_with_table(cs, input, output, table)
|
||||
}
|
||||
|
||||
/// Assigns values to the variables created when calling `configure`.
|
||||
/// Values are supplied as a 1-element array of `[input]` VarTensors.
|
||||
pub fn layout(&self, layouter: &mut impl Layouter<F>, values: &ValTensor<F>) -> ValTensor<F> {
|
||||
pub fn layout(
|
||||
&self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if !self.table.borrow().is_assigned {
|
||||
self.table.borrow_mut().layout(layouter)
|
||||
self.table.borrow_mut().layout(layouter)?
|
||||
}
|
||||
let mut t = ValTensor::from(
|
||||
match layouter.assign_region(
|
||||
@@ -277,7 +283,7 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
|mut region| {
|
||||
self.qlookup.enable(&mut region, 0)?;
|
||||
|
||||
let w = self.input.assign(&mut region, 0, &values).unwrap();
|
||||
let w = self.input.assign(&mut region, 0, values)?;
|
||||
|
||||
let mut res: Vec<i32> = vec![];
|
||||
let _ = Tensor::from(w.iter().map(|acaf| (*acaf).value_field()).map(|vaf| {
|
||||
@@ -298,20 +304,17 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
}
|
||||
};
|
||||
|
||||
Ok(self
|
||||
.output
|
||||
.assign(&mut region, 0, &ValTensor::from(output))
|
||||
.unwrap())
|
||||
self.output.assign(&mut region, 0, &ValTensor::from(output))
|
||||
},
|
||||
) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
abort!("failed to assign elt-wise region {:?}", e);
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
},
|
||||
);
|
||||
t.reshape(values.dims());
|
||||
t
|
||||
t.reshape(values.dims())?;
|
||||
Ok(t)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,3 +7,21 @@ pub mod polynomial;
|
||||
pub mod range;
|
||||
/// Utility functions for building gates.
|
||||
pub mod utils;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CircuitError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("dimension mismatch in circuit construction for op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Error when instantiating lookup tables
|
||||
#[error("failed to instantiate lookup tables")]
|
||||
LookupInstantiation,
|
||||
/// A lookup table was was already assigned
|
||||
#[error("attempting to initialize an already instantiated lookup table")]
|
||||
TableAlreadyAssigned,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
use crate::abort;
|
||||
use crate::tensor::ops::*;
|
||||
use crate::tensor::{Tensor, TensorType};
|
||||
use halo2_proofs::{
|
||||
@@ -8,7 +7,7 @@ use halo2_proofs::{
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use log::error;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -93,18 +92,18 @@ impl Op {
|
||||
pub fn f<T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T>>(
|
||||
&self,
|
||||
mut inputs: Vec<Tensor<T>>,
|
||||
) -> Tensor<T> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
match &self {
|
||||
Op::Identity => inputs[0].clone(),
|
||||
Op::Identity => Ok(inputs[0].clone()),
|
||||
Op::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims);
|
||||
t
|
||||
Ok(t)
|
||||
}
|
||||
Op::Flatten(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims);
|
||||
t
|
||||
Ok(t)
|
||||
}
|
||||
Op::Add => add(&inputs),
|
||||
Op::Sub => sub(&inputs),
|
||||
@@ -124,24 +123,27 @@ impl Op {
|
||||
} => sumpool(&inputs[0], *padding, *stride, *kernel_shape),
|
||||
Op::GlobalSumPool => unreachable!(),
|
||||
Op::Pow(u) => {
|
||||
assert_eq!(inputs.len(), 1);
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
}
|
||||
pow(&inputs[0], *u)
|
||||
}
|
||||
Op::Sum => {
|
||||
assert_eq!(inputs.len(), 1);
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("sum inputs".to_string()));
|
||||
}
|
||||
sum(&inputs[0])
|
||||
}
|
||||
Op::Rescaled { inner, scale } => {
|
||||
assert_eq!(scale.len(), inputs.len());
|
||||
if scale.len() != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
|
||||
inner.f(inputs
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, ri)| {
|
||||
assert_eq!(scale[i].0, i);
|
||||
rescale(ri, scale[i].1)
|
||||
})
|
||||
.collect_vec())
|
||||
let mut rescaled_inputs = vec![];
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
rescaled_inputs.push(rescale(ri, scale[i].1)?);
|
||||
}
|
||||
Ok(inner.f(rescaled_inputs)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -204,31 +206,24 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
let qis = config
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| match input.query(meta, 0) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
abort!("failed to query input {:?}", e);
|
||||
}
|
||||
})
|
||||
.map(|input| input.query(meta, 0).expect("poly: input query failed"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut config_outputs = vec![];
|
||||
for node in config.nodes.iter_mut() {
|
||||
Self::apply_op(node, &qis, &mut config_outputs);
|
||||
Self::apply_op(node, &qis, &mut config_outputs).expect("poly: apply op failed");
|
||||
}
|
||||
let witnessed_output = &config_outputs[config.nodes.len() - 1];
|
||||
|
||||
// Get output expressions for each input channel
|
||||
let expected_output: Tensor<Expression<F>> = match config.output.query(meta, 0) {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
abort!("failed to query output during fused layer layout {:?}", e);
|
||||
}
|
||||
};
|
||||
let expected_output: Tensor<Expression<F>> = config
|
||||
.output
|
||||
.query(meta, 0)
|
||||
.expect("poly: output query failed");
|
||||
|
||||
let constraints = witnessed_output
|
||||
.enum_map(|i, o| o - expected_output[i].clone())
|
||||
.unwrap();
|
||||
.enum_map::<_, _, CircuitError>(|i, o| Ok(o - expected_output[i].clone()))
|
||||
.expect("poly: failed to create constraints");
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
@@ -244,8 +239,12 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> ValTensor<F> {
|
||||
assert_eq!(values.len(), self.inputs.len());
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if values.len() != self.inputs.len() {
|
||||
return Err(Box::new(CircuitError::DimMismatch(
|
||||
"polynomial layout".to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
let t = match layouter.assign_region(
|
||||
|| "assign inputs",
|
||||
@@ -258,15 +257,8 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
let inp = utils::value_muxer(
|
||||
&self.inputs[i],
|
||||
&{
|
||||
match self.inputs[i].assign(&mut region, offset, input) {
|
||||
Ok(res) => res.map(|e| e.value_field().evaluate()),
|
||||
Err(e) => {
|
||||
abort!(
|
||||
"failed to assign inputs during fused layer layout {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
let res = self.inputs[i].assign(&mut region, offset, input)?;
|
||||
res.map(|e| e.value_field().evaluate())
|
||||
},
|
||||
input,
|
||||
);
|
||||
@@ -276,30 +268,27 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
let mut layout_outputs = vec![];
|
||||
|
||||
for node in self.nodes.iter_mut() {
|
||||
Self::apply_op(node, &inputs, &mut layout_outputs);
|
||||
Self::apply_op(node, &inputs, &mut layout_outputs)
|
||||
.expect("poly: apply op failed");
|
||||
}
|
||||
let output: ValTensor<F> = match layout_outputs.last() {
|
||||
Some(a) => a.clone().into(),
|
||||
None => {
|
||||
panic!("fused layer has empty outputs");
|
||||
panic!("poly: empty outputs");
|
||||
}
|
||||
};
|
||||
|
||||
match self.output.assign(&mut region, offset, &output) {
|
||||
Ok(a) => Ok(a),
|
||||
Err(e) => {
|
||||
abort!("failed to assign fused layer output {:?}", e);
|
||||
}
|
||||
}
|
||||
let output = self.output.assign(&mut region, offset, &output)?;
|
||||
Ok(output)
|
||||
},
|
||||
) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
abort!("failed to assign fused layer region {:?}", e);
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
|
||||
ValTensor::from(t)
|
||||
Ok(ValTensor::from(t))
|
||||
}
|
||||
|
||||
/// Applies an operation represented by a [Op] to the set of inputs (both explicit and intermediate results) it indexes over.
|
||||
@@ -307,7 +296,7 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
node: &mut Node,
|
||||
inputs: &[Tensor<T>],
|
||||
outputs: &mut Vec<Tensor<T>>,
|
||||
) {
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let op_inputs = node
|
||||
.input_order
|
||||
.iter()
|
||||
@@ -316,7 +305,8 @@ impl<F: FieldExt + TensorType> Config<F> {
|
||||
InputType::Inter(u) => outputs[*u].clone(),
|
||||
})
|
||||
.collect_vec();
|
||||
outputs.push(node.op.f(op_inputs));
|
||||
outputs.push(node.op.f(op_inputs)?);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::abort;
|
||||
use super::CircuitError;
|
||||
use crate::fieldutils::i32_to_felt;
|
||||
use crate::tensor::{TensorType, ValTensor, VarTensor};
|
||||
use halo2_proofs::{
|
||||
@@ -6,7 +6,6 @@ use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
};
|
||||
use log::error;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Configuration for a range check on the difference between `input` and `expected`.
|
||||
@@ -45,20 +44,12 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
|
||||
// v | 1
|
||||
|
||||
let q = cs.query_selector(config.selector);
|
||||
let witnessed = match input.query(cs, 0) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
abort!("failed to query input {:?}", e);
|
||||
}
|
||||
};
|
||||
let witnessed = input.query(cs, 0).expect("range: failed to query input");
|
||||
|
||||
// Get output expressions for each input channel
|
||||
let expected = match expected.query(cs, 0) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
abort!("failed to query input {:?}", e);
|
||||
}
|
||||
};
|
||||
let expected = expected
|
||||
.query(cs, 0)
|
||||
.expect("range: failed to query expected value");
|
||||
|
||||
// Given a range R and a value v, returns the expression
|
||||
// (v) * (1 - v) * (2 - v) * ... * (R - 1 - v)
|
||||
@@ -69,8 +60,10 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
|
||||
};
|
||||
|
||||
let constraints = witnessed
|
||||
.enum_map(|i, o| range_check(tol as i32, o - expected[i].clone()))
|
||||
.unwrap();
|
||||
.enum_map::<_, _, CircuitError>(|i, o| {
|
||||
Ok(range_check(tol as i32, o - expected[i].clone()))
|
||||
})
|
||||
.expect("range: failed to create constraints");
|
||||
Constraints::with_selector(q, constraints)
|
||||
});
|
||||
|
||||
@@ -86,7 +79,7 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
|
||||
mut layouter: impl Layouter<F>,
|
||||
input: ValTensor<F>,
|
||||
output: ValTensor<F>,
|
||||
) {
|
||||
) -> Result<(), halo2_proofs::plonk::Error> {
|
||||
match layouter.assign_region(
|
||||
|| "range check layout",
|
||||
|mut region| {
|
||||
@@ -96,37 +89,22 @@ impl<F: FieldExt + TensorType> RangeCheckConfig<F> {
|
||||
self.selector.enable(&mut region, offset)?;
|
||||
|
||||
// assigns the instance to the advice.
|
||||
match self.input.assign(&mut region, offset, &input) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
abort!("failed to assign inputs during range layer layout {:?}", e);
|
||||
}
|
||||
};
|
||||
self.input.assign(&mut region, offset, &input)?;
|
||||
|
||||
match self.expected.assign(&mut region, offset, &output) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
abort!(
|
||||
"failed to assign expected output during range layer layout {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
self.expected.assign(&mut region, offset, &output)?;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
abort!("failed to assign fused layer region {:?}", e);
|
||||
}
|
||||
};
|
||||
// ValTensor::from(t);
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::tensor::Tensor;
|
||||
use halo2_proofs::{
|
||||
arithmetic::FieldExt,
|
||||
@@ -169,17 +147,20 @@ mod tests {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout(
|
||||
layouter.namespace(|| "assign value"),
|
||||
self.input.clone(),
|
||||
self.output.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
layouter.namespace(|| "assign value"),
|
||||
self.input.clone(),
|
||||
self.output.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::assertions_on_constants)]
|
||||
fn test_range_check() {
|
||||
let k = 4;
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
//use crate::onnx::OnnxModel;
|
||||
use crate::abort;
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use log::{error, info};
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::io::{stdin, stdout, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -44,14 +44,14 @@ pub struct Cli {
|
||||
|
||||
impl Cli {
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> String {
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
abort!("failed to convert Cli to string {:?}", e);
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
serialized
|
||||
Ok(serialized)
|
||||
}
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
|
||||
142
src/execute.rs
142
src/execute.rs
@@ -1,4 +1,3 @@
|
||||
use crate::abort;
|
||||
use crate::commands::{Cli, Commands, ProofSystem};
|
||||
use crate::fieldutils::i32_to_felt;
|
||||
use crate::graph::Model;
|
||||
@@ -9,9 +8,10 @@ use crate::pfsys::aggregation::{
|
||||
};
|
||||
use crate::pfsys::{create_keys, load_params, load_vk, Proof};
|
||||
use crate::pfsys::{
|
||||
create_proof_model, parse_prover_errors, prepare_circuit_and_public_input, prepare_data,
|
||||
save_params, save_vk, verify_proof_model,
|
||||
create_proof_model, prepare_circuit_and_public_input, prepare_data, save_params, save_vk,
|
||||
verify_proof_model,
|
||||
};
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
#[cfg(feature = "evm")]
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
@@ -34,46 +34,43 @@ use halo2curves::bn256::G1Affine;
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use halo2curves::pasta::vesta;
|
||||
use halo2curves::pasta::Fp;
|
||||
use log::{error, info, trace};
|
||||
use log::{info, trace};
|
||||
#[cfg(feature = "evm")]
|
||||
use plonk_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::error::Error;
|
||||
#[cfg(feature = "evm")]
|
||||
use std::time::Instant;
|
||||
use tabled::Table;
|
||||
use thiserror::Error;
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExecutionError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("verification failed")]
|
||||
VerifyError(Vec<VerifyFailure>),
|
||||
}
|
||||
|
||||
/// Run an ezkl command with given args
|
||||
pub fn run(args: Cli) {
|
||||
pub fn run(args: Cli) -> Result<(), Box<dyn Error>> {
|
||||
match args.command {
|
||||
Commands::Table { model: _ } => {
|
||||
let om = Model::from_ezkl_conf(args);
|
||||
let om = Model::from_ezkl_conf(args)?;
|
||||
println!("{}", Table::new(om.nodes.flatten()));
|
||||
}
|
||||
Commands::Mock { ref data, model: _ } => {
|
||||
let data = prepare_data(data.to_string());
|
||||
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args);
|
||||
let data = prepare_data(data.to_string())?;
|
||||
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args)?;
|
||||
info!("Mock proof");
|
||||
let pi: Vec<Vec<Fp>> = public_inputs
|
||||
.into_iter()
|
||||
.map(|i| i.into_iter().map(i32_to_felt::<Fp>).collect())
|
||||
.collect();
|
||||
|
||||
let prover = match MockProver::run(args.logrows, &circuit, pi) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
abort!("mock prover failed to run {:?}", e);
|
||||
}
|
||||
};
|
||||
match prover.verify() {
|
||||
Ok(_) => {
|
||||
info!("verify succeeded")
|
||||
}
|
||||
Err(v) => {
|
||||
for e in v.iter() {
|
||||
parse_prover_errors(e)
|
||||
}
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
let prover =
|
||||
MockProver::run(args.logrows, &circuit, pi).map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
}
|
||||
|
||||
Commands::Fullprove {
|
||||
@@ -83,16 +80,17 @@ pub fn run(args: Cli) {
|
||||
} => {
|
||||
// A direct proof
|
||||
|
||||
let data = prepare_data(data.to_string());
|
||||
let data = prepare_data(data.to_string())?;
|
||||
|
||||
match pfsys {
|
||||
ProofSystem::IPA => {
|
||||
let (circuit, public_inputs) =
|
||||
prepare_circuit_and_public_input::<Fp>(&data, &args);
|
||||
prepare_circuit_and_public_input::<Fp>(&data, &args)?;
|
||||
info!("full proof with {}", pfsys);
|
||||
|
||||
let params: ParamsIPA<vesta::Affine> = ParamsIPA::new(args.logrows);
|
||||
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, ¶ms);
|
||||
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
let strategy = IPASingleStrategy::new(¶ms);
|
||||
trace!("params computed");
|
||||
|
||||
@@ -102,17 +100,19 @@ pub fn run(args: Cli) {
|
||||
ProverIPA<_>,
|
||||
>(
|
||||
&circuit, &public_inputs, ¶ms, &pk
|
||||
);
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
assert!(verify_proof_model(proof, ¶ms, pk.get_vk(), strategy));
|
||||
verify_proof_model(proof, ¶ms, pk.get_vk(), strategy)?;
|
||||
}
|
||||
#[cfg(not(feature = "evm"))]
|
||||
ProofSystem::KZG => {
|
||||
// A direct proof
|
||||
let (circuit, public_inputs) =
|
||||
prepare_circuit_and_public_input::<Fr>(&data, &args);
|
||||
prepare_circuit_and_public_input::<Fr>(&data, &args)?;
|
||||
let params: ParamsKZG<Bn256> = ParamsKZG::new(args.logrows);
|
||||
let pk = create_keys::<KZGCommitmentScheme<_>, Fr>(&circuit, ¶ms);
|
||||
let pk = create_keys::<KZGCommitmentScheme<_>, Fr>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
trace!("params computed");
|
||||
|
||||
@@ -122,14 +122,15 @@ pub fn run(args: Cli) {
|
||||
ProverGWC<_>,
|
||||
>(
|
||||
&circuit, &public_inputs, ¶ms, &pk
|
||||
);
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
assert!(verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>(
|
||||
verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>(
|
||||
proof,
|
||||
¶ms,
|
||||
pk.get_vk(),
|
||||
strategy
|
||||
));
|
||||
strategy,
|
||||
)?;
|
||||
}
|
||||
#[cfg(feature = "evm")]
|
||||
ProofSystem::KZG => {
|
||||
@@ -144,16 +145,16 @@ pub fn run(args: Cli) {
|
||||
params
|
||||
};
|
||||
let now = Instant::now();
|
||||
let snarks = [(); 1].map(|_| gen_application_snark(¶ms_app, &data, &args));
|
||||
let snarks = [gen_application_snark(¶ms_app, &data, &args)?];
|
||||
info!("Application proof took {}", now.elapsed().as_secs());
|
||||
let agg_circuit = AggregationCircuit::new(¶ms, snarks);
|
||||
let pk = gen_pk(¶ms, &agg_circuit);
|
||||
let agg_circuit = AggregationCircuit::new(¶ms, snarks)?;
|
||||
let pk = gen_pk(¶ms, &agg_circuit)?;
|
||||
let deployment_code = gen_aggregation_evm_verifier(
|
||||
¶ms,
|
||||
pk.get_vk(),
|
||||
AggregationCircuit::num_instance(),
|
||||
AggregationCircuit::accumulator_indices(),
|
||||
);
|
||||
)?;
|
||||
let now = Instant::now();
|
||||
let proof = gen_kzg_proof::<
|
||||
_,
|
||||
@@ -162,10 +163,10 @@ pub fn run(args: Cli) {
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(
|
||||
¶ms, &pk, agg_circuit.clone(), agg_circuit.instances()
|
||||
);
|
||||
)?;
|
||||
info!("Aggregation proof took {}", now.elapsed().as_secs());
|
||||
let now = Instant::now();
|
||||
evm_verify(deployment_code, agg_circuit.instances(), proof);
|
||||
evm_verify(deployment_code, agg_circuit.instances(), proof)?;
|
||||
info!("verify took {}", now.elapsed().as_secs());
|
||||
}
|
||||
}
|
||||
@@ -178,33 +179,37 @@ pub fn run(args: Cli) {
|
||||
ref params_path,
|
||||
pfsys,
|
||||
} => {
|
||||
let data = prepare_data(data.to_string());
|
||||
let data = prepare_data(data.to_string())?;
|
||||
|
||||
match pfsys {
|
||||
ProofSystem::IPA => {
|
||||
info!("proof with {}", pfsys);
|
||||
let (circuit, public_inputs) =
|
||||
prepare_circuit_and_public_input::<Fp>(&data, &args);
|
||||
prepare_circuit_and_public_input::<Fp>(&data, &args)?;
|
||||
let params: ParamsIPA<vesta::Affine> = ParamsIPA::new(args.logrows);
|
||||
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, ¶ms);
|
||||
let pk = create_keys::<IPACommitmentScheme<_>, Fp>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
trace!("params computed");
|
||||
|
||||
let (proof, _) = create_proof_model::<IPACommitmentScheme<_>, Fp, ProverIPA<_>>(
|
||||
&circuit,
|
||||
&public_inputs,
|
||||
¶ms,
|
||||
&pk,
|
||||
);
|
||||
let (proof, _) =
|
||||
create_proof_model::<IPACommitmentScheme<_>, Fp, ProverIPA<_>>(
|
||||
&circuit,
|
||||
&public_inputs,
|
||||
¶ms,
|
||||
&pk,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
proof.save(proof_path);
|
||||
save_params::<IPACommitmentScheme<_>>(params_path, ¶ms);
|
||||
save_vk::<IPACommitmentScheme<_>>(vk_path, pk.get_vk());
|
||||
proof.save(proof_path)?;
|
||||
save_params::<IPACommitmentScheme<_>>(params_path, ¶ms)?;
|
||||
save_vk::<IPACommitmentScheme<_>>(vk_path, pk.get_vk())?;
|
||||
}
|
||||
ProofSystem::KZG => {
|
||||
info!("proof with {}", pfsys);
|
||||
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args);
|
||||
let (circuit, public_inputs) = prepare_circuit_and_public_input(&data, &args)?;
|
||||
let params: ParamsKZG<Bn256> = ParamsKZG::new(args.logrows);
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr>(&circuit, ¶ms);
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
trace!("params computed");
|
||||
|
||||
let (proof, _input_dims) = create_proof_model::<
|
||||
@@ -213,11 +218,12 @@ pub fn run(args: Cli) {
|
||||
ProverGWC<'_, Bn256>,
|
||||
>(
|
||||
&circuit, &public_inputs, ¶ms, &pk
|
||||
);
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
proof.save(proof_path);
|
||||
save_params::<KZGCommitmentScheme<Bn256>>(params_path, ¶ms);
|
||||
save_vk::<KZGCommitmentScheme<Bn256>>(vk_path, pk.get_vk());
|
||||
proof.save(proof_path)?;
|
||||
save_params::<KZGCommitmentScheme<Bn256>>(params_path, ¶ms)?;
|
||||
save_vk::<KZGCommitmentScheme<Bn256>>(vk_path, pk.get_vk())?;
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -228,29 +234,31 @@ pub fn run(args: Cli) {
|
||||
params_path,
|
||||
pfsys,
|
||||
} => {
|
||||
let proof = Proof::load(&proof_path);
|
||||
let proof = Proof::load(&proof_path)?;
|
||||
match pfsys {
|
||||
ProofSystem::IPA => {
|
||||
let params: ParamsIPA<vesta::Affine> =
|
||||
load_params::<IPACommitmentScheme<_>>(params_path);
|
||||
load_params::<IPACommitmentScheme<_>>(params_path)?;
|
||||
let strategy = IPASingleStrategy::new(¶ms);
|
||||
let vk = load_vk::<IPACommitmentScheme<_>, Fp>(vk_path, ¶ms);
|
||||
let result = verify_proof_model(proof, ¶ms, &vk, strategy);
|
||||
let vk = load_vk::<IPACommitmentScheme<_>, Fp>(vk_path, ¶ms)?;
|
||||
let result = verify_proof_model(proof, ¶ms, &vk, strategy).is_ok();
|
||||
info!("verified: {}", result);
|
||||
assert!(result);
|
||||
}
|
||||
ProofSystem::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
load_params::<KZGCommitmentScheme<Bn256>>(params_path);
|
||||
load_params::<KZGCommitmentScheme<Bn256>>(params_path)?;
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr>(vk_path, ¶ms);
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr>(vk_path, ¶ms)?;
|
||||
let result = verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>(
|
||||
proof, ¶ms, &vk, strategy,
|
||||
);
|
||||
)
|
||||
.is_ok();
|
||||
info!("verified: {}", result);
|
||||
assert!(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,15 +14,57 @@ use anyhow::Result;
|
||||
use halo2_proofs::{
|
||||
arithmetic::FieldExt,
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
};
|
||||
use log::{info, trace};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
use std::cmp::max;
|
||||
use std::marker::PhantomData;
|
||||
use thiserror::Error;
|
||||
pub use vars::*;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
/// The wrong inputs were passed to a lookup node
|
||||
#[error("invalid inputs for a lookup node")]
|
||||
InvalidLookupInputs,
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, OpKind),
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, OpKind),
|
||||
/// A requested node is missing in the graph
|
||||
#[error("a requested node is missing in the graph: {0}")]
|
||||
MissingNode(usize),
|
||||
/// The wrong method was called on an operation
|
||||
#[error("an unsupported method was called on node {0} ({1})")]
|
||||
OpMismatch(usize, OpKind),
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
/// A node has missing parameters
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
Visibility,
|
||||
/// Ezkl only supports divisions by constants
|
||||
#[error("ezkl currently only supports division by constants")]
|
||||
NonConstantDiv,
|
||||
/// Ezkl only supports constant powers
|
||||
#[error("ezkl currently only supports constant exponents")]
|
||||
NonConstantPower,
|
||||
/// Error when attempting to rescale an operation
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(OpKind),
|
||||
/// Error when attempting to load a model
|
||||
#[error("failed to load model")]
|
||||
ModelLoad,
|
||||
}
|
||||
|
||||
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelCircuit<F: FieldExt> {
|
||||
@@ -41,7 +83,7 @@ impl<F: FieldExt + TensorType> Circuit<F> for ModelCircuit<F> {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let model = Model::from_arg();
|
||||
let model = Model::from_arg().expect("model should load from args");
|
||||
let mut num_fixed = 0;
|
||||
let row_cap = model.max_node_size();
|
||||
|
||||
@@ -94,7 +136,7 @@ impl<F: FieldExt + TensorType> Circuit<F> for ModelCircuit<F> {
|
||||
&self,
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<(), PlonkError> {
|
||||
trace!("Setting input in synthesize");
|
||||
let inputs = self
|
||||
.inputs
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::node::*;
|
||||
use super::vars::*;
|
||||
use super::GraphError;
|
||||
use crate::circuit::lookup::Config as LookupConfig;
|
||||
use crate::circuit::lookup::Op as LookupOp;
|
||||
use crate::circuit::lookup::Table as LookupTable;
|
||||
@@ -14,8 +15,8 @@ use crate::circuit::range::*;
|
||||
use crate::commands::{Cli, Commands};
|
||||
use crate::tensor::TensorType;
|
||||
use crate::tensor::{Tensor, ValTensor, VarTensor};
|
||||
use anyhow::{Context, Result};
|
||||
//use clap::Parser;
|
||||
use anyhow::{Context, Error as AnyError};
|
||||
use halo2_proofs::{
|
||||
arithmetic::FieldExt,
|
||||
circuit::{Layouter, Value},
|
||||
@@ -26,13 +27,13 @@ use log::{debug, info, trace};
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::max;
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::error::Error;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use tabled::Table;
|
||||
use tract_onnx;
|
||||
use tract_onnx::prelude::{Framework, Graph, InferenceFact, Node as OnnxNode, OutletId};
|
||||
use tract_onnx::tract_hir::internal::InferenceOp;
|
||||
|
||||
/// Mode we're using the model in.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Mode {
|
||||
@@ -96,6 +97,7 @@ impl Model {
|
||||
/// * `tolerance` - How much each quantized output is allowed to be off by
|
||||
/// * `mode` - The [Mode] we're using the model in.
|
||||
/// * `visibility` - Which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
path: impl AsRef<Path>,
|
||||
scale: i32,
|
||||
@@ -105,26 +107,22 @@ impl Model {
|
||||
tolerance: usize,
|
||||
mode: Mode,
|
||||
visibility: VarVisibility,
|
||||
) -> Self {
|
||||
let model = tract_onnx::onnx().model_for_path(path).unwrap();
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
let model = tract_onnx::onnx()
|
||||
.model_for_path(path)
|
||||
.map_err(|_| GraphError::ModelLoad)?;
|
||||
info!("visibility: {}", visibility);
|
||||
|
||||
let mut nodes = BTreeMap::<usize, Node>::new();
|
||||
let _ = model
|
||||
.nodes()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, n)| {
|
||||
let n = Node::new(n.clone(), &mut nodes, scale, i);
|
||||
nodes.insert(i, n);
|
||||
})
|
||||
.collect_vec();
|
||||
for (i, n) in model.nodes.iter().enumerate() {
|
||||
let n = Node::new(n.clone(), &mut nodes, scale, i)?;
|
||||
nodes.insert(i, n);
|
||||
}
|
||||
let om = Model {
|
||||
model: model.clone(),
|
||||
scale,
|
||||
tolerance,
|
||||
nodes: Self::assign_execution_buckets(nodes)
|
||||
.expect("failed to assign execution buckets"),
|
||||
nodes: Self::assign_execution_buckets(nodes)?,
|
||||
bits,
|
||||
logrows,
|
||||
max_rotations,
|
||||
@@ -134,12 +132,12 @@ impl Model {
|
||||
|
||||
debug!("{}", Table::new(om.nodes.flatten()).to_string());
|
||||
|
||||
om
|
||||
Ok(om)
|
||||
}
|
||||
|
||||
/// Creates a `Model` from parsed CLI arguments
|
||||
pub fn from_ezkl_conf(args: Cli) -> Self {
|
||||
let visibility = VarVisibility::from_args(args.clone());
|
||||
pub fn from_ezkl_conf(args: Cli) -> Result<Self, Box<dyn Error>> {
|
||||
let visibility = VarVisibility::from_args(args.clone())?;
|
||||
match args.command {
|
||||
Commands::Table { model } => Model::new(
|
||||
model,
|
||||
@@ -195,7 +193,7 @@ impl Model {
|
||||
}
|
||||
|
||||
/// Creates a `Model` based on CLI arguments
|
||||
pub fn from_arg() -> Self {
|
||||
pub fn from_arg() -> Result<Self, Box<dyn Error>> {
|
||||
let args = Cli::create();
|
||||
Self::from_ezkl_conf(args)
|
||||
}
|
||||
@@ -211,7 +209,7 @@ impl Model {
|
||||
&self,
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
vars: &mut ModelVars<F>,
|
||||
) -> Result<ModelConfig<F>> {
|
||||
) -> Result<ModelConfig<F>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
let mut results = BTreeMap::new();
|
||||
let mut tables = BTreeMap::new();
|
||||
@@ -224,7 +222,7 @@ impl Model {
|
||||
.collect();
|
||||
if !non_op_nodes.is_empty() {
|
||||
for (i, node) in non_op_nodes {
|
||||
let config = self.conf_non_op_node(&node);
|
||||
let config = self.conf_non_op_node(node)?;
|
||||
results.insert(*i, config);
|
||||
}
|
||||
}
|
||||
@@ -236,7 +234,7 @@ impl Model {
|
||||
|
||||
if !lookup_ops.is_empty() {
|
||||
for (i, node) in lookup_ops {
|
||||
let config = self.conf_table(node, meta, vars, &mut tables);
|
||||
let config = self.conf_table(node, meta, vars, &mut tables)?;
|
||||
results.insert(*i, config);
|
||||
}
|
||||
}
|
||||
@@ -248,7 +246,7 @@ impl Model {
|
||||
.collect();
|
||||
// preserves ordering
|
||||
if !poly_ops.is_empty() {
|
||||
let config = self.conf_poly_ops(&poly_ops, meta, vars);
|
||||
let config = self.conf_poly_ops(&poly_ops, meta, vars)?;
|
||||
results.insert(**poly_ops.keys().max().unwrap(), config);
|
||||
|
||||
let mut display: String = "Poly nodes: ".to_string();
|
||||
@@ -301,24 +299,25 @@ impl Model {
|
||||
configs
|
||||
}
|
||||
/// Configures non op related nodes (eg. representing an input or const value)
|
||||
pub fn conf_non_op_node<F: FieldExt + TensorType>(&self, node: &Node) -> NodeConfig<F> {
|
||||
pub fn conf_non_op_node<F: FieldExt + TensorType>(
|
||||
&self,
|
||||
node: &Node,
|
||||
) -> Result<NodeConfig<F>, Box<dyn Error>> {
|
||||
match &node.opkind {
|
||||
OpKind::Const => {
|
||||
// Typically parameters for one or more layers.
|
||||
// Currently this is handled in the consuming node(s), but will be moved here.
|
||||
NodeConfig::Const
|
||||
Ok(NodeConfig::Const)
|
||||
}
|
||||
OpKind::Input => {
|
||||
// This is the input to the model (e.g. the image).
|
||||
// Currently this is handled in the consuming node(s), but will be moved here.
|
||||
NodeConfig::Input
|
||||
Ok(NodeConfig::Input)
|
||||
}
|
||||
OpKind::Unknown(_c) => {
|
||||
unimplemented!()
|
||||
}
|
||||
c => {
|
||||
panic!("wrong method called for {}", c)
|
||||
}
|
||||
c => Err(Box::new(GraphError::WrongMethod(node.idx, c.clone()))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,25 +334,27 @@ impl Model {
|
||||
nodes: &BTreeMap<&usize, &Node>,
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
vars: &mut ModelVars<F>,
|
||||
) -> NodeConfig<F> {
|
||||
let input_nodes: BTreeMap<(&usize, &PolyOp), Vec<Node>> = nodes
|
||||
.iter()
|
||||
.map(|(i, e)| {
|
||||
(
|
||||
(
|
||||
*i,
|
||||
match &e.opkind {
|
||||
OpKind::Poly(f) => f,
|
||||
_ => panic!(),
|
||||
},
|
||||
),
|
||||
e.inputs
|
||||
.iter()
|
||||
.map(|i| self.nodes.filter(i.node))
|
||||
.collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
) -> Result<NodeConfig<F>, Box<dyn Error>> {
|
||||
let mut input_nodes: BTreeMap<(&usize, &PolyOp), Vec<Node>> = BTreeMap::new();
|
||||
|
||||
for (i, e) in nodes.iter() {
|
||||
let key = (
|
||||
*i,
|
||||
match &e.opkind {
|
||||
OpKind::Poly(f) => f,
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::WrongMethod(e.idx, e.opkind.clone())));
|
||||
}
|
||||
},
|
||||
);
|
||||
let value = e
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| self.nodes.filter(i.node))
|
||||
.collect_vec();
|
||||
input_nodes.insert(key, value);
|
||||
}
|
||||
|
||||
// This works because retain only keeps items for which the predicate returns true, and
|
||||
// insert only returns true if the item was not previously present in the set.
|
||||
// Since the vector is traversed in order, we end up keeping just the first occurrence of each item.
|
||||
@@ -412,7 +413,7 @@ impl Model {
|
||||
|
||||
let inputs = inputs_to_layer.iter();
|
||||
|
||||
NodeConfig::Poly(
|
||||
let config = NodeConfig::Poly(
|
||||
PolyConfig::configure(
|
||||
meta,
|
||||
&inputs.clone().map(|x| x.1.clone()).collect_vec(),
|
||||
@@ -420,7 +421,8 @@ impl Model {
|
||||
&fused_nodes,
|
||||
),
|
||||
inputs.map(|x| x.0).collect_vec(),
|
||||
)
|
||||
);
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Configures a lookup table based operation. These correspond to operations that are represented in
|
||||
@@ -436,7 +438,7 @@ impl Model {
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
vars: &mut ModelVars<F>,
|
||||
tables: &mut BTreeMap<Vec<LookupOp>, Rc<RefCell<LookupTable<F>>>>,
|
||||
) -> NodeConfig<F> {
|
||||
) -> Result<NodeConfig<F>, Box<dyn Error>> {
|
||||
let input_len = node.in_dims[0].iter().product();
|
||||
let input = &vars.advices[0].reshape(&[input_len]);
|
||||
let output = &vars.advices[1].reshape(&[input_len]);
|
||||
@@ -444,20 +446,24 @@ impl Model {
|
||||
|
||||
let op = match &node.opkind {
|
||||
OpKind::Lookup(l) => l,
|
||||
c => panic!("wrong method called for {}", c),
|
||||
c => {
|
||||
return Err(Box::new(GraphError::WrongMethod(node.idx, c.clone())));
|
||||
}
|
||||
};
|
||||
|
||||
if tables.contains_key(&vec![op.clone()]) {
|
||||
let table = tables.get(&vec![op.clone()]).unwrap();
|
||||
let conf: LookupConfig<F> =
|
||||
LookupConfig::configure_with_table(meta, input, output, table.clone());
|
||||
NodeConfig::Lookup(conf, node_inputs)
|
||||
} else {
|
||||
let conf: LookupConfig<F> =
|
||||
LookupConfig::configure(meta, input, output, self.bits, &[op.clone()]);
|
||||
tables.insert(vec![op.clone()], conf.table.clone());
|
||||
NodeConfig::Lookup(conf, node_inputs)
|
||||
}
|
||||
let config =
|
||||
if let std::collections::btree_map::Entry::Vacant(e) = tables.entry(vec![op.clone()]) {
|
||||
let conf: LookupConfig<F> =
|
||||
LookupConfig::configure(meta, input, output, self.bits, &[op.clone()]);
|
||||
e.insert(conf.table.clone());
|
||||
NodeConfig::Lookup(conf, node_inputs)
|
||||
} else {
|
||||
let table = tables.get(&vec![op.clone()]).unwrap();
|
||||
let conf: LookupConfig<F> =
|
||||
LookupConfig::configure_with_table(meta, input, output, table.clone());
|
||||
NodeConfig::Lookup(conf, node_inputs)
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Assigns values to the regions created when calling `configure`.
|
||||
@@ -472,7 +478,7 @@ impl Model {
|
||||
layouter: &mut impl Layouter<F>,
|
||||
inputs: &[ValTensor<F>],
|
||||
vars: &ModelVars<F>,
|
||||
) -> Result<()> {
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
info!("model layout");
|
||||
let mut results = BTreeMap::<usize, ValTensor<F>>::new();
|
||||
for i in inputs.iter().enumerate() {
|
||||
@@ -533,7 +539,7 @@ impl Model {
|
||||
layouter: &mut impl Layouter<F>,
|
||||
inputs: &mut BTreeMap<usize, ValTensor<F>>,
|
||||
config: &NodeConfig<F>,
|
||||
) -> Result<Option<ValTensor<F>>> {
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
// The node kind and the config should be the same.
|
||||
let res = match config.clone() {
|
||||
NodeConfig::Poly(mut ac, idx) => {
|
||||
@@ -555,17 +561,19 @@ impl Model {
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
Some(ac.layout(layouter, &values))
|
||||
Some(ac.layout(layouter, &values)?)
|
||||
}
|
||||
NodeConfig::Lookup(rc, idx) => {
|
||||
assert_eq!(idx.len(), 1);
|
||||
if idx.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidLookupInputs));
|
||||
}
|
||||
// For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output.
|
||||
Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap()))
|
||||
Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap())?)
|
||||
}
|
||||
NodeConfig::Input => None,
|
||||
NodeConfig::Const => None,
|
||||
c => {
|
||||
panic!("Not a configurable op {:?}", c)
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::UnsupportedOp));
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
@@ -580,28 +588,37 @@ impl Model {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `nodes` - `BTreeMap` of (node index, [Node]) pairs.
|
||||
pub fn assign_execution_buckets(mut nodes: BTreeMap<usize, Node>) -> Result<NodeGraph> {
|
||||
pub fn assign_execution_buckets(
|
||||
mut nodes: BTreeMap<usize, Node>,
|
||||
) -> Result<NodeGraph, GraphError> {
|
||||
info!("assigning configuration buckets to operations");
|
||||
|
||||
let mut bucketed_nodes = NodeGraph(BTreeMap::<Option<usize>, BTreeMap<usize, Node>>::new());
|
||||
|
||||
for (_, node) in nodes.iter_mut() {
|
||||
let prev_bucket: Option<usize> = node
|
||||
let mut prev_buckets = vec![];
|
||||
for n in node
|
||||
.inputs
|
||||
.iter()
|
||||
.filter(|n| !bucketed_nodes.filter(n.node).opkind.is_const())
|
||||
.map(|n| match bucketed_nodes.filter(n.node).bucket {
|
||||
Some(b) => b,
|
||||
None => panic!(),
|
||||
})
|
||||
.max();
|
||||
{
|
||||
match bucketed_nodes.filter(n.node).bucket {
|
||||
Some(b) => prev_buckets.push(b),
|
||||
None => {
|
||||
return Err(GraphError::MissingNode(n.node));
|
||||
}
|
||||
}
|
||||
}
|
||||
let prev_bucket: Option<&usize> = prev_buckets.iter().max();
|
||||
|
||||
match &node.opkind {
|
||||
OpKind::Input => node.bucket = Some(0),
|
||||
OpKind::Const => node.bucket = None,
|
||||
OpKind::Poly(_) => node.bucket = Some(prev_bucket.unwrap()),
|
||||
OpKind::Poly(_) => node.bucket = Some(*prev_bucket.unwrap()),
|
||||
OpKind::Lookup(_) => node.bucket = Some(prev_bucket.unwrap() + 1),
|
||||
_ => unimplemented!(),
|
||||
op => {
|
||||
return Err(GraphError::WrongMethod(node.idx, op.clone()));
|
||||
}
|
||||
}
|
||||
bucketed_nodes.insert(node.bucket, node.idx, node.clone());
|
||||
}
|
||||
@@ -613,7 +630,7 @@ impl Model {
|
||||
/// Note that this order is not stable over multiple reloads of the model. For example, it will freely
|
||||
/// interchange the order of evaluation of fixed parameters. For example weight could have id 1 on one load,
|
||||
/// and bias id 2, and vice versa on the next load of the same file. The ids are also not stable.
|
||||
pub fn eval_order(&self) -> Result<Vec<usize>> {
|
||||
pub fn eval_order(&self) -> Result<Vec<usize>, AnyError> {
|
||||
self.model.eval_order()
|
||||
}
|
||||
|
||||
@@ -623,12 +640,12 @@ impl Model {
|
||||
}
|
||||
|
||||
/// Returns the ID of the computational graph's inputs
|
||||
pub fn input_outlets(&self) -> Result<Vec<OutletId>> {
|
||||
pub fn input_outlets(&self) -> Result<Vec<OutletId>, Box<dyn Error>> {
|
||||
Ok(self.model.input_outlets()?.to_vec())
|
||||
}
|
||||
|
||||
/// Returns the ID of the computational graph's outputs
|
||||
pub fn output_outlets(&self) -> Result<Vec<OutletId>> {
|
||||
pub fn output_outlets(&self) -> Result<Vec<OutletId>, Box<dyn Error>> {
|
||||
Ok(self.model.output_outlets()?.to_vec())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
use super::utilities::{node_output_shapes, scale_to_multiplier, vector_to_quantized};
|
||||
use crate::abort;
|
||||
use crate::circuit::lookup::Config as LookupConfig;
|
||||
use crate::circuit::lookup::Op as LookupOp;
|
||||
use crate::circuit::polynomial::Config as PolyConfig;
|
||||
use crate::circuit::polynomial::Op as PolyOp;
|
||||
use crate::graph::GraphError;
|
||||
use crate::tensor::ops::{add, const_mult, div, mult};
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorType;
|
||||
use anyhow::Result;
|
||||
use halo2_proofs::arithmetic::FieldExt;
|
||||
use itertools::Itertools;
|
||||
use log::{error, info, trace, warn};
|
||||
use log::{info, trace, warn};
|
||||
use std::collections::{btree_map::Entry, BTreeMap};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
use tabled::Tabled;
|
||||
@@ -294,7 +295,7 @@ impl Node {
|
||||
other_nodes: &mut BTreeMap<usize, Node>,
|
||||
scale: i32,
|
||||
idx: usize,
|
||||
) -> Self {
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
trace!("Create {:?}", node);
|
||||
trace!("Create op {:?}", node.op);
|
||||
let output_shapes = match node_output_shapes(&node) {
|
||||
@@ -302,20 +303,13 @@ impl Node {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let mut inputs: Vec<Node> = node
|
||||
.inputs
|
||||
.iter_mut()
|
||||
// this shouldn't fail
|
||||
.map(|i| {
|
||||
match other_nodes.get(&i.node) {
|
||||
Some(n) => n,
|
||||
None => {
|
||||
abort!("input {} has not been initialized", i.node);
|
||||
}
|
||||
}
|
||||
.clone()
|
||||
})
|
||||
.collect();
|
||||
let mut inputs = vec![];
|
||||
for i in node.inputs.iter_mut() {
|
||||
match other_nodes.get(&i.node) {
|
||||
Some(n) => inputs.push(n.clone()),
|
||||
None => return Err(Box::new(GraphError::MissingNode(i.node))),
|
||||
}
|
||||
}
|
||||
|
||||
let mut opkind = OpKind::new(node.op().name().as_ref()); // parses the op name
|
||||
|
||||
@@ -386,11 +380,11 @@ impl Node {
|
||||
Some(b) => match (*b).as_any().downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
panic!("not a leaky relu!");
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
|
||||
}
|
||||
},
|
||||
None => {
|
||||
panic!("op is not a Tract Expansion!");
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -445,7 +439,7 @@ impl Node {
|
||||
|
||||
opkind = OpKind::Lookup(LookupOp::PReLU {
|
||||
scale: layer_scale,
|
||||
slopes: slopes.clone(),
|
||||
slopes,
|
||||
}); // now the input will be scaled down to match
|
||||
|
||||
Node {
|
||||
@@ -462,7 +456,7 @@ impl Node {
|
||||
}
|
||||
LookupOp::Div { .. } => {
|
||||
if inputs[1].out_dims.clone() != [1] {
|
||||
abort!("ezkl currently only supports division by a constant");
|
||||
return Err(Box::new(GraphError::NonConstantDiv));
|
||||
}
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let div = inputs[1].output_max / mult;
|
||||
@@ -516,27 +510,37 @@ impl Node {
|
||||
Some(b) => match (*b).as_any().downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
panic!("not a conv!");
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
|
||||
}
|
||||
},
|
||||
None => {
|
||||
panic!("op is not a Tract Expansion!");
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
|
||||
}
|
||||
};
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
assert_eq!(conv_node.data_format, DataFormat::NCHW);
|
||||
assert_eq!(conv_node.kernel_fmt, KernelFormat::OIHW);
|
||||
if (conv_node.data_format != DataFormat::NCHW)
|
||||
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
let stride = match conv_node.strides.clone() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
abort!("strides for node {} has not been initialized", idx);
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"strides".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
let padding = match &conv_node.padding {
|
||||
PaddingSpec::Explicit(p, _, _) => p,
|
||||
_ => panic!("padding is not explicitly specified"),
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"padding".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
if inputs.len() == 3 {
|
||||
@@ -544,11 +548,11 @@ impl Node {
|
||||
let scale_diff =
|
||||
weight_node.out_scale + input_node.out_scale - bias_node.out_scale;
|
||||
let mut bias_node = other_nodes.get_mut(&node.inputs[2].node).unwrap();
|
||||
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff);
|
||||
assert_eq!(
|
||||
input_node.out_scale + weight_node.out_scale,
|
||||
bias_node.out_scale
|
||||
);
|
||||
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?;
|
||||
if (input_node.out_scale + weight_node.out_scale) != bias_node.out_scale
|
||||
{
|
||||
return Err(Box::new(GraphError::RescalingError(opkind)));
|
||||
}
|
||||
}
|
||||
|
||||
let oihw = weight_node.out_dims.clone();
|
||||
@@ -590,18 +594,28 @@ impl Node {
|
||||
let op = Box::new(node.op());
|
||||
let sumpool_node: &SumPool = match op.downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => panic!("op isn't a SumPool!"),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
|
||||
}
|
||||
};
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
assert_eq!(pool_spec.data_format, DataFormat::NCHW);
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
let stride = pool_spec.strides.clone().unwrap();
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(p, _, _) => p,
|
||||
_ => panic!("padding is not explicitly specified"),
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"padding".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
@@ -697,12 +711,10 @@ impl Node {
|
||||
let scale_diff =
|
||||
weight_node.out_scale + input_node.out_scale - bias_node.out_scale;
|
||||
let mut bias_node = other_nodes.get_mut(&node.inputs[2].node).unwrap();
|
||||
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff);
|
||||
|
||||
assert_eq!(
|
||||
input_node.out_scale + weight_node.out_scale,
|
||||
bias_node.out_scale
|
||||
);
|
||||
bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?;
|
||||
if (input_node.out_scale + weight_node.out_scale) != bias_node.out_scale {
|
||||
return Err(Box::new(GraphError::RescalingError(opkind)));
|
||||
}
|
||||
|
||||
let in_dim = weight_node.out_dims.clone()[1];
|
||||
let out_dim = weight_node.out_dims.clone()[0];
|
||||
@@ -727,34 +739,26 @@ impl Node {
|
||||
PolyOp::BatchNorm => {
|
||||
//Compute scale and shift from the four inputs,
|
||||
// then replace the first two, and change this node to a ScaleAndShift
|
||||
|
||||
// let (input_node, mut gamma_node, mut beta_node, mean_node, var_node) = (
|
||||
// &mut inputs[0],
|
||||
// &mut inputs[1],
|
||||
// &mut inputs[2],
|
||||
// &mut inputs[3],
|
||||
// &mut inputs[4],
|
||||
// );
|
||||
let gamma = inputs[1].raw_const_value.as_ref().unwrap();
|
||||
let beta = inputs[2].raw_const_value.as_ref().unwrap();
|
||||
let mu = inputs[3].raw_const_value.as_ref().unwrap();
|
||||
let sigma = inputs[4].raw_const_value.as_ref().unwrap();
|
||||
let num_entries = gamma.len();
|
||||
|
||||
let a = div(gamma.clone(), sigma.clone());
|
||||
let amu: Tensor<f32> = mult(&vec![a.clone(), mu.clone()]);
|
||||
let amupb: Tensor<f32> = add(&vec![amu, beta.clone()]);
|
||||
let b = const_mult(&amupb, -1f32);
|
||||
let a = div(gamma.clone(), sigma.clone())?;
|
||||
let amu: Tensor<f32> = mult(&vec![a.clone(), mu.clone()])?;
|
||||
let amupb: Tensor<f32> = add(&vec![amu, beta.clone()])?;
|
||||
let b = const_mult(&amupb, -1f32)?;
|
||||
|
||||
let in_scale = inputs[0].out_scale;
|
||||
let out_scale = 2 * inputs[0].out_scale;
|
||||
// gamma node becomes the scale (weigh) in scale and shift
|
||||
inputs[1].raw_const_value = Some(a);
|
||||
inputs[1].quantize_const_to_scale(in_scale);
|
||||
inputs[1].quantize_const_to_scale(in_scale)?;
|
||||
|
||||
// beta node becomes the shift (bias)
|
||||
inputs[2].raw_const_value = Some(b);
|
||||
inputs[2].quantize_const_to_scale(out_scale);
|
||||
inputs[2].quantize_const_to_scale(out_scale)?;
|
||||
|
||||
Node {
|
||||
idx,
|
||||
@@ -772,7 +776,7 @@ impl Node {
|
||||
}
|
||||
|
||||
PolyOp::Add => {
|
||||
opkind = Self::homogenize_input_scales(opkind, inputs.clone());
|
||||
opkind = Self::homogenize_input_scales(opkind, inputs.clone())?;
|
||||
let output_max =
|
||||
if let OpKind::Poly(PolyOp::Rescaled { scale, .. }) = &opkind {
|
||||
(inputs
|
||||
@@ -785,8 +789,7 @@ impl Node {
|
||||
.unwrap() as f32)
|
||||
* (inputs.len() as f32)
|
||||
} else {
|
||||
error!("failed to homogenize input scalings for node {}", idx);
|
||||
panic!()
|
||||
return Err(Box::new(GraphError::RescalingError(opkind)));
|
||||
};
|
||||
|
||||
Node {
|
||||
@@ -802,7 +805,9 @@ impl Node {
|
||||
}
|
||||
}
|
||||
PolyOp::Sum => {
|
||||
assert!(inputs.len() == 1);
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, opkind)));
|
||||
};
|
||||
|
||||
Node {
|
||||
idx,
|
||||
@@ -818,7 +823,7 @@ impl Node {
|
||||
}
|
||||
}
|
||||
PolyOp::Sub => {
|
||||
opkind = Self::homogenize_input_scales(opkind, inputs.clone());
|
||||
opkind = Self::homogenize_input_scales(opkind, inputs.clone())?;
|
||||
let output_max =
|
||||
if let OpKind::Poly(PolyOp::Rescaled { inner: _, scale }) = &opkind {
|
||||
(inputs
|
||||
@@ -831,8 +836,7 @@ impl Node {
|
||||
.unwrap() as f32)
|
||||
* (inputs.len() as f32)
|
||||
} else {
|
||||
error!("failed to homogenize input scalings for node {}", idx);
|
||||
panic!()
|
||||
return Err(Box::new(GraphError::RescalingError(opkind)));
|
||||
};
|
||||
|
||||
Node {
|
||||
@@ -874,10 +878,9 @@ impl Node {
|
||||
let pow = inputs[1].clone().raw_const_value.unwrap()[0];
|
||||
node.inputs.pop();
|
||||
if inputs[1].out_dims != [1] {
|
||||
error!(
|
||||
"ezkl currently only supports raising to the power by a constant"
|
||||
);
|
||||
unimplemented!()
|
||||
{
|
||||
return Err(Box::new(GraphError::NonConstantPower));
|
||||
}
|
||||
}
|
||||
|
||||
Node {
|
||||
@@ -900,8 +903,7 @@ impl Node {
|
||||
}
|
||||
}
|
||||
PolyOp::Rescaled { .. } => {
|
||||
error!("operations should not already be rescaled at this stage");
|
||||
panic!()
|
||||
return Err(Box::new(GraphError::RescalingError(opkind)));
|
||||
}
|
||||
PolyOp::Identity => {
|
||||
let input_node = &inputs[0];
|
||||
@@ -939,33 +941,42 @@ impl Node {
|
||||
let shape_const = match shape_const_node.const_value.as_ref() {
|
||||
Some(sc) => sc,
|
||||
None => {
|
||||
abort!("missing shape constant");
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"shape constant".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
let shapes = &shape_const[0..];
|
||||
let new_dims: Vec<usize> = if shapes.iter().all(|x| x > &0) {
|
||||
shapes
|
||||
.iter()
|
||||
.map(|x| {
|
||||
assert!(x > &0);
|
||||
*x as usize
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
let num_entries: usize = input_node.out_dims.iter().product();
|
||||
let explicit_prod: i32 = shapes.iter().filter(|x| *x > &0).product();
|
||||
assert!(explicit_prod > 0);
|
||||
let inferred = num_entries / (explicit_prod as usize);
|
||||
let mut new_dims: Vec<usize> = Vec::new();
|
||||
for i in shapes {
|
||||
match i {
|
||||
-1 => new_dims.push(inferred),
|
||||
0 => continue,
|
||||
x => new_dims.push(*x as usize),
|
||||
let new_dims: Result<Vec<usize>, Box<dyn Error>> =
|
||||
if shapes.iter().all(|x| x > &0) {
|
||||
let mut res = vec![];
|
||||
for x in shapes.iter() {
|
||||
if x <= &0 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, opkind)));
|
||||
}
|
||||
res.push(*x as usize);
|
||||
}
|
||||
}
|
||||
new_dims
|
||||
};
|
||||
Ok(res)
|
||||
} else {
|
||||
let num_entries: usize = input_node.out_dims.iter().product();
|
||||
let explicit_prod: i32 =
|
||||
shapes.iter().filter(|x| *x > &0).product();
|
||||
if explicit_prod <= 0 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, opkind)));
|
||||
}
|
||||
let inferred = num_entries / (explicit_prod as usize);
|
||||
let mut new_dims: Vec<usize> = Vec::new();
|
||||
for i in shapes {
|
||||
match i {
|
||||
-1 => new_dims.push(inferred),
|
||||
0 => continue,
|
||||
x => new_dims.push(*x as usize),
|
||||
}
|
||||
}
|
||||
Ok(new_dims)
|
||||
};
|
||||
|
||||
let new_dims = new_dims?;
|
||||
|
||||
Node {
|
||||
idx,
|
||||
@@ -986,7 +997,7 @@ impl Node {
|
||||
let const_node: &Const = match op.as_any().downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
abort!("op is not a const!");
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, opkind)));
|
||||
}
|
||||
};
|
||||
let dt = const_node.0.datum_type();
|
||||
@@ -1081,16 +1092,18 @@ impl Node {
|
||||
warn!("{:?} is unknown", opkind);
|
||||
Node::default()
|
||||
}
|
||||
o => {
|
||||
error!("unsupported op {:?}", o);
|
||||
panic!()
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::UnsupportedOp));
|
||||
}
|
||||
};
|
||||
mn
|
||||
Ok(mn)
|
||||
}
|
||||
|
||||
/// Ensures all inputs to a node have the same fixed point denominator.
|
||||
fn homogenize_input_scales(opkind: OpKind, inputs: Vec<Node>) -> OpKind {
|
||||
fn homogenize_input_scales(
|
||||
opkind: OpKind,
|
||||
inputs: Vec<Node>,
|
||||
) -> Result<OpKind, Box<dyn Error>> {
|
||||
let mut multipliers = vec![1; inputs.len()];
|
||||
let out_scales = inputs.windows(1).map(|w| w[0].out_scale).collect_vec();
|
||||
if !out_scales.windows(2).all(|w| w[0] == w[1]) {
|
||||
@@ -1114,32 +1127,42 @@ impl Node {
|
||||
.collect_vec();
|
||||
}
|
||||
if let OpKind::Poly(c) = &opkind {
|
||||
OpKind::Poly(PolyOp::Rescaled {
|
||||
Ok(OpKind::Poly(PolyOp::Rescaled {
|
||||
inner: Box::new(c.clone()),
|
||||
scale: (0..inputs.len()).zip(multipliers).collect_vec(),
|
||||
})
|
||||
}))
|
||||
} else {
|
||||
error!("should not homegenize input scales for non fused ops.");
|
||||
panic!()
|
||||
Err(Box::new(GraphError::RescalingError(opkind)))
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize_const_to_scale(&mut self, scale: i32) {
|
||||
assert!(matches!(self.opkind, OpKind::Const));
|
||||
fn quantize_const_to_scale(&mut self, scale: i32) -> Result<(), Box<dyn Error>> {
|
||||
if !self.opkind.is_const() {
|
||||
return Err(Box::new(GraphError::WrongMethod(
|
||||
self.idx,
|
||||
self.opkind.clone(),
|
||||
)));
|
||||
};
|
||||
let raw = self.raw_const_value.as_ref().unwrap();
|
||||
self.out_scale = scale;
|
||||
let t = vector_to_quantized(raw, raw.dims(), 0f32, self.out_scale).unwrap();
|
||||
self.output_max = 0f32; //t.iter().map(|x| x.abs()).max().unwrap() as f32;
|
||||
self.const_value = Some(t);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-quantizes a constant value node to a new scale.
|
||||
fn scale_up_const_node(node: &mut Node, scale: i32) -> &mut Node {
|
||||
assert!(matches!(node.opkind, OpKind::Const));
|
||||
fn scale_up_const_node(node: &mut Node, scale: i32) -> Result<&mut Node, Box<dyn Error>> {
|
||||
if !node.opkind.is_const() {
|
||||
return Err(Box::new(GraphError::WrongMethod(
|
||||
node.idx,
|
||||
node.opkind.clone(),
|
||||
)));
|
||||
};
|
||||
if scale > 0 {
|
||||
if let Some(raw) = &node.raw_const_value {
|
||||
if let Some(val) = &node.raw_const_value {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let t = vector_to_quantized(&raw, raw.dims(), 0f32, scale).unwrap();
|
||||
let t = vector_to_quantized(val, val.dims(), 0f32, scale)?;
|
||||
node.const_value = Some(t);
|
||||
info!(
|
||||
"------ scaled const node {:?}: {:?} -> {:?}",
|
||||
@@ -1149,6 +1172,6 @@ impl Node {
|
||||
node.out_scale = scale;
|
||||
}
|
||||
}
|
||||
node
|
||||
Ok(node)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use crate::abort;
|
||||
use std::error::Error;
|
||||
|
||||
use crate::commands::Cli;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::tensor::{ValTensor, VarTensor};
|
||||
use halo2_proofs::{arithmetic::FieldExt, plonk::ConstraintSystem};
|
||||
use itertools::Itertools;
|
||||
use log::error;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::GraphError;
|
||||
|
||||
/// Label Enum to track whether model input, model parameters, and model output are public or private
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub enum Visibility {
|
||||
@@ -53,7 +55,7 @@ impl std::fmt::Display for VarVisibility {
|
||||
impl VarVisibility {
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
pub fn from_args(args: Cli) -> Self {
|
||||
pub fn from_args(args: Cli) -> Result<Self, Box<dyn Error>> {
|
||||
let input_vis = if args.public_inputs {
|
||||
Visibility::Public
|
||||
} else {
|
||||
@@ -70,13 +72,13 @@ impl VarVisibility {
|
||||
Visibility::Private
|
||||
};
|
||||
if !output_vis.is_public() & !params_vis.is_public() & !input_vis.is_public() {
|
||||
abort!("at least one set of variables should be public");
|
||||
return Err(Box::new(GraphError::Visibility));
|
||||
}
|
||||
Self {
|
||||
Ok(Self {
|
||||
input: input_vis,
|
||||
params: params_vis,
|
||||
output: output_vis,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,12 +51,3 @@ pub mod graph;
|
||||
pub mod pfsys;
|
||||
/// An implementation of multi-dimensional tensors.
|
||||
pub mod tensor;
|
||||
|
||||
/// A macro to abort concisely.
|
||||
#[macro_export]
|
||||
macro_rules! abort {
|
||||
($msg:literal $(, $ex:expr)*) => {
|
||||
error!($msg, $($ex,)*);
|
||||
panic!();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -54,8 +54,10 @@ use plonk_verifier::{
|
||||
Protocol,
|
||||
};
|
||||
use rand::rngs::OsRng;
|
||||
use std::error::Error;
|
||||
use std::io::Cursor;
|
||||
use std::{iter, rc::Rc};
|
||||
use thiserror::Error;
|
||||
|
||||
const LIMBS: usize = 4;
|
||||
const BITS: usize = 68;
|
||||
@@ -77,6 +79,26 @@ type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip
|
||||
pub type PoseidonTranscript<L, S> =
|
||||
system::halo2::transcript::halo2::PoseidonTranscript<G1Affine, L, S, T, RATE, R_F, R_P>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
/// Errors related to proof aggregation
|
||||
pub enum AggregationError {
|
||||
/// A KZG proof could not be verified
|
||||
#[error("failed to verify KZG proof")]
|
||||
KZGProofVerification,
|
||||
/// EVM execution errors
|
||||
#[error("EVM execution of raw code failed")]
|
||||
EVMRawExecution,
|
||||
/// proof read errors
|
||||
#[error("Failed to read proof")]
|
||||
ProofRead,
|
||||
/// proof verification errors
|
||||
#[error("Failed to verify proof")]
|
||||
ProofVerify,
|
||||
/// proof creation errors
|
||||
#[error("Failed to create proof")]
|
||||
ProofCreate,
|
||||
}
|
||||
|
||||
/// An application snark with proof and instance variables ready for aggregation (raw field element)
|
||||
#[derive(Debug)]
|
||||
pub struct Snark {
|
||||
@@ -142,7 +164,7 @@ pub fn aggregate<'a>(
|
||||
loader: &Rc<Halo2Loader<'a>>,
|
||||
snarks: &[SnarkWitness],
|
||||
as_proof: Value<&'_ [u8]>,
|
||||
) -> KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>> {
|
||||
) -> Result<KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>, plonk::Error> {
|
||||
let assign_instances = |instances: &[Vec<Value<Fr>>]| {
|
||||
instances
|
||||
.iter()
|
||||
@@ -155,22 +177,24 @@ pub fn aggregate<'a>(
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
let accumulators = snarks
|
||||
.iter()
|
||||
.flat_map(|snark| {
|
||||
let protocol = snark.protocol.loaded(loader);
|
||||
let instances = assign_instances(&snark.instances);
|
||||
let mut transcript =
|
||||
PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap();
|
||||
Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let mut accumulators = vec![];
|
||||
|
||||
for snark in snarks.iter() {
|
||||
let protocol = snark.protocol.loaded(loader);
|
||||
let instances = assign_instances(&snark.instances);
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
let mut accum = Plonk::succinct_verify(svk, &protocol, &instances, &proof)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
accumulators.append(&mut accum);
|
||||
}
|
||||
|
||||
let accumulator = {
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, as_proof);
|
||||
let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap();
|
||||
As::verify(&Default::default(), &accumulators, &proof).unwrap()
|
||||
let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
As::verify(&Default::default(), &accumulators, &proof).map_err(|_| plonk::Error::Synthesis)
|
||||
};
|
||||
|
||||
accumulator
|
||||
@@ -229,28 +253,31 @@ pub struct AggregationCircuit {
|
||||
|
||||
impl AggregationCircuit {
|
||||
/// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof.
|
||||
pub fn new(params: &ParamsKZG<Bn256>, snarks: impl IntoIterator<Item = Snark>) -> Self {
|
||||
pub fn new(
|
||||
params: &ParamsKZG<Bn256>,
|
||||
snarks: impl IntoIterator<Item = Snark>,
|
||||
) -> Result<Self, AggregationError> {
|
||||
let svk = params.get_g()[0].into();
|
||||
let snarks = snarks.into_iter().collect_vec();
|
||||
let accumulators = snarks
|
||||
.iter()
|
||||
.flat_map(|snark| {
|
||||
trace!("Aggregating with snark instances {:?}", snark.instances);
|
||||
let mut transcript =
|
||||
PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
|
||||
let proof =
|
||||
Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript)
|
||||
.unwrap();
|
||||
Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut accumulators = vec![];
|
||||
|
||||
for snark in snarks.iter() {
|
||||
trace!("Aggregating with snark instances {:?}", snark.instances);
|
||||
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
|
||||
let proof = Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript)
|
||||
.map_err(|_| AggregationError::ProofRead)?;
|
||||
let mut accum = Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof)
|
||||
.map_err(|_| AggregationError::ProofVerify)?;
|
||||
accumulators.append(&mut accum);
|
||||
}
|
||||
|
||||
trace!("Accumulator");
|
||||
let (accumulator, as_proof) = {
|
||||
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(Vec::new());
|
||||
let accumulator =
|
||||
As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng)
|
||||
.unwrap();
|
||||
.map_err(|_| AggregationError::ProofCreate)?;
|
||||
(accumulator, transcript.finalize())
|
||||
};
|
||||
|
||||
@@ -260,12 +287,12 @@ impl AggregationCircuit {
|
||||
.map(fe_to_limbs::<_, _, LIMBS, BITS>)
|
||||
.concat();
|
||||
|
||||
Self {
|
||||
Ok(Self {
|
||||
svk,
|
||||
snarks: snarks.into_iter().map_into().collect(),
|
||||
instances,
|
||||
as_proof: Value::known(as_proof),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Accumulator indices used in generating verifier.
|
||||
@@ -331,7 +358,7 @@ impl Circuit<Fr> for AggregationCircuit {
|
||||
let ecc_chip = config.ecc_chip();
|
||||
let loader = Halo2Loader::new(ecc_chip, ctx);
|
||||
let KzgAccumulator { lhs, rhs } =
|
||||
aggregate(&self.svk, &loader, &self.snarks, self.as_proof());
|
||||
aggregate(&self.svk, &loader, &self.snarks, self.as_proof())?;
|
||||
|
||||
let lhs = lhs.assigned().clone();
|
||||
let rhs = rhs.assigned().clone();
|
||||
@@ -355,10 +382,14 @@ impl Circuit<Fr> for AggregationCircuit {
|
||||
}
|
||||
|
||||
/// Create proof and instance variables for the application snark
|
||||
pub fn gen_application_snark(params: &ParamsKZG<Bn256>, data: &ModelInput, args: &Cli) -> Snark {
|
||||
let (circuit, public_inputs) = prepare_circuit_and_public_input::<Fr>(data, &args);
|
||||
pub fn gen_application_snark(
|
||||
params: &ParamsKZG<Bn256>,
|
||||
data: &ModelInput,
|
||||
args: &Cli,
|
||||
) -> Result<Snark, Box<dyn Error>> {
|
||||
let (circuit, public_inputs) = prepare_circuit_and_public_input::<Fr>(data, args)?;
|
||||
|
||||
let pk = gen_pk(params, &circuit);
|
||||
let pk = gen_pk(params, &circuit)?;
|
||||
let number_instance = public_inputs[0].len();
|
||||
trace!("number_instance {:?}", number_instance);
|
||||
let protocol = compile(
|
||||
@@ -377,8 +408,8 @@ pub fn gen_application_snark(params: &ParamsKZG<Bn256>, data: &ModelInput, args:
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(params, &pk, circuit, pi_inner.clone());
|
||||
Snark::new(protocol, pi_inner, proof)
|
||||
>(params, &pk, circuit, pi_inner.clone())?;
|
||||
Ok(Snark::new(protocol, pi_inner, proof))
|
||||
}
|
||||
|
||||
/// Create aggregation EVM verifier bytecode
|
||||
@@ -387,7 +418,7 @@ pub fn gen_aggregation_evm_verifier(
|
||||
vk: &VerifyingKey<G1Affine>,
|
||||
num_instance: Vec<usize>,
|
||||
accumulator_indices: Vec<(usize, usize)>,
|
||||
) -> Vec<u8> {
|
||||
) -> Result<Vec<u8>, AggregationError> {
|
||||
let svk = params.get_g()[0].into();
|
||||
let dk = (params.g2(), params.s_g2()).into();
|
||||
let protocol = compile(
|
||||
@@ -400,37 +431,40 @@ pub fn gen_aggregation_evm_verifier(
|
||||
|
||||
let loader = EvmLoader::new::<Fq, Fr>();
|
||||
let protocol = protocol.loaded(&loader);
|
||||
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader.clone());
|
||||
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader);
|
||||
|
||||
let instances = transcript.load_instances(num_instance);
|
||||
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap();
|
||||
Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap();
|
||||
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| AggregationError::ProofRead)?;
|
||||
Plonk::verify(&svk, &dk, &protocol, &instances, &proof)
|
||||
.map_err(|_| AggregationError::ProofVerify)?;
|
||||
|
||||
evm::compile_yul(&loader.yul_code())
|
||||
Ok(evm::compile_yul(&loader.yul_code()))
|
||||
}
|
||||
|
||||
/// Verify by executing bytecode with instance variables and proof as input
|
||||
pub fn evm_verify(deployment_code: Vec<u8>, instances: Vec<Vec<Fr>>, proof: Vec<u8>) {
|
||||
pub fn evm_verify(
|
||||
deployment_code: Vec<u8>,
|
||||
instances: Vec<Vec<Fr>>,
|
||||
proof: Vec<u8>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
let calldata = encode_calldata(&instances, &proof);
|
||||
let success = {
|
||||
let mut evm = ExecutorBuilder::default()
|
||||
.with_gas_limit(u64::MAX.into())
|
||||
.build(Backend::new(MultiFork::new().0, None));
|
||||
let mut evm = ExecutorBuilder::default()
|
||||
.with_gas_limit(u64::MAX.into())
|
||||
.build(Backend::new(MultiFork::new().0, None));
|
||||
|
||||
let caller = Address::from_low_u64_be(0xfe);
|
||||
let verifier = evm
|
||||
.deploy(caller, deployment_code.into(), 0.into(), None)
|
||||
.unwrap()
|
||||
.address;
|
||||
let result = evm
|
||||
.call_raw(caller, verifier, calldata.into(), 0.into())
|
||||
.unwrap();
|
||||
let caller = Address::from_low_u64_be(0xfe);
|
||||
let verifier = evm
|
||||
.deploy(caller, deployment_code.into(), 0.into(), None)
|
||||
.map_err(Box::new)?
|
||||
.address;
|
||||
let result = evm
|
||||
.call_raw(caller, verifier, calldata.into(), 0.into())
|
||||
.map_err(|_| Box::new(AggregationError::EVMRawExecution))?;
|
||||
|
||||
dbg!(result.gas_used);
|
||||
dbg!(result.gas_used);
|
||||
|
||||
!result.reverted
|
||||
};
|
||||
assert!(success);
|
||||
Ok(!result.reverted)
|
||||
}
|
||||
|
||||
/// Generate a structured reference string for testing. Not secure, do not use in production.
|
||||
@@ -439,9 +473,12 @@ pub fn gen_srs(k: u32) -> ParamsKZG<Bn256> {
|
||||
}
|
||||
|
||||
/// Generate the proving key
|
||||
pub fn gen_pk<C: Circuit<Fr>>(params: &ParamsKZG<Bn256>, circuit: &C) -> ProvingKey<G1Affine> {
|
||||
let vk = keygen_vk(params, circuit).unwrap();
|
||||
keygen_pk(params, vk, circuit).unwrap()
|
||||
pub fn gen_pk<C: Circuit<Fr>>(
|
||||
params: &ParamsKZG<Bn256>,
|
||||
circuit: &C,
|
||||
) -> Result<ProvingKey<G1Affine>, plonk::Error> {
|
||||
let vk = keygen_vk(params, circuit)?;
|
||||
keygen_pk(params, vk, circuit)
|
||||
}
|
||||
|
||||
/// Generates proof for either application circuit (model) or aggregation circuit.
|
||||
@@ -455,43 +492,41 @@ pub fn gen_kzg_proof<
|
||||
pk: &ProvingKey<G1Affine>,
|
||||
circuit: C,
|
||||
instances: Vec<Vec<Fr>>,
|
||||
) -> Vec<u8> {
|
||||
) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
MockProver::run(params.k(), &circuit, instances.clone())
|
||||
.unwrap()
|
||||
.map_err(Box::new)?
|
||||
.assert_satisfied();
|
||||
|
||||
let instances = instances
|
||||
.iter()
|
||||
.map(|instances| instances.as_slice())
|
||||
.collect_vec();
|
||||
let proof = {
|
||||
let mut transcript = TW::init(Vec::new());
|
||||
create_proof::<KZGCommitmentScheme<Bn256>, ProverGWC<_>, _, _, TW, _>(
|
||||
params,
|
||||
pk,
|
||||
&[circuit],
|
||||
&[instances.as_slice()],
|
||||
OsRng,
|
||||
&mut transcript,
|
||||
)
|
||||
.unwrap();
|
||||
transcript.finalize()
|
||||
};
|
||||
let mut proof = TW::init(Vec::new());
|
||||
create_proof::<KZGCommitmentScheme<Bn256>, ProverGWC<_>, _, _, TW, _>(
|
||||
params,
|
||||
pk,
|
||||
&[circuit],
|
||||
&[instances.as_slice()],
|
||||
OsRng,
|
||||
&mut proof,
|
||||
)
|
||||
.map_err(Box::new)?;
|
||||
let proof = proof.finalize();
|
||||
|
||||
let accept = {
|
||||
let mut transcript = TR::init(Cursor::new(proof.clone()));
|
||||
VerificationStrategy::<_, VerifierGWC<_>>::finalize(
|
||||
verify_proof::<_, VerifierGWC<_>, _, TR, _>(
|
||||
params.verifier_params(),
|
||||
pk.get_vk(),
|
||||
AccumulatorStrategy::new(params.verifier_params()),
|
||||
&[instances.as_slice()],
|
||||
&mut transcript,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
};
|
||||
assert!(accept);
|
||||
let mut transcript = TR::init(Cursor::new(proof.clone()));
|
||||
let verify = verify_proof::<_, VerifierGWC<_>, _, TR, _>(
|
||||
params.verifier_params(),
|
||||
pk.get_vk(),
|
||||
AccumulatorStrategy::new(params.verifier_params()),
|
||||
&[instances.as_slice()],
|
||||
&mut transcript,
|
||||
)
|
||||
.map_err(Box::new)?;
|
||||
|
||||
proof
|
||||
let accept = VerificationStrategy::<_, VerifierGWC<_>>::finalize(verify);
|
||||
|
||||
if !accept {
|
||||
return Err(Box::new(AggregationError::KZGProofVerification));
|
||||
}
|
||||
Ok(proof)
|
||||
}
|
||||
|
||||
247
src/pfsys/mod.rs
247
src/pfsys/mod.rs
@@ -2,11 +2,11 @@
|
||||
#[cfg(feature = "evm")]
|
||||
pub mod aggregation;
|
||||
|
||||
use crate::abort;
|
||||
use crate::commands::{data_path, Cli};
|
||||
use crate::fieldutils::i32_to_felt;
|
||||
use crate::graph::{utilities::vector_to_quantized, Model, ModelCircuit};
|
||||
use crate::tensor::{Tensor, TensorType};
|
||||
use halo2_proofs::arithmetic::FieldExt;
|
||||
use halo2_proofs::plonk::{
|
||||
create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
};
|
||||
@@ -15,12 +15,12 @@ use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{
|
||||
Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer,
|
||||
};
|
||||
use halo2_proofs::{arithmetic::FieldExt, dev::VerifyFailure};
|
||||
use log::{error, info, trace};
|
||||
use log::{info, trace};
|
||||
use rand::rngs::OsRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter, Read, Write};
|
||||
use std::io::{self, BufReader, BufWriter, Read, Write};
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
@@ -49,180 +49,95 @@ pub struct Proof {
|
||||
|
||||
impl Proof {
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
abort!("failed to convert proof json to string {:?}", e);
|
||||
}
|
||||
};
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let serialized = serde_json::to_string(&self).map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
let mut file = std::fs::File::create(proof_path).expect("create failed");
|
||||
file.write_all(serialized.as_bytes()).expect("write failed");
|
||||
let mut file = std::fs::File::create(proof_path).map_err(Box::<dyn Error>::from)?;
|
||||
file.write_all(serialized.as_bytes())
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
|
||||
/// Load a json serialized proof from the provided path.
|
||||
pub fn load(proof_path: &PathBuf) -> Self {
|
||||
let mut file = match File::open(proof_path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
abort!("failed to open proof file {:?}", e);
|
||||
}
|
||||
};
|
||||
pub fn load(proof_path: &PathBuf) -> Result<Self, Box<dyn Error>> {
|
||||
let mut file = File::open(proof_path).map_err(Box::<dyn Error>::from)?;
|
||||
let mut data = String::new();
|
||||
match file.read_to_string(&mut data) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
abort!("failed to read file {:?}", e);
|
||||
}
|
||||
};
|
||||
serde_json::from_str(&data).expect("JSON was not well-formatted")
|
||||
file.read_to_string(&mut data)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
serde_json::from_str(&data).map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to print helpful error messages after verification has failed.
|
||||
pub fn parse_prover_errors(f: &VerifyFailure) {
|
||||
match f {
|
||||
VerifyFailure::Lookup {
|
||||
name,
|
||||
location,
|
||||
lookup_index,
|
||||
} => {
|
||||
error!("lookup {:?} is out of range, try increasing 'bits' or reducing 'scale' ({} and lookup index {}).",
|
||||
name, location, lookup_index);
|
||||
}
|
||||
VerifyFailure::ConstraintNotSatisfied {
|
||||
constraint,
|
||||
location,
|
||||
cell_values: _,
|
||||
} => {
|
||||
error!("{} was not satisfied {}).", constraint, location);
|
||||
}
|
||||
VerifyFailure::ConstraintPoisoned { constraint } => {
|
||||
error!("constraint {:?} was poisoned", constraint);
|
||||
}
|
||||
VerifyFailure::Permutation { column, location } => {
|
||||
error!(
|
||||
"permutation did not preserve column cell value (try increasing 'scale') ({} {}).",
|
||||
column, location
|
||||
);
|
||||
}
|
||||
VerifyFailure::CellNotAssigned {
|
||||
gate,
|
||||
region,
|
||||
gate_offset,
|
||||
column,
|
||||
offset,
|
||||
} => {
|
||||
error!(
|
||||
"Unnassigned value in {} ({}) and {} ({:?}, {})",
|
||||
gate, region, gate_offset, column, offset
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
type CircuitInputs<F> = (ModelCircuit<F>, Vec<Tensor<i32>>);
|
||||
|
||||
/// Initialize the model circuit and quantize the provided float inputs from the provided `ModelInput`.
|
||||
pub fn prepare_circuit_and_public_input<F: FieldExt>(
|
||||
data: &ModelInput,
|
||||
args: &Cli,
|
||||
) -> (ModelCircuit<F>, Vec<Tensor<i32>>) {
|
||||
let model = Model::from_ezkl_conf(args.clone());
|
||||
) -> Result<CircuitInputs<F>, Box<dyn Error>> {
|
||||
let model = Model::from_ezkl_conf(args.clone())?;
|
||||
let out_scales = model.get_output_scales();
|
||||
let circuit = prepare_circuit(data, args);
|
||||
let circuit = prepare_circuit(data, args)?;
|
||||
|
||||
// quantize the supplied data using the provided scale.
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs = vec![];
|
||||
if model.visibility.input.is_public() {
|
||||
let mut res = data
|
||||
.input_data
|
||||
.iter()
|
||||
.map(
|
||||
|v| match vector_to_quantized(v, &Vec::from([v.len()]), 0.0, model.scale) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
abort!("failed to quantize vector {:?}", e);
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
public_inputs.append(&mut res);
|
||||
for v in data.input_data.iter() {
|
||||
let t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, model.scale)?;
|
||||
public_inputs.push(t);
|
||||
}
|
||||
}
|
||||
if model.visibility.output.is_public() {
|
||||
let mut res = data
|
||||
.output_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| {
|
||||
match vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx]) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
abort!("failed to quantize vector {:?}", e);
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
public_inputs.append(&mut res);
|
||||
for (idx, v) in data.output_data.iter().enumerate() {
|
||||
let t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx])?;
|
||||
public_inputs.push(t);
|
||||
}
|
||||
}
|
||||
info!(
|
||||
"public inputs lengths: {:?}",
|
||||
"public inputs lengths: {:?}",
|
||||
public_inputs
|
||||
.iter()
|
||||
.map(|i| i.len())
|
||||
.collect::<Vec<usize>>()
|
||||
);
|
||||
trace!("{:?}", public_inputs);
|
||||
(circuit, public_inputs)
|
||||
|
||||
Ok((circuit, public_inputs))
|
||||
}
|
||||
|
||||
/// Initialize the model circuit
|
||||
pub fn prepare_circuit<F: FieldExt>(data: &ModelInput, args: &Cli) -> ModelCircuit<F> {
|
||||
pub fn prepare_circuit<F: FieldExt>(
|
||||
data: &ModelInput,
|
||||
args: &Cli,
|
||||
) -> Result<ModelCircuit<F>, Box<dyn Error>> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let inputs = data
|
||||
.input_data
|
||||
.iter()
|
||||
.zip(data.input_shapes.clone())
|
||||
.map(|(i, s)| match vector_to_quantized(i, &s, 0.0, args.scale) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
abort!("failed to quantize vector {:?}", e);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut inputs: Vec<Tensor<i32>> = vec![];
|
||||
for (input, shape) in data.input_data.iter().zip(data.input_shapes.clone()) {
|
||||
let t = vector_to_quantized(input, &shape, 0.0, args.scale)?;
|
||||
inputs.push(t);
|
||||
}
|
||||
|
||||
ModelCircuit::<F> {
|
||||
Ok(ModelCircuit::<F> {
|
||||
inputs,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserializes the required inputs to a model at path `datapath` to a [ModelInput] struct.
|
||||
pub fn prepare_data(datapath: String) -> ModelInput {
|
||||
let mut file = match File::open(data_path(datapath)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
abort!("failed to open data file {:?}", e);
|
||||
}
|
||||
};
|
||||
pub fn prepare_data(datapath: String) -> Result<ModelInput, Box<dyn Error>> {
|
||||
let mut file = File::open(data_path(datapath)).map_err(Box::<dyn Error>::from)?;
|
||||
let mut data = String::new();
|
||||
match file.read_to_string(&mut data) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
abort!("failed to read file {:?}", e);
|
||||
}
|
||||
};
|
||||
let data: ModelInput = serde_json::from_str(&data).expect("JSON was not well-formatted");
|
||||
|
||||
data
|
||||
file.read_to_string(&mut data)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
serde_json::from_str(&data).map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
|
||||
/// Creates a [VerifyingKey] and [ProvingKey] for a [ModelCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`).
|
||||
pub fn create_keys<Scheme: CommitmentScheme, F: FieldExt + TensorType>(
|
||||
circuit: &ModelCircuit<F>,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
) -> ProvingKey<Scheme::Curve>
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
ModelCircuit<F>: Circuit<Scheme::Scalar>,
|
||||
{
|
||||
@@ -232,12 +147,12 @@ where
|
||||
// Initialize the proving key
|
||||
let now = Instant::now();
|
||||
trace!("preparing VK");
|
||||
let vk = keygen_vk(params, &empty_circuit).expect("keygen_vk should not fail");
|
||||
let vk = keygen_vk(params, &empty_circuit)?;
|
||||
info!("VK took {}", now.elapsed().as_secs());
|
||||
let now = Instant::now();
|
||||
let pk = keygen_pk(params, vk, &empty_circuit).expect("keygen_pk should not fail");
|
||||
let pk = keygen_pk(params, vk, &empty_circuit)?;
|
||||
info!("PK took {}", now.elapsed().as_secs());
|
||||
pk
|
||||
Ok(pk)
|
||||
}
|
||||
|
||||
/// a wrapper around halo2's create_proof
|
||||
@@ -251,7 +166,7 @@ pub fn create_proof_model<
|
||||
public_inputs: &[Tensor<i32>],
|
||||
params: &'params Scheme::ParamsProver,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
) -> (Proof, Vec<Vec<usize>>)
|
||||
) -> Result<(Proof, Vec<Vec<usize>>), halo2_proofs::plonk::Error>
|
||||
where
|
||||
ModelCircuit<F>: Circuit<Scheme::Scalar>,
|
||||
{
|
||||
@@ -282,8 +197,7 @@ where
|
||||
instances,
|
||||
&mut rng,
|
||||
&mut transcript,
|
||||
)
|
||||
.expect("proof generation should not fail");
|
||||
)?;
|
||||
let proof = transcript.finalize();
|
||||
info!("Proof took {}", now.elapsed().as_secs());
|
||||
|
||||
@@ -295,7 +209,7 @@ where
|
||||
proof,
|
||||
};
|
||||
|
||||
(checkable_pf, dims)
|
||||
Ok((checkable_pf, dims))
|
||||
}
|
||||
|
||||
/// A wrapper around halo2's verify_proof
|
||||
@@ -310,7 +224,7 @@ pub fn verify_proof_model<
|
||||
params: &'params Scheme::ParamsVerifier,
|
||||
vk: &VerifyingKey<Scheme::Curve>,
|
||||
strategy: Strategy,
|
||||
) -> bool
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
|
||||
where
|
||||
ModelCircuit<F>: Circuit<Scheme::Scalar>,
|
||||
{
|
||||
@@ -332,60 +246,57 @@ where
|
||||
|
||||
let now = Instant::now();
|
||||
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof.proof[..]);
|
||||
|
||||
let result =
|
||||
verify_proof::<Scheme, V, _, _, _>(params, vk, strategy, instances, &mut transcript)
|
||||
.is_ok();
|
||||
info!("verify took {}", now.elapsed().as_secs());
|
||||
result
|
||||
verify_proof::<Scheme, V, _, _, _>(params, vk, strategy, instances, &mut transcript)
|
||||
}
|
||||
|
||||
/// Loads a [VerifyingKey] at `path`.
|
||||
pub fn load_vk<Scheme: CommitmentScheme, F: FieldExt + TensorType>(
|
||||
path: PathBuf,
|
||||
params: &'_ Scheme::ParamsVerifier,
|
||||
) -> VerifyingKey<Scheme::Curve>
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
where
|
||||
ModelCircuit<F>: Circuit<Scheme::Scalar>,
|
||||
{
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f = match File::open(path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
abort!("failed to load vk {}", e);
|
||||
}
|
||||
};
|
||||
let f = File::open(path).map_err(Box::<dyn Error>::from)?;
|
||||
let mut reader = BufReader::new(f);
|
||||
VerifyingKey::<Scheme::Curve>::read::<_, ModelCircuit<F>>(&mut reader, params).unwrap()
|
||||
VerifyingKey::<Scheme::Curve>::read::<_, ModelCircuit<F>>(&mut reader, params)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_params<Scheme: CommitmentScheme>(path: PathBuf) -> Scheme::ParamsVerifier {
|
||||
pub fn load_params<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
|
||||
info!("loading params from {:?}", path);
|
||||
let f = match File::open(path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
abort!("failed to load params {}", e);
|
||||
}
|
||||
};
|
||||
let f = File::open(path).map_err(Box::<dyn Error>::from)?;
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).unwrap()
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
|
||||
/// Saves a [VerifyingKey] to `path`.
|
||||
pub fn save_vk<Scheme: CommitmentScheme>(path: &PathBuf, vk: &VerifyingKey<Scheme::Curve>) {
|
||||
pub fn save_vk<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
vk: &VerifyingKey<Scheme::Curve>,
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving verification key 💾");
|
||||
let f = File::create(path).unwrap();
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer).unwrap();
|
||||
writer.flush().unwrap();
|
||||
vk.write(&mut writer)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Saves [CommitmentScheme] parameters to `path`.
|
||||
pub fn save_params<Scheme: CommitmentScheme>(path: &PathBuf, params: &'_ Scheme::ParamsVerifier) {
|
||||
pub fn save_params<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
params: &'_ Scheme::ParamsVerifier,
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
let f = File::create(path).unwrap();
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
params.write(&mut writer).unwrap();
|
||||
writer.flush().unwrap();
|
||||
params.write(&mut writer)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -18,11 +18,26 @@ use halo2_proofs::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use std::cmp::max;
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::Deref;
|
||||
use std::ops::DerefMut;
|
||||
use std::ops::Range;
|
||||
use thiserror::Error;
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("dimension mismatch in tensor op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Shape when instantiating
|
||||
#[error("dimensionality error when manipulating a tensor")]
|
||||
DimError,
|
||||
/// wrong method was called on a tensor-like struct
|
||||
#[error("wrong method called")]
|
||||
WrongMethod,
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
@@ -173,10 +188,6 @@ impl TensorType for halo2curves::bn256::Fr {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug)]
|
||||
pub struct TensorError(String);
|
||||
|
||||
/// A generic multi-dimensional array representation of a Tensor.
|
||||
/// The `inner` attribute contains a vector of values whereas `dims` corresponds to the dimensionality of the array
|
||||
/// and as such determines how we index, query for values, or slice a Tensor.
|
||||
@@ -217,13 +228,24 @@ impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Iterator, T: Clone + TensorType + From<I::Item>> From<I> for Tensor<T>
|
||||
impl<I: Iterator> From<I> for Tensor<I::Item>
|
||||
where
|
||||
I::Item: Clone + TensorType,
|
||||
Vec<T>: FromIterator<I::Item>,
|
||||
I::Item: TensorType + Clone,
|
||||
Vec<I::Item>: FromIterator<I::Item>,
|
||||
{
|
||||
fn from(value: I) -> Tensor<T> {
|
||||
let data: Vec<T> = value.collect::<Vec<T>>();
|
||||
fn from(value: I) -> Tensor<I::Item> {
|
||||
let data: Vec<I::Item> = value.collect::<Vec<I::Item>>();
|
||||
Tensor::new(Some(&data), &[data.len()]).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FromIterator<T> for Tensor<T>
|
||||
where
|
||||
T: TensorType + Clone,
|
||||
Vec<T>: FromIterator<T>,
|
||||
{
|
||||
fn from_iter<I: IntoIterator<Item = T>>(value: I) -> Tensor<T> {
|
||||
let data: Vec<I::Item> = value.into_iter().collect::<Vec<I::Item>>();
|
||||
Tensor::new(Some(&data), &[data.len()]).unwrap()
|
||||
}
|
||||
}
|
||||
@@ -292,9 +314,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
match values {
|
||||
Some(v) => {
|
||||
if total_dims != v.len() {
|
||||
return Err(TensorError(
|
||||
"length of values array is not equal to tensor total elements".to_string(),
|
||||
));
|
||||
return Err(TensorError::DimError);
|
||||
}
|
||||
Ok(Tensor {
|
||||
inner: Vec::from(v),
|
||||
@@ -355,10 +375,12 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_slice(&[0..2]), b);
|
||||
/// assert_eq!(a.get_slice(&[0..2]).unwrap(), b);
|
||||
/// ```
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Tensor<T> {
|
||||
assert!(self.dims.len() >= indices.len());
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<Tensor<T>, TensorError> {
|
||||
if self.dims.len() < indices.len() {
|
||||
return Err(TensorError::DimError);
|
||||
}
|
||||
let mut res = Vec::new();
|
||||
// if indices weren't specified we fill them in as required
|
||||
let mut full_indices = indices.to_vec();
|
||||
@@ -377,7 +399,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::new(Some(&res), &dims).unwrap()
|
||||
Tensor::new(Some(&res), &dims)
|
||||
}
|
||||
|
||||
/// Get the array index from rows / columns indices.
|
||||
@@ -446,16 +468,22 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
|
||||
/// Maps a function to tensors and enumerates
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map(|i, x| i32::pow(x + i as i32, 2)).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn enum_map<F: FnMut(usize, T) -> G, G: TensorType>(
|
||||
pub fn enum_map<F: FnMut(usize, T) -> Result<G, E>, G: TensorType, E: Error>(
|
||||
&self,
|
||||
mut f: F,
|
||||
) -> Result<Tensor<G>, TensorError> {
|
||||
let mut t = Tensor::from(self.inner.iter().enumerate().map(|(i, e)| f(i, e.clone())));
|
||||
) -> Result<Tensor<G>, E> {
|
||||
let vec: Result<Vec<G>, E> = self
|
||||
.inner
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| f(i, e.clone()))
|
||||
.collect();
|
||||
let mut t: Tensor<G> = Tensor::from(vec?.iter().cloned());
|
||||
t.reshape(self.dims());
|
||||
Ok(t)
|
||||
}
|
||||
@@ -540,6 +568,6 @@ mod tests {
|
||||
fn tensor_slice() {
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]), b);
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::TensorError;
|
||||
use crate::tensor::{Tensor, TensorType};
|
||||
use itertools::Itertools;
|
||||
pub use std::ops::{Add, Div, Mul, Sub};
|
||||
@@ -23,17 +24,20 @@ pub use std::ops::{Add, Div, Mul, Sub};
|
||||
/// Some(&[0, 0]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = affine(&vec![x, k, b]);
|
||||
/// let result = affine(&vec![x, k, b]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[26, 7, 11, 3, 15, 3, 7, 2]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn affine<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
inputs: &Vec<Tensor<T>>,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(inputs.len(), 3);
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let (mut input, kernel, bias) = (inputs[0].clone(), inputs[1].clone(), inputs[2].clone());
|
||||
assert_eq!(bias.dims()[0], kernel.dims()[0]);
|
||||
assert_eq!(input.dims()[0], kernel.dims()[1]);
|
||||
if (inputs.len() != 3)
|
||||
|| (bias.dims()[0] != kernel.dims()[0])
|
||||
|| (input.dims()[0] != kernel.dims()[1])
|
||||
{
|
||||
return Err(TensorError::DimMismatch("affine".to_string()));
|
||||
}
|
||||
|
||||
// does matrix to vector multiplication
|
||||
if input.dims().len() == 1 {
|
||||
@@ -49,9 +53,9 @@ pub fn affine<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
for i in 0..kernel_dims[0] {
|
||||
for j in 0..input_dims[1] {
|
||||
let prod = dot(&vec![
|
||||
&kernel.get_slice(&[i..i + 1]),
|
||||
&input.get_slice(&[0..input_dims[0], j..j + 1]),
|
||||
]);
|
||||
&kernel.get_slice(&[i..i + 1])?,
|
||||
&input.get_slice(&[0..input_dims[0], j..j + 1])?,
|
||||
])?;
|
||||
output.set(&[i, j], prod[0].clone() + bias[i].clone());
|
||||
}
|
||||
}
|
||||
@@ -59,23 +63,50 @@ pub fn affine<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
if output.dims()[1] == 1 {
|
||||
output.flatten();
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Scales and shifts a tensor.
|
||||
/// Given inputs (x,k,b) computes k*x + b elementwise
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `inputs` - Vector of tensors of length 2
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::scale_and_shift;
|
||||
///
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let b = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = scale_and_shift(&vec![x, k, b]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[6, 2, 6, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn scale_and_shift<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
inputs: &Vec<Tensor<T>>,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(inputs.len(), 3);
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if (inputs.len() != 3)
|
||||
|| (inputs[1].dims() != inputs[2].dims())
|
||||
|| (inputs[0].dims() != inputs[1].dims())
|
||||
{
|
||||
return Err(TensorError::DimMismatch("scale and shift".to_string()));
|
||||
}
|
||||
let (input, kernel, bias) = (inputs[0].clone(), inputs[1].clone(), inputs[2].clone());
|
||||
assert_eq!(input.dims(), kernel.dims());
|
||||
assert_eq!(bias.dims(), kernel.dims());
|
||||
let mut output: Tensor<T> = input;
|
||||
for (i, bias_i) in bias.iter().enumerate() {
|
||||
output[i] = kernel[i].clone() * output[i].clone() + bias_i.clone()
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Matrix multiplies two 2D tensors.
|
||||
@@ -95,20 +126,20 @@ pub fn scale_and_shift<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = matmul(&vec![k, x]);
|
||||
/// let result = matmul(&vec![k, x]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[26, 7, 11, 3, 15, 3, 7, 2]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn matmul<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
inputs: &Vec<Tensor<T>>,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(inputs.len(), 2);
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let (a, b) = (inputs[0].clone(), inputs[1].clone());
|
||||
assert_eq!(a.dims()[a.dims().len() - 1], b.dims()[a.dims().len() - 2]);
|
||||
assert_eq!(
|
||||
a.dims()[0..a.dims().len() - 2],
|
||||
b.dims()[0..a.dims().len() - 2]
|
||||
);
|
||||
if (inputs.len() != 2)
|
||||
|| (a.dims()[a.dims().len() - 1] != b.dims()[a.dims().len() - 2])
|
||||
|| (a.dims()[0..a.dims().len() - 2] != b.dims()[0..a.dims().len() - 2])
|
||||
{
|
||||
return Err(TensorError::DimMismatch("matmul".to_string()));
|
||||
}
|
||||
|
||||
let mut dims = Vec::from(&a.dims()[0..a.dims().len() - 2]);
|
||||
dims.push(a.dims()[a.dims().len() - 2]);
|
||||
@@ -128,11 +159,11 @@ pub fn matmul<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
.map(|&d| d..(d + 1))
|
||||
.collect::<Vec<_>>();
|
||||
col[coord.len() - 2] = 0..b.dims()[coord.len() - 2];
|
||||
let prod = dot(&vec![&a.get_slice(&row[0..]), &b.get_slice(&col[0..])]);
|
||||
let prod = dot(&vec![&a.get_slice(&row[0..])?, &b.get_slice(&col[0..])?])?;
|
||||
output.set(&coord, prod[0].clone());
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Adds multiple tensors.
|
||||
@@ -151,17 +182,19 @@ pub fn matmul<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = add(&vec![x, k]);
|
||||
/// let result = add(&vec![x, k]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Result<Tensor<T>, TensorError> {
|
||||
// determines if we're multiplying by a 1D const
|
||||
if t.len() == 2 && t[1].dims().len() == 1 && t[1].dims()[0] == 1 {
|
||||
return const_add(&t[0], t[1][0].clone());
|
||||
}
|
||||
for e in t.iter() {
|
||||
assert_eq!(t[0].dims(), e.dims());
|
||||
if t[0].dims() != e.dims() {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
@@ -172,7 +205,7 @@ pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Elementwise adds a tensor with a const element.
|
||||
@@ -189,11 +222,14 @@ pub fn add<T: TensorType + Add<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2;
|
||||
/// let result = const_add(&x, k);
|
||||
/// let result = const_add(&x, k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn const_add<T: TensorType + Add<Output = T>>(a: &Tensor<T>, b: T) -> Tensor<T> {
|
||||
pub fn const_add<T: TensorType + Add<Output = T>>(
|
||||
a: &Tensor<T>,
|
||||
b: T,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = a.clone();
|
||||
|
||||
@@ -201,7 +237,7 @@ pub fn const_add<T: TensorType + Add<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
|
||||
output[i] = output[i].clone() + b.clone();
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Subtracts multiple tensors.
|
||||
@@ -221,18 +257,20 @@ pub fn const_add<T: TensorType + Add<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sub(&vec![x, k]);
|
||||
/// let result = sub(&vec![x, k]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Result<Tensor<T>, TensorError> {
|
||||
// determines if we're multiplying by a 1D const
|
||||
if t.len() == 2 && t[1].dims().len() == 1 && t[1].dims()[0] == 1 {
|
||||
return const_sub(&t[0], t[1][0].clone());
|
||||
}
|
||||
|
||||
for e in t.iter() {
|
||||
assert_eq!(t[0].dims(), e.dims());
|
||||
if t[0].dims() != e.dims() {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
@@ -243,7 +281,7 @@ pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Elementwise subtracts a tensor with a const element.
|
||||
@@ -260,11 +298,14 @@ pub fn sub<T: TensorType + Sub<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2;
|
||||
/// let result = const_sub(&x, k);
|
||||
/// let result = const_sub(&x, k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn const_sub<T: TensorType + Sub<Output = T>>(a: &Tensor<T>, b: T) -> Tensor<T> {
|
||||
pub fn const_sub<T: TensorType + Sub<Output = T>>(
|
||||
a: &Tensor<T>,
|
||||
b: T,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = a.clone();
|
||||
|
||||
@@ -272,7 +313,7 @@ pub fn const_sub<T: TensorType + Sub<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
|
||||
output[i] = output[i].clone() - b.clone();
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Elementwise multiplies two tensors.
|
||||
@@ -292,18 +333,20 @@ pub fn const_sub<T: TensorType + Sub<Output = T>>(a: &Tensor<T>, b: T) -> Tensor
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = mult(&vec![x, k]);
|
||||
/// let result = mult(&vec![x, k]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Result<Tensor<T>, TensorError> {
|
||||
// determines if we're multiplying by a 1D const
|
||||
if t.len() == 2 && t[1].dims().len() == 1 && t[1].dims()[0] == 1 {
|
||||
return const_mult(&t[0], t[1][0].clone());
|
||||
}
|
||||
|
||||
for e in t.iter() {
|
||||
assert_eq!(t[0].dims(), e.dims());
|
||||
if t[0].dims() != e.dims() {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
@@ -314,7 +357,7 @@ pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Elementwise divide a tensor with another tensor.
|
||||
@@ -334,19 +377,24 @@ pub fn mult<T: TensorType + Mul<Output = T>>(t: &Vec<Tensor<T>>) -> Tensor<T> {
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = div(x, y);
|
||||
/// let result = div(x, y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn div<T: TensorType + Div<Output = T>>(t: Tensor<T>, d: Tensor<T>) -> Tensor<T> {
|
||||
assert_eq!(t.dims(), d.dims());
|
||||
pub fn div<T: TensorType + Div<Output = T>>(
|
||||
t: Tensor<T>,
|
||||
d: Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.dims() != d.dims() {
|
||||
return Err(TensorError::DimMismatch("div".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t;
|
||||
|
||||
for (i, d_i) in d.iter().enumerate() {
|
||||
output[i] = output[i].clone() / d_i.clone()
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Elementwise multiplies a tensor with a const element.
|
||||
@@ -363,11 +411,14 @@ pub fn div<T: TensorType + Div<Output = T>>(t: Tensor<T>, d: Tensor<T>) -> Tenso
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2;
|
||||
/// let result = const_mult(&x, k);
|
||||
/// let result = const_mult(&x, k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn const_mult<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, b: T) -> Tensor<T> {
|
||||
pub fn const_mult<T: TensorType + Mul<Output = T>>(
|
||||
a: &Tensor<T>,
|
||||
b: T,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = a.clone();
|
||||
|
||||
@@ -375,7 +426,7 @@ pub fn const_mult<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, b: T) -> Tenso
|
||||
output[i] = output[i].clone() * b.clone();
|
||||
}
|
||||
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Rescale a tensor with a const integer (similar to const_mult).
|
||||
@@ -392,11 +443,14 @@ pub fn const_mult<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, b: T) -> Tenso
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2;
|
||||
/// let result = rescale(&x, k);
|
||||
/// let result = rescale(&x, k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn rescale<T: TensorType + Add<Output = T>>(a: &Tensor<T>, mult: usize) -> Tensor<T> {
|
||||
pub fn rescale<T: TensorType + Add<Output = T>>(
|
||||
a: &Tensor<T>,
|
||||
mult: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = a.clone();
|
||||
for (i, a_i) in a.iter().enumerate() {
|
||||
@@ -404,7 +458,7 @@ pub fn rescale<T: TensorType + Add<Output = T>>(a: &Tensor<T>, mult: usize) -> T
|
||||
output[i] = output[i].clone() + a_i.clone();
|
||||
}
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Elementwise raise a tensor to the nth power.
|
||||
@@ -420,11 +474,14 @@ pub fn rescale<T: TensorType + Add<Output = T>>(a: &Tensor<T>, mult: usize) -> T
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = pow(&x, 3);
|
||||
/// let result = pow(&x, 3).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn pow<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, pow: usize) -> Tensor<T> {
|
||||
pub fn pow<T: TensorType + Mul<Output = T>>(
|
||||
a: &Tensor<T>,
|
||||
pow: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = a.clone();
|
||||
for (i, a_i) in a.iter().enumerate() {
|
||||
@@ -432,7 +489,7 @@ pub fn pow<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, pow: usize) -> Tensor
|
||||
output[i] = output[i].clone() * a_i.clone();
|
||||
}
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Sums a tensor.
|
||||
@@ -448,15 +505,15 @@ pub fn pow<T: TensorType + Mul<Output = T>>(a: &Tensor<T>, pow: usize) -> Tensor
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sum(&x);
|
||||
/// let result = sum(&x).unwrap();
|
||||
/// let expected = 21;
|
||||
/// assert_eq!(result[0], expected);
|
||||
/// ```
|
||||
pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Tensor<T> {
|
||||
pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut res = T::zero().unwrap();
|
||||
let _ = a.map(|a_i| res = res.clone() + a_i);
|
||||
Tensor::new(Some(&[res]), &[1]).unwrap()
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Applies convolution over a 3D tensor of shape C x H x W (and adds a bias).
|
||||
@@ -482,7 +539,7 @@ pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Tensor<T> {
|
||||
/// Some(&[0]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = convolution::<i32>(&vec![x, k, b], (0, 0), (1, 1));
|
||||
/// let result = convolution::<i32>(&vec![x, k, b], (0, 0), (1, 1)).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[31, 16, 8, 26]), &[1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
@@ -490,17 +547,22 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
inputs: &Vec<Tensor<T>>,
|
||||
padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Tensor<T> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let has_bias = inputs.len() == 3;
|
||||
let (image, kernel) = (inputs[0].clone(), inputs[1].clone());
|
||||
|
||||
assert_eq!(image.dims().len(), 3);
|
||||
assert_eq!(kernel.dims().len(), 4);
|
||||
assert_eq!(image.dims()[0], kernel.dims()[1]);
|
||||
if (image.dims().len() != 3)
|
||||
|| (kernel.dims().len() != 4)
|
||||
|| (image.dims()[0] != kernel.dims()[1])
|
||||
{
|
||||
return Err(TensorError::DimMismatch("conv".to_string()));
|
||||
}
|
||||
|
||||
if has_bias {
|
||||
let bias = inputs[2].clone();
|
||||
assert_eq!(bias.dims().len(), 1);
|
||||
assert_eq!(bias.dims()[0], kernel.dims()[0]);
|
||||
if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) {
|
||||
return Err(TensorError::DimMismatch("conv bias".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let image_dims = image.dims();
|
||||
@@ -515,7 +577,7 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
|
||||
let (image_height, image_width) = (image_dims[1], image_dims[2]);
|
||||
|
||||
let padded_image = pad::<T>(image.clone(), padding);
|
||||
let padded_image = pad::<T>(image.clone(), padding)?;
|
||||
|
||||
let vert_slides = (image_height + 2 * padding.0 - kernel_height) / stride.0 + 1;
|
||||
let horz_slides = (image_width + 2 * padding.1 - kernel_width) / stride.1 + 1;
|
||||
@@ -530,13 +592,13 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
for k in 0..horz_slides {
|
||||
let cs = k * stride.1;
|
||||
let mut res = dot(&vec![
|
||||
&kernel.get_slice(&[i..i + 1]).clone(),
|
||||
&kernel.get_slice(&[i..i + 1])?.clone(),
|
||||
&padded_image.get_slice(&[
|
||||
0..input_channels,
|
||||
rs..(rs + kernel_height),
|
||||
cs..(cs + kernel_width),
|
||||
]),
|
||||
]);
|
||||
])?,
|
||||
])?;
|
||||
|
||||
if has_bias {
|
||||
// increment result by the bias
|
||||
@@ -547,7 +609,7 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
}
|
||||
}
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Applies 2D sum pooling over a 3D tensor of shape C x H x W.
|
||||
@@ -569,7 +631,7 @@ pub fn convolution<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let pooled = sumpool::<i32>(&x, (0, 0), (1, 1), (2, 2));
|
||||
/// let pooled = sumpool::<i32>(&x, (0, 0), (1, 1), (2, 2)).unwrap();
|
||||
/// let expected: Tensor<i32> = Tensor::<i32>::new(Some(&[11, 8, 8, 10]), &[1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
/// ```
|
||||
@@ -578,8 +640,10 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(image.dims().len(), 3);
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if image.dims().len() != 3 {
|
||||
return Err(TensorError::DimMismatch("sumpool".to_string()));
|
||||
}
|
||||
let image_dims = image.dims();
|
||||
|
||||
let (image_channels, image_height, image_width) = (image_dims[0], image_dims[1], image_dims[2]);
|
||||
@@ -587,7 +651,7 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
let (output_channels, kernel_height, kernel_width) =
|
||||
(image_channels, kernel_shape.0, kernel_shape.1);
|
||||
|
||||
let padded_image = pad::<T>(image.clone(), padding);
|
||||
let padded_image = pad::<T>(image.clone(), padding)?;
|
||||
|
||||
let vert_slides = (image_height + 2 * padding.0 - kernel_height) / stride.0 + 1;
|
||||
let horz_slides = (image_width + 2 * padding.1 - kernel_width) / stride.1 + 1;
|
||||
@@ -605,12 +669,12 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
i..i + 1,
|
||||
rs..(rs + kernel_height),
|
||||
cs..(cs + kernel_width),
|
||||
]));
|
||||
])?)?;
|
||||
output.set(&[i, j, k], thesum[0].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Applies 2D max pooling over a 3D tensor of shape C x H x W.
|
||||
@@ -632,7 +696,7 @@ pub fn sumpool<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let pooled = max_pool2d::<i32>(&x, (0, 0), (1, 1), (2, 2));
|
||||
/// let pooled = max_pool2d::<i32>(&x, (0, 0), (1, 1), (2, 2)).unwrap();
|
||||
/// let expected: Tensor<i32> = Tensor::<i32>::new(Some(&[5, 4, 4, 6]), &[1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
/// ```
|
||||
@@ -641,14 +705,16 @@ pub fn max_pool2d<T: TensorType>(
|
||||
padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
pool_dims: (usize, usize),
|
||||
) -> Tensor<T> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if image.dims().len() != 3 {
|
||||
return Err(TensorError::DimMismatch("max_pool2d".to_string()));
|
||||
}
|
||||
let image_dims = image.dims();
|
||||
assert_eq!(image_dims.len(), 3);
|
||||
|
||||
let input_channels = image_dims[0];
|
||||
let (image_height, image_width) = (image_dims[1], image_dims[2]);
|
||||
|
||||
let padded_image = pad::<T>(image.clone(), padding);
|
||||
let padded_image = pad::<T>(image.clone(), padding)?;
|
||||
|
||||
let horz_slides = (image_height + 2 * padding.0 - pool_dims.0) / stride.0 + 1;
|
||||
let vert_slides = (image_width + 2 * padding.1 - pool_dims.1) / stride.1 + 1;
|
||||
@@ -671,7 +737,7 @@ pub fn max_pool2d<T: TensorType>(
|
||||
output.set(
|
||||
&[i, j, k],
|
||||
padded_image
|
||||
.get_slice(&[i..(i + 1), rs..(rs + pool_dims.0), cs..(cs + pool_dims.1)])
|
||||
.get_slice(&[i..(i + 1), rs..(rs + pool_dims.0), cs..(cs + pool_dims.1)])?
|
||||
.into_iter()
|
||||
.fold(None, fmax)
|
||||
.unwrap(),
|
||||
@@ -679,7 +745,7 @@ pub fn max_pool2d<T: TensorType>(
|
||||
}
|
||||
}
|
||||
}
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Dot product of two tensors.
|
||||
@@ -699,19 +765,20 @@ pub fn max_pool2d<T: TensorType>(
|
||||
/// Some(&[5, 5, 10, -4, 2, -1, 2, 0, 1]),
|
||||
/// &[1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(dot(&vec![&x, &y])[0], 86);
|
||||
/// assert_eq!(dot(&vec![&x, &y]).unwrap()[0], 86);
|
||||
/// ```
|
||||
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
inputs: &Vec<&Tensor<T>>,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(inputs.len(), 2);
|
||||
assert_eq!(inputs[0].clone().len(), inputs[1].clone().len());
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if (inputs.len() != 2) || (inputs[0].clone().len() != inputs[1].clone().len()) {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()));
|
||||
}
|
||||
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
|
||||
let res = a
|
||||
.iter()
|
||||
.zip(b)
|
||||
.fold(T::zero().unwrap(), |acc, (k, i)| acc + k.clone() * i);
|
||||
Tensor::new(Some(&[res]), &[1]).unwrap()
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Pads a 3D tensor of shape `C x H x W` to a tensor of shape `C x (H + 2xPADDING) x (W + 2xPADDING)` using 0 values.
|
||||
@@ -728,15 +795,20 @@ pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = pad::<i32>(x, (1, 1));
|
||||
/// let result = pad::<i32>(x, (1, 1)).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(
|
||||
/// Some(&[0, 0, 0, 0, 0, 0, 5, 2, 3, 0, 0, 0, 4, -1, 0, 0, 3, 1, 6, 0, 0, 0, 0, 0, 0]),
|
||||
/// &[1, 5, 5],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn pad<T: TensorType>(image: Tensor<T>, padding: (usize, usize)) -> Tensor<T> {
|
||||
assert_eq!(image.dims().len(), 3);
|
||||
pub fn pad<T: TensorType>(
|
||||
image: Tensor<T>,
|
||||
padding: (usize, usize),
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if image.dims().len() != 3 {
|
||||
return Err(TensorError::DimMismatch("pad".to_string()));
|
||||
}
|
||||
let (channels, height, width) = (image.dims()[0], image.dims()[1], image.dims()[2]);
|
||||
let padded_height = height + 2 * padding.0;
|
||||
let padded_width = width + 2 * padding.1;
|
||||
@@ -755,7 +827,7 @@ pub fn pad<T: TensorType>(image: Tensor<T>, padding: (usize, usize)) -> Tensor<T
|
||||
}
|
||||
|
||||
output.reshape(&[channels, padded_height, padded_width]);
|
||||
output
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -73,35 +73,36 @@ impl<F: FieldExt + TensorType> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> ValTensor<F> {
|
||||
match self {
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let slice = match self {
|
||||
ValTensor::Value { inner: v, dims: _ } => {
|
||||
let slice = v.get_slice(indices);
|
||||
let slice = v.get_slice(indices)?;
|
||||
ValTensor::Value {
|
||||
inner: slice.clone(),
|
||||
dims: slice.dims().to_vec(),
|
||||
}
|
||||
}
|
||||
ValTensor::AssignedValue { inner: v, dims: _ } => {
|
||||
let slice = v.get_slice(indices);
|
||||
let slice = v.get_slice(indices)?;
|
||||
ValTensor::AssignedValue {
|
||||
inner: slice.clone(),
|
||||
dims: slice.dims().to_vec(),
|
||||
}
|
||||
}
|
||||
ValTensor::PrevAssigned { inner: v, dims: _ } => {
|
||||
let slice = v.get_slice(indices);
|
||||
let slice = v.get_slice(indices)?;
|
||||
ValTensor::PrevAssigned {
|
||||
inner: slice.clone(),
|
||||
dims: slice.dims().to_vec(),
|
||||
}
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Sets the [ValTensor]'s shape.
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) {
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, dims: d } => {
|
||||
v.reshape(new_dims);
|
||||
@@ -116,13 +117,13 @@ impl<F: FieldExt + TensorType> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { dims: d, .. } => {
|
||||
assert_eq!(
|
||||
d.iter().product::<usize>(),
|
||||
new_dims.iter().product::<usize>()
|
||||
);
|
||||
if d.iter().product::<usize>() != new_dims.iter().product::<usize>() {
|
||||
return Err(Box::new(TensorError::DimError));
|
||||
}
|
||||
*d = new_dims.to_vec();
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `flatten` on the inner [Tensor].
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
use super::*;
|
||||
use crate::abort;
|
||||
use log::error;
|
||||
use std::cmp::min;
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
/// The wrapper allows for `VarTensor`'s dimensions to differ from that of the inner (wrapped) columns.
|
||||
@@ -184,7 +182,7 @@ impl VarTensor {
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
offset: usize,
|
||||
) -> Result<Tensor<Expression<F>>, TensorError> {
|
||||
) -> Result<Tensor<Expression<F>>, halo2_proofs::plonk::Error> {
|
||||
match &self {
|
||||
VarTensor::Fixed {
|
||||
inner: fixed, dims, ..
|
||||
@@ -224,87 +222,57 @@ impl VarTensor {
|
||||
region: &mut Region<'_, F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<Tensor<AssignedCell<F, F>>, TensorError> {
|
||||
) -> Result<Tensor<AssignedCell<F, F>>, halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance {
|
||||
inner: instance, ..
|
||||
} => match &self {
|
||||
VarTensor::Advice { inner: v, dims, .. } => {
|
||||
let t = Tensor::new(None, dims).unwrap();
|
||||
t.enum_map(|coord, _: usize| {
|
||||
// this should never ever fail
|
||||
let t: Tensor<i32> = Tensor::new(None, dims).unwrap();
|
||||
t.enum_map(|coord, _| {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
|
||||
match region.assign_advice_from_instance(
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*instance,
|
||||
coord,
|
||||
v[x],
|
||||
y,
|
||||
) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
abort!("failed to assign advice from instance {:?}", e);
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
abort!("should be an advice");
|
||||
}
|
||||
_ => Err(halo2_proofs::plonk::Error::Synthesis),
|
||||
},
|
||||
ValTensor::Value { inner: v, dims: _ } => v.enum_map(|coord, k| match &self {
|
||||
ValTensor::Value { inner: v, .. } => v.enum_map(|coord, k| match &self {
|
||||
VarTensor::Fixed { inner: fixed, .. } => {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
match region.assign_fixed(|| "k", fixed[x], y, || k) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
abort!("failed to assign ValTensor to VarTensor {:?}", e);
|
||||
}
|
||||
}
|
||||
region.assign_fixed(|| "k", fixed[x], y, || k)
|
||||
}
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
match region.assign_advice(|| "k", advices[x], y, || k) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
abort!("failed to assign ValTensor to VarTensor {:?}", e);
|
||||
}
|
||||
}
|
||||
region.assign_advice(|| "k", advices[x], y, || k)
|
||||
}
|
||||
}),
|
||||
ValTensor::PrevAssigned { inner: v, dims: _ } => {
|
||||
v.enum_map(|coord, xcell| match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
match xcell.copy_advice(|| "k", region, advices[x], y) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
abort!("failed to copy ValTensor to VarTensor {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
unimplemented!()
|
||||
}
|
||||
})
|
||||
}
|
||||
ValTensor::AssignedValue { inner: v, dims: _ } => v.enum_map(|coord, k| match &self {
|
||||
ValTensor::PrevAssigned { inner: v, .. } => v.enum_map(|coord, xcell| match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
xcell.copy_advice(|| "k", region, advices[x], y)
|
||||
}
|
||||
_ => Err(halo2_proofs::plonk::Error::Synthesis),
|
||||
}),
|
||||
ValTensor::AssignedValue { inner: v, .. } => v.enum_map(|coord, k| match &self {
|
||||
VarTensor::Fixed { inner: fixed, .. } => {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
match region.assign_fixed(|| "k", fixed[x], y, || k) {
|
||||
Ok(a) => a.evaluate(),
|
||||
Err(e) => {
|
||||
abort!("failed to assign ValTensor to VarTensor {:?}", e);
|
||||
}
|
||||
Ok(a) => Ok(a.evaluate()),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
let (x, y) = self.cartesian_coord(offset + coord);
|
||||
match region.assign_advice(|| "k", advices[x], y, || k) {
|
||||
Ok(a) => a.evaluate(),
|
||||
Err(e) => {
|
||||
abort!("failed to assign ValTensor to VarTensor {:?}", e);
|
||||
}
|
||||
Ok(a) => Ok(a.evaluate()),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -3,7 +3,8 @@ use std::env::var;
|
||||
use std::process::Command;
|
||||
|
||||
lazy_static! {
|
||||
static ref CARGO_TARGET_DIR: String = var("CARGO_TARGET_DIR").unwrap_or("./target".to_string());
|
||||
static ref CARGO_TARGET_DIR: String =
|
||||
var("CARGO_TARGET_DIR").unwrap_or_else(|_| "./target".to_string());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -183,12 +184,7 @@ fn neg_mock(example_name: String, counter_example: String) {
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
fn run_example(example_name: String) {
|
||||
let status = Command::new("cargo")
|
||||
.args([
|
||||
"run",
|
||||
"--release",
|
||||
"--example",
|
||||
format!("{}", example_name).as_str(),
|
||||
])
|
||||
.args(["run", "--release", "--example", example_name.as_str()])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
Reference in New Issue
Block a user