mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-14 00:38:15 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e16a482be9 | ||
|
|
9077b8debc | ||
|
|
28594c7651 |
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '22.1.2'
|
||||
release = '22.1.4'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -1266,7 +1266,8 @@ pub(crate) fn calibrate(
|
||||
num_rows: new_settings.num_rows,
|
||||
total_assignments: new_settings.total_assignments,
|
||||
total_const_size: new_settings.total_const_size,
|
||||
total_dynamic_col_size: new_settings.total_dynamic_col_size,
|
||||
dynamic_lookup_params: new_settings.dynamic_lookup_params,
|
||||
shuffle_params: new_settings.shuffle_params,
|
||||
..settings.clone()
|
||||
};
|
||||
|
||||
|
||||
@@ -21,10 +21,6 @@ use tract_onnx::tract_core::{
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
/// Represents different types of values that can be stored in a file source
|
||||
/// Used for handling various input types in zero-knowledge proofs
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
@@ -37,6 +33,22 @@ pub enum FileSourceInner {
|
||||
Field(Fp),
|
||||
}
|
||||
|
||||
impl From<Fp> for FileSourceInner {
|
||||
fn from(value: Fp) -> Self {
|
||||
FileSourceInner::Field(value)
|
||||
}
|
||||
}
|
||||
impl From<bool> for FileSourceInner {
|
||||
fn from(value: bool) -> Self {
|
||||
FileSourceInner::Bool(value)
|
||||
}
|
||||
}
|
||||
impl From<f64> for FileSourceInner {
|
||||
fn from(value: f64) -> Self {
|
||||
FileSourceInner::Float(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FileSourceInner {
|
||||
/// Returns true if the value is a floating point number
|
||||
pub fn is_float(&self) -> bool {
|
||||
@@ -159,115 +171,11 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
|
||||
/// A collection of input values from a file source
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
pub type DataSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
/// Represents which parts of the model (input/output) are attested to on-chain
|
||||
pub type InputOutput = (bool, bool);
|
||||
|
||||
/// Configuration for accessing on-chain data sources
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub call: CallToAccount,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource { call, rpc }
|
||||
}
|
||||
}
|
||||
|
||||
/// Specification for view-only calls to fetch on-chain data
|
||||
/// Used for data attestation in smart contract verification
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallsToAccount {
|
||||
/// Vector of (call data, decimals) pairs
|
||||
/// call_data: ABI-encoded function call
|
||||
/// decimals: Number of decimal places for float conversion
|
||||
pub call_data: Vec<(Call, Decimals)>,
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Specification for a single view-only call returning an array
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// ABI-encoded function call data
|
||||
pub call_data: Call,
|
||||
/// Number of decimal places for float conversion
|
||||
pub decimals: Vec<Decimals>,
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
pub struct DataSource(FileSource);
|
||||
|
||||
impl DataSource {
|
||||
/// Gets the underlying file source data
|
||||
pub fn values(&self) -> &FileSource {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DataSource {
|
||||
fn default() -> Self {
|
||||
DataSource(vec![vec![]])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FileSource> for DataSource {
|
||||
fn from(data: FileSource) -> Self {
|
||||
DataSource(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Vec<Fp>>> for DataSource {
|
||||
fn from(data: Vec<Vec<Fp>>) -> Self {
|
||||
DataSource(
|
||||
data.iter()
|
||||
.map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Vec<f64>>> for DataSource {
|
||||
fn from(data: Vec<Vec<f64>>) -> Self {
|
||||
DataSource(
|
||||
data.iter()
|
||||
.map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Always use JSON serialization for untagged enums
|
||||
impl<'de> Deserialize<'de> for DataSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for input and output data for graph computations
|
||||
///
|
||||
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
|
||||
@@ -297,7 +205,7 @@ impl GraphData {
|
||||
datum_types: &[tract_onnx::prelude::DatumType],
|
||||
) -> Result<TVec<TValue>, GraphError> {
|
||||
let mut inputs = TVec::new();
|
||||
for (i, input) in self.input_data.values().iter().enumerate() {
|
||||
for (i, input) in self.input_data.iter().enumerate() {
|
||||
if !input.is_empty() {
|
||||
let dt = datum_types[i];
|
||||
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
|
||||
@@ -338,7 +246,7 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
Ok(GraphData {
|
||||
input_data: DataSource(input_data),
|
||||
input_data,
|
||||
output_data: None,
|
||||
})
|
||||
}
|
||||
@@ -424,12 +332,7 @@ impl GraphData {
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
let iterable = match self {
|
||||
GraphData {
|
||||
input_data: DataSource(data),
|
||||
output_data: _,
|
||||
} => data.clone(),
|
||||
};
|
||||
let iterable = self.input_data.clone();
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
@@ -474,12 +377,12 @@ impl GraphData {
|
||||
for input in batched_inputs.iter() {
|
||||
batch.push(input[i].clone());
|
||||
}
|
||||
input_batches.push(DataSource(batch));
|
||||
input_batches.push(batch);
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource(vec![vec![]]));
|
||||
input_batches.push(vec![vec![]]);
|
||||
}
|
||||
|
||||
// Create GraphData instance for each batch
|
||||
@@ -498,28 +401,14 @@ impl GraphData {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
// Test serialization/deserialization of graph input
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
let file = GraphData::new(vec![vec![
|
||||
0.05326242372393608.into(),
|
||||
0.07497056573629379.into(),
|
||||
0.05235547572374344.into(),
|
||||
]]);
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
|
||||
@@ -26,11 +26,11 @@ use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
use self::input::{FileSource, GraphData};
|
||||
use self::input::GraphData;
|
||||
|
||||
use self::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use self::input::{FileSource, GraphData};
|
||||
use self::input::GraphData;
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
@@ -413,6 +413,27 @@ fn insert_polycommit_pydict(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
/// Parameters for dynamic lookups
|
||||
/// serde should flatten this struct
|
||||
pub struct DynamicLookupParams {
|
||||
/// total dynamic column size
|
||||
pub total_dynamic_col_size: usize,
|
||||
/// max dynamic column input length
|
||||
pub max_dynamic_input_len: usize,
|
||||
/// number of dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
/// Parameters for shuffle operations
|
||||
pub struct ShuffleParams {
|
||||
/// number of shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// total shuffle column size
|
||||
pub total_shuffle_col_size: usize,
|
||||
}
|
||||
|
||||
/// model parameters
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GraphSettings {
|
||||
@@ -424,16 +445,12 @@ pub struct GraphSettings {
|
||||
pub total_assignments: usize,
|
||||
/// total const size
|
||||
pub total_const_size: usize,
|
||||
/// total dynamic column size
|
||||
pub total_dynamic_col_size: usize,
|
||||
/// max dynamic column input length
|
||||
pub max_dynamic_input_len: usize,
|
||||
/// number of dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// number of shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// total shuffle column size
|
||||
pub total_shuffle_col_size: usize,
|
||||
/// dynamic lookup parameters, flattened for backwards compatibility
|
||||
#[serde(flatten)]
|
||||
pub dynamic_lookup_params: DynamicLookupParams,
|
||||
/// shuffle parameters, flattened for backwards compatibility
|
||||
#[serde(flatten)]
|
||||
pub shuffle_params: ShuffleParams,
|
||||
/// the shape of public inputs to the model (in order of appearance)
|
||||
pub model_instance_shapes: Vec<Vec<usize>>,
|
||||
/// model output scales
|
||||
@@ -495,15 +512,16 @@ impl GraphSettings {
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_logrows(&self) -> u32 {
|
||||
(self.total_dynamic_col_size as f64 + self.total_shuffle_col_size as f64)
|
||||
(self.dynamic_lookup_params.total_dynamic_col_size as f64
|
||||
+ self.shuffle_params.total_shuffle_col_size as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the dynamic lookup and shuffle
|
||||
pub fn dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
|
||||
(self.total_dynamic_col_size as f64
|
||||
+ self.total_shuffle_col_size as f64
|
||||
(self.dynamic_lookup_params.total_dynamic_col_size as f64
|
||||
+ self.shuffle_params.total_shuffle_col_size as f64
|
||||
+ RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
@@ -511,13 +529,14 @@ impl GraphSettings {
|
||||
|
||||
/// calculate the number of rows required for the dynamic lookup and shuffle
|
||||
pub fn min_dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
|
||||
(self.max_dynamic_input_len as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
(self.dynamic_lookup_params.max_dynamic_input_len as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
|
||||
self.total_dynamic_col_size + self.total_shuffle_col_size
|
||||
self.dynamic_lookup_params.total_dynamic_col_size
|
||||
+ self.shuffle_params.total_shuffle_col_size
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the module constraints
|
||||
@@ -653,12 +672,12 @@ impl GraphSettings {
|
||||
|
||||
/// requires dynamic lookup
|
||||
pub fn requires_dynamic_lookup(&self) -> bool {
|
||||
self.num_dynamic_lookups > 0
|
||||
self.dynamic_lookup_params.num_dynamic_lookups > 0
|
||||
}
|
||||
|
||||
/// requires dynamic shuffle
|
||||
pub fn requires_shuffle(&self) -> bool {
|
||||
self.num_shuffles > 0
|
||||
self.shuffle_params.num_shuffles > 0
|
||||
}
|
||||
|
||||
/// any kzg visibility
|
||||
@@ -956,13 +975,13 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
self.load_file_data(data.input_data.values(), &shapes, scales, input_types)
|
||||
self.load_file_data(&data.input_data, &shapes, scales, input_types)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load_file_data(
|
||||
&mut self,
|
||||
file_data: &FileSource,
|
||||
file_data: &DataSource,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
|
||||
@@ -12,6 +12,8 @@ use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::graph::DynamicLookupParams;
|
||||
use crate::graph::ShuffleParams;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -100,12 +102,10 @@ pub type NodeGraph = BTreeMap<usize, NodeType>;
|
||||
pub struct DummyPassRes {
|
||||
/// number of rows use
|
||||
pub num_rows: usize,
|
||||
/// num dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// max dynamic lookup input len
|
||||
pub max_dynamic_input_len: usize,
|
||||
/// dynamic lookup col size
|
||||
pub dynamic_lookup_col_coord: usize,
|
||||
/// dynamic lookup parameters
|
||||
pub dynamic_lookup_params: DynamicLookupParams,
|
||||
/// shuffle parameters
|
||||
pub shuffle_params: ShuffleParams,
|
||||
/// num shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// shuffle
|
||||
@@ -585,16 +585,13 @@ impl Model {
|
||||
num_rows: res.num_rows,
|
||||
total_assignments: res.linear_coord,
|
||||
required_lookups: res.lookup_ops.into_iter().collect(),
|
||||
max_dynamic_input_len: res.max_dynamic_input_len,
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
input_types: self.get_input_types().ok(),
|
||||
output_types: Some(self.get_output_types()),
|
||||
num_dynamic_lookups: res.num_dynamic_lookups,
|
||||
total_dynamic_col_size: res.dynamic_lookup_col_coord,
|
||||
num_shuffles: res.num_shuffles,
|
||||
total_shuffle_col_size: res.shuffle_col_coord,
|
||||
dynamic_lookup_params: res.dynamic_lookup_params,
|
||||
shuffle_params: res.shuffle_params,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -1523,15 +1520,21 @@ impl Model {
|
||||
let res = DummyPassRes {
|
||||
num_rows: region.row(),
|
||||
linear_coord: region.linear_coord(),
|
||||
max_dynamic_input_len: region.max_dynamic_input_len(),
|
||||
dynamic_lookup_params: DynamicLookupParams {
|
||||
total_dynamic_col_size: region.dynamic_lookup_col_coord(),
|
||||
max_dynamic_input_len: region.max_dynamic_input_len(),
|
||||
num_dynamic_lookups: region.dynamic_lookup_index(),
|
||||
},
|
||||
shuffle_params: ShuffleParams {
|
||||
num_shuffles: region.shuffle_index(),
|
||||
total_shuffle_col_size: region.shuffle_col_coord(),
|
||||
},
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
max_lookup_inputs: region.max_lookup_inputs(),
|
||||
min_lookup_inputs: region.min_lookup_inputs(),
|
||||
max_range_size: region.max_range_size(),
|
||||
num_dynamic_lookups: region.dynamic_lookup_index(),
|
||||
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
|
||||
num_shuffles: region.shuffle_index(),
|
||||
shuffle_col_coord: region.shuffle_col_coord(),
|
||||
outputs,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
mod native_tests {
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, GraphData};
|
||||
use ezkl::graph::input::GraphData;
|
||||
use ezkl::graph::GraphSettings;
|
||||
use ezkl::pfsys::Snark;
|
||||
use ezkl::Commitments;
|
||||
@@ -163,12 +163,7 @@ mod native_tests {
|
||||
let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into())
|
||||
.expect("failed to load input data");
|
||||
|
||||
let duplicated_input_data: FileSource = data
|
||||
.input_data
|
||||
.values()
|
||||
.iter()
|
||||
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
|
||||
.collect();
|
||||
let duplicated_input_data = data.input_data;
|
||||
|
||||
let duplicated_data = GraphData::new(duplicated_input_data.into());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user