Compare commits

..

3 Commits

Author SHA1 Message Date
dante
1a75963705 refactor: DataSource enum -> struct 2025-04-29 12:18:24 -04:00
dante
0ef1f35e59 fix: uniffi bindings 2025-04-29 12:12:56 -04:00
dante
808ab7d0de chore: feature-gate eth (#978) 2025-04-29 12:02:23 -04:00
6 changed files with 81 additions and 160 deletions

View File

@@ -225,7 +225,7 @@ default = [
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings = ["eth-mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
"onnx",

View File

@@ -1,34 +1,34 @@
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::commands::*;
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::TestDataSource;
use crate::graph::{
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
@@ -1040,25 +1040,22 @@ fn calibrate_settings(
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::calibrate(
model,
data,
settings,
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
max_logrows,
)
.await
.map_err(|e| {
let err_str = format!("Failed to calibrate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
crate::execute::calibrate(
model,
data,
settings,
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
max_logrows,
)
.map_err(|e| {
let err_str = format!("Failed to calibrate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
Ok(true)
}
/// Runs the forward pass operation to generate a witness
@@ -1101,15 +1098,12 @@ fn gen_witness(
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
.await
.map_err(|e| {
let err_str = format!("Failed to generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
})
let output =
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
let err_str = format!("Failed to generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Mocks the prover

View File

@@ -171,7 +171,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
scale_rebase_multiplier,
max_logrows,
)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::GenWitness {
data,
@@ -186,7 +185,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
vk_path,
srs_path,
)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(
model.unwrap_or(DEFAULT_MODEL.into()),
@@ -691,7 +689,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLErr
Ok(String::new())
}
pub(crate) async fn gen_witness(
pub(crate) fn gen_witness(
compiled_circuit_path: PathBuf,
data: String,
output: Option<PathBuf>,
@@ -713,7 +711,7 @@ pub(crate) async fn gen_witness(
None
};
let mut input = circuit.load_graph_input(&data).await?;
let mut input = circuit.load_graph_input(&data)?;
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
let mut input = circuit.load_graph_input(&data)?;
@@ -1024,7 +1022,7 @@ impl AccuracyResults {
/// Calibrate the circuit parameters to a given a dataset
#[allow(trivial_casts)]
#[allow(clippy::too_many_arguments)]
pub(crate) async fn calibrate(
pub(crate) fn calibrate(
model_path: PathBuf,
data: String,
settings_path: PathBuf,
@@ -1050,7 +1048,7 @@ pub(crate) async fn calibrate(
let input_shapes = model.graph.input_shapes()?;
let chunks = data.split_into_batches(input_shapes).await?;
let chunks = data.split_into_batches(input_shapes)?;
info!("num calibration batches: {}", chunks.len());
debug!("running onnx predictions...");
@@ -1161,7 +1159,7 @@ pub(crate) async fn calibrate(
let chunk = chunk.clone();
let data = circuit
.load_graph_from_file_exclusively(&chunk)
.load_graph_input(&chunk)
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit

View File

@@ -1,15 +1,15 @@
use super::errors::GraphError;
use super::quantize_float;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -17,7 +17,7 @@ use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::{
tract_data::{TVec, prelude::Tensor as TractTensor},
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
@@ -209,27 +209,30 @@ pub struct CallToAccount {
/// Represents different sources of input/output data for the EZKL model
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
#[serde(untagged)]
pub enum DataSource {
/// Data from a JSON file containing arrays of values
File(FileSource),
pub struct DataSource(FileSource);
impl DataSource {
/// Gets the underlying file source data
pub fn values(&self) -> &FileSource {
&self.0
}
}
impl Default for DataSource {
fn default() -> Self {
DataSource::File(vec![vec![]])
DataSource(vec![vec![]])
}
}
impl From<FileSource> for DataSource {
fn from(data: FileSource) -> Self {
DataSource::File(data)
DataSource(data)
}
}
impl From<Vec<Vec<Fp>>> for DataSource {
fn from(data: Vec<Vec<Fp>>) -> Self {
DataSource::File(
DataSource(
data.iter()
.map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect())
.collect(),
@@ -239,7 +242,7 @@ impl From<Vec<Vec<Fp>>> for DataSource {
impl From<Vec<Vec<f64>>> for DataSource {
fn from(data: Vec<Vec<f64>>) -> Self {
DataSource::File(
DataSource(
data.iter()
.map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect())
.collect(),
@@ -258,7 +261,7 @@ impl<'de> Deserialize<'de> for DataSource {
// Try deserializing as FileSource first
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = first_try {
return Ok(DataSource::File(t));
return Ok(DataSource(t));
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
@@ -294,19 +297,16 @@ impl GraphData {
datum_types: &[tract_onnx::prelude::DatumType],
) -> Result<TVec<TValue>, GraphError> {
let mut inputs = TVec::new();
match &self.input_data {
DataSource::File(data) => {
for (i, input) in data.iter().enumerate() {
if !input.is_empty() {
let dt = datum_types[i];
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
let tt = TractTensor::from_shape(&shapes[i], &input)?;
let tt = tt.cast_to_dt(dt)?;
inputs.push(tt.into_owned().into());
}
}
for (i, input) in self.input_data.values().iter().enumerate() {
if !input.is_empty() {
let dt = datum_types[i];
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
let tt = TractTensor::from_shape(&shapes[i], &input)?;
let tt = tt.cast_to_dt(dt)?;
inputs.push(tt.into_owned().into());
}
}
Ok(inputs)
}
@@ -338,7 +338,7 @@ impl GraphData {
}
}
Ok(GraphData {
input_data: DataSource::File(input_data),
input_data: DataSource(input_data),
output_data: None,
})
}
@@ -420,7 +420,7 @@ impl GraphData {
/// Returns error if:
/// - Data is from on-chain source
/// - Input size is not evenly divisible by batch size
pub async fn split_into_batches(
pub fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, GraphError> {
@@ -428,7 +428,7 @@ impl GraphData {
let iterable = match self {
GraphData {
input_data: DataSource::File(data),
input_data: DataSource(data),
output_data: _,
} => data.clone(),
};
@@ -476,12 +476,12 @@ impl GraphData {
for input in batched_inputs.iter() {
batch.push(input[i].clone());
}
input_batches.push(DataSource::File(batch));
input_batches.push(DataSource(batch));
}
// Ensure at least one batch exists
if input_batches.is_empty() {
input_batches.push(DataSource::File(vec![vec![]]));
input_batches.push(DataSource(vec![vec![]]));
}
// Create GraphData instance for each batch
@@ -556,9 +556,7 @@ impl ToPyObject for CallToAccount {
#[cfg(feature = "python-bindings")]
impl ToPyObject for DataSource {
fn to_object(&self, py: Python) -> PyObject {
match self {
DataSource::File(data) => data.to_object(py),
}
self.0.to_object(py)
}
}

View File

@@ -35,12 +35,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::{IntegerRep, felt_to_f64};
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{EZKL_BUF_CAPACITY, RunArgs};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use halo2_proofs::{
circuit::Layouter,
@@ -55,13 +55,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
@@ -952,77 +952,11 @@ impl GraphCircuit {
}
///
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
self.process_data_source(&data.input_data, shapes, scales, input_types)
}
///
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn load_graph_from_file_exclusively(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
debug!("input scales: {:?}", scales);
match &data.input_data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
}
}
///
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
}
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
/// Process the data source for the model
fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Process the data source for the model
async fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
}
self.load_file_data(data.input_data.values(), &shapes, scales, input_types)
}
///

View File

@@ -4,7 +4,7 @@ mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, GraphData};
use ezkl::graph::{DataSource, GraphSettings};
use ezkl::graph::GraphSettings;
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -163,17 +163,14 @@ mod native_tests {
let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into())
.expect("failed to load input data");
let input_data = match data.input_data {
DataSource::File(data) => data,
_ => panic!("Only File data sources support batching"),
};
let duplicated_input_data: FileSource = input_data
let duplicated_input_data: FileSource = data
.input_data
.values()
.iter()
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
.collect();
let duplicated_data = GraphData::new(DataSource::File(duplicated_input_data));
let duplicated_data = GraphData::new(duplicated_input_data.into());
let res =
duplicated_data.save(format!("{}/{}/input.json", test_dir, output_dir).into());