mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
3 Commits
ac/make-da
...
ac/feature
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0dbd2ee8a | ||
|
|
3f1b28b7a2 | ||
|
|
c47d3bf05c |
@@ -225,7 +225,7 @@ default = [
|
||||
]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
|
||||
ios-bindings = ["eth-mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
|
||||
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
|
||||
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
|
||||
ezkl = [
|
||||
"onnx",
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
|
||||
};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
|
||||
};
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
ProofType, TranscriptType,
|
||||
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
|
||||
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
};
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use pyo3_stub_gen::{
|
||||
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction, TypeInfo,
|
||||
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction,
|
||||
};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::collections::HashSet;
|
||||
@@ -1040,22 +1040,25 @@ fn calibrate_settings(
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
data,
|
||||
settings,
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to calibrate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
data,
|
||||
settings,
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to calibrate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs the forward pass operation to generate a witness
|
||||
@@ -1098,12 +1101,15 @@ fn gen_witness(
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let output =
|
||||
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
|
||||
let err_str = format!("Failed to generate witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to generate witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
})
|
||||
}
|
||||
|
||||
/// Mocks the prover
|
||||
|
||||
@@ -171,6 +171,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::GenWitness {
|
||||
data,
|
||||
@@ -185,6 +186,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
vk_path,
|
||||
srs_path,
|
||||
)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
@@ -689,7 +691,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLErr
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn gen_witness(
|
||||
pub(crate) async fn gen_witness(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data: String,
|
||||
output: Option<PathBuf>,
|
||||
@@ -711,7 +713,7 @@ pub(crate) fn gen_witness(
|
||||
None
|
||||
};
|
||||
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
@@ -1022,7 +1024,7 @@ impl AccuracyResults {
|
||||
/// Calibrate the circuit parameters to a given a dataset
|
||||
#[allow(trivial_casts)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn calibrate(
|
||||
pub(crate) async fn calibrate(
|
||||
model_path: PathBuf,
|
||||
data: String,
|
||||
settings_path: PathBuf,
|
||||
@@ -1048,7 +1050,7 @@ pub(crate) fn calibrate(
|
||||
|
||||
let input_shapes = model.graph.input_shapes()?;
|
||||
|
||||
let chunks = data.split_into_batches(input_shapes)?;
|
||||
let chunks = data.split_into_batches(input_shapes).await?;
|
||||
info!("num calibration batches: {}", chunks.len());
|
||||
|
||||
debug!("running onnx predictions...");
|
||||
@@ -1159,7 +1161,7 @@ pub(crate) fn calibrate(
|
||||
let chunk = chunk.clone();
|
||||
|
||||
let data = circuit
|
||||
.load_graph_input(&chunk)
|
||||
.load_graph_from_file_exclusively(&chunk)
|
||||
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
|
||||
|
||||
let forward_res = circuit
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -17,7 +17,7 @@ use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
tract_data::{TVec, prelude::Tensor as TractTensor},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
@@ -209,30 +209,27 @@ pub struct CallToAccount {
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
pub struct DataSource(FileSource);
|
||||
|
||||
impl DataSource {
|
||||
/// Gets the underlying file source data
|
||||
pub fn values(&self) -> &FileSource {
|
||||
&self.0
|
||||
}
|
||||
#[serde(untagged)]
|
||||
pub enum DataSource {
|
||||
/// Data from a JSON file containing arrays of values
|
||||
File(FileSource),
|
||||
}
|
||||
|
||||
impl Default for DataSource {
|
||||
fn default() -> Self {
|
||||
DataSource(vec![vec![]])
|
||||
DataSource::File(vec![vec![]])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FileSource> for DataSource {
|
||||
fn from(data: FileSource) -> Self {
|
||||
DataSource(data)
|
||||
DataSource::File(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Vec<Fp>>> for DataSource {
|
||||
fn from(data: Vec<Vec<Fp>>) -> Self {
|
||||
DataSource(
|
||||
DataSource::File(
|
||||
data.iter()
|
||||
.map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect())
|
||||
.collect(),
|
||||
@@ -242,7 +239,7 @@ impl From<Vec<Vec<Fp>>> for DataSource {
|
||||
|
||||
impl From<Vec<Vec<f64>>> for DataSource {
|
||||
fn from(data: Vec<Vec<f64>>) -> Self {
|
||||
DataSource(
|
||||
DataSource::File(
|
||||
data.iter()
|
||||
.map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect())
|
||||
.collect(),
|
||||
@@ -261,7 +258,7 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource(t));
|
||||
return Ok(DataSource::File(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
||||
@@ -297,16 +294,19 @@ impl GraphData {
|
||||
datum_types: &[tract_onnx::prelude::DatumType],
|
||||
) -> Result<TVec<TValue>, GraphError> {
|
||||
let mut inputs = TVec::new();
|
||||
for (i, input) in self.input_data.values().iter().enumerate() {
|
||||
if !input.is_empty() {
|
||||
let dt = datum_types[i];
|
||||
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
|
||||
let tt = TractTensor::from_shape(&shapes[i], &input)?;
|
||||
let tt = tt.cast_to_dt(dt)?;
|
||||
inputs.push(tt.into_owned().into());
|
||||
match &self.input_data {
|
||||
DataSource::File(data) => {
|
||||
for (i, input) in data.iter().enumerate() {
|
||||
if !input.is_empty() {
|
||||
let dt = datum_types[i];
|
||||
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
|
||||
let tt = TractTensor::from_shape(&shapes[i], &input)?;
|
||||
let tt = tt.cast_to_dt(dt)?;
|
||||
inputs.push(tt.into_owned().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
@@ -338,7 +338,7 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
Ok(GraphData {
|
||||
input_data: DataSource(input_data),
|
||||
input_data: DataSource::File(input_data),
|
||||
output_data: None,
|
||||
})
|
||||
}
|
||||
@@ -420,7 +420,7 @@ impl GraphData {
|
||||
/// Returns error if:
|
||||
/// - Data is from on-chain source
|
||||
/// - Input size is not evenly divisible by batch size
|
||||
pub fn split_into_batches(
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
@@ -428,7 +428,7 @@ impl GraphData {
|
||||
|
||||
let iterable = match self {
|
||||
GraphData {
|
||||
input_data: DataSource(data),
|
||||
input_data: DataSource::File(data),
|
||||
output_data: _,
|
||||
} => data.clone(),
|
||||
};
|
||||
@@ -476,12 +476,12 @@ impl GraphData {
|
||||
for input in batched_inputs.iter() {
|
||||
batch.push(input[i].clone());
|
||||
}
|
||||
input_batches.push(DataSource(batch));
|
||||
input_batches.push(DataSource::File(batch));
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource(vec![vec![]]));
|
||||
input_batches.push(DataSource::File(vec![vec![]]));
|
||||
}
|
||||
|
||||
// Create GraphData instance for each batch
|
||||
@@ -556,7 +556,9 @@ impl ToPyObject for CallToAccount {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for DataSource {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
self.0.to_object(py)
|
||||
match self {
|
||||
DataSource::File(data) => data.to_object(py),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::{ConstantsMap, RegionSettings};
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::{felt_to_f64, IntegerRep};
|
||||
use crate::fieldutils::{IntegerRep, felt_to_f64};
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
use crate::{EZKL_BUF_CAPACITY, RunArgs};
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
@@ -55,13 +55,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
@@ -952,11 +952,77 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
///
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
self.load_file_data(data.input_data.values(), &shapes, scales, input_types)
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
}
|
||||
|
||||
///
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub fn load_graph_from_file_exclusively(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
match &data.input_data {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
fn process_data_source(
|
||||
&mut self,
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
match &data {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Process the data source for the model
|
||||
async fn process_data_source(
|
||||
&mut self,
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
match &data {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
@@ -4,7 +4,7 @@ mod native_tests {
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, GraphData};
|
||||
use ezkl::graph::GraphSettings;
|
||||
use ezkl::graph::{DataSource, GraphSettings};
|
||||
use ezkl::pfsys::Snark;
|
||||
use ezkl::Commitments;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -163,14 +163,17 @@ mod native_tests {
|
||||
let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into())
|
||||
.expect("failed to load input data");
|
||||
|
||||
let duplicated_input_data: FileSource = data
|
||||
.input_data
|
||||
.values()
|
||||
let input_data = match data.input_data {
|
||||
DataSource::File(data) => data,
|
||||
_ => panic!("Only File data sources support batching"),
|
||||
};
|
||||
|
||||
let duplicated_input_data: FileSource = input_data
|
||||
.iter()
|
||||
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
|
||||
.collect();
|
||||
|
||||
let duplicated_data = GraphData::new(duplicated_input_data.into());
|
||||
let duplicated_data = GraphData::new(DataSource::File(duplicated_input_data));
|
||||
|
||||
let res =
|
||||
duplicated_data.save(format!("{}/{}/input.json", test_dir, output_dir).into());
|
||||
|
||||
Reference in New Issue
Block a user