mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c97ff84198 |
@@ -482,7 +482,7 @@
|
||||
"source": [
|
||||
"import pytest\n",
|
||||
"def test_verification():\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
|
||||
" ezkl.verify(\n",
|
||||
" proof_path_faulty,\n",
|
||||
" settings_path,\n",
|
||||
@@ -514,9 +514,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -478,12 +478,11 @@
|
||||
"import pytest\n",
|
||||
"\n",
|
||||
"def test_verification():\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
|
||||
" ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Run the test function\n",
|
||||
@@ -510,9 +509,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,11 +41,18 @@ pub async fn main() {
|
||||
);
|
||||
let res = run(command).await;
|
||||
match &res {
|
||||
Ok(_) => info!("succeeded"),
|
||||
Err(e) => error!("failed: {}", e),
|
||||
};
|
||||
Ok(_) => {
|
||||
info!("succeeded");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("{}", e);
|
||||
std::process::exit(1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
error!("no command provided");
|
||||
init_logger();
|
||||
error!("No command provided");
|
||||
std::process::exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
25
src/circuit/modules/errors.rs
Normal file
25
src/circuit/modules/errors.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for the circuit module
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ModuleError {
|
||||
/// Halo 2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] PlonkError),
|
||||
/// Wrong input type for a module
|
||||
#[error("wrong input type {0} must be {1}")]
|
||||
WrongInputType(String, String),
|
||||
/// A constant was not previously assigned
|
||||
#[error("constant was not previously assigned")]
|
||||
ConstantNotAssigned,
|
||||
/// Input length is wrong
|
||||
#[error("input length is wrong {0}")]
|
||||
InputWrongLength(usize),
|
||||
}
|
||||
|
||||
impl From<ModuleError> for PlonkError {
|
||||
fn from(_e: ModuleError) -> PlonkError {
|
||||
PlonkError::Synthesis
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,11 @@ pub mod polycommit;
|
||||
|
||||
///
|
||||
pub mod planner;
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Error},
|
||||
};
|
||||
|
||||
///
|
||||
pub mod errors;
|
||||
|
||||
use halo2_proofs::{circuit::Layouter, plonk::ConstraintSystem};
|
||||
use halo2curves::ff::PrimeField;
|
||||
pub use planner::*;
|
||||
|
||||
@@ -35,14 +36,14 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Name
|
||||
fn name(&self) -> &'static str;
|
||||
/// Run the operation the module represents
|
||||
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, Box<dyn std::error::Error>>;
|
||||
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, errors::ModuleError>;
|
||||
/// Layout inputs
|
||||
fn layout_inputs(
|
||||
&self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
input: &[ValTensor<F>],
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<Self::InputAssignments, Error>;
|
||||
) -> Result<Self::InputAssignments, errors::ModuleError>;
|
||||
/// Layout
|
||||
fn layout(
|
||||
&self,
|
||||
@@ -50,7 +51,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
input: &[ValTensor<F>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, Error>;
|
||||
) -> Result<ValTensor<F>, errors::ModuleError>;
|
||||
/// Number of instance values the module uses every time it is applied
|
||||
fn instance_increment_input(&self) -> Vec<usize>;
|
||||
/// Number of rows used by the module
|
||||
|
||||
@@ -18,6 +18,7 @@ use halo2curves::CurveAffine;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
|
||||
|
||||
use super::errors::ModuleError;
|
||||
use super::Module;
|
||||
|
||||
/// The number of instance columns used by the PolyCommit hash function
|
||||
@@ -110,7 +111,7 @@ impl Module<Fp> for PolyCommitChip {
|
||||
_: &mut impl Layouter<Fp>,
|
||||
_: &[ValTensor<Fp>],
|
||||
_: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -123,28 +124,30 @@ impl Module<Fp> for PolyCommitChip {
|
||||
input: &[ValTensor<Fp>],
|
||||
_: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
let local_constants = constants.clone();
|
||||
layouter.assign_region(
|
||||
|| "PolyCommit",
|
||||
|mut region| {
|
||||
let mut local_inner_constants = local_constants.clone();
|
||||
let res = self.config.inputs.assign(
|
||||
&mut region,
|
||||
0,
|
||||
&input[0],
|
||||
&mut local_inner_constants,
|
||||
)?;
|
||||
*constants = local_inner_constants;
|
||||
Ok(res)
|
||||
},
|
||||
)
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "PolyCommit",
|
||||
|mut region| {
|
||||
let mut local_inner_constants = local_constants.clone();
|
||||
let res = self.config.inputs.assign(
|
||||
&mut region,
|
||||
0,
|
||||
&input[0],
|
||||
&mut local_inner_constants,
|
||||
)?;
|
||||
*constants = local_inner_constants;
|
||||
Ok(res)
|
||||
},
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
Ok(vec![message])
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ use std::marker::PhantomData;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType};
|
||||
|
||||
use super::errors::ModuleError;
|
||||
use super::Module;
|
||||
|
||||
/// The number of instance columns used by the Poseidon hash function
|
||||
@@ -174,7 +175,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
@@ -185,78 +186,82 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, Error> = match &message {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
|
||||
match &message {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
|
||||
match value {
|
||||
ValType::Value(v) => region.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
|
||||
log::error!("constant not previously assigned");
|
||||
Error::Synthesis
|
||||
})?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v)
|
||||
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
constants.insert(
|
||||
*f,
|
||||
ValType::AssignedConstant(res.clone(), *f),
|
||||
);
|
||||
|
||||
Ok(res)
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"PrevAssigned".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => {
|
||||
log::error!(
|
||||
"wrong input type {:?}, must be previously assigned",
|
||||
e
|
||||
);
|
||||
Err(Error::Synthesis)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
|
||||
let offset = message.len() / WIDTH + 1;
|
||||
|
||||
@@ -277,7 +282,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
message.len(),
|
||||
start_time.elapsed()
|
||||
);
|
||||
res
|
||||
res.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// L is the number of inputs to the hash function
|
||||
@@ -289,7 +294,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
input: &[ValTensor<Fp>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
@@ -301,7 +306,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while input_cells.len() > 1 || !one_iter {
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, Error> = input_cells
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
|
||||
.chunks(L)
|
||||
.enumerate()
|
||||
.map(|(i, block)| {
|
||||
@@ -332,7 +337,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
|
||||
hash
|
||||
})
|
||||
.collect();
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into());
|
||||
|
||||
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
|
||||
one_iter = true;
|
||||
@@ -348,7 +354,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
ValType::PrevAssigned(v) => v,
|
||||
_ => {
|
||||
log::error!("wrong input type, must be previously assigned");
|
||||
return Err(Error::Synthesis);
|
||||
return Err(Error::Synthesis.into());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -380,7 +386,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
let mut hash_inputs = message;
|
||||
|
||||
let len = hash_inputs.len();
|
||||
@@ -400,7 +406,11 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
block.extend(vec![Fp::ZERO; L - remainder].iter());
|
||||
}
|
||||
|
||||
let message = block.try_into().map_err(|_| Error::Synthesis)?;
|
||||
let block_len = block.len();
|
||||
|
||||
let message = block
|
||||
.try_into()
|
||||
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
|
||||
|
||||
Ok(halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
@@ -411,7 +421,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
>::init()
|
||||
.hash(message))
|
||||
})
|
||||
.collect::<Result<Vec<_>, Error>>()?;
|
||||
.collect::<Result<Vec<_>, ModuleError>>()?;
|
||||
one_iter = true;
|
||||
hash_inputs = hashes;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
@@ -26,31 +24,11 @@ use crate::{
|
||||
},
|
||||
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
|
||||
use std::{collections::BTreeMap, marker::PhantomData};
|
||||
|
||||
use super::{lookup::LookupOp, region::RegionCtx, Op};
|
||||
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CircuitError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("dimension mismatch in circuit construction for op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Error when instantiating lookup tables
|
||||
#[error("failed to instantiate lookup tables")]
|
||||
LookupInstantiation,
|
||||
/// A lookup table was was already assigned
|
||||
#[error("attempting to initialize an already instantiated lookup table")]
|
||||
TableAlreadyAssigned,
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
|
||||
#[derive(
|
||||
@@ -513,18 +491,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
nl: &LookupOp,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
if !index.is_advice() {
|
||||
return Err("wrong input type for lookup index".into());
|
||||
return Err(CircuitError::WrongColumnType(index.name().to_string()));
|
||||
}
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
return Err(CircuitError::WrongColumnType(input.name().to_string()));
|
||||
}
|
||||
if !output.is_advice() {
|
||||
return Err("wrong input type for lookup output".into());
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
@@ -654,19 +632,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
lookups: &[VarTensor; 3],
|
||||
tables: &[VarTensor; 3],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in lookups.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
for t in tables.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,19 +715,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in inputs.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -822,12 +800,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
index: &VarTensor,
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
return Err(CircuitError::WrongColumnType(input.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
@@ -918,7 +896,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
}
|
||||
|
||||
/// layout_tables must be called before layout.
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
|
||||
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
|
||||
if !table.is_assigned {
|
||||
debug!(
|
||||
@@ -939,7 +917,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
pub fn layout_range_checks(
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
for range_check in self.range_checks.ranges.values_mut() {
|
||||
if !range_check.is_assigned {
|
||||
debug!("laying out range check for {:?}", range_check.range);
|
||||
@@ -959,7 +937,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
op.layout(self, region, values)
|
||||
}
|
||||
}
|
||||
|
||||
94
src/circuit/ops/errors.rs
Normal file
94
src/circuit/ops/errors.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::tensor::TensorError;
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for the circuit module
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CircuitError {
|
||||
/// Halo 2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] PlonkError),
|
||||
/// Tensor error
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] TensorError),
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("dimension mismatch in circuit construction for op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Error when instantiating lookup tables
|
||||
#[error("failed to instantiate lookup tables")]
|
||||
LookupInstantiation,
|
||||
/// A lookup table was was already assigned
|
||||
#[error("attempting to initialize an already instantiated lookup table")]
|
||||
TableAlreadyAssigned,
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
#[error("failed to flush, linear coord is not aligned with the next row")]
|
||||
FlushError,
|
||||
/// Constrain error
|
||||
#[error("constrain_equal: one of the tensors is assigned and the other is not")]
|
||||
ConstrainError,
|
||||
/// Failed to get lookups
|
||||
#[error("failed to get lookups for op: {0}")]
|
||||
GetLookupsError(String),
|
||||
/// Failed to get range checks
|
||||
#[error("failed to get range checks for op: {0}")]
|
||||
GetRangeChecksError(String),
|
||||
/// Failed to get dynamic lookup
|
||||
#[error("failed to get dynamic lookup for op: {0}")]
|
||||
GetDynamicLookupError(String),
|
||||
/// Failed to get shuffle
|
||||
#[error("failed to get shuffle for op: {0}")]
|
||||
GetShuffleError(String),
|
||||
/// Failed to get constants
|
||||
#[error("failed to get constants for op: {0}")]
|
||||
GetConstantsError(String),
|
||||
/// Slice length mismatch
|
||||
#[error("slice length mismatch: {0}")]
|
||||
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
|
||||
/// Bad conversion
|
||||
#[error("invalid conversion: {0}")]
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Invalid min/max lookup range
|
||||
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
|
||||
InvalidMinMaxRange(i64, i64),
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
/// Mismatched lookup length
|
||||
#[error("mismatched lookup lengths: {0} and {1}")]
|
||||
MismatchedLookupLength(usize, usize),
|
||||
/// Mismatched shuffle length
|
||||
#[error("mismatched shuffle lengths: {0} and {1}")]
|
||||
MismatchedShuffleLength(usize, usize),
|
||||
/// Mismatched lookup table lengths
|
||||
#[error("mismatched lookup table lengths: {0} and {1}")]
|
||||
MismatchedLookupTableLength(usize, usize),
|
||||
/// Wrong column type for lookup
|
||||
#[error("wrong column type for lookup: {0}")]
|
||||
WrongColumnType(String),
|
||||
/// Wrong column type for dynamic lookup
|
||||
#[error("wrong column type for dynamic lookup: {0}")]
|
||||
WrongDynamicColumnType(String),
|
||||
/// Missing selectors
|
||||
#[error("missing selectors for op: {0}")]
|
||||
MissingSelectors(String),
|
||||
/// Table lookup error
|
||||
#[error("value ({0}) out of range: ({1}, {2})")]
|
||||
TableOOR(i64, i64, i64),
|
||||
/// Loookup not configured
|
||||
#[error("lookup not configured: {0}")]
|
||||
LookupNotConfigured(String),
|
||||
/// Range check not configured
|
||||
#[error("range check not configured: {0}")]
|
||||
RangeCheckNotConfigured(String),
|
||||
/// Missing layout
|
||||
#[error("missing layout for op: {0}")]
|
||||
MissingLayout(String),
|
||||
}
|
||||
@@ -155,7 +155,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
@@ -287,7 +287,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
}))
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
HybridOp::Greater { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
@@ -295,7 +294,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
@@ -305,7 +304,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
}
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{any::Any, error::Error};
|
||||
use std::any::Any;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -15,6 +15,8 @@ pub mod base;
|
||||
///
|
||||
pub mod chip;
|
||||
///
|
||||
pub mod errors;
|
||||
///
|
||||
pub mod hybrid;
|
||||
/// Layouts for specific functions (composed of base ops)
|
||||
pub mod layouts;
|
||||
@@ -25,6 +27,8 @@ pub mod poly;
|
||||
///
|
||||
pub mod region;
|
||||
|
||||
pub use errors::CircuitError;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
@@ -44,10 +48,10 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>>;
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError>;
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>>;
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError>;
|
||||
|
||||
/// Do any of the inputs to this op require homogenous input scales?
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
@@ -139,7 +143,7 @@ pub struct Input {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
|
||||
@@ -156,7 +160,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
let value = values[0].clone();
|
||||
if !value.all_prev_assigned() {
|
||||
match self.datum_type {
|
||||
@@ -194,7 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(0)
|
||||
}
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -209,8 +213,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
_: &mut crate::circuit::BaseConfig<F>,
|
||||
_: &mut RegionCtx<F>,
|
||||
_: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
Err(Box::new(super::CircuitError::UnsupportedOp))
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Err(super::CircuitError::UnsupportedOp)
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
@@ -240,7 +244,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Consta
|
||||
}
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), Box<dyn Error>> {
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
@@ -279,7 +283,7 @@ impl<
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
_: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
let value = if let Some(value) = &self.pre_assigned_val {
|
||||
value.clone()
|
||||
} else {
|
||||
@@ -293,7 +297,7 @@ impl<
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.quantized_values.scale().unwrap())
|
||||
}
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ impl<
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
@@ -278,9 +278,10 @@ impl<
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
return Err(Box::new(TensorError::DimError(
|
||||
return Err(TensorError::DimError(
|
||||
"Pad operation requires a single input".to_string(),
|
||||
)));
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let mut input = values[0].clone();
|
||||
input.pad(p.clone(), 0)?;
|
||||
@@ -297,7 +298,7 @@ impl<
|
||||
}))
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
@@ -19,7 +19,7 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
use super::{lookup::LookupOp, CircuitError};
|
||||
|
||||
/// Constants map
|
||||
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
|
||||
@@ -84,44 +84,6 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
/// Region error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegionError {
|
||||
/// wrap other regions
|
||||
#[error("Wrapped region: {0}")]
|
||||
Wrapped(String),
|
||||
}
|
||||
|
||||
impl From<String> for RegionError {
|
||||
fn from(e: String) -> Self {
|
||||
Self::Wrapped(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for RegionError {
|
||||
fn from(e: &str) -> Self {
|
||||
Self::Wrapped(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TensorError> for RegionError {
|
||||
fn from(e: TensorError) -> Self {
|
||||
Self::Wrapped(format!("{:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for RegionError {
|
||||
fn from(e: Error) -> Self {
|
||||
Self::Wrapped(format!("{:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<dyn std::error::Error>> for RegionError {
|
||||
fn from(e: Box<dyn std::error::Error>) -> Self {
|
||||
Self::Wrapped(format!("{:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
@@ -317,10 +279,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn apply_in_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Result<(), RegionError> {
|
||||
) -> Result<(), CircuitError> {
|
||||
if self.is_dummy() {
|
||||
self.dummy_loop(output, inner_loop_function)?;
|
||||
} else {
|
||||
@@ -333,8 +295,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn real_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>,
|
||||
) -> Result<(), RegionError> {
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>,
|
||||
) -> Result<(), CircuitError> {
|
||||
output
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
@@ -342,7 +304,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
*o = inner_loop_function(i, self)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, RegionError>>()?;
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -353,10 +315,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn dummy_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Result<(), RegionError> {
|
||||
) -> Result<(), CircuitError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
@@ -367,50 +329,48 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
// we kick off the loop with the current offset
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
// we kick off the loop with the current offset
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_linear_coord(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
|
||||
linear_coord.fetch_add(
|
||||
local_reg.linear_coord() - starting_linear_coord,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_linear_coord(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
|
||||
linear_coord.fetch_add(
|
||||
local_reg.linear_coord() - starting_linear_coord,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
|
||||
res
|
||||
})
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
|
||||
res
|
||||
})?;
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
@@ -419,49 +379,25 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?;
|
||||
self.shuffle_index = Arc::try_unwrap(shuffle_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
|
||||
self.assigned_constants = Arc::try_unwrap(constants)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -470,7 +406,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn update_max_min_lookup_inputs(
|
||||
&mut self,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
@@ -482,12 +418,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(
|
||||
&mut self,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
|
||||
return Err(CircuitError::InvalidMinMaxRange(range.0, range.1));
|
||||
}
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
@@ -506,13 +439,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
self.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
@@ -707,7 +640,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// constrain equal
|
||||
pub fn constrain_equal(&mut self, a: &ValTensor<F>, b: &ValTensor<F>) -> Result<(), Error> {
|
||||
pub fn constrain_equal(
|
||||
&mut self,
|
||||
a: &ValTensor<F>,
|
||||
b: &ValTensor<F>,
|
||||
) -> Result<(), CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
let a = a.get_inner_tensor().unwrap();
|
||||
let b = b.get_inner_tensor().unwrap();
|
||||
@@ -717,12 +654,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let b = b.get_prev_assigned();
|
||||
// if they're both assigned, we can constrain them
|
||||
if let (Some(a), Some(b)) = (&a, &b) {
|
||||
region.borrow_mut().constrain_equal(a.cell(), b.cell())
|
||||
region
|
||||
.borrow_mut()
|
||||
.constrain_equal(a.cell(), b.cell())
|
||||
.map_err(|e| e.into())
|
||||
} else if a.is_some() || b.is_some() {
|
||||
log::error!(
|
||||
"constrain_equal: one of the tensors is assigned and the other is not"
|
||||
);
|
||||
return Err(Error::Synthesis);
|
||||
return Err(CircuitError::ConstrainError);
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
@@ -748,7 +685,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// flush row to the next row
|
||||
pub fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn flush(&mut self) -> Result<(), CircuitError> {
|
||||
// increment by the difference between the current linear coord and the next row
|
||||
let remainder = self.linear_coord % self.num_inner_cols;
|
||||
if remainder != 0 {
|
||||
@@ -756,7 +693,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.increment(diff);
|
||||
}
|
||||
if self.linear_coord % self.num_inner_cols != 0 {
|
||||
return Err("flush: linear coord is not aligned with the next row".into());
|
||||
return Err(CircuitError::FlushError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{error::Error, marker::PhantomData};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -194,9 +194,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
preassigned_input: bool,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
return Err(CircuitError::TableAlreadyAssigned);
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
@@ -342,9 +342,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
|
||||
}
|
||||
|
||||
/// Assigns values to the constraints generated when calling `configure`.
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
return Err(CircuitError::TableAlreadyAssigned);
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
|
||||
143
src/eth.rs
143
src/eth.rs
@@ -16,7 +16,8 @@ use alloy::prelude::Wallet;
|
||||
// use alloy::providers::Middleware;
|
||||
use alloy::json_abi::JsonAbi;
|
||||
use alloy::node_bindings::Anvil;
|
||||
use alloy::primitives::{B256, I256};
|
||||
use alloy::primitives::ruint::ParseError;
|
||||
use alloy::primitives::{ParseSignedError, B256, I256};
|
||||
use alloy::providers::fillers::{
|
||||
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
|
||||
};
|
||||
@@ -25,10 +26,13 @@ use alloy::providers::ProviderBuilder;
|
||||
use alloy::providers::{Identity, Provider, RootProvider};
|
||||
use alloy::rpc::types::eth::TransactionInput;
|
||||
use alloy::rpc::types::eth::TransactionRequest;
|
||||
use alloy::signers::wallet::LocalWallet;
|
||||
use alloy::signers::k256::ecdsa;
|
||||
use alloy::signers::wallet::{LocalWallet, WalletError};
|
||||
use alloy::sol as abigen;
|
||||
use alloy::transports::http::Http;
|
||||
use alloy::transports::{RpcError, TransportErrorKind};
|
||||
use foundry_compilers::artifacts::Settings as SolcSettings;
|
||||
use foundry_compilers::error::{SolcError, SolcIoError};
|
||||
use foundry_compilers::Solc;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
@@ -36,7 +40,6 @@ use halo2curves::group::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use log::{debug, info, warn};
|
||||
use reqwest::Client;
|
||||
use std::error::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
@@ -213,6 +216,57 @@ abigen!(
|
||||
}
|
||||
);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EthError {
|
||||
#[error("a transport error occurred: {0}")]
|
||||
Transport(#[from] RpcError<TransportErrorKind>),
|
||||
#[error("a contract error occurred: {0}")]
|
||||
Contract(#[from] alloy::contract::Error),
|
||||
#[error("a wallet error occurred: {0}")]
|
||||
Wallet(#[from] WalletError),
|
||||
#[error("failed to parse url {0}")]
|
||||
UrlParse(String),
|
||||
#[error("evm verification error: {0}")]
|
||||
EvmVerification(#[from] EvmVerificationError),
|
||||
#[error("Private key must be in hex format, 64 chars, without 0x prefix")]
|
||||
PrivateKeyFormat,
|
||||
#[error("failed to parse hex: {0}")]
|
||||
HexParse(#[from] hex::FromHexError),
|
||||
#[error("ecdsa error: {0}")]
|
||||
Ecdsa(#[from] ecdsa::Error),
|
||||
#[error("failed to load graph data")]
|
||||
GraphData,
|
||||
#[error("failed to load graph settings")]
|
||||
GraphSettings,
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Data source for either input_data or output_data must be OnChain")]
|
||||
OnChainDataSource,
|
||||
#[error("failed to parse signed integer: {0}")]
|
||||
SignedIntegerParse(#[from] ParseSignedError),
|
||||
#[error("failed to parse unsigned integer: {0}")]
|
||||
UnSignedIntegerParse(#[from] ParseError),
|
||||
#[error("updateAccountCalls should have failed")]
|
||||
UpdateAccountCalls,
|
||||
#[error("ethabi error: {0}")]
|
||||
EthAbi(#[from] ethabi::Error),
|
||||
#[error("conversion error: {0}")]
|
||||
Conversion(#[from] std::convert::Infallible),
|
||||
// Constructor arguments provided but no constructor found
|
||||
#[error("constructor arguments provided but no constructor found")]
|
||||
NoConstructor,
|
||||
#[error("contract not found at path: {0}")]
|
||||
ContractNotFound(String),
|
||||
#[error("solc error: {0}")]
|
||||
Solc(#[from] SolcError),
|
||||
#[error("solc io error: {0}")]
|
||||
SolcIo(#[from] SolcIoError),
|
||||
#[error("svm error: {0}")]
|
||||
Svm(String),
|
||||
#[error("no contract output found")]
|
||||
NoContractOutput,
|
||||
}
|
||||
|
||||
// we have to generate these two contract differently because they are generated dynamically ! and hence the static compilation from above does not suit
|
||||
const ATTESTDATA_SOL: &str = include_str!("../contracts/AttestData.sol");
|
||||
|
||||
@@ -235,7 +289,7 @@ pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
|
||||
pub async fn setup_eth_backend(
|
||||
rpc_url: Option<&str>,
|
||||
private_key: Option<&str>,
|
||||
) -> Result<(EthersClient, alloy::primitives::Address), Box<dyn Error>> {
|
||||
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
|
||||
// Launch anvil
|
||||
|
||||
let endpoint: String;
|
||||
@@ -257,11 +311,8 @@ pub async fn setup_eth_backend(
|
||||
let wallet: LocalWallet;
|
||||
if let Some(private_key) = private_key {
|
||||
debug!("using private key {}", private_key);
|
||||
// Sanity checks for private_key
|
||||
let private_key_format_error =
|
||||
"Private key must be in hex format, 64 chars, without 0x prefix";
|
||||
if private_key.len() != 64 {
|
||||
return Err(private_key_format_error.into());
|
||||
return Err(EthError::PrivateKeyFormat);
|
||||
}
|
||||
let private_key_buffer = hex::decode(private_key)?;
|
||||
wallet = LocalWallet::from_slice(&private_key_buffer)?;
|
||||
@@ -276,7 +327,11 @@ pub async fn setup_eth_backend(
|
||||
ProviderBuilder::new()
|
||||
.with_recommended_fillers()
|
||||
.signer(EthereumSigner::from(wallet))
|
||||
.on_http(endpoint.parse()?),
|
||||
.on_http(
|
||||
endpoint
|
||||
.parse()
|
||||
.map_err(|_| EthError::UrlParse(endpoint.clone()))?,
|
||||
),
|
||||
);
|
||||
|
||||
let chain_id = client.get_chain_id().await?;
|
||||
@@ -292,7 +347,7 @@ pub async fn deploy_contract_via_solidity(
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
) -> Result<H160, Box<dyn Error>> {
|
||||
) -> Result<H160, EthError> {
|
||||
// anvil instance must be alive at least until the factory completes the deploy
|
||||
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
@@ -314,12 +369,12 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
) -> Result<H160, Box<dyn Error>> {
|
||||
) -> Result<H160, EthError> {
|
||||
let (client, client_address) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
let input = GraphData::from_path(input)?;
|
||||
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let settings = GraphSettings::load(&settings_path).map_err(|_| EthError::GraphSettings)?;
|
||||
|
||||
let mut scales: Vec<u32> = vec![];
|
||||
// The data that will be stored in the test contracts that will eventually be read from.
|
||||
@@ -339,7 +394,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
|
||||
if settings.run_args.param_visibility.is_hashed() {
|
||||
return Err(Box::new(EvmVerificationError::InvalidVisibility));
|
||||
return Err(EvmVerificationError::InvalidVisibility.into());
|
||||
}
|
||||
|
||||
if settings.run_args.output_visibility.is_hashed() {
|
||||
@@ -400,7 +455,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
|
||||
parse_calls_to_accounts(calls_to_accounts)?
|
||||
} else {
|
||||
return Err("Data source for either input_data or output_data must be OnChain".into());
|
||||
return Err(EthError::OnChainDataSource);
|
||||
};
|
||||
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
@@ -469,7 +524,7 @@ type ParsedCallsToAccount = (Vec<H160>, Vec<Vec<Bytes>>, Vec<Vec<U256>>);
|
||||
|
||||
fn parse_calls_to_accounts(
|
||||
calls_to_accounts: Vec<CallsToAccount>,
|
||||
) -> Result<ParsedCallsToAccount, Box<dyn Error>> {
|
||||
) -> Result<ParsedCallsToAccount, EthError> {
|
||||
let mut contract_addresses = vec![];
|
||||
let mut call_data = vec![];
|
||||
let mut decimals: Vec<Vec<U256>> = vec![];
|
||||
@@ -492,8 +547,8 @@ pub async fn update_account_calls(
|
||||
addr: H160,
|
||||
input: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let input = GraphData::from_path(input)?;
|
||||
) -> Result<(), EthError> {
|
||||
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
|
||||
|
||||
// The data that will be stored in the test contracts that will eventually be read from.
|
||||
let mut calls_to_accounts = vec![];
|
||||
@@ -513,7 +568,7 @@ pub async fn update_account_calls(
|
||||
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
|
||||
parse_calls_to_accounts(calls_to_accounts)?
|
||||
} else {
|
||||
return Err("Data source for either input_data or output_data must be OnChain".into());
|
||||
return Err(EthError::OnChainDataSource);
|
||||
};
|
||||
|
||||
let (client, client_address) = setup_eth_backend(rpc_url, None).await?;
|
||||
@@ -547,7 +602,7 @@ pub async fn update_account_calls(
|
||||
{
|
||||
info!("updateAccountCalls failed as expected");
|
||||
} else {
|
||||
return Err("updateAccountCalls should have failed".into());
|
||||
return Err(EthError::UpdateAccountCalls);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -560,7 +615,7 @@ pub async fn verify_proof_via_solidity(
|
||||
addr: H160,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EthError> {
|
||||
let flattened_instances = proof.instances.into_iter().flatten();
|
||||
|
||||
let encoded = encode_calldata(
|
||||
@@ -579,15 +634,15 @@ pub async fn verify_proof_via_solidity(
|
||||
|
||||
let result = client.call(&tx).await;
|
||||
|
||||
if result.is_err() {
|
||||
return Err(Box::new(EvmVerificationError::SolidityExecution));
|
||||
if let Err(e) = result {
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result.to_vec());
|
||||
// decode return bytes value into uint8
|
||||
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
|
||||
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
|
||||
if !result {
|
||||
return Err(Box::new(EvmVerificationError::InvalidProof));
|
||||
return Err(EvmVerificationError::InvalidProof.into());
|
||||
}
|
||||
|
||||
let gas = client.estimate_gas(&tx).await?;
|
||||
@@ -626,7 +681,7 @@ fn count_decimal_places(num: f32) -> usize {
|
||||
pub async fn setup_test_contract<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
data: &[Vec<FileSourceInner>],
|
||||
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), Box<dyn Error>> {
|
||||
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), EthError> {
|
||||
let mut decimals = vec![];
|
||||
let mut scaled_by_decimals_data = vec![];
|
||||
for input in &data[0] {
|
||||
@@ -663,7 +718,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
addr_da: H160,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EthError> {
|
||||
use ethabi::{Function, Param, ParamType, StateMutability, Token};
|
||||
|
||||
let mut public_inputs: Vec<U256> = vec![];
|
||||
@@ -728,15 +783,15 @@ pub async fn verify_proof_with_data_attestation(
|
||||
);
|
||||
|
||||
let result = client.call(&tx).await;
|
||||
if result.is_err() {
|
||||
return Err(Box::new(EvmVerificationError::SolidityExecution));
|
||||
if let Err(e) = result {
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result);
|
||||
// decode return bytes value into uint8
|
||||
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
|
||||
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
|
||||
if !result {
|
||||
return Err(Box::new(EvmVerificationError::InvalidProof));
|
||||
return Err(EvmVerificationError::InvalidProof.into());
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
@@ -748,7 +803,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
data: &[Vec<FileSourceInner>],
|
||||
) -> Result<Vec<CallsToAccount>, Box<dyn Error>> {
|
||||
) -> Result<Vec<CallsToAccount>, EthError> {
|
||||
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
|
||||
|
||||
// Get the encoded call data for each input
|
||||
@@ -774,7 +829,7 @@ pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>
|
||||
client: Arc<M>,
|
||||
address: H160,
|
||||
data: &Vec<CallsToAccount>,
|
||||
) -> Result<(Vec<Bytes>, Vec<u8>), Box<dyn Error>> {
|
||||
) -> Result<(Vec<Bytes>, Vec<u8>), EthError> {
|
||||
// Iterate over all on-chain inputs
|
||||
|
||||
let mut fetched_inputs = vec![];
|
||||
@@ -808,9 +863,7 @@ pub async fn evm_quantize<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
scales: Vec<crate::Scale>,
|
||||
data: &(Vec<Bytes>, Vec<u8>),
|
||||
) -> Result<Vec<Fr>, Box<dyn Error>> {
|
||||
use alloy::primitives::ParseSignedError;
|
||||
|
||||
) -> Result<Vec<Fr>, EthError> {
|
||||
let contract = QuantizeData::deploy(&client).await?;
|
||||
|
||||
let fetched_inputs = data.0.clone();
|
||||
@@ -870,7 +923,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
|
||||
runtime_bytecode: Bytes,
|
||||
client: Arc<M>,
|
||||
params: Option<T>,
|
||||
) -> Result<ContractFactory<M>, Box<dyn Error>> {
|
||||
) -> Result<ContractFactory<M>, EthError> {
|
||||
const MAX_RUNTIME_BYTECODE_SIZE: usize = 24577;
|
||||
let size = runtime_bytecode.len();
|
||||
debug!("runtime bytecode size: {:#?}", size);
|
||||
@@ -888,7 +941,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
|
||||
// Encode the constructor args & concatenate with the bytecode if necessary
|
||||
let data: Bytes = match (abi.constructor(), params.is_none()) {
|
||||
(None, false) => {
|
||||
return Err("Constructor arguments provided but no constructor found".into())
|
||||
return Err(EthError::NoConstructor);
|
||||
}
|
||||
(None, true) => bytecode.clone(),
|
||||
(Some(_), _) => {
|
||||
@@ -911,7 +964,7 @@ pub async fn get_contract_artifacts(
|
||||
sol_code_path: PathBuf,
|
||||
contract_name: &str,
|
||||
runs: usize,
|
||||
) -> Result<(JsonAbi, Bytes, Bytes), Box<dyn Error>> {
|
||||
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
|
||||
use foundry_compilers::{
|
||||
artifacts::{output_selection::OutputSelection, Optimizer},
|
||||
compilers::CompilerInput,
|
||||
@@ -919,7 +972,9 @@ pub async fn get_contract_artifacts(
|
||||
};
|
||||
|
||||
if !sol_code_path.exists() {
|
||||
return Err(format!("file not found: {:#?}", sol_code_path).into());
|
||||
return Err(EthError::ContractNotFound(
|
||||
sol_code_path.to_string_lossy().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let settings = SolcSettings {
|
||||
@@ -946,7 +1001,9 @@ pub async fn get_contract_artifacts(
|
||||
Some(solc) => solc,
|
||||
None => {
|
||||
info!("required solc version is missing ... installing");
|
||||
Solc::install(&SHANGHAI_SOLC).await?
|
||||
Solc::install(&SHANGHAI_SOLC)
|
||||
.await
|
||||
.map_err(|e| EthError::Svm(e.to_string()))?
|
||||
}
|
||||
};
|
||||
|
||||
@@ -955,7 +1012,7 @@ pub async fn get_contract_artifacts(
|
||||
let (abi, bytecode, runtime_bytecode) = match compiled.find(contract_name) {
|
||||
Some(c) => c.into_parts_or_default(),
|
||||
None => {
|
||||
return Err("could not find contract".into());
|
||||
return Err(EthError::ContractNotFound(contract_name.to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -967,7 +1024,7 @@ pub fn fix_da_sol(
|
||||
input_data: Option<Vec<CallsToAccount>>,
|
||||
output_data: Option<Vec<CallsToAccount>>,
|
||||
commitment_bytes: Option<Vec<u8>>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EthError> {
|
||||
let mut accounts_len = 0;
|
||||
let mut contract = ATTESTDATA_SOL.to_string();
|
||||
|
||||
|
||||
141
src/execute.rs
141
src/execute.rs
@@ -1,7 +1,6 @@
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::*;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -23,6 +22,7 @@ use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::{commands::*, EZKLError};
|
||||
use crate::{Commitments, RunArgs};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
@@ -63,7 +63,6 @@ use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::compile;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::BufWriter;
|
||||
@@ -92,12 +91,15 @@ lazy_static! {
|
||||
|
||||
}
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
/// A wrapper for execution errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExecutionError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("verification failed")]
|
||||
/// verification failed
|
||||
#[error("verification failed:\n{}", .0.iter().map(|e| e.to_string()).collect::<Vec<_>>().join("\n"))]
|
||||
VerifyError(Vec<VerifyFailure>),
|
||||
/// Prover error
|
||||
#[error("[mock] {0}")]
|
||||
MockProverError(String),
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
@@ -109,7 +111,7 @@ lazy_static::lazy_static! {
|
||||
}
|
||||
|
||||
/// Run an ezkl command with given args
|
||||
pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
// set working dir
|
||||
std::env::set_current_dir(WORKING_DIR.as_path())?;
|
||||
|
||||
@@ -123,7 +125,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT)?),
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetSrs {
|
||||
@@ -161,7 +163,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse()?),
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -200,7 +202,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -265,8 +267,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_AGGREGATED.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
|
||||
aggregation_settings,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -292,7 +294,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK.into()),
|
||||
witness,
|
||||
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
|
||||
disable_selector_compression
|
||||
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEvmData {
|
||||
@@ -345,7 +348,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
Some(proof_path.unwrap_or(DEFAULT_PROOF.into())),
|
||||
srs_path,
|
||||
proof_type,
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse().unwrap()),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::MockAggregate {
|
||||
@@ -354,8 +357,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
split_proofs,
|
||||
} => mock_aggregate(
|
||||
aggregation_snarks,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
|
||||
),
|
||||
Commands::SetupAggregate {
|
||||
sample_snarks,
|
||||
@@ -371,9 +374,10 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
|
||||
disable_selector_compression
|
||||
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
|
||||
commitment.into(),
|
||||
),
|
||||
Commands::Aggregate {
|
||||
@@ -392,9 +396,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
transcript,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse().unwrap()),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -409,7 +413,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap()),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::VerifyAggr {
|
||||
@@ -423,8 +427,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
|
||||
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap()),
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -507,13 +511,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
/// Assert that the version is valid
|
||||
fn assert_version_is_valid(version: &str) -> Result<(), Box<dyn Error>> {
|
||||
fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
|
||||
let err_string = "Invalid version string. Must be in the format v0.0.0";
|
||||
if version.is_empty() {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
// safe to unwrap since we know the length is not 0
|
||||
if version.chars().nth(0).unwrap() != 'v' {
|
||||
if !version.starts_with('v') {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
@@ -525,7 +529,7 @@ fn assert_version_is_valid(version: &str) -> Result<(), Box<dyn Error>> {
|
||||
|
||||
const INSTALL_BYTES: &[u8] = include_bytes!("../install_ezkl_cli.sh");
|
||||
|
||||
fn update_ezkl_binary(version: &Option<String>) -> Result<String, Box<dyn Error>> {
|
||||
fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
|
||||
// run the install script with the version
|
||||
let install_script = std::str::from_utf8(INSTALL_BYTES)?;
|
||||
// now run as sh script with the version as an argument
|
||||
@@ -574,7 +578,7 @@ pub(crate) fn gen_srs_cmd(
|
||||
srs_path: PathBuf,
|
||||
logrows: u32,
|
||||
commitment: Commitments,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
match commitment {
|
||||
Commitments::KZG => {
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
@@ -589,7 +593,7 @@ pub(crate) fn gen_srs_cmd(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, EZKLError> {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Downloading SRS (this may take a while) ...");
|
||||
@@ -609,7 +613,7 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, EZKLError> {
|
||||
use std::io::Read;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
@@ -632,7 +636,7 @@ fn check_srs_hash(
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: Commitments,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let path = get_srs_path(logrows, srs_path, commitment);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
@@ -659,7 +663,7 @@ pub(crate) async fn get_srs_cmd(
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
commitment: Option<Commitments>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
// logrows overrides settings
|
||||
|
||||
let err_string = "You will need to provide a valid settings file to use the settings option. You should run gen-settings to generate a settings file (and calibrate-settings to pick optimal logrows).";
|
||||
@@ -727,7 +731,7 @@ pub(crate) async fn get_srs_cmd(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLError> {
|
||||
let model = Model::from_run_args(&run_args, &model)?;
|
||||
info!("\n {}", model.table_nodes());
|
||||
Ok(String::new())
|
||||
@@ -739,7 +743,7 @@ pub(crate) async fn gen_witness(
|
||||
output: Option<PathBuf>,
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> Result<GraphWitness, Box<dyn Error>> {
|
||||
) -> Result<GraphWitness, EZKLError> {
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
@@ -840,7 +844,7 @@ pub(crate) fn gen_circuit_settings(
|
||||
model_path: PathBuf,
|
||||
params_output: PathBuf,
|
||||
run_args: RunArgs,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let circuit = GraphCircuit::from_run_args(&run_args, &model_path)?;
|
||||
let params = circuit.settings();
|
||||
params.save(¶ms_output)?;
|
||||
@@ -908,7 +912,7 @@ impl AccuracyResults {
|
||||
pub fn new(
|
||||
mut original_preds: Vec<crate::tensor::Tensor<f32>>,
|
||||
mut calibrated_preds: Vec<crate::tensor::Tensor<f32>>,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
) -> Result<Self, EZKLError> {
|
||||
let mut errors = vec![];
|
||||
let mut abs_errors = vec![];
|
||||
let mut squared_errors = vec![];
|
||||
@@ -997,7 +1001,7 @@ pub(crate) async fn calibrate(
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
) -> Result<GraphSettings, EZKLError> {
|
||||
use log::error;
|
||||
use std::collections::HashMap;
|
||||
use tabled::Table;
|
||||
@@ -1369,7 +1373,7 @@ pub(crate) async fn calibrate(
|
||||
pub(crate) fn mock(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
// mock should catch any issues by default so we set it to safe
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
|
||||
@@ -1386,10 +1390,9 @@ pub(crate) fn mock(
|
||||
&circuit,
|
||||
vec![public_inputs],
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -1401,7 +1404,7 @@ pub(crate) async fn create_evm_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
@@ -1445,7 +1448,7 @@ pub(crate) async fn create_evm_vk(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
@@ -1486,7 +1489,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
_abi_path: PathBuf,
|
||||
_input: PathBuf,
|
||||
_witness: Option<PathBuf>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
#[allow(unused_imports)]
|
||||
use crate::graph::{DataSource, VarVisibility};
|
||||
use crate::{graph::Visibility, pfsys::get_proof_commitments};
|
||||
@@ -1565,7 +1568,7 @@ pub(crate) async fn deploy_da_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let contract_address = deploy_da_verifier_via_solidity(
|
||||
settings_path,
|
||||
data,
|
||||
@@ -1591,7 +1594,7 @@ pub(crate) async fn deploy_evm(
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
@@ -1613,7 +1616,7 @@ pub(crate) fn encode_evm_calldata(
|
||||
proof_path: PathBuf,
|
||||
calldata_path: PathBuf,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let snark = Snark::load::<IPACommitmentScheme<G1Affine>>(&proof_path)?;
|
||||
|
||||
let flattened_instances = snark.instances.into_iter().flatten();
|
||||
@@ -1641,7 +1644,7 @@ pub(crate) async fn verify_evm(
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
@@ -1683,7 +1686,7 @@ pub(crate) async fn create_evm_aggregate_verifier(
|
||||
circuit_settings: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
|
||||
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
|
||||
|
||||
@@ -1740,7 +1743,7 @@ pub(crate) fn compile_circuit(
|
||||
model_path: PathBuf,
|
||||
compiled_circuit: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let circuit = GraphCircuit::from_settings(&settings, &model_path, CheckMode::UNSAFE)?;
|
||||
circuit.save(compiled_circuit)?;
|
||||
@@ -1754,7 +1757,7 @@ pub(crate) fn setup(
|
||||
pk_path: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit)?;
|
||||
@@ -1806,7 +1809,7 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
rpc_url: Option<String>,
|
||||
input_source: TestDataSource,
|
||||
output_source: TestDataSource,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::graph::TestOnChainData;
|
||||
|
||||
let mut data = GraphData::from_path(data_path)?;
|
||||
@@ -1841,7 +1844,7 @@ pub(crate) async fn test_update_account_calls(
|
||||
addr: H160Flag,
|
||||
data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
@@ -1859,7 +1862,7 @@ pub(crate) fn prove(
|
||||
srs_path: Option<PathBuf>,
|
||||
proof_type: ProofType,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
|
||||
let data = GraphWitness::from_path(data_path)?;
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
|
||||
@@ -2011,7 +2014,7 @@ pub(crate) fn prove(
|
||||
pub(crate) fn swap_proof_commitments_cmd(
|
||||
proof_path: PathBuf,
|
||||
witness: PathBuf,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
|
||||
let snark = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
let witness = GraphWitness::from_path(witness)?;
|
||||
let commitments = witness.get_polycommitments();
|
||||
@@ -2030,7 +2033,7 @@ pub(crate) fn mock_aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut snarks = vec![];
|
||||
for proof_path in aggregation_snarks.iter() {
|
||||
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
|
||||
@@ -2057,10 +2060,8 @@ pub(crate) fn mock_aggregate(
|
||||
let circuit = AggregationCircuit::new(&G1Affine::generator().into(), snarks, split_proofs)?;
|
||||
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("Done.");
|
||||
Ok(String::new())
|
||||
@@ -2075,7 +2076,7 @@ pub(crate) fn setup_aggregate(
|
||||
split_proofs: bool,
|
||||
disable_selector_compression: bool,
|
||||
commitment: Commitments,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut snarks = vec![];
|
||||
for proof_path in sample_snarks.iter() {
|
||||
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
|
||||
@@ -2138,7 +2139,7 @@ pub(crate) fn aggregate(
|
||||
check_mode: CheckMode,
|
||||
split_proofs: bool,
|
||||
commitment: Commitments,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
|
||||
let mut snarks = vec![];
|
||||
for proof_path in aggregation_snarks.iter() {
|
||||
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
|
||||
@@ -2318,7 +2319,7 @@ pub(crate) fn verify(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
reduced_srs: bool,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EZKLError> {
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let logrows = circuit_settings.run_args.logrows;
|
||||
@@ -2412,7 +2413,7 @@ fn verify_commitment<
|
||||
vk_path: PathBuf,
|
||||
params: &'a Scheme::ParamsVerifier,
|
||||
logrows: u32,
|
||||
) -> Result<bool, Box<dyn Error>>
|
||||
) -> Result<bool, EZKLError>
|
||||
where
|
||||
Scheme::Scalar: FromUniformBytes<64>
|
||||
+ SerdeObject
|
||||
@@ -2448,7 +2449,7 @@ pub(crate) fn verify_aggr(
|
||||
logrows: u32,
|
||||
reduced_srs: bool,
|
||||
commitment: Commitments,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EZKLError> {
|
||||
match commitment {
|
||||
Commitments::KZG => {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
@@ -2523,11 +2524,11 @@ pub(crate) fn load_params_verifier<Scheme: CommitmentScheme>(
|
||||
srs_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
commitment: Commitments,
|
||||
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
|
||||
) -> Result<Scheme::ParamsVerifier, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, commitment);
|
||||
let mut params = load_srs_verifier::<Scheme>(srs_path)?;
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
if logrows < params.k() {
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
params.downsize(logrows);
|
||||
}
|
||||
Ok(params)
|
||||
@@ -2538,11 +2539,11 @@ pub(crate) fn load_params_prover<Scheme: CommitmentScheme>(
|
||||
srs_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
commitment: Commitments,
|
||||
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
|
||||
) -> Result<Scheme::ParamsProver, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, commitment);
|
||||
let mut params = load_srs_prover::<Scheme>(srs_path)?;
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
if logrows < params.k() {
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
params.downsize(logrows);
|
||||
}
|
||||
Ok(params)
|
||||
|
||||
137
src/graph/errors.rs
Normal file
137
src/graph/errors.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
/// The wrong inputs were passed to a lookup node
|
||||
#[error("invalid inputs for a lookup node")]
|
||||
InvalidLookupInputs,
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
/// A requested node is missing in the graph
|
||||
#[error("a requested node is missing in the graph: {0}")]
|
||||
MissingNode(usize),
|
||||
/// The wrong method was called on an operation
|
||||
#[error("an unsupported method was called on node {0} ({1})")]
|
||||
OpMismatch(usize, String),
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported datatype in graph node {0} ({1})")]
|
||||
UnsupportedDataType(usize, String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
Visibility,
|
||||
/// Ezkl only supports divisions by constants
|
||||
#[error("ezkl currently only supports division by constants")]
|
||||
NonConstantDiv,
|
||||
/// Ezkl only supports constant powers
|
||||
#[error("ezkl currently only supports constant exponents")]
|
||||
NonConstantPower,
|
||||
/// Error when attempting to rescale an operation
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(String),
|
||||
/// Error when attempting to load a model from a file
|
||||
#[error("failed to load model")]
|
||||
ModelLoad(#[from] std::io::Error),
|
||||
/// Model serialization error
|
||||
#[error("failed to ser/deser model: {0}")]
|
||||
ModelSerialize(#[from] bincode::Error),
|
||||
/// Tract error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[tract] {0}")]
|
||||
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
PackingExponent,
|
||||
/// Invalid Input Types
|
||||
#[error("invalid input types")]
|
||||
InvalidInputTypes,
|
||||
/// Missing results
|
||||
#[error("missing results")]
|
||||
MissingResults,
|
||||
/// Tensor error
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] crate::tensor::TensorError),
|
||||
/// Public visibility for params is deprecated
|
||||
#[error("public visibility for params is deprecated, please use `fixed` instead")]
|
||||
ParamsPublicVisibility,
|
||||
/// Slice length mismatch
|
||||
#[error("slice length mismatch: {0}")]
|
||||
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
|
||||
/// Bad conversion
|
||||
#[error("invalid conversion: {0}")]
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Circuit error
|
||||
#[error("[circuit] {0}")]
|
||||
CircuitError(#[from] crate::circuit::CircuitError),
|
||||
/// Halo2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] halo2_proofs::plonk::Error),
|
||||
/// System time error
|
||||
#[error("[system time] {0}")]
|
||||
SystemTimeError(#[from] std::time::SystemTimeError),
|
||||
/// Missing Batch Size
|
||||
#[error("unknown dimension batch_size in model inputs, set batch_size in variables")]
|
||||
MissingBatchSize,
|
||||
/// Tokio postgres error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[tokio postgres] {0}")]
|
||||
TokioPostgresError(#[from] tokio_postgres::Error),
|
||||
/// Eth error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[eth] {0}")]
|
||||
EthError(#[from] crate::eth::EthError),
|
||||
/// Json error
|
||||
#[error("[json] {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
/// Missing instances
|
||||
#[error("missing instances")]
|
||||
MissingInstances,
|
||||
/// Missing constants
|
||||
#[error("missing constants")]
|
||||
MissingConstants,
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
///
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
///
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
///
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
#[error("missing scale")]
|
||||
MissingScale,
|
||||
/// Extended k is too large
|
||||
#[error("extended k is too large to accommodate the quotient polynomial with logrows {0}")]
|
||||
ExtendedKTooLarge(u32),
|
||||
/// Max lookup input is too large
|
||||
#[error("lookup range {0} is too large")]
|
||||
LookupRangeTooLarge(usize),
|
||||
/// Max range check input is too large
|
||||
#[error("range check {0} is too large")]
|
||||
RangeCheckTooLarge(usize),
|
||||
///Cannot use on-chain data source as private data
|
||||
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
|
||||
OnChainDataSource,
|
||||
/// Missing data source
|
||||
#[error("missing data source")]
|
||||
MissingDataSource,
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use super::GraphError;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -211,9 +211,7 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub async fn fetch(
|
||||
&self,
|
||||
) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
@@ -247,9 +245,7 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub async fn fetch_and_format_as_file(
|
||||
&self,
|
||||
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
@@ -279,7 +275,7 @@ impl OnChainSource {
|
||||
scales: Vec<crate::Scale>,
|
||||
mut shapes: Vec<Vec<usize>>,
|
||||
rpc: Option<&str>,
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
|
||||
use crate::eth::{
|
||||
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
|
||||
};
|
||||
@@ -455,7 +451,7 @@ impl GraphData {
|
||||
&self,
|
||||
shapes: &[Vec<usize>],
|
||||
datum_types: &[tract_onnx::prelude::DatumType],
|
||||
) -> Result<TVec<TValue>, Box<dyn std::error::Error>> {
|
||||
) -> Result<TVec<TValue>, GraphError> {
|
||||
let mut inputs = TVec::new();
|
||||
match &self.input_data {
|
||||
DataSource::File(data) => {
|
||||
@@ -470,10 +466,10 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"non file data cannot be split into batches".to_string(),
|
||||
)))
|
||||
))
|
||||
}
|
||||
}
|
||||
Ok(inputs)
|
||||
@@ -488,7 +484,7 @@ impl GraphData {
|
||||
}
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(path)?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
||||
let mut buf = String::new();
|
||||
@@ -498,7 +494,7 @@ impl GraphData {
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
@@ -509,7 +505,7 @@ impl GraphData {
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
// split input data into batches
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
@@ -522,10 +518,10 @@ impl GraphData {
|
||||
input_data: DataSource::OnChain(_),
|
||||
output_data: _,
|
||||
} => {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"on-chain data cannot be split into batches".to_string(),
|
||||
)))
|
||||
))
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
GraphData {
|
||||
@@ -539,11 +535,11 @@ impl GraphData {
|
||||
let input_size = shape.clone().iter().product::<usize>();
|
||||
let input = &iterable[i];
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"calibration data length must be evenly divisible by the original input_size"
|
||||
.to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(input_size) {
|
||||
|
||||
187
src/graph/mod.rs
187
src/graph/mod.rs
@@ -14,6 +14,9 @@ pub mod utilities;
|
||||
/// Representations of a computational graph's variables.
|
||||
pub mod vars;
|
||||
|
||||
/// errors for the graph
|
||||
pub mod errors;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
#[cfg(unix)]
|
||||
@@ -24,6 +27,7 @@ pub use input::DataSource;
|
||||
use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use self::input::OnChainSource;
|
||||
use self::input::{FileSource, GraphData};
|
||||
@@ -58,7 +62,6 @@ use pyo3::types::PyDict;
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
use thiserror::Error;
|
||||
pub use utilities::*;
|
||||
pub use vars::*;
|
||||
|
||||
@@ -88,62 +91,6 @@ lazy_static! {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
/// The wrong inputs were passed to a lookup node
|
||||
#[error("invalid inputs for a lookup node")]
|
||||
InvalidLookupInputs,
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
/// A requested node is missing in the graph
|
||||
#[error("a requested node is missing in the graph: {0}")]
|
||||
MissingNode(usize),
|
||||
/// The wrong method was called on an operation
|
||||
#[error("an unsupported method was called on node {0} ({1})")]
|
||||
OpMismatch(usize, String),
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported datatype in graph")]
|
||||
UnsupportedDataType,
|
||||
/// A node has missing parameters
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
Visibility,
|
||||
/// Ezkl only supports divisions by constants
|
||||
#[error("ezkl currently only supports division by constants")]
|
||||
NonConstantDiv,
|
||||
/// Ezkl only supports constant powers
|
||||
#[error("ezkl currently only supports constant exponents")]
|
||||
NonConstantPower,
|
||||
/// Error when attempting to rescale an operation
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(String),
|
||||
/// Error when attempting to load a model
|
||||
#[error("failed to load")]
|
||||
ModelLoad,
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
PackingExponent,
|
||||
/// Invalid Input Types
|
||||
#[error("invalid input types")]
|
||||
InvalidInputTypes,
|
||||
/// Missing results
|
||||
#[error("missing results")]
|
||||
MissingResults,
|
||||
}
|
||||
|
||||
///
|
||||
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
|
||||
/// The minimum number of rows in the grid
|
||||
@@ -310,27 +257,24 @@ impl GraphWitness {
|
||||
}
|
||||
|
||||
/// Export the ezkl witness as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub fn as_json(&self) -> Result<String, GraphError> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load {}", path.display()))?;
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
// use buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
@@ -595,11 +539,11 @@ impl GraphSettings {
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub fn as_json(&self) -> Result<String, GraphError> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
@@ -695,7 +639,7 @@ impl GraphCircuit {
|
||||
&self.core.model
|
||||
}
|
||||
///
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
@@ -703,7 +647,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
// read bytes from file
|
||||
let f = std::fs::File::open(path)?;
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
@@ -770,10 +714,7 @@ pub struct TestOnChainData {
|
||||
|
||||
impl GraphCircuit {
|
||||
///
|
||||
pub fn new(
|
||||
model: Model,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
|
||||
pub fn new(model: Model, run_args: &RunArgs) -> Result<GraphCircuit, GraphError> {
|
||||
// // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
|
||||
let mut inputs: Vec<Vec<Fp>> = vec![];
|
||||
for shape in model.graph.input_shapes()? {
|
||||
@@ -820,7 +761,7 @@ impl GraphCircuit {
|
||||
model: Model,
|
||||
mut settings: GraphSettings,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
|
||||
) -> Result<GraphCircuit, GraphError> {
|
||||
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
|
||||
let mut inputs: Vec<Vec<Fp>> = vec![];
|
||||
for shape in model.graph.input_shapes()? {
|
||||
@@ -844,20 +785,14 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
/// load inputs and outputs for the model
|
||||
pub fn load_graph_witness(
|
||||
&mut self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn load_graph_witness(&mut self, data: &GraphWitness) -> Result<(), GraphError> {
|
||||
self.graph_witness = data.clone();
|
||||
// load the module settings
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare the public inputs for the circuit.
|
||||
pub fn prepare_public_inputs(
|
||||
&self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<Vec<Fp>, Box<dyn std::error::Error>> {
|
||||
pub fn prepare_public_inputs(&self, data: &GraphWitness) -> Result<Vec<Fp>, GraphError> {
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs: Vec<Fp> = vec![];
|
||||
@@ -890,7 +825,7 @@ impl GraphCircuit {
|
||||
pub fn pretty_public_inputs(
|
||||
&self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<Option<PrettyElements>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Option<PrettyElements>, GraphError> {
|
||||
// dequantize the supplied data using the provided scale.
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
@@ -932,10 +867,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
@@ -946,7 +878,7 @@ impl GraphCircuit {
|
||||
pub fn load_graph_from_file_exclusively(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
@@ -956,7 +888,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => Err("Cannot use non-file data source as input for this method.".into()),
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -965,7 +897,7 @@ impl GraphCircuit {
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
@@ -983,14 +915,12 @@ impl GraphCircuit {
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
match &data {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
DataSource::OnChain(_) => {
|
||||
Err("Cannot use on-chain data source as input for this method.".into())
|
||||
}
|
||||
DataSource::OnChain(_) => Err(GraphError::OnChainDataSource),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1002,7 +932,7 @@ impl GraphCircuit {
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
match &data {
|
||||
DataSource::OnChain(source) => {
|
||||
let mut per_item_scale = vec![];
|
||||
@@ -1030,7 +960,7 @@ impl GraphCircuit {
|
||||
source: OnChainSource,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
|
||||
@@ -1054,7 +984,7 @@ impl GraphCircuit {
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let mut data: Vec<Tensor<Fp>> = vec![];
|
||||
for (((d, shape), scale), input_type) in file_data
|
||||
@@ -1085,7 +1015,7 @@ impl GraphCircuit {
|
||||
&mut self,
|
||||
file_data: &[Vec<Fp>],
|
||||
shapes: &[Vec<usize>],
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let mut data: Vec<Tensor<Fp>> = vec![];
|
||||
for (d, shape) in file_data.iter().zip(shapes) {
|
||||
@@ -1112,7 +1042,7 @@ impl GraphCircuit {
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i64,
|
||||
) -> Result<u32, Box<dyn std::error::Error>> {
|
||||
) -> Result<u32, GraphError> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
|
||||
@@ -1133,7 +1063,7 @@ impl GraphCircuit {
|
||||
max_range_size: i64,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i64,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), GraphError> {
|
||||
// load the max logrows
|
||||
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
|
||||
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
|
||||
@@ -1142,15 +1072,18 @@ impl GraphCircuit {
|
||||
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
|
||||
let lookup_size = (safe_lookup_range.1 - safe_lookup_range.0).abs();
|
||||
// check if has overflowed max lookup input
|
||||
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
|
||||
return Err(err_string.into());
|
||||
if lookup_size > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
return Err(GraphError::LookupRangeTooLarge(
|
||||
lookup_size.unsigned_abs() as usize
|
||||
));
|
||||
}
|
||||
|
||||
if max_range_size.abs() > MAX_LOOKUP_ABS {
|
||||
let err_string = format!("max range check size {:?} is too large", max_range_size);
|
||||
return Err(err_string.into());
|
||||
return Err(GraphError::RangeCheckTooLarge(
|
||||
max_range_size.unsigned_abs() as usize,
|
||||
));
|
||||
}
|
||||
|
||||
// These are hard lower limits, we can't overflow instances or modules constraints
|
||||
@@ -1194,12 +1127,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
return Err(GraphError::ExtendedKTooLarge(max_logrows));
|
||||
}
|
||||
|
||||
let logrows = max_logrows;
|
||||
@@ -1286,7 +1214,7 @@ impl GraphCircuit {
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
) -> Result<GraphWitness, GraphError> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
let visibility = VarVisibility::from_args(&self.settings().run_args)?;
|
||||
@@ -1401,7 +1329,7 @@ impl GraphCircuit {
|
||||
pub fn from_run_args(
|
||||
run_args: &RunArgs,
|
||||
model_path: &std::path::Path,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
) -> Result<Self, GraphError> {
|
||||
let model = Model::from_run_args(run_args, model_path)?;
|
||||
Self::new(model, run_args)
|
||||
}
|
||||
@@ -1412,8 +1340,11 @@ impl GraphCircuit {
|
||||
params: &GraphSettings,
|
||||
model_path: &std::path::Path,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
params.run_args.validate()?;
|
||||
) -> Result<Self, GraphError> {
|
||||
params
|
||||
.run_args
|
||||
.validate()
|
||||
.map_err(GraphError::InvalidRunArgs)?;
|
||||
let model = Model::from_run_args(¶ms.run_args, model_path)?;
|
||||
Self::new_from_settings(model, params.clone(), check_mode)
|
||||
}
|
||||
@@ -1424,7 +1355,7 @@ impl GraphCircuit {
|
||||
&mut self,
|
||||
data: &mut GraphData,
|
||||
test_on_chain_data: TestOnChainData,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), GraphError> {
|
||||
// Set up local anvil instance for reading on-chain data
|
||||
|
||||
let input_scales = self.model().graph.get_input_scales();
|
||||
@@ -1438,15 +1369,13 @@ impl GraphCircuit {
|
||||
) {
|
||||
// if not public then fail
|
||||
if self.settings().run_args.input_visibility.is_private() {
|
||||
return Err("Cannot use on-chain data source as private data".into());
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
let input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => input_data,
|
||||
_ => {
|
||||
return Err("Cannot use non file source as input for on-chain test.
|
||||
Manually populate on-chain data from file source instead"
|
||||
.into())
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
};
|
||||
// Get the flatten length of input_data
|
||||
@@ -1467,19 +1396,13 @@ impl GraphCircuit {
|
||||
) {
|
||||
// if not public then fail
|
||||
if self.settings().run_args.output_visibility.is_private() {
|
||||
return Err("Cannot use on-chain data source as private data".into());
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
let output_data = match &data.output_data {
|
||||
Some(DataSource::File(output_data)) => output_data,
|
||||
Some(DataSource::OnChain(_)) => {
|
||||
return Err(
|
||||
"Cannot use on-chain data source as output for on-chain test.
|
||||
Will manually populate on-chain data from file source instead"
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
_ => return Err("No output data found".into()),
|
||||
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
|
||||
_ => return Err(GraphError::MissingDataSource),
|
||||
};
|
||||
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
|
||||
output_data,
|
||||
@@ -1522,12 +1445,10 @@ impl CircuitSize {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub fn as_json(&self) -> Result<String, GraphError> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use super::errors::GraphError;
|
||||
use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::scale_to_multiplier;
|
||||
use super::vars::*;
|
||||
use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
@@ -37,7 +37,6 @@ use std::collections::BTreeMap;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
@@ -396,7 +395,7 @@ impl ParsedNodes {
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's inputs
|
||||
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
|
||||
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
|
||||
let mut inputs = vec![];
|
||||
|
||||
for input in self.inputs.iter() {
|
||||
@@ -470,7 +469,7 @@ impl Model {
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
/// * `run_args` - [RunArgs]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let visibility = VarVisibility::from_args(run_args)?;
|
||||
|
||||
let graph = Self::load_onnx_model(reader, run_args, &visibility)?;
|
||||
@@ -483,7 +482,7 @@ impl Model {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::new(f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
@@ -491,7 +490,7 @@ impl Model {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load(path: PathBuf) -> Result<Self, Box<dyn Error>> {
|
||||
pub fn load(path: PathBuf) -> Result<Self, GraphError> {
|
||||
// read bytes from file
|
||||
let mut f = std::fs::File::open(&path)?;
|
||||
let metadata = fs::metadata(&path)?;
|
||||
@@ -506,7 +505,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
) -> Result<GraphSettings, GraphError> {
|
||||
let instance_shapes = self.instance_shapes()?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
@@ -536,7 +535,7 @@ impl Model {
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, false, false)?;
|
||||
|
||||
@@ -583,7 +582,7 @@ impl Model {
|
||||
run_args: &RunArgs,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
) -> Result<ForwardResult, GraphError> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
@@ -601,15 +600,12 @@ impl Model {
|
||||
fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<TractResult, Box<dyn Error>> {
|
||||
) -> Result<TractResult, GraphError> {
|
||||
use tract_onnx::{
|
||||
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
|
||||
};
|
||||
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| {
|
||||
error!("Error loading model: {}", e);
|
||||
GraphError::ModelLoad
|
||||
})?;
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader)?;
|
||||
|
||||
let variables: std::collections::HashMap<String, usize> =
|
||||
std::collections::HashMap::from_iter(run_args.variables.clone());
|
||||
@@ -622,7 +618,7 @@ impl Model {
|
||||
if matches!(x, GenericFactoid::Any) {
|
||||
let batch_size = match variables.get("batch_size") {
|
||||
Some(x) => x,
|
||||
None => return Err("Unknown dimension batch_size in model inputs, set batch_size in variables".into()),
|
||||
None => return Err(GraphError::MissingBatchSize),
|
||||
};
|
||||
fact.shape
|
||||
.set_dim(i, tract_onnx::prelude::TDim::Val(*batch_size as i64));
|
||||
@@ -680,12 +676,12 @@ impl Model {
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
visibility: &VarVisibility,
|
||||
) -> Result<ParsedNodes, Box<dyn Error>> {
|
||||
) -> Result<ParsedNodes, GraphError> {
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
|
||||
|
||||
let scales = VarScales::from_args(run_args)?;
|
||||
let scales = VarScales::from_args(run_args);
|
||||
let nodes = Self::nodes_from_graph(
|
||||
&model,
|
||||
run_args,
|
||||
@@ -762,7 +758,7 @@ impl Model {
|
||||
symbol_values: &SymbolValues,
|
||||
override_input_scales: Option<Vec<crate::Scale>>,
|
||||
override_output_scales: Option<HashMap<usize, crate::Scale>>,
|
||||
) -> Result<BTreeMap<usize, NodeType>, Box<dyn Error>> {
|
||||
) -> Result<BTreeMap<usize, NodeType>, GraphError> {
|
||||
use crate::graph::node_output_shapes;
|
||||
|
||||
let mut nodes = BTreeMap::<usize, NodeType>::new();
|
||||
@@ -976,14 +972,11 @@ impl Model {
|
||||
model_path: &std::path::Path,
|
||||
data_chunks: &[GraphData],
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Vec<Tensor<f32>>>, Box<dyn Error>> {
|
||||
) -> Result<Vec<Vec<Tensor<f32>>>, GraphError> {
|
||||
use tract_onnx::tract_core::internal::IntoArcTensor;
|
||||
|
||||
let (model, _) = Model::load_onnx_using_tract(
|
||||
&mut std::fs::File::open(model_path)
|
||||
.map_err(|_| format!("failed to load {}", model_path.display()))?,
|
||||
run_args,
|
||||
)?;
|
||||
let (model, _) =
|
||||
Model::load_onnx_using_tract(&mut std::fs::File::open(model_path)?, run_args)?;
|
||||
|
||||
let datum_types: Vec<DatumType> = model
|
||||
.input_outlets()?
|
||||
@@ -1011,15 +1004,8 @@ impl Model {
|
||||
/// # Arguments
|
||||
/// * `params` - A [GraphSettings] struct holding parsed CLI arguments.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn from_run_args(
|
||||
run_args: &RunArgs,
|
||||
model: &std::path::Path,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
Model::new(
|
||||
&mut std::fs::File::open(model)
|
||||
.map_err(|_| format!("failed to load {}", model.display()))?,
|
||||
run_args,
|
||||
)
|
||||
pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result<Self, GraphError> {
|
||||
Model::new(&mut std::fs::File::open(model)?, run_args)
|
||||
}
|
||||
|
||||
/// Configures a model for the circuit
|
||||
@@ -1031,7 +1017,7 @@ impl Model {
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
settings: &GraphSettings,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
) -> Result<PolyConfig<Fp>, GraphError> {
|
||||
debug!("configuring model");
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
@@ -1093,7 +1079,7 @@ impl Model {
|
||||
vars: &mut ModelVars<Fp>,
|
||||
witnessed_outputs: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
|
||||
info!("model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -1103,7 +1089,11 @@ impl Model {
|
||||
let input_shapes = self.graph.input_shapes()?;
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
if self.visibility.input.is_public() {
|
||||
let instance = vars.instance.as_ref().ok_or("no instance")?.clone();
|
||||
let instance = vars
|
||||
.instance
|
||||
.as_ref()
|
||||
.ok_or(GraphError::MissingInstances)?
|
||||
.clone();
|
||||
results.insert(*input_idx, vec![instance]);
|
||||
vars.increment_instance_idx();
|
||||
} else {
|
||||
@@ -1123,7 +1113,12 @@ impl Model {
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(
|
||||
region,
|
||||
0,
|
||||
run_args.num_inner_cols,
|
||||
original_constants.clone(),
|
||||
);
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1147,24 +1142,31 @@ impl Model {
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars.instance.as_ref().ok_or("no instance")?.clone();
|
||||
let res = vars
|
||||
.instance
|
||||
.as_ref()
|
||||
.ok_or(GraphError::MissingInstances)?
|
||||
.clone();
|
||||
vars.increment_instance_idx();
|
||||
res
|
||||
} else {
|
||||
// if witnessed_outputs is of len less than i error
|
||||
if witnessed_outputs.len() <= i {
|
||||
return Err("you provided insufficient witness values to generate a fixed output".into());
|
||||
return Err(GraphError::InsufficientWitnessValues);
|
||||
}
|
||||
witnessed_outputs[i].clone()
|
||||
};
|
||||
|
||||
config.base.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
config
|
||||
.base
|
||||
.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
})
|
||||
.collect::<Result<Vec<_>,_>>();
|
||||
.collect::<Result<Vec<_>, GraphError>>();
|
||||
res.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
@@ -1178,7 +1180,6 @@ impl Model {
|
||||
|
||||
Ok(outputs)
|
||||
},
|
||||
|
||||
)?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
@@ -1192,7 +1193,7 @@ impl Model {
|
||||
config: &mut ModelConfig,
|
||||
region: &mut RegionCtx<Fp>,
|
||||
results: &mut BTreeMap<usize, Vec<ValTensor<Fp>>>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
|
||||
// index over results to get original inputs
|
||||
let orig_inputs: BTreeMap<usize, _> = results
|
||||
.clone()
|
||||
@@ -1237,7 +1238,10 @@ impl Model {
|
||||
let res = if node.is_constant() && node.num_uses() == 1 {
|
||||
log::debug!("node {} is a constant with 1 use", n.idx);
|
||||
let mut node = n.clone();
|
||||
let c = node.opkind.get_mutable_constant().ok_or("no constant")?;
|
||||
let c = node
|
||||
.opkind
|
||||
.get_mutable_constant()
|
||||
.ok_or(GraphError::MissingConstants)?;
|
||||
Some(c.quantized_values.clone().try_into()?)
|
||||
} else {
|
||||
config
|
||||
@@ -1394,7 +1398,7 @@ impl Model {
|
||||
inputs: &[ValTensor<Fp>],
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
) -> Result<DummyPassRes, GraphError> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -1549,7 +1553,7 @@ impl Model {
|
||||
}
|
||||
|
||||
/// Shapes of the computational graph's public inputs (if any)
|
||||
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
|
||||
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
|
||||
let mut instance_shapes = vec![];
|
||||
if self.visibility.input.is_public() {
|
||||
instance_shapes.extend(self.graph.input_shapes()?);
|
||||
|
||||
@@ -11,6 +11,7 @@ use halo2curves::bn256::{Fr as Fp, G1Affine};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::errors::GraphError;
|
||||
use super::{VarVisibility, Visibility};
|
||||
|
||||
/// poseidon len to hash in tree
|
||||
@@ -295,7 +296,7 @@ impl GraphModules {
|
||||
element_visibility: &Visibility,
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
|
||||
) -> Result<ModuleForwardResult, GraphError> {
|
||||
let mut poseidon_hash = None;
|
||||
let mut polycommit = None;
|
||||
|
||||
|
||||
@@ -8,11 +8,14 @@ use super::Visibility;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::CircuitError;
|
||||
use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::errors::GraphError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,7 +25,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::collections::BTreeMap;
|
||||
use std::error::Error;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::fmt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -65,7 +67,7 @@ impl Op<Fp> for Rescaled {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
.zip(self.scale.iter())
|
||||
@@ -80,11 +82,9 @@ impl Op<Fp> for Rescaled {
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
if self.scale.len() != values.len() {
|
||||
return Err(Box::new(TensorError::DimMismatch(
|
||||
"rescaled inputs".to_string(),
|
||||
)));
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
|
||||
}
|
||||
|
||||
let res =
|
||||
@@ -210,7 +210,7 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
@@ -219,11 +219,11 @@ impl Op<Fp> for RebaseScale {
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
let original_res = self
|
||||
.inner
|
||||
.layout(config, region, values)?
|
||||
.ok_or("no inner layout")?;
|
||||
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
@@ -306,7 +306,7 @@ impl SupportedOp {
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
@@ -372,7 +372,7 @@ impl Op<Fp> for SupportedOp {
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
@@ -400,7 +400,7 @@ impl Op<Fp> for SupportedOp {
|
||||
self
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
@@ -478,7 +478,7 @@ impl Node {
|
||||
symbol_values: &SymbolValues,
|
||||
div_rebasing: bool,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
) -> Result<Self, GraphError> {
|
||||
trace!("Create {:?}", node);
|
||||
trace!("Create op {:?}", node.op);
|
||||
|
||||
@@ -504,10 +504,15 @@ impl Node {
|
||||
input_ids
|
||||
.iter()
|
||||
.map(|(i, _)| {
|
||||
inputs.push(other_nodes.get(i).ok_or("input not found")?.clone());
|
||||
inputs.push(
|
||||
other_nodes
|
||||
.get(i)
|
||||
.ok_or(GraphError::MissingInput(idx))?
|
||||
.clone(),
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let (mut opkind, deleted_indices) = new_op_from_onnx(
|
||||
idx,
|
||||
@@ -544,10 +549,10 @@ impl Node {
|
||||
let idx = inputs
|
||||
.iter()
|
||||
.position(|x| *idx == x.idx())
|
||||
.ok_or("input not found")?;
|
||||
.ok_or(GraphError::MissingInput(*idx))?;
|
||||
Ok(inputs[idx].out_scales()[*outlet])
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let homogenous_inputs = opkind.requires_homogenous_input_scales();
|
||||
// automatically increases a constant's scale if it is only used once and
|
||||
@@ -558,7 +563,7 @@ impl Node {
|
||||
if inputs.len() > input {
|
||||
let input_node = other_nodes
|
||||
.get_mut(&inputs[input].idx())
|
||||
.ok_or("input not found")?;
|
||||
.ok_or(GraphError::MissingInput(idx))?;
|
||||
let input_opkind = &mut input_node.opkind();
|
||||
if let Some(constant) = input_opkind.get_mutable_constant() {
|
||||
rescale_const_with_single_use(
|
||||
@@ -615,10 +620,10 @@ fn rescale_const_with_single_use(
|
||||
in_scales: Vec<crate::Scale>,
|
||||
param_visibility: &Visibility,
|
||||
num_uses: usize,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> Result<(), GraphError> {
|
||||
if num_uses == 1 {
|
||||
let current_scale = constant.out_scale(vec![])?;
|
||||
let scale_max = in_scales.iter().max().ok_or("no scales")?;
|
||||
let scale_max = in_scales.iter().max().ok_or(GraphError::MissingScale)?;
|
||||
if scale_max > ¤t_scale {
|
||||
let raw_values = constant.raw_values.clone();
|
||||
constant.quantized_values =
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::GraphError;
|
||||
use super::errors::GraphError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::VarScales;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
@@ -16,7 +15,6 @@ use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{debug, warn};
|
||||
use std::error::Error;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::sync::Arc;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -92,7 +90,7 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
pub fn node_output_shapes(
|
||||
node: &OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
symbol_values: &SymbolValues,
|
||||
) -> Result<Vec<Vec<usize>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Vec<usize>>, GraphError> {
|
||||
let mut shapes = Vec::new();
|
||||
let outputs = node.outputs.to_vec();
|
||||
for output in outputs {
|
||||
@@ -109,7 +107,7 @@ use tract_onnx::prelude::SymbolValues;
|
||||
/// Extracts the raw values from a tensor.
|
||||
pub fn extract_tensor_value(
|
||||
input: Arc<tract_onnx::prelude::Tensor>,
|
||||
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Tensor<f32>, GraphError> {
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
let dt = input.datum_type();
|
||||
@@ -194,20 +192,20 @@ pub fn extract_tensor_value(
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
|
||||
|
||||
let cast: Result<Vec<f32>, &str> = vec
|
||||
let cast: Result<Vec<f32>, GraphError> = vec
|
||||
.par_iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err("could not evaluate tdim"),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
const_value = Tensor::<f32>::new(Some(&cast?), &dims)?;
|
||||
}
|
||||
_ => return Err("unsupported data type".into()),
|
||||
_ => return Err(GraphError::UnsupportedDataType(0, format!("{:?}", dt))),
|
||||
}
|
||||
const_value.reshape(&dims)?;
|
||||
|
||||
@@ -219,12 +217,12 @@ fn load_op<C: tract_onnx::prelude::Op + Clone>(
|
||||
op: &dyn tract_onnx::prelude::Op,
|
||||
idx: usize,
|
||||
name: String,
|
||||
) -> Result<C, Box<dyn std::error::Error>> {
|
||||
) -> Result<C, GraphError> {
|
||||
// Extract the slope layer hyperparams
|
||||
let op: &C = match op.downcast_ref::<C>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, name)));
|
||||
return Err(GraphError::OpMismatch(idx, name));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -247,7 +245,7 @@ pub fn new_op_from_onnx(
|
||||
inputs: &mut [super::NodeType],
|
||||
symbol_values: &SymbolValues,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
@@ -260,7 +258,7 @@ pub fn new_op_from_onnx(
|
||||
let mut replace_const = |scale: crate::Scale,
|
||||
index: usize,
|
||||
default_op: SupportedOp|
|
||||
-> Result<SupportedOp, Box<dyn std::error::Error>> {
|
||||
-> Result<SupportedOp, GraphError> {
|
||||
let mut constant = inputs[index].opkind();
|
||||
let constant = constant.get_mutable_constant();
|
||||
if let Some(c) = constant {
|
||||
@@ -285,19 +283,13 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"shift left".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"ShiftLeft".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
@@ -307,19 +299,13 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"shift right".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"ShiftRight".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -337,7 +323,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
for (i, input) in inputs.iter_mut().enumerate() {
|
||||
if !input.opkind().is_constant() {
|
||||
return Err("Range only supports constant inputs in a zk circuit".into());
|
||||
return Err(GraphError::NonConstantRange);
|
||||
} else {
|
||||
input.decrement_use();
|
||||
deleted_indices.push(i);
|
||||
@@ -348,7 +334,7 @@ pub fn new_op_from_onnx(
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or("Range requires constant inputs"))
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let start = input_ops[0].raw_values.map(|x| x as usize)[0];
|
||||
@@ -375,11 +361,11 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "trilu".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "trilu".to_string()));
|
||||
}
|
||||
raw_values[0] as i32
|
||||
} else {
|
||||
return Err("we only support constant inputs for trilu diagonal".into());
|
||||
return Err(GraphError::NonConstantTrilu);
|
||||
};
|
||||
|
||||
SupportedOp::Linear(PolyOp::Trilu { upper, k: diagonal })
|
||||
@@ -387,7 +373,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
"Gather" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "gather".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
};
|
||||
let op = load_op::<Gather>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -456,10 +442,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"ScatterElements" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter elements".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
};
|
||||
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -494,10 +477,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"ScatterNd" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter nd".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
};
|
||||
// just verify it deserializes correctly
|
||||
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
@@ -529,10 +509,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
"GatherNd" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather nd".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
};
|
||||
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
let batch_dims = op.batch_dims;
|
||||
@@ -566,10 +543,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
"GatherElements" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather elements".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
};
|
||||
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -615,10 +589,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"MoveAxis".to_string(),
|
||||
)))
|
||||
return Err(GraphError::OpMismatch(idx, "MoveAxis".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -654,7 +625,9 @@ pub fn new_op_from_onnx(
|
||||
| DatumType::U32
|
||||
| DatumType::U64 => 0,
|
||||
DatumType::F16 | DatumType::F32 | DatumType::F64 => scales.params,
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
_ => {
|
||||
return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt)));
|
||||
}
|
||||
};
|
||||
|
||||
// if all raw_values are round then set scale to 0
|
||||
@@ -672,7 +645,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<ArgMax(false)>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "argmax".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
@@ -682,7 +655,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<ArgMin(false)>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "argmin".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
@@ -692,7 +665,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Min>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -701,7 +674,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Max>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -710,7 +683,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Prod>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "prod".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "prod".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
@@ -727,7 +700,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Sum>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "sum".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "sum".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -736,10 +709,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<MeanOfSquares>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"mean of squares".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "mean of squares".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -759,7 +729,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
@@ -768,10 +738,10 @@ pub fn new_op_from_onnx(
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
@@ -790,7 +760,7 @@ pub fn new_op_from_onnx(
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
}
|
||||
"Min" => {
|
||||
@@ -805,7 +775,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
@@ -814,10 +784,10 @@ pub fn new_op_from_onnx(
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
@@ -834,7 +804,7 @@ pub fn new_op_from_onnx(
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
@@ -855,10 +825,7 @@ pub fn new_op_from_onnx(
|
||||
let leaky_op: &LeakyRelu = match leaky_op.0.downcast_ref::<LeakyRelu>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"leaky relu".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "leaky relu".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -867,7 +834,7 @@ pub fn new_op_from_onnx(
|
||||
})
|
||||
}
|
||||
"Scan" => {
|
||||
return Err("scan should never be analyzed explicitly".into());
|
||||
unreachable!();
|
||||
}
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
@@ -932,7 +899,9 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Source" => {
|
||||
let (scale, datum_type) = match node.outputs[0].fact.datum_type {
|
||||
let dt = node.outputs[0].fact.datum_type;
|
||||
|
||||
let (scale, datum_type) = match dt {
|
||||
DatumType::Bool => (0, InputType::Bool),
|
||||
DatumType::TDim => (0, InputType::TDim),
|
||||
DatumType::I64
|
||||
@@ -946,7 +915,7 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F16 => (scales.input, InputType::F16),
|
||||
DatumType::F32 => (scales.input, InputType::F32),
|
||||
DatumType::F64 => (scales.input, InputType::F64),
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
};
|
||||
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
|
||||
}
|
||||
@@ -985,7 +954,7 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
}
|
||||
}
|
||||
"Add" => SupportedOp::Linear(PolyOp::Add),
|
||||
@@ -1001,7 +970,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "mul".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
@@ -1027,17 +996,14 @@ pub fn new_op_from_onnx(
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Less)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "less".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "less".to_string()));
|
||||
}
|
||||
}
|
||||
"LessEqual" => {
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::LessEqual)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"less equal".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
|
||||
}
|
||||
}
|
||||
"Greater" => {
|
||||
@@ -1045,10 +1011,7 @@ pub fn new_op_from_onnx(
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Greater)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"greater".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
|
||||
}
|
||||
}
|
||||
"GreaterEqual" => {
|
||||
@@ -1056,10 +1019,7 @@ pub fn new_op_from_onnx(
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::GreaterEqual)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"greater equal".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "greater equal".to_string()));
|
||||
}
|
||||
}
|
||||
"EinSum" => {
|
||||
@@ -1067,7 +1027,7 @@ pub fn new_op_from_onnx(
|
||||
let op: &EinSum = match node.op().downcast_ref::<EinSum>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "einsum".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "einsum".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1081,7 +1041,7 @@ pub fn new_op_from_onnx(
|
||||
let softmax_op: &Softmax = match node.op().downcast_ref::<Softmax>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "softmax".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1100,7 +1060,7 @@ pub fn new_op_from_onnx(
|
||||
let sumpool_node: &MaxPool = match op.downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Maxpool".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Maxpool".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1108,9 +1068,9 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
@@ -1122,7 +1082,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
@@ -1170,15 +1130,15 @@ pub fn new_op_from_onnx(
|
||||
let conv_node: &Conv = match node.op().downcast_ref::<Conv>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "conv".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "conv".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(dilations) = &conv_node.pool_spec.dilations {
|
||||
if dilations.iter().any(|x| *x != 1) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"non unit dilations not supported".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1186,15 +1146,15 @@ pub fn new_op_from_onnx(
|
||||
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|
||||
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1203,7 +1163,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1234,30 +1194,30 @@ pub fn new_op_from_onnx(
|
||||
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "deconv".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "deconv".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(dilations) = &deconv_node.pool_spec.dilations {
|
||||
if dilations.iter().any(|x| *x != 1) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"non unit dilations not supported".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
|| (deconv_node.kernel_format != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
@@ -1265,7 +1225,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1295,10 +1255,7 @@ pub fn new_op_from_onnx(
|
||||
let downsample_node: Downsample = match node.op().downcast_ref::<Downsample>() {
|
||||
Some(b) => b.clone(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"downsample".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "downsample".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1323,7 +1280,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Resize".to_string()));
|
||||
}
|
||||
|
||||
let scale_factor_node = // find optional_scales_input in the string and extract the value inside the Some
|
||||
@@ -1337,7 +1294,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>()[1]
|
||||
.split(')')
|
||||
.collect::<Vec<_>>()[0]
|
||||
.parse::<usize>()?)
|
||||
.parse::<usize>().map_err(|_| GraphError::OpMismatch(idx, "Resize".to_string()))?)
|
||||
};
|
||||
|
||||
let scale_factor = if let Some(scale_factor_node) = scale_factor_node {
|
||||
@@ -1345,7 +1302,7 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
c.map(|x| x as usize).into_iter().collect::<Vec<usize>>()
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Resize".to_string()));
|
||||
}
|
||||
} else {
|
||||
// default
|
||||
@@ -1369,7 +1326,7 @@ pub fn new_op_from_onnx(
|
||||
let sumpool_node: &SumPool = match op.downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "sumpool".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "sumpool".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1377,9 +1334,9 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
@@ -1391,7 +1348,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1411,7 +1368,7 @@ pub fn new_op_from_onnx(
|
||||
let pad_node: &Pad = match node.op().downcast_ref::<Pad>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "pad".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "pad".to_string()));
|
||||
}
|
||||
};
|
||||
// we only support constant 0 padding
|
||||
@@ -1420,9 +1377,9 @@ pub fn new_op_from_onnx(
|
||||
tract_onnx::prelude::Tensor::zero::<f32>(&[])?,
|
||||
))
|
||||
{
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"pad mode or pad type".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
|
||||
@@ -1473,7 +1430,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
const_value: Tensor<f32>,
|
||||
scale: crate::Scale,
|
||||
visibility: &Visibility,
|
||||
) -> Result<Tensor<F>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
@@ -1492,7 +1449,7 @@ use crate::tensor::ValTensor;
|
||||
pub(crate) fn split_valtensor(
|
||||
values: &ValTensor<Fp>,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
|
||||
let mut tensors: Vec<ValTensor<Fp>> = Vec::new();
|
||||
let mut start = 0;
|
||||
for shape in shapes {
|
||||
@@ -1510,7 +1467,7 @@ pub fn homogenize_input_scales(
|
||||
op: Box<dyn Op<Fp>>,
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
@@ -1529,7 +1486,7 @@ pub fn homogenize_input_scales(
|
||||
|
||||
let mut multipliers: Vec<u128> = vec![1; input_scales.len()];
|
||||
|
||||
let max_scale = input_scales.iter().max().ok_or("no max scale")?;
|
||||
let max_scale = input_scales.iter().max().ok_or(GraphError::MissingScale)?;
|
||||
let _ = input_scales
|
||||
.iter()
|
||||
.enumerate()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::error::Error;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
@@ -17,6 +16,8 @@ use pyo3::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
@@ -261,12 +262,12 @@ impl VarScales {
|
||||
}
|
||||
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
Ok(Self {
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
params: args.param_scale,
|
||||
rebase_multiplier: args.scale_rebase_multiplier,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,15 +304,13 @@ impl Default for VarVisibility {
|
||||
impl VarVisibility {
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
let output_vis = &args.output_visibility;
|
||||
|
||||
if params_vis.is_public() {
|
||||
return Err(
|
||||
"public visibility for params is deprecated, please use `fixed` instead".into(),
|
||||
);
|
||||
return Err(GraphError::ParamsPublicVisibility);
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
@@ -327,7 +326,7 @@ impl VarVisibility {
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(Box::new(GraphError::Visibility));
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
Ok(Self {
|
||||
input: input_vis.clone(),
|
||||
|
||||
55
src/lib.rs
55
src/lib.rs
@@ -28,6 +28,59 @@
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
/// Error type
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum EZKLError {
|
||||
#[error("[aggregation] {0}")]
|
||||
AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError),
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[eth] {0}")]
|
||||
EthError(#[from] eth::EthError),
|
||||
#[error("[graph] {0}")]
|
||||
GraphError(#[from] graph::errors::GraphError),
|
||||
#[error("[pfsys] {0}")]
|
||||
PfsysError(#[from] pfsys::errors::PfsysError),
|
||||
#[error("[circuit] {0}")]
|
||||
CircuitError(#[from] circuit::errors::CircuitError),
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] tensor::errors::TensorError),
|
||||
#[error("[module] {0}")]
|
||||
ModuleError(#[from] circuit::modules::errors::ModuleError),
|
||||
#[error("[io] {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("[json] {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("[utf8] {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[reqwest] {0}")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("[fmt] {0}")]
|
||||
FmtError(#[from] std::fmt::Error),
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] halo2_proofs::plonk::Error),
|
||||
#[error("[Uncategorized] {0}")]
|
||||
UncategorizedError(String),
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[execute] {0}")]
|
||||
ExecutionError(#[from] execute::ExecutionError),
|
||||
#[error("[srs] {0}")]
|
||||
SrsError(#[from] pfsys::srs::SrsError),
|
||||
}
|
||||
|
||||
impl From<&str> for EZKLError {
|
||||
fn from(s: &str) -> Self {
|
||||
EZKLError::UncategorizedError(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for EZKLError {
|
||||
fn from(s: String) -> Self {
|
||||
EZKLError::UncategorizedError(s)
|
||||
}
|
||||
}
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
@@ -248,7 +301,7 @@ impl Default for RunArgs {
|
||||
|
||||
impl RunArgs {
|
||||
///
|
||||
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.param_visibility == Visibility::Public {
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
|
||||
27
src/pfsys/errors.rs
Normal file
27
src/pfsys/errors.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for the pfsys module
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PfsysError {
|
||||
/// Failed to save the proof
|
||||
#[error("failed to save the proof: {0}")]
|
||||
SaveProof(String),
|
||||
/// Failed to load the proof
|
||||
#[error("failed to load the proof: {0}")]
|
||||
LoadProof(String),
|
||||
/// Halo2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] halo2_proofs::plonk::Error),
|
||||
/// Failed to write point to transcript
|
||||
#[error("failed to write point to transcript: {0}")]
|
||||
WritePoint(String),
|
||||
/// Invalid commitment scheme
|
||||
#[error("invalid commitment scheme")]
|
||||
InvalidCommitmentScheme,
|
||||
/// Failed to load vk from file
|
||||
#[error("failed to load vk from file: {0}")]
|
||||
LoadVk(String),
|
||||
/// Failed to load pk from file
|
||||
#[error("failed to load pk from file: {0}")]
|
||||
LoadPk(String),
|
||||
}
|
||||
@@ -10,17 +10,14 @@ pub enum EvmVerificationError {
|
||||
#[error("Solidity verifier found the proof invalid")]
|
||||
InvalidProof,
|
||||
/// If the Solidity verifier threw and error (e.g. OutOfGas)
|
||||
#[error("Execution of Solidity code failed")]
|
||||
SolidityExecution,
|
||||
/// EVM execution errors
|
||||
#[error("EVM execution of raw code failed")]
|
||||
RawExecution,
|
||||
#[error("Execution of Solidity code failed: {0}")]
|
||||
SolidityExecution(String),
|
||||
/// EVM verify errors
|
||||
#[error("evm verification reverted")]
|
||||
Reverted,
|
||||
#[error("evm verification reverted: {0}")]
|
||||
Reverted(String),
|
||||
/// EVM verify errors
|
||||
#[error("evm deployment failed")]
|
||||
Deploy,
|
||||
#[error("evm deployment failed: {0}")]
|
||||
DeploymentFailed(String),
|
||||
/// Invalid Visibility
|
||||
#[error("Invalid visibility")]
|
||||
InvalidVisibility,
|
||||
|
||||
@@ -4,6 +4,11 @@ pub mod evm;
|
||||
/// SRS generation, processing, verification and downloading
|
||||
pub mod srs;
|
||||
|
||||
/// errors related to pfsys
|
||||
pub mod errors;
|
||||
|
||||
pub use errors::PfsysError;
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
@@ -32,7 +37,6 @@ use serde::{Deserialize, Serialize};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::verifier::plonk::PlonkProtocol;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, BufWriter, Cursor, Write};
|
||||
use std::ops::Deref;
|
||||
@@ -364,24 +368,28 @@ where
|
||||
}
|
||||
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let file = std::fs::File::create(proof_path)?;
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), PfsysError> {
|
||||
let file = std::fs::File::create(proof_path)
|
||||
.map_err(|e| PfsysError::SaveProof(format!("{}", e)))?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(&mut writer, &self)?;
|
||||
serde_json::to_writer(&mut writer, &self)
|
||||
.map_err(|e| PfsysError::SaveProof(format!("{}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a json serialized proof from the provided path.
|
||||
pub fn load<Scheme: CommitmentScheme<Curve = C, Scalar = F>>(
|
||||
proof_path: &PathBuf,
|
||||
) -> Result<Self, Box<dyn Error>>
|
||||
) -> Result<Self, PfsysError>
|
||||
where
|
||||
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
|
||||
{
|
||||
trace!("reading proof");
|
||||
let file = std::fs::File::open(proof_path)?;
|
||||
let file =
|
||||
std::fs::File::open(proof_path).map_err(|e| PfsysError::LoadProof(format!("{}", e)))?;
|
||||
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let proof: Self = serde_json::from_reader(reader)?;
|
||||
let proof: Self =
|
||||
serde_json::from_reader(reader).map_err(|e| PfsysError::LoadProof(format!("{}", e)))?;
|
||||
Ok(proof)
|
||||
}
|
||||
}
|
||||
@@ -541,7 +549,7 @@ pub fn create_proof_circuit<
|
||||
transcript_type: TranscriptType,
|
||||
split: Option<ProofSplitCommit>,
|
||||
protocol: Option<PlonkProtocol<Scheme::Curve>>,
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
|
||||
where
|
||||
Scheme::ParamsVerifier: 'params,
|
||||
Scheme::Scalar: Serialize
|
||||
@@ -626,7 +634,7 @@ pub fn swap_proof_commitments<
|
||||
>(
|
||||
snark: &Snark<Scheme::Scalar, Scheme::Curve>,
|
||||
commitments: &[Scheme::Curve],
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
+ PrimeField
|
||||
@@ -654,7 +662,7 @@ pub fn get_proof_commitments<
|
||||
TW: TranscriptWriterBuffer<Vec<u8>, Scheme::Curve, E>,
|
||||
>(
|
||||
commitments: &[Scheme::Curve],
|
||||
) -> Result<Vec<u8>, Box<dyn Error>>
|
||||
) -> Result<Vec<u8>, PfsysError>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
+ PrimeField
|
||||
@@ -671,7 +679,7 @@ where
|
||||
for commit in commitments {
|
||||
transcript_new
|
||||
.write_point(*commit)
|
||||
.map_err(|_| "failed to write point")?;
|
||||
.map_err(|e| PfsysError::WritePoint(format!("{}", e)))?;
|
||||
}
|
||||
|
||||
let proof_first_bytes = transcript_new.finalize();
|
||||
@@ -687,7 +695,7 @@ where
|
||||
pub fn swap_proof_commitments_polycommit(
|
||||
snark: &Snark<Fr, G1Affine>,
|
||||
commitments: &[G1Affine],
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, PfsysError> {
|
||||
let proof = match snark.commitment {
|
||||
Some(Commitments::KZG) => match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
@@ -714,7 +722,7 @@ pub fn swap_proof_commitments_polycommit(
|
||||
>(snark, commitments)?,
|
||||
},
|
||||
None => {
|
||||
return Err("commitment scheme not found".into());
|
||||
return Err(PfsysError::InvalidCommitmentScheme);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -761,22 +769,22 @@ where
|
||||
pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
path: PathBuf,
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
|
||||
debug!("loading verification key from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadVk(format!("{}", e)))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)?;
|
||||
info!("done loading verification key ✅");
|
||||
)
|
||||
.map_err(|e| PfsysError::LoadVk(format!("{}", e)))?;
|
||||
info!("loaded verification key ✅");
|
||||
Ok(vk)
|
||||
}
|
||||
|
||||
@@ -784,22 +792,22 @@ where
|
||||
pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
path: PathBuf,
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
info!("loading proving key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
|
||||
debug!("loading proving key from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)?;
|
||||
info!("done loading proving key ✅");
|
||||
)
|
||||
.map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
|
||||
info!("loaded proving key ✅");
|
||||
Ok(pk)
|
||||
}
|
||||
|
||||
@@ -811,7 +819,7 @@ pub fn save_pk<C: SerdeObject + CurveAffine>(
|
||||
where
|
||||
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
debug!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
@@ -828,7 +836,7 @@ pub fn save_vk<C: CurveAffine + SerdeObject>(
|
||||
where
|
||||
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
debug!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
@@ -842,7 +850,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
params: &'_ Scheme::ParamsVerifier,
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
debug!("saving parameters 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
params.write(&mut writer)?;
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use halo2_proofs::poly::commitment::CommitmentScheme;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use log::info;
|
||||
use std::error::Error;
|
||||
use log::debug;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::path::PathBuf;
|
||||
@@ -16,24 +15,33 @@ pub fn gen_srs<Scheme: CommitmentScheme>(k: u32) -> Scheme::ParamsProver {
|
||||
Scheme::ParamsProver::new(k)
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum SrsError {
|
||||
#[error("failed to download srs from {0}")]
|
||||
DownloadError(String),
|
||||
#[error("failed to load srs from {0}")]
|
||||
LoadError(PathBuf),
|
||||
#[error("failed to read srs {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_srs_verifier<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
|
||||
info!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
|
||||
) -> Result<Scheme::ParamsVerifier, SrsError> {
|
||||
debug!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|_| SrsError::LoadError(path))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(|e| SrsError::ReadError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_srs_prover<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
|
||||
info!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
|
||||
) -> Result<Scheme::ParamsProver, SrsError> {
|
||||
debug!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|_| SrsError::LoadError(path.clone()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(|e| SrsError::ReadError(e.to_string()))
|
||||
}
|
||||
|
||||
30
src/tensor/errors.rs
Normal file
30
src/tensor/errors.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("dimension mismatch in tensor op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Shape when instantiating
|
||||
#[error("dimensionality error when manipulating a tensor: {0}")]
|
||||
DimError(String),
|
||||
/// wrong method was called on a tensor-like struct
|
||||
#[error("wrong method called")]
|
||||
WrongMethod,
|
||||
/// Significant bit truncation when instantiating
|
||||
#[error("significant bit truncation when instantiating, try lowering the scale")]
|
||||
SigBitTruncationError,
|
||||
/// Failed to convert to field element tensor
|
||||
#[error("failed to convert to field element tensor")]
|
||||
FeltError,
|
||||
/// Unsupported operation
|
||||
#[error("unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
/// Unset visibility
|
||||
#[error("unset visibility")]
|
||||
UnsetVisibility,
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
/// Tensor related errors.
|
||||
pub mod errors;
|
||||
/// Implementations of common operations on tensors.
|
||||
pub mod ops;
|
||||
/// A wrapper around a tensor of circuit variables / advices.
|
||||
@@ -5,6 +7,8 @@ pub mod val;
|
||||
/// A wrapper around a tensor of Halo2 Value types.
|
||||
pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use halo2curves::{bn256::Fr, ff::PrimeField};
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
@@ -40,40 +44,10 @@ use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("dimension mismatch in tensor op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Shape when instantiating
|
||||
#[error("dimensionality error when manipulating a tensor: {0}")]
|
||||
DimError(String),
|
||||
/// wrong method was called on a tensor-like struct
|
||||
#[error("wrong method called")]
|
||||
WrongMethod,
|
||||
/// Significant bit truncation when instantiating
|
||||
#[error("Significant bit truncation when instantiating, try lowering the scale")]
|
||||
SigBitTruncationError,
|
||||
/// Failed to convert to field element tensor
|
||||
#[error("Failed to convert to field element tensor")]
|
||||
FeltError,
|
||||
/// Table lookup error
|
||||
#[error("Table lookup error")]
|
||||
TableLookupError,
|
||||
/// Unsupported operation
|
||||
#[error("Unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("Unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
@@ -400,9 +374,7 @@ impl IntoI64 for () {
|
||||
fn into_i64(self) -> i64 {
|
||||
0
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {
|
||||
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {}
|
||||
}
|
||||
|
||||
impl IntoI64 for Fr {
|
||||
@@ -1852,7 +1824,7 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
) -> Result<Vec<usize>, Box<dyn Error>> {
|
||||
) -> Result<Vec<usize>, TensorError> {
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
@@ -1867,9 +1839,9 @@ pub fn get_broadcasted_shape(
|
||||
}
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(Box::new(TensorError::DimError(
|
||||
_ => Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
))),
|
||||
)),
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
|
||||
@@ -256,23 +256,23 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> TryFrom<Tensor<F>> for ValTensor<F> {
|
||||
type Error = Box<dyn Error>;
|
||||
fn try_from(t: Tensor<F>) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
type Error = TensorError;
|
||||
fn try_from(t: Tensor<F>) -> Result<ValTensor<F>, TensorError> {
|
||||
let visibility = t.visibility.clone();
|
||||
let dims = t.dims().to_vec();
|
||||
let inner = t.into_iter().map(|x| {
|
||||
if let Some(vis) = &visibility {
|
||||
match vis {
|
||||
Visibility::Fixed => Ok(ValType::Constant(x)),
|
||||
_ => {
|
||||
Ok(Value::known(x).into())
|
||||
let inner = t
|
||||
.into_iter()
|
||||
.map(|x| {
|
||||
if let Some(vis) = &visibility {
|
||||
match vis {
|
||||
Visibility::Fixed => Ok(ValType::Constant(x)),
|
||||
_ => Ok(Value::known(x).into()),
|
||||
}
|
||||
} else {
|
||||
Err(TensorError::UnsetVisibility)
|
||||
}
|
||||
}
|
||||
else {
|
||||
Err("visibility should be set to convert a tensor of field elements to a ValTensor.".into())
|
||||
}
|
||||
}).collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
let mut inner: Tensor<ValType<F>> = inner.into_iter().into();
|
||||
inner.reshape(&dims)?;
|
||||
@@ -378,13 +378,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// reverse order of elements whilst preserving the shape
|
||||
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
pub fn reverse(&mut self) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.reverse();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
@@ -420,7 +420,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn any_unknowns(&self) -> Result<bool, Box<dyn Error>> {
|
||||
pub fn any_unknowns(&self) -> Result<bool, TensorError> {
|
||||
match self {
|
||||
ValTensor::Instance { .. } => Ok(true),
|
||||
_ => Ok(self.get_inner()?.iter().any(|&x| {
|
||||
@@ -491,7 +491,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Fetch the underlying [Tensor] of field elements.
|
||||
pub fn get_felt_evals(&self) -> Result<Tensor<F>, Box<dyn Error>> {
|
||||
pub fn get_felt_evals(&self) -> Result<Tensor<F>, TensorError> {
|
||||
let mut felt_evals: Vec<F> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
@@ -504,7 +504,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
|
||||
let mut res: Tensor<F> = felt_evals.into_iter().into();
|
||||
@@ -521,7 +521,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, Box<dyn Error>> {
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<i64> = vec![];
|
||||
match self {
|
||||
@@ -547,7 +547,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
@@ -558,7 +558,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -567,14 +567,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
|
||||
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
@@ -592,13 +592,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Calls `get_single_elem` on the inner tensor.
|
||||
pub fn get_single_elem(&self, index: usize) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
pub fn get_single_elem(&self, index: usize) -> Result<ValTensor<F>, TensorError> {
|
||||
let slice = match self {
|
||||
ValTensor::Value {
|
||||
inner: v,
|
||||
@@ -612,7 +612,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
@@ -648,7 +648,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
})
|
||||
}
|
||||
/// Calls `expand` on the inner tensor.
|
||||
pub fn expand(&mut self, dims: &[usize]) -> Result<(), Box<dyn Error>> {
|
||||
pub fn expand(&mut self, dims: &[usize]) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -657,14 +657,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `move_axis` on the inner tensor.
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), Box<dyn Error>> {
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -673,14 +673,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the [ValTensor]'s shape.
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box<dyn Error>> {
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -690,10 +690,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
ValTensor::Instance { dims: d, idx, .. } => {
|
||||
if d[*idx].iter().product::<usize>() != new_dims.iter().product::<usize>() {
|
||||
return Err(Box::new(TensorError::DimError(format!(
|
||||
return Err(TensorError::DimError(format!(
|
||||
"Cannot reshape {:?} to {:?} as they have number of elements",
|
||||
d[*idx], new_dims
|
||||
))));
|
||||
)));
|
||||
}
|
||||
d[*idx] = new_dims.to_vec();
|
||||
}
|
||||
@@ -702,12 +702,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Sets the [ValTensor]'s shape.
|
||||
pub fn slice(
|
||||
&mut self,
|
||||
axis: &usize,
|
||||
start: &usize,
|
||||
end: &usize,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
pub fn slice(&mut self, axis: &usize, start: &usize, end: &usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -716,7 +711,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
@@ -982,7 +977,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
/// inverts the inner values
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, TensorError> {
|
||||
let mut cloned_self = self.clone();
|
||||
|
||||
match &mut cloned_self {
|
||||
@@ -1000,7 +995,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(cloned_self)
|
||||
|
||||
@@ -31,6 +31,15 @@ pub enum VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// name of the tensor
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
VarTensor::Dummy { .. } => "Dummy",
|
||||
VarTensor::Empty => "Empty",
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
|
||||
Reference in New Issue
Block a user