Compare commits

...

1 Commits

Author SHA1 Message Date
dante
ba0a151544 chore: simplify DataSource 2025-07-03 10:22:32 -04:00
3 changed files with 33 additions and 149 deletions

View File

@@ -21,10 +21,6 @@ use tract_onnx::tract_core::{
value::TValue,
};
type Decimals = u8;
type Call = String;
type RPCUrl = String;
/// Represents different types of values that can be stored in a file source
/// Used for handling various input types in zero-knowledge proofs
#[derive(Clone, Debug, PartialOrd, PartialEq)]
@@ -37,6 +33,22 @@ pub enum FileSourceInner {
Field(Fp),
}
impl From<Fp> for FileSourceInner {
fn from(value: Fp) -> Self {
FileSourceInner::Field(value)
}
}
impl From<bool> for FileSourceInner {
fn from(value: bool) -> Self {
FileSourceInner::Bool(value)
}
}
impl From<f64> for FileSourceInner {
fn from(value: f64) -> Self {
FileSourceInner::Float(value)
}
}
impl FileSourceInner {
/// Returns true if the value is a floating point number
pub fn is_float(&self) -> bool {
@@ -159,115 +171,11 @@ impl<'de> Deserialize<'de> for FileSourceInner {
/// A collection of input values from a file source
/// Organized as a vector of vectors where each inner vector represents a row/entry
pub type FileSource = Vec<Vec<FileSourceInner>>;
pub type DataSource = Vec<Vec<FileSourceInner>>;
/// Represents which parts of the model (input/output) are attested to on-chain
pub type InputOutput = (bool, bool);
/// Configuration for accessing on-chain data sources
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct OnChainSource {
/// Call specifications for fetching data
pub call: CallToAccount,
/// RPC endpoint URL for accessing the chain
pub rpc: RPCUrl,
}
impl OnChainSource {
/// Creates a new OnChainSource
///
/// # Arguments
/// * `call` - Call specification
/// * `rpc` - RPC endpoint URL
pub fn new(call: CallToAccount, rpc: RPCUrl) -> Self {
OnChainSource { call, rpc }
}
}
/// Specification for view-only calls to fetch on-chain data
/// Used for data attestation in smart contract verification
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct CallsToAccount {
/// Vector of (call data, decimals) pairs
/// call_data: ABI-encoded function call
/// decimals: Number of decimal places for float conversion
pub call_data: Vec<(Call, Decimals)>,
/// Contract address to call
pub address: String,
}
/// Specification for a single view-only call returning an array
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct CallToAccount {
/// ABI-encoded function call data
pub call_data: Call,
/// Number of decimal places for float conversion
pub decimals: Vec<Decimals>,
/// Contract address to call
pub address: String,
}
/// Represents different sources of input/output data for the EZKL model
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
pub struct DataSource(FileSource);
impl DataSource {
/// Gets the underlying file source data
pub fn values(&self) -> &FileSource {
&self.0
}
}
impl Default for DataSource {
fn default() -> Self {
DataSource(vec![vec![]])
}
}
impl From<FileSource> for DataSource {
fn from(data: FileSource) -> Self {
DataSource(data)
}
}
impl From<Vec<Vec<Fp>>> for DataSource {
fn from(data: Vec<Vec<Fp>>) -> Self {
DataSource(
data.iter()
.map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect())
.collect(),
)
}
}
impl From<Vec<Vec<f64>>> for DataSource {
fn from(data: Vec<Vec<f64>>) -> Self {
DataSource(
data.iter()
.map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect())
.collect(),
)
}
}
// Note: Always use JSON serialization for untagged enums
impl<'de> Deserialize<'de> for DataSource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
// Try deserializing as FileSource first
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = first_try {
return Ok(DataSource(t));
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
}
}
/// Container for input and output data for graph computations
///
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
@@ -297,7 +205,7 @@ impl GraphData {
datum_types: &[tract_onnx::prelude::DatumType],
) -> Result<TVec<TValue>, GraphError> {
let mut inputs = TVec::new();
for (i, input) in self.input_data.values().iter().enumerate() {
for (i, input) in self.input_data.iter().enumerate() {
if !input.is_empty() {
let dt = datum_types[i];
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
@@ -338,7 +246,7 @@ impl GraphData {
}
}
Ok(GraphData {
input_data: DataSource(input_data),
input_data,
output_data: None,
})
}
@@ -424,12 +332,7 @@ impl GraphData {
) -> Result<Vec<Self>, GraphError> {
let mut batched_inputs = vec![];
let iterable = match self {
GraphData {
input_data: DataSource(data),
output_data: _,
} => data.clone(),
};
let iterable = self.input_data.clone();
// Process each input tensor according to its shape
for (i, shape) in input_shapes.iter().enumerate() {
@@ -474,12 +377,12 @@ impl GraphData {
for input in batched_inputs.iter() {
batch.push(input[i].clone());
}
input_batches.push(DataSource(batch));
input_batches.push(batch);
}
// Ensure at least one batch exists
if input_batches.is_empty() {
input_batches.push(DataSource(vec![vec![]]));
input_batches.push(vec![vec![]]);
}
// Create GraphData instance for each batch
@@ -498,28 +401,14 @@ impl GraphData {
mod tests {
use super::*;
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
let serialized = serde_json::to_string(&source).unwrap();
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
assert_eq!(serialized, JSON);
let expect = serde_json::from_str::<DataSource>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(expect, source);
}
#[test]
fn test_graph_input_serialization_round_trip() {
// Test serialization/deserialization of graph input
let file = GraphData::new(DataSource::from(vec![vec![
0.05326242372393608,
0.07497056573629379,
0.05235547572374344,
]]));
let file = GraphData::new(vec![vec![
0.05326242372393608.into(),
0.07497056573629379.into(),
0.05235547572374344.into(),
]]);
let serialized = serde_json::to_string(&file).unwrap();
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;

View File

@@ -26,11 +26,11 @@ use itertools::Itertools;
use tosubcommand::ToFlags;
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
use self::input::{FileSource, GraphData};
use self::input::GraphData;
use self::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use self::input::{FileSource, GraphData};
use self::input::GraphData;
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
@@ -956,13 +956,13 @@ impl GraphCircuit {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
self.load_file_data(data.input_data.values(), &shapes, scales, input_types)
self.load_file_data(&data.input_data, &shapes, scales, input_types)
}
///
pub fn load_file_data(
&mut self,
file_data: &FileSource,
file_data: &DataSource,
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,

View File

@@ -3,7 +3,7 @@
mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, GraphData};
use ezkl::graph::input::GraphData;
use ezkl::graph::GraphSettings;
use ezkl::pfsys::Snark;
use ezkl::Commitments;
@@ -163,12 +163,7 @@ mod native_tests {
let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into())
.expect("failed to load input data");
let duplicated_input_data: FileSource = data
.input_data
.values()
.iter()
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
.collect();
let duplicated_input_data = data.input_data;
let duplicated_data = GraphData::new(duplicated_input_data.into());