Compare commits

...

7 Commits

Author SHA1 Message Date
github-actions[bot]
94abf91400 ci: update version string in docs 2024-06-14 16:32:09 +00:00
dante
4771192823 fix: more verbose io / rw errors (#815) 2024-06-14 12:31:50 -04:00
dante
a863ccc868 chore: update cmd feature flag (#814) 2024-06-11 16:07:49 -04:00
Ethan Cemer
8e6ccc863d feat: all file source kzg commit DA (#812) 2024-06-11 09:32:54 -04:00
dante
00d6873f9a fix: should update using bash when possible (#813) 2024-06-10 12:13:59 -04:00
dante
c97ff84198 refactor: rm boxed errors (opaque) (#810) 2024-06-08 22:41:47 -04:00
dante
f5f8ef56f7 chore: ezkl self update (#809) 2024-06-07 10:30:21 -04:00
44 changed files with 1621 additions and 1198 deletions

View File

@@ -345,6 +345,8 @@ jobs:
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)

1
Cargo.lock generated
View File

@@ -1878,6 +1878,7 @@ dependencies = [
"rand 0.8.5",
"regex",
"reqwest",
"semver 1.0.22",
"seq-macro",
"serde",
"serde-wasm-bindgen",

View File

@@ -45,6 +45,7 @@ num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
semver = "1.0.22"
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
@@ -202,6 +203,7 @@ det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
metal = ["dep:metal", "dep:objc"]
# icicle patch to 0.1.0 if feature icicle is enabled

View File

@@ -93,9 +93,6 @@ contract LoadInstances {
}
}
// Contract that checks that the COMMITMENT_KZG bytes is equal to the first part of the proof.
pragma solidity ^0.8.0;
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
@@ -163,7 +160,7 @@ contract SwapProofCommitments {
}
return equal; // Return true if the commitment comparison passed
}
} /// end checkKzgCommits
}
// This contract serves as a Data Attestation Verifier for the EZKL model.

View File

@@ -1,4 +1,4 @@
ezkl==0.0.0
ezkl==11.4.2
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '11.4.2'
version = release

View File

@@ -482,7 +482,7 @@
"source": [
"import pytest\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path_faulty,\n",
" settings_path,\n",
@@ -514,9 +514,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -478,12 +478,11 @@
"import pytest\n",
"\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"# Run the test function\n",
@@ -510,9 +509,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -99,6 +99,10 @@ fi
echo "Removing old ezkl binary if it exists"
[ -e file ] && rm file
# echo platform and architecture
echo "Platform: $PLATFORM"
echo "Architecture: $ARCHITECTURE"
# download the release and unpack the right tarball
if [ "$PLATFORM" == "windows-msvc" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")

View File

@@ -17,17 +17,14 @@ use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "icicle")]
use std::env;
#[cfg(not(target_arch = "wasm32"))]
use std::error::Error;
#[tokio::main(flavor = "current_thread")]
#[cfg(not(target_arch = "wasm32"))]
pub async fn main() -> Result<(), Box<dyn Error>> {
pub async fn main() {
let args = Cli::parse();
if let Some(generator) = args.generator {
ezkl::commands::print_completions(generator, &mut Cli::command());
Ok(())
} else if let Some(command) = args.command {
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
@@ -38,15 +35,24 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
} else {
info!("Running with CPU");
}
info!("command: \n {}", &command.as_json().to_colored_json_auto()?);
info!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()
);
let res = run(command).await;
match &res {
Ok(_) => info!("succeeded"),
Err(e) => error!("failed: {}", e),
};
res.map(|_| ())
Ok(_) => {
info!("succeeded");
}
Err(e) => {
error!("{}", e);
std::process::exit(1)
}
}
} else {
Err("No command provided".into())
init_logger();
error!("No command provided");
std::process::exit(1)
}
}

View File

@@ -0,0 +1,25 @@
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
/// Error type for the circuit module
#[derive(Error, Debug)]
pub enum ModuleError {
/// Halo 2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] PlonkError),
/// Wrong input type for a module
#[error("wrong input type {0} must be {1}")]
WrongInputType(String, String),
/// A constant was not previously assigned
#[error("constant was not previously assigned")]
ConstantNotAssigned,
/// Input length is wrong
#[error("input length is wrong {0}")]
InputWrongLength(usize),
}
impl From<ModuleError> for PlonkError {
fn from(_e: ModuleError) -> PlonkError {
PlonkError::Synthesis
}
}

View File

@@ -6,10 +6,11 @@ pub mod polycommit;
///
pub mod planner;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Error},
};
///
pub mod errors;
use halo2_proofs::{circuit::Layouter, plonk::ConstraintSystem};
use halo2curves::ff::PrimeField;
pub use planner::*;
@@ -35,14 +36,14 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Name
fn name(&self) -> &'static str;
/// Run the operation the module represents
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, Box<dyn std::error::Error>>;
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, errors::ModuleError>;
/// Layout inputs
fn layout_inputs(
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
) -> Result<Self::InputAssignments, errors::ModuleError>;
/// Layout
fn layout(
&self,
@@ -50,7 +51,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
) -> Result<ValTensor<F>, errors::ModuleError>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;
/// Number of rows used by the module

View File

@@ -18,6 +18,7 @@ use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::errors::ModuleError;
use super::Module;
/// The number of instance columns used by the PolyCommit hash function
@@ -110,7 +111,7 @@ impl Module<Fp> for PolyCommitChip {
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
) -> Result<Self::InputAssignments, ModuleError> {
Ok(())
}
@@ -123,28 +124,30 @@ impl Module<Fp> for PolyCommitChip {
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
) -> Result<ValTensor<Fp>, ModuleError> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
layouter
.assign_region(
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
.map_err(|e| e.into())
}
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
Ok(vec![message])
}

View File

@@ -21,6 +21,7 @@ use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::errors::ModuleError;
use super::Module;
/// The number of instance columns used by the Poseidon hash function
@@ -174,7 +175,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
) -> Result<Self::InputAssignments, ModuleError> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
@@ -185,78 +186,82 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let res = layouter.assign_region(
|| "load message",
|mut region| {
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, Error> = match &message {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
match &message {
ValTensor::Value { inner: v, .. } => {
v.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
match value {
ValType::Value(v) => region.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
)
.map_err(|e| e.into()),
ValType::PrevAssigned(v)
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
constants.insert(
*f,
ValType::AssignedConstant(res.clone(), *f),
);
Ok(res)
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"PrevAssigned".to_string(),
)),
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
e
);
Err(Error::Synthesis)
}
}
})
.collect(),
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect()
}
};
})
.collect()
}
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
let offset = message.len() / WIDTH + 1;
@@ -277,7 +282,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
message.len(),
start_time.elapsed()
);
res
res.map_err(|e| e.into())
}
/// L is the number of inputs to the hash function
@@ -289,7 +294,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
) -> Result<ValTensor<Fp>, ModuleError> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
@@ -301,7 +306,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let mut one_iter = false;
// do the Tree dance baby
while input_cells.len() > 1 || !one_iter {
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, Error> = input_cells
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
.chunks(L)
.enumerate()
.map(|(i, block)| {
@@ -332,7 +337,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
hash
})
.collect();
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into());
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
one_iter = true;
@@ -348,7 +354,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) => v,
_ => {
log::error!("wrong input type, must be previously assigned");
return Err(Error::Synthesis);
return Err(Error::Synthesis.into());
}
};
@@ -380,7 +386,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
}
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
let mut hash_inputs = message;
let len = hash_inputs.len();
@@ -400,7 +406,11 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
block.extend(vec![Fp::ZERO; L - remainder].iter());
}
let message = block.try_into().map_err(|_| Error::Synthesis)?;
let block_len = block.len();
let message = block
.try_into()
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
Ok(halo2_gadgets::poseidon::primitives::Hash::<
_,
@@ -411,7 +421,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
>::init()
.hash(message))
})
.collect::<Result<Vec<_>, Error>>()?;
.collect::<Result<Vec<_>, ModuleError>>()?;
one_iter = true;
hash_inputs = hashes;
}

View File

@@ -1,7 +1,5 @@
use std::str::FromStr;
use thiserror::Error;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector},
@@ -26,31 +24,11 @@ use crate::{
},
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
use std::{collections::BTreeMap, marker::PhantomData};
use super::{lookup::LookupOp, region::RegionCtx, Op};
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
use halo2curves::ff::{Field, PrimeField};
/// circuit related errors.
#[derive(Debug, Error)]
pub enum CircuitError {
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
#[error("invalid einsum expression")]
InvalidEinsum,
}
#[allow(missing_docs)]
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
#[derive(
@@ -513,18 +491,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
lookup_range: Range,
logrows: usize,
nl: &LookupOp,
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
return Err(CircuitError::WrongColumnType(index.name().to_string()));
}
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
return Err(CircuitError::WrongColumnType(input.name().to_string()));
}
if !output.is_advice() {
return Err("wrong input type for lookup output".into());
return Err(CircuitError::WrongColumnType(output.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
@@ -654,19 +632,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 3],
tables: &[VarTensor; 3],
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
@@ -737,19 +715,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
for l in inputs.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
@@ -822,12 +800,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
index: &VarTensor,
range: Range,
logrows: usize,
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
return Err(CircuitError::WrongColumnType(input.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
@@ -918,7 +896,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
}
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
if !table.is_assigned {
debug!(
@@ -939,7 +917,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
pub fn layout_range_checks(
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), CircuitError> {
for range_check in self.range_checks.ranges.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
@@ -959,7 +937,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)
}
}

94
src/circuit/ops/errors.rs Normal file
View File

@@ -0,0 +1,94 @@
use std::convert::Infallible;
use crate::tensor::TensorError;
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
/// Error type for the circuit module
#[derive(Error, Debug)]
pub enum CircuitError {
/// Halo 2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] PlonkError),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] TensorError),
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
#[error("invalid einsum expression")]
InvalidEinsum,
/// Flush error
#[error("failed to flush, linear coord is not aligned with the next row")]
FlushError,
/// Constrain error
#[error("constrain_equal: one of the tensors is assigned and the other is not")]
ConstrainError,
/// Failed to get lookups
#[error("failed to get lookups for op: {0}")]
GetLookupsError(String),
/// Failed to get range checks
#[error("failed to get range checks for op: {0}")]
GetRangeChecksError(String),
/// Failed to get dynamic lookup
#[error("failed to get dynamic lookup for op: {0}")]
GetDynamicLookupError(String),
/// Failed to get shuffle
#[error("failed to get shuffle for op: {0}")]
GetShuffleError(String),
/// Failed to get constants
#[error("failed to get constants for op: {0}")]
GetConstantsError(String),
/// Slice length mismatch
#[error("slice length mismatch: {0}")]
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
/// Bad conversion
#[error("invalid conversion: {0}")]
InvalidConversion(#[from] Infallible),
/// Invalid min/max lookup range
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
InvalidMinMaxRange(i64, i64),
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
/// Mismatched lookup length
#[error("mismatched lookup lengths: {0} and {1}")]
MismatchedLookupLength(usize, usize),
/// Mismatched shuffle length
#[error("mismatched shuffle lengths: {0} and {1}")]
MismatchedShuffleLength(usize, usize),
/// Mismatched lookup table lengths
#[error("mismatched lookup table lengths: {0} and {1}")]
MismatchedLookupTableLength(usize, usize),
/// Wrong column type for lookup
#[error("wrong column type for lookup: {0}")]
WrongColumnType(String),
/// Wrong column type for dynamic lookup
#[error("wrong column type for dynamic lookup: {0}")]
WrongDynamicColumnType(String),
/// Missing selectors
#[error("missing selectors for op: {0}")]
MissingSelectors(String),
/// Table lookup error
#[error("value ({0}) out of range: ({1}, {2})")]
TableOOR(i64, i64, i64),
/// Loookup not configured
#[error("lookup not configured: {0}")]
LookupNotConfigured(String),
/// Range check not configured
#[error("range check not configured: {0}")]
RangeCheckNotConfigured(String),
/// Missing layout
#[error("missing layout for op: {0}")]
MissingLayout(String),
}

View File

@@ -155,7 +155,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::SumPool {
padding,
@@ -287,7 +287,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
}))
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,5 @@
use super::*;
use serde::{Deserialize, Serialize};
use std::error::Error;
use crate::{
circuit::{layouts, table::Range, utils},
@@ -295,7 +294,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,
region,
@@ -305,7 +304,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
}
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];

View File

@@ -1,4 +1,4 @@
use std::{any::Any, error::Error};
use std::any::Any;
use serde::{Deserialize, Serialize};
@@ -15,6 +15,8 @@ pub mod base;
///
pub mod chip;
///
pub mod errors;
///
pub mod hybrid;
/// Layouts for specific functions (composed of base ops)
pub mod layouts;
@@ -25,6 +27,8 @@ pub mod poly;
///
pub mod region;
pub use errors::CircuitError;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
@@ -44,10 +48,10 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>>;
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>>;
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError>;
/// Do any of the inputs to this op require homogenous input scales?
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
@@ -139,7 +143,7 @@ pub struct Input {
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.scale)
}
@@ -156,7 +160,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
match self.datum_type {
@@ -194,7 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
@@ -209,8 +213,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
Err(Box::new(super::CircuitError::UnsupportedOp))
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
@@ -240,7 +244,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Consta
}
}
/// Rebase the scale of the constant
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), Box<dyn Error>> {
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
let visibility = self.quantized_values.visibility().unwrap();
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
@@ -279,7 +283,7 @@ impl<
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
@@ -293,7 +297,7 @@ impl<
Box::new(self.clone()) // Forward to the derive(Clone) impl
}
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.quantized_values.scale().unwrap())
}

View File

@@ -179,7 +179,7 @@ impl<
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
@@ -278,9 +278,10 @@ impl<
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
return Err(Box::new(TensorError::DimError(
return Err(TensorError::DimError(
"Pad operation requires a single input".to_string(),
)));
)
.into());
}
let mut input = values[0].clone();
input.pad(p.clone(), 0)?;
@@ -297,7 +298,7 @@ impl<
}))
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,

View File

@@ -1,6 +1,6 @@
use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
@@ -19,7 +19,7 @@ use std::{
},
};
use super::lookup::LookupOp;
use super::{lookup::LookupOp, CircuitError};
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
@@ -84,44 +84,6 @@ impl ShuffleIndex {
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
/// wrap other regions
#[error("Wrapped region: {0}")]
Wrapped(String),
}
impl From<String> for RegionError {
fn from(e: String) -> Self {
Self::Wrapped(e)
}
}
impl From<&str> for RegionError {
fn from(e: &str) -> Self {
Self::Wrapped(e.to_string())
}
}
impl From<TensorError> for RegionError {
fn from(e: TensorError) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
impl From<Error> for RegionError {
fn from(e: Error) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
impl From<Box<dyn std::error::Error>> for RegionError {
fn from(e: Box<dyn std::error::Error>) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
@@ -317,10 +279,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn apply_in_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
) -> Result<(), RegionError> {
) -> Result<(), CircuitError> {
if self.is_dummy() {
self.dummy_loop(output, inner_loop_function)?;
} else {
@@ -333,8 +295,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn real_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>,
) -> Result<(), RegionError> {
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>,
) -> Result<(), CircuitError> {
output
.iter_mut()
.enumerate()
@@ -342,7 +304,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
*o = inner_loop_function(i, self)?;
Ok(())
})
.collect::<Result<Vec<_>, RegionError>>()?;
.collect::<Result<Vec<_>, CircuitError>>()?;
Ok(())
}
@@ -353,10 +315,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn dummy_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
) -> Result<(), RegionError> {
) -> Result<(), CircuitError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
@@ -367,50 +329,48 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
// get inner value of the locked lookups
*output = output.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.witness_gen,
self.check_lookup_range,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
linear_coord.fetch_add(
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.witness_gen,
self.check_lookup_range,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
linear_coord.fetch_add(
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
res
})?;
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
@@ -419,49 +379,25 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
})?;
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
self.used_range_checks = Arc::try_unwrap(range_checks)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?;
self.shuffle_index = Arc::try_unwrap(shuffle_index)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?;
Ok(())
}
@@ -470,7 +406,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
@@ -482,12 +418,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(
&mut self,
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
if range.0 > range.1 {
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
return Err(CircuitError::InvalidMinMaxRange(range.0, range.1));
}
let range_size = (range.1 - range.0).abs();
@@ -506,13 +439,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), CircuitError> {
self.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
self.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
@@ -707,7 +640,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// constrain equal
pub fn constrain_equal(&mut self, a: &ValTensor<F>, b: &ValTensor<F>) -> Result<(), Error> {
pub fn constrain_equal(
&mut self,
a: &ValTensor<F>,
b: &ValTensor<F>,
) -> Result<(), CircuitError> {
if let Some(region) = &self.region {
let a = a.get_inner_tensor().unwrap();
let b = b.get_inner_tensor().unwrap();
@@ -717,12 +654,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let b = b.get_prev_assigned();
// if they're both assigned, we can constrain them
if let (Some(a), Some(b)) = (&a, &b) {
region.borrow_mut().constrain_equal(a.cell(), b.cell())
region
.borrow_mut()
.constrain_equal(a.cell(), b.cell())
.map_err(|e| e.into())
} else if a.is_some() || b.is_some() {
log::error!(
"constrain_equal: one of the tensors is assigned and the other is not"
);
return Err(Error::Synthesis);
return Err(CircuitError::ConstrainError);
} else {
Ok(())
}
@@ -748,7 +685,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// flush row to the next row
pub fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
pub fn flush(&mut self) -> Result<(), CircuitError> {
// increment by the difference between the current linear coord and the next row
let remainder = self.linear_coord % self.num_inner_cols;
if remainder != 0 {
@@ -756,7 +693,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.increment(diff);
}
if self.linear_coord % self.num_inner_cols != 0 {
return Err("flush: linear coord is not aligned with the next row".into());
return Err(CircuitError::FlushError);
}
Ok(())
}

View File

@@ -1,4 +1,4 @@
use std::{error::Error, marker::PhantomData};
use std::marker::PhantomData;
use halo2curves::ff::PrimeField;
@@ -194,9 +194,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
&mut self,
layouter: &mut impl Layouter<F>,
preassigned_input: bool,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), CircuitError> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
return Err(CircuitError::TableAlreadyAssigned);
}
let smallest = self.range.0;
@@ -342,9 +342,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
return Err(CircuitError::TableAlreadyAssigned);
}
let smallest = self.range.0;

View File

@@ -868,6 +868,13 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(feature = "no-update"))]
/// Updates ezkl binary to version specified (or latest if not specified)
Update {
/// The version to update to
#[arg(value_hint = clap::ValueHint::Other, short='v', long)]
version: Option<String>,
},
}

View File

@@ -16,7 +16,8 @@ use alloy::prelude::Wallet;
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::node_bindings::Anvil;
use alloy::primitives::{B256, I256};
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{ParseSignedError, B256, I256};
use alloy::providers::fillers::{
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
};
@@ -25,10 +26,13 @@ use alloy::providers::ProviderBuilder;
use alloy::providers::{Identity, Provider, RootProvider};
use alloy::rpc::types::eth::TransactionInput;
use alloy::rpc::types::eth::TransactionRequest;
use alloy::signers::wallet::LocalWallet;
use alloy::signers::k256::ecdsa;
use alloy::signers::wallet::{LocalWallet, WalletError};
use alloy::sol as abigen;
use alloy::transports::http::Http;
use alloy::transports::{RpcError, TransportErrorKind};
use foundry_compilers::artifacts::Settings as SolcSettings;
use foundry_compilers::error::{SolcError, SolcIoError};
use foundry_compilers::Solc;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -36,7 +40,6 @@ use halo2curves::group::ff::PrimeField;
use itertools::Itertools;
use log::{debug, info, warn};
use reqwest::Client;
use std::error::Error;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
@@ -213,6 +216,57 @@ abigen!(
}
);
#[derive(Debug, thiserror::Error)]
pub enum EthError {
#[error("a transport error occurred: {0}")]
Transport(#[from] RpcError<TransportErrorKind>),
#[error("a contract error occurred: {0}")]
Contract(#[from] alloy::contract::Error),
#[error("a wallet error occurred: {0}")]
Wallet(#[from] WalletError),
#[error("failed to parse url {0}")]
UrlParse(String),
#[error("evm verification error: {0}")]
EvmVerification(#[from] EvmVerificationError),
#[error("Private key must be in hex format, 64 chars, without 0x prefix")]
PrivateKeyFormat,
#[error("failed to parse hex: {0}")]
HexParse(#[from] hex::FromHexError),
#[error("ecdsa error: {0}")]
Ecdsa(#[from] ecdsa::Error),
#[error("failed to load graph data")]
GraphData,
#[error("failed to load graph settings")]
GraphSettings,
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("Data source for either input_data or output_data must be OnChain")]
OnChainDataSource,
#[error("failed to parse signed integer: {0}")]
SignedIntegerParse(#[from] ParseSignedError),
#[error("failed to parse unsigned integer: {0}")]
UnSignedIntegerParse(#[from] ParseError),
#[error("updateAccountCalls should have failed")]
UpdateAccountCalls,
#[error("ethabi error: {0}")]
EthAbi(#[from] ethabi::Error),
#[error("conversion error: {0}")]
Conversion(#[from] std::convert::Infallible),
// Constructor arguments provided but no constructor found
#[error("constructor arguments provided but no constructor found")]
NoConstructor,
#[error("contract not found at path: {0}")]
ContractNotFound(String),
#[error("solc error: {0}")]
Solc(#[from] SolcError),
#[error("solc io error: {0}")]
SolcIo(#[from] SolcIoError),
#[error("svm error: {0}")]
Svm(String),
#[error("no contract output found")]
NoContractOutput,
}
// we have to generate these two contract differently because they are generated dynamically ! and hence the static compilation from above does not suit
const ATTESTDATA_SOL: &str = include_str!("../contracts/AttestData.sol");
@@ -235,7 +289,7 @@ pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
pub async fn setup_eth_backend(
rpc_url: Option<&str>,
private_key: Option<&str>,
) -> Result<(EthersClient, alloy::primitives::Address), Box<dyn Error>> {
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
// Launch anvil
let endpoint: String;
@@ -257,11 +311,8 @@ pub async fn setup_eth_backend(
let wallet: LocalWallet;
if let Some(private_key) = private_key {
debug!("using private key {}", private_key);
// Sanity checks for private_key
let private_key_format_error =
"Private key must be in hex format, 64 chars, without 0x prefix";
if private_key.len() != 64 {
return Err(private_key_format_error.into());
return Err(EthError::PrivateKeyFormat);
}
let private_key_buffer = hex::decode(private_key)?;
wallet = LocalWallet::from_slice(&private_key_buffer)?;
@@ -276,7 +327,7 @@ pub async fn setup_eth_backend(
ProviderBuilder::new()
.with_recommended_fillers()
.signer(EthereumSigner::from(wallet))
.on_http(endpoint.parse()?),
.on_http(endpoint.parse().map_err(|_| EthError::UrlParse(endpoint))?),
);
let chain_id = client.get_chain_id().await?;
@@ -292,15 +343,14 @@ pub async fn deploy_contract_via_solidity(
runs: usize,
private_key: Option<&str>,
contract_name: &str,
) -> Result<H160, Box<dyn Error>> {
) -> Result<H160, EthError> {
// anvil instance must be alive at least until the factory completes the deploy
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, contract_name, runs).await?;
let factory =
get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone(), None::<()>)?;
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client, None::<()>)?;
let contract = factory.deploy().await?;
Ok(contract)
@@ -314,12 +364,12 @@ pub async fn deploy_da_verifier_via_solidity(
rpc_url: Option<&str>,
runs: usize,
private_key: Option<&str>,
) -> Result<H160, Box<dyn Error>> {
) -> Result<H160, EthError> {
let (client, client_address) = setup_eth_backend(rpc_url, private_key).await?;
let input = GraphData::from_path(input)?;
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
let settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path).map_err(|_| EthError::GraphSettings)?;
let mut scales: Vec<u32> = vec![];
// The data that will be stored in the test contracts that will eventually be read from.
@@ -339,7 +389,7 @@ pub async fn deploy_da_verifier_via_solidity(
}
if settings.run_args.param_visibility.is_hashed() {
return Err(Box::new(EvmVerificationError::InvalidVisibility));
return Err(EvmVerificationError::InvalidVisibility.into());
}
if settings.run_args.output_visibility.is_hashed() {
@@ -397,20 +447,30 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
parse_calls_to_accounts(calls_to_accounts)?
} else {
return Err("Data source for either input_data or output_data must be OnChain".into());
// if calls to accounts is empty then we know need to check that atleast there kzg visibility in the settings file
let kzg_visibility = settings.run_args.input_visibility.is_polycommit()
|| settings.run_args.output_visibility.is_polycommit()
|| settings.run_args.param_visibility.is_polycommit();
if !kzg_visibility {
return Err(EthError::OnChainDataSource);
}
let factory =
get_sol_contract_factory::<_, ()>(abi, bytecode, runtime_bytecode, client, None)?;
let contract = factory.deploy().await?;
return Ok(contract);
};
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
let factory = get_sol_contract_factory(
abi,
bytecode,
runtime_bytecode,
client.clone(),
client,
Some((
// address[] memory _contractAddresses,
DynSeqToken(
@@ -451,7 +511,7 @@ pub async fn deploy_da_verifier_via_solidity(
),
// uint8 _instanceOffset,
WordToken(U256::from(contract_instance_offset as u32).into()),
//address _admin
// address _admin
WordToken(client_address.into_word()),
)),
)?;
@@ -469,12 +529,12 @@ type ParsedCallsToAccount = (Vec<H160>, Vec<Vec<Bytes>>, Vec<Vec<U256>>);
fn parse_calls_to_accounts(
calls_to_accounts: Vec<CallsToAccount>,
) -> Result<ParsedCallsToAccount, Box<dyn Error>> {
) -> Result<ParsedCallsToAccount, EthError> {
let mut contract_addresses = vec![];
let mut call_data = vec![];
let mut decimals: Vec<Vec<U256>> = vec![];
for (i, val) in calls_to_accounts.iter().enumerate() {
let contract_address_bytes = hex::decode(val.address.clone())?;
let contract_address_bytes = hex::decode(&val.address)?;
let contract_address = H160::from_slice(&contract_address_bytes);
contract_addresses.push(contract_address);
call_data.push(vec![]);
@@ -492,8 +552,8 @@ pub async fn update_account_calls(
addr: H160,
input: PathBuf,
rpc_url: Option<&str>,
) -> Result<(), Box<dyn Error>> {
let input = GraphData::from_path(input)?;
) -> Result<(), EthError> {
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
// The data that will be stored in the test contracts that will eventually be read from.
let mut calls_to_accounts = vec![];
@@ -513,12 +573,12 @@ pub async fn update_account_calls(
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
parse_calls_to_accounts(calls_to_accounts)?
} else {
return Err("Data source for either input_data or output_data must be OnChain".into());
return Err(EthError::OnChainDataSource);
};
let (client, client_address) = setup_eth_backend(rpc_url, None).await?;
let contract = DataAttestation::new(addr, client.clone());
let contract = DataAttestation::new(addr, &client);
info!("contract_addresses: {:#?}", contract_addresses);
@@ -547,7 +607,7 @@ pub async fn update_account_calls(
{
info!("updateAccountCalls failed as expected");
} else {
return Err("updateAccountCalls should have failed".into());
return Err(EthError::UpdateAccountCalls);
}
Ok(())
@@ -560,7 +620,7 @@ pub async fn verify_proof_via_solidity(
addr: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
) -> Result<bool, EthError> {
let flattened_instances = proof.instances.into_iter().flatten();
let encoded = encode_calldata(
@@ -579,15 +639,15 @@ pub async fn verify_proof_via_solidity(
let result = client.call(&tx).await;
if result.is_err() {
return Err(Box::new(EvmVerificationError::SolidityExecution));
if let Err(e) = result {
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
}
let result = result?;
debug!("result: {:#?}", result.to_vec());
// decode return bytes value into uint8
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
if !result {
return Err(Box::new(EvmVerificationError::InvalidProof));
return Err(EvmVerificationError::InvalidProof.into());
}
let gas = client.estimate_gas(&tx).await?;
@@ -626,7 +686,7 @@ fn count_decimal_places(num: f32) -> usize {
pub async fn setup_test_contract<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
data: &[Vec<FileSourceInner>],
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), Box<dyn Error>> {
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), EthError> {
let mut decimals = vec![];
let mut scaled_by_decimals_data = vec![];
for input in &data[0] {
@@ -663,7 +723,7 @@ pub async fn verify_proof_with_data_attestation(
addr_da: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
) -> Result<bool, EthError> {
use ethabi::{Function, Param, ParamType, StateMutability, Token};
let mut public_inputs: Vec<U256> = vec![];
@@ -728,15 +788,15 @@ pub async fn verify_proof_with_data_attestation(
);
let result = client.call(&tx).await;
if result.is_err() {
return Err(Box::new(EvmVerificationError::SolidityExecution));
if let Err(e) = result {
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
}
let result = result?;
debug!("result: {:#?}", result);
// decode return bytes value into uint8
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
if !result {
return Err(Box::new(EvmVerificationError::InvalidProof));
return Err(EvmVerificationError::InvalidProof.into());
}
Ok(true)
@@ -748,8 +808,8 @@ pub async fn verify_proof_with_data_attestation(
pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
data: &[Vec<FileSourceInner>],
) -> Result<Vec<CallsToAccount>, Box<dyn Error>> {
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
) -> Result<Vec<CallsToAccount>, EthError> {
let (contract, decimals) = setup_test_contract(client, data).await?;
// Get the encoded call data for each input
let mut calldata = vec![];
@@ -774,17 +834,17 @@ pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>
client: Arc<M>,
address: H160,
data: &Vec<CallsToAccount>,
) -> Result<(Vec<Bytes>, Vec<u8>), Box<dyn Error>> {
) -> Result<(Vec<Bytes>, Vec<u8>), EthError> {
// Iterate over all on-chain inputs
let mut fetched_inputs = vec![];
let mut decimals = vec![];
for on_chain_data in data {
// Construct the address
let contract_address_bytes = hex::decode(on_chain_data.address.clone())?;
let contract_address_bytes = hex::decode(&on_chain_data.address)?;
let contract_address = H160::from_slice(&contract_address_bytes);
for (call_data, decimal) in &on_chain_data.call_data {
let call_data_bytes = hex::decode(call_data.clone())?;
let call_data_bytes = hex::decode(call_data)?;
let input: TransactionInput = call_data_bytes.into();
let tx = TransactionRequest::default()
@@ -808,13 +868,11 @@ pub async fn evm_quantize<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
scales: Vec<crate::Scale>,
data: &(Vec<Bytes>, Vec<u8>),
) -> Result<Vec<Fr>, Box<dyn Error>> {
use alloy::primitives::ParseSignedError;
) -> Result<Vec<Fr>, EthError> {
let contract = QuantizeData::deploy(&client).await?;
let fetched_inputs = data.0.clone();
let decimals = data.1.clone();
let fetched_inputs = &data.0;
let decimals = &data.1;
let fetched_inputs = fetched_inputs
.iter()
@@ -870,7 +928,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
runtime_bytecode: Bytes,
client: Arc<M>,
params: Option<T>,
) -> Result<ContractFactory<M>, Box<dyn Error>> {
) -> Result<ContractFactory<M>, EthError> {
const MAX_RUNTIME_BYTECODE_SIZE: usize = 24577;
let size = runtime_bytecode.len();
debug!("runtime bytecode size: {:#?}", size);
@@ -888,9 +946,9 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
// Encode the constructor args & concatenate with the bytecode if necessary
let data: Bytes = match (abi.constructor(), params.is_none()) {
(None, false) => {
return Err("Constructor arguments provided but no constructor found".into())
return Err(EthError::NoConstructor);
}
(None, true) => bytecode.clone(),
(None, true) => bytecode,
(Some(_), _) => {
let mut data = bytecode.to_vec();
@@ -902,7 +960,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
}
};
Ok(CallBuilder::new_raw_deploy(client.clone(), data))
Ok(CallBuilder::new_raw_deploy(client, data))
}
/// Compiles a solidity verifier contract and returns the abi, bytecode, and runtime bytecode
@@ -911,7 +969,7 @@ pub async fn get_contract_artifacts(
sol_code_path: PathBuf,
contract_name: &str,
runs: usize,
) -> Result<(JsonAbi, Bytes, Bytes), Box<dyn Error>> {
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
use foundry_compilers::{
artifacts::{output_selection::OutputSelection, Optimizer},
compilers::CompilerInput,
@@ -919,7 +977,9 @@ pub async fn get_contract_artifacts(
};
if !sol_code_path.exists() {
return Err(format!("file not found: {:#?}", sol_code_path).into());
return Err(EthError::ContractNotFound(
sol_code_path.to_string_lossy().to_string(),
));
}
let settings = SolcSettings {
@@ -946,7 +1006,9 @@ pub async fn get_contract_artifacts(
Some(solc) => solc,
None => {
info!("required solc version is missing ... installing");
Solc::install(&SHANGHAI_SOLC).await?
Solc::install(&SHANGHAI_SOLC)
.await
.map_err(|e| EthError::Svm(e.to_string()))?
}
};
@@ -955,7 +1017,7 @@ pub async fn get_contract_artifacts(
let (abi, bytecode, runtime_bytecode) = match compiled.find(contract_name) {
Some(c) => c.into_parts_or_default(),
None => {
return Err("could not find contract".into());
return Err(EthError::ContractNotFound(contract_name.to_string()));
}
};
@@ -967,13 +1029,13 @@ pub fn fix_da_sol(
input_data: Option<Vec<CallsToAccount>>,
output_data: Option<Vec<CallsToAccount>>,
commitment_bytes: Option<Vec<u8>>,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EthError> {
let mut accounts_len = 0;
let mut contract = ATTESTDATA_SOL.to_string();
// fill in the quantization params and total calls
// as constants to the contract to save on gas
if let Some(input_data) = input_data {
if let Some(input_data) = &input_data {
let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum();
accounts_len = input_data.len();
contract = contract.replace(
@@ -981,7 +1043,7 @@ pub fn fix_da_sol(
&format!("uint256 constant INPUT_CALLS = {};", input_calls),
);
}
if let Some(output_data) = output_data {
if let Some(output_data) = &output_data {
let output_calls: usize = output_data.iter().map(|v| v.call_data.len()).sum();
accounts_len += output_data.len();
contract = contract.replace(
@@ -991,8 +1053,9 @@ pub fn fix_da_sol(
}
contract = contract.replace("AccountCall[]", &format!("AccountCall[{}]", accounts_len));
if commitment_bytes.clone().is_some() && !commitment_bytes.clone().unwrap().is_empty() {
let commitment_bytes = commitment_bytes.unwrap();
// The case where a combination of on-chain data source + kzg commit is provided.
if commitment_bytes.is_some() && !commitment_bytes.as_ref().unwrap().is_empty() {
let commitment_bytes = commitment_bytes.as_ref().unwrap();
let hex_string = hex::encode(commitment_bytes);
contract = contract.replace(
"bytes constant COMMITMENT_KZG = hex\"\";",
@@ -1007,5 +1070,44 @@ pub fn fix_da_sol(
);
}
// if both input and output data is none then we will only deploy the DataAttest contract, adding in the verifyWithDataAttestation function
if input_data.is_none()
&& output_data.is_none()
&& commitment_bytes.as_ref().is_some()
&& !commitment_bytes.as_ref().unwrap().is_empty()
{
contract = contract.replace(
"contract SwapProofCommitments {",
"contract DataAttestation {",
);
// Remove everything past the end of the checkKzgCommits function
if let Some(pos) = contract.find(" } /// end checkKzgCommits") {
contract.truncate(pos);
contract.push('}');
}
// Add the Solidity function below checkKzgCommits
contract.push_str(
r#"
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}"#,
);
}
Ok(contract)
}

View File

@@ -1,7 +1,6 @@
use crate::circuit::CheckMode;
#[cfg(not(target_arch = "wasm32"))]
use crate::commands::CalibrationTarget;
use crate::commands::*;
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
#[cfg(not(target_arch = "wasm32"))]
@@ -23,6 +22,7 @@ use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
#[cfg(not(target_arch = "wasm32"))]
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
@@ -63,7 +63,6 @@ use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::error::Error;
use std::fs::File;
#[cfg(not(target_arch = "wasm32"))]
use std::io::BufWriter;
@@ -92,12 +91,15 @@ lazy_static! {
}
/// A wrapper for tensor related errors.
/// A wrapper for execution errors
#[derive(Debug, Error)]
pub enum ExecutionError {
/// Shape mismatch in a operation
#[error("verification failed")]
/// verification failed
#[error("verification failed:\n{}", .0.iter().map(|e| e.to_string()).collect::<Vec<_>>().join("\n"))]
VerifyError(Vec<VerifyFailure>),
/// Prover error
#[error("[mock] {0}")]
MockProverError(String),
}
lazy_static::lazy_static! {
@@ -109,7 +111,7 @@ lazy_static::lazy_static! {
}
/// Run an ezkl command with given args
pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
pub async fn run(command: Commands) -> Result<String, EZKLError> {
// set working dir
std::env::set_current_dir(WORKING_DIR.as_path())?;
@@ -123,7 +125,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
} => gen_srs_cmd(
srs_path,
logrows as u32,
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT)?),
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
),
#[cfg(not(target_arch = "wasm32"))]
Commands::GetSrs {
@@ -161,7 +163,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse()?),
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
max_logrows,
)
.await
@@ -200,7 +202,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
)
.await
}
@@ -265,8 +267,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_AGGREGATED.into()),
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
aggregation_settings,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
)
.await
}
@@ -292,7 +294,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path.unwrap_or(DEFAULT_VK.into()),
pk_path.unwrap_or(DEFAULT_PK.into()),
witness,
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
disable_selector_compression
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEvmData {
@@ -345,7 +348,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
Some(proof_path.unwrap_or(DEFAULT_PROOF.into())),
srs_path,
proof_type,
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse().unwrap()),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::MockAggregate {
@@ -354,8 +357,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
split_proofs,
} => mock_aggregate(
aggregation_snarks,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
),
Commands::SetupAggregate {
sample_snarks,
@@ -371,9 +374,10 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
srs_path,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
disable_selector_compression
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
commitment.into(),
),
Commands::Aggregate {
@@ -392,9 +396,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
srs_path,
transcript,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse().unwrap()),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -409,7 +413,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
vk_path.unwrap_or(DEFAULT_VK.into()),
srs_path,
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap()),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::VerifyAggr {
@@ -423,8 +427,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
srs_path,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap()),
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -502,6 +506,69 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
)
.await
}
#[cfg(not(feature = "no-update"))]
Commands::Update { version } => update_ezkl_binary(&version).map(|e| e.to_string()),
}
}
#[cfg(not(feature = "no-update"))]
/// Assert that the version is valid
fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
let err_string = "Invalid version string. Must be in the format v0.0.0";
if version.is_empty() {
return Err(err_string.into());
}
// safe to unwrap since we know the length is not 0
if !version.starts_with('v') {
return Err(err_string.into());
}
semver::Version::parse(&version[1..])
.map_err(|_| "Invalid version string. Must be in the format v0.0.0")?;
Ok(())
}
#[cfg(not(feature = "no-update"))]
const INSTALL_BYTES: &[u8] = include_bytes!("../install_ezkl_cli.sh");
#[cfg(not(feature = "no-update"))]
fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
// run the install script with the version
let install_script = std::str::from_utf8(INSTALL_BYTES)?;
// now run as sh script with the version as an argument
// check if bash is installed
let command = if std::process::Command::new("bash")
.arg("--version")
.status()
.is_err()
{
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
"sh"
} else {
"bash"
};
let mut command = std::process::Command::new(command);
let mut command = command.arg("-c").arg(install_script);
if let Some(version) = version {
assert_version_is_valid(version)?;
command = command.arg(version)
};
let output = command.output()?;
if output.status.success() {
info!("updated binary");
Ok("".to_string())
} else {
Err(format!(
"failed to update binary: {}, {}",
std::str::from_utf8(&output.stdout)?,
std::str::from_utf8(&output.stderr)?
)
.into())
}
}
@@ -528,7 +595,7 @@ pub(crate) fn gen_srs_cmd(
srs_path: PathBuf,
logrows: u32,
commitment: Commitments,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
match commitment {
Commitments::KZG => {
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -543,7 +610,7 @@ pub(crate) fn gen_srs_cmd(
}
#[cfg(not(target_arch = "wasm32"))]
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, EZKLError> {
let pb = {
let pb = init_spinner();
pb.set_message("Downloading SRS (this may take a while) ...");
@@ -563,7 +630,7 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, EZKLError> {
use std::io::Read;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
@@ -586,7 +653,7 @@ fn check_srs_hash(
logrows: u32,
srs_path: Option<PathBuf>,
commitment: Commitments,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let path = get_srs_path(logrows, srs_path, commitment);
let hash = get_file_hash(&path)?;
@@ -613,7 +680,7 @@ pub(crate) async fn get_srs_cmd(
settings_path: Option<PathBuf>,
logrows: Option<u32>,
commitment: Option<Commitments>,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
// logrows overrides settings
let err_string = "You will need to provide a valid settings file to use the settings option. You should run gen-settings to generate a settings file (and calibrate-settings to pick optimal logrows).";
@@ -681,7 +748,7 @@ pub(crate) async fn get_srs_cmd(
Ok(String::new())
}
pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn Error>> {
pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLError> {
let model = Model::from_run_args(&run_args, &model)?;
info!("\n {}", model.table_nodes());
Ok(String::new())
@@ -693,7 +760,7 @@ pub(crate) async fn gen_witness(
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> Result<GraphWitness, Box<dyn Error>> {
) -> Result<GraphWitness, EZKLError> {
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
@@ -794,7 +861,7 @@ pub(crate) fn gen_circuit_settings(
model_path: PathBuf,
params_output: PathBuf,
run_args: RunArgs,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let circuit = GraphCircuit::from_run_args(&run_args, &model_path)?;
let params = circuit.settings();
params.save(&params_output)?;
@@ -862,7 +929,7 @@ impl AccuracyResults {
pub fn new(
mut original_preds: Vec<crate::tensor::Tensor<f32>>,
mut calibrated_preds: Vec<crate::tensor::Tensor<f32>>,
) -> Result<Self, Box<dyn Error>> {
) -> Result<Self, EZKLError> {
let mut errors = vec![];
let mut abs_errors = vec![];
let mut squared_errors = vec![];
@@ -951,7 +1018,7 @@ pub(crate) async fn calibrate(
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
) -> Result<GraphSettings, EZKLError> {
use log::error;
use std::collections::HashMap;
use tabled::Table;
@@ -1323,7 +1390,7 @@ pub(crate) async fn calibrate(
pub(crate) fn mock(
compiled_circuit_path: PathBuf,
data_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
// mock should catch any issues by default so we set it to safe
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
@@ -1340,10 +1407,9 @@ pub(crate) fn mock(
&circuit,
vec![public_inputs],
)
.map_err(Box::<dyn Error>::from)?;
prover
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
prover.verify().map_err(ExecutionError::VerifyError)?;
Ok(String::new())
}
@@ -1355,7 +1421,7 @@ pub(crate) async fn create_evm_verifier(
sol_code_path: PathBuf,
abi_path: PathBuf,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
@@ -1399,7 +1465,7 @@ pub(crate) async fn create_evm_vk(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
@@ -1440,7 +1506,7 @@ pub(crate) async fn create_evm_data_attestation(
_abi_path: PathBuf,
_input: PathBuf,
_witness: Option<PathBuf>,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
#[allow(unused_imports)]
use crate::graph::{DataSource, VarVisibility};
use crate::{graph::Visibility, pfsys::get_proof_commitments};
@@ -1519,7 +1585,7 @@ pub(crate) async fn deploy_da_evm(
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let contract_address = deploy_da_verifier_via_solidity(
settings_path,
data,
@@ -1545,7 +1611,7 @@ pub(crate) async fn deploy_evm(
runs: usize,
private_key: Option<String>,
contract_name: &str,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
@@ -1567,7 +1633,7 @@ pub(crate) fn encode_evm_calldata(
proof_path: PathBuf,
calldata_path: PathBuf,
addr_vk: Option<H160Flag>,
) -> Result<Vec<u8>, Box<dyn Error>> {
) -> Result<Vec<u8>, EZKLError> {
let snark = Snark::load::<IPACommitmentScheme<G1Affine>>(&proof_path)?;
let flattened_instances = snark.instances.into_iter().flatten();
@@ -1595,7 +1661,7 @@ pub(crate) async fn verify_evm(
rpc_url: Option<String>,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
use crate::eth::verify_proof_with_data_attestation;
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
@@ -1637,7 +1703,7 @@ pub(crate) async fn create_evm_aggregate_verifier(
circuit_settings: Vec<PathBuf>,
logrows: u32,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
@@ -1694,7 +1760,7 @@ pub(crate) fn compile_circuit(
model_path: PathBuf,
compiled_circuit: PathBuf,
settings_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let settings = GraphSettings::load(&settings_path)?;
let circuit = GraphCircuit::from_settings(&settings, &model_path, CheckMode::UNSAFE)?;
circuit.save(compiled_circuit)?;
@@ -1708,7 +1774,7 @@ pub(crate) fn setup(
pk_path: PathBuf,
witness: Option<PathBuf>,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit)?;
@@ -1760,7 +1826,7 @@ pub(crate) async fn setup_test_evm_witness(
rpc_url: Option<String>,
input_source: TestDataSource,
output_source: TestDataSource,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
use crate::graph::TestOnChainData;
let mut data = GraphData::from_path(data_path)?;
@@ -1795,7 +1861,7 @@ pub(crate) async fn test_update_account_calls(
addr: H160Flag,
data: PathBuf,
rpc_url: Option<String>,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
use crate::eth::update_account_calls;
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
@@ -1813,7 +1879,7 @@ pub(crate) fn prove(
srs_path: Option<PathBuf>,
proof_type: ProofType,
check_mode: CheckMode,
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
let data = GraphWitness::from_path(data_path)?;
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
@@ -1965,7 +2031,7 @@ pub(crate) fn prove(
pub(crate) fn swap_proof_commitments_cmd(
proof_path: PathBuf,
witness: PathBuf,
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
let snark = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let witness = GraphWitness::from_path(witness)?;
let commitments = witness.get_polycommitments();
@@ -1984,7 +2050,7 @@ pub(crate) fn mock_aggregate(
aggregation_snarks: Vec<PathBuf>,
logrows: u32,
split_proofs: bool,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let mut snarks = vec![];
for proof_path in aggregation_snarks.iter() {
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
@@ -2011,10 +2077,8 @@ pub(crate) fn mock_aggregate(
let circuit = AggregationCircuit::new(&G1Affine::generator().into(), snarks, split_proofs)?;
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
.map_err(Box::<dyn Error>::from)?;
prover
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
prover.verify().map_err(ExecutionError::VerifyError)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("Done.");
Ok(String::new())
@@ -2029,7 +2093,7 @@ pub(crate) fn setup_aggregate(
split_proofs: bool,
disable_selector_compression: bool,
commitment: Commitments,
) -> Result<String, Box<dyn Error>> {
) -> Result<String, EZKLError> {
let mut snarks = vec![];
for proof_path in sample_snarks.iter() {
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
@@ -2092,7 +2156,7 @@ pub(crate) fn aggregate(
check_mode: CheckMode,
split_proofs: bool,
commitment: Commitments,
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
let mut snarks = vec![];
for proof_path in aggregation_snarks.iter() {
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
@@ -2272,7 +2336,7 @@ pub(crate) fn verify(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
reduced_srs: bool,
) -> Result<bool, Box<dyn Error>> {
) -> Result<bool, EZKLError> {
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
@@ -2366,7 +2430,7 @@ fn verify_commitment<
vk_path: PathBuf,
params: &'a Scheme::ParamsVerifier,
logrows: u32,
) -> Result<bool, Box<dyn Error>>
) -> Result<bool, EZKLError>
where
Scheme::Scalar: FromUniformBytes<64>
+ SerdeObject
@@ -2402,7 +2466,7 @@ pub(crate) fn verify_aggr(
logrows: u32,
reduced_srs: bool,
commitment: Commitments,
) -> Result<bool, Box<dyn Error>> {
) -> Result<bool, EZKLError> {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
@@ -2477,11 +2541,11 @@ pub(crate) fn load_params_verifier<Scheme: CommitmentScheme>(
srs_path: Option<PathBuf>,
logrows: u32,
commitment: Commitments,
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
) -> Result<Scheme::ParamsVerifier, EZKLError> {
let srs_path = get_srs_path(logrows, srs_path, commitment);
let mut params = load_srs_verifier::<Scheme>(srs_path)?;
info!("downsizing params to {} logrows", logrows);
if logrows < params.k() {
info!("downsizing params to {} logrows", logrows);
params.downsize(logrows);
}
Ok(params)
@@ -2492,11 +2556,11 @@ pub(crate) fn load_params_prover<Scheme: CommitmentScheme>(
srs_path: Option<PathBuf>,
logrows: u32,
commitment: Commitments,
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
) -> Result<Scheme::ParamsProver, EZKLError> {
let srs_path = get_srs_path(logrows, srs_path, commitment);
let mut params = load_srs_prover::<Scheme>(srs_path)?;
info!("downsizing params to {} logrows", logrows);
if logrows < params.k() {
info!("downsizing params to {} logrows", logrows);
params.downsize(logrows);
}
Ok(params)

137
src/graph/errors.rs Normal file
View File

@@ -0,0 +1,137 @@
use std::convert::Infallible;
use thiserror::Error;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, String),
/// This operation is unsupported
#[error("unsupported datatype in graph node {0} ({1})")]
UnsupportedDataType(usize, String),
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Reading a file failed
#[error("[io] ({0}) {1}")]
ReadWriteFileError(String, String),
/// Model serialization error
#[error("failed to ser/deser model: {0}")]
ModelSerialize(#[from] bincode::Error),
/// Tract error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[tract] {0}")]
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,
/// Invalid Input Types
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),
/// Public visibility for params is deprecated
#[error("public visibility for params is deprecated, please use `fixed` instead")]
ParamsPublicVisibility,
/// Slice length mismatch
#[error("slice length mismatch: {0}")]
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
/// Bad conversion
#[error("invalid conversion: {0}")]
InvalidConversion(#[from] Infallible),
/// Circuit error
#[error("[circuit] {0}")]
CircuitError(#[from] crate::circuit::CircuitError),
/// Halo2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
/// System time error
#[error("[system time] {0}")]
SystemTimeError(#[from] std::time::SystemTimeError),
/// Missing Batch Size
#[error("unknown dimension batch_size in model inputs, set batch_size in variables")]
MissingBatchSize,
/// Tokio postgres error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[eth] {0}")]
EthError(#[from] crate::eth::EthError),
/// Json error
#[error("[json] {0}")]
JsonError(#[from] serde_json::Error),
/// Missing instances
#[error("missing instances")]
MissingInstances,
/// Missing constants
#[error("missing constants")]
MissingConstants,
/// Missing input for a node
#[error("missing input for node {0}")]
MissingInput(usize),
///
#[error("range only supports constant inputs in a zk circuit")]
NonConstantRange,
///
#[error("trilu only supports constant diagonals in a zk circuit")]
NonConstantTrilu,
///
#[error("insufficient witness values to generate a fixed output")]
InsufficientWitnessValues,
/// Missing scale
#[error("missing scale")]
MissingScale,
/// Extended k is too large
#[error("extended k is too large to accommodate the quotient polynomial with logrows {0}")]
ExtendedKTooLarge(u32),
/// Max lookup input is too large
#[error("lookup range {0} is too large")]
LookupRangeTooLarge(usize),
/// Max range check input is too large
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]
MissingDataSource,
/// Invalid RunArg
#[error("invalid RunArgs: {0}")]
InvalidRunArgs(String),
}

View File

@@ -1,5 +1,5 @@
use super::errors::GraphError;
use super::quantize_float;
use super::GraphError;
use crate::circuit::InputType;
use crate::fieldutils::i64_to_felt;
#[cfg(not(target_arch = "wasm32"))]
@@ -211,9 +211,7 @@ impl PostgresSource {
}
/// Fetch data from postgres
pub async fn fetch(
&self,
) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// clone to move into thread
let user = self.user.clone();
let host = self.host.clone();
@@ -247,9 +245,7 @@ impl PostgresSource {
}
/// Fetch data from postgres and format it as a FileSource
pub async fn fetch_and_format_as_file(
&self,
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
@@ -279,7 +275,7 @@ impl OnChainSource {
scales: Vec<crate::Scale>,
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
use crate::eth::{
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
};
@@ -455,7 +451,7 @@ impl GraphData {
&self,
shapes: &[Vec<usize>],
datum_types: &[tract_onnx::prelude::DatumType],
) -> Result<TVec<TValue>, Box<dyn std::error::Error>> {
) -> Result<TVec<TValue>, GraphError> {
let mut inputs = TVec::new();
match &self.input_data {
DataSource::File(data) => {
@@ -470,10 +466,10 @@ impl GraphData {
}
}
_ => {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
)))
))
}
}
Ok(inputs)
@@ -488,19 +484,26 @@ impl GraphData {
}
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let reader = std::fs::File::open(path)?;
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let reader = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
reader.read_to_string(&mut buf).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, self)?;
Ok(())
}
@@ -509,7 +512,7 @@ impl GraphData {
pub async fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
) -> Result<Vec<Self>, GraphError> {
// split input data into batches
let mut batched_inputs = vec![];
@@ -522,10 +525,10 @@ impl GraphData {
input_data: DataSource::OnChain(_),
output_data: _,
} => {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"on-chain data cannot be split into batches".to_string(),
)))
))
}
#[cfg(not(target_arch = "wasm32"))]
GraphData {
@@ -539,11 +542,11 @@ impl GraphData {
let input_size = shape.clone().iter().product::<usize>();
let input = &iterable[i];
if input.len() % input_size != 0 {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"calibration data length must be evenly divisible by the original input_size"
.to_string(),
)));
));
}
let mut batches = vec![];
for batch in input.chunks(input_size) {

View File

@@ -14,6 +14,9 @@ pub mod utilities;
/// Representations of a computational graph's variables.
pub mod vars;
/// errors for the graph
pub mod errors;
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(unix)]
@@ -24,6 +27,7 @@ pub use input::DataSource;
use itertools::Itertools;
use tosubcommand::ToFlags;
use self::errors::GraphError;
#[cfg(not(target_arch = "wasm32"))]
use self::input::OnChainSource;
use self::input::{FileSource, GraphData};
@@ -58,7 +62,6 @@ use pyo3::types::PyDict;
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use thiserror::Error;
pub use utilities::*;
pub use vars::*;
@@ -88,62 +91,6 @@ lazy_static! {
#[cfg(target_arch = "wasm32")]
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, String),
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
/// This operation is unsupported
#[error("unsupported datatype in graph")]
UnsupportedDataType,
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Error when attempting to load a model
#[error("failed to load")]
ModelLoad,
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,
/// Invalid Input Types
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
}
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
@@ -310,30 +257,31 @@ impl GraphWitness {
}
/// Export the ezkl witness as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
Err(e) => return Err(e.into()),
};
Ok(serialized)
}
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load {}", path.display()))?;
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let file = std::fs::File::open(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// use buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
@@ -595,11 +543,11 @@ impl GraphSettings {
}
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
return Err(e.into());
}
};
Ok(serialized)
@@ -695,17 +643,21 @@ impl GraphCircuit {
&self.core.model
}
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let f = std::fs::File::create(path)?;
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
///
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let f = std::fs::File::open(path)?;
let f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
@@ -770,10 +722,7 @@ pub struct TestOnChainData {
impl GraphCircuit {
///
pub fn new(
model: Model,
run_args: &RunArgs,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
pub fn new(model: Model, run_args: &RunArgs) -> Result<GraphCircuit, GraphError> {
// // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Vec<Fp>> = vec![];
for shape in model.graph.input_shapes()? {
@@ -820,7 +769,7 @@ impl GraphCircuit {
model: Model,
mut settings: GraphSettings,
check_mode: CheckMode,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
) -> Result<GraphCircuit, GraphError> {
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Vec<Fp>> = vec![];
for shape in model.graph.input_shapes()? {
@@ -844,20 +793,14 @@ impl GraphCircuit {
}
/// load inputs and outputs for the model
pub fn load_graph_witness(
&mut self,
data: &GraphWitness,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn load_graph_witness(&mut self, data: &GraphWitness) -> Result<(), GraphError> {
self.graph_witness = data.clone();
// load the module settings
Ok(())
}
/// Prepare the public inputs for the circuit.
pub fn prepare_public_inputs(
&self,
data: &GraphWitness,
) -> Result<Vec<Fp>, Box<dyn std::error::Error>> {
pub fn prepare_public_inputs(&self, data: &GraphWitness) -> Result<Vec<Fp>, GraphError> {
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
let mut public_inputs: Vec<Fp> = vec![];
@@ -890,7 +833,7 @@ impl GraphCircuit {
pub fn pretty_public_inputs(
&self,
data: &GraphWitness,
) -> Result<Option<PrettyElements>, Box<dyn std::error::Error>> {
) -> Result<Option<PrettyElements>, GraphError> {
// dequantize the supplied data using the provided scale.
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
@@ -932,10 +875,7 @@ impl GraphCircuit {
///
#[cfg(target_arch = "wasm32")]
pub fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -946,7 +886,7 @@ impl GraphCircuit {
pub fn load_graph_from_file_exclusively(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -956,7 +896,7 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
_ => Err("Cannot use non-file data source as input for this method.".into()),
_ => unreachable!("cannot load from on-chain data"),
}
}
@@ -965,7 +905,7 @@ impl GraphCircuit {
pub async fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -983,14 +923,12 @@ impl GraphCircuit {
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::OnChain(_) => {
Err("Cannot use on-chain data source as input for this method.".into())
}
DataSource::OnChain(_) => Err(GraphError::OnChainDataSource),
}
}
@@ -1002,7 +940,7 @@ impl GraphCircuit {
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::OnChain(source) => {
let mut per_item_scale = vec![];
@@ -1030,7 +968,7 @@ impl GraphCircuit {
source: OnChainSource,
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
@@ -1054,7 +992,7 @@ impl GraphCircuit {
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (((d, shape), scale), input_type) in file_data
@@ -1085,7 +1023,7 @@ impl GraphCircuit {
&mut self,
file_data: &[Vec<Fp>],
shapes: &[Vec<usize>],
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (d, shape) in file_data.iter().zip(shapes) {
@@ -1112,7 +1050,7 @@ impl GraphCircuit {
&self,
safe_lookup_range: Range,
max_range_size: i64,
) -> Result<u32, Box<dyn std::error::Error>> {
) -> Result<u32, GraphError> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
@@ -1133,7 +1071,7 @@ impl GraphCircuit {
max_range_size: i64,
max_logrows: Option<u32>,
lookup_safety_margin: i64,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), GraphError> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
@@ -1142,15 +1080,18 @@ impl GraphCircuit {
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
let lookup_size = (safe_lookup_range.1 - safe_lookup_range.0).abs();
// check if has overflowed max lookup input
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
if lookup_size > MAX_LOOKUP_ABS / lookup_safety_margin {
return Err(GraphError::LookupRangeTooLarge(
lookup_size.unsigned_abs() as usize
));
}
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
return Err(GraphError::RangeCheckTooLarge(
max_range_size.unsigned_abs() as usize,
));
}
// These are hard lower limits, we can't overflow instances or modules constraints
@@ -1194,12 +1135,7 @@ impl GraphCircuit {
}
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
);
debug!("{}", err_string);
return Err(err_string.into());
return Err(GraphError::ExtendedKTooLarge(max_logrows));
}
let logrows = max_logrows;
@@ -1286,7 +1222,7 @@ impl GraphCircuit {
srs: Option<&Scheme::ParamsProver>,
witness_gen: bool,
check_lookup: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
) -> Result<GraphWitness, GraphError> {
let original_inputs = inputs.to_vec();
let visibility = VarVisibility::from_args(&self.settings().run_args)?;
@@ -1401,7 +1337,7 @@ impl GraphCircuit {
pub fn from_run_args(
run_args: &RunArgs,
model_path: &std::path::Path,
) -> Result<Self, Box<dyn std::error::Error>> {
) -> Result<Self, GraphError> {
let model = Model::from_run_args(run_args, model_path)?;
Self::new(model, run_args)
}
@@ -1412,8 +1348,11 @@ impl GraphCircuit {
params: &GraphSettings,
model_path: &std::path::Path,
check_mode: CheckMode,
) -> Result<Self, Box<dyn std::error::Error>> {
params.run_args.validate()?;
) -> Result<Self, GraphError> {
params
.run_args
.validate()
.map_err(GraphError::InvalidRunArgs)?;
let model = Model::from_run_args(&params.run_args, model_path)?;
Self::new_from_settings(model, params.clone(), check_mode)
}
@@ -1424,7 +1363,7 @@ impl GraphCircuit {
&mut self,
data: &mut GraphData,
test_on_chain_data: TestOnChainData,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), GraphError> {
// Set up local anvil instance for reading on-chain data
let input_scales = self.model().graph.get_input_scales();
@@ -1438,15 +1377,13 @@ impl GraphCircuit {
) {
// if not public then fail
if self.settings().run_args.input_visibility.is_private() {
return Err("Cannot use on-chain data source as private data".into());
return Err(GraphError::OnChainDataSource);
}
let input_data = match &data.input_data {
DataSource::File(input_data) => input_data,
_ => {
return Err("Cannot use non file source as input for on-chain test.
Manually populate on-chain data from file source instead"
.into())
return Err(GraphError::OnChainDataSource);
}
};
// Get the flatten length of input_data
@@ -1467,19 +1404,13 @@ impl GraphCircuit {
) {
// if not public then fail
if self.settings().run_args.output_visibility.is_private() {
return Err("Cannot use on-chain data source as private data".into());
return Err(GraphError::OnChainDataSource);
}
let output_data = match &data.output_data {
Some(DataSource::File(output_data)) => output_data,
Some(DataSource::OnChain(_)) => {
return Err(
"Cannot use on-chain data source as output for on-chain test.
Will manually populate on-chain data from file source instead"
.into(),
)
}
_ => return Err("No output data found".into()),
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
_ => return Err(GraphError::MissingDataSource),
};
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
output_data,
@@ -1522,12 +1453,10 @@ impl CircuitSize {
#[cfg(not(target_arch = "wasm32"))]
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
Err(e) => return Err(e.into()),
};
Ok(serialized)
}

View File

@@ -1,8 +1,8 @@
use super::errors::GraphError;
use super::extract_const_quantized_values;
use super::node::*;
use super::scale_to_multiplier;
use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
@@ -37,7 +37,6 @@ use std::collections::BTreeMap;
#[cfg(not(target_arch = "wasm32"))]
use std::collections::HashMap;
use std::collections::HashSet;
use std::error::Error;
use std::fs;
use std::io::Read;
use std::path::PathBuf;
@@ -396,7 +395,7 @@ impl ParsedNodes {
}
/// Returns shapes of the computational graph's inputs
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
let mut inputs = vec![];
for input in self.inputs.iter() {
@@ -470,7 +469,7 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `run_args` - [RunArgs]
#[cfg(not(target_arch = "wasm32"))]
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, Box<dyn Error>> {
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, GraphError> {
let visibility = VarVisibility::from_args(run_args)?;
let graph = Self::load_onnx_model(reader, run_args, &visibility)?;
@@ -483,20 +482,28 @@ impl Model {
}
///
pub fn save(&self, path: PathBuf) -> Result<(), Box<dyn Error>> {
let f = std::fs::File::create(path)?;
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let writer = std::io::BufWriter::new(f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
///
pub fn load(path: PathBuf) -> Result<Self, Box<dyn Error>> {
pub fn load(path: PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let mut f = std::fs::File::open(&path)?;
let metadata = fs::metadata(&path)?;
let mut f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let metadata = fs::metadata(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer)?;
f.read_exact(&mut buffer).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let result = bincode::deserialize(&buffer)?;
Ok(result)
}
@@ -506,7 +513,7 @@ impl Model {
&self,
run_args: &RunArgs,
check_mode: CheckMode,
) -> Result<GraphSettings, Box<dyn Error>> {
) -> Result<GraphSettings, GraphError> {
let instance_shapes = self.instance_shapes()?;
#[cfg(not(target_arch = "wasm32"))]
debug!(
@@ -536,7 +543,7 @@ impl Model {
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.collect::<Result<Vec<_>, GraphError>>()?;
let res = self.dummy_layout(run_args, &inputs, false, false)?;
@@ -583,7 +590,7 @@ impl Model {
run_args: &RunArgs,
witness_gen: bool,
check_lookup: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
) -> Result<ForwardResult, GraphError> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
@@ -601,15 +608,12 @@ impl Model {
fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<TractResult, Box<dyn Error>> {
) -> Result<TractResult, GraphError> {
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| {
error!("Error loading model: {}", e);
GraphError::ModelLoad
})?;
let mut model = tract_onnx::onnx().model_for_read(reader)?;
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(run_args.variables.clone());
@@ -622,7 +626,7 @@ impl Model {
if matches!(x, GenericFactoid::Any) {
let batch_size = match variables.get("batch_size") {
Some(x) => x,
None => return Err("Unknown dimension batch_size in model inputs, set batch_size in variables".into()),
None => return Err(GraphError::MissingBatchSize),
};
fact.shape
.set_dim(i, tract_onnx::prelude::TDim::Val(*batch_size as i64));
@@ -680,12 +684,12 @@ impl Model {
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
visibility: &VarVisibility,
) -> Result<ParsedNodes, Box<dyn Error>> {
) -> Result<ParsedNodes, GraphError> {
let start_time = instant::Instant::now();
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
let scales = VarScales::from_args(run_args)?;
let scales = VarScales::from_args(run_args);
let nodes = Self::nodes_from_graph(
&model,
run_args,
@@ -762,7 +766,7 @@ impl Model {
symbol_values: &SymbolValues,
override_input_scales: Option<Vec<crate::Scale>>,
override_output_scales: Option<HashMap<usize, crate::Scale>>,
) -> Result<BTreeMap<usize, NodeType>, Box<dyn Error>> {
) -> Result<BTreeMap<usize, NodeType>, GraphError> {
use crate::graph::node_output_shapes;
let mut nodes = BTreeMap::<usize, NodeType>::new();
@@ -976,14 +980,14 @@ impl Model {
model_path: &std::path::Path,
data_chunks: &[GraphData],
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Vec<Tensor<f32>>>, Box<dyn Error>> {
) -> Result<Vec<Vec<Tensor<f32>>>, GraphError> {
use tract_onnx::tract_core::internal::IntoArcTensor;
let (model, _) = Model::load_onnx_using_tract(
&mut std::fs::File::open(model_path)
.map_err(|_| format!("failed to load {}", model_path.display()))?,
run_args,
)?;
let mut file = std::fs::File::open(model_path).map_err(|e| {
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
})?;
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
let datum_types: Vec<DatumType> = model
.input_outlets()?
@@ -1011,15 +1015,11 @@ impl Model {
/// # Arguments
/// * `params` - A [GraphSettings] struct holding parsed CLI arguments.
#[cfg(not(target_arch = "wasm32"))]
pub fn from_run_args(
run_args: &RunArgs,
model: &std::path::Path,
) -> Result<Self, Box<dyn Error>> {
Model::new(
&mut std::fs::File::open(model)
.map_err(|_| format!("failed to load {}", model.display()))?,
run_args,
)
pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result<Self, GraphError> {
let mut file = std::fs::File::open(model).map_err(|e| {
GraphError::ReadWriteFileError(model.display().to_string(), e.to_string())
})?;
Model::new(&mut file, run_args)
}
/// Configures a model for the circuit
@@ -1031,7 +1031,7 @@ impl Model {
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
settings: &GraphSettings,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
) -> Result<PolyConfig<Fp>, GraphError> {
debug!("configuring model");
let lookup_range = settings.run_args.lookup_range;
@@ -1093,7 +1093,7 @@ impl Model {
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
info!("model layout...");
let start_time = instant::Instant::now();
@@ -1103,7 +1103,11 @@ impl Model {
let input_shapes = self.graph.input_shapes()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
if self.visibility.input.is_public() {
let instance = vars.instance.as_ref().ok_or("no instance")?.clone();
let instance = vars
.instance
.as_ref()
.ok_or(GraphError::MissingInstances)?
.clone();
results.insert(*input_idx, vec![instance]);
vars.increment_instance_idx();
} else {
@@ -1123,7 +1127,12 @@ impl Model {
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
let mut thread_safe_region = RegionCtx::new_with_constants(
region,
0,
run_args.num_inner_cols,
original_constants.clone(),
);
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1147,24 +1156,31 @@ impl Model {
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars.instance.as_ref().ok_or("no instance")?.clone();
let res = vars
.instance
.as_ref()
.ok_or(GraphError::MissingInstances)?
.clone();
vars.increment_instance_idx();
res
} else {
// if witnessed_outputs is of len less than i error
if witnessed_outputs.len() <= i {
return Err("you provided insufficient witness values to generate a fixed output".into());
return Err(GraphError::InsufficientWitnessValues);
}
witnessed_outputs[i].clone()
};
config.base.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::RangeCheck(tolerance)),
)
config
.base
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::RangeCheck(tolerance)),
)
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>,_>>();
.collect::<Result<Vec<_>, GraphError>>();
res.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
@@ -1178,7 +1194,6 @@ impl Model {
Ok(outputs)
},
)?;
let duration = start_time.elapsed();
@@ -1192,7 +1207,7 @@ impl Model {
config: &mut ModelConfig,
region: &mut RegionCtx<Fp>,
results: &mut BTreeMap<usize, Vec<ValTensor<Fp>>>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
// index over results to get original inputs
let orig_inputs: BTreeMap<usize, _> = results
.clone()
@@ -1237,7 +1252,10 @@ impl Model {
let res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node.opkind.get_mutable_constant().ok_or("no constant")?;
let c = node
.opkind
.get_mutable_constant()
.ok_or(GraphError::MissingConstants)?;
Some(c.quantized_values.clone().try_into()?)
} else {
config
@@ -1394,7 +1412,7 @@ impl Model {
inputs: &[ValTensor<Fp>],
witness_gen: bool,
check_lookup: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
) -> Result<DummyPassRes, GraphError> {
debug!("calculating num of constraints using dummy model layout...");
let start_time = instant::Instant::now();
@@ -1549,7 +1567,7 @@ impl Model {
}
/// Shapes of the computational graph's public inputs (if any)
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
let mut instance_shapes = vec![];
if self.visibility.input.is_public() {
instance_shapes.extend(self.graph.input_shapes()?);

View File

@@ -11,6 +11,7 @@ use halo2curves::bn256::{Fr as Fp, G1Affine};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::errors::GraphError;
use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
@@ -295,7 +296,7 @@ impl GraphModules {
element_visibility: &Visibility,
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
) -> Result<ModuleForwardResult, GraphError> {
let mut poseidon_hash = None;
let mut polycommit = None;

View File

@@ -8,11 +8,14 @@ use super::Visibility;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::circuit::CircuitError;
use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::errors::GraphError;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
@@ -22,7 +25,6 @@ use serde::Deserialize;
use serde::Serialize;
#[cfg(not(target_arch = "wasm32"))]
use std::collections::BTreeMap;
use std::error::Error;
#[cfg(not(target_arch = "wasm32"))]
use std::fmt;
#[cfg(not(target_arch = "wasm32"))]
@@ -65,7 +67,7 @@ impl Op<Fp> for Rescaled {
format!("RESCALED INPUT ({})", self.inner.as_string())
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let in_scales = in_scales
.into_iter()
.zip(self.scale.iter())
@@ -80,11 +82,9 @@ impl Op<Fp> for Rescaled {
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(Box::new(TensorError::DimMismatch(
"rescaled inputs".to_string(),
)));
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
@@ -210,7 +210,7 @@ impl Op<Fp> for RebaseScale {
)
}
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.target_scale)
}
@@ -219,11 +219,11 @@ impl Op<Fp> for RebaseScale {
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or("no inner layout")?;
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[original_res])
}
@@ -306,7 +306,7 @@ impl SupportedOp {
fn homogenous_rescale(
&self,
in_scales: Vec<crate::Scale>,
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let inputs_to_scale = self.requires_homogenous_input_scales();
// creates a rescaled op if the inputs are not homogenous
let op = self.clone_dyn();
@@ -372,7 +372,7 @@ impl Op<Fp> for SupportedOp {
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}
@@ -400,7 +400,7 @@ impl Op<Fp> for SupportedOp {
self
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
self.as_op().out_scale(in_scales)
}
}
@@ -478,7 +478,7 @@ impl Node {
symbol_values: &SymbolValues,
div_rebasing: bool,
rebase_frac_zero_constants: bool,
) -> Result<Self, Box<dyn Error>> {
) -> Result<Self, GraphError> {
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
@@ -504,10 +504,15 @@ impl Node {
input_ids
.iter()
.map(|(i, _)| {
inputs.push(other_nodes.get(i).ok_or("input not found")?.clone());
inputs.push(
other_nodes
.get(i)
.ok_or(GraphError::MissingInput(idx))?
.clone(),
);
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.collect::<Result<Vec<_>, GraphError>>()?;
let (mut opkind, deleted_indices) = new_op_from_onnx(
idx,
@@ -544,10 +549,10 @@ impl Node {
let idx = inputs
.iter()
.position(|x| *idx == x.idx())
.ok_or("input not found")?;
.ok_or(GraphError::MissingInput(*idx))?;
Ok(inputs[idx].out_scales()[*outlet])
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.collect::<Result<Vec<_>, GraphError>>()?;
let homogenous_inputs = opkind.requires_homogenous_input_scales();
// automatically increases a constant's scale if it is only used once and
@@ -558,7 +563,7 @@ impl Node {
if inputs.len() > input {
let input_node = other_nodes
.get_mut(&inputs[input].idx())
.ok_or("input not found")?;
.ok_or(GraphError::MissingInput(idx))?;
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
rescale_const_with_single_use(
@@ -615,10 +620,10 @@ fn rescale_const_with_single_use(
in_scales: Vec<crate::Scale>,
param_visibility: &Visibility,
num_uses: usize,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), GraphError> {
if num_uses == 1 {
let current_scale = constant.out_scale(vec![])?;
let scale_max = in_scales.iter().max().ok_or("no scales")?;
let scale_max = in_scales.iter().max().ok_or(GraphError::MissingScale)?;
if scale_max > &current_scale {
let raw_values = constant.raw_values.clone();
constant.quantized_values =

View File

@@ -1,5 +1,4 @@
#[cfg(not(target_arch = "wasm32"))]
use super::GraphError;
use super::errors::GraphError;
#[cfg(not(target_arch = "wasm32"))]
use super::VarScales;
use super::{Rescaled, SupportedOp, Visibility};
@@ -16,7 +15,6 @@ use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::{debug, warn};
use std::error::Error;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
@@ -92,7 +90,7 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
pub fn node_output_shapes(
node: &OnnxNode<TypedFact, Box<dyn TypedOp>>,
symbol_values: &SymbolValues,
) -> Result<Vec<Vec<usize>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Vec<usize>>, GraphError> {
let mut shapes = Vec::new();
let outputs = node.outputs.to_vec();
for output in outputs {
@@ -109,7 +107,7 @@ use tract_onnx::prelude::SymbolValues;
/// Extracts the raw values from a tensor.
pub fn extract_tensor_value(
input: Arc<tract_onnx::prelude::Tensor>,
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
) -> Result<Tensor<f32>, GraphError> {
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
let dt = input.datum_type();
@@ -194,20 +192,20 @@ pub fn extract_tensor_value(
// Generally a shape or hyperparam
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
let cast: Result<Vec<f32>, &str> = vec
let cast: Result<Vec<f32>, GraphError> = vec
.par_iter()
.map(|x| match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => Err("could not evaluate tdim"),
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
},
})
.collect();
const_value = Tensor::<f32>::new(Some(&cast?), &dims)?;
}
_ => return Err("unsupported data type".into()),
_ => return Err(GraphError::UnsupportedDataType(0, format!("{:?}", dt))),
}
const_value.reshape(&dims)?;
@@ -219,12 +217,12 @@ fn load_op<C: tract_onnx::prelude::Op + Clone>(
op: &dyn tract_onnx::prelude::Op,
idx: usize,
name: String,
) -> Result<C, Box<dyn std::error::Error>> {
) -> Result<C, GraphError> {
// Extract the slope layer hyperparams
let op: &C = match op.downcast_ref::<C>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, name)));
return Err(GraphError::OpMismatch(idx, name));
}
};
@@ -247,7 +245,7 @@ pub fn new_op_from_onnx(
inputs: &mut [super::NodeType],
symbol_values: &SymbolValues,
rebase_frac_zero_constants: bool,
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
@@ -260,7 +258,7 @@ pub fn new_op_from_onnx(
let mut replace_const = |scale: crate::Scale,
index: usize,
default_op: SupportedOp|
-> Result<SupportedOp, Box<dyn std::error::Error>> {
-> Result<SupportedOp, GraphError> {
let mut constant = inputs[index].opkind();
let constant = constant.get_mutable_constant();
if let Some(c) = constant {
@@ -285,19 +283,13 @@ pub fn new_op_from_onnx(
deleted_indices.push(1);
let raw_values = &c.raw_values;
if raw_values.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"shift left".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
}
SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] - raw_values[0] as i32),
})
} else {
return Err(Box::new(GraphError::OpMismatch(
idx,
"ShiftLeft".to_string(),
)));
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
}
}
"ShiftRight" => {
@@ -307,19 +299,13 @@ pub fn new_op_from_onnx(
deleted_indices.push(1);
let raw_values = &c.raw_values;
if raw_values.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"shift right".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
}
SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values[0] as i32),
})
} else {
return Err(Box::new(GraphError::OpMismatch(
idx,
"ShiftRight".to_string(),
)));
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
}
}
"MultiBroadcastTo" => {
@@ -337,7 +323,7 @@ pub fn new_op_from_onnx(
for (i, input) in inputs.iter_mut().enumerate() {
if !input.opkind().is_constant() {
return Err("Range only supports constant inputs in a zk circuit".into());
return Err(GraphError::NonConstantRange);
} else {
input.decrement_use();
deleted_indices.push(i);
@@ -348,7 +334,7 @@ pub fn new_op_from_onnx(
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
let input_ops = input_ops
.iter()
.map(|x| x.get_constant().ok_or("Range requires constant inputs"))
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
.collect::<Result<Vec<_>, _>>()?;
let start = input_ops[0].raw_values.map(|x| x as usize)[0];
@@ -375,11 +361,11 @@ pub fn new_op_from_onnx(
deleted_indices.push(1);
let raw_values = &c.raw_values;
if raw_values.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "trilu".to_string())));
return Err(GraphError::InvalidDims(idx, "trilu".to_string()));
}
raw_values[0] as i32
} else {
return Err("we only support constant inputs for trilu diagonal".into());
return Err(GraphError::NonConstantTrilu);
};
SupportedOp::Linear(PolyOp::Trilu { upper, k: diagonal })
@@ -387,7 +373,7 @@ pub fn new_op_from_onnx(
"Gather" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(idx, "gather".to_string())));
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
};
let op = load_op::<Gather>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
@@ -456,10 +442,7 @@ pub fn new_op_from_onnx(
}
"ScatterElements" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter elements".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
};
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
@@ -494,10 +477,7 @@ pub fn new_op_from_onnx(
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
@@ -529,10 +509,7 @@ pub fn new_op_from_onnx(
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
@@ -566,10 +543,7 @@ pub fn new_op_from_onnx(
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather elements".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
};
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
@@ -615,10 +589,7 @@ pub fn new_op_from_onnx(
}
_ => {
return Err(Box::new(GraphError::OpMismatch(
idx,
"MoveAxis".to_string(),
)))
return Err(GraphError::OpMismatch(idx, "MoveAxis".to_string()));
}
}
}
@@ -654,7 +625,9 @@ pub fn new_op_from_onnx(
| DatumType::U32
| DatumType::U64 => 0,
DatumType::F16 | DatumType::F32 | DatumType::F64 => scales.params,
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
_ => {
return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt)));
}
};
// if all raw_values are round then set scale to 0
@@ -672,7 +645,7 @@ pub fn new_op_from_onnx(
}
"Reduce<ArgMax(false)>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "argmax".to_string())));
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
@@ -682,7 +655,7 @@ pub fn new_op_from_onnx(
}
"Reduce<ArgMin(false)>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "argmin".to_string())));
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
@@ -692,7 +665,7 @@ pub fn new_op_from_onnx(
}
"Reduce<Min>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
return Err(GraphError::InvalidDims(idx, "min".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
@@ -701,7 +674,7 @@ pub fn new_op_from_onnx(
}
"Reduce<Max>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
return Err(GraphError::InvalidDims(idx, "max".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
@@ -710,7 +683,7 @@ pub fn new_op_from_onnx(
}
"Reduce<Prod>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "prod".to_string())));
return Err(GraphError::InvalidDims(idx, "prod".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
@@ -727,7 +700,7 @@ pub fn new_op_from_onnx(
}
"Reduce<Sum>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "sum".to_string())));
return Err(GraphError::InvalidDims(idx, "sum".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
@@ -736,10 +709,7 @@ pub fn new_op_from_onnx(
}
"Reduce<MeanOfSquares>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"mean of squares".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "mean of squares".to_string()));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
@@ -759,7 +729,7 @@ pub fn new_op_from_onnx(
.collect::<Vec<_>>();
if const_inputs.len() != 1 {
return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string())));
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
}
let const_idx = const_inputs[0];
@@ -768,10 +738,10 @@ pub fn new_op_from_onnx(
if c.len() == 1 {
c[0]
} else {
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
return Err(GraphError::InvalidDims(idx, "max".to_string()));
}
} else {
return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string())));
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
};
if inputs.len() == 2 {
@@ -790,7 +760,7 @@ pub fn new_op_from_onnx(
})
}
} else {
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
return Err(GraphError::InvalidDims(idx, "max".to_string()));
}
}
"Min" => {
@@ -805,7 +775,7 @@ pub fn new_op_from_onnx(
.collect::<Vec<_>>();
if const_inputs.len() != 1 {
return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string())));
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
}
let const_idx = const_inputs[0];
@@ -814,10 +784,10 @@ pub fn new_op_from_onnx(
if c.len() == 1 {
c[0]
} else {
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
return Err(GraphError::InvalidDims(idx, "min".to_string()));
}
} else {
return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string())));
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
};
if inputs.len() == 2 {
@@ -834,7 +804,7 @@ pub fn new_op_from_onnx(
a: crate::circuit::utils::F32(unit),
})
} else {
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
return Err(GraphError::InvalidDims(idx, "min".to_string()));
}
}
"Recip" => {
@@ -855,10 +825,7 @@ pub fn new_op_from_onnx(
let leaky_op: &LeakyRelu = match leaky_op.0.downcast_ref::<LeakyRelu>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(
idx,
"leaky relu".to_string(),
)));
return Err(GraphError::OpMismatch(idx, "leaky relu".to_string()));
}
};
@@ -867,7 +834,7 @@ pub fn new_op_from_onnx(
})
}
"Scan" => {
return Err("scan should never be analyzed explicitly".into());
unreachable!();
}
"QuantizeLinearU8" | "DequantizeLinearF32" => {
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
@@ -932,7 +899,9 @@ pub fn new_op_from_onnx(
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Source" => {
let (scale, datum_type) = match node.outputs[0].fact.datum_type {
let dt = node.outputs[0].fact.datum_type;
let (scale, datum_type) = match dt {
DatumType::Bool => (0, InputType::Bool),
DatumType::TDim => (0, InputType::TDim),
DatumType::I64
@@ -946,7 +915,7 @@ pub fn new_op_from_onnx(
DatumType::F16 => (scales.input, InputType::F16),
DatumType::F32 => (scales.input, InputType::F32),
DatumType::F64 => (scales.input, InputType::F64),
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
};
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
}
@@ -985,7 +954,7 @@ pub fn new_op_from_onnx(
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
}
}
"Add" => SupportedOp::Linear(PolyOp::Add),
@@ -1001,7 +970,7 @@ pub fn new_op_from_onnx(
.collect::<Vec<_>>();
if const_idx.len() > 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "mul".to_string())));
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
}
if const_idx.len() == 1 {
@@ -1027,17 +996,14 @@ pub fn new_op_from_onnx(
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Less)
} else {
return Err(Box::new(GraphError::InvalidDims(idx, "less".to_string())));
return Err(GraphError::InvalidDims(idx, "less".to_string()));
}
}
"LessEqual" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::LessEqual)
} else {
return Err(Box::new(GraphError::InvalidDims(
idx,
"less equal".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
}
}
"Greater" => {
@@ -1045,10 +1011,7 @@ pub fn new_op_from_onnx(
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Greater)
} else {
return Err(Box::new(GraphError::InvalidDims(
idx,
"greater".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
}
}
"GreaterEqual" => {
@@ -1056,10 +1019,7 @@ pub fn new_op_from_onnx(
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::GreaterEqual)
} else {
return Err(Box::new(GraphError::InvalidDims(
idx,
"greater equal".to_string(),
)));
return Err(GraphError::InvalidDims(idx, "greater equal".to_string()));
}
}
"EinSum" => {
@@ -1067,7 +1027,7 @@ pub fn new_op_from_onnx(
let op: &EinSum = match node.op().downcast_ref::<EinSum>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "einsum".to_string())));
return Err(GraphError::OpMismatch(idx, "einsum".to_string()));
}
};
@@ -1081,7 +1041,7 @@ pub fn new_op_from_onnx(
let softmax_op: &Softmax = match node.op().downcast_ref::<Softmax>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "softmax".to_string())));
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
}
};
@@ -1100,7 +1060,7 @@ pub fn new_op_from_onnx(
let sumpool_node: &MaxPool = match op.downcast_ref() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "Maxpool".to_string())));
return Err(GraphError::OpMismatch(idx, "Maxpool".to_string()));
}
};
@@ -1108,9 +1068,9 @@ pub fn new_op_from_onnx(
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(Box::new(GraphError::MissingParams(
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
)));
));
}
let stride = pool_spec
@@ -1122,7 +1082,7 @@ pub fn new_op_from_onnx(
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let kernel_shape = &pool_spec.kernel_shape;
@@ -1170,15 +1130,15 @@ pub fn new_op_from_onnx(
let conv_node: &Conv = match node.op().downcast_ref::<Conv>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "conv".to_string())));
return Err(GraphError::OpMismatch(idx, "conv".to_string()));
}
};
if let Some(dilations) = &conv_node.pool_spec.dilations {
if dilations.iter().any(|x| *x != 1) {
return Err(Box::new(GraphError::MisformedParams(
return Err(GraphError::MisformedParams(
"non unit dilations not supported".to_string(),
)));
));
}
}
@@ -1186,15 +1146,15 @@ pub fn new_op_from_onnx(
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
return Err(Box::new(GraphError::MisformedParams(
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
)));
));
}
let stride = match conv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
return Err(GraphError::MissingParams("strides".to_string()));
}
};
@@ -1203,7 +1163,7 @@ pub fn new_op_from_onnx(
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
return Err(GraphError::MissingParams("padding".to_string()));
}
};
@@ -1234,30 +1194,30 @@ pub fn new_op_from_onnx(
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "deconv".to_string())));
return Err(GraphError::OpMismatch(idx, "deconv".to_string()));
}
};
if let Some(dilations) = &deconv_node.pool_spec.dilations {
if dilations.iter().any(|x| *x != 1) {
return Err(Box::new(GraphError::MisformedParams(
return Err(GraphError::MisformedParams(
"non unit dilations not supported".to_string(),
)));
));
}
}
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|| (deconv_node.kernel_format != KernelFormat::OIHW)
{
return Err(Box::new(GraphError::MisformedParams(
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
)));
));
}
let stride = match deconv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
return Err(GraphError::MissingParams("strides".to_string()));
}
};
let padding = match &deconv_node.pool_spec.padding {
@@ -1265,7 +1225,7 @@ pub fn new_op_from_onnx(
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
return Err(GraphError::MissingParams("padding".to_string()));
}
};
@@ -1295,10 +1255,7 @@ pub fn new_op_from_onnx(
let downsample_node: Downsample = match node.op().downcast_ref::<Downsample>() {
Some(b) => b.clone(),
None => {
return Err(Box::new(GraphError::OpMismatch(
idx,
"downsample".to_string(),
)));
return Err(GraphError::OpMismatch(idx, "downsample".to_string()));
}
};
@@ -1323,7 +1280,7 @@ pub fn new_op_from_onnx(
}
// check if optional scale factor is present
if inputs.len() != 2 && inputs.len() != 3 {
return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string())));
return Err(GraphError::OpMismatch(idx, "Resize".to_string()));
}
let scale_factor_node = // find optional_scales_input in the string and extract the value inside the Some
@@ -1337,7 +1294,7 @@ pub fn new_op_from_onnx(
.collect::<Vec<_>>()[1]
.split(')')
.collect::<Vec<_>>()[0]
.parse::<usize>()?)
.parse::<usize>().map_err(|_| GraphError::OpMismatch(idx, "Resize".to_string()))?)
};
let scale_factor = if let Some(scale_factor_node) = scale_factor_node {
@@ -1345,7 +1302,7 @@ pub fn new_op_from_onnx(
if let Some(c) = extract_const_raw_values(boxed_op) {
c.map(|x| x as usize).into_iter().collect::<Vec<usize>>()
} else {
return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string())));
return Err(GraphError::OpMismatch(idx, "Resize".to_string()));
}
} else {
// default
@@ -1369,7 +1326,7 @@ pub fn new_op_from_onnx(
let sumpool_node: &SumPool = match op.downcast_ref() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "sumpool".to_string())));
return Err(GraphError::OpMismatch(idx, "sumpool".to_string()));
}
};
@@ -1377,9 +1334,9 @@ pub fn new_op_from_onnx(
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(Box::new(GraphError::MissingParams(
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
)));
));
}
let stride = pool_spec
@@ -1391,7 +1348,7 @@ pub fn new_op_from_onnx(
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
return Err(GraphError::MissingParams("padding".to_string()));
}
};
@@ -1411,7 +1368,7 @@ pub fn new_op_from_onnx(
let pad_node: &Pad = match node.op().downcast_ref::<Pad>() {
Some(b) => b,
None => {
return Err(Box::new(GraphError::OpMismatch(idx, "pad".to_string())));
return Err(GraphError::OpMismatch(idx, "pad".to_string()));
}
};
// we only support constant 0 padding
@@ -1420,9 +1377,9 @@ pub fn new_op_from_onnx(
tract_onnx::prelude::Tensor::zero::<f32>(&[])?,
))
{
return Err(Box::new(GraphError::MisformedParams(
return Err(GraphError::MisformedParams(
"pad mode or pad type".to_string(),
)));
));
}
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
@@ -1473,7 +1430,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
const_value: Tensor<f32>,
scale: crate::Scale,
visibility: &Visibility,
) -> Result<Tensor<F>, Box<dyn std::error::Error>> {
) -> Result<Tensor<F>, TensorError> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
&(x).into(),
@@ -1492,7 +1449,7 @@ use crate::tensor::ValTensor;
pub(crate) fn split_valtensor(
values: &ValTensor<Fp>,
shapes: Vec<Vec<usize>>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
let mut tensors: Vec<ValTensor<Fp>> = Vec::new();
let mut start = 0;
for shape in shapes {
@@ -1510,7 +1467,7 @@ pub fn homogenize_input_scales(
op: Box<dyn Op<Fp>>,
input_scales: Vec<crate::Scale>,
inputs_to_scale: Vec<usize>,
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let relevant_input_scales = input_scales
.clone()
.into_iter()
@@ -1529,7 +1486,7 @@ pub fn homogenize_input_scales(
let mut multipliers: Vec<u128> = vec![1; input_scales.len()];
let max_scale = input_scales.iter().max().ok_or("no max scale")?;
let max_scale = input_scales.iter().max().ok_or(GraphError::MissingScale)?;
let _ = input_scales
.iter()
.enumerate()

View File

@@ -1,4 +1,3 @@
use std::error::Error;
use std::fmt::Display;
use crate::tensor::TensorType;
@@ -17,6 +16,8 @@ use pyo3::{
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use self::errors::GraphError;
use super::*;
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
@@ -261,12 +262,12 @@ impl VarScales {
}
/// Place in [VarScales] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
Ok(Self {
pub fn from_args(args: &RunArgs) -> Self {
Self {
input: args.input_scale,
params: args.param_scale,
rebase_multiplier: args.scale_rebase_multiplier,
})
}
}
}
@@ -303,15 +304,13 @@ impl Default for VarVisibility {
impl VarVisibility {
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
/// Place in [VarVisibility] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
let input_vis = &args.input_visibility;
let params_vis = &args.param_visibility;
let output_vis = &args.output_visibility;
if params_vis.is_public() {
return Err(
"public visibility for params is deprecated, please use `fixed` instead".into(),
);
return Err(GraphError::ParamsPublicVisibility);
}
if !output_vis.is_public()
@@ -327,7 +326,7 @@ impl VarVisibility {
& !params_vis.is_polycommit()
& !input_vis.is_polycommit()
{
return Err(Box::new(GraphError::Visibility));
return Err(GraphError::Visibility);
}
Ok(Self {
input: input_vis.clone(),

View File

@@ -28,6 +28,59 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
/// Error type
#[derive(thiserror::Error, Debug)]
#[allow(missing_docs)]
pub enum EZKLError {
#[error("[aggregation] {0}")]
AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError),
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[eth] {0}")]
EthError(#[from] eth::EthError),
#[error("[graph] {0}")]
GraphError(#[from] graph::errors::GraphError),
#[error("[pfsys] {0}")]
PfsysError(#[from] pfsys::errors::PfsysError),
#[error("[circuit] {0}")]
CircuitError(#[from] circuit::errors::CircuitError),
#[error("[tensor] {0}")]
TensorError(#[from] tensor::errors::TensorError),
#[error("[module] {0}")]
ModuleError(#[from] circuit::modules::errors::ModuleError),
#[error("[io] {0}")]
IoError(#[from] std::io::Error),
#[error("[json] {0}")]
JsonError(#[from] serde_json::Error),
#[error("[utf8] {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[reqwest] {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("[fmt] {0}")]
FmtError(#[from] std::fmt::Error),
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
#[error("[Uncategorized] {0}")]
UncategorizedError(String),
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[execute] {0}")]
ExecutionError(#[from] execute::ExecutionError),
#[error("[srs] {0}")]
SrsError(#[from] pfsys::srs::SrsError),
}
impl From<&str> for EZKLError {
fn from(s: &str) -> Self {
EZKLError::UncategorizedError(s.to_string())
}
}
impl From<String> for EZKLError {
fn from(s: String) -> Self {
EZKLError::UncategorizedError(s)
}
}
use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
@@ -248,7 +301,7 @@ impl Default for RunArgs {
impl RunArgs {
///
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
pub fn validate(&self) -> Result<(), String> {
if self.param_visibility == Visibility::Public {
return Err(
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"

27
src/pfsys/errors.rs Normal file
View File

@@ -0,0 +1,27 @@
use thiserror::Error;
/// Error type for the pfsys module
#[derive(Error, Debug)]
pub enum PfsysError {
/// Failed to save the proof
#[error("failed to save the proof: {0}")]
SaveProof(String),
/// Failed to load the proof
#[error("failed to load the proof: {0}")]
LoadProof(String),
/// Halo2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
/// Failed to write point to transcript
#[error("failed to write point to transcript: {0}")]
WritePoint(String),
/// Invalid commitment scheme
#[error("invalid commitment scheme")]
InvalidCommitmentScheme,
/// Failed to load vk from file
#[error("failed to load vk from file: {0}")]
LoadVk(String),
/// Failed to load pk from file
#[error("failed to load pk from file: {0}")]
LoadPk(String),
}

View File

@@ -10,17 +10,14 @@ pub enum EvmVerificationError {
#[error("Solidity verifier found the proof invalid")]
InvalidProof,
/// If the Solidity verifier threw and error (e.g. OutOfGas)
#[error("Execution of Solidity code failed")]
SolidityExecution,
/// EVM execution errors
#[error("EVM execution of raw code failed")]
RawExecution,
#[error("Execution of Solidity code failed: {0}")]
SolidityExecution(String),
/// EVM verify errors
#[error("evm verification reverted")]
Reverted,
#[error("evm verification reverted: {0}")]
Reverted(String),
/// EVM verify errors
#[error("evm deployment failed")]
Deploy,
#[error("evm deployment failed: {0}")]
DeploymentFailed(String),
/// Invalid Visibility
#[error("Invalid visibility")]
InvalidVisibility,

View File

@@ -4,6 +4,11 @@ pub mod evm;
/// SRS generation, processing, verification and downloading
pub mod srs;
/// errors related to pfsys
pub mod errors;
pub use errors::PfsysError;
use crate::circuit::CheckMode;
use crate::graph::GraphWitness;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
@@ -32,7 +37,6 @@ use serde::{Deserialize, Serialize};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::verifier::plonk::PlonkProtocol;
use std::error::Error;
use std::fs::File;
use std::io::{self, BufReader, BufWriter, Cursor, Write};
use std::ops::Deref;
@@ -364,24 +368,28 @@ where
}
/// Saves the Proof to a specified `proof_path`.
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
let file = std::fs::File::create(proof_path)?;
pub fn save(&self, proof_path: &PathBuf) -> Result<(), PfsysError> {
let file = std::fs::File::create(proof_path)
.map_err(|e| PfsysError::SaveProof(format!("{}", e)))?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(&mut writer, &self)?;
serde_json::to_writer(&mut writer, &self)
.map_err(|e| PfsysError::SaveProof(format!("{}", e)))?;
Ok(())
}
/// Load a json serialized proof from the provided path.
pub fn load<Scheme: CommitmentScheme<Curve = C, Scalar = F>>(
proof_path: &PathBuf,
) -> Result<Self, Box<dyn Error>>
) -> Result<Self, PfsysError>
where
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
{
trace!("reading proof");
let file = std::fs::File::open(proof_path)?;
let file =
std::fs::File::open(proof_path).map_err(|e| PfsysError::LoadProof(format!("{}", e)))?;
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let proof: Self = serde_json::from_reader(reader)?;
let proof: Self =
serde_json::from_reader(reader).map_err(|e| PfsysError::LoadProof(format!("{}", e)))?;
Ok(proof)
}
}
@@ -541,7 +549,7 @@ pub fn create_proof_circuit<
transcript_type: TranscriptType,
split: Option<ProofSplitCommit>,
protocol: Option<PlonkProtocol<Scheme::Curve>>,
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
where
Scheme::ParamsVerifier: 'params,
Scheme::Scalar: Serialize
@@ -626,7 +634,7 @@ pub fn swap_proof_commitments<
>(
snark: &Snark<Scheme::Scalar, Scheme::Curve>,
commitments: &[Scheme::Curve],
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
where
Scheme::Scalar: SerdeObject
+ PrimeField
@@ -654,7 +662,7 @@ pub fn get_proof_commitments<
TW: TranscriptWriterBuffer<Vec<u8>, Scheme::Curve, E>,
>(
commitments: &[Scheme::Curve],
) -> Result<Vec<u8>, Box<dyn Error>>
) -> Result<Vec<u8>, PfsysError>
where
Scheme::Scalar: SerdeObject
+ PrimeField
@@ -671,7 +679,7 @@ where
for commit in commitments {
transcript_new
.write_point(*commit)
.map_err(|_| "failed to write point")?;
.map_err(|e| PfsysError::WritePoint(format!("{}", e)))?;
}
let proof_first_bytes = transcript_new.finalize();
@@ -687,7 +695,7 @@ where
pub fn swap_proof_commitments_polycommit(
snark: &Snark<Fr, G1Affine>,
commitments: &[G1Affine],
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
) -> Result<Snark<Fr, G1Affine>, PfsysError> {
let proof = match snark.commitment {
Some(Commitments::KZG) => match snark.transcript_type {
TranscriptType::EVM => swap_proof_commitments::<
@@ -714,7 +722,7 @@ pub fn swap_proof_commitments_polycommit(
>(snark, commitments)?,
},
None => {
return Err("commitment scheme not found".into());
return Err(PfsysError::InvalidCommitmentScheme);
}
};
@@ -761,22 +769,22 @@ where
pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
path: PathBuf,
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
where
C: Circuit<Scheme::Scalar>,
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{
info!("loading verification key from {:?}", path);
let f =
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
debug!("loading verification key from {:?}", path);
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadVk(format!("{}", e)))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)?;
info!("done loading verification key ✅");
)
.map_err(|e| PfsysError::LoadVk(format!("{}", e)))?;
info!("loaded verification key ✅");
Ok(vk)
}
@@ -784,22 +792,22 @@ where
pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
path: PathBuf,
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<ProvingKey<Scheme::Curve>, Box<dyn Error>>
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
where
C: Circuit<Scheme::Scalar>,
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{
info!("loading proving key from {:?}", path);
let f =
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
debug!("loading proving key from {:?}", path);
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)?;
info!("done loading proving key ✅");
)
.map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
info!("loaded proving key ✅");
Ok(pk)
}
@@ -811,7 +819,7 @@ pub fn save_pk<C: SerdeObject + CurveAffine>(
where
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
{
info!("saving proving key 💾");
debug!("saving proving key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
@@ -828,7 +836,7 @@ pub fn save_vk<C: CurveAffine + SerdeObject>(
where
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
{
info!("saving verification key 💾");
debug!("saving verification key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
@@ -842,7 +850,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
path: &PathBuf,
params: &'_ Scheme::ParamsVerifier,
) -> Result<(), io::Error> {
info!("saving parameters 💾");
debug!("saving parameters 💾");
let f = File::create(path)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
params.write(&mut writer)?;

View File

@@ -1,8 +1,7 @@
use halo2_proofs::poly::commitment::CommitmentScheme;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::poly::commitment::ParamsProver;
use log::info;
use std::error::Error;
use log::debug;
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
@@ -16,24 +15,33 @@ pub fn gen_srs<Scheme: CommitmentScheme>(k: u32) -> Scheme::ParamsProver {
Scheme::ParamsProver::new(k)
}
#[derive(thiserror::Error, Debug)]
#[allow(missing_docs)]
pub enum SrsError {
#[error("failed to download srs from {0}")]
DownloadError(String),
#[error("failed to load srs from {0}")]
LoadError(PathBuf),
#[error("failed to read srs {0}")]
ReadError(String),
}
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
pub fn load_srs_verifier<Scheme: CommitmentScheme>(
path: PathBuf,
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
info!("loading srs from {:?}", path);
let f = File::open(path.clone())
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
) -> Result<Scheme::ParamsVerifier, SrsError> {
debug!("loading srs from {:?}", path);
let f = File::open(path.clone()).map_err(|_| SrsError::LoadError(path))?;
let mut reader = BufReader::new(f);
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(|e| SrsError::ReadError(e.to_string()))
}
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
pub fn load_srs_prover<Scheme: CommitmentScheme>(
path: PathBuf,
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
info!("loading srs from {:?}", path);
let f = File::open(path.clone())
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
) -> Result<Scheme::ParamsProver, SrsError> {
debug!("loading srs from {:?}", path);
let f = File::open(path.clone()).map_err(|_| SrsError::LoadError(path.clone()))?;
let mut reader = BufReader::new(f);
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(|e| SrsError::ReadError(e.to_string()))
}

30
src/tensor/errors.rs Normal file
View File

@@ -0,0 +1,30 @@
use thiserror::Error;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
pub enum TensorError {
/// Shape mismatch in a operation
#[error("dimension mismatch in tensor op: {0}")]
DimMismatch(String),
/// Shape when instantiating
#[error("dimensionality error when manipulating a tensor: {0}")]
DimError(String),
/// wrong method was called on a tensor-like struct
#[error("wrong method called")]
WrongMethod,
/// Significant bit truncation when instantiating
#[error("significant bit truncation when instantiating, try lowering the scale")]
SigBitTruncationError,
/// Failed to convert to field element tensor
#[error("failed to convert to field element tensor")]
FeltError,
/// Unsupported operation
#[error("unsupported operation on a tensor type")]
Unsupported,
/// Overflow
#[error("unsigned integer overflow or underflow error in op: {0}")]
Overflow(String),
/// Unset visibility
#[error("unset visibility")]
UnsetVisibility,
}

View File

@@ -1,3 +1,5 @@
/// Tensor related errors.
pub mod errors;
/// Implementations of common operations on tensors.
pub mod ops;
/// A wrapper around a tensor of circuit variables / advices.
@@ -5,6 +7,8 @@ pub mod val;
/// A wrapper around a tensor of Halo2 Value types.
pub mod var;
pub use errors::TensorError;
use halo2curves::{bn256::Fr, ff::PrimeField};
use maybe_rayon::{
prelude::{
@@ -40,40 +44,10 @@ use std::fmt::Debug;
use std::iter::Iterator;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
use thiserror::Error;
#[cfg(feature = "metal")]
use std::collections::HashMap;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
pub enum TensorError {
/// Shape mismatch in a operation
#[error("dimension mismatch in tensor op: {0}")]
DimMismatch(String),
/// Shape when instantiating
#[error("dimensionality error when manipulating a tensor: {0}")]
DimError(String),
/// wrong method was called on a tensor-like struct
#[error("wrong method called")]
WrongMethod,
/// Significant bit truncation when instantiating
#[error("Significant bit truncation when instantiating, try lowering the scale")]
SigBitTruncationError,
/// Failed to convert to field element tensor
#[error("Failed to convert to field element tensor")]
FeltError,
/// Table lookup error
#[error("Table lookup error")]
TableLookupError,
/// Unsupported operation
#[error("Unsupported operation on a tensor type")]
Unsupported,
/// Overflow
#[error("Unsigned integer overflow or underflow error in op: {0}")]
Overflow(String),
}
#[cfg(feature = "metal")]
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
@@ -400,9 +374,7 @@ impl IntoI64 for () {
fn into_i64(self) -> i64 {
0
}
fn from_i64(_: i64) -> Self {
}
fn from_i64(_: i64) -> Self {}
}
impl IntoI64 for Fr {
@@ -1852,7 +1824,7 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
pub fn get_broadcasted_shape(
shape_a: &[usize],
shape_b: &[usize],
) -> Result<Vec<usize>, Box<dyn Error>> {
) -> Result<Vec<usize>, TensorError> {
let num_dims_a = shape_a.len();
let num_dims_b = shape_b.len();
@@ -1867,9 +1839,9 @@ pub fn get_broadcasted_shape(
}
(a, b) if a < b => Ok(shape_b.to_vec()),
(a, b) if a > b => Ok(shape_a.to_vec()),
_ => Err(Box::new(TensorError::DimError(
_ => Err(TensorError::DimError(
"Unknown condition for broadcasting".to_string(),
))),
)),
}
}
////////////////////////

View File

@@ -256,23 +256,23 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
}
impl<F: PrimeField + TensorType + PartialOrd> TryFrom<Tensor<F>> for ValTensor<F> {
type Error = Box<dyn Error>;
fn try_from(t: Tensor<F>) -> Result<ValTensor<F>, Box<dyn Error>> {
type Error = TensorError;
fn try_from(t: Tensor<F>) -> Result<ValTensor<F>, TensorError> {
let visibility = t.visibility.clone();
let dims = t.dims().to_vec();
let inner = t.into_iter().map(|x| {
if let Some(vis) = &visibility {
match vis {
Visibility::Fixed => Ok(ValType::Constant(x)),
_ => {
Ok(Value::known(x).into())
let inner = t
.into_iter()
.map(|x| {
if let Some(vis) = &visibility {
match vis {
Visibility::Fixed => Ok(ValType::Constant(x)),
_ => Ok(Value::known(x).into()),
}
} else {
Err(TensorError::UnsetVisibility)
}
}
else {
Err("visibility should be set to convert a tensor of field elements to a ValTensor.".into())
}
}).collect::<Result<Vec<_>, Box<dyn Error>>>()?;
})
.collect::<Result<Vec<_>, TensorError>>()?;
let mut inner: Tensor<ValType<F>> = inner.into_iter().into();
inner.reshape(&dims)?;
@@ -378,13 +378,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// reverse order of elements whilst preserving the shape
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
pub fn reverse(&mut self) -> Result<(), TensorError> {
match self {
ValTensor::Value { inner: v, .. } => {
v.reverse();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
return Err(TensorError::WrongMethod);
}
};
Ok(())
@@ -420,7 +420,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
///
pub fn any_unknowns(&self) -> Result<bool, Box<dyn Error>> {
pub fn any_unknowns(&self) -> Result<bool, TensorError> {
match self {
ValTensor::Instance { .. } => Ok(true),
_ => Ok(self.get_inner()?.iter().any(|&x| {
@@ -491,7 +491,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Fetch the underlying [Tensor] of field elements.
pub fn get_felt_evals(&self) -> Result<Tensor<F>, Box<dyn Error>> {
pub fn get_felt_evals(&self) -> Result<Tensor<F>, TensorError> {
let mut felt_evals: Vec<F> = vec![];
match self {
ValTensor::Value {
@@ -504,7 +504,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
});
}
_ => return Err(Box::new(TensorError::WrongMethod)),
_ => return Err(TensorError::WrongMethod),
};
let mut res: Tensor<F> = felt_evals.into_iter().into();
@@ -521,7 +521,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Calls `int_evals` on the inner tensor.
pub fn get_int_evals(&self) -> Result<Tensor<i64>, Box<dyn Error>> {
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
// finally convert to vector of integers
let mut integer_evals: Vec<i64> = vec![];
match self {
@@ -547,7 +547,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
});
}
_ => return Err(Box::new(TensorError::WrongMethod)),
_ => return Err(TensorError::WrongMethod),
};
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
@@ -558,7 +558,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Calls `pad_to_zero_rem` on the inner tensor.
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
@@ -567,14 +567,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
return Err(TensorError::WrongMethod);
}
};
Ok(())
}
/// Calls `get_slice` on the inner tensor.
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, Box<dyn Error>> {
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
return Ok(self.clone());
}
@@ -592,13 +592,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
scale: *scale,
}
}
_ => return Err(Box::new(TensorError::WrongMethod)),
_ => return Err(TensorError::WrongMethod),
};
Ok(slice)
}
/// Calls `get_single_elem` on the inner tensor.
pub fn get_single_elem(&self, index: usize) -> Result<ValTensor<F>, Box<dyn Error>> {
pub fn get_single_elem(&self, index: usize) -> Result<ValTensor<F>, TensorError> {
let slice = match self {
ValTensor::Value {
inner: v,
@@ -612,7 +612,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
scale: *scale,
}
}
_ => return Err(Box::new(TensorError::WrongMethod)),
_ => return Err(TensorError::WrongMethod),
};
Ok(slice)
}
@@ -648,7 +648,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
})
}
/// Calls `expand` on the inner tensor.
pub fn expand(&mut self, dims: &[usize]) -> Result<(), Box<dyn Error>> {
pub fn expand(&mut self, dims: &[usize]) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
@@ -657,14 +657,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
return Err(TensorError::WrongMethod);
}
};
Ok(())
}
/// Calls `move_axis` on the inner tensor.
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), Box<dyn Error>> {
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
@@ -673,14 +673,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
return Err(TensorError::WrongMethod);
}
};
Ok(())
}
/// Sets the [ValTensor]'s shape.
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box<dyn Error>> {
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
@@ -690,10 +690,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
ValTensor::Instance { dims: d, idx, .. } => {
if d[*idx].iter().product::<usize>() != new_dims.iter().product::<usize>() {
return Err(Box::new(TensorError::DimError(format!(
return Err(TensorError::DimError(format!(
"Cannot reshape {:?} to {:?} as they have number of elements",
d[*idx], new_dims
))));
)));
}
d[*idx] = new_dims.to_vec();
}
@@ -702,12 +702,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Sets the [ValTensor]'s shape.
pub fn slice(
&mut self,
axis: &usize,
start: &usize,
end: &usize,
) -> Result<(), Box<dyn Error>> {
pub fn slice(&mut self, axis: &usize, start: &usize, end: &usize) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
@@ -716,7 +711,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
return Err(TensorError::WrongMethod);
}
};
Ok(())
@@ -982,7 +977,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
/// inverts the inner values
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
pub fn inverse(&self) -> Result<ValTensor<F>, TensorError> {
let mut cloned_self = self.clone();
match &mut cloned_self {
@@ -1000,7 +995,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
return Err(TensorError::WrongMethod);
}
};
Ok(cloned_self)

View File

@@ -31,6 +31,15 @@ pub enum VarTensor {
}
impl VarTensor {
/// name of the tensor
pub fn name(&self) -> &'static str {
match self {
VarTensor::Advice { .. } => "Advice",
VarTensor::Dummy { .. } => "Dummy",
VarTensor::Empty => "Empty",
}
}
///
pub fn is_advice(&self) -> bool {
matches!(self, VarTensor::Advice { .. })

View File

@@ -1066,6 +1066,15 @@ mod native_tests {
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "polycommit", "public", "polycommit");
test_dir.close().unwrap();
}
#(#[test_case(TESTS_ON_CHAIN_INPUT[N])])*
fn kzg_evm_on_chain_all_kzg_params_prove_and_verify_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "file", "polycommit", "polycommit", "polycommit");
test_dir.close().unwrap();
}
});
@@ -2330,7 +2339,6 @@ mod native_tests {
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
init_params(settings_path.clone().into());
let data_path = format!("{}/{}/input.json", test_dir, example_name);
@@ -2342,62 +2350,6 @@ mod native_tests {
let test_input_source = format!("--input-source={}", input_source);
let test_output_source = format!("--output-source={}", output_source);
// load witness
let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap();
let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap();
if input_visibility == "hashed" {
let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap();
input.input_data = DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
);
}
if output_visibility == "hashed" {
let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap();
input.output_data = Some(DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
));
} else {
input.output_data = Some(DataSource::File(
witness
.pretty_elements
.unwrap()
.rescaled_outputs
.iter()
.map(|o| {
o.iter()
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
.collect()
})
.collect(),
));
}
input.save(data_path.clone().into()).unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup",
@@ -2412,6 +2364,82 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
// generate the witness, passing the vk path to generate the necessary kzg commits
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"gen-witness",
"-D",
&data_path,
"-M",
&model_path,
"-O",
&witness_path,
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
])
.status()
.expect("failed to execute process");
assert!(status.success());
// load witness
let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap();
// print out the witness
println!("WITNESS: {:?}", witness);
let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap();
if input_source != "file" || output_source != "file" {
println!("on chain input");
if input_visibility == "hashed" {
let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap();
input.input_data = DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
);
}
if output_visibility == "hashed" {
let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap();
input.output_data = Some(DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
));
} else {
input.output_data = Some(DataSource::File(
witness
.pretty_elements
.unwrap()
.rescaled_outputs
.iter()
.map(|o| {
o.iter()
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
.collect()
})
.collect(),
));
}
input.save(data_path.clone().into()).unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"prove",
@@ -2502,13 +2530,19 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let deploy_evm_data_path = if input_source != "file" || output_source != "file" {
test_on_chain_data_path.clone()
} else {
data_path.clone()
};
let addr_path_da_arg = format!("--addr-path={}/{}/addr_da.txt", test_dir, example_name);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"deploy-evm-da",
format!("--settings-path={}", settings_path).as_str(),
"-D",
test_on_chain_data_path.as_str(),
deploy_evm_data_path.as_str(),
"--sol-code-path",
sol_arg.as_str(),
rpc_arg.as_str(),
@@ -2546,40 +2580,42 @@ mod native_tests {
.status()
.expect("failed to execute process");
assert!(status.success());
// Create a new set of test on chain data
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
// Create a new set of test on chain data only for the on-chain input source
if input_source != "file" || output_source != "file" {
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let deployed_addr_arg = format!("--addr={}", addr_da);
let args: Vec<&str> = vec![
"test-update-account-calls",
deployed_addr_arg.as_str(),
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());
let deployed_addr_arg = format!("--addr={}", addr_da);
let args = vec![
"test-update-account-calls",
deployed_addr_arg.as_str(),
"-D",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());
assert!(status.success());
}
// As sanity check, add example that should fail.
let args = vec![
"verify-evm",