mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
1 Commits
ac/release
...
ac/cleanup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba0a151544 |
@@ -21,10 +21,6 @@ use tract_onnx::tract_core::{
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
/// Represents different types of values that can be stored in a file source
|
||||
/// Used for handling various input types in zero-knowledge proofs
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
@@ -37,6 +33,22 @@ pub enum FileSourceInner {
|
||||
Field(Fp),
|
||||
}
|
||||
|
||||
impl From<Fp> for FileSourceInner {
|
||||
fn from(value: Fp) -> Self {
|
||||
FileSourceInner::Field(value)
|
||||
}
|
||||
}
|
||||
impl From<bool> for FileSourceInner {
|
||||
fn from(value: bool) -> Self {
|
||||
FileSourceInner::Bool(value)
|
||||
}
|
||||
}
|
||||
impl From<f64> for FileSourceInner {
|
||||
fn from(value: f64) -> Self {
|
||||
FileSourceInner::Float(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FileSourceInner {
|
||||
/// Returns true if the value is a floating point number
|
||||
pub fn is_float(&self) -> bool {
|
||||
@@ -159,115 +171,11 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
|
||||
/// A collection of input values from a file source
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
pub type DataSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
/// Represents which parts of the model (input/output) are attested to on-chain
|
||||
pub type InputOutput = (bool, bool);
|
||||
|
||||
/// Configuration for accessing on-chain data sources
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub call: CallToAccount,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource { call, rpc }
|
||||
}
|
||||
}
|
||||
|
||||
/// Specification for view-only calls to fetch on-chain data
|
||||
/// Used for data attestation in smart contract verification
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallsToAccount {
|
||||
/// Vector of (call data, decimals) pairs
|
||||
/// call_data: ABI-encoded function call
|
||||
/// decimals: Number of decimal places for float conversion
|
||||
pub call_data: Vec<(Call, Decimals)>,
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Specification for a single view-only call returning an array
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// ABI-encoded function call data
|
||||
pub call_data: Call,
|
||||
/// Number of decimal places for float conversion
|
||||
pub decimals: Vec<Decimals>,
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
pub struct DataSource(FileSource);
|
||||
|
||||
impl DataSource {
|
||||
/// Gets the underlying file source data
|
||||
pub fn values(&self) -> &FileSource {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DataSource {
|
||||
fn default() -> Self {
|
||||
DataSource(vec![vec![]])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FileSource> for DataSource {
|
||||
fn from(data: FileSource) -> Self {
|
||||
DataSource(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Vec<Fp>>> for DataSource {
|
||||
fn from(data: Vec<Vec<Fp>>) -> Self {
|
||||
DataSource(
|
||||
data.iter()
|
||||
.map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Vec<f64>>> for DataSource {
|
||||
fn from(data: Vec<Vec<f64>>) -> Self {
|
||||
DataSource(
|
||||
data.iter()
|
||||
.map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Always use JSON serialization for untagged enums
|
||||
impl<'de> Deserialize<'de> for DataSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for input and output data for graph computations
|
||||
///
|
||||
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
|
||||
@@ -297,7 +205,7 @@ impl GraphData {
|
||||
datum_types: &[tract_onnx::prelude::DatumType],
|
||||
) -> Result<TVec<TValue>, GraphError> {
|
||||
let mut inputs = TVec::new();
|
||||
for (i, input) in self.input_data.values().iter().enumerate() {
|
||||
for (i, input) in self.input_data.iter().enumerate() {
|
||||
if !input.is_empty() {
|
||||
let dt = datum_types[i];
|
||||
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
|
||||
@@ -338,7 +246,7 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
Ok(GraphData {
|
||||
input_data: DataSource(input_data),
|
||||
input_data,
|
||||
output_data: None,
|
||||
})
|
||||
}
|
||||
@@ -424,12 +332,7 @@ impl GraphData {
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
let iterable = match self {
|
||||
GraphData {
|
||||
input_data: DataSource(data),
|
||||
output_data: _,
|
||||
} => data.clone(),
|
||||
};
|
||||
let iterable = self.input_data.clone();
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
@@ -474,12 +377,12 @@ impl GraphData {
|
||||
for input in batched_inputs.iter() {
|
||||
batch.push(input[i].clone());
|
||||
}
|
||||
input_batches.push(DataSource(batch));
|
||||
input_batches.push(batch);
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource(vec![vec![]]));
|
||||
input_batches.push(vec![vec![]]);
|
||||
}
|
||||
|
||||
// Create GraphData instance for each batch
|
||||
@@ -498,28 +401,14 @@ impl GraphData {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
// Test serialization/deserialization of graph input
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
let file = GraphData::new(vec![vec![
|
||||
0.05326242372393608.into(),
|
||||
0.07497056573629379.into(),
|
||||
0.05235547572374344.into(),
|
||||
]]);
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
|
||||
@@ -26,11 +26,11 @@ use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
use self::input::{FileSource, GraphData};
|
||||
use self::input::GraphData;
|
||||
|
||||
use self::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use self::input::{FileSource, GraphData};
|
||||
use self::input::GraphData;
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
@@ -956,13 +956,13 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
self.load_file_data(data.input_data.values(), &shapes, scales, input_types)
|
||||
self.load_file_data(&data.input_data, &shapes, scales, input_types)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load_file_data(
|
||||
&mut self,
|
||||
file_data: &FileSource,
|
||||
file_data: &DataSource,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
mod native_tests {
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, GraphData};
|
||||
use ezkl::graph::input::GraphData;
|
||||
use ezkl::graph::GraphSettings;
|
||||
use ezkl::pfsys::Snark;
|
||||
use ezkl::Commitments;
|
||||
@@ -163,12 +163,7 @@ mod native_tests {
|
||||
let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into())
|
||||
.expect("failed to load input data");
|
||||
|
||||
let duplicated_input_data: FileSource = data
|
||||
.input_data
|
||||
.values()
|
||||
.iter()
|
||||
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
|
||||
.collect();
|
||||
let duplicated_input_data = data.input_data;
|
||||
|
||||
let duplicated_data = GraphData::new(duplicated_input_data.into());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user